├── .gitignore
├── images
├── 00006-2623256163.png
├── 00006-2623256163_cat.png
├── 00006-2623256163_beach.png
└── 00006-2623256163_sunglasses.png
├── scripts
├── daam
│ ├── __init__.py
│ ├── hook.py
│ ├── evaluate.py
│ ├── utils.py
│ ├── experiment.py
│ └── trace.py
└── daam_script.py
├── install.py
├── README.md
└── LICENSE
/.gitignore:
--------------------------------------------------------------------------------
1 | .vscode/
2 | .env
3 |
4 | __pycache__/
--------------------------------------------------------------------------------
/images/00006-2623256163.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kousw/stable-diffusion-webui-daam/HEAD/images/00006-2623256163.png
--------------------------------------------------------------------------------
/images/00006-2623256163_cat.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kousw/stable-diffusion-webui-daam/HEAD/images/00006-2623256163_cat.png
--------------------------------------------------------------------------------
/images/00006-2623256163_beach.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kousw/stable-diffusion-webui-daam/HEAD/images/00006-2623256163_beach.png
--------------------------------------------------------------------------------
/images/00006-2623256163_sunglasses.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kousw/stable-diffusion-webui-daam/HEAD/images/00006-2623256163_sunglasses.png
--------------------------------------------------------------------------------
/scripts/daam/__init__.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | from .hook import *
3 | from .utils import *
4 | from .evaluate import *
5 | from .experiment import *
6 | from .trace import *
--------------------------------------------------------------------------------
/install.py:
--------------------------------------------------------------------------------
1 | import launch
2 |
3 |
4 | def check_matplotlib():
5 | if not launch.is_installed("matplotlib"):
6 | return False
7 |
8 | try:
9 | import matplotlib
10 | except ImportError:
11 | return False
12 |
13 | if hasattr(matplotlib, "__version_info__"):
14 | version = matplotlib.__version_info__
15 | version = (version.major, version.minor, version.micro)
16 | return version >= (3, 6, 2)
17 | return False
18 |
19 |
20 | if not check_matplotlib():
21 | launch.run_pip("install matplotlib==3.6.2", desc="Installing matplotlib==3.6.2")
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # DAAM Extension for Stable Diffusion Web UI
2 |
3 | This is a port of [DAAM](https://github.com/castorini/daam) for Stable Diffusion Web UI.
4 |
5 | *Now support for SDXL!*
6 |
7 | # Update
8 | - `2024/01/28` Fixed an issue where it did not work with the latest version of webui. Added support for SDXL.
9 |
10 | # Setup and Running
11 |
12 | Clone this repository to extension folder.
13 |
14 | # How to use
15 |
16 | Select "Daam script" from the script drop-down. Enter the 'attention text' (must be a string contained in the prompt ) and run.
17 | An overlapping image with a heatmap for each attention will be generated along with the original image.
18 | Images will now be created in the default output directory.
19 |
20 | Attention text is divided by commas, but multiple words without commas are recognized as a single sequence.
21 | If you type "cat" for attention text, then all the tokens matching "cat" will be retrieved and combined into attention.
22 | If you type "cute cat", only tokens with "cute" and "cat" in sequence will be retrieved and only their attention will be output.
23 |
24 | # Sample
25 |
26 | prompt : "A photo of a cute cat wearing sunglasses relaxing on a beach"
27 |
28 | attention text: "cat, sunglasses, beach"
29 |
30 | output images: orginal, cat, sunglasses, beach
31 |
32 |
33 |
34 |
35 |
36 |
37 | # Tutorial
38 |
39 | - [Easiest way to use DAAM script tutorial](https://www.youtube.com/watch?v=XiKyEKJrTLQ)
40 |
41 | [](https://www.youtube.com/watch?v=XiKyEKJrTLQ)
42 |
43 | # Notice
44 | At the moment, this works well with the Stable Diffusion 1.5 model.
45 | However, in the Stable Diffusion 2.0 model this seems to be working a little less well.
46 |
47 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | The MIT License (MIT)
2 |
3 | Modified work Copyright (c) 2022 kousw
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
23 |
24 | ### castorini/daam
25 |
26 | License: MIT
27 | By: Castorini
28 | Repository:
29 |
30 | >MIT License
31 | >
32 | >Copyright (c) 2022 Castorini
33 | >
34 | >Permission is hereby granted, free of charge, to any person obtaining a copy
35 | >of this software and associated documentation files (the "Software"), to deal
36 | >in the Software without restriction, including without limitation the rights
37 | >to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
38 | >copies of the Software, and to permit persons to whom the Software is
39 | >furnished to do so, subject to the following conditions:
40 | >
41 | >The above copyright notice and this permission notice shall be included in all
42 | >copies or substantial portions of the Software.
43 | >
44 | >THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
45 | >IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
46 | >FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
47 | >AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
48 | >LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
49 | >OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
50 | >SOFTWARE.
--------------------------------------------------------------------------------
/scripts/daam/hook.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | from typing import List, Generic, TypeVar, Callable, Union, Any
3 | import functools
4 | import itertools
5 | from ldm.modules.attention import SpatialTransformer
6 |
7 | from ldm.modules.diffusionmodules.openaimodel import UNetModel
8 | from ldm.modules.attention import CrossAttention
9 | import torch.nn as nn
10 |
11 |
12 | __all__ = ['ObjectHooker', 'ModuleLocator', 'AggregateHooker', 'UNetCrossAttentionLocator']
13 |
14 |
15 | ModuleType = TypeVar('ModuleType')
16 | ModuleListType = TypeVar('ModuleListType', bound=List)
17 |
18 |
19 | class ModuleLocator(Generic[ModuleType]):
20 | def locate(self, model: nn.Module) -> List[ModuleType]:
21 | raise NotImplementedError
22 |
23 |
24 | class ObjectHooker(Generic[ModuleType]):
25 | def __init__(self, module: ModuleType):
26 | self.module: ModuleType = module
27 | self.hooked = False
28 | self.old_state = dict()
29 |
30 | def __enter__(self):
31 | self.hook()
32 | return self
33 |
34 | def __exit__(self, exc_type, exc_val, exc_tb):
35 | self.unhook()
36 |
37 | def hook(self):
38 | if self.hooked:
39 | raise RuntimeError('Already hooked module')
40 |
41 | self.old_state = dict()
42 | self.hooked = True
43 | self._hook_impl()
44 |
45 | return self
46 |
47 | def unhook(self):
48 | if not self.hooked:
49 | raise RuntimeError('Module is not hooked')
50 |
51 | for k, v in self.old_state.items():
52 | if k.startswith('old_fn_'):
53 | setattr(self.module, k[7:], v)
54 |
55 | self.hooked = False
56 | self._unhook_impl()
57 |
58 | return self
59 |
60 | def monkey_patch(self, fn_name, fn):
61 | self.old_state[f'old_fn_{fn_name}'] = getattr(self.module, fn_name)
62 | setattr(self.module, fn_name, functools.partial(fn, self.module))
63 |
64 | def monkey_super(self, fn_name, *args, **kwargs):
65 | return self.old_state[f'old_fn_{fn_name}'](*args, **kwargs)
66 |
67 | def _hook_impl(self):
68 | raise NotImplementedError
69 |
70 | def _unhook_impl(self):
71 | pass
72 |
73 |
74 | class AggregateHooker(ObjectHooker[ModuleListType]):
75 | def _hook_impl(self):
76 | for h in self.module:
77 | h.hook()
78 |
79 | def _unhook_impl(self):
80 | for h in self.module:
81 | h.unhook()
82 |
83 | def register_hook(self, hook: ObjectHooker):
84 | self.module.append(hook)
85 |
86 |
87 | class UNetCrossAttentionLocator(ModuleLocator[CrossAttention]):
88 | def locate(self, model: UNetModel, layer_idx: int) -> List[CrossAttention]:
89 | """
90 | Locate all cross-attention modules in a UNetModel.
91 |
92 | Args:
93 | model (`UNetModel`): The model to locate the cross-attention modules in.
94 |
95 | Returns:
96 | `List[CrossAttention]`: The list of cross-attention modules.
97 | """
98 | blocks = []
99 |
100 | for i, unet_block in enumerate(itertools.chain(model.input_blocks, [model.middle_block], model.output_blocks)):
101 | # if 'CrossAttn' in unet_block.__class__.__name__:
102 | if not layer_idx or i == layer_idx:
103 | for module in unet_block.modules():
104 | if module.__class__.__name__ == "SpatialTransformer":
105 | spatial_transformer = module
106 | for basic_transformer_block in spatial_transformer.transformer_blocks:
107 | blocks.append(basic_transformer_block.attn2)
108 |
109 |
110 | return blocks
111 |
--------------------------------------------------------------------------------
/scripts/daam/evaluate.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | from collections import defaultdict
3 | from typing import List
4 |
5 | from scipy.optimize import linear_sum_assignment
6 | import PIL.Image as Image
7 | import numpy as np
8 | import torch
9 | import torch.nn.functional as F
10 |
11 |
12 | __all__ = ['compute_iou', 'MeanEvaluator']
13 |
14 |
15 | def compute_iou(a: torch.Tensor, b: torch.Tensor) -> float:
16 | if a.shape[0] != b.shape[0]:
17 | a = F.interpolate(a.unsqueeze(0).unsqueeze(0).float(), size=b.shape, mode='bicubic').squeeze()
18 | a[a < 1] = 0
19 | a[a >= 1] = 1
20 |
21 | intersection = (a * b).sum()
22 | union = a.sum() + b.sum() - intersection
23 |
24 | return (intersection / (union + 1e-8)).item()
25 |
26 |
27 | def load_mask(path: str) -> torch.Tensor:
28 | mask = np.array(Image.open(path))
29 | mask = torch.from_numpy(mask).float()[:, :, 3] # use alpha channel
30 | mask = (mask > 0).float()
31 |
32 | return mask
33 |
34 |
35 | class UnsupervisedEvaluator:
36 | def __init__(self, name: str = 'UnsupervisedEvaluator'):
37 | self.name = name
38 | self.ious = defaultdict(list)
39 | self.num_samples = 0
40 |
41 | def log_iou(self, preds: torch.Tensor | List[torch.Tensor], truth: torch.Tensor, gt_idx: int = 0, pred_idx: int = 0):
42 | if not isinstance(preds, list):
43 | preds = [preds]
44 |
45 | iou = max(compute_iou(pred, truth) for pred in preds)
46 | self.ious[gt_idx].append((pred_idx, iou))
47 |
48 | @property
49 | def mean_iou(self) -> float:
50 | n = max(max(self.ious), max([y[0] for x in self.ious.values() for y in x])) + 1
51 | iou_matrix = np.zeros((n, n))
52 | count_matrix = np.zeros((n, n))
53 |
54 | for gt_idx, ious in self.ious.items():
55 | for pred_idx, iou in ious:
56 | iou_matrix[gt_idx, pred_idx] += iou
57 | count_matrix[gt_idx, pred_idx] += 1
58 |
59 | row_ind, col_ind = linear_sum_assignment(iou_matrix, maximize=True)
60 | return iou_matrix[row_ind, col_ind].sum() / count_matrix[row_ind, col_ind].sum()
61 |
62 | def increment(self):
63 | self.num_samples += 1
64 |
65 | def __len__(self) -> int:
66 | return self.num_samples
67 |
68 | def __str__(self):
69 | return f'{self.name}<{self.mean_iou:.4f} (mIoU) {len(self)} samples>'
70 |
71 |
72 | class MeanEvaluator:
73 | def __init__(self, name: str = 'MeanEvaluator'):
74 | self.ious: List[float] = []
75 | self.intensities: List[float] = []
76 | self.name = name
77 |
78 | def log_iou(self, preds: torch.Tensor | List[torch.Tensor], truth: torch.Tensor):
79 | if not isinstance(preds, list):
80 | preds = [preds]
81 |
82 | self.ious.append(max(compute_iou(pred, truth) for pred in preds))
83 | return self
84 |
85 | def log_intensity(self, pred: torch.Tensor):
86 | self.intensities.append(pred.mean().item())
87 | return self
88 |
89 | @property
90 | def mean_iou(self) -> float:
91 | return np.mean(self.ious)
92 |
93 | @property
94 | def mean_intensity(self) -> float:
95 | return np.mean(self.intensities)
96 |
97 | @property
98 | def ci95_miou(self) -> float:
99 | return 1.96 * np.std(self.ious) / np.sqrt(len(self.ious))
100 |
101 | def __len__(self) -> int:
102 | return max(len(self.ious), len(self.intensities))
103 |
104 | def __str__(self):
105 | return f'{self.name}<{self.mean_iou:.4f} (±{self.ci95_miou:.3f} mIoU) {self.mean_intensity:.4f} (mInt) {len(self)} samples>'
106 |
107 |
108 | if __name__ == '__main__':
109 | mask = load_mask('truth/output/452/sink.gt.png')
110 |
111 | print(MeanEvaluator().log_iou(mask, mask))
112 |
--------------------------------------------------------------------------------
/scripts/daam/utils.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | from itertools import chain
3 | from functools import lru_cache
4 | from pathlib import Path
5 | import random
6 | import re
7 | from typing import Union
8 |
9 | from PIL import Image, ImageFont, ImageDraw
10 | # from fonts.ttf import Roboto
11 | from modules.paths_internal import roboto_ttf_file
12 | import matplotlib.pyplot as plt
13 | from matplotlib import cm
14 | import numpy as np
15 | # import spacy
16 | import torch
17 | import torch.nn.functional as F
18 | from modules.devices import dtype
19 |
20 | from ldm.modules.encoders.modules import FrozenCLIPEmbedder, FrozenOpenCLIPEmbedder
21 | import open_clip.tokenizer
22 | from modules.sd_hijack_clip import FrozenCLIPEmbedderWithCustomWordsBase, FrozenCLIPEmbedderWithCustomWords
23 | from modules.sd_hijack_open_clip import FrozenOpenCLIPEmbedderWithCustomWords
24 | from modules.shared import opts
25 | from sgm.modules import GeneralConditioner
26 |
27 | __all__ = ['expand_image', 'set_seed', 'escape_prompt', 'calc_context_size', 'compute_token_merge_indices', 'compute_token_merge_indices_with_tokenizer', 'image_overlay_heat_map', 'plot_overlay_heat_map', 'plot_mask_heat_map', 'PromptAnalyzer']
28 |
29 | def expand_image(im: torch.Tensor, h = 512, w = 512, absolute: bool = False, threshold: float = None) -> torch.Tensor:
30 |
31 | im = im.unsqueeze(0).unsqueeze(0)
32 | im = F.interpolate(im.float().detach(), size=(h, w), mode='bicubic')
33 |
34 | if not absolute:
35 | im = (im - im.min()) / (im.max() - im.min() + 1e-8)
36 |
37 | if threshold:
38 | im = (im > threshold).float()
39 |
40 | # im = im.cpu().detach()
41 |
42 | return im.squeeze()
43 |
44 | def _write_on_image(img, caption, font_size = 32):
45 | ix,iy = img.size
46 | draw = ImageDraw.Draw(img)
47 | margin=2
48 | fontsize=font_size
49 | draw = ImageDraw.Draw(img)
50 | font = ImageFont.truetype(roboto_ttf_file, fontsize)
51 | text_height=iy-60
52 | tx = draw.textbbox((0,0),caption,font)
53 | draw.text((int((ix-tx[2])/2),text_height+margin),caption,(0,0,0),font=font)
54 | draw.text((int((ix-tx[2])/2),text_height-margin),caption,(0,0,0),font=font)
55 | draw.text((int((ix-tx[2])/2+margin),text_height),caption,(0,0,0),font=font)
56 | draw.text((int((ix-tx[2])/2-margin),text_height),caption,(0,0,0),font=font)
57 | draw.text((int((ix-tx[2])/2),text_height), caption,(255,255,255),font=font)
58 | return img
59 |
60 | def image_overlay_heat_map(img, heat_map, word=None, out_file=None, crop=None, alpha=0.5, caption=None, image_scale=1.0):
61 | # type: (Image.Image | np.ndarray, torch.Tensor, str, Path, int, float, str, float) -> Image.Image
62 | assert(img is not None)
63 |
64 | if heat_map is not None:
65 | shape : torch.Size = heat_map.shape
66 | # heat_map = heat_map.unsqueeze(-1).expand(shape[0], shape[1], 3).clone()
67 | heat_map = _convert_heat_map_colors(heat_map)
68 | heat_map = heat_map.to('cpu').detach().numpy().copy().astype(np.uint8)
69 | heat_map_img = Image.fromarray(heat_map)
70 |
71 | img = Image.blend(img, heat_map_img, alpha)
72 | else:
73 | img = img.copy()
74 |
75 | if caption:
76 | img = _write_on_image(img, caption)
77 |
78 | if image_scale != 1.0:
79 | x, y = img.size
80 | size = (int(x * image_scale), int(y * image_scale))
81 | img = img.resize(size, Image.BICUBIC)
82 |
83 | return img
84 |
85 |
86 | def _convert_heat_map_colors(heat_map : torch.Tensor):
87 | def get_color(value):
88 | return np.array(cm.turbo(value / 255)[0:3])
89 |
90 | color_map = np.array([ get_color(i) * 255 for i in range(256) ])
91 | color_map = torch.tensor(color_map, device=heat_map.device, dtype=dtype)
92 |
93 | heat_map = (heat_map * 255).long()
94 |
95 | return color_map[heat_map]
96 |
97 | def plot_overlay_heat_map(im, heat_map, word=None, out_file=None, crop=None):
98 | # type: (Image.Image | np.ndarray, torch.Tensor, str, Path, int) -> None
99 | plt.clf()
100 | plt.rcParams.update({'font.size': 24})
101 |
102 | im = np.array(im)
103 | if crop is not None:
104 | heat_map = heat_map.squeeze()[crop:-crop, crop:-crop]
105 | im = im[crop:-crop, crop:-crop]
106 |
107 | plt.imshow(heat_map.squeeze().cpu().numpy(), cmap='jet')
108 | im = torch.from_numpy(im).float() / 255
109 | im = torch.cat((im, (1 - heat_map.unsqueeze(-1))), dim=-1)
110 | plt.imshow(im)
111 |
112 | if word is not None:
113 | plt.title(word)
114 |
115 | if out_file is not None:
116 | plt.savefig(out_file)
117 |
118 |
119 | def plot_mask_heat_map(im: Image.Image, heat_map: torch.Tensor, threshold: float = 0.4):
120 | im = torch.from_numpy(np.array(im)).float() / 255
121 | mask = (heat_map.squeeze() > threshold).float()
122 | im = im * mask.unsqueeze(-1)
123 | plt.imshow(im)
124 |
125 |
126 | def set_seed(seed: int) -> torch.Generator:
127 | random.seed(seed)
128 | np.random.seed(seed)
129 | torch.manual_seed(seed)
130 | torch.cuda.manual_seed_all(seed)
131 |
132 | gen = torch.Generator(device='cuda')
133 | gen.manual_seed(seed)
134 |
135 | return gen
136 |
137 | def calc_context_size(token_length : int):
138 | len_check = 0 if (token_length - 1) < 0 else token_length - 1
139 | return ((int)(len_check // 75) + 1) * 77
140 |
141 | def escape_prompt(prompt):
142 | if type(prompt) is str:
143 | prompt = prompt.lower()
144 | prompt = re.sub(r"[\(\)\[\]]", "", prompt)
145 | prompt = re.sub(r":\d+\.*\d*", "", prompt)
146 | return prompt
147 | elif type(prompt) is list:
148 | prompt_new = []
149 | for i in range(len(prompt)):
150 | prompt_new.append(escape_prompt(prompt[i]))
151 | return prompt_new
152 |
153 |
154 | def compute_token_merge_indices(model, prompt: str, word: str, word_idx: int = None):
155 |
156 | clip = None
157 | tokenize = None
158 | if type(model.cond_stage_model.wrapped) == FrozenCLIPEmbedder:
159 | clip : FrozenCLIPEmbedder = model.cond_stage_model.wrapped
160 | tokenize = clip.tokenizer.tokenize
161 | elif type(model.cond_stage_model.wrapped) == FrozenOpenCLIPEmbedder:
162 | clip : FrozenOpenCLIPEmbedder = model.cond_stage_model.wrapped
163 | tokenize = open_clip.tokenizer._tokenizer.encode
164 | else:
165 | assert False
166 |
167 | escaped_prompt = escape_prompt(prompt)
168 | # escaped_prompt = re.sub(r"[_-]", " ", escaped_prompt)
169 | tokens : list = tokenize(escaped_prompt)
170 | word = word.lower()
171 | merge_idxs = []
172 |
173 | needles = tokenize(word)
174 |
175 | if len(needles) == 0:
176 | return []
177 |
178 | for i, token in enumerate(tokens):
179 | if needles[0] == token and len(needles) > 1:
180 | next = i + 1
181 | success = True
182 | for needle in needles[1:]:
183 | if next >= len(tokens) or needle != tokens[next]:
184 | success = False
185 | break
186 | next += 1
187 |
188 | # append consecutive indexes if all pass
189 | if success:
190 | merge_idxs.extend(list(range(i, next)))
191 |
192 | elif needles[0] == token:
193 | merge_idxs.append(i)
194 |
195 | idxs = []
196 | for x in merge_idxs:
197 | seq = (int)(x / 75)
198 | if seq == 0:
199 | idxs.append(x + 1) # padding
200 | else:
201 | idxs.append(x + 1 + seq*2) # If tokens exceed 75, they are split.
202 |
203 | return idxs
204 |
205 | def compute_token_merge_indices_with_tokenizer(tokenizer, prompt: str, word: str, word_idx: int = None, limit : int = -1):
206 |
207 | escaped_prompt = escape_prompt(prompt)
208 | # escaped_prompt = re.sub(r"[_-]", " ", escaped_prompt)
209 | tokens : list = tokenizer.tokenize(escaped_prompt)
210 | word = word.lower()
211 | merge_idxs = []
212 |
213 | needles = tokenizer.tokenize(word)
214 |
215 | if len(needles) == 0:
216 | return []
217 |
218 | limit_count = 0
219 | for i, token in enumerate(tokens):
220 | if needles[0] == token and len(needles) > 1:
221 | next = i + 1
222 | success = True
223 | for needle in needles[1:]:
224 | if next >= len(tokens) or needle != tokens[next]:
225 | success = False
226 | break
227 | next += 1
228 |
229 | # append consecutive indexes if all pass
230 | if success:
231 | merge_idxs.extend(list(range(i, next)))
232 | if limit > 0:
233 | limit_count += 1
234 | if limit_count >= limit:
235 | break
236 |
237 | elif needles[0] == token:
238 | merge_idxs.append(i)
239 | if limit > 0:
240 | limit_count += 1
241 | if limit_count >= limit:
242 | break
243 |
244 | idxs = []
245 | for x in merge_idxs:
246 | seq = (int)(x / 75)
247 | if seq == 0:
248 | idxs.append(x + 1) # padding
249 | else:
250 | idxs.append(x + 1 + seq*2) # If tokens exceed 75, they are split.
251 |
252 | return idxs
253 |
254 | nlp = None
255 |
256 |
257 | @lru_cache(maxsize=100000)
258 | def cached_nlp(prompt: str, type='en_core_web_md'):
259 | global nlp
260 |
261 | # if nlp is None:
262 | # nlp = spacy.load(type)
263 |
264 | return nlp(prompt)
265 |
266 | class PromptAnalyzer:
267 | def __init__(self, clip : Union[FrozenCLIPEmbedderWithCustomWordsBase, GeneralConditioner], text : str):
268 | use_old = opts.use_old_emphasis_implementation
269 | assert not use_old, "use_old_emphasis_implementation is not supported"
270 |
271 | self.clip = clip
272 | # self.id_start = clip.id_start
273 | # self.id_end = clip.id_end
274 | self.is_open_clip = True if type(clip) == FrozenOpenCLIPEmbedderWithCustomWords else False
275 | self.is_sdxl = True if type(clip) == GeneralConditioner else False
276 | self.used_custom_terms = []
277 | self.hijack_comments = []
278 |
279 | chunks, token_count = self.tokenize_line(text)
280 |
281 | self.token_count = token_count
282 | self.fixes = list(chain.from_iterable(chunk.fixes for chunk in chunks))
283 | self.context_size = calc_context_size(token_count)
284 |
285 | tokens = list(chain.from_iterable(chunk.tokens for chunk in chunks))
286 | multipliers = list(chain.from_iterable(chunk.multipliers for chunk in chunks))
287 | print(tokens, multipliers)
288 | print(len(tokens), len(multipliers))
289 |
290 | self.tokens = tokens # []
291 | self.multipliers = multipliers # []
292 | # for i in range(self.context_size // 77):
293 | # self.tokens.extend([self.id_start] + tokens[i*75:i*75+75] + [self.id_end])
294 | # self.multipliers.extend([1.0] + multipliers[i*75:i*75+75]+ [1.0])
295 |
296 | def create(self, text : str):
297 | return PromptAnalyzer(self.clip, text)
298 |
299 | def tokenize_line(self, line):
300 | chunks, token_count = self.clip.tokenize_line(line)
301 | return chunks, token_count
302 |
303 | def process_text(self, texts):
304 | batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.clip.process_text(texts)
305 | print(batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count)
306 | return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
307 |
308 | def encode(self, text : str):
309 | return self.clip.tokenize([text])[0]
310 |
311 | def calc_word_indecies(self, word : str, limit : int = -1, start_pos = 0):
312 | word = word.lower()
313 | merge_idxs = []
314 |
315 | tokens = self.tokens
316 | needles = self.encode(word)
317 |
318 | limit_count = 0
319 | current_pos = 0
320 | for i, token in enumerate(tokens):
321 | current_pos = i
322 | if i < start_pos:
323 | continue
324 |
325 | if needles[0] == token and len(needles) > 1:
326 | next = i + 1
327 | success = True
328 | for needle in needles[1:]:
329 | if next >= len(tokens) or needle != tokens[next]:
330 | success = False
331 | break
332 | next += 1
333 |
334 | # append consecutive indexes if all pass
335 | if success:
336 | merge_idxs.extend(list(range(i, next)))
337 | if limit > 0:
338 | limit_count += 1
339 | if limit_count >= limit:
340 | break
341 |
342 | elif needles[0] == token:
343 | merge_idxs.append(i)
344 | if limit > 0:
345 | limit_count += 1
346 | if limit_count >= limit:
347 | break
348 |
349 | return merge_idxs, current_pos
--------------------------------------------------------------------------------
/scripts/daam/experiment.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | from pathlib import Path
3 | from typing import List, Optional, Dict, Any
4 | from dataclasses import dataclass
5 | import json
6 |
7 | from transformers import PreTrainedTokenizer, AutoTokenizer
8 | import PIL.Image
9 | import numpy as np
10 | import torch
11 |
12 | from .evaluate import load_mask
13 | from .utils import plot_overlay_heat_map, expand_image
14 |
15 | __all__ = ['GenerationExperiment', 'COCO80_LABELS', 'COCOSTUFF27_LABELS', 'COCO80_INDICES', 'build_word_list_coco80']
16 |
17 |
18 | COCO80_LABELS: List[str] = [
19 | 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light',
20 | 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
21 | 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee',
22 | 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard',
23 | 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
24 | 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch',
25 | 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
26 | 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear',
27 | 'hair drier', 'toothbrush'
28 | ]
29 |
30 | COCO80_INDICES: Dict[str, int] = {x: i for i, x in enumerate(COCO80_LABELS)}
31 |
32 | UNUSED_LABELS: List[str] = [f'__unused_{i}__' for i in range(1, 200)]
33 |
34 | COCOSTUFF27_LABELS: List[str] = [
35 | 'electronic', 'appliance', 'food', 'furniture', 'indoor', 'kitchen', 'accessory', 'animal', 'outdoor', 'person',
36 | 'sports', 'vehicle', 'ceiling', 'floor', 'food', 'furniture', 'rawmaterial', 'textile', 'wall', 'window',
37 | 'building', 'ground', 'plant', 'sky', 'solid', 'structural', 'water'
38 | ]
39 |
40 | COCO80_ONTOLOGY = {
41 | 'two-wheeled vehicle': ['bicycle', 'motorcycle'],
42 | 'vehicle': ['two-wheeled vehicle', 'four-wheeled vehicle'],
43 | 'four-wheeled vehicle': ['bus', 'truck', 'car'],
44 | 'four-legged animals': ['livestock', 'pets', 'wild animals'],
45 | 'livestock': ['cow', 'horse', 'sheep'],
46 | 'pets': ['cat', 'dog'],
47 | 'wild animals': ['elephant', 'bear', 'zebra', 'giraffe'],
48 | 'bags': ['backpack', 'handbag', 'suitcase'],
49 | 'sports boards': ['snowboard', 'surfboard', 'skateboard'],
50 | 'utensils': ['fork', 'knife', 'spoon'],
51 | 'receptacles': ['bowl', 'cup'],
52 | 'fruits': ['banana', 'apple', 'orange'],
53 | 'foods': ['fruits', 'meals', 'desserts'],
54 | 'meals': ['sandwich', 'hot dog', 'pizza'],
55 | 'desserts': ['cake', 'donut'],
56 | 'furniture': ['chair', 'couch', 'bench'],
57 | 'electronics': ['monitors', 'appliances'],
58 | 'monitors': ['tv', 'cell phone', 'laptop'],
59 | 'appliances': ['oven', 'toaster', 'refrigerator']
60 | }
61 |
62 | COCO80_TO_27 = {
63 | 'bicycle': 'vehicle', 'car': 'vehicle', 'motorcycle': 'vehicle', 'airplane': 'vehicle', 'bus': 'vehicle',
64 | 'train': 'vehicle', 'truck': 'vehicle', 'boat': 'vehicle', 'traffic light': 'accessory', 'fire hydrant': 'accessory',
65 | 'stop sign': 'accessory', 'parking meter': 'accessory', 'bench': 'furniture', 'bird': 'animal', 'cat': 'animal',
66 | 'dog': 'animal', 'horse': 'animal', 'sheep': 'animal', 'cow': 'animal', 'elephant': 'animal', 'bear': 'animal',
67 | 'zebra': 'animal', 'giraffe': 'animal', 'backpack': 'accessory', 'umbrella': 'accessory', 'handbag': 'accessory',
68 | 'tie': 'accessory', 'suitcase': 'accessory', 'frisbee': 'sports', 'skis': 'sports', 'snowboard': 'sports',
69 | 'sports ball': 'sports', 'kite': 'sports', 'baseball bat': 'sports', 'baseball glove': 'sports',
70 | 'skateboard': 'sports', 'surfboard': 'sports', 'tennis racket': 'sports', 'bottle': 'food', 'wine glass': 'food',
71 | 'cup': 'food', 'fork': 'food', 'knife': 'food', 'spoon': 'food', 'bowl': 'food', 'banana': 'food', 'apple': 'food',
72 | 'sandwich': 'food', 'orange': 'food', 'broccoli': 'food', 'carrot': 'food', 'hot dog': 'food', 'pizza': 'food',
73 | 'donut': 'food', 'cake': 'food', 'chair': 'furniture', 'couch': 'furniture', 'potted plant': 'plant',
74 | 'bed': 'furniture', 'dining table': 'furniture', 'toilet': 'furniture', 'tv': 'electronic', 'laptop': 'electronic',
75 | 'mouse': 'electronic', 'remote': 'electronic', 'keyboard': 'electronic', 'cell phone': 'electronic',
76 | 'microwave': 'appliance', 'oven': 'appliance', 'toaster': 'appliance', 'sink': 'appliance',
77 | 'refrigerator': 'appliance', 'book': 'indoor', 'clock': 'indoor', 'vase': 'indoor', 'scissors': 'indoor',
78 | 'teddy bear': 'indoor', 'hair drier': 'indoor', 'toothbrush': 'indoor'
79 | }
80 |
81 |
82 | def build_word_list_coco80() -> Dict[str, List[str]]:
83 | words_map = COCO80_ONTOLOGY.copy()
84 | words_map = {k: v for k, v in words_map.items() if not any(item in COCO80_ONTOLOGY for item in v)}
85 |
86 | return words_map
87 |
88 |
89 | def _add_mask(masks: Dict[str, torch.Tensor], word: str, mask: torch.Tensor, simplify80: bool = False) -> Dict[str, torch.Tensor]:
90 | if simplify80:
91 | word = COCO80_TO_27.get(word, word)
92 |
93 | if word in masks:
94 | masks[word] = masks[word.lower()] + mask
95 | masks[word].clamp_(0, 1)
96 | else:
97 | masks[word] = mask
98 |
99 | return masks
100 |
101 |
102 | @dataclass
103 | class GenerationExperiment:
104 | """Class to hold experiment parameters. Pickleable."""
105 | id: str
106 | image: PIL.Image.Image
107 | global_heat_map: torch.Tensor
108 | seed: int
109 | prompt: str
110 |
111 | path: Optional[Path] = None
112 | truth_masks: Optional[Dict[str, torch.Tensor]] = None
113 | prediction_masks: Optional[Dict[str, torch.Tensor]] = None
114 | annotations: Optional[Dict[str, Any]] = None
115 | subtype: Optional[str] = '.'
116 |
117 | def nsfw(self) -> bool:
118 | return np.sum(np.array(self.image)) == 0
119 |
120 | def heat_map(self, tokenizer: AutoTokenizer):
121 | from .trace import HeatMap
122 | return HeatMap(tokenizer, self.prompt, self.global_heat_map)
123 |
124 | def save(self, path: str = None):
125 | if path is None:
126 | path = self.path
127 | else:
128 | path = Path(path) / self.id
129 |
130 | (path / self.subtype).mkdir(parents=True, exist_ok=True)
131 | torch.save(self, path / self.subtype / 'generation.pt')
132 | self.image.save(path / self.subtype / 'output.png')
133 |
134 | with (path / 'prompt.txt').open('w') as f:
135 | f.write(self.prompt)
136 |
137 | with (path / 'seed.txt').open('w') as f:
138 | f.write(str(self.seed))
139 |
140 | if self.truth_masks is not None:
141 | for name, mask in self.truth_masks.items():
142 | im = PIL.Image.fromarray((mask * 255).unsqueeze(-1).expand(-1, -1, 4).byte().numpy())
143 | im.save(path / f'{name.lower()}.gt.png')
144 |
145 | self.save_annotations()
146 |
147 | def save_annotations(self, path: Path = None):
148 | if path is None:
149 | path = self.path
150 |
151 | if self.annotations is not None:
152 | with (path / 'annotations.json').open('w') as f:
153 | json.dump(self.annotations, f)
154 |
155 | def _load_truth_masks(self, simplify80: bool = False) -> Dict[str, torch.Tensor]:
156 | masks = {}
157 |
158 | for mask_path in self.path.glob('*.gt.png'):
159 | word = mask_path.name.split('.gt.png')[0].lower()
160 | mask = load_mask(str(mask_path))
161 | _add_mask(masks, word, mask, simplify80)
162 |
163 | return masks
164 |
165 | def _load_pred_masks(self, pred_prefix, composite=False, simplify80=False, vocab=None):
166 | # type: (str, bool, bool, List[str] | None) -> Dict[str, torch.Tensor]
167 | masks = {}
168 |
169 | if vocab is None:
170 | vocab = UNUSED_LABELS
171 |
172 | if composite:
173 | try:
174 | im = PIL.Image.open(self.path / self.subtype / f'composite.{pred_prefix}.pred.png')
175 | im = np.array(im)
176 |
177 | for mask_idx in np.unique(im):
178 | mask = torch.from_numpy((im == mask_idx).astype(np.float32))
179 | _add_mask(masks, vocab[mask_idx], mask, simplify80)
180 | except FileNotFoundError:
181 | pass
182 | else:
183 | for mask_path in (self.path / self.subtype).glob(f'*.{pred_prefix}.pred.png'):
184 | mask = load_mask(str(mask_path))
185 | word = mask_path.name.split(f'.{pred_prefix}.pred')[0].lower()
186 | _add_mask(masks, word, mask, simplify80)
187 |
188 | return masks
189 |
190 | def clear_prediction_masks(self, name: str):
191 | path = self if isinstance(self, Path) else self.path
192 | path = path / self.subtype
193 |
194 | for mask_path in path.glob(f'*.{name}.pred.png'):
195 | mask_path.unlink()
196 |
197 | def save_prediction_mask(self, mask: torch.Tensor, word: str, name: str):
198 | path = self if isinstance(self, Path) else self.path
199 | im = PIL.Image.fromarray((mask * 255).unsqueeze(-1).expand(-1, -1, 4).cpu().byte().numpy())
200 | im.save(path / self.subtype / f'{word.lower()}.{name}.pred.png')
201 |
202 | def save_heat_map(self, tokenizer: PreTrainedTokenizer, word: str, crop: int = None) -> Path:
203 | from .trace import HeatMap # because of cyclical import
204 | heat_map = HeatMap(tokenizer, self.prompt, self.global_heat_map)
205 | heat_map = expand_image(heat_map.compute_word_heat_map(word))
206 | path = self.path / self.subtype / f'{word.lower()}.heat_map.png'
207 | plot_overlay_heat_map(self.image, heat_map, word, path, crop=crop)
208 |
209 | return path
210 |
211 | def save_all_heat_maps(self, tokenizer: PreTrainedTokenizer, crop: int = None) -> Dict[str, Path]:
212 | path_map = {}
213 |
214 | for word in self.prompt.split(' '):
215 | try:
216 | path = self.save_heat_map(tokenizer, word, crop=crop)
217 | path_map[word] = path
218 | except:
219 | pass
220 |
221 | return path_map
222 |
223 | @staticmethod
224 | def contains_truth_mask(path: str | Path, prompt_id: str = None) -> bool:
225 | if prompt_id is None:
226 | return any(Path(path).glob('*.gt.png'))
227 | else:
228 | return any((Path(path) / prompt_id).glob('*.gt.png'))
229 |
230 | @staticmethod
231 | def read_seed(path: str | Path, prompt_id: str = None) -> int:
232 | if prompt_id is None:
233 | return int(Path(path).joinpath('seed.txt').read_text())
234 | else:
235 | return int(Path(path).joinpath(prompt_id).joinpath('seed.txt').read_text())
236 |
237 | @staticmethod
238 | def has_annotations(path: str | Path) -> bool:
239 | return Path(path).joinpath('annotations.json').exists()
240 |
241 | @staticmethod
242 | def has_experiment(path: str | Path, prompt_id: str) -> bool:
243 | return (Path(path) / prompt_id / 'generation.pt').exists()
244 |
245 | @staticmethod
246 | def read_prompt(path: str | Path, prompt_id: str = None) -> str:
247 | if prompt_id is None:
248 | prompt_id = '.'
249 |
250 | with (Path(path) / prompt_id / 'prompt.txt').open('r') as f:
251 | return f.read().strip()
252 |
253 | def _try_load_annotations(self):
254 | if not (self.path / 'annotations.json').exists():
255 | return None
256 |
257 | return json.load((self.path / 'annotations.json').open())
258 |
259 | def annotate(self, key: str, value: Any) -> 'GenerationExperiment':
260 | if self.annotations is None:
261 | self.annotations = {}
262 |
263 | self.annotations[key] = value
264 |
265 | return self
266 |
267 | @classmethod
268 | def load(
269 | cls,
270 | path,
271 | pred_prefix='daam',
272 | composite=False,
273 | simplify80=False,
274 | vocab=None,
275 | subtype='.',
276 | all_subtypes=False
277 | ):
278 | # type: (str, str, bool, bool, List[str] | None, str, bool) -> GenerationExperiment | List[GenerationExperiment]
279 | if all_subtypes:
280 | experiments = []
281 |
282 | for directory in Path(path).iterdir():
283 | if not directory.is_dir():
284 | continue
285 |
286 | try:
287 | experiments.append(cls.load(
288 | path,
289 | pred_prefix=pred_prefix,
290 | composite=composite,
291 | simplify80=simplify80,
292 | vocab=vocab,
293 | subtype=directory.name
294 | ))
295 | except:
296 | pass
297 |
298 | return experiments
299 |
300 | path = Path(path)
301 | exp = torch.load(path / subtype / 'generation.pt')
302 | exp.subtype = subtype
303 | exp.path = path
304 | exp.truth_masks = exp._load_truth_masks(simplify80=simplify80)
305 | exp.prediction_masks = exp._load_pred_masks(pred_prefix, composite=composite, simplify80=simplify80, vocab=vocab)
306 | exp.annotations = exp._try_load_annotations()
307 |
308 | return exp
309 |
--------------------------------------------------------------------------------
/scripts/daam_script.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | from collections import defaultdict
3 | import os
4 | import re
5 | import traceback
6 |
7 | import gradio as gr
8 | import modules.images as images
9 | import modules.scripts as scripts
10 | import torch
11 | from ldm.modules.encoders.modules import FrozenCLIPEmbedder, FrozenOpenCLIPEmbedder
12 | import sgm
13 | from sgm.modules import GeneralConditioner
14 | import open_clip.tokenizer
15 | from modules import script_callbacks
16 | from modules import script_callbacks, sd_hijack_clip, sd_hijack_open_clip
17 | from modules.processing import (Processed, StableDiffusionProcessing, fix_seed,
18 | process_images)
19 | from modules.shared import cmd_opts, opts, state
20 | import modules.shared as shared
21 | from PIL import Image
22 |
23 | from scripts.daam import trace, utils
24 |
25 | before_image_saved_handler = None
26 |
27 | class Script(scripts.Script):
28 |
29 | GRID_LAYOUT_AUTO = "Auto"
30 | GRID_LAYOUT_PREVENT_EMPTY = "Prevent Empty Spot"
31 | GRID_LAYOUT_BATCH_LENGTH_AS_ROW = "Batch Length As Row"
32 |
33 |
34 | def title(self):
35 | return "Daam script"
36 |
37 | def show(self, is_img2img):
38 | return scripts.AlwaysVisible
39 |
40 | def ui(self, is_img2img):
41 | with gr.Group():
42 | with gr.Accordion("Attention Heatmap", open=False):
43 | attention_texts = gr.Text(label='Attention texts for visualization. (comma separated)', value='')
44 |
45 | with gr.Row():
46 | hide_images = gr.Checkbox(label='Hide heatmap images', value=False)
47 |
48 | dont_save_images = gr.Checkbox(label='Do not save heatmap images', value=False)
49 |
50 | hide_caption = gr.Checkbox(label='Hide caption', value=False)
51 |
52 | with gr.Row():
53 | use_grid = gr.Checkbox(label='Use grid (output to grid dir)', value=False)
54 |
55 | grid_layouyt = gr.Dropdown(
56 | [Script.GRID_LAYOUT_AUTO, Script.GRID_LAYOUT_PREVENT_EMPTY, Script.GRID_LAYOUT_BATCH_LENGTH_AS_ROW], label="Grid layout",
57 | value=Script.GRID_LAYOUT_AUTO
58 | )
59 |
60 | with gr.Row():
61 | alpha = gr.Slider(label='Heatmap blend alpha', value=0.5, minimum=0, maximum=1, step=0.01)
62 |
63 | heatmap_image_scale = gr.Slider(label='Heatmap image scale', value=1.0, minimum=0.1, maximum=1, step=0.025)
64 |
65 | with gr.Row():
66 | trace_each_layers = gr.Checkbox(label = 'Trace each layers', value=False)
67 |
68 | layers_as_row = gr.Checkbox(label = 'Use layers as row instead of Batch Length', value=False)
69 |
70 |
71 | self.tracers = None
72 |
73 | return [attention_texts, hide_images, dont_save_images, hide_caption, use_grid, grid_layouyt, alpha, heatmap_image_scale, trace_each_layers, layers_as_row]
74 |
75 | def process(self,
76 | p : StableDiffusionProcessing,
77 | attention_texts : str,
78 | hide_images : bool,
79 | dont_save_images : bool,
80 | hide_caption : bool,
81 | use_grid : bool,
82 | grid_layouyt :str,
83 | alpha : float,
84 | heatmap_image_scale : float,
85 | trace_each_layers : bool,
86 | layers_as_row: bool):
87 |
88 | self.enabled = False # in case the assert fails
89 | assert opts.samples_save, "Cannot run Daam script. Enable 'Always save all generated images' setting."
90 |
91 | self.images = []
92 | self.hide_images = hide_images
93 | self.dont_save_images = dont_save_images
94 | self.hide_caption = hide_caption
95 | self.alpha = alpha
96 | self.use_grid = use_grid
97 | self.grid_layouyt = grid_layouyt
98 | self.heatmap_image_scale = heatmap_image_scale
99 | self.heatmap_images = dict()
100 |
101 | self.attentions = [s.strip() for s in attention_texts.split(",") if s.strip()]
102 | self.enabled = len(self.attentions) > 0
103 |
104 | fix_seed(p)
105 |
106 | def process_batch(self,
107 | p : StableDiffusionProcessing,
108 | attention_texts : str,
109 | hide_images : bool,
110 | dont_save_images : bool,
111 | hide_caption : bool,
112 | use_grid : bool,
113 | grid_layouyt :str,
114 | alpha : float,
115 | heatmap_image_scale : float,
116 | trace_each_layers : bool,
117 | layers_as_row: bool,
118 | prompts,
119 | **kwargs):
120 |
121 | if not self.enabled:
122 | return
123 |
124 | styled_prompt = prompts[0]
125 |
126 | embedder = None
127 | if type(p.sd_model.cond_stage_model) == sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords or \
128 | type(p.sd_model.cond_stage_model) == sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords:
129 | embedder = p.sd_model.cond_stage_model
130 | elif type(p.sd_model.cond_stage_model) == GeneralConditioner:
131 | conditioner = p.sd_model.cond_stage_model
132 | print(conditioner.embedders)
133 | embedder = conditioner.embedders[0]
134 | else:
135 | assert False, f"Embedder '{type(p.sd_model.cond_stage_model)}' is not supported."
136 |
137 | clip = None
138 | tokenize = None
139 | clip_type = type(embedder.wrapped)
140 | if clip_type == FrozenCLIPEmbedder:
141 | clip : FrozenCLIPEmbedder = embedder.wrapped
142 | tokenize = clip.tokenizer.tokenize
143 | elif clip_type == FrozenOpenCLIPEmbedder:
144 | clip : FrozenOpenCLIPEmbedder = embedder.wrapped
145 | tokenize = open_clip.tokenizer._tokenizer.encode
146 | elif clip_type == sgm.modules.encoders.modules.FrozenCLIPEmbedder:
147 | clip : sgm.modules.encoders.modules.FrozenCLIPEmbedder = embedder.wrapped
148 | tokenize = clip.tokenizer.tokenize
149 | else:
150 | assert False, f"CLIP '{clip_type}' is not supported."
151 |
152 | tokens = tokenize(utils.escape_prompt(styled_prompt))
153 | context_size = utils.calc_context_size(len(tokens))
154 |
155 | prompt_analyzer = utils.PromptAnalyzer(embedder, styled_prompt)
156 | self.prompt_analyzer = prompt_analyzer
157 | context_size = prompt_analyzer.context_size
158 |
159 | print(f"daam run with context_size={prompt_analyzer.context_size}, token_count={prompt_analyzer.token_count}")
160 | print(f"remade_tokens={prompt_analyzer.tokens}, multipliers={prompt_analyzer.multipliers}")
161 | print(f"hijack_comments={prompt_analyzer.hijack_comments}, used_custom_terms={prompt_analyzer.used_custom_terms}")
162 | print(f"fixes={prompt_analyzer.fixes}")
163 |
164 | if any(item[0] in self.attentions for item in self.prompt_analyzer.used_custom_terms):
165 | print("Embedding heatmap cannot be shown.")
166 |
167 | global before_image_saved_handler
168 | before_image_saved_handler = lambda params : self.before_image_saved(params)
169 |
170 | with torch.no_grad():
171 | # cannot trace the same block from two tracers
172 | if trace_each_layers:
173 | num_input = len(p.sd_model.model.diffusion_model.input_blocks)
174 | num_output = len(p.sd_model.model.diffusion_model.output_blocks)
175 | self.tracers = [trace(p.sd_model, p.height, p.width, context_size, layer_idx=i) for i in range(num_input + num_output + 1)]
176 | self.attn_captions = [f"IN{i:02d}" for i in range(num_input)] + ["MID"] + [f"OUT{i:02d}" for i in range(num_output)]
177 | else:
178 | self.tracers = [trace(p.sd_model, p.height, p.width, context_size)]
179 | self.attn_captions = [""]
180 |
181 | for tracer in self.tracers:
182 | tracer.hook()
183 |
184 | def postprocess(self, p, processed,
185 | attention_texts : str,
186 | hide_images : bool,
187 | dont_save_images : bool,
188 | hide_caption : bool,
189 | use_grid : bool,
190 | grid_layouyt :str,
191 | alpha : float,
192 | heatmap_image_scale : float,
193 | trace_each_layers : bool,
194 | layers_as_row: bool,
195 | **kwargs):
196 | if self.enabled == False:
197 | return
198 |
199 | for trace in self.tracers:
200 | trace.unhook()
201 | self.tracers = None
202 |
203 | initial_info = None
204 |
205 | if initial_info is None:
206 | initial_info = processed.info
207 |
208 | self.images += processed.images
209 |
210 | global before_image_saved_handler
211 | before_image_saved_handler = None
212 |
213 | if layers_as_row:
214 | images_list = []
215 | for i in range(p.batch_size * p.n_iter):
216 | imgs = []
217 | for k in sorted(self.heatmap_images.keys()):
218 | imgs += [self.heatmap_images[k][len(self.attentions)*i + j] for j in range(len(self.attentions))]
219 | images_list.append(imgs)
220 | else:
221 | images_list = [self.heatmap_images[k] for k in sorted(self.heatmap_images.keys())]
222 |
223 | for img_list in images_list:
224 |
225 | if img_list and self.use_grid:
226 |
227 | grid_layout = self.grid_layouyt
228 | if grid_layout == Script.GRID_LAYOUT_AUTO:
229 | if p.batch_size * p.n_iter == 1:
230 | grid_layout = Script.GRID_LAYOUT_PREVENT_EMPTY
231 | else:
232 | grid_layout = Script.GRID_LAYOUT_BATCH_LENGTH_AS_ROW
233 |
234 | if grid_layout == Script.GRID_LAYOUT_PREVENT_EMPTY:
235 | grid_img = images.image_grid(img_list)
236 | elif grid_layout == Script.GRID_LAYOUT_BATCH_LENGTH_AS_ROW:
237 | if layers_as_row:
238 | batch_size = len(self.attentions)
239 | rows = len(self.heatmap_images)
240 | else:
241 | batch_size = p.batch_size
242 | rows = p.batch_size * p.n_iter
243 | grid_img = images.image_grid(img_list, batch_size=batch_size, rows=rows)
244 | else:
245 | pass
246 |
247 | if not self.dont_save_images:
248 | images.save_image(grid_img, p.outpath_grids, "grid_daam", grid=True, p=p)
249 |
250 | if not self.hide_images:
251 | processed.images.insert(0, grid_img)
252 | processed.index_of_first_image += 1
253 | processed.infotexts.insert(0, processed.infotexts[0])
254 |
255 | else:
256 | if not self.hide_images:
257 | processed.images[:0] = img_list
258 | processed.index_of_first_image += len(img_list)
259 | processed.infotexts[:0] = [processed.infotexts[0]] * len(img_list)
260 |
261 | return processed
262 |
263 | def before_image_saved(self, params : script_callbacks.ImageSaveParams):
264 | batch_pos = -1
265 | if params.p.batch_size > 1:
266 | match = re.search(r"Batch pos: (\d+)", params.pnginfo['parameters'])
267 | if match:
268 | batch_pos = int(match.group(1))
269 | else:
270 | batch_pos = 0
271 |
272 | if batch_pos < 0:
273 | return
274 |
275 | if self.tracers is not None and len(self.attentions) > 0:
276 | for i, tracer in enumerate(self.tracers):
277 | with torch.no_grad():
278 | styled_prompot = shared.prompt_styles.apply_styles_to_prompt(params.p.prompt, params.p.styles)
279 | try:
280 | global_heat_map = tracer.compute_global_heat_map(self.prompt_analyzer, styled_prompot, batch_pos)
281 | except:
282 | continue
283 |
284 | if i not in self.heatmap_images:
285 | self.heatmap_images[i] = []
286 |
287 | if global_heat_map is not None:
288 | heatmap_images = []
289 | for attention in self.attentions:
290 |
291 | img_size = params.image.size
292 | caption = attention + (" " + self.attn_captions[i] if self.attn_captions[i] else "") if not self.hide_caption else None
293 |
294 | heat_map = global_heat_map.compute_word_heat_map(attention)
295 | if heat_map is None : print(f"No heatmaps for '{attention}'")
296 |
297 | heat_map_img = utils.expand_image(heat_map, img_size[1], img_size[0]) if heat_map is not None else None
298 | img : Image.Image = utils.image_overlay_heat_map(params.image, heat_map_img, alpha=self.alpha, caption=caption, image_scale=self.heatmap_image_scale)
299 |
300 | fullfn_without_extension, extension = os.path.splitext(params.filename)
301 | full_filename = fullfn_without_extension + "_" + attention + ("_" + self.attn_captions[i] if self.attn_captions[i] else "") + extension
302 |
303 | if self.use_grid:
304 | heatmap_images.append(img)
305 | else:
306 | heatmap_images.append(img)
307 | if not self.dont_save_images:
308 | img.save(full_filename)
309 |
310 | self.heatmap_images[i] += heatmap_images
311 |
312 | self.heatmap_images = {j:self.heatmap_images[j] for j in self.heatmap_images.keys() if self.heatmap_images[j]}
313 |
314 | # if it is last batch pos, clear heatmaps
315 | if batch_pos == params.p.batch_size - 1:
316 | for tracer in self.tracers:
317 | tracer.reset()
318 |
319 | return
320 |
321 |
322 | def handle_before_image_saved(params : script_callbacks.ImageSaveParams):
323 |
324 | if before_image_saved_handler is not None and callable(before_image_saved_handler):
325 | before_image_saved_handler(params)
326 |
327 | return
328 |
329 | script_callbacks.on_before_image_saved(handle_before_image_saved)
330 |
--------------------------------------------------------------------------------
/scripts/daam/trace.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | from collections import defaultdict
3 | from copy import deepcopy
4 | from pathlib import Path
5 | from typing import List, Type, Any, Literal, Dict
6 | import math
7 | from modules.devices import device
8 |
9 | from ldm.models.diffusion.ddpm import DiffusionWrapper, LatentDiffusion
10 | from ldm.modules.diffusionmodules.openaimodel import UNetModel
11 | from ldm.modules.attention import CrossAttention, default, exists
12 |
13 | import numba
14 | import numpy as np
15 | import torch
16 | import torch.nn.functional as F
17 | from torch import nn, einsum
18 | from einops import rearrange, repeat
19 |
20 | from .experiment import COCO80_LABELS
21 | from .hook import ObjectHooker, AggregateHooker, UNetCrossAttentionLocator
22 | from .utils import compute_token_merge_indices, PromptAnalyzer
23 |
24 |
25 | __all__ = ['trace', 'DiffusionHeatMapHooker', 'HeatMap', 'MmDetectHeatMap']
26 |
27 |
28 | class UNetForwardHooker(ObjectHooker[UNetModel]):
29 | def __init__(self, module: UNetModel, heat_maps: defaultdict(defaultdict)):
30 | super().__init__(module)
31 | self.all_heat_maps = []
32 | self.heat_maps = heat_maps
33 |
34 | def _hook_impl(self):
35 | self.monkey_patch('forward', self._forward)
36 |
37 | def _unhook_impl(self):
38 | pass
39 |
40 | def _forward(hk_self, self, *args, **kwargs):
41 | super_return = hk_self.monkey_super('forward', *args, **kwargs)
42 | hk_self.all_heat_maps.append(deepcopy(hk_self.heat_maps))
43 | hk_self.heat_maps.clear()
44 |
45 | return super_return
46 |
47 |
48 | class HeatMap:
49 | def __init__(self, prompt_analyzer: PromptAnalyzer, prompt: str, heat_maps: torch.Tensor):
50 | self.prompt_analyzer = prompt_analyzer.create(prompt)
51 | self.heat_maps = heat_maps
52 | self.prompt = prompt
53 |
54 | def compute_word_heat_map(self, word: str, word_idx: int = None) -> torch.Tensor:
55 | merge_idxs, _ = self.prompt_analyzer.calc_word_indecies(word)
56 | # print("merge_idxs", merge_idxs)
57 | if len(merge_idxs) == 0:
58 | return None
59 |
60 | return self.heat_maps[merge_idxs].mean(0)
61 |
62 |
63 | class MmDetectHeatMap:
64 | def __init__(self, pred_file: str | Path, threshold: float = 0.95):
65 | @numba.njit
66 | def _compute_mask(masks: np.ndarray, bboxes: np.ndarray):
67 | x_any = np.any(masks, axis=1)
68 | y_any = np.any(masks, axis=2)
69 | num_masks = len(bboxes)
70 |
71 | for idx in range(num_masks):
72 | x = np.where(x_any[idx, :])[0]
73 | y = np.where(y_any[idx, :])[0]
74 | bboxes[idx, :4] = np.array([x[0], y[0], x[-1] + 1, y[-1] + 1], dtype=np.float32)
75 |
76 | pred_file = Path(pred_file)
77 | self.word_masks: Dict[str, torch.Tensor] = defaultdict(lambda: 0)
78 | bbox_result, masks = torch.load(pred_file)
79 | labels = [np.full(bbox.shape[0], i, dtype=np.int32) for i, bbox in enumerate(bbox_result)]
80 | labels = np.concatenate(labels)
81 | bboxes = np.vstack(bbox_result)
82 |
83 | if masks is not None and bboxes[:, :4].sum() == 0:
84 | _compute_mask(masks, bboxes)
85 | scores = bboxes[:, -1]
86 | inds = scores > threshold
87 | labels = labels[inds]
88 | masks = masks[inds, ...]
89 |
90 | for lbl, mask in zip(labels, masks):
91 | self.word_masks[COCO80_LABELS[lbl]] |= torch.from_numpy(mask)
92 |
93 | self.word_masks = {k: v.float() for k, v in self.word_masks.items()}
94 |
95 | def compute_word_heat_map(self, word: str) -> torch.Tensor:
96 | return self.word_masks[word]
97 |
98 |
99 | class DiffusionHeatMapHooker(AggregateHooker):
100 | def __init__(self, model: LatentDiffusion, heigth : int, width : int, context_size : int = 77, weighted: bool = False, layer_idx: int = None, head_idx: int = None):
101 | heat_maps = defaultdict(lambda: defaultdict(list)) # batch index, factor, attention
102 | modules = [UNetCrossAttentionHooker(x, heigth, width, heat_maps, context_size=context_size, weighted=weighted, head_idx=head_idx) for x in UNetCrossAttentionLocator().locate(model.model.diffusion_model, layer_idx)]
103 | self.forward_hook = UNetForwardHooker(model.model.diffusion_model, heat_maps)
104 | modules.append(self.forward_hook)
105 |
106 | self.height = heigth
107 | self.width = width
108 | self.model = model
109 | self.last_prompt = ''
110 |
111 | super().__init__(modules)
112 |
113 |
114 |
115 | @property
116 | def all_heat_maps(self):
117 | return self.forward_hook.all_heat_maps
118 |
119 | def reset(self):
120 | map(lambda module: module.reset(), self.module)
121 | return self.forward_hook.all_heat_maps.clear()
122 |
123 | def compute_global_heat_map(self, prompt_analyzer, prompt, batch_index, time_weights=None, time_idx=None, last_n=None, first_n=None, factors=None):
124 | # type: (PromptAnalyzer, str, int, int, int, int, int, List[float]) -> HeatMap
125 | """
126 | Compute the global heat map for the given prompt, aggregating across time (inference steps) and space (different
127 | spatial transformer block heat maps).
128 |
129 | Args:
130 | prompt: The prompt to compute the heat map for.
131 | time_weights: The weights to apply to each time step. If None, all time steps are weighted equally.
132 | time_idx: The time step to compute the heat map for. If None, the heat map is computed for all time steps.
133 | Mutually exclusive with `last_n` and `first_n`.
134 | last_n: The number of last n time steps to use. If None, the heat map is computed for all time steps.
135 | Mutually exclusive with `time_idx`.
136 | first_n: The number of first n time steps to use. If None, the heat map is computed for all time steps.
137 | Mutually exclusive with `time_idx`.
138 | factors: Restrict the application to heat maps with spatial factors in this set. If `None`, use all sizes.
139 | """
140 | if len(self.forward_hook.all_heat_maps) == 0:
141 | return None
142 |
143 | if time_weights is None:
144 | time_weights = [1.0] * len(self.forward_hook.all_heat_maps)
145 |
146 | time_weights = np.array(time_weights)
147 | time_weights /= time_weights.sum()
148 | all_heat_maps = self.forward_hook.all_heat_maps
149 |
150 | if time_idx is not None:
151 | heat_maps = [all_heat_maps[time_idx]]
152 | else:
153 | heat_maps = all_heat_maps[-last_n:] if last_n is not None else all_heat_maps
154 | heat_maps = heat_maps[:first_n] if first_n is not None else heat_maps
155 |
156 |
157 | if factors is None:
158 | factors = {1, 2, 4, 8, 16, 32}
159 | else:
160 | factors = set(factors)
161 |
162 | all_merges = []
163 |
164 | for batch_to_heat_maps in heat_maps:
165 |
166 | if not (batch_index in batch_to_heat_maps):
167 | continue
168 |
169 | merge_list = []
170 |
171 | factors_to_heat_maps = batch_to_heat_maps[batch_index]
172 |
173 | for k, heat_map in factors_to_heat_maps.items():
174 | # heat_map shape: (tokens, 1, height, width)
175 | # each v is a heat map tensor for a layer of factor size k across the tokens
176 | if k in factors:
177 | merge_list.append(torch.stack(heat_map, 0).mean(0))
178 |
179 | if len(merge_list) > 0:
180 | all_merges.append(merge_list)
181 |
182 | maps = torch.stack([torch.stack(x, 0) for x in all_merges], dim=0)
183 | maps = maps.sum(0).to(device).sum(2).sum(0)
184 |
185 | return HeatMap(prompt_analyzer, prompt, maps)
186 |
187 |
188 | class UNetCrossAttentionHooker(ObjectHooker[CrossAttention]):
189 | def __init__(self, module: CrossAttention, img_height : int, img_width : int, heat_maps: defaultdict(defaultdict), context_size: int = 77, weighted: bool = False, head_idx: int = 0):
190 | super().__init__(module)
191 | self.heat_maps = heat_maps
192 | self.context_size = context_size
193 | self.weighted = weighted
194 | self.head_idx = head_idx
195 | self.img_height = img_height
196 | self.img_width = img_width
197 | self.calledCount = 0
198 |
199 | def reset(self):
200 | self.heat_maps.clear()
201 | self.calledCount = 0
202 |
203 | @torch.no_grad()
204 | def _up_sample_attn(self, x, value, factor, method='bicubic'):
205 | # type: (torch.Tensor, torch.Tensor, int, Literal['bicubic', 'conv']) -> torch.Tensor
206 | # x shape: (heads, height * width, tokens)
207 | """
208 | Up samples the attention map in x using interpolation to the maximum size of (64, 64), as assumed in the Stable
209 | Diffusion model.
210 |
211 | Args:
212 | x (`torch.Tensor`): cross attention slice/map between the words and the tokens.
213 | value (`torch.Tensor`): the value tensor.
214 | method (`str`): the method to use; one of `'bicubic'` or `'conv'`.
215 |
216 | Returns:
217 | `torch.Tensor`: the up-sampled attention map of shape (tokens, 1, height, width).
218 | """
219 | weight = torch.full((factor, factor), 1 / factor ** 2, device=x.device)
220 | weight = weight.view(1, 1, factor, factor)
221 |
222 | h = int(math.sqrt ( (self.img_height * x.size(1)) / self.img_width))
223 | w = int(self.img_width * h / self.img_height)
224 |
225 | h_fix = w_fix = 64
226 | if h >= w:
227 | w_fix = int((w * h_fix) / h)
228 | else:
229 | h_fix = int((h * w_fix) / w)
230 |
231 | maps = []
232 | x = x.permute(2, 0, 1)
233 | value = value.permute(1, 0, 2)
234 | weights = 1
235 |
236 | with torch.cuda.amp.autocast(dtype=torch.float32):
237 | for map_ in x:
238 | map_ = map_.unsqueeze(1).view(map_.size(0), 1, h, w)
239 |
240 | if method == 'bicubic':
241 | map_ = F.interpolate(map_, size=(h_fix, w_fix), mode='bicubic')
242 | maps.append(map_.squeeze(1))
243 | else:
244 | maps.append(F.conv_transpose2d(map_, weight, stride=factor).squeeze(1))
245 |
246 | if self.weighted:
247 | weights = value.norm(p=1, dim=-1, keepdim=True).unsqueeze(-1)
248 |
249 | maps = torch.stack(maps, 0) # shape: (tokens, heads, height, width)
250 |
251 | if self.head_idx:
252 | maps = maps[:, self.head_idx:self.head_idx+1, :, :]
253 |
254 | return (weights * maps).sum(1, keepdim=True).cpu()
255 |
256 | def _forward(hk_self, self, x, context=None, mask=None, additional_tokens=None):
257 |
258 | if additional_tokens is not None:
259 | # get the number of masked tokens at the beginning of the output sequence
260 | n_tokens_to_mask = additional_tokens.shape[1]
261 | # add additional token
262 | x = torch.cat([additional_tokens, x], dim=1)
263 |
264 | hk_self.calledCount += 1
265 | batch_size, sequence_length, _ = x.shape
266 | h = self.heads
267 |
268 | q = self.to_q(x)
269 | context = default(context, x)
270 | k = self.to_k(context)
271 | v = self.to_v(context)
272 |
273 | dim = q.shape[-1]
274 |
275 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
276 |
277 | sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
278 |
279 | if exists(mask):
280 | mask = rearrange(mask, 'b ... -> b (...)')
281 | max_neg_value = -torch.finfo(sim.dtype).max
282 | mask = repeat(mask, 'b j -> (b h) () j', h=h)
283 | sim.masked_fill_(~mask, max_neg_value)
284 |
285 | out = hk_self._hooked_attention(self, q, k, v, batch_size, sequence_length, dim)
286 |
287 | if additional_tokens is not None:
288 | # remove additional token
289 | out = out[:, n_tokens_to_mask:]
290 |
291 | return self.to_out(out)
292 |
293 | ### forward implemetation of diffuser CrossAttention
294 | # def forward(self, hidden_states, context=None, mask=None):
295 | # batch_size, sequence_length, _ = hidden_states.shape
296 |
297 | # query = self.to_q(hidden_states)
298 | # context = context if context is not None else hidden_states
299 | # key = self.to_k(context)
300 | # value = self.to_v(context)
301 |
302 | # dim = query.shape[-1]
303 |
304 | # query = self.reshape_heads_to_batch_dim(query)
305 | # key = self.reshape_heads_to_batch_dim(key)
306 | # value = self.reshape_heads_to_batch_dim(value)
307 |
308 | # # TODO(PVP) - mask is currently never used. Remember to re-implement when used
309 |
310 | # # attention, what we cannot get enough of
311 | # if self._use_memory_efficient_attention_xformers:
312 | # hidden_states = self._memory_efficient_attention_xformers(query, key, value)
313 | # # Some versions of xformers return output in fp32, cast it back to the dtype of the input
314 | # hidden_states = hidden_states.to(query.dtype)
315 | # else:
316 | # if self._slice_size is None or query.shape[0] // self._slice_size == 1:
317 | # hidden_states = self._attention(query, key, value)
318 | # else:
319 | # hidden_states = self._sliced_attention(query, key, value, sequence_length, dim)
320 |
321 | # # linear proj
322 | # hidden_states = self.to_out[0](hidden_states)
323 | # # dropout
324 | # hidden_states = self.to_out[1](hidden_states)
325 | # return hidden_states
326 |
327 | def _hooked_attention(hk_self, self, query, key, value, batch_size, sequence_length, dim, use_context: bool = True):
328 | """
329 | Monkey-patched version of :py:func:`.CrossAttention._attention` to capture attentions and aggregate them.
330 |
331 | Args:
332 | hk_self (`UNetCrossAttentionHooker`): pointer to the hook itself.
333 | self (`CrossAttention`): pointer to the module.
334 | query (`torch.Tensor`): the query tensor.
335 | key (`torch.Tensor`): the key tensor.
336 | value (`torch.Tensor`): the value tensor.
337 | batch_size (`int`): the batch size
338 | use_context (`bool`): whether to check if the resulting attention slices are between the words and the image
339 | """
340 | batch_size_attention = query.shape[0]
341 | hidden_states = torch.zeros(
342 | (batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype
343 | )
344 | slice_size = hidden_states.shape[0] // batch_size # self._slice_size if self._slice_size is not None else hidden_states.shape[0]
345 |
346 | def calc_factor_base(w, h):
347 | z = max(w/64, h/64)
348 | factor_b = min(w, h) * z
349 | return factor_b
350 |
351 | factor_base = calc_factor_base(hk_self.img_width, hk_self.img_height)
352 |
353 | for batch_index in range(hidden_states.shape[0] // slice_size):
354 | start_idx = batch_index * slice_size
355 | end_idx = (batch_index + 1) * slice_size
356 | attn_slice = (
357 | torch.einsum("b i d, b j d -> b i j", query[start_idx:end_idx], key[start_idx:end_idx]) * self.scale
358 | )
359 | factor = int(math.sqrt(factor_base // attn_slice.shape[1]))
360 | attn_slice = attn_slice.softmax(-1)
361 | hid_states = torch.einsum("b i j, b j d -> b i d", attn_slice, value[start_idx:end_idx])
362 |
363 | if use_context and hk_self.calledCount % 2 == 1 and attn_slice.shape[-1] == hk_self.context_size:
364 | if factor >= 1:
365 | factor //= 1
366 | maps = hk_self._up_sample_attn(attn_slice, value, factor)
367 | hk_self.heat_maps[batch_index][factor].append(maps)
368 |
369 | hidden_states[start_idx:end_idx] = hid_states
370 |
371 | # reshape hidden_states
372 | hidden_states = hk_self.reshape_batch_dim_to_heads(self, hidden_states)
373 | return hidden_states
374 |
375 | def reshape_batch_dim_to_heads(hk_self, self, tensor):
376 | batch_size, seq_len, dim = tensor.shape
377 | head_size = self.heads
378 | tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
379 | tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
380 | return tensor
381 |
382 | def _hook_impl(self):
383 | self.monkey_patch('forward', self._forward)
384 |
385 | @property
386 | def num_heat_maps(self):
387 | return len(next(iter(self.heat_maps.values())))
388 |
389 |
390 | trace: Type[DiffusionHeatMapHooker] = DiffusionHeatMapHooker
391 |
--------------------------------------------------------------------------------