├── LICENSE ├── README.md ├── exec ├── plot_heatmap.jar └── plot_heatmap.py ├── multihead-att-java ├── ActionLabel.java ├── DataObject.java ├── HeatmapPanel.java ├── MainFrame.java ├── MainPanel.java └── Utils.java └── toydata ├── figures ├── java-heatmap1.png ├── java-heatmap2.png ├── java-heatmap3.png ├── java-heatmap4.png ├── py-heatmap.png ├── py-heatmap1.png └── py-heatmap2.png └── toy.attention /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Attention-Visualization 2 | Visualization for simple attention and Google's multi-head attention. 3 | 4 | ## Requirements 5 | 6 | - python 7 | - jdk1.8 8 | 9 | ## Usage 10 | 11 | 1\. Python version (for simple attention only): 12 | 13 | ``` bash 14 | python exec/plot_heatmap.py --input xxx.attention 15 | ``` 16 | 17 | 2\. Java version (for both simple attention and Google's multi-head attention): 18 | ``` bash 19 | java -jar exec/plot_heatmap.jar 20 | ``` 21 | then select the attention file on the GUI. 22 | 23 | 24 | ## Data Format 25 | 26 | The name of the attention file should end with ".attention" extension especially when using exec/plot_heatmap.jar and the file should be json format: 27 | 28 | ``` python 29 | { "0": { 30 | "source": " ", # the source sentence (without and symbols) 31 | "translation": " ", # the target sentence (without and symbols) 32 | "attentions": [ # various attention results 33 | { 34 | "name": " ", # a unique name for this attention 35 | "type": " ", # the type of this attention (simple or multihead) 36 | "value": [...] # the attention weights, a json array 37 | }, # end of one attention result 38 | {...}, ...] # end of various attention results 39 | }, # end of the first sample 40 | "1":{ 41 | ... 42 | }, # end of the second sample 43 | ... 44 | } # end of file 45 | ``` 46 | 47 | Note that due to the hard coding, the `name` of each attention should contain "encoder_decoder_attention", "encoder_self_attention" or "decoder_self_attention" substring on the basis of its real meaning. 48 | 49 | The `value` has shape [length_queries, length_keys] when `type`=simple and has shape [num_heads, length_queries, length_keys] when `type`=multihead. 50 | 51 | For more details, see [attention.py](https://github.com/zhaocq-nlp/NJUNMT-tf/blob/master/njunmt/inference/attention.py). 52 | 53 | ## Demo 54 | 55 | The `toydata/toy.attention` is generated by a NMT model with a self-attention encoder, Bahdanau's attention and a RNN decoder using [NJUNMT-tf](https://github.com/zhaocq-nlp/NJUNMT-tf). 56 | 57 | 1\. Execute the python version (for simple attention only): 58 | 59 | ``` bash 60 | python exec/plot_heatmap.py --input toydata/toy.attention 61 | ``` 62 | It will plot the traditional attention heatmap: 63 | 64 |
65 |

66 |
67 | 68 | 69 | 2\. As for java version (for both simple attention and Google's multihead attention), execute 70 | ``` bash 71 | java -jar exec/plot_heatmap.jar 72 | ``` 73 | then select the `toydata/toy.attention` on the GUI. 74 | 75 |
76 |

77 |
78 |
79 |

80 |
81 | 82 | The words on the left side are attention "queries" and attention "keys" are on the right. Click on the words on the left side to see the heatmap: 83 | 84 |
85 |

86 |
87 | 88 | Here shows the traditional `encoder_decoder_attention` of word "obtained". The color depth of lines and squares indicate the degree of attention. 89 | 90 | Next, select `encoder_self_attention0` under the menu bar. Click on the "获得" on the left. 91 | 92 |
93 |

94 |
95 | 96 | It shows the multi-head attention of the word "获得". Attention weights of head0 - head7 are displayed on the right. -------------------------------------------------------------------------------- /exec/plot_heatmap.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaocq-nlp/Attention-Visualization/66ca6ba3aa88a0450f3c5d99d1afc84481cc4d88/exec/plot_heatmap.jar -------------------------------------------------------------------------------- /exec/plot_heatmap.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import absolute_import 3 | 4 | import numpy 5 | import matplotlib.pyplot as plt 6 | import json 7 | import argparse 8 | 9 | 10 | # input: 11 | # alignment matrix - numpy array 12 | # shape (target tokens + eos, number of hidden source states = source tokens +eos) 13 | # one line correpsonds to one decoding step producing one target token 14 | # each line has the attention model weights corresponding to that decoding step 15 | # each float on a line is the attention model weight for a corresponding source state. 16 | # plot: a heat map of the alignment matrix 17 | # x axis are the source tokens (alignment is to source hidden state that roughly corresponds to a source token) 18 | # y axis are the target tokens 19 | 20 | # http://stackoverflow.com/questions/14391959/heatmap-in-matplotlib-with-pcolor 21 | def plot_head_map(mma, target_labels, source_labels): 22 | fig, ax = plt.subplots() 23 | heatmap = ax.pcolor(mma, cmap=plt.cm.Blues) 24 | 25 | # put the major ticks at the middle of each cell 26 | ax.set_xticks(numpy.arange(mma.shape[1]) + 0.5, minor=False) 27 | ax.set_yticks(numpy.arange(mma.shape[0]) + 0.5, minor=False) 28 | 29 | # without this I get some extra columns rows 30 | # http://stackoverflow.com/questions/31601351/why-does-this-matplotlib-heatmap-have-an-extra-blank-column 31 | ax.set_xlim(0, int(mma.shape[1])) 32 | ax.set_ylim(0, int(mma.shape[0])) 33 | 34 | # want a more natural, table-like display 35 | ax.invert_yaxis() 36 | ax.xaxis.tick_top() 37 | 38 | # source words -> column labels 39 | ax.set_xticklabels(source_labels, minor=False) 40 | # target words -> row labels 41 | ax.set_yticklabels(target_labels, minor=False) 42 | 43 | plt.xticks(rotation=45) 44 | 45 | # plt.tight_layout() 46 | plt.show() 47 | 48 | 49 | # column labels -> target words 50 | # row labels -> source words 51 | 52 | def read_alignment_matrix(f): 53 | header = f.readline().strip().split('|||') 54 | if header[0] == '': 55 | return None, None, None, None 56 | sid = int(header[0].strip()) 57 | # number of tokens in source and translation +1 for eos 58 | src_count, trg_count = map(int, header[-1].split()) 59 | # source words 60 | source_labels = header[3].decode('UTF-8').split() 61 | # source_labels.append('') 62 | # target words 63 | target_labels = header[1].decode('UTF-8').split() 64 | target_labels.append('') 65 | 66 | mm = [] 67 | for r in range(trg_count): 68 | alignment = map(float, f.readline().strip().split()) 69 | mm.append(alignment) 70 | mma = numpy.array(mm) 71 | return sid, mma, target_labels, source_labels 72 | 73 | 74 | def read_plot_alignment_matrices(f, start=0): 75 | attentions = json.load(f, encoding="utf-8") 76 | 77 | for idx, att in attentions.items(): 78 | if idx < start: continue 79 | source_labels = att["source"].split() + ["SEQUENCE_END"] 80 | target_labels = att["translation"].split() 81 | att_list = att["attentions"] 82 | assert att_list[0]["type"] == "simple", "Do not use this tool for multihead attention." 83 | mma = numpy.array(att_list[0]["value"]) 84 | if mma.shape[0] == len(target_labels) + 1: 85 | target_labels += ["SEQUENCE_END"] 86 | 87 | plot_head_map(mma, target_labels, source_labels) 88 | 89 | 90 | if __name__ == "__main__": 91 | parser = argparse.ArgumentParser() 92 | parser.add_argument('--input', '-i', type=argparse.FileType("rb"), 93 | default="trans.att", 94 | metavar='PATH', 95 | help="Input file (default: standard input)") 96 | parser.add_argument('--start', type=int, default=0) 97 | 98 | args = parser.parse_args() 99 | 100 | read_plot_alignment_matrices(args.input, args.start) 101 | -------------------------------------------------------------------------------- /multihead-att-java/ActionLabel.java: -------------------------------------------------------------------------------- 1 | import javax.swing.*; 2 | import javax.swing.border.BevelBorder; 3 | import javax.swing.event.MouseInputListener; 4 | import java.awt.*; 5 | import java.awt.event.MouseEvent; 6 | import java.awt.event.MouseListener; 7 | import java.awt.event.MouseMotionListener; 8 | 9 | 10 | public class ActionLabel extends JLabel { 11 | 12 | private boolean isActive = false; 13 | private HeatmapPanel parent = null; 14 | private String currentText = ""; 15 | private int id = -1; 16 | 17 | // private HintPanel hintPanel = null; 18 | 19 | public ActionLabel(HeatmapPanel parent, Integer id, String text, int horizontalAlignment) { 20 | super(text, horizontalAlignment); 21 | this.parent = parent; 22 | this.currentText = text; 23 | this.id = id; 24 | LabelMouseListener listener = new LabelMouseListener(); 25 | this.addMouseListener(listener); 26 | 27 | } 28 | 29 | @Override 30 | public void setText(String text) { 31 | super.setText(text); 32 | this.currentText = text; 33 | if (text.length() > 0) { 34 | isActive = true; 35 | } else { 36 | isActive = false; 37 | } 38 | } 39 | 40 | @Override 41 | public void setBounds(int x, int y, int width, int height) { 42 | super.setBounds(x, y, width, height); 43 | // if(hintPanel == null){ 44 | // hintPanel = new HintPanel(x - 50, (int) y - 25, 200, 30); 45 | // } 46 | } 47 | 48 | private class HintPanel extends JPanel{ 49 | 50 | JLabel label = new JLabel("", JLabel.CENTER); 51 | public HintPanel(int x, int y, int width, int height){ 52 | super(); 53 | super.setBounds(x, y, width, height); 54 | super.add(label); 55 | label.setBounds(0, 0, width, height); 56 | label.setFont(new Font("TimesRoman", Font.PLAIN, 20)); 57 | } 58 | 59 | public void setText(String text){ 60 | label.setText(text); 61 | } 62 | 63 | @Override 64 | protected void paintComponent(Graphics g) { 65 | super.paintComponent(g); 66 | Graphics2D g2d = (Graphics2D) g; 67 | g2d.setComposite(AlphaComposite.getInstance(AlphaComposite.SRC_OVER, 0.7f)); 68 | g2d.setColor(Color.YELLOW); 69 | g2d.fill(getBounds()); 70 | g2d.dispose(); 71 | } 72 | } 73 | 74 | private class LabelMouseListener implements MouseListener { 75 | 76 | @Override 77 | public void mouseClicked(MouseEvent e) { 78 | if (isActive) { 79 | parent.setWordIndex(id); 80 | } 81 | } 82 | 83 | @Override 84 | public void mouseEntered(MouseEvent e) { 85 | if (isActive) { 86 | // if (hintPanel != null) { 87 | // parent.add(hintPanel); 88 | // parent.flushPanel(); 89 | // } 90 | ((JLabel) e.getComponent()).setBorder(new BevelBorder(BevelBorder.RAISED, null, null, null, null)); 91 | ((JLabel) e.getComponent()).setCursor(new Cursor(Cursor.HAND_CURSOR)); 92 | } 93 | } 94 | 95 | @Override 96 | public void mouseExited(MouseEvent e) { 97 | if (isActive) { 98 | // parent.remove(hintPanel); 99 | // parent.flushPanel(); 100 | ((JLabel) e.getComponent()).setBorder(null); 101 | } 102 | } 103 | 104 | @Override 105 | public void mousePressed(MouseEvent e) { 106 | if (isActive) { 107 | ((JLabel) e.getComponent()).setBorder(new BevelBorder(BevelBorder.LOWERED, null, null, null, null)); 108 | } 109 | } 110 | 111 | @Override 112 | public void mouseReleased(MouseEvent e) { 113 | if (isActive) { 114 | ((JLabel) e.getComponent()).setBorder(new BevelBorder(BevelBorder.RAISED, null, null, null, null)); 115 | } 116 | } 117 | } 118 | 119 | } 120 | -------------------------------------------------------------------------------- /multihead-att-java/DataObject.java: -------------------------------------------------------------------------------- 1 | import org.json.JSONArray; 2 | import org.json.JSONObject; 3 | 4 | import java.io.BufferedReader; 5 | import java.io.FileInputStream; 6 | import java.io.InputStreamReader; 7 | import java.util.ArrayList; 8 | import java.util.Collections; 9 | import java.util.List; 10 | 11 | public class DataObject { 12 | 13 | public int numSamples = 0; 14 | public List attentionFieldList; 15 | 16 | 17 | private JSONObject dataObject = null; 18 | 19 | public DataObject(String filename){ 20 | this.reload(filename); 21 | } 22 | 23 | public JSONObject get(int index){ 24 | try { 25 | JSONObject obj = this.dataObject.getJSONObject(String.format("%d", index)); 26 | return obj; 27 | } catch(Exception e){ 28 | e.printStackTrace(); 29 | System.exit(0); 30 | } 31 | return null; 32 | } 33 | 34 | public String getAttentionType(int index, String attentionField){ 35 | JSONObject instanceObj = this.get(index); 36 | try{ 37 | JSONArray attLists = instanceObj.getJSONArray("attentions"); 38 | for (int i = 0; i < attLists.length(); ++i) { 39 | JSONObject obj = attLists.getJSONObject(i); 40 | if(obj.getString("name").equals(attentionField)){ 41 | return obj.getString("type"); 42 | } 43 | } 44 | } catch(Exception e){ 45 | e.printStackTrace(); 46 | System.exit(0); 47 | } 48 | System.exit(0); 49 | return null; 50 | } 51 | 52 | 53 | public JSONArray getAttentionWeight(int index, String attentionField){ 54 | JSONObject instanceObj = this.get(index); 55 | try{ 56 | JSONArray attLists = instanceObj.getJSONArray("attentions"); 57 | for (int i = 0; i < attLists.length(); ++i) { 58 | JSONObject obj = attLists.getJSONObject(i); 59 | if(obj.getString("name").equals(attentionField)){ 60 | return obj.getJSONArray("value"); 61 | } 62 | } 63 | } catch(Exception e){ 64 | e.printStackTrace(); 65 | System.exit(0); 66 | } 67 | System.exit(0); 68 | return null; 69 | } 70 | 71 | public void reload(String filename){ 72 | try { 73 | BufferedReader br = new BufferedReader(new InputStreamReader(new FileInputStream(filename), "utf-8")); 74 | String str = br.readLine(); 75 | br.close(); 76 | this.dataObject = new JSONObject(str); 77 | 78 | this.numSamples = this.dataObject.length(); 79 | this.attentionFieldList = new ArrayList(); 80 | JSONArray attArray = this.dataObject.getJSONObject("0").getJSONArray("attentions"); 81 | for (int i = 0; i < attArray.length(); ++i) { 82 | JSONObject obj = attArray.getJSONObject(i); 83 | String type =obj.getString("type"); 84 | if (type.equals("multihead") || type.equals("simple")) { 85 | this.attentionFieldList.add(obj.getString("name")); 86 | } else{ 87 | System.err.println(String.format("Error with type: %s", type)); 88 | System.exit(0); 89 | } 90 | } 91 | Collections.sort(this.attentionFieldList); 92 | } catch (Exception e){ 93 | e.printStackTrace(); 94 | } 95 | } 96 | } 97 | -------------------------------------------------------------------------------- /multihead-att-java/HeatmapPanel.java: -------------------------------------------------------------------------------- 1 | 2 | import org.json.JSONArray; 3 | import org.json.JSONObject; 4 | 5 | import javax.swing.*; 6 | import java.awt.*; 7 | import java.util.ArrayList; 8 | import java.util.Arrays; 9 | 10 | 11 | public class HeatmapPanel extends JPanel { 12 | 13 | private MainFrame parent = null; 14 | private DataObject dataObject = null; 15 | private int preDefineMaxLength = 50; 16 | 17 | public int currentSampleId = 0; 18 | public String currentAttentionName = ""; 19 | public int wordIndex = -1; 20 | 21 | 22 | // left 23 | private ArrayList leftLabelList = new ArrayList(); 24 | private ArrayList rightLabelList = new ArrayList(); 25 | private ArrayList leftNumLabelList = new ArrayList(); 26 | private ArrayList rightNumLabelList = new ArrayList(); 27 | 28 | public HeatmapPanel(MainFrame parent, DataObject dataObject, int preDefineMaxLength) { 29 | super(); 30 | this.parent = parent; 31 | this.dataObject = dataObject; 32 | this.preDefineMaxLength = preDefineMaxLength; 33 | this.setBounds(0, 0, 768, 3400); 34 | this.setLayout(null); 35 | this.addEmptyLabels(); 36 | } 37 | 38 | public void flushPanel() { 39 | this.parent.flushFrame(); 40 | } 41 | 42 | private void addEmptyLabels() { 43 | for (int i = 0; i < this.preDefineMaxLength; ++i) { 44 | JLabel label = new ActionLabel(this, i, "", JLabel.RIGHT); 45 | label.setFont(new Font("TimesRoman", Font.PLAIN, 18)); 46 | label.setBounds(100, 50 + i * 20 + 7 * i, 140, 20); 47 | this.add(label); 48 | this.leftLabelList.add(label); 49 | } 50 | for (int i = 0; i < this.preDefineMaxLength; ++i) { 51 | JLabel label = new JLabel("", JLabel.LEFT); 52 | label.setFont(new Font("TimesRoman", Font.PLAIN, 18)); 53 | label.setBounds(500, 50 + i * 20 + 7 * i, 140, 20); 54 | this.add(label); 55 | this.rightLabelList.add(label); 56 | } 57 | for (int i = 0; i < this.preDefineMaxLength; ++i) { 58 | JLabel label = new JLabel("", JLabel.RIGHT); 59 | label.setFont(new Font("TimesRoman", Font.PLAIN, 18)); 60 | label.setBounds(60, 50 + i * 20 + 7 * i, 30, 20); 61 | this.add(label); 62 | this.leftNumLabelList.add(label); 63 | } 64 | for (int i = 0; i < this.preDefineMaxLength; ++i) { 65 | JLabel label = new JLabel("", JLabel.LEFT); 66 | label.setFont(new Font("TimesRoman", Font.PLAIN, 18)); 67 | label.setBounds(640, 50 + i * 20 + 7 * i, 30, 20); 68 | this.add(label); 69 | this.rightNumLabelList.add(label); 70 | } 71 | } 72 | 73 | public void setWordIndex(int wordIndex) { 74 | this.wordIndex = wordIndex; 75 | this.parent.flushFrame(); 76 | } 77 | 78 | public void display(String currentAttentionName) { 79 | int curPrefixIndex = this.currentAttentionName.indexOf("_attention"); 80 | int nextPrefixIndex = currentAttentionName.indexOf("_attention"); 81 | if (this.currentAttentionName.substring(0, curPrefixIndex).equals( 82 | currentAttentionName.substring(0, nextPrefixIndex))) { 83 | this.currentAttentionName = currentAttentionName; 84 | } else { 85 | this.currentAttentionName = currentAttentionName; 86 | this.display(this.currentSampleId); 87 | } 88 | } 89 | 90 | public void display(int sampleId) { 91 | this.currentSampleId = sampleId; 92 | this.wordIndex = -1; 93 | String left = ""; 94 | String right = ""; 95 | try { 96 | JSONObject currentObj = this.dataObject.get(this.currentSampleId); 97 | String source = currentObj.getString("source"); 98 | String target = currentObj.getString("translation"); 99 | if (this.currentAttentionName.contains("encoder_decoder_attention")) { 100 | left = target; 101 | right = source; 102 | } else if (this.currentAttentionName.contains("encoder_self_attention")) { 103 | left = source; 104 | right = source; 105 | } else if (this.currentAttentionName.contains("decoder_self_attention")) { 106 | left = target; 107 | right = target; 108 | } else { 109 | System.err.println("Error name with attention"); 110 | System.exit(0); 111 | } 112 | if (this.currentAttentionName.contains("decoder_self_attention")) { 113 | left = " " + left; 114 | right = " " + right; 115 | } else { 116 | left += " "; 117 | right += " "; 118 | } 119 | 120 | } catch (Exception e) { 121 | e.printStackTrace(); 122 | } 123 | String[] leftTokens = left.trim().split(" "); 124 | String[] rightTokens = right.trim().split(" "); 125 | int auxIndex = this.leftLabelList.size(); 126 | while (auxIndex < leftTokens.length) { 127 | JLabel label = new ActionLabel(this, auxIndex, "", JLabel.RIGHT); 128 | label.setFont(new Font("TimesRoman", Font.PLAIN, 18)); 129 | label.setBounds(100, 50 + auxIndex * 20 + 7 * auxIndex, 140, 20); 130 | this.add(label); 131 | this.leftLabelList.add(label); 132 | 133 | JLabel numLabel = new JLabel("", JLabel.RIGHT); 134 | numLabel.setFont(new Font("TimesRoman", Font.PLAIN, 18)); 135 | numLabel.setBounds(60, 50 + auxIndex * 20 + 7 * auxIndex, 30, 20); 136 | this.add(numLabel); 137 | this.leftNumLabelList.add(numLabel); 138 | ++auxIndex; 139 | } 140 | auxIndex = this.rightLabelList.size(); 141 | while (auxIndex < rightTokens.length) { 142 | JLabel label = new JLabel("", JLabel.LEFT); 143 | label.setFont(new Font("TimesRoman", Font.PLAIN, 18)); 144 | label.setBounds(500, 50 + auxIndex * 20 + 7 * auxIndex, 130, 20); 145 | this.add(label); 146 | this.rightLabelList.add(label); 147 | 148 | JLabel numlabel = new JLabel("", JLabel.LEFT); 149 | numlabel.setFont(new Font("TimesRoman", Font.PLAIN, 18)); 150 | numlabel.setBounds(640, 50 + auxIndex * 20 + 7 * auxIndex, 30, 20); 151 | this.add(numlabel); 152 | this.rightNumLabelList.add(numlabel); 153 | ++auxIndex; 154 | } 155 | for (int i = 0; i < this.leftLabelList.size(); ++i) { 156 | if (i < leftTokens.length) { 157 | this.leftLabelList.get(i).setText(leftTokens[i]); 158 | this.leftNumLabelList.get(i).setText(String.format("%d", i)); 159 | } else { 160 | this.leftLabelList.get(i).setText(""); 161 | this.leftNumLabelList.get(i).setText(""); 162 | } 163 | } 164 | for (int i = 0; i < this.rightLabelList.size(); ++i) { 165 | if (i < rightTokens.length) { 166 | this.rightLabelList.get(i).setText(rightTokens[i]); 167 | this.rightNumLabelList.get(i).setText(String.format("%d", i)); 168 | } else { 169 | this.rightLabelList.get(i).setText(""); 170 | this.rightNumLabelList.get(i).setText(""); 171 | } 172 | } 173 | // this.validate(); 174 | // this.invalidate(); 175 | // this.repaint(); 176 | } 177 | 178 | @Override 179 | public void paint(Graphics g) { 180 | super.paint(g); 181 | if (this.wordIndex < 0) { 182 | return; 183 | } 184 | Graphics2D g2 = (Graphics2D) g; 185 | g2.setColor(Color.BLUE); 186 | // g2.setColor(Color.CYAN); 187 | JSONArray attArray = this.dataObject.getAttentionWeight(this.currentSampleId, this.currentAttentionName); 188 | String attType = this.dataObject.getAttentionType(this.currentSampleId, this.currentAttentionName); 189 | if (attType.equals("multihead")) { 190 | double[] accumulateScores = null; 191 | g2.setStroke(new BasicStroke(3.0f)); 192 | for (int head = 0; head < attArray.length(); ++head) { 193 | JSONArray att = null; 194 | try { 195 | att = (JSONArray) ((JSONArray) (attArray.get(head))).get(this.wordIndex); 196 | if (head == 0) { 197 | accumulateScores = new double[att.length()]; 198 | Arrays.fill(accumulateScores, 0.0); 199 | } 200 | for (int idx = 0; idx < att.length(); ++idx) { 201 | accumulateScores[idx] += att.getDouble(idx); 202 | g2.setComposite(AlphaComposite.getInstance( 203 | AlphaComposite.SRC_OVER, (float) (att.getDouble(idx) * 0.7))); 204 | g2.fillRect(500 + head * 20, 50 + 27 * idx, 20, 20); 205 | } 206 | } catch (Exception e) { 207 | e.printStackTrace(); 208 | } 209 | } 210 | double sum = 0.0; 211 | for (int idx = 0; idx < accumulateScores.length; ++idx) { 212 | sum += accumulateScores[idx]; 213 | } 214 | int[] topIndexes = Utils.topIndexes(accumulateScores, 5); 215 | double multiplier = 0.6 / (accumulateScores[topIndexes[0]] / sum); 216 | for (int idx = 0; idx < topIndexes.length; ++idx) { 217 | double prob = accumulateScores[topIndexes[idx]] / sum * multiplier; 218 | g2.setComposite(AlphaComposite.getInstance( 219 | AlphaComposite.SRC_OVER, (float) (prob))); 220 | g2.drawLine(240, 60 + this.wordIndex * 27, 500, 60 + topIndexes[idx] * 27); 221 | } 222 | } else if (attType.equals("simple")) { 223 | double[] accumulateScores = null; 224 | g2.setStroke(new BasicStroke(3.0f)); 225 | JSONArray att = null; 226 | try { 227 | att = (JSONArray) (attArray.get(this.wordIndex)); 228 | accumulateScores = new double[att.length()]; 229 | Arrays.fill(accumulateScores, 0.0); 230 | for (int idx = 0; idx < att.length(); ++idx) { 231 | accumulateScores[idx] = att.getDouble(idx); 232 | g2.setComposite(AlphaComposite.getInstance( 233 | AlphaComposite.SRC_OVER, (float) (att.getDouble(idx) * 0.7))); 234 | g2.fillRect(500 + 20, 50 + 27 * idx, 20 * 8, 20); 235 | } 236 | } catch (Exception e) { 237 | e.printStackTrace(); 238 | } 239 | 240 | int[] topIndexes = Utils.topIndexes(accumulateScores, 5); 241 | double multiplier = 0.6 / accumulateScores[topIndexes[0]]; 242 | for (int idx = 0; idx < topIndexes.length; ++idx) { 243 | double prob = accumulateScores[topIndexes[idx]] * multiplier; 244 | g2.setComposite(AlphaComposite.getInstance( 245 | AlphaComposite.SRC_OVER, (float) (prob))); 246 | g2.drawLine(240, 60 + this.wordIndex * 27, 500, 60 + topIndexes[idx] * 27); 247 | } 248 | } 249 | g.dispose(); 250 | } 251 | } 252 | -------------------------------------------------------------------------------- /multihead-att-java/MainFrame.java: -------------------------------------------------------------------------------- 1 | import java.awt.*; 2 | import javax.swing.*; 3 | import javax.swing.filechooser.FileFilter; 4 | 5 | import java.awt.event.ActionEvent; 6 | import java.awt.event.ActionListener; 7 | import java.io.File; 8 | 9 | 10 | public class MainFrame extends JDialog { 11 | 12 | // for main panel 13 | private JTabbedPane mainTabbedPane = new JTabbedPane(JTabbedPane.TOP); 14 | 15 | // for menu 16 | private JFileChooser fileChooser = new JFileChooser(); 17 | 18 | private int panelCount = 0; 19 | 20 | public MainFrame(String name) throws Exception { 21 | this.setTitle(name); 22 | this.setDefaultCloseOperation(JFrame.DISPOSE_ON_CLOSE); 23 | this.setMinimumSize(new Dimension(800, 600));//setSize(); 24 | this.setVisible(true); 25 | this.setLocation((Toolkit.getDefaultToolkit().getScreenSize().width - this.getWidth()) / 2, 26 | (Toolkit.getDefaultToolkit().getScreenSize().height - this.getHeight()) / 2); 27 | this.setResizable(true); 28 | 29 | FileFilter attFileFilter = new FileFilter() { 30 | @Override 31 | public boolean accept(File f) { 32 | String name = f.getName(); 33 | return f.isDirectory() || name.endsWith(".attention"); 34 | } 35 | 36 | @Override 37 | public String getDescription() { 38 | return "*.attention"; 39 | } 40 | }; 41 | fileChooser.setFileFilter(attFileFilter); 42 | fileChooser.addChoosableFileFilter(attFileFilter); 43 | 44 | this.createMenuBar(); 45 | 46 | 47 | Container contentPane = this.getContentPane(); 48 | contentPane.add(this.mainTabbedPane); 49 | this.flushFrame(); 50 | System.out.println("Finish"); 51 | } 52 | 53 | public void openFile(String filename) { 54 | try { 55 | DataObject dataObject = new DataObject(filename); 56 | HeatmapPanel heatmapPanel = new HeatmapPanel(this, dataObject, 30); 57 | MainPanel panel = new MainPanel(new JPanel(), dataObject); 58 | panel.addHeatmapPanel(heatmapPanel); 59 | this.mainTabbedPane.add(panel, Utils.extractFilePrefix(filename)); 60 | this.mainTabbedPane.setSelectedComponent(panel); 61 | this.flushFrame(); 62 | } catch (Exception e) { 63 | e.printStackTrace(); 64 | } 65 | } 66 | 67 | public void flushFrame() { 68 | this.validate(); 69 | this.invalidate(); 70 | this.repaint(); 71 | } 72 | 73 | 74 | private void createMenuBar() { 75 | JMenu menu = new JMenu("File"); 76 | JMenuItem openItem = new JMenuItem("Open..."); 77 | menu.add(openItem); 78 | JMenuBar br = new JMenuBar(); 79 | br.add(menu); 80 | 81 | JMenuItem closeItem = new JMenuItem("Close Tab"); 82 | menu.add(closeItem); 83 | 84 | JMenuItem quitItem = new JMenuItem("Quit"); 85 | menu.add(quitItem); 86 | 87 | closeItem.setEnabled(false); 88 | 89 | openItem.addActionListener(new ActionListener() { 90 | @Override 91 | public void actionPerformed(ActionEvent e) { 92 | int state = fileChooser.showOpenDialog(null); 93 | if (state == 1) { 94 | return; 95 | } else { 96 | String filename = fileChooser.getSelectedFile().getAbsolutePath(); 97 | openFile(filename); 98 | ++panelCount; 99 | closeItem.setEnabled(true); 100 | } 101 | } 102 | }); 103 | 104 | closeItem.addActionListener(new ActionListener() { 105 | @Override 106 | public void actionPerformed(ActionEvent e) { 107 | if(closeItem.isEnabled()){ 108 | mainTabbedPane.remove(mainTabbedPane.getSelectedComponent()); 109 | --panelCount; 110 | if (panelCount == 0){ 111 | closeItem.setEnabled(false); 112 | } 113 | } 114 | 115 | } 116 | }); 117 | 118 | quitItem.addActionListener(new ActionListener() { 119 | @Override 120 | public void actionPerformed(ActionEvent e) { 121 | dispose(); 122 | } 123 | }); 124 | 125 | this.setJMenuBar(br); 126 | } 127 | 128 | 129 | public static void main(String[] args) throws Exception { 130 | new MainFrame("Heatmap"); 131 | } 132 | } 133 | -------------------------------------------------------------------------------- /multihead-att-java/MainPanel.java: -------------------------------------------------------------------------------- 1 | import javax.swing.*; 2 | import java.awt.*; 3 | import java.awt.event.ActionEvent; 4 | import java.awt.event.ActionListener; 5 | 6 | public class MainPanel extends JScrollPane { 7 | 8 | 9 | private JPanel mainPanel = null; 10 | private HeatmapPanel heatmapPanel = null; 11 | private JComboBox sampleComboBox = null; 12 | private JComboBox attentionComboBox = null; 13 | private DataObject dataObject = null; 14 | 15 | private JButton prevButton = null; 16 | private JButton nextButton = null; 17 | 18 | public MainPanel(JPanel panel, DataObject dataObj) { 19 | super(panel); 20 | this.mainPanel = panel; 21 | this.mainPanel.setPreferredSize(new Dimension(768, 3500)); 22 | this.mainPanel.setLayout(null); 23 | this.setBounds(0, 0, 200, 200); 24 | this.setBackground(Color.WHITE); 25 | this.setOpaque(true); 26 | 27 | this.dataObject = dataObj; 28 | this.createPopupMenu(); 29 | } 30 | 31 | public void addHeatmapPanel(HeatmapPanel panel) { 32 | this.mainPanel.add(panel); 33 | this.heatmapPanel = panel; 34 | heatmapPanel.currentAttentionName = (String) this.attentionComboBox.getSelectedItem(); 35 | heatmapPanel.display(Integer.parseInt((String) this.sampleComboBox.getSelectedItem())); 36 | } 37 | 38 | public void flushPanel() { 39 | this.mainPanel.validate(); 40 | this.mainPanel.invalidate(); 41 | this.mainPanel.repaint(); 42 | } 43 | 44 | public void createPopupMenu() { 45 | JLabel sampleLabel = new JLabel("Sample: ", JLabel.RIGHT); 46 | this.mainPanel.add(sampleLabel); 47 | sampleLabel.setFont(new Font("TimesRoman", Font.PLAIN, 16)); 48 | sampleLabel.setBounds(50, 5, 60, 30); 49 | 50 | this.prevButton = new JButton("prev"); 51 | this.mainPanel.add(this.prevButton); 52 | this.prevButton.setFont(new Font("TimesRoman", Font.PLAIN, 16)); 53 | this.prevButton.setBounds(185, 5, 60, 30); 54 | 55 | this.nextButton = new JButton("next"); 56 | this.mainPanel.add(this.nextButton); 57 | this.nextButton.setFont(new Font("TimesRoman", Font.PLAIN, 16)); 58 | this.nextButton.setBounds(250, 5, 60, 30); 59 | 60 | JLabel attLabel = new JLabel("Displaying: ", JLabel.RIGHT); 61 | this.mainPanel.add(attLabel); 62 | attLabel.setFont(new Font("TimesRoman", Font.PLAIN, 16)); 63 | attLabel.setBounds(290, 5, 120, 30); 64 | // } 65 | this.sampleComboBox = new JComboBox(); 66 | for (int i = 0; i < this.dataObject.numSamples; ++i) { 67 | this.sampleComboBox.addItem(String.format("%d", i)); 68 | } 69 | this.mainPanel.add(sampleComboBox); 70 | this.sampleComboBox.setBounds(110, 5, 70, 30); 71 | this.sampleComboBox.setSelectedIndex(0); 72 | 73 | this.attentionComboBox = new JComboBox(); 74 | this.mainPanel.add(attentionComboBox); 75 | for (int i = 0; i < this.dataObject.attentionFieldList.size(); ++i) { 76 | this.attentionComboBox.addItem(this.dataObject.attentionFieldList.get(i)); 77 | } 78 | this.attentionComboBox.setBounds(410, 5, 250, 30); 79 | this.attentionComboBox.setSelectedIndex(0); 80 | 81 | this.sampleComboBox.addActionListener( 82 | new ActionListener() { 83 | @Override 84 | public void actionPerformed(ActionEvent e) { 85 | heatmapPanel.display(Integer.parseInt((String) sampleComboBox.getSelectedItem())); 86 | flushPanel(); 87 | } 88 | }); 89 | this.attentionComboBox.addActionListener( 90 | new ActionListener() { 91 | @Override 92 | public void actionPerformed(ActionEvent e) { 93 | heatmapPanel.display((String) attentionComboBox.getSelectedItem()); 94 | flushPanel(); 95 | } 96 | }); 97 | 98 | this.prevButton.addActionListener(new ActionListener() { 99 | @Override 100 | public void actionPerformed(ActionEvent e) { 101 | if (heatmapPanel.currentSampleId > 0){ 102 | int prev = heatmapPanel.currentSampleId - 1; 103 | sampleComboBox.setSelectedIndex(prev); 104 | heatmapPanel.display(prev); 105 | flushPanel(); 106 | } 107 | } 108 | }); 109 | this.nextButton.addActionListener(new ActionListener() { 110 | @Override 111 | public void actionPerformed(ActionEvent e) { 112 | if(heatmapPanel.currentSampleId < dataObject.numSamples - 1){ 113 | int next = heatmapPanel.currentSampleId + 1; 114 | sampleComboBox.setSelectedIndex(next); 115 | heatmapPanel.display(next); 116 | flushPanel(); 117 | } 118 | } 119 | }); 120 | 121 | } 122 | 123 | 124 | } 125 | -------------------------------------------------------------------------------- /multihead-att-java/Utils.java: -------------------------------------------------------------------------------- 1 | import java.util.Iterator; 2 | import java.util.TreeMap; 3 | 4 | public final class Utils { 5 | public static int[] topIndexes(double[] values, int topk) { 6 | int[] indexes = new int[topk]; 7 | TreeMap map = new TreeMap(); 8 | for (int i = 0; i < values.length; ++i) { 9 | map.put(-values[i], i); 10 | } 11 | int cnt = 0; 12 | Iterator iter = map.values().iterator(); 13 | while (iter.hasNext() && cnt < topk) { 14 | indexes[cnt] = (int) iter.next(); 15 | cnt += 1; 16 | } 17 | return indexes; 18 | } 19 | 20 | public static void main(String[] args) { 21 | topIndexes(new double[]{1., 2., 3., 2.5, 2.1}, 3); 22 | } 23 | 24 | public static String extractFilePrefix(String path) { 25 | String[] paths = path.trim().split("/"); 26 | String filename = paths[paths.length - 1]; 27 | int tmp = filename.indexOf(".attention"); 28 | return filename.substring(0, tmp); 29 | } 30 | } 31 | -------------------------------------------------------------------------------- /toydata/figures/java-heatmap1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaocq-nlp/Attention-Visualization/66ca6ba3aa88a0450f3c5d99d1afc84481cc4d88/toydata/figures/java-heatmap1.png -------------------------------------------------------------------------------- /toydata/figures/java-heatmap2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaocq-nlp/Attention-Visualization/66ca6ba3aa88a0450f3c5d99d1afc84481cc4d88/toydata/figures/java-heatmap2.png -------------------------------------------------------------------------------- /toydata/figures/java-heatmap3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaocq-nlp/Attention-Visualization/66ca6ba3aa88a0450f3c5d99d1afc84481cc4d88/toydata/figures/java-heatmap3.png -------------------------------------------------------------------------------- /toydata/figures/java-heatmap4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaocq-nlp/Attention-Visualization/66ca6ba3aa88a0450f3c5d99d1afc84481cc4d88/toydata/figures/java-heatmap4.png -------------------------------------------------------------------------------- /toydata/figures/py-heatmap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaocq-nlp/Attention-Visualization/66ca6ba3aa88a0450f3c5d99d1afc84481cc4d88/toydata/figures/py-heatmap.png -------------------------------------------------------------------------------- /toydata/figures/py-heatmap1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaocq-nlp/Attention-Visualization/66ca6ba3aa88a0450f3c5d99d1afc84481cc4d88/toydata/figures/py-heatmap1.png -------------------------------------------------------------------------------- /toydata/figures/py-heatmap2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaocq-nlp/Attention-Visualization/66ca6ba3aa88a0450f3c5d99d1afc84481cc4d88/toydata/figures/py-heatmap2.png --------------------------------------------------------------------------------