├── .gitignore ├── Corpus.pyx ├── LICENSE ├── README.md ├── images ├── 01.skipgram-objective.png ├── 01.skipgram-prepare-data.png ├── 02.skipgram-objective.png ├── 03.glove-objective.png ├── 03.glove-weighting-function.png ├── 04.window-classifier-architecture.png ├── 04.window-data.png ├── 05.neural-dparser-architecture.png ├── 05.transition-based-parse.png ├── 06.rnnlm-architecture.png ├── 07.attention-mechanism.png ├── 07.pad_to_sequence.png ├── 07.seq2seq.png ├── 08.cnn-for-text-architecture.png ├── 09.rntn-layer.png └── 10.dmn-architecture.png ├── notebooks ├── 01.Skip-gram-Naive-Softmax.ipynb ├── 02.Skip-gram-Negative-Sampling.ipynb ├── 03.GloVe.ipynb ├── 04.Window-Classifier-for-NER.ipynb ├── 05.Neural-Dependancy-Parser.ipynb ├── 06.RNN-Language-Model.ipynb ├── 07.Neural-Machine-Translation-with-Attention.ipynb ├── 08.CNN-for-Text-Classification.ipynb ├── 09.Recursive-NN-for-Sentiment-Classification.ipynb └── 10.Dynamic-Memory-Network-for-Question-Answering.ipynb ├── script ├── docker-compose.yml └── prepare_dataset.sh └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.txt 2 | __pycache__ 3 | .ipynb_checkpoints 4 | build/ 5 | dataset/ 6 | model/ 7 | *.cpp 8 | *.so 9 | *.bin 10 | *.pkl 11 | -------------------------------------------------------------------------------- /Corpus.pyx: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.sparse as sp 3 | from libcpp.vector cimport vector 4 | from libcpp.string cimport string 5 | from libcpp.map cimport map 6 | 7 | # Text 전처리 cpp extention 모듈 8 | 9 | def make_dictionary(vocab): 10 | cdef int vocab_size = len(vocab) 11 | cdef int i 12 | dictionary={} 13 | inv_dictionary={} 14 | for i in range(vocab_size): 15 | dictionary[vocab[i]] = i 16 | inv_dictionary[i]=vocab[i] 17 | return dictionary,inv_dictionary 18 | 19 | 20 | def make_window_data(sents,window_size): 21 | pass 22 | 23 | 24 | def make_coo_matrix(corpus,dictionary): 25 | cdef int matrix_size = len(dictionary) 26 | pass 27 | 28 | 29 | def getBatch_FromBucket(batch_size,buckets): 30 | i=0 31 | bucket_mask =[False for _ in range(len(buckets))] 32 | indices = [[0,batch_size] for _ in range(len(buckets))] 33 | is_done=False 34 | while is_done==False: 35 | batch = buckets[i][indices[i][0]:indices[i][1]] 36 | temp = indices[i][1] 37 | indices[i][1]= indices[i][1]+batch_size 38 | indices[i][0] = temp 39 | 40 | i = (i+1)%len(buckets) 41 | while bucket_mask[i]: 42 | i = (i+1)%len(buckets) 43 | 44 | if indices[i][1]>len(buckets[i]): 45 | bucket_mask[i]= True 46 | if bucket_mask.count(True)==len(buckets): 47 | is_done=True 48 | else: 49 | while bucket_mask[i]: 50 | i = (i+1)%len(buckets) 51 | yield batch -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 SungDong Kim 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DeepNLP-models-Pytorch 2 | 3 | Pytorch implementations of various Deep NLP models in cs-224n(Stanford Univ: NLP with Deep Learning) 4 | 5 | - This is not for Pytorch beginners. If it is your first time to use Pytorch, I recommend these [awesome tutorials](#references). 6 | - If you're interested in DeepNLP, I strongly recommend you to work with this awesome lecture. 7 | 8 | * cs-224n-slides 9 | * cs-224n-videos 10 | 11 | This material is not perfect but will help your study and research:) Please feel free to pull requests!! 12 | 13 |
14 | 15 | ## Contents 16 | 17 | | Model | Links | 18 | | ------------- |:-------------:| 19 | | 01. Skip-gram-Naive-Softmax | [notebook / data / paper] | 20 | | 02. Skip-gram-Negative-Sampling | [notebook / data / paper] | 21 | | 03. GloVe | [notebook / data / paper] | 22 | | 04. Window-Classifier-for-NER | [notebook / data / paper] | 23 | | 05. Neural-Dependancy-Parser | [notebook / data / paper] | 24 | | 06. RNN-Language-Model | [notebook / data / paper] | 25 | | 07. Neural-Machine-Translation-with-Attention | [notebook / data / paper] | 26 | | 08. CNN-for-Text-Classification | [notebook / data / paper] | 27 | | 09. Recursive-NN-for-Sentiment-Classification | [notebook / data / paper] | 28 | | 10. Dynamic-Memory-Network-for-Question-Answering | [notebook / data / paper] | 29 | 30 | 31 | ## Requirements 32 | 33 | - Python 3.5 34 | - Pytorch 0.2 35 | - nltk 3.2.2 36 | - gensim 2.2.0 37 | - sklearn_crfsuite 38 | 39 | 40 | ## Getting started 41 | 42 | `git clone https://github.com/DSKSD/cs-224n-Pytorch.git` 43 | 44 | ### prepare dataset 45 | 46 | ```` 47 | cd script 48 | chmod u+x prepare_dataset.sh 49 | ./prepare_dataset.sh 50 | ```` 51 | 52 | ### docker env 53 | ubuntu 16.04 python 3.5.2 with various of ML/DL packages including tensorflow, sklearn, pytorch 54 | 55 | `docker pull dsksd/deepstudy:0.2` 56 | 57 | ```` 58 | pip3 install docker-compose 59 | cd script 60 | docker-compose up -d 61 | ```` 62 | 63 | ### cloud setting 64 | 65 | `not yet` 66 | 67 | ## References 68 | 69 | * practical-pytorch 70 | * DeepLearningForNLPInPytorch 71 | * pytorch-tutorial 72 | * pytorch-examples 73 | 74 | ## Author 75 | 76 | Sungdong Kim / @DSKSD -------------------------------------------------------------------------------- /images/01.skipgram-objective.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shawn1993/DeepNLP-models-Pytorch/ceeb4221b176790229cd20c6ca4c05f625bdf02e/images/01.skipgram-objective.png -------------------------------------------------------------------------------- /images/01.skipgram-prepare-data.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shawn1993/DeepNLP-models-Pytorch/ceeb4221b176790229cd20c6ca4c05f625bdf02e/images/01.skipgram-prepare-data.png -------------------------------------------------------------------------------- /images/02.skipgram-objective.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shawn1993/DeepNLP-models-Pytorch/ceeb4221b176790229cd20c6ca4c05f625bdf02e/images/02.skipgram-objective.png -------------------------------------------------------------------------------- /images/03.glove-objective.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shawn1993/DeepNLP-models-Pytorch/ceeb4221b176790229cd20c6ca4c05f625bdf02e/images/03.glove-objective.png -------------------------------------------------------------------------------- /images/03.glove-weighting-function.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shawn1993/DeepNLP-models-Pytorch/ceeb4221b176790229cd20c6ca4c05f625bdf02e/images/03.glove-weighting-function.png -------------------------------------------------------------------------------- /images/04.window-classifier-architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shawn1993/DeepNLP-models-Pytorch/ceeb4221b176790229cd20c6ca4c05f625bdf02e/images/04.window-classifier-architecture.png -------------------------------------------------------------------------------- /images/04.window-data.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shawn1993/DeepNLP-models-Pytorch/ceeb4221b176790229cd20c6ca4c05f625bdf02e/images/04.window-data.png -------------------------------------------------------------------------------- /images/05.neural-dparser-architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shawn1993/DeepNLP-models-Pytorch/ceeb4221b176790229cd20c6ca4c05f625bdf02e/images/05.neural-dparser-architecture.png -------------------------------------------------------------------------------- /images/05.transition-based-parse.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shawn1993/DeepNLP-models-Pytorch/ceeb4221b176790229cd20c6ca4c05f625bdf02e/images/05.transition-based-parse.png -------------------------------------------------------------------------------- /images/06.rnnlm-architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shawn1993/DeepNLP-models-Pytorch/ceeb4221b176790229cd20c6ca4c05f625bdf02e/images/06.rnnlm-architecture.png -------------------------------------------------------------------------------- /images/07.attention-mechanism.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shawn1993/DeepNLP-models-Pytorch/ceeb4221b176790229cd20c6ca4c05f625bdf02e/images/07.attention-mechanism.png -------------------------------------------------------------------------------- /images/07.pad_to_sequence.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shawn1993/DeepNLP-models-Pytorch/ceeb4221b176790229cd20c6ca4c05f625bdf02e/images/07.pad_to_sequence.png -------------------------------------------------------------------------------- /images/07.seq2seq.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shawn1993/DeepNLP-models-Pytorch/ceeb4221b176790229cd20c6ca4c05f625bdf02e/images/07.seq2seq.png -------------------------------------------------------------------------------- /images/08.cnn-for-text-architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shawn1993/DeepNLP-models-Pytorch/ceeb4221b176790229cd20c6ca4c05f625bdf02e/images/08.cnn-for-text-architecture.png -------------------------------------------------------------------------------- /images/09.rntn-layer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shawn1993/DeepNLP-models-Pytorch/ceeb4221b176790229cd20c6ca4c05f625bdf02e/images/09.rntn-layer.png -------------------------------------------------------------------------------- /images/10.dmn-architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shawn1993/DeepNLP-models-Pytorch/ceeb4221b176790229cd20c6ca4c05f625bdf02e/images/10.dmn-architecture.png -------------------------------------------------------------------------------- /notebooks/01.Skip-gram-Naive-Softmax.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# 1. Skip-gram with naiive softmax " 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "I recommend you take a look at these material first." 15 | ] 16 | }, 17 | { 18 | "cell_type": "markdown", 19 | "metadata": {}, 20 | "source": [ 21 | "* http://web.stanford.edu/class/cs224n/lectures/cs224n-2017-lecture2.pdf\n", 22 | "* https://arxiv.org/abs/1301.3781\n", 23 | "* http://mccormickml.com/2016/04/19/word2vec-tutorial-the-skip-gram-model/" 24 | ] 25 | }, 26 | { 27 | "cell_type": "code", 28 | "execution_count": 1, 29 | "metadata": { 30 | "collapsed": true 31 | }, 32 | "outputs": [], 33 | "source": [ 34 | "import torch\n", 35 | "import torch.nn as nn\n", 36 | "from torch.autograd import Variable\n", 37 | "import torch.optim as optim\n", 38 | "import torch.nn.functional as F\n", 39 | "import nltk\n", 40 | "import random\n", 41 | "import numpy as np\n", 42 | "from collections import Counter\n", 43 | "flatten = lambda l: [item for sublist in l for item in sublist]" 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": 2, 49 | "metadata": { 50 | "collapsed": false 51 | }, 52 | "outputs": [ 53 | { 54 | "name": "stdout", 55 | "output_type": "stream", 56 | "text": [ 57 | "0.2.0+751198f\n", 58 | "3.2.4\n" 59 | ] 60 | } 61 | ], 62 | "source": [ 63 | "print(torch.__version__)\n", 64 | "print(nltk.__version__)" 65 | ] 66 | }, 67 | { 68 | "cell_type": "code", 69 | "execution_count": 3, 70 | "metadata": { 71 | "collapsed": true 72 | }, 73 | "outputs": [], 74 | "source": [ 75 | "USE_CUDA = torch.cuda.is_available()\n", 76 | "\n", 77 | "FloatTensor = torch.cuda.FloatTensor if USE_CUDA else torch.FloatTensor\n", 78 | "LongTensor = torch.cuda.LongTensor if USE_CUDA else torch.LongTensor\n", 79 | "ByteTensor = torch.cuda.ByteTensor if USE_CUDA else torch.ByteTensor" 80 | ] 81 | }, 82 | { 83 | "cell_type": "code", 84 | "execution_count": 4, 85 | "metadata": { 86 | "collapsed": true 87 | }, 88 | "outputs": [], 89 | "source": [ 90 | "def getBatch(batch_size,train_data):\n", 91 | " random.shuffle(train_data)\n", 92 | " sindex=0\n", 93 | " eindex=batch_size\n", 94 | " while eindex < len(train_data):\n", 95 | " batch = train_data[sindex:eindex]\n", 96 | " temp = eindex\n", 97 | " eindex = eindex+batch_size\n", 98 | " sindex = temp\n", 99 | " yield batch\n", 100 | " \n", 101 | " if eindex >= len(train_data):\n", 102 | " batch = train_data[sindex:]\n", 103 | " yield batch" 104 | ] 105 | }, 106 | { 107 | "cell_type": "code", 108 | "execution_count": 5, 109 | "metadata": { 110 | "collapsed": true 111 | }, 112 | "outputs": [], 113 | "source": [ 114 | "def prepare_sequence(seq, word2index):\n", 115 | " idxs = list(map(lambda w: word2index[w] if w in word2index.keys() else word2index[\"\"], seq))\n", 116 | " return Variable(LongTensor(idxs))\n", 117 | "\n", 118 | "def prepare_word(word,word2index):\n", 119 | " return Variable(LongTensor([word2index[word]]) if word in word2index.keys() else LongTensor([word2index[\"\"]]))" 120 | ] 121 | }, 122 | { 123 | "cell_type": "markdown", 124 | "metadata": {}, 125 | "source": [ 126 | "## Data load and Preprocessing " 127 | ] 128 | }, 129 | { 130 | "cell_type": "markdown", 131 | "metadata": {}, 132 | "source": [ 133 | "### Load corpus : Gutenberg corpus" 134 | ] 135 | }, 136 | { 137 | "cell_type": "markdown", 138 | "metadata": {}, 139 | "source": [ 140 | "If you don't have gutenberg corpus, you can download it first using nltk.download()" 141 | ] 142 | }, 143 | { 144 | "cell_type": "code", 145 | "execution_count": 6, 146 | "metadata": { 147 | "collapsed": false 148 | }, 149 | "outputs": [ 150 | { 151 | "data": { 152 | "text/plain": [ 153 | "['austen-emma.txt',\n", 154 | " 'austen-persuasion.txt',\n", 155 | " 'austen-sense.txt',\n", 156 | " 'bible-kjv.txt',\n", 157 | " 'blake-poems.txt',\n", 158 | " 'bryant-stories.txt',\n", 159 | " 'burgess-busterbrown.txt',\n", 160 | " 'carroll-alice.txt',\n", 161 | " 'chesterton-ball.txt',\n", 162 | " 'chesterton-brown.txt',\n", 163 | " 'chesterton-thursday.txt',\n", 164 | " 'edgeworth-parents.txt',\n", 165 | " 'melville-moby_dick.txt',\n", 166 | " 'milton-paradise.txt',\n", 167 | " 'shakespeare-caesar.txt',\n", 168 | " 'shakespeare-hamlet.txt',\n", 169 | " 'shakespeare-macbeth.txt',\n", 170 | " 'whitman-leaves.txt']" 171 | ] 172 | }, 173 | "execution_count": 6, 174 | "metadata": {}, 175 | "output_type": "execute_result" 176 | } 177 | ], 178 | "source": [ 179 | "nltk.corpus.gutenberg.fileids()" 180 | ] 181 | }, 182 | { 183 | "cell_type": "code", 184 | "execution_count": 7, 185 | "metadata": { 186 | "collapsed": true 187 | }, 188 | "outputs": [], 189 | "source": [ 190 | "corpus = list(nltk.corpus.gutenberg.sents('melville-moby_dick.txt'))[:100] # sampling sentences for test\n", 191 | "corpus = [[word.lower() for word in sent] for sent in corpus]" 192 | ] 193 | }, 194 | { 195 | "cell_type": "markdown", 196 | "metadata": {}, 197 | "source": [ 198 | "### Extract Stopwords from unigram distribution's tails" 199 | ] 200 | }, 201 | { 202 | "cell_type": "code", 203 | "execution_count": 8, 204 | "metadata": { 205 | "collapsed": true 206 | }, 207 | "outputs": [], 208 | "source": [ 209 | "word_count = Counter(flatten(corpus))\n", 210 | "border =int(len(word_count)*0.01) " 211 | ] 212 | }, 213 | { 214 | "cell_type": "code", 215 | "execution_count": 9, 216 | "metadata": { 217 | "collapsed": true 218 | }, 219 | "outputs": [], 220 | "source": [ 221 | "stopwords = word_count.most_common()[:border]+list(reversed(word_count.most_common()))[:border]" 222 | ] 223 | }, 224 | { 225 | "cell_type": "code", 226 | "execution_count": 10, 227 | "metadata": { 228 | "collapsed": true 229 | }, 230 | "outputs": [], 231 | "source": [ 232 | "stopwords = [s[0] for s in stopwords]" 233 | ] 234 | }, 235 | { 236 | "cell_type": "code", 237 | "execution_count": 11, 238 | "metadata": { 239 | "collapsed": false 240 | }, 241 | "outputs": [ 242 | { 243 | "data": { 244 | "text/plain": [ 245 | "[',', '.', 'the', 'of', 'and', 'baleine', '--(', 'fat', 'oil', 'boiling']" 246 | ] 247 | }, 248 | "execution_count": 11, 249 | "metadata": {}, 250 | "output_type": "execute_result" 251 | } 252 | ], 253 | "source": [ 254 | "stopwords" 255 | ] 256 | }, 257 | { 258 | "cell_type": "markdown", 259 | "metadata": {}, 260 | "source": [ 261 | "### Build vocab" 262 | ] 263 | }, 264 | { 265 | "cell_type": "code", 266 | "execution_count": 12, 267 | "metadata": { 268 | "collapsed": true 269 | }, 270 | "outputs": [], 271 | "source": [ 272 | "vocab = list(set(flatten(corpus))-set(stopwords))\n", 273 | "vocab.append('')" 274 | ] 275 | }, 276 | { 277 | "cell_type": "code", 278 | "execution_count": 13, 279 | "metadata": { 280 | "collapsed": false 281 | }, 282 | "outputs": [ 283 | { 284 | "name": "stdout", 285 | "output_type": "stream", 286 | "text": [ 287 | "592 583\n" 288 | ] 289 | } 290 | ], 291 | "source": [ 292 | "print(len(set(flatten(corpus))),len(vocab))" 293 | ] 294 | }, 295 | { 296 | "cell_type": "code", 297 | "execution_count": 14, 298 | "metadata": { 299 | "collapsed": true 300 | }, 301 | "outputs": [], 302 | "source": [ 303 | "word2index = {'' : 0} \n", 304 | "\n", 305 | "for vo in vocab:\n", 306 | " if vo not in word2index.keys():\n", 307 | " word2index[vo]=len(word2index)\n", 308 | "\n", 309 | "index2word = {v:k for k,v in word2index.items()} " 310 | ] 311 | }, 312 | { 313 | "cell_type": "markdown", 314 | "metadata": {}, 315 | "source": [ 316 | "### Prepare train data " 317 | ] 318 | }, 319 | { 320 | "cell_type": "markdown", 321 | "metadata": {}, 322 | "source": [ 323 | "window data example" 324 | ] 325 | }, 326 | { 327 | "cell_type": "markdown", 328 | "metadata": {}, 329 | "source": [ 330 | "\n", 331 | "
borrowed image from http://mccormickml.com/2016/04/19/word2vec-tutorial-the-skip-gram-model/
" 332 | ] 333 | }, 334 | { 335 | "cell_type": "code", 336 | "execution_count": 15, 337 | "metadata": { 338 | "collapsed": true 339 | }, 340 | "outputs": [], 341 | "source": [ 342 | "WINDOW_SIZE = 3\n", 343 | "windows = flatten([list(nltk.ngrams(['']*WINDOW_SIZE+c+['']*WINDOW_SIZE,WINDOW_SIZE*2+1)) for c in corpus])" 344 | ] 345 | }, 346 | { 347 | "cell_type": "code", 348 | "execution_count": 16, 349 | "metadata": { 350 | "collapsed": false 351 | }, 352 | "outputs": [ 353 | { 354 | "data": { 355 | "text/plain": [ 356 | "('', '', '', '[', 'moby', 'dick', 'by')" 357 | ] 358 | }, 359 | "execution_count": 16, 360 | "metadata": {}, 361 | "output_type": "execute_result" 362 | } 363 | ], 364 | "source": [ 365 | "windows[0]" 366 | ] 367 | }, 368 | { 369 | "cell_type": "code", 370 | "execution_count": 17, 371 | "metadata": { 372 | "collapsed": false 373 | }, 374 | "outputs": [ 375 | { 376 | "name": "stdout", 377 | "output_type": "stream", 378 | "text": [ 379 | "[('[', 'moby'), ('[', 'dick'), ('[', 'by'), ('moby', '['), ('moby', 'dick'), ('moby', 'by')]\n" 380 | ] 381 | } 382 | ], 383 | "source": [ 384 | "train_data = []\n", 385 | "\n", 386 | "for window in windows:\n", 387 | " for i in range(WINDOW_SIZE*2+1):\n", 388 | " if i==WINDOW_SIZE or window[i]=='': continue\n", 389 | " train_data.append((window[WINDOW_SIZE],window[i]))\n", 390 | "\n", 391 | "print(train_data[:WINDOW_SIZE*2])" 392 | ] 393 | }, 394 | { 395 | "cell_type": "code", 396 | "execution_count": 18, 397 | "metadata": { 398 | "collapsed": true 399 | }, 400 | "outputs": [], 401 | "source": [ 402 | "X_p=[]\n", 403 | "y_p=[]" 404 | ] 405 | }, 406 | { 407 | "cell_type": "code", 408 | "execution_count": 19, 409 | "metadata": { 410 | "collapsed": false 411 | }, 412 | "outputs": [ 413 | { 414 | "data": { 415 | "text/plain": [ 416 | "('[', 'moby')" 417 | ] 418 | }, 419 | "execution_count": 19, 420 | "metadata": {}, 421 | "output_type": "execute_result" 422 | } 423 | ], 424 | "source": [ 425 | "train_data[0]" 426 | ] 427 | }, 428 | { 429 | "cell_type": "code", 430 | "execution_count": 20, 431 | "metadata": { 432 | "collapsed": false 433 | }, 434 | "outputs": [], 435 | "source": [ 436 | "for tr in train_data:\n", 437 | " X_p.append(prepare_word(tr[0],word2index).view(1,-1))\n", 438 | " y_p.append(prepare_word(tr[1],word2index).view(1,-1))" 439 | ] 440 | }, 441 | { 442 | "cell_type": "code", 443 | "execution_count": 21, 444 | "metadata": { 445 | "collapsed": false 446 | }, 447 | "outputs": [], 448 | "source": [ 449 | "train_data = list(zip(X_p,y_p))" 450 | ] 451 | }, 452 | { 453 | "cell_type": "code", 454 | "execution_count": 22, 455 | "metadata": { 456 | "collapsed": false 457 | }, 458 | "outputs": [ 459 | { 460 | "data": { 461 | "text/plain": [ 462 | "7606" 463 | ] 464 | }, 465 | "execution_count": 22, 466 | "metadata": {}, 467 | "output_type": "execute_result" 468 | } 469 | ], 470 | "source": [ 471 | "len(train_data)" 472 | ] 473 | }, 474 | { 475 | "cell_type": "markdown", 476 | "metadata": {}, 477 | "source": [ 478 | "## Modeling" 479 | ] 480 | }, 481 | { 482 | "cell_type": "markdown", 483 | "metadata": {}, 484 | "source": [ 485 | "\n", 486 | "
borrowed image from http://web.stanford.edu/class/cs224n/lectures/cs224n-2017-lecture2.pdf
" 487 | ] 488 | }, 489 | { 490 | "cell_type": "code", 491 | "execution_count": 59, 492 | "metadata": { 493 | "collapsed": true 494 | }, 495 | "outputs": [], 496 | "source": [ 497 | "class Skipgram(nn.Module):\n", 498 | " \n", 499 | " def __init__(self, vocab_size,projection_dim):\n", 500 | " super(Skipgram,self).__init__()\n", 501 | " self.embedding_v = nn.Embedding(vocab_size, projection_dim)\n", 502 | " self.embedding_u = nn.Embedding(vocab_size, projection_dim)\n", 503 | "\n", 504 | " self.embedding_v.weight.data.uniform_(-1, 1) # init\n", 505 | " self.embedding_u.weight.data.uniform_(0, 0) # init\n", 506 | " #self.out = nn.Linear(projection_dim,vocab_size)\n", 507 | " def forward(self, center_words,target_words, outer_words):\n", 508 | " center_embeds = self.embedding_v(center_words) # B x 1 x D\n", 509 | " target_embeds = self.embedding_u(target_words) # B x 1 x D\n", 510 | " outer_embeds = self.embedding_u(outer_words) # B x V x D\n", 511 | " \n", 512 | " scores = target_embeds.bmm(center_embeds.transpose(1,2)).squeeze(2) # Bx1xD * BxDx1 => Bx1\n", 513 | " norm_scores = outer_embeds.bmm(center_embeds.transpose(1,2)).squeeze(2) # BxVxD * BxDx1 => BxV\n", 514 | " \n", 515 | " nll = -torch.mean(torch.log(torch.exp(scores)/torch.sum(torch.exp(norm_scores),1).unsqueeze(1))) # log-softmax\n", 516 | " \n", 517 | " return nll # negative log likelihood\n", 518 | " \n", 519 | " def prediction(self, inputs):\n", 520 | " embeds = self.embedding_v(inputs)\n", 521 | " \n", 522 | " return embeds " 523 | ] 524 | }, 525 | { 526 | "cell_type": "markdown", 527 | "metadata": {}, 528 | "source": [ 529 | "## Train " 530 | ] 531 | }, 532 | { 533 | "cell_type": "code", 534 | "execution_count": 60, 535 | "metadata": { 536 | "collapsed": true 537 | }, 538 | "outputs": [], 539 | "source": [ 540 | "EMBEDDING_SIZE = 30\n", 541 | "BATCH_SIZE = 256\n", 542 | "EPOCH = 100" 543 | ] 544 | }, 545 | { 546 | "cell_type": "code", 547 | "execution_count": 61, 548 | "metadata": { 549 | "collapsed": true 550 | }, 551 | "outputs": [], 552 | "source": [ 553 | "losses = []\n", 554 | "model = Skipgram(len(word2index),EMBEDDING_SIZE)\n", 555 | "if USE_CUDA:\n", 556 | " model = model.cuda()\n", 557 | "optimizer = optim.Adam(model.parameters(), lr=0.01)" 558 | ] 559 | }, 560 | { 561 | "cell_type": "code", 562 | "execution_count": 62, 563 | "metadata": { 564 | "collapsed": false 565 | }, 566 | "outputs": [ 567 | { 568 | "name": "stdout", 569 | "output_type": "stream", 570 | "text": [ 571 | "Epoch : 0, mean_loss : 6.20\n", 572 | "Epoch : 10, mean_loss : 4.38\n", 573 | "Epoch : 20, mean_loss : 3.48\n", 574 | "Epoch : 30, mean_loss : 3.31\n", 575 | "Epoch : 40, mean_loss : 3.26\n", 576 | "Epoch : 50, mean_loss : 3.24\n", 577 | "Epoch : 60, mean_loss : 3.22\n", 578 | "Epoch : 70, mean_loss : 3.22\n", 579 | "Epoch : 80, mean_loss : 3.21\n", 580 | "Epoch : 90, mean_loss : 3.20\n" 581 | ] 582 | } 583 | ], 584 | "source": [ 585 | "for epoch in range(EPOCH):\n", 586 | " for i,batch in enumerate(getBatch(BATCH_SIZE,train_data)):\n", 587 | " \n", 588 | " inputs, targets = zip(*batch)\n", 589 | " \n", 590 | " inputs = torch.cat(inputs) # B x 1\n", 591 | " targets = torch.cat(targets) # B x 1\n", 592 | " vocabs = prepare_sequence(list(vocab),word2index).expand(inputs.size(0),len(vocab)) # B x V\n", 593 | " model.zero_grad()\n", 594 | "\n", 595 | " loss = model(inputs,targets,vocabs)\n", 596 | " \n", 597 | " loss.backward()\n", 598 | " optimizer.step()\n", 599 | " \n", 600 | " losses.append(loss.data.tolist()[0])\n", 601 | "\n", 602 | " if epoch % 10==0:\n", 603 | " print(\"Epoch : %d, mean_loss : %.02f\" % (epoch,np.mean(losses)))\n", 604 | " losses=[]" 605 | ] 606 | }, 607 | { 608 | "cell_type": "markdown", 609 | "metadata": {}, 610 | "source": [ 611 | "## Test" 612 | ] 613 | }, 614 | { 615 | "cell_type": "code", 616 | "execution_count": 63, 617 | "metadata": { 618 | "collapsed": true 619 | }, 620 | "outputs": [], 621 | "source": [ 622 | "def word_similarity(target,vocab):\n", 623 | " if USE_CUDA:\n", 624 | " target_V = model.prediction(prepare_word(target,word2index))\n", 625 | " else:\n", 626 | " target_V = model.prediction(prepare_word(target,word2index))\n", 627 | " similarities=[]\n", 628 | " for i in range(len(vocab)):\n", 629 | " if vocab[i] == target: continue\n", 630 | " \n", 631 | " if USE_CUDA:\n", 632 | " vector = model.prediction(prepare_word(list(vocab)[i],word2index))\n", 633 | " else:\n", 634 | " vector = model.prediction(prepare_word(list(vocab)[i],word2index))\n", 635 | " cosine_sim = F.cosine_similarity(target_V,vector).data.tolist()[0] \n", 636 | " similarities.append([vocab[i],cosine_sim])\n", 637 | " return sorted(similarities, key=lambda x: x[1], reverse=True)[:10] # sort by similarity" 638 | ] 639 | }, 640 | { 641 | "cell_type": "code", 642 | "execution_count": 64, 643 | "metadata": { 644 | "collapsed": false 645 | }, 646 | "outputs": [ 647 | { 648 | "data": { 649 | "text/plain": [ 650 | "'least'" 651 | ] 652 | }, 653 | "execution_count": 64, 654 | "metadata": {}, 655 | "output_type": "execute_result" 656 | } 657 | ], 658 | "source": [ 659 | "test = random.choice(list(vocab))\n", 660 | "test" 661 | ] 662 | }, 663 | { 664 | "cell_type": "code", 665 | "execution_count": 65, 666 | "metadata": { 667 | "collapsed": false 668 | }, 669 | "outputs": [ 670 | { 671 | "data": { 672 | "text/plain": [ 673 | "[['at', 0.8147411346435547],\n", 674 | " ['every', 0.7143548130989075],\n", 675 | " ['case', 0.6975079774856567],\n", 676 | " ['secure', 0.6121522188186646],\n", 677 | " ['heart', 0.5974172949790955],\n", 678 | " ['including', 0.5867112278938293],\n", 679 | " ['please', 0.5557640194892883],\n", 680 | " ['has', 0.5536234974861145],\n", 681 | " ['while', 0.5366998314857483],\n", 682 | " ['you', 0.509368896484375]]" 683 | ] 684 | }, 685 | "execution_count": 65, 686 | "metadata": {}, 687 | "output_type": "execute_result" 688 | } 689 | ], 690 | "source": [ 691 | "word_similarity(test,vocab)" 692 | ] 693 | }, 694 | { 695 | "cell_type": "code", 696 | "execution_count": null, 697 | "metadata": { 698 | "collapsed": true 699 | }, 700 | "outputs": [], 701 | "source": [] 702 | } 703 | ], 704 | "metadata": { 705 | "kernelspec": { 706 | "display_name": "Python 3", 707 | "language": "python", 708 | "name": "python3" 709 | }, 710 | "language_info": { 711 | "codemirror_mode": { 712 | "name": "ipython", 713 | "version": 3 714 | }, 715 | "file_extension": ".py", 716 | "mimetype": "text/x-python", 717 | "name": "python", 718 | "nbconvert_exporter": "python", 719 | "pygments_lexer": "ipython3", 720 | "version": "3.5.2" 721 | } 722 | }, 723 | "nbformat": 4, 724 | "nbformat_minor": 2 725 | } 726 | -------------------------------------------------------------------------------- /notebooks/02.Skip-gram-Negative-Sampling.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# 2. Skip-gram with negative sampling" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "I recommend you take a look at these material first." 15 | ] 16 | }, 17 | { 18 | "cell_type": "markdown", 19 | "metadata": {}, 20 | "source": [ 21 | "* http://web.stanford.edu/class/cs224n/lectures/cs224n-2017-lecture3.pdf\n", 22 | "* http://papers.nips.cc/paper/5021-distributed-representations-of-words-and-phrases-and-their-compositionality.pdf" 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": 1, 28 | "metadata": { 29 | "collapsed": true 30 | }, 31 | "outputs": [], 32 | "source": [ 33 | "import torch\n", 34 | "import torch.nn as nn\n", 35 | "from torch.autograd import Variable\n", 36 | "import torch.optim as optim\n", 37 | "import torch.nn.functional as F\n", 38 | "import nltk\n", 39 | "import random\n", 40 | "import numpy as np\n", 41 | "from collections import Counter\n", 42 | "flatten = lambda l: [item for sublist in l for item in sublist]" 43 | ] 44 | }, 45 | { 46 | "cell_type": "code", 47 | "execution_count": 2, 48 | "metadata": { 49 | "collapsed": false 50 | }, 51 | "outputs": [ 52 | { 53 | "name": "stdout", 54 | "output_type": "stream", 55 | "text": [ 56 | "0.2.0+751198f\n", 57 | "3.2.4\n" 58 | ] 59 | } 60 | ], 61 | "source": [ 62 | "print(torch.__version__)\n", 63 | "print(nltk.__version__)" 64 | ] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "execution_count": 3, 69 | "metadata": { 70 | "collapsed": true 71 | }, 72 | "outputs": [], 73 | "source": [ 74 | "USE_CUDA = torch.cuda.is_available()\n", 75 | "\n", 76 | "FloatTensor = torch.cuda.FloatTensor if USE_CUDA else torch.FloatTensor\n", 77 | "LongTensor = torch.cuda.LongTensor if USE_CUDA else torch.LongTensor\n", 78 | "ByteTensor = torch.cuda.ByteTensor if USE_CUDA else torch.ByteTensor" 79 | ] 80 | }, 81 | { 82 | "cell_type": "code", 83 | "execution_count": 4, 84 | "metadata": { 85 | "collapsed": true 86 | }, 87 | "outputs": [], 88 | "source": [ 89 | "def getBatch(batch_size,train_data):\n", 90 | " random.shuffle(train_data)\n", 91 | " sindex=0\n", 92 | " eindex=batch_size\n", 93 | " while eindex < len(train_data):\n", 94 | " batch = train_data[sindex:eindex]\n", 95 | " temp = eindex\n", 96 | " eindex = eindex+batch_size\n", 97 | " sindex = temp\n", 98 | " yield batch\n", 99 | " \n", 100 | " if eindex >= len(train_data):\n", 101 | " batch = train_data[sindex:]\n", 102 | " yield batch" 103 | ] 104 | }, 105 | { 106 | "cell_type": "code", 107 | "execution_count": 5, 108 | "metadata": { 109 | "collapsed": true 110 | }, 111 | "outputs": [], 112 | "source": [ 113 | "def prepare_sequence(seq, word2index):\n", 114 | " idxs = list(map(lambda w: word2index[w] if w in word2index.keys() else word2index[\"\"], seq))\n", 115 | " return Variable(LongTensor(idxs))\n", 116 | "\n", 117 | "def prepare_word(word,word2index):\n", 118 | " return Variable(LongTensor([word2index[word]]) if word in word2index.keys() else LongTensor([word2index[\"\"]]))" 119 | ] 120 | }, 121 | { 122 | "cell_type": "markdown", 123 | "metadata": {}, 124 | "source": [ 125 | "## Data load and Preprocessing " 126 | ] 127 | }, 128 | { 129 | "cell_type": "code", 130 | "execution_count": 6, 131 | "metadata": { 132 | "collapsed": true 133 | }, 134 | "outputs": [], 135 | "source": [ 136 | "corpus = list(nltk.corpus.gutenberg.sents('melville-moby_dick.txt'))[:500]\n", 137 | "corpus = [[word.lower() for word in sent] for sent in corpus]" 138 | ] 139 | }, 140 | { 141 | "cell_type": "markdown", 142 | "metadata": {}, 143 | "source": [ 144 | "### Exclude sparse words " 145 | ] 146 | }, 147 | { 148 | "cell_type": "code", 149 | "execution_count": 7, 150 | "metadata": { 151 | "collapsed": true 152 | }, 153 | "outputs": [], 154 | "source": [ 155 | "word_count = Counter(flatten(corpus))" 156 | ] 157 | }, 158 | { 159 | "cell_type": "code", 160 | "execution_count": 8, 161 | "metadata": { 162 | "collapsed": true 163 | }, 164 | "outputs": [], 165 | "source": [ 166 | "MIN_COUNT=3\n", 167 | "exclude=[]" 168 | ] 169 | }, 170 | { 171 | "cell_type": "code", 172 | "execution_count": 9, 173 | "metadata": { 174 | "collapsed": true 175 | }, 176 | "outputs": [], 177 | "source": [ 178 | "for w,c in word_count.items():\n", 179 | " if c']*WINDOW_SIZE+c+['']*WINDOW_SIZE,WINDOW_SIZE*2+1)) for c in corpus])\n", 227 | "\n", 228 | "train_data = []\n", 229 | "\n", 230 | "for window in windows:\n", 231 | " for i in range(WINDOW_SIZE*2+1):\n", 232 | " if window[i] in exclude or window[WINDOW_SIZE] in exclude: continue # min_count\n", 233 | " if i==WINDOW_SIZE or window[i]=='': continue\n", 234 | " train_data.append((window[WINDOW_SIZE],window[i]))\n", 235 | "\n", 236 | "X_p=[]\n", 237 | "y_p=[]\n", 238 | "\n", 239 | "for tr in train_data:\n", 240 | " X_p.append(prepare_word(tr[0],word2index).view(1,-1))\n", 241 | " y_p.append(prepare_word(tr[1],word2index).view(1,-1))\n", 242 | " \n", 243 | "train_data = list(zip(X_p,y_p))" 244 | ] 245 | }, 246 | { 247 | "cell_type": "code", 248 | "execution_count": 13, 249 | "metadata": { 250 | "collapsed": false 251 | }, 252 | "outputs": [ 253 | { 254 | "data": { 255 | "text/plain": [ 256 | "50242" 257 | ] 258 | }, 259 | "execution_count": 13, 260 | "metadata": {}, 261 | "output_type": "execute_result" 262 | } 263 | ], 264 | "source": [ 265 | "len(train_data)" 266 | ] 267 | }, 268 | { 269 | "cell_type": "markdown", 270 | "metadata": {}, 271 | "source": [ 272 | "### Build Unigram Distribution**0.75 " 273 | ] 274 | }, 275 | { 276 | "cell_type": "markdown", 277 | "metadata": {}, 278 | "source": [ 279 | "$$P(w)=U(w)^{3/4}/Z$$" 280 | ] 281 | }, 282 | { 283 | "cell_type": "code", 284 | "execution_count": 14, 285 | "metadata": { 286 | "collapsed": true 287 | }, 288 | "outputs": [], 289 | "source": [ 290 | "Z = 0.001" 291 | ] 292 | }, 293 | { 294 | "cell_type": "code", 295 | "execution_count": 15, 296 | "metadata": { 297 | "collapsed": true 298 | }, 299 | "outputs": [], 300 | "source": [ 301 | "word_count = Counter(flatten(corpus))\n", 302 | "num_total_words = sum([c for w,c in word_count.items() if w not in exclude])" 303 | ] 304 | }, 305 | { 306 | "cell_type": "code", 307 | "execution_count": 16, 308 | "metadata": { 309 | "collapsed": true 310 | }, 311 | "outputs": [], 312 | "source": [ 313 | "unigram_table=[]\n", 314 | "\n", 315 | "for vo in vocab:\n", 316 | " unigram_table.extend([vo]*int(((word_count[vo]/num_total_words)**0.75)/Z))" 317 | ] 318 | }, 319 | { 320 | "cell_type": "code", 321 | "execution_count": 17, 322 | "metadata": { 323 | "collapsed": false 324 | }, 325 | "outputs": [ 326 | { 327 | "name": "stdout", 328 | "output_type": "stream", 329 | "text": [ 330 | "478 3500\n" 331 | ] 332 | } 333 | ], 334 | "source": [ 335 | "print(len(vocab),len(unigram_table))" 336 | ] 337 | }, 338 | { 339 | "cell_type": "markdown", 340 | "metadata": {}, 341 | "source": [ 342 | "### Negative Sampling " 343 | ] 344 | }, 345 | { 346 | "cell_type": "code", 347 | "execution_count": 18, 348 | "metadata": { 349 | "collapsed": true 350 | }, 351 | "outputs": [], 352 | "source": [ 353 | "def negative_sampling(targets,unigram_table,k):\n", 354 | " batch_size = targets.size(0)\n", 355 | " neg_samples=[]\n", 356 | " for i in range(batch_size):\n", 357 | " nsample=[]\n", 358 | " target_index = targets[i].data.cpu().tolist()[0] if USE_CUDA else targets[i].data.tolist()[0]\n", 359 | " while len(nsample)\n", 380 | "
borrowed image from http://web.stanford.edu/class/cs224n/lectures/cs224n-2017-lecture3.pdf
" 381 | ] 382 | }, 383 | { 384 | "cell_type": "code", 385 | "execution_count": 19, 386 | "metadata": { 387 | "collapsed": true 388 | }, 389 | "outputs": [], 390 | "source": [ 391 | "class SkipgramNegSampling(nn.Module):\n", 392 | " \n", 393 | " def __init__(self, vocab_size,projection_dim):\n", 394 | " super(SkipgramNegSampling,self).__init__()\n", 395 | " self.embedding_v = nn.Embedding(vocab_size, projection_dim) # center embedding\n", 396 | " self.embedding_u = nn.Embedding(vocab_size, projection_dim) # out embedding\n", 397 | " self.logsigmoid = nn.LogSigmoid()\n", 398 | " \n", 399 | " initrange = (2.0 / (vocab_size+projection_dim))**0.5 # Xavier init\n", 400 | " self.embedding_v.weight.data.uniform_(-initrange, initrange) # init\n", 401 | " self.embedding_u.weight.data.uniform_(-0.0, 0.0) # init\n", 402 | " \n", 403 | " def forward(self, center_words,target_words,negative_words):\n", 404 | " center_embeds = self.embedding_v(center_words) # B x 1 x D\n", 405 | " target_embeds = self.embedding_u(target_words) # B x 1 x D\n", 406 | " \n", 407 | " neg_embeds = -self.embedding_u(negative_words) # B x K x D\n", 408 | " \n", 409 | " positive_score = target_embeds.bmm(center_embeds.transpose(1,2)).squeeze(2) # Bx1\n", 410 | " negative_score = torch.sum(neg_embeds.bmm(center_embeds.transpose(1,2)).squeeze(2),1).view(negs.size(0),-1) # BxK -> Bx1\n", 411 | " \n", 412 | " loss = self.logsigmoid(positive_score) + self.logsigmoid(negative_score)\n", 413 | " \n", 414 | " return -torch.mean(loss)\n", 415 | " \n", 416 | " def prediction(self, inputs):\n", 417 | " embeds = self.embedding_v(inputs)\n", 418 | " \n", 419 | " return embeds" 420 | ] 421 | }, 422 | { 423 | "cell_type": "markdown", 424 | "metadata": {}, 425 | "source": [ 426 | "## Train " 427 | ] 428 | }, 429 | { 430 | "cell_type": "code", 431 | "execution_count": 68, 432 | "metadata": { 433 | "collapsed": true 434 | }, 435 | "outputs": [], 436 | "source": [ 437 | "EMBEDDING_SIZE = 30 \n", 438 | "BATCH_SIZE = 256\n", 439 | "EPOCH = 100\n", 440 | "NEG=10 # Num of Negative Sampling" 441 | ] 442 | }, 443 | { 444 | "cell_type": "code", 445 | "execution_count": 69, 446 | "metadata": { 447 | "collapsed": true 448 | }, 449 | "outputs": [], 450 | "source": [ 451 | "losses = []\n", 452 | "model = SkipgramNegSampling(len(word2index),EMBEDDING_SIZE)\n", 453 | "if USE_CUDA:\n", 454 | " model = model.cuda()\n", 455 | "optimizer = optim.Adam(model.parameters(), lr=0.001)" 456 | ] 457 | }, 458 | { 459 | "cell_type": "code", 460 | "execution_count": 70, 461 | "metadata": { 462 | "collapsed": false 463 | }, 464 | "outputs": [ 465 | { 466 | "name": "stdout", 467 | "output_type": "stream", 468 | "text": [ 469 | "Epoch : 0, mean_loss : 1.06\n", 470 | "Epoch : 10, mean_loss : 0.86\n", 471 | "Epoch : 20, mean_loss : 0.79\n", 472 | "Epoch : 30, mean_loss : 0.74\n", 473 | "Epoch : 40, mean_loss : 0.71\n", 474 | "Epoch : 50, mean_loss : 0.69\n", 475 | "Epoch : 60, mean_loss : 0.67\n", 476 | "Epoch : 70, mean_loss : 0.65\n", 477 | "Epoch : 80, mean_loss : 0.64\n", 478 | "Epoch : 90, mean_loss : 0.63\n" 479 | ] 480 | } 481 | ], 482 | "source": [ 483 | "for epoch in range(EPOCH):\n", 484 | " for i,batch in enumerate(getBatch(BATCH_SIZE,train_data)):\n", 485 | " \n", 486 | " inputs, targets = zip(*batch)\n", 487 | " \n", 488 | " inputs = torch.cat(inputs) # B x 1\n", 489 | " targets = torch.cat(targets) # B x 1\n", 490 | " negs = negative_sampling(targets,unigram_table,NEG)\n", 491 | " model.zero_grad()\n", 492 | "\n", 493 | " loss = model(inputs,targets,negs)\n", 494 | " \n", 495 | " loss.backward()\n", 496 | " optimizer.step()\n", 497 | " \n", 498 | " losses.append(loss.data.tolist()[0])\n", 499 | " if epoch % 10==0:\n", 500 | " print(\"Epoch : %d, mean_loss : %.02f\" % (epoch,np.mean(losses)))\n", 501 | " losses=[]" 502 | ] 503 | }, 504 | { 505 | "cell_type": "markdown", 506 | "metadata": {}, 507 | "source": [ 508 | "## Test " 509 | ] 510 | }, 511 | { 512 | "cell_type": "code", 513 | "execution_count": 71, 514 | "metadata": { 515 | "collapsed": true 516 | }, 517 | "outputs": [], 518 | "source": [ 519 | "def word_similarity(target,vocab):\n", 520 | " if USE_CUDA:\n", 521 | " target_V = model.prediction(prepare_word(target,word2index))\n", 522 | " else:\n", 523 | " target_V = model.prediction(prepare_word(target,word2index))\n", 524 | " similarities=[]\n", 525 | " for i in range(len(vocab)):\n", 526 | " if vocab[i] == target: continue\n", 527 | " \n", 528 | " if USE_CUDA:\n", 529 | " vector = model.prediction(prepare_word(list(vocab)[i],word2index))\n", 530 | " else:\n", 531 | " vector = model.prediction(prepare_word(list(vocab)[i],word2index))\n", 532 | " \n", 533 | " cosine_sim = F.cosine_similarity(target_V,vector).data.tolist()[0]\n", 534 | " similarities.append([vocab[i],cosine_sim])\n", 535 | " return sorted(similarities, key=lambda x: x[1], reverse=True)[:10]" 536 | ] 537 | }, 538 | { 539 | "cell_type": "code", 540 | "execution_count": 212, 541 | "metadata": { 542 | "collapsed": false 543 | }, 544 | "outputs": [ 545 | { 546 | "data": { 547 | "text/plain": [ 548 | "'passengers'" 549 | ] 550 | }, 551 | "execution_count": 212, 552 | "metadata": {}, 553 | "output_type": "execute_result" 554 | } 555 | ], 556 | "source": [ 557 | "test = random.choice(list(vocab))\n", 558 | "test" 559 | ] 560 | }, 561 | { 562 | "cell_type": "code", 563 | "execution_count": 213, 564 | "metadata": { 565 | "collapsed": false 566 | }, 567 | "outputs": [ 568 | { 569 | "data": { 570 | "text/plain": [ 571 | "[['am', 0.7353377342224121],\n", 572 | " ['passenger', 0.7154150605201721],\n", 573 | " ['cook', 0.6829826831817627],\n", 574 | " ['new', 0.6648461818695068],\n", 575 | " ['bedford', 0.6283411383628845],\n", 576 | " ['besides', 0.5972960591316223],\n", 577 | " ['themselves', 0.5964340567588806],\n", 578 | " ['grow', 0.5957046151161194],\n", 579 | " ['tell', 0.5952941179275513],\n", 580 | " ['get', 0.5943044424057007]]" 581 | ] 582 | }, 583 | "execution_count": 213, 584 | "metadata": {}, 585 | "output_type": "execute_result" 586 | } 587 | ], 588 | "source": [ 589 | "word_similarity(test,vocab)" 590 | ] 591 | }, 592 | { 593 | "cell_type": "code", 594 | "execution_count": null, 595 | "metadata": { 596 | "collapsed": true 597 | }, 598 | "outputs": [], 599 | "source": [] 600 | } 601 | ], 602 | "metadata": { 603 | "kernelspec": { 604 | "display_name": "Python 3", 605 | "language": "python", 606 | "name": "python3" 607 | }, 608 | "language_info": { 609 | "codemirror_mode": { 610 | "name": "ipython", 611 | "version": 3 612 | }, 613 | "file_extension": ".py", 614 | "mimetype": "text/x-python", 615 | "name": "python", 616 | "nbconvert_exporter": "python", 617 | "pygments_lexer": "ipython3", 618 | "version": "3.5.2" 619 | } 620 | }, 621 | "nbformat": 4, 622 | "nbformat_minor": 2 623 | } 624 | -------------------------------------------------------------------------------- /notebooks/03.GloVe.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# 3. GloVe: Global Vectors for Word Representation" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "I recommend you take a look at these material first." 15 | ] 16 | }, 17 | { 18 | "cell_type": "markdown", 19 | "metadata": {}, 20 | "source": [ 21 | "* http://web.stanford.edu/class/cs224n/lectures/cs224n-2017-lecture3.pdf\n", 22 | "* https://nlp.stanford.edu/pubs/glove.pdf" 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": 2, 28 | "metadata": { 29 | "collapsed": true 30 | }, 31 | "outputs": [], 32 | "source": [ 33 | "import torch\n", 34 | "import torch.nn as nn\n", 35 | "from torch.autograd import Variable\n", 36 | "import torch.optim as optim\n", 37 | "import torch.nn.functional as F\n", 38 | "import nltk\n", 39 | "import random\n", 40 | "import numpy as np\n", 41 | "from collections import Counter\n", 42 | "flatten = lambda l: [item for sublist in l for item in sublist]" 43 | ] 44 | }, 45 | { 46 | "cell_type": "code", 47 | "execution_count": 2, 48 | "metadata": { 49 | "collapsed": false 50 | }, 51 | "outputs": [ 52 | { 53 | "name": "stdout", 54 | "output_type": "stream", 55 | "text": [ 56 | "0.2.0+751198f\n", 57 | "3.2.4\n" 58 | ] 59 | } 60 | ], 61 | "source": [ 62 | "print(torch.__version__)\n", 63 | "print(nltk.__version__)" 64 | ] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "execution_count": 3, 69 | "metadata": { 70 | "collapsed": true 71 | }, 72 | "outputs": [], 73 | "source": [ 74 | "USE_CUDA = torch.cuda.is_available()\n", 75 | "\n", 76 | "FloatTensor = torch.cuda.FloatTensor if USE_CUDA else torch.FloatTensor\n", 77 | "LongTensor = torch.cuda.LongTensor if USE_CUDA else torch.LongTensor\n", 78 | "ByteTensor = torch.cuda.ByteTensor if USE_CUDA else torch.ByteTensor" 79 | ] 80 | }, 81 | { 82 | "cell_type": "code", 83 | "execution_count": 4, 84 | "metadata": { 85 | "collapsed": true 86 | }, 87 | "outputs": [], 88 | "source": [ 89 | "def getBatch(batch_size,train_data):\n", 90 | " random.shuffle(train_data)\n", 91 | " sindex=0\n", 92 | " eindex=batch_size\n", 93 | " while eindex < len(train_data):\n", 94 | " batch = train_data[sindex:eindex]\n", 95 | " temp = eindex\n", 96 | " eindex = eindex+batch_size\n", 97 | " sindex = temp\n", 98 | " yield batch\n", 99 | " \n", 100 | " if eindex >= len(train_data):\n", 101 | " batch = train_data[sindex:]\n", 102 | " yield batch" 103 | ] 104 | }, 105 | { 106 | "cell_type": "code", 107 | "execution_count": 5, 108 | "metadata": { 109 | "collapsed": true 110 | }, 111 | "outputs": [], 112 | "source": [ 113 | "def prepare_sequence(seq, word2index):\n", 114 | " idxs = list(map(lambda w: word2index[w] if w in word2index.keys() else word2index[\"\"], seq))\n", 115 | " return Variable(LongTensor(idxs))\n", 116 | "\n", 117 | "def prepare_word(word,word2index):\n", 118 | " return Variable(LongTensor([word2index[word]]) if word in word2index.keys() else LongTensor([word2index[\"\"]]))" 119 | ] 120 | }, 121 | { 122 | "cell_type": "markdown", 123 | "metadata": {}, 124 | "source": [ 125 | "## Data load and Preprocessing " 126 | ] 127 | }, 128 | { 129 | "cell_type": "code", 130 | "execution_count": 6, 131 | "metadata": { 132 | "collapsed": true 133 | }, 134 | "outputs": [], 135 | "source": [ 136 | "corpus = list(nltk.corpus.gutenberg.sents('melville-moby_dick.txt'))[:500]\n", 137 | "corpus = [[word.lower() for word in sent] for sent in corpus]" 138 | ] 139 | }, 140 | { 141 | "cell_type": "markdown", 142 | "metadata": {}, 143 | "source": [ 144 | "### Build vocab" 145 | ] 146 | }, 147 | { 148 | "cell_type": "code", 149 | "execution_count": 7, 150 | "metadata": { 151 | "collapsed": true 152 | }, 153 | "outputs": [], 154 | "source": [ 155 | "vocab = list(set(flatten(corpus)))" 156 | ] 157 | }, 158 | { 159 | "cell_type": "code", 160 | "execution_count": 8, 161 | "metadata": { 162 | "collapsed": true 163 | }, 164 | "outputs": [], 165 | "source": [ 166 | "word2index={}\n", 167 | "for vo in vocab:\n", 168 | " if vo not in word2index.keys():\n", 169 | " word2index[vo]=len(word2index)\n", 170 | " \n", 171 | "index2word={v:k for k,v in word2index.items()}" 172 | ] 173 | }, 174 | { 175 | "cell_type": "code", 176 | "execution_count": 9, 177 | "metadata": { 178 | "collapsed": true 179 | }, 180 | "outputs": [], 181 | "source": [ 182 | "WINDOW_SIZE = 5\n", 183 | "windows = flatten([list(nltk.ngrams(['']*WINDOW_SIZE+c+['']*WINDOW_SIZE,WINDOW_SIZE*2+1)) for c in corpus])\n", 184 | "\n", 185 | "window_data = []\n", 186 | "\n", 187 | "for window in windows:\n", 188 | " for i in range(WINDOW_SIZE*2+1):\n", 189 | " if i==WINDOW_SIZE or window[i]=='': continue\n", 190 | " window_data.append((window[WINDOW_SIZE],window[i]))\n" 191 | ] 192 | }, 193 | { 194 | "cell_type": "markdown", 195 | "metadata": {}, 196 | "source": [ 197 | "### Weighting Function " 198 | ] 199 | }, 200 | { 201 | "cell_type": "markdown", 202 | "metadata": {}, 203 | "source": [ 204 | "\n", 205 | "
borrowed image from https://nlp.stanford.edu/pubs/glove.pdf
" 206 | ] 207 | }, 208 | { 209 | "cell_type": "code", 210 | "execution_count": 10, 211 | "metadata": { 212 | "collapsed": true 213 | }, 214 | "outputs": [], 215 | "source": [ 216 | "def weighting(w_i,w_j):\n", 217 | " try:\n", 218 | " x_ij = X_ik[(w_i,w_j)]\n", 219 | " except:\n", 220 | " x_ij = 1\n", 221 | " \n", 222 | " x_max = 100 #100 # fixed in paper\n", 223 | " alpha = 0.75\n", 224 | " \n", 225 | " if x_ij < x_max:\n", 226 | " result = (x_ij/x_max)**alpha\n", 227 | " else:\n", 228 | " result = 1\n", 229 | " \n", 230 | " return result" 231 | ] 232 | }, 233 | { 234 | "cell_type": "markdown", 235 | "metadata": {}, 236 | "source": [ 237 | "### Build Co-occurence Matrix X" 238 | ] 239 | }, 240 | { 241 | "cell_type": "markdown", 242 | "metadata": {}, 243 | "source": [ 244 | "Because of model complexity, It is important to determine whether a tighter bound can be placed on the number of nonzero elements of X." 245 | ] 246 | }, 247 | { 248 | "cell_type": "code", 249 | "execution_count": 11, 250 | "metadata": { 251 | "collapsed": true 252 | }, 253 | "outputs": [], 254 | "source": [ 255 | "X_i = Counter(flatten(corpus)) # X_i" 256 | ] 257 | }, 258 | { 259 | "cell_type": "code", 260 | "execution_count": 12, 261 | "metadata": { 262 | "collapsed": true 263 | }, 264 | "outputs": [], 265 | "source": [ 266 | "X_ik_window_5 = Counter(window_data) # Co-occurece in window size 5" 267 | ] 268 | }, 269 | { 270 | "cell_type": "code", 271 | "execution_count": 13, 272 | "metadata": { 273 | "collapsed": true 274 | }, 275 | "outputs": [], 276 | "source": [ 277 | "X_ik={}\n", 278 | "weighting_dic={}" 279 | ] 280 | }, 281 | { 282 | "cell_type": "code", 283 | "execution_count": 14, 284 | "metadata": { 285 | "collapsed": true 286 | }, 287 | "outputs": [], 288 | "source": [ 289 | "from itertools import combinations_with_replacement" 290 | ] 291 | }, 292 | { 293 | "cell_type": "code", 294 | "execution_count": 15, 295 | "metadata": { 296 | "collapsed": true 297 | }, 298 | "outputs": [], 299 | "source": [ 300 | "for bigram in combinations_with_replacement(vocab, 2):\n", 301 | " if bigram in X_ik_window_5.keys(): # nonzero elements\n", 302 | " co_occer = X_ik_window_5[bigram]\n", 303 | " X_ik[bigram]=co_occer+1 # log(Xik) -> log(Xik+1) to prevent divergence\n", 304 | " X_ik[(bigram[1],bigram[0])]=co_occer+1\n", 305 | " else:\n", 306 | " pass\n", 307 | " \n", 308 | " weighting_dic[bigram] = weighting(bigram[0],bigram[1])\n", 309 | " weighting_dic[(bigram[1],bigram[0])] = weighting(bigram[1],bigram[0])" 310 | ] 311 | }, 312 | { 313 | "cell_type": "code", 314 | "execution_count": 16, 315 | "metadata": { 316 | "collapsed": false 317 | }, 318 | "outputs": [ 319 | { 320 | "name": "stdout", 321 | "output_type": "stream", 322 | "text": [ 323 | "(',', 'was')\n", 324 | "True\n" 325 | ] 326 | } 327 | ], 328 | "source": [ 329 | "test = random.choice(window_data)\n", 330 | "print(test)\n", 331 | "try:\n", 332 | " print(X_ik[(test[0],test[1])]==X_ik[(test[1],test[0])])\n", 333 | "except:\n", 334 | " 1" 335 | ] 336 | }, 337 | { 338 | "cell_type": "markdown", 339 | "metadata": {}, 340 | "source": [ 341 | "### Prepare train data" 342 | ] 343 | }, 344 | { 345 | "cell_type": "code", 346 | "execution_count": 17, 347 | "metadata": { 348 | "collapsed": false 349 | }, 350 | "outputs": [ 351 | { 352 | "name": "stdout", 353 | "output_type": "stream", 354 | "text": [ 355 | "(Variable containing:\n", 356 | " 703\n", 357 | "[torch.cuda.LongTensor of size 1x1 (GPU 0)]\n", 358 | ", Variable containing:\n", 359 | " 23\n", 360 | "[torch.cuda.LongTensor of size 1x1 (GPU 0)]\n", 361 | ", Variable containing:\n", 362 | " 0.6931\n", 363 | "[torch.cuda.FloatTensor of size 1x1 (GPU 0)]\n", 364 | ", Variable containing:\n", 365 | "1.00000e-02 *\n", 366 | " 5.3183\n", 367 | "[torch.cuda.FloatTensor of size 1x1 (GPU 0)]\n", 368 | ")\n" 369 | ] 370 | } 371 | ], 372 | "source": [ 373 | "u_p=[] # center vec\n", 374 | "v_p=[] # context vec\n", 375 | "co_p=[] # log(x_ij)\n", 376 | "weight_p=[] # f(x_ij)\n", 377 | "\n", 378 | "for pair in window_data: \n", 379 | " u_p.append(prepare_word(pair[0],word2index).view(1,-1))\n", 380 | " v_p.append(prepare_word(pair[1],word2index).view(1,-1))\n", 381 | " \n", 382 | " try:\n", 383 | " cooc = X_ik[pair]\n", 384 | " except:\n", 385 | " cooc = 1\n", 386 | "\n", 387 | " co_p.append(torch.log(Variable(FloatTensor([cooc]))).view(1,-1))\n", 388 | " weight_p.append(Variable(FloatTensor([weighting_dic[pair]])).view(1,-1))\n", 389 | " \n", 390 | "train_data = list(zip(u_p,v_p,co_p,weight_p))\n", 391 | "del u_p\n", 392 | "del v_p\n", 393 | "del co_p\n", 394 | "del weight_p\n", 395 | "print(train_data[0]) # tuple (center vec i, context vec j log(x_ij), weight f(w_ij))" 396 | ] 397 | }, 398 | { 399 | "cell_type": "markdown", 400 | "metadata": {}, 401 | "source": [ 402 | "## Modeling " 403 | ] 404 | }, 405 | { 406 | "cell_type": "markdown", 407 | "metadata": {}, 408 | "source": [ 409 | "\n", 410 | "
borrowed image from https://nlp.stanford.edu/pubs/glove.pdf
" 411 | ] 412 | }, 413 | { 414 | "cell_type": "code", 415 | "execution_count": 19, 416 | "metadata": { 417 | "collapsed": true 418 | }, 419 | "outputs": [], 420 | "source": [ 421 | "class GloVe(nn.Module):\n", 422 | " \n", 423 | " def __init__(self, vocab_size,projection_dim):\n", 424 | " super(GloVe,self).__init__()\n", 425 | " self.embedding_v = nn.Embedding(vocab_size, projection_dim) # center embedding\n", 426 | " self.embedding_u = nn.Embedding(vocab_size, projection_dim) # out embedding\n", 427 | " \n", 428 | " self.v_bias = nn.Embedding(vocab_size,1)\n", 429 | " self.u_bias = nn.Embedding(vocab_size,1)\n", 430 | " \n", 431 | " initrange = (2.0 / (vocab_size+projection_dim))**0.5 # Xavier init\n", 432 | " self.embedding_v.weight.data.uniform_(-initrange, initrange) # init\n", 433 | " self.embedding_u.weight.data.uniform_(-initrange, initrange) # init\n", 434 | " self.v_bias.weight.data.uniform_(-initrange, initrange) # init\n", 435 | " self.u_bias.weight.data.uniform_(-initrange, initrange) # init\n", 436 | " \n", 437 | " def forward(self, center_words,target_words,coocs,weights):\n", 438 | " center_embeds = self.embedding_v(center_words) # B x 1 x D\n", 439 | " target_embeds = self.embedding_u(target_words) # B x 1 x D\n", 440 | " \n", 441 | " center_bias = self.v_bias(center_words).squeeze(1)\n", 442 | " target_bias = self.u_bias(target_words).squeeze(1)\n", 443 | " \n", 444 | " inner_product = target_embeds.bmm(center_embeds.transpose(1,2)).squeeze(2) # Bx1\n", 445 | " \n", 446 | " loss = weights*torch.pow(inner_product +center_bias + target_bias - coocs,2)\n", 447 | " \n", 448 | " return torch.sum(loss)\n", 449 | " \n", 450 | " def prediction(self, inputs):\n", 451 | " v_embeds = self.embedding_v(inputs) # B x 1 x D\n", 452 | " u_embeds = self.embedding_u(inputs) # B x 1 x D\n", 453 | " \n", 454 | " return v_embeds+u_embeds # final embed" 455 | ] 456 | }, 457 | { 458 | "cell_type": "markdown", 459 | "metadata": {}, 460 | "source": [ 461 | "## Train " 462 | ] 463 | }, 464 | { 465 | "cell_type": "code", 466 | "execution_count": 22, 467 | "metadata": { 468 | "collapsed": true 469 | }, 470 | "outputs": [], 471 | "source": [ 472 | "EMBEDDING_SIZE = 50\n", 473 | "BATCH_SIZE = 256\n", 474 | "EPOCH = 50" 475 | ] 476 | }, 477 | { 478 | "cell_type": "code", 479 | "execution_count": 23, 480 | "metadata": { 481 | "collapsed": false 482 | }, 483 | "outputs": [], 484 | "source": [ 485 | "losses = []\n", 486 | "model = GloVe(len(word2index),EMBEDDING_SIZE)\n", 487 | "if USE_CUDA:\n", 488 | " model = model.cuda()\n", 489 | "optimizer = optim.Adam(model.parameters(), lr=0.001)" 490 | ] 491 | }, 492 | { 493 | "cell_type": "code", 494 | "execution_count": 24, 495 | "metadata": { 496 | "collapsed": false 497 | }, 498 | "outputs": [ 499 | { 500 | "name": "stdout", 501 | "output_type": "stream", 502 | "text": [ 503 | "Epoch : 0, mean_loss : 236.10\n", 504 | "Epoch : 10, mean_loss : 2.27\n", 505 | "Epoch : 20, mean_loss : 0.53\n", 506 | "Epoch : 30, mean_loss : 0.12\n", 507 | "Epoch : 40, mean_loss : 0.04\n" 508 | ] 509 | } 510 | ], 511 | "source": [ 512 | "for epoch in range(EPOCH):\n", 513 | " for i,batch in enumerate(getBatch(BATCH_SIZE,train_data)):\n", 514 | " \n", 515 | " inputs, targets, coocs, weights = zip(*batch)\n", 516 | " \n", 517 | " inputs = torch.cat(inputs) # B x 1\n", 518 | " targets = torch.cat(targets) # B x 1\n", 519 | " coocs = torch.cat(coocs)\n", 520 | " weights = torch.cat(weights)\n", 521 | " model.zero_grad()\n", 522 | "\n", 523 | " loss = model(inputs,targets,coocs,weights)\n", 524 | " \n", 525 | " loss.backward()\n", 526 | " optimizer.step()\n", 527 | " \n", 528 | " losses.append(loss.data.tolist()[0])\n", 529 | " if epoch % 10==0:\n", 530 | " print(\"Epoch : %d, mean_loss : %.02f\" % (epoch,np.mean(losses)))\n", 531 | " losses=[]" 532 | ] 533 | }, 534 | { 535 | "cell_type": "markdown", 536 | "metadata": {}, 537 | "source": [ 538 | "## Test " 539 | ] 540 | }, 541 | { 542 | "cell_type": "code", 543 | "execution_count": 25, 544 | "metadata": { 545 | "collapsed": true 546 | }, 547 | "outputs": [], 548 | "source": [ 549 | "def word_similarity(target,vocab):\n", 550 | " if USE_CUDA:\n", 551 | " target_V = model.prediction(prepare_word(target,word2index))\n", 552 | " else:\n", 553 | " target_V = model.prediction(prepare_word(target,word2index))\n", 554 | " similarities=[]\n", 555 | " for i in range(len(vocab)):\n", 556 | " if vocab[i] == target: continue\n", 557 | " \n", 558 | " if USE_CUDA:\n", 559 | " vector = model.prediction(prepare_word(list(vocab)[i],word2index))\n", 560 | " else:\n", 561 | " vector = model.prediction(prepare_word(list(vocab)[i],word2index))\n", 562 | " \n", 563 | " cosine_sim = F.cosine_similarity(target_V,vector).data.tolist()[0] \n", 564 | " similarities.append([vocab[i],cosine_sim])\n", 565 | " return sorted(similarities, key=lambda x: x[1], reverse=True)[:10]" 566 | ] 567 | }, 568 | { 569 | "cell_type": "code", 570 | "execution_count": 86, 571 | "metadata": { 572 | "collapsed": false 573 | }, 574 | "outputs": [ 575 | { 576 | "data": { 577 | "text/plain": [ 578 | "'spiral'" 579 | ] 580 | }, 581 | "execution_count": 86, 582 | "metadata": {}, 583 | "output_type": "execute_result" 584 | } 585 | ], 586 | "source": [ 587 | "test = random.choice(list(vocab))\n", 588 | "test" 589 | ] 590 | }, 591 | { 592 | "cell_type": "code", 593 | "execution_count": 87, 594 | "metadata": { 595 | "collapsed": false 596 | }, 597 | "outputs": [ 598 | { 599 | "data": { 600 | "text/plain": [ 601 | "[['horns', 0.9727935194969177],\n", 602 | " ['swords', 0.9076412916183472],\n", 603 | " ['hooked', 0.8984033465385437],\n", 604 | " ['thar', 0.8066437244415283],\n", 605 | " ['montaigne', 0.8062068819999695],\n", 606 | " ['rabelais', 0.789764940738678],\n", 607 | " ['orion', 0.7886737585067749],\n", 608 | " ['isaiah', 0.780662477016449],\n", 609 | " ['hamlet', 0.7799868583679199],\n", 610 | " ['colnett', 0.7792885899543762]]" 611 | ] 612 | }, 613 | "execution_count": 87, 614 | "metadata": {}, 615 | "output_type": "execute_result" 616 | } 617 | ], 618 | "source": [ 619 | "word_similarity(test,vocab)" 620 | ] 621 | }, 622 | { 623 | "cell_type": "markdown", 624 | "metadata": { 625 | "collapsed": true 626 | }, 627 | "source": [ 628 | "## TODO" 629 | ] 630 | }, 631 | { 632 | "cell_type": "markdown", 633 | "metadata": {}, 634 | "source": [ 635 | "* Use sparse-matrix to build co-occurence matrix for memory efficiency" 636 | ] 637 | }, 638 | { 639 | "cell_type": "markdown", 640 | "metadata": {}, 641 | "source": [ 642 | "## Suggested Readings" 643 | ] 644 | }, 645 | { 646 | "cell_type": "markdown", 647 | "metadata": {}, 648 | "source": [ 649 | "* Word embeddings in 2017: Trends and future directions" 650 | ] 651 | } 652 | ], 653 | "metadata": { 654 | "kernelspec": { 655 | "display_name": "Python 3", 656 | "language": "python", 657 | "name": "python3" 658 | }, 659 | "language_info": { 660 | "codemirror_mode": { 661 | "name": "ipython", 662 | "version": 3 663 | }, 664 | "file_extension": ".py", 665 | "mimetype": "text/x-python", 666 | "name": "python", 667 | "nbconvert_exporter": "python", 668 | "pygments_lexer": "ipython3", 669 | "version": "3.5.2" 670 | } 671 | }, 672 | "nbformat": 4, 673 | "nbformat_minor": 2 674 | } 675 | -------------------------------------------------------------------------------- /notebooks/04.Window-Classifier-for-NER.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# 4. Word Window Classification and Neural Networks " 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "I recommend you take a look at these material first." 15 | ] 16 | }, 17 | { 18 | "cell_type": "markdown", 19 | "metadata": {}, 20 | "source": [ 21 | "* http://web.stanford.edu/class/cs224n/lectures/cs224n-2017-lecture4.pdf\n", 22 | "* https://en.wikipedia.org/wiki/Named-entity_recognition" 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": 1, 28 | "metadata": { 29 | "collapsed": true 30 | }, 31 | "outputs": [], 32 | "source": [ 33 | "import torch\n", 34 | "import torch.nn as nn\n", 35 | "from torch.autograd import Variable\n", 36 | "import torch.optim as optim\n", 37 | "import torch.nn.functional as F\n", 38 | "import nltk\n", 39 | "import random\n", 40 | "import numpy as np\n", 41 | "from collections import Counter\n", 42 | "flatten = lambda l: [item for sublist in l for item in sublist]\n", 43 | "from sklearn_crfsuite import metrics" 44 | ] 45 | }, 46 | { 47 | "cell_type": "markdown", 48 | "metadata": {}, 49 | "source": [ 50 | "You also need sklearn_crfsuite latest version for print confusion matrix" 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": 2, 56 | "metadata": { 57 | "collapsed": false 58 | }, 59 | "outputs": [ 60 | { 61 | "name": "stdout", 62 | "output_type": "stream", 63 | "text": [ 64 | "0.2.0+751198f\n", 65 | "3.2.4\n" 66 | ] 67 | } 68 | ], 69 | "source": [ 70 | "print(torch.__version__)\n", 71 | "print(nltk.__version__)" 72 | ] 73 | }, 74 | { 75 | "cell_type": "code", 76 | "execution_count": 3, 77 | "metadata": { 78 | "collapsed": true 79 | }, 80 | "outputs": [], 81 | "source": [ 82 | "USE_CUDA = torch.cuda.is_available()\n", 83 | "\n", 84 | "FloatTensor = torch.cuda.FloatTensor if USE_CUDA else torch.FloatTensor\n", 85 | "LongTensor = torch.cuda.LongTensor if USE_CUDA else torch.LongTensor\n", 86 | "ByteTensor = torch.cuda.ByteTensor if USE_CUDA else torch.ByteTensor" 87 | ] 88 | }, 89 | { 90 | "cell_type": "code", 91 | "execution_count": 4, 92 | "metadata": { 93 | "collapsed": true 94 | }, 95 | "outputs": [], 96 | "source": [ 97 | "def getBatch(batch_size,train_data):\n", 98 | " random.shuffle(train_data)\n", 99 | " sindex=0\n", 100 | " eindex=batch_size\n", 101 | " while eindex < len(train_data):\n", 102 | " batch = train_data[sindex:eindex]\n", 103 | " temp = eindex\n", 104 | " eindex = eindex+batch_size\n", 105 | " sindex = temp\n", 106 | " yield batch\n", 107 | " \n", 108 | " if eindex >= len(train_data):\n", 109 | " batch = train_data[sindex:]\n", 110 | " yield batch" 111 | ] 112 | }, 113 | { 114 | "cell_type": "code", 115 | "execution_count": 5, 116 | "metadata": { 117 | "collapsed": true 118 | }, 119 | "outputs": [], 120 | "source": [ 121 | "def prepare_sequence(seq, word2index):\n", 122 | " idxs = list(map(lambda w: word2index[w] if w in word2index.keys() else word2index[\"\"], seq))\n", 123 | " return Variable(LongTensor(idxs))\n", 124 | "\n", 125 | "def prepare_word(word,word2index):\n", 126 | " return Variable(LongTensor([word2index[word]]) if word in word2index.keys() else LongTensor([word2index[\"\"]]))\n", 127 | "\n", 128 | "def prepare_tag(tag,tag2index):\n", 129 | " return Variable(LongTensor([tag2index[tag]]))" 130 | ] 131 | }, 132 | { 133 | "cell_type": "markdown", 134 | "metadata": {}, 135 | "source": [ 136 | "## Data load and Preprocessing " 137 | ] 138 | }, 139 | { 140 | "cell_type": "markdown", 141 | "metadata": {}, 142 | "source": [ 143 | "CoNLL-2002 Shared Task: Language-Independent Named Entity Recognition
\n", 144 | "https://www.clips.uantwerpen.be/conll2002/ner/" 145 | ] 146 | }, 147 | { 148 | "cell_type": "code", 149 | "execution_count": 6, 150 | "metadata": { 151 | "collapsed": true 152 | }, 153 | "outputs": [], 154 | "source": [ 155 | "corpus = nltk.corpus.conll2002.iob_sents()" 156 | ] 157 | }, 158 | { 159 | "cell_type": "code", 160 | "execution_count": 7, 161 | "metadata": { 162 | "collapsed": true 163 | }, 164 | "outputs": [], 165 | "source": [ 166 | "data=[]\n", 167 | "for cor in corpus:\n", 168 | " sent,_,tag = list(zip(*cor))\n", 169 | " data.append([sent,tag])" 170 | ] 171 | }, 172 | { 173 | "cell_type": "code", 174 | "execution_count": 8, 175 | "metadata": { 176 | "collapsed": false 177 | }, 178 | "outputs": [ 179 | { 180 | "name": "stdout", 181 | "output_type": "stream", 182 | "text": [ 183 | "35651\n", 184 | "[('Sao', 'Paulo', '(', 'Brasil', ')', ',', '23', 'may', '(', 'EFECOM', ')', '.'), ('B-LOC', 'I-LOC', 'O', 'B-LOC', 'O', 'O', 'O', 'O', 'O', 'B-ORG', 'O', 'O')]\n" 185 | ] 186 | } 187 | ], 188 | "source": [ 189 | "print(len(data))\n", 190 | "print(data[0])" 191 | ] 192 | }, 193 | { 194 | "cell_type": "markdown", 195 | "metadata": {}, 196 | "source": [ 197 | "### Build Vocab" 198 | ] 199 | }, 200 | { 201 | "cell_type": "code", 202 | "execution_count": 9, 203 | "metadata": { 204 | "collapsed": true 205 | }, 206 | "outputs": [], 207 | "source": [ 208 | "sents,tags = list(zip(*data))\n", 209 | "vocab = list(set(flatten(sents)))\n", 210 | "tagset = list(set(flatten(tags)))" 211 | ] 212 | }, 213 | { 214 | "cell_type": "code", 215 | "execution_count": 10, 216 | "metadata": { 217 | "collapsed": true 218 | }, 219 | "outputs": [], 220 | "source": [ 221 | "word2index={'' : 0, '' : 1} # dummy token is for start or end of sentence\n", 222 | "for vo in vocab:\n", 223 | " if vo not in word2index.keys():\n", 224 | " word2index[vo]=len(word2index)\n", 225 | "index2word = {v:k for k,v in word2index.items()}\n", 226 | "\n", 227 | "tag2index = {}\n", 228 | "for tag in tagset:\n", 229 | " if tag not in tag2index.keys():\n", 230 | " tag2index[tag]=len(tag2index)\n", 231 | "index2tag={v:k for k,v in tag2index.items()}" 232 | ] 233 | }, 234 | { 235 | "cell_type": "markdown", 236 | "metadata": {}, 237 | "source": [ 238 | "### Prepare data" 239 | ] 240 | }, 241 | { 242 | "cell_type": "markdown", 243 | "metadata": {}, 244 | "source": [ 245 | "
Example : Classify 'Paris' in the context of this sentence with window length 2
" 246 | ] 247 | }, 248 | { 249 | "cell_type": "markdown", 250 | "metadata": {}, 251 | "source": [ 252 | "" 253 | ] 254 | }, 255 | { 256 | "cell_type": "markdown", 257 | "metadata": {}, 258 | "source": [ 259 | "
borrowed image from http://web.stanford.edu/class/cs224n/lectures/cs224n-2017-lecture4.pdf
" 260 | ] 261 | }, 262 | { 263 | "cell_type": "code", 264 | "execution_count": 11, 265 | "metadata": { 266 | "collapsed": true 267 | }, 268 | "outputs": [], 269 | "source": [ 270 | "WINDOW_SIZE=2\n", 271 | "windows=[]" 272 | ] 273 | }, 274 | { 275 | "cell_type": "code", 276 | "execution_count": 12, 277 | "metadata": { 278 | "collapsed": true 279 | }, 280 | "outputs": [], 281 | "source": [ 282 | "for sample in data:\n", 283 | " dummy=['']*WINDOW_SIZE\n", 284 | " window = list(nltk.ngrams(dummy+list(sample[0])+dummy,WINDOW_SIZE*2+1))\n", 285 | " windows.extend([[list(window[i]),sample[1][i]] for i in range(len(sample[0]))])" 286 | ] 287 | }, 288 | { 289 | "cell_type": "code", 290 | "execution_count": 13, 291 | "metadata": { 292 | "collapsed": false 293 | }, 294 | "outputs": [ 295 | { 296 | "data": { 297 | "text/plain": [ 298 | "[['', '', 'Sao', 'Paulo', '('], 'B-LOC']" 299 | ] 300 | }, 301 | "execution_count": 13, 302 | "metadata": {}, 303 | "output_type": "execute_result" 304 | } 305 | ], 306 | "source": [ 307 | "windows[0]" 308 | ] 309 | }, 310 | { 311 | "cell_type": "code", 312 | "execution_count": 14, 313 | "metadata": { 314 | "collapsed": false 315 | }, 316 | "outputs": [ 317 | { 318 | "data": { 319 | "text/plain": [ 320 | "678377" 321 | ] 322 | }, 323 | "execution_count": 14, 324 | "metadata": {}, 325 | "output_type": "execute_result" 326 | } 327 | ], 328 | "source": [ 329 | "len(windows)" 330 | ] 331 | }, 332 | { 333 | "cell_type": "code", 334 | "execution_count": 15, 335 | "metadata": { 336 | "collapsed": true 337 | }, 338 | "outputs": [], 339 | "source": [ 340 | "random.shuffle(windows)\n", 341 | "\n", 342 | "train_data = windows[:int(len(windows)*0.9)]\n", 343 | "test_data = windows[int(len(windows)*0.9):]" 344 | ] 345 | }, 346 | { 347 | "cell_type": "markdown", 348 | "metadata": {}, 349 | "source": [ 350 | "## Modeling " 351 | ] 352 | }, 353 | { 354 | "cell_type": "markdown", 355 | "metadata": {}, 356 | "source": [ 357 | "\n", 358 | "
borrowed image from http://web.stanford.edu/class/cs224n/lectures/cs224n-2017-lecture4.pdf
" 359 | ] 360 | }, 361 | { 362 | "cell_type": "code", 363 | "execution_count": 16, 364 | "metadata": { 365 | "collapsed": true 366 | }, 367 | "outputs": [], 368 | "source": [ 369 | "class WindowClassifier(nn.Module): \n", 370 | " def __init__(self,vocab_size,embedding_size,window_size,hidden_size,output_size):\n", 371 | "\n", 372 | " super(WindowClassifier, self).__init__()\n", 373 | " \n", 374 | " self.embed = nn.Embedding(vocab_size,embedding_size)\n", 375 | " self.h_layer1 = nn.Linear(embedding_size*(window_size*2+1), hidden_size)\n", 376 | " self.h_layer2 = nn.Linear(hidden_size, hidden_size)\n", 377 | " self.o_layer = nn.Linear(hidden_size,output_size)\n", 378 | " self.relu = nn.ReLU()\n", 379 | " self.softmax = nn.LogSoftmax()\n", 380 | " self.dropout = nn.Dropout(0.3)\n", 381 | " \n", 382 | " def forward(self, inputs,is_training=False): \n", 383 | " embeds = self.embed(inputs) # BxWxD\n", 384 | " concated = embeds.view(-1,embeds.size(1)*embeds.size(2)) # Bx(W*D)\n", 385 | " h0 = self.relu(self.h_layer1(concated))\n", 386 | " if is_training:\n", 387 | " h0 = self.dropout(h0)\n", 388 | " h1 = self.relu(self.h_layer2(h0))\n", 389 | " if is_training:\n", 390 | " h1 = self.dropout(h1)\n", 391 | " out = self.softmax(self.o_layer(h1))\n", 392 | " return out" 393 | ] 394 | }, 395 | { 396 | "cell_type": "code", 397 | "execution_count": 20, 398 | "metadata": { 399 | "collapsed": true 400 | }, 401 | "outputs": [], 402 | "source": [ 403 | "BATCH_SIZE=128\n", 404 | "EMBEDDING_SIZE=50 # x (WINDOW_SIZE*2+1) = 250\n", 405 | "HIDDEN_SIZE=300\n", 406 | "EPOCH=3\n", 407 | "LEARNING_RATE = 0.001" 408 | ] 409 | }, 410 | { 411 | "cell_type": "markdown", 412 | "metadata": {}, 413 | "source": [ 414 | "## Training " 415 | ] 416 | }, 417 | { 418 | "cell_type": "markdown", 419 | "metadata": {}, 420 | "source": [ 421 | "It takes for a while if you use just cpu." 422 | ] 423 | }, 424 | { 425 | "cell_type": "code", 426 | "execution_count": 22, 427 | "metadata": { 428 | "collapsed": true 429 | }, 430 | "outputs": [], 431 | "source": [ 432 | "model = WindowClassifier(len(word2index),EMBEDDING_SIZE,WINDOW_SIZE,HIDDEN_SIZE,len(tag2index))\n", 433 | "if USE_CUDA:\n", 434 | " model = model.cuda()\n", 435 | "loss_function = nn.CrossEntropyLoss()\n", 436 | "optimizer = optim.Adam(model.parameters(),lr=LEARNING_RATE)" 437 | ] 438 | }, 439 | { 440 | "cell_type": "code", 441 | "execution_count": 23, 442 | "metadata": { 443 | "collapsed": false 444 | }, 445 | "outputs": [ 446 | { 447 | "name": "stdout", 448 | "output_type": "stream", 449 | "text": [ 450 | "[0/3] mean_loss : 2.25\n", 451 | "[0/3] mean_loss : 0.47\n", 452 | "[0/3] mean_loss : 0.36\n", 453 | "[0/3] mean_loss : 0.31\n", 454 | "[0/3] mean_loss : 0.28\n", 455 | "[1/3] mean_loss : 0.22\n", 456 | "[1/3] mean_loss : 0.21\n", 457 | "[1/3] mean_loss : 0.21\n", 458 | "[1/3] mean_loss : 0.19\n", 459 | "[1/3] mean_loss : 0.19\n", 460 | "[2/3] mean_loss : 0.12\n", 461 | "[2/3] mean_loss : 0.15\n", 462 | "[2/3] mean_loss : 0.15\n", 463 | "[2/3] mean_loss : 0.14\n", 464 | "[2/3] mean_loss : 0.14\n" 465 | ] 466 | } 467 | ], 468 | "source": [ 469 | "for epoch in range(EPOCH):\n", 470 | " losses=[]\n", 471 | " for i,batch in enumerate(getBatch(BATCH_SIZE,train_data)):\n", 472 | " x,y=list(zip(*batch))\n", 473 | " inputs = torch.cat([prepare_sequence(sent,word2index).view(1,-1) for sent in x])\n", 474 | " targets = torch.cat([prepare_tag(tag,tag2index) for tag in y])\n", 475 | " model.zero_grad()\n", 476 | " preds = model(inputs,is_training=True)\n", 477 | " loss = loss_function(preds,targets)\n", 478 | " losses.append(loss.data.tolist()[0])\n", 479 | " loss.backward()\n", 480 | " optimizer.step()\n", 481 | "\n", 482 | " if i % 1000==0:\n", 483 | " print(\"[%d/%d] mean_loss : %0.2f\" %(epoch,EPOCH,np.mean(losses)))\n", 484 | " losses=[]" 485 | ] 486 | }, 487 | { 488 | "cell_type": "markdown", 489 | "metadata": {}, 490 | "source": [ 491 | "## Test " 492 | ] 493 | }, 494 | { 495 | "cell_type": "code", 496 | "execution_count": 24, 497 | "metadata": { 498 | "collapsed": true 499 | }, 500 | "outputs": [], 501 | "source": [ 502 | "for_f1_score=[]" 503 | ] 504 | }, 505 | { 506 | "cell_type": "code", 507 | "execution_count": 25, 508 | "metadata": { 509 | "collapsed": false 510 | }, 511 | "outputs": [ 512 | { 513 | "name": "stdout", 514 | "output_type": "stream", 515 | "text": [ 516 | "95.69120551903063\n" 517 | ] 518 | } 519 | ], 520 | "source": [ 521 | "accuracy=0\n", 522 | "for test in test_data:\n", 523 | " x,y = test[0],test[1]\n", 524 | " input_ = prepare_sequence(x,word2index).view(1,-1)\n", 525 | "\n", 526 | " i = model(input_).max(1)[1]\n", 527 | " pred = index2tag[i.data.tolist()[0]]\n", 528 | " for_f1_score.append([pred,y])\n", 529 | " if pred==y:\n", 530 | " accuracy+=1\n", 531 | "\n", 532 | "print(accuracy/len(test_data)*100)" 533 | ] 534 | }, 535 | { 536 | "cell_type": "markdown", 537 | "metadata": {}, 538 | "source": [ 539 | "This high score is because most of labels are 'O' tag. So we need to measure f1 score." 540 | ] 541 | }, 542 | { 543 | "cell_type": "markdown", 544 | "metadata": {}, 545 | "source": [ 546 | "### Print Confusion matrix " 547 | ] 548 | }, 549 | { 550 | "cell_type": "code", 551 | "execution_count": 26, 552 | "metadata": { 553 | "collapsed": true 554 | }, 555 | "outputs": [], 556 | "source": [ 557 | "y_pred, y_test = list(zip(*for_f1_score))" 558 | ] 559 | }, 560 | { 561 | "cell_type": "code", 562 | "execution_count": 27, 563 | "metadata": { 564 | "collapsed": true 565 | }, 566 | "outputs": [], 567 | "source": [ 568 | "sorted_labels = sorted(\n", 569 | " list(set(y_test)-{'O'}),\n", 570 | " key=lambda name: (name[1:], name[0])\n", 571 | ")" 572 | ] 573 | }, 574 | { 575 | "cell_type": "code", 576 | "execution_count": 28, 577 | "metadata": { 578 | "collapsed": false 579 | }, 580 | "outputs": [ 581 | { 582 | "data": { 583 | "text/plain": [ 584 | "['B-LOC', 'I-LOC', 'B-MISC', 'I-MISC', 'B-ORG', 'I-ORG', 'B-PER', 'I-PER']" 585 | ] 586 | }, 587 | "execution_count": 28, 588 | "metadata": {}, 589 | "output_type": "execute_result" 590 | } 591 | ], 592 | "source": [ 593 | "sorted_labels" 594 | ] 595 | }, 596 | { 597 | "cell_type": "code", 598 | "execution_count": 29, 599 | "metadata": { 600 | "collapsed": true 601 | }, 602 | "outputs": [], 603 | "source": [ 604 | "y_pred = [[y] for y in y_pred] # this is because sklearn_crfsuite.metrics function flatten inputs\n", 605 | "y_test = [[y] for y in y_test]" 606 | ] 607 | }, 608 | { 609 | "cell_type": "code", 610 | "execution_count": 30, 611 | "metadata": { 612 | "collapsed": false 613 | }, 614 | "outputs": [ 615 | { 616 | "name": "stdout", 617 | "output_type": "stream", 618 | "text": [ 619 | " precision recall f1-score support\n", 620 | "\n", 621 | " B-LOC 0.802 0.636 0.710 1085\n", 622 | " I-LOC 0.732 0.457 0.562 311\n", 623 | " B-MISC 0.750 0.378 0.503 801\n", 624 | " I-MISC 0.679 0.331 0.445 641\n", 625 | " B-ORG 0.723 0.738 0.730 1430\n", 626 | " I-ORG 0.710 0.700 0.705 969\n", 627 | " B-PER 0.782 0.773 0.777 1268\n", 628 | " I-PER 0.853 0.871 0.861 950\n", 629 | "\n", 630 | "avg / total 0.759 0.656 0.693 7455\n", 631 | "\n" 632 | ] 633 | } 634 | ], 635 | "source": [ 636 | "print(metrics.flat_classification_report(\n", 637 | " y_test, y_pred, labels=sorted_labels, digits=3\n", 638 | "))" 639 | ] 640 | }, 641 | { 642 | "cell_type": "markdown", 643 | "metadata": { 644 | "collapsed": true 645 | }, 646 | "source": [ 647 | "### TODO" 648 | ] 649 | }, 650 | { 651 | "cell_type": "markdown", 652 | "metadata": {}, 653 | "source": [ 654 | "* use max-margin objective function http://pytorch.org/docs/master/nn.html#multilabelmarginloss" 655 | ] 656 | }, 657 | { 658 | "cell_type": "code", 659 | "execution_count": null, 660 | "metadata": { 661 | "collapsed": true 662 | }, 663 | "outputs": [], 664 | "source": [] 665 | } 666 | ], 667 | "metadata": { 668 | "kernelspec": { 669 | "display_name": "Python 3", 670 | "language": "python", 671 | "name": "python3" 672 | }, 673 | "language_info": { 674 | "codemirror_mode": { 675 | "name": "ipython", 676 | "version": 3 677 | }, 678 | "file_extension": ".py", 679 | "mimetype": "text/x-python", 680 | "name": "python", 681 | "nbconvert_exporter": "python", 682 | "pygments_lexer": "ipython3", 683 | "version": "3.5.2" 684 | } 685 | }, 686 | "nbformat": 4, 687 | "nbformat_minor": 2 688 | } 689 | -------------------------------------------------------------------------------- /notebooks/06.RNN-Language-Model.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# 6. Recurrent Neural Networks and Language Models" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "* http://web.stanford.edu/class/cs224n/lectures/cs224n-2017-lecture8.pdf\n", 15 | "* http://web.stanford.edu/class/cs224n/lectures/cs224n-2017-lecture9.pdf\n", 16 | "* http://colah.github.io/posts/2015-08-Understanding-LSTMs/\n", 17 | "* https://github.com/pytorch/examples/tree/master/word_language_model\n", 18 | "* https://github.com/yunjey/pytorch-tutorial/blob/master/tutorials/02-intermediate/language_model" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": 1, 24 | "metadata": { 25 | "collapsed": true 26 | }, 27 | "outputs": [], 28 | "source": [ 29 | "import torch\n", 30 | "import torch.nn as nn\n", 31 | "from torch.autograd import Variable\n", 32 | "import torch.optim as optim\n", 33 | "import torch.nn.functional as F\n", 34 | "import nltk\n", 35 | "import random\n", 36 | "import numpy as np\n", 37 | "from collections import Counter, OrderedDict\n", 38 | "import nltk\n", 39 | "from copy import deepcopy\n", 40 | "flatten = lambda l: [item for sublist in l for item in sublist]" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": 2, 46 | "metadata": { 47 | "collapsed": true 48 | }, 49 | "outputs": [], 50 | "source": [ 51 | "USE_CUDA = torch.cuda.is_available()\n", 52 | "\n", 53 | "FloatTensor = torch.cuda.FloatTensor if USE_CUDA else torch.FloatTensor\n", 54 | "LongTensor = torch.cuda.LongTensor if USE_CUDA else torch.LongTensor\n", 55 | "ByteTensor = torch.cuda.ByteTensor if USE_CUDA else torch.ByteTensor" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": 4, 61 | "metadata": { 62 | "collapsed": true 63 | }, 64 | "outputs": [], 65 | "source": [ 66 | "def prepare_sequence(seq, to_index):\n", 67 | " idxs = list(map(lambda w: to_index[w] if w in to_index.keys() else to_index[\"\"], seq))\n", 68 | " return LongTensor(idxs)" 69 | ] 70 | }, 71 | { 72 | "cell_type": "markdown", 73 | "metadata": {}, 74 | "source": [ 75 | "## Data load and Preprocessing" 76 | ] 77 | }, 78 | { 79 | "cell_type": "markdown", 80 | "metadata": {}, 81 | "source": [ 82 | "### Penn TreeBank" 83 | ] 84 | }, 85 | { 86 | "cell_type": "code", 87 | "execution_count": 5, 88 | "metadata": { 89 | "collapsed": true 90 | }, 91 | "outputs": [], 92 | "source": [ 93 | "def prepare_ptb_dataset(filename,word2index=None):\n", 94 | " corpus = open(filename,'r',encoding='utf-8').readlines()\n", 95 | " corpus = flatten([co.strip().split() + [''] for co in corpus])\n", 96 | " \n", 97 | " if word2index==None:\n", 98 | " vocab = list(set(corpus))\n", 99 | " word2index={'':0}\n", 100 | " for vo in vocab:\n", 101 | " if vo not in word2index.keys():\n", 102 | " word2index[vo]=len(word2index)\n", 103 | " \n", 104 | " return prepare_sequence(corpus,word2index), word2index" 105 | ] 106 | }, 107 | { 108 | "cell_type": "code", 109 | "execution_count": 175, 110 | "metadata": { 111 | "collapsed": true 112 | }, 113 | "outputs": [], 114 | "source": [ 115 | "# borrowed code from https://github.com/pytorch/examples/tree/master/word_language_model\n", 116 | "\n", 117 | "def batchify(data, bsz):\n", 118 | " # Work out how cleanly we can divide the dataset into bsz parts.\n", 119 | " nbatch = data.size(0) // bsz\n", 120 | " # Trim off any extra elements that wouldn't cleanly fit (remainders).\n", 121 | " data = data.narrow(0, 0, nbatch * bsz)\n", 122 | " # Evenly divide the data across the bsz batches.\n", 123 | " data = data.view(bsz, -1).contiguous()\n", 124 | " if USE_CUDA:\n", 125 | " data = data.cuda()\n", 126 | " return data" 127 | ] 128 | }, 129 | { 130 | "cell_type": "code", 131 | "execution_count": 176, 132 | "metadata": { 133 | "collapsed": true 134 | }, 135 | "outputs": [], 136 | "source": [ 137 | "def getBatch(data,seq_length):\n", 138 | " for i in range(0, data.size(1) - seq_length, seq_length):\n", 139 | " inputs = Variable(data[:, i:i+seq_length])\n", 140 | " targets = Variable(data[:, (i+1):(i+1)+seq_length].contiguous())\n", 141 | " yield (inputs,targets)" 142 | ] 143 | }, 144 | { 145 | "cell_type": "code", 146 | "execution_count": 177, 147 | "metadata": { 148 | "collapsed": false 149 | }, 150 | "outputs": [], 151 | "source": [ 152 | "train_data, word2index= prepare_ptb_dataset('../dataset/ptb/ptb.train.txt',)\n", 153 | "dev_data , _ = prepare_ptb_dataset('../dataset/ptb/ptb.valid.txt',word2index)\n", 154 | "test_data, _ = prepare_ptb_dataset('../dataset/ptb/ptb.test.txt',word2index)" 155 | ] 156 | }, 157 | { 158 | "cell_type": "code", 159 | "execution_count": 178, 160 | "metadata": { 161 | "collapsed": false 162 | }, 163 | "outputs": [ 164 | { 165 | "data": { 166 | "text/plain": [ 167 | "10000" 168 | ] 169 | }, 170 | "execution_count": 178, 171 | "metadata": {}, 172 | "output_type": "execute_result" 173 | } 174 | ], 175 | "source": [ 176 | "len(word2index)" 177 | ] 178 | }, 179 | { 180 | "cell_type": "code", 181 | "execution_count": 179, 182 | "metadata": { 183 | "collapsed": true 184 | }, 185 | "outputs": [], 186 | "source": [ 187 | "index2word = {v:k for k,v in word2index.items()}" 188 | ] 189 | }, 190 | { 191 | "cell_type": "markdown", 192 | "metadata": {}, 193 | "source": [ 194 | "## Modeling " 195 | ] 196 | }, 197 | { 198 | "cell_type": "markdown", 199 | "metadata": {}, 200 | "source": [ 201 | "\n", 202 | "
borrowed image from http://web.stanford.edu/class/cs224n/lectures/cs224n-2017-lecture8.pdf
" 203 | ] 204 | }, 205 | { 206 | "cell_type": "code", 207 | "execution_count": 180, 208 | "metadata": { 209 | "collapsed": false 210 | }, 211 | "outputs": [], 212 | "source": [ 213 | "class LanguageModel(nn.Module): \n", 214 | " def __init__(self,vocab_size,embedding_size,hidden_size,n_layers=1,dropout_p=0.5):\n", 215 | "\n", 216 | " super(LanguageModel, self).__init__()\n", 217 | " self.n_layers = n_layers\n", 218 | " self.hidden_size = hidden_size\n", 219 | " self.embed = nn.Embedding(vocab_size,embedding_size)\n", 220 | " self.rnn = nn.LSTM(embedding_size,hidden_size,n_layers,batch_first=True)\n", 221 | " self.linear = nn.Linear(hidden_size,vocab_size)\n", 222 | " self.dropout = nn.Dropout(dropout_p)\n", 223 | " \n", 224 | " def init_weight(self):\n", 225 | " self.embed.weight = nn.init.xavier_uniform(self.embed.weight)\n", 226 | " self.linear.weight = nn.init.xavier_uniform(self.linear.weight)\n", 227 | " self.linear.bias.data.fill_(0)\n", 228 | " \n", 229 | " def init_hidden(self,batch_size):\n", 230 | " hidden = Variable(torch.zeros(self.n_layers,batch_size,self.hidden_size))\n", 231 | " context = Variable(torch.zeros(self.n_layers,batch_size,self.hidden_size))\n", 232 | " return (hidden.cuda(), context.cuda()) if USE_CUDA else (hidden,context)\n", 233 | " \n", 234 | " def detach_hidden(self,hiddens):\n", 235 | " return tuple([hidden.detach() for hidden in hiddens])\n", 236 | " \n", 237 | " def forward(self, inputs,hidden,is_training=False): \n", 238 | "\n", 239 | " embeds = self.embed(inputs)\n", 240 | " if is_training:\n", 241 | " embeds = self.dropout(embeds)\n", 242 | " out,hidden = self.rnn(embeds,hidden)\n", 243 | " return self.linear(out.contiguous().view(out.size(0)*out.size(1),-1)), hidden" 244 | ] 245 | }, 246 | { 247 | "cell_type": "markdown", 248 | "metadata": {}, 249 | "source": [ 250 | "## Train " 251 | ] 252 | }, 253 | { 254 | "cell_type": "markdown", 255 | "metadata": {}, 256 | "source": [ 257 | "It takes for a while..." 258 | ] 259 | }, 260 | { 261 | "cell_type": "code", 262 | "execution_count": 181, 263 | "metadata": { 264 | "collapsed": true 265 | }, 266 | "outputs": [], 267 | "source": [ 268 | "EMBED_SIZE=128\n", 269 | "HIDDEN_SIZE=1024\n", 270 | "NUM_LAYER=1\n", 271 | "LR = 0.01\n", 272 | "SEQ_LENGTH = 30 # for bptt\n", 273 | "BATCH_SIZE = 20\n", 274 | "EPOCH = 40\n", 275 | "RESCHEDULED=False" 276 | ] 277 | }, 278 | { 279 | "cell_type": "code", 280 | "execution_count": 182, 281 | "metadata": { 282 | "collapsed": true 283 | }, 284 | "outputs": [], 285 | "source": [ 286 | "train_data = batchify(train_data,BATCH_SIZE)\n", 287 | "dev_data = batchify(dev_data,BATCH_SIZE//2)\n", 288 | "test_data = batchify(test_data,BATCH_SIZE//2)" 289 | ] 290 | }, 291 | { 292 | "cell_type": "code", 293 | "execution_count": 185, 294 | "metadata": { 295 | "collapsed": false 296 | }, 297 | "outputs": [], 298 | "source": [ 299 | "model = LanguageModel(len(word2index),EMBED_SIZE,HIDDEN_SIZE,NUM_LAYER,0.5)\n", 300 | "model.init_weight() \n", 301 | "if USE_CUDA:\n", 302 | " model = model.cuda()\n", 303 | "loss_function = nn.CrossEntropyLoss()\n", 304 | "optimizer = optim.Adam(model.parameters(),lr=LR)" 305 | ] 306 | }, 307 | { 308 | "cell_type": "code", 309 | "execution_count": 186, 310 | "metadata": { 311 | "collapsed": false 312 | }, 313 | "outputs": [ 314 | { 315 | "name": "stdout", 316 | "output_type": "stream", 317 | "text": [ 318 | "[00/40] mean_loss : 9.45, Perplexity : 12712.23\n", 319 | "[00/40] mean_loss : 5.88, Perplexity : 358.21\n", 320 | "[00/40] mean_loss : 5.55, Perplexity : 256.44\n", 321 | "[01/40] mean_loss : 5.38, Perplexity : 217.46\n", 322 | "[01/40] mean_loss : 5.21, Perplexity : 182.41\n", 323 | "[01/40] mean_loss : 5.10, Perplexity : 164.39\n", 324 | "[02/40] mean_loss : 5.08, Perplexity : 160.87\n", 325 | "[02/40] mean_loss : 4.99, Perplexity : 147.18\n", 326 | "[02/40] mean_loss : 4.92, Perplexity : 136.52\n", 327 | "[03/40] mean_loss : 4.92, Perplexity : 136.64\n", 328 | "[03/40] mean_loss : 4.86, Perplexity : 129.32\n", 329 | "[03/40] mean_loss : 4.80, Perplexity : 121.46\n", 330 | "[04/40] mean_loss : 4.80, Perplexity : 121.91\n", 331 | "[04/40] mean_loss : 4.77, Perplexity : 117.64\n", 332 | "[04/40] mean_loss : 4.71, Perplexity : 111.22\n", 333 | "[05/40] mean_loss : 4.72, Perplexity : 112.01\n", 334 | "[05/40] mean_loss : 4.70, Perplexity : 109.46\n", 335 | "[05/40] mean_loss : 4.64, Perplexity : 103.96\n", 336 | "[06/40] mean_loss : 4.66, Perplexity : 105.25\n", 337 | "[06/40] mean_loss : 4.64, Perplexity : 103.63\n", 338 | "[06/40] mean_loss : 4.60, Perplexity : 99.00\n", 339 | "[07/40] mean_loss : 4.60, Perplexity : 99.89\n", 340 | "[07/40] mean_loss : 4.59, Perplexity : 98.97\n", 341 | "[07/40] mean_loss : 4.55, Perplexity : 94.97\n", 342 | "[08/40] mean_loss : 4.56, Perplexity : 95.54\n", 343 | "[08/40] mean_loss : 4.56, Perplexity : 95.67\n", 344 | "[08/40] mean_loss : 4.52, Perplexity : 91.98\n", 345 | "[09/40] mean_loss : 4.53, Perplexity : 92.61\n", 346 | "[09/40] mean_loss : 4.53, Perplexity : 92.79\n", 347 | "[09/40] mean_loss : 4.50, Perplexity : 89.63\n", 348 | "[10/40] mean_loss : 4.50, Perplexity : 90.13\n", 349 | "[10/40] mean_loss : 4.50, Perplexity : 90.19\n", 350 | "[10/40] mean_loss : 4.47, Perplexity : 87.11\n", 351 | "[11/40] mean_loss : 4.48, Perplexity : 88.11\n", 352 | "[11/40] mean_loss : 4.48, Perplexity : 88.26\n", 353 | "[11/40] mean_loss : 4.45, Perplexity : 86.05\n", 354 | "[12/40] mean_loss : 4.46, Perplexity : 86.81\n", 355 | "[12/40] mean_loss : 4.47, Perplexity : 87.03\n", 356 | "[12/40] mean_loss : 4.43, Perplexity : 84.04\n", 357 | "[13/40] mean_loss : 4.45, Perplexity : 85.27\n", 358 | "[13/40] mean_loss : 4.45, Perplexity : 85.83\n", 359 | "[13/40] mean_loss : 4.42, Perplexity : 83.33\n", 360 | "[14/40] mean_loss : 4.43, Perplexity : 84.15\n", 361 | "[14/40] mean_loss : 4.43, Perplexity : 84.31\n", 362 | "[14/40] mean_loss : 4.41, Perplexity : 82.29\n", 363 | "[15/40] mean_loss : 4.43, Perplexity : 83.82\n", 364 | "[15/40] mean_loss : 4.43, Perplexity : 83.70\n", 365 | "[15/40] mean_loss : 4.40, Perplexity : 81.59\n", 366 | "[16/40] mean_loss : 4.42, Perplexity : 83.06\n", 367 | "[16/40] mean_loss : 4.42, Perplexity : 83.29\n", 368 | "[16/40] mean_loss : 4.39, Perplexity : 80.89\n", 369 | "[17/40] mean_loss : 4.41, Perplexity : 82.44\n", 370 | "[17/40] mean_loss : 4.41, Perplexity : 82.51\n", 371 | "[17/40] mean_loss : 4.39, Perplexity : 80.59\n", 372 | "[18/40] mean_loss : 4.40, Perplexity : 81.59\n", 373 | "[18/40] mean_loss : 4.41, Perplexity : 82.21\n", 374 | "[18/40] mean_loss : 4.38, Perplexity : 79.87\n", 375 | "[19/40] mean_loss : 4.40, Perplexity : 81.43\n", 376 | "[19/40] mean_loss : 4.40, Perplexity : 81.67\n", 377 | "[19/40] mean_loss : 4.37, Perplexity : 79.28\n", 378 | "[20/40] mean_loss : 4.40, Perplexity : 81.18\n", 379 | "[20/40] mean_loss : 4.40, Perplexity : 81.17\n", 380 | "[20/40] mean_loss : 4.37, Perplexity : 79.11\n", 381 | "[21/40] mean_loss : 4.40, Perplexity : 81.44\n", 382 | "[21/40] mean_loss : 4.34, Perplexity : 76.43\n", 383 | "[21/40] mean_loss : 4.21, Perplexity : 67.17\n", 384 | "[22/40] mean_loss : 4.26, Perplexity : 70.84\n", 385 | "[22/40] mean_loss : 4.26, Perplexity : 70.75\n", 386 | "[22/40] mean_loss : 4.17, Perplexity : 64.99\n", 387 | "[23/40] mean_loss : 4.22, Perplexity : 68.36\n", 388 | "[23/40] mean_loss : 4.22, Perplexity : 67.82\n", 389 | "[23/40] mean_loss : 4.15, Perplexity : 63.74\n", 390 | "[24/40] mean_loss : 4.20, Perplexity : 66.66\n", 391 | "[24/40] mean_loss : 4.20, Perplexity : 66.43\n", 392 | "[24/40] mean_loss : 4.14, Perplexity : 62.85\n", 393 | "[25/40] mean_loss : 4.18, Perplexity : 65.53\n", 394 | "[25/40] mean_loss : 4.17, Perplexity : 64.99\n", 395 | "[25/40] mean_loss : 4.13, Perplexity : 61.94\n", 396 | "[26/40] mean_loss : 4.17, Perplexity : 64.61\n", 397 | "[26/40] mean_loss : 4.16, Perplexity : 64.34\n", 398 | "[26/40] mean_loss : 4.12, Perplexity : 61.27\n", 399 | "[27/40] mean_loss : 4.15, Perplexity : 63.73\n", 400 | "[27/40] mean_loss : 4.15, Perplexity : 63.32\n", 401 | "[27/40] mean_loss : 4.11, Perplexity : 60.87\n", 402 | "[28/40] mean_loss : 4.14, Perplexity : 62.96\n", 403 | "[28/40] mean_loss : 4.14, Perplexity : 63.01\n", 404 | "[28/40] mean_loss : 4.10, Perplexity : 60.33\n", 405 | "[29/40] mean_loss : 4.14, Perplexity : 62.54\n", 406 | "[29/40] mean_loss : 4.13, Perplexity : 62.36\n", 407 | "[29/40] mean_loss : 4.10, Perplexity : 60.06\n", 408 | "[30/40] mean_loss : 4.13, Perplexity : 62.05\n", 409 | "[30/40] mean_loss : 4.13, Perplexity : 61.91\n", 410 | "[30/40] mean_loss : 4.09, Perplexity : 59.46\n", 411 | "[31/40] mean_loss : 4.12, Perplexity : 61.45\n", 412 | "[31/40] mean_loss : 4.11, Perplexity : 61.24\n", 413 | "[31/40] mean_loss : 4.08, Perplexity : 59.12\n", 414 | "[32/40] mean_loss : 4.11, Perplexity : 61.03\n", 415 | "[32/40] mean_loss : 4.11, Perplexity : 60.88\n", 416 | "[32/40] mean_loss : 4.07, Perplexity : 58.69\n", 417 | "[33/40] mean_loss : 4.11, Perplexity : 60.71\n", 418 | "[33/40] mean_loss : 4.10, Perplexity : 60.57\n", 419 | "[33/40] mean_loss : 4.07, Perplexity : 58.38\n", 420 | "[34/40] mean_loss : 4.10, Perplexity : 60.33\n", 421 | "[34/40] mean_loss : 4.10, Perplexity : 60.23\n", 422 | "[34/40] mean_loss : 4.06, Perplexity : 58.06\n", 423 | "[35/40] mean_loss : 4.09, Perplexity : 60.00\n", 424 | "[35/40] mean_loss : 4.09, Perplexity : 59.74\n", 425 | "[35/40] mean_loss : 4.06, Perplexity : 57.75\n", 426 | "[36/40] mean_loss : 4.09, Perplexity : 59.58\n", 427 | "[36/40] mean_loss : 4.09, Perplexity : 59.47\n", 428 | "[36/40] mean_loss : 4.05, Perplexity : 57.59\n", 429 | "[37/40] mean_loss : 4.08, Perplexity : 59.30\n", 430 | "[37/40] mean_loss : 4.08, Perplexity : 59.11\n", 431 | "[37/40] mean_loss : 4.05, Perplexity : 57.11\n", 432 | "[38/40] mean_loss : 4.08, Perplexity : 58.98\n", 433 | "[38/40] mean_loss : 4.07, Perplexity : 58.70\n", 434 | "[38/40] mean_loss : 4.04, Perplexity : 57.10\n", 435 | "[39/40] mean_loss : 4.07, Perplexity : 58.79\n", 436 | "[39/40] mean_loss : 4.07, Perplexity : 58.58\n", 437 | "[39/40] mean_loss : 4.04, Perplexity : 56.79\n" 438 | ] 439 | } 440 | ], 441 | "source": [ 442 | "for epoch in range(EPOCH):\n", 443 | " total_loss = 0\n", 444 | " losses=[]\n", 445 | " hidden = model.init_hidden(BATCH_SIZE)\n", 446 | " for i,batch in enumerate(getBatch(train_data,SEQ_LENGTH)):\n", 447 | " inputs, targets = batch\n", 448 | " hidden = model.detach_hidden(hidden)\n", 449 | " model.zero_grad()\n", 450 | " preds,hidden = model(inputs,hidden,True)\n", 451 | "\n", 452 | " loss = loss_function(preds,targets.view(-1))\n", 453 | " losses.append(loss.data[0])\n", 454 | " loss.backward()\n", 455 | " torch.nn.utils.clip_grad_norm(model.parameters(),0.5) # gradient clipping\n", 456 | " optimizer.step()\n", 457 | "\n", 458 | " if i>0 and i % 500==0:\n", 459 | " print(\"[%02d/%d] mean_loss : %0.2f, Perplexity : %0.2f\" % (epoch,EPOCH, \\\n", 460 | " np.mean(losses),np.exp(np.mean(losses))))\n", 461 | " losses=[]\n", 462 | " \n", 463 | " # learning rate anealing\n", 464 | " # You can use http://pytorch.org/docs/master/optim.html#how-to-adjust-learning-rate\n", 465 | " if RESCHEDULED==False and epoch==EPOCH//2:\n", 466 | " LR=LR*0.1\n", 467 | " optimizer = optim.Adam(model.parameters(),lr=LR)\n", 468 | " RESCHEDULED=True" 469 | ] 470 | }, 471 | { 472 | "cell_type": "markdown", 473 | "metadata": {}, 474 | "source": [ 475 | "### Test " 476 | ] 477 | }, 478 | { 479 | "cell_type": "code", 480 | "execution_count": 189, 481 | "metadata": { 482 | "collapsed": false 483 | }, 484 | "outputs": [ 485 | { 486 | "name": "stdout", 487 | "output_type": "stream", 488 | "text": [ 489 | "Test Perpelexity : 155.89\n" 490 | ] 491 | } 492 | ], 493 | "source": [ 494 | "total_loss = 0\n", 495 | "hidden = model.init_hidden(BATCH_SIZE//2)\n", 496 | "for batch in getBatch(test_data,SEQ_LENGTH):\n", 497 | " inputs,targets = batch\n", 498 | " \n", 499 | " hidden = model.detach_hidden(hidden)\n", 500 | " model.zero_grad()\n", 501 | " preds,hidden = model(inputs,hidden)\n", 502 | " total_loss += inputs.size(1) * loss_function(preds, targets.view(-1)).data\n", 503 | "\n", 504 | "total_loss = total_loss[0]/test_data.size(1)\n", 505 | "print(\"Test Perpelexity : %5.2f\" % (np.exp(total_loss)))" 506 | ] 507 | }, 508 | { 509 | "cell_type": "markdown", 510 | "metadata": { 511 | "collapsed": true 512 | }, 513 | "source": [ 514 | "## Further topics" 515 | ] 516 | }, 517 | { 518 | "cell_type": "markdown", 519 | "metadata": {}, 520 | "source": [ 521 | "* Pointer Sentinel Mixture Models\n", 522 | "* Regularizing and Optimizing LSTM Language Models" 523 | ] 524 | }, 525 | { 526 | "cell_type": "code", 527 | "execution_count": null, 528 | "metadata": { 529 | "collapsed": true 530 | }, 531 | "outputs": [], 532 | "source": [] 533 | } 534 | ], 535 | "metadata": { 536 | "kernelspec": { 537 | "display_name": "Python 3", 538 | "language": "python", 539 | "name": "python3" 540 | }, 541 | "language_info": { 542 | "codemirror_mode": { 543 | "name": "ipython", 544 | "version": 3 545 | }, 546 | "file_extension": ".py", 547 | "mimetype": "text/x-python", 548 | "name": "python", 549 | "nbconvert_exporter": "python", 550 | "pygments_lexer": "ipython3", 551 | "version": "3.5.2" 552 | } 553 | }, 554 | "nbformat": 4, 555 | "nbformat_minor": 2 556 | } 557 | -------------------------------------------------------------------------------- /notebooks/08.CNN-for-Text-Classification.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# 8. Convolutional Neural Networks" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "I recommend you take a look at these material first." 15 | ] 16 | }, 17 | { 18 | "cell_type": "markdown", 19 | "metadata": {}, 20 | "source": [ 21 | "* http://web.stanford.edu/class/cs224n/lectures/cs224n-2017-lecture13-CNNs.pdf\n", 22 | "* http://www.aclweb.org/anthology/D14-1181\n", 23 | "* https://github.com/Shawn1993/cnn-text-classification-pytorch\n", 24 | "* http://cogcomp.org/Data/QA/QC/" 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": 58, 30 | "metadata": { 31 | "collapsed": true 32 | }, 33 | "outputs": [], 34 | "source": [ 35 | "import torch\n", 36 | "import torch.nn as nn\n", 37 | "from torch.autograd import Variable\n", 38 | "import torch.optim as optim\n", 39 | "import torch.nn.functional as F\n", 40 | "import nltk\n", 41 | "import random\n", 42 | "import numpy as np\n", 43 | "from collections import Counter, OrderedDict\n", 44 | "import nltk\n", 45 | "import re\n", 46 | "from copy import deepcopy\n", 47 | "flatten = lambda l: [item for sublist in l for item in sublist]" 48 | ] 49 | }, 50 | { 51 | "cell_type": "code", 52 | "execution_count": 2, 53 | "metadata": { 54 | "collapsed": true 55 | }, 56 | "outputs": [], 57 | "source": [ 58 | "USE_CUDA = torch.cuda.is_available()\n", 59 | "\n", 60 | "FloatTensor = torch.cuda.FloatTensor if USE_CUDA else torch.FloatTensor\n", 61 | "LongTensor = torch.cuda.LongTensor if USE_CUDA else torch.LongTensor\n", 62 | "ByteTensor = torch.cuda.ByteTensor if USE_CUDA else torch.ByteTensor" 63 | ] 64 | }, 65 | { 66 | "cell_type": "code", 67 | "execution_count": 3, 68 | "metadata": { 69 | "collapsed": true 70 | }, 71 | "outputs": [], 72 | "source": [ 73 | "def getBatch(batch_size,train_data):\n", 74 | " random.shuffle(train_data)\n", 75 | " sindex=0\n", 76 | " eindex=batch_size\n", 77 | " while eindex < len(train_data):\n", 78 | " batch = train_data[sindex:eindex]\n", 79 | " temp = eindex\n", 80 | " eindex = eindex+batch_size\n", 81 | " sindex = temp\n", 82 | " yield batch\n", 83 | " \n", 84 | " if eindex >= len(train_data):\n", 85 | " batch = train_data[sindex:]\n", 86 | " yield batch" 87 | ] 88 | }, 89 | { 90 | "cell_type": "code", 91 | "execution_count": 110, 92 | "metadata": { 93 | "collapsed": true 94 | }, 95 | "outputs": [], 96 | "source": [ 97 | "def pad_to_batch(batch):\n", 98 | " x,y = zip(*batch)\n", 99 | " max_x = max([s.size(1) for s in x])\n", 100 | " x_p=[]\n", 101 | " for i in range(len(batch)):\n", 102 | " if x[i].size(1)']]*(max_x-x[i].size(1)))).view(1,-1)],1))\n", 104 | " else:\n", 105 | " x_p.append(x[i])\n", 106 | " return torch.cat(x_p),torch.cat(y).view(-1)" 107 | ] 108 | }, 109 | { 110 | "cell_type": "code", 111 | "execution_count": 20, 112 | "metadata": { 113 | "collapsed": true 114 | }, 115 | "outputs": [], 116 | "source": [ 117 | "def prepare_sequence(seq, to_index):\n", 118 | " idxs = list(map(lambda w: to_index[w] if w in to_index.keys() else to_index[\"\"], seq))\n", 119 | " return Variable(LongTensor(idxs))" 120 | ] 121 | }, 122 | { 123 | "cell_type": "markdown", 124 | "metadata": {}, 125 | "source": [ 126 | "## Data load & Preprocessing" 127 | ] 128 | }, 129 | { 130 | "cell_type": "markdown", 131 | "metadata": {}, 132 | "source": [ 133 | "### TREC question dataset(http://cogcomp.org/Data/QA/QC/)" 134 | ] 135 | }, 136 | { 137 | "cell_type": "markdown", 138 | "metadata": {}, 139 | "source": [ 140 | "Task involves\n", 141 | "classifying a question into 6 question\n", 142 | "types (whether the question is about person,\n", 143 | "location, numeric information, etc.)" 144 | ] 145 | }, 146 | { 147 | "cell_type": "code", 148 | "execution_count": 53, 149 | "metadata": { 150 | "collapsed": true 151 | }, 152 | "outputs": [], 153 | "source": [ 154 | "data = open('../dataset/train_5500.label.txt','r',encoding='latin-1').readlines()" 155 | ] 156 | }, 157 | { 158 | "cell_type": "code", 159 | "execution_count": 54, 160 | "metadata": { 161 | "collapsed": true 162 | }, 163 | "outputs": [], 164 | "source": [ 165 | "data = [[d.split(':')[1][:-1],d.split(':')[0]] for d in data]" 166 | ] 167 | }, 168 | { 169 | "cell_type": "code", 170 | "execution_count": 61, 171 | "metadata": { 172 | "collapsed": true 173 | }, 174 | "outputs": [], 175 | "source": [ 176 | "X,y = list(zip(*data))\n", 177 | "X = list(X)" 178 | ] 179 | }, 180 | { 181 | "cell_type": "markdown", 182 | "metadata": {}, 183 | "source": [ 184 | "### Num masking " 185 | ] 186 | }, 187 | { 188 | "cell_type": "markdown", 189 | "metadata": {}, 190 | "source": [ 191 | "It reduces the search space. ex. my birthday is 12.22 ==> my birthday is ##.##" 192 | ] 193 | }, 194 | { 195 | "cell_type": "code", 196 | "execution_count": 62, 197 | "metadata": { 198 | "collapsed": true 199 | }, 200 | "outputs": [], 201 | "source": [ 202 | "for i,x in enumerate(X):\n", 203 | " X[i] = re.sub('\\d','#',x).split()" 204 | ] 205 | }, 206 | { 207 | "cell_type": "markdown", 208 | "metadata": {}, 209 | "source": [ 210 | "### Build Vocab " 211 | ] 212 | }, 213 | { 214 | "cell_type": "code", 215 | "execution_count": 63, 216 | "metadata": { 217 | "collapsed": true 218 | }, 219 | "outputs": [], 220 | "source": [ 221 | "vocab = list(set(flatten(X)))" 222 | ] 223 | }, 224 | { 225 | "cell_type": "code", 226 | "execution_count": 64, 227 | "metadata": { 228 | "collapsed": false 229 | }, 230 | "outputs": [ 231 | { 232 | "data": { 233 | "text/plain": [ 234 | "9117" 235 | ] 236 | }, 237 | "execution_count": 64, 238 | "metadata": {}, 239 | "output_type": "execute_result" 240 | } 241 | ], 242 | "source": [ 243 | "len(vocab)" 244 | ] 245 | }, 246 | { 247 | "cell_type": "code", 248 | "execution_count": 31, 249 | "metadata": { 250 | "collapsed": false 251 | }, 252 | "outputs": [ 253 | { 254 | "data": { 255 | "text/plain": [ 256 | "6" 257 | ] 258 | }, 259 | "execution_count": 31, 260 | "metadata": {}, 261 | "output_type": "execute_result" 262 | } 263 | ], 264 | "source": [ 265 | "len(set(y)) # num of class" 266 | ] 267 | }, 268 | { 269 | "cell_type": "code", 270 | "execution_count": 94, 271 | "metadata": { 272 | "collapsed": true 273 | }, 274 | "outputs": [], 275 | "source": [ 276 | "word2index={'':0,'':1}\n", 277 | "\n", 278 | "for vo in vocab:\n", 279 | " if vo not in word2index.keys():\n", 280 | " word2index[vo]=len(word2index)\n", 281 | " \n", 282 | "index2word = {v:k for k,v in word2index.items()}\n", 283 | "\n", 284 | "target2index = {}\n", 285 | "\n", 286 | "for cl in set(y):\n", 287 | " if cl not in target2index.keys():\n", 288 | " target2index[cl]=len(target2index)\n", 289 | "\n", 290 | "index2target = {v:k for k,v in target2index.items()}" 291 | ] 292 | }, 293 | { 294 | "cell_type": "code", 295 | "execution_count": 95, 296 | "metadata": { 297 | "collapsed": true 298 | }, 299 | "outputs": [], 300 | "source": [ 301 | "X_p,y_p=[],[]\n", 302 | "for pair in zip(X,y):\n", 303 | " X_p.append(prepare_sequence(pair[0],word2index).view(1,-1))\n", 304 | " y_p.append(Variable(LongTensor([target2index[pair[1]]])).view(1,-1))\n", 305 | " \n", 306 | "data_p = list(zip(X_p,y_p))\n", 307 | "random.shuffle(data_p)\n", 308 | "\n", 309 | "train_data = data_p[:int(len(data_p)*0.9)]\n", 310 | "test_data = data_p[int(len(data_p)*0.9):]" 311 | ] 312 | }, 313 | { 314 | "cell_type": "markdown", 315 | "metadata": {}, 316 | "source": [ 317 | "### Load Pretrained word vector" 318 | ] 319 | }, 320 | { 321 | "cell_type": "markdown", 322 | "metadata": {}, 323 | "source": [ 324 | "you can download pretrained word vector from here https://github.com/mmihaltz/word2vec-GoogleNews-vectors " 325 | ] 326 | }, 327 | { 328 | "cell_type": "code", 329 | "execution_count": 41, 330 | "metadata": { 331 | "collapsed": true 332 | }, 333 | "outputs": [], 334 | "source": [ 335 | "import gensim" 336 | ] 337 | }, 338 | { 339 | "cell_type": "code", 340 | "execution_count": 43, 341 | "metadata": { 342 | "collapsed": true 343 | }, 344 | "outputs": [], 345 | "source": [ 346 | "model = gensim.models.KeyedVectors.load_word2vec_format('../dataset/GoogleNews-vectors-negative300.bin', binary=True)" 347 | ] 348 | }, 349 | { 350 | "cell_type": "code", 351 | "execution_count": 48, 352 | "metadata": { 353 | "collapsed": false 354 | }, 355 | "outputs": [ 356 | { 357 | "data": { 358 | "text/plain": [ 359 | "3000000" 360 | ] 361 | }, 362 | "execution_count": 48, 363 | "metadata": {}, 364 | "output_type": "execute_result" 365 | } 366 | ], 367 | "source": [ 368 | "len(model.index2word)" 369 | ] 370 | }, 371 | { 372 | "cell_type": "code", 373 | "execution_count": 96, 374 | "metadata": { 375 | "collapsed": true 376 | }, 377 | "outputs": [], 378 | "source": [ 379 | "pretrained = []\n", 380 | "\n", 381 | "for i in range(len(word2index)):\n", 382 | " try:\n", 383 | " pretrained.append(model[word2index[i]])\n", 384 | " except:\n", 385 | " pretrained.append(np.random.randn(300))\n", 386 | " \n", 387 | "pretrained_vectors = np.vstack(pretrained)" 388 | ] 389 | }, 390 | { 391 | "cell_type": "markdown", 392 | "metadata": {}, 393 | "source": [ 394 | "## Modeling " 395 | ] 396 | }, 397 | { 398 | "cell_type": "markdown", 399 | "metadata": {}, 400 | "source": [ 401 | "\n", 402 | "
borrowed image from http://www.aclweb.org/anthology/D14-1181
" 403 | ] 404 | }, 405 | { 406 | "cell_type": "code", 407 | "execution_count": 117, 408 | "metadata": { 409 | "collapsed": true 410 | }, 411 | "outputs": [], 412 | "source": [ 413 | "class CNNClassifier(nn.Module):\n", 414 | " \n", 415 | " def __init__(self, vocab_size,embedding_dim,output_size,kernel_dim=100,kernel_sizes=[3,4,5],dropout=0.5):\n", 416 | " super(CNNClassifier,self).__init__()\n", 417 | "\n", 418 | " self.embedding = nn.Embedding(vocab_size, embedding_dim)\n", 419 | " self.convs = nn.ModuleList([nn.Conv2d(1, kernel_dim, (K, embedding_dim)) for K in kernel_sizes])\n", 420 | "\n", 421 | " # kernal_size = (K,D) \n", 422 | " self.dropout = nn.Dropout(dropout)\n", 423 | " self.fc = nn.Linear(len(kernel_sizes)*kernel_dim, output_size)\n", 424 | " \n", 425 | " \n", 426 | " def init_weights(self,pretrained_word_vectors,is_static=False):\n", 427 | " self.embedding.weight = nn.Parameter(torch.from_numpy(pretrained_word_vectors).float())\n", 428 | " if is_static:\n", 429 | " self.embedding.weight.requires_grad = False\n", 430 | "\n", 431 | "\n", 432 | " def forward(self, inputs,is_training=False):\n", 433 | " inputs = self.embedding(inputs).unsqueeze(1) # (B,1,T,D)\n", 434 | " inputs = [F.relu(conv(inputs)).squeeze(3) for conv in self.convs] #[(N,Co,W), ...]*len(Ks)\n", 435 | " inputs = [F.max_pool1d(i, i.size(2)).squeeze(2) for i in inputs] #[(N,Co), ...]*len(Ks)\n", 436 | "\n", 437 | " concated = torch.cat(inputs, 1)\n", 438 | "\n", 439 | " if is_training:\n", 440 | " concated = self.dropout(concated) # (N,len(Ks)*Co)\n", 441 | " out = self.fc(concated) \n", 442 | " return F.log_softmax(out)" 443 | ] 444 | }, 445 | { 446 | "cell_type": "markdown", 447 | "metadata": {}, 448 | "source": [ 449 | "## Train " 450 | ] 451 | }, 452 | { 453 | "cell_type": "markdown", 454 | "metadata": {}, 455 | "source": [ 456 | "It takes for a while if you use just cpu." 457 | ] 458 | }, 459 | { 460 | "cell_type": "code", 461 | "execution_count": 145, 462 | "metadata": { 463 | "collapsed": true 464 | }, 465 | "outputs": [], 466 | "source": [ 467 | "EPOCH=5\n", 468 | "BATCH_SIZE=50\n", 469 | "KERNEL_SIZES = [3,4,5]\n", 470 | "KERNEL_DIM = 100\n", 471 | "LR = 0.001" 472 | ] 473 | }, 474 | { 475 | "cell_type": "code", 476 | "execution_count": 146, 477 | "metadata": { 478 | "collapsed": true 479 | }, 480 | "outputs": [], 481 | "source": [ 482 | "model = CNNClassifier(len(word2index),300,len(target2index),KERNEL_DIM,KERNEL_SIZES)\n", 483 | "model.init_weights(pretrained_vectors) # initialize embedding matrix using pretrained vectors\n", 484 | "\n", 485 | "if USE_CUDA:\n", 486 | " model = model.cuda()\n", 487 | " \n", 488 | "loss_function = nn.CrossEntropyLoss()\n", 489 | "optimizer = optim.Adam(model.parameters(),lr=LR)" 490 | ] 491 | }, 492 | { 493 | "cell_type": "code", 494 | "execution_count": 147, 495 | "metadata": { 496 | "collapsed": false 497 | }, 498 | "outputs": [ 499 | { 500 | "name": "stdout", 501 | "output_type": "stream", 502 | "text": [ 503 | "[0/5] mean_loss : 2.13\n", 504 | "[1/5] mean_loss : 0.12\n", 505 | "[2/5] mean_loss : 0.08\n", 506 | "[3/5] mean_loss : 0.02\n", 507 | "[4/5] mean_loss : 0.05\n" 508 | ] 509 | } 510 | ], 511 | "source": [ 512 | "for epoch in range(EPOCH):\n", 513 | " losses=[]\n", 514 | " for i,batch in enumerate(getBatch(BATCH_SIZE,train_data)):\n", 515 | " inputs,targets = pad_to_batch(batch)\n", 516 | " \n", 517 | " model.zero_grad()\n", 518 | " preds = model(inputs,True)\n", 519 | " \n", 520 | " loss = loss_function(preds,targets)\n", 521 | " losses.append(loss.data.tolist()[0])\n", 522 | " loss.backward()\n", 523 | " \n", 524 | " #for param in model.parameters():\n", 525 | " # param.grad.data.clamp_(-3, 3)\n", 526 | " \n", 527 | " optimizer.step()\n", 528 | " \n", 529 | " if i % 100==0:\n", 530 | " print(\"[%d/%d] mean_loss : %0.2f\" %(epoch,EPOCH,np.mean(losses)))\n", 531 | " losses=[]" 532 | ] 533 | }, 534 | { 535 | "cell_type": "markdown", 536 | "metadata": {}, 537 | "source": [ 538 | "## Test " 539 | ] 540 | }, 541 | { 542 | "cell_type": "code", 543 | "execution_count": 150, 544 | "metadata": { 545 | "collapsed": true 546 | }, 547 | "outputs": [], 548 | "source": [ 549 | "accuracy=0" 550 | ] 551 | }, 552 | { 553 | "cell_type": "code", 554 | "execution_count": 151, 555 | "metadata": { 556 | "collapsed": false 557 | }, 558 | "outputs": [ 559 | { 560 | "name": "stdout", 561 | "output_type": "stream", 562 | "text": [ 563 | "97.61904761904762\n" 564 | ] 565 | } 566 | ], 567 | "source": [ 568 | "for test in test_data:\n", 569 | " pred = model(test[0]).max(1)[1]\n", 570 | " pred = pred.data.tolist()[0]\n", 571 | " target = test[1].data.tolist()[0][0]\n", 572 | " if pred == target:\n", 573 | " accuracy+=1\n", 574 | "\n", 575 | "print(accuracy/len(test_data)*100)" 576 | ] 577 | }, 578 | { 579 | "cell_type": "markdown", 580 | "metadata": { 581 | "collapsed": true 582 | }, 583 | "source": [ 584 | "## Further topics " 585 | ] 586 | }, 587 | { 588 | "cell_type": "markdown", 589 | "metadata": {}, 590 | "source": [ 591 | "* Character-Aware Neural Language Models\n", 592 | "* Character level CNN for text classification" 593 | ] 594 | }, 595 | { 596 | "cell_type": "markdown", 597 | "metadata": {}, 598 | "source": [ 599 | "## Suggested Reading" 600 | ] 601 | }, 602 | { 603 | "cell_type": "markdown", 604 | "metadata": {}, 605 | "source": [ 606 | "* https://blog.statsbot.co/text-classifier-algorithms-in-machine-learning-acc115293278\n", 607 | "* Bag of Tricks for Efficient Text Classification\n", 608 | "* Which Encoding is the Best for Text Classification in Chinese, English, Japanese and Korean?" 609 | ] 610 | }, 611 | { 612 | "cell_type": "code", 613 | "execution_count": null, 614 | "metadata": { 615 | "collapsed": true 616 | }, 617 | "outputs": [], 618 | "source": [] 619 | } 620 | ], 621 | "metadata": { 622 | "kernelspec": { 623 | "display_name": "Python 3", 624 | "language": "python", 625 | "name": "python3" 626 | }, 627 | "language_info": { 628 | "codemirror_mode": { 629 | "name": "ipython", 630 | "version": 3 631 | }, 632 | "file_extension": ".py", 633 | "mimetype": "text/x-python", 634 | "name": "python", 635 | "nbconvert_exporter": "python", 636 | "pygments_lexer": "ipython3", 637 | "version": "3.5.2" 638 | } 639 | }, 640 | "nbformat": 4, 641 | "nbformat_minor": 2 642 | } 643 | -------------------------------------------------------------------------------- /notebooks/09.Recursive-NN-for-Sentiment-Classification.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# 9. Recursive Neural Networks and Constituency Parsing" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "I recommend you take a look at these material first." 15 | ] 16 | }, 17 | { 18 | "cell_type": "markdown", 19 | "metadata": {}, 20 | "source": [ 21 | "* http://web.stanford.edu/class/cs224n/lectures/cs224n-2017-lecture14-TreeRNNs.pdf\n", 22 | "* https://nlp.stanford.edu/~socherr/EMNLP2013_RNTN.pdf" 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": 4, 28 | "metadata": { 29 | "collapsed": true 30 | }, 31 | "outputs": [], 32 | "source": [ 33 | "import torch\n", 34 | "import torch.nn as nn\n", 35 | "from torch.autograd import Variable\n", 36 | "import torch.optim as optim\n", 37 | "import torch.nn.functional as F\n", 38 | "import nltk\n", 39 | "import random\n", 40 | "import numpy as np\n", 41 | "from collections import Counter, OrderedDict\n", 42 | "import nltk\n", 43 | "from copy import deepcopy\n", 44 | "import os\n", 45 | "from IPython.display import Image, display\n", 46 | "from nltk.draw import TreeWidget\n", 47 | "from nltk.draw.util import CanvasFrame\n", 48 | "from nltk.tree import Tree as nltkTree\n", 49 | "flatten = lambda l: [item for sublist in l for item in sublist]" 50 | ] 51 | }, 52 | { 53 | "cell_type": "code", 54 | "execution_count": 2, 55 | "metadata": { 56 | "collapsed": true 57 | }, 58 | "outputs": [], 59 | "source": [ 60 | "USE_CUDA = torch.cuda.is_available()\n", 61 | "\n", 62 | "FloatTensor = torch.cuda.FloatTensor if USE_CUDA else torch.FloatTensor\n", 63 | "LongTensor = torch.cuda.LongTensor if USE_CUDA else torch.LongTensor\n", 64 | "ByteTensor = torch.cuda.ByteTensor if USE_CUDA else torch.ByteTensor" 65 | ] 66 | }, 67 | { 68 | "cell_type": "code", 69 | "execution_count": 3, 70 | "metadata": { 71 | "collapsed": true 72 | }, 73 | "outputs": [], 74 | "source": [ 75 | "def getBatch(batch_size,train_data):\n", 76 | " random.shuffle(train_data)\n", 77 | " sindex=0\n", 78 | " eindex=batch_size\n", 79 | " while eindex < len(train_data):\n", 80 | " batch = train_data[sindex:eindex]\n", 81 | " temp = eindex\n", 82 | " eindex = eindex+batch_size\n", 83 | " sindex = temp\n", 84 | " yield batch\n", 85 | " \n", 86 | " if eindex >= len(train_data):\n", 87 | " batch = train_data[sindex:]\n", 88 | " yield batch" 89 | ] 90 | }, 91 | { 92 | "cell_type": "code", 93 | "execution_count": 5, 94 | "metadata": { 95 | "collapsed": true 96 | }, 97 | "outputs": [], 98 | "source": [ 99 | "# Borrowed from https://stackoverflow.com/questions/31779707/how-do-you-make-nltk-draw-trees-that-are-inline-in-ipython-jupyter\n", 100 | "\n", 101 | "def draw_nltk_tree(tree):\n", 102 | " cf = CanvasFrame()\n", 103 | " tc = TreeWidget(cf.canvas(), tree)\n", 104 | " tc['node_font'] = 'arial 15 bold'\n", 105 | " tc['leaf_font'] = 'arial 15'\n", 106 | " tc['node_color'] = '#005990'\n", 107 | " tc['leaf_color'] = '#3F8F57'\n", 108 | " tc['line_color'] = '#175252'\n", 109 | " cf.add_widget(tc, 50, 50)\n", 110 | " cf.print_to_file('tmp_tree_output.ps')\n", 111 | " cf.destroy()\n", 112 | " os.system('convert tmp_tree_output.ps tmp_tree_output.png')\n", 113 | " display(Image(filename='tmp_tree_output.png'))\n", 114 | " os.system('rm tmp_tree_output.ps tmp_tree_output.png')" 115 | ] 116 | }, 117 | { 118 | "cell_type": "markdown", 119 | "metadata": {}, 120 | "source": [ 121 | "## Data load and Preprocessing" 122 | ] 123 | }, 124 | { 125 | "cell_type": "markdown", 126 | "metadata": {}, 127 | "source": [ 128 | "### Stanford Sentiment Treebank(https://nlp.stanford.edu/sentiment/index.html)" 129 | ] 130 | }, 131 | { 132 | "cell_type": "code", 133 | "execution_count": 10, 134 | "metadata": { 135 | "collapsed": false 136 | }, 137 | "outputs": [ 138 | { 139 | "name": "stdout", 140 | "output_type": "stream", 141 | "text": [ 142 | "(3 (2 (1 Deflated) (2 (2 ending) (2 aside))) (4 (2 ,) (4 (2 there) (3 (3 (2 's) (3 (2 much) (2 (2 to) (3 (3 recommend) (2 (2 the) (2 film)))))) (2 .)))))\n", 143 | "\n" 144 | ] 145 | } 146 | ], 147 | "source": [ 148 | "sample = random.choice(open('../dataset/trees/train.txt','r',encoding='utf-8').readlines())\n", 149 | "print(sample)" 150 | ] 151 | }, 152 | { 153 | "cell_type": "code", 154 | "execution_count": 11, 155 | "metadata": { 156 | "collapsed": false 157 | }, 158 | "outputs": [ 159 | { 160 | "data": { 161 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAhEAAAGVCAMAAAB3iVNzAAAJJGlDQ1BpY2MAAHjalZVnUJNZF8fv\n8zzphUASQodQQ5EqJYCUEFoo0quoQOidUEVsiLgCK4qINEUQUUDBVSmyVkSxsCgoYkE3yCKgrBtX\nERWUF/Sd0Xnf2Q/7n7n3/OY/Z+4995wPFwCCOFgSvLQnJqULvJ3smIFBwUzwg8L4aSkcT0838I96\nPwyg5XhvBfj3IkREpvGX4sLSyuWnCNIBgLKXWDMrPWWZDy8xPTz+K59dZsFSgUt8Y5mjv/Ho15xv\nLPqa4+vNXXoVCgAcKfoHDv+B/3vvslQ4gvTYqMhspk9yVHpWmCCSmbbcCR6Xy/QUJEfFJkT+UPC/\nSv4HpUdmpy9HbnLKBkFsdEw68/8ONTIwNATfZ/HW62uPIUb//85nWd+95HoA2LMAIHu+e+GVAHTu\nAED68XdPbamvlHwAOu7wMwSZ3zzU8oYGBEABdCADFIEq0AS6wAiYAUtgCxyAC/AAviAIrAN8EAMS\ngQBkgVywDRSAIrAH7AdVoBY0gCbQCk6DTnAeXAHXwW1wFwyDJ0AIJsArIALvwTwEQViIDNEgGUgJ\nUod0ICOIDVlDDpAb5A0FQaFQNJQEZUC50HaoCCqFqqA6qAn6BToHXYFuQoPQI2gMmob+hj7BCEyC\n6bACrAHrw2yYA7vCvvBaOBpOhXPgfHg3XAHXwyfgDvgKfBsehoXwK3gWAQgRYSDKiC7CRriIBxKM\nRCECZDNSiJQj9Ugr0o30IfcQITKDfERhUDQUE6WLskQ5o/xQfFQqajOqGFWFOo7qQPWi7qHGUCLU\nFzQZLY/WQVugeehAdDQ6C12ALkc3otvR19DD6An0ewwGw8CwMGYYZ0wQJg6zEVOMOYhpw1zGDGLG\nMbNYLFYGq4O1wnpgw7Dp2AJsJfYE9hJ2CDuB/YAj4pRwRjhHXDAuCZeHK8c14y7ihnCTuHm8OF4d\nb4H3wEfgN+BL8A34bvwd/AR+niBBYBGsCL6EOMI2QgWhlXCNMEp4SyQSVYjmRC9iLHErsYJ4iniD\nOEb8SKKStElcUggpg7SbdIx0mfSI9JZMJmuQbcnB5HTybnIT+Sr5GfmDGE1MT4wnFiG2RaxarENs\nSOw1BU9Rp3Ao6yg5lHLKGcodyow4XlxDnCseJr5ZvFr8nPiI+KwETcJQwkMiUaJYolnipsQUFUvV\noDpQI6j51CPUq9RxGkJTpXFpfNp2WgPtGm2CjqGz6Dx6HL2IfpI+QBdJUiWNJf0lsyWrJS9IChkI\nQ4PBYyQwShinGQ8Yn6QUpDhSkVK7pFqlhqTmpOWkbaUjpQul26SHpT/JMGUcZOJl9sp0yjyVRclq\ny3rJZskekr0mOyNHl7OU48sVyp2WeywPy2vLe8tvlD8i3y8/q6Co4KSQolCpcFVhRpGhaKsYp1im\neFFxWommZK0Uq1SmdEnpJVOSyWEmMCuYvUyRsryys3KGcp3ygPK8CkvFTyVPpU3lqSpBla0apVqm\n2qMqUlNSc1fLVWtRe6yOV2erx6gfUO9Tn9NgaQRo7NTo1JhiSbN4rBxWC2tUk6xpo5mqWa95Xwuj\nxdaK1zqodVcb1jbRjtGu1r6jA+uY6sTqHNQZXIFeYb4iaUX9ihFdki5HN1O3RXdMj6Hnppen16n3\nWl9NP1h/r36f/hcDE4MEgwaDJ4ZUQxfDPMNuw7+NtI34RtVG91eSVzqu3LKya+UbYx3jSONDxg9N\naCbuJjtNekw+m5qZCkxbTafN1MxCzWrMRth0tie7mH3DHG1uZ77F/Lz5RwtTi3SL0xZ/Wepaxls2\nW06tYq2KXNWwatxKxSrMqs5KaM20DrU+bC20UbYJs6m3eW6rahth22g7ydHixHFOcF7bGdgJ7Nrt\n5rgW3E3cy/aIvZN9of2AA9XBz6HK4ZmjimO0Y4ujyMnEaaPTZWe0s6vzXucRngKPz2viiVzMXDa5\n9LqSXH1cq1yfu2m7Cdy63WF3F/d97qOr1Vcnre70AB48j30eTz1Znqmev3phvDy9qr1eeBt653r3\n+dB81vs0+7z3tfMt8X3ip+mX4dfjT/EP8W/ynwuwDygNEAbqB24KvB0kGxQb1BWMDfYPbgyeXeOw\nZv+aiRCTkIKQB2tZa7PX3lwnuy5h3YX1lPVh68+EokMDQptDF8I8wurDZsN54TXhIj6Xf4D/KsI2\noixiOtIqsjRyMsoqqjRqKtoqel/0dIxNTHnMTCw3tir2TZxzXG3cXLxH/LH4xYSAhLZEXGJo4rkk\nalJ8Um+yYnJ28mCKTkpBijDVInV/qkjgKmhMg9LWpnWl05c+xf4MzYwdGWOZ1pnVmR+y/LPOZEtk\nJ2X3b9DesGvDZI5jztGNqI38jT25yrnbcsc2cTbVbYY2h2/u2aK6JX/LxFanrce3EbbFb/stzyCv\nNO/d9oDt3fkK+Vvzx3c47WgpECsQFIzstNxZ+xPqp9ifBnat3FW560thROGtIoOi8qKFYn7xrZ8N\nf674eXF31O6BEtOSQ3swe5L2PNhrs/d4qURpTun4Pvd9HWXMssKyd/vX779Zblxee4BwIOOAsMKt\noqtSrXJP5UJVTNVwtV11W418za6auYMRB4cO2R5qrVWoLar9dDj28MM6p7qOeo368iOYI5lHXjT4\nN/QdZR9tapRtLGr8fCzpmPC49/HeJrOmpmb55pIWuCWjZfpEyIm7J+1PdrXqtta1MdqKToFTGade\n/hL6y4PTrqd7zrDPtJ5VP1vTTmsv7IA6NnSIOmM6hV1BXYPnXM71dFt2t/+q9+ux88rnqy9IXii5\nSLiYf3HxUs6l2cspl2euRF8Z71nf8+Rq4NX7vV69A9dcr9247nj9ah+n79INqxvnb1rcPHeLfavz\ntuntjn6T/vbfTH5rHzAd6Lhjdqfrrvnd7sFVgxeHbIau3LO/d/0+7/7t4dXDgw/8HjwcCRkRPox4\nOPUo4dGbx5mP559sHUWPFj4Vf1r+TP5Z/e9av7cJTYUXxuzH+p/7PH8yzh9/9UfaHwsT+S/IL8on\nlSabpoymzk87Tt99ueblxKuUV/MzBX9K/FnzWvP12b9s/+oXBYom3gjeLP5d/Fbm7bF3xu96Zj1n\nn71PfD8/V/hB5sPxj+yPfZ8CPk3OZy1gFyo+a33u/uL6ZXQxcXHxPy6ikLxyKdSVAAAAIGNIUk0A\nAHomAACAhAAA+gAAAIDoAAB1MAAA6mAAADqYAAAXcJy6UTwAAACrUExURf///wBZkABZkABZkABZ\nkABZkABZkABZkABZkABZkABZkABZkBdSUhdSUhdTUxdSUhdSUhdSUhdSUhdSUhdSUhdSUhdSUhdS\nUhdSUhdSUhdSUhdSUgBZkABZkBdSUgBZkBdSUj+PVz+PVz+PVz+PVz+PVz+PVz+PVz+PVz+PVz+P\nVz+PVz+PVz+PVz+PVxdSUhdSUhdSUhdTUxdSUhhUVABZkBdSUj+PV////2WyAXMAAAA1dFJOUwAR\niO6ZuyJVZjPdqkR3dZnMu6OIZjMRqt3uVSJEzMd31rsziN2ZIkQRqlXuZsx3r+HSW+wg7fJpnQAA\nAAFiS0dEAIgFHUgAAAAJcEhZcwAAAEgAAABIAEbJaz4AAAAHdElNRQfhCwINNSfD9n+TAAASs0lE\nQVR42u2dCZuiuBqFU9pVXVWtrQ6uZffMZVFwq5m7SP7/P7skiFuxhiAJnPd5ulAgIZpD8sVODoQA\nAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA\nAAAAAAAAAAAAAAAAAAAAaDZPne6x++2p7mIAZfh27D53j891FwMoQ6fzQp6Ox7qLAZTiO9oIcEUn\niCNe6i4EUIjX5273te5CAKV4O3brLgJQhvfjE3lBZAnOPPPR54+6iwGU4ekbfqECAAAAgCi9/s9+\nr+5CADXo9QdD3x/9MfL94QCyaDcnMRjjCX87GRuQRWu5E8MFyKJ1THvGLF4MF06ymBm9ad3FBVXC\nxDD3FyPjY5Lj7MmHMVr4c8iimVzEsCyUbglZNA9RMVyALBpDWJVlxFBNXqAGqrmvy7c3oAaqbuQh\nC4143Pig2LgF1EAdvyHk+W0D1EDNvzMm/v4JakCZH50hCxUYqyGGCydZjOsuR2vpKSQG1UsFAABA\nUZ5+/aq7CK1EXReQH1geVgcdZV1Afh2hiDp47bwp6gLy3lWyWK3g5fhedxG+0jl+hyJq4un3UT3P\nh5fuDwJF1MPbu4KCID+6L1BEPbx2uwoO8t6Ovzud47GjoFabztvxXUXjKDbQYKg4CGo478f354C3\nussRB3qNOjjdiwp2HFAEAAAAoAG9n0pOTZmOf46x9uvxLI354s/F3FBrycT0wxj68z/n/tD4gCoe\nyHQ88kfBjRht1aBnjPzFbMw0uhzPFv7IULINayC9wVXbwNqKQf3f/KQ/8/1Zf5K+C1RAjAJuFFJH\nkZIbhKtmA1RBUi9RX+8xHQ8ygoYwtBgo07c1idS2oIbeI39dZ+sGFCZHjT+09yjcHyDYlEneXuFB\nvYdwzIhgUw6F7v2qe4/StzqCzZII1HBlvYescADBpjCivUAFvYfsWkSwKUCpe11q71FRS49gswgS\nalRO71FxNIhgMxeyWv2y+TzoJkawmYUhMTJkbY0hWIwHdvRhmCJY0ObTk3tTimbXe/BgYDpGRAEA\nAIVpryFHez95Kr/e5S53EHQeqcOwBFYkcfx1fJb6vQg6j9RhWAIrklj+epW7JErQeaQOwxJYkSQh\n/XsRdB55sGEJrEgSkf29CDqPPNiwBFYkyUj+XgSdRx5tWAIrkmTkfi+CziOPNiyBFUkKUhUh6Dzy\ncMMSWJEk8YvfKR1p96eg80g9hiXoNWLohLdKR1Z+gs4j9RiWQBEAAACAaix/6jFfrdf/2cdzqGPx\npebWX/xLbF6d3GKkMflgDw/1R3+MfH+OR4h+xZd4Ty9HizH5WAwFJqzJLEYS015/wGQwix4wvPzg\njxAdDvDY+it8efMNjcWMfbFTwzcKf8ESixEHe34xaxiMmCe99canx9Zj6iVHWlVMhouP08ve5eXD\ni3HPqWFgzy9epp5msIdFDmd4QLmsqjD82VXDYCxGxb7YKhTBnlg7L3TzXzUl7e1G5FTFZDi/zScI\nKQpFmHIVcbnje8Xv+KBVMdocdY4kVEVs5FAswpRRDM5EUlTQ3qhTQlX05vO4TApFmBKKETUMMuuw\njVFn6apIqfgCEWa5Ylwqrop2vmVR56jk+sfefJhSC7kjTNFiLHuPatxbE3WWU8R05qenzxthChQj\nbBjmFTUMCZ+3BVFnKUUE0WPmF5MvwixUjFPDMBrUdLs2O+osoYjpLNf9nyvCzFsMhe7RpkadA2FF\njHP/CpUjwswuxvL0P1R1NQzxNDDqNEQVMVj0C1xlMShVjPHpf6hqbxjiiaLOcd0FqZVivwYuS9pc\n9dVvmac6FBIAAEQxqUlMK+0EK3deQDdsGuCs1le7XGp5hLqXHeZ9ouuDdzx9K+QBUvB0qakfxose\nxTxh0+CPtaHbyy6X7bqqdIveJ0pRxO9iHiAFT5ea+mGwlWkaFPOEHVb3ZucRsnXobuW5QaNhskr3\nVkHrsSXsvX06xs9xtimK6HSeyEv+RVEFT5ea+lG8sGI+2hdFnJMiLLon7s4l68Pq3EasHI/sqcff\nR8e8QBbeJkUR/Cso5gFSzjLkwYYjovylWxtBgmZgswq2W3pWhLcOt+x9dGxPg53rdEU8/T5+L1CC\ngqdLTf0ono/vesURhCvCpBw3UsR6w99yRUTHwtNTFfGr2y3ivlDwdKmpH0YnkIQulojnXmNLTDvc\nFSnC3KzPbUR0LFsRr8ffRRb3FzxdaupH8qZH50YukaUT/DuEuyJFsHpfR73G6RjvNawURQR1VKR5\nLHi61NSP4lf3XYcAOOI0+txZTAmfwfhic1aEY5O1ySJOHl2Gx9a7FRuDJCoi+OC/mQdIzqvnOd2W\ndbG6eOry0ee3usuRE/4L1S78hSoYWdKgp4gU4Tr0YNlBgHEIR5/sWDj6TP7NMnLtyXn1PKdTW9LF\nauPlx1GjX6jUx/TqLgFQCm9TdwkAAKBaemOjaVMuC1FwNp7w5D0Jqatm0jPYHFx/9PffI59NxTV6\nis7+qxQogs2jGxujIZ8a3o8mGS7ZbPG57w/ZHG2FpghXT7sVMfkwmCuFPxokzAbOPKF5tFQRxZqA\nuEaksbROEUHlzgTDhHOgMWty6NkeRUirzxKa0oEWKKKaNr+xoWejFfGAuLB5oWczFfHoG7hJoWfT\nFFFrJ9+I0LMxilCoNvQOPfVXhKottq6hp86K0CKq06KQ12ipCP1uP1Ubshi0U8RU4y76KtjRQsea\nUHvgWJ6gwai7CAAAIIZlZqwDvMGll1QFLgA0gq8lzz2l3nMvqfJg5j81F3rZfSShtreKWIXlTMXs\nTaQqQi+7jySU8VZhpiOHLSEb16S7T8KXiJmflPcal307utleKtFmq8MsQvZsG7pTRKmi7JLh9iYu\nDdI61uX81dah/DqOTQqimd1HEsp4q5grj3zSNaGHNd/azGTACRVx3ues2YLBKMme7oOaNPkaY2/D\nV5meU0XZpRDamwRKYmuXL5ffu8QOrufu9iIfQx+7jzRU8FaxeO0dbEJZRbDFw8x45NRGnPeFbccl\nEeH1atFw8WjwMkp1zi4Frgh2nnt1Pl8aysVgCwSdOtl9pKCEt4ob+oyY4dCCqcAmYaWdzAbO20vX\n760cloaJ5GDzFennVOfsUi95yuz6/PBSIcU/hU52H8mo4a1yrucCijAP1unter/a2beKyHfJiyLu\nLi+KPnYfiSjirWJFpodRldz0GtE+3pqfK5vvjt6GrleXXmObeclrRdxeflc4qmRoZveRhDLeKiaP\n5qxz7ds7N4wirxSxOtxElkH1e5876n2ywYLt8MgyShVll4J7Ms8LvZCuL29zO73Pgp9AM7uPBNTx\nVuGWl5+XNoK9N/fmjSLYvqvRp8WMMdeOGY1Cg5o9p4qyS+PARp8kVMTN5dnoc2cXNptohN2HLt4q\nV2x3dZcAKANzyGXmqACcsE8GygDowbTfx/QjcGYyWPzzz2Kg4Xy6O1RYE90AxiN/1iOkN9P/2WrG\noHwebWdpLBbG8utrPTFGdZdAd760C6f2QlegiFJM+/OY2CGIKebaRplQRAmSaz5eKVoARQiT0Tvo\nGmX2oAgh8kSQekaZUIQIue9/DaNMKKIwxWIE7aLMnl93CTSjeA1rFmVCEYUQ7AV0ijKhiPyUiRT1\niTKhiLyUvs81iTKXUEQuljJiARaDqN9OQBH5kDNemPbr/hwAAAAAeCyyfDaa4ToC5PlsNMN1BEjz\n2WiI6wgIkeWz0QzXESDLZ6MhriNAms9GM1xHgEyfjQa4jgBZPhsNcR0B0nw29HAdcYU8dtqFNJ8N\nLVxHbJm+rwAAAACQB1xHGoMUn42Pmf/PP/7so+4PAyRQ3mdjaYRTNdlkSz1mZYM0Sq6Yno5H/mg8\njXkDNKWUIr42C1GDAbRFXBHT/jw2dAiCCr0WhIIbRBWRVu9JWgE6ILSGPrtvQJipLcUVkTN+RJip\nKUUVUeTmR5ipI4VWTBcPEBBmakcBRYjVLsJMzciriDI9AMJMnciliNJRIsJMfcjhsyHnFkeYqQtZ\niljKCwNYIILOQ39kDhVgPAIAAPpiiT83GjSSVeYKBlmWI1XkBuSTvaZFluVIFbkB+bgmsQ6U8kfZ\nxyPJcqSS3EAlODYhe5r6bHpZliNV5Aaks/vMOkOS5UgluQH57HfOapt2gizLkSpyA1XgbW3nkHxY\nouWI9NxAVXg08XcJSZYjleQGKmEXRJZbuk44KstypIrcQDWw0echMZCQZjlSQW4AAAAAaBnLfh+T\nqHRHiuUIZ9kf+v/+tz+EKPRGkiLYbGyuBaYLiEJnZChiOp75V9O5IQqtKa2IUA538/IhCn0pp4hY\nOYRAFJpSRhEfbO1PyqodiEJHhBURyGExyHx6NEShHWKK4HLIuRQMotALAUX0CsghBKLQiKKKmLAV\nvwILRSEKXSikiEAO/kzYFQCi0IL8iignhxCIQn1yKmIpQQ6nnCAKtcmjCFaJMi1kIAqVyVbEsorq\n46KAJjSlort5CfsRAADQGcusuwRALdwYv4lORSYhT1VlDCQSo4hOVSYh3+A+oggb16S7veVQxw0V\nwP5YG0o3XvByH+y/MZt47byRpyqWaXU6L9VkDApCD2vySc01WTlnRXjOyvM2JnF3gSw2zn2SykxC\nvqONUAC6P/0JlBApgi8VXrM2Yx3TdTz9PlbiCRF0SN9e6v46AOHOAezPlSIiB7NzL3LN23s1giCv\nz7AfUYGiinjtdn9VVZa3Y7furwN8VcSe9RpBNLm24xTxdnyvpmV/P8LXTg2uFbGmn2R9YJHlJogs\nD3GKeD++M5MQ+V5Cz3z0+aPurwPcKIJ87qjpmlejT3KviJNJiPyOA/64AAAAAChB72fmijCgJEWe\nIZub6Xjo/+kP8fw/HfHlT6pbGou5MSVTY77Ag0L1w5fduH+M/NFHzGugCXIV8aVdOLUXQB9kKqI3\niIkdWEyR7TsAlEGaIlJqPlYpQFFGchSR0TsgytQHKYrIE0EiytSE8orIff8jytSCUUl/w0IxAqJM\nDSilCIEaRpSpOiUUIdgLIMpUG2FFlIkUEWUqzEBIEaXvc0SZyiLkeGlIiAVYDCLtWQ+gZnpyxguS\nsgEAAADA4/GEXGekuYXAdUQ5tlQklSy3kMrsTIAoNqXUJZ8OpRurQDJZbiGV2ZkAYfhSQLon3mbn\nFUspyy2kMjsTIARTxGETvLDotkg6aW4hVdmZAEGYIqjNXoV/8yLLLaQyOxMgiKgiJLmFVGlnAoQQ\n6zVkuYVUZmcChHGpR/ZBZLk2nQKpZLmFVGZnAoTxDtHoc10glSy3kMrsTAAAAADQBqb/wby6hiDH\nhmQ6/O8QkmgGUhSxHA7/N8Tz3ZqBDEVMFoPg72AxqfvDAAlIUEQoCEiiIZRXRCQISKIZlFbEx+Ky\nVMNYYHmX9pRVxNgfJ74DOlJSEfcSgCS0p5wivgoAktCdUoqIixuu4wqgIWUUET+2uIw9gI6UUETS\nYBOS0BpxRST/+gBJ6IyoIqZp/4+xHOL/vbRFVBGj1DqfDkd1fzDwYCbpjcAUv2cDAEBrMYusOAct\ngD+O9oosBxE4jDSdO0VkOYjAYURJtuY+qMitQ3crjxDboc6e8O3OJsQ1XYeaa5PutkGfsDWpY9mU\nLyKOEphucPCT+RnR3f5OEVkOInAYURKXblzP3blkfViRz2DrBvVqB//cQBJuUPtbtiRw5QQtgMm2\nq2DHmkQJCD2syWewY7WzvNV9r0GyHUTgMKIcLg3CwU1Qt8ykitUxcdfhNlABWzlM6P5kM7Anpx3u\nOQHfyXYESiHrr4rIchCBw4h6sLomJuW4kasE39o0PMjqmSvitGUvogRh6MC2NvkaWWY6iMBhREFC\nRVwrgeRRRGRIkqqILAcROIyoCK/jzSF8w91G9m64XR2SFRElOCsirtfIchCBw4iS8Dp26SfxVpsg\nRmSR5TaMLKmdrIgowVkRq93as+8UkeUgAocRJeF1zAaT3GWEjT6DsSTf2iRZEecEkSL46NO8tTLK\nchCBwwgAAAAAAMhP1mw8OU4lQB8yFQF/kZaRqQg8yq1lQBHgFigC3AJFgFuyFDGCIloGFAFugSLA\nLZmKgK9Iy4AiwC1QBLgFigC3ZCliAEW0jCxFGFAEAAAAAL5gmcnHTGrylWFwIGkNZrQQJBaXWh5x\nPRKzlhw0E4tmKCLcQhFtwaWU2i7dO9QJ+gVvRelhe3eY9xpsueC1ZwloLKwRcHcbj2ycoAdZedx0\n5PZwpIhrzxLQWLgimAiCFxYXw8G+O3xWxJVnCWgsLj0vOud9BO8mbg+fFUGu1hiDpnKriNjDUESr\nuFaERbdxh6GIVsGig0gRxDwwbzvr5jCBIlrGgY0+SVjPbPTJbUrOQBEAAAAAAAAAAABQnf8DmwKy\nJJha+OIAAAAldEVYdGRhdGU6Y3JlYXRlADIwMTctMTEtMDJUMjI6NTM6MzkrMDk6MDBXyJTvAAAA\nJXRFWHRkYXRlOm1vZGlmeQAyMDE3LTExLTAyVDIyOjUzOjM5KzA5OjAwJpUsUwAAACN0RVh0cHM6\nSGlSZXNCb3VuZGluZ0JveAA1Mjl4NDA1LTI2NC0yMDKzNU7+AAAAHHRFWHRwczpMZXZlbABBZG9i\nZS0zLjAgRVBTRi0zLjAKm3C74wAAACJ0RVh0cHM6U3BvdENvbG9yLTAAZm9udCBMaWJlcmF0aW9u\nU2Fuc/4Zp8YAAAAASUVORK5CYII=\n", 162 | "text/plain": [ 163 | "" 164 | ] 165 | }, 166 | "metadata": {}, 167 | "output_type": "display_data" 168 | } 169 | ], 170 | "source": [ 171 | "draw_nltk_tree(nltkTree.fromstring(sample))" 172 | ] 173 | }, 174 | { 175 | "cell_type": "markdown", 176 | "metadata": {}, 177 | "source": [ 178 | "### Tree Class " 179 | ] 180 | }, 181 | { 182 | "cell_type": "markdown", 183 | "metadata": {}, 184 | "source": [ 185 | "borrowed code from https://github.com/bogatyy/cs224d/tree/master/assignment3" 186 | ] 187 | }, 188 | { 189 | "cell_type": "code", 190 | "execution_count": 10, 191 | "metadata": { 192 | "collapsed": false 193 | }, 194 | "outputs": [], 195 | "source": [ 196 | "class Node: # a node in the tree\n", 197 | " def __init__(self, label, word=None):\n", 198 | " self.label = label\n", 199 | " self.word = word\n", 200 | " self.parent = None # reference to parent\n", 201 | " self.left = None # reference to left child\n", 202 | " self.right = None # reference to right child\n", 203 | " # true if I am a leaf (could have probably derived this from if I have\n", 204 | " # a word)\n", 205 | " self.isLeaf = False\n", 206 | " # true if we have finished performing fowardprop on this node (note,\n", 207 | " # there are many ways to implement the recursion.. some might not\n", 208 | " # require this flag)\n", 209 | "\n", 210 | " def __str__(self):\n", 211 | " if self.isLeaf:\n", 212 | " return '[{0}:{1}]'.format(self.word, self.label)\n", 213 | " return '({0} <- [{1}:{2}] -> {3})'.format(self.left, self.word, self.label, self.right)\n", 214 | "\n", 215 | "\n", 216 | "class Tree:\n", 217 | "\n", 218 | " def __init__(self, treeString, openChar='(', closeChar=')'):\n", 219 | " tokens = []\n", 220 | " self.open = '('\n", 221 | " self.close = ')'\n", 222 | " for toks in treeString.strip().split():\n", 223 | " tokens += list(toks)\n", 224 | " self.root = self.parse(tokens)\n", 225 | " # get list of labels as obtained through a post-order traversal\n", 226 | " self.labels = get_labels(self.root)\n", 227 | " self.num_words = len(self.labels)\n", 228 | "\n", 229 | " def parse(self, tokens, parent=None):\n", 230 | " assert tokens[0] == self.open, \"Malformed tree\"\n", 231 | " assert tokens[-1] == self.close, \"Malformed tree\"\n", 232 | "\n", 233 | " split = 2 # position after open and label\n", 234 | " countOpen = countClose = 0\n", 235 | "\n", 236 | " if tokens[split] == self.open:\n", 237 | " countOpen += 1\n", 238 | " split += 1\n", 239 | " # Find where left child and right child split\n", 240 | " while countOpen != countClose:\n", 241 | " if tokens[split] == self.open:\n", 242 | " countOpen += 1\n", 243 | " if tokens[split] == self.close:\n", 244 | " countClose += 1\n", 245 | " split += 1\n", 246 | "\n", 247 | " # New node\n", 248 | " node = Node(int(tokens[1])) # zero index labels\n", 249 | "\n", 250 | " node.parent = parent\n", 251 | "\n", 252 | " # leaf Node\n", 253 | " if countOpen == 0:\n", 254 | " node.word = ''.join(tokens[2:-1]).lower() # lower case?\n", 255 | " node.isLeaf = True\n", 256 | " return node\n", 257 | "\n", 258 | " node.left = self.parse(tokens[2:split], parent=node)\n", 259 | " node.right = self.parse(tokens[split:-1], parent=node)\n", 260 | "\n", 261 | " return node\n", 262 | "\n", 263 | " def get_words(self):\n", 264 | " leaves = getLeaves(self.root)\n", 265 | " words = [node.word for node in leaves]\n", 266 | " return words\n", 267 | "\n", 268 | "def get_labels(node):\n", 269 | " if node is None:\n", 270 | " return []\n", 271 | " return get_labels(node.left) + get_labels(node.right) + [node.label]\n", 272 | "\n", 273 | "def getLeaves(node):\n", 274 | " if node is None:\n", 275 | " return []\n", 276 | " if node.isLeaf:\n", 277 | " return [node]\n", 278 | " else:\n", 279 | " return getLeaves(node.left) + getLeaves(node.right)\n", 280 | "\n", 281 | " \n", 282 | "def loadTrees(dataSet='train'):\n", 283 | " \"\"\"\n", 284 | " Loads training trees. Maps leaf node words to word ids.\n", 285 | " \"\"\"\n", 286 | " file = '../dataset/trees/%s.txt' % dataSet\n", 287 | " print(\"Loading %s trees..\" % dataSet)\n", 288 | " with open(file, 'r',encoding='utf-8') as fid:\n", 289 | " trees = [Tree(l) for l in fid.readlines()]\n", 290 | "\n", 291 | " return trees" 292 | ] 293 | }, 294 | { 295 | "cell_type": "code", 296 | "execution_count": 11, 297 | "metadata": { 298 | "collapsed": false 299 | }, 300 | "outputs": [ 301 | { 302 | "name": "stdout", 303 | "output_type": "stream", 304 | "text": [ 305 | "Loading train trees..\n" 306 | ] 307 | } 308 | ], 309 | "source": [ 310 | "train_data = loadTrees('train')" 311 | ] 312 | }, 313 | { 314 | "cell_type": "markdown", 315 | "metadata": {}, 316 | "source": [ 317 | "### Build Vocab " 318 | ] 319 | }, 320 | { 321 | "cell_type": "code", 322 | "execution_count": 12, 323 | "metadata": { 324 | "collapsed": false 325 | }, 326 | "outputs": [], 327 | "source": [ 328 | "vocab = list(set(flatten([t.get_words() for t in train_data])))" 329 | ] 330 | }, 331 | { 332 | "cell_type": "code", 333 | "execution_count": 13, 334 | "metadata": { 335 | "collapsed": true 336 | }, 337 | "outputs": [], 338 | "source": [ 339 | "word2index={'':0}\n", 340 | "for vo in vocab:\n", 341 | " if vo not in word2index.keys():\n", 342 | " word2index[vo]=len(word2index)\n", 343 | " \n", 344 | "index2word = {v:k for k,v in word2index.items()}" 345 | ] 346 | }, 347 | { 348 | "cell_type": "markdown", 349 | "metadata": {}, 350 | "source": [ 351 | "## Modeling " 352 | ] 353 | }, 354 | { 355 | "cell_type": "markdown", 356 | "metadata": {}, 357 | "source": [ 358 | "\n", 359 | "
borrowed image from https://nlp.stanford.edu/~socherr/EMNLP2013_RNTN.pdf
" 360 | ] 361 | }, 362 | { 363 | "cell_type": "code", 364 | "execution_count": 14, 365 | "metadata": { 366 | "collapsed": false 367 | }, 368 | "outputs": [], 369 | "source": [ 370 | "class RNTN(nn.Module):\n", 371 | " \n", 372 | " def __init__(self,word2index,hidden_size,output_size):\n", 373 | " super(RNTN,self).__init__()\n", 374 | " \n", 375 | " self.word2index = word2index\n", 376 | " self.embed = nn.Embedding(len(word2index),hidden_size)\n", 377 | "# self.V = nn.ModuleList([nn.Linear(hidden_size*2,hidden_size*2) for _ in range(hidden_size)])\n", 378 | "# self.W = nn.Linear(hidden_size*2,hidden_size)\n", 379 | " self.V = nn.ParameterList([nn.Parameter(torch.randn(hidden_size*2,hidden_size*2)) for _ in range(hidden_size)]) # Tensor\n", 380 | " self.W = nn.Parameter(torch.randn(hidden_size*2,hidden_size))\n", 381 | " self.b = nn.Parameter(torch.randn(1,hidden_size))\n", 382 | "# self.W_out = nn.Parameter(torch.randn(hidden_size,output_size))\n", 383 | " self.W_out = nn.Linear(hidden_size,output_size)\n", 384 | " \n", 385 | " def init_weight(self):\n", 386 | " nn.init.xavier_uniform(self.embed.state_dict()['weight'])\n", 387 | " nn.init.xavier_uniform(self.W_out.state_dict()['weight'])\n", 388 | " for param in self.V.parameters():\n", 389 | " nn.init.xavier_uniform(param)\n", 390 | " nn.init.xavier_uniform(self.W)\n", 391 | " self.b.data.fill_(0)\n", 392 | "# nn.init.xavier_uniform(self.W_out)\n", 393 | " \n", 394 | " def tree_propagation(self,node):\n", 395 | " \n", 396 | " recursive_tensor = OrderedDict()\n", 397 | " current=None\n", 398 | " if node.isLeaf:\n", 399 | " tensor = Variable(LongTensor([self.word2index[node.word]])) if node.word in self.word2index.keys() \\\n", 400 | " else Variable(LongTensor([self.word2index['']]))\n", 401 | " current = self.embed(tensor) # 1xD\n", 402 | " else:\n", 403 | " recursive_tensor.update(self.tree_propagation(node.left))\n", 404 | " recursive_tensor.update(self.tree_propagation(node.right))\n", 405 | " \n", 406 | " concated = torch.cat([recursive_tensor[node.left],recursive_tensor[node.right]],1) # 1x2D\n", 407 | " xVx=[] \n", 408 | " for i,v in enumerate(self.V):\n", 409 | "# xVx.append(torch.matmul(v(concated),concated.transpose(0,1)))\n", 410 | " xVx.append(torch.matmul(torch.matmul(concated,v),concated.transpose(0,1)))\n", 411 | " \n", 412 | " xVx = torch.cat(xVx,1) # 1xD\n", 413 | "# Wx = self.W(concated)\n", 414 | " Wx = torch.matmul(concated,self.W) # 1xD\n", 415 | "\n", 416 | " current = F.tanh(xVx+Wx+self.b) # 1xD\n", 417 | " recursive_tensor[node]=current\n", 418 | " return recursive_tensor\n", 419 | " \n", 420 | " def forward(self,Trees,root_only=False):\n", 421 | " \n", 422 | " propagated=[]\n", 423 | " if not isinstance(Trees,list):\n", 424 | " Trees = [Trees]\n", 425 | " \n", 426 | " for Tree in Trees:\n", 427 | " recursive_tensor = self.tree_propagation(Tree.root)\n", 428 | " if root_only:\n", 429 | " recursive_tensor = recursive_tensor[Tree.root]\n", 430 | " propagated.append(recursive_tensor)\n", 431 | " else:\n", 432 | " recursive_tensor = [tensor for node,tensor in recursive_tensor.items()]\n", 433 | " propagated.extend(recursive_tensor)\n", 434 | " \n", 435 | " propagated = torch.cat(propagated) # (num_of_node in batch, D)\n", 436 | " \n", 437 | "# return F.log_softmax(propagated.matmul(self.W_out))\n", 438 | " return F.log_softmax(self.W_out(propagated))" 439 | ] 440 | }, 441 | { 442 | "cell_type": "markdown", 443 | "metadata": {}, 444 | "source": [ 445 | "## Training " 446 | ] 447 | }, 448 | { 449 | "cell_type": "markdown", 450 | "metadata": {}, 451 | "source": [ 452 | "It takes for a while... It builds its computational graph dynamically. So Its computation is difficult to train with batch." 453 | ] 454 | }, 455 | { 456 | "cell_type": "code", 457 | "execution_count": 15, 458 | "metadata": { 459 | "collapsed": true 460 | }, 461 | "outputs": [], 462 | "source": [ 463 | "HIDDEN_SIZE = 30\n", 464 | "ROOT_ONLY = False\n", 465 | "BATCH_SIZE = 20\n", 466 | "EPOCH = 20\n", 467 | "LR = 0.01\n", 468 | "LAMBDA = 1e-5\n", 469 | "RESCHEDULED=False" 470 | ] 471 | }, 472 | { 473 | "cell_type": "code", 474 | "execution_count": 18, 475 | "metadata": { 476 | "collapsed": false 477 | }, 478 | "outputs": [], 479 | "source": [ 480 | "model = RNTN(word2index,HIDDEN_SIZE,5)\n", 481 | "model.init_weight()\n", 482 | "if USE_CUDA:\n", 483 | " model = model.cuda()\n", 484 | "\n", 485 | "loss_function = nn.CrossEntropyLoss()\n", 486 | "optimizer = optim.Adam(model.parameters(),lr=LR)" 487 | ] 488 | }, 489 | { 490 | "cell_type": "code", 491 | "execution_count": 19, 492 | "metadata": { 493 | "collapsed": false 494 | }, 495 | "outputs": [ 496 | { 497 | "name": "stdout", 498 | "output_type": "stream", 499 | "text": [ 500 | "[0/20] mean_loss : 1.62\n", 501 | "[0/20] mean_loss : 1.25\n", 502 | "[0/20] mean_loss : 0.95\n", 503 | "[0/20] mean_loss : 0.90\n", 504 | "[0/20] mean_loss : 0.88\n", 505 | "[1/20] mean_loss : 0.88\n", 506 | "[1/20] mean_loss : 0.84\n", 507 | "[1/20] mean_loss : 0.83\n", 508 | "[1/20] mean_loss : 0.82\n", 509 | "[1/20] mean_loss : 0.82\n", 510 | "[2/20] mean_loss : 0.81\n", 511 | "[2/20] mean_loss : 0.79\n", 512 | "[2/20] mean_loss : 0.78\n", 513 | "[2/20] mean_loss : 0.76\n", 514 | "[2/20] mean_loss : 0.75\n", 515 | "[3/20] mean_loss : 0.68\n", 516 | "[3/20] mean_loss : 0.73\n", 517 | "[3/20] mean_loss : 0.74\n", 518 | "[3/20] mean_loss : 0.72\n", 519 | "[3/20] mean_loss : 0.72\n", 520 | "[4/20] mean_loss : 0.74\n", 521 | "[4/20] mean_loss : 0.69\n", 522 | "[4/20] mean_loss : 0.69\n", 523 | "[4/20] mean_loss : 0.68\n", 524 | "[4/20] mean_loss : 0.67\n", 525 | "[5/20] mean_loss : 0.73\n", 526 | "[5/20] mean_loss : 0.65\n", 527 | "[5/20] mean_loss : 0.64\n", 528 | "[5/20] mean_loss : 0.64\n", 529 | "[5/20] mean_loss : 0.65\n", 530 | "[6/20] mean_loss : 0.67\n", 531 | "[6/20] mean_loss : 0.62\n", 532 | "[6/20] mean_loss : 0.62\n", 533 | "[6/20] mean_loss : 0.62\n", 534 | "[6/20] mean_loss : 0.62\n", 535 | "[7/20] mean_loss : 0.57\n", 536 | "[7/20] mean_loss : 0.59\n", 537 | "[7/20] mean_loss : 0.59\n", 538 | "[7/20] mean_loss : 0.59\n", 539 | "[7/20] mean_loss : 0.59\n", 540 | "[8/20] mean_loss : 0.60\n", 541 | "[8/20] mean_loss : 0.58\n", 542 | "[8/20] mean_loss : 0.59\n", 543 | "[8/20] mean_loss : 0.60\n", 544 | "[8/20] mean_loss : 0.60\n", 545 | "[9/20] mean_loss : 0.52\n", 546 | "[9/20] mean_loss : 0.58\n", 547 | "[9/20] mean_loss : 0.60\n", 548 | "[9/20] mean_loss : 0.59\n", 549 | "[9/20] mean_loss : 0.59\n", 550 | "[10/20] mean_loss : 0.56\n", 551 | "[10/20] mean_loss : 0.56\n", 552 | "[10/20] mean_loss : 0.56\n", 553 | "[10/20] mean_loss : 0.56\n", 554 | "[10/20] mean_loss : 0.56\n", 555 | "[11/20] mean_loss : 0.52\n", 556 | "[11/20] mean_loss : 0.54\n", 557 | "[11/20] mean_loss : 0.54\n", 558 | "[11/20] mean_loss : 0.54\n", 559 | "[11/20] mean_loss : 0.55\n", 560 | "[12/20] mean_loss : 0.55\n", 561 | "[12/20] mean_loss : 0.53\n", 562 | "[12/20] mean_loss : 0.53\n", 563 | "[12/20] mean_loss : 0.53\n", 564 | "[12/20] mean_loss : 0.53\n", 565 | "[13/20] mean_loss : 0.59\n", 566 | "[13/20] mean_loss : 0.52\n", 567 | "[13/20] mean_loss : 0.52\n", 568 | "[13/20] mean_loss : 0.53\n", 569 | "[13/20] mean_loss : 0.53\n", 570 | "[14/20] mean_loss : 0.49\n", 571 | "[14/20] mean_loss : 0.51\n", 572 | "[14/20] mean_loss : 0.51\n", 573 | "[14/20] mean_loss : 0.52\n", 574 | "[14/20] mean_loss : 0.52\n", 575 | "[15/20] mean_loss : 0.43\n", 576 | "[15/20] mean_loss : 0.51\n", 577 | "[15/20] mean_loss : 0.51\n", 578 | "[15/20] mean_loss : 0.51\n", 579 | "[15/20] mean_loss : 0.51\n", 580 | "[16/20] mean_loss : 0.46\n", 581 | "[16/20] mean_loss : 0.50\n", 582 | "[16/20] mean_loss : 0.50\n", 583 | "[16/20] mean_loss : 0.50\n", 584 | "[16/20] mean_loss : 0.50\n", 585 | "[17/20] mean_loss : 0.50\n", 586 | "[17/20] mean_loss : 0.50\n", 587 | "[17/20] mean_loss : 0.50\n", 588 | "[17/20] mean_loss : 0.50\n", 589 | "[17/20] mean_loss : 0.51\n", 590 | "[18/20] mean_loss : 0.46\n", 591 | "[18/20] mean_loss : 0.50\n", 592 | "[18/20] mean_loss : 0.50\n", 593 | "[18/20] mean_loss : 0.49\n", 594 | "[18/20] mean_loss : 0.49\n", 595 | "[19/20] mean_loss : 0.49\n", 596 | "[19/20] mean_loss : 0.49\n", 597 | "[19/20] mean_loss : 0.49\n", 598 | "[19/20] mean_loss : 0.50\n", 599 | "[19/20] mean_loss : 0.50\n" 600 | ] 601 | } 602 | ], 603 | "source": [ 604 | "for epoch in range(EPOCH):\n", 605 | " losses=[]\n", 606 | " \n", 607 | " # learning rate annealing\n", 608 | " if RESCHEDULED==False and epoch==EPOCH//2:\n", 609 | " LR=LR*0.1\n", 610 | " optimizer = optim.Adam(model.parameters(),lr=LR,weight_decay=LAMBDA) # L2 norm\n", 611 | " RESCHEDULED=True\n", 612 | " \n", 613 | " for i, batch in enumerate(getBatch(BATCH_SIZE,train_data)):\n", 614 | " \n", 615 | " if ROOT_ONLY:\n", 616 | " labels = [tree.labels[-1] for tree in batch]\n", 617 | " labels = Variable(LongTensor(labels))\n", 618 | " else:\n", 619 | " labels = [tree.labels for tree in batch]\n", 620 | " labels = Variable(LongTensor(flatten(labels)))\n", 621 | " \n", 622 | " model.zero_grad()\n", 623 | " preds = model(batch,ROOT_ONLY)\n", 624 | " \n", 625 | " loss = loss_function(preds,labels)\n", 626 | " losses.append(loss.data.tolist()[0])\n", 627 | " \n", 628 | " loss.backward()\n", 629 | " optimizer.step()\n", 630 | " \n", 631 | " if i % 100==0:\n", 632 | " print('[%d/%d] mean_loss : %.2f' % (epoch,EPOCH,np.mean(losses)))\n", 633 | " losses=[]\n", 634 | " " 635 | ] 636 | }, 637 | { 638 | "cell_type": "markdown", 639 | "metadata": {}, 640 | "source": [ 641 | "The convergence of the model is unstable according to the initial values. I tried to 5~6 times for this." 642 | ] 643 | }, 644 | { 645 | "cell_type": "markdown", 646 | "metadata": {}, 647 | "source": [ 648 | "## Test" 649 | ] 650 | }, 651 | { 652 | "cell_type": "code", 653 | "execution_count": 20, 654 | "metadata": { 655 | "collapsed": false 656 | }, 657 | "outputs": [ 658 | { 659 | "name": "stdout", 660 | "output_type": "stream", 661 | "text": [ 662 | "Loading test trees..\n" 663 | ] 664 | } 665 | ], 666 | "source": [ 667 | "test_data = loadTrees('test')" 668 | ] 669 | }, 670 | { 671 | "cell_type": "code", 672 | "execution_count": 21, 673 | "metadata": { 674 | "collapsed": true 675 | }, 676 | "outputs": [], 677 | "source": [ 678 | "accuracy=0\n", 679 | "num_node=0" 680 | ] 681 | }, 682 | { 683 | "cell_type": "markdown", 684 | "metadata": {}, 685 | "source": [ 686 | "### Fine-grained all" 687 | ] 688 | }, 689 | { 690 | "cell_type": "markdown", 691 | "metadata": {}, 692 | "source": [ 693 | "In paper, they acheived 80.2 accuracy. " 694 | ] 695 | }, 696 | { 697 | "cell_type": "code", 698 | "execution_count": 23, 699 | "metadata": { 700 | "collapsed": false 701 | }, 702 | "outputs": [ 703 | { 704 | "name": "stdout", 705 | "output_type": "stream", 706 | "text": [ 707 | "79.33705899068254\n" 708 | ] 709 | } 710 | ], 711 | "source": [ 712 | "for test in test_data:\n", 713 | " model.zero_grad()\n", 714 | " preds = model(test,ROOT_ONLY)\n", 715 | " labels = test.labels[-1:] if ROOT_ONLY else test.labels\n", 716 | " for pred,label in zip(preds.max(1)[1].data.tolist(),labels):\n", 717 | " num_node+=1\n", 718 | " if pred==label:\n", 719 | " accuracy+=1\n", 720 | "\n", 721 | "print(accuracy/num_node*100)" 722 | ] 723 | }, 724 | { 725 | "cell_type": "markdown", 726 | "metadata": {}, 727 | "source": [ 728 | "## TODO " 729 | ] 730 | }, 731 | { 732 | "cell_type": "markdown", 733 | "metadata": {}, 734 | "source": [ 735 | "* https://github.com/nearai/pytorch-tools # Dynamic batch using TensorFold" 736 | ] 737 | }, 738 | { 739 | "cell_type": "markdown", 740 | "metadata": { 741 | "collapsed": true 742 | }, 743 | "source": [ 744 | "## Further topics " 745 | ] 746 | }, 747 | { 748 | "cell_type": "markdown", 749 | "metadata": {}, 750 | "source": [ 751 | "* Improved Semantic Representations From Tree-Structured Long Short-Term Memory Networks\n", 752 | "* A Fast Unified Model for Parsing and Sentence Understanding(SPINN)\n", 753 | "* Posting about SPINN" 754 | ] 755 | }, 756 | { 757 | "cell_type": "code", 758 | "execution_count": null, 759 | "metadata": { 760 | "collapsed": true 761 | }, 762 | "outputs": [], 763 | "source": [] 764 | } 765 | ], 766 | "metadata": { 767 | "kernelspec": { 768 | "display_name": "Python 3", 769 | "language": "python", 770 | "name": "python3" 771 | }, 772 | "language_info": { 773 | "codemirror_mode": { 774 | "name": "ipython", 775 | "version": 3 776 | }, 777 | "file_extension": ".py", 778 | "mimetype": "text/x-python", 779 | "name": "python", 780 | "nbconvert_exporter": "python", 781 | "pygments_lexer": "ipython3", 782 | "version": "3.5.2" 783 | } 784 | }, 785 | "nbformat": 4, 786 | "nbformat_minor": 2 787 | } 788 | -------------------------------------------------------------------------------- /notebooks/10.Dynamic-Memory-Network-for-Question-Answering.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# 10. \tDynamic Memory Networks for Question Answering" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "I recommend you take a look at these material first." 15 | ] 16 | }, 17 | { 18 | "cell_type": "markdown", 19 | "metadata": {}, 20 | "source": [ 21 | "* http://web.stanford.edu/class/cs224n/lectures/cs224n-2017-lecture16-DMN-QA.pdf\n", 22 | "* https://arxiv.org/abs/1506.07285" 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": 1, 28 | "metadata": { 29 | "collapsed": true 30 | }, 31 | "outputs": [], 32 | "source": [ 33 | "import torch\n", 34 | "import torch.nn as nn\n", 35 | "from torch.autograd import Variable\n", 36 | "import torch.optim as optim\n", 37 | "import torch.nn.functional as F\n", 38 | "import nltk\n", 39 | "import random\n", 40 | "import numpy as np\n", 41 | "from collections import Counter, OrderedDict\n", 42 | "import nltk\n", 43 | "from copy import deepcopy\n", 44 | "import os\n", 45 | "import re\n", 46 | "import unicodedata\n", 47 | "flatten = lambda l: [item for sublist in l for item in sublist]\n", 48 | "\n", 49 | "from torch.nn.utils.rnn import PackedSequence,pack_padded_sequence" 50 | ] 51 | }, 52 | { 53 | "cell_type": "code", 54 | "execution_count": 2, 55 | "metadata": { 56 | "collapsed": true 57 | }, 58 | "outputs": [], 59 | "source": [ 60 | "USE_CUDA = torch.cuda.is_available()\n", 61 | "\n", 62 | "FloatTensor = torch.cuda.FloatTensor if USE_CUDA else torch.FloatTensor\n", 63 | "LongTensor = torch.cuda.LongTensor if USE_CUDA else torch.LongTensor\n", 64 | "ByteTensor = torch.cuda.ByteTensor if USE_CUDA else torch.ByteTensor" 65 | ] 66 | }, 67 | { 68 | "cell_type": "code", 69 | "execution_count": 3, 70 | "metadata": { 71 | "collapsed": true 72 | }, 73 | "outputs": [], 74 | "source": [ 75 | "def getBatch(batch_size,train_data):\n", 76 | " random.shuffle(train_data)\n", 77 | " sindex=0\n", 78 | " eindex=batch_size\n", 79 | " while eindex < len(train_data):\n", 80 | " batch = train_data[sindex:eindex]\n", 81 | " temp = eindex\n", 82 | " eindex = eindex+batch_size\n", 83 | " sindex = temp\n", 84 | " yield batch\n", 85 | " \n", 86 | " if eindex >= len(train_data):\n", 87 | " batch = train_data[sindex:]\n", 88 | " yield batch" 89 | ] 90 | }, 91 | { 92 | "cell_type": "code", 93 | "execution_count": 4, 94 | "metadata": { 95 | "collapsed": true 96 | }, 97 | "outputs": [], 98 | "source": [ 99 | "def pad_to_batch(batch,w_to_ix): # for bAbI dataset\n", 100 | " fact,q,a = list(zip(*batch))\n", 101 | " max_fact = max([len(f) for f in fact])\n", 102 | " max_len = max([f.size(1) for f in flatten(fact)])\n", 103 | " max_q = max([qq.size(1) for qq in q])\n", 104 | " max_a = max([aa.size(1) for aa in a])\n", 105 | " \n", 106 | " facts,fact_masks,q_p,a_p=[],[],[],[]\n", 107 | " for i in range(len(batch)):\n", 108 | " fact_p_t=[]\n", 109 | " for j in range(len(fact[i])):\n", 110 | " if fact[i][j].size(1)']]*(max_len-fact[i][j].size(1)))).view(1,-1)],1))\n", 112 | " else:\n", 113 | " fact_p_t.append(fact[i][j])\n", 114 | "\n", 115 | " while len(fact_p_t)']]*max_len)).view(1,-1))\n", 117 | "\n", 118 | " fact_p_t = torch.cat(fact_p_t)\n", 119 | " facts.append(fact_p_t)\n", 120 | " fact_masks.append(torch.cat([Variable(ByteTensor(tuple(map(lambda s: s ==0, t.data))),volatile=False) for t in fact_p_t]).view(fact_p_t.size(0),-1))\n", 121 | "\n", 122 | " if q[i].size(1)']]*(max_q-q[i].size(1)))).view(1,-1)],1))\n", 124 | " else:\n", 125 | " q_p.append(q[i])\n", 126 | "\n", 127 | " if a[i].size(1)']]*(max_a-a[i].size(1)))).view(1,-1)],1))\n", 129 | " else:\n", 130 | " a_p.append(a[i])\n", 131 | "\n", 132 | " questions = torch.cat(q_p)\n", 133 | " answers = torch.cat(a_p)\n", 134 | " question_masks = torch.cat([Variable(ByteTensor(tuple(map(lambda s: s ==0, t.data))),volatile=False) for t in questions]).view(questions.size(0),-1)\n", 135 | " \n", 136 | " return facts, fact_masks, questions, question_masks, answers" 137 | ] 138 | }, 139 | { 140 | "cell_type": "code", 141 | "execution_count": 5, 142 | "metadata": { 143 | "collapsed": true 144 | }, 145 | "outputs": [], 146 | "source": [ 147 | "def prepare_sequence(seq, to_index):\n", 148 | " idxs = list(map(lambda w: to_index[w] if w in to_index.keys() else to_index[\"\"], seq))\n", 149 | " return Variable(LongTensor(idxs))" 150 | ] 151 | }, 152 | { 153 | "cell_type": "markdown", 154 | "metadata": {}, 155 | "source": [ 156 | "## Data load and Preprocessing " 157 | ] 158 | }, 159 | { 160 | "cell_type": "markdown", 161 | "metadata": {}, 162 | "source": [ 163 | "### bAbI dataset(https://research.fb.com/downloads/babi/)" 164 | ] 165 | }, 166 | { 167 | "cell_type": "code", 168 | "execution_count": 24, 169 | "metadata": { 170 | "collapsed": false 171 | }, 172 | "outputs": [], 173 | "source": [ 174 | "def bAbI_data_load(path):\n", 175 | " try:\n", 176 | " data = open(path).readlines()\n", 177 | " except:\n", 178 | " print(\"Such a file does not exist at %s\".format(path))\n", 179 | " return None\n", 180 | " \n", 181 | " data = [d[:-1] for d in data]\n", 182 | " data_p=[]\n", 183 | " fact=[]\n", 184 | " qa=[]\n", 185 | " try:\n", 186 | " for d in data:\n", 187 | " index = d.split(' ')[0]\n", 188 | " if index=='1':\n", 189 | " fact=[]\n", 190 | " qa=[]\n", 191 | " if '?' in d:\n", 192 | " temp = d.split('\\t')\n", 193 | " q = temp[0].strip().replace('?','').split(' ')[1:]+['?']\n", 194 | " a = temp[1].split()+['']\n", 195 | " stemp = deepcopy(fact)\n", 196 | " data_p.append([stemp,q,a])\n", 197 | " else:\n", 198 | " tokens = d.replace('.','').split(' ')[1:]+['']\n", 199 | " fact.append(tokens)\n", 200 | " except:\n", 201 | " print(\"Please check the data is right\")\n", 202 | " return None\n", 203 | " return data_p" 204 | ] 205 | }, 206 | { 207 | "cell_type": "code", 208 | "execution_count": 25, 209 | "metadata": { 210 | "collapsed": true 211 | }, 212 | "outputs": [], 213 | "source": [ 214 | "train_data = bAbI_data_load('../dataset/corpus/bAbI/en-10k/qa5_three-arg-relations_train.txt')" 215 | ] 216 | }, 217 | { 218 | "cell_type": "code", 219 | "execution_count": 26, 220 | "metadata": { 221 | "collapsed": false 222 | }, 223 | "outputs": [ 224 | { 225 | "data": { 226 | "text/plain": [ 227 | "[[['Bill', 'travelled', 'to', 'the', 'office', ''],\n", 228 | " ['Bill', 'picked', 'up', 'the', 'football', 'there', ''],\n", 229 | " ['Bill', 'went', 'to', 'the', 'bedroom', ''],\n", 230 | " ['Bill', 'gave', 'the', 'football', 'to', 'Fred', '']],\n", 231 | " ['What', 'did', 'Bill', 'give', 'to', 'Fred', '?'],\n", 232 | " ['football', '']]" 233 | ] 234 | }, 235 | "execution_count": 26, 236 | "metadata": {}, 237 | "output_type": "execute_result" 238 | } 239 | ], 240 | "source": [ 241 | "train_data[0]" 242 | ] 243 | }, 244 | { 245 | "cell_type": "code", 246 | "execution_count": 11, 247 | "metadata": { 248 | "collapsed": true 249 | }, 250 | "outputs": [], 251 | "source": [ 252 | "fact,q,a = list(zip(*train_data))" 253 | ] 254 | }, 255 | { 256 | "cell_type": "code", 257 | "execution_count": 12, 258 | "metadata": { 259 | "collapsed": true 260 | }, 261 | "outputs": [], 262 | "source": [ 263 | "vocab = list(set(flatten(flatten(fact))+flatten(q)+flatten(a)))" 264 | ] 265 | }, 266 | { 267 | "cell_type": "code", 268 | "execution_count": 13, 269 | "metadata": { 270 | "collapsed": true 271 | }, 272 | "outputs": [], 273 | "source": [ 274 | "word2index={'':0,'':1,'':2,'':3}\n", 275 | "for vo in vocab:\n", 276 | " if vo not in word2index.keys():\n", 277 | " word2index[vo]=len(word2index)\n", 278 | "index2word = {v:k for k,v in word2index.items()}" 279 | ] 280 | }, 281 | { 282 | "cell_type": "code", 283 | "execution_count": 14, 284 | "metadata": { 285 | "collapsed": false 286 | }, 287 | "outputs": [ 288 | { 289 | "data": { 290 | "text/plain": [ 291 | "44" 292 | ] 293 | }, 294 | "execution_count": 14, 295 | "metadata": {}, 296 | "output_type": "execute_result" 297 | } 298 | ], 299 | "source": [ 300 | "len(word2index)" 301 | ] 302 | }, 303 | { 304 | "cell_type": "code", 305 | "execution_count": 15, 306 | "metadata": { 307 | "collapsed": true 308 | }, 309 | "outputs": [], 310 | "source": [ 311 | "for t in train_data:\n", 312 | " for i,fact in enumerate(t[0]):\n", 313 | " t[0][i] = prepare_sequence(fact,word2index).view(1,-1)\n", 314 | " \n", 315 | " t[1] = prepare_sequence(t[1],word2index).view(1,-1)\n", 316 | " t[2] = prepare_sequence(t[2],word2index).view(1,-1)" 317 | ] 318 | }, 319 | { 320 | "cell_type": "markdown", 321 | "metadata": {}, 322 | "source": [ 323 | "## Modeling " 324 | ] 325 | }, 326 | { 327 | "cell_type": "markdown", 328 | "metadata": {}, 329 | "source": [ 330 | "\n", 331 | "
borrowed image from https://arxiv.org/pdf/1506.07285.pdf
" 332 | ] 333 | }, 334 | { 335 | "cell_type": "code", 336 | "execution_count": 16, 337 | "metadata": { 338 | "collapsed": true 339 | }, 340 | "outputs": [], 341 | "source": [ 342 | "class DMN(nn.Module):\n", 343 | " def __init__(self, input_size,hidden_size,output_size,dropout_p=0.1):\n", 344 | " super(DMN, self).__init__()\n", 345 | " \n", 346 | " self.hidden_size = hidden_size\n", 347 | " self.embed = nn.Embedding(input_size, hidden_size, padding_idx=0) #sparse=True)\n", 348 | " self.input_gru = nn.GRU(hidden_size, hidden_size,batch_first=True)\n", 349 | " self.question_gru = nn.GRU(hidden_size, hidden_size,batch_first=True)\n", 350 | " \n", 351 | " self.gate = nn.Sequential(\n", 352 | " nn.Linear(hidden_size*4,hidden_size),\n", 353 | " nn.Tanh(),\n", 354 | " nn.Linear(hidden_size,1),\n", 355 | " nn.Sigmoid()\n", 356 | " )\n", 357 | " \n", 358 | " self.attention_grucell = nn.GRUCell(hidden_size, hidden_size)\n", 359 | " self.memory_grucell = nn.GRUCell(hidden_size,hidden_size)\n", 360 | " self.answer_grucell = nn.GRUCell(hidden_size*2, hidden_size)\n", 361 | " self.answer_fc = nn.Linear(hidden_size,output_size)\n", 362 | " \n", 363 | " self.dropout = nn.Dropout(dropout_p)\n", 364 | " \n", 365 | " def init_hidden(self,inputs):\n", 366 | " hidden = Variable(torch.zeros(1,inputs.size(0),self.hidden_size))\n", 367 | " return hidden.cuda() if USE_CUDA else hidden\n", 368 | " \n", 369 | " def init_weight(self):\n", 370 | " nn.init.xavier_uniform(self.embed.state_dict()['weight'])\n", 371 | " \n", 372 | " for name, param in self.input_gru.state_dict().items():\n", 373 | " if 'weight' in name: nn.init.xavier_normal(param)\n", 374 | " for name, param in self.question_gru.state_dict().items():\n", 375 | " if 'weight' in name: nn.init.xavier_normal(param)\n", 376 | " for name, param in self.gate.state_dict().items():\n", 377 | " if 'weight' in name: nn.init.xavier_normal(param)\n", 378 | " for name, param in self.attention_grucell.state_dict().items():\n", 379 | " if 'weight' in name: nn.init.xavier_normal(param)\n", 380 | " for name, param in self.memory_grucell.state_dict().items():\n", 381 | " if 'weight' in name: nn.init.xavier_normal(param)\n", 382 | " for name, param in self.answer_grucell.state_dict().items():\n", 383 | " if 'weight' in name: nn.init.xavier_normal(param)\n", 384 | " \n", 385 | " nn.init.xavier_normal(self.answer_fc.state_dict()['weight'])\n", 386 | " self.answer_fc.bias.data.fill_(0)\n", 387 | " \n", 388 | " def forward(self,facts,fact_masks,questions,question_masks,num_decode,episodes=3,is_training=False):\n", 389 | " \"\"\"\n", 390 | " facts : (B,T_C,T_I) / LongTensor in List # batch_size, num_of_facts, length_of_each_fact(padded)\n", 391 | " fact_masks : (B,T_C,T_I) / ByteTensor in List # batch_size, num_of_facts, length_of_each_fact(padded)\n", 392 | " questions : (B,T_Q) / LongTensor # batch_size, question_length\n", 393 | " question_masks : (B,T_Q) / ByteTensor # batch_size, question_length\n", 394 | " \"\"\"\n", 395 | " # Input Module\n", 396 | " C=[] # encoded facts\n", 397 | " for fact,fact_mask in zip(facts,fact_masks):\n", 398 | " embeds = self.embed(fact)\n", 399 | " if is_training:\n", 400 | " embeds = self.dropout(embeds)\n", 401 | " hidden = self.init_hidden(fact)\n", 402 | " outputs,hidden = self.input_gru(embeds,hidden)\n", 403 | " real_hidden=[]\n", 404 | "\n", 405 | " for i,o in enumerate(outputs): # B,T,D\n", 406 | " real_length = fact_mask[i].data.tolist().count(0) \n", 407 | " real_hidden.append(o[real_length-1])\n", 408 | "\n", 409 | " C.append(torch.cat(real_hidden).view(fact.size(0),-1).unsqueeze(0))\n", 410 | " \n", 411 | " encoded_facts = torch.cat(C) # B,T_C,D\n", 412 | " \n", 413 | " # Question Module\n", 414 | " embeds = self.embed(questions)\n", 415 | " if training:\n", 416 | " embeds = self.dropout(embeds)\n", 417 | " hidden = self.init_hidden(questions)\n", 418 | " outputs, hidden = self.question_gru(embeds,hidden)\n", 419 | " \n", 420 | " if isinstance(question_masks,torch.autograd.variable.Variable):\n", 421 | " real_question=[]\n", 422 | " for i,o in enumerate(outputs): # B,T,D\n", 423 | " real_length = question_masks[i].data.tolist().count(0) \n", 424 | " real_question.append(o[real_length-1])\n", 425 | " encoded_question = torch.cat(real_question).view(questions.size(0),-1) # B,D\n", 426 | " else: # for inference mode\n", 427 | " encoded_question = hidden.squeeze(0) # B,D\n", 428 | " \n", 429 | " # Episodic Memory Module\n", 430 | " memory = encoded_question\n", 431 | " T_C = encoded_facts.size(1)\n", 432 | " B = encoded_facts.size(0)\n", 433 | " for i in range(episodes):\n", 434 | " hidden = self.init_hidden(encoded_facts.transpose(0,1)[0]).squeeze(0) # B,D\n", 435 | " for t in range(T_C):\n", 436 | " #TODO: fact masking\n", 437 | " #TODO: gate function => softmax\n", 438 | " z = torch.cat([\n", 439 | " encoded_facts.transpose(0,1)[t]*encoded_question, # B,D , element-wise product\n", 440 | " encoded_facts.transpose(0,1)[t]*memory, # B,D , element-wise product\n", 441 | " torch.abs(encoded_facts.transpose(0,1)[t]-encoded_question), # B,D\n", 442 | " torch.abs(encoded_facts.transpose(0,1)[t]-memory) # B,D\n", 443 | " ],1)\n", 444 | " g_t = self.gate(z) # B,1 scalar\n", 445 | " hidden = g_t*self.attention_grucell(encoded_facts.transpose(0,1)[t],hidden) + (1-g_t)*hidden\n", 446 | " \n", 447 | " e = hidden\n", 448 | " memory = self.memory_grucell(e,memory)\n", 449 | " \n", 450 | " # Answer Module\n", 451 | " answer_hidden = memory\n", 452 | " start_decode = Variable(LongTensor([[word2index['']]*memory.size(0)])).transpose(0,1)\n", 453 | " y_t_1 = self.embed(start_decode).squeeze(1) # B,D\n", 454 | " \n", 455 | " decodes=[]\n", 456 | " for t in range(num_decode):\n", 457 | " answer_hidden = self.answer_grucell(torch.cat([y_t_1,encoded_question],1),answer_hidden)\n", 458 | " decodes.append(F.log_softmax(self.answer_fc(answer_hidden)))\n", 459 | " return torch.cat(decodes,1).view(B*num_decode,-1)\n" 460 | ] 461 | }, 462 | { 463 | "cell_type": "markdown", 464 | "metadata": {}, 465 | "source": [ 466 | "## Train " 467 | ] 468 | }, 469 | { 470 | "cell_type": "markdown", 471 | "metadata": {}, 472 | "source": [ 473 | "It takes for a while if you use just cpu." 474 | ] 475 | }, 476 | { 477 | "cell_type": "code", 478 | "execution_count": 17, 479 | "metadata": { 480 | "collapsed": true 481 | }, 482 | "outputs": [], 483 | "source": [ 484 | "HIDDEN_SIZE=80\n", 485 | "BATCH_SIZE=64\n", 486 | "LR=0.001\n", 487 | "EPOCH=50\n", 488 | "NUM_EPISODE=3\n", 489 | "EARLY_STOPPING=False" 490 | ] 491 | }, 492 | { 493 | "cell_type": "code", 494 | "execution_count": 18, 495 | "metadata": { 496 | "collapsed": false 497 | }, 498 | "outputs": [], 499 | "source": [ 500 | "model = DMN(len(word2index),HIDDEN_SIZE,len(word2index))\n", 501 | "model.init_weight()\n", 502 | "if USE_CUDA:\n", 503 | " model = model.cuda()\n", 504 | "\n", 505 | "loss_function = nn.CrossEntropyLoss(ignore_index=0)\n", 506 | "optimizer = optim.Adam(model.parameters(),lr=LR)" 507 | ] 508 | }, 509 | { 510 | "cell_type": "code", 511 | "execution_count": 19, 512 | "metadata": { 513 | "collapsed": false 514 | }, 515 | "outputs": [ 516 | { 517 | "name": "stdout", 518 | "output_type": "stream", 519 | "text": [ 520 | "[0/50] mean_loss : 3.86\n", 521 | "[0/50] mean_loss : 1.32\n", 522 | "[1/50] mean_loss : 0.68\n", 523 | "[1/50] mean_loss : 0.65\n", 524 | "[2/50] mean_loss : 0.62\n", 525 | "[2/50] mean_loss : 0.65\n", 526 | "[3/50] mean_loss : 0.65\n", 527 | "[3/50] mean_loss : 0.64\n", 528 | "[4/50] mean_loss : 0.60\n", 529 | "[4/50] mean_loss : 0.62\n", 530 | "[5/50] mean_loss : 0.63\n", 531 | "[5/50] mean_loss : 0.61\n", 532 | "[6/50] mean_loss : 0.60\n", 533 | "[6/50] mean_loss : 0.61\n", 534 | "[7/50] mean_loss : 0.63\n", 535 | "[7/50] mean_loss : 0.60\n", 536 | "[8/50] mean_loss : 0.62\n", 537 | "[8/50] mean_loss : 0.60\n", 538 | "[9/50] mean_loss : 0.58\n", 539 | "[9/50] mean_loss : 0.60\n", 540 | "[10/50] mean_loss : 0.60\n", 541 | "[10/50] mean_loss : 0.60\n", 542 | "[11/50] mean_loss : 0.62\n", 543 | "[11/50] mean_loss : 0.60\n", 544 | "[12/50] mean_loss : 0.61\n", 545 | "[12/50] mean_loss : 0.60\n", 546 | "[13/50] mean_loss : 0.57\n", 547 | "[13/50] mean_loss : 0.60\n", 548 | "[14/50] mean_loss : 0.59\n", 549 | "[14/50] mean_loss : 0.60\n", 550 | "[15/50] mean_loss : 0.61\n", 551 | "[15/50] mean_loss : 0.60\n", 552 | "[16/50] mean_loss : 0.59\n", 553 | "[16/50] mean_loss : 0.60\n", 554 | "[17/50] mean_loss : 0.59\n", 555 | "[17/50] mean_loss : 0.60\n", 556 | "[18/50] mean_loss : 0.51\n", 557 | "[18/50] mean_loss : 0.50\n", 558 | "[19/50] mean_loss : 0.44\n", 559 | "[19/50] mean_loss : 0.37\n", 560 | "[20/50] mean_loss : 0.30\n", 561 | "[20/50] mean_loss : 0.33\n", 562 | "[21/50] mean_loss : 0.31\n", 563 | "[21/50] mean_loss : 0.31\n", 564 | "[22/50] mean_loss : 0.29\n", 565 | "[22/50] mean_loss : 0.31\n", 566 | "[23/50] mean_loss : 0.29\n", 567 | "[23/50] mean_loss : 0.31\n", 568 | "[24/50] mean_loss : 0.24\n", 569 | "[24/50] mean_loss : 0.31\n", 570 | "[25/50] mean_loss : 0.30\n", 571 | "[25/50] mean_loss : 0.30\n", 572 | "[26/50] mean_loss : 0.14\n", 573 | "[26/50] mean_loss : 0.16\n", 574 | "[27/50] mean_loss : 0.12\n", 575 | "[27/50] mean_loss : 0.15\n", 576 | "[28/50] mean_loss : 0.18\n", 577 | "[28/50] mean_loss : 0.14\n", 578 | "[29/50] mean_loss : 0.12\n", 579 | "[29/50] mean_loss : 0.14\n", 580 | "[30/50] mean_loss : 0.14\n", 581 | "[30/50] mean_loss : 0.14\n", 582 | "[31/50] mean_loss : 0.13\n", 583 | "[31/50] mean_loss : 0.14\n", 584 | "[32/50] mean_loss : 0.11\n", 585 | "[32/50] mean_loss : 0.13\n", 586 | "[33/50] mean_loss : 0.08\n", 587 | "[33/50] mean_loss : 0.06\n", 588 | "[34/50] mean_loss : 0.01\n", 589 | "[34/50] mean_loss : 0.03\n", 590 | "[35/50] mean_loss : 0.01\n", 591 | "Early Stopping!\n" 592 | ] 593 | } 594 | ], 595 | "source": [ 596 | "for epoch in range(EPOCH):\n", 597 | " losses=[]\n", 598 | " if EARLY_STOPPING: break\n", 599 | " \n", 600 | " for i,batch in enumerate(getBatch(BATCH_SIZE,train_data)):\n", 601 | " facts, fact_masks, questions, question_masks, answers = pad_to_batch(batch,word2index)\n", 602 | " \n", 603 | " model.zero_grad()\n", 604 | " pred = model(facts,fact_masks,questions,question_masks,answers.size(1),NUM_EPISODE,True)\n", 605 | " loss = loss_function(pred,answers.view(-1))\n", 606 | " losses.append(loss.data.tolist()[0])\n", 607 | " \n", 608 | " loss.backward()\n", 609 | " optimizer.step()\n", 610 | " \n", 611 | " if i % 100==0:\n", 612 | " print(\"[%d/%d] mean_loss : %0.2f\" %(epoch,EPOCH,np.mean(losses)))\n", 613 | " \n", 614 | " if np.mean(losses)<0.01:\n", 615 | " EARLY_STOPPING=True\n", 616 | " print(\"Early Stopping!\")\n", 617 | " break\n", 618 | " losses=[]" 619 | ] 620 | }, 621 | { 622 | "cell_type": "markdown", 623 | "metadata": {}, 624 | "source": [ 625 | "## Test " 626 | ] 627 | }, 628 | { 629 | "cell_type": "code", 630 | "execution_count": 21, 631 | "metadata": { 632 | "collapsed": true 633 | }, 634 | "outputs": [], 635 | "source": [ 636 | "def pad_to_fact(fact,x_to_ix): # this is for inference\n", 637 | " \n", 638 | " max_x = max([s.size(1) for s in fact])\n", 639 | " x_p=[]\n", 640 | " for i in range(len(fact)):\n", 641 | " if fact[i].size(1)']]*(max_x-fact[i].size(1)))).view(1,-1)],1))\n", 643 | " else:\n", 644 | " x_p.append(fact[i])\n", 645 | " \n", 646 | " fact = torch.cat(x_p)\n", 647 | " fact_mask = torch.cat([Variable(ByteTensor(tuple(map(lambda s: s ==0, t.data))),volatile=False) for t in fact]).view(fact.size(0),-1)\n", 648 | " return fact,fact_mask" 649 | ] 650 | }, 651 | { 652 | "cell_type": "markdown", 653 | "metadata": {}, 654 | "source": [ 655 | "### Prepare Test data " 656 | ] 657 | }, 658 | { 659 | "cell_type": "code", 660 | "execution_count": 27, 661 | "metadata": { 662 | "collapsed": false 663 | }, 664 | "outputs": [], 665 | "source": [ 666 | "test_data = bAbI_data_load('../dataset/bAbI/en-10k/qa5_three-arg-relations_test.txt')" 667 | ] 668 | }, 669 | { 670 | "cell_type": "code", 671 | "execution_count": 28, 672 | "metadata": { 673 | "collapsed": true 674 | }, 675 | "outputs": [], 676 | "source": [ 677 | "for t in test_data:\n", 678 | " for i,fact in enumerate(t[0]):\n", 679 | " t[0][i] = prepare_sequence(fact,word2index).view(1,-1)\n", 680 | " \n", 681 | " t[1] = prepare_sequence(t[1],word2index).view(1,-1)\n", 682 | " t[2] = prepare_sequence(t[2],word2index).view(1,-1)" 683 | ] 684 | }, 685 | { 686 | "cell_type": "markdown", 687 | "metadata": {}, 688 | "source": [ 689 | "### Accuracy " 690 | ] 691 | }, 692 | { 693 | "cell_type": "code", 694 | "execution_count": 31, 695 | "metadata": { 696 | "collapsed": true 697 | }, 698 | "outputs": [], 699 | "source": [ 700 | "accuracy=0" 701 | ] 702 | }, 703 | { 704 | "cell_type": "code", 705 | "execution_count": 32, 706 | "metadata": { 707 | "collapsed": false 708 | }, 709 | "outputs": [ 710 | { 711 | "name": "stdout", 712 | "output_type": "stream", 713 | "text": [ 714 | "97.39999999999999\n" 715 | ] 716 | } 717 | ], 718 | "source": [ 719 | "for t in test_data:\n", 720 | " fact, fact_mask = pad_to_fact(t[0],word2index)\n", 721 | " question = t[1]\n", 722 | " question_mask = Variable(ByteTensor([0]*t[1].size(1)),volatile=False).unsqueeze(0)\n", 723 | " answer = t[2].squeeze(0)\n", 724 | " \n", 725 | " model.zero_grad()\n", 726 | " pred = model([fact],[fact_mask],question,question_mask,answer.size(0),NUM_EPISODE)\n", 727 | " if pred.max(1)[1].data.tolist()==answer.data.tolist():\n", 728 | " accuracy+=1\n", 729 | "\n", 730 | "print(accuracy/len(test_data)*100)" 731 | ] 732 | }, 733 | { 734 | "cell_type": "markdown", 735 | "metadata": {}, 736 | "source": [ 737 | "### Sample test result " 738 | ] 739 | }, 740 | { 741 | "cell_type": "code", 742 | "execution_count": 34, 743 | "metadata": { 744 | "collapsed": false 745 | }, 746 | "outputs": [ 747 | { 748 | "name": "stdout", 749 | "output_type": "stream", 750 | "text": [ 751 | "Facts : \n", 752 | "Bill went back to the bedroom \n", 753 | "Mary went to the office \n", 754 | "Jeff journeyed to the kitchen \n", 755 | "Fred journeyed to the kitchen \n", 756 | "Fred got the milk there \n", 757 | "Fred handed the milk to Jeff \n", 758 | "Jeff passed the milk to Fred \n", 759 | "Fred gave the milk to Jeff \n", 760 | "\n", 761 | "Question : Who received the milk ?\n", 762 | "\n", 763 | "Answer : Jeff \n", 764 | "Prediction : Jeff \n" 765 | ] 766 | } 767 | ], 768 | "source": [ 769 | "t = random.choice(test_data)\n", 770 | "fact, fact_mask = pad_to_fact(t[0],word2index)\n", 771 | "question = t[1]\n", 772 | "question_mask = Variable(ByteTensor([0]*t[1].size(1)),volatile=False).unsqueeze(0)\n", 773 | "answer = t[2].squeeze(0)\n", 774 | "\n", 775 | "model.zero_grad()\n", 776 | "pred = model([fact],[fact_mask],question,question_mask,answer.size(0),NUM_EPISODE)\n", 777 | "\n", 778 | "print(\"Facts : \")\n", 779 | "print('\\n'.join([' '.join(list(map(lambda x : index2word[x],f))) for f in fact.data.tolist()]))\n", 780 | "print(\"\")\n", 781 | "print(\"Question : \",' '.join(list(map(lambda x:index2word[x],question.data.tolist()[0]))))\n", 782 | "print(\"\")\n", 783 | "print(\"Answer : \",' '.join(list(map(lambda x:index2word[x],answer.data.tolist()))))\n", 784 | "print(\"Prediction : \",' '.join(list(map(lambda x:index2word[x],pred.max(1)[1].data.tolist()))))" 785 | ] 786 | }, 787 | { 788 | "cell_type": "markdown", 789 | "metadata": { 790 | "collapsed": true 791 | }, 792 | "source": [ 793 | "## Further topics " 794 | ] 795 | }, 796 | { 797 | "cell_type": "markdown", 798 | "metadata": {}, 799 | "source": [ 800 | "* Dynamic Memory Networks for Visual and Textual Question Answering(DMN+)\n", 801 | "* DMN+ Pytorch implementation\n", 802 | "* Dynamic Coattention Networks For Question Answering\n", 803 | "* DCN+: Mixed Objective and Deep Residual Coattention for Question Answering" 804 | ] 805 | }, 806 | { 807 | "cell_type": "code", 808 | "execution_count": null, 809 | "metadata": { 810 | "collapsed": true 811 | }, 812 | "outputs": [], 813 | "source": [] 814 | } 815 | ], 816 | "metadata": { 817 | "kernelspec": { 818 | "display_name": "Python 3", 819 | "language": "python", 820 | "name": "python3" 821 | }, 822 | "language_info": { 823 | "codemirror_mode": { 824 | "name": "ipython", 825 | "version": 3 826 | }, 827 | "file_extension": ".py", 828 | "mimetype": "text/x-python", 829 | "name": "python", 830 | "nbconvert_exporter": "python", 831 | "pygments_lexer": "ipython3", 832 | "version": "3.5.2" 833 | } 834 | }, 835 | "nbformat": 4, 836 | "nbformat_minor": 2 837 | } 838 | -------------------------------------------------------------------------------- /script/docker-compose.yml: -------------------------------------------------------------------------------- 1 | cs224n: 2 | image: dsksd/deepstudy:0.2 3 | ports: 4 | - "8888:8888" 5 | - "6006:6006" 6 | volumes: 7 | - ../:/notebooks/work 8 | command: /run.sh 9 | tty: true -------------------------------------------------------------------------------- /script/prepare_dataset.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | OUT_DIR="${1:-../dataset}" 4 | 5 | mkdir -v -p $OUT_DIR 6 | 7 | echo "download machine translation dataset from http://www.manythings.org/anki/..." 8 | curl -o "$OUT_DIR/fra-eng.zip" "http://www.manythings.org/anki/fra-eng.zip" 9 | unzip "$OUT_DIR/fra-eng.zip" -d "$OUT_DIR" 10 | rm "$OUT_DIR/fra-eng.zip" 11 | rm "$OUT_DIR/_about.txt" 12 | mv "$OUT_DIR/fra.txt" "$OUT_DIR/eng-fra.txt" 13 | 14 | echo "download ptb dataset... (clone from https://github.com/tomsercu/lstm/tree/master/data" 15 | mkdir -v -p "$OUT_DIR/ptb" 16 | wget "https://raw.githubusercontent.com/tomsercu/lstm/master/data/ptb.train.txt" -P "$OUT_DIR/ptb" 17 | wget "https://raw.githubusercontent.com/tomsercu/lstm/master/data/ptb.valid.txt" -P "$OUT_DIR/ptb" 18 | wget "https://raw.githubusercontent.com/tomsercu/lstm/master/data/ptb.test.txt" -P "$OUT_DIR/ptb" 19 | 20 | echo "download TREC question dataset..." 21 | curl -o "$OUT_DIR/train_5500.label.txt" "http://cogcomp.org/Data/QA/QC/train_5500.label" 22 | 23 | echo "download Stanford sentment treebank..." 24 | curl -o "$OUT_DIR/trainDevTestTrees_PTB.zip" "https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip" 25 | unzip "$OUT_DIR/trainDevTestTrees_PTB.zip" -d "$OUT_DIR" 26 | rm "$OUT_DIR/trainDevTestTrees_PTB.zip" 27 | 28 | echo "download bAbI dataset..." 29 | curl -o "$OUT_DIR/tasks_1-20_v1-2.tar.gz" "http://www.thespermwhale.com/jaseweston/babi/tasks_1-20_v1-2.tar.gz" 30 | tar zxvf "$OUT_DIR/tasks_1-20_v1-2.tar.gz" 31 | mv "tasks_1-20_v1-2" "$OUT_DIR/bAbI" 32 | rm "$OUT_DIR/tasks_1-20_v1-2.tar.gz" 33 | 34 | echo "download nltk dataset..." 35 | python3 -c "import nltk;nltk.download('gutenberg');nltk.download('brown');nltk.download('conll2002');nltk.download('timit')" 36 | 37 | echo "download dependency parser dataset... (clone from https://github.com/rguthrie3/DeepDependencyParsingProblemSet" 38 | mkdir -v -p "$OUT_DIR/dparser" 39 | wget "https://raw.githubusercontent.com/rguthrie3/DeepDependencyParsingProblemSet/master/data/train.txt" -P "$OUT_DIR/dparser" 40 | wget "https://raw.githubusercontent.com/rguthrie3/DeepDependencyParsingProblemSet/master/data/vocab.txt" -P "$OUT_DIR/dparser" 41 | wget "https://raw.githubusercontent.com/rguthrie3/DeepDependencyParsingProblemSet/master/data/dev.txt" -P "$OUT_DIR/dparser" -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from distutils.core import setup 2 | from Cython.Build import cythonize 3 | from Cython.Distutils import build_ext 4 | from distutils.extension import Extension 5 | 6 | setup(name="Corpus",cmdclass = {'build_ext': build_ext},ext_modules = [Extension("Corpus", ["Corpus.pyx"],language='c++')]) --------------------------------------------------------------------------------