├── .gitignore ├── LICENSE ├── README.md ├── images ├── 00006-2623256163.png ├── 00006-2623256163_beach.png ├── 00006-2623256163_cat.png └── 00006-2623256163_sunglasses.png ├── install.py └── scripts ├── daam ├── __init__.py ├── evaluate.py ├── experiment.py ├── hook.py ├── trace.py └── utils.py └── daam_script.py /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode/ 2 | .env 3 | 4 | scripts/daam/__pycache__/* 5 | __pycache__/* -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Modified work Copyright (c) 2022 kousw 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | 23 | 24 | ### castorini/daam 25 | 26 | License: MIT 27 | By: Castorini 28 | Repository: 29 | 30 | >MIT License 31 | > 32 | >Copyright (c) 2022 Castorini 33 | > 34 | >Permission is hereby granted, free of charge, to any person obtaining a copy 35 | >of this software and associated documentation files (the "Software"), to deal 36 | >in the Software without restriction, including without limitation the rights 37 | >to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 38 | >copies of the Software, and to permit persons to whom the Software is 39 | >furnished to do so, subject to the following conditions: 40 | > 41 | >The above copyright notice and this permission notice shall be included in all 42 | >copies or substantial portions of the Software. 43 | > 44 | >THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 45 | >IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 46 | >FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 47 | >AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 48 | >LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 49 | >OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 50 | >SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DAAM Extension for Stable Diffusion Web UI 2 | 3 | This is a port of [DAAM](https://github.com/castorini/daam) for Stable Diffusion Web UI. 4 | 5 | # Setup and Running 6 | 7 | Clone this repository to extension folder. 8 | 9 | # How to use 10 | 11 | Select "Daam script" from the script drop-down. Enter the 'attention text' (must be a string contained in the prompt ) and run. 12 | An overlapping image with a heatmap for each attention will be generated along with the original image. 13 | Images will now be created in the default output directory. 14 | 15 | Attention text is divided by commas, but multiple words without commas are recognized as a single sequence. 16 | If you type "cat" for attention text, then all the tokens matching "cat" will be retrieved and combined into attention. 17 | If you type "cute cat", only tokens with "cute" and "cat" in sequence will be retrieved and only their attention will be output. 18 | 19 | # Sample 20 | 21 | prompt : "A photo of a cute cat wearing sunglasses relaxing on a beach" 22 | 23 | attention text: "cat, sunglasses, beach" 24 | 25 | output images: orginal, cat, sunglasses, beach 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | # Notice 34 | At the moment, this works well with the Stable Diffusion 1.5 model. 35 | However, in the Stable Diffusion 2.0 model this seems to be working a little less well. -------------------------------------------------------------------------------- /images/00006-2623256163.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/toriato/stable-diffusion-webui-daam/0906c850fb70d7e4b296f9449763d48fa8d1e687/images/00006-2623256163.png -------------------------------------------------------------------------------- /images/00006-2623256163_beach.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/toriato/stable-diffusion-webui-daam/0906c850fb70d7e4b296f9449763d48fa8d1e687/images/00006-2623256163_beach.png -------------------------------------------------------------------------------- /images/00006-2623256163_cat.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/toriato/stable-diffusion-webui-daam/0906c850fb70d7e4b296f9449763d48fa8d1e687/images/00006-2623256163_cat.png -------------------------------------------------------------------------------- /images/00006-2623256163_sunglasses.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/toriato/stable-diffusion-webui-daam/0906c850fb70d7e4b296f9449763d48fa8d1e687/images/00006-2623256163_sunglasses.png -------------------------------------------------------------------------------- /install.py: -------------------------------------------------------------------------------- 1 | import launch 2 | 3 | 4 | def check_matplotlib(): 5 | if not launch.is_installed("matplotlib"): 6 | return False 7 | 8 | try: 9 | import matplotlib 10 | except ImportError: 11 | return False 12 | 13 | if hasattr(matplotlib, "__version_info__"): 14 | version = matplotlib.__version_info__ 15 | version = (version.major, version.minor, version.micro) 16 | return version >= (3, 6, 2) 17 | return False 18 | 19 | 20 | if not check_matplotlib(): 21 | launch.run_pip("install matplotlib==3.6.2", desc="Installing matplotlib==3.6.2") 22 | -------------------------------------------------------------------------------- /scripts/daam/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from .hook import * 3 | from .utils import * 4 | from .evaluate import * 5 | from .experiment import * 6 | from .trace import * -------------------------------------------------------------------------------- /scripts/daam/evaluate.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from collections import defaultdict 3 | from typing import List 4 | 5 | from scipy.optimize import linear_sum_assignment 6 | import PIL.Image as Image 7 | import numpy as np 8 | import torch 9 | import torch.nn.functional as F 10 | 11 | 12 | __all__ = ['compute_iou', 'MeanEvaluator'] 13 | 14 | 15 | def compute_iou(a: torch.Tensor, b: torch.Tensor) -> float: 16 | if a.shape[0] != b.shape[0]: 17 | a = F.interpolate(a.unsqueeze(0).unsqueeze(0).float(), size=b.shape, mode='bicubic').squeeze() 18 | a[a < 1] = 0 19 | a[a >= 1] = 1 20 | 21 | intersection = (a * b).sum() 22 | union = a.sum() + b.sum() - intersection 23 | 24 | return (intersection / (union + 1e-8)).item() 25 | 26 | 27 | def load_mask(path: str) -> torch.Tensor: 28 | mask = np.array(Image.open(path)) 29 | mask = torch.from_numpy(mask).float()[:, :, 3] # use alpha channel 30 | mask = (mask > 0).float() 31 | 32 | return mask 33 | 34 | 35 | class UnsupervisedEvaluator: 36 | def __init__(self, name: str = 'UnsupervisedEvaluator'): 37 | self.name = name 38 | self.ious = defaultdict(list) 39 | self.num_samples = 0 40 | 41 | def log_iou(self, preds: torch.Tensor | List[torch.Tensor], truth: torch.Tensor, gt_idx: int = 0, pred_idx: int = 0): 42 | if not isinstance(preds, list): 43 | preds = [preds] 44 | 45 | iou = max(compute_iou(pred, truth) for pred in preds) 46 | self.ious[gt_idx].append((pred_idx, iou)) 47 | 48 | @property 49 | def mean_iou(self) -> float: 50 | n = max(max(self.ious), max([y[0] for x in self.ious.values() for y in x])) + 1 51 | iou_matrix = np.zeros((n, n)) 52 | count_matrix = np.zeros((n, n)) 53 | 54 | for gt_idx, ious in self.ious.items(): 55 | for pred_idx, iou in ious: 56 | iou_matrix[gt_idx, pred_idx] += iou 57 | count_matrix[gt_idx, pred_idx] += 1 58 | 59 | row_ind, col_ind = linear_sum_assignment(iou_matrix, maximize=True) 60 | return iou_matrix[row_ind, col_ind].sum() / count_matrix[row_ind, col_ind].sum() 61 | 62 | def increment(self): 63 | self.num_samples += 1 64 | 65 | def __len__(self) -> int: 66 | return self.num_samples 67 | 68 | def __str__(self): 69 | return f'{self.name}<{self.mean_iou:.4f} (mIoU) {len(self)} samples>' 70 | 71 | 72 | class MeanEvaluator: 73 | def __init__(self, name: str = 'MeanEvaluator'): 74 | self.ious: List[float] = [] 75 | self.intensities: List[float] = [] 76 | self.name = name 77 | 78 | def log_iou(self, preds: torch.Tensor | List[torch.Tensor], truth: torch.Tensor): 79 | if not isinstance(preds, list): 80 | preds = [preds] 81 | 82 | self.ious.append(max(compute_iou(pred, truth) for pred in preds)) 83 | return self 84 | 85 | def log_intensity(self, pred: torch.Tensor): 86 | self.intensities.append(pred.mean().item()) 87 | return self 88 | 89 | @property 90 | def mean_iou(self) -> float: 91 | return np.mean(self.ious) 92 | 93 | @property 94 | def mean_intensity(self) -> float: 95 | return np.mean(self.intensities) 96 | 97 | @property 98 | def ci95_miou(self) -> float: 99 | return 1.96 * np.std(self.ious) / np.sqrt(len(self.ious)) 100 | 101 | def __len__(self) -> int: 102 | return max(len(self.ious), len(self.intensities)) 103 | 104 | def __str__(self): 105 | return f'{self.name}<{self.mean_iou:.4f} (±{self.ci95_miou:.3f} mIoU) {self.mean_intensity:.4f} (mInt) {len(self)} samples>' 106 | 107 | 108 | if __name__ == '__main__': 109 | mask = load_mask('truth/output/452/sink.gt.png') 110 | 111 | print(MeanEvaluator().log_iou(mask, mask)) 112 | -------------------------------------------------------------------------------- /scripts/daam/experiment.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from pathlib import Path 3 | from typing import List, Optional, Dict, Any 4 | from dataclasses import dataclass 5 | import json 6 | 7 | from transformers import PreTrainedTokenizer, AutoTokenizer 8 | import PIL.Image 9 | import numpy as np 10 | import torch 11 | 12 | from .evaluate import load_mask 13 | from .utils import plot_overlay_heat_map, expand_image 14 | 15 | __all__ = ['GenerationExperiment', 'COCO80_LABELS', 'COCOSTUFF27_LABELS', 'COCO80_INDICES', 'build_word_list_coco80'] 16 | 17 | 18 | COCO80_LABELS: List[str] = [ 19 | 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', 20 | 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 21 | 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 22 | 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 23 | 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 24 | 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 25 | 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 26 | 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 27 | 'hair drier', 'toothbrush' 28 | ] 29 | 30 | COCO80_INDICES: Dict[str, int] = {x: i for i, x in enumerate(COCO80_LABELS)} 31 | 32 | UNUSED_LABELS: List[str] = [f'__unused_{i}__' for i in range(1, 200)] 33 | 34 | COCOSTUFF27_LABELS: List[str] = [ 35 | 'electronic', 'appliance', 'food', 'furniture', 'indoor', 'kitchen', 'accessory', 'animal', 'outdoor', 'person', 36 | 'sports', 'vehicle', 'ceiling', 'floor', 'food', 'furniture', 'rawmaterial', 'textile', 'wall', 'window', 37 | 'building', 'ground', 'plant', 'sky', 'solid', 'structural', 'water' 38 | ] 39 | 40 | COCO80_ONTOLOGY = { 41 | 'two-wheeled vehicle': ['bicycle', 'motorcycle'], 42 | 'vehicle': ['two-wheeled vehicle', 'four-wheeled vehicle'], 43 | 'four-wheeled vehicle': ['bus', 'truck', 'car'], 44 | 'four-legged animals': ['livestock', 'pets', 'wild animals'], 45 | 'livestock': ['cow', 'horse', 'sheep'], 46 | 'pets': ['cat', 'dog'], 47 | 'wild animals': ['elephant', 'bear', 'zebra', 'giraffe'], 48 | 'bags': ['backpack', 'handbag', 'suitcase'], 49 | 'sports boards': ['snowboard', 'surfboard', 'skateboard'], 50 | 'utensils': ['fork', 'knife', 'spoon'], 51 | 'receptacles': ['bowl', 'cup'], 52 | 'fruits': ['banana', 'apple', 'orange'], 53 | 'foods': ['fruits', 'meals', 'desserts'], 54 | 'meals': ['sandwich', 'hot dog', 'pizza'], 55 | 'desserts': ['cake', 'donut'], 56 | 'furniture': ['chair', 'couch', 'bench'], 57 | 'electronics': ['monitors', 'appliances'], 58 | 'monitors': ['tv', 'cell phone', 'laptop'], 59 | 'appliances': ['oven', 'toaster', 'refrigerator'] 60 | } 61 | 62 | COCO80_TO_27 = { 63 | 'bicycle': 'vehicle', 'car': 'vehicle', 'motorcycle': 'vehicle', 'airplane': 'vehicle', 'bus': 'vehicle', 64 | 'train': 'vehicle', 'truck': 'vehicle', 'boat': 'vehicle', 'traffic light': 'accessory', 'fire hydrant': 'accessory', 65 | 'stop sign': 'accessory', 'parking meter': 'accessory', 'bench': 'furniture', 'bird': 'animal', 'cat': 'animal', 66 | 'dog': 'animal', 'horse': 'animal', 'sheep': 'animal', 'cow': 'animal', 'elephant': 'animal', 'bear': 'animal', 67 | 'zebra': 'animal', 'giraffe': 'animal', 'backpack': 'accessory', 'umbrella': 'accessory', 'handbag': 'accessory', 68 | 'tie': 'accessory', 'suitcase': 'accessory', 'frisbee': 'sports', 'skis': 'sports', 'snowboard': 'sports', 69 | 'sports ball': 'sports', 'kite': 'sports', 'baseball bat': 'sports', 'baseball glove': 'sports', 70 | 'skateboard': 'sports', 'surfboard': 'sports', 'tennis racket': 'sports', 'bottle': 'food', 'wine glass': 'food', 71 | 'cup': 'food', 'fork': 'food', 'knife': 'food', 'spoon': 'food', 'bowl': 'food', 'banana': 'food', 'apple': 'food', 72 | 'sandwich': 'food', 'orange': 'food', 'broccoli': 'food', 'carrot': 'food', 'hot dog': 'food', 'pizza': 'food', 73 | 'donut': 'food', 'cake': 'food', 'chair': 'furniture', 'couch': 'furniture', 'potted plant': 'plant', 74 | 'bed': 'furniture', 'dining table': 'furniture', 'toilet': 'furniture', 'tv': 'electronic', 'laptop': 'electronic', 75 | 'mouse': 'electronic', 'remote': 'electronic', 'keyboard': 'electronic', 'cell phone': 'electronic', 76 | 'microwave': 'appliance', 'oven': 'appliance', 'toaster': 'appliance', 'sink': 'appliance', 77 | 'refrigerator': 'appliance', 'book': 'indoor', 'clock': 'indoor', 'vase': 'indoor', 'scissors': 'indoor', 78 | 'teddy bear': 'indoor', 'hair drier': 'indoor', 'toothbrush': 'indoor' 79 | } 80 | 81 | 82 | def build_word_list_coco80() -> Dict[str, List[str]]: 83 | words_map = COCO80_ONTOLOGY.copy() 84 | words_map = {k: v for k, v in words_map.items() if not any(item in COCO80_ONTOLOGY for item in v)} 85 | 86 | return words_map 87 | 88 | 89 | def _add_mask(masks: Dict[str, torch.Tensor], word: str, mask: torch.Tensor, simplify80: bool = False) -> Dict[str, torch.Tensor]: 90 | if simplify80: 91 | word = COCO80_TO_27.get(word, word) 92 | 93 | if word in masks: 94 | masks[word] = masks[word.lower()] + mask 95 | masks[word].clamp_(0, 1) 96 | else: 97 | masks[word] = mask 98 | 99 | return masks 100 | 101 | 102 | @dataclass 103 | class GenerationExperiment: 104 | """Class to hold experiment parameters. Pickleable.""" 105 | id: str 106 | image: PIL.Image.Image 107 | global_heat_map: torch.Tensor 108 | seed: int 109 | prompt: str 110 | 111 | path: Optional[Path] = None 112 | truth_masks: Optional[Dict[str, torch.Tensor]] = None 113 | prediction_masks: Optional[Dict[str, torch.Tensor]] = None 114 | annotations: Optional[Dict[str, Any]] = None 115 | subtype: Optional[str] = '.' 116 | 117 | def nsfw(self) -> bool: 118 | return np.sum(np.array(self.image)) == 0 119 | 120 | def heat_map(self, tokenizer: AutoTokenizer): 121 | from .trace import HeatMap 122 | return HeatMap(tokenizer, self.prompt, self.global_heat_map) 123 | 124 | def save(self, path: str = None): 125 | if path is None: 126 | path = self.path 127 | else: 128 | path = Path(path) / self.id 129 | 130 | (path / self.subtype).mkdir(parents=True, exist_ok=True) 131 | torch.save(self, path / self.subtype / 'generation.pt') 132 | self.image.save(path / self.subtype / 'output.png') 133 | 134 | with (path / 'prompt.txt').open('w') as f: 135 | f.write(self.prompt) 136 | 137 | with (path / 'seed.txt').open('w') as f: 138 | f.write(str(self.seed)) 139 | 140 | if self.truth_masks is not None: 141 | for name, mask in self.truth_masks.items(): 142 | im = PIL.Image.fromarray((mask * 255).unsqueeze(-1).expand(-1, -1, 4).byte().numpy()) 143 | im.save(path / f'{name.lower()}.gt.png') 144 | 145 | self.save_annotations() 146 | 147 | def save_annotations(self, path: Path = None): 148 | if path is None: 149 | path = self.path 150 | 151 | if self.annotations is not None: 152 | with (path / 'annotations.json').open('w') as f: 153 | json.dump(self.annotations, f) 154 | 155 | def _load_truth_masks(self, simplify80: bool = False) -> Dict[str, torch.Tensor]: 156 | masks = {} 157 | 158 | for mask_path in self.path.glob('*.gt.png'): 159 | word = mask_path.name.split('.gt.png')[0].lower() 160 | mask = load_mask(str(mask_path)) 161 | _add_mask(masks, word, mask, simplify80) 162 | 163 | return masks 164 | 165 | def _load_pred_masks(self, pred_prefix, composite=False, simplify80=False, vocab=None): 166 | # type: (str, bool, bool, List[str] | None) -> Dict[str, torch.Tensor] 167 | masks = {} 168 | 169 | if vocab is None: 170 | vocab = UNUSED_LABELS 171 | 172 | if composite: 173 | try: 174 | im = PIL.Image.open(self.path / self.subtype / f'composite.{pred_prefix}.pred.png') 175 | im = np.array(im) 176 | 177 | for mask_idx in np.unique(im): 178 | mask = torch.from_numpy((im == mask_idx).astype(np.float32)) 179 | _add_mask(masks, vocab[mask_idx], mask, simplify80) 180 | except FileNotFoundError: 181 | pass 182 | else: 183 | for mask_path in (self.path / self.subtype).glob(f'*.{pred_prefix}.pred.png'): 184 | mask = load_mask(str(mask_path)) 185 | word = mask_path.name.split(f'.{pred_prefix}.pred')[0].lower() 186 | _add_mask(masks, word, mask, simplify80) 187 | 188 | return masks 189 | 190 | def clear_prediction_masks(self, name: str): 191 | path = self if isinstance(self, Path) else self.path 192 | path = path / self.subtype 193 | 194 | for mask_path in path.glob(f'*.{name}.pred.png'): 195 | mask_path.unlink() 196 | 197 | def save_prediction_mask(self, mask: torch.Tensor, word: str, name: str): 198 | path = self if isinstance(self, Path) else self.path 199 | im = PIL.Image.fromarray((mask * 255).unsqueeze(-1).expand(-1, -1, 4).cpu().byte().numpy()) 200 | im.save(path / self.subtype / f'{word.lower()}.{name}.pred.png') 201 | 202 | def save_heat_map(self, tokenizer: PreTrainedTokenizer, word: str, crop: int = None) -> Path: 203 | from .trace import HeatMap # because of cyclical import 204 | heat_map = HeatMap(tokenizer, self.prompt, self.global_heat_map) 205 | heat_map = expand_image(heat_map.compute_word_heat_map(word)) 206 | path = self.path / self.subtype / f'{word.lower()}.heat_map.png' 207 | plot_overlay_heat_map(self.image, heat_map, word, path, crop=crop) 208 | 209 | return path 210 | 211 | def save_all_heat_maps(self, tokenizer: PreTrainedTokenizer, crop: int = None) -> Dict[str, Path]: 212 | path_map = {} 213 | 214 | for word in self.prompt.split(' '): 215 | try: 216 | path = self.save_heat_map(tokenizer, word, crop=crop) 217 | path_map[word] = path 218 | except: 219 | pass 220 | 221 | return path_map 222 | 223 | @staticmethod 224 | def contains_truth_mask(path: str | Path, prompt_id: str = None) -> bool: 225 | if prompt_id is None: 226 | return any(Path(path).glob('*.gt.png')) 227 | else: 228 | return any((Path(path) / prompt_id).glob('*.gt.png')) 229 | 230 | @staticmethod 231 | def read_seed(path: str | Path, prompt_id: str = None) -> int: 232 | if prompt_id is None: 233 | return int(Path(path).joinpath('seed.txt').read_text()) 234 | else: 235 | return int(Path(path).joinpath(prompt_id).joinpath('seed.txt').read_text()) 236 | 237 | @staticmethod 238 | def has_annotations(path: str | Path) -> bool: 239 | return Path(path).joinpath('annotations.json').exists() 240 | 241 | @staticmethod 242 | def has_experiment(path: str | Path, prompt_id: str) -> bool: 243 | return (Path(path) / prompt_id / 'generation.pt').exists() 244 | 245 | @staticmethod 246 | def read_prompt(path: str | Path, prompt_id: str = None) -> str: 247 | if prompt_id is None: 248 | prompt_id = '.' 249 | 250 | with (Path(path) / prompt_id / 'prompt.txt').open('r') as f: 251 | return f.read().strip() 252 | 253 | def _try_load_annotations(self): 254 | if not (self.path / 'annotations.json').exists(): 255 | return None 256 | 257 | return json.load((self.path / 'annotations.json').open()) 258 | 259 | def annotate(self, key: str, value: Any) -> 'GenerationExperiment': 260 | if self.annotations is None: 261 | self.annotations = {} 262 | 263 | self.annotations[key] = value 264 | 265 | return self 266 | 267 | @classmethod 268 | def load( 269 | cls, 270 | path, 271 | pred_prefix='daam', 272 | composite=False, 273 | simplify80=False, 274 | vocab=None, 275 | subtype='.', 276 | all_subtypes=False 277 | ): 278 | # type: (str, str, bool, bool, List[str] | None, str, bool) -> GenerationExperiment | List[GenerationExperiment] 279 | if all_subtypes: 280 | experiments = [] 281 | 282 | for directory in Path(path).iterdir(): 283 | if not directory.is_dir(): 284 | continue 285 | 286 | try: 287 | experiments.append(cls.load( 288 | path, 289 | pred_prefix=pred_prefix, 290 | composite=composite, 291 | simplify80=simplify80, 292 | vocab=vocab, 293 | subtype=directory.name 294 | )) 295 | except: 296 | pass 297 | 298 | return experiments 299 | 300 | path = Path(path) 301 | exp = torch.load(path / subtype / 'generation.pt') 302 | exp.subtype = subtype 303 | exp.path = path 304 | exp.truth_masks = exp._load_truth_masks(simplify80=simplify80) 305 | exp.prediction_masks = exp._load_pred_masks(pred_prefix, composite=composite, simplify80=simplify80, vocab=vocab) 306 | exp.annotations = exp._try_load_annotations() 307 | 308 | return exp 309 | -------------------------------------------------------------------------------- /scripts/daam/hook.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing import List, Generic, TypeVar, Callable, Union, Any 3 | import functools 4 | import itertools 5 | from ldm.modules.attention import SpatialTransformer 6 | 7 | from ldm.modules.diffusionmodules.openaimodel import UNetModel 8 | from ldm.modules.attention import CrossAttention 9 | import torch.nn as nn 10 | 11 | 12 | __all__ = ['ObjectHooker', 'ModuleLocator', 'AggregateHooker', 'UNetCrossAttentionLocator'] 13 | 14 | 15 | ModuleType = TypeVar('ModuleType') 16 | ModuleListType = TypeVar('ModuleListType', bound=List) 17 | 18 | 19 | class ModuleLocator(Generic[ModuleType]): 20 | def locate(self, model: nn.Module) -> List[ModuleType]: 21 | raise NotImplementedError 22 | 23 | 24 | class ObjectHooker(Generic[ModuleType]): 25 | def __init__(self, module: ModuleType): 26 | self.module: ModuleType = module 27 | self.hooked = False 28 | self.old_state = dict() 29 | 30 | def __enter__(self): 31 | self.hook() 32 | return self 33 | 34 | def __exit__(self, exc_type, exc_val, exc_tb): 35 | self.unhook() 36 | 37 | def hook(self): 38 | if self.hooked: 39 | raise RuntimeError('Already hooked module') 40 | 41 | self.old_state = dict() 42 | self.hooked = True 43 | self._hook_impl() 44 | 45 | return self 46 | 47 | def unhook(self): 48 | if not self.hooked: 49 | raise RuntimeError('Module is not hooked') 50 | 51 | for k, v in self.old_state.items(): 52 | if k.startswith('old_fn_'): 53 | setattr(self.module, k[7:], v) 54 | 55 | self.hooked = False 56 | self._unhook_impl() 57 | 58 | return self 59 | 60 | def monkey_patch(self, fn_name, fn): 61 | self.old_state[f'old_fn_{fn_name}'] = getattr(self.module, fn_name) 62 | setattr(self.module, fn_name, functools.partial(fn, self.module)) 63 | 64 | def monkey_super(self, fn_name, *args, **kwargs): 65 | return self.old_state[f'old_fn_{fn_name}'](*args, **kwargs) 66 | 67 | def _hook_impl(self): 68 | raise NotImplementedError 69 | 70 | def _unhook_impl(self): 71 | pass 72 | 73 | 74 | class AggregateHooker(ObjectHooker[ModuleListType]): 75 | def _hook_impl(self): 76 | for h in self.module: 77 | h.hook() 78 | 79 | def _unhook_impl(self): 80 | for h in self.module: 81 | h.unhook() 82 | 83 | def register_hook(self, hook: ObjectHooker): 84 | self.module.append(hook) 85 | 86 | 87 | class UNetCrossAttentionLocator(ModuleLocator[CrossAttention]): 88 | def locate(self, model: UNetModel, layer_idx: int) -> List[CrossAttention]: 89 | """ 90 | Locate all cross-attention modules in a UNetModel. 91 | 92 | Args: 93 | model (`UNetModel`): The model to locate the cross-attention modules in. 94 | 95 | Returns: 96 | `List[CrossAttention]`: The list of cross-attention modules. 97 | """ 98 | blocks = [] 99 | 100 | for i, unet_block in enumerate(itertools.chain(model.input_blocks, [model.middle_block], model.output_blocks)): 101 | # if 'CrossAttn' in unet_block.__class__.__name__: 102 | if not layer_idx or i == layer_idx: 103 | for module in unet_block.modules(): 104 | if type(module) is SpatialTransformer: 105 | spatial_transformer = module 106 | for basic_transformer_block in spatial_transformer.transformer_blocks: 107 | blocks.append(basic_transformer_block.attn2) 108 | 109 | return blocks 110 | -------------------------------------------------------------------------------- /scripts/daam/trace.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from collections import defaultdict 3 | from copy import deepcopy 4 | from pathlib import Path 5 | from typing import List, Type, Any, Literal, Dict 6 | import math 7 | from modules.devices import device 8 | 9 | from ldm.models.diffusion.ddpm import DiffusionWrapper, LatentDiffusion 10 | from ldm.modules.diffusionmodules.openaimodel import UNetModel 11 | from ldm.modules.attention import CrossAttention, default, exists 12 | 13 | import numba 14 | import numpy as np 15 | import torch 16 | import torch.nn.functional as F 17 | from torch import nn, einsum 18 | from einops import rearrange, repeat 19 | 20 | from .experiment import COCO80_LABELS 21 | from .hook import ObjectHooker, AggregateHooker, UNetCrossAttentionLocator 22 | from .utils import compute_token_merge_indices, PromptAnalyzer 23 | 24 | 25 | __all__ = ['trace', 'DiffusionHeatMapHooker', 'HeatMap', 'MmDetectHeatMap'] 26 | 27 | 28 | class UNetForwardHooker(ObjectHooker[UNetModel]): 29 | def __init__(self, module: UNetModel, heat_maps: defaultdict(defaultdict)): 30 | super().__init__(module) 31 | self.all_heat_maps = [] 32 | self.heat_maps = heat_maps 33 | 34 | def _hook_impl(self): 35 | self.monkey_patch('forward', self._forward) 36 | 37 | def _unhook_impl(self): 38 | pass 39 | 40 | def _forward(hk_self, self, *args, **kwargs): 41 | super_return = hk_self.monkey_super('forward', *args, **kwargs) 42 | hk_self.all_heat_maps.append(deepcopy(hk_self.heat_maps)) 43 | hk_self.heat_maps.clear() 44 | 45 | return super_return 46 | 47 | 48 | class HeatMap: 49 | def __init__(self, prompt_analyzer: PromptAnalyzer, prompt: str, heat_maps: torch.Tensor): 50 | self.prompt_analyzer = prompt_analyzer.create(prompt) 51 | self.heat_maps = heat_maps 52 | self.prompt = prompt 53 | 54 | def compute_word_heat_map(self, word: str, word_idx: int = None) -> torch.Tensor: 55 | merge_idxs, _ = self.prompt_analyzer.calc_word_indecies(word) 56 | if len(merge_idxs) == 0: 57 | return None 58 | 59 | return self.heat_maps[merge_idxs].mean(0) 60 | 61 | 62 | class MmDetectHeatMap: 63 | def __init__(self, pred_file: str | Path, threshold: float = 0.95): 64 | @numba.njit 65 | def _compute_mask(masks: np.ndarray, bboxes: np.ndarray): 66 | x_any = np.any(masks, axis=1) 67 | y_any = np.any(masks, axis=2) 68 | num_masks = len(bboxes) 69 | 70 | for idx in range(num_masks): 71 | x = np.where(x_any[idx, :])[0] 72 | y = np.where(y_any[idx, :])[0] 73 | bboxes[idx, :4] = np.array([x[0], y[0], x[-1] + 1, y[-1] + 1], dtype=np.float32) 74 | 75 | pred_file = Path(pred_file) 76 | self.word_masks: Dict[str, torch.Tensor] = defaultdict(lambda: 0) 77 | bbox_result, masks = torch.load(pred_file) 78 | labels = [np.full(bbox.shape[0], i, dtype=np.int32) for i, bbox in enumerate(bbox_result)] 79 | labels = np.concatenate(labels) 80 | bboxes = np.vstack(bbox_result) 81 | 82 | if masks is not None and bboxes[:, :4].sum() == 0: 83 | _compute_mask(masks, bboxes) 84 | scores = bboxes[:, -1] 85 | inds = scores > threshold 86 | labels = labels[inds] 87 | masks = masks[inds, ...] 88 | 89 | for lbl, mask in zip(labels, masks): 90 | self.word_masks[COCO80_LABELS[lbl]] |= torch.from_numpy(mask) 91 | 92 | self.word_masks = {k: v.float() for k, v in self.word_masks.items()} 93 | 94 | def compute_word_heat_map(self, word: str) -> torch.Tensor: 95 | return self.word_masks[word] 96 | 97 | 98 | class DiffusionHeatMapHooker(AggregateHooker): 99 | def __init__(self, model: LatentDiffusion, heigth : int, width : int, context_size : int = 77, weighted: bool = False, layer_idx: int = None, head_idx: int = None): 100 | heat_maps = defaultdict(lambda: defaultdict(list)) # batch index, factor, attention 101 | modules = [UNetCrossAttentionHooker(x, heigth, width, heat_maps, context_size=context_size, weighted=weighted, head_idx=head_idx) for x in UNetCrossAttentionLocator().locate(model.model.diffusion_model, layer_idx)] 102 | self.forward_hook = UNetForwardHooker(model.model.diffusion_model, heat_maps) 103 | modules.append(self.forward_hook) 104 | 105 | self.height = heigth 106 | self.width = width 107 | self.model = model 108 | self.last_prompt = '' 109 | 110 | super().__init__(modules) 111 | 112 | 113 | 114 | @property 115 | def all_heat_maps(self): 116 | return self.forward_hook.all_heat_maps 117 | 118 | def reset(self): 119 | map(lambda module: module.reset(), self.module) 120 | return self.forward_hook.all_heat_maps.clear() 121 | 122 | def compute_global_heat_map(self, prompt_analyzer, prompt, batch_index, time_weights=None, time_idx=None, last_n=None, first_n=None, factors=None): 123 | # type: (PromptAnalyzer, str, int, int, int, int, int, List[float]) -> HeatMap 124 | """ 125 | Compute the global heat map for the given prompt, aggregating across time (inference steps) and space (different 126 | spatial transformer block heat maps). 127 | 128 | Args: 129 | prompt: The prompt to compute the heat map for. 130 | time_weights: The weights to apply to each time step. If None, all time steps are weighted equally. 131 | time_idx: The time step to compute the heat map for. If None, the heat map is computed for all time steps. 132 | Mutually exclusive with `last_n` and `first_n`. 133 | last_n: The number of last n time steps to use. If None, the heat map is computed for all time steps. 134 | Mutually exclusive with `time_idx`. 135 | first_n: The number of first n time steps to use. If None, the heat map is computed for all time steps. 136 | Mutually exclusive with `time_idx`. 137 | factors: Restrict the application to heat maps with spatial factors in this set. If `None`, use all sizes. 138 | """ 139 | if len(self.forward_hook.all_heat_maps) == 0: 140 | return None 141 | 142 | if time_weights is None: 143 | time_weights = [1.0] * len(self.forward_hook.all_heat_maps) 144 | 145 | time_weights = np.array(time_weights) 146 | time_weights /= time_weights.sum() 147 | all_heat_maps = self.forward_hook.all_heat_maps 148 | 149 | if time_idx is not None: 150 | heat_maps = [all_heat_maps[time_idx]] 151 | else: 152 | heat_maps = all_heat_maps[-last_n:] if last_n is not None else all_heat_maps 153 | heat_maps = heat_maps[:first_n] if first_n is not None else heat_maps 154 | 155 | 156 | if factors is None: 157 | factors = {1, 2, 4, 8, 16, 32} 158 | else: 159 | factors = set(factors) 160 | 161 | all_merges = [] 162 | 163 | for batch_to_heat_maps in heat_maps: 164 | 165 | if not (batch_index in batch_to_heat_maps): 166 | continue 167 | 168 | merge_list = [] 169 | 170 | factors_to_heat_maps = batch_to_heat_maps[batch_index] 171 | 172 | for k, heat_map in factors_to_heat_maps.items(): 173 | # heat_map shape: (tokens, 1, height, width) 174 | # each v is a heat map tensor for a layer of factor size k across the tokens 175 | if k in factors: 176 | merge_list.append(torch.stack(heat_map, 0).mean(0)) 177 | 178 | if len(merge_list) > 0: 179 | all_merges.append(merge_list) 180 | 181 | maps = torch.stack([torch.stack(x, 0) for x in all_merges], dim=0) 182 | maps = maps.sum(0).to(device).sum(2).sum(0) 183 | 184 | return HeatMap(prompt_analyzer, prompt, maps) 185 | 186 | 187 | class UNetCrossAttentionHooker(ObjectHooker[CrossAttention]): 188 | def __init__(self, module: CrossAttention, img_height : int, img_width : int, heat_maps: defaultdict(defaultdict), context_size: int = 77, weighted: bool = False, head_idx: int = 0): 189 | super().__init__(module) 190 | self.heat_maps = heat_maps 191 | self.context_size = context_size 192 | self.weighted = weighted 193 | self.head_idx = head_idx 194 | self.img_height = img_height 195 | self.img_width = img_width 196 | self.calledCount = 0 197 | 198 | def reset(self): 199 | self.heat_maps.clear() 200 | self.calledCount = 0 201 | 202 | @torch.no_grad() 203 | def _up_sample_attn(self, x, value, factor, method='bicubic'): 204 | # type: (torch.Tensor, torch.Tensor, int, Literal['bicubic', 'conv']) -> torch.Tensor 205 | # x shape: (heads, height * width, tokens) 206 | """ 207 | Up samples the attention map in x using interpolation to the maximum size of (64, 64), as assumed in the Stable 208 | Diffusion model. 209 | 210 | Args: 211 | x (`torch.Tensor`): cross attention slice/map between the words and the tokens. 212 | value (`torch.Tensor`): the value tensor. 213 | method (`str`): the method to use; one of `'bicubic'` or `'conv'`. 214 | 215 | Returns: 216 | `torch.Tensor`: the up-sampled attention map of shape (tokens, 1, height, width). 217 | """ 218 | weight = torch.full((factor, factor), 1 / factor ** 2, device=x.device) 219 | weight = weight.view(1, 1, factor, factor) 220 | 221 | h = int(math.sqrt ( (self.img_height * x.size(1)) / self.img_width)) 222 | w = int(self.img_width * h / self.img_height) 223 | 224 | h_fix = w_fix = 64 225 | if h >= w: 226 | w_fix = int((w * h_fix) / h) 227 | else: 228 | h_fix = int((h * w_fix) / w) 229 | 230 | maps = [] 231 | x = x.permute(2, 0, 1) 232 | value = value.permute(1, 0, 2) 233 | weights = 1 234 | 235 | with torch.cuda.amp.autocast(dtype=torch.float32): 236 | for map_ in x: 237 | map_ = map_.unsqueeze(1).view(map_.size(0), 1, h, w) 238 | 239 | if method == 'bicubic': 240 | map_ = F.interpolate(map_, size=(h_fix, w_fix), mode='bicubic') 241 | maps.append(map_.squeeze(1)) 242 | else: 243 | maps.append(F.conv_transpose2d(map_, weight, stride=factor).squeeze(1)) 244 | 245 | if self.weighted: 246 | weights = value.norm(p=1, dim=-1, keepdim=True).unsqueeze(-1) 247 | 248 | maps = torch.stack(maps, 0) # shape: (tokens, heads, height, width) 249 | 250 | if self.head_idx: 251 | maps = maps[:, self.head_idx:self.head_idx+1, :, :] 252 | 253 | return (weights * maps).sum(1, keepdim=True).cpu() 254 | 255 | def _forward(hk_self, self, x, context=None, mask=None): 256 | hk_self.calledCount += 1 257 | batch_size, sequence_length, _ = x.shape 258 | h = self.heads 259 | 260 | q = self.to_q(x) 261 | context = default(context, x) 262 | k = self.to_k(context) 263 | v = self.to_v(context) 264 | 265 | dim = q.shape[-1] 266 | 267 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) 268 | 269 | sim = einsum('b i d, b j d -> b i j', q, k) * self.scale 270 | 271 | if exists(mask): 272 | mask = rearrange(mask, 'b ... -> b (...)') 273 | max_neg_value = -torch.finfo(sim.dtype).max 274 | mask = repeat(mask, 'b j -> (b h) () j', h=h) 275 | sim.masked_fill_(~mask, max_neg_value) 276 | 277 | out = hk_self._hooked_attention(self, q, k, v, batch_size, sequence_length, dim) 278 | 279 | return self.to_out(out) 280 | 281 | ### forward implemetation of diffuser CrossAttention 282 | # def forward(self, hidden_states, context=None, mask=None): 283 | # batch_size, sequence_length, _ = hidden_states.shape 284 | 285 | # query = self.to_q(hidden_states) 286 | # context = context if context is not None else hidden_states 287 | # key = self.to_k(context) 288 | # value = self.to_v(context) 289 | 290 | # dim = query.shape[-1] 291 | 292 | # query = self.reshape_heads_to_batch_dim(query) 293 | # key = self.reshape_heads_to_batch_dim(key) 294 | # value = self.reshape_heads_to_batch_dim(value) 295 | 296 | # # TODO(PVP) - mask is currently never used. Remember to re-implement when used 297 | 298 | # # attention, what we cannot get enough of 299 | # if self._use_memory_efficient_attention_xformers: 300 | # hidden_states = self._memory_efficient_attention_xformers(query, key, value) 301 | # # Some versions of xformers return output in fp32, cast it back to the dtype of the input 302 | # hidden_states = hidden_states.to(query.dtype) 303 | # else: 304 | # if self._slice_size is None or query.shape[0] // self._slice_size == 1: 305 | # hidden_states = self._attention(query, key, value) 306 | # else: 307 | # hidden_states = self._sliced_attention(query, key, value, sequence_length, dim) 308 | 309 | # # linear proj 310 | # hidden_states = self.to_out[0](hidden_states) 311 | # # dropout 312 | # hidden_states = self.to_out[1](hidden_states) 313 | # return hidden_states 314 | 315 | def _hooked_attention(hk_self, self, query, key, value, batch_size, sequence_length, dim, use_context: bool = True): 316 | """ 317 | Monkey-patched version of :py:func:`.CrossAttention._attention` to capture attentions and aggregate them. 318 | 319 | Args: 320 | hk_self (`UNetCrossAttentionHooker`): pointer to the hook itself. 321 | self (`CrossAttention`): pointer to the module. 322 | query (`torch.Tensor`): the query tensor. 323 | key (`torch.Tensor`): the key tensor. 324 | value (`torch.Tensor`): the value tensor. 325 | batch_size (`int`): the batch size 326 | use_context (`bool`): whether to check if the resulting attention slices are between the words and the image 327 | """ 328 | batch_size_attention = query.shape[0] 329 | hidden_states = torch.zeros( 330 | (batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype 331 | ) 332 | slice_size = hidden_states.shape[0] // batch_size # self._slice_size if self._slice_size is not None else hidden_states.shape[0] 333 | 334 | def calc_factor_base(w, h): 335 | z = max(w/64, h/64) 336 | factor_b = min(w, h) * z 337 | return factor_b 338 | 339 | factor_base = calc_factor_base(hk_self.img_width, hk_self.img_height) 340 | 341 | for batch_index in range(hidden_states.shape[0] // slice_size): 342 | start_idx = batch_index * slice_size 343 | end_idx = (batch_index + 1) * slice_size 344 | attn_slice = ( 345 | torch.einsum("b i d, b j d -> b i j", query[start_idx:end_idx], key[start_idx:end_idx]) * self.scale 346 | ) 347 | factor = int(math.sqrt(factor_base // attn_slice.shape[1])) 348 | attn_slice = attn_slice.softmax(-1) 349 | hid_states = torch.einsum("b i j, b j d -> b i d", attn_slice, value[start_idx:end_idx]) 350 | 351 | if use_context and hk_self.calledCount % 2 == 1 and attn_slice.shape[-1] == hk_self.context_size: 352 | if factor >= 1: 353 | factor //= 1 354 | maps = hk_self._up_sample_attn(attn_slice, value, factor) 355 | hk_self.heat_maps[batch_index][factor].append(maps) 356 | 357 | hidden_states[start_idx:end_idx] = hid_states 358 | 359 | # reshape hidden_states 360 | hidden_states = hk_self.reshape_batch_dim_to_heads(self, hidden_states) 361 | return hidden_states 362 | 363 | def reshape_batch_dim_to_heads(hk_self, self, tensor): 364 | batch_size, seq_len, dim = tensor.shape 365 | head_size = self.heads 366 | tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) 367 | tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) 368 | return tensor 369 | 370 | def _hook_impl(self): 371 | self.monkey_patch('forward', self._forward) 372 | 373 | @property 374 | def num_heat_maps(self): 375 | return len(next(iter(self.heat_maps.values()))) 376 | 377 | 378 | trace: Type[DiffusionHeatMapHooker] = DiffusionHeatMapHooker 379 | -------------------------------------------------------------------------------- /scripts/daam/utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from itertools import chain 3 | from functools import lru_cache 4 | from pathlib import Path 5 | import random 6 | import re 7 | 8 | from PIL import Image, ImageFont, ImageDraw 9 | from fonts.ttf import Roboto 10 | import matplotlib.pyplot as plt 11 | from matplotlib import cm 12 | import numpy as np 13 | # import spacy 14 | import torch 15 | import torch.nn.functional as F 16 | from modules.devices import dtype 17 | 18 | from ldm.modules.encoders.modules import FrozenCLIPEmbedder, FrozenOpenCLIPEmbedder 19 | import open_clip.tokenizer 20 | from modules.sd_hijack_clip import FrozenCLIPEmbedderWithCustomWordsBase, FrozenCLIPEmbedderWithCustomWords 21 | from modules.sd_hijack_open_clip import FrozenOpenCLIPEmbedderWithCustomWords 22 | from modules.shared import opts 23 | 24 | __all__ = ['expand_image', 'set_seed', 'escape_prompt', 'calc_context_size', 'compute_token_merge_indices', 'compute_token_merge_indices_with_tokenizer', 'image_overlay_heat_map', 'plot_overlay_heat_map', 'plot_mask_heat_map', 'PromptAnalyzer'] 25 | 26 | def expand_image(im: torch.Tensor, h = 512, w = 512, absolute: bool = False, threshold: float = None) -> torch.Tensor: 27 | 28 | im = im.unsqueeze(0).unsqueeze(0) 29 | im = F.interpolate(im.float().detach(), size=(h, w), mode='bicubic') 30 | 31 | if not absolute: 32 | im = (im - im.min()) / (im.max() - im.min() + 1e-8) 33 | 34 | if threshold: 35 | im = (im > threshold).float() 36 | 37 | # im = im.cpu().detach() 38 | 39 | return im.squeeze() 40 | 41 | def _write_on_image(img, caption, font_size = 32): 42 | ix,iy = img.size 43 | draw = ImageDraw.Draw(img) 44 | margin=2 45 | fontsize=font_size 46 | draw = ImageDraw.Draw(img) 47 | font = ImageFont.truetype(Roboto, fontsize) 48 | text_height=iy-60 49 | tx = draw.textbbox((0,0),caption,font) 50 | draw.text((int((ix-tx[2])/2),text_height+margin),caption,(0,0,0),font=font) 51 | draw.text((int((ix-tx[2])/2),text_height-margin),caption,(0,0,0),font=font) 52 | draw.text((int((ix-tx[2])/2+margin),text_height),caption,(0,0,0),font=font) 53 | draw.text((int((ix-tx[2])/2-margin),text_height),caption,(0,0,0),font=font) 54 | draw.text((int((ix-tx[2])/2),text_height), caption,(255,255,255),font=font) 55 | return img 56 | 57 | def image_overlay_heat_map(img, heat_map, word=None, out_file=None, crop=None, alpha=0.5, caption=None, image_scale=1.0): 58 | # type: (Image.Image | np.ndarray, torch.Tensor, str, Path, int, float, str, float) -> Image.Image 59 | assert(img is not None) 60 | 61 | if heat_map is not None: 62 | shape : torch.Size = heat_map.shape 63 | # heat_map = heat_map.unsqueeze(-1).expand(shape[0], shape[1], 3).clone() 64 | heat_map = _convert_heat_map_colors(heat_map) 65 | heat_map = heat_map.to('cpu').detach().numpy().copy().astype(np.uint8) 66 | heat_map_img = Image.fromarray(heat_map) 67 | 68 | img = Image.blend(img, heat_map_img, alpha) 69 | else: 70 | img = img.copy() 71 | 72 | if caption: 73 | img = _write_on_image(img, caption) 74 | 75 | if image_scale != 1.0: 76 | x, y = img.size 77 | size = (int(x * image_scale), int(y * image_scale)) 78 | img = img.resize(size, Image.BICUBIC) 79 | 80 | return img 81 | 82 | 83 | def _convert_heat_map_colors(heat_map : torch.Tensor): 84 | def get_color(value): 85 | return np.array(cm.turbo(value / 255)[0:3]) 86 | 87 | color_map = np.array([ get_color(i) * 255 for i in range(256) ]) 88 | color_map = torch.tensor(color_map, device=heat_map.device, dtype=dtype) 89 | 90 | heat_map = (heat_map * 255).long() 91 | 92 | return color_map[heat_map] 93 | 94 | def plot_overlay_heat_map(im, heat_map, word=None, out_file=None, crop=None): 95 | # type: (Image.Image | np.ndarray, torch.Tensor, str, Path, int) -> None 96 | plt.clf() 97 | plt.rcParams.update({'font.size': 24}) 98 | 99 | im = np.array(im) 100 | if crop is not None: 101 | heat_map = heat_map.squeeze()[crop:-crop, crop:-crop] 102 | im = im[crop:-crop, crop:-crop] 103 | 104 | plt.imshow(heat_map.squeeze().cpu().numpy(), cmap='jet') 105 | im = torch.from_numpy(im).float() / 255 106 | im = torch.cat((im, (1 - heat_map.unsqueeze(-1))), dim=-1) 107 | plt.imshow(im) 108 | 109 | if word is not None: 110 | plt.title(word) 111 | 112 | if out_file is not None: 113 | plt.savefig(out_file) 114 | 115 | 116 | def plot_mask_heat_map(im: Image.Image, heat_map: torch.Tensor, threshold: float = 0.4): 117 | im = torch.from_numpy(np.array(im)).float() / 255 118 | mask = (heat_map.squeeze() > threshold).float() 119 | im = im * mask.unsqueeze(-1) 120 | plt.imshow(im) 121 | 122 | 123 | def set_seed(seed: int) -> torch.Generator: 124 | random.seed(seed) 125 | np.random.seed(seed) 126 | torch.manual_seed(seed) 127 | torch.cuda.manual_seed_all(seed) 128 | 129 | gen = torch.Generator(device='cuda') 130 | gen.manual_seed(seed) 131 | 132 | return gen 133 | 134 | def calc_context_size(token_length : int): 135 | len_check = 0 if (token_length - 1) < 0 else token_length - 1 136 | return ((int)(len_check // 75) + 1) * 77 137 | 138 | def escape_prompt(prompt): 139 | if type(prompt) is str: 140 | prompt = prompt.lower() 141 | prompt = re.sub(r"[\(\)\[\]]", "", prompt) 142 | prompt = re.sub(r":\d+\.*\d*", "", prompt) 143 | return prompt 144 | elif type(prompt) is list: 145 | prompt_new = [] 146 | for i in range(len(prompt)): 147 | prompt_new.append(escape_prompt(prompt[i])) 148 | return prompt_new 149 | 150 | 151 | def compute_token_merge_indices(model, prompt: str, word: str, word_idx: int = None): 152 | 153 | clip = None 154 | tokenize = None 155 | if type(model.cond_stage_model.wrapped) == FrozenCLIPEmbedder: 156 | clip : FrozenCLIPEmbedder = model.cond_stage_model.wrapped 157 | tokenize = clip.tokenizer.tokenize 158 | elif type(model.cond_stage_model.wrapped) == FrozenOpenCLIPEmbedder: 159 | clip : FrozenOpenCLIPEmbedder = model.cond_stage_model.wrapped 160 | tokenize = open_clip.tokenizer._tokenizer.encode 161 | else: 162 | assert False 163 | 164 | escaped_prompt = escape_prompt(prompt) 165 | # escaped_prompt = re.sub(r"[_-]", " ", escaped_prompt) 166 | tokens : list = tokenize(escaped_prompt) 167 | word = word.lower() 168 | merge_idxs = [] 169 | 170 | needles = tokenize(word) 171 | 172 | if len(needles) == 0: 173 | return [] 174 | 175 | for i, token in enumerate(tokens): 176 | if needles[0] == token and len(needles) > 1: 177 | next = i + 1 178 | success = True 179 | for needle in needles[1:]: 180 | if next >= len(tokens) or needle != tokens[next]: 181 | success = False 182 | break 183 | next += 1 184 | 185 | # append consecutive indexes if all pass 186 | if success: 187 | merge_idxs.extend(list(range(i, next))) 188 | 189 | elif needles[0] == token: 190 | merge_idxs.append(i) 191 | 192 | idxs = [] 193 | for x in merge_idxs: 194 | seq = (int)(x / 75) 195 | if seq == 0: 196 | idxs.append(x + 1) # padding 197 | else: 198 | idxs.append(x + 1 + seq*2) # If tokens exceed 75, they are split. 199 | 200 | return idxs 201 | 202 | def compute_token_merge_indices_with_tokenizer(tokenizer, prompt: str, word: str, word_idx: int = None, limit : int = -1): 203 | 204 | escaped_prompt = escape_prompt(prompt) 205 | # escaped_prompt = re.sub(r"[_-]", " ", escaped_prompt) 206 | tokens : list = tokenizer.tokenize(escaped_prompt) 207 | word = word.lower() 208 | merge_idxs = [] 209 | 210 | needles = tokenizer.tokenize(word) 211 | 212 | if len(needles) == 0: 213 | return [] 214 | 215 | limit_count = 0 216 | for i, token in enumerate(tokens): 217 | if needles[0] == token and len(needles) > 1: 218 | next = i + 1 219 | success = True 220 | for needle in needles[1:]: 221 | if next >= len(tokens) or needle != tokens[next]: 222 | success = False 223 | break 224 | next += 1 225 | 226 | # append consecutive indexes if all pass 227 | if success: 228 | merge_idxs.extend(list(range(i, next))) 229 | if limit > 0: 230 | limit_count += 1 231 | if limit_count >= limit: 232 | break 233 | 234 | elif needles[0] == token: 235 | merge_idxs.append(i) 236 | if limit > 0: 237 | limit_count += 1 238 | if limit_count >= limit: 239 | break 240 | 241 | idxs = [] 242 | for x in merge_idxs: 243 | seq = (int)(x / 75) 244 | if seq == 0: 245 | idxs.append(x + 1) # padding 246 | else: 247 | idxs.append(x + 1 + seq*2) # If tokens exceed 75, they are split. 248 | 249 | return idxs 250 | 251 | nlp = None 252 | 253 | 254 | @lru_cache(maxsize=100000) 255 | def cached_nlp(prompt: str, type='en_core_web_md'): 256 | global nlp 257 | 258 | # if nlp is None: 259 | # nlp = spacy.load(type) 260 | 261 | return nlp(prompt) 262 | 263 | class PromptAnalyzer: 264 | def __init__(self, clip : FrozenCLIPEmbedderWithCustomWordsBase, text : str): 265 | use_old = opts.use_old_emphasis_implementation 266 | assert not use_old, "use_old_emphasis_implementation is not supported" 267 | 268 | self.clip = clip 269 | self.id_start = clip.id_start 270 | self.id_end = clip.id_end 271 | self.is_open_clip = True if type(clip) == FrozenOpenCLIPEmbedderWithCustomWords else False 272 | self.used_custom_terms = [] 273 | self.hijack_comments = [] 274 | 275 | chunks, token_count = self.tokenize_line(text) 276 | 277 | self.token_count = token_count 278 | self.fixes = list(chain.from_iterable(chunk.fixes for chunk in chunks)) 279 | self.context_size = calc_context_size(token_count) 280 | 281 | tokens = list(chain.from_iterable(chunk.tokens for chunk in chunks)) 282 | multipliers = list(chain.from_iterable(chunk.multipliers for chunk in chunks)) 283 | 284 | self.tokens = [] 285 | self.multipliers = [] 286 | for i in range(self.context_size // 77): 287 | self.tokens.extend([self.id_start] + tokens[i*75:i*75+75] + [self.id_end]) 288 | self.multipliers.extend([1.0] + multipliers[i*75:i*75+75]+ [1.0]) 289 | 290 | def create(self, text : str): 291 | return PromptAnalyzer(self.clip, text) 292 | 293 | def tokenize_line(self, line): 294 | chunks, token_count = self.clip.tokenize_line(line) 295 | return chunks, token_count 296 | 297 | def process_text(self, texts): 298 | batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.clip.process_text(texts) 299 | return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count 300 | 301 | def encode(self, text : str): 302 | return self.clip.tokenize([text])[0] 303 | 304 | def calc_word_indecies(self, word : str, limit : int = -1, start_pos = 0): 305 | word = word.lower() 306 | merge_idxs = [] 307 | 308 | tokens = self.tokens 309 | needles = self.encode(word) 310 | 311 | limit_count = 0 312 | current_pos = 0 313 | for i, token in enumerate(tokens): 314 | current_pos = i 315 | if i < start_pos: 316 | continue 317 | 318 | if needles[0] == token and len(needles) > 1: 319 | next = i + 1 320 | success = True 321 | for needle in needles[1:]: 322 | if next >= len(tokens) or needle != tokens[next]: 323 | success = False 324 | break 325 | next += 1 326 | 327 | # append consecutive indexes if all pass 328 | if success: 329 | merge_idxs.extend(list(range(i, next))) 330 | if limit > 0: 331 | limit_count += 1 332 | if limit_count >= limit: 333 | break 334 | 335 | elif needles[0] == token: 336 | merge_idxs.append(i) 337 | if limit > 0: 338 | limit_count += 1 339 | if limit_count >= limit: 340 | break 341 | 342 | return merge_idxs, current_pos -------------------------------------------------------------------------------- /scripts/daam_script.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from collections import defaultdict 3 | import os 4 | import re 5 | import traceback 6 | 7 | import gradio as gr 8 | import modules.images as images 9 | import modules.scripts as scripts 10 | import torch 11 | from ldm.modules.encoders.modules import FrozenCLIPEmbedder, FrozenOpenCLIPEmbedder 12 | import open_clip.tokenizer 13 | from modules import script_callbacks 14 | from modules import script_callbacks, sd_hijack_clip, sd_hijack_open_clip 15 | from modules.processing import (Processed, StableDiffusionProcessing, fix_seed, 16 | process_images) 17 | from modules.shared import cmd_opts, opts, state 18 | import modules.shared as shared 19 | from PIL import Image 20 | 21 | from scripts.daam import trace, utils 22 | 23 | before_image_saved_handler = None 24 | 25 | class Script(scripts.Script): 26 | 27 | GRID_LAYOUT_AUTO = "Auto" 28 | GRID_LAYOUT_PREVENT_EMPTY = "Prevent Empty Spot" 29 | GRID_LAYOUT_BATCH_LENGTH_AS_ROW = "Batch Length As Row" 30 | 31 | 32 | def title(self): 33 | return "Daam script" 34 | 35 | def show(self, is_img2img): 36 | return scripts.AlwaysVisible 37 | 38 | def ui(self, is_img2img): 39 | with gr.Group(): 40 | with gr.Accordion("Attention Heatmap", open=False): 41 | attention_texts = gr.Text(label='Attention texts for visualization. (comma separated)', value='') 42 | 43 | with gr.Row(): 44 | hide_images = gr.Checkbox(label='Hide heatmap images', value=False) 45 | 46 | dont_save_images = gr.Checkbox(label='Do not save heatmap images', value=False) 47 | 48 | hide_caption = gr.Checkbox(label='Hide caption', value=False) 49 | 50 | with gr.Row(): 51 | use_grid = gr.Checkbox(label='Use grid (output to grid dir)', value=False) 52 | 53 | grid_layouyt = gr.Dropdown( 54 | [Script.GRID_LAYOUT_AUTO, Script.GRID_LAYOUT_PREVENT_EMPTY, Script.GRID_LAYOUT_BATCH_LENGTH_AS_ROW], label="Grid layout", 55 | value=Script.GRID_LAYOUT_AUTO 56 | ) 57 | 58 | with gr.Row(): 59 | alpha = gr.Slider(label='Heatmap blend alpha', value=0.5, minimum=0, maximum=1, step=0.01) 60 | 61 | heatmap_image_scale = gr.Slider(label='Heatmap image scale', value=1.0, minimum=0.1, maximum=1, step=0.025) 62 | 63 | with gr.Row(): 64 | trace_each_layers = gr.Checkbox(label = 'Trace each layers', value=False) 65 | 66 | layers_as_row = gr.Checkbox(label = 'Use layers as row instead of Batch Length', value=False) 67 | 68 | 69 | self.tracers = None 70 | 71 | return [attention_texts, hide_images, dont_save_images, hide_caption, use_grid, grid_layouyt, alpha, heatmap_image_scale, trace_each_layers, layers_as_row] 72 | 73 | def process(self, 74 | p : StableDiffusionProcessing, 75 | attention_texts : str, 76 | hide_images : bool, 77 | dont_save_images : bool, 78 | hide_caption : bool, 79 | use_grid : bool, 80 | grid_layouyt :str, 81 | alpha : float, 82 | heatmap_image_scale : float, 83 | trace_each_layers : bool, 84 | layers_as_row: bool): 85 | 86 | self.enabled = False # in case the assert fails 87 | assert opts.samples_save, "Cannot run Daam script. Enable 'Always save all generated images' setting." 88 | 89 | self.images = [] 90 | self.hide_images = hide_images 91 | self.dont_save_images = dont_save_images 92 | self.hide_caption = hide_caption 93 | self.alpha = alpha 94 | self.use_grid = use_grid 95 | self.grid_layouyt = grid_layouyt 96 | self.heatmap_image_scale = heatmap_image_scale 97 | self.heatmap_images = dict() 98 | 99 | self.attentions = [s.strip() for s in attention_texts.split(",") if s.strip()] 100 | self.enabled = len(self.attentions) > 0 101 | 102 | fix_seed(p) 103 | 104 | def process_batch(self, 105 | p : StableDiffusionProcessing, 106 | attention_texts : str, 107 | hide_images : bool, 108 | dont_save_images : bool, 109 | hide_caption : bool, 110 | use_grid : bool, 111 | grid_layouyt :str, 112 | alpha : float, 113 | heatmap_image_scale : float, 114 | trace_each_layers : bool, 115 | layers_as_row: bool, 116 | prompts, 117 | **kwargs): 118 | 119 | if not self.enabled: 120 | return 121 | 122 | styled_prompt = prompts[0] 123 | 124 | embedder = None 125 | if type(p.sd_model.cond_stage_model) == sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords or \ 126 | type(p.sd_model.cond_stage_model) == sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords: 127 | embedder = p.sd_model.cond_stage_model 128 | else: 129 | assert False, f"Embedder '{type(p.sd_model.cond_stage_model)}' is not supported." 130 | 131 | clip = None 132 | tokenize = None 133 | if type(p.sd_model.cond_stage_model.wrapped) == FrozenCLIPEmbedder: 134 | clip : FrozenCLIPEmbedder = p.sd_model.cond_stage_model.wrapped 135 | tokenize = clip.tokenizer.tokenize 136 | elif type(p.sd_model.cond_stage_model.wrapped) == FrozenOpenCLIPEmbedder: 137 | clip : FrozenOpenCLIPEmbedder = p.sd_model.cond_stage_model.wrapped 138 | tokenize = open_clip.tokenizer._tokenizer.encode 139 | else: 140 | assert False 141 | 142 | tokens = tokenize(utils.escape_prompt(styled_prompt)) 143 | context_size = utils.calc_context_size(len(tokens)) 144 | 145 | prompt_analyzer = utils.PromptAnalyzer(embedder, styled_prompt) 146 | self.prompt_analyzer = prompt_analyzer 147 | context_size = prompt_analyzer.context_size 148 | 149 | print(f"daam run with context_size={prompt_analyzer.context_size}, token_count={prompt_analyzer.token_count}") 150 | # print(f"remade_tokens={prompt_analyzer.tokens}, multipliers={prompt_analyzer.multipliers}") 151 | # print(f"hijack_comments={prompt_analyzer.hijack_comments}, used_custom_terms={prompt_analyzer.used_custom_terms}") 152 | # print(f"fixes={prompt_analyzer.fixes}") 153 | 154 | if any(item[0] in self.attentions for item in self.prompt_analyzer.used_custom_terms): 155 | print("Embedding heatmap cannot be shown.") 156 | 157 | global before_image_saved_handler 158 | before_image_saved_handler = lambda params : self.before_image_saved(params) 159 | 160 | with torch.no_grad(): 161 | # cannot trace the same block from two tracers 162 | if trace_each_layers: 163 | num_input = len(p.sd_model.model.diffusion_model.input_blocks) 164 | num_output = len(p.sd_model.model.diffusion_model.output_blocks) 165 | self.tracers = [trace(p.sd_model, p.height, p.width, context_size, layer_idx=i) for i in range(num_input + num_output + 1)] 166 | self.attn_captions = [f"IN{i:02d}" for i in range(num_input)] + ["MID"] + [f"OUT{i:02d}" for i in range(num_output)] 167 | else: 168 | self.tracers = [trace(p.sd_model, p.height, p.width, context_size)] 169 | self.attn_captions = [""] 170 | 171 | for tracer in self.tracers: 172 | tracer.hook() 173 | 174 | def postprocess(self, p, processed, 175 | attention_texts : str, 176 | hide_images : bool, 177 | dont_save_images : bool, 178 | hide_caption : bool, 179 | use_grid : bool, 180 | grid_layouyt :str, 181 | alpha : float, 182 | heatmap_image_scale : float, 183 | trace_each_layers : bool, 184 | layers_as_row: bool, 185 | **kwargs): 186 | if self.enabled == False: 187 | return 188 | 189 | for trace in self.tracers: 190 | trace.unhook() 191 | self.tracers = None 192 | 193 | initial_info = None 194 | 195 | if initial_info is None: 196 | initial_info = processed.info 197 | 198 | self.images += processed.images 199 | 200 | global before_image_saved_handler 201 | before_image_saved_handler = None 202 | 203 | if layers_as_row: 204 | images_list = [] 205 | for i in range(p.batch_size * p.n_iter): 206 | imgs = [] 207 | for k in sorted(self.heatmap_images.keys()): 208 | imgs += [self.heatmap_images[k][len(self.attentions)*i + j] for j in range(len(self.attentions))] 209 | images_list.append(imgs) 210 | else: 211 | images_list = [self.heatmap_images[k] for k in sorted(self.heatmap_images.keys())] 212 | 213 | for img_list in images_list: 214 | 215 | if img_list and self.use_grid: 216 | 217 | grid_layout = self.grid_layouyt 218 | if grid_layout == Script.GRID_LAYOUT_AUTO: 219 | if p.batch_size * p.n_iter == 1: 220 | grid_layout = Script.GRID_LAYOUT_PREVENT_EMPTY 221 | else: 222 | grid_layout = Script.GRID_LAYOUT_BATCH_LENGTH_AS_ROW 223 | 224 | if grid_layout == Script.GRID_LAYOUT_PREVENT_EMPTY: 225 | grid_img = images.image_grid(img_list) 226 | elif grid_layout == Script.GRID_LAYOUT_BATCH_LENGTH_AS_ROW: 227 | if layers_as_row: 228 | batch_size = len(self.attentions) 229 | rows = len(self.heatmap_images) 230 | else: 231 | batch_size = p.batch_size 232 | rows = p.batch_size * p.n_iter 233 | grid_img = images.image_grid(img_list, batch_size=batch_size, rows=rows) 234 | else: 235 | pass 236 | 237 | if not self.dont_save_images: 238 | images.save_image(grid_img, p.outpath_grids, "grid_daam", grid=True, p=p) 239 | 240 | if not self.hide_images: 241 | processed.images.insert(0, grid_img) 242 | processed.index_of_first_image += 1 243 | processed.infotexts.insert(0, processed.infotexts[0]) 244 | 245 | else: 246 | if not self.hide_images: 247 | processed.images[:0] = img_list 248 | processed.index_of_first_image += len(img_list) 249 | processed.infotexts[:0] = [processed.infotexts[0]] * len(img_list) 250 | 251 | return processed 252 | 253 | def before_image_saved(self, params : script_callbacks.ImageSaveParams): 254 | batch_pos = -1 255 | if params.p.batch_size > 1: 256 | match = re.search(r"Batch pos: (\d+)", params.pnginfo['parameters']) 257 | if match: 258 | batch_pos = int(match.group(1)) 259 | else: 260 | batch_pos = 0 261 | 262 | if batch_pos < 0: 263 | return 264 | 265 | if self.tracers is not None and len(self.attentions) > 0: 266 | for i, tracer in enumerate(self.tracers): 267 | with torch.no_grad(): 268 | styled_prompot = shared.prompt_styles.apply_styles_to_prompt(params.p.prompt, params.p.styles) 269 | try: 270 | global_heat_map = tracer.compute_global_heat_map(self.prompt_analyzer, styled_prompot, batch_pos) 271 | except: 272 | continue 273 | 274 | if i not in self.heatmap_images: 275 | self.heatmap_images[i] = [] 276 | 277 | if global_heat_map is not None: 278 | heatmap_images = [] 279 | for attention in self.attentions: 280 | 281 | img_size = params.image.size 282 | caption = attention + (" " + self.attn_captions[i] if self.attn_captions[i] else "") if not self.hide_caption else None 283 | 284 | heat_map = global_heat_map.compute_word_heat_map(attention) 285 | if heat_map is None : print(f"No heatmaps for '{attention}'") 286 | 287 | heat_map_img = utils.expand_image(heat_map, img_size[1], img_size[0]) if heat_map is not None else None 288 | img : Image.Image = utils.image_overlay_heat_map(params.image, heat_map_img, alpha=self.alpha, caption=caption, image_scale=self.heatmap_image_scale) 289 | 290 | fullfn_without_extension, extension = os.path.splitext(params.filename) 291 | full_filename = fullfn_without_extension + "_" + attention + ("_" + self.attn_captions[i] if self.attn_captions[i] else "") + extension 292 | 293 | if self.use_grid: 294 | heatmap_images.append(img) 295 | else: 296 | heatmap_images.append(img) 297 | if not self.dont_save_images: 298 | img.save(full_filename) 299 | 300 | self.heatmap_images[i] += heatmap_images 301 | 302 | self.heatmap_images = {j:self.heatmap_images[j] for j in self.heatmap_images.keys() if self.heatmap_images[j]} 303 | 304 | # if it is last batch pos, clear heatmaps 305 | if batch_pos == params.p.batch_size - 1: 306 | for tracer in self.tracers: 307 | tracer.reset() 308 | 309 | return 310 | 311 | 312 | def handle_before_image_saved(params : script_callbacks.ImageSaveParams): 313 | 314 | if before_image_saved_handler is not None and callable(before_image_saved_handler): 315 | before_image_saved_handler(params) 316 | 317 | return 318 | 319 | script_callbacks.on_before_image_saved(handle_before_image_saved) 320 | --------------------------------------------------------------------------------