├── Procfile ├── .gitattributes ├── requirements.txt ├── word_embedding_gif.gif ├── setup.sh ├── glove2word2vec_model.sav ├── train_model.py ├── README.md └── app.py /Procfile: -------------------------------------------------------------------------------- 1 | web: sh setup.sh && streamlit run app.py 2 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | glove2word2vec_model.sav filter=lfs diff=lfs merge=lfs -text 2 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.18.1 2 | streamlit==0.63.0 3 | plotly==4.9.0 4 | scikit-learn==0.22.1 5 | 6 | -------------------------------------------------------------------------------- /word_embedding_gif.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mrubenw/word-embedding-visualization/HEAD/word_embedding_gif.gif -------------------------------------------------------------------------------- /setup.sh: -------------------------------------------------------------------------------- 1 | mkdir -p ~/.streamlit/ 2 | echo "\ 3 | [server]\n\ 4 | headless = true\n\ 5 | port = $PORT\n\ 6 | enableCORS = false\n\ 7 | \n\ 8 | " > ~/.streamlit/config.toml -------------------------------------------------------------------------------- /glove2word2vec_model.sav: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:4d85c21e1fc97a1a65e955ce89e4415126ebe77a6610c5a3c78ea8fd75829e2f 3 | size 181494860 4 | -------------------------------------------------------------------------------- /train_model.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | """ 4 | 5 | import numpy as np 6 | 7 | import pickle 8 | import matplotlib.pyplot as plt 9 | plt.style.use('ggplot') 10 | 11 | from sklearn.decomposition import PCA 12 | 13 | from gensim.test.utils import datapath, get_tmpfile 14 | from gensim.models import KeyedVectors 15 | from gensim.scripts.glove2word2vec import glove2word2vec 16 | from mpl_toolkits.mplot3d import Axes3D 17 | 18 | glove_file = datapath('C:/Users/ASUS/glove.6B.100d.txt') 19 | word2vec_glove_file = get_tmpfile("glove.6B.100d.word2vec.txt") 20 | glove2word2vec(glove_file, word2vec_glove_file) 21 | 22 | model = KeyedVectors.load_word2vec_format(word2vec_glove_file) 23 | 24 | filename = 'glove2word2vec_model.sav' 25 | pickle.dump(model, open(filename, 'wb')) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Word Embedding Visualization 2 | 3 | This is a repo to visualize the word embedding in 2D or 3D with either Principal Component Analysis (PCA) or t-Distributed Stochastic Neighbor Embedding (t-SNE). 4 | 5 | Below is the snapshot of the web app to visualize the word embedding. 6 | 7 |

8 | 9 |

10 | 11 | ## Files 12 | 13 | - train_model.py: Python file to load the pre-trained GloVe word embedding model. 14 | - app.py: Python file to create the word embedding visualization web app. 15 | - glove2word2vec_model.sav: Saved pre-trained word embedding model. 16 | 17 | To execute the web app, go to the working directory of the app.py and type the following command in the conda environment: 18 | ``` 19 | streamlit run app.py 20 | ``` 21 | -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Sat Aug 15 17:38:30 2020 4 | 5 | @author: ASUS 6 | """ 7 | # Import dependencies 8 | import plotly 9 | import plotly.graph_objs as go 10 | import numpy as np 11 | import pickle 12 | import streamlit as st 13 | from sklearn.decomposition import PCA 14 | from sklearn.manifold import TSNE 15 | 16 | filename = 'glove2word2vec_model.sav' 17 | model = pickle.load(open(filename, 'rb')) 18 | 19 | @st.cache 20 | def append_list(sim_words, words): 21 | 22 | list_of_words = [] 23 | 24 | for i in range(len(sim_words)): 25 | 26 | sim_words_list = list(sim_words[i]) 27 | sim_words_list.append(words) 28 | sim_words_tuple = tuple(sim_words_list) 29 | list_of_words.append(sim_words_tuple) 30 | 31 | return list_of_words 32 | 33 | 34 | def display_scatterplot_3D(model, user_input=None, words=None, label=None, color_map=None, annotation='On', dim_red = 'PCA', perplexity = 0, learning_rate = 0, iteration = 0, topn=0, sample=10): 35 | 36 | if words == None: 37 | if sample > 0: 38 | words = np.random.choice(list(model.vocab.keys()), sample) 39 | else: 40 | words = [ word for word in model.vocab ] 41 | 42 | word_vectors = np.array([model[w] for w in words]) 43 | 44 | if dim_red == 'PCA': 45 | three_dim = PCA(random_state=0).fit_transform(word_vectors)[:,:3] 46 | else: 47 | three_dim = TSNE(n_components = 3, random_state=0, perplexity = perplexity, learning_rate = learning_rate, n_iter = iteration).fit_transform(word_vectors)[:,:3] 48 | 49 | color = 'blue' 50 | quiver = go.Cone( 51 | x = [0,0,0], 52 | y = [0,0,0], 53 | z = [0,0,0], 54 | u = [1.5,0,0], 55 | v = [0,1.5,0], 56 | w = [0,0,1.5], 57 | anchor = "tail", 58 | colorscale = [[0, color] , [1, color]], 59 | showscale = False 60 | ) 61 | 62 | data = [quiver] 63 | 64 | count = 0 65 | for i in range (len(user_input)): 66 | 67 | trace = go.Scatter3d( 68 | x = three_dim[count:count+topn,0], 69 | y = three_dim[count:count+topn,1], 70 | z = three_dim[count:count+topn,2], 71 | text = words[count:count+topn] if annotation == 'On' else '', 72 | name = user_input[i], 73 | textposition = "top center", 74 | textfont_size = 30, 75 | mode = 'markers+text', 76 | marker = { 77 | 'size': 10, 78 | 'opacity': 0.8, 79 | 'color': 2 80 | } 81 | 82 | ) 83 | 84 | data.append(trace) 85 | count = count+topn 86 | 87 | trace_input = go.Scatter3d( 88 | x = three_dim[count:,0], 89 | y = three_dim[count:,1], 90 | z = three_dim[count:,2], 91 | text = words[count:], 92 | name = 'input words', 93 | textposition = "top center", 94 | textfont_size = 30, 95 | mode = 'markers+text', 96 | marker = { 97 | 'size': 10, 98 | 'opacity': 1, 99 | 'color': 'black' 100 | } 101 | ) 102 | 103 | data.append(trace_input) 104 | 105 | # Configure the layout. 106 | layout = go.Layout( 107 | margin = {'l': 0, 'r': 0, 'b': 0, 't': 0}, 108 | showlegend=True, 109 | legend=dict( 110 | x=1, 111 | y=0.5, 112 | font=dict( 113 | family="Courier New", 114 | size=25, 115 | color="black" 116 | )), 117 | font = dict( 118 | family = " Courier New ", 119 | size = 15), 120 | autosize = False, 121 | width = 1000, 122 | height = 1000 123 | ) 124 | 125 | 126 | plot_figure = go.Figure(data = data, layout = layout) 127 | 128 | st.plotly_chart(plot_figure) 129 | 130 | def horizontal_bar(word, similarity): 131 | 132 | similarity = [ round(elem, 2) for elem in similarity ] 133 | 134 | data = go.Bar( 135 | x= similarity, 136 | y= word, 137 | orientation='h', 138 | text = similarity, 139 | marker_color= 4, 140 | textposition='auto') 141 | 142 | layout = go.Layout( 143 | font = dict(size=20), 144 | xaxis = dict(showticklabels=False, automargin=True), 145 | yaxis = dict(showticklabels=True, automargin=True,autorange="reversed"), 146 | margin = dict(t=20, b= 20, r=10) 147 | ) 148 | 149 | plot_figure = go.Figure(data = data, layout = layout) 150 | st.plotly_chart(plot_figure) 151 | 152 | def display_scatterplot_2D(model, user_input=None, words=None, label=None, color_map=None, annotation='On', dim_red = 'PCA', perplexity = 0, learning_rate = 0, iteration = 0, topn=0, sample=10): 153 | 154 | if words == None: 155 | if sample > 0: 156 | words = np.random.choice(list(model.vocab.keys()), sample) 157 | else: 158 | words = [ word for word in model.vocab ] 159 | 160 | word_vectors = np.array([model[w] for w in words]) 161 | 162 | if dim_red == 'PCA': 163 | two_dim = PCA(random_state=0).fit_transform(word_vectors)[:,:2] 164 | else: 165 | two_dim = TSNE(random_state=0, perplexity = perplexity, learning_rate = learning_rate, n_iter = iteration).fit_transform(word_vectors)[:,:2] 166 | 167 | 168 | data = [] 169 | count = 0 170 | for i in range (len(user_input)): 171 | 172 | trace = go.Scatter( 173 | x = two_dim[count:count+topn,0], 174 | y = two_dim[count:count+topn,1], 175 | text = words[count:count+topn] if annotation == 'On' else '', 176 | name = user_input[i], 177 | textposition = "top center", 178 | textfont_size = 20, 179 | mode = 'markers+text', 180 | marker = { 181 | 'size': 15, 182 | 'opacity': 0.8, 183 | 'color': 2 184 | } 185 | 186 | ) 187 | 188 | data.append(trace) 189 | count = count+topn 190 | 191 | trace_input = go.Scatter( 192 | x = two_dim[count:,0], 193 | y = two_dim[count:,1], 194 | text = words[count:], 195 | name = 'input words', 196 | textposition = "top center", 197 | textfont_size = 20, 198 | mode = 'markers+text', 199 | marker = { 200 | 'size': 25, 201 | 'opacity': 1, 202 | 'color': 'black' 203 | } 204 | ) 205 | 206 | data.append(trace_input) 207 | 208 | # Configure the layout. 209 | layout = go.Layout( 210 | margin = {'l': 0, 'r': 0, 'b': 0, 't': 0}, 211 | showlegend=True, 212 | hoverlabel=dict( 213 | bgcolor="white", 214 | font_size=20, 215 | font_family="Courier New"), 216 | legend=dict( 217 | x=1, 218 | y=0.5, 219 | font=dict( 220 | family="Courier New", 221 | size=25, 222 | color="black" 223 | )), 224 | font = dict( 225 | family = " Courier New ", 226 | size = 15), 227 | autosize = False, 228 | width = 1000, 229 | height = 1000 230 | ) 231 | 232 | 233 | plot_figure = go.Figure(data = data, layout = layout) 234 | 235 | st.plotly_chart(plot_figure) 236 | 237 | dim_red = st.sidebar.selectbox( 238 | 'Select dimension reduction method', 239 | ('PCA','TSNE')) 240 | dimension = st.sidebar.radio( 241 | "Select the dimension of the visualization", 242 | ('2D', '3D')) 243 | user_input = st.sidebar.text_input("Type the word that you want to investigate. You can type more than one word by separating one word with other with comma (,)",'') 244 | top_n = st.sidebar.slider('Select the amount of words associated with the input words you want to visualize ', 245 | 5, 100, (5)) 246 | annotation = st.sidebar.radio( 247 | "Enable or disable the annotation on the visualization", 248 | ('On', 'Off')) 249 | 250 | if dim_red == 'TSNE': 251 | perplexity = st.sidebar.slider('Adjust the perplexity. The perplexity is related to the number of nearest neighbors that is used in other manifold learning algorithms. Larger datasets usually require a larger perplexity', 252 | 5, 50, (30)) 253 | 254 | learning_rate = st.sidebar.slider('Adjust the learning rate', 255 | 10, 1000, (200)) 256 | 257 | iteration = st.sidebar.slider('Adjust the number of iteration', 258 | 250, 100000, (1000)) 259 | 260 | else: 261 | perplexity = 0 262 | learning_rate = 0 263 | iteration = 0 264 | 265 | if user_input == '': 266 | 267 | similar_word = None 268 | labels = None 269 | color_map = None 270 | 271 | else: 272 | 273 | user_input = [x.strip() for x in user_input.split(',')] 274 | result_word = [] 275 | 276 | for words in user_input: 277 | 278 | sim_words = model.most_similar(words, topn = top_n) 279 | sim_words = append_list(sim_words, words) 280 | 281 | result_word.extend(sim_words) 282 | 283 | similar_word = [word[0] for word in result_word] 284 | similarity = [word[1] for word in result_word] 285 | similar_word.extend(user_input) 286 | labels = [word[2] for word in result_word] 287 | label_dict = dict([(y,x+1) for x,y in enumerate(set(labels))]) 288 | color_map = [label_dict[x] for x in labels] 289 | 290 | 291 | st.title('Word Embedding Visualization Based on Cosine Similarity') 292 | 293 | st.header('This is a web app to visualize the word embedding.') 294 | st.markdown('First, choose which dimension of visualization that you want to see. There are two options: 2D and 3D.') 295 | 296 | st.markdown('Next, type the word that you want to investigate. You can type more than one word by separating one word with other with comma (,).') 297 | 298 | st.markdown('With the slider in the sidebar, you can pick the amount of words associated with the input word you want to visualize. This is done by computing the cosine similarity between vectors of words in embedding space.') 299 | st.markdown('Lastly, you have an option to enable or disable the text annotation in the visualization.') 300 | 301 | if dimension == '2D': 302 | st.header('2D Visualization') 303 | st.write('For more detail about each point (just in case it is difficult to read the annotation), you can hover around each points to see the words. You can expand the visualization by clicking expand symbol in the top right corner of the visualization.') 304 | display_scatterplot_2D(model, user_input, similar_word, labels, color_map, annotation, dim_red, perplexity, learning_rate, iteration, top_n) 305 | else: 306 | st.header('3D Visualization') 307 | st.write('For more detail about each point (just in case it is difficult to read the annotation), you can hover around each points to see the words. You can expand the visualization by clicking expand symbol in the top right corner of the visualization.') 308 | display_scatterplot_3D(model, user_input, similar_word, labels, color_map, annotation, dim_red, perplexity, learning_rate, iteration, top_n) 309 | 310 | st.header('The Top 5 Most Similar Words for Each Input') 311 | count=0 312 | for i in range (len(user_input)): 313 | 314 | st.write('The most similar words from '+str(user_input[i])+' are:') 315 | horizontal_bar(similar_word[count:count+5], similarity[count:count+5]) 316 | 317 | count = count+top_n 318 | --------------------------------------------------------------------------------