├── assets └── mesa-header-nz.png ├── requirements.txt ├── LICENSE ├── README.md ├── pipeline_terrain.py └── models.py /assets/mesa-header-nz.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PaulBorneP/MESA/HEAD/assets/mesa-header-nz.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | diffusers==0.32.2 2 | transformers==4.51.1 3 | accelerate==1.5.2 4 | torch --index-url https://download.pytorch.org/whl/cu124 5 | 6 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | # ADOBE RESEARCH LICENSE 2 | 3 | This license agreement (the “License”) between Adobe Inc., having a place of business at 345 Park Avenue, San Jose, California 95110-2704 (“Adobe”), and you, the individual or entity exercising rights under this License (“you” or “your”), sets forth the terms for your use of certain research materials that are owned by Adobe (the “Licensed Materials”). By exercising rights under this License, you accept and agree to be bound by its terms. If you are exercising rights under this license on behalf of an entity, then “you” means you and such entity, and you (personally) represent and warrant that you (personally) have all necessary authority to bind that entity to the terms of this License. 4 | 5 | 1. **GRANT OF LICENSE.** 6 | 7 | 1.1. Adobe grants you a nonexclusive, worldwide, royalty-free, revocable, fully paid license to (A) reproduce, use, modify, and publicly display the Licensed Materials for noncommercial research purposes only; and (B) redistribute the Licensed Materials, and modifications or derivative works thereof, for noncommercial research purposes only, provided that you give recipients a copy of this License. 8 | 9 | 1.2. You may add your own copyright statement to your modifications and may provide additional or different license terms for use, reproduction, modification, public display, and redistribution of your modifications and derivative works, provided that such license terms limit the use, reproduction, modification, public display, and redistribution of such modifications and derivative works to noncommercial research purposes only. 10 | 11 | 1.3. For purposes of this License, noncommercial research purposes include academic research and teaching but do not include commercial licensing or distribution, development of commercial products, or any other activity which results in commercial gain. 12 | 13 | 14 | 2. **OWNERSHIP AND ATTRIBUTION.** Adobe and its licensors own all right, title, and interest in the Licensed Materials. You must keep intact any copyright or other notices or disclaimers in the Licensed Materials. 15 | 16 | 3. **DISCLAIMER OF WARRANTIES.** THE LICENSED MATERIALS ARE PROVIDED “AS IS” WITHOUT WARRANTY OF ANY KIND. THE ENTIRE RISK AS TO THE RESULTS AND PERFORMANCE OF THE LICENSED MATERIALS IS ASSUMED BY YOU. ADOBE DISCLAIMS ALL WARRANTIES, EXPRESS, IMPLIED OR STATUTORY, WITH REGARD TO ANY LICENSED MATERIALS PROVIDED UNDER THIS LICENSE, INCLUDING, BUT NOT LIMITED TO, ANY IMPLIED WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, AND NONINFRINGEMENT OF THIRD-PARTY RIGHTS. 17 | 18 | 4. **LIMITATION OF LIABILITY.** IN NO EVENT WILL ADOBE BE LIABLE FOR ANY ACTUAL, INCIDENTAL, SPECIAL OR CONSEQUENTIAL DAMAGES OF ANY NATURE WHATSOEVER, INCLUDING WITHOUT LIMITATION, LOSS OF PROFITS OR OTHER COMMERCIAL LOSS, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF ANY LICENSED MATERIALS PROVIDED UNDER THIS LICENSE, EVEN IF ADOBE HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. 19 | 20 | 5. **TERM AND TERMINATION.** 21 | 22 | 5.1. The License is effective upon acceptance by you and will remain in effect unless terminated earlier as permitted under this License. 23 | 24 | 5.2. If you breach any material provision of this License, then your rights will terminate immediately. 25 | 26 | 5.3. All clauses which by their nature should survive the termination of this License will survive such termination. In addition, and without limiting the generality of the preceding sentence, Sections 2 (Ownership and Attribution), 3 (Disclaimer of Warranties), and 4 (Limitation of Liability) will survive termination of this License. 27 | 28 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 |

MESA: Text-Driven Terrain Generation Using Latent Diffusion and Global Copernicus Data

4 |

Paul Borne--Pons, Mikolaj Czerkawski,Rosalie Martin, 5 | Romain Rouffet

6 | 7 |

CVPR 2025 Workshop MORSE

8 | 9 | [![HF](https://img.shields.io/badge/%F0%9F%8F%94%EF%B8%8FProject%20Page-679c39)](https://paulbornep.github.io/mesa-terrain/) 10 | [![paper](https://img.shields.io/badge/arXiv-2402.12095-D12424)](https://arxiv.org/abs/2504.07210) 11 | [![HF](https://img.shields.io/badge/%F0%9F%A4%97-Models-yellow)](https://www.huggingface.co/NewtNewt/MESA) 12 | [![HF](https://img.shields.io/badge/%F0%9F%A4%97-Datasets-yellow)](https://www.huggingface.co/Major-TOM) 13 | [![HF](https://img.shields.io/badge/%F0%9F%A4%97-Spaces_Demo-yellow)](https://huggingface.co/spaces/mikonvergence/MESA) 14 | Open In Colab 15 | 16 | 17 | MESA is a novel generative model based on latent denoising diffusion capable of generating 2.5D representations of terrain based on the text prompt conditioning supplied via natural language. The model produces two co-registered modalities of optical and depth maps. 18 | 19 |

20 | 21 | ## Abstract 22 | 23 | Terrain modeling has traditionally relied on procedural techniques, which often require extensive domain expertise and handcrafted rules. In this paper, we present MESA - a novel data-centric alternative by training a diffusion model on global remote sensing data. This approach leverages large-scale geospatial information to generate high-quality terrain samples from text descriptions, showcasing a flexible and scalable solution for terrain generation. The model’s capabilities are demonstrated through extensive experiments, highlighting its ability to generate realistic and diverse terrain landscapes. The dataset produced to support this work, the Major TOM Core-DEM extension dataset, is released openly as a comprehensive resource for global terrain data. The results suggest that data-driven models, trained on remote sensing data, can provide a powerful tool for realistic terrain modeling and generation. 24 | 25 | ## Model Weights 26 | 27 | You still manually acquire the weights by cloning the models from Hugging Face: 28 | 29 | ```bash 30 | mkdir weights 31 | huggingface-cli download NewtNewt/MESA --local-dir weights 32 | ``` 33 | 34 | ## Installation 35 | 36 | ```bash 37 | # using python 3.11.12 38 | pip install -r requirements.txt 39 | ``` 40 | 41 | Note that this environment is only compatible with NVIDIA GPUs. Additionally, we recommend using a GPU with a minimum of 8GB of memory. 42 | 43 | ## Inference 44 | 45 | ```python 46 | from MESA.pipeline_terrain import TerrainDiffusionPipeline 47 | import torch 48 | 49 | pipe = TerrainDiffusionPipeline.from_pretrained("./weights", torch_dtype=torch.float16) 50 | pipe.to("cuda"); 51 | 52 | prompt = "A sentinel-2 image of montane forests and mountains in Mexico in August" 53 | image,dem = pipe(prompt, num_inference_steps=50, guidance_scale=7.5) 54 | ``` 55 | 56 | A straightforward code for inference is provided in Open In Colab 57 | 58 | Alternatively, you can download and use the Gradio demo from the HF page. 59 | 60 | ## Citation 61 | 62 | ```latex 63 | @inproceedings{mesa2025, 64 | title={MESA: Text-Driven Terrain Generation Using Latent Diffusion and Global Copernicus Data}, 65 | author={Paul Borne--Pons and Mikolaj Czerkawski and Rosalie Martin and Romain Rouffet}, 66 | year={2025}, 67 | booktitle={MORSE Workshop at CVPR 2025}, 68 | eprint={2504.07210}, 69 | url={https://arxiv.org/abs/2504.07210},} 70 | ``` 71 | ## Acknowledgements 72 | 73 | This implementation builds upon Hugging Face’s [Diffusers](https://github.com/huggingface/diffusers) library. We also acknowledge [Gradio](https://www.gradio.app/) for providing an easy-to-use interface that allowed us to create the inference demos for our models. 74 | 75 | This model is the product of a collaboration between [Φ-lab, European Space Agency (ESA)](https://philab.esa.int/) and the [Adobe Research (Paris, France)](https://research.adobe.com/careers/paris/). 76 | -------------------------------------------------------------------------------- /pipeline_terrain.py: -------------------------------------------------------------------------------- 1 | ########################################################################### 2 | # References: 3 | # https://github.com/huggingface/diffusers/ 4 | ########################################################################### 5 | import inspect 6 | from typing import Any, Callable, Dict, List, Optional, Union 7 | 8 | import torch 9 | from packaging import version 10 | from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection 11 | 12 | from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback 13 | from diffusers.configuration_utils import FrozenDict 14 | from diffusers.image_processor import PipelineImageInput, VaeImageProcessor 15 | from diffusers.loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin 16 | from diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel 17 | from diffusers.models.lora import adjust_lora_scale_text_encoder 18 | from diffusers.schedulers import KarrasDiffusionSchedulers 19 | from diffusers.utils import ( 20 | USE_PEFT_BACKEND, 21 | deprecate, 22 | is_torch_xla_available, 23 | logging, 24 | replace_example_docstring, 25 | scale_lora_layers, 26 | unscale_lora_layers, 27 | ) 28 | from diffusers.utils.torch_utils import randn_tensor 29 | from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin 30 | from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput 31 | from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker 32 | 33 | 34 | if is_torch_xla_available(): 35 | import torch_xla.core.xla_model as xm 36 | 37 | XLA_AVAILABLE = True 38 | else: 39 | XLA_AVAILABLE = False 40 | 41 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 42 | 43 | EXAMPLE_DOC_STRING = """ 44 | Examples: 45 | ```py 46 | >>> import torch 47 | >>> from diffusers import StableDiffusionPipeline 48 | 49 | >>> pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16) 50 | >>> pipe = pipe.to("cuda") 51 | 52 | >>> prompt = "a photo of an astronaut riding a horse on mars" 53 | >>> image = pipe(prompt).images[0] 54 | ``` 55 | """ 56 | 57 | 58 | def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): 59 | """ 60 | Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and 61 | Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 62 | """ 63 | std_text = noise_pred_text.std( 64 | dim=list(range(1, noise_pred_text.ndim)), keepdim=True) 65 | std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) 66 | # rescale the results from guidance (fixes overexposure) 67 | noise_pred_rescaled = noise_cfg * (std_text / std_cfg) 68 | # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images 69 | noise_cfg = guidance_rescale * noise_pred_rescaled + \ 70 | (1 - guidance_rescale) * noise_cfg 71 | return noise_cfg 72 | 73 | 74 | def retrieve_timesteps( 75 | scheduler, 76 | num_inference_steps: Optional[int] = None, 77 | device: Optional[Union[str, torch.device]] = None, 78 | timesteps: Optional[List[int]] = None, 79 | sigmas: Optional[List[float]] = None, 80 | **kwargs, 81 | ): 82 | """ 83 | Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles 84 | custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. 85 | 86 | Args: 87 | scheduler (`SchedulerMixin`): 88 | The scheduler to get timesteps from. 89 | num_inference_steps (`int`): 90 | The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` 91 | must be `None`. 92 | device (`str` or `torch.device`, *optional*): 93 | The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. 94 | timesteps (`List[int]`, *optional*): 95 | Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, 96 | `num_inference_steps` and `sigmas` must be `None`. 97 | sigmas (`List[float]`, *optional*): 98 | Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, 99 | `num_inference_steps` and `timesteps` must be `None`. 100 | 101 | Returns: 102 | `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the 103 | second element is the number of inference steps. 104 | """ 105 | if timesteps is not None and sigmas is not None: 106 | raise ValueError( 107 | "Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") 108 | if timesteps is not None: 109 | accepts_timesteps = "timesteps" in set( 110 | inspect.signature(scheduler.set_timesteps).parameters.keys()) 111 | if not accepts_timesteps: 112 | raise ValueError( 113 | f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" 114 | f" timestep schedules. Please check whether you are using the correct scheduler." 115 | ) 116 | scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) 117 | timesteps = scheduler.timesteps 118 | num_inference_steps = len(timesteps) 119 | elif sigmas is not None: 120 | accept_sigmas = "sigmas" in set(inspect.signature( 121 | scheduler.set_timesteps).parameters.keys()) 122 | if not accept_sigmas: 123 | raise ValueError( 124 | f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" 125 | f" sigmas schedules. Please check whether you are using the correct scheduler." 126 | ) 127 | scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) 128 | timesteps = scheduler.timesteps 129 | num_inference_steps = len(timesteps) 130 | else: 131 | scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) 132 | timesteps = scheduler.timesteps 133 | return timesteps, num_inference_steps 134 | 135 | 136 | class TerrainDiffusionPipeline( 137 | DiffusionPipeline, 138 | StableDiffusionMixin, 139 | TextualInversionLoaderMixin, 140 | StableDiffusionLoraLoaderMixin, 141 | IPAdapterMixin, 142 | FromSingleFileMixin, 143 | ): 144 | r""" 145 | Pipeline for text-to-image generation using Stable Diffusion. 146 | 147 | This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods 148 | implemented for all pipelines (downloading, saving, running on a particular device, etc.). 149 | 150 | The pipeline also inherits the following loading methods: 151 | - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings 152 | - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights 153 | - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights 154 | - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files 155 | - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters 156 | 157 | Args: 158 | vae ([`AutoencoderKL`]): 159 | Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. 160 | text_encoder ([`~transformers.CLIPTextModel`]): 161 | Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). 162 | tokenizer ([`~transformers.CLIPTokenizer`]): 163 | A `CLIPTokenizer` to tokenize text. 164 | unet ([`UNet2DConditionModel`]): 165 | A `UNet2DConditionModel` to denoise the encoded image latents. 166 | scheduler ([`SchedulerMixin`]): 167 | A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of 168 | [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. 169 | safety_checker ([`StableDiffusionSafetyChecker`]): 170 | Classification module that estimates whether generated images could be considered offensive or harmful. 171 | Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details 172 | about a model's potential harms. 173 | feature_extractor ([`~transformers.CLIPImageProcessor`]): 174 | A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. 175 | """ 176 | 177 | model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae" 178 | _optional_components = ["safety_checker", 179 | "feature_extractor", "image_encoder"] 180 | _exclude_from_cpu_offload = ["safety_checker"] 181 | _callback_tensor_inputs = ["latents", 182 | "prompt_embeds", "negative_prompt_embeds"] 183 | 184 | def __init__( 185 | self, 186 | vae: AutoencoderKL, 187 | text_encoder: CLIPTextModel, 188 | tokenizer: CLIPTokenizer, 189 | unet: UNet2DConditionModel, 190 | scheduler: KarrasDiffusionSchedulers, 191 | safety_checker: StableDiffusionSafetyChecker, 192 | feature_extractor: CLIPImageProcessor, 193 | image_encoder: CLIPVisionModelWithProjection = None, 194 | requires_safety_checker: bool = True, 195 | ): 196 | super().__init__() 197 | 198 | if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: 199 | deprecation_message = ( 200 | f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" 201 | f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " 202 | "to update the config accordingly as leaving `steps_offset` might led to incorrect results" 203 | " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," 204 | " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" 205 | " file" 206 | ) 207 | deprecate("steps_offset!=1", "1.0.0", 208 | deprecation_message, standard_warn=False) 209 | new_config = dict(scheduler.config) 210 | new_config["steps_offset"] = 1 211 | scheduler._internal_dict = FrozenDict(new_config) 212 | 213 | if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: 214 | deprecation_message = ( 215 | f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." 216 | " `clip_sample` should be set to False in the configuration file. Please make sure to update the" 217 | " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" 218 | " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" 219 | " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" 220 | ) 221 | deprecate("clip_sample not set", "1.0.0", 222 | deprecation_message, standard_warn=False) 223 | new_config = dict(scheduler.config) 224 | new_config["clip_sample"] = False 225 | scheduler._internal_dict = FrozenDict(new_config) 226 | 227 | if safety_checker is None and requires_safety_checker: 228 | logger.warning( 229 | f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" 230 | " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" 231 | " results in services or applications open to the public. Both the diffusers team and Hugging Face" 232 | " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" 233 | " it only for use-cases that involve analyzing network behavior or auditing its results. For more" 234 | " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." 235 | ) 236 | 237 | if safety_checker is not None and feature_extractor is None: 238 | raise ValueError( 239 | "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" 240 | " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." 241 | ) 242 | 243 | is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( 244 | version.parse(unet.config._diffusers_version).base_version 245 | ) < version.parse("0.9.0.dev0") 246 | is_unet_sample_size_less_64 = hasattr( 247 | unet.config, "sample_size") and unet.config.sample_size < 64 248 | if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: 249 | deprecation_message = ( 250 | "The configuration file of the unet has set the default `sample_size` to smaller than" 251 | " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the" 252 | " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" 253 | " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" 254 | " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" 255 | " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" 256 | " in the config might lead to incorrect results in future versions. If you have downloaded this" 257 | " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" 258 | " the `unet/config.json` file" 259 | ) 260 | deprecate("sample_size<64", "1.0.0", 261 | deprecation_message, standard_warn=False) 262 | new_config = dict(unet.config) 263 | new_config["sample_size"] = 64 264 | unet._internal_dict = FrozenDict(new_config) 265 | 266 | self.register_modules( 267 | vae=vae, 268 | text_encoder=text_encoder, 269 | tokenizer=tokenizer, 270 | unet=unet, 271 | scheduler=scheduler, 272 | safety_checker=safety_checker, 273 | feature_extractor=feature_extractor, 274 | image_encoder=image_encoder, 275 | ) 276 | self.vae_scale_factor = 2 ** ( 277 | len(self.vae.config.block_out_channels) - 1) 278 | self.image_processor = VaeImageProcessor( 279 | vae_scale_factor=self.vae_scale_factor) 280 | self.register_to_config( 281 | requires_safety_checker=requires_safety_checker) 282 | 283 | def _encode_prompt( 284 | self, 285 | prompt, 286 | device, 287 | num_images_per_prompt, 288 | do_classifier_free_guidance, 289 | negative_prompt=None, 290 | prompt_embeds: Optional[torch.Tensor] = None, 291 | negative_prompt_embeds: Optional[torch.Tensor] = None, 292 | lora_scale: Optional[float] = None, 293 | **kwargs, 294 | ): 295 | deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." 296 | deprecate("_encode_prompt()", "1.0.0", 297 | deprecation_message, standard_warn=False) 298 | 299 | prompt_embeds_tuple = self.encode_prompt( 300 | prompt=prompt, 301 | device=device, 302 | num_images_per_prompt=num_images_per_prompt, 303 | do_classifier_free_guidance=do_classifier_free_guidance, 304 | negative_prompt=negative_prompt, 305 | prompt_embeds=prompt_embeds, 306 | negative_prompt_embeds=negative_prompt_embeds, 307 | lora_scale=lora_scale, 308 | **kwargs, 309 | ) 310 | 311 | # concatenate for backwards comp 312 | prompt_embeds = torch.cat( 313 | [prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) 314 | 315 | return prompt_embeds 316 | 317 | def encode_prompt( 318 | self, 319 | prompt, 320 | device, 321 | num_images_per_prompt, 322 | do_classifier_free_guidance, 323 | negative_prompt=None, 324 | prompt_embeds: Optional[torch.Tensor] = None, 325 | negative_prompt_embeds: Optional[torch.Tensor] = None, 326 | lora_scale: Optional[float] = None, 327 | clip_skip: Optional[int] = None, 328 | ): 329 | r""" 330 | Encodes the prompt into text encoder hidden states. 331 | 332 | Args: 333 | prompt (`str` or `List[str]`, *optional*): 334 | prompt to be encoded 335 | device: (`torch.device`): 336 | torch device 337 | num_images_per_prompt (`int`): 338 | number of images that should be generated per prompt 339 | do_classifier_free_guidance (`bool`): 340 | whether to use classifier free guidance or not 341 | negative_prompt (`str` or `List[str]`, *optional*): 342 | The prompt or prompts not to guide the image generation. If not defined, one has to pass 343 | `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is 344 | less than `1`). 345 | prompt_embeds (`torch.Tensor`, *optional*): 346 | Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not 347 | provided, text embeddings will be generated from `prompt` input argument. 348 | negative_prompt_embeds (`torch.Tensor`, *optional*): 349 | Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt 350 | weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input 351 | argument. 352 | lora_scale (`float`, *optional*): 353 | A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. 354 | clip_skip (`int`, *optional*): 355 | Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that 356 | the output of the pre-final layer will be used for computing the prompt embeddings. 357 | """ 358 | # set lora scale so that monkey patched LoRA 359 | # function of text encoder can correctly access it 360 | if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): 361 | self._lora_scale = lora_scale 362 | 363 | # dynamically adjust the LoRA scale 364 | if not USE_PEFT_BACKEND: 365 | adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) 366 | else: 367 | scale_lora_layers(self.text_encoder, lora_scale) 368 | 369 | if prompt is not None and isinstance(prompt, str): 370 | batch_size = 1 371 | elif prompt is not None and isinstance(prompt, list): 372 | batch_size = len(prompt) 373 | else: 374 | batch_size = prompt_embeds.shape[0] 375 | 376 | if prompt_embeds is None: 377 | # textual inversion: process multi-vector tokens if necessary 378 | if isinstance(self, TextualInversionLoaderMixin): 379 | prompt = self.maybe_convert_prompt(prompt, self.tokenizer) 380 | 381 | text_inputs = self.tokenizer( 382 | prompt, 383 | padding="max_length", 384 | max_length=self.tokenizer.model_max_length, 385 | truncation=True, 386 | return_tensors="pt", 387 | ) 388 | text_input_ids = text_inputs.input_ids 389 | untruncated_ids = self.tokenizer( 390 | prompt, padding="longest", return_tensors="pt").input_ids 391 | 392 | if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( 393 | text_input_ids, untruncated_ids 394 | ): 395 | removed_text = self.tokenizer.batch_decode( 396 | untruncated_ids[:, self.tokenizer.model_max_length - 1: -1] 397 | ) 398 | logger.warning( 399 | "The following part of your input was truncated because CLIP can only handle sequences up to" 400 | f" {self.tokenizer.model_max_length} tokens: {removed_text}" 401 | ) 402 | 403 | if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: 404 | attention_mask = text_inputs.attention_mask.to(device) 405 | else: 406 | attention_mask = None 407 | 408 | if clip_skip is None: 409 | prompt_embeds = self.text_encoder( 410 | text_input_ids.to(device), attention_mask=attention_mask) 411 | prompt_embeds = prompt_embeds[0] 412 | else: 413 | prompt_embeds = self.text_encoder( 414 | text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True 415 | ) 416 | # Access the `hidden_states` first, that contains a tuple of 417 | # all the hidden states from the encoder layers. Then index into 418 | # the tuple to access the hidden states from the desired layer. 419 | prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] 420 | # We also need to apply the final LayerNorm here to not mess with the 421 | # representations. The `last_hidden_states` that we typically use for 422 | # obtaining the final prompt representations passes through the LayerNorm 423 | # layer. 424 | prompt_embeds = self.text_encoder.text_model.final_layer_norm( 425 | prompt_embeds) 426 | 427 | if self.text_encoder is not None: 428 | prompt_embeds_dtype = self.text_encoder.dtype 429 | elif self.unet is not None: 430 | prompt_embeds_dtype = self.unet.dtype 431 | else: 432 | prompt_embeds_dtype = prompt_embeds.dtype 433 | 434 | prompt_embeds = prompt_embeds.to( 435 | dtype=prompt_embeds_dtype, device=device) 436 | 437 | bs_embed, seq_len, _ = prompt_embeds.shape 438 | # duplicate text embeddings for each generation per prompt, using mps friendly method 439 | prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) 440 | prompt_embeds = prompt_embeds.view( 441 | bs_embed * num_images_per_prompt, seq_len, -1) 442 | 443 | # get unconditional embeddings for classifier free guidance 444 | if do_classifier_free_guidance and negative_prompt_embeds is None: 445 | uncond_tokens: List[str] 446 | if negative_prompt is None: 447 | uncond_tokens = [""] * batch_size 448 | elif prompt is not None and type(prompt) is not type(negative_prompt): 449 | raise TypeError( 450 | f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" 451 | f" {type(prompt)}." 452 | ) 453 | elif isinstance(negative_prompt, str): 454 | uncond_tokens = [negative_prompt] 455 | elif batch_size != len(negative_prompt): 456 | raise ValueError( 457 | f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" 458 | f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" 459 | " the batch size of `prompt`." 460 | ) 461 | else: 462 | uncond_tokens = negative_prompt 463 | 464 | # textual inversion: process multi-vector tokens if necessary 465 | if isinstance(self, TextualInversionLoaderMixin): 466 | uncond_tokens = self.maybe_convert_prompt( 467 | uncond_tokens, self.tokenizer) 468 | 469 | max_length = prompt_embeds.shape[1] 470 | uncond_input = self.tokenizer( 471 | uncond_tokens, 472 | padding="max_length", 473 | max_length=max_length, 474 | truncation=True, 475 | return_tensors="pt", 476 | ) 477 | 478 | if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: 479 | attention_mask = uncond_input.attention_mask.to(device) 480 | else: 481 | attention_mask = None 482 | 483 | negative_prompt_embeds = self.text_encoder( 484 | uncond_input.input_ids.to(device), 485 | attention_mask=attention_mask, 486 | ) 487 | negative_prompt_embeds = negative_prompt_embeds[0] 488 | 489 | if do_classifier_free_guidance: 490 | # duplicate unconditional embeddings for each generation per prompt, using mps friendly method 491 | seq_len = negative_prompt_embeds.shape[1] 492 | 493 | negative_prompt_embeds = negative_prompt_embeds.to( 494 | dtype=prompt_embeds_dtype, device=device) 495 | 496 | negative_prompt_embeds = negative_prompt_embeds.repeat( 497 | 1, num_images_per_prompt, 1) 498 | negative_prompt_embeds = negative_prompt_embeds.view( 499 | batch_size * num_images_per_prompt, seq_len, -1) 500 | 501 | if self.text_encoder is not None: 502 | if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: 503 | # Retrieve the original scale by scaling back the LoRA layers 504 | unscale_lora_layers(self.text_encoder, lora_scale) 505 | 506 | return prompt_embeds, negative_prompt_embeds 507 | 508 | def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): 509 | dtype = next(self.image_encoder.parameters()).dtype 510 | 511 | if not isinstance(image, torch.Tensor): 512 | image = self.feature_extractor( 513 | image, return_tensors="pt").pixel_values 514 | 515 | image = image.to(device=device, dtype=dtype) 516 | if output_hidden_states: 517 | image_enc_hidden_states = self.image_encoder( 518 | image, output_hidden_states=True).hidden_states[-2] 519 | image_enc_hidden_states = image_enc_hidden_states.repeat_interleave( 520 | num_images_per_prompt, dim=0) 521 | uncond_image_enc_hidden_states = self.image_encoder( 522 | torch.zeros_like(image), output_hidden_states=True 523 | ).hidden_states[-2] 524 | uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( 525 | num_images_per_prompt, dim=0 526 | ) 527 | return image_enc_hidden_states, uncond_image_enc_hidden_states 528 | else: 529 | image_embeds = self.image_encoder(image).image_embeds 530 | image_embeds = image_embeds.repeat_interleave( 531 | num_images_per_prompt, dim=0) 532 | uncond_image_embeds = torch.zeros_like(image_embeds) 533 | 534 | return image_embeds, uncond_image_embeds 535 | 536 | def prepare_ip_adapter_image_embeds( 537 | self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance 538 | ): 539 | image_embeds = [] 540 | if do_classifier_free_guidance: 541 | negative_image_embeds = [] 542 | if ip_adapter_image_embeds is None: 543 | if not isinstance(ip_adapter_image, list): 544 | ip_adapter_image = [ip_adapter_image] 545 | 546 | if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): 547 | raise ValueError( 548 | f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." 549 | ) 550 | 551 | for single_ip_adapter_image, image_proj_layer in zip( 552 | ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers 553 | ): 554 | output_hidden_state = not isinstance( 555 | image_proj_layer, ImageProjection) 556 | single_image_embeds, single_negative_image_embeds = self.encode_image( 557 | single_ip_adapter_image, device, 1, output_hidden_state 558 | ) 559 | 560 | image_embeds.append(single_image_embeds[None, :]) 561 | if do_classifier_free_guidance: 562 | negative_image_embeds.append( 563 | single_negative_image_embeds[None, :]) 564 | else: 565 | for single_image_embeds in ip_adapter_image_embeds: 566 | if do_classifier_free_guidance: 567 | single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk( 568 | 2) 569 | negative_image_embeds.append(single_negative_image_embeds) 570 | image_embeds.append(single_image_embeds) 571 | 572 | ip_adapter_image_embeds = [] 573 | for i, single_image_embeds in enumerate(image_embeds): 574 | single_image_embeds = torch.cat( 575 | [single_image_embeds] * num_images_per_prompt, dim=0) 576 | if do_classifier_free_guidance: 577 | single_negative_image_embeds = torch.cat( 578 | [negative_image_embeds[i]] * num_images_per_prompt, dim=0) 579 | single_image_embeds = torch.cat( 580 | [single_negative_image_embeds, single_image_embeds], dim=0) 581 | 582 | single_image_embeds = single_image_embeds.to(device=device) 583 | ip_adapter_image_embeds.append(single_image_embeds) 584 | 585 | return ip_adapter_image_embeds 586 | 587 | def decode_latents(self, latents): 588 | deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" 589 | deprecate("decode_latents", "1.0.0", 590 | deprecation_message, standard_warn=False) 591 | 592 | latents = 1 / self.vae.config.scaling_factor * latents 593 | image = self.vae.decode(latents, return_dict=False)[0] 594 | image = (image / 2 + 0.5).clamp(0, 1) 595 | # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 596 | image = image.cpu().permute(0, 2, 3, 1).float().numpy() 597 | return image 598 | 599 | def prepare_extra_step_kwargs(self, generator, eta): 600 | # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature 601 | # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. 602 | # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 603 | # and should be between [0, 1] 604 | 605 | accepts_eta = "eta" in set(inspect.signature( 606 | self.scheduler.step).parameters.keys()) 607 | extra_step_kwargs = {} 608 | if accepts_eta: 609 | extra_step_kwargs["eta"] = eta 610 | 611 | # check if the scheduler accepts generator 612 | accepts_generator = "generator" in set( 613 | inspect.signature(self.scheduler.step).parameters.keys()) 614 | if accepts_generator: 615 | extra_step_kwargs["generator"] = generator 616 | return extra_step_kwargs 617 | 618 | def check_inputs( 619 | self, 620 | prompt, 621 | height, 622 | width, 623 | callback_steps, 624 | negative_prompt=None, 625 | prompt_embeds=None, 626 | negative_prompt_embeds=None, 627 | ip_adapter_image=None, 628 | ip_adapter_image_embeds=None, 629 | callback_on_step_end_tensor_inputs=None, 630 | ): 631 | if height % 8 != 0 or width % 8 != 0: 632 | raise ValueError( 633 | f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") 634 | 635 | if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): 636 | raise ValueError( 637 | f"`callback_steps` has to be a positive integer but is {callback_steps} of type" 638 | f" {type(callback_steps)}." 639 | ) 640 | if callback_on_step_end_tensor_inputs is not None and not all( 641 | k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs 642 | ): 643 | raise ValueError( 644 | f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" 645 | ) 646 | 647 | if prompt is not None and prompt_embeds is not None: 648 | raise ValueError( 649 | f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" 650 | " only forward one of the two." 651 | ) 652 | elif prompt is None and prompt_embeds is None: 653 | raise ValueError( 654 | "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." 655 | ) 656 | elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): 657 | raise ValueError( 658 | f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") 659 | 660 | if negative_prompt is not None and negative_prompt_embeds is not None: 661 | raise ValueError( 662 | f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" 663 | f" {negative_prompt_embeds}. Please make sure to only forward one of the two." 664 | ) 665 | 666 | if prompt_embeds is not None and negative_prompt_embeds is not None: 667 | if prompt_embeds.shape != negative_prompt_embeds.shape: 668 | raise ValueError( 669 | "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" 670 | f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" 671 | f" {negative_prompt_embeds.shape}." 672 | ) 673 | 674 | if ip_adapter_image is not None and ip_adapter_image_embeds is not None: 675 | raise ValueError( 676 | "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." 677 | ) 678 | 679 | if ip_adapter_image_embeds is not None: 680 | if not isinstance(ip_adapter_image_embeds, list): 681 | raise ValueError( 682 | f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" 683 | ) 684 | elif ip_adapter_image_embeds[0].ndim not in [3, 4]: 685 | raise ValueError( 686 | f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" 687 | ) 688 | 689 | def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): 690 | shape = ( 691 | batch_size, 692 | num_channels_latents, 693 | int(height) // self.vae_scale_factor, 694 | int(width) // self.vae_scale_factor, 695 | ) 696 | if isinstance(generator, list) and len(generator) != batch_size: 697 | raise ValueError( 698 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" 699 | f" size of {batch_size}. Make sure the batch size matches the length of the generators." 700 | ) 701 | 702 | if latents is None: 703 | latents = randn_tensor( 704 | shape, generator=generator, device=device, dtype=dtype) 705 | else: 706 | latents = latents.to(device) 707 | 708 | # scale the initial noise by the standard deviation required by the scheduler 709 | latents = latents * self.scheduler.init_noise_sigma 710 | return latents 711 | 712 | # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding 713 | def get_guidance_scale_embedding( 714 | self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 715 | ) -> torch.Tensor: 716 | """ 717 | See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 718 | 719 | Args: 720 | w (`torch.Tensor`): 721 | Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. 722 | embedding_dim (`int`, *optional*, defaults to 512): 723 | Dimension of the embeddings to generate. 724 | dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): 725 | Data type of the generated embeddings. 726 | 727 | Returns: 728 | `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. 729 | """ 730 | assert len(w.shape) == 1 731 | w = w * 1000.0 732 | 733 | half_dim = embedding_dim // 2 734 | emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) 735 | emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) 736 | emb = w.to(dtype)[:, None] * emb[None, :] 737 | emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) 738 | if embedding_dim % 2 == 1: # zero pad 739 | emb = torch.nn.functional.pad(emb, (0, 1)) 740 | assert emb.shape == (w.shape[0], embedding_dim) 741 | return emb 742 | 743 | @property 744 | def guidance_scale(self): 745 | return self._guidance_scale 746 | 747 | @property 748 | def guidance_rescale(self): 749 | return self._guidance_rescale 750 | 751 | @property 752 | def clip_skip(self): 753 | return self._clip_skip 754 | 755 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) 756 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` 757 | # corresponds to doing no classifier free guidance. 758 | @property 759 | def do_classifier_free_guidance(self): 760 | return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None 761 | 762 | @property 763 | def cross_attention_kwargs(self): 764 | return self._cross_attention_kwargs 765 | 766 | @property 767 | def num_timesteps(self): 768 | return self._num_timesteps 769 | 770 | @property 771 | def interrupt(self): 772 | return self._interrupt 773 | 774 | def decode_rgbd(self, latents, generator, output_type="np"): 775 | dem_latents = latents[:, 4:, :, :] 776 | img_latents = latents[:, :4, :, :] 777 | image = self.vae.decode(img_latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ 778 | 0 779 | ] 780 | dem = self.vae.decode(dem_latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ 781 | 0 782 | ] 783 | do_denormalize = [True] * image.shape[0] 784 | image = self.image_processor.postprocess( 785 | image, output_type=output_type, do_denormalize=do_denormalize) 786 | dem = self.image_processor.postprocess( 787 | dem, output_type=output_type, do_denormalize=do_denormalize) 788 | return image, dem 789 | 790 | @torch.no_grad() 791 | @replace_example_docstring(EXAMPLE_DOC_STRING) 792 | def __call__( 793 | self, 794 | prompt: Union[str, List[str]] = None, 795 | height: Optional[int] = None, 796 | width: Optional[int] = None, 797 | num_inference_steps: int = 50, 798 | timesteps: List[int] = None, 799 | sigmas: List[float] = None, 800 | guidance_scale: float = 7.5, 801 | negative_prompt: Optional[Union[str, List[str]]] = None, 802 | num_images_per_prompt: Optional[int] = 1, 803 | eta: float = 0.0, 804 | generator: Optional[Union[torch.Generator, 805 | List[torch.Generator]]] = None, 806 | latents: Optional[torch.Tensor] = None, 807 | prompt_embeds: Optional[torch.Tensor] = None, 808 | negative_prompt_embeds: Optional[torch.Tensor] = None, 809 | ip_adapter_image: Optional[PipelineImageInput] = None, 810 | ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, 811 | output_type: Optional[str] = "np", 812 | return_dict: bool = True, 813 | cross_attention_kwargs: Optional[Dict[str, Any]] = None, 814 | guidance_rescale: float = 0.0, 815 | clip_skip: Optional[int] = None, 816 | callback_on_step_end: Optional[ 817 | Union[Callable[[int, int, Dict], None], 818 | PipelineCallback, MultiPipelineCallbacks] 819 | ] = None, 820 | callback_on_step_end_tensor_inputs: List[str] = ["latents"], 821 | **kwargs, 822 | ): 823 | r""" 824 | The call function to the pipeline for generation. 825 | 826 | Args: 827 | prompt (`str` or `List[str]`, *optional*): 828 | The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. 829 | height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): 830 | The height in pixels of the generated image. 831 | width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): 832 | The width in pixels of the generated image. 833 | num_inference_steps (`int`, *optional*, defaults to 50): 834 | The number of denoising steps. More denoising steps usually lead to a higher quality image at the 835 | expense of slower inference. 836 | timesteps (`List[int]`, *optional*): 837 | Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument 838 | in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is 839 | passed will be used. Must be in descending order. 840 | sigmas (`List[float]`, *optional*): 841 | Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in 842 | their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed 843 | will be used. 844 | guidance_scale (`float`, *optional*, defaults to 7.5): 845 | A higher guidance scale value encourages the model to generate images closely linked to the text 846 | `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. 847 | negative_prompt (`str` or `List[str]`, *optional*): 848 | The prompt or prompts to guide what to not include in image generation. If not defined, you need to 849 | pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). 850 | num_images_per_prompt (`int`, *optional*, defaults to 1): 851 | The number of images to generate per prompt. 852 | eta (`float`, *optional*, defaults to 0.0): 853 | Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies 854 | to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. 855 | generator (`torch.Generator` or `List[torch.Generator]`, *optional*): 856 | A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make 857 | generation deterministic. 858 | latents (`torch.Tensor`, *optional*): 859 | Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image 860 | generation. Can be used to tweak the same generation with different prompts. If not provided, a latents 861 | tensor is generated by sampling using the supplied random `generator`. 862 | prompt_embeds (`torch.Tensor`, *optional*): 863 | Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not 864 | provided, text embeddings are generated from the `prompt` input argument. 865 | negative_prompt_embeds (`torch.Tensor`, *optional*): 866 | Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If 867 | not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. 868 | ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. 869 | ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): 870 | Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of 871 | IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should 872 | contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not 873 | provided, embeddings are computed from the `ip_adapter_image` input argument. 874 | output_type (`str`, *optional*, defaults to `"pil"`): 875 | The output format of the generated image. Choose between `PIL.Image` or `np.array`. 876 | return_dict (`bool`, *optional*, defaults to `True`): 877 | Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a 878 | plain tuple. 879 | cross_attention_kwargs (`dict`, *optional*): 880 | A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in 881 | [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). 882 | guidance_rescale (`float`, *optional*, defaults to 0.0): 883 | Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are 884 | Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when 885 | using zero terminal SNR. 886 | clip_skip (`int`, *optional*): 887 | Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that 888 | the output of the pre-final layer will be used for computing the prompt embeddings. 889 | callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): 890 | A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of 891 | each denoising step during the inference. with the following arguments: `callback_on_step_end(self: 892 | DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a 893 | list of all tensors as specified by `callback_on_step_end_tensor_inputs`. 894 | callback_on_step_end_tensor_inputs (`List`, *optional*): 895 | The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list 896 | will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the 897 | `._callback_tensor_inputs` attribute of your pipeline class. 898 | 899 | Examples: 900 | 901 | Returns: 902 | [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: 903 | If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, 904 | otherwise a `tuple` is returned where the first element is a list with the generated images and the 905 | second element is a list of `bool`s indicating whether the corresponding generated image contains 906 | "not-safe-for-work" (nsfw) content. 907 | """ 908 | 909 | callback = kwargs.pop("callback", None) 910 | callback_steps = kwargs.pop("callback_steps", None) 911 | 912 | if callback is not None: 913 | deprecate( 914 | "callback", 915 | "1.0.0", 916 | "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", 917 | ) 918 | if callback_steps is not None: 919 | deprecate( 920 | "callback_steps", 921 | "1.0.0", 922 | "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", 923 | ) 924 | 925 | if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): 926 | callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs 927 | 928 | # 0. Default height and width to unet 929 | height = height or self.unet.config.sample_size * self.vae_scale_factor 930 | width = width or self.unet.config.sample_size * self.vae_scale_factor 931 | # to deal with lora scaling and other possible forward hooks 932 | 933 | # 1. Check inputs. Raise error if not correct 934 | self.check_inputs( 935 | prompt, 936 | height, 937 | width, 938 | callback_steps, 939 | negative_prompt, 940 | prompt_embeds, 941 | negative_prompt_embeds, 942 | ip_adapter_image, 943 | ip_adapter_image_embeds, 944 | callback_on_step_end_tensor_inputs, 945 | ) 946 | 947 | self._guidance_scale = guidance_scale 948 | self._guidance_rescale = guidance_rescale 949 | self._clip_skip = clip_skip 950 | self._cross_attention_kwargs = cross_attention_kwargs 951 | self._interrupt = False 952 | 953 | # 2. Define call parameters 954 | if prompt is not None and isinstance(prompt, str): 955 | batch_size = 1 956 | elif prompt is not None and isinstance(prompt, list): 957 | batch_size = len(prompt) 958 | else: 959 | batch_size = prompt_embeds.shape[0] 960 | 961 | device = self._execution_device 962 | 963 | # 3. Encode input prompt 964 | lora_scale = ( 965 | self.cross_attention_kwargs.get( 966 | "scale", None) if self.cross_attention_kwargs is not None else None 967 | ) 968 | 969 | prompt_embeds, negative_prompt_embeds = self.encode_prompt( 970 | prompt, 971 | device, 972 | num_images_per_prompt, 973 | self.do_classifier_free_guidance, 974 | negative_prompt, 975 | prompt_embeds=prompt_embeds, 976 | negative_prompt_embeds=negative_prompt_embeds, 977 | lora_scale=lora_scale, 978 | clip_skip=self.clip_skip, 979 | ) 980 | 981 | # For classifier free guidance, we need to do two forward passes. 982 | # Here we concatenate the unconditional and text embeddings into a single batch 983 | # to avoid doing two forward passes 984 | if self.do_classifier_free_guidance: 985 | prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) 986 | 987 | if ip_adapter_image is not None or ip_adapter_image_embeds is not None: 988 | image_embeds = self.prepare_ip_adapter_image_embeds( 989 | ip_adapter_image, 990 | ip_adapter_image_embeds, 991 | device, 992 | batch_size * num_images_per_prompt, 993 | self.do_classifier_free_guidance, 994 | ) 995 | 996 | # 4. Prepare timesteps 997 | timesteps, num_inference_steps = retrieve_timesteps( 998 | self.scheduler, num_inference_steps, device, timesteps, sigmas 999 | ) 1000 | 1001 | # 5. Prepare latent variables 1002 | num_channels_latents = self.unet.config.in_channels*2 1003 | latents = self.prepare_latents( 1004 | batch_size * num_images_per_prompt, 1005 | num_channels_latents, 1006 | height, 1007 | width, 1008 | prompt_embeds.dtype, 1009 | device, 1010 | generator, 1011 | latents, 1012 | ) 1013 | 1014 | # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline 1015 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) 1016 | 1017 | # 6.1 Add image embeds for IP-Adapter 1018 | added_cond_kwargs = ( 1019 | {"image_embeds": image_embeds} 1020 | if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) 1021 | else None 1022 | ) 1023 | 1024 | # 6.2 Optionally get Guidance Scale Embedding 1025 | timestep_cond = None 1026 | if self.unet.config.time_cond_proj_dim is not None: 1027 | guidance_scale_tensor = torch.tensor( 1028 | self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) 1029 | timestep_cond = self.get_guidance_scale_embedding( 1030 | guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim 1031 | ).to(device=device, dtype=latents.dtype) 1032 | 1033 | # 7. Denoising loop 1034 | num_warmup_steps = len(timesteps) - \ 1035 | num_inference_steps * self.scheduler.order 1036 | self._num_timesteps = len(timesteps) 1037 | # intermediate_latents = [] 1038 | with self.progress_bar(total=num_inference_steps) as progress_bar: 1039 | for i, t in enumerate(timesteps): 1040 | if self.interrupt: 1041 | continue 1042 | 1043 | # expand the latents if we are doing classifier free guidance 1044 | latent_model_input = torch.cat( 1045 | [latents] * 2) if self.do_classifier_free_guidance else latents 1046 | latent_model_input = self.scheduler.scale_model_input( 1047 | latent_model_input, t) 1048 | 1049 | # predict the noise residual 1050 | noise_pred = self.unet( 1051 | latent_model_input, 1052 | t, 1053 | encoder_hidden_states=prompt_embeds, 1054 | timestep_cond=timestep_cond, 1055 | cross_attention_kwargs=self.cross_attention_kwargs, 1056 | added_cond_kwargs=added_cond_kwargs, 1057 | return_dict=False, 1058 | )[0] 1059 | 1060 | # perform guidance 1061 | if self.do_classifier_free_guidance: 1062 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 1063 | noise_pred = noise_pred_uncond + self.guidance_scale * \ 1064 | (noise_pred_text - noise_pred_uncond) 1065 | 1066 | if self.do_classifier_free_guidance and self.guidance_rescale > 0.0: 1067 | # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf 1068 | noise_pred = rescale_noise_cfg( 1069 | noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) 1070 | 1071 | # compute the previous noisy sample x_t -> x_t-1 1072 | scheduler_output = self.scheduler.step( 1073 | noise_pred, t, latents, **extra_step_kwargs, return_dict=True) 1074 | latents = scheduler_output.prev_sample 1075 | # if i % 10 == 0: 1076 | # intermediate_latents.append(scheduler_output.pred_original_sample) 1077 | if callback_on_step_end is not None: 1078 | callback_kwargs = {} 1079 | for k in callback_on_step_end_tensor_inputs: 1080 | callback_kwargs[k] = locals()[k] 1081 | callback_outputs = callback_on_step_end( 1082 | self, i, t, callback_kwargs) 1083 | 1084 | latents = callback_outputs.pop("latents", latents) 1085 | prompt_embeds = callback_outputs.pop( 1086 | "prompt_embeds", prompt_embeds) 1087 | negative_prompt_embeds = callback_outputs.pop( 1088 | "negative_prompt_embeds", negative_prompt_embeds) 1089 | 1090 | # call the callback, if provided 1091 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): 1092 | progress_bar.update() 1093 | if callback is not None and i % callback_steps == 0: 1094 | step_idx = i // getattr(self.scheduler, "order", 1) 1095 | callback(step_idx, t, latents) 1096 | 1097 | if XLA_AVAILABLE: 1098 | xm.mark_step() 1099 | 1100 | image, dem = self.decode_rgbd(latents, generator, output_type) 1101 | 1102 | # intermediate = [self.decode_rgbd(latent,generator,output_type)for latent in intermediate_latents] 1103 | 1104 | # Offload all models 1105 | self.maybe_free_model_hooks() 1106 | 1107 | return image, dem 1108 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | ########################################################################### 2 | # References: 3 | # https://github.com/huggingface/diffusers/ 4 | ########################################################################### 5 | 6 | from dataclasses import dataclass 7 | from typing import Any, Dict, List, Optional, Tuple, Union 8 | 9 | import torch 10 | import torch.nn as nn 11 | 12 | from diffusers.configuration_utils import ConfigMixin, register_to_config 13 | from diffusers.loaders import PeftAdapterMixin, UNet2DConditionLoadersMixin 14 | from diffusers.loaders.single_file_model import FromOriginalModelMixin 15 | from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers 16 | from diffusers.models.activations import get_activation 17 | from diffusers.models.attention_processor import ( 18 | ADDED_KV_ATTENTION_PROCESSORS, 19 | CROSS_ATTENTION_PROCESSORS, 20 | Attention, 21 | AttentionProcessor, 22 | AttnAddedKVProcessor, 23 | AttnProcessor, 24 | FusedAttnProcessor2_0, 25 | ) 26 | from diffusers.models.embeddings import ( 27 | GaussianFourierProjection, 28 | GLIGENTextBoundingboxProjection, 29 | ImageHintTimeEmbedding, 30 | ImageProjection, 31 | ImageTimeEmbedding, 32 | TextImageProjection, 33 | TextImageTimeEmbedding, 34 | TextTimeEmbedding, 35 | TimestepEmbedding, 36 | Timesteps, 37 | ) 38 | from diffusers.models.modeling_utils import ModelMixin 39 | from diffusers.models.unets.unet_2d_blocks import ( 40 | get_down_block, 41 | get_mid_block, 42 | get_up_block, 43 | ) 44 | 45 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 46 | 47 | 48 | @dataclass 49 | class UNet2DConditionOutput(BaseOutput): 50 | """ 51 | The output of [`UNet2DConditionModel`]. 52 | 53 | Args: 54 | sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)`): 55 | The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model. 56 | """ 57 | 58 | sample: torch.Tensor = None 59 | 60 | 61 | class UNetDEMConditionModel( 62 | ModelMixin, ConfigMixin, FromOriginalModelMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin 63 | ): 64 | r""" 65 | A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample 66 | shaped output. 67 | 68 | This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented 69 | for all models (such as downloading or saving). 70 | 71 | Parameters: 72 | sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): 73 | Height and width of input/output sample. 74 | in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample. 75 | out_channels (`int`, *optional*, defaults to 4): Number of channels in the output. 76 | center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample. 77 | flip_sin_to_cos (`bool`, *optional*, defaults to `True`): 78 | Whether to flip the sin to cos in the time embedding. 79 | freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding. 80 | down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): 81 | The tuple of downsample blocks to use. 82 | mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`): 83 | Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`, `UNetMidBlock2D`, or 84 | `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped. 85 | up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`): 86 | The tuple of upsample blocks to use. 87 | only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`): 88 | Whether to include self-attention in the basic transformer blocks, see 89 | [`~models.attention.BasicTransformerBlock`]. 90 | block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): 91 | The tuple of output channels for each block. 92 | layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. 93 | downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution. 94 | mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block. 95 | dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. 96 | act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. 97 | norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. 98 | If `None`, normalization and activation layers is skipped in post-processing. 99 | norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. 100 | cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280): 101 | The dimension of the cross attention features. 102 | transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1): 103 | The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for 104 | [`~models.unets.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unets.unet_2d_blocks.CrossAttnUpBlock2D`], 105 | [`~models.unets.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. 106 | reverse_transformer_layers_per_block : (`Tuple[Tuple]`, *optional*, defaults to None): 107 | The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`], in the upsampling 108 | blocks of the U-Net. Only relevant if `transformer_layers_per_block` is of type `Tuple[Tuple]` and for 109 | [`~models.unets.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unets.unet_2d_blocks.CrossAttnUpBlock2D`], 110 | [`~models.unets.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. 111 | encoder_hid_dim (`int`, *optional*, defaults to None): 112 | If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim` 113 | dimension to `cross_attention_dim`. 114 | encoder_hid_dim_type (`str`, *optional*, defaults to `None`): 115 | If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text 116 | embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`. 117 | attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. 118 | num_attention_heads (`int`, *optional*): 119 | The number of attention heads. If not defined, defaults to `attention_head_dim` 120 | resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config 121 | for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`. 122 | class_embed_type (`str`, *optional*, defaults to `None`): 123 | The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`, 124 | `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`. 125 | addition_embed_type (`str`, *optional*, defaults to `None`): 126 | Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or 127 | "text". "text" will use the `TextTimeEmbedding` layer. 128 | addition_time_embed_dim: (`int`, *optional*, defaults to `None`): 129 | Dimension for the timestep embeddings. 130 | num_class_embeds (`int`, *optional*, defaults to `None`): 131 | Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing 132 | class conditioning with `class_embed_type` equal to `None`. 133 | time_embedding_type (`str`, *optional*, defaults to `positional`): 134 | The type of position embedding to use for timesteps. Choose from `positional` or `fourier`. 135 | time_embedding_dim (`int`, *optional*, defaults to `None`): 136 | An optional override for the dimension of the projected time embedding. 137 | time_embedding_act_fn (`str`, *optional*, defaults to `None`): 138 | Optional activation function to use only once on the time embeddings before they are passed to the rest of 139 | the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`. 140 | timestep_post_act (`str`, *optional*, defaults to `None`): 141 | The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`. 142 | time_cond_proj_dim (`int`, *optional*, defaults to `None`): 143 | The dimension of `cond_proj` layer in the timestep embedding. 144 | conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. 145 | conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer. 146 | projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when 147 | `class_embed_type="projection"`. Required when `class_embed_type="projection"`. 148 | class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time 149 | embeddings with the class embeddings. 150 | mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`): 151 | Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If 152 | `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the 153 | `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False` 154 | otherwise. 155 | """ 156 | 157 | _supports_gradient_checkpointing = True 158 | _no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D", "CrossAttnUpBlock2D"] 159 | 160 | @register_to_config 161 | def __init__( 162 | self, 163 | sample_size: Optional[int] = None, 164 | in_channels: int = 8, 165 | out_channels: int = 8, 166 | center_input_sample: bool = False, 167 | flip_sin_to_cos: bool = True, 168 | freq_shift: int = 0, 169 | down_block_types: Tuple[str] = ( 170 | "CrossAttnDownBlock2D", 171 | "CrossAttnDownBlock2D", 172 | "CrossAttnDownBlock2D", 173 | "DownBlock2D", 174 | ), 175 | mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn", 176 | up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"), 177 | only_cross_attention: Union[bool, Tuple[bool]] = False, 178 | block_out_channels: Tuple[int] = (320, 640, 1280, 1280), 179 | layers_per_block: Union[int, Tuple[int]] = 2, 180 | downsample_padding: int = 1, 181 | mid_block_scale_factor: float = 1, 182 | dropout: float = 0.0, 183 | act_fn: str = "silu", 184 | norm_num_groups: Optional[int] = 32, 185 | norm_eps: float = 1e-5, 186 | cross_attention_dim: Union[int, Tuple[int]] = 1280, 187 | transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1, 188 | reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None, 189 | encoder_hid_dim: Optional[int] = None, 190 | encoder_hid_dim_type: Optional[str] = None, 191 | attention_head_dim: Union[int, Tuple[int]] = 8, 192 | num_attention_heads: Optional[Union[int, Tuple[int]]] = None, 193 | dual_cross_attention: bool = False, 194 | use_linear_projection: bool = False, 195 | class_embed_type: Optional[str] = None, 196 | addition_embed_type: Optional[str] = None, 197 | addition_time_embed_dim: Optional[int] = None, 198 | num_class_embeds: Optional[int] = None, 199 | upcast_attention: bool = False, 200 | resnet_time_scale_shift: str = "default", 201 | resnet_skip_time_act: bool = False, 202 | resnet_out_scale_factor: float = 1.0, 203 | time_embedding_type: str = "positional", 204 | time_embedding_dim: Optional[int] = None, 205 | time_embedding_act_fn: Optional[str] = None, 206 | timestep_post_act: Optional[str] = None, 207 | time_cond_proj_dim: Optional[int] = None, 208 | conv_in_kernel: int = 3, 209 | conv_out_kernel: int = 3, 210 | projection_class_embeddings_input_dim: Optional[int] = None, 211 | attention_type: str = "default", 212 | class_embeddings_concat: bool = False, 213 | mid_block_only_cross_attention: Optional[bool] = None, 214 | cross_attention_norm: Optional[str] = None, 215 | addition_embed_type_num_heads: int = 64, 216 | ): 217 | super().__init__() 218 | 219 | self.sample_size = sample_size 220 | 221 | if num_attention_heads is not None: 222 | raise ValueError( 223 | "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19." 224 | ) 225 | 226 | # If `num_attention_heads` is not defined (which is the case for most models) 227 | # it will default to `attention_head_dim`. This looks weird upon first reading it and it is. 228 | # The reason for this behavior is to correct for incorrectly named variables that were introduced 229 | # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 230 | # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking 231 | # which is why we correct for the naming here. 232 | num_attention_heads = num_attention_heads or attention_head_dim 233 | 234 | # Check inputs 235 | self._check_config( 236 | down_block_types=down_block_types, 237 | up_block_types=up_block_types, 238 | only_cross_attention=only_cross_attention, 239 | block_out_channels=block_out_channels, 240 | layers_per_block=layers_per_block, 241 | cross_attention_dim=cross_attention_dim, 242 | transformer_layers_per_block=transformer_layers_per_block, 243 | reverse_transformer_layers_per_block=reverse_transformer_layers_per_block, 244 | attention_head_dim=attention_head_dim, 245 | num_attention_heads=num_attention_heads, 246 | ) 247 | 248 | # input 249 | conv_in_padding = (conv_in_kernel - 1) // 2 250 | self.conv_in_img = nn.Conv2d( 251 | in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding 252 | ) 253 | self.conv_in_dem = nn.Conv2d( 254 | in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding 255 | ) 256 | 257 | # time 258 | time_embed_dim, timestep_input_dim = self._set_time_proj( 259 | time_embedding_type, 260 | block_out_channels=block_out_channels, 261 | flip_sin_to_cos=flip_sin_to_cos, 262 | freq_shift=freq_shift, 263 | time_embedding_dim=time_embedding_dim, 264 | ) 265 | 266 | self.time_embedding = TimestepEmbedding( 267 | timestep_input_dim, 268 | time_embed_dim, 269 | act_fn=act_fn, 270 | post_act_fn=timestep_post_act, 271 | cond_proj_dim=time_cond_proj_dim, 272 | ) 273 | 274 | self._set_encoder_hid_proj( 275 | encoder_hid_dim_type, 276 | cross_attention_dim=cross_attention_dim, 277 | encoder_hid_dim=encoder_hid_dim, 278 | ) 279 | 280 | # class embedding 281 | self._set_class_embedding( 282 | class_embed_type, 283 | act_fn=act_fn, 284 | num_class_embeds=num_class_embeds, 285 | projection_class_embeddings_input_dim=projection_class_embeddings_input_dim, 286 | time_embed_dim=time_embed_dim, 287 | timestep_input_dim=timestep_input_dim, 288 | ) 289 | 290 | self._set_add_embedding( 291 | addition_embed_type, 292 | addition_embed_type_num_heads=addition_embed_type_num_heads, 293 | addition_time_embed_dim=addition_time_embed_dim, 294 | cross_attention_dim=cross_attention_dim, 295 | encoder_hid_dim=encoder_hid_dim, 296 | flip_sin_to_cos=flip_sin_to_cos, 297 | freq_shift=freq_shift, 298 | projection_class_embeddings_input_dim=projection_class_embeddings_input_dim, 299 | time_embed_dim=time_embed_dim, 300 | ) 301 | 302 | if time_embedding_act_fn is None: 303 | self.time_embed_act = None 304 | else: 305 | self.time_embed_act = get_activation(time_embedding_act_fn) 306 | 307 | self.down_blocks = nn.ModuleList([]) 308 | self.up_blocks = nn.ModuleList([]) 309 | 310 | if isinstance(only_cross_attention, bool): 311 | if mid_block_only_cross_attention is None: 312 | mid_block_only_cross_attention = only_cross_attention 313 | 314 | only_cross_attention = [only_cross_attention] * len(down_block_types) 315 | 316 | if mid_block_only_cross_attention is None: 317 | mid_block_only_cross_attention = False 318 | 319 | if isinstance(num_attention_heads, int): 320 | num_attention_heads = (num_attention_heads,) * len(down_block_types) 321 | 322 | if isinstance(attention_head_dim, int): 323 | attention_head_dim = (attention_head_dim,) * len(down_block_types) 324 | 325 | if isinstance(cross_attention_dim, int): 326 | cross_attention_dim = (cross_attention_dim,) * len(down_block_types) 327 | 328 | if isinstance(layers_per_block, int): 329 | layers_per_block = [layers_per_block] * len(down_block_types) 330 | 331 | if isinstance(transformer_layers_per_block, int): 332 | transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) 333 | 334 | if class_embeddings_concat: 335 | # The time embeddings are concatenated with the class embeddings. The dimension of the 336 | # time embeddings passed to the down, middle, and up blocks is twice the dimension of the 337 | # regular time embeddings 338 | blocks_time_embed_dim = time_embed_dim * 2 339 | else: 340 | blocks_time_embed_dim = time_embed_dim 341 | 342 | # down 343 | output_channel = block_out_channels[0] 344 | 345 | 346 | for i, down_block_type in enumerate(down_block_types): 347 | input_channel = output_channel 348 | output_channel = block_out_channels[i] 349 | is_final_block = i == len(block_out_channels) - 1 350 | down_block_kwargs = {"down_block_type":down_block_type, 351 | "num_layers":layers_per_block[i], 352 | "transformer_layers_per_block":transformer_layers_per_block[i], 353 | "in_channels":input_channel, 354 | "out_channels":output_channel, 355 | "temb_channels":blocks_time_embed_dim, 356 | "add_downsample":not is_final_block, 357 | "resnet_eps":norm_eps, 358 | "resnet_act_fn":act_fn, 359 | "resnet_groups":norm_num_groups, 360 | "cross_attention_dim":cross_attention_dim[i], 361 | "num_attention_heads":num_attention_heads[i], 362 | "downsample_padding":downsample_padding, 363 | "dual_cross_attention":dual_cross_attention, 364 | "use_linear_projection":use_linear_projection, 365 | "only_cross_attention":only_cross_attention[i], 366 | "upcast_attention":upcast_attention, 367 | "resnet_time_scale_shift":resnet_time_scale_shift, 368 | "attention_type":attention_type, 369 | "resnet_skip_time_act":resnet_skip_time_act, 370 | "resnet_out_scale_factor":resnet_out_scale_factor, 371 | "cross_attention_norm":cross_attention_norm, 372 | "attention_head_dim":attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, 373 | "dropout":dropout} 374 | 375 | if i == 0: 376 | self.head_img = get_down_block(**down_block_kwargs) 377 | # same architecture as the head_img but different weights 378 | self.head_dem = get_down_block(**down_block_kwargs) 379 | # elif i == 1: 380 | # down_block_kwargs["in_channels"] = input_channel *2 # concatenate the output of the head_img and head_dem 381 | # down_block = get_down_block(**down_block_kwargs) 382 | # self.down_blocks.append(down_block) 383 | else: 384 | down_block = get_down_block(**down_block_kwargs) 385 | self.down_blocks.append(down_block) 386 | 387 | # mid 388 | self.mid_block = get_mid_block( 389 | mid_block_type, 390 | temb_channels=blocks_time_embed_dim, 391 | in_channels=block_out_channels[-1], 392 | resnet_eps=norm_eps, 393 | resnet_act_fn=act_fn, 394 | resnet_groups=norm_num_groups, 395 | output_scale_factor=mid_block_scale_factor, 396 | transformer_layers_per_block=transformer_layers_per_block[-1], 397 | num_attention_heads=num_attention_heads[-1], 398 | cross_attention_dim=cross_attention_dim[-1], 399 | dual_cross_attention=dual_cross_attention, 400 | use_linear_projection=use_linear_projection, 401 | mid_block_only_cross_attention=mid_block_only_cross_attention, 402 | upcast_attention=upcast_attention, 403 | resnet_time_scale_shift=resnet_time_scale_shift, 404 | attention_type=attention_type, 405 | resnet_skip_time_act=resnet_skip_time_act, 406 | cross_attention_norm=cross_attention_norm, 407 | attention_head_dim=attention_head_dim[-1], 408 | dropout=dropout, 409 | ) 410 | 411 | # count how many layers upsample the images 412 | self.num_upsamplers = 0 413 | 414 | # up 415 | reversed_block_out_channels = list(reversed(block_out_channels)) 416 | reversed_num_attention_heads = list(reversed(num_attention_heads)) 417 | reversed_layers_per_block = list(reversed(layers_per_block)) 418 | reversed_cross_attention_dim = list(reversed(cross_attention_dim)) 419 | reversed_transformer_layers_per_block = ( 420 | list(reversed(transformer_layers_per_block)) 421 | if reverse_transformer_layers_per_block is None 422 | else reverse_transformer_layers_per_block 423 | ) 424 | only_cross_attention = list(reversed(only_cross_attention)) 425 | 426 | output_channel = reversed_block_out_channels[0] 427 | for i, up_block_type in enumerate(up_block_types): 428 | is_final_block = i == len(block_out_channels) - 1 429 | 430 | prev_output_channel = output_channel 431 | output_channel = reversed_block_out_channels[i] 432 | input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] 433 | 434 | # add upsample block for all BUT final layer 435 | if not is_final_block: 436 | add_upsample = True 437 | self.num_upsamplers += 1 438 | else: 439 | add_upsample = False 440 | 441 | up_block_kwargs = {"up_block_type":up_block_type, 442 | "num_layers":reversed_layers_per_block[i] + 1, 443 | "transformer_layers_per_block":reversed_transformer_layers_per_block[i], 444 | "in_channels":input_channel, 445 | "out_channels":output_channel, 446 | "prev_output_channel":prev_output_channel, 447 | "temb_channels":blocks_time_embed_dim, 448 | "add_upsample":add_upsample, 449 | "resnet_eps":norm_eps, 450 | "resnet_act_fn":act_fn, 451 | "resolution_idx":i, 452 | "resnet_groups":norm_num_groups, 453 | "cross_attention_dim":reversed_cross_attention_dim[i], 454 | "num_attention_heads":reversed_num_attention_heads[i], 455 | "dual_cross_attention":dual_cross_attention, 456 | "use_linear_projection":use_linear_projection, 457 | "only_cross_attention":only_cross_attention[i], 458 | "upcast_attention":upcast_attention, 459 | "resnet_time_scale_shift":resnet_time_scale_shift, 460 | "attention_type":attention_type, 461 | "resnet_skip_time_act":resnet_skip_time_act, 462 | "resnet_out_scale_factor":resnet_out_scale_factor, 463 | "cross_attention_norm":cross_attention_norm, 464 | "attention_head_dim":attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, 465 | "dropout":dropout,} 466 | 467 | # if i == len(block_out_channels) - 2: 468 | # up_block_kwargs["in_channels"] = input_channel*2 469 | # up_block = get_up_block(**up_block_kwargs) 470 | # self.up_blocks.append(up_block) 471 | 472 | if is_final_block : 473 | 474 | self.head_out_img = get_up_block(**up_block_kwargs) 475 | self.head_out_dem = get_up_block(**up_block_kwargs) 476 | 477 | else : 478 | up_block = get_up_block(**up_block_kwargs) 479 | self.up_blocks.append(up_block) 480 | 481 | 482 | # out 483 | if norm_num_groups is not None: 484 | self.conv_norm_out_img = nn.GroupNorm( 485 | num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps 486 | ) 487 | self.conv_norm_out_dem = nn.GroupNorm( 488 | num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps 489 | ) 490 | 491 | self.conv_act = get_activation(act_fn) 492 | 493 | else: 494 | self.conv_norm_out_img = None 495 | self.conv_norm_out_dem = None 496 | self.conv_act = None 497 | 498 | conv_out_padding = (conv_out_kernel - 1) // 2 499 | self.conv_out_img = nn.Conv2d( 500 | block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding 501 | ) 502 | self.conv_out_dem = nn.Conv2d( 503 | block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding 504 | ) 505 | 506 | self._set_pos_net_if_use_gligen(attention_type=attention_type, cross_attention_dim=cross_attention_dim) 507 | 508 | def _check_config( 509 | self, 510 | down_block_types: Tuple[str], 511 | up_block_types: Tuple[str], 512 | only_cross_attention: Union[bool, Tuple[bool]], 513 | block_out_channels: Tuple[int], 514 | layers_per_block: Union[int, Tuple[int]], 515 | cross_attention_dim: Union[int, Tuple[int]], 516 | transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple[int]]], 517 | reverse_transformer_layers_per_block: bool, 518 | attention_head_dim: int, 519 | num_attention_heads: Optional[Union[int, Tuple[int]]], 520 | ): 521 | if len(down_block_types) != len(up_block_types): 522 | raise ValueError( 523 | f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." 524 | ) 525 | 526 | if len(block_out_channels) != len(down_block_types): 527 | raise ValueError( 528 | f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." 529 | ) 530 | 531 | if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types): 532 | raise ValueError( 533 | f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." 534 | ) 535 | 536 | if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): 537 | raise ValueError( 538 | f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." 539 | ) 540 | 541 | if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types): 542 | raise ValueError( 543 | f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}." 544 | ) 545 | 546 | if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types): 547 | raise ValueError( 548 | f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}." 549 | ) 550 | 551 | if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types): 552 | raise ValueError( 553 | f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}." 554 | ) 555 | if isinstance(transformer_layers_per_block, list) and reverse_transformer_layers_per_block is None: 556 | for layer_number_per_block in transformer_layers_per_block: 557 | if isinstance(layer_number_per_block, list): 558 | raise ValueError("Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet.") 559 | 560 | def _set_time_proj( 561 | self, 562 | time_embedding_type: str, 563 | block_out_channels: int, 564 | flip_sin_to_cos: bool, 565 | freq_shift: float, 566 | time_embedding_dim: int, 567 | ) -> Tuple[int, int]: 568 | if time_embedding_type == "fourier": 569 | time_embed_dim = time_embedding_dim or block_out_channels[0] * 2 570 | if time_embed_dim % 2 != 0: 571 | raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.") 572 | self.time_proj = GaussianFourierProjection( 573 | time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos 574 | ) 575 | timestep_input_dim = time_embed_dim 576 | elif time_embedding_type == "positional": 577 | time_embed_dim = time_embedding_dim or block_out_channels[0] * 4 578 | 579 | self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) 580 | timestep_input_dim = block_out_channels[0] 581 | else: 582 | raise ValueError( 583 | f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`." 584 | ) 585 | 586 | return time_embed_dim, timestep_input_dim 587 | 588 | def _set_encoder_hid_proj( 589 | self, 590 | encoder_hid_dim_type: Optional[str], 591 | cross_attention_dim: Union[int, Tuple[int]], 592 | encoder_hid_dim: Optional[int], 593 | ): 594 | if encoder_hid_dim_type is None and encoder_hid_dim is not None: 595 | encoder_hid_dim_type = "text_proj" 596 | self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type) 597 | logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.") 598 | 599 | if encoder_hid_dim is None and encoder_hid_dim_type is not None: 600 | raise ValueError( 601 | f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}." 602 | ) 603 | 604 | if encoder_hid_dim_type == "text_proj": 605 | self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim) 606 | elif encoder_hid_dim_type == "text_image_proj": 607 | # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much 608 | # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use 609 | # case when `addition_embed_type == "text_image_proj"` (Kandinsky 2.1)` 610 | self.encoder_hid_proj = TextImageProjection( 611 | text_embed_dim=encoder_hid_dim, 612 | image_embed_dim=cross_attention_dim, 613 | cross_attention_dim=cross_attention_dim, 614 | ) 615 | elif encoder_hid_dim_type == "image_proj": 616 | # Kandinsky 2.2 617 | self.encoder_hid_proj = ImageProjection( 618 | image_embed_dim=encoder_hid_dim, 619 | cross_attention_dim=cross_attention_dim, 620 | ) 621 | elif encoder_hid_dim_type is not None: 622 | raise ValueError( 623 | f"`encoder_hid_dim_type`: {encoder_hid_dim_type} must be None, 'text_proj', 'text_image_proj', or 'image_proj'." 624 | ) 625 | else: 626 | self.encoder_hid_proj = None 627 | 628 | def _set_class_embedding( 629 | self, 630 | class_embed_type: Optional[str], 631 | act_fn: str, 632 | num_class_embeds: Optional[int], 633 | projection_class_embeddings_input_dim: Optional[int], 634 | time_embed_dim: int, 635 | timestep_input_dim: int, 636 | ): 637 | if class_embed_type is None and num_class_embeds is not None: 638 | self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) 639 | elif class_embed_type == "timestep": 640 | self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn) 641 | elif class_embed_type == "identity": 642 | self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) 643 | elif class_embed_type == "projection": 644 | if projection_class_embeddings_input_dim is None: 645 | raise ValueError( 646 | "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set" 647 | ) 648 | # The projection `class_embed_type` is the same as the timestep `class_embed_type` except 649 | # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings 650 | # 2. it projects from an arbitrary input dimension. 651 | # 652 | # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations. 653 | # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings. 654 | # As a result, `TimestepEmbedding` can be passed arbitrary vectors. 655 | self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) 656 | elif class_embed_type == "simple_projection": 657 | if projection_class_embeddings_input_dim is None: 658 | raise ValueError( 659 | "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set" 660 | ) 661 | self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim) 662 | else: 663 | self.class_embedding = None 664 | 665 | def _set_add_embedding( 666 | self, 667 | addition_embed_type: str, 668 | addition_embed_type_num_heads: int, 669 | addition_time_embed_dim: Optional[int], 670 | flip_sin_to_cos: bool, 671 | freq_shift: float, 672 | cross_attention_dim: Optional[int], 673 | encoder_hid_dim: Optional[int], 674 | projection_class_embeddings_input_dim: Optional[int], 675 | time_embed_dim: int, 676 | ): 677 | if addition_embed_type == "text": 678 | if encoder_hid_dim is not None: 679 | text_time_embedding_from_dim = encoder_hid_dim 680 | else: 681 | text_time_embedding_from_dim = cross_attention_dim 682 | 683 | self.add_embedding = TextTimeEmbedding( 684 | text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads 685 | ) 686 | elif addition_embed_type == "text_image": 687 | # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much 688 | # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use 689 | # case when `addition_embed_type == "text_image"` (Kandinsky 2.1)` 690 | self.add_embedding = TextImageTimeEmbedding( 691 | text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim 692 | ) 693 | elif addition_embed_type == "text_time": 694 | self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift) 695 | self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) 696 | elif addition_embed_type == "image": 697 | # Kandinsky 2.2 698 | self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim) 699 | elif addition_embed_type == "image_hint": 700 | # Kandinsky 2.2 ControlNet 701 | self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim) 702 | elif addition_embed_type is not None: 703 | raise ValueError( 704 | f"`addition_embed_type`: {addition_embed_type} must be None, 'text', 'text_image', 'text_time', 'image', or 'image_hint'." 705 | ) 706 | 707 | def _set_pos_net_if_use_gligen(self, attention_type: str, cross_attention_dim: int): 708 | if attention_type in ["gated", "gated-text-image"]: 709 | positive_len = 768 710 | if isinstance(cross_attention_dim, int): 711 | positive_len = cross_attention_dim 712 | elif isinstance(cross_attention_dim, (list, tuple)): 713 | positive_len = cross_attention_dim[0] 714 | 715 | feature_type = "text-only" if attention_type == "gated" else "text-image" 716 | self.position_net = GLIGENTextBoundingboxProjection( 717 | positive_len=positive_len, out_dim=cross_attention_dim, feature_type=feature_type 718 | ) 719 | 720 | @property 721 | def attn_processors(self) -> Dict[str, AttentionProcessor]: 722 | r""" 723 | Returns: 724 | `dict` of attention processors: A dictionary containing all attention processors used in the model with 725 | indexed by its weight name. 726 | """ 727 | # set recursively 728 | processors = {} 729 | 730 | def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): 731 | if hasattr(module, "get_processor"): 732 | processors[f"{name}.processor"] = module.get_processor() 733 | 734 | for sub_name, child in module.named_children(): 735 | fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) 736 | 737 | return processors 738 | 739 | for name, module in self.named_children(): 740 | fn_recursive_add_processors(name, module, processors) 741 | 742 | return processors 743 | 744 | def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): 745 | r""" 746 | Sets the attention processor to use to compute attention. 747 | 748 | Parameters: 749 | processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): 750 | The instantiated processor class or a dictionary of processor classes that will be set as the processor 751 | for **all** `Attention` layers. 752 | 753 | If `processor` is a dict, the key needs to define the path to the corresponding cross attention 754 | processor. This is strongly recommended when setting trainable attention processors. 755 | 756 | """ 757 | count = len(self.attn_processors.keys()) 758 | 759 | if isinstance(processor, dict) and len(processor) != count: 760 | raise ValueError( 761 | f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" 762 | f" number of attention layers: {count}. Please make sure to pass {count} processor classes." 763 | ) 764 | 765 | def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): 766 | if hasattr(module, "set_processor"): 767 | if not isinstance(processor, dict): 768 | module.set_processor(processor) 769 | else: 770 | module.set_processor(processor.pop(f"{name}.processor")) 771 | 772 | for sub_name, child in module.named_children(): 773 | fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) 774 | 775 | for name, module in self.named_children(): 776 | fn_recursive_attn_processor(name, module, processor) 777 | 778 | def set_default_attn_processor(self): 779 | """ 780 | Disables custom attention processors and sets the default attention implementation. 781 | """ 782 | if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): 783 | processor = AttnAddedKVProcessor() 784 | elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): 785 | processor = AttnProcessor() 786 | else: 787 | raise ValueError( 788 | f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" 789 | ) 790 | 791 | self.set_attn_processor(processor) 792 | 793 | def set_attention_slice(self, slice_size: Union[str, int, List[int]] = "auto"): 794 | r""" 795 | Enable sliced attention computation. 796 | 797 | When this option is enabled, the attention module splits the input tensor in slices to compute attention in 798 | several steps. This is useful for saving some memory in exchange for a small decrease in speed. 799 | 800 | Args: 801 | slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): 802 | When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If 803 | `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is 804 | provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` 805 | must be a multiple of `slice_size`. 806 | """ 807 | sliceable_head_dims = [] 808 | 809 | def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module): 810 | if hasattr(module, "set_attention_slice"): 811 | sliceable_head_dims.append(module.sliceable_head_dim) 812 | 813 | for child in module.children(): 814 | fn_recursive_retrieve_sliceable_dims(child) 815 | 816 | # retrieve number of attention layers 817 | for module in self.children(): 818 | fn_recursive_retrieve_sliceable_dims(module) 819 | 820 | num_sliceable_layers = len(sliceable_head_dims) 821 | 822 | if slice_size == "auto": 823 | # half the attention head size is usually a good trade-off between 824 | # speed and memory 825 | slice_size = [dim // 2 for dim in sliceable_head_dims] 826 | elif slice_size == "max": 827 | # make smallest slice possible 828 | slice_size = num_sliceable_layers * [1] 829 | 830 | slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size 831 | 832 | if len(slice_size) != len(sliceable_head_dims): 833 | raise ValueError( 834 | f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" 835 | f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." 836 | ) 837 | 838 | for i in range(len(slice_size)): 839 | size = slice_size[i] 840 | dim = sliceable_head_dims[i] 841 | if size is not None and size > dim: 842 | raise ValueError(f"size {size} has to be smaller or equal to {dim}.") 843 | 844 | # Recursively walk through all the children. 845 | # Any children which exposes the set_attention_slice method 846 | # gets the message 847 | def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): 848 | if hasattr(module, "set_attention_slice"): 849 | module.set_attention_slice(slice_size.pop()) 850 | 851 | for child in module.children(): 852 | fn_recursive_set_attention_slice(child, slice_size) 853 | 854 | reversed_slice_size = list(reversed(slice_size)) 855 | for module in self.children(): 856 | fn_recursive_set_attention_slice(module, reversed_slice_size) 857 | 858 | def _set_gradient_checkpointing(self, module, value=False): 859 | if hasattr(module, "gradient_checkpointing"): 860 | module.gradient_checkpointing = value 861 | 862 | def enable_freeu(self, s1: float, s2: float, b1: float, b2: float): 863 | r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497. 864 | 865 | The suffixes after the scaling factors represent the stage blocks where they are being applied. 866 | 867 | Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that 868 | are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL. 869 | 870 | Args: 871 | s1 (`float`): 872 | Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to 873 | mitigate the "oversmoothing effect" in the enhanced denoising process. 874 | s2 (`float`): 875 | Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to 876 | mitigate the "oversmoothing effect" in the enhanced denoising process. 877 | b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features. 878 | b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features. 879 | """ 880 | for i, upsample_block in enumerate(self.up_blocks): 881 | setattr(upsample_block, "s1", s1) 882 | setattr(upsample_block, "s2", s2) 883 | setattr(upsample_block, "b1", b1) 884 | setattr(upsample_block, "b2", b2) 885 | 886 | def disable_freeu(self): 887 | """Disables the FreeU mechanism.""" 888 | freeu_keys = {"s1", "s2", "b1", "b2"} 889 | for i, upsample_block in enumerate(self.up_blocks): 890 | for k in freeu_keys: 891 | if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None: 892 | setattr(upsample_block, k, None) 893 | 894 | def fuse_qkv_projections(self): 895 | """ 896 | Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) 897 | are fused. For cross-attention modules, key and value projection matrices are fused. 898 | 899 | 900 | 901 | This API is 🧪 experimental. 902 | 903 | 904 | """ 905 | self.original_attn_processors = None 906 | 907 | for _, attn_processor in self.attn_processors.items(): 908 | if "Added" in str(attn_processor.__class__.__name__): 909 | raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") 910 | 911 | self.original_attn_processors = self.attn_processors 912 | 913 | for module in self.modules(): 914 | if isinstance(module, Attention): 915 | module.fuse_projections(fuse=True) 916 | 917 | self.set_attn_processor(FusedAttnProcessor2_0()) 918 | 919 | def unfuse_qkv_projections(self): 920 | """Disables the fused QKV projection if enabled. 921 | 922 | 923 | 924 | This API is 🧪 experimental. 925 | 926 | 927 | 928 | """ 929 | if self.original_attn_processors is not None: 930 | self.set_attn_processor(self.original_attn_processors) 931 | 932 | def get_time_embed( 933 | self, sample: torch.Tensor, timestep: Union[torch.Tensor, float, int] 934 | ) -> Optional[torch.Tensor]: 935 | timesteps = timestep 936 | if not torch.is_tensor(timesteps): 937 | # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can 938 | # This would be a good case for the `match` statement (Python 3.10+) 939 | is_mps = sample.device.type == "mps" 940 | if isinstance(timestep, float): 941 | dtype = torch.float32 if is_mps else torch.float64 942 | else: 943 | dtype = torch.int32 if is_mps else torch.int64 944 | timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) 945 | elif len(timesteps.shape) == 0: 946 | timesteps = timesteps[None].to(sample.device) 947 | 948 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML 949 | timesteps = timesteps.expand(sample.shape[0]) 950 | 951 | t_emb = self.time_proj(timesteps) 952 | # `Timesteps` does not contain any weights and will always return f32 tensors 953 | # but time_embedding might actually be running in fp16. so we need to cast here. 954 | # there might be better ways to encapsulate this. 955 | t_emb = t_emb.to(dtype=sample.dtype) 956 | return t_emb 957 | 958 | def get_class_embed(self, sample: torch.Tensor, class_labels: Optional[torch.Tensor]) -> Optional[torch.Tensor]: 959 | class_emb = None 960 | if self.class_embedding is not None: 961 | if class_labels is None: 962 | raise ValueError("class_labels should be provided when num_class_embeds > 0") 963 | 964 | if self.config.class_embed_type == "timestep": 965 | class_labels = self.time_proj(class_labels) 966 | 967 | # `Timesteps` does not contain any weights and will always return f32 tensors 968 | # there might be better ways to encapsulate this. 969 | class_labels = class_labels.to(dtype=sample.dtype) 970 | 971 | class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype) 972 | return class_emb 973 | 974 | def get_aug_embed( 975 | self, emb: torch.Tensor, encoder_hidden_states: torch.Tensor, added_cond_kwargs: Dict[str, Any] 976 | ) -> Optional[torch.Tensor]: 977 | aug_emb = None 978 | if self.config.addition_embed_type == "text": 979 | aug_emb = self.add_embedding(encoder_hidden_states) 980 | elif self.config.addition_embed_type == "text_image": 981 | # Kandinsky 2.1 - style 982 | if "image_embeds" not in added_cond_kwargs: 983 | raise ValueError( 984 | f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" 985 | ) 986 | 987 | image_embs = added_cond_kwargs.get("image_embeds") 988 | text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states) 989 | aug_emb = self.add_embedding(text_embs, image_embs) 990 | elif self.config.addition_embed_type == "text_time": 991 | # SDXL - style 992 | if "text_embeds" not in added_cond_kwargs: 993 | raise ValueError( 994 | f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" 995 | ) 996 | text_embeds = added_cond_kwargs.get("text_embeds") 997 | if "time_ids" not in added_cond_kwargs: 998 | raise ValueError( 999 | f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" 1000 | ) 1001 | time_ids = added_cond_kwargs.get("time_ids") 1002 | time_embeds = self.add_time_proj(time_ids.flatten()) 1003 | time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) 1004 | add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) 1005 | add_embeds = add_embeds.to(emb.dtype) 1006 | aug_emb = self.add_embedding(add_embeds) 1007 | elif self.config.addition_embed_type == "image": 1008 | # Kandinsky 2.2 - style 1009 | if "image_embeds" not in added_cond_kwargs: 1010 | raise ValueError( 1011 | f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" 1012 | ) 1013 | image_embs = added_cond_kwargs.get("image_embeds") 1014 | aug_emb = self.add_embedding(image_embs) 1015 | elif self.config.addition_embed_type == "image_hint": 1016 | # Kandinsky 2.2 ControlNet - style 1017 | if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs: 1018 | raise ValueError( 1019 | f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`" 1020 | ) 1021 | image_embs = added_cond_kwargs.get("image_embeds") 1022 | hint = added_cond_kwargs.get("hint") 1023 | aug_emb = self.add_embedding(image_embs, hint) 1024 | return aug_emb 1025 | 1026 | def process_encoder_hidden_states( 1027 | self, encoder_hidden_states: torch.Tensor, added_cond_kwargs: Dict[str, Any] 1028 | ) -> torch.Tensor: 1029 | if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj": 1030 | encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states) 1031 | elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj": 1032 | # Kandinsky 2.1 - style 1033 | if "image_embeds" not in added_cond_kwargs: 1034 | raise ValueError( 1035 | f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" 1036 | ) 1037 | 1038 | image_embeds = added_cond_kwargs.get("image_embeds") 1039 | encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds) 1040 | elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj": 1041 | # Kandinsky 2.2 - style 1042 | if "image_embeds" not in added_cond_kwargs: 1043 | raise ValueError( 1044 | f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" 1045 | ) 1046 | image_embeds = added_cond_kwargs.get("image_embeds") 1047 | encoder_hidden_states = self.encoder_hid_proj(image_embeds) 1048 | elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj": 1049 | if "image_embeds" not in added_cond_kwargs: 1050 | raise ValueError( 1051 | f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" 1052 | ) 1053 | 1054 | if hasattr(self, "text_encoder_hid_proj") and self.text_encoder_hid_proj is not None: 1055 | encoder_hidden_states = self.text_encoder_hid_proj(encoder_hidden_states) 1056 | 1057 | image_embeds = added_cond_kwargs.get("image_embeds") 1058 | image_embeds = self.encoder_hid_proj(image_embeds) 1059 | encoder_hidden_states = (encoder_hidden_states, image_embeds) 1060 | return encoder_hidden_states 1061 | 1062 | def forward( 1063 | self, 1064 | sample: torch.Tensor, 1065 | timestep: Union[torch.Tensor, float, int], 1066 | encoder_hidden_states: torch.Tensor, 1067 | class_labels: Optional[torch.Tensor] = None, 1068 | timestep_cond: Optional[torch.Tensor] = None, 1069 | attention_mask: Optional[torch.Tensor] = None, 1070 | cross_attention_kwargs: Optional[Dict[str, Any]] = None, 1071 | added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, 1072 | down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, 1073 | mid_block_additional_residual: Optional[torch.Tensor] = None, 1074 | down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None, 1075 | encoder_attention_mask: Optional[torch.Tensor] = None, 1076 | return_dict: bool = True, 1077 | ) -> Union[UNet2DConditionOutput, Tuple]: 1078 | r""" 1079 | The [`UNet2DConditionModel`] forward method. 1080 | 1081 | Args: 1082 | sample (`torch.Tensor`): 1083 | The noisy input tensor with the following shape `(batch, channel, height, width)`. 1084 | timestep (`torch.Tensor` or `float` or `int`): The number of timesteps to denoise an input. 1085 | encoder_hidden_states (`torch.Tensor`): 1086 | The encoder hidden states with shape `(batch, sequence_length, feature_dim)`. 1087 | class_labels (`torch.Tensor`, *optional*, defaults to `None`): 1088 | Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings. 1089 | timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`): 1090 | Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed 1091 | through the `self.time_embedding` layer to obtain the timestep embeddings. 1092 | attention_mask (`torch.Tensor`, *optional*, defaults to `None`): 1093 | An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask 1094 | is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large 1095 | negative values to the attention scores corresponding to "discard" tokens. 1096 | cross_attention_kwargs (`dict`, *optional*): 1097 | A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under 1098 | `self.processor` in 1099 | [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). 1100 | added_cond_kwargs: (`dict`, *optional*): 1101 | A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that 1102 | are passed along to the UNet blocks. 1103 | down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*): 1104 | A tuple of tensors that if specified are added to the residuals of down unet blocks. 1105 | mid_block_additional_residual: (`torch.Tensor`, *optional*): 1106 | A tensor that if specified is added to the residual of the middle unet block. 1107 | down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*): 1108 | additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s) 1109 | encoder_attention_mask (`torch.Tensor`): 1110 | A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If 1111 | `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias, 1112 | which adds large negative values to the attention scores corresponding to "discard" tokens. 1113 | return_dict (`bool`, *optional*, defaults to `True`): 1114 | Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain 1115 | tuple. 1116 | 1117 | Returns: 1118 | [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: 1119 | If `return_dict` is True, an [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] is returned, 1120 | otherwise a `tuple` is returned where the first element is the sample tensor. 1121 | """ 1122 | # By default samples have to be AT least a multiple of the overall upsampling factor. 1123 | # The overall upsampling factor is equal to 2 ** (# num of upsampling layers). 1124 | # However, the upsampling interpolation output size can be forced to fit any upsampling size 1125 | # on the fly if necessary. 1126 | default_overall_up_factor = 2**self.num_upsamplers 1127 | 1128 | # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` 1129 | forward_upsample_size = False 1130 | upsample_size = None 1131 | 1132 | for dim in sample.shape[-2:]: 1133 | if dim % default_overall_up_factor != 0: 1134 | # Forward upsample size to force interpolation output size. 1135 | forward_upsample_size = True 1136 | break 1137 | 1138 | # ensure attention_mask is a bias, and give it a singleton query_tokens dimension 1139 | # expects mask of shape: 1140 | # [batch, key_tokens] 1141 | # adds singleton query_tokens dimension: 1142 | # [batch, 1, key_tokens] 1143 | # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: 1144 | # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) 1145 | # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) 1146 | if attention_mask is not None: 1147 | # assume that mask is expressed as: 1148 | # (1 = keep, 0 = discard) 1149 | # convert mask into a bias that can be added to attention scores: 1150 | # (keep = +0, discard = -10000.0) 1151 | attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 1152 | attention_mask = attention_mask.unsqueeze(1) 1153 | 1154 | # convert encoder_attention_mask to a bias the same way we do for attention_mask 1155 | if encoder_attention_mask is not None: 1156 | encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0 1157 | encoder_attention_mask = encoder_attention_mask.unsqueeze(1) 1158 | 1159 | # 0. center input if necessary 1160 | if self.config.center_input_sample: 1161 | sample = 2 * sample - 1.0 1162 | 1163 | # 1. time 1164 | t_emb = self.get_time_embed(sample=sample, timestep=timestep) 1165 | emb = self.time_embedding(t_emb, timestep_cond) 1166 | 1167 | class_emb = self.get_class_embed(sample=sample, class_labels=class_labels) 1168 | if class_emb is not None: 1169 | if self.config.class_embeddings_concat: 1170 | emb = torch.cat([emb, class_emb], dim=-1) 1171 | else: 1172 | emb = emb + class_emb 1173 | 1174 | aug_emb = self.get_aug_embed( 1175 | emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs 1176 | ) 1177 | if self.config.addition_embed_type == "image_hint": 1178 | aug_emb, hint = aug_emb 1179 | sample = torch.cat([sample, hint], dim=1) 1180 | 1181 | 1182 | emb = emb + aug_emb if aug_emb is not None else emb 1183 | 1184 | if self.time_embed_act is not None: 1185 | emb = self.time_embed_act(emb) 1186 | 1187 | encoder_hidden_states = self.process_encoder_hidden_states( 1188 | encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs 1189 | ) 1190 | 1191 | sample_img = sample[:, :4, :, :] 1192 | sample_dem = sample[:, 4:, :, :] 1193 | # 2. pre-process using the two different heads 1194 | sample_img = self.conv_in_img(sample_img) 1195 | sample_dem = self.conv_in_dem(sample_dem) 1196 | 1197 | # 2.5 GLIGEN position net 1198 | if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None: 1199 | cross_attention_kwargs = cross_attention_kwargs.copy() 1200 | gligen_args = cross_attention_kwargs.pop("gligen") 1201 | cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)} 1202 | 1203 | # 3. down 1204 | # we're popping the `scale` instead of getting it because otherwise `scale` will be propagated 1205 | # to the internal blocks and will raise deprecation warnings. this will be confusing for our users. 1206 | if cross_attention_kwargs is not None: 1207 | cross_attention_kwargs = cross_attention_kwargs.copy() 1208 | lora_scale = cross_attention_kwargs.pop("scale", 1.0) 1209 | else: 1210 | lora_scale = 1.0 1211 | 1212 | if USE_PEFT_BACKEND: 1213 | # weight the lora layers by setting `lora_scale` for each PEFT layer 1214 | scale_lora_layers(self, lora_scale) 1215 | 1216 | is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None 1217 | # using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets 1218 | is_adapter = down_intrablock_additional_residuals is not None 1219 | if (down_intrablock_additional_residuals is not None) or is_adapter: 1220 | raise NotImplementedError("additional_residuals") 1221 | 1222 | 1223 | # go through the heads 1224 | head_img_res_sample = (sample_img,) 1225 | # RGB head 1226 | if hasattr(self.head_img, "has_cross_attention") and self.head_img.has_cross_attention: 1227 | # For t2i-adapter CrossAttnDownBlock2D 1228 | additional_residuals = {} 1229 | sample_img, res_samples_img = self.head_img( 1230 | hidden_states=sample_img, 1231 | temb=emb, 1232 | encoder_hidden_states=encoder_hidden_states, 1233 | attention_mask=attention_mask, 1234 | cross_attention_kwargs=cross_attention_kwargs, 1235 | encoder_attention_mask=encoder_attention_mask, 1236 | **additional_residuals, 1237 | ) 1238 | else: 1239 | sample_img, res_samples_img = self.head_img(hidden_states=sample, temb=emb) 1240 | head_img_res_sample += res_samples_img[:2] 1241 | 1242 | 1243 | 1244 | head_dem_res_sample = (sample_dem,) 1245 | # DEM head 1246 | if hasattr(self.head_dem, "has_cross_attention") and self.head_dem.has_cross_attention: 1247 | # For t2i-adapter CrossAttnDownBlock2D 1248 | additional_residuals = {} 1249 | 1250 | sample_dem, res_samples_dem = self.head_dem( 1251 | hidden_states=sample_dem, 1252 | temb=emb, 1253 | encoder_hidden_states=encoder_hidden_states, 1254 | attention_mask=attention_mask, 1255 | cross_attention_kwargs=cross_attention_kwargs, 1256 | encoder_attention_mask=encoder_attention_mask, 1257 | **additional_residuals, 1258 | ) 1259 | else: 1260 | # sample_dem, res_samples_dem = self.head_dem(hidden_states=sample, temb=emb) 1261 | sample_dem, res_samples_dem = self.head_img(hidden_states=sample, temb=emb) # shared weights 1262 | 1263 | head_dem_res_sample += res_samples_dem[:2] 1264 | 1265 | #average the two heads and pass them through the down blocks 1266 | sample = (sample_img + sample_dem) / 2 1267 | ##### 1268 | res_samples_img_dem = (res_samples_img[2] + res_samples_dem[2]) / 2 1269 | down_block_res_samples = (res_samples_img_dem,) 1270 | 1271 | 1272 | for downsample_block in self.down_blocks: 1273 | if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: 1274 | # For t2i-adapter CrossAttnDownBlock2D 1275 | additional_residuals = {} 1276 | if is_adapter and len(down_intrablock_additional_residuals) > 0: 1277 | additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0) 1278 | 1279 | sample, res_samples = downsample_block( 1280 | hidden_states=sample, 1281 | temb=emb, 1282 | encoder_hidden_states=encoder_hidden_states, 1283 | attention_mask=attention_mask, 1284 | cross_attention_kwargs=cross_attention_kwargs, 1285 | encoder_attention_mask=encoder_attention_mask, 1286 | **additional_residuals, 1287 | ) 1288 | else: 1289 | sample, res_samples = downsample_block(hidden_states=sample, temb=emb) 1290 | if is_adapter and len(down_intrablock_additional_residuals) > 0: 1291 | sample += down_intrablock_additional_residuals.pop(0) 1292 | 1293 | down_block_res_samples += res_samples 1294 | 1295 | if is_controlnet: 1296 | new_down_block_res_samples = () 1297 | 1298 | for down_block_res_sample, down_block_additional_residual in zip( 1299 | down_block_res_samples, down_block_additional_residuals 1300 | ): 1301 | down_block_res_sample = down_block_res_sample + down_block_additional_residual 1302 | new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,) 1303 | 1304 | down_block_res_samples = new_down_block_res_samples 1305 | 1306 | 1307 | # 4. mid 1308 | if self.mid_block is not None: 1309 | if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention: 1310 | sample = self.mid_block( 1311 | sample, 1312 | emb, 1313 | encoder_hidden_states=encoder_hidden_states, 1314 | attention_mask=attention_mask, 1315 | cross_attention_kwargs=cross_attention_kwargs, 1316 | encoder_attention_mask=encoder_attention_mask, 1317 | ) 1318 | else: 1319 | sample = self.mid_block(sample, emb) 1320 | 1321 | # To support T2I-Adapter-XL 1322 | if ( 1323 | is_adapter 1324 | and len(down_intrablock_additional_residuals) > 0 1325 | and sample.shape == down_intrablock_additional_residuals[0].shape 1326 | ): 1327 | sample += down_intrablock_additional_residuals.pop(0) 1328 | 1329 | if is_controlnet: 1330 | sample = sample + mid_block_additional_residual 1331 | 1332 | # 5. up 1333 | for i, upsample_block in enumerate(self.up_blocks): 1334 | 1335 | res_samples = down_block_res_samples[-len(upsample_block.resnets) :] 1336 | down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] 1337 | 1338 | # if we have not reached the final block and need to forward the 1339 | # upsample size, we do it here 1340 | if forward_upsample_size: 1341 | upsample_size = down_block_res_samples[-1].shape[2:] 1342 | if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: 1343 | sample = upsample_block( 1344 | hidden_states=sample, 1345 | temb=emb, 1346 | res_hidden_states_tuple=res_samples, 1347 | encoder_hidden_states=encoder_hidden_states, 1348 | cross_attention_kwargs=cross_attention_kwargs, 1349 | upsample_size=upsample_size, 1350 | attention_mask=attention_mask, 1351 | encoder_attention_mask=encoder_attention_mask, 1352 | ) 1353 | 1354 | else: 1355 | sample = upsample_block( 1356 | hidden_states=sample, 1357 | temb=emb, 1358 | res_hidden_states_tuple=res_samples, 1359 | upsample_size=upsample_size) 1360 | 1361 | 1362 | # go through each head 1363 | 1364 | sample_img = sample 1365 | 1366 | if hasattr(self.head_out_img, "has_cross_attention") and self.head_out_img.has_cross_attention: 1367 | sample_img = self.head_out_img( 1368 | hidden_states=sample_img, 1369 | temb=emb, 1370 | res_hidden_states_tuple=head_img_res_sample, 1371 | encoder_hidden_states=encoder_hidden_states, 1372 | cross_attention_kwargs=cross_attention_kwargs, 1373 | upsample_size=upsample_size, 1374 | attention_mask=attention_mask, 1375 | encoder_attention_mask=encoder_attention_mask, 1376 | ) 1377 | else: 1378 | sample_img = self.head_out_img(sample_img, 1379 | hidden_states=sample, 1380 | temb=emb, 1381 | res_hidden_states_tuple=head_img_res_sample, 1382 | upsample_size=upsample_size, 1383 | ) 1384 | if self.conv_norm_out_img: 1385 | sample_img = self.conv_norm_out_img(sample_img) 1386 | sample_img = self.conv_act(sample_img) 1387 | sample_img = self.conv_out_img(sample_img) 1388 | 1389 | sample_dem = sample 1390 | 1391 | if hasattr(self.head_out_dem, "has_cross_attention") and self.head_out_dem.has_cross_attention: 1392 | sample_dem = self.head_out_dem( 1393 | hidden_states=sample_dem, 1394 | temb=emb, 1395 | res_hidden_states_tuple=head_dem_res_sample, 1396 | encoder_hidden_states=encoder_hidden_states, 1397 | cross_attention_kwargs=cross_attention_kwargs, 1398 | upsample_size=upsample_size, 1399 | attention_mask=attention_mask, 1400 | encoder_attention_mask=encoder_attention_mask, 1401 | ) 1402 | else: 1403 | sample_dem = self.head_out_dem(sample_dem, 1404 | hidden_states=sample, 1405 | temb=emb, 1406 | res_hidden_states_tuple=head_dem_res_sample, 1407 | upsample_size=upsample_size, 1408 | ) 1409 | 1410 | if self.conv_norm_out_dem: 1411 | sample_dem = self.conv_norm_out_dem(sample_dem) 1412 | sample_dem = self.conv_act(sample_dem) 1413 | sample_dem = self.conv_out_dem(sample_dem) 1414 | 1415 | sample = torch.cat([sample_img,sample_dem],dim=1) 1416 | 1417 | if USE_PEFT_BACKEND: 1418 | # remove `lora_scale` from each PEFT layer 1419 | unscale_lora_layers(self, lora_scale) 1420 | 1421 | if not return_dict: 1422 | return (sample,) 1423 | 1424 | return UNet2DConditionOutput(sample=sample) 1425 | 1426 | 1427 | 1428 | def load_weights_from_pretrained(pretrain_model,model_dem): 1429 | dem_state_dict = model_dem.state_dict() 1430 | for name, param in pretrain_model.named_parameters(): 1431 | block = name.split(".")[0] 1432 | if block == "conv_in": 1433 | new_name_img = name.replace("conv_in","conv_in_img") 1434 | dem_state_dict[new_name_img] = param 1435 | new_name_dem = name.replace("conv_in","conv_in_dem") 1436 | dem_state_dict[new_name_dem] = param 1437 | if block == "down_blocks": 1438 | block_num = int(name.split(".")[1]) 1439 | if block_num == 0: 1440 | new_name_img = name.replace("down_blocks.0","head_img") 1441 | dem_state_dict[new_name_img] = param 1442 | new_name_dem = name.replace("down_blocks.0","head_dem") 1443 | dem_state_dict[new_name_dem] = param 1444 | elif block_num > 0: 1445 | new_name = name.replace(f"down_blocks.{block_num}",f"down_blocks.{block_num-1}") 1446 | dem_state_dict[new_name] = param 1447 | if block == "mid_block": 1448 | dem_state_dict[name] = param 1449 | if block == "time_embedding": 1450 | dem_state_dict[name] = param 1451 | if block == "up_blocks": 1452 | block_num = int(name.split(".")[1]) 1453 | if block_num == 3: 1454 | new_name = name.replace("up_blocks.3","head_out_img") 1455 | dem_state_dict[new_name] = param 1456 | new_name = name.replace("up_blocks.3","head_out_dem") 1457 | dem_state_dict[new_name] = param 1458 | else: 1459 | dem_state_dict[name] = param 1460 | if block == "conv_out": 1461 | new_name = name.replace("conv_out","conv_out_img") 1462 | dem_state_dict[new_name] = param 1463 | new_name = name.replace("conv_out","conv_out_dem") 1464 | dem_state_dict[new_name] = param 1465 | if block == "conv_norm_out": 1466 | new_name = name.replace("conv_norm_out","conv_norm_out_img") 1467 | dem_state_dict[new_name] = param 1468 | new_name = name.replace("conv_norm_out","conv_norm_out_dem") 1469 | dem_state_dict[new_name] = param 1470 | 1471 | model_dem.load_state_dict(dem_state_dict) 1472 | 1473 | return model_dem 1474 | --------------------------------------------------------------------------------