├── .gitignore ├── README.md ├── assets ├── boy.png ├── car.png ├── cat.png ├── surf.png └── umbrella.png ├── playground.ipynb ├── setup.py ├── style.sh └── tflcg ├── __init__.py └── layout_guidance_pipeline.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | .idea/ 132 | 133 | sample/ 134 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Diffusers based Training-Free Layout Control with Cross-Attention Guidance 2 | 3 | This repository provides an implementation of the paper [Training-Free Layout Control with Cross-Attention Guidance](https://arxiv.org/abs/2304.03373) using 🤗 [Hugging Face](https://github.com/huggingface/diffusers) Diffusion models. The code is adaptation of the original implementation by silen-chen is also acknowledged and extended here to make it more accessible. Special thanks to [@silen-chen](https://github.com/silent-chen) for sharing their work. Additionally, we use the work in AttendExcite diffusers as reference. 4 | 5 | Currently, it only supports backward guidance. The forward guidance will be added soon. 6 | 7 | ## Installation 8 | 9 | ```shell 10 | pip install git+https://github.com/nipunjindal/diffusers-layout-guidance.git 11 | ``` 12 | 13 | ## How to use Training-Free Layout Control Guidance (TFLCG) 14 | 15 | ```python 16 | from tflcg.layout_guidance_pipeline import LayoutGuidanceStableDiffusionPipeline 17 | from diffusers import EulerDiscreteScheduler 18 | 19 | pipe = LayoutGuidanceStableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5") 20 | pipe = pipe.to("mps") 21 | pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config) 22 | 23 | prompt = "A cat playing with a ball" 24 | bboxes = [[0.55, 0.4, 0.95, 0.8]] 25 | image = pipe(prompt, num_inference_steps=20, 26 | token_indices=[[2]], 27 | bboxes=bboxes).images[0] 28 | image = pipe.draw_box(image, bboxes) 29 | image.save("output.png") 30 | ``` 31 | 32 | ## Image Examples 33 | To use the repository, follow the steps below to enable the outputs: 34 | 35 | 1. Install the required dependencies (instructions are provided in the README). 36 | 2. Use the provided code snippets to generate images with the desired settings. You can play around with the various parameters such as `max_guidance_iter`, `max_guidance_iter_per_step`, and `scale_factor` to adjust the output according to your needs. 37 | 3. Keep in mind that increasing the `scale_factor` may result in a loss of fidelity in the generated images. 38 | 39 | Feel free to experiment with the various settings and see how they affect the output. If you have any questions or run into any issues, please refer to the documentation or reach out to the community for help. 40 | 41 | Below are some examples of the types of images you can generate using our models: 42 | 43 | ## Cat Playing with a Ball 44 | ![cat playing with a ball](assets/cat.png) 45 | 46 | ## Person Holding an Umbrella in the Rain 47 | ![person holding an umbrella in the rain](assets/umbrella.png) 48 | 49 | ## Car Driving on a Winding Road 50 | ![car driving on a winding road](assets/car.png) 51 | 52 | ## Child Blowing Bubbles 53 | ![child blowing bubbles](assets/boy.png) 54 | 55 | ## Person Surfing on a Wave 56 | ![person surfing on a wave](assets/surf.png) 57 | 58 | ## Contributing and Issues 59 | Please feel free to contribute to this repository by submitting pull requests or creating issues in the [GitHub repository](https://github.com/nipunjindal/diffusers-layout-guidance). If you encounter any bugs or have suggestions for improvements, don't hesitate to open an issue. We welcome all contributions and appreciate your feedback! 60 | 61 | ## Citation 62 | 63 | ```bibtex 64 | @article{chen2023trainingfree, 65 | title={Training-Free Layout Control with Cross-Attention Guidance}, 66 | author={Minghao Chen and Iro Laina and Andrea Vedaldi}, 67 | journal={arXiv preprint arXiv:2304.03373}, 68 | year={2023} 69 | } 70 | ``` -------------------------------------------------------------------------------- /assets/boy.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nipunjindal/diffusers-layout-guidance/9769374f8633369115e734f4b554a5944b85f729/assets/boy.png -------------------------------------------------------------------------------- /assets/car.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nipunjindal/diffusers-layout-guidance/9769374f8633369115e734f4b554a5944b85f729/assets/car.png -------------------------------------------------------------------------------- /assets/cat.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nipunjindal/diffusers-layout-guidance/9769374f8633369115e734f4b554a5944b85f729/assets/cat.png -------------------------------------------------------------------------------- /assets/surf.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nipunjindal/diffusers-layout-guidance/9769374f8633369115e734f4b554a5944b85f729/assets/surf.png -------------------------------------------------------------------------------- /assets/umbrella.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nipunjindal/diffusers-layout-guidance/9769374f8633369115e734f4b554a5944b85f729/assets/umbrella.png -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | 3 | setup( 4 | name="tflcg", 5 | version="0.1", 6 | packages=find_packages(), 7 | install_requires=[ 8 | "diffusers==0.15.0", 9 | "Pillow==9.5.0", 10 | "setuptools==67.6.1", 11 | "transformers==4.28.0", 12 | ], 13 | author="Nipun Jindal", 14 | author_email="jindal.nipun@gmail.com", 15 | description="Unofficial huggingface/diffusers-based implementation of the paper Training-Free Layout Control with Cross-Attention Guidance", 16 | url="https://github.com/yourusername/your-repository", 17 | ) 18 | -------------------------------------------------------------------------------- /style.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Define the required tools 4 | tools=("isort" "flake8" "mypy" "black") 5 | 6 | # Check if each tool is installed 7 | for tool in "${tools[@]}" 8 | do 9 | if ! command -v "$tool" &> /dev/null 10 | then 11 | echo "$tool could not be found. Please install it using 'pip install $tool'." 12 | exit 1 13 | fi 14 | done 15 | 16 | # Change to the root directory of your project 17 | cd tflcg 18 | 19 | # Run isort to sort imports 20 | isort . 21 | 22 | # Run flake8 to check for syntax errors and style issues 23 | flake8 . 24 | 25 | # Run mypy to check for type errors 26 | mypy . 27 | 28 | # Run black to format the code 29 | black . 30 | -------------------------------------------------------------------------------- /tflcg/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nipunjindal/diffusers-layout-guidance/9769374f8633369115e734f4b554a5944b85f729/tflcg/__init__.py -------------------------------------------------------------------------------- /tflcg/layout_guidance_pipeline.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, Dict, List, Optional, Union 2 | 3 | import numpy as np 4 | import torch 5 | from diffusers.loaders import TextualInversionLoaderMixin 6 | from diffusers.models import AutoencoderKL, UNet2DConditionModel 7 | from diffusers.models.attention_processor import Attention 8 | from diffusers.pipelines.stable_diffusion import ( 9 | StableDiffusionAttendAndExcitePipeline, 10 | StableDiffusionPipelineOutput, 11 | StableDiffusionSafetyChecker, 12 | ) 13 | from diffusers.schedulers import KarrasDiffusionSchedulers 14 | from diffusers.utils import logging 15 | from PIL import Image, ImageDraw 16 | from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer 17 | 18 | logger = logging.get_logger(__name__) 19 | 20 | 21 | class AttentionStore: 22 | @staticmethod 23 | def get_empty_store(): 24 | return {"down": [], "mid": [], "up": []} 25 | 26 | def __call__(self, attn, is_cross: bool, place_in_unet: str): 27 | if is_cross: 28 | if attn.shape[1] in self.attn_res: 29 | self.step_store[place_in_unet].append(attn) 30 | 31 | self.cur_att_layer += 1 32 | if self.cur_att_layer == self.num_att_layers: 33 | self.cur_att_layer = 0 34 | self.between_steps() 35 | 36 | def between_steps(self): 37 | self.attention_store = self.step_store 38 | self.step_store = self.get_empty_store() 39 | 40 | def maps(self, block_type: str): 41 | return self.attention_store[block_type] 42 | 43 | def reset(self): 44 | self.cur_att_layer = 0 45 | self.step_store = self.get_empty_store() 46 | self.attention_store = {} 47 | 48 | def __init__(self, attn_res=[256, 64]): 49 | """ 50 | Initialize an empty AttentionStore :param step_index: used to visualize only a specific step in the diffusion 51 | process 52 | """ 53 | self.num_att_layers = -1 54 | self.cur_att_layer = 0 55 | self.step_store = self.get_empty_store() 56 | self.attention_store = {} 57 | self.curr_step_index = 0 58 | self.attn_res = attn_res 59 | 60 | 61 | class LayoutGuidanceAttnProcessor: 62 | def __init__(self, attnstore, place_in_unet): 63 | super().__init__() 64 | self.attnstore = attnstore 65 | self.place_in_unet = place_in_unet 66 | 67 | def __call__( 68 | self, 69 | attn: Attention, 70 | hidden_states, 71 | encoder_hidden_states=None, 72 | attention_mask=None, 73 | ): 74 | batch_size, sequence_length, _ = hidden_states.shape 75 | attention_mask = attn.prepare_attention_mask( 76 | attention_mask, sequence_length, batch_size 77 | ) 78 | 79 | query = attn.to_q(hidden_states) 80 | 81 | is_cross = encoder_hidden_states is not None 82 | encoder_hidden_states = ( 83 | encoder_hidden_states 84 | if encoder_hidden_states is not None 85 | else hidden_states 86 | ) 87 | key = attn.to_k(encoder_hidden_states) 88 | value = attn.to_v(encoder_hidden_states) 89 | 90 | query = attn.head_to_batch_dim(query) 91 | key = attn.head_to_batch_dim(key) 92 | value = attn.head_to_batch_dim(value) 93 | 94 | attention_probs = attn.get_attention_scores(query, key, attention_mask) 95 | 96 | # only need to store attention maps during the Attend and Excite process 97 | if attention_probs.requires_grad: 98 | self.attnstore(attention_probs, is_cross, self.place_in_unet) 99 | 100 | hidden_states = torch.bmm(attention_probs, value) 101 | hidden_states = attn.batch_to_head_dim(hidden_states) 102 | 103 | # linear proj 104 | hidden_states = attn.to_out[0](hidden_states) 105 | # dropout 106 | hidden_states = attn.to_out[1](hidden_states) 107 | 108 | return hidden_states 109 | 110 | 111 | class LayoutGuidanceStableDiffusionPipeline(StableDiffusionAttendAndExcitePipeline): 112 | def __init__( 113 | self, 114 | vae: AutoencoderKL, 115 | text_encoder: CLIPTextModel, 116 | tokenizer: CLIPTokenizer, 117 | unet: UNet2DConditionModel, 118 | scheduler: KarrasDiffusionSchedulers, 119 | safety_checker: StableDiffusionSafetyChecker, 120 | feature_extractor: CLIPImageProcessor, 121 | requires_safety_checker: bool = True, 122 | ): 123 | super().__init__( 124 | vae, 125 | text_encoder, 126 | tokenizer, 127 | unet, 128 | scheduler, 129 | safety_checker, 130 | feature_extractor, 131 | requires_safety_checker, 132 | ) 133 | 134 | def _encode_prompt( 135 | self, 136 | prompt, 137 | device, 138 | num_images_per_prompt, 139 | do_classifier_free_guidance, 140 | negative_prompt=None, 141 | prompt_embeds: Optional[torch.FloatTensor] = None, 142 | negative_prompt_embeds: Optional[torch.FloatTensor] = None, 143 | ): 144 | r""" 145 | Encodes the prompt into text encoder hidden states. 146 | 147 | Args: 148 | prompt (`str` or `List[str]`, *optional*): 149 | prompt to be encoded 150 | device: (`torch.device`): 151 | torch device 152 | num_images_per_prompt (`int`): 153 | number of images that should be generated per prompt 154 | do_classifier_free_guidance (`bool`): 155 | whether to use classifier free guidance or not 156 | negative_prompt (`str` or `List[str]`, *optional*): 157 | The prompt or prompts not to guide the image generation. If not defined, one has to pass 158 | `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is 159 | less than `1`). 160 | prompt_embeds (`torch.FloatTensor`, *optional*): 161 | Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not 162 | provided, text embeddings will be generated from `prompt` input argument. 163 | negative_prompt_embeds (`torch.FloatTensor`, *optional*): 164 | Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt 165 | weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input 166 | argument. 167 | """ 168 | if prompt is not None and isinstance(prompt, str): 169 | batch_size = 1 170 | elif prompt is not None and isinstance(prompt, list): 171 | batch_size = len(prompt) 172 | else: 173 | batch_size = prompt_embeds.shape[0] 174 | 175 | if prompt_embeds is None: 176 | # textual inversion: procecss multi-vector tokens if necessary 177 | if isinstance(self, TextualInversionLoaderMixin): 178 | prompt = self.maybe_convert_prompt(prompt, self.tokenizer) 179 | 180 | text_inputs = self.tokenizer( 181 | prompt, 182 | padding="max_length", 183 | max_length=self.tokenizer.model_max_length, 184 | truncation=True, 185 | return_tensors="pt", 186 | ) 187 | text_input_ids = text_inputs.input_ids 188 | untruncated_ids = self.tokenizer( 189 | prompt, padding="longest", return_tensors="pt" 190 | ).input_ids 191 | 192 | if untruncated_ids.shape[-1] >= text_input_ids.shape[ 193 | -1 194 | ] and not torch.equal(text_input_ids, untruncated_ids): 195 | removed_text = self.tokenizer.batch_decode( 196 | untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] 197 | ) 198 | logger.warning( 199 | "The following part of your input was truncated because CLIP can only handle sequences up to" 200 | f" {self.tokenizer.model_max_length} tokens: {removed_text}" 201 | ) 202 | 203 | if ( 204 | hasattr(self.text_encoder.config, "use_attention_mask") 205 | and self.text_encoder.config.use_attention_mask 206 | ): 207 | attention_mask = text_inputs.attention_mask.to(device) 208 | else: 209 | attention_mask = None 210 | 211 | prompt_embeds = self.text_encoder( 212 | text_input_ids.to(device), 213 | attention_mask=attention_mask, 214 | ) 215 | prompt_embeds = prompt_embeds[0] 216 | 217 | prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) 218 | 219 | bs_embed, seq_len, _ = prompt_embeds.shape 220 | # duplicate text embeddings for each generation per prompt, using mps friendly method 221 | prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) 222 | prompt_embeds = prompt_embeds.view( 223 | bs_embed * num_images_per_prompt, seq_len, -1 224 | ) 225 | 226 | # get unconditional embeddings for classifier free guidance 227 | if do_classifier_free_guidance and negative_prompt_embeds is None: 228 | uncond_tokens: List[str] 229 | if negative_prompt is None: 230 | uncond_tokens = [""] * batch_size 231 | elif type(prompt) is not type(negative_prompt): 232 | raise TypeError( 233 | f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" 234 | f" {type(prompt)}." 235 | ) 236 | elif isinstance(negative_prompt, str): 237 | uncond_tokens = [negative_prompt] 238 | elif batch_size != len(negative_prompt): 239 | raise ValueError( 240 | f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" 241 | f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" 242 | " the batch size of `prompt`." 243 | ) 244 | else: 245 | uncond_tokens = negative_prompt 246 | 247 | # textual inversion: procecss multi-vector tokens if necessary 248 | if isinstance(self, TextualInversionLoaderMixin): 249 | uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) 250 | 251 | max_length = prompt_embeds.shape[1] 252 | uncond_input = self.tokenizer( 253 | uncond_tokens, 254 | padding="max_length", 255 | max_length=max_length, 256 | truncation=True, 257 | return_tensors="pt", 258 | ) 259 | 260 | if ( 261 | hasattr(self.text_encoder.config, "use_attention_mask") 262 | and self.text_encoder.config.use_attention_mask 263 | ): 264 | attention_mask = uncond_input.attention_mask.to(device) 265 | else: 266 | attention_mask = None 267 | 268 | negative_prompt_embeds = self.text_encoder( 269 | uncond_input.input_ids.to(device), 270 | attention_mask=attention_mask, 271 | ) 272 | negative_prompt_embeds = negative_prompt_embeds[0] 273 | 274 | if do_classifier_free_guidance: 275 | # duplicate unconditional embeddings for each generation per prompt, using mps friendly method 276 | seq_len = negative_prompt_embeds.shape[1] 277 | 278 | negative_prompt_embeds = negative_prompt_embeds.to( 279 | dtype=self.text_encoder.dtype, device=device 280 | ) 281 | 282 | negative_prompt_embeds = negative_prompt_embeds.repeat( 283 | 1, num_images_per_prompt, 1 284 | ) 285 | negative_prompt_embeds = negative_prompt_embeds.view( 286 | batch_size * num_images_per_prompt, seq_len, -1 287 | ) 288 | 289 | # For classifier free guidance, we need to do two forward passes. 290 | # Here we concatenate the unconditional and text embeddings into a single batch 291 | # to avoid doing two forward passes 292 | final_prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) 293 | 294 | return final_prompt_embeds, prompt_embeds 295 | 296 | def check_inputs( 297 | self, 298 | prompt, 299 | token_indices, 300 | bboxes, 301 | height, 302 | width, 303 | callback_steps, 304 | negative_prompt=None, 305 | prompt_embeds=None, 306 | negative_prompt_embeds=None, 307 | ): 308 | if height % 8 != 0 or width % 8 != 0: 309 | raise ValueError( 310 | f"`height` and `width` have to be divisible by 8 but are {height} and {width}." 311 | ) 312 | 313 | if (callback_steps is None) or ( 314 | callback_steps is not None 315 | and (not isinstance(callback_steps, int) or callback_steps <= 0) 316 | ): 317 | raise ValueError( 318 | f"`callback_steps` has to be a positive integer but is {callback_steps} of type" 319 | f" {type(callback_steps)}." 320 | ) 321 | 322 | if prompt is not None and prompt_embeds is not None: 323 | raise ValueError( 324 | f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" 325 | " only forward one of the two." 326 | ) 327 | elif prompt is None and prompt_embeds is None: 328 | raise ValueError( 329 | "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." 330 | ) 331 | elif prompt is not None and ( 332 | not isinstance(prompt, str) and not isinstance(prompt, list) 333 | ): 334 | raise ValueError( 335 | f"`prompt` has to be of type `str` or `list` but is {type(prompt)}" 336 | ) 337 | 338 | if negative_prompt is not None and negative_prompt_embeds is not None: 339 | raise ValueError( 340 | f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" 341 | f" {negative_prompt_embeds}. Please make sure to only forward one of the two." 342 | ) 343 | 344 | if prompt_embeds is not None and negative_prompt_embeds is not None: 345 | if prompt_embeds.shape != negative_prompt_embeds.shape: 346 | raise ValueError( 347 | "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" 348 | f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" 349 | f" {negative_prompt_embeds.shape}." 350 | ) 351 | 352 | if token_indices is not None: 353 | if isinstance(token_indices, list): 354 | if isinstance(token_indices[0], list): 355 | if isinstance(token_indices[0][0], list): 356 | token_indices_batch_size = len(token_indices) 357 | elif isinstance(token_indices[0][0], int): 358 | token_indices_batch_size = 1 359 | else: 360 | raise TypeError( 361 | "`token_indices` must be a list of lists of integers or a list of integers." 362 | ) 363 | else: 364 | raise TypeError( 365 | "`token_indices` must be a list of lists of integers or a list of integers." 366 | ) 367 | else: 368 | raise TypeError( 369 | "`token_indices` must be a list of lists of integers or a list of integers." 370 | ) 371 | 372 | if bboxes is not None: 373 | if isinstance(bboxes, list): 374 | if isinstance(bboxes[0], list): 375 | if ( 376 | isinstance(bboxes[0][0], list) 377 | and len(bboxes[0][0]) == 4 378 | and all(isinstance(x, float) for x in bboxes[0][0]) 379 | ): 380 | bboxes_batch_size = len(bboxes) 381 | elif ( 382 | isinstance(bboxes[0], list) 383 | and len(bboxes[0]) == 4 384 | and all(isinstance(x, float) for x in bboxes[0]) 385 | ): 386 | bboxes_batch_size = 1 387 | else: 388 | print(isinstance(bboxes[0], list), len(bboxes[0])) 389 | raise TypeError( 390 | "`bboxes` must be a list of lists of list with four floats or a list of tuples with four floats." 391 | ) 392 | else: 393 | print(isinstance(bboxes[0], list), len(bboxes[0])) 394 | raise TypeError( 395 | "`bboxes` must be a list of lists of list with four floats or a list of tuples with four floats." 396 | ) 397 | else: 398 | print(isinstance(bboxes[0], list), len(bboxes[0])) 399 | raise TypeError( 400 | "`bboxes` must be a list of lists of list with four floats or a list of tuples with four floats." 401 | ) 402 | 403 | if prompt is not None and isinstance(prompt, str): 404 | prompt_batch_size = 1 405 | elif prompt is not None and isinstance(prompt, list): 406 | prompt_batch_size = len(prompt) 407 | elif prompt_embeds is not None: 408 | prompt_batch_size = prompt_embeds.shape[0] 409 | 410 | if token_indices_batch_size != prompt_batch_size: 411 | raise ValueError( 412 | f"token indices batch size must be same as prompt batch size. token indices batch size: {token_indices_batch_size}, prompt batch size: {prompt_batch_size}" 413 | ) 414 | 415 | if bboxes_batch_size != prompt_batch_size: 416 | raise ValueError( 417 | f"bbox batch size must be same as prompt batch size. bbox batch size: {bboxes_batch_size}, prompt batch size: {prompt_batch_size}" 418 | ) 419 | 420 | def _compute_loss(self, token_indices, bboxes, device) -> torch.Tensor: 421 | loss = 0 422 | object_number = len(bboxes) 423 | total_maps = 0 424 | for location in ["mid", "up"]: 425 | for attn_map_integrated in self.attention_store.maps(location): 426 | attn_map = attn_map_integrated.chunk(2)[1] 427 | 428 | b, i, j = attn_map.shape 429 | H = W = int(np.sqrt(i)) 430 | 431 | total_maps += 1 432 | for obj_idx in range(object_number): 433 | obj_loss = 0 434 | obj_box = bboxes[obj_idx] 435 | 436 | x_min, y_min, x_max, y_max = ( 437 | obj_box[0] * W, 438 | obj_box[1] * H, 439 | obj_box[2] * W, 440 | obj_box[3] * H, 441 | ) 442 | mask = torch.zeros((H, W), device=device) 443 | mask[round(y_min) : round(y_max), round(x_min) : round(x_max)] = 1 444 | 445 | for obj_position in token_indices[obj_idx]: 446 | ca_map_obj = attn_map[:, :, obj_position].reshape(b, H, W) 447 | activation_value = (ca_map_obj * mask).reshape(b, -1).sum( 448 | dim=-1 449 | ) / ca_map_obj.reshape(b, -1).sum(dim=-1) 450 | 451 | obj_loss += torch.mean((1 - activation_value) ** 2) 452 | 453 | loss += obj_loss / len(token_indices[obj_idx]) 454 | 455 | loss /= object_number * total_maps 456 | return loss 457 | 458 | def get_indices(self, prompt: str) -> Dict[str, int]: 459 | """Utility function to list the indices of the tokens you wish to alte""" 460 | ids = self.tokenizer(prompt).input_ids 461 | indices = { 462 | i: tok 463 | for tok, i in zip( 464 | self.tokenizer.convert_ids_to_tokens(ids), range(len(ids)) 465 | ) 466 | } 467 | return indices 468 | 469 | @staticmethod 470 | def draw_box(pil_img: Image, bboxes: List[List[float]]) -> Image: 471 | """Utility function to draw bbox on the image""" 472 | width, height = pil_img.size 473 | draw = ImageDraw.Draw(pil_img) 474 | 475 | for obj_box in bboxes: 476 | x_min, y_min, x_max, y_max = ( 477 | obj_box[0] * width, 478 | obj_box[1] * height, 479 | obj_box[2] * width, 480 | obj_box[3] * height, 481 | ) 482 | draw.rectangle( 483 | [int(x_min), int(y_min), int(x_max), int(y_max)], 484 | outline="red", 485 | width=4, 486 | ) 487 | 488 | return pil_img 489 | 490 | @torch.no_grad() 491 | def __call__( 492 | self, 493 | prompt: Union[str, List[str]] = None, 494 | token_indices: Union[List[List[List[int]]], List[List[int]]] = None, 495 | bboxes: Union[ 496 | List[List[List[float]]], 497 | List[List[float]], 498 | ] = None, 499 | height: Optional[int] = None, 500 | width: Optional[int] = None, 501 | num_inference_steps: int = 50, 502 | guidance_scale: float = 7.5, 503 | negative_prompt: Optional[Union[str, List[str]]] = None, 504 | num_images_per_prompt: Optional[int] = 1, 505 | eta: float = 0.0, 506 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 507 | latents: Optional[torch.FloatTensor] = None, 508 | prompt_embeds: Optional[torch.FloatTensor] = None, 509 | negative_prompt_embeds: Optional[torch.FloatTensor] = None, 510 | output_type: Optional[str] = "pil", 511 | return_dict: bool = True, 512 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, 513 | callback_steps: int = 1, 514 | cross_attention_kwargs: Optional[Dict[str, Any]] = None, 515 | max_guidance_iter: int = 10, 516 | max_guidance_iter_per_step: int = 5, 517 | scale_factor: int = 50, 518 | ): 519 | r""" 520 | Function invoked when calling the pipeline for generation. 521 | 522 | Args: 523 | prompt (`str` or `List[str]`, *optional*): 524 | The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. 525 | instead. 526 | token_indices (Union[List[List[List[int]]], List[List[int]]], optional): 527 | The list of the indexes in the prompt to layout. Defaults to None. 528 | bboxes (Union[List[List[List[float]]], List[List[float]]], optional): 529 | The bounding boxes of the indexes to maintain layout in the image. Defaults to None. 530 | height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): 531 | The height in pixels of the generated image. 532 | width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): 533 | The width in pixels of the generated image. 534 | num_inference_steps (`int`, *optional*, defaults to 50): 535 | The number of denoising steps. More denoising steps usually lead to a higher quality image at the 536 | expense of slower inference. 537 | guidance_scale (`float`, *optional*, defaults to 7.5): 538 | Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). 539 | `guidance_scale` is defined as `w` of equation 2. of [Imagen 540 | Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 541 | 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, 542 | usually at the expense of lower image quality. 543 | negative_prompt (`str` or `List[str]`, *optional*): 544 | The prompt or prompts not to guide the image generation. If not defined, one has to pass 545 | `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is 546 | less than `1`). 547 | num_images_per_prompt (`int`, *optional*, defaults to 1): 548 | The number of images to generate per prompt. 549 | eta (`float`, *optional*, defaults to 0.0): 550 | Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to 551 | [`schedulers.DDIMScheduler`], will be ignored for others. 552 | generator (`torch.Generator` or `List[torch.Generator]`, *optional*): 553 | One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) 554 | to make generation deterministic. 555 | latents (`torch.FloatTensor`, *optional*): 556 | Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image 557 | generation. Can be used to tweak the same generation with different prompts. If not provided, a latents 558 | tensor will ge generated by sampling using the supplied random `generator`. 559 | prompt_embeds (`torch.FloatTensor`, *optional*): 560 | Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not 561 | provided, text embeddings will be generated from `prompt` input argument. 562 | negative_prompt_embeds (`torch.FloatTensor`, *optional*): 563 | Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt 564 | weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input 565 | argument. 566 | output_type (`str`, *optional*, defaults to `"pil"`): 567 | The output format of the generate image. Choose between 568 | [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. 569 | return_dict (`bool`, *optional*, defaults to `True`): 570 | Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a 571 | plain tuple. 572 | callback (`Callable`, *optional*): 573 | A function that will be called every `callback_steps` steps during inference. The function will be 574 | called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. 575 | callback_steps (`int`, *optional*, defaults to 1): 576 | The frequency at which the `callback` function will be called. If not specified, the callback will be 577 | called at every step. 578 | cross_attention_kwargs (`dict`, *optional*): 579 | A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under 580 | `self.processor` in 581 | [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). 582 | max_guidance_iter (`int`, *optional*, defaults to `10`): 583 | The maximum number of iterations for the layout guidance on attention maps in diffusion mode. 584 | max_guidance_iter_per_step (`int`, *optional*, defaults to `5`): 585 | The maximum number of iterations to run during each time step for layout guidance. 586 | scale_factor (`int`, *optional*, defaults to `50`): 587 | The scale factor used to update the latents during optimization. 588 | 589 | Examples: 590 | 591 | Returns: 592 | [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: 593 | [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. 594 | When returning a tuple, the first element is a list with the generated images, and the second element is a 595 | list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" 596 | (nsfw) content, according to the `safety_checker`. 597 | """ 598 | # 0. Default height and width to unet 599 | height = height or self.unet.config.sample_size * self.vae_scale_factor 600 | width = width or self.unet.config.sample_size * self.vae_scale_factor 601 | 602 | # 1. Check inputs. Raise error if not correct 603 | self.check_inputs( 604 | prompt, 605 | token_indices, 606 | bboxes, 607 | height, 608 | width, 609 | callback_steps, 610 | negative_prompt, 611 | prompt_embeds, 612 | negative_prompt_embeds, 613 | ) 614 | 615 | # 2. Define call parameters 616 | if prompt is not None and isinstance(prompt, str): 617 | batch_size = 1 618 | elif prompt is not None and isinstance(prompt, list): 619 | batch_size = len(prompt) 620 | else: 621 | batch_size = prompt_embeds.shape[0] 622 | 623 | device = self._execution_device 624 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) 625 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` 626 | # corresponds to doing no classifier free guidance. 627 | do_classifier_free_guidance = guidance_scale > 1.0 628 | 629 | # 3. Encode input prompt 630 | prompt_embeds, cond_prompt_embeds = self._encode_prompt( 631 | prompt, 632 | device, 633 | num_images_per_prompt, 634 | do_classifier_free_guidance, 635 | negative_prompt, 636 | prompt_embeds=prompt_embeds, 637 | negative_prompt_embeds=negative_prompt_embeds, 638 | ) 639 | 640 | # 4. Prepare timesteps 641 | self.scheduler.set_timesteps(num_inference_steps, device=device) 642 | timesteps = self.scheduler.timesteps 643 | 644 | # 5. Prepare latent variables 645 | num_channels_latents = self.unet.config.in_channels 646 | latents = self.prepare_latents( 647 | batch_size * num_images_per_prompt, 648 | num_channels_latents, 649 | height, 650 | width, 651 | prompt_embeds.dtype, 652 | device, 653 | generator, 654 | latents, 655 | ) 656 | 657 | # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline 658 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) 659 | 660 | self.attention_store = AttentionStore() 661 | self.register_attention_control() 662 | 663 | # 7. Denoising loop 664 | num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order 665 | 666 | loss = torch.tensor(10000) 667 | 668 | with self.progress_bar(total=num_inference_steps) as progress_bar: 669 | for i, t in enumerate(timesteps): 670 | # Layout guidance loss optimization loop 671 | if i < max_guidance_iter: 672 | with torch.enable_grad(): 673 | latents = latents.clone().detach().requires_grad_(True) 674 | 675 | for guidance_iter in range(max_guidance_iter_per_step): 676 | if loss.item() / scale_factor < 0.2: 677 | break 678 | 679 | latent_model_input = self.scheduler.scale_model_input( 680 | latents, t 681 | ) 682 | self.unet( 683 | latent_model_input, 684 | t, 685 | encoder_hidden_states=cond_prompt_embeds, 686 | cross_attention_kwargs=cross_attention_kwargs, 687 | ) 688 | self.unet.zero_grad() 689 | 690 | loss = ( 691 | self._compute_loss(token_indices, bboxes, device) 692 | * scale_factor 693 | ) 694 | grad_cond = torch.autograd.grad( 695 | loss.requires_grad_(True), 696 | [latents], 697 | retain_graph=True, 698 | )[0] 699 | latents = ( 700 | latents - grad_cond * self.scheduler.sigmas[i] ** 2 701 | ) 702 | 703 | # expand the latents if we are doing classifier free guidance 704 | latent_model_input = ( 705 | torch.cat([latents] * 2) if do_classifier_free_guidance else latents 706 | ) 707 | latent_model_input = self.scheduler.scale_model_input( 708 | latent_model_input, t 709 | ) 710 | 711 | # predict the noise residual 712 | noise_pred = self.unet( 713 | latent_model_input, 714 | t, 715 | encoder_hidden_states=prompt_embeds, 716 | cross_attention_kwargs=cross_attention_kwargs, 717 | ).sample 718 | 719 | # perform guidance 720 | if do_classifier_free_guidance: 721 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 722 | noise_pred = noise_pred_uncond + guidance_scale * ( 723 | noise_pred_text - noise_pred_uncond 724 | ) 725 | 726 | # compute the previous noisy sample x_t -> x_t-1 727 | latents = self.scheduler.step( 728 | noise_pred, t, latents, **extra_step_kwargs 729 | ).prev_sample 730 | 731 | # call the callback, if provided 732 | if i == len(timesteps) - 1 or ( 733 | (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 734 | ): 735 | progress_bar.update() 736 | if callback is not None and i % callback_steps == 0: 737 | callback(i, t, latents) 738 | 739 | if output_type == "latent": 740 | image = latents 741 | has_nsfw_concept = None 742 | elif output_type == "pil": 743 | # 8. Post-processing 744 | image = self.decode_latents(latents) 745 | 746 | # 9. Run safety checker 747 | image, has_nsfw_concept = self.run_safety_checker( 748 | image, device, prompt_embeds.dtype 749 | ) 750 | 751 | # 10. Convert to PIL 752 | image = self.numpy_to_pil(image) 753 | else: 754 | # 8. Post-processing 755 | image = self.decode_latents(latents) 756 | 757 | # 9. Run safety checker 758 | image, has_nsfw_concept = self.run_safety_checker( 759 | image, device, prompt_embeds.dtype 760 | ) 761 | 762 | # Offload last model to CPU 763 | if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: 764 | self.final_offload_hook.offload() 765 | 766 | if not return_dict: 767 | return (image, has_nsfw_concept) 768 | 769 | return StableDiffusionPipelineOutput( 770 | images=image, nsfw_content_detected=has_nsfw_concept 771 | ) 772 | --------------------------------------------------------------------------------