├── utils ├── __init__.py ├── vis_utils.py ├── gaussian_smoothing.py ├── attention_mask.py └── ptp_utils.py ├── prompts └── demo_prompt.txt ├── README.md ├── .gitignore └── run_maskdiffusion.py /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /prompts/demo_prompt.txt: -------------------------------------------------------------------------------- 1 | a red car and a pink sheep 2 | a blue apple and a green vase 3 | a gold car and a red clock 4 | a brown bench and a green dog 5 | a red bird and a brown bowl 6 | a green bench and a blue car 7 | a green apple and a brown sheep 8 | a green clock and a gold vase 9 | a red bowl and a blue cup 10 | a green bench and a brown clock -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MaskDiffusion: Boosting Text-to-Image Consistency with Conditional Mask (IJCV) 2 | 3 | ## 🚩 **TODO/Updates** 4 | - [x] Basic Code. 5 | - [ ] Demo 6 | - [ ] Integration with LLM Priors. 7 | - [ ] Support for Other Pre-trained Models 8 | --- 9 | ## Setup 10 | 11 | ### Environment 12 | Our code builds on the requirement of the official [Stable Diffusion repository](https://github.com/CompVis/stable-diffusion). To set up their environment, please run: 13 | 14 | ``` 15 | conda env create -f environment/environment.yaml 16 | conda activate maskdiffusion 17 | 18 | ``` 19 | 20 | ## Testing 21 | 22 | ### Testing with the mini-testset 23 | ``` 24 | python run_maskdiffusion.py 25 | ``` 26 | 27 | 28 | 29 | 30 | ## Acknowledgements 31 | This orignal code is builds on the code from the [diffusers](https://github.com/huggingface/diffusers) library and [Prompt-to-Prompt](https://github.com/google/prompt-to-prompt/). When reorganizing the code to be public, I also reused code from [Attend-and-Excite](https://github.com/yuval-alaluf/Attend-and-Excite) and [Densediffusion](https://github.com/naver-ai/DenseDiffusion), to achieve a more concise implementation. 32 | 33 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # ignored folders 2 | datasets/* 3 | visual/* 4 | tb_logger/* 5 | model_output/* 6 | wandb/* 7 | tmp/* 8 | environment/* 9 | 10 | docs/api 11 | scripts/__init__.py 12 | 13 | *.DS_Store 14 | .idea 15 | 16 | # ignored files 17 | version.py 18 | 19 | # ignored files with suffix 20 | *.html 21 | *.png 22 | *.jpeg 23 | *.jpg 24 | *.gif 25 | *.pth 26 | *.zip 27 | *.pt 28 | # template 29 | 30 | # Byte-compiled / optimized / DLL files 31 | __pycache__/ 32 | *.py[cod] 33 | *$py.class 34 | 35 | # C extensions 36 | *.so 37 | 38 | # Distribution / packaging 39 | .Python 40 | build/ 41 | develop-eggs/ 42 | dist/ 43 | downloads/ 44 | eggs/ 45 | .eggs/ 46 | lib/ 47 | lib64/ 48 | parts/ 49 | sdist/ 50 | var/ 51 | wheels/ 52 | *.egg-info/ 53 | .installed.cfg 54 | *.egg 55 | MANIFEST 56 | 57 | # PyInstaller 58 | # Usually these files are written by a python script from a template 59 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 60 | *.manifest 61 | *.spec 62 | 63 | # Installer logs 64 | pip-log.txt 65 | pip-delete-this-directory.txt 66 | 67 | # Unit test / coverage reports 68 | htmlcov/ 69 | .tox/ 70 | .coverage 71 | .coverage.* 72 | .cache 73 | nosetests.xml 74 | coverage.xml 75 | *.cover 76 | .hypothesis/ 77 | .pytest_cache/ 78 | 79 | # Translations 80 | *.mo 81 | *.pot 82 | 83 | # Django stuff: 84 | *.log 85 | local_settings.py 86 | db.sqlite3 87 | 88 | # Flask stuff: 89 | instance/ 90 | .webassets-cache 91 | 92 | # Scrapy stuff: 93 | .scrapy 94 | 95 | # Sphinx documentation 96 | docs/_build/ 97 | 98 | # PyBuilder 99 | target/ 100 | 101 | # Jupyter Notebook 102 | .ipynb_checkpoints 103 | 104 | # pyenv 105 | .python-version 106 | 107 | # celery beat schedule file 108 | celerybeat-schedule 109 | 110 | # SageMath parsed files 111 | *.sage.py 112 | 113 | # Environments 114 | .env 115 | .venv 116 | env/ 117 | venv/ 118 | ENV/ 119 | env.bak/ 120 | venv.bak/ 121 | 122 | # Spyder project settings 123 | .spyderproject 124 | .spyproject 125 | 126 | # Rope project settings 127 | .ropeproject 128 | 129 | # mkdocs documentation 130 | /site 131 | 132 | # mypy 133 | .mypy_cache/ -------------------------------------------------------------------------------- /utils/vis_utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import List 3 | from PIL import Image 4 | import cv2 5 | import numpy as np 6 | import torch 7 | 8 | from utils import ptp_utils 9 | from utils.ptp_utils import AttentionStore, aggregate_attention 10 | 11 | 12 | def show_cross_attention(prompt: str, 13 | attention_store: AttentionStore, 14 | tokenizer, 15 | indices_to_alter: List[int], 16 | res: int, 17 | from_where: List[str], 18 | select: int = 0, 19 | orig_image=None): 20 | tokens = tokenizer.encode(prompt) 21 | decoder = tokenizer.decode 22 | attention_maps = aggregate_attention(attention_store, res, from_where, True, select).detach().cpu() 23 | images = [] 24 | 25 | # show spatial attention for indices of tokens to strengthen 26 | for i in range(len(tokens)): 27 | image = attention_maps[:, :, i] 28 | if i in indices_to_alter: 29 | image = show_image_relevance(image, orig_image) 30 | image = image.astype(np.uint8) 31 | image = np.array(Image.fromarray(image).resize((res ** 2, res ** 2))) 32 | image = ptp_utils.text_under_image(image, decoder(int(tokens[i]))) 33 | images.append(image) 34 | 35 | ptp_utils.view_images(np.stack(images, axis=0)) 36 | 37 | 38 | def show_image_relevance(image_relevance, image: Image.Image, relevnace_res=16): 39 | # create heatmap from mask on image 40 | def show_cam_on_image(img, mask): 41 | heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET) 42 | heatmap = np.float32(heatmap) / 255 43 | cam = heatmap + np.float32(img) 44 | cam = cam / np.max(cam) 45 | return cam 46 | 47 | image = image.resize((relevnace_res ** 2, relevnace_res ** 2)) 48 | image = np.array(image) 49 | 50 | image_relevance = image_relevance.reshape(1, 1, image_relevance.shape[-1], image_relevance.shape[-1]) 51 | image_relevance = image_relevance.cuda() # because float16 precision interpolation is not supported on cpu 52 | image_relevance = torch.nn.functional.interpolate(image_relevance, size=relevnace_res ** 2, mode='bilinear') 53 | image_relevance = image_relevance.cpu() # send it back to cpu 54 | image_relevance = (image_relevance - image_relevance.min()) / (image_relevance.max() - image_relevance.min()) 55 | image_relevance = image_relevance.reshape(relevnace_res ** 2, relevnace_res ** 2) 56 | image = (image - image.min()) / (image.max() - image.min()) 57 | vis = show_cam_on_image(image, image_relevance) 58 | vis = np.uint8(255 * vis) 59 | vis = cv2.cvtColor(np.array(vis), cv2.COLOR_RGB2BGR) 60 | return vis 61 | 62 | 63 | def get_image_grid(images: List[Image.Image]) -> Image: 64 | num_images = len(images) 65 | cols = int(math.ceil(math.sqrt(num_images))) 66 | rows = int(math.ceil(num_images / cols)) 67 | width, height = images[0].size 68 | grid_image = Image.new('RGB', (cols * width, rows * height)) 69 | for i, img in enumerate(images): 70 | x = i % cols 71 | y = i // cols 72 | grid_image.paste(img, (x * width, y * height)) 73 | return grid_image 74 | -------------------------------------------------------------------------------- /utils/gaussian_smoothing.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numbers 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | 7 | 8 | class GaussianSmoothing(nn.Module): 9 | """ 10 | Apply gaussian smoothing on a 11 | 1d, 2d or 3d tensor. Filtering is performed seperately for each channel 12 | in the input using a depthwise convolution. 13 | Arguments: 14 | channels (int, sequence): Number of channels of the input tensors. Output will 15 | have this number of channels as well. 16 | kernel_size (int, sequence): Size of the gaussian kernel. 17 | sigma (float, sequence): Standard deviation of the gaussian kernel. 18 | dim (int, optional): The number of dimensions of the data. 19 | Default value is 2 (spatial). 20 | """ 21 | def __init__(self, channels, kernel_size, sigma, dim=2): 22 | super(GaussianSmoothing, self).__init__() 23 | if isinstance(kernel_size, numbers.Number): 24 | kernel_size = [kernel_size] * dim 25 | if isinstance(sigma, numbers.Number): 26 | sigma = [sigma] * dim 27 | 28 | # The gaussian kernel is the product of the 29 | # gaussian function of each dimension. 30 | kernel = 1 31 | meshgrids = torch.meshgrid( 32 | [ 33 | torch.arange(size, dtype=torch.float32) 34 | for size in kernel_size 35 | ] 36 | ) 37 | for size, std, mgrid in zip(kernel_size, sigma, meshgrids): 38 | mean = (size - 1) / 2 39 | kernel *= 1 / (std * math.sqrt(2 * math.pi)) * \ 40 | torch.exp(-((mgrid - mean) / (2 * std)) ** 2) 41 | 42 | # Make sure sum of values in gaussian kernel equals 1. 43 | kernel = kernel / torch.sum(kernel) 44 | 45 | # Reshape to depthwise convolutional weight 46 | kernel = kernel.view(1, 1, *kernel.size()) 47 | kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1)) 48 | 49 | self.register_buffer('weight', kernel) 50 | self.groups = channels 51 | 52 | if dim == 1: 53 | self.conv = F.conv1d 54 | elif dim == 2: 55 | self.conv = F.conv2d 56 | elif dim == 3: 57 | self.conv = F.conv3d 58 | else: 59 | raise RuntimeError( 60 | 'Only 1, 2 and 3 dimensions are supported. Received {}.'.format(dim) 61 | ) 62 | 63 | def forward(self, input): 64 | """ 65 | Apply gaussian filter to input. 66 | Arguments: 67 | input (torch.Tensor): Input to apply gaussian filter on. 68 | Returns: 69 | filtered (torch.Tensor): Filtered output. 70 | """ 71 | return self.conv(input, weight=self.weight.to(input.dtype), groups=self.groups) 72 | 73 | 74 | class AverageSmoothing(nn.Module): 75 | """ 76 | Apply average smoothing on a 77 | 1d, 2d or 3d tensor. Filtering is performed seperately for each channel 78 | in the input using a depthwise convolution. 79 | Arguments: 80 | channels (int, sequence): Number of channels of the input tensors. Output will 81 | have this number of channels as well. 82 | kernel_size (int, sequence): Size of the average kernel. 83 | sigma (float, sequence): Standard deviation of the rage kernel. 84 | dim (int, optional): The number of dimensions of the data. 85 | Default value is 2 (spatial). 86 | """ 87 | def __init__(self, channels, kernel_size, dim=2): 88 | super(AverageSmoothing, self).__init__() 89 | 90 | # Make sure sum of values in gaussian kernel equals 1. 91 | kernel = torch.ones(size=(kernel_size, kernel_size)) / (kernel_size * kernel_size) 92 | 93 | # Reshape to depthwise convolutional weight 94 | kernel = kernel.view(1, 1, *kernel.size()) 95 | kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1)) 96 | 97 | self.register_buffer('weight', kernel) 98 | self.groups = channels 99 | 100 | if dim == 1: 101 | self.conv = F.conv1d 102 | elif dim == 2: 103 | self.conv = F.conv2d 104 | elif dim == 3: 105 | self.conv = F.conv3d 106 | else: 107 | raise RuntimeError( 108 | 'Only 1, 2 and 3 dimensions are supported. Received {}.'.format(dim) 109 | ) 110 | 111 | def forward(self, input): 112 | """ 113 | Apply average filter to input. 114 | Arguments: 115 | input (torch.Tensor): Input to apply average filter on. 116 | Returns: 117 | filtered (torch.Tensor): Filtered output. 118 | """ 119 | return self.conv(input, weight=self.weight, groups=self.groups) 120 | -------------------------------------------------------------------------------- /run_maskdiffusion.py: -------------------------------------------------------------------------------- 1 | ####### 2 | # Code refactored to implement the maskdiffusion. 3 | # The original implementation was created in April 2023 for a conference submission. 4 | # The new implementation reuses code from "attend-and-excite" and "densediffusion" to achieve a more simple implementation on 2023.9.20. 5 | ####### 6 | import os 7 | from typing import Callable, Dict, List, Optional, Tuple, Union 8 | 9 | import numpy as np 10 | import torch 11 | import torch.nn.functional as nnf 12 | from PIL import Image 13 | from diffusers import DDIMScheduler, StableDiffusionPipeline 14 | 15 | from utils import ptp_utils 16 | from utils.attention_mask import MaskdiffusionStore 17 | from utils.ptp_utils import AttentionStore 18 | 19 | 20 | MY_TOKEN = '' 21 | LOW_RESOURCE = False 22 | NUM_DIFFUSION_STEPS = 50 23 | GUIDANCE_SCALE = 7.5 24 | MAX_NUM_WORDS = 77 25 | SAVE_PATH = "./generated_images/test/" 26 | 27 | 28 | # 设置设备 29 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 30 | 31 | 32 | # 加载 Stable Diffusion 模型 33 | stable_diffusion_version = "CompVis/stable-diffusion-v1-4" 34 | stable = StableDiffusionPipeline.from_pretrained( 35 | stable_diffusion_version, safety_checker=None 36 | ).to(device) 37 | 38 | # 配置 Stable Diffusion 的调度器 39 | tokenizer = stable.tokenizer 40 | stable.scheduler = DDIMScheduler.from_config(stable.scheduler.config) 41 | stable.scheduler.set_timesteps(50) 42 | 43 | 44 | 45 | @torch.no_grad() 46 | def run_on_prompt(prompts: List[str], 47 | pipe: StableDiffusionPipeline, 48 | controller: AttentionStore, 49 | seed: torch.Generator, 50 | mask_dict=None) -> Image.Image: 51 | """ 52 | Generate an image based on a list of prompts using a Stable Diffusion pipeline. 53 | 54 | Args: 55 | prompts (List[str]): List of prompts. The first is the main prompt, and the rest are sub-prompts. 56 | pipe (StableDiffusionPipeline): Pre-trained Stable Diffusion pipeline. 57 | controller (AttentionStore): Controller to manage attention control. 58 | seed (torch.Generator): Random seed for reproducibility. 59 | mask_dict (dict, optional): Masking dictionary, if any. 60 | 61 | Returns: 62 | Image.Image: Generated image. 63 | """ 64 | 65 | # Step 1: Register attention control, if applicable 66 | if controller is not None: 67 | ptp_utils.register_MA_attention_control(pipe, controller) 68 | 69 | # Step 2: Encode text embeddings for prompts 70 | text_input = pipe.tokenizer( 71 | prompts, 72 | padding="max_length", 73 | return_length=True, 74 | return_overflowing_tokens=False, 75 | max_length=pipe.tokenizer.model_max_length, 76 | truncation=True, 77 | return_tensors="pt" 78 | ) 79 | cond_embeddings = pipe.text_encoder(text_input.input_ids.to(pipe.device))[0] 80 | 81 | # Encode unconditional (blank) embedding 82 | uncond_input = pipe.tokenizer( 83 | [""], 84 | padding="max_length", 85 | max_length=pipe.tokenizer.model_max_length, 86 | truncation=True, 87 | return_tensors="pt" 88 | ) 89 | uncond_embeddings = pipe.text_encoder(uncond_input.input_ids.to(pipe.device))[0] 90 | 91 | # Step 3: Swap sub-prompt embeddings into the main prompt 92 | def find_and_replace_embedding(main_ids, sub_ids, cond_embeddings, i): 93 | """Find and replace sub-prompt embeddings in the main prompt.""" 94 | wlen = len(sub_ids) # Length of sub-prompt tokens 95 | for j in range(len(main_ids) - wlen + 1): # Search within the main prompt tokens 96 | if torch.all(main_ids[j:j + wlen] == sub_ids): 97 | # Replace the embedding for the matched tokens 98 | cond_embeddings[0][j:j + wlen] = cond_embeddings[i][1:1 + wlen] 99 | return [k for k in range(j, j + wlen)] # Return matched indices 100 | raise ValueError(f"Sub-prompt '{prompts[i]}' not found in the main prompt!") 101 | 102 | # Extract token IDs for the main prompt and initialize lists 103 | main_ids = text_input['input_ids'][0] 104 | indice_lists = [] 105 | 106 | # Process each sub-prompt 107 | for i in range(1, len(prompts)): 108 | sub_ids = text_input['input_ids'][i][1:text_input['length'][i] - 1] # Remove special tokens 109 | indice_lists.append(find_and_replace_embedding(main_ids, sub_ids, cond_embeddings, i)) 110 | 111 | # Step 4: Build token dictionary for tracking 112 | token_dict = {indices[-1]: indices for indices in indice_lists} 113 | controller.token_dict = token_dict 114 | controller.text_cond = torch.cat([uncond_embeddings, cond_embeddings[0].unsqueeze(0)]) 115 | controller.timesteps = pipe.scheduler.timesteps 116 | controller.text_length = text_input['length'][0] 117 | 118 | print(f"Token Dictionary: {token_dict}, Main Prompt: {prompts[0]}") 119 | 120 | 121 | # Step 5: Generate image using the pipeline 122 | outputs = pipe( 123 | prompt=prompts[0], 124 | generator=seed, 125 | num_inference_steps=NUM_DIFFUSION_STEPS, 126 | guidance_scale=GUIDANCE_SCALE, 127 | height=512, 128 | width=512 129 | ) 130 | return outputs.images[0] 131 | 132 | 133 | 134 | promptsarr = [] 135 | 136 | # 读取文件并将每行添加到列表中 137 | with open("prompts/demo_prompt.txt", "r") as f: 138 | for line in f: 139 | promptsarr.append(line.strip()) # 使用 strip() 去除行首尾的换行符和多余空格 140 | 141 | 142 | # 遍历 promptsarr 143 | for i in range(len(promptsarr)): 144 | target_text = promptsarr[i] 145 | 146 | g = torch.Generator('cuda').manual_seed(0) 147 | 148 | # 分割文本 149 | before_and, and_word, after_and = target_text.partition(' and ') 150 | 151 | # 创建保存路径 152 | controller = MaskdiffusionStore() 153 | save_path = os.path.join(SAVE_PATH, target_text) # 用 os.path.join 拼接路径 154 | 155 | # 创建目录 156 | os.makedirs(save_path, exist_ok=True) 157 | controller.save_path = save_path 158 | 159 | # 生成图像 160 | image = run_on_prompt(prompts=[target_text, before_and, after_and], 161 | pipe=stable, 162 | controller=controller, 163 | seed=g) 164 | 165 | 166 | # 保存生成的图像 167 | img_mask_diff = image 168 | img_mask_diff.save(os.path.join(save_path, f"save_img.png")) 169 | -------------------------------------------------------------------------------- /utils/attention_mask.py: -------------------------------------------------------------------------------- 1 | ####### 2 | # Code refactored to implement the mask mechanism. 3 | # The original implementation was created in April 2023 for a conference submission. 4 | # The implementation reuses code from "attend-and-excite" and "densediffusion" to achieve a more simple implementation on 2023.9.20. 5 | ####### 6 | 7 | 8 | from select import select 9 | from typing import Optional, Union, Tuple, List, Callable, Dict 10 | import torch 11 | from diffusers import StableDiffusionPipeline 12 | import torch.nn.functional as nnf 13 | import numpy as np 14 | import abc 15 | 16 | from utils import ptp_utils 17 | from utils.ptp_utils import AttentionStore as AttentionStore 18 | from PIL import Image 19 | device = torch.device('cuda') 20 | import math 21 | 22 | 23 | @torch.no_grad() 24 | def getGaussianKernel(ksize, sigma=0): 25 | if sigma <= 0: 26 | # 根据 kernelsize 计算默认的 sigma,和 opencv 保持一致 27 | sigma = 0.3 * ((ksize - 1) * 0.5 - 1) + 0.8 28 | center = ksize // 2 29 | xs = (np.arange(ksize, dtype=np.float32) - center) # 元素与矩阵中心的横向距离 30 | kernel1d = np.exp(-(xs ** 2) / (2 * sigma ** 2)) # 计算一维卷积核 31 | # 根据指数函数性质,利用矩阵乘法快速计算二维卷积核 32 | kernel = kernel1d[..., None] @ kernel1d[None, ...] 33 | kernel = torch.from_numpy(kernel) 34 | kernel = kernel / kernel.sum() # 归一化 35 | return kernel 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | def combine_color_layers(bool_layers, color_map= [(255, 0, 0), (0, 255, 0), (0, 0, 255)]): 45 | """ 46 | Combines several boolean layers each representing a color into a single RGB image. 47 | 48 | Parameters: 49 | bool_layers (numpy.ndarray): A boolean array of shape (n, h, w), each layer n representing a different color. 50 | color_map (list): List of tuples, each representing the RGB values for the corresponding layer. 51 | 52 | Returns: 53 | PIL.Image: The combined RGB image. 54 | """ 55 | bool_layers = bool_layers.cpu().numpy().astype(np.int32) 56 | 57 | if bool_layers.shape[0] > len(color_map): 58 | raise ValueError(f"Number of layers {bool_layers.shape[0]} and number of colors {len(color_map)} must match.") 59 | 60 | # Initialize an empty image array with zeros. 61 | h, w = bool_layers.shape[1], bool_layers.shape[2] 62 | n = bool_layers.shape[0] 63 | image_array = np.zeros((h, w, 3), dtype=np.int32) 64 | 65 | # Add each color layer to the image. 66 | for i in range(n): 67 | color = color_map[i] 68 | layer = bool_layers[i] 69 | # print(layer) 70 | # Only add where the boolean layer is True 71 | for c in range(3): # RGB channels 72 | image_array[:, :, c] += (color[c] * layer) 73 | 74 | # Clip values to be in valid range (0-255) after combining colors 75 | image_array = np.clip(image_array, 0, 255).astype(dtype=np.int8) 76 | 77 | return Image.fromarray(image_array, 'RGB') 78 | 79 | 80 | def aggregate_attention(attention_store: AttentionStore, res: int, from_where: List[str], is_cross: bool, select: int): 81 | out = [] 82 | attention_maps = attention_store.get_average_attention() 83 | num_pixels = res ** 2 84 | for location in from_where: 85 | for item in attention_maps[f"{location}_{'cross' if is_cross else 'self'}"]: 86 | if item.shape[1] == num_pixels: 87 | cross_maps = item.reshape(len(prompts), -1, res, res, item.shape[-1])[select] 88 | out.append(cross_maps) 89 | out = torch.cat(out, dim=0) 90 | out = out.sum(0) / out.shape[0] 91 | return out.cpu() 92 | 93 | 94 | def point_to_line_distance(point, endpoint1, endpoint2): 95 | """ 96 | Calculate the distance from a point to a line defined by two endpoints. 97 | 98 | Args: 99 | point (tuple): The coordinates of the point (x, y). 100 | endpoint1 (tuple): The coordinates of the first endpoint (x1, y1). 101 | endpoint2 (tuple): The coordinates of the second endpoint (x2, y2). 102 | 103 | Returns: 104 | float: The distance from the point to the line. 105 | """ 106 | x1, y1 = endpoint1 107 | x2, y2 = endpoint2 108 | x, y = point 109 | a = np.sqrt((x2-x1)**2+ (y2-y1)**2)/2 110 | b = a/2.5 111 | # Numerator of the distance formula 112 | numerator = abs((y2 - y1) * x - (x2 - x1) * y + x2 * y1 - y2 * x1) 113 | # Denominator of the distance formula 114 | denominator = np.sqrt((y2 - y1)**2 + (x2 - x1)**2) 115 | if denominator == 0: 116 | raise ValueError("The two endpoints cannot be the same point.") 117 | 118 | # Calculate the distance 119 | ya = numerator / denominator 120 | 121 | # 122 | l = np.sqrt((y - (y1+y2)/2)**2 + (x - (x2 + x1)/2)**2) 123 | # 124 | if l**2-y**2 <=0: 125 | xa = 0 126 | else: 127 | xa = np.sqrt(l**2-y**2) 128 | return (xa*xa) / (a*a) + (ya*ya) / (b*b) 129 | 130 | 131 | 132 | 133 | def show_self_attention_comp(attention_store: AttentionStore, res: int, from_where: List[str], 134 | max_com=10, select: int = 0): 135 | attention_maps = aggregate_attention(attention_store, res, from_where, False, select).numpy().reshape((res ** 2, res ** 2)) 136 | u, s, vh = np.linalg.svd(attention_maps - np.mean(attention_maps, axis=1, keepdims=True)) 137 | images = [] 138 | for i in range(max_com): 139 | image = vh[i].reshape(res, res) 140 | image = image - image.min() 141 | image = 255 * image / image.max() 142 | image = np.repeat(np.expand_dims(image, axis=2), 3, axis=2).astype(np.uint8) 143 | image = Image.fromarray(image).resize((256, 256)) 144 | image = np.array(image) 145 | images.append(image) 146 | ptp_utils.view_images(np.concatenate(images, axis=1)) 147 | 148 | def judge_region(x,y,h,w,object_nums): 149 | band_width = 16// object_nums 150 | if (x ) // band_width == (x + h)//band_width or (y )// band_width == (y + w) //band_width: 151 | return True 152 | else: 153 | return False 154 | 155 | def find_max_sum_submatrix(matrix, area,object_nums): 156 | max_sum = 0 157 | 158 | num_rows, num_cols = len(matrix), len(matrix[0]) 159 | # print(num_rows,num_cols) 160 | hmax = int(area**0.5) 161 | # hmin = int(area / (16 * 0.7)) 162 | possiblehws = [] 163 | for h_pos in range(4,hmax): 164 | w_pos = int(area/h_pos) 165 | possiblehws.append((h_pos,w_pos)) 166 | possiblehws.append((w_pos,h_pos)) 167 | xm = 0 168 | ym = 0 169 | hwm = possiblehws[0] 170 | # 遍历每一个可能的矩形框起点位置 171 | for i in range(0,num_rows): 172 | for j in range(0,num_cols): 173 | for hw_now in possiblehws: 174 | h,w = hw_now 175 | if i+h >= num_rows-1 or j+w >= num_cols-1: 176 | continue 177 | # if not judge_region(i,j,h,w,object_nums): 178 | # continue 179 | current_sum = 0 180 | 181 | # 计算这个矩形框白所有元素的和 182 | for ki in range(h): 183 | for kj in range(w): 184 | current_sum += matrix[i + ki][j + kj] 185 | # 更新最大和 186 | if max_sum < current_sum/(h*w): 187 | hwm = hw_now 188 | xm = i 189 | ym = j 190 | max_sum = current_sum/(h*w) 191 | 192 | mask = torch.zeros((64,64)).to(matrix.device) 193 | real_mask = torch.zeros((64,64)).to(matrix.device) 194 | forbid_mask = torch.zeros_like(matrix) 195 | hm,wm = hwm 196 | forbid_mask[xm:xm+hm,ym:ym+wm] = 1 197 | # forbid_mask[ 198 | # (xm - 1 if xm - 1 > 0 else xm) : (xm + hm + 1 if xm + hm + 1 < num_rows else xm + hm), 199 | # (ym - 1 if ym - 1 > 0 else ym) : (ym + wm + 1 if ym + wm + 1 < num_cols else ym + wm) 200 | # ] = 1 201 | hm *= 4 202 | wm *= 4 203 | xm *= 4 204 | ym *= 4 205 | mask[xm:xm+hm,ym:ym+wm] = 1 206 | for iterh in range(hm+4): 207 | for iterw in range(wm+4): 208 | if xm+iterh-2 >= num_rows*4-1 or ym+iterw-2 >= num_cols*4-1: 209 | continue 210 | if xm+iterh-2 <0 or ym+iterw-2 < 0: 211 | continue 212 | if hm > wm: 213 | if point_to_line_distance((xm+iterh-2,ym+iterw-2), (xm,ym), (xm+hm,ym+wm))<1: 214 | real_mask[xm+iterh-2,ym+iterw-2] =1 215 | else: 216 | if point_to_line_distance((xm+iterh-2,ym+iterw-2), (xm,ym), (xm+hm,ym+wm))<1: 217 | real_mask[xm+iterh-2,ym+iterw-2] =1 218 | return mask,mask,forbid_mask #mask,real_mask,forbid_mask 219 | 220 | def resize_attn(attn_map,th,tw): 221 | h = int(math.sqrt(attn_map.shape[1])) 222 | w = h 223 | bz, _, c = attn_map.shape 224 | attn_map = attn_map.permute(0,2,1).reshape(bz,c,h,w) 225 | attn_map = torch.mean(attn_map,dim = 0,keepdim = True) 226 | resized_attn_map = nnf.interpolate(attn_map, size=(th,tw), mode='bilinear').squeeze(0) 227 | return resized_attn_map 228 | def text_list(text): 229 | text = text.replace(' ','') 230 | text = text.replace('\n','') 231 | text = text.replace('\t','') 232 | digits = text[1:-1].split(',') 233 | # import pdb; pdb.set_trace() 234 | result = [] 235 | for d in digits: 236 | result.append(int(d)) 237 | return tuple(result) 238 | class MaskdiffusionStore: 239 | 240 | @staticmethod 241 | def get_empty_store(): 242 | return {"down_cross": [], "mid_cross": [], "up_cross": [], 243 | "down_self": [], "mid_self": [], "up_self": []} 244 | @property 245 | def num_uncond_att_layers(self): 246 | return 0 247 | def forward(self, attn,is_cross: bool, place_in_unet: str): 248 | if not is_cross: 249 | if self.cur_step < 15: 250 | return self.add_self_mask(attn) 251 | else: 252 | return attn 253 | key = f"{place_in_unet}_{'cross' if is_cross else 'self'}" 254 | if self.cur_step > 1: 255 | attn = self.add_mask(attn) 256 | if attn.shape[1] == 16 ** 2: # avoid memory overhead 257 | self.step_store[key].append(attn[int(attn.size(0)/2):].softmax(-1).clone()) 258 | return attn 259 | 260 | 261 | def add_mask(self,sim): 262 | reduction_factor= (self.timesteps[self.cur_step]/1000) ** 5 263 | batch_size = int(sim.shape[0]/2) 264 | if self.cur_step == 0: 265 | return sim 266 | heads_nums = sim.shape[0]//2 267 | mask = self.mask[sim.size(1)].repeat(heads_nums,1,1) 268 | ## Calculate the minimum and maximum values along the token dimension. 269 | min_value = sim[batch_size:].min(-1)[0].unsqueeze(-1) 270 | max_value = sim[batch_size:].max(-1)[0].unsqueeze(-1) 271 | sim[batch_size:] += reduction_factor*1*(mask>0.05)*mask*(max_value-sim[batch_size:]) 272 | if self.cur_step <15: 273 | sim[batch_size:] = sim[batch_size:] - (mask==0)*reduction_factor*(sim[batch_size:]-min_value) + sim[batch_size:] * (mask>= 0.05) 274 | 275 | 276 | return sim 277 | def add_self_mask(self,sim): 278 | reduction_factor = 0.3*(self.timesteps[self.cur_step]/1000) ** 5 279 | batch_size = int(sim.shape[0]/2) 280 | if self.cur_step == 0: 281 | return sim 282 | 283 | heads_nums = sim.shape[0]//2 284 | mask = self.self_mask[sim.size(1)].repeat(heads_nums,1,1) 285 | 286 | min_value = sim[batch_size:].min(-1)[0].unsqueeze(-1) 287 | max_value = sim[batch_size:].max(-1)[0].unsqueeze(-1) 288 | sim[batch_size:] += (mask>0)*reduction_factor*(max_value-sim[batch_size:]) 289 | sim[batch_size:] -= (mask == 0)*reduction_factor*(sim[batch_size:]-min_value) 290 | return sim 291 | 292 | 293 | def between_steps(self): 294 | self.attention_store = self.step_store 295 | if self.save_global_store: 296 | with torch.no_grad(): 297 | if len(self.global_store) == 0: 298 | self.global_store = self.step_store 299 | else: 300 | for key in self.global_store: 301 | for i in range(len(self.global_store[key])): 302 | self.global_store[key][i] += self.step_store[key][i].detach() 303 | self.step_store = self.get_empty_store() 304 | 305 | def get_average_attention(self): 306 | average_attention = self.attention_store 307 | return average_attention 308 | 309 | def get_average_global_attention(self): 310 | average_attention = {key: [item / self.cur_step for item in self.global_store[key]] for key in 311 | self.attention_store} 312 | return average_attention 313 | def __call__(self, attn, is_cross: bool, place_in_unet: str): 314 | if self.cur_step >= self.end_step: 315 | return attn 316 | if self.cur_att_layer >= self.num_uncond_att_layers: 317 | attn = self.forward(attn, is_cross, place_in_unet) 318 | self.cur_att_layer += 1 319 | if self.cur_att_layer == self.num_att_layers + self.num_uncond_att_layers: 320 | self.cur_att_layer = 0 321 | self.cur_step += 1 322 | self.between_steps() 323 | if self.cur_step <5: 324 | self.generated_mask2() 325 | return attn 326 | def reset(self): 327 | 328 | self.cur_step = 0 329 | self.cur_att_layer = 0 330 | #### 331 | self.step_store = self.get_empty_store() 332 | self.attention_store = {} 333 | self.global_store = {} 334 | 335 | def get_input_mask(self,token_id,device): 336 | y0,x0,y1,x1 = self.input_mask[token_id] 337 | x0 = x0 // 8 338 | x1 = x1 // 8 339 | y0 = y0 // 8 340 | y1 = y1 // 8 341 | num_rows = 16 342 | num_cols = 16 343 | hm = x1-x0 344 | wm = y1-y0 345 | 346 | mask = torch.zeros((64,64)).to(device) 347 | real_mask = torch.zeros((64,64)).to(device) 348 | forbid_mask = torch.zeros((16,16)).to(device) 349 | forbid_mask[x0//8:x1//8,y0//8:y1//8] = 1 350 | mask[x0:x1,y0:y1] = 1 351 | 352 | for iterh in range(hm+4): 353 | for iterw in range(wm+4): 354 | if x0 + iterh-2 >= num_rows * 4 - 1 or y0 + iterw-2 >= num_cols * 4 - 1: 355 | continue 356 | if x0 + iterh-2 <0 or y0 + iterw-2 < 0: 357 | continue 358 | if hm > wm: 359 | if point_to_line_distance((x0+iterh-2,y0+iterw-2), (x0,y0), (x0+hm,y0+wm))<1: 360 | real_mask[x0+iterh-2,y0+iterw-2] =1 361 | else: 362 | if point_to_line_distance((x0+iterh-2,y0+iterw-2), (x0,y0), (x0+hm,y0+wm))<1: 363 | real_mask[x0+iterh-2,y0+iterw-2] =1 364 | 365 | return mask,mask,forbid_mask 366 | 367 | def generated_mask2(self): 368 | maskres = 16 369 | ##### Strength = 5 370 | w = 5 371 | average_attention = self.get_average_global_attention() 372 | extract_attentions = average_attention["down_cross"] + average_attention["up_cross"] 373 | print(len(extract_attentions)) 374 | ## Bz*head, h*w, 77 BZ=1 375 | extract_attentions = [resize_attn(extract_attention,maskres,maskres) for extract_attention in extract_attentions] 376 | mean_attentions = torch.mean(torch.stack(extract_attentions),dim = 0,keepdim = False).reshape(77,-1) 377 | # print(mean_attentions.shape) # 77 H,W 378 | num_pixels = mean_attentions.shape[-1] 379 | token_ids = list(self.token_dict.keys()) 380 | protect_indexs = {} 381 | protect_attentions = mean_attentions.clone() 382 | ## new zeros mask_maps 383 | mask_maps = torch.zeros(1,77,64,64).to(device) 384 | mask_maps[:,0] = 0.01 385 | mask_maps[:,self.text_length-2:,:,:] += 0.20 386 | self_maps = torch.zeros(1,maskres*maskres,maskres*maskres).to(device) 387 | negative_mask_map = torch.zeros(1,77,maskres,maskres).to(device) 388 | ## perform protection 389 | for token_id in token_ids: 390 | protect_indexs[token_id] = torch.topk(mean_attentions[token_id],int(num_pixels*0.15),0).indices 391 | sub_prompt_ids = self.token_dict[token_id] 392 | minus_mask = torch.zeros_like(mean_attentions, dtype=torch.bool) 393 | 394 | minus_mask[:,protect_indexs[token_id]] = True 395 | minus_mask[sub_prompt_ids] = False 396 | mean_attentions[minus_mask] /= 2 397 | save_mask_maps = [] 398 | save_real_maps = [] 399 | def get_other_ids(token_id): 400 | id_pools = [] 401 | for iter_id in list(self.token_dict.keys()): 402 | if token_id != iter_id: 403 | id_pools += self.token_dict[iter_id] 404 | return id_pools 405 | ## select the max rectangle 406 | object_nums = len(token_ids) 407 | for token_id in token_ids: 408 | sub_prompt_ids = self.token_dict[token_id] 409 | curmap = mean_attentions[token_id].clone() 410 | if self.input_mask is None: 411 | real_mask,mask_map,forbid_mask = find_max_sum_submatrix(curmap.reshape(maskres,maskres),maskres**2*0.2,object_nums) 412 | else: 413 | real_mask,mask_map,forbid_mask = self.get_input_mask(token_id,curmap.device) 414 | # mask_maps[0,get_other_ids(token_id)] -= 1* mask_map 415 | mask_maps[0,sub_prompt_ids] = 0.20 * mask_map 416 | mask_maps[0,token_id] = 0.20 * mask_map * w 417 | # negative_mask_map[0,sub_prompt_ids] = 1-mask_map 418 | save_mask_maps.append(mask_maps[0,token_id].clone()) 419 | save_real_maps.append(real_mask.clone()) 420 | # self_maps[0,:] += mask_map.reshape(-1,1) * (1 - mask_map.reshape(1,-1)) 421 | # negative_self_mask_map.append(mask_maps[0,token_id].clone()) 422 | mean_attentions -= (forbid_mask.reshape(1,-1)>0) * mean_attentions * 3 423 | # save_mask_maps[:,token_id] = mask_maps[:,token_id] 424 | save_mask_maps = torch.stack(save_mask_maps) # N x H x W 425 | 426 | 427 | maskdict = {} 428 | sreg_maps = {} 429 | negative_mask_maps = {} 430 | sreg_maps = {} 431 | negative_self_maps = {} 432 | for r in range(4): 433 | res = int(64/np.power(2,r)) 434 | layouts_s = nnf.interpolate(save_mask_maps.unsqueeze(1),(res, res),mode='nearest') 435 | layouts_s = (layouts_s.view(layouts_s.size(0),1,-1)*layouts_s.view(layouts_s.size(0),-1,1)).sum(0).unsqueeze(0) 436 | sreg_maps[np.power(res, 2)] = layouts_s 437 | layout_c = nnf.interpolate(mask_maps,(res,res),mode='nearest').view(1,77,-1).permute(0,2,1) 438 | maskdict[np.power(res, 2)] = layout_c 439 | 440 | 441 | self.mask = maskdict 442 | self.self_mask = sreg_maps 443 | self.negative_mask_maps = negative_mask_maps 444 | self.negative_self_maps = negative_self_maps 445 | 446 | 447 | return None 448 | def __init__(self, save_global_store=True,end_step = 15,token_dict = None): 449 | ''' 450 | Initialize an empty AttentionStore 451 | :param step_index: used to visualize only a specific step in the diffusion process 452 | ''' 453 | # super(AttentionStore, self).__init__() 454 | self.cur_step = 0 455 | self.num_att_layers = -1 456 | self.cur_att_layer = 0 457 | #### 458 | self.save_global_store = save_global_store 459 | self.step_store = self.get_empty_store() 460 | self.attention_store = {} 461 | self.global_store = {} 462 | #### maskdiffusion parameter 463 | self.text_cond = None 464 | self.mask = None 465 | self.end_step = end_step 466 | self.token_dict = token_dict 467 | #### 468 | self.timesteps = None 469 | self.self_mask = None 470 | self.input_mask = None 471 | 472 | -------------------------------------------------------------------------------- /utils/ptp_utils.py: -------------------------------------------------------------------------------- 1 | import abc 2 | 3 | import cv2 4 | import numpy as np 5 | import torch 6 | from IPython.display import display 7 | from PIL import Image 8 | from typing import Union, Tuple, List 9 | 10 | 11 | try: 12 | from diffusers.models.cross_attention import CrossAttention 13 | except ImportError: 14 | pass 15 | else: 16 | pass 17 | 18 | 19 | def get_word_inds(text: str, word_place: int, tokenizer): 20 | split_text = text.split(" ") 21 | if type(word_place) is str: 22 | word_place = [i for i, word in enumerate(split_text) if word_place == word] 23 | elif type(word_place) is int: 24 | word_place = [word_place] 25 | out = [] 26 | if len(word_place) > 0: 27 | words_encode = [tokenizer.decode([item]).strip("#") for item in tokenizer.encode(text)][1:-1] 28 | cur_len, ptr = 0, 0 29 | for i in range(len(words_encode)): 30 | cur_len += len(words_encode[i]) 31 | if ptr in word_place: 32 | out.append(i + 1) 33 | if cur_len >= len(split_text[ptr]): 34 | ptr += 1 35 | cur_len = 0 36 | return np.array(out) 37 | 38 | 39 | def update_alpha_time_word(alpha, bounds: Union[float, Tuple[float, float]], prompt_ind: int, 40 | word_inds = None): 41 | if type(bounds) is float: 42 | bounds = 0, bounds 43 | start, end = int(bounds[0] * alpha.shape[0]), int(bounds[1] * alpha.shape[0]) 44 | if word_inds is None: 45 | word_inds = torch.arange(alpha.shape[2]) 46 | alpha[: start, prompt_ind, word_inds] = 0 47 | alpha[start: end, prompt_ind, word_inds] = 1 48 | alpha[end:, prompt_ind, word_inds] = 0 49 | return alpha 50 | 51 | 52 | def get_time_words_attention_alpha(prompts, num_steps, 53 | cross_replace_steps , 54 | tokenizer, max_num_words=77): 55 | if type(cross_replace_steps) is not dict: 56 | cross_replace_steps = {"default_": cross_replace_steps} 57 | if "default_" not in cross_replace_steps: 58 | cross_replace_steps["default_"] = (0., 1.) 59 | alpha_time_words = torch.zeros(num_steps + 1, 1, max_num_words) 60 | for i in range(len(prompts)): 61 | alpha_time_words = update_alpha_time_word(alpha_time_words, cross_replace_steps["default_"], 62 | i) 63 | print(cross_replace_steps.items()) 64 | for key, item in cross_replace_steps.items(): 65 | if key != "default_": 66 | inds = [get_word_inds(prompts[i], key, tokenizer) for i in range(1, len(prompts))] 67 | for i, ind in enumerate(inds): 68 | if len(ind) > 0: 69 | alpha_time_words = update_alpha_time_word(alpha_time_words, item, i, ind) 70 | alpha_time_words = alpha_time_words.reshape(num_steps + 1, len(prompts), 1, 1, max_num_words) 71 | return alpha_time_words 72 | 73 | def text_under_image(image: np.ndarray, text: str, text_color: Tuple[int, int, int] = (0, 0, 0)) -> np.ndarray: 74 | h, w, c = image.shape 75 | offset = int(h * .2) 76 | img = np.ones((h + offset, w, c), dtype=np.uint8) * 255 77 | font = cv2.FONT_HERSHEY_SIMPLEX 78 | img[:h] = image 79 | textsize = cv2.getTextSize(text, font, 1, 2)[0] 80 | text_x, text_y = (w - textsize[0]) // 2, h + offset - textsize[1] // 2 81 | cv2.putText(img, text, (text_x, text_y), font, 1, text_color, 2) 82 | return img 83 | 84 | 85 | def view_images(images: Union[np.ndarray, List], 86 | num_rows: int = 1, 87 | offset_ratio: float = 0.02, 88 | display_image: bool = True) -> Image.Image: 89 | """ Displays a list of images in a grid. """ 90 | if type(images) is list: 91 | num_empty = len(images) % num_rows 92 | elif images.ndim == 4: 93 | num_empty = images.shape[0] % num_rows 94 | else: 95 | images = [images] 96 | num_empty = 0 97 | 98 | empty_images = np.ones(images[0].shape, dtype=np.uint8) * 255 99 | images = [image.astype(np.uint8) for image in images] + [empty_images] * num_empty 100 | num_items = len(images) 101 | 102 | h, w, c = images[0].shape 103 | offset = int(h * offset_ratio) 104 | num_cols = num_items // num_rows 105 | image_ = np.ones((h * num_rows + offset * (num_rows - 1), 106 | w * num_cols + offset * (num_cols - 1), 3), dtype=np.uint8) * 255 107 | for i in range(num_rows): 108 | for j in range(num_cols): 109 | image_[i * (h + offset): i * (h + offset) + h:, j * (w + offset): j * (w + offset) + w] = images[ 110 | i * num_cols + j] 111 | 112 | pil_img = Image.fromarray(image_) 113 | if display_image: 114 | display(pil_img) 115 | return pil_img 116 | 117 | 118 | class MaskdiffusionCrossAttnProcessor: 119 | 120 | def __init__(self, attnstore, place_in_unet): 121 | super().__init__() 122 | self.attnstore = attnstore 123 | self.place_in_unet = place_in_unet 124 | 125 | def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None): 126 | batch_size, sequence_length, _ = hidden_states.shape 127 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) 128 | 129 | query = attn.to_q(hidden_states) 130 | 131 | is_cross = encoder_hidden_states is not None 132 | encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states 133 | key = attn.to_k(encoder_hidden_states) 134 | value = attn.to_v(encoder_hidden_states) 135 | # key_1 = attn.to_q(hidden_states) 136 | # value_1 = attn.to_q(hidden_states) 137 | query = attn.head_to_batch_dim(query) 138 | key = attn.head_to_batch_dim(key) 139 | # key_1 = attn.head_to_batch_dim(key_1) 140 | value = attn.head_to_batch_dim(value) 141 | # value_1 = attn.head_to_batch_dim(value_1) 142 | if attn.upcast_attention: 143 | query = query.float() 144 | key = key.float() 145 | dtype = query.dtype 146 | # key = torch.cat([key, key_1], dim=1) 147 | # value = torch.cat([value, value_1], dim=1) 148 | # print(key.shape,value.shape,query.shape) 149 | attention_scores = torch.baddbmm( 150 | torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device), 151 | query, 152 | key.transpose(-1, -2), 153 | beta=0, 154 | alpha=attn.scale, 155 | ) 156 | 157 | # print(attention_scores) 158 | attention_probs = attn.get_attention_scores(query, key, attention_mask) 159 | 160 | attention_scores = self.attnstore(attention_probs.clone(), attention_scores.clone() ,is_cross, self.place_in_unet) 161 | # print(attention_scores) 162 | # exit() 163 | if attention_mask is not None: 164 | attention_scores = attention_scores + attention_mask 165 | 166 | if attn.upcast_softmax: 167 | attention_scores = attention_scores.float() 168 | 169 | attention_probs = attention_scores.softmax(dim=-1) 170 | attention_probs = attention_probs.to(dtype) 171 | 172 | hidden_states = torch.bmm(attention_probs, value) 173 | hidden_states = attn.batch_to_head_dim(hidden_states) 174 | 175 | # linear proj 176 | hidden_states = attn.to_out[0](hidden_states) 177 | # dropout 178 | hidden_states = attn.to_out[1](hidden_states) 179 | 180 | return hidden_states 181 | 182 | class AttendExciteCrossAttnProcessor: 183 | 184 | def __init__(self, attnstore, place_in_unet): 185 | super().__init__() 186 | self.attnstore = attnstore 187 | self.place_in_unet = place_in_unet 188 | 189 | def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None): 190 | batch_size, sequence_length, _ = hidden_states.shape 191 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) 192 | 193 | query = attn.to_q(hidden_states) 194 | 195 | is_cross = encoder_hidden_states is not None 196 | encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states 197 | key = attn.to_k(encoder_hidden_states) 198 | value = attn.to_v(encoder_hidden_states) 199 | 200 | query = attn.head_to_batch_dim(query) 201 | key = attn.head_to_batch_dim(key) 202 | value = attn.head_to_batch_dim(value) 203 | 204 | attention_probs = attn.get_attention_scores(query, key, attention_mask) 205 | 206 | self.attnstore(attention_probs, None, is_cross, self.place_in_unet) 207 | 208 | hidden_states = torch.bmm(attention_probs, value) 209 | hidden_states = attn.batch_to_head_dim(hidden_states) 210 | 211 | # linear proj 212 | hidden_states = attn.to_out[0](hidden_states) 213 | # dropout 214 | hidden_states = attn.to_out[1](hidden_states) 215 | 216 | return hidden_states 217 | 218 | class MACrossAttnProcessor: 219 | 220 | def __init__(self, attnstore, place_in_unet): 221 | super().__init__() 222 | self.attnstore = attnstore 223 | self.place_in_unet = place_in_unet 224 | 225 | def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None): 226 | batch_size, sequence_length, _ = hidden_states.shape 227 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) 228 | query = attn.to_q(hidden_states) 229 | 230 | is_cross = encoder_hidden_states is not None 231 | if is_cross: 232 | # print(encoder_hidden_states.shape,self.attnstore.text_cond.shape) 233 | if self.attnstore.cur_step <= self.attnstore.end_step: 234 | encoder_hidden_states = self.attnstore.text_cond 235 | else: 236 | encoder_hidden_states = encoder_hidden_states 237 | else: 238 | encoder_hidden_states = hidden_states 239 | key = attn.to_k(encoder_hidden_states) 240 | value = attn.to_v(encoder_hidden_states) 241 | 242 | query = attn.head_to_batch_dim(query) 243 | key = attn.head_to_batch_dim(key) 244 | value = attn.head_to_batch_dim(value) 245 | 246 | attention_scores = torch.baddbmm( 247 | torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device), 248 | query, 249 | key.transpose(-1, -2), 250 | beta=0, 251 | alpha=attn.scale, 252 | ) 253 | 254 | attention_scores = self.attnstore(attention_scores, is_cross, self.place_in_unet) 255 | attention_probs = attention_scores.softmax(dim=-1) 256 | hidden_states = torch.bmm(attention_probs, value) 257 | hidden_states = attn.batch_to_head_dim(hidden_states) 258 | 259 | # linear proj 260 | hidden_states = attn.to_out[0](hidden_states) 261 | # dropout 262 | hidden_states = attn.to_out[1](hidden_states) 263 | 264 | return hidden_states 265 | 266 | def register_attend_attention_control(model, controller): 267 | 268 | attn_procs = {} 269 | cross_att_count = 0 270 | for name in model.unet.attn_processors.keys(): 271 | cross_attention_dim = None if name.endswith("attn1.processor") else model.unet.config.cross_attention_dim 272 | if name.startswith("mid_block"): 273 | hidden_size = model.unet.config.block_out_channels[-1] 274 | place_in_unet = "mid" 275 | elif name.startswith("up_blocks"): 276 | block_id = int(name[len("up_blocks.")]) 277 | hidden_size = list(reversed(model.unet.config.block_out_channels))[block_id] 278 | place_in_unet = "up" 279 | elif name.startswith("down_blocks"): 280 | block_id = int(name[len("down_blocks.")]) 281 | hidden_size = model.unet.config.block_out_channels[block_id] 282 | place_in_unet = "down" 283 | else: 284 | continue 285 | 286 | cross_att_count += 1 287 | attn_procs[name] = AttendExciteCrossAttnProcessor( 288 | attnstore=controller, place_in_unet=place_in_unet 289 | ) 290 | 291 | model.unet.set_attn_processor(attn_procs) 292 | controller.num_att_layers = cross_att_count 293 | 294 | 295 | def register_maskdiffusion_attention_control(model, controller): 296 | 297 | attn_procs = {} 298 | cross_att_count = 0 299 | for name in model.unet.attn_processors.keys(): 300 | cross_attention_dim = None if name.endswith("attn1.processor") else model.unet.config.cross_attention_dim 301 | if name.startswith("mid_block"): 302 | hidden_size = model.unet.config.block_out_channels[-1] 303 | place_in_unet = "mid" 304 | elif name.startswith("up_blocks"): 305 | block_id = int(name[len("up_blocks.")]) 306 | hidden_size = list(reversed(model.unet.config.block_out_channels))[block_id] 307 | place_in_unet = "up" 308 | elif name.startswith("down_blocks"): 309 | block_id = int(name[len("down_blocks.")]) 310 | hidden_size = model.unet.config.block_out_channels[block_id] 311 | place_in_unet = "down" 312 | else: 313 | continue 314 | 315 | cross_att_count += 1 316 | attn_procs[name] = MaskdiffusionCrossAttnProcessor( 317 | attnstore=controller, place_in_unet=place_in_unet 318 | ) 319 | 320 | model.unet.set_attn_processor(attn_procs) 321 | controller.num_att_layers = cross_att_count 322 | 323 | 324 | def register_deepfloyd_attention_control(model, controller): 325 | 326 | attn_procs = {} 327 | cross_att_count = 0 328 | for name in model.unet.attn_processors.keys(): 329 | cross_attention_dim = None if name.endswith("attn1.processor") else model.unet.config.cross_attention_dim 330 | if name.startswith("mid_block"): 331 | hidden_size = model.unet.config.block_out_channels[-1] 332 | place_in_unet = "mid" 333 | elif name.startswith("up_blocks"): 334 | block_id = int(name[len("up_blocks.")]) 335 | hidden_size = list(reversed(model.unet.config.block_out_channels))[block_id] 336 | place_in_unet = "up" 337 | elif name.startswith("down_blocks"): 338 | block_id = int(name[len("down_blocks.")]) 339 | hidden_size = model.unet.config.block_out_channels[block_id] 340 | place_in_unet = "down" 341 | else: 342 | continue 343 | 344 | cross_att_count += 1 345 | attn_procs[name] = AttendExciteCrossAttnProcessor( 346 | attnstore=controller, place_in_unet=place_in_unet 347 | ) 348 | 349 | model.unet.set_attn_processor(attn_procs) 350 | controller.num_att_layers = cross_att_count 351 | 352 | class AttentionControl(abc.ABC): 353 | 354 | def step_callback(self, x_t): 355 | return x_t 356 | 357 | def between_steps(self): 358 | return 359 | 360 | @property 361 | def num_uncond_att_layers(self): 362 | return 0 363 | 364 | @abc.abstractmethod 365 | def forward(self, attn, is_cross: bool, place_in_unet: str): 366 | raise NotImplementedError 367 | 368 | def __call__(self, attn,attn2, is_cross: bool, place_in_unet: str): 369 | 370 | if self.cur_att_layer >= self.num_uncond_att_layers: 371 | attnnew = self.forward(attn,attn2, is_cross, place_in_unet) 372 | self.cur_att_layer += 1 373 | if self.cur_att_layer == self.num_att_layers + self.num_uncond_att_layers: 374 | self.cur_att_layer = 0 375 | self.cur_step += 1 376 | self.between_steps() 377 | return attnnew 378 | 379 | def reset(self): 380 | self.cur_step = 0 381 | self.cur_att_layer = 0 382 | 383 | def __init__(self): 384 | self.cur_step = 0 385 | self.num_att_layers = -1 386 | self.cur_att_layer = 0 387 | 388 | 389 | class EmptyControl(AttentionControl): 390 | 391 | def forward(self, attn, is_cross: bool, place_in_unet: str): 392 | return attn 393 | 394 | 395 | class AttentionStore(AttentionControl): 396 | 397 | @staticmethod 398 | def get_empty_store(): 399 | return {"down_cross": [], "mid_cross": [], "up_cross": [], 400 | "down_self": [], "mid_self": [], "up_self": []} 401 | 402 | def forward(self, attn, attn2,is_cross: bool, place_in_unet: str): 403 | key = f"{place_in_unet}_{'cross' if is_cross else 'self'}" 404 | if attn.shape[1] <= 32 ** 2: # avoid memory overhead 405 | self.step_store[key].append(attn) 406 | return attn2 407 | 408 | def between_steps(self): 409 | self.attention_store = self.step_store 410 | if self.save_global_store: 411 | with torch.no_grad(): 412 | if len(self.global_store) == 0: 413 | self.global_store = self.step_store 414 | else: 415 | for key in self.global_store: 416 | for i in range(len(self.global_store[key])): 417 | self.global_store[key][i] += self.step_store[key][i].detach() 418 | self.step_store = self.get_empty_store() 419 | 420 | def get_average_attention(self): 421 | average_attention = self.attention_store 422 | return average_attention 423 | 424 | def get_average_global_attention(self): 425 | average_attention = {key: [item / self.cur_step for item in self.global_store[key]] for key in 426 | self.attention_store} 427 | return average_attention 428 | 429 | def reset(self): 430 | super(AttentionStore, self).reset() 431 | self.step_store = self.get_empty_store() 432 | self.attention_store = {} 433 | self.global_store = {} 434 | 435 | def __init__(self, save_global_store=False): 436 | ''' 437 | Initialize an empty AttentionStore 438 | :param step_index: used to visualize only a specific step in the diffusion process 439 | ''' 440 | super(AttentionStore, self).__init__() 441 | #### 442 | self.save_global_store = save_global_store 443 | self.step_store = self.get_empty_store() 444 | self.attention_store = {} 445 | self.global_store = {} 446 | self.curr_step_index = 0 447 | 448 | 449 | 450 | def aggregate_attention(attention_store: AttentionStore, 451 | res: int, 452 | from_where: List[str], 453 | is_cross: bool, 454 | select: int) -> torch.Tensor: 455 | """ Aggregates the attention across the different layers and heads at the specified resolution. """ 456 | out = [] 457 | attention_maps = attention_store.get_average_attention() 458 | num_pixels = res ** 2 459 | for location in from_where: 460 | for item in attention_maps[f"{location}_{'cross' if is_cross else 'self'}"]: 461 | if item.shape[1] == num_pixels: 462 | #print(item.shape[1],num_pixels) 463 | cross_maps = item.reshape(1, -1, res, res, item.shape[-1])[select] 464 | out.append(cross_maps) 465 | out = torch.cat(out, dim=0) 466 | out = out.sum(0) / out.shape[0] 467 | return out 468 | 469 | def aggregate_attention_head(attention_store: AttentionStore, 470 | res: int, 471 | from_where: List[str], 472 | is_cross: bool, 473 | select: int) -> torch.Tensor: 474 | """ Aggregates the attention across the different layers and heads at the specified resolution. """ 475 | out = [] 476 | attention_maps = attention_store.get_average_attention() 477 | num_pixels = res ** 2 478 | for location in from_where: 479 | for item in attention_maps[f"{location}_{'cross' if is_cross else 'self'}"]: 480 | if item.shape[1] == num_pixels: 481 | cross_maps = item.reshape(1, -1, res, res, item.shape[-1])[select] 482 | out.append(cross_maps.unsqueeze(0)) 483 | out = torch.cat(out, dim=0) 484 | out = out.sum(0) / out.shape[0] 485 | return out 486 | 487 | 488 | def register_MA_attention_control(model, controller): 489 | 490 | attn_procs = {} 491 | cross_att_count = 0 492 | for name in model.unet.attn_processors.keys(): 493 | cross_attention_dim = None if name.endswith("attn1.processor") else model.unet.config.cross_attention_dim 494 | if name.startswith("mid_block"): 495 | hidden_size = model.unet.config.block_out_channels[-1] 496 | place_in_unet = "mid" 497 | elif name.startswith("up_blocks"): 498 | block_id = int(name[len("up_blocks.")]) 499 | hidden_size = list(reversed(model.unet.config.block_out_channels))[block_id] 500 | place_in_unet = "up" 501 | elif name.startswith("down_blocks"): 502 | block_id = int(name[len("down_blocks.")]) 503 | hidden_size = model.unet.config.block_out_channels[block_id] 504 | place_in_unet = "down" 505 | else: 506 | continue 507 | 508 | cross_att_count += 1 509 | attn_procs[name] = MACrossAttnProcessor( 510 | attnstore=controller, place_in_unet=place_in_unet 511 | ) 512 | 513 | model.unet.set_attn_processor(attn_procs) 514 | controller.num_att_layers = cross_att_count 515 | --------------------------------------------------------------------------------