├── .gitignore ├── LICENSE ├── README.md ├── images ├── 00006-2623256163.png ├── 00006-2623256163_beach.png ├── 00006-2623256163_cat.png └── 00006-2623256163_sunglasses.png ├── install.py └── scripts ├── daam ├── __init__.py ├── evaluate.py ├── experiment.py ├── hook.py ├── trace.py └── utils.py └── daam_script.py /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode/ 2 | .env 3 | 4 | __pycache__/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Modified work Copyright (c) 2022 kousw 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | 23 | 24 | ### castorini/daam 25 | 26 | License: MIT 27 | By: Castorini 28 | Repository: 29 | 30 | >MIT License 31 | > 32 | >Copyright (c) 2022 Castorini 33 | > 34 | >Permission is hereby granted, free of charge, to any person obtaining a copy 35 | >of this software and associated documentation files (the "Software"), to deal 36 | >in the Software without restriction, including without limitation the rights 37 | >to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 38 | >copies of the Software, and to permit persons to whom the Software is 39 | >furnished to do so, subject to the following conditions: 40 | > 41 | >The above copyright notice and this permission notice shall be included in all 42 | >copies or substantial portions of the Software. 43 | > 44 | >THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 45 | >IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 46 | >FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 47 | >AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 48 | >LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 49 | >OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 50 | >SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DAAM Extension for Stable Diffusion Web UI 2 | 3 | This is a port of [DAAM](https://github.com/castorini/daam) for Stable Diffusion Web UI. 4 | 5 | *Now support for SDXL!* 6 | 7 | # Update 8 | - `2024/01/28` Fixed an issue where it did not work with the latest version of webui. Added support for SDXL. 9 | 10 | # Setup and Running 11 | 12 | Clone this repository to extension folder. 13 | 14 | # How to use 15 | 16 | Select "Daam script" from the script drop-down. Enter the 'attention text' (must be a string contained in the prompt ) and run. 17 | An overlapping image with a heatmap for each attention will be generated along with the original image. 18 | Images will now be created in the default output directory. 19 | 20 | Attention text is divided by commas, but multiple words without commas are recognized as a single sequence. 21 | If you type "cat" for attention text, then all the tokens matching "cat" will be retrieved and combined into attention. 22 | If you type "cute cat", only tokens with "cute" and "cat" in sequence will be retrieved and only their attention will be output. 23 | 24 | # Sample 25 | 26 | prompt : "A photo of a cute cat wearing sunglasses relaxing on a beach" 27 | 28 | attention text: "cat, sunglasses, beach" 29 | 30 | output images: orginal, cat, sunglasses, beach 31 | 32 | 33 | 34 | 35 | 36 | 37 | # Tutorial 38 | 39 | - [Easiest way to use DAAM script tutorial](https://www.youtube.com/watch?v=XiKyEKJrTLQ) 40 | 41 | [![image.png](https://s3.amazonaws.com/moonup/production/uploads/1675628788246-6345bd89fe134dfd7a0dba40.png)](https://www.youtube.com/watch?v=XiKyEKJrTLQ) 42 | 43 | # Notice 44 | At the moment, this works well with the Stable Diffusion 1.5 model. 45 | However, in the Stable Diffusion 2.0 model this seems to be working a little less well. 46 | 47 | -------------------------------------------------------------------------------- /images/00006-2623256163.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kousw/stable-diffusion-webui-daam/b23fb574bf691f0bdf503e5617a0b3578160c7a1/images/00006-2623256163.png -------------------------------------------------------------------------------- /images/00006-2623256163_beach.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kousw/stable-diffusion-webui-daam/b23fb574bf691f0bdf503e5617a0b3578160c7a1/images/00006-2623256163_beach.png -------------------------------------------------------------------------------- /images/00006-2623256163_cat.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kousw/stable-diffusion-webui-daam/b23fb574bf691f0bdf503e5617a0b3578160c7a1/images/00006-2623256163_cat.png -------------------------------------------------------------------------------- /images/00006-2623256163_sunglasses.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kousw/stable-diffusion-webui-daam/b23fb574bf691f0bdf503e5617a0b3578160c7a1/images/00006-2623256163_sunglasses.png -------------------------------------------------------------------------------- /install.py: -------------------------------------------------------------------------------- 1 | import launch 2 | 3 | 4 | def check_matplotlib(): 5 | if not launch.is_installed("matplotlib"): 6 | return False 7 | 8 | try: 9 | import matplotlib 10 | except ImportError: 11 | return False 12 | 13 | if hasattr(matplotlib, "__version_info__"): 14 | version = matplotlib.__version_info__ 15 | version = (version.major, version.minor, version.micro) 16 | return version >= (3, 6, 2) 17 | return False 18 | 19 | 20 | if not check_matplotlib(): 21 | launch.run_pip("install matplotlib==3.6.2", desc="Installing matplotlib==3.6.2") 22 | -------------------------------------------------------------------------------- /scripts/daam/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from .hook import * 3 | from .utils import * 4 | from .evaluate import * 5 | from .experiment import * 6 | from .trace import * -------------------------------------------------------------------------------- /scripts/daam/evaluate.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from collections import defaultdict 3 | from typing import List 4 | 5 | from scipy.optimize import linear_sum_assignment 6 | import PIL.Image as Image 7 | import numpy as np 8 | import torch 9 | import torch.nn.functional as F 10 | 11 | 12 | __all__ = ['compute_iou', 'MeanEvaluator'] 13 | 14 | 15 | def compute_iou(a: torch.Tensor, b: torch.Tensor) -> float: 16 | if a.shape[0] != b.shape[0]: 17 | a = F.interpolate(a.unsqueeze(0).unsqueeze(0).float(), size=b.shape, mode='bicubic').squeeze() 18 | a[a < 1] = 0 19 | a[a >= 1] = 1 20 | 21 | intersection = (a * b).sum() 22 | union = a.sum() + b.sum() - intersection 23 | 24 | return (intersection / (union + 1e-8)).item() 25 | 26 | 27 | def load_mask(path: str) -> torch.Tensor: 28 | mask = np.array(Image.open(path)) 29 | mask = torch.from_numpy(mask).float()[:, :, 3] # use alpha channel 30 | mask = (mask > 0).float() 31 | 32 | return mask 33 | 34 | 35 | class UnsupervisedEvaluator: 36 | def __init__(self, name: str = 'UnsupervisedEvaluator'): 37 | self.name = name 38 | self.ious = defaultdict(list) 39 | self.num_samples = 0 40 | 41 | def log_iou(self, preds: torch.Tensor | List[torch.Tensor], truth: torch.Tensor, gt_idx: int = 0, pred_idx: int = 0): 42 | if not isinstance(preds, list): 43 | preds = [preds] 44 | 45 | iou = max(compute_iou(pred, truth) for pred in preds) 46 | self.ious[gt_idx].append((pred_idx, iou)) 47 | 48 | @property 49 | def mean_iou(self) -> float: 50 | n = max(max(self.ious), max([y[0] for x in self.ious.values() for y in x])) + 1 51 | iou_matrix = np.zeros((n, n)) 52 | count_matrix = np.zeros((n, n)) 53 | 54 | for gt_idx, ious in self.ious.items(): 55 | for pred_idx, iou in ious: 56 | iou_matrix[gt_idx, pred_idx] += iou 57 | count_matrix[gt_idx, pred_idx] += 1 58 | 59 | row_ind, col_ind = linear_sum_assignment(iou_matrix, maximize=True) 60 | return iou_matrix[row_ind, col_ind].sum() / count_matrix[row_ind, col_ind].sum() 61 | 62 | def increment(self): 63 | self.num_samples += 1 64 | 65 | def __len__(self) -> int: 66 | return self.num_samples 67 | 68 | def __str__(self): 69 | return f'{self.name}<{self.mean_iou:.4f} (mIoU) {len(self)} samples>' 70 | 71 | 72 | class MeanEvaluator: 73 | def __init__(self, name: str = 'MeanEvaluator'): 74 | self.ious: List[float] = [] 75 | self.intensities: List[float] = [] 76 | self.name = name 77 | 78 | def log_iou(self, preds: torch.Tensor | List[torch.Tensor], truth: torch.Tensor): 79 | if not isinstance(preds, list): 80 | preds = [preds] 81 | 82 | self.ious.append(max(compute_iou(pred, truth) for pred in preds)) 83 | return self 84 | 85 | def log_intensity(self, pred: torch.Tensor): 86 | self.intensities.append(pred.mean().item()) 87 | return self 88 | 89 | @property 90 | def mean_iou(self) -> float: 91 | return np.mean(self.ious) 92 | 93 | @property 94 | def mean_intensity(self) -> float: 95 | return np.mean(self.intensities) 96 | 97 | @property 98 | def ci95_miou(self) -> float: 99 | return 1.96 * np.std(self.ious) / np.sqrt(len(self.ious)) 100 | 101 | def __len__(self) -> int: 102 | return max(len(self.ious), len(self.intensities)) 103 | 104 | def __str__(self): 105 | return f'{self.name}<{self.mean_iou:.4f} (±{self.ci95_miou:.3f} mIoU) {self.mean_intensity:.4f} (mInt) {len(self)} samples>' 106 | 107 | 108 | if __name__ == '__main__': 109 | mask = load_mask('truth/output/452/sink.gt.png') 110 | 111 | print(MeanEvaluator().log_iou(mask, mask)) 112 | -------------------------------------------------------------------------------- /scripts/daam/experiment.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from pathlib import Path 3 | from typing import List, Optional, Dict, Any 4 | from dataclasses import dataclass 5 | import json 6 | 7 | from transformers import PreTrainedTokenizer, AutoTokenizer 8 | import PIL.Image 9 | import numpy as np 10 | import torch 11 | 12 | from .evaluate import load_mask 13 | from .utils import plot_overlay_heat_map, expand_image 14 | 15 | __all__ = ['GenerationExperiment', 'COCO80_LABELS', 'COCOSTUFF27_LABELS', 'COCO80_INDICES', 'build_word_list_coco80'] 16 | 17 | 18 | COCO80_LABELS: List[str] = [ 19 | 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', 20 | 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 21 | 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 22 | 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 23 | 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 24 | 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 25 | 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 26 | 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 27 | 'hair drier', 'toothbrush' 28 | ] 29 | 30 | COCO80_INDICES: Dict[str, int] = {x: i for i, x in enumerate(COCO80_LABELS)} 31 | 32 | UNUSED_LABELS: List[str] = [f'__unused_{i}__' for i in range(1, 200)] 33 | 34 | COCOSTUFF27_LABELS: List[str] = [ 35 | 'electronic', 'appliance', 'food', 'furniture', 'indoor', 'kitchen', 'accessory', 'animal', 'outdoor', 'person', 36 | 'sports', 'vehicle', 'ceiling', 'floor', 'food', 'furniture', 'rawmaterial', 'textile', 'wall', 'window', 37 | 'building', 'ground', 'plant', 'sky', 'solid', 'structural', 'water' 38 | ] 39 | 40 | COCO80_ONTOLOGY = { 41 | 'two-wheeled vehicle': ['bicycle', 'motorcycle'], 42 | 'vehicle': ['two-wheeled vehicle', 'four-wheeled vehicle'], 43 | 'four-wheeled vehicle': ['bus', 'truck', 'car'], 44 | 'four-legged animals': ['livestock', 'pets', 'wild animals'], 45 | 'livestock': ['cow', 'horse', 'sheep'], 46 | 'pets': ['cat', 'dog'], 47 | 'wild animals': ['elephant', 'bear', 'zebra', 'giraffe'], 48 | 'bags': ['backpack', 'handbag', 'suitcase'], 49 | 'sports boards': ['snowboard', 'surfboard', 'skateboard'], 50 | 'utensils': ['fork', 'knife', 'spoon'], 51 | 'receptacles': ['bowl', 'cup'], 52 | 'fruits': ['banana', 'apple', 'orange'], 53 | 'foods': ['fruits', 'meals', 'desserts'], 54 | 'meals': ['sandwich', 'hot dog', 'pizza'], 55 | 'desserts': ['cake', 'donut'], 56 | 'furniture': ['chair', 'couch', 'bench'], 57 | 'electronics': ['monitors', 'appliances'], 58 | 'monitors': ['tv', 'cell phone', 'laptop'], 59 | 'appliances': ['oven', 'toaster', 'refrigerator'] 60 | } 61 | 62 | COCO80_TO_27 = { 63 | 'bicycle': 'vehicle', 'car': 'vehicle', 'motorcycle': 'vehicle', 'airplane': 'vehicle', 'bus': 'vehicle', 64 | 'train': 'vehicle', 'truck': 'vehicle', 'boat': 'vehicle', 'traffic light': 'accessory', 'fire hydrant': 'accessory', 65 | 'stop sign': 'accessory', 'parking meter': 'accessory', 'bench': 'furniture', 'bird': 'animal', 'cat': 'animal', 66 | 'dog': 'animal', 'horse': 'animal', 'sheep': 'animal', 'cow': 'animal', 'elephant': 'animal', 'bear': 'animal', 67 | 'zebra': 'animal', 'giraffe': 'animal', 'backpack': 'accessory', 'umbrella': 'accessory', 'handbag': 'accessory', 68 | 'tie': 'accessory', 'suitcase': 'accessory', 'frisbee': 'sports', 'skis': 'sports', 'snowboard': 'sports', 69 | 'sports ball': 'sports', 'kite': 'sports', 'baseball bat': 'sports', 'baseball glove': 'sports', 70 | 'skateboard': 'sports', 'surfboard': 'sports', 'tennis racket': 'sports', 'bottle': 'food', 'wine glass': 'food', 71 | 'cup': 'food', 'fork': 'food', 'knife': 'food', 'spoon': 'food', 'bowl': 'food', 'banana': 'food', 'apple': 'food', 72 | 'sandwich': 'food', 'orange': 'food', 'broccoli': 'food', 'carrot': 'food', 'hot dog': 'food', 'pizza': 'food', 73 | 'donut': 'food', 'cake': 'food', 'chair': 'furniture', 'couch': 'furniture', 'potted plant': 'plant', 74 | 'bed': 'furniture', 'dining table': 'furniture', 'toilet': 'furniture', 'tv': 'electronic', 'laptop': 'electronic', 75 | 'mouse': 'electronic', 'remote': 'electronic', 'keyboard': 'electronic', 'cell phone': 'electronic', 76 | 'microwave': 'appliance', 'oven': 'appliance', 'toaster': 'appliance', 'sink': 'appliance', 77 | 'refrigerator': 'appliance', 'book': 'indoor', 'clock': 'indoor', 'vase': 'indoor', 'scissors': 'indoor', 78 | 'teddy bear': 'indoor', 'hair drier': 'indoor', 'toothbrush': 'indoor' 79 | } 80 | 81 | 82 | def build_word_list_coco80() -> Dict[str, List[str]]: 83 | words_map = COCO80_ONTOLOGY.copy() 84 | words_map = {k: v for k, v in words_map.items() if not any(item in COCO80_ONTOLOGY for item in v)} 85 | 86 | return words_map 87 | 88 | 89 | def _add_mask(masks: Dict[str, torch.Tensor], word: str, mask: torch.Tensor, simplify80: bool = False) -> Dict[str, torch.Tensor]: 90 | if simplify80: 91 | word = COCO80_TO_27.get(word, word) 92 | 93 | if word in masks: 94 | masks[word] = masks[word.lower()] + mask 95 | masks[word].clamp_(0, 1) 96 | else: 97 | masks[word] = mask 98 | 99 | return masks 100 | 101 | 102 | @dataclass 103 | class GenerationExperiment: 104 | """Class to hold experiment parameters. Pickleable.""" 105 | id: str 106 | image: PIL.Image.Image 107 | global_heat_map: torch.Tensor 108 | seed: int 109 | prompt: str 110 | 111 | path: Optional[Path] = None 112 | truth_masks: Optional[Dict[str, torch.Tensor]] = None 113 | prediction_masks: Optional[Dict[str, torch.Tensor]] = None 114 | annotations: Optional[Dict[str, Any]] = None 115 | subtype: Optional[str] = '.' 116 | 117 | def nsfw(self) -> bool: 118 | return np.sum(np.array(self.image)) == 0 119 | 120 | def heat_map(self, tokenizer: AutoTokenizer): 121 | from .trace import HeatMap 122 | return HeatMap(tokenizer, self.prompt, self.global_heat_map) 123 | 124 | def save(self, path: str = None): 125 | if path is None: 126 | path = self.path 127 | else: 128 | path = Path(path) / self.id 129 | 130 | (path / self.subtype).mkdir(parents=True, exist_ok=True) 131 | torch.save(self, path / self.subtype / 'generation.pt') 132 | self.image.save(path / self.subtype / 'output.png') 133 | 134 | with (path / 'prompt.txt').open('w') as f: 135 | f.write(self.prompt) 136 | 137 | with (path / 'seed.txt').open('w') as f: 138 | f.write(str(self.seed)) 139 | 140 | if self.truth_masks is not None: 141 | for name, mask in self.truth_masks.items(): 142 | im = PIL.Image.fromarray((mask * 255).unsqueeze(-1).expand(-1, -1, 4).byte().numpy()) 143 | im.save(path / f'{name.lower()}.gt.png') 144 | 145 | self.save_annotations() 146 | 147 | def save_annotations(self, path: Path = None): 148 | if path is None: 149 | path = self.path 150 | 151 | if self.annotations is not None: 152 | with (path / 'annotations.json').open('w') as f: 153 | json.dump(self.annotations, f) 154 | 155 | def _load_truth_masks(self, simplify80: bool = False) -> Dict[str, torch.Tensor]: 156 | masks = {} 157 | 158 | for mask_path in self.path.glob('*.gt.png'): 159 | word = mask_path.name.split('.gt.png')[0].lower() 160 | mask = load_mask(str(mask_path)) 161 | _add_mask(masks, word, mask, simplify80) 162 | 163 | return masks 164 | 165 | def _load_pred_masks(self, pred_prefix, composite=False, simplify80=False, vocab=None): 166 | # type: (str, bool, bool, List[str] | None) -> Dict[str, torch.Tensor] 167 | masks = {} 168 | 169 | if vocab is None: 170 | vocab = UNUSED_LABELS 171 | 172 | if composite: 173 | try: 174 | im = PIL.Image.open(self.path / self.subtype / f'composite.{pred_prefix}.pred.png') 175 | im = np.array(im) 176 | 177 | for mask_idx in np.unique(im): 178 | mask = torch.from_numpy((im == mask_idx).astype(np.float32)) 179 | _add_mask(masks, vocab[mask_idx], mask, simplify80) 180 | except FileNotFoundError: 181 | pass 182 | else: 183 | for mask_path in (self.path / self.subtype).glob(f'*.{pred_prefix}.pred.png'): 184 | mask = load_mask(str(mask_path)) 185 | word = mask_path.name.split(f'.{pred_prefix}.pred')[0].lower() 186 | _add_mask(masks, word, mask, simplify80) 187 | 188 | return masks 189 | 190 | def clear_prediction_masks(self, name: str): 191 | path = self if isinstance(self, Path) else self.path 192 | path = path / self.subtype 193 | 194 | for mask_path in path.glob(f'*.{name}.pred.png'): 195 | mask_path.unlink() 196 | 197 | def save_prediction_mask(self, mask: torch.Tensor, word: str, name: str): 198 | path = self if isinstance(self, Path) else self.path 199 | im = PIL.Image.fromarray((mask * 255).unsqueeze(-1).expand(-1, -1, 4).cpu().byte().numpy()) 200 | im.save(path / self.subtype / f'{word.lower()}.{name}.pred.png') 201 | 202 | def save_heat_map(self, tokenizer: PreTrainedTokenizer, word: str, crop: int = None) -> Path: 203 | from .trace import HeatMap # because of cyclical import 204 | heat_map = HeatMap(tokenizer, self.prompt, self.global_heat_map) 205 | heat_map = expand_image(heat_map.compute_word_heat_map(word)) 206 | path = self.path / self.subtype / f'{word.lower()}.heat_map.png' 207 | plot_overlay_heat_map(self.image, heat_map, word, path, crop=crop) 208 | 209 | return path 210 | 211 | def save_all_heat_maps(self, tokenizer: PreTrainedTokenizer, crop: int = None) -> Dict[str, Path]: 212 | path_map = {} 213 | 214 | for word in self.prompt.split(' '): 215 | try: 216 | path = self.save_heat_map(tokenizer, word, crop=crop) 217 | path_map[word] = path 218 | except: 219 | pass 220 | 221 | return path_map 222 | 223 | @staticmethod 224 | def contains_truth_mask(path: str | Path, prompt_id: str = None) -> bool: 225 | if prompt_id is None: 226 | return any(Path(path).glob('*.gt.png')) 227 | else: 228 | return any((Path(path) / prompt_id).glob('*.gt.png')) 229 | 230 | @staticmethod 231 | def read_seed(path: str | Path, prompt_id: str = None) -> int: 232 | if prompt_id is None: 233 | return int(Path(path).joinpath('seed.txt').read_text()) 234 | else: 235 | return int(Path(path).joinpath(prompt_id).joinpath('seed.txt').read_text()) 236 | 237 | @staticmethod 238 | def has_annotations(path: str | Path) -> bool: 239 | return Path(path).joinpath('annotations.json').exists() 240 | 241 | @staticmethod 242 | def has_experiment(path: str | Path, prompt_id: str) -> bool: 243 | return (Path(path) / prompt_id / 'generation.pt').exists() 244 | 245 | @staticmethod 246 | def read_prompt(path: str | Path, prompt_id: str = None) -> str: 247 | if prompt_id is None: 248 | prompt_id = '.' 249 | 250 | with (Path(path) / prompt_id / 'prompt.txt').open('r') as f: 251 | return f.read().strip() 252 | 253 | def _try_load_annotations(self): 254 | if not (self.path / 'annotations.json').exists(): 255 | return None 256 | 257 | return json.load((self.path / 'annotations.json').open()) 258 | 259 | def annotate(self, key: str, value: Any) -> 'GenerationExperiment': 260 | if self.annotations is None: 261 | self.annotations = {} 262 | 263 | self.annotations[key] = value 264 | 265 | return self 266 | 267 | @classmethod 268 | def load( 269 | cls, 270 | path, 271 | pred_prefix='daam', 272 | composite=False, 273 | simplify80=False, 274 | vocab=None, 275 | subtype='.', 276 | all_subtypes=False 277 | ): 278 | # type: (str, str, bool, bool, List[str] | None, str, bool) -> GenerationExperiment | List[GenerationExperiment] 279 | if all_subtypes: 280 | experiments = [] 281 | 282 | for directory in Path(path).iterdir(): 283 | if not directory.is_dir(): 284 | continue 285 | 286 | try: 287 | experiments.append(cls.load( 288 | path, 289 | pred_prefix=pred_prefix, 290 | composite=composite, 291 | simplify80=simplify80, 292 | vocab=vocab, 293 | subtype=directory.name 294 | )) 295 | except: 296 | pass 297 | 298 | return experiments 299 | 300 | path = Path(path) 301 | exp = torch.load(path / subtype / 'generation.pt') 302 | exp.subtype = subtype 303 | exp.path = path 304 | exp.truth_masks = exp._load_truth_masks(simplify80=simplify80) 305 | exp.prediction_masks = exp._load_pred_masks(pred_prefix, composite=composite, simplify80=simplify80, vocab=vocab) 306 | exp.annotations = exp._try_load_annotations() 307 | 308 | return exp 309 | -------------------------------------------------------------------------------- /scripts/daam/hook.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing import List, Generic, TypeVar, Callable, Union, Any 3 | import functools 4 | import itertools 5 | from ldm.modules.attention import SpatialTransformer 6 | 7 | from ldm.modules.diffusionmodules.openaimodel import UNetModel 8 | from ldm.modules.attention import CrossAttention 9 | import torch.nn as nn 10 | 11 | 12 | __all__ = ['ObjectHooker', 'ModuleLocator', 'AggregateHooker', 'UNetCrossAttentionLocator'] 13 | 14 | 15 | ModuleType = TypeVar('ModuleType') 16 | ModuleListType = TypeVar('ModuleListType', bound=List) 17 | 18 | 19 | class ModuleLocator(Generic[ModuleType]): 20 | def locate(self, model: nn.Module) -> List[ModuleType]: 21 | raise NotImplementedError 22 | 23 | 24 | class ObjectHooker(Generic[ModuleType]): 25 | def __init__(self, module: ModuleType): 26 | self.module: ModuleType = module 27 | self.hooked = False 28 | self.old_state = dict() 29 | 30 | def __enter__(self): 31 | self.hook() 32 | return self 33 | 34 | def __exit__(self, exc_type, exc_val, exc_tb): 35 | self.unhook() 36 | 37 | def hook(self): 38 | if self.hooked: 39 | raise RuntimeError('Already hooked module') 40 | 41 | self.old_state = dict() 42 | self.hooked = True 43 | self._hook_impl() 44 | 45 | return self 46 | 47 | def unhook(self): 48 | if not self.hooked: 49 | raise RuntimeError('Module is not hooked') 50 | 51 | for k, v in self.old_state.items(): 52 | if k.startswith('old_fn_'): 53 | setattr(self.module, k[7:], v) 54 | 55 | self.hooked = False 56 | self._unhook_impl() 57 | 58 | return self 59 | 60 | def monkey_patch(self, fn_name, fn): 61 | self.old_state[f'old_fn_{fn_name}'] = getattr(self.module, fn_name) 62 | setattr(self.module, fn_name, functools.partial(fn, self.module)) 63 | 64 | def monkey_super(self, fn_name, *args, **kwargs): 65 | return self.old_state[f'old_fn_{fn_name}'](*args, **kwargs) 66 | 67 | def _hook_impl(self): 68 | raise NotImplementedError 69 | 70 | def _unhook_impl(self): 71 | pass 72 | 73 | 74 | class AggregateHooker(ObjectHooker[ModuleListType]): 75 | def _hook_impl(self): 76 | for h in self.module: 77 | h.hook() 78 | 79 | def _unhook_impl(self): 80 | for h in self.module: 81 | h.unhook() 82 | 83 | def register_hook(self, hook: ObjectHooker): 84 | self.module.append(hook) 85 | 86 | 87 | class UNetCrossAttentionLocator(ModuleLocator[CrossAttention]): 88 | def locate(self, model: UNetModel, layer_idx: int) -> List[CrossAttention]: 89 | """ 90 | Locate all cross-attention modules in a UNetModel. 91 | 92 | Args: 93 | model (`UNetModel`): The model to locate the cross-attention modules in. 94 | 95 | Returns: 96 | `List[CrossAttention]`: The list of cross-attention modules. 97 | """ 98 | blocks = [] 99 | 100 | for i, unet_block in enumerate(itertools.chain(model.input_blocks, [model.middle_block], model.output_blocks)): 101 | # if 'CrossAttn' in unet_block.__class__.__name__: 102 | if not layer_idx or i == layer_idx: 103 | for module in unet_block.modules(): 104 | if module.__class__.__name__ == "SpatialTransformer": 105 | spatial_transformer = module 106 | for basic_transformer_block in spatial_transformer.transformer_blocks: 107 | blocks.append(basic_transformer_block.attn2) 108 | 109 | 110 | return blocks 111 | -------------------------------------------------------------------------------- /scripts/daam/trace.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from collections import defaultdict 3 | from copy import deepcopy 4 | from pathlib import Path 5 | from typing import List, Type, Any, Literal, Dict 6 | import math 7 | from modules.devices import device 8 | 9 | from ldm.models.diffusion.ddpm import DiffusionWrapper, LatentDiffusion 10 | from ldm.modules.diffusionmodules.openaimodel import UNetModel 11 | from ldm.modules.attention import CrossAttention, default, exists 12 | 13 | import numba 14 | import numpy as np 15 | import torch 16 | import torch.nn.functional as F 17 | from torch import nn, einsum 18 | from einops import rearrange, repeat 19 | 20 | from .experiment import COCO80_LABELS 21 | from .hook import ObjectHooker, AggregateHooker, UNetCrossAttentionLocator 22 | from .utils import compute_token_merge_indices, PromptAnalyzer 23 | 24 | 25 | __all__ = ['trace', 'DiffusionHeatMapHooker', 'HeatMap', 'MmDetectHeatMap'] 26 | 27 | 28 | class UNetForwardHooker(ObjectHooker[UNetModel]): 29 | def __init__(self, module: UNetModel, heat_maps: defaultdict(defaultdict)): 30 | super().__init__(module) 31 | self.all_heat_maps = [] 32 | self.heat_maps = heat_maps 33 | 34 | def _hook_impl(self): 35 | self.monkey_patch('forward', self._forward) 36 | 37 | def _unhook_impl(self): 38 | pass 39 | 40 | def _forward(hk_self, self, *args, **kwargs): 41 | super_return = hk_self.monkey_super('forward', *args, **kwargs) 42 | hk_self.all_heat_maps.append(deepcopy(hk_self.heat_maps)) 43 | hk_self.heat_maps.clear() 44 | 45 | return super_return 46 | 47 | 48 | class HeatMap: 49 | def __init__(self, prompt_analyzer: PromptAnalyzer, prompt: str, heat_maps: torch.Tensor): 50 | self.prompt_analyzer = prompt_analyzer.create(prompt) 51 | self.heat_maps = heat_maps 52 | self.prompt = prompt 53 | 54 | def compute_word_heat_map(self, word: str, word_idx: int = None) -> torch.Tensor: 55 | merge_idxs, _ = self.prompt_analyzer.calc_word_indecies(word) 56 | # print("merge_idxs", merge_idxs) 57 | if len(merge_idxs) == 0: 58 | return None 59 | 60 | return self.heat_maps[merge_idxs].mean(0) 61 | 62 | 63 | class MmDetectHeatMap: 64 | def __init__(self, pred_file: str | Path, threshold: float = 0.95): 65 | @numba.njit 66 | def _compute_mask(masks: np.ndarray, bboxes: np.ndarray): 67 | x_any = np.any(masks, axis=1) 68 | y_any = np.any(masks, axis=2) 69 | num_masks = len(bboxes) 70 | 71 | for idx in range(num_masks): 72 | x = np.where(x_any[idx, :])[0] 73 | y = np.where(y_any[idx, :])[0] 74 | bboxes[idx, :4] = np.array([x[0], y[0], x[-1] + 1, y[-1] + 1], dtype=np.float32) 75 | 76 | pred_file = Path(pred_file) 77 | self.word_masks: Dict[str, torch.Tensor] = defaultdict(lambda: 0) 78 | bbox_result, masks = torch.load(pred_file) 79 | labels = [np.full(bbox.shape[0], i, dtype=np.int32) for i, bbox in enumerate(bbox_result)] 80 | labels = np.concatenate(labels) 81 | bboxes = np.vstack(bbox_result) 82 | 83 | if masks is not None and bboxes[:, :4].sum() == 0: 84 | _compute_mask(masks, bboxes) 85 | scores = bboxes[:, -1] 86 | inds = scores > threshold 87 | labels = labels[inds] 88 | masks = masks[inds, ...] 89 | 90 | for lbl, mask in zip(labels, masks): 91 | self.word_masks[COCO80_LABELS[lbl]] |= torch.from_numpy(mask) 92 | 93 | self.word_masks = {k: v.float() for k, v in self.word_masks.items()} 94 | 95 | def compute_word_heat_map(self, word: str) -> torch.Tensor: 96 | return self.word_masks[word] 97 | 98 | 99 | class DiffusionHeatMapHooker(AggregateHooker): 100 | def __init__(self, model: LatentDiffusion, heigth : int, width : int, context_size : int = 77, weighted: bool = False, layer_idx: int = None, head_idx: int = None): 101 | heat_maps = defaultdict(lambda: defaultdict(list)) # batch index, factor, attention 102 | modules = [UNetCrossAttentionHooker(x, heigth, width, heat_maps, context_size=context_size, weighted=weighted, head_idx=head_idx) for x in UNetCrossAttentionLocator().locate(model.model.diffusion_model, layer_idx)] 103 | self.forward_hook = UNetForwardHooker(model.model.diffusion_model, heat_maps) 104 | modules.append(self.forward_hook) 105 | 106 | self.height = heigth 107 | self.width = width 108 | self.model = model 109 | self.last_prompt = '' 110 | 111 | super().__init__(modules) 112 | 113 | 114 | 115 | @property 116 | def all_heat_maps(self): 117 | return self.forward_hook.all_heat_maps 118 | 119 | def reset(self): 120 | map(lambda module: module.reset(), self.module) 121 | return self.forward_hook.all_heat_maps.clear() 122 | 123 | def compute_global_heat_map(self, prompt_analyzer, prompt, batch_index, time_weights=None, time_idx=None, last_n=None, first_n=None, factors=None): 124 | # type: (PromptAnalyzer, str, int, int, int, int, int, List[float]) -> HeatMap 125 | """ 126 | Compute the global heat map for the given prompt, aggregating across time (inference steps) and space (different 127 | spatial transformer block heat maps). 128 | 129 | Args: 130 | prompt: The prompt to compute the heat map for. 131 | time_weights: The weights to apply to each time step. If None, all time steps are weighted equally. 132 | time_idx: The time step to compute the heat map for. If None, the heat map is computed for all time steps. 133 | Mutually exclusive with `last_n` and `first_n`. 134 | last_n: The number of last n time steps to use. If None, the heat map is computed for all time steps. 135 | Mutually exclusive with `time_idx`. 136 | first_n: The number of first n time steps to use. If None, the heat map is computed for all time steps. 137 | Mutually exclusive with `time_idx`. 138 | factors: Restrict the application to heat maps with spatial factors in this set. If `None`, use all sizes. 139 | """ 140 | if len(self.forward_hook.all_heat_maps) == 0: 141 | return None 142 | 143 | if time_weights is None: 144 | time_weights = [1.0] * len(self.forward_hook.all_heat_maps) 145 | 146 | time_weights = np.array(time_weights) 147 | time_weights /= time_weights.sum() 148 | all_heat_maps = self.forward_hook.all_heat_maps 149 | 150 | if time_idx is not None: 151 | heat_maps = [all_heat_maps[time_idx]] 152 | else: 153 | heat_maps = all_heat_maps[-last_n:] if last_n is not None else all_heat_maps 154 | heat_maps = heat_maps[:first_n] if first_n is not None else heat_maps 155 | 156 | 157 | if factors is None: 158 | factors = {1, 2, 4, 8, 16, 32} 159 | else: 160 | factors = set(factors) 161 | 162 | all_merges = [] 163 | 164 | for batch_to_heat_maps in heat_maps: 165 | 166 | if not (batch_index in batch_to_heat_maps): 167 | continue 168 | 169 | merge_list = [] 170 | 171 | factors_to_heat_maps = batch_to_heat_maps[batch_index] 172 | 173 | for k, heat_map in factors_to_heat_maps.items(): 174 | # heat_map shape: (tokens, 1, height, width) 175 | # each v is a heat map tensor for a layer of factor size k across the tokens 176 | if k in factors: 177 | merge_list.append(torch.stack(heat_map, 0).mean(0)) 178 | 179 | if len(merge_list) > 0: 180 | all_merges.append(merge_list) 181 | 182 | maps = torch.stack([torch.stack(x, 0) for x in all_merges], dim=0) 183 | maps = maps.sum(0).to(device).sum(2).sum(0) 184 | 185 | return HeatMap(prompt_analyzer, prompt, maps) 186 | 187 | 188 | class UNetCrossAttentionHooker(ObjectHooker[CrossAttention]): 189 | def __init__(self, module: CrossAttention, img_height : int, img_width : int, heat_maps: defaultdict(defaultdict), context_size: int = 77, weighted: bool = False, head_idx: int = 0): 190 | super().__init__(module) 191 | self.heat_maps = heat_maps 192 | self.context_size = context_size 193 | self.weighted = weighted 194 | self.head_idx = head_idx 195 | self.img_height = img_height 196 | self.img_width = img_width 197 | self.calledCount = 0 198 | 199 | def reset(self): 200 | self.heat_maps.clear() 201 | self.calledCount = 0 202 | 203 | @torch.no_grad() 204 | def _up_sample_attn(self, x, value, factor, method='bicubic'): 205 | # type: (torch.Tensor, torch.Tensor, int, Literal['bicubic', 'conv']) -> torch.Tensor 206 | # x shape: (heads, height * width, tokens) 207 | """ 208 | Up samples the attention map in x using interpolation to the maximum size of (64, 64), as assumed in the Stable 209 | Diffusion model. 210 | 211 | Args: 212 | x (`torch.Tensor`): cross attention slice/map between the words and the tokens. 213 | value (`torch.Tensor`): the value tensor. 214 | method (`str`): the method to use; one of `'bicubic'` or `'conv'`. 215 | 216 | Returns: 217 | `torch.Tensor`: the up-sampled attention map of shape (tokens, 1, height, width). 218 | """ 219 | weight = torch.full((factor, factor), 1 / factor ** 2, device=x.device) 220 | weight = weight.view(1, 1, factor, factor) 221 | 222 | h = int(math.sqrt ( (self.img_height * x.size(1)) / self.img_width)) 223 | w = int(self.img_width * h / self.img_height) 224 | 225 | h_fix = w_fix = 64 226 | if h >= w: 227 | w_fix = int((w * h_fix) / h) 228 | else: 229 | h_fix = int((h * w_fix) / w) 230 | 231 | maps = [] 232 | x = x.permute(2, 0, 1) 233 | value = value.permute(1, 0, 2) 234 | weights = 1 235 | 236 | with torch.cuda.amp.autocast(dtype=torch.float32): 237 | for map_ in x: 238 | map_ = map_.unsqueeze(1).view(map_.size(0), 1, h, w) 239 | 240 | if method == 'bicubic': 241 | map_ = F.interpolate(map_, size=(h_fix, w_fix), mode='bicubic') 242 | maps.append(map_.squeeze(1)) 243 | else: 244 | maps.append(F.conv_transpose2d(map_, weight, stride=factor).squeeze(1)) 245 | 246 | if self.weighted: 247 | weights = value.norm(p=1, dim=-1, keepdim=True).unsqueeze(-1) 248 | 249 | maps = torch.stack(maps, 0) # shape: (tokens, heads, height, width) 250 | 251 | if self.head_idx: 252 | maps = maps[:, self.head_idx:self.head_idx+1, :, :] 253 | 254 | return (weights * maps).sum(1, keepdim=True).cpu() 255 | 256 | def _forward(hk_self, self, x, context=None, mask=None, additional_tokens=None): 257 | 258 | if additional_tokens is not None: 259 | # get the number of masked tokens at the beginning of the output sequence 260 | n_tokens_to_mask = additional_tokens.shape[1] 261 | # add additional token 262 | x = torch.cat([additional_tokens, x], dim=1) 263 | 264 | hk_self.calledCount += 1 265 | batch_size, sequence_length, _ = x.shape 266 | h = self.heads 267 | 268 | q = self.to_q(x) 269 | context = default(context, x) 270 | k = self.to_k(context) 271 | v = self.to_v(context) 272 | 273 | dim = q.shape[-1] 274 | 275 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) 276 | 277 | sim = einsum('b i d, b j d -> b i j', q, k) * self.scale 278 | 279 | if exists(mask): 280 | mask = rearrange(mask, 'b ... -> b (...)') 281 | max_neg_value = -torch.finfo(sim.dtype).max 282 | mask = repeat(mask, 'b j -> (b h) () j', h=h) 283 | sim.masked_fill_(~mask, max_neg_value) 284 | 285 | out = hk_self._hooked_attention(self, q, k, v, batch_size, sequence_length, dim) 286 | 287 | if additional_tokens is not None: 288 | # remove additional token 289 | out = out[:, n_tokens_to_mask:] 290 | 291 | return self.to_out(out) 292 | 293 | ### forward implemetation of diffuser CrossAttention 294 | # def forward(self, hidden_states, context=None, mask=None): 295 | # batch_size, sequence_length, _ = hidden_states.shape 296 | 297 | # query = self.to_q(hidden_states) 298 | # context = context if context is not None else hidden_states 299 | # key = self.to_k(context) 300 | # value = self.to_v(context) 301 | 302 | # dim = query.shape[-1] 303 | 304 | # query = self.reshape_heads_to_batch_dim(query) 305 | # key = self.reshape_heads_to_batch_dim(key) 306 | # value = self.reshape_heads_to_batch_dim(value) 307 | 308 | # # TODO(PVP) - mask is currently never used. Remember to re-implement when used 309 | 310 | # # attention, what we cannot get enough of 311 | # if self._use_memory_efficient_attention_xformers: 312 | # hidden_states = self._memory_efficient_attention_xformers(query, key, value) 313 | # # Some versions of xformers return output in fp32, cast it back to the dtype of the input 314 | # hidden_states = hidden_states.to(query.dtype) 315 | # else: 316 | # if self._slice_size is None or query.shape[0] // self._slice_size == 1: 317 | # hidden_states = self._attention(query, key, value) 318 | # else: 319 | # hidden_states = self._sliced_attention(query, key, value, sequence_length, dim) 320 | 321 | # # linear proj 322 | # hidden_states = self.to_out[0](hidden_states) 323 | # # dropout 324 | # hidden_states = self.to_out[1](hidden_states) 325 | # return hidden_states 326 | 327 | def _hooked_attention(hk_self, self, query, key, value, batch_size, sequence_length, dim, use_context: bool = True): 328 | """ 329 | Monkey-patched version of :py:func:`.CrossAttention._attention` to capture attentions and aggregate them. 330 | 331 | Args: 332 | hk_self (`UNetCrossAttentionHooker`): pointer to the hook itself. 333 | self (`CrossAttention`): pointer to the module. 334 | query (`torch.Tensor`): the query tensor. 335 | key (`torch.Tensor`): the key tensor. 336 | value (`torch.Tensor`): the value tensor. 337 | batch_size (`int`): the batch size 338 | use_context (`bool`): whether to check if the resulting attention slices are between the words and the image 339 | """ 340 | batch_size_attention = query.shape[0] 341 | hidden_states = torch.zeros( 342 | (batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype 343 | ) 344 | slice_size = hidden_states.shape[0] // batch_size # self._slice_size if self._slice_size is not None else hidden_states.shape[0] 345 | 346 | def calc_factor_base(w, h): 347 | z = max(w/64, h/64) 348 | factor_b = min(w, h) * z 349 | return factor_b 350 | 351 | factor_base = calc_factor_base(hk_self.img_width, hk_self.img_height) 352 | 353 | for batch_index in range(hidden_states.shape[0] // slice_size): 354 | start_idx = batch_index * slice_size 355 | end_idx = (batch_index + 1) * slice_size 356 | attn_slice = ( 357 | torch.einsum("b i d, b j d -> b i j", query[start_idx:end_idx], key[start_idx:end_idx]) * self.scale 358 | ) 359 | factor = int(math.sqrt(factor_base // attn_slice.shape[1])) 360 | attn_slice = attn_slice.softmax(-1) 361 | hid_states = torch.einsum("b i j, b j d -> b i d", attn_slice, value[start_idx:end_idx]) 362 | 363 | if use_context and hk_self.calledCount % 2 == 1 and attn_slice.shape[-1] == hk_self.context_size: 364 | if factor >= 1: 365 | factor //= 1 366 | maps = hk_self._up_sample_attn(attn_slice, value, factor) 367 | hk_self.heat_maps[batch_index][factor].append(maps) 368 | 369 | hidden_states[start_idx:end_idx] = hid_states 370 | 371 | # reshape hidden_states 372 | hidden_states = hk_self.reshape_batch_dim_to_heads(self, hidden_states) 373 | return hidden_states 374 | 375 | def reshape_batch_dim_to_heads(hk_self, self, tensor): 376 | batch_size, seq_len, dim = tensor.shape 377 | head_size = self.heads 378 | tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) 379 | tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) 380 | return tensor 381 | 382 | def _hook_impl(self): 383 | self.monkey_patch('forward', self._forward) 384 | 385 | @property 386 | def num_heat_maps(self): 387 | return len(next(iter(self.heat_maps.values()))) 388 | 389 | 390 | trace: Type[DiffusionHeatMapHooker] = DiffusionHeatMapHooker 391 | -------------------------------------------------------------------------------- /scripts/daam/utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from itertools import chain 3 | from functools import lru_cache 4 | from pathlib import Path 5 | import random 6 | import re 7 | from typing import Union 8 | 9 | from PIL import Image, ImageFont, ImageDraw 10 | # from fonts.ttf import Roboto 11 | from modules.paths_internal import roboto_ttf_file 12 | import matplotlib.pyplot as plt 13 | from matplotlib import cm 14 | import numpy as np 15 | # import spacy 16 | import torch 17 | import torch.nn.functional as F 18 | from modules.devices import dtype 19 | 20 | from ldm.modules.encoders.modules import FrozenCLIPEmbedder, FrozenOpenCLIPEmbedder 21 | import open_clip.tokenizer 22 | from modules.sd_hijack_clip import FrozenCLIPEmbedderWithCustomWordsBase, FrozenCLIPEmbedderWithCustomWords 23 | from modules.sd_hijack_open_clip import FrozenOpenCLIPEmbedderWithCustomWords 24 | from modules.shared import opts 25 | from sgm.modules import GeneralConditioner 26 | 27 | __all__ = ['expand_image', 'set_seed', 'escape_prompt', 'calc_context_size', 'compute_token_merge_indices', 'compute_token_merge_indices_with_tokenizer', 'image_overlay_heat_map', 'plot_overlay_heat_map', 'plot_mask_heat_map', 'PromptAnalyzer'] 28 | 29 | def expand_image(im: torch.Tensor, h = 512, w = 512, absolute: bool = False, threshold: float = None) -> torch.Tensor: 30 | 31 | im = im.unsqueeze(0).unsqueeze(0) 32 | im = F.interpolate(im.float().detach(), size=(h, w), mode='bicubic') 33 | 34 | if not absolute: 35 | im = (im - im.min()) / (im.max() - im.min() + 1e-8) 36 | 37 | if threshold: 38 | im = (im > threshold).float() 39 | 40 | # im = im.cpu().detach() 41 | 42 | return im.squeeze() 43 | 44 | def _write_on_image(img, caption, font_size = 32): 45 | ix,iy = img.size 46 | draw = ImageDraw.Draw(img) 47 | margin=2 48 | fontsize=font_size 49 | draw = ImageDraw.Draw(img) 50 | font = ImageFont.truetype(roboto_ttf_file, fontsize) 51 | text_height=iy-60 52 | tx = draw.textbbox((0,0),caption,font) 53 | draw.text((int((ix-tx[2])/2),text_height+margin),caption,(0,0,0),font=font) 54 | draw.text((int((ix-tx[2])/2),text_height-margin),caption,(0,0,0),font=font) 55 | draw.text((int((ix-tx[2])/2+margin),text_height),caption,(0,0,0),font=font) 56 | draw.text((int((ix-tx[2])/2-margin),text_height),caption,(0,0,0),font=font) 57 | draw.text((int((ix-tx[2])/2),text_height), caption,(255,255,255),font=font) 58 | return img 59 | 60 | def image_overlay_heat_map(img, heat_map, word=None, out_file=None, crop=None, alpha=0.5, caption=None, image_scale=1.0): 61 | # type: (Image.Image | np.ndarray, torch.Tensor, str, Path, int, float, str, float) -> Image.Image 62 | assert(img is not None) 63 | 64 | if heat_map is not None: 65 | shape : torch.Size = heat_map.shape 66 | # heat_map = heat_map.unsqueeze(-1).expand(shape[0], shape[1], 3).clone() 67 | heat_map = _convert_heat_map_colors(heat_map) 68 | heat_map = heat_map.to('cpu').detach().numpy().copy().astype(np.uint8) 69 | heat_map_img = Image.fromarray(heat_map) 70 | 71 | img = Image.blend(img, heat_map_img, alpha) 72 | else: 73 | img = img.copy() 74 | 75 | if caption: 76 | img = _write_on_image(img, caption) 77 | 78 | if image_scale != 1.0: 79 | x, y = img.size 80 | size = (int(x * image_scale), int(y * image_scale)) 81 | img = img.resize(size, Image.BICUBIC) 82 | 83 | return img 84 | 85 | 86 | def _convert_heat_map_colors(heat_map : torch.Tensor): 87 | def get_color(value): 88 | return np.array(cm.turbo(value / 255)[0:3]) 89 | 90 | color_map = np.array([ get_color(i) * 255 for i in range(256) ]) 91 | color_map = torch.tensor(color_map, device=heat_map.device, dtype=dtype) 92 | 93 | heat_map = (heat_map * 255).long() 94 | 95 | return color_map[heat_map] 96 | 97 | def plot_overlay_heat_map(im, heat_map, word=None, out_file=None, crop=None): 98 | # type: (Image.Image | np.ndarray, torch.Tensor, str, Path, int) -> None 99 | plt.clf() 100 | plt.rcParams.update({'font.size': 24}) 101 | 102 | im = np.array(im) 103 | if crop is not None: 104 | heat_map = heat_map.squeeze()[crop:-crop, crop:-crop] 105 | im = im[crop:-crop, crop:-crop] 106 | 107 | plt.imshow(heat_map.squeeze().cpu().numpy(), cmap='jet') 108 | im = torch.from_numpy(im).float() / 255 109 | im = torch.cat((im, (1 - heat_map.unsqueeze(-1))), dim=-1) 110 | plt.imshow(im) 111 | 112 | if word is not None: 113 | plt.title(word) 114 | 115 | if out_file is not None: 116 | plt.savefig(out_file) 117 | 118 | 119 | def plot_mask_heat_map(im: Image.Image, heat_map: torch.Tensor, threshold: float = 0.4): 120 | im = torch.from_numpy(np.array(im)).float() / 255 121 | mask = (heat_map.squeeze() > threshold).float() 122 | im = im * mask.unsqueeze(-1) 123 | plt.imshow(im) 124 | 125 | 126 | def set_seed(seed: int) -> torch.Generator: 127 | random.seed(seed) 128 | np.random.seed(seed) 129 | torch.manual_seed(seed) 130 | torch.cuda.manual_seed_all(seed) 131 | 132 | gen = torch.Generator(device='cuda') 133 | gen.manual_seed(seed) 134 | 135 | return gen 136 | 137 | def calc_context_size(token_length : int): 138 | len_check = 0 if (token_length - 1) < 0 else token_length - 1 139 | return ((int)(len_check // 75) + 1) * 77 140 | 141 | def escape_prompt(prompt): 142 | if type(prompt) is str: 143 | prompt = prompt.lower() 144 | prompt = re.sub(r"[\(\)\[\]]", "", prompt) 145 | prompt = re.sub(r":\d+\.*\d*", "", prompt) 146 | return prompt 147 | elif type(prompt) is list: 148 | prompt_new = [] 149 | for i in range(len(prompt)): 150 | prompt_new.append(escape_prompt(prompt[i])) 151 | return prompt_new 152 | 153 | 154 | def compute_token_merge_indices(model, prompt: str, word: str, word_idx: int = None): 155 | 156 | clip = None 157 | tokenize = None 158 | if type(model.cond_stage_model.wrapped) == FrozenCLIPEmbedder: 159 | clip : FrozenCLIPEmbedder = model.cond_stage_model.wrapped 160 | tokenize = clip.tokenizer.tokenize 161 | elif type(model.cond_stage_model.wrapped) == FrozenOpenCLIPEmbedder: 162 | clip : FrozenOpenCLIPEmbedder = model.cond_stage_model.wrapped 163 | tokenize = open_clip.tokenizer._tokenizer.encode 164 | else: 165 | assert False 166 | 167 | escaped_prompt = escape_prompt(prompt) 168 | # escaped_prompt = re.sub(r"[_-]", " ", escaped_prompt) 169 | tokens : list = tokenize(escaped_prompt) 170 | word = word.lower() 171 | merge_idxs = [] 172 | 173 | needles = tokenize(word) 174 | 175 | if len(needles) == 0: 176 | return [] 177 | 178 | for i, token in enumerate(tokens): 179 | if needles[0] == token and len(needles) > 1: 180 | next = i + 1 181 | success = True 182 | for needle in needles[1:]: 183 | if next >= len(tokens) or needle != tokens[next]: 184 | success = False 185 | break 186 | next += 1 187 | 188 | # append consecutive indexes if all pass 189 | if success: 190 | merge_idxs.extend(list(range(i, next))) 191 | 192 | elif needles[0] == token: 193 | merge_idxs.append(i) 194 | 195 | idxs = [] 196 | for x in merge_idxs: 197 | seq = (int)(x / 75) 198 | if seq == 0: 199 | idxs.append(x + 1) # padding 200 | else: 201 | idxs.append(x + 1 + seq*2) # If tokens exceed 75, they are split. 202 | 203 | return idxs 204 | 205 | def compute_token_merge_indices_with_tokenizer(tokenizer, prompt: str, word: str, word_idx: int = None, limit : int = -1): 206 | 207 | escaped_prompt = escape_prompt(prompt) 208 | # escaped_prompt = re.sub(r"[_-]", " ", escaped_prompt) 209 | tokens : list = tokenizer.tokenize(escaped_prompt) 210 | word = word.lower() 211 | merge_idxs = [] 212 | 213 | needles = tokenizer.tokenize(word) 214 | 215 | if len(needles) == 0: 216 | return [] 217 | 218 | limit_count = 0 219 | for i, token in enumerate(tokens): 220 | if needles[0] == token and len(needles) > 1: 221 | next = i + 1 222 | success = True 223 | for needle in needles[1:]: 224 | if next >= len(tokens) or needle != tokens[next]: 225 | success = False 226 | break 227 | next += 1 228 | 229 | # append consecutive indexes if all pass 230 | if success: 231 | merge_idxs.extend(list(range(i, next))) 232 | if limit > 0: 233 | limit_count += 1 234 | if limit_count >= limit: 235 | break 236 | 237 | elif needles[0] == token: 238 | merge_idxs.append(i) 239 | if limit > 0: 240 | limit_count += 1 241 | if limit_count >= limit: 242 | break 243 | 244 | idxs = [] 245 | for x in merge_idxs: 246 | seq = (int)(x / 75) 247 | if seq == 0: 248 | idxs.append(x + 1) # padding 249 | else: 250 | idxs.append(x + 1 + seq*2) # If tokens exceed 75, they are split. 251 | 252 | return idxs 253 | 254 | nlp = None 255 | 256 | 257 | @lru_cache(maxsize=100000) 258 | def cached_nlp(prompt: str, type='en_core_web_md'): 259 | global nlp 260 | 261 | # if nlp is None: 262 | # nlp = spacy.load(type) 263 | 264 | return nlp(prompt) 265 | 266 | class PromptAnalyzer: 267 | def __init__(self, clip : Union[FrozenCLIPEmbedderWithCustomWordsBase, GeneralConditioner], text : str): 268 | use_old = opts.use_old_emphasis_implementation 269 | assert not use_old, "use_old_emphasis_implementation is not supported" 270 | 271 | self.clip = clip 272 | # self.id_start = clip.id_start 273 | # self.id_end = clip.id_end 274 | self.is_open_clip = True if type(clip) == FrozenOpenCLIPEmbedderWithCustomWords else False 275 | self.is_sdxl = True if type(clip) == GeneralConditioner else False 276 | self.used_custom_terms = [] 277 | self.hijack_comments = [] 278 | 279 | chunks, token_count = self.tokenize_line(text) 280 | 281 | self.token_count = token_count 282 | self.fixes = list(chain.from_iterable(chunk.fixes for chunk in chunks)) 283 | self.context_size = calc_context_size(token_count) 284 | 285 | tokens = list(chain.from_iterable(chunk.tokens for chunk in chunks)) 286 | multipliers = list(chain.from_iterable(chunk.multipliers for chunk in chunks)) 287 | print(tokens, multipliers) 288 | print(len(tokens), len(multipliers)) 289 | 290 | self.tokens = tokens # [] 291 | self.multipliers = multipliers # [] 292 | # for i in range(self.context_size // 77): 293 | # self.tokens.extend([self.id_start] + tokens[i*75:i*75+75] + [self.id_end]) 294 | # self.multipliers.extend([1.0] + multipliers[i*75:i*75+75]+ [1.0]) 295 | 296 | def create(self, text : str): 297 | return PromptAnalyzer(self.clip, text) 298 | 299 | def tokenize_line(self, line): 300 | chunks, token_count = self.clip.tokenize_line(line) 301 | return chunks, token_count 302 | 303 | def process_text(self, texts): 304 | batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.clip.process_text(texts) 305 | print(batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count) 306 | return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count 307 | 308 | def encode(self, text : str): 309 | return self.clip.tokenize([text])[0] 310 | 311 | def calc_word_indecies(self, word : str, limit : int = -1, start_pos = 0): 312 | word = word.lower() 313 | merge_idxs = [] 314 | 315 | tokens = self.tokens 316 | needles = self.encode(word) 317 | 318 | limit_count = 0 319 | current_pos = 0 320 | for i, token in enumerate(tokens): 321 | current_pos = i 322 | if i < start_pos: 323 | continue 324 | 325 | if needles[0] == token and len(needles) > 1: 326 | next = i + 1 327 | success = True 328 | for needle in needles[1:]: 329 | if next >= len(tokens) or needle != tokens[next]: 330 | success = False 331 | break 332 | next += 1 333 | 334 | # append consecutive indexes if all pass 335 | if success: 336 | merge_idxs.extend(list(range(i, next))) 337 | if limit > 0: 338 | limit_count += 1 339 | if limit_count >= limit: 340 | break 341 | 342 | elif needles[0] == token: 343 | merge_idxs.append(i) 344 | if limit > 0: 345 | limit_count += 1 346 | if limit_count >= limit: 347 | break 348 | 349 | return merge_idxs, current_pos -------------------------------------------------------------------------------- /scripts/daam_script.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from collections import defaultdict 3 | import os 4 | import re 5 | import traceback 6 | 7 | import gradio as gr 8 | import modules.images as images 9 | import modules.scripts as scripts 10 | import torch 11 | from ldm.modules.encoders.modules import FrozenCLIPEmbedder, FrozenOpenCLIPEmbedder 12 | import sgm 13 | from sgm.modules import GeneralConditioner 14 | import open_clip.tokenizer 15 | from modules import script_callbacks 16 | from modules import script_callbacks, sd_hijack_clip, sd_hijack_open_clip 17 | from modules.processing import (Processed, StableDiffusionProcessing, fix_seed, 18 | process_images) 19 | from modules.shared import cmd_opts, opts, state 20 | import modules.shared as shared 21 | from PIL import Image 22 | 23 | from scripts.daam import trace, utils 24 | 25 | before_image_saved_handler = None 26 | 27 | class Script(scripts.Script): 28 | 29 | GRID_LAYOUT_AUTO = "Auto" 30 | GRID_LAYOUT_PREVENT_EMPTY = "Prevent Empty Spot" 31 | GRID_LAYOUT_BATCH_LENGTH_AS_ROW = "Batch Length As Row" 32 | 33 | 34 | def title(self): 35 | return "Daam script" 36 | 37 | def show(self, is_img2img): 38 | return scripts.AlwaysVisible 39 | 40 | def ui(self, is_img2img): 41 | with gr.Group(): 42 | with gr.Accordion("Attention Heatmap", open=False): 43 | attention_texts = gr.Text(label='Attention texts for visualization. (comma separated)', value='') 44 | 45 | with gr.Row(): 46 | hide_images = gr.Checkbox(label='Hide heatmap images', value=False) 47 | 48 | dont_save_images = gr.Checkbox(label='Do not save heatmap images', value=False) 49 | 50 | hide_caption = gr.Checkbox(label='Hide caption', value=False) 51 | 52 | with gr.Row(): 53 | use_grid = gr.Checkbox(label='Use grid (output to grid dir)', value=False) 54 | 55 | grid_layouyt = gr.Dropdown( 56 | [Script.GRID_LAYOUT_AUTO, Script.GRID_LAYOUT_PREVENT_EMPTY, Script.GRID_LAYOUT_BATCH_LENGTH_AS_ROW], label="Grid layout", 57 | value=Script.GRID_LAYOUT_AUTO 58 | ) 59 | 60 | with gr.Row(): 61 | alpha = gr.Slider(label='Heatmap blend alpha', value=0.5, minimum=0, maximum=1, step=0.01) 62 | 63 | heatmap_image_scale = gr.Slider(label='Heatmap image scale', value=1.0, minimum=0.1, maximum=1, step=0.025) 64 | 65 | with gr.Row(): 66 | trace_each_layers = gr.Checkbox(label = 'Trace each layers', value=False) 67 | 68 | layers_as_row = gr.Checkbox(label = 'Use layers as row instead of Batch Length', value=False) 69 | 70 | 71 | self.tracers = None 72 | 73 | return [attention_texts, hide_images, dont_save_images, hide_caption, use_grid, grid_layouyt, alpha, heatmap_image_scale, trace_each_layers, layers_as_row] 74 | 75 | def process(self, 76 | p : StableDiffusionProcessing, 77 | attention_texts : str, 78 | hide_images : bool, 79 | dont_save_images : bool, 80 | hide_caption : bool, 81 | use_grid : bool, 82 | grid_layouyt :str, 83 | alpha : float, 84 | heatmap_image_scale : float, 85 | trace_each_layers : bool, 86 | layers_as_row: bool): 87 | 88 | self.enabled = False # in case the assert fails 89 | assert opts.samples_save, "Cannot run Daam script. Enable 'Always save all generated images' setting." 90 | 91 | self.images = [] 92 | self.hide_images = hide_images 93 | self.dont_save_images = dont_save_images 94 | self.hide_caption = hide_caption 95 | self.alpha = alpha 96 | self.use_grid = use_grid 97 | self.grid_layouyt = grid_layouyt 98 | self.heatmap_image_scale = heatmap_image_scale 99 | self.heatmap_images = dict() 100 | 101 | self.attentions = [s.strip() for s in attention_texts.split(",") if s.strip()] 102 | self.enabled = len(self.attentions) > 0 103 | 104 | fix_seed(p) 105 | 106 | def process_batch(self, 107 | p : StableDiffusionProcessing, 108 | attention_texts : str, 109 | hide_images : bool, 110 | dont_save_images : bool, 111 | hide_caption : bool, 112 | use_grid : bool, 113 | grid_layouyt :str, 114 | alpha : float, 115 | heatmap_image_scale : float, 116 | trace_each_layers : bool, 117 | layers_as_row: bool, 118 | prompts, 119 | **kwargs): 120 | 121 | if not self.enabled: 122 | return 123 | 124 | styled_prompt = prompts[0] 125 | 126 | embedder = None 127 | if type(p.sd_model.cond_stage_model) == sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords or \ 128 | type(p.sd_model.cond_stage_model) == sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords: 129 | embedder = p.sd_model.cond_stage_model 130 | elif type(p.sd_model.cond_stage_model) == GeneralConditioner: 131 | conditioner = p.sd_model.cond_stage_model 132 | print(conditioner.embedders) 133 | embedder = conditioner.embedders[0] 134 | else: 135 | assert False, f"Embedder '{type(p.sd_model.cond_stage_model)}' is not supported." 136 | 137 | clip = None 138 | tokenize = None 139 | clip_type = type(embedder.wrapped) 140 | if clip_type == FrozenCLIPEmbedder: 141 | clip : FrozenCLIPEmbedder = embedder.wrapped 142 | tokenize = clip.tokenizer.tokenize 143 | elif clip_type == FrozenOpenCLIPEmbedder: 144 | clip : FrozenOpenCLIPEmbedder = embedder.wrapped 145 | tokenize = open_clip.tokenizer._tokenizer.encode 146 | elif clip_type == sgm.modules.encoders.modules.FrozenCLIPEmbedder: 147 | clip : sgm.modules.encoders.modules.FrozenCLIPEmbedder = embedder.wrapped 148 | tokenize = clip.tokenizer.tokenize 149 | else: 150 | assert False, f"CLIP '{clip_type}' is not supported." 151 | 152 | tokens = tokenize(utils.escape_prompt(styled_prompt)) 153 | context_size = utils.calc_context_size(len(tokens)) 154 | 155 | prompt_analyzer = utils.PromptAnalyzer(embedder, styled_prompt) 156 | self.prompt_analyzer = prompt_analyzer 157 | context_size = prompt_analyzer.context_size 158 | 159 | print(f"daam run with context_size={prompt_analyzer.context_size}, token_count={prompt_analyzer.token_count}") 160 | print(f"remade_tokens={prompt_analyzer.tokens}, multipliers={prompt_analyzer.multipliers}") 161 | print(f"hijack_comments={prompt_analyzer.hijack_comments}, used_custom_terms={prompt_analyzer.used_custom_terms}") 162 | print(f"fixes={prompt_analyzer.fixes}") 163 | 164 | if any(item[0] in self.attentions for item in self.prompt_analyzer.used_custom_terms): 165 | print("Embedding heatmap cannot be shown.") 166 | 167 | global before_image_saved_handler 168 | before_image_saved_handler = lambda params : self.before_image_saved(params) 169 | 170 | with torch.no_grad(): 171 | # cannot trace the same block from two tracers 172 | if trace_each_layers: 173 | num_input = len(p.sd_model.model.diffusion_model.input_blocks) 174 | num_output = len(p.sd_model.model.diffusion_model.output_blocks) 175 | self.tracers = [trace(p.sd_model, p.height, p.width, context_size, layer_idx=i) for i in range(num_input + num_output + 1)] 176 | self.attn_captions = [f"IN{i:02d}" for i in range(num_input)] + ["MID"] + [f"OUT{i:02d}" for i in range(num_output)] 177 | else: 178 | self.tracers = [trace(p.sd_model, p.height, p.width, context_size)] 179 | self.attn_captions = [""] 180 | 181 | for tracer in self.tracers: 182 | tracer.hook() 183 | 184 | def postprocess(self, p, processed, 185 | attention_texts : str, 186 | hide_images : bool, 187 | dont_save_images : bool, 188 | hide_caption : bool, 189 | use_grid : bool, 190 | grid_layouyt :str, 191 | alpha : float, 192 | heatmap_image_scale : float, 193 | trace_each_layers : bool, 194 | layers_as_row: bool, 195 | **kwargs): 196 | if self.enabled == False: 197 | return 198 | 199 | for trace in self.tracers: 200 | trace.unhook() 201 | self.tracers = None 202 | 203 | initial_info = None 204 | 205 | if initial_info is None: 206 | initial_info = processed.info 207 | 208 | self.images += processed.images 209 | 210 | global before_image_saved_handler 211 | before_image_saved_handler = None 212 | 213 | if layers_as_row: 214 | images_list = [] 215 | for i in range(p.batch_size * p.n_iter): 216 | imgs = [] 217 | for k in sorted(self.heatmap_images.keys()): 218 | imgs += [self.heatmap_images[k][len(self.attentions)*i + j] for j in range(len(self.attentions))] 219 | images_list.append(imgs) 220 | else: 221 | images_list = [self.heatmap_images[k] for k in sorted(self.heatmap_images.keys())] 222 | 223 | for img_list in images_list: 224 | 225 | if img_list and self.use_grid: 226 | 227 | grid_layout = self.grid_layouyt 228 | if grid_layout == Script.GRID_LAYOUT_AUTO: 229 | if p.batch_size * p.n_iter == 1: 230 | grid_layout = Script.GRID_LAYOUT_PREVENT_EMPTY 231 | else: 232 | grid_layout = Script.GRID_LAYOUT_BATCH_LENGTH_AS_ROW 233 | 234 | if grid_layout == Script.GRID_LAYOUT_PREVENT_EMPTY: 235 | grid_img = images.image_grid(img_list) 236 | elif grid_layout == Script.GRID_LAYOUT_BATCH_LENGTH_AS_ROW: 237 | if layers_as_row: 238 | batch_size = len(self.attentions) 239 | rows = len(self.heatmap_images) 240 | else: 241 | batch_size = p.batch_size 242 | rows = p.batch_size * p.n_iter 243 | grid_img = images.image_grid(img_list, batch_size=batch_size, rows=rows) 244 | else: 245 | pass 246 | 247 | if not self.dont_save_images: 248 | images.save_image(grid_img, p.outpath_grids, "grid_daam", grid=True, p=p) 249 | 250 | if not self.hide_images: 251 | processed.images.insert(0, grid_img) 252 | processed.index_of_first_image += 1 253 | processed.infotexts.insert(0, processed.infotexts[0]) 254 | 255 | else: 256 | if not self.hide_images: 257 | processed.images[:0] = img_list 258 | processed.index_of_first_image += len(img_list) 259 | processed.infotexts[:0] = [processed.infotexts[0]] * len(img_list) 260 | 261 | return processed 262 | 263 | def before_image_saved(self, params : script_callbacks.ImageSaveParams): 264 | batch_pos = -1 265 | if params.p.batch_size > 1: 266 | match = re.search(r"Batch pos: (\d+)", params.pnginfo['parameters']) 267 | if match: 268 | batch_pos = int(match.group(1)) 269 | else: 270 | batch_pos = 0 271 | 272 | if batch_pos < 0: 273 | return 274 | 275 | if self.tracers is not None and len(self.attentions) > 0: 276 | for i, tracer in enumerate(self.tracers): 277 | with torch.no_grad(): 278 | styled_prompot = shared.prompt_styles.apply_styles_to_prompt(params.p.prompt, params.p.styles) 279 | try: 280 | global_heat_map = tracer.compute_global_heat_map(self.prompt_analyzer, styled_prompot, batch_pos) 281 | except: 282 | continue 283 | 284 | if i not in self.heatmap_images: 285 | self.heatmap_images[i] = [] 286 | 287 | if global_heat_map is not None: 288 | heatmap_images = [] 289 | for attention in self.attentions: 290 | 291 | img_size = params.image.size 292 | caption = attention + (" " + self.attn_captions[i] if self.attn_captions[i] else "") if not self.hide_caption else None 293 | 294 | heat_map = global_heat_map.compute_word_heat_map(attention) 295 | if heat_map is None : print(f"No heatmaps for '{attention}'") 296 | 297 | heat_map_img = utils.expand_image(heat_map, img_size[1], img_size[0]) if heat_map is not None else None 298 | img : Image.Image = utils.image_overlay_heat_map(params.image, heat_map_img, alpha=self.alpha, caption=caption, image_scale=self.heatmap_image_scale) 299 | 300 | fullfn_without_extension, extension = os.path.splitext(params.filename) 301 | full_filename = fullfn_without_extension + "_" + attention + ("_" + self.attn_captions[i] if self.attn_captions[i] else "") + extension 302 | 303 | if self.use_grid: 304 | heatmap_images.append(img) 305 | else: 306 | heatmap_images.append(img) 307 | if not self.dont_save_images: 308 | img.save(full_filename) 309 | 310 | self.heatmap_images[i] += heatmap_images 311 | 312 | self.heatmap_images = {j:self.heatmap_images[j] for j in self.heatmap_images.keys() if self.heatmap_images[j]} 313 | 314 | # if it is last batch pos, clear heatmaps 315 | if batch_pos == params.p.batch_size - 1: 316 | for tracer in self.tracers: 317 | tracer.reset() 318 | 319 | return 320 | 321 | 322 | def handle_before_image_saved(params : script_callbacks.ImageSaveParams): 323 | 324 | if before_image_saved_handler is not None and callable(before_image_saved_handler): 325 | before_image_saved_handler(params) 326 | 327 | return 328 | 329 | script_callbacks.on_before_image_saved(handle_before_image_saved) 330 | --------------------------------------------------------------------------------