├── .gitignore ├── Corpus.pyx ├── LICENSE ├── README.md ├── images ├── 01.skipgram-objective.png ├── 01.skipgram-prepare-data.png ├── 02.skipgram-objective.png ├── 03.glove-objective.png ├── 03.glove-weighting-function.png ├── 04.window-classifier-architecture.png ├── 04.window-data.png ├── 05.neural-dparser-architecture.png ├── 05.transition-based-parse.png ├── 06.rnnlm-architecture.png ├── 07.attention-mechanism.png ├── 07.pad_to_sequence.png ├── 07.seq2seq.png ├── 08.cnn-for-text-architecture.png ├── 09.rntn-layer.png └── 10.dmn-architecture.png ├── notebooks ├── 01.Skip-gram-Naive-Softmax.ipynb ├── 02.Skip-gram-Negative-Sampling.ipynb ├── 03.GloVe.ipynb ├── 04.Window-Classifier-for-NER.ipynb ├── 05.Neural-Dependancy-Parser.ipynb ├── 06.RNN-Language-Model.ipynb ├── 07.Neural-Machine-Translation-with-Attention.ipynb ├── 08.CNN-for-Text-Classification.ipynb ├── 09.Recursive-NN-for-Sentiment-Classification.ipynb └── 10.Dynamic-Memory-Network-for-Question-Answering.ipynb ├── script ├── docker-compose.yml └── prepare_dataset.sh └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.txt 2 | __pycache__ 3 | .ipynb_checkpoints 4 | build/ 5 | dataset/ 6 | model/ 7 | *.cpp 8 | *.so 9 | *.bin 10 | *.pkl 11 | -------------------------------------------------------------------------------- /Corpus.pyx: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.sparse as sp 3 | from libcpp.vector cimport vector 4 | from libcpp.string cimport string 5 | from libcpp.map cimport map 6 | 7 | # Text 전처리 cpp extention 모듈 8 | 9 | def make_dictionary(vocab): 10 | cdef int vocab_size = len(vocab) 11 | cdef int i 12 | dictionary={} 13 | inv_dictionary={} 14 | for i in range(vocab_size): 15 | dictionary[vocab[i]] = i 16 | inv_dictionary[i]=vocab[i] 17 | return dictionary,inv_dictionary 18 | 19 | 20 | def make_window_data(sents,window_size): 21 | pass 22 | 23 | 24 | def make_coo_matrix(corpus,dictionary): 25 | cdef int matrix_size = len(dictionary) 26 | pass 27 | 28 | 29 | def getBatch_FromBucket(batch_size,buckets): 30 | i=0 31 | bucket_mask =[False for _ in range(len(buckets))] 32 | indices = [[0,batch_size] for _ in range(len(buckets))] 33 | is_done=False 34 | while is_done==False: 35 | batch = buckets[i][indices[i][0]:indices[i][1]] 36 | temp = indices[i][1] 37 | indices[i][1]= indices[i][1]+batch_size 38 | indices[i][0] = temp 39 | 40 | i = (i+1)%len(buckets) 41 | while bucket_mask[i]: 42 | i = (i+1)%len(buckets) 43 | 44 | if indices[i][1]>len(buckets[i]): 45 | bucket_mask[i]= True 46 | if bucket_mask.count(True)==len(buckets): 47 | is_done=True 48 | else: 49 | while bucket_mask[i]: 50 | i = (i+1)%len(buckets) 51 | yield batch -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 SungDong Kim 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DeepNLP-models-Pytorch 2 | 3 | Pytorch implementations of various Deep NLP models in cs-224n(Stanford Univ: NLP with Deep Learning) 4 | 5 | - This is not for Pytorch beginners. If it is your first time to use Pytorch, I recommend these [awesome tutorials](#references). 6 | - If you're interested in DeepNLP, I strongly recommend you to work with this awesome lecture. 7 | 8 | * cs-224n-slides 9 | * cs-224n-videos 10 | 11 | This material is not perfect but will help your study and research:) Please feel free to pull requests!! 12 | 13 |
14 | 15 | ## Contents 16 | 17 | | Model | Links | 18 | | ------------- |:-------------:| 19 | | 01. Skip-gram-Naive-Softmax | [notebook / data / paper] | 20 | | 02. Skip-gram-Negative-Sampling | [notebook / data / paper] | 21 | | 03. GloVe | [notebook / data / paper] | 22 | | 04. Window-Classifier-for-NER | [notebook / data / paper] | 23 | | 05. Neural-Dependancy-Parser | [notebook / data / paper] | 24 | | 06. RNN-Language-Model | [notebook / data / paper] | 25 | | 07. Neural-Machine-Translation-with-Attention | [notebook / data / paper] | 26 | | 08. CNN-for-Text-Classification | [notebook / data / paper] | 27 | | 09. Recursive-NN-for-Sentiment-Classification | [notebook / data / paper] | 28 | | 10. Dynamic-Memory-Network-for-Question-Answering | [notebook / data / paper] | 29 | 30 | 31 | ## Requirements 32 | 33 | - Python 3.5 34 | - Pytorch 0.2+ 35 | - nltk 3.2.2 36 | - gensim 2.2.0 37 | - sklearn_crfsuite 38 | 39 | 40 | ## Getting started 41 | 42 | `git clone https://github.com/DSKSD/cs-224n-Pytorch.git` 43 | 44 | ### prepare dataset 45 | 46 | ```` 47 | cd script 48 | chmod u+x prepare_dataset.sh 49 | ./prepare_dataset.sh 50 | ```` 51 | 52 | ### docker env 53 | ubuntu 16.04 python 3.5.2 with various of ML/DL packages including tensorflow, sklearn, pytorch 54 | 55 | `docker pull dsksd/deepstudy:0.2` 56 | 57 | ```` 58 | pip3 install docker-compose 59 | cd script 60 | docker-compose up -d 61 | ```` 62 | 63 | ### cloud setting 64 | 65 | `not yet` 66 | 67 | ## References 68 | 69 | * practical-pytorch 70 | * DeepLearningForNLPInPytorch 71 | * pytorch-tutorial 72 | * pytorch-examples 73 | 74 | ## Author 75 | 76 | Sungdong Kim / @DSKSD -------------------------------------------------------------------------------- /images/01.skipgram-objective.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DSKSD/DeepNLP-models-Pytorch/7fec64d72615933e8f4ea499c2dbaa42508f4017/images/01.skipgram-objective.png -------------------------------------------------------------------------------- /images/01.skipgram-prepare-data.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DSKSD/DeepNLP-models-Pytorch/7fec64d72615933e8f4ea499c2dbaa42508f4017/images/01.skipgram-prepare-data.png -------------------------------------------------------------------------------- /images/02.skipgram-objective.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DSKSD/DeepNLP-models-Pytorch/7fec64d72615933e8f4ea499c2dbaa42508f4017/images/02.skipgram-objective.png -------------------------------------------------------------------------------- /images/03.glove-objective.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DSKSD/DeepNLP-models-Pytorch/7fec64d72615933e8f4ea499c2dbaa42508f4017/images/03.glove-objective.png -------------------------------------------------------------------------------- /images/03.glove-weighting-function.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DSKSD/DeepNLP-models-Pytorch/7fec64d72615933e8f4ea499c2dbaa42508f4017/images/03.glove-weighting-function.png -------------------------------------------------------------------------------- /images/04.window-classifier-architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DSKSD/DeepNLP-models-Pytorch/7fec64d72615933e8f4ea499c2dbaa42508f4017/images/04.window-classifier-architecture.png -------------------------------------------------------------------------------- /images/04.window-data.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DSKSD/DeepNLP-models-Pytorch/7fec64d72615933e8f4ea499c2dbaa42508f4017/images/04.window-data.png -------------------------------------------------------------------------------- /images/05.neural-dparser-architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DSKSD/DeepNLP-models-Pytorch/7fec64d72615933e8f4ea499c2dbaa42508f4017/images/05.neural-dparser-architecture.png -------------------------------------------------------------------------------- /images/05.transition-based-parse.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DSKSD/DeepNLP-models-Pytorch/7fec64d72615933e8f4ea499c2dbaa42508f4017/images/05.transition-based-parse.png -------------------------------------------------------------------------------- /images/06.rnnlm-architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DSKSD/DeepNLP-models-Pytorch/7fec64d72615933e8f4ea499c2dbaa42508f4017/images/06.rnnlm-architecture.png -------------------------------------------------------------------------------- /images/07.attention-mechanism.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DSKSD/DeepNLP-models-Pytorch/7fec64d72615933e8f4ea499c2dbaa42508f4017/images/07.attention-mechanism.png -------------------------------------------------------------------------------- /images/07.pad_to_sequence.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DSKSD/DeepNLP-models-Pytorch/7fec64d72615933e8f4ea499c2dbaa42508f4017/images/07.pad_to_sequence.png -------------------------------------------------------------------------------- /images/07.seq2seq.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DSKSD/DeepNLP-models-Pytorch/7fec64d72615933e8f4ea499c2dbaa42508f4017/images/07.seq2seq.png -------------------------------------------------------------------------------- /images/08.cnn-for-text-architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DSKSD/DeepNLP-models-Pytorch/7fec64d72615933e8f4ea499c2dbaa42508f4017/images/08.cnn-for-text-architecture.png -------------------------------------------------------------------------------- /images/09.rntn-layer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DSKSD/DeepNLP-models-Pytorch/7fec64d72615933e8f4ea499c2dbaa42508f4017/images/09.rntn-layer.png -------------------------------------------------------------------------------- /images/10.dmn-architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DSKSD/DeepNLP-models-Pytorch/7fec64d72615933e8f4ea499c2dbaa42508f4017/images/10.dmn-architecture.png -------------------------------------------------------------------------------- /notebooks/01.Skip-gram-Naive-Softmax.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# 1. Skip-gram with naiive softmax " 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "I recommend you take a look at these material first." 15 | ] 16 | }, 17 | { 18 | "cell_type": "markdown", 19 | "metadata": {}, 20 | "source": [ 21 | "* http://web.stanford.edu/class/cs224n/lectures/cs224n-2017-lecture2.pdf\n", 22 | "* https://arxiv.org/abs/1301.3781\n", 23 | "* http://mccormickml.com/2016/04/19/word2vec-tutorial-the-skip-gram-model/" 24 | ] 25 | }, 26 | { 27 | "cell_type": "code", 28 | "execution_count": 1, 29 | "metadata": { 30 | "collapsed": true 31 | }, 32 | "outputs": [], 33 | "source": [ 34 | "import torch\n", 35 | "import torch.nn as nn\n", 36 | "from torch.autograd import Variable\n", 37 | "import torch.optim as optim\n", 38 | "import torch.nn.functional as F\n", 39 | "import nltk\n", 40 | "import random\n", 41 | "import numpy as np\n", 42 | "from collections import Counter\n", 43 | "flatten = lambda l: [item for sublist in l for item in sublist]\n", 44 | "random.seed(1024)" 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": 2, 50 | "metadata": { 51 | "collapsed": false 52 | }, 53 | "outputs": [ 54 | { 55 | "name": "stdout", 56 | "output_type": "stream", 57 | "text": [ 58 | "0.3.0.post4\n", 59 | "3.2.4\n" 60 | ] 61 | } 62 | ], 63 | "source": [ 64 | "print(torch.__version__)\n", 65 | "print(nltk.__version__)" 66 | ] 67 | }, 68 | { 69 | "cell_type": "code", 70 | "execution_count": 3, 71 | "metadata": { 72 | "collapsed": true 73 | }, 74 | "outputs": [], 75 | "source": [ 76 | "USE_CUDA = torch.cuda.is_available()\n", 77 | "gpus = [0]\n", 78 | "torch.cuda.set_device(gpus[0])\n", 79 | "\n", 80 | "FloatTensor = torch.cuda.FloatTensor if USE_CUDA else torch.FloatTensor\n", 81 | "LongTensor = torch.cuda.LongTensor if USE_CUDA else torch.LongTensor\n", 82 | "ByteTensor = torch.cuda.ByteTensor if USE_CUDA else torch.ByteTensor" 83 | ] 84 | }, 85 | { 86 | "cell_type": "code", 87 | "execution_count": 4, 88 | "metadata": { 89 | "collapsed": true 90 | }, 91 | "outputs": [], 92 | "source": [ 93 | "def getBatch(batch_size, train_data):\n", 94 | " random.shuffle(train_data)\n", 95 | " sindex = 0\n", 96 | " eindex = batch_size\n", 97 | " while eindex < len(train_data):\n", 98 | " batch = train_data[sindex: eindex]\n", 99 | " temp = eindex\n", 100 | " eindex = eindex + batch_size\n", 101 | " sindex = temp\n", 102 | " yield batch\n", 103 | " \n", 104 | " if eindex >= len(train_data):\n", 105 | " batch = train_data[sindex:]\n", 106 | " yield batch" 107 | ] 108 | }, 109 | { 110 | "cell_type": "code", 111 | "execution_count": 5, 112 | "metadata": { 113 | "collapsed": true 114 | }, 115 | "outputs": [], 116 | "source": [ 117 | "def prepare_sequence(seq, word2index):\n", 118 | " idxs = list(map(lambda w: word2index[w] if word2index.get(w) is not None else word2index[\"\"], seq))\n", 119 | " return Variable(LongTensor(idxs))\n", 120 | "\n", 121 | "def prepare_word(word, word2index):\n", 122 | " return Variable(LongTensor([word2index[word]]) if word2index.get(word) is not None else LongTensor([word2index[\"\"]]))" 123 | ] 124 | }, 125 | { 126 | "cell_type": "markdown", 127 | "metadata": {}, 128 | "source": [ 129 | "## Data load and Preprocessing " 130 | ] 131 | }, 132 | { 133 | "cell_type": "markdown", 134 | "metadata": {}, 135 | "source": [ 136 | "### Load corpus : Gutenberg corpus" 137 | ] 138 | }, 139 | { 140 | "cell_type": "markdown", 141 | "metadata": {}, 142 | "source": [ 143 | "If you don't have gutenberg corpus, you can download it first using nltk.download()" 144 | ] 145 | }, 146 | { 147 | "cell_type": "code", 148 | "execution_count": 6, 149 | "metadata": { 150 | "collapsed": false 151 | }, 152 | "outputs": [ 153 | { 154 | "data": { 155 | "text/plain": [ 156 | "['austen-emma.txt',\n", 157 | " 'austen-persuasion.txt',\n", 158 | " 'austen-sense.txt',\n", 159 | " 'bible-kjv.txt',\n", 160 | " 'blake-poems.txt',\n", 161 | " 'bryant-stories.txt',\n", 162 | " 'burgess-busterbrown.txt',\n", 163 | " 'carroll-alice.txt',\n", 164 | " 'chesterton-ball.txt',\n", 165 | " 'chesterton-brown.txt',\n", 166 | " 'chesterton-thursday.txt',\n", 167 | " 'edgeworth-parents.txt',\n", 168 | " 'melville-moby_dick.txt',\n", 169 | " 'milton-paradise.txt',\n", 170 | " 'shakespeare-caesar.txt',\n", 171 | " 'shakespeare-hamlet.txt',\n", 172 | " 'shakespeare-macbeth.txt',\n", 173 | " 'whitman-leaves.txt']" 174 | ] 175 | }, 176 | "execution_count": 6, 177 | "metadata": {}, 178 | "output_type": "execute_result" 179 | } 180 | ], 181 | "source": [ 182 | "nltk.corpus.gutenberg.fileids()" 183 | ] 184 | }, 185 | { 186 | "cell_type": "code", 187 | "execution_count": 7, 188 | "metadata": { 189 | "collapsed": true 190 | }, 191 | "outputs": [], 192 | "source": [ 193 | "corpus = list(nltk.corpus.gutenberg.sents('melville-moby_dick.txt'))[:100] # sampling sentences for test\n", 194 | "corpus = [[word.lower() for word in sent] for sent in corpus]" 195 | ] 196 | }, 197 | { 198 | "cell_type": "markdown", 199 | "metadata": {}, 200 | "source": [ 201 | "### Extract Stopwords from unigram distribution's tails" 202 | ] 203 | }, 204 | { 205 | "cell_type": "code", 206 | "execution_count": 8, 207 | "metadata": { 208 | "collapsed": true 209 | }, 210 | "outputs": [], 211 | "source": [ 212 | "word_count = Counter(flatten(corpus))\n", 213 | "border = int(len(word_count) * 0.01) " 214 | ] 215 | }, 216 | { 217 | "cell_type": "code", 218 | "execution_count": 9, 219 | "metadata": { 220 | "collapsed": true 221 | }, 222 | "outputs": [], 223 | "source": [ 224 | "stopwords = word_count.most_common()[:border] + list(reversed(word_count.most_common()))[:border]" 225 | ] 226 | }, 227 | { 228 | "cell_type": "code", 229 | "execution_count": 10, 230 | "metadata": { 231 | "collapsed": true 232 | }, 233 | "outputs": [], 234 | "source": [ 235 | "stopwords = [s[0] for s in stopwords]" 236 | ] 237 | }, 238 | { 239 | "cell_type": "code", 240 | "execution_count": 11, 241 | "metadata": { 242 | "collapsed": false 243 | }, 244 | "outputs": [ 245 | { 246 | "data": { 247 | "text/plain": [ 248 | "[',', '.', 'the', 'of', 'and', 'baleine', '--(', 'fat', 'oil', 'boiling']" 249 | ] 250 | }, 251 | "execution_count": 11, 252 | "metadata": {}, 253 | "output_type": "execute_result" 254 | } 255 | ], 256 | "source": [ 257 | "stopwords" 258 | ] 259 | }, 260 | { 261 | "cell_type": "markdown", 262 | "metadata": {}, 263 | "source": [ 264 | "### Build vocab" 265 | ] 266 | }, 267 | { 268 | "cell_type": "code", 269 | "execution_count": 12, 270 | "metadata": { 271 | "collapsed": true 272 | }, 273 | "outputs": [], 274 | "source": [ 275 | "vocab = list(set(flatten(corpus)) - set(stopwords))\n", 276 | "vocab.append('')" 277 | ] 278 | }, 279 | { 280 | "cell_type": "code", 281 | "execution_count": 13, 282 | "metadata": { 283 | "collapsed": false 284 | }, 285 | "outputs": [ 286 | { 287 | "name": "stdout", 288 | "output_type": "stream", 289 | "text": [ 290 | "592 583\n" 291 | ] 292 | } 293 | ], 294 | "source": [ 295 | "print(len(set(flatten(corpus))), len(vocab))" 296 | ] 297 | }, 298 | { 299 | "cell_type": "code", 300 | "execution_count": 14, 301 | "metadata": { 302 | "collapsed": true 303 | }, 304 | "outputs": [], 305 | "source": [ 306 | "word2index = {'' : 0} \n", 307 | "\n", 308 | "for vo in vocab:\n", 309 | " if word2index.get(vo) is None:\n", 310 | " word2index[vo] = len(word2index)\n", 311 | "\n", 312 | "index2word = {v:k for k, v in word2index.items()} " 313 | ] 314 | }, 315 | { 316 | "cell_type": "markdown", 317 | "metadata": {}, 318 | "source": [ 319 | "### Prepare train data " 320 | ] 321 | }, 322 | { 323 | "cell_type": "markdown", 324 | "metadata": {}, 325 | "source": [ 326 | "window data example" 327 | ] 328 | }, 329 | { 330 | "cell_type": "markdown", 331 | "metadata": {}, 332 | "source": [ 333 | "\n", 334 | "
borrowed image from http://mccormickml.com/2016/04/19/word2vec-tutorial-the-skip-gram-model/
" 335 | ] 336 | }, 337 | { 338 | "cell_type": "code", 339 | "execution_count": 15, 340 | "metadata": { 341 | "collapsed": true 342 | }, 343 | "outputs": [], 344 | "source": [ 345 | "WINDOW_SIZE = 3\n", 346 | "windows = flatten([list(nltk.ngrams([''] * WINDOW_SIZE + c + [''] * WINDOW_SIZE, WINDOW_SIZE * 2 + 1)) for c in corpus])" 347 | ] 348 | }, 349 | { 350 | "cell_type": "code", 351 | "execution_count": 16, 352 | "metadata": { 353 | "collapsed": false 354 | }, 355 | "outputs": [ 356 | { 357 | "data": { 358 | "text/plain": [ 359 | "('', '', '', '[', 'moby', 'dick', 'by')" 360 | ] 361 | }, 362 | "execution_count": 16, 363 | "metadata": {}, 364 | "output_type": "execute_result" 365 | } 366 | ], 367 | "source": [ 368 | "windows[0]" 369 | ] 370 | }, 371 | { 372 | "cell_type": "code", 373 | "execution_count": 17, 374 | "metadata": { 375 | "collapsed": false 376 | }, 377 | "outputs": [ 378 | { 379 | "name": "stdout", 380 | "output_type": "stream", 381 | "text": [ 382 | "[('[', 'moby'), ('[', 'dick'), ('[', 'by'), ('moby', '['), ('moby', 'dick'), ('moby', 'by')]\n" 383 | ] 384 | } 385 | ], 386 | "source": [ 387 | "train_data = []\n", 388 | "\n", 389 | "for window in windows:\n", 390 | " for i in range(WINDOW_SIZE * 2 + 1):\n", 391 | " if i == WINDOW_SIZE or window[i] == '': \n", 392 | " continue\n", 393 | " train_data.append((window[WINDOW_SIZE], window[i]))\n", 394 | "\n", 395 | "print(train_data[:WINDOW_SIZE * 2])" 396 | ] 397 | }, 398 | { 399 | "cell_type": "code", 400 | "execution_count": 18, 401 | "metadata": { 402 | "collapsed": true 403 | }, 404 | "outputs": [], 405 | "source": [ 406 | "X_p = []\n", 407 | "y_p = []" 408 | ] 409 | }, 410 | { 411 | "cell_type": "code", 412 | "execution_count": 19, 413 | "metadata": { 414 | "collapsed": false 415 | }, 416 | "outputs": [ 417 | { 418 | "data": { 419 | "text/plain": [ 420 | "('[', 'moby')" 421 | ] 422 | }, 423 | "execution_count": 19, 424 | "metadata": {}, 425 | "output_type": "execute_result" 426 | } 427 | ], 428 | "source": [ 429 | "train_data[0]" 430 | ] 431 | }, 432 | { 433 | "cell_type": "code", 434 | "execution_count": 20, 435 | "metadata": { 436 | "collapsed": true 437 | }, 438 | "outputs": [], 439 | "source": [ 440 | "for tr in train_data:\n", 441 | " X_p.append(prepare_word(tr[0], word2index).view(1, -1))\n", 442 | " y_p.append(prepare_word(tr[1], word2index).view(1, -1))" 443 | ] 444 | }, 445 | { 446 | "cell_type": "code", 447 | "execution_count": 21, 448 | "metadata": { 449 | "collapsed": true 450 | }, 451 | "outputs": [], 452 | "source": [ 453 | "train_data = list(zip(X_p, y_p))" 454 | ] 455 | }, 456 | { 457 | "cell_type": "code", 458 | "execution_count": 22, 459 | "metadata": { 460 | "collapsed": false 461 | }, 462 | "outputs": [ 463 | { 464 | "data": { 465 | "text/plain": [ 466 | "7606" 467 | ] 468 | }, 469 | "execution_count": 22, 470 | "metadata": {}, 471 | "output_type": "execute_result" 472 | } 473 | ], 474 | "source": [ 475 | "len(train_data)" 476 | ] 477 | }, 478 | { 479 | "cell_type": "markdown", 480 | "metadata": {}, 481 | "source": [ 482 | "## Modeling" 483 | ] 484 | }, 485 | { 486 | "cell_type": "markdown", 487 | "metadata": {}, 488 | "source": [ 489 | "\n", 490 | "
borrowed image from http://web.stanford.edu/class/cs224n/lectures/cs224n-2017-lecture2.pdf
" 491 | ] 492 | }, 493 | { 494 | "cell_type": "code", 495 | "execution_count": 59, 496 | "metadata": { 497 | "collapsed": true 498 | }, 499 | "outputs": [], 500 | "source": [ 501 | "class Skipgram(nn.Module):\n", 502 | " \n", 503 | " def __init__(self, vocab_size, projection_dim):\n", 504 | " super(Skipgram,self).__init__()\n", 505 | " self.embedding_v = nn.Embedding(vocab_size, projection_dim)\n", 506 | " self.embedding_u = nn.Embedding(vocab_size, projection_dim)\n", 507 | "\n", 508 | " self.embedding_v.weight.data.uniform_(-1, 1) # init\n", 509 | " self.embedding_u.weight.data.uniform_(0, 0) # init\n", 510 | " #self.out = nn.Linear(projection_dim,vocab_size)\n", 511 | " def forward(self, center_words,target_words, outer_words):\n", 512 | " center_embeds = self.embedding_v(center_words) # B x 1 x D\n", 513 | " target_embeds = self.embedding_u(target_words) # B x 1 x D\n", 514 | " outer_embeds = self.embedding_u(outer_words) # B x V x D\n", 515 | " \n", 516 | " scores = target_embeds.bmm(center_embeds.transpose(1, 2)).squeeze(2) # Bx1xD * BxDx1 => Bx1\n", 517 | " norm_scores = outer_embeds.bmm(center_embeds.transpose(1, 2)).squeeze(2) # BxVxD * BxDx1 => BxV\n", 518 | " \n", 519 | " nll = -torch.mean(torch.log(torch.exp(scores)/torch.sum(torch.exp(norm_scores), 1).unsqueeze(1))) # log-softmax\n", 520 | " \n", 521 | " return nll # negative log likelihood\n", 522 | " \n", 523 | " def prediction(self, inputs):\n", 524 | " embeds = self.embedding_v(inputs)\n", 525 | " \n", 526 | " return embeds " 527 | ] 528 | }, 529 | { 530 | "cell_type": "markdown", 531 | "metadata": {}, 532 | "source": [ 533 | "## Train " 534 | ] 535 | }, 536 | { 537 | "cell_type": "code", 538 | "execution_count": 60, 539 | "metadata": { 540 | "collapsed": true 541 | }, 542 | "outputs": [], 543 | "source": [ 544 | "EMBEDDING_SIZE = 30\n", 545 | "BATCH_SIZE = 256\n", 546 | "EPOCH = 100" 547 | ] 548 | }, 549 | { 550 | "cell_type": "code", 551 | "execution_count": 61, 552 | "metadata": { 553 | "collapsed": true 554 | }, 555 | "outputs": [], 556 | "source": [ 557 | "losses = []\n", 558 | "model = Skipgram(len(word2index), EMBEDDING_SIZE)\n", 559 | "if USE_CUDA:\n", 560 | " model = model.cuda()\n", 561 | "optimizer = optim.Adam(model.parameters(), lr=0.01)" 562 | ] 563 | }, 564 | { 565 | "cell_type": "code", 566 | "execution_count": 62, 567 | "metadata": { 568 | "collapsed": false 569 | }, 570 | "outputs": [ 571 | { 572 | "name": "stdout", 573 | "output_type": "stream", 574 | "text": [ 575 | "Epoch : 0, mean_loss : 6.20\n", 576 | "Epoch : 10, mean_loss : 4.38\n", 577 | "Epoch : 20, mean_loss : 3.48\n", 578 | "Epoch : 30, mean_loss : 3.31\n", 579 | "Epoch : 40, mean_loss : 3.26\n", 580 | "Epoch : 50, mean_loss : 3.24\n", 581 | "Epoch : 60, mean_loss : 3.22\n", 582 | "Epoch : 70, mean_loss : 3.22\n", 583 | "Epoch : 80, mean_loss : 3.21\n", 584 | "Epoch : 90, mean_loss : 3.20\n" 585 | ] 586 | } 587 | ], 588 | "source": [ 589 | "for epoch in range(EPOCH):\n", 590 | " for i, batch in enumerate(getBatch(BATCH_SIZE, train_data)):\n", 591 | " \n", 592 | " inputs, targets = zip(*batch)\n", 593 | " \n", 594 | " inputs = torch.cat(inputs) # B x 1\n", 595 | " targets = torch.cat(targets) # B x 1\n", 596 | " vocabs = prepare_sequence(list(vocab), word2index).expand(inputs.size(0), len(vocab)) # B x V\n", 597 | " model.zero_grad()\n", 598 | "\n", 599 | " loss = model(inputs, targets, vocabs)\n", 600 | " \n", 601 | " loss.backward()\n", 602 | " optimizer.step()\n", 603 | " \n", 604 | " losses.append(loss.data.tolist()[0])\n", 605 | "\n", 606 | " if epoch % 10 == 0:\n", 607 | " print(\"Epoch : %d, mean_loss : %.02f\" % (epoch,np.mean(losses)))\n", 608 | " losses = []" 609 | ] 610 | }, 611 | { 612 | "cell_type": "markdown", 613 | "metadata": {}, 614 | "source": [ 615 | "## Test" 616 | ] 617 | }, 618 | { 619 | "cell_type": "code", 620 | "execution_count": 63, 621 | "metadata": { 622 | "collapsed": true 623 | }, 624 | "outputs": [], 625 | "source": [ 626 | "def word_similarity(target, vocab):\n", 627 | " if USE_CUDA:\n", 628 | " target_V = model.prediction(prepare_word(target, word2index))\n", 629 | " else:\n", 630 | " target_V = model.prediction(prepare_word(target, word2index))\n", 631 | " similarities = []\n", 632 | " for i in range(len(vocab)):\n", 633 | " if vocab[i] == target: continue\n", 634 | " \n", 635 | " if USE_CUDA:\n", 636 | " vector = model.prediction(prepare_word(list(vocab)[i], word2index))\n", 637 | " else:\n", 638 | " vector = model.prediction(prepare_word(list(vocab)[i], word2index))\n", 639 | " cosine_sim = F.cosine_similarity(target_V, vector).data.tolist()[0] \n", 640 | " similarities.append([vocab[i], cosine_sim])\n", 641 | " return sorted(similarities, key=lambda x: x[1], reverse=True)[:10] # sort by similarity" 642 | ] 643 | }, 644 | { 645 | "cell_type": "code", 646 | "execution_count": 64, 647 | "metadata": { 648 | "collapsed": false 649 | }, 650 | "outputs": [ 651 | { 652 | "data": { 653 | "text/plain": [ 654 | "'least'" 655 | ] 656 | }, 657 | "execution_count": 64, 658 | "metadata": {}, 659 | "output_type": "execute_result" 660 | } 661 | ], 662 | "source": [ 663 | "test = random.choice(list(vocab))\n", 664 | "test" 665 | ] 666 | }, 667 | { 668 | "cell_type": "code", 669 | "execution_count": 65, 670 | "metadata": { 671 | "collapsed": false 672 | }, 673 | "outputs": [ 674 | { 675 | "data": { 676 | "text/plain": [ 677 | "[['at', 0.8147411346435547],\n", 678 | " ['every', 0.7143548130989075],\n", 679 | " ['case', 0.6975079774856567],\n", 680 | " ['secure', 0.6121522188186646],\n", 681 | " ['heart', 0.5974172949790955],\n", 682 | " ['including', 0.5867112278938293],\n", 683 | " ['please', 0.5557640194892883],\n", 684 | " ['has', 0.5536234974861145],\n", 685 | " ['while', 0.5366998314857483],\n", 686 | " ['you', 0.509368896484375]]" 687 | ] 688 | }, 689 | "execution_count": 65, 690 | "metadata": {}, 691 | "output_type": "execute_result" 692 | } 693 | ], 694 | "source": [ 695 | "word_similarity(test, vocab)" 696 | ] 697 | }, 698 | { 699 | "cell_type": "code", 700 | "execution_count": null, 701 | "metadata": { 702 | "collapsed": true 703 | }, 704 | "outputs": [], 705 | "source": [] 706 | } 707 | ], 708 | "metadata": { 709 | "kernelspec": { 710 | "display_name": "Python 3", 711 | "language": "python", 712 | "name": "python3" 713 | }, 714 | "language_info": { 715 | "codemirror_mode": { 716 | "name": "ipython", 717 | "version": 3 718 | }, 719 | "file_extension": ".py", 720 | "mimetype": "text/x-python", 721 | "name": "python", 722 | "nbconvert_exporter": "python", 723 | "pygments_lexer": "ipython3", 724 | "version": "3.5.2" 725 | } 726 | }, 727 | "nbformat": 4, 728 | "nbformat_minor": 2 729 | } 730 | -------------------------------------------------------------------------------- /notebooks/02.Skip-gram-Negative-Sampling.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# 2. Skip-gram with negative sampling" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "I recommend you take a look at these material first." 15 | ] 16 | }, 17 | { 18 | "cell_type": "markdown", 19 | "metadata": {}, 20 | "source": [ 21 | "* http://web.stanford.edu/class/cs224n/lectures/cs224n-2017-lecture3.pdf\n", 22 | "* http://papers.nips.cc/paper/5021-distributed-representations-of-words-and-phrases-and-their-compositionality.pdf" 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": 1, 28 | "metadata": { 29 | "collapsed": true 30 | }, 31 | "outputs": [], 32 | "source": [ 33 | "import torch\n", 34 | "import torch.nn as nn\n", 35 | "from torch.autograd import Variable\n", 36 | "import torch.optim as optim\n", 37 | "import torch.nn.functional as F\n", 38 | "import nltk\n", 39 | "import random\n", 40 | "import numpy as np\n", 41 | "from collections import Counter\n", 42 | "flatten = lambda l: [item for sublist in l for item in sublist]\n", 43 | "random.seed(1024)" 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": 2, 49 | "metadata": { 50 | "collapsed": false 51 | }, 52 | "outputs": [ 53 | { 54 | "name": "stdout", 55 | "output_type": "stream", 56 | "text": [ 57 | "0.3.0.post4\n", 58 | "3.2.4\n" 59 | ] 60 | } 61 | ], 62 | "source": [ 63 | "print(torch.__version__)\n", 64 | "print(nltk.__version__)" 65 | ] 66 | }, 67 | { 68 | "cell_type": "code", 69 | "execution_count": 3, 70 | "metadata": { 71 | "collapsed": true 72 | }, 73 | "outputs": [], 74 | "source": [ 75 | "USE_CUDA = torch.cuda.is_available()\n", 76 | "gpus = [0]\n", 77 | "torch.cuda.set_device(gpus[0])\n", 78 | "\n", 79 | "FloatTensor = torch.cuda.FloatTensor if USE_CUDA else torch.FloatTensor\n", 80 | "LongTensor = torch.cuda.LongTensor if USE_CUDA else torch.LongTensor\n", 81 | "ByteTensor = torch.cuda.ByteTensor if USE_CUDA else torch.ByteTensor" 82 | ] 83 | }, 84 | { 85 | "cell_type": "code", 86 | "execution_count": 4, 87 | "metadata": { 88 | "collapsed": true 89 | }, 90 | "outputs": [], 91 | "source": [ 92 | "def getBatch(batch_size, train_data):\n", 93 | " random.shuffle(train_data)\n", 94 | " sindex = 0\n", 95 | " eindex = batch_size\n", 96 | " while eindex < len(train_data):\n", 97 | " batch = train_data[sindex: eindex]\n", 98 | " temp = eindex\n", 99 | " eindex = eindex + batch_size\n", 100 | " sindex = temp\n", 101 | " yield batch\n", 102 | " \n", 103 | " if eindex >= len(train_data):\n", 104 | " batch = train_data[sindex:]\n", 105 | " yield batch" 106 | ] 107 | }, 108 | { 109 | "cell_type": "code", 110 | "execution_count": 5, 111 | "metadata": { 112 | "collapsed": true 113 | }, 114 | "outputs": [], 115 | "source": [ 116 | "def prepare_sequence(seq, word2index):\n", 117 | " idxs = list(map(lambda w: word2index[w] if word2index.get(w) is not None else word2index[\"\"], seq))\n", 118 | " return Variable(LongTensor(idxs))\n", 119 | "\n", 120 | "def prepare_word(word, word2index):\n", 121 | " return Variable(LongTensor([word2index[word]]) if word2index.get(word) is not None else LongTensor([word2index[\"\"]]))" 122 | ] 123 | }, 124 | { 125 | "cell_type": "markdown", 126 | "metadata": {}, 127 | "source": [ 128 | "## Data load and Preprocessing " 129 | ] 130 | }, 131 | { 132 | "cell_type": "code", 133 | "execution_count": 6, 134 | "metadata": { 135 | "collapsed": true 136 | }, 137 | "outputs": [], 138 | "source": [ 139 | "corpus = list(nltk.corpus.gutenberg.sents('melville-moby_dick.txt'))[:500]\n", 140 | "corpus = [[word.lower() for word in sent] for sent in corpus]" 141 | ] 142 | }, 143 | { 144 | "cell_type": "markdown", 145 | "metadata": {}, 146 | "source": [ 147 | "### Exclude sparse words " 148 | ] 149 | }, 150 | { 151 | "cell_type": "code", 152 | "execution_count": 7, 153 | "metadata": { 154 | "collapsed": true 155 | }, 156 | "outputs": [], 157 | "source": [ 158 | "word_count = Counter(flatten(corpus))" 159 | ] 160 | }, 161 | { 162 | "cell_type": "code", 163 | "execution_count": 8, 164 | "metadata": { 165 | "collapsed": true 166 | }, 167 | "outputs": [], 168 | "source": [ 169 | "MIN_COUNT = 3\n", 170 | "exclude = []" 171 | ] 172 | }, 173 | { 174 | "cell_type": "code", 175 | "execution_count": 9, 176 | "metadata": { 177 | "collapsed": true 178 | }, 179 | "outputs": [], 180 | "source": [ 181 | "for w, c in word_count.items():\n", 182 | " if c < MIN_COUNT:\n", 183 | " exclude.append(w)" 184 | ] 185 | }, 186 | { 187 | "cell_type": "markdown", 188 | "metadata": {}, 189 | "source": [ 190 | "### Prepare train data " 191 | ] 192 | }, 193 | { 194 | "cell_type": "code", 195 | "execution_count": 10, 196 | "metadata": { 197 | "collapsed": true 198 | }, 199 | "outputs": [], 200 | "source": [ 201 | "vocab = list(set(flatten(corpus)) - set(exclude))" 202 | ] 203 | }, 204 | { 205 | "cell_type": "code", 206 | "execution_count": 11, 207 | "metadata": { 208 | "collapsed": true 209 | }, 210 | "outputs": [], 211 | "source": [ 212 | "word2index = {}\n", 213 | "for vo in vocab:\n", 214 | " if word2index.get(vo) is None:\n", 215 | " word2index[vo] = len(word2index)\n", 216 | " \n", 217 | "index2word = {v:k for k, v in word2index.items()}" 218 | ] 219 | }, 220 | { 221 | "cell_type": "code", 222 | "execution_count": 12, 223 | "metadata": { 224 | "collapsed": true 225 | }, 226 | "outputs": [], 227 | "source": [ 228 | "WINDOW_SIZE = 5\n", 229 | "windows = flatten([list(nltk.ngrams([''] * WINDOW_SIZE + c + [''] * WINDOW_SIZE, WINDOW_SIZE * 2 + 1)) for c in corpus])\n", 230 | "\n", 231 | "train_data = []\n", 232 | "\n", 233 | "for window in windows:\n", 234 | " for i in range(WINDOW_SIZE * 2 + 1):\n", 235 | " if window[i] in exclude or window[WINDOW_SIZE] in exclude: \n", 236 | " continue # min_count\n", 237 | " if i == WINDOW_SIZE or window[i] == '': \n", 238 | " continue\n", 239 | " train_data.append((window[WINDOW_SIZE], window[i]))\n", 240 | "\n", 241 | "X_p = []\n", 242 | "y_p = []\n", 243 | "\n", 244 | "for tr in train_data:\n", 245 | " X_p.append(prepare_word(tr[0], word2index).view(1, -1))\n", 246 | " y_p.append(prepare_word(tr[1], word2index).view(1, -1))\n", 247 | " \n", 248 | "train_data = list(zip(X_p, y_p))" 249 | ] 250 | }, 251 | { 252 | "cell_type": "code", 253 | "execution_count": 13, 254 | "metadata": { 255 | "collapsed": false 256 | }, 257 | "outputs": [ 258 | { 259 | "data": { 260 | "text/plain": [ 261 | "50242" 262 | ] 263 | }, 264 | "execution_count": 13, 265 | "metadata": {}, 266 | "output_type": "execute_result" 267 | } 268 | ], 269 | "source": [ 270 | "len(train_data)" 271 | ] 272 | }, 273 | { 274 | "cell_type": "markdown", 275 | "metadata": {}, 276 | "source": [ 277 | "### Build Unigram Distribution**0.75 " 278 | ] 279 | }, 280 | { 281 | "cell_type": "markdown", 282 | "metadata": {}, 283 | "source": [ 284 | "$$P(w)=U(w)^{3/4}/Z$$" 285 | ] 286 | }, 287 | { 288 | "cell_type": "code", 289 | "execution_count": 14, 290 | "metadata": { 291 | "collapsed": true 292 | }, 293 | "outputs": [], 294 | "source": [ 295 | "Z = 0.001" 296 | ] 297 | }, 298 | { 299 | "cell_type": "code", 300 | "execution_count": 15, 301 | "metadata": { 302 | "collapsed": true 303 | }, 304 | "outputs": [], 305 | "source": [ 306 | "word_count = Counter(flatten(corpus))\n", 307 | "num_total_words = sum([c for w, c in word_count.items() if w not in exclude])" 308 | ] 309 | }, 310 | { 311 | "cell_type": "code", 312 | "execution_count": 16, 313 | "metadata": { 314 | "collapsed": true 315 | }, 316 | "outputs": [], 317 | "source": [ 318 | "unigram_table = []\n", 319 | "\n", 320 | "for vo in vocab:\n", 321 | " unigram_table.extend([vo] * int(((word_count[vo]/num_total_words)**0.75)/Z))" 322 | ] 323 | }, 324 | { 325 | "cell_type": "code", 326 | "execution_count": 17, 327 | "metadata": { 328 | "collapsed": false 329 | }, 330 | "outputs": [ 331 | { 332 | "name": "stdout", 333 | "output_type": "stream", 334 | "text": [ 335 | "478 3500\n" 336 | ] 337 | } 338 | ], 339 | "source": [ 340 | "print(len(vocab), len(unigram_table))" 341 | ] 342 | }, 343 | { 344 | "cell_type": "markdown", 345 | "metadata": {}, 346 | "source": [ 347 | "### Negative Sampling " 348 | ] 349 | }, 350 | { 351 | "cell_type": "code", 352 | "execution_count": 18, 353 | "metadata": { 354 | "collapsed": true 355 | }, 356 | "outputs": [], 357 | "source": [ 358 | "def negative_sampling(targets, unigram_table, k):\n", 359 | " batch_size = targets.size(0)\n", 360 | " neg_samples = []\n", 361 | " for i in range(batch_size):\n", 362 | " nsample = []\n", 363 | " target_index = targets[i].data.cpu().tolist()[0] if USE_CUDA else targets[i].data.tolist()[0]\n", 364 | " while len(nsample) < k: # num of sampling\n", 365 | " neg = random.choice(unigram_table)\n", 366 | " if word2index[neg] == target_index:\n", 367 | " continue\n", 368 | " nsample.append(neg)\n", 369 | " neg_samples.append(prepare_sequence(nsample, word2index).view(1, -1))\n", 370 | " \n", 371 | " return torch.cat(neg_samples)" 372 | ] 373 | }, 374 | { 375 | "cell_type": "markdown", 376 | "metadata": {}, 377 | "source": [ 378 | "## Modeling " 379 | ] 380 | }, 381 | { 382 | "cell_type": "markdown", 383 | "metadata": {}, 384 | "source": [ 385 | "\n", 386 | "
borrowed image from http://web.stanford.edu/class/cs224n/lectures/cs224n-2017-lecture3.pdf
" 387 | ] 388 | }, 389 | { 390 | "cell_type": "code", 391 | "execution_count": 19, 392 | "metadata": { 393 | "collapsed": true 394 | }, 395 | "outputs": [], 396 | "source": [ 397 | "class SkipgramNegSampling(nn.Module):\n", 398 | " \n", 399 | " def __init__(self, vocab_size, projection_dim):\n", 400 | " super(SkipgramNegSampling, self).__init__()\n", 401 | " self.embedding_v = nn.Embedding(vocab_size, projection_dim) # center embedding\n", 402 | " self.embedding_u = nn.Embedding(vocab_size, projection_dim) # out embedding\n", 403 | " self.logsigmoid = nn.LogSigmoid()\n", 404 | " \n", 405 | " initrange = (2.0 / (vocab_size + projection_dim))**0.5 # Xavier init\n", 406 | " self.embedding_v.weight.data.uniform_(-initrange, initrange) # init\n", 407 | " self.embedding_u.weight.data.uniform_(-0.0, 0.0) # init\n", 408 | " \n", 409 | " def forward(self, center_words, target_words, negative_words):\n", 410 | " center_embeds = self.embedding_v(center_words) # B x 1 x D\n", 411 | " target_embeds = self.embedding_u(target_words) # B x 1 x D\n", 412 | " \n", 413 | " neg_embeds = -self.embedding_u(negative_words) # B x K x D\n", 414 | " \n", 415 | " positive_score = target_embeds.bmm(center_embeds.transpose(1, 2)).squeeze(2) # Bx1\n", 416 | " negative_score = torch.sum(neg_embeds.bmm(center_embeds.transpose(1, 2)).squeeze(2), 1).view(negs.size(0), -1) # BxK -> Bx1\n", 417 | " \n", 418 | " loss = self.logsigmoid(positive_score) + self.logsigmoid(negative_score)\n", 419 | " \n", 420 | " return -torch.mean(loss)\n", 421 | " \n", 422 | " def prediction(self, inputs):\n", 423 | " embeds = self.embedding_v(inputs)\n", 424 | " \n", 425 | " return embeds" 426 | ] 427 | }, 428 | { 429 | "cell_type": "markdown", 430 | "metadata": {}, 431 | "source": [ 432 | "## Train " 433 | ] 434 | }, 435 | { 436 | "cell_type": "code", 437 | "execution_count": 68, 438 | "metadata": { 439 | "collapsed": true 440 | }, 441 | "outputs": [], 442 | "source": [ 443 | "EMBEDDING_SIZE = 30 \n", 444 | "BATCH_SIZE = 256\n", 445 | "EPOCH = 100\n", 446 | "NEG = 10 # Num of Negative Sampling" 447 | ] 448 | }, 449 | { 450 | "cell_type": "code", 451 | "execution_count": 69, 452 | "metadata": { 453 | "collapsed": true 454 | }, 455 | "outputs": [], 456 | "source": [ 457 | "losses = []\n", 458 | "model = SkipgramNegSampling(len(word2index), EMBEDDING_SIZE)\n", 459 | "if USE_CUDA:\n", 460 | " model = model.cuda()\n", 461 | "optimizer = optim.Adam(model.parameters(), lr=0.001)" 462 | ] 463 | }, 464 | { 465 | "cell_type": "code", 466 | "execution_count": 70, 467 | "metadata": { 468 | "collapsed": false 469 | }, 470 | "outputs": [ 471 | { 472 | "name": "stdout", 473 | "output_type": "stream", 474 | "text": [ 475 | "Epoch : 0, mean_loss : 1.06\n", 476 | "Epoch : 10, mean_loss : 0.86\n", 477 | "Epoch : 20, mean_loss : 0.79\n", 478 | "Epoch : 30, mean_loss : 0.74\n", 479 | "Epoch : 40, mean_loss : 0.71\n", 480 | "Epoch : 50, mean_loss : 0.69\n", 481 | "Epoch : 60, mean_loss : 0.67\n", 482 | "Epoch : 70, mean_loss : 0.65\n", 483 | "Epoch : 80, mean_loss : 0.64\n", 484 | "Epoch : 90, mean_loss : 0.63\n" 485 | ] 486 | } 487 | ], 488 | "source": [ 489 | "for epoch in range(EPOCH):\n", 490 | " for i,batch in enumerate(getBatch(BATCH_SIZE, train_data)):\n", 491 | " \n", 492 | " inputs, targets = zip(*batch)\n", 493 | " \n", 494 | " inputs = torch.cat(inputs) # B x 1\n", 495 | " targets = torch.cat(targets) # B x 1\n", 496 | " negs = negative_sampling(targets, unigram_table, NEG)\n", 497 | " model.zero_grad()\n", 498 | "\n", 499 | " loss = model(inputs, targets, negs)\n", 500 | " \n", 501 | " loss.backward()\n", 502 | " optimizer.step()\n", 503 | " \n", 504 | " losses.append(loss.data.tolist()[0])\n", 505 | " if epoch % 10 == 0:\n", 506 | " print(\"Epoch : %d, mean_loss : %.02f\" % (epoch, np.mean(losses)))\n", 507 | " losses = []" 508 | ] 509 | }, 510 | { 511 | "cell_type": "markdown", 512 | "metadata": {}, 513 | "source": [ 514 | "## Test " 515 | ] 516 | }, 517 | { 518 | "cell_type": "code", 519 | "execution_count": 71, 520 | "metadata": { 521 | "collapsed": true 522 | }, 523 | "outputs": [], 524 | "source": [ 525 | "def word_similarity(target, vocab):\n", 526 | " if USE_CUDA:\n", 527 | " target_V = model.prediction(prepare_word(target, word2index))\n", 528 | " else:\n", 529 | " target_V = model.prediction(prepare_word(target, word2index))\n", 530 | " similarities = []\n", 531 | " for i in range(len(vocab)):\n", 532 | " if vocab[i] == target: \n", 533 | " continue\n", 534 | " \n", 535 | " if USE_CUDA:\n", 536 | " vector = model.prediction(prepare_word(list(vocab)[i], word2index))\n", 537 | " else:\n", 538 | " vector = model.prediction(prepare_word(list(vocab)[i], word2index))\n", 539 | " \n", 540 | " cosine_sim = F.cosine_similarity(target_V, vector).data.tolist()[0]\n", 541 | " similarities.append([vocab[i], cosine_sim])\n", 542 | " return sorted(similarities, key=lambda x: x[1], reverse=True)[:10]" 543 | ] 544 | }, 545 | { 546 | "cell_type": "code", 547 | "execution_count": 212, 548 | "metadata": { 549 | "collapsed": false 550 | }, 551 | "outputs": [ 552 | { 553 | "data": { 554 | "text/plain": [ 555 | "'passengers'" 556 | ] 557 | }, 558 | "execution_count": 212, 559 | "metadata": {}, 560 | "output_type": "execute_result" 561 | } 562 | ], 563 | "source": [ 564 | "test = random.choice(list(vocab))\n", 565 | "test" 566 | ] 567 | }, 568 | { 569 | "cell_type": "code", 570 | "execution_count": 213, 571 | "metadata": { 572 | "collapsed": false 573 | }, 574 | "outputs": [ 575 | { 576 | "data": { 577 | "text/plain": [ 578 | "[['am', 0.7353377342224121],\n", 579 | " ['passenger', 0.7154150605201721],\n", 580 | " ['cook', 0.6829826831817627],\n", 581 | " ['new', 0.6648461818695068],\n", 582 | " ['bedford', 0.6283411383628845],\n", 583 | " ['besides', 0.5972960591316223],\n", 584 | " ['themselves', 0.5964340567588806],\n", 585 | " ['grow', 0.5957046151161194],\n", 586 | " ['tell', 0.5952941179275513],\n", 587 | " ['get', 0.5943044424057007]]" 588 | ] 589 | }, 590 | "execution_count": 213, 591 | "metadata": {}, 592 | "output_type": "execute_result" 593 | } 594 | ], 595 | "source": [ 596 | "word_similarity(test, vocab)" 597 | ] 598 | }, 599 | { 600 | "cell_type": "code", 601 | "execution_count": null, 602 | "metadata": { 603 | "collapsed": true 604 | }, 605 | "outputs": [], 606 | "source": [] 607 | } 608 | ], 609 | "metadata": { 610 | "kernelspec": { 611 | "display_name": "Python 3", 612 | "language": "python", 613 | "name": "python3" 614 | }, 615 | "language_info": { 616 | "codemirror_mode": { 617 | "name": "ipython", 618 | "version": 3 619 | }, 620 | "file_extension": ".py", 621 | "mimetype": "text/x-python", 622 | "name": "python", 623 | "nbconvert_exporter": "python", 624 | "pygments_lexer": "ipython3", 625 | "version": "3.5.2" 626 | } 627 | }, 628 | "nbformat": 4, 629 | "nbformat_minor": 2 630 | } 631 | -------------------------------------------------------------------------------- /notebooks/03.GloVe.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# 3. GloVe: Global Vectors for Word Representation" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "I recommend you take a look at these material first." 15 | ] 16 | }, 17 | { 18 | "cell_type": "markdown", 19 | "metadata": {}, 20 | "source": [ 21 | "* http://web.stanford.edu/class/cs224n/lectures/cs224n-2017-lecture3.pdf\n", 22 | "* https://nlp.stanford.edu/pubs/glove.pdf" 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": 1, 28 | "metadata": { 29 | "collapsed": true 30 | }, 31 | "outputs": [], 32 | "source": [ 33 | "import torch\n", 34 | "import torch.nn as nn\n", 35 | "from torch.autograd import Variable\n", 36 | "import torch.optim as optim\n", 37 | "import torch.nn.functional as F\n", 38 | "import nltk\n", 39 | "import random\n", 40 | "import numpy as np\n", 41 | "from collections import Counter\n", 42 | "flatten = lambda l: [item for sublist in l for item in sublist]\n", 43 | "random.seed(1024)" 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": 2, 49 | "metadata": { 50 | "collapsed": false 51 | }, 52 | "outputs": [ 53 | { 54 | "name": "stdout", 55 | "output_type": "stream", 56 | "text": [ 57 | "0.3.0.post4\n", 58 | "3.2.4\n" 59 | ] 60 | } 61 | ], 62 | "source": [ 63 | "print(torch.__version__)\n", 64 | "print(nltk.__version__)" 65 | ] 66 | }, 67 | { 68 | "cell_type": "code", 69 | "execution_count": 3, 70 | "metadata": { 71 | "collapsed": true 72 | }, 73 | "outputs": [], 74 | "source": [ 75 | "USE_CUDA = torch.cuda.is_available()\n", 76 | "gpus = [0]\n", 77 | "torch.cuda.set_device(gpus[0])\n", 78 | "\n", 79 | "FloatTensor = torch.cuda.FloatTensor if USE_CUDA else torch.FloatTensor\n", 80 | "LongTensor = torch.cuda.LongTensor if USE_CUDA else torch.LongTensor\n", 81 | "ByteTensor = torch.cuda.ByteTensor if USE_CUDA else torch.ByteTensor" 82 | ] 83 | }, 84 | { 85 | "cell_type": "code", 86 | "execution_count": 4, 87 | "metadata": { 88 | "collapsed": true 89 | }, 90 | "outputs": [], 91 | "source": [ 92 | "def getBatch(batch_size, train_data):\n", 93 | " random.shuffle(train_data)\n", 94 | " sindex = 0\n", 95 | " eindex = batch_size\n", 96 | " while eindex < len(train_data):\n", 97 | " batch = train_data[sindex:eindex]\n", 98 | " temp = eindex\n", 99 | " eindex = eindex + batch_size\n", 100 | " sindex = temp\n", 101 | " yield batch\n", 102 | " \n", 103 | " if eindex >= len(train_data):\n", 104 | " batch = train_data[sindex:]\n", 105 | " yield batch" 106 | ] 107 | }, 108 | { 109 | "cell_type": "code", 110 | "execution_count": 5, 111 | "metadata": { 112 | "collapsed": true 113 | }, 114 | "outputs": [], 115 | "source": [ 116 | "def prepare_sequence(seq, word2index):\n", 117 | " idxs = list(map(lambda w: word2index[w] if word2index.get(w) is not None else word2index[\"\"], seq))\n", 118 | " return Variable(LongTensor(idxs))\n", 119 | "\n", 120 | "def prepare_word(word, word2index):\n", 121 | " return Variable(LongTensor([word2index[word]]) if word2index.get(word) is not None else LongTensor([word2index[\"\"]]))" 122 | ] 123 | }, 124 | { 125 | "cell_type": "markdown", 126 | "metadata": {}, 127 | "source": [ 128 | "## Data load and Preprocessing " 129 | ] 130 | }, 131 | { 132 | "cell_type": "code", 133 | "execution_count": 6, 134 | "metadata": { 135 | "collapsed": true 136 | }, 137 | "outputs": [], 138 | "source": [ 139 | "corpus = list(nltk.corpus.gutenberg.sents('melville-moby_dick.txt'))[:500]\n", 140 | "corpus = [[word.lower() for word in sent] for sent in corpus]" 141 | ] 142 | }, 143 | { 144 | "cell_type": "markdown", 145 | "metadata": {}, 146 | "source": [ 147 | "### Build vocab" 148 | ] 149 | }, 150 | { 151 | "cell_type": "code", 152 | "execution_count": 7, 153 | "metadata": { 154 | "collapsed": true 155 | }, 156 | "outputs": [], 157 | "source": [ 158 | "vocab = list(set(flatten(corpus)))" 159 | ] 160 | }, 161 | { 162 | "cell_type": "code", 163 | "execution_count": 8, 164 | "metadata": { 165 | "collapsed": true 166 | }, 167 | "outputs": [], 168 | "source": [ 169 | "word2index = {}\n", 170 | "for vo in vocab:\n", 171 | " if word2index.get(vo) is None:\n", 172 | " word2index[vo] = len(word2index)\n", 173 | " \n", 174 | "index2word={v:k for k, v in word2index.items()}" 175 | ] 176 | }, 177 | { 178 | "cell_type": "code", 179 | "execution_count": 9, 180 | "metadata": { 181 | "collapsed": true 182 | }, 183 | "outputs": [], 184 | "source": [ 185 | "WINDOW_SIZE = 5\n", 186 | "windows = flatten([list(nltk.ngrams([''] * WINDOW_SIZE + c + [''] * WINDOW_SIZE, WINDOW_SIZE * 2 + 1)) for c in corpus])\n", 187 | "\n", 188 | "window_data = []\n", 189 | "\n", 190 | "for window in windows:\n", 191 | " for i in range(WINDOW_SIZE * 2 + 1):\n", 192 | " if i == WINDOW_SIZE or window[i] == '': \n", 193 | " continue\n", 194 | " window_data.append((window[WINDOW_SIZE], window[i]))\n" 195 | ] 196 | }, 197 | { 198 | "cell_type": "markdown", 199 | "metadata": {}, 200 | "source": [ 201 | "### Weighting Function " 202 | ] 203 | }, 204 | { 205 | "cell_type": "markdown", 206 | "metadata": {}, 207 | "source": [ 208 | "\n", 209 | "
borrowed image from https://nlp.stanford.edu/pubs/glove.pdf
" 210 | ] 211 | }, 212 | { 213 | "cell_type": "code", 214 | "execution_count": 10, 215 | "metadata": { 216 | "collapsed": true 217 | }, 218 | "outputs": [], 219 | "source": [ 220 | "def weighting(w_i, w_j):\n", 221 | " try:\n", 222 | " x_ij = X_ik[(w_i, w_j)]\n", 223 | " except:\n", 224 | " x_ij = 1\n", 225 | " \n", 226 | " x_max = 100 #100 # fixed in paper\n", 227 | " alpha = 0.75\n", 228 | " \n", 229 | " if x_ij < x_max:\n", 230 | " result = (x_ij/x_max)**alpha\n", 231 | " else:\n", 232 | " result = 1\n", 233 | " \n", 234 | " return result" 235 | ] 236 | }, 237 | { 238 | "cell_type": "markdown", 239 | "metadata": {}, 240 | "source": [ 241 | "### Build Co-occurence Matrix X" 242 | ] 243 | }, 244 | { 245 | "cell_type": "markdown", 246 | "metadata": {}, 247 | "source": [ 248 | "Because of model complexity, It is important to determine whether a tighter bound can be placed on the number of nonzero elements of X." 249 | ] 250 | }, 251 | { 252 | "cell_type": "code", 253 | "execution_count": 11, 254 | "metadata": { 255 | "collapsed": true 256 | }, 257 | "outputs": [], 258 | "source": [ 259 | "X_i = Counter(flatten(corpus)) # X_i" 260 | ] 261 | }, 262 | { 263 | "cell_type": "code", 264 | "execution_count": 12, 265 | "metadata": { 266 | "collapsed": true 267 | }, 268 | "outputs": [], 269 | "source": [ 270 | "X_ik_window_5 = Counter(window_data) # Co-occurece in window size 5" 271 | ] 272 | }, 273 | { 274 | "cell_type": "code", 275 | "execution_count": 13, 276 | "metadata": { 277 | "collapsed": true 278 | }, 279 | "outputs": [], 280 | "source": [ 281 | "X_ik = {}\n", 282 | "weighting_dic = {}" 283 | ] 284 | }, 285 | { 286 | "cell_type": "code", 287 | "execution_count": 14, 288 | "metadata": { 289 | "collapsed": true 290 | }, 291 | "outputs": [], 292 | "source": [ 293 | "from itertools import combinations_with_replacement" 294 | ] 295 | }, 296 | { 297 | "cell_type": "code", 298 | "execution_count": 15, 299 | "metadata": { 300 | "collapsed": true 301 | }, 302 | "outputs": [], 303 | "source": [ 304 | "for bigram in combinations_with_replacement(vocab, 2):\n", 305 | " if X_ik_window_5.get(bigram) is not None: # nonzero elements\n", 306 | " co_occer = X_ik_window_5[bigram]\n", 307 | " X_ik[bigram] = co_occer + 1 # log(Xik) -> log(Xik+1) to prevent divergence\n", 308 | " X_ik[(bigram[1],bigram[0])] = co_occer+1\n", 309 | " else:\n", 310 | " pass\n", 311 | " \n", 312 | " weighting_dic[bigram] = weighting(bigram[0], bigram[1])\n", 313 | " weighting_dic[(bigram[1], bigram[0])] = weighting(bigram[1], bigram[0])" 314 | ] 315 | }, 316 | { 317 | "cell_type": "code", 318 | "execution_count": 16, 319 | "metadata": { 320 | "collapsed": false 321 | }, 322 | "outputs": [ 323 | { 324 | "name": "stdout", 325 | "output_type": "stream", 326 | "text": [ 327 | "(',', 'was')\n", 328 | "True\n" 329 | ] 330 | } 331 | ], 332 | "source": [ 333 | "test = random.choice(window_data)\n", 334 | "print(test)\n", 335 | "try:\n", 336 | " print(X_ik[(test[0], test[1])] == X_ik[(test[1], test[0])])\n", 337 | "except:\n", 338 | " 1" 339 | ] 340 | }, 341 | { 342 | "cell_type": "markdown", 343 | "metadata": {}, 344 | "source": [ 345 | "### Prepare train data" 346 | ] 347 | }, 348 | { 349 | "cell_type": "code", 350 | "execution_count": 17, 351 | "metadata": { 352 | "collapsed": false 353 | }, 354 | "outputs": [ 355 | { 356 | "name": "stdout", 357 | "output_type": "stream", 358 | "text": [ 359 | "(Variable containing:\n", 360 | " 703\n", 361 | "[torch.cuda.LongTensor of size 1x1 (GPU 0)]\n", 362 | ", Variable containing:\n", 363 | " 23\n", 364 | "[torch.cuda.LongTensor of size 1x1 (GPU 0)]\n", 365 | ", Variable containing:\n", 366 | " 0.6931\n", 367 | "[torch.cuda.FloatTensor of size 1x1 (GPU 0)]\n", 368 | ", Variable containing:\n", 369 | "1.00000e-02 *\n", 370 | " 5.3183\n", 371 | "[torch.cuda.FloatTensor of size 1x1 (GPU 0)]\n", 372 | ")\n" 373 | ] 374 | } 375 | ], 376 | "source": [ 377 | "u_p = [] # center vec\n", 378 | "v_p = [] # context vec\n", 379 | "co_p = [] # log(x_ij)\n", 380 | "weight_p = [] # f(x_ij)\n", 381 | "\n", 382 | "for pair in window_data: \n", 383 | " u_p.append(prepare_word(pair[0], word2index).view(1, -1))\n", 384 | " v_p.append(prepare_word(pair[1], word2index).view(1, -1))\n", 385 | " \n", 386 | " try:\n", 387 | " cooc = X_ik[pair]\n", 388 | " except:\n", 389 | " cooc = 1\n", 390 | "\n", 391 | " co_p.append(torch.log(Variable(FloatTensor([cooc]))).view(1, -1))\n", 392 | " weight_p.append(Variable(FloatTensor([weighting_dic[pair]])).view(1, -1))\n", 393 | " \n", 394 | "train_data = list(zip(u_p, v_p, co_p, weight_p))\n", 395 | "del u_p\n", 396 | "del v_p\n", 397 | "del co_p\n", 398 | "del weight_p\n", 399 | "print(train_data[0]) # tuple (center vec i, context vec j log(x_ij), weight f(w_ij))" 400 | ] 401 | }, 402 | { 403 | "cell_type": "markdown", 404 | "metadata": {}, 405 | "source": [ 406 | "## Modeling " 407 | ] 408 | }, 409 | { 410 | "cell_type": "markdown", 411 | "metadata": {}, 412 | "source": [ 413 | "\n", 414 | "
borrowed image from https://nlp.stanford.edu/pubs/glove.pdf
" 415 | ] 416 | }, 417 | { 418 | "cell_type": "code", 419 | "execution_count": 19, 420 | "metadata": { 421 | "collapsed": true 422 | }, 423 | "outputs": [], 424 | "source": [ 425 | "class GloVe(nn.Module):\n", 426 | " \n", 427 | " def __init__(self, vocab_size,projection_dim):\n", 428 | " super(GloVe,self).__init__()\n", 429 | " self.embedding_v = nn.Embedding(vocab_size, projection_dim) # center embedding\n", 430 | " self.embedding_u = nn.Embedding(vocab_size, projection_dim) # out embedding\n", 431 | " \n", 432 | " self.v_bias = nn.Embedding(vocab_size, 1)\n", 433 | " self.u_bias = nn.Embedding(vocab_size, 1)\n", 434 | " \n", 435 | " initrange = (2.0 / (vocab_size + projection_dim))**0.5 # Xavier init\n", 436 | " self.embedding_v.weight.data.uniform_(-initrange, initrange) # init\n", 437 | " self.embedding_u.weight.data.uniform_(-initrange, initrange) # init\n", 438 | " self.v_bias.weight.data.uniform_(-initrange, initrange) # init\n", 439 | " self.u_bias.weight.data.uniform_(-initrange, initrange) # init\n", 440 | " \n", 441 | " def forward(self, center_words, target_words, coocs, weights):\n", 442 | " center_embeds = self.embedding_v(center_words) # B x 1 x D\n", 443 | " target_embeds = self.embedding_u(target_words) # B x 1 x D\n", 444 | " \n", 445 | " center_bias = self.v_bias(center_words).squeeze(1)\n", 446 | " target_bias = self.u_bias(target_words).squeeze(1)\n", 447 | " \n", 448 | " inner_product = target_embeds.bmm(center_embeds.transpose(1, 2)).squeeze(2) # Bx1\n", 449 | " \n", 450 | " loss = weights*torch.pow(inner_product +center_bias + target_bias - coocs, 2)\n", 451 | " \n", 452 | " return torch.sum(loss)\n", 453 | " \n", 454 | " def prediction(self, inputs):\n", 455 | " v_embeds = self.embedding_v(inputs) # B x 1 x D\n", 456 | " u_embeds = self.embedding_u(inputs) # B x 1 x D\n", 457 | " \n", 458 | " return v_embeds+u_embeds # final embed" 459 | ] 460 | }, 461 | { 462 | "cell_type": "markdown", 463 | "metadata": {}, 464 | "source": [ 465 | "## Train " 466 | ] 467 | }, 468 | { 469 | "cell_type": "code", 470 | "execution_count": 22, 471 | "metadata": { 472 | "collapsed": true 473 | }, 474 | "outputs": [], 475 | "source": [ 476 | "EMBEDDING_SIZE = 50\n", 477 | "BATCH_SIZE = 256\n", 478 | "EPOCH = 50" 479 | ] 480 | }, 481 | { 482 | "cell_type": "code", 483 | "execution_count": 23, 484 | "metadata": { 485 | "collapsed": true 486 | }, 487 | "outputs": [], 488 | "source": [ 489 | "losses = []\n", 490 | "model = GloVe(len(word2index), EMBEDDING_SIZE)\n", 491 | "if USE_CUDA:\n", 492 | " model = model.cuda()\n", 493 | "optimizer = optim.Adam(model.parameters(), lr=0.001)" 494 | ] 495 | }, 496 | { 497 | "cell_type": "code", 498 | "execution_count": 24, 499 | "metadata": { 500 | "collapsed": false 501 | }, 502 | "outputs": [ 503 | { 504 | "name": "stdout", 505 | "output_type": "stream", 506 | "text": [ 507 | "Epoch : 0, mean_loss : 236.10\n", 508 | "Epoch : 10, mean_loss : 2.27\n", 509 | "Epoch : 20, mean_loss : 0.53\n", 510 | "Epoch : 30, mean_loss : 0.12\n", 511 | "Epoch : 40, mean_loss : 0.04\n" 512 | ] 513 | } 514 | ], 515 | "source": [ 516 | "for epoch in range(EPOCH):\n", 517 | " for i,batch in enumerate(getBatch(BATCH_SIZE, train_data)):\n", 518 | " \n", 519 | " inputs, targets, coocs, weights = zip(*batch)\n", 520 | " \n", 521 | " inputs = torch.cat(inputs) # B x 1\n", 522 | " targets = torch.cat(targets) # B x 1\n", 523 | " coocs = torch.cat(coocs)\n", 524 | " weights = torch.cat(weights)\n", 525 | " model.zero_grad()\n", 526 | "\n", 527 | " loss = model(inputs, targets, coocs, weights)\n", 528 | " \n", 529 | " loss.backward()\n", 530 | " optimizer.step()\n", 531 | " \n", 532 | " losses.append(loss.data.tolist()[0])\n", 533 | " if epoch % 10 == 0:\n", 534 | " print(\"Epoch : %d, mean_loss : %.02f\" % (epoch, np.mean(losses)))\n", 535 | " losses = []" 536 | ] 537 | }, 538 | { 539 | "cell_type": "markdown", 540 | "metadata": {}, 541 | "source": [ 542 | "## Test " 543 | ] 544 | }, 545 | { 546 | "cell_type": "code", 547 | "execution_count": 25, 548 | "metadata": { 549 | "collapsed": true 550 | }, 551 | "outputs": [], 552 | "source": [ 553 | "def word_similarity(target, vocab):\n", 554 | " if USE_CUDA:\n", 555 | " target_V = model.prediction(prepare_word(target, word2index))\n", 556 | " else:\n", 557 | " target_V = model.prediction(prepare_word(target, word2index))\n", 558 | " similarities = []\n", 559 | " for i in range(len(vocab)):\n", 560 | " if vocab[i] == target: \n", 561 | " continue\n", 562 | " \n", 563 | " if USE_CUDA:\n", 564 | " vector = model.prediction(prepare_word(list(vocab)[i], word2index))\n", 565 | " else:\n", 566 | " vector = model.prediction(prepare_word(list(vocab)[i], word2index))\n", 567 | " \n", 568 | " cosine_sim = F.cosine_similarity(target_V, vector).data.tolist()[0] \n", 569 | " similarities.append([vocab[i], cosine_sim])\n", 570 | " return sorted(similarities, key=lambda x: x[1], reverse=True)[:10]" 571 | ] 572 | }, 573 | { 574 | "cell_type": "code", 575 | "execution_count": 86, 576 | "metadata": { 577 | "collapsed": false 578 | }, 579 | "outputs": [ 580 | { 581 | "data": { 582 | "text/plain": [ 583 | "'spiral'" 584 | ] 585 | }, 586 | "execution_count": 86, 587 | "metadata": {}, 588 | "output_type": "execute_result" 589 | } 590 | ], 591 | "source": [ 592 | "test = random.choice(list(vocab))\n", 593 | "test" 594 | ] 595 | }, 596 | { 597 | "cell_type": "code", 598 | "execution_count": 87, 599 | "metadata": { 600 | "collapsed": false 601 | }, 602 | "outputs": [ 603 | { 604 | "data": { 605 | "text/plain": [ 606 | "[['horns', 0.9727935194969177],\n", 607 | " ['swords', 0.9076412916183472],\n", 608 | " ['hooked', 0.8984033465385437],\n", 609 | " ['thar', 0.8066437244415283],\n", 610 | " ['montaigne', 0.8062068819999695],\n", 611 | " ['rabelais', 0.789764940738678],\n", 612 | " ['orion', 0.7886737585067749],\n", 613 | " ['isaiah', 0.780662477016449],\n", 614 | " ['hamlet', 0.7799868583679199],\n", 615 | " ['colnett', 0.7792885899543762]]" 616 | ] 617 | }, 618 | "execution_count": 87, 619 | "metadata": {}, 620 | "output_type": "execute_result" 621 | } 622 | ], 623 | "source": [ 624 | "word_similarity(test, vocab)" 625 | ] 626 | }, 627 | { 628 | "cell_type": "markdown", 629 | "metadata": { 630 | "collapsed": true 631 | }, 632 | "source": [ 633 | "## TODO" 634 | ] 635 | }, 636 | { 637 | "cell_type": "markdown", 638 | "metadata": {}, 639 | "source": [ 640 | "* Use sparse-matrix to build co-occurence matrix for memory efficiency" 641 | ] 642 | }, 643 | { 644 | "cell_type": "markdown", 645 | "metadata": {}, 646 | "source": [ 647 | "## Suggested Readings" 648 | ] 649 | }, 650 | { 651 | "cell_type": "markdown", 652 | "metadata": {}, 653 | "source": [ 654 | "* Word embeddings in 2017: Trends and future directions" 655 | ] 656 | } 657 | ], 658 | "metadata": { 659 | "kernelspec": { 660 | "display_name": "Python 3", 661 | "language": "python", 662 | "name": "python3" 663 | }, 664 | "language_info": { 665 | "codemirror_mode": { 666 | "name": "ipython", 667 | "version": 3 668 | }, 669 | "file_extension": ".py", 670 | "mimetype": "text/x-python", 671 | "name": "python", 672 | "nbconvert_exporter": "python", 673 | "pygments_lexer": "ipython3", 674 | "version": "3.5.2" 675 | } 676 | }, 677 | "nbformat": 4, 678 | "nbformat_minor": 2 679 | } 680 | -------------------------------------------------------------------------------- /notebooks/04.Window-Classifier-for-NER.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# 4. Word Window Classification and Neural Networks " 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "I recommend you take a look at these material first." 15 | ] 16 | }, 17 | { 18 | "cell_type": "markdown", 19 | "metadata": {}, 20 | "source": [ 21 | "* http://web.stanford.edu/class/cs224n/lectures/cs224n-2017-lecture4.pdf\n", 22 | "* https://en.wikipedia.org/wiki/Named-entity_recognition" 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": 1, 28 | "metadata": { 29 | "collapsed": true 30 | }, 31 | "outputs": [], 32 | "source": [ 33 | "import torch\n", 34 | "import torch.nn as nn\n", 35 | "from torch.autograd import Variable\n", 36 | "import torch.optim as optim\n", 37 | "import torch.nn.functional as F\n", 38 | "import nltk\n", 39 | "import random\n", 40 | "import numpy as np\n", 41 | "from collections import Counter\n", 42 | "flatten = lambda l: [item for sublist in l for item in sublist]\n", 43 | "from sklearn_crfsuite import metrics\n", 44 | "random.seed(1024)" 45 | ] 46 | }, 47 | { 48 | "cell_type": "markdown", 49 | "metadata": {}, 50 | "source": [ 51 | "You also need sklearn_crfsuite latest version for print confusion matrix" 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": 2, 57 | "metadata": { 58 | "collapsed": false 59 | }, 60 | "outputs": [ 61 | { 62 | "name": "stdout", 63 | "output_type": "stream", 64 | "text": [ 65 | "0.3.0.post4\n", 66 | "3.2.4\n" 67 | ] 68 | } 69 | ], 70 | "source": [ 71 | "print(torch.__version__)\n", 72 | "print(nltk.__version__)" 73 | ] 74 | }, 75 | { 76 | "cell_type": "code", 77 | "execution_count": 3, 78 | "metadata": { 79 | "collapsed": true 80 | }, 81 | "outputs": [], 82 | "source": [ 83 | "USE_CUDA = torch.cuda.is_available()\n", 84 | "gpus = [0]\n", 85 | "torch.cuda.set_device(gpus[0])\n", 86 | "\n", 87 | "FloatTensor = torch.cuda.FloatTensor if USE_CUDA else torch.FloatTensor\n", 88 | "LongTensor = torch.cuda.LongTensor if USE_CUDA else torch.LongTensor\n", 89 | "ByteTensor = torch.cuda.ByteTensor if USE_CUDA else torch.ByteTensor" 90 | ] 91 | }, 92 | { 93 | "cell_type": "code", 94 | "execution_count": 4, 95 | "metadata": { 96 | "collapsed": true 97 | }, 98 | "outputs": [], 99 | "source": [ 100 | "def getBatch(batch_size, train_data):\n", 101 | " random.shuffle(train_data)\n", 102 | " sindex = 0\n", 103 | " eindex = batch_size\n", 104 | " while eindex < len(train_data):\n", 105 | " batch = train_data[sindex: eindex]\n", 106 | " temp = eindex\n", 107 | " eindex = eindex + batch_size\n", 108 | " sindex = temp\n", 109 | " yield batch\n", 110 | " \n", 111 | " if eindex >= len(train_data):\n", 112 | " batch = train_data[sindex:]\n", 113 | " yield batch" 114 | ] 115 | }, 116 | { 117 | "cell_type": "code", 118 | "execution_count": 5, 119 | "metadata": { 120 | "collapsed": true 121 | }, 122 | "outputs": [], 123 | "source": [ 124 | "def prepare_sequence(seq, word2index):\n", 125 | " idxs = list(map(lambda w: word2index[w] if word2index.get(w) is not None else word2index[\"\"], seq))\n", 126 | " return Variable(LongTensor(idxs))\n", 127 | "\n", 128 | "def prepare_word(word, word2index):\n", 129 | " return Variable(LongTensor([word2index[word]]) if word2index.get(word) is not None else LongTensor([word2index[\"\"]]))\n", 130 | "\n", 131 | "def prepare_tag(tag,tag2index):\n", 132 | " return Variable(LongTensor([tag2index[tag]]))" 133 | ] 134 | }, 135 | { 136 | "cell_type": "markdown", 137 | "metadata": {}, 138 | "source": [ 139 | "## Data load and Preprocessing " 140 | ] 141 | }, 142 | { 143 | "cell_type": "markdown", 144 | "metadata": {}, 145 | "source": [ 146 | "CoNLL-2002 Shared Task: Language-Independent Named Entity Recognition
\n", 147 | "https://www.clips.uantwerpen.be/conll2002/ner/" 148 | ] 149 | }, 150 | { 151 | "cell_type": "code", 152 | "execution_count": 6, 153 | "metadata": { 154 | "collapsed": true 155 | }, 156 | "outputs": [], 157 | "source": [ 158 | "corpus = nltk.corpus.conll2002.iob_sents()" 159 | ] 160 | }, 161 | { 162 | "cell_type": "code", 163 | "execution_count": 7, 164 | "metadata": { 165 | "collapsed": true 166 | }, 167 | "outputs": [], 168 | "source": [ 169 | "data = []\n", 170 | "for cor in corpus:\n", 171 | " sent, _, tag = list(zip(*cor))\n", 172 | " data.append([sent, tag])" 173 | ] 174 | }, 175 | { 176 | "cell_type": "code", 177 | "execution_count": 8, 178 | "metadata": { 179 | "collapsed": false 180 | }, 181 | "outputs": [ 182 | { 183 | "name": "stdout", 184 | "output_type": "stream", 185 | "text": [ 186 | "35651\n", 187 | "[('Sao', 'Paulo', '(', 'Brasil', ')', ',', '23', 'may', '(', 'EFECOM', ')', '.'), ('B-LOC', 'I-LOC', 'O', 'B-LOC', 'O', 'O', 'O', 'O', 'O', 'B-ORG', 'O', 'O')]\n" 188 | ] 189 | } 190 | ], 191 | "source": [ 192 | "print(len(data))\n", 193 | "print(data[0])" 194 | ] 195 | }, 196 | { 197 | "cell_type": "markdown", 198 | "metadata": {}, 199 | "source": [ 200 | "### Build Vocab" 201 | ] 202 | }, 203 | { 204 | "cell_type": "code", 205 | "execution_count": 9, 206 | "metadata": { 207 | "collapsed": true 208 | }, 209 | "outputs": [], 210 | "source": [ 211 | "sents,tags = list(zip(*data))\n", 212 | "vocab = list(set(flatten(sents)))\n", 213 | "tagset = list(set(flatten(tags)))" 214 | ] 215 | }, 216 | { 217 | "cell_type": "code", 218 | "execution_count": 10, 219 | "metadata": { 220 | "collapsed": true 221 | }, 222 | "outputs": [], 223 | "source": [ 224 | "word2index={'' : 0, '' : 1} # dummy token is for start or end of sentence\n", 225 | "for vo in vocab:\n", 226 | " if word2index.get(vo) is None:\n", 227 | " word2index[vo] = len(word2index)\n", 228 | "index2word = {v:k for k, v in word2index.items()}\n", 229 | "\n", 230 | "tag2index = {}\n", 231 | "for tag in tagset:\n", 232 | " if tag2index.get(tag) is None:\n", 233 | " tag2index[tag] = len(tag2index)\n", 234 | "index2tag={v:k for k, v in tag2index.items()}" 235 | ] 236 | }, 237 | { 238 | "cell_type": "markdown", 239 | "metadata": {}, 240 | "source": [ 241 | "### Prepare data" 242 | ] 243 | }, 244 | { 245 | "cell_type": "markdown", 246 | "metadata": {}, 247 | "source": [ 248 | "
Example : Classify 'Paris' in the context of this sentence with window length 2
" 249 | ] 250 | }, 251 | { 252 | "cell_type": "markdown", 253 | "metadata": {}, 254 | "source": [ 255 | "" 256 | ] 257 | }, 258 | { 259 | "cell_type": "markdown", 260 | "metadata": {}, 261 | "source": [ 262 | "
borrowed image from http://web.stanford.edu/class/cs224n/lectures/cs224n-2017-lecture4.pdf
" 263 | ] 264 | }, 265 | { 266 | "cell_type": "code", 267 | "execution_count": 11, 268 | "metadata": { 269 | "collapsed": true 270 | }, 271 | "outputs": [], 272 | "source": [ 273 | "WINDOW_SIZE = 2\n", 274 | "windows = []" 275 | ] 276 | }, 277 | { 278 | "cell_type": "code", 279 | "execution_count": 12, 280 | "metadata": { 281 | "collapsed": true 282 | }, 283 | "outputs": [], 284 | "source": [ 285 | "for sample in data:\n", 286 | " dummy = [''] * WINDOW_SIZE\n", 287 | " window = list(nltk.ngrams(dummy + list(sample[0]) + dummy, WINDOW_SIZE * 2 + 1))\n", 288 | " windows.extend([[list(window[i]), sample[1][i]] for i in range(len(sample[0]))])" 289 | ] 290 | }, 291 | { 292 | "cell_type": "code", 293 | "execution_count": 13, 294 | "metadata": { 295 | "collapsed": false 296 | }, 297 | "outputs": [ 298 | { 299 | "data": { 300 | "text/plain": [ 301 | "[['', '', 'Sao', 'Paulo', '('], 'B-LOC']" 302 | ] 303 | }, 304 | "execution_count": 13, 305 | "metadata": {}, 306 | "output_type": "execute_result" 307 | } 308 | ], 309 | "source": [ 310 | "windows[0]" 311 | ] 312 | }, 313 | { 314 | "cell_type": "code", 315 | "execution_count": 14, 316 | "metadata": { 317 | "collapsed": false 318 | }, 319 | "outputs": [ 320 | { 321 | "data": { 322 | "text/plain": [ 323 | "678377" 324 | ] 325 | }, 326 | "execution_count": 14, 327 | "metadata": {}, 328 | "output_type": "execute_result" 329 | } 330 | ], 331 | "source": [ 332 | "len(windows)" 333 | ] 334 | }, 335 | { 336 | "cell_type": "code", 337 | "execution_count": 15, 338 | "metadata": { 339 | "collapsed": true 340 | }, 341 | "outputs": [], 342 | "source": [ 343 | "random.shuffle(windows)\n", 344 | "\n", 345 | "train_data = windows[:int(len(windows) * 0.9)]\n", 346 | "test_data = windows[int(len(windows) * 0.9):]" 347 | ] 348 | }, 349 | { 350 | "cell_type": "markdown", 351 | "metadata": {}, 352 | "source": [ 353 | "## Modeling " 354 | ] 355 | }, 356 | { 357 | "cell_type": "markdown", 358 | "metadata": {}, 359 | "source": [ 360 | "\n", 361 | "
borrowed image from http://web.stanford.edu/class/cs224n/lectures/cs224n-2017-lecture4.pdf
" 362 | ] 363 | }, 364 | { 365 | "cell_type": "code", 366 | "execution_count": 16, 367 | "metadata": { 368 | "collapsed": true 369 | }, 370 | "outputs": [], 371 | "source": [ 372 | "class WindowClassifier(nn.Module): \n", 373 | " def __init__(self, vocab_size, embedding_size, window_size, hidden_size, output_size):\n", 374 | "\n", 375 | " super(WindowClassifier, self).__init__()\n", 376 | " \n", 377 | " self.embed = nn.Embedding(vocab_size, embedding_size)\n", 378 | " self.h_layer1 = nn.Linear(embedding_size * (window_size * 2 + 1), hidden_size)\n", 379 | " self.h_layer2 = nn.Linear(hidden_size, hidden_size)\n", 380 | " self.o_layer = nn.Linear(hidden_size, output_size)\n", 381 | " self.relu = nn.ReLU()\n", 382 | " self.softmax = nn.LogSoftmax(dim=1)\n", 383 | " self.dropout = nn.Dropout(0.3)\n", 384 | " \n", 385 | " def forward(self, inputs, is_training=False): \n", 386 | " embeds = self.embed(inputs) # BxWxD\n", 387 | " concated = embeds.view(-1, embeds.size(1)*embeds.size(2)) # Bx(W*D)\n", 388 | " h0 = self.relu(self.h_layer1(concated))\n", 389 | " if is_training:\n", 390 | " h0 = self.dropout(h0)\n", 391 | " h1 = self.relu(self.h_layer2(h0))\n", 392 | " if is_training:\n", 393 | " h1 = self.dropout(h1)\n", 394 | " out = self.softmax(self.o_layer(h1))\n", 395 | " return out" 396 | ] 397 | }, 398 | { 399 | "cell_type": "code", 400 | "execution_count": 20, 401 | "metadata": { 402 | "collapsed": true 403 | }, 404 | "outputs": [], 405 | "source": [ 406 | "BATCH_SIZE = 128\n", 407 | "EMBEDDING_SIZE = 50 # x (WINDOW_SIZE*2+1) = 250\n", 408 | "HIDDEN_SIZE = 300\n", 409 | "EPOCH = 3\n", 410 | "LEARNING_RATE = 0.001" 411 | ] 412 | }, 413 | { 414 | "cell_type": "markdown", 415 | "metadata": {}, 416 | "source": [ 417 | "## Training " 418 | ] 419 | }, 420 | { 421 | "cell_type": "markdown", 422 | "metadata": {}, 423 | "source": [ 424 | "It takes for a while if you use just cpu." 425 | ] 426 | }, 427 | { 428 | "cell_type": "code", 429 | "execution_count": 22, 430 | "metadata": { 431 | "collapsed": true 432 | }, 433 | "outputs": [], 434 | "source": [ 435 | "model = WindowClassifier(len(word2index), EMBEDDING_SIZE, WINDOW_SIZE, HIDDEN_SIZE, len(tag2index))\n", 436 | "if USE_CUDA:\n", 437 | " model = model.cuda()\n", 438 | "loss_function = nn.CrossEntropyLoss()\n", 439 | "optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)" 440 | ] 441 | }, 442 | { 443 | "cell_type": "code", 444 | "execution_count": 23, 445 | "metadata": { 446 | "collapsed": false 447 | }, 448 | "outputs": [ 449 | { 450 | "name": "stdout", 451 | "output_type": "stream", 452 | "text": [ 453 | "[0/3] mean_loss : 2.25\n", 454 | "[0/3] mean_loss : 0.47\n", 455 | "[0/3] mean_loss : 0.36\n", 456 | "[0/3] mean_loss : 0.31\n", 457 | "[0/3] mean_loss : 0.28\n", 458 | "[1/3] mean_loss : 0.22\n", 459 | "[1/3] mean_loss : 0.21\n", 460 | "[1/3] mean_loss : 0.21\n", 461 | "[1/3] mean_loss : 0.19\n", 462 | "[1/3] mean_loss : 0.19\n", 463 | "[2/3] mean_loss : 0.12\n", 464 | "[2/3] mean_loss : 0.15\n", 465 | "[2/3] mean_loss : 0.15\n", 466 | "[2/3] mean_loss : 0.14\n", 467 | "[2/3] mean_loss : 0.14\n" 468 | ] 469 | } 470 | ], 471 | "source": [ 472 | "for epoch in range(EPOCH):\n", 473 | " losses = []\n", 474 | " for i,batch in enumerate(getBatch(BATCH_SIZE, train_data)):\n", 475 | " x,y=list(zip(*batch))\n", 476 | " inputs = torch.cat([prepare_sequence(sent, word2index).view(1, -1) for sent in x])\n", 477 | " targets = torch.cat([prepare_tag(tag, tag2index) for tag in y])\n", 478 | " model.zero_grad()\n", 479 | " preds = model(inputs, is_training=True)\n", 480 | " loss = loss_function(preds, targets)\n", 481 | " losses.append(loss.data.tolist()[0])\n", 482 | " loss.backward()\n", 483 | " optimizer.step()\n", 484 | "\n", 485 | " if i % 1000 == 0:\n", 486 | " print(\"[%d/%d] mean_loss : %0.2f\" %(epoch, EPOCH, np.mean(losses)))\n", 487 | " losses = []" 488 | ] 489 | }, 490 | { 491 | "cell_type": "markdown", 492 | "metadata": {}, 493 | "source": [ 494 | "## Test " 495 | ] 496 | }, 497 | { 498 | "cell_type": "code", 499 | "execution_count": 24, 500 | "metadata": { 501 | "collapsed": true 502 | }, 503 | "outputs": [], 504 | "source": [ 505 | "for_f1_score = []" 506 | ] 507 | }, 508 | { 509 | "cell_type": "code", 510 | "execution_count": 25, 511 | "metadata": { 512 | "collapsed": false 513 | }, 514 | "outputs": [ 515 | { 516 | "name": "stdout", 517 | "output_type": "stream", 518 | "text": [ 519 | "95.69120551903063\n" 520 | ] 521 | } 522 | ], 523 | "source": [ 524 | "accuracy = 0\n", 525 | "for test in test_data:\n", 526 | " x, y = test[0], test[1]\n", 527 | " input_ = prepare_sequence(x, word2index).view(1, -1)\n", 528 | "\n", 529 | " i = model(input_).max(1)[1]\n", 530 | " pred = index2tag[i.data.tolist()[0]]\n", 531 | " for_f1_score.append([pred, y])\n", 532 | " if pred == y:\n", 533 | " accuracy += 1\n", 534 | "\n", 535 | "print(accuracy/len(test_data) * 100)" 536 | ] 537 | }, 538 | { 539 | "cell_type": "markdown", 540 | "metadata": {}, 541 | "source": [ 542 | "This high score is because most of labels are 'O' tag. So we need to measure f1 score." 543 | ] 544 | }, 545 | { 546 | "cell_type": "markdown", 547 | "metadata": {}, 548 | "source": [ 549 | "### Print Confusion matrix " 550 | ] 551 | }, 552 | { 553 | "cell_type": "code", 554 | "execution_count": 26, 555 | "metadata": { 556 | "collapsed": true 557 | }, 558 | "outputs": [], 559 | "source": [ 560 | "y_pred, y_test = list(zip(*for_f1_score))" 561 | ] 562 | }, 563 | { 564 | "cell_type": "code", 565 | "execution_count": 27, 566 | "metadata": { 567 | "collapsed": true 568 | }, 569 | "outputs": [], 570 | "source": [ 571 | "sorted_labels = sorted(\n", 572 | " list(set(y_test) - {'O'}),\n", 573 | " key=lambda name: (name[1:], name[0])\n", 574 | ")" 575 | ] 576 | }, 577 | { 578 | "cell_type": "code", 579 | "execution_count": 28, 580 | "metadata": { 581 | "collapsed": false 582 | }, 583 | "outputs": [ 584 | { 585 | "data": { 586 | "text/plain": [ 587 | "['B-LOC', 'I-LOC', 'B-MISC', 'I-MISC', 'B-ORG', 'I-ORG', 'B-PER', 'I-PER']" 588 | ] 589 | }, 590 | "execution_count": 28, 591 | "metadata": {}, 592 | "output_type": "execute_result" 593 | } 594 | ], 595 | "source": [ 596 | "sorted_labels" 597 | ] 598 | }, 599 | { 600 | "cell_type": "code", 601 | "execution_count": 29, 602 | "metadata": { 603 | "collapsed": true 604 | }, 605 | "outputs": [], 606 | "source": [ 607 | "y_pred = [[y] for y in y_pred] # this is because sklearn_crfsuite.metrics function flatten inputs\n", 608 | "y_test = [[y] for y in y_test]" 609 | ] 610 | }, 611 | { 612 | "cell_type": "code", 613 | "execution_count": 30, 614 | "metadata": { 615 | "collapsed": false 616 | }, 617 | "outputs": [ 618 | { 619 | "name": "stdout", 620 | "output_type": "stream", 621 | "text": [ 622 | " precision recall f1-score support\n", 623 | "\n", 624 | " B-LOC 0.802 0.636 0.710 1085\n", 625 | " I-LOC 0.732 0.457 0.562 311\n", 626 | " B-MISC 0.750 0.378 0.503 801\n", 627 | " I-MISC 0.679 0.331 0.445 641\n", 628 | " B-ORG 0.723 0.738 0.730 1430\n", 629 | " I-ORG 0.710 0.700 0.705 969\n", 630 | " B-PER 0.782 0.773 0.777 1268\n", 631 | " I-PER 0.853 0.871 0.861 950\n", 632 | "\n", 633 | "avg / total 0.759 0.656 0.693 7455\n", 634 | "\n" 635 | ] 636 | } 637 | ], 638 | "source": [ 639 | "print(metrics.flat_classification_report(\n", 640 | " y_test, y_pred, labels = sorted_labels, digits=3\n", 641 | "))" 642 | ] 643 | }, 644 | { 645 | "cell_type": "markdown", 646 | "metadata": { 647 | "collapsed": true 648 | }, 649 | "source": [ 650 | "### TODO" 651 | ] 652 | }, 653 | { 654 | "cell_type": "markdown", 655 | "metadata": {}, 656 | "source": [ 657 | "* use max-margin objective function http://pytorch.org/docs/master/nn.html#multilabelmarginloss" 658 | ] 659 | }, 660 | { 661 | "cell_type": "code", 662 | "execution_count": null, 663 | "metadata": { 664 | "collapsed": true 665 | }, 666 | "outputs": [], 667 | "source": [] 668 | } 669 | ], 670 | "metadata": { 671 | "kernelspec": { 672 | "display_name": "Python 3", 673 | "language": "python", 674 | "name": "python3" 675 | }, 676 | "language_info": { 677 | "codemirror_mode": { 678 | "name": "ipython", 679 | "version": 3 680 | }, 681 | "file_extension": ".py", 682 | "mimetype": "text/x-python", 683 | "name": "python", 684 | "nbconvert_exporter": "python", 685 | "pygments_lexer": "ipython3", 686 | "version": "3.5.2" 687 | } 688 | }, 689 | "nbformat": 4, 690 | "nbformat_minor": 2 691 | } 692 | -------------------------------------------------------------------------------- /notebooks/06.RNN-Language-Model.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# 6. Recurrent Neural Networks and Language Models" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "* http://web.stanford.edu/class/cs224n/lectures/cs224n-2017-lecture8.pdf\n", 15 | "* http://web.stanford.edu/class/cs224n/lectures/cs224n-2017-lecture9.pdf\n", 16 | "* http://colah.github.io/posts/2015-08-Understanding-LSTMs/\n", 17 | "* https://github.com/pytorch/examples/tree/master/word_language_model\n", 18 | "* https://github.com/yunjey/pytorch-tutorial/blob/master/tutorials/02-intermediate/language_model" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": 1, 24 | "metadata": { 25 | "collapsed": true 26 | }, 27 | "outputs": [], 28 | "source": [ 29 | "import torch\n", 30 | "import torch.nn as nn\n", 31 | "from torch.autograd import Variable\n", 32 | "import torch.optim as optim\n", 33 | "import torch.nn.functional as F\n", 34 | "import nltk\n", 35 | "import random\n", 36 | "import numpy as np\n", 37 | "from collections import Counter, OrderedDict\n", 38 | "import nltk\n", 39 | "from copy import deepcopy\n", 40 | "flatten = lambda l: [item for sublist in l for item in sublist]\n", 41 | "random.seed(1024)" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": 2, 47 | "metadata": { 48 | "collapsed": true 49 | }, 50 | "outputs": [], 51 | "source": [ 52 | "USE_CUDA = torch.cuda.is_available()\n", 53 | "gpus = [0]\n", 54 | "torch.cuda.set_device(gpus[0])\n", 55 | "\n", 56 | "FloatTensor = torch.cuda.FloatTensor if USE_CUDA else torch.FloatTensor\n", 57 | "LongTensor = torch.cuda.LongTensor if USE_CUDA else torch.LongTensor\n", 58 | "ByteTensor = torch.cuda.ByteTensor if USE_CUDA else torch.ByteTensor" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": 4, 64 | "metadata": { 65 | "collapsed": true 66 | }, 67 | "outputs": [], 68 | "source": [ 69 | "def prepare_sequence(seq, to_index):\n", 70 | " idxs = list(map(lambda w: to_index[w] if to_index.get(w) is not None else to_index[\"\"], seq))\n", 71 | " return LongTensor(idxs)" 72 | ] 73 | }, 74 | { 75 | "cell_type": "markdown", 76 | "metadata": {}, 77 | "source": [ 78 | "## Data load and Preprocessing" 79 | ] 80 | }, 81 | { 82 | "cell_type": "markdown", 83 | "metadata": {}, 84 | "source": [ 85 | "### Penn TreeBank" 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": 5, 91 | "metadata": { 92 | "collapsed": true 93 | }, 94 | "outputs": [], 95 | "source": [ 96 | "def prepare_ptb_dataset(filename, word2index=None):\n", 97 | " corpus = open(filename, 'r', encoding='utf-8').readlines()\n", 98 | " corpus = flatten([co.strip().split() + [''] for co in corpus])\n", 99 | " \n", 100 | " if word2index == None:\n", 101 | " vocab = list(set(corpus))\n", 102 | " word2index = {'': 0}\n", 103 | " for vo in vocab:\n", 104 | " if word2index.get(vo) is None:\n", 105 | " word2index[vo] = len(word2index)\n", 106 | " \n", 107 | " return prepare_sequence(corpus, word2index), word2index" 108 | ] 109 | }, 110 | { 111 | "cell_type": "code", 112 | "execution_count": 175, 113 | "metadata": { 114 | "collapsed": true 115 | }, 116 | "outputs": [], 117 | "source": [ 118 | "# borrowed code from https://github.com/pytorch/examples/tree/master/word_language_model\n", 119 | "\n", 120 | "def batchify(data, bsz):\n", 121 | " # Work out how cleanly we can divide the dataset into bsz parts.\n", 122 | " nbatch = data.size(0) // bsz\n", 123 | " # Trim off any extra elements that wouldn't cleanly fit (remainders).\n", 124 | " data = data.narrow(0, 0, nbatch * bsz)\n", 125 | " # Evenly divide the data across the bsz batches.\n", 126 | " data = data.view(bsz, -1).contiguous()\n", 127 | " if USE_CUDA:\n", 128 | " data = data.cuda()\n", 129 | " return data" 130 | ] 131 | }, 132 | { 133 | "cell_type": "code", 134 | "execution_count": 176, 135 | "metadata": { 136 | "collapsed": true 137 | }, 138 | "outputs": [], 139 | "source": [ 140 | "def getBatch(data, seq_length):\n", 141 | " for i in range(0, data.size(1) - seq_length, seq_length):\n", 142 | " inputs = Variable(data[:, i: i + seq_length])\n", 143 | " targets = Variable(data[:, (i + 1): (i + 1) + seq_length].contiguous())\n", 144 | " yield (inputs, targets)" 145 | ] 146 | }, 147 | { 148 | "cell_type": "code", 149 | "execution_count": 177, 150 | "metadata": { 151 | "collapsed": true 152 | }, 153 | "outputs": [], 154 | "source": [ 155 | "train_data, word2index = prepare_ptb_dataset('../dataset/ptb/ptb.train.txt',)\n", 156 | "dev_data , _ = prepare_ptb_dataset('../dataset/ptb/ptb.valid.txt', word2index)\n", 157 | "test_data, _ = prepare_ptb_dataset('../dataset/ptb/ptb.test.txt', word2index)" 158 | ] 159 | }, 160 | { 161 | "cell_type": "code", 162 | "execution_count": 178, 163 | "metadata": { 164 | "collapsed": false 165 | }, 166 | "outputs": [ 167 | { 168 | "data": { 169 | "text/plain": [ 170 | "10000" 171 | ] 172 | }, 173 | "execution_count": 178, 174 | "metadata": {}, 175 | "output_type": "execute_result" 176 | } 177 | ], 178 | "source": [ 179 | "len(word2index)" 180 | ] 181 | }, 182 | { 183 | "cell_type": "code", 184 | "execution_count": 179, 185 | "metadata": { 186 | "collapsed": true 187 | }, 188 | "outputs": [], 189 | "source": [ 190 | "index2word = {v:k for k, v in word2index.items()}" 191 | ] 192 | }, 193 | { 194 | "cell_type": "markdown", 195 | "metadata": {}, 196 | "source": [ 197 | "## Modeling " 198 | ] 199 | }, 200 | { 201 | "cell_type": "markdown", 202 | "metadata": {}, 203 | "source": [ 204 | "\n", 205 | "
borrowed image from http://web.stanford.edu/class/cs224n/lectures/cs224n-2017-lecture8.pdf
" 206 | ] 207 | }, 208 | { 209 | "cell_type": "code", 210 | "execution_count": 180, 211 | "metadata": { 212 | "collapsed": true 213 | }, 214 | "outputs": [], 215 | "source": [ 216 | "class LanguageModel(nn.Module): \n", 217 | " def __init__(self, vocab_size, embedding_size, hidden_size, n_layers=1, dropout_p=0.5):\n", 218 | "\n", 219 | " super(LanguageModel, self).__init__()\n", 220 | " self.n_layers = n_layers\n", 221 | " self.hidden_size = hidden_size\n", 222 | " self.embed = nn.Embedding(vocab_size, embedding_size)\n", 223 | " self.rnn = nn.LSTM(embedding_size, hidden_size, n_layers, batch_first=True)\n", 224 | " self.linear = nn.Linear(hidden_size, vocab_size)\n", 225 | " self.dropout = nn.Dropout(dropout_p)\n", 226 | " \n", 227 | " def init_weight(self):\n", 228 | " self.embed.weight = nn.init.xavier_uniform(self.embed.weight)\n", 229 | " self.linear.weight = nn.init.xavier_uniform(self.linear.weight)\n", 230 | " self.linear.bias.data.fill_(0)\n", 231 | " \n", 232 | " def init_hidden(self,batch_size):\n", 233 | " hidden = Variable(torch.zeros(self.n_layers,batch_size,self.hidden_size))\n", 234 | " context = Variable(torch.zeros(self.n_layers,batch_size,self.hidden_size))\n", 235 | " return (hidden.cuda(), context.cuda()) if USE_CUDA else (hidden, context)\n", 236 | " \n", 237 | " def detach_hidden(self, hiddens):\n", 238 | " return tuple([hidden.detach() for hidden in hiddens])\n", 239 | " \n", 240 | " def forward(self, inputs, hidden, is_training=False): \n", 241 | "\n", 242 | " embeds = self.embed(inputs)\n", 243 | " if is_training:\n", 244 | " embeds = self.dropout(embeds)\n", 245 | " out,hidden = self.rnn(embeds, hidden)\n", 246 | " return self.linear(out.contiguous().view(out.size(0) * out.size(1), -1)), hidden" 247 | ] 248 | }, 249 | { 250 | "cell_type": "markdown", 251 | "metadata": {}, 252 | "source": [ 253 | "## Train " 254 | ] 255 | }, 256 | { 257 | "cell_type": "markdown", 258 | "metadata": {}, 259 | "source": [ 260 | "It takes for a while..." 261 | ] 262 | }, 263 | { 264 | "cell_type": "code", 265 | "execution_count": 181, 266 | "metadata": { 267 | "collapsed": true 268 | }, 269 | "outputs": [], 270 | "source": [ 271 | "EMBED_SIZE = 128\n", 272 | "HIDDEN_SIZE = 1024\n", 273 | "NUM_LAYER = 1\n", 274 | "LR = 0.01\n", 275 | "SEQ_LENGTH = 30 # for bptt\n", 276 | "BATCH_SIZE = 20\n", 277 | "EPOCH = 40\n", 278 | "RESCHEDULED = False" 279 | ] 280 | }, 281 | { 282 | "cell_type": "code", 283 | "execution_count": 182, 284 | "metadata": { 285 | "collapsed": true 286 | }, 287 | "outputs": [], 288 | "source": [ 289 | "train_data = batchify(train_data, BATCH_SIZE)\n", 290 | "dev_data = batchify(dev_data, BATCH_SIZE//2)\n", 291 | "test_data = batchify(test_data, BATCH_SIZE//2)" 292 | ] 293 | }, 294 | { 295 | "cell_type": "code", 296 | "execution_count": 185, 297 | "metadata": { 298 | "collapsed": true 299 | }, 300 | "outputs": [], 301 | "source": [ 302 | "model = LanguageModel(len(word2index), EMBED_SIZE, HIDDEN_SIZE, NUM_LAYER, 0.5)\n", 303 | "model.init_weight() \n", 304 | "if USE_CUDA:\n", 305 | " model = model.cuda()\n", 306 | "loss_function = nn.CrossEntropyLoss()\n", 307 | "optimizer = optim.Adam(model.parameters(), lr=LR)" 308 | ] 309 | }, 310 | { 311 | "cell_type": "code", 312 | "execution_count": 186, 313 | "metadata": { 314 | "collapsed": false 315 | }, 316 | "outputs": [ 317 | { 318 | "name": "stdout", 319 | "output_type": "stream", 320 | "text": [ 321 | "[00/40] mean_loss : 9.45, Perplexity : 12712.23\n", 322 | "[00/40] mean_loss : 5.88, Perplexity : 358.21\n", 323 | "[00/40] mean_loss : 5.55, Perplexity : 256.44\n", 324 | "[01/40] mean_loss : 5.38, Perplexity : 217.46\n", 325 | "[01/40] mean_loss : 5.21, Perplexity : 182.41\n", 326 | "[01/40] mean_loss : 5.10, Perplexity : 164.39\n", 327 | "[02/40] mean_loss : 5.08, Perplexity : 160.87\n", 328 | "[02/40] mean_loss : 4.99, Perplexity : 147.18\n", 329 | "[02/40] mean_loss : 4.92, Perplexity : 136.52\n", 330 | "[03/40] mean_loss : 4.92, Perplexity : 136.64\n", 331 | "[03/40] mean_loss : 4.86, Perplexity : 129.32\n", 332 | "[03/40] mean_loss : 4.80, Perplexity : 121.46\n", 333 | "[04/40] mean_loss : 4.80, Perplexity : 121.91\n", 334 | "[04/40] mean_loss : 4.77, Perplexity : 117.64\n", 335 | "[04/40] mean_loss : 4.71, Perplexity : 111.22\n", 336 | "[05/40] mean_loss : 4.72, Perplexity : 112.01\n", 337 | "[05/40] mean_loss : 4.70, Perplexity : 109.46\n", 338 | "[05/40] mean_loss : 4.64, Perplexity : 103.96\n", 339 | "[06/40] mean_loss : 4.66, Perplexity : 105.25\n", 340 | "[06/40] mean_loss : 4.64, Perplexity : 103.63\n", 341 | "[06/40] mean_loss : 4.60, Perplexity : 99.00\n", 342 | "[07/40] mean_loss : 4.60, Perplexity : 99.89\n", 343 | "[07/40] mean_loss : 4.59, Perplexity : 98.97\n", 344 | "[07/40] mean_loss : 4.55, Perplexity : 94.97\n", 345 | "[08/40] mean_loss : 4.56, Perplexity : 95.54\n", 346 | "[08/40] mean_loss : 4.56, Perplexity : 95.67\n", 347 | "[08/40] mean_loss : 4.52, Perplexity : 91.98\n", 348 | "[09/40] mean_loss : 4.53, Perplexity : 92.61\n", 349 | "[09/40] mean_loss : 4.53, Perplexity : 92.79\n", 350 | "[09/40] mean_loss : 4.50, Perplexity : 89.63\n", 351 | "[10/40] mean_loss : 4.50, Perplexity : 90.13\n", 352 | "[10/40] mean_loss : 4.50, Perplexity : 90.19\n", 353 | "[10/40] mean_loss : 4.47, Perplexity : 87.11\n", 354 | "[11/40] mean_loss : 4.48, Perplexity : 88.11\n", 355 | "[11/40] mean_loss : 4.48, Perplexity : 88.26\n", 356 | "[11/40] mean_loss : 4.45, Perplexity : 86.05\n", 357 | "[12/40] mean_loss : 4.46, Perplexity : 86.81\n", 358 | "[12/40] mean_loss : 4.47, Perplexity : 87.03\n", 359 | "[12/40] mean_loss : 4.43, Perplexity : 84.04\n", 360 | "[13/40] mean_loss : 4.45, Perplexity : 85.27\n", 361 | "[13/40] mean_loss : 4.45, Perplexity : 85.83\n", 362 | "[13/40] mean_loss : 4.42, Perplexity : 83.33\n", 363 | "[14/40] mean_loss : 4.43, Perplexity : 84.15\n", 364 | "[14/40] mean_loss : 4.43, Perplexity : 84.31\n", 365 | "[14/40] mean_loss : 4.41, Perplexity : 82.29\n", 366 | "[15/40] mean_loss : 4.43, Perplexity : 83.82\n", 367 | "[15/40] mean_loss : 4.43, Perplexity : 83.70\n", 368 | "[15/40] mean_loss : 4.40, Perplexity : 81.59\n", 369 | "[16/40] mean_loss : 4.42, Perplexity : 83.06\n", 370 | "[16/40] mean_loss : 4.42, Perplexity : 83.29\n", 371 | "[16/40] mean_loss : 4.39, Perplexity : 80.89\n", 372 | "[17/40] mean_loss : 4.41, Perplexity : 82.44\n", 373 | "[17/40] mean_loss : 4.41, Perplexity : 82.51\n", 374 | "[17/40] mean_loss : 4.39, Perplexity : 80.59\n", 375 | "[18/40] mean_loss : 4.40, Perplexity : 81.59\n", 376 | "[18/40] mean_loss : 4.41, Perplexity : 82.21\n", 377 | "[18/40] mean_loss : 4.38, Perplexity : 79.87\n", 378 | "[19/40] mean_loss : 4.40, Perplexity : 81.43\n", 379 | "[19/40] mean_loss : 4.40, Perplexity : 81.67\n", 380 | "[19/40] mean_loss : 4.37, Perplexity : 79.28\n", 381 | "[20/40] mean_loss : 4.40, Perplexity : 81.18\n", 382 | "[20/40] mean_loss : 4.40, Perplexity : 81.17\n", 383 | "[20/40] mean_loss : 4.37, Perplexity : 79.11\n", 384 | "[21/40] mean_loss : 4.40, Perplexity : 81.44\n", 385 | "[21/40] mean_loss : 4.34, Perplexity : 76.43\n", 386 | "[21/40] mean_loss : 4.21, Perplexity : 67.17\n", 387 | "[22/40] mean_loss : 4.26, Perplexity : 70.84\n", 388 | "[22/40] mean_loss : 4.26, Perplexity : 70.75\n", 389 | "[22/40] mean_loss : 4.17, Perplexity : 64.99\n", 390 | "[23/40] mean_loss : 4.22, Perplexity : 68.36\n", 391 | "[23/40] mean_loss : 4.22, Perplexity : 67.82\n", 392 | "[23/40] mean_loss : 4.15, Perplexity : 63.74\n", 393 | "[24/40] mean_loss : 4.20, Perplexity : 66.66\n", 394 | "[24/40] mean_loss : 4.20, Perplexity : 66.43\n", 395 | "[24/40] mean_loss : 4.14, Perplexity : 62.85\n", 396 | "[25/40] mean_loss : 4.18, Perplexity : 65.53\n", 397 | "[25/40] mean_loss : 4.17, Perplexity : 64.99\n", 398 | "[25/40] mean_loss : 4.13, Perplexity : 61.94\n", 399 | "[26/40] mean_loss : 4.17, Perplexity : 64.61\n", 400 | "[26/40] mean_loss : 4.16, Perplexity : 64.34\n", 401 | "[26/40] mean_loss : 4.12, Perplexity : 61.27\n", 402 | "[27/40] mean_loss : 4.15, Perplexity : 63.73\n", 403 | "[27/40] mean_loss : 4.15, Perplexity : 63.32\n", 404 | "[27/40] mean_loss : 4.11, Perplexity : 60.87\n", 405 | "[28/40] mean_loss : 4.14, Perplexity : 62.96\n", 406 | "[28/40] mean_loss : 4.14, Perplexity : 63.01\n", 407 | "[28/40] mean_loss : 4.10, Perplexity : 60.33\n", 408 | "[29/40] mean_loss : 4.14, Perplexity : 62.54\n", 409 | "[29/40] mean_loss : 4.13, Perplexity : 62.36\n", 410 | "[29/40] mean_loss : 4.10, Perplexity : 60.06\n", 411 | "[30/40] mean_loss : 4.13, Perplexity : 62.05\n", 412 | "[30/40] mean_loss : 4.13, Perplexity : 61.91\n", 413 | "[30/40] mean_loss : 4.09, Perplexity : 59.46\n", 414 | "[31/40] mean_loss : 4.12, Perplexity : 61.45\n", 415 | "[31/40] mean_loss : 4.11, Perplexity : 61.24\n", 416 | "[31/40] mean_loss : 4.08, Perplexity : 59.12\n", 417 | "[32/40] mean_loss : 4.11, Perplexity : 61.03\n", 418 | "[32/40] mean_loss : 4.11, Perplexity : 60.88\n", 419 | "[32/40] mean_loss : 4.07, Perplexity : 58.69\n", 420 | "[33/40] mean_loss : 4.11, Perplexity : 60.71\n", 421 | "[33/40] mean_loss : 4.10, Perplexity : 60.57\n", 422 | "[33/40] mean_loss : 4.07, Perplexity : 58.38\n", 423 | "[34/40] mean_loss : 4.10, Perplexity : 60.33\n", 424 | "[34/40] mean_loss : 4.10, Perplexity : 60.23\n", 425 | "[34/40] mean_loss : 4.06, Perplexity : 58.06\n", 426 | "[35/40] mean_loss : 4.09, Perplexity : 60.00\n", 427 | "[35/40] mean_loss : 4.09, Perplexity : 59.74\n", 428 | "[35/40] mean_loss : 4.06, Perplexity : 57.75\n", 429 | "[36/40] mean_loss : 4.09, Perplexity : 59.58\n", 430 | "[36/40] mean_loss : 4.09, Perplexity : 59.47\n", 431 | "[36/40] mean_loss : 4.05, Perplexity : 57.59\n", 432 | "[37/40] mean_loss : 4.08, Perplexity : 59.30\n", 433 | "[37/40] mean_loss : 4.08, Perplexity : 59.11\n", 434 | "[37/40] mean_loss : 4.05, Perplexity : 57.11\n", 435 | "[38/40] mean_loss : 4.08, Perplexity : 58.98\n", 436 | "[38/40] mean_loss : 4.07, Perplexity : 58.70\n", 437 | "[38/40] mean_loss : 4.04, Perplexity : 57.10\n", 438 | "[39/40] mean_loss : 4.07, Perplexity : 58.79\n", 439 | "[39/40] mean_loss : 4.07, Perplexity : 58.58\n", 440 | "[39/40] mean_loss : 4.04, Perplexity : 56.79\n" 441 | ] 442 | } 443 | ], 444 | "source": [ 445 | "for epoch in range(EPOCH):\n", 446 | " total_loss = 0\n", 447 | " losses = []\n", 448 | " hidden = model.init_hidden(BATCH_SIZE)\n", 449 | " for i,batch in enumerate(getBatch(train_data, SEQ_LENGTH)):\n", 450 | " inputs, targets = batch\n", 451 | " hidden = model.detach_hidden(hidden)\n", 452 | " model.zero_grad()\n", 453 | " preds, hidden = model(inputs, hidden, True)\n", 454 | "\n", 455 | " loss = loss_function(preds, targets.view(-1))\n", 456 | " losses.append(loss.data[0])\n", 457 | " loss.backward()\n", 458 | " torch.nn.utils.clip_grad_norm(model.parameters(), 0.5) # gradient clipping\n", 459 | " optimizer.step()\n", 460 | "\n", 461 | " if i > 0 and i % 500 == 0:\n", 462 | " print(\"[%02d/%d] mean_loss : %0.2f, Perplexity : %0.2f\" % (epoch,EPOCH, np.mean(losses), np.exp(np.mean(losses))))\n", 463 | " losses = []\n", 464 | " \n", 465 | " # learning rate anealing\n", 466 | " # You can use http://pytorch.org/docs/master/optim.html#how-to-adjust-learning-rate\n", 467 | " if RESCHEDULED == False and epoch == EPOCH//2:\n", 468 | " LR *= 0.1\n", 469 | " optimizer = optim.Adam(model.parameters(), lr=LR)\n", 470 | " RESCHEDULED = True" 471 | ] 472 | }, 473 | { 474 | "cell_type": "markdown", 475 | "metadata": {}, 476 | "source": [ 477 | "### Test " 478 | ] 479 | }, 480 | { 481 | "cell_type": "code", 482 | "execution_count": 189, 483 | "metadata": { 484 | "collapsed": false 485 | }, 486 | "outputs": [ 487 | { 488 | "name": "stdout", 489 | "output_type": "stream", 490 | "text": [ 491 | "Test Perpelexity : 155.89\n" 492 | ] 493 | } 494 | ], 495 | "source": [ 496 | "total_loss = 0\n", 497 | "hidden = model.init_hidden(BATCH_SIZE//2)\n", 498 | "for batch in getBatch(test_data, SEQ_LENGTH):\n", 499 | " inputs,targets = batch\n", 500 | " \n", 501 | " hidden = model.detach_hidden(hidden)\n", 502 | " model.zero_grad()\n", 503 | " preds, hidden = model(inputs, hidden)\n", 504 | " total_loss += inputs.size(1) * loss_function(preds, targets.view(-1)).data\n", 505 | "\n", 506 | "total_loss = total_loss[0]/test_data.size(1)\n", 507 | "print(\"Test Perpelexity : %5.2f\" % (np.exp(total_loss)))" 508 | ] 509 | }, 510 | { 511 | "cell_type": "markdown", 512 | "metadata": { 513 | "collapsed": true 514 | }, 515 | "source": [ 516 | "## Further topics" 517 | ] 518 | }, 519 | { 520 | "cell_type": "markdown", 521 | "metadata": {}, 522 | "source": [ 523 | "* Pointer Sentinel Mixture Models\n", 524 | "* Regularizing and Optimizing LSTM Language Models" 525 | ] 526 | }, 527 | { 528 | "cell_type": "code", 529 | "execution_count": null, 530 | "metadata": { 531 | "collapsed": true 532 | }, 533 | "outputs": [], 534 | "source": [] 535 | } 536 | ], 537 | "metadata": { 538 | "kernelspec": { 539 | "display_name": "Python 3", 540 | "language": "python", 541 | "name": "python3" 542 | }, 543 | "language_info": { 544 | "codemirror_mode": { 545 | "name": "ipython", 546 | "version": 3 547 | }, 548 | "file_extension": ".py", 549 | "mimetype": "text/x-python", 550 | "name": "python", 551 | "nbconvert_exporter": "python", 552 | "pygments_lexer": "ipython3", 553 | "version": "3.5.2" 554 | } 555 | }, 556 | "nbformat": 4, 557 | "nbformat_minor": 2 558 | } 559 | -------------------------------------------------------------------------------- /notebooks/08.CNN-for-Text-Classification.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# 8. Convolutional Neural Networks" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "I recommend you take a look at these material first." 15 | ] 16 | }, 17 | { 18 | "cell_type": "markdown", 19 | "metadata": {}, 20 | "source": [ 21 | "* http://web.stanford.edu/class/cs224n/lectures/cs224n-2017-lecture13-CNNs.pdf\n", 22 | "* http://www.aclweb.org/anthology/D14-1181\n", 23 | "* https://github.com/Shawn1993/cnn-text-classification-pytorch\n", 24 | "* http://cogcomp.org/Data/QA/QC/" 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": 58, 30 | "metadata": { 31 | "collapsed": true 32 | }, 33 | "outputs": [], 34 | "source": [ 35 | "import torch\n", 36 | "import torch.nn as nn\n", 37 | "from torch.autograd import Variable\n", 38 | "import torch.optim as optim\n", 39 | "import torch.nn.functional as F\n", 40 | "import nltk\n", 41 | "import random\n", 42 | "import numpy as np\n", 43 | "from collections import Counter, OrderedDict\n", 44 | "import nltk\n", 45 | "import re\n", 46 | "from copy import deepcopy\n", 47 | "flatten = lambda l: [item for sublist in l for item in sublist]\n", 48 | "random.seed(1024)" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": 2, 54 | "metadata": { 55 | "collapsed": true 56 | }, 57 | "outputs": [], 58 | "source": [ 59 | "USE_CUDA = torch.cuda.is_available()\n", 60 | "gpus = [0]\n", 61 | "torch.cuda.set_device(gpus[0])\n", 62 | "\n", 63 | "FloatTensor = torch.cuda.FloatTensor if USE_CUDA else torch.FloatTensor\n", 64 | "LongTensor = torch.cuda.LongTensor if USE_CUDA else torch.LongTensor\n", 65 | "ByteTensor = torch.cuda.ByteTensor if USE_CUDA else torch.ByteTensor" 66 | ] 67 | }, 68 | { 69 | "cell_type": "code", 70 | "execution_count": 3, 71 | "metadata": { 72 | "collapsed": true 73 | }, 74 | "outputs": [], 75 | "source": [ 76 | "def getBatch(batch_size, train_data):\n", 77 | " random.shuffle(train_data)\n", 78 | " sindex = 0\n", 79 | " eindex = batch_size\n", 80 | " while eindex < len(train_data):\n", 81 | " batch = train_data[sindex: eindex]\n", 82 | " temp = eindex\n", 83 | " eindex = eindex + batch_size\n", 84 | " sindex = temp\n", 85 | " yield batch\n", 86 | " \n", 87 | " if eindex >= len(train_data):\n", 88 | " batch = train_data[sindex:]\n", 89 | " yield batch" 90 | ] 91 | }, 92 | { 93 | "cell_type": "code", 94 | "execution_count": 110, 95 | "metadata": { 96 | "collapsed": true 97 | }, 98 | "outputs": [], 99 | "source": [ 100 | "def pad_to_batch(batch):\n", 101 | " x,y = zip(*batch)\n", 102 | " max_x = max([s.size(1) for s in x])\n", 103 | " x_p = []\n", 104 | " for i in range(len(batch)):\n", 105 | " if x[i].size(1) < max_x:\n", 106 | " x_p.append(torch.cat([x[i], Variable(LongTensor([word2index['']] * (max_x - x[i].size(1)))).view(1, -1)], 1))\n", 107 | " else:\n", 108 | " x_p.append(x[i])\n", 109 | " return torch.cat(x_p), torch.cat(y).view(-1)" 110 | ] 111 | }, 112 | { 113 | "cell_type": "code", 114 | "execution_count": 20, 115 | "metadata": { 116 | "collapsed": true 117 | }, 118 | "outputs": [], 119 | "source": [ 120 | "def prepare_sequence(seq, to_index):\n", 121 | " idxs = list(map(lambda w: to_index[w] if to_index.get(w) is not None else to_index[\"\"], seq))\n", 122 | " return Variable(LongTensor(idxs))" 123 | ] 124 | }, 125 | { 126 | "cell_type": "markdown", 127 | "metadata": {}, 128 | "source": [ 129 | "## Data load & Preprocessing" 130 | ] 131 | }, 132 | { 133 | "cell_type": "markdown", 134 | "metadata": {}, 135 | "source": [ 136 | "### TREC question dataset(http://cogcomp.org/Data/QA/QC/)" 137 | ] 138 | }, 139 | { 140 | "cell_type": "markdown", 141 | "metadata": {}, 142 | "source": [ 143 | "Task involves\n", 144 | "classifying a question into 6 question\n", 145 | "types (whether the question is about person,\n", 146 | "location, numeric information, etc.)" 147 | ] 148 | }, 149 | { 150 | "cell_type": "code", 151 | "execution_count": 53, 152 | "metadata": { 153 | "collapsed": true 154 | }, 155 | "outputs": [], 156 | "source": [ 157 | "data = open('../dataset/train_5500.label.txt', 'r', encoding='latin-1').readlines()" 158 | ] 159 | }, 160 | { 161 | "cell_type": "code", 162 | "execution_count": 54, 163 | "metadata": { 164 | "collapsed": true 165 | }, 166 | "outputs": [], 167 | "source": [ 168 | "data = [[d.split(':')[1][:-1], d.split(':')[0]] for d in data]" 169 | ] 170 | }, 171 | { 172 | "cell_type": "code", 173 | "execution_count": 61, 174 | "metadata": { 175 | "collapsed": true 176 | }, 177 | "outputs": [], 178 | "source": [ 179 | "X, y = list(zip(*data))\n", 180 | "X = list(X)" 181 | ] 182 | }, 183 | { 184 | "cell_type": "markdown", 185 | "metadata": {}, 186 | "source": [ 187 | "### Num masking " 188 | ] 189 | }, 190 | { 191 | "cell_type": "markdown", 192 | "metadata": {}, 193 | "source": [ 194 | "It reduces the search space. ex. my birthday is 12.22 ==> my birthday is ##.##" 195 | ] 196 | }, 197 | { 198 | "cell_type": "code", 199 | "execution_count": 62, 200 | "metadata": { 201 | "collapsed": true 202 | }, 203 | "outputs": [], 204 | "source": [ 205 | "for i, x in enumerate(X):\n", 206 | " X[i] = re.sub('\\d', '#', x).split()" 207 | ] 208 | }, 209 | { 210 | "cell_type": "markdown", 211 | "metadata": {}, 212 | "source": [ 213 | "### Build Vocab " 214 | ] 215 | }, 216 | { 217 | "cell_type": "code", 218 | "execution_count": 63, 219 | "metadata": { 220 | "collapsed": true 221 | }, 222 | "outputs": [], 223 | "source": [ 224 | "vocab = list(set(flatten(X)))" 225 | ] 226 | }, 227 | { 228 | "cell_type": "code", 229 | "execution_count": 64, 230 | "metadata": { 231 | "collapsed": false 232 | }, 233 | "outputs": [ 234 | { 235 | "data": { 236 | "text/plain": [ 237 | "9117" 238 | ] 239 | }, 240 | "execution_count": 64, 241 | "metadata": {}, 242 | "output_type": "execute_result" 243 | } 244 | ], 245 | "source": [ 246 | "len(vocab)" 247 | ] 248 | }, 249 | { 250 | "cell_type": "code", 251 | "execution_count": 31, 252 | "metadata": { 253 | "collapsed": false 254 | }, 255 | "outputs": [ 256 | { 257 | "data": { 258 | "text/plain": [ 259 | "6" 260 | ] 261 | }, 262 | "execution_count": 31, 263 | "metadata": {}, 264 | "output_type": "execute_result" 265 | } 266 | ], 267 | "source": [ 268 | "len(set(y)) # num of class" 269 | ] 270 | }, 271 | { 272 | "cell_type": "code", 273 | "execution_count": 94, 274 | "metadata": { 275 | "collapsed": true 276 | }, 277 | "outputs": [], 278 | "source": [ 279 | "word2index={'': 0, '': 1}\n", 280 | "\n", 281 | "for vo in vocab:\n", 282 | " if word2index.get(vo) is None:\n", 283 | " word2index[vo] = len(word2index)\n", 284 | " \n", 285 | "index2word = {v:k for k, v in word2index.items()}\n", 286 | "\n", 287 | "target2index = {}\n", 288 | "\n", 289 | "for cl in set(y):\n", 290 | " if target2index.get(cl) is None:\n", 291 | " target2index[cl] = len(target2index)\n", 292 | "\n", 293 | "index2target = {v:k for k, v in target2index.items()}" 294 | ] 295 | }, 296 | { 297 | "cell_type": "code", 298 | "execution_count": 95, 299 | "metadata": { 300 | "collapsed": true 301 | }, 302 | "outputs": [], 303 | "source": [ 304 | "X_p, y_p = [], []\n", 305 | "for pair in zip(X,y):\n", 306 | " X_p.append(prepare_sequence(pair[0], word2index).view(1, -1))\n", 307 | " y_p.append(Variable(LongTensor([target2index[pair[1]]])).view(1, -1))\n", 308 | " \n", 309 | "data_p = list(zip(X_p, y_p))\n", 310 | "random.shuffle(data_p)\n", 311 | "\n", 312 | "train_data = data_p[: int(len(data_p) * 0.9)]\n", 313 | "test_data = data_p[int(len(data_p) * 0.9):]" 314 | ] 315 | }, 316 | { 317 | "cell_type": "markdown", 318 | "metadata": {}, 319 | "source": [ 320 | "### Load Pretrained word vector" 321 | ] 322 | }, 323 | { 324 | "cell_type": "markdown", 325 | "metadata": {}, 326 | "source": [ 327 | "you can download pretrained word vector from here https://github.com/mmihaltz/word2vec-GoogleNews-vectors " 328 | ] 329 | }, 330 | { 331 | "cell_type": "code", 332 | "execution_count": 41, 333 | "metadata": { 334 | "collapsed": true 335 | }, 336 | "outputs": [], 337 | "source": [ 338 | "import gensim" 339 | ] 340 | }, 341 | { 342 | "cell_type": "code", 343 | "execution_count": 43, 344 | "metadata": { 345 | "collapsed": true 346 | }, 347 | "outputs": [], 348 | "source": [ 349 | "model = gensim.models.KeyedVectors.load_word2vec_format('../dataset/GoogleNews-vectors-negative300.bin', binary=True)" 350 | ] 351 | }, 352 | { 353 | "cell_type": "code", 354 | "execution_count": 48, 355 | "metadata": { 356 | "collapsed": false 357 | }, 358 | "outputs": [ 359 | { 360 | "data": { 361 | "text/plain": [ 362 | "3000000" 363 | ] 364 | }, 365 | "execution_count": 48, 366 | "metadata": {}, 367 | "output_type": "execute_result" 368 | } 369 | ], 370 | "source": [ 371 | "len(model.index2word)" 372 | ] 373 | }, 374 | { 375 | "cell_type": "code", 376 | "execution_count": 96, 377 | "metadata": { 378 | "collapsed": true 379 | }, 380 | "outputs": [], 381 | "source": [ 382 | "pretrained = []\n", 383 | "\n", 384 | "for key in word2index.keys():\n", 385 | " try:\n", 386 | " pretrained.append(model[word2index[key]])\n", 387 | " except:\n", 388 | " pretrained.append(np.random.randn(300))\n", 389 | " \n", 390 | "pretrained_vectors = np.vstack(pretrained)" 391 | ] 392 | }, 393 | { 394 | "cell_type": "markdown", 395 | "metadata": {}, 396 | "source": [ 397 | "## Modeling " 398 | ] 399 | }, 400 | { 401 | "cell_type": "markdown", 402 | "metadata": {}, 403 | "source": [ 404 | "\n", 405 | "
borrowed image from http://www.aclweb.org/anthology/D14-1181
" 406 | ] 407 | }, 408 | { 409 | "cell_type": "code", 410 | "execution_count": 117, 411 | "metadata": { 412 | "collapsed": true 413 | }, 414 | "outputs": [], 415 | "source": [ 416 | "class CNNClassifier(nn.Module):\n", 417 | " \n", 418 | " def __init__(self, vocab_size, embedding_dim, output_size, kernel_dim=100, kernel_sizes=(3, 4, 5), dropout=0.5):\n", 419 | " super(CNNClassifier,self).__init__()\n", 420 | "\n", 421 | " self.embedding = nn.Embedding(vocab_size, embedding_dim)\n", 422 | " self.convs = nn.ModuleList([nn.Conv2d(1, kernel_dim, (K, embedding_dim)) for K in kernel_sizes])\n", 423 | "\n", 424 | " # kernal_size = (K,D) \n", 425 | " self.dropout = nn.Dropout(dropout)\n", 426 | " self.fc = nn.Linear(len(kernel_sizes) * kernel_dim, output_size)\n", 427 | " \n", 428 | " \n", 429 | " def init_weights(self, pretrained_word_vectors, is_static=False):\n", 430 | " self.embedding.weight = nn.Parameter(torch.from_numpy(pretrained_word_vectors).float())\n", 431 | " if is_static:\n", 432 | " self.embedding.weight.requires_grad = False\n", 433 | "\n", 434 | "\n", 435 | " def forward(self, inputs, is_training=False):\n", 436 | " inputs = self.embedding(inputs).unsqueeze(1) # (B,1,T,D)\n", 437 | " inputs = [F.relu(conv(inputs)).squeeze(3) for conv in self.convs] #[(N,Co,W), ...]*len(Ks)\n", 438 | " inputs = [F.max_pool1d(i, i.size(2)).squeeze(2) for i in inputs] #[(N,Co), ...]*len(Ks)\n", 439 | "\n", 440 | " concated = torch.cat(inputs, 1)\n", 441 | "\n", 442 | " if is_training:\n", 443 | " concated = self.dropout(concated) # (N,len(Ks)*Co)\n", 444 | " out = self.fc(concated) \n", 445 | " return F.log_softmax(out,1)" 446 | ] 447 | }, 448 | { 449 | "cell_type": "markdown", 450 | "metadata": {}, 451 | "source": [ 452 | "## Train " 453 | ] 454 | }, 455 | { 456 | "cell_type": "markdown", 457 | "metadata": {}, 458 | "source": [ 459 | "It takes for a while if you use just cpu." 460 | ] 461 | }, 462 | { 463 | "cell_type": "code", 464 | "execution_count": 145, 465 | "metadata": { 466 | "collapsed": true 467 | }, 468 | "outputs": [], 469 | "source": [ 470 | "EPOCH = 5\n", 471 | "BATCH_SIZE = 50\n", 472 | "KERNEL_SIZES = [3,4,5]\n", 473 | "KERNEL_DIM = 100\n", 474 | "LR = 0.001" 475 | ] 476 | }, 477 | { 478 | "cell_type": "code", 479 | "execution_count": 146, 480 | "metadata": { 481 | "collapsed": true 482 | }, 483 | "outputs": [], 484 | "source": [ 485 | "model = CNNClassifier(len(word2index), 300, len(target2index), KERNEL_DIM, KERNEL_SIZES)\n", 486 | "model.init_weights(pretrained_vectors) # initialize embedding matrix using pretrained vectors\n", 487 | "\n", 488 | "if USE_CUDA:\n", 489 | " model = model.cuda()\n", 490 | " \n", 491 | "loss_function = nn.CrossEntropyLoss()\n", 492 | "optimizer = optim.Adam(model.parameters(), lr=LR)" 493 | ] 494 | }, 495 | { 496 | "cell_type": "code", 497 | "execution_count": 147, 498 | "metadata": { 499 | "collapsed": false 500 | }, 501 | "outputs": [ 502 | { 503 | "name": "stdout", 504 | "output_type": "stream", 505 | "text": [ 506 | "[0/5] mean_loss : 2.13\n", 507 | "[1/5] mean_loss : 0.12\n", 508 | "[2/5] mean_loss : 0.08\n", 509 | "[3/5] mean_loss : 0.02\n", 510 | "[4/5] mean_loss : 0.05\n" 511 | ] 512 | } 513 | ], 514 | "source": [ 515 | "for epoch in range(EPOCH):\n", 516 | " losses = []\n", 517 | " for i,batch in enumerate(getBatch(BATCH_SIZE, train_data)):\n", 518 | " inputs,targets = pad_to_batch(batch)\n", 519 | " \n", 520 | " model.zero_grad()\n", 521 | " preds = model(inputs, True)\n", 522 | " \n", 523 | " loss = loss_function(preds, targets)\n", 524 | " losses.append(loss.data.tolist()[0])\n", 525 | " loss.backward()\n", 526 | " \n", 527 | " #for param in model.parameters():\n", 528 | " # param.grad.data.clamp_(-3, 3)\n", 529 | " \n", 530 | " optimizer.step()\n", 531 | " \n", 532 | " if i % 100 == 0:\n", 533 | " print(\"[%d/%d] mean_loss : %0.2f\" %(epoch, EPOCH, np.mean(losses)))\n", 534 | " losses = []" 535 | ] 536 | }, 537 | { 538 | "cell_type": "markdown", 539 | "metadata": {}, 540 | "source": [ 541 | "## Test " 542 | ] 543 | }, 544 | { 545 | "cell_type": "code", 546 | "execution_count": 150, 547 | "metadata": { 548 | "collapsed": true 549 | }, 550 | "outputs": [], 551 | "source": [ 552 | "accuracy = 0" 553 | ] 554 | }, 555 | { 556 | "cell_type": "code", 557 | "execution_count": 151, 558 | "metadata": { 559 | "collapsed": false 560 | }, 561 | "outputs": [ 562 | { 563 | "name": "stdout", 564 | "output_type": "stream", 565 | "text": [ 566 | "97.61904761904762\n" 567 | ] 568 | } 569 | ], 570 | "source": [ 571 | "for test in test_data:\n", 572 | " pred = model(test[0]).max(1)[1]\n", 573 | " pred = pred.data.tolist()[0]\n", 574 | " target = test[1].data.tolist()[0][0]\n", 575 | " if pred == target:\n", 576 | " accuracy += 1\n", 577 | "\n", 578 | "print(accuracy/len(test_data) * 100)" 579 | ] 580 | }, 581 | { 582 | "cell_type": "markdown", 583 | "metadata": { 584 | "collapsed": true 585 | }, 586 | "source": [ 587 | "## Further topics " 588 | ] 589 | }, 590 | { 591 | "cell_type": "markdown", 592 | "metadata": {}, 593 | "source": [ 594 | "* Character-Aware Neural Language Models\n", 595 | "* Character level CNN for text classification" 596 | ] 597 | }, 598 | { 599 | "cell_type": "markdown", 600 | "metadata": {}, 601 | "source": [ 602 | "## Suggested Reading" 603 | ] 604 | }, 605 | { 606 | "cell_type": "markdown", 607 | "metadata": {}, 608 | "source": [ 609 | "* https://blog.statsbot.co/text-classifier-algorithms-in-machine-learning-acc115293278\n", 610 | "* Bag of Tricks for Efficient Text Classification\n", 611 | "* Which Encoding is the Best for Text Classification in Chinese, English, Japanese and Korean?" 612 | ] 613 | }, 614 | { 615 | "cell_type": "code", 616 | "execution_count": null, 617 | "metadata": { 618 | "collapsed": true 619 | }, 620 | "outputs": [], 621 | "source": [] 622 | } 623 | ], 624 | "metadata": { 625 | "kernelspec": { 626 | "display_name": "Python 3", 627 | "language": "python", 628 | "name": "python3" 629 | }, 630 | "language_info": { 631 | "codemirror_mode": { 632 | "name": "ipython", 633 | "version": 3 634 | }, 635 | "file_extension": ".py", 636 | "mimetype": "text/x-python", 637 | "name": "python", 638 | "nbconvert_exporter": "python", 639 | "pygments_lexer": "ipython3", 640 | "version": "3.5.2" 641 | } 642 | }, 643 | "nbformat": 4, 644 | "nbformat_minor": 2 645 | } 646 | -------------------------------------------------------------------------------- /notebooks/09.Recursive-NN-for-Sentiment-Classification.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# 9. Recursive Neural Networks and Constituency Parsing" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "I recommend you take a look at these material first." 15 | ] 16 | }, 17 | { 18 | "cell_type": "markdown", 19 | "metadata": {}, 20 | "source": [ 21 | "* http://web.stanford.edu/class/cs224n/lectures/cs224n-2017-lecture14-TreeRNNs.pdf\n", 22 | "* https://nlp.stanford.edu/~socherr/EMNLP2013_RNTN.pdf" 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": 4, 28 | "metadata": { 29 | "collapsed": true 30 | }, 31 | "outputs": [], 32 | "source": [ 33 | "import torch\n", 34 | "import torch.nn as nn\n", 35 | "from torch.autograd import Variable\n", 36 | "import torch.optim as optim\n", 37 | "import torch.nn.functional as F\n", 38 | "import nltk\n", 39 | "import random\n", 40 | "import numpy as np\n", 41 | "from collections import Counter, OrderedDict\n", 42 | "import nltk\n", 43 | "from copy import deepcopy\n", 44 | "import os\n", 45 | "from IPython.display import Image, display\n", 46 | "from nltk.draw import TreeWidget\n", 47 | "from nltk.draw.util import CanvasFrame\n", 48 | "from nltk.tree import Tree as nltkTree\n", 49 | "flatten = lambda l: [item for sublist in l for item in sublist]\n", 50 | "random.seed(1024)" 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": 2, 56 | "metadata": { 57 | "collapsed": true 58 | }, 59 | "outputs": [], 60 | "source": [ 61 | "USE_CUDA = torch.cuda.is_available()\n", 62 | "gpus = [0]\n", 63 | "torch.cuda.set_device(gpus[0])\n", 64 | "\n", 65 | "FloatTensor = torch.cuda.FloatTensor if USE_CUDA else torch.FloatTensor\n", 66 | "LongTensor = torch.cuda.LongTensor if USE_CUDA else torch.LongTensor\n", 67 | "ByteTensor = torch.cuda.ByteTensor if USE_CUDA else torch.ByteTensor" 68 | ] 69 | }, 70 | { 71 | "cell_type": "code", 72 | "execution_count": 3, 73 | "metadata": { 74 | "collapsed": true 75 | }, 76 | "outputs": [], 77 | "source": [ 78 | "def getBatch(batch_size, train_data):\n", 79 | " random.shuffle(train_data)\n", 80 | " sindex = 0\n", 81 | " eindex = batch_size\n", 82 | " while eindex < len(train_data):\n", 83 | " batch = train_data[sindex: eindex]\n", 84 | " temp = eindex\n", 85 | " eindex = eindex + batch_size\n", 86 | " sindex = temp\n", 87 | " yield batch\n", 88 | " \n", 89 | " if eindex >= len(train_data):\n", 90 | " batch = train_data[sindex:]\n", 91 | " yield batch" 92 | ] 93 | }, 94 | { 95 | "cell_type": "code", 96 | "execution_count": 5, 97 | "metadata": { 98 | "collapsed": true 99 | }, 100 | "outputs": [], 101 | "source": [ 102 | "# Borrowed from https://stackoverflow.com/questions/31779707/how-do-you-make-nltk-draw-trees-that-are-inline-in-ipython-jupyter\n", 103 | "\n", 104 | "def draw_nltk_tree(tree):\n", 105 | " cf = CanvasFrame()\n", 106 | " tc = TreeWidget(cf.canvas(), tree)\n", 107 | " tc['node_font'] = 'arial 15 bold'\n", 108 | " tc['leaf_font'] = 'arial 15'\n", 109 | " tc['node_color'] = '#005990'\n", 110 | " tc['leaf_color'] = '#3F8F57'\n", 111 | " tc['line_color'] = '#175252'\n", 112 | " cf.add_widget(tc, 50, 50)\n", 113 | " cf.print_to_file('tmp_tree_output.ps')\n", 114 | " cf.destroy()\n", 115 | " os.system('convert tmp_tree_output.ps tmp_tree_output.png')\n", 116 | " display(Image(filename='tmp_tree_output.png'))\n", 117 | " os.system('rm tmp_tree_output.ps tmp_tree_output.png')" 118 | ] 119 | }, 120 | { 121 | "cell_type": "markdown", 122 | "metadata": {}, 123 | "source": [ 124 | "## Data load and Preprocessing" 125 | ] 126 | }, 127 | { 128 | "cell_type": "markdown", 129 | "metadata": {}, 130 | "source": [ 131 | "### Stanford Sentiment Treebank(https://nlp.stanford.edu/sentiment/index.html)" 132 | ] 133 | }, 134 | { 135 | "cell_type": "code", 136 | "execution_count": 10, 137 | "metadata": { 138 | "collapsed": false 139 | }, 140 | "outputs": [ 141 | { 142 | "name": "stdout", 143 | "output_type": "stream", 144 | "text": [ 145 | "(3 (2 (1 Deflated) (2 (2 ending) (2 aside))) (4 (2 ,) (4 (2 there) (3 (3 (2 's) (3 (2 much) (2 (2 to) (3 (3 recommend) (2 (2 the) (2 film)))))) (2 .)))))\n", 146 | "\n" 147 | ] 148 | } 149 | ], 150 | "source": [ 151 | "sample = random.choice(open('../dataset/trees/train.txt', 'r', encoding='utf-8').readlines())\n", 152 | "print(sample)" 153 | ] 154 | }, 155 | { 156 | "cell_type": "code", 157 | "execution_count": 11, 158 | "metadata": { 159 | "collapsed": false 160 | }, 161 | "outputs": [ 162 | { 163 | "data": { 164 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAhEAAAGVCAMAAAB3iVNzAAAJJGlDQ1BpY2MAAHjalZVnUJNZF8fv\n8zzphUASQodQQ5EqJYCUEFoo0quoQOidUEVsiLgCK4qINEUQUUDBVSmyVkSxsCgoYkE3yCKgrBtX\nERWUF/Sd0Xnf2Q/7n7n3/OY/Z+4995wPFwCCOFgSvLQnJqULvJ3smIFBwUzwg8L4aSkcT0838I96\nPwyg5XhvBfj3IkREpvGX4sLSyuWnCNIBgLKXWDMrPWWZDy8xPTz+K59dZsFSgUt8Y5mjv/Ho15xv\nLPqa4+vNXXoVCgAcKfoHDv+B/3vvslQ4gvTYqMhspk9yVHpWmCCSmbbcCR6Xy/QUJEfFJkT+UPC/\nSv4HpUdmpy9HbnLKBkFsdEw68/8ONTIwNATfZ/HW62uPIUb//85nWd+95HoA2LMAIHu+e+GVAHTu\nAED68XdPbamvlHwAOu7wMwSZ3zzU8oYGBEABdCADFIEq0AS6wAiYAUtgCxyAC/AAviAIrAN8EAMS\ngQBkgVywDRSAIrAH7AdVoBY0gCbQCk6DTnAeXAHXwW1wFwyDJ0AIJsArIALvwTwEQViIDNEgGUgJ\nUod0ICOIDVlDDpAb5A0FQaFQNJQEZUC50HaoCCqFqqA6qAn6BToHXYFuQoPQI2gMmob+hj7BCEyC\n6bACrAHrw2yYA7vCvvBaOBpOhXPgfHg3XAHXwyfgDvgKfBsehoXwK3gWAQgRYSDKiC7CRriIBxKM\nRCECZDNSiJQj9Ugr0o30IfcQITKDfERhUDQUE6WLskQ5o/xQfFQqajOqGFWFOo7qQPWi7qHGUCLU\nFzQZLY/WQVugeehAdDQ6C12ALkc3otvR19DD6An0ewwGw8CwMGYYZ0wQJg6zEVOMOYhpw1zGDGLG\nMbNYLFYGq4O1wnpgw7Dp2AJsJfYE9hJ2CDuB/YAj4pRwRjhHXDAuCZeHK8c14y7ihnCTuHm8OF4d\nb4H3wEfgN+BL8A34bvwd/AR+niBBYBGsCL6EOMI2QgWhlXCNMEp4SyQSVYjmRC9iLHErsYJ4iniD\nOEb8SKKStElcUggpg7SbdIx0mfSI9JZMJmuQbcnB5HTybnIT+Sr5GfmDGE1MT4wnFiG2RaxarENs\nSOw1BU9Rp3Ao6yg5lHLKGcodyow4XlxDnCseJr5ZvFr8nPiI+KwETcJQwkMiUaJYolnipsQUFUvV\noDpQI6j51CPUq9RxGkJTpXFpfNp2WgPtGm2CjqGz6Dx6HL2IfpI+QBdJUiWNJf0lsyWrJS9IChkI\nQ4PBYyQwShinGQ8Yn6QUpDhSkVK7pFqlhqTmpOWkbaUjpQul26SHpT/JMGUcZOJl9sp0yjyVRclq\ny3rJZskekr0mOyNHl7OU48sVyp2WeywPy2vLe8tvlD8i3y8/q6Co4KSQolCpcFVhRpGhaKsYp1im\neFFxWommZK0Uq1SmdEnpJVOSyWEmMCuYvUyRsryys3KGcp3ygPK8CkvFTyVPpU3lqSpBla0apVqm\n2qMqUlNSc1fLVWtRe6yOV2erx6gfUO9Tn9NgaQRo7NTo1JhiSbN4rBxWC2tUk6xpo5mqWa95Xwuj\nxdaK1zqodVcb1jbRjtGu1r6jA+uY6sTqHNQZXIFeYb4iaUX9ihFdki5HN1O3RXdMj6Hnppen16n3\nWl9NP1h/r36f/hcDE4MEgwaDJ4ZUQxfDPMNuw7+NtI34RtVG91eSVzqu3LKya+UbYx3jSONDxg9N\naCbuJjtNekw+m5qZCkxbTafN1MxCzWrMRth0tie7mH3DHG1uZ77F/Lz5RwtTi3SL0xZ/Wepaxls2\nW06tYq2KXNWwatxKxSrMqs5KaM20DrU+bC20UbYJs6m3eW6rahth22g7ydHixHFOcF7bGdgJ7Nrt\n5rgW3E3cy/aIvZN9of2AA9XBz6HK4ZmjimO0Y4ujyMnEaaPTZWe0s6vzXucRngKPz2viiVzMXDa5\n9LqSXH1cq1yfu2m7Cdy63WF3F/d97qOr1Vcnre70AB48j30eTz1Znqmev3phvDy9qr1eeBt653r3\n+dB81vs0+7z3tfMt8X3ip+mX4dfjT/EP8W/ynwuwDygNEAbqB24KvB0kGxQb1BWMDfYPbgyeXeOw\nZv+aiRCTkIKQB2tZa7PX3lwnuy5h3YX1lPVh68+EokMDQptDF8I8wurDZsN54TXhIj6Xf4D/KsI2\noixiOtIqsjRyMsoqqjRqKtoqel/0dIxNTHnMTCw3tir2TZxzXG3cXLxH/LH4xYSAhLZEXGJo4rkk\nalJ8Um+yYnJ28mCKTkpBijDVInV/qkjgKmhMg9LWpnWl05c+xf4MzYwdGWOZ1pnVmR+y/LPOZEtk\nJ2X3b9DesGvDZI5jztGNqI38jT25yrnbcsc2cTbVbYY2h2/u2aK6JX/LxFanrce3EbbFb/stzyCv\nNO/d9oDt3fkK+Vvzx3c47WgpECsQFIzstNxZ+xPqp9ifBnat3FW560thROGtIoOi8qKFYn7xrZ8N\nf674eXF31O6BEtOSQ3swe5L2PNhrs/d4qURpTun4Pvd9HWXMssKyd/vX779Zblxee4BwIOOAsMKt\noqtSrXJP5UJVTNVwtV11W418za6auYMRB4cO2R5qrVWoLar9dDj28MM6p7qOeo368iOYI5lHXjT4\nN/QdZR9tapRtLGr8fCzpmPC49/HeJrOmpmb55pIWuCWjZfpEyIm7J+1PdrXqtta1MdqKToFTGade\n/hL6y4PTrqd7zrDPtJ5VP1vTTmsv7IA6NnSIOmM6hV1BXYPnXM71dFt2t/+q9+ux88rnqy9IXii5\nSLiYf3HxUs6l2cspl2euRF8Z71nf8+Rq4NX7vV69A9dcr9247nj9ah+n79INqxvnb1rcPHeLfavz\ntuntjn6T/vbfTH5rHzAd6Lhjdqfrrvnd7sFVgxeHbIau3LO/d/0+7/7t4dXDgw/8HjwcCRkRPox4\nOPUo4dGbx5mP559sHUWPFj4Vf1r+TP5Z/e9av7cJTYUXxuzH+p/7PH8yzh9/9UfaHwsT+S/IL8on\nlSabpoymzk87Tt99ueblxKuUV/MzBX9K/FnzWvP12b9s/+oXBYom3gjeLP5d/Fbm7bF3xu96Zj1n\nn71PfD8/V/hB5sPxj+yPfZ8CPk3OZy1gFyo+a33u/uL6ZXQxcXHxPy6ikLxyKdSVAAAAIGNIUk0A\nAHomAACAhAAA+gAAAIDoAAB1MAAA6mAAADqYAAAXcJy6UTwAAACrUExURf///wBZkABZkABZkABZ\nkABZkABZkABZkABZkABZkABZkABZkBdSUhdSUhdTUxdSUhdSUhdSUhdSUhdSUhdSUhdSUhdSUhdS\nUhdSUhdSUhdSUhdSUgBZkABZkBdSUgBZkBdSUj+PVz+PVz+PVz+PVz+PVz+PVz+PVz+PVz+PVz+P\nVz+PVz+PVz+PVz+PVxdSUhdSUhdSUhdTUxdSUhhUVABZkBdSUj+PV////2WyAXMAAAA1dFJOUwAR\niO6ZuyJVZjPdqkR3dZnMu6OIZjMRqt3uVSJEzMd31rsziN2ZIkQRqlXuZsx3r+HSW+wg7fJpnQAA\nAAFiS0dEAIgFHUgAAAAJcEhZcwAAAEgAAABIAEbJaz4AAAAHdElNRQfhCwINNSfD9n+TAAASs0lE\nQVR42u2dCZuiuBqFU9pVXVWtrQ6uZffMZVFwq5m7SP7/P7skiFuxhiAJnPd5ulAgIZpD8sVODoQA\nAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA\nAAAAAAAAAAAAAAAAAAAAaDZPne6x++2p7mIAZfh27D53j891FwMoQ6fzQp6Ox7qLAZTiO9oIcEUn\niCNe6i4EUIjX5273te5CAKV4O3brLgJQhvfjE3lBZAnOPPPR54+6iwGU4ekbfqECAAAAgCi9/s9+\nr+5CADXo9QdD3x/9MfL94QCyaDcnMRjjCX87GRuQRWu5E8MFyKJ1THvGLF4MF06ymBm9ad3FBVXC\nxDD3FyPjY5Lj7MmHMVr4c8iimVzEsCyUbglZNA9RMVyALBpDWJVlxFBNXqAGqrmvy7c3oAaqbuQh\nC4143Pig2LgF1EAdvyHk+W0D1EDNvzMm/v4JakCZH50hCxUYqyGGCydZjOsuR2vpKSQG1UsFAABA\nUZ5+/aq7CK1EXReQH1geVgcdZV1Afh2hiDp47bwp6gLy3lWyWK3g5fhedxG+0jl+hyJq4un3UT3P\nh5fuDwJF1MPbu4KCID+6L1BEPbx2uwoO8t6Ovzud47GjoFabztvxXUXjKDbQYKg4CGo478f354C3\nussRB3qNOjjdiwp2HFAEAAAAoAG9n0pOTZmOf46x9uvxLI354s/F3FBrycT0wxj68z/n/tD4gCoe\nyHQ88kfBjRht1aBnjPzFbMw0uhzPFv7IULINayC9wVXbwNqKQf3f/KQ/8/1Zf5K+C1RAjAJuFFJH\nkZIbhKtmA1RBUi9RX+8xHQ8ygoYwtBgo07c1idS2oIbeI39dZ+sGFCZHjT+09yjcHyDYlEneXuFB\nvYdwzIhgUw6F7v2qe4/StzqCzZII1HBlvYescADBpjCivUAFvYfsWkSwKUCpe11q71FRS49gswgS\nalRO71FxNIhgMxeyWv2y+TzoJkawmYUhMTJkbY0hWIwHdvRhmCJY0ObTk3tTimbXe/BgYDpGRAEA\nAIVpryFHez95Kr/e5S53EHQeqcOwBFYkcfx1fJb6vQg6j9RhWAIrklj+epW7JErQeaQOwxJYkSQh\n/XsRdB55sGEJrEgSkf29CDqPPNiwBFYkyUj+XgSdRx5tWAIrkmTkfi+CziOPNiyBFUkKUhUh6Dzy\ncMMSWJEk8YvfKR1p96eg80g9hiXoNWLohLdKR1Z+gs4j9RiWQBEAAACAaix/6jFfrdf/2cdzqGPx\npebWX/xLbF6d3GKkMflgDw/1R3+MfH+OR4h+xZd4Ty9HizH5WAwFJqzJLEYS015/wGQwix4wvPzg\njxAdDvDY+it8efMNjcWMfbFTwzcKf8ESixEHe34xaxiMmCe99canx9Zj6iVHWlVMhouP08ve5eXD\ni3HPqWFgzy9epp5msIdFDmd4QLmsqjD82VXDYCxGxb7YKhTBnlg7L3TzXzUl7e1G5FTFZDi/zScI\nKQpFmHIVcbnje8Xv+KBVMdocdY4kVEVs5FAswpRRDM5EUlTQ3qhTQlX05vO4TApFmBKKETUMMuuw\njVFn6apIqfgCEWa5Ylwqrop2vmVR56jk+sfefJhSC7kjTNFiLHuPatxbE3WWU8R05qenzxthChQj\nbBjmFTUMCZ+3BVFnKUUE0WPmF5MvwixUjFPDMBrUdLs2O+osoYjpLNf9nyvCzFsMhe7RpkadA2FF\njHP/CpUjwswuxvL0P1R1NQzxNDDqNEQVMVj0C1xlMShVjPHpf6hqbxjiiaLOcd0FqZVivwYuS9pc\n9dVvmac6FBIAAEQxqUlMK+0EK3deQDdsGuCs1le7XGp5hLqXHeZ9ouuDdzx9K+QBUvB0qakfxose\nxTxh0+CPtaHbyy6X7bqqdIveJ0pRxO9iHiAFT5ea+mGwlWkaFPOEHVb3ZucRsnXobuW5QaNhskr3\nVkHrsSXsvX06xs9xtimK6HSeyEv+RVEFT5ea+lG8sGI+2hdFnJMiLLon7s4l68Pq3EasHI/sqcff\nR8e8QBbeJkUR/Cso5gFSzjLkwYYjovylWxtBgmZgswq2W3pWhLcOt+x9dGxPg53rdEU8/T5+L1CC\ngqdLTf0ono/vesURhCvCpBw3UsR6w99yRUTHwtNTFfGr2y3ivlDwdKmpH0YnkIQulojnXmNLTDvc\nFSnC3KzPbUR0LFsRr8ffRRb3FzxdaupH8qZH50YukaUT/DuEuyJFsHpfR73G6RjvNawURQR1VKR5\nLHi61NSP4lf3XYcAOOI0+txZTAmfwfhic1aEY5O1ySJOHl2Gx9a7FRuDJCoi+OC/mQdIzqvnOd2W\ndbG6eOry0ee3usuRE/4L1S78hSoYWdKgp4gU4Tr0YNlBgHEIR5/sWDj6TP7NMnLtyXn1PKdTW9LF\nauPlx1GjX6jUx/TqLgFQCm9TdwkAAKBaemOjaVMuC1FwNp7w5D0Jqatm0jPYHFx/9PffI59NxTV6\nis7+qxQogs2jGxujIZ8a3o8mGS7ZbPG57w/ZHG2FpghXT7sVMfkwmCuFPxokzAbOPKF5tFQRxZqA\nuEaksbROEUHlzgTDhHOgMWty6NkeRUirzxKa0oEWKKKaNr+xoWejFfGAuLB5oWczFfHoG7hJoWfT\nFFFrJ9+I0LMxilCoNvQOPfVXhKottq6hp86K0CKq06KQ12ipCP1uP1Ubshi0U8RU4y76KtjRQsea\nUHvgWJ6gwai7CAAAIIZlZqwDvMGll1QFLgA0gq8lzz2l3nMvqfJg5j81F3rZfSShtreKWIXlTMXs\nTaQqQi+7jySU8VZhpiOHLSEb16S7T8KXiJmflPcal307utleKtFmq8MsQvZsG7pTRKmi7JLh9iYu\nDdI61uX81dah/DqOTQqimd1HEsp4q5grj3zSNaGHNd/azGTACRVx3ues2YLBKMme7oOaNPkaY2/D\nV5meU0XZpRDamwRKYmuXL5ffu8QOrufu9iIfQx+7jzRU8FaxeO0dbEJZRbDFw8x45NRGnPeFbccl\nEeH1atFw8WjwMkp1zi4Frgh2nnt1Pl8aysVgCwSdOtl9pKCEt4ob+oyY4dCCqcAmYaWdzAbO20vX\n760cloaJ5GDzFennVOfsUi95yuz6/PBSIcU/hU52H8mo4a1yrucCijAP1unter/a2beKyHfJiyLu\nLi+KPnYfiSjirWJFpodRldz0GtE+3pqfK5vvjt6GrleXXmObeclrRdxeflc4qmRoZveRhDLeKiaP\n5qxz7ds7N4wirxSxOtxElkH1e5876n2ywYLt8MgyShVll4J7Ms8LvZCuL29zO73Pgp9AM7uPBNTx\nVuGWl5+XNoK9N/fmjSLYvqvRp8WMMdeOGY1Cg5o9p4qyS+PARp8kVMTN5dnoc2cXNptohN2HLt4q\nV2x3dZcAKANzyGXmqACcsE8GygDowbTfx/QjcGYyWPzzz2Kg4Xy6O1RYE90AxiN/1iOkN9P/2WrG\noHwebWdpLBbG8utrPTFGdZdAd760C6f2QlegiFJM+/OY2CGIKebaRplQRAmSaz5eKVoARQiT0Tvo\nGmX2oAgh8kSQekaZUIQIue9/DaNMKKIwxWIE7aLMnl93CTSjeA1rFmVCEYUQ7AV0ijKhiPyUiRT1\niTKhiLyUvs81iTKXUEQuljJiARaDqN9OQBH5kDNemPbr/hwAAAAAeCyyfDaa4ToC5PlsNMN1BEjz\n2WiI6wgIkeWz0QzXESDLZ6MhriNAms9GM1xHgEyfjQa4jgBZPhsNcR0B0nw29HAdcYU8dtqFNJ8N\nLVxHbJm+rwAAAACQB1xHGoMUn42Pmf/PP/7so+4PAyRQ3mdjaYRTNdlkSz1mZYM0Sq6Yno5H/mg8\njXkDNKWUIr42C1GDAbRFXBHT/jw2dAiCCr0WhIIbRBWRVu9JWgE6ILSGPrtvQJipLcUVkTN+RJip\nKUUVUeTmR5ipI4VWTBcPEBBmakcBRYjVLsJMzciriDI9AMJMnciliNJRIsJMfcjhsyHnFkeYqQtZ\niljKCwNYIILOQ39kDhVgPAIAAPpiiT83GjSSVeYKBlmWI1XkBuSTvaZFluVIFbkB+bgmsQ6U8kfZ\nxyPJcqSS3EAlODYhe5r6bHpZliNV5Aaks/vMOkOS5UgluQH57HfOapt2gizLkSpyA1XgbW3nkHxY\nouWI9NxAVXg08XcJSZYjleQGKmEXRJZbuk44KstypIrcQDWw0echMZCQZjlSQW4AAAAAaBnLfh+T\nqHRHiuUIZ9kf+v/+tz+EKPRGkiLYbGyuBaYLiEJnZChiOp75V9O5IQqtKa2IUA538/IhCn0pp4hY\nOYRAFJpSRhEfbO1PyqodiEJHhBURyGExyHx6NEShHWKK4HLIuRQMotALAUX0CsghBKLQiKKKmLAV\nvwILRSEKXSikiEAO/kzYFQCi0IL8iignhxCIQn1yKmIpQQ6nnCAKtcmjCFaJMi1kIAqVyVbEsorq\n46KAJjSlort5CfsRAADQGcusuwRALdwYv4lORSYhT1VlDCQSo4hOVSYh3+A+oggb16S7veVQxw0V\nwP5YG0o3XvByH+y/MZt47byRpyqWaXU6L9VkDApCD2vySc01WTlnRXjOyvM2JnF3gSw2zn2SykxC\nvqONUAC6P/0JlBApgi8VXrM2Yx3TdTz9PlbiCRF0SN9e6v46AOHOAezPlSIiB7NzL3LN23s1giCv\nz7AfUYGiinjtdn9VVZa3Y7furwN8VcSe9RpBNLm24xTxdnyvpmV/P8LXTg2uFbGmn2R9YJHlJogs\nD3GKeD++M5MQ+V5Cz3z0+aPurwPcKIJ87qjpmlejT3KviJNJiPyOA/64AAAAAChB72fmijCgJEWe\nIZub6Xjo/+kP8fw/HfHlT6pbGou5MSVTY77Ag0L1w5fduH+M/NFHzGugCXIV8aVdOLUXQB9kKqI3\niIkdWEyR7TsAlEGaIlJqPlYpQFFGchSR0TsgytQHKYrIE0EiytSE8orIff8jytSCUUl/w0IxAqJM\nDSilCIEaRpSpOiUUIdgLIMpUG2FFlIkUEWUqzEBIEaXvc0SZyiLkeGlIiAVYDCLtWQ+gZnpyxguS\nsgEAAADA4/GEXGekuYXAdUQ5tlQklSy3kMrsTIAoNqXUJZ8OpRurQDJZbiGV2ZkAYfhSQLon3mbn\nFUspyy2kMjsTIARTxGETvLDotkg6aW4hVdmZAEGYIqjNXoV/8yLLLaQyOxMgiKgiJLmFVGlnAoQQ\n6zVkuYVUZmcChHGpR/ZBZLk2nQKpZLmFVGZnAoTxDtHoc10glSy3kMrsTAAAAADQBqb/wby6hiDH\nhmQ6/O8QkmgGUhSxHA7/N8Tz3ZqBDEVMFoPg72AxqfvDAAlIUEQoCEiiIZRXRCQISKIZlFbEx+Ky\nVMNYYHmX9pRVxNgfJ74DOlJSEfcSgCS0p5wivgoAktCdUoqIixuu4wqgIWUUET+2uIw9gI6UUETS\nYBOS0BpxRST/+gBJ6IyoIqZp/4+xHOL/vbRFVBGj1DqfDkd1fzDwYCbpjcAUv2cDAEBrMYusOAct\ngD+O9oosBxE4jDSdO0VkOYjAYURJtuY+qMitQ3crjxDboc6e8O3OJsQ1XYeaa5PutkGfsDWpY9mU\nLyKOEphucPCT+RnR3f5OEVkOInAYURKXblzP3blkfViRz2DrBvVqB//cQBJuUPtbtiRw5QQtgMm2\nq2DHmkQJCD2syWewY7WzvNV9r0GyHUTgMKIcLg3CwU1Qt8ykitUxcdfhNlABWzlM6P5kM7Anpx3u\nOQHfyXYESiHrr4rIchCBw4h6sLomJuW4kasE39o0PMjqmSvitGUvogRh6MC2NvkaWWY6iMBhREFC\nRVwrgeRRRGRIkqqILAcROIyoCK/jzSF8w91G9m64XR2SFRElOCsirtfIchCBw4iS8Dp26SfxVpsg\nRmSR5TaMLKmdrIgowVkRq93as+8UkeUgAocRJeF1zAaT3GWEjT6DsSTf2iRZEecEkSL46NO8tTLK\nchCBwwgAAAAAAMhP1mw8OU4lQB8yFQF/kZaRqQg8yq1lQBHgFigC3AJFgFuyFDGCIloGFAFugSLA\nLZmKgK9Iy4AiwC1QBLgFigC3ZCliAEW0jCxFGFAEAAAAAL5gmcnHTGrylWFwIGkNZrQQJBaXWh5x\nPRKzlhw0E4tmKCLcQhFtwaWU2i7dO9QJ+gVvRelhe3eY9xpsueC1ZwloLKwRcHcbj2ycoAdZedx0\n5PZwpIhrzxLQWLgimAiCFxYXw8G+O3xWxJVnCWgsLj0vOud9BO8mbg+fFUGu1hiDpnKriNjDUESr\nuFaERbdxh6GIVsGig0gRxDwwbzvr5jCBIlrGgY0+SVjPbPTJbUrOQBEAAAAAAAAAAABQnf8DmwKy\nJJha+OIAAAAldEVYdGRhdGU6Y3JlYXRlADIwMTctMTEtMDJUMjI6NTM6MzkrMDk6MDBXyJTvAAAA\nJXRFWHRkYXRlOm1vZGlmeQAyMDE3LTExLTAyVDIyOjUzOjM5KzA5OjAwJpUsUwAAACN0RVh0cHM6\nSGlSZXNCb3VuZGluZ0JveAA1Mjl4NDA1LTI2NC0yMDKzNU7+AAAAHHRFWHRwczpMZXZlbABBZG9i\nZS0zLjAgRVBTRi0zLjAKm3C74wAAACJ0RVh0cHM6U3BvdENvbG9yLTAAZm9udCBMaWJlcmF0aW9u\nU2Fuc/4Zp8YAAAAASUVORK5CYII=\n", 165 | "text/plain": [ 166 | "" 167 | ] 168 | }, 169 | "metadata": {}, 170 | "output_type": "display_data" 171 | } 172 | ], 173 | "source": [ 174 | "draw_nltk_tree(nltkTree.fromstring(sample))" 175 | ] 176 | }, 177 | { 178 | "cell_type": "markdown", 179 | "metadata": {}, 180 | "source": [ 181 | "### Tree Class " 182 | ] 183 | }, 184 | { 185 | "cell_type": "markdown", 186 | "metadata": {}, 187 | "source": [ 188 | "borrowed code from https://github.com/bogatyy/cs224d/tree/master/assignment3" 189 | ] 190 | }, 191 | { 192 | "cell_type": "code", 193 | "execution_count": 10, 194 | "metadata": { 195 | "collapsed": true 196 | }, 197 | "outputs": [], 198 | "source": [ 199 | "class Node: # a node in the tree\n", 200 | " def __init__(self, label, word=None):\n", 201 | " self.label = label\n", 202 | " self.word = word\n", 203 | " self.parent = None # reference to parent\n", 204 | " self.left = None # reference to left child\n", 205 | " self.right = None # reference to right child\n", 206 | " # true if I am a leaf (could have probably derived this from if I have\n", 207 | " # a word)\n", 208 | " self.isLeaf = False\n", 209 | " # true if we have finished performing fowardprop on this node (note,\n", 210 | " # there are many ways to implement the recursion.. some might not\n", 211 | " # require this flag)\n", 212 | "\n", 213 | " def __str__(self):\n", 214 | " if self.isLeaf:\n", 215 | " return '[{0}:{1}]'.format(self.word, self.label)\n", 216 | " return '({0} <- [{1}:{2}] -> {3})'.format(self.left, self.word, self.label, self.right)\n", 217 | "\n", 218 | "\n", 219 | "class Tree:\n", 220 | "\n", 221 | " def __init__(self, treeString, openChar='(', closeChar=')'):\n", 222 | " tokens = []\n", 223 | " self.open = '('\n", 224 | " self.close = ')'\n", 225 | " for toks in treeString.strip().split():\n", 226 | " tokens += list(toks)\n", 227 | " self.root = self.parse(tokens)\n", 228 | " # get list of labels as obtained through a post-order traversal\n", 229 | " self.labels = get_labels(self.root)\n", 230 | " self.num_words = len(self.labels)\n", 231 | "\n", 232 | " def parse(self, tokens, parent=None):\n", 233 | " assert tokens[0] == self.open, \"Malformed tree\"\n", 234 | " assert tokens[-1] == self.close, \"Malformed tree\"\n", 235 | "\n", 236 | " split = 2 # position after open and label\n", 237 | " countOpen = countClose = 0\n", 238 | "\n", 239 | " if tokens[split] == self.open:\n", 240 | " countOpen += 1\n", 241 | " split += 1\n", 242 | " # Find where left child and right child split\n", 243 | " while countOpen != countClose:\n", 244 | " if tokens[split] == self.open:\n", 245 | " countOpen += 1\n", 246 | " if tokens[split] == self.close:\n", 247 | " countClose += 1\n", 248 | " split += 1\n", 249 | "\n", 250 | " # New node\n", 251 | " node = Node(int(tokens[1])) # zero index labels\n", 252 | "\n", 253 | " node.parent = parent\n", 254 | "\n", 255 | " # leaf Node\n", 256 | " if countOpen == 0:\n", 257 | " node.word = ''.join(tokens[2: -1]).lower() # lower case?\n", 258 | " node.isLeaf = True\n", 259 | " return node\n", 260 | "\n", 261 | " node.left = self.parse(tokens[2: split], parent=node)\n", 262 | " node.right = self.parse(tokens[split: -1], parent=node)\n", 263 | "\n", 264 | " return node\n", 265 | "\n", 266 | " def get_words(self):\n", 267 | " leaves = getLeaves(self.root)\n", 268 | " words = [node.word for node in leaves]\n", 269 | " return words\n", 270 | "\n", 271 | "def get_labels(node):\n", 272 | " if node is None:\n", 273 | " return []\n", 274 | " return get_labels(node.left) + get_labels(node.right) + [node.label]\n", 275 | "\n", 276 | "def getLeaves(node):\n", 277 | " if node is None:\n", 278 | " return []\n", 279 | " if node.isLeaf:\n", 280 | " return [node]\n", 281 | " else:\n", 282 | " return getLeaves(node.left) + getLeaves(node.right)\n", 283 | "\n", 284 | " \n", 285 | "def loadTrees(dataSet='train'):\n", 286 | " \"\"\"\n", 287 | " Loads training trees. Maps leaf node words to word ids.\n", 288 | " \"\"\"\n", 289 | " file = '../dataset/trees/%s.txt' % dataSet\n", 290 | " print(\"Loading %s trees..\" % dataSet)\n", 291 | " with open(file, 'r', encoding='utf-8') as fid:\n", 292 | " trees = [Tree(l) for l in fid.readlines()]\n", 293 | "\n", 294 | " return trees" 295 | ] 296 | }, 297 | { 298 | "cell_type": "code", 299 | "execution_count": 11, 300 | "metadata": { 301 | "collapsed": false 302 | }, 303 | "outputs": [ 304 | { 305 | "name": "stdout", 306 | "output_type": "stream", 307 | "text": [ 308 | "Loading train trees..\n" 309 | ] 310 | } 311 | ], 312 | "source": [ 313 | "train_data = loadTrees('train')" 314 | ] 315 | }, 316 | { 317 | "cell_type": "markdown", 318 | "metadata": {}, 319 | "source": [ 320 | "### Build Vocab " 321 | ] 322 | }, 323 | { 324 | "cell_type": "code", 325 | "execution_count": 12, 326 | "metadata": { 327 | "collapsed": true 328 | }, 329 | "outputs": [], 330 | "source": [ 331 | "vocab = list(set(flatten([t.get_words() for t in train_data])))" 332 | ] 333 | }, 334 | { 335 | "cell_type": "code", 336 | "execution_count": 13, 337 | "metadata": { 338 | "collapsed": true 339 | }, 340 | "outputs": [], 341 | "source": [ 342 | "word2index = {'': 0}\n", 343 | "for vo in vocab:\n", 344 | " if word2index.get(vo) is None:\n", 345 | " word2index[vo] = len(word2index)\n", 346 | " \n", 347 | "index2word = {v:k for k, v in word2index.items()}" 348 | ] 349 | }, 350 | { 351 | "cell_type": "markdown", 352 | "metadata": {}, 353 | "source": [ 354 | "## Modeling " 355 | ] 356 | }, 357 | { 358 | "cell_type": "markdown", 359 | "metadata": {}, 360 | "source": [ 361 | "\n", 362 | "
borrowed image from https://nlp.stanford.edu/~socherr/EMNLP2013_RNTN.pdf
" 363 | ] 364 | }, 365 | { 366 | "cell_type": "code", 367 | "execution_count": 14, 368 | "metadata": { 369 | "collapsed": true 370 | }, 371 | "outputs": [], 372 | "source": [ 373 | "class RNTN(nn.Module):\n", 374 | " \n", 375 | " def __init__(self, word2index, hidden_size, output_size):\n", 376 | " super(RNTN,self).__init__()\n", 377 | " \n", 378 | " self.word2index = word2index\n", 379 | " self.embed = nn.Embedding(len(word2index), hidden_size)\n", 380 | "# self.V = nn.ModuleList([nn.Linear(hidden_size*2,hidden_size*2) for _ in range(hidden_size)])\n", 381 | "# self.W = nn.Linear(hidden_size*2,hidden_size)\n", 382 | " self.V = nn.ParameterList([nn.Parameter(torch.randn(hidden_size * 2, hidden_size * 2)) for _ in range(hidden_size)]) # Tensor\n", 383 | " self.W = nn.Parameter(torch.randn(hidden_size * 2, hidden_size))\n", 384 | " self.b = nn.Parameter(torch.randn(1, hidden_size))\n", 385 | "# self.W_out = nn.Parameter(torch.randn(hidden_size,output_size))\n", 386 | " self.W_out = nn.Linear(hidden_size, output_size)\n", 387 | " \n", 388 | " def init_weight(self):\n", 389 | " nn.init.xavier_uniform(self.embed.state_dict()['weight'])\n", 390 | " nn.init.xavier_uniform(self.W_out.state_dict()['weight'])\n", 391 | " for param in self.V.parameters():\n", 392 | " nn.init.xavier_uniform(param)\n", 393 | " nn.init.xavier_uniform(self.W)\n", 394 | " self.b.data.fill_(0)\n", 395 | "# nn.init.xavier_uniform(self.W_out)\n", 396 | " \n", 397 | " def tree_propagation(self, node):\n", 398 | " \n", 399 | " recursive_tensor = OrderedDict()\n", 400 | " current = None\n", 401 | " if node.isLeaf:\n", 402 | " tensor = Variable(LongTensor([self.word2index[node.word]])) if node.word in self.word2index.keys() \\\n", 403 | " else Variable(LongTensor([self.word2index['']]))\n", 404 | " current = self.embed(tensor) # 1xD\n", 405 | " else:\n", 406 | " recursive_tensor.update(self.tree_propagation(node.left))\n", 407 | " recursive_tensor.update(self.tree_propagation(node.right))\n", 408 | " \n", 409 | " concated = torch.cat([recursive_tensor[node.left], recursive_tensor[node.right]], 1) # 1x2D\n", 410 | " xVx = [] \n", 411 | " for i, v in enumerate(self.V):\n", 412 | "# xVx.append(torch.matmul(v(concated),concated.transpose(0,1)))\n", 413 | " xVx.append(torch.matmul(torch.matmul(concated, v), concated.transpose(0, 1)))\n", 414 | " \n", 415 | " xVx = torch.cat(xVx, 1) # 1xD\n", 416 | "# Wx = self.W(concated)\n", 417 | " Wx = torch.matmul(concated, self.W) # 1xD\n", 418 | "\n", 419 | " current = F.tanh(xVx + Wx + self.b) # 1xD\n", 420 | " recursive_tensor[node] = current\n", 421 | " return recursive_tensor\n", 422 | " \n", 423 | " def forward(self, Trees, root_only=False):\n", 424 | " \n", 425 | " propagated = []\n", 426 | " if not isinstance(Trees, list):\n", 427 | " Trees = [Trees]\n", 428 | " \n", 429 | " for Tree in Trees:\n", 430 | " recursive_tensor = self.tree_propagation(Tree.root)\n", 431 | " if root_only:\n", 432 | " recursive_tensor = recursive_tensor[Tree.root]\n", 433 | " propagated.append(recursive_tensor)\n", 434 | " else:\n", 435 | " recursive_tensor = [tensor for node,tensor in recursive_tensor.items()]\n", 436 | " propagated.extend(recursive_tensor)\n", 437 | " \n", 438 | " propagated = torch.cat(propagated) # (num_of_node in batch, D)\n", 439 | " \n", 440 | "# return F.log_softmax(propagated.matmul(self.W_out))\n", 441 | " return F.log_softmax(self.W_out(propagated),1)" 442 | ] 443 | }, 444 | { 445 | "cell_type": "markdown", 446 | "metadata": {}, 447 | "source": [ 448 | "## Training " 449 | ] 450 | }, 451 | { 452 | "cell_type": "markdown", 453 | "metadata": {}, 454 | "source": [ 455 | "It takes for a while... It builds its computational graph dynamically. So Its computation is difficult to train with batch." 456 | ] 457 | }, 458 | { 459 | "cell_type": "code", 460 | "execution_count": 15, 461 | "metadata": { 462 | "collapsed": true 463 | }, 464 | "outputs": [], 465 | "source": [ 466 | "HIDDEN_SIZE = 30\n", 467 | "ROOT_ONLY = False\n", 468 | "BATCH_SIZE = 20\n", 469 | "EPOCH = 20\n", 470 | "LR = 0.01\n", 471 | "LAMBDA = 1e-5\n", 472 | "RESCHEDULED = False" 473 | ] 474 | }, 475 | { 476 | "cell_type": "code", 477 | "execution_count": 18, 478 | "metadata": { 479 | "collapsed": true 480 | }, 481 | "outputs": [], 482 | "source": [ 483 | "model = RNTN(word2index, HIDDEN_SIZE,5)\n", 484 | "model.init_weight()\n", 485 | "if USE_CUDA:\n", 486 | " model = model.cuda()\n", 487 | "\n", 488 | "loss_function = nn.CrossEntropyLoss()\n", 489 | "optimizer = optim.Adam(model.parameters(), lr=LR)" 490 | ] 491 | }, 492 | { 493 | "cell_type": "code", 494 | "execution_count": 19, 495 | "metadata": { 496 | "collapsed": false 497 | }, 498 | "outputs": [ 499 | { 500 | "name": "stdout", 501 | "output_type": "stream", 502 | "text": [ 503 | "[0/20] mean_loss : 1.62\n", 504 | "[0/20] mean_loss : 1.25\n", 505 | "[0/20] mean_loss : 0.95\n", 506 | "[0/20] mean_loss : 0.90\n", 507 | "[0/20] mean_loss : 0.88\n", 508 | "[1/20] mean_loss : 0.88\n", 509 | "[1/20] mean_loss : 0.84\n", 510 | "[1/20] mean_loss : 0.83\n", 511 | "[1/20] mean_loss : 0.82\n", 512 | "[1/20] mean_loss : 0.82\n", 513 | "[2/20] mean_loss : 0.81\n", 514 | "[2/20] mean_loss : 0.79\n", 515 | "[2/20] mean_loss : 0.78\n", 516 | "[2/20] mean_loss : 0.76\n", 517 | "[2/20] mean_loss : 0.75\n", 518 | "[3/20] mean_loss : 0.68\n", 519 | "[3/20] mean_loss : 0.73\n", 520 | "[3/20] mean_loss : 0.74\n", 521 | "[3/20] mean_loss : 0.72\n", 522 | "[3/20] mean_loss : 0.72\n", 523 | "[4/20] mean_loss : 0.74\n", 524 | "[4/20] mean_loss : 0.69\n", 525 | "[4/20] mean_loss : 0.69\n", 526 | "[4/20] mean_loss : 0.68\n", 527 | "[4/20] mean_loss : 0.67\n", 528 | "[5/20] mean_loss : 0.73\n", 529 | "[5/20] mean_loss : 0.65\n", 530 | "[5/20] mean_loss : 0.64\n", 531 | "[5/20] mean_loss : 0.64\n", 532 | "[5/20] mean_loss : 0.65\n", 533 | "[6/20] mean_loss : 0.67\n", 534 | "[6/20] mean_loss : 0.62\n", 535 | "[6/20] mean_loss : 0.62\n", 536 | "[6/20] mean_loss : 0.62\n", 537 | "[6/20] mean_loss : 0.62\n", 538 | "[7/20] mean_loss : 0.57\n", 539 | "[7/20] mean_loss : 0.59\n", 540 | "[7/20] mean_loss : 0.59\n", 541 | "[7/20] mean_loss : 0.59\n", 542 | "[7/20] mean_loss : 0.59\n", 543 | "[8/20] mean_loss : 0.60\n", 544 | "[8/20] mean_loss : 0.58\n", 545 | "[8/20] mean_loss : 0.59\n", 546 | "[8/20] mean_loss : 0.60\n", 547 | "[8/20] mean_loss : 0.60\n", 548 | "[9/20] mean_loss : 0.52\n", 549 | "[9/20] mean_loss : 0.58\n", 550 | "[9/20] mean_loss : 0.60\n", 551 | "[9/20] mean_loss : 0.59\n", 552 | "[9/20] mean_loss : 0.59\n", 553 | "[10/20] mean_loss : 0.56\n", 554 | "[10/20] mean_loss : 0.56\n", 555 | "[10/20] mean_loss : 0.56\n", 556 | "[10/20] mean_loss : 0.56\n", 557 | "[10/20] mean_loss : 0.56\n", 558 | "[11/20] mean_loss : 0.52\n", 559 | "[11/20] mean_loss : 0.54\n", 560 | "[11/20] mean_loss : 0.54\n", 561 | "[11/20] mean_loss : 0.54\n", 562 | "[11/20] mean_loss : 0.55\n", 563 | "[12/20] mean_loss : 0.55\n", 564 | "[12/20] mean_loss : 0.53\n", 565 | "[12/20] mean_loss : 0.53\n", 566 | "[12/20] mean_loss : 0.53\n", 567 | "[12/20] mean_loss : 0.53\n", 568 | "[13/20] mean_loss : 0.59\n", 569 | "[13/20] mean_loss : 0.52\n", 570 | "[13/20] mean_loss : 0.52\n", 571 | "[13/20] mean_loss : 0.53\n", 572 | "[13/20] mean_loss : 0.53\n", 573 | "[14/20] mean_loss : 0.49\n", 574 | "[14/20] mean_loss : 0.51\n", 575 | "[14/20] mean_loss : 0.51\n", 576 | "[14/20] mean_loss : 0.52\n", 577 | "[14/20] mean_loss : 0.52\n", 578 | "[15/20] mean_loss : 0.43\n", 579 | "[15/20] mean_loss : 0.51\n", 580 | "[15/20] mean_loss : 0.51\n", 581 | "[15/20] mean_loss : 0.51\n", 582 | "[15/20] mean_loss : 0.51\n", 583 | "[16/20] mean_loss : 0.46\n", 584 | "[16/20] mean_loss : 0.50\n", 585 | "[16/20] mean_loss : 0.50\n", 586 | "[16/20] mean_loss : 0.50\n", 587 | "[16/20] mean_loss : 0.50\n", 588 | "[17/20] mean_loss : 0.50\n", 589 | "[17/20] mean_loss : 0.50\n", 590 | "[17/20] mean_loss : 0.50\n", 591 | "[17/20] mean_loss : 0.50\n", 592 | "[17/20] mean_loss : 0.51\n", 593 | "[18/20] mean_loss : 0.46\n", 594 | "[18/20] mean_loss : 0.50\n", 595 | "[18/20] mean_loss : 0.50\n", 596 | "[18/20] mean_loss : 0.49\n", 597 | "[18/20] mean_loss : 0.49\n", 598 | "[19/20] mean_loss : 0.49\n", 599 | "[19/20] mean_loss : 0.49\n", 600 | "[19/20] mean_loss : 0.49\n", 601 | "[19/20] mean_loss : 0.50\n", 602 | "[19/20] mean_loss : 0.50\n" 603 | ] 604 | } 605 | ], 606 | "source": [ 607 | "for epoch in range(EPOCH):\n", 608 | " losses = []\n", 609 | " \n", 610 | " # learning rate annealing\n", 611 | " if RESCHEDULED == False and epoch == EPOCH//2:\n", 612 | " LR *= 0.1\n", 613 | " optimizer = optim.Adam(model.parameters(), lr=LR, weight_decay=LAMBDA) # L2 norm\n", 614 | " RESCHEDULED = True\n", 615 | " \n", 616 | " for i, batch in enumerate(getBatch(BATCH_SIZE, train_data)):\n", 617 | " \n", 618 | " if ROOT_ONLY:\n", 619 | " labels = [tree.labels[-1] for tree in batch]\n", 620 | " labels = Variable(LongTensor(labels))\n", 621 | " else:\n", 622 | " labels = [tree.labels for tree in batch]\n", 623 | " labels = Variable(LongTensor(flatten(labels)))\n", 624 | " \n", 625 | " model.zero_grad()\n", 626 | " preds = model(batch, ROOT_ONLY)\n", 627 | " \n", 628 | " loss = loss_function(preds, labels)\n", 629 | " losses.append(loss.data.tolist()[0])\n", 630 | " \n", 631 | " loss.backward()\n", 632 | " optimizer.step()\n", 633 | " \n", 634 | " if i % 100 == 0:\n", 635 | " print('[%d/%d] mean_loss : %.2f' % (epoch, EPOCH, np.mean(losses)))\n", 636 | " losses = []\n", 637 | " " 638 | ] 639 | }, 640 | { 641 | "cell_type": "markdown", 642 | "metadata": {}, 643 | "source": [ 644 | "The convergence of the model is unstable according to the initial values. I tried to 5~6 times for this." 645 | ] 646 | }, 647 | { 648 | "cell_type": "markdown", 649 | "metadata": {}, 650 | "source": [ 651 | "## Test" 652 | ] 653 | }, 654 | { 655 | "cell_type": "code", 656 | "execution_count": 20, 657 | "metadata": { 658 | "collapsed": false 659 | }, 660 | "outputs": [ 661 | { 662 | "name": "stdout", 663 | "output_type": "stream", 664 | "text": [ 665 | "Loading test trees..\n" 666 | ] 667 | } 668 | ], 669 | "source": [ 670 | "test_data = loadTrees('test')" 671 | ] 672 | }, 673 | { 674 | "cell_type": "code", 675 | "execution_count": 21, 676 | "metadata": { 677 | "collapsed": true 678 | }, 679 | "outputs": [], 680 | "source": [ 681 | "accuracy = 0\n", 682 | "num_node = 0" 683 | ] 684 | }, 685 | { 686 | "cell_type": "markdown", 687 | "metadata": {}, 688 | "source": [ 689 | "### Fine-grained all" 690 | ] 691 | }, 692 | { 693 | "cell_type": "markdown", 694 | "metadata": {}, 695 | "source": [ 696 | "In paper, they acheived 80.2 accuracy. " 697 | ] 698 | }, 699 | { 700 | "cell_type": "code", 701 | "execution_count": 23, 702 | "metadata": { 703 | "collapsed": false 704 | }, 705 | "outputs": [ 706 | { 707 | "name": "stdout", 708 | "output_type": "stream", 709 | "text": [ 710 | "79.33705899068254\n" 711 | ] 712 | } 713 | ], 714 | "source": [ 715 | "for test in test_data:\n", 716 | " model.zero_grad()\n", 717 | " preds = model(test, ROOT_ONLY)\n", 718 | " labels = test.labels[-1:] if ROOT_ONLY else test.labels\n", 719 | " for pred, label in zip(preds.max(1)[1].data.tolist(), labels):\n", 720 | " num_node += 1\n", 721 | " if pred == label:\n", 722 | " accuracy += 1\n", 723 | "\n", 724 | "print(accuracy/num_node * 100)" 725 | ] 726 | }, 727 | { 728 | "cell_type": "markdown", 729 | "metadata": {}, 730 | "source": [ 731 | "## TODO " 732 | ] 733 | }, 734 | { 735 | "cell_type": "markdown", 736 | "metadata": {}, 737 | "source": [ 738 | "* https://github.com/nearai/pytorch-tools # Dynamic batch using TensorFold" 739 | ] 740 | }, 741 | { 742 | "cell_type": "markdown", 743 | "metadata": { 744 | "collapsed": true 745 | }, 746 | "source": [ 747 | "## Further topics " 748 | ] 749 | }, 750 | { 751 | "cell_type": "markdown", 752 | "metadata": {}, 753 | "source": [ 754 | "* Improved Semantic Representations From Tree-Structured Long Short-Term Memory Networks\n", 755 | "* A Fast Unified Model for Parsing and Sentence Understanding(SPINN)\n", 756 | "* Posting about SPINN" 757 | ] 758 | }, 759 | { 760 | "cell_type": "code", 761 | "execution_count": null, 762 | "metadata": { 763 | "collapsed": true 764 | }, 765 | "outputs": [], 766 | "source": [] 767 | } 768 | ], 769 | "metadata": { 770 | "kernelspec": { 771 | "display_name": "Python 3", 772 | "language": "python", 773 | "name": "python3" 774 | }, 775 | "language_info": { 776 | "codemirror_mode": { 777 | "name": "ipython", 778 | "version": 3 779 | }, 780 | "file_extension": ".py", 781 | "mimetype": "text/x-python", 782 | "name": "python", 783 | "nbconvert_exporter": "python", 784 | "pygments_lexer": "ipython3", 785 | "version": "3.5.2" 786 | } 787 | }, 788 | "nbformat": 4, 789 | "nbformat_minor": 2 790 | } 791 | -------------------------------------------------------------------------------- /notebooks/10.Dynamic-Memory-Network-for-Question-Answering.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# 10. \tDynamic Memory Networks for Question Answering" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "I recommend you take a look at these material first." 15 | ] 16 | }, 17 | { 18 | "cell_type": "markdown", 19 | "metadata": {}, 20 | "source": [ 21 | "* http://web.stanford.edu/class/cs224n/lectures/cs224n-2017-lecture16-DMN-QA.pdf\n", 22 | "* https://arxiv.org/abs/1506.07285" 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": 1, 28 | "metadata": { 29 | "collapsed": true 30 | }, 31 | "outputs": [], 32 | "source": [ 33 | "import torch\n", 34 | "import torch.nn as nn\n", 35 | "from torch.autograd import Variable\n", 36 | "import torch.optim as optim\n", 37 | "import torch.nn.functional as F\n", 38 | "import nltk\n", 39 | "import random\n", 40 | "import numpy as np\n", 41 | "from collections import Counter, OrderedDict\n", 42 | "import nltk\n", 43 | "from copy import deepcopy\n", 44 | "import os\n", 45 | "import re\n", 46 | "import unicodedata\n", 47 | "flatten = lambda l: [item for sublist in l for item in sublist]\n", 48 | "\n", 49 | "from torch.nn.utils.rnn import PackedSequence, pack_padded_sequence\n", 50 | "random.seed(1024)" 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": 2, 56 | "metadata": { 57 | "collapsed": true 58 | }, 59 | "outputs": [], 60 | "source": [ 61 | "USE_CUDA = torch.cuda.is_available()\n", 62 | "gpus = [0]\n", 63 | "torch.cuda.set_device(gpus[0])\n", 64 | "\n", 65 | "FloatTensor = torch.cuda.FloatTensor if USE_CUDA else torch.FloatTensor\n", 66 | "LongTensor = torch.cuda.LongTensor if USE_CUDA else torch.LongTensor\n", 67 | "ByteTensor = torch.cuda.ByteTensor if USE_CUDA else torch.ByteTensor" 68 | ] 69 | }, 70 | { 71 | "cell_type": "code", 72 | "execution_count": 3, 73 | "metadata": { 74 | "collapsed": true 75 | }, 76 | "outputs": [], 77 | "source": [ 78 | "def getBatch(batch_size, train_data):\n", 79 | " random.shuffle(train_data)\n", 80 | " sindex = 0\n", 81 | " eindex = batch_size\n", 82 | " while eindex < len(train_data):\n", 83 | " batch = train_data[sindex: eindex]\n", 84 | " temp = eindex\n", 85 | " eindex = eindex + batch_size\n", 86 | " sindex = temp\n", 87 | " yield batch\n", 88 | " \n", 89 | " if eindex >= len(train_data):\n", 90 | " batch = train_data[sindex:]\n", 91 | " yield batch" 92 | ] 93 | }, 94 | { 95 | "cell_type": "code", 96 | "execution_count": 4, 97 | "metadata": { 98 | "collapsed": true 99 | }, 100 | "outputs": [], 101 | "source": [ 102 | "def pad_to_batch(batch, w_to_ix): # for bAbI dataset\n", 103 | " fact,q,a = list(zip(*batch))\n", 104 | " max_fact = max([len(f) for f in fact])\n", 105 | " max_len = max([f.size(1) for f in flatten(fact)])\n", 106 | " max_q = max([qq.size(1) for qq in q])\n", 107 | " max_a = max([aa.size(1) for aa in a])\n", 108 | " \n", 109 | " facts, fact_masks, q_p, a_p = [], [], [], []\n", 110 | " for i in range(len(batch)):\n", 111 | " fact_p_t = []\n", 112 | " for j in range(len(fact[i])):\n", 113 | " if fact[i][j].size(1) < max_len:\n", 114 | " fact_p_t.append(torch.cat([fact[i][j], Variable(LongTensor([w_to_ix['']] * (max_len - fact[i][j].size(1)))).view(1, -1)], 1))\n", 115 | " else:\n", 116 | " fact_p_t.append(fact[i][j])\n", 117 | "\n", 118 | " while len(fact_p_t) < max_fact:\n", 119 | " fact_p_t.append(Variable(LongTensor([w_to_ix['']] * max_len)).view(1, -1))\n", 120 | "\n", 121 | " fact_p_t = torch.cat(fact_p_t)\n", 122 | " facts.append(fact_p_t)\n", 123 | " fact_masks.append(torch.cat([Variable(ByteTensor(tuple(map(lambda s: s ==0, t.data))), volatile=False) for t in fact_p_t]).view(fact_p_t.size(0), -1))\n", 124 | "\n", 125 | " if q[i].size(1) < max_q:\n", 126 | " q_p.append(torch.cat([q[i], Variable(LongTensor([w_to_ix['']] * (max_q - q[i].size(1)))).view(1, -1)], 1))\n", 127 | " else:\n", 128 | " q_p.append(q[i])\n", 129 | "\n", 130 | " if a[i].size(1) < max_a:\n", 131 | " a_p.append(torch.cat([a[i], Variable(LongTensor([w_to_ix['']] * (max_a - a[i].size(1)))).view(1, -1)], 1))\n", 132 | " else:\n", 133 | " a_p.append(a[i])\n", 134 | "\n", 135 | " questions = torch.cat(q_p)\n", 136 | " answers = torch.cat(a_p)\n", 137 | " question_masks = torch.cat([Variable(ByteTensor(tuple(map(lambda s: s ==0, t.data))), volatile=False) for t in questions]).view(questions.size(0), -1)\n", 138 | " \n", 139 | " return facts, fact_masks, questions, question_masks, answers" 140 | ] 141 | }, 142 | { 143 | "cell_type": "code", 144 | "execution_count": 5, 145 | "metadata": { 146 | "collapsed": true 147 | }, 148 | "outputs": [], 149 | "source": [ 150 | "def prepare_sequence(seq, to_index):\n", 151 | " idxs = list(map(lambda w: to_index[w] if to_index.get(w) is not None else to_index[\"\"], seq))\n", 152 | " return Variable(LongTensor(idxs))" 153 | ] 154 | }, 155 | { 156 | "cell_type": "markdown", 157 | "metadata": {}, 158 | "source": [ 159 | "## Data load and Preprocessing " 160 | ] 161 | }, 162 | { 163 | "cell_type": "markdown", 164 | "metadata": {}, 165 | "source": [ 166 | "### bAbI dataset(https://research.fb.com/downloads/babi/)" 167 | ] 168 | }, 169 | { 170 | "cell_type": "code", 171 | "execution_count": 24, 172 | "metadata": { 173 | "collapsed": true 174 | }, 175 | "outputs": [], 176 | "source": [ 177 | "def bAbI_data_load(path):\n", 178 | " try:\n", 179 | " data = open(path).readlines()\n", 180 | " except:\n", 181 | " print(\"Such a file does not exist at %s\".format(path))\n", 182 | " return None\n", 183 | " \n", 184 | " data = [d[:-1] for d in data]\n", 185 | " data_p = []\n", 186 | " fact = []\n", 187 | " qa = []\n", 188 | " try:\n", 189 | " for d in data:\n", 190 | " index = d.split(' ')[0]\n", 191 | " if index == '1':\n", 192 | " fact = []\n", 193 | " qa = []\n", 194 | " if '?' in d:\n", 195 | " temp = d.split('\\t')\n", 196 | " q = temp[0].strip().replace('?', '').split(' ')[1:] + ['?']\n", 197 | " a = temp[1].split() + ['']\n", 198 | " stemp = deepcopy(fact)\n", 199 | " data_p.append([stemp, q, a])\n", 200 | " else:\n", 201 | " tokens = d.replace('.', '').split(' ')[1:] + ['']\n", 202 | " fact.append(tokens)\n", 203 | " except:\n", 204 | " print(\"Please check the data is right\")\n", 205 | " return None\n", 206 | " return data_p" 207 | ] 208 | }, 209 | { 210 | "cell_type": "code", 211 | "execution_count": 25, 212 | "metadata": { 213 | "collapsed": true 214 | }, 215 | "outputs": [], 216 | "source": [ 217 | "train_data = bAbI_data_load('../dataset/bAbI/en-10k/qa5_three-arg-relations_train.txt')" 218 | ] 219 | }, 220 | { 221 | "cell_type": "code", 222 | "execution_count": 26, 223 | "metadata": { 224 | "collapsed": false 225 | }, 226 | "outputs": [ 227 | { 228 | "data": { 229 | "text/plain": [ 230 | "[[['Bill', 'travelled', 'to', 'the', 'office', ''],\n", 231 | " ['Bill', 'picked', 'up', 'the', 'football', 'there', ''],\n", 232 | " ['Bill', 'went', 'to', 'the', 'bedroom', ''],\n", 233 | " ['Bill', 'gave', 'the', 'football', 'to', 'Fred', '']],\n", 234 | " ['What', 'did', 'Bill', 'give', 'to', 'Fred', '?'],\n", 235 | " ['football', '']]" 236 | ] 237 | }, 238 | "execution_count": 26, 239 | "metadata": {}, 240 | "output_type": "execute_result" 241 | } 242 | ], 243 | "source": [ 244 | "train_data[0]" 245 | ] 246 | }, 247 | { 248 | "cell_type": "code", 249 | "execution_count": 11, 250 | "metadata": { 251 | "collapsed": true 252 | }, 253 | "outputs": [], 254 | "source": [ 255 | "fact,q,a = list(zip(*train_data))" 256 | ] 257 | }, 258 | { 259 | "cell_type": "code", 260 | "execution_count": 12, 261 | "metadata": { 262 | "collapsed": true 263 | }, 264 | "outputs": [], 265 | "source": [ 266 | "vocab = list(set(flatten(flatten(fact)) + flatten(q) + flatten(a)))" 267 | ] 268 | }, 269 | { 270 | "cell_type": "code", 271 | "execution_count": 13, 272 | "metadata": { 273 | "collapsed": true 274 | }, 275 | "outputs": [], 276 | "source": [ 277 | "word2index={'': 0, '': 1, '': 2, '': 3}\n", 278 | "for vo in vocab:\n", 279 | " if word2index.get(vo) is None:\n", 280 | " word2index[vo] = len(word2index)\n", 281 | "index2word = {v:k for k, v in word2index.items()}" 282 | ] 283 | }, 284 | { 285 | "cell_type": "code", 286 | "execution_count": 14, 287 | "metadata": { 288 | "collapsed": false 289 | }, 290 | "outputs": [ 291 | { 292 | "data": { 293 | "text/plain": [ 294 | "44" 295 | ] 296 | }, 297 | "execution_count": 14, 298 | "metadata": {}, 299 | "output_type": "execute_result" 300 | } 301 | ], 302 | "source": [ 303 | "len(word2index)" 304 | ] 305 | }, 306 | { 307 | "cell_type": "code", 308 | "execution_count": 15, 309 | "metadata": { 310 | "collapsed": true 311 | }, 312 | "outputs": [], 313 | "source": [ 314 | "for t in train_data:\n", 315 | " for i,fact in enumerate(t[0]):\n", 316 | " t[0][i] = prepare_sequence(fact, word2index).view(1, -1)\n", 317 | " \n", 318 | " t[1] = prepare_sequence(t[1], word2index).view(1, -1)\n", 319 | " t[2] = prepare_sequence(t[2], word2index).view(1, -1)" 320 | ] 321 | }, 322 | { 323 | "cell_type": "markdown", 324 | "metadata": {}, 325 | "source": [ 326 | "## Modeling " 327 | ] 328 | }, 329 | { 330 | "cell_type": "markdown", 331 | "metadata": {}, 332 | "source": [ 333 | "\n", 334 | "
borrowed image from https://arxiv.org/pdf/1506.07285.pdf
" 335 | ] 336 | }, 337 | { 338 | "cell_type": "code", 339 | "execution_count": 16, 340 | "metadata": { 341 | "collapsed": true 342 | }, 343 | "outputs": [], 344 | "source": [ 345 | "class DMN(nn.Module):\n", 346 | " def __init__(self, input_size, hidden_size, output_size, dropout_p=0.1):\n", 347 | " super(DMN, self).__init__()\n", 348 | " \n", 349 | " self.hidden_size = hidden_size\n", 350 | " self.embed = nn.Embedding(input_size, hidden_size, padding_idx=0) #sparse=True)\n", 351 | " self.input_gru = nn.GRU(hidden_size, hidden_size, batch_first=True)\n", 352 | " self.question_gru = nn.GRU(hidden_size, hidden_size, batch_first=True)\n", 353 | " \n", 354 | " self.gate = nn.Sequential(\n", 355 | " nn.Linear(hidden_size * 4, hidden_size),\n", 356 | " nn.Tanh(),\n", 357 | " nn.Linear(hidden_size, 1),\n", 358 | " nn.Sigmoid()\n", 359 | " )\n", 360 | " \n", 361 | " self.attention_grucell = nn.GRUCell(hidden_size, hidden_size)\n", 362 | " self.memory_grucell = nn.GRUCell(hidden_size, hidden_size)\n", 363 | " self.answer_grucell = nn.GRUCell(hidden_size * 2, hidden_size)\n", 364 | " self.answer_fc = nn.Linear(hidden_size, output_size)\n", 365 | " \n", 366 | " self.dropout = nn.Dropout(dropout_p)\n", 367 | " \n", 368 | " def init_hidden(self, inputs):\n", 369 | " hidden = Variable(torch.zeros(1, inputs.size(0), self.hidden_size))\n", 370 | " return hidden.cuda() if USE_CUDA else hidden\n", 371 | " \n", 372 | " def init_weight(self):\n", 373 | " nn.init.xavier_uniform(self.embed.state_dict()['weight'])\n", 374 | " \n", 375 | " for name, param in self.input_gru.state_dict().items():\n", 376 | " if 'weight' in name: nn.init.xavier_normal(param)\n", 377 | " for name, param in self.question_gru.state_dict().items():\n", 378 | " if 'weight' in name: nn.init.xavier_normal(param)\n", 379 | " for name, param in self.gate.state_dict().items():\n", 380 | " if 'weight' in name: nn.init.xavier_normal(param)\n", 381 | " for name, param in self.attention_grucell.state_dict().items():\n", 382 | " if 'weight' in name: nn.init.xavier_normal(param)\n", 383 | " for name, param in self.memory_grucell.state_dict().items():\n", 384 | " if 'weight' in name: nn.init.xavier_normal(param)\n", 385 | " for name, param in self.answer_grucell.state_dict().items():\n", 386 | " if 'weight' in name: nn.init.xavier_normal(param)\n", 387 | " \n", 388 | " nn.init.xavier_normal(self.answer_fc.state_dict()['weight'])\n", 389 | " self.answer_fc.bias.data.fill_(0)\n", 390 | " \n", 391 | " def forward(self, facts, fact_masks, questions, question_masks, num_decode, episodes=3, is_training=False):\n", 392 | " \"\"\"\n", 393 | " facts : (B,T_C,T_I) / LongTensor in List # batch_size, num_of_facts, length_of_each_fact(padded)\n", 394 | " fact_masks : (B,T_C,T_I) / ByteTensor in List # batch_size, num_of_facts, length_of_each_fact(padded)\n", 395 | " questions : (B,T_Q) / LongTensor # batch_size, question_length\n", 396 | " question_masks : (B,T_Q) / ByteTensor # batch_size, question_length\n", 397 | " \"\"\"\n", 398 | " # Input Module\n", 399 | " C = [] # encoded facts\n", 400 | " for fact, fact_mask in zip(facts, fact_masks):\n", 401 | " embeds = self.embed(fact)\n", 402 | " if is_training:\n", 403 | " embeds = self.dropout(embeds)\n", 404 | " hidden = self.init_hidden(fact)\n", 405 | " outputs, hidden = self.input_gru(embeds, hidden)\n", 406 | " real_hidden = []\n", 407 | "\n", 408 | " for i, o in enumerate(outputs): # B,T,D\n", 409 | " real_length = fact_mask[i].data.tolist().count(0) \n", 410 | " real_hidden.append(o[real_length - 1])\n", 411 | "\n", 412 | " C.append(torch.cat(real_hidden).view(fact.size(0), -1).unsqueeze(0))\n", 413 | " \n", 414 | " encoded_facts = torch.cat(C) # B,T_C,D\n", 415 | " \n", 416 | " # Question Module\n", 417 | " embeds = self.embed(questions)\n", 418 | " if is_training:\n", 419 | " embeds = self.dropout(embeds)\n", 420 | " hidden = self.init_hidden(questions)\n", 421 | " outputs, hidden = self.question_gru(embeds, hidden)\n", 422 | " \n", 423 | " if isinstance(question_masks, torch.autograd.variable.Variable):\n", 424 | " real_question = []\n", 425 | " for i, o in enumerate(outputs): # B,T,D\n", 426 | " real_length = question_masks[i].data.tolist().count(0) \n", 427 | " real_question.append(o[real_length - 1])\n", 428 | " encoded_question = torch.cat(real_question).view(questions.size(0), -1) # B,D\n", 429 | " else: # for inference mode\n", 430 | " encoded_question = hidden.squeeze(0) # B,D\n", 431 | " \n", 432 | " # Episodic Memory Module\n", 433 | " memory = encoded_question\n", 434 | " T_C = encoded_facts.size(1)\n", 435 | " B = encoded_facts.size(0)\n", 436 | " for i in range(episodes):\n", 437 | " hidden = self.init_hidden(encoded_facts.transpose(0, 1)[0]).squeeze(0) # B,D\n", 438 | " for t in range(T_C):\n", 439 | " #TODO: fact masking\n", 440 | " #TODO: gate function => softmax\n", 441 | " z = torch.cat([\n", 442 | " encoded_facts.transpose(0, 1)[t] * encoded_question, # B,D , element-wise product\n", 443 | " encoded_facts.transpose(0, 1)[t] * memory, # B,D , element-wise product\n", 444 | " torch.abs(encoded_facts.transpose(0,1)[t] - encoded_question), # B,D\n", 445 | " torch.abs(encoded_facts.transpose(0,1)[t] - memory) # B,D\n", 446 | " ], 1)\n", 447 | " g_t = self.gate(z) # B,1 scalar\n", 448 | " hidden = g_t * self.attention_grucell(encoded_facts.transpose(0, 1)[t], hidden) + (1 - g_t) * hidden\n", 449 | " \n", 450 | " e = hidden\n", 451 | " memory = self.memory_grucell(e, memory)\n", 452 | " \n", 453 | " # Answer Module\n", 454 | " answer_hidden = memory\n", 455 | " start_decode = Variable(LongTensor([[word2index['']] * memory.size(0)])).transpose(0, 1)\n", 456 | " y_t_1 = self.embed(start_decode).squeeze(1) # B,D\n", 457 | " \n", 458 | " decodes = []\n", 459 | " for t in range(num_decode):\n", 460 | " answer_hidden = self.answer_grucell(torch.cat([y_t_1, encoded_question], 1), answer_hidden)\n", 461 | " decodes.append(F.log_softmax(self.answer_fc(answer_hidden),1))\n", 462 | " return torch.cat(decodes, 1).view(B * num_decode, -1)\n" 463 | ] 464 | }, 465 | { 466 | "cell_type": "markdown", 467 | "metadata": {}, 468 | "source": [ 469 | "## Train " 470 | ] 471 | }, 472 | { 473 | "cell_type": "markdown", 474 | "metadata": {}, 475 | "source": [ 476 | "It takes for a while if you use just cpu." 477 | ] 478 | }, 479 | { 480 | "cell_type": "code", 481 | "execution_count": 17, 482 | "metadata": { 483 | "collapsed": true 484 | }, 485 | "outputs": [], 486 | "source": [ 487 | "HIDDEN_SIZE = 80\n", 488 | "BATCH_SIZE = 64\n", 489 | "LR = 0.001\n", 490 | "EPOCH = 50\n", 491 | "NUM_EPISODE = 3\n", 492 | "EARLY_STOPPING = False" 493 | ] 494 | }, 495 | { 496 | "cell_type": "code", 497 | "execution_count": 18, 498 | "metadata": { 499 | "collapsed": true 500 | }, 501 | "outputs": [], 502 | "source": [ 503 | "model = DMN(len(word2index), HIDDEN_SIZE, len(word2index))\n", 504 | "model.init_weight()\n", 505 | "if USE_CUDA:\n", 506 | " model = model.cuda()\n", 507 | "\n", 508 | "loss_function = nn.CrossEntropyLoss(ignore_index=0)\n", 509 | "optimizer = optim.Adam(model.parameters(), lr=LR)" 510 | ] 511 | }, 512 | { 513 | "cell_type": "code", 514 | "execution_count": 19, 515 | "metadata": { 516 | "collapsed": false 517 | }, 518 | "outputs": [ 519 | { 520 | "name": "stdout", 521 | "output_type": "stream", 522 | "text": [ 523 | "[0/50] mean_loss : 3.86\n", 524 | "[0/50] mean_loss : 1.32\n", 525 | "[1/50] mean_loss : 0.68\n", 526 | "[1/50] mean_loss : 0.65\n", 527 | "[2/50] mean_loss : 0.62\n", 528 | "[2/50] mean_loss : 0.65\n", 529 | "[3/50] mean_loss : 0.65\n", 530 | "[3/50] mean_loss : 0.64\n", 531 | "[4/50] mean_loss : 0.60\n", 532 | "[4/50] mean_loss : 0.62\n", 533 | "[5/50] mean_loss : 0.63\n", 534 | "[5/50] mean_loss : 0.61\n", 535 | "[6/50] mean_loss : 0.60\n", 536 | "[6/50] mean_loss : 0.61\n", 537 | "[7/50] mean_loss : 0.63\n", 538 | "[7/50] mean_loss : 0.60\n", 539 | "[8/50] mean_loss : 0.62\n", 540 | "[8/50] mean_loss : 0.60\n", 541 | "[9/50] mean_loss : 0.58\n", 542 | "[9/50] mean_loss : 0.60\n", 543 | "[10/50] mean_loss : 0.60\n", 544 | "[10/50] mean_loss : 0.60\n", 545 | "[11/50] mean_loss : 0.62\n", 546 | "[11/50] mean_loss : 0.60\n", 547 | "[12/50] mean_loss : 0.61\n", 548 | "[12/50] mean_loss : 0.60\n", 549 | "[13/50] mean_loss : 0.57\n", 550 | "[13/50] mean_loss : 0.60\n", 551 | "[14/50] mean_loss : 0.59\n", 552 | "[14/50] mean_loss : 0.60\n", 553 | "[15/50] mean_loss : 0.61\n", 554 | "[15/50] mean_loss : 0.60\n", 555 | "[16/50] mean_loss : 0.59\n", 556 | "[16/50] mean_loss : 0.60\n", 557 | "[17/50] mean_loss : 0.59\n", 558 | "[17/50] mean_loss : 0.60\n", 559 | "[18/50] mean_loss : 0.51\n", 560 | "[18/50] mean_loss : 0.50\n", 561 | "[19/50] mean_loss : 0.44\n", 562 | "[19/50] mean_loss : 0.37\n", 563 | "[20/50] mean_loss : 0.30\n", 564 | "[20/50] mean_loss : 0.33\n", 565 | "[21/50] mean_loss : 0.31\n", 566 | "[21/50] mean_loss : 0.31\n", 567 | "[22/50] mean_loss : 0.29\n", 568 | "[22/50] mean_loss : 0.31\n", 569 | "[23/50] mean_loss : 0.29\n", 570 | "[23/50] mean_loss : 0.31\n", 571 | "[24/50] mean_loss : 0.24\n", 572 | "[24/50] mean_loss : 0.31\n", 573 | "[25/50] mean_loss : 0.30\n", 574 | "[25/50] mean_loss : 0.30\n", 575 | "[26/50] mean_loss : 0.14\n", 576 | "[26/50] mean_loss : 0.16\n", 577 | "[27/50] mean_loss : 0.12\n", 578 | "[27/50] mean_loss : 0.15\n", 579 | "[28/50] mean_loss : 0.18\n", 580 | "[28/50] mean_loss : 0.14\n", 581 | "[29/50] mean_loss : 0.12\n", 582 | "[29/50] mean_loss : 0.14\n", 583 | "[30/50] mean_loss : 0.14\n", 584 | "[30/50] mean_loss : 0.14\n", 585 | "[31/50] mean_loss : 0.13\n", 586 | "[31/50] mean_loss : 0.14\n", 587 | "[32/50] mean_loss : 0.11\n", 588 | "[32/50] mean_loss : 0.13\n", 589 | "[33/50] mean_loss : 0.08\n", 590 | "[33/50] mean_loss : 0.06\n", 591 | "[34/50] mean_loss : 0.01\n", 592 | "[34/50] mean_loss : 0.03\n", 593 | "[35/50] mean_loss : 0.01\n", 594 | "Early Stopping!\n" 595 | ] 596 | } 597 | ], 598 | "source": [ 599 | "for epoch in range(EPOCH):\n", 600 | " losses = []\n", 601 | " if EARLY_STOPPING: \n", 602 | " break\n", 603 | " \n", 604 | " for i,batch in enumerate(getBatch(BATCH_SIZE, train_data)):\n", 605 | " facts, fact_masks, questions, question_masks, answers = pad_to_batch(batch, word2index)\n", 606 | " \n", 607 | " model.zero_grad()\n", 608 | " pred = model(facts, fact_masks, questions, question_masks, answers.size(1), NUM_EPISODE, True)\n", 609 | " loss = loss_function(pred, answers.view(-1))\n", 610 | " losses.append(loss.data.tolist()[0])\n", 611 | " \n", 612 | " loss.backward()\n", 613 | " optimizer.step()\n", 614 | " \n", 615 | " if i % 100 == 0:\n", 616 | " print(\"[%d/%d] mean_loss : %0.2f\" %(epoch, EPOCH, np.mean(losses)))\n", 617 | " \n", 618 | " if np.mean(losses) < 0.01:\n", 619 | " EARLY_STOPPING = True\n", 620 | " print(\"Early Stopping!\")\n", 621 | " break\n", 622 | " losses = []" 623 | ] 624 | }, 625 | { 626 | "cell_type": "markdown", 627 | "metadata": {}, 628 | "source": [ 629 | "## Test " 630 | ] 631 | }, 632 | { 633 | "cell_type": "code", 634 | "execution_count": 21, 635 | "metadata": { 636 | "collapsed": true 637 | }, 638 | "outputs": [], 639 | "source": [ 640 | "def pad_to_fact(fact, x_to_ix): # this is for inference\n", 641 | " \n", 642 | " max_x = max([s.size(1) for s in fact])\n", 643 | " x_p = []\n", 644 | " for i in range(len(fact)):\n", 645 | " if fact[i].size(1) < max_x:\n", 646 | " x_p.append(torch.cat([fact[i], Variable(LongTensor([x_to_ix['']] * (max_x - fact[i].size(1)))).view(1, -1)], 1))\n", 647 | " else:\n", 648 | " x_p.append(fact[i])\n", 649 | " \n", 650 | " fact = torch.cat(x_p)\n", 651 | " fact_mask = torch.cat([Variable(ByteTensor(tuple(map(lambda s: s ==0, t.data))), volatile=False) for t in fact]).view(fact.size(0), -1)\n", 652 | " return fact, fact_mask" 653 | ] 654 | }, 655 | { 656 | "cell_type": "markdown", 657 | "metadata": {}, 658 | "source": [ 659 | "### Prepare Test data " 660 | ] 661 | }, 662 | { 663 | "cell_type": "code", 664 | "execution_count": 27, 665 | "metadata": { 666 | "collapsed": true 667 | }, 668 | "outputs": [], 669 | "source": [ 670 | "test_data = bAbI_data_load('../dataset/bAbI/en-10k/qa5_three-arg-relations_test.txt')" 671 | ] 672 | }, 673 | { 674 | "cell_type": "code", 675 | "execution_count": 28, 676 | "metadata": { 677 | "collapsed": true 678 | }, 679 | "outputs": [], 680 | "source": [ 681 | "for t in test_data:\n", 682 | " for i, fact in enumerate(t[0]):\n", 683 | " t[0][i] = prepare_sequence(fact, word2index).view(1, -1)\n", 684 | " \n", 685 | " t[1] = prepare_sequence(t[1], word2index).view(1, -1)\n", 686 | " t[2] = prepare_sequence(t[2], word2index).view(1, -1)" 687 | ] 688 | }, 689 | { 690 | "cell_type": "markdown", 691 | "metadata": {}, 692 | "source": [ 693 | "### Accuracy " 694 | ] 695 | }, 696 | { 697 | "cell_type": "code", 698 | "execution_count": 31, 699 | "metadata": { 700 | "collapsed": true 701 | }, 702 | "outputs": [], 703 | "source": [ 704 | "accuracy = 0" 705 | ] 706 | }, 707 | { 708 | "cell_type": "code", 709 | "execution_count": 32, 710 | "metadata": { 711 | "collapsed": false 712 | }, 713 | "outputs": [ 714 | { 715 | "name": "stdout", 716 | "output_type": "stream", 717 | "text": [ 718 | "97.39999999999999\n" 719 | ] 720 | } 721 | ], 722 | "source": [ 723 | "for t in test_data:\n", 724 | " fact, fact_mask = pad_to_fact(t[0], word2index)\n", 725 | " question = t[1]\n", 726 | " question_mask = Variable(ByteTensor([0] * t[1].size(1)), volatile=False).unsqueeze(0)\n", 727 | " answer = t[2].squeeze(0)\n", 728 | " \n", 729 | " model.zero_grad()\n", 730 | " pred = model([fact], [fact_mask], question, question_mask, answer.size(0), NUM_EPISODE)\n", 731 | " if pred.max(1)[1].data.tolist() == answer.data.tolist():\n", 732 | " accuracy += 1\n", 733 | "\n", 734 | "print(accuracy/len(test_data) * 100)" 735 | ] 736 | }, 737 | { 738 | "cell_type": "markdown", 739 | "metadata": {}, 740 | "source": [ 741 | "### Sample test result " 742 | ] 743 | }, 744 | { 745 | "cell_type": "code", 746 | "execution_count": 34, 747 | "metadata": { 748 | "collapsed": false 749 | }, 750 | "outputs": [ 751 | { 752 | "name": "stdout", 753 | "output_type": "stream", 754 | "text": [ 755 | "Facts : \n", 756 | "Bill went back to the bedroom \n", 757 | "Mary went to the office \n", 758 | "Jeff journeyed to the kitchen \n", 759 | "Fred journeyed to the kitchen \n", 760 | "Fred got the milk there \n", 761 | "Fred handed the milk to Jeff \n", 762 | "Jeff passed the milk to Fred \n", 763 | "Fred gave the milk to Jeff \n", 764 | "\n", 765 | "Question : Who received the milk ?\n", 766 | "\n", 767 | "Answer : Jeff \n", 768 | "Prediction : Jeff \n" 769 | ] 770 | } 771 | ], 772 | "source": [ 773 | "t = random.choice(test_data)\n", 774 | "fact, fact_mask = pad_to_fact(t[0], word2index)\n", 775 | "question = t[1]\n", 776 | "question_mask = Variable(ByteTensor([0] * t[1].size(1)), volatile=False).unsqueeze(0)\n", 777 | "answer = t[2].squeeze(0)\n", 778 | "\n", 779 | "model.zero_grad()\n", 780 | "pred = model([fact], [fact_mask], question, question_mask, answer.size(0), NUM_EPISODE)\n", 781 | "\n", 782 | "print(\"Facts : \")\n", 783 | "print('\\n'.join([' '.join(list(map(lambda x: index2word[x],f))) for f in fact.data.tolist()]))\n", 784 | "print(\"\")\n", 785 | "print(\"Question : \",' '.join(list(map(lambda x: index2word[x], question.data.tolist()[0]))))\n", 786 | "print(\"\")\n", 787 | "print(\"Answer : \",' '.join(list(map(lambda x: index2word[x], answer.data.tolist()))))\n", 788 | "print(\"Prediction : \",' '.join(list(map(lambda x: index2word[x], pred.max(1)[1].data.tolist()))))" 789 | ] 790 | }, 791 | { 792 | "cell_type": "markdown", 793 | "metadata": { 794 | "collapsed": true 795 | }, 796 | "source": [ 797 | "## Further topics " 798 | ] 799 | }, 800 | { 801 | "cell_type": "markdown", 802 | "metadata": {}, 803 | "source": [ 804 | "* Dynamic Memory Networks for Visual and Textual Question Answering(DMN+)\n", 805 | "* DMN+ Pytorch implementation\n", 806 | "* Dynamic Coattention Networks For Question Answering\n", 807 | "* DCN+: Mixed Objective and Deep Residual Coattention for Question Answering" 808 | ] 809 | }, 810 | { 811 | "cell_type": "code", 812 | "execution_count": null, 813 | "metadata": { 814 | "collapsed": true 815 | }, 816 | "outputs": [], 817 | "source": [] 818 | } 819 | ], 820 | "metadata": { 821 | "kernelspec": { 822 | "display_name": "Python 3", 823 | "language": "python", 824 | "name": "python3" 825 | }, 826 | "language_info": { 827 | "codemirror_mode": { 828 | "name": "ipython", 829 | "version": 3 830 | }, 831 | "file_extension": ".py", 832 | "mimetype": "text/x-python", 833 | "name": "python", 834 | "nbconvert_exporter": "python", 835 | "pygments_lexer": "ipython3", 836 | "version": "3.5.2" 837 | } 838 | }, 839 | "nbformat": 4, 840 | "nbformat_minor": 2 841 | } 842 | -------------------------------------------------------------------------------- /script/docker-compose.yml: -------------------------------------------------------------------------------- 1 | cs224n: 2 | image: dsksd/deepstudy:0.2 3 | ports: 4 | - "8888:8888" 5 | - "6006:6006" 6 | volumes: 7 | - ../:/notebooks/work 8 | command: /run.sh 9 | tty: true -------------------------------------------------------------------------------- /script/prepare_dataset.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | OUT_DIR="${1:-../dataset}" 4 | 5 | mkdir -v -p $OUT_DIR 6 | 7 | echo "download machine translation dataset from http://www.manythings.org/anki/..." 8 | curl -o "$OUT_DIR/fra-eng.zip" "http://www.manythings.org/anki/fra-eng.zip" 9 | unzip "$OUT_DIR/fra-eng.zip" -d "$OUT_DIR" 10 | rm "$OUT_DIR/fra-eng.zip" 11 | rm "$OUT_DIR/_about.txt" 12 | mv "$OUT_DIR/fra.txt" "$OUT_DIR/eng-fra.txt" 13 | 14 | echo "download ptb dataset... (clone from https://github.com/tomsercu/lstm/tree/master/data" 15 | mkdir -v -p "$OUT_DIR/ptb" 16 | wget "https://raw.githubusercontent.com/tomsercu/lstm/master/data/ptb.train.txt" -P "$OUT_DIR/ptb" 17 | wget "https://raw.githubusercontent.com/tomsercu/lstm/master/data/ptb.valid.txt" -P "$OUT_DIR/ptb" 18 | wget "https://raw.githubusercontent.com/tomsercu/lstm/master/data/ptb.test.txt" -P "$OUT_DIR/ptb" 19 | 20 | echo "download TREC question dataset..." 21 | curl -o "$OUT_DIR/train_5500.label.txt" "http://cogcomp.org/Data/QA/QC/train_5500.label" 22 | 23 | echo "download Stanford sentment treebank..." 24 | curl -o "$OUT_DIR/trainDevTestTrees_PTB.zip" "https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip" 25 | unzip "$OUT_DIR/trainDevTestTrees_PTB.zip" -d "$OUT_DIR" 26 | rm "$OUT_DIR/trainDevTestTrees_PTB.zip" 27 | 28 | echo "download bAbI dataset..." 29 | curl -o "$OUT_DIR/tasks_1-20_v1-2.tar.gz" "http://www.thespermwhale.com/jaseweston/babi/tasks_1-20_v1-2.tar.gz" 30 | tar zxvf "$OUT_DIR/tasks_1-20_v1-2.tar.gz" 31 | mv "tasks_1-20_v1-2" "$OUT_DIR/bAbI" 32 | rm "$OUT_DIR/tasks_1-20_v1-2.tar.gz" 33 | 34 | echo "download nltk dataset..." 35 | python3 -c "import nltk;nltk.download('gutenberg');nltk.download('brown');nltk.download('conll2002');nltk.download('timit')" 36 | 37 | echo "download dependency parser dataset... (clone from https://github.com/rguthrie3/DeepDependencyParsingProblemSet" 38 | mkdir -v -p "$OUT_DIR/dparser" 39 | wget "https://raw.githubusercontent.com/rguthrie3/DeepDependencyParsingProblemSet/master/data/train.txt" -P "$OUT_DIR/dparser" 40 | wget "https://raw.githubusercontent.com/rguthrie3/DeepDependencyParsingProblemSet/master/data/vocab.txt" -P "$OUT_DIR/dparser" 41 | wget "https://raw.githubusercontent.com/rguthrie3/DeepDependencyParsingProblemSet/master/data/dev.txt" -P "$OUT_DIR/dparser" -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from distutils.core import setup 2 | from Cython.Build import cythonize 3 | from Cython.Distutils import build_ext 4 | from distutils.extension import Extension 5 | 6 | setup(name="Corpus",cmdclass = {'build_ext': build_ext},ext_modules = [Extension("Corpus", ["Corpus.pyx"],language='c++')]) --------------------------------------------------------------------------------