├── .gitignore
├── Corpus.pyx
├── LICENSE
├── README.md
├── images
├── 01.skipgram-objective.png
├── 01.skipgram-prepare-data.png
├── 02.skipgram-objective.png
├── 03.glove-objective.png
├── 03.glove-weighting-function.png
├── 04.window-classifier-architecture.png
├── 04.window-data.png
├── 05.neural-dparser-architecture.png
├── 05.transition-based-parse.png
├── 06.rnnlm-architecture.png
├── 07.attention-mechanism.png
├── 07.pad_to_sequence.png
├── 07.seq2seq.png
├── 08.cnn-for-text-architecture.png
├── 09.rntn-layer.png
└── 10.dmn-architecture.png
├── notebooks
├── 01.Skip-gram-Naive-Softmax.ipynb
├── 02.Skip-gram-Negative-Sampling.ipynb
├── 03.GloVe.ipynb
├── 04.Window-Classifier-for-NER.ipynb
├── 05.Neural-Dependancy-Parser.ipynb
├── 06.RNN-Language-Model.ipynb
├── 07.Neural-Machine-Translation-with-Attention.ipynb
├── 08.CNN-for-Text-Classification.ipynb
├── 09.Recursive-NN-for-Sentiment-Classification.ipynb
└── 10.Dynamic-Memory-Network-for-Question-Answering.ipynb
├── script
├── docker-compose.yml
└── prepare_dataset.sh
└── setup.py
/.gitignore:
--------------------------------------------------------------------------------
1 | *.txt
2 | __pycache__
3 | .ipynb_checkpoints
4 | build/
5 | dataset/
6 | model/
7 | *.cpp
8 | *.so
9 | *.bin
10 | *.pkl
11 |
--------------------------------------------------------------------------------
/Corpus.pyx:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import scipy.sparse as sp
3 | from libcpp.vector cimport vector
4 | from libcpp.string cimport string
5 | from libcpp.map cimport map
6 |
7 | # Text 전처리 cpp extention 모듈
8 |
9 | def make_dictionary(vocab):
10 | cdef int vocab_size = len(vocab)
11 | cdef int i
12 | dictionary={}
13 | inv_dictionary={}
14 | for i in range(vocab_size):
15 | dictionary[vocab[i]] = i
16 | inv_dictionary[i]=vocab[i]
17 | return dictionary,inv_dictionary
18 |
19 |
20 | def make_window_data(sents,window_size):
21 | pass
22 |
23 |
24 | def make_coo_matrix(corpus,dictionary):
25 | cdef int matrix_size = len(dictionary)
26 | pass
27 |
28 |
29 | def getBatch_FromBucket(batch_size,buckets):
30 | i=0
31 | bucket_mask =[False for _ in range(len(buckets))]
32 | indices = [[0,batch_size] for _ in range(len(buckets))]
33 | is_done=False
34 | while is_done==False:
35 | batch = buckets[i][indices[i][0]:indices[i][1]]
36 | temp = indices[i][1]
37 | indices[i][1]= indices[i][1]+batch_size
38 | indices[i][0] = temp
39 |
40 | i = (i+1)%len(buckets)
41 | while bucket_mask[i]:
42 | i = (i+1)%len(buckets)
43 |
44 | if indices[i][1]>len(buckets[i]):
45 | bucket_mask[i]= True
46 | if bucket_mask.count(True)==len(buckets):
47 | is_done=True
48 | else:
49 | while bucket_mask[i]:
50 | i = (i+1)%len(buckets)
51 | yield batch
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2017 SungDong Kim
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # DeepNLP-models-Pytorch
2 |
3 | Pytorch implementations of various Deep NLP models in cs-224n(Stanford Univ: NLP with Deep Learning)
4 |
5 | - This is not for Pytorch beginners. If it is your first time to use Pytorch, I recommend these [awesome tutorials](#references).
6 | - If you're interested in DeepNLP, I strongly recommend you to work with this awesome lecture.
7 |
8 | * cs-224n-slides
9 | * cs-224n-videos
10 |
11 | This material is not perfect but will help your study and research:) Please feel free to pull requests!!
12 |
13 |
14 |
15 | ## Contents
16 |
17 | | Model | Links |
18 | | ------------- |:-------------:|
19 | | 01. Skip-gram-Naive-Softmax | [notebook / data / paper] |
20 | | 02. Skip-gram-Negative-Sampling | [notebook / data / paper] |
21 | | 03. GloVe | [notebook / data / paper] |
22 | | 04. Window-Classifier-for-NER | [notebook / data / paper] |
23 | | 05. Neural-Dependancy-Parser | [notebook / data / paper] |
24 | | 06. RNN-Language-Model | [notebook / data / paper] |
25 | | 07. Neural-Machine-Translation-with-Attention | [notebook / data / paper] |
26 | | 08. CNN-for-Text-Classification | [notebook / data / paper] |
27 | | 09. Recursive-NN-for-Sentiment-Classification | [notebook / data / paper] |
28 | | 10. Dynamic-Memory-Network-for-Question-Answering | [notebook / data / paper] |
29 |
30 |
31 | ## Requirements
32 |
33 | - Python 3.5
34 | - Pytorch 0.2
35 | - nltk 3.2.2
36 | - gensim 2.2.0
37 | - sklearn_crfsuite
38 |
39 |
40 | ## Getting started
41 |
42 | `git clone https://github.com/DSKSD/cs-224n-Pytorch.git`
43 |
44 | ### prepare dataset
45 |
46 | ````
47 | cd script
48 | chmod u+x prepare_dataset.sh
49 | ./prepare_dataset.sh
50 | ````
51 |
52 | ### docker env
53 | ubuntu 16.04 python 3.5.2 with various of ML/DL packages including tensorflow, sklearn, pytorch
54 |
55 | `docker pull dsksd/deepstudy:0.2`
56 |
57 | ````
58 | pip3 install docker-compose
59 | cd script
60 | docker-compose up -d
61 | ````
62 |
63 | ### cloud setting
64 |
65 | `not yet`
66 |
67 | ## References
68 |
69 | * practical-pytorch
70 | * DeepLearningForNLPInPytorch
71 | * pytorch-tutorial
72 | * pytorch-examples
73 |
74 | ## Author
75 |
76 | Sungdong Kim / @DSKSD
--------------------------------------------------------------------------------
/images/01.skipgram-objective.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Shawn1993/DeepNLP-models-Pytorch/ceeb4221b176790229cd20c6ca4c05f625bdf02e/images/01.skipgram-objective.png
--------------------------------------------------------------------------------
/images/01.skipgram-prepare-data.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Shawn1993/DeepNLP-models-Pytorch/ceeb4221b176790229cd20c6ca4c05f625bdf02e/images/01.skipgram-prepare-data.png
--------------------------------------------------------------------------------
/images/02.skipgram-objective.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Shawn1993/DeepNLP-models-Pytorch/ceeb4221b176790229cd20c6ca4c05f625bdf02e/images/02.skipgram-objective.png
--------------------------------------------------------------------------------
/images/03.glove-objective.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Shawn1993/DeepNLP-models-Pytorch/ceeb4221b176790229cd20c6ca4c05f625bdf02e/images/03.glove-objective.png
--------------------------------------------------------------------------------
/images/03.glove-weighting-function.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Shawn1993/DeepNLP-models-Pytorch/ceeb4221b176790229cd20c6ca4c05f625bdf02e/images/03.glove-weighting-function.png
--------------------------------------------------------------------------------
/images/04.window-classifier-architecture.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Shawn1993/DeepNLP-models-Pytorch/ceeb4221b176790229cd20c6ca4c05f625bdf02e/images/04.window-classifier-architecture.png
--------------------------------------------------------------------------------
/images/04.window-data.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Shawn1993/DeepNLP-models-Pytorch/ceeb4221b176790229cd20c6ca4c05f625bdf02e/images/04.window-data.png
--------------------------------------------------------------------------------
/images/05.neural-dparser-architecture.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Shawn1993/DeepNLP-models-Pytorch/ceeb4221b176790229cd20c6ca4c05f625bdf02e/images/05.neural-dparser-architecture.png
--------------------------------------------------------------------------------
/images/05.transition-based-parse.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Shawn1993/DeepNLP-models-Pytorch/ceeb4221b176790229cd20c6ca4c05f625bdf02e/images/05.transition-based-parse.png
--------------------------------------------------------------------------------
/images/06.rnnlm-architecture.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Shawn1993/DeepNLP-models-Pytorch/ceeb4221b176790229cd20c6ca4c05f625bdf02e/images/06.rnnlm-architecture.png
--------------------------------------------------------------------------------
/images/07.attention-mechanism.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Shawn1993/DeepNLP-models-Pytorch/ceeb4221b176790229cd20c6ca4c05f625bdf02e/images/07.attention-mechanism.png
--------------------------------------------------------------------------------
/images/07.pad_to_sequence.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Shawn1993/DeepNLP-models-Pytorch/ceeb4221b176790229cd20c6ca4c05f625bdf02e/images/07.pad_to_sequence.png
--------------------------------------------------------------------------------
/images/07.seq2seq.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Shawn1993/DeepNLP-models-Pytorch/ceeb4221b176790229cd20c6ca4c05f625bdf02e/images/07.seq2seq.png
--------------------------------------------------------------------------------
/images/08.cnn-for-text-architecture.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Shawn1993/DeepNLP-models-Pytorch/ceeb4221b176790229cd20c6ca4c05f625bdf02e/images/08.cnn-for-text-architecture.png
--------------------------------------------------------------------------------
/images/09.rntn-layer.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Shawn1993/DeepNLP-models-Pytorch/ceeb4221b176790229cd20c6ca4c05f625bdf02e/images/09.rntn-layer.png
--------------------------------------------------------------------------------
/images/10.dmn-architecture.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Shawn1993/DeepNLP-models-Pytorch/ceeb4221b176790229cd20c6ca4c05f625bdf02e/images/10.dmn-architecture.png
--------------------------------------------------------------------------------
/notebooks/01.Skip-gram-Naive-Softmax.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# 1. Skip-gram with naiive softmax "
8 | ]
9 | },
10 | {
11 | "cell_type": "markdown",
12 | "metadata": {},
13 | "source": [
14 | "I recommend you take a look at these material first."
15 | ]
16 | },
17 | {
18 | "cell_type": "markdown",
19 | "metadata": {},
20 | "source": [
21 | "* http://web.stanford.edu/class/cs224n/lectures/cs224n-2017-lecture2.pdf\n",
22 | "* https://arxiv.org/abs/1301.3781\n",
23 | "* http://mccormickml.com/2016/04/19/word2vec-tutorial-the-skip-gram-model/"
24 | ]
25 | },
26 | {
27 | "cell_type": "code",
28 | "execution_count": 1,
29 | "metadata": {
30 | "collapsed": true
31 | },
32 | "outputs": [],
33 | "source": [
34 | "import torch\n",
35 | "import torch.nn as nn\n",
36 | "from torch.autograd import Variable\n",
37 | "import torch.optim as optim\n",
38 | "import torch.nn.functional as F\n",
39 | "import nltk\n",
40 | "import random\n",
41 | "import numpy as np\n",
42 | "from collections import Counter\n",
43 | "flatten = lambda l: [item for sublist in l for item in sublist]"
44 | ]
45 | },
46 | {
47 | "cell_type": "code",
48 | "execution_count": 2,
49 | "metadata": {
50 | "collapsed": false
51 | },
52 | "outputs": [
53 | {
54 | "name": "stdout",
55 | "output_type": "stream",
56 | "text": [
57 | "0.2.0+751198f\n",
58 | "3.2.4\n"
59 | ]
60 | }
61 | ],
62 | "source": [
63 | "print(torch.__version__)\n",
64 | "print(nltk.__version__)"
65 | ]
66 | },
67 | {
68 | "cell_type": "code",
69 | "execution_count": 3,
70 | "metadata": {
71 | "collapsed": true
72 | },
73 | "outputs": [],
74 | "source": [
75 | "USE_CUDA = torch.cuda.is_available()\n",
76 | "\n",
77 | "FloatTensor = torch.cuda.FloatTensor if USE_CUDA else torch.FloatTensor\n",
78 | "LongTensor = torch.cuda.LongTensor if USE_CUDA else torch.LongTensor\n",
79 | "ByteTensor = torch.cuda.ByteTensor if USE_CUDA else torch.ByteTensor"
80 | ]
81 | },
82 | {
83 | "cell_type": "code",
84 | "execution_count": 4,
85 | "metadata": {
86 | "collapsed": true
87 | },
88 | "outputs": [],
89 | "source": [
90 | "def getBatch(batch_size,train_data):\n",
91 | " random.shuffle(train_data)\n",
92 | " sindex=0\n",
93 | " eindex=batch_size\n",
94 | " while eindex < len(train_data):\n",
95 | " batch = train_data[sindex:eindex]\n",
96 | " temp = eindex\n",
97 | " eindex = eindex+batch_size\n",
98 | " sindex = temp\n",
99 | " yield batch\n",
100 | " \n",
101 | " if eindex >= len(train_data):\n",
102 | " batch = train_data[sindex:]\n",
103 | " yield batch"
104 | ]
105 | },
106 | {
107 | "cell_type": "code",
108 | "execution_count": 5,
109 | "metadata": {
110 | "collapsed": true
111 | },
112 | "outputs": [],
113 | "source": [
114 | "def prepare_sequence(seq, word2index):\n",
115 | " idxs = list(map(lambda w: word2index[w] if w in word2index.keys() else word2index[\"\"], seq))\n",
116 | " return Variable(LongTensor(idxs))\n",
117 | "\n",
118 | "def prepare_word(word,word2index):\n",
119 | " return Variable(LongTensor([word2index[word]]) if word in word2index.keys() else LongTensor([word2index[\"\"]]))"
120 | ]
121 | },
122 | {
123 | "cell_type": "markdown",
124 | "metadata": {},
125 | "source": [
126 | "## Data load and Preprocessing "
127 | ]
128 | },
129 | {
130 | "cell_type": "markdown",
131 | "metadata": {},
132 | "source": [
133 | "### Load corpus : Gutenberg corpus"
134 | ]
135 | },
136 | {
137 | "cell_type": "markdown",
138 | "metadata": {},
139 | "source": [
140 | "If you don't have gutenberg corpus, you can download it first using nltk.download()"
141 | ]
142 | },
143 | {
144 | "cell_type": "code",
145 | "execution_count": 6,
146 | "metadata": {
147 | "collapsed": false
148 | },
149 | "outputs": [
150 | {
151 | "data": {
152 | "text/plain": [
153 | "['austen-emma.txt',\n",
154 | " 'austen-persuasion.txt',\n",
155 | " 'austen-sense.txt',\n",
156 | " 'bible-kjv.txt',\n",
157 | " 'blake-poems.txt',\n",
158 | " 'bryant-stories.txt',\n",
159 | " 'burgess-busterbrown.txt',\n",
160 | " 'carroll-alice.txt',\n",
161 | " 'chesterton-ball.txt',\n",
162 | " 'chesterton-brown.txt',\n",
163 | " 'chesterton-thursday.txt',\n",
164 | " 'edgeworth-parents.txt',\n",
165 | " 'melville-moby_dick.txt',\n",
166 | " 'milton-paradise.txt',\n",
167 | " 'shakespeare-caesar.txt',\n",
168 | " 'shakespeare-hamlet.txt',\n",
169 | " 'shakespeare-macbeth.txt',\n",
170 | " 'whitman-leaves.txt']"
171 | ]
172 | },
173 | "execution_count": 6,
174 | "metadata": {},
175 | "output_type": "execute_result"
176 | }
177 | ],
178 | "source": [
179 | "nltk.corpus.gutenberg.fileids()"
180 | ]
181 | },
182 | {
183 | "cell_type": "code",
184 | "execution_count": 7,
185 | "metadata": {
186 | "collapsed": true
187 | },
188 | "outputs": [],
189 | "source": [
190 | "corpus = list(nltk.corpus.gutenberg.sents('melville-moby_dick.txt'))[:100] # sampling sentences for test\n",
191 | "corpus = [[word.lower() for word in sent] for sent in corpus]"
192 | ]
193 | },
194 | {
195 | "cell_type": "markdown",
196 | "metadata": {},
197 | "source": [
198 | "### Extract Stopwords from unigram distribution's tails"
199 | ]
200 | },
201 | {
202 | "cell_type": "code",
203 | "execution_count": 8,
204 | "metadata": {
205 | "collapsed": true
206 | },
207 | "outputs": [],
208 | "source": [
209 | "word_count = Counter(flatten(corpus))\n",
210 | "border =int(len(word_count)*0.01) "
211 | ]
212 | },
213 | {
214 | "cell_type": "code",
215 | "execution_count": 9,
216 | "metadata": {
217 | "collapsed": true
218 | },
219 | "outputs": [],
220 | "source": [
221 | "stopwords = word_count.most_common()[:border]+list(reversed(word_count.most_common()))[:border]"
222 | ]
223 | },
224 | {
225 | "cell_type": "code",
226 | "execution_count": 10,
227 | "metadata": {
228 | "collapsed": true
229 | },
230 | "outputs": [],
231 | "source": [
232 | "stopwords = [s[0] for s in stopwords]"
233 | ]
234 | },
235 | {
236 | "cell_type": "code",
237 | "execution_count": 11,
238 | "metadata": {
239 | "collapsed": false
240 | },
241 | "outputs": [
242 | {
243 | "data": {
244 | "text/plain": [
245 | "[',', '.', 'the', 'of', 'and', 'baleine', '--(', 'fat', 'oil', 'boiling']"
246 | ]
247 | },
248 | "execution_count": 11,
249 | "metadata": {},
250 | "output_type": "execute_result"
251 | }
252 | ],
253 | "source": [
254 | "stopwords"
255 | ]
256 | },
257 | {
258 | "cell_type": "markdown",
259 | "metadata": {},
260 | "source": [
261 | "### Build vocab"
262 | ]
263 | },
264 | {
265 | "cell_type": "code",
266 | "execution_count": 12,
267 | "metadata": {
268 | "collapsed": true
269 | },
270 | "outputs": [],
271 | "source": [
272 | "vocab = list(set(flatten(corpus))-set(stopwords))\n",
273 | "vocab.append('')"
274 | ]
275 | },
276 | {
277 | "cell_type": "code",
278 | "execution_count": 13,
279 | "metadata": {
280 | "collapsed": false
281 | },
282 | "outputs": [
283 | {
284 | "name": "stdout",
285 | "output_type": "stream",
286 | "text": [
287 | "592 583\n"
288 | ]
289 | }
290 | ],
291 | "source": [
292 | "print(len(set(flatten(corpus))),len(vocab))"
293 | ]
294 | },
295 | {
296 | "cell_type": "code",
297 | "execution_count": 14,
298 | "metadata": {
299 | "collapsed": true
300 | },
301 | "outputs": [],
302 | "source": [
303 | "word2index = {'' : 0} \n",
304 | "\n",
305 | "for vo in vocab:\n",
306 | " if vo not in word2index.keys():\n",
307 | " word2index[vo]=len(word2index)\n",
308 | "\n",
309 | "index2word = {v:k for k,v in word2index.items()} "
310 | ]
311 | },
312 | {
313 | "cell_type": "markdown",
314 | "metadata": {},
315 | "source": [
316 | "### Prepare train data "
317 | ]
318 | },
319 | {
320 | "cell_type": "markdown",
321 | "metadata": {},
322 | "source": [
323 | "window data example"
324 | ]
325 | },
326 | {
327 | "cell_type": "markdown",
328 | "metadata": {},
329 | "source": [
330 | "
\n",
331 | "borrowed image from http://mccormickml.com/2016/04/19/word2vec-tutorial-the-skip-gram-model/"
332 | ]
333 | },
334 | {
335 | "cell_type": "code",
336 | "execution_count": 15,
337 | "metadata": {
338 | "collapsed": true
339 | },
340 | "outputs": [],
341 | "source": [
342 | "WINDOW_SIZE = 3\n",
343 | "windows = flatten([list(nltk.ngrams(['']*WINDOW_SIZE+c+['']*WINDOW_SIZE,WINDOW_SIZE*2+1)) for c in corpus])"
344 | ]
345 | },
346 | {
347 | "cell_type": "code",
348 | "execution_count": 16,
349 | "metadata": {
350 | "collapsed": false
351 | },
352 | "outputs": [
353 | {
354 | "data": {
355 | "text/plain": [
356 | "('', '', '', '[', 'moby', 'dick', 'by')"
357 | ]
358 | },
359 | "execution_count": 16,
360 | "metadata": {},
361 | "output_type": "execute_result"
362 | }
363 | ],
364 | "source": [
365 | "windows[0]"
366 | ]
367 | },
368 | {
369 | "cell_type": "code",
370 | "execution_count": 17,
371 | "metadata": {
372 | "collapsed": false
373 | },
374 | "outputs": [
375 | {
376 | "name": "stdout",
377 | "output_type": "stream",
378 | "text": [
379 | "[('[', 'moby'), ('[', 'dick'), ('[', 'by'), ('moby', '['), ('moby', 'dick'), ('moby', 'by')]\n"
380 | ]
381 | }
382 | ],
383 | "source": [
384 | "train_data = []\n",
385 | "\n",
386 | "for window in windows:\n",
387 | " for i in range(WINDOW_SIZE*2+1):\n",
388 | " if i==WINDOW_SIZE or window[i]=='': continue\n",
389 | " train_data.append((window[WINDOW_SIZE],window[i]))\n",
390 | "\n",
391 | "print(train_data[:WINDOW_SIZE*2])"
392 | ]
393 | },
394 | {
395 | "cell_type": "code",
396 | "execution_count": 18,
397 | "metadata": {
398 | "collapsed": true
399 | },
400 | "outputs": [],
401 | "source": [
402 | "X_p=[]\n",
403 | "y_p=[]"
404 | ]
405 | },
406 | {
407 | "cell_type": "code",
408 | "execution_count": 19,
409 | "metadata": {
410 | "collapsed": false
411 | },
412 | "outputs": [
413 | {
414 | "data": {
415 | "text/plain": [
416 | "('[', 'moby')"
417 | ]
418 | },
419 | "execution_count": 19,
420 | "metadata": {},
421 | "output_type": "execute_result"
422 | }
423 | ],
424 | "source": [
425 | "train_data[0]"
426 | ]
427 | },
428 | {
429 | "cell_type": "code",
430 | "execution_count": 20,
431 | "metadata": {
432 | "collapsed": false
433 | },
434 | "outputs": [],
435 | "source": [
436 | "for tr in train_data:\n",
437 | " X_p.append(prepare_word(tr[0],word2index).view(1,-1))\n",
438 | " y_p.append(prepare_word(tr[1],word2index).view(1,-1))"
439 | ]
440 | },
441 | {
442 | "cell_type": "code",
443 | "execution_count": 21,
444 | "metadata": {
445 | "collapsed": false
446 | },
447 | "outputs": [],
448 | "source": [
449 | "train_data = list(zip(X_p,y_p))"
450 | ]
451 | },
452 | {
453 | "cell_type": "code",
454 | "execution_count": 22,
455 | "metadata": {
456 | "collapsed": false
457 | },
458 | "outputs": [
459 | {
460 | "data": {
461 | "text/plain": [
462 | "7606"
463 | ]
464 | },
465 | "execution_count": 22,
466 | "metadata": {},
467 | "output_type": "execute_result"
468 | }
469 | ],
470 | "source": [
471 | "len(train_data)"
472 | ]
473 | },
474 | {
475 | "cell_type": "markdown",
476 | "metadata": {},
477 | "source": [
478 | "## Modeling"
479 | ]
480 | },
481 | {
482 | "cell_type": "markdown",
483 | "metadata": {},
484 | "source": [
485 | "
\n",
486 | "borrowed image from http://web.stanford.edu/class/cs224n/lectures/cs224n-2017-lecture2.pdf"
487 | ]
488 | },
489 | {
490 | "cell_type": "code",
491 | "execution_count": 59,
492 | "metadata": {
493 | "collapsed": true
494 | },
495 | "outputs": [],
496 | "source": [
497 | "class Skipgram(nn.Module):\n",
498 | " \n",
499 | " def __init__(self, vocab_size,projection_dim):\n",
500 | " super(Skipgram,self).__init__()\n",
501 | " self.embedding_v = nn.Embedding(vocab_size, projection_dim)\n",
502 | " self.embedding_u = nn.Embedding(vocab_size, projection_dim)\n",
503 | "\n",
504 | " self.embedding_v.weight.data.uniform_(-1, 1) # init\n",
505 | " self.embedding_u.weight.data.uniform_(0, 0) # init\n",
506 | " #self.out = nn.Linear(projection_dim,vocab_size)\n",
507 | " def forward(self, center_words,target_words, outer_words):\n",
508 | " center_embeds = self.embedding_v(center_words) # B x 1 x D\n",
509 | " target_embeds = self.embedding_u(target_words) # B x 1 x D\n",
510 | " outer_embeds = self.embedding_u(outer_words) # B x V x D\n",
511 | " \n",
512 | " scores = target_embeds.bmm(center_embeds.transpose(1,2)).squeeze(2) # Bx1xD * BxDx1 => Bx1\n",
513 | " norm_scores = outer_embeds.bmm(center_embeds.transpose(1,2)).squeeze(2) # BxVxD * BxDx1 => BxV\n",
514 | " \n",
515 | " nll = -torch.mean(torch.log(torch.exp(scores)/torch.sum(torch.exp(norm_scores),1).unsqueeze(1))) # log-softmax\n",
516 | " \n",
517 | " return nll # negative log likelihood\n",
518 | " \n",
519 | " def prediction(self, inputs):\n",
520 | " embeds = self.embedding_v(inputs)\n",
521 | " \n",
522 | " return embeds "
523 | ]
524 | },
525 | {
526 | "cell_type": "markdown",
527 | "metadata": {},
528 | "source": [
529 | "## Train "
530 | ]
531 | },
532 | {
533 | "cell_type": "code",
534 | "execution_count": 60,
535 | "metadata": {
536 | "collapsed": true
537 | },
538 | "outputs": [],
539 | "source": [
540 | "EMBEDDING_SIZE = 30\n",
541 | "BATCH_SIZE = 256\n",
542 | "EPOCH = 100"
543 | ]
544 | },
545 | {
546 | "cell_type": "code",
547 | "execution_count": 61,
548 | "metadata": {
549 | "collapsed": true
550 | },
551 | "outputs": [],
552 | "source": [
553 | "losses = []\n",
554 | "model = Skipgram(len(word2index),EMBEDDING_SIZE)\n",
555 | "if USE_CUDA:\n",
556 | " model = model.cuda()\n",
557 | "optimizer = optim.Adam(model.parameters(), lr=0.01)"
558 | ]
559 | },
560 | {
561 | "cell_type": "code",
562 | "execution_count": 62,
563 | "metadata": {
564 | "collapsed": false
565 | },
566 | "outputs": [
567 | {
568 | "name": "stdout",
569 | "output_type": "stream",
570 | "text": [
571 | "Epoch : 0, mean_loss : 6.20\n",
572 | "Epoch : 10, mean_loss : 4.38\n",
573 | "Epoch : 20, mean_loss : 3.48\n",
574 | "Epoch : 30, mean_loss : 3.31\n",
575 | "Epoch : 40, mean_loss : 3.26\n",
576 | "Epoch : 50, mean_loss : 3.24\n",
577 | "Epoch : 60, mean_loss : 3.22\n",
578 | "Epoch : 70, mean_loss : 3.22\n",
579 | "Epoch : 80, mean_loss : 3.21\n",
580 | "Epoch : 90, mean_loss : 3.20\n"
581 | ]
582 | }
583 | ],
584 | "source": [
585 | "for epoch in range(EPOCH):\n",
586 | " for i,batch in enumerate(getBatch(BATCH_SIZE,train_data)):\n",
587 | " \n",
588 | " inputs, targets = zip(*batch)\n",
589 | " \n",
590 | " inputs = torch.cat(inputs) # B x 1\n",
591 | " targets = torch.cat(targets) # B x 1\n",
592 | " vocabs = prepare_sequence(list(vocab),word2index).expand(inputs.size(0),len(vocab)) # B x V\n",
593 | " model.zero_grad()\n",
594 | "\n",
595 | " loss = model(inputs,targets,vocabs)\n",
596 | " \n",
597 | " loss.backward()\n",
598 | " optimizer.step()\n",
599 | " \n",
600 | " losses.append(loss.data.tolist()[0])\n",
601 | "\n",
602 | " if epoch % 10==0:\n",
603 | " print(\"Epoch : %d, mean_loss : %.02f\" % (epoch,np.mean(losses)))\n",
604 | " losses=[]"
605 | ]
606 | },
607 | {
608 | "cell_type": "markdown",
609 | "metadata": {},
610 | "source": [
611 | "## Test"
612 | ]
613 | },
614 | {
615 | "cell_type": "code",
616 | "execution_count": 63,
617 | "metadata": {
618 | "collapsed": true
619 | },
620 | "outputs": [],
621 | "source": [
622 | "def word_similarity(target,vocab):\n",
623 | " if USE_CUDA:\n",
624 | " target_V = model.prediction(prepare_word(target,word2index))\n",
625 | " else:\n",
626 | " target_V = model.prediction(prepare_word(target,word2index))\n",
627 | " similarities=[]\n",
628 | " for i in range(len(vocab)):\n",
629 | " if vocab[i] == target: continue\n",
630 | " \n",
631 | " if USE_CUDA:\n",
632 | " vector = model.prediction(prepare_word(list(vocab)[i],word2index))\n",
633 | " else:\n",
634 | " vector = model.prediction(prepare_word(list(vocab)[i],word2index))\n",
635 | " cosine_sim = F.cosine_similarity(target_V,vector).data.tolist()[0] \n",
636 | " similarities.append([vocab[i],cosine_sim])\n",
637 | " return sorted(similarities, key=lambda x: x[1], reverse=True)[:10] # sort by similarity"
638 | ]
639 | },
640 | {
641 | "cell_type": "code",
642 | "execution_count": 64,
643 | "metadata": {
644 | "collapsed": false
645 | },
646 | "outputs": [
647 | {
648 | "data": {
649 | "text/plain": [
650 | "'least'"
651 | ]
652 | },
653 | "execution_count": 64,
654 | "metadata": {},
655 | "output_type": "execute_result"
656 | }
657 | ],
658 | "source": [
659 | "test = random.choice(list(vocab))\n",
660 | "test"
661 | ]
662 | },
663 | {
664 | "cell_type": "code",
665 | "execution_count": 65,
666 | "metadata": {
667 | "collapsed": false
668 | },
669 | "outputs": [
670 | {
671 | "data": {
672 | "text/plain": [
673 | "[['at', 0.8147411346435547],\n",
674 | " ['every', 0.7143548130989075],\n",
675 | " ['case', 0.6975079774856567],\n",
676 | " ['secure', 0.6121522188186646],\n",
677 | " ['heart', 0.5974172949790955],\n",
678 | " ['including', 0.5867112278938293],\n",
679 | " ['please', 0.5557640194892883],\n",
680 | " ['has', 0.5536234974861145],\n",
681 | " ['while', 0.5366998314857483],\n",
682 | " ['you', 0.509368896484375]]"
683 | ]
684 | },
685 | "execution_count": 65,
686 | "metadata": {},
687 | "output_type": "execute_result"
688 | }
689 | ],
690 | "source": [
691 | "word_similarity(test,vocab)"
692 | ]
693 | },
694 | {
695 | "cell_type": "code",
696 | "execution_count": null,
697 | "metadata": {
698 | "collapsed": true
699 | },
700 | "outputs": [],
701 | "source": []
702 | }
703 | ],
704 | "metadata": {
705 | "kernelspec": {
706 | "display_name": "Python 3",
707 | "language": "python",
708 | "name": "python3"
709 | },
710 | "language_info": {
711 | "codemirror_mode": {
712 | "name": "ipython",
713 | "version": 3
714 | },
715 | "file_extension": ".py",
716 | "mimetype": "text/x-python",
717 | "name": "python",
718 | "nbconvert_exporter": "python",
719 | "pygments_lexer": "ipython3",
720 | "version": "3.5.2"
721 | }
722 | },
723 | "nbformat": 4,
724 | "nbformat_minor": 2
725 | }
726 |
--------------------------------------------------------------------------------
/notebooks/02.Skip-gram-Negative-Sampling.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# 2. Skip-gram with negative sampling"
8 | ]
9 | },
10 | {
11 | "cell_type": "markdown",
12 | "metadata": {},
13 | "source": [
14 | "I recommend you take a look at these material first."
15 | ]
16 | },
17 | {
18 | "cell_type": "markdown",
19 | "metadata": {},
20 | "source": [
21 | "* http://web.stanford.edu/class/cs224n/lectures/cs224n-2017-lecture3.pdf\n",
22 | "* http://papers.nips.cc/paper/5021-distributed-representations-of-words-and-phrases-and-their-compositionality.pdf"
23 | ]
24 | },
25 | {
26 | "cell_type": "code",
27 | "execution_count": 1,
28 | "metadata": {
29 | "collapsed": true
30 | },
31 | "outputs": [],
32 | "source": [
33 | "import torch\n",
34 | "import torch.nn as nn\n",
35 | "from torch.autograd import Variable\n",
36 | "import torch.optim as optim\n",
37 | "import torch.nn.functional as F\n",
38 | "import nltk\n",
39 | "import random\n",
40 | "import numpy as np\n",
41 | "from collections import Counter\n",
42 | "flatten = lambda l: [item for sublist in l for item in sublist]"
43 | ]
44 | },
45 | {
46 | "cell_type": "code",
47 | "execution_count": 2,
48 | "metadata": {
49 | "collapsed": false
50 | },
51 | "outputs": [
52 | {
53 | "name": "stdout",
54 | "output_type": "stream",
55 | "text": [
56 | "0.2.0+751198f\n",
57 | "3.2.4\n"
58 | ]
59 | }
60 | ],
61 | "source": [
62 | "print(torch.__version__)\n",
63 | "print(nltk.__version__)"
64 | ]
65 | },
66 | {
67 | "cell_type": "code",
68 | "execution_count": 3,
69 | "metadata": {
70 | "collapsed": true
71 | },
72 | "outputs": [],
73 | "source": [
74 | "USE_CUDA = torch.cuda.is_available()\n",
75 | "\n",
76 | "FloatTensor = torch.cuda.FloatTensor if USE_CUDA else torch.FloatTensor\n",
77 | "LongTensor = torch.cuda.LongTensor if USE_CUDA else torch.LongTensor\n",
78 | "ByteTensor = torch.cuda.ByteTensor if USE_CUDA else torch.ByteTensor"
79 | ]
80 | },
81 | {
82 | "cell_type": "code",
83 | "execution_count": 4,
84 | "metadata": {
85 | "collapsed": true
86 | },
87 | "outputs": [],
88 | "source": [
89 | "def getBatch(batch_size,train_data):\n",
90 | " random.shuffle(train_data)\n",
91 | " sindex=0\n",
92 | " eindex=batch_size\n",
93 | " while eindex < len(train_data):\n",
94 | " batch = train_data[sindex:eindex]\n",
95 | " temp = eindex\n",
96 | " eindex = eindex+batch_size\n",
97 | " sindex = temp\n",
98 | " yield batch\n",
99 | " \n",
100 | " if eindex >= len(train_data):\n",
101 | " batch = train_data[sindex:]\n",
102 | " yield batch"
103 | ]
104 | },
105 | {
106 | "cell_type": "code",
107 | "execution_count": 5,
108 | "metadata": {
109 | "collapsed": true
110 | },
111 | "outputs": [],
112 | "source": [
113 | "def prepare_sequence(seq, word2index):\n",
114 | " idxs = list(map(lambda w: word2index[w] if w in word2index.keys() else word2index[\"\"], seq))\n",
115 | " return Variable(LongTensor(idxs))\n",
116 | "\n",
117 | "def prepare_word(word,word2index):\n",
118 | " return Variable(LongTensor([word2index[word]]) if word in word2index.keys() else LongTensor([word2index[\"\"]]))"
119 | ]
120 | },
121 | {
122 | "cell_type": "markdown",
123 | "metadata": {},
124 | "source": [
125 | "## Data load and Preprocessing "
126 | ]
127 | },
128 | {
129 | "cell_type": "code",
130 | "execution_count": 6,
131 | "metadata": {
132 | "collapsed": true
133 | },
134 | "outputs": [],
135 | "source": [
136 | "corpus = list(nltk.corpus.gutenberg.sents('melville-moby_dick.txt'))[:500]\n",
137 | "corpus = [[word.lower() for word in sent] for sent in corpus]"
138 | ]
139 | },
140 | {
141 | "cell_type": "markdown",
142 | "metadata": {},
143 | "source": [
144 | "### Exclude sparse words "
145 | ]
146 | },
147 | {
148 | "cell_type": "code",
149 | "execution_count": 7,
150 | "metadata": {
151 | "collapsed": true
152 | },
153 | "outputs": [],
154 | "source": [
155 | "word_count = Counter(flatten(corpus))"
156 | ]
157 | },
158 | {
159 | "cell_type": "code",
160 | "execution_count": 8,
161 | "metadata": {
162 | "collapsed": true
163 | },
164 | "outputs": [],
165 | "source": [
166 | "MIN_COUNT=3\n",
167 | "exclude=[]"
168 | ]
169 | },
170 | {
171 | "cell_type": "code",
172 | "execution_count": 9,
173 | "metadata": {
174 | "collapsed": true
175 | },
176 | "outputs": [],
177 | "source": [
178 | "for w,c in word_count.items():\n",
179 | " if c']*WINDOW_SIZE+c+['']*WINDOW_SIZE,WINDOW_SIZE*2+1)) for c in corpus])\n",
227 | "\n",
228 | "train_data = []\n",
229 | "\n",
230 | "for window in windows:\n",
231 | " for i in range(WINDOW_SIZE*2+1):\n",
232 | " if window[i] in exclude or window[WINDOW_SIZE] in exclude: continue # min_count\n",
233 | " if i==WINDOW_SIZE or window[i]=='': continue\n",
234 | " train_data.append((window[WINDOW_SIZE],window[i]))\n",
235 | "\n",
236 | "X_p=[]\n",
237 | "y_p=[]\n",
238 | "\n",
239 | "for tr in train_data:\n",
240 | " X_p.append(prepare_word(tr[0],word2index).view(1,-1))\n",
241 | " y_p.append(prepare_word(tr[1],word2index).view(1,-1))\n",
242 | " \n",
243 | "train_data = list(zip(X_p,y_p))"
244 | ]
245 | },
246 | {
247 | "cell_type": "code",
248 | "execution_count": 13,
249 | "metadata": {
250 | "collapsed": false
251 | },
252 | "outputs": [
253 | {
254 | "data": {
255 | "text/plain": [
256 | "50242"
257 | ]
258 | },
259 | "execution_count": 13,
260 | "metadata": {},
261 | "output_type": "execute_result"
262 | }
263 | ],
264 | "source": [
265 | "len(train_data)"
266 | ]
267 | },
268 | {
269 | "cell_type": "markdown",
270 | "metadata": {},
271 | "source": [
272 | "### Build Unigram Distribution**0.75 "
273 | ]
274 | },
275 | {
276 | "cell_type": "markdown",
277 | "metadata": {},
278 | "source": [
279 | "$$P(w)=U(w)^{3/4}/Z$$"
280 | ]
281 | },
282 | {
283 | "cell_type": "code",
284 | "execution_count": 14,
285 | "metadata": {
286 | "collapsed": true
287 | },
288 | "outputs": [],
289 | "source": [
290 | "Z = 0.001"
291 | ]
292 | },
293 | {
294 | "cell_type": "code",
295 | "execution_count": 15,
296 | "metadata": {
297 | "collapsed": true
298 | },
299 | "outputs": [],
300 | "source": [
301 | "word_count = Counter(flatten(corpus))\n",
302 | "num_total_words = sum([c for w,c in word_count.items() if w not in exclude])"
303 | ]
304 | },
305 | {
306 | "cell_type": "code",
307 | "execution_count": 16,
308 | "metadata": {
309 | "collapsed": true
310 | },
311 | "outputs": [],
312 | "source": [
313 | "unigram_table=[]\n",
314 | "\n",
315 | "for vo in vocab:\n",
316 | " unigram_table.extend([vo]*int(((word_count[vo]/num_total_words)**0.75)/Z))"
317 | ]
318 | },
319 | {
320 | "cell_type": "code",
321 | "execution_count": 17,
322 | "metadata": {
323 | "collapsed": false
324 | },
325 | "outputs": [
326 | {
327 | "name": "stdout",
328 | "output_type": "stream",
329 | "text": [
330 | "478 3500\n"
331 | ]
332 | }
333 | ],
334 | "source": [
335 | "print(len(vocab),len(unigram_table))"
336 | ]
337 | },
338 | {
339 | "cell_type": "markdown",
340 | "metadata": {},
341 | "source": [
342 | "### Negative Sampling "
343 | ]
344 | },
345 | {
346 | "cell_type": "code",
347 | "execution_count": 18,
348 | "metadata": {
349 | "collapsed": true
350 | },
351 | "outputs": [],
352 | "source": [
353 | "def negative_sampling(targets,unigram_table,k):\n",
354 | " batch_size = targets.size(0)\n",
355 | " neg_samples=[]\n",
356 | " for i in range(batch_size):\n",
357 | " nsample=[]\n",
358 | " target_index = targets[i].data.cpu().tolist()[0] if USE_CUDA else targets[i].data.tolist()[0]\n",
359 | " while len(nsample)\n",
380 | "borrowed image from http://web.stanford.edu/class/cs224n/lectures/cs224n-2017-lecture3.pdf"
381 | ]
382 | },
383 | {
384 | "cell_type": "code",
385 | "execution_count": 19,
386 | "metadata": {
387 | "collapsed": true
388 | },
389 | "outputs": [],
390 | "source": [
391 | "class SkipgramNegSampling(nn.Module):\n",
392 | " \n",
393 | " def __init__(self, vocab_size,projection_dim):\n",
394 | " super(SkipgramNegSampling,self).__init__()\n",
395 | " self.embedding_v = nn.Embedding(vocab_size, projection_dim) # center embedding\n",
396 | " self.embedding_u = nn.Embedding(vocab_size, projection_dim) # out embedding\n",
397 | " self.logsigmoid = nn.LogSigmoid()\n",
398 | " \n",
399 | " initrange = (2.0 / (vocab_size+projection_dim))**0.5 # Xavier init\n",
400 | " self.embedding_v.weight.data.uniform_(-initrange, initrange) # init\n",
401 | " self.embedding_u.weight.data.uniform_(-0.0, 0.0) # init\n",
402 | " \n",
403 | " def forward(self, center_words,target_words,negative_words):\n",
404 | " center_embeds = self.embedding_v(center_words) # B x 1 x D\n",
405 | " target_embeds = self.embedding_u(target_words) # B x 1 x D\n",
406 | " \n",
407 | " neg_embeds = -self.embedding_u(negative_words) # B x K x D\n",
408 | " \n",
409 | " positive_score = target_embeds.bmm(center_embeds.transpose(1,2)).squeeze(2) # Bx1\n",
410 | " negative_score = torch.sum(neg_embeds.bmm(center_embeds.transpose(1,2)).squeeze(2),1).view(negs.size(0),-1) # BxK -> Bx1\n",
411 | " \n",
412 | " loss = self.logsigmoid(positive_score) + self.logsigmoid(negative_score)\n",
413 | " \n",
414 | " return -torch.mean(loss)\n",
415 | " \n",
416 | " def prediction(self, inputs):\n",
417 | " embeds = self.embedding_v(inputs)\n",
418 | " \n",
419 | " return embeds"
420 | ]
421 | },
422 | {
423 | "cell_type": "markdown",
424 | "metadata": {},
425 | "source": [
426 | "## Train "
427 | ]
428 | },
429 | {
430 | "cell_type": "code",
431 | "execution_count": 68,
432 | "metadata": {
433 | "collapsed": true
434 | },
435 | "outputs": [],
436 | "source": [
437 | "EMBEDDING_SIZE = 30 \n",
438 | "BATCH_SIZE = 256\n",
439 | "EPOCH = 100\n",
440 | "NEG=10 # Num of Negative Sampling"
441 | ]
442 | },
443 | {
444 | "cell_type": "code",
445 | "execution_count": 69,
446 | "metadata": {
447 | "collapsed": true
448 | },
449 | "outputs": [],
450 | "source": [
451 | "losses = []\n",
452 | "model = SkipgramNegSampling(len(word2index),EMBEDDING_SIZE)\n",
453 | "if USE_CUDA:\n",
454 | " model = model.cuda()\n",
455 | "optimizer = optim.Adam(model.parameters(), lr=0.001)"
456 | ]
457 | },
458 | {
459 | "cell_type": "code",
460 | "execution_count": 70,
461 | "metadata": {
462 | "collapsed": false
463 | },
464 | "outputs": [
465 | {
466 | "name": "stdout",
467 | "output_type": "stream",
468 | "text": [
469 | "Epoch : 0, mean_loss : 1.06\n",
470 | "Epoch : 10, mean_loss : 0.86\n",
471 | "Epoch : 20, mean_loss : 0.79\n",
472 | "Epoch : 30, mean_loss : 0.74\n",
473 | "Epoch : 40, mean_loss : 0.71\n",
474 | "Epoch : 50, mean_loss : 0.69\n",
475 | "Epoch : 60, mean_loss : 0.67\n",
476 | "Epoch : 70, mean_loss : 0.65\n",
477 | "Epoch : 80, mean_loss : 0.64\n",
478 | "Epoch : 90, mean_loss : 0.63\n"
479 | ]
480 | }
481 | ],
482 | "source": [
483 | "for epoch in range(EPOCH):\n",
484 | " for i,batch in enumerate(getBatch(BATCH_SIZE,train_data)):\n",
485 | " \n",
486 | " inputs, targets = zip(*batch)\n",
487 | " \n",
488 | " inputs = torch.cat(inputs) # B x 1\n",
489 | " targets = torch.cat(targets) # B x 1\n",
490 | " negs = negative_sampling(targets,unigram_table,NEG)\n",
491 | " model.zero_grad()\n",
492 | "\n",
493 | " loss = model(inputs,targets,negs)\n",
494 | " \n",
495 | " loss.backward()\n",
496 | " optimizer.step()\n",
497 | " \n",
498 | " losses.append(loss.data.tolist()[0])\n",
499 | " if epoch % 10==0:\n",
500 | " print(\"Epoch : %d, mean_loss : %.02f\" % (epoch,np.mean(losses)))\n",
501 | " losses=[]"
502 | ]
503 | },
504 | {
505 | "cell_type": "markdown",
506 | "metadata": {},
507 | "source": [
508 | "## Test "
509 | ]
510 | },
511 | {
512 | "cell_type": "code",
513 | "execution_count": 71,
514 | "metadata": {
515 | "collapsed": true
516 | },
517 | "outputs": [],
518 | "source": [
519 | "def word_similarity(target,vocab):\n",
520 | " if USE_CUDA:\n",
521 | " target_V = model.prediction(prepare_word(target,word2index))\n",
522 | " else:\n",
523 | " target_V = model.prediction(prepare_word(target,word2index))\n",
524 | " similarities=[]\n",
525 | " for i in range(len(vocab)):\n",
526 | " if vocab[i] == target: continue\n",
527 | " \n",
528 | " if USE_CUDA:\n",
529 | " vector = model.prediction(prepare_word(list(vocab)[i],word2index))\n",
530 | " else:\n",
531 | " vector = model.prediction(prepare_word(list(vocab)[i],word2index))\n",
532 | " \n",
533 | " cosine_sim = F.cosine_similarity(target_V,vector).data.tolist()[0]\n",
534 | " similarities.append([vocab[i],cosine_sim])\n",
535 | " return sorted(similarities, key=lambda x: x[1], reverse=True)[:10]"
536 | ]
537 | },
538 | {
539 | "cell_type": "code",
540 | "execution_count": 212,
541 | "metadata": {
542 | "collapsed": false
543 | },
544 | "outputs": [
545 | {
546 | "data": {
547 | "text/plain": [
548 | "'passengers'"
549 | ]
550 | },
551 | "execution_count": 212,
552 | "metadata": {},
553 | "output_type": "execute_result"
554 | }
555 | ],
556 | "source": [
557 | "test = random.choice(list(vocab))\n",
558 | "test"
559 | ]
560 | },
561 | {
562 | "cell_type": "code",
563 | "execution_count": 213,
564 | "metadata": {
565 | "collapsed": false
566 | },
567 | "outputs": [
568 | {
569 | "data": {
570 | "text/plain": [
571 | "[['am', 0.7353377342224121],\n",
572 | " ['passenger', 0.7154150605201721],\n",
573 | " ['cook', 0.6829826831817627],\n",
574 | " ['new', 0.6648461818695068],\n",
575 | " ['bedford', 0.6283411383628845],\n",
576 | " ['besides', 0.5972960591316223],\n",
577 | " ['themselves', 0.5964340567588806],\n",
578 | " ['grow', 0.5957046151161194],\n",
579 | " ['tell', 0.5952941179275513],\n",
580 | " ['get', 0.5943044424057007]]"
581 | ]
582 | },
583 | "execution_count": 213,
584 | "metadata": {},
585 | "output_type": "execute_result"
586 | }
587 | ],
588 | "source": [
589 | "word_similarity(test,vocab)"
590 | ]
591 | },
592 | {
593 | "cell_type": "code",
594 | "execution_count": null,
595 | "metadata": {
596 | "collapsed": true
597 | },
598 | "outputs": [],
599 | "source": []
600 | }
601 | ],
602 | "metadata": {
603 | "kernelspec": {
604 | "display_name": "Python 3",
605 | "language": "python",
606 | "name": "python3"
607 | },
608 | "language_info": {
609 | "codemirror_mode": {
610 | "name": "ipython",
611 | "version": 3
612 | },
613 | "file_extension": ".py",
614 | "mimetype": "text/x-python",
615 | "name": "python",
616 | "nbconvert_exporter": "python",
617 | "pygments_lexer": "ipython3",
618 | "version": "3.5.2"
619 | }
620 | },
621 | "nbformat": 4,
622 | "nbformat_minor": 2
623 | }
624 |
--------------------------------------------------------------------------------
/notebooks/03.GloVe.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# 3. GloVe: Global Vectors for Word Representation"
8 | ]
9 | },
10 | {
11 | "cell_type": "markdown",
12 | "metadata": {},
13 | "source": [
14 | "I recommend you take a look at these material first."
15 | ]
16 | },
17 | {
18 | "cell_type": "markdown",
19 | "metadata": {},
20 | "source": [
21 | "* http://web.stanford.edu/class/cs224n/lectures/cs224n-2017-lecture3.pdf\n",
22 | "* https://nlp.stanford.edu/pubs/glove.pdf"
23 | ]
24 | },
25 | {
26 | "cell_type": "code",
27 | "execution_count": 2,
28 | "metadata": {
29 | "collapsed": true
30 | },
31 | "outputs": [],
32 | "source": [
33 | "import torch\n",
34 | "import torch.nn as nn\n",
35 | "from torch.autograd import Variable\n",
36 | "import torch.optim as optim\n",
37 | "import torch.nn.functional as F\n",
38 | "import nltk\n",
39 | "import random\n",
40 | "import numpy as np\n",
41 | "from collections import Counter\n",
42 | "flatten = lambda l: [item for sublist in l for item in sublist]"
43 | ]
44 | },
45 | {
46 | "cell_type": "code",
47 | "execution_count": 2,
48 | "metadata": {
49 | "collapsed": false
50 | },
51 | "outputs": [
52 | {
53 | "name": "stdout",
54 | "output_type": "stream",
55 | "text": [
56 | "0.2.0+751198f\n",
57 | "3.2.4\n"
58 | ]
59 | }
60 | ],
61 | "source": [
62 | "print(torch.__version__)\n",
63 | "print(nltk.__version__)"
64 | ]
65 | },
66 | {
67 | "cell_type": "code",
68 | "execution_count": 3,
69 | "metadata": {
70 | "collapsed": true
71 | },
72 | "outputs": [],
73 | "source": [
74 | "USE_CUDA = torch.cuda.is_available()\n",
75 | "\n",
76 | "FloatTensor = torch.cuda.FloatTensor if USE_CUDA else torch.FloatTensor\n",
77 | "LongTensor = torch.cuda.LongTensor if USE_CUDA else torch.LongTensor\n",
78 | "ByteTensor = torch.cuda.ByteTensor if USE_CUDA else torch.ByteTensor"
79 | ]
80 | },
81 | {
82 | "cell_type": "code",
83 | "execution_count": 4,
84 | "metadata": {
85 | "collapsed": true
86 | },
87 | "outputs": [],
88 | "source": [
89 | "def getBatch(batch_size,train_data):\n",
90 | " random.shuffle(train_data)\n",
91 | " sindex=0\n",
92 | " eindex=batch_size\n",
93 | " while eindex < len(train_data):\n",
94 | " batch = train_data[sindex:eindex]\n",
95 | " temp = eindex\n",
96 | " eindex = eindex+batch_size\n",
97 | " sindex = temp\n",
98 | " yield batch\n",
99 | " \n",
100 | " if eindex >= len(train_data):\n",
101 | " batch = train_data[sindex:]\n",
102 | " yield batch"
103 | ]
104 | },
105 | {
106 | "cell_type": "code",
107 | "execution_count": 5,
108 | "metadata": {
109 | "collapsed": true
110 | },
111 | "outputs": [],
112 | "source": [
113 | "def prepare_sequence(seq, word2index):\n",
114 | " idxs = list(map(lambda w: word2index[w] if w in word2index.keys() else word2index[\"\"], seq))\n",
115 | " return Variable(LongTensor(idxs))\n",
116 | "\n",
117 | "def prepare_word(word,word2index):\n",
118 | " return Variable(LongTensor([word2index[word]]) if word in word2index.keys() else LongTensor([word2index[\"\"]]))"
119 | ]
120 | },
121 | {
122 | "cell_type": "markdown",
123 | "metadata": {},
124 | "source": [
125 | "## Data load and Preprocessing "
126 | ]
127 | },
128 | {
129 | "cell_type": "code",
130 | "execution_count": 6,
131 | "metadata": {
132 | "collapsed": true
133 | },
134 | "outputs": [],
135 | "source": [
136 | "corpus = list(nltk.corpus.gutenberg.sents('melville-moby_dick.txt'))[:500]\n",
137 | "corpus = [[word.lower() for word in sent] for sent in corpus]"
138 | ]
139 | },
140 | {
141 | "cell_type": "markdown",
142 | "metadata": {},
143 | "source": [
144 | "### Build vocab"
145 | ]
146 | },
147 | {
148 | "cell_type": "code",
149 | "execution_count": 7,
150 | "metadata": {
151 | "collapsed": true
152 | },
153 | "outputs": [],
154 | "source": [
155 | "vocab = list(set(flatten(corpus)))"
156 | ]
157 | },
158 | {
159 | "cell_type": "code",
160 | "execution_count": 8,
161 | "metadata": {
162 | "collapsed": true
163 | },
164 | "outputs": [],
165 | "source": [
166 | "word2index={}\n",
167 | "for vo in vocab:\n",
168 | " if vo not in word2index.keys():\n",
169 | " word2index[vo]=len(word2index)\n",
170 | " \n",
171 | "index2word={v:k for k,v in word2index.items()}"
172 | ]
173 | },
174 | {
175 | "cell_type": "code",
176 | "execution_count": 9,
177 | "metadata": {
178 | "collapsed": true
179 | },
180 | "outputs": [],
181 | "source": [
182 | "WINDOW_SIZE = 5\n",
183 | "windows = flatten([list(nltk.ngrams(['']*WINDOW_SIZE+c+['']*WINDOW_SIZE,WINDOW_SIZE*2+1)) for c in corpus])\n",
184 | "\n",
185 | "window_data = []\n",
186 | "\n",
187 | "for window in windows:\n",
188 | " for i in range(WINDOW_SIZE*2+1):\n",
189 | " if i==WINDOW_SIZE or window[i]=='': continue\n",
190 | " window_data.append((window[WINDOW_SIZE],window[i]))\n"
191 | ]
192 | },
193 | {
194 | "cell_type": "markdown",
195 | "metadata": {},
196 | "source": [
197 | "### Weighting Function "
198 | ]
199 | },
200 | {
201 | "cell_type": "markdown",
202 | "metadata": {},
203 | "source": [
204 | "
\n",
205 | "borrowed image from https://nlp.stanford.edu/pubs/glove.pdf"
206 | ]
207 | },
208 | {
209 | "cell_type": "code",
210 | "execution_count": 10,
211 | "metadata": {
212 | "collapsed": true
213 | },
214 | "outputs": [],
215 | "source": [
216 | "def weighting(w_i,w_j):\n",
217 | " try:\n",
218 | " x_ij = X_ik[(w_i,w_j)]\n",
219 | " except:\n",
220 | " x_ij = 1\n",
221 | " \n",
222 | " x_max = 100 #100 # fixed in paper\n",
223 | " alpha = 0.75\n",
224 | " \n",
225 | " if x_ij < x_max:\n",
226 | " result = (x_ij/x_max)**alpha\n",
227 | " else:\n",
228 | " result = 1\n",
229 | " \n",
230 | " return result"
231 | ]
232 | },
233 | {
234 | "cell_type": "markdown",
235 | "metadata": {},
236 | "source": [
237 | "### Build Co-occurence Matrix X"
238 | ]
239 | },
240 | {
241 | "cell_type": "markdown",
242 | "metadata": {},
243 | "source": [
244 | "Because of model complexity, It is important to determine whether a tighter bound can be placed on the number of nonzero elements of X."
245 | ]
246 | },
247 | {
248 | "cell_type": "code",
249 | "execution_count": 11,
250 | "metadata": {
251 | "collapsed": true
252 | },
253 | "outputs": [],
254 | "source": [
255 | "X_i = Counter(flatten(corpus)) # X_i"
256 | ]
257 | },
258 | {
259 | "cell_type": "code",
260 | "execution_count": 12,
261 | "metadata": {
262 | "collapsed": true
263 | },
264 | "outputs": [],
265 | "source": [
266 | "X_ik_window_5 = Counter(window_data) # Co-occurece in window size 5"
267 | ]
268 | },
269 | {
270 | "cell_type": "code",
271 | "execution_count": 13,
272 | "metadata": {
273 | "collapsed": true
274 | },
275 | "outputs": [],
276 | "source": [
277 | "X_ik={}\n",
278 | "weighting_dic={}"
279 | ]
280 | },
281 | {
282 | "cell_type": "code",
283 | "execution_count": 14,
284 | "metadata": {
285 | "collapsed": true
286 | },
287 | "outputs": [],
288 | "source": [
289 | "from itertools import combinations_with_replacement"
290 | ]
291 | },
292 | {
293 | "cell_type": "code",
294 | "execution_count": 15,
295 | "metadata": {
296 | "collapsed": true
297 | },
298 | "outputs": [],
299 | "source": [
300 | "for bigram in combinations_with_replacement(vocab, 2):\n",
301 | " if bigram in X_ik_window_5.keys(): # nonzero elements\n",
302 | " co_occer = X_ik_window_5[bigram]\n",
303 | " X_ik[bigram]=co_occer+1 # log(Xik) -> log(Xik+1) to prevent divergence\n",
304 | " X_ik[(bigram[1],bigram[0])]=co_occer+1\n",
305 | " else:\n",
306 | " pass\n",
307 | " \n",
308 | " weighting_dic[bigram] = weighting(bigram[0],bigram[1])\n",
309 | " weighting_dic[(bigram[1],bigram[0])] = weighting(bigram[1],bigram[0])"
310 | ]
311 | },
312 | {
313 | "cell_type": "code",
314 | "execution_count": 16,
315 | "metadata": {
316 | "collapsed": false
317 | },
318 | "outputs": [
319 | {
320 | "name": "stdout",
321 | "output_type": "stream",
322 | "text": [
323 | "(',', 'was')\n",
324 | "True\n"
325 | ]
326 | }
327 | ],
328 | "source": [
329 | "test = random.choice(window_data)\n",
330 | "print(test)\n",
331 | "try:\n",
332 | " print(X_ik[(test[0],test[1])]==X_ik[(test[1],test[0])])\n",
333 | "except:\n",
334 | " 1"
335 | ]
336 | },
337 | {
338 | "cell_type": "markdown",
339 | "metadata": {},
340 | "source": [
341 | "### Prepare train data"
342 | ]
343 | },
344 | {
345 | "cell_type": "code",
346 | "execution_count": 17,
347 | "metadata": {
348 | "collapsed": false
349 | },
350 | "outputs": [
351 | {
352 | "name": "stdout",
353 | "output_type": "stream",
354 | "text": [
355 | "(Variable containing:\n",
356 | " 703\n",
357 | "[torch.cuda.LongTensor of size 1x1 (GPU 0)]\n",
358 | ", Variable containing:\n",
359 | " 23\n",
360 | "[torch.cuda.LongTensor of size 1x1 (GPU 0)]\n",
361 | ", Variable containing:\n",
362 | " 0.6931\n",
363 | "[torch.cuda.FloatTensor of size 1x1 (GPU 0)]\n",
364 | ", Variable containing:\n",
365 | "1.00000e-02 *\n",
366 | " 5.3183\n",
367 | "[torch.cuda.FloatTensor of size 1x1 (GPU 0)]\n",
368 | ")\n"
369 | ]
370 | }
371 | ],
372 | "source": [
373 | "u_p=[] # center vec\n",
374 | "v_p=[] # context vec\n",
375 | "co_p=[] # log(x_ij)\n",
376 | "weight_p=[] # f(x_ij)\n",
377 | "\n",
378 | "for pair in window_data: \n",
379 | " u_p.append(prepare_word(pair[0],word2index).view(1,-1))\n",
380 | " v_p.append(prepare_word(pair[1],word2index).view(1,-1))\n",
381 | " \n",
382 | " try:\n",
383 | " cooc = X_ik[pair]\n",
384 | " except:\n",
385 | " cooc = 1\n",
386 | "\n",
387 | " co_p.append(torch.log(Variable(FloatTensor([cooc]))).view(1,-1))\n",
388 | " weight_p.append(Variable(FloatTensor([weighting_dic[pair]])).view(1,-1))\n",
389 | " \n",
390 | "train_data = list(zip(u_p,v_p,co_p,weight_p))\n",
391 | "del u_p\n",
392 | "del v_p\n",
393 | "del co_p\n",
394 | "del weight_p\n",
395 | "print(train_data[0]) # tuple (center vec i, context vec j log(x_ij), weight f(w_ij))"
396 | ]
397 | },
398 | {
399 | "cell_type": "markdown",
400 | "metadata": {},
401 | "source": [
402 | "## Modeling "
403 | ]
404 | },
405 | {
406 | "cell_type": "markdown",
407 | "metadata": {},
408 | "source": [
409 | "
\n",
410 | "borrowed image from https://nlp.stanford.edu/pubs/glove.pdf"
411 | ]
412 | },
413 | {
414 | "cell_type": "code",
415 | "execution_count": 19,
416 | "metadata": {
417 | "collapsed": true
418 | },
419 | "outputs": [],
420 | "source": [
421 | "class GloVe(nn.Module):\n",
422 | " \n",
423 | " def __init__(self, vocab_size,projection_dim):\n",
424 | " super(GloVe,self).__init__()\n",
425 | " self.embedding_v = nn.Embedding(vocab_size, projection_dim) # center embedding\n",
426 | " self.embedding_u = nn.Embedding(vocab_size, projection_dim) # out embedding\n",
427 | " \n",
428 | " self.v_bias = nn.Embedding(vocab_size,1)\n",
429 | " self.u_bias = nn.Embedding(vocab_size,1)\n",
430 | " \n",
431 | " initrange = (2.0 / (vocab_size+projection_dim))**0.5 # Xavier init\n",
432 | " self.embedding_v.weight.data.uniform_(-initrange, initrange) # init\n",
433 | " self.embedding_u.weight.data.uniform_(-initrange, initrange) # init\n",
434 | " self.v_bias.weight.data.uniform_(-initrange, initrange) # init\n",
435 | " self.u_bias.weight.data.uniform_(-initrange, initrange) # init\n",
436 | " \n",
437 | " def forward(self, center_words,target_words,coocs,weights):\n",
438 | " center_embeds = self.embedding_v(center_words) # B x 1 x D\n",
439 | " target_embeds = self.embedding_u(target_words) # B x 1 x D\n",
440 | " \n",
441 | " center_bias = self.v_bias(center_words).squeeze(1)\n",
442 | " target_bias = self.u_bias(target_words).squeeze(1)\n",
443 | " \n",
444 | " inner_product = target_embeds.bmm(center_embeds.transpose(1,2)).squeeze(2) # Bx1\n",
445 | " \n",
446 | " loss = weights*torch.pow(inner_product +center_bias + target_bias - coocs,2)\n",
447 | " \n",
448 | " return torch.sum(loss)\n",
449 | " \n",
450 | " def prediction(self, inputs):\n",
451 | " v_embeds = self.embedding_v(inputs) # B x 1 x D\n",
452 | " u_embeds = self.embedding_u(inputs) # B x 1 x D\n",
453 | " \n",
454 | " return v_embeds+u_embeds # final embed"
455 | ]
456 | },
457 | {
458 | "cell_type": "markdown",
459 | "metadata": {},
460 | "source": [
461 | "## Train "
462 | ]
463 | },
464 | {
465 | "cell_type": "code",
466 | "execution_count": 22,
467 | "metadata": {
468 | "collapsed": true
469 | },
470 | "outputs": [],
471 | "source": [
472 | "EMBEDDING_SIZE = 50\n",
473 | "BATCH_SIZE = 256\n",
474 | "EPOCH = 50"
475 | ]
476 | },
477 | {
478 | "cell_type": "code",
479 | "execution_count": 23,
480 | "metadata": {
481 | "collapsed": false
482 | },
483 | "outputs": [],
484 | "source": [
485 | "losses = []\n",
486 | "model = GloVe(len(word2index),EMBEDDING_SIZE)\n",
487 | "if USE_CUDA:\n",
488 | " model = model.cuda()\n",
489 | "optimizer = optim.Adam(model.parameters(), lr=0.001)"
490 | ]
491 | },
492 | {
493 | "cell_type": "code",
494 | "execution_count": 24,
495 | "metadata": {
496 | "collapsed": false
497 | },
498 | "outputs": [
499 | {
500 | "name": "stdout",
501 | "output_type": "stream",
502 | "text": [
503 | "Epoch : 0, mean_loss : 236.10\n",
504 | "Epoch : 10, mean_loss : 2.27\n",
505 | "Epoch : 20, mean_loss : 0.53\n",
506 | "Epoch : 30, mean_loss : 0.12\n",
507 | "Epoch : 40, mean_loss : 0.04\n"
508 | ]
509 | }
510 | ],
511 | "source": [
512 | "for epoch in range(EPOCH):\n",
513 | " for i,batch in enumerate(getBatch(BATCH_SIZE,train_data)):\n",
514 | " \n",
515 | " inputs, targets, coocs, weights = zip(*batch)\n",
516 | " \n",
517 | " inputs = torch.cat(inputs) # B x 1\n",
518 | " targets = torch.cat(targets) # B x 1\n",
519 | " coocs = torch.cat(coocs)\n",
520 | " weights = torch.cat(weights)\n",
521 | " model.zero_grad()\n",
522 | "\n",
523 | " loss = model(inputs,targets,coocs,weights)\n",
524 | " \n",
525 | " loss.backward()\n",
526 | " optimizer.step()\n",
527 | " \n",
528 | " losses.append(loss.data.tolist()[0])\n",
529 | " if epoch % 10==0:\n",
530 | " print(\"Epoch : %d, mean_loss : %.02f\" % (epoch,np.mean(losses)))\n",
531 | " losses=[]"
532 | ]
533 | },
534 | {
535 | "cell_type": "markdown",
536 | "metadata": {},
537 | "source": [
538 | "## Test "
539 | ]
540 | },
541 | {
542 | "cell_type": "code",
543 | "execution_count": 25,
544 | "metadata": {
545 | "collapsed": true
546 | },
547 | "outputs": [],
548 | "source": [
549 | "def word_similarity(target,vocab):\n",
550 | " if USE_CUDA:\n",
551 | " target_V = model.prediction(prepare_word(target,word2index))\n",
552 | " else:\n",
553 | " target_V = model.prediction(prepare_word(target,word2index))\n",
554 | " similarities=[]\n",
555 | " for i in range(len(vocab)):\n",
556 | " if vocab[i] == target: continue\n",
557 | " \n",
558 | " if USE_CUDA:\n",
559 | " vector = model.prediction(prepare_word(list(vocab)[i],word2index))\n",
560 | " else:\n",
561 | " vector = model.prediction(prepare_word(list(vocab)[i],word2index))\n",
562 | " \n",
563 | " cosine_sim = F.cosine_similarity(target_V,vector).data.tolist()[0] \n",
564 | " similarities.append([vocab[i],cosine_sim])\n",
565 | " return sorted(similarities, key=lambda x: x[1], reverse=True)[:10]"
566 | ]
567 | },
568 | {
569 | "cell_type": "code",
570 | "execution_count": 86,
571 | "metadata": {
572 | "collapsed": false
573 | },
574 | "outputs": [
575 | {
576 | "data": {
577 | "text/plain": [
578 | "'spiral'"
579 | ]
580 | },
581 | "execution_count": 86,
582 | "metadata": {},
583 | "output_type": "execute_result"
584 | }
585 | ],
586 | "source": [
587 | "test = random.choice(list(vocab))\n",
588 | "test"
589 | ]
590 | },
591 | {
592 | "cell_type": "code",
593 | "execution_count": 87,
594 | "metadata": {
595 | "collapsed": false
596 | },
597 | "outputs": [
598 | {
599 | "data": {
600 | "text/plain": [
601 | "[['horns', 0.9727935194969177],\n",
602 | " ['swords', 0.9076412916183472],\n",
603 | " ['hooked', 0.8984033465385437],\n",
604 | " ['thar', 0.8066437244415283],\n",
605 | " ['montaigne', 0.8062068819999695],\n",
606 | " ['rabelais', 0.789764940738678],\n",
607 | " ['orion', 0.7886737585067749],\n",
608 | " ['isaiah', 0.780662477016449],\n",
609 | " ['hamlet', 0.7799868583679199],\n",
610 | " ['colnett', 0.7792885899543762]]"
611 | ]
612 | },
613 | "execution_count": 87,
614 | "metadata": {},
615 | "output_type": "execute_result"
616 | }
617 | ],
618 | "source": [
619 | "word_similarity(test,vocab)"
620 | ]
621 | },
622 | {
623 | "cell_type": "markdown",
624 | "metadata": {
625 | "collapsed": true
626 | },
627 | "source": [
628 | "## TODO"
629 | ]
630 | },
631 | {
632 | "cell_type": "markdown",
633 | "metadata": {},
634 | "source": [
635 | "* Use sparse-matrix to build co-occurence matrix for memory efficiency"
636 | ]
637 | },
638 | {
639 | "cell_type": "markdown",
640 | "metadata": {},
641 | "source": [
642 | "## Suggested Readings"
643 | ]
644 | },
645 | {
646 | "cell_type": "markdown",
647 | "metadata": {},
648 | "source": [
649 | "* Word embeddings in 2017: Trends and future directions"
650 | ]
651 | }
652 | ],
653 | "metadata": {
654 | "kernelspec": {
655 | "display_name": "Python 3",
656 | "language": "python",
657 | "name": "python3"
658 | },
659 | "language_info": {
660 | "codemirror_mode": {
661 | "name": "ipython",
662 | "version": 3
663 | },
664 | "file_extension": ".py",
665 | "mimetype": "text/x-python",
666 | "name": "python",
667 | "nbconvert_exporter": "python",
668 | "pygments_lexer": "ipython3",
669 | "version": "3.5.2"
670 | }
671 | },
672 | "nbformat": 4,
673 | "nbformat_minor": 2
674 | }
675 |
--------------------------------------------------------------------------------
/notebooks/04.Window-Classifier-for-NER.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# 4. Word Window Classification and Neural Networks "
8 | ]
9 | },
10 | {
11 | "cell_type": "markdown",
12 | "metadata": {},
13 | "source": [
14 | "I recommend you take a look at these material first."
15 | ]
16 | },
17 | {
18 | "cell_type": "markdown",
19 | "metadata": {},
20 | "source": [
21 | "* http://web.stanford.edu/class/cs224n/lectures/cs224n-2017-lecture4.pdf\n",
22 | "* https://en.wikipedia.org/wiki/Named-entity_recognition"
23 | ]
24 | },
25 | {
26 | "cell_type": "code",
27 | "execution_count": 1,
28 | "metadata": {
29 | "collapsed": true
30 | },
31 | "outputs": [],
32 | "source": [
33 | "import torch\n",
34 | "import torch.nn as nn\n",
35 | "from torch.autograd import Variable\n",
36 | "import torch.optim as optim\n",
37 | "import torch.nn.functional as F\n",
38 | "import nltk\n",
39 | "import random\n",
40 | "import numpy as np\n",
41 | "from collections import Counter\n",
42 | "flatten = lambda l: [item for sublist in l for item in sublist]\n",
43 | "from sklearn_crfsuite import metrics"
44 | ]
45 | },
46 | {
47 | "cell_type": "markdown",
48 | "metadata": {},
49 | "source": [
50 | "You also need sklearn_crfsuite latest version for print confusion matrix"
51 | ]
52 | },
53 | {
54 | "cell_type": "code",
55 | "execution_count": 2,
56 | "metadata": {
57 | "collapsed": false
58 | },
59 | "outputs": [
60 | {
61 | "name": "stdout",
62 | "output_type": "stream",
63 | "text": [
64 | "0.2.0+751198f\n",
65 | "3.2.4\n"
66 | ]
67 | }
68 | ],
69 | "source": [
70 | "print(torch.__version__)\n",
71 | "print(nltk.__version__)"
72 | ]
73 | },
74 | {
75 | "cell_type": "code",
76 | "execution_count": 3,
77 | "metadata": {
78 | "collapsed": true
79 | },
80 | "outputs": [],
81 | "source": [
82 | "USE_CUDA = torch.cuda.is_available()\n",
83 | "\n",
84 | "FloatTensor = torch.cuda.FloatTensor if USE_CUDA else torch.FloatTensor\n",
85 | "LongTensor = torch.cuda.LongTensor if USE_CUDA else torch.LongTensor\n",
86 | "ByteTensor = torch.cuda.ByteTensor if USE_CUDA else torch.ByteTensor"
87 | ]
88 | },
89 | {
90 | "cell_type": "code",
91 | "execution_count": 4,
92 | "metadata": {
93 | "collapsed": true
94 | },
95 | "outputs": [],
96 | "source": [
97 | "def getBatch(batch_size,train_data):\n",
98 | " random.shuffle(train_data)\n",
99 | " sindex=0\n",
100 | " eindex=batch_size\n",
101 | " while eindex < len(train_data):\n",
102 | " batch = train_data[sindex:eindex]\n",
103 | " temp = eindex\n",
104 | " eindex = eindex+batch_size\n",
105 | " sindex = temp\n",
106 | " yield batch\n",
107 | " \n",
108 | " if eindex >= len(train_data):\n",
109 | " batch = train_data[sindex:]\n",
110 | " yield batch"
111 | ]
112 | },
113 | {
114 | "cell_type": "code",
115 | "execution_count": 5,
116 | "metadata": {
117 | "collapsed": true
118 | },
119 | "outputs": [],
120 | "source": [
121 | "def prepare_sequence(seq, word2index):\n",
122 | " idxs = list(map(lambda w: word2index[w] if w in word2index.keys() else word2index[\"\"], seq))\n",
123 | " return Variable(LongTensor(idxs))\n",
124 | "\n",
125 | "def prepare_word(word,word2index):\n",
126 | " return Variable(LongTensor([word2index[word]]) if word in word2index.keys() else LongTensor([word2index[\"\"]]))\n",
127 | "\n",
128 | "def prepare_tag(tag,tag2index):\n",
129 | " return Variable(LongTensor([tag2index[tag]]))"
130 | ]
131 | },
132 | {
133 | "cell_type": "markdown",
134 | "metadata": {},
135 | "source": [
136 | "## Data load and Preprocessing "
137 | ]
138 | },
139 | {
140 | "cell_type": "markdown",
141 | "metadata": {},
142 | "source": [
143 | "CoNLL-2002 Shared Task: Language-Independent Named Entity Recognition
\n",
144 | "https://www.clips.uantwerpen.be/conll2002/ner/"
145 | ]
146 | },
147 | {
148 | "cell_type": "code",
149 | "execution_count": 6,
150 | "metadata": {
151 | "collapsed": true
152 | },
153 | "outputs": [],
154 | "source": [
155 | "corpus = nltk.corpus.conll2002.iob_sents()"
156 | ]
157 | },
158 | {
159 | "cell_type": "code",
160 | "execution_count": 7,
161 | "metadata": {
162 | "collapsed": true
163 | },
164 | "outputs": [],
165 | "source": [
166 | "data=[]\n",
167 | "for cor in corpus:\n",
168 | " sent,_,tag = list(zip(*cor))\n",
169 | " data.append([sent,tag])"
170 | ]
171 | },
172 | {
173 | "cell_type": "code",
174 | "execution_count": 8,
175 | "metadata": {
176 | "collapsed": false
177 | },
178 | "outputs": [
179 | {
180 | "name": "stdout",
181 | "output_type": "stream",
182 | "text": [
183 | "35651\n",
184 | "[('Sao', 'Paulo', '(', 'Brasil', ')', ',', '23', 'may', '(', 'EFECOM', ')', '.'), ('B-LOC', 'I-LOC', 'O', 'B-LOC', 'O', 'O', 'O', 'O', 'O', 'B-ORG', 'O', 'O')]\n"
185 | ]
186 | }
187 | ],
188 | "source": [
189 | "print(len(data))\n",
190 | "print(data[0])"
191 | ]
192 | },
193 | {
194 | "cell_type": "markdown",
195 | "metadata": {},
196 | "source": [
197 | "### Build Vocab"
198 | ]
199 | },
200 | {
201 | "cell_type": "code",
202 | "execution_count": 9,
203 | "metadata": {
204 | "collapsed": true
205 | },
206 | "outputs": [],
207 | "source": [
208 | "sents,tags = list(zip(*data))\n",
209 | "vocab = list(set(flatten(sents)))\n",
210 | "tagset = list(set(flatten(tags)))"
211 | ]
212 | },
213 | {
214 | "cell_type": "code",
215 | "execution_count": 10,
216 | "metadata": {
217 | "collapsed": true
218 | },
219 | "outputs": [],
220 | "source": [
221 | "word2index={'' : 0, '' : 1} # dummy token is for start or end of sentence\n",
222 | "for vo in vocab:\n",
223 | " if vo not in word2index.keys():\n",
224 | " word2index[vo]=len(word2index)\n",
225 | "index2word = {v:k for k,v in word2index.items()}\n",
226 | "\n",
227 | "tag2index = {}\n",
228 | "for tag in tagset:\n",
229 | " if tag not in tag2index.keys():\n",
230 | " tag2index[tag]=len(tag2index)\n",
231 | "index2tag={v:k for k,v in tag2index.items()}"
232 | ]
233 | },
234 | {
235 | "cell_type": "markdown",
236 | "metadata": {},
237 | "source": [
238 | "### Prepare data"
239 | ]
240 | },
241 | {
242 | "cell_type": "markdown",
243 | "metadata": {},
244 | "source": [
245 | "Example : Classify 'Paris' in the context of this sentence with window length 2"
246 | ]
247 | },
248 | {
249 | "cell_type": "markdown",
250 | "metadata": {},
251 | "source": [
252 | "
"
253 | ]
254 | },
255 | {
256 | "cell_type": "markdown",
257 | "metadata": {},
258 | "source": [
259 | "borrowed image from http://web.stanford.edu/class/cs224n/lectures/cs224n-2017-lecture4.pdf"
260 | ]
261 | },
262 | {
263 | "cell_type": "code",
264 | "execution_count": 11,
265 | "metadata": {
266 | "collapsed": true
267 | },
268 | "outputs": [],
269 | "source": [
270 | "WINDOW_SIZE=2\n",
271 | "windows=[]"
272 | ]
273 | },
274 | {
275 | "cell_type": "code",
276 | "execution_count": 12,
277 | "metadata": {
278 | "collapsed": true
279 | },
280 | "outputs": [],
281 | "source": [
282 | "for sample in data:\n",
283 | " dummy=['']*WINDOW_SIZE\n",
284 | " window = list(nltk.ngrams(dummy+list(sample[0])+dummy,WINDOW_SIZE*2+1))\n",
285 | " windows.extend([[list(window[i]),sample[1][i]] for i in range(len(sample[0]))])"
286 | ]
287 | },
288 | {
289 | "cell_type": "code",
290 | "execution_count": 13,
291 | "metadata": {
292 | "collapsed": false
293 | },
294 | "outputs": [
295 | {
296 | "data": {
297 | "text/plain": [
298 | "[['', '', 'Sao', 'Paulo', '('], 'B-LOC']"
299 | ]
300 | },
301 | "execution_count": 13,
302 | "metadata": {},
303 | "output_type": "execute_result"
304 | }
305 | ],
306 | "source": [
307 | "windows[0]"
308 | ]
309 | },
310 | {
311 | "cell_type": "code",
312 | "execution_count": 14,
313 | "metadata": {
314 | "collapsed": false
315 | },
316 | "outputs": [
317 | {
318 | "data": {
319 | "text/plain": [
320 | "678377"
321 | ]
322 | },
323 | "execution_count": 14,
324 | "metadata": {},
325 | "output_type": "execute_result"
326 | }
327 | ],
328 | "source": [
329 | "len(windows)"
330 | ]
331 | },
332 | {
333 | "cell_type": "code",
334 | "execution_count": 15,
335 | "metadata": {
336 | "collapsed": true
337 | },
338 | "outputs": [],
339 | "source": [
340 | "random.shuffle(windows)\n",
341 | "\n",
342 | "train_data = windows[:int(len(windows)*0.9)]\n",
343 | "test_data = windows[int(len(windows)*0.9):]"
344 | ]
345 | },
346 | {
347 | "cell_type": "markdown",
348 | "metadata": {},
349 | "source": [
350 | "## Modeling "
351 | ]
352 | },
353 | {
354 | "cell_type": "markdown",
355 | "metadata": {},
356 | "source": [
357 | "
\n",
358 | "borrowed image from http://web.stanford.edu/class/cs224n/lectures/cs224n-2017-lecture4.pdf"
359 | ]
360 | },
361 | {
362 | "cell_type": "code",
363 | "execution_count": 16,
364 | "metadata": {
365 | "collapsed": true
366 | },
367 | "outputs": [],
368 | "source": [
369 | "class WindowClassifier(nn.Module): \n",
370 | " def __init__(self,vocab_size,embedding_size,window_size,hidden_size,output_size):\n",
371 | "\n",
372 | " super(WindowClassifier, self).__init__()\n",
373 | " \n",
374 | " self.embed = nn.Embedding(vocab_size,embedding_size)\n",
375 | " self.h_layer1 = nn.Linear(embedding_size*(window_size*2+1), hidden_size)\n",
376 | " self.h_layer2 = nn.Linear(hidden_size, hidden_size)\n",
377 | " self.o_layer = nn.Linear(hidden_size,output_size)\n",
378 | " self.relu = nn.ReLU()\n",
379 | " self.softmax = nn.LogSoftmax()\n",
380 | " self.dropout = nn.Dropout(0.3)\n",
381 | " \n",
382 | " def forward(self, inputs,is_training=False): \n",
383 | " embeds = self.embed(inputs) # BxWxD\n",
384 | " concated = embeds.view(-1,embeds.size(1)*embeds.size(2)) # Bx(W*D)\n",
385 | " h0 = self.relu(self.h_layer1(concated))\n",
386 | " if is_training:\n",
387 | " h0 = self.dropout(h0)\n",
388 | " h1 = self.relu(self.h_layer2(h0))\n",
389 | " if is_training:\n",
390 | " h1 = self.dropout(h1)\n",
391 | " out = self.softmax(self.o_layer(h1))\n",
392 | " return out"
393 | ]
394 | },
395 | {
396 | "cell_type": "code",
397 | "execution_count": 20,
398 | "metadata": {
399 | "collapsed": true
400 | },
401 | "outputs": [],
402 | "source": [
403 | "BATCH_SIZE=128\n",
404 | "EMBEDDING_SIZE=50 # x (WINDOW_SIZE*2+1) = 250\n",
405 | "HIDDEN_SIZE=300\n",
406 | "EPOCH=3\n",
407 | "LEARNING_RATE = 0.001"
408 | ]
409 | },
410 | {
411 | "cell_type": "markdown",
412 | "metadata": {},
413 | "source": [
414 | "## Training "
415 | ]
416 | },
417 | {
418 | "cell_type": "markdown",
419 | "metadata": {},
420 | "source": [
421 | "It takes for a while if you use just cpu."
422 | ]
423 | },
424 | {
425 | "cell_type": "code",
426 | "execution_count": 22,
427 | "metadata": {
428 | "collapsed": true
429 | },
430 | "outputs": [],
431 | "source": [
432 | "model = WindowClassifier(len(word2index),EMBEDDING_SIZE,WINDOW_SIZE,HIDDEN_SIZE,len(tag2index))\n",
433 | "if USE_CUDA:\n",
434 | " model = model.cuda()\n",
435 | "loss_function = nn.CrossEntropyLoss()\n",
436 | "optimizer = optim.Adam(model.parameters(),lr=LEARNING_RATE)"
437 | ]
438 | },
439 | {
440 | "cell_type": "code",
441 | "execution_count": 23,
442 | "metadata": {
443 | "collapsed": false
444 | },
445 | "outputs": [
446 | {
447 | "name": "stdout",
448 | "output_type": "stream",
449 | "text": [
450 | "[0/3] mean_loss : 2.25\n",
451 | "[0/3] mean_loss : 0.47\n",
452 | "[0/3] mean_loss : 0.36\n",
453 | "[0/3] mean_loss : 0.31\n",
454 | "[0/3] mean_loss : 0.28\n",
455 | "[1/3] mean_loss : 0.22\n",
456 | "[1/3] mean_loss : 0.21\n",
457 | "[1/3] mean_loss : 0.21\n",
458 | "[1/3] mean_loss : 0.19\n",
459 | "[1/3] mean_loss : 0.19\n",
460 | "[2/3] mean_loss : 0.12\n",
461 | "[2/3] mean_loss : 0.15\n",
462 | "[2/3] mean_loss : 0.15\n",
463 | "[2/3] mean_loss : 0.14\n",
464 | "[2/3] mean_loss : 0.14\n"
465 | ]
466 | }
467 | ],
468 | "source": [
469 | "for epoch in range(EPOCH):\n",
470 | " losses=[]\n",
471 | " for i,batch in enumerate(getBatch(BATCH_SIZE,train_data)):\n",
472 | " x,y=list(zip(*batch))\n",
473 | " inputs = torch.cat([prepare_sequence(sent,word2index).view(1,-1) for sent in x])\n",
474 | " targets = torch.cat([prepare_tag(tag,tag2index) for tag in y])\n",
475 | " model.zero_grad()\n",
476 | " preds = model(inputs,is_training=True)\n",
477 | " loss = loss_function(preds,targets)\n",
478 | " losses.append(loss.data.tolist()[0])\n",
479 | " loss.backward()\n",
480 | " optimizer.step()\n",
481 | "\n",
482 | " if i % 1000==0:\n",
483 | " print(\"[%d/%d] mean_loss : %0.2f\" %(epoch,EPOCH,np.mean(losses)))\n",
484 | " losses=[]"
485 | ]
486 | },
487 | {
488 | "cell_type": "markdown",
489 | "metadata": {},
490 | "source": [
491 | "## Test "
492 | ]
493 | },
494 | {
495 | "cell_type": "code",
496 | "execution_count": 24,
497 | "metadata": {
498 | "collapsed": true
499 | },
500 | "outputs": [],
501 | "source": [
502 | "for_f1_score=[]"
503 | ]
504 | },
505 | {
506 | "cell_type": "code",
507 | "execution_count": 25,
508 | "metadata": {
509 | "collapsed": false
510 | },
511 | "outputs": [
512 | {
513 | "name": "stdout",
514 | "output_type": "stream",
515 | "text": [
516 | "95.69120551903063\n"
517 | ]
518 | }
519 | ],
520 | "source": [
521 | "accuracy=0\n",
522 | "for test in test_data:\n",
523 | " x,y = test[0],test[1]\n",
524 | " input_ = prepare_sequence(x,word2index).view(1,-1)\n",
525 | "\n",
526 | " i = model(input_).max(1)[1]\n",
527 | " pred = index2tag[i.data.tolist()[0]]\n",
528 | " for_f1_score.append([pred,y])\n",
529 | " if pred==y:\n",
530 | " accuracy+=1\n",
531 | "\n",
532 | "print(accuracy/len(test_data)*100)"
533 | ]
534 | },
535 | {
536 | "cell_type": "markdown",
537 | "metadata": {},
538 | "source": [
539 | "This high score is because most of labels are 'O' tag. So we need to measure f1 score."
540 | ]
541 | },
542 | {
543 | "cell_type": "markdown",
544 | "metadata": {},
545 | "source": [
546 | "### Print Confusion matrix "
547 | ]
548 | },
549 | {
550 | "cell_type": "code",
551 | "execution_count": 26,
552 | "metadata": {
553 | "collapsed": true
554 | },
555 | "outputs": [],
556 | "source": [
557 | "y_pred, y_test = list(zip(*for_f1_score))"
558 | ]
559 | },
560 | {
561 | "cell_type": "code",
562 | "execution_count": 27,
563 | "metadata": {
564 | "collapsed": true
565 | },
566 | "outputs": [],
567 | "source": [
568 | "sorted_labels = sorted(\n",
569 | " list(set(y_test)-{'O'}),\n",
570 | " key=lambda name: (name[1:], name[0])\n",
571 | ")"
572 | ]
573 | },
574 | {
575 | "cell_type": "code",
576 | "execution_count": 28,
577 | "metadata": {
578 | "collapsed": false
579 | },
580 | "outputs": [
581 | {
582 | "data": {
583 | "text/plain": [
584 | "['B-LOC', 'I-LOC', 'B-MISC', 'I-MISC', 'B-ORG', 'I-ORG', 'B-PER', 'I-PER']"
585 | ]
586 | },
587 | "execution_count": 28,
588 | "metadata": {},
589 | "output_type": "execute_result"
590 | }
591 | ],
592 | "source": [
593 | "sorted_labels"
594 | ]
595 | },
596 | {
597 | "cell_type": "code",
598 | "execution_count": 29,
599 | "metadata": {
600 | "collapsed": true
601 | },
602 | "outputs": [],
603 | "source": [
604 | "y_pred = [[y] for y in y_pred] # this is because sklearn_crfsuite.metrics function flatten inputs\n",
605 | "y_test = [[y] for y in y_test]"
606 | ]
607 | },
608 | {
609 | "cell_type": "code",
610 | "execution_count": 30,
611 | "metadata": {
612 | "collapsed": false
613 | },
614 | "outputs": [
615 | {
616 | "name": "stdout",
617 | "output_type": "stream",
618 | "text": [
619 | " precision recall f1-score support\n",
620 | "\n",
621 | " B-LOC 0.802 0.636 0.710 1085\n",
622 | " I-LOC 0.732 0.457 0.562 311\n",
623 | " B-MISC 0.750 0.378 0.503 801\n",
624 | " I-MISC 0.679 0.331 0.445 641\n",
625 | " B-ORG 0.723 0.738 0.730 1430\n",
626 | " I-ORG 0.710 0.700 0.705 969\n",
627 | " B-PER 0.782 0.773 0.777 1268\n",
628 | " I-PER 0.853 0.871 0.861 950\n",
629 | "\n",
630 | "avg / total 0.759 0.656 0.693 7455\n",
631 | "\n"
632 | ]
633 | }
634 | ],
635 | "source": [
636 | "print(metrics.flat_classification_report(\n",
637 | " y_test, y_pred, labels=sorted_labels, digits=3\n",
638 | "))"
639 | ]
640 | },
641 | {
642 | "cell_type": "markdown",
643 | "metadata": {
644 | "collapsed": true
645 | },
646 | "source": [
647 | "### TODO"
648 | ]
649 | },
650 | {
651 | "cell_type": "markdown",
652 | "metadata": {},
653 | "source": [
654 | "* use max-margin objective function http://pytorch.org/docs/master/nn.html#multilabelmarginloss"
655 | ]
656 | },
657 | {
658 | "cell_type": "code",
659 | "execution_count": null,
660 | "metadata": {
661 | "collapsed": true
662 | },
663 | "outputs": [],
664 | "source": []
665 | }
666 | ],
667 | "metadata": {
668 | "kernelspec": {
669 | "display_name": "Python 3",
670 | "language": "python",
671 | "name": "python3"
672 | },
673 | "language_info": {
674 | "codemirror_mode": {
675 | "name": "ipython",
676 | "version": 3
677 | },
678 | "file_extension": ".py",
679 | "mimetype": "text/x-python",
680 | "name": "python",
681 | "nbconvert_exporter": "python",
682 | "pygments_lexer": "ipython3",
683 | "version": "3.5.2"
684 | }
685 | },
686 | "nbformat": 4,
687 | "nbformat_minor": 2
688 | }
689 |
--------------------------------------------------------------------------------
/notebooks/06.RNN-Language-Model.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# 6. Recurrent Neural Networks and Language Models"
8 | ]
9 | },
10 | {
11 | "cell_type": "markdown",
12 | "metadata": {},
13 | "source": [
14 | "* http://web.stanford.edu/class/cs224n/lectures/cs224n-2017-lecture8.pdf\n",
15 | "* http://web.stanford.edu/class/cs224n/lectures/cs224n-2017-lecture9.pdf\n",
16 | "* http://colah.github.io/posts/2015-08-Understanding-LSTMs/\n",
17 | "* https://github.com/pytorch/examples/tree/master/word_language_model\n",
18 | "* https://github.com/yunjey/pytorch-tutorial/blob/master/tutorials/02-intermediate/language_model"
19 | ]
20 | },
21 | {
22 | "cell_type": "code",
23 | "execution_count": 1,
24 | "metadata": {
25 | "collapsed": true
26 | },
27 | "outputs": [],
28 | "source": [
29 | "import torch\n",
30 | "import torch.nn as nn\n",
31 | "from torch.autograd import Variable\n",
32 | "import torch.optim as optim\n",
33 | "import torch.nn.functional as F\n",
34 | "import nltk\n",
35 | "import random\n",
36 | "import numpy as np\n",
37 | "from collections import Counter, OrderedDict\n",
38 | "import nltk\n",
39 | "from copy import deepcopy\n",
40 | "flatten = lambda l: [item for sublist in l for item in sublist]"
41 | ]
42 | },
43 | {
44 | "cell_type": "code",
45 | "execution_count": 2,
46 | "metadata": {
47 | "collapsed": true
48 | },
49 | "outputs": [],
50 | "source": [
51 | "USE_CUDA = torch.cuda.is_available()\n",
52 | "\n",
53 | "FloatTensor = torch.cuda.FloatTensor if USE_CUDA else torch.FloatTensor\n",
54 | "LongTensor = torch.cuda.LongTensor if USE_CUDA else torch.LongTensor\n",
55 | "ByteTensor = torch.cuda.ByteTensor if USE_CUDA else torch.ByteTensor"
56 | ]
57 | },
58 | {
59 | "cell_type": "code",
60 | "execution_count": 4,
61 | "metadata": {
62 | "collapsed": true
63 | },
64 | "outputs": [],
65 | "source": [
66 | "def prepare_sequence(seq, to_index):\n",
67 | " idxs = list(map(lambda w: to_index[w] if w in to_index.keys() else to_index[\"\"], seq))\n",
68 | " return LongTensor(idxs)"
69 | ]
70 | },
71 | {
72 | "cell_type": "markdown",
73 | "metadata": {},
74 | "source": [
75 | "## Data load and Preprocessing"
76 | ]
77 | },
78 | {
79 | "cell_type": "markdown",
80 | "metadata": {},
81 | "source": [
82 | "### Penn TreeBank"
83 | ]
84 | },
85 | {
86 | "cell_type": "code",
87 | "execution_count": 5,
88 | "metadata": {
89 | "collapsed": true
90 | },
91 | "outputs": [],
92 | "source": [
93 | "def prepare_ptb_dataset(filename,word2index=None):\n",
94 | " corpus = open(filename,'r',encoding='utf-8').readlines()\n",
95 | " corpus = flatten([co.strip().split() + [''] for co in corpus])\n",
96 | " \n",
97 | " if word2index==None:\n",
98 | " vocab = list(set(corpus))\n",
99 | " word2index={'':0}\n",
100 | " for vo in vocab:\n",
101 | " if vo not in word2index.keys():\n",
102 | " word2index[vo]=len(word2index)\n",
103 | " \n",
104 | " return prepare_sequence(corpus,word2index), word2index"
105 | ]
106 | },
107 | {
108 | "cell_type": "code",
109 | "execution_count": 175,
110 | "metadata": {
111 | "collapsed": true
112 | },
113 | "outputs": [],
114 | "source": [
115 | "# borrowed code from https://github.com/pytorch/examples/tree/master/word_language_model\n",
116 | "\n",
117 | "def batchify(data, bsz):\n",
118 | " # Work out how cleanly we can divide the dataset into bsz parts.\n",
119 | " nbatch = data.size(0) // bsz\n",
120 | " # Trim off any extra elements that wouldn't cleanly fit (remainders).\n",
121 | " data = data.narrow(0, 0, nbatch * bsz)\n",
122 | " # Evenly divide the data across the bsz batches.\n",
123 | " data = data.view(bsz, -1).contiguous()\n",
124 | " if USE_CUDA:\n",
125 | " data = data.cuda()\n",
126 | " return data"
127 | ]
128 | },
129 | {
130 | "cell_type": "code",
131 | "execution_count": 176,
132 | "metadata": {
133 | "collapsed": true
134 | },
135 | "outputs": [],
136 | "source": [
137 | "def getBatch(data,seq_length):\n",
138 | " for i in range(0, data.size(1) - seq_length, seq_length):\n",
139 | " inputs = Variable(data[:, i:i+seq_length])\n",
140 | " targets = Variable(data[:, (i+1):(i+1)+seq_length].contiguous())\n",
141 | " yield (inputs,targets)"
142 | ]
143 | },
144 | {
145 | "cell_type": "code",
146 | "execution_count": 177,
147 | "metadata": {
148 | "collapsed": false
149 | },
150 | "outputs": [],
151 | "source": [
152 | "train_data, word2index= prepare_ptb_dataset('../dataset/ptb/ptb.train.txt',)\n",
153 | "dev_data , _ = prepare_ptb_dataset('../dataset/ptb/ptb.valid.txt',word2index)\n",
154 | "test_data, _ = prepare_ptb_dataset('../dataset/ptb/ptb.test.txt',word2index)"
155 | ]
156 | },
157 | {
158 | "cell_type": "code",
159 | "execution_count": 178,
160 | "metadata": {
161 | "collapsed": false
162 | },
163 | "outputs": [
164 | {
165 | "data": {
166 | "text/plain": [
167 | "10000"
168 | ]
169 | },
170 | "execution_count": 178,
171 | "metadata": {},
172 | "output_type": "execute_result"
173 | }
174 | ],
175 | "source": [
176 | "len(word2index)"
177 | ]
178 | },
179 | {
180 | "cell_type": "code",
181 | "execution_count": 179,
182 | "metadata": {
183 | "collapsed": true
184 | },
185 | "outputs": [],
186 | "source": [
187 | "index2word = {v:k for k,v in word2index.items()}"
188 | ]
189 | },
190 | {
191 | "cell_type": "markdown",
192 | "metadata": {},
193 | "source": [
194 | "## Modeling "
195 | ]
196 | },
197 | {
198 | "cell_type": "markdown",
199 | "metadata": {},
200 | "source": [
201 | "
\n",
202 | "borrowed image from http://web.stanford.edu/class/cs224n/lectures/cs224n-2017-lecture8.pdf"
203 | ]
204 | },
205 | {
206 | "cell_type": "code",
207 | "execution_count": 180,
208 | "metadata": {
209 | "collapsed": false
210 | },
211 | "outputs": [],
212 | "source": [
213 | "class LanguageModel(nn.Module): \n",
214 | " def __init__(self,vocab_size,embedding_size,hidden_size,n_layers=1,dropout_p=0.5):\n",
215 | "\n",
216 | " super(LanguageModel, self).__init__()\n",
217 | " self.n_layers = n_layers\n",
218 | " self.hidden_size = hidden_size\n",
219 | " self.embed = nn.Embedding(vocab_size,embedding_size)\n",
220 | " self.rnn = nn.LSTM(embedding_size,hidden_size,n_layers,batch_first=True)\n",
221 | " self.linear = nn.Linear(hidden_size,vocab_size)\n",
222 | " self.dropout = nn.Dropout(dropout_p)\n",
223 | " \n",
224 | " def init_weight(self):\n",
225 | " self.embed.weight = nn.init.xavier_uniform(self.embed.weight)\n",
226 | " self.linear.weight = nn.init.xavier_uniform(self.linear.weight)\n",
227 | " self.linear.bias.data.fill_(0)\n",
228 | " \n",
229 | " def init_hidden(self,batch_size):\n",
230 | " hidden = Variable(torch.zeros(self.n_layers,batch_size,self.hidden_size))\n",
231 | " context = Variable(torch.zeros(self.n_layers,batch_size,self.hidden_size))\n",
232 | " return (hidden.cuda(), context.cuda()) if USE_CUDA else (hidden,context)\n",
233 | " \n",
234 | " def detach_hidden(self,hiddens):\n",
235 | " return tuple([hidden.detach() for hidden in hiddens])\n",
236 | " \n",
237 | " def forward(self, inputs,hidden,is_training=False): \n",
238 | "\n",
239 | " embeds = self.embed(inputs)\n",
240 | " if is_training:\n",
241 | " embeds = self.dropout(embeds)\n",
242 | " out,hidden = self.rnn(embeds,hidden)\n",
243 | " return self.linear(out.contiguous().view(out.size(0)*out.size(1),-1)), hidden"
244 | ]
245 | },
246 | {
247 | "cell_type": "markdown",
248 | "metadata": {},
249 | "source": [
250 | "## Train "
251 | ]
252 | },
253 | {
254 | "cell_type": "markdown",
255 | "metadata": {},
256 | "source": [
257 | "It takes for a while..."
258 | ]
259 | },
260 | {
261 | "cell_type": "code",
262 | "execution_count": 181,
263 | "metadata": {
264 | "collapsed": true
265 | },
266 | "outputs": [],
267 | "source": [
268 | "EMBED_SIZE=128\n",
269 | "HIDDEN_SIZE=1024\n",
270 | "NUM_LAYER=1\n",
271 | "LR = 0.01\n",
272 | "SEQ_LENGTH = 30 # for bptt\n",
273 | "BATCH_SIZE = 20\n",
274 | "EPOCH = 40\n",
275 | "RESCHEDULED=False"
276 | ]
277 | },
278 | {
279 | "cell_type": "code",
280 | "execution_count": 182,
281 | "metadata": {
282 | "collapsed": true
283 | },
284 | "outputs": [],
285 | "source": [
286 | "train_data = batchify(train_data,BATCH_SIZE)\n",
287 | "dev_data = batchify(dev_data,BATCH_SIZE//2)\n",
288 | "test_data = batchify(test_data,BATCH_SIZE//2)"
289 | ]
290 | },
291 | {
292 | "cell_type": "code",
293 | "execution_count": 185,
294 | "metadata": {
295 | "collapsed": false
296 | },
297 | "outputs": [],
298 | "source": [
299 | "model = LanguageModel(len(word2index),EMBED_SIZE,HIDDEN_SIZE,NUM_LAYER,0.5)\n",
300 | "model.init_weight() \n",
301 | "if USE_CUDA:\n",
302 | " model = model.cuda()\n",
303 | "loss_function = nn.CrossEntropyLoss()\n",
304 | "optimizer = optim.Adam(model.parameters(),lr=LR)"
305 | ]
306 | },
307 | {
308 | "cell_type": "code",
309 | "execution_count": 186,
310 | "metadata": {
311 | "collapsed": false
312 | },
313 | "outputs": [
314 | {
315 | "name": "stdout",
316 | "output_type": "stream",
317 | "text": [
318 | "[00/40] mean_loss : 9.45, Perplexity : 12712.23\n",
319 | "[00/40] mean_loss : 5.88, Perplexity : 358.21\n",
320 | "[00/40] mean_loss : 5.55, Perplexity : 256.44\n",
321 | "[01/40] mean_loss : 5.38, Perplexity : 217.46\n",
322 | "[01/40] mean_loss : 5.21, Perplexity : 182.41\n",
323 | "[01/40] mean_loss : 5.10, Perplexity : 164.39\n",
324 | "[02/40] mean_loss : 5.08, Perplexity : 160.87\n",
325 | "[02/40] mean_loss : 4.99, Perplexity : 147.18\n",
326 | "[02/40] mean_loss : 4.92, Perplexity : 136.52\n",
327 | "[03/40] mean_loss : 4.92, Perplexity : 136.64\n",
328 | "[03/40] mean_loss : 4.86, Perplexity : 129.32\n",
329 | "[03/40] mean_loss : 4.80, Perplexity : 121.46\n",
330 | "[04/40] mean_loss : 4.80, Perplexity : 121.91\n",
331 | "[04/40] mean_loss : 4.77, Perplexity : 117.64\n",
332 | "[04/40] mean_loss : 4.71, Perplexity : 111.22\n",
333 | "[05/40] mean_loss : 4.72, Perplexity : 112.01\n",
334 | "[05/40] mean_loss : 4.70, Perplexity : 109.46\n",
335 | "[05/40] mean_loss : 4.64, Perplexity : 103.96\n",
336 | "[06/40] mean_loss : 4.66, Perplexity : 105.25\n",
337 | "[06/40] mean_loss : 4.64, Perplexity : 103.63\n",
338 | "[06/40] mean_loss : 4.60, Perplexity : 99.00\n",
339 | "[07/40] mean_loss : 4.60, Perplexity : 99.89\n",
340 | "[07/40] mean_loss : 4.59, Perplexity : 98.97\n",
341 | "[07/40] mean_loss : 4.55, Perplexity : 94.97\n",
342 | "[08/40] mean_loss : 4.56, Perplexity : 95.54\n",
343 | "[08/40] mean_loss : 4.56, Perplexity : 95.67\n",
344 | "[08/40] mean_loss : 4.52, Perplexity : 91.98\n",
345 | "[09/40] mean_loss : 4.53, Perplexity : 92.61\n",
346 | "[09/40] mean_loss : 4.53, Perplexity : 92.79\n",
347 | "[09/40] mean_loss : 4.50, Perplexity : 89.63\n",
348 | "[10/40] mean_loss : 4.50, Perplexity : 90.13\n",
349 | "[10/40] mean_loss : 4.50, Perplexity : 90.19\n",
350 | "[10/40] mean_loss : 4.47, Perplexity : 87.11\n",
351 | "[11/40] mean_loss : 4.48, Perplexity : 88.11\n",
352 | "[11/40] mean_loss : 4.48, Perplexity : 88.26\n",
353 | "[11/40] mean_loss : 4.45, Perplexity : 86.05\n",
354 | "[12/40] mean_loss : 4.46, Perplexity : 86.81\n",
355 | "[12/40] mean_loss : 4.47, Perplexity : 87.03\n",
356 | "[12/40] mean_loss : 4.43, Perplexity : 84.04\n",
357 | "[13/40] mean_loss : 4.45, Perplexity : 85.27\n",
358 | "[13/40] mean_loss : 4.45, Perplexity : 85.83\n",
359 | "[13/40] mean_loss : 4.42, Perplexity : 83.33\n",
360 | "[14/40] mean_loss : 4.43, Perplexity : 84.15\n",
361 | "[14/40] mean_loss : 4.43, Perplexity : 84.31\n",
362 | "[14/40] mean_loss : 4.41, Perplexity : 82.29\n",
363 | "[15/40] mean_loss : 4.43, Perplexity : 83.82\n",
364 | "[15/40] mean_loss : 4.43, Perplexity : 83.70\n",
365 | "[15/40] mean_loss : 4.40, Perplexity : 81.59\n",
366 | "[16/40] mean_loss : 4.42, Perplexity : 83.06\n",
367 | "[16/40] mean_loss : 4.42, Perplexity : 83.29\n",
368 | "[16/40] mean_loss : 4.39, Perplexity : 80.89\n",
369 | "[17/40] mean_loss : 4.41, Perplexity : 82.44\n",
370 | "[17/40] mean_loss : 4.41, Perplexity : 82.51\n",
371 | "[17/40] mean_loss : 4.39, Perplexity : 80.59\n",
372 | "[18/40] mean_loss : 4.40, Perplexity : 81.59\n",
373 | "[18/40] mean_loss : 4.41, Perplexity : 82.21\n",
374 | "[18/40] mean_loss : 4.38, Perplexity : 79.87\n",
375 | "[19/40] mean_loss : 4.40, Perplexity : 81.43\n",
376 | "[19/40] mean_loss : 4.40, Perplexity : 81.67\n",
377 | "[19/40] mean_loss : 4.37, Perplexity : 79.28\n",
378 | "[20/40] mean_loss : 4.40, Perplexity : 81.18\n",
379 | "[20/40] mean_loss : 4.40, Perplexity : 81.17\n",
380 | "[20/40] mean_loss : 4.37, Perplexity : 79.11\n",
381 | "[21/40] mean_loss : 4.40, Perplexity : 81.44\n",
382 | "[21/40] mean_loss : 4.34, Perplexity : 76.43\n",
383 | "[21/40] mean_loss : 4.21, Perplexity : 67.17\n",
384 | "[22/40] mean_loss : 4.26, Perplexity : 70.84\n",
385 | "[22/40] mean_loss : 4.26, Perplexity : 70.75\n",
386 | "[22/40] mean_loss : 4.17, Perplexity : 64.99\n",
387 | "[23/40] mean_loss : 4.22, Perplexity : 68.36\n",
388 | "[23/40] mean_loss : 4.22, Perplexity : 67.82\n",
389 | "[23/40] mean_loss : 4.15, Perplexity : 63.74\n",
390 | "[24/40] mean_loss : 4.20, Perplexity : 66.66\n",
391 | "[24/40] mean_loss : 4.20, Perplexity : 66.43\n",
392 | "[24/40] mean_loss : 4.14, Perplexity : 62.85\n",
393 | "[25/40] mean_loss : 4.18, Perplexity : 65.53\n",
394 | "[25/40] mean_loss : 4.17, Perplexity : 64.99\n",
395 | "[25/40] mean_loss : 4.13, Perplexity : 61.94\n",
396 | "[26/40] mean_loss : 4.17, Perplexity : 64.61\n",
397 | "[26/40] mean_loss : 4.16, Perplexity : 64.34\n",
398 | "[26/40] mean_loss : 4.12, Perplexity : 61.27\n",
399 | "[27/40] mean_loss : 4.15, Perplexity : 63.73\n",
400 | "[27/40] mean_loss : 4.15, Perplexity : 63.32\n",
401 | "[27/40] mean_loss : 4.11, Perplexity : 60.87\n",
402 | "[28/40] mean_loss : 4.14, Perplexity : 62.96\n",
403 | "[28/40] mean_loss : 4.14, Perplexity : 63.01\n",
404 | "[28/40] mean_loss : 4.10, Perplexity : 60.33\n",
405 | "[29/40] mean_loss : 4.14, Perplexity : 62.54\n",
406 | "[29/40] mean_loss : 4.13, Perplexity : 62.36\n",
407 | "[29/40] mean_loss : 4.10, Perplexity : 60.06\n",
408 | "[30/40] mean_loss : 4.13, Perplexity : 62.05\n",
409 | "[30/40] mean_loss : 4.13, Perplexity : 61.91\n",
410 | "[30/40] mean_loss : 4.09, Perplexity : 59.46\n",
411 | "[31/40] mean_loss : 4.12, Perplexity : 61.45\n",
412 | "[31/40] mean_loss : 4.11, Perplexity : 61.24\n",
413 | "[31/40] mean_loss : 4.08, Perplexity : 59.12\n",
414 | "[32/40] mean_loss : 4.11, Perplexity : 61.03\n",
415 | "[32/40] mean_loss : 4.11, Perplexity : 60.88\n",
416 | "[32/40] mean_loss : 4.07, Perplexity : 58.69\n",
417 | "[33/40] mean_loss : 4.11, Perplexity : 60.71\n",
418 | "[33/40] mean_loss : 4.10, Perplexity : 60.57\n",
419 | "[33/40] mean_loss : 4.07, Perplexity : 58.38\n",
420 | "[34/40] mean_loss : 4.10, Perplexity : 60.33\n",
421 | "[34/40] mean_loss : 4.10, Perplexity : 60.23\n",
422 | "[34/40] mean_loss : 4.06, Perplexity : 58.06\n",
423 | "[35/40] mean_loss : 4.09, Perplexity : 60.00\n",
424 | "[35/40] mean_loss : 4.09, Perplexity : 59.74\n",
425 | "[35/40] mean_loss : 4.06, Perplexity : 57.75\n",
426 | "[36/40] mean_loss : 4.09, Perplexity : 59.58\n",
427 | "[36/40] mean_loss : 4.09, Perplexity : 59.47\n",
428 | "[36/40] mean_loss : 4.05, Perplexity : 57.59\n",
429 | "[37/40] mean_loss : 4.08, Perplexity : 59.30\n",
430 | "[37/40] mean_loss : 4.08, Perplexity : 59.11\n",
431 | "[37/40] mean_loss : 4.05, Perplexity : 57.11\n",
432 | "[38/40] mean_loss : 4.08, Perplexity : 58.98\n",
433 | "[38/40] mean_loss : 4.07, Perplexity : 58.70\n",
434 | "[38/40] mean_loss : 4.04, Perplexity : 57.10\n",
435 | "[39/40] mean_loss : 4.07, Perplexity : 58.79\n",
436 | "[39/40] mean_loss : 4.07, Perplexity : 58.58\n",
437 | "[39/40] mean_loss : 4.04, Perplexity : 56.79\n"
438 | ]
439 | }
440 | ],
441 | "source": [
442 | "for epoch in range(EPOCH):\n",
443 | " total_loss = 0\n",
444 | " losses=[]\n",
445 | " hidden = model.init_hidden(BATCH_SIZE)\n",
446 | " for i,batch in enumerate(getBatch(train_data,SEQ_LENGTH)):\n",
447 | " inputs, targets = batch\n",
448 | " hidden = model.detach_hidden(hidden)\n",
449 | " model.zero_grad()\n",
450 | " preds,hidden = model(inputs,hidden,True)\n",
451 | "\n",
452 | " loss = loss_function(preds,targets.view(-1))\n",
453 | " losses.append(loss.data[0])\n",
454 | " loss.backward()\n",
455 | " torch.nn.utils.clip_grad_norm(model.parameters(),0.5) # gradient clipping\n",
456 | " optimizer.step()\n",
457 | "\n",
458 | " if i>0 and i % 500==0:\n",
459 | " print(\"[%02d/%d] mean_loss : %0.2f, Perplexity : %0.2f\" % (epoch,EPOCH, \\\n",
460 | " np.mean(losses),np.exp(np.mean(losses))))\n",
461 | " losses=[]\n",
462 | " \n",
463 | " # learning rate anealing\n",
464 | " # You can use http://pytorch.org/docs/master/optim.html#how-to-adjust-learning-rate\n",
465 | " if RESCHEDULED==False and epoch==EPOCH//2:\n",
466 | " LR=LR*0.1\n",
467 | " optimizer = optim.Adam(model.parameters(),lr=LR)\n",
468 | " RESCHEDULED=True"
469 | ]
470 | },
471 | {
472 | "cell_type": "markdown",
473 | "metadata": {},
474 | "source": [
475 | "### Test "
476 | ]
477 | },
478 | {
479 | "cell_type": "code",
480 | "execution_count": 189,
481 | "metadata": {
482 | "collapsed": false
483 | },
484 | "outputs": [
485 | {
486 | "name": "stdout",
487 | "output_type": "stream",
488 | "text": [
489 | "Test Perpelexity : 155.89\n"
490 | ]
491 | }
492 | ],
493 | "source": [
494 | "total_loss = 0\n",
495 | "hidden = model.init_hidden(BATCH_SIZE//2)\n",
496 | "for batch in getBatch(test_data,SEQ_LENGTH):\n",
497 | " inputs,targets = batch\n",
498 | " \n",
499 | " hidden = model.detach_hidden(hidden)\n",
500 | " model.zero_grad()\n",
501 | " preds,hidden = model(inputs,hidden)\n",
502 | " total_loss += inputs.size(1) * loss_function(preds, targets.view(-1)).data\n",
503 | "\n",
504 | "total_loss = total_loss[0]/test_data.size(1)\n",
505 | "print(\"Test Perpelexity : %5.2f\" % (np.exp(total_loss)))"
506 | ]
507 | },
508 | {
509 | "cell_type": "markdown",
510 | "metadata": {
511 | "collapsed": true
512 | },
513 | "source": [
514 | "## Further topics"
515 | ]
516 | },
517 | {
518 | "cell_type": "markdown",
519 | "metadata": {},
520 | "source": [
521 | "* Pointer Sentinel Mixture Models\n",
522 | "* Regularizing and Optimizing LSTM Language Models"
523 | ]
524 | },
525 | {
526 | "cell_type": "code",
527 | "execution_count": null,
528 | "metadata": {
529 | "collapsed": true
530 | },
531 | "outputs": [],
532 | "source": []
533 | }
534 | ],
535 | "metadata": {
536 | "kernelspec": {
537 | "display_name": "Python 3",
538 | "language": "python",
539 | "name": "python3"
540 | },
541 | "language_info": {
542 | "codemirror_mode": {
543 | "name": "ipython",
544 | "version": 3
545 | },
546 | "file_extension": ".py",
547 | "mimetype": "text/x-python",
548 | "name": "python",
549 | "nbconvert_exporter": "python",
550 | "pygments_lexer": "ipython3",
551 | "version": "3.5.2"
552 | }
553 | },
554 | "nbformat": 4,
555 | "nbformat_minor": 2
556 | }
557 |
--------------------------------------------------------------------------------
/notebooks/08.CNN-for-Text-Classification.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# 8. Convolutional Neural Networks"
8 | ]
9 | },
10 | {
11 | "cell_type": "markdown",
12 | "metadata": {},
13 | "source": [
14 | "I recommend you take a look at these material first."
15 | ]
16 | },
17 | {
18 | "cell_type": "markdown",
19 | "metadata": {},
20 | "source": [
21 | "* http://web.stanford.edu/class/cs224n/lectures/cs224n-2017-lecture13-CNNs.pdf\n",
22 | "* http://www.aclweb.org/anthology/D14-1181\n",
23 | "* https://github.com/Shawn1993/cnn-text-classification-pytorch\n",
24 | "* http://cogcomp.org/Data/QA/QC/"
25 | ]
26 | },
27 | {
28 | "cell_type": "code",
29 | "execution_count": 58,
30 | "metadata": {
31 | "collapsed": true
32 | },
33 | "outputs": [],
34 | "source": [
35 | "import torch\n",
36 | "import torch.nn as nn\n",
37 | "from torch.autograd import Variable\n",
38 | "import torch.optim as optim\n",
39 | "import torch.nn.functional as F\n",
40 | "import nltk\n",
41 | "import random\n",
42 | "import numpy as np\n",
43 | "from collections import Counter, OrderedDict\n",
44 | "import nltk\n",
45 | "import re\n",
46 | "from copy import deepcopy\n",
47 | "flatten = lambda l: [item for sublist in l for item in sublist]"
48 | ]
49 | },
50 | {
51 | "cell_type": "code",
52 | "execution_count": 2,
53 | "metadata": {
54 | "collapsed": true
55 | },
56 | "outputs": [],
57 | "source": [
58 | "USE_CUDA = torch.cuda.is_available()\n",
59 | "\n",
60 | "FloatTensor = torch.cuda.FloatTensor if USE_CUDA else torch.FloatTensor\n",
61 | "LongTensor = torch.cuda.LongTensor if USE_CUDA else torch.LongTensor\n",
62 | "ByteTensor = torch.cuda.ByteTensor if USE_CUDA else torch.ByteTensor"
63 | ]
64 | },
65 | {
66 | "cell_type": "code",
67 | "execution_count": 3,
68 | "metadata": {
69 | "collapsed": true
70 | },
71 | "outputs": [],
72 | "source": [
73 | "def getBatch(batch_size,train_data):\n",
74 | " random.shuffle(train_data)\n",
75 | " sindex=0\n",
76 | " eindex=batch_size\n",
77 | " while eindex < len(train_data):\n",
78 | " batch = train_data[sindex:eindex]\n",
79 | " temp = eindex\n",
80 | " eindex = eindex+batch_size\n",
81 | " sindex = temp\n",
82 | " yield batch\n",
83 | " \n",
84 | " if eindex >= len(train_data):\n",
85 | " batch = train_data[sindex:]\n",
86 | " yield batch"
87 | ]
88 | },
89 | {
90 | "cell_type": "code",
91 | "execution_count": 110,
92 | "metadata": {
93 | "collapsed": true
94 | },
95 | "outputs": [],
96 | "source": [
97 | "def pad_to_batch(batch):\n",
98 | " x,y = zip(*batch)\n",
99 | " max_x = max([s.size(1) for s in x])\n",
100 | " x_p=[]\n",
101 | " for i in range(len(batch)):\n",
102 | " if x[i].size(1)']]*(max_x-x[i].size(1)))).view(1,-1)],1))\n",
104 | " else:\n",
105 | " x_p.append(x[i])\n",
106 | " return torch.cat(x_p),torch.cat(y).view(-1)"
107 | ]
108 | },
109 | {
110 | "cell_type": "code",
111 | "execution_count": 20,
112 | "metadata": {
113 | "collapsed": true
114 | },
115 | "outputs": [],
116 | "source": [
117 | "def prepare_sequence(seq, to_index):\n",
118 | " idxs = list(map(lambda w: to_index[w] if w in to_index.keys() else to_index[\"\"], seq))\n",
119 | " return Variable(LongTensor(idxs))"
120 | ]
121 | },
122 | {
123 | "cell_type": "markdown",
124 | "metadata": {},
125 | "source": [
126 | "## Data load & Preprocessing"
127 | ]
128 | },
129 | {
130 | "cell_type": "markdown",
131 | "metadata": {},
132 | "source": [
133 | "### TREC question dataset(http://cogcomp.org/Data/QA/QC/)"
134 | ]
135 | },
136 | {
137 | "cell_type": "markdown",
138 | "metadata": {},
139 | "source": [
140 | "Task involves\n",
141 | "classifying a question into 6 question\n",
142 | "types (whether the question is about person,\n",
143 | "location, numeric information, etc.)"
144 | ]
145 | },
146 | {
147 | "cell_type": "code",
148 | "execution_count": 53,
149 | "metadata": {
150 | "collapsed": true
151 | },
152 | "outputs": [],
153 | "source": [
154 | "data = open('../dataset/train_5500.label.txt','r',encoding='latin-1').readlines()"
155 | ]
156 | },
157 | {
158 | "cell_type": "code",
159 | "execution_count": 54,
160 | "metadata": {
161 | "collapsed": true
162 | },
163 | "outputs": [],
164 | "source": [
165 | "data = [[d.split(':')[1][:-1],d.split(':')[0]] for d in data]"
166 | ]
167 | },
168 | {
169 | "cell_type": "code",
170 | "execution_count": 61,
171 | "metadata": {
172 | "collapsed": true
173 | },
174 | "outputs": [],
175 | "source": [
176 | "X,y = list(zip(*data))\n",
177 | "X = list(X)"
178 | ]
179 | },
180 | {
181 | "cell_type": "markdown",
182 | "metadata": {},
183 | "source": [
184 | "### Num masking "
185 | ]
186 | },
187 | {
188 | "cell_type": "markdown",
189 | "metadata": {},
190 | "source": [
191 | "It reduces the search space. ex. my birthday is 12.22 ==> my birthday is ##.##"
192 | ]
193 | },
194 | {
195 | "cell_type": "code",
196 | "execution_count": 62,
197 | "metadata": {
198 | "collapsed": true
199 | },
200 | "outputs": [],
201 | "source": [
202 | "for i,x in enumerate(X):\n",
203 | " X[i] = re.sub('\\d','#',x).split()"
204 | ]
205 | },
206 | {
207 | "cell_type": "markdown",
208 | "metadata": {},
209 | "source": [
210 | "### Build Vocab "
211 | ]
212 | },
213 | {
214 | "cell_type": "code",
215 | "execution_count": 63,
216 | "metadata": {
217 | "collapsed": true
218 | },
219 | "outputs": [],
220 | "source": [
221 | "vocab = list(set(flatten(X)))"
222 | ]
223 | },
224 | {
225 | "cell_type": "code",
226 | "execution_count": 64,
227 | "metadata": {
228 | "collapsed": false
229 | },
230 | "outputs": [
231 | {
232 | "data": {
233 | "text/plain": [
234 | "9117"
235 | ]
236 | },
237 | "execution_count": 64,
238 | "metadata": {},
239 | "output_type": "execute_result"
240 | }
241 | ],
242 | "source": [
243 | "len(vocab)"
244 | ]
245 | },
246 | {
247 | "cell_type": "code",
248 | "execution_count": 31,
249 | "metadata": {
250 | "collapsed": false
251 | },
252 | "outputs": [
253 | {
254 | "data": {
255 | "text/plain": [
256 | "6"
257 | ]
258 | },
259 | "execution_count": 31,
260 | "metadata": {},
261 | "output_type": "execute_result"
262 | }
263 | ],
264 | "source": [
265 | "len(set(y)) # num of class"
266 | ]
267 | },
268 | {
269 | "cell_type": "code",
270 | "execution_count": 94,
271 | "metadata": {
272 | "collapsed": true
273 | },
274 | "outputs": [],
275 | "source": [
276 | "word2index={'':0,'':1}\n",
277 | "\n",
278 | "for vo in vocab:\n",
279 | " if vo not in word2index.keys():\n",
280 | " word2index[vo]=len(word2index)\n",
281 | " \n",
282 | "index2word = {v:k for k,v in word2index.items()}\n",
283 | "\n",
284 | "target2index = {}\n",
285 | "\n",
286 | "for cl in set(y):\n",
287 | " if cl not in target2index.keys():\n",
288 | " target2index[cl]=len(target2index)\n",
289 | "\n",
290 | "index2target = {v:k for k,v in target2index.items()}"
291 | ]
292 | },
293 | {
294 | "cell_type": "code",
295 | "execution_count": 95,
296 | "metadata": {
297 | "collapsed": true
298 | },
299 | "outputs": [],
300 | "source": [
301 | "X_p,y_p=[],[]\n",
302 | "for pair in zip(X,y):\n",
303 | " X_p.append(prepare_sequence(pair[0],word2index).view(1,-1))\n",
304 | " y_p.append(Variable(LongTensor([target2index[pair[1]]])).view(1,-1))\n",
305 | " \n",
306 | "data_p = list(zip(X_p,y_p))\n",
307 | "random.shuffle(data_p)\n",
308 | "\n",
309 | "train_data = data_p[:int(len(data_p)*0.9)]\n",
310 | "test_data = data_p[int(len(data_p)*0.9):]"
311 | ]
312 | },
313 | {
314 | "cell_type": "markdown",
315 | "metadata": {},
316 | "source": [
317 | "### Load Pretrained word vector"
318 | ]
319 | },
320 | {
321 | "cell_type": "markdown",
322 | "metadata": {},
323 | "source": [
324 | "you can download pretrained word vector from here https://github.com/mmihaltz/word2vec-GoogleNews-vectors "
325 | ]
326 | },
327 | {
328 | "cell_type": "code",
329 | "execution_count": 41,
330 | "metadata": {
331 | "collapsed": true
332 | },
333 | "outputs": [],
334 | "source": [
335 | "import gensim"
336 | ]
337 | },
338 | {
339 | "cell_type": "code",
340 | "execution_count": 43,
341 | "metadata": {
342 | "collapsed": true
343 | },
344 | "outputs": [],
345 | "source": [
346 | "model = gensim.models.KeyedVectors.load_word2vec_format('../dataset/GoogleNews-vectors-negative300.bin', binary=True)"
347 | ]
348 | },
349 | {
350 | "cell_type": "code",
351 | "execution_count": 48,
352 | "metadata": {
353 | "collapsed": false
354 | },
355 | "outputs": [
356 | {
357 | "data": {
358 | "text/plain": [
359 | "3000000"
360 | ]
361 | },
362 | "execution_count": 48,
363 | "metadata": {},
364 | "output_type": "execute_result"
365 | }
366 | ],
367 | "source": [
368 | "len(model.index2word)"
369 | ]
370 | },
371 | {
372 | "cell_type": "code",
373 | "execution_count": 96,
374 | "metadata": {
375 | "collapsed": true
376 | },
377 | "outputs": [],
378 | "source": [
379 | "pretrained = []\n",
380 | "\n",
381 | "for i in range(len(word2index)):\n",
382 | " try:\n",
383 | " pretrained.append(model[word2index[i]])\n",
384 | " except:\n",
385 | " pretrained.append(np.random.randn(300))\n",
386 | " \n",
387 | "pretrained_vectors = np.vstack(pretrained)"
388 | ]
389 | },
390 | {
391 | "cell_type": "markdown",
392 | "metadata": {},
393 | "source": [
394 | "## Modeling "
395 | ]
396 | },
397 | {
398 | "cell_type": "markdown",
399 | "metadata": {},
400 | "source": [
401 | "
\n",
402 | "borrowed image from http://www.aclweb.org/anthology/D14-1181"
403 | ]
404 | },
405 | {
406 | "cell_type": "code",
407 | "execution_count": 117,
408 | "metadata": {
409 | "collapsed": true
410 | },
411 | "outputs": [],
412 | "source": [
413 | "class CNNClassifier(nn.Module):\n",
414 | " \n",
415 | " def __init__(self, vocab_size,embedding_dim,output_size,kernel_dim=100,kernel_sizes=[3,4,5],dropout=0.5):\n",
416 | " super(CNNClassifier,self).__init__()\n",
417 | "\n",
418 | " self.embedding = nn.Embedding(vocab_size, embedding_dim)\n",
419 | " self.convs = nn.ModuleList([nn.Conv2d(1, kernel_dim, (K, embedding_dim)) for K in kernel_sizes])\n",
420 | "\n",
421 | " # kernal_size = (K,D) \n",
422 | " self.dropout = nn.Dropout(dropout)\n",
423 | " self.fc = nn.Linear(len(kernel_sizes)*kernel_dim, output_size)\n",
424 | " \n",
425 | " \n",
426 | " def init_weights(self,pretrained_word_vectors,is_static=False):\n",
427 | " self.embedding.weight = nn.Parameter(torch.from_numpy(pretrained_word_vectors).float())\n",
428 | " if is_static:\n",
429 | " self.embedding.weight.requires_grad = False\n",
430 | "\n",
431 | "\n",
432 | " def forward(self, inputs,is_training=False):\n",
433 | " inputs = self.embedding(inputs).unsqueeze(1) # (B,1,T,D)\n",
434 | " inputs = [F.relu(conv(inputs)).squeeze(3) for conv in self.convs] #[(N,Co,W), ...]*len(Ks)\n",
435 | " inputs = [F.max_pool1d(i, i.size(2)).squeeze(2) for i in inputs] #[(N,Co), ...]*len(Ks)\n",
436 | "\n",
437 | " concated = torch.cat(inputs, 1)\n",
438 | "\n",
439 | " if is_training:\n",
440 | " concated = self.dropout(concated) # (N,len(Ks)*Co)\n",
441 | " out = self.fc(concated) \n",
442 | " return F.log_softmax(out)"
443 | ]
444 | },
445 | {
446 | "cell_type": "markdown",
447 | "metadata": {},
448 | "source": [
449 | "## Train "
450 | ]
451 | },
452 | {
453 | "cell_type": "markdown",
454 | "metadata": {},
455 | "source": [
456 | "It takes for a while if you use just cpu."
457 | ]
458 | },
459 | {
460 | "cell_type": "code",
461 | "execution_count": 145,
462 | "metadata": {
463 | "collapsed": true
464 | },
465 | "outputs": [],
466 | "source": [
467 | "EPOCH=5\n",
468 | "BATCH_SIZE=50\n",
469 | "KERNEL_SIZES = [3,4,5]\n",
470 | "KERNEL_DIM = 100\n",
471 | "LR = 0.001"
472 | ]
473 | },
474 | {
475 | "cell_type": "code",
476 | "execution_count": 146,
477 | "metadata": {
478 | "collapsed": true
479 | },
480 | "outputs": [],
481 | "source": [
482 | "model = CNNClassifier(len(word2index),300,len(target2index),KERNEL_DIM,KERNEL_SIZES)\n",
483 | "model.init_weights(pretrained_vectors) # initialize embedding matrix using pretrained vectors\n",
484 | "\n",
485 | "if USE_CUDA:\n",
486 | " model = model.cuda()\n",
487 | " \n",
488 | "loss_function = nn.CrossEntropyLoss()\n",
489 | "optimizer = optim.Adam(model.parameters(),lr=LR)"
490 | ]
491 | },
492 | {
493 | "cell_type": "code",
494 | "execution_count": 147,
495 | "metadata": {
496 | "collapsed": false
497 | },
498 | "outputs": [
499 | {
500 | "name": "stdout",
501 | "output_type": "stream",
502 | "text": [
503 | "[0/5] mean_loss : 2.13\n",
504 | "[1/5] mean_loss : 0.12\n",
505 | "[2/5] mean_loss : 0.08\n",
506 | "[3/5] mean_loss : 0.02\n",
507 | "[4/5] mean_loss : 0.05\n"
508 | ]
509 | }
510 | ],
511 | "source": [
512 | "for epoch in range(EPOCH):\n",
513 | " losses=[]\n",
514 | " for i,batch in enumerate(getBatch(BATCH_SIZE,train_data)):\n",
515 | " inputs,targets = pad_to_batch(batch)\n",
516 | " \n",
517 | " model.zero_grad()\n",
518 | " preds = model(inputs,True)\n",
519 | " \n",
520 | " loss = loss_function(preds,targets)\n",
521 | " losses.append(loss.data.tolist()[0])\n",
522 | " loss.backward()\n",
523 | " \n",
524 | " #for param in model.parameters():\n",
525 | " # param.grad.data.clamp_(-3, 3)\n",
526 | " \n",
527 | " optimizer.step()\n",
528 | " \n",
529 | " if i % 100==0:\n",
530 | " print(\"[%d/%d] mean_loss : %0.2f\" %(epoch,EPOCH,np.mean(losses)))\n",
531 | " losses=[]"
532 | ]
533 | },
534 | {
535 | "cell_type": "markdown",
536 | "metadata": {},
537 | "source": [
538 | "## Test "
539 | ]
540 | },
541 | {
542 | "cell_type": "code",
543 | "execution_count": 150,
544 | "metadata": {
545 | "collapsed": true
546 | },
547 | "outputs": [],
548 | "source": [
549 | "accuracy=0"
550 | ]
551 | },
552 | {
553 | "cell_type": "code",
554 | "execution_count": 151,
555 | "metadata": {
556 | "collapsed": false
557 | },
558 | "outputs": [
559 | {
560 | "name": "stdout",
561 | "output_type": "stream",
562 | "text": [
563 | "97.61904761904762\n"
564 | ]
565 | }
566 | ],
567 | "source": [
568 | "for test in test_data:\n",
569 | " pred = model(test[0]).max(1)[1]\n",
570 | " pred = pred.data.tolist()[0]\n",
571 | " target = test[1].data.tolist()[0][0]\n",
572 | " if pred == target:\n",
573 | " accuracy+=1\n",
574 | "\n",
575 | "print(accuracy/len(test_data)*100)"
576 | ]
577 | },
578 | {
579 | "cell_type": "markdown",
580 | "metadata": {
581 | "collapsed": true
582 | },
583 | "source": [
584 | "## Further topics "
585 | ]
586 | },
587 | {
588 | "cell_type": "markdown",
589 | "metadata": {},
590 | "source": [
591 | "* Character-Aware Neural Language Models\n",
592 | "* Character level CNN for text classification"
593 | ]
594 | },
595 | {
596 | "cell_type": "markdown",
597 | "metadata": {},
598 | "source": [
599 | "## Suggested Reading"
600 | ]
601 | },
602 | {
603 | "cell_type": "markdown",
604 | "metadata": {},
605 | "source": [
606 | "* https://blog.statsbot.co/text-classifier-algorithms-in-machine-learning-acc115293278\n",
607 | "* Bag of Tricks for Efficient Text Classification\n",
608 | "* Which Encoding is the Best for Text Classification in Chinese, English, Japanese and Korean?"
609 | ]
610 | },
611 | {
612 | "cell_type": "code",
613 | "execution_count": null,
614 | "metadata": {
615 | "collapsed": true
616 | },
617 | "outputs": [],
618 | "source": []
619 | }
620 | ],
621 | "metadata": {
622 | "kernelspec": {
623 | "display_name": "Python 3",
624 | "language": "python",
625 | "name": "python3"
626 | },
627 | "language_info": {
628 | "codemirror_mode": {
629 | "name": "ipython",
630 | "version": 3
631 | },
632 | "file_extension": ".py",
633 | "mimetype": "text/x-python",
634 | "name": "python",
635 | "nbconvert_exporter": "python",
636 | "pygments_lexer": "ipython3",
637 | "version": "3.5.2"
638 | }
639 | },
640 | "nbformat": 4,
641 | "nbformat_minor": 2
642 | }
643 |
--------------------------------------------------------------------------------
/notebooks/09.Recursive-NN-for-Sentiment-Classification.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# 9. Recursive Neural Networks and Constituency Parsing"
8 | ]
9 | },
10 | {
11 | "cell_type": "markdown",
12 | "metadata": {},
13 | "source": [
14 | "I recommend you take a look at these material first."
15 | ]
16 | },
17 | {
18 | "cell_type": "markdown",
19 | "metadata": {},
20 | "source": [
21 | "* http://web.stanford.edu/class/cs224n/lectures/cs224n-2017-lecture14-TreeRNNs.pdf\n",
22 | "* https://nlp.stanford.edu/~socherr/EMNLP2013_RNTN.pdf"
23 | ]
24 | },
25 | {
26 | "cell_type": "code",
27 | "execution_count": 4,
28 | "metadata": {
29 | "collapsed": true
30 | },
31 | "outputs": [],
32 | "source": [
33 | "import torch\n",
34 | "import torch.nn as nn\n",
35 | "from torch.autograd import Variable\n",
36 | "import torch.optim as optim\n",
37 | "import torch.nn.functional as F\n",
38 | "import nltk\n",
39 | "import random\n",
40 | "import numpy as np\n",
41 | "from collections import Counter, OrderedDict\n",
42 | "import nltk\n",
43 | "from copy import deepcopy\n",
44 | "import os\n",
45 | "from IPython.display import Image, display\n",
46 | "from nltk.draw import TreeWidget\n",
47 | "from nltk.draw.util import CanvasFrame\n",
48 | "from nltk.tree import Tree as nltkTree\n",
49 | "flatten = lambda l: [item for sublist in l for item in sublist]"
50 | ]
51 | },
52 | {
53 | "cell_type": "code",
54 | "execution_count": 2,
55 | "metadata": {
56 | "collapsed": true
57 | },
58 | "outputs": [],
59 | "source": [
60 | "USE_CUDA = torch.cuda.is_available()\n",
61 | "\n",
62 | "FloatTensor = torch.cuda.FloatTensor if USE_CUDA else torch.FloatTensor\n",
63 | "LongTensor = torch.cuda.LongTensor if USE_CUDA else torch.LongTensor\n",
64 | "ByteTensor = torch.cuda.ByteTensor if USE_CUDA else torch.ByteTensor"
65 | ]
66 | },
67 | {
68 | "cell_type": "code",
69 | "execution_count": 3,
70 | "metadata": {
71 | "collapsed": true
72 | },
73 | "outputs": [],
74 | "source": [
75 | "def getBatch(batch_size,train_data):\n",
76 | " random.shuffle(train_data)\n",
77 | " sindex=0\n",
78 | " eindex=batch_size\n",
79 | " while eindex < len(train_data):\n",
80 | " batch = train_data[sindex:eindex]\n",
81 | " temp = eindex\n",
82 | " eindex = eindex+batch_size\n",
83 | " sindex = temp\n",
84 | " yield batch\n",
85 | " \n",
86 | " if eindex >= len(train_data):\n",
87 | " batch = train_data[sindex:]\n",
88 | " yield batch"
89 | ]
90 | },
91 | {
92 | "cell_type": "code",
93 | "execution_count": 5,
94 | "metadata": {
95 | "collapsed": true
96 | },
97 | "outputs": [],
98 | "source": [
99 | "# Borrowed from https://stackoverflow.com/questions/31779707/how-do-you-make-nltk-draw-trees-that-are-inline-in-ipython-jupyter\n",
100 | "\n",
101 | "def draw_nltk_tree(tree):\n",
102 | " cf = CanvasFrame()\n",
103 | " tc = TreeWidget(cf.canvas(), tree)\n",
104 | " tc['node_font'] = 'arial 15 bold'\n",
105 | " tc['leaf_font'] = 'arial 15'\n",
106 | " tc['node_color'] = '#005990'\n",
107 | " tc['leaf_color'] = '#3F8F57'\n",
108 | " tc['line_color'] = '#175252'\n",
109 | " cf.add_widget(tc, 50, 50)\n",
110 | " cf.print_to_file('tmp_tree_output.ps')\n",
111 | " cf.destroy()\n",
112 | " os.system('convert tmp_tree_output.ps tmp_tree_output.png')\n",
113 | " display(Image(filename='tmp_tree_output.png'))\n",
114 | " os.system('rm tmp_tree_output.ps tmp_tree_output.png')"
115 | ]
116 | },
117 | {
118 | "cell_type": "markdown",
119 | "metadata": {},
120 | "source": [
121 | "## Data load and Preprocessing"
122 | ]
123 | },
124 | {
125 | "cell_type": "markdown",
126 | "metadata": {},
127 | "source": [
128 | "### Stanford Sentiment Treebank(https://nlp.stanford.edu/sentiment/index.html)"
129 | ]
130 | },
131 | {
132 | "cell_type": "code",
133 | "execution_count": 10,
134 | "metadata": {
135 | "collapsed": false
136 | },
137 | "outputs": [
138 | {
139 | "name": "stdout",
140 | "output_type": "stream",
141 | "text": [
142 | "(3 (2 (1 Deflated) (2 (2 ending) (2 aside))) (4 (2 ,) (4 (2 there) (3 (3 (2 's) (3 (2 much) (2 (2 to) (3 (3 recommend) (2 (2 the) (2 film)))))) (2 .)))))\n",
143 | "\n"
144 | ]
145 | }
146 | ],
147 | "source": [
148 | "sample = random.choice(open('../dataset/trees/train.txt','r',encoding='utf-8').readlines())\n",
149 | "print(sample)"
150 | ]
151 | },
152 | {
153 | "cell_type": "code",
154 | "execution_count": 11,
155 | "metadata": {
156 | "collapsed": false
157 | },
158 | "outputs": [
159 | {
160 | "data": {
161 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAhEAAAGVCAMAAAB3iVNzAAAJJGlDQ1BpY2MAAHjalZVnUJNZF8fv\n8zzphUASQodQQ5EqJYCUEFoo0quoQOidUEVsiLgCK4qINEUQUUDBVSmyVkSxsCgoYkE3yCKgrBtX\nERWUF/Sd0Xnf2Q/7n7n3/OY/Z+4995wPFwCCOFgSvLQnJqULvJ3smIFBwUzwg8L4aSkcT0838I96\nPwyg5XhvBfj3IkREpvGX4sLSyuWnCNIBgLKXWDMrPWWZDy8xPTz+K59dZsFSgUt8Y5mjv/Ho15xv\nLPqa4+vNXXoVCgAcKfoHDv+B/3vvslQ4gvTYqMhspk9yVHpWmCCSmbbcCR6Xy/QUJEfFJkT+UPC/\nSv4HpUdmpy9HbnLKBkFsdEw68/8ONTIwNATfZ/HW62uPIUb//85nWd+95HoA2LMAIHu+e+GVAHTu\nAED68XdPbamvlHwAOu7wMwSZ3zzU8oYGBEABdCADFIEq0AS6wAiYAUtgCxyAC/AAviAIrAN8EAMS\ngQBkgVywDRSAIrAH7AdVoBY0gCbQCk6DTnAeXAHXwW1wFwyDJ0AIJsArIALvwTwEQViIDNEgGUgJ\nUod0ICOIDVlDDpAb5A0FQaFQNJQEZUC50HaoCCqFqqA6qAn6BToHXYFuQoPQI2gMmob+hj7BCEyC\n6bACrAHrw2yYA7vCvvBaOBpOhXPgfHg3XAHXwyfgDvgKfBsehoXwK3gWAQgRYSDKiC7CRriIBxKM\nRCECZDNSiJQj9Ugr0o30IfcQITKDfERhUDQUE6WLskQ5o/xQfFQqajOqGFWFOo7qQPWi7qHGUCLU\nFzQZLY/WQVugeehAdDQ6C12ALkc3otvR19DD6An0ewwGw8CwMGYYZ0wQJg6zEVOMOYhpw1zGDGLG\nMbNYLFYGq4O1wnpgw7Dp2AJsJfYE9hJ2CDuB/YAj4pRwRjhHXDAuCZeHK8c14y7ihnCTuHm8OF4d\nb4H3wEfgN+BL8A34bvwd/AR+niBBYBGsCL6EOMI2QgWhlXCNMEp4SyQSVYjmRC9iLHErsYJ4iniD\nOEb8SKKStElcUggpg7SbdIx0mfSI9JZMJmuQbcnB5HTybnIT+Sr5GfmDGE1MT4wnFiG2RaxarENs\nSOw1BU9Rp3Ao6yg5lHLKGcodyow4XlxDnCseJr5ZvFr8nPiI+KwETcJQwkMiUaJYolnipsQUFUvV\noDpQI6j51CPUq9RxGkJTpXFpfNp2WgPtGm2CjqGz6Dx6HL2IfpI+QBdJUiWNJf0lsyWrJS9IChkI\nQ4PBYyQwShinGQ8Yn6QUpDhSkVK7pFqlhqTmpOWkbaUjpQul26SHpT/JMGUcZOJl9sp0yjyVRclq\ny3rJZskekr0mOyNHl7OU48sVyp2WeywPy2vLe8tvlD8i3y8/q6Co4KSQolCpcFVhRpGhaKsYp1im\neFFxWommZK0Uq1SmdEnpJVOSyWEmMCuYvUyRsryys3KGcp3ygPK8CkvFTyVPpU3lqSpBla0apVqm\n2qMqUlNSc1fLVWtRe6yOV2erx6gfUO9Tn9NgaQRo7NTo1JhiSbN4rBxWC2tUk6xpo5mqWa95Xwuj\nxdaK1zqodVcb1jbRjtGu1r6jA+uY6sTqHNQZXIFeYb4iaUX9ihFdki5HN1O3RXdMj6Hnppen16n3\nWl9NP1h/r36f/hcDE4MEgwaDJ4ZUQxfDPMNuw7+NtI34RtVG91eSVzqu3LKya+UbYx3jSONDxg9N\naCbuJjtNekw+m5qZCkxbTafN1MxCzWrMRth0tie7mH3DHG1uZ77F/Lz5RwtTi3SL0xZ/Wepaxls2\nW06tYq2KXNWwatxKxSrMqs5KaM20DrU+bC20UbYJs6m3eW6rahth22g7ydHixHFOcF7bGdgJ7Nrt\n5rgW3E3cy/aIvZN9of2AA9XBz6HK4ZmjimO0Y4ujyMnEaaPTZWe0s6vzXucRngKPz2viiVzMXDa5\n9LqSXH1cq1yfu2m7Cdy63WF3F/d97qOr1Vcnre70AB48j30eTz1Znqmev3phvDy9qr1eeBt653r3\n+dB81vs0+7z3tfMt8X3ip+mX4dfjT/EP8W/ynwuwDygNEAbqB24KvB0kGxQb1BWMDfYPbgyeXeOw\nZv+aiRCTkIKQB2tZa7PX3lwnuy5h3YX1lPVh68+EokMDQptDF8I8wurDZsN54TXhIj6Xf4D/KsI2\noixiOtIqsjRyMsoqqjRqKtoqel/0dIxNTHnMTCw3tir2TZxzXG3cXLxH/LH4xYSAhLZEXGJo4rkk\nalJ8Um+yYnJ28mCKTkpBijDVInV/qkjgKmhMg9LWpnWl05c+xf4MzYwdGWOZ1pnVmR+y/LPOZEtk\nJ2X3b9DesGvDZI5jztGNqI38jT25yrnbcsc2cTbVbYY2h2/u2aK6JX/LxFanrce3EbbFb/stzyCv\nNO/d9oDt3fkK+Vvzx3c47WgpECsQFIzstNxZ+xPqp9ifBnat3FW560thROGtIoOi8qKFYn7xrZ8N\nf674eXF31O6BEtOSQ3swe5L2PNhrs/d4qURpTun4Pvd9HWXMssKyd/vX779Zblxee4BwIOOAsMKt\noqtSrXJP5UJVTNVwtV11W418za6auYMRB4cO2R5qrVWoLar9dDj28MM6p7qOeo368iOYI5lHXjT4\nN/QdZR9tapRtLGr8fCzpmPC49/HeJrOmpmb55pIWuCWjZfpEyIm7J+1PdrXqtta1MdqKToFTGade\n/hL6y4PTrqd7zrDPtJ5VP1vTTmsv7IA6NnSIOmM6hV1BXYPnXM71dFt2t/+q9+ux88rnqy9IXii5\nSLiYf3HxUs6l2cspl2euRF8Z71nf8+Rq4NX7vV69A9dcr9247nj9ah+n79INqxvnb1rcPHeLfavz\ntuntjn6T/vbfTH5rHzAd6Lhjdqfrrvnd7sFVgxeHbIau3LO/d/0+7/7t4dXDgw/8HjwcCRkRPox4\nOPUo4dGbx5mP559sHUWPFj4Vf1r+TP5Z/e9av7cJTYUXxuzH+p/7PH8yzh9/9UfaHwsT+S/IL8on\nlSabpoymzk87Tt99ueblxKuUV/MzBX9K/FnzWvP12b9s/+oXBYom3gjeLP5d/Fbm7bF3xu96Zj1n\nn71PfD8/V/hB5sPxj+yPfZ8CPk3OZy1gFyo+a33u/uL6ZXQxcXHxPy6ikLxyKdSVAAAAIGNIUk0A\nAHomAACAhAAA+gAAAIDoAAB1MAAA6mAAADqYAAAXcJy6UTwAAACrUExURf///wBZkABZkABZkABZ\nkABZkABZkABZkABZkABZkABZkABZkBdSUhdSUhdTUxdSUhdSUhdSUhdSUhdSUhdSUhdSUhdSUhdS\nUhdSUhdSUhdSUhdSUgBZkABZkBdSUgBZkBdSUj+PVz+PVz+PVz+PVz+PVz+PVz+PVz+PVz+PVz+P\nVz+PVz+PVz+PVz+PVxdSUhdSUhdSUhdTUxdSUhhUVABZkBdSUj+PV////2WyAXMAAAA1dFJOUwAR\niO6ZuyJVZjPdqkR3dZnMu6OIZjMRqt3uVSJEzMd31rsziN2ZIkQRqlXuZsx3r+HSW+wg7fJpnQAA\nAAFiS0dEAIgFHUgAAAAJcEhZcwAAAEgAAABIAEbJaz4AAAAHdElNRQfhCwINNSfD9n+TAAASs0lE\nQVR42u2dCZuiuBqFU9pVXVWtrQ6uZffMZVFwq5m7SP7/P7skiFuxhiAJnPd5ulAgIZpD8sVODoQA\nAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA\nAAAAAAAAAAAAAAAAAAAAaDZPne6x++2p7mIAZfh27D53j891FwMoQ6fzQp6Ox7qLAZTiO9oIcEUn\niCNe6i4EUIjX5273te5CAKV4O3brLgJQhvfjE3lBZAnOPPPR54+6iwGU4ekbfqECAAAAgCi9/s9+\nr+5CADXo9QdD3x/9MfL94QCyaDcnMRjjCX87GRuQRWu5E8MFyKJ1THvGLF4MF06ymBm9ad3FBVXC\nxDD3FyPjY5Lj7MmHMVr4c8iimVzEsCyUbglZNA9RMVyALBpDWJVlxFBNXqAGqrmvy7c3oAaqbuQh\nC4143Pig2LgF1EAdvyHk+W0D1EDNvzMm/v4JakCZH50hCxUYqyGGCydZjOsuR2vpKSQG1UsFAABA\nUZ5+/aq7CK1EXReQH1geVgcdZV1Afh2hiDp47bwp6gLy3lWyWK3g5fhedxG+0jl+hyJq4un3UT3P\nh5fuDwJF1MPbu4KCID+6L1BEPbx2uwoO8t6Ovzud47GjoFabztvxXUXjKDbQYKg4CGo478f354C3\nussRB3qNOjjdiwp2HFAEAAAAoAG9n0pOTZmOf46x9uvxLI354s/F3FBrycT0wxj68z/n/tD4gCoe\nyHQ88kfBjRht1aBnjPzFbMw0uhzPFv7IULINayC9wVXbwNqKQf3f/KQ/8/1Zf5K+C1RAjAJuFFJH\nkZIbhKtmA1RBUi9RX+8xHQ8ygoYwtBgo07c1idS2oIbeI39dZ+sGFCZHjT+09yjcHyDYlEneXuFB\nvYdwzIhgUw6F7v2qe4/StzqCzZII1HBlvYescADBpjCivUAFvYfsWkSwKUCpe11q71FRS49gswgS\nalRO71FxNIhgMxeyWv2y+TzoJkawmYUhMTJkbY0hWIwHdvRhmCJY0ObTk3tTimbXe/BgYDpGRAEA\nAIVpryFHez95Kr/e5S53EHQeqcOwBFYkcfx1fJb6vQg6j9RhWAIrklj+epW7JErQeaQOwxJYkSQh\n/XsRdB55sGEJrEgSkf29CDqPPNiwBFYkyUj+XgSdRx5tWAIrkmTkfi+CziOPNiyBFUkKUhUh6Dzy\ncMMSWJEk8YvfKR1p96eg80g9hiXoNWLohLdKR1Z+gs4j9RiWQBEAAACAaix/6jFfrdf/2cdzqGPx\npebWX/xLbF6d3GKkMflgDw/1R3+MfH+OR4h+xZd4Ty9HizH5WAwFJqzJLEYS015/wGQwix4wvPzg\njxAdDvDY+it8efMNjcWMfbFTwzcKf8ESixEHe34xaxiMmCe99canx9Zj6iVHWlVMhouP08ve5eXD\ni3HPqWFgzy9epp5msIdFDmd4QLmsqjD82VXDYCxGxb7YKhTBnlg7L3TzXzUl7e1G5FTFZDi/zScI\nKQpFmHIVcbnje8Xv+KBVMdocdY4kVEVs5FAswpRRDM5EUlTQ3qhTQlX05vO4TApFmBKKETUMMuuw\njVFn6apIqfgCEWa5Ylwqrop2vmVR56jk+sfefJhSC7kjTNFiLHuPatxbE3WWU8R05qenzxthChQj\nbBjmFTUMCZ+3BVFnKUUE0WPmF5MvwixUjFPDMBrUdLs2O+osoYjpLNf9nyvCzFsMhe7RpkadA2FF\njHP/CpUjwswuxvL0P1R1NQzxNDDqNEQVMVj0C1xlMShVjPHpf6hqbxjiiaLOcd0FqZVivwYuS9pc\n9dVvmac6FBIAAEQxqUlMK+0EK3deQDdsGuCs1le7XGp5hLqXHeZ9ouuDdzx9K+QBUvB0qakfxose\nxTxh0+CPtaHbyy6X7bqqdIveJ0pRxO9iHiAFT5ea+mGwlWkaFPOEHVb3ZucRsnXobuW5QaNhskr3\nVkHrsSXsvX06xs9xtimK6HSeyEv+RVEFT5ea+lG8sGI+2hdFnJMiLLon7s4l68Pq3EasHI/sqcff\nR8e8QBbeJkUR/Cso5gFSzjLkwYYjovylWxtBgmZgswq2W3pWhLcOt+x9dGxPg53rdEU8/T5+L1CC\ngqdLTf0ono/vesURhCvCpBw3UsR6w99yRUTHwtNTFfGr2y3ivlDwdKmpH0YnkIQulojnXmNLTDvc\nFSnC3KzPbUR0LFsRr8ffRRb3FzxdaupH8qZH50YukaUT/DuEuyJFsHpfR73G6RjvNawURQR1VKR5\nLHi61NSP4lf3XYcAOOI0+txZTAmfwfhic1aEY5O1ySJOHl2Gx9a7FRuDJCoi+OC/mQdIzqvnOd2W\ndbG6eOry0ee3usuRE/4L1S78hSoYWdKgp4gU4Tr0YNlBgHEIR5/sWDj6TP7NMnLtyXn1PKdTW9LF\nauPlx1GjX6jUx/TqLgFQCm9TdwkAAKBaemOjaVMuC1FwNp7w5D0Jqatm0jPYHFx/9PffI59NxTV6\nis7+qxQogs2jGxujIZ8a3o8mGS7ZbPG57w/ZHG2FpghXT7sVMfkwmCuFPxokzAbOPKF5tFQRxZqA\nuEaksbROEUHlzgTDhHOgMWty6NkeRUirzxKa0oEWKKKaNr+xoWejFfGAuLB5oWczFfHoG7hJoWfT\nFFFrJ9+I0LMxilCoNvQOPfVXhKottq6hp86K0CKq06KQ12ipCP1uP1Ubshi0U8RU4y76KtjRQsea\nUHvgWJ6gwai7CAAAIIZlZqwDvMGll1QFLgA0gq8lzz2l3nMvqfJg5j81F3rZfSShtreKWIXlTMXs\nTaQqQi+7jySU8VZhpiOHLSEb16S7T8KXiJmflPcal307utleKtFmq8MsQvZsG7pTRKmi7JLh9iYu\nDdI61uX81dah/DqOTQqimd1HEsp4q5grj3zSNaGHNd/azGTACRVx3ues2YLBKMme7oOaNPkaY2/D\nV5meU0XZpRDamwRKYmuXL5ffu8QOrufu9iIfQx+7jzRU8FaxeO0dbEJZRbDFw8x45NRGnPeFbccl\nEeH1atFw8WjwMkp1zi4Frgh2nnt1Pl8aysVgCwSdOtl9pKCEt4ob+oyY4dCCqcAmYaWdzAbO20vX\n760cloaJ5GDzFennVOfsUi95yuz6/PBSIcU/hU52H8mo4a1yrucCijAP1unter/a2beKyHfJiyLu\nLi+KPnYfiSjirWJFpodRldz0GtE+3pqfK5vvjt6GrleXXmObeclrRdxeflc4qmRoZveRhDLeKiaP\n5qxz7ds7N4wirxSxOtxElkH1e5876n2ywYLt8MgyShVll4J7Ms8LvZCuL29zO73Pgp9AM7uPBNTx\nVuGWl5+XNoK9N/fmjSLYvqvRp8WMMdeOGY1Cg5o9p4qyS+PARp8kVMTN5dnoc2cXNptohN2HLt4q\nV2x3dZcAKANzyGXmqACcsE8GygDowbTfx/QjcGYyWPzzz2Kg4Xy6O1RYE90AxiN/1iOkN9P/2WrG\noHwebWdpLBbG8utrPTFGdZdAd760C6f2QlegiFJM+/OY2CGIKebaRplQRAmSaz5eKVoARQiT0Tvo\nGmX2oAgh8kSQekaZUIQIue9/DaNMKKIwxWIE7aLMnl93CTSjeA1rFmVCEYUQ7AV0ijKhiPyUiRT1\niTKhiLyUvs81iTKXUEQuljJiARaDqN9OQBH5kDNemPbr/hwAAAAAeCyyfDaa4ToC5PlsNMN1BEjz\n2WiI6wgIkeWz0QzXESDLZ6MhriNAms9GM1xHgEyfjQa4jgBZPhsNcR0B0nw29HAdcYU8dtqFNJ8N\nLVxHbJm+rwAAAACQB1xHGoMUn42Pmf/PP/7so+4PAyRQ3mdjaYRTNdlkSz1mZYM0Sq6Yno5H/mg8\njXkDNKWUIr42C1GDAbRFXBHT/jw2dAiCCr0WhIIbRBWRVu9JWgE6ILSGPrtvQJipLcUVkTN+RJip\nKUUVUeTmR5ipI4VWTBcPEBBmakcBRYjVLsJMzciriDI9AMJMnciliNJRIsJMfcjhsyHnFkeYqQtZ\niljKCwNYIILOQ39kDhVgPAIAAPpiiT83GjSSVeYKBlmWI1XkBuSTvaZFluVIFbkB+bgmsQ6U8kfZ\nxyPJcqSS3EAlODYhe5r6bHpZliNV5Aaks/vMOkOS5UgluQH57HfOapt2gizLkSpyA1XgbW3nkHxY\nouWI9NxAVXg08XcJSZYjleQGKmEXRJZbuk44KstypIrcQDWw0echMZCQZjlSQW4AAAAAaBnLfh+T\nqHRHiuUIZ9kf+v/+tz+EKPRGkiLYbGyuBaYLiEJnZChiOp75V9O5IQqtKa2IUA538/IhCn0pp4hY\nOYRAFJpSRhEfbO1PyqodiEJHhBURyGExyHx6NEShHWKK4HLIuRQMotALAUX0CsghBKLQiKKKmLAV\nvwILRSEKXSikiEAO/kzYFQCi0IL8iignhxCIQn1yKmIpQQ6nnCAKtcmjCFaJMi1kIAqVyVbEsorq\n46KAJjSlort5CfsRAADQGcusuwRALdwYv4lORSYhT1VlDCQSo4hOVSYh3+A+oggb16S7veVQxw0V\nwP5YG0o3XvByH+y/MZt47byRpyqWaXU6L9VkDApCD2vySc01WTlnRXjOyvM2JnF3gSw2zn2SykxC\nvqONUAC6P/0JlBApgi8VXrM2Yx3TdTz9PlbiCRF0SN9e6v46AOHOAezPlSIiB7NzL3LN23s1giCv\nz7AfUYGiinjtdn9VVZa3Y7furwN8VcSe9RpBNLm24xTxdnyvpmV/P8LXTg2uFbGmn2R9YJHlJogs\nD3GKeD++M5MQ+V5Cz3z0+aPurwPcKIJ87qjpmlejT3KviJNJiPyOA/64AAAAAChB72fmijCgJEWe\nIZub6Xjo/+kP8fw/HfHlT6pbGou5MSVTY77Ag0L1w5fduH+M/NFHzGugCXIV8aVdOLUXQB9kKqI3\niIkdWEyR7TsAlEGaIlJqPlYpQFFGchSR0TsgytQHKYrIE0EiytSE8orIff8jytSCUUl/w0IxAqJM\nDSilCIEaRpSpOiUUIdgLIMpUG2FFlIkUEWUqzEBIEaXvc0SZyiLkeGlIiAVYDCLtWQ+gZnpyxguS\nsgEAAADA4/GEXGekuYXAdUQ5tlQklSy3kMrsTIAoNqXUJZ8OpRurQDJZbiGV2ZkAYfhSQLon3mbn\nFUspyy2kMjsTIARTxGETvLDotkg6aW4hVdmZAEGYIqjNXoV/8yLLLaQyOxMgiKgiJLmFVGlnAoQQ\n6zVkuYVUZmcChHGpR/ZBZLk2nQKpZLmFVGZnAoTxDtHoc10glSy3kMrsTAAAAADQBqb/wby6hiDH\nhmQ6/O8QkmgGUhSxHA7/N8Tz3ZqBDEVMFoPg72AxqfvDAAlIUEQoCEiiIZRXRCQISKIZlFbEx+Ky\nVMNYYHmX9pRVxNgfJ74DOlJSEfcSgCS0p5wivgoAktCdUoqIixuu4wqgIWUUET+2uIw9gI6UUETS\nYBOS0BpxRST/+gBJ6IyoIqZp/4+xHOL/vbRFVBGj1DqfDkd1fzDwYCbpjcAUv2cDAEBrMYusOAct\ngD+O9oosBxE4jDSdO0VkOYjAYURJtuY+qMitQ3crjxDboc6e8O3OJsQ1XYeaa5PutkGfsDWpY9mU\nLyKOEphucPCT+RnR3f5OEVkOInAYURKXblzP3blkfViRz2DrBvVqB//cQBJuUPtbtiRw5QQtgMm2\nq2DHmkQJCD2syWewY7WzvNV9r0GyHUTgMKIcLg3CwU1Qt8ykitUxcdfhNlABWzlM6P5kM7Anpx3u\nOQHfyXYESiHrr4rIchCBw4h6sLomJuW4kasE39o0PMjqmSvitGUvogRh6MC2NvkaWWY6iMBhREFC\nRVwrgeRRRGRIkqqILAcROIyoCK/jzSF8w91G9m64XR2SFRElOCsirtfIchCBw4iS8Dp26SfxVpsg\nRmSR5TaMLKmdrIgowVkRq93as+8UkeUgAocRJeF1zAaT3GWEjT6DsSTf2iRZEecEkSL46NO8tTLK\nchCBwwgAAAAAAMhP1mw8OU4lQB8yFQF/kZaRqQg8yq1lQBHgFigC3AJFgFuyFDGCIloGFAFugSLA\nLZmKgK9Iy4AiwC1QBLgFigC3ZCliAEW0jCxFGFAEAAAAAL5gmcnHTGrylWFwIGkNZrQQJBaXWh5x\nPRKzlhw0E4tmKCLcQhFtwaWU2i7dO9QJ+gVvRelhe3eY9xpsueC1ZwloLKwRcHcbj2ycoAdZedx0\n5PZwpIhrzxLQWLgimAiCFxYXw8G+O3xWxJVnCWgsLj0vOud9BO8mbg+fFUGu1hiDpnKriNjDUESr\nuFaERbdxh6GIVsGig0gRxDwwbzvr5jCBIlrGgY0+SVjPbPTJbUrOQBEAAAAAAAAAAABQnf8DmwKy\nJJha+OIAAAAldEVYdGRhdGU6Y3JlYXRlADIwMTctMTEtMDJUMjI6NTM6MzkrMDk6MDBXyJTvAAAA\nJXRFWHRkYXRlOm1vZGlmeQAyMDE3LTExLTAyVDIyOjUzOjM5KzA5OjAwJpUsUwAAACN0RVh0cHM6\nSGlSZXNCb3VuZGluZ0JveAA1Mjl4NDA1LTI2NC0yMDKzNU7+AAAAHHRFWHRwczpMZXZlbABBZG9i\nZS0zLjAgRVBTRi0zLjAKm3C74wAAACJ0RVh0cHM6U3BvdENvbG9yLTAAZm9udCBMaWJlcmF0aW9u\nU2Fuc/4Zp8YAAAAASUVORK5CYII=\n",
162 | "text/plain": [
163 | ""
164 | ]
165 | },
166 | "metadata": {},
167 | "output_type": "display_data"
168 | }
169 | ],
170 | "source": [
171 | "draw_nltk_tree(nltkTree.fromstring(sample))"
172 | ]
173 | },
174 | {
175 | "cell_type": "markdown",
176 | "metadata": {},
177 | "source": [
178 | "### Tree Class "
179 | ]
180 | },
181 | {
182 | "cell_type": "markdown",
183 | "metadata": {},
184 | "source": [
185 | "borrowed code from https://github.com/bogatyy/cs224d/tree/master/assignment3"
186 | ]
187 | },
188 | {
189 | "cell_type": "code",
190 | "execution_count": 10,
191 | "metadata": {
192 | "collapsed": false
193 | },
194 | "outputs": [],
195 | "source": [
196 | "class Node: # a node in the tree\n",
197 | " def __init__(self, label, word=None):\n",
198 | " self.label = label\n",
199 | " self.word = word\n",
200 | " self.parent = None # reference to parent\n",
201 | " self.left = None # reference to left child\n",
202 | " self.right = None # reference to right child\n",
203 | " # true if I am a leaf (could have probably derived this from if I have\n",
204 | " # a word)\n",
205 | " self.isLeaf = False\n",
206 | " # true if we have finished performing fowardprop on this node (note,\n",
207 | " # there are many ways to implement the recursion.. some might not\n",
208 | " # require this flag)\n",
209 | "\n",
210 | " def __str__(self):\n",
211 | " if self.isLeaf:\n",
212 | " return '[{0}:{1}]'.format(self.word, self.label)\n",
213 | " return '({0} <- [{1}:{2}] -> {3})'.format(self.left, self.word, self.label, self.right)\n",
214 | "\n",
215 | "\n",
216 | "class Tree:\n",
217 | "\n",
218 | " def __init__(self, treeString, openChar='(', closeChar=')'):\n",
219 | " tokens = []\n",
220 | " self.open = '('\n",
221 | " self.close = ')'\n",
222 | " for toks in treeString.strip().split():\n",
223 | " tokens += list(toks)\n",
224 | " self.root = self.parse(tokens)\n",
225 | " # get list of labels as obtained through a post-order traversal\n",
226 | " self.labels = get_labels(self.root)\n",
227 | " self.num_words = len(self.labels)\n",
228 | "\n",
229 | " def parse(self, tokens, parent=None):\n",
230 | " assert tokens[0] == self.open, \"Malformed tree\"\n",
231 | " assert tokens[-1] == self.close, \"Malformed tree\"\n",
232 | "\n",
233 | " split = 2 # position after open and label\n",
234 | " countOpen = countClose = 0\n",
235 | "\n",
236 | " if tokens[split] == self.open:\n",
237 | " countOpen += 1\n",
238 | " split += 1\n",
239 | " # Find where left child and right child split\n",
240 | " while countOpen != countClose:\n",
241 | " if tokens[split] == self.open:\n",
242 | " countOpen += 1\n",
243 | " if tokens[split] == self.close:\n",
244 | " countClose += 1\n",
245 | " split += 1\n",
246 | "\n",
247 | " # New node\n",
248 | " node = Node(int(tokens[1])) # zero index labels\n",
249 | "\n",
250 | " node.parent = parent\n",
251 | "\n",
252 | " # leaf Node\n",
253 | " if countOpen == 0:\n",
254 | " node.word = ''.join(tokens[2:-1]).lower() # lower case?\n",
255 | " node.isLeaf = True\n",
256 | " return node\n",
257 | "\n",
258 | " node.left = self.parse(tokens[2:split], parent=node)\n",
259 | " node.right = self.parse(tokens[split:-1], parent=node)\n",
260 | "\n",
261 | " return node\n",
262 | "\n",
263 | " def get_words(self):\n",
264 | " leaves = getLeaves(self.root)\n",
265 | " words = [node.word for node in leaves]\n",
266 | " return words\n",
267 | "\n",
268 | "def get_labels(node):\n",
269 | " if node is None:\n",
270 | " return []\n",
271 | " return get_labels(node.left) + get_labels(node.right) + [node.label]\n",
272 | "\n",
273 | "def getLeaves(node):\n",
274 | " if node is None:\n",
275 | " return []\n",
276 | " if node.isLeaf:\n",
277 | " return [node]\n",
278 | " else:\n",
279 | " return getLeaves(node.left) + getLeaves(node.right)\n",
280 | "\n",
281 | " \n",
282 | "def loadTrees(dataSet='train'):\n",
283 | " \"\"\"\n",
284 | " Loads training trees. Maps leaf node words to word ids.\n",
285 | " \"\"\"\n",
286 | " file = '../dataset/trees/%s.txt' % dataSet\n",
287 | " print(\"Loading %s trees..\" % dataSet)\n",
288 | " with open(file, 'r',encoding='utf-8') as fid:\n",
289 | " trees = [Tree(l) for l in fid.readlines()]\n",
290 | "\n",
291 | " return trees"
292 | ]
293 | },
294 | {
295 | "cell_type": "code",
296 | "execution_count": 11,
297 | "metadata": {
298 | "collapsed": false
299 | },
300 | "outputs": [
301 | {
302 | "name": "stdout",
303 | "output_type": "stream",
304 | "text": [
305 | "Loading train trees..\n"
306 | ]
307 | }
308 | ],
309 | "source": [
310 | "train_data = loadTrees('train')"
311 | ]
312 | },
313 | {
314 | "cell_type": "markdown",
315 | "metadata": {},
316 | "source": [
317 | "### Build Vocab "
318 | ]
319 | },
320 | {
321 | "cell_type": "code",
322 | "execution_count": 12,
323 | "metadata": {
324 | "collapsed": false
325 | },
326 | "outputs": [],
327 | "source": [
328 | "vocab = list(set(flatten([t.get_words() for t in train_data])))"
329 | ]
330 | },
331 | {
332 | "cell_type": "code",
333 | "execution_count": 13,
334 | "metadata": {
335 | "collapsed": true
336 | },
337 | "outputs": [],
338 | "source": [
339 | "word2index={'':0}\n",
340 | "for vo in vocab:\n",
341 | " if vo not in word2index.keys():\n",
342 | " word2index[vo]=len(word2index)\n",
343 | " \n",
344 | "index2word = {v:k for k,v in word2index.items()}"
345 | ]
346 | },
347 | {
348 | "cell_type": "markdown",
349 | "metadata": {},
350 | "source": [
351 | "## Modeling "
352 | ]
353 | },
354 | {
355 | "cell_type": "markdown",
356 | "metadata": {},
357 | "source": [
358 | "
\n",
359 | "borrowed image from https://nlp.stanford.edu/~socherr/EMNLP2013_RNTN.pdf"
360 | ]
361 | },
362 | {
363 | "cell_type": "code",
364 | "execution_count": 14,
365 | "metadata": {
366 | "collapsed": false
367 | },
368 | "outputs": [],
369 | "source": [
370 | "class RNTN(nn.Module):\n",
371 | " \n",
372 | " def __init__(self,word2index,hidden_size,output_size):\n",
373 | " super(RNTN,self).__init__()\n",
374 | " \n",
375 | " self.word2index = word2index\n",
376 | " self.embed = nn.Embedding(len(word2index),hidden_size)\n",
377 | "# self.V = nn.ModuleList([nn.Linear(hidden_size*2,hidden_size*2) for _ in range(hidden_size)])\n",
378 | "# self.W = nn.Linear(hidden_size*2,hidden_size)\n",
379 | " self.V = nn.ParameterList([nn.Parameter(torch.randn(hidden_size*2,hidden_size*2)) for _ in range(hidden_size)]) # Tensor\n",
380 | " self.W = nn.Parameter(torch.randn(hidden_size*2,hidden_size))\n",
381 | " self.b = nn.Parameter(torch.randn(1,hidden_size))\n",
382 | "# self.W_out = nn.Parameter(torch.randn(hidden_size,output_size))\n",
383 | " self.W_out = nn.Linear(hidden_size,output_size)\n",
384 | " \n",
385 | " def init_weight(self):\n",
386 | " nn.init.xavier_uniform(self.embed.state_dict()['weight'])\n",
387 | " nn.init.xavier_uniform(self.W_out.state_dict()['weight'])\n",
388 | " for param in self.V.parameters():\n",
389 | " nn.init.xavier_uniform(param)\n",
390 | " nn.init.xavier_uniform(self.W)\n",
391 | " self.b.data.fill_(0)\n",
392 | "# nn.init.xavier_uniform(self.W_out)\n",
393 | " \n",
394 | " def tree_propagation(self,node):\n",
395 | " \n",
396 | " recursive_tensor = OrderedDict()\n",
397 | " current=None\n",
398 | " if node.isLeaf:\n",
399 | " tensor = Variable(LongTensor([self.word2index[node.word]])) if node.word in self.word2index.keys() \\\n",
400 | " else Variable(LongTensor([self.word2index['']]))\n",
401 | " current = self.embed(tensor) # 1xD\n",
402 | " else:\n",
403 | " recursive_tensor.update(self.tree_propagation(node.left))\n",
404 | " recursive_tensor.update(self.tree_propagation(node.right))\n",
405 | " \n",
406 | " concated = torch.cat([recursive_tensor[node.left],recursive_tensor[node.right]],1) # 1x2D\n",
407 | " xVx=[] \n",
408 | " for i,v in enumerate(self.V):\n",
409 | "# xVx.append(torch.matmul(v(concated),concated.transpose(0,1)))\n",
410 | " xVx.append(torch.matmul(torch.matmul(concated,v),concated.transpose(0,1)))\n",
411 | " \n",
412 | " xVx = torch.cat(xVx,1) # 1xD\n",
413 | "# Wx = self.W(concated)\n",
414 | " Wx = torch.matmul(concated,self.W) # 1xD\n",
415 | "\n",
416 | " current = F.tanh(xVx+Wx+self.b) # 1xD\n",
417 | " recursive_tensor[node]=current\n",
418 | " return recursive_tensor\n",
419 | " \n",
420 | " def forward(self,Trees,root_only=False):\n",
421 | " \n",
422 | " propagated=[]\n",
423 | " if not isinstance(Trees,list):\n",
424 | " Trees = [Trees]\n",
425 | " \n",
426 | " for Tree in Trees:\n",
427 | " recursive_tensor = self.tree_propagation(Tree.root)\n",
428 | " if root_only:\n",
429 | " recursive_tensor = recursive_tensor[Tree.root]\n",
430 | " propagated.append(recursive_tensor)\n",
431 | " else:\n",
432 | " recursive_tensor = [tensor for node,tensor in recursive_tensor.items()]\n",
433 | " propagated.extend(recursive_tensor)\n",
434 | " \n",
435 | " propagated = torch.cat(propagated) # (num_of_node in batch, D)\n",
436 | " \n",
437 | "# return F.log_softmax(propagated.matmul(self.W_out))\n",
438 | " return F.log_softmax(self.W_out(propagated))"
439 | ]
440 | },
441 | {
442 | "cell_type": "markdown",
443 | "metadata": {},
444 | "source": [
445 | "## Training "
446 | ]
447 | },
448 | {
449 | "cell_type": "markdown",
450 | "metadata": {},
451 | "source": [
452 | "It takes for a while... It builds its computational graph dynamically. So Its computation is difficult to train with batch."
453 | ]
454 | },
455 | {
456 | "cell_type": "code",
457 | "execution_count": 15,
458 | "metadata": {
459 | "collapsed": true
460 | },
461 | "outputs": [],
462 | "source": [
463 | "HIDDEN_SIZE = 30\n",
464 | "ROOT_ONLY = False\n",
465 | "BATCH_SIZE = 20\n",
466 | "EPOCH = 20\n",
467 | "LR = 0.01\n",
468 | "LAMBDA = 1e-5\n",
469 | "RESCHEDULED=False"
470 | ]
471 | },
472 | {
473 | "cell_type": "code",
474 | "execution_count": 18,
475 | "metadata": {
476 | "collapsed": false
477 | },
478 | "outputs": [],
479 | "source": [
480 | "model = RNTN(word2index,HIDDEN_SIZE,5)\n",
481 | "model.init_weight()\n",
482 | "if USE_CUDA:\n",
483 | " model = model.cuda()\n",
484 | "\n",
485 | "loss_function = nn.CrossEntropyLoss()\n",
486 | "optimizer = optim.Adam(model.parameters(),lr=LR)"
487 | ]
488 | },
489 | {
490 | "cell_type": "code",
491 | "execution_count": 19,
492 | "metadata": {
493 | "collapsed": false
494 | },
495 | "outputs": [
496 | {
497 | "name": "stdout",
498 | "output_type": "stream",
499 | "text": [
500 | "[0/20] mean_loss : 1.62\n",
501 | "[0/20] mean_loss : 1.25\n",
502 | "[0/20] mean_loss : 0.95\n",
503 | "[0/20] mean_loss : 0.90\n",
504 | "[0/20] mean_loss : 0.88\n",
505 | "[1/20] mean_loss : 0.88\n",
506 | "[1/20] mean_loss : 0.84\n",
507 | "[1/20] mean_loss : 0.83\n",
508 | "[1/20] mean_loss : 0.82\n",
509 | "[1/20] mean_loss : 0.82\n",
510 | "[2/20] mean_loss : 0.81\n",
511 | "[2/20] mean_loss : 0.79\n",
512 | "[2/20] mean_loss : 0.78\n",
513 | "[2/20] mean_loss : 0.76\n",
514 | "[2/20] mean_loss : 0.75\n",
515 | "[3/20] mean_loss : 0.68\n",
516 | "[3/20] mean_loss : 0.73\n",
517 | "[3/20] mean_loss : 0.74\n",
518 | "[3/20] mean_loss : 0.72\n",
519 | "[3/20] mean_loss : 0.72\n",
520 | "[4/20] mean_loss : 0.74\n",
521 | "[4/20] mean_loss : 0.69\n",
522 | "[4/20] mean_loss : 0.69\n",
523 | "[4/20] mean_loss : 0.68\n",
524 | "[4/20] mean_loss : 0.67\n",
525 | "[5/20] mean_loss : 0.73\n",
526 | "[5/20] mean_loss : 0.65\n",
527 | "[5/20] mean_loss : 0.64\n",
528 | "[5/20] mean_loss : 0.64\n",
529 | "[5/20] mean_loss : 0.65\n",
530 | "[6/20] mean_loss : 0.67\n",
531 | "[6/20] mean_loss : 0.62\n",
532 | "[6/20] mean_loss : 0.62\n",
533 | "[6/20] mean_loss : 0.62\n",
534 | "[6/20] mean_loss : 0.62\n",
535 | "[7/20] mean_loss : 0.57\n",
536 | "[7/20] mean_loss : 0.59\n",
537 | "[7/20] mean_loss : 0.59\n",
538 | "[7/20] mean_loss : 0.59\n",
539 | "[7/20] mean_loss : 0.59\n",
540 | "[8/20] mean_loss : 0.60\n",
541 | "[8/20] mean_loss : 0.58\n",
542 | "[8/20] mean_loss : 0.59\n",
543 | "[8/20] mean_loss : 0.60\n",
544 | "[8/20] mean_loss : 0.60\n",
545 | "[9/20] mean_loss : 0.52\n",
546 | "[9/20] mean_loss : 0.58\n",
547 | "[9/20] mean_loss : 0.60\n",
548 | "[9/20] mean_loss : 0.59\n",
549 | "[9/20] mean_loss : 0.59\n",
550 | "[10/20] mean_loss : 0.56\n",
551 | "[10/20] mean_loss : 0.56\n",
552 | "[10/20] mean_loss : 0.56\n",
553 | "[10/20] mean_loss : 0.56\n",
554 | "[10/20] mean_loss : 0.56\n",
555 | "[11/20] mean_loss : 0.52\n",
556 | "[11/20] mean_loss : 0.54\n",
557 | "[11/20] mean_loss : 0.54\n",
558 | "[11/20] mean_loss : 0.54\n",
559 | "[11/20] mean_loss : 0.55\n",
560 | "[12/20] mean_loss : 0.55\n",
561 | "[12/20] mean_loss : 0.53\n",
562 | "[12/20] mean_loss : 0.53\n",
563 | "[12/20] mean_loss : 0.53\n",
564 | "[12/20] mean_loss : 0.53\n",
565 | "[13/20] mean_loss : 0.59\n",
566 | "[13/20] mean_loss : 0.52\n",
567 | "[13/20] mean_loss : 0.52\n",
568 | "[13/20] mean_loss : 0.53\n",
569 | "[13/20] mean_loss : 0.53\n",
570 | "[14/20] mean_loss : 0.49\n",
571 | "[14/20] mean_loss : 0.51\n",
572 | "[14/20] mean_loss : 0.51\n",
573 | "[14/20] mean_loss : 0.52\n",
574 | "[14/20] mean_loss : 0.52\n",
575 | "[15/20] mean_loss : 0.43\n",
576 | "[15/20] mean_loss : 0.51\n",
577 | "[15/20] mean_loss : 0.51\n",
578 | "[15/20] mean_loss : 0.51\n",
579 | "[15/20] mean_loss : 0.51\n",
580 | "[16/20] mean_loss : 0.46\n",
581 | "[16/20] mean_loss : 0.50\n",
582 | "[16/20] mean_loss : 0.50\n",
583 | "[16/20] mean_loss : 0.50\n",
584 | "[16/20] mean_loss : 0.50\n",
585 | "[17/20] mean_loss : 0.50\n",
586 | "[17/20] mean_loss : 0.50\n",
587 | "[17/20] mean_loss : 0.50\n",
588 | "[17/20] mean_loss : 0.50\n",
589 | "[17/20] mean_loss : 0.51\n",
590 | "[18/20] mean_loss : 0.46\n",
591 | "[18/20] mean_loss : 0.50\n",
592 | "[18/20] mean_loss : 0.50\n",
593 | "[18/20] mean_loss : 0.49\n",
594 | "[18/20] mean_loss : 0.49\n",
595 | "[19/20] mean_loss : 0.49\n",
596 | "[19/20] mean_loss : 0.49\n",
597 | "[19/20] mean_loss : 0.49\n",
598 | "[19/20] mean_loss : 0.50\n",
599 | "[19/20] mean_loss : 0.50\n"
600 | ]
601 | }
602 | ],
603 | "source": [
604 | "for epoch in range(EPOCH):\n",
605 | " losses=[]\n",
606 | " \n",
607 | " # learning rate annealing\n",
608 | " if RESCHEDULED==False and epoch==EPOCH//2:\n",
609 | " LR=LR*0.1\n",
610 | " optimizer = optim.Adam(model.parameters(),lr=LR,weight_decay=LAMBDA) # L2 norm\n",
611 | " RESCHEDULED=True\n",
612 | " \n",
613 | " for i, batch in enumerate(getBatch(BATCH_SIZE,train_data)):\n",
614 | " \n",
615 | " if ROOT_ONLY:\n",
616 | " labels = [tree.labels[-1] for tree in batch]\n",
617 | " labels = Variable(LongTensor(labels))\n",
618 | " else:\n",
619 | " labels = [tree.labels for tree in batch]\n",
620 | " labels = Variable(LongTensor(flatten(labels)))\n",
621 | " \n",
622 | " model.zero_grad()\n",
623 | " preds = model(batch,ROOT_ONLY)\n",
624 | " \n",
625 | " loss = loss_function(preds,labels)\n",
626 | " losses.append(loss.data.tolist()[0])\n",
627 | " \n",
628 | " loss.backward()\n",
629 | " optimizer.step()\n",
630 | " \n",
631 | " if i % 100==0:\n",
632 | " print('[%d/%d] mean_loss : %.2f' % (epoch,EPOCH,np.mean(losses)))\n",
633 | " losses=[]\n",
634 | " "
635 | ]
636 | },
637 | {
638 | "cell_type": "markdown",
639 | "metadata": {},
640 | "source": [
641 | "The convergence of the model is unstable according to the initial values. I tried to 5~6 times for this."
642 | ]
643 | },
644 | {
645 | "cell_type": "markdown",
646 | "metadata": {},
647 | "source": [
648 | "## Test"
649 | ]
650 | },
651 | {
652 | "cell_type": "code",
653 | "execution_count": 20,
654 | "metadata": {
655 | "collapsed": false
656 | },
657 | "outputs": [
658 | {
659 | "name": "stdout",
660 | "output_type": "stream",
661 | "text": [
662 | "Loading test trees..\n"
663 | ]
664 | }
665 | ],
666 | "source": [
667 | "test_data = loadTrees('test')"
668 | ]
669 | },
670 | {
671 | "cell_type": "code",
672 | "execution_count": 21,
673 | "metadata": {
674 | "collapsed": true
675 | },
676 | "outputs": [],
677 | "source": [
678 | "accuracy=0\n",
679 | "num_node=0"
680 | ]
681 | },
682 | {
683 | "cell_type": "markdown",
684 | "metadata": {},
685 | "source": [
686 | "### Fine-grained all"
687 | ]
688 | },
689 | {
690 | "cell_type": "markdown",
691 | "metadata": {},
692 | "source": [
693 | "In paper, they acheived 80.2 accuracy. "
694 | ]
695 | },
696 | {
697 | "cell_type": "code",
698 | "execution_count": 23,
699 | "metadata": {
700 | "collapsed": false
701 | },
702 | "outputs": [
703 | {
704 | "name": "stdout",
705 | "output_type": "stream",
706 | "text": [
707 | "79.33705899068254\n"
708 | ]
709 | }
710 | ],
711 | "source": [
712 | "for test in test_data:\n",
713 | " model.zero_grad()\n",
714 | " preds = model(test,ROOT_ONLY)\n",
715 | " labels = test.labels[-1:] if ROOT_ONLY else test.labels\n",
716 | " for pred,label in zip(preds.max(1)[1].data.tolist(),labels):\n",
717 | " num_node+=1\n",
718 | " if pred==label:\n",
719 | " accuracy+=1\n",
720 | "\n",
721 | "print(accuracy/num_node*100)"
722 | ]
723 | },
724 | {
725 | "cell_type": "markdown",
726 | "metadata": {},
727 | "source": [
728 | "## TODO "
729 | ]
730 | },
731 | {
732 | "cell_type": "markdown",
733 | "metadata": {},
734 | "source": [
735 | "* https://github.com/nearai/pytorch-tools # Dynamic batch using TensorFold"
736 | ]
737 | },
738 | {
739 | "cell_type": "markdown",
740 | "metadata": {
741 | "collapsed": true
742 | },
743 | "source": [
744 | "## Further topics "
745 | ]
746 | },
747 | {
748 | "cell_type": "markdown",
749 | "metadata": {},
750 | "source": [
751 | "* Improved Semantic Representations From Tree-Structured Long Short-Term Memory Networks\n",
752 | "* A Fast Unified Model for Parsing and Sentence Understanding(SPINN)\n",
753 | "* Posting about SPINN"
754 | ]
755 | },
756 | {
757 | "cell_type": "code",
758 | "execution_count": null,
759 | "metadata": {
760 | "collapsed": true
761 | },
762 | "outputs": [],
763 | "source": []
764 | }
765 | ],
766 | "metadata": {
767 | "kernelspec": {
768 | "display_name": "Python 3",
769 | "language": "python",
770 | "name": "python3"
771 | },
772 | "language_info": {
773 | "codemirror_mode": {
774 | "name": "ipython",
775 | "version": 3
776 | },
777 | "file_extension": ".py",
778 | "mimetype": "text/x-python",
779 | "name": "python",
780 | "nbconvert_exporter": "python",
781 | "pygments_lexer": "ipython3",
782 | "version": "3.5.2"
783 | }
784 | },
785 | "nbformat": 4,
786 | "nbformat_minor": 2
787 | }
788 |
--------------------------------------------------------------------------------
/notebooks/10.Dynamic-Memory-Network-for-Question-Answering.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# 10. \tDynamic Memory Networks for Question Answering"
8 | ]
9 | },
10 | {
11 | "cell_type": "markdown",
12 | "metadata": {},
13 | "source": [
14 | "I recommend you take a look at these material first."
15 | ]
16 | },
17 | {
18 | "cell_type": "markdown",
19 | "metadata": {},
20 | "source": [
21 | "* http://web.stanford.edu/class/cs224n/lectures/cs224n-2017-lecture16-DMN-QA.pdf\n",
22 | "* https://arxiv.org/abs/1506.07285"
23 | ]
24 | },
25 | {
26 | "cell_type": "code",
27 | "execution_count": 1,
28 | "metadata": {
29 | "collapsed": true
30 | },
31 | "outputs": [],
32 | "source": [
33 | "import torch\n",
34 | "import torch.nn as nn\n",
35 | "from torch.autograd import Variable\n",
36 | "import torch.optim as optim\n",
37 | "import torch.nn.functional as F\n",
38 | "import nltk\n",
39 | "import random\n",
40 | "import numpy as np\n",
41 | "from collections import Counter, OrderedDict\n",
42 | "import nltk\n",
43 | "from copy import deepcopy\n",
44 | "import os\n",
45 | "import re\n",
46 | "import unicodedata\n",
47 | "flatten = lambda l: [item for sublist in l for item in sublist]\n",
48 | "\n",
49 | "from torch.nn.utils.rnn import PackedSequence,pack_padded_sequence"
50 | ]
51 | },
52 | {
53 | "cell_type": "code",
54 | "execution_count": 2,
55 | "metadata": {
56 | "collapsed": true
57 | },
58 | "outputs": [],
59 | "source": [
60 | "USE_CUDA = torch.cuda.is_available()\n",
61 | "\n",
62 | "FloatTensor = torch.cuda.FloatTensor if USE_CUDA else torch.FloatTensor\n",
63 | "LongTensor = torch.cuda.LongTensor if USE_CUDA else torch.LongTensor\n",
64 | "ByteTensor = torch.cuda.ByteTensor if USE_CUDA else torch.ByteTensor"
65 | ]
66 | },
67 | {
68 | "cell_type": "code",
69 | "execution_count": 3,
70 | "metadata": {
71 | "collapsed": true
72 | },
73 | "outputs": [],
74 | "source": [
75 | "def getBatch(batch_size,train_data):\n",
76 | " random.shuffle(train_data)\n",
77 | " sindex=0\n",
78 | " eindex=batch_size\n",
79 | " while eindex < len(train_data):\n",
80 | " batch = train_data[sindex:eindex]\n",
81 | " temp = eindex\n",
82 | " eindex = eindex+batch_size\n",
83 | " sindex = temp\n",
84 | " yield batch\n",
85 | " \n",
86 | " if eindex >= len(train_data):\n",
87 | " batch = train_data[sindex:]\n",
88 | " yield batch"
89 | ]
90 | },
91 | {
92 | "cell_type": "code",
93 | "execution_count": 4,
94 | "metadata": {
95 | "collapsed": true
96 | },
97 | "outputs": [],
98 | "source": [
99 | "def pad_to_batch(batch,w_to_ix): # for bAbI dataset\n",
100 | " fact,q,a = list(zip(*batch))\n",
101 | " max_fact = max([len(f) for f in fact])\n",
102 | " max_len = max([f.size(1) for f in flatten(fact)])\n",
103 | " max_q = max([qq.size(1) for qq in q])\n",
104 | " max_a = max([aa.size(1) for aa in a])\n",
105 | " \n",
106 | " facts,fact_masks,q_p,a_p=[],[],[],[]\n",
107 | " for i in range(len(batch)):\n",
108 | " fact_p_t=[]\n",
109 | " for j in range(len(fact[i])):\n",
110 | " if fact[i][j].size(1)']]*(max_len-fact[i][j].size(1)))).view(1,-1)],1))\n",
112 | " else:\n",
113 | " fact_p_t.append(fact[i][j])\n",
114 | "\n",
115 | " while len(fact_p_t)']]*max_len)).view(1,-1))\n",
117 | "\n",
118 | " fact_p_t = torch.cat(fact_p_t)\n",
119 | " facts.append(fact_p_t)\n",
120 | " fact_masks.append(torch.cat([Variable(ByteTensor(tuple(map(lambda s: s ==0, t.data))),volatile=False) for t in fact_p_t]).view(fact_p_t.size(0),-1))\n",
121 | "\n",
122 | " if q[i].size(1)']]*(max_q-q[i].size(1)))).view(1,-1)],1))\n",
124 | " else:\n",
125 | " q_p.append(q[i])\n",
126 | "\n",
127 | " if a[i].size(1)']]*(max_a-a[i].size(1)))).view(1,-1)],1))\n",
129 | " else:\n",
130 | " a_p.append(a[i])\n",
131 | "\n",
132 | " questions = torch.cat(q_p)\n",
133 | " answers = torch.cat(a_p)\n",
134 | " question_masks = torch.cat([Variable(ByteTensor(tuple(map(lambda s: s ==0, t.data))),volatile=False) for t in questions]).view(questions.size(0),-1)\n",
135 | " \n",
136 | " return facts, fact_masks, questions, question_masks, answers"
137 | ]
138 | },
139 | {
140 | "cell_type": "code",
141 | "execution_count": 5,
142 | "metadata": {
143 | "collapsed": true
144 | },
145 | "outputs": [],
146 | "source": [
147 | "def prepare_sequence(seq, to_index):\n",
148 | " idxs = list(map(lambda w: to_index[w] if w in to_index.keys() else to_index[\"\"], seq))\n",
149 | " return Variable(LongTensor(idxs))"
150 | ]
151 | },
152 | {
153 | "cell_type": "markdown",
154 | "metadata": {},
155 | "source": [
156 | "## Data load and Preprocessing "
157 | ]
158 | },
159 | {
160 | "cell_type": "markdown",
161 | "metadata": {},
162 | "source": [
163 | "### bAbI dataset(https://research.fb.com/downloads/babi/)"
164 | ]
165 | },
166 | {
167 | "cell_type": "code",
168 | "execution_count": 24,
169 | "metadata": {
170 | "collapsed": false
171 | },
172 | "outputs": [],
173 | "source": [
174 | "def bAbI_data_load(path):\n",
175 | " try:\n",
176 | " data = open(path).readlines()\n",
177 | " except:\n",
178 | " print(\"Such a file does not exist at %s\".format(path))\n",
179 | " return None\n",
180 | " \n",
181 | " data = [d[:-1] for d in data]\n",
182 | " data_p=[]\n",
183 | " fact=[]\n",
184 | " qa=[]\n",
185 | " try:\n",
186 | " for d in data:\n",
187 | " index = d.split(' ')[0]\n",
188 | " if index=='1':\n",
189 | " fact=[]\n",
190 | " qa=[]\n",
191 | " if '?' in d:\n",
192 | " temp = d.split('\\t')\n",
193 | " q = temp[0].strip().replace('?','').split(' ')[1:]+['?']\n",
194 | " a = temp[1].split()+['']\n",
195 | " stemp = deepcopy(fact)\n",
196 | " data_p.append([stemp,q,a])\n",
197 | " else:\n",
198 | " tokens = d.replace('.','').split(' ')[1:]+['']\n",
199 | " fact.append(tokens)\n",
200 | " except:\n",
201 | " print(\"Please check the data is right\")\n",
202 | " return None\n",
203 | " return data_p"
204 | ]
205 | },
206 | {
207 | "cell_type": "code",
208 | "execution_count": 25,
209 | "metadata": {
210 | "collapsed": true
211 | },
212 | "outputs": [],
213 | "source": [
214 | "train_data = bAbI_data_load('../dataset/corpus/bAbI/en-10k/qa5_three-arg-relations_train.txt')"
215 | ]
216 | },
217 | {
218 | "cell_type": "code",
219 | "execution_count": 26,
220 | "metadata": {
221 | "collapsed": false
222 | },
223 | "outputs": [
224 | {
225 | "data": {
226 | "text/plain": [
227 | "[[['Bill', 'travelled', 'to', 'the', 'office', ''],\n",
228 | " ['Bill', 'picked', 'up', 'the', 'football', 'there', ''],\n",
229 | " ['Bill', 'went', 'to', 'the', 'bedroom', ''],\n",
230 | " ['Bill', 'gave', 'the', 'football', 'to', 'Fred', '']],\n",
231 | " ['What', 'did', 'Bill', 'give', 'to', 'Fred', '?'],\n",
232 | " ['football', '']]"
233 | ]
234 | },
235 | "execution_count": 26,
236 | "metadata": {},
237 | "output_type": "execute_result"
238 | }
239 | ],
240 | "source": [
241 | "train_data[0]"
242 | ]
243 | },
244 | {
245 | "cell_type": "code",
246 | "execution_count": 11,
247 | "metadata": {
248 | "collapsed": true
249 | },
250 | "outputs": [],
251 | "source": [
252 | "fact,q,a = list(zip(*train_data))"
253 | ]
254 | },
255 | {
256 | "cell_type": "code",
257 | "execution_count": 12,
258 | "metadata": {
259 | "collapsed": true
260 | },
261 | "outputs": [],
262 | "source": [
263 | "vocab = list(set(flatten(flatten(fact))+flatten(q)+flatten(a)))"
264 | ]
265 | },
266 | {
267 | "cell_type": "code",
268 | "execution_count": 13,
269 | "metadata": {
270 | "collapsed": true
271 | },
272 | "outputs": [],
273 | "source": [
274 | "word2index={'':0,'':1,'':2,'':3}\n",
275 | "for vo in vocab:\n",
276 | " if vo not in word2index.keys():\n",
277 | " word2index[vo]=len(word2index)\n",
278 | "index2word = {v:k for k,v in word2index.items()}"
279 | ]
280 | },
281 | {
282 | "cell_type": "code",
283 | "execution_count": 14,
284 | "metadata": {
285 | "collapsed": false
286 | },
287 | "outputs": [
288 | {
289 | "data": {
290 | "text/plain": [
291 | "44"
292 | ]
293 | },
294 | "execution_count": 14,
295 | "metadata": {},
296 | "output_type": "execute_result"
297 | }
298 | ],
299 | "source": [
300 | "len(word2index)"
301 | ]
302 | },
303 | {
304 | "cell_type": "code",
305 | "execution_count": 15,
306 | "metadata": {
307 | "collapsed": true
308 | },
309 | "outputs": [],
310 | "source": [
311 | "for t in train_data:\n",
312 | " for i,fact in enumerate(t[0]):\n",
313 | " t[0][i] = prepare_sequence(fact,word2index).view(1,-1)\n",
314 | " \n",
315 | " t[1] = prepare_sequence(t[1],word2index).view(1,-1)\n",
316 | " t[2] = prepare_sequence(t[2],word2index).view(1,-1)"
317 | ]
318 | },
319 | {
320 | "cell_type": "markdown",
321 | "metadata": {},
322 | "source": [
323 | "## Modeling "
324 | ]
325 | },
326 | {
327 | "cell_type": "markdown",
328 | "metadata": {},
329 | "source": [
330 | "
\n",
331 | "borrowed image from https://arxiv.org/pdf/1506.07285.pdf"
332 | ]
333 | },
334 | {
335 | "cell_type": "code",
336 | "execution_count": 16,
337 | "metadata": {
338 | "collapsed": true
339 | },
340 | "outputs": [],
341 | "source": [
342 | "class DMN(nn.Module):\n",
343 | " def __init__(self, input_size,hidden_size,output_size,dropout_p=0.1):\n",
344 | " super(DMN, self).__init__()\n",
345 | " \n",
346 | " self.hidden_size = hidden_size\n",
347 | " self.embed = nn.Embedding(input_size, hidden_size, padding_idx=0) #sparse=True)\n",
348 | " self.input_gru = nn.GRU(hidden_size, hidden_size,batch_first=True)\n",
349 | " self.question_gru = nn.GRU(hidden_size, hidden_size,batch_first=True)\n",
350 | " \n",
351 | " self.gate = nn.Sequential(\n",
352 | " nn.Linear(hidden_size*4,hidden_size),\n",
353 | " nn.Tanh(),\n",
354 | " nn.Linear(hidden_size,1),\n",
355 | " nn.Sigmoid()\n",
356 | " )\n",
357 | " \n",
358 | " self.attention_grucell = nn.GRUCell(hidden_size, hidden_size)\n",
359 | " self.memory_grucell = nn.GRUCell(hidden_size,hidden_size)\n",
360 | " self.answer_grucell = nn.GRUCell(hidden_size*2, hidden_size)\n",
361 | " self.answer_fc = nn.Linear(hidden_size,output_size)\n",
362 | " \n",
363 | " self.dropout = nn.Dropout(dropout_p)\n",
364 | " \n",
365 | " def init_hidden(self,inputs):\n",
366 | " hidden = Variable(torch.zeros(1,inputs.size(0),self.hidden_size))\n",
367 | " return hidden.cuda() if USE_CUDA else hidden\n",
368 | " \n",
369 | " def init_weight(self):\n",
370 | " nn.init.xavier_uniform(self.embed.state_dict()['weight'])\n",
371 | " \n",
372 | " for name, param in self.input_gru.state_dict().items():\n",
373 | " if 'weight' in name: nn.init.xavier_normal(param)\n",
374 | " for name, param in self.question_gru.state_dict().items():\n",
375 | " if 'weight' in name: nn.init.xavier_normal(param)\n",
376 | " for name, param in self.gate.state_dict().items():\n",
377 | " if 'weight' in name: nn.init.xavier_normal(param)\n",
378 | " for name, param in self.attention_grucell.state_dict().items():\n",
379 | " if 'weight' in name: nn.init.xavier_normal(param)\n",
380 | " for name, param in self.memory_grucell.state_dict().items():\n",
381 | " if 'weight' in name: nn.init.xavier_normal(param)\n",
382 | " for name, param in self.answer_grucell.state_dict().items():\n",
383 | " if 'weight' in name: nn.init.xavier_normal(param)\n",
384 | " \n",
385 | " nn.init.xavier_normal(self.answer_fc.state_dict()['weight'])\n",
386 | " self.answer_fc.bias.data.fill_(0)\n",
387 | " \n",
388 | " def forward(self,facts,fact_masks,questions,question_masks,num_decode,episodes=3,is_training=False):\n",
389 | " \"\"\"\n",
390 | " facts : (B,T_C,T_I) / LongTensor in List # batch_size, num_of_facts, length_of_each_fact(padded)\n",
391 | " fact_masks : (B,T_C,T_I) / ByteTensor in List # batch_size, num_of_facts, length_of_each_fact(padded)\n",
392 | " questions : (B,T_Q) / LongTensor # batch_size, question_length\n",
393 | " question_masks : (B,T_Q) / ByteTensor # batch_size, question_length\n",
394 | " \"\"\"\n",
395 | " # Input Module\n",
396 | " C=[] # encoded facts\n",
397 | " for fact,fact_mask in zip(facts,fact_masks):\n",
398 | " embeds = self.embed(fact)\n",
399 | " if is_training:\n",
400 | " embeds = self.dropout(embeds)\n",
401 | " hidden = self.init_hidden(fact)\n",
402 | " outputs,hidden = self.input_gru(embeds,hidden)\n",
403 | " real_hidden=[]\n",
404 | "\n",
405 | " for i,o in enumerate(outputs): # B,T,D\n",
406 | " real_length = fact_mask[i].data.tolist().count(0) \n",
407 | " real_hidden.append(o[real_length-1])\n",
408 | "\n",
409 | " C.append(torch.cat(real_hidden).view(fact.size(0),-1).unsqueeze(0))\n",
410 | " \n",
411 | " encoded_facts = torch.cat(C) # B,T_C,D\n",
412 | " \n",
413 | " # Question Module\n",
414 | " embeds = self.embed(questions)\n",
415 | " if training:\n",
416 | " embeds = self.dropout(embeds)\n",
417 | " hidden = self.init_hidden(questions)\n",
418 | " outputs, hidden = self.question_gru(embeds,hidden)\n",
419 | " \n",
420 | " if isinstance(question_masks,torch.autograd.variable.Variable):\n",
421 | " real_question=[]\n",
422 | " for i,o in enumerate(outputs): # B,T,D\n",
423 | " real_length = question_masks[i].data.tolist().count(0) \n",
424 | " real_question.append(o[real_length-1])\n",
425 | " encoded_question = torch.cat(real_question).view(questions.size(0),-1) # B,D\n",
426 | " else: # for inference mode\n",
427 | " encoded_question = hidden.squeeze(0) # B,D\n",
428 | " \n",
429 | " # Episodic Memory Module\n",
430 | " memory = encoded_question\n",
431 | " T_C = encoded_facts.size(1)\n",
432 | " B = encoded_facts.size(0)\n",
433 | " for i in range(episodes):\n",
434 | " hidden = self.init_hidden(encoded_facts.transpose(0,1)[0]).squeeze(0) # B,D\n",
435 | " for t in range(T_C):\n",
436 | " #TODO: fact masking\n",
437 | " #TODO: gate function => softmax\n",
438 | " z = torch.cat([\n",
439 | " encoded_facts.transpose(0,1)[t]*encoded_question, # B,D , element-wise product\n",
440 | " encoded_facts.transpose(0,1)[t]*memory, # B,D , element-wise product\n",
441 | " torch.abs(encoded_facts.transpose(0,1)[t]-encoded_question), # B,D\n",
442 | " torch.abs(encoded_facts.transpose(0,1)[t]-memory) # B,D\n",
443 | " ],1)\n",
444 | " g_t = self.gate(z) # B,1 scalar\n",
445 | " hidden = g_t*self.attention_grucell(encoded_facts.transpose(0,1)[t],hidden) + (1-g_t)*hidden\n",
446 | " \n",
447 | " e = hidden\n",
448 | " memory = self.memory_grucell(e,memory)\n",
449 | " \n",
450 | " # Answer Module\n",
451 | " answer_hidden = memory\n",
452 | " start_decode = Variable(LongTensor([[word2index['']]*memory.size(0)])).transpose(0,1)\n",
453 | " y_t_1 = self.embed(start_decode).squeeze(1) # B,D\n",
454 | " \n",
455 | " decodes=[]\n",
456 | " for t in range(num_decode):\n",
457 | " answer_hidden = self.answer_grucell(torch.cat([y_t_1,encoded_question],1),answer_hidden)\n",
458 | " decodes.append(F.log_softmax(self.answer_fc(answer_hidden)))\n",
459 | " return torch.cat(decodes,1).view(B*num_decode,-1)\n"
460 | ]
461 | },
462 | {
463 | "cell_type": "markdown",
464 | "metadata": {},
465 | "source": [
466 | "## Train "
467 | ]
468 | },
469 | {
470 | "cell_type": "markdown",
471 | "metadata": {},
472 | "source": [
473 | "It takes for a while if you use just cpu."
474 | ]
475 | },
476 | {
477 | "cell_type": "code",
478 | "execution_count": 17,
479 | "metadata": {
480 | "collapsed": true
481 | },
482 | "outputs": [],
483 | "source": [
484 | "HIDDEN_SIZE=80\n",
485 | "BATCH_SIZE=64\n",
486 | "LR=0.001\n",
487 | "EPOCH=50\n",
488 | "NUM_EPISODE=3\n",
489 | "EARLY_STOPPING=False"
490 | ]
491 | },
492 | {
493 | "cell_type": "code",
494 | "execution_count": 18,
495 | "metadata": {
496 | "collapsed": false
497 | },
498 | "outputs": [],
499 | "source": [
500 | "model = DMN(len(word2index),HIDDEN_SIZE,len(word2index))\n",
501 | "model.init_weight()\n",
502 | "if USE_CUDA:\n",
503 | " model = model.cuda()\n",
504 | "\n",
505 | "loss_function = nn.CrossEntropyLoss(ignore_index=0)\n",
506 | "optimizer = optim.Adam(model.parameters(),lr=LR)"
507 | ]
508 | },
509 | {
510 | "cell_type": "code",
511 | "execution_count": 19,
512 | "metadata": {
513 | "collapsed": false
514 | },
515 | "outputs": [
516 | {
517 | "name": "stdout",
518 | "output_type": "stream",
519 | "text": [
520 | "[0/50] mean_loss : 3.86\n",
521 | "[0/50] mean_loss : 1.32\n",
522 | "[1/50] mean_loss : 0.68\n",
523 | "[1/50] mean_loss : 0.65\n",
524 | "[2/50] mean_loss : 0.62\n",
525 | "[2/50] mean_loss : 0.65\n",
526 | "[3/50] mean_loss : 0.65\n",
527 | "[3/50] mean_loss : 0.64\n",
528 | "[4/50] mean_loss : 0.60\n",
529 | "[4/50] mean_loss : 0.62\n",
530 | "[5/50] mean_loss : 0.63\n",
531 | "[5/50] mean_loss : 0.61\n",
532 | "[6/50] mean_loss : 0.60\n",
533 | "[6/50] mean_loss : 0.61\n",
534 | "[7/50] mean_loss : 0.63\n",
535 | "[7/50] mean_loss : 0.60\n",
536 | "[8/50] mean_loss : 0.62\n",
537 | "[8/50] mean_loss : 0.60\n",
538 | "[9/50] mean_loss : 0.58\n",
539 | "[9/50] mean_loss : 0.60\n",
540 | "[10/50] mean_loss : 0.60\n",
541 | "[10/50] mean_loss : 0.60\n",
542 | "[11/50] mean_loss : 0.62\n",
543 | "[11/50] mean_loss : 0.60\n",
544 | "[12/50] mean_loss : 0.61\n",
545 | "[12/50] mean_loss : 0.60\n",
546 | "[13/50] mean_loss : 0.57\n",
547 | "[13/50] mean_loss : 0.60\n",
548 | "[14/50] mean_loss : 0.59\n",
549 | "[14/50] mean_loss : 0.60\n",
550 | "[15/50] mean_loss : 0.61\n",
551 | "[15/50] mean_loss : 0.60\n",
552 | "[16/50] mean_loss : 0.59\n",
553 | "[16/50] mean_loss : 0.60\n",
554 | "[17/50] mean_loss : 0.59\n",
555 | "[17/50] mean_loss : 0.60\n",
556 | "[18/50] mean_loss : 0.51\n",
557 | "[18/50] mean_loss : 0.50\n",
558 | "[19/50] mean_loss : 0.44\n",
559 | "[19/50] mean_loss : 0.37\n",
560 | "[20/50] mean_loss : 0.30\n",
561 | "[20/50] mean_loss : 0.33\n",
562 | "[21/50] mean_loss : 0.31\n",
563 | "[21/50] mean_loss : 0.31\n",
564 | "[22/50] mean_loss : 0.29\n",
565 | "[22/50] mean_loss : 0.31\n",
566 | "[23/50] mean_loss : 0.29\n",
567 | "[23/50] mean_loss : 0.31\n",
568 | "[24/50] mean_loss : 0.24\n",
569 | "[24/50] mean_loss : 0.31\n",
570 | "[25/50] mean_loss : 0.30\n",
571 | "[25/50] mean_loss : 0.30\n",
572 | "[26/50] mean_loss : 0.14\n",
573 | "[26/50] mean_loss : 0.16\n",
574 | "[27/50] mean_loss : 0.12\n",
575 | "[27/50] mean_loss : 0.15\n",
576 | "[28/50] mean_loss : 0.18\n",
577 | "[28/50] mean_loss : 0.14\n",
578 | "[29/50] mean_loss : 0.12\n",
579 | "[29/50] mean_loss : 0.14\n",
580 | "[30/50] mean_loss : 0.14\n",
581 | "[30/50] mean_loss : 0.14\n",
582 | "[31/50] mean_loss : 0.13\n",
583 | "[31/50] mean_loss : 0.14\n",
584 | "[32/50] mean_loss : 0.11\n",
585 | "[32/50] mean_loss : 0.13\n",
586 | "[33/50] mean_loss : 0.08\n",
587 | "[33/50] mean_loss : 0.06\n",
588 | "[34/50] mean_loss : 0.01\n",
589 | "[34/50] mean_loss : 0.03\n",
590 | "[35/50] mean_loss : 0.01\n",
591 | "Early Stopping!\n"
592 | ]
593 | }
594 | ],
595 | "source": [
596 | "for epoch in range(EPOCH):\n",
597 | " losses=[]\n",
598 | " if EARLY_STOPPING: break\n",
599 | " \n",
600 | " for i,batch in enumerate(getBatch(BATCH_SIZE,train_data)):\n",
601 | " facts, fact_masks, questions, question_masks, answers = pad_to_batch(batch,word2index)\n",
602 | " \n",
603 | " model.zero_grad()\n",
604 | " pred = model(facts,fact_masks,questions,question_masks,answers.size(1),NUM_EPISODE,True)\n",
605 | " loss = loss_function(pred,answers.view(-1))\n",
606 | " losses.append(loss.data.tolist()[0])\n",
607 | " \n",
608 | " loss.backward()\n",
609 | " optimizer.step()\n",
610 | " \n",
611 | " if i % 100==0:\n",
612 | " print(\"[%d/%d] mean_loss : %0.2f\" %(epoch,EPOCH,np.mean(losses)))\n",
613 | " \n",
614 | " if np.mean(losses)<0.01:\n",
615 | " EARLY_STOPPING=True\n",
616 | " print(\"Early Stopping!\")\n",
617 | " break\n",
618 | " losses=[]"
619 | ]
620 | },
621 | {
622 | "cell_type": "markdown",
623 | "metadata": {},
624 | "source": [
625 | "## Test "
626 | ]
627 | },
628 | {
629 | "cell_type": "code",
630 | "execution_count": 21,
631 | "metadata": {
632 | "collapsed": true
633 | },
634 | "outputs": [],
635 | "source": [
636 | "def pad_to_fact(fact,x_to_ix): # this is for inference\n",
637 | " \n",
638 | " max_x = max([s.size(1) for s in fact])\n",
639 | " x_p=[]\n",
640 | " for i in range(len(fact)):\n",
641 | " if fact[i].size(1)']]*(max_x-fact[i].size(1)))).view(1,-1)],1))\n",
643 | " else:\n",
644 | " x_p.append(fact[i])\n",
645 | " \n",
646 | " fact = torch.cat(x_p)\n",
647 | " fact_mask = torch.cat([Variable(ByteTensor(tuple(map(lambda s: s ==0, t.data))),volatile=False) for t in fact]).view(fact.size(0),-1)\n",
648 | " return fact,fact_mask"
649 | ]
650 | },
651 | {
652 | "cell_type": "markdown",
653 | "metadata": {},
654 | "source": [
655 | "### Prepare Test data "
656 | ]
657 | },
658 | {
659 | "cell_type": "code",
660 | "execution_count": 27,
661 | "metadata": {
662 | "collapsed": false
663 | },
664 | "outputs": [],
665 | "source": [
666 | "test_data = bAbI_data_load('../dataset/bAbI/en-10k/qa5_three-arg-relations_test.txt')"
667 | ]
668 | },
669 | {
670 | "cell_type": "code",
671 | "execution_count": 28,
672 | "metadata": {
673 | "collapsed": true
674 | },
675 | "outputs": [],
676 | "source": [
677 | "for t in test_data:\n",
678 | " for i,fact in enumerate(t[0]):\n",
679 | " t[0][i] = prepare_sequence(fact,word2index).view(1,-1)\n",
680 | " \n",
681 | " t[1] = prepare_sequence(t[1],word2index).view(1,-1)\n",
682 | " t[2] = prepare_sequence(t[2],word2index).view(1,-1)"
683 | ]
684 | },
685 | {
686 | "cell_type": "markdown",
687 | "metadata": {},
688 | "source": [
689 | "### Accuracy "
690 | ]
691 | },
692 | {
693 | "cell_type": "code",
694 | "execution_count": 31,
695 | "metadata": {
696 | "collapsed": true
697 | },
698 | "outputs": [],
699 | "source": [
700 | "accuracy=0"
701 | ]
702 | },
703 | {
704 | "cell_type": "code",
705 | "execution_count": 32,
706 | "metadata": {
707 | "collapsed": false
708 | },
709 | "outputs": [
710 | {
711 | "name": "stdout",
712 | "output_type": "stream",
713 | "text": [
714 | "97.39999999999999\n"
715 | ]
716 | }
717 | ],
718 | "source": [
719 | "for t in test_data:\n",
720 | " fact, fact_mask = pad_to_fact(t[0],word2index)\n",
721 | " question = t[1]\n",
722 | " question_mask = Variable(ByteTensor([0]*t[1].size(1)),volatile=False).unsqueeze(0)\n",
723 | " answer = t[2].squeeze(0)\n",
724 | " \n",
725 | " model.zero_grad()\n",
726 | " pred = model([fact],[fact_mask],question,question_mask,answer.size(0),NUM_EPISODE)\n",
727 | " if pred.max(1)[1].data.tolist()==answer.data.tolist():\n",
728 | " accuracy+=1\n",
729 | "\n",
730 | "print(accuracy/len(test_data)*100)"
731 | ]
732 | },
733 | {
734 | "cell_type": "markdown",
735 | "metadata": {},
736 | "source": [
737 | "### Sample test result "
738 | ]
739 | },
740 | {
741 | "cell_type": "code",
742 | "execution_count": 34,
743 | "metadata": {
744 | "collapsed": false
745 | },
746 | "outputs": [
747 | {
748 | "name": "stdout",
749 | "output_type": "stream",
750 | "text": [
751 | "Facts : \n",
752 | "Bill went back to the bedroom \n",
753 | "Mary went to the office \n",
754 | "Jeff journeyed to the kitchen \n",
755 | "Fred journeyed to the kitchen \n",
756 | "Fred got the milk there \n",
757 | "Fred handed the milk to Jeff \n",
758 | "Jeff passed the milk to Fred \n",
759 | "Fred gave the milk to Jeff \n",
760 | "\n",
761 | "Question : Who received the milk ?\n",
762 | "\n",
763 | "Answer : Jeff \n",
764 | "Prediction : Jeff \n"
765 | ]
766 | }
767 | ],
768 | "source": [
769 | "t = random.choice(test_data)\n",
770 | "fact, fact_mask = pad_to_fact(t[0],word2index)\n",
771 | "question = t[1]\n",
772 | "question_mask = Variable(ByteTensor([0]*t[1].size(1)),volatile=False).unsqueeze(0)\n",
773 | "answer = t[2].squeeze(0)\n",
774 | "\n",
775 | "model.zero_grad()\n",
776 | "pred = model([fact],[fact_mask],question,question_mask,answer.size(0),NUM_EPISODE)\n",
777 | "\n",
778 | "print(\"Facts : \")\n",
779 | "print('\\n'.join([' '.join(list(map(lambda x : index2word[x],f))) for f in fact.data.tolist()]))\n",
780 | "print(\"\")\n",
781 | "print(\"Question : \",' '.join(list(map(lambda x:index2word[x],question.data.tolist()[0]))))\n",
782 | "print(\"\")\n",
783 | "print(\"Answer : \",' '.join(list(map(lambda x:index2word[x],answer.data.tolist()))))\n",
784 | "print(\"Prediction : \",' '.join(list(map(lambda x:index2word[x],pred.max(1)[1].data.tolist()))))"
785 | ]
786 | },
787 | {
788 | "cell_type": "markdown",
789 | "metadata": {
790 | "collapsed": true
791 | },
792 | "source": [
793 | "## Further topics "
794 | ]
795 | },
796 | {
797 | "cell_type": "markdown",
798 | "metadata": {},
799 | "source": [
800 | "* Dynamic Memory Networks for Visual and Textual Question Answering(DMN+)\n",
801 | "* DMN+ Pytorch implementation\n",
802 | "* Dynamic Coattention Networks For Question Answering\n",
803 | "* DCN+: Mixed Objective and Deep Residual Coattention for Question Answering"
804 | ]
805 | },
806 | {
807 | "cell_type": "code",
808 | "execution_count": null,
809 | "metadata": {
810 | "collapsed": true
811 | },
812 | "outputs": [],
813 | "source": []
814 | }
815 | ],
816 | "metadata": {
817 | "kernelspec": {
818 | "display_name": "Python 3",
819 | "language": "python",
820 | "name": "python3"
821 | },
822 | "language_info": {
823 | "codemirror_mode": {
824 | "name": "ipython",
825 | "version": 3
826 | },
827 | "file_extension": ".py",
828 | "mimetype": "text/x-python",
829 | "name": "python",
830 | "nbconvert_exporter": "python",
831 | "pygments_lexer": "ipython3",
832 | "version": "3.5.2"
833 | }
834 | },
835 | "nbformat": 4,
836 | "nbformat_minor": 2
837 | }
838 |
--------------------------------------------------------------------------------
/script/docker-compose.yml:
--------------------------------------------------------------------------------
1 | cs224n:
2 | image: dsksd/deepstudy:0.2
3 | ports:
4 | - "8888:8888"
5 | - "6006:6006"
6 | volumes:
7 | - ../:/notebooks/work
8 | command: /run.sh
9 | tty: true
--------------------------------------------------------------------------------
/script/prepare_dataset.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 |
3 | OUT_DIR="${1:-../dataset}"
4 |
5 | mkdir -v -p $OUT_DIR
6 |
7 | echo "download machine translation dataset from http://www.manythings.org/anki/..."
8 | curl -o "$OUT_DIR/fra-eng.zip" "http://www.manythings.org/anki/fra-eng.zip"
9 | unzip "$OUT_DIR/fra-eng.zip" -d "$OUT_DIR"
10 | rm "$OUT_DIR/fra-eng.zip"
11 | rm "$OUT_DIR/_about.txt"
12 | mv "$OUT_DIR/fra.txt" "$OUT_DIR/eng-fra.txt"
13 |
14 | echo "download ptb dataset... (clone from https://github.com/tomsercu/lstm/tree/master/data"
15 | mkdir -v -p "$OUT_DIR/ptb"
16 | wget "https://raw.githubusercontent.com/tomsercu/lstm/master/data/ptb.train.txt" -P "$OUT_DIR/ptb"
17 | wget "https://raw.githubusercontent.com/tomsercu/lstm/master/data/ptb.valid.txt" -P "$OUT_DIR/ptb"
18 | wget "https://raw.githubusercontent.com/tomsercu/lstm/master/data/ptb.test.txt" -P "$OUT_DIR/ptb"
19 |
20 | echo "download TREC question dataset..."
21 | curl -o "$OUT_DIR/train_5500.label.txt" "http://cogcomp.org/Data/QA/QC/train_5500.label"
22 |
23 | echo "download Stanford sentment treebank..."
24 | curl -o "$OUT_DIR/trainDevTestTrees_PTB.zip" "https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip"
25 | unzip "$OUT_DIR/trainDevTestTrees_PTB.zip" -d "$OUT_DIR"
26 | rm "$OUT_DIR/trainDevTestTrees_PTB.zip"
27 |
28 | echo "download bAbI dataset..."
29 | curl -o "$OUT_DIR/tasks_1-20_v1-2.tar.gz" "http://www.thespermwhale.com/jaseweston/babi/tasks_1-20_v1-2.tar.gz"
30 | tar zxvf "$OUT_DIR/tasks_1-20_v1-2.tar.gz"
31 | mv "tasks_1-20_v1-2" "$OUT_DIR/bAbI"
32 | rm "$OUT_DIR/tasks_1-20_v1-2.tar.gz"
33 |
34 | echo "download nltk dataset..."
35 | python3 -c "import nltk;nltk.download('gutenberg');nltk.download('brown');nltk.download('conll2002');nltk.download('timit')"
36 |
37 | echo "download dependency parser dataset... (clone from https://github.com/rguthrie3/DeepDependencyParsingProblemSet"
38 | mkdir -v -p "$OUT_DIR/dparser"
39 | wget "https://raw.githubusercontent.com/rguthrie3/DeepDependencyParsingProblemSet/master/data/train.txt" -P "$OUT_DIR/dparser"
40 | wget "https://raw.githubusercontent.com/rguthrie3/DeepDependencyParsingProblemSet/master/data/vocab.txt" -P "$OUT_DIR/dparser"
41 | wget "https://raw.githubusercontent.com/rguthrie3/DeepDependencyParsingProblemSet/master/data/dev.txt" -P "$OUT_DIR/dparser"
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from distutils.core import setup
2 | from Cython.Build import cythonize
3 | from Cython.Distutils import build_ext
4 | from distutils.extension import Extension
5 |
6 | setup(name="Corpus",cmdclass = {'build_ext': build_ext},ext_modules = [Extension("Corpus", ["Corpus.pyx"],language='c++')])
--------------------------------------------------------------------------------