├── assets
└── mesa-header-nz.png
├── requirements.txt
├── LICENSE
├── README.md
├── pipeline_terrain.py
└── models.py
/assets/mesa-header-nz.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PaulBorneP/MESA/HEAD/assets/mesa-header-nz.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | diffusers==0.32.2
2 | transformers==4.51.1
3 | accelerate==1.5.2
4 | torch --index-url https://download.pytorch.org/whl/cu124
5 |
6 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | # ADOBE RESEARCH LICENSE
2 |
3 | This license agreement (the “License”) between Adobe Inc., having a place of business at 345 Park Avenue, San Jose, California 95110-2704 (“Adobe”), and you, the individual or entity exercising rights under this License (“you” or “your”), sets forth the terms for your use of certain research materials that are owned by Adobe (the “Licensed Materials”). By exercising rights under this License, you accept and agree to be bound by its terms. If you are exercising rights under this license on behalf of an entity, then “you” means you and such entity, and you (personally) represent and warrant that you (personally) have all necessary authority to bind that entity to the terms of this License.
4 |
5 | 1. **GRANT OF LICENSE.**
6 |
7 | 1.1. Adobe grants you a nonexclusive, worldwide, royalty-free, revocable, fully paid license to (A) reproduce, use, modify, and publicly display the Licensed Materials for noncommercial research purposes only; and (B) redistribute the Licensed Materials, and modifications or derivative works thereof, for noncommercial research purposes only, provided that you give recipients a copy of this License.
8 |
9 | 1.2. You may add your own copyright statement to your modifications and may provide additional or different license terms for use, reproduction, modification, public display, and redistribution of your modifications and derivative works, provided that such license terms limit the use, reproduction, modification, public display, and redistribution of such modifications and derivative works to noncommercial research purposes only.
10 |
11 | 1.3. For purposes of this License, noncommercial research purposes include academic research and teaching but do not include commercial licensing or distribution, development of commercial products, or any other activity which results in commercial gain.
12 |
13 |
14 | 2. **OWNERSHIP AND ATTRIBUTION.** Adobe and its licensors own all right, title, and interest in the Licensed Materials. You must keep intact any copyright or other notices or disclaimers in the Licensed Materials.
15 |
16 | 3. **DISCLAIMER OF WARRANTIES.** THE LICENSED MATERIALS ARE PROVIDED “AS IS” WITHOUT WARRANTY OF ANY KIND. THE ENTIRE RISK AS TO THE RESULTS AND PERFORMANCE OF THE LICENSED MATERIALS IS ASSUMED BY YOU. ADOBE DISCLAIMS ALL WARRANTIES, EXPRESS, IMPLIED OR STATUTORY, WITH REGARD TO ANY LICENSED MATERIALS PROVIDED UNDER THIS LICENSE, INCLUDING, BUT NOT LIMITED TO, ANY IMPLIED WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, AND NONINFRINGEMENT OF THIRD-PARTY RIGHTS.
17 |
18 | 4. **LIMITATION OF LIABILITY.** IN NO EVENT WILL ADOBE BE LIABLE FOR ANY ACTUAL, INCIDENTAL, SPECIAL OR CONSEQUENTIAL DAMAGES OF ANY NATURE WHATSOEVER, INCLUDING WITHOUT LIMITATION, LOSS OF PROFITS OR OTHER COMMERCIAL LOSS, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF ANY LICENSED MATERIALS PROVIDED UNDER THIS LICENSE, EVEN IF ADOBE HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES.
19 |
20 | 5. **TERM AND TERMINATION.**
21 |
22 | 5.1. The License is effective upon acceptance by you and will remain in effect unless terminated earlier as permitted under this License.
23 |
24 | 5.2. If you breach any material provision of this License, then your rights will terminate immediately.
25 |
26 | 5.3. All clauses which by their nature should survive the termination of this License will survive such termination. In addition, and without limiting the generality of the preceding sentence, Sections 2 (Ownership and Attribution), 3 (Disclaimer of Warranties), and 4 (Limitation of Liability) will survive termination of this License.
27 |
28 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
MESA: Text-Driven Terrain Generation Using Latent Diffusion and Global Copernicus Data
4 | Paul Borne--Pons, Mikolaj Czerkawski,Rosalie Martin,
5 | Romain Rouffet
6 |
7 | CVPR 2025 Workshop MORSE
8 |
9 | [](https://paulbornep.github.io/mesa-terrain/)
10 | [](https://arxiv.org/abs/2504.07210)
11 | [](https://www.huggingface.co/NewtNewt/MESA)
12 | [](https://www.huggingface.co/Major-TOM)
13 | [](https://huggingface.co/spaces/mikonvergence/MESA)
14 |
15 |
16 |
17 | MESA is a novel generative model based on latent denoising diffusion capable of generating 2.5D representations of terrain based on the text prompt conditioning supplied via natural language. The model produces two co-registered modalities of optical and depth maps.
18 |
19 | 
20 |
21 | ## Abstract
22 |
23 | Terrain modeling has traditionally relied on procedural techniques, which often require extensive domain expertise and handcrafted rules. In this paper, we present MESA - a novel data-centric alternative by training a diffusion model on global remote sensing data. This approach leverages large-scale geospatial information to generate high-quality terrain samples from text descriptions, showcasing a flexible and scalable solution for terrain generation. The model’s capabilities are demonstrated through extensive experiments, highlighting its ability to generate realistic and diverse terrain landscapes. The dataset produced to support this work, the Major TOM Core-DEM extension dataset, is released openly as a comprehensive resource for global terrain data. The results suggest that data-driven models, trained on remote sensing data, can provide a powerful tool for realistic terrain modeling and generation.
24 |
25 | ## Model Weights
26 |
27 | You still manually acquire the weights by cloning the models from Hugging Face:
28 |
29 | ```bash
30 | mkdir weights
31 | huggingface-cli download NewtNewt/MESA --local-dir weights
32 | ```
33 |
34 | ## Installation
35 |
36 | ```bash
37 | # using python 3.11.12
38 | pip install -r requirements.txt
39 | ```
40 |
41 | Note that this environment is only compatible with NVIDIA GPUs. Additionally, we recommend using a GPU with a minimum of 8GB of memory.
42 |
43 | ## Inference
44 |
45 | ```python
46 | from MESA.pipeline_terrain import TerrainDiffusionPipeline
47 | import torch
48 |
49 | pipe = TerrainDiffusionPipeline.from_pretrained("./weights", torch_dtype=torch.float16)
50 | pipe.to("cuda");
51 |
52 | prompt = "A sentinel-2 image of montane forests and mountains in Mexico in August"
53 | image,dem = pipe(prompt, num_inference_steps=50, guidance_scale=7.5)
54 | ```
55 |
56 | A straightforward code for inference is provided in
57 |
58 | Alternatively, you can download and use the Gradio demo from the HF page.
59 |
60 | ## Citation
61 |
62 | ```latex
63 | @inproceedings{mesa2025,
64 | title={MESA: Text-Driven Terrain Generation Using Latent Diffusion and Global Copernicus Data},
65 | author={Paul Borne--Pons and Mikolaj Czerkawski and Rosalie Martin and Romain Rouffet},
66 | year={2025},
67 | booktitle={MORSE Workshop at CVPR 2025},
68 | eprint={2504.07210},
69 | url={https://arxiv.org/abs/2504.07210},}
70 | ```
71 | ## Acknowledgements
72 |
73 | This implementation builds upon Hugging Face’s [Diffusers](https://github.com/huggingface/diffusers) library. We also acknowledge [Gradio](https://www.gradio.app/) for providing an easy-to-use interface that allowed us to create the inference demos for our models.
74 |
75 | This model is the product of a collaboration between [Φ-lab, European Space Agency (ESA)](https://philab.esa.int/) and the [Adobe Research (Paris, France)](https://research.adobe.com/careers/paris/).
76 |
--------------------------------------------------------------------------------
/pipeline_terrain.py:
--------------------------------------------------------------------------------
1 | ###########################################################################
2 | # References:
3 | # https://github.com/huggingface/diffusers/
4 | ###########################################################################
5 | import inspect
6 | from typing import Any, Callable, Dict, List, Optional, Union
7 |
8 | import torch
9 | from packaging import version
10 | from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
11 |
12 | from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
13 | from diffusers.configuration_utils import FrozenDict
14 | from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
15 | from diffusers.loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
16 | from diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel
17 | from diffusers.models.lora import adjust_lora_scale_text_encoder
18 | from diffusers.schedulers import KarrasDiffusionSchedulers
19 | from diffusers.utils import (
20 | USE_PEFT_BACKEND,
21 | deprecate,
22 | is_torch_xla_available,
23 | logging,
24 | replace_example_docstring,
25 | scale_lora_layers,
26 | unscale_lora_layers,
27 | )
28 | from diffusers.utils.torch_utils import randn_tensor
29 | from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
30 | from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
31 | from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
32 |
33 |
34 | if is_torch_xla_available():
35 | import torch_xla.core.xla_model as xm
36 |
37 | XLA_AVAILABLE = True
38 | else:
39 | XLA_AVAILABLE = False
40 |
41 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name
42 |
43 | EXAMPLE_DOC_STRING = """
44 | Examples:
45 | ```py
46 | >>> import torch
47 | >>> from diffusers import StableDiffusionPipeline
48 |
49 | >>> pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
50 | >>> pipe = pipe.to("cuda")
51 |
52 | >>> prompt = "a photo of an astronaut riding a horse on mars"
53 | >>> image = pipe(prompt).images[0]
54 | ```
55 | """
56 |
57 |
58 | def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
59 | """
60 | Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
61 | Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
62 | """
63 | std_text = noise_pred_text.std(
64 | dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
65 | std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
66 | # rescale the results from guidance (fixes overexposure)
67 | noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
68 | # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
69 | noise_cfg = guidance_rescale * noise_pred_rescaled + \
70 | (1 - guidance_rescale) * noise_cfg
71 | return noise_cfg
72 |
73 |
74 | def retrieve_timesteps(
75 | scheduler,
76 | num_inference_steps: Optional[int] = None,
77 | device: Optional[Union[str, torch.device]] = None,
78 | timesteps: Optional[List[int]] = None,
79 | sigmas: Optional[List[float]] = None,
80 | **kwargs,
81 | ):
82 | """
83 | Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
84 | custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
85 |
86 | Args:
87 | scheduler (`SchedulerMixin`):
88 | The scheduler to get timesteps from.
89 | num_inference_steps (`int`):
90 | The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
91 | must be `None`.
92 | device (`str` or `torch.device`, *optional*):
93 | The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
94 | timesteps (`List[int]`, *optional*):
95 | Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
96 | `num_inference_steps` and `sigmas` must be `None`.
97 | sigmas (`List[float]`, *optional*):
98 | Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
99 | `num_inference_steps` and `timesteps` must be `None`.
100 |
101 | Returns:
102 | `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
103 | second element is the number of inference steps.
104 | """
105 | if timesteps is not None and sigmas is not None:
106 | raise ValueError(
107 | "Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
108 | if timesteps is not None:
109 | accepts_timesteps = "timesteps" in set(
110 | inspect.signature(scheduler.set_timesteps).parameters.keys())
111 | if not accepts_timesteps:
112 | raise ValueError(
113 | f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
114 | f" timestep schedules. Please check whether you are using the correct scheduler."
115 | )
116 | scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
117 | timesteps = scheduler.timesteps
118 | num_inference_steps = len(timesteps)
119 | elif sigmas is not None:
120 | accept_sigmas = "sigmas" in set(inspect.signature(
121 | scheduler.set_timesteps).parameters.keys())
122 | if not accept_sigmas:
123 | raise ValueError(
124 | f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
125 | f" sigmas schedules. Please check whether you are using the correct scheduler."
126 | )
127 | scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
128 | timesteps = scheduler.timesteps
129 | num_inference_steps = len(timesteps)
130 | else:
131 | scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
132 | timesteps = scheduler.timesteps
133 | return timesteps, num_inference_steps
134 |
135 |
136 | class TerrainDiffusionPipeline(
137 | DiffusionPipeline,
138 | StableDiffusionMixin,
139 | TextualInversionLoaderMixin,
140 | StableDiffusionLoraLoaderMixin,
141 | IPAdapterMixin,
142 | FromSingleFileMixin,
143 | ):
144 | r"""
145 | Pipeline for text-to-image generation using Stable Diffusion.
146 |
147 | This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
148 | implemented for all pipelines (downloading, saving, running on a particular device, etc.).
149 |
150 | The pipeline also inherits the following loading methods:
151 | - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
152 | - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
153 | - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
154 | - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
155 | - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
156 |
157 | Args:
158 | vae ([`AutoencoderKL`]):
159 | Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
160 | text_encoder ([`~transformers.CLIPTextModel`]):
161 | Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
162 | tokenizer ([`~transformers.CLIPTokenizer`]):
163 | A `CLIPTokenizer` to tokenize text.
164 | unet ([`UNet2DConditionModel`]):
165 | A `UNet2DConditionModel` to denoise the encoded image latents.
166 | scheduler ([`SchedulerMixin`]):
167 | A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
168 | [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
169 | safety_checker ([`StableDiffusionSafetyChecker`]):
170 | Classification module that estimates whether generated images could be considered offensive or harmful.
171 | Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
172 | about a model's potential harms.
173 | feature_extractor ([`~transformers.CLIPImageProcessor`]):
174 | A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
175 | """
176 |
177 | model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae"
178 | _optional_components = ["safety_checker",
179 | "feature_extractor", "image_encoder"]
180 | _exclude_from_cpu_offload = ["safety_checker"]
181 | _callback_tensor_inputs = ["latents",
182 | "prompt_embeds", "negative_prompt_embeds"]
183 |
184 | def __init__(
185 | self,
186 | vae: AutoencoderKL,
187 | text_encoder: CLIPTextModel,
188 | tokenizer: CLIPTokenizer,
189 | unet: UNet2DConditionModel,
190 | scheduler: KarrasDiffusionSchedulers,
191 | safety_checker: StableDiffusionSafetyChecker,
192 | feature_extractor: CLIPImageProcessor,
193 | image_encoder: CLIPVisionModelWithProjection = None,
194 | requires_safety_checker: bool = True,
195 | ):
196 | super().__init__()
197 |
198 | if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
199 | deprecation_message = (
200 | f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
201 | f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
202 | "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
203 | " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
204 | " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
205 | " file"
206 | )
207 | deprecate("steps_offset!=1", "1.0.0",
208 | deprecation_message, standard_warn=False)
209 | new_config = dict(scheduler.config)
210 | new_config["steps_offset"] = 1
211 | scheduler._internal_dict = FrozenDict(new_config)
212 |
213 | if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
214 | deprecation_message = (
215 | f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
216 | " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
217 | " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
218 | " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
219 | " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
220 | )
221 | deprecate("clip_sample not set", "1.0.0",
222 | deprecation_message, standard_warn=False)
223 | new_config = dict(scheduler.config)
224 | new_config["clip_sample"] = False
225 | scheduler._internal_dict = FrozenDict(new_config)
226 |
227 | if safety_checker is None and requires_safety_checker:
228 | logger.warning(
229 | f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
230 | " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
231 | " results in services or applications open to the public. Both the diffusers team and Hugging Face"
232 | " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
233 | " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
234 | " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
235 | )
236 |
237 | if safety_checker is not None and feature_extractor is None:
238 | raise ValueError(
239 | "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
240 | " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
241 | )
242 |
243 | is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
244 | version.parse(unet.config._diffusers_version).base_version
245 | ) < version.parse("0.9.0.dev0")
246 | is_unet_sample_size_less_64 = hasattr(
247 | unet.config, "sample_size") and unet.config.sample_size < 64
248 | if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
249 | deprecation_message = (
250 | "The configuration file of the unet has set the default `sample_size` to smaller than"
251 | " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
252 | " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
253 | " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
254 | " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
255 | " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
256 | " in the config might lead to incorrect results in future versions. If you have downloaded this"
257 | " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
258 | " the `unet/config.json` file"
259 | )
260 | deprecate("sample_size<64", "1.0.0",
261 | deprecation_message, standard_warn=False)
262 | new_config = dict(unet.config)
263 | new_config["sample_size"] = 64
264 | unet._internal_dict = FrozenDict(new_config)
265 |
266 | self.register_modules(
267 | vae=vae,
268 | text_encoder=text_encoder,
269 | tokenizer=tokenizer,
270 | unet=unet,
271 | scheduler=scheduler,
272 | safety_checker=safety_checker,
273 | feature_extractor=feature_extractor,
274 | image_encoder=image_encoder,
275 | )
276 | self.vae_scale_factor = 2 ** (
277 | len(self.vae.config.block_out_channels) - 1)
278 | self.image_processor = VaeImageProcessor(
279 | vae_scale_factor=self.vae_scale_factor)
280 | self.register_to_config(
281 | requires_safety_checker=requires_safety_checker)
282 |
283 | def _encode_prompt(
284 | self,
285 | prompt,
286 | device,
287 | num_images_per_prompt,
288 | do_classifier_free_guidance,
289 | negative_prompt=None,
290 | prompt_embeds: Optional[torch.Tensor] = None,
291 | negative_prompt_embeds: Optional[torch.Tensor] = None,
292 | lora_scale: Optional[float] = None,
293 | **kwargs,
294 | ):
295 | deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
296 | deprecate("_encode_prompt()", "1.0.0",
297 | deprecation_message, standard_warn=False)
298 |
299 | prompt_embeds_tuple = self.encode_prompt(
300 | prompt=prompt,
301 | device=device,
302 | num_images_per_prompt=num_images_per_prompt,
303 | do_classifier_free_guidance=do_classifier_free_guidance,
304 | negative_prompt=negative_prompt,
305 | prompt_embeds=prompt_embeds,
306 | negative_prompt_embeds=negative_prompt_embeds,
307 | lora_scale=lora_scale,
308 | **kwargs,
309 | )
310 |
311 | # concatenate for backwards comp
312 | prompt_embeds = torch.cat(
313 | [prompt_embeds_tuple[1], prompt_embeds_tuple[0]])
314 |
315 | return prompt_embeds
316 |
317 | def encode_prompt(
318 | self,
319 | prompt,
320 | device,
321 | num_images_per_prompt,
322 | do_classifier_free_guidance,
323 | negative_prompt=None,
324 | prompt_embeds: Optional[torch.Tensor] = None,
325 | negative_prompt_embeds: Optional[torch.Tensor] = None,
326 | lora_scale: Optional[float] = None,
327 | clip_skip: Optional[int] = None,
328 | ):
329 | r"""
330 | Encodes the prompt into text encoder hidden states.
331 |
332 | Args:
333 | prompt (`str` or `List[str]`, *optional*):
334 | prompt to be encoded
335 | device: (`torch.device`):
336 | torch device
337 | num_images_per_prompt (`int`):
338 | number of images that should be generated per prompt
339 | do_classifier_free_guidance (`bool`):
340 | whether to use classifier free guidance or not
341 | negative_prompt (`str` or `List[str]`, *optional*):
342 | The prompt or prompts not to guide the image generation. If not defined, one has to pass
343 | `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
344 | less than `1`).
345 | prompt_embeds (`torch.Tensor`, *optional*):
346 | Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
347 | provided, text embeddings will be generated from `prompt` input argument.
348 | negative_prompt_embeds (`torch.Tensor`, *optional*):
349 | Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
350 | weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
351 | argument.
352 | lora_scale (`float`, *optional*):
353 | A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
354 | clip_skip (`int`, *optional*):
355 | Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
356 | the output of the pre-final layer will be used for computing the prompt embeddings.
357 | """
358 | # set lora scale so that monkey patched LoRA
359 | # function of text encoder can correctly access it
360 | if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin):
361 | self._lora_scale = lora_scale
362 |
363 | # dynamically adjust the LoRA scale
364 | if not USE_PEFT_BACKEND:
365 | adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
366 | else:
367 | scale_lora_layers(self.text_encoder, lora_scale)
368 |
369 | if prompt is not None and isinstance(prompt, str):
370 | batch_size = 1
371 | elif prompt is not None and isinstance(prompt, list):
372 | batch_size = len(prompt)
373 | else:
374 | batch_size = prompt_embeds.shape[0]
375 |
376 | if prompt_embeds is None:
377 | # textual inversion: process multi-vector tokens if necessary
378 | if isinstance(self, TextualInversionLoaderMixin):
379 | prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
380 |
381 | text_inputs = self.tokenizer(
382 | prompt,
383 | padding="max_length",
384 | max_length=self.tokenizer.model_max_length,
385 | truncation=True,
386 | return_tensors="pt",
387 | )
388 | text_input_ids = text_inputs.input_ids
389 | untruncated_ids = self.tokenizer(
390 | prompt, padding="longest", return_tensors="pt").input_ids
391 |
392 | if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
393 | text_input_ids, untruncated_ids
394 | ):
395 | removed_text = self.tokenizer.batch_decode(
396 | untruncated_ids[:, self.tokenizer.model_max_length - 1: -1]
397 | )
398 | logger.warning(
399 | "The following part of your input was truncated because CLIP can only handle sequences up to"
400 | f" {self.tokenizer.model_max_length} tokens: {removed_text}"
401 | )
402 |
403 | if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
404 | attention_mask = text_inputs.attention_mask.to(device)
405 | else:
406 | attention_mask = None
407 |
408 | if clip_skip is None:
409 | prompt_embeds = self.text_encoder(
410 | text_input_ids.to(device), attention_mask=attention_mask)
411 | prompt_embeds = prompt_embeds[0]
412 | else:
413 | prompt_embeds = self.text_encoder(
414 | text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True
415 | )
416 | # Access the `hidden_states` first, that contains a tuple of
417 | # all the hidden states from the encoder layers. Then index into
418 | # the tuple to access the hidden states from the desired layer.
419 | prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]
420 | # We also need to apply the final LayerNorm here to not mess with the
421 | # representations. The `last_hidden_states` that we typically use for
422 | # obtaining the final prompt representations passes through the LayerNorm
423 | # layer.
424 | prompt_embeds = self.text_encoder.text_model.final_layer_norm(
425 | prompt_embeds)
426 |
427 | if self.text_encoder is not None:
428 | prompt_embeds_dtype = self.text_encoder.dtype
429 | elif self.unet is not None:
430 | prompt_embeds_dtype = self.unet.dtype
431 | else:
432 | prompt_embeds_dtype = prompt_embeds.dtype
433 |
434 | prompt_embeds = prompt_embeds.to(
435 | dtype=prompt_embeds_dtype, device=device)
436 |
437 | bs_embed, seq_len, _ = prompt_embeds.shape
438 | # duplicate text embeddings for each generation per prompt, using mps friendly method
439 | prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
440 | prompt_embeds = prompt_embeds.view(
441 | bs_embed * num_images_per_prompt, seq_len, -1)
442 |
443 | # get unconditional embeddings for classifier free guidance
444 | if do_classifier_free_guidance and negative_prompt_embeds is None:
445 | uncond_tokens: List[str]
446 | if negative_prompt is None:
447 | uncond_tokens = [""] * batch_size
448 | elif prompt is not None and type(prompt) is not type(negative_prompt):
449 | raise TypeError(
450 | f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
451 | f" {type(prompt)}."
452 | )
453 | elif isinstance(negative_prompt, str):
454 | uncond_tokens = [negative_prompt]
455 | elif batch_size != len(negative_prompt):
456 | raise ValueError(
457 | f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
458 | f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
459 | " the batch size of `prompt`."
460 | )
461 | else:
462 | uncond_tokens = negative_prompt
463 |
464 | # textual inversion: process multi-vector tokens if necessary
465 | if isinstance(self, TextualInversionLoaderMixin):
466 | uncond_tokens = self.maybe_convert_prompt(
467 | uncond_tokens, self.tokenizer)
468 |
469 | max_length = prompt_embeds.shape[1]
470 | uncond_input = self.tokenizer(
471 | uncond_tokens,
472 | padding="max_length",
473 | max_length=max_length,
474 | truncation=True,
475 | return_tensors="pt",
476 | )
477 |
478 | if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
479 | attention_mask = uncond_input.attention_mask.to(device)
480 | else:
481 | attention_mask = None
482 |
483 | negative_prompt_embeds = self.text_encoder(
484 | uncond_input.input_ids.to(device),
485 | attention_mask=attention_mask,
486 | )
487 | negative_prompt_embeds = negative_prompt_embeds[0]
488 |
489 | if do_classifier_free_guidance:
490 | # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
491 | seq_len = negative_prompt_embeds.shape[1]
492 |
493 | negative_prompt_embeds = negative_prompt_embeds.to(
494 | dtype=prompt_embeds_dtype, device=device)
495 |
496 | negative_prompt_embeds = negative_prompt_embeds.repeat(
497 | 1, num_images_per_prompt, 1)
498 | negative_prompt_embeds = negative_prompt_embeds.view(
499 | batch_size * num_images_per_prompt, seq_len, -1)
500 |
501 | if self.text_encoder is not None:
502 | if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND:
503 | # Retrieve the original scale by scaling back the LoRA layers
504 | unscale_lora_layers(self.text_encoder, lora_scale)
505 |
506 | return prompt_embeds, negative_prompt_embeds
507 |
508 | def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
509 | dtype = next(self.image_encoder.parameters()).dtype
510 |
511 | if not isinstance(image, torch.Tensor):
512 | image = self.feature_extractor(
513 | image, return_tensors="pt").pixel_values
514 |
515 | image = image.to(device=device, dtype=dtype)
516 | if output_hidden_states:
517 | image_enc_hidden_states = self.image_encoder(
518 | image, output_hidden_states=True).hidden_states[-2]
519 | image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(
520 | num_images_per_prompt, dim=0)
521 | uncond_image_enc_hidden_states = self.image_encoder(
522 | torch.zeros_like(image), output_hidden_states=True
523 | ).hidden_states[-2]
524 | uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
525 | num_images_per_prompt, dim=0
526 | )
527 | return image_enc_hidden_states, uncond_image_enc_hidden_states
528 | else:
529 | image_embeds = self.image_encoder(image).image_embeds
530 | image_embeds = image_embeds.repeat_interleave(
531 | num_images_per_prompt, dim=0)
532 | uncond_image_embeds = torch.zeros_like(image_embeds)
533 |
534 | return image_embeds, uncond_image_embeds
535 |
536 | def prepare_ip_adapter_image_embeds(
537 | self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
538 | ):
539 | image_embeds = []
540 | if do_classifier_free_guidance:
541 | negative_image_embeds = []
542 | if ip_adapter_image_embeds is None:
543 | if not isinstance(ip_adapter_image, list):
544 | ip_adapter_image = [ip_adapter_image]
545 |
546 | if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
547 | raise ValueError(
548 | f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
549 | )
550 |
551 | for single_ip_adapter_image, image_proj_layer in zip(
552 | ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
553 | ):
554 | output_hidden_state = not isinstance(
555 | image_proj_layer, ImageProjection)
556 | single_image_embeds, single_negative_image_embeds = self.encode_image(
557 | single_ip_adapter_image, device, 1, output_hidden_state
558 | )
559 |
560 | image_embeds.append(single_image_embeds[None, :])
561 | if do_classifier_free_guidance:
562 | negative_image_embeds.append(
563 | single_negative_image_embeds[None, :])
564 | else:
565 | for single_image_embeds in ip_adapter_image_embeds:
566 | if do_classifier_free_guidance:
567 | single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(
568 | 2)
569 | negative_image_embeds.append(single_negative_image_embeds)
570 | image_embeds.append(single_image_embeds)
571 |
572 | ip_adapter_image_embeds = []
573 | for i, single_image_embeds in enumerate(image_embeds):
574 | single_image_embeds = torch.cat(
575 | [single_image_embeds] * num_images_per_prompt, dim=0)
576 | if do_classifier_free_guidance:
577 | single_negative_image_embeds = torch.cat(
578 | [negative_image_embeds[i]] * num_images_per_prompt, dim=0)
579 | single_image_embeds = torch.cat(
580 | [single_negative_image_embeds, single_image_embeds], dim=0)
581 |
582 | single_image_embeds = single_image_embeds.to(device=device)
583 | ip_adapter_image_embeds.append(single_image_embeds)
584 |
585 | return ip_adapter_image_embeds
586 |
587 | def decode_latents(self, latents):
588 | deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead"
589 | deprecate("decode_latents", "1.0.0",
590 | deprecation_message, standard_warn=False)
591 |
592 | latents = 1 / self.vae.config.scaling_factor * latents
593 | image = self.vae.decode(latents, return_dict=False)[0]
594 | image = (image / 2 + 0.5).clamp(0, 1)
595 | # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
596 | image = image.cpu().permute(0, 2, 3, 1).float().numpy()
597 | return image
598 |
599 | def prepare_extra_step_kwargs(self, generator, eta):
600 | # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
601 | # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
602 | # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
603 | # and should be between [0, 1]
604 |
605 | accepts_eta = "eta" in set(inspect.signature(
606 | self.scheduler.step).parameters.keys())
607 | extra_step_kwargs = {}
608 | if accepts_eta:
609 | extra_step_kwargs["eta"] = eta
610 |
611 | # check if the scheduler accepts generator
612 | accepts_generator = "generator" in set(
613 | inspect.signature(self.scheduler.step).parameters.keys())
614 | if accepts_generator:
615 | extra_step_kwargs["generator"] = generator
616 | return extra_step_kwargs
617 |
618 | def check_inputs(
619 | self,
620 | prompt,
621 | height,
622 | width,
623 | callback_steps,
624 | negative_prompt=None,
625 | prompt_embeds=None,
626 | negative_prompt_embeds=None,
627 | ip_adapter_image=None,
628 | ip_adapter_image_embeds=None,
629 | callback_on_step_end_tensor_inputs=None,
630 | ):
631 | if height % 8 != 0 or width % 8 != 0:
632 | raise ValueError(
633 | f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
634 |
635 | if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
636 | raise ValueError(
637 | f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
638 | f" {type(callback_steps)}."
639 | )
640 | if callback_on_step_end_tensor_inputs is not None and not all(
641 | k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
642 | ):
643 | raise ValueError(
644 | f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
645 | )
646 |
647 | if prompt is not None and prompt_embeds is not None:
648 | raise ValueError(
649 | f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
650 | " only forward one of the two."
651 | )
652 | elif prompt is None and prompt_embeds is None:
653 | raise ValueError(
654 | "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
655 | )
656 | elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
657 | raise ValueError(
658 | f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
659 |
660 | if negative_prompt is not None and negative_prompt_embeds is not None:
661 | raise ValueError(
662 | f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
663 | f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
664 | )
665 |
666 | if prompt_embeds is not None and negative_prompt_embeds is not None:
667 | if prompt_embeds.shape != negative_prompt_embeds.shape:
668 | raise ValueError(
669 | "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
670 | f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
671 | f" {negative_prompt_embeds.shape}."
672 | )
673 |
674 | if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
675 | raise ValueError(
676 | "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
677 | )
678 |
679 | if ip_adapter_image_embeds is not None:
680 | if not isinstance(ip_adapter_image_embeds, list):
681 | raise ValueError(
682 | f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
683 | )
684 | elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
685 | raise ValueError(
686 | f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
687 | )
688 |
689 | def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
690 | shape = (
691 | batch_size,
692 | num_channels_latents,
693 | int(height) // self.vae_scale_factor,
694 | int(width) // self.vae_scale_factor,
695 | )
696 | if isinstance(generator, list) and len(generator) != batch_size:
697 | raise ValueError(
698 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
699 | f" size of {batch_size}. Make sure the batch size matches the length of the generators."
700 | )
701 |
702 | if latents is None:
703 | latents = randn_tensor(
704 | shape, generator=generator, device=device, dtype=dtype)
705 | else:
706 | latents = latents.to(device)
707 |
708 | # scale the initial noise by the standard deviation required by the scheduler
709 | latents = latents * self.scheduler.init_noise_sigma
710 | return latents
711 |
712 | # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
713 | def get_guidance_scale_embedding(
714 | self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32
715 | ) -> torch.Tensor:
716 | """
717 | See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
718 |
719 | Args:
720 | w (`torch.Tensor`):
721 | Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings.
722 | embedding_dim (`int`, *optional*, defaults to 512):
723 | Dimension of the embeddings to generate.
724 | dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
725 | Data type of the generated embeddings.
726 |
727 | Returns:
728 | `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`.
729 | """
730 | assert len(w.shape) == 1
731 | w = w * 1000.0
732 |
733 | half_dim = embedding_dim // 2
734 | emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
735 | emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
736 | emb = w.to(dtype)[:, None] * emb[None, :]
737 | emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
738 | if embedding_dim % 2 == 1: # zero pad
739 | emb = torch.nn.functional.pad(emb, (0, 1))
740 | assert emb.shape == (w.shape[0], embedding_dim)
741 | return emb
742 |
743 | @property
744 | def guidance_scale(self):
745 | return self._guidance_scale
746 |
747 | @property
748 | def guidance_rescale(self):
749 | return self._guidance_rescale
750 |
751 | @property
752 | def clip_skip(self):
753 | return self._clip_skip
754 |
755 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
756 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
757 | # corresponds to doing no classifier free guidance.
758 | @property
759 | def do_classifier_free_guidance(self):
760 | return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
761 |
762 | @property
763 | def cross_attention_kwargs(self):
764 | return self._cross_attention_kwargs
765 |
766 | @property
767 | def num_timesteps(self):
768 | return self._num_timesteps
769 |
770 | @property
771 | def interrupt(self):
772 | return self._interrupt
773 |
774 | def decode_rgbd(self, latents, generator, output_type="np"):
775 | dem_latents = latents[:, 4:, :, :]
776 | img_latents = latents[:, :4, :, :]
777 | image = self.vae.decode(img_latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
778 | 0
779 | ]
780 | dem = self.vae.decode(dem_latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
781 | 0
782 | ]
783 | do_denormalize = [True] * image.shape[0]
784 | image = self.image_processor.postprocess(
785 | image, output_type=output_type, do_denormalize=do_denormalize)
786 | dem = self.image_processor.postprocess(
787 | dem, output_type=output_type, do_denormalize=do_denormalize)
788 | return image, dem
789 |
790 | @torch.no_grad()
791 | @replace_example_docstring(EXAMPLE_DOC_STRING)
792 | def __call__(
793 | self,
794 | prompt: Union[str, List[str]] = None,
795 | height: Optional[int] = None,
796 | width: Optional[int] = None,
797 | num_inference_steps: int = 50,
798 | timesteps: List[int] = None,
799 | sigmas: List[float] = None,
800 | guidance_scale: float = 7.5,
801 | negative_prompt: Optional[Union[str, List[str]]] = None,
802 | num_images_per_prompt: Optional[int] = 1,
803 | eta: float = 0.0,
804 | generator: Optional[Union[torch.Generator,
805 | List[torch.Generator]]] = None,
806 | latents: Optional[torch.Tensor] = None,
807 | prompt_embeds: Optional[torch.Tensor] = None,
808 | negative_prompt_embeds: Optional[torch.Tensor] = None,
809 | ip_adapter_image: Optional[PipelineImageInput] = None,
810 | ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
811 | output_type: Optional[str] = "np",
812 | return_dict: bool = True,
813 | cross_attention_kwargs: Optional[Dict[str, Any]] = None,
814 | guidance_rescale: float = 0.0,
815 | clip_skip: Optional[int] = None,
816 | callback_on_step_end: Optional[
817 | Union[Callable[[int, int, Dict], None],
818 | PipelineCallback, MultiPipelineCallbacks]
819 | ] = None,
820 | callback_on_step_end_tensor_inputs: List[str] = ["latents"],
821 | **kwargs,
822 | ):
823 | r"""
824 | The call function to the pipeline for generation.
825 |
826 | Args:
827 | prompt (`str` or `List[str]`, *optional*):
828 | The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
829 | height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
830 | The height in pixels of the generated image.
831 | width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
832 | The width in pixels of the generated image.
833 | num_inference_steps (`int`, *optional*, defaults to 50):
834 | The number of denoising steps. More denoising steps usually lead to a higher quality image at the
835 | expense of slower inference.
836 | timesteps (`List[int]`, *optional*):
837 | Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
838 | in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
839 | passed will be used. Must be in descending order.
840 | sigmas (`List[float]`, *optional*):
841 | Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
842 | their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
843 | will be used.
844 | guidance_scale (`float`, *optional*, defaults to 7.5):
845 | A higher guidance scale value encourages the model to generate images closely linked to the text
846 | `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
847 | negative_prompt (`str` or `List[str]`, *optional*):
848 | The prompt or prompts to guide what to not include in image generation. If not defined, you need to
849 | pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
850 | num_images_per_prompt (`int`, *optional*, defaults to 1):
851 | The number of images to generate per prompt.
852 | eta (`float`, *optional*, defaults to 0.0):
853 | Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
854 | to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
855 | generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
856 | A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
857 | generation deterministic.
858 | latents (`torch.Tensor`, *optional*):
859 | Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
860 | generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
861 | tensor is generated by sampling using the supplied random `generator`.
862 | prompt_embeds (`torch.Tensor`, *optional*):
863 | Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
864 | provided, text embeddings are generated from the `prompt` input argument.
865 | negative_prompt_embeds (`torch.Tensor`, *optional*):
866 | Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
867 | not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
868 | ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
869 | ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
870 | Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
871 | IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should
872 | contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not
873 | provided, embeddings are computed from the `ip_adapter_image` input argument.
874 | output_type (`str`, *optional*, defaults to `"pil"`):
875 | The output format of the generated image. Choose between `PIL.Image` or `np.array`.
876 | return_dict (`bool`, *optional*, defaults to `True`):
877 | Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
878 | plain tuple.
879 | cross_attention_kwargs (`dict`, *optional*):
880 | A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
881 | [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
882 | guidance_rescale (`float`, *optional*, defaults to 0.0):
883 | Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are
884 | Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when
885 | using zero terminal SNR.
886 | clip_skip (`int`, *optional*):
887 | Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
888 | the output of the pre-final layer will be used for computing the prompt embeddings.
889 | callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
890 | A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
891 | each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
892 | DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
893 | list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
894 | callback_on_step_end_tensor_inputs (`List`, *optional*):
895 | The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
896 | will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
897 | `._callback_tensor_inputs` attribute of your pipeline class.
898 |
899 | Examples:
900 |
901 | Returns:
902 | [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
903 | If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
904 | otherwise a `tuple` is returned where the first element is a list with the generated images and the
905 | second element is a list of `bool`s indicating whether the corresponding generated image contains
906 | "not-safe-for-work" (nsfw) content.
907 | """
908 |
909 | callback = kwargs.pop("callback", None)
910 | callback_steps = kwargs.pop("callback_steps", None)
911 |
912 | if callback is not None:
913 | deprecate(
914 | "callback",
915 | "1.0.0",
916 | "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
917 | )
918 | if callback_steps is not None:
919 | deprecate(
920 | "callback_steps",
921 | "1.0.0",
922 | "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
923 | )
924 |
925 | if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
926 | callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
927 |
928 | # 0. Default height and width to unet
929 | height = height or self.unet.config.sample_size * self.vae_scale_factor
930 | width = width or self.unet.config.sample_size * self.vae_scale_factor
931 | # to deal with lora scaling and other possible forward hooks
932 |
933 | # 1. Check inputs. Raise error if not correct
934 | self.check_inputs(
935 | prompt,
936 | height,
937 | width,
938 | callback_steps,
939 | negative_prompt,
940 | prompt_embeds,
941 | negative_prompt_embeds,
942 | ip_adapter_image,
943 | ip_adapter_image_embeds,
944 | callback_on_step_end_tensor_inputs,
945 | )
946 |
947 | self._guidance_scale = guidance_scale
948 | self._guidance_rescale = guidance_rescale
949 | self._clip_skip = clip_skip
950 | self._cross_attention_kwargs = cross_attention_kwargs
951 | self._interrupt = False
952 |
953 | # 2. Define call parameters
954 | if prompt is not None and isinstance(prompt, str):
955 | batch_size = 1
956 | elif prompt is not None and isinstance(prompt, list):
957 | batch_size = len(prompt)
958 | else:
959 | batch_size = prompt_embeds.shape[0]
960 |
961 | device = self._execution_device
962 |
963 | # 3. Encode input prompt
964 | lora_scale = (
965 | self.cross_attention_kwargs.get(
966 | "scale", None) if self.cross_attention_kwargs is not None else None
967 | )
968 |
969 | prompt_embeds, negative_prompt_embeds = self.encode_prompt(
970 | prompt,
971 | device,
972 | num_images_per_prompt,
973 | self.do_classifier_free_guidance,
974 | negative_prompt,
975 | prompt_embeds=prompt_embeds,
976 | negative_prompt_embeds=negative_prompt_embeds,
977 | lora_scale=lora_scale,
978 | clip_skip=self.clip_skip,
979 | )
980 |
981 | # For classifier free guidance, we need to do two forward passes.
982 | # Here we concatenate the unconditional and text embeddings into a single batch
983 | # to avoid doing two forward passes
984 | if self.do_classifier_free_guidance:
985 | prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
986 |
987 | if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
988 | image_embeds = self.prepare_ip_adapter_image_embeds(
989 | ip_adapter_image,
990 | ip_adapter_image_embeds,
991 | device,
992 | batch_size * num_images_per_prompt,
993 | self.do_classifier_free_guidance,
994 | )
995 |
996 | # 4. Prepare timesteps
997 | timesteps, num_inference_steps = retrieve_timesteps(
998 | self.scheduler, num_inference_steps, device, timesteps, sigmas
999 | )
1000 |
1001 | # 5. Prepare latent variables
1002 | num_channels_latents = self.unet.config.in_channels*2
1003 | latents = self.prepare_latents(
1004 | batch_size * num_images_per_prompt,
1005 | num_channels_latents,
1006 | height,
1007 | width,
1008 | prompt_embeds.dtype,
1009 | device,
1010 | generator,
1011 | latents,
1012 | )
1013 |
1014 | # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
1015 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
1016 |
1017 | # 6.1 Add image embeds for IP-Adapter
1018 | added_cond_kwargs = (
1019 | {"image_embeds": image_embeds}
1020 | if (ip_adapter_image is not None or ip_adapter_image_embeds is not None)
1021 | else None
1022 | )
1023 |
1024 | # 6.2 Optionally get Guidance Scale Embedding
1025 | timestep_cond = None
1026 | if self.unet.config.time_cond_proj_dim is not None:
1027 | guidance_scale_tensor = torch.tensor(
1028 | self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
1029 | timestep_cond = self.get_guidance_scale_embedding(
1030 | guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
1031 | ).to(device=device, dtype=latents.dtype)
1032 |
1033 | # 7. Denoising loop
1034 | num_warmup_steps = len(timesteps) - \
1035 | num_inference_steps * self.scheduler.order
1036 | self._num_timesteps = len(timesteps)
1037 | # intermediate_latents = []
1038 | with self.progress_bar(total=num_inference_steps) as progress_bar:
1039 | for i, t in enumerate(timesteps):
1040 | if self.interrupt:
1041 | continue
1042 |
1043 | # expand the latents if we are doing classifier free guidance
1044 | latent_model_input = torch.cat(
1045 | [latents] * 2) if self.do_classifier_free_guidance else latents
1046 | latent_model_input = self.scheduler.scale_model_input(
1047 | latent_model_input, t)
1048 |
1049 | # predict the noise residual
1050 | noise_pred = self.unet(
1051 | latent_model_input,
1052 | t,
1053 | encoder_hidden_states=prompt_embeds,
1054 | timestep_cond=timestep_cond,
1055 | cross_attention_kwargs=self.cross_attention_kwargs,
1056 | added_cond_kwargs=added_cond_kwargs,
1057 | return_dict=False,
1058 | )[0]
1059 |
1060 | # perform guidance
1061 | if self.do_classifier_free_guidance:
1062 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1063 | noise_pred = noise_pred_uncond + self.guidance_scale * \
1064 | (noise_pred_text - noise_pred_uncond)
1065 |
1066 | if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
1067 | # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
1068 | noise_pred = rescale_noise_cfg(
1069 | noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
1070 |
1071 | # compute the previous noisy sample x_t -> x_t-1
1072 | scheduler_output = self.scheduler.step(
1073 | noise_pred, t, latents, **extra_step_kwargs, return_dict=True)
1074 | latents = scheduler_output.prev_sample
1075 | # if i % 10 == 0:
1076 | # intermediate_latents.append(scheduler_output.pred_original_sample)
1077 | if callback_on_step_end is not None:
1078 | callback_kwargs = {}
1079 | for k in callback_on_step_end_tensor_inputs:
1080 | callback_kwargs[k] = locals()[k]
1081 | callback_outputs = callback_on_step_end(
1082 | self, i, t, callback_kwargs)
1083 |
1084 | latents = callback_outputs.pop("latents", latents)
1085 | prompt_embeds = callback_outputs.pop(
1086 | "prompt_embeds", prompt_embeds)
1087 | negative_prompt_embeds = callback_outputs.pop(
1088 | "negative_prompt_embeds", negative_prompt_embeds)
1089 |
1090 | # call the callback, if provided
1091 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1092 | progress_bar.update()
1093 | if callback is not None and i % callback_steps == 0:
1094 | step_idx = i // getattr(self.scheduler, "order", 1)
1095 | callback(step_idx, t, latents)
1096 |
1097 | if XLA_AVAILABLE:
1098 | xm.mark_step()
1099 |
1100 | image, dem = self.decode_rgbd(latents, generator, output_type)
1101 |
1102 | # intermediate = [self.decode_rgbd(latent,generator,output_type)for latent in intermediate_latents]
1103 |
1104 | # Offload all models
1105 | self.maybe_free_model_hooks()
1106 |
1107 | return image, dem
1108 |
--------------------------------------------------------------------------------
/models.py:
--------------------------------------------------------------------------------
1 | ###########################################################################
2 | # References:
3 | # https://github.com/huggingface/diffusers/
4 | ###########################################################################
5 |
6 | from dataclasses import dataclass
7 | from typing import Any, Dict, List, Optional, Tuple, Union
8 |
9 | import torch
10 | import torch.nn as nn
11 |
12 | from diffusers.configuration_utils import ConfigMixin, register_to_config
13 | from diffusers.loaders import PeftAdapterMixin, UNet2DConditionLoadersMixin
14 | from diffusers.loaders.single_file_model import FromOriginalModelMixin
15 | from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers
16 | from diffusers.models.activations import get_activation
17 | from diffusers.models.attention_processor import (
18 | ADDED_KV_ATTENTION_PROCESSORS,
19 | CROSS_ATTENTION_PROCESSORS,
20 | Attention,
21 | AttentionProcessor,
22 | AttnAddedKVProcessor,
23 | AttnProcessor,
24 | FusedAttnProcessor2_0,
25 | )
26 | from diffusers.models.embeddings import (
27 | GaussianFourierProjection,
28 | GLIGENTextBoundingboxProjection,
29 | ImageHintTimeEmbedding,
30 | ImageProjection,
31 | ImageTimeEmbedding,
32 | TextImageProjection,
33 | TextImageTimeEmbedding,
34 | TextTimeEmbedding,
35 | TimestepEmbedding,
36 | Timesteps,
37 | )
38 | from diffusers.models.modeling_utils import ModelMixin
39 | from diffusers.models.unets.unet_2d_blocks import (
40 | get_down_block,
41 | get_mid_block,
42 | get_up_block,
43 | )
44 |
45 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name
46 |
47 |
48 | @dataclass
49 | class UNet2DConditionOutput(BaseOutput):
50 | """
51 | The output of [`UNet2DConditionModel`].
52 |
53 | Args:
54 | sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)`):
55 | The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
56 | """
57 |
58 | sample: torch.Tensor = None
59 |
60 |
61 | class UNetDEMConditionModel(
62 | ModelMixin, ConfigMixin, FromOriginalModelMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin
63 | ):
64 | r"""
65 | A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
66 | shaped output.
67 |
68 | This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
69 | for all models (such as downloading or saving).
70 |
71 | Parameters:
72 | sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
73 | Height and width of input/output sample.
74 | in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample.
75 | out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
76 | center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
77 | flip_sin_to_cos (`bool`, *optional*, defaults to `True`):
78 | Whether to flip the sin to cos in the time embedding.
79 | freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
80 | down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
81 | The tuple of downsample blocks to use.
82 | mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
83 | Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`, `UNetMidBlock2D`, or
84 | `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
85 | up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
86 | The tuple of upsample blocks to use.
87 | only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):
88 | Whether to include self-attention in the basic transformer blocks, see
89 | [`~models.attention.BasicTransformerBlock`].
90 | block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
91 | The tuple of output channels for each block.
92 | layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
93 | downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
94 | mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
95 | dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
96 | act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
97 | norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
98 | If `None`, normalization and activation layers is skipped in post-processing.
99 | norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
100 | cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
101 | The dimension of the cross attention features.
102 | transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1):
103 | The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
104 | [`~models.unets.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unets.unet_2d_blocks.CrossAttnUpBlock2D`],
105 | [`~models.unets.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
106 | reverse_transformer_layers_per_block : (`Tuple[Tuple]`, *optional*, defaults to None):
107 | The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`], in the upsampling
108 | blocks of the U-Net. Only relevant if `transformer_layers_per_block` is of type `Tuple[Tuple]` and for
109 | [`~models.unets.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unets.unet_2d_blocks.CrossAttnUpBlock2D`],
110 | [`~models.unets.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
111 | encoder_hid_dim (`int`, *optional*, defaults to None):
112 | If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
113 | dimension to `cross_attention_dim`.
114 | encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
115 | If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
116 | embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
117 | attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
118 | num_attention_heads (`int`, *optional*):
119 | The number of attention heads. If not defined, defaults to `attention_head_dim`
120 | resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
121 | for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
122 | class_embed_type (`str`, *optional*, defaults to `None`):
123 | The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
124 | `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
125 | addition_embed_type (`str`, *optional*, defaults to `None`):
126 | Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
127 | "text". "text" will use the `TextTimeEmbedding` layer.
128 | addition_time_embed_dim: (`int`, *optional*, defaults to `None`):
129 | Dimension for the timestep embeddings.
130 | num_class_embeds (`int`, *optional*, defaults to `None`):
131 | Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
132 | class conditioning with `class_embed_type` equal to `None`.
133 | time_embedding_type (`str`, *optional*, defaults to `positional`):
134 | The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
135 | time_embedding_dim (`int`, *optional*, defaults to `None`):
136 | An optional override for the dimension of the projected time embedding.
137 | time_embedding_act_fn (`str`, *optional*, defaults to `None`):
138 | Optional activation function to use only once on the time embeddings before they are passed to the rest of
139 | the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`.
140 | timestep_post_act (`str`, *optional*, defaults to `None`):
141 | The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
142 | time_cond_proj_dim (`int`, *optional*, defaults to `None`):
143 | The dimension of `cond_proj` layer in the timestep embedding.
144 | conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer.
145 | conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer.
146 | projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when
147 | `class_embed_type="projection"`. Required when `class_embed_type="projection"`.
148 | class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
149 | embeddings with the class embeddings.
150 | mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
151 | Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If
152 | `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the
153 | `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False`
154 | otherwise.
155 | """
156 |
157 | _supports_gradient_checkpointing = True
158 | _no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D", "CrossAttnUpBlock2D"]
159 |
160 | @register_to_config
161 | def __init__(
162 | self,
163 | sample_size: Optional[int] = None,
164 | in_channels: int = 8,
165 | out_channels: int = 8,
166 | center_input_sample: bool = False,
167 | flip_sin_to_cos: bool = True,
168 | freq_shift: int = 0,
169 | down_block_types: Tuple[str] = (
170 | "CrossAttnDownBlock2D",
171 | "CrossAttnDownBlock2D",
172 | "CrossAttnDownBlock2D",
173 | "DownBlock2D",
174 | ),
175 | mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
176 | up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
177 | only_cross_attention: Union[bool, Tuple[bool]] = False,
178 | block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
179 | layers_per_block: Union[int, Tuple[int]] = 2,
180 | downsample_padding: int = 1,
181 | mid_block_scale_factor: float = 1,
182 | dropout: float = 0.0,
183 | act_fn: str = "silu",
184 | norm_num_groups: Optional[int] = 32,
185 | norm_eps: float = 1e-5,
186 | cross_attention_dim: Union[int, Tuple[int]] = 1280,
187 | transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
188 | reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None,
189 | encoder_hid_dim: Optional[int] = None,
190 | encoder_hid_dim_type: Optional[str] = None,
191 | attention_head_dim: Union[int, Tuple[int]] = 8,
192 | num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
193 | dual_cross_attention: bool = False,
194 | use_linear_projection: bool = False,
195 | class_embed_type: Optional[str] = None,
196 | addition_embed_type: Optional[str] = None,
197 | addition_time_embed_dim: Optional[int] = None,
198 | num_class_embeds: Optional[int] = None,
199 | upcast_attention: bool = False,
200 | resnet_time_scale_shift: str = "default",
201 | resnet_skip_time_act: bool = False,
202 | resnet_out_scale_factor: float = 1.0,
203 | time_embedding_type: str = "positional",
204 | time_embedding_dim: Optional[int] = None,
205 | time_embedding_act_fn: Optional[str] = None,
206 | timestep_post_act: Optional[str] = None,
207 | time_cond_proj_dim: Optional[int] = None,
208 | conv_in_kernel: int = 3,
209 | conv_out_kernel: int = 3,
210 | projection_class_embeddings_input_dim: Optional[int] = None,
211 | attention_type: str = "default",
212 | class_embeddings_concat: bool = False,
213 | mid_block_only_cross_attention: Optional[bool] = None,
214 | cross_attention_norm: Optional[str] = None,
215 | addition_embed_type_num_heads: int = 64,
216 | ):
217 | super().__init__()
218 |
219 | self.sample_size = sample_size
220 |
221 | if num_attention_heads is not None:
222 | raise ValueError(
223 | "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
224 | )
225 |
226 | # If `num_attention_heads` is not defined (which is the case for most models)
227 | # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
228 | # The reason for this behavior is to correct for incorrectly named variables that were introduced
229 | # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
230 | # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
231 | # which is why we correct for the naming here.
232 | num_attention_heads = num_attention_heads or attention_head_dim
233 |
234 | # Check inputs
235 | self._check_config(
236 | down_block_types=down_block_types,
237 | up_block_types=up_block_types,
238 | only_cross_attention=only_cross_attention,
239 | block_out_channels=block_out_channels,
240 | layers_per_block=layers_per_block,
241 | cross_attention_dim=cross_attention_dim,
242 | transformer_layers_per_block=transformer_layers_per_block,
243 | reverse_transformer_layers_per_block=reverse_transformer_layers_per_block,
244 | attention_head_dim=attention_head_dim,
245 | num_attention_heads=num_attention_heads,
246 | )
247 |
248 | # input
249 | conv_in_padding = (conv_in_kernel - 1) // 2
250 | self.conv_in_img = nn.Conv2d(
251 | in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
252 | )
253 | self.conv_in_dem = nn.Conv2d(
254 | in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
255 | )
256 |
257 | # time
258 | time_embed_dim, timestep_input_dim = self._set_time_proj(
259 | time_embedding_type,
260 | block_out_channels=block_out_channels,
261 | flip_sin_to_cos=flip_sin_to_cos,
262 | freq_shift=freq_shift,
263 | time_embedding_dim=time_embedding_dim,
264 | )
265 |
266 | self.time_embedding = TimestepEmbedding(
267 | timestep_input_dim,
268 | time_embed_dim,
269 | act_fn=act_fn,
270 | post_act_fn=timestep_post_act,
271 | cond_proj_dim=time_cond_proj_dim,
272 | )
273 |
274 | self._set_encoder_hid_proj(
275 | encoder_hid_dim_type,
276 | cross_attention_dim=cross_attention_dim,
277 | encoder_hid_dim=encoder_hid_dim,
278 | )
279 |
280 | # class embedding
281 | self._set_class_embedding(
282 | class_embed_type,
283 | act_fn=act_fn,
284 | num_class_embeds=num_class_embeds,
285 | projection_class_embeddings_input_dim=projection_class_embeddings_input_dim,
286 | time_embed_dim=time_embed_dim,
287 | timestep_input_dim=timestep_input_dim,
288 | )
289 |
290 | self._set_add_embedding(
291 | addition_embed_type,
292 | addition_embed_type_num_heads=addition_embed_type_num_heads,
293 | addition_time_embed_dim=addition_time_embed_dim,
294 | cross_attention_dim=cross_attention_dim,
295 | encoder_hid_dim=encoder_hid_dim,
296 | flip_sin_to_cos=flip_sin_to_cos,
297 | freq_shift=freq_shift,
298 | projection_class_embeddings_input_dim=projection_class_embeddings_input_dim,
299 | time_embed_dim=time_embed_dim,
300 | )
301 |
302 | if time_embedding_act_fn is None:
303 | self.time_embed_act = None
304 | else:
305 | self.time_embed_act = get_activation(time_embedding_act_fn)
306 |
307 | self.down_blocks = nn.ModuleList([])
308 | self.up_blocks = nn.ModuleList([])
309 |
310 | if isinstance(only_cross_attention, bool):
311 | if mid_block_only_cross_attention is None:
312 | mid_block_only_cross_attention = only_cross_attention
313 |
314 | only_cross_attention = [only_cross_attention] * len(down_block_types)
315 |
316 | if mid_block_only_cross_attention is None:
317 | mid_block_only_cross_attention = False
318 |
319 | if isinstance(num_attention_heads, int):
320 | num_attention_heads = (num_attention_heads,) * len(down_block_types)
321 |
322 | if isinstance(attention_head_dim, int):
323 | attention_head_dim = (attention_head_dim,) * len(down_block_types)
324 |
325 | if isinstance(cross_attention_dim, int):
326 | cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
327 |
328 | if isinstance(layers_per_block, int):
329 | layers_per_block = [layers_per_block] * len(down_block_types)
330 |
331 | if isinstance(transformer_layers_per_block, int):
332 | transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
333 |
334 | if class_embeddings_concat:
335 | # The time embeddings are concatenated with the class embeddings. The dimension of the
336 | # time embeddings passed to the down, middle, and up blocks is twice the dimension of the
337 | # regular time embeddings
338 | blocks_time_embed_dim = time_embed_dim * 2
339 | else:
340 | blocks_time_embed_dim = time_embed_dim
341 |
342 | # down
343 | output_channel = block_out_channels[0]
344 |
345 |
346 | for i, down_block_type in enumerate(down_block_types):
347 | input_channel = output_channel
348 | output_channel = block_out_channels[i]
349 | is_final_block = i == len(block_out_channels) - 1
350 | down_block_kwargs = {"down_block_type":down_block_type,
351 | "num_layers":layers_per_block[i],
352 | "transformer_layers_per_block":transformer_layers_per_block[i],
353 | "in_channels":input_channel,
354 | "out_channels":output_channel,
355 | "temb_channels":blocks_time_embed_dim,
356 | "add_downsample":not is_final_block,
357 | "resnet_eps":norm_eps,
358 | "resnet_act_fn":act_fn,
359 | "resnet_groups":norm_num_groups,
360 | "cross_attention_dim":cross_attention_dim[i],
361 | "num_attention_heads":num_attention_heads[i],
362 | "downsample_padding":downsample_padding,
363 | "dual_cross_attention":dual_cross_attention,
364 | "use_linear_projection":use_linear_projection,
365 | "only_cross_attention":only_cross_attention[i],
366 | "upcast_attention":upcast_attention,
367 | "resnet_time_scale_shift":resnet_time_scale_shift,
368 | "attention_type":attention_type,
369 | "resnet_skip_time_act":resnet_skip_time_act,
370 | "resnet_out_scale_factor":resnet_out_scale_factor,
371 | "cross_attention_norm":cross_attention_norm,
372 | "attention_head_dim":attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
373 | "dropout":dropout}
374 |
375 | if i == 0:
376 | self.head_img = get_down_block(**down_block_kwargs)
377 | # same architecture as the head_img but different weights
378 | self.head_dem = get_down_block(**down_block_kwargs)
379 | # elif i == 1:
380 | # down_block_kwargs["in_channels"] = input_channel *2 # concatenate the output of the head_img and head_dem
381 | # down_block = get_down_block(**down_block_kwargs)
382 | # self.down_blocks.append(down_block)
383 | else:
384 | down_block = get_down_block(**down_block_kwargs)
385 | self.down_blocks.append(down_block)
386 |
387 | # mid
388 | self.mid_block = get_mid_block(
389 | mid_block_type,
390 | temb_channels=blocks_time_embed_dim,
391 | in_channels=block_out_channels[-1],
392 | resnet_eps=norm_eps,
393 | resnet_act_fn=act_fn,
394 | resnet_groups=norm_num_groups,
395 | output_scale_factor=mid_block_scale_factor,
396 | transformer_layers_per_block=transformer_layers_per_block[-1],
397 | num_attention_heads=num_attention_heads[-1],
398 | cross_attention_dim=cross_attention_dim[-1],
399 | dual_cross_attention=dual_cross_attention,
400 | use_linear_projection=use_linear_projection,
401 | mid_block_only_cross_attention=mid_block_only_cross_attention,
402 | upcast_attention=upcast_attention,
403 | resnet_time_scale_shift=resnet_time_scale_shift,
404 | attention_type=attention_type,
405 | resnet_skip_time_act=resnet_skip_time_act,
406 | cross_attention_norm=cross_attention_norm,
407 | attention_head_dim=attention_head_dim[-1],
408 | dropout=dropout,
409 | )
410 |
411 | # count how many layers upsample the images
412 | self.num_upsamplers = 0
413 |
414 | # up
415 | reversed_block_out_channels = list(reversed(block_out_channels))
416 | reversed_num_attention_heads = list(reversed(num_attention_heads))
417 | reversed_layers_per_block = list(reversed(layers_per_block))
418 | reversed_cross_attention_dim = list(reversed(cross_attention_dim))
419 | reversed_transformer_layers_per_block = (
420 | list(reversed(transformer_layers_per_block))
421 | if reverse_transformer_layers_per_block is None
422 | else reverse_transformer_layers_per_block
423 | )
424 | only_cross_attention = list(reversed(only_cross_attention))
425 |
426 | output_channel = reversed_block_out_channels[0]
427 | for i, up_block_type in enumerate(up_block_types):
428 | is_final_block = i == len(block_out_channels) - 1
429 |
430 | prev_output_channel = output_channel
431 | output_channel = reversed_block_out_channels[i]
432 | input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
433 |
434 | # add upsample block for all BUT final layer
435 | if not is_final_block:
436 | add_upsample = True
437 | self.num_upsamplers += 1
438 | else:
439 | add_upsample = False
440 |
441 | up_block_kwargs = {"up_block_type":up_block_type,
442 | "num_layers":reversed_layers_per_block[i] + 1,
443 | "transformer_layers_per_block":reversed_transformer_layers_per_block[i],
444 | "in_channels":input_channel,
445 | "out_channels":output_channel,
446 | "prev_output_channel":prev_output_channel,
447 | "temb_channels":blocks_time_embed_dim,
448 | "add_upsample":add_upsample,
449 | "resnet_eps":norm_eps,
450 | "resnet_act_fn":act_fn,
451 | "resolution_idx":i,
452 | "resnet_groups":norm_num_groups,
453 | "cross_attention_dim":reversed_cross_attention_dim[i],
454 | "num_attention_heads":reversed_num_attention_heads[i],
455 | "dual_cross_attention":dual_cross_attention,
456 | "use_linear_projection":use_linear_projection,
457 | "only_cross_attention":only_cross_attention[i],
458 | "upcast_attention":upcast_attention,
459 | "resnet_time_scale_shift":resnet_time_scale_shift,
460 | "attention_type":attention_type,
461 | "resnet_skip_time_act":resnet_skip_time_act,
462 | "resnet_out_scale_factor":resnet_out_scale_factor,
463 | "cross_attention_norm":cross_attention_norm,
464 | "attention_head_dim":attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
465 | "dropout":dropout,}
466 |
467 | # if i == len(block_out_channels) - 2:
468 | # up_block_kwargs["in_channels"] = input_channel*2
469 | # up_block = get_up_block(**up_block_kwargs)
470 | # self.up_blocks.append(up_block)
471 |
472 | if is_final_block :
473 |
474 | self.head_out_img = get_up_block(**up_block_kwargs)
475 | self.head_out_dem = get_up_block(**up_block_kwargs)
476 |
477 | else :
478 | up_block = get_up_block(**up_block_kwargs)
479 | self.up_blocks.append(up_block)
480 |
481 |
482 | # out
483 | if norm_num_groups is not None:
484 | self.conv_norm_out_img = nn.GroupNorm(
485 | num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
486 | )
487 | self.conv_norm_out_dem = nn.GroupNorm(
488 | num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
489 | )
490 |
491 | self.conv_act = get_activation(act_fn)
492 |
493 | else:
494 | self.conv_norm_out_img = None
495 | self.conv_norm_out_dem = None
496 | self.conv_act = None
497 |
498 | conv_out_padding = (conv_out_kernel - 1) // 2
499 | self.conv_out_img = nn.Conv2d(
500 | block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
501 | )
502 | self.conv_out_dem = nn.Conv2d(
503 | block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
504 | )
505 |
506 | self._set_pos_net_if_use_gligen(attention_type=attention_type, cross_attention_dim=cross_attention_dim)
507 |
508 | def _check_config(
509 | self,
510 | down_block_types: Tuple[str],
511 | up_block_types: Tuple[str],
512 | only_cross_attention: Union[bool, Tuple[bool]],
513 | block_out_channels: Tuple[int],
514 | layers_per_block: Union[int, Tuple[int]],
515 | cross_attention_dim: Union[int, Tuple[int]],
516 | transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple[int]]],
517 | reverse_transformer_layers_per_block: bool,
518 | attention_head_dim: int,
519 | num_attention_heads: Optional[Union[int, Tuple[int]]],
520 | ):
521 | if len(down_block_types) != len(up_block_types):
522 | raise ValueError(
523 | f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
524 | )
525 |
526 | if len(block_out_channels) != len(down_block_types):
527 | raise ValueError(
528 | f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
529 | )
530 |
531 | if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
532 | raise ValueError(
533 | f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
534 | )
535 |
536 | if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
537 | raise ValueError(
538 | f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
539 | )
540 |
541 | if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
542 | raise ValueError(
543 | f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
544 | )
545 |
546 | if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
547 | raise ValueError(
548 | f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
549 | )
550 |
551 | if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
552 | raise ValueError(
553 | f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
554 | )
555 | if isinstance(transformer_layers_per_block, list) and reverse_transformer_layers_per_block is None:
556 | for layer_number_per_block in transformer_layers_per_block:
557 | if isinstance(layer_number_per_block, list):
558 | raise ValueError("Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet.")
559 |
560 | def _set_time_proj(
561 | self,
562 | time_embedding_type: str,
563 | block_out_channels: int,
564 | flip_sin_to_cos: bool,
565 | freq_shift: float,
566 | time_embedding_dim: int,
567 | ) -> Tuple[int, int]:
568 | if time_embedding_type == "fourier":
569 | time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
570 | if time_embed_dim % 2 != 0:
571 | raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
572 | self.time_proj = GaussianFourierProjection(
573 | time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
574 | )
575 | timestep_input_dim = time_embed_dim
576 | elif time_embedding_type == "positional":
577 | time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
578 |
579 | self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
580 | timestep_input_dim = block_out_channels[0]
581 | else:
582 | raise ValueError(
583 | f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
584 | )
585 |
586 | return time_embed_dim, timestep_input_dim
587 |
588 | def _set_encoder_hid_proj(
589 | self,
590 | encoder_hid_dim_type: Optional[str],
591 | cross_attention_dim: Union[int, Tuple[int]],
592 | encoder_hid_dim: Optional[int],
593 | ):
594 | if encoder_hid_dim_type is None and encoder_hid_dim is not None:
595 | encoder_hid_dim_type = "text_proj"
596 | self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
597 | logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
598 |
599 | if encoder_hid_dim is None and encoder_hid_dim_type is not None:
600 | raise ValueError(
601 | f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
602 | )
603 |
604 | if encoder_hid_dim_type == "text_proj":
605 | self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
606 | elif encoder_hid_dim_type == "text_image_proj":
607 | # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
608 | # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
609 | # case when `addition_embed_type == "text_image_proj"` (Kandinsky 2.1)`
610 | self.encoder_hid_proj = TextImageProjection(
611 | text_embed_dim=encoder_hid_dim,
612 | image_embed_dim=cross_attention_dim,
613 | cross_attention_dim=cross_attention_dim,
614 | )
615 | elif encoder_hid_dim_type == "image_proj":
616 | # Kandinsky 2.2
617 | self.encoder_hid_proj = ImageProjection(
618 | image_embed_dim=encoder_hid_dim,
619 | cross_attention_dim=cross_attention_dim,
620 | )
621 | elif encoder_hid_dim_type is not None:
622 | raise ValueError(
623 | f"`encoder_hid_dim_type`: {encoder_hid_dim_type} must be None, 'text_proj', 'text_image_proj', or 'image_proj'."
624 | )
625 | else:
626 | self.encoder_hid_proj = None
627 |
628 | def _set_class_embedding(
629 | self,
630 | class_embed_type: Optional[str],
631 | act_fn: str,
632 | num_class_embeds: Optional[int],
633 | projection_class_embeddings_input_dim: Optional[int],
634 | time_embed_dim: int,
635 | timestep_input_dim: int,
636 | ):
637 | if class_embed_type is None and num_class_embeds is not None:
638 | self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
639 | elif class_embed_type == "timestep":
640 | self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
641 | elif class_embed_type == "identity":
642 | self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
643 | elif class_embed_type == "projection":
644 | if projection_class_embeddings_input_dim is None:
645 | raise ValueError(
646 | "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
647 | )
648 | # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
649 | # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
650 | # 2. it projects from an arbitrary input dimension.
651 | #
652 | # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
653 | # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
654 | # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
655 | self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
656 | elif class_embed_type == "simple_projection":
657 | if projection_class_embeddings_input_dim is None:
658 | raise ValueError(
659 | "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
660 | )
661 | self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim)
662 | else:
663 | self.class_embedding = None
664 |
665 | def _set_add_embedding(
666 | self,
667 | addition_embed_type: str,
668 | addition_embed_type_num_heads: int,
669 | addition_time_embed_dim: Optional[int],
670 | flip_sin_to_cos: bool,
671 | freq_shift: float,
672 | cross_attention_dim: Optional[int],
673 | encoder_hid_dim: Optional[int],
674 | projection_class_embeddings_input_dim: Optional[int],
675 | time_embed_dim: int,
676 | ):
677 | if addition_embed_type == "text":
678 | if encoder_hid_dim is not None:
679 | text_time_embedding_from_dim = encoder_hid_dim
680 | else:
681 | text_time_embedding_from_dim = cross_attention_dim
682 |
683 | self.add_embedding = TextTimeEmbedding(
684 | text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
685 | )
686 | elif addition_embed_type == "text_image":
687 | # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
688 | # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
689 | # case when `addition_embed_type == "text_image"` (Kandinsky 2.1)`
690 | self.add_embedding = TextImageTimeEmbedding(
691 | text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
692 | )
693 | elif addition_embed_type == "text_time":
694 | self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
695 | self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
696 | elif addition_embed_type == "image":
697 | # Kandinsky 2.2
698 | self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
699 | elif addition_embed_type == "image_hint":
700 | # Kandinsky 2.2 ControlNet
701 | self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
702 | elif addition_embed_type is not None:
703 | raise ValueError(
704 | f"`addition_embed_type`: {addition_embed_type} must be None, 'text', 'text_image', 'text_time', 'image', or 'image_hint'."
705 | )
706 |
707 | def _set_pos_net_if_use_gligen(self, attention_type: str, cross_attention_dim: int):
708 | if attention_type in ["gated", "gated-text-image"]:
709 | positive_len = 768
710 | if isinstance(cross_attention_dim, int):
711 | positive_len = cross_attention_dim
712 | elif isinstance(cross_attention_dim, (list, tuple)):
713 | positive_len = cross_attention_dim[0]
714 |
715 | feature_type = "text-only" if attention_type == "gated" else "text-image"
716 | self.position_net = GLIGENTextBoundingboxProjection(
717 | positive_len=positive_len, out_dim=cross_attention_dim, feature_type=feature_type
718 | )
719 |
720 | @property
721 | def attn_processors(self) -> Dict[str, AttentionProcessor]:
722 | r"""
723 | Returns:
724 | `dict` of attention processors: A dictionary containing all attention processors used in the model with
725 | indexed by its weight name.
726 | """
727 | # set recursively
728 | processors = {}
729 |
730 | def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
731 | if hasattr(module, "get_processor"):
732 | processors[f"{name}.processor"] = module.get_processor()
733 |
734 | for sub_name, child in module.named_children():
735 | fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
736 |
737 | return processors
738 |
739 | for name, module in self.named_children():
740 | fn_recursive_add_processors(name, module, processors)
741 |
742 | return processors
743 |
744 | def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
745 | r"""
746 | Sets the attention processor to use to compute attention.
747 |
748 | Parameters:
749 | processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
750 | The instantiated processor class or a dictionary of processor classes that will be set as the processor
751 | for **all** `Attention` layers.
752 |
753 | If `processor` is a dict, the key needs to define the path to the corresponding cross attention
754 | processor. This is strongly recommended when setting trainable attention processors.
755 |
756 | """
757 | count = len(self.attn_processors.keys())
758 |
759 | if isinstance(processor, dict) and len(processor) != count:
760 | raise ValueError(
761 | f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
762 | f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
763 | )
764 |
765 | def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
766 | if hasattr(module, "set_processor"):
767 | if not isinstance(processor, dict):
768 | module.set_processor(processor)
769 | else:
770 | module.set_processor(processor.pop(f"{name}.processor"))
771 |
772 | for sub_name, child in module.named_children():
773 | fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
774 |
775 | for name, module in self.named_children():
776 | fn_recursive_attn_processor(name, module, processor)
777 |
778 | def set_default_attn_processor(self):
779 | """
780 | Disables custom attention processors and sets the default attention implementation.
781 | """
782 | if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
783 | processor = AttnAddedKVProcessor()
784 | elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
785 | processor = AttnProcessor()
786 | else:
787 | raise ValueError(
788 | f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
789 | )
790 |
791 | self.set_attn_processor(processor)
792 |
793 | def set_attention_slice(self, slice_size: Union[str, int, List[int]] = "auto"):
794 | r"""
795 | Enable sliced attention computation.
796 |
797 | When this option is enabled, the attention module splits the input tensor in slices to compute attention in
798 | several steps. This is useful for saving some memory in exchange for a small decrease in speed.
799 |
800 | Args:
801 | slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
802 | When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
803 | `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
804 | provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
805 | must be a multiple of `slice_size`.
806 | """
807 | sliceable_head_dims = []
808 |
809 | def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
810 | if hasattr(module, "set_attention_slice"):
811 | sliceable_head_dims.append(module.sliceable_head_dim)
812 |
813 | for child in module.children():
814 | fn_recursive_retrieve_sliceable_dims(child)
815 |
816 | # retrieve number of attention layers
817 | for module in self.children():
818 | fn_recursive_retrieve_sliceable_dims(module)
819 |
820 | num_sliceable_layers = len(sliceable_head_dims)
821 |
822 | if slice_size == "auto":
823 | # half the attention head size is usually a good trade-off between
824 | # speed and memory
825 | slice_size = [dim // 2 for dim in sliceable_head_dims]
826 | elif slice_size == "max":
827 | # make smallest slice possible
828 | slice_size = num_sliceable_layers * [1]
829 |
830 | slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
831 |
832 | if len(slice_size) != len(sliceable_head_dims):
833 | raise ValueError(
834 | f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
835 | f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
836 | )
837 |
838 | for i in range(len(slice_size)):
839 | size = slice_size[i]
840 | dim = sliceable_head_dims[i]
841 | if size is not None and size > dim:
842 | raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
843 |
844 | # Recursively walk through all the children.
845 | # Any children which exposes the set_attention_slice method
846 | # gets the message
847 | def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
848 | if hasattr(module, "set_attention_slice"):
849 | module.set_attention_slice(slice_size.pop())
850 |
851 | for child in module.children():
852 | fn_recursive_set_attention_slice(child, slice_size)
853 |
854 | reversed_slice_size = list(reversed(slice_size))
855 | for module in self.children():
856 | fn_recursive_set_attention_slice(module, reversed_slice_size)
857 |
858 | def _set_gradient_checkpointing(self, module, value=False):
859 | if hasattr(module, "gradient_checkpointing"):
860 | module.gradient_checkpointing = value
861 |
862 | def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
863 | r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
864 |
865 | The suffixes after the scaling factors represent the stage blocks where they are being applied.
866 |
867 | Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that
868 | are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
869 |
870 | Args:
871 | s1 (`float`):
872 | Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
873 | mitigate the "oversmoothing effect" in the enhanced denoising process.
874 | s2 (`float`):
875 | Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
876 | mitigate the "oversmoothing effect" in the enhanced denoising process.
877 | b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
878 | b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
879 | """
880 | for i, upsample_block in enumerate(self.up_blocks):
881 | setattr(upsample_block, "s1", s1)
882 | setattr(upsample_block, "s2", s2)
883 | setattr(upsample_block, "b1", b1)
884 | setattr(upsample_block, "b2", b2)
885 |
886 | def disable_freeu(self):
887 | """Disables the FreeU mechanism."""
888 | freeu_keys = {"s1", "s2", "b1", "b2"}
889 | for i, upsample_block in enumerate(self.up_blocks):
890 | for k in freeu_keys:
891 | if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None:
892 | setattr(upsample_block, k, None)
893 |
894 | def fuse_qkv_projections(self):
895 | """
896 | Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
897 | are fused. For cross-attention modules, key and value projection matrices are fused.
898 |
899 |
900 |
901 | This API is 🧪 experimental.
902 |
903 |
904 | """
905 | self.original_attn_processors = None
906 |
907 | for _, attn_processor in self.attn_processors.items():
908 | if "Added" in str(attn_processor.__class__.__name__):
909 | raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
910 |
911 | self.original_attn_processors = self.attn_processors
912 |
913 | for module in self.modules():
914 | if isinstance(module, Attention):
915 | module.fuse_projections(fuse=True)
916 |
917 | self.set_attn_processor(FusedAttnProcessor2_0())
918 |
919 | def unfuse_qkv_projections(self):
920 | """Disables the fused QKV projection if enabled.
921 |
922 |
923 |
924 | This API is 🧪 experimental.
925 |
926 |
927 |
928 | """
929 | if self.original_attn_processors is not None:
930 | self.set_attn_processor(self.original_attn_processors)
931 |
932 | def get_time_embed(
933 | self, sample: torch.Tensor, timestep: Union[torch.Tensor, float, int]
934 | ) -> Optional[torch.Tensor]:
935 | timesteps = timestep
936 | if not torch.is_tensor(timesteps):
937 | # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
938 | # This would be a good case for the `match` statement (Python 3.10+)
939 | is_mps = sample.device.type == "mps"
940 | if isinstance(timestep, float):
941 | dtype = torch.float32 if is_mps else torch.float64
942 | else:
943 | dtype = torch.int32 if is_mps else torch.int64
944 | timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
945 | elif len(timesteps.shape) == 0:
946 | timesteps = timesteps[None].to(sample.device)
947 |
948 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
949 | timesteps = timesteps.expand(sample.shape[0])
950 |
951 | t_emb = self.time_proj(timesteps)
952 | # `Timesteps` does not contain any weights and will always return f32 tensors
953 | # but time_embedding might actually be running in fp16. so we need to cast here.
954 | # there might be better ways to encapsulate this.
955 | t_emb = t_emb.to(dtype=sample.dtype)
956 | return t_emb
957 |
958 | def get_class_embed(self, sample: torch.Tensor, class_labels: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
959 | class_emb = None
960 | if self.class_embedding is not None:
961 | if class_labels is None:
962 | raise ValueError("class_labels should be provided when num_class_embeds > 0")
963 |
964 | if self.config.class_embed_type == "timestep":
965 | class_labels = self.time_proj(class_labels)
966 |
967 | # `Timesteps` does not contain any weights and will always return f32 tensors
968 | # there might be better ways to encapsulate this.
969 | class_labels = class_labels.to(dtype=sample.dtype)
970 |
971 | class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
972 | return class_emb
973 |
974 | def get_aug_embed(
975 | self, emb: torch.Tensor, encoder_hidden_states: torch.Tensor, added_cond_kwargs: Dict[str, Any]
976 | ) -> Optional[torch.Tensor]:
977 | aug_emb = None
978 | if self.config.addition_embed_type == "text":
979 | aug_emb = self.add_embedding(encoder_hidden_states)
980 | elif self.config.addition_embed_type == "text_image":
981 | # Kandinsky 2.1 - style
982 | if "image_embeds" not in added_cond_kwargs:
983 | raise ValueError(
984 | f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
985 | )
986 |
987 | image_embs = added_cond_kwargs.get("image_embeds")
988 | text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
989 | aug_emb = self.add_embedding(text_embs, image_embs)
990 | elif self.config.addition_embed_type == "text_time":
991 | # SDXL - style
992 | if "text_embeds" not in added_cond_kwargs:
993 | raise ValueError(
994 | f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
995 | )
996 | text_embeds = added_cond_kwargs.get("text_embeds")
997 | if "time_ids" not in added_cond_kwargs:
998 | raise ValueError(
999 | f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
1000 | )
1001 | time_ids = added_cond_kwargs.get("time_ids")
1002 | time_embeds = self.add_time_proj(time_ids.flatten())
1003 | time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
1004 | add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
1005 | add_embeds = add_embeds.to(emb.dtype)
1006 | aug_emb = self.add_embedding(add_embeds)
1007 | elif self.config.addition_embed_type == "image":
1008 | # Kandinsky 2.2 - style
1009 | if "image_embeds" not in added_cond_kwargs:
1010 | raise ValueError(
1011 | f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
1012 | )
1013 | image_embs = added_cond_kwargs.get("image_embeds")
1014 | aug_emb = self.add_embedding(image_embs)
1015 | elif self.config.addition_embed_type == "image_hint":
1016 | # Kandinsky 2.2 ControlNet - style
1017 | if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:
1018 | raise ValueError(
1019 | f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
1020 | )
1021 | image_embs = added_cond_kwargs.get("image_embeds")
1022 | hint = added_cond_kwargs.get("hint")
1023 | aug_emb = self.add_embedding(image_embs, hint)
1024 | return aug_emb
1025 |
1026 | def process_encoder_hidden_states(
1027 | self, encoder_hidden_states: torch.Tensor, added_cond_kwargs: Dict[str, Any]
1028 | ) -> torch.Tensor:
1029 | if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
1030 | encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
1031 | elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
1032 | # Kandinsky 2.1 - style
1033 | if "image_embeds" not in added_cond_kwargs:
1034 | raise ValueError(
1035 | f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
1036 | )
1037 |
1038 | image_embeds = added_cond_kwargs.get("image_embeds")
1039 | encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
1040 | elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":
1041 | # Kandinsky 2.2 - style
1042 | if "image_embeds" not in added_cond_kwargs:
1043 | raise ValueError(
1044 | f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
1045 | )
1046 | image_embeds = added_cond_kwargs.get("image_embeds")
1047 | encoder_hidden_states = self.encoder_hid_proj(image_embeds)
1048 | elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj":
1049 | if "image_embeds" not in added_cond_kwargs:
1050 | raise ValueError(
1051 | f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
1052 | )
1053 |
1054 | if hasattr(self, "text_encoder_hid_proj") and self.text_encoder_hid_proj is not None:
1055 | encoder_hidden_states = self.text_encoder_hid_proj(encoder_hidden_states)
1056 |
1057 | image_embeds = added_cond_kwargs.get("image_embeds")
1058 | image_embeds = self.encoder_hid_proj(image_embeds)
1059 | encoder_hidden_states = (encoder_hidden_states, image_embeds)
1060 | return encoder_hidden_states
1061 |
1062 | def forward(
1063 | self,
1064 | sample: torch.Tensor,
1065 | timestep: Union[torch.Tensor, float, int],
1066 | encoder_hidden_states: torch.Tensor,
1067 | class_labels: Optional[torch.Tensor] = None,
1068 | timestep_cond: Optional[torch.Tensor] = None,
1069 | attention_mask: Optional[torch.Tensor] = None,
1070 | cross_attention_kwargs: Optional[Dict[str, Any]] = None,
1071 | added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
1072 | down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
1073 | mid_block_additional_residual: Optional[torch.Tensor] = None,
1074 | down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
1075 | encoder_attention_mask: Optional[torch.Tensor] = None,
1076 | return_dict: bool = True,
1077 | ) -> Union[UNet2DConditionOutput, Tuple]:
1078 | r"""
1079 | The [`UNet2DConditionModel`] forward method.
1080 |
1081 | Args:
1082 | sample (`torch.Tensor`):
1083 | The noisy input tensor with the following shape `(batch, channel, height, width)`.
1084 | timestep (`torch.Tensor` or `float` or `int`): The number of timesteps to denoise an input.
1085 | encoder_hidden_states (`torch.Tensor`):
1086 | The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
1087 | class_labels (`torch.Tensor`, *optional*, defaults to `None`):
1088 | Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
1089 | timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`):
1090 | Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed
1091 | through the `self.time_embedding` layer to obtain the timestep embeddings.
1092 | attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
1093 | An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
1094 | is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
1095 | negative values to the attention scores corresponding to "discard" tokens.
1096 | cross_attention_kwargs (`dict`, *optional*):
1097 | A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
1098 | `self.processor` in
1099 | [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
1100 | added_cond_kwargs: (`dict`, *optional*):
1101 | A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that
1102 | are passed along to the UNet blocks.
1103 | down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*):
1104 | A tuple of tensors that if specified are added to the residuals of down unet blocks.
1105 | mid_block_additional_residual: (`torch.Tensor`, *optional*):
1106 | A tensor that if specified is added to the residual of the middle unet block.
1107 | down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
1108 | additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s)
1109 | encoder_attention_mask (`torch.Tensor`):
1110 | A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
1111 | `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
1112 | which adds large negative values to the attention scores corresponding to "discard" tokens.
1113 | return_dict (`bool`, *optional*, defaults to `True`):
1114 | Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
1115 | tuple.
1116 |
1117 | Returns:
1118 | [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
1119 | If `return_dict` is True, an [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] is returned,
1120 | otherwise a `tuple` is returned where the first element is the sample tensor.
1121 | """
1122 | # By default samples have to be AT least a multiple of the overall upsampling factor.
1123 | # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
1124 | # However, the upsampling interpolation output size can be forced to fit any upsampling size
1125 | # on the fly if necessary.
1126 | default_overall_up_factor = 2**self.num_upsamplers
1127 |
1128 | # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
1129 | forward_upsample_size = False
1130 | upsample_size = None
1131 |
1132 | for dim in sample.shape[-2:]:
1133 | if dim % default_overall_up_factor != 0:
1134 | # Forward upsample size to force interpolation output size.
1135 | forward_upsample_size = True
1136 | break
1137 |
1138 | # ensure attention_mask is a bias, and give it a singleton query_tokens dimension
1139 | # expects mask of shape:
1140 | # [batch, key_tokens]
1141 | # adds singleton query_tokens dimension:
1142 | # [batch, 1, key_tokens]
1143 | # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
1144 | # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
1145 | # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
1146 | if attention_mask is not None:
1147 | # assume that mask is expressed as:
1148 | # (1 = keep, 0 = discard)
1149 | # convert mask into a bias that can be added to attention scores:
1150 | # (keep = +0, discard = -10000.0)
1151 | attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
1152 | attention_mask = attention_mask.unsqueeze(1)
1153 |
1154 | # convert encoder_attention_mask to a bias the same way we do for attention_mask
1155 | if encoder_attention_mask is not None:
1156 | encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
1157 | encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
1158 |
1159 | # 0. center input if necessary
1160 | if self.config.center_input_sample:
1161 | sample = 2 * sample - 1.0
1162 |
1163 | # 1. time
1164 | t_emb = self.get_time_embed(sample=sample, timestep=timestep)
1165 | emb = self.time_embedding(t_emb, timestep_cond)
1166 |
1167 | class_emb = self.get_class_embed(sample=sample, class_labels=class_labels)
1168 | if class_emb is not None:
1169 | if self.config.class_embeddings_concat:
1170 | emb = torch.cat([emb, class_emb], dim=-1)
1171 | else:
1172 | emb = emb + class_emb
1173 |
1174 | aug_emb = self.get_aug_embed(
1175 | emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
1176 | )
1177 | if self.config.addition_embed_type == "image_hint":
1178 | aug_emb, hint = aug_emb
1179 | sample = torch.cat([sample, hint], dim=1)
1180 |
1181 |
1182 | emb = emb + aug_emb if aug_emb is not None else emb
1183 |
1184 | if self.time_embed_act is not None:
1185 | emb = self.time_embed_act(emb)
1186 |
1187 | encoder_hidden_states = self.process_encoder_hidden_states(
1188 | encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
1189 | )
1190 |
1191 | sample_img = sample[:, :4, :, :]
1192 | sample_dem = sample[:, 4:, :, :]
1193 | # 2. pre-process using the two different heads
1194 | sample_img = self.conv_in_img(sample_img)
1195 | sample_dem = self.conv_in_dem(sample_dem)
1196 |
1197 | # 2.5 GLIGEN position net
1198 | if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None:
1199 | cross_attention_kwargs = cross_attention_kwargs.copy()
1200 | gligen_args = cross_attention_kwargs.pop("gligen")
1201 | cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)}
1202 |
1203 | # 3. down
1204 | # we're popping the `scale` instead of getting it because otherwise `scale` will be propagated
1205 | # to the internal blocks and will raise deprecation warnings. this will be confusing for our users.
1206 | if cross_attention_kwargs is not None:
1207 | cross_attention_kwargs = cross_attention_kwargs.copy()
1208 | lora_scale = cross_attention_kwargs.pop("scale", 1.0)
1209 | else:
1210 | lora_scale = 1.0
1211 |
1212 | if USE_PEFT_BACKEND:
1213 | # weight the lora layers by setting `lora_scale` for each PEFT layer
1214 | scale_lora_layers(self, lora_scale)
1215 |
1216 | is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
1217 | # using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets
1218 | is_adapter = down_intrablock_additional_residuals is not None
1219 | if (down_intrablock_additional_residuals is not None) or is_adapter:
1220 | raise NotImplementedError("additional_residuals")
1221 |
1222 |
1223 | # go through the heads
1224 | head_img_res_sample = (sample_img,)
1225 | # RGB head
1226 | if hasattr(self.head_img, "has_cross_attention") and self.head_img.has_cross_attention:
1227 | # For t2i-adapter CrossAttnDownBlock2D
1228 | additional_residuals = {}
1229 | sample_img, res_samples_img = self.head_img(
1230 | hidden_states=sample_img,
1231 | temb=emb,
1232 | encoder_hidden_states=encoder_hidden_states,
1233 | attention_mask=attention_mask,
1234 | cross_attention_kwargs=cross_attention_kwargs,
1235 | encoder_attention_mask=encoder_attention_mask,
1236 | **additional_residuals,
1237 | )
1238 | else:
1239 | sample_img, res_samples_img = self.head_img(hidden_states=sample, temb=emb)
1240 | head_img_res_sample += res_samples_img[:2]
1241 |
1242 |
1243 |
1244 | head_dem_res_sample = (sample_dem,)
1245 | # DEM head
1246 | if hasattr(self.head_dem, "has_cross_attention") and self.head_dem.has_cross_attention:
1247 | # For t2i-adapter CrossAttnDownBlock2D
1248 | additional_residuals = {}
1249 |
1250 | sample_dem, res_samples_dem = self.head_dem(
1251 | hidden_states=sample_dem,
1252 | temb=emb,
1253 | encoder_hidden_states=encoder_hidden_states,
1254 | attention_mask=attention_mask,
1255 | cross_attention_kwargs=cross_attention_kwargs,
1256 | encoder_attention_mask=encoder_attention_mask,
1257 | **additional_residuals,
1258 | )
1259 | else:
1260 | # sample_dem, res_samples_dem = self.head_dem(hidden_states=sample, temb=emb)
1261 | sample_dem, res_samples_dem = self.head_img(hidden_states=sample, temb=emb) # shared weights
1262 |
1263 | head_dem_res_sample += res_samples_dem[:2]
1264 |
1265 | #average the two heads and pass them through the down blocks
1266 | sample = (sample_img + sample_dem) / 2
1267 | #####
1268 | res_samples_img_dem = (res_samples_img[2] + res_samples_dem[2]) / 2
1269 | down_block_res_samples = (res_samples_img_dem,)
1270 |
1271 |
1272 | for downsample_block in self.down_blocks:
1273 | if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
1274 | # For t2i-adapter CrossAttnDownBlock2D
1275 | additional_residuals = {}
1276 | if is_adapter and len(down_intrablock_additional_residuals) > 0:
1277 | additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0)
1278 |
1279 | sample, res_samples = downsample_block(
1280 | hidden_states=sample,
1281 | temb=emb,
1282 | encoder_hidden_states=encoder_hidden_states,
1283 | attention_mask=attention_mask,
1284 | cross_attention_kwargs=cross_attention_kwargs,
1285 | encoder_attention_mask=encoder_attention_mask,
1286 | **additional_residuals,
1287 | )
1288 | else:
1289 | sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
1290 | if is_adapter and len(down_intrablock_additional_residuals) > 0:
1291 | sample += down_intrablock_additional_residuals.pop(0)
1292 |
1293 | down_block_res_samples += res_samples
1294 |
1295 | if is_controlnet:
1296 | new_down_block_res_samples = ()
1297 |
1298 | for down_block_res_sample, down_block_additional_residual in zip(
1299 | down_block_res_samples, down_block_additional_residuals
1300 | ):
1301 | down_block_res_sample = down_block_res_sample + down_block_additional_residual
1302 | new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
1303 |
1304 | down_block_res_samples = new_down_block_res_samples
1305 |
1306 |
1307 | # 4. mid
1308 | if self.mid_block is not None:
1309 | if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
1310 | sample = self.mid_block(
1311 | sample,
1312 | emb,
1313 | encoder_hidden_states=encoder_hidden_states,
1314 | attention_mask=attention_mask,
1315 | cross_attention_kwargs=cross_attention_kwargs,
1316 | encoder_attention_mask=encoder_attention_mask,
1317 | )
1318 | else:
1319 | sample = self.mid_block(sample, emb)
1320 |
1321 | # To support T2I-Adapter-XL
1322 | if (
1323 | is_adapter
1324 | and len(down_intrablock_additional_residuals) > 0
1325 | and sample.shape == down_intrablock_additional_residuals[0].shape
1326 | ):
1327 | sample += down_intrablock_additional_residuals.pop(0)
1328 |
1329 | if is_controlnet:
1330 | sample = sample + mid_block_additional_residual
1331 |
1332 | # 5. up
1333 | for i, upsample_block in enumerate(self.up_blocks):
1334 |
1335 | res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
1336 | down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
1337 |
1338 | # if we have not reached the final block and need to forward the
1339 | # upsample size, we do it here
1340 | if forward_upsample_size:
1341 | upsample_size = down_block_res_samples[-1].shape[2:]
1342 | if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
1343 | sample = upsample_block(
1344 | hidden_states=sample,
1345 | temb=emb,
1346 | res_hidden_states_tuple=res_samples,
1347 | encoder_hidden_states=encoder_hidden_states,
1348 | cross_attention_kwargs=cross_attention_kwargs,
1349 | upsample_size=upsample_size,
1350 | attention_mask=attention_mask,
1351 | encoder_attention_mask=encoder_attention_mask,
1352 | )
1353 |
1354 | else:
1355 | sample = upsample_block(
1356 | hidden_states=sample,
1357 | temb=emb,
1358 | res_hidden_states_tuple=res_samples,
1359 | upsample_size=upsample_size)
1360 |
1361 |
1362 | # go through each head
1363 |
1364 | sample_img = sample
1365 |
1366 | if hasattr(self.head_out_img, "has_cross_attention") and self.head_out_img.has_cross_attention:
1367 | sample_img = self.head_out_img(
1368 | hidden_states=sample_img,
1369 | temb=emb,
1370 | res_hidden_states_tuple=head_img_res_sample,
1371 | encoder_hidden_states=encoder_hidden_states,
1372 | cross_attention_kwargs=cross_attention_kwargs,
1373 | upsample_size=upsample_size,
1374 | attention_mask=attention_mask,
1375 | encoder_attention_mask=encoder_attention_mask,
1376 | )
1377 | else:
1378 | sample_img = self.head_out_img(sample_img,
1379 | hidden_states=sample,
1380 | temb=emb,
1381 | res_hidden_states_tuple=head_img_res_sample,
1382 | upsample_size=upsample_size,
1383 | )
1384 | if self.conv_norm_out_img:
1385 | sample_img = self.conv_norm_out_img(sample_img)
1386 | sample_img = self.conv_act(sample_img)
1387 | sample_img = self.conv_out_img(sample_img)
1388 |
1389 | sample_dem = sample
1390 |
1391 | if hasattr(self.head_out_dem, "has_cross_attention") and self.head_out_dem.has_cross_attention:
1392 | sample_dem = self.head_out_dem(
1393 | hidden_states=sample_dem,
1394 | temb=emb,
1395 | res_hidden_states_tuple=head_dem_res_sample,
1396 | encoder_hidden_states=encoder_hidden_states,
1397 | cross_attention_kwargs=cross_attention_kwargs,
1398 | upsample_size=upsample_size,
1399 | attention_mask=attention_mask,
1400 | encoder_attention_mask=encoder_attention_mask,
1401 | )
1402 | else:
1403 | sample_dem = self.head_out_dem(sample_dem,
1404 | hidden_states=sample,
1405 | temb=emb,
1406 | res_hidden_states_tuple=head_dem_res_sample,
1407 | upsample_size=upsample_size,
1408 | )
1409 |
1410 | if self.conv_norm_out_dem:
1411 | sample_dem = self.conv_norm_out_dem(sample_dem)
1412 | sample_dem = self.conv_act(sample_dem)
1413 | sample_dem = self.conv_out_dem(sample_dem)
1414 |
1415 | sample = torch.cat([sample_img,sample_dem],dim=1)
1416 |
1417 | if USE_PEFT_BACKEND:
1418 | # remove `lora_scale` from each PEFT layer
1419 | unscale_lora_layers(self, lora_scale)
1420 |
1421 | if not return_dict:
1422 | return (sample,)
1423 |
1424 | return UNet2DConditionOutput(sample=sample)
1425 |
1426 |
1427 |
1428 | def load_weights_from_pretrained(pretrain_model,model_dem):
1429 | dem_state_dict = model_dem.state_dict()
1430 | for name, param in pretrain_model.named_parameters():
1431 | block = name.split(".")[0]
1432 | if block == "conv_in":
1433 | new_name_img = name.replace("conv_in","conv_in_img")
1434 | dem_state_dict[new_name_img] = param
1435 | new_name_dem = name.replace("conv_in","conv_in_dem")
1436 | dem_state_dict[new_name_dem] = param
1437 | if block == "down_blocks":
1438 | block_num = int(name.split(".")[1])
1439 | if block_num == 0:
1440 | new_name_img = name.replace("down_blocks.0","head_img")
1441 | dem_state_dict[new_name_img] = param
1442 | new_name_dem = name.replace("down_blocks.0","head_dem")
1443 | dem_state_dict[new_name_dem] = param
1444 | elif block_num > 0:
1445 | new_name = name.replace(f"down_blocks.{block_num}",f"down_blocks.{block_num-1}")
1446 | dem_state_dict[new_name] = param
1447 | if block == "mid_block":
1448 | dem_state_dict[name] = param
1449 | if block == "time_embedding":
1450 | dem_state_dict[name] = param
1451 | if block == "up_blocks":
1452 | block_num = int(name.split(".")[1])
1453 | if block_num == 3:
1454 | new_name = name.replace("up_blocks.3","head_out_img")
1455 | dem_state_dict[new_name] = param
1456 | new_name = name.replace("up_blocks.3","head_out_dem")
1457 | dem_state_dict[new_name] = param
1458 | else:
1459 | dem_state_dict[name] = param
1460 | if block == "conv_out":
1461 | new_name = name.replace("conv_out","conv_out_img")
1462 | dem_state_dict[new_name] = param
1463 | new_name = name.replace("conv_out","conv_out_dem")
1464 | dem_state_dict[new_name] = param
1465 | if block == "conv_norm_out":
1466 | new_name = name.replace("conv_norm_out","conv_norm_out_img")
1467 | dem_state_dict[new_name] = param
1468 | new_name = name.replace("conv_norm_out","conv_norm_out_dem")
1469 | dem_state_dict[new_name] = param
1470 |
1471 | model_dem.load_state_dict(dem_state_dict)
1472 |
1473 | return model_dem
1474 |
--------------------------------------------------------------------------------