├── LICENSE ├── README.md ├── appendix ├── keras_cnn.ipynb ├── seq2seq_nmt.ipynb └── tensorboard_word_embeddings.ipynb ├── ch10 ├── bleu_score_example.ipynb ├── neural_machine_translation.ipynb ├── neural_machine_translation_attention.ipynb ├── nmt_with_pretrained_wordvecs.ipynb └── word2vec.py ├── ch11 └── tv_embeddings.ipynb ├── ch2 ├── tensorflow_introduction.ipynb ├── test1.txt ├── test2.txt └── test3.txt ├── ch3 ├── ch3_word2vec.ipynb └── ch3_wordnet.ipynb ├── ch4 ├── ch4_document_embedding.ipynb ├── ch4_glove.ipynb ├── ch4_word2vec_extended.ipynb └── ch4_word2vec_improvements.ipynb ├── ch5 ├── cnn_sentence_classification.ipynb └── image_classification_mnist.ipynb ├── ch6 ├── rnn_language_bigram.ipynb └── rnn_language_bigram_multilayer.ipynb ├── ch8 ├── embeddings.npy ├── lstm_extensions.ipynb ├── lstm_word2vec.ipynb ├── lstm_word2vec_rnn_api.ipynb ├── lstms_for_text_generation.ipynb ├── plot_perplexity_over_time.ipynb └── word2vec.py └── ch9 ├── correct_spellings.py ├── image_caption_data └── class_names.txt ├── lstm_image_caption.ipynb ├── lstm_image_caption_pretrained_wordvecs_rnn_api.ipynb └── word2vec.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Packt 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | # Natural Language Processing with TensorFlow 5 | This is the code repository for [Natural Language Processing with TensorFlow](https://www.packtpub.com/application-development/natural-language-processing-tensorflow?utm_source=github&utm_medium=repository&utm_campaign=9781788478311), published by [Packt](https://www.packtpub.com/?utm_source=github). It contains all the supporting project files necessary to work through the book from start to finish. 6 | ## About the Book 7 | Natural language processing (NLP) supplies the majority of data available to deep learning applications, while TensorFlow is the most important deep learning framework currently available. Natural Language Processing with TensorFlow brings TensorFlow and NLP together to give you invaluable tools to work with the immense volume of unstructured data in today’s data streams, and apply these tools to specific NLP tasks. 8 | 9 | Thushan Ganegedara starts by giving you a grounding in NLP and TensorFlow basics. You'll then learn how to use Word2vec, including advanced extensions, to create word embeddings that turn sequences of words into vectors accessible to deep learning algorithms. Chapters on classical deep learning algorithms, like convolutional neural networks (CNN) and recurrent neural networks (RNN), demonstrate important NLP tasks as sentence classification and language generation. You will learn how to apply high-performance RNN models, like long short-term memory (LSTM) cells, to NLP tasks. You will also explore neural machine translation and implement a neural machine translator. 10 | 11 | After reading this book, you will gain an understanding of NLP and you'll have the skills to apply TensorFlow in deep learning NLP applications, and how to perform specific NLP tasks. 12 | 13 | ## Instructions and Navigations 14 | All of the code is organized into folders. Each folder starts with a number followed by the application name. For example, Chapter02. 15 | 16 | 17 | 18 | The code will look like the following: 19 | ``` 20 | graph = tf.Graph() # Creates a graph 21 | session = tf.InteractiveSession(graph=graph) # Creates a session 22 | ``` 23 | 24 | To get the most out of this book, we assume the following from the reader: 25 | * A solid will and an ambition to learn the modern ways of NLP 26 | * Familiarity with basic Python syntax and data structures (for example, lists and dictionaries) 27 | * A good understanding of basic mathematics (for example, matrix/vector multiplication) 28 | * (Optional) Advance mathematics knowledge (for example, derivative calculation) to understand a handful of subsections that cover the details of how certain learning models overcome potential practical issues faced during training 29 | * (Optional) Read research papers to refer to advances/details in systems, beyond what the book covers 30 | 31 | ## Related Products 32 | * [Hands-On Deep Learning with TensorFlow](https://www.packtpub.com/big-data-and-business-intelligence/hands-deep-learning-tensorflow?utm_source=github&utm_medium=repository&utm_campaign=9781787282773) 33 | 34 | * [Deep Learning with TensorFlow - Second Edition](https://www.packtpub.com/big-data-and-business-intelligence/deep-learning-tensorflow-second-edition?utm_source=github&utm_medium=repository&utm_campaign=9781788831109) 35 | 36 | * [Beginning Application Development with TensorFlow and Keras](https://www.packtpub.com/application-development/beginning-application-development-tensorflow-and-keras?utm_source=github&utm_medium=repository&utm_campaign=9781789537291) 37 | ### Download a free PDF 38 | 39 | If you have already purchased a print or Kindle version of this book, you can get a DRM-free PDF version at no cost.
Simply click on the link to claim your free PDF.
40 |

https://packt.link/free-ebook/9781788478311

-------------------------------------------------------------------------------- /appendix/keras_cnn.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "name": "stderr", 10 | "output_type": "stream", 11 | "text": [ 12 | "c:\\users\\thushan\\documents\\python_virtualenvs\\tensorflow_venv\\lib\\site-packages\\h5py\\__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.\n", 13 | " from ._conv import register_converters as _register_converters\n" 14 | ] 15 | } 16 | ], 17 | "source": [ 18 | "import tensorflow as tf\n", 19 | "import numpy as np\n", 20 | "from tensorflow.python.keras.models import Sequential\n", 21 | "from tensorflow.python.keras.layers import Conv2D,Dense,MaxPool2D,BatchNormalization,Flatten\n", 22 | "from tensorflow.python.keras import backend as K\n", 23 | "\n", 24 | "# Required for Data downaload and preparation\n", 25 | "import struct\n", 26 | "import gzip\n", 27 | "import os\n", 28 | "from six.moves.urllib.request import urlretrieve" 29 | ] 30 | }, 31 | { 32 | "cell_type": "markdown", 33 | "metadata": {}, 34 | "source": [ 35 | "## Lolading Data\n", 36 | "\n", 37 | "Here we download (if needed) the MNIST dataset and, perform reshaping and normalization. Also we conver the labels to one hot encoded vectors." 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": 2, 43 | "metadata": {}, 44 | "outputs": [ 45 | { 46 | "name": "stdout", 47 | "output_type": "stream", 48 | "text": [ 49 | "Found and verified train-images-idx3-ubyte.gz\n", 50 | "Found and verified train-labels-idx1-ubyte.gz\n", 51 | "Found and verified t10k-images-idx3-ubyte.gz\n", 52 | "Found and verified t10k-labels-idx1-ubyte.gz\n", 53 | "\n", 54 | "Reading files train-images-idx3-ubyte.gz and train-labels-idx1-ubyte.gz\n", 55 | "60000 28 28\n", 56 | "(Images) Returned a tensor of shape (60000, 28, 28, 1)\n", 57 | "(Labels) Returned a tensor of shape: 60000\n", 58 | "Sample labels: [5 0 4 1 9 2 1 3 1 4]\n", 59 | "\n", 60 | "Reading files t10k-images-idx3-ubyte.gz and t10k-labels-idx1-ubyte.gz\n", 61 | "10000 28 28\n", 62 | "(Images) Returned a tensor of shape (10000, 28, 28, 1)\n", 63 | "(Labels) Returned a tensor of shape: 10000\n", 64 | "Sample labels: [7 2 1 0 4 1 4 9 5 9]\n", 65 | "\n", 66 | "Train size: 60000\n", 67 | "\n", 68 | "Test size: 10000\n" 69 | ] 70 | } 71 | ], 72 | "source": [ 73 | "def maybe_download(url, filename, expected_bytes, force=False):\n", 74 | " \"\"\"Download a file if not present, and make sure it's the right size.\"\"\"\n", 75 | " if force or not os.path.exists(filename):\n", 76 | " print('Attempting to download:', filename) \n", 77 | " filename, _ = urlretrieve(url + filename, filename)\n", 78 | " print('\\nDownload Complete!')\n", 79 | " statinfo = os.stat(filename)\n", 80 | " if statinfo.st_size == expected_bytes:\n", 81 | " print('Found and verified', filename)\n", 82 | " else:\n", 83 | " raise Exception(\n", 84 | " 'Failed to verify ' + filename + '. Can you get to it with a browser?')\n", 85 | " return filename\n", 86 | "\n", 87 | "\n", 88 | "def read_mnist(fname_img, fname_lbl, one_hot=False):\n", 89 | " print('\\nReading files %s and %s'%(fname_img, fname_lbl))\n", 90 | " \n", 91 | " # Processing images\n", 92 | " with gzip.open(fname_img) as fimg: \n", 93 | " magic, num, rows, cols = struct.unpack(\">IIII\", fimg.read(16))\n", 94 | " print(num,rows,cols)\n", 95 | " img = (np.frombuffer(fimg.read(num*rows*cols), dtype=np.uint8).reshape(num, rows, cols,1)).astype(np.float32)\n", 96 | " print('(Images) Returned a tensor of shape ',img.shape)\n", 97 | " \n", 98 | " #img = (img - np.mean(img)) /np.std(img)\n", 99 | " img *= 1.0 / 255.0\n", 100 | " \n", 101 | " # Processing labels\n", 102 | " with gzip.open(fname_lbl) as flbl:\n", 103 | " # flbl.read(8) reads upto 8 bytes\n", 104 | " magic, num = struct.unpack(\">II\", flbl.read(8)) \n", 105 | " lbl = np.frombuffer(flbl.read(num), dtype=np.int8)\n", 106 | " if one_hot:\n", 107 | " one_hot_lbl = np.zeros(shape=(num,10),dtype=np.float32)\n", 108 | " one_hot_lbl[np.arange(num),lbl] = 1.0\n", 109 | " print('(Labels) Returned a tensor of shape: %s'%lbl.shape)\n", 110 | " print('Sample labels: ',lbl[:10])\n", 111 | " \n", 112 | " if not one_hot:\n", 113 | " return img, lbl\n", 114 | " else:\n", 115 | " return img, one_hot_lbl\n", 116 | " \n", 117 | " \n", 118 | "# Download data if needed\n", 119 | "url = 'http://yann.lecun.com/exdb/mnist/'\n", 120 | "# training data\n", 121 | "maybe_download(url,'train-images-idx3-ubyte.gz',9912422)\n", 122 | "maybe_download(url,'train-labels-idx1-ubyte.gz',28881)\n", 123 | "# testing data\n", 124 | "maybe_download(url,'t10k-images-idx3-ubyte.gz',1648877)\n", 125 | "maybe_download(url,'t10k-labels-idx1-ubyte.gz',4542)\n", 126 | "\n", 127 | "# Read the training and testing data \n", 128 | "train_inputs, train_labels = read_mnist('train-images-idx3-ubyte.gz', 'train-labels-idx1-ubyte.gz',True)\n", 129 | "test_inputs, test_labels = read_mnist('t10k-images-idx3-ubyte.gz', 't10k-labels-idx1-ubyte.gz',True)\n", 130 | "\n", 131 | "\n", 132 | "print('\\nTrain size: ', train_inputs.shape[0])\n", 133 | "print('\\nTest size: ', test_inputs.shape[0])" 134 | ] 135 | }, 136 | { 137 | "cell_type": "markdown", 138 | "metadata": {}, 139 | "source": [ 140 | "## Data Generators for MNIST\n", 141 | "\n", 142 | "Here we have the logic to iterate through each training, validation and testing datasets, in `batch_size` size strides." 143 | ] 144 | }, 145 | { 146 | "cell_type": "code", 147 | "execution_count": 3, 148 | "metadata": { 149 | "collapsed": true 150 | }, 151 | "outputs": [], 152 | "source": [ 153 | "train_index, test_index = 0,0\n", 154 | "\n", 155 | "def get_train_batch(images, labels, batch_size):\n", 156 | " global train_index\n", 157 | " batch = images[train_index:train_index+batch_size,:,:,:], labels[train_index:train_index+batch_size,:]\n", 158 | " train_index = (train_index + batch_size)%(images.shape[0] - batch_size)\n", 159 | " return batch\n", 160 | "\n", 161 | "\n", 162 | "def get_test_batch(images, labels, batch_size):\n", 163 | " global test_index\n", 164 | " batch = images[test_index:test_index+batch_size,:,:,:], labels[test_index:test_index+batch_size,:]\n", 165 | " test_index = (test_index + batch_size)%(images.shape[0] - batch_size)\n", 166 | " return batch" 167 | ] 168 | }, 169 | { 170 | "cell_type": "code", 171 | "execution_count": 4, 172 | "metadata": { 173 | "collapsed": true 174 | }, 175 | "outputs": [], 176 | "source": [ 177 | "config = tf.ConfigProto(allow_soft_placement=True)\n", 178 | "# Good practice to use this to avoid any surprising errors thrown by TensorFlow\n", 179 | "config.gpu_options.allow_growth = True \n", 180 | "config.gpu_options.per_process_gpu_memory_fraction = 0.9 # Making sure Tensorflow doesn't overflow the GPU\n", 181 | "sess = tf.Session(config=config)\n", 182 | "K.set_session(sess)" 183 | ] 184 | }, 185 | { 186 | "cell_type": "code", 187 | "execution_count": 5, 188 | "metadata": { 189 | "collapsed": true 190 | }, 191 | "outputs": [], 192 | "source": [ 193 | "# Define a sequential model\n", 194 | "model = Sequential()\n", 195 | "\n", 196 | "# Added a convolution layer\n", 197 | "model.add(Conv2D(32, 3, activation='relu', input_shape=[28, 28, 1]))\n", 198 | "\n", 199 | "# Add a max pool lyer\n", 200 | "model.add(MaxPool2D())\n", 201 | "\n", 202 | "# Add a batch norm layer\n", 203 | "model.add(BatchNormalization())\n", 204 | "\n", 205 | "# Convolution layer\n", 206 | "model.add(Conv2D(64, 3, activation='relu'))\n", 207 | "# Max pool layer\n", 208 | "model.add(MaxPool2D())\n", 209 | "# Add a batch norm layer\n", 210 | "model.add(BatchNormalization())\n", 211 | "\n", 212 | "# More convolution, max pool, batch norm\n", 213 | "model.add(Conv2D(128, 3, activation='relu'))\n", 214 | "model.add(MaxPool2D())\n", 215 | "model.add(BatchNormalization())\n", 216 | "\n", 217 | "model.add(Flatten())\n", 218 | "\n", 219 | "model.add(Dense(256, activation='relu'))\n", 220 | "model.add(Dense(10, activation='softmax'))\n", 221 | "\n", 222 | "model.compile(\n", 223 | " optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy']\n", 224 | ")" 225 | ] 226 | }, 227 | { 228 | "cell_type": "code", 229 | "execution_count": 6, 230 | "metadata": {}, 231 | "outputs": [ 232 | { 233 | "name": "stdout", 234 | "output_type": "stream", 235 | "text": [ 236 | "Training for epoch: 0\n", 237 | "Epoch 1/1\n", 238 | "60000/60000 [==============================] - 13s 210us/step - loss: 0.1245 - acc: 0.9628\n", 239 | "10000/10000 [==============================] - 1s 67us/step\n", 240 | "\tEpoch ( 0 ) Test accuracy: 0.9757000070810318\n", 241 | "Training for epoch: 1\n", 242 | "Epoch 1/1\n", 243 | "60000/60000 [==============================] - 11s 182us/step - loss: 0.0452 - acc: 0.9859\n", 244 | "10000/10000 [==============================] - 1s 55us/step\n", 245 | "\tEpoch ( 1 ) Test accuracy: 0.9827000075578689\n", 246 | "Training for epoch: 2\n", 247 | "Epoch 1/1\n", 248 | "60000/60000 [==============================] - 11s 186us/step - loss: 0.0294 - acc: 0.9910\n", 249 | "10000/10000 [==============================] - 1s 56us/step\n", 250 | "\tEpoch ( 2 ) Test accuracy: 0.9847000110149383\n", 251 | "Training for epoch: 3\n", 252 | "Epoch 1/1\n", 253 | "60000/60000 [==============================] - 11s 184us/step - loss: 0.0255 - acc: 0.9918\n", 254 | "10000/10000 [==============================] - 1s 56us/step\n", 255 | "\tEpoch ( 3 ) Test accuracy: 0.9859000080823899\n", 256 | "Training for epoch: 4\n", 257 | "Epoch 1/1\n", 258 | "60000/60000 [==============================] - 11s 182us/step - loss: 0.0197 - acc: 0.9937\n", 259 | "10000/10000 [==============================] - 1s 56us/step\n", 260 | "\tEpoch ( 4 ) Test accuracy: 0.9872000062465668\n" 261 | ] 262 | } 263 | ], 264 | "source": [ 265 | "n_epochs = 5 # Number of epochs the training runs for\n", 266 | "\n", 267 | "n_train = 55000\n", 268 | "n_test = 10000\n", 269 | "\n", 270 | "batch_size = 100\n", 271 | "\n", 272 | "x_train, y_train = train_inputs, train_labels\n", 273 | " \n", 274 | "x_test, y_test = test_inputs, test_labels\n", 275 | " \n", 276 | "for epoch in range(n_epochs):\n", 277 | "\n", 278 | " print('Training for epoch: ',epoch) \n", 279 | " \n", 280 | " # Training for a single epoch\n", 281 | " model.fit(x_train, y_train, batch_size = batch_size)\n", 282 | " \n", 283 | " # Testing phase\n", 284 | " # Returns a list where first item is loss and second is accuracy\n", 285 | " test_acc = model.evaluate(x_test, y_test, batch_size=batch_size) \n", 286 | " print('\\tEpoch (', epoch ,') Test accuracy: ',test_acc[1])" 287 | ] 288 | }, 289 | { 290 | "cell_type": "code", 291 | "execution_count": null, 292 | "metadata": { 293 | "collapsed": true 294 | }, 295 | "outputs": [], 296 | "source": [] 297 | } 298 | ], 299 | "metadata": { 300 | "kernelspec": { 301 | "display_name": "Python 3", 302 | "language": "python", 303 | "name": "python3" 304 | }, 305 | "language_info": { 306 | "codemirror_mode": { 307 | "name": "ipython", 308 | "version": 3 309 | }, 310 | "file_extension": ".py", 311 | "mimetype": "text/x-python", 312 | "name": "python", 313 | "nbconvert_exporter": "python", 314 | "pygments_lexer": "ipython3", 315 | "version": "3.5.2" 316 | } 317 | }, 318 | "nbformat": 4, 319 | "nbformat_minor": 2 320 | } 321 | -------------------------------------------------------------------------------- /appendix/tensorboard_word_embeddings.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Visualizing Word Embeddings on the Tensorboard" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 17, 13 | "metadata": { 14 | "collapsed": true 15 | }, 16 | "outputs": [], 17 | "source": [ 18 | "import numpy as np\n", 19 | "import tensorflow as tf\n", 20 | "import os\n", 21 | "import zipfile\n", 22 | "from tensorflow.contrib.tensorboard.plugins import projector\n", 23 | "import csv\n" 24 | ] 25 | }, 26 | { 27 | "cell_type": "markdown", 28 | "metadata": {}, 29 | "source": [ 30 | "## Read the GloVe file\n", 31 | "\n", 32 | "Here we first need to download the GloVe word embeddings (`glove.6B.zip`) found at this [website](https://nlp.stanford.edu/projects/glove/). Then we read the GloVe file to get the first 50000 words in the file. We will be using 50 dimensional word vectors" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": 18, 38 | "metadata": {}, 39 | "outputs": [ 40 | { 41 | "name": "stdout", 42 | "output_type": "stream", 43 | "text": [ 44 | ".....\tDone\n" 45 | ] 46 | } 47 | ], 48 | "source": [ 49 | "vocabulary_size = 50000\n", 50 | "\n", 51 | "pret_embeddings = np.empty(shape=(vocabulary_size,50),dtype=np.float32)\n", 52 | "\n", 53 | "words = [] \n", 54 | "\n", 55 | "word_idx = 0\n", 56 | "# Open the zip file\n", 57 | "with zipfile.ZipFile('glove.6B.zip') as glovezip:\n", 58 | " # Read the file with 50 dimensional embeddings\n", 59 | " with glovezip.open('glove.6B.50d.txt') as glovefile:\n", 60 | " # Read line by line\n", 61 | " for li, line in enumerate(glovefile):\n", 62 | " # Print progress\n", 63 | " if (li+1)%10000==0: print('.',end='')\n", 64 | " \n", 65 | " # Get the word and the corresponding vector\n", 66 | " line_tokens = line.decode('utf-8').split(' ')\n", 67 | " word = line_tokens[0]\n", 68 | " vector = [float(v) for v in line_tokens[1:]]\n", 69 | " \n", 70 | " assert len(vector)==50\n", 71 | " words.append(word)\n", 72 | " # Update the embedding matrix\n", 73 | " pret_embeddings[word_idx,:] = np.array(vector)\n", 74 | " word_idx += 1\n", 75 | " # If the first 50000 words being read, finish\n", 76 | " if word_idx == vocabulary_size:\n", 77 | " break\n", 78 | " \n", 79 | "print('\\tDone')" 80 | ] 81 | }, 82 | { 83 | "cell_type": "markdown", 84 | "metadata": {}, 85 | "source": [ 86 | "## Create TensorFlow Variable\n", 87 | "\n", 88 | "Here we create a TensorFlow variable to store the embeddings we read above and save it to the disk. This is necessary for the visualization." 89 | ] 90 | }, 91 | { 92 | "cell_type": "code", 93 | "execution_count": 19, 94 | "metadata": { 95 | "collapsed": true 96 | }, 97 | "outputs": [], 98 | "source": [ 99 | "# Create a directory to save our model\n", 100 | "log_dir = 'models'\n", 101 | "if not os.path.exists(log_dir):\n", 102 | " os.mkdir(log_dir)\n", 103 | "\n", 104 | "tf.reset_default_graph()\n", 105 | "\n", 106 | "# Create a Tensorflow variable initialized with the word embedings we just read in\n", 107 | "embeddings = tf.get_variable('embeddings',shape=[vocabulary_size, 50],\n", 108 | " initializer=tf.constant_initializer(pret_embeddings))\n", 109 | "\n", 110 | "session = tf.InteractiveSession()\n", 111 | "\n", 112 | "tf.global_variables_initializer().run()\n", 113 | "\n", 114 | "# Define a saver, that will save the Tensorflow variables to a given location\n", 115 | "saver = tf.train.Saver({'embeddings':embeddings})\n", 116 | "# Save the file\n", 117 | "saver.save(session, os.path.join(log_dir, \"model.ckpt\"), 0)\n", 118 | "\n", 119 | "# Define metadata for word embeddings\n", 120 | "with open(os.path.join(log_dir,'metadata.tsv'), 'w',encoding='utf-8') as csvfile:\n", 121 | " writer = csv.writer(csvfile, delimiter='\\t',\n", 122 | " quotechar='|', quoting=csv.QUOTE_MINIMAL)\n", 123 | " writer.writerow(['Word','Word ID'])\n", 124 | " for wi,w in enumerate(words):\n", 125 | " writer.writerow([w,wi])" 126 | ] 127 | }, 128 | { 129 | "cell_type": "markdown", 130 | "metadata": {}, 131 | "source": [ 132 | "## Define the configuration to tell the Tensorboard where and what to look" 133 | ] 134 | }, 135 | { 136 | "cell_type": "code", 137 | "execution_count": 20, 138 | "metadata": { 139 | "collapsed": true 140 | }, 141 | "outputs": [], 142 | "source": [ 143 | "config = projector.ProjectorConfig()\n", 144 | "\n", 145 | "# You can add multiple embeddings. Here we add only one.\n", 146 | "embedding_config = config.embeddings.add()\n", 147 | "embedding_config.tensor_name = embeddings.name\n", 148 | "# Link this tensor to its metadata file (e.g. labels).\n", 149 | "embedding_config.metadata_path = 'metadata.tsv'\n", 150 | "\n", 151 | "# Use the same LOG_DIR where you stored your checkpoint.\n", 152 | "summary_writer = tf.summary.FileWriter(log_dir)\n", 153 | "\n", 154 | "# The next line writes a projector_config.pbtxt in the LOG_DIR. TensorBoard will\n", 155 | "# read this file during startup.\n", 156 | "projector.visualize_embeddings(summary_writer, config)" 157 | ] 158 | }, 159 | { 160 | "cell_type": "code", 161 | "execution_count": null, 162 | "metadata": { 163 | "collapsed": true 164 | }, 165 | "outputs": [], 166 | "source": [] 167 | } 168 | ], 169 | "metadata": { 170 | "kernelspec": { 171 | "display_name": "Python 3", 172 | "language": "python", 173 | "name": "python3" 174 | }, 175 | "language_info": { 176 | "codemirror_mode": { 177 | "name": "ipython", 178 | "version": 3 179 | }, 180 | "file_extension": ".py", 181 | "mimetype": "text/x-python", 182 | "name": "python", 183 | "nbconvert_exporter": "python", 184 | "pygments_lexer": "ipython3", 185 | "version": "3.5.2" 186 | } 187 | }, 188 | "nbformat": 4, 189 | "nbformat_minor": 2 190 | } 191 | -------------------------------------------------------------------------------- /ch10/bleu_score_example.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## Performance Evaluation of MT\n", 8 | "### BLEU Score: Adequacy and Fluency of Translations\n", 9 | "\n", 10 | "1. Calculate the modified n-gram precision for the full corpus for all n=1...N\n", 11 | "2. Calculate the geometric mean (gm-precision) of all the precisions\n", 12 | "3. Calculate the brevity penalty (bp) for the full corpus\n", 13 | "3. Calculate the BLEU by bp * gm-precision\n", 14 | "\n", 15 | "Example Calculation:\n", 16 | "\n", 17 | "* Candidate1: the the the the the the the\n", 18 | "* Candidate2: the cat is on the mat\n", 19 | "* Ref1: the cat sat on the mat\n", 20 | "* Ref2: there is a cat on the mat\n", 21 | "\n", 22 | "#### Modified 1-gram Precision (Measures adequacy)\n", 23 | "* Candidate1: $\\frac{2}{7}$\n", 24 | "* Candidate2: $\\frac{5}{5}$\n", 25 | "\n", 26 | "#### Modified 2-gram Precision (Measures fluency)\n", 27 | "* Candidate1: $\\frac{0}{1}$\n", 28 | "* Candidate2: $\\frac{3}{5}$\n" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": 1, 34 | "metadata": { 35 | "collapsed": true 36 | }, 37 | "outputs": [], 38 | "source": [ 39 | "import numpy as np\n", 40 | "from collections import Counter" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": 11, 46 | "metadata": {}, 47 | "outputs": [ 48 | { 49 | "name": "stdout", 50 | "output_type": "stream", 51 | "text": [ 52 | "Calculating modified 1-gram precision\n", 53 | "\tReference sentence: the cat sat on the mat\n", 54 | "\t Candidate sentence: the the the\n", 55 | "\tReference sentence: there is a cat on the mat\n", 56 | "\t Candidate sentence: the a cat\n", 57 | "Calculating modified 2-gram precision\n", 58 | "\tReference sentence: the cat sat on the mat\n", 59 | "\t Candidate sentence: the the the\n", 60 | "\tReference sentence: there is a cat on the mat\n", 61 | "\t Candidate sentence: the a cat\n", 62 | "Calculating modified 3-gram precision\n", 63 | "\tReference sentence: the cat sat on the mat\n", 64 | "\t Candidate sentence: the the the\n", 65 | "\tReference sentence: there is a cat on the mat\n", 66 | "\t Candidate sentence: the a cat\n", 67 | "\n", 68 | "BLEU-3: 8.568589920310384e-35\n", 69 | "\n", 70 | "Calculating modified 1-gram precision\n", 71 | "\tReference sentence: the cat sat on the mat\n", 72 | "\t Candidate sentence: the dog on the mat\n", 73 | "\tReference sentence: there is a cat on the mat\n", 74 | "\t Candidate sentence: there is cat on the mat\n", 75 | "Calculating modified 2-gram precision\n", 76 | "\tReference sentence: the cat sat on the mat\n", 77 | "\t Candidate sentence: the dog on the mat\n", 78 | "\tReference sentence: there is a cat on the mat\n", 79 | "\t Candidate sentence: there is cat on the mat\n", 80 | "Calculating modified 3-gram precision\n", 81 | "\tReference sentence: the cat sat on the mat\n", 82 | "\t Candidate sentence: the dog on the mat\n", 83 | "\tReference sentence: there is a cat on the mat\n", 84 | "\t Candidate sentence: there is cat on the mat\n", 85 | "\n", 86 | "BLEU-3: 0.5319658954895262\n" 87 | ] 88 | } 89 | ], 90 | "source": [ 91 | "def unique_n_gram_string(n_gram):\n", 92 | " string = ''\n", 93 | " for g in n_gram[:-1]:\n", 94 | " string += str(g)+'-'\n", 95 | " \n", 96 | " string += str(n_gram[-1])\n", 97 | " return string\n", 98 | "\n", 99 | "def calculate_mod_n_gram_precision(n_gram, refs, cands):\n", 100 | " \n", 101 | " denominator = 0.0\n", 102 | " \n", 103 | " tot_bleu = 0.0\n", 104 | " \n", 105 | " tot_ref_length, tot_cand_length = 0, 0\n", 106 | " for ref, cand in zip(refs, cands):\n", 107 | " \n", 108 | " print('\\tReference sentence: ',' '.join([reverse_test_dict[r] for r in ref]))\n", 109 | " print('\\t Candidate sentence: ',' '.join([reverse_test_dict[c] for c in cand]))\n", 110 | "\n", 111 | " denominator += max(cand.size + 1 - n_gram,1)\n", 112 | " tot_ref_length += ref.size\n", 113 | " tot_cand_length += cand.size\n", 114 | " \n", 115 | " # find unique n-grams in predicted\n", 116 | " cand_n_grams = [unique_n_gram_string(cand[w_i:w_i+n_gram]) for w_i in range(cand.size + 1 - n_gram)]\n", 117 | " cand_n_grams = list(set(cand_n_grams))\n", 118 | "\n", 119 | " occurences_for_unique_grams = dict(zip(cand_n_grams,[0 for _ in cand_n_grams]))\n", 120 | "\n", 121 | " ref_n_grams = [unique_n_gram_string(ref[w_i:w_i+n_gram]) for w_i in range(ref.size + 1 - n_gram)]\n", 122 | " ref_counts = Counter(ref_n_grams)\n", 123 | "\n", 124 | " # iterates through every n_gram in the predicted\n", 125 | " for w_i in range(cand.size + 1 - n_gram): \n", 126 | " c_gram = cand[w_i:w_i+n_gram]\n", 127 | " gram_string = unique_n_gram_string(c_gram)\n", 128 | " \n", 129 | " for ref_i in range(ref.size + 1 - n_gram):\n", 130 | "\n", 131 | " r_gram = ref[ref_i:ref_i+n_gram]\n", 132 | "\n", 133 | " found_gram_in_actual = int(np.prod(c_gram == r_gram))\n", 134 | "\n", 135 | " occurences_for_unique_grams[gram_string] += found_gram_in_actual\n", 136 | "\n", 137 | " \n", 138 | " for g, occ in occurences_for_unique_grams.items():\n", 139 | " g_bleu = float(occ)\n", 140 | " if g in ref_counts:\n", 141 | " g_bleu = min(g_bleu,ref_counts[g])\n", 142 | "\n", 143 | " tot_bleu += g_bleu\n", 144 | "\n", 145 | " mod_n_prec = tot_bleu/denominator\n", 146 | " \n", 147 | " \n", 148 | " return mod_n_prec, tot_ref_length, tot_cand_length\n", 149 | "\n", 150 | "\n", 151 | "def calculate_bleu(refs, cands, high_n):\n", 152 | " weight = 1.0/high_n # using the same weight for all mod n_gram precisions\n", 153 | " \n", 154 | " tot_precision = []\n", 155 | " for n in range(1,high_n+1): \n", 156 | " print('Calculating modified %d-gram precision'%n)\n", 157 | " prec, tot_ref_length, tot_cand_length = calculate_mod_n_gram_precision(n,refs,cands)\n", 158 | " tot_precision.append(weight*np.log(prec + 1e-100))\n", 159 | " \n", 160 | " brevity_penalty = 1.0\n", 161 | " \n", 162 | " if tot_cand_length <= tot_ref_length:\n", 163 | " brevity_penalty = np.exp(1.0-(tot_ref_length*1.0/max(tot_cand_length,1)))\n", 164 | " \n", 165 | " bleu = brevity_penalty * np.exp(np.sum(tot_precision))\n", 166 | "\n", 167 | " return bleu\n", 168 | "\n", 169 | "test_dict = {'the':10,'cat':11,'sat':12,'on':13,'mat':14,'is':15,'there':16,'a':17,'dog':18}\n", 170 | "reverse_test_dict = dict(zip(test_dict.values(),test_dict.keys()))\n", 171 | "\n", 172 | "sample_text_refs = [['the','cat','sat','on','the','mat'],['there','is','a','cat','on','the','mat']]\n", 173 | "sample_refs = []\n", 174 | "for r in sample_text_refs:\n", 175 | " sample_refs.append(np.asarray([test_dict[w] for w in r],dtype=np.int32))\n", 176 | "\n", 177 | "sample_text_cands_1 = [['the','the','the'],['the','a','cat']]\n", 178 | "sample_cands_1 = []\n", 179 | "for c1 in sample_text_cands_1:\n", 180 | " sample_cands_1.append(np.asarray([test_dict[w] for w in c1],dtype=np.int32))\n", 181 | "\n", 182 | "\n", 183 | "sample_text_cands_2 = [['the','dog','on','the','mat'],['there','is','cat','on','the','mat']]\n", 184 | "sample_cands_2 = []\n", 185 | "for c2 in sample_text_cands_2:\n", 186 | " sample_cands_2.append(np.asarray([test_dict[w] for w in c2],dtype=np.int32))\n", 187 | "\n", 188 | "\n", 189 | "b1 = calculate_bleu(sample_refs,sample_cands_1,3)\n", 190 | "print('\\nBLEU-3: ',b1)\n", 191 | "print()\n", 192 | "\n", 193 | "b2 = calculate_bleu(sample_refs,sample_cands_2,3)\n", 194 | "print('\\nBLEU-3: ',b2)" 195 | ] 196 | }, 197 | { 198 | "cell_type": "code", 199 | "execution_count": null, 200 | "metadata": { 201 | "collapsed": true 202 | }, 203 | "outputs": [], 204 | "source": [] 205 | } 206 | ], 207 | "metadata": { 208 | "kernelspec": { 209 | "display_name": "Python 3", 210 | "language": "python", 211 | "name": "python3" 212 | }, 213 | "language_info": { 214 | "codemirror_mode": { 215 | "name": "ipython", 216 | "version": 3 217 | }, 218 | "file_extension": ".py", 219 | "mimetype": "text/x-python", 220 | "name": "python", 221 | "nbconvert_exporter": "python", 222 | "pygments_lexer": "ipython3", 223 | "version": "3.5.2" 224 | } 225 | }, 226 | "nbformat": 4, 227 | "nbformat_minor": 2 228 | } 229 | -------------------------------------------------------------------------------- /ch10/word2vec.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import math 4 | sentence_cursors = None 5 | tot_sentences = None 6 | src_max_sent_length, tgt_max_sent_length = 0, 0 7 | src_dictionary, tgt_dictionary = {}, {} 8 | src_reverse_dictionary, tgt_reverse_dictionary = {},{} 9 | train_inputs, train_outputs = None, None 10 | embedding_size = None # Dimension of the embedding vector. 11 | vocabulary_size = None 12 | def define_data_and_hyperparameters( 13 | _tot_sentences, _src_max, _tgt_max, _src_dict, _tgt_dict, 14 | _src_rev_dict, _tgt_rev_dict, _tr_inp, _tr_out, _emb_size, _vocab_size): 15 | global tot_sentences, sentence_cursors 16 | global src_max_sent_length, tgt_max_sent_length 17 | global src_dictionary, tgt_dictionary 18 | global src_reverse_dictionary, tgt_reverse_dictionary 19 | global train_inputs, train_outputs 20 | global embedding_size, vocabulary_size 21 | 22 | embedding_size = _emb_size 23 | vocabulary_size = _vocab_size 24 | src_max_sent_length, tgt_max_sent_length = _src_max, _tgt_max 25 | 26 | src_dictionary = _src_dict 27 | tgt_dictionary = _tgt_dict 28 | 29 | src_reverse_dictionary = _src_rev_dict 30 | tgt_reverse_dictionary = _tgt_rev_dict 31 | 32 | train_inputs = _tr_inp 33 | train_outputs = _tr_out 34 | 35 | tot_sentences = _tot_sentences 36 | sentence_cursors = [0 for _ in range(tot_sentences)] 37 | 38 | 39 | def generate_batch_for_word2vec(batch_size, window_size, is_source): 40 | # window_size is the amount of words we're looking at from each side of a given word 41 | # creates a single batch 42 | global sentence_cursors 43 | global src_dictionary, tgt_dictionary 44 | global train_inputs, train_outputs 45 | span = 2 * window_size + 1 # [ skip_window target skip_window ] 46 | 47 | batch = np.ndarray(shape=(batch_size, span - 1), dtype=np.int32) 48 | labels = np.ndarray(shape=(batch_size, 1), dtype=np.int32) 49 | # e.g if skip_window = 2 then span = 5 50 | # span is the length of the whole frame we are considering for a single word (left + word + right) 51 | # skip_window is the length of one side 52 | 53 | sentence_ids_for_batch = np.random.randint(0, tot_sentences, batch_size) 54 | 55 | for b_i in range(batch_size): 56 | sent_id = sentence_ids_for_batch[b_i] 57 | 58 | if is_source: 59 | buffer = train_inputs[sent_id, sentence_cursors[sent_id]:sentence_cursors[sent_id] + span] 60 | else: 61 | buffer = train_outputs[sent_id, sentence_cursors[sent_id]:sentence_cursors[sent_id] + span] 62 | assert buffer.size == span, 'Buffer length (%d), Current data index (%d), Span(%d)' % ( 63 | buffer.size, sentence_cursors[sent_id], span) 64 | # If we only have EOS tokesn in the sampled text, we sample a new one 65 | if is_source: 66 | while np.all(buffer == src_dictionary['']): 67 | # reset the sentence_cursors for that cap_id 68 | sentence_cursors[sent_id] = 0 69 | # sample a new cap_id 70 | sent_id = np.random.randint(0, tot_sentences) 71 | buffer = train_inputs[sent_id, sentence_cursors[sent_id]:sentence_cursors[sent_id] + span] 72 | else: 73 | while np.all(buffer == tgt_dictionary['']): 74 | # reset the sentence_cursors for that cap_id 75 | sentence_cursors[sent_id] = 0 76 | # sample a new cap_id 77 | sent_id = np.random.randint(0, tot_sentences) 78 | buffer = train_outputs[sent_id, sentence_cursors[sent_id]:sentence_cursors[sent_id] + span] 79 | 80 | # fill left and right sides of batch 81 | batch[b_i, :window_size] = buffer[:window_size] 82 | batch[b_i, window_size:] = buffer[window_size + 1:] 83 | 84 | labels[b_i, 0] = buffer[window_size] 85 | 86 | # increase the corresponding index 87 | if is_source: 88 | sentence_cursors[sent_id] = (sentence_cursors[sent_id] + 1) % (src_max_sent_length - span) 89 | else: 90 | sentence_cursors[sent_id] = (sentence_cursors[sent_id] + 1) % (tgt_max_sent_length - span) 91 | 92 | assert batch.shape[0] == batch_size and batch.shape[1] == span - 1 93 | return batch, labels 94 | 95 | 96 | def print_some_batches(): 97 | global sentence_cursors, tot_sentences 98 | global src_reverse_dictionary 99 | 100 | for window_size in [1, 2]: 101 | sentence_cursors = [0 for _ in range(tot_sentences)] 102 | batch, labels = generate_batch_for_word2vec(batch_size=8, window_size=window_size, is_source=True) 103 | print('\nwith window_size = %d:' % (window_size)) 104 | print(' batch:', [[src_reverse_dictionary[bii] for bii in bi] for bi in batch]) 105 | print(' labels:', [src_reverse_dictionary[li] for li in labels.reshape(8)]) 106 | 107 | sentence_cursors = [0 for _ in range(tot_sentences)] 108 | 109 | batch_size, window_size = None, None 110 | valid_size, valid_window, valid_examples = None, None, None 111 | num_sampled = None 112 | 113 | train_dataset, train_labels = None, None 114 | valid_dataset = None 115 | 116 | softmax_weights, softmax_biases = None, None 117 | 118 | loss, optimizer, similarity, normalized_embeddings = None, None, None, None 119 | 120 | def define_word2vec_tensorflow(batch_size): 121 | 122 | global embedding_size, window_size 123 | global valid_size, valid_window, valid_examples 124 | global num_sampled 125 | global train_dataset, train_labels 126 | global valid_dataset 127 | global softmax_weights, softmax_biases 128 | global loss, optimizer, similarity 129 | global vocabulary_size, embedding_size 130 | global normalized_embeddings 131 | 132 | 133 | window_size = 2 # How many words to consider left and right. 134 | # We pick a random validation set to sample nearest neighbors. here we limit the 135 | # validation samples to the words that have a low numeric ID, which by 136 | # construction are also the most frequent. 137 | valid_size = 20 # Random set of words to evaluate similarity on. 138 | valid_window = 100 # Only pick dev samples in the head of the distribution. 139 | # pick 16 samples from 100 140 | valid_examples = np.array(np.random.randint(0, valid_window, valid_size // 2)) 141 | valid_examples = np.append(valid_examples, np.random.randint(1000, 1000 + valid_window, valid_size // 2)) 142 | num_sampled = 32 # Number of negative examples to sample. 143 | 144 | tf.reset_default_graph() 145 | 146 | # Input data. 147 | train_dataset = tf.placeholder(tf.int32, shape=[batch_size, 2 * window_size]) 148 | train_labels = tf.placeholder(tf.int32, shape=[batch_size, 1]) 149 | valid_dataset = tf.constant(valid_examples, dtype=tf.int32) 150 | 151 | # Variables. 152 | # embedding, vector for each word in the vocabulary 153 | embeddings = tf.Variable(tf.random_uniform([vocabulary_size, embedding_size], -1.0, 1.0, dtype=tf.float32)) 154 | softmax_weights = tf.Variable(tf.truncated_normal([vocabulary_size, embedding_size], 155 | stddev=1.0 / math.sqrt(embedding_size), dtype=tf.float32)) 156 | softmax_biases = tf.Variable(tf.zeros([vocabulary_size], dtype=tf.float32)) 157 | 158 | # Model. 159 | # Look up embeddings for inputs. 160 | # this might efficiently find the embeddings for given ids (traind dataset) 161 | # manually doing this might not be efficient given there are 50000 entries in embeddings 162 | stacked_embedings = None 163 | print('Defining %d embedding lookups representing each word in the context' % (2 * window_size)) 164 | for i in range(2 * window_size): 165 | embedding_i = tf.nn.embedding_lookup(embeddings, train_dataset[:, i]) 166 | x_size, y_size = embedding_i.get_shape().as_list() 167 | if stacked_embedings is None: 168 | stacked_embedings = tf.reshape(embedding_i, [x_size, y_size, 1]) 169 | else: 170 | stacked_embedings = tf.concat(axis=2, 171 | values=[stacked_embedings, tf.reshape(embedding_i, [x_size, y_size, 1])]) 172 | 173 | assert stacked_embedings.get_shape().as_list()[2] == 2 * window_size 174 | print("Stacked embedding size: %s" % stacked_embedings.get_shape().as_list()) 175 | mean_embeddings = tf.reduce_mean(stacked_embedings, 2, keepdims=False) 176 | print("Reduced mean embedding size: %s" % mean_embeddings.get_shape().as_list()) 177 | 178 | # Compute the softmax loss, using a sample of the negative labels each time. 179 | # inputs are embeddings of the train words 180 | # with this loss we optimize weights, biases, embeddings 181 | 182 | loss = tf.reduce_mean( 183 | tf.nn.sampled_softmax_loss(weights=softmax_weights, biases=softmax_biases, inputs=mean_embeddings, 184 | labels=train_labels, num_sampled=num_sampled, num_classes=vocabulary_size)) 185 | 186 | # Optimizer. 187 | # Note: The optimizer will optimize the softmax_weights AND the embeddings. 188 | optimizer = tf.train.AdamOptimizer(0.001).minimize(loss) 189 | 190 | # Compute the similarity between minibatch examples and all embeddings. 191 | # We use the cosine distance: 192 | norm = tf.sqrt(tf.reduce_sum(tf.square(embeddings), 1, keepdims=True)) 193 | normalized_embeddings = embeddings / norm 194 | valid_embeddings = tf.nn.embedding_lookup(normalized_embeddings, valid_dataset) 195 | similarity = tf.matmul(valid_embeddings, tf.transpose(normalized_embeddings)) 196 | 197 | 198 | def run_word2vec_source(batch_size): 199 | global embedding_size, window_size 200 | global valid_size, valid_window, valid_examples 201 | global num_sampled 202 | global train_dataset, train_labels 203 | global valid_dataset 204 | global softmax_weights, softmax_biases 205 | global loss, optimizer, similarity, normalized_embeddings 206 | global src_reverse_dictionary 207 | global vocabulary_size, embedding_size 208 | 209 | num_steps = 100001 210 | 211 | config=tf.ConfigProto(allow_soft_placement=True) 212 | config.gpu_options.allow_growth = True 213 | 214 | with tf.Session(config=config) as session: 215 | tf.global_variables_initializer().run() 216 | print('Initialized') 217 | average_loss = 0 218 | for step in range(num_steps): 219 | 220 | batch_data, batch_labels = generate_batch_for_word2vec(batch_size, window_size, is_source=True) 221 | feed_dict = {train_dataset: batch_data, train_labels: batch_labels} 222 | _, l = session.run([optimizer, loss], feed_dict=feed_dict) 223 | average_loss += l 224 | if (step + 1) % 2000 == 0: 225 | if step > 0: 226 | average_loss = average_loss / 2000 227 | # The average loss is an estimate of the loss over the last 2000 batches. 228 | print('Average loss at step %d: %f' % (step + 1, average_loss)) 229 | average_loss = 0 230 | # note that this is expensive (~20% slowdown if computed every 500 steps) 231 | if (step + 1) % 10000 == 0: 232 | sim = similarity.eval() 233 | for i in range(valid_size): 234 | valid_word = src_reverse_dictionary[valid_examples[i]] 235 | top_k = 8 # number of nearest neighbors 236 | nearest = (-sim[i, :]).argsort()[1:top_k + 1] 237 | log = 'Nearest to %s:' % valid_word 238 | for k in range(top_k): 239 | close_word = src_reverse_dictionary[nearest[k]] 240 | log = '%s %s,' % (log, close_word) 241 | print(log) 242 | cbow_final_embeddings = normalized_embeddings.eval() 243 | 244 | np.save('de-embeddings.npy', cbow_final_embeddings) 245 | 246 | def run_word2vec_target(batch_size): 247 | global embedding_size, window_size 248 | global valid_size, valid_window, valid_examples 249 | global num_sampled 250 | global train_dataset, train_labels 251 | global valid_dataset 252 | global softmax_weights, softmax_biases 253 | global loss, optimizer, similarity, normalized_embeddings 254 | global tgt_reverse_dictionary 255 | global vocabulary_size, embedding_size 256 | 257 | num_steps = 100001 258 | 259 | config=tf.ConfigProto(allow_soft_placement=True) 260 | config.gpu_options.allow_growth = True 261 | with tf.Session(config=config) as session: 262 | tf.global_variables_initializer().run() 263 | print('Initialized') 264 | average_loss = 0 265 | for step in range(num_steps): 266 | 267 | batch_data, batch_labels = generate_batch_for_word2vec(batch_size, window_size, is_source=False) 268 | feed_dict = {train_dataset: batch_data, train_labels: batch_labels} 269 | _, l = session.run([optimizer, loss], feed_dict=feed_dict) 270 | average_loss += l 271 | if (step + 1) % 2000 == 0: 272 | if step > 0: 273 | average_loss = average_loss / 2000 274 | # The average loss is an estimate of the loss over the last 2000 batches. 275 | print('Average loss at step %d: %f' % (step + 1, average_loss)) 276 | average_loss = 0 277 | # note that this is expensive (~20% slowdown if computed every 500 steps) 278 | if (step + 1) % 10000 == 0: 279 | sim = similarity.eval() 280 | for i in range(valid_size): 281 | valid_word = tgt_reverse_dictionary[valid_examples[i]] 282 | top_k = 8 # number of nearest neighbors 283 | nearest = (-sim[i, :]).argsort()[1:top_k + 1] 284 | log = 'Nearest to %s:' % valid_word 285 | for k in range(top_k): 286 | close_word = tgt_reverse_dictionary[nearest[k]] 287 | log = '%s %s,' % (log, close_word) 288 | print(log) 289 | cbow_final_embeddings = normalized_embeddings.eval() 290 | 291 | np.save('en-embeddings.npy', cbow_final_embeddings) -------------------------------------------------------------------------------- /ch2/test1.txt: -------------------------------------------------------------------------------- 1 | 0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1.0 2 | 0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1.0 3 | 0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1.0 4 | 0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1.0 5 | 0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1.0 6 | -------------------------------------------------------------------------------- /ch2/test2.txt: -------------------------------------------------------------------------------- 1 | 0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1 2 | 0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1 3 | 0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1 4 | 0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1 5 | 0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1 6 | -------------------------------------------------------------------------------- /ch2/test3.txt: -------------------------------------------------------------------------------- 1 | 1.0,0.9,0.8,0.7,0.6,0.5,0.4,0.3,0.2,0.1 2 | 1.0,0.9,0.8,0.7,0.6,0.5,0.4,0.3,0.2,0.1 3 | 1.0,0.9,0.8,0.7,0.6,0.5,0.4,0.3,0.2,0.1 4 | 1.0,0.9,0.8,0.7,0.6,0.5,0.4,0.3,0.2,0.1 5 | 1.0,0.9,0.8,0.7,0.6,0.5,0.4,0.3,0.2,0.1 6 | -------------------------------------------------------------------------------- /ch3/ch3_wordnet.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": { 7 | "collapsed": true 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "# You first need to download the wordnet following these commands \n", 12 | "# before importing it\n", 13 | "import nltk\n", 14 | "nltk.download('wordnet')" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": 2, 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "# you will need to download the wordnet corpus from nltk using nltk.download()\n", 24 | "from nltk.corpus import wordnet as wn" 25 | ] 26 | }, 27 | { 28 | "cell_type": "markdown", 29 | "metadata": {}, 30 | "source": [ 31 | "## Various Synset Relationships\n", 32 | "Here we will look at what lemmas, hypernyms, hyponyms, meronyms and holonyms look like" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": 3, 38 | "metadata": {}, 39 | "outputs": [ 40 | { 41 | "name": "stdout", 42 | "output_type": "stream", 43 | "text": [ 44 | "All the available Synsets for car\n", 45 | "\t [Synset('car.n.01'), Synset('car.n.02'), Synset('car.n.03'), Synset('car.n.04'), Synset('cable_car.n.01')] \n", 46 | "\n", 47 | "Example definitions of available Synsets ...\n", 48 | "\t car.n.01 : a motor vehicle with four wheels; usually propelled by an internal combustion engine\n", 49 | "\t car.n.02 : a wheeled vehicle adapted to the rails of railroad\n", 50 | "\t car.n.03 : the compartment that is suspended from an airship and that carries personnel and the cargo and the power plant\n", 51 | "\n", 52 | "\n", 53 | "Example lemmas for the Synset car.n.03\n", 54 | "\t ['car', 'auto', 'automobile'] \n", 55 | "\n", 56 | "Hypernyms of the Synset car.n.01\n", 57 | "\t motor_vehicle.n.01 \n", 58 | "\n", 59 | "Hyponyms of the Synset car.n.01\n", 60 | "\t ['ambulance.n.01', 'beach_wagon.n.01', 'bus.n.04'] \n", 61 | "\n", 62 | "Holonyms (Part) of the Synset car.n.03\n", 63 | "\t ['airship.n.01'] \n", 64 | "\n", 65 | "Meronyms (Part) of the Synset car.n.01\n", 66 | "\t ['accelerator.n.01', 'air_bag.n.01', 'auto_accessory.n.01'] \n", 67 | "\n" 68 | ] 69 | } 70 | ], 71 | "source": [ 72 | "# shows all the available synsets\n", 73 | "word = 'car'\n", 74 | "car_syns = wn.synsets(word)\n", 75 | "print('All the available Synsets for ',word)\n", 76 | "print('\\t',car_syns,'\\n')\n", 77 | "\n", 78 | "# The definition of the first two synsets\n", 79 | "syns_defs = [car_syns[i].definition() for i in range(len(car_syns))]\n", 80 | "print('Example definitions of available Synsets ...')\n", 81 | "for i in range(3):\n", 82 | " print('\\t',car_syns[i].name(),': ',syns_defs[i])\n", 83 | "print('\\n')\n", 84 | "\n", 85 | "# Get the lemmas for the first Synset\n", 86 | "print('Example lemmas for the Synset ',car_syns[i].name())\n", 87 | "car_lemmas = car_syns[0].lemmas()[:3]\n", 88 | "print('\\t',[lemma.name() for lemma in car_lemmas],'\\n')\n", 89 | "\n", 90 | "# Let us get hypernyms for a Synset (general superclass)\n", 91 | "syn = car_syns[0]\n", 92 | "print('Hypernyms of the Synset ',syn.name())\n", 93 | "print('\\t',syn.hypernyms()[0].name(),'\\n')\n", 94 | "\n", 95 | "# Let us get hyponyms for a Synset (specific subclass)\n", 96 | "syn = car_syns[0]\n", 97 | "print('Hyponyms of the Synset ',syn.name())\n", 98 | "print('\\t',[hypo.name() for hypo in syn.hyponyms()[:3]],'\\n')\n", 99 | "\n", 100 | "# Let us get part-holonyms for a Synset (specific subclass)\n", 101 | "# also there is another holonym category called \"substance-holonyms\"\n", 102 | "syn = car_syns[2]\n", 103 | "print('Holonyms (Part) of the Synset ',syn.name())\n", 104 | "print('\\t',[holo.name() for holo in syn.part_holonyms()],'\\n')\n", 105 | "\n", 106 | "# Let us get meronyms for a Synset (specific subclass)\n", 107 | "# also there is another meronym category called \"substance-meronyms\"\n", 108 | "syn = car_syns[0]\n", 109 | "print('Meronyms (Part) of the Synset ',syn.name())\n", 110 | "print('\\t',[mero.name() for mero in syn.part_meronyms()[:3]],'\\n')" 111 | ] 112 | }, 113 | { 114 | "cell_type": "markdown", 115 | "metadata": {}, 116 | "source": [ 117 | "## Similarity between Synsets" 118 | ] 119 | }, 120 | { 121 | "cell_type": "code", 122 | "execution_count": 4, 123 | "metadata": {}, 124 | "outputs": [ 125 | { 126 | "name": "stdout", 127 | "output_type": "stream", 128 | "text": [ 129 | "Word Similarity (car)<->(lorry): 0.6956521739130435\n", 130 | "Word Similarity (car)<->(tree): 0.38095238095238093\n" 131 | ] 132 | } 133 | ], 134 | "source": [ 135 | "word1, word2, word3 = 'car','lorry','tree'\n", 136 | "w1_syns, w2_syns, w3_syns = wn.synsets(word1), wn.synsets(word2), wn.synsets(word3)\n", 137 | "\n", 138 | "print('Word Similarity (%s)<->(%s): '%(word1,word2),wn.wup_similarity(w1_syns[0], w2_syns[0]))\n", 139 | "print('Word Similarity (%s)<->(%s): '%(word1,word3),wn.wup_similarity(w1_syns[0], w3_syns[0]))" 140 | ] 141 | }, 142 | { 143 | "cell_type": "code", 144 | "execution_count": null, 145 | "metadata": { 146 | "collapsed": true 147 | }, 148 | "outputs": [], 149 | "source": [] 150 | } 151 | ], 152 | "metadata": { 153 | "kernelspec": { 154 | "display_name": "Python 3", 155 | "language": "python", 156 | "name": "python3" 157 | }, 158 | "language_info": { 159 | "codemirror_mode": { 160 | "name": "ipython", 161 | "version": 3 162 | }, 163 | "file_extension": ".py", 164 | "mimetype": "text/x-python", 165 | "name": "python", 166 | "nbconvert_exporter": "python", 167 | "pygments_lexer": "ipython3", 168 | "version": "3.5.2" 169 | } 170 | }, 171 | "nbformat": 4, 172 | "nbformat_minor": 2 173 | } 174 | -------------------------------------------------------------------------------- /ch4/ch4_glove.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# GloVe: Global Vectors for Word2Vec" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": {}, 14 | "outputs": [ 15 | { 16 | "name": "stderr", 17 | "output_type": "stream", 18 | "text": [ 19 | "c:\\users\\thushan\\documents\\python_virtualenvs\\tensorflow_venv\\lib\\site-packages\\h5py\\__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.\n", 20 | " from ._conv import register_converters as _register_converters\n" 21 | ] 22 | } 23 | ], 24 | "source": [ 25 | "# These are all the modules we'll be using later. Make sure you can import them\n", 26 | "# before proceeding further.\n", 27 | "%matplotlib inline\n", 28 | "from __future__ import print_function\n", 29 | "import collections\n", 30 | "import math\n", 31 | "import numpy as np\n", 32 | "import os\n", 33 | "import random\n", 34 | "import tensorflow as tf\n", 35 | "import bz2\n", 36 | "from matplotlib import pylab\n", 37 | "from six.moves import range\n", 38 | "from six.moves.urllib.request import urlretrieve\n", 39 | "from sklearn.manifold import TSNE\n", 40 | "from sklearn.cluster import KMeans\n", 41 | "from scipy.sparse import lil_matrix\n", 42 | "import nltk # standard preprocessing\n", 43 | "import operator # sorting items in dictionary by value\n", 44 | "#nltk.download() #tokenizers/punkt/PY3/english.pickle\n", 45 | "from math import ceil" 46 | ] 47 | }, 48 | { 49 | "cell_type": "markdown", 50 | "metadata": {}, 51 | "source": [ 52 | "## Dataset\n", 53 | "This code downloads a [dataset](http://www.evanjones.ca/software/wikipedia2text.html) consisting of several Wikipedia articles totaling up to roughly 61 megabytes. Additionally the code makes sure the file has the correct size after downloading it." 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": 2, 59 | "metadata": {}, 60 | "outputs": [ 61 | { 62 | "name": "stdout", 63 | "output_type": "stream", 64 | "text": [ 65 | "Found and verified wikipedia2text-extracted.txt.bz2\n" 66 | ] 67 | } 68 | ], 69 | "source": [ 70 | "url = 'http://www.evanjones.ca/software/'\n", 71 | "\n", 72 | "def maybe_download(filename, expected_bytes):\n", 73 | " \"\"\"Download a file if not present, and make sure it's the right size.\"\"\"\n", 74 | " if not os.path.exists(filename):\n", 75 | " filename, _ = urlretrieve(url + filename, filename)\n", 76 | " statinfo = os.stat(filename)\n", 77 | " if statinfo.st_size == expected_bytes:\n", 78 | " print('Found and verified %s' % filename)\n", 79 | " else:\n", 80 | " print(statinfo.st_size)\n", 81 | " raise Exception(\n", 82 | " 'Failed to verify ' + filename + '. Can you get to it with a browser?')\n", 83 | " return filename\n", 84 | "\n", 85 | "filename = maybe_download('wikipedia2text-extracted.txt.bz2', 18377035)" 86 | ] 87 | }, 88 | { 89 | "cell_type": "markdown", 90 | "metadata": {}, 91 | "source": [ 92 | "## Read Data with Preprocessing with NLTK\n", 93 | "Reads data as it is to a string, convert to lower-case and tokenize it using the nltk library. This code reads data in 1MB portions as processing the full text at once slows down the task and returns a list of words" 94 | ] 95 | }, 96 | { 97 | "cell_type": "code", 98 | "execution_count": 3, 99 | "metadata": {}, 100 | "outputs": [ 101 | { 102 | "name": "stdout", 103 | "output_type": "stream", 104 | "text": [ 105 | "Reading data...\n", 106 | "Data size 3360286\n", 107 | "Example words (start): ['propaganda', 'is', 'a', 'concerted', 'set', 'of', 'messages', 'aimed', 'at', 'influencing']\n", 108 | "Example words (end): ['favorable', 'long-term', 'outcomes', 'for', 'around', 'half', 'of', 'those', 'diagnosed', 'with']\n" 109 | ] 110 | } 111 | ], 112 | "source": [ 113 | "def read_data(filename):\n", 114 | " \"\"\"\n", 115 | " Extract the first file enclosed in a zip file as a list of words\n", 116 | " and pre-processes it using the nltk python library\n", 117 | " \"\"\"\n", 118 | "\n", 119 | " with bz2.BZ2File(filename) as f:\n", 120 | "\n", 121 | " data = []\n", 122 | " file_size = os.stat(filename).st_size\n", 123 | " chunk_size = 1024 * 1024 # reading 1 MB at a time as the dataset is moderately large\n", 124 | " print('Reading data...')\n", 125 | " for i in range(ceil(file_size//chunk_size)+1):\n", 126 | " bytes_to_read = min(chunk_size,file_size-(i*chunk_size))\n", 127 | " file_string = f.read(bytes_to_read).decode('utf-8')\n", 128 | " file_string = file_string.lower()\n", 129 | " # tokenizes a string to words residing in a list\n", 130 | " file_string = nltk.word_tokenize(file_string)\n", 131 | " data.extend(file_string)\n", 132 | " return data\n", 133 | "\n", 134 | "words = read_data(filename)\n", 135 | "print('Data size %d' % len(words))\n", 136 | "token_count = len(words)\n", 137 | "\n", 138 | "print('Example words (start): ',words[:10])\n", 139 | "print('Example words (end): ',words[-10:])" 140 | ] 141 | }, 142 | { 143 | "cell_type": "markdown", 144 | "metadata": {}, 145 | "source": [ 146 | "## Building the Dictionaries\n", 147 | "Builds the following. To understand each of these elements, let us also assume the text \"I like to go to school\"\n", 148 | "\n", 149 | "* `dictionary`: maps a string word to an ID (e.g. {I:0, like:1, to:2, go:3, school:4})\n", 150 | "* `reverse_dictionary`: maps an ID to a string word (e.g. {0:I, 1:like, 2:to, 3:go, 4:school}\n", 151 | "* `count`: List of list of (word, frequency) elements (e.g. [(I,1),(like,1),(to,2),(go,1),(school,1)]\n", 152 | "* `data` : Contain the string of text we read, where string words are replaced with word IDs (e.g. [0, 1, 2, 3, 2, 4])\n", 153 | "\n", 154 | "It also introduces an additional special token `UNK` to denote rare words to are too rare to make use of." 155 | ] 156 | }, 157 | { 158 | "cell_type": "code", 159 | "execution_count": 4, 160 | "metadata": {}, 161 | "outputs": [ 162 | { 163 | "name": "stdout", 164 | "output_type": "stream", 165 | "text": [ 166 | "Most common words (+UNK) [['UNK', 69215], ('the', 226881), (',', 184013), ('.', 120944), ('of', 116323)]\n", 167 | "Sample data [1730, 9, 8, 16741, 223, 4, 5169, 4509, 26, 11641]\n" 168 | ] 169 | } 170 | ], 171 | "source": [ 172 | "# we restrict our vocabulary size to 50000\n", 173 | "vocabulary_size = 50000 \n", 174 | "\n", 175 | "def build_dataset(words):\n", 176 | " count = [['UNK', -1]]\n", 177 | " # Gets only the vocabulary_size most common words as the vocabulary\n", 178 | " # All the other words will be replaced with UNK token\n", 179 | " count.extend(collections.Counter(words).most_common(vocabulary_size - 1))\n", 180 | " dictionary = dict()\n", 181 | "\n", 182 | " # Create an ID for each word by giving the current length of the dictionary\n", 183 | " # And adding that item to the dictionary\n", 184 | " for word, _ in count:\n", 185 | " dictionary[word] = len(dictionary)\n", 186 | " \n", 187 | " data = list()\n", 188 | " unk_count = 0\n", 189 | " # Traverse through all the text we have and produce a list\n", 190 | " # where each element corresponds to the ID of the word found at that index\n", 191 | " for word in words:\n", 192 | " # If word is in the dictionary use the word ID,\n", 193 | " # else use the ID of the special token \"UNK\"\n", 194 | " if word in dictionary:\n", 195 | " index = dictionary[word]\n", 196 | " else:\n", 197 | " index = 0 # dictionary['UNK']\n", 198 | " unk_count = unk_count + 1\n", 199 | " data.append(index)\n", 200 | " \n", 201 | " # update the count variable with the number of UNK occurences\n", 202 | " count[0][1] = unk_count\n", 203 | " \n", 204 | " reverse_dictionary = dict(zip(dictionary.values(), dictionary.keys())) \n", 205 | " # Make sure the dictionary is of size of the vocabulary\n", 206 | " assert len(dictionary) == vocabulary_size\n", 207 | " \n", 208 | " return data, count, dictionary, reverse_dictionary\n", 209 | "\n", 210 | "data, count, dictionary, reverse_dictionary = build_dataset(words)\n", 211 | "print('Most common words (+UNK)', count[:5])\n", 212 | "print('Sample data', data[:10])\n", 213 | "del words # Hint to reduce memory." 214 | ] 215 | }, 216 | { 217 | "cell_type": "markdown", 218 | "metadata": {}, 219 | "source": [ 220 | "## Generating Batches of Data for GloVe\n", 221 | "Generates a batch or target words (`batch`) and a batch of corresponding context words (`labels`). It reads `2*window_size+1` words at a time (called a `span`) and create `2*window_size` datapoints in a single span. The function continue in this manner until `batch_size` datapoints are created. Everytime we reach the end of the word sequence, we start from beginning. " 222 | ] 223 | }, 224 | { 225 | "cell_type": "code", 226 | "execution_count": 5, 227 | "metadata": {}, 228 | "outputs": [ 229 | { 230 | "name": "stdout", 231 | "output_type": "stream", 232 | "text": [ 233 | "data: ['propaganda', 'is', 'a', 'concerted', 'set', 'of', 'messages', 'aimed']\n", 234 | "\n", 235 | "with window_size = 2:\n", 236 | " batch: ['a', 'a', 'a', 'a', 'concerted', 'concerted', 'concerted', 'concerted']\n", 237 | " labels: ['propaganda', 'is', 'concerted', 'set', 'is', 'a', 'set', 'of']\n", 238 | " weights: [0.5, 1.0, 1.0, 0.5, 0.5, 1.0, 1.0, 0.5]\n", 239 | "\n", 240 | "with window_size = 4:\n", 241 | " batch: ['set', 'set', 'set', 'set', 'set', 'set', 'set', 'set']\n", 242 | " labels: ['propaganda', 'is', 'a', 'concerted', 'of', 'messages', 'aimed', 'at']\n", 243 | " weights: [0.25, 0.33333334, 0.5, 1.0, 1.0, 0.5, 0.33333334, 0.25]\n" 244 | ] 245 | } 246 | ], 247 | "source": [ 248 | "data_index = 0\n", 249 | "\n", 250 | "def generate_batch(batch_size, window_size):\n", 251 | " # data_index is updated by 1 everytime we read a data point\n", 252 | " global data_index \n", 253 | " \n", 254 | " # two numpy arras to hold target words (batch)\n", 255 | " # and context words (labels)\n", 256 | " batch = np.ndarray(shape=(batch_size), dtype=np.int32)\n", 257 | " labels = np.ndarray(shape=(batch_size, 1), dtype=np.int32)\n", 258 | " weights = np.ndarray(shape=(batch_size), dtype=np.float32)\n", 259 | "\n", 260 | " # span defines the total window size, where\n", 261 | " # data we consider at an instance looks as follows. \n", 262 | " # [ skip_window target skip_window ]\n", 263 | " span = 2 * window_size + 1 \n", 264 | " \n", 265 | " # The buffer holds the data contained within the span\n", 266 | " buffer = collections.deque(maxlen=span)\n", 267 | " \n", 268 | " # Fill the buffer and update the data_index\n", 269 | " for _ in range(span):\n", 270 | " buffer.append(data[data_index])\n", 271 | " data_index = (data_index + 1) % len(data)\n", 272 | " \n", 273 | " # This is the number of context words we sample for a single target word\n", 274 | " num_samples = 2*window_size \n", 275 | "\n", 276 | " # We break the batch reading into two for loops\n", 277 | " # The inner for loop fills in the batch and labels with \n", 278 | " # num_samples data points using data contained withing the span\n", 279 | " # The outper for loop repeat this for batch_size//num_samples times\n", 280 | " # to produce a full batch\n", 281 | " for i in range(batch_size // num_samples):\n", 282 | " k=0\n", 283 | " # avoid the target word itself as a prediction\n", 284 | " # fill in batch and label numpy arrays\n", 285 | " for j in list(range(window_size))+list(range(window_size+1,2*window_size+1)):\n", 286 | " batch[i * num_samples + k] = buffer[window_size]\n", 287 | " labels[i * num_samples + k, 0] = buffer[j]\n", 288 | " weights[i * num_samples + k] = abs(1.0/(j - window_size))\n", 289 | " k += 1 \n", 290 | " \n", 291 | " # Everytime we read num_samples data points,\n", 292 | " # we have created the maximum number of datapoints possible\n", 293 | " # withing a single span, so we need to move the span by 1\n", 294 | " # to create a fresh new span\n", 295 | " buffer.append(data[data_index])\n", 296 | " data_index = (data_index + 1) % len(data)\n", 297 | " return batch, labels, weights\n", 298 | "\n", 299 | "print('data:', [reverse_dictionary[di] for di in data[:8]])\n", 300 | "\n", 301 | "for window_size in [2, 4]:\n", 302 | " data_index = 0\n", 303 | " batch, labels, weights = generate_batch(batch_size=8, window_size=window_size)\n", 304 | " print('\\nwith window_size = %d:' %window_size)\n", 305 | " print(' batch:', [reverse_dictionary[bi] for bi in batch])\n", 306 | " print(' labels:', [reverse_dictionary[li] for li in labels.reshape(8)])\n", 307 | " print(' weights:', [w for w in weights])" 308 | ] 309 | }, 310 | { 311 | "cell_type": "markdown", 312 | "metadata": {}, 313 | "source": [ 314 | "## Creating the Word Co-Occurance Matrix\n", 315 | "Why GloVe shine above context window based method is that it employs global statistics of the corpus in to the model (according to authors). This is done by using information from the word co-occurance matrix to optimize the word vectors. Basically, the X(i,j) entry of the co-occurance matrix says how frequent word i to appear near j. We also use a weighting mechanishm to give more weight to words close together than to ones further-apart (from experiments section of the paper)." 316 | ] 317 | }, 318 | { 319 | "cell_type": "code", 320 | "execution_count": 6, 321 | "metadata": {}, 322 | "outputs": [ 323 | { 324 | "name": "stdout", 325 | "output_type": "stream", 326 | "text": [ 327 | "(50000, 50000)\n", 328 | "Running 420035 iterations to compute the co-occurance matrix\n", 329 | "\tFinished 100000 iterations\n", 330 | "\tFinished 200000 iterations\n", 331 | "\tFinished 300000 iterations\n", 332 | "\tFinished 400000 iterations\n", 333 | "Sample chunks of co-occurance matrix\n", 334 | "\n", 335 | "Target Word: \"UNK\"\n", 336 | "Context word:\",\"(id:2,count:3482.30), \"UNK\"(id:0,count:2164.01), \"the\"(id:1,count:2020.93), \"and\"(id:5,count:1454.50), \".\"(id:3,count:1310.58), \"of\"(id:4,count:1086.33), \"(\"(id:13,count:1047.17), \")\"(id:12,count:831.17), \"in\"(id:6,count:776.17), \"a\"(id:8,count:624.50), \n", 337 | "\n", 338 | "Target Word: \"imagery\"\n", 339 | "Context word:\"and\"(id:5,count:1.25), \"UNK\"(id:0,count:1.00), \"generated\"(id:3145,count:1.00), \"demonstrates\"(id:10422,count:1.00), \"explored\"(id:5276,count:1.00), \"horrific\"(id:16241,count:1.00), \"(\"(id:13,count:1.00), \",\"(id:2,count:0.58), \"computer\"(id:936,count:0.50), \"goya\"(id:22688,count:0.50), \n", 340 | "\n", 341 | "Target Word: \"defining\"\n", 342 | "Context word:\"the\"(id:1,count:5.00), \"of\"(id:4,count:2.83), \"and\"(id:5,count:1.50), \"feature\"(id:1397,count:1.00), \"other\"(id:42,count:1.00), \"influence\"(id:452,count:1.00), \"it\"(id:24,count:1.00), \"or\"(id:29,count:1.00), \"than\"(id:62,count:1.00), \"moments\"(id:7053,count:1.00), \n", 343 | "\n", 344 | "Target Word: \"liberalism\"\n", 345 | "Context word:\"of\"(id:4,count:5.75), \".\"(id:3,count:3.25), \"forms\"(id:423,count:1.50), \"the\"(id:1,count:1.08), \"western\"(id:216,count:1.00), \"within\"(id:152,count:1.00), \"are\"(id:22,count:1.00), \"may\"(id:73,count:1.00), \"on\"(id:18,count:1.00), \"<\"(id:1716,count:1.00), \n", 346 | "\n", 347 | "Target Word: \"rampant\"\n", 348 | "Context word:\",\"(id:2,count:2.08), \"and\"(id:5,count:1.25), \"UNK\"(id:0,count:1.00), \"government\"(id:84,count:1.00), \"also\"(id:37,count:1.00), \"were\"(id:31,count:1.00), \"throughout\"(id:308,count:1.00), \".\"(id:3,count:1.00), \"the\"(id:1,count:0.83), \"expenditures\"(id:10039,count:0.50), \n", 349 | "\n", 350 | "Target Word: \"and\"\n", 351 | "Context word:\",\"(id:2,count:3990.46), \"the\"(id:1,count:2566.00), \"UNK\"(id:0,count:1488.59), \"of\"(id:4,count:1001.66), \".\"(id:3,count:894.16), \"in\"(id:6,count:728.16), \"to\"(id:7,count:555.67), \"a\"(id:8,count:549.25), \")\"(id:12,count:412.92), \"and\"(id:5,count:318.50), \n", 352 | "\n", 353 | "Target Word: \"in\"\n", 354 | "Context word:\"the\"(id:1,count:3765.79), \",\"(id:2,count:1934.93), \".\"(id:3,count:1836.76), \"UNK\"(id:0,count:776.17), \"of\"(id:4,count:747.16), \"and\"(id:5,count:723.41), \"a\"(id:8,count:685.08), \"to\"(id:7,count:425.67), \"in\"(id:6,count:316.00), \"was\"(id:11,count:290.08), \n", 355 | "\n", 356 | "Target Word: \"to\"\n", 357 | "Context word:\"the\"(id:1,count:2449.92), \",\"(id:2,count:990.33), \".\"(id:3,count:687.00), \"a\"(id:8,count:613.00), \"be\"(id:30,count:573.75), \"and\"(id:5,count:527.33), \"UNK\"(id:0,count:470.42), \"of\"(id:4,count:457.09), \"in\"(id:6,count:403.67), \"is\"(id:9,count:282.67), \n", 358 | "\n", 359 | "Target Word: \"a\"\n", 360 | "Context word:\",\"(id:2,count:1496.51), \"of\"(id:4,count:1298.42), \".\"(id:3,count:907.00), \"in\"(id:6,count:713.08), \"the\"(id:1,count:640.42), \"to\"(id:7,count:625.92), \"as\"(id:10,count:614.67), \"UNK\"(id:0,count:602.92), \"and\"(id:5,count:583.08), \"is\"(id:9,count:558.25), \n", 361 | "\n", 362 | "Target Word: \"is\"\n", 363 | "Context word:\"the\"(id:1,count:1062.00), \",\"(id:2,count:651.92), \".\"(id:3,count:567.50), \"a\"(id:8,count:504.00), \"it\"(id:24,count:381.92), \"of\"(id:4,count:340.67), \"UNK\"(id:0,count:298.83), \"to\"(id:7,count:261.42), \"in\"(id:6,count:237.42), \"and\"(id:5,count:232.08), \n" 364 | ] 365 | } 366 | ], 367 | "source": [ 368 | "# We are creating the co-occurance matrix as a compressed sparse colum matrix from scipy. \n", 369 | "cooc_data_index = 0\n", 370 | "dataset_size = len(data) # We iterate through the full text\n", 371 | "skip_window = 4 # How many words to consider left and right.\n", 372 | "\n", 373 | "# The sparse matrix that stores the word co-occurences\n", 374 | "cooc_mat = lil_matrix((vocabulary_size, vocabulary_size), dtype=np.float32)\n", 375 | "\n", 376 | "print(cooc_mat.shape)\n", 377 | "def generate_cooc(batch_size,skip_window):\n", 378 | " '''\n", 379 | " Generate co-occurence matrix by processing batches of data\n", 380 | " '''\n", 381 | " data_index = 0\n", 382 | " print('Running %d iterations to compute the co-occurance matrix'%(dataset_size//batch_size))\n", 383 | " for i in range(dataset_size//batch_size):\n", 384 | " # Printing progress\n", 385 | " if i>0 and i%100000==0:\n", 386 | " print('\\tFinished %d iterations'%i)\n", 387 | " \n", 388 | " # Generating a single batch of data\n", 389 | " batch, labels, weights = generate_batch(batch_size, skip_window)\n", 390 | " labels = labels.reshape(-1)\n", 391 | " \n", 392 | " # Incrementing the sparse matrix entries accordingly\n", 393 | " for inp,lbl,w in zip(batch,labels,weights): \n", 394 | " cooc_mat[inp,lbl] += (1.0*w)\n", 395 | "\n", 396 | "# Generate the matrix\n", 397 | "generate_cooc(8,skip_window) \n", 398 | "\n", 399 | "# Just printing some parts of co-occurance matrix\n", 400 | "print('Sample chunks of co-occurance matrix')\n", 401 | "\n", 402 | "\n", 403 | "# Basically calculates the highest cooccurance of several chosen word\n", 404 | "for i in range(10):\n", 405 | " idx_target = i\n", 406 | " \n", 407 | " # get the ith row of the sparse matrix and make it dense\n", 408 | " ith_row = cooc_mat.getrow(idx_target) \n", 409 | " ith_row_dense = ith_row.toarray('C').reshape(-1) \n", 410 | " \n", 411 | " # select target words only with a reasonable words around it.\n", 412 | " while np.sum(ith_row_dense)<10 or np.sum(ith_row_dense)>50000:\n", 413 | " # Choose a random word\n", 414 | " idx_target = np.random.randint(0,vocabulary_size)\n", 415 | " \n", 416 | " # get the ith row of the sparse matrix and make it dense\n", 417 | " ith_row = cooc_mat.getrow(idx_target) \n", 418 | " ith_row_dense = ith_row.toarray('C').reshape(-1) \n", 419 | " \n", 420 | " print('\\nTarget Word: \"%s\"'%reverse_dictionary[idx_target])\n", 421 | " \n", 422 | " sort_indices = np.argsort(ith_row_dense).reshape(-1) # indices with highest count of ith_row_dense\n", 423 | " sort_indices = np.flip(sort_indices,axis=0) # reverse the array (to get max values to the start)\n", 424 | "\n", 425 | " # printing several context words to make sure cooc_mat is correct\n", 426 | " print('Context word:',end='')\n", 427 | " for j in range(10): \n", 428 | " idx_context = sort_indices[j] \n", 429 | " print('\"%s\"(id:%d,count:%.2f), '%(reverse_dictionary[idx_context],idx_context,ith_row_dense[idx_context]),end='')\n", 430 | " print()" 431 | ] 432 | }, 433 | { 434 | "cell_type": "markdown", 435 | "metadata": {}, 436 | "source": [ 437 | "## GloVe Algorithm" 438 | ] 439 | }, 440 | { 441 | "cell_type": "markdown", 442 | "metadata": { 443 | "collapsed": true 444 | }, 445 | "source": [ 446 | "### Defining Hyperparameters\n", 447 | "\n", 448 | "Here we define several hyperparameters including `batch_size` (amount of samples in a single batch) `embedding_size` (size of embedding vectors) `window_size` (context window size)." 449 | ] 450 | }, 451 | { 452 | "cell_type": "code", 453 | "execution_count": 7, 454 | "metadata": { 455 | "collapsed": true 456 | }, 457 | "outputs": [], 458 | "source": [ 459 | "batch_size = 128 # Data points in a single batch\n", 460 | "embedding_size = 128 # Dimension of the embedding vector.\n", 461 | "window_size = 4 # How many words to consider left and right.\n", 462 | "\n", 463 | "# We pick a random validation set to sample nearest neighbors\n", 464 | "valid_size = 16 # Random set of words to evaluate similarity on.\n", 465 | "# We sample valid datapoints randomly from a large window without always being deterministic\n", 466 | "valid_window = 50\n", 467 | "\n", 468 | "# When selecting valid examples, we select some of the most frequent words as well as\n", 469 | "# some moderately rare words as well\n", 470 | "valid_examples = np.array(random.sample(range(valid_window), valid_size))\n", 471 | "valid_examples = np.append(valid_examples,random.sample(range(1000, 1000+valid_window), valid_size),axis=0)\n", 472 | "\n", 473 | "num_sampled = 32 # Number of negative examples to sample.\n", 474 | "\n", 475 | "epsilon = 1 # used for the stability of log in the loss function" 476 | ] 477 | }, 478 | { 479 | "cell_type": "markdown", 480 | "metadata": {}, 481 | "source": [ 482 | "### Defining Inputs and Outputs\n", 483 | "\n", 484 | "Here we define placeholders for feeding in training inputs and outputs (each of size `batch_size`) and a constant tensor to contain validation examples." 485 | ] 486 | }, 487 | { 488 | "cell_type": "code", 489 | "execution_count": 8, 490 | "metadata": { 491 | "collapsed": true 492 | }, 493 | "outputs": [], 494 | "source": [ 495 | "tf.reset_default_graph()\n", 496 | "\n", 497 | "# Training input data (target word IDs).\n", 498 | "train_dataset = tf.placeholder(tf.int32, shape=[batch_size])\n", 499 | "# Training input label data (context word IDs)\n", 500 | "train_labels = tf.placeholder(tf.int32, shape=[batch_size])\n", 501 | "# Validation input data, we don't need a placeholder\n", 502 | "# as we have already defined the IDs of the words selected\n", 503 | "# as validation data\n", 504 | "valid_dataset = tf.constant(valid_examples, dtype=tf.int32)" 505 | ] 506 | }, 507 | { 508 | "cell_type": "markdown", 509 | "metadata": {}, 510 | "source": [ 511 | "### Defining Model Parameters and Other Variables\n", 512 | "We now define four TensorFlow variables which is composed of an embedding layer, a bias for each input and output words." 513 | ] 514 | }, 515 | { 516 | "cell_type": "code", 517 | "execution_count": 9, 518 | "metadata": { 519 | "collapsed": true 520 | }, 521 | "outputs": [], 522 | "source": [ 523 | "# Variables.\n", 524 | "in_embeddings = tf.Variable(\n", 525 | " tf.random_uniform([vocabulary_size, embedding_size], -1.0, 1.0),name='embeddings')\n", 526 | "in_bias_embeddings = tf.Variable(tf.random_uniform([vocabulary_size],0.0,0.01,dtype=tf.float32),name='embeddings_bias')\n", 527 | "\n", 528 | "out_embeddings = tf.Variable(\n", 529 | " tf.random_uniform([vocabulary_size, embedding_size], -1.0, 1.0),name='embeddings')\n", 530 | "out_bias_embeddings = tf.Variable(tf.random_uniform([vocabulary_size],0.0,0.01,dtype=tf.float32),name='embeddings_bias')" 531 | ] 532 | }, 533 | { 534 | "cell_type": "markdown", 535 | "metadata": {}, 536 | "source": [ 537 | "### Defining the Model Computations\n", 538 | "\n", 539 | "We first defing a lookup function to fetch the corresponding embedding vectors for a set of given inputs. Then we define a placeholder that takes in the weights for a given batch of data points (`weights_x`) and co-occurence matrix weights (`x_ij`). `weights_x` measures the importance of a data point with respect to how much those two words co-occur and `x_ij` denotes the co-occurence matrix value for the row and column denoted by the words in a datapoint. With these defined, we can define the loss as shown below. For exact details refer Chapter 4 text." 540 | ] 541 | }, 542 | { 543 | "cell_type": "code", 544 | "execution_count": 10, 545 | "metadata": { 546 | "collapsed": true 547 | }, 548 | "outputs": [], 549 | "source": [ 550 | "# Look up embeddings for inputs and outputs\n", 551 | "# Have two seperate embedding vector spaces for inputs and outputs\n", 552 | "embed_in = tf.nn.embedding_lookup(in_embeddings, train_dataset)\n", 553 | "embed_out = tf.nn.embedding_lookup(out_embeddings, train_labels)\n", 554 | "embed_bias_in = tf.nn.embedding_lookup(in_bias_embeddings,train_dataset)\n", 555 | "embed_bias_out = tf.nn.embedding_lookup(out_bias_embeddings,train_labels)\n", 556 | "\n", 557 | "# weights used in the cost function\n", 558 | "weights_x = tf.placeholder(tf.float32,shape=[batch_size],name='weights_x') \n", 559 | "# Cooccurence value for that position\n", 560 | "x_ij = tf.placeholder(tf.float32,shape=[batch_size],name='x_ij')\n", 561 | "\n", 562 | "# Compute the loss defined in the paper. Note that \n", 563 | "# I'm not following the exact equation given (which is computing a pair of words at a time)\n", 564 | "# I'm calculating the loss for a batch at one time, but the calculations are identical.\n", 565 | "# I also made an assumption about the bias, that it is a smaller type of embedding\n", 566 | "loss = tf.reduce_mean(\n", 567 | " weights_x * (tf.reduce_sum(embed_in*embed_out,axis=1) + embed_bias_in + embed_bias_out - tf.log(epsilon+x_ij))**2)\n" 568 | ] 569 | }, 570 | { 571 | "cell_type": "markdown", 572 | "metadata": {}, 573 | "source": [ 574 | "### Calculating Word Similarities \n", 575 | "We calculate the similarity between two given words in terms of the cosine distance. To do this efficiently we use matrix operations to do so, as shown below." 576 | ] 577 | }, 578 | { 579 | "cell_type": "code", 580 | "execution_count": 11, 581 | "metadata": { 582 | "collapsed": true 583 | }, 584 | "outputs": [], 585 | "source": [ 586 | "# Compute the similarity between minibatch examples and all embeddings.\n", 587 | "# We use the cosine distance:\n", 588 | "embeddings = (in_embeddings + out_embeddings)/2.0\n", 589 | "norm = tf.sqrt(tf.reduce_sum(tf.square(embeddings), 1, keepdims=True))\n", 590 | "normalized_embeddings = embeddings / norm\n", 591 | "valid_embeddings = tf.nn.embedding_lookup(\n", 592 | "normalized_embeddings, valid_dataset)\n", 593 | "similarity = tf.matmul(valid_embeddings, tf.transpose(normalized_embeddings))" 594 | ] 595 | }, 596 | { 597 | "cell_type": "markdown", 598 | "metadata": {}, 599 | "source": [ 600 | "### Model Parameter Optimizer\n", 601 | "\n", 602 | "We then define a constant learning rate and an optimizer which uses the Adagrad method. Feel free to experiment with other optimizers listed [here](https://www.tensorflow.org/api_guides/python/train)." 603 | ] 604 | }, 605 | { 606 | "cell_type": "code", 607 | "execution_count": 12, 608 | "metadata": { 609 | "collapsed": true 610 | }, 611 | "outputs": [], 612 | "source": [ 613 | "# Optimizer.\n", 614 | "optimizer = tf.train.AdagradOptimizer(1.0).minimize(loss)" 615 | ] 616 | }, 617 | { 618 | "cell_type": "markdown", 619 | "metadata": {}, 620 | "source": [ 621 | "## Running the GloVe Algorithm\n", 622 | "\n", 623 | "Here we run the GloVe algorithm we defined above. Specifically, we first initialize variables, and then train the algorithm for many steps (`num_steps`). And every few steps we evaluate the algorithm on a fixed validation set and print out the words that appear to be closest for a given set of words." 624 | ] 625 | }, 626 | { 627 | "cell_type": "code", 628 | "execution_count": 13, 629 | "metadata": {}, 630 | "outputs": [ 631 | { 632 | "name": "stdout", 633 | "output_type": "stream", 634 | "text": [ 635 | "Initialized\n", 636 | "Average loss at step 0: 9.578778\n", 637 | "Nearest to it: karol, burgh, destabilise, armchair, crook, roguery, one-sixth, swains,\n", 638 | "Nearest to that: wmap, partake, ahmadi, armstrong, memberships, forza, director-general, condo,\n", 639 | "Nearest to has: mentality, vastly, approaches, bulwark, enzymes, originally, privatize, reunify,\n", 640 | "Nearest to but: inhabited, potrero, trust, memory, curran, philips, p.m.s, pagoda,\n", 641 | "Nearest to city: seals, counter-revolution, tubular, kayaking, central, 1568, override, buckland,\n", 642 | "Nearest to this: dispersion, intermarriage, dialysis, moguls, aldermen, alcoholic, codes, farallon,\n", 643 | "Nearest to UNK: 40.3, tatsam, jupiter, verify, unequal, berliners, march, 1559,\n", 644 | "Nearest to by: functionalists, synthesised, palladius, chiapas, synaptic, sumner, raining, valued,\n", 645 | "Nearest to or: amherst, 'mother, epiglottis, wen, stanislaus, trafford, cuticle, reminded,\n", 646 | "Nearest to been: 640,961., depression-era, uniquely, mami, 375,000, stickiness, medium-sized, amor,\n", 647 | "Nearest to with: anti-statist, pitigliano, branches, reparations, acquittal, frowned, pishpek, left-leaning,\n", 648 | "Nearest to be: i-20, kevin, greased, rightly, conductors, hypercholesterolemia, pedro, douaumont,\n", 649 | "Nearest to as: gabon, horda, mead, protruding, soundtrack, algeria, 48, macon,\n", 650 | "Nearest to at: kambula, tisa, spelled, 130,000, 2008, organisers, |jul_rec_lo_°f, arrows,\n", 651 | "Nearest to ,: is, of, its, malton, martinů, retiree, reliant, uri,\n", 652 | "Nearest to its: of, ,, galleon, gitlow, rugby-playing, varanasi, fono, clusters,\n", 653 | "Average loss at step 2000: 0.739107\n", 654 | "Average loss at step 4000: 0.091107\n", 655 | "Average loss at step 6000: 0.068614\n", 656 | "Average loss at step 8000: 0.076040\n", 657 | "Average loss at step 10000: 0.058149\n", 658 | "Nearest to it: was, is, that, not, a, in, to, .,\n", 659 | "Nearest to that: is, was, the, a, ., ,, to, in,\n", 660 | "Nearest to has: is, it, that, a, been, was, to, mentality,\n", 661 | "Nearest to but: with, said, trust, mating, not, squamous, war—the, r101,\n", 662 | "Nearest to city: of, 's, counter-revolution, the, professed, ., equilibrium, seals,\n", 663 | "Nearest to this: is, ., for, in, was, the, a, that,\n", 664 | "Nearest to UNK: and, ,, (, in, the, ., ), a,\n", 665 | "Nearest to by: the, and, ,, ., in, was, of, a,\n", 666 | "Nearest to or: UNK, ,, and, a, cuticle, donnchad, ``, 'mother,\n", 667 | "Nearest to been: have, had, to, has, be, was, that, it,\n", 668 | "Nearest to with: ,, and, a, the, in, of, for, .,\n", 669 | "Nearest to be: to, have, that, a, for, not, can, been,\n", 670 | "Nearest to as: a, ,, for, and, UNK, ``, is, in,\n", 671 | "Nearest to at: the, of, ., in, ,, and, 's, UNK,\n", 672 | "Nearest to ,: and, UNK, in, the, ., a, of, for,\n", 673 | "Nearest to its: compacted, for, puzzling, buddha, bjorn, d'etat, tēōtl, encapsulated,\n", 674 | "Average loss at step 12000: 0.048867\n", 675 | "Average loss at step 14000: 0.102374\n", 676 | "Average loss at step 16000: 0.047017\n", 677 | "Average loss at step 18000: 0.041279\n", 678 | "Average loss at step 20000: 0.065086\n", 679 | "Nearest to it: is, was, not, that, a, to, he, has,\n", 680 | "Nearest to that: is, was, the, a, to, ,, ., and,\n", 681 | "Nearest to has: it, been, was, a, is, that, to, .,\n", 682 | "Nearest to but: with, not, which, that, ,, said, mating, trust,\n", 683 | "Nearest to city: of, 's, the, ., in, counter-revolution, for, at,\n", 684 | "Nearest to this: ., is, was, for, in, the, it, of,\n", 685 | "Nearest to UNK: and, ,, (, ), in, the, a, .,\n", 686 | "Nearest to by: the, in, ,, and, was, ., of, a,\n", 687 | "Nearest to or: UNK, ,, a, and, (, ``, ), with,\n", 688 | "Nearest to been: have, had, has, be, to, was, that, not,\n", 689 | "Nearest to with: and, ,, a, of, the, in, for, UNK,\n", 690 | "Nearest to be: to, have, that, a, not, from, is, been,\n", 691 | "Nearest to as: a, ,, for, and, is, UNK, the, of,\n", 692 | "Nearest to at: the, ., of, in, 's, and, ,, by,\n", 693 | "Nearest to ,: and, UNK, in, the, a, ., of, with,\n", 694 | "Nearest to its: for, with, and, compacted, of, his, tēōtl, ,,\n", 695 | "Average loss at step 22000: 0.036469\n", 696 | "Average loss at step 24000: 0.037744\n", 697 | "Average loss at step 26000: 0.035548\n", 698 | "Average loss at step 28000: 0.035010\n", 699 | "Average loss at step 30000: 0.038970\n", 700 | "Nearest to it: is, was, not, that, has, he, a, this,\n", 701 | "Nearest to that: is, was, the, to, ,, ., a, and,\n", 702 | "Nearest to has: it, is, was, been, a, that, have, to,\n", 703 | "Nearest to but: which, with, not, ,, was, that, is, it,\n", 704 | "Nearest to city: of, 's, the, ., in, at, from, world,\n", 705 | "Nearest to this: is, ., was, in, for, it, the, a,\n", 706 | "Nearest to UNK: and, ,, (, ), in, ., the, a,\n", 707 | "Nearest to by: the, was, in, ,, and, ., of, a,\n", 708 | "Nearest to or: UNK, (, ,, ``, and, a, ), with,\n", 709 | "Nearest to been: have, had, has, be, to, was, that, not,\n", 710 | "Nearest to with: ,, and, a, of, the, in, ., UNK,\n", 711 | "Nearest to be: to, have, from, not, that, a, is, can,\n", 712 | "Nearest to as: a, ,, for, an, and, is, UNK, with,\n", 713 | "Nearest to at: the, of, ., in, 's, ,, and, UNK,\n", 714 | "Nearest to ,: and, UNK, in, the, ., a, with, of,\n", 715 | "Nearest to its: for, with, his, and, to, of, the, ,,\n", 716 | "Average loss at step 32000: 0.033023\n", 717 | "Average loss at step 34000: 0.031445\n", 718 | "Average loss at step 36000: 0.030053\n", 719 | "Average loss at step 38000: 0.028875\n", 720 | "Average loss at step 40000: 0.028649\n", 721 | "Nearest to it: is, was, not, that, has, a, also, he,\n", 722 | "Nearest to that: is, was, to, a, the, it, ,, .,\n", 723 | "Nearest to has: it, was, is, been, a, that, had, also,\n", 724 | "Nearest to but: which, not, with, that, ,, was, it, is,\n", 725 | "Nearest to city: 's, of, the, ., in, at, world, from,\n", 726 | "Nearest to this: is, ., was, in, for, it, the, a,\n", 727 | "Nearest to UNK: and, ,, ), (, in, the, ., a,\n", 728 | "Nearest to by: was, ,, the, in, and, ., of, to,\n", 729 | "Nearest to or: UNK, (, ,, a, ), and, ``, with,\n", 730 | "Nearest to been: have, had, has, be, to, was, not, that,\n", 731 | "Nearest to with: and, ,, a, of, the, in, for, UNK,\n", 732 | "Nearest to be: to, have, from, that, not, a, can, is,\n", 733 | "Nearest to as: a, ,, an, for, and, is, such, to,\n", 734 | "Nearest to at: the, ., of, in, 's, ,, by, and,\n", 735 | "Nearest to ,: and, UNK, in, the, a, ., with, of,\n", 736 | "Nearest to its: for, and, with, to, his, of, the, ,,\n", 737 | "Average loss at step 42000: 0.037198\n", 738 | "Average loss at step 44000: 0.027172\n", 739 | "Average loss at step 46000: 0.027344\n", 740 | "Average loss at step 48000: 0.028739\n", 741 | "Average loss at step 50000: 0.105829\n", 742 | "Nearest to it: is, was, that, not, has, he, also, a,\n", 743 | "Nearest to that: is, was, a, to, it, by, the, .,\n", 744 | "Nearest to has: it, is, been, was, that, a, have, had,\n", 745 | "Nearest to but: which, not, with, ,, it, that, was, is,\n", 746 | "Nearest to city: of, the, 's, ., in, from, at, is,\n", 747 | "Nearest to this: is, ., in, for, was, it, the, a,\n", 748 | "Nearest to UNK: (, and, ,, from, ., at, by, a,\n", 749 | "Nearest to by: the, ,, and, ., a, in, was, of,\n", 750 | "Nearest to or: ``, ), (, ,, a, with, and, '',\n", 751 | "Nearest to been: have, has, had, be, to, that, was, also,\n", 752 | "Nearest to with: and, ,, a, for, the, in, of, .,\n", 753 | "Nearest to be: to, have, not, from, a, that, is, can,\n", 754 | "Nearest to as: a, an, ,, for, is, and, to, such,\n", 755 | "Nearest to at: the, of, ., in, by, 's, ,, and,\n", 756 | "Nearest to ,: and, in, the, ., a, of, with, for,\n", 757 | "Nearest to its: for, with, and, ,, his, the, of, in,\n", 758 | "Average loss at step 52000: 0.111760\n", 759 | "Average loss at step 54000: 0.031062\n", 760 | "Average loss at step 56000: 0.070919\n", 761 | "Average loss at step 58000: 0.027815\n", 762 | "Average loss at step 60000: 0.025161\n", 763 | "Nearest to it: is, was, that, not, has, also, he, a,\n", 764 | "Nearest to that: is, was, it, to, the, ., a, by,\n", 765 | "Nearest to has: it, is, was, been, a, that, had, also,\n", 766 | "Nearest to but: which, not, with, ,, it, was, that, is,\n", 767 | "Nearest to city: 's, of, the, ., in, from, at, world,\n", 768 | "Nearest to this: is, ., was, in, it, for, the, at,\n", 769 | "Nearest to UNK: (, ), and, ,, the, a, by, .,\n", 770 | "Nearest to by: the, was, in, ., ,, and, of, is,\n", 771 | "Nearest to or: (, ), ,, a, ``, and, with, '',\n", 772 | "Nearest to been: have, has, had, be, was, also, that, to,\n", 773 | "Nearest to with: and, ,, a, of, in, for, the, .,\n", 774 | "Nearest to be: to, have, can, not, from, a, is, that,\n", 775 | "Nearest to as: a, an, ,, for, such, and, is, to,\n", 776 | "Nearest to at: the, ., of, in, 's, by, ,, is,\n", 777 | "Nearest to ,: and, in, the, ., a, with, of, for,\n", 778 | "Nearest to its: for, with, and, to, ,, of, the, his,\n", 779 | "Average loss at step 62000: 0.024341\n", 780 | "Average loss at step 64000: 0.024122\n", 781 | "Average loss at step 66000: 0.023625\n", 782 | "Average loss at step 68000: 0.023307\n", 783 | "Average loss at step 70000: 0.023168\n", 784 | "Nearest to it: is, was, not, that, has, also, a, he,\n", 785 | "Nearest to that: is, was, to, the, it, ., a, ,,\n", 786 | "Nearest to has: it, was, been, had, a, is, that, have,\n", 787 | "Nearest to but: which, not, with, ,, it, was, that, is,\n", 788 | "Nearest to city: of, 's, the, ., in, at, world, from,\n", 789 | "Nearest to this: is, ., was, in, it, for, the, at,\n", 790 | "Nearest to UNK: (, ), and, ,, by, or, the, in,\n", 791 | "Nearest to by: the, ,, was, ., and, in, of, a,\n", 792 | "Nearest to or: (, ), a, ,, and, ``, UNK, with,\n", 793 | "Nearest to been: have, has, had, be, was, also, to, that,\n", 794 | "Nearest to with: and, ,, a, the, in, of, for, .,\n", 795 | "Nearest to be: to, have, can, not, a, that, is, from,\n", 796 | "Nearest to as: a, an, for, ,, such, and, is, to,\n", 797 | "Nearest to at: the, of, ., in, 's, by, ,, is,\n", 798 | "Nearest to ,: and, in, the, ., a, of, with, for,\n", 799 | "Nearest to its: for, with, and, their, his, ,, of, the,\n" 800 | ] 801 | }, 802 | { 803 | "name": "stdout", 804 | "output_type": "stream", 805 | "text": [ 806 | "Average loss at step 72000: 0.022413\n", 807 | "Average loss at step 74000: 0.021599\n", 808 | "Average loss at step 76000: 0.021968\n", 809 | "Average loss at step 78000: 0.021922\n", 810 | "Average loss at step 80000: 0.021073\n", 811 | "Nearest to it: is, was, not, that, also, has, a, this,\n", 812 | "Nearest to that: is, was, it, to, ., the, ,, a,\n", 813 | "Nearest to has: it, been, was, is, also, that, had, a,\n", 814 | "Nearest to but: which, not, with, it, ,, was, that, and,\n", 815 | "Nearest to city: of, 's, the, ., in, at, world, from,\n", 816 | "Nearest to this: is, ., was, in, it, for, at, the,\n", 817 | "Nearest to UNK: (, ), and, ,, or, a, ., by,\n", 818 | "Nearest to by: the, ,, was, and, ., in, a, of,\n", 819 | "Nearest to or: (, UNK, a, ), ,, and, ``, with,\n", 820 | "Nearest to been: have, has, had, also, be, to, that, was,\n", 821 | "Nearest to with: ,, and, a, the, of, in, for, .,\n", 822 | "Nearest to be: to, have, can, not, from, that, is, a,\n", 823 | "Nearest to as: a, an, ,, such, for, and, ., is,\n", 824 | "Nearest to at: of, the, ., in, 's, ,, and, by,\n", 825 | "Nearest to ,: and, in, the, ., a, with, of, for,\n", 826 | "Nearest to its: for, and, with, their, ,, his, to, the,\n", 827 | "Average loss at step 82000: 0.021116\n", 828 | "Average loss at step 84000: 0.020798\n", 829 | "Average loss at step 86000: 0.020017\n", 830 | "Average loss at step 88000: 0.019837\n", 831 | "Average loss at step 90000: 0.019543\n", 832 | "Nearest to it: is, was, that, also, not, has, this, a,\n", 833 | "Nearest to that: was, is, to, the, ., it, a, ,,\n", 834 | "Nearest to has: it, been, was, is, a, had, also, that,\n", 835 | "Nearest to but: which, not, ,, with, it, was, and, that,\n", 836 | "Nearest to city: of, 's, the, ., in, new, world, at,\n", 837 | "Nearest to this: is, ., was, it, in, for, the, at,\n", 838 | "Nearest to UNK: (, and, ), ,, in, or, a, .,\n", 839 | "Nearest to by: the, was, ,, in, ., and, a, of,\n", 840 | "Nearest to or: (, UNK, ), ``, a, ,, and, with,\n", 841 | "Nearest to been: have, has, had, also, be, that, was, to,\n", 842 | "Nearest to with: and, ,, a, the, of, in, for, .,\n", 843 | "Nearest to be: to, have, can, not, that, from, is, would,\n", 844 | "Nearest to as: a, an, ,, such, for, and, is, the,\n", 845 | "Nearest to at: of, the, ., in, 's, ,, and, by,\n", 846 | "Nearest to ,: and, in, the, ., a, with, of, UNK,\n", 847 | "Nearest to its: for, and, their, with, his, ,, the, of,\n", 848 | "Average loss at step 92000: 0.019305\n", 849 | "Average loss at step 94000: 0.019555\n", 850 | "Average loss at step 96000: 0.019266\n", 851 | "Average loss at step 98000: 0.018803\n", 852 | "Average loss at step 100000: 0.018488\n", 853 | "Nearest to it: is, was, also, that, not, has, this, a,\n", 854 | "Nearest to that: was, is, to, it, the, a, ., ,,\n", 855 | "Nearest to has: it, been, was, had, also, is, that, a,\n", 856 | "Nearest to but: which, not, ,, it, with, was, and, a,\n", 857 | "Nearest to city: of, 's, the, ., in, is, new, world,\n", 858 | "Nearest to this: is, ., was, it, in, for, the, at,\n", 859 | "Nearest to UNK: (, and, ), ,, or, a, the, .,\n", 860 | "Nearest to by: the, ., was, ,, and, in, of, a,\n", 861 | "Nearest to or: UNK, (, ``, a, ), ,, and, with,\n", 862 | "Nearest to been: have, has, had, also, be, was, that, to,\n", 863 | "Nearest to with: and, ,, a, the, of, in, for, .,\n", 864 | "Nearest to be: to, have, can, not, would, from, that, a,\n", 865 | "Nearest to as: a, such, an, ,, for, is, and, to,\n", 866 | "Nearest to at: of, ., the, in, 's, by, ,, and,\n", 867 | "Nearest to ,: and, in, the, ., a, with, UNK, of,\n", 868 | "Nearest to its: for, their, and, with, his, ,, to, the,\n" 869 | ] 870 | } 871 | ], 872 | "source": [ 873 | "num_steps = 100001\n", 874 | "glove_loss = []\n", 875 | "\n", 876 | "average_loss = 0\n", 877 | "with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as session:\n", 878 | " \n", 879 | " tf.global_variables_initializer().run()\n", 880 | " print('Initialized')\n", 881 | " \n", 882 | " for step in range(num_steps):\n", 883 | " \n", 884 | " # generate a single batch (data,labels,co-occurance weights)\n", 885 | " batch_data, batch_labels, batch_weights = generate_batch(\n", 886 | " batch_size, skip_window) \n", 887 | " \n", 888 | " # Computing the weights required by the loss function\n", 889 | " batch_weights = [] # weighting used in the loss function\n", 890 | " batch_xij = [] # weighted frequency of finding i near j\n", 891 | " \n", 892 | " # Compute the weights for each datapoint in the batch\n", 893 | " for inp,lbl in zip(batch_data,batch_labels.reshape(-1)): \n", 894 | " point_weight = (cooc_mat[inp,lbl]/100.0)**0.75 if cooc_mat[inp,lbl]<100.0 else 1.0 \n", 895 | " batch_weights.append(point_weight)\n", 896 | " batch_xij.append(cooc_mat[inp,lbl])\n", 897 | " batch_weights = np.clip(batch_weights,-100,1)\n", 898 | " batch_xij = np.asarray(batch_xij)\n", 899 | " \n", 900 | " # Populate the feed_dict and run the optimizer (minimize loss)\n", 901 | " # and compute the loss. Specifically we provide\n", 902 | " # train_dataset/train_labels: training inputs and training labels\n", 903 | " # weights_x: measures the importance of a data point with respect to how much those two words co-occur\n", 904 | " # x_ij: co-occurence matrix value for the row and column denoted by the words in a datapoint\n", 905 | " feed_dict = {train_dataset : batch_data.reshape(-1), train_labels : batch_labels.reshape(-1),\n", 906 | " weights_x:batch_weights,x_ij:batch_xij}\n", 907 | " _, l = session.run([optimizer, loss], feed_dict=feed_dict)\n", 908 | " \n", 909 | " # Update the average loss variable\n", 910 | " average_loss += l\n", 911 | " if step % 2000 == 0:\n", 912 | " if step > 0:\n", 913 | " average_loss = average_loss / 2000\n", 914 | " # The average loss is an estimate of the loss over the last 2000 batches.\n", 915 | " print('Average loss at step %d: %f' % (step, average_loss))\n", 916 | " glove_loss.append(average_loss)\n", 917 | " average_loss = 0\n", 918 | " \n", 919 | " # Here we compute the top_k closest words for a given validation word\n", 920 | " # in terms of the cosine distance\n", 921 | " # We do this for all the words in the validation set\n", 922 | " # Note: This is an expensive step\n", 923 | " if step % 10000 == 0:\n", 924 | " sim = similarity.eval()\n", 925 | " for i in range(valid_size):\n", 926 | " valid_word = reverse_dictionary[valid_examples[i]]\n", 927 | " top_k = 8 # number of nearest neighbors\n", 928 | " nearest = (-sim[i, :]).argsort()[1:top_k+1]\n", 929 | " log = 'Nearest to %s:' % valid_word\n", 930 | " for k in range(top_k):\n", 931 | " close_word = reverse_dictionary[nearest[k]]\n", 932 | " log = '%s %s,' % (log, close_word)\n", 933 | " print(log)\n", 934 | " \n", 935 | " final_embeddings = normalized_embeddings.eval()\n" 936 | ] 937 | }, 938 | { 939 | "cell_type": "code", 940 | "execution_count": null, 941 | "metadata": { 942 | "collapsed": true 943 | }, 944 | "outputs": [], 945 | "source": [] 946 | } 947 | ], 948 | "metadata": { 949 | "kernelspec": { 950 | "display_name": "Python 3", 951 | "language": "python", 952 | "name": "python3" 953 | }, 954 | "language_info": { 955 | "codemirror_mode": { 956 | "name": "ipython", 957 | "version": 3 958 | }, 959 | "file_extension": ".py", 960 | "mimetype": "text/x-python", 961 | "name": "python", 962 | "nbconvert_exporter": "python", 963 | "pygments_lexer": "ipython3", 964 | "version": "3.5.2" 965 | } 966 | }, 967 | "nbformat": 4, 968 | "nbformat_minor": 2 969 | } 970 | -------------------------------------------------------------------------------- /ch4/ch4_word2vec_extended.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Extended Word2vec and GloVe" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": {}, 14 | "outputs": [ 15 | { 16 | "name": "stderr", 17 | "output_type": "stream", 18 | "text": [ 19 | "c:\\users\\thushan\\documents\\python_virtualenvs\\tensorflow_venv\\lib\\site-packages\\h5py\\__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.\n", 20 | " from ._conv import register_converters as _register_converters\n" 21 | ] 22 | } 23 | ], 24 | "source": [ 25 | "# These are all the modules we'll be using later. Make sure you can import them\n", 26 | "# before proceeding further.\n", 27 | "%matplotlib inline\n", 28 | "from __future__ import print_function\n", 29 | "import collections\n", 30 | "import math\n", 31 | "import numpy as np\n", 32 | "import os\n", 33 | "import random\n", 34 | "import tensorflow as tf\n", 35 | "import bz2\n", 36 | "from matplotlib import pylab\n", 37 | "from six.moves import range\n", 38 | "from six.moves.urllib.request import urlretrieve\n", 39 | "from sklearn.manifold import TSNE\n", 40 | "from sklearn.cluster import KMeans\n", 41 | "import nltk # standard preprocessing\n", 42 | "import operator # sorting items in dictionary by value\n", 43 | "#nltk.download() #tokenizers/punkt/PY3/english.pickle\n", 44 | "from math import ceil\n", 45 | "import csv" 46 | ] 47 | }, 48 | { 49 | "cell_type": "markdown", 50 | "metadata": {}, 51 | "source": [ 52 | "## Dataset\n", 53 | "This code downloads a [dataset](http://www.evanjones.ca/software/wikipedia2text.html) consisting of several Wikipedia articles totaling up to roughly 61 megabytes. Additionally the code makes sure the file has the correct size after downloading it." 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": 2, 59 | "metadata": {}, 60 | "outputs": [ 61 | { 62 | "name": "stdout", 63 | "output_type": "stream", 64 | "text": [ 65 | "Found and verified wikipedia2text-extracted.txt.bz2\n" 66 | ] 67 | } 68 | ], 69 | "source": [ 70 | "url = 'http://www.evanjones.ca/software/'\n", 71 | "\n", 72 | "def maybe_download(filename, expected_bytes):\n", 73 | " \"\"\"Download a file if not present, and make sure it's the right size.\"\"\"\n", 74 | " if not os.path.exists(filename):\n", 75 | " filename, _ = urlretrieve(url + filename, filename)\n", 76 | " statinfo = os.stat(filename)\n", 77 | " if statinfo.st_size == expected_bytes:\n", 78 | " print('Found and verified %s' % filename)\n", 79 | " else:\n", 80 | " print(statinfo.st_size)\n", 81 | " raise Exception(\n", 82 | " 'Failed to verify ' + filename + '. Can you get to it with a browser?')\n", 83 | " return filename\n", 84 | "\n", 85 | "filename = maybe_download('wikipedia2text-extracted.txt.bz2', 18377035)" 86 | ] 87 | }, 88 | { 89 | "cell_type": "markdown", 90 | "metadata": {}, 91 | "source": [ 92 | "## Read Data with Preprocessing with NLTK\n", 93 | "Reads data as it is to a string, convert to lower-case and tokenize it using the nltk library. This code reads data in 1MB portions as processing the full text at once slows down the task and returns a list of words" 94 | ] 95 | }, 96 | { 97 | "cell_type": "code", 98 | "execution_count": 3, 99 | "metadata": {}, 100 | "outputs": [ 101 | { 102 | "name": "stdout", 103 | "output_type": "stream", 104 | "text": [ 105 | "Reading data...\n", 106 | "Data size 3360286\n", 107 | "Example words (start): ['propaganda', 'is', 'a', 'concerted', 'set', 'of', 'messages', 'aimed', 'at', 'influencing']\n", 108 | "Example words (end): ['favorable', 'long-term', 'outcomes', 'for', 'around', 'half', 'of', 'those', 'diagnosed', 'with']\n" 109 | ] 110 | } 111 | ], 112 | "source": [ 113 | "def read_data(filename):\n", 114 | " \"\"\"\n", 115 | " Extract the first file enclosed in a zip file as a list of words\n", 116 | " and pre-processes it using the nltk python library\n", 117 | " \"\"\"\n", 118 | "\n", 119 | " with bz2.BZ2File(filename) as f:\n", 120 | "\n", 121 | " data = []\n", 122 | " file_size = os.stat(filename).st_size\n", 123 | " chunk_size = 1024 * 1024 # reading 1 MB at a time as the dataset is moderately large\n", 124 | " print('Reading data...')\n", 125 | " for i in range(ceil(file_size//chunk_size)+1):\n", 126 | " bytes_to_read = min(chunk_size,file_size-(i*chunk_size))\n", 127 | " file_string = f.read(bytes_to_read).decode('utf-8')\n", 128 | " file_string = file_string.lower()\n", 129 | " # tokenizes a string to words residing in a list\n", 130 | " file_string = nltk.word_tokenize(file_string)\n", 131 | " data.extend(file_string)\n", 132 | " return data\n", 133 | "\n", 134 | "words = read_data(filename)\n", 135 | "print('Data size %d' % len(words))\n", 136 | "token_count = len(words)\n", 137 | "\n", 138 | "print('Example words (start): ',words[:10])\n", 139 | "print('Example words (end): ',words[-10:])" 140 | ] 141 | }, 142 | { 143 | "cell_type": "markdown", 144 | "metadata": {}, 145 | "source": [ 146 | "## Building the Dictionaries\n", 147 | "Builds the following. To understand each of these elements, let us also assume the text \"I like to go to school\"\n", 148 | "\n", 149 | "* `dictionary`: maps a string word to an ID (e.g. {I:0, like:1, to:2, go:3, school:4})\n", 150 | "* `reverse_dictionary`: maps an ID to a string word (e.g. {0:I, 1:like, 2:to, 3:go, 4:school}\n", 151 | "* `count`: List of list of (word, frequency) elements (e.g. [(I,1),(like,1),(to,2),(go,1),(school,1)]\n", 152 | "* `data` : Contain the string of text we read, where string words are replaced with word IDs (e.g. [0, 1, 2, 3, 2, 4])\n", 153 | "\n", 154 | "It also introduces an additional special token `UNK` to denote rare words to are too rare to make use of." 155 | ] 156 | }, 157 | { 158 | "cell_type": "code", 159 | "execution_count": 4, 160 | "metadata": {}, 161 | "outputs": [ 162 | { 163 | "name": "stdout", 164 | "output_type": "stream", 165 | "text": [ 166 | "Most common words (+UNK) [['UNK', 69215], ('the', 226881), (',', 184013), ('.', 120944), ('of', 116323)]\n", 167 | "Sample data [1728, 9, 8, 17488, 223, 4, 5211, 4461, 26, 11637]\n" 168 | ] 169 | } 170 | ], 171 | "source": [ 172 | "\n", 173 | "vocabulary_size = 50000\n", 174 | "\n", 175 | "def build_dataset(words):\n", 176 | " count = [['UNK', -1]]\n", 177 | " count.extend(collections.Counter(words).most_common(vocabulary_size - 1))\n", 178 | " dictionary = dict()\n", 179 | " for word, _ in count:\n", 180 | " dictionary[word] = len(dictionary)\n", 181 | " data = list()\n", 182 | " unk_count = 0\n", 183 | " for word in words:\n", 184 | " if word in dictionary:\n", 185 | " index = dictionary[word]\n", 186 | " else:\n", 187 | " index = 0 # dictionary['UNK']\n", 188 | " unk_count = unk_count + 1\n", 189 | " data.append(index)\n", 190 | " count[0][1] = unk_count\n", 191 | " reverse_dictionary = dict(zip(dictionary.values(), dictionary.keys())) \n", 192 | " assert len(dictionary) == vocabulary_size\n", 193 | " return data, count, dictionary, reverse_dictionary\n", 194 | "\n", 195 | "data, count, dictionary, reverse_dictionary = build_dataset(words)\n", 196 | "print('Most common words (+UNK)', count[:5])\n", 197 | "print('Sample data', data[:10])\n", 198 | "del words # Hint to reduce memory." 199 | ] 200 | }, 201 | { 202 | "cell_type": "code", 203 | "execution_count": 5, 204 | "metadata": {}, 205 | "outputs": [ 206 | { 207 | "name": "stdout", 208 | "output_type": "stream", 209 | "text": [ 210 | "data: ['propaganda', 'is', 'a', 'concerted', 'set', 'of', 'messages', 'aimed']\n", 211 | "\n", 212 | "with window_size = 1:\n", 213 | " batch: ['is', 'a', 'concerted', 'set', 'of', 'messages', 'aimed', 'at']\n", 214 | " labels: [['propaganda', 'a'], ['is', 'concerted'], ['a', 'set'], ['concerted', 'of'], ['set', 'messages'], ['of', 'aimed'], ['messages', 'at'], ['aimed', 'influencing']]\n", 215 | "\n", 216 | "with window_size = 2:\n", 217 | " batch: ['a', 'concerted', 'set', 'of', 'messages', 'aimed', 'at', 'influencing']\n", 218 | " labels: [['propaganda', 'is', 'concerted', 'set'], ['is', 'a', 'set', 'of'], ['a', 'concerted', 'of', 'messages'], ['concerted', 'set', 'messages', 'aimed'], ['set', 'of', 'aimed', 'at'], ['of', 'messages', 'at', 'influencing'], ['messages', 'aimed', 'influencing', 'the'], ['aimed', 'at', 'the', 'opinions']]\n" 219 | ] 220 | } 221 | ], 222 | "source": [ 223 | "data_index = 0\n", 224 | "\n", 225 | "def generate_batch(batch_size, window_size):\n", 226 | " global data_index\n", 227 | "\n", 228 | " # two numpy arras to hold target words (batch)\n", 229 | " # and context words (labels)\n", 230 | " # Note that the labels array has 2*window_size columns\n", 231 | " batch = np.ndarray(shape=(batch_size), dtype=np.int32)\n", 232 | " labels = np.ndarray(shape=(batch_size, 2*window_size), dtype=np.int32)\n", 233 | " \n", 234 | " # span defines the total window size, where\n", 235 | " # data we consider at an instance looks as follows. \n", 236 | " # [ skip_window target skip_window ]\n", 237 | " span = 2 * window_size + 1 # [ skip_window target skip_window ]\n", 238 | " \n", 239 | " buffer = collections.deque(maxlen=span)\n", 240 | " \n", 241 | " # Fill the buffer and update the data_index\n", 242 | " for _ in range(span):\n", 243 | " buffer.append(data[data_index])\n", 244 | " data_index = (data_index + 1) % len(data)\n", 245 | " \n", 246 | " # for a full length of batch size, we do the following\n", 247 | " # make the target word the i th input word (i th row of batch)\n", 248 | " # make all the context words the columns of labels array\n", 249 | " # Update the data index and the buffer \n", 250 | " for i in range(batch_size):\n", 251 | " batch[i] = buffer[window_size]\n", 252 | " labels[i, :] = [buffer[span_idx] for span_idx in list(range(0,window_size))+ list(range(window_size+1,span))]\n", 253 | " buffer.append(data[data_index])\n", 254 | " data_index = (data_index + 1) % len(data)\n", 255 | " \n", 256 | " return batch, labels\n", 257 | "\n", 258 | "print('data:', [reverse_dictionary[di] for di in data[:8]])\n", 259 | "\n", 260 | "for window_size in [1,2]:\n", 261 | " data_index = 0\n", 262 | " batch, labels = generate_batch(batch_size=8, window_size=window_size)\n", 263 | " print('\\nwith window_size = %d:' % window_size)\n", 264 | " print(' batch:', [reverse_dictionary[bi] for bi in batch])\n", 265 | " print(' labels:', [[reverse_dictionary[li] for li in lbls] for lbls in labels])" 266 | ] 267 | }, 268 | { 269 | "cell_type": "markdown", 270 | "metadata": {}, 271 | "source": [ 272 | "# Structured Skip-Gram Algorithm\n", 273 | "The basic idea behind the structured skip-gram algorithm is to pay attention to the position of the context words during learning. Giving the algorithm the power to distinguish between words falling very close to the target word and the ones that fall far away from the context words allow the structured skip-gram model to learn better word vectors ([Paper](http://www.cs.cmu.edu/~lingwang/papers/naacl2015.pdf)). You can learn about this algorithm in more detail in Chapter 4 text." 274 | ] 275 | }, 276 | { 277 | "cell_type": "markdown", 278 | "metadata": {}, 279 | "source": [ 280 | "### Defining Hyperparameters\n", 281 | "\n", 282 | "Here we define several hyperparameters including `batch_size` (amount of samples in a single batch) `embedding_size` (size of embedding vectors) `window_size` (context window size)." 283 | ] 284 | }, 285 | { 286 | "cell_type": "code", 287 | "execution_count": 6, 288 | "metadata": { 289 | "collapsed": true 290 | }, 291 | "outputs": [], 292 | "source": [ 293 | "batch_size = 128 # Data points in a single batch\n", 294 | "embedding_size = 128 # Dimension of the embedding vector.\n", 295 | "window_size = 2 # How many words to consider left and right.\n", 296 | "\n", 297 | "# We pick a random validation set to sample nearest neighbors\n", 298 | "valid_size = 16 # Random set of words to evaluate similarity on.\n", 299 | "# We sample valid datapoints randomly from a large window without always being deterministic\n", 300 | "valid_window = 50\n", 301 | "\n", 302 | "# When selecting valid examples, we select some of the most frequent words as well as\n", 303 | "# some moderately rare words as well\n", 304 | "valid_examples = np.array(random.sample(range(valid_window), valid_size))\n", 305 | "valid_examples = np.append(valid_examples,random.sample(range(1000, 1000+valid_window), valid_size),axis=0)\n", 306 | "\n", 307 | "num_sampled = 32 # Number of negative examples to sample." 308 | ] 309 | }, 310 | { 311 | "cell_type": "markdown", 312 | "metadata": {}, 313 | "source": [ 314 | "### Defining Inputs and Outputs\n", 315 | "\n", 316 | "Here we define placeholders for feeding in training inputs and outputs (each of size `batch_size`) and a constant tensor to contain validation examples." 317 | ] 318 | }, 319 | { 320 | "cell_type": "code", 321 | "execution_count": 7, 322 | "metadata": { 323 | "collapsed": true 324 | }, 325 | "outputs": [], 326 | "source": [ 327 | "tf.reset_default_graph()\n", 328 | "\n", 329 | "# Training input data (target word IDs).\n", 330 | "train_dataset = tf.placeholder(tf.int32, shape=[batch_size])\n", 331 | "# Training input label data (context word IDs)\n", 332 | "train_labels = [tf.placeholder(tf.int32, shape=[batch_size, 1]) for _ in range(2*window_size)]\n", 333 | "# Validation input data, we don't need a placeholder\n", 334 | "# as we have already defined the IDs of the words selected\n", 335 | "# as validation data\n", 336 | "valid_dataset = tf.constant(valid_examples, dtype=tf.int32)" 337 | ] 338 | }, 339 | { 340 | "cell_type": "markdown", 341 | "metadata": {}, 342 | "source": [ 343 | "### Defining Model Parameters and Other Variables\n", 344 | "We now define several TensorFlow variables such as an embedding layer (`embeddings`) and neural network parameters (`softmax_weights` and `softmax_biases`). Note that the softmax weights is `2*window_size` larger than the original skip-gram algorithms's softmax weights." 345 | ] 346 | }, 347 | { 348 | "cell_type": "code", 349 | "execution_count": 8, 350 | "metadata": { 351 | "collapsed": true 352 | }, 353 | "outputs": [], 354 | "source": [ 355 | "embeddings = tf.Variable(\n", 356 | "tf.random_uniform([vocabulary_size, embedding_size], -1.0, 1.0))\n", 357 | "softmax_weights = [tf.Variable(\n", 358 | "tf.truncated_normal([vocabulary_size, embedding_size],\n", 359 | " stddev=0.5 / math.sqrt(embedding_size))) for _ in range(2*window_size)]\n", 360 | "softmax_biases = [tf.Variable(tf.random_uniform([vocabulary_size],0.0,0.01)) for _ in range(2*window_size)]\n" 361 | ] 362 | }, 363 | { 364 | "cell_type": "markdown", 365 | "metadata": {}, 366 | "source": [ 367 | "### Defining the Model Computations\n", 368 | "\n", 369 | "We first defing a lookup function to fetch the corresponding embedding vectors for a set of given inputs. With that, we define negative sampling loss function `tf.nn.sampled_softmax_loss` which takes in the embedding vectors and previously defined neural network parameters." 370 | ] 371 | }, 372 | { 373 | "cell_type": "code", 374 | "execution_count": 9, 375 | "metadata": {}, 376 | "outputs": [ 377 | { 378 | "name": "stdout", 379 | "output_type": "stream", 380 | "text": [ 381 | "WARNING:tensorflow:From c:\\users\\thushan\\documents\\python_virtualenvs\\tensorflow_venv\\lib\\site-packages\\tensorflow\\python\\ops\\nn_impl.py:1344: softmax_cross_entropy_with_logits (from tensorflow.python.ops.nn_ops) is deprecated and will be removed in a future version.\n", 382 | "Instructions for updating:\n", 383 | "\n", 384 | "Future major versions of TensorFlow will allow gradients to flow\n", 385 | "into the labels input on backprop by default.\n", 386 | "\n", 387 | "See @{tf.nn.softmax_cross_entropy_with_logits_v2}.\n", 388 | "\n" 389 | ] 390 | } 391 | ], 392 | "source": [ 393 | "\n", 394 | "# Model.\n", 395 | "# Look up embeddings for inputs.\n", 396 | "embed = tf.nn.embedding_lookup(embeddings, train_dataset)\n", 397 | "\n", 398 | "# You might see the warning when running the line below\n", 399 | "# WARNING:tensorflow:From c:\\...\\lib\\site-packages\\tensorflow\\python\\ops\\nn_impl.py:1346: \n", 400 | "#softmax_cross_entropy_with_logits (from tensorflow.python.ops.nn_ops) is deprecated and \n", 401 | "# will be removed in a future version.\n", 402 | "# This is due to the sampled_softmax_loss function using a deprecated function internally\n", 403 | "# therefore, this is not an error in the code and you can ignore this error\n", 404 | "\n", 405 | "# Compute the softmax loss, using a sample of the negative labels each time.\n", 406 | "loss = tf.reduce_sum(\n", 407 | "[\n", 408 | " tf.reduce_mean(tf.nn.sampled_softmax_loss(weights=softmax_weights[wi], biases=softmax_biases[wi], inputs=embed,\n", 409 | " labels=train_labels[wi], num_sampled=num_sampled, num_classes=vocabulary_size))\n", 410 | " for wi in range(window_size*2)\n", 411 | "]\n", 412 | ")\n", 413 | "\n" 414 | ] 415 | }, 416 | { 417 | "cell_type": "markdown", 418 | "metadata": {}, 419 | "source": [ 420 | "### Calculating Word Similarities \n", 421 | "We calculate the similarity between two given words in terms of the cosine distance. To do this efficiently we use matrix operations to do so, as shown below." 422 | ] 423 | }, 424 | { 425 | "cell_type": "code", 426 | "execution_count": 10, 427 | "metadata": { 428 | "collapsed": true 429 | }, 430 | "outputs": [], 431 | "source": [ 432 | "# Compute the similarity between minibatch examples and all embeddings.\n", 433 | "# We use the cosine distance:\n", 434 | "norm = tf.sqrt(tf.reduce_sum(tf.square(embeddings), 1, keepdims=True))\n", 435 | "normalized_embeddings = embeddings / norm\n", 436 | "valid_embeddings = tf.nn.embedding_lookup(\n", 437 | "normalized_embeddings, valid_dataset)\n", 438 | "similarity = tf.matmul(valid_embeddings, tf.transpose(normalized_embeddings))\n" 439 | ] 440 | }, 441 | { 442 | "cell_type": "markdown", 443 | "metadata": {}, 444 | "source": [ 445 | "### Model Parameter Optimizer\n", 446 | "\n", 447 | "We then define a constant learning rate and an optimizer which uses the Adagrad method. Feel free to experiment with other optimizers listed [here](https://www.tensorflow.org/api_guides/python/train)." 448 | ] 449 | }, 450 | { 451 | "cell_type": "code", 452 | "execution_count": 11, 453 | "metadata": { 454 | "collapsed": true 455 | }, 456 | "outputs": [], 457 | "source": [ 458 | "\n", 459 | "# Optimizer.\n", 460 | "optimizer = tf.train.AdagradOptimizer(1.0).minimize(loss)\n" 461 | ] 462 | }, 463 | { 464 | "cell_type": "markdown", 465 | "metadata": {}, 466 | "source": [ 467 | "## Running the Structured Skip-gram Algorithm\n", 468 | "\n", 469 | "Here we run the structured skip-gram algorithm we defined above. Specifically, we first initialize variables, and then train the algorithm for many steps (`num_steps`). And every few steps we evaluate the algorithm on a fixed validation set and print out the words that appear to be closest for a given set of words." 470 | ] 471 | }, 472 | { 473 | "cell_type": "code", 474 | "execution_count": 12, 475 | "metadata": { 476 | "scrolled": true 477 | }, 478 | "outputs": [ 479 | { 480 | "name": "stdout", 481 | "output_type": "stream", 482 | "text": [ 483 | "Initialized\n", 484 | "Average loss at step 2000: 14.825290\n", 485 | "Average loss at step 4000: 12.924444\n", 486 | "Average loss at step 6000: 12.492212\n", 487 | "Average loss at step 8000: 12.194010\n", 488 | "Average loss at step 10000: 11.985265\n", 489 | "Nearest to .: ;, ,, of, :, and, reclassify, '', in,\n", 490 | "Nearest to which: but, that, who, and, it, what, where, then,\n", 491 | "Nearest to an: a, the, chong, its, constrained, rockwell, spartan, cigars,\n", 492 | "Nearest to as: creating, by, 29.9, kunda, ravens, diracodon, including, attractive,\n", 493 | "Nearest to be: been, have, being, daly, was, spectroscopic, often, were,\n", 494 | "Nearest to first: last, second, next, only, same, main, original, late,\n", 495 | "Nearest to ,: ;, ., (, and, :, of, —, ''protecteur,\n", 496 | "Nearest to (: ;, ,, na+, 30.1, travis, :, per, dram,\n", 497 | "Nearest to from: in, into, across, by, at, resident, lcd, between,\n", 498 | "Nearest to for: soo, of, over, follower, bien, among, inequalities, introductions,\n", 499 | "Nearest to ;: ., ,, :, (, —, ..., magill, one-man,\n", 500 | "Nearest to have: had, has, were, are, be, 7-6, year.the, make,\n", 501 | "Nearest to UNK: artificially, postulated, disasters, cooling, tselinograd, indefinite, enthusiastic, mass-marketed,\n", 502 | "Nearest to or: and, centred, hematoma, allowing, preeminent, than, resident, tubulin,\n", 503 | "Nearest to :: ;, ., dexter, (, stallion, methodologies, ,, une,\n", 504 | "Nearest to the: a, its, their, his, any, this, an, another,\n", 505 | "Average loss at step 12000: 11.888148\n", 506 | "Average loss at step 14000: 11.728705\n", 507 | "Average loss at step 16000: 11.675454\n", 508 | "Average loss at step 18000: 11.627067\n", 509 | "Average loss at step 20000: 11.579372\n", 510 | "Nearest to .: ;, ,, :, —, and, of, (, pear-shaped,\n", 511 | "Nearest to which: that, but, who, where, and, can, this, ecstasy,\n", 512 | "Nearest to an: a, the, rockwell, irwin, spartan, reclassify, oyo, corrugated,\n", 513 | "Nearest to as: kunda, called, yorkville, ordinances, attractive, creating, diracodon, revitalized,\n", 514 | "Nearest to be: been, being, have, become, were, write, amnh, clearly,\n", 515 | "Nearest to first: last, second, next, only, final, original, drury, main,\n", 516 | "Nearest to ,: ;, ., —, -, :, (, and, turkish-cypriot,\n", 517 | "Nearest to (: ;, -, ,, —, or, per, travis, :,\n", 518 | "Nearest to from: across, into, in, by, between, towards, through, on,\n", 519 | "Nearest to for: during, bien, yamaha, with, within, after, soo, among,\n", 520 | "Nearest to ;: ., ,, :, —, (, consumes, -, censured,\n", 521 | "Nearest to have: had, has, were, be, having, are, martyred, preserve,\n", 522 | "Nearest to UNK: tselinograd, 10,000, 1835., 4, quilting, r, jewellery, deadline,\n", 523 | "Nearest to or: and, (, buddhist, hematoma, centred, preeminent, retain, angélil,\n", 524 | "Nearest to :: ., ;, ,, —, pan-slavic, consumes, resorts, reunify,\n", 525 | "Nearest to the: its, a, their, his, an, her, typing, paroled,\n", 526 | "Average loss at step 22000: 11.470342\n", 527 | "Average loss at step 24000: 11.460939\n", 528 | "Average loss at step 26000: 11.296547\n", 529 | "Average loss at step 28000: 10.891410\n", 530 | "Average loss at step 30000: 10.739479\n", 531 | "Nearest to .: ;, ,, alfa, —, --, outcroppings, conventions, shaved,\n", 532 | "Nearest to which: who, that, but, and, where, whom, shoppers, repeal,\n", 533 | "Nearest to an: spartan, rockwell, pablo, cheadle, novgorodians, irwin, resnick, corrugated,\n", 534 | "Nearest to as: creating, yorkville, kunda, sài, f-75, including, retains, rossini,\n", 535 | "Nearest to be: been, being, become, spectroscopic, clearly, remain, write, kieft,\n", 536 | "Nearest to first: last, next, second, final, best, only, third, highest,\n", 537 | "Nearest to ,: ;, ., —, -, hazardous, –, ''protecteur, inflicted,\n", 538 | "Nearest to (: -, na+, travis, =, quis, 4-d, preemptive, sarawak,\n", 539 | "Nearest to from: across, in, longevity, overcome, snoop, via, missile, panicked,\n", 540 | "Nearest to for: bien, water…and, nasty, hectare, of, goryeo, keller, −300,\n", 541 | "Nearest to ;: ., ,, —, -, --, one-man, :, mucous,\n", 542 | "Nearest to have: had, has, having, rely, are, western-style, year.the, were,\n", 543 | "Nearest to UNK: blue, tarnów, monkees, tselinograd, silent, 1.1, artificially, 300,\n", 544 | "Nearest to or: and, preeminent, formatted, landmass, langston, morton, erysipelas, tubulin,\n", 545 | "Nearest to :: dexter, freshest, une, -, cowdery, ;, include, aerostatic/aerodynamic,\n", 546 | "Nearest to the: its, a, his, their, harriet, debbie, our, tranquilizer,\n", 547 | "Average loss at step 32000: 10.743503\n", 548 | "Average loss at step 34000: 10.741862\n", 549 | "Average loss at step 36000: 10.708315\n", 550 | "Average loss at step 38000: 10.610916\n", 551 | "Average loss at step 40000: 10.676803\n", 552 | "Nearest to .: ;, ,, :, and, 1932–33, harvested, cornelius, kalmykova,\n", 553 | "Nearest to which: that, who, whom, but, what, where, repeal, ecstasy,\n", 554 | "Nearest to an: rockwell, sarai, constrained, irwin, resnick, open-spandrel, spartan, corrugated,\n", 555 | "Nearest to as: yorkville, self-esteem, kunda, attractive, |mar_lo_°c, ravens, f-75, disperse,\n", 556 | "Nearest to be: been, being, fully, remain, amnh, have, was, clearly,\n", 557 | "Nearest to first: last, second, next, earliest, only, final, original, best,\n", 558 | "Nearest to ,: ;, ., —, and, -, ''protecteur, shipboard, –,\n", 559 | "Nearest to (: -, —, dram, approximately, =, 30.1, –, na+,\n", 560 | "Nearest to from: into, across, collects, lcd, documenting, mastiff, deep-seated, accommodating,\n", 561 | "Nearest to for: nasty, soo, introductions, in, water…and, yamaha, among, tumwater,\n", 562 | "Nearest to ;: ,, ., —, superfluous, petitioner, pro-russian, complains, fund-raising,\n", 563 | "Nearest to have: had, has, are, having, contain, were, martyred, apply,\n", 564 | "Nearest to UNK: r, tselinograd, perch, zha, re-instated, eighth, 300, speculates,\n", 565 | "Nearest to or: and, somebody, formatted, nor, dat, preeminent, 4-5, landmass,\n", 566 | "Nearest to :: dexter, consumes, word, providers, stallion, differentiating, pan-slavic, .,\n", 567 | "Nearest to the: a, their, your, his, its, any, zaidi, generalfeldmarschall,\n", 568 | "Average loss at step 42000: 10.604452\n", 569 | "Average loss at step 44000: 10.669711\n", 570 | "Average loss at step 46000: 10.638800\n", 571 | "Average loss at step 48000: 10.602861\n", 572 | "Average loss at step 50000: 10.685731\n", 573 | "Nearest to .: ;, ,, :, —, photographing, interviewing, shias, in,\n", 574 | "Nearest to which: that, and, but, where, what, whom, who, ecstasy,\n", 575 | "Nearest to an: rockwell, corrugated, reclassify, boise, irwin, novgorodians, resnick, the,\n", 576 | "Nearest to as: self-esteem, kunda, triploid, attractive, ravens, |mar_lo_°c, racquet, yorkville,\n", 577 | "Nearest to be: been, being, easily, replace, surpass, remain, solve, readily,\n", 578 | "Nearest to first: last, next, second, earliest, final, fourth, third, only,\n", 579 | "Nearest to ,: —, ;, ., (, and, in, djurgårdens, shipboard,\n", 580 | "Nearest to (: —, -, dram, ,, –, ''hancock, approximately, ;,\n", 581 | "Nearest to from: into, in, through, lcd, across, sault, liaison, towards,\n", 582 | "Nearest to for: nasty, yamaha, introductions, soo, during, arbitrarily, bien, in,\n", 583 | "Nearest to ;: ., ,, —, :, -, consumes, (, than,\n", 584 | "Nearest to have: had, has, having, are, were, contain, martyred, rely,\n", 585 | "Nearest to UNK: hawkeye, silent, tselinograd, brown, non-living, aesthetics, d, here,\n", 586 | "Nearest to or: and, formosan, dat, desc, nor, preeminent, containing, formatted,\n", 587 | "Nearest to :: ., ;, differentiating, consumes, resorts, dexter, cowdery, methodologies,\n", 588 | "Nearest to the: a, its, their, this, horsetails, his, acelhuate, delagoa,\n", 589 | "Average loss at step 52000: 10.430200\n", 590 | "Average loss at step 54000: 10.324997\n", 591 | "Average loss at step 56000: 10.216399\n", 592 | "Average loss at step 58000: 10.217039\n", 593 | "Average loss at step 60000: 10.210400\n", 594 | "Nearest to .: ,, ;, albinus, of, :, ?, 'big, matsui,\n", 595 | "Nearest to which: that, who, whom, what, shoppers, but, sheikh, repeal,\n", 596 | "Nearest to an: rockwell, resnick, spartan, open-spandrel, kant, irwin, corrugated, takings,\n", 597 | "Nearest to as: ravens, self-esteem, blaming, sài, beginning, result, creating, attractive,\n", 598 | "Nearest to be: been, have, being, easily, clearly, fully, replace, grow,\n", 599 | "Nearest to first: last, second, earliest, next, only, final, best, original,\n", 600 | "Nearest to ,: ., ;, —, theobromine, -, ''protecteur, cabled, :,\n", 601 | "Nearest to (: [, -, 405, bernard, adventurers, dram, horace, 30.1,\n", 602 | "Nearest to from: jawbone, metrovick, overcome, lcd, replacing, across, into, in,\n", 603 | "Nearest to for: introductions, seeker, spion, reactor, nasty, smelting, bien, rehabilitated,\n", 604 | "Nearest to ;: ., ,, —, khaldun, prowess, -, avellaneda, :,\n", 605 | "Nearest to have: has, had, be, having, dumps, rely, apply, contain,\n", 606 | "Nearest to UNK: tselinograd, 1800, re-instated, -1, fostered, tarnów, r., cobo,\n", 607 | "Nearest to or: and, formatted, landmass, centred, meaning, preeminent, hematoma, reciting,\n", 608 | "Nearest to :: termed, dexter, providers, freshest, ., teufel, stallion, nickname,\n", 609 | "Nearest to the: a, glial, blackburn, our, appease, 'the, atheistic, various,\n" 610 | ] 611 | }, 612 | { 613 | "name": "stdout", 614 | "output_type": "stream", 615 | "text": [ 616 | "Average loss at step 62000: 10.223157\n", 617 | "Average loss at step 64000: 10.105503\n", 618 | "Average loss at step 66000: 10.191790\n", 619 | "Average loss at step 68000: 10.157220\n", 620 | "Average loss at step 70000: 10.154481\n", 621 | "Nearest to .: ,, ;, of, and, in, that, verbiage, rostand,\n", 622 | "Nearest to which: that, who, and, ecstasy, whom, but, repeal, whose,\n", 623 | "Nearest to an: rockwell, resnick, spartan, irwin, sarai, cavitation, boise, novgorodians,\n", 624 | "Nearest to as: ravens, blaming, sài, kunda, rossini, medial, continuous-wave, result,\n", 625 | "Nearest to be: been, being, customisation, easily, replace, surpass, occur, were,\n", 626 | "Nearest to first: last, second, next, earliest, fourth, final, only, oldest,\n", 627 | "Nearest to ,: ., ;, and, —, of, langevin, recuperating, assaulted,\n", 628 | "Nearest to (: -, 1187., wander, ;, holmgard, —, eucharistic, mib,\n", 629 | "Nearest to from: in, towards, lcd, longevity, accommodating, accra, into, rampa,\n", 630 | "Nearest to for: yamaha, introductions, bien, nasty, during, fucking, gourmet, dislodge,\n", 631 | "Nearest to ;: ., ,, pro-russian, —, (, consumes, ?, :,\n", 632 | "Nearest to have: had, has, were, are, having, rely, brutish, be,\n", 633 | "Nearest to UNK: silent, hellene, weak, berger, hardiness, headingley, bone, 39,\n", 634 | "Nearest to or: and, formatted, preeminent, sax, pre-s2, reciting, baleen, buddhist,\n", 635 | "Nearest to :: reunify, termed, differentiating, replicates, consumes, liberalised, teufel, ;,\n", 636 | "Nearest to the: its, a, their, glial, these, 1937–1945, this, his,\n", 637 | "Average loss at step 72000: 10.217530\n", 638 | "Average loss at step 74000: 10.146726\n", 639 | "Average loss at step 76000: 10.247005\n", 640 | "Average loss at step 78000: 10.026597\n", 641 | "Average loss at step 80000: 9.882595\n", 642 | "Nearest to .: ;, ,, 1924., 10., 2003., 2006., 2004., 1983.,\n", 643 | "Nearest to which: that, whom, who, 35.6, where, shoppers, roney, but,\n", 644 | "Nearest to an: rockwell, resnick, irwin, corrugated, novgorodians, spartan, cheadle, sarai,\n", 645 | "Nearest to as: yorkville, quintessentially, |mar_lo_°c, sài, self-esteem, escapes, mississippians, thessaly,\n", 646 | "Nearest to be: been, being, surpass, easily, customisation, replace, deliberately, occur,\n", 647 | "Nearest to first: last, second, next, fourth, only, earliest, final, oldest,\n", 648 | "Nearest to ,: ., ;, and, melinda, refuelling, —, apostate, hunslet,\n", 649 | "Nearest to (: -, na+, dram, chanute, indented, lihue, 4-d, approximately,\n", 650 | "Nearest to from: in, lampboard, deep-seated, waterways, across, israeli-palestinian, cambodian, lcd,\n", 651 | "Nearest to for: water…and, nasty, concealing, γαλαξίας, yamaha, bien, keller, kopfstein,\n", 652 | "Nearest to ;: ., ,, deco, --, —, one-man, penned, mucous,\n", 653 | "Nearest to have: had, has, having, rely, contain, year.the, were, spend,\n", 654 | "Nearest to UNK: unsuited, 1.1, 99, tarnów, tselinograd, schlich, monkees, natasha,\n", 655 | "Nearest to or: formatted, and, preeminent, plus, landmass, pre-s2, semivowels, lukewarm,\n", 656 | "Nearest to :: differentiating, une, termed, freshest, aerostatic/aerodynamic, dexter, dragged, rhinos,\n", 657 | "Nearest to the: a, its, non-agricultural, forster, an, shawnee, tanzanian, paroled,\n", 658 | "Average loss at step 82000: 9.922878\n", 659 | "Average loss at step 84000: 9.897537\n", 660 | "Average loss at step 86000: 9.913045\n", 661 | "Average loss at step 88000: 9.824237\n", 662 | "Average loss at step 90000: 9.811843\n", 663 | "Nearest to .: ,, ;, jazzy, 2001., bethad, 1821., supertankers, align=,\n", 664 | "Nearest to which: that, who, whom, what, but, shoppers, where, and,\n", 665 | "Nearest to an: rockwell, resnick, sarai, corrugated, thane, open-spandrel, fended, entremeses,\n", 666 | "Nearest to as: self-esteem, |mar_lo_°c, thelma, triploid, kunda, ravens, yorkville, quintessentially,\n", 667 | "Nearest to be: been, being, surpass, become, fully, grow, regain, remain,\n", 668 | "Nearest to first: second, last, earliest, next, fourth, oldest, final, facultatively,\n", 669 | "Nearest to ,: ., ;, —, posit, and, -, 802.11b, of,\n", 670 | "Nearest to (: -, —, 30.1, teddy, investment-grade, 'scouse, <, 405,\n", 671 | "Nearest to from: lampboard, into, mastiff, in, deep-seated, fasi, lcd, alchemical,\n", 672 | "Nearest to for: yamaha, soo, introductions, nasty, water…and, electrical, fattened, reactor,\n", 673 | "Nearest to ;: ,, ., —, complains, superfluous, pro-russian, scorers, transliterations,\n", 674 | "Nearest to have: had, has, exist, contain, having, are, contribute, represent,\n", 675 | "Nearest to UNK: re-instated, reuters, loftus, -1, gaston, multi-instrumentalist, harvard, tarnów,\n", 676 | "Nearest to or: formatted, and, landmass, 4-5, lampooned, buddhist, preeminent, nor,\n", 677 | "Nearest to :: freshest, dexter, providers, differentiating, actor-managers, stanislaus, retorted, aerostatic/aerodynamic,\n", 678 | "Nearest to the: a, his, paroled, its, newly-created, delagoa, stormtrooper, woda,\n", 679 | "Average loss at step 92000: 9.857963\n", 680 | "Average loss at step 94000: 9.855468\n", 681 | "Average loss at step 96000: 9.892065\n", 682 | "Average loss at step 98000: 9.858063\n", 683 | "Average loss at step 100000: 9.912151\n", 684 | "Nearest to .: ;, ,, pear-shaped, of, d'etat, :, and, seaways,\n", 685 | "Nearest to which: that, whom, what, where, who, alushta, redfin, 35.6,\n", 686 | "Nearest to an: rockwell, corrugated, resnick, irwin, boise, 40.4, novgorodians, thane,\n", 687 | "Nearest to as: kunda, triploid, yorkville, racquet, quintessentially, ravens, self-esteem, result,\n", 688 | "Nearest to be: been, being, surpass, replace, easily, was, formally, partially,\n", 689 | "Nearest to first: last, second, earliest, next, fourth, oldest, final, best,\n", 690 | "Nearest to ,: ., ;, —, ''protecteur, and, diphthongisation, chokai, ultimatetv,\n", 691 | "Nearest to (: —, dram, bernard, -, 30.1, –, sarawak, toray,\n", 692 | "Nearest to from: lampboard, phosphorus, rossby, lighter-than-air, lcd, wendt, alchemical, longevity,\n", 693 | "Nearest to for: yamaha, introductions, nasty, freest, seeker, water…and, mistook, reminded,\n", 694 | "Nearest to ;: ., ,, :, durant, --, —, basel-landschaft, >,\n", 695 | "Nearest to have: has, had, contain, having, spend, contribute, rely, 've,\n", 696 | "Nearest to UNK: silent, darts, tselinograd, 4th, berger, jewellery, honour, claudius,\n", 697 | "Nearest to or: formatted, and, slew, desc, sax, preeminent, pre-s2, meaning,\n", 698 | "Nearest to :: differentiating, ;, consumes, pour, freshest, mattila, termed, recounting,\n", 699 | "Nearest to the: a, his, its, their, 1959-1960, this, species-rich, 1.88,\n" 700 | ] 701 | } 702 | ], 703 | "source": [ 704 | "num_steps = 100001\n", 705 | "decay_learning_rate_every = 2000\n", 706 | "skip_gram_loss = [] # Collect the sequential loss values for plotting purposes\n", 707 | "\n", 708 | "with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as session:\n", 709 | " tf.global_variables_initializer().run()\n", 710 | " print('Initialized')\n", 711 | " average_loss = 0\n", 712 | " for step in range(num_steps):\n", 713 | " batch_data, batch_labels = generate_batch(\n", 714 | " batch_size, window_size)\n", 715 | " feed_dict = {train_dataset : batch_data}\n", 716 | " for wi in range(2*window_size):\n", 717 | " feed_dict.update({train_labels[wi]:np.reshape(batch_labels[:,wi],(-1,1))})\n", 718 | " \n", 719 | " _, l = session.run([optimizer, loss], feed_dict=feed_dict)\n", 720 | " average_loss += l\n", 721 | " \n", 722 | " if (step+1) % 2000 == 0:\n", 723 | " if step > 0:\n", 724 | " average_loss = average_loss / 2000\n", 725 | " # The average loss is an estimate of the loss over the last 2000 batches.\n", 726 | " print('Average loss at step %d: %f' % (step+1, average_loss))\n", 727 | " skip_gram_loss.append(average_loss)\n", 728 | " average_loss = 0\n", 729 | " # note that this is expensive (~20% slowdown if computed every 500 steps)\n", 730 | " if (step+1) % 10000 == 0:\n", 731 | " sim = similarity.eval()\n", 732 | " for i in range(valid_size):\n", 733 | " valid_word = reverse_dictionary[valid_examples[i]]\n", 734 | " top_k = 8 # number of nearest neighbors\n", 735 | " nearest = (-sim[i, :]).argsort()[1:top_k+1]\n", 736 | " log = 'Nearest to %s:' % valid_word\n", 737 | " for k in range(top_k):\n", 738 | " close_word = reverse_dictionary[nearest[k]]\n", 739 | " log = '%s %s,' % (log, close_word)\n", 740 | " print(log)\n", 741 | " skip_gram_final_embeddings = normalized_embeddings.eval()\n", 742 | "\n", 743 | "# We will save the word vectors learned and the loss over time\n", 744 | "# as this information is required later for comparisons\n", 745 | "np.save('struct_skip_embeddings',skip_gram_final_embeddings)\n", 746 | "\n", 747 | "with open('struct_skip_losses.csv', 'wt') as f:\n", 748 | " writer = csv.writer(f, delimiter=',')\n", 749 | " writer.writerow(skip_gram_loss)" 750 | ] 751 | }, 752 | { 753 | "cell_type": "code", 754 | "execution_count": null, 755 | "metadata": { 756 | "collapsed": true 757 | }, 758 | "outputs": [], 759 | "source": [] 760 | } 761 | ], 762 | "metadata": { 763 | "kernelspec": { 764 | "display_name": "Python 3", 765 | "language": "python", 766 | "name": "python3" 767 | }, 768 | "language_info": { 769 | "codemirror_mode": { 770 | "name": "ipython", 771 | "version": 3 772 | }, 773 | "file_extension": ".py", 774 | "mimetype": "text/x-python", 775 | "name": "python", 776 | "nbconvert_exporter": "python", 777 | "pygments_lexer": "ipython3", 778 | "version": "3.5.2" 779 | } 780 | }, 781 | "nbformat": 4, 782 | "nbformat_minor": 2 783 | } 784 | -------------------------------------------------------------------------------- /ch5/cnn_sentence_classification.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Sentence Classification with Convolution Neural Networks\n", 8 | "[Paper](https://arxiv.org/pdf/1408.5882.pdf): Convolutional Neural Networks for Sentence Classification by Yoon Kim" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": 1, 14 | "metadata": {}, 15 | "outputs": [ 16 | { 17 | "name": "stderr", 18 | "output_type": "stream", 19 | "text": [ 20 | "c:\\users\\thushan\\documents\\python_virtualenvs\\tensorflow_venv\\lib\\site-packages\\h5py\\__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.\n", 21 | " from ._conv import register_converters as _register_converters\n" 22 | ] 23 | } 24 | ], 25 | "source": [ 26 | "# These are all the modules we'll be using later. Make sure you can import them\n", 27 | "# before proceeding further.\n", 28 | "%matplotlib inline\n", 29 | "from __future__ import print_function\n", 30 | "import collections\n", 31 | "import math\n", 32 | "import numpy as np\n", 33 | "import os\n", 34 | "import random\n", 35 | "import tensorflow as tf\n", 36 | "import zipfile\n", 37 | "from matplotlib import pylab\n", 38 | "from six.moves import range\n", 39 | "from six.moves.urllib.request import urlretrieve\n", 40 | "import tensorflow as tf" 41 | ] 42 | }, 43 | { 44 | "cell_type": "markdown", 45 | "metadata": {}, 46 | "source": [ 47 | "## Downloading and Checking the Dataset\n", 48 | "This [dataset](Dataset: http://cogcomp.cs.illinois.edu/Data/QA/QC/) is composed of questions as inputs and their respective type as the output. For example, (e.g. Who was Abraham Lincon?) and the output or label would be Human." 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": 2, 54 | "metadata": {}, 55 | "outputs": [ 56 | { 57 | "name": "stdout", 58 | "output_type": "stream", 59 | "text": [ 60 | "question-classif-data\\train_1000.label\n", 61 | "Found and verified question-classif-data\\train_1000.label\n", 62 | "question-classif-data\\TREC_10.label\n", 63 | "Found and verified question-classif-data\\TREC_10.label\n" 64 | ] 65 | } 66 | ], 67 | "source": [ 68 | "url = 'http://cogcomp.org/Data/QA/QC/'\n", 69 | "dir_name = 'question-classif-data'\n", 70 | "\n", 71 | "def maybe_download(dir_name, filename, expected_bytes):\n", 72 | " \"\"\"Download a file if not present, and make sure it's the right size.\"\"\"\n", 73 | " if not os.path.exists(dir_name):\n", 74 | " os.mkdir(dir_name)\n", 75 | " if not os.path.exists(os.path.join(dir_name,filename)):\n", 76 | " filename, _ = urlretrieve(url + filename, os.path.join(dir_name,filename))\n", 77 | " print(os.path.join(dir_name,filename))\n", 78 | " statinfo = os.stat(os.path.join(dir_name,filename))\n", 79 | " if statinfo.st_size == expected_bytes:\n", 80 | " print('Found and verified %s' % os.path.join(dir_name,filename))\n", 81 | " else:\n", 82 | " print(statinfo.st_size)\n", 83 | " raise Exception(\n", 84 | " 'Failed to verify ' + os.path.join(dir_name,filename) + '. Can you get to it with a browser?')\n", 85 | " return filename\n", 86 | "\n", 87 | "filename = maybe_download(dir_name, 'train_1000.label', 60774)\n", 88 | "test_filename = maybe_download(dir_name, 'TREC_10.label',23354)" 89 | ] 90 | }, 91 | { 92 | "cell_type": "code", 93 | "execution_count": 3, 94 | "metadata": {}, 95 | "outputs": [ 96 | { 97 | "name": "stdout", 98 | "output_type": "stream", 99 | "text": [ 100 | "Files found and verified.\n" 101 | ] 102 | } 103 | ], 104 | "source": [ 105 | "# Check the existence of files\n", 106 | "filenames = ['train_1000.label','TREC_10.label']\n", 107 | "num_files = len(filenames)\n", 108 | "for i in range(len(filenames)):\n", 109 | " file_exists = os.path.isfile(os.path.join(dir_name,filenames[i]))\n", 110 | " assert file_exists\n", 111 | "print('Files found and verified.')" 112 | ] 113 | }, 114 | { 115 | "cell_type": "markdown", 116 | "metadata": {}, 117 | "source": [ 118 | "## Loading and Preprocessing Data\n", 119 | "Below we load the text into the program and do some simple preprocessing on data" 120 | ] 121 | }, 122 | { 123 | "cell_type": "code", 124 | "execution_count": 4, 125 | "metadata": {}, 126 | "outputs": [ 127 | { 128 | "name": "stdout", 129 | "output_type": "stream", 130 | "text": [ 131 | "\n", 132 | "Processing file question-classif-data\\train_1000.label\n", 133 | "\tQuestion 0: ['manner', 'how', 'did', 'serfdom', 'develop', 'in', 'and', 'then', 'leave', 'russia', '?']\n", 134 | "\tLabel 0: DESC\n", 135 | "\n", 136 | "\tQuestion 1: ['cremat', 'what', 'films', 'featured', 'the', 'character', 'popeye', 'doyle', '?']\n", 137 | "\tLabel 1: ENTY\n", 138 | "\n", 139 | "\tQuestion 2: ['manner', 'how', 'can', 'i', 'find', 'a', 'list', 'of', 'celebrities', \"'\", 'real', 'names', '?']\n", 140 | "\tLabel 2: DESC\n", 141 | "\n", 142 | "\tQuestion 3: ['animal', 'what', 'fowl', 'grabs', 'the', 'spotlight', 'after', 'the', 'chinese', 'year', 'of', 'the', 'monkey', '?']\n", 143 | "\tLabel 3: ENTY\n", 144 | "\n", 145 | "\tQuestion 4: ['exp', 'what', 'is', 'the', 'full', 'form', 'of', '.com', '?']\n", 146 | "\tLabel 4: ABBR\n", 147 | "\n", 148 | "\n", 149 | "Processing file question-classif-data\\TREC_10.label\n", 150 | "\tQuestion 0: ['manner', 'how', 'did', 'serfdom', 'develop', 'in', 'and', 'then', 'leave', 'russia', '?']\n", 151 | "\tLabel 0: DESC\n", 152 | "\n", 153 | "\tQuestion 1: ['cremat', 'what', 'films', 'featured', 'the', 'character', 'popeye', 'doyle', '?']\n", 154 | "\tLabel 1: ENTY\n", 155 | "\n", 156 | "\tQuestion 2: ['manner', 'how', 'can', 'i', 'find', 'a', 'list', 'of', 'celebrities', \"'\", 'real', 'names', '?']\n", 157 | "\tLabel 2: DESC\n", 158 | "\n", 159 | "\tQuestion 3: ['animal', 'what', 'fowl', 'grabs', 'the', 'spotlight', 'after', 'the', 'chinese', 'year', 'of', 'the', 'monkey', '?']\n", 160 | "\tLabel 3: ENTY\n", 161 | "\n", 162 | "\tQuestion 4: ['exp', 'what', 'is', 'the', 'full', 'form', 'of', '.com', '?']\n", 163 | "\tLabel 4: ABBR\n", 164 | "\n", 165 | "Max Sentence Length: 33\n", 166 | "\n", 167 | "Normalizing all sentences to same length\n" 168 | ] 169 | } 170 | ], 171 | "source": [ 172 | "# Records the maximum length of the sentences\n", 173 | "# as we need to pad shorter sentences accordingly\n", 174 | "max_sent_length = 0 \n", 175 | "\n", 176 | "def read_data(filename):\n", 177 | " '''\n", 178 | " Read data from a file with given filename\n", 179 | " Returns a list of strings where each string is a lower case word\n", 180 | " '''\n", 181 | " global max_sent_length\n", 182 | " questions = []\n", 183 | " labels = []\n", 184 | " with open(filename,'r',encoding='latin-1') as f: \n", 185 | " for row in f:\n", 186 | " row_str = row.split(\":\")\n", 187 | " lb,q = row_str[0],row_str[1]\n", 188 | " q = q.lower()\n", 189 | " labels.append(lb)\n", 190 | " questions.append(q.split()) \n", 191 | " if len(questions[-1])>max_sent_length:\n", 192 | " max_sent_length = len(questions[-1])\n", 193 | " return questions,labels\n", 194 | "\n", 195 | "# Process train and Test data\n", 196 | "for i in range(num_files): \n", 197 | " print('\\nProcessing file %s'%os.path.join(dir_name,filenames[i]))\n", 198 | " if i==0:\n", 199 | " # Processing training data\n", 200 | " train_questions,train_labels = read_data(os.path.join(dir_name,filenames[i]))\n", 201 | " # Making sure we got all the questions and corresponding labels\n", 202 | " assert len(train_questions)==len(train_labels)\n", 203 | " elif i==1:\n", 204 | " # Processing testing data\n", 205 | " test_questions,test_labels = read_data(os.path.join(dir_name,filenames[i]))\n", 206 | " # Making sure we got all the questions and corresponding labels.\n", 207 | " assert len(test_questions)==len(test_labels)\n", 208 | " \n", 209 | " # Print some data to see everything is okey\n", 210 | " for j in range(5):\n", 211 | " print('\\tQuestion %d: %s' %(j,train_questions[j]))\n", 212 | " print('\\tLabel %d: %s\\n'%(j,train_labels[j]))\n", 213 | " \n", 214 | "print('Max Sentence Length: %d'%max_sent_length)\n", 215 | "print('\\nNormalizing all sentences to same length')" 216 | ] 217 | }, 218 | { 219 | "cell_type": "markdown", 220 | "metadata": {}, 221 | "source": [ 222 | "## Padding Shorter Sentences\n", 223 | "We use padding to pad short sentences so that all the sentences are of the same length." 224 | ] 225 | }, 226 | { 227 | "cell_type": "code", 228 | "execution_count": 5, 229 | "metadata": {}, 230 | "outputs": [ 231 | { 232 | "name": "stdout", 233 | "output_type": "stream", 234 | "text": [ 235 | "Train questions padded\n", 236 | "\n", 237 | "Test questions padded\n", 238 | "\n", 239 | "Sample test question: %s ['dist', 'how', 'far', 'is', 'it', 'from', 'denver', 'to', 'aspen', '?', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD']\n" 240 | ] 241 | } 242 | ], 243 | "source": [ 244 | "# Padding training data\n", 245 | "for qi,que in enumerate(train_questions):\n", 246 | " for _ in range(max_sent_length-len(que)):\n", 247 | " que.append('PAD')\n", 248 | " assert len(que)==max_sent_length\n", 249 | " train_questions[qi] = que\n", 250 | "print('Train questions padded')\n", 251 | "\n", 252 | "# Padding testing data\n", 253 | "for qi,que in enumerate(test_questions):\n", 254 | " for _ in range(max_sent_length-len(que)):\n", 255 | " que.append('PAD')\n", 256 | " assert len(que)==max_sent_length\n", 257 | " test_questions[qi] = que\n", 258 | "print('\\nTest questions padded') \n", 259 | "\n", 260 | "# Printing a test question to see if everything is correct\n", 261 | "print('\\nSample test question: %s',test_questions[0])" 262 | ] 263 | }, 264 | { 265 | "cell_type": "markdown", 266 | "metadata": {}, 267 | "source": [ 268 | "## Building the Dictionaries\n", 269 | "Builds the following. To understand each of these elements, let us also assume the text \"I like to go to school\"\n", 270 | "\n", 271 | "* `dictionary`: maps a string word to an ID (e.g. {I:0, like:1, to:2, go:3, school:4})\n", 272 | "* `reverse_dictionary`: maps an ID to a string word (e.g. {0:I, 1:like, 2:to, 3:go, 4:school}\n", 273 | "* `count`: List of list of (word, frequency) elements (e.g. [(I,1),(like,1),(to,2),(go,1),(school,1)]\n", 274 | "* `data` : Contain the string of text we read, where string words are replaced with word IDs (e.g. [0, 1, 2, 3, 2, 4])\n", 275 | "\n", 276 | "We do not replace rare words with \"UNK\" because the vocabulary is already quite small." 277 | ] 278 | }, 279 | { 280 | "cell_type": "code", 281 | "execution_count": 6, 282 | "metadata": {}, 283 | "outputs": [ 284 | { 285 | "name": "stdout", 286 | "output_type": "stream", 287 | "text": [ 288 | "49500 Words found.\n", 289 | "Found 3369 words in the vocabulary. \n", 290 | "All words (count) [('PAD', 34407), ('?', 1454), ('the', 999), ('what', 963), ('is', 587)]\n", 291 | "\n", 292 | "0th entry in dictionary: %s PAD\n", 293 | "\n", 294 | "Sample data [38, 12, 19, 1977, 1118, 6, 28, 2230, 3107, 686, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]\n", 295 | "\n", 296 | "Sample data [44, 3, 881, 2852, 2, 173, 2113, 2996, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]\n", 297 | "\n", 298 | "Vocabulary: 3369\n", 299 | "\n", 300 | "Number of training questions: 1000\n", 301 | "Number of testing questions: 500\n" 302 | ] 303 | } 304 | ], 305 | "source": [ 306 | "def build_dataset(questions):\n", 307 | " words = []\n", 308 | " data_list = []\n", 309 | " count = []\n", 310 | " \n", 311 | " # First create a large list with all the words in all the questions\n", 312 | " for d in questions:\n", 313 | " words.extend(d)\n", 314 | " print('%d Words found.'%len(words)) \n", 315 | " print('Found %d words in the vocabulary. '%len(collections.Counter(words).most_common()))\n", 316 | " \n", 317 | " # Sort words by there frequency\n", 318 | " count.extend(collections.Counter(words).most_common())\n", 319 | " \n", 320 | " # Create an ID for each word by giving the current length of the dictionary\n", 321 | " # And adding that item to the dictionary\n", 322 | " dictionary = dict()\n", 323 | " for word, _ in count:\n", 324 | " dictionary[word] = len(dictionary)\n", 325 | " \n", 326 | " # Traverse through all the text and \n", 327 | " # replace the string words with the ID \n", 328 | " # of the word found at that index\n", 329 | " for d in questions:\n", 330 | " data = list()\n", 331 | " for word in d:\n", 332 | " index = dictionary[word] \n", 333 | " data.append(index)\n", 334 | " \n", 335 | " data_list.append(data)\n", 336 | " \n", 337 | " reverse_dictionary = dict(zip(dictionary.values(), dictionary.keys())) \n", 338 | " \n", 339 | " return data_list, count, dictionary, reverse_dictionary\n", 340 | "\n", 341 | "# Create a dataset with both train and test questions\n", 342 | "all_questions = list(train_questions)\n", 343 | "all_questions.extend(test_questions)\n", 344 | "\n", 345 | "# Use the above created dataset to build the vocabulary\n", 346 | "all_question_ind, count, dictionary, reverse_dictionary = build_dataset(all_questions)\n", 347 | "\n", 348 | "# Print some statistics about the processed data\n", 349 | "print('All words (count)', count[:5])\n", 350 | "print('\\n0th entry in dictionary: %s',reverse_dictionary[0])\n", 351 | "print('\\nSample data', all_question_ind[0])\n", 352 | "print('\\nSample data', all_question_ind[1])\n", 353 | "print('\\nVocabulary: ',len(dictionary))\n", 354 | "vocabulary_size = len(dictionary)\n", 355 | "\n", 356 | "print('\\nNumber of training questions: ',len(train_questions))\n", 357 | "print('Number of testing questions: ',len(test_questions))" 358 | ] 359 | }, 360 | { 361 | "cell_type": "markdown", 362 | "metadata": {}, 363 | "source": [ 364 | "## Generating Batches of Data\n", 365 | "Below I show the code to generate a batch of data from a given set of questions and labels." 366 | ] 367 | }, 368 | { 369 | "cell_type": "code", 370 | "execution_count": 7, 371 | "metadata": {}, 372 | "outputs": [ 373 | { 374 | "name": "stdout", 375 | "output_type": "stream", 376 | "text": [ 377 | "Sample batch labels\n", 378 | "[3 4 3 4 5 2 2 2 3 2 0 3 2 2 4 1]\n", 379 | "[3 0 3 3 0 4 2 3 3 4 2 1 4 1 5 4]\n" 380 | ] 381 | } 382 | ], 383 | "source": [ 384 | "batch_size = 16 # We process 16 questions at a time\n", 385 | "sent_length = max_sent_length\n", 386 | "\n", 387 | "num_classes = 6 # Number of classes\n", 388 | "# All the types of question that are in the dataset\n", 389 | "all_labels = ['NUM','LOC','HUM','DESC','ENTY','ABBR'] \n", 390 | "\n", 391 | "class BatchGenerator(object):\n", 392 | " '''\n", 393 | " Generates a batch of data\n", 394 | " '''\n", 395 | " def __init__(self,batch_size,questions,labels):\n", 396 | " self.questions = questions\n", 397 | " self.labels = labels\n", 398 | " self.text_size = len(questions)\n", 399 | " self.batch_size = batch_size\n", 400 | " self.data_index = 0\n", 401 | " assert len(self.questions)==len(self.labels)\n", 402 | " \n", 403 | " def generate_batch(self):\n", 404 | " '''\n", 405 | " Data generation function. This outputs two matrices\n", 406 | " inputs: a batch of questions where each question is a tensor of size\n", 407 | " [sent_length, vocabulary_size] with each word one-hot-encoded\n", 408 | " labels_ohe: one-hot-encoded labels corresponding to the questions in inputs\n", 409 | " '''\n", 410 | " global sent_length,num_classes\n", 411 | " global dictionary, all_labels\n", 412 | " \n", 413 | " # Numpy arrays holding input and label data\n", 414 | " inputs = np.zeros((self.batch_size,sent_length,vocabulary_size),dtype=np.float32)\n", 415 | " labels_ohe = np.zeros((self.batch_size,num_classes),dtype=np.float32)\n", 416 | " \n", 417 | " # When we reach the end of the dataset\n", 418 | " # start from beginning\n", 419 | " if self.data_index + self.batch_size >= self.text_size:\n", 420 | " self.data_index = 0\n", 421 | " \n", 422 | " # For each question in the dataset\n", 423 | " for qi,que in enumerate(self.questions[self.data_index:self.data_index+self.batch_size]):\n", 424 | " # For each word in the question\n", 425 | " for wi,word in enumerate(que): \n", 426 | " # Set the element at the word ID index to 1\n", 427 | " # this gives the one-hot-encoded vector of that word\n", 428 | " inputs[qi,wi,dictionary[word]] = 1.0\n", 429 | " \n", 430 | " # Set the index corrsponding to that particular class to 1\n", 431 | " labels_ohe[qi,all_labels.index(self.labels[self.data_index + qi])] = 1.0\n", 432 | " \n", 433 | " # Update the data index to get the next batch of data\n", 434 | " self.data_index = (self.data_index + self.batch_size)%self.text_size\n", 435 | " \n", 436 | " return inputs,labels_ohe\n", 437 | " \n", 438 | " def return_index(self):\n", 439 | " # Get the current index of data\n", 440 | " return self.data_index\n", 441 | "\n", 442 | "# Test our batch generator\n", 443 | "sample_gen = BatchGenerator(batch_size,train_questions,train_labels)\n", 444 | "# Generate a single batch\n", 445 | "sample_batch_inputs,sample_batch_labels = sample_gen.generate_batch()\n", 446 | "# Generate another batch\n", 447 | "sample_batch_inputs_2,sample_batch_labels_2 = sample_gen.generate_batch()\n", 448 | "\n", 449 | "# Make sure that we infact have the question 0 as the 0th element of our batch\n", 450 | "assert np.all(np.asarray([dictionary[w] for w in train_questions[0]],dtype=np.int32) \n", 451 | " == np.argmax(sample_batch_inputs[0,:,:],axis=1))\n", 452 | "\n", 453 | "# Print some data labels we obtained\n", 454 | "print('Sample batch labels')\n", 455 | "print(np.argmax(sample_batch_labels,axis=1))\n", 456 | "print(np.argmax(sample_batch_labels_2,axis=1))" 457 | ] 458 | }, 459 | { 460 | "cell_type": "markdown", 461 | "metadata": {}, 462 | "source": [ 463 | "## Sentence Classifying Convolution Neural Network\n", 464 | "We are going to implement a very simple CNN to classify sentences. However you will see that even with this simple structure we achieve good accuracies. Our CNN will have one layer (with 3 different parallel layers). This will be followed by a pooling-over-time layer and finally a fully connected layer that produces the logits." 465 | ] 466 | }, 467 | { 468 | "cell_type": "markdown", 469 | "metadata": {}, 470 | "source": [ 471 | "## Defining hyperparameters and inputs" 472 | ] 473 | }, 474 | { 475 | "cell_type": "code", 476 | "execution_count": 8, 477 | "metadata": { 478 | "collapsed": true 479 | }, 480 | "outputs": [], 481 | "source": [ 482 | "tf.reset_default_graph()\n", 483 | "\n", 484 | "batch_size = 32\n", 485 | "# Different filter sizes we use in a single convolution layer\n", 486 | "filter_sizes = [3,5,7] \n", 487 | "\n", 488 | "# inputs and labels\n", 489 | "sent_inputs = tf.placeholder(shape=[batch_size,sent_length,vocabulary_size],dtype=tf.float32,name='sentence_inputs')\n", 490 | "sent_labels = tf.placeholder(shape=[batch_size,num_classes],dtype=tf.float32,name='sentence_labels')\n" 491 | ] 492 | }, 493 | { 494 | "cell_type": "markdown", 495 | "metadata": {}, 496 | "source": [ 497 | "## Defining Model Parameters\n", 498 | "Our model has following parameters.\n", 499 | "* 3 sets of convolution layer weights and biases (one for each parallel layer)\n", 500 | "* 1 fully connected output layer" 501 | ] 502 | }, 503 | { 504 | "cell_type": "code", 505 | "execution_count": 9, 506 | "metadata": { 507 | "collapsed": true 508 | }, 509 | "outputs": [], 510 | "source": [ 511 | "# 3 filters with different context window sizes (3,5,7)\n", 512 | "# Each of this filter spans the full one-hot-encoded length of each word and the context window width\n", 513 | "\n", 514 | "# Weights of the first parallel layer\n", 515 | "w1 = tf.Variable(tf.truncated_normal([filter_sizes[0],vocabulary_size,1],stddev=0.02,dtype=tf.float32),name='weights_1')\n", 516 | "b1 = tf.Variable(tf.random_uniform([1],0,0.01,dtype=tf.float32),name='bias_1')\n", 517 | "\n", 518 | "# Weights of the second parallel layer\n", 519 | "w2 = tf.Variable(tf.truncated_normal([filter_sizes[1],vocabulary_size,1],stddev=0.02,dtype=tf.float32),name='weights_2')\n", 520 | "b2 = tf.Variable(tf.random_uniform([1],0,0.01,dtype=tf.float32),name='bias_2')\n", 521 | "\n", 522 | "# Weights of the third parallel layer\n", 523 | "w3 = tf.Variable(tf.truncated_normal([filter_sizes[2],vocabulary_size,1],stddev=0.02,dtype=tf.float32),name='weights_3')\n", 524 | "b3 = tf.Variable(tf.random_uniform([1],0,0.01,dtype=tf.float32),name='bias_3')\n", 525 | "\n", 526 | "# Fully connected layer\n", 527 | "w_fc1 = tf.Variable(tf.truncated_normal([len(filter_sizes),num_classes],stddev=0.5,dtype=tf.float32),name='weights_fulcon_1')\n", 528 | "b_fc1 = tf.Variable(tf.random_uniform([num_classes],0,0.01,dtype=tf.float32),name='bias_fulcon_1')" 529 | ] 530 | }, 531 | { 532 | "cell_type": "markdown", 533 | "metadata": {}, 534 | "source": [ 535 | "## Defining Inference of the CNN\n", 536 | "Here we define the CNN inference logic. First compute the convolution output for each parallel layer within the convolution layer. Then perform pooling-over-time over all the convolution outputs. Finally feed the output of the pooling layer to a fully connected layer to obtain the output logits." 537 | ] 538 | }, 539 | { 540 | "cell_type": "code", 541 | "execution_count": 10, 542 | "metadata": { 543 | "collapsed": true 544 | }, 545 | "outputs": [], 546 | "source": [ 547 | "# Calculate the output for all the filters with a stride 1\n", 548 | "# We use relu activation as the activation function\n", 549 | "h1_1 = tf.nn.relu(tf.nn.conv1d(sent_inputs,w1,stride=1,padding='SAME') + b1)\n", 550 | "h1_2 = tf.nn.relu(tf.nn.conv1d(sent_inputs,w2,stride=1,padding='SAME') + b2)\n", 551 | "h1_3 = tf.nn.relu(tf.nn.conv1d(sent_inputs,w3,stride=1,padding='SAME') + b3)\n", 552 | "\n", 553 | "# Pooling over time operation\n", 554 | "\n", 555 | "# This is doing the max pooling. Thereare two options to do the max pooling\n", 556 | "# 1. Use tf.nn.max_pool operation on a tensor made by concatenating h1_1,h1_2,h1_3 and converting that tensor to 4D\n", 557 | "# (Because max_pool takes a tensor of rank >= 4 )\n", 558 | "# 2. Do the max pooling separately for each filter output and combine them using tf.concat \n", 559 | "# (this is the one used in the code)\n", 560 | "\n", 561 | "h2_1 = tf.reduce_max(h1_1,axis=1)\n", 562 | "h2_2 = tf.reduce_max(h1_2,axis=1)\n", 563 | "h2_3 = tf.reduce_max(h1_3,axis=1)\n", 564 | "\n", 565 | "h2 = tf.concat([h2_1,h2_2,h2_3],axis=1)\n", 566 | "\n", 567 | "# Calculate the fully connected layer output (no activation)\n", 568 | "# Note: since h2 is 2d [batch_size,number of parallel filters] \n", 569 | "# reshaping the output is not required as it usually do in CNNs\n", 570 | "logits = tf.matmul(h2,w_fc1) + b_fc1" 571 | ] 572 | }, 573 | { 574 | "cell_type": "markdown", 575 | "metadata": {}, 576 | "source": [ 577 | "## Model Loss and the Optimizer\n", 578 | "We compute the cross entropy loss and use the momentum optimizer (which works better than standard gradient descent) to optimize our model" 579 | ] 580 | }, 581 | { 582 | "cell_type": "code", 583 | "execution_count": 11, 584 | "metadata": { 585 | "collapsed": true 586 | }, 587 | "outputs": [], 588 | "source": [ 589 | "# Loss (Cross-Entropy)\n", 590 | "loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=sent_labels,logits=logits))\n", 591 | "\n", 592 | "# Momentum Optimizer\n", 593 | "optimizer = tf.train.MomentumOptimizer(learning_rate=0.01,momentum=0.9).minimize(loss)" 594 | ] 595 | }, 596 | { 597 | "cell_type": "markdown", 598 | "metadata": {}, 599 | "source": [ 600 | "## Model Predictions\n", 601 | "Note that we are not getting the raw predictions, but the index of the maximally activated element in the prediction vector." 602 | ] 603 | }, 604 | { 605 | "cell_type": "code", 606 | "execution_count": 12, 607 | "metadata": { 608 | "collapsed": true 609 | }, 610 | "outputs": [], 611 | "source": [ 612 | "predictions = tf.argmax(tf.nn.softmax(logits),axis=1)" 613 | ] 614 | }, 615 | { 616 | "cell_type": "markdown", 617 | "metadata": {}, 618 | "source": [ 619 | "## Running Our Model to Classify Sentences\n", 620 | "\n", 621 | "Below we run our algorithm for 50 epochs. With the provided hyperparameters you should achieve around 90% accuracy on the test set. However you are welcome to play around with the hyperparameters." 622 | ] 623 | }, 624 | { 625 | "cell_type": "code", 626 | "execution_count": 13, 627 | "metadata": {}, 628 | "outputs": [ 629 | { 630 | "name": "stdout", 631 | "output_type": "stream", 632 | "text": [ 633 | "Initialized\n", 634 | "\n", 635 | "Train Loss at Epoch 0: 1.75\n", 636 | "Test accuracy at Epoch 0: 13.333\n", 637 | "Train Loss at Epoch 1: 1.69\n", 638 | "Test accuracy at Epoch 1: 13.333\n", 639 | "Train Loss at Epoch 2: 1.63\n", 640 | "Test accuracy at Epoch 2: 26.875\n", 641 | "Train Loss at Epoch 3: 1.58\n", 642 | "Test accuracy at Epoch 3: 28.542\n", 643 | "Train Loss at Epoch 4: 1.53\n", 644 | "Test accuracy at Epoch 4: 30.417\n", 645 | "Train Loss at Epoch 5: 1.49\n", 646 | "Test accuracy at Epoch 5: 34.792\n", 647 | "Train Loss at Epoch 6: 1.45\n", 648 | "Test accuracy at Epoch 6: 40.833\n", 649 | "Train Loss at Epoch 7: 1.42\n", 650 | "Test accuracy at Epoch 7: 45.625\n", 651 | "Train Loss at Epoch 8: 1.39\n", 652 | "Test accuracy at Epoch 8: 47.083\n", 653 | "Train Loss at Epoch 9: 1.37\n", 654 | "Test accuracy at Epoch 9: 48.542\n", 655 | "Train Loss at Epoch 10: 1.34\n", 656 | "Test accuracy at Epoch 10: 48.750\n", 657 | "Train Loss at Epoch 11: 1.33\n", 658 | "Test accuracy at Epoch 11: 48.750\n", 659 | "Train Loss at Epoch 12: 1.29\n", 660 | "Test accuracy at Epoch 12: 50.208\n", 661 | "Train Loss at Epoch 13: 1.27\n", 662 | "Test accuracy at Epoch 13: 53.958\n", 663 | "Train Loss at Epoch 14: 1.23\n", 664 | "Test accuracy at Epoch 14: 57.708\n", 665 | "Train Loss at Epoch 15: 1.16\n", 666 | "Test accuracy at Epoch 15: 61.667\n", 667 | "Train Loss at Epoch 16: 1.10\n", 668 | "Test accuracy at Epoch 16: 65.208\n", 669 | "Train Loss at Epoch 17: 1.04\n", 670 | "Test accuracy at Epoch 17: 64.375\n", 671 | "Train Loss at Epoch 18: 0.98\n", 672 | "Test accuracy at Epoch 18: 63.750\n", 673 | "Train Loss at Epoch 19: 0.93\n", 674 | "Test accuracy at Epoch 19: 63.125\n", 675 | "Train Loss at Epoch 20: 0.88\n", 676 | "Test accuracy at Epoch 20: 63.333\n", 677 | "Train Loss at Epoch 21: 0.83\n", 678 | "Test accuracy at Epoch 21: 63.333\n", 679 | "Train Loss at Epoch 22: 0.80\n", 680 | "Test accuracy at Epoch 22: 63.542\n", 681 | "Train Loss at Epoch 23: 0.77\n", 682 | "Test accuracy at Epoch 23: 65.000\n", 683 | "Train Loss at Epoch 24: 0.74\n", 684 | "Test accuracy at Epoch 24: 69.583\n", 685 | "Train Loss at Epoch 25: 0.70\n", 686 | "Test accuracy at Epoch 25: 72.500\n", 687 | "Train Loss at Epoch 26: 0.67\n", 688 | "Test accuracy at Epoch 26: 75.208\n", 689 | "Train Loss at Epoch 27: 0.64\n", 690 | "Test accuracy at Epoch 27: 76.667\n", 691 | "Train Loss at Epoch 28: 0.61\n", 692 | "Test accuracy at Epoch 28: 78.125\n", 693 | "Train Loss at Epoch 29: 0.58\n", 694 | "Test accuracy at Epoch 29: 80.417\n", 695 | "Train Loss at Epoch 30: 0.55\n", 696 | "Test accuracy at Epoch 30: 82.083\n", 697 | "Train Loss at Epoch 31: 0.53\n", 698 | "Test accuracy at Epoch 31: 83.125\n", 699 | "Train Loss at Epoch 32: 0.51\n", 700 | "Test accuracy at Epoch 32: 83.542\n", 701 | "Train Loss at Epoch 33: 0.48\n", 702 | "Test accuracy at Epoch 33: 84.167\n", 703 | "Train Loss at Epoch 34: 0.47\n", 704 | "Test accuracy at Epoch 34: 85.000\n", 705 | "Train Loss at Epoch 35: 0.44\n", 706 | "Test accuracy at Epoch 35: 85.417\n", 707 | "Train Loss at Epoch 36: 0.43\n", 708 | "Test accuracy at Epoch 36: 85.625\n", 709 | "Train Loss at Epoch 37: 0.42\n", 710 | "Test accuracy at Epoch 37: 85.833\n", 711 | "Train Loss at Epoch 38: 0.41\n", 712 | "Test accuracy at Epoch 38: 86.667\n", 713 | "Train Loss at Epoch 39: 0.39\n", 714 | "Test accuracy at Epoch 39: 87.292\n", 715 | "Train Loss at Epoch 40: 0.38\n", 716 | "Test accuracy at Epoch 40: 87.292\n", 717 | "Train Loss at Epoch 41: 0.36\n", 718 | "Test accuracy at Epoch 41: 87.500\n", 719 | "Train Loss at Epoch 42: 0.36\n", 720 | "Test accuracy at Epoch 42: 87.917\n", 721 | "Train Loss at Epoch 43: 0.34\n", 722 | "Test accuracy at Epoch 43: 88.542\n", 723 | "Train Loss at Epoch 44: 0.33\n", 724 | "Test accuracy at Epoch 44: 88.542\n", 725 | "Train Loss at Epoch 45: 0.32\n", 726 | "Test accuracy at Epoch 45: 88.542\n", 727 | "Train Loss at Epoch 46: 0.32\n", 728 | "Test accuracy at Epoch 46: 88.333\n", 729 | "Train Loss at Epoch 47: 0.31\n", 730 | "Test accuracy at Epoch 47: 88.542\n", 731 | "Train Loss at Epoch 48: 0.30\n", 732 | "Test accuracy at Epoch 48: 88.542\n", 733 | "Train Loss at Epoch 49: 0.29\n", 734 | "Test accuracy at Epoch 49: 88.750\n" 735 | ] 736 | } 737 | ], 738 | "source": [ 739 | "# With filter widths [3,5,7] and batch_size 32 the algorithm \n", 740 | "# achieves around ~90% accuracy on test dataset (50 epochs). \n", 741 | "# From batch sizes [16,32,64] I found 32 to give best performance\n", 742 | "\n", 743 | "session = tf.InteractiveSession()\n", 744 | "\n", 745 | "num_steps = 50 # Number of epochs the algorithm runs for\n", 746 | "\n", 747 | "# Initialize all variables\n", 748 | "tf.global_variables_initializer().run()\n", 749 | "print('Initialized\\n')\n", 750 | "\n", 751 | "# Define data batch generators for train and test data\n", 752 | "train_gen = BatchGenerator(batch_size,train_questions,train_labels)\n", 753 | "test_gen = BatchGenerator(batch_size,test_questions,test_labels)\n", 754 | "\n", 755 | "# How often do we compute the test accuracy\n", 756 | "test_interval = 1\n", 757 | "\n", 758 | "# Compute accuracy for a given set of predictions and labels\n", 759 | "def accuracy(labels,preds):\n", 760 | " return np.sum(np.argmax(labels,axis=1)==preds)/labels.shape[0]\n", 761 | "\n", 762 | "# Running the algorithm\n", 763 | "for step in range(num_steps):\n", 764 | " avg_loss = []\n", 765 | " \n", 766 | " # A single traverse through the whole training set\n", 767 | " for tr_i in range((len(train_questions)//batch_size)-1):\n", 768 | " # Get a batch of data\n", 769 | " tr_inputs, tr_labels = train_gen.generate_batch()\n", 770 | " # Optimize the network and compute the loss\n", 771 | " l,_ = session.run([loss,optimizer],feed_dict={sent_inputs: tr_inputs, sent_labels: tr_labels})\n", 772 | " avg_loss.append(l)\n", 773 | "\n", 774 | " # Print average loss\n", 775 | " print('Train Loss at Epoch %d: %.2f'%(step,np.mean(avg_loss)))\n", 776 | " test_accuracy = []\n", 777 | " \n", 778 | " # Compute the test accuracy\n", 779 | " if (step+1)%test_interval==0: \n", 780 | " for ts_i in range((len(test_questions)-1)//batch_size):\n", 781 | " # Get a batch of test data\n", 782 | " ts_inputs,ts_labels = test_gen.generate_batch()\n", 783 | " # Get predictions for that batch\n", 784 | " preds = session.run(predictions,feed_dict={sent_inputs: ts_inputs, sent_labels: ts_labels})\n", 785 | " # Compute test accuracy\n", 786 | " test_accuracy.append(accuracy(ts_labels,preds))\n", 787 | " \n", 788 | " # Display the mean test accuracy\n", 789 | " print('Test accuracy at Epoch %d: %.3f'%(step,np.mean(test_accuracy)*100.0))" 790 | ] 791 | }, 792 | { 793 | "cell_type": "code", 794 | "execution_count": null, 795 | "metadata": { 796 | "collapsed": true 797 | }, 798 | "outputs": [], 799 | "source": [] 800 | } 801 | ], 802 | "metadata": { 803 | "kernelspec": { 804 | "display_name": "Python 3", 805 | "language": "python", 806 | "name": "python3" 807 | }, 808 | "language_info": { 809 | "codemirror_mode": { 810 | "name": "ipython", 811 | "version": 3 812 | }, 813 | "file_extension": ".py", 814 | "mimetype": "text/x-python", 815 | "name": "python", 816 | "nbconvert_exporter": "python", 817 | "pygments_lexer": "ipython3", 818 | "version": "3.5.2" 819 | } 820 | }, 821 | "nbformat": 4, 822 | "nbformat_minor": 2 823 | } 824 | -------------------------------------------------------------------------------- /ch8/embeddings.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PacktPublishing/Natural-Language-Processing-with-TensorFlow/1d432b7e6fceb7819a60c9fd29560c864633a25b/ch8/embeddings.npy -------------------------------------------------------------------------------- /ch8/word2vec.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | import collections 4 | import random 5 | import math 6 | 7 | data_indices = None 8 | data_list = None 9 | reverse_dictionary = None 10 | embedding_size = None 11 | vocabulary_size = None 12 | num_files = None 13 | def define_data_and_hyperparameters(_num_files,_data_list, _reverse_dictionary, _emb_size, _vocab_size): 14 | global num_files, data_indices, data_list, reverse_dictionary 15 | global embedding_size, vocabulary_size 16 | 17 | num_files = _num_files 18 | data_indices = [0 for _ in range(num_files)] 19 | data_list = _data_list 20 | reverse_dictionary = _reverse_dictionary 21 | embedding_size = _emb_size 22 | vocabulary_size = _vocab_size 23 | 24 | 25 | def generate_batch_for_word2vec(data_list, doc_id, batch_size, window_size): 26 | # window_size is the amount of words we're looking at from each side of a given word 27 | # creates a single batch 28 | # doc_id is the ID of the story we want to extract a batch from 29 | 30 | # data_indices[doc_id] is updated by 1 everytime we read a set of data point 31 | # from the document identified by doc_id 32 | global data_indices 33 | 34 | # span defines the total window size, where 35 | # data we consider at an instance looks as follows. 36 | # [ skip_window target skip_window ] 37 | # e.g if skip_window = 2 then span = 5 38 | span = 2 * window_size + 1 39 | 40 | # two numpy arras to hold target words (batch) 41 | # and context words (labels) 42 | # Note that batch has span-1=2*window_size columns 43 | batch = np.ndarray(shape=(batch_size,span-1), dtype=np.int32) 44 | labels = np.ndarray(shape=(batch_size, 1), dtype=np.int32) 45 | 46 | # The buffer holds the data contained within the span 47 | buffer = collections.deque(maxlen=span) 48 | 49 | # Fill the buffer and update the data_index 50 | for _ in range(span): 51 | buffer.append(data_list[doc_id][data_indices[doc_id]]) 52 | data_indices[doc_id] = (data_indices[doc_id] + 1) % len(data_list[doc_id]) 53 | 54 | # Here we do the batch reading 55 | # We iterate through each batch index 56 | # For each batch index, we iterate through span elements 57 | # to fill in the columns of batch array 58 | for i in range(batch_size): 59 | target = window_size # target label at the center of the buffer 60 | target_to_avoid = [ window_size ] # we only need to know the words around a given word, not the word itself 61 | 62 | # add selected target to avoid_list for next time 63 | col_idx = 0 64 | for j in range(span): 65 | # ignore the target word when creating the batch 66 | if j==span//2: 67 | continue 68 | batch[i,col_idx] = buffer[j] 69 | col_idx += 1 70 | labels[i, 0] = buffer[target] 71 | 72 | # Everytime we read a data point, 73 | # we need to move the span by 1 74 | # to update the span 75 | buffer.append(data_list[doc_id][data_indices[doc_id]]) 76 | data_indices[doc_id] = (data_indices[doc_id] + 1) % len(data_list[doc_id]) 77 | 78 | assert batch.shape[0]==batch_size and batch.shape[1]== span-1 79 | return batch, labels 80 | 81 | def print_some_batches(): 82 | global num_files, data_list, reverse_dictionary 83 | 84 | for window_size in [1,2]: 85 | data_indices = [0 for _ in range(num_files)] 86 | batch, labels = generate_batch_for_word2vec(data_list, doc_id=0, batch_size=8, window_size=window_size) 87 | print('\nwith window_size = %d:' % (window_size)) 88 | print(' batch:', [[reverse_dictionary[bii] for bii in bi] for bi in batch]) 89 | print(' labels:', [reverse_dictionary[li] for li in labels.reshape(8)]) 90 | 91 | batch_size, embedding_size, window_size = None, None, None 92 | valid_size, valid_window, valid_examples = None, None, None 93 | num_sampled = None 94 | 95 | train_dataset, train_labels = None, None 96 | valid_dataset = None 97 | 98 | softmax_weights, softmax_biases = None, None 99 | 100 | loss, optimizer, similarity, normalized_embeddings = None, None, None, None 101 | 102 | def define_word2vec_tensorflow(): 103 | global batch_size, embedding_size, window_size 104 | global valid_size, valid_window, valid_examples 105 | global num_sampled 106 | global train_dataset, train_labels 107 | global valid_dataset 108 | global softmax_weights, softmax_biases 109 | global loss, optimizer, similarity 110 | global vocabulary_size, embedding_size 111 | global normalized_embeddings 112 | 113 | batch_size = 128 # Data points in a single batch 114 | 115 | # How many words to consider left and right. 116 | # Skip gram by design does not require to have all the context words in a given step 117 | # However, for CBOW that's a requirement, so we limit the window size 118 | window_size = 3 119 | 120 | # We pick a random validation set to sample nearest neighbors 121 | valid_size = 16 # Random set of words to evaluate similarity on. 122 | # We sample valid datapoints randomly from a large window without always being deterministic 123 | valid_window = 50 124 | 125 | # When selecting valid examples, we select some of the most frequent words as well as 126 | # some moderately rare words as well 127 | valid_examples = np.array(random.sample(range(valid_window), valid_size)) 128 | valid_examples = np.append(valid_examples,random.sample(range(1000, 1000+valid_window), valid_size),axis=0) 129 | 130 | num_sampled = 32 # Number of negative examples to sample. 131 | 132 | tf.reset_default_graph() 133 | 134 | # Training input data (target word IDs). Note that it has 2*window_size columns 135 | train_dataset = tf.placeholder(tf.int32, shape=[batch_size,2*window_size]) 136 | # Training input label data (context word IDs) 137 | train_labels = tf.placeholder(tf.int32, shape=[batch_size, 1]) 138 | # Validation input data, we don't need a placeholder 139 | # as we have already defined the IDs of the words selected 140 | # as validation data 141 | valid_dataset = tf.constant(valid_examples, dtype=tf.int32) 142 | 143 | # Variables. 144 | 145 | # Embedding layer, contains the word embeddings 146 | embeddings = tf.Variable(tf.random_uniform([vocabulary_size, embedding_size], -1.0, 1.0,dtype=tf.float32)) 147 | 148 | # Softmax Weights and Biases 149 | softmax_weights = tf.Variable(tf.truncated_normal([vocabulary_size, embedding_size], 150 | stddev=0.5 / math.sqrt(embedding_size),dtype=tf.float32)) 151 | softmax_biases = tf.Variable(tf.random_uniform([vocabulary_size],0.0,0.01)) 152 | 153 | # Model. 154 | # Look up embeddings for a batch of inputs. 155 | # Here we do embedding lookups for each column in the input placeholder 156 | # and then average them to produce an embedding_size word vector 157 | stacked_embedings = None 158 | print('Defining %d embedding lookups representing each word in the context'%(2*window_size)) 159 | for i in range(2*window_size): 160 | embedding_i = tf.nn.embedding_lookup(embeddings, train_dataset[:,i]) 161 | x_size,y_size = embedding_i.get_shape().as_list() 162 | if stacked_embedings is None: 163 | stacked_embedings = tf.reshape(embedding_i,[x_size,y_size,1]) 164 | else: 165 | stacked_embedings = tf.concat(axis=2,values=[stacked_embedings,tf.reshape(embedding_i,[x_size,y_size,1])]) 166 | 167 | assert stacked_embedings.get_shape().as_list()[2]==2*window_size 168 | print("Stacked embedding size: %s"%stacked_embedings.get_shape().as_list()) 169 | mean_embeddings = tf.reduce_mean(stacked_embedings,2,keepdims=False) 170 | print("Reduced mean embedding size: %s"%mean_embeddings.get_shape().as_list()) 171 | 172 | 173 | # Compute the softmax loss, using a sample of the negative labels each time. 174 | # inputs are embeddings of the train words 175 | # with this loss we optimize weights, biases, embeddings 176 | loss = tf.reduce_mean(tf.nn.sampled_softmax_loss(weights=softmax_weights, biases=softmax_biases, inputs=mean_embeddings, 177 | labels=train_labels, num_sampled=num_sampled, num_classes=vocabulary_size)) 178 | # AdamOptimizer. 179 | optimizer = tf.train.AdamOptimizer(0.0005).minimize(loss) 180 | 181 | # Compute the similarity between minibatch examples and all embeddings. 182 | # We use the cosine distance: 183 | norm = tf.sqrt(tf.reduce_sum(tf.square(embeddings), 1, keepdims=True)) 184 | normalized_embeddings = embeddings / norm 185 | valid_embeddings = tf.nn.embedding_lookup(normalized_embeddings, valid_dataset) 186 | similarity = tf.matmul(valid_embeddings, tf.transpose(normalized_embeddings)) 187 | 188 | 189 | def run_word2vec(): 190 | global batch_size, embedding_size, window_size 191 | global valid_size, valid_window, valid_examples 192 | global num_sampled 193 | global train_dataset, train_labels 194 | global valid_dataset 195 | global softmax_weights, softmax_biases 196 | global loss, optimizer, similarity, normalized_embeddings 197 | global data_list, num_files, reverse_dictionary 198 | global vocabulary_size, embedding_size 199 | 200 | num_steps = 10 201 | steps_per_doc = 100 202 | 203 | session = tf.InteractiveSession() 204 | 205 | # Initialize the variables in the graph 206 | tf.global_variables_initializer().run() 207 | print('Initialized') 208 | 209 | average_loss = 0 210 | 211 | for step in range(num_steps): 212 | 213 | # Iterate through the documents in a random order 214 | for doc_id in np.random.permutation(num_files): 215 | for doc_step in range(steps_per_doc): 216 | 217 | # Generate a single batch of data from a document 218 | batch_data, batch_labels = generate_batch_for_word2vec(data_list, doc_id, batch_size, window_size) 219 | 220 | # Populate the feed_dict and run the optimizer (minimize loss) 221 | # and compute the loss 222 | feed_dict = {train_dataset : batch_data, train_labels : batch_labels} 223 | _, l = session.run([optimizer, loss], feed_dict=feed_dict) 224 | 225 | average_loss += l 226 | 227 | if (step+1) % 1 == 0: 228 | if step > 0: 229 | # compute average loss 230 | average_loss = average_loss / (doc_id*steps_per_doc) 231 | 232 | print('Average loss at step %d: %f' % (step+1, average_loss)) 233 | average_loss = 0 # reset average loss 234 | 235 | # Evaluating validation set word similarities 236 | if (step+1) % 5 == 0: 237 | sim = similarity.eval() 238 | 239 | # Here we compute the top_k closest words for a given validation word 240 | # in terms of the cosine distance 241 | # We do this for all the words in the validation set 242 | # Note: This is an expensive step 243 | for i in range(valid_size): 244 | valid_word = reverse_dictionary[valid_examples[i]] 245 | top_k = 4 # number of nearest neighbors 246 | nearest = (-sim[i, :]).argsort()[1:top_k+1] 247 | log = 'Nearest to %s:' % valid_word 248 | for k in range(top_k): 249 | close_word = reverse_dictionary[nearest[k]] 250 | log = '%s %s,' % (log, close_word) 251 | print(log) 252 | cbow_final_embeddings = normalized_embeddings.eval() 253 | 254 | # We save the embeddings as embeddings.npy 255 | np.save('embeddings',cbow_final_embeddings) -------------------------------------------------------------------------------- /ch9/correct_spellings.py: -------------------------------------------------------------------------------- 1 | from difflib import SequenceMatcher 2 | 3 | def string_similarity(a, b): 4 | return SequenceMatcher(None, a, b).ratio() 5 | 6 | def correct_wrong_word(cw,gw,cap): 7 | ''' 8 | Spelling correction logic 9 | This is a very simple logic that replaces 10 | words with incorrect spelling with the word that highest 11 | similarity. Some words are manually corrected as the words 12 | found to be most similar semantically did not match. 13 | ''' 14 | correct_word = None 15 | found_similar_word = False 16 | sim = string_similarity(gw,cw) 17 | if sim>0.9: 18 | if cw != 'stting' and cw != 'sittign' and cw != 'smilling' and \ 19 | cw!='skiies' and cw!='childi' and cw!='sittion' and cw!='peacefuly' and cw!='stainding' and\ 20 | cw != 'staning' and cw!='lating' and cw!='sking' and cw!='trolly' and cw!='umping' and cw!='earing' and \ 21 | cw !='baters' and cw !='talkes' and cw !='trowing' and cw !='convered' and cw !='onsie' and cw !='slying': 22 | print(gw,' ',cw,' ',sim,' (',cap,')') 23 | correct_word = gw 24 | found_similar_word = True 25 | elif cw == 'stting' or cw == 'sittign' or cw == 'sittion': 26 | correct_word = 'sitting' 27 | found_similar_word = True 28 | elif cw == 'smilling': 29 | correct_word = 'smiling' 30 | found_similar_word = True 31 | elif cw == 'skiies': 32 | correct_word = 'skis' 33 | found_similar_word = True 34 | elif cw == 'childi': 35 | correct_word = 'child' 36 | found_similar_word = True 37 | elif cw == 'peacefuly': 38 | correct_word = 'peacefully' 39 | found_similar_word = True 40 | elif cw == 'stainding' or cw == 'staning': 41 | correct_word = 'standing' 42 | found_similar_word = True 43 | elif cw == 'lating': 44 | correct_word = 'laying' 45 | found_similar_word = True 46 | elif cw == 'sking': 47 | correct_word = 'skiing' 48 | found_similar_word = True 49 | elif cw == 'trolly': 50 | correct_word = 'trolley' 51 | found_similar_word = True 52 | elif cw == 'umping': 53 | correct_word = 'jumping' 54 | found_similar_word = True 55 | elif cw == 'earing': 56 | correct_word = 'eating' 57 | found_similar_word = True 58 | elif cw == 'baters': 59 | correct_word = 'batters' 60 | found_similar_word = True 61 | elif cw == 'talkes': 62 | correct_word = 'talks' 63 | found_similar_word = True 64 | elif cw == 'trowing': 65 | correct_word = 'throwing' 66 | found_similar_word = True 67 | elif cw =='convered': 68 | correct_word = 'covered' 69 | found_similar_word = True 70 | elif cw == 'onsie': 71 | correct_word = cw 72 | found_similar_word = True 73 | elif cw =='slying': 74 | correct_word = 'flying' 75 | found_similar_word = True 76 | else: 77 | raise NotImplementedError 78 | else: 79 | correct_word = cw 80 | found_similar_word = False 81 | 82 | return correct_word, found_similar_word -------------------------------------------------------------------------------- /ch9/image_caption_data/class_names.txt: -------------------------------------------------------------------------------- 1 | tench, Tinca tinca 2 | goldfish, Carassius auratus 3 | great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias 4 | tiger shark, Galeocerdo cuvieri 5 | hammerhead, hammerhead shark 6 | electric ray, crampfish, numbfish, torpedo 7 | stingray 8 | cock 9 | hen 10 | ostrich, Struthio camelus 11 | brambling, Fringilla montifringilla 12 | goldfinch, Carduelis carduelis 13 | house finch, linnet, Carpodacus mexicanus 14 | junco, snowbird 15 | indigo bunting, indigo finch, indigo bird, Passerina cyanea 16 | robin, American robin, Turdus migratorius 17 | bulbul 18 | jay 19 | magpie 20 | chickadee 21 | water ouzel, dipper 22 | kite 23 | bald eagle, American eagle, Haliaeetus leucocephalus 24 | vulture 25 | great grey owl, great gray owl, Strix nebulosa 26 | European fire salamander, Salamandra salamandra 27 | common newt, Triturus vulgaris 28 | eft 29 | spotted salamander, Ambystoma maculatum 30 | axolotl, mud puppy, Ambystoma mexicanum 31 | bullfrog, Rana catesbeiana 32 | tree frog, tree-frog 33 | tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui 34 | loggerhead, loggerhead turtle, Caretta caretta 35 | leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea 36 | mud turtle 37 | terrapin 38 | box turtle, box tortoise 39 | banded gecko 40 | common iguana, iguana, Iguana iguana 41 | American chameleon, anole, Anolis carolinensis 42 | whiptail, whiptail lizard 43 | agama 44 | frilled lizard, Chlamydosaurus kingi 45 | alligator lizard 46 | Gila monster, Heloderma suspectum 47 | green lizard, Lacerta viridis 48 | African chameleon, Chamaeleo chamaeleon 49 | Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis 50 | African crocodile, Nile crocodile, Crocodylus niloticus 51 | American alligator, Alligator mississipiensis 52 | triceratops 53 | thunder snake, worm snake, Carphophis amoenus 54 | ringneck snake, ring-necked snake, ring snake 55 | hognose snake, puff adder, sand viper 56 | green snake, grass snake 57 | king snake, kingsnake 58 | garter snake, grass snake 59 | water snake 60 | vine snake 61 | night snake, Hypsiglena torquata 62 | boa constrictor, Constrictor constrictor 63 | rock python, rock snake, Python sebae 64 | Indian cobra, Naja naja 65 | green mamba 66 | sea snake 67 | horned viper, cerastes, sand viper, horned asp, Cerastes cornutus 68 | diamondback, diamondback rattlesnake, Crotalus adamanteus 69 | sidewinder, horned rattlesnake, Crotalus cerastes 70 | trilobite 71 | harvestman, daddy longlegs, Phalangium opilio 72 | scorpion 73 | black and gold garden spider, Argiope aurantia 74 | barn spider, Araneus cavaticus 75 | garden spider, Aranea diademata 76 | black widow, Latrodectus mactans 77 | tarantula 78 | wolf spider, hunting spider 79 | tick 80 | centipede 81 | black grouse 82 | ptarmigan 83 | ruffed grouse, partridge, Bonasa umbellus 84 | prairie chicken, prairie grouse, prairie fowl 85 | peacock 86 | quail 87 | partridge 88 | African grey, African gray, Psittacus erithacus 89 | macaw 90 | sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita 91 | lorikeet 92 | coucal 93 | bee eater 94 | hornbill 95 | hummingbird 96 | jacamar 97 | toucan 98 | drake 99 | red-breasted merganser, Mergus serrator 100 | goose 101 | black swan, Cygnus atratus 102 | tusker 103 | echidna, spiny anteater, anteater 104 | platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus 105 | wallaby, brush kangaroo 106 | koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus 107 | wombat 108 | jellyfish 109 | sea anemone, anemone 110 | brain coral 111 | flatworm, platyhelminth 112 | nematode, nematode worm, roundworm 113 | conch 114 | snail 115 | slug 116 | sea slug, nudibranch 117 | chiton, coat-of-mail shell, sea cradle, polyplacophore 118 | chambered nautilus, pearly nautilus, nautilus 119 | Dungeness crab, Cancer magister 120 | rock crab, Cancer irroratus 121 | fiddler crab 122 | king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica 123 | American lobster, Northern lobster, Maine lobster, Homarus americanus 124 | spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish 125 | crayfish, crawfish, crawdad, crawdaddy 126 | hermit crab 127 | isopod 128 | white stork, Ciconia ciconia 129 | black stork, Ciconia nigra 130 | spoonbill 131 | flamingo 132 | little blue heron, Egretta caerulea 133 | American egret, great white heron, Egretta albus 134 | bittern 135 | crane 136 | limpkin, Aramus pictus 137 | European gallinule, Porphyrio porphyrio 138 | American coot, marsh hen, mud hen, water hen, Fulica americana 139 | bustard 140 | ruddy turnstone, Arenaria interpres 141 | red-backed sandpiper, dunlin, Erolia alpina 142 | redshank, Tringa totanus 143 | dowitcher 144 | oystercatcher, oyster catcher 145 | pelican 146 | king penguin, Aptenodytes patagonica 147 | albatross, mollymawk 148 | grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus 149 | killer whale, killer, orca, grampus, sea wolf, Orcinus orca 150 | dugong, Dugong dugon 151 | sea lion 152 | Chihuahua 153 | Japanese spaniel 154 | Maltese dog, Maltese terrier, Maltese 155 | Pekinese, Pekingese, Peke 156 | Shih-Tzu 157 | Blenheim spaniel 158 | papillon 159 | toy terrier 160 | Rhodesian ridgeback 161 | Afghan hound, Afghan 162 | basset, basset hound 163 | beagle 164 | bloodhound, sleuthhound 165 | bluetick 166 | black-and-tan coonhound 167 | Walker hound, Walker foxhound 168 | English foxhound 169 | redbone 170 | borzoi, Russian wolfhound 171 | Irish wolfhound 172 | Italian greyhound 173 | whippet 174 | Ibizan hound, Ibizan Podenco 175 | Norwegian elkhound, elkhound 176 | otterhound, otter hound 177 | Saluki, gazelle hound 178 | Scottish deerhound, deerhound 179 | Weimaraner 180 | Staffordshire bullterrier, Staffordshire bull terrier 181 | American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier 182 | Bedlington terrier 183 | Border terrier 184 | Kerry blue terrier 185 | Irish terrier 186 | Norfolk terrier 187 | Norwich terrier 188 | Yorkshire terrier 189 | wire-haired fox terrier 190 | Lakeland terrier 191 | Sealyham terrier, Sealyham 192 | Airedale, Airedale terrier 193 | cairn, cairn terrier 194 | Australian terrier 195 | Dandie Dinmont, Dandie Dinmont terrier 196 | Boston bull, Boston terrier 197 | miniature schnauzer 198 | giant schnauzer 199 | standard schnauzer 200 | Scotch terrier, Scottish terrier, Scottie 201 | Tibetan terrier, chrysanthemum dog 202 | silky terrier, Sydney silky 203 | soft-coated wheaten terrier 204 | West Highland white terrier 205 | Lhasa, Lhasa apso 206 | flat-coated retriever 207 | curly-coated retriever 208 | golden retriever 209 | Labrador retriever 210 | Chesapeake Bay retriever 211 | German short-haired pointer 212 | vizsla, Hungarian pointer 213 | English setter 214 | Irish setter, red setter 215 | Gordon setter 216 | Brittany spaniel 217 | clumber, clumber spaniel 218 | English springer, English springer spaniel 219 | Welsh springer spaniel 220 | cocker spaniel, English cocker spaniel, cocker 221 | Sussex spaniel 222 | Irish water spaniel 223 | kuvasz 224 | schipperke 225 | groenendael 226 | malinois 227 | briard 228 | kelpie 229 | komondor 230 | Old English sheepdog, bobtail 231 | Shetland sheepdog, Shetland sheep dog, Shetland 232 | collie 233 | Border collie 234 | Bouvier des Flandres, Bouviers des Flandres 235 | Rottweiler 236 | German shepherd, German shepherd dog, German police dog, alsatian 237 | Doberman, Doberman pinscher 238 | miniature pinscher 239 | Greater Swiss Mountain dog 240 | Bernese mountain dog 241 | Appenzeller 242 | EntleBucher 243 | boxer 244 | bull mastiff 245 | Tibetan mastiff 246 | French bulldog 247 | Great Dane 248 | Saint Bernard, St Bernard 249 | Eskimo dog, husky 250 | malamute, malemute, Alaskan malamute 251 | Siberian husky 252 | dalmatian, coach dog, carriage dog 253 | affenpinscher, monkey pinscher, monkey dog 254 | basenji 255 | pug, pug-dog 256 | Leonberg 257 | Newfoundland, Newfoundland dog 258 | Great Pyrenees 259 | Samoyed, Samoyede 260 | Pomeranian 261 | chow, chow chow 262 | keeshond 263 | Brabancon griffon 264 | Pembroke, Pembroke Welsh corgi 265 | Cardigan, Cardigan Welsh corgi 266 | toy poodle 267 | miniature poodle 268 | standard poodle 269 | Mexican hairless 270 | timber wolf, grey wolf, gray wolf, Canis lupus 271 | white wolf, Arctic wolf, Canis lupus tundrarum 272 | red wolf, maned wolf, Canis rufus, Canis niger 273 | coyote, prairie wolf, brush wolf, Canis latrans 274 | dingo, warrigal, warragal, Canis dingo 275 | dhole, Cuon alpinus 276 | African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus 277 | hyena, hyaena 278 | red fox, Vulpes vulpes 279 | kit fox, Vulpes macrotis 280 | Arctic fox, white fox, Alopex lagopus 281 | grey fox, gray fox, Urocyon cinereoargenteus 282 | tabby, tabby cat 283 | tiger cat 284 | Persian cat 285 | Siamese cat, Siamese 286 | Egyptian cat 287 | cougar, puma, catamount, mountain lion, painter, panther, Felis concolor 288 | lynx, catamount 289 | leopard, Panthera pardus 290 | snow leopard, ounce, Panthera uncia 291 | jaguar, panther, Panthera onca, Felis onca 292 | lion, king of beasts, Panthera leo 293 | tiger, Panthera tigris 294 | cheetah, chetah, Acinonyx jubatus 295 | brown bear, bruin, Ursus arctos 296 | American black bear, black bear, Ursus americanus, Euarctos americanus 297 | ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus 298 | sloth bear, Melursus ursinus, Ursus ursinus 299 | mongoose 300 | meerkat, mierkat 301 | tiger beetle 302 | ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle 303 | ground beetle, carabid beetle 304 | long-horned beetle, longicorn, longicorn beetle 305 | leaf beetle, chrysomelid 306 | dung beetle 307 | rhinoceros beetle 308 | weevil 309 | fly 310 | bee 311 | ant, emmet, pismire 312 | grasshopper, hopper 313 | cricket 314 | walking stick, walkingstick, stick insect 315 | cockroach, roach 316 | mantis, mantid 317 | cicada, cicala 318 | leafhopper 319 | lacewing, lacewing fly 320 | dragonfly, darning needle, devil's darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk 321 | damselfly 322 | admiral 323 | ringlet, ringlet butterfly 324 | monarch, monarch butterfly, milkweed butterfly, Danaus plexippus 325 | cabbage butterfly 326 | sulphur butterfly, sulfur butterfly 327 | lycaenid, lycaenid butterfly 328 | starfish, sea star 329 | sea urchin 330 | sea cucumber, holothurian 331 | wood rabbit, cottontail, cottontail rabbit 332 | hare 333 | Angora, Angora rabbit 334 | hamster 335 | porcupine, hedgehog 336 | fox squirrel, eastern fox squirrel, Sciurus niger 337 | marmot 338 | beaver 339 | guinea pig, Cavia cobaya 340 | sorrel 341 | zebra 342 | hog, pig, grunter, squealer, Sus scrofa 343 | wild boar, boar, Sus scrofa 344 | warthog 345 | hippopotamus, hippo, river horse, Hippopotamus amphibius 346 | ox 347 | water buffalo, water ox, Asiatic buffalo, Bubalus bubalis 348 | bison 349 | ram, tup 350 | bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis 351 | ibex, Capra ibex 352 | hartebeest 353 | impala, Aepyceros melampus 354 | gazelle 355 | Arabian camel, dromedary, Camelus dromedarius 356 | llama 357 | weasel 358 | mink 359 | polecat, fitch, foulmart, foumart, Mustela putorius 360 | black-footed ferret, ferret, Mustela nigripes 361 | otter 362 | skunk, polecat, wood pussy 363 | badger 364 | armadillo 365 | three-toed sloth, ai, Bradypus tridactylus 366 | orangutan, orang, orangutang, Pongo pygmaeus 367 | gorilla, Gorilla gorilla 368 | chimpanzee, chimp, Pan troglodytes 369 | gibbon, Hylobates lar 370 | siamang, Hylobates syndactylus, Symphalangus syndactylus 371 | guenon, guenon monkey 372 | patas, hussar monkey, Erythrocebus patas 373 | baboon 374 | macaque 375 | langur 376 | colobus, colobus monkey 377 | proboscis monkey, Nasalis larvatus 378 | marmoset 379 | capuchin, ringtail, Cebus capucinus 380 | howler monkey, howler 381 | titi, titi monkey 382 | spider monkey, Ateles geoffroyi 383 | squirrel monkey, Saimiri sciureus 384 | Madagascar cat, ring-tailed lemur, Lemur catta 385 | indri, indris, Indri indri, Indri brevicaudatus 386 | Indian elephant, Elephas maximus 387 | African elephant, Loxodonta africana 388 | lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens 389 | giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca 390 | barracouta, snoek 391 | eel 392 | coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch 393 | rock beauty, Holocanthus tricolor 394 | anemone fish 395 | sturgeon 396 | gar, garfish, garpike, billfish, Lepisosteus osseus 397 | lionfish 398 | puffer, pufferfish, blowfish, globefish 399 | abacus 400 | abaya 401 | academic gown, academic robe, judge's robe 402 | accordion, piano accordion, squeeze box 403 | acoustic guitar 404 | aircraft carrier, carrier, flattop, attack aircraft carrier 405 | airliner 406 | airship, dirigible 407 | altar 408 | ambulance 409 | amphibian, amphibious vehicle 410 | analog clock 411 | apiary, bee house 412 | apron 413 | ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin 414 | assault rifle, assault gun 415 | backpack, back pack, knapsack, packsack, rucksack, haversack 416 | bakery, bakeshop, bakehouse 417 | balance beam, beam 418 | balloon 419 | ballpoint, ballpoint pen, ballpen, Biro 420 | Band Aid 421 | banjo 422 | bannister, banister, balustrade, balusters, handrail 423 | barbell 424 | barber chair 425 | barbershop 426 | barn 427 | barometer 428 | barrel, cask 429 | barrow, garden cart, lawn cart, wheelbarrow 430 | baseball 431 | basketball 432 | bassinet 433 | bassoon 434 | bathing cap, swimming cap 435 | bath towel 436 | bathtub, bathing tub, bath, tub 437 | beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon 438 | beacon, lighthouse, beacon light, pharos 439 | beaker 440 | bearskin, busby, shako 441 | beer bottle 442 | beer glass 443 | bell cote, bell cot 444 | bib 445 | bicycle-built-for-two, tandem bicycle, tandem 446 | bikini, two-piece 447 | binder, ring-binder 448 | binoculars, field glasses, opera glasses 449 | birdhouse 450 | boathouse 451 | bobsled, bobsleigh, bob 452 | bolo tie, bolo, bola tie, bola 453 | bonnet, poke bonnet 454 | bookcase 455 | bookshop, bookstore, bookstall 456 | bottlecap 457 | bow 458 | bow tie, bow-tie, bowtie 459 | brass, memorial tablet, plaque 460 | brassiere, bra, bandeau 461 | breakwater, groin, groyne, mole, bulwark, seawall, jetty 462 | breastplate, aegis, egis 463 | broom 464 | bucket, pail 465 | buckle 466 | bulletproof vest 467 | bullet train, bullet 468 | butcher shop, meat market 469 | cab, hack, taxi, taxicab 470 | caldron, cauldron 471 | candle, taper, wax light 472 | cannon 473 | canoe 474 | can opener, tin opener 475 | cardigan 476 | car mirror 477 | carousel, carrousel, merry-go-round, roundabout, whirligig 478 | carpenter's kit, tool kit 479 | carton 480 | car wheel 481 | cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM 482 | cassette 483 | cassette player 484 | castle 485 | catamaran 486 | CD player 487 | cello, violoncello 488 | cellular telephone, cellular phone, cellphone, cell, mobile phone 489 | chain 490 | chainlink fence 491 | chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour 492 | chain saw, chainsaw 493 | chest 494 | chiffonier, commode 495 | chime, bell, gong 496 | china cabinet, china closet 497 | Christmas stocking 498 | church, church building 499 | cinema, movie theater, movie theatre, movie house, picture palace 500 | cleaver, meat cleaver, chopper 501 | cliff dwelling 502 | cloak 503 | clog, geta, patten, sabot 504 | cocktail shaker 505 | coffee mug 506 | coffeepot 507 | coil, spiral, volute, whorl, helix 508 | combination lock 509 | computer keyboard, keypad 510 | confectionery, confectionary, candy store 511 | container ship, containership, container vessel 512 | convertible 513 | corkscrew, bottle screw 514 | cornet, horn, trumpet, trump 515 | cowboy boot 516 | cowboy hat, ten-gallon hat 517 | cradle 518 | crane 519 | crash helmet 520 | crate 521 | crib, cot 522 | Crock Pot 523 | croquet ball 524 | crutch 525 | cuirass 526 | dam, dike, dyke 527 | desk 528 | desktop computer 529 | dial telephone, dial phone 530 | diaper, nappy, napkin 531 | digital clock 532 | digital watch 533 | dining table, board 534 | dishrag, dishcloth 535 | dishwasher, dish washer, dishwashing machine 536 | disk brake, disc brake 537 | dock, dockage, docking facility 538 | dogsled, dog sled, dog sleigh 539 | dome 540 | doormat, welcome mat 541 | drilling platform, offshore rig 542 | drum, membranophone, tympan 543 | drumstick 544 | dumbbell 545 | Dutch oven 546 | electric fan, blower 547 | electric guitar 548 | electric locomotive 549 | entertainment center 550 | envelope 551 | espresso maker 552 | face powder 553 | feather boa, boa 554 | file, file cabinet, filing cabinet 555 | fireboat 556 | fire engine, fire truck 557 | fire screen, fireguard 558 | flagpole, flagstaff 559 | flute, transverse flute 560 | folding chair 561 | football helmet 562 | forklift 563 | fountain 564 | fountain pen 565 | four-poster 566 | freight car 567 | French horn, horn 568 | frying pan, frypan, skillet 569 | fur coat 570 | garbage truck, dustcart 571 | gasmask, respirator, gas helmet 572 | gas pump, gasoline pump, petrol pump, island dispenser 573 | goblet 574 | go-kart 575 | golf ball 576 | golfcart, golf cart 577 | gondola 578 | gong, tam-tam 579 | gown 580 | grand piano, grand 581 | greenhouse, nursery, glasshouse 582 | grille, radiator grille 583 | grocery store, grocery, food market, market 584 | guillotine 585 | hair slide 586 | hair spray 587 | half track 588 | hammer 589 | hamper 590 | hand blower, blow dryer, blow drier, hair dryer, hair drier 591 | hand-held computer, hand-held microcomputer 592 | handkerchief, hankie, hanky, hankey 593 | hard disc, hard disk, fixed disk 594 | harmonica, mouth organ, harp, mouth harp 595 | harp 596 | harvester, reaper 597 | hatchet 598 | holster 599 | home theater, home theatre 600 | honeycomb 601 | hook, claw 602 | hoopskirt, crinoline 603 | horizontal bar, high bar 604 | horse cart, horse-cart 605 | hourglass 606 | iPod 607 | iron, smoothing iron 608 | jack-o'-lantern 609 | jean, blue jean, denim 610 | jeep, landrover 611 | jersey, T-shirt, tee shirt 612 | jigsaw puzzle 613 | jinrikisha, ricksha, rickshaw 614 | joystick 615 | kimono 616 | knee pad 617 | knot 618 | lab coat, laboratory coat 619 | ladle 620 | lampshade, lamp shade 621 | laptop, laptop computer 622 | lawn mower, mower 623 | lens cap, lens cover 624 | letter opener, paper knife, paperknife 625 | library 626 | lifeboat 627 | lighter, light, igniter, ignitor 628 | limousine, limo 629 | liner, ocean liner 630 | lipstick, lip rouge 631 | Loafer 632 | lotion 633 | loudspeaker, speaker, speaker unit, loudspeaker system, speaker system 634 | loupe, jeweler's loupe 635 | lumbermill, sawmill 636 | magnetic compass 637 | mailbag, postbag 638 | mailbox, letter box 639 | maillot 640 | maillot, tank suit 641 | manhole cover 642 | maraca 643 | marimba, xylophone 644 | mask 645 | matchstick 646 | maypole 647 | maze, labyrinth 648 | measuring cup 649 | medicine chest, medicine cabinet 650 | megalith, megalithic structure 651 | microphone, mike 652 | microwave, microwave oven 653 | military uniform 654 | milk can 655 | minibus 656 | miniskirt, mini 657 | minivan 658 | missile 659 | mitten 660 | mixing bowl 661 | mobile home, manufactured home 662 | Model T 663 | modem 664 | monastery 665 | monitor 666 | moped 667 | mortar 668 | mortarboard 669 | mosque 670 | mosquito net 671 | motor scooter, scooter 672 | mountain bike, all-terrain bike, off-roader 673 | mountain tent 674 | mouse, computer mouse 675 | mousetrap 676 | moving van 677 | muzzle 678 | nail 679 | neck brace 680 | necklace 681 | nipple 682 | notebook, notebook computer 683 | obelisk 684 | oboe, hautboy, hautbois 685 | ocarina, sweet potato 686 | odometer, hodometer, mileometer, milometer 687 | oil filter 688 | organ, pipe organ 689 | oscilloscope, scope, cathode-ray oscilloscope, CRO 690 | overskirt 691 | oxcart 692 | oxygen mask 693 | packet 694 | paddle, boat paddle 695 | paddlewheel, paddle wheel 696 | padlock 697 | paintbrush 698 | pajama, pyjama, pj's, jammies 699 | palace 700 | panpipe, pandean pipe, syrinx 701 | paper towel 702 | parachute, chute 703 | parallel bars, bars 704 | park bench 705 | parking meter 706 | passenger car, coach, carriage 707 | patio, terrace 708 | pay-phone, pay-station 709 | pedestal, plinth, footstall 710 | pencil box, pencil case 711 | pencil sharpener 712 | perfume, essence 713 | Petri dish 714 | photocopier 715 | pick, plectrum, plectron 716 | pickelhaube 717 | picket fence, paling 718 | pickup, pickup truck 719 | pier 720 | piggy bank, penny bank 721 | pill bottle 722 | pillow 723 | ping-pong ball 724 | pinwheel 725 | pirate, pirate ship 726 | pitcher, ewer 727 | plane, carpenter's plane, woodworking plane 728 | planetarium 729 | plastic bag 730 | plate rack 731 | plow, plough 732 | plunger, plumber's helper 733 | Polaroid camera, Polaroid Land camera 734 | pole 735 | police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria 736 | poncho 737 | pool table, billiard table, snooker table 738 | pop bottle, soda bottle 739 | pot, flowerpot 740 | potter's wheel 741 | power drill 742 | prayer rug, prayer mat 743 | printer 744 | prison, prison house 745 | projectile, missile 746 | projector 747 | puck, hockey puck 748 | punching bag, punch bag, punching ball, punchball 749 | purse 750 | quill, quill pen 751 | quilt, comforter, comfort, puff 752 | racer, race car, racing car 753 | racket, racquet 754 | radiator 755 | radio, wireless 756 | radio telescope, radio reflector 757 | rain barrel 758 | recreational vehicle, RV, R.V. 759 | reel 760 | reflex camera 761 | refrigerator, icebox 762 | remote control, remote 763 | restaurant, eating house, eating place, eatery 764 | revolver, six-gun, six-shooter 765 | rifle 766 | rocking chair, rocker 767 | rotisserie 768 | rubber eraser, rubber, pencil eraser 769 | rugby ball 770 | rule, ruler 771 | running shoe 772 | safe 773 | safety pin 774 | saltshaker, salt shaker 775 | sandal 776 | sarong 777 | sax, saxophone 778 | scabbard 779 | scale, weighing machine 780 | school bus 781 | schooner 782 | scoreboard 783 | screen, CRT screen 784 | screw 785 | screwdriver 786 | seat belt, seatbelt 787 | sewing machine 788 | shield, buckler 789 | shoe shop, shoe-shop, shoe store 790 | shoji 791 | shopping basket 792 | shopping cart 793 | shovel 794 | shower cap 795 | shower curtain 796 | ski 797 | ski mask 798 | sleeping bag 799 | slide rule, slipstick 800 | sliding door 801 | slot, one-armed bandit 802 | snorkel 803 | snowmobile 804 | snowplow, snowplough 805 | soap dispenser 806 | soccer ball 807 | sock 808 | solar dish, solar collector, solar furnace 809 | sombrero 810 | soup bowl 811 | space bar 812 | space heater 813 | space shuttle 814 | spatula 815 | speedboat 816 | spider web, spider's web 817 | spindle 818 | sports car, sport car 819 | spotlight, spot 820 | stage 821 | steam locomotive 822 | steel arch bridge 823 | steel drum 824 | stethoscope 825 | stole 826 | stone wall 827 | stopwatch, stop watch 828 | stove 829 | strainer 830 | streetcar, tram, tramcar, trolley, trolley car 831 | stretcher 832 | studio couch, day bed 833 | stupa, tope 834 | submarine, pigboat, sub, U-boat 835 | suit, suit of clothes 836 | sundial 837 | sunglass 838 | sunglasses, dark glasses, shades 839 | sunscreen, sunblock, sun blocker 840 | suspension bridge 841 | swab, swob, mop 842 | sweatshirt 843 | swimming trunks, bathing trunks 844 | swing 845 | switch, electric switch, electrical switch 846 | syringe 847 | table lamp 848 | tank, army tank, armored combat vehicle, armoured combat vehicle 849 | tape player 850 | teapot 851 | teddy, teddy bear 852 | television, television system 853 | tennis ball 854 | thatch, thatched roof 855 | theater curtain, theatre curtain 856 | thimble 857 | thresher, thrasher, threshing machine 858 | throne 859 | tile roof 860 | toaster 861 | tobacco shop, tobacconist shop, tobacconist 862 | toilet seat 863 | torch 864 | totem pole 865 | tow truck, tow car, wrecker 866 | toyshop 867 | tractor 868 | trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi 869 | tray 870 | trench coat 871 | tricycle, trike, velocipede 872 | trimaran 873 | tripod 874 | triumphal arch 875 | trolleybus, trolley coach, trackless trolley 876 | trombone 877 | tub, vat 878 | turnstile 879 | typewriter keyboard 880 | umbrella 881 | unicycle, monocycle 882 | upright, upright piano 883 | vacuum, vacuum cleaner 884 | vase 885 | vault 886 | velvet 887 | vending machine 888 | vestment 889 | viaduct 890 | violin, fiddle 891 | volleyball 892 | waffle iron 893 | wall clock 894 | wallet, billfold, notecase, pocketbook 895 | wardrobe, closet, press 896 | warplane, military plane 897 | washbasin, handbasin, washbowl, lavabo, wash-hand basin 898 | washer, automatic washer, washing machine 899 | water bottle 900 | water jug 901 | water tower 902 | whiskey jug 903 | whistle 904 | wig 905 | window screen 906 | window shade 907 | Windsor tie 908 | wine bottle 909 | wing 910 | wok 911 | wooden spoon 912 | wool, woolen, woollen 913 | worm fence, snake fence, snake-rail fence, Virginia fence 914 | wreck 915 | yawl 916 | yurt 917 | web site, website, internet site, site 918 | comic book 919 | crossword puzzle, crossword 920 | street sign 921 | traffic light, traffic signal, stoplight 922 | book jacket, dust cover, dust jacket, dust wrapper 923 | menu 924 | plate 925 | guacamole 926 | consomme 927 | hot pot, hotpot 928 | trifle 929 | ice cream, icecream 930 | ice lolly, lolly, lollipop, popsicle 931 | French loaf 932 | bagel, beigel 933 | pretzel 934 | cheeseburger 935 | hotdog, hot dog, red hot 936 | mashed potato 937 | head cabbage 938 | broccoli 939 | cauliflower 940 | zucchini, courgette 941 | spaghetti squash 942 | acorn squash 943 | butternut squash 944 | cucumber, cuke 945 | artichoke, globe artichoke 946 | bell pepper 947 | cardoon 948 | mushroom 949 | Granny Smith 950 | strawberry 951 | orange 952 | lemon 953 | fig 954 | pineapple, ananas 955 | banana 956 | jackfruit, jak, jack 957 | custard apple 958 | pomegranate 959 | hay 960 | carbonara 961 | chocolate sauce, chocolate syrup 962 | dough 963 | meat loaf, meatloaf 964 | pizza, pizza pie 965 | potpie 966 | burrito 967 | red wine 968 | espresso 969 | cup 970 | eggnog 971 | alp 972 | bubble 973 | cliff, drop, drop-off 974 | coral reef 975 | geyser 976 | lakeside, lakeshore 977 | promontory, headland, head, foreland 978 | sandbar, sand bar 979 | seashore, coast, seacoast, sea-coast 980 | valley, vale 981 | volcano 982 | ballplayer, baseball player 983 | groom, bridegroom 984 | scuba diver 985 | rapeseed 986 | daisy 987 | yellow lady's slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum 988 | corn 989 | acorn 990 | hip, rose hip, rosehip 991 | buckeye, horse chestnut, conker 992 | coral fungus 993 | agaric 994 | gyromitra 995 | stinkhorn, carrion fungus 996 | earthstar 997 | hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa 998 | bolete 999 | ear, spike, capitulum 1000 | toilet tissue, toilet paper, bathroom tissue -------------------------------------------------------------------------------- /ch9/word2vec.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | import collections 4 | import random 5 | import math 6 | import os 7 | 8 | tot_captions, only_captions = None, None 9 | 10 | data_indices = None 11 | reverse_dictionary = None 12 | embedding_size = None 13 | vocabulary_size = None 14 | max_caption_length = None 15 | 16 | def define_data_and_hyperparameters( 17 | _tot_captions, _only_captions, _reverse_dictionary, 18 | _emb_size, _vocab_size, _max_cap_length): 19 | global data_indices, tot_captions, only_captions, reverse_dictionary 20 | global embedding_size, vocabulary_size, max_caption_length 21 | 22 | tot_captions = _tot_captions 23 | only_captions = _only_captions 24 | 25 | data_indices = [0 for _ in range(tot_captions)] 26 | reverse_dictionary = _reverse_dictionary 27 | embedding_size = _emb_size 28 | vocabulary_size = _vocab_size 29 | max_caption_length = _max_cap_length 30 | 31 | 32 | def generate_batch_for_word2vec(batch_size, window_size): 33 | # window_size is the amount of words we're looking at from each side of a given word 34 | # creates a single batch 35 | global data_indices 36 | 37 | span = 2 * window_size + 1 # [ skip_window target skip_window ] 38 | 39 | batch = np.ndarray(shape=(batch_size, span - 1), dtype=np.int32) 40 | labels = np.ndarray(shape=(batch_size, 1), dtype=np.int32) 41 | # e.g if skip_window = 2 then span = 5 42 | # span is the length of the whole frame we are considering for a single word (left + word + right) 43 | # skip_window is the length of one side 44 | 45 | caption_ids_for_batch = np.random.randint(0, tot_captions, batch_size) 46 | 47 | for b_i in range(batch_size): 48 | cap_id = caption_ids_for_batch[b_i] 49 | 50 | buffer = only_captions[cap_id, data_indices[cap_id]:data_indices[cap_id] + span] 51 | assert buffer.size == span, 'Buffer length (%d), Current data index (%d), Span(%d)' % ( 52 | buffer.size, data_indices[cap_id], span) 53 | # If we only have EOS tokesn in the sampled text, we sample a new one 54 | while np.all(buffer == 1): 55 | # reset the data_indices for that cap_id 56 | data_indices[cap_id] = 0 57 | # sample a new cap_id 58 | cap_id = np.random.randint(0, tot_captions) 59 | buffer = only_captions[cap_id, data_indices[cap_id]:data_indices[cap_id] + span] 60 | 61 | # fill left and right sides of batch 62 | batch[b_i, :window_size] = buffer[:window_size] 63 | batch[b_i, window_size:] = buffer[window_size + 1:] 64 | 65 | labels[b_i, 0] = buffer[window_size] 66 | 67 | # increase the corresponding index 68 | data_indices[cap_id] = (data_indices[cap_id] + 1) % (max_caption_length - span) 69 | 70 | assert batch.shape[0] == batch_size and batch.shape[1] == span - 1 71 | return batch, labels 72 | 73 | def print_some_batches(): 74 | global data_indices, reverse_dictionary 75 | 76 | for w_size in [1, 2]: 77 | data_indices = [0 for _ in range(tot_captions)] 78 | batch, labels = generate_batch_for_word2vec(batch_size=8, window_size=w_size) 79 | print('\nwith window_size = %d:' %w_size) 80 | print(' batch:', [[reverse_dictionary[bii] for bii in bi] for bi in batch]) 81 | print(' labels:', [reverse_dictionary[li] for li in labels.reshape(8)]) 82 | 83 | batch_size, embedding_size, window_size = None, None, None 84 | valid_size, valid_window, valid_examples = None, None, None 85 | num_sampled = None 86 | 87 | train_dataset, train_labels = None, None 88 | valid_dataset = None 89 | 90 | softmax_weights, softmax_biases = None, None 91 | 92 | loss, optimizer, similarity, normalized_embeddings = None, None, None, None 93 | 94 | def define_word2vec_tensorflow(batch_size): 95 | global embedding_size, window_size 96 | global valid_size, valid_window, valid_examples 97 | global num_sampled 98 | global train_dataset, train_labels 99 | global valid_dataset 100 | global softmax_weights, softmax_biases 101 | global loss, optimizer, similarity 102 | global vocabulary_size, embedding_size 103 | global normalized_embeddings 104 | 105 | # How many words to consider left and right. 106 | # Skip gram by design does not require to have all the context words in a given step 107 | # However, for CBOW that's a requirement, so we limit the window size 108 | window_size = 3 109 | 110 | # We pick a random validation set to sample nearest neighbors 111 | valid_size = 16 # Random set of words to evaluate similarity on. 112 | # We sample valid datapoints randomly from a large window without always being deterministic 113 | valid_window = 50 114 | 115 | # When selecting valid examples, we select some of the most frequent words as well as 116 | # some moderately rare words as well 117 | valid_examples = np.array(random.sample(range(valid_window), valid_size)) 118 | valid_examples = np.append(valid_examples,random.sample(range(1000, 1000+valid_window), valid_size),axis=0) 119 | 120 | num_sampled = 32 # Number of negative examples to sample. 121 | 122 | tf.reset_default_graph() 123 | 124 | # Training input data (target word IDs). Note that it has 2*window_size columns 125 | train_dataset = tf.placeholder(tf.int32, shape=[batch_size,2*window_size]) 126 | # Training input label data (context word IDs) 127 | train_labels = tf.placeholder(tf.int32, shape=[batch_size, 1]) 128 | # Validation input data, we don't need a placeholder 129 | # as we have already defined the IDs of the words selected 130 | # as validation data 131 | valid_dataset = tf.constant(valid_examples, dtype=tf.int32) 132 | 133 | # Variables. 134 | 135 | # Embedding layer, contains the word embeddings 136 | embeddings = tf.Variable(tf.random_uniform([vocabulary_size, embedding_size], -1.0, 1.0,dtype=tf.float32)) 137 | 138 | # Softmax Weights and Biases 139 | softmax_weights = tf.Variable(tf.truncated_normal([vocabulary_size, embedding_size], 140 | stddev=0.5 / math.sqrt(embedding_size),dtype=tf.float32)) 141 | softmax_biases = tf.Variable(tf.random_uniform([vocabulary_size],0.0,0.01)) 142 | 143 | # Model. 144 | # Look up embeddings for a batch of inputs. 145 | # Here we do embedding lookups for each column in the input placeholder 146 | # and then average them to produce an embedding_size word vector 147 | stacked_embedings = None 148 | print('Defining %d embedding lookups representing each word in the context'%(2*window_size)) 149 | for i in range(2*window_size): 150 | embedding_i = tf.nn.embedding_lookup(embeddings, train_dataset[:,i]) 151 | x_size,y_size = embedding_i.get_shape().as_list() 152 | if stacked_embedings is None: 153 | stacked_embedings = tf.reshape(embedding_i,[x_size,y_size,1]) 154 | else: 155 | stacked_embedings = tf.concat(axis=2,values=[stacked_embedings,tf.reshape(embedding_i,[x_size,y_size,1])]) 156 | 157 | assert stacked_embedings.get_shape().as_list()[2]==2*window_size 158 | print("Stacked embedding size: %s"%stacked_embedings.get_shape().as_list()) 159 | mean_embeddings = tf.reduce_mean(stacked_embedings,2,keepdims=False) 160 | print("Reduced mean embedding size: %s"%mean_embeddings.get_shape().as_list()) 161 | 162 | 163 | # Compute the softmax loss, using a sample of the negative labels each time. 164 | # inputs are embeddings of the train words 165 | # with this loss we optimize weights, biases, embeddings 166 | loss = tf.reduce_mean(tf.nn.sampled_softmax_loss(weights=softmax_weights, biases=softmax_biases, inputs=mean_embeddings, 167 | labels=train_labels, num_sampled=num_sampled, num_classes=vocabulary_size)) 168 | # AdamOptimizer. 169 | optimizer = tf.train.AdamOptimizer(0.0005).minimize(loss) 170 | 171 | # Compute the similarity between minibatch examples and all embeddings. 172 | # We use the cosine distance: 173 | norm = tf.sqrt(tf.reduce_sum(tf.square(embeddings), 1, keepdims=True)) 174 | normalized_embeddings = embeddings / norm 175 | valid_embeddings = tf.nn.embedding_lookup(normalized_embeddings, valid_dataset) 176 | similarity = tf.matmul(valid_embeddings, tf.transpose(normalized_embeddings)) 177 | 178 | 179 | def run_word2vec(batch_size): 180 | global embedding_size, window_size 181 | global valid_size, valid_window, valid_examples 182 | global num_sampled 183 | global train_dataset, train_labels 184 | global valid_dataset 185 | global softmax_weights, softmax_biases 186 | global loss, optimizer, similarity, normalized_embeddings 187 | global data_list, num_files, reverse_dictionary 188 | global vocabulary_size, embedding_size 189 | 190 | work_dir = 'image_caption_data' 191 | num_steps = 100001 192 | 193 | session = tf.InteractiveSession() 194 | 195 | tf.global_variables_initializer().run() 196 | print('Initialized') 197 | average_loss = 0 198 | for step in range(num_steps): 199 | 200 | # Load a batch of data 201 | batch_data, batch_labels = generate_batch_for_word2vec(batch_size, window_size) 202 | 203 | # Populate the feed_dict and run the optimizer and get the loss out 204 | feed_dict = {train_dataset: batch_data, train_labels: batch_labels} 205 | _, l = session.run([optimizer, loss], feed_dict=feed_dict) 206 | 207 | average_loss += l 208 | 209 | if (step + 1) % 2000 == 0: 210 | if step > 0: 211 | # The average loss is an estimate of the loss over the last 2000 batches. 212 | average_loss = average_loss / 2000 213 | 214 | print('Average loss at step %d: %f' % (step + 1, average_loss)) 215 | average_loss = 0 # Reset average loss 216 | 217 | if (step + 1) % 10000 == 0: 218 | sim = similarity.eval() 219 | # Calculate the most similar (top_k) words 220 | # to the previosly selected set of valid words 221 | # Note that this is an expensive step 222 | for i in range(valid_size): 223 | valid_word = reverse_dictionary[valid_examples[i]] 224 | top_k = 3 # number of nearest neighbors 225 | nearest = (-sim[i, :]).argsort()[1:top_k + 1] 226 | log = 'Nearest to %s:' % valid_word 227 | for k in range(top_k): 228 | close_word = reverse_dictionary[nearest[k]] 229 | log = '%s %s,' % (log, close_word) 230 | print(log) 231 | 232 | # Get the normalized embeddings we learnt 233 | cbow_final_embeddings = normalized_embeddings.eval() 234 | 235 | # Save the embeddings to the disk as 'caption_embeddings-tmp.npy' 236 | # If you want to use this embeddings in the next steps 237 | # please change the filename to 'caption-embeddings.npy' 238 | np.save(os.path.join(work_dir,'caption-embeddings-tmp'), cbow_final_embeddings) --------------------------------------------------------------------------------