├── LICENSE
├── README.md
├── app
├── __init__.py
├── app_configuration.py
├── app_layout.py
├── assets
│ ├── intel_ai_logo.jpg
│ └── stylesheet-oil-and-gas.css
├── database
│ ├── __init__.py
│ ├── database.py
│ ├── db_analyzer.py
│ ├── db_example.py
│ └── models
│ │ ├── kdvlp.py
│ │ └── vl_model.py
└── plot_func.py
├── assets
├── logo.png
├── screencast.png
├── screenshot_add_ex.png
└── screenshot_app_launch.png
├── example_database1
├── data.mdb
├── faiss
│ ├── img_indices_0
│ ├── img_indices_1
│ ├── img_indices_10
│ ├── img_indices_11
│ ├── img_indices_12
│ ├── img_indices_2
│ ├── img_indices_3
│ ├── img_indices_4
│ ├── img_indices_5
│ ├── img_indices_6
│ ├── img_indices_7
│ ├── img_indices_8
│ ├── img_indices_9
│ ├── txt_indices_0
│ ├── txt_indices_1
│ ├── txt_indices_10
│ ├── txt_indices_11
│ ├── txt_indices_12
│ ├── txt_indices_2
│ ├── txt_indices_3
│ ├── txt_indices_4
│ ├── txt_indices_5
│ ├── txt_indices_6
│ ├── txt_indices_7
│ ├── txt_indices_8
│ └── txt_indices_9
└── lock.mdb
├── example_database2
├── data.mdb
├── faiss
│ ├── img_indices_0
│ ├── img_indices_1
│ ├── img_indices_10
│ ├── img_indices_11
│ ├── img_indices_12
│ ├── img_indices_2
│ ├── img_indices_3
│ ├── img_indices_4
│ ├── img_indices_5
│ ├── img_indices_6
│ ├── img_indices_7
│ ├── img_indices_8
│ ├── img_indices_9
│ ├── txt_indices_0
│ ├── txt_indices_1
│ ├── txt_indices_10
│ ├── txt_indices_11
│ ├── txt_indices_12
│ ├── txt_indices_2
│ ├── txt_indices_3
│ ├── txt_indices_4
│ ├── txt_indices_5
│ ├── txt_indices_6
│ ├── txt_indices_7
│ ├── txt_indices_8
│ └── txt_indices_9
└── lock.mdb
├── requirements.txt
└── run_app.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Intel Labs
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # DISCONTINUATION OF PROJECT #
2 | This project will no longer be maintained by Intel.
3 | Intel has ceased development and contributions including, but not limited to, maintenance, bug fixes, new releases, or updates, to this project.
4 | Intel no longer accepts patches to this project.
5 | If you have an ongoing need to use this project, are interested in independently developing it, or would like to maintain patches for the open source software community, please create your own fork of this project.
6 |
7 |
8 |
9 |
10 |
11 |
12 | VL-InterpreT provides interactive visualizations for interpreting the attentions and hidden representations in vision-language transformers. It is a task agnostic and integrated tool that:
13 | - Tracks a variety of statistics in attention heads throughout all layers for both vision and language components
14 | - Visualizes cross-modal and intra-modal attentions through easily readable heatmaps
15 | - Plots the hidden representations of vision and language tokens as they pass through the transformer layers.
16 |
17 | # Paper
18 | Our paper won the Best Demo Award at CVPR 2022: VL-InterpreT: An Interactive Visualization Tool for Interpreting Vision-Language Transformers
19 |
20 | # Screencast Video
21 | This video provides an overview of VL-Interpret and demonstrates a few interesting examples.
22 |
23 |
24 |
25 |
26 |
27 | # Live Demo
28 | A live demo of the app (same as in the screencast video) is available here.
29 |
30 | Please watch the screencast video to get a sense of how to navigate the app. This demo contains 100 examples from the Visual Commonsense Reasoning task and shows the attention and hidden representations from the KD-VLP model.
31 |
32 | # Setup and Usage
33 | You may run VL-InterpreT together with a model of your own choice (see [*Set up a live model*](#set-up-a-live-model)), and/or with a database that contains data extracted from a model.
34 |
35 | To run VL-InterpreT with our example databases, please first clone this repository and install the dependencies. For example:
36 | ```bash
37 | git clone https://github.com/IntelLabs/VL-InterpreT.git
38 | # create and activate your virtual environment if needed, then:
39 | cd VL-InterpreT
40 | pip install -r requirements.txt
41 | ```
42 |
43 | Then you can run VL-InterpreT (replace 6006 with any port number you prefer):
44 | ```bash
45 | python run_app.py --port 6006 --database example_database2
46 | # alternatively:
47 | python run_app.py -p 6006 -d example_database2
48 | ```
49 |
50 | We have included two example databases in this repository. `example_database1` contains grey "images" and randomly generated data, and `example_database2` contains one example image+text pair that was processed by the KD-VLP model.
51 |
52 | Once the app runs, it will show the IP address where the app is running on. Open it in your browser to use VL-InterpreT:
53 |
54 |
55 |
56 |
57 |
58 |
59 | ## Set up a database
60 | You may extract data from a transformer in a specific format, and then use VL-InterpreT to visualize them interactively.
61 |
62 | To set up such a database, please see [db_example.py](), which is an example script that creates a database (i.e., `example_database1`) with randomly generated data. To prepare the data from your own transformer, for each image+text pair you should mainly:
63 | - Extract cross-attention weights and hidden representation vectors from your transformer
64 | - For example, you may extract them from [a Huggingface model](https://huggingface.co/transformers/v3.0.2/model_doc/bert.html#transformers.BertModel.forward) by specifying `output_attentions=True` and `output_hidden_states=True`.
65 | - Get the original input image as an array, and input tokens as a list of strings (text tokens followed by image tokens, where image tokens can be named at your discretion, e.g., "img_0", "img_1", etc.).
66 | - For the input image tokens, specify how they corresponds to the positions in the original image, assuming the top left corner has coordinates (0, 0).
67 |
68 | Please refer to [db_example.py]() for more details. You may also look into our `example_database1` and `example_database2` for example databases that have been preprocessed.
69 |
70 | Once you prepared the data as specified, organize them in the following format (again, see [db_example.py]() for the specifics):
71 | ```python
72 | data = [
73 | { # first example
74 | 'ex_id': 0,
75 | 'image': np.array([]),
76 | 'tokens': [],
77 | 'txt_len': 0,
78 | 'img_coords': [],
79 | 'attention': np.array([]),
80 | 'hidden_states': np.array([])
81 | },
82 | { # second example
83 | 'ex_id': 1,
84 | 'image': np.array([]),
85 | 'tokens': [],
86 | 'txt_len': 0,
87 | 'img_coords': [],
88 | 'attention': np.array([]),
89 | 'hidden_states': np.array([])
90 | },
91 | # ...
92 | ]
93 | ```
94 | Then run the following code from the `app/database` directory to create and preprocess your database:
95 | ```python
96 | import pickle
97 | from database import VliLmdb # this is in app/database/database.py
98 |
99 | # create a database
100 | db = VliLmdb(db_dir='path_to_your_database', read_only=False)
101 | # add data
102 | for ex_data in data:
103 | db[str(ex_data['ex_id'])] = pickle.dumps(ex_data, protocol=pickle.HIGHEST_PROTOCOL)
104 | # preprocess the database
105 | db.preprocess()
106 | ```
107 | Now you can visualize your data with VL-InterpreT:
108 | ```bash
109 | python run_app.py -p 6006 -d path_to_your_database
110 | ```
111 |
112 | ## Set up a live model
113 | You may also run a live transformer model together with VL-InterpreT. This way, the `Add example` functionality will become available on the web app -- users can add image+text pairs for the transformer to process them in real time and for VL-InterpreT to visualize the process:
114 |
115 |
116 |
117 |
118 |
119 | To add a model, you will need to define your own model class that inherits from the [VL_Model]() base class, and then implement a `data_setup` function in this class. This function should run a forward pass with your model given the input image+text pair, and return data (e.g., attention weights, hidden state vectors, etc.) in the required format. The return data format is the same as what was specified in [*Set up a database*](#set-up-a-database) and in [db_example.py](). **Please see [vl_model.py]() for more details**. You may also refer to [kdvlp.py]() for an example model class.
120 |
121 | Additionally, please start with the following naming pattern for your script and class, to make sure your model runs easily with VL-InterpreT:
122 | - Create a new python script in `app/database/models` for your model class, and name it in all lowercase (e.g., `app/database/models/yourmodelname.py`)
123 | - Name your model class in title case, e.g., `class Yourmodelname`. This class name should be the result of calling `'yourmodelname'.title()`, where `'yourmodelname.py'` is the name of your python script.
124 | - For example, our KD-VLP model class is defined in `app/database/models/kdblp.py`, and named `class Kdvlp`.
125 |
126 | Once your implementation is completed, you can run VL-InterpreT with your model using:
127 | ```bash
128 | python run_app.py --port 6006 --database example_database2 --model yourmodelname your_model_parameters
129 | # alternatively:
130 | python run_app.py -p 6006 -d example_database2 -m yourmodelname your_model_parameters
131 | ```
132 |
133 | Note that the `__init__` method of your model class may take an arbituary number of parameters, and you may specify these parameters when running VL-InterpreT by putting them after the name of your model (i.e., replacing `your_model_parameters`).
134 |
--------------------------------------------------------------------------------
/app/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IntelLabs/VL-InterpreT/0f9d0bb1dad13d91ce79d9453faf81fc2b9277b9/app/__init__.py
--------------------------------------------------------------------------------
/app/app_configuration.py:
--------------------------------------------------------------------------------
1 | ''' Module containing the main Dash app code.
2 |
3 | The main function in this module is configureApp() which starts
4 | the Dash app on a Flask server and configures it with callbacks.
5 | '''
6 |
7 | import os
8 | import copy
9 | import dash
10 | from flask import Flask
11 | from dash.dependencies import Input, Output, State
12 | from dash.exceptions import PreventUpdate
13 | from termcolor import cprint
14 | from .app_layout import get_layout
15 | from .plot_func import *
16 | from app.database.db_analyzer import VliDataBaseAnalyzer
17 |
18 |
19 | def configure_app(db_dir, model=None):
20 | db = VliDataBaseAnalyzer(db_dir, read_only=(not model))
21 | n_layers = db['n_layers']
22 | app = start_app(db['n_examples'], n_layers)
23 |
24 | @app.callback(
25 | Output('add_ex_div', 'style'),
26 | Output('add_ex_toggle', 'style'),
27 | Input('add_ex_toggle', 'n_clicks'),
28 | )
29 | def toggle_add_example(n_clicks):
30 | if model is None:
31 | return {'display': 'none'}, {'display': 'none'}
32 | if not n_clicks:
33 | raise PreventUpdate
34 | if n_clicks % 2 == 0:
35 | return {'display': 'none'}, {'padding': '0 15px'}
36 | return {}, {'padding': '0 15px', 'background': 'lightgrey'}
37 |
38 |
39 | @app.callback(
40 | Output('ex_selector', 'max'),
41 | Output('ex_selector', 'value'),
42 | Input('add_ex_btn', 'n_clicks'),
43 | State('new_ex_img', 'value'),
44 | State('new_ex_txt', 'value'),
45 | )
46 | def add_example(n_clicks, image_in, text_in):
47 | if not n_clicks: # app initialzation
48 | return db['n_examples'] - 1, '0'
49 | ex_id = db['n_examples']
50 | data = model.data_setup(ex_id, image_in, text_in)
51 | db.add_example(ex_id, data)
52 | return ex_id, ex_id
53 |
54 |
55 | @app.callback(
56 | Output('ex_text', 'figure'),
57 | Output('ex_txt_len', 'data'),
58 | Output('ex_sm_text', 'figure'),
59 | Output('ex_img_attn_overlay', 'figure'),
60 | [Input('ex_selector', 'value'),
61 | Input('selected_head_layer', 'data'),
62 | Input('selected_text_token', 'data'),
63 | Input('selected_img_tokens', 'data'),
64 | Input('map_smooth_checkbox', 'value')],
65 | [State('ex_txt_len', 'data')]
66 | )
67 | def display_ex_text_and_img_overlay(ex_id, head_layer, text_token, selected_img_token, smooth,
68 | txt_len):
69 | trigger = get_input_trigger(dash.callback_context)
70 | layer, head = head_layer['layer'], head_layer['head']
71 | new_state = (layer < 0 or head < 0) or (not txt_len)
72 | if (not selected_img_token) and \
73 | (not text_token or text_token['sentence'] < 0 or text_token['word'] < 0):
74 | new_state = True
75 |
76 | if new_state or trigger == 'ex_selector':
77 | text_fig, txt_len = plot_text(db, ex_id)
78 | text_attn_fig = text_fig
79 | img_attn_overlay = get_empty_fig()
80 | return text_fig, txt_len, text_attn_fig, img_attn_overlay
81 |
82 | img_token = set(map(tuple, selected_img_token))
83 | text_fig = highlight_txt_token(db, ex_id, text_token)
84 | if trigger == 'selected_text_token':
85 | img_token = None
86 | elif trigger == 'selected_img_tokens':
87 | text_token = None
88 | text_fig, txt_len = plot_text(db, ex_id)
89 |
90 | text_attn_fig = plot_attn_from_txt(db, ex_id, layer, head, text_token, img_token)
91 | img_attn_overlay = plot_attn_from_img(db, ex_id, layer, head, word_ids=text_token,
92 | img_coords=img_token, smooth=smooth)
93 |
94 | return text_fig, txt_len, text_attn_fig, img_attn_overlay
95 |
96 |
97 | @app.callback(
98 | Output('ex_img_token_overlay', 'figure'),
99 | [Input('ex_selector', 'value'),
100 | Input('selected_img_tokens', 'data')]
101 | )
102 | def display_img_token_selection(ex_id, selected_img_tokens):
103 | if not ex_id:
104 | return get_empty_fig()
105 | trigger = get_input_trigger(dash.callback_context)
106 | if trigger == 'ex_selector':
107 | token_overlay = get_overlay_fig(db, ex_id)
108 | else:
109 | token_overlay = highlight_img_grid_selection(db, ex_id, selected_img_tokens)
110 | return token_overlay
111 |
112 |
113 | @app.callback(
114 | Output('attn2img', 'style'),
115 | Output('attn2txt', 'style'),
116 | Output('attn2img_label', 'style'),
117 | Output('attn2txt_label', 'style'),
118 | Output('play_btn', 'style'),
119 | [Input('attn2img_toggle', 'value')]
120 | )
121 | def toggle_attn_to(attn_toggle):
122 | hidden_style = {'display': 'none', 'z-index': '-1'}
123 | if attn_toggle: # hide image
124 | return hidden_style, {}, {'color': 'lightgrey'}, {}, {}
125 | else: # hide text
126 | return {}, hidden_style, {}, {'color': 'lightgrey'}, hidden_style
127 |
128 |
129 | @app.callback(
130 | Output('ex_image', 'figure'),
131 | Output('ex_sm_image', 'figure'),
132 | [Input('ex_selector', 'value')]
133 | )
134 | def display_ex_image(ex_id):
135 | image_fig = plot_image(db, ex_id)
136 | sm_image_fig = copy.deepcopy(image_fig)
137 | sm_image_fig.update_layout(height=256)
138 | return image_fig, sm_image_fig
139 |
140 |
141 | @app.callback(
142 | Output('selected_img_tokens', 'data'),
143 | [Input('ex_img_token_overlay', 'clickData'),
144 | Input('ex_img_token_overlay', 'selectedData'),
145 | Input('selected_text_token', 'data'),
146 | Input('attn2img_toggle', 'value')], # clear if toggle attention to
147 | [State('selected_img_tokens', 'data')]
148 | )
149 | def save_img_selection(click_data, selected_data, txt_token, toggle, points):
150 | if points is None:
151 | return []
152 | if (not click_data) and (not selected_data):
153 | return []
154 | trigger = get_input_trigger_full(dash.callback_context).split('.')[1]
155 | points = set(map(tuple, points))
156 | if trigger == 'clickData':
157 | x = click_data['points'][0]['pointNumber']
158 | y = click_data['points'][0]['curveNumber']
159 | if (x, y) in points:
160 | points.remove((x, y))
161 | else:
162 | # points.add((x, y))
163 | points = {(x, y)} # TODO: enable a set of multiple points and show average
164 | elif trigger == 'selectedData':
165 | if selected_data:
166 | points.update([(pt['pointNumber'], pt['curveNumber']) for pt in selected_data['points']])
167 | else:
168 | return []
169 | else:
170 | return []
171 | return list(points) # list is JSON serializable
172 |
173 |
174 | @app.callback(
175 | Output('selected_text_token', 'data'),
176 | [Input('ex_text', 'clickData'), # from mouse click
177 | Input('movie_progress', 'data'), # from movie slider update
178 | Input('ex_selector', 'value'), # clear if new example selected
179 | Input('attn2img_toggle', 'value')], # clear if toggle attention to
180 | [State('auto_stepper', 'disabled')]
181 | )
182 | def save_text_click(click_data, movie_progress, ex_id, toggle, movie_disabled):
183 | trigger = get_input_trigger(dash.callback_context)
184 | if trigger == 'ex_text' and click_data:
185 | sentence_id = click_data['points'][0]['curveNumber']
186 | word_id = click_data['points'][0]['pointIndex']
187 | return {'sentence': sentence_id, 'word': word_id}
188 | if trigger == 'movie_progress':
189 | if movie_disabled or movie_progress == -1: # no change
190 | raise PreventUpdate()
191 | # convert index to sentence/word
192 | sentence, word = 0, 0
193 | txt_len = db[ex_id]['txt_len']
194 | for i, token in enumerate(db[ex_id]['tokens'][:txt_len]):
195 | if i == movie_progress:
196 | break
197 | if token == '[SEP]':
198 | sentence += 1
199 | word = -1
200 | word += 1
201 | return {'sentence': sentence, 'word': word}
202 | return None
203 |
204 |
205 | @app.callback(
206 | Output('head_summary', 'figure'),
207 | [Input('ex_selector', 'value'),
208 | Input('selected_head_layer', 'data')],
209 | [State('head_summary', 'figure')]
210 | )
211 | def display_head_summary(ex_id, head_layer, fig):
212 | layer, head = head_layer['layer'], head_layer['head']
213 | attn_type_index = get_active_figure_data(fig)
214 | data = plot_head_summary(db, ex_id, attn_type_index, layer, head)
215 | return data
216 |
217 |
218 | @app.callback(
219 | Output('selected_head_layer', 'data'),
220 | [Input('head_summary', 'clickData')]
221 | )
222 | def save_head_summary_click(click_data):
223 | if click_data:
224 | x, y = get_click_coords(click_data)
225 | return {'layer': y, 'head': x}
226 | return {'layer': 10, 'head': 7}
227 |
228 |
229 | @app.callback(
230 | Output('tsne_title', 'children'),
231 | Output('attn_title', 'children'),
232 | Output('accuracy', 'children'),
233 | [Input('layer_slider', 'value'),
234 | Input('selected_head_layer', 'data'),
235 | Input('ex_selector', 'value')]
236 | )
237 | def update_titles(tsne_layer, selected_head_layer, ex_id):
238 | if ('accuracy' not in db[ex_id]) or (db[ex_id]['accuracy'] is None):
239 | acc_text = 'unknown'
240 | else:
241 | acc_text = 'correct' if db[ex_id]['accuracy'] else 'incorrect'
242 | if tsne_layer == n_layers:
243 | tsne_title = 't-SNE Embeddings before Layer #1'
244 | else:
245 | tsne_title = f't-SNE Embeddings after Layer #{n_layers - tsne_layer}'
246 | layer, head = selected_head_layer['layer'], selected_head_layer['head']
247 | head = 'Average' if int(head) != head else f'#{head + 1}'
248 | att_title = f'Attention in Layer #{layer+1}, Head {head}'
249 | return tsne_title, att_title, acc_text
250 |
251 |
252 | # movie callbacks
253 |
254 | @app.callback(
255 | Output('movie_progress', 'data'),
256 | [Input('auto_stepper', 'n_intervals'),
257 | Input('auto_stepper', 'disabled'),
258 | Input('ex_selector', 'value')],
259 | [State('movie_progress', 'data'),
260 | State('ex_txt_len', 'data')]
261 | )
262 | def stepper_advance(n_intervals, disable, ex_id, movie_progress, txt_len):
263 | trigger = get_input_trigger_full(dash.callback_context)
264 | if trigger.startswith('ex_selector') or \
265 | (trigger == 'auto_stepper.disabled' and disable):
266 | return -1
267 | if n_intervals:
268 | return (movie_progress + 1) % txt_len
269 | return -1
270 |
271 | @app.callback(
272 | Output('play_btn', 'children'),
273 | Output('play_btn', 'n_clicks'),
274 | Output('auto_stepper', 'disabled'),
275 | [Input('play_btn', 'n_clicks'),
276 | Input('ex_selector', 'value'),
277 | Input('attn2img_toggle', 'value')],
278 | [State('auto_stepper', 'n_intervals')]
279 | )
280 | def play_btn_click(n_clicks, ex_id, attn_to_toggle, n_intervals):
281 | trigger = get_input_trigger(dash.callback_context)
282 | if (trigger == 'play_btn') and (n_intervals is not None) and \
283 | (n_clicks is not None) and n_clicks % 2 == 1:
284 | return '\u275A\u275A', n_clicks, False
285 | elif (trigger == 'ex_selector') or (trigger == 'attn2img_toggle'):
286 | return '\u25B6', 0, True
287 | return '\u25B6', n_clicks, True
288 |
289 |
290 | # TSNE
291 |
292 | @app.callback(
293 | Output('tsne_map', 'figure'),
294 | [Input('ex_selector', 'value'),
295 | Input('layer_slider', 'value'),
296 | Input('selected_text_token', 'data'),
297 | Input('selected_img_tokens', 'data')],
298 | [State('auto_stepper', 'disabled')]
299 | )
300 | def plot_tsne(ex_id, layer, text_tokens_id, img_tokens, movie_paused):
301 | if not movie_paused:
302 | raise PreventUpdate
303 | figure = show_tsne(db, ex_id, n_layers - layer, text_tokens_id, img_tokens)
304 | return figure
305 |
306 | @app.callback(
307 | Output('hover_out', 'show'),
308 | Output('hover_out', 'bbox'),
309 | Output('hover_out', 'children'),
310 | Input('tsne_map', 'hoverData'),
311 | Input('layer_slider', 'value')
312 | )
313 | def display_hover(hover_data, layer):
314 | if hover_data is None:
315 | return False, dash.no_update, dash.no_update
316 | layer = n_layers - layer
317 | point = hover_data['points'][0]
318 | ex_id, token_id = point['hovertext'].split(',')
319 | tooltip_children = plot_tooltip_content(db, ex_id, int(token_id))
320 | return True, point['bbox'], tooltip_children
321 |
322 |
323 | return app
324 |
325 |
326 | def start_app(n_examples, n_layers):
327 | print('Starting server')
328 | server = Flask(__name__)
329 | server.secret_key = os.environ.get('secret_key', 'secret')
330 |
331 | app = dash.Dash(__name__, server=server, url_base_pathname='/')
332 | app.title = 'VL-InterpreT'
333 | app.layout = get_layout(n_examples, n_layers)
334 | return app
335 |
336 | def print_page_link(hostname, port):
337 | print('\n\n')
338 | cprint('------------------------' '------------------------', 'green')
339 | cprint('App Launched!', 'red')
340 | cprint('------------------------' '------------------------', 'green')
341 |
342 | def get_click_coords(click_data):
343 | return click_data['points'][0]['x'], click_data['points'][0]['y']
344 |
345 | def get_input_trigger(ctx):
346 | return ctx.triggered[0]['prop_id'].split('.')[0]
347 |
348 | def get_input_trigger_full(ctx):
349 | return ctx.triggered[0]['prop_id']
350 |
351 | def get_active_figure_data(fig):
352 | return fig['layout']['updatemenus'][0]['active'] if fig else 0
353 |
--------------------------------------------------------------------------------
/app/app_layout.py:
--------------------------------------------------------------------------------
1 | """ Module specifying Dash app UI layout
2 |
3 | The main function here is the get_layout() function which returns
4 | the Dash/HTML layout for InterpreT.
5 | """
6 |
7 | from dash import dcc
8 | from dash import html
9 | import dash_daq as daq
10 | import base64
11 | import os
12 |
13 | intel_dark_blue = "#0168b5"
14 | intel_light_blue = "#04c7fd"
15 |
16 |
17 | def get_layout(n_examples, n_layers):
18 | logoConfig = dict(displaylogo=False, modeBarButtonsToRemove=["sendDataToCloud"])
19 | image_filename = os.path.join("app", "assets", "intel_ai_logo.jpg")
20 | encoded_image = base64.b64encode(open(image_filename, "rb").read())
21 |
22 | layout = html.Div(
23 | [
24 | # Stored values
25 | dcc.Store(id="selected_head_layer"),
26 | dcc.Store(id="selected_text_token", data={'sentence': -1, 'word': -1}),
27 | dcc.Store(id="selected_img_tokens"),
28 | dcc.Store(id="selected_token_from_matrix"),
29 | dcc.Store(id="movie_progress", data=0),
30 | dcc.Store(id="ex_txt_len"),
31 |
32 |
33 | # Header
34 | html.Div(
35 | [
36 | html.Div(
37 | [
38 | html.Img( # Intel logo
39 | src=f"data:image/png;base64,{encoded_image.decode()}",
40 | style={
41 | "display": "inline",
42 | "height": str(174 * 0.18) + "px",
43 | "width": str(600 * 0.18) + "px",
44 | "position": "relative",
45 | "padding-right": "30px",
46 | "vertical-align": "middle",
47 | },
48 | ),
49 | html.Div([
50 | html.Span("VL-Interpre", style={'color': intel_dark_blue}),
51 | html.Span("T", style={'color': intel_light_blue})
52 | ], style={"font-size": "40px", "display": "inline", "vertical-align": "middle"}
53 | ),
54 | html.Div(
55 | children=[
56 | html.Span("An Interactive Visualization Tool for "),
57 | html.Strong("Interpre", style={'color': intel_dark_blue}),
58 | html.Span("ting "),
59 | html.Strong("V", style={'color': intel_dark_blue}),
60 | html.Span("ision-"),
61 | html.Strong("L", style={'color': intel_dark_blue}),
62 | html.Span("anguage "),
63 | html.Strong("T", style={'color': intel_light_blue}),
64 | html.Span("ransformers")
65 | ],
66 | style={"font-size": "25px"}
67 | ),
68 | ],
69 | )
70 | ],
71 | style={ "text-align": "center", "margin": "2%"},
72 | ),
73 |
74 |
75 | # Example selector
76 | html.Div(
77 | [html.Label("Example ID", className="plot-label", style={"margin-right": "10px"}),
78 | dcc.Input(id="ex_selector", type="number", value=0, min=0, max=n_examples-1),
79 | html.Span(" - Model prediction ", className="plot-label", style={"font-weight": "300"}),
80 | html.Label("", id="accuracy", className="plot-label", style={"margin-right": "70px"}),
81 | html.Button('Add example', id='add_ex_toggle', style={"padding": "0 15px"}),
82 |
83 | # Add example
84 | html.Div([
85 | html.Div(
86 | [html.Div([
87 | html.Label("Image", className="plot-label", style={"margin-right": "10px"}),
88 | dcc.Input(
89 | id="new_ex_img",
90 | placeholder="Enter URL or path",
91 | style={"width": "80%", "max-width": "600px", "margin": "5px"})
92 | ]), html.Div([
93 | html.Label("Text", className="plot-label", style={"margin-right": "10px"}),
94 | dcc.Input(id="new_ex_txt",
95 | placeholder="Enter text",
96 | style={"width": "80%", "max-width": "600px", "margin": "5px"})
97 | ]),
98 | html.Button("Add", id="add_ex_btn",
99 | style={"padding": "0 15px", "margin": "5px", "background": "white"})],
100 | style={"margin": "5px 10%", "padding": "5px", "background": "#f1f1f1"})],
101 | id="add_ex_div",
102 | style={"display": "none"})
103 | ],
104 | style={"text-align": "center", "margin": "2%"},
105 | ),
106 |
107 |
108 | # Head summary matrix
109 | html.Div(
110 | [html.Label("Attention Head Summary", className="plot-label"),
111 | dcc.Graph(id="head_summary", config={"displayModeBar": False})],
112 | style={"display": "inline-block", "margin-right": "5%", "vertical-align": "top"}
113 | ),
114 |
115 |
116 | # TSNE
117 | html.Div(
118 | [
119 | html.Label("t-SNE Embeddings", id="tsne_title", className="plot-label"),
120 | html.Div([
121 | dcc.Graph(id="tsne_map", config=logoConfig, clear_on_unhover=True),
122 | dcc.Tooltip(id="hover_out", direction='bottom')
123 | ], style={"display": "inline-block", "width": "85%"}),
124 | html.Div([
125 | html.Label("Layer", style={}),
126 | html.Div([
127 | dcc.Slider(
128 | id="layer_slider", min=0, max=n_layers, step=1,
129 | value=0, included=False, vertical=True,
130 | marks={i: {
131 | "label": str(n_layers - i),
132 | "style": {"margin": "0 0 -10px -41px" if (i < n_layers - 9) else "0 0 -10px -35px"}
133 | } for i in range(n_layers + 1)}
134 | )
135 | ], style={"margin": "-10px -25px -25px 24px", })
136 | ], style={"display": "inline-block", "vertical-align": "top",
137 | "width": "44px", "height": "450px", "margin": "-10px 0 0 10px",
138 | "border": "1px solid lightgray", "border-radius": "5px"}),
139 | ],
140 | style={"width": "45%", "display": "inline-block", "vertical-align": "top"}
141 | ),
142 |
143 | html.Div(
144 | [html.Label("Attention", id="attn_title", className="plot-label"),
145 | html.Hr(style={"margin": "0 15%"})],
146 | style={"margin": "3% 0 5px 0"}),
147 |
148 |
149 | # Attention to
150 | html.Div(
151 | [html.Div(
152 | [html.Label("Attention to:", className="plot-label", style={"margin-right": "10px"}),
153 | html.Label("Image", id="attn2img_label", className="plot-label"),
154 | html.Div(
155 | daq.ToggleSwitch(id="attn2img_toggle", value=False, size=40),
156 | style={"display": "inline-block", "margin": "0 10px", "vertical-align": "top"}
157 | ),
158 | html.Label("Text", id="attn2txt_label", className="plot-label"),
159 | html.Button("\u25B6", id="play_btn", n_clicks=0, style={"display": "none"})],
160 | style={"height": "34px", "width": "400px", "text-align": "left",
161 | "display": "inline-block", "padding-left": "100px"}
162 | ),
163 | html.Div(id="attn2img", children=[ # image
164 | html.Div(
165 | dcc.Graph(id="ex_image", config={"displayModeBar": False}, style={"height": "450px"}),
166 | style={"width": "40%", "position": "absolute"}),
167 | html.Div(
168 | [dcc.Graph(id="ex_img_token_overlay", config={"displayModeBar": False})],
169 | style={"width": "40%", "position": "absolute", "height": "450px"}),
170 | ]),
171 | html.Div(id="attn2txt", children=[ # text
172 | html.Div(
173 | [dcc.Graph(id="ex_text", config={"displayModeBar": False})],
174 | style={"width": "40%", "position": "absolute", "height": "450px"})
175 | ], style={"display": "none", "z-index": "-1"})
176 | ],
177 | style={"width": "40%", "display": "inline-block", "vertical-align": "top"}
178 | ),
179 |
180 |
181 | # Attention from
182 | html.Div(
183 | [html.Div([html.Label("Attention from Text", className="plot-label", style={"height": "26px"})]),
184 | html.Div( # text display area
185 | [dcc.Graph(id="ex_sm_text", config={"displayModeBar": False}, style={"height": "174px"})]
186 | ),
187 | html.Div([ # label
188 | html.Hr(style={"height": "1px", "margin": "0"}),
189 | html.Label("Attention from Image", className="plot-label", style={"height": "26px"})
190 | ]),
191 | html.Div( # image display area
192 | [dcc.Graph(id="ex_sm_image", config={"displayModeBar": False},
193 | style={"width": "37%", "height": "256px", "margin-left": "4%", "position": "absolute"}),
194 | dcc.Graph(id="ex_img_attn_overlay", config={"displayModeBar": False},
195 | style={"width": "37%", "height": "256px", "margin-left": "4%", "position": "absolute"})]
196 | )],
197 | style={"height": "484px", "width": "45%", "display": "inline-block", "margin-left": "15px", "vertical-align": "top"},
198 | ),
199 | # Movie control (hidden)
200 | html.Div(
201 | [dcc.Interval(id="auto_stepper", interval=2500, n_intervals=0, disabled=True)],
202 | style={"display": "none"}
203 | ),
204 |
205 | # Checkbox for smoothing
206 | html.Div(
207 | [dcc.Checklist(
208 | id="map_smooth_checkbox",
209 | options=[{'label': 'Smooth attention map', 'value': 'smooth'}],
210 | value=['smooth'],
211 | style={"text-align": "center"})],
212 | style={"width": "45%", "display": "inline-block", "margin": "5px 0 2% 0", "padding-left": "40%"}
213 | ),
214 | ],
215 | style={"textAlign": "center"}
216 | )
217 |
218 | return layout
219 |
--------------------------------------------------------------------------------
/app/assets/intel_ai_logo.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IntelLabs/VL-InterpreT/0f9d0bb1dad13d91ce79d9453faf81fc2b9277b9/app/assets/intel_ai_logo.jpg
--------------------------------------------------------------------------------
/app/assets/stylesheet-oil-and-gas.css:
--------------------------------------------------------------------------------
1 | /* Table of contents
2 | ––––––––––––––––––––––––––––––––––––––––––––––––––
3 | - Grid
4 | - Base Styles
5 | - Typography
6 | - Links
7 | - Buttons
8 | - Forms
9 | - Lists
10 | - Code
11 | - Tables
12 | - Spacing
13 | - Utilities
14 | - Clearing
15 | - Media Queries
16 | */
17 |
18 |
19 | /* Grid
20 | –––––––––––––––––––––––––––––––––––––––––––––––––– */
21 | .container {
22 | position: relative;
23 | width: 100%;
24 | max-width: 960px;
25 | margin: 0 auto;
26 | padding: 0 20px;
27 | box-sizing: border-box; }
28 | .column,
29 | .columns {
30 | width: 100%;
31 | float: left;
32 | box-sizing: border-box; }
33 |
34 | /* For devices larger than 400px */
35 | @media (min-width: 400px) {
36 | .container {
37 | width: 85%;
38 | padding: 0; }
39 | }
40 |
41 | /* For devices larger than 550px */
42 | @media (min-width: 550px) {
43 | .container {
44 | width: 80%; }
45 | .column,
46 | .columns {
47 | margin-left: 0.5%; }
48 | .column:first-child,
49 | .columns:first-child {
50 | margin-left: 0; }
51 |
52 | .one.column,
53 | .one.columns { width: 8%; }
54 | .two.columns { width: 16.25%; }
55 | .three.columns { width: 22%; }
56 | .four.columns { width: 33%; }
57 | .five.columns { width: 39.3333333333%; }
58 | .six.columns { width: 49.75%; }
59 | .seven.columns { width: 56.6666666667%; }
60 | .eight.columns { width: 66.5%; }
61 | .nine.columns { width: 74.0%; }
62 | .ten.columns { width: 82.6666666667%; }
63 | .eleven.columns { width: 91.5%; }
64 | .twelve.columns { width: 100%; margin-left: 0; }
65 |
66 | .one-third.column { width: 30.6666666667%; }
67 | .two-thirds.column { width: 65.3333333333%; }
68 |
69 | .one-half.column { width: 48%; }
70 |
71 | /* Offsets */
72 | .offset-by-one.column,
73 | .offset-by-one.columns { margin-left: 8.66666666667%; }
74 | .offset-by-two.column,
75 | .offset-by-two.columns { margin-left: 17.3333333333%; }
76 | .offset-by-three.column,
77 | .offset-by-three.columns { margin-left: 26%; }
78 | .offset-by-four.column,
79 | .offset-by-four.columns { margin-left: 34.6666666667%; }
80 | .offset-by-five.column,
81 | .offset-by-five.columns { margin-left: 43.3333333333%; }
82 | .offset-by-six.column,
83 | .offset-by-six.columns { margin-left: 52%; }
84 | .offset-by-seven.column,
85 | .offset-by-seven.columns { margin-left: 60.6666666667%; }
86 | .offset-by-eight.column,
87 | .offset-by-eight.columns { margin-left: 69.3333333333%; }
88 | .offset-by-nine.column,
89 | .offset-by-nine.columns { margin-left: 78.0%; }
90 | .offset-by-ten.column,
91 | .offset-by-ten.columns { margin-left: 86.6666666667%; }
92 | .offset-by-eleven.column,
93 | .offset-by-eleven.columns { margin-left: 95.3333333333%; }
94 |
95 | .offset-by-one-third.column,
96 | .offset-by-one-third.columns { margin-left: 34.6666666667%; }
97 | .offset-by-two-thirds.column,
98 | .offset-by-two-thirds.columns { margin-left: 69.3333333333%; }
99 |
100 | .offset-by-one-half.column,
101 | .offset-by-one-half.columns { margin-left: 52%; }
102 |
103 | }
104 |
105 |
106 | /* Base Styles
107 | –––––––––––––––––––––––––––––––––––––––––––––––––– */
108 | /* NOTE
109 | html is set to 62.5% so that all the REM measurements throughout Skeleton
110 | are based on 10px sizing. So basically 1.5rem = 15px :) */
111 | html {
112 | font-size: 62.5%; }
113 | body {
114 | font-size: 1.5em; /* currently ems cause chrome bug misinterpreting rems on body element */
115 | line-height: 1.6;
116 | font-weight: 400;
117 | font-family: "Open Sans", "HelveticaNeue", "Helvetica Neue", Helvetica, Arial, sans-serif;
118 | color: rgb(50, 50, 50);
119 | margin: 0;
120 | }
121 |
122 |
123 | /* Typography
124 | –––––––––––––––––––––––––––––––––––––––––––––––––– */
125 | h1, h2, h3, h4, h5, h6 {
126 | margin-top: 0;
127 | margin-bottom: 0;
128 | font-weight: 300; }
129 | h1 { font-size: 4.5rem; line-height: 1.2; letter-spacing: -.1rem; margin-bottom: 2rem; }
130 | h2 { font-size: 3.6rem; line-height: 1.25; letter-spacing: -.1rem; margin-bottom: 1.8rem; margin-top: 1.8rem;}
131 | h3 { font-size: 3.0rem; line-height: 1.3; letter-spacing: -.1rem; margin-bottom: 1.5rem; margin-top: 1.5rem;}
132 | h4 { font-size: 2.6rem; line-height: 1.35; letter-spacing: -.08rem; margin-bottom: 1.2rem; margin-top: 1.2rem;}
133 | h5 { font-size: 2.2rem; line-height: 1.5; letter-spacing: -.05rem; margin-bottom: 0.6rem; margin-top: 0.6rem;}
134 | h6 { font-size: 2.0rem; line-height: 1.6; letter-spacing: 0; margin-bottom: 0.75rem; margin-top: 0.75rem;}
135 |
136 | p {
137 | margin-top: 0; }
138 |
139 |
140 | /* Blockquotes
141 | –––––––––––––––––––––––––––––––––––––––––––––––––– */
142 | blockquote {
143 | border-left: 4px lightgrey solid;
144 | padding-left: 1rem;
145 | margin-top: 2rem;
146 | margin-bottom: 2rem;
147 | margin-left: 0rem;
148 | }
149 |
150 |
151 | /* Links
152 | –––––––––––––––––––––––––––––––––––––––––––––––––– */
153 | a {
154 | color: #1EAEDB; }
155 | a:hover {
156 | color: #0FA0CE; }
157 |
158 |
159 | /* Buttons
160 | –––––––––––––––––––––––––––––––––––––––––––––––––– */
161 | .button,
162 | button,
163 | input[type="submit"],
164 | input[type="reset"],
165 | input[type="button"] {
166 | display: inline-block;
167 | height: 38px;
168 | padding: 0 30px;
169 | color: #555;
170 | text-align: center;
171 | font-size: 11px;
172 | font-weight: 600;
173 | line-height: 38px;
174 | letter-spacing: .1rem;
175 | text-transform: uppercase;
176 | text-decoration: none;
177 | white-space: nowrap;
178 | background-color: transparent;
179 | border-radius: 4px;
180 | border: 1px solid #bbb;
181 | cursor: pointer;
182 | box-sizing: border-box; }
183 | .button:hover,
184 | button:hover,
185 | input[type="submit"]:hover,
186 | input[type="reset"]:hover,
187 | input[type="button"]:hover,
188 | .button:focus,
189 | button:focus,
190 | input[type="submit"]:focus,
191 | input[type="reset"]:focus,
192 | input[type="button"]:focus {
193 | color: #333;
194 | border-color: #888;
195 | outline: 0; }
196 | .button.button-primary,
197 | button.button-primary,
198 | input[type="submit"].button-primary,
199 | input[type="reset"].button-primary,
200 | input[type="button"].button-primary {
201 | color: #FFF;
202 | background-color: #33C3F0;
203 | border-color: #33C3F0; }
204 | .button.button-primary:hover,
205 | button.button-primary:hover,
206 | input[type="submit"].button-primary:hover,
207 | input[type="reset"].button-primary:hover,
208 | input[type="button"].button-primary:hover,
209 | .button.button-primary:focus,
210 | button.button-primary:focus,
211 | input[type="submit"].button-primary:focus,
212 | input[type="reset"].button-primary:focus,
213 | input[type="button"].button-primary:focus {
214 | color: #FFF;
215 | background-color: #1EAEDB;
216 | border-color: #1EAEDB; }
217 |
218 | #play_btn {
219 | height: 24px;
220 | display: "inline-block";
221 | margin-left: 20px;
222 | line-height: 24px;
223 | padding: 0 25px;
224 | }
225 |
226 | /* Forms
227 | –––––––––––––––––––––––––––––––––––––––––––––––––– */
228 | input[type="email"],
229 | input[type="number"],
230 | input[type="search"],
231 | input[type="text"],
232 | input[type="tel"],
233 | input[type="url"],
234 | input[type="password"],
235 | textarea,
236 | select {
237 | height: 38px;
238 | padding: 6px 10px; /* The 6px vertically centers text on FF, ignored by Webkit */
239 | background-color: #fff;
240 | border: 1px solid #D1D1D1;
241 | border-radius: 4px;
242 | box-shadow: none;
243 | box-sizing: border-box;
244 | font-family: inherit;
245 | font-size: inherit; /*https://stackoverflow.com/questions/6080413/why-doesnt-input-inherit-the-font-from-body*/}
246 | /* Removes awkward default styles on some inputs for iOS */
247 | input[type="email"],
248 | input[type="number"],
249 | input[type="search"],
250 | input[type="text"],
251 | input[type="tel"],
252 | input[type="url"],
253 | input[type="password"],
254 | textarea {
255 | -webkit-appearance: none;
256 | -moz-appearance: none;
257 | appearance: none; }
258 | textarea {
259 | min-height: 65px;
260 | padding-top: 6px;
261 | padding-bottom: 6px; }
262 | input[type="email"]:focus,
263 | input[type="number"]:focus,
264 | input[type="search"]:focus,
265 | input[type="text"]:focus,
266 | input[type="tel"]:focus,
267 | input[type="url"]:focus,
268 | input[type="password"]:focus,
269 | textarea:focus,
270 | select:focus {
271 | border: 1px solid #33C3F0;
272 | outline: 0; }
273 | label,
274 | legend {
275 | margin-bottom: 0px; }
276 | fieldset {
277 | padding: 0;
278 | border-width: 0; }
279 | input[type="checkbox"],
280 | input[type="radio"] {
281 | display: inline; }
282 | label > .label-body {
283 | display: inline-block;
284 | margin-left: .5rem;
285 | font-weight: normal; }
286 | .plot-label {
287 | font-family: "Trebuchet MS", "verdana";
288 | font-size: 16px;
289 | justify-content: "center";
290 | font-weight: 600;
291 | display: "inline-block";
292 | }
293 |
294 | /* Lists
295 | –––––––––––––––––––––––––––––––––––––––––––––––––– */
296 | ul {
297 | list-style: circle inside; }
298 | ol {
299 | list-style: decimal inside; }
300 | ol, ul {
301 | padding-left: 0;
302 | margin-top: 0; }
303 | ul ul,
304 | ul ol,
305 | ol ol,
306 | ol ul {
307 | margin: 1.5rem 0 1.5rem 3rem;
308 | font-size: 90%; }
309 | li {
310 | margin-bottom: 1rem; }
311 |
312 |
313 | /* Tables
314 | –––––––––––––––––––––––––––––––––––––––––––––––––– */
315 | th,
316 | td {
317 | padding: 12px 15px;
318 | text-align: left;
319 | border-bottom: 1px solid #E1E1E1; }
320 | th:first-child,
321 | td:first-child {
322 | padding-left: 0; }
323 | th:last-child,
324 | td:last-child {
325 | padding-right: 0; }
326 |
327 |
328 | /* Spacing
329 | –––––––––––––––––––––––––––––––––––––––––––––––––– */
330 | button,
331 | .button {
332 | margin-bottom: 0rem; }
333 | input,
334 | textarea,
335 | select,
336 | fieldset {
337 | margin-bottom: 0rem; }
338 | pre,
339 | dl,
340 | figure,
341 | table,
342 | form {
343 | margin-bottom: 0rem; }
344 | p,
345 | ul,
346 | ol {
347 | margin-bottom: 0.75rem; }
348 |
349 | /* Utilities
350 | –––––––––––––––––––––––––––––––––––––––––––––––––– */
351 | .u-full-width {
352 | width: 100%;
353 | box-sizing: border-box; }
354 | .u-max-full-width {
355 | max-width: 100%;
356 | box-sizing: border-box; }
357 | .u-pull-right {
358 | float: right; }
359 | .u-pull-left {
360 | float: left; }
361 |
362 |
363 | /* Misc
364 | –––––––––––––––––––––––––––––––––––––––––––––––––– */
365 | hr {
366 | margin-top: 3rem;
367 | margin-bottom: 3.5rem;
368 | border-width: 0;
369 | border-top: 1px solid #E1E1E1; }
370 |
371 |
372 | /* Clearing
373 | –––––––––––––––––––––––––––––––––––––––––––––––––– */
374 |
375 | /* Self Clearing Goodness */
376 | .container:after,
377 | .row:after,
378 | .u-cf {
379 | content: "";
380 | display: table;
381 | clear: both; }
382 |
383 |
384 | /* Media Queries
385 | –––––––––––––––––––––––––––––––––––––––––––––––––– */
386 | /*
387 | Note: The best way to structure the use of media queries is to create the queries
388 | near the relevant code. For example, if you wanted to change the styles for buttons
389 | on small devices, paste the mobile query code up in the buttons section and style it
390 | there.
391 | */
392 |
393 |
394 | /* Larger than mobile */
395 | @media (min-width: 400px) {}
396 |
397 | /* Larger than phablet (also point when grid becomes active) */
398 | @media (min-width: 550px) {}
399 |
400 | /* Larger than tablet */
401 | @media (min-width: 750px) {}
402 |
403 | /* Larger than desktop */
404 | @media (min-width: 1000px) {}
405 |
406 | /* Larger than Desktop HD */
407 | @media (min-width: 1200px) {}
408 |
--------------------------------------------------------------------------------
/app/database/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IntelLabs/VL-InterpreT/0f9d0bb1dad13d91ce79d9453faf81fc2b9277b9/app/database/__init__.py
--------------------------------------------------------------------------------
/app/database/database.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | from collections import defaultdict
4 | from sklearn.manifold import TSNE
5 | import faiss
6 | from tqdm import tqdm
7 | import pickle
8 | import lmdb
9 | from lz4.frame import compress, decompress
10 | import msgpack
11 | import msgpack_numpy
12 | msgpack_numpy.patch()
13 |
14 |
15 | class VliLmdb(object):
16 | def __init__(self, db_dir, read_only=True, local_world_size=1):
17 | self.readonly = read_only
18 | self.path = db_dir
19 | if not os.path.isdir(db_dir):
20 | os.mkdir(db_dir)
21 | if read_only:
22 | readahead = not self._check_distributed(local_world_size)
23 | self.env = lmdb.open(db_dir, readonly=True, create=False, lock=False, readahead=readahead)
24 | self.txn = self.env.begin(buffers=True)
25 | self.write_cnt = None
26 | else:
27 | self.env = lmdb.open(db_dir, readonly=False, create=True, map_size=4 * 1024**4)
28 | self.txn = self.env.begin(write=True)
29 | self.write_cnt = 0
30 | if not self.txn.get('n_examples'.encode('utf-8')):
31 | self['n_examples'] = 0
32 |
33 | def _check_distributed(self, local_world_size):
34 | try:
35 | dist = local_world_size != 1
36 | except ValueError:
37 | # not using horovod
38 | dist = False
39 | return dist
40 |
41 | def __del__(self):
42 | if self.write_cnt:
43 | self.txn.commit()
44 | self.env.close()
45 |
46 | def __getitem__(self, key):
47 | value = self.txn.get(str(key).encode('utf-8'))
48 | if value is None:
49 | raise KeyError(key)
50 | return msgpack.loads(decompress(value), raw=False)
51 |
52 | def __setitem__(self, key, value):
53 | # NOTE: not thread safe
54 | if self.readonly:
55 | raise ValueError('readonly text DB')
56 | ret = self.txn.put(key.encode('utf-8'), compress(msgpack.dumps(value, use_bin_type=True)))
57 | self.write_cnt += 1
58 | if self.write_cnt % 1000 == 0:
59 | self.txn.commit()
60 | self.txn = self.env.begin(write=True)
61 | self.write_cnt = 0
62 | return ret
63 |
64 | def preprocess_example(self, ex_id, ex_data, faiss_indices, faiss_data):
65 | txt_len = ex_data['txt_len']
66 | # image tokens
67 | if ex_data['img_coords'] is not None and len(ex_data['img_coords']) > 0:
68 | ex_data['img_grid_size'] = np.max(ex_data['img_coords'], axis=0) + 1
69 | img_coords = [tuple(coords) for coords in ex_data['img_coords']]
70 | img_coords = np.array(img_coords, np.dtype([('x', int), ('y', int)]))
71 | img_sort = np.argsort(img_coords, order=('y', 'x'))
72 | ex_data['img_coords'] = np.take_along_axis(img_coords, img_sort, axis=0).tolist()
73 | img_tokens = np.array(ex_data['tokens'][txt_len:])
74 | ex_data['tokens'][txt_len:] = np.take_along_axis(img_tokens, img_sort, axis=0)
75 | ex_data['attention'][:,:,txt_len:] = np.take(ex_data['attention'][:,:,txt_len:], img_sort, axis=2)
76 | ex_data['attention'][:,:,:,txt_len:] = np.take(ex_data['attention'][:,:,:,txt_len:], img_sort, axis=3)
77 | else:
78 | ex_data['img_grid_size'] = (0, 0)
79 | print(f'Warning: Image coordinates are missing for example #{ex_id}.')
80 | # t-SNE
81 | tsne = [TSNE(n_components=2, random_state=None, n_jobs=-1).fit_transform(ex_data['hidden_states'][i])
82 | for i in range(len(ex_data['hidden_states']))]
83 | if len(faiss_indices) == 0:
84 | faiss_indices = {(layer, mod): faiss.IndexFlatL2(2) \
85 | for layer in range(len(tsne)) for mod in ('txt', 'img')}
86 | ex_data['tsne'] = tsne
87 | self[str(ex_id)] = pickle.dumps(ex_data, protocol=pickle.HIGHEST_PROTOCOL)
88 | self['n_examples'] += 1
89 | # faiss
90 | for layer in range(len(tsne)):
91 | faiss_all_tokens = [(ex_id, i) for i in range(len(ex_data['tokens']))]
92 | faiss_data[(layer, 'txt')] += faiss_all_tokens[:txt_len]
93 | faiss_data[(layer, 'img')] += faiss_all_tokens[txt_len:]
94 | tsne_txt = tsne[layer][:txt_len]
95 | tsne_img = tsne[layer][txt_len:]
96 | faiss_indices[(layer, 'txt')].add(tsne_txt)
97 | faiss_indices[(layer, 'img')].add(tsne_img)
98 | return faiss_indices, faiss_data
99 |
100 | def preprocess(self):
101 | print('Preprocessing database...')
102 | faiss_indices = {}
103 | faiss_data = defaultdict(list)
104 | n_layers, n_heads = 0, 0
105 | with self.txn.cursor() as cursor:
106 | for key, value in tqdm(cursor):
107 | ex_id = str(key, 'utf8')
108 | if not ex_id.isdigit():
109 | continue
110 | ex_data = pickle.loads(msgpack.loads(decompress(value), raw=False))
111 | faiss_indices, faiss_data = self.preprocess_example(ex_id, ex_data, faiss_indices, faiss_data)
112 |
113 | n_layers, n_heads, _, _ = ex_data['attention'].shape
114 | self['n_layers'] = n_layers
115 | self['n_heads'] = n_heads
116 |
117 | print('Generating faiss indices...')
118 | faiss_path = os.path.join(self.path, 'faiss')
119 | if not os.path.exists(faiss_path):
120 | os.makedirs(faiss_path)
121 | for layer in range(n_layers + 1):
122 | faiss.write_index(faiss_indices[(layer, 'txt')], os.path.join(faiss_path, f'txt_indices_{layer}'))
123 | faiss.write_index(faiss_indices[(layer, 'img')], os.path.join(faiss_path, f'img_indices_{layer}'))
124 | self['faiss'] = pickle.dumps(faiss_data, protocol=pickle.HIGHEST_PROTOCOL)
125 | print('Preprocessing done.')
126 |
--------------------------------------------------------------------------------
/app/database/db_analyzer.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | from scipy.stats import zscore
4 | import faiss
5 | import pickle
6 | from app.database.database import VliLmdb
7 |
8 | ATTN_MAP_SCALE_FACTOR = 64
9 |
10 | class VliDataBaseAnalyzer:
11 | def __init__(self, db_dir, read_only=True):
12 | self.db_dir = db_dir
13 | self.db = VliLmdb(db_dir, read_only=read_only)
14 | faiss_folder = os.path.join(db_dir, 'faiss')
15 | self.faiss_indices = {(layer, mod): faiss.read_index(f'{faiss_folder}/{mod}_indices_{layer}')
16 | for layer in range(self['n_layers'] + 1) for mod in ('txt', 'img')}
17 | self.faiss_data = self['faiss']
18 | self.updated = False
19 |
20 | def __getitem__(self, item):
21 | if isinstance(self.db[item], bytes):
22 | return pickle.loads(self.db[item])
23 | return self.db[item]
24 |
25 | def __del__(self):
26 | if self.updated:
27 | print('Writing new examples to database...')
28 | self.db['faiss'] = pickle.dumps(self.faiss_data, protocol=pickle.HIGHEST_PROTOCOL)
29 | faiss_path = os.path.join(self.db.path, 'faiss')
30 | for layer in range(self['n_layers'] + 1):
31 | faiss.write_index(self.faiss_indices[(layer, 'txt')], os.path.join(faiss_path, f'txt_indices_{layer}'))
32 | faiss.write_index(self.faiss_indices[(layer, 'img')], os.path.join(faiss_path, f'img_indices_{layer}'))
33 | print('Done.')
34 |
35 | def add_example(self, ex_id, ex_data):
36 | self.faiss_indices, self.faiss_data = self.db.preprocess_example(ex_id, ex_data, self.faiss_indices, self.faiss_data)
37 | self.updated = True
38 |
39 | def get_ex_attn(self, ex_id, exclude_tokens=['[SEP]', '[CLS]']):
40 | ex_attn = self[ex_id]['attention']
41 | ex_tokens = self[ex_id]['tokens']
42 | if exclude_tokens:
43 | exclude_indices = [i for i, token in enumerate(ex_tokens) if token in exclude_tokens]
44 | ex_attn = np.delete(ex_attn, exclude_indices, axis=2)
45 | ex_attn = np.delete(ex_attn, exclude_indices, axis=3)
46 | return ex_attn
47 |
48 | def get_attn_means(self, attn, normalize=True):
49 | head_avg = attn.mean(axis=(2, 3))
50 | if normalize:
51 | head_avg = zscore(head_avg)
52 | layer_avg = np.mean(head_avg, axis=1)
53 | return head_avg, layer_avg
54 |
55 | def get_attn_components_means(self, components):
56 | avg_attn = np.mean(np.array(components), axis=0)
57 | avg_attn = zscore(avg_attn) # normalize
58 | layer_avg = np.mean(avg_attn, axis=1)
59 | return avg_attn, layer_avg
60 |
61 | def img2txt_mean_attn(self, ex_id, exclude_tokens=['[SEP]', '[CLS]']):
62 | ex_attn = self.get_ex_attn(ex_id, exclude_tokens)
63 | txt_len = self[ex_id]['txt_len']
64 | img2txt_attn = ex_attn[:, :, :txt_len, txt_len:]
65 | return self.get_attn_means(img2txt_attn)
66 |
67 | def txt2img_mean_attn(self, ex_id, exclude_tokens=['[SEP]', '[CLS]']):
68 | ex_attn = self.get_ex_attn(ex_id, exclude_tokens)
69 | txt_len = self[ex_id]['txt_len']
70 | txt2img_attn = ex_attn[:, :, txt_len:, :txt_len]
71 | return self.get_attn_means(txt2img_attn)
72 |
73 | def txt2txt_mean_attn(self, ex_id, exclude_tokens=['[SEP]', '[CLS]']):
74 | ex_attn = self.get_ex_attn(ex_id, exclude_tokens)
75 | txt_len = self[ex_id]['txt_len']
76 | txt2txt_attn = ex_attn[:, :, :txt_len, :txt_len]
77 | return self.get_attn_means(txt2txt_attn)
78 |
79 | def img2img_mean_attn(self, ex_id, exclude_tokens=['[SEP]', '[CLS]']):
80 | ex_attn = self.get_ex_attn(ex_id, exclude_tokens)
81 | txt_len = self[ex_id]['txt_len']
82 | txt2img_attn = ex_attn[:, :, txt_len:, txt_len:]
83 | return self.get_attn_means(txt2img_attn)
84 |
85 | def img2img_mean_attn_without_self(self, ex_id, exclude_tokens=['[SEP]', '[CLS]']):
86 | ex_attn = self.get_ex_attn(ex_id, exclude_tokens)
87 | txt_len = self[ex_id]['txt_len']
88 | txt2img_attn = ex_attn[:, :, txt_len:, txt_len:]
89 | n_layer, n_head, attn_len, attn_len = txt2img_attn.shape
90 | for i in range(attn_len):
91 | for j in range(attn_len):
92 | if i == j or i == j+1 or j == i+1:
93 | txt2img_attn[:, :, i, j ] = 0
94 | return self.get_attn_means(txt2img_attn)
95 |
96 | def get_all_attn_stats(self, ex_id, exclude_tokens=['[SEP]', '[CLS]']):
97 | attn_stats = []
98 | for func in (self.img2txt_mean_attn,
99 | self.txt2img_mean_attn,
100 | self.img2img_mean_attn,
101 | self.img2img_mean_attn_without_self,
102 | self.txt2txt_mean_attn):
103 | attn, layer_avg = func(ex_id, exclude_tokens)
104 | attn_stats.append((attn, layer_avg))
105 | # crossmodal
106 | attn_stats.append(self.get_attn_components_means([attn_stats[0][0], attn_stats[1][0]]))
107 | # intramodal
108 | attn_stats.append(self.get_attn_components_means([attn_stats[2][0], attn_stats[3][0]]))
109 | for i, (stats, layer_avg) in enumerate(attn_stats):
110 | attn_stats[i] = np.hstack((stats, layer_avg.reshape(layer_avg.shape[0], 1)))
111 | return attn_stats
112 |
113 | def get_custom_metrics(self, ex_id):
114 | if 'custom_metrics' in self[ex_id]:
115 | cm = self[ex_id]['custom_metrics']
116 | labels, stats = [], []
117 | for metrics in cm:
118 | data = cm[metrics]
119 | mean = np.mean(data, axis=1)
120 | mean = mean.reshape(mean.shape[0], 1)
121 | data = np.hstack((data, mean))
122 | labels.append(metrics)
123 | stats.append(data)
124 | return labels, stats
125 |
126 | def find_closest_token(self, tsne, layer, mod):
127 | _, index = self.faiss_indices[(layer, mod)].search(tsne.reshape(1, -1), 1) # returns distance, index
128 | index = index.item(0)
129 | ex_id, token_id = self.faiss_data[(layer, mod)][index]
130 | return ex_id, token_id
131 |
132 | def get_txt_token_index(self, ex_id, text_tokens_id):
133 | # text_tokens_id: {'sentence': 0, 'word': 0}
134 | if not text_tokens_id:
135 | return
136 | sep_idx = [-1] + [i for i, x in enumerate(self[ex_id]['tokens']) if x == '[SEP]']
137 | sentence, word = text_tokens_id['sentence'], text_tokens_id['word']
138 | return sep_idx[sentence] + 1 + word
139 |
140 | def get_img_token_index(self, ex_id, img_coords):
141 | txt_len = self[ex_id]['txt_len']
142 | return self[ex_id]['img_coords'].index(tuple(img_coords)) + txt_len
143 |
144 | def get_img_unit_len(self, ex_id):
145 | token_grid_size = self[ex_id]['img_grid_size']
146 | image_size = self[ex_id]['image'].shape
147 | return image_size[1]//token_grid_size[0], image_size[0]//token_grid_size[1]
148 |
--------------------------------------------------------------------------------
/app/database/db_example.py:
--------------------------------------------------------------------------------
1 | import pickle
2 | import numpy as np
3 | from database import VliLmdb
4 |
5 | if __name__ == '__main__':
6 |
7 | # The following example contains randomly generated data to
8 | # illustrate what the data should look like before preprocessing
9 | example_data = []
10 | for ex_id in range(3):
11 | example_data.append(
12 | # each example should contain the following information:
13 | {
14 | # Example ID (integers starting from 0)
15 | 'ex_id': ex_id,
16 |
17 | # The original input image (RGB)
18 | # If your model preprocesses the image (e.g., resizing, padding), you may want to
19 | # use the preprocessed image instead of the original
20 | 'image': np.ones((450, 800, 3), dtype=int) * 100 * ex_id + 30,
21 |
22 | # Input tokens (text tokens followed by image tokens)
23 | 'tokens': ['[CLS]', 'text', 'input', 'for', 'example', str(ex_id), '.', '[SEP]',
24 | 'IMG_0', 'IMG_1', 'IMG_2', 'IMG_3', 'IMG_4', 'IMG_5'],
25 |
26 | # The number of text tokens
27 | 'txt_len': 8,
28 |
29 | # The (x, y) coordinates of each image token on the original image,
30 | # assuming the *top left* corner of an image is (0, 0)
31 | # The order of coordinates should correspond to how image tokens are ordered in 'tokens'
32 | 'img_coords': [(0, 0), (0, 1), (1, 0), (1, 1), (2, 0), (2, 1)],
33 |
34 | # (Optional) Whether model predicted correctly for this example
35 | 'accuracy': ex_id % 2 == 0, # either True or False
36 |
37 | # Attention weights for all attention heads in all layers
38 | # Shape: (n_layers, n_attention_heads_per_layer, n_tokens, n_tokens)
39 | # n_layers and n_attention_heads_per_layer should be the same accross example
40 | # The order of columns and rows of the attention weight matrix for each head should
41 | # correspond to how tokens are ordered in 'tokens'
42 | 'attention': np.random.rand(12, 12, 14, 14),
43 |
44 | # The hidden representations for each token in the model,
45 | # both before the first layer and after each layer
46 | # Shape: (n_layers + 1, n_tokens, hidden_state_vector_size)
47 | # Note that in our demo app, hidden representations of stop words were removed
48 | # to reduce the number of displayed datapoints
49 | 'hidden_states': np.random.rand(13, 14, 768),
50 |
51 | # (Optional) Custom statistics for attention heads in all layers
52 | # Shape: (n_layers, n_attention_heads_per_layer)
53 | # The order should follow how attention heads are ordered in 'attention' matrices
54 | 'custom_metrics': {'Example Custom Metrics': np.random.rand(12, 12)}
55 | }
56 | )
57 |
58 | # Create database
59 | print('Creating database...')
60 | db = VliLmdb(db_dir='example_database1', read_only=False)
61 | for ex_id, ex_data in enumerate(example_data):
62 | # Keys must be strings
63 | db[str(ex_id)] = pickle.dumps(ex_data, protocol=pickle.HIGHEST_PROTOCOL)
64 |
65 | # Preprocess the database
66 | db.preprocess()
67 |
68 |
--------------------------------------------------------------------------------
/app/database/models/kdvlp.py:
--------------------------------------------------------------------------------
1 | '''
2 | The following script runs a forward passes with the KD-VLP model for
3 | each given image+text pair, and produces the corresponding attention
4 | and hidden states that can be visualized with VL-InterpreT.
5 |
6 | The KD-VLP model has not been made publicly available in this repo.
7 | Please create your own model class by inheriting from VL_Model.
8 |
9 | Note that to run your own model with VL-InterpreT, you are only
10 | required to implement the data_setup function. Most of the code in
11 | this file are specific to our KD-VLP model, and they are provided just
12 | for your reference.
13 | '''
14 |
15 |
16 | import numpy as np
17 | import torch
18 | import pickle
19 | from torch.nn.utils.rnn import pad_sequence
20 | from transformers import BertTokenizer
21 |
22 | try:
23 | from app.database.models.vl_model import VL_Model
24 | except ModuleNotFoundError:
25 | from vl_model import VL_Model
26 |
27 | import sys
28 | MODEL_DIR = '/workdisk/ccr_vislang/benchmarks/vision_language/vcr/VILLA/'
29 | sys.path.append(MODEL_DIR)
30 |
31 | from model.vcr import UniterForVisualCommonsenseReasoning
32 | from utils.const import IMG_DIM
33 | from utils.misc import remove_prefix
34 | from data import get_transform
35 |
36 |
37 | class Kdvlp(VL_Model):
38 | '''
39 | Running KD-VLP with VL-Interpret:
40 | python run_app.py -p 6006 -d example_database2 \
41 | -m kdvlp /data1/users/shaoyent/e2e-vcr-kgmvm-pgm-run49-2/ckpt/model_step_8500.pt
42 | '''
43 | def __init__(self, ckpt_file, device='cuda'):
44 | self.device = torch.device(device, 0)
45 | if device == 'cuda':
46 | torch.cuda.set_device(0)
47 | self.model, self.tokenizer = self.build_model(ckpt_file)
48 |
49 |
50 | def build_model(self, ckpt_file):
51 | tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
52 | tokenizer._add_tokens([f'[PERSON_{i}]' for i in range(81)])
53 |
54 | ckpt = torch.load(ckpt_file)
55 | checkpoint = {remove_prefix(k.replace('bert', 'uniter'), 'module.') : v for k, v in ckpt.items()}
56 | model = UniterForVisualCommonsenseReasoning.from_pretrained(
57 | f'{MODEL_DIR}/config/uniter-base.json', state_dict={},
58 | img_dim=IMG_DIM)
59 | model.init_type_embedding()
60 | model.load_state_dict(checkpoint, strict=False)
61 | model.eval()
62 | model = model.to(self.device)
63 | return model, tokenizer
64 |
65 |
66 | def move_to_device(self, x):
67 | if isinstance(x, list):
68 | return [self.move_to_device(y) for y in x]
69 | elif isinstance(x, dict):
70 | new_dict = {}
71 | for k, v in x.items():
72 | new_dict[k] = self.move_to_device(x[k])
73 | return new_dict
74 | elif isinstance(x, torch.Tensor):
75 | return x.to(self.device)
76 | else:
77 | return x
78 |
79 |
80 | def build_batch(self, input_text, image, answer=None, person_info=None):
81 | if not input_text:
82 | input_text = ''
83 | if answer is None:
84 | input_ids = self.tokenizer.encode(input_text, return_tensors='pt')
85 | input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0) # all*input_ids
86 | txt_type_ids = torch.zeros_like(input_ids)
87 | else:
88 | input_ids_q = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(input_text))
89 | input_ids_c = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(answer))
90 | input_ids = [torch.tensor(self.tokenizer.build_inputs_with_special_tokens(input_ids_q, input_ids_c))]
91 | input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0) # all*input_ids
92 |
93 | txt_type_ids = torch.tensor((len(input_ids_q) + 2 )* [0] + (len(input_ids_c) + 1) * [2])
94 |
95 | position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long).unsqueeze(0)
96 | num_sents = [input_ids.size(0)]
97 | txt_lens = [i.size(0) for i in input_ids]
98 |
99 | if image is None:
100 | images_batch = None
101 | else:
102 | images_batch = torch.as_tensor(image.copy(), dtype=torch.float32)
103 | images_batch = get_transform(images_batch.permute(2, 0, 1))
104 |
105 | batch = {'input_ids': input_ids, 'txt_type_ids': txt_type_ids, 'position_ids': position_ids, 'images': images_batch,
106 | "txt_lens": txt_lens, "num_sents": num_sents, 'person_info': person_info}
107 | batch = self.move_to_device(batch)
108 |
109 | return batch
110 |
111 |
112 | def data_setup(self, ex_id, image_location, input_text):
113 | image = self.fetch_image(image_location) if image_location else None
114 |
115 | batch = self.build_batch(input_text, image, answer=None, person_info=None)
116 | scores, hidden_states, attentions = self.model(batch,
117 | compute_loss=False,
118 | output_attentions=True,
119 | output_hidden_states=True)
120 |
121 | attentions = torch.stack(attentions).transpose(1,0).detach().cpu()[0]
122 |
123 | if batch['images'] is None:
124 | img, img_coords = np.array([]), []
125 | len_img = 0
126 | else:
127 | image1, mask1 = self.model.preprocess_image(batch['images'].to(self.device))
128 | image1 = (image1 * self.model.pixel_std + self.model.pixel_mean) * mask1
129 | img = image1.cpu().numpy().astype(int).squeeze().transpose(1,2,0)
130 |
131 | h, w, _ = img.shape
132 | h0, w0 = h//64, w//64
133 | len_img = w0 * h0
134 | img_coords = np.fliplr(list(np.ndindex(h0, w0)))
135 |
136 | input_ids = batch['input_ids'].cpu()
137 | len_text = input_ids.size(1)
138 | txt_tokens = self.tokenizer.convert_ids_to_tokens(input_ids[0, :len_text])
139 |
140 | len_tokens = len_text + len_img
141 | attentions = attentions[:, :, :len_tokens, :len_tokens]
142 | hidden_states = [hs[0].detach().cpu().numpy()[:len_tokens] for hs in hidden_states]
143 |
144 | return {
145 | 'ex_id': ex_id,
146 | 'image': img,
147 | 'tokens': txt_tokens + [f'IMG_{i}' for i in range(len_img)],
148 | 'txt_len': len(txt_tokens),
149 | 'attention': attentions.detach().cpu().numpy(),
150 | 'img_coords': img_coords,
151 | 'hidden_states': hidden_states
152 | }
153 |
154 |
155 | def create_example_db():
156 | '''
157 | This function creates example_database2.
158 | '''
159 | images = [f'{MODEL_DIR}/visualization/ex.jpg']
160 | texts = ['Horses are pulling a carriage, while someone is standing on the top of a golden ball.']
161 | kdvlp = Kdvlp('/data1/users/shaoyent/e2e-vcr-kgmvm-pgm-run49-2/ckpt/model_step_8500.pt')
162 | data = [kdvlp.data_setup(i, img, txt) for i, (img, txt) in enumerate(zip(images, texts))]
163 |
164 | db = VliLmdb(db_dir='/workdisk/VL-InterpreT/example_database2', read_only=False)
165 | for i, dat in enumerate(data):
166 | db[str(i)] = pickle.dumps(dat, protocol=pickle.HIGHEST_PROTOCOL)
167 | db.preprocess()
168 |
169 |
170 | if __name__ == '__main__':
171 | sys.path.append('..')
172 | from database import VliLmdb
173 | create_example_db()
174 |
--------------------------------------------------------------------------------
/app/database/models/vl_model.py:
--------------------------------------------------------------------------------
1 | '''
2 | An abstract base class for live models that can run together with VL-InterpreT.
3 |
4 | To run your own model with VL-InterpreT, create another file your_model.py in this
5 | folder that contains a class Your_Model (use title case for the class name), which
6 | inherits from the VL_Model class and implements the data_setup method. The data_setup
7 | method should take the ID, image and text of a given example, run a forward pass for
8 | this example with your model, and return the corresponding attention, hidden states
9 | and other required data that can be visualized with VL-InterpreT.
10 | '''
11 |
12 |
13 | from abc import ABC, abstractmethod
14 | import numpy as np
15 | from PIL import Image
16 | import urllib.request
17 |
18 |
19 | class VL_Model(ABC):
20 | '''
21 | To run a live transformer with VL-InterpreT, define your own model class by inheriting
22 | from this class and implementing the data_setup method.
23 |
24 | Please follow these naming patterns to make sure your model runs easily with VL-InterpreT:
25 | - Create a new python script in this folder for your class, and name it in all lower
26 | case (e.g., yourmodelname.py)
27 | - Name your model class in title case, e.g., Yourmodelname. This class name should be
28 | the result of calling 'yourmodelname'.title(), where 'yourmodelname.py' is the name
29 | of your python script.
30 |
31 | Then you can run VL-InterpreT with your model:
32 | python run_app.py -p 6006 -d example_database2 -m yourmodelname your_model_parameters
33 | '''
34 |
35 | @abstractmethod
36 | def data_setup(self, example_id: int, image_location: str, input_text: str) -> dict:
37 | '''
38 | This method should run a forward pass with your model given the input image and
39 | text, and return the required data. See app/database/db_example.py for specifications
40 | of the return data format, and see the implementation in kdvlp.py for an example.
41 | '''
42 | return {
43 | 'ex_id': example_id,
44 | 'image': np.array(),
45 | 'tokens': [],
46 | 'txt_len': 0,
47 | 'img_coords': [],
48 | 'attention': np.array(),
49 | 'hidden_states': np.array()
50 | }
51 |
52 |
53 | def fetch_image(self, image_location: str):
54 | '''
55 | This helper function takes the path to an image (either an URL or a local path) and
56 | returns the image as an numpy array.
57 | '''
58 | if image_location.startswith('http'):
59 | urllib.request.urlretrieve(image_location, 'temp.jpg')
60 | image_location = 'temp.jpg'
61 |
62 | img = Image.open(image_location).convert('RGB')
63 | img = np.array(img)
64 | return img
65 |
--------------------------------------------------------------------------------
/app/plot_func.py:
--------------------------------------------------------------------------------
1 | ''' Module containing plotting functions.
2 |
3 | These plotting functions are used in appConfiguration.py to
4 | generate all the the plots for the UI.
5 | '''
6 |
7 |
8 | import base64
9 | from io import BytesIO
10 | from PIL import Image
11 | from dash import html
12 |
13 | import numpy as np
14 | import matplotlib
15 | matplotlib.use('Agg')
16 | import matplotlib.pyplot as plt
17 | import matplotlib.cm as cm
18 | import plotly.graph_objects as go
19 | import scipy.ndimage
20 |
21 |
22 | IMG_LABEL_SIZE = 12
23 | LAYER_AVG_GAP = 0.8
24 | ATTN_MAP_SCALE_FACTOR = 64
25 |
26 | heatmap_layout = dict(
27 | autosize=True,
28 | font=dict(color='black'),
29 | titlefont=dict(color='black', size=14),
30 | legend=dict(font=dict(size=8), orientation='h'),
31 | )
32 |
33 | img_overlay_layout = dict(
34 | barmode='stack',
35 | bargap=0,
36 | hovermode='closest',
37 | showlegend=False,
38 | autosize=False,
39 | margin={'l': 0, 'r': 0, 't': 0, 'b': 0},
40 | paper_bgcolor='rgba(0,0,0,0)',
41 | plot_bgcolor='rgba(0,0,0,0)',
42 | xaxis=dict(visible=False, fixedrange=True),
43 | yaxis=dict(visible=False, scaleanchor='x') # constant aspect ratio
44 | )
45 |
46 | def get_sep_indices(tokens):
47 | return [-1] + [i for i, x in enumerate(tokens) if x == '[SEP]']
48 |
49 |
50 | def plot_text(db, ex_id):
51 | tokens = db[ex_id]['tokens']
52 | txt_len = db[ex_id]['txt_len']
53 | return show_texts(tokens[:txt_len]), txt_len
54 |
55 | def plot_attn_from_txt(db, ex_id, layer, head, word_ids=None, img_coords=None):
56 | ex_data = db[ex_id]
57 | txt_len = ex_data['txt_len']
58 | txt_tokens = ex_data['tokens'][:txt_len]
59 | if img_coords:
60 | w, h = db[ex_id]['img_grid_size']
61 | actual_img_tok_ids = [txt_len + w*(h-j-1) + i for i, j in img_coords] # TODO double check TODO TODO order image tokens TODO
62 | token_id = actual_img_tok_ids[0] # TODO: avg if multiple
63 | elif word_ids:
64 | token_id = db.get_txt_token_index(ex_id, word_ids)
65 | else: # nothing selected, show empty text
66 | return show_texts(txt_tokens)
67 | # get attention
68 | attn = ex_data['attention']
69 | if int(head) == head:
70 | txt_attn = attn[layer, head, token_id, :txt_len]
71 | else:
72 | layer_attn = [attn[layer, hd, token_id, :txt_len] for hd in range(attn.shape[0])]
73 | txt_attn = np.mean(layer_attn, axis=0)
74 | # get colors by sentence
75 | colors = cm.get_cmap('Reds')(txt_attn * 120, 0.5)
76 | colors = ['rgba(' + ','.join(map(lambda x: str(int(x*255)), rgba[:3])) + ',' + str(rgba[3]) + ')'
77 | for rgba in colors]
78 | seps = get_sep_indices(ex_data['tokens'])
79 | colors = [colors[i+1:j+1] for i, j in zip(seps[:-1], seps[1:])]
80 | txt_fig = show_texts(txt_tokens, colors)
81 | return txt_fig
82 |
83 | def highlight_txt_token(db, ex_id, token_ids):
84 | txt_len = db[ex_id]['txt_len']
85 | txt_tokens = db[ex_id]['tokens'][:txt_len]
86 | if token_ids:
87 | seps = get_sep_indices(txt_tokens)
88 | sentence, word = token_ids['sentence'], token_ids['word']
89 | sentence_len = seps[sentence+1] - seps[sentence]
90 | colors = []
91 | for s in range(len(seps) - 1):
92 | if s == sentence:
93 | colors.append(['white'] * sentence_len)
94 | colors[-1][word] = 'orange'
95 | else:
96 | colors.append('white')
97 | else:
98 | colors = 'white'
99 | txt_fig = show_texts(txt_tokens, colors)
100 | return txt_fig
101 |
102 | def plot_image(db, ex_id):
103 | image = db[ex_id]['image']
104 | return show_img(image, bg='black', opacity=1.0)
105 |
106 | def get_empty_fig():
107 | fig = go.Figure()
108 | fig.update_layout(img_overlay_layout)
109 | return fig
110 |
111 | def get_overlay_grid(db, ex_id, color='rgba(255, 255, 255, 0)'):
112 | imgh, imgw, _ = db[ex_id]['image'].shape
113 | gridw, gridh = db[ex_id]['img_grid_size']
114 | unit_w = imgw / gridw
115 | unit_h = imgh / gridh
116 | grid_data = tuple(go.Bar(
117 | x=np.linspace(unit_w/2, imgw-unit_w/2, gridw),
118 | y=[unit_h] * gridw,
119 | hoverinfo='none',
120 | marker_line_width=0,
121 | marker_color=(color if type(color) == str else color[i]))
122 | for i in range(gridh))
123 | return grid_data
124 |
125 | def get_overlay_fig(db, ex_id, color='rgba(255, 255, 255, 0)'):
126 | grid_data = get_overlay_grid(db, ex_id, color)
127 | imgh, imgw, _ = db[ex_id]['image'].shape
128 | fig = go.Figure(
129 | data=grid_data,
130 | layout=go.Layout(
131 | xaxis={'range': (0, imgw)},
132 | yaxis={'range': (0, imgh)})
133 | )
134 | fig.update_layout(img_overlay_layout)
135 | return fig
136 |
137 | def highlight_img_grid_selection(db, ex_id, img_selection):
138 | img_selection = set(map(tuple, img_selection))
139 | w, h = db[ex_id]['img_grid_size']
140 | highlight, transparent = 'rgba(255, 165, 0, .4)', 'rgba(255, 255, 255, 0)'
141 | colors = [[(highlight if (j, i) in img_selection else transparent) for j in range(w)] for i in range(h)]
142 | return get_overlay_fig(db, ex_id, colors)
143 |
144 | def plot_attn_from_img(db, ex_id, layer, head, word_ids=None, img_coords=None, smooth=True):
145 | ex_data = db[ex_id]
146 | w0, h0 = ex_data['img_grid_size']
147 | txt_len = ex_data['txt_len']
148 | if word_ids and word_ids['sentence'] > -1 and word_ids['word'] > -1:
149 | token_id = db.get_txt_token_index(ex_id, word_ids)
150 | elif img_coords:
151 | img_token_ids = [db.get_img_token_index(ex_id, coords) for coords in img_coords]
152 | token_id = img_token_ids[0] # TODO show avg for multiple selection?
153 | else:
154 | return get_empty_fig()
155 | attn = ex_data['attention']
156 |
157 | if head < attn.shape[1]:
158 | img_attn = attn[layer, head, token_id, txt_len:(w0*h0+txt_len)].reshape(h0, w0)
159 | else: # show layer avg
160 | layer_attn = [attn[layer, h, token_id, txt_len:(w0*h0+txt_len)].reshape(h0, w0) \
161 | for h in range(attn.shape[0])]
162 | img_attn = np.mean(layer_attn, axis=0)
163 | if smooth:
164 | img_attn = scipy.ndimage.zoom(img_attn, ATTN_MAP_SCALE_FACTOR, order=1)
165 | return show_img(img_attn, opacity=0.3, bg='rgba(0,0,0,0)', hw=ex_data['image'].shape[:2])
166 |
167 | def plot_head_summary(db, ex_id, attn_type_index=0, layer=None, head=None):
168 | stats = db.get_all_attn_stats(ex_id)
169 | custom_stats = db.get_custom_metrics(ex_id)
170 | return show_head_summary(stats, custom_stats, attn_type_index, layer, head)
171 |
172 | def plot_attn_matrix(db, ex_id, layer, head, attn_type=0):
173 | txt_len = db[ex_id]['txt_len']
174 | txt_attn = db[ex_id]['attention'][layer, head, :txt_len, :txt_len]
175 | img_attn = db[ex_id]['attention'][layer, head, txt_len:, txt_len:]
176 | tokens = db[ex_id]['tokens']
177 | return show_attn_matrix([txt_attn, img_attn], tokens, layer, head, attn_type)
178 |
179 |
180 | def show_texts(tokens, colors='white'):
181 | seps = get_sep_indices(tokens)
182 | sentences = [tokens[i+1:j+1] for i, j in zip(seps[:-1], seps[1:])]
183 |
184 | fig = go.Figure()
185 | annotations = []
186 | for sen_i, sentence in enumerate(sentences):
187 | word_lengths = list(map(len, sentence))
188 | fig.add_trace(go.Bar(
189 | x=word_lengths, # TODO center at 0
190 | y=[sen_i] * len(sentence),
191 | orientation='h',
192 | marker_color=(colors if type(colors) is str else colors[sen_i]),
193 | marker_line=dict(color='rgba(255, 255, 255, 0)', width=0),
194 | hoverinfo='none'
195 | ))
196 | word_pos = np.cumsum(word_lengths) - np.array(word_lengths) / 2
197 | for word_i in range(len(sentence)):
198 | annotations.append(dict(
199 | xref='x', yref='y',
200 | x=word_pos[word_i], y=sen_i,
201 | text=sentence[word_i],
202 | showarrow=False
203 | ))
204 | fig.update_xaxes(visible=False, fixedrange=True)
205 | fig.update_yaxes(visible=False, fixedrange=True, range=(len(sentences)+.3, -len(sentences)+.7))
206 | fig.update_layout(
207 | annotations=annotations,
208 | barmode='stack',
209 | autosize=False,
210 | margin={'l': 0, 'r': 0, 't': 0, 'b': 0},
211 | showlegend=False,
212 | plot_bgcolor='white'
213 | )
214 | return fig
215 |
216 |
217 | def show_img(img, opacity, bg, hw=None):
218 | img_height, img_width = hw if hw else (img.shape[0], img.shape[1])
219 | mfig, ax = plt.subplots(figsize=(img_width/100., img_height/100.), dpi=100)
220 | plt.subplots_adjust(left=0, right=1, top=1, bottom=0, wspace=0, hspace=0)
221 | plt.axis('off')
222 | if hw:
223 | ax.imshow(img, cmap='jet', interpolation='nearest', aspect='auto')
224 | else:
225 | ax.imshow(img)
226 | img_uri = fig_to_uri(mfig)
227 | fig_width, fig_height = mfig.get_size_inches() * mfig.dpi
228 |
229 | fig = go.Figure()
230 | fig.update_xaxes(range=(0, fig_width))
231 | fig.update_yaxes(range=(0, fig_height))
232 | fig.update_layout(img_overlay_layout)
233 | fig.update_layout(
234 | autosize=True,
235 | plot_bgcolor=bg,
236 | paper_bgcolor=bg
237 | )
238 | fig.layout.images = [] # remove previous image
239 | fig.add_layout_image(dict(
240 | x=0, y=fig_height,
241 | sizex=fig_width, sizey=fig_height,
242 | xref='x', yref='y',
243 | opacity=opacity,
244 | sizing='stretch',
245 | source=img_uri
246 | ))
247 | return fig
248 |
249 |
250 | def fig_to_uri(fig, close_all=True, **save_args):
251 | out_img = BytesIO()
252 | fig.savefig(out_img, format='jpeg', **save_args)
253 | if close_all:
254 | fig.clf()
255 | plt.close('all')
256 | out_img.seek(0) # rewind file
257 | encoded = base64.b64encode(out_img.read()).decode('ascii').replace('\n', '')
258 | return 'data:image/jpeg;base64,{}'.format(encoded)
259 |
260 |
261 | def show_head_summary(attn_stats, custom_stats=None, attn_type_index=0, layer=None, head=None):
262 | n_stats = len(attn_stats) + (len(custom_stats) if custom_stats else 0)
263 | if attn_type_index >= n_stats:
264 | attn_type_index = 0
265 |
266 | num_layers, num_heads = np.shape(attn_stats[0])
267 | data = [go.Heatmap(
268 | type = 'heatmap',
269 | x=list(np.arange(-0.5, num_heads-1)) + [num_heads + LAYER_AVG_GAP - 0.5],
270 | z=stats,
271 | zmin=-3, zmax=3, # fix color scale range
272 | colorscale='plasma',
273 | reversescale=False,
274 | colorbar=dict(thickness=10),
275 | visible=(i == attn_type_index)
276 | ) for i, stats in enumerate(attn_stats)]
277 |
278 | if custom_stats:
279 | data += [go.Heatmap(
280 | type = 'heatmap',
281 | x=list(np.arange(-0.5, num_heads-1)) + [num_heads + LAYER_AVG_GAP - 0.5],
282 | z=stats,
283 | colorscale='plasma',
284 | reversescale=False,
285 | colorbar=dict(thickness=10),
286 | visible=(attn_type_index == len(data))) for stats in custom_stats[1]]
287 |
288 | if layer is not None and head is not None:
289 | x_pos = np.floor(head) + LAYER_AVG_GAP if head > num_heads - 1 else head
290 | data.append(dict(
291 | type='scattergl',
292 | x=[x_pos],
293 | y=[layer],
294 | marker=dict(
295 | color='black',
296 | symbol='x',
297 | size=15,
298 | opacity=0.6,
299 | line=dict(width=1, color='lightgrey')
300 | ),
301 | ))
302 | layout = heatmap_layout
303 | layout.update({'title': ''})
304 | layout.update({
305 | 'yaxis': {
306 | 'title': 'Layer #',
307 | 'tickmode': 'array',
308 | 'ticktext': list(range(1, num_layers+1)),
309 | 'tickvals': list(range(num_layers)),
310 | 'range': (num_layers-0.5, -0.5),
311 | 'fixedrange': True},
312 | 'xaxis': {
313 | 'title': 'Attention head #',
314 | 'tickmode': 'array',
315 | 'ticktext': list(range(1, num_heads)) + ['Mean'],
316 | 'tickvals': list(range(num_heads-1)) + [num_heads + LAYER_AVG_GAP - 1],
317 | 'range': (-0.5, num_heads + LAYER_AVG_GAP - 0.5),
318 | 'fixedrange': True},
319 | 'margin': {'t': 0, 'l': 0, 'r': 0, 'b': 0},
320 | 'height': 450,
321 | 'width': 520
322 | })
323 |
324 | fig = go.Figure(data=data, layout=layout)
325 | fig.add_vrect(x0=num_heads-1.5, x1=num_heads+LAYER_AVG_GAP-1.5, fillcolor='white', line_width=0)
326 |
327 | # dropdown menu
328 | labels = ['Mean image-to-text attention (without CLS/SEP)',
329 | 'Mean text-to-image attention (without CLS/SEP)',
330 | 'Mean image-to-image attention (without CLS/SEP)',
331 | 'Mean image-to-image attention (without self/CLS/SEP)',
332 | 'Mean text-to-text attention (without CLS/SEP)',
333 | 'Mean cross-modal attention (without CLS/SEP)',
334 | 'Mean intra-modal attention (without CLS/SEP)']
335 | if custom_stats:
336 | labels += custom_stats[0]
337 | fig.update_layout(
338 | updatemenus=[dict(
339 | buttons=[
340 | dict(
341 | args=[{'visible': [True if j == i else False for j in range(len(labels))] + [True]}],
342 | label=labels[i],
343 | method='update'
344 | ) for i in range(len(labels))
345 | ],
346 | direction='down',
347 | pad={'l': 85},
348 | active=attn_type_index,
349 | showactive=True,
350 | x=0,
351 | xanchor='left',
352 | y=1.01,
353 | yanchor='bottom'
354 | )],
355 | annotations=[dict(
356 | text='Display data:', showarrow=False,
357 | x=0, y=1.027, xref='paper', yref='paper',
358 | yanchor='bottom'
359 | )]
360 | )
361 |
362 | return fig
363 |
364 |
365 | def show_attn_matrix(attns, tokens, layer, head, attn_type):
366 | txt_tokens = tokens[:attns[0].shape[0]]
367 | data = [dict(
368 | type='heatmap',
369 | z=attn,
370 | colorbar=dict(thickness=10),
371 | visible=(i == attn_type)) for i, attn in enumerate(attns)]
372 | layout = heatmap_layout
373 | layout.update({
374 | 'title': f'Attention Matrix for Layer #{layer+1}, Head #{head+1}',
375 | 'title_font_size': 16,
376 | 'title_x': 0.5, 'title_y': 0.99,
377 | 'margin': {'t': 70, 'b': 10, 'l': 0, 'r': 0}
378 | })
379 | txt_attn_layout = {
380 | 'xaxis': dict(
381 | tickmode='array',
382 | tickvals=list(range(len(txt_tokens))),
383 | ticktext=txt_tokens,
384 | tickangle=45,
385 | tickfont=dict(size=12),
386 | range=(-0.5, len(txt_tokens)-0.5),
387 | fixedrange=True),
388 | 'yaxis': dict(
389 | tickmode='array',
390 | tickvals=list(range(len(txt_tokens))),
391 | ticktext=txt_tokens,
392 | tickfont=dict(size=12),
393 | range=(-0.5, len(txt_tokens)-0.5),
394 | fixedrange=True),
395 | 'width': 600,
396 | 'height': 600
397 | }
398 | img_attn_layout = {
399 | 'xaxis': dict(range=(-0.5, len(tokens)-len(txt_tokens)-0.5)),
400 | 'yaxis': dict(range=(-0.5, len(tokens)-len(txt_tokens)-0.5), scaleanchor='x'),
401 | 'width': 900,
402 | 'height': 900,
403 | 'plot_bgcolor': 'rgba(0,0,0,0)',
404 | }
405 | layout.update((txt_attn_layout, img_attn_layout)[attn_type])
406 | figure = go.Figure(dict(data=data, layout=layout))
407 |
408 | # dropdown menu
409 | figure.update_layout(
410 | updatemenus=[dict(
411 | buttons=[dict(
412 | args=[
413 | {'visible': [True, False]},
414 | txt_attn_layout
415 | ],
416 | label='Text-to-text attention',
417 | method='update'
418 | ), dict(
419 | args=[
420 | {'visible': [False, True]},
421 | img_attn_layout
422 | ],
423 | label='Image-to-image attention',
424 | method='update'
425 | )],
426 | direction='down',
427 | pad={'l': 85},
428 | active=attn_type,
429 | showactive=True,
430 | x=0,
431 | xanchor='left',
432 | y=1.01,
433 | yanchor='bottom'
434 | )],
435 | annotations=[dict(
436 | text='Display data:', showarrow=False,
437 | x=0, y=1.02,
438 | xref='paper', yref='paper',
439 | xanchor='left', yanchor='bottom'
440 | )],
441 | )
442 |
443 | return figure
444 |
445 | def show_tsne(db, ex_id, layer, text_tokens_id, img_tokens):
446 | ex_data = db[ex_id]
447 | selected_token_id = None
448 | if text_tokens_id and text_tokens_id['sentence'] >= 0 and text_tokens_id['word'] >= 0:
449 | selected_token_id = db.get_txt_token_index(ex_id, text_tokens_id)
450 | mod = 'img'
451 | elif img_tokens:
452 | selected_token_id = db.get_img_token_index(ex_id, img_tokens[0])
453 | mod = 'txt'
454 |
455 | # draw all tokens
456 | figure = go.Figure()
457 | txt_len = ex_data['txt_len']
458 | for modality in ('txt', 'img'):
459 | slicing = slice(txt_len) if modality == 'txt' else slice(txt_len, len(ex_data['tokens']))
460 | figure.add_trace(go.Scatter(
461 | x=ex_data['tsne'][layer][slicing][:,0],
462 | y=ex_data['tsne'][layer][slicing][:,1],
463 | hovertext=[f'{ex_id},{token_id}' for token_id in list(range(len(ex_data['tokens'])))[slicing]],
464 | mode='markers',
465 | name=modality,
466 | marker={'color': 'blue' if modality == 'img' else 'red'}
467 | ))
468 | if selected_token_id is not None:
469 | # hightlight selected
470 | token = ex_data['tokens'][selected_token_id]
471 | tsne = ex_data['tsne'][layer][selected_token_id]
472 | figure.add_trace(go.Scatter(
473 | x=[tsne[0]],
474 | y=[tsne[1]],
475 | mode='markers+text',
476 | hovertext=[f'{ex_id},{selected_token_id}'],
477 | text=token,
478 | textposition='top center',
479 | textfont=dict(size=15),
480 | marker=dict(
481 | color='orange',
482 | colorscale='Jet',
483 | showscale=False,
484 | symbol='star',
485 | size=10,
486 | opacity=1,
487 | cmin=-1.0,
488 | cmax=1.0,
489 | ),
490 | showlegend=False,
491 | ))
492 | # hightlight closest
493 | closest_ex_id, closest_token_id = db.find_closest_token(tsne, layer, mod)
494 | closest_token = db[closest_ex_id]['tokens'][closest_token_id]
495 | closest_tsne = db[closest_ex_id]['tsne'][layer][closest_token_id]
496 | figure.add_trace(go.Scatter(
497 | x=[closest_tsne[0]],
498 | y=[closest_tsne[1]],
499 | mode='markers+text',
500 | hovertext=[f'{closest_ex_id},{closest_token_id}'],
501 | text=closest_token + ' from ex ' + str(closest_ex_id),
502 | textposition='top center',
503 | textfont=dict(size=15),
504 | marker=dict(
505 | color='green',
506 | colorscale='Jet',
507 | showscale=False,
508 | symbol='star',
509 | size=10,
510 | opacity=1,
511 | cmin=-1.0,
512 | cmax=1.0,
513 | ),
514 | showlegend=False,
515 | ))
516 | figure.update_layout(legend=dict(font=dict(size=15), y=1, x=0.99, bgcolor='rgba(0,0,0,0)'))
517 | figure.update_layout(margin={'l': 0, 'r': 0, 'b': 0, 't': 0})
518 | figure.update_traces(hoverinfo='none', hovertemplate=None)
519 | return figure
520 |
521 |
522 | def plot_tooltip_content(db, ex_id, token_id):
523 | ex_data = db[ex_id]
524 | txt_len = ex_data['txt_len']
525 | token = ex_data['tokens'][token_id]
526 | if token_id < txt_len: # text token
527 | return [html.P(token)]
528 | # image token
529 | img_coords = ex_data['img_coords'][token_id-txt_len]
530 | img = np.copy(ex_data['image'])
531 | x_unit_len, y_unit_len = db.get_img_unit_len(ex_id)
532 | x = int(x_unit_len * img_coords[0])
533 | y_end = img.shape[0] - int(y_unit_len * img_coords[1])
534 | x_end = x + x_unit_len
535 | y = y_end - y_unit_len
536 | img[y:y_end, x:x_end, 0] = 0
537 | img[y:y_end, x:x_end, 1] = 0
538 | img[y:y_end, x:x_end, 2] = 255
539 | # dump img to base64
540 | buffer = BytesIO()
541 | img = Image.fromarray(np.uint8(img)).save(buffer, format='jpeg')
542 | encoded_image = base64.b64encode(buffer.getvalue()).decode()
543 | img_url = 'data:image/jpeg;base64, ' + encoded_image
544 | return [
545 | html.Div([
546 | html.Img(
547 | src=img_url,
548 | style={'width': '300px', 'display': 'block', 'margin': '0 auto'},
549 | ),
550 | html.P(token, style={'font-weight': 'bold','font-size': '7x'})
551 | ])
552 | ]
553 |
--------------------------------------------------------------------------------
/assets/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IntelLabs/VL-InterpreT/0f9d0bb1dad13d91ce79d9453faf81fc2b9277b9/assets/logo.png
--------------------------------------------------------------------------------
/assets/screencast.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IntelLabs/VL-InterpreT/0f9d0bb1dad13d91ce79d9453faf81fc2b9277b9/assets/screencast.png
--------------------------------------------------------------------------------
/assets/screenshot_add_ex.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IntelLabs/VL-InterpreT/0f9d0bb1dad13d91ce79d9453faf81fc2b9277b9/assets/screenshot_add_ex.png
--------------------------------------------------------------------------------
/assets/screenshot_app_launch.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IntelLabs/VL-InterpreT/0f9d0bb1dad13d91ce79d9453faf81fc2b9277b9/assets/screenshot_app_launch.png
--------------------------------------------------------------------------------
/example_database1/data.mdb:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IntelLabs/VL-InterpreT/0f9d0bb1dad13d91ce79d9453faf81fc2b9277b9/example_database1/data.mdb
--------------------------------------------------------------------------------
/example_database1/faiss/img_indices_0:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IntelLabs/VL-InterpreT/0f9d0bb1dad13d91ce79d9453faf81fc2b9277b9/example_database1/faiss/img_indices_0
--------------------------------------------------------------------------------
/example_database1/faiss/img_indices_1:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IntelLabs/VL-InterpreT/0f9d0bb1dad13d91ce79d9453faf81fc2b9277b9/example_database1/faiss/img_indices_1
--------------------------------------------------------------------------------
/example_database1/faiss/img_indices_10:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IntelLabs/VL-InterpreT/0f9d0bb1dad13d91ce79d9453faf81fc2b9277b9/example_database1/faiss/img_indices_10
--------------------------------------------------------------------------------
/example_database1/faiss/img_indices_11:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IntelLabs/VL-InterpreT/0f9d0bb1dad13d91ce79d9453faf81fc2b9277b9/example_database1/faiss/img_indices_11
--------------------------------------------------------------------------------
/example_database1/faiss/img_indices_12:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IntelLabs/VL-InterpreT/0f9d0bb1dad13d91ce79d9453faf81fc2b9277b9/example_database1/faiss/img_indices_12
--------------------------------------------------------------------------------
/example_database1/faiss/img_indices_2:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IntelLabs/VL-InterpreT/0f9d0bb1dad13d91ce79d9453faf81fc2b9277b9/example_database1/faiss/img_indices_2
--------------------------------------------------------------------------------
/example_database1/faiss/img_indices_3:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IntelLabs/VL-InterpreT/0f9d0bb1dad13d91ce79d9453faf81fc2b9277b9/example_database1/faiss/img_indices_3
--------------------------------------------------------------------------------
/example_database1/faiss/img_indices_4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IntelLabs/VL-InterpreT/0f9d0bb1dad13d91ce79d9453faf81fc2b9277b9/example_database1/faiss/img_indices_4
--------------------------------------------------------------------------------
/example_database1/faiss/img_indices_5:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IntelLabs/VL-InterpreT/0f9d0bb1dad13d91ce79d9453faf81fc2b9277b9/example_database1/faiss/img_indices_5
--------------------------------------------------------------------------------
/example_database1/faiss/img_indices_6:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IntelLabs/VL-InterpreT/0f9d0bb1dad13d91ce79d9453faf81fc2b9277b9/example_database1/faiss/img_indices_6
--------------------------------------------------------------------------------
/example_database1/faiss/img_indices_7:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IntelLabs/VL-InterpreT/0f9d0bb1dad13d91ce79d9453faf81fc2b9277b9/example_database1/faiss/img_indices_7
--------------------------------------------------------------------------------
/example_database1/faiss/img_indices_8:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IntelLabs/VL-InterpreT/0f9d0bb1dad13d91ce79d9453faf81fc2b9277b9/example_database1/faiss/img_indices_8
--------------------------------------------------------------------------------
/example_database1/faiss/img_indices_9:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IntelLabs/VL-InterpreT/0f9d0bb1dad13d91ce79d9453faf81fc2b9277b9/example_database1/faiss/img_indices_9
--------------------------------------------------------------------------------
/example_database1/faiss/txt_indices_0:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IntelLabs/VL-InterpreT/0f9d0bb1dad13d91ce79d9453faf81fc2b9277b9/example_database1/faiss/txt_indices_0
--------------------------------------------------------------------------------
/example_database1/faiss/txt_indices_1:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IntelLabs/VL-InterpreT/0f9d0bb1dad13d91ce79d9453faf81fc2b9277b9/example_database1/faiss/txt_indices_1
--------------------------------------------------------------------------------
/example_database1/faiss/txt_indices_10:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IntelLabs/VL-InterpreT/0f9d0bb1dad13d91ce79d9453faf81fc2b9277b9/example_database1/faiss/txt_indices_10
--------------------------------------------------------------------------------
/example_database1/faiss/txt_indices_11:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IntelLabs/VL-InterpreT/0f9d0bb1dad13d91ce79d9453faf81fc2b9277b9/example_database1/faiss/txt_indices_11
--------------------------------------------------------------------------------
/example_database1/faiss/txt_indices_12:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IntelLabs/VL-InterpreT/0f9d0bb1dad13d91ce79d9453faf81fc2b9277b9/example_database1/faiss/txt_indices_12
--------------------------------------------------------------------------------
/example_database1/faiss/txt_indices_2:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IntelLabs/VL-InterpreT/0f9d0bb1dad13d91ce79d9453faf81fc2b9277b9/example_database1/faiss/txt_indices_2
--------------------------------------------------------------------------------
/example_database1/faiss/txt_indices_3:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IntelLabs/VL-InterpreT/0f9d0bb1dad13d91ce79d9453faf81fc2b9277b9/example_database1/faiss/txt_indices_3
--------------------------------------------------------------------------------
/example_database1/faiss/txt_indices_4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IntelLabs/VL-InterpreT/0f9d0bb1dad13d91ce79d9453faf81fc2b9277b9/example_database1/faiss/txt_indices_4
--------------------------------------------------------------------------------
/example_database1/faiss/txt_indices_5:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IntelLabs/VL-InterpreT/0f9d0bb1dad13d91ce79d9453faf81fc2b9277b9/example_database1/faiss/txt_indices_5
--------------------------------------------------------------------------------
/example_database1/faiss/txt_indices_6:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IntelLabs/VL-InterpreT/0f9d0bb1dad13d91ce79d9453faf81fc2b9277b9/example_database1/faiss/txt_indices_6
--------------------------------------------------------------------------------
/example_database1/faiss/txt_indices_7:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IntelLabs/VL-InterpreT/0f9d0bb1dad13d91ce79d9453faf81fc2b9277b9/example_database1/faiss/txt_indices_7
--------------------------------------------------------------------------------
/example_database1/faiss/txt_indices_8:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IntelLabs/VL-InterpreT/0f9d0bb1dad13d91ce79d9453faf81fc2b9277b9/example_database1/faiss/txt_indices_8
--------------------------------------------------------------------------------
/example_database1/faiss/txt_indices_9:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IntelLabs/VL-InterpreT/0f9d0bb1dad13d91ce79d9453faf81fc2b9277b9/example_database1/faiss/txt_indices_9
--------------------------------------------------------------------------------
/example_database1/lock.mdb:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IntelLabs/VL-InterpreT/0f9d0bb1dad13d91ce79d9453faf81fc2b9277b9/example_database1/lock.mdb
--------------------------------------------------------------------------------
/example_database2/data.mdb:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IntelLabs/VL-InterpreT/0f9d0bb1dad13d91ce79d9453faf81fc2b9277b9/example_database2/data.mdb
--------------------------------------------------------------------------------
/example_database2/faiss/img_indices_0:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IntelLabs/VL-InterpreT/0f9d0bb1dad13d91ce79d9453faf81fc2b9277b9/example_database2/faiss/img_indices_0
--------------------------------------------------------------------------------
/example_database2/faiss/img_indices_1:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IntelLabs/VL-InterpreT/0f9d0bb1dad13d91ce79d9453faf81fc2b9277b9/example_database2/faiss/img_indices_1
--------------------------------------------------------------------------------
/example_database2/faiss/img_indices_10:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IntelLabs/VL-InterpreT/0f9d0bb1dad13d91ce79d9453faf81fc2b9277b9/example_database2/faiss/img_indices_10
--------------------------------------------------------------------------------
/example_database2/faiss/img_indices_11:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IntelLabs/VL-InterpreT/0f9d0bb1dad13d91ce79d9453faf81fc2b9277b9/example_database2/faiss/img_indices_11
--------------------------------------------------------------------------------
/example_database2/faiss/img_indices_12:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IntelLabs/VL-InterpreT/0f9d0bb1dad13d91ce79d9453faf81fc2b9277b9/example_database2/faiss/img_indices_12
--------------------------------------------------------------------------------
/example_database2/faiss/img_indices_2:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IntelLabs/VL-InterpreT/0f9d0bb1dad13d91ce79d9453faf81fc2b9277b9/example_database2/faiss/img_indices_2
--------------------------------------------------------------------------------
/example_database2/faiss/img_indices_3:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IntelLabs/VL-InterpreT/0f9d0bb1dad13d91ce79d9453faf81fc2b9277b9/example_database2/faiss/img_indices_3
--------------------------------------------------------------------------------
/example_database2/faiss/img_indices_4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IntelLabs/VL-InterpreT/0f9d0bb1dad13d91ce79d9453faf81fc2b9277b9/example_database2/faiss/img_indices_4
--------------------------------------------------------------------------------
/example_database2/faiss/img_indices_5:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IntelLabs/VL-InterpreT/0f9d0bb1dad13d91ce79d9453faf81fc2b9277b9/example_database2/faiss/img_indices_5
--------------------------------------------------------------------------------
/example_database2/faiss/img_indices_6:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IntelLabs/VL-InterpreT/0f9d0bb1dad13d91ce79d9453faf81fc2b9277b9/example_database2/faiss/img_indices_6
--------------------------------------------------------------------------------
/example_database2/faiss/img_indices_7:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IntelLabs/VL-InterpreT/0f9d0bb1dad13d91ce79d9453faf81fc2b9277b9/example_database2/faiss/img_indices_7
--------------------------------------------------------------------------------
/example_database2/faiss/img_indices_8:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IntelLabs/VL-InterpreT/0f9d0bb1dad13d91ce79d9453faf81fc2b9277b9/example_database2/faiss/img_indices_8
--------------------------------------------------------------------------------
/example_database2/faiss/img_indices_9:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IntelLabs/VL-InterpreT/0f9d0bb1dad13d91ce79d9453faf81fc2b9277b9/example_database2/faiss/img_indices_9
--------------------------------------------------------------------------------
/example_database2/faiss/txt_indices_0:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IntelLabs/VL-InterpreT/0f9d0bb1dad13d91ce79d9453faf81fc2b9277b9/example_database2/faiss/txt_indices_0
--------------------------------------------------------------------------------
/example_database2/faiss/txt_indices_1:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IntelLabs/VL-InterpreT/0f9d0bb1dad13d91ce79d9453faf81fc2b9277b9/example_database2/faiss/txt_indices_1
--------------------------------------------------------------------------------
/example_database2/faiss/txt_indices_10:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IntelLabs/VL-InterpreT/0f9d0bb1dad13d91ce79d9453faf81fc2b9277b9/example_database2/faiss/txt_indices_10
--------------------------------------------------------------------------------
/example_database2/faiss/txt_indices_11:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IntelLabs/VL-InterpreT/0f9d0bb1dad13d91ce79d9453faf81fc2b9277b9/example_database2/faiss/txt_indices_11
--------------------------------------------------------------------------------
/example_database2/faiss/txt_indices_12:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IntelLabs/VL-InterpreT/0f9d0bb1dad13d91ce79d9453faf81fc2b9277b9/example_database2/faiss/txt_indices_12
--------------------------------------------------------------------------------
/example_database2/faiss/txt_indices_2:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IntelLabs/VL-InterpreT/0f9d0bb1dad13d91ce79d9453faf81fc2b9277b9/example_database2/faiss/txt_indices_2
--------------------------------------------------------------------------------
/example_database2/faiss/txt_indices_3:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IntelLabs/VL-InterpreT/0f9d0bb1dad13d91ce79d9453faf81fc2b9277b9/example_database2/faiss/txt_indices_3
--------------------------------------------------------------------------------
/example_database2/faiss/txt_indices_4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IntelLabs/VL-InterpreT/0f9d0bb1dad13d91ce79d9453faf81fc2b9277b9/example_database2/faiss/txt_indices_4
--------------------------------------------------------------------------------
/example_database2/faiss/txt_indices_5:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IntelLabs/VL-InterpreT/0f9d0bb1dad13d91ce79d9453faf81fc2b9277b9/example_database2/faiss/txt_indices_5
--------------------------------------------------------------------------------
/example_database2/faiss/txt_indices_6:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IntelLabs/VL-InterpreT/0f9d0bb1dad13d91ce79d9453faf81fc2b9277b9/example_database2/faiss/txt_indices_6
--------------------------------------------------------------------------------
/example_database2/faiss/txt_indices_7:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IntelLabs/VL-InterpreT/0f9d0bb1dad13d91ce79d9453faf81fc2b9277b9/example_database2/faiss/txt_indices_7
--------------------------------------------------------------------------------
/example_database2/faiss/txt_indices_8:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IntelLabs/VL-InterpreT/0f9d0bb1dad13d91ce79d9453faf81fc2b9277b9/example_database2/faiss/txt_indices_8
--------------------------------------------------------------------------------
/example_database2/faiss/txt_indices_9:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IntelLabs/VL-InterpreT/0f9d0bb1dad13d91ce79d9453faf81fc2b9277b9/example_database2/faiss/txt_indices_9
--------------------------------------------------------------------------------
/example_database2/lock.mdb:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IntelLabs/VL-InterpreT/0f9d0bb1dad13d91ce79d9453faf81fc2b9277b9/example_database2/lock.mdb
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | dash>=2.0.0
2 | plotly>=5.4.0
3 | dash-daq>=0.5.0
4 | matplotlib>=3.3.3
5 | termcolor>=1.1.0
6 | numpy>=1.18.4
7 | lmdb>=1.1.1
8 | lz4>=3.1.3
9 | msgpack>=1.0.2
10 | msgpack_numpy>=0.4.7.1
11 | faiss-cpu>=1.7.1
12 | scikit-learn>=0.23.2
13 | tqdm
14 |
--------------------------------------------------------------------------------
/run_app.py:
--------------------------------------------------------------------------------
1 | '''Driver script to start and run InterpreT.'''
2 |
3 | import socket
4 | import argparse
5 | from importlib import import_module
6 |
7 | from app import app_configuration
8 |
9 |
10 | def build_model(model_name, *args, **kwargs):
11 | model = getattr(import_module(f'app.database.models.{model_name.lower()}'), model_name.title())
12 | return model(*args, **kwargs)
13 |
14 |
15 | def main(port, db_path, model_name=None, model_params=None):
16 | print(f'Running on port {port} with database {db_path}' +
17 | f' and {model_name} model' if model_name else '')
18 |
19 | model = build_model(model_name, *model_params) if model_name else None
20 | app = app_configuration.configure_app(db_path, model=model)
21 | hostname = socket.gethostname()
22 | ip_address = socket.gethostbyname(hostname)
23 | app_configuration.print_page_link(hostname, port)
24 | application = app.server
25 | application.run(debug=False, threaded=True, host=ip_address, port=int(port))
26 |
27 |
28 | if __name__ == '__main__':
29 | parser = argparse.ArgumentParser()
30 | parser.add_argument('-p', '--port', required=True,
31 | help='The port number to run this app on.')
32 | parser.add_argument('-d', '--database', required=True,
33 | help='The path to a database. See app/database/db_example.py')
34 | parser.add_argument('-m', '--model', nargs='+', required=False,
35 | help='Your model name to run, followed by the arguments '
36 | 'that need to be passed to start the model.')
37 |
38 | args = parser.parse_args()
39 | if args.model:
40 | main(args.port, args.database, args.model[0], args.model[1:])
41 | else:
42 | main(args.port, args.database)
43 |
--------------------------------------------------------------------------------