├── .github ├── auto_assign-issues.yml └── auto_assign.yml ├── LICENSE ├── MANIFEST.in ├── README.md ├── assets ├── corgi_eiffel_tower.png ├── corgi_eiffel_tower_box_1.png ├── corgi_eiffel_tower_step10.png ├── image_slider_cropped.gif ├── img2img_1.gif ├── inpaint_1.gif ├── pixel_attributions_1.png ├── pixel_attributions_inpaint_1.png └── token_attributions_1.png ├── notebooks ├── stable_diffusion_example_colab.ipynb ├── stable_diffusion_img2img_example.ipynb └── stable_diffusion_inpaint_example.ipynb ├── requirements.txt ├── setup.py └── src └── diffusers_interpret ├── __init__.py ├── attribution.py ├── data.py ├── dataviz └── image-slider │ ├── css │ └── index.css │ ├── index.html │ └── js │ └── index.js ├── explainer.py ├── explainers ├── __init__.py ├── latent_diffusion.py └── stable_diffusion.py ├── generated_images.py ├── pixel_attributions.py ├── saliency_map.py ├── token_attributions.py └── utils.py /.github/auto_assign-issues.yml: -------------------------------------------------------------------------------- 1 | # If enabled, auto-assigns users when a new issue is created 2 | # Defaults to true, allows you to install the app globally, and disable on a per-repo basis 3 | addAssignees: true 4 | 5 | # The list of users to assign to new issues. 6 | # If empty or not provided, the repository owner is assigned 7 | assignees: 8 | - JoaoLages 9 | -------------------------------------------------------------------------------- /.github/auto_assign.yml: -------------------------------------------------------------------------------- 1 | # Set to true to add reviewers to pull requests 2 | addReviewers: true 3 | 4 | # Set to true to add assignees to pull requests 5 | addAssignees: false 6 | 7 | # A list of reviewers to be added to pull requests (GitHub user name) 8 | reviewers: 9 | - JoaoLages 10 | 11 | 12 | # A list of keywords to be skipped the process that add reviewers if pull requests include it 13 | #skipKeywords: 14 | # - wip 15 | 16 | # A number of reviewers added to the pull request 17 | # Set 0 to add all the reviewers (default: 0) 18 | numberOfReviewers: 0 -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 João Lages 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include requirements.txt 2 | include LICENSE 3 | include README.md 4 | recursive-include src/diffusers_interpret/dataviz * 5 | 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | # Diffusers-Interpret 🤗🧨🕵️‍♀️ 4 | 5 | ![PyPI Latest Package Version](https://img.shields.io/pypi/v/diffusers-interpret?logo=pypi&style=flat&color=orange) ![GitHub License](https://img.shields.io/github/license/JoaoLages/diffusers-interpret?logo=github&style=flat&color=green) 6 | 7 | `diffusers-interpret` is a model explainability tool built on top of [🤗 Diffusers](https://github.com/huggingface/diffusers) 8 |
9 | 10 | ## Installation 11 | 12 | Install directly from PyPI: 13 | 14 | pip install --upgrade diffusers-interpret 15 | 16 | ## Usage 17 | 18 | Let's see how we can interpret the **[new 🎨🎨🎨 Stable Diffusion](https://github.com/huggingface/diffusers#new--stable-diffusion-is-now-fully-compatible-with-diffusers)!** 19 | 20 | 1. [Explanations for StableDiffusionPipeline](#explanations-for-stablediffusionpipeline) 21 | 2. [Explanations for StableDiffusionImg2ImgPipeline](#explanations-for-stablediffusionimg2imgpipeline) 22 | 3. [Explanations for StableDiffusionInpaintPipeline](#explanations-for-stablediffusioninpaintpipeline) 23 | 24 | ### Explanations for StableDiffusionPipeline 25 | [![Open In Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/JoaoLages/diffusers-interpret/blob/main/notebooks/stable_diffusion_example_colab.ipynb) 26 | 27 | ```python 28 | import torch 29 | from diffusers import StableDiffusionPipeline 30 | from diffusers_interpret import StableDiffusionPipelineExplainer 31 | 32 | pipe = StableDiffusionPipeline.from_pretrained( 33 | "CompVis/stable-diffusion-v1-4", 34 | use_auth_token=True, 35 | revision='fp16', 36 | torch_dtype=torch.float16 37 | ).to('cuda') 38 | 39 | # optional: reduce memory requirement with a speed trade off 40 | pipe.enable_attention_slicing() 41 | 42 | # pass pipeline to the explainer class 43 | explainer = StableDiffusionPipelineExplainer(pipe) 44 | 45 | # generate an image with `explainer` 46 | prompt = "A cute corgi with the Eiffel Tower in the background" 47 | with torch.autocast('cuda'): 48 | output = explainer( 49 | prompt, 50 | num_inference_steps=15 51 | ) 52 | ``` 53 | 54 | If you are having GPU memory problems, try reducing `n_last_diffusion_steps_to_consider_for_attributions`, `height`, `width` and/or `num_inference_steps`. 55 | ``` 56 | output = explainer( 57 | prompt, 58 | num_inference_steps=15, 59 | height=448, 60 | width=448, 61 | n_last_diffusion_steps_to_consider_for_attributions=5 62 | ) 63 | ``` 64 | 65 | You can completely deactivate token/pixel attributions computation by passing `n_last_diffusion_steps_to_consider_for_attributions=0`. 66 | 67 | Gradient checkpointing also reduces GPU usage, but makes computations a bit slower: 68 | ``` 69 | explainer = StableDiffusionPipelineExplainer(pipe, gradient_checkpointing=True) 70 | ``` 71 | 72 | To see the final generated image: 73 | ```python 74 | output.image 75 | ``` 76 | 77 | ![](assets/corgi_eiffel_tower.png) 78 | 79 | You can also check all the images that the diffusion process generated at the end of each step: 80 | ```python 81 | output.all_images_during_generation.show() 82 | ``` 83 | ![](assets/image_slider_cropped.gif) 84 | 85 | To analyse how a token in the input `prompt` influenced the generation, you can study the token attribution scores: 86 | ```python 87 | >>> output.token_attributions # (token, attribution) 88 | [('a', 1063.0526), 89 | ('cute', 415.62888), 90 | ('corgi', 6430.694), 91 | ('with', 1874.0208), 92 | ('the', 1223.2847), 93 | ('eiffel', 4756.4556), 94 | ('tower', 4490.699), 95 | ('in', 2463.1294), 96 | ('the', 655.4624), 97 | ('background', 3997.9395)] 98 | ``` 99 | 100 | Or their computed normalized version, in percentage: 101 | ```python 102 | >>> output.token_attributions.normalized # (token, attribution_percentage) 103 | [('a', 3.884), 104 | ('cute', 1.519), 105 | ('corgi', 23.495), 106 | ('with', 6.847), 107 | ('the', 4.469), 108 | ('eiffel', 17.378), 109 | ('tower', 16.407), 110 | ('in', 8.999), 111 | ('the', 2.395), 112 | ('background', 14.607)] 113 | ``` 114 | 115 | Or plot them! 116 | ```python 117 | output.token_attributions.plot(normalize=True) 118 | ``` 119 | ![](assets/token_attributions_1.png) 120 | 121 | 122 | `diffusers-interpret` also computes these token/pixel attributions for generating a particular part of the image. 123 | 124 | To do that, call `explainer` with a particular 2D bounding box defined in `explanation_2d_bounding_box`: 125 | 126 | ```python 127 | with torch.autocast('cuda'): 128 | output = explainer( 129 | prompt, 130 | num_inference_steps=15, 131 | explanation_2d_bounding_box=((70, 180), (400, 435)), # (upper left corner, bottom right corner) 132 | ) 133 | output.image 134 | ``` 135 | ![](assets/corgi_eiffel_tower_box_1.png) 136 | 137 | The generated image now has a **red bounding box** to indicate the region of the image that is being explained. 138 | 139 | The attributions are now computed only for the area specified in the image. 140 | 141 | ```python 142 | >>> output.token_attributions.normalized # (token, attribution_percentage) 143 | [('a', 1.891), 144 | ('cute', 1.344), 145 | ('corgi', 23.115), 146 | ('with', 11.995), 147 | ('the', 7.981), 148 | ('eiffel', 5.162), 149 | ('tower', 11.603), 150 | ('in', 11.99), 151 | ('the', 1.87), 152 | ('background', 23.05)] 153 | ``` 154 | 155 | ### Explanations for StableDiffusionImg2ImgPipeline 156 | [![Open In Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/JoaoLages/diffusers-interpret/blob/main/notebooks/stable_diffusion_img2img_example.ipynb) 157 | 158 | ```python 159 | import torch 160 | import requests 161 | from PIL import Image 162 | from io import BytesIO 163 | from diffusers import StableDiffusionImg2ImgPipeline 164 | from diffusers_interpret import StableDiffusionImg2ImgPipelineExplainer 165 | 166 | 167 | pipe = StableDiffusionImg2ImgPipeline.from_pretrained( 168 | "CompVis/stable-diffusion-v1-4", 169 | use_auth_token=True, 170 | ).to('cuda') 171 | 172 | explainer = StableDiffusionImg2ImgPipelineExplainer(pipe) 173 | 174 | prompt = "A fantasy landscape, trending on artstation" 175 | 176 | # let's download an initial image 177 | url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" 178 | 179 | response = requests.get(url) 180 | init_image = Image.open(BytesIO(response.content)).convert("RGB") 181 | init_image = init_image.resize((448, 448)) 182 | 183 | with torch.autocast('cuda'): 184 | output = explainer( 185 | prompt=prompt, init_image=init_image, strength=0.75 186 | ) 187 | ``` 188 | 189 | `output` will have all the properties that were presented for [StableDiffusionPipeline](#explanations-for-stablediffusionpipeline). 190 | For example, to see the gif version of all the images during generation: 191 | ```python 192 | output.all_images_during_generation.gif() 193 | ``` 194 | ![](assets/img2img_1.gif) 195 | 196 | Additionally, it is also possible to visualize pixel attributions of the input image as a saliency map: 197 | ```python 198 | output.input_saliency_map.show() 199 | ``` 200 | ![](assets/pixel_attributions_1.png) 201 | 202 | or access their values directly: 203 | ```python 204 | >>> output.pixel_attributions 205 | array([[ 1.2714844 , 4.15625 , 7.8203125 , ..., 2.7753906 , 206 | 2.1308594 , 0.66552734], 207 | [ 5.5078125 , 11.1953125 , 4.8125 , ..., 5.6367188 , 208 | 6.8828125 , 3.0136719 ], 209 | ..., 210 | [ 0.21386719, 1.8867188 , 2.2109375 , ..., 3.0859375 , 211 | 2.7421875 , 0.7871094 ], 212 | [ 0.85791016, 0.6694336 , 1.71875 , ..., 3.8496094 , 213 | 1.4589844 , 0.5727539 ]], dtype=float32) 214 | ``` 215 | or the normalized version: 216 | ```python 217 | >>> output.pixel_attributions.normalized 218 | array([[7.16054201e-05, 2.34065039e-04, 4.40411852e-04, ..., 219 | 1.56300011e-04, 1.20002325e-04, 3.74801020e-05], 220 | [3.10180156e-04, 6.30479713e-04, 2.71022669e-04, ..., 221 | 3.17439699e-04, 3.87615233e-04, 1.69719147e-04], 222 | ..., 223 | [1.20442292e-05, 1.06253210e-04, 1.24512037e-04, ..., 224 | 1.73788882e-04, 1.54430119e-04, 4.43271674e-05], 225 | [4.83144104e-05, 3.77000870e-05, 9.67938031e-05, ..., 226 | 2.16796136e-04, 8.21647482e-05, 3.22554370e-05]], dtype=float32) 227 | ``` 228 | 229 | **Note:** Passing `explanation_2d_bounding_box` to the `explainer` will also change these values to explain a specific part of the **output** image. 230 | The attributions are always calculated for the model's input (image and text) with respect to the output image. 231 | 232 | ### Explanations for StableDiffusionInpaintPipeline 233 | [![Open In Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/JoaoLages/diffusers-interpret/blob/main/notebooks/stable_diffusion_inpaint_example.ipynb) 234 | 235 | Same as [StableDiffusionImg2ImgPipeline](#explanations-for-stablediffusionimg2imgpipeline), but now we also pass a `mask_image` argument to `explainer`. 236 | 237 | ```python 238 | import torch 239 | import requests 240 | from PIL import Image 241 | from io import BytesIO 242 | from diffusers import StableDiffusionInpaintPipeline 243 | from diffusers_interpret import StableDiffusionInpaintPipelineExplainer 244 | 245 | 246 | def download_image(url): 247 | response = requests.get(url) 248 | return Image.open(BytesIO(response.content)).convert("RGB") 249 | 250 | 251 | pipe = StableDiffusionInpaintPipeline.from_pretrained( 252 | "CompVis/stable-diffusion-v1-4", 253 | use_auth_token=True, 254 | ).to('cuda') 255 | 256 | explainer = StableDiffusionInpaintPipelineExplainer(pipe) 257 | 258 | prompt = "a cat sitting on a bench" 259 | 260 | img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png" 261 | mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png" 262 | 263 | init_image = download_image(img_url).resize((448, 448)) 264 | mask_image = download_image(mask_url).resize((448, 448)) 265 | 266 | with torch.autocast('cuda'): 267 | output = explainer( 268 | prompt=prompt, init_image=init_image, mask_image=mask_image, strength=0.75 269 | ) 270 | ``` 271 | 272 | `output` will have all the properties that were presented for [StableDiffusionImg2ImgPipeline](#explanations-for-stablediffusionimg2imgpipeline) and [StableDiffusionPipeline](#explanations-for-stablediffusionpipeline). 273 | For example, to see the gif version of all the images during generation: 274 | ```python 275 | output.all_images_during_generation.gif() 276 | ``` 277 | ![](assets/inpaint_1.gif) 278 | 279 | The only difference in `output` now, is that we can now see the masked part of the image: 280 | ```python 281 | output.input_saliency_map.show() 282 | ``` 283 | ![](assets/pixel_attributions_inpaint_1.png) 284 | 285 | Check other functionalities and more implementation examples in [here](https://github.com/JoaoLages/diffusers-interpret/blob/main/notebooks/). 286 | 287 | ## Future Development 288 | - [x] ~~Add interactive display of all the images that were generated in the diffusion process~~ 289 | - [x] ~~Add explainer for StableDiffusionImg2ImgPipeline~~ 290 | - [x] ~~Add explainer for StableDiffusionInpaintPipeline~~ 291 | - [ ] Add attentions visualization 292 | - [ ] Add unit tests 293 | - [ ] Website for documentation 294 | - [ ] Do not require another generation every time the `explanation_2d_bounding_box` argument is changed 295 | - [ ] Add interactive bounding-box and token attributions visualization 296 | - [ ] Add more explainability methods 297 | 298 | ## Contributing 299 | Feel free to open an [Issue](https://github.com/JoaoLages/diffusers-interpret/issues) or create a [Pull Request](https://github.com/JoaoLages/diffusers-interpret/pulls) and let's get started 🚀 300 | 301 | ## Credits 302 | 303 | A special thanks to: 304 | - [@andrewizbatista](https://github.com/andrewizbatista) for creating a great [image slider](https://github.com/JoaoLages/diffusers-interpret/pull/1) to show all the generated images during diffusion! 💪 305 | - [@TomPham97](https://github.com/TomPham97) for README improvements, the [GIF visualization](https://github.com/JoaoLages/diffusers-interpret/pull/9) and the [token attributions plot](https://github.com/JoaoLages/diffusers-interpret/pull/13) 😁 306 | -------------------------------------------------------------------------------- /assets/corgi_eiffel_tower.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JoaoLages/diffusers-interpret/2dd01d6a494dd0bc18b5c2fa99dda7132ae03f42/assets/corgi_eiffel_tower.png -------------------------------------------------------------------------------- /assets/corgi_eiffel_tower_box_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JoaoLages/diffusers-interpret/2dd01d6a494dd0bc18b5c2fa99dda7132ae03f42/assets/corgi_eiffel_tower_box_1.png -------------------------------------------------------------------------------- /assets/corgi_eiffel_tower_step10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JoaoLages/diffusers-interpret/2dd01d6a494dd0bc18b5c2fa99dda7132ae03f42/assets/corgi_eiffel_tower_step10.png -------------------------------------------------------------------------------- /assets/image_slider_cropped.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JoaoLages/diffusers-interpret/2dd01d6a494dd0bc18b5c2fa99dda7132ae03f42/assets/image_slider_cropped.gif -------------------------------------------------------------------------------- /assets/img2img_1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JoaoLages/diffusers-interpret/2dd01d6a494dd0bc18b5c2fa99dda7132ae03f42/assets/img2img_1.gif -------------------------------------------------------------------------------- /assets/inpaint_1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JoaoLages/diffusers-interpret/2dd01d6a494dd0bc18b5c2fa99dda7132ae03f42/assets/inpaint_1.gif -------------------------------------------------------------------------------- /assets/pixel_attributions_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JoaoLages/diffusers-interpret/2dd01d6a494dd0bc18b5c2fa99dda7132ae03f42/assets/pixel_attributions_1.png -------------------------------------------------------------------------------- /assets/pixel_attributions_inpaint_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JoaoLages/diffusers-interpret/2dd01d6a494dd0bc18b5c2fa99dda7132ae03f42/assets/pixel_attributions_inpaint_1.png -------------------------------------------------------------------------------- /assets/token_attributions_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JoaoLages/diffusers-interpret/2dd01d6a494dd0bc18b5c2fa99dda7132ae03f42/assets/token_attributions_1.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | transformers>=4.21.1 2 | setuptools>=49.6.0 3 | torch>=1.9.1 4 | diffusers~=0.3.0 5 | scipy>=1.7.3 6 | ftfy>=6.1.1 7 | cmapy>=0.6.6 8 | matplotlib>=3.5.3 9 | opencv-python>=4.6.0 -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | 4 | with open('README.md', encoding='utf-8') as f: 5 | long_description = f.read() 6 | 7 | with open('requirements.txt', encoding='utf-8') as f: 8 | required = f.read().splitlines() 9 | 10 | setup( 11 | name='diffusers-interpret', 12 | version='0.5.0', 13 | description='diffusers-interpret: model explainability for 🤗 Diffusers', 14 | long_description=long_description, 15 | long_description_content_type='text/markdown', 16 | url='https://github.com/JoaoLages/diffusers-interpret', 17 | author='Joao Lages', 18 | author_email='joaop.glages@gmail.com', 19 | license='MIT', 20 | packages=find_packages('src'), 21 | package_dir={'': 'src'}, 22 | include_package_data=True, 23 | install_requires=required 24 | ) -------------------------------------------------------------------------------- /src/diffusers_interpret/__init__.py: -------------------------------------------------------------------------------- 1 | from .explainer import BasePipelineExplainer, BasePipelineImg2ImgExplainer 2 | from .explainers.latent_diffusion import LDMTextToImagePipelineExplainer 3 | from .explainers.stable_diffusion import StableDiffusionPipelineExplainer, StableDiffusionImg2ImgPipelineExplainer, \ 4 | StableDiffusionInpaintPipelineExplainer 5 | from .data import PipelineExplainerOutput, PipelineImg2ImgExplainerOutput -------------------------------------------------------------------------------- /src/diffusers_interpret/attribution.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Optional, List 2 | 3 | import torch 4 | 5 | from diffusers_interpret.data import AttributionAlgorithm 6 | 7 | 8 | def gradients_attribution( 9 | pred_logits: torch.Tensor, 10 | input_embeds: Tuple[torch.Tensor], 11 | attribution_algorithms: List[AttributionAlgorithm], 12 | explanation_2d_bounding_box: Optional[Tuple[Tuple[int, int], Tuple[int, int]]] = None, 13 | retain_graph: bool = False 14 | ) -> List[torch.Tensor]: 15 | # TODO: add description 16 | 17 | assert len(pred_logits.shape) == 3 18 | if explanation_2d_bounding_box: 19 | upper_left, bottom_right = explanation_2d_bounding_box 20 | pred_logits = pred_logits[upper_left[0]: bottom_right[0], upper_left[1]: bottom_right[1], :] 21 | 22 | assert len(input_embeds) == len(attribution_algorithms) 23 | 24 | # Construct tuple of scalar tensors with all `pred_logits` 25 | # The code below is equivalent to `tuple_of_pred_logits = tuple(torch.flatten(pred_logits))`, 26 | # but for some reason the gradient calculation is way faster if the tensor is flattened like this 27 | tuple_of_pred_logits = [] 28 | for x in pred_logits: 29 | for y in x: 30 | for z in y: 31 | tuple_of_pred_logits.append(z) 32 | tuple_of_pred_logits = tuple(tuple_of_pred_logits) 33 | 34 | # get the sum of back-prop gradients for all predictions with respect to the inputs 35 | if torch.is_autocast_enabled(): 36 | # FP16 may cause NaN gradients https://github.com/pytorch/pytorch/issues/40497 37 | # TODO: this is still an issue, the code below does not solve it 38 | with torch.autocast(input_embeds[0].device.type, enabled=False): 39 | grads = torch.autograd.grad(tuple_of_pred_logits, input_embeds, retain_graph=retain_graph) 40 | else: 41 | grads = torch.autograd.grad(tuple_of_pred_logits, input_embeds, retain_graph=retain_graph) 42 | 43 | if torch.isnan(grads[-1]).any(): 44 | raise RuntimeError( 45 | "Found NaNs while calculating gradients. " 46 | "This is a known issue of FP16 (https://github.com/pytorch/pytorch/issues/40497).\n" 47 | "Try to rerun the code or deactivate FP16 to not face this issue again." 48 | ) 49 | 50 | # Aggregate 51 | aggregated_grads = [] 52 | for grad, inp, attr_alg in zip(grads, input_embeds, attribution_algorithms): 53 | 54 | if attr_alg == AttributionAlgorithm.GRAD_X_INPUT: 55 | aggregated_grads.append(torch.norm(grad * inp, dim=-1)) 56 | elif attr_alg == AttributionAlgorithm.MAX_GRAD: 57 | aggregated_grads.append(grad.abs().max(-1).values) 58 | elif attr_alg == AttributionAlgorithm.MEAN_GRAD: 59 | aggregated_grads.append(grad.abs().mean(-1).values) 60 | elif attr_alg == AttributionAlgorithm.MIN_GRAD: 61 | aggregated_grads.append(grad.abs().min(-1).values) 62 | else: 63 | raise NotImplementedError(f"aggregation type `{attr_alg}` not implemented") 64 | 65 | return aggregated_grads -------------------------------------------------------------------------------- /src/diffusers_interpret/data.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from dataclasses import dataclass 3 | from enum import Enum 4 | from typing import Union, List, Optional, Tuple, Any 5 | 6 | import numpy as np 7 | import torch 8 | from PIL.Image import Image 9 | 10 | from diffusers_interpret.generated_images import GeneratedImages 11 | from diffusers_interpret.pixel_attributions import PixelAttributions 12 | from diffusers_interpret.saliency_map import SaliencyMap 13 | from diffusers_interpret.token_attributions import TokenAttributions 14 | 15 | 16 | @dataclass 17 | class BaseMimicPipelineCallOutput: 18 | """ 19 | Output class for BasePipelineExplainer._mimic_pipeline_call 20 | 21 | Args: 22 | images (`List[Image]` or `torch.Tensor`) 23 | List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, 24 | num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. 25 | nsfw_content_detected (`Optional[List[bool]]`) 26 | List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work" 27 | (nsfw) content. 28 | all_images_during_generation (`Optional[Union[List[List[Image]]], List[torch.Tensor]]`) 29 | A list with all the batch images generated during diffusion 30 | """ 31 | images: Union[List[Image], torch.Tensor] 32 | nsfw_content_detected: Optional[List[bool]] = None 33 | all_images_during_generation: Optional[Union[List[List[Image]], List[torch.Tensor]]] = None 34 | 35 | def __getitem__(self, item): 36 | return getattr(self, item) 37 | 38 | def __setitem__(self, key, value): 39 | setattr(self, key, value) 40 | 41 | 42 | @dataclass 43 | class PipelineExplainerOutput: 44 | """ 45 | Output class for BasePipelineExplainer.__call__ if `init_image=None` and `explanation_2d_bounding_box=None` 46 | 47 | Args: 48 | image (`Image` or `torch.Tensor`) 49 | The denoised PIL output image or torch.Tensor of shape `(height, width, num_channels)`. 50 | nsfw_content_detected (`Optional[bool]`) 51 | A flag denoting whether the generated image likely represents "not-safe-for-work" 52 | (nsfw) content. 53 | all_images_during_generation (`Optional[Union[GeneratedImages, List[torch.Tensor]]]`) 54 | A GeneratedImages object to visualize all the generated images during diffusion OR a list of tensors of those images 55 | token_attributions (`Optional[TokenAttributions]`) 56 | TokenAttributions that contains a list of tuples with (token, token_attribution) 57 | """ 58 | image: Union[Image, torch.Tensor] 59 | nsfw_content_detected: Optional[bool] = None 60 | all_images_during_generation: Optional[Union[GeneratedImages, List[torch.Tensor]]] = None 61 | token_attributions: Optional[TokenAttributions] = None 62 | 63 | def __getitem__(self, item): 64 | return getattr(self, item) 65 | 66 | def __setitem__(self, key, value): 67 | setattr(self, key, value) 68 | 69 | def __getattr__(self, attr): 70 | if attr == 'normalized_token_attributions': 71 | warnings.warn( 72 | f"`normalized_token_attributions` is deprecated as an attribute of `{self.__class__.__name__}` " 73 | f"and will be removed in a future version. Consider using `output.token_attributions.normalized` instead", 74 | DeprecationWarning, stacklevel=2 75 | ) 76 | return self.token_attributions.normalized 77 | raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{attr}'") 78 | 79 | 80 | @dataclass 81 | class PipelineExplainerForBoundingBoxOutput(PipelineExplainerOutput): 82 | """ 83 | Output class for BasePipelineExplainer.__call__ if `init_image=None` and `explanation_2d_bounding_box is not None` 84 | 85 | Args: 86 | image (`Image` or `torch.Tensor`) 87 | The denoised PIL output image or torch.Tensor of shape `(height, width, num_channels)`. 88 | nsfw_content_detected (`Optional[bool]`) 89 | A flag denoting whether the generated image likely represents "not-safe-for-work" 90 | (nsfw) content. 91 | all_images_during_generation (`Optional[Union[GeneratedImages, List[torch.Tensor]]]`) 92 | A GeneratedImages object to visualize all the generated images during diffusion OR a list of tensors of those images 93 | token_attributions (`Optional[TokenAttributions]`) 94 | TokenAttributions that contains a list of tuples with (token, token_attribution) 95 | explanation_2d_bounding_box: (`Tuple[Tuple[int, int], Tuple[int, int]]`) 96 | Tuple with the bounding box coordinates where the attributions were calculated for. 97 | The tuple is like (upper left corner, bottom right corner). Example: `((0, 0), (300, 300))` 98 | """ 99 | explanation_2d_bounding_box: Tuple[Tuple[int, int], Tuple[int, int]] = None # (upper left corner, bottom right corner) 100 | 101 | 102 | @dataclass 103 | class PipelineImg2ImgExplainerOutput(PipelineExplainerOutput): 104 | """ 105 | Output class for BasePipelineExplainer.__call__ if `init_image is not None` and `explanation_2d_bounding_box=None` 106 | 107 | Args: 108 | image (`Image` or `torch.Tensor`) 109 | The denoised PIL output image or torch.Tensor of shape `(height, width, num_channels)`. 110 | nsfw_content_detected (`Optional[bool]`) 111 | A flag denoting whether the generated image likely represents "not-safe-for-work" 112 | (nsfw) content. 113 | all_images_during_generation (`Optional[Union[GeneratedImages, List[torch.Tensor]]]`) 114 | A GeneratedImages object to visualize all the generated images during diffusion OR a list of tensors of those images 115 | token_attributions (`Optional[TokenAttributions]`) 116 | TokenAttributions that contains a list of tuples with (token, token_attribution) 117 | pixel_attributions (`Optional[PixelAttributions]`) 118 | PixelAttributions that is a numpy array of shape `(height, width)` with an attribution score per pixel in the input image 119 | input_saliency_map (`Optional[SaliencyMap]`) 120 | A SaliencyMap object to visualize the pixel attributions of the input image 121 | """ 122 | pixel_attributions: Optional[PixelAttributions] = None 123 | 124 | def __getattr__(self, attr): 125 | if attr == 'normalized_pixel_attributions': 126 | warnings.warn( 127 | f"`normalized_pixel_attributions` is deprecated as an attribute of `{self.__class__.__name__}` " 128 | f"and will be removed in a future version. Consider using `output.pixel_attributions.normalized` instead", 129 | DeprecationWarning, stacklevel=2 130 | ) 131 | return self.token_attributions.normalized 132 | elif attr == 'input_saliency_map': 133 | return self.pixel_attributions.saliency_map 134 | return super().__getattr__(attr) 135 | 136 | 137 | @dataclass 138 | class PipelineImg2ImgExplainerForBoundingBoxOutputOutput(PipelineExplainerForBoundingBoxOutput, PipelineImg2ImgExplainerOutput): 139 | """ 140 | Output class for BasePipelineExplainer.__call__ if `init_image is not None` and `explanation_2d_bounding_box=None` 141 | 142 | Args: 143 | image (`Image` or `torch.Tensor`) 144 | The denoised PIL output image or torch.Tensor of shape `(height, width, num_channels)`. 145 | nsfw_content_detected (`Optional[bool]`) 146 | A flag denoting whether the generated image likely represents "not-safe-for-work" 147 | (nsfw) content. 148 | all_images_during_generation (`Optional[Union[GeneratedImages, List[torch.Tensor]]]`) 149 | A GeneratedImages object to visualize all the generated images during diffusion OR a list of tensors of those images 150 | token_attributions (`Optional[TokenAttributions]`) 151 | TokenAttributions that contains a list of tuples with (token, token_attribution) 152 | pixel_attributions (`Optional[np.ndarray]`) 153 | PixelAttributions that is a numpy array of shape `(height, width)` with an attribution score per pixel in the input image 154 | input_saliency_map (`Optional[SaliencyMap]`) 155 | A SaliencyMap object to visualize the pixel attributions of the input image 156 | explanation_2d_bounding_box: (`Tuple[Tuple[int, int], Tuple[int, int]]`) 157 | Tuple with the bounding box coordinates where the attributions were calculated for. 158 | The tuple is like (upper left corner, bottom right corner). Example: `((0, 0), (300, 300))` 159 | """ 160 | pass 161 | 162 | 163 | class ExplicitEnum(str, Enum): 164 | """ 165 | Enum with more explicit error message for missing values. 166 | """ 167 | 168 | @classmethod 169 | def _missing_(cls, value): 170 | raise ValueError( 171 | f"{value} is not a valid {cls.__name__}, please select one of {list(cls._value2member_map_.keys())}" 172 | ) 173 | 174 | 175 | class AttributionAlgorithm(ExplicitEnum): 176 | """ 177 | Possible values for `tokens_attribution_method` and `pixels_attribution_method` arguments in `AttributionMethods` 178 | """ 179 | GRAD_X_INPUT = "grad_x_input" 180 | MAX_GRAD = "max_grad" 181 | MEAN_GRAD = "mean_grad" 182 | MIN_GRAD = "min_grad" 183 | 184 | 185 | @dataclass 186 | class AttributionMethods: 187 | tokens_attribution_method: Union[str, AttributionAlgorithm] = AttributionAlgorithm.GRAD_X_INPUT 188 | pixels_attribution_method: Optional[Union[str, AttributionAlgorithm]] = AttributionAlgorithm.MAX_GRAD -------------------------------------------------------------------------------- /src/diffusers_interpret/dataviz/image-slider/css/index.css: -------------------------------------------------------------------------------- 1 | :root { 2 | --animation-time: 100ms; 3 | --image-size: 296px; 4 | --loading-margin: 340px; 5 | --error-margin: 390px; 6 | --border-radius: 4px; 7 | --color-primary: #ff6347; 8 | --color-primary-hover: #e46b55; 9 | --color-primary-active: #9acd32; 10 | --color-primary-disabled: #aa8983; 11 | --color-loading: #aa8983; 12 | } 13 | 14 | html { 15 | font-family: system-ui, -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, 16 | Oxygen, Ubuntu, Cantarell, "Open Sans", "Helvetica Neue", sans-serif; 17 | } 18 | 19 | #slider { 20 | display: none; 21 | flex-direction: row; 22 | justify-content: flex-start; 23 | align-items: flex-start; 24 | user-select: none; 25 | } 26 | 27 | #error { 28 | display: none; 29 | flex-direction: row; 30 | justify-content: flex-start; 31 | align-items: center; 32 | user-select: none; 33 | height: 200px; 34 | padding-left: var(--error-margin); 35 | } 36 | 37 | #error span { 38 | font-size: 1.2rem; 39 | color: var(--color-primary); 40 | font-weight: bold; 41 | letter-spacing: 0.5px; 42 | } 43 | 44 | #loading { 45 | display: flex; 46 | flex-direction: row; 47 | justify-content: flex-start; 48 | align-items: center; 49 | user-select: none; 50 | height: 200px; 51 | opacity: 0.5; 52 | padding-left: var(--loading-margin); 53 | -webkit-animation: pulsate 1s ease-out; 54 | -moz-animation: pulsate 1s ease-out; 55 | -ms-animation: pulsate 1s ease-out; 56 | -o-animation: pulsate 1s ease-out; 57 | animation: pulsate 1s ease-out; 58 | -webkit-animation-iteration-count: infinite; 59 | -moz-animation-iteration-count: infinite; 60 | -ms-animation-iteration-count: infinite; 61 | -o-animation-iteration-count: infinite; 62 | animation-iteration-count: infinite; 63 | } 64 | 65 | #loading span { 66 | font-size: 1rem; 67 | margin-left: 0.5rem; 68 | color: var(--color-loading); 69 | letter-spacing: 1px; 70 | } 71 | 72 | .slider-item { 73 | padding: 5px; 74 | text-align: center; 75 | } 76 | 77 | .slider-image { 78 | display: block; 79 | width: var(--image-size); 80 | height: var(--image-size); 81 | margin-bottom: 10px; 82 | background-color: #222; 83 | background-position: center; 84 | background-repeat: no-repeat; 85 | background-size: contain; 86 | -webkit-transition: all var(--animation-time) linear; 87 | -moz-transition: all var(--animation-time) linear; 88 | -ms-transition: all var(--animation-time) linear; 89 | -o-transition: all var(--animation-time) linear; 90 | transition: all var(--animation-time) linear; 91 | } 92 | 93 | .slider-title { 94 | display: block; 95 | font-size: 1rem; 96 | font-weight: bold; 97 | user-select: none; 98 | } 99 | 100 | .slider-iteration { 101 | display: block; 102 | font-size: 0.8rem; 103 | margin-top: 6px; 104 | } 105 | 106 | .slide-actions { 107 | display: flex; 108 | flex-direction: row; 109 | justify-content: space-between; 110 | align-items: center; 111 | background-color: var(--color-primary); 112 | border-radius: var(--border-radius); 113 | color: #fff; 114 | } 115 | 116 | button { 117 | color: #fff; 118 | background-color: var(--color-primary); 119 | border: 0; 120 | font-size: 0.8rem; 121 | padding-top: 14px; 122 | padding-bottom: 14px; 123 | cursor: pointer; 124 | user-select: none; 125 | -webkit-transition: all var(--animation-time) linear; 126 | -moz-transition: all var(--animation-time) linear; 127 | -ms-transition: all var(--animation-time) linear; 128 | -o-transition: all var(--animation-time) linear; 129 | transition: all var(--animation-time) linear; 130 | } 131 | 132 | button:disabled { 133 | cursor: default; 134 | background-color: var(--color-primary-disabled) !important; 135 | } 136 | 137 | button:hover { 138 | background-color: var(--color-primary-hover); 139 | } 140 | 141 | button:active { 142 | background-color: var(--color-primary-active); 143 | } 144 | 145 | .slider-iterations { 146 | font-size: 0.8rem; 147 | font-family: monospace; 148 | user-select: none; 149 | cursor: default; 150 | } 151 | 152 | #slider-action-prev { 153 | border-top-left-radius: var(--border-radius); 154 | border-bottom-left-radius: var(--border-radius); 155 | padding-left: 12px; 156 | padding-right: 18px; 157 | } 158 | 159 | #slider-action-next { 160 | border-top-right-radius: var(--border-radius); 161 | border-bottom-right-radius: var(--border-radius); 162 | padding-left: 18px; 163 | padding-right: 12px; 164 | } 165 | 166 | #loading-ripple { 167 | display: inline-block; 168 | position: relative; 169 | width: 80px; 170 | height: 80px; 171 | } 172 | 173 | #loading-ripple div { 174 | position: absolute; 175 | border: 4px solid var(--color-loading); 176 | opacity: 1; 177 | border-radius: 50%; 178 | animation: ripple 1s cubic-bezier(0, 0.2, 0.8, 1) infinite; 179 | } 180 | 181 | #loading-ripple div:nth-child(2) { 182 | animation-delay: -0.5s; 183 | } 184 | 185 | @keyframes ripple { 186 | 0% { 187 | top: 36px; 188 | left: 36px; 189 | width: 0; 190 | height: 0; 191 | opacity: 0; 192 | } 193 | 4.9% { 194 | top: 36px; 195 | left: 36px; 196 | width: 0; 197 | height: 0; 198 | opacity: 0; 199 | } 200 | 5% { 201 | top: 36px; 202 | left: 36px; 203 | width: 0; 204 | height: 0; 205 | opacity: 1; 206 | } 207 | 100% { 208 | top: 0px; 209 | left: 0px; 210 | width: 72px; 211 | height: 72px; 212 | opacity: 0; 213 | } 214 | } 215 | 216 | @keyframes pulsate { 217 | 0% { 218 | opacity: 0.5; 219 | } 220 | 50% { 221 | opacity: 1; 222 | } 223 | 100% { 224 | opacity: 0.5; 225 | } 226 | } 227 | -------------------------------------------------------------------------------- /src/diffusers_interpret/dataviz/image-slider/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | Image Slider 4 | 5 | 6 | 7 | 8 |
9 |
10 |
11 |
12 |
13 | Preparing Data... 14 |
15 |
16 | Unexpected Error 17 |
18 |
19 |
20 |
21 | First Iteration 22 |
23 |
24 |
25 |
26 | 27 | 28 | 0 29 | / 30 | 0 31 | 32 | 33 |
34 |
35 |
36 |
37 | Final Iteration 38 |
39 |
40 | 41 | 42 | 43 | 44 | -------------------------------------------------------------------------------- /src/diffusers_interpret/dataviz/image-slider/js/index.js: -------------------------------------------------------------------------------- 1 | // @ts-check 2 | 3 | ((d) => { 4 | /** 5 | * Constants 6 | */ 7 | const ID_SLIDER = "slider"; 8 | const ID_LOADING = "loading"; 9 | const ID_ERROR = "error"; 10 | // 11 | const ID_BUTTON_PREV = "slider-action-prev"; 12 | const ID_BUTTON_NEXT = "slider-action-next"; 13 | // 14 | const ID_IMAGE_FIRST = "slider-image-first"; 15 | const ID_IMAGE_CURRENT = "slider-image-current"; 16 | const ID_IMAGE_FINAL = "slider-image-final"; 17 | // 18 | const ID_ITERATIONS_CURRENT = "slider-iterations-current"; 19 | const ID_ITERATIONS_FINAL = "slider-iterations-final"; 20 | 21 | /** 22 | * @type {{image: string}[]} 23 | */ 24 | let imageList = []; 25 | 26 | /** 27 | * @type {number} 28 | */ 29 | let currentIndex = 0; 30 | 31 | /** 32 | * Initialize the Image Slider 33 | * 34 | * @param {string} jsonPayload 35 | */ 36 | function initialize(jsonPayload) { 37 | const isOK = parseJSONPayload(jsonPayload); 38 | 39 | if (isOK) { 40 | handleSuccessState(); 41 | } else { 42 | handleErrorState(); 43 | } 44 | } 45 | 46 | /** 47 | * Parse the JSON payload 48 | * 49 | * @param {string} jsonPayload 50 | * 51 | * @return {boolean} 52 | */ 53 | function parseJSONPayload(jsonPayload) { 54 | if (Array.isArray(jsonPayload)) { 55 | imageList = jsonPayload; 56 | 57 | return true; 58 | } 59 | 60 | return false; 61 | } 62 | 63 | /** 64 | * Render the application success state 65 | */ 66 | function handleSuccessState() { 67 | /** 68 | * Update First Image 69 | */ 70 | const imageFirst = imageList[0] ?? {}; 71 | updateImageAttributes(ID_IMAGE_FIRST, { 72 | backgroundImage: imageFirst?.image ?? null, 73 | }); 74 | 75 | /** 76 | * Update Current Image 77 | */ 78 | const imageCurrent = imageList[currentIndex] ?? {}; 79 | updateImageAttributes(ID_IMAGE_CURRENT, { 80 | backgroundImage: imageCurrent?.image ?? null, 81 | }); 82 | 83 | /** 84 | * Update Final Image 85 | */ 86 | const imageFinal = imageList[imageList.length - 1] ?? {}; 87 | updateImageAttributes(ID_IMAGE_FINAL, { 88 | backgroundImage: imageFinal?.image ?? null, 89 | }); 90 | 91 | /** 92 | * Update the iteration values 93 | */ 94 | updateIterationValues(); 95 | 96 | /** 97 | * Set Prev Button initial state 98 | */ 99 | updateButtonAttributes(ID_BUTTON_PREV, { 100 | disabled: currentIndex === 0, 101 | }); 102 | 103 | /** 104 | * Set Next Button initial state 105 | */ 106 | const imageLen = imageList.length; 107 | updateButtonAttributes(ID_BUTTON_NEXT, { 108 | disabled: currentIndex === imageLen - 1, 109 | }); 110 | 111 | /** 112 | * Initialize a `click` event in the Prev Button 113 | */ 114 | const $actionPrev = d.getElementById(ID_BUTTON_PREV); 115 | if ($actionPrev) $actionPrev.addEventListener("click", prevImageAction); 116 | 117 | /** 118 | * Initialize a `click` event in the Next Button 119 | */ 120 | const $actionNext = d.getElementById(ID_BUTTON_NEXT); 121 | if ($actionNext) $actionNext.addEventListener("click", nextImageAction); 122 | 123 | hideElement(ID_LOADING); 124 | showElement(ID_SLIDER); 125 | } 126 | 127 | /** 128 | * Render the application error state 129 | */ 130 | function handleErrorState() { 131 | hideElement(ID_LOADING); 132 | showElement(ID_ERROR); 133 | } 134 | 135 | /** 136 | * Click the `Prev` image button 137 | */ 138 | function prevImageAction() { 139 | const canGoPrev = currentIndex > 0; 140 | 141 | if (canGoPrev) { 142 | currentIndex--; 143 | 144 | const backgroundImage = imageList[currentIndex]?.image ?? null; 145 | 146 | updateImageAttributes(ID_IMAGE_CURRENT, { backgroundImage }); 147 | updateButtonAttributes(ID_BUTTON_NEXT, { disabled: false }); 148 | updateIterationValues(); 149 | 150 | const disablePrev = currentIndex === 0; 151 | 152 | if (disablePrev) { 153 | updateButtonAttributes(ID_BUTTON_PREV, { disabled: true }); 154 | } else { 155 | updateButtonAttributes(ID_BUTTON_NEXT, { disabled: false }); 156 | } 157 | } 158 | } 159 | 160 | /** 161 | * Click the `Next` image button 162 | */ 163 | function nextImageAction() { 164 | const imageLen = imageList.length; 165 | const canGoNext = currentIndex < imageLen - 1; 166 | 167 | if (canGoNext) { 168 | currentIndex++; 169 | 170 | const backgroundImage = imageList[currentIndex]?.image ?? null; 171 | 172 | updateImageAttributes(ID_IMAGE_CURRENT, { backgroundImage }); 173 | updateButtonAttributes(ID_BUTTON_PREV, { disabled: false }); 174 | updateIterationValues(); 175 | 176 | const disableNext = currentIndex === imageLen - 1; 177 | 178 | if (disableNext) { 179 | updateButtonAttributes(ID_BUTTON_NEXT, { disabled: true }); 180 | } else { 181 | updateButtonAttributes(ID_BUTTON_PREV, { disabled: false }); 182 | } 183 | } 184 | } 185 | 186 | /** 187 | * Update the iteration values 188 | */ 189 | function updateIterationValues() { 190 | const $iterationCurrent = d.getElementById(ID_ITERATIONS_CURRENT); 191 | const $iterationFinal = d.getElementById(ID_ITERATIONS_FINAL); 192 | 193 | if ($iterationCurrent) { 194 | $iterationCurrent.innerText = `${currentIndex + 1}`; 195 | } 196 | 197 | if ($iterationFinal) { 198 | const len = imageList?.length ?? 0; 199 | 200 | $iterationFinal.innerText = `${len}`; 201 | } 202 | } 203 | 204 | /** 205 | * Show an element 206 | * 207 | * @param {ID_SLIDER | ID_LOADING | ID_ERROR} id 208 | */ 209 | function showElement(id) { 210 | const $element = d.getElementById(id); 211 | 212 | if ($element) $element.style.display = "flex"; 213 | } 214 | 215 | /** 216 | * Hide an element 217 | * 218 | * @param {ID_SLIDER | ID_LOADING | ID_ERROR} id 219 | */ 220 | function hideElement(id) { 221 | const $element = d.getElementById(id); 222 | 223 | if ($element) $element.style.display = "none"; 224 | } 225 | 226 | /** 227 | * Update the Image attributes 228 | * 229 | * @param {ID_IMAGE_FIRST | ID_IMAGE_CURRENT | ID_IMAGE_FINAL} id 230 | * @param {{ backgroundImage: string | null}} options 231 | */ 232 | function updateImageAttributes(id, options) { 233 | const { backgroundImage } = options ?? {}; 234 | 235 | if (id) { 236 | const $img = d.getElementById(id); 237 | 238 | if ($img && backgroundImage) { 239 | $img.style.backgroundImage = `url("${backgroundImage}")`; 240 | } 241 | } 242 | } 243 | 244 | /** 245 | * Update the Prev/Next Button attributes 246 | * 247 | * @param {ID_BUTTON_PREV | ID_BUTTON_NEXT} id 248 | * @param {{disabled: boolean}} options 249 | */ 250 | function updateButtonAttributes(id, options) { 251 | const { disabled } = options ?? {}; 252 | 253 | if (id) { 254 | const $button = d.getElementById(id); 255 | 256 | // @ts-ignore 257 | if ($button) $button.disabled = disabled; 258 | } 259 | } 260 | 261 | /** 262 | * Trigger the `INITIALIZE_IS_READY` event when the Document is ready. 263 | */ 264 | d.addEventListener("DOMContentLoaded", function isReady() { 265 | const $body = d.querySelector("body"); 266 | 267 | if ($body) { 268 | const e = new CustomEvent("INITIALIZE_IS_READY", { 269 | detail: { initialize }, 270 | }); 271 | 272 | $body.dispatchEvent(e); 273 | } 274 | }); 275 | })(document); 276 | -------------------------------------------------------------------------------- /src/diffusers_interpret/explainer.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import List, Optional, Union, Tuple, Set, Dict, Any 3 | 4 | import torch 5 | from PIL import ImageDraw 6 | from PIL.Image import Image 7 | from diffusers import DiffusionPipeline 8 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import preprocess 9 | from transformers import BatchEncoding, PreTrainedTokenizerBase 10 | 11 | from diffusers_interpret.attribution import gradients_attribution 12 | from diffusers_interpret.data import PipelineExplainerOutput, PipelineImg2ImgExplainerOutput, \ 13 | BaseMimicPipelineCallOutput, AttributionMethods, AttributionAlgorithm, PipelineExplainerForBoundingBoxOutput, \ 14 | PipelineImg2ImgExplainerForBoundingBoxOutputOutput 15 | from diffusers_interpret.generated_images import GeneratedImages 16 | from diffusers_interpret.pixel_attributions import PixelAttributions 17 | from diffusers_interpret.saliency_map import SaliencyMap 18 | from diffusers_interpret.token_attributions import TokenAttributions 19 | from diffusers_interpret.utils import clean_token_from_prefixes_and_suffixes 20 | 21 | 22 | class BasePipelineExplainer(ABC): 23 | """ 24 | Core base class to explain all DiffusionPipeline: text2img, img2img and inpaint pipelines 25 | """ 26 | 27 | def __init__(self, pipe: DiffusionPipeline, verbose: bool = True, gradient_checkpointing: bool = False) -> None: 28 | self.pipe = pipe 29 | self.verbose = verbose 30 | self.pipe._progress_bar_config = { 31 | **(getattr(self.pipe, '_progress_bar_config', {}) or {}), 32 | 'disable': not verbose 33 | } 34 | self.gradient_checkpointing = gradient_checkpointing 35 | if self.gradient_checkpointing: 36 | self.gradient_checkpointing_enable() 37 | 38 | def _preprocess_input( 39 | self, 40 | prompt: str, 41 | init_image: Optional[Union[torch.FloatTensor, Image]] = None, 42 | mask_image: Optional[Union[torch.FloatTensor, Image]] = None 43 | ) -> Tuple[Any, Any, Any]: 44 | return prompt, init_image, mask_image 45 | 46 | def __call__( 47 | self, 48 | prompt: str, 49 | init_image: Optional[Union[torch.FloatTensor, Image]] = None, 50 | mask_image: Optional[Union[torch.FloatTensor, Image]] = None, 51 | attribution_method: Union[str, AttributionMethods] = None, 52 | explanation_2d_bounding_box: Optional[Tuple[Tuple[int, int], Tuple[int, int]]] = None, 53 | consider_special_tokens: bool = False, 54 | clean_token_prefixes_and_suffixes: bool = True, 55 | run_safety_checker: bool = False, 56 | n_last_diffusion_steps_to_consider_for_attributions: Optional[int] = None, 57 | get_images_for_all_inference_steps: bool = True, 58 | output_type: Optional[str] = 'pil', 59 | **kwargs 60 | ) -> Union[ 61 | PipelineExplainerOutput, 62 | PipelineExplainerForBoundingBoxOutput, 63 | PipelineImg2ImgExplainerOutput, 64 | PipelineImg2ImgExplainerForBoundingBoxOutputOutput 65 | ]: 66 | """ 67 | Calls a DiffusionPipeline and generates explanations for a given prompt. 68 | 69 | Args: 70 | prompt (`str`): 71 | Input string for the diffusion model 72 | init_image (`torch.FloatTensor` or `PIL.Image.Image`, *optional*): 73 | `Image`, or tensor representing an image batch, that will be used as the starting point for the process. 74 | If provided, output will be of type `PipelineImg2ImgExplainerOutput` or `PipelineImg2ImgExplainerForBoundingBoxOutputOutput`. 75 | mask_image (`torch.FloatTensor` or `PIL.Image.Image`, *optional*): 76 | `Image`, or tensor representing an image batch, to mask `init_image`. White pixels in the mask will be 77 | replaced by noise and therefore repainted, while black pixels will be preserved. The mask image will be 78 | converted to a single channel (luminance) before use. 79 | attribution_method (`Union[str, AttributionMethods]`, *optional*): 80 | `AttributionMethods` or `str` with the attribution algorithms to compute. 81 | Only one algorithm per type of attribution. If `str` is provided, the same algorithm 82 | will be applied to calculate both token and pixel attributions. 83 | explanation_2d_bounding_box (`Tuple[Tuple[int, int], Tuple[int, int]]`, *optional*): 84 | Tuple with the bounding box coordinates to calculate attributions for. 85 | The tuple is like (upper left corner, bottom right corner). Example: `((0, 0), (300, 300))` 86 | If this argument is provided, the output will be of type `PipelineExplainerForBoundingBoxOutput` 87 | or `PipelineImg2ImgExplainerForBoundingBoxOutputOutput`- 88 | consider_special_tokens (bool, defaults to `True`): 89 | If True, token attributions will also show attributions for `pipe.tokenizer.SPECIAL_TOKENS_ATTRIBUTES` 90 | clean_token_prefixes_and_suffixes (bool, defaults to `True`): 91 | If True, tries to clean prefixes and suffixes added by the `pipe.tokenizer`. 92 | run_safety_checker (bool, defaults to `False`): 93 | If True, will run the NSFW checker and return a black image if the safety checker says so. 94 | n_last_diffusion_steps_to_consider_for_attributions (int, *optional*): 95 | If not provided, it will calculate explanations for the output image based on all the diffusion steps. 96 | If given a number, it will only use the last provided diffusion steps. 97 | Set to `n_last_diffusion_steps_to_consider_for_attributions=0` for deactivating attributions calculation. 98 | get_images_for_all_inference_steps (bool, defaults to `True`): 99 | If True, will return all the images during diffusion in `output.all_images_during_generation` 100 | output_type (str, *optional*, defaults to `"pil"`): 101 | The output format of the generated image. Choose between 102 | [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `torch.Tensor`. 103 | **kwargs: 104 | Used to pass more arguments to DiffusionPipeline.__call__. 105 | Returns: 106 | [`PipelineExplainerOutput`], [`PipelineExplainerForBoundingBoxOutput`], 107 | [`PipelineImg2ImgExplainerOutput`] or [`PipelineImg2ImgExplainerForBoundingBoxOutputOutput`] 108 | 109 | [`PipelineExplainerOutput`] if `init_image=None` and `explanation_2d_bounding_box=None` 110 | [`PipelineExplainerForBoundingBoxOutput`] if `init_image=None` and `explanation_2d_bounding_box is not None` 111 | [`PipelineImg2ImgExplainerOutput`] if `init_image is not None` and `explanation_2d_bounding_box=None` 112 | [`PipelineImg2ImgExplainerForBoundingBoxOutputOutput`] if `init_image is not None` and `explanation_2d_bounding_box is not None` 113 | """ 114 | 115 | attribution_method = attribution_method or AttributionMethods() 116 | 117 | if isinstance(attribution_method, str): 118 | attribution_method = AttributionMethods( 119 | tokens_attribution_method=AttributionAlgorithm(attribution_method), 120 | pixels_attribution_method=AttributionAlgorithm(attribution_method) 121 | ) 122 | else: 123 | if not isinstance(attribution_method, AttributionMethods): 124 | raise ValueError("`attribution_method` has to be of type `str` or `AttributionMethods`") 125 | 126 | for k in ['tokens_attribution_method', 'pixels_attribution_method']: 127 | v = getattr(attribution_method, k) 128 | if not isinstance(v, AttributionAlgorithm): 129 | setattr(attribution_method, k, AttributionAlgorithm(v)) 130 | 131 | if isinstance(prompt, str): 132 | batch_size = 1 # TODO: make compatible with bigger batch sizes 133 | elif isinstance(prompt, list) and len(prompt) > 0 and isinstance(prompt[0], str): 134 | batch_size = len(prompt) 135 | raise NotImplementedError("Passing a list of strings in `prompt` is still not implemented yet.") 136 | else: 137 | raise ValueError(f"`prompt` has to be of type `str` but is {type(prompt)}") 138 | 139 | # TODO: add asserts for out of bounds 140 | if explanation_2d_bounding_box: 141 | pass 142 | 143 | prompt, init_image, mask_image = self._preprocess_input(prompt=prompt, init_image=init_image, mask_image=mask_image) 144 | 145 | # get prompt text embeddings 146 | tokens, text_input, text_embeddings = self.get_prompt_tokens_token_ids_and_embeds(prompt=prompt) 147 | 148 | # Enable gradient, if `n_last_diffusion_steps_to_consider_for_attributions > 0` 149 | calculate_attributions = n_last_diffusion_steps_to_consider_for_attributions is None \ 150 | or n_last_diffusion_steps_to_consider_for_attributions > 0 151 | if not calculate_attributions: 152 | torch.set_grad_enabled(False) 153 | else: 154 | torch.set_grad_enabled(True) 155 | 156 | # Get prediction with their associated gradients 157 | output: BaseMimicPipelineCallOutput = self._mimic_pipeline_call( 158 | text_input=text_input, 159 | text_embeddings=text_embeddings, 160 | init_image=init_image, 161 | mask_image=mask_image, 162 | batch_size=batch_size, 163 | output_type=None, 164 | run_safety_checker=run_safety_checker, 165 | n_last_diffusion_steps_to_consider_for_attributions=n_last_diffusion_steps_to_consider_for_attributions, 166 | get_images_for_all_inference_steps=get_images_for_all_inference_steps, 167 | **kwargs 168 | ) 169 | 170 | # transform BaseMimicPipelineCallOutput to PipelineExplainerOutput or PipelineExplainerForBoundingBoxOutput 171 | output_kwargs = { 172 | 'image': output.images[0], 173 | 'nsfw_content_detected': output.nsfw_content_detected, 174 | 'all_images_during_generation': output.all_images_during_generation, 175 | } 176 | if explanation_2d_bounding_box is not None: 177 | output['explanation_2d_bounding_box'] = explanation_2d_bounding_box 178 | output: PipelineExplainerForBoundingBoxOutput = PipelineExplainerForBoundingBoxOutput(**output_kwargs) 179 | else: 180 | output: PipelineExplainerOutput = PipelineExplainerOutput(**output_kwargs) 181 | 182 | if output.nsfw_content_detected: 183 | raise Exception( 184 | "NSFW content was detected, it is not possible to provide an explanation. " 185 | "Try to set `run_safety_checker=False` if you really want to skip the NSFW safety check." 186 | ) 187 | 188 | # Calculate primary attribution scores 189 | if calculate_attributions: 190 | output: Union[PipelineExplainerOutput, PipelineImg2ImgExplainerOutput] = self._get_attributions( 191 | output=output, 192 | attribution_method=attribution_method, 193 | tokens=tokens, 194 | text_embeddings=text_embeddings, 195 | init_image=init_image, 196 | mask_image=mask_image, 197 | explanation_2d_bounding_box=explanation_2d_bounding_box, 198 | consider_special_tokens=consider_special_tokens, 199 | clean_token_prefixes_and_suffixes=clean_token_prefixes_and_suffixes, 200 | n_last_diffusion_steps_to_consider_for_attributions=n_last_diffusion_steps_to_consider_for_attributions, 201 | **kwargs 202 | ) 203 | 204 | if batch_size == 1: 205 | # squash batch dimension 206 | for k in ['nsfw_content_detected', 'token_attributions', 'pixel_attributions']: 207 | if getattr(output, k, None) is not None: 208 | output[k] = output[k][0] 209 | if output.all_images_during_generation: 210 | output.all_images_during_generation = [b[0] for b in output.all_images_during_generation] 211 | 212 | else: 213 | raise NotImplementedError 214 | 215 | # convert to PIL Image if requested 216 | # also draw bounding box in the last image if requested 217 | if output.all_images_during_generation or output_type == "pil": 218 | all_images = GeneratedImages( 219 | all_generated_images=output.all_images_during_generation or [output.image], 220 | pipe=self.pipe, 221 | remove_batch_dimension=batch_size==1, 222 | prepare_image_slider=bool(output.all_images_during_generation) 223 | ) 224 | if output.all_images_during_generation: 225 | output.all_images_during_generation = all_images 226 | image = output.all_images_during_generation[-1] 227 | else: 228 | image = all_images[-1] 229 | 230 | if explanation_2d_bounding_box: 231 | draw = ImageDraw.Draw(image) 232 | draw.rectangle(explanation_2d_bounding_box, outline="red") 233 | 234 | if output_type == "pil": 235 | output.image = image 236 | 237 | return output 238 | 239 | def _post_process_token_attributions( 240 | self, 241 | output: PipelineExplainerOutput, 242 | tokens: List[List[str]], 243 | token_attributions: torch.Tensor, 244 | consider_special_tokens: bool, 245 | clean_token_prefixes_and_suffixes: bool 246 | ) -> PipelineExplainerOutput: 247 | # remove special tokens 248 | assert len(token_attributions) == len(tokens) 249 | output.token_attributions = [] 250 | for image_token_attributions, image_tokens in zip(token_attributions, tokens): 251 | assert len(image_token_attributions) == len(image_tokens) 252 | 253 | # Add token attributions 254 | output.token_attributions.append([]) 255 | for attr, token in zip(image_token_attributions, image_tokens): 256 | if consider_special_tokens or token not in self.special_tokens_attributes: 257 | 258 | if clean_token_prefixes_and_suffixes: 259 | token = clean_token_from_prefixes_and_suffixes(token) 260 | 261 | output.token_attributions[-1].append( 262 | (token, attr) 263 | ) 264 | 265 | output.token_attributions[-1] = TokenAttributions(output.token_attributions[-1]) 266 | 267 | return output 268 | 269 | def _get_attributions( 270 | self, 271 | output: Union[PipelineExplainerOutput, PipelineExplainerForBoundingBoxOutput], 272 | attribution_method: AttributionMethods, 273 | tokens: List[List[str]], 274 | text_embeddings: torch.Tensor, 275 | init_image: Optional[torch.FloatTensor] = None, 276 | mask_image: Optional[Union[torch.FloatTensor, Image]] = None, 277 | explanation_2d_bounding_box: Optional[Tuple[Tuple[int, int], Tuple[int, int]]] = None, 278 | consider_special_tokens: bool = False, 279 | clean_token_prefixes_and_suffixes: bool = True, 280 | n_last_diffusion_steps_to_consider_for_attributions: Optional[int] = None, 281 | **kwargs 282 | ) -> Union[ 283 | PipelineExplainerOutput, 284 | PipelineExplainerForBoundingBoxOutput, 285 | PipelineImg2ImgExplainerOutput, 286 | PipelineImg2ImgExplainerForBoundingBoxOutputOutput 287 | ]: 288 | if self.verbose: 289 | print("Calculating token attributions... ", end='') 290 | 291 | token_attributions = gradients_attribution( 292 | pred_logits=output.image, 293 | input_embeds=(text_embeddings,), 294 | attribution_algorithms=[attribution_method.tokens_attribution_method], 295 | explanation_2d_bounding_box=explanation_2d_bounding_box 296 | )[0].detach().cpu().numpy() 297 | 298 | output = self._post_process_token_attributions( 299 | output=output, 300 | tokens=tokens, 301 | token_attributions=token_attributions, 302 | consider_special_tokens=consider_special_tokens, 303 | clean_token_prefixes_and_suffixes=clean_token_prefixes_and_suffixes 304 | ) 305 | 306 | if self.verbose: 307 | print("Done!") 308 | 309 | return output 310 | 311 | @property 312 | def special_tokens_attributes(self) -> Set[str]: 313 | 314 | # remove verbosity 315 | verbose = self.tokenizer.verbose 316 | self.tokenizer.verbose = False 317 | 318 | # get special tokens 319 | special_tokens = [] 320 | for attr in self.tokenizer.SPECIAL_TOKENS_ATTRIBUTES: 321 | t = getattr(self.tokenizer, attr, None) 322 | 323 | if isinstance(t, str): 324 | special_tokens.append(t) 325 | elif isinstance(t, list) and len(t) > 0 and isinstance(t[0], str): 326 | special_tokens += t 327 | 328 | # reset verbosity 329 | self.tokenizer.verbose = verbose 330 | 331 | return set(special_tokens) 332 | 333 | def gradient_checkpointing_enable(self) -> None: 334 | self.gradient_checkpointing = True 335 | 336 | def gradient_checkpointing_disable(self) -> None: 337 | self.gradient_checkpointing = False 338 | 339 | @property 340 | @abstractmethod 341 | def tokenizer(self) -> PreTrainedTokenizerBase: 342 | raise NotImplementedError 343 | 344 | @abstractmethod 345 | def get_prompt_tokens_token_ids_and_embeds(self, prompt: Union[str, List[str]]) -> Tuple[List[List[str]], BatchEncoding, torch.Tensor]: 346 | raise NotImplementedError 347 | 348 | @abstractmethod 349 | def _mimic_pipeline_call( 350 | self, 351 | text_input: BatchEncoding, 352 | text_embeddings: torch.Tensor, 353 | batch_size: int, 354 | init_image: Optional[torch.FloatTensor] = None, 355 | mask_image: Optional[Union[torch.FloatTensor, Image]] = None, 356 | height: Optional[int] = 512, 357 | width: Optional[int] = 512, 358 | strength: float = 0.8, 359 | num_inference_steps: Optional[int] = 50, 360 | guidance_scale: Optional[float] = 7.5, 361 | eta: Optional[float] = 0.0, 362 | generator: Optional[torch.Generator] = None, 363 | latents: Optional[torch.FloatTensor] = None, 364 | output_type: Optional[str] = 'pil', 365 | return_dict: bool = True, 366 | run_safety_checker: bool = True, 367 | n_last_diffusion_steps_to_consider_for_attributions: Optional[int] = None, 368 | get_images_for_all_inference_steps: bool = False 369 | ) -> Union[ 370 | BaseMimicPipelineCallOutput, 371 | Tuple[Union[List[Image], torch.Tensor], Optional[Union[List[List[Image]], List[torch.Tensor]]], Optional[List[bool]]] 372 | ]: 373 | r""" 374 | Mimics DiffusionPipeline.__call__ but adds extra functionality to calculate explanations. 375 | 376 | Args: 377 | text_input (`BatchEncoding`): 378 | Tokenized input string. 379 | text_embeddings (`torch.Tensor`): 380 | Output of the text encoder. 381 | batch_size (`int`): 382 | Batch size to be used. 383 | init_image (`torch.FloatTensor`, *optional*): 384 | `Image`, or tensor representing an image batch, that will be used as the starting point for the 385 | process. 386 | mask_image (`torch.FloatTensor` or `PIL.Image.Image`, *optional*): 387 | `Image`, or tensor representing an image batch, to mask `init_image`. White pixels in the mask will be 388 | replaced by noise and therefore repainted, while black pixels will be preserved. The mask image will be 389 | converted to a single channel (luminance) before use. 390 | strength (`float`, *optional*, defaults to 0.8): 391 | Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. When `strength` 392 | is 1, the denoising process will be run on the masked area for the full number of iterations specified 393 | in `num_inference_steps`. `init_image` will be used as a reference for the masked area, adding more 394 | noise to that region the larger the `strength`. If `strength` is 0, no inpainting will occur. 395 | num_inference_steps (`int`, *optional*, defaults to 50): 396 | The reference number of denoising steps. More denoising steps usually lead to a higher quality image at 397 | the expense of slower inference. This parameter will be modulated by `strength`, as explained above. 398 | guidance_scale (`float`, *optional*, defaults to 7.5): 399 | Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). 400 | `guidance_scale` is defined as `w` of equation 2. of [Imagen 401 | Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 402 | 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, 403 | usually at the expense of lower image quality. 404 | eta (`float`, *optional*, defaults to 0.0): 405 | Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to 406 | [`schedulers.DDIMScheduler`], will be ignored for others. 407 | generator (`torch.Generator`, *optional*): 408 | A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation 409 | deterministic. 410 | latents (`torch.FloatTensor`, *optional*): 411 | Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image 412 | generation. Can be used to tweak the same generation with different prompts. If not provided, a latents 413 | tensor will ge generated by sampling using the supplied random `generator`. 414 | output_type (`str`, *optional*, defaults to `"pil"`): 415 | The output format of the generate image. Choose between 416 | [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `nd.array`. 417 | return_dict (`bool`, *optional*, defaults to `True`): 418 | Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a 419 | plain tuple. 420 | 421 | Returns: 422 | [`BaseMimicPipelineCallOutput`] or `tuple`: 423 | [`BaseMimicPipelineCallOutput`] if `return_dict` is True, otherwise a `tuple`. 424 | When returning a tuple, the first element is a list with the generated images, 425 | the second element contains all the generated images during the diffusion process and the third element is a 426 | list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" 427 | (nsfw) content, according to the `safety_checker` . 428 | """ 429 | raise NotImplementedError 430 | 431 | 432 | class BasePipelineImg2ImgExplainer(BasePipelineExplainer): 433 | """ 434 | Core base class to explain img2img and inpaint pipelines 435 | """ 436 | def _preprocess_input( 437 | self, 438 | prompt: str, 439 | init_image: Optional[Union[torch.FloatTensor, Image]] = None, 440 | mask_image: Optional[Union[torch.FloatTensor, Image]] = None 441 | ) -> Tuple[Any, Any, Any]: 442 | """ 443 | Converts input image to tensor 444 | """ 445 | prompt, init_image, mask_image = super()._preprocess_input( 446 | prompt=prompt, init_image=init_image, mask_image=mask_image 447 | ) 448 | if init_image is None: 449 | raise TypeError("missing 1 required positional argument: 'init_image'") 450 | 451 | init_image = preprocess(init_image).to(self.pipe.device).permute(0, 2, 3, 1) 452 | init_image.requires_grad = True 453 | 454 | return prompt, init_image, mask_image 455 | 456 | def _get_attributions( 457 | self, 458 | output: Union[PipelineExplainerOutput, PipelineExplainerForBoundingBoxOutput], 459 | attribution_method: AttributionMethods, 460 | tokens: List[List[str]], 461 | text_embeddings: torch.Tensor, 462 | init_image: Optional[torch.FloatTensor] = None, 463 | mask_image: Optional[Union[torch.FloatTensor, Image]] = None, 464 | explanation_2d_bounding_box: Optional[Tuple[Tuple[int, int], Tuple[int, int]]] = None, 465 | consider_special_tokens: bool = False, 466 | clean_token_prefixes_and_suffixes: bool = True, 467 | n_last_diffusion_steps_to_consider_for_attributions: Optional[int] = None, 468 | **kwargs 469 | ) -> Union[ 470 | PipelineExplainerOutput, 471 | PipelineExplainerForBoundingBoxOutput, 472 | PipelineImg2ImgExplainerOutput, 473 | PipelineImg2ImgExplainerForBoundingBoxOutputOutput 474 | ]: 475 | if init_image is None: 476 | raise TypeError("missing 1 required positional argument: 'init_image'") 477 | 478 | input_embeds = (text_embeddings,) 479 | if n_last_diffusion_steps_to_consider_for_attributions is None: 480 | input_embeds = (text_embeddings, init_image) 481 | 482 | if self.verbose: 483 | if n_last_diffusion_steps_to_consider_for_attributions is None: 484 | print("Calculating token and image pixel attributions... ", end='') 485 | else: 486 | print( 487 | "Can't calculate image pixel attributions " 488 | "with a specified `n_last_diffusion_steps_to_consider_for_attributions`. " 489 | "Set `n_last_diffusion_steps_to_consider_for_attributions=None` " 490 | "if you wish to calculate image pixel attributions" 491 | ) 492 | print("Calculating token attributions... ", end='') 493 | 494 | attributions = gradients_attribution( 495 | pred_logits=output.image, 496 | input_embeds=input_embeds, 497 | attribution_algorithms=[ 498 | attribution_method.tokens_attribution_method, attribution_method.pixels_attribution_method 499 | ], 500 | explanation_2d_bounding_box=explanation_2d_bounding_box 501 | ) 502 | 503 | token_attributions = attributions[0].detach().cpu().numpy() 504 | 505 | pixel_attributions = None 506 | if n_last_diffusion_steps_to_consider_for_attributions is None: 507 | pixel_attributions = attributions[1].detach().cpu().numpy() 508 | 509 | output = self._post_process_token_attributions( 510 | output=output, 511 | tokens=tokens, 512 | token_attributions=token_attributions, 513 | consider_special_tokens=consider_special_tokens, 514 | clean_token_prefixes_and_suffixes=clean_token_prefixes_and_suffixes 515 | ) 516 | 517 | # removes preprocessing done in diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess 518 | init_image = (init_image + 1.0) / 2.0 519 | 520 | # add batch dimension to mask if needed 521 | masks = mask_image 522 | if isinstance(masks, Image): 523 | masks = [masks] 524 | elif torch.is_tensor(masks) and len(masks.shape) == 3: 525 | masks = masks.unsqueeze(0) 526 | 527 | # construct PixelAttributions objects 528 | images = init_image.detach().cpu().numpy() 529 | assert len(images) == len(pixel_attributions) 530 | if masks is not None: 531 | assert len(images) == len(masks) 532 | pixel_attributions = [ 533 | PixelAttributions( 534 | attr, 535 | saliency_map=SaliencyMap( 536 | image=img, 537 | pixel_attributions=attr, 538 | mask=mask 539 | ) 540 | ) for img, attr, mask in zip(images, pixel_attributions, masks or [None] * len(images)) 541 | ] 542 | 543 | output_kwargs = { 544 | 'image': output.image, 545 | 'nsfw_content_detected': output.nsfw_content_detected, 546 | 'all_images_during_generation': output.all_images_during_generation, 547 | 'token_attributions': output.token_attributions, 548 | 'pixel_attributions': pixel_attributions 549 | } 550 | if explanation_2d_bounding_box is not None: 551 | output_kwargs['explanation_2d_bounding_box'] = explanation_2d_bounding_box 552 | output = PipelineImg2ImgExplainerForBoundingBoxOutputOutput(**output_kwargs) 553 | else: 554 | output = PipelineImg2ImgExplainerOutput(**output_kwargs) 555 | 556 | if self.verbose: 557 | print("Done!") 558 | 559 | return output 560 | -------------------------------------------------------------------------------- /src/diffusers_interpret/explainers/__init__.py: -------------------------------------------------------------------------------- 1 | from .latent_diffusion import LDMTextToImagePipelineExplainer 2 | from .stable_diffusion import StableDiffusionPipelineExplainer, StableDiffusionImg2ImgPipelineExplainer, \ 3 | StableDiffusionInpaintPipelineExplainer -------------------------------------------------------------------------------- /src/diffusers_interpret/explainers/latent_diffusion.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | from typing import List, Optional, Union, Tuple 3 | 4 | import torch 5 | from torch.utils.checkpoint import checkpoint 6 | 7 | from diffusers import LDMTextToImagePipeline 8 | from transformers import BatchEncoding, PreTrainedTokenizerBase 9 | 10 | from diffusers_interpret import BasePipelineExplainer 11 | from diffusers_interpret.explainer import BaseMimicPipelineCallOutput 12 | from diffusers_interpret.utils import transform_images_to_pil_format 13 | 14 | 15 | class LDMTextToImagePipelineExplainer(BasePipelineExplainer): 16 | pipe: LDMTextToImagePipeline 17 | 18 | @property 19 | def tokenizer(self) -> PreTrainedTokenizerBase: 20 | return self.pipe.tokenizer 21 | 22 | def get_prompt_tokens_token_ids_and_embeds(self, prompt: Union[str, List[str]]) -> Tuple[List[List[str]], BatchEncoding, torch.Tensor]: 23 | text_input = self.pipe.tokenizer(prompt, padding="max_length", max_length=77, return_tensors="pt") 24 | text_embeddings = self.pipe.bert(text_input.input_ids.to(self.pipe.device))[0] 25 | tokens = [self.pipe.tokenizer.convert_ids_to_tokens(sample) for sample in text_input['input_ids']] 26 | return tokens, text_input, text_embeddings 27 | 28 | def gradient_checkpointing_enable(self) -> None: 29 | self.pipe.bert.gradient_checkpointing_enable() 30 | super().gradient_checkpointing_enable() 31 | 32 | def gradient_checkpointing_disable(self) -> None: 33 | self.pipe.bert.gradient_checkpointing_disable() 34 | super().gradient_checkpointing_disable() 35 | 36 | def _mimic_pipeline_call( 37 | self, 38 | text_input: BatchEncoding, 39 | text_embeddings: torch.Tensor, 40 | batch_size: int, 41 | height: Optional[int] = 512, 42 | width: Optional[int] = 512, 43 | num_inference_steps: Optional[int] = 50, 44 | guidance_scale: Optional[float] = 7.5, 45 | eta: Optional[float] = 0.0, 46 | generator: Optional[torch.Generator] = None, 47 | latents: Optional[torch.FloatTensor] = None, 48 | output_type: Optional[str] = 'pil', 49 | return_dict: bool = True, 50 | run_safety_checker: bool = True, 51 | n_last_diffusion_steps_to_consider_for_attributions: Optional[int] = None, 52 | get_images_for_all_inference_steps: bool = False 53 | ) -> BaseMimicPipelineCallOutput: 54 | # TODO: add description 55 | 56 | if height % 8 != 0 or width % 8 != 0: 57 | raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") 58 | 59 | if not return_dict: 60 | raise NotImplementedError( 61 | "`return_dict=False` not available in LDMTextToImagePipelineExplainer._mimic_pipeline_call" 62 | ) 63 | 64 | if latents is not None: 65 | raise NotImplementedError( 66 | "Can't provide `latents` to LDMTextToImagePipelineExplainer._mimic_pipeline_call" 67 | ) 68 | 69 | # get unconditional embeddings for classifier free guidance 70 | if guidance_scale != 1.0: 71 | uncond_input = self.pipe.tokenizer([""] * batch_size, padding="max_length", max_length=77, return_tensors="pt") 72 | uncond_embeddings = self.pipe.bert(uncond_input.input_ids.to(self.pipe.device))[0] 73 | 74 | # get the initial random noise 75 | latents = torch.randn( 76 | (batch_size, self.pipe.unet.in_channels, height // 8, width // 8), 77 | generator=generator, 78 | ) 79 | latents = latents.to(self.pipe.device) 80 | 81 | self.pipe.scheduler.set_timesteps(num_inference_steps) 82 | 83 | # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature 84 | accepts_eta = "eta" in set(inspect.signature(self.pipe.scheduler.step).parameters.keys()) 85 | 86 | extra_kwargs = {} 87 | if accepts_eta: 88 | extra_kwargs["eta"] = eta 89 | 90 | def decode_latents(latents: torch.Tensor, pipe: LDMTextToImagePipeline) -> Tuple[torch.Tensor, Optional[List[bool]]]: 91 | # scale and decode the image latents with vae 92 | latents = 1 / 0.18215 * latents 93 | if not self.gradient_checkpointing or not torch.is_grad_enabled(): 94 | image = pipe.vqvae.decode(latents).sample 95 | else: 96 | image = checkpoint(pipe.vqvae.decode, latents, use_reentrant=False).sample 97 | 98 | image = (image / 2 + 0.5).clamp(0, 1) 99 | image = image.permute(0, 2, 3, 1) 100 | 101 | has_nsfw_concept = None 102 | if run_safety_checker: 103 | image = image.detach().cpu().numpy() 104 | safety_cheker_input = pipe.feature_extractor( 105 | pipe.numpy_to_pil(image), return_tensors="pt" 106 | ).to(pipe.device) 107 | image, has_nsfw_concept = pipe.safety_checker( 108 | images=image, clip_input=safety_cheker_input.pixel_values 109 | ) 110 | 111 | return image, has_nsfw_concept 112 | 113 | all_generated_images = [] if get_images_for_all_inference_steps else None 114 | for i, t in enumerate(self.pipe.progress_bar(self.pipe.scheduler.timesteps)): 115 | 116 | if n_last_diffusion_steps_to_consider_for_attributions: 117 | if i < len(self.pipe.scheduler.timesteps) - n_last_diffusion_steps_to_consider_for_attributions: 118 | torch.set_grad_enabled(False) 119 | else: 120 | torch.set_grad_enabled(True) 121 | 122 | # decode latents 123 | if get_images_for_all_inference_steps: 124 | with torch.no_grad(): 125 | image, _ = decode_latents(latents=latents, pipe=self.pipe) 126 | all_generated_images.append(image) 127 | 128 | if guidance_scale == 1.0: 129 | # guidance_scale of 1 means no guidance 130 | latents_input = latents 131 | context = text_embeddings 132 | else: 133 | # For classifier free guidance, we need to do two forward passes. 134 | # Here we concatenate the unconditional and text embeddings into a single batch 135 | # to avoid doing two forward passes 136 | latents_input = torch.cat([latents] * 2) 137 | context = torch.cat([uncond_embeddings, text_embeddings]) 138 | 139 | # predict the noise residual 140 | if not self.gradient_checkpointing or not torch.is_grad_enabled(): 141 | noise_pred = self.pipe.unet(latents_input, t, context).sample 142 | else: 143 | noise_pred = checkpoint( 144 | self.pipe.unet.forward, latents_input, t, context, use_reentrant=False 145 | ).sample 146 | 147 | # perform guidance 148 | if guidance_scale != 1.0: 149 | noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2) 150 | noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond) 151 | 152 | # compute the previous noisy sample x_t -> x_t-1 153 | latents = self.pipe.scheduler.step(noise_pred, t, latents, **extra_kwargs).prev_sample 154 | 155 | image, has_nsfw_concept = decode_latents(latents=latents, pipe=self.pipe) 156 | if all_generated_images: 157 | all_generated_images.append(image) 158 | 159 | if output_type == "pil": 160 | if all_generated_images: 161 | all_generated_images = transform_images_to_pil_format(all_generated_images, self.pipe) 162 | image = all_generated_images[-1] 163 | else: 164 | image = transform_images_to_pil_format([image], self.pipe)[0] 165 | 166 | return BaseMimicPipelineCallOutput( 167 | images=image, nsfw_content_detected=has_nsfw_concept, 168 | all_images_during_generation=all_generated_images 169 | ) 170 | -------------------------------------------------------------------------------- /src/diffusers_interpret/explainers/stable_diffusion.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | from typing import List, Optional, Union, Tuple 3 | 4 | import torch 5 | from PIL.Image import Image 6 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint import preprocess_mask 7 | from torch.utils.checkpoint import checkpoint 8 | from diffusers import StableDiffusionPipeline, LMSDiscreteScheduler, StableDiffusionImg2ImgPipeline, \ 9 | StableDiffusionInpaintPipeline 10 | from transformers import BatchEncoding, PreTrainedTokenizerBase 11 | 12 | from diffusers_interpret import BasePipelineExplainer 13 | from diffusers_interpret.explainer import BaseMimicPipelineCallOutput, BasePipelineImg2ImgExplainer 14 | from diffusers_interpret.utils import transform_images_to_pil_format 15 | 16 | 17 | def decode_latents( 18 | latents: torch.Tensor, 19 | pipe: Union[StableDiffusionImg2ImgPipeline, StableDiffusionPipeline], 20 | gradient_checkpointing: bool, 21 | run_safety_checker: bool 22 | ) -> Tuple[torch.Tensor, Optional[List[bool]]]: 23 | # scale and decode the image latents with vae 24 | latents = 1 / 0.18215 * latents 25 | if not gradient_checkpointing or not torch.is_grad_enabled(): 26 | image = pipe.vae.decode(latents.to(pipe.vae.dtype)).sample 27 | else: 28 | image = checkpoint(pipe.vae.decode, latents.to(pipe.vae.dtype), use_reentrant=False).sample 29 | 30 | image = (image / 2 + 0.5).clamp(0, 1) 31 | image = image.permute(0, 2, 3, 1) 32 | 33 | has_nsfw_concept = None 34 | if run_safety_checker: 35 | image = image.detach().cpu().numpy() 36 | safety_cheker_input = pipe.feature_extractor( 37 | pipe.numpy_to_pil(image), return_tensors="pt" 38 | ).to(pipe.device) 39 | image, has_nsfw_concept = pipe.safety_checker( 40 | images=image, clip_input=safety_cheker_input.pixel_values 41 | ) 42 | 43 | return image, has_nsfw_concept 44 | 45 | 46 | class BaseStableDiffusionPipelineExplainer(BasePipelineExplainer): 47 | pipe: Union[StableDiffusionPipeline, StableDiffusionImg2ImgPipeline] 48 | 49 | @property 50 | def tokenizer(self) -> PreTrainedTokenizerBase: 51 | return self.pipe.tokenizer 52 | 53 | def get_prompt_tokens_token_ids_and_embeds(self, prompt: Union[str, List[str]]) -> Tuple[ 54 | List[List[str]], BatchEncoding, torch.Tensor]: 55 | text_input = self.pipe.tokenizer( 56 | prompt, 57 | padding="max_length", 58 | max_length=self.pipe.tokenizer.model_max_length, 59 | truncation=True, 60 | return_tensors="pt", 61 | ) 62 | text_embeddings = self.pipe.text_encoder(text_input.input_ids.to(self.pipe.device))[0] 63 | tokens = [self.pipe.tokenizer.convert_ids_to_tokens(sample) for sample in text_input['input_ids']] 64 | return tokens, text_input, text_embeddings 65 | 66 | def gradient_checkpointing_enable(self) -> None: 67 | self.pipe.text_encoder.gradient_checkpointing_enable() 68 | super().gradient_checkpointing_enable() 69 | 70 | def gradient_checkpointing_disable(self) -> None: 71 | self.pipe.text_encoder.gradient_checkpointing_disable() 72 | super().gradient_checkpointing_disable() 73 | 74 | 75 | class StableDiffusionPipelineExplainer(BaseStableDiffusionPipelineExplainer): 76 | pipe: StableDiffusionPipeline 77 | 78 | def _mimic_pipeline_call( 79 | self, 80 | text_input: BatchEncoding, 81 | text_embeddings: torch.Tensor, 82 | batch_size: int, 83 | init_image: Optional[torch.FloatTensor] = None, 84 | mask_image: Optional[Union[torch.FloatTensor, Image]] = None, 85 | height: Optional[int] = 512, 86 | width: Optional[int] = 512, 87 | strength: float = 0.8, 88 | num_inference_steps: Optional[int] = 50, 89 | guidance_scale: Optional[float] = 7.5, 90 | eta: Optional[float] = 0.0, 91 | generator: Optional[torch.Generator] = None, 92 | latents: Optional[torch.FloatTensor] = None, 93 | output_type: Optional[str] = 'pil', 94 | return_dict: bool = True, 95 | run_safety_checker: bool = True, 96 | n_last_diffusion_steps_to_consider_for_attributions: Optional[int] = None, 97 | get_images_for_all_inference_steps: bool = False 98 | ) -> Union[ 99 | BaseMimicPipelineCallOutput, 100 | Tuple[Union[List[Image], torch.Tensor], Optional[Union[List[List[Image]], List[torch.Tensor]]], Optional[ 101 | List[bool]]] 102 | ]: 103 | if height % 8 != 0 or width % 8 != 0: 104 | raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") 105 | 106 | if init_image is not None: 107 | if mask_image is not None: 108 | raise ValueError( 109 | "`init_image` and `mask_image` were passed to StableDiffusionPipelineExplainer and are not expected.\n" 110 | "Were you trying to use StableDiffusionInpaintPipelineExplainer ?" 111 | ) 112 | else: 113 | raise ValueError( 114 | "`init_image` was passed to StableDiffusionPipelineExplainer and is not expected.\n" 115 | "Were you trying to use StableDiffusionImg2ImgPipelineExplainer ?" 116 | ) 117 | 118 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) 119 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` 120 | # corresponds to doing no classifier free guidance. 121 | do_classifier_free_guidance = guidance_scale > 1.0 122 | # get unconditional embeddings for classifier free guidance 123 | if do_classifier_free_guidance: 124 | max_length = text_input.input_ids.shape[-1] 125 | uncond_input = self.pipe.tokenizer( 126 | [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt" 127 | ) 128 | uncond_embeddings = self.pipe.text_encoder(uncond_input.input_ids.to(self.pipe.device))[0] 129 | # For classifier free guidance, we need to do two forward passes. 130 | # Here we concatenate the unconditional and text embeddings into a single batch 131 | # to avoid doing two forward passes 132 | text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) 133 | 134 | # get the initial random noise unless the user supplied it 135 | # Unlike in other pipelines, latents need to be generated in the target device 136 | # for 1-to-1 results reproducibility with the CompVis implementation. 137 | # However this currently doesn't work in `mps`. 138 | latents_device = "cpu" if self.pipe.device.type == "mps" else self.pipe.device 139 | latents_shape = (batch_size, self.pipe.unet.in_channels, height // 8, width // 8) 140 | if latents is None: 141 | latents = torch.randn( 142 | latents_shape, 143 | generator=generator, 144 | device=latents_device, 145 | ) 146 | else: 147 | if latents.shape != latents_shape: 148 | raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") 149 | latents = latents.to(self.pipe.device) 150 | 151 | # set timesteps 152 | accepts_offset = "offset" in set(inspect.signature(self.pipe.scheduler.set_timesteps).parameters.keys()) 153 | extra_set_kwargs = {} 154 | if accepts_offset: 155 | extra_set_kwargs["offset"] = 1 156 | 157 | self.pipe.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs) 158 | 159 | # if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas 160 | if isinstance(self.pipe.scheduler, LMSDiscreteScheduler): 161 | latents = latents * self.pipe.scheduler.sigmas[0] 162 | 163 | # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature 164 | # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. 165 | # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 166 | # and should be between [0, 1] 167 | accepts_eta = "eta" in set(inspect.signature(self.pipe.scheduler.step).parameters.keys()) 168 | extra_step_kwargs = {} 169 | if accepts_eta: 170 | extra_step_kwargs["eta"] = eta 171 | 172 | all_generated_images = [] if get_images_for_all_inference_steps else None 173 | for i, t in enumerate(self.pipe.progress_bar(self.pipe.scheduler.timesteps)): 174 | 175 | if n_last_diffusion_steps_to_consider_for_attributions: 176 | if i < len(self.pipe.scheduler.timesteps) - n_last_diffusion_steps_to_consider_for_attributions: 177 | torch.set_grad_enabled(False) 178 | else: 179 | torch.set_grad_enabled(True) 180 | 181 | # decode latents 182 | if get_images_for_all_inference_steps: 183 | with torch.no_grad(): 184 | image, _ = decode_latents( 185 | latents=latents, pipe=self.pipe, 186 | gradient_checkpointing=self.gradient_checkpointing, run_safety_checker=run_safety_checker 187 | ) 188 | all_generated_images.append(image) 189 | 190 | # expand the latents if we are doing classifier free guidance 191 | latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents 192 | if isinstance(self.pipe.scheduler, LMSDiscreteScheduler): 193 | sigma = self.pipe.scheduler.sigmas[i] 194 | latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5) 195 | 196 | # predict the noise residual 197 | if not self.gradient_checkpointing or not torch.is_grad_enabled(): 198 | noise_pred = self.pipe.unet(latent_model_input, t, text_embeddings).sample 199 | else: 200 | noise_pred = checkpoint( 201 | self.pipe.unet.forward, latent_model_input, t, text_embeddings, use_reentrant=False 202 | ).sample 203 | 204 | # perform guidance 205 | if do_classifier_free_guidance: 206 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 207 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) 208 | 209 | # compute the previous noisy sample x_t -> x_t-1 210 | if isinstance(self.pipe.scheduler, LMSDiscreteScheduler): 211 | latents = self.pipe.scheduler.step(noise_pred, i, latents, **extra_step_kwargs).prev_sample 212 | else: 213 | latents = self.pipe.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample 214 | 215 | image, has_nsfw_concept = decode_latents( 216 | latents=latents, pipe=self.pipe, 217 | gradient_checkpointing=self.gradient_checkpointing, run_safety_checker=run_safety_checker 218 | ) 219 | if all_generated_images: 220 | all_generated_images.append(image) 221 | 222 | if output_type == "pil": 223 | if all_generated_images: 224 | all_generated_images = transform_images_to_pil_format(all_generated_images, self.pipe) 225 | image = all_generated_images[-1] 226 | else: 227 | image = transform_images_to_pil_format([image], self.pipe)[0] 228 | 229 | if return_dict: 230 | return BaseMimicPipelineCallOutput( 231 | images=image, nsfw_content_detected=has_nsfw_concept, 232 | all_images_during_generation=all_generated_images 233 | ) 234 | else: 235 | return (image, all_generated_images, has_nsfw_concept) 236 | 237 | 238 | class StableDiffusionImg2ImgPipelineExplainer(BasePipelineImg2ImgExplainer, BaseStableDiffusionPipelineExplainer): 239 | pipe: Union[StableDiffusionImg2ImgPipeline, StableDiffusionInpaintPipeline] 240 | 241 | def _mimic_pipeline_call( 242 | self, 243 | text_input: BatchEncoding, 244 | text_embeddings: torch.Tensor, 245 | batch_size: int, 246 | init_image: Optional[torch.FloatTensor] = None, 247 | mask_image: Optional[Union[torch.FloatTensor, Image]] = None, 248 | height: Optional[int] = 512, 249 | width: Optional[int] = 512, 250 | strength: float = 0.8, 251 | num_inference_steps: Optional[int] = 50, 252 | guidance_scale: Optional[float] = 7.5, 253 | eta: Optional[float] = 0.0, 254 | generator: Optional[torch.Generator] = None, 255 | latents: Optional[torch.FloatTensor] = None, 256 | output_type: Optional[str] = 'pil', 257 | return_dict: bool = True, 258 | run_safety_checker: bool = True, 259 | n_last_diffusion_steps_to_consider_for_attributions: Optional[int] = None, 260 | get_images_for_all_inference_steps: bool = False 261 | ) -> Union[ 262 | BaseMimicPipelineCallOutput, 263 | Tuple[Union[List[Image], torch.Tensor], Optional[Union[List[List[Image]], List[torch.Tensor]]], Optional[ 264 | List[bool]]] 265 | ]: 266 | 267 | if latents is not None: 268 | raise ValueError( 269 | f"`latents` was passed to {self.__class__.__name__} and it is not expected." 270 | ) 271 | 272 | # set timesteps 273 | accepts_offset = "offset" in set(inspect.signature(self.pipe.scheduler.set_timesteps).parameters.keys()) 274 | extra_set_kwargs = {} 275 | offset = 0 276 | if accepts_offset: 277 | offset = 1 278 | extra_set_kwargs["offset"] = 1 279 | 280 | self.pipe.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs) 281 | 282 | # save all generated images during diffusion, if get_images_for_all_inference_steps 283 | all_generated_images = [(init_image / 2 + 0.5).clamp(0, 1)] if get_images_for_all_inference_steps else None 284 | 285 | # encode the init image into latents and scale the latents 286 | init_latent_dist = self.pipe.vae.encode(init_image.permute(0, 3, 1, 2)).latent_dist 287 | init_latents = init_latent_dist.sample(generator=generator) 288 | init_latents = 0.18215 * init_latents 289 | 290 | # expand init_latents for batch_size 291 | init_latents = torch.cat([init_latents] * batch_size) 292 | init_latents_orig = init_latents 293 | 294 | mask = None 295 | if mask_image is not None: 296 | # preprocess mask 297 | mask = preprocess_mask(mask_image).to(self.pipe.device) 298 | mask = torch.cat([mask] * batch_size) 299 | 300 | # check sizes 301 | if not mask.shape == init_latents.shape: 302 | raise ValueError("The mask and init_image should be the same size!") 303 | 304 | # get the original timestep using init_timestep 305 | init_timestep = int(num_inference_steps * strength) + offset 306 | init_timestep = min(init_timestep, num_inference_steps) 307 | if isinstance(self.pipe.scheduler, LMSDiscreteScheduler): 308 | timesteps = torch.tensor( 309 | [num_inference_steps - init_timestep] * batch_size, dtype=torch.long, device=self.pipe.device 310 | ) 311 | else: 312 | timesteps = self.pipe.scheduler.timesteps[-init_timestep] 313 | timesteps = torch.tensor([timesteps] * batch_size, dtype=torch.long, device=self.pipe.device) 314 | 315 | # add noise to latents using the timesteps 316 | noise = torch.randn(init_latents.shape, generator=generator, device=self.pipe.device) 317 | init_latents = self.pipe.scheduler.add_noise(init_latents, noise, timesteps).to(self.pipe.device) 318 | 319 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) 320 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` 321 | # corresponds to doing no classifier free guidance. 322 | do_classifier_free_guidance = guidance_scale > 1.0 323 | # get unconditional embeddings for classifier free guidance 324 | if do_classifier_free_guidance: 325 | max_length = text_input.input_ids.shape[-1] 326 | uncond_input = self.pipe.tokenizer( 327 | [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt" 328 | ) 329 | uncond_embeddings = self.pipe.text_encoder(uncond_input.input_ids.to(self.pipe.device))[0] 330 | 331 | # For classifier free guidance, we need to do two forward passes. 332 | # Here we concatenate the unconditional and text embeddings into a single batch 333 | # to avoid doing two forward passes 334 | text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) 335 | 336 | # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature 337 | # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. 338 | # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 339 | # and should be between [0, 1] 340 | accepts_eta = "eta" in set(inspect.signature(self.pipe.scheduler.step).parameters.keys()) 341 | extra_step_kwargs = {} 342 | if accepts_eta: 343 | extra_step_kwargs["eta"] = eta 344 | 345 | latents = init_latents 346 | t_start = max(num_inference_steps - init_timestep + offset, 0) 347 | for i, t in enumerate(self.pipe.progress_bar(self.pipe.scheduler.timesteps[t_start:])): 348 | t_index = t_start + i 349 | 350 | if n_last_diffusion_steps_to_consider_for_attributions: 351 | if t_index < len(self.pipe.scheduler.timesteps) - n_last_diffusion_steps_to_consider_for_attributions: 352 | torch.set_grad_enabled(False) 353 | else: 354 | torch.set_grad_enabled(True) 355 | 356 | # decode latents 357 | if get_images_for_all_inference_steps: 358 | with torch.no_grad(): 359 | image, _ = decode_latents( 360 | latents=latents, pipe=self.pipe, 361 | gradient_checkpointing=self.gradient_checkpointing, run_safety_checker=run_safety_checker 362 | ) 363 | all_generated_images.append(image) 364 | 365 | # expand the latents if we are doing classifier free guidance 366 | latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents 367 | 368 | # if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas 369 | if isinstance(self.pipe.scheduler, LMSDiscreteScheduler): 370 | sigma = self.pipe.scheduler.sigmas[t_index] 371 | # the model input needs to be scaled to match the continuous ODE formulation in K-LMS 372 | latent_model_input = latent_model_input / ((sigma ** 2 + 1) ** 0.5) 373 | latent_model_input = latent_model_input.to(self.pipe.unet.dtype) 374 | t = t.to(self.pipe.unet.dtype) 375 | 376 | # predict the noise residual 377 | if not self.gradient_checkpointing or not torch.is_grad_enabled(): 378 | noise_pred = self.pipe.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample 379 | else: 380 | noise_pred = checkpoint( 381 | self.pipe.unet.forward, latent_model_input, t, text_embeddings, use_reentrant=False 382 | ).sample 383 | 384 | # perform guidance 385 | if do_classifier_free_guidance: 386 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 387 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) 388 | 389 | # compute the previous noisy sample x_t -> x_t-1 390 | if isinstance(self.pipe.scheduler, LMSDiscreteScheduler): 391 | latents = self.pipe.scheduler.step(noise_pred, t_index, latents, **extra_step_kwargs).prev_sample 392 | else: 393 | latents = self.pipe.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample 394 | 395 | # masking 396 | if mask is not None: 397 | init_latents_proper = self.pipe.scheduler.add_noise(init_latents_orig, noise, t) 398 | latents = (init_latents_proper * mask) + (latents * (1 - mask)) 399 | 400 | image, has_nsfw_concept = decode_latents( 401 | latents=latents, pipe=self.pipe, 402 | gradient_checkpointing=self.gradient_checkpointing, run_safety_checker=run_safety_checker 403 | ) 404 | if all_generated_images: 405 | all_generated_images.append(image) 406 | 407 | if output_type == "pil": 408 | if all_generated_images: 409 | all_generated_images = transform_images_to_pil_format(all_generated_images, self.pipe) 410 | image = all_generated_images[-1] 411 | else: 412 | image = transform_images_to_pil_format([image], self.pipe)[0] 413 | 414 | if return_dict: 415 | return BaseMimicPipelineCallOutput( 416 | images=image, nsfw_content_detected=has_nsfw_concept, 417 | all_images_during_generation=all_generated_images 418 | ) 419 | else: 420 | return (image, all_generated_images, has_nsfw_concept) 421 | 422 | 423 | class StableDiffusionInpaintPipelineExplainer(StableDiffusionImg2ImgPipelineExplainer): 424 | # Actually the same as StableDiffusionImg2ImgPipelineExplainer 425 | pass -------------------------------------------------------------------------------- /src/diffusers_interpret/generated_images.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import json 3 | import os 4 | import random 5 | from typing import List, Union 6 | 7 | import torch 8 | from IPython import display as d 9 | from PIL.Image import Image 10 | from diffusers import DiffusionPipeline 11 | 12 | import diffusers_interpret 13 | from diffusers_interpret.utils import transform_images_to_pil_format 14 | 15 | 16 | class GeneratedImages: 17 | def __init__( 18 | self, 19 | all_generated_images: List[torch.Tensor], 20 | pipe: DiffusionPipeline, 21 | remove_batch_dimension: bool = True, 22 | prepare_image_slider: bool = True 23 | ) -> None: 24 | 25 | assert all_generated_images, "Can't create GeneratedImages object with empty `all_generated_images`" 26 | 27 | # Convert images to PIL and draw box if requested 28 | self.images = [] 29 | for list_im in transform_images_to_pil_format(all_generated_images, pipe): 30 | batch_images = [] 31 | for im in list_im: 32 | batch_images.append(im) 33 | 34 | if remove_batch_dimension: 35 | self.images.extend(batch_images) 36 | else: 37 | self.images.append(batch_images) 38 | 39 | self.loading_iframe = None 40 | self.image_slider_iframe = None 41 | if prepare_image_slider: 42 | self.prepare_image_slider() 43 | 44 | def prepare_image_slider(self) -> None: 45 | """ 46 | Creates auxiliary HTML file to be displayed in self.__repr__ 47 | """ 48 | 49 | # Get data dir 50 | image_slider_dir = os.path.join(os.path.dirname(diffusers_interpret.__file__), "dataviz", "image-slider") 51 | 52 | # Convert images to base64 53 | json_payload = [] 54 | for i, image in enumerate(self.images): 55 | image.save(f"{image_slider_dir}/to_delete.png") 56 | with open(f"{image_slider_dir}/to_delete.png", "rb") as image_file: 57 | json_payload.append( 58 | {"image": "data:image/png;base64," + base64.b64encode(image_file.read()).decode('utf-8')} 59 | ) 60 | os.remove(f"{image_slider_dir}/to_delete.png") 61 | 62 | # get HTML file 63 | with open(os.path.join(image_slider_dir, "index.html")) as fp: 64 | html = fp.read() 65 | 66 | # get CSS file 67 | with open(os.path.join(image_slider_dir, "css/index.css")) as fp: 68 | css = fp.read() 69 | 70 | # get JS file 71 | with open(os.path.join(image_slider_dir, "js/index.js")) as fp: 72 | js = fp.read() 73 | 74 | # replace CSS text in CSS file 75 | html = html.replace("""""", 76 | f"""""") 77 | 78 | # replace JS text in HTML file 79 | html = html.replace("""""", "" 80 | f"""""") 81 | 82 | # get html with image slider JS call 83 | index = html.find("") 84 | add = """ 85 | 98 | """ % json.dumps(json_payload) 99 | html_with_image_slider = html[:index] + add + html[index:] 100 | 101 | # save files and load IFrame to be displayed in self.__repr__ 102 | with open(os.path.join(image_slider_dir, "loading.html"), 'w') as fp: 103 | fp.write(html) 104 | with open(os.path.join(image_slider_dir, "final.html"), 'w') as fp: 105 | fp.write(html_with_image_slider) 106 | 107 | self.loading_iframe = d.IFrame( 108 | os.path.relpath( 109 | os.path.join(os.path.dirname(diffusers_interpret.__file__), "dataviz", "image-slider", "loading.html"), 110 | '.' 111 | ), 112 | width="100%", height="400px" 113 | ) 114 | 115 | self.image_slider_iframe = d.IFrame( 116 | os.path.relpath( 117 | os.path.join(os.path.dirname(diffusers_interpret.__file__), "dataviz", "image-slider", "final.html"), 118 | '.' 119 | ), 120 | width="100%", height="400px" 121 | ) 122 | 123 | def __getitem__(self, item: int) -> Union[Image, List[Image]]: 124 | return self.images[item] 125 | 126 | def show(self, width: Union[str, int] = "100%", height: Union[str, int] = "400px") -> None: 127 | 128 | if len(self.images) == 0: 129 | raise Exception("`self.images` is an empty list, can't show any images") 130 | 131 | if isinstance(self.images[0], list): 132 | raise NotImplementedError("GeneratedImages.show visualization is not supported " 133 | "when `self.images` is a list of lists of images") 134 | 135 | if self.image_slider_iframe is None: 136 | self.prepare_image_slider() 137 | 138 | # display loading 139 | self.loading_iframe.width = width 140 | self.loading_iframe.height = height 141 | display = d.display(self.loading_iframe, display_id=random.randint(0, 9999999)) 142 | 143 | # display image slider 144 | self.image_slider_iframe.width = width 145 | self.image_slider_iframe.height = height 146 | display.update(self.image_slider_iframe) 147 | 148 | def gif(self, file_name: str = "diffusion_process.gif", duration: int = 400, show: bool = True) -> None: 149 | 150 | if len(self.images) == 0: 151 | raise Exception("`self.images` is an empty list, can't show any images") 152 | if isinstance(self.images[0], list): 153 | raise NotImplementedError("GeneratedImages.gif is not supported " 154 | "when `self.images` is a list of lists of images") 155 | ''' 156 | Generate and display a GIF from the denoising process 157 | ''' 158 | self[0].save(file_name, 159 | save_all = True, 160 | append_images = self[1:], 161 | optimize = False, 162 | duration = duration, 163 | loop = 0) 164 | if show: 165 | d.display(d.Image(file_name)) 166 | -------------------------------------------------------------------------------- /src/diffusers_interpret/pixel_attributions.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Union 2 | 3 | import numpy as np 4 | 5 | from diffusers_interpret.saliency_map import SaliencyMap 6 | 7 | 8 | class PixelAttributions(np.ndarray): 9 | def __new__(cls, pixel_attributions: np.ndarray, saliency_map: SaliencyMap) -> "PixelAttributions": 10 | # Construct new ndarray 11 | obj = np.asarray(pixel_attributions).view(cls) 12 | obj.pixel_attributions = pixel_attributions 13 | obj.normalized = 100 * (pixel_attributions / pixel_attributions.sum()) 14 | obj.saliency_map = saliency_map 15 | 16 | # Calculate normalized 17 | obj.normalized = 100 * (pixel_attributions / pixel_attributions.sum()) 18 | 19 | return obj 20 | 21 | def __getitem__(self, item: Union[str, int]) -> Any: 22 | return getattr(self, item) if isinstance(item, str) else self.pixel_attributions[item] 23 | 24 | def __setitem__(self, key: Union[str, int], value: Any) -> None: 25 | setattr(self, key, value) 26 | 27 | def __repr__(self) -> str: 28 | return self.pixel_attributions.__repr__() -------------------------------------------------------------------------------- /src/diffusers_interpret/saliency_map.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Union, List 2 | 3 | import cv2 4 | import cmapy 5 | import numpy as np 6 | import torch 7 | from PIL.Image import Image 8 | from matplotlib import pyplot as plt 9 | 10 | 11 | class SaliencyMap: 12 | def __init__( 13 | self, 14 | image: np.ndarray, 15 | pixel_attributions: np.ndarray, 16 | mask: Optional[Union[torch.FloatTensor, Image]] = None 17 | ): 18 | 19 | if mask is not None: 20 | if torch.is_tensor(mask): 21 | mask = mask.detach().cpu().numpy() 22 | else: # List[Image] 23 | mask = np.float32(mask) 24 | 25 | self.img = np.float32(image) 26 | self.pixel_attributions = pixel_attributions 27 | self.mask = mask 28 | 29 | def show(self, cmap='jet', image_weight=0.5, tight=True, apply_mask=True, **kwargs) -> None: 30 | 31 | saliency_map = cv2.applyColorMap(np.uint8(self.pixel_attributions), cmapy.cmap(cmap)) 32 | saliency_map = cv2.cvtColor(saliency_map, cv2.COLOR_BGR2RGB) 33 | saliency_map = np.float32(saliency_map) / 255.0 34 | 35 | img = self.img 36 | if self.mask is not None and apply_mask: 37 | img = np.array(self.img) * (1 - self.mask / 255) # np.array so that we copy `img` and don't change it 38 | saliency_map *= (1 - self.mask / 255) 39 | 40 | overlayed = (1 - image_weight) * saliency_map + image_weight * img 41 | overlayed = overlayed / np.max(overlayed) 42 | overlayed = np.uint8(255 * overlayed) 43 | 44 | # Visualize the image and the saliency map 45 | fig, ax = plt.subplots(1, 3, **kwargs) 46 | ax[0].imshow(img) 47 | ax[0].axis('off') 48 | ax[0].title.set_text('Image') 49 | 50 | ax[1].imshow(saliency_map) 51 | ax[1].axis('off') 52 | ax[1].title.set_text('Pixel attributions') 53 | 54 | ax[2].imshow(overlayed) 55 | ax[2].axis('off') 56 | ax[2].title.set_text('Image Overlayed') 57 | 58 | if tight: 59 | plt.tight_layout() 60 | plt.show() -------------------------------------------------------------------------------- /src/diffusers_interpret/token_attributions.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, List, Any, Union 2 | import matplotlib.pyplot as plt 3 | 4 | 5 | class TokenAttributions(list): 6 | def __init__(self, token_attributions: List[Tuple[str, float]]) -> None: 7 | super().__init__(token_attributions) 8 | self.token_attributions = token_attributions 9 | 10 | # Calculate normalized 11 | total = sum([attr for _, attr in token_attributions]) 12 | self.normalized = [ 13 | (token, round(100 * attr / total, 3)) 14 | for token, attr in token_attributions 15 | ] 16 | 17 | def __getitem__(self, item: Union[str, int]) -> Any: 18 | return getattr(self, item) if isinstance(item, str) else self.token_attributions[item] 19 | 20 | def __setitem__(self, key: Union[str, int], value: Any) -> None: 21 | setattr(self, key, value) 22 | 23 | def plot(self, plot_type: str = 'barh', normalize: bool = False, **plot_kwargs) -> None: 24 | ''' 25 | Plot the token attributions to have a comparative view. 26 | Available plot types include bar chart, horizontal bar chart, and pie chart. 27 | ''' 28 | 29 | attrs = self.normalized if normalize else self.token_attributions 30 | tokens, attributions = list(zip(*attrs)) 31 | prefix = 'normalized ' if normalize else '' 32 | 33 | # get arguments from plot_kwargs 34 | xlabel = plot_kwargs.get('xlabel') 35 | ylabel = plot_kwargs.get('ylabel') 36 | title = plot_kwargs.get('title') or f'{prefix.title()}Token Attributions' 37 | 38 | if plot_type == 'bar': 39 | # Bar chart 40 | plt.bar(tokens, attributions) 41 | plt.xlabel(xlabel or 'tokens') 42 | plt.ylabel(ylabel or f'{prefix}attribution value') 43 | 44 | elif plot_type == 'barh': 45 | # Horizontal bar chart 46 | plt.barh(tokens, attributions) 47 | plt.xlabel(xlabel or f'{prefix}attribution value') 48 | plt.ylabel(ylabel or 'tokens') 49 | plt.gca().invert_yaxis() # to have the order of tokens from top to bottom 50 | 51 | elif plot_type == 'pie': 52 | # Pie chart 53 | plot_kwargs = { 54 | 'startangle': 90, 'counterclock': False, 'labels': tokens, 55 | 'autopct': '%1.1f%%', 'pctdistance': 0.8, 56 | **plot_kwargs 57 | } 58 | plt.pie(attributions, **plot_kwargs) 59 | if xlabel: 60 | plt.xlabel(xlabel) 61 | if ylabel: 62 | plt.ylabel(ylabel) 63 | 64 | else: 65 | raise NotImplementedError( 66 | f"`plot_type={plot_type}` is not implemented. Choose one of: ['bar', 'barh', 'pie']" 67 | ) 68 | 69 | # set title and show 70 | plt.title(title) 71 | plt.show() 72 | -------------------------------------------------------------------------------- /src/diffusers_interpret/utils.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import torch 4 | from PIL.Image import Image 5 | from diffusers import DiffusionPipeline 6 | 7 | 8 | def clean_token_from_prefixes_and_suffixes(token: str) -> str: 9 | """ 10 | Removes all the known token prefixes and suffixes 11 | 12 | Args: 13 | token (`str`): string with token 14 | 15 | Returns: 16 | `str`: clean token 17 | """ 18 | 19 | # removes T5 prefix 20 | token = token.lstrip('▁') 21 | 22 | # removes BERT/GPT-2 prefix 23 | token = token.lstrip('Ġ') 24 | 25 | # removes CLIP suffix 26 | token = token.rstrip('') 27 | 28 | return token 29 | 30 | 31 | def transform_images_to_pil_format(all_generated_images: List[torch.Tensor], pipe: DiffusionPipeline) -> List[List[Image]]: 32 | pil_images = [] 33 | for im in all_generated_images: 34 | if isinstance(im, torch.Tensor): 35 | im = im.detach().cpu().numpy() 36 | im = pipe.numpy_to_pil(im) 37 | pil_images.append(im) 38 | return pil_images --------------------------------------------------------------------------------