├── Procfile
├── .gitattributes
├── requirements.txt
├── word_embedding_gif.gif
├── setup.sh
├── glove2word2vec_model.sav
├── train_model.py
├── README.md
└── app.py
/Procfile:
--------------------------------------------------------------------------------
1 | web: sh setup.sh && streamlit run app.py
2 |
--------------------------------------------------------------------------------
/.gitattributes:
--------------------------------------------------------------------------------
1 | glove2word2vec_model.sav filter=lfs diff=lfs merge=lfs -text
2 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy==1.18.1
2 | streamlit==0.63.0
3 | plotly==4.9.0
4 | scikit-learn==0.22.1
5 |
6 |
--------------------------------------------------------------------------------
/word_embedding_gif.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mrubenw/word-embedding-visualization/HEAD/word_embedding_gif.gif
--------------------------------------------------------------------------------
/setup.sh:
--------------------------------------------------------------------------------
1 | mkdir -p ~/.streamlit/
2 | echo "\
3 | [server]\n\
4 | headless = true\n\
5 | port = $PORT\n\
6 | enableCORS = false\n\
7 | \n\
8 | " > ~/.streamlit/config.toml
--------------------------------------------------------------------------------
/glove2word2vec_model.sav:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:4d85c21e1fc97a1a65e955ce89e4415126ebe77a6610c5a3c78ea8fd75829e2f
3 | size 181494860
4 |
--------------------------------------------------------------------------------
/train_model.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | """
4 |
5 | import numpy as np
6 |
7 | import pickle
8 | import matplotlib.pyplot as plt
9 | plt.style.use('ggplot')
10 |
11 | from sklearn.decomposition import PCA
12 |
13 | from gensim.test.utils import datapath, get_tmpfile
14 | from gensim.models import KeyedVectors
15 | from gensim.scripts.glove2word2vec import glove2word2vec
16 | from mpl_toolkits.mplot3d import Axes3D
17 |
18 | glove_file = datapath('C:/Users/ASUS/glove.6B.100d.txt')
19 | word2vec_glove_file = get_tmpfile("glove.6B.100d.word2vec.txt")
20 | glove2word2vec(glove_file, word2vec_glove_file)
21 |
22 | model = KeyedVectors.load_word2vec_format(word2vec_glove_file)
23 |
24 | filename = 'glove2word2vec_model.sav'
25 | pickle.dump(model, open(filename, 'wb'))
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Word Embedding Visualization
2 |
3 | This is a repo to visualize the word embedding in 2D or 3D with either Principal Component Analysis (PCA) or t-Distributed Stochastic Neighbor Embedding (t-SNE).
4 |
5 | Below is the snapshot of the web app to visualize the word embedding.
6 |
7 |
8 |
9 |
10 |
11 | ## Files
12 |
13 | - train_model.py: Python file to load the pre-trained GloVe word embedding model.
14 | - app.py: Python file to create the word embedding visualization web app.
15 | - glove2word2vec_model.sav: Saved pre-trained word embedding model.
16 |
17 | To execute the web app, go to the working directory of the app.py and type the following command in the conda environment:
18 | ```
19 | streamlit run app.py
20 | ```
21 |
--------------------------------------------------------------------------------
/app.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | Created on Sat Aug 15 17:38:30 2020
4 |
5 | @author: ASUS
6 | """
7 | # Import dependencies
8 | import plotly
9 | import plotly.graph_objs as go
10 | import numpy as np
11 | import pickle
12 | import streamlit as st
13 | from sklearn.decomposition import PCA
14 | from sklearn.manifold import TSNE
15 |
16 | filename = 'glove2word2vec_model.sav'
17 | model = pickle.load(open(filename, 'rb'))
18 |
19 | @st.cache
20 | def append_list(sim_words, words):
21 |
22 | list_of_words = []
23 |
24 | for i in range(len(sim_words)):
25 |
26 | sim_words_list = list(sim_words[i])
27 | sim_words_list.append(words)
28 | sim_words_tuple = tuple(sim_words_list)
29 | list_of_words.append(sim_words_tuple)
30 |
31 | return list_of_words
32 |
33 |
34 | def display_scatterplot_3D(model, user_input=None, words=None, label=None, color_map=None, annotation='On', dim_red = 'PCA', perplexity = 0, learning_rate = 0, iteration = 0, topn=0, sample=10):
35 |
36 | if words == None:
37 | if sample > 0:
38 | words = np.random.choice(list(model.vocab.keys()), sample)
39 | else:
40 | words = [ word for word in model.vocab ]
41 |
42 | word_vectors = np.array([model[w] for w in words])
43 |
44 | if dim_red == 'PCA':
45 | three_dim = PCA(random_state=0).fit_transform(word_vectors)[:,:3]
46 | else:
47 | three_dim = TSNE(n_components = 3, random_state=0, perplexity = perplexity, learning_rate = learning_rate, n_iter = iteration).fit_transform(word_vectors)[:,:3]
48 |
49 | color = 'blue'
50 | quiver = go.Cone(
51 | x = [0,0,0],
52 | y = [0,0,0],
53 | z = [0,0,0],
54 | u = [1.5,0,0],
55 | v = [0,1.5,0],
56 | w = [0,0,1.5],
57 | anchor = "tail",
58 | colorscale = [[0, color] , [1, color]],
59 | showscale = False
60 | )
61 |
62 | data = [quiver]
63 |
64 | count = 0
65 | for i in range (len(user_input)):
66 |
67 | trace = go.Scatter3d(
68 | x = three_dim[count:count+topn,0],
69 | y = three_dim[count:count+topn,1],
70 | z = three_dim[count:count+topn,2],
71 | text = words[count:count+topn] if annotation == 'On' else '',
72 | name = user_input[i],
73 | textposition = "top center",
74 | textfont_size = 30,
75 | mode = 'markers+text',
76 | marker = {
77 | 'size': 10,
78 | 'opacity': 0.8,
79 | 'color': 2
80 | }
81 |
82 | )
83 |
84 | data.append(trace)
85 | count = count+topn
86 |
87 | trace_input = go.Scatter3d(
88 | x = three_dim[count:,0],
89 | y = three_dim[count:,1],
90 | z = three_dim[count:,2],
91 | text = words[count:],
92 | name = 'input words',
93 | textposition = "top center",
94 | textfont_size = 30,
95 | mode = 'markers+text',
96 | marker = {
97 | 'size': 10,
98 | 'opacity': 1,
99 | 'color': 'black'
100 | }
101 | )
102 |
103 | data.append(trace_input)
104 |
105 | # Configure the layout.
106 | layout = go.Layout(
107 | margin = {'l': 0, 'r': 0, 'b': 0, 't': 0},
108 | showlegend=True,
109 | legend=dict(
110 | x=1,
111 | y=0.5,
112 | font=dict(
113 | family="Courier New",
114 | size=25,
115 | color="black"
116 | )),
117 | font = dict(
118 | family = " Courier New ",
119 | size = 15),
120 | autosize = False,
121 | width = 1000,
122 | height = 1000
123 | )
124 |
125 |
126 | plot_figure = go.Figure(data = data, layout = layout)
127 |
128 | st.plotly_chart(plot_figure)
129 |
130 | def horizontal_bar(word, similarity):
131 |
132 | similarity = [ round(elem, 2) for elem in similarity ]
133 |
134 | data = go.Bar(
135 | x= similarity,
136 | y= word,
137 | orientation='h',
138 | text = similarity,
139 | marker_color= 4,
140 | textposition='auto')
141 |
142 | layout = go.Layout(
143 | font = dict(size=20),
144 | xaxis = dict(showticklabels=False, automargin=True),
145 | yaxis = dict(showticklabels=True, automargin=True,autorange="reversed"),
146 | margin = dict(t=20, b= 20, r=10)
147 | )
148 |
149 | plot_figure = go.Figure(data = data, layout = layout)
150 | st.plotly_chart(plot_figure)
151 |
152 | def display_scatterplot_2D(model, user_input=None, words=None, label=None, color_map=None, annotation='On', dim_red = 'PCA', perplexity = 0, learning_rate = 0, iteration = 0, topn=0, sample=10):
153 |
154 | if words == None:
155 | if sample > 0:
156 | words = np.random.choice(list(model.vocab.keys()), sample)
157 | else:
158 | words = [ word for word in model.vocab ]
159 |
160 | word_vectors = np.array([model[w] for w in words])
161 |
162 | if dim_red == 'PCA':
163 | two_dim = PCA(random_state=0).fit_transform(word_vectors)[:,:2]
164 | else:
165 | two_dim = TSNE(random_state=0, perplexity = perplexity, learning_rate = learning_rate, n_iter = iteration).fit_transform(word_vectors)[:,:2]
166 |
167 |
168 | data = []
169 | count = 0
170 | for i in range (len(user_input)):
171 |
172 | trace = go.Scatter(
173 | x = two_dim[count:count+topn,0],
174 | y = two_dim[count:count+topn,1],
175 | text = words[count:count+topn] if annotation == 'On' else '',
176 | name = user_input[i],
177 | textposition = "top center",
178 | textfont_size = 20,
179 | mode = 'markers+text',
180 | marker = {
181 | 'size': 15,
182 | 'opacity': 0.8,
183 | 'color': 2
184 | }
185 |
186 | )
187 |
188 | data.append(trace)
189 | count = count+topn
190 |
191 | trace_input = go.Scatter(
192 | x = two_dim[count:,0],
193 | y = two_dim[count:,1],
194 | text = words[count:],
195 | name = 'input words',
196 | textposition = "top center",
197 | textfont_size = 20,
198 | mode = 'markers+text',
199 | marker = {
200 | 'size': 25,
201 | 'opacity': 1,
202 | 'color': 'black'
203 | }
204 | )
205 |
206 | data.append(trace_input)
207 |
208 | # Configure the layout.
209 | layout = go.Layout(
210 | margin = {'l': 0, 'r': 0, 'b': 0, 't': 0},
211 | showlegend=True,
212 | hoverlabel=dict(
213 | bgcolor="white",
214 | font_size=20,
215 | font_family="Courier New"),
216 | legend=dict(
217 | x=1,
218 | y=0.5,
219 | font=dict(
220 | family="Courier New",
221 | size=25,
222 | color="black"
223 | )),
224 | font = dict(
225 | family = " Courier New ",
226 | size = 15),
227 | autosize = False,
228 | width = 1000,
229 | height = 1000
230 | )
231 |
232 |
233 | plot_figure = go.Figure(data = data, layout = layout)
234 |
235 | st.plotly_chart(plot_figure)
236 |
237 | dim_red = st.sidebar.selectbox(
238 | 'Select dimension reduction method',
239 | ('PCA','TSNE'))
240 | dimension = st.sidebar.radio(
241 | "Select the dimension of the visualization",
242 | ('2D', '3D'))
243 | user_input = st.sidebar.text_input("Type the word that you want to investigate. You can type more than one word by separating one word with other with comma (,)",'')
244 | top_n = st.sidebar.slider('Select the amount of words associated with the input words you want to visualize ',
245 | 5, 100, (5))
246 | annotation = st.sidebar.radio(
247 | "Enable or disable the annotation on the visualization",
248 | ('On', 'Off'))
249 |
250 | if dim_red == 'TSNE':
251 | perplexity = st.sidebar.slider('Adjust the perplexity. The perplexity is related to the number of nearest neighbors that is used in other manifold learning algorithms. Larger datasets usually require a larger perplexity',
252 | 5, 50, (30))
253 |
254 | learning_rate = st.sidebar.slider('Adjust the learning rate',
255 | 10, 1000, (200))
256 |
257 | iteration = st.sidebar.slider('Adjust the number of iteration',
258 | 250, 100000, (1000))
259 |
260 | else:
261 | perplexity = 0
262 | learning_rate = 0
263 | iteration = 0
264 |
265 | if user_input == '':
266 |
267 | similar_word = None
268 | labels = None
269 | color_map = None
270 |
271 | else:
272 |
273 | user_input = [x.strip() for x in user_input.split(',')]
274 | result_word = []
275 |
276 | for words in user_input:
277 |
278 | sim_words = model.most_similar(words, topn = top_n)
279 | sim_words = append_list(sim_words, words)
280 |
281 | result_word.extend(sim_words)
282 |
283 | similar_word = [word[0] for word in result_word]
284 | similarity = [word[1] for word in result_word]
285 | similar_word.extend(user_input)
286 | labels = [word[2] for word in result_word]
287 | label_dict = dict([(y,x+1) for x,y in enumerate(set(labels))])
288 | color_map = [label_dict[x] for x in labels]
289 |
290 |
291 | st.title('Word Embedding Visualization Based on Cosine Similarity')
292 |
293 | st.header('This is a web app to visualize the word embedding.')
294 | st.markdown('First, choose which dimension of visualization that you want to see. There are two options: 2D and 3D.')
295 |
296 | st.markdown('Next, type the word that you want to investigate. You can type more than one word by separating one word with other with comma (,).')
297 |
298 | st.markdown('With the slider in the sidebar, you can pick the amount of words associated with the input word you want to visualize. This is done by computing the cosine similarity between vectors of words in embedding space.')
299 | st.markdown('Lastly, you have an option to enable or disable the text annotation in the visualization.')
300 |
301 | if dimension == '2D':
302 | st.header('2D Visualization')
303 | st.write('For more detail about each point (just in case it is difficult to read the annotation), you can hover around each points to see the words. You can expand the visualization by clicking expand symbol in the top right corner of the visualization.')
304 | display_scatterplot_2D(model, user_input, similar_word, labels, color_map, annotation, dim_red, perplexity, learning_rate, iteration, top_n)
305 | else:
306 | st.header('3D Visualization')
307 | st.write('For more detail about each point (just in case it is difficult to read the annotation), you can hover around each points to see the words. You can expand the visualization by clicking expand symbol in the top right corner of the visualization.')
308 | display_scatterplot_3D(model, user_input, similar_word, labels, color_map, annotation, dim_red, perplexity, learning_rate, iteration, top_n)
309 |
310 | st.header('The Top 5 Most Similar Words for Each Input')
311 | count=0
312 | for i in range (len(user_input)):
313 |
314 | st.write('The most similar words from '+str(user_input[i])+' are:')
315 | horizontal_bar(similar_word[count:count+5], similarity[count:count+5])
316 |
317 | count = count+top_n
318 |
--------------------------------------------------------------------------------