├── LICENSE
├── README.md
├── exec
├── plot_heatmap.jar
└── plot_heatmap.py
├── multihead-att-java
├── ActionLabel.java
├── DataObject.java
├── HeatmapPanel.java
├── MainFrame.java
├── MainPanel.java
└── Utils.java
└── toydata
├── figures
├── java-heatmap1.png
├── java-heatmap2.png
├── java-heatmap3.png
├── java-heatmap4.png
├── py-heatmap.png
├── py-heatmap1.png
└── py-heatmap2.png
└── toy.attention
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Attention-Visualization
2 | Visualization for simple attention and Google's multi-head attention.
3 |
4 | ## Requirements
5 |
6 | - python
7 | - jdk1.8
8 |
9 | ## Usage
10 |
11 | 1\. Python version (for simple attention only):
12 |
13 | ``` bash
14 | python exec/plot_heatmap.py --input xxx.attention
15 | ```
16 |
17 | 2\. Java version (for both simple attention and Google's multi-head attention):
18 | ``` bash
19 | java -jar exec/plot_heatmap.jar
20 | ```
21 | then select the attention file on the GUI.
22 |
23 |
24 | ## Data Format
25 |
26 | The name of the attention file should end with ".attention" extension especially when using exec/plot_heatmap.jar and the file should be json format:
27 |
28 | ``` python
29 | { "0": {
30 | "source": " ", # the source sentence (without and symbols)
31 | "translation": " ", # the target sentence (without and symbols)
32 | "attentions": [ # various attention results
33 | {
34 | "name": " ", # a unique name for this attention
35 | "type": " ", # the type of this attention (simple or multihead)
36 | "value": [...] # the attention weights, a json array
37 | }, # end of one attention result
38 | {...}, ...] # end of various attention results
39 | }, # end of the first sample
40 | "1":{
41 | ...
42 | }, # end of the second sample
43 | ...
44 | } # end of file
45 | ```
46 |
47 | Note that due to the hard coding, the `name` of each attention should contain "encoder_decoder_attention", "encoder_self_attention" or "decoder_self_attention" substring on the basis of its real meaning.
48 |
49 | The `value` has shape [length_queries, length_keys] when `type`=simple and has shape [num_heads, length_queries, length_keys] when `type`=multihead.
50 |
51 | For more details, see [attention.py](https://github.com/zhaocq-nlp/NJUNMT-tf/blob/master/njunmt/inference/attention.py).
52 |
53 | ## Demo
54 |
55 | The `toydata/toy.attention` is generated by a NMT model with a self-attention encoder, Bahdanau's attention and a RNN decoder using [NJUNMT-tf](https://github.com/zhaocq-nlp/NJUNMT-tf).
56 |
57 | 1\. Execute the python version (for simple attention only):
58 |
59 | ``` bash
60 | python exec/plot_heatmap.py --input toydata/toy.attention
61 | ```
62 | It will plot the traditional attention heatmap:
63 |
64 |
65 |

66 |
67 |
68 |
69 | 2\. As for java version (for both simple attention and Google's multihead attention), execute
70 | ``` bash
71 | java -jar exec/plot_heatmap.jar
72 | ```
73 | then select the `toydata/toy.attention` on the GUI.
74 |
75 |
76 |

77 |
78 |
79 |

80 |
81 |
82 | The words on the left side are attention "queries" and attention "keys" are on the right. Click on the words on the left side to see the heatmap:
83 |
84 |
85 |

86 |
87 |
88 | Here shows the traditional `encoder_decoder_attention` of word "obtained". The color depth of lines and squares indicate the degree of attention.
89 |
90 | Next, select `encoder_self_attention0` under the menu bar. Click on the "获得" on the left.
91 |
92 |
93 |

94 |
95 |
96 | It shows the multi-head attention of the word "获得". Attention weights of head0 - head7 are displayed on the right.
--------------------------------------------------------------------------------
/exec/plot_heatmap.jar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaocq-nlp/Attention-Visualization/66ca6ba3aa88a0450f3c5d99d1afc84481cc4d88/exec/plot_heatmap.jar
--------------------------------------------------------------------------------
/exec/plot_heatmap.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from __future__ import absolute_import
3 |
4 | import numpy
5 | import matplotlib.pyplot as plt
6 | import json
7 | import argparse
8 |
9 |
10 | # input:
11 | # alignment matrix - numpy array
12 | # shape (target tokens + eos, number of hidden source states = source tokens +eos)
13 | # one line correpsonds to one decoding step producing one target token
14 | # each line has the attention model weights corresponding to that decoding step
15 | # each float on a line is the attention model weight for a corresponding source state.
16 | # plot: a heat map of the alignment matrix
17 | # x axis are the source tokens (alignment is to source hidden state that roughly corresponds to a source token)
18 | # y axis are the target tokens
19 |
20 | # http://stackoverflow.com/questions/14391959/heatmap-in-matplotlib-with-pcolor
21 | def plot_head_map(mma, target_labels, source_labels):
22 | fig, ax = plt.subplots()
23 | heatmap = ax.pcolor(mma, cmap=plt.cm.Blues)
24 |
25 | # put the major ticks at the middle of each cell
26 | ax.set_xticks(numpy.arange(mma.shape[1]) + 0.5, minor=False)
27 | ax.set_yticks(numpy.arange(mma.shape[0]) + 0.5, minor=False)
28 |
29 | # without this I get some extra columns rows
30 | # http://stackoverflow.com/questions/31601351/why-does-this-matplotlib-heatmap-have-an-extra-blank-column
31 | ax.set_xlim(0, int(mma.shape[1]))
32 | ax.set_ylim(0, int(mma.shape[0]))
33 |
34 | # want a more natural, table-like display
35 | ax.invert_yaxis()
36 | ax.xaxis.tick_top()
37 |
38 | # source words -> column labels
39 | ax.set_xticklabels(source_labels, minor=False)
40 | # target words -> row labels
41 | ax.set_yticklabels(target_labels, minor=False)
42 |
43 | plt.xticks(rotation=45)
44 |
45 | # plt.tight_layout()
46 | plt.show()
47 |
48 |
49 | # column labels -> target words
50 | # row labels -> source words
51 |
52 | def read_alignment_matrix(f):
53 | header = f.readline().strip().split('|||')
54 | if header[0] == '':
55 | return None, None, None, None
56 | sid = int(header[0].strip())
57 | # number of tokens in source and translation +1 for eos
58 | src_count, trg_count = map(int, header[-1].split())
59 | # source words
60 | source_labels = header[3].decode('UTF-8').split()
61 | # source_labels.append('')
62 | # target words
63 | target_labels = header[1].decode('UTF-8').split()
64 | target_labels.append('')
65 |
66 | mm = []
67 | for r in range(trg_count):
68 | alignment = map(float, f.readline().strip().split())
69 | mm.append(alignment)
70 | mma = numpy.array(mm)
71 | return sid, mma, target_labels, source_labels
72 |
73 |
74 | def read_plot_alignment_matrices(f, start=0):
75 | attentions = json.load(f, encoding="utf-8")
76 |
77 | for idx, att in attentions.items():
78 | if idx < start: continue
79 | source_labels = att["source"].split() + ["SEQUENCE_END"]
80 | target_labels = att["translation"].split()
81 | att_list = att["attentions"]
82 | assert att_list[0]["type"] == "simple", "Do not use this tool for multihead attention."
83 | mma = numpy.array(att_list[0]["value"])
84 | if mma.shape[0] == len(target_labels) + 1:
85 | target_labels += ["SEQUENCE_END"]
86 |
87 | plot_head_map(mma, target_labels, source_labels)
88 |
89 |
90 | if __name__ == "__main__":
91 | parser = argparse.ArgumentParser()
92 | parser.add_argument('--input', '-i', type=argparse.FileType("rb"),
93 | default="trans.att",
94 | metavar='PATH',
95 | help="Input file (default: standard input)")
96 | parser.add_argument('--start', type=int, default=0)
97 |
98 | args = parser.parse_args()
99 |
100 | read_plot_alignment_matrices(args.input, args.start)
101 |
--------------------------------------------------------------------------------
/multihead-att-java/ActionLabel.java:
--------------------------------------------------------------------------------
1 | import javax.swing.*;
2 | import javax.swing.border.BevelBorder;
3 | import javax.swing.event.MouseInputListener;
4 | import java.awt.*;
5 | import java.awt.event.MouseEvent;
6 | import java.awt.event.MouseListener;
7 | import java.awt.event.MouseMotionListener;
8 |
9 |
10 | public class ActionLabel extends JLabel {
11 |
12 | private boolean isActive = false;
13 | private HeatmapPanel parent = null;
14 | private String currentText = "";
15 | private int id = -1;
16 |
17 | // private HintPanel hintPanel = null;
18 |
19 | public ActionLabel(HeatmapPanel parent, Integer id, String text, int horizontalAlignment) {
20 | super(text, horizontalAlignment);
21 | this.parent = parent;
22 | this.currentText = text;
23 | this.id = id;
24 | LabelMouseListener listener = new LabelMouseListener();
25 | this.addMouseListener(listener);
26 |
27 | }
28 |
29 | @Override
30 | public void setText(String text) {
31 | super.setText(text);
32 | this.currentText = text;
33 | if (text.length() > 0) {
34 | isActive = true;
35 | } else {
36 | isActive = false;
37 | }
38 | }
39 |
40 | @Override
41 | public void setBounds(int x, int y, int width, int height) {
42 | super.setBounds(x, y, width, height);
43 | // if(hintPanel == null){
44 | // hintPanel = new HintPanel(x - 50, (int) y - 25, 200, 30);
45 | // }
46 | }
47 |
48 | private class HintPanel extends JPanel{
49 |
50 | JLabel label = new JLabel("", JLabel.CENTER);
51 | public HintPanel(int x, int y, int width, int height){
52 | super();
53 | super.setBounds(x, y, width, height);
54 | super.add(label);
55 | label.setBounds(0, 0, width, height);
56 | label.setFont(new Font("TimesRoman", Font.PLAIN, 20));
57 | }
58 |
59 | public void setText(String text){
60 | label.setText(text);
61 | }
62 |
63 | @Override
64 | protected void paintComponent(Graphics g) {
65 | super.paintComponent(g);
66 | Graphics2D g2d = (Graphics2D) g;
67 | g2d.setComposite(AlphaComposite.getInstance(AlphaComposite.SRC_OVER, 0.7f));
68 | g2d.setColor(Color.YELLOW);
69 | g2d.fill(getBounds());
70 | g2d.dispose();
71 | }
72 | }
73 |
74 | private class LabelMouseListener implements MouseListener {
75 |
76 | @Override
77 | public void mouseClicked(MouseEvent e) {
78 | if (isActive) {
79 | parent.setWordIndex(id);
80 | }
81 | }
82 |
83 | @Override
84 | public void mouseEntered(MouseEvent e) {
85 | if (isActive) {
86 | // if (hintPanel != null) {
87 | // parent.add(hintPanel);
88 | // parent.flushPanel();
89 | // }
90 | ((JLabel) e.getComponent()).setBorder(new BevelBorder(BevelBorder.RAISED, null, null, null, null));
91 | ((JLabel) e.getComponent()).setCursor(new Cursor(Cursor.HAND_CURSOR));
92 | }
93 | }
94 |
95 | @Override
96 | public void mouseExited(MouseEvent e) {
97 | if (isActive) {
98 | // parent.remove(hintPanel);
99 | // parent.flushPanel();
100 | ((JLabel) e.getComponent()).setBorder(null);
101 | }
102 | }
103 |
104 | @Override
105 | public void mousePressed(MouseEvent e) {
106 | if (isActive) {
107 | ((JLabel) e.getComponent()).setBorder(new BevelBorder(BevelBorder.LOWERED, null, null, null, null));
108 | }
109 | }
110 |
111 | @Override
112 | public void mouseReleased(MouseEvent e) {
113 | if (isActive) {
114 | ((JLabel) e.getComponent()).setBorder(new BevelBorder(BevelBorder.RAISED, null, null, null, null));
115 | }
116 | }
117 | }
118 |
119 | }
120 |
--------------------------------------------------------------------------------
/multihead-att-java/DataObject.java:
--------------------------------------------------------------------------------
1 | import org.json.JSONArray;
2 | import org.json.JSONObject;
3 |
4 | import java.io.BufferedReader;
5 | import java.io.FileInputStream;
6 | import java.io.InputStreamReader;
7 | import java.util.ArrayList;
8 | import java.util.Collections;
9 | import java.util.List;
10 |
11 | public class DataObject {
12 |
13 | public int numSamples = 0;
14 | public List attentionFieldList;
15 |
16 |
17 | private JSONObject dataObject = null;
18 |
19 | public DataObject(String filename){
20 | this.reload(filename);
21 | }
22 |
23 | public JSONObject get(int index){
24 | try {
25 | JSONObject obj = this.dataObject.getJSONObject(String.format("%d", index));
26 | return obj;
27 | } catch(Exception e){
28 | e.printStackTrace();
29 | System.exit(0);
30 | }
31 | return null;
32 | }
33 |
34 | public String getAttentionType(int index, String attentionField){
35 | JSONObject instanceObj = this.get(index);
36 | try{
37 | JSONArray attLists = instanceObj.getJSONArray("attentions");
38 | for (int i = 0; i < attLists.length(); ++i) {
39 | JSONObject obj = attLists.getJSONObject(i);
40 | if(obj.getString("name").equals(attentionField)){
41 | return obj.getString("type");
42 | }
43 | }
44 | } catch(Exception e){
45 | e.printStackTrace();
46 | System.exit(0);
47 | }
48 | System.exit(0);
49 | return null;
50 | }
51 |
52 |
53 | public JSONArray getAttentionWeight(int index, String attentionField){
54 | JSONObject instanceObj = this.get(index);
55 | try{
56 | JSONArray attLists = instanceObj.getJSONArray("attentions");
57 | for (int i = 0; i < attLists.length(); ++i) {
58 | JSONObject obj = attLists.getJSONObject(i);
59 | if(obj.getString("name").equals(attentionField)){
60 | return obj.getJSONArray("value");
61 | }
62 | }
63 | } catch(Exception e){
64 | e.printStackTrace();
65 | System.exit(0);
66 | }
67 | System.exit(0);
68 | return null;
69 | }
70 |
71 | public void reload(String filename){
72 | try {
73 | BufferedReader br = new BufferedReader(new InputStreamReader(new FileInputStream(filename), "utf-8"));
74 | String str = br.readLine();
75 | br.close();
76 | this.dataObject = new JSONObject(str);
77 |
78 | this.numSamples = this.dataObject.length();
79 | this.attentionFieldList = new ArrayList();
80 | JSONArray attArray = this.dataObject.getJSONObject("0").getJSONArray("attentions");
81 | for (int i = 0; i < attArray.length(); ++i) {
82 | JSONObject obj = attArray.getJSONObject(i);
83 | String type =obj.getString("type");
84 | if (type.equals("multihead") || type.equals("simple")) {
85 | this.attentionFieldList.add(obj.getString("name"));
86 | } else{
87 | System.err.println(String.format("Error with type: %s", type));
88 | System.exit(0);
89 | }
90 | }
91 | Collections.sort(this.attentionFieldList);
92 | } catch (Exception e){
93 | e.printStackTrace();
94 | }
95 | }
96 | }
97 |
--------------------------------------------------------------------------------
/multihead-att-java/HeatmapPanel.java:
--------------------------------------------------------------------------------
1 |
2 | import org.json.JSONArray;
3 | import org.json.JSONObject;
4 |
5 | import javax.swing.*;
6 | import java.awt.*;
7 | import java.util.ArrayList;
8 | import java.util.Arrays;
9 |
10 |
11 | public class HeatmapPanel extends JPanel {
12 |
13 | private MainFrame parent = null;
14 | private DataObject dataObject = null;
15 | private int preDefineMaxLength = 50;
16 |
17 | public int currentSampleId = 0;
18 | public String currentAttentionName = "";
19 | public int wordIndex = -1;
20 |
21 |
22 | // left
23 | private ArrayList leftLabelList = new ArrayList();
24 | private ArrayList rightLabelList = new ArrayList();
25 | private ArrayList leftNumLabelList = new ArrayList();
26 | private ArrayList rightNumLabelList = new ArrayList();
27 |
28 | public HeatmapPanel(MainFrame parent, DataObject dataObject, int preDefineMaxLength) {
29 | super();
30 | this.parent = parent;
31 | this.dataObject = dataObject;
32 | this.preDefineMaxLength = preDefineMaxLength;
33 | this.setBounds(0, 0, 768, 3400);
34 | this.setLayout(null);
35 | this.addEmptyLabels();
36 | }
37 |
38 | public void flushPanel() {
39 | this.parent.flushFrame();
40 | }
41 |
42 | private void addEmptyLabels() {
43 | for (int i = 0; i < this.preDefineMaxLength; ++i) {
44 | JLabel label = new ActionLabel(this, i, "", JLabel.RIGHT);
45 | label.setFont(new Font("TimesRoman", Font.PLAIN, 18));
46 | label.setBounds(100, 50 + i * 20 + 7 * i, 140, 20);
47 | this.add(label);
48 | this.leftLabelList.add(label);
49 | }
50 | for (int i = 0; i < this.preDefineMaxLength; ++i) {
51 | JLabel label = new JLabel("", JLabel.LEFT);
52 | label.setFont(new Font("TimesRoman", Font.PLAIN, 18));
53 | label.setBounds(500, 50 + i * 20 + 7 * i, 140, 20);
54 | this.add(label);
55 | this.rightLabelList.add(label);
56 | }
57 | for (int i = 0; i < this.preDefineMaxLength; ++i) {
58 | JLabel label = new JLabel("", JLabel.RIGHT);
59 | label.setFont(new Font("TimesRoman", Font.PLAIN, 18));
60 | label.setBounds(60, 50 + i * 20 + 7 * i, 30, 20);
61 | this.add(label);
62 | this.leftNumLabelList.add(label);
63 | }
64 | for (int i = 0; i < this.preDefineMaxLength; ++i) {
65 | JLabel label = new JLabel("", JLabel.LEFT);
66 | label.setFont(new Font("TimesRoman", Font.PLAIN, 18));
67 | label.setBounds(640, 50 + i * 20 + 7 * i, 30, 20);
68 | this.add(label);
69 | this.rightNumLabelList.add(label);
70 | }
71 | }
72 |
73 | public void setWordIndex(int wordIndex) {
74 | this.wordIndex = wordIndex;
75 | this.parent.flushFrame();
76 | }
77 |
78 | public void display(String currentAttentionName) {
79 | int curPrefixIndex = this.currentAttentionName.indexOf("_attention");
80 | int nextPrefixIndex = currentAttentionName.indexOf("_attention");
81 | if (this.currentAttentionName.substring(0, curPrefixIndex).equals(
82 | currentAttentionName.substring(0, nextPrefixIndex))) {
83 | this.currentAttentionName = currentAttentionName;
84 | } else {
85 | this.currentAttentionName = currentAttentionName;
86 | this.display(this.currentSampleId);
87 | }
88 | }
89 |
90 | public void display(int sampleId) {
91 | this.currentSampleId = sampleId;
92 | this.wordIndex = -1;
93 | String left = "";
94 | String right = "";
95 | try {
96 | JSONObject currentObj = this.dataObject.get(this.currentSampleId);
97 | String source = currentObj.getString("source");
98 | String target = currentObj.getString("translation");
99 | if (this.currentAttentionName.contains("encoder_decoder_attention")) {
100 | left = target;
101 | right = source;
102 | } else if (this.currentAttentionName.contains("encoder_self_attention")) {
103 | left = source;
104 | right = source;
105 | } else if (this.currentAttentionName.contains("decoder_self_attention")) {
106 | left = target;
107 | right = target;
108 | } else {
109 | System.err.println("Error name with attention");
110 | System.exit(0);
111 | }
112 | if (this.currentAttentionName.contains("decoder_self_attention")) {
113 | left = " " + left;
114 | right = " " + right;
115 | } else {
116 | left += " ";
117 | right += " ";
118 | }
119 |
120 | } catch (Exception e) {
121 | e.printStackTrace();
122 | }
123 | String[] leftTokens = left.trim().split(" ");
124 | String[] rightTokens = right.trim().split(" ");
125 | int auxIndex = this.leftLabelList.size();
126 | while (auxIndex < leftTokens.length) {
127 | JLabel label = new ActionLabel(this, auxIndex, "", JLabel.RIGHT);
128 | label.setFont(new Font("TimesRoman", Font.PLAIN, 18));
129 | label.setBounds(100, 50 + auxIndex * 20 + 7 * auxIndex, 140, 20);
130 | this.add(label);
131 | this.leftLabelList.add(label);
132 |
133 | JLabel numLabel = new JLabel("", JLabel.RIGHT);
134 | numLabel.setFont(new Font("TimesRoman", Font.PLAIN, 18));
135 | numLabel.setBounds(60, 50 + auxIndex * 20 + 7 * auxIndex, 30, 20);
136 | this.add(numLabel);
137 | this.leftNumLabelList.add(numLabel);
138 | ++auxIndex;
139 | }
140 | auxIndex = this.rightLabelList.size();
141 | while (auxIndex < rightTokens.length) {
142 | JLabel label = new JLabel("", JLabel.LEFT);
143 | label.setFont(new Font("TimesRoman", Font.PLAIN, 18));
144 | label.setBounds(500, 50 + auxIndex * 20 + 7 * auxIndex, 130, 20);
145 | this.add(label);
146 | this.rightLabelList.add(label);
147 |
148 | JLabel numlabel = new JLabel("", JLabel.LEFT);
149 | numlabel.setFont(new Font("TimesRoman", Font.PLAIN, 18));
150 | numlabel.setBounds(640, 50 + auxIndex * 20 + 7 * auxIndex, 30, 20);
151 | this.add(numlabel);
152 | this.rightNumLabelList.add(numlabel);
153 | ++auxIndex;
154 | }
155 | for (int i = 0; i < this.leftLabelList.size(); ++i) {
156 | if (i < leftTokens.length) {
157 | this.leftLabelList.get(i).setText(leftTokens[i]);
158 | this.leftNumLabelList.get(i).setText(String.format("%d", i));
159 | } else {
160 | this.leftLabelList.get(i).setText("");
161 | this.leftNumLabelList.get(i).setText("");
162 | }
163 | }
164 | for (int i = 0; i < this.rightLabelList.size(); ++i) {
165 | if (i < rightTokens.length) {
166 | this.rightLabelList.get(i).setText(rightTokens[i]);
167 | this.rightNumLabelList.get(i).setText(String.format("%d", i));
168 | } else {
169 | this.rightLabelList.get(i).setText("");
170 | this.rightNumLabelList.get(i).setText("");
171 | }
172 | }
173 | // this.validate();
174 | // this.invalidate();
175 | // this.repaint();
176 | }
177 |
178 | @Override
179 | public void paint(Graphics g) {
180 | super.paint(g);
181 | if (this.wordIndex < 0) {
182 | return;
183 | }
184 | Graphics2D g2 = (Graphics2D) g;
185 | g2.setColor(Color.BLUE);
186 | // g2.setColor(Color.CYAN);
187 | JSONArray attArray = this.dataObject.getAttentionWeight(this.currentSampleId, this.currentAttentionName);
188 | String attType = this.dataObject.getAttentionType(this.currentSampleId, this.currentAttentionName);
189 | if (attType.equals("multihead")) {
190 | double[] accumulateScores = null;
191 | g2.setStroke(new BasicStroke(3.0f));
192 | for (int head = 0; head < attArray.length(); ++head) {
193 | JSONArray att = null;
194 | try {
195 | att = (JSONArray) ((JSONArray) (attArray.get(head))).get(this.wordIndex);
196 | if (head == 0) {
197 | accumulateScores = new double[att.length()];
198 | Arrays.fill(accumulateScores, 0.0);
199 | }
200 | for (int idx = 0; idx < att.length(); ++idx) {
201 | accumulateScores[idx] += att.getDouble(idx);
202 | g2.setComposite(AlphaComposite.getInstance(
203 | AlphaComposite.SRC_OVER, (float) (att.getDouble(idx) * 0.7)));
204 | g2.fillRect(500 + head * 20, 50 + 27 * idx, 20, 20);
205 | }
206 | } catch (Exception e) {
207 | e.printStackTrace();
208 | }
209 | }
210 | double sum = 0.0;
211 | for (int idx = 0; idx < accumulateScores.length; ++idx) {
212 | sum += accumulateScores[idx];
213 | }
214 | int[] topIndexes = Utils.topIndexes(accumulateScores, 5);
215 | double multiplier = 0.6 / (accumulateScores[topIndexes[0]] / sum);
216 | for (int idx = 0; idx < topIndexes.length; ++idx) {
217 | double prob = accumulateScores[topIndexes[idx]] / sum * multiplier;
218 | g2.setComposite(AlphaComposite.getInstance(
219 | AlphaComposite.SRC_OVER, (float) (prob)));
220 | g2.drawLine(240, 60 + this.wordIndex * 27, 500, 60 + topIndexes[idx] * 27);
221 | }
222 | } else if (attType.equals("simple")) {
223 | double[] accumulateScores = null;
224 | g2.setStroke(new BasicStroke(3.0f));
225 | JSONArray att = null;
226 | try {
227 | att = (JSONArray) (attArray.get(this.wordIndex));
228 | accumulateScores = new double[att.length()];
229 | Arrays.fill(accumulateScores, 0.0);
230 | for (int idx = 0; idx < att.length(); ++idx) {
231 | accumulateScores[idx] = att.getDouble(idx);
232 | g2.setComposite(AlphaComposite.getInstance(
233 | AlphaComposite.SRC_OVER, (float) (att.getDouble(idx) * 0.7)));
234 | g2.fillRect(500 + 20, 50 + 27 * idx, 20 * 8, 20);
235 | }
236 | } catch (Exception e) {
237 | e.printStackTrace();
238 | }
239 |
240 | int[] topIndexes = Utils.topIndexes(accumulateScores, 5);
241 | double multiplier = 0.6 / accumulateScores[topIndexes[0]];
242 | for (int idx = 0; idx < topIndexes.length; ++idx) {
243 | double prob = accumulateScores[topIndexes[idx]] * multiplier;
244 | g2.setComposite(AlphaComposite.getInstance(
245 | AlphaComposite.SRC_OVER, (float) (prob)));
246 | g2.drawLine(240, 60 + this.wordIndex * 27, 500, 60 + topIndexes[idx] * 27);
247 | }
248 | }
249 | g.dispose();
250 | }
251 | }
252 |
--------------------------------------------------------------------------------
/multihead-att-java/MainFrame.java:
--------------------------------------------------------------------------------
1 | import java.awt.*;
2 | import javax.swing.*;
3 | import javax.swing.filechooser.FileFilter;
4 |
5 | import java.awt.event.ActionEvent;
6 | import java.awt.event.ActionListener;
7 | import java.io.File;
8 |
9 |
10 | public class MainFrame extends JDialog {
11 |
12 | // for main panel
13 | private JTabbedPane mainTabbedPane = new JTabbedPane(JTabbedPane.TOP);
14 |
15 | // for menu
16 | private JFileChooser fileChooser = new JFileChooser();
17 |
18 | private int panelCount = 0;
19 |
20 | public MainFrame(String name) throws Exception {
21 | this.setTitle(name);
22 | this.setDefaultCloseOperation(JFrame.DISPOSE_ON_CLOSE);
23 | this.setMinimumSize(new Dimension(800, 600));//setSize();
24 | this.setVisible(true);
25 | this.setLocation((Toolkit.getDefaultToolkit().getScreenSize().width - this.getWidth()) / 2,
26 | (Toolkit.getDefaultToolkit().getScreenSize().height - this.getHeight()) / 2);
27 | this.setResizable(true);
28 |
29 | FileFilter attFileFilter = new FileFilter() {
30 | @Override
31 | public boolean accept(File f) {
32 | String name = f.getName();
33 | return f.isDirectory() || name.endsWith(".attention");
34 | }
35 |
36 | @Override
37 | public String getDescription() {
38 | return "*.attention";
39 | }
40 | };
41 | fileChooser.setFileFilter(attFileFilter);
42 | fileChooser.addChoosableFileFilter(attFileFilter);
43 |
44 | this.createMenuBar();
45 |
46 |
47 | Container contentPane = this.getContentPane();
48 | contentPane.add(this.mainTabbedPane);
49 | this.flushFrame();
50 | System.out.println("Finish");
51 | }
52 |
53 | public void openFile(String filename) {
54 | try {
55 | DataObject dataObject = new DataObject(filename);
56 | HeatmapPanel heatmapPanel = new HeatmapPanel(this, dataObject, 30);
57 | MainPanel panel = new MainPanel(new JPanel(), dataObject);
58 | panel.addHeatmapPanel(heatmapPanel);
59 | this.mainTabbedPane.add(panel, Utils.extractFilePrefix(filename));
60 | this.mainTabbedPane.setSelectedComponent(panel);
61 | this.flushFrame();
62 | } catch (Exception e) {
63 | e.printStackTrace();
64 | }
65 | }
66 |
67 | public void flushFrame() {
68 | this.validate();
69 | this.invalidate();
70 | this.repaint();
71 | }
72 |
73 |
74 | private void createMenuBar() {
75 | JMenu menu = new JMenu("File");
76 | JMenuItem openItem = new JMenuItem("Open...");
77 | menu.add(openItem);
78 | JMenuBar br = new JMenuBar();
79 | br.add(menu);
80 |
81 | JMenuItem closeItem = new JMenuItem("Close Tab");
82 | menu.add(closeItem);
83 |
84 | JMenuItem quitItem = new JMenuItem("Quit");
85 | menu.add(quitItem);
86 |
87 | closeItem.setEnabled(false);
88 |
89 | openItem.addActionListener(new ActionListener() {
90 | @Override
91 | public void actionPerformed(ActionEvent e) {
92 | int state = fileChooser.showOpenDialog(null);
93 | if (state == 1) {
94 | return;
95 | } else {
96 | String filename = fileChooser.getSelectedFile().getAbsolutePath();
97 | openFile(filename);
98 | ++panelCount;
99 | closeItem.setEnabled(true);
100 | }
101 | }
102 | });
103 |
104 | closeItem.addActionListener(new ActionListener() {
105 | @Override
106 | public void actionPerformed(ActionEvent e) {
107 | if(closeItem.isEnabled()){
108 | mainTabbedPane.remove(mainTabbedPane.getSelectedComponent());
109 | --panelCount;
110 | if (panelCount == 0){
111 | closeItem.setEnabled(false);
112 | }
113 | }
114 |
115 | }
116 | });
117 |
118 | quitItem.addActionListener(new ActionListener() {
119 | @Override
120 | public void actionPerformed(ActionEvent e) {
121 | dispose();
122 | }
123 | });
124 |
125 | this.setJMenuBar(br);
126 | }
127 |
128 |
129 | public static void main(String[] args) throws Exception {
130 | new MainFrame("Heatmap");
131 | }
132 | }
133 |
--------------------------------------------------------------------------------
/multihead-att-java/MainPanel.java:
--------------------------------------------------------------------------------
1 | import javax.swing.*;
2 | import java.awt.*;
3 | import java.awt.event.ActionEvent;
4 | import java.awt.event.ActionListener;
5 |
6 | public class MainPanel extends JScrollPane {
7 |
8 |
9 | private JPanel mainPanel = null;
10 | private HeatmapPanel heatmapPanel = null;
11 | private JComboBox sampleComboBox = null;
12 | private JComboBox attentionComboBox = null;
13 | private DataObject dataObject = null;
14 |
15 | private JButton prevButton = null;
16 | private JButton nextButton = null;
17 |
18 | public MainPanel(JPanel panel, DataObject dataObj) {
19 | super(panel);
20 | this.mainPanel = panel;
21 | this.mainPanel.setPreferredSize(new Dimension(768, 3500));
22 | this.mainPanel.setLayout(null);
23 | this.setBounds(0, 0, 200, 200);
24 | this.setBackground(Color.WHITE);
25 | this.setOpaque(true);
26 |
27 | this.dataObject = dataObj;
28 | this.createPopupMenu();
29 | }
30 |
31 | public void addHeatmapPanel(HeatmapPanel panel) {
32 | this.mainPanel.add(panel);
33 | this.heatmapPanel = panel;
34 | heatmapPanel.currentAttentionName = (String) this.attentionComboBox.getSelectedItem();
35 | heatmapPanel.display(Integer.parseInt((String) this.sampleComboBox.getSelectedItem()));
36 | }
37 |
38 | public void flushPanel() {
39 | this.mainPanel.validate();
40 | this.mainPanel.invalidate();
41 | this.mainPanel.repaint();
42 | }
43 |
44 | public void createPopupMenu() {
45 | JLabel sampleLabel = new JLabel("Sample: ", JLabel.RIGHT);
46 | this.mainPanel.add(sampleLabel);
47 | sampleLabel.setFont(new Font("TimesRoman", Font.PLAIN, 16));
48 | sampleLabel.setBounds(50, 5, 60, 30);
49 |
50 | this.prevButton = new JButton("prev");
51 | this.mainPanel.add(this.prevButton);
52 | this.prevButton.setFont(new Font("TimesRoman", Font.PLAIN, 16));
53 | this.prevButton.setBounds(185, 5, 60, 30);
54 |
55 | this.nextButton = new JButton("next");
56 | this.mainPanel.add(this.nextButton);
57 | this.nextButton.setFont(new Font("TimesRoman", Font.PLAIN, 16));
58 | this.nextButton.setBounds(250, 5, 60, 30);
59 |
60 | JLabel attLabel = new JLabel("Displaying: ", JLabel.RIGHT);
61 | this.mainPanel.add(attLabel);
62 | attLabel.setFont(new Font("TimesRoman", Font.PLAIN, 16));
63 | attLabel.setBounds(290, 5, 120, 30);
64 | // }
65 | this.sampleComboBox = new JComboBox();
66 | for (int i = 0; i < this.dataObject.numSamples; ++i) {
67 | this.sampleComboBox.addItem(String.format("%d", i));
68 | }
69 | this.mainPanel.add(sampleComboBox);
70 | this.sampleComboBox.setBounds(110, 5, 70, 30);
71 | this.sampleComboBox.setSelectedIndex(0);
72 |
73 | this.attentionComboBox = new JComboBox();
74 | this.mainPanel.add(attentionComboBox);
75 | for (int i = 0; i < this.dataObject.attentionFieldList.size(); ++i) {
76 | this.attentionComboBox.addItem(this.dataObject.attentionFieldList.get(i));
77 | }
78 | this.attentionComboBox.setBounds(410, 5, 250, 30);
79 | this.attentionComboBox.setSelectedIndex(0);
80 |
81 | this.sampleComboBox.addActionListener(
82 | new ActionListener() {
83 | @Override
84 | public void actionPerformed(ActionEvent e) {
85 | heatmapPanel.display(Integer.parseInt((String) sampleComboBox.getSelectedItem()));
86 | flushPanel();
87 | }
88 | });
89 | this.attentionComboBox.addActionListener(
90 | new ActionListener() {
91 | @Override
92 | public void actionPerformed(ActionEvent e) {
93 | heatmapPanel.display((String) attentionComboBox.getSelectedItem());
94 | flushPanel();
95 | }
96 | });
97 |
98 | this.prevButton.addActionListener(new ActionListener() {
99 | @Override
100 | public void actionPerformed(ActionEvent e) {
101 | if (heatmapPanel.currentSampleId > 0){
102 | int prev = heatmapPanel.currentSampleId - 1;
103 | sampleComboBox.setSelectedIndex(prev);
104 | heatmapPanel.display(prev);
105 | flushPanel();
106 | }
107 | }
108 | });
109 | this.nextButton.addActionListener(new ActionListener() {
110 | @Override
111 | public void actionPerformed(ActionEvent e) {
112 | if(heatmapPanel.currentSampleId < dataObject.numSamples - 1){
113 | int next = heatmapPanel.currentSampleId + 1;
114 | sampleComboBox.setSelectedIndex(next);
115 | heatmapPanel.display(next);
116 | flushPanel();
117 | }
118 | }
119 | });
120 |
121 | }
122 |
123 |
124 | }
125 |
--------------------------------------------------------------------------------
/multihead-att-java/Utils.java:
--------------------------------------------------------------------------------
1 | import java.util.Iterator;
2 | import java.util.TreeMap;
3 |
4 | public final class Utils {
5 | public static int[] topIndexes(double[] values, int topk) {
6 | int[] indexes = new int[topk];
7 | TreeMap map = new TreeMap();
8 | for (int i = 0; i < values.length; ++i) {
9 | map.put(-values[i], i);
10 | }
11 | int cnt = 0;
12 | Iterator iter = map.values().iterator();
13 | while (iter.hasNext() && cnt < topk) {
14 | indexes[cnt] = (int) iter.next();
15 | cnt += 1;
16 | }
17 | return indexes;
18 | }
19 |
20 | public static void main(String[] args) {
21 | topIndexes(new double[]{1., 2., 3., 2.5, 2.1}, 3);
22 | }
23 |
24 | public static String extractFilePrefix(String path) {
25 | String[] paths = path.trim().split("/");
26 | String filename = paths[paths.length - 1];
27 | int tmp = filename.indexOf(".attention");
28 | return filename.substring(0, tmp);
29 | }
30 | }
31 |
--------------------------------------------------------------------------------
/toydata/figures/java-heatmap1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaocq-nlp/Attention-Visualization/66ca6ba3aa88a0450f3c5d99d1afc84481cc4d88/toydata/figures/java-heatmap1.png
--------------------------------------------------------------------------------
/toydata/figures/java-heatmap2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaocq-nlp/Attention-Visualization/66ca6ba3aa88a0450f3c5d99d1afc84481cc4d88/toydata/figures/java-heatmap2.png
--------------------------------------------------------------------------------
/toydata/figures/java-heatmap3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaocq-nlp/Attention-Visualization/66ca6ba3aa88a0450f3c5d99d1afc84481cc4d88/toydata/figures/java-heatmap3.png
--------------------------------------------------------------------------------
/toydata/figures/java-heatmap4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaocq-nlp/Attention-Visualization/66ca6ba3aa88a0450f3c5d99d1afc84481cc4d88/toydata/figures/java-heatmap4.png
--------------------------------------------------------------------------------
/toydata/figures/py-heatmap.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaocq-nlp/Attention-Visualization/66ca6ba3aa88a0450f3c5d99d1afc84481cc4d88/toydata/figures/py-heatmap.png
--------------------------------------------------------------------------------
/toydata/figures/py-heatmap1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaocq-nlp/Attention-Visualization/66ca6ba3aa88a0450f3c5d99d1afc84481cc4d88/toydata/figures/py-heatmap1.png
--------------------------------------------------------------------------------
/toydata/figures/py-heatmap2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaocq-nlp/Attention-Visualization/66ca6ba3aa88a0450f3c5d99d1afc84481cc4d88/toydata/figures/py-heatmap2.png
--------------------------------------------------------------------------------