├── README.md ├── config.py ├── docs ├── bbox_as_condition.json ├── boxdiff.gif ├── boxdiff.png ├── example.png └── visorgpt.gif ├── eval.py ├── pipeline ├── gligen_pipeline_boxdiff.py └── sd_pipeline_boxdiff.py ├── requirements.txt ├── run_gligen_boxdiff.py ├── run_sd_boxdiff.py └── utils ├── __init__.py ├── drawer.py ├── gaussian_smoothing.py ├── ptp_utils.py └── vis_utils.py /README.md: -------------------------------------------------------------------------------- 1 |
2 |

BoxDiff 🎨 (ICCV 2023)

3 |

BoxDiff: Text-to-Image Synthesis with Training-Free Box-Constrained Diffusion

4 | 5 | [Jinheng Xie](https://sierkinhane.github.io/)1  Yuexiang Li2  Yawen Huang2  Haozhe Liu2,3  Wentian Zhang2 Yefeng Zheng2  [Mike Zheng Shou](https://scholar.google.com/citations?hl=zh-CN&user=h1-3lSoAAAAJ&view_op=list_works&sortby=pubdate)1 6 | 7 | 1 National University of Singapore  2 Tencent Jarvis Lab  3 KAUST 8 | 9 | [![arXiv](https://img.shields.io/badge/arXiv-<2307.10816>-.svg)](https://arxiv.org/abs/2307.10816) 10 | 11 |
12 | 13 | 14 | 15 | ### BoxDiff-SD-XL 16 | [A BoxDiff implementation based on SD-XL](https://github.com/Cominclip/BoxDiff-XL) 17 | 18 | ### Integration in diffusers 19 | Thanks to [@zjysteven](https://github.com/zjysteven) for his efforts. Below shows an example with `stable-diffusion-2-1-base`. 20 | ``` 21 | import torch 22 | from PIL import Image, ImageDraw 23 | from copy import deepcopy 24 | 25 | from examples.community.pipeline_stable_diffusion_boxdiff import StableDiffusionBoxDiffPipeline 26 | 27 | def draw_box_with_text(img, boxes, names): 28 | colors = ["red", "olive", "blue", "green", "orange", "brown", "cyan", "purple"] 29 | img_new = deepcopy(img) 30 | draw = ImageDraw.Draw(img_new) 31 | 32 | W, H = img.size 33 | for bid, box in enumerate(boxes): 34 | draw.rectangle([box[0] * W, box[1] * H, box[2] * W, box[3] * H], outline=colors[bid % len(colors)], width=4) 35 | draw.text((box[0] * W, box[1] * H), names[bid], fill=colors[bid % len(colors)]) 36 | return img_new 37 | 38 | pipe = StableDiffusionBoxDiffPipeline.from_pretrained( 39 | "stabilityai/stable-diffusion-2-1-base", 40 | torch_dtype=torch.float16, 41 | ) 42 | pipe.to("cuda") 43 | 44 | # example 1 45 | prompt = "as the aurora lights up the sky, a herd of reindeer leisurely wanders on the grassy meadow, admiring the breathtaking view, a serene lake quietly reflects the magnificent display, and in the distance, a snow-capped mountain stands majestically, fantasy, 8k, highly detailed" 46 | phrases = [ 47 | "aurora", 48 | "reindeer", 49 | "meadow", 50 | "lake", 51 | "mountain" 52 | ] 53 | boxes = [[1,3,512,202], [75,344,421,495], [1,327,508,507], [2,217,507,341], [1,135,509,242]] 54 | 55 | # example 2 56 | # prompt = "A rabbit wearing sunglasses looks very proud" 57 | # phrases = ["rabbit", "sunglasses"] 58 | # boxes = [[67,87,366,512], [66,130,364,262]] 59 | 60 | boxes = [[x / 512 for x in box] for box in boxes] 61 | 62 | images = pipe( 63 | prompt, 64 | boxdiff_phrases=phrases, 65 | boxdiff_boxes=boxes, 66 | boxdiff_kwargs={ 67 | "attention_res": 16, 68 | "normalize_eot": True 69 | }, 70 | num_inference_steps=50, 71 | guidance_scale=7.5, 72 | generator=torch.manual_seed(42), 73 | safety_checker=None 74 | ).images 75 | 76 | draw_box_with_text(images[0], boxes, phrases).save("output.png") 77 | ``` 78 | 79 | ### Setup 80 | Note that we only test the code using PyTorch==1.12.0. You can build the environment via `pip` as follow: 81 | ``` 82 | pip3 install -r requirements.txt 83 | ``` 84 | To apply BoxDiff on GLIGEN pipeline, please install diffusers as follow: 85 | ``` 86 | git clone git@github.com:gligen/diffusers.git 87 | pip3 install -e . 88 | ``` 89 | 90 | ### Usage 91 | To add spatial control on the Stable Diffusion model, you can simply use `run_sd_boxdiff.py`. For example: 92 | ``` 93 | CUDA_VISIBLE_DEVICES=0 python3 run_sd_boxdiff.py --prompt "as the aurora lights up the sky, a herd of reindeer leisurely wanders on the grassy meadow, admiring the breathtaking view, a serene lake quietly reflects the magnificent display, and in the distance, a snow-capped mountain stands majestically, fantasy, 8k, highly detailed" --P 0.2 --L 1 --seeds [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,21,22,23,24,25,26,27,28,29,30] --token_indices [3,12,21,30,46] --bbox [[1,3,512,202],[75,344,421,495],[1,327,508,507],[2,217,507,341],[1,135,509,242]] --refine False 94 | ``` 95 | or another example: 96 | ``` 97 | CUDA_VISIBLE_DEVICES=0 python3 run_sd_boxdiff.py --prompt "A rabbit wearing sunglasses looks very proud" --P 0.2 --L 1 --seeds [1,2,3,4,5,6,7,8,9] --token_indices [2,4] --bbox [[67,87,366,512],[66,130,364,262]] 98 | ``` 99 | Note that you can specify the token indices as the indices of words you want control in the text prompt and one token index has one corresponding conditoning box. `P` and `L` are hyper-parameters for the proposed constraints. 100 | 101 | When `--bbox` is not specified, there is a interface to draw bounding boxes as conditions. 102 | ``` 103 | CUDA_VISIBLE_DEVICES=0 python3 run_sd_boxdiff.py --prompt "A rabbit wearing sunglasses looks very proud" --P 0.2 --L 1 --seeds [1,2,3,4,5,6,7,8,9] --token_indices [2,4] 104 | ``` 105 | 106 | To add spatial control on the GLIGEN model, you can simply use `run_gligen_boxdiff.py`. For example: 107 | ``` 108 | CUDA_VISIBLE_DEVICES=0 python3 run_gligen_boxdiff.py --prompt "A rabbit wearing sunglasses looks very proud" --gligen_phrases ["a rabbit","sunglasses"] --P 0.2 --L 1 --seeds [1,2,3,4,5,6,7,8,9] --token_indices [2,4] --bbox [[67,87,366,512],[66,130,364,262]] --refine False 109 | ``` 110 | 111 | The direcory structure of synthetic results are as follows: 112 | ``` 113 | outputs/ 114 | |-- text prompt/ 115 | | |-- 0.png 116 | | |-- 0_canvas.png 117 | | |-- 1.png 118 | | |-- 1_canvas.png 119 | | |-- ... 120 | ``` 121 | ![](docs/example.png) 122 | 123 | ### Customize Your Layout 124 | [VisorGPT](https://github.com/Sierkinhane/VisorGPT) can customize layouts as spatial conditions for image synthesis using BoxDiff. 125 | 126 | ### Citation 127 | ``` 128 | @InProceedings{Xie_2023_ICCV, 129 | author = {Xie, Jinheng and Li, Yuexiang and Huang, Yawen and Liu, Haozhe and Zhang, Wentian and Zheng, Yefeng and Shou, Mike Zheng}, 130 | title = {BoxDiff: Text-to-Image Synthesis with Training-Free Box-Constrained Diffusion}, 131 | booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)}, 132 | year = {2023}, 133 | pages = {7452-7461} 134 | } 135 | ``` 136 | 137 | Acknowledgment - the code is highly based on the repository of [diffusers](https://github.com/huggingface/diffusers), [google](https://github.com/google/prompt-to-prompt), and [yuval-alaluf](https://github.com/yuval-alaluf). 138 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from pathlib import Path 3 | from typing import Dict, List 4 | 5 | 6 | @dataclass 7 | class RunConfig: 8 | # Guiding text prompt 9 | prompt: str 10 | # Whether to use Stable Diffusion v2.1 11 | sd_2_1: bool = False 12 | # Which token indices to alter with attend-and-excite 13 | token_indices: List[int] = None 14 | # Which random seeds to use when generating 15 | seeds: List[int] = field(default_factory=lambda: [42]) 16 | # Path to save all outputs to 17 | output_path: Path = Path('./outputs') 18 | # Number of denoising steps 19 | n_inference_steps: int = 50 20 | # Text guidance scale 21 | guidance_scale: float = 7.5 22 | # Number of denoising steps to apply attend-and-excite 23 | max_iter_to_alter: int = 25 24 | # Resolution of UNet to compute attention maps over 25 | attention_res: int = 16 26 | # Whether to run standard SD or attend-and-excite 27 | run_standard_sd: bool = False 28 | # Dictionary defining the iterations and desired thresholds to apply iterative latent refinement in 29 | thresholds: Dict[int, float] = field(default_factory=lambda: {0: 0.05, 10: 0.5, 20: 0.8}) 30 | # Scale factor for updating the denoised latent z_t 31 | scale_factor: int = 20 32 | # Start and end values used for scaling the scale factor - decays linearly with the denoising timestep 33 | scale_range: tuple = field(default_factory=lambda: (1.0, 0.5)) 34 | # Whether to apply the Gaussian smoothing before computing the maximum attention value for each subject token 35 | smooth_attentions: bool = True 36 | # Standard deviation for the Gaussian smoothing 37 | sigma: float = 0.5 38 | # Kernel size for the Gaussian smoothing 39 | kernel_size: int = 3 40 | # Whether to save cross attention maps for the final results 41 | save_cross_attention_maps: bool = False 42 | 43 | # BoxDiff 44 | bbox: List[list] = field(default_factory=lambda: [[], []]) 45 | color: List[str] = field(default_factory=lambda: ['blue', 'red', 'purple', 'orange', 'green', 'yellow', 'black']) 46 | P: float = 0.2 47 | # number of pixels around the corner to be selected 48 | L: int = 1 49 | refine: bool = True 50 | gligen_phrases: List[str] = field(default_factory=lambda: ['', '']) 51 | n_splits: int = 4 52 | which_one: int = 1 53 | eval_output_path: Path = Path('./outputs/eval') 54 | 55 | 56 | def __post_init__(self): 57 | self.output_path.mkdir(exist_ok=True, parents=True) 58 | -------------------------------------------------------------------------------- /docs/boxdiff.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/showlab/BoxDiff/b0d5d3b534418aa3fc71b9a16e5b575c0b2ee3b6/docs/boxdiff.gif -------------------------------------------------------------------------------- /docs/boxdiff.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/showlab/BoxDiff/b0d5d3b534418aa3fc71b9a16e5b575c0b2ee3b6/docs/boxdiff.png -------------------------------------------------------------------------------- /docs/example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/showlab/BoxDiff/b0d5d3b534418aa3fc71b9a16e5b575c0b2ee3b6/docs/example.png -------------------------------------------------------------------------------- /docs/visorgpt.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/showlab/BoxDiff/b0d5d3b534418aa3fc71b9a16e5b575c0b2ee3b6/docs/visorgpt.gif -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | 2 | import pprint 3 | from typing import List 4 | 5 | import pyrallis 6 | import torch 7 | from PIL import Image 8 | from config import RunConfig 9 | from pipeline.gligen_pipeline_boxdiff import BoxDiffPipeline 10 | from utils import ptp_utils, vis_utils 11 | from utils.ptp_utils import AttentionStore 12 | 13 | import numpy as np 14 | from utils.drawer import draw_rectangle, DashedImageDraw 15 | 16 | import warnings 17 | import json, os 18 | from tqdm import tqdm 19 | warnings.filterwarnings("ignore", category=UserWarning) 20 | 21 | 22 | def load_model(): 23 | device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu') 24 | stable_diffusion_version = "gligen/diffusers-generation-text-box" 25 | # If you cannot access the huggingface on your server, you can use the local prepared one. 26 | # stable_diffusion_version = "../../packages/diffusers/gligen_ckpts/diffusers-generation-text-box" 27 | stable = BoxDiffPipeline.from_pretrained(stable_diffusion_version, revision="fp16", torch_dtype=torch.float16).to(device) 28 | 29 | return stable 30 | 31 | 32 | def get_indices_to_alter(stable, prompt: str) -> List[int]: 33 | token_idx_to_word = {idx: stable.tokenizer.decode(t) 34 | for idx, t in enumerate(stable.tokenizer(prompt)['input_ids']) 35 | if 0 < idx < len(stable.tokenizer(prompt)['input_ids']) - 1} 36 | pprint.pprint(token_idx_to_word) 37 | token_indices = input("Please enter the a comma-separated list indices of the tokens you wish to " 38 | "alter (e.g., 2,5): ") 39 | token_indices = [int(i) for i in token_indices.split(",")] 40 | print(f"Altering tokens: {[token_idx_to_word[i] for i in token_indices]}") 41 | return token_indices 42 | 43 | 44 | def run_on_prompt(prompt: List[str], 45 | model: BoxDiffPipeline, 46 | controller: AttentionStore, 47 | token_indices: List[int], 48 | seed: torch.Generator, 49 | config: RunConfig) -> Image.Image: 50 | if controller is not None: 51 | ptp_utils.register_attention_control(model, controller) 52 | 53 | gligen_boxes = [] 54 | for i in range(len(config.bbox)): 55 | x1, y1, x2, y2 = config.bbox[i] 56 | gligen_boxes.append([x1/512, y1/512, x2/512, y2/512]) 57 | 58 | outputs = model(prompt=prompt, 59 | attention_store=controller, 60 | indices_to_alter=token_indices, 61 | attention_res=config.attention_res, 62 | guidance_scale=config.guidance_scale, 63 | gligen_phrases=config.gligen_phrases, 64 | gligen_boxes=gligen_boxes, 65 | gligen_scheduled_sampling_beta=0.3, 66 | generator=seed, 67 | num_inference_steps=config.n_inference_steps, 68 | max_iter_to_alter=config.max_iter_to_alter, 69 | run_standard_sd=config.run_standard_sd, 70 | thresholds=config.thresholds, 71 | scale_factor=config.scale_factor, 72 | scale_range=config.scale_range, 73 | smooth_attentions=config.smooth_attentions, 74 | sigma=config.sigma, 75 | kernel_size=config.kernel_size, 76 | sd_2_1=config.sd_2_1, 77 | bbox=config.bbox, 78 | config=config) 79 | image = outputs.images[0] 80 | return image 81 | 82 | 83 | @pyrallis.wrap() 84 | def main(config: RunConfig): 85 | stable = load_model() 86 | 87 | # read bbox from the pre-prepared .json file 88 | with open('docs/bbox_as_condition.json', 'r', encoding='utf8') as fp: 89 | bbox_json = json.load(fp) 90 | 91 | idx = np.arange(len(bbox_json)) 92 | split_idx = list(np.array_split(idx, config.n_splits)[config.which_one - 1]) 93 | 94 | for bidx in tqdm(split_idx): 95 | 96 | filename = bbox_json[bidx]['filename'] 97 | sub_dir = filename.split('/')[-2] 98 | img_name = filename.split('/')[-1].split('.')[0] 99 | 100 | objects = bbox_json[bidx]['objects'] 101 | cls_name = list(objects.keys()) 102 | config.bbox = list(objects.values()) 103 | # import ipdb 104 | # ipdb.set_trace() 105 | text_prompt = '' 106 | gligen_phrases = [] 107 | token_indices = [] 108 | for nidx, n in enumerate(cls_name): 109 | text_prompt += f'a {n} and ' 110 | if nidx == 0: 111 | token_indices.append(2) 112 | else: 113 | token_indices.append(5 + (nidx - 1) * 3) 114 | gligen_phrases.append('a {n}') 115 | config.prompt = text_prompt[:-5] 116 | config.gligen_phrases = gligen_phrases 117 | 118 | for seed in config.seeds: 119 | print(f"Current seed is : {seed}") 120 | g = torch.Generator('cuda').manual_seed(seed) 121 | controller = AttentionStore() 122 | controller.num_uncond_att_layers = -16 123 | image = run_on_prompt(prompt=config.prompt, 124 | model=stable, 125 | controller=controller, 126 | token_indices=token_indices, 127 | seed=g, 128 | config=config) 129 | 130 | prompt_output_path = config.eval_output_path / sub_dir 131 | prompt_output_path.mkdir(exist_ok=True, parents=True) 132 | if os.path.isfile(prompt_output_path / f'{img_name}_{seed}.png'): 133 | continue 134 | 135 | image.save(prompt_output_path / f'{img_name}_{seed}.png') 136 | 137 | canvas = Image.fromarray(np.zeros((image.size[0], image.size[0], 3), dtype=np.uint8) + 220) 138 | draw = DashedImageDraw(canvas) 139 | 140 | for i in range(len(config.bbox)): 141 | x1, y1, x2, y2 = config.bbox[i] 142 | draw.dashed_rectangle([(x1, y1), (x2, y2)], dash=(5, 5), outline=config.color[i], width=5) 143 | canvas.save(prompt_output_path / f'{img_name}_{seed}_canvas.png') 144 | 145 | 146 | if __name__ == '__main__': 147 | main() 148 | -------------------------------------------------------------------------------- /pipeline/gligen_pipeline_boxdiff.py: -------------------------------------------------------------------------------- 1 | 2 | from typing import Any, Callable, Dict, List, Optional, Union, Tuple 3 | 4 | import numpy as np 5 | import torch 6 | from torch.nn import functional as F 7 | 8 | from diffusers.utils import deprecate, is_accelerate_available, logging, randn_tensor, replace_example_docstring 9 | from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput 10 | 11 | from diffusers.pipelines.stable_diffusion import StableDiffusionGLIGENPipeline 12 | 13 | from utils.gaussian_smoothing import GaussianSmoothing 14 | from utils.ptp_utils import AttentionStore, aggregate_attention 15 | import PIL 16 | 17 | logger = logging.get_logger(__name__) 18 | 19 | class BoxDiffPipeline(StableDiffusionGLIGENPipeline): 20 | r""" 21 | Pipeline for text-to-image generation using Stable Diffusion. 22 | This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the 23 | library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) 24 | Args: 25 | vae ([`AutoencoderKL`]): 26 | Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. 27 | text_encoder ([`CLIPTextModel`]): 28 | Frozen text-encoder. Stable Diffusion uses the text portion of 29 | [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically 30 | the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. 31 | tokenizer (`CLIPTokenizer`): 32 | Tokenizer of class 33 | [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). 34 | unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. 35 | scheduler ([`SchedulerMixin`]): 36 | A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of 37 | [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. 38 | safety_checker ([`StableDiffusionSafetyChecker`]): 39 | Classification module that estimates whether generated images could be considered offensive or harmful. 40 | Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. 41 | feature_extractor ([`CLIPFeatureExtractor`]): 42 | Model that extracts features from generated images to be used as inputs for the `safety_checker`. 43 | """ 44 | _optional_components = ["safety_checker", "feature_extractor"] 45 | 46 | def _encode_prompt( 47 | self, 48 | prompt, 49 | device, 50 | num_images_per_prompt, 51 | do_classifier_free_guidance, 52 | negative_prompt=None, 53 | prompt_embeds: Optional[torch.FloatTensor] = None, 54 | negative_prompt_embeds: Optional[torch.FloatTensor] = None, 55 | ): 56 | r""" 57 | Encodes the prompt into text encoder hidden states. 58 | Args: 59 | prompt (`str` or `List[str]`, *optional*): 60 | prompt to be encoded 61 | device: (`torch.device`): 62 | torch device 63 | num_images_per_prompt (`int`): 64 | number of images that should be generated per prompt 65 | do_classifier_free_guidance (`bool`): 66 | whether to use classifier free guidance or not 67 | negative_ prompt (`str` or `List[str]`, *optional*): 68 | The prompt or prompts not to guide the image generation. If not defined, one has to pass 69 | `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. 70 | Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). 71 | prompt_embeds (`torch.FloatTensor`, *optional*): 72 | Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not 73 | provided, text embeddings will be generated from `prompt` input argument. 74 | negative_prompt_embeds (`torch.FloatTensor`, *optional*): 75 | Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt 76 | weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input 77 | argument. 78 | """ 79 | if prompt is not None and isinstance(prompt, str): 80 | batch_size = 1 81 | elif prompt is not None and isinstance(prompt, list): 82 | batch_size = len(prompt) 83 | else: 84 | batch_size = prompt_embeds.shape[0] 85 | 86 | if prompt_embeds is None: 87 | text_inputs = self.tokenizer( 88 | prompt, 89 | padding="max_length", 90 | max_length=self.tokenizer.model_max_length, 91 | truncation=True, 92 | return_tensors="pt", 93 | ) 94 | text_input_ids = text_inputs.input_ids 95 | untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids 96 | 97 | if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( 98 | text_input_ids, untruncated_ids 99 | ): 100 | removed_text = self.tokenizer.batch_decode( 101 | untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] 102 | ) 103 | logger.warning( 104 | "The following part of your input was truncated because CLIP can only handle sequences up to" 105 | f" {self.tokenizer.model_max_length} tokens: {removed_text}" 106 | ) 107 | 108 | if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: 109 | attention_mask = text_inputs.attention_mask.to(device) 110 | else: 111 | attention_mask = None 112 | 113 | prompt_embeds = self.text_encoder( 114 | text_input_ids.to(device), 115 | attention_mask=attention_mask, 116 | ) 117 | prompt_embeds = prompt_embeds[0] 118 | 119 | prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) 120 | 121 | bs_embed, seq_len, _ = prompt_embeds.shape 122 | # duplicate text embeddings for each generation per prompt, using mps friendly method 123 | prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) 124 | prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) 125 | 126 | # get unconditional embeddings for classifier free guidance 127 | if do_classifier_free_guidance and negative_prompt_embeds is None: 128 | uncond_tokens: List[str] 129 | if negative_prompt is None: 130 | uncond_tokens = [""] * batch_size 131 | elif type(prompt) is not type(negative_prompt): 132 | raise TypeError( 133 | f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" 134 | f" {type(prompt)}." 135 | ) 136 | elif isinstance(negative_prompt, str): 137 | uncond_tokens = [negative_prompt] 138 | elif batch_size != len(negative_prompt): 139 | raise ValueError( 140 | f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" 141 | f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" 142 | " the batch size of `prompt`." 143 | ) 144 | else: 145 | uncond_tokens = negative_prompt 146 | 147 | max_length = prompt_embeds.shape[1] 148 | uncond_input = self.tokenizer( 149 | uncond_tokens, 150 | padding="max_length", 151 | max_length=max_length, 152 | truncation=True, 153 | return_tensors="pt", 154 | ) 155 | 156 | if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: 157 | attention_mask = uncond_input.attention_mask.to(device) 158 | else: 159 | attention_mask = None 160 | 161 | negative_prompt_embeds = self.text_encoder( 162 | uncond_input.input_ids.to(device), 163 | attention_mask=attention_mask, 164 | ) 165 | negative_prompt_embeds = negative_prompt_embeds[0] 166 | 167 | if do_classifier_free_guidance: 168 | # duplicate unconditional embeddings for each generation per prompt, using mps friendly method 169 | seq_len = negative_prompt_embeds.shape[1] 170 | 171 | negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) 172 | 173 | negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) 174 | negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) 175 | 176 | # For classifier free guidance, we need to do two forward passes. 177 | # Here we concatenate the unconditional and text embeddings into a single batch 178 | # to avoid doing two forward passes 179 | prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) 180 | 181 | return text_inputs, prompt_embeds 182 | 183 | def _compute_max_attention_per_index(self, 184 | attention_maps: torch.Tensor, 185 | indices_to_alter: List[int], 186 | smooth_attentions: bool = False, 187 | sigma: float = 0.5, 188 | kernel_size: int = 3, 189 | normalize_eot: bool = False, 190 | bbox: List[int] = None, 191 | config=None, 192 | ) -> List[torch.Tensor]: 193 | """ Computes the maximum attention value for each of the tokens we wish to alter. """ 194 | last_idx = -1 195 | if normalize_eot: 196 | prompt = self.prompt 197 | if isinstance(self.prompt, list): 198 | prompt = self.prompt[0] 199 | last_idx = len(self.tokenizer(prompt)['input_ids']) - 1 200 | attention_for_text = attention_maps[:, :, 1:last_idx] 201 | attention_for_text *= 100 202 | attention_for_text = torch.nn.functional.softmax(attention_for_text, dim=-1) 203 | 204 | # Shift indices since we removed the first token 205 | indices_to_alter = [index - 1 for index in indices_to_alter] 206 | 207 | # Extract the maximum values 208 | max_indices_list_fg = [] 209 | max_indices_list_bg = [] 210 | dist_x = [] 211 | dist_y = [] 212 | 213 | cnt = 0 214 | for i in indices_to_alter: 215 | image = attention_for_text[:, :, i] 216 | 217 | box = [max(round(b / (512 / image.shape[0])), 0) for b in bbox[cnt]] 218 | x1, y1, x2, y2 = box 219 | cnt += 1 220 | 221 | # coordinates to masks 222 | obj_mask = torch.zeros_like(image) 223 | ones_mask = torch.ones([y2 - y1, x2 - x1], dtype=obj_mask.dtype).to(obj_mask.device) 224 | obj_mask[y1:y2, x1:x2] = ones_mask 225 | bg_mask = 1 - obj_mask 226 | 227 | if smooth_attentions: 228 | smoothing = GaussianSmoothing(channels=1, kernel_size=kernel_size, sigma=sigma, dim=2).cuda() 229 | input = F.pad(image.unsqueeze(0).unsqueeze(0), (1, 1, 1, 1), mode='reflect') 230 | image = smoothing(input).squeeze(0).squeeze(0) 231 | 232 | # Inner-Box constraint 233 | k = (obj_mask.sum() * config.P).long() 234 | max_indices_list_fg.append((image * obj_mask).reshape(-1).topk(k)[0].mean()) 235 | 236 | # Outer-Box constraint 237 | k = (bg_mask.sum() * config.P).long() 238 | max_indices_list_bg.append((image * bg_mask).reshape(-1).topk(k)[0].mean()) 239 | 240 | # Corner Constraint 241 | gt_proj_x = torch.max(obj_mask, dim=0)[0] 242 | gt_proj_y = torch.max(obj_mask, dim=1)[0] 243 | corner_mask_x = torch.zeros_like(gt_proj_x) 244 | corner_mask_y = torch.zeros_like(gt_proj_y) 245 | 246 | # create gt according to the number config.L 247 | N = gt_proj_x.shape[0] 248 | corner_mask_x[max(box[0] - config.L, 0): min(box[0] + config.L + 1, N)] = 1. 249 | corner_mask_x[max(box[2] - config.L, 0): min(box[2] + config.L + 1, N)] = 1. 250 | corner_mask_y[max(box[1] - config.L, 0): min(box[1] + config.L + 1, N)] = 1. 251 | corner_mask_y[max(box[3] - config.L, 0): min(box[3] + config.L + 1, N)] = 1. 252 | dist_x.append((F.l1_loss(image.max(dim=0)[0], gt_proj_x, reduction='none') * corner_mask_x).mean()) 253 | dist_y.append((F.l1_loss(image.max(dim=1)[0], gt_proj_y, reduction='none') * corner_mask_y).mean()) 254 | 255 | return max_indices_list_fg, max_indices_list_bg, dist_x, dist_y 256 | 257 | def _aggregate_and_get_max_attention_per_token(self, attention_store: AttentionStore, 258 | indices_to_alter: List[int], 259 | attention_res: int = 16, 260 | smooth_attentions: bool = False, 261 | sigma: float = 0.5, 262 | kernel_size: int = 3, 263 | normalize_eot: bool = False, 264 | bbox: List[int] = None, 265 | config=None, 266 | ): 267 | """ Aggregates the attention for each token and computes the max activation value for each token to alter. """ 268 | attention_maps = aggregate_attention( 269 | attention_store=attention_store, 270 | res=attention_res, 271 | from_where=("up", "down", "mid"), 272 | is_cross=True, 273 | select=0) 274 | 275 | max_attention_per_index_fg, max_attention_per_index_bg, dist_x, dist_y = self._compute_max_attention_per_index( 276 | attention_maps=attention_maps, 277 | indices_to_alter=indices_to_alter, 278 | smooth_attentions=smooth_attentions, 279 | sigma=sigma, 280 | kernel_size=kernel_size, 281 | normalize_eot=normalize_eot, 282 | bbox=bbox, 283 | config=config, 284 | ) 285 | return max_attention_per_index_fg, max_attention_per_index_bg, dist_x, dist_y 286 | 287 | @staticmethod 288 | def _compute_loss(max_attention_per_index_fg: List[torch.Tensor], max_attention_per_index_bg: List[torch.Tensor], 289 | dist_x: List[torch.Tensor], dist_y: List[torch.Tensor], return_losses: bool = False) -> torch.Tensor: 290 | """ Computes the attend-and-excite loss using the maximum attention value for each token. """ 291 | losses_fg = [max(0, 1. - curr_max) for curr_max in max_attention_per_index_fg] 292 | losses_bg = [max(0, curr_max) for curr_max in max_attention_per_index_bg] 293 | loss = sum(losses_fg) + sum(losses_bg) + sum(dist_x) + sum(dist_y) 294 | if return_losses: 295 | return max(losses_fg), losses_fg 296 | else: 297 | return max(losses_fg), loss 298 | 299 | @staticmethod 300 | def _update_latent(latents: torch.Tensor, loss: torch.Tensor, step_size: float) -> torch.Tensor: 301 | """ Update the latent according to the computed loss. """ 302 | grad_cond = torch.autograd.grad(loss.requires_grad_(True), [latents], retain_graph=True)[0] 303 | latents = latents - step_size * grad_cond 304 | return latents 305 | 306 | def _perform_iterative_refinement_step(self, 307 | latents: torch.Tensor, 308 | indices_to_alter: List[int], 309 | loss_fg: torch.Tensor, 310 | threshold: float, 311 | text_embeddings: torch.Tensor, 312 | text_input, 313 | attention_store: AttentionStore, 314 | step_size: float, 315 | t: int, 316 | attention_res: int = 16, 317 | smooth_attentions: bool = True, 318 | sigma: float = 0.5, 319 | kernel_size: int = 3, 320 | max_refinement_steps: int = 20, 321 | normalize_eot: bool = False, 322 | bbox: List[int] = None, 323 | config=None, 324 | ): 325 | """ 326 | Performs the iterative latent refinement introduced in the paper. Here, we continuously update the latent 327 | code according to our loss objective until the given threshold is reached for all tokens. 328 | """ 329 | iteration = 0 330 | target_loss = max(0, 1. - threshold) 331 | while loss_fg > target_loss: 332 | iteration += 1 333 | 334 | latents = latents.clone().detach().requires_grad_(True) 335 | noise_pred_text = self.unet(latents, t, encoder_hidden_states=text_embeddings[1].unsqueeze(0)).sample 336 | self.unet.zero_grad() 337 | 338 | # Get max activation value for each subject token 339 | max_attention_per_index_fg, max_attention_per_index_bg, dist_x, dist_y = self._aggregate_and_get_max_attention_per_token( 340 | attention_store=attention_store, 341 | indices_to_alter=indices_to_alter, 342 | attention_res=attention_res, 343 | smooth_attentions=smooth_attentions, 344 | sigma=sigma, 345 | kernel_size=kernel_size, 346 | normalize_eot=normalize_eot, 347 | bbox=bbox, 348 | config=config, 349 | ) 350 | 351 | loss_fg, losses_fg = self._compute_loss(max_attention_per_index_fg, max_attention_per_index_bg, dist_x, dist_y, return_losses=True) 352 | 353 | if loss_fg != 0: 354 | latents = self._update_latent(latents, loss_fg, step_size) 355 | 356 | with torch.no_grad(): 357 | noise_pred_uncond = self.unet(latents, t, encoder_hidden_states=text_embeddings[0].unsqueeze(0)).sample 358 | noise_pred_text = self.unet(latents, t, encoder_hidden_states=text_embeddings[1].unsqueeze(0)).sample 359 | 360 | try: 361 | low_token = np.argmax([l.item() if type(l) != int else l for l in losses_fg]) 362 | except Exception as e: 363 | print(e) # catch edge case :) 364 | 365 | low_token = np.argmax(losses_fg) 366 | 367 | low_word = self.tokenizer.decode(text_input.input_ids[0][indices_to_alter[low_token]]) 368 | # print(f'\t Try {iteration}. {low_word} has a max attention of {max_attention_per_index_fg[low_token]}') 369 | 370 | if iteration >= max_refinement_steps: 371 | # print(f'\t Exceeded max number of iterations ({max_refinement_steps})! ' 372 | # f'Finished with a max attention of {max_attention_per_index_fg[low_token]}') 373 | break 374 | 375 | # Run one more time but don't compute gradients and update the latents. 376 | # We just need to compute the new loss - the grad update will occur below 377 | latents = latents.clone().detach().requires_grad_(True) 378 | noise_pred_text = self.unet(latents, t, encoder_hidden_states=text_embeddings[1].unsqueeze(0)).sample 379 | self.unet.zero_grad() 380 | 381 | # Get max activation value for each subject token 382 | max_attention_per_index_fg, max_attention_per_index_bg, dist_x, dist_y = self._aggregate_and_get_max_attention_per_token( 383 | attention_store=attention_store, 384 | indices_to_alter=indices_to_alter, 385 | attention_res=attention_res, 386 | smooth_attentions=smooth_attentions, 387 | sigma=sigma, 388 | kernel_size=kernel_size, 389 | normalize_eot=normalize_eot, 390 | bbox=bbox, 391 | config=config, 392 | ) 393 | loss_fg, losses_fg = self._compute_loss(max_attention_per_index_fg, max_attention_per_index_bg, dist_x, dist_y, return_losses=True) 394 | # print(f"\t Finished with loss of: {loss_fg}") 395 | return loss_fg, latents, max_attention_per_index_fg 396 | 397 | @torch.no_grad() 398 | def __call__( 399 | self, 400 | prompt: Union[str, List[str]], 401 | attention_store: AttentionStore, 402 | indices_to_alter: List[int], 403 | attention_res: int = 16, 404 | height: Optional[int] = None, 405 | width: Optional[int] = None, 406 | num_inference_steps: int = 50, 407 | guidance_scale: float = 7.5, 408 | gligen_scheduled_sampling_beta: float = 0.3, 409 | gligen_phrases: List[str] = None, 410 | gligen_boxes: List[List[float]] = None, 411 | gligen_inpaint_image: Optional[PIL.Image.Image] = None, 412 | negative_prompt: Optional[Union[str, List[str]]] = None, 413 | num_images_per_prompt: Optional[int] = 1, 414 | eta: float = 0.0, 415 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 416 | latents: Optional[torch.FloatTensor] = None, 417 | prompt_embeds: Optional[torch.FloatTensor] = None, 418 | negative_prompt_embeds: Optional[torch.FloatTensor] = None, 419 | output_type: Optional[str] = "pil", 420 | return_dict: bool = True, 421 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, 422 | callback_steps: Optional[int] = 1, 423 | cross_attention_kwargs: Optional[Dict[str, Any]] = None, 424 | max_iter_to_alter: Optional[int] = 25, 425 | run_standard_sd: bool = False, 426 | thresholds: Optional[dict] = {0: 0.05, 10: 0.5, 20: 0.8}, 427 | scale_factor: int = 20, 428 | scale_range: Tuple[float, float] = (1., 0.5), 429 | smooth_attentions: bool = True, 430 | sigma: float = 0.5, 431 | kernel_size: int = 3, 432 | sd_2_1: bool = False, 433 | bbox: List[int] = None, 434 | config = None, 435 | ): 436 | r""" 437 | Function invoked when calling the pipeline for generation. 438 | Args: 439 | prompt (`str` or `List[str]`, *optional*): 440 | The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. 441 | instead. 442 | height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): 443 | The height in pixels of the generated image. 444 | width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): 445 | The width in pixels of the generated image. 446 | num_inference_steps (`int`, *optional*, defaults to 50): 447 | The number of denoising steps. More denoising steps usually lead to a higher quality image at the 448 | expense of slower inference. 449 | guidance_scale (`float`, *optional*, defaults to 7.5): 450 | Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). 451 | `guidance_scale` is defined as `w` of equation 2. of [Imagen 452 | Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 453 | 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, 454 | usually at the expense of lower image quality. 455 | negative_prompt (`str` or `List[str]`, *optional*): 456 | The prompt or prompts not to guide the image generation. If not defined, one has to pass 457 | `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. 458 | Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). 459 | num_images_per_prompt (`int`, *optional*, defaults to 1): 460 | The number of images to generate per prompt. 461 | eta (`float`, *optional*, defaults to 0.0): 462 | Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to 463 | [`schedulers.DDIMScheduler`], will be ignored for others. 464 | generator (`torch.Generator` or `List[torch.Generator]`, *optional*): 465 | One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) 466 | to make generation deterministic. 467 | latents (`torch.FloatTensor`, *optional*): 468 | Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image 469 | generation. Can be used to tweak the same generation with different prompts. If not provided, a latents 470 | tensor will ge generated by sampling using the supplied random `generator`. 471 | prompt_embeds (`torch.FloatTensor`, *optional*): 472 | Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not 473 | provided, text embeddings will be generated from `prompt` input argument. 474 | negative_prompt_embeds (`torch.FloatTensor`, *optional*): 475 | Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt 476 | weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input 477 | argument. 478 | output_type (`str`, *optional*, defaults to `"pil"`): 479 | The output format of the generate image. Choose between 480 | [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. 481 | return_dict (`bool`, *optional*, defaults to `True`): 482 | Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a 483 | plain tuple. 484 | callback (`Callable`, *optional*): 485 | A function that will be called every `callback_steps` steps during inference. The function will be 486 | called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. 487 | callback_steps (`int`, *optional*, defaults to 1): 488 | The frequency at which the `callback` function will be called. If not specified, the callback will be 489 | called at every step. 490 | cross_attention_kwargs (`dict`, *optional*): 491 | A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under 492 | `self.processor` in 493 | [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). 494 | Examples: 495 | Returns: 496 | [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: 497 | [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. 498 | When returning a tuple, the first element is a list with the generated images, and the second element is a 499 | list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" 500 | (nsfw) content, according to the `safety_checker`. 501 | :type attention_store: object 502 | """ 503 | # 0. Default height and width to unet 504 | height = height or self.unet.config.sample_size * self.vae_scale_factor 505 | width = width or self.unet.config.sample_size * self.vae_scale_factor 506 | 507 | # 1. Check inputs. Raise error if not correct 508 | self.check_inputs( 509 | prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds 510 | ) 511 | 512 | # 2. Define call parameters 513 | self.prompt = prompt 514 | if prompt is not None and isinstance(prompt, str): 515 | batch_size = 1 516 | elif prompt is not None and isinstance(prompt, list): 517 | batch_size = len(prompt) 518 | else: 519 | batch_size = prompt_embeds.shape[0] 520 | 521 | device = self._execution_device 522 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) 523 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` 524 | # corresponds to doing no classifier free guidance. 525 | do_classifier_free_guidance = guidance_scale > 1.0 526 | 527 | # 3. Encode input prompt 528 | text_inputs, prompt_embeds = self._encode_prompt( 529 | prompt, 530 | device, 531 | num_images_per_prompt, 532 | do_classifier_free_guidance, 533 | negative_prompt, 534 | prompt_embeds=prompt_embeds, 535 | negative_prompt_embeds=negative_prompt_embeds, 536 | ) 537 | 538 | # 4. Prepare timesteps 539 | self.scheduler.set_timesteps(num_inference_steps, device=device) 540 | timesteps = self.scheduler.timesteps 541 | 542 | # 5. Prepare latent variables 543 | num_channels_latents = self.unet.in_channels 544 | latents = self.prepare_latents( 545 | batch_size * num_images_per_prompt, 546 | num_channels_latents, 547 | height, 548 | width, 549 | prompt_embeds.dtype, 550 | device, 551 | generator, 552 | latents, 553 | ) 554 | 555 | def draw_inpaint_mask_from_boxes(boxes, size): 556 | inpaint_mask = torch.ones(size[0], size[1]) 557 | for box in boxes: 558 | x0, x1 = box[0] * size[0], box[2] * size[0] 559 | y0, y1 = box[1] * size[1], box[3] * size[1] 560 | inpaint_mask[int(y0):int(y1), int(x0):int(x1)] = 0 561 | return inpaint_mask 562 | 563 | # 5.1 Prepare GLIGEN variables 564 | if gligen_phrases is not None: 565 | assert len(gligen_phrases) == len(gligen_boxes) 566 | assert batch_size == 1 567 | max_objs = 30 568 | _boxes = gligen_boxes 569 | tokenizer_inputs = self.tokenizer(gligen_phrases, padding=True, return_tensors="pt").to( 570 | self.text_encoder.device) 571 | _text_embeddings = self.text_encoder(**tokenizer_inputs).pooler_output 572 | n_objs = min(len(_boxes), max_objs) 573 | device = self.text_encoder.device 574 | dtype = self.text_encoder.dtype 575 | boxes = torch.zeros(max_objs, 4, device=device, dtype=dtype) 576 | boxes[:n_objs] = torch.tensor(_boxes[:n_objs]) 577 | text_embeddings = torch.zeros(max_objs, 768, device=device, dtype=dtype) 578 | text_embeddings[:n_objs] = _text_embeddings[:n_objs] 579 | masks = torch.zeros(max_objs, device=device, dtype=dtype) 580 | masks[:n_objs] = 1 581 | 582 | repeat_batch = batch_size * num_images_per_prompt 583 | if do_classifier_free_guidance: 584 | repeat_batch = repeat_batch * 2 585 | boxes = boxes.unsqueeze(0).expand(repeat_batch, -1, -1).clone() 586 | text_embeddings = text_embeddings.unsqueeze(0).expand(repeat_batch, -1, -1).clone() 587 | masks = masks.unsqueeze(0).expand(repeat_batch, -1).clone() 588 | if do_classifier_free_guidance: 589 | masks[:repeat_batch // 2] = 0 590 | if cross_attention_kwargs is None: 591 | cross_attention_kwargs = {} 592 | cross_attention_kwargs['gligen'] = { 593 | 'boxes': boxes, 594 | 'positive_embeddings': text_embeddings, 595 | 'masks': masks 596 | } 597 | 598 | # Prepare latent variables for GLIGEN inpainting 599 | if gligen_inpaint_image is not None: 600 | if gligen_inpaint_image.size != (self.vae.sample_size, self.vae.sample_size): 601 | def crop(im, new_width, new_height): 602 | width, height = im.size 603 | left = (width - new_width) / 2 604 | top = (height - new_height) / 2 605 | right = (width + new_width) / 2 606 | bottom = (height + new_height) / 2 607 | return im.crop((left, top, right, bottom)) 608 | 609 | def target_size_center_crop(im, new_hw): 610 | width, height = im.size 611 | if width != height: 612 | im = crop(im, min(height, width), min(height, width)) 613 | return im.resize((new_hw, new_hw), PIL.Image.LANCZOS) 614 | 615 | gligen_inpaint_image = target_size_center_crop(gligen_inpaint_image, self.vae.sample_size) 616 | 617 | gligen_inpaint_image = torch.from_numpy(np.asarray(gligen_inpaint_image)) 618 | gligen_inpaint_image = gligen_inpaint_image.unsqueeze(0).permute(0, 3, 1, 2) 619 | gligen_inpaint_image = gligen_inpaint_image.to(dtype=torch.float32) / 127.5 - 1.0 620 | gligen_inpaint_image = gligen_inpaint_image.to(dtype=self.vae.dtype, device=self.vae.device) 621 | gligen_inpaint_latent = self.vae.encode(gligen_inpaint_image).latent_dist.sample() 622 | gligen_inpaint_latent = self.vae.config.scaling_factor * gligen_inpaint_latent 623 | gligen_inpaint_mask = draw_inpaint_mask_from_boxes(_boxes[:n_objs], gligen_inpaint_latent.shape[2:]) 624 | gligen_inpaint_mask = gligen_inpaint_mask.to(dtype=gligen_inpaint_latent.dtype, 625 | device=gligen_inpaint_latent.device) 626 | gligen_inpaint_mask = gligen_inpaint_mask[None, None] 627 | gligen_inpaint_mask_addition = torch.cat( 628 | (gligen_inpaint_latent * gligen_inpaint_mask, gligen_inpaint_mask), dim=1) 629 | gligen_inpaint_mask_addition = gligen_inpaint_mask_addition.expand(repeat_batch, -1, -1, -1).clone() 630 | 631 | num_grounding_steps = int(gligen_scheduled_sampling_beta * len(timesteps)) 632 | self.enable_fuser(True) 633 | 634 | # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline 635 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) 636 | 637 | scale_range = np.linspace(scale_range[0], scale_range[1], len(self.scheduler.timesteps)) 638 | 639 | if max_iter_to_alter is None: 640 | max_iter_to_alter = len(self.scheduler.timesteps) + 1 641 | 642 | # 7. Denoising loop 643 | num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order 644 | with self.progress_bar(total=num_inference_steps) as progress_bar: 645 | for i, t in enumerate(timesteps): 646 | 647 | with torch.enable_grad(): 648 | 649 | if i == num_grounding_steps: 650 | self.enable_fuser(False) 651 | 652 | if latents.shape[1] != 4: 653 | latents = torch.randn_like(latents[:, :4]) 654 | 655 | if gligen_inpaint_image is not None: 656 | gligen_inpaint_latent_with_noise = self.scheduler.add_noise( 657 | gligen_inpaint_latent, 658 | torch.randn_like(gligen_inpaint_latent), 659 | t 660 | ).expand(latents.shape[0], -1, -1, -1).clone() 661 | latents = gligen_inpaint_latent_with_noise * gligen_inpaint_mask + latents * ( 662 | 1 - gligen_inpaint_mask) 663 | 664 | latents = latents.clone().detach().requires_grad_(True) 665 | 666 | # expand the latents if we are doing classifier free guidance 667 | latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents 668 | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) 669 | 670 | if gligen_inpaint_image is not None: 671 | latent_model_input = torch.cat((latent_model_input, gligen_inpaint_mask_addition), dim=1) 672 | 673 | # Forward pass of denoising with text conditioning 674 | noise_pred_text = self.unet(latent_model_input, t, 675 | encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs).sample 676 | self.unet.zero_grad() 677 | 678 | # Get max activation value for each subject token 679 | max_attention_per_index_fg, max_attention_per_index_bg, dist_x, dist_y = self._aggregate_and_get_max_attention_per_token( 680 | attention_store=attention_store, 681 | indices_to_alter=indices_to_alter, 682 | attention_res=attention_res, 683 | smooth_attentions=smooth_attentions, 684 | sigma=sigma, 685 | kernel_size=kernel_size, 686 | normalize_eot=sd_2_1, 687 | bbox=bbox, 688 | config=config, 689 | ) 690 | 691 | if not run_standard_sd: 692 | 693 | loss_fg, loss = self._compute_loss(max_attention_per_index_fg, max_attention_per_index_bg, dist_x, dist_y) 694 | 695 | # Refinement from attend-and-excite (not necessary) 696 | if i in thresholds.keys() and loss_fg > 1. - thresholds[i] and config.refine: 697 | del noise_pred_text 698 | torch.cuda.empty_cache() 699 | loss_fg, latents, max_attention_per_index_fg = self._perform_iterative_refinement_step( 700 | latents=latents, 701 | indices_to_alter=indices_to_alter, 702 | loss_fg=loss_fg, 703 | threshold=thresholds[i], 704 | text_embeddings=prompt_embeds, 705 | text_input=text_inputs, 706 | attention_store=attention_store, 707 | step_size=scale_factor * np.sqrt(scale_range[i]), 708 | t=t, 709 | attention_res=attention_res, 710 | smooth_attentions=smooth_attentions, 711 | sigma=sigma, 712 | kernel_size=kernel_size, 713 | normalize_eot=sd_2_1, 714 | bbox=bbox, 715 | config=config, 716 | ) 717 | 718 | # Perform gradient update 719 | if i < max_iter_to_alter: 720 | _, loss = self._compute_loss(max_attention_per_index_fg, max_attention_per_index_bg, dist_x, dist_y) 721 | if loss != 0: 722 | latents = self._update_latent(latents=latents, loss=loss, 723 | step_size=scale_factor * np.sqrt(scale_range[i])) 724 | 725 | # print(f'Iteration {i} | Loss: {loss:0.4f}') 726 | 727 | # predict the noise residual 728 | noise_pred = self.unet( 729 | latent_model_input, 730 | t, 731 | encoder_hidden_states=prompt_embeds, 732 | cross_attention_kwargs=cross_attention_kwargs, 733 | ).sample 734 | 735 | # perform guidance 736 | if do_classifier_free_guidance: 737 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 738 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) 739 | 740 | # compute the previous noisy sample x_t -> x_t-1 741 | latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample 742 | 743 | # call the callback, if provided 744 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): 745 | progress_bar.update() 746 | if callback is not None and i % callback_steps == 0: 747 | callback(i, t, latents) 748 | 749 | # 8. Post-processing 750 | image = self.decode_latents(latents) 751 | 752 | # 9. Run safety checker 753 | image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) 754 | 755 | # 10. Convert to PIL 756 | if output_type == "pil": 757 | image = self.numpy_to_pil(image) 758 | 759 | if not return_dict: 760 | return (image, has_nsfw_concept) 761 | 762 | return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) 763 | -------------------------------------------------------------------------------- /pipeline/sd_pipeline_boxdiff.py: -------------------------------------------------------------------------------- 1 | 2 | from typing import Any, Callable, Dict, List, Optional, Union, Tuple 3 | 4 | import numpy as np 5 | import torch 6 | from torch.nn import functional as F 7 | 8 | from diffusers.utils import deprecate, is_accelerate_available, logging, randn_tensor, replace_example_docstring 9 | from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput 10 | 11 | from diffusers.pipelines.stable_diffusion import StableDiffusionPipeline 12 | 13 | from utils.gaussian_smoothing import GaussianSmoothing 14 | from utils.ptp_utils import AttentionStore, aggregate_attention 15 | 16 | logger = logging.get_logger(__name__) 17 | 18 | class BoxDiffPipeline(StableDiffusionPipeline): 19 | r""" 20 | Pipeline for text-to-image generation using Stable Diffusion. 21 | This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the 22 | library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) 23 | Args: 24 | vae ([`AutoencoderKL`]): 25 | Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. 26 | text_encoder ([`CLIPTextModel`]): 27 | Frozen text-encoder. Stable Diffusion uses the text portion of 28 | [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically 29 | the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. 30 | tokenizer (`CLIPTokenizer`): 31 | Tokenizer of class 32 | [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). 33 | unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. 34 | scheduler ([`SchedulerMixin`]): 35 | A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of 36 | [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. 37 | safety_checker ([`StableDiffusionSafetyChecker`]): 38 | Classification module that estimates whether generated images could be considered offensive or harmful. 39 | Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. 40 | feature_extractor ([`CLIPFeatureExtractor`]): 41 | Model that extracts features from generated images to be used as inputs for the `safety_checker`. 42 | """ 43 | _optional_components = ["safety_checker", "feature_extractor"] 44 | 45 | def _encode_prompt( 46 | self, 47 | prompt, 48 | device, 49 | num_images_per_prompt, 50 | do_classifier_free_guidance, 51 | negative_prompt=None, 52 | prompt_embeds: Optional[torch.FloatTensor] = None, 53 | negative_prompt_embeds: Optional[torch.FloatTensor] = None, 54 | ): 55 | r""" 56 | Encodes the prompt into text encoder hidden states. 57 | Args: 58 | prompt (`str` or `List[str]`, *optional*): 59 | prompt to be encoded 60 | device: (`torch.device`): 61 | torch device 62 | num_images_per_prompt (`int`): 63 | number of images that should be generated per prompt 64 | do_classifier_free_guidance (`bool`): 65 | whether to use classifier free guidance or not 66 | negative_ prompt (`str` or `List[str]`, *optional*): 67 | The prompt or prompts not to guide the image generation. If not defined, one has to pass 68 | `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. 69 | Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). 70 | prompt_embeds (`torch.FloatTensor`, *optional*): 71 | Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not 72 | provided, text embeddings will be generated from `prompt` input argument. 73 | negative_prompt_embeds (`torch.FloatTensor`, *optional*): 74 | Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt 75 | weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input 76 | argument. 77 | """ 78 | if prompt is not None and isinstance(prompt, str): 79 | batch_size = 1 80 | elif prompt is not None and isinstance(prompt, list): 81 | batch_size = len(prompt) 82 | else: 83 | batch_size = prompt_embeds.shape[0] 84 | 85 | if prompt_embeds is None: 86 | text_inputs = self.tokenizer( 87 | prompt, 88 | padding="max_length", 89 | max_length=self.tokenizer.model_max_length, 90 | truncation=True, 91 | return_tensors="pt", 92 | ) 93 | text_input_ids = text_inputs.input_ids 94 | untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids 95 | 96 | if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( 97 | text_input_ids, untruncated_ids 98 | ): 99 | removed_text = self.tokenizer.batch_decode( 100 | untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] 101 | ) 102 | logger.warning( 103 | "The following part of your input was truncated because CLIP can only handle sequences up to" 104 | f" {self.tokenizer.model_max_length} tokens: {removed_text}" 105 | ) 106 | 107 | if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: 108 | attention_mask = text_inputs.attention_mask.to(device) 109 | else: 110 | attention_mask = None 111 | 112 | prompt_embeds = self.text_encoder( 113 | text_input_ids.to(device), 114 | attention_mask=attention_mask, 115 | ) 116 | prompt_embeds = prompt_embeds[0] 117 | 118 | prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) 119 | 120 | bs_embed, seq_len, _ = prompt_embeds.shape 121 | # duplicate text embeddings for each generation per prompt, using mps friendly method 122 | prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) 123 | prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) 124 | 125 | # get unconditional embeddings for classifier free guidance 126 | if do_classifier_free_guidance and negative_prompt_embeds is None: 127 | uncond_tokens: List[str] 128 | if negative_prompt is None: 129 | uncond_tokens = [""] * batch_size 130 | elif type(prompt) is not type(negative_prompt): 131 | raise TypeError( 132 | f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" 133 | f" {type(prompt)}." 134 | ) 135 | elif isinstance(negative_prompt, str): 136 | uncond_tokens = [negative_prompt] 137 | elif batch_size != len(negative_prompt): 138 | raise ValueError( 139 | f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" 140 | f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" 141 | " the batch size of `prompt`." 142 | ) 143 | else: 144 | uncond_tokens = negative_prompt 145 | 146 | max_length = prompt_embeds.shape[1] 147 | uncond_input = self.tokenizer( 148 | uncond_tokens, 149 | padding="max_length", 150 | max_length=max_length, 151 | truncation=True, 152 | return_tensors="pt", 153 | ) 154 | 155 | if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: 156 | attention_mask = uncond_input.attention_mask.to(device) 157 | else: 158 | attention_mask = None 159 | 160 | negative_prompt_embeds = self.text_encoder( 161 | uncond_input.input_ids.to(device), 162 | attention_mask=attention_mask, 163 | ) 164 | negative_prompt_embeds = negative_prompt_embeds[0] 165 | 166 | if do_classifier_free_guidance: 167 | # duplicate unconditional embeddings for each generation per prompt, using mps friendly method 168 | seq_len = negative_prompt_embeds.shape[1] 169 | 170 | negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) 171 | 172 | negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) 173 | negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) 174 | 175 | # For classifier free guidance, we need to do two forward passes. 176 | # Here we concatenate the unconditional and text embeddings into a single batch 177 | # to avoid doing two forward passes 178 | prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) 179 | 180 | return text_inputs, prompt_embeds 181 | 182 | def _compute_max_attention_per_index(self, 183 | attention_maps: torch.Tensor, 184 | indices_to_alter: List[int], 185 | smooth_attentions: bool = False, 186 | sigma: float = 0.5, 187 | kernel_size: int = 3, 188 | normalize_eot: bool = False, 189 | bbox: List[int] = None, 190 | config=None, 191 | ) -> List[torch.Tensor]: 192 | """ Computes the maximum attention value for each of the tokens we wish to alter. """ 193 | last_idx = -1 194 | if normalize_eot: 195 | prompt = self.prompt 196 | if isinstance(self.prompt, list): 197 | prompt = self.prompt[0] 198 | last_idx = len(self.tokenizer(prompt)['input_ids']) - 1 199 | attention_for_text = attention_maps[:, :, 1:last_idx] 200 | attention_for_text *= 100 201 | attention_for_text = torch.nn.functional.softmax(attention_for_text, dim=-1) 202 | 203 | # Shift indices since we removed the first token 204 | indices_to_alter = [index - 1 for index in indices_to_alter] 205 | 206 | # Extract the maximum values 207 | max_indices_list_fg = [] 208 | max_indices_list_bg = [] 209 | dist_x = [] 210 | dist_y = [] 211 | 212 | cnt = 0 213 | for i in indices_to_alter: 214 | image = attention_for_text[:, :, i] 215 | 216 | box = [max(round(b / (512 / image.shape[0])), 0) for b in bbox[cnt]] 217 | x1, y1, x2, y2 = box 218 | cnt += 1 219 | 220 | # coordinates to masks 221 | obj_mask = torch.zeros_like(image) 222 | ones_mask = torch.ones([y2 - y1, x2 - x1], dtype=obj_mask.dtype).to(obj_mask.device) 223 | obj_mask[y1:y2, x1:x2] = ones_mask 224 | bg_mask = 1 - obj_mask 225 | 226 | if smooth_attentions: 227 | smoothing = GaussianSmoothing(channels=1, kernel_size=kernel_size, sigma=sigma, dim=2).cuda() 228 | input = F.pad(image.unsqueeze(0).unsqueeze(0), (1, 1, 1, 1), mode='reflect') 229 | image = smoothing(input).squeeze(0).squeeze(0) 230 | 231 | # Inner-Box constraint 232 | k = (obj_mask.sum() * config.P).long() 233 | max_indices_list_fg.append((image * obj_mask).reshape(-1).topk(k)[0].mean()) 234 | 235 | # Outer-Box constraint 236 | k = (bg_mask.sum() * config.P).long() 237 | max_indices_list_bg.append((image * bg_mask).reshape(-1).topk(k)[0].mean()) 238 | 239 | # Corner Constraint 240 | gt_proj_x = torch.max(obj_mask, dim=0)[0] 241 | gt_proj_y = torch.max(obj_mask, dim=1)[0] 242 | corner_mask_x = torch.zeros_like(gt_proj_x) 243 | corner_mask_y = torch.zeros_like(gt_proj_y) 244 | 245 | # create gt according to the number config.L 246 | N = gt_proj_x.shape[0] 247 | corner_mask_x[max(box[0] - config.L, 0): min(box[0] + config.L + 1, N)] = 1. 248 | corner_mask_x[max(box[2] - config.L, 0): min(box[2] + config.L + 1, N)] = 1. 249 | corner_mask_y[max(box[1] - config.L, 0): min(box[1] + config.L + 1, N)] = 1. 250 | corner_mask_y[max(box[3] - config.L, 0): min(box[3] + config.L + 1, N)] = 1. 251 | dist_x.append((F.l1_loss(image.max(dim=0)[0], gt_proj_x, reduction='none') * corner_mask_x).mean()) 252 | dist_y.append((F.l1_loss(image.max(dim=1)[0], gt_proj_y, reduction='none') * corner_mask_y).mean()) 253 | 254 | return max_indices_list_fg, max_indices_list_bg, dist_x, dist_y 255 | 256 | def _aggregate_and_get_max_attention_per_token(self, attention_store: AttentionStore, 257 | indices_to_alter: List[int], 258 | attention_res: int = 16, 259 | smooth_attentions: bool = False, 260 | sigma: float = 0.5, 261 | kernel_size: int = 3, 262 | normalize_eot: bool = False, 263 | bbox: List[int] = None, 264 | config=None, 265 | ): 266 | """ Aggregates the attention for each token and computes the max activation value for each token to alter. """ 267 | attention_maps = aggregate_attention( 268 | attention_store=attention_store, 269 | res=attention_res, 270 | from_where=("up", "down", "mid"), 271 | is_cross=True, 272 | select=0) 273 | max_attention_per_index_fg, max_attention_per_index_bg, dist_x, dist_y = self._compute_max_attention_per_index( 274 | attention_maps=attention_maps, 275 | indices_to_alter=indices_to_alter, 276 | smooth_attentions=smooth_attentions, 277 | sigma=sigma, 278 | kernel_size=kernel_size, 279 | normalize_eot=normalize_eot, 280 | bbox=bbox, 281 | config=config, 282 | ) 283 | return max_attention_per_index_fg, max_attention_per_index_bg, dist_x, dist_y 284 | 285 | @staticmethod 286 | def _compute_loss(max_attention_per_index_fg: List[torch.Tensor], max_attention_per_index_bg: List[torch.Tensor], 287 | dist_x: List[torch.Tensor], dist_y: List[torch.Tensor], return_losses: bool = False) -> torch.Tensor: 288 | """ Computes the attend-and-excite loss using the maximum attention value for each token. """ 289 | losses_fg = [max(0, 1. - curr_max) for curr_max in max_attention_per_index_fg] 290 | losses_bg = [max(0, curr_max) for curr_max in max_attention_per_index_bg] 291 | loss = sum(losses_fg) + sum(losses_bg) + sum(dist_x) + sum(dist_y) 292 | if return_losses: 293 | return max(losses_fg), losses_fg 294 | else: 295 | return max(losses_fg), loss 296 | 297 | @staticmethod 298 | def _update_latent(latents: torch.Tensor, loss: torch.Tensor, step_size: float) -> torch.Tensor: 299 | """ Update the latent according to the computed loss. """ 300 | grad_cond = torch.autograd.grad(loss.requires_grad_(True), [latents], retain_graph=True)[0] 301 | latents = latents - step_size * grad_cond 302 | return latents 303 | 304 | def _perform_iterative_refinement_step(self, 305 | latents: torch.Tensor, 306 | indices_to_alter: List[int], 307 | loss_fg: torch.Tensor, 308 | threshold: float, 309 | text_embeddings: torch.Tensor, 310 | text_input, 311 | attention_store: AttentionStore, 312 | step_size: float, 313 | t: int, 314 | attention_res: int = 16, 315 | smooth_attentions: bool = True, 316 | sigma: float = 0.5, 317 | kernel_size: int = 3, 318 | max_refinement_steps: int = 20, 319 | normalize_eot: bool = False, 320 | bbox: List[int] = None, 321 | config=None, 322 | ): 323 | """ 324 | Performs the iterative latent refinement introduced in the paper. Here, we continuously update the latent 325 | code according to our loss objective until the given threshold is reached for all tokens. 326 | """ 327 | iteration = 0 328 | target_loss = max(0, 1. - threshold) 329 | while loss_fg > target_loss: 330 | iteration += 1 331 | 332 | latents = latents.clone().detach().requires_grad_(True) 333 | noise_pred_text = self.unet(latents, t, encoder_hidden_states=text_embeddings[1].unsqueeze(0)).sample 334 | self.unet.zero_grad() 335 | 336 | # Get max activation value for each subject token 337 | max_attention_per_index_fg, max_attention_per_index_bg, dist_x, dist_y = self._aggregate_and_get_max_attention_per_token( 338 | attention_store=attention_store, 339 | indices_to_alter=indices_to_alter, 340 | attention_res=attention_res, 341 | smooth_attentions=smooth_attentions, 342 | sigma=sigma, 343 | kernel_size=kernel_size, 344 | normalize_eot=normalize_eot, 345 | bbox=bbox, 346 | config=config, 347 | ) 348 | 349 | loss_fg, losses_fg = self._compute_loss(max_attention_per_index_fg, max_attention_per_index_bg, dist_x, dist_y, return_losses=True) 350 | 351 | if loss_fg != 0: 352 | latents = self._update_latent(latents, loss_fg, step_size) 353 | 354 | with torch.no_grad(): 355 | noise_pred_uncond = self.unet(latents, t, encoder_hidden_states=text_embeddings[0].unsqueeze(0)).sample 356 | noise_pred_text = self.unet(latents, t, encoder_hidden_states=text_embeddings[1].unsqueeze(0)).sample 357 | 358 | try: 359 | low_token = np.argmax([l.item() if type(l) != int else l for l in losses_fg]) 360 | except Exception as e: 361 | print(e) # catch edge case :) 362 | 363 | low_token = np.argmax(losses_fg) 364 | 365 | low_word = self.tokenizer.decode(text_input.input_ids[0][indices_to_alter[low_token]]) 366 | # print(f'\t Try {iteration}. {low_word} has a max attention of {max_attention_per_index_fg[low_token]}') 367 | 368 | if iteration >= max_refinement_steps: 369 | # print(f'\t Exceeded max number of iterations ({max_refinement_steps})! ' 370 | # f'Finished with a max attention of {max_attention_per_index_fg[low_token]}') 371 | break 372 | 373 | # Run one more time but don't compute gradients and update the latents. 374 | # We just need to compute the new loss - the grad update will occur below 375 | latents = latents.clone().detach().requires_grad_(True) 376 | noise_pred_text = self.unet(latents, t, encoder_hidden_states=text_embeddings[1].unsqueeze(0)).sample 377 | self.unet.zero_grad() 378 | 379 | # Get max activation value for each subject token 380 | max_attention_per_index_fg, max_attention_per_index_bg, dist_x, dist_y = self._aggregate_and_get_max_attention_per_token( 381 | attention_store=attention_store, 382 | indices_to_alter=indices_to_alter, 383 | attention_res=attention_res, 384 | smooth_attentions=smooth_attentions, 385 | sigma=sigma, 386 | kernel_size=kernel_size, 387 | normalize_eot=normalize_eot, 388 | bbox=bbox, 389 | config=config, 390 | ) 391 | loss_fg, losses_fg = self._compute_loss(max_attention_per_index_fg, max_attention_per_index_bg, dist_x, dist_y, return_losses=True) 392 | # print(f"\t Finished with loss of: {loss_fg}") 393 | return loss_fg, latents, max_attention_per_index_fg 394 | 395 | @torch.no_grad() 396 | def __call__( 397 | self, 398 | prompt: Union[str, List[str]], 399 | attention_store: AttentionStore, 400 | indices_to_alter: List[int], 401 | attention_res: int = 16, 402 | height: Optional[int] = None, 403 | width: Optional[int] = None, 404 | num_inference_steps: int = 50, 405 | guidance_scale: float = 7.5, 406 | negative_prompt: Optional[Union[str, List[str]]] = None, 407 | num_images_per_prompt: Optional[int] = 1, 408 | eta: float = 0.0, 409 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 410 | latents: Optional[torch.FloatTensor] = None, 411 | prompt_embeds: Optional[torch.FloatTensor] = None, 412 | negative_prompt_embeds: Optional[torch.FloatTensor] = None, 413 | output_type: Optional[str] = "pil", 414 | return_dict: bool = True, 415 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, 416 | callback_steps: Optional[int] = 1, 417 | cross_attention_kwargs: Optional[Dict[str, Any]] = None, 418 | max_iter_to_alter: Optional[int] = 25, 419 | run_standard_sd: bool = False, 420 | thresholds: Optional[dict] = {0: 0.05, 10: 0.5, 20: 0.8}, 421 | scale_factor: int = 20, 422 | scale_range: Tuple[float, float] = (1., 0.5), 423 | smooth_attentions: bool = True, 424 | sigma: float = 0.5, 425 | kernel_size: int = 3, 426 | sd_2_1: bool = False, 427 | bbox: List[int] = None, 428 | config = None, 429 | ): 430 | r""" 431 | Function invoked when calling the pipeline for generation. 432 | Args: 433 | prompt (`str` or `List[str]`, *optional*): 434 | The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. 435 | instead. 436 | height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): 437 | The height in pixels of the generated image. 438 | width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): 439 | The width in pixels of the generated image. 440 | num_inference_steps (`int`, *optional*, defaults to 50): 441 | The number of denoising steps. More denoising steps usually lead to a higher quality image at the 442 | expense of slower inference. 443 | guidance_scale (`float`, *optional*, defaults to 7.5): 444 | Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). 445 | `guidance_scale` is defined as `w` of equation 2. of [Imagen 446 | Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 447 | 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, 448 | usually at the expense of lower image quality. 449 | negative_prompt (`str` or `List[str]`, *optional*): 450 | The prompt or prompts not to guide the image generation. If not defined, one has to pass 451 | `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. 452 | Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). 453 | num_images_per_prompt (`int`, *optional*, defaults to 1): 454 | The number of images to generate per prompt. 455 | eta (`float`, *optional*, defaults to 0.0): 456 | Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to 457 | [`schedulers.DDIMScheduler`], will be ignored for others. 458 | generator (`torch.Generator` or `List[torch.Generator]`, *optional*): 459 | One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) 460 | to make generation deterministic. 461 | latents (`torch.FloatTensor`, *optional*): 462 | Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image 463 | generation. Can be used to tweak the same generation with different prompts. If not provided, a latents 464 | tensor will ge generated by sampling using the supplied random `generator`. 465 | prompt_embeds (`torch.FloatTensor`, *optional*): 466 | Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not 467 | provided, text embeddings will be generated from `prompt` input argument. 468 | negative_prompt_embeds (`torch.FloatTensor`, *optional*): 469 | Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt 470 | weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input 471 | argument. 472 | output_type (`str`, *optional*, defaults to `"pil"`): 473 | The output format of the generate image. Choose between 474 | [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. 475 | return_dict (`bool`, *optional*, defaults to `True`): 476 | Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a 477 | plain tuple. 478 | callback (`Callable`, *optional*): 479 | A function that will be called every `callback_steps` steps during inference. The function will be 480 | called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. 481 | callback_steps (`int`, *optional*, defaults to 1): 482 | The frequency at which the `callback` function will be called. If not specified, the callback will be 483 | called at every step. 484 | cross_attention_kwargs (`dict`, *optional*): 485 | A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under 486 | `self.processor` in 487 | [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). 488 | Examples: 489 | Returns: 490 | [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: 491 | [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. 492 | When returning a tuple, the first element is a list with the generated images, and the second element is a 493 | list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" 494 | (nsfw) content, according to the `safety_checker`. 495 | :type attention_store: object 496 | """ 497 | # 0. Default height and width to unet 498 | height = height or self.unet.config.sample_size * self.vae_scale_factor 499 | width = width or self.unet.config.sample_size * self.vae_scale_factor 500 | 501 | # 1. Check inputs. Raise error if not correct 502 | self.check_inputs( 503 | prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds 504 | ) 505 | 506 | # 2. Define call parameters 507 | self.prompt = prompt 508 | if prompt is not None and isinstance(prompt, str): 509 | batch_size = 1 510 | elif prompt is not None and isinstance(prompt, list): 511 | batch_size = len(prompt) 512 | else: 513 | batch_size = prompt_embeds.shape[0] 514 | 515 | device = self._execution_device 516 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) 517 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` 518 | # corresponds to doing no classifier free guidance. 519 | do_classifier_free_guidance = guidance_scale > 1.0 520 | 521 | # 3. Encode input prompt 522 | text_inputs, prompt_embeds = self._encode_prompt( 523 | prompt, 524 | device, 525 | num_images_per_prompt, 526 | do_classifier_free_guidance, 527 | negative_prompt, 528 | prompt_embeds=prompt_embeds, 529 | negative_prompt_embeds=negative_prompt_embeds, 530 | ) 531 | 532 | # 4. Prepare timesteps 533 | self.scheduler.set_timesteps(num_inference_steps, device=device) 534 | timesteps = self.scheduler.timesteps 535 | 536 | # 5. Prepare latent variables 537 | num_channels_latents = self.unet.in_channels 538 | latents = self.prepare_latents( 539 | batch_size * num_images_per_prompt, 540 | num_channels_latents, 541 | height, 542 | width, 543 | prompt_embeds.dtype, 544 | device, 545 | generator, 546 | latents, 547 | ) 548 | 549 | # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline 550 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) 551 | 552 | scale_range = np.linspace(scale_range[0], scale_range[1], len(self.scheduler.timesteps)) 553 | 554 | if max_iter_to_alter is None: 555 | max_iter_to_alter = len(self.scheduler.timesteps) + 1 556 | 557 | # 7. Denoising loop 558 | num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order 559 | with self.progress_bar(total=num_inference_steps) as progress_bar: 560 | for i, t in enumerate(timesteps): 561 | 562 | with torch.enable_grad(): 563 | 564 | latents = latents.clone().detach().requires_grad_(True) 565 | 566 | # Forward pass of denoising with text conditioning 567 | noise_pred_text = self.unet(latents, t, 568 | encoder_hidden_states=prompt_embeds[1].unsqueeze(0), cross_attention_kwargs=cross_attention_kwargs).sample 569 | self.unet.zero_grad() 570 | 571 | # Get max activation value for each subject token 572 | max_attention_per_index_fg, max_attention_per_index_bg, dist_x, dist_y = self._aggregate_and_get_max_attention_per_token( 573 | attention_store=attention_store, 574 | indices_to_alter=indices_to_alter, 575 | attention_res=attention_res, 576 | smooth_attentions=smooth_attentions, 577 | sigma=sigma, 578 | kernel_size=kernel_size, 579 | normalize_eot=sd_2_1, 580 | bbox=bbox, 581 | config=config, 582 | ) 583 | 584 | if not run_standard_sd: 585 | 586 | loss_fg, loss = self._compute_loss(max_attention_per_index_fg, max_attention_per_index_bg, dist_x, dist_y) 587 | 588 | # Refinement from attend-and-excite (not necessary) 589 | if i in thresholds.keys() and loss_fg > 1. - thresholds[i] and config.refine: 590 | del noise_pred_text 591 | torch.cuda.empty_cache() 592 | loss_fg, latents, max_attention_per_index_fg = self._perform_iterative_refinement_step( 593 | latents=latents, 594 | indices_to_alter=indices_to_alter, 595 | loss_fg=loss_fg, 596 | threshold=thresholds[i], 597 | text_embeddings=prompt_embeds, 598 | text_input=text_inputs, 599 | attention_store=attention_store, 600 | step_size=scale_factor * np.sqrt(scale_range[i]), 601 | t=t, 602 | attention_res=attention_res, 603 | smooth_attentions=smooth_attentions, 604 | sigma=sigma, 605 | kernel_size=kernel_size, 606 | normalize_eot=sd_2_1, 607 | bbox=bbox, 608 | config=config, 609 | ) 610 | 611 | # Perform gradient update 612 | if i < max_iter_to_alter: 613 | _, loss = self._compute_loss(max_attention_per_index_fg, max_attention_per_index_bg, dist_x, dist_y) 614 | if loss != 0: 615 | latents = self._update_latent(latents=latents, loss=loss, 616 | step_size=scale_factor * np.sqrt(scale_range[i])) 617 | 618 | # print(f'Iteration {i} | Loss: {loss:0.4f}') 619 | 620 | # expand the latents if we are doing classifier free guidance 621 | latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents 622 | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) 623 | 624 | # predict the noise residual 625 | noise_pred = self.unet( 626 | latent_model_input, 627 | t, 628 | encoder_hidden_states=prompt_embeds, 629 | cross_attention_kwargs=cross_attention_kwargs, 630 | ).sample 631 | 632 | # perform guidance 633 | if do_classifier_free_guidance: 634 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 635 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) 636 | 637 | # compute the previous noisy sample x_t -> x_t-1 638 | latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample 639 | 640 | # call the callback, if provided 641 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): 642 | progress_bar.update() 643 | if callback is not None and i % callback_steps == 0: 644 | callback(i, t, latents) 645 | 646 | # 8. Post-processing 647 | image = self.decode_latents(latents) 648 | 649 | # 9. Run safety checker 650 | image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) 651 | 652 | # 10. Convert to PIL 653 | if output_type == "pil": 654 | image = self.numpy_to_pil(image) 655 | 656 | if not return_dict: 657 | return (image, has_nsfw_concept) 658 | 659 | return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) 660 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | ftfy 2 | opencv-python 3 | ipywidgets 4 | matplotlib 5 | pyrallis 6 | torch==1.12.0 7 | torchvision==0.13.0 8 | transformers==4.26.0 9 | jupyter 10 | accelerate 11 | IPython 12 | -------------------------------------------------------------------------------- /run_gligen_boxdiff.py: -------------------------------------------------------------------------------- 1 | 2 | import pprint 3 | from typing import List 4 | 5 | import pyrallis 6 | import torch 7 | from PIL import Image 8 | from config import RunConfig 9 | from pipeline.gligen_pipeline_boxdiff import BoxDiffPipeline 10 | from utils import ptp_utils, vis_utils 11 | from utils.ptp_utils import AttentionStore 12 | 13 | import numpy as np 14 | from utils.drawer import draw_rectangle, DashedImageDraw 15 | 16 | import warnings 17 | warnings.filterwarnings("ignore", category=UserWarning) 18 | 19 | 20 | def load_model(): 21 | device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu') 22 | stable_diffusion_version = "gligen/diffusers-generation-text-box" 23 | # If you cannot access the huggingface on your server, you can use the local prepared one. 24 | # stable_diffusion_version = "../../packages/diffusers/gligen_ckpts/diffusers-generation-text-box" 25 | stable = BoxDiffPipeline.from_pretrained(stable_diffusion_version, revision="fp16", torch_dtype=torch.float16).to(device) 26 | 27 | return stable 28 | 29 | 30 | def get_indices_to_alter(stable, prompt: str) -> List[int]: 31 | token_idx_to_word = {idx: stable.tokenizer.decode(t) 32 | for idx, t in enumerate(stable.tokenizer(prompt)['input_ids']) 33 | if 0 < idx < len(stable.tokenizer(prompt)['input_ids']) - 1} 34 | pprint.pprint(token_idx_to_word) 35 | token_indices = input("Please enter the a comma-separated list indices of the tokens you wish to " 36 | "alter (e.g., 2,5): ") 37 | token_indices = [int(i) for i in token_indices.split(",")] 38 | print(f"Altering tokens: {[token_idx_to_word[i] for i in token_indices]}") 39 | return token_indices 40 | 41 | 42 | def run_on_prompt(prompt: List[str], 43 | model: BoxDiffPipeline, 44 | controller: AttentionStore, 45 | token_indices: List[int], 46 | seed: torch.Generator, 47 | config: RunConfig) -> Image.Image: 48 | if controller is not None: 49 | ptp_utils.register_attention_control(model, controller) 50 | 51 | gligen_boxes = [] 52 | for i in range(len(config.bbox)): 53 | x1, y1, x2, y2 = config.bbox[i] 54 | gligen_boxes.append([x1/512, y1/512, x2/512, y2/512]) 55 | 56 | outputs = model(prompt=prompt, 57 | attention_store=controller, 58 | indices_to_alter=token_indices, 59 | attention_res=config.attention_res, 60 | guidance_scale=config.guidance_scale, 61 | gligen_phrases=config.gligen_phrases, 62 | gligen_boxes=gligen_boxes, 63 | gligen_scheduled_sampling_beta=0.3, 64 | generator=seed, 65 | num_inference_steps=config.n_inference_steps, 66 | max_iter_to_alter=config.max_iter_to_alter, 67 | run_standard_sd=config.run_standard_sd, 68 | thresholds=config.thresholds, 69 | scale_factor=config.scale_factor, 70 | scale_range=config.scale_range, 71 | smooth_attentions=config.smooth_attentions, 72 | sigma=config.sigma, 73 | kernel_size=config.kernel_size, 74 | sd_2_1=config.sd_2_1, 75 | bbox=config.bbox, 76 | config=config) 77 | image = outputs.images[0] 78 | return image 79 | 80 | 81 | @pyrallis.wrap() 82 | def main(config: RunConfig): 83 | stable = load_model() 84 | token_indices = get_indices_to_alter(stable, config.prompt) if config.token_indices is None else config.token_indices 85 | 86 | if len(config.bbox[0]) == 0: 87 | config.bbox = draw_rectangle() 88 | 89 | images = [] 90 | for seed in config.seeds: 91 | print(f"Current seed is : {seed}") 92 | g = torch.Generator('cuda').manual_seed(seed) 93 | controller = AttentionStore() 94 | controller.num_uncond_att_layers = -16 95 | image = run_on_prompt(prompt=config.prompt, 96 | model=stable, 97 | controller=controller, 98 | token_indices=token_indices, 99 | seed=g, 100 | config=config) 101 | prompt_output_path = config.output_path / config.prompt[:100] 102 | prompt_output_path.mkdir(exist_ok=True, parents=True) 103 | image.save(prompt_output_path / f'{seed}.png') 104 | images.append(image) 105 | 106 | canvas = Image.fromarray(np.zeros((image.size[0], image.size[0], 3), dtype=np.uint8) + 220) 107 | draw = DashedImageDraw(canvas) 108 | 109 | for i in range(len(config.bbox)): 110 | x1, y1, x2, y2 = config.bbox[i] 111 | draw.dashed_rectangle([(x1, y1), (x2, y2)], dash=(5, 5), outline=config.color[i], width=5) 112 | canvas.save(prompt_output_path / f'{seed}_canvas.png') 113 | 114 | # save a grid of results across all seeds 115 | joined_image = vis_utils.get_image_grid(images) 116 | joined_image.save(config.output_path / f'{config.prompt}.png') 117 | 118 | 119 | if __name__ == '__main__': 120 | main() 121 | -------------------------------------------------------------------------------- /run_sd_boxdiff.py: -------------------------------------------------------------------------------- 1 | 2 | import pprint 3 | from typing import List 4 | 5 | import pyrallis 6 | import torch 7 | from PIL import Image 8 | from config import RunConfig 9 | from pipeline.sd_pipeline_boxdiff import BoxDiffPipeline 10 | from utils import ptp_utils, vis_utils 11 | from utils.ptp_utils import AttentionStore 12 | 13 | import numpy as np 14 | from utils.drawer import draw_rectangle, DashedImageDraw 15 | 16 | import warnings 17 | warnings.filterwarnings("ignore", category=UserWarning) 18 | 19 | 20 | def load_model(config: RunConfig): 21 | device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu') 22 | 23 | if config.sd_2_1: 24 | stable_diffusion_version = "stabilityai/stable-diffusion-2-1-base" 25 | else: 26 | stable_diffusion_version = "CompVis/stable-diffusion-v1-4" 27 | # If you cannot access the huggingface on your server, you can use the local prepared one. 28 | # stable_diffusion_version = "../../packages/huggingface/hub/stable-diffusion-v1-4" 29 | stable = BoxDiffPipeline.from_pretrained(stable_diffusion_version).to(device) 30 | 31 | return stable 32 | 33 | 34 | def get_indices_to_alter(stable, prompt: str) -> List[int]: 35 | token_idx_to_word = {idx: stable.tokenizer.decode(t) 36 | for idx, t in enumerate(stable.tokenizer(prompt)['input_ids']) 37 | if 0 < idx < len(stable.tokenizer(prompt)['input_ids']) - 1} 38 | pprint.pprint(token_idx_to_word) 39 | token_indices = input("Please enter the a comma-separated list indices of the tokens you wish to " 40 | "alter (e.g., 2,5): ") 41 | token_indices = [int(i) for i in token_indices.split(",")] 42 | print(f"Altering tokens: {[token_idx_to_word[i] for i in token_indices]}") 43 | return token_indices 44 | 45 | 46 | def run_on_prompt(prompt: List[str], 47 | model: BoxDiffPipeline, 48 | controller: AttentionStore, 49 | token_indices: List[int], 50 | seed: torch.Generator, 51 | config: RunConfig) -> Image.Image: 52 | if controller is not None: 53 | ptp_utils.register_attention_control(model, controller) 54 | outputs = model(prompt=prompt, 55 | attention_store=controller, 56 | indices_to_alter=token_indices, 57 | attention_res=config.attention_res, 58 | guidance_scale=config.guidance_scale, 59 | generator=seed, 60 | num_inference_steps=config.n_inference_steps, 61 | max_iter_to_alter=config.max_iter_to_alter, 62 | run_standard_sd=config.run_standard_sd, 63 | thresholds=config.thresholds, 64 | scale_factor=config.scale_factor, 65 | scale_range=config.scale_range, 66 | smooth_attentions=config.smooth_attentions, 67 | sigma=config.sigma, 68 | kernel_size=config.kernel_size, 69 | sd_2_1=config.sd_2_1, 70 | bbox=config.bbox, 71 | config=config) 72 | image = outputs.images[0] 73 | return image 74 | 75 | 76 | @pyrallis.wrap() 77 | def main(config: RunConfig): 78 | stable = load_model(config) 79 | token_indices = get_indices_to_alter(stable, config.prompt) if config.token_indices is None else config.token_indices 80 | 81 | if len(config.bbox[0]) == 0: 82 | config.bbox = draw_rectangle() 83 | 84 | images = [] 85 | for seed in config.seeds: 86 | print(f"Current seed is : {seed}") 87 | g = torch.Generator('cuda').manual_seed(seed) 88 | controller = AttentionStore() 89 | image = run_on_prompt(prompt=config.prompt, 90 | model=stable, 91 | controller=controller, 92 | token_indices=token_indices, 93 | seed=g, 94 | config=config) 95 | prompt_output_path = config.output_path / config.prompt[:100] 96 | prompt_output_path.mkdir(exist_ok=True, parents=True) 97 | image.save(prompt_output_path / f'{seed}.png') 98 | images.append(image) 99 | 100 | canvas = Image.fromarray(np.zeros((image.size[0], image.size[0], 3), dtype=np.uint8) + 220) 101 | draw = DashedImageDraw(canvas) 102 | 103 | for i in range(len(config.bbox)): 104 | x1, y1, x2, y2 = config.bbox[i] 105 | draw.dashed_rectangle([(x1, y1), (x2, y2)], dash=(5, 5), outline=config.color[i], width=5) 106 | canvas.save(prompt_output_path / f'{seed}_canvas.png') 107 | 108 | # save a grid of results across all seeds 109 | joined_image = vis_utils.get_image_grid(images) 110 | joined_image.save(config.output_path / f'{config.prompt}.png') 111 | 112 | 113 | if __name__ == '__main__': 114 | main() 115 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/showlab/BoxDiff/b0d5d3b534418aa3fc71b9a16e5b575c0b2ee3b6/utils/__init__.py -------------------------------------------------------------------------------- /utils/drawer.py: -------------------------------------------------------------------------------- 1 | from tkinter import * 2 | from PIL import ImageDraw as D 3 | 4 | import math 5 | class DashedImageDraw(D.ImageDraw): 6 | 7 | def thick_line(self, xy, direction, fill=None, width=0): 8 | 9 | if xy[0] != xy[1]: 10 | self.line(xy, fill=fill, width=width) 11 | else: 12 | x1, y1 = xy[0] 13 | dx1, dy1 = direction[0] 14 | dx2, dy2 = direction[1] 15 | if dy2 - dy1 < 0: 16 | x1 -= 1 17 | if dx2 - dx1 < 0: 18 | y1 -= 1 19 | if dy2 - dy1 != 0: 20 | if dx2 - dx1 != 0: 21 | k = - (dx2 - dx1) / (dy2 - dy1) 22 | a = 1 / math.sqrt(1 + k ** 2) 23 | b = (width * a - 1) / 2 24 | else: 25 | k = 0 26 | b = (width - 1) / 2 27 | x3 = x1 - math.floor(b) 28 | y3 = y1 - int(k * b) 29 | x4 = x1 + math.ceil(b) 30 | y4 = y1 + int(k * b) 31 | else: 32 | x3 = x1 33 | y3 = y1 - math.floor((width - 1) / 2) 34 | x4 = x1 35 | y4 = y1 + math.ceil((width - 1) / 2) 36 | self.line([(x3, y3), (x4, y4)], fill=fill, width=1) 37 | return 38 | 39 | def dashed_line(self, xy, dash=(2, 2), fill=None, width=0): 40 | for i in range(len(xy) - 1): 41 | x1, y1 = xy[i] 42 | x2, y2 = xy[i + 1] 43 | x_length = x2 - x1 44 | y_length = y2 - y1 45 | length = math.sqrt(x_length ** 2 + y_length ** 2) 46 | dash_enabled = True 47 | postion = 0 48 | while postion <= length: 49 | for dash_step in dash: 50 | if postion > length: 51 | break 52 | if dash_enabled: 53 | start = postion / length 54 | end = min((postion + dash_step - 1) / length, 1) 55 | self.thick_line([(round(x1 + start * x_length), 56 | round(y1 + start * y_length)), 57 | (round(x1 + end * x_length), 58 | round(y1 + end * y_length))], 59 | xy, fill, width) 60 | dash_enabled = not dash_enabled 61 | postion += dash_step 62 | return 63 | 64 | def dashed_rectangle(self, xy, dash=(2, 2), outline=None, width=0): 65 | x1, y1 = xy[0] 66 | x2, y2 = xy[1] 67 | halfwidth1 = math.floor((width - 1) / 2) 68 | halfwidth2 = math.ceil((width - 1) / 2) 69 | min_dash_gap = min(dash[1::2]) 70 | end_change1 = halfwidth1 + min_dash_gap + 1 71 | end_change2 = halfwidth2 + min_dash_gap + 1 72 | odd_width_change = (width - 1) % 2 73 | self.dashed_line([(x1 - halfwidth1, y1), (x2 - end_change1, y1)], 74 | dash, outline, width) 75 | self.dashed_line([(x2, y1 - halfwidth1), (x2, y2 - end_change1)], 76 | dash, outline, width) 77 | self.dashed_line([(x2 + halfwidth2, y2 + odd_width_change), 78 | (x1 + end_change2, y2 + odd_width_change)], 79 | dash, outline, width) 80 | self.dashed_line([(x1 + odd_width_change, y2 + halfwidth2), 81 | (x1 + odd_width_change, y1 + end_change2)], 82 | dash, outline, width) 83 | return 84 | 85 | class RectangleDrawer: 86 | def __init__(self, master): 87 | self.master = master 88 | width, height = 512, 512 89 | self.canvas = Canvas(self.master, bg='#F0FFF0', width=width, height=height) 90 | self.canvas.pack() 91 | 92 | self.rectangles = [] 93 | self.colors = ['blue', 'red', 'purple', 'orange', 'green', 'yellow', 'black'] 94 | 95 | self.canvas.bind("", self.on_button_press) 96 | self.canvas.bind("", self.on_move_press) 97 | self.canvas.bind("", self.on_button_release) 98 | self.start_x = None 99 | self.start_y = None 100 | self.cur_rect = None 101 | self.master.update() 102 | width = self.master.winfo_width() 103 | height = self.master.winfo_height() 104 | x = (self.master.winfo_screenwidth() // 2) - (width // 2) 105 | y = (self.master.winfo_screenheight() // 2) - (height // 2) 106 | self.master.geometry('{}x{}+{}+{}'.format(width, height, x, y)) 107 | 108 | 109 | def on_button_press(self, event): 110 | self.start_x = event.x 111 | self.start_y = event.y 112 | self.cur_rect = self.canvas.create_rectangle(self.start_x, self.start_y, self.start_x, self.start_y, outline=self.colors[len(self.rectangles)%len(self.colors)], width=5, dash=(4, 4)) 113 | 114 | def on_move_press(self, event): 115 | cur_x, cur_y = (event.x, event.y) 116 | self.canvas.coords(self.cur_rect, self.start_x, self.start_y, cur_x, cur_y) 117 | 118 | def on_button_release(self, event): 119 | cur_x, cur_y = (event.x, event.y) 120 | self.rectangles.append([self.start_x, self.start_y, cur_x, cur_y]) 121 | self.cur_rect = None 122 | 123 | def get_rectangles(self): 124 | return self.rectangles 125 | 126 | 127 | def draw_rectangle(): 128 | root = Tk() 129 | root.title("Rectangle Drawer") 130 | 131 | drawer = RectangleDrawer(root) 132 | 133 | def on_enter_press(event): 134 | root.quit() 135 | 136 | root.bind('', on_enter_press) 137 | 138 | root.mainloop() 139 | rectangles = drawer.get_rectangles() 140 | 141 | new_rects = [] 142 | for r in rectangles: 143 | new_rects.extend(r) 144 | 145 | return new_rects 146 | 147 | if __name__ == '__main__': 148 | root = Tk() 149 | root.title("Rectangle Drawer") 150 | 151 | drawer = RectangleDrawer(root) 152 | 153 | def on_enter_press(event): 154 | root.quit() 155 | 156 | root.bind('', on_enter_press) 157 | 158 | root.mainloop() 159 | rectangles = drawer.get_rectangles() 160 | 161 | string = '[' 162 | for r in rectangles: 163 | string += '[' 164 | for n in r: 165 | string += str(n) 166 | string += ',' 167 | string = string[:-1] 168 | string += '],' 169 | string = string[:-1] 170 | string += ']' 171 | print("Rectangles:", string) -------------------------------------------------------------------------------- /utils/gaussian_smoothing.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numbers 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | 7 | 8 | class GaussianSmoothing(nn.Module): 9 | """ 10 | Apply gaussian smoothing on a 11 | 1d, 2d or 3d tensor. Filtering is performed seperately for each channel 12 | in the input using a depthwise convolution. 13 | Arguments: 14 | channels (int, sequence): Number of channels of the input tensors. Output will 15 | have this number of channels as well. 16 | kernel_size (int, sequence): Size of the gaussian kernel. 17 | sigma (float, sequence): Standard deviation of the gaussian kernel. 18 | dim (int, optional): The number of dimensions of the data. 19 | Default value is 2 (spatial). 20 | """ 21 | def __init__(self, channels, kernel_size, sigma, dim=2): 22 | super(GaussianSmoothing, self).__init__() 23 | if isinstance(kernel_size, numbers.Number): 24 | kernel_size = [kernel_size] * dim 25 | if isinstance(sigma, numbers.Number): 26 | sigma = [sigma] * dim 27 | 28 | # The gaussian kernel is the product of the 29 | # gaussian function of each dimension. 30 | kernel = 1 31 | meshgrids = torch.meshgrid( 32 | [ 33 | torch.arange(size, dtype=torch.float32) 34 | for size in kernel_size 35 | ] 36 | ) 37 | for size, std, mgrid in zip(kernel_size, sigma, meshgrids): 38 | mean = (size - 1) / 2 39 | kernel *= 1 / (std * math.sqrt(2 * math.pi)) * \ 40 | torch.exp(-((mgrid - mean) / (2 * std)) ** 2) 41 | 42 | # Make sure sum of values in gaussian kernel equals 1. 43 | kernel = kernel / torch.sum(kernel) 44 | 45 | # Reshape to depthwise convolutional weight 46 | kernel = kernel.view(1, 1, *kernel.size()) 47 | kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1)) 48 | 49 | self.register_buffer('weight', kernel) 50 | self.groups = channels 51 | 52 | if dim == 1: 53 | self.conv = F.conv1d 54 | elif dim == 2: 55 | self.conv = F.conv2d 56 | elif dim == 3: 57 | self.conv = F.conv3d 58 | else: 59 | raise RuntimeError( 60 | 'Only 1, 2 and 3 dimensions are supported. Received {}.'.format(dim) 61 | ) 62 | 63 | def forward(self, input): 64 | """ 65 | Apply gaussian filter to input. 66 | Arguments: 67 | input (torch.Tensor): Input to apply gaussian filter on. 68 | Returns: 69 | filtered (torch.Tensor): Filtered output. 70 | """ 71 | return self.conv(input, weight=self.weight.to(input.dtype), groups=self.groups) 72 | 73 | 74 | class AverageSmoothing(nn.Module): 75 | """ 76 | Apply average smoothing on a 77 | 1d, 2d or 3d tensor. Filtering is performed seperately for each channel 78 | in the input using a depthwise convolution. 79 | Arguments: 80 | channels (int, sequence): Number of channels of the input tensors. Output will 81 | have this number of channels as well. 82 | kernel_size (int, sequence): Size of the average kernel. 83 | sigma (float, sequence): Standard deviation of the rage kernel. 84 | dim (int, optional): The number of dimensions of the data. 85 | Default value is 2 (spatial). 86 | """ 87 | def __init__(self, channels, kernel_size, dim=2): 88 | super(AverageSmoothing, self).__init__() 89 | 90 | # Make sure sum of values in gaussian kernel equals 1. 91 | kernel = torch.ones(size=(kernel_size, kernel_size)) / (kernel_size * kernel_size) 92 | 93 | # Reshape to depthwise convolutional weight 94 | kernel = kernel.view(1, 1, *kernel.size()) 95 | kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1)) 96 | 97 | self.register_buffer('weight', kernel) 98 | self.groups = channels 99 | 100 | if dim == 1: 101 | self.conv = F.conv1d 102 | elif dim == 2: 103 | self.conv = F.conv2d 104 | elif dim == 3: 105 | self.conv = F.conv3d 106 | else: 107 | raise RuntimeError( 108 | 'Only 1, 2 and 3 dimensions are supported. Received {}.'.format(dim) 109 | ) 110 | 111 | def forward(self, input): 112 | """ 113 | Apply average filter to input. 114 | Arguments: 115 | input (torch.Tensor): Input to apply average filter on. 116 | Returns: 117 | filtered (torch.Tensor): Filtered output. 118 | """ 119 | return self.conv(input, weight=self.weight, groups=self.groups) 120 | -------------------------------------------------------------------------------- /utils/ptp_utils.py: -------------------------------------------------------------------------------- 1 | import abc 2 | 3 | import cv2 4 | import numpy as np 5 | import torch 6 | from IPython.display import display 7 | from PIL import Image 8 | from typing import Union, Tuple, List 9 | 10 | from diffusers.models.cross_attention import CrossAttention 11 | 12 | def text_under_image(image: np.ndarray, text: str, text_color: Tuple[int, int, int] = (0, 0, 0)) -> np.ndarray: 13 | h, w, c = image.shape 14 | offset = int(h * .2) 15 | img = np.ones((h + offset, w, c), dtype=np.uint8) * 255 16 | font = cv2.FONT_HERSHEY_SIMPLEX 17 | img[:h] = image 18 | textsize = cv2.getTextSize(text, font, 1, 2)[0] 19 | text_x, text_y = (w - textsize[0]) // 2, h + offset - textsize[1] // 2 20 | cv2.putText(img, text, (text_x, text_y), font, 1, text_color, 2) 21 | return img 22 | 23 | 24 | def view_images(images: Union[np.ndarray, List], 25 | num_rows: int = 1, 26 | offset_ratio: float = 0.02, 27 | display_image: bool = True) -> Image.Image: 28 | """ Displays a list of images in a grid. """ 29 | if type(images) is list: 30 | num_empty = len(images) % num_rows 31 | elif images.ndim == 4: 32 | num_empty = images.shape[0] % num_rows 33 | else: 34 | images = [images] 35 | num_empty = 0 36 | 37 | empty_images = np.ones(images[0].shape, dtype=np.uint8) * 255 38 | images = [image.astype(np.uint8) for image in images] + [empty_images] * num_empty 39 | num_items = len(images) 40 | 41 | h, w, c = images[0].shape 42 | offset = int(h * offset_ratio) 43 | num_cols = num_items // num_rows 44 | image_ = np.ones((h * num_rows + offset * (num_rows - 1), 45 | w * num_cols + offset * (num_cols - 1), 3), dtype=np.uint8) * 255 46 | for i in range(num_rows): 47 | for j in range(num_cols): 48 | image_[i * (h + offset): i * (h + offset) + h:, j * (w + offset): j * (w + offset) + w] = images[ 49 | i * num_cols + j] 50 | 51 | pil_img = Image.fromarray(image_) 52 | if display_image: 53 | display(pil_img) 54 | return pil_img 55 | 56 | 57 | class AttendExciteCrossAttnProcessor: 58 | 59 | def __init__(self, attnstore, place_in_unet): 60 | super().__init__() 61 | self.attnstore = attnstore 62 | self.place_in_unet = place_in_unet 63 | 64 | def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None): 65 | batch_size, sequence_length, _ = hidden_states.shape 66 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size=1) 67 | query = attn.to_q(hidden_states) 68 | 69 | is_cross = encoder_hidden_states is not None 70 | encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states 71 | key = attn.to_k(encoder_hidden_states) 72 | value = attn.to_v(encoder_hidden_states) 73 | 74 | query = attn.head_to_batch_dim(query) 75 | key = attn.head_to_batch_dim(key) 76 | value = attn.head_to_batch_dim(value) 77 | 78 | attention_probs = attn.get_attention_scores(query, key, attention_mask) 79 | 80 | self.attnstore(attention_probs, is_cross, self.place_in_unet) 81 | 82 | hidden_states = torch.bmm(attention_probs, value) 83 | hidden_states = attn.batch_to_head_dim(hidden_states) 84 | 85 | # linear proj 86 | hidden_states = attn.to_out[0](hidden_states) 87 | # dropout 88 | hidden_states = attn.to_out[1](hidden_states) 89 | 90 | return hidden_states 91 | 92 | 93 | def register_attention_control(model, controller): 94 | 95 | attn_procs = {} 96 | cross_att_count = 0 97 | for name in model.unet.attn_processors.keys(): 98 | cross_attention_dim = None if name.endswith("attn1.processor") else model.unet.config.cross_attention_dim 99 | if name.startswith("mid_block"): 100 | hidden_size = model.unet.config.block_out_channels[-1] 101 | place_in_unet = "mid" 102 | elif name.startswith("up_blocks"): 103 | block_id = int(name[len("up_blocks.")]) 104 | hidden_size = list(reversed(model.unet.config.block_out_channels))[block_id] 105 | place_in_unet = "up" 106 | elif name.startswith("down_blocks"): 107 | block_id = int(name[len("down_blocks.")]) 108 | hidden_size = model.unet.config.block_out_channels[block_id] 109 | place_in_unet = "down" 110 | else: 111 | continue 112 | 113 | cross_att_count += 1 114 | attn_procs[name] = AttendExciteCrossAttnProcessor( 115 | attnstore=controller, place_in_unet=place_in_unet 116 | ) 117 | model.unet.set_attn_processor(attn_procs) 118 | controller.num_att_layers = cross_att_count 119 | 120 | class AttentionControl(abc.ABC): 121 | 122 | def step_callback(self, x_t): 123 | return x_t 124 | 125 | def between_steps(self): 126 | return 127 | 128 | # @property 129 | # def num_uncond_att_layers(self): 130 | # return 0 131 | 132 | @abc.abstractmethod 133 | def forward(self, attn, is_cross: bool, place_in_unet: str): 134 | raise NotImplementedError 135 | 136 | def __call__(self, attn, is_cross: bool, place_in_unet: str): 137 | if self.cur_att_layer >= self.num_uncond_att_layers: 138 | self.forward(attn, is_cross, place_in_unet) 139 | self.cur_att_layer += 1 140 | if self.cur_att_layer == self.num_att_layers + self.num_uncond_att_layers: 141 | self.cur_att_layer = 0 142 | self.cur_step += 1 143 | self.between_steps() 144 | 145 | def reset(self): 146 | self.cur_step = 0 147 | self.cur_att_layer = 0 148 | 149 | def __init__(self): 150 | self.cur_step = 0 151 | self.num_att_layers = -1 152 | self.cur_att_layer = 0 153 | 154 | 155 | class EmptyControl(AttentionControl): 156 | 157 | def forward(self, attn, is_cross: bool, place_in_unet: str): 158 | return attn 159 | 160 | 161 | class AttentionStore(AttentionControl): 162 | 163 | @staticmethod 164 | def get_empty_store(): 165 | return {"down_cross": [], "mid_cross": [], "up_cross": [], 166 | "down_self": [], "mid_self": [], "up_self": []} 167 | 168 | def forward(self, attn, is_cross: bool, place_in_unet: str): 169 | key = f"{place_in_unet}_{'cross' if is_cross else 'self'}" 170 | if attn.shape[1] <= 32 ** 2: # avoid memory overhead 171 | self.step_store[key].append(attn) 172 | return attn 173 | 174 | def between_steps(self): 175 | self.attention_store = self.step_store 176 | if self.save_global_store: 177 | with torch.no_grad(): 178 | if len(self.global_store) == 0: 179 | self.global_store = self.step_store 180 | else: 181 | for key in self.global_store: 182 | for i in range(len(self.global_store[key])): 183 | self.global_store[key][i] += self.step_store[key][i].detach() 184 | self.step_store = self.get_empty_store() 185 | self.step_store = self.get_empty_store() 186 | 187 | def get_average_attention(self): 188 | average_attention = self.attention_store 189 | return average_attention 190 | 191 | def get_average_global_attention(self): 192 | average_attention = {key: [item / self.cur_step for item in self.global_store[key]] for key in 193 | self.attention_store} 194 | return average_attention 195 | 196 | def reset(self): 197 | super(AttentionStore, self).reset() 198 | self.step_store = self.get_empty_store() 199 | self.attention_store = {} 200 | self.global_store = {} 201 | 202 | def __init__(self, save_global_store=False): 203 | ''' 204 | Initialize an empty AttentionStore 205 | :param step_index: used to visualize only a specific step in the diffusion process 206 | ''' 207 | super(AttentionStore, self).__init__() 208 | self.save_global_store = save_global_store 209 | self.step_store = self.get_empty_store() 210 | self.attention_store = {} 211 | self.global_store = {} 212 | self.curr_step_index = 0 213 | self.num_uncond_att_layers = 0 214 | 215 | 216 | def aggregate_attention(attention_store: AttentionStore, 217 | res: int, 218 | from_where: List[str], 219 | is_cross: bool, 220 | select: int) -> torch.Tensor: 221 | """ Aggregates the attention across the different layers and heads at the specified resolution. """ 222 | out = [] 223 | attention_maps = attention_store.get_average_attention() 224 | 225 | num_pixels = res ** 2 226 | for location in from_where: 227 | for item in attention_maps[f"{location}_{'cross' if is_cross else 'self'}"]: 228 | if item.shape[1] == num_pixels: 229 | cross_maps = item.reshape(1, -1, res, res, item.shape[-1])[select] 230 | out.append(cross_maps) 231 | out = torch.cat(out, dim=0) 232 | out = out.sum(0) / out.shape[0] 233 | return out 234 | -------------------------------------------------------------------------------- /utils/vis_utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import List 3 | from PIL import Image 4 | import cv2 5 | import numpy as np 6 | import torch 7 | 8 | from utils import ptp_utils 9 | from utils.ptp_utils import AttentionStore, aggregate_attention 10 | 11 | 12 | def show_cross_attention(prompt: str, 13 | attention_store: AttentionStore, 14 | tokenizer, 15 | indices_to_alter: List[int], 16 | res: int, 17 | from_where: List[str], 18 | select: int = 0, 19 | orig_image=None): 20 | tokens = tokenizer.encode(prompt) 21 | decoder = tokenizer.decode 22 | attention_maps = aggregate_attention(attention_store, res, from_where, True, select).detach().cpu() 23 | images = [] 24 | 25 | # show spatial attention for indices of tokens to strengthen 26 | for i in range(len(tokens)): 27 | image = attention_maps[:, :, i] 28 | if i in indices_to_alter: 29 | image = show_image_relevance(image, orig_image) 30 | image = image.astype(np.uint8) 31 | image = np.array(Image.fromarray(image).resize((res ** 2, res ** 2))) 32 | image = ptp_utils.text_under_image(image, decoder(int(tokens[i]))) 33 | images.append(image) 34 | 35 | ptp_utils.view_images(np.stack(images, axis=0)) 36 | 37 | 38 | def show_image_relevance(image_relevance, image: Image.Image, relevnace_res=16): 39 | # create heatmap from mask on image 40 | def show_cam_on_image(img, mask): 41 | heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET) 42 | heatmap = np.float32(heatmap) / 255 43 | cam = heatmap + np.float32(img) 44 | cam = cam / np.max(cam) 45 | return cam 46 | 47 | image = image.resize((relevnace_res ** 2, relevnace_res ** 2)) 48 | image = np.array(image) 49 | 50 | image_relevance = image_relevance.reshape(1, 1, image_relevance.shape[-1], image_relevance.shape[-1]) 51 | image_relevance = image_relevance.cuda() # because float16 precision interpolation is not supported on cpu 52 | image_relevance = torch.nn.functional.interpolate(image_relevance, size=relevnace_res ** 2, mode='bilinear') 53 | image_relevance = image_relevance.cpu() # send it back to cpu 54 | image_relevance = (image_relevance - image_relevance.min()) / (image_relevance.max() - image_relevance.min()) 55 | image_relevance = image_relevance.reshape(relevnace_res ** 2, relevnace_res ** 2) 56 | image = (image - image.min()) / (image.max() - image.min()) 57 | vis = show_cam_on_image(image, image_relevance) 58 | vis = np.uint8(255 * vis) 59 | vis = cv2.cvtColor(np.array(vis), cv2.COLOR_RGB2BGR) 60 | return vis 61 | 62 | 63 | def get_image_grid(images: List[Image.Image]) -> Image: 64 | num_images = len(images) 65 | cols = int(math.ceil(math.sqrt(num_images))) 66 | rows = int(math.ceil(num_images / cols)) 67 | width, height = images[0].size 68 | grid_image = Image.new('RGB', (cols * width, rows * height)) 69 | for i, img in enumerate(images): 70 | x = i % cols 71 | y = i // cols 72 | grid_image.paste(img, (x * width, y * height)) 73 | return grid_image 74 | --------------------------------------------------------------------------------