├── .gitmodules ├── LICENSE ├── README.md ├── Security.md ├── app.py ├── demo.ipynb ├── processing_llavagemma.py ├── requirements.txt ├── utils_attn.py ├── utils_causal_discovery.py ├── utils_causal_discovery_fn.py ├── utils_gradio.py ├── utils_model.py └── utils_relevancy.py /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "causality-lab"] 2 | path = causality_lab 3 | url = https://github.com/IntelLabs/causality-lab.git 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # LVLM-Interpret: An Interpretability Tool for Large Vision-Language Models 2 | [[Project Page](https://intellabs.github.io/multimodal_cognitive_ai/lvlm_interpret/)] [[Paper](https://arxiv.org/abs/2404.03118)] 3 | 4 | ## Setup 5 | 6 | - Update submodules 7 | 8 | `git submodule update --init --recursive` 9 | 10 | - Install dependencies 11 | 12 | `pip install -r requirements.txt` 13 | 14 | 15 | ## Usage 16 | 17 | Start the Gradio server: 18 | ``` 19 | python app.py --model_name_or_path Intel/llava-gemma-2b --load_8bit 20 | ``` 21 | or 22 | ``` 23 | python app.py --model_name_or_path llava-hf/llava-1.5-7b-hf --load_8bit 24 | ``` 25 | 26 | Options: 27 | ``` 28 | usage: app.py [-h] [--model_name_or_path MODEL_NAME_OR_PATH] [--host HOST] [--port PORT] [--share] [--embed] [--load_4bit] [--load_8bit] 29 | 30 | options: 31 | -h, --help show this help message and exit 32 | --model_name_or_path MODEL_NAME_OR_PATH 33 | Model name or path to load the model from 34 | --host HOST Host to run the server on 35 | --port PORT Port to run the server on 36 | --share Whether to share the server on Gradio's public server 37 | --embed Whether to run the server in an iframe 38 | --load_4bit Whether to load the model in 4bit 39 | --load_8bit Whether to load the model in 8bit 40 | 41 | ``` 42 | -------------------------------------------------------------------------------- /Security.md: -------------------------------------------------------------------------------- 1 | # Security Policy 2 | Intel is committed to rapidly addressing security vulnerabilities affecting our customers and providing clear guidance on the solution, impact, severity and mitigation. 3 | 4 | ## Reporting a Vulnerability 5 | Please report any security vulnerabilities in this project [utilizing the guidelines here](https://www.intel.com/content/www/us/en/security-center/vulnerability-handling-guidelines.html). 6 | -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | 4 | 5 | from utils_gradio import build_demo 6 | 7 | logging.basicConfig(level=logging.INFO) 8 | 9 | if __name__ == "__main__": 10 | parser = argparse.ArgumentParser() 11 | 12 | parser.add_argument("--model_name_or_path", type=str, default="Intel/llava-gemma-2b", 13 | help="Model name or path to load the model from") 14 | parser.add_argument("--host", type=str, default="0.0.0.0", 15 | help="Host to run the server on") 16 | parser.add_argument("--port", type=int, default=7860, 17 | help="Port to run the server on") 18 | parser.add_argument("--share", action="store_true", 19 | help="Whether to share the server on Gradio's public server") 20 | parser.add_argument("--embed", action="store_true", 21 | help="Whether to run the server in an iframe") 22 | parser.add_argument("--load_4bit", action="store_true", 23 | help="Whether to load the model in 4bit") 24 | parser.add_argument("--load_8bit", action="store_true", 25 | help="Whether to load the model in 8bit") 26 | parser.add_argument("--device_map", default="auto", 27 | help="Device map to use for model", choices=["auto", "cpu", "cuda", "hpu"]) 28 | args = parser.parse_args() 29 | 30 | assert not( args.load_4bit and args.load_8bit), "Cannot load both 4bit and 8bit models" 31 | 32 | demo = build_demo(args, embed_mode=False) 33 | # demo.queue(max_size=1) 34 | demo.launch( 35 | server_name=args.host, 36 | server_port=args.port, 37 | share=args.share, 38 | debug=True 39 | ) 40 | -------------------------------------------------------------------------------- /processing_llavagemma.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ 16 | Processor class for Llava. 17 | Modified to include support for Gemma tokenizer. 18 | """ 19 | 20 | 21 | from typing import List, Optional, Union 22 | 23 | from transformers.feature_extraction_utils import BatchFeature 24 | from transformers.image_utils import ImageInput 25 | from transformers.processing_utils import ProcessorMixin 26 | from transformers.tokenization_utils_base import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy 27 | from transformers.utils import TensorType 28 | 29 | 30 | class LlavaGemmaProcessor(ProcessorMixin): 31 | r""" 32 | Constructs a Llava processor which wraps a Llava image processor and a Llava tokenizer into a single processor. 33 | 34 | [`LlavaProcessor`] offers all the functionalities of [`CLIPImageProcessor`] and [`LlamaTokenizerFast`]. See the 35 | [`~LlavaProcessor.__call__`] and [`~LlavaProcessor.decode`] for more information. 36 | 37 | Args: 38 | image_processor ([`CLIPImageProcessor`], *optional*): 39 | The image processor is a required input. 40 | tokenizer ([`LlamaTokenizerFast`], *optional*): 41 | The tokenizer is a required input. 42 | """ 43 | 44 | attributes = ["image_processor", "tokenizer"] 45 | image_processor_class = "CLIPImageProcessor" 46 | tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast", 47 | "GemmaTokenizer", "GemmaTokenizerFast") 48 | 49 | def __init__(self, image_processor=None, tokenizer=None): 50 | super().__init__(image_processor, tokenizer) 51 | 52 | def __call__( 53 | self, 54 | text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, 55 | images: ImageInput = None, 56 | padding: Union[bool, str, PaddingStrategy] = False, 57 | truncation: Union[bool, str, TruncationStrategy] = None, 58 | max_length=None, 59 | return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH, 60 | ) -> BatchFeature: 61 | """ 62 | Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` 63 | and `kwargs` arguments to LlamaTokenizerFast's [`~LlamaTokenizerFast.__call__`] if `text` is not `None` to encode 64 | the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to 65 | CLIPImageProcessor's [`~CLIPImageProcessor.__call__`] if `images` is not `None`. Please refer to the doctsring 66 | of the above two methods for more information. 67 | 68 | Args: 69 | text (`str`, `List[str]`, `List[List[str]]`): 70 | The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings 71 | (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set 72 | `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). 73 | images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`): 74 | The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch 75 | tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a 76 | number of channels, H and W are image height and width. 77 | padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`): 78 | Select a strategy to pad the returned sequences (according to the model's padding side and padding 79 | index) among: 80 | - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single 81 | sequence if provided). 82 | - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum 83 | acceptable input length for the model if that argument is not provided. 84 | - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different 85 | lengths). 86 | max_length (`int`, *optional*): 87 | Maximum length of the returned list and optionally padding length (see above). 88 | truncation (`bool`, *optional*): 89 | Activates truncation to cut input sequences longer than `max_length` to `max_length`. 90 | return_tensors (`str` or [`~utils.TensorType`], *optional*): 91 | If set, will return tensors of a particular framework. Acceptable values are: 92 | 93 | - `'tf'`: Return TensorFlow `tf.constant` objects. 94 | - `'pt'`: Return PyTorch `torch.Tensor` objects. 95 | - `'np'`: Return NumPy `np.ndarray` objects. 96 | - `'jax'`: Return JAX `jnp.ndarray` objects. 97 | 98 | Returns: 99 | [`BatchFeature`]: A [`BatchFeature`] with the following fields: 100 | 101 | - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. 102 | - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when 103 | `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not 104 | `None`). 105 | - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. 106 | """ 107 | if images is not None: 108 | pixel_values = self.image_processor(images, return_tensors=return_tensors)["pixel_values"] 109 | else: 110 | pixel_values = None 111 | text_inputs = self.tokenizer( 112 | text, return_tensors=return_tensors, padding=padding, truncation=truncation, max_length=max_length 113 | ) 114 | 115 | return BatchFeature(data={**text_inputs, "pixel_values": pixel_values}) 116 | 117 | # Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama 118 | def batch_decode(self, *args, **kwargs): 119 | """ 120 | This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please 121 | refer to the docstring of this method for more information. 122 | """ 123 | return self.tokenizer.batch_decode(*args, **kwargs) 124 | 125 | # Copied from transformers.models.clip.processing_clip.CLIPProcessor.decode with CLIP->Llama 126 | def decode(self, *args, **kwargs): 127 | """ 128 | This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to 129 | the docstring of this method for more information. 130 | """ 131 | return self.tokenizer.decode(*args, **kwargs) 132 | 133 | @property 134 | # Copied from transformers.models.clip.processing_clip.CLIPProcessor.model_input_names 135 | def model_input_names(self): 136 | tokenizer_input_names = self.tokenizer.model_input_names 137 | image_processor_input_names = self.image_processor.model_input_names 138 | return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names)) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchvision 3 | transformers>=4.41.1 4 | gradio>=4.36.1 5 | spaces 6 | pillow 7 | accelerate 8 | matplotlib 9 | seaborn 10 | scipy 11 | bitsandbytes 12 | deepspeed 13 | opencv-python 14 | -------------------------------------------------------------------------------- /utils_attn.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | sys.path.append(os.getenv('LLAVA_HOME')) 3 | 4 | from collections import defaultdict 5 | import numpy as np 6 | import torch 7 | from torchvision.transforms.functional import to_pil_image 8 | import gradio as gr 9 | import PIL 10 | import matplotlib.pyplot as plt 11 | import matplotlib.gridspec as gridspec 12 | from matplotlib.colors import to_rgba 13 | 14 | import seaborn 15 | from PIL import Image, ImageDraw 16 | import pandas as pd 17 | from scipy import stats 18 | 19 | import logging 20 | 21 | logger = logging.getLogger(__name__) 22 | cmap = plt.get_cmap('jet') 23 | separators_list = ['.',',','?','!', ':', ';', '', '/', '!', '(', ')', '[', ']', '{', '}', '<', '>', '|', '\\', '-', '_', '+', '=', '*', '&', '^', '%', '$', '#', '@', '!', '~', '`', ' ', '\t', '\n', '\r', '\x0b', '\x0c'] 24 | 25 | def move_to_device(input, device='cpu'): 26 | 27 | if isinstance(input, torch.Tensor): 28 | return input.to(device).detach() 29 | elif isinstance(input, list): 30 | return [move_to_device(inp) for inp in input] 31 | elif isinstance(input, tuple): 32 | return tuple([move_to_device(inp) for inp in input]) 33 | elif isinstance(input, dict): 34 | return dict( ((k, move_to_device(v)) for k,v in input.items())) 35 | else: 36 | raise ValueError(f"Unknown data type for {input.type}") 37 | 38 | def convert_token2word(R_i_i, tokens, separators_list): 39 | current_count = 1 40 | current_rel_map = 0 41 | word_rel_maps = {} 42 | current_word = "" 43 | for token, rel in zip(tokens, R_i_i): 44 | if not token.startswith('▁') and token not in separators_list: 45 | current_word += token 46 | current_rel_map += rel 47 | current_count += 1 48 | else: 49 | # Otherwise, store the current word's relevancy map and start a new word 50 | word_rel_maps[current_word] = current_rel_map / current_count 51 | current_word = token 52 | current_rel_map = rel 53 | current_count = 1 54 | return list(word_rel_maps.keys()), torch.Tensor(list(word_rel_maps.values())) 55 | 56 | def draw_heatmap_on_image(mat, img_recover, normalize=True): 57 | if normalize: 58 | mat = (mat - mat.min()) / (mat.max() - mat.min()) 59 | mat = cmap(mat) #.cpu().numpy()) 60 | mat = Image.fromarray((mat[:, :, :3] * 255).astype(np.uint8)).resize((336,336), Image.BICUBIC) 61 | mat.putalpha(128) 62 | img_overlay_attn = img_recover.copy() 63 | img_overlay_attn.paste(mat, mask=mat) 64 | 65 | return img_overlay_attn 66 | 67 | def attn_update_slider(state): 68 | fn_attention = state.attention_key + '_attn.pt' 69 | attentions = torch.load(fn_attention, mmap=True) 70 | num_layers = len(attentions[0]) 71 | return state, gr.Slider(1, num_layers, value=num_layers, step=1, label="Layer") 72 | 73 | 74 | def handle_attentions_i2t(state, highlighted_text, layer_idx=32, token_idx=0): 75 | ''' 76 | Draw attention heatmaps and return as a list of PIL images 77 | ''' 78 | 79 | if not hasattr(state, 'attention_key'): 80 | return None, None, [], None 81 | layer_idx -= 1 82 | fn_attention = state.attention_key + '_attn.pt' 83 | recovered_image = state.recovered_image 84 | img_idx = state.image_idx 85 | 86 | if highlighted_text is not None: 87 | generated_text = state.output_ids_decoded 88 | token_idx_map = dict((t,i) for i,t in enumerate(generated_text)) 89 | token_idx_list = [] 90 | for item in highlighted_text: 91 | label = item['class_or_confidence'] 92 | if label is None: 93 | continue 94 | tokens = item['token'].split(' ') 95 | 96 | for tok in tokens: 97 | tok = tok.strip(' ') 98 | if tok in token_idx_map: 99 | token_idx_list.append(token_idx_map[tok]) 100 | else: 101 | logger.warning(f'{tok} not found in generated text') 102 | 103 | if not token_idx_list: 104 | logger.info(highlighted_text) 105 | logger.info(generated_text) 106 | gr.Error(f"Selected text not found in generated output") 107 | return None, None, [], None 108 | 109 | generated_text = [] 110 | for data in highlighted_text: 111 | generated_text.extend([(data['token'], None if data['class_or_confidence'] is None else "'"), (' ', None)]) 112 | else: 113 | token_idx_list = [0] 114 | 115 | generated_text = [] 116 | for text in state.output_ids_decoded: 117 | generated_text.extend([(text, None), (' ', None)]) 118 | 119 | 120 | if not os.path.exists(fn_attention): 121 | gr.Error('Attention file not found. Please re-run query.') 122 | else: 123 | attentions = torch.load(fn_attention) 124 | logger.info(f'Loaded attention from {fn_attention}') 125 | if len(attentions) == len(state.output_ids_decoded): 126 | gr.Error('Mismatch between lengths of attentions and output tokens') 127 | batch_size, num_heads, inp_seq_len, seq_len = attentions[0][0].shape 128 | cmap = plt.get_cmap('jet') 129 | 130 | img_attn_list = [] 131 | img_attn_mean = [] 132 | for head_idx in range(num_heads): 133 | img_attn = None 134 | for token_idx in token_idx_list: 135 | if token_idx >= len(attentions): 136 | logger.info(f'token index {token_idx} out of bounds') 137 | continue 138 | mh_attention = attentions[token_idx][layer_idx] 139 | batch_size, num_heads, inp_seq_len, seq_len = mh_attention.shape 140 | if inp_seq_len > 1: 141 | mh_attention = mh_attention[:,:,-1,:] 142 | mh_attention = mh_attention.squeeze() 143 | img_attn_token = mh_attention[head_idx, img_idx:img_idx+576].reshape(24,24).float().cpu().numpy() 144 | 145 | if img_attn is None: 146 | img_attn = img_attn_token 147 | else: 148 | img_attn += img_attn_token 149 | img_attn /= len(token_idx_list) 150 | img_overlay_attn = draw_heatmap_on_image(img_attn, recovered_image) 151 | 152 | img_attn_list.append((img_overlay_attn, f'Head_{head_idx}')) 153 | 154 | # Calculate mean attention per head 155 | # img_attn = mh_attention[head_idx, img_idx:img_idx+576].reshape(24,24).cpu().numpy() 156 | 157 | img_attn /= img_attn.max() 158 | img_attn_mean.append(img_attn.mean()) 159 | img_attn_list = [x for _, x in sorted(zip(img_attn_mean, img_attn_list), key=lambda pair: pair[0], reverse=True)] 160 | 161 | fig = plt.figure(figsize=(10, 3)) 162 | ax = seaborn.heatmap([img_attn_mean], 163 | linewidths=.3, square=True, cbar_kws={"orientation": "horizontal", "shrink":0.3} 164 | ) 165 | ax.set_xlabel('Head number') 166 | ax.set_title(f"Mean Attention between the image and the token {[state.output_ids_decoded[tok] for tok in token_idx_list]} for layer {layer_idx+1}") 167 | 168 | fig.tight_layout() 169 | 170 | return generated_text, recovered_image, img_attn_list, fig 171 | 172 | def handle_relevancy(state, type_selector,incude_text_relevancy=False): 173 | incude_text_relevancy = True 174 | logger.debug(f'incude_text_relevancy: {incude_text_relevancy}') 175 | 176 | if not hasattr(state, 'attention_key'): 177 | return [] 178 | 179 | fn_attention = state.attention_key + '_relevancy.pt' 180 | recovered_image = state.recovered_image 181 | img_idx = state.image_idx 182 | 183 | word_rel_maps = torch.load(fn_attention) 184 | if type_selector not in word_rel_maps: 185 | logger.warning(f'{type_selector} not in keys: {word_rel_maps.keys()}') 186 | return [] 187 | 188 | word_rel_map = word_rel_maps[type_selector] 189 | image_list = [] 190 | i = 0 191 | for rel_key, rel_map in word_rel_map.items(): 192 | i+=1 193 | if rel_key in separators_list: 194 | continue 195 | if (rel_map.shape[-1] != 577) and img_idx: 196 | if not incude_text_relevancy: 197 | rel_map = rel_map[-1,:][img_idx:img_idx+576].reshape(24,24).float().cpu().numpy() 198 | normalize_image_tokens = True 199 | if incude_text_relevancy: 200 | input_text_tokenized = state.input_text_tokenized 201 | input_text_tokenized_len = int(len(input_text_tokenized)) 202 | img_idx = int(img_idx) 203 | rel_maps_no_system = torch.cat((rel_map[-1,:][img_idx:img_idx+576], rel_map[-1,:][img_idx+576+3:576 + input_text_tokenized_len-1-5])) 204 | logger.debug(f'shape of rel_maps_no_system: {rel_maps_no_system.shape}') 205 | # make sure the sum of the relevancy scores is 1 206 | # if rel_maps_no_system.sum() != 0: 207 | # rel_maps_no_system /= rel_maps_no_system.sum() 208 | rel_maps_no_system = (rel_maps_no_system - rel_maps_no_system.min()) / (rel_maps_no_system.max() - rel_maps_no_system.min()) 209 | rel_map = rel_maps_no_system[:576].reshape(24,24).cpu().numpy() 210 | normalize_image_tokens = False 211 | else: 212 | rel_map = rel_map[0,1:].reshape(24,24).cpu().numpy() 213 | normalize_image_tokens = True 214 | rel_map = draw_heatmap_on_image(rel_map, recovered_image, normalize=normalize_image_tokens) 215 | # strip _ from all rel keys 216 | rel_key = rel_key.strip('▁').strip('_') 217 | image_list.append( (rel_map, rel_key)) 218 | 219 | return image_list 220 | 221 | def grid_size(len): 222 | n_columns = 3 if len < 16 else 4 223 | n_rows = int(np.ceil(len / n_columns)) 224 | return (n_rows, n_columns) 225 | 226 | def fig2img(fig): 227 | """Convert a Matplotlib figure to a PIL Image and return it""" 228 | import io 229 | buf = io.BytesIO() 230 | fig.savefig(buf) 231 | buf.seek(0) 232 | img = Image.open(buf) 233 | return img 234 | 235 | def handle_text_relevancy(state, type_selector): 236 | if type_selector != "llama": 237 | return [], [] 238 | else: 239 | tokens = state.output_ids_decoded 240 | fn_attention = state.attention_key + '_relevancy.pt' 241 | img_idx = state.image_idx 242 | input_text_tokenized = state.input_text_tokenized 243 | word_rel_maps = torch.load(fn_attention) 244 | 245 | input_text_tokenized_all = input_text_tokenized.copy() 246 | # loop over all output tokens 247 | word_rel_map = word_rel_maps["llama_token"] 248 | # grid_size_temp = grid_size(len(rel_scores)) 249 | all_figs = [] 250 | highlighted_tokens = [] 251 | for word, rel_map in word_rel_map.items(): 252 | if word in separators_list: 253 | continue 254 | fig, ax = plt.subplots(figsize=(5, 5)) 255 | # if the token is not a separator 256 | # if i < len(tokens) and tokens[i] not in separators_list: 257 | img_avg_rel = rel_map[-1,:][img_idx:img_idx+576].mean() 258 | img_max_rel = rel_map[-1,:][img_idx:img_idx+576].max() 259 | logger.debug(f'img_avg_rel for token {word}: {img_avg_rel}') 260 | # exclude the image tokens from the rel_scores[i] and replace all of them by a single value of the average relevancy for the image 261 | current_relevency = rel_map[-1,:][:img_idx].clone() 262 | # add the average relevancy for the image to the current_relevency tensor 263 | current_relevency = torch.cat((current_relevency, img_avg_rel.unsqueeze(0))) 264 | current_relevency = torch.cat((current_relevency, rel_map[-1,:][img_idx+576:576 + len(input_text_tokenized_all)-1])) 265 | current_relevency = current_relevency.cpu() 266 | logger.debug(f'shape of text relevancy map: {rel_map[-1,:].shape}') 267 | #rel_score_text = rel_scores[i][-1,:][:img_idx] 268 | assert len(current_relevency) == len(input_text_tokenized_all), f"The length of the relevancy score ({len(current_relevency)}) is not the same as the length of the input tokens ({len(input_text_tokenized_all)})\n{input_text_tokenized_all}" 269 | current_relevency = current_relevency[img_idx+3:-5] 270 | input_text_tokenized = input_text_tokenized_all[img_idx+3:-5] 271 | input_text_tokenized_word, current_relevency_word = convert_token2word(current_relevency, input_text_tokenized, separators_list) 272 | 273 | current_relevency_word_topk = current_relevency_word.topk(min(3,len(word_rel_map))) 274 | max_rel_scores = current_relevency_word_topk.values 275 | max_rel_scores = torch.cat((max_rel_scores, img_max_rel.unsqueeze(0).cpu())) 276 | max_rel_scores_idx = current_relevency_word_topk.indices 277 | max_input_token = [input_text_tokenized_word[j].lstrip('▁').lstrip('_') for j in max_rel_scores_idx] 278 | 279 | # Image to text relevancy ratio 280 | # img_text_rel_ratio = max_rel_scores[-1] / current_relevency_word.mean() 281 | img_text_rel_value = stats.percentileofscore(max_rel_scores, img_max_rel.item(), kind='strict') / 100 282 | 283 | highlighted_tokens.extend( 284 | [ 285 | (word.strip('▁'), float(img_text_rel_value)), 286 | (" ", None) 287 | ] 288 | ) 289 | 290 | max_input_token.append("max_img") 291 | ax.bar(max_input_token, max_rel_scores) 292 | # ax.set_xticklabels(max_input_token, fontsize=12) 293 | 294 | # save the plot per each output token 295 | # make part of the title bold 296 | ax.set_title(f'Output Token: {word.strip("▁").strip("_")}', fontsize=15) 297 | # add labels for the x and y axis 298 | ax.set_xlabel('Input Tokens', fontsize=15) 299 | ax.set_ylabel('Relevancy Score', fontsize=15) 300 | 301 | fig.tight_layout() 302 | 303 | fig_pil = fig2img(fig) 304 | all_figs.append(fig_pil) 305 | 306 | return all_figs, highlighted_tokens 307 | 308 | def handle_image_click(image,box_grid, x, y): 309 | # Calculate which box was clicked 310 | box_width = image.size[1] // 24 311 | box_height = image.size[0] // 24 312 | 313 | box_x = x // box_width 314 | box_y = y // box_height 315 | 316 | box_grid[box_x][box_y] = not box_grid[box_x][box_y] 317 | 318 | # Add a transparent teal box to the image at the clicked location 319 | overlay = image.copy() 320 | draw = ImageDraw.Draw(overlay) 321 | indices = np.where(box_grid) 322 | for i, j in zip(*indices): 323 | draw.rectangle([(i * box_width, j * box_height), ((i + 1) * box_width, (j + 1) * box_height)], fill=(255, 100, 100, 128)) 324 | 325 | image = Image.blend(image, overlay, alpha=0.8) 326 | 327 | # Return the updated image 328 | return image, box_grid 329 | 330 | def handle_box_reset(input_image,box_grid): 331 | for i in range(24): 332 | for j in range(24): 333 | box_grid[i][j] = False 334 | try: 335 | to_return = input_image.copy() 336 | except: 337 | to_return = None 338 | return to_return, box_grid 339 | 340 | 341 | def boxes_click_handler(image, box_grid, event: gr.SelectData): 342 | if event is not None: 343 | x, y = event.index[0], event.index[1] 344 | 345 | image,box_grid = handle_image_click(image,box_grid, x, y) 346 | if x is not None and y is not None: 347 | return image,box_grid 348 | 349 | def plot_attention_analysis(state, attn_modality_select): 350 | fn_attention = state.attention_key + '_attn.pt' 351 | recovered_image = state.recovered_image 352 | img_idx = state.image_idx 353 | 354 | if os.path.exists(fn_attention): 355 | attentions = torch.load(fn_attention) 356 | logger.info(f'Loaded attention from {fn_attention}') 357 | if len(attentions) == len(state.output_ids_decoded): 358 | gr.Error('Mismatch between lengths of attentions and output tokens') 359 | 360 | num_tokens = len(attentions) 361 | num_layers = len(attentions[0]) 362 | last_mh_attention = attentions[0][-1] 363 | batch_size, num_heads, inp_seq_len, seq_len = attentions[0][0].shape 364 | generated_text = state.output_ids_decoded 365 | 366 | else: 367 | return state, None 368 | 369 | # Img2TextAns Attention 370 | heatmap_mean = defaultdict(dict) 371 | if attn_modality_select == "Image-to-Answer": 372 | for layer_idx in range(1,num_layers): 373 | for head_idx in range(num_heads): 374 | mh_attentions = [] 375 | mh_attentions = [attentions[i][layer_idx][:,:,-1,:].squeeze() for i in range(len(generated_text))] 376 | img_attn = torch.stack([mh_attention[head_idx, img_idx:img_idx+576].reshape(24,24) for mh_attention in mh_attentions]).float().cpu().numpy() 377 | # img_attn /= img_attn.max() 378 | heatmap_mean[layer_idx][head_idx] = img_attn.mean() # img_attn.mean((1,2)) 379 | elif attn_modality_select == "Question-to-Answer": 380 | fn_input_ids = state.attention_key + '_input_ids.pt' 381 | img_idx = state.image_idx 382 | input_ids = torch.load(fn_input_ids) 383 | len_question_only = input_ids.shape[1] - img_idx - 1 384 | for layer_idx in range(num_layers): 385 | for head_idx in range(num_heads): 386 | mh_attentions = [] 387 | mh_attentions = [attentions[i][layer_idx][:,:,-1,:].squeeze() for i in range(len(generated_text))] 388 | ques_attn = torch.stack([mh_attention[head_idx, img_idx+576:img_idx+576+len_question_only] for mh_attention in mh_attentions]).float().cpu().numpy() 389 | # ques_attn /= ques_attn.max() 390 | heatmap_mean[layer_idx][head_idx] = ques_attn.mean() 391 | heatmap_mean_df = pd.DataFrame(heatmap_mean) 392 | fig = plt.figure(figsize=(4, 4)) 393 | ax = seaborn.heatmap(heatmap_mean_df,square=True, cbar_kws={"orientation": "horizontal"}) 394 | ax.set_xlabel("Layers") 395 | ax.set_ylabel("Heads") 396 | ax.set_title(f"{attn_modality_select} Mean Attention") 397 | 398 | fig.tight_layout() 399 | return state, fig 400 | 401 | def plot_text_to_image_analysis(state, layer_idx, boxes, head_idx=1 ): 402 | 403 | fn_attention = state.attention_key + '_attn.pt' 404 | img_recover = state.recovered_image 405 | img_idx = state.image_idx 406 | generated_text = state.output_ids_decoded 407 | 408 | # Sliders start at 1 409 | head_idx -= 1 410 | layer_idx -= 1 411 | img_patches = [(j, i) for i, row in enumerate(boxes) for j, clicked in enumerate(row) if clicked] 412 | if len(img_patches) == 0: 413 | img_patches = [(5,5)] 414 | if os.path.exists(fn_attention): 415 | attentions = torch.load(fn_attention) 416 | logger.info(f'Loaded attention from {fn_attention}') 417 | if len(attentions) == len(state.output_ids_decoded): 418 | gr.Error('Mismatch between lengths of attentions and output tokens') 419 | 420 | # num_tokens = len(attentions) 421 | # num_layers = len(attentions[0]) 422 | # last_mh_attention = attentions[0][-1] 423 | batch_size, num_heads, inp_seq_len, seq_len = attentions[0][0].shape 424 | generated_text = state.output_ids_decoded 425 | 426 | else: 427 | return state, None 428 | mh_attentions = [] 429 | for head_id in range(num_heads): 430 | att_per_head = [] 431 | for i, out_att in enumerate(attentions): 432 | mh_attention = out_att[layer_idx] 433 | mh_attention = mh_attention[:, :, -1, :].unsqueeze(2) 434 | att_img = mh_attention.squeeze()[head_id, img_idx:img_idx+576].reshape(24,24) 435 | att_per_head.append(att_img) 436 | att_per_head = torch.stack(att_per_head) 437 | mh_attentions.append(att_per_head) 438 | mh_attentions = torch.stack(mh_attentions) 439 | 440 | img_mask = np.zeros((24, 24)) 441 | for img_patch in img_patches: 442 | img_mask[img_patch[0], img_patch[1]] = 1 443 | img_mask = cmap(img_mask) 444 | img_mask = Image.fromarray((img_mask[:, :, :3] * 255 ).astype(np.uint8)).resize((336,336), Image.BICUBIC) 445 | img_mask.putalpha(208) 446 | img_patch_recovered = img_recover.copy() 447 | img_patch_recovered.paste(img_mask, mask=img_mask) 448 | img_patch_recovered 449 | 450 | words = generated_text 451 | float_values = torch.mean(torch.stack([mh_attentions[head_idx, :, x, y] for x, y in img_patches]), dim=0).float().cpu() 452 | normalized_values = (float_values - float_values.min()) / (float_values.max() - float_values.min()) 453 | 454 | fig = plt.figure(figsize=(10, 4)) 455 | gs = gridspec.GridSpec(1, 2, width_ratios=[1, 3]) # 2 columns, first column for the image, second column for the words 456 | ax_img = plt.subplot(gs[0]) 457 | ax_img.imshow(img_patch_recovered) 458 | ax_img.axis('off') 459 | ax_words = plt.subplot(gs[1]) 460 | x_position = 0.0 461 | 462 | for word, value in zip(words, normalized_values): 463 | color = plt.get_cmap("coolwarm")(value) 464 | color = to_rgba(color, alpha=0.6) 465 | ax_words.text(x_position, 0.5, word, color=color, fontsize=14, ha='left', va='center') 466 | x_position += 0.10 467 | 468 | cax = fig.add_axes([0.1, 0.15, 0.8, 0.03]) 469 | norm = plt.Normalize(min(normalized_values), max(normalized_values)) 470 | sm = plt.cm.ScalarMappable(cmap="coolwarm", norm=norm) 471 | sm.set_array([]) 472 | cb = fig.colorbar(sm, cax=cax, orientation='horizontal') 473 | cb.set_label('Color Legend', labelpad=10, loc="center") 474 | 475 | ax_words.axis('off') 476 | plt.suptitle(f"Attention to the selected image patch(es) of head #{head_idx+1} and layer #{layer_idx+1}", fontsize=16, y=0.8, x=0.6) 477 | 478 | # attn_heatmap = plt.figure(figsize=(10, 3)) 479 | # attn_image_patch = mh_attentions[:, :, img_patch[0], img_patch[1]].cpu().mean(-1) 480 | attn_image_patch = torch.stack([mh_attentions[:, :, x, y] for x, y in img_patches]).mean(0).float().cpu().mean(-1) 481 | logger.debug(torch.stack([mh_attentions[:, :, x, y] for x, y in img_patches]).shape) 482 | logger.debug(torch.stack([mh_attentions[:, :, x, y] for x, y in img_patches]).mean(0).shape) 483 | logger.debug(attn_image_patch.shape) 484 | 485 | fig2 = plt.figure(figsize=(10, 3)) 486 | ax2 = seaborn.heatmap([attn_image_patch], 487 | linewidths=.3, square=True, cbar_kws={"orientation": "horizontal", "shrink":0.3} 488 | ) 489 | ax2.set_xlabel('Head number') 490 | ax2.set_title(f"Mean Head Attention between the image patches selected and the answer for layer {layer_idx+1}") 491 | fig2.tight_layout() 492 | return state, fig, fig2 493 | 494 | 495 | def reset_tokens(state): 496 | generated_text = [] 497 | for text in state.output_ids_decoded: 498 | generated_text.extend([(text, None), (' ', None)]) 499 | 500 | return generated_text 501 | 502 | -------------------------------------------------------------------------------- /utils_causal_discovery.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('causality_lab') 3 | 4 | import logging 5 | import os 6 | import numpy as np 7 | import gradio as gr 8 | import torch 9 | from PIL import ImageDraw, Image 10 | from matplotlib import pyplot as plt 11 | from plot_utils import draw_graph, draw_pds_tree 12 | from causal_discovery_utils.cond_indep_tests import CondIndepParCorr 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | from utils_causal_discovery_fn import ( 17 | get_relevant_image_tokens, 18 | tokens_analysis, 19 | create_explanation, 20 | copy_sub_graph, 21 | show_tokens_on_image, 22 | calculate_explanation_pvals, 23 | get_relevant_prompt_tokens, 24 | get_relevant_text_tokens, 25 | crop_token, 26 | get_expla_set_per_rad 27 | ) 28 | 29 | 30 | def create_im_tokens_marks(orig_img, tokens_to_mark, weights=None, txt=None, txt_pos=None): 31 | im_1 = orig_img.copy() 32 | if weights is not None: 33 | im_heat = show_tokens_on_image(tokens_to_mark, im_1, weights) 34 | else: 35 | im_heat = show_tokens_on_image(tokens_to_mark, im_1) 36 | im_heat_edit = ImageDraw.Draw(im_heat) 37 | if isinstance(txt, str): 38 | if txt_pos is None: 39 | txt_pos = (10, 10) 40 | im_heat_edit.text(txt_pos, txt, fill=(255, 255, 255)) 41 | im_heat = im_heat_edit._image 42 | return im_heat 43 | 44 | 45 | def causality_update_dropdown(state): 46 | generated_text = state.output_ids_decoded 47 | choices = [ f'{i}_{tok}' for i,tok in enumerate(generated_text)] 48 | return state, gr.Dropdown(value=choices[0], interactive=True, scale=2, choices=choices) 49 | 50 | 51 | def handle_causal_head(state, explainers_data, head_selection, class_token_txt): 52 | recovered_image = state.recovered_image 53 | first_im_token_idx = state.image_idx 54 | 55 | token_to_explain = explainers_data[0] 56 | head_id = head_selection 57 | explainer = explainers_data[1][head_id] 58 | if explainer is None: 59 | return [], None 60 | 61 | expla_set_per_rad = get_expla_set_per_rad(explainer.results[token_to_explain]['pds_tree']) 62 | max_depth = max(expla_set_per_rad.keys()) 63 | im_heat_list = [] 64 | im_tok_rel_idx = [] 65 | for rad in range(1,max_depth+1): 66 | im_tok_rel_idx += [v-first_im_token_idx 67 | for v in expla_set_per_rad[rad] if v >= first_im_token_idx and v < (first_im_token_idx+576)] 68 | im_heat_list.append( 69 | create_im_tokens_marks(recovered_image, im_tok_rel_idx, txt='search radius: {rad}'.format(rad=rad)) 70 | ) 71 | 72 | 73 | # im_graph_list = [] 74 | # for r in range(1, 5): 75 | # expla_list = explainer.explain(token_to_explain, max_range=r)[0][0] 76 | # nodes_set = set(expla_list) 77 | # nodes_set.add(token_to_explain) 78 | # subgraph = copy_sub_graph(explainer.graph, nodes_set) 79 | # fig = draw_graph(subgraph, show=False) 80 | # fig.canvas.draw() 81 | # im_graph = Image.frombytes('RGB', fig.canvas.get_width_height(),fig.canvas.tostring_rgb()) 82 | # plt.close() 83 | # im_graph_list.append(im_graph) 84 | 85 | expla_list = explainers_data[2][head_id] 86 | node_labels = dict() 87 | for tok in expla_list: 88 | im_idx = tok - first_im_token_idx 89 | if im_idx < 0 or im_idx >= 576: # if token is not image 90 | continue 91 | im_tok = crop_token(recovered_image, im_idx, pad=2) 92 | node_labels[tok] = im_tok.resize((45, 45)) 93 | 94 | node_labels[token_to_explain] = class_token_txt.split('_')[1] 95 | 96 | nodes_set = set(expla_list) 97 | nodes_set.add(token_to_explain) 98 | fig = draw_pds_tree(explainer.results[token_to_explain]['pds_tree'], explainer.graph, node_labels=node_labels, 99 | node_size_factor=1.4) 100 | if fig is None: 101 | fig = plt.figure() 102 | fig.canvas.draw() 103 | im_graph = Image.frombytes('RGB', fig.canvas.get_width_height(),fig.canvas.tostring_rgb()) 104 | plt.close() 105 | 106 | return im_heat_list, im_graph 107 | 108 | 109 | def handle_causality(state, state_causal_explainers, token_to_explain, alpha_ext=None, att_th_ext=None): 110 | # ---***------***------***------***------***------***------***------***------***------***------***------***--- 111 | # ---***--- Results' containers ---***--- 112 | gallery_image_list = [] 113 | gallery_graph_list = [] 114 | gallery_bar_graphs = [] 115 | 116 | # ---***------***------***------***------***------***------***------***------***------***------***------***--- 117 | # ---***--- Generic app handling ---***--- 118 | if not hasattr(state, 'attention_key'): 119 | return [] 120 | 121 | # ---***------***------***------***------***------***------***------***------***------***------***------***--- 122 | # ---***--- Load attention matrix ---***--- 123 | fn_attention = state.attention_key + '_attn.pt' 124 | recovered_image = state.recovered_image 125 | first_im_token_idx = state.image_idx 126 | generated_text = state.output_ids_decoded 127 | 128 | if not os.path.exists(fn_attention): 129 | gr.Error('Attention file not found. Please re-run query.') 130 | else: 131 | attentions = torch.load(fn_attention) # attentions is a tuple of length as the number of generated tokens. 132 | 133 | last_mh_attention = attentions[-1][-1] # last generated token, last layer 134 | num_heads, _, attention_len = last_mh_attention[-1].shape 135 | full_attention = np.zeros((num_heads, attention_len, attention_len)) 136 | 137 | last_mh_attention = attentions[0][-1] # last layer's attention matrices before output generation 138 | attention_vals = last_mh_attention[0].detach().cpu().numpy() # 0 is the index for the sample in the batch. 139 | d1 = attention_vals.shape[-1] 140 | full_attention[:, :d1, :d1] = attention_vals 141 | 142 | # create one full attention matrix that includes attention to generated tokens 143 | for gen_idx in range(1, len(generated_text)): 144 | last_mh_attention = attentions[gen_idx][-1] 145 | att_np = last_mh_attention[0].detach().cpu().numpy() 146 | full_attention[:, d1, :att_np.shape[-1]] = att_np[:,0,:] 147 | d1 += 1 148 | 149 | # Sizes: 150 | # Number of heads: {num_heads}, attention size: {attention_len}x{attention_len} 151 | 152 | # ---***------***------***------***------***------***------***------***------***------***------***------***--- 153 | # ---***--- Hyper-parameters for causal discovery ---***--- 154 | threshold = 1e-5 # alpha; threshold for p-value in conditional independence testing 155 | degrees_of_freedom = 128 156 | default_search_range = 3 157 | max_num_image_tokens = 50 # number of image-tokens to consider as 'observed'. Used for calculating head importance 158 | att_th = 0.01 # threshold for attention values. Below this value, the token is considered 'not-attented' 159 | search_range = default_search_range # causal-explanation seach-distance in the causal graph 160 | if alpha_ext is not None: 161 | threshold = alpha_ext 162 | if att_th_ext is not None: 163 | att_th = att_th_ext 164 | 165 | heads_to_analyse = list(range(num_heads)) 166 | 167 | token_to_explain = attention_len - len(generated_text) + int(token_to_explain.split('_')[0]) 168 | logger.info(f'Using token index {token_to_explain} for explaining') 169 | 170 | # ---***------***------***------***------***------***------***------***------***------***------***------***--- 171 | # ---***--- Learn causal Structure ---***--- 172 | 173 | time_struct = [] # list of runtime results for learning the structure for different heads 174 | time_reason = [] # list of runtime results for recovering an explanation for different heads 175 | 176 | expla_list_all = [None] * num_heads 177 | explainer_all = [None] * num_heads 178 | timing_all = [None] * num_heads 179 | head_importance = [0] * num_heads 180 | 181 | # state_causal_explainers[0] = token_to_explain 182 | # state_causal_explainers[1] = [] 183 | state_causal_explainers = [token_to_explain] 184 | # state_causal_explainers.append(dict()) 185 | 186 | total_weights = [0 for _ in range(576)] # weights for image tokens (24 x 24 tokens) 187 | 188 | for head_id in heads_to_analyse: # ToDo: Run in parallel (threading/multiprocessing; a worker for head) 189 | head_attention = full_attention[head_id] # alias for readability 190 | 191 | # ---***------***--- Text causal graph ---***------***--- 192 | text_expla, text_expl, timing = tokens_analysis(head_attention, list(range(first_im_token_idx+576, token_to_explain+1)), 193 | token_of_interest=token_to_explain, 194 | number_of_samples=degrees_of_freedom, p_val_thrshold=threshold, max_search_range=search_range, 195 | verbose=False) 196 | txt_node_labels = dict() 197 | for v in text_expla: 198 | # print(f'attention len: {attention_len} - Generated len: {len(generated_text)} + node: {v}, idx={attention_len - len(generated_text) + v}') 199 | idx = v - (attention_len - len(generated_text)) 200 | if idx >= 0: 201 | txt_node_labels[v] = generated_text[idx] 202 | # End: *------***--- Text causal graph ---***------***--- 203 | 204 | 205 | w = head_attention[token_to_explain, :] 206 | w_img = w[first_im_token_idx:(first_im_token_idx+576)] 207 | # im_entropy = -np.nansum(w_img * np.log(w_img)) 208 | # total_entropy = -np.nansum(w * np.log(w)) 209 | 210 | # print(f'{head_id}: total_entropy: {total_entropy}, image entropy: {im_entropy}, entropy diff: im - total: {im_entropy - total_entropy}') 211 | num_high_att = max(2, sum(w > att_th)) 212 | 213 | num_image_tokens = min(num_high_att, max_num_image_tokens) # number of image tokens to select for analysis 214 | 215 | relevant_image_idx = get_relevant_image_tokens(class_token=token_to_explain, 216 | attention_matrix=head_attention, 217 | first_token=first_im_token_idx, 218 | num_top_k_tokens=num_image_tokens) 219 | 220 | relevant_gen_idx = get_relevant_text_tokens(class_token=token_to_explain, attention_matrix=head_attention, att_th=att_th, first_image_token=first_im_token_idx) 221 | relevant_tokens = relevant_image_idx + relevant_gen_idx + [token_to_explain] 222 | 223 | # print(f'Self: {head_attention[token_to_explain, token_to_explain]}') 224 | # att_th = head_attention[token_to_explain, token_to_explain] 225 | # att_th = np.median(w[first_im_token_idx+576:]) 226 | # print(f'Attentnion threshold: {att_th}') 227 | # relevant_tokens = set(np.where(w >= att_th)[0]) 228 | # relevant_tokens.add(token_to_explain) 229 | # relevant_tokens = list(relevant_tokens) 230 | # relevant_tokens = [v for v in relevant_tokens if v >= first_im_token_idx] 231 | # print('relevant tokens', relevant_tokens) 232 | 233 | expla_list, explainer, timing = tokens_analysis(head_attention, relevant_tokens, 234 | token_of_interest=token_to_explain, 235 | number_of_samples=degrees_of_freedom, p_val_thrshold=threshold, max_search_range=search_range, 236 | verbose=False) 237 | 238 | expla_list_all[head_id] = expla_list 239 | explainer_all[head_id] = explainer 240 | timing_all[head_id] = timing 241 | 242 | # calculate Head Importance 243 | im_expla_tokens_list = [v for v in expla_list if (v >= first_im_token_idx) and (v < first_im_token_idx + 576)] # only image explanation 244 | ci_test = explainer.ci_test 245 | prev_num_records = ci_test.num_records 246 | ci_test.num_records = len(im_expla_tokens_list) 247 | weights_list = [] 248 | for im_expla_tok in im_expla_tokens_list: 249 | cond_set = tuple(set(im_expla_tokens_list) - {im_expla_tok}) 250 | p_val = min(ci_test.calc_statistic(im_expla_tok, token_to_explain, cond_set), 1) # avoid inf 251 | weights_list.append(1-p_val) 252 | ci_test.num_records = prev_num_records 253 | 254 | # print(f'*** Head: {head_id} -- weights: {weights_list}') 255 | # if len(weights_list) == 0: 256 | # head_importance[head_id] = 0 257 | # else: 258 | # head_importance[head_id] = np.mean(weights_list) 259 | head_importance[head_id] = max(w_img) / max(w[first_im_token_idx+576:]) 260 | 261 | for im_expla_tok, im_expla_weight in zip(im_expla_tokens_list, weights_list): 262 | total_weights[im_expla_tok-first_im_token_idx] += im_expla_weight 263 | 264 | # if len(im_expla_tokens_list) > 0: 265 | # head_importance[head_id] = np.median(w[im_expla_tokens_list]) 266 | # else: 267 | # head_importance[head_id] = 0 268 | 269 | # p_vals_dict = calculate_explanation_pvals(explainer, token_to_explain, search_range) 270 | # p_weights_im_tokens = [ 271 | # (1-p_vals_dict[v])*w[v] for v in expla_list if (v >= first_im_token_idx) and (v < first_im_token_idx + 576) 272 | # ] 273 | # if len(p_weights_im_tokens) == 0: 274 | # head_importance[head_id] = 0 275 | # else: 276 | # head_importance[head_id] = np.median(p_weights_im_tokens) 277 | 278 | # if len(expla_list) > 0: 279 | # # head_importance[head_id] = np.median(w[expla_list]) 280 | # head_importance[head_id] = np.median(sorted(w)[-max_num_image_tokens:]) 281 | # else: 282 | # head_importance[head_id] = 0 283 | 284 | txt = '{head}: {importance:.2f} / 100'.format(head=head_id, importance=head_importance[head_id]*100) 285 | logger.info(f'Head: {head_id}: importance: {txt}') 286 | 287 | 288 | time_struct.append(timing['structure']) 289 | time_reason.append(timing['reasoning']) 290 | im_expla_rel_idx = [v-first_im_token_idx for v in im_expla_tokens_list] # only image 291 | 292 | # print(f'head {head_id}, importance: {head_importance[head_id]:.3f}, above {att_th}: {num_high_att}') 293 | 294 | # plot results 295 | logger.info('Max: *******', max(total_weights)) 296 | if max(total_weights) > 0: 297 | norm_total_weights = [v/max(total_weights) for v in total_weights] 298 | else: 299 | norm_total_weights = total_weights 300 | im_t = recovered_image.copy() 301 | im_heat_total = show_tokens_on_image(list(range(576)), im_t, norm_total_weights) 302 | im_heat_edit_t = ImageDraw.Draw(im_heat_total) 303 | im_heat_edit_t.text((10, 10), txt, fill=(255, 255, 255)) 304 | im_heat_total = im_heat_edit_t._image 305 | 306 | fig = plt.figure() 307 | ax = fig.add_subplot(1, 1, 1) 308 | ax.bar(range(num_heads), head_importance) 309 | ax.grid(True) 310 | xmin, xmax, ymin, ymax = ax.axis() 311 | ax.axis([1, 32, -ymax*0.01, ymax]) 312 | # ax.set_position([0, 0, 1, 1]) 313 | fig.canvas.draw() 314 | im_head_importance = Image.frombytes('RGB', fig.canvas.get_width_height(),fig.canvas.tostring_rgb()) 315 | plt.close() 316 | 317 | # attentnion values 318 | fig = plt.figure() 319 | ax = fig.add_subplot(1, 1, 1) 320 | h = [max(w[first_im_token_idx:576+first_im_token_idx])] + list(w[first_im_token_idx+576:]) 321 | ax.bar(range(len(h)), h) 322 | ax.grid(True) 323 | fig.canvas.draw() 324 | im_att_bar = Image.frombytes('RGB', fig.canvas.get_width_height(),fig.canvas.tostring_rgb()) 325 | plt.close() 326 | 327 | im_heat = create_im_tokens_marks(recovered_image, im_expla_rel_idx, txt=txt) 328 | # im_1 = recovered_image.copy() 329 | # # im_heat = show_tokens_on_image(im_expla_rel_idx, im_1, weights_list) 330 | # im_heat = show_tokens_on_image(im_expla_rel_idx, im_1) 331 | # im_heat_edit = ImageDraw.Draw(im_heat) 332 | # im_heat_edit.text((10, 10), txt, fill=(255, 255, 255)) 333 | # im_heat = im_heat_edit._image 334 | 335 | fig = plt.figure() 336 | ax = fig.add_subplot(1, 1, 1) 337 | ax.plot(head_importance, '.-') 338 | ax.grid(True) 339 | fig.canvas.draw() 340 | im_pl = Image.frombytes('RGB', fig.canvas.get_width_height(),fig.canvas.tostring_rgb()) 341 | plt.close() 342 | 343 | nodes_set = set(expla_list) 344 | nodes_set.add(token_to_explain) 345 | subgraph = copy_sub_graph(explainer.graph, nodes_set) 346 | fig = draw_graph(subgraph, show=False) 347 | fig.canvas.draw() 348 | # im_graph = Image.frombytes('RGB', fig.canvas.get_width_height(),fig.canvas.tostring_rgb()) 349 | plt.close() 350 | 351 | # nodes_set = set(text_expla) 352 | # nodes_set.add(token_to_explain) 353 | # subgraph = copy_sub_graph(text_expl.graph, nodes_set) 354 | # fig = draw_graph(subgraph, show=False, node_labels=node_labels) 355 | # fig.canvas.draw() 356 | im_graph = Image.frombytes('RGB', fig.canvas.get_width_height(),fig.canvas.tostring_rgb()) 357 | # plt.close() 358 | 359 | node_labels = dict() 360 | for tok in expla_list: 361 | if tok in txt_node_labels: # if token is text 362 | node_labels[tok] = txt_node_labels[tok] 363 | continue 364 | im_idx = tok - first_im_token_idx 365 | if im_idx < 0 or im_idx >= 576: # if token is not image 366 | continue 367 | im_tok = crop_token(recovered_image, im_idx, pad=2) 368 | node_labels[tok] = im_tok.resize((45, 45)) 369 | 370 | idx = token_to_explain - (attention_len - len(generated_text)) 371 | node_labels[token_to_explain] = generated_text[idx] 372 | 373 | nodes_set = set(expla_list) 374 | nodes_set.add(token_to_explain) 375 | fig = draw_pds_tree(explainer.results[token_to_explain]['pds_tree'], explainer.graph, node_labels=node_labels, 376 | node_size_factor=1.4) 377 | if fig is None: 378 | fig = plt.figure() 379 | fig.canvas.draw() 380 | im_graph = Image.frombytes('RGB', fig.canvas.get_width_height(),fig.canvas.tostring_rgb()) 381 | plt.close() 382 | 383 | gallery_image_list.append(im_heat) 384 | gallery_graph_list.append(im_graph) 385 | gallery_bar_graphs.append(im_att_bar) 386 | # gallery_image_list.append(im_pl) 387 | 388 | state_causal_explainers.append(explainer_all) # idx 1 389 | state_causal_explainers.append(expla_list_all) # idx 2 390 | return gallery_image_list + gallery_graph_list + gallery_bar_graphs, state_causal_explainers #im_heat_total #im_head_importance 391 | -------------------------------------------------------------------------------- /utils_causal_discovery_fn.py: -------------------------------------------------------------------------------- 1 | import time 2 | import numpy as np 3 | from itertools import combinations 4 | 5 | from PIL import Image, ImageEnhance 6 | import torch 7 | 8 | import matplotlib.pyplot as plt 9 | 10 | try: 11 | from causal_discovery_algs import LearnStructOrderedICD 12 | except ImportError: 13 | print("Warning: causal discovery pending update.") 14 | LearnStructOrderedICD = None 15 | 16 | from graphical_models import PAG 17 | from causal_reasoning import CLEANN 18 | 19 | 20 | def get_expla_set_per_rad(pds_tree): 21 | root_node = pds_tree.origin 22 | expla_lists_per_rad = {0: root_node} 23 | children = pds_tree.children 24 | rad = 1 25 | while len(children) > 0: 26 | expla_lists_per_rad[rad] = set() # initialize an explanation set at range rad 27 | children_of_children = [] 28 | for child in children: 29 | expla_lists_per_rad[rad].add(child.origin) 30 | children_of_children += child.children 31 | rad += 1 32 | children = children_of_children 33 | return expla_lists_per_rad 34 | 35 | 36 | 37 | def get_relevant_image_tokens(class_token, attention_matrix, first_token, num_top_k_tokens) -> list: 38 | """ 39 | Find the indexes of the image tokens for which the class tokens most attend (highest attention) 40 | 41 | :param class_token: 42 | :param attention_matrix: 43 | :param first_token: 44 | :param num_top_k_tokens: 45 | """ 46 | weights = attention_matrix[class_token, first_token:(first_token+576)] 47 | sorting_indexes = np.argsort(weights)[::-1] # descending sorting indexes 48 | all_indexes = list(range(576)) 49 | top_k_idx = [all_indexes[sorting_indexes[i]] for i in range(num_top_k_tokens)] 50 | top_k_idx = [t + first_token for t in top_k_idx] # add index offset 51 | return top_k_idx 52 | 53 | 54 | def get_relevant_prompt_tokens(class_token, attention_matrix, att_th, first_image_token) -> list: 55 | weights = attention_matrix[class_token, :first_image_token] 56 | relevent_prompt_tokens = np.where(weights > att_th)[0] 57 | return list(relevent_prompt_tokens) 58 | 59 | 60 | def get_relevant_text_tokens(class_token, attention_matrix, att_th, first_image_token) -> list: 61 | """ 62 | Get the indexes of the text tokens after the image (not including the prompt) 63 | for which the class tokens highly attends (attention above the threshold) 64 | """ 65 | weights = attention_matrix[class_token, (first_image_token+576):class_token] 66 | idxs = np.where(weights > att_th)[0] 67 | relevent_gen_tokens = [t + (first_image_token+576) for t in idxs] 68 | return relevent_gen_tokens 69 | 70 | 71 | def tokens_analysis(attention_matrix, tokens_idx_list, token_of_interest, 72 | number_of_samples, p_val_thrshold, max_search_range=None, verbose=True): 73 | explanation_list, cleann_explainer, runtimes = create_explanation(attention_matrix, tokens_idx_list, token_of_interest, 74 | number_of_samples, p_val_thrshold, max_search_range, 75 | verbose=verbose) 76 | explanation_list = sorted(explanation_list) 77 | if verbose: 78 | print(f'len {len(explanation_list)}', explanation_list) 79 | return explanation_list, cleann_explainer, runtimes 80 | 81 | 82 | def create_explanation(attention_matrix, tokens_idx_list, token_of_interest, 83 | number_of_samples, p_val_thrshold, max_search_range=None, verbose=True): 84 | cleann_explainer = CLEANN( 85 | attention_matrix=attention_matrix, 86 | num_samples=number_of_samples, 87 | p_val_th=p_val_thrshold, 88 | explanation_tester=None, 89 | nodes_set=set(tokens_idx_list), 90 | 91 | ) 92 | cond_indep_test = cleann_explainer.ci_test 93 | structure_learner = LearnStructOrderedICD(set(tokens_idx_list), sorted(tokens_idx_list), cond_indep_test, 94 | is_selection_bias=False) 95 | 96 | runtimes = {'structure': None, 'reasoning': None} 97 | t0 = time.time() 98 | structure_learner.learn_structure_global() 99 | t1 = time.time() 100 | runtimes['structure'] = t1-t0 101 | if verbose: 102 | print(f'Structure learning time {t1 - t0} seconds.') 103 | 104 | cleann_explainer.graph = structure_learner.graph 105 | t0 = time.time() 106 | explanation = cleann_explainer.explain(token_of_interest, max_range=max_search_range) 107 | t1 = time.time() 108 | runtimes['reasoning'] = t1-t0 109 | if verbose: 110 | print(f'Explanation deduction time {t1 - t0} seconds.') 111 | explanation_list = [v for v in explanation[0][0]] 112 | return explanation_list, cleann_explainer, runtimes 113 | 114 | 115 | def copy_sub_graph(full_graph: PAG, nodes_of_interest: set) -> PAG: 116 | sub_graph = PAG(nodes_of_interest) 117 | sub_graph.create_empty_graph() 118 | for node_i, node_j in combinations(nodes_of_interest, 2): 119 | if full_graph.is_connected(node_i, node_j): 120 | edge_at_i = full_graph.get_edge_mark(node_j, node_i) 121 | edge_at_j = full_graph.get_edge_mark(node_i, node_j) 122 | sub_graph.add_edge(node_i, node_j, edge_at_i, edge_at_j) 123 | return sub_graph 124 | 125 | # def create_preprocessed_image(in_image): 126 | # img_std = torch.tensor(image_processor.image_std).view(3,1,1) 127 | # img_mean = torch.tensor(image_processor.image_mean).view(3,1,1) 128 | # img_recover = in_image * img_std + img_mean 129 | # return to_pil_image(img_recover) 130 | 131 | 132 | def show_tokens_on_image(selected_image_tokens, pil_image, weights=None): 133 | if weights is None or len(weights)==0: 134 | weights_n = [0.7] * len(selected_image_tokens) 135 | else: 136 | mx = 1 # max(weights) 137 | weights_n = [v/mx for v in weights] 138 | 139 | tokens_mask = np.zeros(576) 140 | for i, tok in enumerate(selected_image_tokens): 141 | tokens_mask[tok] = weights_n[i] 142 | cmap = plt.get_cmap('jet') 143 | im_mask = tokens_mask.reshape((24, 24)) 144 | im_mask = cmap(im_mask) 145 | a_im = Image.fromarray((im_mask[:, :, :3] * 255).astype(np.uint8)).resize((336, 336), Image.BICUBIC) 146 | a_im.putalpha(128) 147 | new_im = pil_image.copy() 148 | new_im.paste(a_im, mask=a_im) 149 | return new_im 150 | 151 | 152 | def calculate_explanation_pvals(explainer_instance, target_node, max_range=None): 153 | if target_node not in explainer_instance.results: 154 | raise "explainer should have initially been run." 155 | if max_range is None: 156 | max_range = explainer_instance.results[target_node]['max_pds_tree_depth'] 157 | 158 | ci_test = explainer_instance.ci_test # alias 159 | pvals = dict() 160 | 161 | cond_set = () # initial conditioning set 162 | prev_res_set = set() 163 | for r in range(1, max_range): 164 | res_set = explainer_instance.explain(target_node, max_range=r)[0][0] 165 | for v in res_set.difference(prev_res_set): 166 | pvals[v] = min(ci_test.calc_statistic(v, target_node, cond_set), 1) 167 | cond_set = tuple(res_set) 168 | prev_res_set = res_set 169 | return pvals 170 | 171 | 172 | def image_token_to_xy(image_token, n_x_tokens=24, n_y_tokens=24, token_width=14, token_height=14): 173 | 174 | x_pos = (image_token % n_x_tokens) * token_width 175 | y_pos = (image_token // n_y_tokens) * token_height 176 | return x_pos, y_pos 177 | 178 | 179 | def crop_token(in_im, image_token, n_x_tokens=24, n_y_tokens=24, pad=None): 180 | im_width, im_height = in_im.size 181 | token_width = im_width // n_x_tokens 182 | token_height = im_height // n_y_tokens 183 | x_pos, y_pos = image_token_to_xy(image_token, n_x_tokens, n_y_tokens, token_width, token_height) 184 | left= x_pos 185 | right = left + token_width - 1 186 | top = y_pos 187 | bottom = top + token_height - 1 188 | im_token = in_im.crop((left, top, right, bottom)) 189 | if pad is None: 190 | return im_token 191 | else: 192 | left_pad = max(0, left-pad*token_width) 193 | right_pad = min(im_width-1, right+pad*token_width) 194 | top_pad = max(0, top-pad*token_height) 195 | bottom_pad = min(im_height-1, bottom+pad*token_height) 196 | # print(left_pad, right_pad, top_pad, bottom_pad) 197 | enhancer = ImageEnhance.Brightness(in_im) 198 | new_im = enhancer.enhance(0.5) 199 | pad_image = new_im.crop((left_pad, top_pad, right_pad, bottom_pad)) 200 | pad_image.paste(im_token, (left-left_pad, top-top_pad)) 201 | return pad_image 202 | -------------------------------------------------------------------------------- /utils_gradio.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tempfile 3 | import logging 4 | 5 | import torch 6 | 7 | from PIL import Image 8 | import numpy as np 9 | import gradio as gr 10 | import spaces 11 | 12 | from torchvision.transforms.functional import to_pil_image 13 | 14 | from utils_model import get_processor_model, move_to_device, to_gradio_chatbot, process_image 15 | 16 | from utils_attn import ( 17 | handle_attentions_i2t, plot_attention_analysis, handle_relevancy, handle_text_relevancy, reset_tokens, 18 | plot_text_to_image_analysis, handle_box_reset, boxes_click_handler, attn_update_slider 19 | ) 20 | 21 | from utils_relevancy import construct_relevancy_map 22 | 23 | from utils_causal_discovery import ( 24 | handle_causality, handle_causal_head, causality_update_dropdown 25 | ) 26 | 27 | logger = logging.getLogger(__name__) 28 | 29 | N_LAYERS = 32 30 | CUR_DIR = os.path.dirname(os.path.abspath(__file__)) 31 | ROLE0 = "USER" 32 | ROLE1 = "ASSISTANT" 33 | 34 | processor = None 35 | model = None 36 | 37 | system_prompt = """You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. 38 | If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""" 39 | # system_prompt = "" 40 | # system_prompt ="""A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.""" 41 | 42 | title_markdown = (""" 43 | # LVLM-Interpret: An Interpretability Tool for Large Vision-Language Models 44 | """) 45 | 46 | tos_markdown = (""" 47 | ### Terms of use 48 | By using this service, users are required to agree to the following terms: 49 | The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. 50 | """) 51 | 52 | block_css = """ 53 | 54 | #image_canvas canvas { 55 | max-width: 400px !important; 56 | max-height: 400px !important; 57 | } 58 | 59 | #buttons button { 60 | min-width: min(120px,100%); 61 | } 62 | 63 | """ 64 | 65 | def clear_history(request: gr.Request): 66 | logger.info(f"clear_history. ip: {request.client.host}") 67 | state = gr.State() 68 | state.messages = [] 69 | return (state, [], "", None, None, None, None) 70 | 71 | def add_text(state, text, image, image_process_mode): 72 | global processor 73 | 74 | if True: # state is None: 75 | state = gr.State() 76 | state.messages = [] 77 | 78 | if isinstance(image, dict): 79 | image = image['composite'] 80 | background = Image.new('RGBA', image.size, (255, 255, 255)) 81 | image = Image.alpha_composite(background, image).convert('RGB') 82 | 83 | # ImageEditor does not return None image 84 | if (np.array(image)==255).all(): 85 | image =None 86 | 87 | text = text[:1536] # Hard cut-off 88 | logger.info(text) 89 | 90 | prompt_len = 0 91 | # prompt=f"[INST] {system_prompt} [/INST]\n\n" if system_prompt else "" 92 | if processor.tokenizer.chat_template is not None: 93 | prompt = processor.tokenizer.apply_chat_template( 94 | [{"role": "user", "content": "\n" + text}], 95 | tokenize=False, 96 | add_generation_prompt=True, 97 | ) 98 | prompt_len += len(prompt) 99 | else: 100 | prompt = system_prompt 101 | prompt_len += len(prompt) 102 | if image is not None: 103 | msg = f"\n{ROLE0}: \n{text}\n{ROLE1}:" # Ignore token when calculating prompt length\ 104 | else: 105 | msg = f"\n{ROLE0}: {text}\n{ROLE1}: " 106 | prompt += msg 107 | prompt_len += len(msg) 108 | 109 | state.messages.append([ROLE0, (text, image, image_process_mode)]) 110 | state.messages.append([ROLE1, None]) 111 | 112 | state.prompt_len = prompt_len 113 | state.prompt = prompt 114 | state.image = process_image(image, image_process_mode, return_pil=True) 115 | 116 | return (state, to_gradio_chatbot(state), "", None) 117 | 118 | 119 | @spaces.GPU 120 | def lvlm_bot(state, temperature, top_p, max_new_tokens): 121 | prompt = state.prompt 122 | prompt_len = state.prompt_len 123 | image = state.image 124 | 125 | inputs = processor(prompt, image, return_tensors="pt").to(model.device) 126 | input_ids = inputs.input_ids 127 | img_idx = torch.where(input_ids==model.config.image_token_index)[1][0].item() 128 | do_sample = True if temperature > 0.001 else False 129 | # Generate 130 | model.enc_attn_weights = [] 131 | model.enc_attn_weights_vit = [] 132 | 133 | if model.language_model.config.model_type == "gemma": 134 | eos_token_id = processor.tokenizer('', add_special_tokens=False).input_ids[0] 135 | else: 136 | eos_token_id = processor.tokenizer.eos_token_id 137 | 138 | outputs = model.generate( 139 | **inputs, 140 | do_sample=do_sample, 141 | temperature=temperature, 142 | top_p=top_p, 143 | max_new_tokens=max_new_tokens, 144 | use_cache=True, 145 | output_attentions=True, 146 | return_dict_in_generate=True, 147 | output_scores=True, 148 | eos_token_id=eos_token_id 149 | ) 150 | 151 | input_ids_list = input_ids.reshape(-1).tolist() 152 | input_ids_list[img_idx] = 0 153 | input_text = processor.tokenizer.decode(input_ids_list) # eg. " You are a helpful ..." 154 | if input_text.startswith(" "): 155 | input_text = '' + input_text[4:] # Remove the first space after to maintain correct length 156 | input_text_tokenized = processor.tokenizer.tokenize(input_text) # eg. ['', '▁You', '▁are', '▁a', '▁helpful', ... ] 157 | input_text_tokenized[img_idx] = "average_image" 158 | 159 | output_ids = outputs.sequences.reshape(-1)[input_ids.shape[-1]:].tolist() 160 | 161 | generated_text = processor.tokenizer.decode(output_ids) 162 | output_ids_decoded = [processor.tokenizer.decode(oid).strip() for oid in output_ids] # eg. ['The', 'man', "'", 's', 'sh', 'irt', 'is', 'yellow', '.', ''] 163 | generated_text_tokenized = processor.tokenizer.tokenize(generated_text) 164 | 165 | logger.info(f"Generated response: {generated_text}") 166 | logger.debug(f"output_ids_decoded: {output_ids_decoded}") 167 | logger.debug(f"generated_text_tokenized: {generated_text_tokenized}") 168 | 169 | state.messages[-1][-1] = generated_text[:-len('')] if generated_text.endswith('') else generated_text 170 | 171 | tempdir = os.getenv('TMPDIR', '/tmp/') 172 | tempfilename = tempfile.NamedTemporaryFile(dir=tempdir) 173 | tempfilename.close() 174 | 175 | # Save input_ids and attentions 176 | fn_input_ids = f'{tempfilename.name}_input_ids.pt' 177 | torch.save(move_to_device(input_ids, device='cpu'), fn_input_ids) 178 | fn_attention = f'{tempfilename.name}_attn.pt' 179 | torch.save(move_to_device(outputs.attentions, device='cpu'), fn_attention) 180 | logger.info(f"Saved attention to {fn_attention}") 181 | 182 | # Handle relevancy map 183 | # tokens_for_rel = tokens_for_rel[1:] 184 | word_rel_map = construct_relevancy_map( 185 | tokenizer=processor.tokenizer, 186 | model=model, 187 | input_ids=inputs.input_ids, 188 | tokens=generated_text_tokenized, 189 | outputs=outputs, 190 | output_ids=output_ids, 191 | img_idx=img_idx 192 | ) 193 | fn_relevancy = f'{tempfilename.name}_relevancy.pt' 194 | torch.save(move_to_device(word_rel_map, device='cpu'), fn_relevancy) 195 | logger.info(f"Saved relevancy map to {fn_relevancy}") 196 | model.enc_attn_weights = [] 197 | model.enc_attn_weights_vit = [] 198 | # enc_attn_weights_vit = [] 199 | # rel_maps = [] 200 | 201 | # Reconstruct processed image 202 | img_std = torch.tensor(processor.image_processor.image_std).view(3,1,1) 203 | img_mean = torch.tensor(processor.image_processor.image_mean).view(3,1,1) 204 | img_recover = inputs.pixel_values[0].cpu() * img_std + img_mean 205 | img_recover = to_pil_image(img_recover) 206 | 207 | state.recovered_image = img_recover 208 | state.input_text_tokenized = input_text_tokenized 209 | state.output_ids_decoded = output_ids_decoded 210 | state.attention_key = tempfilename.name 211 | state.image_idx = img_idx 212 | 213 | return state, to_gradio_chatbot(state) 214 | 215 | 216 | def build_demo(args, embed_mode=False): 217 | global model 218 | global processor 219 | global system_prompt 220 | global ROLE0 221 | global ROLE1 222 | 223 | if model is None: 224 | processor, model = get_processor_model(args) 225 | 226 | if 'gemma' in args.model_name_or_path: 227 | system_prompt = '' 228 | ROLE0 = 'user' 229 | ROLE1 = 'model' 230 | 231 | textbox = gr.Textbox(show_label=False, placeholder="Enter text and press ENTER", container=False) 232 | with gr.Blocks(title="LVLM-Interpret", theme=gr.themes.Default(), css=block_css) as demo: 233 | state = gr.State() 234 | 235 | if not embed_mode: 236 | gr.Markdown(title_markdown) 237 | 238 | with gr.Tab("Generation"): 239 | with gr.Row(): 240 | with gr.Column(scale=6): 241 | 242 | imagebox = gr.ImageEditor(type="pil", height=400, elem_id="image_canvas") 243 | 244 | 245 | with gr.Accordion("Parameters", open=False) as parameter_row: 246 | image_process_mode = gr.Radio( 247 | ["Crop", "Resize", "Pad", "Default"], 248 | value="Default", 249 | label="Preprocess for non-square image", visible=True 250 | ) 251 | temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.2, step=0.1, interactive=True, label="Temperature",) 252 | top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, step=0.1, interactive=True, label="Top P",) 253 | max_output_tokens = gr.Slider(minimum=0, maximum=512, value=32, step=32, interactive=True, label="Max new output tokens",) 254 | 255 | 256 | with gr.Column(scale=6): 257 | chatbot = gr.Chatbot(elem_id="chatbot", label="Chatbot", height=400) 258 | with gr.Row(): 259 | with gr.Column(scale=8): 260 | textbox.render() 261 | with gr.Column(scale=1, min_width=50): 262 | submit_btn = gr.Button(value="Send", variant="primary") 263 | with gr.Row(elem_id="buttons") as button_row: 264 | clear_btn = gr.Button(value="🗑️ Clear", interactive=True, visible=True) 265 | 266 | # with gr.Row(): 267 | # with gr.Column(scale=6): 268 | 269 | # gr.Examples(examples=[ 270 | # [f"{CUR_DIR}/examples/extreme_ironing.jpg", "What color is the man's shirt?"], 271 | # [f"{CUR_DIR}/examples/waterview.jpg", "What is in the top left of this image?"], 272 | # [f"{CUR_DIR}/examples/MMVP_34.jpg", "Is the butterfly's abdomen visible in the image?"], 273 | # ], inputs=[imagebox, textbox]) 274 | 275 | # with gr.Column(scale=6): 276 | # gr.Examples(examples=[ 277 | # [f"{CUR_DIR}/examples/MMVP_84.jpg", "Is the door of the truck cab open?"], 278 | # [f"{CUR_DIR}/examples/MMVP_173.jpg", "Is the decoration on the Easter egg flat or raised?"], 279 | # [f"{CUR_DIR}/examples/MMVP_279.jpg", "Is the elderly person standing or sitting in the picture?"], 280 | # ], inputs=[imagebox, textbox]) 281 | 282 | with gr.Tab("Attention analysis"): 283 | with gr.Row(): 284 | with gr.Column(scale=3): 285 | # attn_ana_layer = gr.Slider(1, 100, step=1, label="Layer") 286 | attn_modality_select = gr.Dropdown( 287 | choices=['Image-to-Answer', 'Question-to-Answer'], 288 | value='Image-to-Answer', 289 | interactive=True, 290 | show_label=False, 291 | container=False 292 | ) 293 | attn_ana_submit = gr.Button(value="Plot attention matrix", interactive=True) 294 | with gr.Column(scale=6): 295 | attn_ana_plot = gr.Plot(label="Attention plot") 296 | 297 | attn_ana_submit.click( 298 | plot_attention_analysis, 299 | [state, attn_modality_select], 300 | [state, attn_ana_plot] 301 | ) 302 | 303 | with gr.Tab("Attentions"): 304 | with gr.Row(): 305 | attn_select_layer = gr.Slider(1, N_LAYERS, value=32, step=1, label="Layer") 306 | with gr.Row(): 307 | with gr.Column(scale=3): 308 | imagebox_recover = gr.Image(type="pil", label='Preprocessed image', interactive=False) 309 | 310 | generated_text = gr.HighlightedText( 311 | label="Generated text (tokenized)", 312 | combine_adjacent=False, 313 | interactive=True, 314 | color_map={"label": "green"} 315 | ) 316 | with gr.Row(): 317 | attn_reset = gr.Button(value="Reset tokens", interactive=True) 318 | attn_submit = gr.Button(value="Plot attention", interactive=True) 319 | 320 | with gr.Column(scale=9): 321 | i2t_attn_head_mean_plot = gr.Plot(label="Image-to-Text attention average per head") 322 | i2t_attn_gallery = gr.Gallery(type="pil", label='Attention heatmaps', columns=8, interactive=False) 323 | 324 | box_states = gr.Dataframe(type="numpy", datatype="bool", row_count=24, col_count=24, visible=False) 325 | with gr.Row(equal_height=True): 326 | with gr.Column(scale=3): 327 | imagebox_recover_boxable = gr.Image(label='Patch Selector') 328 | attn_ana_head= gr.Slider(1, 40, step=1, label="Head Index") 329 | 330 | reset_boxes_btn = gr.Button(value="Reset patch selector") 331 | attn_ana_submit_2 = gr.Button(value="Plot attention matrix", interactive=True) 332 | 333 | with gr.Column(scale=9): 334 | t2i_attn_head_mean_plot = gr.Plot(label="Text-to-Image attention average per head") 335 | attn_ana_plot_2 = gr.Plot(scale=2, label="Attention plot",container=True) 336 | 337 | reset_boxes_btn.click( 338 | handle_box_reset, 339 | [imagebox_recover,box_states], 340 | [imagebox_recover_boxable, box_states] 341 | ) 342 | imagebox_recover_boxable.select(boxes_click_handler, [imagebox_recover,box_states], [imagebox_recover_boxable, box_states]) 343 | 344 | attn_reset.click( 345 | reset_tokens, 346 | [state], 347 | [generated_text] 348 | ) 349 | 350 | attn_ana_submit_2.click( 351 | plot_text_to_image_analysis, 352 | [state, attn_select_layer, box_states, attn_ana_head ], 353 | [state, attn_ana_plot_2, t2i_attn_head_mean_plot] 354 | ) 355 | 356 | 357 | attn_submit.click( 358 | handle_attentions_i2t, 359 | [state, generated_text, attn_select_layer], 360 | [generated_text, imagebox_recover, i2t_attn_gallery, i2t_attn_head_mean_plot] 361 | ) 362 | 363 | with gr.Tab("Relevancy"): 364 | with gr.Row(): 365 | relevancy_token_dropdown = gr.Dropdown( 366 | choices=['llama','vit','all'], 367 | value='llama', 368 | interactive=True, 369 | show_label=False, 370 | container=False 371 | ) 372 | relevancy_submit = gr.Button(value="Plot relevancy", interactive=True) 373 | with gr.Row(): 374 | relevancy_gallery = gr.Gallery(type="pil", label='Input image relevancy heatmaps', columns=8, interactive=False) 375 | with gr.Row(): 376 | relevancy_txt_gallery = gr.Gallery(type="pil", label='Image-text relevancy comparison', columns=8, interactive=False) 377 | #gr.Plot(label='Input text Relevancy heatmaps') 378 | with gr.Row(): 379 | relevancy_highlightedtext = gr.HighlightedText( 380 | label='Tokens with high relevancy to image' 381 | ) 382 | 383 | relevancy_submit.click( 384 | lambda state, relevancy_token_dropdown: handle_relevancy(state, relevancy_token_dropdown, incude_text_relevancy=True), 385 | #handle_relevancy, 386 | [state, relevancy_token_dropdown], 387 | [relevancy_gallery], 388 | ) 389 | relevancy_submit.click( 390 | handle_text_relevancy, 391 | [state, relevancy_token_dropdown], 392 | [relevancy_txt_gallery, relevancy_highlightedtext] 393 | ) 394 | 395 | enable_causality = False 396 | with gr.Tab("Causality"): 397 | gr.Markdown( 398 | """ 399 | ### *Coming soon* 400 | """ 401 | ) 402 | state_causal_explainers = gr.State() 403 | with gr.Row(visible=enable_causality): 404 | causality_dropdown = gr.Dropdown( 405 | choices=[], 406 | interactive=True, 407 | show_label=False, 408 | container=False, 409 | scale=2, 410 | ) 411 | causality_submit = gr.Button(value="Learn Causal Structures", interactive=True, variant='primary', scale=1) 412 | with gr.Row(visible=enable_causality): 413 | with gr.Accordion("Hyper Parameters", open=False) as causal_parameters_row: 414 | with gr.Row(): 415 | with gr.Column(scale=2): 416 | # search_rad_slider= gr.Slider(1, 5, step=1, value=3, label="Search Radius", 417 | # info="The maximal distance on the graph from the explained token.",) 418 | att_th_slider = gr.Slider(minimum=0.0001, maximum=1-0.0001, value=0.005, step=0.0001, interactive=True, label="Raw Attention Threshold", 419 | info="A threshold for selecting tokens to be graph nodes.",) 420 | with gr.Column(scale=2): 421 | alpha_slider = gr.Slider(minimum=1e-7, maximum=1e-2, value=1e-5, step=1e-7, interactive=True, label="Statistical Test Threshold (alpha)", 422 | info="A threshold for the statistical test of conditional independence.",) 423 | # dof_slider = gr.Slider(minimum=32, maximum=1024, value=128, step=1, interactive=True, label="Degrees of Freedom", 424 | # info="Degrees of freedom of correlation matrix.") 425 | with gr.Row(visible=enable_causality): 426 | pds_plot = gr.Image(type="pil", label='Preprocessed image') 427 | causal_head_gallery = gr.Gallery(type="pil", label='Causal Head Graph', columns=8, interactive=False) 428 | with gr.Row(visible=enable_causality): 429 | causal_head_slider = gr.Slider(minimum=0, maximum=31, value=1, step=1, interactive=True, label="Head Selection") 430 | causal_head_submit = gr.Button(value="Plot Causal Head", interactive=True, scale=1) 431 | with gr.Row(visible=enable_causality): 432 | causality_gallery = gr.Gallery(type="pil", label='Causal Heatmaps', columns=8, interactive=False) 433 | 434 | causal_head_submit.click( 435 | handle_causal_head, 436 | [state, state_causal_explainers, causal_head_slider, causality_dropdown], 437 | [causal_head_gallery, pds_plot] 438 | ) 439 | 440 | causality_submit.click( 441 | handle_causality, 442 | [state, state_causal_explainers, causality_dropdown, alpha_slider, att_th_slider], 443 | [causality_gallery, state_causal_explainers] 444 | ) 445 | 446 | if not embed_mode: 447 | gr.Markdown(tos_markdown) 448 | 449 | clear_btn.click( 450 | clear_history, 451 | None, 452 | [state, chatbot, textbox, imagebox, imagebox_recover, generated_text, i2t_attn_gallery ] , 453 | queue=False 454 | ) 455 | 456 | textbox.submit( 457 | add_text, 458 | [state, textbox, imagebox, image_process_mode], 459 | [state, chatbot, textbox, imagebox], 460 | queue=False 461 | ).then( 462 | lvlm_bot, 463 | [state, temperature, top_p, max_output_tokens], 464 | [state, chatbot] , 465 | ).then( 466 | attn_update_slider, 467 | [state], 468 | [state, attn_select_layer] 469 | ).then( 470 | causality_update_dropdown, 471 | [state], 472 | [state, causality_dropdown] 473 | ) 474 | # .then( 475 | # handle_box_reset, 476 | # [imagebox_recover,box_states], 477 | # [imagebox_recover_boxable, box_states] 478 | # ).then( 479 | # handle_attentions_i2t, 480 | # [state, generated_text, attn_select_layer], 481 | # [generated_text, imagebox_recover, i2t_attn_gallery, i2t_attn_head_mean_plot] 482 | # ).then( 483 | # clear_canvas, 484 | # [], 485 | # [imagebox] 486 | # ).then( 487 | # handle_relevancy, 488 | # [state, relevancy_token_dropdown], 489 | # [relevancy_gallery] 490 | # ).then( 491 | # handle_text_relevancy, 492 | # [state, relevancy_token_dropdown], 493 | # [relevancy_txt_gallery, relevancy_highlightedtext] 494 | # ) 495 | submit_btn.click( 496 | add_text, 497 | [state, textbox, imagebox, image_process_mode], 498 | [state, chatbot, textbox, imagebox], 499 | queue=False 500 | ).then( 501 | lvlm_bot, 502 | [state, temperature, top_p, max_output_tokens], 503 | [state, chatbot], 504 | ).then( 505 | attn_update_slider, 506 | [state], 507 | [state, attn_select_layer] 508 | ).then( 509 | causality_update_dropdown, 510 | [state], 511 | [state, causality_dropdown] 512 | ) 513 | # .then( 514 | # causality_update_dropdown, 515 | # [state], 516 | # [causality_dropdown] 517 | # ).then( 518 | # handle_box_reset, 519 | # [imagebox_recover,box_states], 520 | # [imagebox_recover_boxable, box_states] 521 | # ).then( 522 | # plot_attention_analysis, 523 | # [state, attn_modality_select], 524 | # [state, attn_ana_plot] 525 | # ).then( 526 | # handle_relevancy, 527 | # [state, relevancy_token_dropdown], 528 | # [relevancy_gallery] 529 | # ).then( 530 | # handle_text_relevancy, 531 | # [state, relevancy_token_dropdown], 532 | # [relevancy_txt_gallery, relevancy_highlightedtext] 533 | # ) 534 | 535 | 536 | return demo 537 | 538 | -------------------------------------------------------------------------------- /utils_model.py: -------------------------------------------------------------------------------- 1 | 2 | import logging 3 | import base64 4 | from io import BytesIO 5 | from PIL import Image 6 | import torch 7 | # from torchvision.transforms.functional import to_pil_image 8 | from transformers import LlavaForConditionalGeneration, AutoProcessor 9 | from transformers import BitsAndBytesConfig 10 | 11 | func_to_enable_grad = '_sample' 12 | setattr(LlavaForConditionalGeneration, func_to_enable_grad, torch.enable_grad(getattr(LlavaForConditionalGeneration, func_to_enable_grad))) 13 | 14 | try: 15 | import intel_extension_for_pytorch as ipex 16 | except ModuleNotFoundError: 17 | pass 18 | 19 | logger = logging.getLogger(__name__) 20 | 21 | def get_processor_model(args): 22 | #outputs: attn_output, attn_weights, past_key_value 23 | processor = AutoProcessor.from_pretrained(args.model_name_or_path) 24 | 25 | if args.load_4bit: 26 | quant_config = BitsAndBytesConfig( 27 | load_in_4bit=True, 28 | bnb_4bit_quant_type="nf4", 29 | bnb_4bit_use_double_quant=True, 30 | bnb_4bit_compute_dtype=torch.bfloat16 31 | ) 32 | elif args.load_8bit: 33 | quant_config = BitsAndBytesConfig( 34 | load_in_8bit=True 35 | ) 36 | else: 37 | quant_config = None 38 | 39 | model = LlavaForConditionalGeneration.from_pretrained( 40 | args.model_name_or_path, torch_dtype=torch.bfloat16, 41 | quantization_config=quant_config, low_cpu_mem_usage=True, device_map=args.device_map 42 | ) 43 | model.vision_tower.config.output_attentions = True 44 | 45 | # Relevancy map 46 | # set hooks to get attention weights 47 | model.enc_attn_weights = [] 48 | #outputs: attn_output, attn_weights, past_key_value 49 | def forward_hook(module, inputs, output): 50 | if output[1] is None: 51 | logger.error( 52 | ("Attention weights were not returned for the encoder. " 53 | "To enable, set output_attentions=True in the forward pass of the model. ") 54 | ) 55 | return output 56 | 57 | output[1].requires_grad_(True) 58 | output[1].retain_grad() 59 | model.enc_attn_weights.append(output[1]) 60 | return output 61 | 62 | hooks_pre_encoder, hooks_encoder = [], [] 63 | for layer in model.language_model.model.layers: 64 | hook_encoder_layer = layer.self_attn.register_forward_hook(forward_hook) 65 | hooks_pre_encoder.append(hook_encoder_layer) 66 | 67 | model.enc_attn_weights_vit = [] 68 | 69 | 70 | def forward_hook_image_processor(module, inputs, output): 71 | if output[1] is None: 72 | logger.warning( 73 | ("Attention weights were not returned for the vision model. " 74 | "Relevancy maps will not be calculated for the vision model. " 75 | "To enable, set output_attentions=True in the forward pass of vision_tower. ") 76 | ) 77 | return output 78 | 79 | output[1].requires_grad_(True) 80 | output[1].retain_grad() 81 | model.enc_attn_weights_vit.append(output[1]) 82 | return output 83 | 84 | hooks_pre_encoder_vit = [] 85 | for layer in model.vision_tower.vision_model.encoder.layers: 86 | hook_encoder_layer_vit = layer.self_attn.register_forward_hook(forward_hook_image_processor) 87 | hooks_pre_encoder_vit.append(hook_encoder_layer_vit) 88 | 89 | return processor, model 90 | 91 | def process_image(image, image_process_mode, return_pil=False, image_format='PNG', max_len=1344, min_len=672): 92 | if image_process_mode == "Pad": 93 | def expand2square(pil_img, background_color=(122, 116, 104)): 94 | width, height = pil_img.size 95 | if width == height: 96 | return pil_img 97 | elif width > height: 98 | result = Image.new(pil_img.mode, (width, width), background_color) 99 | result.paste(pil_img, (0, (width - height) // 2)) 100 | return result 101 | else: 102 | result = Image.new(pil_img.mode, (height, height), background_color) 103 | result.paste(pil_img, ((height - width) // 2, 0)) 104 | return result 105 | image = expand2square(image) 106 | elif image_process_mode in ["Default", "Crop"]: 107 | pass 108 | elif image_process_mode == "Resize": 109 | image = image.resize((336, 336)) 110 | else: 111 | raise ValueError(f"Invalid image_process_mode: {image_process_mode}") 112 | if max(image.size) > max_len: 113 | max_hw, min_hw = max(image.size), min(image.size) 114 | aspect_ratio = max_hw / min_hw 115 | shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw)) 116 | longest_edge = int(shortest_edge * aspect_ratio) 117 | W, H = image.size 118 | if H > W: 119 | H, W = longest_edge, shortest_edge 120 | else: 121 | H, W = shortest_edge, longest_edge 122 | image = image.resize((W, H)) 123 | if return_pil: 124 | return image 125 | else: 126 | buffered = BytesIO() 127 | image.save(buffered, format=image_format) 128 | img_b64_str = base64.b64encode(buffered.getvalue()).decode() 129 | return img_b64_str 130 | 131 | 132 | def to_gradio_chatbot(state): 133 | ret = [] 134 | for i, (role, msg) in enumerate(state.messages): 135 | if i % 2 == 0: 136 | if type(msg) is tuple: 137 | msg, image, image_process_mode = msg 138 | img_b64_str = process_image( 139 | image, "Default", return_pil=False, 140 | image_format='JPEG') 141 | img_str = f'user upload image' 142 | msg = img_str + msg.replace('', '').strip() 143 | ret.append([msg, None]) 144 | else: 145 | ret.append([msg, None]) 146 | else: 147 | ret[-1][-1] = msg 148 | return ret 149 | 150 | def move_to_device(input, device='cpu'): 151 | 152 | if isinstance(input, torch.Tensor): 153 | return input.to(device).detach() 154 | elif isinstance(input, list): 155 | return [move_to_device(inp) for inp in input] 156 | elif isinstance(input, tuple): 157 | return tuple([move_to_device(inp) for inp in input]) 158 | elif isinstance(input, dict): 159 | return dict( ((k, move_to_device(v)) for k,v in input.items())) 160 | else: 161 | raise ValueError(f"Unknown data type for {input.type}") 162 | 163 | -------------------------------------------------------------------------------- /utils_relevancy.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import inspect 3 | import warnings 4 | from dataclasses import dataclass 5 | from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union 6 | 7 | import torch 8 | import torch.distributed as dist 9 | from torch import nn 10 | import torch.nn.functional as F 11 | 12 | from transformers.utils import logging 13 | from transformers.generation.beam_constraints import DisjunctiveConstraint, PhrasalConstraint 14 | from transformers.generation.beam_search import BeamSearchScorer, ConstrainedBeamSearchScorer 15 | from transformers import GenerationConfig 16 | from transformers.generation.logits_process import ( 17 | LogitsProcessorList, 18 | ) 19 | from transformers import ( 20 | StoppingCriteriaList) 21 | from transformers.generation.utils import GenerateOutput 22 | 23 | import gradio as gr 24 | from tqdm import tqdm 25 | 26 | logger = logging.get_logger(__name__) 27 | 28 | SEPARATORS_LIST = ['.',',','?','!', ':', ';', '', '/', '!', '(', ')', '[', ']', '{', '}', '<', '>', '|', '\\', '-', '_', '+', '=', '*', '&', '^', '%', '$', '#', '@', '!', '~', '`', ' ', '\t', '\n', '\r', '\x0b', '\x0c'] 29 | 30 | 31 | # rule 5 from paper 32 | def avg_heads(cam, grad): 33 | cam = cam.reshape(-1, cam.shape[-2], cam.shape[-1]) 34 | grad = grad.reshape(-1, grad.shape[-2], grad.shape[-1]) 35 | cam = grad * cam 36 | cam = cam.clamp(min=0).mean(dim=0) 37 | return cam 38 | 39 | # rule 6 from paper 40 | def handle_self_attention_image(R_i_i, enc_attn_weights, privious_cam=[]): 41 | if privious_cam : 42 | device = privious_cam[-1].device 43 | else: 44 | device = None 45 | for i, blk in enumerate(enc_attn_weights): 46 | grad = blk.grad.float().detach() 47 | # if model.use_lrp: # not used 48 | # cam = blk[batch_no].detach() 49 | # else: 50 | cam = blk.float().detach() # the attention of one layer 51 | if device is None: 52 | device = cam.device 53 | cam = avg_heads(cam.to(device), grad.to(device)) 54 | # rebuild the privious attenions to the same size as the current attention 55 | if len(privious_cam) != 0 and cam.shape[0] == 1: 56 | len_seq, all_len_seq = privious_cam[i].shape 57 | assert len_seq == all_len_seq, "The privious CAMs are not square" 58 | new_column = torch.zeros(len_seq, 1).to(cam.device) 59 | privious_cam[i] = torch.cat((privious_cam[i], new_column), dim=1) 60 | privious_cam[i] = torch.cat((privious_cam[i], cam), dim=0) 61 | cam = privious_cam[i] 62 | elif cam.shape[0] != 1: 63 | privious_cam.append(cam) 64 | assert cam.shape == R_i_i.shape, "The attention weights and the relevancy map are not the same size" 65 | R_i_i += torch.matmul(cam, R_i_i) 66 | del grad, cam 67 | # torch.cuda.empty_cache() 68 | 69 | return R_i_i, privious_cam 70 | 71 | def handle_self_attention_image_vit(R_i_i_init, enc_attn_weights_vit, img_idx=None, add_skip=False, normalize=False): 72 | if img_idx: 73 | R_i_i = R_i_i_init[img_idx:img_idx+576, img_idx:img_idx+576] 74 | if add_skip: 75 | R_i_i = R_i_i + torch.eye(R_i_i.shape[-1]).to(R_i_i.device) 76 | # add a first column and first row of zeros to R_i_i - option #1 77 | R_i_i = torch.cat((torch.zeros(1, R_i_i.shape[1]).to(R_i_i.device), R_i_i), dim=0) 78 | R_i_i = torch.cat((torch.zeros(R_i_i.shape[0], 1).to(R_i_i.device), R_i_i), dim=1) 79 | R_i_i[0,0] = 1 80 | else: 81 | R_i_i = R_i_i_init 82 | if normalize: 83 | R_i_i = handle_residual(R_i_i) 84 | for j, blk_vit in enumerate(enc_attn_weights_vit): #577x577, 1x576 85 | grad_vit = blk_vit.grad.float().detach() 86 | cam_vit = blk_vit.float().detach() 87 | cam_vit = avg_heads(cam_vit, grad_vit) 88 | assert cam_vit.shape == R_i_i.shape, "The vit relevancy map and the llama relevancy map are not the same size" 89 | R_i_i += torch.matmul(cam_vit, R_i_i) 90 | return R_i_i 91 | 92 | def compute_rollout_attention(all_layer_matrices_raw, start_layer=0, average_positive=False, add_residual=False): 93 | all_layer_matrices = [] 94 | # image average self attention in the encoder 95 | for blk in all_layer_matrices_raw: 96 | cam = blk.squeeze().detach() #16x577x577 97 | if average_positive: 98 | cam = cam.clamp(min=0).mean(dim=0) 99 | else: 100 | cam = cam.mean(dim=0) 101 | all_layer_matrices.append(cam) #577x577 102 | layer_attn_avg = [all_layer_matrices[i].detach().clone() for i in range(len(all_layer_matrices))] 103 | # adding residual consideration 104 | num_tokens = all_layer_matrices[0].shape[-1] 105 | eye = torch.eye(num_tokens).to(all_layer_matrices[0].device) #577x577 106 | if add_residual == "start": 107 | all_layer_matrices[start_layer] = eye + all_layer_matrices[start_layer] 108 | all_layer_matrices[start_layer] = all_layer_matrices[start_layer] / all_layer_matrices[start_layer].sum(dim=-1, keepdim=True) 109 | elif add_residual == "all": 110 | all_layer_matrices = [all_layer_matrices[i] + eye for i in range(len(all_layer_matrices))] 111 | all_layer_matrices = [all_layer_matrices[i] / all_layer_matrices[i].sum(dim=-1, keepdim=True) 112 | for i in range(len(all_layer_matrices))] 113 | 114 | matrices_aug = all_layer_matrices 115 | joint_attention = matrices_aug[start_layer] 116 | if start_layer == 0: 117 | for i in range(start_layer+1, len(matrices_aug)): 118 | joint_attention = matrices_aug[i].matmul(joint_attention) 119 | if start_layer == len(matrices_aug)-1: 120 | for i in range(start_layer-1, -1,-1): 121 | joint_attention = matrices_aug[i].matmul(joint_attention) 122 | return joint_attention, layer_attn_avg 123 | 124 | # normalization- eq. 8+9 125 | def handle_residual(orig_self_attention): 126 | self_attention = orig_self_attention.clone() 127 | diag_idx = range(self_attention.shape[-1]) 128 | self_attention -= torch.eye(self_attention.shape[-1]).to(self_attention.device) 129 | assert self_attention[diag_idx, diag_idx].min() >= 0 130 | sum_rows = self_attention.sum(dim=-1, keepdim=True) 131 | sum_rows[sum_rows == 0] = 1 # replace all elements equal to zero by 1 132 | self_attention = self_attention / sum_rows# this has nan elements ue to divoding by zero 133 | self_attention += torch.eye(self_attention.shape[-1]).to(self_attention.device) 134 | return self_attention 135 | 136 | def compute_word_rel_map(tokens, target_index, R_i_i, separators_list, 137 | current_rel_map, current_count, current_word, word_rel_maps): 138 | if target_index == 0: 139 | current_word = tokens[target_index] 140 | current_rel_map = R_i_i 141 | current_count = 1 142 | # If the token is a part of the current word, add its relevancy map to the current word's relevancy map 143 | else: 144 | if not tokens[target_index].startswith('▁') and tokens[target_index] not in separators_list: 145 | current_word += tokens[target_index] 146 | # If current_rel_map is smaller, pad it with zeros 147 | if current_rel_map.shape[0] < R_i_i.shape[0]: 148 | # Calculate the padding sizes 149 | padding = (0, R_i_i.shape[1] - current_rel_map.shape[1], 0, R_i_i.shape[0] - current_rel_map.shape[0]) 150 | # Pad rel_maps[1] with zeros 151 | current_rel_map = F.pad(current_rel_map, padding, "constant", 0) 152 | current_rel_map += R_i_i 153 | current_count += 1 154 | else: 155 | # Otherwise, store the current word's relevancy map and start a new word 156 | word_rel_maps[current_word] = current_rel_map / current_count 157 | current_word = tokens[target_index] 158 | current_rel_map = R_i_i 159 | current_count = 1 160 | return word_rel_maps, current_rel_map, current_count, current_word 161 | 162 | 163 | def construct_relevancy_map(tokenizer, model, input_ids, tokens, outputs, output_ids, img_idx, apply_normalization=True, progress=gr.Progress(track_tqdm=True)): 164 | logger.debug('Tokens: %s', tokens) 165 | enable_vit_relevancy = len(model.enc_attn_weights_vit) > 0 166 | if enable_vit_relevancy: 167 | enc_attn_weights_vit = model.enc_attn_weights_vit 168 | enc_attn_weights = model.enc_attn_weights 169 | device = outputs.attentions[-1][0][0].device 170 | 171 | # compute rollout attention 172 | # start_layer = len(enc_attn_weights_vit)-2 # the last layer is not considered for llava 173 | # rollout_vit, layer_attn_avg = compute_rollout_attention(enc_attn_weights_vit, start_layer,average_positive=False, add_residual=False) 174 | 175 | # compute relevancy maps 176 | rel_maps = [] 177 | rel_maps_all = [] 178 | rel_maps_vit = [] 179 | rel_maps_all_generated_token = [] 180 | 181 | num_generated_tokens = len(outputs.attentions) 182 | num_self_att_layers = len(outputs.attentions[0]) 183 | assert num_generated_tokens == len(outputs.scores) 184 | assert num_generated_tokens*num_self_att_layers == len(enc_attn_weights), f'{num_generated_tokens}x{num_self_att_layers} != {len(enc_attn_weights)}' 185 | # rearenge the attention weights the same as outputs.attentions 186 | enc_attn_weights = [enc_attn_weights[i*num_self_att_layers : (i+1)*num_self_att_layers] for i in range(num_generated_tokens)] 187 | 188 | 189 | assert len(tokens) == len(outputs.scores), f'Length of tokens {len(tokens)} is not equal to the length of outputs.scores {len(outputs.scores)}\ntokens: {tokens}' 190 | clean_tokens = [] 191 | 192 | # Initialize dictionaries 193 | word_rel_maps_llama, word_rel_maps_all, word_rel_maps_vit, word_rel_maps_all_generated_token = {}, {}, {}, {} 194 | word_counts = {} 195 | 196 | # Initialize the averaged attention map for the first token 197 | privious_cam = [] 198 | 199 | # Initialize current_rel_map and current_word variables 200 | current_rel_map, current_rel_map_all, current_rel_map_all_generated_token, current_rel_map_vit = None, None, None, None 201 | current_word, current_word_all, current_word_vit, current_word_all_generated_token = None, None, None, None 202 | 203 | # Initialize current_count variables 204 | current_count, current_count_vit, current_count_all, current_count_all_generated_token = 0, 0, 0, 0 205 | 206 | if enable_vit_relevancy: 207 | enc_attn_weights_vit = enc_attn_weights_vit[:-1] # last layer is not considered for llava 208 | assert len(enc_attn_weights_vit) > 0 209 | 210 | rel_maps_dict = {} 211 | logger.debug(f'Number of output scores: {len(outputs.scores)}') 212 | for target_index in tqdm(range(len(outputs.scores)), desc="Building relevancy maps"): #the last token is 213 | clean_tokens.append(tokens[target_index]) 214 | token_logits = outputs.scores[target_index] 215 | token_id = torch.tensor(output_ids[target_index]).to(device) 216 | 217 | # print out the token and its predicted id 218 | #print(f'Token: {tokens[target_index]}, Predicted ID: {token_id}') 219 | if token_id != output_ids[target_index]: 220 | logger.warning(f'The token_id_max_score is not the same as the output_id') 221 | # print the decoded token 222 | logger.warning(f'Decoded Token: {tokenizer.decode(token_id)}') 223 | logger.warning(f'The generated output: {tokens[target_index]}') 224 | # check if the output_id is the same as the token_id 225 | assert token_id == output_ids[target_index], "The token_id_max_score is not the same as the output_id" 226 | 227 | 228 | token_id_one_hot = torch.nn.functional.one_hot(token_id, num_classes=token_logits.size(-1)).float() 229 | token_id_one_hot = token_id_one_hot.view(1, -1) 230 | token_id_one_hot.requires_grad_(True) 231 | 232 | # Compute loss and backpropagate to get gradients on attention weights 233 | model.zero_grad() 234 | token_logits.backward(gradient=token_id_one_hot, retain_graph=True) 235 | 236 | # initialize relevancy map for llama 237 | R_i_i_init = torch.eye(enc_attn_weights[target_index][0].shape[-1], enc_attn_weights[target_index][0].shape[-1]).to(token_logits.device).float() 238 | # compute relevancy map accourding to rule #6 239 | R_i_i, privious_cam = handle_self_attention_image(R_i_i_init, enc_attn_weights[target_index], privious_cam) 240 | 241 | if enable_vit_relevancy: 242 | # initialize the vit relevancy map with the llama relevancy map 243 | R_i_i_all = handle_self_attention_image_vit(R_i_i, enc_attn_weights_vit, img_idx, add_skip=False, normalize=False) 244 | 245 | # initialize using the relevancy map of the generated token to the image - option #1 246 | R_i_i_init_vit_all = torch.eye(enc_attn_weights_vit[0].shape[-1], enc_attn_weights_vit[0].shape[-1]).to(token_logits.device).float() 247 | 248 | # add R_i_i[-1,:][img_idx:img_idx+576] to the first row and column of R_i_i_init_vit - option #2 249 | R_i_i_init_vit_all[0,1:] = R_i_i_init_vit_all[0,1:] + R_i_i[-1,:][img_idx:img_idx+576] 250 | R_i_i_init_vit_all[1:,0] = R_i_i_init_vit_all[1:,0] + R_i_i[-1,:][img_idx:img_idx+576] 251 | R_i_i_all_generated_token = handle_self_attention_image_vit(R_i_i_init_vit_all, enc_attn_weights_vit) 252 | 253 | # compute ViT relevancy map 254 | R_i_i_init_vit = torch.eye(enc_attn_weights_vit[0].shape[-1], enc_attn_weights_vit[0].shape[-1]).to(token_logits.device).float() 255 | R_i_i_vit = handle_self_attention_image_vit(R_i_i_init_vit, enc_attn_weights_vit) 256 | if apply_normalization: 257 | R_i_i = handle_residual(R_i_i) 258 | if enable_vit_relevancy: 259 | R_i_i_all = handle_residual(R_i_i_all) 260 | R_i_i_vit = handle_residual(R_i_i_vit) 261 | R_i_i_all_generated_token = handle_residual(R_i_i_all_generated_token) 262 | else: 263 | R_i_i = R_i_i - torch.eye(enc_attn_weights[target_index][0].shape[-1], enc_attn_weights[target_index][0].shape[-1]).to(token_logits.device).float() 264 | 265 | rel_maps.append(R_i_i) 266 | if enable_vit_relevancy: 267 | rel_maps_all.append(R_i_i_all) 268 | rel_maps_vit.append(R_i_i_vit) 269 | rel_maps_all_generated_token.append(R_i_i_all_generated_token) 270 | 271 | # values should be rel_maps, and the keys should be the tokens 272 | # check if this token already exsits 273 | if tokens[target_index] in rel_maps_dict.keys(): 274 | tokens[target_index] = tokens[target_index] + '_' 275 | rel_maps_dict[tokens[target_index]] = R_i_i 276 | 277 | # If the token is a part of the current word, add its relevancy map to the current word's relevancy map 278 | word_rel_maps_llama, current_rel_map, current_count, current_word = compute_word_rel_map( 279 | tokens, target_index, R_i_i, SEPARATORS_LIST, 280 | current_rel_map, current_count, current_word, word_rel_maps_llama) 281 | 282 | if enable_vit_relevancy: 283 | word_rel_maps_all, current_rel_map_all, current_count_all, current_word_all = compute_word_rel_map( 284 | tokens, target_index, R_i_i_all, SEPARATORS_LIST, 285 | current_rel_map_all, current_count_all, current_word_all, word_rel_maps_all) 286 | 287 | word_rel_maps_vit, current_rel_map_vit, current_count_vit, current_word_vit = compute_word_rel_map( 288 | tokens, target_index, R_i_i_vit, SEPARATORS_LIST, 289 | current_rel_map_vit, current_count_vit, current_word_vit, word_rel_maps_vit) 290 | 291 | word_rel_maps_all_generated_token, current_rel_map_all_generated_token, \ 292 | current_count_all_generated_token, current_word_all_generated_token = compute_word_rel_map( 293 | tokens, target_index, R_i_i_all_generated_token, SEPARATORS_LIST, 294 | current_rel_map_all_generated_token, current_count_all_generated_token, 295 | current_word_all_generated_token, word_rel_maps_all_generated_token 296 | ) 297 | 298 | logger.debug(f'Current word: {current_word}') 299 | 300 | # Store the last word's relevancy map 301 | word_rel_maps_llama[current_word] = current_rel_map / current_count 302 | 303 | if enable_vit_relevancy: 304 | word_rel_maps_all[current_word_all] = current_rel_map_all / current_count_all 305 | word_rel_maps_vit[current_word_vit] = current_rel_map_vit / current_count_vit 306 | word_rel_maps_all_generated_token[current_word_all_generated_token] = current_rel_map_all_generated_token / current_count_all_generated_token 307 | 308 | 309 | word_rel_maps = { 310 | "llama": word_rel_maps_llama, 311 | "llama_token":rel_maps_dict, 312 | "vit": word_rel_maps_vit, 313 | "all": word_rel_maps_all, 314 | "all_v2": word_rel_maps_all_generated_token 315 | } 316 | 317 | return word_rel_maps 318 | --------------------------------------------------------------------------------