├── LICENSE ├── README.md ├── app ├── __init__.py ├── app_configuration.py ├── app_layout.py ├── assets │ ├── intel_ai_logo.jpg │ └── stylesheet-oil-and-gas.css ├── database │ ├── __init__.py │ ├── database.py │ ├── db_analyzer.py │ ├── db_example.py │ └── models │ │ ├── kdvlp.py │ │ └── vl_model.py └── plot_func.py ├── assets ├── logo.png ├── screencast.png ├── screenshot_add_ex.png └── screenshot_app_launch.png ├── example_database1 ├── data.mdb ├── faiss │ ├── img_indices_0 │ ├── img_indices_1 │ ├── img_indices_10 │ ├── img_indices_11 │ ├── img_indices_12 │ ├── img_indices_2 │ ├── img_indices_3 │ ├── img_indices_4 │ ├── img_indices_5 │ ├── img_indices_6 │ ├── img_indices_7 │ ├── img_indices_8 │ ├── img_indices_9 │ ├── txt_indices_0 │ ├── txt_indices_1 │ ├── txt_indices_10 │ ├── txt_indices_11 │ ├── txt_indices_12 │ ├── txt_indices_2 │ ├── txt_indices_3 │ ├── txt_indices_4 │ ├── txt_indices_5 │ ├── txt_indices_6 │ ├── txt_indices_7 │ ├── txt_indices_8 │ └── txt_indices_9 └── lock.mdb ├── example_database2 ├── data.mdb ├── faiss │ ├── img_indices_0 │ ├── img_indices_1 │ ├── img_indices_10 │ ├── img_indices_11 │ ├── img_indices_12 │ ├── img_indices_2 │ ├── img_indices_3 │ ├── img_indices_4 │ ├── img_indices_5 │ ├── img_indices_6 │ ├── img_indices_7 │ ├── img_indices_8 │ ├── img_indices_9 │ ├── txt_indices_0 │ ├── txt_indices_1 │ ├── txt_indices_10 │ ├── txt_indices_11 │ ├── txt_indices_12 │ ├── txt_indices_2 │ ├── txt_indices_3 │ ├── txt_indices_4 │ ├── txt_indices_5 │ ├── txt_indices_6 │ ├── txt_indices_7 │ ├── txt_indices_8 │ └── txt_indices_9 └── lock.mdb ├── requirements.txt └── run_app.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Intel Labs 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DISCONTINUATION OF PROJECT # 2 | This project will no longer be maintained by Intel. 3 | Intel has ceased development and contributions including, but not limited to, maintenance, bug fixes, new releases, or updates, to this project. 4 | Intel no longer accepts patches to this project. 5 | If you have an ongoing need to use this project, are interested in independently developing it, or would like to maintain patches for the open source software community, please create your own fork of this project. 6 | 7 |

8 | VL-InterpreT: An Interactive Visualization Tool for Interpreting Vision-Language Transformers 9 |

10 | 11 | 12 | VL-InterpreT provides interactive visualizations for interpreting the attentions and hidden representations in vision-language transformers. It is a task agnostic and integrated tool that: 13 | - Tracks a variety of statistics in attention heads throughout all layers for both vision and language components 14 | - Visualizes cross-modal and intra-modal attentions through easily readable heatmaps 15 | - Plots the hidden representations of vision and language tokens as they pass through the transformer layers. 16 | 17 | # Paper 18 | Our paper won the Best Demo Award at CVPR 2022: VL-InterpreT: An Interactive Visualization Tool for Interpreting Vision-Language Transformers 19 | 20 | # Screencast Video 21 | This video provides an overview of VL-Interpret and demonstrates a few interesting examples. 22 | 23 |

24 | Video Demo 25 |

26 | 27 | # Live Demo 28 | A live demo of the app (same as in the screencast video) is available here. 29 | 30 | Please watch the screencast video to get a sense of how to navigate the app. This demo contains 100 examples from the Visual Commonsense Reasoning task and shows the attention and hidden representations from the KD-VLP model. 31 | 32 | # Setup and Usage 33 | You may run VL-InterpreT together with a model of your own choice (see [*Set up a live model*](#set-up-a-live-model)), and/or with a database that contains data extracted from a model. 34 | 35 | To run VL-InterpreT with our example databases, please first clone this repository and install the dependencies. For example: 36 | ```bash 37 | git clone https://github.com/IntelLabs/VL-InterpreT.git 38 | # create and activate your virtual environment if needed, then: 39 | cd VL-InterpreT 40 | pip install -r requirements.txt 41 | ``` 42 | 43 | Then you can run VL-InterpreT (replace 6006 with any port number you prefer): 44 | ```bash 45 | python run_app.py --port 6006 --database example_database2 46 | # alternatively: 47 | python run_app.py -p 6006 -d example_database2 48 | ``` 49 | 50 | We have included two example databases in this repository. `example_database1` contains grey "images" and randomly generated data, and `example_database2` contains one example image+text pair that was processed by the KD-VLP model. 51 | 52 | Once the app runs, it will show the IP address where the app is running on. Open it in your browser to use VL-InterpreT: 53 | 54 |

55 | Screenshot for app launch 56 |

57 | 58 | 59 | ## Set up a database 60 | You may extract data from a transformer in a specific format, and then use VL-InterpreT to visualize them interactively. 61 | 62 | To set up such a database, please see [db_example.py](), which is an example script that creates a database (i.e., `example_database1`) with randomly generated data. To prepare the data from your own transformer, for each image+text pair you should mainly: 63 | - Extract cross-attention weights and hidden representation vectors from your transformer 64 | - For example, you may extract them from [a Huggingface model](https://huggingface.co/transformers/v3.0.2/model_doc/bert.html#transformers.BertModel.forward) by specifying `output_attentions=True` and `output_hidden_states=True`. 65 | - Get the original input image as an array, and input tokens as a list of strings (text tokens followed by image tokens, where image tokens can be named at your discretion, e.g., "img_0", "img_1", etc.). 66 | - For the input image tokens, specify how they corresponds to the positions in the original image, assuming the top left corner has coordinates (0, 0). 67 | 68 | Please refer to [db_example.py]() for more details. You may also look into our `example_database1` and `example_database2` for example databases that have been preprocessed. 69 | 70 | Once you prepared the data as specified, organize them in the following format (again, see [db_example.py]() for the specifics): 71 | ```python 72 | data = [ 73 | { # first example 74 | 'ex_id': 0, 75 | 'image': np.array([]), 76 | 'tokens': [], 77 | 'txt_len': 0, 78 | 'img_coords': [], 79 | 'attention': np.array([]), 80 | 'hidden_states': np.array([]) 81 | }, 82 | { # second example 83 | 'ex_id': 1, 84 | 'image': np.array([]), 85 | 'tokens': [], 86 | 'txt_len': 0, 87 | 'img_coords': [], 88 | 'attention': np.array([]), 89 | 'hidden_states': np.array([]) 90 | }, 91 | # ... 92 | ] 93 | ``` 94 | Then run the following code from the `app/database` directory to create and preprocess your database: 95 | ```python 96 | import pickle 97 | from database import VliLmdb # this is in app/database/database.py 98 | 99 | # create a database 100 | db = VliLmdb(db_dir='path_to_your_database', read_only=False) 101 | # add data 102 | for ex_data in data: 103 | db[str(ex_data['ex_id'])] = pickle.dumps(ex_data, protocol=pickle.HIGHEST_PROTOCOL) 104 | # preprocess the database 105 | db.preprocess() 106 | ``` 107 | Now you can visualize your data with VL-InterpreT: 108 | ```bash 109 | python run_app.py -p 6006 -d path_to_your_database 110 | ``` 111 | 112 | ## Set up a live model 113 | You may also run a live transformer model together with VL-InterpreT. This way, the `Add example` functionality will become available on the web app -- users can add image+text pairs for the transformer to process them in real time and for VL-InterpreT to visualize the process: 114 | 115 |

116 | Screenshot for adding example 117 |

118 | 119 | To add a model, you will need to define your own model class that inherits from the [VL_Model]() base class, and then implement a `data_setup` function in this class. This function should run a forward pass with your model given the input image+text pair, and return data (e.g., attention weights, hidden state vectors, etc.) in the required format. The return data format is the same as what was specified in [*Set up a database*](#set-up-a-database) and in [db_example.py](). **Please see [vl_model.py]() for more details**. You may also refer to [kdvlp.py]() for an example model class. 120 | 121 | Additionally, please start with the following naming pattern for your script and class, to make sure your model runs easily with VL-InterpreT: 122 | - Create a new python script in `app/database/models` for your model class, and name it in all lowercase (e.g., `app/database/models/yourmodelname.py`) 123 | - Name your model class in title case, e.g., `class Yourmodelname`. This class name should be the result of calling `'yourmodelname'.title()`, where `'yourmodelname.py'` is the name of your python script. 124 | - For example, our KD-VLP model class is defined in `app/database/models/kdblp.py`, and named `class Kdvlp`. 125 | 126 | Once your implementation is completed, you can run VL-InterpreT with your model using: 127 | ```bash 128 | python run_app.py --port 6006 --database example_database2 --model yourmodelname your_model_parameters 129 | # alternatively: 130 | python run_app.py -p 6006 -d example_database2 -m yourmodelname your_model_parameters 131 | ``` 132 | 133 | Note that the `__init__` method of your model class may take an arbituary number of parameters, and you may specify these parameters when running VL-InterpreT by putting them after the name of your model (i.e., replacing `your_model_parameters`). 134 | -------------------------------------------------------------------------------- /app/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IntelLabs/VL-InterpreT/0f9d0bb1dad13d91ce79d9453faf81fc2b9277b9/app/__init__.py -------------------------------------------------------------------------------- /app/app_configuration.py: -------------------------------------------------------------------------------- 1 | ''' Module containing the main Dash app code. 2 | 3 | The main function in this module is configureApp() which starts 4 | the Dash app on a Flask server and configures it with callbacks. 5 | ''' 6 | 7 | import os 8 | import copy 9 | import dash 10 | from flask import Flask 11 | from dash.dependencies import Input, Output, State 12 | from dash.exceptions import PreventUpdate 13 | from termcolor import cprint 14 | from .app_layout import get_layout 15 | from .plot_func import * 16 | from app.database.db_analyzer import VliDataBaseAnalyzer 17 | 18 | 19 | def configure_app(db_dir, model=None): 20 | db = VliDataBaseAnalyzer(db_dir, read_only=(not model)) 21 | n_layers = db['n_layers'] 22 | app = start_app(db['n_examples'], n_layers) 23 | 24 | @app.callback( 25 | Output('add_ex_div', 'style'), 26 | Output('add_ex_toggle', 'style'), 27 | Input('add_ex_toggle', 'n_clicks'), 28 | ) 29 | def toggle_add_example(n_clicks): 30 | if model is None: 31 | return {'display': 'none'}, {'display': 'none'} 32 | if not n_clicks: 33 | raise PreventUpdate 34 | if n_clicks % 2 == 0: 35 | return {'display': 'none'}, {'padding': '0 15px'} 36 | return {}, {'padding': '0 15px', 'background': 'lightgrey'} 37 | 38 | 39 | @app.callback( 40 | Output('ex_selector', 'max'), 41 | Output('ex_selector', 'value'), 42 | Input('add_ex_btn', 'n_clicks'), 43 | State('new_ex_img', 'value'), 44 | State('new_ex_txt', 'value'), 45 | ) 46 | def add_example(n_clicks, image_in, text_in): 47 | if not n_clicks: # app initialzation 48 | return db['n_examples'] - 1, '0' 49 | ex_id = db['n_examples'] 50 | data = model.data_setup(ex_id, image_in, text_in) 51 | db.add_example(ex_id, data) 52 | return ex_id, ex_id 53 | 54 | 55 | @app.callback( 56 | Output('ex_text', 'figure'), 57 | Output('ex_txt_len', 'data'), 58 | Output('ex_sm_text', 'figure'), 59 | Output('ex_img_attn_overlay', 'figure'), 60 | [Input('ex_selector', 'value'), 61 | Input('selected_head_layer', 'data'), 62 | Input('selected_text_token', 'data'), 63 | Input('selected_img_tokens', 'data'), 64 | Input('map_smooth_checkbox', 'value')], 65 | [State('ex_txt_len', 'data')] 66 | ) 67 | def display_ex_text_and_img_overlay(ex_id, head_layer, text_token, selected_img_token, smooth, 68 | txt_len): 69 | trigger = get_input_trigger(dash.callback_context) 70 | layer, head = head_layer['layer'], head_layer['head'] 71 | new_state = (layer < 0 or head < 0) or (not txt_len) 72 | if (not selected_img_token) and \ 73 | (not text_token or text_token['sentence'] < 0 or text_token['word'] < 0): 74 | new_state = True 75 | 76 | if new_state or trigger == 'ex_selector': 77 | text_fig, txt_len = plot_text(db, ex_id) 78 | text_attn_fig = text_fig 79 | img_attn_overlay = get_empty_fig() 80 | return text_fig, txt_len, text_attn_fig, img_attn_overlay 81 | 82 | img_token = set(map(tuple, selected_img_token)) 83 | text_fig = highlight_txt_token(db, ex_id, text_token) 84 | if trigger == 'selected_text_token': 85 | img_token = None 86 | elif trigger == 'selected_img_tokens': 87 | text_token = None 88 | text_fig, txt_len = plot_text(db, ex_id) 89 | 90 | text_attn_fig = plot_attn_from_txt(db, ex_id, layer, head, text_token, img_token) 91 | img_attn_overlay = plot_attn_from_img(db, ex_id, layer, head, word_ids=text_token, 92 | img_coords=img_token, smooth=smooth) 93 | 94 | return text_fig, txt_len, text_attn_fig, img_attn_overlay 95 | 96 | 97 | @app.callback( 98 | Output('ex_img_token_overlay', 'figure'), 99 | [Input('ex_selector', 'value'), 100 | Input('selected_img_tokens', 'data')] 101 | ) 102 | def display_img_token_selection(ex_id, selected_img_tokens): 103 | if not ex_id: 104 | return get_empty_fig() 105 | trigger = get_input_trigger(dash.callback_context) 106 | if trigger == 'ex_selector': 107 | token_overlay = get_overlay_fig(db, ex_id) 108 | else: 109 | token_overlay = highlight_img_grid_selection(db, ex_id, selected_img_tokens) 110 | return token_overlay 111 | 112 | 113 | @app.callback( 114 | Output('attn2img', 'style'), 115 | Output('attn2txt', 'style'), 116 | Output('attn2img_label', 'style'), 117 | Output('attn2txt_label', 'style'), 118 | Output('play_btn', 'style'), 119 | [Input('attn2img_toggle', 'value')] 120 | ) 121 | def toggle_attn_to(attn_toggle): 122 | hidden_style = {'display': 'none', 'z-index': '-1'} 123 | if attn_toggle: # hide image 124 | return hidden_style, {}, {'color': 'lightgrey'}, {}, {} 125 | else: # hide text 126 | return {}, hidden_style, {}, {'color': 'lightgrey'}, hidden_style 127 | 128 | 129 | @app.callback( 130 | Output('ex_image', 'figure'), 131 | Output('ex_sm_image', 'figure'), 132 | [Input('ex_selector', 'value')] 133 | ) 134 | def display_ex_image(ex_id): 135 | image_fig = plot_image(db, ex_id) 136 | sm_image_fig = copy.deepcopy(image_fig) 137 | sm_image_fig.update_layout(height=256) 138 | return image_fig, sm_image_fig 139 | 140 | 141 | @app.callback( 142 | Output('selected_img_tokens', 'data'), 143 | [Input('ex_img_token_overlay', 'clickData'), 144 | Input('ex_img_token_overlay', 'selectedData'), 145 | Input('selected_text_token', 'data'), 146 | Input('attn2img_toggle', 'value')], # clear if toggle attention to 147 | [State('selected_img_tokens', 'data')] 148 | ) 149 | def save_img_selection(click_data, selected_data, txt_token, toggle, points): 150 | if points is None: 151 | return [] 152 | if (not click_data) and (not selected_data): 153 | return [] 154 | trigger = get_input_trigger_full(dash.callback_context).split('.')[1] 155 | points = set(map(tuple, points)) 156 | if trigger == 'clickData': 157 | x = click_data['points'][0]['pointNumber'] 158 | y = click_data['points'][0]['curveNumber'] 159 | if (x, y) in points: 160 | points.remove((x, y)) 161 | else: 162 | # points.add((x, y)) 163 | points = {(x, y)} # TODO: enable a set of multiple points and show average 164 | elif trigger == 'selectedData': 165 | if selected_data: 166 | points.update([(pt['pointNumber'], pt['curveNumber']) for pt in selected_data['points']]) 167 | else: 168 | return [] 169 | else: 170 | return [] 171 | return list(points) # list is JSON serializable 172 | 173 | 174 | @app.callback( 175 | Output('selected_text_token', 'data'), 176 | [Input('ex_text', 'clickData'), # from mouse click 177 | Input('movie_progress', 'data'), # from movie slider update 178 | Input('ex_selector', 'value'), # clear if new example selected 179 | Input('attn2img_toggle', 'value')], # clear if toggle attention to 180 | [State('auto_stepper', 'disabled')] 181 | ) 182 | def save_text_click(click_data, movie_progress, ex_id, toggle, movie_disabled): 183 | trigger = get_input_trigger(dash.callback_context) 184 | if trigger == 'ex_text' and click_data: 185 | sentence_id = click_data['points'][0]['curveNumber'] 186 | word_id = click_data['points'][0]['pointIndex'] 187 | return {'sentence': sentence_id, 'word': word_id} 188 | if trigger == 'movie_progress': 189 | if movie_disabled or movie_progress == -1: # no change 190 | raise PreventUpdate() 191 | # convert index to sentence/word 192 | sentence, word = 0, 0 193 | txt_len = db[ex_id]['txt_len'] 194 | for i, token in enumerate(db[ex_id]['tokens'][:txt_len]): 195 | if i == movie_progress: 196 | break 197 | if token == '[SEP]': 198 | sentence += 1 199 | word = -1 200 | word += 1 201 | return {'sentence': sentence, 'word': word} 202 | return None 203 | 204 | 205 | @app.callback( 206 | Output('head_summary', 'figure'), 207 | [Input('ex_selector', 'value'), 208 | Input('selected_head_layer', 'data')], 209 | [State('head_summary', 'figure')] 210 | ) 211 | def display_head_summary(ex_id, head_layer, fig): 212 | layer, head = head_layer['layer'], head_layer['head'] 213 | attn_type_index = get_active_figure_data(fig) 214 | data = plot_head_summary(db, ex_id, attn_type_index, layer, head) 215 | return data 216 | 217 | 218 | @app.callback( 219 | Output('selected_head_layer', 'data'), 220 | [Input('head_summary', 'clickData')] 221 | ) 222 | def save_head_summary_click(click_data): 223 | if click_data: 224 | x, y = get_click_coords(click_data) 225 | return {'layer': y, 'head': x} 226 | return {'layer': 10, 'head': 7} 227 | 228 | 229 | @app.callback( 230 | Output('tsne_title', 'children'), 231 | Output('attn_title', 'children'), 232 | Output('accuracy', 'children'), 233 | [Input('layer_slider', 'value'), 234 | Input('selected_head_layer', 'data'), 235 | Input('ex_selector', 'value')] 236 | ) 237 | def update_titles(tsne_layer, selected_head_layer, ex_id): 238 | if ('accuracy' not in db[ex_id]) or (db[ex_id]['accuracy'] is None): 239 | acc_text = 'unknown' 240 | else: 241 | acc_text = 'correct' if db[ex_id]['accuracy'] else 'incorrect' 242 | if tsne_layer == n_layers: 243 | tsne_title = 't-SNE Embeddings before Layer #1' 244 | else: 245 | tsne_title = f't-SNE Embeddings after Layer #{n_layers - tsne_layer}' 246 | layer, head = selected_head_layer['layer'], selected_head_layer['head'] 247 | head = 'Average' if int(head) != head else f'#{head + 1}' 248 | att_title = f'Attention in Layer #{layer+1}, Head {head}' 249 | return tsne_title, att_title, acc_text 250 | 251 | 252 | # movie callbacks 253 | 254 | @app.callback( 255 | Output('movie_progress', 'data'), 256 | [Input('auto_stepper', 'n_intervals'), 257 | Input('auto_stepper', 'disabled'), 258 | Input('ex_selector', 'value')], 259 | [State('movie_progress', 'data'), 260 | State('ex_txt_len', 'data')] 261 | ) 262 | def stepper_advance(n_intervals, disable, ex_id, movie_progress, txt_len): 263 | trigger = get_input_trigger_full(dash.callback_context) 264 | if trigger.startswith('ex_selector') or \ 265 | (trigger == 'auto_stepper.disabled' and disable): 266 | return -1 267 | if n_intervals: 268 | return (movie_progress + 1) % txt_len 269 | return -1 270 | 271 | @app.callback( 272 | Output('play_btn', 'children'), 273 | Output('play_btn', 'n_clicks'), 274 | Output('auto_stepper', 'disabled'), 275 | [Input('play_btn', 'n_clicks'), 276 | Input('ex_selector', 'value'), 277 | Input('attn2img_toggle', 'value')], 278 | [State('auto_stepper', 'n_intervals')] 279 | ) 280 | def play_btn_click(n_clicks, ex_id, attn_to_toggle, n_intervals): 281 | trigger = get_input_trigger(dash.callback_context) 282 | if (trigger == 'play_btn') and (n_intervals is not None) and \ 283 | (n_clicks is not None) and n_clicks % 2 == 1: 284 | return '\u275A\u275A', n_clicks, False 285 | elif (trigger == 'ex_selector') or (trigger == 'attn2img_toggle'): 286 | return '\u25B6', 0, True 287 | return '\u25B6', n_clicks, True 288 | 289 | 290 | # TSNE 291 | 292 | @app.callback( 293 | Output('tsne_map', 'figure'), 294 | [Input('ex_selector', 'value'), 295 | Input('layer_slider', 'value'), 296 | Input('selected_text_token', 'data'), 297 | Input('selected_img_tokens', 'data')], 298 | [State('auto_stepper', 'disabled')] 299 | ) 300 | def plot_tsne(ex_id, layer, text_tokens_id, img_tokens, movie_paused): 301 | if not movie_paused: 302 | raise PreventUpdate 303 | figure = show_tsne(db, ex_id, n_layers - layer, text_tokens_id, img_tokens) 304 | return figure 305 | 306 | @app.callback( 307 | Output('hover_out', 'show'), 308 | Output('hover_out', 'bbox'), 309 | Output('hover_out', 'children'), 310 | Input('tsne_map', 'hoverData'), 311 | Input('layer_slider', 'value') 312 | ) 313 | def display_hover(hover_data, layer): 314 | if hover_data is None: 315 | return False, dash.no_update, dash.no_update 316 | layer = n_layers - layer 317 | point = hover_data['points'][0] 318 | ex_id, token_id = point['hovertext'].split(',') 319 | tooltip_children = plot_tooltip_content(db, ex_id, int(token_id)) 320 | return True, point['bbox'], tooltip_children 321 | 322 | 323 | return app 324 | 325 | 326 | def start_app(n_examples, n_layers): 327 | print('Starting server') 328 | server = Flask(__name__) 329 | server.secret_key = os.environ.get('secret_key', 'secret') 330 | 331 | app = dash.Dash(__name__, server=server, url_base_pathname='/') 332 | app.title = 'VL-InterpreT' 333 | app.layout = get_layout(n_examples, n_layers) 334 | return app 335 | 336 | def print_page_link(hostname, port): 337 | print('\n\n') 338 | cprint('------------------------' '------------------------', 'green') 339 | cprint('App Launched!', 'red') 340 | cprint('------------------------' '------------------------', 'green') 341 | 342 | def get_click_coords(click_data): 343 | return click_data['points'][0]['x'], click_data['points'][0]['y'] 344 | 345 | def get_input_trigger(ctx): 346 | return ctx.triggered[0]['prop_id'].split('.')[0] 347 | 348 | def get_input_trigger_full(ctx): 349 | return ctx.triggered[0]['prop_id'] 350 | 351 | def get_active_figure_data(fig): 352 | return fig['layout']['updatemenus'][0]['active'] if fig else 0 353 | -------------------------------------------------------------------------------- /app/app_layout.py: -------------------------------------------------------------------------------- 1 | """ Module specifying Dash app UI layout 2 | 3 | The main function here is the get_layout() function which returns 4 | the Dash/HTML layout for InterpreT. 5 | """ 6 | 7 | from dash import dcc 8 | from dash import html 9 | import dash_daq as daq 10 | import base64 11 | import os 12 | 13 | intel_dark_blue = "#0168b5" 14 | intel_light_blue = "#04c7fd" 15 | 16 | 17 | def get_layout(n_examples, n_layers): 18 | logoConfig = dict(displaylogo=False, modeBarButtonsToRemove=["sendDataToCloud"]) 19 | image_filename = os.path.join("app", "assets", "intel_ai_logo.jpg") 20 | encoded_image = base64.b64encode(open(image_filename, "rb").read()) 21 | 22 | layout = html.Div( 23 | [ 24 | # Stored values 25 | dcc.Store(id="selected_head_layer"), 26 | dcc.Store(id="selected_text_token", data={'sentence': -1, 'word': -1}), 27 | dcc.Store(id="selected_img_tokens"), 28 | dcc.Store(id="selected_token_from_matrix"), 29 | dcc.Store(id="movie_progress", data=0), 30 | dcc.Store(id="ex_txt_len"), 31 | 32 | 33 | # Header 34 | html.Div( 35 | [ 36 | html.Div( 37 | [ 38 | html.Img( # Intel logo 39 | src=f"data:image/png;base64,{encoded_image.decode()}", 40 | style={ 41 | "display": "inline", 42 | "height": str(174 * 0.18) + "px", 43 | "width": str(600 * 0.18) + "px", 44 | "position": "relative", 45 | "padding-right": "30px", 46 | "vertical-align": "middle", 47 | }, 48 | ), 49 | html.Div([ 50 | html.Span("VL-Interpre", style={'color': intel_dark_blue}), 51 | html.Span("T", style={'color': intel_light_blue}) 52 | ], style={"font-size": "40px", "display": "inline", "vertical-align": "middle"} 53 | ), 54 | html.Div( 55 | children=[ 56 | html.Span("An Interactive Visualization Tool for "), 57 | html.Strong("Interpre", style={'color': intel_dark_blue}), 58 | html.Span("ting "), 59 | html.Strong("V", style={'color': intel_dark_blue}), 60 | html.Span("ision-"), 61 | html.Strong("L", style={'color': intel_dark_blue}), 62 | html.Span("anguage "), 63 | html.Strong("T", style={'color': intel_light_blue}), 64 | html.Span("ransformers") 65 | ], 66 | style={"font-size": "25px"} 67 | ), 68 | ], 69 | ) 70 | ], 71 | style={ "text-align": "center", "margin": "2%"}, 72 | ), 73 | 74 | 75 | # Example selector 76 | html.Div( 77 | [html.Label("Example ID", className="plot-label", style={"margin-right": "10px"}), 78 | dcc.Input(id="ex_selector", type="number", value=0, min=0, max=n_examples-1), 79 | html.Span(" - Model prediction ", className="plot-label", style={"font-weight": "300"}), 80 | html.Label("", id="accuracy", className="plot-label", style={"margin-right": "70px"}), 81 | html.Button('Add example', id='add_ex_toggle', style={"padding": "0 15px"}), 82 | 83 | # Add example 84 | html.Div([ 85 | html.Div( 86 | [html.Div([ 87 | html.Label("Image", className="plot-label", style={"margin-right": "10px"}), 88 | dcc.Input( 89 | id="new_ex_img", 90 | placeholder="Enter URL or path", 91 | style={"width": "80%", "max-width": "600px", "margin": "5px"}) 92 | ]), html.Div([ 93 | html.Label("Text", className="plot-label", style={"margin-right": "10px"}), 94 | dcc.Input(id="new_ex_txt", 95 | placeholder="Enter text", 96 | style={"width": "80%", "max-width": "600px", "margin": "5px"}) 97 | ]), 98 | html.Button("Add", id="add_ex_btn", 99 | style={"padding": "0 15px", "margin": "5px", "background": "white"})], 100 | style={"margin": "5px 10%", "padding": "5px", "background": "#f1f1f1"})], 101 | id="add_ex_div", 102 | style={"display": "none"}) 103 | ], 104 | style={"text-align": "center", "margin": "2%"}, 105 | ), 106 | 107 | 108 | # Head summary matrix 109 | html.Div( 110 | [html.Label("Attention Head Summary", className="plot-label"), 111 | dcc.Graph(id="head_summary", config={"displayModeBar": False})], 112 | style={"display": "inline-block", "margin-right": "5%", "vertical-align": "top"} 113 | ), 114 | 115 | 116 | # TSNE 117 | html.Div( 118 | [ 119 | html.Label("t-SNE Embeddings", id="tsne_title", className="plot-label"), 120 | html.Div([ 121 | dcc.Graph(id="tsne_map", config=logoConfig, clear_on_unhover=True), 122 | dcc.Tooltip(id="hover_out", direction='bottom') 123 | ], style={"display": "inline-block", "width": "85%"}), 124 | html.Div([ 125 | html.Label("Layer", style={}), 126 | html.Div([ 127 | dcc.Slider( 128 | id="layer_slider", min=0, max=n_layers, step=1, 129 | value=0, included=False, vertical=True, 130 | marks={i: { 131 | "label": str(n_layers - i), 132 | "style": {"margin": "0 0 -10px -41px" if (i < n_layers - 9) else "0 0 -10px -35px"} 133 | } for i in range(n_layers + 1)} 134 | ) 135 | ], style={"margin": "-10px -25px -25px 24px", }) 136 | ], style={"display": "inline-block", "vertical-align": "top", 137 | "width": "44px", "height": "450px", "margin": "-10px 0 0 10px", 138 | "border": "1px solid lightgray", "border-radius": "5px"}), 139 | ], 140 | style={"width": "45%", "display": "inline-block", "vertical-align": "top"} 141 | ), 142 | 143 | html.Div( 144 | [html.Label("Attention", id="attn_title", className="plot-label"), 145 | html.Hr(style={"margin": "0 15%"})], 146 | style={"margin": "3% 0 5px 0"}), 147 | 148 | 149 | # Attention to 150 | html.Div( 151 | [html.Div( 152 | [html.Label("Attention to:", className="plot-label", style={"margin-right": "10px"}), 153 | html.Label("Image", id="attn2img_label", className="plot-label"), 154 | html.Div( 155 | daq.ToggleSwitch(id="attn2img_toggle", value=False, size=40), 156 | style={"display": "inline-block", "margin": "0 10px", "vertical-align": "top"} 157 | ), 158 | html.Label("Text", id="attn2txt_label", className="plot-label"), 159 | html.Button("\u25B6", id="play_btn", n_clicks=0, style={"display": "none"})], 160 | style={"height": "34px", "width": "400px", "text-align": "left", 161 | "display": "inline-block", "padding-left": "100px"} 162 | ), 163 | html.Div(id="attn2img", children=[ # image 164 | html.Div( 165 | dcc.Graph(id="ex_image", config={"displayModeBar": False}, style={"height": "450px"}), 166 | style={"width": "40%", "position": "absolute"}), 167 | html.Div( 168 | [dcc.Graph(id="ex_img_token_overlay", config={"displayModeBar": False})], 169 | style={"width": "40%", "position": "absolute", "height": "450px"}), 170 | ]), 171 | html.Div(id="attn2txt", children=[ # text 172 | html.Div( 173 | [dcc.Graph(id="ex_text", config={"displayModeBar": False})], 174 | style={"width": "40%", "position": "absolute", "height": "450px"}) 175 | ], style={"display": "none", "z-index": "-1"}) 176 | ], 177 | style={"width": "40%", "display": "inline-block", "vertical-align": "top"} 178 | ), 179 | 180 | 181 | # Attention from 182 | html.Div( 183 | [html.Div([html.Label("Attention from Text", className="plot-label", style={"height": "26px"})]), 184 | html.Div( # text display area 185 | [dcc.Graph(id="ex_sm_text", config={"displayModeBar": False}, style={"height": "174px"})] 186 | ), 187 | html.Div([ # label 188 | html.Hr(style={"height": "1px", "margin": "0"}), 189 | html.Label("Attention from Image", className="plot-label", style={"height": "26px"}) 190 | ]), 191 | html.Div( # image display area 192 | [dcc.Graph(id="ex_sm_image", config={"displayModeBar": False}, 193 | style={"width": "37%", "height": "256px", "margin-left": "4%", "position": "absolute"}), 194 | dcc.Graph(id="ex_img_attn_overlay", config={"displayModeBar": False}, 195 | style={"width": "37%", "height": "256px", "margin-left": "4%", "position": "absolute"})] 196 | )], 197 | style={"height": "484px", "width": "45%", "display": "inline-block", "margin-left": "15px", "vertical-align": "top"}, 198 | ), 199 | # Movie control (hidden) 200 | html.Div( 201 | [dcc.Interval(id="auto_stepper", interval=2500, n_intervals=0, disabled=True)], 202 | style={"display": "none"} 203 | ), 204 | 205 | # Checkbox for smoothing 206 | html.Div( 207 | [dcc.Checklist( 208 | id="map_smooth_checkbox", 209 | options=[{'label': 'Smooth attention map', 'value': 'smooth'}], 210 | value=['smooth'], 211 | style={"text-align": "center"})], 212 | style={"width": "45%", "display": "inline-block", "margin": "5px 0 2% 0", "padding-left": "40%"} 213 | ), 214 | ], 215 | style={"textAlign": "center"} 216 | ) 217 | 218 | return layout 219 | -------------------------------------------------------------------------------- /app/assets/intel_ai_logo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IntelLabs/VL-InterpreT/0f9d0bb1dad13d91ce79d9453faf81fc2b9277b9/app/assets/intel_ai_logo.jpg -------------------------------------------------------------------------------- /app/assets/stylesheet-oil-and-gas.css: -------------------------------------------------------------------------------- 1 | /* Table of contents 2 | –––––––––––––––––––––––––––––––––––––––––––––––––– 3 | - Grid 4 | - Base Styles 5 | - Typography 6 | - Links 7 | - Buttons 8 | - Forms 9 | - Lists 10 | - Code 11 | - Tables 12 | - Spacing 13 | - Utilities 14 | - Clearing 15 | - Media Queries 16 | */ 17 | 18 | 19 | /* Grid 20 | –––––––––––––––––––––––––––––––––––––––––––––––––– */ 21 | .container { 22 | position: relative; 23 | width: 100%; 24 | max-width: 960px; 25 | margin: 0 auto; 26 | padding: 0 20px; 27 | box-sizing: border-box; } 28 | .column, 29 | .columns { 30 | width: 100%; 31 | float: left; 32 | box-sizing: border-box; } 33 | 34 | /* For devices larger than 400px */ 35 | @media (min-width: 400px) { 36 | .container { 37 | width: 85%; 38 | padding: 0; } 39 | } 40 | 41 | /* For devices larger than 550px */ 42 | @media (min-width: 550px) { 43 | .container { 44 | width: 80%; } 45 | .column, 46 | .columns { 47 | margin-left: 0.5%; } 48 | .column:first-child, 49 | .columns:first-child { 50 | margin-left: 0; } 51 | 52 | .one.column, 53 | .one.columns { width: 8%; } 54 | .two.columns { width: 16.25%; } 55 | .three.columns { width: 22%; } 56 | .four.columns { width: 33%; } 57 | .five.columns { width: 39.3333333333%; } 58 | .six.columns { width: 49.75%; } 59 | .seven.columns { width: 56.6666666667%; } 60 | .eight.columns { width: 66.5%; } 61 | .nine.columns { width: 74.0%; } 62 | .ten.columns { width: 82.6666666667%; } 63 | .eleven.columns { width: 91.5%; } 64 | .twelve.columns { width: 100%; margin-left: 0; } 65 | 66 | .one-third.column { width: 30.6666666667%; } 67 | .two-thirds.column { width: 65.3333333333%; } 68 | 69 | .one-half.column { width: 48%; } 70 | 71 | /* Offsets */ 72 | .offset-by-one.column, 73 | .offset-by-one.columns { margin-left: 8.66666666667%; } 74 | .offset-by-two.column, 75 | .offset-by-two.columns { margin-left: 17.3333333333%; } 76 | .offset-by-three.column, 77 | .offset-by-three.columns { margin-left: 26%; } 78 | .offset-by-four.column, 79 | .offset-by-four.columns { margin-left: 34.6666666667%; } 80 | .offset-by-five.column, 81 | .offset-by-five.columns { margin-left: 43.3333333333%; } 82 | .offset-by-six.column, 83 | .offset-by-six.columns { margin-left: 52%; } 84 | .offset-by-seven.column, 85 | .offset-by-seven.columns { margin-left: 60.6666666667%; } 86 | .offset-by-eight.column, 87 | .offset-by-eight.columns { margin-left: 69.3333333333%; } 88 | .offset-by-nine.column, 89 | .offset-by-nine.columns { margin-left: 78.0%; } 90 | .offset-by-ten.column, 91 | .offset-by-ten.columns { margin-left: 86.6666666667%; } 92 | .offset-by-eleven.column, 93 | .offset-by-eleven.columns { margin-left: 95.3333333333%; } 94 | 95 | .offset-by-one-third.column, 96 | .offset-by-one-third.columns { margin-left: 34.6666666667%; } 97 | .offset-by-two-thirds.column, 98 | .offset-by-two-thirds.columns { margin-left: 69.3333333333%; } 99 | 100 | .offset-by-one-half.column, 101 | .offset-by-one-half.columns { margin-left: 52%; } 102 | 103 | } 104 | 105 | 106 | /* Base Styles 107 | –––––––––––––––––––––––––––––––––––––––––––––––––– */ 108 | /* NOTE 109 | html is set to 62.5% so that all the REM measurements throughout Skeleton 110 | are based on 10px sizing. So basically 1.5rem = 15px :) */ 111 | html { 112 | font-size: 62.5%; } 113 | body { 114 | font-size: 1.5em; /* currently ems cause chrome bug misinterpreting rems on body element */ 115 | line-height: 1.6; 116 | font-weight: 400; 117 | font-family: "Open Sans", "HelveticaNeue", "Helvetica Neue", Helvetica, Arial, sans-serif; 118 | color: rgb(50, 50, 50); 119 | margin: 0; 120 | } 121 | 122 | 123 | /* Typography 124 | –––––––––––––––––––––––––––––––––––––––––––––––––– */ 125 | h1, h2, h3, h4, h5, h6 { 126 | margin-top: 0; 127 | margin-bottom: 0; 128 | font-weight: 300; } 129 | h1 { font-size: 4.5rem; line-height: 1.2; letter-spacing: -.1rem; margin-bottom: 2rem; } 130 | h2 { font-size: 3.6rem; line-height: 1.25; letter-spacing: -.1rem; margin-bottom: 1.8rem; margin-top: 1.8rem;} 131 | h3 { font-size: 3.0rem; line-height: 1.3; letter-spacing: -.1rem; margin-bottom: 1.5rem; margin-top: 1.5rem;} 132 | h4 { font-size: 2.6rem; line-height: 1.35; letter-spacing: -.08rem; margin-bottom: 1.2rem; margin-top: 1.2rem;} 133 | h5 { font-size: 2.2rem; line-height: 1.5; letter-spacing: -.05rem; margin-bottom: 0.6rem; margin-top: 0.6rem;} 134 | h6 { font-size: 2.0rem; line-height: 1.6; letter-spacing: 0; margin-bottom: 0.75rem; margin-top: 0.75rem;} 135 | 136 | p { 137 | margin-top: 0; } 138 | 139 | 140 | /* Blockquotes 141 | –––––––––––––––––––––––––––––––––––––––––––––––––– */ 142 | blockquote { 143 | border-left: 4px lightgrey solid; 144 | padding-left: 1rem; 145 | margin-top: 2rem; 146 | margin-bottom: 2rem; 147 | margin-left: 0rem; 148 | } 149 | 150 | 151 | /* Links 152 | –––––––––––––––––––––––––––––––––––––––––––––––––– */ 153 | a { 154 | color: #1EAEDB; } 155 | a:hover { 156 | color: #0FA0CE; } 157 | 158 | 159 | /* Buttons 160 | –––––––––––––––––––––––––––––––––––––––––––––––––– */ 161 | .button, 162 | button, 163 | input[type="submit"], 164 | input[type="reset"], 165 | input[type="button"] { 166 | display: inline-block; 167 | height: 38px; 168 | padding: 0 30px; 169 | color: #555; 170 | text-align: center; 171 | font-size: 11px; 172 | font-weight: 600; 173 | line-height: 38px; 174 | letter-spacing: .1rem; 175 | text-transform: uppercase; 176 | text-decoration: none; 177 | white-space: nowrap; 178 | background-color: transparent; 179 | border-radius: 4px; 180 | border: 1px solid #bbb; 181 | cursor: pointer; 182 | box-sizing: border-box; } 183 | .button:hover, 184 | button:hover, 185 | input[type="submit"]:hover, 186 | input[type="reset"]:hover, 187 | input[type="button"]:hover, 188 | .button:focus, 189 | button:focus, 190 | input[type="submit"]:focus, 191 | input[type="reset"]:focus, 192 | input[type="button"]:focus { 193 | color: #333; 194 | border-color: #888; 195 | outline: 0; } 196 | .button.button-primary, 197 | button.button-primary, 198 | input[type="submit"].button-primary, 199 | input[type="reset"].button-primary, 200 | input[type="button"].button-primary { 201 | color: #FFF; 202 | background-color: #33C3F0; 203 | border-color: #33C3F0; } 204 | .button.button-primary:hover, 205 | button.button-primary:hover, 206 | input[type="submit"].button-primary:hover, 207 | input[type="reset"].button-primary:hover, 208 | input[type="button"].button-primary:hover, 209 | .button.button-primary:focus, 210 | button.button-primary:focus, 211 | input[type="submit"].button-primary:focus, 212 | input[type="reset"].button-primary:focus, 213 | input[type="button"].button-primary:focus { 214 | color: #FFF; 215 | background-color: #1EAEDB; 216 | border-color: #1EAEDB; } 217 | 218 | #play_btn { 219 | height: 24px; 220 | display: "inline-block"; 221 | margin-left: 20px; 222 | line-height: 24px; 223 | padding: 0 25px; 224 | } 225 | 226 | /* Forms 227 | –––––––––––––––––––––––––––––––––––––––––––––––––– */ 228 | input[type="email"], 229 | input[type="number"], 230 | input[type="search"], 231 | input[type="text"], 232 | input[type="tel"], 233 | input[type="url"], 234 | input[type="password"], 235 | textarea, 236 | select { 237 | height: 38px; 238 | padding: 6px 10px; /* The 6px vertically centers text on FF, ignored by Webkit */ 239 | background-color: #fff; 240 | border: 1px solid #D1D1D1; 241 | border-radius: 4px; 242 | box-shadow: none; 243 | box-sizing: border-box; 244 | font-family: inherit; 245 | font-size: inherit; /*https://stackoverflow.com/questions/6080413/why-doesnt-input-inherit-the-font-from-body*/} 246 | /* Removes awkward default styles on some inputs for iOS */ 247 | input[type="email"], 248 | input[type="number"], 249 | input[type="search"], 250 | input[type="text"], 251 | input[type="tel"], 252 | input[type="url"], 253 | input[type="password"], 254 | textarea { 255 | -webkit-appearance: none; 256 | -moz-appearance: none; 257 | appearance: none; } 258 | textarea { 259 | min-height: 65px; 260 | padding-top: 6px; 261 | padding-bottom: 6px; } 262 | input[type="email"]:focus, 263 | input[type="number"]:focus, 264 | input[type="search"]:focus, 265 | input[type="text"]:focus, 266 | input[type="tel"]:focus, 267 | input[type="url"]:focus, 268 | input[type="password"]:focus, 269 | textarea:focus, 270 | select:focus { 271 | border: 1px solid #33C3F0; 272 | outline: 0; } 273 | label, 274 | legend { 275 | margin-bottom: 0px; } 276 | fieldset { 277 | padding: 0; 278 | border-width: 0; } 279 | input[type="checkbox"], 280 | input[type="radio"] { 281 | display: inline; } 282 | label > .label-body { 283 | display: inline-block; 284 | margin-left: .5rem; 285 | font-weight: normal; } 286 | .plot-label { 287 | font-family: "Trebuchet MS", "verdana"; 288 | font-size: 16px; 289 | justify-content: "center"; 290 | font-weight: 600; 291 | display: "inline-block"; 292 | } 293 | 294 | /* Lists 295 | –––––––––––––––––––––––––––––––––––––––––––––––––– */ 296 | ul { 297 | list-style: circle inside; } 298 | ol { 299 | list-style: decimal inside; } 300 | ol, ul { 301 | padding-left: 0; 302 | margin-top: 0; } 303 | ul ul, 304 | ul ol, 305 | ol ol, 306 | ol ul { 307 | margin: 1.5rem 0 1.5rem 3rem; 308 | font-size: 90%; } 309 | li { 310 | margin-bottom: 1rem; } 311 | 312 | 313 | /* Tables 314 | –––––––––––––––––––––––––––––––––––––––––––––––––– */ 315 | th, 316 | td { 317 | padding: 12px 15px; 318 | text-align: left; 319 | border-bottom: 1px solid #E1E1E1; } 320 | th:first-child, 321 | td:first-child { 322 | padding-left: 0; } 323 | th:last-child, 324 | td:last-child { 325 | padding-right: 0; } 326 | 327 | 328 | /* Spacing 329 | –––––––––––––––––––––––––––––––––––––––––––––––––– */ 330 | button, 331 | .button { 332 | margin-bottom: 0rem; } 333 | input, 334 | textarea, 335 | select, 336 | fieldset { 337 | margin-bottom: 0rem; } 338 | pre, 339 | dl, 340 | figure, 341 | table, 342 | form { 343 | margin-bottom: 0rem; } 344 | p, 345 | ul, 346 | ol { 347 | margin-bottom: 0.75rem; } 348 | 349 | /* Utilities 350 | –––––––––––––––––––––––––––––––––––––––––––––––––– */ 351 | .u-full-width { 352 | width: 100%; 353 | box-sizing: border-box; } 354 | .u-max-full-width { 355 | max-width: 100%; 356 | box-sizing: border-box; } 357 | .u-pull-right { 358 | float: right; } 359 | .u-pull-left { 360 | float: left; } 361 | 362 | 363 | /* Misc 364 | –––––––––––––––––––––––––––––––––––––––––––––––––– */ 365 | hr { 366 | margin-top: 3rem; 367 | margin-bottom: 3.5rem; 368 | border-width: 0; 369 | border-top: 1px solid #E1E1E1; } 370 | 371 | 372 | /* Clearing 373 | –––––––––––––––––––––––––––––––––––––––––––––––––– */ 374 | 375 | /* Self Clearing Goodness */ 376 | .container:after, 377 | .row:after, 378 | .u-cf { 379 | content: ""; 380 | display: table; 381 | clear: both; } 382 | 383 | 384 | /* Media Queries 385 | –––––––––––––––––––––––––––––––––––––––––––––––––– */ 386 | /* 387 | Note: The best way to structure the use of media queries is to create the queries 388 | near the relevant code. For example, if you wanted to change the styles for buttons 389 | on small devices, paste the mobile query code up in the buttons section and style it 390 | there. 391 | */ 392 | 393 | 394 | /* Larger than mobile */ 395 | @media (min-width: 400px) {} 396 | 397 | /* Larger than phablet (also point when grid becomes active) */ 398 | @media (min-width: 550px) {} 399 | 400 | /* Larger than tablet */ 401 | @media (min-width: 750px) {} 402 | 403 | /* Larger than desktop */ 404 | @media (min-width: 1000px) {} 405 | 406 | /* Larger than Desktop HD */ 407 | @media (min-width: 1200px) {} 408 | -------------------------------------------------------------------------------- /app/database/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IntelLabs/VL-InterpreT/0f9d0bb1dad13d91ce79d9453faf81fc2b9277b9/app/database/__init__.py -------------------------------------------------------------------------------- /app/database/database.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from collections import defaultdict 4 | from sklearn.manifold import TSNE 5 | import faiss 6 | from tqdm import tqdm 7 | import pickle 8 | import lmdb 9 | from lz4.frame import compress, decompress 10 | import msgpack 11 | import msgpack_numpy 12 | msgpack_numpy.patch() 13 | 14 | 15 | class VliLmdb(object): 16 | def __init__(self, db_dir, read_only=True, local_world_size=1): 17 | self.readonly = read_only 18 | self.path = db_dir 19 | if not os.path.isdir(db_dir): 20 | os.mkdir(db_dir) 21 | if read_only: 22 | readahead = not self._check_distributed(local_world_size) 23 | self.env = lmdb.open(db_dir, readonly=True, create=False, lock=False, readahead=readahead) 24 | self.txn = self.env.begin(buffers=True) 25 | self.write_cnt = None 26 | else: 27 | self.env = lmdb.open(db_dir, readonly=False, create=True, map_size=4 * 1024**4) 28 | self.txn = self.env.begin(write=True) 29 | self.write_cnt = 0 30 | if not self.txn.get('n_examples'.encode('utf-8')): 31 | self['n_examples'] = 0 32 | 33 | def _check_distributed(self, local_world_size): 34 | try: 35 | dist = local_world_size != 1 36 | except ValueError: 37 | # not using horovod 38 | dist = False 39 | return dist 40 | 41 | def __del__(self): 42 | if self.write_cnt: 43 | self.txn.commit() 44 | self.env.close() 45 | 46 | def __getitem__(self, key): 47 | value = self.txn.get(str(key).encode('utf-8')) 48 | if value is None: 49 | raise KeyError(key) 50 | return msgpack.loads(decompress(value), raw=False) 51 | 52 | def __setitem__(self, key, value): 53 | # NOTE: not thread safe 54 | if self.readonly: 55 | raise ValueError('readonly text DB') 56 | ret = self.txn.put(key.encode('utf-8'), compress(msgpack.dumps(value, use_bin_type=True))) 57 | self.write_cnt += 1 58 | if self.write_cnt % 1000 == 0: 59 | self.txn.commit() 60 | self.txn = self.env.begin(write=True) 61 | self.write_cnt = 0 62 | return ret 63 | 64 | def preprocess_example(self, ex_id, ex_data, faiss_indices, faiss_data): 65 | txt_len = ex_data['txt_len'] 66 | # image tokens 67 | if ex_data['img_coords'] is not None and len(ex_data['img_coords']) > 0: 68 | ex_data['img_grid_size'] = np.max(ex_data['img_coords'], axis=0) + 1 69 | img_coords = [tuple(coords) for coords in ex_data['img_coords']] 70 | img_coords = np.array(img_coords, np.dtype([('x', int), ('y', int)])) 71 | img_sort = np.argsort(img_coords, order=('y', 'x')) 72 | ex_data['img_coords'] = np.take_along_axis(img_coords, img_sort, axis=0).tolist() 73 | img_tokens = np.array(ex_data['tokens'][txt_len:]) 74 | ex_data['tokens'][txt_len:] = np.take_along_axis(img_tokens, img_sort, axis=0) 75 | ex_data['attention'][:,:,txt_len:] = np.take(ex_data['attention'][:,:,txt_len:], img_sort, axis=2) 76 | ex_data['attention'][:,:,:,txt_len:] = np.take(ex_data['attention'][:,:,:,txt_len:], img_sort, axis=3) 77 | else: 78 | ex_data['img_grid_size'] = (0, 0) 79 | print(f'Warning: Image coordinates are missing for example #{ex_id}.') 80 | # t-SNE 81 | tsne = [TSNE(n_components=2, random_state=None, n_jobs=-1).fit_transform(ex_data['hidden_states'][i]) 82 | for i in range(len(ex_data['hidden_states']))] 83 | if len(faiss_indices) == 0: 84 | faiss_indices = {(layer, mod): faiss.IndexFlatL2(2) \ 85 | for layer in range(len(tsne)) for mod in ('txt', 'img')} 86 | ex_data['tsne'] = tsne 87 | self[str(ex_id)] = pickle.dumps(ex_data, protocol=pickle.HIGHEST_PROTOCOL) 88 | self['n_examples'] += 1 89 | # faiss 90 | for layer in range(len(tsne)): 91 | faiss_all_tokens = [(ex_id, i) for i in range(len(ex_data['tokens']))] 92 | faiss_data[(layer, 'txt')] += faiss_all_tokens[:txt_len] 93 | faiss_data[(layer, 'img')] += faiss_all_tokens[txt_len:] 94 | tsne_txt = tsne[layer][:txt_len] 95 | tsne_img = tsne[layer][txt_len:] 96 | faiss_indices[(layer, 'txt')].add(tsne_txt) 97 | faiss_indices[(layer, 'img')].add(tsne_img) 98 | return faiss_indices, faiss_data 99 | 100 | def preprocess(self): 101 | print('Preprocessing database...') 102 | faiss_indices = {} 103 | faiss_data = defaultdict(list) 104 | n_layers, n_heads = 0, 0 105 | with self.txn.cursor() as cursor: 106 | for key, value in tqdm(cursor): 107 | ex_id = str(key, 'utf8') 108 | if not ex_id.isdigit(): 109 | continue 110 | ex_data = pickle.loads(msgpack.loads(decompress(value), raw=False)) 111 | faiss_indices, faiss_data = self.preprocess_example(ex_id, ex_data, faiss_indices, faiss_data) 112 | 113 | n_layers, n_heads, _, _ = ex_data['attention'].shape 114 | self['n_layers'] = n_layers 115 | self['n_heads'] = n_heads 116 | 117 | print('Generating faiss indices...') 118 | faiss_path = os.path.join(self.path, 'faiss') 119 | if not os.path.exists(faiss_path): 120 | os.makedirs(faiss_path) 121 | for layer in range(n_layers + 1): 122 | faiss.write_index(faiss_indices[(layer, 'txt')], os.path.join(faiss_path, f'txt_indices_{layer}')) 123 | faiss.write_index(faiss_indices[(layer, 'img')], os.path.join(faiss_path, f'img_indices_{layer}')) 124 | self['faiss'] = pickle.dumps(faiss_data, protocol=pickle.HIGHEST_PROTOCOL) 125 | print('Preprocessing done.') 126 | -------------------------------------------------------------------------------- /app/database/db_analyzer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from scipy.stats import zscore 4 | import faiss 5 | import pickle 6 | from app.database.database import VliLmdb 7 | 8 | ATTN_MAP_SCALE_FACTOR = 64 9 | 10 | class VliDataBaseAnalyzer: 11 | def __init__(self, db_dir, read_only=True): 12 | self.db_dir = db_dir 13 | self.db = VliLmdb(db_dir, read_only=read_only) 14 | faiss_folder = os.path.join(db_dir, 'faiss') 15 | self.faiss_indices = {(layer, mod): faiss.read_index(f'{faiss_folder}/{mod}_indices_{layer}') 16 | for layer in range(self['n_layers'] + 1) for mod in ('txt', 'img')} 17 | self.faiss_data = self['faiss'] 18 | self.updated = False 19 | 20 | def __getitem__(self, item): 21 | if isinstance(self.db[item], bytes): 22 | return pickle.loads(self.db[item]) 23 | return self.db[item] 24 | 25 | def __del__(self): 26 | if self.updated: 27 | print('Writing new examples to database...') 28 | self.db['faiss'] = pickle.dumps(self.faiss_data, protocol=pickle.HIGHEST_PROTOCOL) 29 | faiss_path = os.path.join(self.db.path, 'faiss') 30 | for layer in range(self['n_layers'] + 1): 31 | faiss.write_index(self.faiss_indices[(layer, 'txt')], os.path.join(faiss_path, f'txt_indices_{layer}')) 32 | faiss.write_index(self.faiss_indices[(layer, 'img')], os.path.join(faiss_path, f'img_indices_{layer}')) 33 | print('Done.') 34 | 35 | def add_example(self, ex_id, ex_data): 36 | self.faiss_indices, self.faiss_data = self.db.preprocess_example(ex_id, ex_data, self.faiss_indices, self.faiss_data) 37 | self.updated = True 38 | 39 | def get_ex_attn(self, ex_id, exclude_tokens=['[SEP]', '[CLS]']): 40 | ex_attn = self[ex_id]['attention'] 41 | ex_tokens = self[ex_id]['tokens'] 42 | if exclude_tokens: 43 | exclude_indices = [i for i, token in enumerate(ex_tokens) if token in exclude_tokens] 44 | ex_attn = np.delete(ex_attn, exclude_indices, axis=2) 45 | ex_attn = np.delete(ex_attn, exclude_indices, axis=3) 46 | return ex_attn 47 | 48 | def get_attn_means(self, attn, normalize=True): 49 | head_avg = attn.mean(axis=(2, 3)) 50 | if normalize: 51 | head_avg = zscore(head_avg) 52 | layer_avg = np.mean(head_avg, axis=1) 53 | return head_avg, layer_avg 54 | 55 | def get_attn_components_means(self, components): 56 | avg_attn = np.mean(np.array(components), axis=0) 57 | avg_attn = zscore(avg_attn) # normalize 58 | layer_avg = np.mean(avg_attn, axis=1) 59 | return avg_attn, layer_avg 60 | 61 | def img2txt_mean_attn(self, ex_id, exclude_tokens=['[SEP]', '[CLS]']): 62 | ex_attn = self.get_ex_attn(ex_id, exclude_tokens) 63 | txt_len = self[ex_id]['txt_len'] 64 | img2txt_attn = ex_attn[:, :, :txt_len, txt_len:] 65 | return self.get_attn_means(img2txt_attn) 66 | 67 | def txt2img_mean_attn(self, ex_id, exclude_tokens=['[SEP]', '[CLS]']): 68 | ex_attn = self.get_ex_attn(ex_id, exclude_tokens) 69 | txt_len = self[ex_id]['txt_len'] 70 | txt2img_attn = ex_attn[:, :, txt_len:, :txt_len] 71 | return self.get_attn_means(txt2img_attn) 72 | 73 | def txt2txt_mean_attn(self, ex_id, exclude_tokens=['[SEP]', '[CLS]']): 74 | ex_attn = self.get_ex_attn(ex_id, exclude_tokens) 75 | txt_len = self[ex_id]['txt_len'] 76 | txt2txt_attn = ex_attn[:, :, :txt_len, :txt_len] 77 | return self.get_attn_means(txt2txt_attn) 78 | 79 | def img2img_mean_attn(self, ex_id, exclude_tokens=['[SEP]', '[CLS]']): 80 | ex_attn = self.get_ex_attn(ex_id, exclude_tokens) 81 | txt_len = self[ex_id]['txt_len'] 82 | txt2img_attn = ex_attn[:, :, txt_len:, txt_len:] 83 | return self.get_attn_means(txt2img_attn) 84 | 85 | def img2img_mean_attn_without_self(self, ex_id, exclude_tokens=['[SEP]', '[CLS]']): 86 | ex_attn = self.get_ex_attn(ex_id, exclude_tokens) 87 | txt_len = self[ex_id]['txt_len'] 88 | txt2img_attn = ex_attn[:, :, txt_len:, txt_len:] 89 | n_layer, n_head, attn_len, attn_len = txt2img_attn.shape 90 | for i in range(attn_len): 91 | for j in range(attn_len): 92 | if i == j or i == j+1 or j == i+1: 93 | txt2img_attn[:, :, i, j ] = 0 94 | return self.get_attn_means(txt2img_attn) 95 | 96 | def get_all_attn_stats(self, ex_id, exclude_tokens=['[SEP]', '[CLS]']): 97 | attn_stats = [] 98 | for func in (self.img2txt_mean_attn, 99 | self.txt2img_mean_attn, 100 | self.img2img_mean_attn, 101 | self.img2img_mean_attn_without_self, 102 | self.txt2txt_mean_attn): 103 | attn, layer_avg = func(ex_id, exclude_tokens) 104 | attn_stats.append((attn, layer_avg)) 105 | # crossmodal 106 | attn_stats.append(self.get_attn_components_means([attn_stats[0][0], attn_stats[1][0]])) 107 | # intramodal 108 | attn_stats.append(self.get_attn_components_means([attn_stats[2][0], attn_stats[3][0]])) 109 | for i, (stats, layer_avg) in enumerate(attn_stats): 110 | attn_stats[i] = np.hstack((stats, layer_avg.reshape(layer_avg.shape[0], 1))) 111 | return attn_stats 112 | 113 | def get_custom_metrics(self, ex_id): 114 | if 'custom_metrics' in self[ex_id]: 115 | cm = self[ex_id]['custom_metrics'] 116 | labels, stats = [], [] 117 | for metrics in cm: 118 | data = cm[metrics] 119 | mean = np.mean(data, axis=1) 120 | mean = mean.reshape(mean.shape[0], 1) 121 | data = np.hstack((data, mean)) 122 | labels.append(metrics) 123 | stats.append(data) 124 | return labels, stats 125 | 126 | def find_closest_token(self, tsne, layer, mod): 127 | _, index = self.faiss_indices[(layer, mod)].search(tsne.reshape(1, -1), 1) # returns distance, index 128 | index = index.item(0) 129 | ex_id, token_id = self.faiss_data[(layer, mod)][index] 130 | return ex_id, token_id 131 | 132 | def get_txt_token_index(self, ex_id, text_tokens_id): 133 | # text_tokens_id: {'sentence': 0, 'word': 0} 134 | if not text_tokens_id: 135 | return 136 | sep_idx = [-1] + [i for i, x in enumerate(self[ex_id]['tokens']) if x == '[SEP]'] 137 | sentence, word = text_tokens_id['sentence'], text_tokens_id['word'] 138 | return sep_idx[sentence] + 1 + word 139 | 140 | def get_img_token_index(self, ex_id, img_coords): 141 | txt_len = self[ex_id]['txt_len'] 142 | return self[ex_id]['img_coords'].index(tuple(img_coords)) + txt_len 143 | 144 | def get_img_unit_len(self, ex_id): 145 | token_grid_size = self[ex_id]['img_grid_size'] 146 | image_size = self[ex_id]['image'].shape 147 | return image_size[1]//token_grid_size[0], image_size[0]//token_grid_size[1] 148 | -------------------------------------------------------------------------------- /app/database/db_example.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import numpy as np 3 | from database import VliLmdb 4 | 5 | if __name__ == '__main__': 6 | 7 | # The following example contains randomly generated data to 8 | # illustrate what the data should look like before preprocessing 9 | example_data = [] 10 | for ex_id in range(3): 11 | example_data.append( 12 | # each example should contain the following information: 13 | { 14 | # Example ID (integers starting from 0) 15 | 'ex_id': ex_id, 16 | 17 | # The original input image (RGB) 18 | # If your model preprocesses the image (e.g., resizing, padding), you may want to 19 | # use the preprocessed image instead of the original 20 | 'image': np.ones((450, 800, 3), dtype=int) * 100 * ex_id + 30, 21 | 22 | # Input tokens (text tokens followed by image tokens) 23 | 'tokens': ['[CLS]', 'text', 'input', 'for', 'example', str(ex_id), '.', '[SEP]', 24 | 'IMG_0', 'IMG_1', 'IMG_2', 'IMG_3', 'IMG_4', 'IMG_5'], 25 | 26 | # The number of text tokens 27 | 'txt_len': 8, 28 | 29 | # The (x, y) coordinates of each image token on the original image, 30 | # assuming the *top left* corner of an image is (0, 0) 31 | # The order of coordinates should correspond to how image tokens are ordered in 'tokens' 32 | 'img_coords': [(0, 0), (0, 1), (1, 0), (1, 1), (2, 0), (2, 1)], 33 | 34 | # (Optional) Whether model predicted correctly for this example 35 | 'accuracy': ex_id % 2 == 0, # either True or False 36 | 37 | # Attention weights for all attention heads in all layers 38 | # Shape: (n_layers, n_attention_heads_per_layer, n_tokens, n_tokens) 39 | # n_layers and n_attention_heads_per_layer should be the same accross example 40 | # The order of columns and rows of the attention weight matrix for each head should 41 | # correspond to how tokens are ordered in 'tokens' 42 | 'attention': np.random.rand(12, 12, 14, 14), 43 | 44 | # The hidden representations for each token in the model, 45 | # both before the first layer and after each layer 46 | # Shape: (n_layers + 1, n_tokens, hidden_state_vector_size) 47 | # Note that in our demo app, hidden representations of stop words were removed 48 | # to reduce the number of displayed datapoints 49 | 'hidden_states': np.random.rand(13, 14, 768), 50 | 51 | # (Optional) Custom statistics for attention heads in all layers 52 | # Shape: (n_layers, n_attention_heads_per_layer) 53 | # The order should follow how attention heads are ordered in 'attention' matrices 54 | 'custom_metrics': {'Example Custom Metrics': np.random.rand(12, 12)} 55 | } 56 | ) 57 | 58 | # Create database 59 | print('Creating database...') 60 | db = VliLmdb(db_dir='example_database1', read_only=False) 61 | for ex_id, ex_data in enumerate(example_data): 62 | # Keys must be strings 63 | db[str(ex_id)] = pickle.dumps(ex_data, protocol=pickle.HIGHEST_PROTOCOL) 64 | 65 | # Preprocess the database 66 | db.preprocess() 67 | 68 | -------------------------------------------------------------------------------- /app/database/models/kdvlp.py: -------------------------------------------------------------------------------- 1 | ''' 2 | The following script runs a forward passes with the KD-VLP model for 3 | each given image+text pair, and produces the corresponding attention 4 | and hidden states that can be visualized with VL-InterpreT. 5 | 6 | The KD-VLP model has not been made publicly available in this repo. 7 | Please create your own model class by inheriting from VL_Model. 8 | 9 | Note that to run your own model with VL-InterpreT, you are only 10 | required to implement the data_setup function. Most of the code in 11 | this file are specific to our KD-VLP model, and they are provided just 12 | for your reference. 13 | ''' 14 | 15 | 16 | import numpy as np 17 | import torch 18 | import pickle 19 | from torch.nn.utils.rnn import pad_sequence 20 | from transformers import BertTokenizer 21 | 22 | try: 23 | from app.database.models.vl_model import VL_Model 24 | except ModuleNotFoundError: 25 | from vl_model import VL_Model 26 | 27 | import sys 28 | MODEL_DIR = '/workdisk/ccr_vislang/benchmarks/vision_language/vcr/VILLA/' 29 | sys.path.append(MODEL_DIR) 30 | 31 | from model.vcr import UniterForVisualCommonsenseReasoning 32 | from utils.const import IMG_DIM 33 | from utils.misc import remove_prefix 34 | from data import get_transform 35 | 36 | 37 | class Kdvlp(VL_Model): 38 | ''' 39 | Running KD-VLP with VL-Interpret: 40 | python run_app.py -p 6006 -d example_database2 \ 41 | -m kdvlp /data1/users/shaoyent/e2e-vcr-kgmvm-pgm-run49-2/ckpt/model_step_8500.pt 42 | ''' 43 | def __init__(self, ckpt_file, device='cuda'): 44 | self.device = torch.device(device, 0) 45 | if device == 'cuda': 46 | torch.cuda.set_device(0) 47 | self.model, self.tokenizer = self.build_model(ckpt_file) 48 | 49 | 50 | def build_model(self, ckpt_file): 51 | tokenizer = BertTokenizer.from_pretrained('bert-base-cased') 52 | tokenizer._add_tokens([f'[PERSON_{i}]' for i in range(81)]) 53 | 54 | ckpt = torch.load(ckpt_file) 55 | checkpoint = {remove_prefix(k.replace('bert', 'uniter'), 'module.') : v for k, v in ckpt.items()} 56 | model = UniterForVisualCommonsenseReasoning.from_pretrained( 57 | f'{MODEL_DIR}/config/uniter-base.json', state_dict={}, 58 | img_dim=IMG_DIM) 59 | model.init_type_embedding() 60 | model.load_state_dict(checkpoint, strict=False) 61 | model.eval() 62 | model = model.to(self.device) 63 | return model, tokenizer 64 | 65 | 66 | def move_to_device(self, x): 67 | if isinstance(x, list): 68 | return [self.move_to_device(y) for y in x] 69 | elif isinstance(x, dict): 70 | new_dict = {} 71 | for k, v in x.items(): 72 | new_dict[k] = self.move_to_device(x[k]) 73 | return new_dict 74 | elif isinstance(x, torch.Tensor): 75 | return x.to(self.device) 76 | else: 77 | return x 78 | 79 | 80 | def build_batch(self, input_text, image, answer=None, person_info=None): 81 | if not input_text: 82 | input_text = '' 83 | if answer is None: 84 | input_ids = self.tokenizer.encode(input_text, return_tensors='pt') 85 | input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0) # all*input_ids 86 | txt_type_ids = torch.zeros_like(input_ids) 87 | else: 88 | input_ids_q = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(input_text)) 89 | input_ids_c = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(answer)) 90 | input_ids = [torch.tensor(self.tokenizer.build_inputs_with_special_tokens(input_ids_q, input_ids_c))] 91 | input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0) # all*input_ids 92 | 93 | txt_type_ids = torch.tensor((len(input_ids_q) + 2 )* [0] + (len(input_ids_c) + 1) * [2]) 94 | 95 | position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long).unsqueeze(0) 96 | num_sents = [input_ids.size(0)] 97 | txt_lens = [i.size(0) for i in input_ids] 98 | 99 | if image is None: 100 | images_batch = None 101 | else: 102 | images_batch = torch.as_tensor(image.copy(), dtype=torch.float32) 103 | images_batch = get_transform(images_batch.permute(2, 0, 1)) 104 | 105 | batch = {'input_ids': input_ids, 'txt_type_ids': txt_type_ids, 'position_ids': position_ids, 'images': images_batch, 106 | "txt_lens": txt_lens, "num_sents": num_sents, 'person_info': person_info} 107 | batch = self.move_to_device(batch) 108 | 109 | return batch 110 | 111 | 112 | def data_setup(self, ex_id, image_location, input_text): 113 | image = self.fetch_image(image_location) if image_location else None 114 | 115 | batch = self.build_batch(input_text, image, answer=None, person_info=None) 116 | scores, hidden_states, attentions = self.model(batch, 117 | compute_loss=False, 118 | output_attentions=True, 119 | output_hidden_states=True) 120 | 121 | attentions = torch.stack(attentions).transpose(1,0).detach().cpu()[0] 122 | 123 | if batch['images'] is None: 124 | img, img_coords = np.array([]), [] 125 | len_img = 0 126 | else: 127 | image1, mask1 = self.model.preprocess_image(batch['images'].to(self.device)) 128 | image1 = (image1 * self.model.pixel_std + self.model.pixel_mean) * mask1 129 | img = image1.cpu().numpy().astype(int).squeeze().transpose(1,2,0) 130 | 131 | h, w, _ = img.shape 132 | h0, w0 = h//64, w//64 133 | len_img = w0 * h0 134 | img_coords = np.fliplr(list(np.ndindex(h0, w0))) 135 | 136 | input_ids = batch['input_ids'].cpu() 137 | len_text = input_ids.size(1) 138 | txt_tokens = self.tokenizer.convert_ids_to_tokens(input_ids[0, :len_text]) 139 | 140 | len_tokens = len_text + len_img 141 | attentions = attentions[:, :, :len_tokens, :len_tokens] 142 | hidden_states = [hs[0].detach().cpu().numpy()[:len_tokens] for hs in hidden_states] 143 | 144 | return { 145 | 'ex_id': ex_id, 146 | 'image': img, 147 | 'tokens': txt_tokens + [f'IMG_{i}' for i in range(len_img)], 148 | 'txt_len': len(txt_tokens), 149 | 'attention': attentions.detach().cpu().numpy(), 150 | 'img_coords': img_coords, 151 | 'hidden_states': hidden_states 152 | } 153 | 154 | 155 | def create_example_db(): 156 | ''' 157 | This function creates example_database2. 158 | ''' 159 | images = [f'{MODEL_DIR}/visualization/ex.jpg'] 160 | texts = ['Horses are pulling a carriage, while someone is standing on the top of a golden ball.'] 161 | kdvlp = Kdvlp('/data1/users/shaoyent/e2e-vcr-kgmvm-pgm-run49-2/ckpt/model_step_8500.pt') 162 | data = [kdvlp.data_setup(i, img, txt) for i, (img, txt) in enumerate(zip(images, texts))] 163 | 164 | db = VliLmdb(db_dir='/workdisk/VL-InterpreT/example_database2', read_only=False) 165 | for i, dat in enumerate(data): 166 | db[str(i)] = pickle.dumps(dat, protocol=pickle.HIGHEST_PROTOCOL) 167 | db.preprocess() 168 | 169 | 170 | if __name__ == '__main__': 171 | sys.path.append('..') 172 | from database import VliLmdb 173 | create_example_db() 174 | -------------------------------------------------------------------------------- /app/database/models/vl_model.py: -------------------------------------------------------------------------------- 1 | ''' 2 | An abstract base class for live models that can run together with VL-InterpreT. 3 | 4 | To run your own model with VL-InterpreT, create another file your_model.py in this 5 | folder that contains a class Your_Model (use title case for the class name), which 6 | inherits from the VL_Model class and implements the data_setup method. The data_setup 7 | method should take the ID, image and text of a given example, run a forward pass for 8 | this example with your model, and return the corresponding attention, hidden states 9 | and other required data that can be visualized with VL-InterpreT. 10 | ''' 11 | 12 | 13 | from abc import ABC, abstractmethod 14 | import numpy as np 15 | from PIL import Image 16 | import urllib.request 17 | 18 | 19 | class VL_Model(ABC): 20 | ''' 21 | To run a live transformer with VL-InterpreT, define your own model class by inheriting 22 | from this class and implementing the data_setup method. 23 | 24 | Please follow these naming patterns to make sure your model runs easily with VL-InterpreT: 25 | - Create a new python script in this folder for your class, and name it in all lower 26 | case (e.g., yourmodelname.py) 27 | - Name your model class in title case, e.g., Yourmodelname. This class name should be 28 | the result of calling 'yourmodelname'.title(), where 'yourmodelname.py' is the name 29 | of your python script. 30 | 31 | Then you can run VL-InterpreT with your model: 32 | python run_app.py -p 6006 -d example_database2 -m yourmodelname your_model_parameters 33 | ''' 34 | 35 | @abstractmethod 36 | def data_setup(self, example_id: int, image_location: str, input_text: str) -> dict: 37 | ''' 38 | This method should run a forward pass with your model given the input image and 39 | text, and return the required data. See app/database/db_example.py for specifications 40 | of the return data format, and see the implementation in kdvlp.py for an example. 41 | ''' 42 | return { 43 | 'ex_id': example_id, 44 | 'image': np.array(), 45 | 'tokens': [], 46 | 'txt_len': 0, 47 | 'img_coords': [], 48 | 'attention': np.array(), 49 | 'hidden_states': np.array() 50 | } 51 | 52 | 53 | def fetch_image(self, image_location: str): 54 | ''' 55 | This helper function takes the path to an image (either an URL or a local path) and 56 | returns the image as an numpy array. 57 | ''' 58 | if image_location.startswith('http'): 59 | urllib.request.urlretrieve(image_location, 'temp.jpg') 60 | image_location = 'temp.jpg' 61 | 62 | img = Image.open(image_location).convert('RGB') 63 | img = np.array(img) 64 | return img 65 | -------------------------------------------------------------------------------- /app/plot_func.py: -------------------------------------------------------------------------------- 1 | ''' Module containing plotting functions. 2 | 3 | These plotting functions are used in appConfiguration.py to 4 | generate all the the plots for the UI. 5 | ''' 6 | 7 | 8 | import base64 9 | from io import BytesIO 10 | from PIL import Image 11 | from dash import html 12 | 13 | import numpy as np 14 | import matplotlib 15 | matplotlib.use('Agg') 16 | import matplotlib.pyplot as plt 17 | import matplotlib.cm as cm 18 | import plotly.graph_objects as go 19 | import scipy.ndimage 20 | 21 | 22 | IMG_LABEL_SIZE = 12 23 | LAYER_AVG_GAP = 0.8 24 | ATTN_MAP_SCALE_FACTOR = 64 25 | 26 | heatmap_layout = dict( 27 | autosize=True, 28 | font=dict(color='black'), 29 | titlefont=dict(color='black', size=14), 30 | legend=dict(font=dict(size=8), orientation='h'), 31 | ) 32 | 33 | img_overlay_layout = dict( 34 | barmode='stack', 35 | bargap=0, 36 | hovermode='closest', 37 | showlegend=False, 38 | autosize=False, 39 | margin={'l': 0, 'r': 0, 't': 0, 'b': 0}, 40 | paper_bgcolor='rgba(0,0,0,0)', 41 | plot_bgcolor='rgba(0,0,0,0)', 42 | xaxis=dict(visible=False, fixedrange=True), 43 | yaxis=dict(visible=False, scaleanchor='x') # constant aspect ratio 44 | ) 45 | 46 | def get_sep_indices(tokens): 47 | return [-1] + [i for i, x in enumerate(tokens) if x == '[SEP]'] 48 | 49 | 50 | def plot_text(db, ex_id): 51 | tokens = db[ex_id]['tokens'] 52 | txt_len = db[ex_id]['txt_len'] 53 | return show_texts(tokens[:txt_len]), txt_len 54 | 55 | def plot_attn_from_txt(db, ex_id, layer, head, word_ids=None, img_coords=None): 56 | ex_data = db[ex_id] 57 | txt_len = ex_data['txt_len'] 58 | txt_tokens = ex_data['tokens'][:txt_len] 59 | if img_coords: 60 | w, h = db[ex_id]['img_grid_size'] 61 | actual_img_tok_ids = [txt_len + w*(h-j-1) + i for i, j in img_coords] # TODO double check TODO TODO order image tokens TODO 62 | token_id = actual_img_tok_ids[0] # TODO: avg if multiple 63 | elif word_ids: 64 | token_id = db.get_txt_token_index(ex_id, word_ids) 65 | else: # nothing selected, show empty text 66 | return show_texts(txt_tokens) 67 | # get attention 68 | attn = ex_data['attention'] 69 | if int(head) == head: 70 | txt_attn = attn[layer, head, token_id, :txt_len] 71 | else: 72 | layer_attn = [attn[layer, hd, token_id, :txt_len] for hd in range(attn.shape[0])] 73 | txt_attn = np.mean(layer_attn, axis=0) 74 | # get colors by sentence 75 | colors = cm.get_cmap('Reds')(txt_attn * 120, 0.5) 76 | colors = ['rgba(' + ','.join(map(lambda x: str(int(x*255)), rgba[:3])) + ',' + str(rgba[3]) + ')' 77 | for rgba in colors] 78 | seps = get_sep_indices(ex_data['tokens']) 79 | colors = [colors[i+1:j+1] for i, j in zip(seps[:-1], seps[1:])] 80 | txt_fig = show_texts(txt_tokens, colors) 81 | return txt_fig 82 | 83 | def highlight_txt_token(db, ex_id, token_ids): 84 | txt_len = db[ex_id]['txt_len'] 85 | txt_tokens = db[ex_id]['tokens'][:txt_len] 86 | if token_ids: 87 | seps = get_sep_indices(txt_tokens) 88 | sentence, word = token_ids['sentence'], token_ids['word'] 89 | sentence_len = seps[sentence+1] - seps[sentence] 90 | colors = [] 91 | for s in range(len(seps) - 1): 92 | if s == sentence: 93 | colors.append(['white'] * sentence_len) 94 | colors[-1][word] = 'orange' 95 | else: 96 | colors.append('white') 97 | else: 98 | colors = 'white' 99 | txt_fig = show_texts(txt_tokens, colors) 100 | return txt_fig 101 | 102 | def plot_image(db, ex_id): 103 | image = db[ex_id]['image'] 104 | return show_img(image, bg='black', opacity=1.0) 105 | 106 | def get_empty_fig(): 107 | fig = go.Figure() 108 | fig.update_layout(img_overlay_layout) 109 | return fig 110 | 111 | def get_overlay_grid(db, ex_id, color='rgba(255, 255, 255, 0)'): 112 | imgh, imgw, _ = db[ex_id]['image'].shape 113 | gridw, gridh = db[ex_id]['img_grid_size'] 114 | unit_w = imgw / gridw 115 | unit_h = imgh / gridh 116 | grid_data = tuple(go.Bar( 117 | x=np.linspace(unit_w/2, imgw-unit_w/2, gridw), 118 | y=[unit_h] * gridw, 119 | hoverinfo='none', 120 | marker_line_width=0, 121 | marker_color=(color if type(color) == str else color[i])) 122 | for i in range(gridh)) 123 | return grid_data 124 | 125 | def get_overlay_fig(db, ex_id, color='rgba(255, 255, 255, 0)'): 126 | grid_data = get_overlay_grid(db, ex_id, color) 127 | imgh, imgw, _ = db[ex_id]['image'].shape 128 | fig = go.Figure( 129 | data=grid_data, 130 | layout=go.Layout( 131 | xaxis={'range': (0, imgw)}, 132 | yaxis={'range': (0, imgh)}) 133 | ) 134 | fig.update_layout(img_overlay_layout) 135 | return fig 136 | 137 | def highlight_img_grid_selection(db, ex_id, img_selection): 138 | img_selection = set(map(tuple, img_selection)) 139 | w, h = db[ex_id]['img_grid_size'] 140 | highlight, transparent = 'rgba(255, 165, 0, .4)', 'rgba(255, 255, 255, 0)' 141 | colors = [[(highlight if (j, i) in img_selection else transparent) for j in range(w)] for i in range(h)] 142 | return get_overlay_fig(db, ex_id, colors) 143 | 144 | def plot_attn_from_img(db, ex_id, layer, head, word_ids=None, img_coords=None, smooth=True): 145 | ex_data = db[ex_id] 146 | w0, h0 = ex_data['img_grid_size'] 147 | txt_len = ex_data['txt_len'] 148 | if word_ids and word_ids['sentence'] > -1 and word_ids['word'] > -1: 149 | token_id = db.get_txt_token_index(ex_id, word_ids) 150 | elif img_coords: 151 | img_token_ids = [db.get_img_token_index(ex_id, coords) for coords in img_coords] 152 | token_id = img_token_ids[0] # TODO show avg for multiple selection? 153 | else: 154 | return get_empty_fig() 155 | attn = ex_data['attention'] 156 | 157 | if head < attn.shape[1]: 158 | img_attn = attn[layer, head, token_id, txt_len:(w0*h0+txt_len)].reshape(h0, w0) 159 | else: # show layer avg 160 | layer_attn = [attn[layer, h, token_id, txt_len:(w0*h0+txt_len)].reshape(h0, w0) \ 161 | for h in range(attn.shape[0])] 162 | img_attn = np.mean(layer_attn, axis=0) 163 | if smooth: 164 | img_attn = scipy.ndimage.zoom(img_attn, ATTN_MAP_SCALE_FACTOR, order=1) 165 | return show_img(img_attn, opacity=0.3, bg='rgba(0,0,0,0)', hw=ex_data['image'].shape[:2]) 166 | 167 | def plot_head_summary(db, ex_id, attn_type_index=0, layer=None, head=None): 168 | stats = db.get_all_attn_stats(ex_id) 169 | custom_stats = db.get_custom_metrics(ex_id) 170 | return show_head_summary(stats, custom_stats, attn_type_index, layer, head) 171 | 172 | def plot_attn_matrix(db, ex_id, layer, head, attn_type=0): 173 | txt_len = db[ex_id]['txt_len'] 174 | txt_attn = db[ex_id]['attention'][layer, head, :txt_len, :txt_len] 175 | img_attn = db[ex_id]['attention'][layer, head, txt_len:, txt_len:] 176 | tokens = db[ex_id]['tokens'] 177 | return show_attn_matrix([txt_attn, img_attn], tokens, layer, head, attn_type) 178 | 179 | 180 | def show_texts(tokens, colors='white'): 181 | seps = get_sep_indices(tokens) 182 | sentences = [tokens[i+1:j+1] for i, j in zip(seps[:-1], seps[1:])] 183 | 184 | fig = go.Figure() 185 | annotations = [] 186 | for sen_i, sentence in enumerate(sentences): 187 | word_lengths = list(map(len, sentence)) 188 | fig.add_trace(go.Bar( 189 | x=word_lengths, # TODO center at 0 190 | y=[sen_i] * len(sentence), 191 | orientation='h', 192 | marker_color=(colors if type(colors) is str else colors[sen_i]), 193 | marker_line=dict(color='rgba(255, 255, 255, 0)', width=0), 194 | hoverinfo='none' 195 | )) 196 | word_pos = np.cumsum(word_lengths) - np.array(word_lengths) / 2 197 | for word_i in range(len(sentence)): 198 | annotations.append(dict( 199 | xref='x', yref='y', 200 | x=word_pos[word_i], y=sen_i, 201 | text=sentence[word_i], 202 | showarrow=False 203 | )) 204 | fig.update_xaxes(visible=False, fixedrange=True) 205 | fig.update_yaxes(visible=False, fixedrange=True, range=(len(sentences)+.3, -len(sentences)+.7)) 206 | fig.update_layout( 207 | annotations=annotations, 208 | barmode='stack', 209 | autosize=False, 210 | margin={'l': 0, 'r': 0, 't': 0, 'b': 0}, 211 | showlegend=False, 212 | plot_bgcolor='white' 213 | ) 214 | return fig 215 | 216 | 217 | def show_img(img, opacity, bg, hw=None): 218 | img_height, img_width = hw if hw else (img.shape[0], img.shape[1]) 219 | mfig, ax = plt.subplots(figsize=(img_width/100., img_height/100.), dpi=100) 220 | plt.subplots_adjust(left=0, right=1, top=1, bottom=0, wspace=0, hspace=0) 221 | plt.axis('off') 222 | if hw: 223 | ax.imshow(img, cmap='jet', interpolation='nearest', aspect='auto') 224 | else: 225 | ax.imshow(img) 226 | img_uri = fig_to_uri(mfig) 227 | fig_width, fig_height = mfig.get_size_inches() * mfig.dpi 228 | 229 | fig = go.Figure() 230 | fig.update_xaxes(range=(0, fig_width)) 231 | fig.update_yaxes(range=(0, fig_height)) 232 | fig.update_layout(img_overlay_layout) 233 | fig.update_layout( 234 | autosize=True, 235 | plot_bgcolor=bg, 236 | paper_bgcolor=bg 237 | ) 238 | fig.layout.images = [] # remove previous image 239 | fig.add_layout_image(dict( 240 | x=0, y=fig_height, 241 | sizex=fig_width, sizey=fig_height, 242 | xref='x', yref='y', 243 | opacity=opacity, 244 | sizing='stretch', 245 | source=img_uri 246 | )) 247 | return fig 248 | 249 | 250 | def fig_to_uri(fig, close_all=True, **save_args): 251 | out_img = BytesIO() 252 | fig.savefig(out_img, format='jpeg', **save_args) 253 | if close_all: 254 | fig.clf() 255 | plt.close('all') 256 | out_img.seek(0) # rewind file 257 | encoded = base64.b64encode(out_img.read()).decode('ascii').replace('\n', '') 258 | return 'data:image/jpeg;base64,{}'.format(encoded) 259 | 260 | 261 | def show_head_summary(attn_stats, custom_stats=None, attn_type_index=0, layer=None, head=None): 262 | n_stats = len(attn_stats) + (len(custom_stats) if custom_stats else 0) 263 | if attn_type_index >= n_stats: 264 | attn_type_index = 0 265 | 266 | num_layers, num_heads = np.shape(attn_stats[0]) 267 | data = [go.Heatmap( 268 | type = 'heatmap', 269 | x=list(np.arange(-0.5, num_heads-1)) + [num_heads + LAYER_AVG_GAP - 0.5], 270 | z=stats, 271 | zmin=-3, zmax=3, # fix color scale range 272 | colorscale='plasma', 273 | reversescale=False, 274 | colorbar=dict(thickness=10), 275 | visible=(i == attn_type_index) 276 | ) for i, stats in enumerate(attn_stats)] 277 | 278 | if custom_stats: 279 | data += [go.Heatmap( 280 | type = 'heatmap', 281 | x=list(np.arange(-0.5, num_heads-1)) + [num_heads + LAYER_AVG_GAP - 0.5], 282 | z=stats, 283 | colorscale='plasma', 284 | reversescale=False, 285 | colorbar=dict(thickness=10), 286 | visible=(attn_type_index == len(data))) for stats in custom_stats[1]] 287 | 288 | if layer is not None and head is not None: 289 | x_pos = np.floor(head) + LAYER_AVG_GAP if head > num_heads - 1 else head 290 | data.append(dict( 291 | type='scattergl', 292 | x=[x_pos], 293 | y=[layer], 294 | marker=dict( 295 | color='black', 296 | symbol='x', 297 | size=15, 298 | opacity=0.6, 299 | line=dict(width=1, color='lightgrey') 300 | ), 301 | )) 302 | layout = heatmap_layout 303 | layout.update({'title': ''}) 304 | layout.update({ 305 | 'yaxis': { 306 | 'title': 'Layer #', 307 | 'tickmode': 'array', 308 | 'ticktext': list(range(1, num_layers+1)), 309 | 'tickvals': list(range(num_layers)), 310 | 'range': (num_layers-0.5, -0.5), 311 | 'fixedrange': True}, 312 | 'xaxis': { 313 | 'title': 'Attention head #', 314 | 'tickmode': 'array', 315 | 'ticktext': list(range(1, num_heads)) + ['Mean'], 316 | 'tickvals': list(range(num_heads-1)) + [num_heads + LAYER_AVG_GAP - 1], 317 | 'range': (-0.5, num_heads + LAYER_AVG_GAP - 0.5), 318 | 'fixedrange': True}, 319 | 'margin': {'t': 0, 'l': 0, 'r': 0, 'b': 0}, 320 | 'height': 450, 321 | 'width': 520 322 | }) 323 | 324 | fig = go.Figure(data=data, layout=layout) 325 | fig.add_vrect(x0=num_heads-1.5, x1=num_heads+LAYER_AVG_GAP-1.5, fillcolor='white', line_width=0) 326 | 327 | # dropdown menu 328 | labels = ['Mean image-to-text attention (without CLS/SEP)', 329 | 'Mean text-to-image attention (without CLS/SEP)', 330 | 'Mean image-to-image attention (without CLS/SEP)', 331 | 'Mean image-to-image attention (without self/CLS/SEP)', 332 | 'Mean text-to-text attention (without CLS/SEP)', 333 | 'Mean cross-modal attention (without CLS/SEP)', 334 | 'Mean intra-modal attention (without CLS/SEP)'] 335 | if custom_stats: 336 | labels += custom_stats[0] 337 | fig.update_layout( 338 | updatemenus=[dict( 339 | buttons=[ 340 | dict( 341 | args=[{'visible': [True if j == i else False for j in range(len(labels))] + [True]}], 342 | label=labels[i], 343 | method='update' 344 | ) for i in range(len(labels)) 345 | ], 346 | direction='down', 347 | pad={'l': 85}, 348 | active=attn_type_index, 349 | showactive=True, 350 | x=0, 351 | xanchor='left', 352 | y=1.01, 353 | yanchor='bottom' 354 | )], 355 | annotations=[dict( 356 | text='Display data:', showarrow=False, 357 | x=0, y=1.027, xref='paper', yref='paper', 358 | yanchor='bottom' 359 | )] 360 | ) 361 | 362 | return fig 363 | 364 | 365 | def show_attn_matrix(attns, tokens, layer, head, attn_type): 366 | txt_tokens = tokens[:attns[0].shape[0]] 367 | data = [dict( 368 | type='heatmap', 369 | z=attn, 370 | colorbar=dict(thickness=10), 371 | visible=(i == attn_type)) for i, attn in enumerate(attns)] 372 | layout = heatmap_layout 373 | layout.update({ 374 | 'title': f'Attention Matrix for Layer #{layer+1}, Head #{head+1}', 375 | 'title_font_size': 16, 376 | 'title_x': 0.5, 'title_y': 0.99, 377 | 'margin': {'t': 70, 'b': 10, 'l': 0, 'r': 0} 378 | }) 379 | txt_attn_layout = { 380 | 'xaxis': dict( 381 | tickmode='array', 382 | tickvals=list(range(len(txt_tokens))), 383 | ticktext=txt_tokens, 384 | tickangle=45, 385 | tickfont=dict(size=12), 386 | range=(-0.5, len(txt_tokens)-0.5), 387 | fixedrange=True), 388 | 'yaxis': dict( 389 | tickmode='array', 390 | tickvals=list(range(len(txt_tokens))), 391 | ticktext=txt_tokens, 392 | tickfont=dict(size=12), 393 | range=(-0.5, len(txt_tokens)-0.5), 394 | fixedrange=True), 395 | 'width': 600, 396 | 'height': 600 397 | } 398 | img_attn_layout = { 399 | 'xaxis': dict(range=(-0.5, len(tokens)-len(txt_tokens)-0.5)), 400 | 'yaxis': dict(range=(-0.5, len(tokens)-len(txt_tokens)-0.5), scaleanchor='x'), 401 | 'width': 900, 402 | 'height': 900, 403 | 'plot_bgcolor': 'rgba(0,0,0,0)', 404 | } 405 | layout.update((txt_attn_layout, img_attn_layout)[attn_type]) 406 | figure = go.Figure(dict(data=data, layout=layout)) 407 | 408 | # dropdown menu 409 | figure.update_layout( 410 | updatemenus=[dict( 411 | buttons=[dict( 412 | args=[ 413 | {'visible': [True, False]}, 414 | txt_attn_layout 415 | ], 416 | label='Text-to-text attention', 417 | method='update' 418 | ), dict( 419 | args=[ 420 | {'visible': [False, True]}, 421 | img_attn_layout 422 | ], 423 | label='Image-to-image attention', 424 | method='update' 425 | )], 426 | direction='down', 427 | pad={'l': 85}, 428 | active=attn_type, 429 | showactive=True, 430 | x=0, 431 | xanchor='left', 432 | y=1.01, 433 | yanchor='bottom' 434 | )], 435 | annotations=[dict( 436 | text='Display data:', showarrow=False, 437 | x=0, y=1.02, 438 | xref='paper', yref='paper', 439 | xanchor='left', yanchor='bottom' 440 | )], 441 | ) 442 | 443 | return figure 444 | 445 | def show_tsne(db, ex_id, layer, text_tokens_id, img_tokens): 446 | ex_data = db[ex_id] 447 | selected_token_id = None 448 | if text_tokens_id and text_tokens_id['sentence'] >= 0 and text_tokens_id['word'] >= 0: 449 | selected_token_id = db.get_txt_token_index(ex_id, text_tokens_id) 450 | mod = 'img' 451 | elif img_tokens: 452 | selected_token_id = db.get_img_token_index(ex_id, img_tokens[0]) 453 | mod = 'txt' 454 | 455 | # draw all tokens 456 | figure = go.Figure() 457 | txt_len = ex_data['txt_len'] 458 | for modality in ('txt', 'img'): 459 | slicing = slice(txt_len) if modality == 'txt' else slice(txt_len, len(ex_data['tokens'])) 460 | figure.add_trace(go.Scatter( 461 | x=ex_data['tsne'][layer][slicing][:,0], 462 | y=ex_data['tsne'][layer][slicing][:,1], 463 | hovertext=[f'{ex_id},{token_id}' for token_id in list(range(len(ex_data['tokens'])))[slicing]], 464 | mode='markers', 465 | name=modality, 466 | marker={'color': 'blue' if modality == 'img' else 'red'} 467 | )) 468 | if selected_token_id is not None: 469 | # hightlight selected 470 | token = ex_data['tokens'][selected_token_id] 471 | tsne = ex_data['tsne'][layer][selected_token_id] 472 | figure.add_trace(go.Scatter( 473 | x=[tsne[0]], 474 | y=[tsne[1]], 475 | mode='markers+text', 476 | hovertext=[f'{ex_id},{selected_token_id}'], 477 | text=token, 478 | textposition='top center', 479 | textfont=dict(size=15), 480 | marker=dict( 481 | color='orange', 482 | colorscale='Jet', 483 | showscale=False, 484 | symbol='star', 485 | size=10, 486 | opacity=1, 487 | cmin=-1.0, 488 | cmax=1.0, 489 | ), 490 | showlegend=False, 491 | )) 492 | # hightlight closest 493 | closest_ex_id, closest_token_id = db.find_closest_token(tsne, layer, mod) 494 | closest_token = db[closest_ex_id]['tokens'][closest_token_id] 495 | closest_tsne = db[closest_ex_id]['tsne'][layer][closest_token_id] 496 | figure.add_trace(go.Scatter( 497 | x=[closest_tsne[0]], 498 | y=[closest_tsne[1]], 499 | mode='markers+text', 500 | hovertext=[f'{closest_ex_id},{closest_token_id}'], 501 | text=closest_token + ' from ex ' + str(closest_ex_id), 502 | textposition='top center', 503 | textfont=dict(size=15), 504 | marker=dict( 505 | color='green', 506 | colorscale='Jet', 507 | showscale=False, 508 | symbol='star', 509 | size=10, 510 | opacity=1, 511 | cmin=-1.0, 512 | cmax=1.0, 513 | ), 514 | showlegend=False, 515 | )) 516 | figure.update_layout(legend=dict(font=dict(size=15), y=1, x=0.99, bgcolor='rgba(0,0,0,0)')) 517 | figure.update_layout(margin={'l': 0, 'r': 0, 'b': 0, 't': 0}) 518 | figure.update_traces(hoverinfo='none', hovertemplate=None) 519 | return figure 520 | 521 | 522 | def plot_tooltip_content(db, ex_id, token_id): 523 | ex_data = db[ex_id] 524 | txt_len = ex_data['txt_len'] 525 | token = ex_data['tokens'][token_id] 526 | if token_id < txt_len: # text token 527 | return [html.P(token)] 528 | # image token 529 | img_coords = ex_data['img_coords'][token_id-txt_len] 530 | img = np.copy(ex_data['image']) 531 | x_unit_len, y_unit_len = db.get_img_unit_len(ex_id) 532 | x = int(x_unit_len * img_coords[0]) 533 | y_end = img.shape[0] - int(y_unit_len * img_coords[1]) 534 | x_end = x + x_unit_len 535 | y = y_end - y_unit_len 536 | img[y:y_end, x:x_end, 0] = 0 537 | img[y:y_end, x:x_end, 1] = 0 538 | img[y:y_end, x:x_end, 2] = 255 539 | # dump img to base64 540 | buffer = BytesIO() 541 | img = Image.fromarray(np.uint8(img)).save(buffer, format='jpeg') 542 | encoded_image = base64.b64encode(buffer.getvalue()).decode() 543 | img_url = 'data:image/jpeg;base64, ' + encoded_image 544 | return [ 545 | html.Div([ 546 | html.Img( 547 | src=img_url, 548 | style={'width': '300px', 'display': 'block', 'margin': '0 auto'}, 549 | ), 550 | html.P(token, style={'font-weight': 'bold','font-size': '7x'}) 551 | ]) 552 | ] 553 | -------------------------------------------------------------------------------- /assets/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IntelLabs/VL-InterpreT/0f9d0bb1dad13d91ce79d9453faf81fc2b9277b9/assets/logo.png -------------------------------------------------------------------------------- /assets/screencast.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IntelLabs/VL-InterpreT/0f9d0bb1dad13d91ce79d9453faf81fc2b9277b9/assets/screencast.png -------------------------------------------------------------------------------- /assets/screenshot_add_ex.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IntelLabs/VL-InterpreT/0f9d0bb1dad13d91ce79d9453faf81fc2b9277b9/assets/screenshot_add_ex.png -------------------------------------------------------------------------------- /assets/screenshot_app_launch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IntelLabs/VL-InterpreT/0f9d0bb1dad13d91ce79d9453faf81fc2b9277b9/assets/screenshot_app_launch.png -------------------------------------------------------------------------------- /example_database1/data.mdb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IntelLabs/VL-InterpreT/0f9d0bb1dad13d91ce79d9453faf81fc2b9277b9/example_database1/data.mdb -------------------------------------------------------------------------------- /example_database1/faiss/img_indices_0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IntelLabs/VL-InterpreT/0f9d0bb1dad13d91ce79d9453faf81fc2b9277b9/example_database1/faiss/img_indices_0 -------------------------------------------------------------------------------- /example_database1/faiss/img_indices_1: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IntelLabs/VL-InterpreT/0f9d0bb1dad13d91ce79d9453faf81fc2b9277b9/example_database1/faiss/img_indices_1 -------------------------------------------------------------------------------- /example_database1/faiss/img_indices_10: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IntelLabs/VL-InterpreT/0f9d0bb1dad13d91ce79d9453faf81fc2b9277b9/example_database1/faiss/img_indices_10 -------------------------------------------------------------------------------- /example_database1/faiss/img_indices_11: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IntelLabs/VL-InterpreT/0f9d0bb1dad13d91ce79d9453faf81fc2b9277b9/example_database1/faiss/img_indices_11 -------------------------------------------------------------------------------- /example_database1/faiss/img_indices_12: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IntelLabs/VL-InterpreT/0f9d0bb1dad13d91ce79d9453faf81fc2b9277b9/example_database1/faiss/img_indices_12 -------------------------------------------------------------------------------- /example_database1/faiss/img_indices_2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IntelLabs/VL-InterpreT/0f9d0bb1dad13d91ce79d9453faf81fc2b9277b9/example_database1/faiss/img_indices_2 -------------------------------------------------------------------------------- /example_database1/faiss/img_indices_3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IntelLabs/VL-InterpreT/0f9d0bb1dad13d91ce79d9453faf81fc2b9277b9/example_database1/faiss/img_indices_3 -------------------------------------------------------------------------------- /example_database1/faiss/img_indices_4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IntelLabs/VL-InterpreT/0f9d0bb1dad13d91ce79d9453faf81fc2b9277b9/example_database1/faiss/img_indices_4 -------------------------------------------------------------------------------- /example_database1/faiss/img_indices_5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IntelLabs/VL-InterpreT/0f9d0bb1dad13d91ce79d9453faf81fc2b9277b9/example_database1/faiss/img_indices_5 -------------------------------------------------------------------------------- /example_database1/faiss/img_indices_6: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IntelLabs/VL-InterpreT/0f9d0bb1dad13d91ce79d9453faf81fc2b9277b9/example_database1/faiss/img_indices_6 -------------------------------------------------------------------------------- /example_database1/faiss/img_indices_7: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IntelLabs/VL-InterpreT/0f9d0bb1dad13d91ce79d9453faf81fc2b9277b9/example_database1/faiss/img_indices_7 -------------------------------------------------------------------------------- /example_database1/faiss/img_indices_8: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IntelLabs/VL-InterpreT/0f9d0bb1dad13d91ce79d9453faf81fc2b9277b9/example_database1/faiss/img_indices_8 -------------------------------------------------------------------------------- /example_database1/faiss/img_indices_9: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IntelLabs/VL-InterpreT/0f9d0bb1dad13d91ce79d9453faf81fc2b9277b9/example_database1/faiss/img_indices_9 -------------------------------------------------------------------------------- /example_database1/faiss/txt_indices_0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IntelLabs/VL-InterpreT/0f9d0bb1dad13d91ce79d9453faf81fc2b9277b9/example_database1/faiss/txt_indices_0 -------------------------------------------------------------------------------- /example_database1/faiss/txt_indices_1: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IntelLabs/VL-InterpreT/0f9d0bb1dad13d91ce79d9453faf81fc2b9277b9/example_database1/faiss/txt_indices_1 -------------------------------------------------------------------------------- /example_database1/faiss/txt_indices_10: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IntelLabs/VL-InterpreT/0f9d0bb1dad13d91ce79d9453faf81fc2b9277b9/example_database1/faiss/txt_indices_10 -------------------------------------------------------------------------------- /example_database1/faiss/txt_indices_11: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IntelLabs/VL-InterpreT/0f9d0bb1dad13d91ce79d9453faf81fc2b9277b9/example_database1/faiss/txt_indices_11 -------------------------------------------------------------------------------- /example_database1/faiss/txt_indices_12: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IntelLabs/VL-InterpreT/0f9d0bb1dad13d91ce79d9453faf81fc2b9277b9/example_database1/faiss/txt_indices_12 -------------------------------------------------------------------------------- /example_database1/faiss/txt_indices_2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IntelLabs/VL-InterpreT/0f9d0bb1dad13d91ce79d9453faf81fc2b9277b9/example_database1/faiss/txt_indices_2 -------------------------------------------------------------------------------- /example_database1/faiss/txt_indices_3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IntelLabs/VL-InterpreT/0f9d0bb1dad13d91ce79d9453faf81fc2b9277b9/example_database1/faiss/txt_indices_3 -------------------------------------------------------------------------------- /example_database1/faiss/txt_indices_4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IntelLabs/VL-InterpreT/0f9d0bb1dad13d91ce79d9453faf81fc2b9277b9/example_database1/faiss/txt_indices_4 -------------------------------------------------------------------------------- /example_database1/faiss/txt_indices_5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IntelLabs/VL-InterpreT/0f9d0bb1dad13d91ce79d9453faf81fc2b9277b9/example_database1/faiss/txt_indices_5 -------------------------------------------------------------------------------- /example_database1/faiss/txt_indices_6: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IntelLabs/VL-InterpreT/0f9d0bb1dad13d91ce79d9453faf81fc2b9277b9/example_database1/faiss/txt_indices_6 -------------------------------------------------------------------------------- /example_database1/faiss/txt_indices_7: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IntelLabs/VL-InterpreT/0f9d0bb1dad13d91ce79d9453faf81fc2b9277b9/example_database1/faiss/txt_indices_7 -------------------------------------------------------------------------------- /example_database1/faiss/txt_indices_8: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IntelLabs/VL-InterpreT/0f9d0bb1dad13d91ce79d9453faf81fc2b9277b9/example_database1/faiss/txt_indices_8 -------------------------------------------------------------------------------- /example_database1/faiss/txt_indices_9: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IntelLabs/VL-InterpreT/0f9d0bb1dad13d91ce79d9453faf81fc2b9277b9/example_database1/faiss/txt_indices_9 -------------------------------------------------------------------------------- /example_database1/lock.mdb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IntelLabs/VL-InterpreT/0f9d0bb1dad13d91ce79d9453faf81fc2b9277b9/example_database1/lock.mdb -------------------------------------------------------------------------------- /example_database2/data.mdb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IntelLabs/VL-InterpreT/0f9d0bb1dad13d91ce79d9453faf81fc2b9277b9/example_database2/data.mdb -------------------------------------------------------------------------------- /example_database2/faiss/img_indices_0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IntelLabs/VL-InterpreT/0f9d0bb1dad13d91ce79d9453faf81fc2b9277b9/example_database2/faiss/img_indices_0 -------------------------------------------------------------------------------- /example_database2/faiss/img_indices_1: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IntelLabs/VL-InterpreT/0f9d0bb1dad13d91ce79d9453faf81fc2b9277b9/example_database2/faiss/img_indices_1 -------------------------------------------------------------------------------- /example_database2/faiss/img_indices_10: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IntelLabs/VL-InterpreT/0f9d0bb1dad13d91ce79d9453faf81fc2b9277b9/example_database2/faiss/img_indices_10 -------------------------------------------------------------------------------- /example_database2/faiss/img_indices_11: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IntelLabs/VL-InterpreT/0f9d0bb1dad13d91ce79d9453faf81fc2b9277b9/example_database2/faiss/img_indices_11 -------------------------------------------------------------------------------- /example_database2/faiss/img_indices_12: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IntelLabs/VL-InterpreT/0f9d0bb1dad13d91ce79d9453faf81fc2b9277b9/example_database2/faiss/img_indices_12 -------------------------------------------------------------------------------- /example_database2/faiss/img_indices_2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IntelLabs/VL-InterpreT/0f9d0bb1dad13d91ce79d9453faf81fc2b9277b9/example_database2/faiss/img_indices_2 -------------------------------------------------------------------------------- /example_database2/faiss/img_indices_3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IntelLabs/VL-InterpreT/0f9d0bb1dad13d91ce79d9453faf81fc2b9277b9/example_database2/faiss/img_indices_3 -------------------------------------------------------------------------------- /example_database2/faiss/img_indices_4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IntelLabs/VL-InterpreT/0f9d0bb1dad13d91ce79d9453faf81fc2b9277b9/example_database2/faiss/img_indices_4 -------------------------------------------------------------------------------- /example_database2/faiss/img_indices_5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IntelLabs/VL-InterpreT/0f9d0bb1dad13d91ce79d9453faf81fc2b9277b9/example_database2/faiss/img_indices_5 -------------------------------------------------------------------------------- /example_database2/faiss/img_indices_6: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IntelLabs/VL-InterpreT/0f9d0bb1dad13d91ce79d9453faf81fc2b9277b9/example_database2/faiss/img_indices_6 -------------------------------------------------------------------------------- /example_database2/faiss/img_indices_7: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IntelLabs/VL-InterpreT/0f9d0bb1dad13d91ce79d9453faf81fc2b9277b9/example_database2/faiss/img_indices_7 -------------------------------------------------------------------------------- /example_database2/faiss/img_indices_8: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IntelLabs/VL-InterpreT/0f9d0bb1dad13d91ce79d9453faf81fc2b9277b9/example_database2/faiss/img_indices_8 -------------------------------------------------------------------------------- /example_database2/faiss/img_indices_9: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IntelLabs/VL-InterpreT/0f9d0bb1dad13d91ce79d9453faf81fc2b9277b9/example_database2/faiss/img_indices_9 -------------------------------------------------------------------------------- /example_database2/faiss/txt_indices_0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IntelLabs/VL-InterpreT/0f9d0bb1dad13d91ce79d9453faf81fc2b9277b9/example_database2/faiss/txt_indices_0 -------------------------------------------------------------------------------- /example_database2/faiss/txt_indices_1: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IntelLabs/VL-InterpreT/0f9d0bb1dad13d91ce79d9453faf81fc2b9277b9/example_database2/faiss/txt_indices_1 -------------------------------------------------------------------------------- /example_database2/faiss/txt_indices_10: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IntelLabs/VL-InterpreT/0f9d0bb1dad13d91ce79d9453faf81fc2b9277b9/example_database2/faiss/txt_indices_10 -------------------------------------------------------------------------------- /example_database2/faiss/txt_indices_11: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IntelLabs/VL-InterpreT/0f9d0bb1dad13d91ce79d9453faf81fc2b9277b9/example_database2/faiss/txt_indices_11 -------------------------------------------------------------------------------- /example_database2/faiss/txt_indices_12: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IntelLabs/VL-InterpreT/0f9d0bb1dad13d91ce79d9453faf81fc2b9277b9/example_database2/faiss/txt_indices_12 -------------------------------------------------------------------------------- /example_database2/faiss/txt_indices_2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IntelLabs/VL-InterpreT/0f9d0bb1dad13d91ce79d9453faf81fc2b9277b9/example_database2/faiss/txt_indices_2 -------------------------------------------------------------------------------- /example_database2/faiss/txt_indices_3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IntelLabs/VL-InterpreT/0f9d0bb1dad13d91ce79d9453faf81fc2b9277b9/example_database2/faiss/txt_indices_3 -------------------------------------------------------------------------------- /example_database2/faiss/txt_indices_4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IntelLabs/VL-InterpreT/0f9d0bb1dad13d91ce79d9453faf81fc2b9277b9/example_database2/faiss/txt_indices_4 -------------------------------------------------------------------------------- /example_database2/faiss/txt_indices_5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IntelLabs/VL-InterpreT/0f9d0bb1dad13d91ce79d9453faf81fc2b9277b9/example_database2/faiss/txt_indices_5 -------------------------------------------------------------------------------- /example_database2/faiss/txt_indices_6: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IntelLabs/VL-InterpreT/0f9d0bb1dad13d91ce79d9453faf81fc2b9277b9/example_database2/faiss/txt_indices_6 -------------------------------------------------------------------------------- /example_database2/faiss/txt_indices_7: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IntelLabs/VL-InterpreT/0f9d0bb1dad13d91ce79d9453faf81fc2b9277b9/example_database2/faiss/txt_indices_7 -------------------------------------------------------------------------------- /example_database2/faiss/txt_indices_8: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IntelLabs/VL-InterpreT/0f9d0bb1dad13d91ce79d9453faf81fc2b9277b9/example_database2/faiss/txt_indices_8 -------------------------------------------------------------------------------- /example_database2/faiss/txt_indices_9: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IntelLabs/VL-InterpreT/0f9d0bb1dad13d91ce79d9453faf81fc2b9277b9/example_database2/faiss/txt_indices_9 -------------------------------------------------------------------------------- /example_database2/lock.mdb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IntelLabs/VL-InterpreT/0f9d0bb1dad13d91ce79d9453faf81fc2b9277b9/example_database2/lock.mdb -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | dash>=2.0.0 2 | plotly>=5.4.0 3 | dash-daq>=0.5.0 4 | matplotlib>=3.3.3 5 | termcolor>=1.1.0 6 | numpy>=1.18.4 7 | lmdb>=1.1.1 8 | lz4>=3.1.3 9 | msgpack>=1.0.2 10 | msgpack_numpy>=0.4.7.1 11 | faiss-cpu>=1.7.1 12 | scikit-learn>=0.23.2 13 | tqdm 14 | -------------------------------------------------------------------------------- /run_app.py: -------------------------------------------------------------------------------- 1 | '''Driver script to start and run InterpreT.''' 2 | 3 | import socket 4 | import argparse 5 | from importlib import import_module 6 | 7 | from app import app_configuration 8 | 9 | 10 | def build_model(model_name, *args, **kwargs): 11 | model = getattr(import_module(f'app.database.models.{model_name.lower()}'), model_name.title()) 12 | return model(*args, **kwargs) 13 | 14 | 15 | def main(port, db_path, model_name=None, model_params=None): 16 | print(f'Running on port {port} with database {db_path}' + 17 | f' and {model_name} model' if model_name else '') 18 | 19 | model = build_model(model_name, *model_params) if model_name else None 20 | app = app_configuration.configure_app(db_path, model=model) 21 | hostname = socket.gethostname() 22 | ip_address = socket.gethostbyname(hostname) 23 | app_configuration.print_page_link(hostname, port) 24 | application = app.server 25 | application.run(debug=False, threaded=True, host=ip_address, port=int(port)) 26 | 27 | 28 | if __name__ == '__main__': 29 | parser = argparse.ArgumentParser() 30 | parser.add_argument('-p', '--port', required=True, 31 | help='The port number to run this app on.') 32 | parser.add_argument('-d', '--database', required=True, 33 | help='The path to a database. See app/database/db_example.py') 34 | parser.add_argument('-m', '--model', nargs='+', required=False, 35 | help='Your model name to run, followed by the arguments ' 36 | 'that need to be passed to start the model.') 37 | 38 | args = parser.parse_args() 39 | if args.model: 40 | main(args.port, args.database, args.model[0], args.model[1:]) 41 | else: 42 | main(args.port, args.database) 43 | --------------------------------------------------------------------------------