├── .gitignore ├── LICENSE ├── MANIFEST.in ├── README.md ├── daam ├── __init__.py ├── _version.py ├── evaluate.py ├── experiment.py ├── heatmap.py ├── hook.py ├── run │ ├── __init__.py │ ├── demo.py │ ├── evaluate.py │ └── generate.py ├── trace.py └── utils.py ├── data ├── vocab-small.tsv └── vocab.tsv ├── docs ├── .buildinfo ├── .nojekyll ├── Makefile ├── _sources │ ├── daam.rst.txt │ ├── daam.run.rst.txt │ ├── index.rst.txt │ └── modules.rst.txt ├── _static │ ├── alabaster.css │ ├── basic.css │ ├── custom.css │ ├── doctools.js │ ├── documentation_options.js │ ├── file.png │ ├── jquery-3.5.1.js │ ├── jquery.js │ ├── language_data.js │ ├── minus.png │ ├── plus.png │ ├── pygments.css │ ├── searchtools.js │ ├── underscore-1.13.1.js │ └── underscore.js ├── conf.py ├── daam.html ├── daam.rst ├── daam.run.html ├── daam.run.rst ├── genindex.html ├── index.html ├── index.rst ├── make.sh ├── modules.html ├── modules.rst ├── objects.inv ├── py-modindex.html ├── search.html └── searchindex.js ├── example.jpg ├── notebooks ├── 0-setup.ipynb ├── 1-visuosyntactic-analyses.ipynb └── 2-visuosemantic-analyses.ipynb ├── requirements.txt ├── scrollbar.css └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | .idea/ 161 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Castorini 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include requirements.txt 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # What the DAAM: Interpreting Stable Diffusion Using Cross Attention 2 | 3 | [![HF Spaces](https://img.shields.io/badge/HuggingFace%20Space-online-green.svg)](https://huggingface.co/spaces/tetrisd/Diffusion-Attentive-Attribution-Maps) [![Citation](https://img.shields.io/badge/Citation-ACL-orange.svg)](https://gist.github.com/daemon/639de6fea584d7df1a62f04a2ea0cdad) [![PyPi version](https://badgen.net/pypi/v/daam?color=blue)](https://pypi.org/project/daam) [![Downloads](https://static.pepy.tech/badge/daam)](https://pepy.tech/project/daam) 4 | 5 | ![example image](example.jpg) 6 | 7 | ### Updated to support Stable Diffusion XL (SDXL) and Diffusers 0.21.1! 8 | 9 | I regularly update this codebase. Please submit an issue if you have any questions. 10 | 11 | In [our paper](https://aclanthology.org/2023.acl-long.310), we propose diffusion attentive attribution maps (DAAM), a cross attention-based approach for interpreting Stable Diffusion. 12 | Check out our demo: https://huggingface.co/spaces/tetrisd/Diffusion-Attentive-Attribution-Maps. 13 | See our [documentation](https://castorini.github.io/daam/), hosted by GitHub pages, and [our Colab notebook](https://colab.research.google.com/drive/1miGauqa07uHnDoe81NmbmtTtnupmlipv?usp=sharing), updated for v0.1.0. 14 | 15 | ## Getting Started 16 | First, install [PyTorch](https://pytorch.org) for your platform. 17 | Then, install DAAM with `pip install daam`, unless you want an editable version of the library, in which case do `git clone https://github.com/castorini/daam && pip install -e daam`. 18 | Finally, login using `huggingface-cli login` to get many stable diffusion models -- you'll need to get a token at [HuggingFace.co](https://huggingface.co/). 19 | 20 | ### Running the Website Demo 21 | Simply run `daam-demo` in a shell and navigate to http://localhost:8080. 22 | The same demo as the one on HuggingFace Spaces will show up. 23 | 24 | ### Using DAAM as a CLI Utility 25 | DAAM comes with a simple generation script for people who want to quickly try it out. 26 | Try running 27 | ```bash 28 | $ mkdir -p daam-test && cd daam-test 29 | $ daam "A dog running across the field." 30 | $ ls 31 | a.heat_map.png field.heat_map.png generation.pt output.png seed.txt 32 | dog.heat_map.png running.heat_map.png prompt.txt 33 | ``` 34 | Your current working directory will now contain the generated image as `output.png` and a DAAM map for every word, as well as some auxiliary data. 35 | You can see more options for `daam` by running `daam -h`. 36 | To use Stable Diffusion XL as the backend, run `daam --model xl-base-1.0 "Dog jumping"`. 37 | 38 | ### Using DAAM as a Library 39 | 40 | Import and use DAAM as follows: 41 | 42 | ```python 43 | from daam import trace, set_seed 44 | from diffusers import DiffusionPipeline 45 | from matplotlib import pyplot as plt 46 | import torch 47 | 48 | 49 | model_id = 'stabilityai/stable-diffusion-xl-base-1.0' 50 | device = 'cuda' 51 | 52 | pipe = DiffusionPipeline.from_pretrained(model_id, use_auth_token=True, torch_dtype=torch.float16, use_safetensors=True, variant='fp16') 53 | pipe = pipe.to(device) 54 | 55 | prompt = 'A dog runs across the field' 56 | gen = set_seed(0) # for reproducibility 57 | 58 | with torch.no_grad(): 59 | with trace(pipe) as tc: 60 | out = pipe(prompt, num_inference_steps=50, generator=gen) 61 | heat_map = tc.compute_global_heat_map() 62 | heat_map = heat_map.compute_word_heat_map('dog') 63 | heat_map.plot_overlay(out.images[0]) 64 | plt.show() 65 | ``` 66 | 67 | You can also serialize and deserialize the DAAM maps pretty easily: 68 | 69 | ```python 70 | from daam import GenerationExperiment, trace 71 | 72 | with trace(pipe) as tc: 73 | pipe('A dog and a cat') 74 | exp = tc.to_experiment('experiment-dir') 75 | exp.save() # experiment-dir now contains all the data and heat maps 76 | 77 | exp = GenerationExperiment.load('experiment-dir') # load the experiment 78 | ``` 79 | 80 | We'll continue adding docs. 81 | In the meantime, check out the `GenerationExperiment`, `GlobalHeatMap`, and `DiffusionHeatMapHooker` classes, as well as the `daam/run/*.py` example scripts. 82 | You can download the COCO-Gen dataset from the paper at http://ralphtang.com/coco-gen.tar.gz. 83 | If clicking the link doesn't work on your browser, copy and paste it in a new tab, or use a CLI utility such as `wget`. 84 | 85 | ## See Also 86 | - [DAAM-i2i](https://github.com/RishiDarkDevil/daam-i2i), an extension of DAAM to image-to-image attribution. 87 | 88 | - [Furkan's video](https://www.youtube.com/watch?v=XiKyEKJrTLQ) on easily getting started with DAAM. 89 | 90 | - [1littlecoder's video](https://www.youtube.com/watch?v=J2WtkA1Xfew) for a code demonstration and Colab notebook of an older version of DAAM. 91 | 92 | ## Citation 93 | ``` 94 | @inproceedings{tang2023daam, 95 | title = "What the {DAAM}: Interpreting Stable Diffusion Using Cross Attention", 96 | author = "Tang, Raphael and 97 | Liu, Linqing and 98 | Pandey, Akshat and 99 | Jiang, Zhiying and 100 | Yang, Gefei and 101 | Kumar, Karun and 102 | Stenetorp, Pontus and 103 | Lin, Jimmy and 104 | Ture, Ferhan", 105 | booktitle = "Proceedings of the 61st Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers)", 106 | year = "2023", 107 | url = "https://aclanthology.org/2023.acl-long.310", 108 | } 109 | ``` 110 | -------------------------------------------------------------------------------- /daam/__init__.py: -------------------------------------------------------------------------------- 1 | from ._version import __version__ 2 | from .experiment import * 3 | from .heatmap import * 4 | from .hook import * 5 | from .utils import * 6 | from .trace import * 7 | -------------------------------------------------------------------------------- /daam/_version.py: -------------------------------------------------------------------------------- 1 | __version__ = '0.2.0' 2 | -------------------------------------------------------------------------------- /daam/evaluate.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from typing import List, Union 3 | 4 | from scipy.optimize import linear_sum_assignment 5 | import PIL.Image as Image 6 | import numpy as np 7 | import torch 8 | import torch.nn.functional as F 9 | 10 | 11 | __all__ = ['compute_iou', 'MeanEvaluator', 'load_mask', 'compute_ioa'] 12 | 13 | 14 | def compute_iou(a: torch.Tensor, b: torch.Tensor) -> float: 15 | if a.shape[0] != b.shape[0]: 16 | a = F.interpolate(a.unsqueeze(0).unsqueeze(0).float(), size=b.shape, mode='bicubic').squeeze() 17 | a[a < 1] = 0 18 | a[a >= 1] = 1 19 | 20 | intersection = (a * b).sum() 21 | union = a.sum() + b.sum() - intersection 22 | 23 | return (intersection / (union + 1e-8)).item() 24 | 25 | 26 | def compute_ioa(a: torch.Tensor, b: torch.Tensor) -> float: 27 | if a.shape[0] != b.shape[0]: 28 | a = F.interpolate(a.unsqueeze(0).unsqueeze(0).float(), size=b.shape, mode='bicubic').squeeze() 29 | a[a < 1] = 0 30 | a[a >= 1] = 1 31 | 32 | intersection = (a * b).sum() 33 | area = a.sum() 34 | 35 | return (intersection / (area + 1e-8)).item() 36 | 37 | 38 | def load_mask(path: str) -> torch.Tensor: 39 | mask = np.array(Image.open(path)) 40 | mask = torch.from_numpy(mask).float()[:, :, 3] # use alpha channel 41 | mask = (mask > 0).float() 42 | 43 | return mask 44 | 45 | 46 | class UnsupervisedEvaluator: 47 | def __init__(self, name: str = 'UnsupervisedEvaluator'): 48 | self.name = name 49 | self.ious = defaultdict(list) 50 | self.num_samples = 0 51 | 52 | def log_iou(self, preds: Union[torch.Tensor, List[torch.Tensor]], truth: torch.Tensor, gt_idx: int = 0, pred_idx: int = 0): 53 | if not isinstance(preds, list): 54 | preds = [preds] 55 | 56 | iou = max(compute_iou(pred, truth) for pred in preds) 57 | self.ious[gt_idx].append((pred_idx, iou)) 58 | 59 | @property 60 | def mean_iou(self) -> float: 61 | n = max(max(self.ious), max([y[0] for x in self.ious.values() for y in x])) + 1 62 | iou_matrix = np.zeros((n, n)) 63 | count_matrix = np.zeros((n, n)) 64 | 65 | for gt_idx, ious in self.ious.items(): 66 | for pred_idx, iou in ious: 67 | iou_matrix[gt_idx, pred_idx] += iou 68 | count_matrix[gt_idx, pred_idx] += 1 69 | 70 | row_ind, col_ind = linear_sum_assignment(iou_matrix, maximize=True) 71 | return iou_matrix[row_ind, col_ind].sum() / count_matrix[row_ind, col_ind].sum() 72 | 73 | def increment(self): 74 | self.num_samples += 1 75 | 76 | def __len__(self) -> int: 77 | return self.num_samples 78 | 79 | def __str__(self): 80 | return f'{self.name}<{self.mean_iou:.4f} (mIoU) {len(self)} samples>' 81 | 82 | 83 | class MeanEvaluator: 84 | def __init__(self, name: str = 'MeanEvaluator'): 85 | self.ious: List[float] = [] 86 | self.intensities: List[float] = [] 87 | self.name = name 88 | 89 | def log_iou(self, preds: Union[torch.Tensor, List[torch.Tensor]], truth: torch.Tensor): 90 | if not isinstance(preds, list): 91 | preds = [preds] 92 | 93 | self.ious.append(max(compute_iou(pred, truth) for pred in preds)) 94 | return self 95 | 96 | def log_intensity(self, pred: torch.Tensor): 97 | self.intensities.append(pred.mean().item()) 98 | return self 99 | 100 | @property 101 | def mean_iou(self) -> float: 102 | return np.mean(self.ious) 103 | 104 | @property 105 | def mean_intensity(self) -> float: 106 | return np.mean(self.intensities) 107 | 108 | @property 109 | def ci95_miou(self) -> float: 110 | return 1.96 * np.std(self.ious) / np.sqrt(len(self.ious)) 111 | 112 | def __len__(self) -> int: 113 | return max(len(self.ious), len(self.intensities)) 114 | 115 | def __str__(self): 116 | return f'{self.name}<{self.mean_iou:.4f} (±{self.ci95_miou:.3f} mIoU) {self.mean_intensity:.4f} (mInt) {len(self)} samples>' 117 | 118 | 119 | if __name__ == '__main__': 120 | mask = load_mask('truth/output/452/sink.gt.png') 121 | 122 | print(MeanEvaluator().log_iou(mask, mask)) 123 | -------------------------------------------------------------------------------- /daam/experiment.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import List, Optional, Dict, Any, Union 3 | from dataclasses import dataclass 4 | import json 5 | 6 | from transformers import PreTrainedTokenizer, AutoTokenizer 7 | import PIL.Image 8 | import numpy as np 9 | import torch 10 | 11 | from .utils import auto_autocast 12 | from .evaluate import load_mask 13 | 14 | 15 | __all__ = ['GenerationExperiment', 'COCO80_LABELS', 'COCOSTUFF27_LABELS', 'COCO80_INDICES', 'build_word_list_coco80'] 16 | 17 | 18 | COCO80_LABELS: List[str] = [ 19 | 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', 20 | 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 21 | 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 22 | 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 23 | 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 24 | 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 25 | 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 26 | 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 27 | 'hair drier', 'toothbrush' 28 | ] 29 | 30 | COCO80_INDICES: Dict[str, int] = {x: i for i, x in enumerate(COCO80_LABELS)} 31 | 32 | UNUSED_LABELS: List[str] = [f'__unused_{i}__' for i in range(1, 200)] 33 | 34 | COCOSTUFF27_LABELS: List[str] = [ 35 | 'electronic', 'appliance', 'food', 'furniture', 'indoor', 'kitchen', 'accessory', 'animal', 'outdoor', 'person', 36 | 'sports', 'vehicle', 'ceiling', 'floor', 'food', 'furniture', 'rawmaterial', 'textile', 'wall', 'window', 37 | 'building', 'ground', 'plant', 'sky', 'solid', 'structural', 'water' 38 | ] 39 | 40 | COCO80_ONTOLOGY = { 41 | 'two-wheeled vehicle': ['bicycle', 'motorcycle'], 42 | 'vehicle': ['two-wheeled vehicle', 'four-wheeled vehicle'], 43 | 'four-wheeled vehicle': ['bus', 'truck', 'car'], 44 | 'four-legged animals': ['livestock', 'pets', 'wild animals'], 45 | 'livestock': ['cow', 'horse', 'sheep'], 46 | 'pets': ['cat', 'dog'], 47 | 'wild animals': ['elephant', 'bear', 'zebra', 'giraffe'], 48 | 'bags': ['backpack', 'handbag', 'suitcase'], 49 | 'sports boards': ['snowboard', 'surfboard', 'skateboard'], 50 | 'utensils': ['fork', 'knife', 'spoon'], 51 | 'receptacles': ['bowl', 'cup'], 52 | 'fruits': ['banana', 'apple', 'orange'], 53 | 'foods': ['fruits', 'meals', 'desserts'], 54 | 'meals': ['sandwich', 'hot dog', 'pizza'], 55 | 'desserts': ['cake', 'donut'], 56 | 'furniture': ['chair', 'couch', 'bench'], 57 | 'electronics': ['monitors', 'appliances'], 58 | 'monitors': ['tv', 'cell phone', 'laptop'], 59 | 'appliances': ['oven', 'toaster', 'refrigerator'] 60 | } 61 | 62 | COCO80_TO_27 = { 63 | 'bicycle': 'vehicle', 'car': 'vehicle', 'motorcycle': 'vehicle', 'airplane': 'vehicle', 'bus': 'vehicle', 64 | 'train': 'vehicle', 'truck': 'vehicle', 'boat': 'vehicle', 'traffic light': 'accessory', 'fire hydrant': 'accessory', 65 | 'stop sign': 'accessory', 'parking meter': 'accessory', 'bench': 'furniture', 'bird': 'animal', 'cat': 'animal', 66 | 'dog': 'animal', 'horse': 'animal', 'sheep': 'animal', 'cow': 'animal', 'elephant': 'animal', 'bear': 'animal', 67 | 'zebra': 'animal', 'giraffe': 'animal', 'backpack': 'accessory', 'umbrella': 'accessory', 'handbag': 'accessory', 68 | 'tie': 'accessory', 'suitcase': 'accessory', 'frisbee': 'sports', 'skis': 'sports', 'snowboard': 'sports', 69 | 'sports ball': 'sports', 'kite': 'sports', 'baseball bat': 'sports', 'baseball glove': 'sports', 70 | 'skateboard': 'sports', 'surfboard': 'sports', 'tennis racket': 'sports', 'bottle': 'food', 'wine glass': 'food', 71 | 'cup': 'food', 'fork': 'food', 'knife': 'food', 'spoon': 'food', 'bowl': 'food', 'banana': 'food', 'apple': 'food', 72 | 'sandwich': 'food', 'orange': 'food', 'broccoli': 'food', 'carrot': 'food', 'hot dog': 'food', 'pizza': 'food', 73 | 'donut': 'food', 'cake': 'food', 'chair': 'furniture', 'couch': 'furniture', 'potted plant': 'plant', 74 | 'bed': 'furniture', 'dining table': 'furniture', 'toilet': 'furniture', 'tv': 'electronic', 'laptop': 'electronic', 75 | 'mouse': 'electronic', 'remote': 'electronic', 'keyboard': 'electronic', 'cell phone': 'electronic', 76 | 'microwave': 'appliance', 'oven': 'appliance', 'toaster': 'appliance', 'sink': 'appliance', 77 | 'refrigerator': 'appliance', 'book': 'indoor', 'clock': 'indoor', 'vase': 'indoor', 'scissors': 'indoor', 78 | 'teddy bear': 'indoor', 'hair drier': 'indoor', 'toothbrush': 'indoor' 79 | } 80 | 81 | 82 | def build_word_list_coco80() -> Dict[str, List[str]]: 83 | words_map = COCO80_ONTOLOGY.copy() 84 | words_map = {k: v for k, v in words_map.items() if not any(item in COCO80_ONTOLOGY for item in v)} 85 | 86 | return words_map 87 | 88 | 89 | def _add_mask(masks: Dict[str, torch.Tensor], word: str, mask: torch.Tensor, simplify80: bool = False) -> Dict[str, torch.Tensor]: 90 | if simplify80: 91 | word = COCO80_TO_27.get(word, word) 92 | 93 | if word in masks: 94 | masks[word] = masks[word.lower()] + mask 95 | masks[word].clamp_(0, 1) 96 | else: 97 | masks[word] = mask 98 | 99 | return masks 100 | 101 | 102 | @dataclass 103 | class GenerationExperiment: 104 | """Class to hold experiment parameters. Pickleable.""" 105 | image: PIL.Image.Image 106 | global_heat_map: torch.Tensor 107 | prompt: str 108 | 109 | seed: int = None 110 | id: str = '.' 111 | path: Optional[Path] = None 112 | 113 | truth_masks: Optional[Dict[str, torch.Tensor]] = None 114 | prediction_masks: Optional[Dict[str, torch.Tensor]] = None 115 | annotations: Optional[Dict[str, Any]] = None 116 | subtype: Optional[str] = '.' 117 | tokenizer: AutoTokenizer = None 118 | 119 | def __post_init__(self): 120 | if isinstance(self.path, str): 121 | self.path = Path(self.path) 122 | 123 | self.path = None if self.path is None else self.path / self.id 124 | 125 | def nsfw(self) -> bool: 126 | return np.sum(np.array(self.image)) == 0 127 | 128 | def heat_map(self, tokenizer: AutoTokenizer = None): 129 | if tokenizer is None: 130 | tokenizer = self.tokenizer 131 | 132 | from daam import GlobalHeatMap 133 | return GlobalHeatMap(tokenizer, self.prompt, self.global_heat_map) 134 | 135 | def clear_checkpoint(self): 136 | path = self if isinstance(self, Path) else self.path 137 | 138 | (path / 'generation.pt').unlink(missing_ok=True) 139 | 140 | def save(self, path: str = None, heat_maps: bool = True, tokenizer: AutoTokenizer = None): 141 | if path is None: 142 | path = self.path 143 | else: 144 | path = Path(path) / self.id 145 | 146 | if tokenizer is None: 147 | tokenizer = self.tokenizer 148 | 149 | (path / self.subtype).mkdir(parents=True, exist_ok=True) 150 | torch.save(self, path / self.subtype / 'generation.pt') 151 | self.image.save(path / self.subtype / 'output.png') 152 | 153 | with (path / 'prompt.txt').open('w') as f: 154 | f.write(self.prompt) 155 | 156 | with (path / 'seed.txt').open('w') as f: 157 | f.write(str(self.seed)) 158 | 159 | if self.truth_masks is not None: 160 | for name, mask in self.truth_masks.items(): 161 | im = PIL.Image.fromarray((mask * 255).unsqueeze(-1).expand(-1, -1, 4).byte().numpy()) 162 | im.save(path / f'{name.lower()}.gt.png') 163 | 164 | if heat_maps and tokenizer is not None: 165 | self.save_all_heat_maps(tokenizer) 166 | 167 | self.save_annotations() 168 | 169 | def save_annotations(self, path: Path = None): 170 | if path is None: 171 | path = self.path 172 | 173 | if self.annotations is not None: 174 | with (path / 'annotations.json').open('w') as f: 175 | json.dump(self.annotations, f) 176 | 177 | def _load_truth_masks(self, simplify80: bool = False) -> Dict[str, torch.Tensor]: 178 | masks = {} 179 | 180 | for mask_path in self.path.glob('*.gt.png'): 181 | word = mask_path.name.split('.gt.png')[0].lower() 182 | mask = load_mask(str(mask_path)) 183 | _add_mask(masks, word, mask, simplify80) 184 | 185 | return masks 186 | 187 | def _load_pred_masks(self, pred_prefix, composite=False, simplify80=False, vocab=None): 188 | # type: (str, bool, bool, List[str] | None) -> Dict[str, torch.Tensor] 189 | masks = {} 190 | 191 | if vocab is None: 192 | vocab = UNUSED_LABELS 193 | 194 | if composite: 195 | try: 196 | im = PIL.Image.open(self.path / self.subtype / f'composite.{pred_prefix}.pred.png') 197 | im = np.array(im) 198 | 199 | for mask_idx in np.unique(im): 200 | mask = torch.from_numpy((im == mask_idx).astype(np.float32)) 201 | _add_mask(masks, vocab[mask_idx], mask, simplify80) 202 | except FileNotFoundError: 203 | pass 204 | else: 205 | for mask_path in (self.path / self.subtype).glob(f'*.{pred_prefix}.pred.png'): 206 | mask = load_mask(str(mask_path)) 207 | word = mask_path.name.split(f'.{pred_prefix}.pred')[0].lower() 208 | _add_mask(masks, word, mask, simplify80) 209 | 210 | return masks 211 | 212 | def clear_prediction_masks(self, name: str): 213 | path = self if isinstance(self, Path) else self.path 214 | path = path / self.subtype 215 | 216 | for mask_path in path.glob(f'*.{name}.pred.png'): 217 | mask_path.unlink() 218 | 219 | def save_prediction_mask(self, mask: torch.Tensor, word: str, name: str): 220 | path = self if isinstance(self, Path) else self.path 221 | im = PIL.Image.fromarray((mask * 255).unsqueeze(-1).expand(-1, -1, 4).cpu().byte().numpy()) 222 | im.save(path / self.subtype / f'{word.lower()}.{name}.pred.png') 223 | 224 | def save_heat_map( 225 | self, 226 | word: str, 227 | tokenizer: PreTrainedTokenizer = None, 228 | crop: int = None, 229 | output_prefix: str = '', 230 | absolute: bool = False 231 | ) -> Path: 232 | from .trace import GlobalHeatMap # because of cyclical import 233 | 234 | if tokenizer is None: 235 | tokenizer = self.tokenizer 236 | 237 | with auto_autocast(dtype=torch.float32): 238 | path = self.path / self.subtype / f'{output_prefix}{word.lower()}.heat_map.png' 239 | heat_map = GlobalHeatMap(tokenizer, self.prompt, self.global_heat_map) 240 | heat_map.compute_word_heat_map(word).expand_as(self.image, color_normalize=not absolute, out_file=path, plot=True) 241 | 242 | return path 243 | 244 | def save_all_heat_maps(self, tokenizer: PreTrainedTokenizer = None, crop: int = None) -> Dict[str, Path]: 245 | path_map = {} 246 | 247 | if tokenizer is None: 248 | tokenizer = self.tokenizer 249 | 250 | for word in self.prompt.split(' '): 251 | try: 252 | path = self.save_heat_map(word, tokenizer, crop=crop) 253 | path_map[word] = path 254 | except: 255 | pass 256 | 257 | return path_map 258 | 259 | @staticmethod 260 | def contains_truth_mask(path: Union[str, Path], prompt_id: str = None) -> bool: 261 | if prompt_id is None: 262 | return any(Path(path).glob('*.gt.png')) 263 | else: 264 | return any((Path(path) / prompt_id).glob('*.gt.png')) 265 | 266 | @staticmethod 267 | def read_seed(path: Union[str, Path], prompt_id: str = None) -> int: 268 | if prompt_id is None: 269 | return int(Path(path).joinpath('seed.txt').read_text()) 270 | else: 271 | return int(Path(path).joinpath(prompt_id).joinpath('seed.txt').read_text()) 272 | 273 | @staticmethod 274 | def has_annotations(path: Union[str, Path]) -> bool: 275 | return Path(path).joinpath('annotations.json').exists() 276 | 277 | @staticmethod 278 | def has_experiment(path: Union[str, Path], prompt_id: str) -> bool: 279 | return (Path(path) / prompt_id / 'generation.pt').exists() 280 | 281 | @staticmethod 282 | def read_prompt(path: Union[str, Path], prompt_id: str = None) -> str: 283 | if prompt_id is None: 284 | prompt_id = '.' 285 | 286 | with (Path(path) / prompt_id / 'prompt.txt').open('r') as f: 287 | return f.read().strip() 288 | 289 | def _try_load_annotations(self): 290 | if not (self.path / 'annotations.json').exists(): 291 | return None 292 | 293 | return json.load((self.path / 'annotations.json').open()) 294 | 295 | def annotate(self, key: str, value: Any) -> 'GenerationExperiment': 296 | if self.annotations is None: 297 | self.annotations = {} 298 | 299 | self.annotations[key] = value 300 | 301 | return self 302 | 303 | @classmethod 304 | def load( 305 | cls, 306 | path, 307 | pred_prefix='daam', 308 | composite=False, 309 | simplify80=False, 310 | vocab=None, 311 | subtype='.', 312 | all_subtypes=False 313 | ): 314 | # type: (str, str, bool, bool, List[str] | None, str, bool) -> GenerationExperiment | List[GenerationExperiment] 315 | if all_subtypes: 316 | experiments = [] 317 | 318 | for directory in Path(path).iterdir(): 319 | if not directory.is_dir(): 320 | continue 321 | 322 | try: 323 | experiments.append(cls.load( 324 | path, 325 | pred_prefix=pred_prefix, 326 | composite=composite, 327 | simplify80=simplify80, 328 | vocab=vocab, 329 | subtype=directory.name 330 | )) 331 | except: 332 | pass 333 | 334 | return experiments 335 | 336 | path = Path(path) 337 | exp = torch.load(path / subtype / 'generation.pt') 338 | exp.subtype = subtype 339 | exp.path = path 340 | exp.truth_masks = exp._load_truth_masks(simplify80=simplify80) 341 | exp.prediction_masks = exp._load_pred_masks(pred_prefix, composite=composite, simplify80=simplify80, vocab=vocab) 342 | exp.annotations = exp._try_load_annotations() 343 | 344 | return exp 345 | -------------------------------------------------------------------------------- /daam/heatmap.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from dataclasses import dataclass 3 | from functools import lru_cache 4 | from pathlib import Path 5 | from typing import List, Any, Dict, Tuple, Set, Iterable 6 | 7 | from matplotlib import pyplot as plt 8 | import numpy as np 9 | import PIL.Image 10 | import spacy.tokens 11 | import torch 12 | import torch.nn.functional as F 13 | 14 | from .evaluate import compute_ioa 15 | from .utils import compute_token_merge_indices, cached_nlp, auto_autocast 16 | 17 | __all__ = ['GlobalHeatMap', 'RawHeatMapCollection', 'WordHeatMap', 'ParsedHeatMap', 'SyntacticHeatMapPair'] 18 | 19 | 20 | def plot_overlay_heat_map(im, heat_map, word=None, out_file=None, crop=None, color_normalize=True, ax=None): 21 | # type: (PIL.Image.Image | np.ndarray, torch.Tensor, str, Path, int, bool, plt.Axes) -> None 22 | if ax is None: 23 | plt.clf() 24 | plt.rcParams.update({'font.size': 24}) 25 | plt_ = plt 26 | else: 27 | plt_ = ax 28 | 29 | with auto_autocast(dtype=torch.float32): 30 | im = np.array(im) 31 | 32 | if crop is not None: 33 | heat_map = heat_map.squeeze()[crop:-crop, crop:-crop] 34 | im = im[crop:-crop, crop:-crop] 35 | 36 | if color_normalize: 37 | plt_.imshow(heat_map.squeeze().cpu().numpy(), cmap='jet') 38 | else: 39 | heat_map = heat_map.clamp_(min=0, max=1) 40 | plt_.imshow(heat_map.squeeze().cpu().numpy(), cmap='jet', vmin=0.0, vmax=1.0) 41 | 42 | im = torch.from_numpy(im).float() / 255 43 | im = torch.cat((im, (1 - heat_map.unsqueeze(-1))), dim=-1) 44 | plt_.imshow(im) 45 | 46 | if word is not None: 47 | if ax is None: 48 | plt.title(word) 49 | else: 50 | ax.set_title(word) 51 | 52 | if out_file is not None: 53 | plt.savefig(out_file) 54 | 55 | 56 | class WordHeatMap: 57 | def __init__(self, heatmap: torch.Tensor, word: str = None, word_idx: int = None): 58 | self.word = word 59 | self.word_idx = word_idx 60 | self.heatmap = heatmap 61 | 62 | @property 63 | def value(self): 64 | return self.heatmap 65 | 66 | def plot_overlay(self, image, out_file=None, color_normalize=True, ax=None, **expand_kwargs): 67 | # type: (PIL.Image.Image | np.ndarray, Path, bool, plt.Axes, Dict[str, Any]) -> None 68 | plot_overlay_heat_map( 69 | image, 70 | self.expand_as(image, **expand_kwargs), 71 | word=self.word, 72 | out_file=out_file, 73 | color_normalize=color_normalize, 74 | ax=ax 75 | ) 76 | 77 | def expand_as(self, image, absolute=False, threshold=None, plot=False, **plot_kwargs): 78 | # type: (PIL.Image.Image, bool, float, bool, Dict[str, Any]) -> torch.Tensor 79 | im = self.heatmap.unsqueeze(0).unsqueeze(0) 80 | im = F.interpolate(im.float().detach(), size=(image.size[0], image.size[1]), mode='bicubic') 81 | 82 | if not absolute: 83 | im = (im - im.min()) / (im.max() - im.min() + 1e-8) 84 | 85 | if threshold: 86 | im = (im > threshold).float() 87 | 88 | im = im.cpu().detach().squeeze() 89 | 90 | if plot: 91 | self.plot_overlay(image, **plot_kwargs) 92 | 93 | return im 94 | 95 | def compute_ioa(self, other: 'WordHeatMap'): 96 | return compute_ioa(self.heatmap, other.heatmap) 97 | 98 | 99 | @dataclass 100 | class SyntacticHeatMapPair: 101 | head_heat_map: WordHeatMap 102 | dep_heat_map: WordHeatMap 103 | head_text: str 104 | dep_text: str 105 | relation: str 106 | 107 | 108 | @dataclass 109 | class ParsedHeatMap: 110 | word_heat_map: WordHeatMap 111 | token: spacy.tokens.Token 112 | 113 | 114 | class GlobalHeatMap: 115 | def __init__(self, tokenizer: Any, prompt: str, heat_maps: torch.Tensor): 116 | self.tokenizer = tokenizer 117 | self.heat_maps = heat_maps 118 | self.prompt = prompt 119 | self.compute_word_heat_map = lru_cache(maxsize=50)(self.compute_word_heat_map) 120 | 121 | def compute_word_heat_map(self, word: str, word_idx: int = None, offset_idx: int = 0) -> WordHeatMap: 122 | merge_idxs, word_idx = compute_token_merge_indices(self.tokenizer, self.prompt, word, word_idx, offset_idx) 123 | return WordHeatMap(self.heat_maps[merge_idxs].mean(0), word, word_idx) 124 | 125 | def parsed_heat_maps(self) -> Iterable[ParsedHeatMap]: 126 | for token in cached_nlp(self.prompt): 127 | try: 128 | heat_map = self.compute_word_heat_map(token.text) 129 | yield ParsedHeatMap(heat_map, token) 130 | except ValueError: 131 | pass 132 | 133 | def dependency_relations(self) -> Iterable[SyntacticHeatMapPair]: 134 | for token in cached_nlp(self.prompt): 135 | if token.dep_ != 'ROOT': 136 | try: 137 | dep_heat_map = self.compute_word_heat_map(token.text) 138 | head_heat_map = self.compute_word_heat_map(token.head.text) 139 | 140 | yield SyntacticHeatMapPair(head_heat_map, dep_heat_map, token.head.text, token.text, token.dep_) 141 | except ValueError: 142 | pass 143 | 144 | 145 | RawHeatMapKey = Tuple[int, int, int] # factor, layer, head 146 | 147 | 148 | class RawHeatMapCollection: 149 | def __init__(self): 150 | self.ids_to_heatmaps: Dict[RawHeatMapKey, torch.Tensor] = defaultdict(lambda: 0.0) 151 | self.ids_to_num_maps: Dict[RawHeatMapKey, int] = defaultdict(lambda: 0) 152 | 153 | def update(self, factor: int, layer_idx: int, head_idx: int, heatmap: torch.Tensor): 154 | with auto_autocast(dtype=torch.float32): 155 | key = (factor, layer_idx, head_idx) 156 | self.ids_to_heatmaps[key] = self.ids_to_heatmaps[key] + heatmap 157 | 158 | def factors(self) -> Set[int]: 159 | return set(key[0] for key in self.ids_to_heatmaps.keys()) 160 | 161 | def layers(self) -> Set[int]: 162 | return set(key[1] for key in self.ids_to_heatmaps.keys()) 163 | 164 | def heads(self) -> Set[int]: 165 | return set(key[2] for key in self.ids_to_heatmaps.keys()) 166 | 167 | def __iter__(self): 168 | return iter(self.ids_to_heatmaps.items()) 169 | 170 | def clear(self): 171 | self.ids_to_heatmaps.clear() 172 | self.ids_to_num_maps.clear() 173 | -------------------------------------------------------------------------------- /daam/hook.py: -------------------------------------------------------------------------------- 1 | from typing import List, Generic, TypeVar, Callable, Union, Any 2 | import functools 3 | import itertools 4 | 5 | from diffusers import UNet2DConditionModel 6 | from diffusers.models.attention_processor import Attention 7 | import torch.nn as nn 8 | 9 | 10 | __all__ = ['ObjectHooker', 'ModuleLocator', 'AggregateHooker', 'UNetCrossAttentionLocator'] 11 | 12 | 13 | ModuleType = TypeVar('ModuleType') 14 | ModuleListType = TypeVar('ModuleListType', bound=List) 15 | 16 | 17 | class ModuleLocator(Generic[ModuleType]): 18 | def locate(self, model: nn.Module) -> List[ModuleType]: 19 | raise NotImplementedError 20 | 21 | 22 | class ObjectHooker(Generic[ModuleType]): 23 | def __init__(self, module: ModuleType): 24 | self.module: ModuleType = module 25 | self.hooked = False 26 | self.old_state = dict() 27 | 28 | def __enter__(self): 29 | self.hook() 30 | return self 31 | 32 | def __exit__(self, exc_type, exc_val, exc_tb): 33 | self.unhook() 34 | 35 | def hook(self): 36 | if self.hooked: 37 | raise RuntimeError('Already hooked module') 38 | 39 | self.old_state = dict() 40 | self.hooked = True 41 | self._hook_impl() 42 | 43 | return self 44 | 45 | def unhook(self): 46 | if not self.hooked: 47 | raise RuntimeError('Module is not hooked') 48 | 49 | for k, v in self.old_state.items(): 50 | if k.startswith('old_fn_'): 51 | setattr(self.module, k[7:], v) 52 | 53 | self.hooked = False 54 | self._unhook_impl() 55 | 56 | return self 57 | 58 | def monkey_patch(self, fn_name, fn, strict: bool = True): 59 | try: 60 | self.old_state[f'old_fn_{fn_name}'] = getattr(self.module, fn_name) 61 | setattr(self.module, fn_name, functools.partial(fn, self.module)) 62 | except AttributeError: 63 | if strict: 64 | raise 65 | 66 | def monkey_super(self, fn_name, *args, **kwargs): 67 | return self.old_state[f'old_fn_{fn_name}'](*args, **kwargs) 68 | 69 | def _hook_impl(self): 70 | raise NotImplementedError 71 | 72 | def _unhook_impl(self): 73 | pass 74 | 75 | 76 | class AggregateHooker(ObjectHooker[ModuleListType]): 77 | def _hook_impl(self): 78 | for h in self.module: 79 | h.hook() 80 | 81 | def _unhook_impl(self): 82 | for h in self.module: 83 | h.unhook() 84 | 85 | def register_hook(self, hook: ObjectHooker): 86 | self.module.append(hook) 87 | 88 | 89 | class UNetCrossAttentionLocator(ModuleLocator[Attention]): 90 | def __init__(self, restrict: bool = None, locate_middle_block: bool = False): 91 | self.restrict = restrict 92 | self.layer_names = [] 93 | self.locate_middle_block = locate_middle_block 94 | 95 | def locate(self, model: UNet2DConditionModel) -> List[Attention]: 96 | """ 97 | Locate all cross-attention modules in a UNet2DConditionModel. 98 | 99 | Args: 100 | model (`UNet2DConditionModel`): The model to locate the cross-attention modules in. 101 | 102 | Returns: 103 | `List[Attention]`: The list of cross-attention modules. 104 | """ 105 | self.layer_names.clear() 106 | blocks_list = [] 107 | up_names = ['up'] * len(model.up_blocks) 108 | down_names = ['down'] * len(model.down_blocks) 109 | 110 | for unet_block, name in itertools.chain( 111 | zip(model.up_blocks, up_names), 112 | zip(model.down_blocks, down_names), 113 | zip([model.mid_block], ['mid']) if self.locate_middle_block else [], 114 | ): 115 | if 'CrossAttn' in unet_block.__class__.__name__: 116 | blocks = [] 117 | 118 | for spatial_transformer in unet_block.attentions: 119 | for transformer_block in spatial_transformer.transformer_blocks: 120 | blocks.append(transformer_block.attn2) 121 | 122 | blocks = [b for idx, b in enumerate(blocks) if self.restrict is None or idx in self.restrict] 123 | names = [f'{name}-attn-{i}' for i in range(len(blocks)) if self.restrict is None or i in self.restrict] 124 | blocks_list.extend(blocks) 125 | self.layer_names.extend(names) 126 | 127 | return blocks_list 128 | -------------------------------------------------------------------------------- /daam/run/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/castorini/daam/c30493ed0154bfccb6c342400f25cc24599bb1ff/daam/run/__init__.py -------------------------------------------------------------------------------- /daam/run/demo.py: -------------------------------------------------------------------------------- 1 | import math 2 | import time 3 | from threading import Lock 4 | from typing import Any, List 5 | import argparse 6 | 7 | import numpy as np 8 | from diffusers import StableDiffusionPipeline 9 | from matplotlib import pyplot as plt 10 | import gradio as gr 11 | import torch 12 | from spacy import displacy 13 | 14 | from daam import trace 15 | from daam.utils import set_seed, cached_nlp, auto_autocast 16 | 17 | 18 | def dependency(text): 19 | doc = cached_nlp(text) 20 | svg = displacy.render(doc, style='dep', options={'compact': True, 'distance': 100}) 21 | 22 | return svg 23 | 24 | 25 | def get_tokenizing_mapping(prompt: str, tokenizer: Any) -> List[List[int]]: 26 | tokens = tokenizer.tokenize(prompt) 27 | merge_idxs = [] 28 | words = [] 29 | curr_idxs = [] 30 | curr_word = '' 31 | 32 | for i, token in enumerate(tokens): 33 | curr_idxs.append(i + 1) # because of the [CLS] token 34 | curr_word += token 35 | if '' in token: 36 | merge_idxs.append(curr_idxs) 37 | curr_idxs = [] 38 | words.append(curr_word[:-4]) 39 | curr_word = '' 40 | 41 | return merge_idxs, words 42 | 43 | 44 | def get_args(): 45 | model_id_map = { 46 | 'v1': 'runwayml/stable-diffusion-v1-5', 47 | 'v2-base': 'stabilityai/stable-diffusion-2-base', 48 | 'v2-large': 'stabilityai/stable-diffusion-2', 49 | 'v2-1-base': 'stabilityai/stable-diffusion-2-1-base', 50 | 'v2-1-large': 'stabilityai/stable-diffusion-2-1', 51 | } 52 | 53 | parser = argparse.ArgumentParser() 54 | parser.add_argument('--model', '-m', type=str, default='v2-1-base', choices=list(model_id_map.keys()), help="which diffusion model to use") 55 | parser.add_argument('--seed', '-s', type=int, default=0, help="the random seed") 56 | parser.add_argument('--port', '-p', type=int, default=8080, help="the port to launch the demo") 57 | parser.add_argument('--no-cuda', action='store_true', help="Use CPUs instead of GPUs") 58 | args = parser.parse_args() 59 | args.model = model_id_map[args.model] 60 | return args 61 | 62 | 63 | def main(): 64 | args = get_args() 65 | plt.switch_backend('agg') 66 | 67 | device = "cpu" if args.no_cuda else "cuda" 68 | pipe = StableDiffusionPipeline.from_pretrained(args.model, use_auth_token=True).to(device) 69 | lock = Lock() 70 | 71 | @torch.no_grad() 72 | def update_dropdown(prompt): 73 | tokens = [''] + [x.text for x in cached_nlp(prompt) if x.pos_ == 'ADJ'] 74 | return gr.Dropdown.update(choices=tokens), dependency(prompt) 75 | 76 | @torch.no_grad() 77 | def plot(prompt, choice, replaced_word, inf_steps, is_random_seed): 78 | new_prompt = prompt.replace(',', ', ').replace('.', '. ') 79 | 80 | if choice: 81 | if not replaced_word: 82 | replaced_word = '.' 83 | 84 | new_prompt = [replaced_word if tok.text == choice else tok.text for tok in cached_nlp(prompt)] 85 | new_prompt = ' '.join(new_prompt) 86 | 87 | merge_idxs, words = get_tokenizing_mapping(prompt, pipe.tokenizer) 88 | with auto_autocast(dtype=torch.float16), lock: 89 | try: 90 | plt.close('all') 91 | plt.clf() 92 | except: 93 | pass 94 | 95 | seed = int(time.time()) if is_random_seed else args.seed 96 | gen = set_seed(seed) 97 | prompt = prompt.replace(',', ', ').replace('.', '. ') # hacky fix to address later 98 | 99 | if choice: 100 | new_prompt = new_prompt.replace(',', ', ').replace('.', '. ') # hacky fix to address later 101 | 102 | with trace(pipe, save_heads=new_prompt != prompt) as tc: 103 | out = pipe(prompt, num_inference_steps=inf_steps, generator=gen) 104 | image = np.array(out.images[0]) / 255 105 | heat_map = tc.compute_global_heat_map() 106 | 107 | if new_prompt == prompt: 108 | image2 = image 109 | else: 110 | gen = set_seed(seed) 111 | 112 | with trace(pipe, load_heads=True) as tc: 113 | out2 = pipe(new_prompt, num_inference_steps=inf_steps, generator=gen) 114 | image2 = np.array(out2.images[0]) / 255 115 | else: 116 | with trace(pipe) as tc: 117 | out = pipe(prompt, num_inference_steps=inf_steps, generator=gen) 118 | image = np.array(out.images[0]) / 255 119 | heat_map = tc.compute_global_heat_map() 120 | 121 | # the main image 122 | if new_prompt == prompt: 123 | fig, ax = plt.subplots() 124 | ax.imshow(image) 125 | ax.set_xticks([]) 126 | ax.set_yticks([]) 127 | else: 128 | fig, ax = plt.subplots(1, 2) 129 | ax[0].imshow(image) 130 | 131 | if choice: 132 | ax[1].imshow(image2) 133 | 134 | ax[0].set_title(choice) 135 | ax[0].set_xticks([]) 136 | ax[0].set_yticks([]) 137 | ax[1].set_title(replaced_word) 138 | ax[1].set_xticks([]) 139 | ax[1].set_yticks([]) 140 | 141 | # the heat maps 142 | num_cells = 4 143 | w = int(num_cells * 3.5) 144 | h = math.ceil(len(words) / num_cells * 4.5) 145 | fig_soft, axs_soft = plt.subplots(math.ceil(len(words) / num_cells), num_cells, figsize=(w, h)) 146 | axs_soft = axs_soft.flatten() 147 | with torch.cuda.amp.autocast(dtype=torch.float32): 148 | for idx, parsed_map in enumerate(heat_map.parsed_heat_maps()): 149 | word_ax_soft = axs_soft[idx] 150 | word_ax_soft.set_xticks([]) 151 | word_ax_soft.set_yticks([]) 152 | parsed_map.word_heat_map.plot_overlay(out.images[0], ax=word_ax_soft) 153 | word_ax_soft.set_title(parsed_map.word_heat_map.word, fontsize=12) 154 | 155 | for idx in range(len(words), len(axs_soft)): 156 | fig_soft.delaxes(axs_soft[idx]) 157 | 158 | return fig, fig_soft 159 | 160 | with gr.Blocks(css='scrollbar.css') as demo: 161 | md = '''# DAAM: Attention Maps for Interpreting Stable Diffusion 162 | Check out the paper: [What the DAAM: Interpreting Stable Diffusion Using Cross Attention](http://arxiv.org/abs/2210.04885). 163 | See our (much cleaner) [DAAM codebase](https://github.com/castorini/daam) on GitHub. 164 | ''' 165 | gr.Markdown(md) 166 | 167 | with gr.Row(): 168 | with gr.Column(): 169 | dropdown = gr.Dropdown([ 170 | 'An angry, bald man doing research', 171 | 'A bear and a moose', 172 | 'A blue car driving through the city', 173 | 'Monkey walking with hat', 174 | 'Doing research at Comcast Applied AI labs', 175 | 'Professor Jimmy Lin from the modern University of Waterloo', 176 | 'Yann Lecun teaching machine learning on a green chalkboard', 177 | 'A brown cat eating yummy cake for her birthday', 178 | 'A brown fox, a white dog, and a blue wolf in a green field', 179 | ], label='Examples', value='An angry, bald man doing research') 180 | 181 | text = gr.Textbox(label='Prompt', value='An angry, bald man doing research') 182 | 183 | with gr.Row(): 184 | doc = cached_nlp('An angry, bald man doing research') 185 | tokens = [''] + [x.text for x in doc if x.pos_ == 'ADJ'] 186 | dropdown2 = gr.Dropdown(tokens, label='Adjective to replace', interactive=True) 187 | text2 = gr.Textbox(label='New adjective', value='') 188 | 189 | checkbox = gr.Checkbox(value=False, label='Random seed') 190 | slider1 = gr.Slider(15, 30, value=25, interactive=True, step=1, label='Inference steps') 191 | 192 | submit_btn = gr.Button('Submit', elem_id='submit-btn') 193 | viz = gr.HTML(dependency('An angry, bald man doing research'), elem_id='viz') 194 | 195 | with gr.Column(): 196 | with gr.Tab('Images'): 197 | p0 = gr.Plot() 198 | 199 | with gr.Tab('DAAM Maps'): 200 | p1 = gr.Plot() 201 | 202 | text.change(fn=update_dropdown, inputs=[text], outputs=[dropdown2, viz]) 203 | 204 | submit_btn.click( 205 | fn=plot, 206 | inputs=[text, dropdown2, text2, slider1, checkbox], 207 | outputs=[p0, p1]) 208 | dropdown.change(lambda prompt: prompt, dropdown, text) 209 | dropdown.update() 210 | 211 | while True: 212 | try: 213 | demo.launch(server_name='0.0.0.0', server_port=args.port) 214 | except OSError: 215 | gr.close_all() 216 | except KeyboardInterrupt: 217 | gr.close_all() 218 | break 219 | 220 | 221 | if __name__ == '__main__': 222 | main() 223 | -------------------------------------------------------------------------------- /daam/run/evaluate.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import argparse 3 | 4 | from tqdm import tqdm 5 | 6 | from daam.evaluate import MeanEvaluator, UnsupervisedEvaluator 7 | from daam.experiment import GenerationExperiment, COCOSTUFF27_LABELS, COCO80_LABELS 8 | 9 | 10 | def main(): 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument('--input-folder', '-i', type=str, required=True) 13 | parser.add_argument('--pred-prefix', '-p', type=str, default='daam') 14 | parser.add_argument('--mask-type', '-m', type=str, default='word', choices=['word', 'composite']) 15 | parser.add_argument('--eval-type', '-e', type=str, default='labeled', choices=['labeled', 'unlabeled', 'hungarian']) 16 | parser.add_argument('--restrict-set', '-r', type=str, default='none', choices=['none', 'coco27', 'coco80']) 17 | parser.add_argument('--subtype', '-st', type=str, default='.') 18 | args = parser.parse_args() 19 | 20 | evaluator = MeanEvaluator() if args.eval_type != 'hungarian' else UnsupervisedEvaluator() 21 | simplify80 = False 22 | vocab = [] 23 | 24 | if args.restrict_set == 'coco27': 25 | simplify80 = True 26 | vocab = COCOSTUFF27_LABELS 27 | elif args.restrict_set == 'coco80': 28 | vocab = COCO80_LABELS 29 | 30 | if not vocab: 31 | for path in tqdm(Path(args.input_folder).glob('*')): 32 | if not path.is_dir() or not GenerationExperiment.contains_truth_mask(path): 33 | continue 34 | 35 | exp = GenerationExperiment.load( 36 | path, 37 | args.pred_prefix, 38 | composite=args.mask_type == 'composite', 39 | simplify80=simplify80 40 | ) 41 | 42 | vocab.extend(exp.truth_masks) 43 | vocab.extend(exp.prediction_masks) 44 | 45 | vocab = list(set(vocab)) 46 | vocab.sort() 47 | 48 | for path in tqdm(Path(args.input_folder).glob('*')): 49 | if not path.is_dir() or not GenerationExperiment.contains_truth_mask(path): 50 | continue 51 | 52 | exp = GenerationExperiment.load( 53 | path, 54 | args.pred_prefix, 55 | composite=args.mask_type == 'composite', 56 | simplify80=simplify80, 57 | vocab=vocab, 58 | subtype=args.subtype 59 | ) 60 | 61 | if args.eval_type == 'labeled': 62 | for word, mask in exp.truth_masks.items(): 63 | if word not in vocab and args.restrict_set != 'none': 64 | continue 65 | 66 | try: 67 | evaluator.log_iou(exp.prediction_masks[word], mask) 68 | evaluator.log_intensity(exp.prediction_masks[word]) 69 | except KeyError: 70 | continue 71 | elif args.eval_type == 'hungarian': 72 | for gt_word, gt_mask in exp.truth_masks.items(): 73 | if gt_word not in vocab and args.restrict_set != 'none': 74 | continue 75 | 76 | for pred_word, pred_mask in exp.prediction_masks.items(): 77 | try: 78 | evaluator.log_iou(pred_mask, gt_mask, vocab.index(gt_word), vocab.index(pred_word)) 79 | except (KeyError, ValueError): 80 | continue 81 | 82 | evaluator.increment() 83 | else: 84 | for word, mask in exp.truth_masks.items(): 85 | evaluator.log_iou(list(exp.prediction_masks.values()), mask) 86 | 87 | print(evaluator) 88 | 89 | 90 | if __name__ == '__main__': 91 | main() 92 | -------------------------------------------------------------------------------- /daam/run/generate.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from pathlib import Path 3 | import argparse 4 | import json 5 | import random 6 | import sys 7 | import time 8 | 9 | import pandas as pd 10 | from diffusers import StableDiffusionPipeline, DiffusionPipeline 11 | from tqdm import tqdm 12 | import inflect 13 | import numpy as np 14 | import torch 15 | 16 | from daam import trace 17 | from daam.experiment import GenerationExperiment, build_word_list_coco80 18 | from daam.utils import set_seed, cached_nlp, auto_device, auto_autocast 19 | 20 | 21 | def main(): 22 | actions = ['quickgen', 'prompt', 'coco', 'template', 'cconj', 'coco-unreal', 'stdin', 'regenerate'] 23 | model_id_map = { 24 | 'v1': 'runwayml/stable-diffusion-v1-5', 25 | 'v2-base': 'stabilityai/stable-diffusion-2-base', 26 | 'v2-large': 'stabilityai/stable-diffusion-2', 27 | 'v2-1-base': 'stabilityai/stable-diffusion-2-1-base', 28 | 'v2-1-large': 'stabilityai/stable-diffusion-2-1', 29 | 'xl-base-1.0': 'stabilityai/stable-diffusion-xl-base-1.0', 30 | } 31 | 32 | parser = argparse.ArgumentParser() 33 | parser.add_argument('prompt', nargs='?', type=str) 34 | parser.add_argument('--action', '-a', type=str, choices=actions, default=actions[0]) 35 | parser.add_argument('--low-memory', action='store_true') 36 | parser.add_argument('--model', type=str, default='v2-1-base', choices=list(model_id_map.keys())) 37 | parser.add_argument('--output-folder', '-o', type=str) 38 | parser.add_argument('--input-folder', '-i', type=str, default='input') 39 | parser.add_argument('--seed', '-s', type=int, default=0) 40 | parser.add_argument('--gen-limit', type=int, default=1000) 41 | parser.add_argument('--template', type=str, default='{numeral} {noun}') 42 | parser.add_argument('--template-data-file', '-tdf', type=str, default='template.tsv') 43 | parser.add_argument('--seed-offset', type=int, default=0) 44 | parser.add_argument('--num-timesteps', '-n', type=int, default=30) 45 | parser.add_argument('--all-heads', action='store_true') 46 | parser.add_argument('--word', type=str) 47 | parser.add_argument('--random-seed', action='store_true') 48 | parser.add_argument('--truth-only', action='store_true') 49 | parser.add_argument('--save-heads', action='store_true') 50 | parser.add_argument('--load-heads', action='store_true') 51 | args = parser.parse_args() 52 | 53 | eng = inflect.engine() 54 | args.lemma = cached_nlp(args.word)[0].lemma_ if args.word else None 55 | model_id = model_id_map[args.model] 56 | seeds = [] 57 | 58 | if args.action.startswith('coco'): 59 | with (Path(args.input_folder) / 'captions_val2014.json').open() as f: 60 | captions = json.load(f)['annotations'] 61 | 62 | random.shuffle(captions) 63 | new_captions = [] 64 | 65 | if args.action == 'coco-unreal': 66 | pos_map = defaultdict(list) 67 | 68 | for caption in tqdm(captions): 69 | doc = cached_nlp(caption['caption']) 70 | 71 | for tok in doc: 72 | if tok.pos_ == 'ADJ' or tok.pos_ == 'NOUN': 73 | pos_map[tok.pos_].append(tok) 74 | 75 | for caption in tqdm(captions): 76 | doc = cached_nlp(caption['caption']) 77 | new_tokens = [] 78 | 79 | for tok in doc: 80 | if tok.pos_ == 'ADJ' or tok.pos_ == 'NOUN': 81 | new_tokens.append(random.choice(pos_map[tok.pos_])) 82 | 83 | new_prompt = ''.join([tok.text_with_ws for tok in new_tokens]) 84 | caption['caption'] = new_prompt 85 | 86 | print(new_prompt) 87 | 88 | new_prompt = ''.join([tok.text_with_ws for tok in new_tokens]) 89 | caption['caption'] = new_prompt 90 | 91 | print(new_prompt) 92 | new_captions.append(caption) 93 | 94 | prompts = [(caption['id'], caption['caption']) for caption in captions] 95 | elif args.action == 'stdin': 96 | prompts = [] 97 | 98 | for idx, line in enumerate(sys.stdin): 99 | prompts.append((idx, line.strip())) 100 | elif args.action == 'template': 101 | template_df = pd.read_csv(args.template_data_file, sep='\t', quoting=3) 102 | sample_dict = defaultdict(list) 103 | 104 | for name, df in template_df.groupby('pos'): 105 | sample_dict[name].extend(df['word'].tolist()) 106 | 107 | prompts = [] 108 | template_words = args.template.split() 109 | plural_numerals = {'0', '2', '3', '4', '5', '6', '7', '8', '9', 'zero', 'two', 'three', 'four', 'five', 'six', 'seven', 'eight', 'nine'} 110 | 111 | for prompt_id in range(args.gen_limit): 112 | words = [] 113 | pluralize = False 114 | 115 | for word in template_words: 116 | if word.startswith('{'): 117 | pos = word[1:-1] 118 | word = random.choice(sample_dict[pos]) 119 | 120 | if pos == 'noun' and pluralize: 121 | word = eng.plural(word) 122 | 123 | words.append(word) 124 | pluralize = word in plural_numerals 125 | 126 | prompt_id = str(prompt_id) 127 | prompts.append((prompt_id, ' '.join(words))) 128 | tqdm.write(str(prompts[-1])) 129 | elif args.action == 'cconj': 130 | words_map = build_word_list_coco80() 131 | prompts = [] 132 | 133 | for idx in range(args.gen_limit): 134 | use_cohyponym = random.random() < 0.5 135 | 136 | if use_cohyponym: 137 | c = random.choice(list(words_map.keys())) 138 | w1, w2 = np.random.choice(words_map[c], 2, replace=False) 139 | else: 140 | c1, c2 = np.random.choice(list(words_map.keys()), 2, replace=False) 141 | w1 = random.choice(words_map[c1]) 142 | w2 = random.choice(words_map[c2]) 143 | 144 | prompt_id = f'{"cohypo" if use_cohyponym else "diff"}-{idx}' 145 | a1 = 'an' if w1[0] in 'aeiou' else 'a' 146 | a2 = 'an' if w2[0] in 'aeiou' else 'a' 147 | prompt = f'{a1} {w1} and {a2} {w2}' 148 | prompts.append((prompt_id, prompt)) 149 | elif args.action == 'quickgen': 150 | if args.output_folder is None: 151 | args.output_folder = '.' 152 | 153 | prompts = [('.', args.prompt)] 154 | elif args.action == 'regenerate': 155 | prompts = [] 156 | 157 | for exp_folder in Path(args.input_folder).iterdir(): 158 | if not GenerationExperiment.contains_truth_mask(exp_folder) and args.truth_only: 159 | continue 160 | 161 | prompt = GenerationExperiment.read_prompt(exp_folder) 162 | prompts.append((exp_folder.name, prompt)) 163 | seeds.append(GenerationExperiment.read_seed(exp_folder)) 164 | 165 | if args.output_folder is None: 166 | args.output_folder = args.input_folder 167 | else: 168 | prompts = [('prompt', input('> '))] 169 | 170 | if args.output_folder is None: 171 | args.output_folder = 'output' 172 | 173 | new_prompts = [] 174 | 175 | if args.lemma is not None: 176 | for prompt_id, prompt in tqdm(prompts): 177 | if args.lemma not in prompt.lower(): 178 | continue 179 | 180 | doc = cached_nlp(prompt) 181 | found = False 182 | 183 | for tok in doc: 184 | if tok.lemma_.lower() == args.lemma and not found: 185 | found = True 186 | elif tok.lemma_.lower() == args.lemma: # filter out prompts with multiple instances of the word 187 | found = False 188 | break 189 | 190 | if found: 191 | new_prompts.append((prompt_id, prompt)) 192 | 193 | prompts = new_prompts 194 | 195 | prompts = prompts[:args.gen_limit] 196 | 197 | if 'xl' in model_id: 198 | pipe = DiffusionPipeline.from_pretrained( 199 | model_id, 200 | use_auth_token=True, 201 | torch_dtype=torch.float16, 202 | use_safetensors=True, variant='fp16' 203 | ) 204 | else: 205 | pipe = StableDiffusionPipeline.from_pretrained(model_id, use_auth_token=True) 206 | 207 | pipe = auto_device(pipe) 208 | 209 | with torch.no_grad(): 210 | for gen_idx, (prompt_id, prompt) in enumerate(tqdm(prompts)): 211 | seed = int(time.time()) if args.random_seed else args.seed 212 | prompt = prompt.replace(',', ' ,').replace('.', ' .').strip() 213 | 214 | if seeds and gen_idx < len(seeds): 215 | seed = seeds[gen_idx] 216 | 217 | gen = set_seed(seed) 218 | 219 | if args.action == 'cconj': 220 | seed = int(prompt_id.split('-')[1]) + args.seed_offset 221 | gen = set_seed(seed) 222 | 223 | prompt_id = str(prompt_id) 224 | 225 | with trace(pipe, low_memory=args.low_memory, save_heads=args.save_heads, load_heads=args.load_heads) as tc: 226 | out = pipe(prompt, num_inference_steps=args.num_timesteps, generator=gen, callback=tc.time_callback) 227 | exp = tc.to_experiment(args.output_folder, id=prompt_id, seed=seed) 228 | exp.save(args.output_folder, heat_maps=args.action == 'quickgen') 229 | 230 | if args.all_heads: 231 | exp.clear_checkpoint() 232 | 233 | for word in prompt.split(): 234 | if args.lemma is not None and cached_nlp(word)[0].lemma_.lower() != args.lemma: 235 | continue 236 | 237 | exp.save_heat_map(word) 238 | 239 | if args.all_heads: 240 | for head_idx in range(16): 241 | for layer_idx, layer_name in enumerate(tc.layer_names): 242 | try: 243 | heat_map = tc.compute_global_heat_map(layer_idx=layer_idx, head_idx=head_idx) 244 | exp = GenerationExperiment( 245 | path=Path(args.output_folder), 246 | id=prompt_id, 247 | global_heat_map=heat_map.heat_maps, 248 | seed=seed, 249 | prompt=prompt, 250 | image=out.images[0] 251 | ) 252 | 253 | exp.save_heat_map(word, output_prefix=f'l{layer_idx}-{layer_name}-h{head_idx}-') 254 | except RuntimeError: 255 | print(f'Missing ({layer_idx}, {head_idx}, {layer_name})') 256 | 257 | 258 | if __name__ == '__main__': 259 | main() 260 | -------------------------------------------------------------------------------- /daam/trace.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import List, Type, Any, Dict, Tuple, Union 3 | import math 4 | 5 | from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline 6 | from diffusers.image_processor import VaeImageProcessor 7 | from diffusers.models.attention_processor import Attention 8 | import numpy as np 9 | import PIL.Image as Image 10 | import torch 11 | import torch.nn.functional as F 12 | 13 | from .utils import cache_dir, auto_autocast 14 | from .experiment import GenerationExperiment 15 | from .heatmap import RawHeatMapCollection, GlobalHeatMap 16 | from .hook import ObjectHooker, AggregateHooker, UNetCrossAttentionLocator 17 | 18 | 19 | __all__ = ['trace', 'DiffusionHeatMapHooker', 'GlobalHeatMap'] 20 | 21 | 22 | class DiffusionHeatMapHooker(AggregateHooker): 23 | def __init__( 24 | self, 25 | pipeline: Union[StableDiffusionPipeline, StableDiffusionXLPipeline], 26 | low_memory: bool = False, 27 | load_heads: bool = False, 28 | save_heads: bool = False, 29 | data_dir: str = None 30 | ): 31 | self.all_heat_maps = RawHeatMapCollection() 32 | h = (pipeline.unet.config.sample_size * pipeline.vae_scale_factor) 33 | self.latent_hw = 4096 if h == 512 or h == 1024 else 9216 # 64x64 or 96x96 depending on if it's 2.0-v or 2.0 34 | locate_middle = load_heads or save_heads 35 | self.locator = UNetCrossAttentionLocator(restrict={0} if low_memory else None, locate_middle_block=locate_middle) 36 | self.last_prompt: str = '' 37 | self.last_image: Image = None 38 | self.time_idx = 0 39 | self._gen_idx = 0 40 | 41 | modules = [ 42 | UNetCrossAttentionHooker( 43 | x, 44 | self, 45 | layer_idx=idx, 46 | latent_hw=self.latent_hw, 47 | load_heads=load_heads, 48 | save_heads=save_heads, 49 | data_dir=data_dir 50 | ) for idx, x in enumerate(self.locator.locate(pipeline.unet)) 51 | ] 52 | 53 | modules.append(PipelineHooker(pipeline, self)) 54 | 55 | if type(pipeline) == StableDiffusionXLPipeline: 56 | modules.append(ImageProcessorHooker(pipeline.image_processor, self)) 57 | 58 | super().__init__(modules) 59 | self.pipe = pipeline 60 | 61 | def time_callback(self, *args, **kwargs): 62 | self.time_idx += 1 63 | 64 | @property 65 | def layer_names(self): 66 | return self.locator.layer_names 67 | 68 | def to_experiment(self, path, seed=None, id='.', subtype='.', **compute_kwargs): 69 | # type: (Union[Path, str], int, str, str, Dict[str, Any]) -> GenerationExperiment 70 | """Exports the last generation call to a serializable generation experiment.""" 71 | 72 | return GenerationExperiment( 73 | self.last_image, 74 | self.compute_global_heat_map(**compute_kwargs).heat_maps, 75 | self.last_prompt, 76 | seed=seed, 77 | id=id, 78 | subtype=subtype, 79 | path=path, 80 | tokenizer=self.pipe.tokenizer, 81 | ) 82 | 83 | def compute_global_heat_map(self, prompt=None, factors=None, head_idx=None, layer_idx=None, normalize=False): 84 | # type: (str, List[float], int, int, bool) -> GlobalHeatMap 85 | """ 86 | Compute the global heat map for the given prompt, aggregating across time (inference steps) and space (different 87 | spatial transformer block heat maps). 88 | 89 | Args: 90 | prompt: The prompt to compute the heat map for. If none, uses the last prompt that was used for generation. 91 | factors: Restrict the application to heat maps with spatial factors in this set. If `None`, use all sizes. 92 | head_idx: Restrict the application to heat maps with this head index. If `None`, use all heads. 93 | layer_idx: Restrict the application to heat maps with this layer index. If `None`, use all layers. 94 | 95 | Returns: 96 | A heat map object for computing word-level heat maps. 97 | """ 98 | heat_maps = self.all_heat_maps 99 | 100 | if prompt is None: 101 | prompt = self.last_prompt 102 | 103 | if factors is None: 104 | factors = {0, 1, 2, 4, 8, 16, 32, 64} 105 | else: 106 | factors = set(factors) 107 | 108 | all_merges = [] 109 | x = int(np.sqrt(self.latent_hw)) 110 | 111 | with auto_autocast(dtype=torch.float32): 112 | for (factor, layer, head), heat_map in heat_maps: 113 | if factor in factors and (head_idx is None or head_idx == head) and (layer_idx is None or layer_idx == layer): 114 | heat_map = heat_map.unsqueeze(1) 115 | # The clamping fixes undershoot. 116 | all_merges.append(F.interpolate(heat_map, size=(x, x), mode='bicubic').clamp_(min=0)) 117 | 118 | try: 119 | maps = torch.stack(all_merges, dim=0) 120 | except RuntimeError: 121 | if head_idx is not None or layer_idx is not None: 122 | raise RuntimeError('No heat maps found for the given parameters.') 123 | else: 124 | raise RuntimeError('No heat maps found. Did you forget to call `with trace(...)` during generation?') 125 | 126 | maps = maps.mean(0)[:, 0] 127 | maps = maps[:len(self.pipe.tokenizer.tokenize(prompt)) + 2] # 1 for SOS and 1 for padding 128 | 129 | if normalize: 130 | maps = maps / (maps[1:-1].sum(0, keepdim=True) + 1e-6) # drop out [SOS] and [PAD] for proper probabilities 131 | 132 | return GlobalHeatMap(self.pipe.tokenizer, prompt, maps) 133 | 134 | 135 | class ImageProcessorHooker(ObjectHooker[VaeImageProcessor]): 136 | def __init__(self, processor: VaeImageProcessor, parent_trace: 'trace'): 137 | super().__init__(processor) 138 | self.parent_trace = parent_trace 139 | 140 | def _hooked_postprocess(hk_self, _: VaeImageProcessor, *args, **kwargs): 141 | images = hk_self.monkey_super('postprocess', *args, **kwargs) 142 | hk_self.parent_trace.last_image = images[0] 143 | 144 | return images 145 | 146 | def _hook_impl(self): 147 | self.monkey_patch('postprocess', self._hooked_postprocess) 148 | 149 | 150 | class PipelineHooker(ObjectHooker[StableDiffusionPipeline]): 151 | def __init__(self, pipeline: StableDiffusionPipeline, parent_trace: 'trace'): 152 | super().__init__(pipeline) 153 | self.heat_maps = parent_trace.all_heat_maps 154 | self.parent_trace = parent_trace 155 | 156 | def _hooked_run_safety_checker(hk_self, self: StableDiffusionPipeline, image, *args, **kwargs): 157 | image, has_nsfw = hk_self.monkey_super('run_safety_checker', image, *args, **kwargs) 158 | 159 | if self.image_processor: 160 | if torch.is_tensor(image): 161 | images = self.image_processor.postprocess(image, output_type='pil') 162 | else: 163 | images = self.image_processor.numpy_to_pil(image) 164 | else: 165 | images = self.numpy_to_pil(image) 166 | 167 | hk_self.parent_trace.last_image = images[len(images)-1] 168 | 169 | return image, has_nsfw 170 | 171 | def _hooked_check_inputs(hk_self, _: StableDiffusionPipeline, prompt: Union[str, List[str]], *args, **kwargs): 172 | if not isinstance(prompt, str) and len(prompt) > 1: 173 | raise ValueError('Only single prompt generation is supported for heat map computation.') 174 | elif not isinstance(prompt, str): 175 | last_prompt = prompt[0] 176 | else: 177 | last_prompt = prompt 178 | 179 | hk_self.heat_maps.clear() 180 | hk_self.parent_trace.last_prompt = last_prompt 181 | 182 | return hk_self.monkey_super('check_inputs', prompt, *args, **kwargs) 183 | 184 | def _hook_impl(self): 185 | self.monkey_patch('run_safety_checker', self._hooked_run_safety_checker, strict=False) # not present in SDXL 186 | self.monkey_patch('check_inputs', self._hooked_check_inputs) 187 | 188 | 189 | class UNetCrossAttentionHooker(ObjectHooker[Attention]): 190 | def __init__( 191 | self, 192 | module: Attention, 193 | parent_trace: 'trace', 194 | context_size: int = 77, 195 | layer_idx: int = 0, 196 | latent_hw: int = 9216, 197 | load_heads: bool = False, 198 | save_heads: bool = False, 199 | data_dir: Union[str, Path] = None, 200 | ): 201 | super().__init__(module) 202 | self.heat_maps = parent_trace.all_heat_maps 203 | self.context_size = context_size 204 | self.layer_idx = layer_idx 205 | self.latent_hw = latent_hw 206 | 207 | self.load_heads = load_heads 208 | self.save_heads = save_heads 209 | self.trace = parent_trace 210 | 211 | if data_dir is not None: 212 | data_dir = Path(data_dir) 213 | else: 214 | data_dir = cache_dir() / 'heads' 215 | 216 | self.data_dir = data_dir 217 | self.data_dir.mkdir(parents=True, exist_ok=True) 218 | 219 | @torch.no_grad() 220 | def _unravel_attn(self, x): 221 | # type: (torch.Tensor) -> torch.Tensor 222 | # x shape: (heads, height * width, tokens) 223 | """ 224 | Unravels the attention, returning it as a collection of heat maps. 225 | 226 | Args: 227 | x (`torch.Tensor`): cross attention slice/map between the words and the tokens. 228 | value (`torch.Tensor`): the value tensor. 229 | 230 | Returns: 231 | `List[Tuple[int, torch.Tensor]]`: the list of heat maps across heads. 232 | """ 233 | h = w = int(math.sqrt(x.size(1))) 234 | maps = [] 235 | x = x.permute(2, 0, 1) 236 | 237 | with auto_autocast(dtype=torch.float32): 238 | for map_ in x: 239 | map_ = map_.view(map_.size(0), h, w) 240 | # For Instruct Pix2Pix, divide the map into three parts: text condition, image condition and unconditional, 241 | # and only keep the text condition part, which is first of the three parts(as per diffusers implementation). 242 | if map_.size(0) == 24: 243 | map_ = map_[:((map_.size(0) // 3)+1)] # Filter out unconditional and image condition 244 | else: 245 | map_ = map_[map_.size(0) // 2:] # # Filter out unconditional 246 | maps.append(map_) 247 | 248 | maps = torch.stack(maps, 0) # shape: (tokens, heads, height, width) 249 | return maps.permute(1, 0, 2, 3).contiguous() # shape: (heads, tokens, height, width) 250 | 251 | def _save_attn(self, attn_slice: torch.Tensor): 252 | torch.save(attn_slice, self.data_dir / f'{self.trace._gen_idx}.pt') 253 | 254 | def _load_attn(self) -> torch.Tensor: 255 | return torch.load(self.data_dir / f'{self.trace._gen_idx}.pt') 256 | 257 | def __call__( 258 | self, 259 | attn: Attention, 260 | hidden_states, 261 | encoder_hidden_states=None, 262 | attention_mask=None, 263 | ): 264 | """Capture attentions and aggregate them.""" 265 | batch_size, sequence_length, _ = hidden_states.shape 266 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) 267 | query = attn.to_q(hidden_states) 268 | 269 | if encoder_hidden_states is None: 270 | encoder_hidden_states = hidden_states 271 | elif attn.norm_cross is not None: 272 | encoder_hidden_states = attn.norm_cross(encoder_hidden_states) 273 | 274 | key = attn.to_k(encoder_hidden_states) 275 | value = attn.to_v(encoder_hidden_states) 276 | 277 | query = attn.head_to_batch_dim(query) 278 | key = attn.head_to_batch_dim(key) 279 | value = attn.head_to_batch_dim(value) 280 | 281 | attention_probs = attn.get_attention_scores(query, key, attention_mask) 282 | 283 | # DAAM save heads 284 | if self.save_heads: 285 | self._save_attn(attention_probs) 286 | elif self.load_heads: 287 | attention_probs = self._load_attn() 288 | 289 | # compute shape factor 290 | factor = int(math.sqrt(self.latent_hw // attention_probs.shape[1])) 291 | self.trace._gen_idx += 1 292 | 293 | # skip if too large 294 | if attention_probs.shape[-1] == self.context_size and factor != 8: 295 | # shape: (batch_size, 64 // factor, 64 // factor, 77) 296 | maps = self._unravel_attn(attention_probs) 297 | 298 | for head_idx, heatmap in enumerate(maps): 299 | self.heat_maps.update(factor, self.layer_idx, head_idx, heatmap) 300 | 301 | hidden_states = torch.bmm(attention_probs, value) 302 | hidden_states = attn.batch_to_head_dim(hidden_states) 303 | 304 | # linear proj 305 | hidden_states = attn.to_out[0](hidden_states) 306 | # dropout 307 | hidden_states = attn.to_out[1](hidden_states) 308 | 309 | return hidden_states 310 | 311 | def _hook_impl(self): 312 | self.original_processor = self.module.processor 313 | self.module.set_processor(self) 314 | 315 | def _unhook_impl(self): 316 | self.module.set_processor(self.original_processor) 317 | 318 | @property 319 | def num_heat_maps(self): 320 | return len(next(iter(self.heat_maps.values()))) 321 | 322 | 323 | trace: Type[DiffusionHeatMapHooker] = DiffusionHeatMapHooker 324 | -------------------------------------------------------------------------------- /daam/utils.py: -------------------------------------------------------------------------------- 1 | from functools import lru_cache 2 | from pathlib import Path 3 | import os 4 | import sys 5 | import random 6 | from typing import TypeVar 7 | 8 | import PIL.Image 9 | import matplotlib.pyplot as plt 10 | import numpy as np 11 | import spacy 12 | import torch 13 | import torch.nn.functional as F 14 | 15 | 16 | __all__ = ['set_seed', 'compute_token_merge_indices', 'plot_mask_heat_map', 'cached_nlp', 'cache_dir', 'auto_device', 'auto_autocast'] 17 | 18 | 19 | T = TypeVar('T') 20 | 21 | 22 | def auto_device(obj: T = torch.device('cpu')) -> T: 23 | if isinstance(obj, torch.device): 24 | return torch.device('cuda' if torch.cuda.is_available() else 'cpu') 25 | 26 | if torch.cuda.is_available(): 27 | return obj.to('cuda') 28 | 29 | return obj 30 | 31 | 32 | def auto_autocast(*args, **kwargs): 33 | if not torch.cuda.is_available(): 34 | kwargs['enabled'] = False 35 | 36 | return torch.cuda.amp.autocast(*args, **kwargs) 37 | 38 | 39 | def plot_mask_heat_map(im: PIL.Image.Image, heat_map: torch.Tensor, threshold: float = 0.4): 40 | im = torch.from_numpy(np.array(im)).float() / 255 41 | mask = (heat_map.squeeze() > threshold).float() 42 | im = im * mask.unsqueeze(-1) 43 | plt.imshow(im) 44 | 45 | 46 | def set_seed(seed: int) -> torch.Generator: 47 | random.seed(seed) 48 | np.random.seed(seed) 49 | torch.manual_seed(seed) 50 | torch.cuda.manual_seed_all(seed) 51 | 52 | gen = torch.Generator(device=auto_device()) 53 | gen.manual_seed(seed) 54 | 55 | return gen 56 | 57 | 58 | def cache_dir() -> Path: 59 | # *nix 60 | if os.name == 'posix' and sys.platform != 'darwin': 61 | xdg = os.environ.get('XDG_CACHE_HOME', os.path.expanduser('~/.cache')) 62 | return Path(xdg, 'daam') 63 | elif sys.platform == 'darwin': 64 | # Mac OS 65 | return Path(os.path.expanduser('~'), 'Library/Caches/daam') 66 | else: 67 | # Windows 68 | local = os.environ.get('LOCALAPPDATA', None) \ 69 | or os.path.expanduser('~\\AppData\\Local') 70 | return Path(local, 'daam') 71 | 72 | 73 | def compute_token_merge_indices(tokenizer, prompt: str, word: str, word_idx: int = None, offset_idx: int = 0): 74 | merge_idxs = [] 75 | tokens = tokenizer.tokenize(prompt.lower()) 76 | tokens = [x.replace('', '') for x in tokens] # New tokenizer uses wordpiece markers. 77 | 78 | if word_idx is None: 79 | word = word.lower() 80 | search_tokens = [x.replace('', '') for x in tokenizer.tokenize(word)] # New tokenizer uses wordpiece markers. 81 | start_indices = [x + offset_idx for x in range(len(tokens)) if tokens[x:x + len(search_tokens)] == search_tokens] 82 | 83 | for indice in start_indices: 84 | merge_idxs += [i + indice for i in range(0, len(search_tokens))] 85 | 86 | if not merge_idxs: 87 | raise ValueError(f'Search word {word} not found in prompt!') 88 | else: 89 | merge_idxs.append(word_idx) 90 | 91 | return [x + 1 for x in merge_idxs], word_idx # Offset by 1. 92 | 93 | 94 | nlp = None 95 | 96 | 97 | @lru_cache(maxsize=100000) 98 | def cached_nlp(prompt: str, type='en_core_web_md'): 99 | global nlp 100 | 101 | if nlp is None: 102 | try: 103 | nlp = spacy.load(type) 104 | except OSError: 105 | import os 106 | os.system(f'python -m spacy download {type}') 107 | nlp = spacy.load(type) 108 | 109 | return nlp(prompt) 110 | -------------------------------------------------------------------------------- /data/vocab-small.tsv: -------------------------------------------------------------------------------- 1 | word pos 2 | oven noun 3 | toaster noun 4 | refrigerator noun 5 | tv noun 6 | cell phone noun 7 | laptop noun 8 | chair noun 9 | couch noun 10 | bench noun 11 | cake noun 12 | donut noun 13 | sandwich noun 14 | hot dog noun 15 | pizza noun 16 | banana noun 17 | apple noun 18 | orange noun 19 | bowl noun 20 | cup noun 21 | fork noun 22 | knife noun 23 | spoon noun 24 | snowboard noun 25 | surfboard noun 26 | skateboard noun 27 | backpack noun 28 | handbag noun 29 | suitcase noun 30 | elephant noun 31 | bear noun 32 | zebra noun 33 | giraffe noun 34 | cat noun 35 | dog noun 36 | cow noun 37 | horse noun 38 | sheep noun 39 | bus noun 40 | truck noun 41 | car noun 42 | bicycle noun 43 | motorcycle noun 44 | -------------------------------------------------------------------------------- /data/vocab.tsv: -------------------------------------------------------------------------------- 1 | word pos 2 | person noun 3 | bicycle noun 4 | car noun 5 | motorcycle noun 6 | airplane noun 7 | bus noun 8 | train noun 9 | truck noun 10 | boat noun 11 | traffic light noun 12 | fire hydrant noun 13 | stop sign noun 14 | parking meter noun 15 | bench noun 16 | bird noun 17 | cat noun 18 | dog noun 19 | horse noun 20 | sheep noun 21 | cow noun 22 | elephant noun 23 | bear noun 24 | zebra noun 25 | giraffe noun 26 | backpack noun 27 | umbrella noun 28 | handbag noun 29 | tie noun 30 | suitcase noun 31 | frisbee noun 32 | skis noun 33 | snowboard noun 34 | sports ball noun 35 | kite noun 36 | baseball bat noun 37 | baseball glove noun 38 | skateboard noun 39 | surfboard noun 40 | tennis racket noun 41 | bottle noun 42 | wine glass noun 43 | cup noun 44 | fork noun 45 | knife noun 46 | spoon noun 47 | bowl noun 48 | banana noun 49 | apple noun 50 | sandwich noun 51 | orange noun 52 | broccoli noun 53 | carrot noun 54 | hot dog noun 55 | pizza noun 56 | donut noun 57 | cake noun 58 | chair noun 59 | couch noun 60 | potted plant noun 61 | bed noun 62 | dining table noun 63 | toilet noun 64 | tv noun 65 | laptop noun 66 | mouse noun 67 | remote noun 68 | keyboard noun 69 | cell phone noun 70 | microwave noun 71 | oven noun 72 | toaster noun 73 | sink noun 74 | refrigerator noun 75 | book noun 76 | clock noun 77 | vase noun 78 | scissors noun 79 | teddy bear noun 80 | hair drier noun 81 | toothbrush noun 82 | one numeral 83 | two numeral 84 | three numeral 85 | four numeral 86 | five numeral 87 | six numeral 88 | seven numeral 89 | eight numeral 90 | nine numeral 91 | -------------------------------------------------------------------------------- /docs/.buildinfo: -------------------------------------------------------------------------------- 1 | # Sphinx build info version 1 2 | # This file hashes the configuration used when building these files. When it is not found, a full rebuild will be done. 3 | config: c508eeb2efa3ecbb291988531b41d478 4 | tags: 645f666f9bcd5a90fca523b33c5a78b7 5 | -------------------------------------------------------------------------------- /docs/.nojekyll: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/castorini/daam/c30493ed0154bfccb6c342400f25cc24599bb1ff/docs/.nojekyll -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | 22 | github: 23 | @make html 24 | @cp -a _build/html/. ../docs 25 | -------------------------------------------------------------------------------- /docs/_sources/daam.rst.txt: -------------------------------------------------------------------------------- 1 | daam package 2 | ============ 3 | 4 | Subpackages 5 | ----------- 6 | 7 | .. toctree:: 8 | :maxdepth: 4 9 | 10 | daam.run 11 | 12 | Submodules 13 | ---------- 14 | 15 | daam.evaluate module 16 | -------------------- 17 | 18 | .. automodule:: daam.evaluate 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | daam.experiment module 24 | ---------------------- 25 | 26 | .. automodule:: daam.experiment 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | daam.hook module 32 | ---------------- 33 | 34 | .. automodule:: daam.hook 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | 39 | daam.trace module 40 | ----------------- 41 | 42 | .. automodule:: daam.trace 43 | :members: 44 | :undoc-members: 45 | :show-inheritance: 46 | 47 | daam.utils module 48 | ----------------- 49 | 50 | .. automodule:: daam.utils 51 | :members: 52 | :undoc-members: 53 | :show-inheritance: 54 | 55 | Module contents 56 | --------------- 57 | 58 | .. automodule:: daam 59 | :members: 60 | :undoc-members: 61 | :show-inheritance: 62 | -------------------------------------------------------------------------------- /docs/_sources/daam.run.rst.txt: -------------------------------------------------------------------------------- 1 | daam.run package 2 | ================ 3 | 4 | Submodules 5 | ---------- 6 | 7 | daam.run.annotate module 8 | ------------------------ 9 | 10 | .. automodule:: daam.run.annotate 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | daam.run.daam\_to\_mask module 16 | ------------------------------ 17 | 18 | .. automodule:: daam.run.daam_to_mask 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | daam.run.evaluate module 24 | ------------------------ 25 | 26 | .. automodule:: daam.run.evaluate 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | daam.run.filter\_coco module 32 | ---------------------------- 33 | 34 | .. automodule:: daam.run.filter_coco 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | 39 | daam.run.generate module 40 | ------------------------ 41 | 42 | .. automodule:: daam.run.generate 43 | :members: 44 | :undoc-members: 45 | :show-inheritance: 46 | 47 | daam.run.test\_literacy module 48 | ------------------------------ 49 | 50 | .. automodule:: daam.run.test_literacy 51 | :members: 52 | :undoc-members: 53 | :show-inheritance: 54 | 55 | daam.run.test\_numeracy module 56 | ------------------------------ 57 | 58 | .. automodule:: daam.run.test_numeracy 59 | :members: 60 | :undoc-members: 61 | :show-inheritance: 62 | 63 | Module contents 64 | --------------- 65 | 66 | .. automodule:: daam.run 67 | :members: 68 | :undoc-members: 69 | :show-inheritance: 70 | -------------------------------------------------------------------------------- /docs/_sources/index.rst.txt: -------------------------------------------------------------------------------- 1 | 2 | Welcome to DAAM. 3 | =================================== 4 | `DAAM `_ is an attribution method and toolkit for interpreting the Stable Diffusion model. 5 | 6 | .. toctree:: 7 | :maxdepth: 2 8 | :caption: Contents: 9 | 10 | 11 | 12 | Indices and tables 13 | ================== 14 | 15 | * :ref:`genindex` 16 | * :ref:`modindex` 17 | * :ref:`search` 18 | -------------------------------------------------------------------------------- /docs/_sources/modules.rst.txt: -------------------------------------------------------------------------------- 1 | daam 2 | ==== 3 | 4 | .. toctree:: 5 | :maxdepth: 4 6 | 7 | daam 8 | -------------------------------------------------------------------------------- /docs/_static/alabaster.css: -------------------------------------------------------------------------------- 1 | @import url("basic.css"); 2 | 3 | /* -- page layout ----------------------------------------------------------- */ 4 | 5 | body { 6 | font-family: Georgia, serif; 7 | font-size: 17px; 8 | background-color: #fff; 9 | color: #000; 10 | margin: 0; 11 | padding: 0; 12 | } 13 | 14 | 15 | div.document { 16 | width: 940px; 17 | margin: 30px auto 0 auto; 18 | } 19 | 20 | div.documentwrapper { 21 | float: left; 22 | width: 100%; 23 | } 24 | 25 | div.bodywrapper { 26 | margin: 0 0 0 220px; 27 | } 28 | 29 | div.sphinxsidebar { 30 | width: 220px; 31 | font-size: 14px; 32 | line-height: 1.5; 33 | } 34 | 35 | hr { 36 | border: 1px solid #B1B4B6; 37 | } 38 | 39 | div.body { 40 | background-color: #fff; 41 | color: #3E4349; 42 | padding: 0 30px 0 30px; 43 | } 44 | 45 | div.body > .section { 46 | text-align: left; 47 | } 48 | 49 | div.footer { 50 | width: 940px; 51 | margin: 20px auto 30px auto; 52 | font-size: 14px; 53 | color: #888; 54 | text-align: right; 55 | } 56 | 57 | div.footer a { 58 | color: #888; 59 | } 60 | 61 | p.caption { 62 | font-family: inherit; 63 | font-size: inherit; 64 | } 65 | 66 | 67 | div.relations { 68 | display: none; 69 | } 70 | 71 | 72 | div.sphinxsidebar a { 73 | color: #444; 74 | text-decoration: none; 75 | border-bottom: 1px dotted #999; 76 | } 77 | 78 | div.sphinxsidebar a:hover { 79 | border-bottom: 1px solid #999; 80 | } 81 | 82 | div.sphinxsidebarwrapper { 83 | padding: 18px 10px; 84 | } 85 | 86 | div.sphinxsidebarwrapper p.logo { 87 | padding: 0; 88 | margin: -10px 0 0 0px; 89 | text-align: center; 90 | } 91 | 92 | div.sphinxsidebarwrapper h1.logo { 93 | margin-top: -10px; 94 | text-align: center; 95 | margin-bottom: 5px; 96 | text-align: left; 97 | } 98 | 99 | div.sphinxsidebarwrapper h1.logo-name { 100 | margin-top: 0px; 101 | } 102 | 103 | div.sphinxsidebarwrapper p.blurb { 104 | margin-top: 0; 105 | font-style: normal; 106 | } 107 | 108 | div.sphinxsidebar h3, 109 | div.sphinxsidebar h4 { 110 | font-family: Georgia, serif; 111 | color: #444; 112 | font-size: 24px; 113 | font-weight: normal; 114 | margin: 0 0 5px 0; 115 | padding: 0; 116 | } 117 | 118 | div.sphinxsidebar h4 { 119 | font-size: 20px; 120 | } 121 | 122 | div.sphinxsidebar h3 a { 123 | color: #444; 124 | } 125 | 126 | div.sphinxsidebar p.logo a, 127 | div.sphinxsidebar h3 a, 128 | div.sphinxsidebar p.logo a:hover, 129 | div.sphinxsidebar h3 a:hover { 130 | border: none; 131 | } 132 | 133 | div.sphinxsidebar p { 134 | color: #555; 135 | margin: 10px 0; 136 | } 137 | 138 | div.sphinxsidebar ul { 139 | margin: 10px 0; 140 | padding: 0; 141 | color: #000; 142 | } 143 | 144 | div.sphinxsidebar ul li.toctree-l1 > a { 145 | font-size: 120%; 146 | } 147 | 148 | div.sphinxsidebar ul li.toctree-l2 > a { 149 | font-size: 110%; 150 | } 151 | 152 | div.sphinxsidebar input { 153 | border: 1px solid #CCC; 154 | font-family: Georgia, serif; 155 | font-size: 1em; 156 | } 157 | 158 | div.sphinxsidebar hr { 159 | border: none; 160 | height: 1px; 161 | color: #AAA; 162 | background: #AAA; 163 | 164 | text-align: left; 165 | margin-left: 0; 166 | width: 50%; 167 | } 168 | 169 | div.sphinxsidebar .badge { 170 | border-bottom: none; 171 | } 172 | 173 | div.sphinxsidebar .badge:hover { 174 | border-bottom: none; 175 | } 176 | 177 | /* To address an issue with donation coming after search */ 178 | div.sphinxsidebar h3.donation { 179 | margin-top: 10px; 180 | } 181 | 182 | /* -- body styles ----------------------------------------------------------- */ 183 | 184 | a { 185 | color: #004B6B; 186 | text-decoration: underline; 187 | } 188 | 189 | a:hover { 190 | color: #6D4100; 191 | text-decoration: underline; 192 | } 193 | 194 | div.body h1, 195 | div.body h2, 196 | div.body h3, 197 | div.body h4, 198 | div.body h5, 199 | div.body h6 { 200 | font-family: Georgia, serif; 201 | font-weight: normal; 202 | margin: 30px 0px 10px 0px; 203 | padding: 0; 204 | } 205 | 206 | div.body h1 { margin-top: 0; padding-top: 0; font-size: 240%; } 207 | div.body h2 { font-size: 180%; } 208 | div.body h3 { font-size: 150%; } 209 | div.body h4 { font-size: 130%; } 210 | div.body h5 { font-size: 100%; } 211 | div.body h6 { font-size: 100%; } 212 | 213 | a.headerlink { 214 | color: #DDD; 215 | padding: 0 4px; 216 | text-decoration: none; 217 | } 218 | 219 | a.headerlink:hover { 220 | color: #444; 221 | background: #EAEAEA; 222 | } 223 | 224 | div.body p, div.body dd, div.body li { 225 | line-height: 1.4em; 226 | } 227 | 228 | div.admonition { 229 | margin: 20px 0px; 230 | padding: 10px 30px; 231 | background-color: #EEE; 232 | border: 1px solid #CCC; 233 | } 234 | 235 | div.admonition tt.xref, div.admonition code.xref, div.admonition a tt { 236 | background-color: #FBFBFB; 237 | border-bottom: 1px solid #fafafa; 238 | } 239 | 240 | div.admonition p.admonition-title { 241 | font-family: Georgia, serif; 242 | font-weight: normal; 243 | font-size: 24px; 244 | margin: 0 0 10px 0; 245 | padding: 0; 246 | line-height: 1; 247 | } 248 | 249 | div.admonition p.last { 250 | margin-bottom: 0; 251 | } 252 | 253 | div.highlight { 254 | background-color: #fff; 255 | } 256 | 257 | dt:target, .highlight { 258 | background: #FAF3E8; 259 | } 260 | 261 | div.warning { 262 | background-color: #FCC; 263 | border: 1px solid #FAA; 264 | } 265 | 266 | div.danger { 267 | background-color: #FCC; 268 | border: 1px solid #FAA; 269 | -moz-box-shadow: 2px 2px 4px #D52C2C; 270 | -webkit-box-shadow: 2px 2px 4px #D52C2C; 271 | box-shadow: 2px 2px 4px #D52C2C; 272 | } 273 | 274 | div.error { 275 | background-color: #FCC; 276 | border: 1px solid #FAA; 277 | -moz-box-shadow: 2px 2px 4px #D52C2C; 278 | -webkit-box-shadow: 2px 2px 4px #D52C2C; 279 | box-shadow: 2px 2px 4px #D52C2C; 280 | } 281 | 282 | div.caution { 283 | background-color: #FCC; 284 | border: 1px solid #FAA; 285 | } 286 | 287 | div.attention { 288 | background-color: #FCC; 289 | border: 1px solid #FAA; 290 | } 291 | 292 | div.important { 293 | background-color: #EEE; 294 | border: 1px solid #CCC; 295 | } 296 | 297 | div.note { 298 | background-color: #EEE; 299 | border: 1px solid #CCC; 300 | } 301 | 302 | div.tip { 303 | background-color: #EEE; 304 | border: 1px solid #CCC; 305 | } 306 | 307 | div.hint { 308 | background-color: #EEE; 309 | border: 1px solid #CCC; 310 | } 311 | 312 | div.seealso { 313 | background-color: #EEE; 314 | border: 1px solid #CCC; 315 | } 316 | 317 | div.topic { 318 | background-color: #EEE; 319 | } 320 | 321 | p.admonition-title { 322 | display: inline; 323 | } 324 | 325 | p.admonition-title:after { 326 | content: ":"; 327 | } 328 | 329 | pre, tt, code { 330 | font-family: 'Consolas', 'Menlo', 'DejaVu Sans Mono', 'Bitstream Vera Sans Mono', monospace; 331 | font-size: 0.9em; 332 | } 333 | 334 | .hll { 335 | background-color: #FFC; 336 | margin: 0 -12px; 337 | padding: 0 12px; 338 | display: block; 339 | } 340 | 341 | img.screenshot { 342 | } 343 | 344 | tt.descname, tt.descclassname, code.descname, code.descclassname { 345 | font-size: 0.95em; 346 | } 347 | 348 | tt.descname, code.descname { 349 | padding-right: 0.08em; 350 | } 351 | 352 | img.screenshot { 353 | -moz-box-shadow: 2px 2px 4px #EEE; 354 | -webkit-box-shadow: 2px 2px 4px #EEE; 355 | box-shadow: 2px 2px 4px #EEE; 356 | } 357 | 358 | table.docutils { 359 | border: 1px solid #888; 360 | -moz-box-shadow: 2px 2px 4px #EEE; 361 | -webkit-box-shadow: 2px 2px 4px #EEE; 362 | box-shadow: 2px 2px 4px #EEE; 363 | } 364 | 365 | table.docutils td, table.docutils th { 366 | border: 1px solid #888; 367 | padding: 0.25em 0.7em; 368 | } 369 | 370 | table.field-list, table.footnote { 371 | border: none; 372 | -moz-box-shadow: none; 373 | -webkit-box-shadow: none; 374 | box-shadow: none; 375 | } 376 | 377 | table.footnote { 378 | margin: 15px 0; 379 | width: 100%; 380 | border: 1px solid #EEE; 381 | background: #FDFDFD; 382 | font-size: 0.9em; 383 | } 384 | 385 | table.footnote + table.footnote { 386 | margin-top: -15px; 387 | border-top: none; 388 | } 389 | 390 | table.field-list th { 391 | padding: 0 0.8em 0 0; 392 | } 393 | 394 | table.field-list td { 395 | padding: 0; 396 | } 397 | 398 | table.field-list p { 399 | margin-bottom: 0.8em; 400 | } 401 | 402 | /* Cloned from 403 | * https://github.com/sphinx-doc/sphinx/commit/ef60dbfce09286b20b7385333d63a60321784e68 404 | */ 405 | .field-name { 406 | -moz-hyphens: manual; 407 | -ms-hyphens: manual; 408 | -webkit-hyphens: manual; 409 | hyphens: manual; 410 | } 411 | 412 | table.footnote td.label { 413 | width: .1px; 414 | padding: 0.3em 0 0.3em 0.5em; 415 | } 416 | 417 | table.footnote td { 418 | padding: 0.3em 0.5em; 419 | } 420 | 421 | dl { 422 | margin: 0; 423 | padding: 0; 424 | } 425 | 426 | dl dd { 427 | margin-left: 30px; 428 | } 429 | 430 | blockquote { 431 | margin: 0 0 0 30px; 432 | padding: 0; 433 | } 434 | 435 | ul, ol { 436 | /* Matches the 30px from the narrow-screen "li > ul" selector below */ 437 | margin: 10px 0 10px 30px; 438 | padding: 0; 439 | } 440 | 441 | pre { 442 | background: #EEE; 443 | padding: 7px 30px; 444 | margin: 15px 0px; 445 | line-height: 1.3em; 446 | } 447 | 448 | div.viewcode-block:target { 449 | background: #ffd; 450 | } 451 | 452 | dl pre, blockquote pre, li pre { 453 | margin-left: 0; 454 | padding-left: 30px; 455 | } 456 | 457 | tt, code { 458 | background-color: #ecf0f3; 459 | color: #222; 460 | /* padding: 1px 2px; */ 461 | } 462 | 463 | tt.xref, code.xref, a tt { 464 | background-color: #FBFBFB; 465 | border-bottom: 1px solid #fff; 466 | } 467 | 468 | a.reference { 469 | text-decoration: none; 470 | border-bottom: 1px dotted #004B6B; 471 | } 472 | 473 | /* Don't put an underline on images */ 474 | a.image-reference, a.image-reference:hover { 475 | border-bottom: none; 476 | } 477 | 478 | a.reference:hover { 479 | border-bottom: 1px solid #6D4100; 480 | } 481 | 482 | a.footnote-reference { 483 | text-decoration: none; 484 | font-size: 0.7em; 485 | vertical-align: top; 486 | border-bottom: 1px dotted #004B6B; 487 | } 488 | 489 | a.footnote-reference:hover { 490 | border-bottom: 1px solid #6D4100; 491 | } 492 | 493 | a:hover tt, a:hover code { 494 | background: #EEE; 495 | } 496 | 497 | 498 | @media screen and (max-width: 870px) { 499 | 500 | div.sphinxsidebar { 501 | display: none; 502 | } 503 | 504 | div.document { 505 | width: 100%; 506 | 507 | } 508 | 509 | div.documentwrapper { 510 | margin-left: 0; 511 | margin-top: 0; 512 | margin-right: 0; 513 | margin-bottom: 0; 514 | } 515 | 516 | div.bodywrapper { 517 | margin-top: 0; 518 | margin-right: 0; 519 | margin-bottom: 0; 520 | margin-left: 0; 521 | } 522 | 523 | ul { 524 | margin-left: 0; 525 | } 526 | 527 | li > ul { 528 | /* Matches the 30px from the "ul, ol" selector above */ 529 | margin-left: 30px; 530 | } 531 | 532 | .document { 533 | width: auto; 534 | } 535 | 536 | .footer { 537 | width: auto; 538 | } 539 | 540 | .bodywrapper { 541 | margin: 0; 542 | } 543 | 544 | .footer { 545 | width: auto; 546 | } 547 | 548 | .github { 549 | display: none; 550 | } 551 | 552 | 553 | 554 | } 555 | 556 | 557 | 558 | @media screen and (max-width: 875px) { 559 | 560 | body { 561 | margin: 0; 562 | padding: 20px 30px; 563 | } 564 | 565 | div.documentwrapper { 566 | float: none; 567 | background: #fff; 568 | } 569 | 570 | div.sphinxsidebar { 571 | display: block; 572 | float: none; 573 | width: 102.5%; 574 | margin: 50px -30px -20px -30px; 575 | padding: 10px 20px; 576 | background: #333; 577 | color: #FFF; 578 | } 579 | 580 | div.sphinxsidebar h3, div.sphinxsidebar h4, div.sphinxsidebar p, 581 | div.sphinxsidebar h3 a { 582 | color: #fff; 583 | } 584 | 585 | div.sphinxsidebar a { 586 | color: #AAA; 587 | } 588 | 589 | div.sphinxsidebar p.logo { 590 | display: none; 591 | } 592 | 593 | div.document { 594 | width: 100%; 595 | margin: 0; 596 | } 597 | 598 | div.footer { 599 | display: none; 600 | } 601 | 602 | div.bodywrapper { 603 | margin: 0; 604 | } 605 | 606 | div.body { 607 | min-height: 0; 608 | padding: 0; 609 | } 610 | 611 | .rtd_doc_footer { 612 | display: none; 613 | } 614 | 615 | .document { 616 | width: auto; 617 | } 618 | 619 | .footer { 620 | width: auto; 621 | } 622 | 623 | .footer { 624 | width: auto; 625 | } 626 | 627 | .github { 628 | display: none; 629 | } 630 | } 631 | 632 | 633 | /* misc. */ 634 | 635 | .revsys-inline { 636 | display: none!important; 637 | } 638 | 639 | /* Make nested-list/multi-paragraph items look better in Releases changelog 640 | * pages. Without this, docutils' magical list fuckery causes inconsistent 641 | * formatting between different release sub-lists. 642 | */ 643 | div#changelog > div.section > ul > li > p:only-child { 644 | margin-bottom: 0; 645 | } 646 | 647 | /* Hide fugly table cell borders in ..bibliography:: directive output */ 648 | table.docutils.citation, table.docutils.citation td, table.docutils.citation th { 649 | border: none; 650 | /* Below needed in some edge cases; if not applied, bottom shadows appear */ 651 | -moz-box-shadow: none; 652 | -webkit-box-shadow: none; 653 | box-shadow: none; 654 | } 655 | 656 | 657 | /* relbar */ 658 | 659 | .related { 660 | line-height: 30px; 661 | width: 100%; 662 | font-size: 0.9rem; 663 | } 664 | 665 | .related.top { 666 | border-bottom: 1px solid #EEE; 667 | margin-bottom: 20px; 668 | } 669 | 670 | .related.bottom { 671 | border-top: 1px solid #EEE; 672 | } 673 | 674 | .related ul { 675 | padding: 0; 676 | margin: 0; 677 | list-style: none; 678 | } 679 | 680 | .related li { 681 | display: inline; 682 | } 683 | 684 | nav#rellinks { 685 | float: right; 686 | } 687 | 688 | nav#rellinks li+li:before { 689 | content: "|"; 690 | } 691 | 692 | nav#breadcrumbs li+li:before { 693 | content: "\00BB"; 694 | } 695 | 696 | /* Hide certain items when printing */ 697 | @media print { 698 | div.related { 699 | display: none; 700 | } 701 | } -------------------------------------------------------------------------------- /docs/_static/custom.css: -------------------------------------------------------------------------------- 1 | /* This file intentionally left blank. */ 2 | -------------------------------------------------------------------------------- /docs/_static/doctools.js: -------------------------------------------------------------------------------- 1 | /* 2 | * doctools.js 3 | * ~~~~~~~~~~~ 4 | * 5 | * Sphinx JavaScript utilities for all documentation. 6 | * 7 | * :copyright: Copyright 2007-2022 by the Sphinx team, see AUTHORS. 8 | * :license: BSD, see LICENSE for details. 9 | * 10 | */ 11 | 12 | /** 13 | * select a different prefix for underscore 14 | */ 15 | $u = _.noConflict(); 16 | 17 | /** 18 | * make the code below compatible with browsers without 19 | * an installed firebug like debugger 20 | if (!window.console || !console.firebug) { 21 | var names = ["log", "debug", "info", "warn", "error", "assert", "dir", 22 | "dirxml", "group", "groupEnd", "time", "timeEnd", "count", "trace", 23 | "profile", "profileEnd"]; 24 | window.console = {}; 25 | for (var i = 0; i < names.length; ++i) 26 | window.console[names[i]] = function() {}; 27 | } 28 | */ 29 | 30 | /** 31 | * small helper function to urldecode strings 32 | * 33 | * See https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/decodeURIComponent#Decoding_query_parameters_from_a_URL 34 | */ 35 | jQuery.urldecode = function(x) { 36 | if (!x) { 37 | return x 38 | } 39 | return decodeURIComponent(x.replace(/\+/g, ' ')); 40 | }; 41 | 42 | /** 43 | * small helper function to urlencode strings 44 | */ 45 | jQuery.urlencode = encodeURIComponent; 46 | 47 | /** 48 | * This function returns the parsed url parameters of the 49 | * current request. Multiple values per key are supported, 50 | * it will always return arrays of strings for the value parts. 51 | */ 52 | jQuery.getQueryParameters = function(s) { 53 | if (typeof s === 'undefined') 54 | s = document.location.search; 55 | var parts = s.substr(s.indexOf('?') + 1).split('&'); 56 | var result = {}; 57 | for (var i = 0; i < parts.length; i++) { 58 | var tmp = parts[i].split('=', 2); 59 | var key = jQuery.urldecode(tmp[0]); 60 | var value = jQuery.urldecode(tmp[1]); 61 | if (key in result) 62 | result[key].push(value); 63 | else 64 | result[key] = [value]; 65 | } 66 | return result; 67 | }; 68 | 69 | /** 70 | * highlight a given string on a jquery object by wrapping it in 71 | * span elements with the given class name. 72 | */ 73 | jQuery.fn.highlightText = function(text, className) { 74 | function highlight(node, addItems) { 75 | if (node.nodeType === 3) { 76 | var val = node.nodeValue; 77 | var pos = val.toLowerCase().indexOf(text); 78 | if (pos >= 0 && 79 | !jQuery(node.parentNode).hasClass(className) && 80 | !jQuery(node.parentNode).hasClass("nohighlight")) { 81 | var span; 82 | var isInSVG = jQuery(node).closest("body, svg, foreignObject").is("svg"); 83 | if (isInSVG) { 84 | span = document.createElementNS("http://www.w3.org/2000/svg", "tspan"); 85 | } else { 86 | span = document.createElement("span"); 87 | span.className = className; 88 | } 89 | span.appendChild(document.createTextNode(val.substr(pos, text.length))); 90 | node.parentNode.insertBefore(span, node.parentNode.insertBefore( 91 | document.createTextNode(val.substr(pos + text.length)), 92 | node.nextSibling)); 93 | node.nodeValue = val.substr(0, pos); 94 | if (isInSVG) { 95 | var rect = document.createElementNS("http://www.w3.org/2000/svg", "rect"); 96 | var bbox = node.parentElement.getBBox(); 97 | rect.x.baseVal.value = bbox.x; 98 | rect.y.baseVal.value = bbox.y; 99 | rect.width.baseVal.value = bbox.width; 100 | rect.height.baseVal.value = bbox.height; 101 | rect.setAttribute('class', className); 102 | addItems.push({ 103 | "parent": node.parentNode, 104 | "target": rect}); 105 | } 106 | } 107 | } 108 | else if (!jQuery(node).is("button, select, textarea")) { 109 | jQuery.each(node.childNodes, function() { 110 | highlight(this, addItems); 111 | }); 112 | } 113 | } 114 | var addItems = []; 115 | var result = this.each(function() { 116 | highlight(this, addItems); 117 | }); 118 | for (var i = 0; i < addItems.length; ++i) { 119 | jQuery(addItems[i].parent).before(addItems[i].target); 120 | } 121 | return result; 122 | }; 123 | 124 | /* 125 | * backward compatibility for jQuery.browser 126 | * This will be supported until firefox bug is fixed. 127 | */ 128 | if (!jQuery.browser) { 129 | jQuery.uaMatch = function(ua) { 130 | ua = ua.toLowerCase(); 131 | 132 | var match = /(chrome)[ \/]([\w.]+)/.exec(ua) || 133 | /(webkit)[ \/]([\w.]+)/.exec(ua) || 134 | /(opera)(?:.*version|)[ \/]([\w.]+)/.exec(ua) || 135 | /(msie) ([\w.]+)/.exec(ua) || 136 | ua.indexOf("compatible") < 0 && /(mozilla)(?:.*? rv:([\w.]+)|)/.exec(ua) || 137 | []; 138 | 139 | return { 140 | browser: match[ 1 ] || "", 141 | version: match[ 2 ] || "0" 142 | }; 143 | }; 144 | jQuery.browser = {}; 145 | jQuery.browser[jQuery.uaMatch(navigator.userAgent).browser] = true; 146 | } 147 | 148 | /** 149 | * Small JavaScript module for the documentation. 150 | */ 151 | var Documentation = { 152 | 153 | init : function() { 154 | this.fixFirefoxAnchorBug(); 155 | this.highlightSearchWords(); 156 | this.initIndexTable(); 157 | this.initOnKeyListeners(); 158 | }, 159 | 160 | /** 161 | * i18n support 162 | */ 163 | TRANSLATIONS : {}, 164 | PLURAL_EXPR : function(n) { return n === 1 ? 0 : 1; }, 165 | LOCALE : 'unknown', 166 | 167 | // gettext and ngettext don't access this so that the functions 168 | // can safely bound to a different name (_ = Documentation.gettext) 169 | gettext : function(string) { 170 | var translated = Documentation.TRANSLATIONS[string]; 171 | if (typeof translated === 'undefined') 172 | return string; 173 | return (typeof translated === 'string') ? translated : translated[0]; 174 | }, 175 | 176 | ngettext : function(singular, plural, n) { 177 | var translated = Documentation.TRANSLATIONS[singular]; 178 | if (typeof translated === 'undefined') 179 | return (n == 1) ? singular : plural; 180 | return translated[Documentation.PLURALEXPR(n)]; 181 | }, 182 | 183 | addTranslations : function(catalog) { 184 | for (var key in catalog.messages) 185 | this.TRANSLATIONS[key] = catalog.messages[key]; 186 | this.PLURAL_EXPR = new Function('n', 'return +(' + catalog.plural_expr + ')'); 187 | this.LOCALE = catalog.locale; 188 | }, 189 | 190 | /** 191 | * add context elements like header anchor links 192 | */ 193 | addContextElements : function() { 194 | $('div[id] > :header:first').each(function() { 195 | $('\u00B6'). 196 | attr('href', '#' + this.id). 197 | attr('title', _('Permalink to this headline')). 198 | appendTo(this); 199 | }); 200 | $('dt[id]').each(function() { 201 | $('\u00B6'). 202 | attr('href', '#' + this.id). 203 | attr('title', _('Permalink to this definition')). 204 | appendTo(this); 205 | }); 206 | }, 207 | 208 | /** 209 | * workaround a firefox stupidity 210 | * see: https://bugzilla.mozilla.org/show_bug.cgi?id=645075 211 | */ 212 | fixFirefoxAnchorBug : function() { 213 | if (document.location.hash && $.browser.mozilla) 214 | window.setTimeout(function() { 215 | document.location.href += ''; 216 | }, 10); 217 | }, 218 | 219 | /** 220 | * highlight the search words provided in the url in the text 221 | */ 222 | highlightSearchWords : function() { 223 | var params = $.getQueryParameters(); 224 | var terms = (params.highlight) ? params.highlight[0].split(/\s+/) : []; 225 | if (terms.length) { 226 | var body = $('div.body'); 227 | if (!body.length) { 228 | body = $('body'); 229 | } 230 | window.setTimeout(function() { 231 | $.each(terms, function() { 232 | body.highlightText(this.toLowerCase(), 'highlighted'); 233 | }); 234 | }, 10); 235 | $('') 237 | .appendTo($('#searchbox')); 238 | } 239 | }, 240 | 241 | /** 242 | * init the domain index toggle buttons 243 | */ 244 | initIndexTable : function() { 245 | var togglers = $('img.toggler').click(function() { 246 | var src = $(this).attr('src'); 247 | var idnum = $(this).attr('id').substr(7); 248 | $('tr.cg-' + idnum).toggle(); 249 | if (src.substr(-9) === 'minus.png') 250 | $(this).attr('src', src.substr(0, src.length-9) + 'plus.png'); 251 | else 252 | $(this).attr('src', src.substr(0, src.length-8) + 'minus.png'); 253 | }).css('display', ''); 254 | if (DOCUMENTATION_OPTIONS.COLLAPSE_INDEX) { 255 | togglers.click(); 256 | } 257 | }, 258 | 259 | /** 260 | * helper function to hide the search marks again 261 | */ 262 | hideSearchWords : function() { 263 | $('#searchbox .highlight-link').fadeOut(300); 264 | $('span.highlighted').removeClass('highlighted'); 265 | var url = new URL(window.location); 266 | url.searchParams.delete('highlight'); 267 | window.history.replaceState({}, '', url); 268 | }, 269 | 270 | /** 271 | * helper function to focus on search bar 272 | */ 273 | focusSearchBar : function() { 274 | $('input[name=q]').first().focus(); 275 | }, 276 | 277 | /** 278 | * make the url absolute 279 | */ 280 | makeURL : function(relativeURL) { 281 | return DOCUMENTATION_OPTIONS.URL_ROOT + '/' + relativeURL; 282 | }, 283 | 284 | /** 285 | * get the current relative url 286 | */ 287 | getCurrentURL : function() { 288 | var path = document.location.pathname; 289 | var parts = path.split(/\//); 290 | $.each(DOCUMENTATION_OPTIONS.URL_ROOT.split(/\//), function() { 291 | if (this === '..') 292 | parts.pop(); 293 | }); 294 | var url = parts.join('/'); 295 | return path.substring(url.lastIndexOf('/') + 1, path.length - 1); 296 | }, 297 | 298 | initOnKeyListeners: function() { 299 | // only install a listener if it is really needed 300 | if (!DOCUMENTATION_OPTIONS.NAVIGATION_WITH_KEYS && 301 | !DOCUMENTATION_OPTIONS.ENABLE_SEARCH_SHORTCUTS) 302 | return; 303 | 304 | $(document).keydown(function(event) { 305 | var activeElementType = document.activeElement.tagName; 306 | // don't navigate when in search box, textarea, dropdown or button 307 | if (activeElementType !== 'TEXTAREA' && activeElementType !== 'INPUT' && activeElementType !== 'SELECT' 308 | && activeElementType !== 'BUTTON') { 309 | if (event.altKey || event.ctrlKey || event.metaKey) 310 | return; 311 | 312 | if (!event.shiftKey) { 313 | switch (event.key) { 314 | case 'ArrowLeft': 315 | if (!DOCUMENTATION_OPTIONS.NAVIGATION_WITH_KEYS) 316 | break; 317 | var prevHref = $('link[rel="prev"]').prop('href'); 318 | if (prevHref) { 319 | window.location.href = prevHref; 320 | return false; 321 | } 322 | break; 323 | case 'ArrowRight': 324 | if (!DOCUMENTATION_OPTIONS.NAVIGATION_WITH_KEYS) 325 | break; 326 | var nextHref = $('link[rel="next"]').prop('href'); 327 | if (nextHref) { 328 | window.location.href = nextHref; 329 | return false; 330 | } 331 | break; 332 | case 'Escape': 333 | if (!DOCUMENTATION_OPTIONS.ENABLE_SEARCH_SHORTCUTS) 334 | break; 335 | Documentation.hideSearchWords(); 336 | return false; 337 | } 338 | } 339 | 340 | // some keyboard layouts may need Shift to get / 341 | switch (event.key) { 342 | case '/': 343 | if (!DOCUMENTATION_OPTIONS.ENABLE_SEARCH_SHORTCUTS) 344 | break; 345 | Documentation.focusSearchBar(); 346 | return false; 347 | } 348 | } 349 | }); 350 | } 351 | }; 352 | 353 | // quick alias for translations 354 | _ = Documentation.gettext; 355 | 356 | $(document).ready(function() { 357 | Documentation.init(); 358 | }); 359 | -------------------------------------------------------------------------------- /docs/_static/documentation_options.js: -------------------------------------------------------------------------------- 1 | var DOCUMENTATION_OPTIONS = { 2 | URL_ROOT: document.getElementById("documentation_options").getAttribute('data-url_root'), 3 | VERSION: 'v0.0.4', 4 | LANGUAGE: 'None', 5 | COLLAPSE_INDEX: false, 6 | BUILDER: 'html', 7 | FILE_SUFFIX: '.html', 8 | LINK_SUFFIX: '.html', 9 | HAS_SOURCE: true, 10 | SOURCELINK_SUFFIX: '.txt', 11 | NAVIGATION_WITH_KEYS: false, 12 | SHOW_SEARCH_SUMMARY: true, 13 | ENABLE_SEARCH_SHORTCUTS: true, 14 | }; -------------------------------------------------------------------------------- /docs/_static/file.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/castorini/daam/c30493ed0154bfccb6c342400f25cc24599bb1ff/docs/_static/file.png -------------------------------------------------------------------------------- /docs/_static/language_data.js: -------------------------------------------------------------------------------- 1 | /* 2 | * language_data.js 3 | * ~~~~~~~~~~~~~~~~ 4 | * 5 | * This script contains the language-specific data used by searchtools.js, 6 | * namely the list of stopwords, stemmer, scorer and splitter. 7 | * 8 | * :copyright: Copyright 2007-2022 by the Sphinx team, see AUTHORS. 9 | * :license: BSD, see LICENSE for details. 10 | * 11 | */ 12 | 13 | var stopwords = ["a","and","are","as","at","be","but","by","for","if","in","into","is","it","near","no","not","of","on","or","such","that","the","their","then","there","these","they","this","to","was","will","with"]; 14 | 15 | 16 | /* Non-minified version is copied as a separate JS file, is available */ 17 | 18 | /** 19 | * Porter Stemmer 20 | */ 21 | var Stemmer = function() { 22 | 23 | var step2list = { 24 | ational: 'ate', 25 | tional: 'tion', 26 | enci: 'ence', 27 | anci: 'ance', 28 | izer: 'ize', 29 | bli: 'ble', 30 | alli: 'al', 31 | entli: 'ent', 32 | eli: 'e', 33 | ousli: 'ous', 34 | ization: 'ize', 35 | ation: 'ate', 36 | ator: 'ate', 37 | alism: 'al', 38 | iveness: 'ive', 39 | fulness: 'ful', 40 | ousness: 'ous', 41 | aliti: 'al', 42 | iviti: 'ive', 43 | biliti: 'ble', 44 | logi: 'log' 45 | }; 46 | 47 | var step3list = { 48 | icate: 'ic', 49 | ative: '', 50 | alize: 'al', 51 | iciti: 'ic', 52 | ical: 'ic', 53 | ful: '', 54 | ness: '' 55 | }; 56 | 57 | var c = "[^aeiou]"; // consonant 58 | var v = "[aeiouy]"; // vowel 59 | var C = c + "[^aeiouy]*"; // consonant sequence 60 | var V = v + "[aeiou]*"; // vowel sequence 61 | 62 | var mgr0 = "^(" + C + ")?" + V + C; // [C]VC... is m>0 63 | var meq1 = "^(" + C + ")?" + V + C + "(" + V + ")?$"; // [C]VC[V] is m=1 64 | var mgr1 = "^(" + C + ")?" + V + C + V + C; // [C]VCVC... is m>1 65 | var s_v = "^(" + C + ")?" + v; // vowel in stem 66 | 67 | this.stemWord = function (w) { 68 | var stem; 69 | var suffix; 70 | var firstch; 71 | var origword = w; 72 | 73 | if (w.length < 3) 74 | return w; 75 | 76 | var re; 77 | var re2; 78 | var re3; 79 | var re4; 80 | 81 | firstch = w.substr(0,1); 82 | if (firstch == "y") 83 | w = firstch.toUpperCase() + w.substr(1); 84 | 85 | // Step 1a 86 | re = /^(.+?)(ss|i)es$/; 87 | re2 = /^(.+?)([^s])s$/; 88 | 89 | if (re.test(w)) 90 | w = w.replace(re,"$1$2"); 91 | else if (re2.test(w)) 92 | w = w.replace(re2,"$1$2"); 93 | 94 | // Step 1b 95 | re = /^(.+?)eed$/; 96 | re2 = /^(.+?)(ed|ing)$/; 97 | if (re.test(w)) { 98 | var fp = re.exec(w); 99 | re = new RegExp(mgr0); 100 | if (re.test(fp[1])) { 101 | re = /.$/; 102 | w = w.replace(re,""); 103 | } 104 | } 105 | else if (re2.test(w)) { 106 | var fp = re2.exec(w); 107 | stem = fp[1]; 108 | re2 = new RegExp(s_v); 109 | if (re2.test(stem)) { 110 | w = stem; 111 | re2 = /(at|bl|iz)$/; 112 | re3 = new RegExp("([^aeiouylsz])\\1$"); 113 | re4 = new RegExp("^" + C + v + "[^aeiouwxy]$"); 114 | if (re2.test(w)) 115 | w = w + "e"; 116 | else if (re3.test(w)) { 117 | re = /.$/; 118 | w = w.replace(re,""); 119 | } 120 | else if (re4.test(w)) 121 | w = w + "e"; 122 | } 123 | } 124 | 125 | // Step 1c 126 | re = /^(.+?)y$/; 127 | if (re.test(w)) { 128 | var fp = re.exec(w); 129 | stem = fp[1]; 130 | re = new RegExp(s_v); 131 | if (re.test(stem)) 132 | w = stem + "i"; 133 | } 134 | 135 | // Step 2 136 | re = /^(.+?)(ational|tional|enci|anci|izer|bli|alli|entli|eli|ousli|ization|ation|ator|alism|iveness|fulness|ousness|aliti|iviti|biliti|logi)$/; 137 | if (re.test(w)) { 138 | var fp = re.exec(w); 139 | stem = fp[1]; 140 | suffix = fp[2]; 141 | re = new RegExp(mgr0); 142 | if (re.test(stem)) 143 | w = stem + step2list[suffix]; 144 | } 145 | 146 | // Step 3 147 | re = /^(.+?)(icate|ative|alize|iciti|ical|ful|ness)$/; 148 | if (re.test(w)) { 149 | var fp = re.exec(w); 150 | stem = fp[1]; 151 | suffix = fp[2]; 152 | re = new RegExp(mgr0); 153 | if (re.test(stem)) 154 | w = stem + step3list[suffix]; 155 | } 156 | 157 | // Step 4 158 | re = /^(.+?)(al|ance|ence|er|ic|able|ible|ant|ement|ment|ent|ou|ism|ate|iti|ous|ive|ize)$/; 159 | re2 = /^(.+?)(s|t)(ion)$/; 160 | if (re.test(w)) { 161 | var fp = re.exec(w); 162 | stem = fp[1]; 163 | re = new RegExp(mgr1); 164 | if (re.test(stem)) 165 | w = stem; 166 | } 167 | else if (re2.test(w)) { 168 | var fp = re2.exec(w); 169 | stem = fp[1] + fp[2]; 170 | re2 = new RegExp(mgr1); 171 | if (re2.test(stem)) 172 | w = stem; 173 | } 174 | 175 | // Step 5 176 | re = /^(.+?)e$/; 177 | if (re.test(w)) { 178 | var fp = re.exec(w); 179 | stem = fp[1]; 180 | re = new RegExp(mgr1); 181 | re2 = new RegExp(meq1); 182 | re3 = new RegExp("^" + C + v + "[^aeiouwxy]$"); 183 | if (re.test(stem) || (re2.test(stem) && !(re3.test(stem)))) 184 | w = stem; 185 | } 186 | re = /ll$/; 187 | re2 = new RegExp(mgr1); 188 | if (re.test(w) && re2.test(w)) { 189 | re = /.$/; 190 | w = w.replace(re,""); 191 | } 192 | 193 | // and turn initial Y back to y 194 | if (firstch == "y") 195 | w = firstch.toLowerCase() + w.substr(1); 196 | return w; 197 | } 198 | } 199 | 200 | 201 | 202 | 203 | var splitChars = (function() { 204 | var result = {}; 205 | var singles = [96, 180, 187, 191, 215, 247, 749, 885, 903, 907, 909, 930, 1014, 1648, 206 | 1748, 1809, 2416, 2473, 2481, 2526, 2601, 2609, 2612, 2615, 2653, 2702, 207 | 2706, 2729, 2737, 2740, 2857, 2865, 2868, 2910, 2928, 2948, 2961, 2971, 208 | 2973, 3085, 3089, 3113, 3124, 3213, 3217, 3241, 3252, 3295, 3341, 3345, 209 | 3369, 3506, 3516, 3633, 3715, 3721, 3736, 3744, 3748, 3750, 3756, 3761, 210 | 3781, 3912, 4239, 4347, 4681, 4695, 4697, 4745, 4785, 4799, 4801, 4823, 211 | 4881, 5760, 5901, 5997, 6313, 7405, 8024, 8026, 8028, 8030, 8117, 8125, 212 | 8133, 8181, 8468, 8485, 8487, 8489, 8494, 8527, 11311, 11359, 11687, 11695, 213 | 11703, 11711, 11719, 11727, 11735, 12448, 12539, 43010, 43014, 43019, 43587, 214 | 43696, 43713, 64286, 64297, 64311, 64317, 64319, 64322, 64325, 65141]; 215 | var i, j, start, end; 216 | for (i = 0; i < singles.length; i++) { 217 | result[singles[i]] = true; 218 | } 219 | var ranges = [[0, 47], [58, 64], [91, 94], [123, 169], [171, 177], [182, 184], [706, 709], 220 | [722, 735], [741, 747], [751, 879], [888, 889], [894, 901], [1154, 1161], 221 | [1318, 1328], [1367, 1368], [1370, 1376], [1416, 1487], [1515, 1519], [1523, 1568], 222 | [1611, 1631], [1642, 1645], [1750, 1764], [1767, 1773], [1789, 1790], [1792, 1807], 223 | [1840, 1868], [1958, 1968], [1970, 1983], [2027, 2035], [2038, 2041], [2043, 2047], 224 | [2070, 2073], [2075, 2083], [2085, 2087], [2089, 2307], [2362, 2364], [2366, 2383], 225 | [2385, 2391], [2402, 2405], [2419, 2424], [2432, 2436], [2445, 2446], [2449, 2450], 226 | [2483, 2485], [2490, 2492], [2494, 2509], [2511, 2523], [2530, 2533], [2546, 2547], 227 | [2554, 2564], [2571, 2574], [2577, 2578], [2618, 2648], [2655, 2661], [2672, 2673], 228 | [2677, 2692], [2746, 2748], [2750, 2767], [2769, 2783], [2786, 2789], [2800, 2820], 229 | [2829, 2830], [2833, 2834], [2874, 2876], [2878, 2907], [2914, 2917], [2930, 2946], 230 | [2955, 2957], [2966, 2968], [2976, 2978], [2981, 2983], [2987, 2989], [3002, 3023], 231 | [3025, 3045], [3059, 3076], [3130, 3132], [3134, 3159], [3162, 3167], [3170, 3173], 232 | [3184, 3191], [3199, 3204], [3258, 3260], [3262, 3293], [3298, 3301], [3312, 3332], 233 | [3386, 3388], [3390, 3423], [3426, 3429], [3446, 3449], [3456, 3460], [3479, 3481], 234 | [3518, 3519], [3527, 3584], [3636, 3647], [3655, 3663], [3674, 3712], [3717, 3718], 235 | [3723, 3724], [3726, 3731], [3752, 3753], [3764, 3772], [3774, 3775], [3783, 3791], 236 | [3802, 3803], [3806, 3839], [3841, 3871], [3892, 3903], [3949, 3975], [3980, 4095], 237 | [4139, 4158], [4170, 4175], [4182, 4185], [4190, 4192], [4194, 4196], [4199, 4205], 238 | [4209, 4212], [4226, 4237], [4250, 4255], [4294, 4303], [4349, 4351], [4686, 4687], 239 | [4702, 4703], [4750, 4751], [4790, 4791], [4806, 4807], [4886, 4887], [4955, 4968], 240 | [4989, 4991], [5008, 5023], [5109, 5120], [5741, 5742], [5787, 5791], [5867, 5869], 241 | [5873, 5887], [5906, 5919], [5938, 5951], [5970, 5983], [6001, 6015], [6068, 6102], 242 | [6104, 6107], [6109, 6111], [6122, 6127], [6138, 6159], [6170, 6175], [6264, 6271], 243 | [6315, 6319], [6390, 6399], [6429, 6469], [6510, 6511], [6517, 6527], [6572, 6592], 244 | [6600, 6607], [6619, 6655], [6679, 6687], [6741, 6783], [6794, 6799], [6810, 6822], 245 | [6824, 6916], [6964, 6980], [6988, 6991], [7002, 7042], [7073, 7085], [7098, 7167], 246 | [7204, 7231], [7242, 7244], [7294, 7400], [7410, 7423], [7616, 7679], [7958, 7959], 247 | [7966, 7967], [8006, 8007], [8014, 8015], [8062, 8063], [8127, 8129], [8141, 8143], 248 | [8148, 8149], [8156, 8159], [8173, 8177], [8189, 8303], [8306, 8307], [8314, 8318], 249 | [8330, 8335], [8341, 8449], [8451, 8454], [8456, 8457], [8470, 8472], [8478, 8483], 250 | [8506, 8507], [8512, 8516], [8522, 8525], [8586, 9311], [9372, 9449], [9472, 10101], 251 | [10132, 11263], [11493, 11498], [11503, 11516], [11518, 11519], [11558, 11567], 252 | [11622, 11630], [11632, 11647], [11671, 11679], [11743, 11822], [11824, 12292], 253 | [12296, 12320], [12330, 12336], [12342, 12343], [12349, 12352], [12439, 12444], 254 | [12544, 12548], [12590, 12592], [12687, 12689], [12694, 12703], [12728, 12783], 255 | [12800, 12831], [12842, 12880], [12896, 12927], [12938, 12976], [12992, 13311], 256 | [19894, 19967], [40908, 40959], [42125, 42191], [42238, 42239], [42509, 42511], 257 | [42540, 42559], [42592, 42593], [42607, 42622], [42648, 42655], [42736, 42774], 258 | [42784, 42785], [42889, 42890], [42893, 43002], [43043, 43055], [43062, 43071], 259 | [43124, 43137], [43188, 43215], [43226, 43249], [43256, 43258], [43260, 43263], 260 | [43302, 43311], [43335, 43359], [43389, 43395], [43443, 43470], [43482, 43519], 261 | [43561, 43583], [43596, 43599], [43610, 43615], [43639, 43641], [43643, 43647], 262 | [43698, 43700], [43703, 43704], [43710, 43711], [43715, 43738], [43742, 43967], 263 | [44003, 44015], [44026, 44031], [55204, 55215], [55239, 55242], [55292, 55295], 264 | [57344, 63743], [64046, 64047], [64110, 64111], [64218, 64255], [64263, 64274], 265 | [64280, 64284], [64434, 64466], [64830, 64847], [64912, 64913], [64968, 65007], 266 | [65020, 65135], [65277, 65295], [65306, 65312], [65339, 65344], [65371, 65381], 267 | [65471, 65473], [65480, 65481], [65488, 65489], [65496, 65497]]; 268 | for (i = 0; i < ranges.length; i++) { 269 | start = ranges[i][0]; 270 | end = ranges[i][1]; 271 | for (j = start; j <= end; j++) { 272 | result[j] = true; 273 | } 274 | } 275 | return result; 276 | })(); 277 | 278 | function splitQuery(query) { 279 | var result = []; 280 | var start = -1; 281 | for (var i = 0; i < query.length; i++) { 282 | if (splitChars[query.charCodeAt(i)]) { 283 | if (start !== -1) { 284 | result.push(query.slice(start, i)); 285 | start = -1; 286 | } 287 | } else if (start === -1) { 288 | start = i; 289 | } 290 | } 291 | if (start !== -1) { 292 | result.push(query.slice(start)); 293 | } 294 | return result; 295 | } 296 | 297 | 298 | -------------------------------------------------------------------------------- /docs/_static/minus.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/castorini/daam/c30493ed0154bfccb6c342400f25cc24599bb1ff/docs/_static/minus.png -------------------------------------------------------------------------------- /docs/_static/plus.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/castorini/daam/c30493ed0154bfccb6c342400f25cc24599bb1ff/docs/_static/plus.png -------------------------------------------------------------------------------- /docs/_static/pygments.css: -------------------------------------------------------------------------------- 1 | pre { line-height: 125%; } 2 | td.linenos .normal { color: inherit; background-color: transparent; padding-left: 5px; padding-right: 5px; } 3 | span.linenos { color: inherit; background-color: transparent; padding-left: 5px; padding-right: 5px; } 4 | td.linenos .special { color: #000000; background-color: #ffffc0; padding-left: 5px; padding-right: 5px; } 5 | span.linenos.special { color: #000000; background-color: #ffffc0; padding-left: 5px; padding-right: 5px; } 6 | .highlight .hll { background-color: #ffffcc } 7 | .highlight { background: #f8f8f8; } 8 | .highlight .c { color: #8f5902; font-style: italic } /* Comment */ 9 | .highlight .err { color: #a40000; border: 1px solid #ef2929 } /* Error */ 10 | .highlight .g { color: #000000 } /* Generic */ 11 | .highlight .k { color: #004461; font-weight: bold } /* Keyword */ 12 | .highlight .l { color: #000000 } /* Literal */ 13 | .highlight .n { color: #000000 } /* Name */ 14 | .highlight .o { color: #582800 } /* Operator */ 15 | .highlight .x { color: #000000 } /* Other */ 16 | .highlight .p { color: #000000; font-weight: bold } /* Punctuation */ 17 | .highlight .ch { color: #8f5902; font-style: italic } /* Comment.Hashbang */ 18 | .highlight .cm { color: #8f5902; font-style: italic } /* Comment.Multiline */ 19 | .highlight .cp { color: #8f5902 } /* Comment.Preproc */ 20 | .highlight .cpf { color: #8f5902; font-style: italic } /* Comment.PreprocFile */ 21 | .highlight .c1 { color: #8f5902; font-style: italic } /* Comment.Single */ 22 | .highlight .cs { color: #8f5902; font-style: italic } /* Comment.Special */ 23 | .highlight .gd { color: #a40000 } /* Generic.Deleted */ 24 | .highlight .ge { color: #000000; font-style: italic } /* Generic.Emph */ 25 | .highlight .gr { color: #ef2929 } /* Generic.Error */ 26 | .highlight .gh { color: #000080; font-weight: bold } /* Generic.Heading */ 27 | .highlight .gi { color: #00A000 } /* Generic.Inserted */ 28 | .highlight .go { color: #888888 } /* Generic.Output */ 29 | .highlight .gp { color: #745334 } /* Generic.Prompt */ 30 | .highlight .gs { color: #000000; font-weight: bold } /* Generic.Strong */ 31 | .highlight .gu { color: #800080; font-weight: bold } /* Generic.Subheading */ 32 | .highlight .gt { color: #a40000; font-weight: bold } /* Generic.Traceback */ 33 | .highlight .kc { color: #004461; font-weight: bold } /* Keyword.Constant */ 34 | .highlight .kd { color: #004461; font-weight: bold } /* Keyword.Declaration */ 35 | .highlight .kn { color: #004461; font-weight: bold } /* Keyword.Namespace */ 36 | .highlight .kp { color: #004461; font-weight: bold } /* Keyword.Pseudo */ 37 | .highlight .kr { color: #004461; font-weight: bold } /* Keyword.Reserved */ 38 | .highlight .kt { color: #004461; font-weight: bold } /* Keyword.Type */ 39 | .highlight .ld { color: #000000 } /* Literal.Date */ 40 | .highlight .m { color: #990000 } /* Literal.Number */ 41 | .highlight .s { color: #4e9a06 } /* Literal.String */ 42 | .highlight .na { color: #c4a000 } /* Name.Attribute */ 43 | .highlight .nb { color: #004461 } /* Name.Builtin */ 44 | .highlight .nc { color: #000000 } /* Name.Class */ 45 | .highlight .no { color: #000000 } /* Name.Constant */ 46 | .highlight .nd { color: #888888 } /* Name.Decorator */ 47 | .highlight .ni { color: #ce5c00 } /* Name.Entity */ 48 | .highlight .ne { color: #cc0000; font-weight: bold } /* Name.Exception */ 49 | .highlight .nf { color: #000000 } /* Name.Function */ 50 | .highlight .nl { color: #f57900 } /* Name.Label */ 51 | .highlight .nn { color: #000000 } /* Name.Namespace */ 52 | .highlight .nx { color: #000000 } /* Name.Other */ 53 | .highlight .py { color: #000000 } /* Name.Property */ 54 | .highlight .nt { color: #004461; font-weight: bold } /* Name.Tag */ 55 | .highlight .nv { color: #000000 } /* Name.Variable */ 56 | .highlight .ow { color: #004461; font-weight: bold } /* Operator.Word */ 57 | .highlight .w { color: #f8f8f8; text-decoration: underline } /* Text.Whitespace */ 58 | .highlight .mb { color: #990000 } /* Literal.Number.Bin */ 59 | .highlight .mf { color: #990000 } /* Literal.Number.Float */ 60 | .highlight .mh { color: #990000 } /* Literal.Number.Hex */ 61 | .highlight .mi { color: #990000 } /* Literal.Number.Integer */ 62 | .highlight .mo { color: #990000 } /* Literal.Number.Oct */ 63 | .highlight .sa { color: #4e9a06 } /* Literal.String.Affix */ 64 | .highlight .sb { color: #4e9a06 } /* Literal.String.Backtick */ 65 | .highlight .sc { color: #4e9a06 } /* Literal.String.Char */ 66 | .highlight .dl { color: #4e9a06 } /* Literal.String.Delimiter */ 67 | .highlight .sd { color: #8f5902; font-style: italic } /* Literal.String.Doc */ 68 | .highlight .s2 { color: #4e9a06 } /* Literal.String.Double */ 69 | .highlight .se { color: #4e9a06 } /* Literal.String.Escape */ 70 | .highlight .sh { color: #4e9a06 } /* Literal.String.Heredoc */ 71 | .highlight .si { color: #4e9a06 } /* Literal.String.Interpol */ 72 | .highlight .sx { color: #4e9a06 } /* Literal.String.Other */ 73 | .highlight .sr { color: #4e9a06 } /* Literal.String.Regex */ 74 | .highlight .s1 { color: #4e9a06 } /* Literal.String.Single */ 75 | .highlight .ss { color: #4e9a06 } /* Literal.String.Symbol */ 76 | .highlight .bp { color: #3465a4 } /* Name.Builtin.Pseudo */ 77 | .highlight .fm { color: #000000 } /* Name.Function.Magic */ 78 | .highlight .vc { color: #000000 } /* Name.Variable.Class */ 79 | .highlight .vg { color: #000000 } /* Name.Variable.Global */ 80 | .highlight .vi { color: #000000 } /* Name.Variable.Instance */ 81 | .highlight .vm { color: #000000 } /* Name.Variable.Magic */ 82 | .highlight .il { color: #990000 } /* Literal.Number.Integer.Long */ -------------------------------------------------------------------------------- /docs/_static/underscore.js: -------------------------------------------------------------------------------- 1 | !function(n,r){"object"==typeof exports&&"undefined"!=typeof module?module.exports=r():"function"==typeof define&&define.amd?define("underscore",r):(n="undefined"!=typeof globalThis?globalThis:n||self,function(){var t=n._,e=n._=r();e.noConflict=function(){return n._=t,e}}())}(this,(function(){ 2 | // Underscore.js 1.13.1 3 | // https://underscorejs.org 4 | // (c) 2009-2021 Jeremy Ashkenas, Julian Gonggrijp, and DocumentCloud and Investigative Reporters & Editors 5 | // Underscore may be freely distributed under the MIT license. 6 | var n="1.13.1",r="object"==typeof self&&self.self===self&&self||"object"==typeof global&&global.global===global&&global||Function("return this")()||{},t=Array.prototype,e=Object.prototype,u="undefined"!=typeof Symbol?Symbol.prototype:null,o=t.push,i=t.slice,a=e.toString,f=e.hasOwnProperty,c="undefined"!=typeof ArrayBuffer,l="undefined"!=typeof DataView,s=Array.isArray,p=Object.keys,v=Object.create,h=c&&ArrayBuffer.isView,y=isNaN,d=isFinite,g=!{toString:null}.propertyIsEnumerable("toString"),b=["valueOf","isPrototypeOf","toString","propertyIsEnumerable","hasOwnProperty","toLocaleString"],m=Math.pow(2,53)-1;function j(n,r){return r=null==r?n.length-1:+r,function(){for(var t=Math.max(arguments.length-r,0),e=Array(t),u=0;u=0&&t<=m}}function J(n){return function(r){return null==r?void 0:r[n]}}var G=J("byteLength"),H=K(G),Q=/\[object ((I|Ui)nt(8|16|32)|Float(32|64)|Uint8Clamped|Big(I|Ui)nt64)Array\]/;var X=c?function(n){return h?h(n)&&!q(n):H(n)&&Q.test(a.call(n))}:C(!1),Y=J("length");function Z(n,r){r=function(n){for(var r={},t=n.length,e=0;e":">",'"':""","'":"'","`":"`"},Cn=Ln($n),Kn=Ln(_n($n)),Jn=tn.templateSettings={evaluate:/<%([\s\S]+?)%>/g,interpolate:/<%=([\s\S]+?)%>/g,escape:/<%-([\s\S]+?)%>/g},Gn=/(.)^/,Hn={"'":"'","\\":"\\","\r":"r","\n":"n","\u2028":"u2028","\u2029":"u2029"},Qn=/\\|'|\r|\n|\u2028|\u2029/g;function Xn(n){return"\\"+Hn[n]}var Yn=/^\s*(\w|\$)+\s*$/;var Zn=0;function nr(n,r,t,e,u){if(!(e instanceof r))return n.apply(t,u);var o=Mn(n.prototype),i=n.apply(o,u);return _(i)?i:o}var rr=j((function(n,r){var t=rr.placeholder,e=function(){for(var u=0,o=r.length,i=Array(o),a=0;a1)ur(a,r-1,t,e),u=e.length;else for(var f=0,c=a.length;f0&&(t=r.apply(this,arguments)),n<=1&&(r=null),t}}var lr=rr(cr,2);function sr(n,r,t){r=qn(r,t);for(var e,u=nn(n),o=0,i=u.length;o0?0:u-1;o>=0&&o0?a=o>=0?o:Math.max(o+f,a):f=o>=0?Math.min(o+1,f):o+f+1;else if(t&&o&&f)return e[o=t(e,u)]===u?o:-1;if(u!=u)return(o=r(i.call(e,a,f),$))>=0?o+a:-1;for(o=n>0?a:f-1;o>=0&&o0?0:i-1;for(u||(e=r[o?o[a]:a],a+=n);a>=0&&a=3;return r(n,Fn(t,u,4),e,o)}}var Ar=wr(1),xr=wr(-1);function Sr(n,r,t){var e=[];return r=qn(r,t),jr(n,(function(n,t,u){r(n,t,u)&&e.push(n)})),e}function Or(n,r,t){r=qn(r,t);for(var e=!er(n)&&nn(n),u=(e||n).length,o=0;o=0}var Br=j((function(n,r,t){var e,u;return D(r)?u=r:(r=Nn(r),e=r.slice(0,-1),r=r[r.length-1]),_r(n,(function(n){var o=u;if(!o){if(e&&e.length&&(n=In(n,e)),null==n)return;o=n[r]}return null==o?o:o.apply(n,t)}))}));function Nr(n,r){return _r(n,Rn(r))}function Ir(n,r,t){var e,u,o=-1/0,i=-1/0;if(null==r||"number"==typeof r&&"object"!=typeof n[0]&&null!=n)for(var a=0,f=(n=er(n)?n:jn(n)).length;ao&&(o=e);else r=qn(r,t),jr(n,(function(n,t,e){((u=r(n,t,e))>i||u===-1/0&&o===-1/0)&&(o=n,i=u)}));return o}function Tr(n,r,t){if(null==r||t)return er(n)||(n=jn(n)),n[Wn(n.length-1)];var e=er(n)?En(n):jn(n),u=Y(e);r=Math.max(Math.min(r,u),0);for(var o=u-1,i=0;i1&&(e=Fn(e,r[1])),r=an(n)):(e=qr,r=ur(r,!1,!1),n=Object(n));for(var u=0,o=r.length;u1&&(t=r[1])):(r=_r(ur(r,!1,!1),String),e=function(n,t){return!Er(r,t)}),Ur(n,e,t)}));function zr(n,r,t){return i.call(n,0,Math.max(0,n.length-(null==r||t?1:r)))}function Lr(n,r,t){return null==n||n.length<1?null==r||t?void 0:[]:null==r||t?n[0]:zr(n,n.length-r)}function $r(n,r,t){return i.call(n,null==r||t?1:r)}var Cr=j((function(n,r){return r=ur(r,!0,!0),Sr(n,(function(n){return!Er(r,n)}))})),Kr=j((function(n,r){return Cr(n,r)}));function Jr(n,r,t,e){A(r)||(e=t,t=r,r=!1),null!=t&&(t=qn(t,e));for(var u=[],o=[],i=0,a=Y(n);ir?(e&&(clearTimeout(e),e=null),a=c,i=n.apply(u,o),e||(u=o=null)):e||!1===t.trailing||(e=setTimeout(f,l)),i};return c.cancel=function(){clearTimeout(e),a=0,e=u=o=null},c},debounce:function(n,r,t){var e,u,o,i,a,f=function(){var c=zn()-u;r>c?e=setTimeout(f,r-c):(e=null,t||(i=n.apply(a,o)),e||(o=a=null))},c=j((function(c){return a=this,o=c,u=zn(),e||(e=setTimeout(f,r),t&&(i=n.apply(a,o))),i}));return c.cancel=function(){clearTimeout(e),e=o=a=null},c},wrap:function(n,r){return rr(r,n)},negate:fr,compose:function(){var n=arguments,r=n.length-1;return function(){for(var t=r,e=n[r].apply(this,arguments);t--;)e=n[t].call(this,e);return e}},after:function(n,r){return function(){if(--n<1)return r.apply(this,arguments)}},before:cr,once:lr,findKey:sr,findIndex:vr,findLastIndex:hr,sortedIndex:yr,indexOf:gr,lastIndexOf:br,find:mr,detect:mr,findWhere:function(n,r){return mr(n,Dn(r))},each:jr,forEach:jr,map:_r,collect:_r,reduce:Ar,foldl:Ar,inject:Ar,reduceRight:xr,foldr:xr,filter:Sr,select:Sr,reject:function(n,r,t){return Sr(n,fr(qn(r)),t)},every:Or,all:Or,some:Mr,any:Mr,contains:Er,includes:Er,include:Er,invoke:Br,pluck:Nr,where:function(n,r){return Sr(n,Dn(r))},max:Ir,min:function(n,r,t){var e,u,o=1/0,i=1/0;if(null==r||"number"==typeof r&&"object"!=typeof n[0]&&null!=n)for(var a=0,f=(n=er(n)?n:jn(n)).length;ae||void 0===t)return 1;if(t 3 | 4 | 5 | 6 | 7 | 8 | daam.run package — DAAM v0.0.4 documentation 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 |
27 |
28 |
29 | 30 | 31 |
32 | 33 |
34 |

daam.run package

35 |
36 |

Submodules

37 |
38 |
39 |

daam.run.annotate module

40 |
41 |
42 | daam.run.annotate.main()
43 |
44 | 45 |
46 |
47 |

daam.run.daam_to_mask module

48 |
49 |
50 | daam.run.daam_to_mask.main()
51 |
52 | 53 |
54 |
55 |

daam.run.evaluate module

56 |
57 |
58 | daam.run.evaluate.main()
59 |
60 | 61 |
62 |
63 |

daam.run.filter_coco module

64 |
65 |
66 | daam.run.filter_coco.main()
67 |
68 | 69 |
70 |
71 |

daam.run.generate module

72 |
73 |
74 | daam.run.generate.build_word_list_large() Dict[str, List[str]]
75 |
76 | 77 |
78 |
79 | daam.run.generate.main()
80 |
81 | 82 |
83 |
84 |

daam.run.test_literacy module

85 |
86 |
87 | daam.run.test_literacy.main()
88 |
89 | 90 |
91 |
92 |

daam.run.test_numeracy module

93 |
94 |
95 | daam.run.test_numeracy.main()
96 |
97 | 98 |
99 |
100 |

Module contents

101 |
102 |
103 | 104 | 105 |
106 | 107 |
108 |
109 | 149 |
150 |
151 | 162 | 163 | 164 | 165 | 166 | 167 | -------------------------------------------------------------------------------- /docs/daam.run.rst: -------------------------------------------------------------------------------- 1 | daam.run package 2 | ================ 3 | 4 | Submodules 5 | ---------- 6 | 7 | daam.run.annotate module 8 | ------------------------ 9 | 10 | .. automodule:: daam.run.annotate 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | daam.run.daam\_to\_mask module 16 | ------------------------------ 17 | 18 | .. automodule:: daam.run.daam_to_mask 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | daam.run.evaluate module 24 | ------------------------ 25 | 26 | .. automodule:: daam.run.evaluate 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | daam.run.filter\_coco module 32 | ---------------------------- 33 | 34 | .. automodule:: daam.run.filter_coco 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | 39 | daam.run.generate module 40 | ------------------------ 41 | 42 | .. automodule:: daam.run.generate 43 | :members: 44 | :undoc-members: 45 | :show-inheritance: 46 | 47 | daam.run.test\_literacy module 48 | ------------------------------ 49 | 50 | .. automodule:: daam.run.test_literacy 51 | :members: 52 | :undoc-members: 53 | :show-inheritance: 54 | 55 | daam.run.test\_numeracy module 56 | ------------------------------ 57 | 58 | .. automodule:: daam.run.test_numeracy 59 | :members: 60 | :undoc-members: 61 | :show-inheritance: 62 | 63 | Module contents 64 | --------------- 65 | 66 | .. automodule:: daam.run 67 | :members: 68 | :undoc-members: 69 | :show-inheritance: 70 | -------------------------------------------------------------------------------- /docs/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | Welcome to DAAM. — DAAM v0.0.4 documentation 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 |
27 |
28 |
29 | 30 | 31 |
32 | 33 |
34 |

Welcome to DAAM.

35 |

DAAM is an attribution method and toolkit for interpreting the Stable Diffusion model.

36 |
37 |
38 |
39 |
40 |

Indices and tables

41 | 46 |
47 | 48 | 49 |
50 | 51 |
52 |
53 | 93 |
94 |
95 | 106 | 107 | 108 | 109 | 110 | 111 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | 2 | Welcome to DAAM. 3 | =================================== 4 | `DAAM `_ is an attribution method and toolkit for interpreting the Stable Diffusion model. 5 | 6 | .. toctree:: 7 | :maxdepth: 2 8 | :caption: Contents: 9 | 10 | 11 | 12 | Indices and tables 13 | ================== 14 | 15 | * :ref:`genindex` 16 | * :ref:`modindex` 17 | * :ref:`search` 18 | -------------------------------------------------------------------------------- /docs/make.sh: -------------------------------------------------------------------------------- 1 | rm daam.*rst daam.rst 2 | rm -rf _build 3 | sphinx-apidoc -f -o . ../daam 4 | make html 5 | make github 6 | 7 | -------------------------------------------------------------------------------- /docs/modules.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | daam — DAAM v0.0.4 documentation 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 |
27 | 71 | 111 |
112 |
113 | 124 | 125 | 126 | 127 | 128 | 129 | -------------------------------------------------------------------------------- /docs/modules.rst: -------------------------------------------------------------------------------- 1 | daam 2 | ==== 3 | 4 | .. toctree:: 5 | :maxdepth: 4 6 | 7 | daam 8 | -------------------------------------------------------------------------------- /docs/objects.inv: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/castorini/daam/c30493ed0154bfccb6c342400f25cc24599bb1ff/docs/objects.inv -------------------------------------------------------------------------------- /docs/py-modindex.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | Python Module Index — DAAM v0.0.4 documentation 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 |
30 |
31 |
32 | 33 | 34 |
35 | 36 | 37 |

Python Module Index

38 | 39 |
40 | d 41 |
42 | 43 | 44 | 45 | 47 | 48 | 50 | 53 | 54 | 55 | 58 | 59 | 60 | 63 | 64 | 65 | 68 | 69 | 70 | 73 | 74 | 75 | 78 | 79 | 80 | 83 | 84 | 85 | 88 | 89 | 90 | 93 | 94 | 95 | 98 | 99 | 100 | 103 | 104 | 105 | 108 | 109 | 110 | 113 | 114 | 115 | 118 |
 
46 | d
51 | daam 52 |
    56 | daam.evaluate 57 |
    61 | daam.experiment 62 |
    66 | daam.hook 67 |
    71 | daam.run 72 |
    76 | daam.run.annotate 77 |
    81 | daam.run.daam_to_mask 82 |
    86 | daam.run.evaluate 87 |
    91 | daam.run.filter_coco 92 |
    96 | daam.run.generate 97 |
    101 | daam.run.test_literacy 102 |
    106 | daam.run.test_numeracy 107 |
    111 | daam.trace 112 |
    116 | daam.utils 117 |
119 | 120 | 121 |
122 | 123 |
124 |
125 | 165 |
166 |
167 | 175 | 176 | 177 | 178 | 179 | 180 | -------------------------------------------------------------------------------- /docs/search.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | Search — DAAM v0.0.4 documentation 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 |
33 |
34 |
35 | 36 | 37 |
38 | 39 |

Search

40 | 41 | 49 | 50 | 51 |

52 | Searching for multiple words only shows matches that contain 53 | all words. 54 |

55 | 56 | 57 |
58 | 59 | 60 | 61 |
62 | 63 | 64 | 65 |
66 | 67 |
68 | 69 | 70 |
71 | 72 |
73 |
74 | 104 |
105 |
106 | 114 | 115 | 116 | 117 | 118 | 119 | -------------------------------------------------------------------------------- /docs/searchindex.js: -------------------------------------------------------------------------------- 1 | Search.setIndex({docnames:["daam","daam.run","index","modules"],envversion:{"sphinx.domains.c":2,"sphinx.domains.changeset":1,"sphinx.domains.citation":1,"sphinx.domains.cpp":5,"sphinx.domains.index":1,"sphinx.domains.javascript":2,"sphinx.domains.math":2,"sphinx.domains.python":3,"sphinx.domains.rst":2,"sphinx.domains.std":2,sphinx:56},filenames:["daam.rst","daam.run.rst","index.rst","modules.rst"],objects:{"":[[0,0,0,"-","daam"]],"daam.evaluate":[[0,1,1,"","MeanEvaluator"],[0,4,1,"","compute_iou"]],"daam.evaluate.MeanEvaluator":[[0,2,1,"","log_intensity"],[0,2,1,"","log_iou"],[0,3,1,"","mean_intensity"],[0,3,1,"","mean_iou"]],"daam.experiment":[[0,1,1,"","GenerationExperiment"],[0,4,1,"","build_word_list_coco80"]],"daam.experiment.GenerationExperiment":[[0,2,1,"","annotate"],[0,5,1,"","annotations"],[0,2,1,"","clear_prediction_masks"],[0,2,1,"","contains_truth_mask"],[0,5,1,"","global_heat_map"],[0,2,1,"","has_annotations"],[0,2,1,"","has_experiment"],[0,5,1,"","id"],[0,5,1,"","image"],[0,2,1,"","load"],[0,2,1,"","nsfw"],[0,5,1,"","path"],[0,5,1,"","prediction_masks"],[0,5,1,"","prompt"],[0,2,1,"","read_prompt"],[0,2,1,"","read_seed"],[0,2,1,"","save"],[0,2,1,"","save_annotations"],[0,2,1,"","save_heat_map"],[0,2,1,"","save_prediction_mask"],[0,5,1,"","seed"],[0,5,1,"","truth_masks"]],"daam.hook":[[0,1,1,"","AggregateHooker"],[0,1,1,"","ModuleLocator"],[0,1,1,"","ObjectHooker"],[0,1,1,"","UNetCrossAttentionLocator"]],"daam.hook.AggregateHooker":[[0,2,1,"","register_hook"]],"daam.hook.ModuleLocator":[[0,2,1,"","locate"]],"daam.hook.ObjectHooker":[[0,2,1,"","hook"],[0,2,1,"","monkey_patch"],[0,2,1,"","monkey_super"],[0,2,1,"","unhook"]],"daam.hook.UNetCrossAttentionLocator":[[0,2,1,"","locate"]],"daam.run":[[1,0,0,"-","annotate"],[1,0,0,"-","daam_to_mask"],[1,0,0,"-","evaluate"],[1,0,0,"-","filter_coco"],[1,0,0,"-","generate"],[1,0,0,"-","test_literacy"],[1,0,0,"-","test_numeracy"]],"daam.run.annotate":[[1,4,1,"","main"]],"daam.run.daam_to_mask":[[1,4,1,"","main"]],"daam.run.evaluate":[[1,4,1,"","main"]],"daam.run.filter_coco":[[1,4,1,"","main"]],"daam.run.generate":[[1,4,1,"","build_word_list_large"],[1,4,1,"","main"]],"daam.run.test_literacy":[[1,4,1,"","main"]],"daam.run.test_numeracy":[[1,4,1,"","main"]],"daam.trace":[[0,1,1,"","DiffusionHeatMapHooker"],[0,1,1,"","HeatMap"],[0,1,1,"","MmDetectHeatMap"],[0,5,1,"","trace"]],"daam.trace.DiffusionHeatMapHooker":[[0,3,1,"","all_heat_maps"],[0,2,1,"","compute_global_heat_map"]],"daam.trace.HeatMap":[[0,2,1,"","compute_word_heat_map"]],"daam.trace.MmDetectHeatMap":[[0,2,1,"","compute_word_heat_map"]],"daam.utils":[[0,4,1,"","compute_token_merge_indices"],[0,4,1,"","expand_image"],[0,4,1,"","plot_mask_heat_map"],[0,4,1,"","plot_overlay_heat_map"],[0,4,1,"","set_seed"]],daam:[[0,0,0,"-","evaluate"],[0,0,0,"-","experiment"],[0,0,0,"-","hook"],[1,0,0,"-","run"],[0,0,0,"-","trace"],[0,0,0,"-","utils"]]},objnames:{"0":["py","module","Python module"],"1":["py","class","Python class"],"2":["py","method","Python method"],"3":["py","property","Python property"],"4":["py","function","Python function"],"5":["py","attribute","Python attribute"]},objtypes:{"0":"py:module","1":"py:class","2":"py:method","3":"py:property","4":"py:function","5":"py:attribute"},terms:{"0":0,"4":0,"512":0,"95":0,"class":0,"float":0,"int":0,"return":0,"static":0,If:0,The:0,_c:0,absolut:0,across:0,aggreg:0,aggregatehook:0,alia:0,all:0,all_heat_map:0,an:2,ani:0,annot:[0,3],appli:0,applic:0,ar:0,arg:0,attent:0,attribut:2,b:0,base:0,block:0,bool:0,build_word_list_coco80:0,build_word_list_larg:1,classmethod:0,clear_prediction_mask:0,composit:0,comput:0,compute_global_heat_map:0,compute_i:0,compute_token_merge_indic:0,compute_word_heat_map:0,contains_truth_mask:0,content:3,cross:0,crossattent:0,daam_to_mask:[0,3],dict:[0,1],differ:0,diffus:[0,2],diffusionheatmaphook:0,each:0,equal:0,evalu:3,expand_imag:0,experi:3,factor:0,fals:0,filter_coco:[0,3],fn:0,fn_name:0,gener:[0,3],generationexperi:0,given:0,global:0,global_heat_map:0,has_annot:0,has_experi:0,heat:0,heat_map:0,heatmap:0,hold:0,hook:3,id:0,im:0,imag:0,index:2,infer:0,interpret:2,kei:0,kwarg:0,last:0,last_n:0,list:[0,1],load:0,locat:0,log_intens:0,log_iou:0,main:1,map:0,mask:0,mean_intens:0,mean_iou:0,meanevalu:0,method:2,mmdetectheatmap:0,model:[0,2],modul:[2,3],modulelisttyp:0,moduleloc:0,moduletyp:0,monkey_patch:0,monkey_sup:0,n:0,name:0,ndarrai:0,nn:0,none:0,nsfw:0,number:0,numpi:0,object:0,objecthook:0,option:0,out:0,out_fil:0,packag:3,page:2,paramet:0,path:0,pathlib:0,pickleabl:0,pil:0,pipelin:0,pipeline_stable_diffus:0,plot_mask_heat_map:0,plot_overlay_heat_map:0,pred:0,pred_fil:0,pred_prefix:0,prediction_mask:0,pretrainedtoken:0,prompt:0,prompt_id:0,properti:0,read_prompt:0,read_se:0,register_hook:0,restrict:0,run:[0,3],save:0,save_annot:0,save_heat_map:0,save_prediction_mask:0,search:2,seed:0,set:0,set_se:0,simplify80:0,size:0,space:0,spatial:0,stabl:2,stable_diffus:0,stablediffusionpipelin:0,step:0,str:[0,1],submodul:3,subpackag:3,tensor:0,test_literaci:[0,3],test_numeraci:[0,3],thi:0,threshold:0,time:0,time_idx:0,time_weight:0,token:0,tokenization_util:0,toolkit:2,torch:0,trace:3,transform:0,truth:0,truth_mask:0,type:0,unet2dconditionmodel:0,unet_2d_condit:0,unetcrossattentionloc:0,unhook:0,union:0,us:0,util:3,valu:0,vocab:0,weight:0,word:0,word_idx:0},titles:["daam package","daam.run package","Welcome to DAAM.","daam"],titleterms:{annot:1,content:[0,1],daam:[0,1,2,3],daam_to_mask:1,evalu:[0,1],experi:0,filter_coco:1,gener:1,hook:0,indic:2,modul:[0,1],packag:[0,1],run:1,submodul:[0,1],subpackag:0,tabl:2,test_literaci:1,test_numeraci:1,trace:0,util:0,welcom:2}}) -------------------------------------------------------------------------------- /example.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/castorini/daam/c30493ed0154bfccb6c342400f25cc24599bb1ff/example.jpg -------------------------------------------------------------------------------- /notebooks/0-setup.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "65bf851a", 6 | "metadata": {}, 7 | "source": [ 8 | "# Installation of Prerequisites" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": 1, 14 | "id": "4616b907", 15 | "metadata": { 16 | "scrolled": true 17 | }, 18 | "outputs": [ 19 | { 20 | "name": "stdout", 21 | "output_type": "stream", 22 | "text": [ 23 | "--2023-07-07 00:30:51-- https://nlp.stanford.edu/software/stanford-corenlp-4.5.4.zip\n", 24 | "Resolving nlp.stanford.edu (nlp.stanford.edu)... 171.64.67.140\n", 25 | "Connecting to nlp.stanford.edu (nlp.stanford.edu)|171.64.67.140|:443... connected.\n", 26 | "HTTP request sent, awaiting response... 302 FOUND\n", 27 | "Location: https://downloads.cs.stanford.edu/nlp/software/stanford-corenlp-4.5.4.zip [following]\n", 28 | "--2023-07-07 00:30:52-- https://downloads.cs.stanford.edu/nlp/software/stanford-corenlp-4.5.4.zip\n", 29 | "Resolving downloads.cs.stanford.edu (downloads.cs.stanford.edu)... 171.64.64.22\n", 30 | "Connecting to downloads.cs.stanford.edu (downloads.cs.stanford.edu)|171.64.64.22|:443... connected.\n", 31 | "HTTP request sent, awaiting response... 200 OK\n", 32 | "Length: 506470124 (483M) [application/zip]\n", 33 | "Saving to: ‘stanford-corenlp-4.5.4.zip’\n", 34 | "\n", 35 | "orenlp-4.5.4.zip 18%[==> ] 88.40M 5.08MB/s eta 60s ^C\n", 36 | "Archive: stanford-corenlp-4.5.4.zip\n", 37 | " End-of-central-directory signature not found. Either this file is not\n", 38 | " a zipfile, or it constitutes one disk of a multi-part archive. In the\n", 39 | " latter case the central directory and zipfile comment will be found on\n", 40 | " the last disk(s) of this archive.\n", 41 | "unzip: cannot find zipfile directory in one of stanford-corenlp-4.5.4.zip or\n", 42 | " stanford-corenlp-4.5.4.zip.zip, and cannot find stanford-corenlp-4.5.4.zip.ZIP, period.\n" 43 | ] 44 | } 45 | ], 46 | "source": [ 47 | "!wget https://nlp.stanford.edu/software/stanford-corenlp-4.5.4.zip\n", 48 | "!unzip stanford-corenlp-4.5.4.zip" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": 2, 54 | "id": "956dcdf1", 55 | "metadata": { 56 | "scrolled": true 57 | }, 58 | "outputs": [ 59 | { 60 | "name": "stdout", 61 | "output_type": "stream", 62 | "text": [ 63 | "Collecting stanza\n", 64 | " Downloading stanza-1.5.0-py3-none-any.whl (802 kB)\n", 65 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m802.5/802.5 kB\u001b[0m \u001b[31m4.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n", 66 | "\u001b[?25hCollecting emoji (from stanza)\n", 67 | " Downloading emoji-2.6.0.tar.gz (356 kB)\n", 68 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m356.6/356.6 kB\u001b[0m \u001b[31m4.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n", 69 | "\u001b[?25h Preparing metadata (setup.py) ... \u001b[?25ldone\n", 70 | "\u001b[?25hRequirement already satisfied: numpy in /home/ralph/miniconda3/envs/daam/lib/python3.8/site-packages (from stanza) (1.24.3)\n", 71 | "Collecting protobuf (from stanza)\n", 72 | " Downloading protobuf-4.23.4-cp37-abi3-manylinux2014_x86_64.whl (304 kB)\n", 73 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m304.5/304.5 kB\u001b[0m \u001b[31m4.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n", 74 | "\u001b[?25hRequirement already satisfied: requests in /home/ralph/miniconda3/envs/daam/lib/python3.8/site-packages (from stanza) (2.29.0)\n", 75 | "Requirement already satisfied: six in /home/ralph/miniconda3/envs/daam/lib/python3.8/site-packages (from stanza) (1.16.0)\n", 76 | "Requirement already satisfied: torch>=1.3.0 in /home/ralph/miniconda3/envs/daam/lib/python3.8/site-packages (from stanza) (2.0.1)\n", 77 | "Requirement already satisfied: tqdm in /home/ralph/miniconda3/envs/daam/lib/python3.8/site-packages (from stanza) (4.65.0)\n", 78 | "Requirement already satisfied: filelock in /home/ralph/miniconda3/envs/daam/lib/python3.8/site-packages (from torch>=1.3.0->stanza) (3.9.0)\n", 79 | "Requirement already satisfied: typing-extensions in /home/ralph/miniconda3/envs/daam/lib/python3.8/site-packages (from torch>=1.3.0->stanza) (4.6.3)\n", 80 | "Requirement already satisfied: sympy in /home/ralph/miniconda3/envs/daam/lib/python3.8/site-packages (from torch>=1.3.0->stanza) (1.11.1)\n", 81 | "Requirement already satisfied: networkx in /home/ralph/miniconda3/envs/daam/lib/python3.8/site-packages (from torch>=1.3.0->stanza) (2.8.4)\n", 82 | "Requirement already satisfied: jinja2 in /home/ralph/miniconda3/envs/daam/lib/python3.8/site-packages (from torch>=1.3.0->stanza) (3.1.2)\n", 83 | "Requirement already satisfied: charset-normalizer<4,>=2 in /home/ralph/miniconda3/envs/daam/lib/python3.8/site-packages (from requests->stanza) (2.0.4)\n", 84 | "Requirement already satisfied: idna<4,>=2.5 in /home/ralph/miniconda3/envs/daam/lib/python3.8/site-packages (from requests->stanza) (3.4)\n", 85 | "Requirement already satisfied: urllib3<1.27,>=1.21.1 in /home/ralph/miniconda3/envs/daam/lib/python3.8/site-packages (from requests->stanza) (1.26.16)\n", 86 | "Requirement already satisfied: certifi>=2017.4.17 in /home/ralph/miniconda3/envs/daam/lib/python3.8/site-packages (from requests->stanza) (2023.5.7)\n", 87 | "Requirement already satisfied: MarkupSafe>=2.0 in /home/ralph/miniconda3/envs/daam/lib/python3.8/site-packages (from jinja2->torch>=1.3.0->stanza) (2.1.1)\n", 88 | "Requirement already satisfied: mpmath>=0.19 in /home/ralph/miniconda3/envs/daam/lib/python3.8/site-packages (from sympy->torch>=1.3.0->stanza) (1.2.1)\n", 89 | "Building wheels for collected packages: emoji\n", 90 | " Building wheel for emoji (setup.py) ... \u001b[?25ldone\n", 91 | "\u001b[?25h Created wheel for emoji: filename=emoji-2.6.0-py2.py3-none-any.whl size=351312 sha256=438edf73fbaa4e062879aa454062e0f1172c51d84c7df39f189be5af4a37dd92\n", 92 | " Stored in directory: /home/ralph/.cache/pip/wheels/65/d8/90/e78a11fccc67c1983e5496ee1b6c831bce3185ed9dec4cd2c2\n", 93 | "Successfully built emoji\n", 94 | "Installing collected packages: protobuf, emoji, stanza\n", 95 | "Successfully installed emoji-2.6.0 protobuf-4.23.4 stanza-1.5.0\n" 96 | ] 97 | } 98 | ], 99 | "source": [ 100 | "!pip install stanza" 101 | ] 102 | }, 103 | { 104 | "cell_type": "code", 105 | "execution_count": 3, 106 | "id": "8bf1406f", 107 | "metadata": {}, 108 | "outputs": [ 109 | { 110 | "name": "stderr", 111 | "output_type": "stream", 112 | "text": [ 113 | "2023-07-07 00:32:22 WARNING: Directory stanford-corenlp-4.5.4 already exists. Please install CoreNLP to a new directory.\n" 114 | ] 115 | } 116 | ], 117 | "source": [ 118 | "import stanza\n", 119 | "stanza.install_corenlp(dir='stanford-corenlp-4.5.4')" 120 | ] 121 | }, 122 | { 123 | "cell_type": "code", 124 | "execution_count": null, 125 | "id": "321801f0", 126 | "metadata": {}, 127 | "outputs": [], 128 | "source": [ 129 | "!pip install daam==0.1.0" 130 | ] 131 | }, 132 | { 133 | "cell_type": "code", 134 | "execution_count": 4, 135 | "id": "ec55a0b3", 136 | "metadata": {}, 137 | "outputs": [ 138 | { 139 | "name": "stdout", 140 | "output_type": "stream", 141 | "text": [ 142 | "--2023-07-07 00:33:02-- http://images.cocodataset.org/annotations/annotations_trainval2014.zip\n", 143 | "Resolving images.cocodataset.org (images.cocodataset.org)... 16.182.64.121, 52.216.220.249, 54.231.170.209, ...\n", 144 | "Connecting to images.cocodataset.org (images.cocodataset.org)|16.182.64.121|:80... connected.\n", 145 | "HTTP request sent, awaiting response... 200 OK\n", 146 | "Length: 252872794 (241M) [application/zip]\n", 147 | "Saving to: ‘annotations_trainval2014.zip’\n", 148 | "\n", 149 | "annotations_trainva 100%[===================>] 241.16M 34.5MB/s in 9.6s \n", 150 | "\n", 151 | "2023-07-07 00:33:12 (25.1 MB/s) - ‘annotations_trainval2014.zip’ saved [252872794/252872794]\n", 152 | "\n" 153 | ] 154 | } 155 | ], 156 | "source": [ 157 | "!wget http://images.cocodataset.org/annotations/annotations_trainval2014.zip" 158 | ] 159 | }, 160 | { 161 | "cell_type": "code", 162 | "execution_count": 5, 163 | "id": "1ed5edd7", 164 | "metadata": {}, 165 | "outputs": [], 166 | "source": [ 167 | "!mkdir -p coco\n", 168 | "!mv annotations_* coco" 169 | ] 170 | }, 171 | { 172 | "cell_type": "code", 173 | "execution_count": 9, 174 | "id": "2f6ed111", 175 | "metadata": {}, 176 | "outputs": [ 177 | { 178 | "name": "stdout", 179 | "output_type": "stream", 180 | "text": [ 181 | "[Errno 2] No such file or directory: 'coco'\n", 182 | "/home/ralph/programming/daam/notebooks/coco\n", 183 | "Archive: annotations_trainval2014.zip\n", 184 | " inflating: annotations/instances_train2014.json \n", 185 | " inflating: annotations/instances_val2014.json \n", 186 | " inflating: annotations/person_keypoints_train2014.json \n", 187 | " inflating: annotations/person_keypoints_val2014.json \n", 188 | " inflating: annotations/captions_train2014.json \n", 189 | " inflating: annotations/captions_val2014.json \n" 190 | ] 191 | } 192 | ], 193 | "source": [ 194 | "%cd coco\n", 195 | "!unzip annotations_*" 196 | ] 197 | } 198 | ], 199 | "metadata": { 200 | "kernelspec": { 201 | "display_name": "Python 3 (ipykernel)", 202 | "language": "python", 203 | "name": "python3" 204 | }, 205 | "language_info": { 206 | "codemirror_mode": { 207 | "name": "ipython", 208 | "version": 3 209 | }, 210 | "file_extension": ".py", 211 | "mimetype": "text/x-python", 212 | "name": "python", 213 | "nbconvert_exporter": "python", 214 | "pygments_lexer": "ipython3", 215 | "version": "3.8.17" 216 | } 217 | }, 218 | "nbformat": 4, 219 | "nbformat_minor": 5 220 | } 221 | -------------------------------------------------------------------------------- /notebooks/1-visuosyntactic-analyses.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "collapsed": true 7 | }, 8 | "source": [ 9 | "# Visuosyntactic Analyses" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": 1, 15 | "metadata": {}, 16 | "outputs": [ 17 | { 18 | "name": "stdout", 19 | "output_type": "stream", 20 | "text": [ 21 | "env: CORENLP_HOME=stanford-corenlp-4.5.4\n" 22 | ] 23 | } 24 | ], 25 | "source": [ 26 | "%env CORENLP_HOME=stanford-corenlp-4.5.4" 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "execution_count": 2, 32 | "metadata": {}, 33 | "outputs": [ 34 | { 35 | "name": "stderr", 36 | "output_type": "stream", 37 | "text": [ 38 | "2023-07-07 00:32:43 WARNING: Directory stanford-corenlp-4.5.4 already exists. Please install CoreNLP to a new directory.\n", 39 | "2023-07-07 00:32:43 INFO: Writing properties to tmp file: corenlp_server-8e8c4b3e34ad4e6a.props\n" 40 | ] 41 | } 42 | ], 43 | "source": [ 44 | "from stanza.server import CoreNLPClient\n", 45 | "import stanza\n", 46 | "\n", 47 | "stanza.install_corenlp(dir='stanford-corenlp-4.5.4')\n", 48 | "client = CoreNLPClient(annotators=['tokenize', 'ssplit', 'pos', 'lemma', 'ner', 'parse', 'depparse','coref'], timeout=30000, memory='6G')" 49 | ] 50 | }, 51 | { 52 | "cell_type": "markdown", 53 | "metadata": {}, 54 | "source": [ 55 | "# Generate DAAM Maps" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": 7, 61 | "metadata": {}, 62 | "outputs": [], 63 | "source": [ 64 | "from pathlib import Path\n", 65 | "import json\n", 66 | "\n", 67 | "annotations = json.load(Path('coco/annotations/captions_val2014.json').open())" 68 | ] 69 | }, 70 | { 71 | "cell_type": "code", 72 | "execution_count": 11, 73 | "metadata": {}, 74 | "outputs": [ 75 | { 76 | "data": { 77 | "text/plain": [ 78 | "dict_keys(['info', 'images', 'licenses', 'annotations'])" 79 | ] 80 | }, 81 | "execution_count": 11, 82 | "metadata": {}, 83 | "output_type": "execute_result" 84 | } 85 | ], 86 | "source": [ 87 | "annotations.keys()" 88 | ] 89 | }, 90 | { 91 | "cell_type": "code", 92 | "execution_count": 13, 93 | "metadata": {}, 94 | "outputs": [], 95 | "source": [ 96 | "import pandas as pd\n", 97 | "\n", 98 | "df = pd.DataFrame(annotations['annotations'])" 99 | ] 100 | }, 101 | { 102 | "cell_type": "code", 103 | "execution_count": 15, 104 | "metadata": {}, 105 | "outputs": [], 106 | "source": [ 107 | "!mkdir -p experiments/visuosyntax" 108 | ] 109 | }, 110 | { 111 | "cell_type": "code", 112 | "execution_count": 17, 113 | "metadata": {}, 114 | "outputs": [], 115 | "source": [ 116 | "df = df.sample(1500, replace=False)" 117 | ] 118 | }, 119 | { 120 | "cell_type": "code", 121 | "execution_count": 22, 122 | "metadata": {}, 123 | "outputs": [], 124 | "source": [ 125 | "import torch\n", 126 | "\n", 127 | "torch.cuda.amp.autocast().__enter__()\n", 128 | "torch.set_grad_enabled(False);" 129 | ] 130 | }, 131 | { 132 | "cell_type": "code", 133 | "execution_count": null, 134 | "metadata": {}, 135 | "outputs": [], 136 | "source": [ 137 | "from diffusers import StableDiffusionPipeline\n", 138 | "from daam import set_seed, trace\n", 139 | "\n", 140 | "pipe = StableDiffusionPipeline.from_pretrained('stabilityai/stable-diffusion-2-1-base')" 141 | ] 142 | }, 143 | { 144 | "cell_type": "code", 145 | "execution_count": 37, 146 | "metadata": {}, 147 | "outputs": [], 148 | "source": [ 149 | "pipe.to('cuda:0');" 150 | ] 151 | }, 152 | { 153 | "cell_type": "code", 154 | "execution_count": null, 155 | "metadata": { 156 | "scrolled": true 157 | }, 158 | "outputs": [], 159 | "source": [ 160 | "from tqdm import tqdm\n", 161 | "\n", 162 | "for _, row in tqdm(df.iterrows(), total=len(df)):\n", 163 | " image_id, caption = row.image_id, row.caption\n", 164 | " gen = set_seed(image_id)\n", 165 | " output_folder = Path('experiments/visuosyntax')\n", 166 | " \n", 167 | " with trace(pipe) as tc:\n", 168 | " out = pipe(caption, num_inference_steps=30, generator=gen)\n", 169 | " exp = tc.to_experiment(output_folder, id=str(image_id), seed=image_id)\n", 170 | " exp.save(output_folder, heat_maps=False)" 171 | ] 172 | }, 173 | { 174 | "cell_type": "markdown", 175 | "metadata": {}, 176 | "source": [ 177 | "# Parse and Analyze" 178 | ] 179 | }, 180 | { 181 | "cell_type": "code", 182 | "execution_count": null, 183 | "metadata": {}, 184 | "outputs": [], 185 | "source": [ 186 | "from matplotlib import pyplot as plt\n", 187 | "from daam import GenerationExperiment\n", 188 | "\n", 189 | "def iou(a, b, t: float = 0.15) -> float:\n", 190 | " i = ((a > t) & (b > t)).float().sum()\n", 191 | " u = ((a > t) | (b > t)).float().sum()\n", 192 | " \n", 193 | " if u < 1e-6:\n", 194 | " return 0.0\n", 195 | " else:\n", 196 | " return (i / u).item()\n", 197 | "\n", 198 | "def ioa(a, b, t: float = 0.15) -> float:\n", 199 | " i = ((a > t) & (b > t)).float().sum()\n", 200 | " a = (a > t).float().sum()\n", 201 | " \n", 202 | " if a < 1e-6:\n", 203 | " return 0.0\n", 204 | " else:\n", 205 | " return (i / a).item()\n", 206 | "\n", 207 | "stats = []\n", 208 | "\n", 209 | "for path in tqdm(list(Path('experiments/visuosyntax').iterdir())):\n", 210 | " exp = GenerationExperiment.load(path)\n", 211 | " sent = client.annotate(exp.prompt).sentence[0]\n", 212 | " heat_map = exp.heat_map() \n", 213 | " word_maps = dict()\n", 214 | " \n", 215 | " for tok in sent.token:\n", 216 | " try:\n", 217 | " word_maps[tok.word] = heat_map.compute_word_heat_map(tok.word).value.cuda()\n", 218 | " except ValueError:\n", 219 | " pass \n", 220 | " \n", 221 | " for edge in sent.enhancedDependencies.edge:\n", 222 | " head = sent.token[edge.source - 1].word\n", 223 | " rel = edge.dep\n", 224 | " dep = sent.token[edge.target - 1].word\n", 225 | " \n", 226 | " try:\n", 227 | " head_heat_map = word_maps[head]\n", 228 | " dep_heat_map = word_maps[dep]\n", 229 | " except KeyError:\n", 230 | " continue\n", 231 | " \n", 232 | " stats.append(dict(\n", 233 | " rel=rel,\n", 234 | " iou=iou(head_heat_map, dep_heat_map),\n", 235 | " iod=ioa(dep_heat_map, head_heat_map),\n", 236 | " ioh=ioa(head_heat_map, dep_heat_map)\n", 237 | " ))" 238 | ] 239 | }, 240 | { 241 | "cell_type": "markdown", 242 | "metadata": {}, 243 | "source": [ 244 | "# Results" 245 | ] 246 | }, 247 | { 248 | "cell_type": "code", 249 | "execution_count": 149, 250 | "metadata": {}, 251 | "outputs": [], 252 | "source": [ 253 | "stats_df = pd.DataFrame(stats)\n", 254 | "res_df = stats_df.groupby('rel').agg(count=('rel', len), mIoU=('iou', 'mean'), mIoD=('iod', 'mean'), mIoH=('ioh', 'mean'))\n", 255 | "res_df = res_df.sort_values('count', ascending=False).iloc[:10]\n", 256 | "res_df['delta'] = (res_df['mIoH'] - res_df['mIoD']).abs()" 257 | ] 258 | }, 259 | { 260 | "cell_type": "code", 261 | "execution_count": 150, 262 | "metadata": {}, 263 | "outputs": [ 264 | { 265 | "data": { 266 | "text/html": [ 267 | "
\n", 268 | "\n", 281 | "\n", 282 | " \n", 283 | " \n", 284 | " \n", 285 | " \n", 286 | " \n", 287 | " \n", 288 | " \n", 289 | " \n", 290 | " \n", 291 | " \n", 292 | " \n", 293 | " \n", 294 | " \n", 295 | " \n", 296 | " \n", 297 | " \n", 298 | " \n", 299 | " \n", 300 | " \n", 301 | " \n", 302 | " \n", 303 | " \n", 304 | " \n", 305 | " \n", 306 | " \n", 307 | " \n", 308 | " \n", 309 | " \n", 310 | " \n", 311 | " \n", 312 | " \n", 313 | " \n", 314 | " \n", 315 | " \n", 316 | " \n", 317 | " \n", 318 | " \n", 319 | " \n", 320 | " \n", 321 | " \n", 322 | " \n", 323 | " \n", 324 | " \n", 325 | " \n", 326 | " \n", 327 | " \n", 328 | " \n", 329 | " \n", 330 | " \n", 331 | " \n", 332 | " \n", 333 | " \n", 334 | " \n", 335 | " \n", 336 | " \n", 337 | " \n", 338 | " \n", 339 | " \n", 340 | " \n", 341 | " \n", 342 | " \n", 343 | " \n", 344 | " \n", 345 | " \n", 346 | " \n", 347 | " \n", 348 | " \n", 349 | " \n", 350 | " \n", 351 | " \n", 352 | " \n", 353 | " \n", 354 | " \n", 355 | " \n", 356 | " \n", 357 | " \n", 358 | " \n", 359 | " \n", 360 | " \n", 361 | " \n", 362 | " \n", 363 | " \n", 364 | " \n", 365 | " \n", 366 | " \n", 367 | " \n", 368 | " \n", 369 | " \n", 370 | "
mIoUmIoDmIoHdelta
rel
punct0.0998572.4484100.1032952.345114
nmod:of8.65707412.85535821.9878569.132498
compound33.43411359.13079549.9851709.145626
nsubj5.02722710.69213322.71029312.018160
case3.83195218.0880065.89582912.192177
det0.44781113.0129750.65780812.355168
conj:and28.43592855.50186739.64988315.851984
acl6.45200928.69241511.10118417.591231
obj6.64195210.56667336.44249625.875823
amod14.69087845.06272019.05172026.011000
\n", 371 | "
" 372 | ], 373 | "text/plain": [ 374 | " mIoU mIoD mIoH delta\n", 375 | "rel \n", 376 | "punct 0.099857 2.448410 0.103295 2.345114\n", 377 | "nmod:of 8.657074 12.855358 21.987856 9.132498\n", 378 | "compound 33.434113 59.130795 49.985170 9.145626\n", 379 | "nsubj 5.027227 10.692133 22.710293 12.018160\n", 380 | "case 3.831952 18.088006 5.895829 12.192177\n", 381 | "det 0.447811 13.012975 0.657808 12.355168\n", 382 | "conj:and 28.435928 55.501867 39.649883 15.851984\n", 383 | "acl 6.452009 28.692415 11.101184 17.591231\n", 384 | "obj 6.641952 10.566673 36.442496 25.875823\n", 385 | "amod 14.690878 45.062720 19.051720 26.011000" 386 | ] 387 | }, 388 | "execution_count": 150, 389 | "metadata": {}, 390 | "output_type": "execute_result" 391 | } 392 | ], 393 | "source": [ 394 | "res_df.drop(columns=['count'], inplace=True)\n", 395 | "res_df = res_df.transform(lambda x: x * 100)\n", 396 | "res_df.sort_values('delta')" 397 | ] 398 | } 399 | ], 400 | "metadata": { 401 | "kernelspec": { 402 | "display_name": "Python 3 (ipykernel)", 403 | "language": "python", 404 | "name": "python3" 405 | }, 406 | "language_info": { 407 | "codemirror_mode": { 408 | "name": "ipython", 409 | "version": 3 410 | }, 411 | "file_extension": ".py", 412 | "mimetype": "text/x-python", 413 | "name": "python", 414 | "nbconvert_exporter": "python", 415 | "pygments_lexer": "ipython3", 416 | "version": "3.8.17" 417 | } 418 | }, 419 | "nbformat": 4, 420 | "nbformat_minor": 1 421 | } 422 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | scikit-image 2 | diffusers==0.21.2 3 | spacy 4 | gradio 5 | ftfy 6 | transformers==4.30.2 7 | pandas 8 | numba 9 | nltk 10 | inflect 11 | joblib 12 | accelerate==0.23.0 13 | -------------------------------------------------------------------------------- /scrollbar.css: -------------------------------------------------------------------------------- 1 | .output-html { 2 | overflow-x: auto; 3 | } 4 | 5 | .output-html::-webkit-scrollbar { 6 | -webkit-appearance: none; 7 | } 8 | 9 | .output-html::-webkit-scrollbar:vertical { 10 | width: 0px; 11 | } 12 | 13 | .output-html::-webkit-scrollbar:horizontal { 14 | height: 11px; 15 | } 16 | 17 | .output-html::-webkit-scrollbar-thumb { 18 | border-radius: 8px; 19 | border: 2px solid white; 20 | background-color: rgba(0, 0, 0, .5); 21 | } 22 | 23 | .output-html::-webkit-scrollbar-track { 24 | background-color: #fff; 25 | border-radius: 8px; 26 | } 27 | 28 | .spans { 29 | min-height: 75px; 30 | } 31 | 32 | svg { 33 | margin: auto; 34 | display: block; 35 | } 36 | 37 | #submit-btn { 38 | z-index: 999; 39 | } 40 | 41 | #viz { 42 | width: 100%; 43 | top: -30px; 44 | object-fit: scale-down; 45 | object-position: 0 100%; 46 | } 47 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | setuptools.setup( 4 | name='daam', 5 | version=eval(open('daam/_version.py').read().strip().split('=')[1]), 6 | author='Raphael Tang', 7 | license='MIT', 8 | url='https://github.com/castorini/daam', 9 | author_email='r33tang@uwaterloo.ca', 10 | description='What the DAAM: Interpreting Stable Diffusion Using Cross Attention.', 11 | install_requires=open('requirements.txt').read().strip().splitlines(), 12 | packages=setuptools.find_packages(), 13 | python_requires='>=3.8', 14 | entry_points={ 15 | 'console_scripts': [ 16 | 'daam = daam.run.generate:main', 17 | 'daam-demo = daam.run.demo:main', 18 | ] 19 | } 20 | ) 21 | --------------------------------------------------------------------------------