├── .gitignore
├── Corpus.pyx
├── LICENSE
├── README.md
├── images
├── 01.skipgram-objective.png
├── 01.skipgram-prepare-data.png
├── 02.skipgram-objective.png
├── 03.glove-objective.png
├── 03.glove-weighting-function.png
├── 04.window-classifier-architecture.png
├── 04.window-data.png
├── 05.neural-dparser-architecture.png
├── 05.transition-based-parse.png
├── 06.rnnlm-architecture.png
├── 07.attention-mechanism.png
├── 07.pad_to_sequence.png
├── 07.seq2seq.png
├── 08.cnn-for-text-architecture.png
├── 09.rntn-layer.png
└── 10.dmn-architecture.png
├── notebooks
├── 01.Skip-gram-Naive-Softmax.ipynb
├── 02.Skip-gram-Negative-Sampling.ipynb
├── 03.GloVe.ipynb
├── 04.Window-Classifier-for-NER.ipynb
├── 05.Neural-Dependancy-Parser.ipynb
├── 06.RNN-Language-Model.ipynb
├── 07.Neural-Machine-Translation-with-Attention.ipynb
├── 08.CNN-for-Text-Classification.ipynb
├── 09.Recursive-NN-for-Sentiment-Classification.ipynb
└── 10.Dynamic-Memory-Network-for-Question-Answering.ipynb
├── script
├── docker-compose.yml
└── prepare_dataset.sh
└── setup.py
/.gitignore:
--------------------------------------------------------------------------------
1 | *.txt
2 | __pycache__
3 | .ipynb_checkpoints
4 | build/
5 | dataset/
6 | model/
7 | *.cpp
8 | *.so
9 | *.bin
10 | *.pkl
11 |
--------------------------------------------------------------------------------
/Corpus.pyx:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import scipy.sparse as sp
3 | from libcpp.vector cimport vector
4 | from libcpp.string cimport string
5 | from libcpp.map cimport map
6 |
7 | # Text 전처리 cpp extention 모듈
8 |
9 | def make_dictionary(vocab):
10 | cdef int vocab_size = len(vocab)
11 | cdef int i
12 | dictionary={}
13 | inv_dictionary={}
14 | for i in range(vocab_size):
15 | dictionary[vocab[i]] = i
16 | inv_dictionary[i]=vocab[i]
17 | return dictionary,inv_dictionary
18 |
19 |
20 | def make_window_data(sents,window_size):
21 | pass
22 |
23 |
24 | def make_coo_matrix(corpus,dictionary):
25 | cdef int matrix_size = len(dictionary)
26 | pass
27 |
28 |
29 | def getBatch_FromBucket(batch_size,buckets):
30 | i=0
31 | bucket_mask =[False for _ in range(len(buckets))]
32 | indices = [[0,batch_size] for _ in range(len(buckets))]
33 | is_done=False
34 | while is_done==False:
35 | batch = buckets[i][indices[i][0]:indices[i][1]]
36 | temp = indices[i][1]
37 | indices[i][1]= indices[i][1]+batch_size
38 | indices[i][0] = temp
39 |
40 | i = (i+1)%len(buckets)
41 | while bucket_mask[i]:
42 | i = (i+1)%len(buckets)
43 |
44 | if indices[i][1]>len(buckets[i]):
45 | bucket_mask[i]= True
46 | if bucket_mask.count(True)==len(buckets):
47 | is_done=True
48 | else:
49 | while bucket_mask[i]:
50 | i = (i+1)%len(buckets)
51 | yield batch
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2017 SungDong Kim
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # DeepNLP-models-Pytorch
2 |
3 | Pytorch implementations of various Deep NLP models in cs-224n(Stanford Univ: NLP with Deep Learning)
4 |
5 | - This is not for Pytorch beginners. If it is your first time to use Pytorch, I recommend these [awesome tutorials](#references).
6 | - If you're interested in DeepNLP, I strongly recommend you to work with this awesome lecture.
7 |
8 | * cs-224n-slides
9 | * cs-224n-videos
10 |
11 | This material is not perfect but will help your study and research:) Please feel free to pull requests!!
12 |
13 |
14 |
15 | ## Contents
16 |
17 | | Model | Links |
18 | | ------------- |:-------------:|
19 | | 01. Skip-gram-Naive-Softmax | [notebook / data / paper] |
20 | | 02. Skip-gram-Negative-Sampling | [notebook / data / paper] |
21 | | 03. GloVe | [notebook / data / paper] |
22 | | 04. Window-Classifier-for-NER | [notebook / data / paper] |
23 | | 05. Neural-Dependancy-Parser | [notebook / data / paper] |
24 | | 06. RNN-Language-Model | [notebook / data / paper] |
25 | | 07. Neural-Machine-Translation-with-Attention | [notebook / data / paper] |
26 | | 08. CNN-for-Text-Classification | [notebook / data / paper] |
27 | | 09. Recursive-NN-for-Sentiment-Classification | [notebook / data / paper] |
28 | | 10. Dynamic-Memory-Network-for-Question-Answering | [notebook / data / paper] |
29 |
30 |
31 | ## Requirements
32 |
33 | - Python 3.5
34 | - Pytorch 0.2+
35 | - nltk 3.2.2
36 | - gensim 2.2.0
37 | - sklearn_crfsuite
38 |
39 |
40 | ## Getting started
41 |
42 | `git clone https://github.com/DSKSD/cs-224n-Pytorch.git`
43 |
44 | ### prepare dataset
45 |
46 | ````
47 | cd script
48 | chmod u+x prepare_dataset.sh
49 | ./prepare_dataset.sh
50 | ````
51 |
52 | ### docker env
53 | ubuntu 16.04 python 3.5.2 with various of ML/DL packages including tensorflow, sklearn, pytorch
54 |
55 | `docker pull dsksd/deepstudy:0.2`
56 |
57 | ````
58 | pip3 install docker-compose
59 | cd script
60 | docker-compose up -d
61 | ````
62 |
63 | ### cloud setting
64 |
65 | `not yet`
66 |
67 | ## References
68 |
69 | * practical-pytorch
70 | * DeepLearningForNLPInPytorch
71 | * pytorch-tutorial
72 | * pytorch-examples
73 |
74 | ## Author
75 |
76 | Sungdong Kim / @DSKSD
--------------------------------------------------------------------------------
/images/01.skipgram-objective.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DSKSD/DeepNLP-models-Pytorch/7fec64d72615933e8f4ea499c2dbaa42508f4017/images/01.skipgram-objective.png
--------------------------------------------------------------------------------
/images/01.skipgram-prepare-data.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DSKSD/DeepNLP-models-Pytorch/7fec64d72615933e8f4ea499c2dbaa42508f4017/images/01.skipgram-prepare-data.png
--------------------------------------------------------------------------------
/images/02.skipgram-objective.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DSKSD/DeepNLP-models-Pytorch/7fec64d72615933e8f4ea499c2dbaa42508f4017/images/02.skipgram-objective.png
--------------------------------------------------------------------------------
/images/03.glove-objective.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DSKSD/DeepNLP-models-Pytorch/7fec64d72615933e8f4ea499c2dbaa42508f4017/images/03.glove-objective.png
--------------------------------------------------------------------------------
/images/03.glove-weighting-function.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DSKSD/DeepNLP-models-Pytorch/7fec64d72615933e8f4ea499c2dbaa42508f4017/images/03.glove-weighting-function.png
--------------------------------------------------------------------------------
/images/04.window-classifier-architecture.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DSKSD/DeepNLP-models-Pytorch/7fec64d72615933e8f4ea499c2dbaa42508f4017/images/04.window-classifier-architecture.png
--------------------------------------------------------------------------------
/images/04.window-data.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DSKSD/DeepNLP-models-Pytorch/7fec64d72615933e8f4ea499c2dbaa42508f4017/images/04.window-data.png
--------------------------------------------------------------------------------
/images/05.neural-dparser-architecture.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DSKSD/DeepNLP-models-Pytorch/7fec64d72615933e8f4ea499c2dbaa42508f4017/images/05.neural-dparser-architecture.png
--------------------------------------------------------------------------------
/images/05.transition-based-parse.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DSKSD/DeepNLP-models-Pytorch/7fec64d72615933e8f4ea499c2dbaa42508f4017/images/05.transition-based-parse.png
--------------------------------------------------------------------------------
/images/06.rnnlm-architecture.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DSKSD/DeepNLP-models-Pytorch/7fec64d72615933e8f4ea499c2dbaa42508f4017/images/06.rnnlm-architecture.png
--------------------------------------------------------------------------------
/images/07.attention-mechanism.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DSKSD/DeepNLP-models-Pytorch/7fec64d72615933e8f4ea499c2dbaa42508f4017/images/07.attention-mechanism.png
--------------------------------------------------------------------------------
/images/07.pad_to_sequence.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DSKSD/DeepNLP-models-Pytorch/7fec64d72615933e8f4ea499c2dbaa42508f4017/images/07.pad_to_sequence.png
--------------------------------------------------------------------------------
/images/07.seq2seq.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DSKSD/DeepNLP-models-Pytorch/7fec64d72615933e8f4ea499c2dbaa42508f4017/images/07.seq2seq.png
--------------------------------------------------------------------------------
/images/08.cnn-for-text-architecture.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DSKSD/DeepNLP-models-Pytorch/7fec64d72615933e8f4ea499c2dbaa42508f4017/images/08.cnn-for-text-architecture.png
--------------------------------------------------------------------------------
/images/09.rntn-layer.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DSKSD/DeepNLP-models-Pytorch/7fec64d72615933e8f4ea499c2dbaa42508f4017/images/09.rntn-layer.png
--------------------------------------------------------------------------------
/images/10.dmn-architecture.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DSKSD/DeepNLP-models-Pytorch/7fec64d72615933e8f4ea499c2dbaa42508f4017/images/10.dmn-architecture.png
--------------------------------------------------------------------------------
/notebooks/01.Skip-gram-Naive-Softmax.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# 1. Skip-gram with naiive softmax "
8 | ]
9 | },
10 | {
11 | "cell_type": "markdown",
12 | "metadata": {},
13 | "source": [
14 | "I recommend you take a look at these material first."
15 | ]
16 | },
17 | {
18 | "cell_type": "markdown",
19 | "metadata": {},
20 | "source": [
21 | "* http://web.stanford.edu/class/cs224n/lectures/cs224n-2017-lecture2.pdf\n",
22 | "* https://arxiv.org/abs/1301.3781\n",
23 | "* http://mccormickml.com/2016/04/19/word2vec-tutorial-the-skip-gram-model/"
24 | ]
25 | },
26 | {
27 | "cell_type": "code",
28 | "execution_count": 1,
29 | "metadata": {
30 | "collapsed": true
31 | },
32 | "outputs": [],
33 | "source": [
34 | "import torch\n",
35 | "import torch.nn as nn\n",
36 | "from torch.autograd import Variable\n",
37 | "import torch.optim as optim\n",
38 | "import torch.nn.functional as F\n",
39 | "import nltk\n",
40 | "import random\n",
41 | "import numpy as np\n",
42 | "from collections import Counter\n",
43 | "flatten = lambda l: [item for sublist in l for item in sublist]\n",
44 | "random.seed(1024)"
45 | ]
46 | },
47 | {
48 | "cell_type": "code",
49 | "execution_count": 2,
50 | "metadata": {
51 | "collapsed": false
52 | },
53 | "outputs": [
54 | {
55 | "name": "stdout",
56 | "output_type": "stream",
57 | "text": [
58 | "0.3.0.post4\n",
59 | "3.2.4\n"
60 | ]
61 | }
62 | ],
63 | "source": [
64 | "print(torch.__version__)\n",
65 | "print(nltk.__version__)"
66 | ]
67 | },
68 | {
69 | "cell_type": "code",
70 | "execution_count": 3,
71 | "metadata": {
72 | "collapsed": true
73 | },
74 | "outputs": [],
75 | "source": [
76 | "USE_CUDA = torch.cuda.is_available()\n",
77 | "gpus = [0]\n",
78 | "torch.cuda.set_device(gpus[0])\n",
79 | "\n",
80 | "FloatTensor = torch.cuda.FloatTensor if USE_CUDA else torch.FloatTensor\n",
81 | "LongTensor = torch.cuda.LongTensor if USE_CUDA else torch.LongTensor\n",
82 | "ByteTensor = torch.cuda.ByteTensor if USE_CUDA else torch.ByteTensor"
83 | ]
84 | },
85 | {
86 | "cell_type": "code",
87 | "execution_count": 4,
88 | "metadata": {
89 | "collapsed": true
90 | },
91 | "outputs": [],
92 | "source": [
93 | "def getBatch(batch_size, train_data):\n",
94 | " random.shuffle(train_data)\n",
95 | " sindex = 0\n",
96 | " eindex = batch_size\n",
97 | " while eindex < len(train_data):\n",
98 | " batch = train_data[sindex: eindex]\n",
99 | " temp = eindex\n",
100 | " eindex = eindex + batch_size\n",
101 | " sindex = temp\n",
102 | " yield batch\n",
103 | " \n",
104 | " if eindex >= len(train_data):\n",
105 | " batch = train_data[sindex:]\n",
106 | " yield batch"
107 | ]
108 | },
109 | {
110 | "cell_type": "code",
111 | "execution_count": 5,
112 | "metadata": {
113 | "collapsed": true
114 | },
115 | "outputs": [],
116 | "source": [
117 | "def prepare_sequence(seq, word2index):\n",
118 | " idxs = list(map(lambda w: word2index[w] if word2index.get(w) is not None else word2index[\"\"], seq))\n",
119 | " return Variable(LongTensor(idxs))\n",
120 | "\n",
121 | "def prepare_word(word, word2index):\n",
122 | " return Variable(LongTensor([word2index[word]]) if word2index.get(word) is not None else LongTensor([word2index[\"\"]]))"
123 | ]
124 | },
125 | {
126 | "cell_type": "markdown",
127 | "metadata": {},
128 | "source": [
129 | "## Data load and Preprocessing "
130 | ]
131 | },
132 | {
133 | "cell_type": "markdown",
134 | "metadata": {},
135 | "source": [
136 | "### Load corpus : Gutenberg corpus"
137 | ]
138 | },
139 | {
140 | "cell_type": "markdown",
141 | "metadata": {},
142 | "source": [
143 | "If you don't have gutenberg corpus, you can download it first using nltk.download()"
144 | ]
145 | },
146 | {
147 | "cell_type": "code",
148 | "execution_count": 6,
149 | "metadata": {
150 | "collapsed": false
151 | },
152 | "outputs": [
153 | {
154 | "data": {
155 | "text/plain": [
156 | "['austen-emma.txt',\n",
157 | " 'austen-persuasion.txt',\n",
158 | " 'austen-sense.txt',\n",
159 | " 'bible-kjv.txt',\n",
160 | " 'blake-poems.txt',\n",
161 | " 'bryant-stories.txt',\n",
162 | " 'burgess-busterbrown.txt',\n",
163 | " 'carroll-alice.txt',\n",
164 | " 'chesterton-ball.txt',\n",
165 | " 'chesterton-brown.txt',\n",
166 | " 'chesterton-thursday.txt',\n",
167 | " 'edgeworth-parents.txt',\n",
168 | " 'melville-moby_dick.txt',\n",
169 | " 'milton-paradise.txt',\n",
170 | " 'shakespeare-caesar.txt',\n",
171 | " 'shakespeare-hamlet.txt',\n",
172 | " 'shakespeare-macbeth.txt',\n",
173 | " 'whitman-leaves.txt']"
174 | ]
175 | },
176 | "execution_count": 6,
177 | "metadata": {},
178 | "output_type": "execute_result"
179 | }
180 | ],
181 | "source": [
182 | "nltk.corpus.gutenberg.fileids()"
183 | ]
184 | },
185 | {
186 | "cell_type": "code",
187 | "execution_count": 7,
188 | "metadata": {
189 | "collapsed": true
190 | },
191 | "outputs": [],
192 | "source": [
193 | "corpus = list(nltk.corpus.gutenberg.sents('melville-moby_dick.txt'))[:100] # sampling sentences for test\n",
194 | "corpus = [[word.lower() for word in sent] for sent in corpus]"
195 | ]
196 | },
197 | {
198 | "cell_type": "markdown",
199 | "metadata": {},
200 | "source": [
201 | "### Extract Stopwords from unigram distribution's tails"
202 | ]
203 | },
204 | {
205 | "cell_type": "code",
206 | "execution_count": 8,
207 | "metadata": {
208 | "collapsed": true
209 | },
210 | "outputs": [],
211 | "source": [
212 | "word_count = Counter(flatten(corpus))\n",
213 | "border = int(len(word_count) * 0.01) "
214 | ]
215 | },
216 | {
217 | "cell_type": "code",
218 | "execution_count": 9,
219 | "metadata": {
220 | "collapsed": true
221 | },
222 | "outputs": [],
223 | "source": [
224 | "stopwords = word_count.most_common()[:border] + list(reversed(word_count.most_common()))[:border]"
225 | ]
226 | },
227 | {
228 | "cell_type": "code",
229 | "execution_count": 10,
230 | "metadata": {
231 | "collapsed": true
232 | },
233 | "outputs": [],
234 | "source": [
235 | "stopwords = [s[0] for s in stopwords]"
236 | ]
237 | },
238 | {
239 | "cell_type": "code",
240 | "execution_count": 11,
241 | "metadata": {
242 | "collapsed": false
243 | },
244 | "outputs": [
245 | {
246 | "data": {
247 | "text/plain": [
248 | "[',', '.', 'the', 'of', 'and', 'baleine', '--(', 'fat', 'oil', 'boiling']"
249 | ]
250 | },
251 | "execution_count": 11,
252 | "metadata": {},
253 | "output_type": "execute_result"
254 | }
255 | ],
256 | "source": [
257 | "stopwords"
258 | ]
259 | },
260 | {
261 | "cell_type": "markdown",
262 | "metadata": {},
263 | "source": [
264 | "### Build vocab"
265 | ]
266 | },
267 | {
268 | "cell_type": "code",
269 | "execution_count": 12,
270 | "metadata": {
271 | "collapsed": true
272 | },
273 | "outputs": [],
274 | "source": [
275 | "vocab = list(set(flatten(corpus)) - set(stopwords))\n",
276 | "vocab.append('')"
277 | ]
278 | },
279 | {
280 | "cell_type": "code",
281 | "execution_count": 13,
282 | "metadata": {
283 | "collapsed": false
284 | },
285 | "outputs": [
286 | {
287 | "name": "stdout",
288 | "output_type": "stream",
289 | "text": [
290 | "592 583\n"
291 | ]
292 | }
293 | ],
294 | "source": [
295 | "print(len(set(flatten(corpus))), len(vocab))"
296 | ]
297 | },
298 | {
299 | "cell_type": "code",
300 | "execution_count": 14,
301 | "metadata": {
302 | "collapsed": true
303 | },
304 | "outputs": [],
305 | "source": [
306 | "word2index = {'' : 0} \n",
307 | "\n",
308 | "for vo in vocab:\n",
309 | " if word2index.get(vo) is None:\n",
310 | " word2index[vo] = len(word2index)\n",
311 | "\n",
312 | "index2word = {v:k for k, v in word2index.items()} "
313 | ]
314 | },
315 | {
316 | "cell_type": "markdown",
317 | "metadata": {},
318 | "source": [
319 | "### Prepare train data "
320 | ]
321 | },
322 | {
323 | "cell_type": "markdown",
324 | "metadata": {},
325 | "source": [
326 | "window data example"
327 | ]
328 | },
329 | {
330 | "cell_type": "markdown",
331 | "metadata": {},
332 | "source": [
333 | "
\n",
334 | "borrowed image from http://mccormickml.com/2016/04/19/word2vec-tutorial-the-skip-gram-model/"
335 | ]
336 | },
337 | {
338 | "cell_type": "code",
339 | "execution_count": 15,
340 | "metadata": {
341 | "collapsed": true
342 | },
343 | "outputs": [],
344 | "source": [
345 | "WINDOW_SIZE = 3\n",
346 | "windows = flatten([list(nltk.ngrams([''] * WINDOW_SIZE + c + [''] * WINDOW_SIZE, WINDOW_SIZE * 2 + 1)) for c in corpus])"
347 | ]
348 | },
349 | {
350 | "cell_type": "code",
351 | "execution_count": 16,
352 | "metadata": {
353 | "collapsed": false
354 | },
355 | "outputs": [
356 | {
357 | "data": {
358 | "text/plain": [
359 | "('', '', '', '[', 'moby', 'dick', 'by')"
360 | ]
361 | },
362 | "execution_count": 16,
363 | "metadata": {},
364 | "output_type": "execute_result"
365 | }
366 | ],
367 | "source": [
368 | "windows[0]"
369 | ]
370 | },
371 | {
372 | "cell_type": "code",
373 | "execution_count": 17,
374 | "metadata": {
375 | "collapsed": false
376 | },
377 | "outputs": [
378 | {
379 | "name": "stdout",
380 | "output_type": "stream",
381 | "text": [
382 | "[('[', 'moby'), ('[', 'dick'), ('[', 'by'), ('moby', '['), ('moby', 'dick'), ('moby', 'by')]\n"
383 | ]
384 | }
385 | ],
386 | "source": [
387 | "train_data = []\n",
388 | "\n",
389 | "for window in windows:\n",
390 | " for i in range(WINDOW_SIZE * 2 + 1):\n",
391 | " if i == WINDOW_SIZE or window[i] == '': \n",
392 | " continue\n",
393 | " train_data.append((window[WINDOW_SIZE], window[i]))\n",
394 | "\n",
395 | "print(train_data[:WINDOW_SIZE * 2])"
396 | ]
397 | },
398 | {
399 | "cell_type": "code",
400 | "execution_count": 18,
401 | "metadata": {
402 | "collapsed": true
403 | },
404 | "outputs": [],
405 | "source": [
406 | "X_p = []\n",
407 | "y_p = []"
408 | ]
409 | },
410 | {
411 | "cell_type": "code",
412 | "execution_count": 19,
413 | "metadata": {
414 | "collapsed": false
415 | },
416 | "outputs": [
417 | {
418 | "data": {
419 | "text/plain": [
420 | "('[', 'moby')"
421 | ]
422 | },
423 | "execution_count": 19,
424 | "metadata": {},
425 | "output_type": "execute_result"
426 | }
427 | ],
428 | "source": [
429 | "train_data[0]"
430 | ]
431 | },
432 | {
433 | "cell_type": "code",
434 | "execution_count": 20,
435 | "metadata": {
436 | "collapsed": true
437 | },
438 | "outputs": [],
439 | "source": [
440 | "for tr in train_data:\n",
441 | " X_p.append(prepare_word(tr[0], word2index).view(1, -1))\n",
442 | " y_p.append(prepare_word(tr[1], word2index).view(1, -1))"
443 | ]
444 | },
445 | {
446 | "cell_type": "code",
447 | "execution_count": 21,
448 | "metadata": {
449 | "collapsed": true
450 | },
451 | "outputs": [],
452 | "source": [
453 | "train_data = list(zip(X_p, y_p))"
454 | ]
455 | },
456 | {
457 | "cell_type": "code",
458 | "execution_count": 22,
459 | "metadata": {
460 | "collapsed": false
461 | },
462 | "outputs": [
463 | {
464 | "data": {
465 | "text/plain": [
466 | "7606"
467 | ]
468 | },
469 | "execution_count": 22,
470 | "metadata": {},
471 | "output_type": "execute_result"
472 | }
473 | ],
474 | "source": [
475 | "len(train_data)"
476 | ]
477 | },
478 | {
479 | "cell_type": "markdown",
480 | "metadata": {},
481 | "source": [
482 | "## Modeling"
483 | ]
484 | },
485 | {
486 | "cell_type": "markdown",
487 | "metadata": {},
488 | "source": [
489 | "
\n",
490 | "borrowed image from http://web.stanford.edu/class/cs224n/lectures/cs224n-2017-lecture2.pdf"
491 | ]
492 | },
493 | {
494 | "cell_type": "code",
495 | "execution_count": 59,
496 | "metadata": {
497 | "collapsed": true
498 | },
499 | "outputs": [],
500 | "source": [
501 | "class Skipgram(nn.Module):\n",
502 | " \n",
503 | " def __init__(self, vocab_size, projection_dim):\n",
504 | " super(Skipgram,self).__init__()\n",
505 | " self.embedding_v = nn.Embedding(vocab_size, projection_dim)\n",
506 | " self.embedding_u = nn.Embedding(vocab_size, projection_dim)\n",
507 | "\n",
508 | " self.embedding_v.weight.data.uniform_(-1, 1) # init\n",
509 | " self.embedding_u.weight.data.uniform_(0, 0) # init\n",
510 | " #self.out = nn.Linear(projection_dim,vocab_size)\n",
511 | " def forward(self, center_words,target_words, outer_words):\n",
512 | " center_embeds = self.embedding_v(center_words) # B x 1 x D\n",
513 | " target_embeds = self.embedding_u(target_words) # B x 1 x D\n",
514 | " outer_embeds = self.embedding_u(outer_words) # B x V x D\n",
515 | " \n",
516 | " scores = target_embeds.bmm(center_embeds.transpose(1, 2)).squeeze(2) # Bx1xD * BxDx1 => Bx1\n",
517 | " norm_scores = outer_embeds.bmm(center_embeds.transpose(1, 2)).squeeze(2) # BxVxD * BxDx1 => BxV\n",
518 | " \n",
519 | " nll = -torch.mean(torch.log(torch.exp(scores)/torch.sum(torch.exp(norm_scores), 1).unsqueeze(1))) # log-softmax\n",
520 | " \n",
521 | " return nll # negative log likelihood\n",
522 | " \n",
523 | " def prediction(self, inputs):\n",
524 | " embeds = self.embedding_v(inputs)\n",
525 | " \n",
526 | " return embeds "
527 | ]
528 | },
529 | {
530 | "cell_type": "markdown",
531 | "metadata": {},
532 | "source": [
533 | "## Train "
534 | ]
535 | },
536 | {
537 | "cell_type": "code",
538 | "execution_count": 60,
539 | "metadata": {
540 | "collapsed": true
541 | },
542 | "outputs": [],
543 | "source": [
544 | "EMBEDDING_SIZE = 30\n",
545 | "BATCH_SIZE = 256\n",
546 | "EPOCH = 100"
547 | ]
548 | },
549 | {
550 | "cell_type": "code",
551 | "execution_count": 61,
552 | "metadata": {
553 | "collapsed": true
554 | },
555 | "outputs": [],
556 | "source": [
557 | "losses = []\n",
558 | "model = Skipgram(len(word2index), EMBEDDING_SIZE)\n",
559 | "if USE_CUDA:\n",
560 | " model = model.cuda()\n",
561 | "optimizer = optim.Adam(model.parameters(), lr=0.01)"
562 | ]
563 | },
564 | {
565 | "cell_type": "code",
566 | "execution_count": 62,
567 | "metadata": {
568 | "collapsed": false
569 | },
570 | "outputs": [
571 | {
572 | "name": "stdout",
573 | "output_type": "stream",
574 | "text": [
575 | "Epoch : 0, mean_loss : 6.20\n",
576 | "Epoch : 10, mean_loss : 4.38\n",
577 | "Epoch : 20, mean_loss : 3.48\n",
578 | "Epoch : 30, mean_loss : 3.31\n",
579 | "Epoch : 40, mean_loss : 3.26\n",
580 | "Epoch : 50, mean_loss : 3.24\n",
581 | "Epoch : 60, mean_loss : 3.22\n",
582 | "Epoch : 70, mean_loss : 3.22\n",
583 | "Epoch : 80, mean_loss : 3.21\n",
584 | "Epoch : 90, mean_loss : 3.20\n"
585 | ]
586 | }
587 | ],
588 | "source": [
589 | "for epoch in range(EPOCH):\n",
590 | " for i, batch in enumerate(getBatch(BATCH_SIZE, train_data)):\n",
591 | " \n",
592 | " inputs, targets = zip(*batch)\n",
593 | " \n",
594 | " inputs = torch.cat(inputs) # B x 1\n",
595 | " targets = torch.cat(targets) # B x 1\n",
596 | " vocabs = prepare_sequence(list(vocab), word2index).expand(inputs.size(0), len(vocab)) # B x V\n",
597 | " model.zero_grad()\n",
598 | "\n",
599 | " loss = model(inputs, targets, vocabs)\n",
600 | " \n",
601 | " loss.backward()\n",
602 | " optimizer.step()\n",
603 | " \n",
604 | " losses.append(loss.data.tolist()[0])\n",
605 | "\n",
606 | " if epoch % 10 == 0:\n",
607 | " print(\"Epoch : %d, mean_loss : %.02f\" % (epoch,np.mean(losses)))\n",
608 | " losses = []"
609 | ]
610 | },
611 | {
612 | "cell_type": "markdown",
613 | "metadata": {},
614 | "source": [
615 | "## Test"
616 | ]
617 | },
618 | {
619 | "cell_type": "code",
620 | "execution_count": 63,
621 | "metadata": {
622 | "collapsed": true
623 | },
624 | "outputs": [],
625 | "source": [
626 | "def word_similarity(target, vocab):\n",
627 | " if USE_CUDA:\n",
628 | " target_V = model.prediction(prepare_word(target, word2index))\n",
629 | " else:\n",
630 | " target_V = model.prediction(prepare_word(target, word2index))\n",
631 | " similarities = []\n",
632 | " for i in range(len(vocab)):\n",
633 | " if vocab[i] == target: continue\n",
634 | " \n",
635 | " if USE_CUDA:\n",
636 | " vector = model.prediction(prepare_word(list(vocab)[i], word2index))\n",
637 | " else:\n",
638 | " vector = model.prediction(prepare_word(list(vocab)[i], word2index))\n",
639 | " cosine_sim = F.cosine_similarity(target_V, vector).data.tolist()[0] \n",
640 | " similarities.append([vocab[i], cosine_sim])\n",
641 | " return sorted(similarities, key=lambda x: x[1], reverse=True)[:10] # sort by similarity"
642 | ]
643 | },
644 | {
645 | "cell_type": "code",
646 | "execution_count": 64,
647 | "metadata": {
648 | "collapsed": false
649 | },
650 | "outputs": [
651 | {
652 | "data": {
653 | "text/plain": [
654 | "'least'"
655 | ]
656 | },
657 | "execution_count": 64,
658 | "metadata": {},
659 | "output_type": "execute_result"
660 | }
661 | ],
662 | "source": [
663 | "test = random.choice(list(vocab))\n",
664 | "test"
665 | ]
666 | },
667 | {
668 | "cell_type": "code",
669 | "execution_count": 65,
670 | "metadata": {
671 | "collapsed": false
672 | },
673 | "outputs": [
674 | {
675 | "data": {
676 | "text/plain": [
677 | "[['at', 0.8147411346435547],\n",
678 | " ['every', 0.7143548130989075],\n",
679 | " ['case', 0.6975079774856567],\n",
680 | " ['secure', 0.6121522188186646],\n",
681 | " ['heart', 0.5974172949790955],\n",
682 | " ['including', 0.5867112278938293],\n",
683 | " ['please', 0.5557640194892883],\n",
684 | " ['has', 0.5536234974861145],\n",
685 | " ['while', 0.5366998314857483],\n",
686 | " ['you', 0.509368896484375]]"
687 | ]
688 | },
689 | "execution_count": 65,
690 | "metadata": {},
691 | "output_type": "execute_result"
692 | }
693 | ],
694 | "source": [
695 | "word_similarity(test, vocab)"
696 | ]
697 | },
698 | {
699 | "cell_type": "code",
700 | "execution_count": null,
701 | "metadata": {
702 | "collapsed": true
703 | },
704 | "outputs": [],
705 | "source": []
706 | }
707 | ],
708 | "metadata": {
709 | "kernelspec": {
710 | "display_name": "Python 3",
711 | "language": "python",
712 | "name": "python3"
713 | },
714 | "language_info": {
715 | "codemirror_mode": {
716 | "name": "ipython",
717 | "version": 3
718 | },
719 | "file_extension": ".py",
720 | "mimetype": "text/x-python",
721 | "name": "python",
722 | "nbconvert_exporter": "python",
723 | "pygments_lexer": "ipython3",
724 | "version": "3.5.2"
725 | }
726 | },
727 | "nbformat": 4,
728 | "nbformat_minor": 2
729 | }
730 |
--------------------------------------------------------------------------------
/notebooks/02.Skip-gram-Negative-Sampling.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# 2. Skip-gram with negative sampling"
8 | ]
9 | },
10 | {
11 | "cell_type": "markdown",
12 | "metadata": {},
13 | "source": [
14 | "I recommend you take a look at these material first."
15 | ]
16 | },
17 | {
18 | "cell_type": "markdown",
19 | "metadata": {},
20 | "source": [
21 | "* http://web.stanford.edu/class/cs224n/lectures/cs224n-2017-lecture3.pdf\n",
22 | "* http://papers.nips.cc/paper/5021-distributed-representations-of-words-and-phrases-and-their-compositionality.pdf"
23 | ]
24 | },
25 | {
26 | "cell_type": "code",
27 | "execution_count": 1,
28 | "metadata": {
29 | "collapsed": true
30 | },
31 | "outputs": [],
32 | "source": [
33 | "import torch\n",
34 | "import torch.nn as nn\n",
35 | "from torch.autograd import Variable\n",
36 | "import torch.optim as optim\n",
37 | "import torch.nn.functional as F\n",
38 | "import nltk\n",
39 | "import random\n",
40 | "import numpy as np\n",
41 | "from collections import Counter\n",
42 | "flatten = lambda l: [item for sublist in l for item in sublist]\n",
43 | "random.seed(1024)"
44 | ]
45 | },
46 | {
47 | "cell_type": "code",
48 | "execution_count": 2,
49 | "metadata": {
50 | "collapsed": false
51 | },
52 | "outputs": [
53 | {
54 | "name": "stdout",
55 | "output_type": "stream",
56 | "text": [
57 | "0.3.0.post4\n",
58 | "3.2.4\n"
59 | ]
60 | }
61 | ],
62 | "source": [
63 | "print(torch.__version__)\n",
64 | "print(nltk.__version__)"
65 | ]
66 | },
67 | {
68 | "cell_type": "code",
69 | "execution_count": 3,
70 | "metadata": {
71 | "collapsed": true
72 | },
73 | "outputs": [],
74 | "source": [
75 | "USE_CUDA = torch.cuda.is_available()\n",
76 | "gpus = [0]\n",
77 | "torch.cuda.set_device(gpus[0])\n",
78 | "\n",
79 | "FloatTensor = torch.cuda.FloatTensor if USE_CUDA else torch.FloatTensor\n",
80 | "LongTensor = torch.cuda.LongTensor if USE_CUDA else torch.LongTensor\n",
81 | "ByteTensor = torch.cuda.ByteTensor if USE_CUDA else torch.ByteTensor"
82 | ]
83 | },
84 | {
85 | "cell_type": "code",
86 | "execution_count": 4,
87 | "metadata": {
88 | "collapsed": true
89 | },
90 | "outputs": [],
91 | "source": [
92 | "def getBatch(batch_size, train_data):\n",
93 | " random.shuffle(train_data)\n",
94 | " sindex = 0\n",
95 | " eindex = batch_size\n",
96 | " while eindex < len(train_data):\n",
97 | " batch = train_data[sindex: eindex]\n",
98 | " temp = eindex\n",
99 | " eindex = eindex + batch_size\n",
100 | " sindex = temp\n",
101 | " yield batch\n",
102 | " \n",
103 | " if eindex >= len(train_data):\n",
104 | " batch = train_data[sindex:]\n",
105 | " yield batch"
106 | ]
107 | },
108 | {
109 | "cell_type": "code",
110 | "execution_count": 5,
111 | "metadata": {
112 | "collapsed": true
113 | },
114 | "outputs": [],
115 | "source": [
116 | "def prepare_sequence(seq, word2index):\n",
117 | " idxs = list(map(lambda w: word2index[w] if word2index.get(w) is not None else word2index[\"\"], seq))\n",
118 | " return Variable(LongTensor(idxs))\n",
119 | "\n",
120 | "def prepare_word(word, word2index):\n",
121 | " return Variable(LongTensor([word2index[word]]) if word2index.get(word) is not None else LongTensor([word2index[\"\"]]))"
122 | ]
123 | },
124 | {
125 | "cell_type": "markdown",
126 | "metadata": {},
127 | "source": [
128 | "## Data load and Preprocessing "
129 | ]
130 | },
131 | {
132 | "cell_type": "code",
133 | "execution_count": 6,
134 | "metadata": {
135 | "collapsed": true
136 | },
137 | "outputs": [],
138 | "source": [
139 | "corpus = list(nltk.corpus.gutenberg.sents('melville-moby_dick.txt'))[:500]\n",
140 | "corpus = [[word.lower() for word in sent] for sent in corpus]"
141 | ]
142 | },
143 | {
144 | "cell_type": "markdown",
145 | "metadata": {},
146 | "source": [
147 | "### Exclude sparse words "
148 | ]
149 | },
150 | {
151 | "cell_type": "code",
152 | "execution_count": 7,
153 | "metadata": {
154 | "collapsed": true
155 | },
156 | "outputs": [],
157 | "source": [
158 | "word_count = Counter(flatten(corpus))"
159 | ]
160 | },
161 | {
162 | "cell_type": "code",
163 | "execution_count": 8,
164 | "metadata": {
165 | "collapsed": true
166 | },
167 | "outputs": [],
168 | "source": [
169 | "MIN_COUNT = 3\n",
170 | "exclude = []"
171 | ]
172 | },
173 | {
174 | "cell_type": "code",
175 | "execution_count": 9,
176 | "metadata": {
177 | "collapsed": true
178 | },
179 | "outputs": [],
180 | "source": [
181 | "for w, c in word_count.items():\n",
182 | " if c < MIN_COUNT:\n",
183 | " exclude.append(w)"
184 | ]
185 | },
186 | {
187 | "cell_type": "markdown",
188 | "metadata": {},
189 | "source": [
190 | "### Prepare train data "
191 | ]
192 | },
193 | {
194 | "cell_type": "code",
195 | "execution_count": 10,
196 | "metadata": {
197 | "collapsed": true
198 | },
199 | "outputs": [],
200 | "source": [
201 | "vocab = list(set(flatten(corpus)) - set(exclude))"
202 | ]
203 | },
204 | {
205 | "cell_type": "code",
206 | "execution_count": 11,
207 | "metadata": {
208 | "collapsed": true
209 | },
210 | "outputs": [],
211 | "source": [
212 | "word2index = {}\n",
213 | "for vo in vocab:\n",
214 | " if word2index.get(vo) is None:\n",
215 | " word2index[vo] = len(word2index)\n",
216 | " \n",
217 | "index2word = {v:k for k, v in word2index.items()}"
218 | ]
219 | },
220 | {
221 | "cell_type": "code",
222 | "execution_count": 12,
223 | "metadata": {
224 | "collapsed": true
225 | },
226 | "outputs": [],
227 | "source": [
228 | "WINDOW_SIZE = 5\n",
229 | "windows = flatten([list(nltk.ngrams([''] * WINDOW_SIZE + c + [''] * WINDOW_SIZE, WINDOW_SIZE * 2 + 1)) for c in corpus])\n",
230 | "\n",
231 | "train_data = []\n",
232 | "\n",
233 | "for window in windows:\n",
234 | " for i in range(WINDOW_SIZE * 2 + 1):\n",
235 | " if window[i] in exclude or window[WINDOW_SIZE] in exclude: \n",
236 | " continue # min_count\n",
237 | " if i == WINDOW_SIZE or window[i] == '': \n",
238 | " continue\n",
239 | " train_data.append((window[WINDOW_SIZE], window[i]))\n",
240 | "\n",
241 | "X_p = []\n",
242 | "y_p = []\n",
243 | "\n",
244 | "for tr in train_data:\n",
245 | " X_p.append(prepare_word(tr[0], word2index).view(1, -1))\n",
246 | " y_p.append(prepare_word(tr[1], word2index).view(1, -1))\n",
247 | " \n",
248 | "train_data = list(zip(X_p, y_p))"
249 | ]
250 | },
251 | {
252 | "cell_type": "code",
253 | "execution_count": 13,
254 | "metadata": {
255 | "collapsed": false
256 | },
257 | "outputs": [
258 | {
259 | "data": {
260 | "text/plain": [
261 | "50242"
262 | ]
263 | },
264 | "execution_count": 13,
265 | "metadata": {},
266 | "output_type": "execute_result"
267 | }
268 | ],
269 | "source": [
270 | "len(train_data)"
271 | ]
272 | },
273 | {
274 | "cell_type": "markdown",
275 | "metadata": {},
276 | "source": [
277 | "### Build Unigram Distribution**0.75 "
278 | ]
279 | },
280 | {
281 | "cell_type": "markdown",
282 | "metadata": {},
283 | "source": [
284 | "$$P(w)=U(w)^{3/4}/Z$$"
285 | ]
286 | },
287 | {
288 | "cell_type": "code",
289 | "execution_count": 14,
290 | "metadata": {
291 | "collapsed": true
292 | },
293 | "outputs": [],
294 | "source": [
295 | "Z = 0.001"
296 | ]
297 | },
298 | {
299 | "cell_type": "code",
300 | "execution_count": 15,
301 | "metadata": {
302 | "collapsed": true
303 | },
304 | "outputs": [],
305 | "source": [
306 | "word_count = Counter(flatten(corpus))\n",
307 | "num_total_words = sum([c for w, c in word_count.items() if w not in exclude])"
308 | ]
309 | },
310 | {
311 | "cell_type": "code",
312 | "execution_count": 16,
313 | "metadata": {
314 | "collapsed": true
315 | },
316 | "outputs": [],
317 | "source": [
318 | "unigram_table = []\n",
319 | "\n",
320 | "for vo in vocab:\n",
321 | " unigram_table.extend([vo] * int(((word_count[vo]/num_total_words)**0.75)/Z))"
322 | ]
323 | },
324 | {
325 | "cell_type": "code",
326 | "execution_count": 17,
327 | "metadata": {
328 | "collapsed": false
329 | },
330 | "outputs": [
331 | {
332 | "name": "stdout",
333 | "output_type": "stream",
334 | "text": [
335 | "478 3500\n"
336 | ]
337 | }
338 | ],
339 | "source": [
340 | "print(len(vocab), len(unigram_table))"
341 | ]
342 | },
343 | {
344 | "cell_type": "markdown",
345 | "metadata": {},
346 | "source": [
347 | "### Negative Sampling "
348 | ]
349 | },
350 | {
351 | "cell_type": "code",
352 | "execution_count": 18,
353 | "metadata": {
354 | "collapsed": true
355 | },
356 | "outputs": [],
357 | "source": [
358 | "def negative_sampling(targets, unigram_table, k):\n",
359 | " batch_size = targets.size(0)\n",
360 | " neg_samples = []\n",
361 | " for i in range(batch_size):\n",
362 | " nsample = []\n",
363 | " target_index = targets[i].data.cpu().tolist()[0] if USE_CUDA else targets[i].data.tolist()[0]\n",
364 | " while len(nsample) < k: # num of sampling\n",
365 | " neg = random.choice(unigram_table)\n",
366 | " if word2index[neg] == target_index:\n",
367 | " continue\n",
368 | " nsample.append(neg)\n",
369 | " neg_samples.append(prepare_sequence(nsample, word2index).view(1, -1))\n",
370 | " \n",
371 | " return torch.cat(neg_samples)"
372 | ]
373 | },
374 | {
375 | "cell_type": "markdown",
376 | "metadata": {},
377 | "source": [
378 | "## Modeling "
379 | ]
380 | },
381 | {
382 | "cell_type": "markdown",
383 | "metadata": {},
384 | "source": [
385 | "
\n",
386 | "borrowed image from http://web.stanford.edu/class/cs224n/lectures/cs224n-2017-lecture3.pdf"
387 | ]
388 | },
389 | {
390 | "cell_type": "code",
391 | "execution_count": 19,
392 | "metadata": {
393 | "collapsed": true
394 | },
395 | "outputs": [],
396 | "source": [
397 | "class SkipgramNegSampling(nn.Module):\n",
398 | " \n",
399 | " def __init__(self, vocab_size, projection_dim):\n",
400 | " super(SkipgramNegSampling, self).__init__()\n",
401 | " self.embedding_v = nn.Embedding(vocab_size, projection_dim) # center embedding\n",
402 | " self.embedding_u = nn.Embedding(vocab_size, projection_dim) # out embedding\n",
403 | " self.logsigmoid = nn.LogSigmoid()\n",
404 | " \n",
405 | " initrange = (2.0 / (vocab_size + projection_dim))**0.5 # Xavier init\n",
406 | " self.embedding_v.weight.data.uniform_(-initrange, initrange) # init\n",
407 | " self.embedding_u.weight.data.uniform_(-0.0, 0.0) # init\n",
408 | " \n",
409 | " def forward(self, center_words, target_words, negative_words):\n",
410 | " center_embeds = self.embedding_v(center_words) # B x 1 x D\n",
411 | " target_embeds = self.embedding_u(target_words) # B x 1 x D\n",
412 | " \n",
413 | " neg_embeds = -self.embedding_u(negative_words) # B x K x D\n",
414 | " \n",
415 | " positive_score = target_embeds.bmm(center_embeds.transpose(1, 2)).squeeze(2) # Bx1\n",
416 | " negative_score = torch.sum(neg_embeds.bmm(center_embeds.transpose(1, 2)).squeeze(2), 1).view(negs.size(0), -1) # BxK -> Bx1\n",
417 | " \n",
418 | " loss = self.logsigmoid(positive_score) + self.logsigmoid(negative_score)\n",
419 | " \n",
420 | " return -torch.mean(loss)\n",
421 | " \n",
422 | " def prediction(self, inputs):\n",
423 | " embeds = self.embedding_v(inputs)\n",
424 | " \n",
425 | " return embeds"
426 | ]
427 | },
428 | {
429 | "cell_type": "markdown",
430 | "metadata": {},
431 | "source": [
432 | "## Train "
433 | ]
434 | },
435 | {
436 | "cell_type": "code",
437 | "execution_count": 68,
438 | "metadata": {
439 | "collapsed": true
440 | },
441 | "outputs": [],
442 | "source": [
443 | "EMBEDDING_SIZE = 30 \n",
444 | "BATCH_SIZE = 256\n",
445 | "EPOCH = 100\n",
446 | "NEG = 10 # Num of Negative Sampling"
447 | ]
448 | },
449 | {
450 | "cell_type": "code",
451 | "execution_count": 69,
452 | "metadata": {
453 | "collapsed": true
454 | },
455 | "outputs": [],
456 | "source": [
457 | "losses = []\n",
458 | "model = SkipgramNegSampling(len(word2index), EMBEDDING_SIZE)\n",
459 | "if USE_CUDA:\n",
460 | " model = model.cuda()\n",
461 | "optimizer = optim.Adam(model.parameters(), lr=0.001)"
462 | ]
463 | },
464 | {
465 | "cell_type": "code",
466 | "execution_count": 70,
467 | "metadata": {
468 | "collapsed": false
469 | },
470 | "outputs": [
471 | {
472 | "name": "stdout",
473 | "output_type": "stream",
474 | "text": [
475 | "Epoch : 0, mean_loss : 1.06\n",
476 | "Epoch : 10, mean_loss : 0.86\n",
477 | "Epoch : 20, mean_loss : 0.79\n",
478 | "Epoch : 30, mean_loss : 0.74\n",
479 | "Epoch : 40, mean_loss : 0.71\n",
480 | "Epoch : 50, mean_loss : 0.69\n",
481 | "Epoch : 60, mean_loss : 0.67\n",
482 | "Epoch : 70, mean_loss : 0.65\n",
483 | "Epoch : 80, mean_loss : 0.64\n",
484 | "Epoch : 90, mean_loss : 0.63\n"
485 | ]
486 | }
487 | ],
488 | "source": [
489 | "for epoch in range(EPOCH):\n",
490 | " for i,batch in enumerate(getBatch(BATCH_SIZE, train_data)):\n",
491 | " \n",
492 | " inputs, targets = zip(*batch)\n",
493 | " \n",
494 | " inputs = torch.cat(inputs) # B x 1\n",
495 | " targets = torch.cat(targets) # B x 1\n",
496 | " negs = negative_sampling(targets, unigram_table, NEG)\n",
497 | " model.zero_grad()\n",
498 | "\n",
499 | " loss = model(inputs, targets, negs)\n",
500 | " \n",
501 | " loss.backward()\n",
502 | " optimizer.step()\n",
503 | " \n",
504 | " losses.append(loss.data.tolist()[0])\n",
505 | " if epoch % 10 == 0:\n",
506 | " print(\"Epoch : %d, mean_loss : %.02f\" % (epoch, np.mean(losses)))\n",
507 | " losses = []"
508 | ]
509 | },
510 | {
511 | "cell_type": "markdown",
512 | "metadata": {},
513 | "source": [
514 | "## Test "
515 | ]
516 | },
517 | {
518 | "cell_type": "code",
519 | "execution_count": 71,
520 | "metadata": {
521 | "collapsed": true
522 | },
523 | "outputs": [],
524 | "source": [
525 | "def word_similarity(target, vocab):\n",
526 | " if USE_CUDA:\n",
527 | " target_V = model.prediction(prepare_word(target, word2index))\n",
528 | " else:\n",
529 | " target_V = model.prediction(prepare_word(target, word2index))\n",
530 | " similarities = []\n",
531 | " for i in range(len(vocab)):\n",
532 | " if vocab[i] == target: \n",
533 | " continue\n",
534 | " \n",
535 | " if USE_CUDA:\n",
536 | " vector = model.prediction(prepare_word(list(vocab)[i], word2index))\n",
537 | " else:\n",
538 | " vector = model.prediction(prepare_word(list(vocab)[i], word2index))\n",
539 | " \n",
540 | " cosine_sim = F.cosine_similarity(target_V, vector).data.tolist()[0]\n",
541 | " similarities.append([vocab[i], cosine_sim])\n",
542 | " return sorted(similarities, key=lambda x: x[1], reverse=True)[:10]"
543 | ]
544 | },
545 | {
546 | "cell_type": "code",
547 | "execution_count": 212,
548 | "metadata": {
549 | "collapsed": false
550 | },
551 | "outputs": [
552 | {
553 | "data": {
554 | "text/plain": [
555 | "'passengers'"
556 | ]
557 | },
558 | "execution_count": 212,
559 | "metadata": {},
560 | "output_type": "execute_result"
561 | }
562 | ],
563 | "source": [
564 | "test = random.choice(list(vocab))\n",
565 | "test"
566 | ]
567 | },
568 | {
569 | "cell_type": "code",
570 | "execution_count": 213,
571 | "metadata": {
572 | "collapsed": false
573 | },
574 | "outputs": [
575 | {
576 | "data": {
577 | "text/plain": [
578 | "[['am', 0.7353377342224121],\n",
579 | " ['passenger', 0.7154150605201721],\n",
580 | " ['cook', 0.6829826831817627],\n",
581 | " ['new', 0.6648461818695068],\n",
582 | " ['bedford', 0.6283411383628845],\n",
583 | " ['besides', 0.5972960591316223],\n",
584 | " ['themselves', 0.5964340567588806],\n",
585 | " ['grow', 0.5957046151161194],\n",
586 | " ['tell', 0.5952941179275513],\n",
587 | " ['get', 0.5943044424057007]]"
588 | ]
589 | },
590 | "execution_count": 213,
591 | "metadata": {},
592 | "output_type": "execute_result"
593 | }
594 | ],
595 | "source": [
596 | "word_similarity(test, vocab)"
597 | ]
598 | },
599 | {
600 | "cell_type": "code",
601 | "execution_count": null,
602 | "metadata": {
603 | "collapsed": true
604 | },
605 | "outputs": [],
606 | "source": []
607 | }
608 | ],
609 | "metadata": {
610 | "kernelspec": {
611 | "display_name": "Python 3",
612 | "language": "python",
613 | "name": "python3"
614 | },
615 | "language_info": {
616 | "codemirror_mode": {
617 | "name": "ipython",
618 | "version": 3
619 | },
620 | "file_extension": ".py",
621 | "mimetype": "text/x-python",
622 | "name": "python",
623 | "nbconvert_exporter": "python",
624 | "pygments_lexer": "ipython3",
625 | "version": "3.5.2"
626 | }
627 | },
628 | "nbformat": 4,
629 | "nbformat_minor": 2
630 | }
631 |
--------------------------------------------------------------------------------
/notebooks/03.GloVe.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# 3. GloVe: Global Vectors for Word Representation"
8 | ]
9 | },
10 | {
11 | "cell_type": "markdown",
12 | "metadata": {},
13 | "source": [
14 | "I recommend you take a look at these material first."
15 | ]
16 | },
17 | {
18 | "cell_type": "markdown",
19 | "metadata": {},
20 | "source": [
21 | "* http://web.stanford.edu/class/cs224n/lectures/cs224n-2017-lecture3.pdf\n",
22 | "* https://nlp.stanford.edu/pubs/glove.pdf"
23 | ]
24 | },
25 | {
26 | "cell_type": "code",
27 | "execution_count": 1,
28 | "metadata": {
29 | "collapsed": true
30 | },
31 | "outputs": [],
32 | "source": [
33 | "import torch\n",
34 | "import torch.nn as nn\n",
35 | "from torch.autograd import Variable\n",
36 | "import torch.optim as optim\n",
37 | "import torch.nn.functional as F\n",
38 | "import nltk\n",
39 | "import random\n",
40 | "import numpy as np\n",
41 | "from collections import Counter\n",
42 | "flatten = lambda l: [item for sublist in l for item in sublist]\n",
43 | "random.seed(1024)"
44 | ]
45 | },
46 | {
47 | "cell_type": "code",
48 | "execution_count": 2,
49 | "metadata": {
50 | "collapsed": false
51 | },
52 | "outputs": [
53 | {
54 | "name": "stdout",
55 | "output_type": "stream",
56 | "text": [
57 | "0.3.0.post4\n",
58 | "3.2.4\n"
59 | ]
60 | }
61 | ],
62 | "source": [
63 | "print(torch.__version__)\n",
64 | "print(nltk.__version__)"
65 | ]
66 | },
67 | {
68 | "cell_type": "code",
69 | "execution_count": 3,
70 | "metadata": {
71 | "collapsed": true
72 | },
73 | "outputs": [],
74 | "source": [
75 | "USE_CUDA = torch.cuda.is_available()\n",
76 | "gpus = [0]\n",
77 | "torch.cuda.set_device(gpus[0])\n",
78 | "\n",
79 | "FloatTensor = torch.cuda.FloatTensor if USE_CUDA else torch.FloatTensor\n",
80 | "LongTensor = torch.cuda.LongTensor if USE_CUDA else torch.LongTensor\n",
81 | "ByteTensor = torch.cuda.ByteTensor if USE_CUDA else torch.ByteTensor"
82 | ]
83 | },
84 | {
85 | "cell_type": "code",
86 | "execution_count": 4,
87 | "metadata": {
88 | "collapsed": true
89 | },
90 | "outputs": [],
91 | "source": [
92 | "def getBatch(batch_size, train_data):\n",
93 | " random.shuffle(train_data)\n",
94 | " sindex = 0\n",
95 | " eindex = batch_size\n",
96 | " while eindex < len(train_data):\n",
97 | " batch = train_data[sindex:eindex]\n",
98 | " temp = eindex\n",
99 | " eindex = eindex + batch_size\n",
100 | " sindex = temp\n",
101 | " yield batch\n",
102 | " \n",
103 | " if eindex >= len(train_data):\n",
104 | " batch = train_data[sindex:]\n",
105 | " yield batch"
106 | ]
107 | },
108 | {
109 | "cell_type": "code",
110 | "execution_count": 5,
111 | "metadata": {
112 | "collapsed": true
113 | },
114 | "outputs": [],
115 | "source": [
116 | "def prepare_sequence(seq, word2index):\n",
117 | " idxs = list(map(lambda w: word2index[w] if word2index.get(w) is not None else word2index[\"\"], seq))\n",
118 | " return Variable(LongTensor(idxs))\n",
119 | "\n",
120 | "def prepare_word(word, word2index):\n",
121 | " return Variable(LongTensor([word2index[word]]) if word2index.get(word) is not None else LongTensor([word2index[\"\"]]))"
122 | ]
123 | },
124 | {
125 | "cell_type": "markdown",
126 | "metadata": {},
127 | "source": [
128 | "## Data load and Preprocessing "
129 | ]
130 | },
131 | {
132 | "cell_type": "code",
133 | "execution_count": 6,
134 | "metadata": {
135 | "collapsed": true
136 | },
137 | "outputs": [],
138 | "source": [
139 | "corpus = list(nltk.corpus.gutenberg.sents('melville-moby_dick.txt'))[:500]\n",
140 | "corpus = [[word.lower() for word in sent] for sent in corpus]"
141 | ]
142 | },
143 | {
144 | "cell_type": "markdown",
145 | "metadata": {},
146 | "source": [
147 | "### Build vocab"
148 | ]
149 | },
150 | {
151 | "cell_type": "code",
152 | "execution_count": 7,
153 | "metadata": {
154 | "collapsed": true
155 | },
156 | "outputs": [],
157 | "source": [
158 | "vocab = list(set(flatten(corpus)))"
159 | ]
160 | },
161 | {
162 | "cell_type": "code",
163 | "execution_count": 8,
164 | "metadata": {
165 | "collapsed": true
166 | },
167 | "outputs": [],
168 | "source": [
169 | "word2index = {}\n",
170 | "for vo in vocab:\n",
171 | " if word2index.get(vo) is None:\n",
172 | " word2index[vo] = len(word2index)\n",
173 | " \n",
174 | "index2word={v:k for k, v in word2index.items()}"
175 | ]
176 | },
177 | {
178 | "cell_type": "code",
179 | "execution_count": 9,
180 | "metadata": {
181 | "collapsed": true
182 | },
183 | "outputs": [],
184 | "source": [
185 | "WINDOW_SIZE = 5\n",
186 | "windows = flatten([list(nltk.ngrams([''] * WINDOW_SIZE + c + [''] * WINDOW_SIZE, WINDOW_SIZE * 2 + 1)) for c in corpus])\n",
187 | "\n",
188 | "window_data = []\n",
189 | "\n",
190 | "for window in windows:\n",
191 | " for i in range(WINDOW_SIZE * 2 + 1):\n",
192 | " if i == WINDOW_SIZE or window[i] == '': \n",
193 | " continue\n",
194 | " window_data.append((window[WINDOW_SIZE], window[i]))\n"
195 | ]
196 | },
197 | {
198 | "cell_type": "markdown",
199 | "metadata": {},
200 | "source": [
201 | "### Weighting Function "
202 | ]
203 | },
204 | {
205 | "cell_type": "markdown",
206 | "metadata": {},
207 | "source": [
208 | "
\n",
209 | "borrowed image from https://nlp.stanford.edu/pubs/glove.pdf"
210 | ]
211 | },
212 | {
213 | "cell_type": "code",
214 | "execution_count": 10,
215 | "metadata": {
216 | "collapsed": true
217 | },
218 | "outputs": [],
219 | "source": [
220 | "def weighting(w_i, w_j):\n",
221 | " try:\n",
222 | " x_ij = X_ik[(w_i, w_j)]\n",
223 | " except:\n",
224 | " x_ij = 1\n",
225 | " \n",
226 | " x_max = 100 #100 # fixed in paper\n",
227 | " alpha = 0.75\n",
228 | " \n",
229 | " if x_ij < x_max:\n",
230 | " result = (x_ij/x_max)**alpha\n",
231 | " else:\n",
232 | " result = 1\n",
233 | " \n",
234 | " return result"
235 | ]
236 | },
237 | {
238 | "cell_type": "markdown",
239 | "metadata": {},
240 | "source": [
241 | "### Build Co-occurence Matrix X"
242 | ]
243 | },
244 | {
245 | "cell_type": "markdown",
246 | "metadata": {},
247 | "source": [
248 | "Because of model complexity, It is important to determine whether a tighter bound can be placed on the number of nonzero elements of X."
249 | ]
250 | },
251 | {
252 | "cell_type": "code",
253 | "execution_count": 11,
254 | "metadata": {
255 | "collapsed": true
256 | },
257 | "outputs": [],
258 | "source": [
259 | "X_i = Counter(flatten(corpus)) # X_i"
260 | ]
261 | },
262 | {
263 | "cell_type": "code",
264 | "execution_count": 12,
265 | "metadata": {
266 | "collapsed": true
267 | },
268 | "outputs": [],
269 | "source": [
270 | "X_ik_window_5 = Counter(window_data) # Co-occurece in window size 5"
271 | ]
272 | },
273 | {
274 | "cell_type": "code",
275 | "execution_count": 13,
276 | "metadata": {
277 | "collapsed": true
278 | },
279 | "outputs": [],
280 | "source": [
281 | "X_ik = {}\n",
282 | "weighting_dic = {}"
283 | ]
284 | },
285 | {
286 | "cell_type": "code",
287 | "execution_count": 14,
288 | "metadata": {
289 | "collapsed": true
290 | },
291 | "outputs": [],
292 | "source": [
293 | "from itertools import combinations_with_replacement"
294 | ]
295 | },
296 | {
297 | "cell_type": "code",
298 | "execution_count": 15,
299 | "metadata": {
300 | "collapsed": true
301 | },
302 | "outputs": [],
303 | "source": [
304 | "for bigram in combinations_with_replacement(vocab, 2):\n",
305 | " if X_ik_window_5.get(bigram) is not None: # nonzero elements\n",
306 | " co_occer = X_ik_window_5[bigram]\n",
307 | " X_ik[bigram] = co_occer + 1 # log(Xik) -> log(Xik+1) to prevent divergence\n",
308 | " X_ik[(bigram[1],bigram[0])] = co_occer+1\n",
309 | " else:\n",
310 | " pass\n",
311 | " \n",
312 | " weighting_dic[bigram] = weighting(bigram[0], bigram[1])\n",
313 | " weighting_dic[(bigram[1], bigram[0])] = weighting(bigram[1], bigram[0])"
314 | ]
315 | },
316 | {
317 | "cell_type": "code",
318 | "execution_count": 16,
319 | "metadata": {
320 | "collapsed": false
321 | },
322 | "outputs": [
323 | {
324 | "name": "stdout",
325 | "output_type": "stream",
326 | "text": [
327 | "(',', 'was')\n",
328 | "True\n"
329 | ]
330 | }
331 | ],
332 | "source": [
333 | "test = random.choice(window_data)\n",
334 | "print(test)\n",
335 | "try:\n",
336 | " print(X_ik[(test[0], test[1])] == X_ik[(test[1], test[0])])\n",
337 | "except:\n",
338 | " 1"
339 | ]
340 | },
341 | {
342 | "cell_type": "markdown",
343 | "metadata": {},
344 | "source": [
345 | "### Prepare train data"
346 | ]
347 | },
348 | {
349 | "cell_type": "code",
350 | "execution_count": 17,
351 | "metadata": {
352 | "collapsed": false
353 | },
354 | "outputs": [
355 | {
356 | "name": "stdout",
357 | "output_type": "stream",
358 | "text": [
359 | "(Variable containing:\n",
360 | " 703\n",
361 | "[torch.cuda.LongTensor of size 1x1 (GPU 0)]\n",
362 | ", Variable containing:\n",
363 | " 23\n",
364 | "[torch.cuda.LongTensor of size 1x1 (GPU 0)]\n",
365 | ", Variable containing:\n",
366 | " 0.6931\n",
367 | "[torch.cuda.FloatTensor of size 1x1 (GPU 0)]\n",
368 | ", Variable containing:\n",
369 | "1.00000e-02 *\n",
370 | " 5.3183\n",
371 | "[torch.cuda.FloatTensor of size 1x1 (GPU 0)]\n",
372 | ")\n"
373 | ]
374 | }
375 | ],
376 | "source": [
377 | "u_p = [] # center vec\n",
378 | "v_p = [] # context vec\n",
379 | "co_p = [] # log(x_ij)\n",
380 | "weight_p = [] # f(x_ij)\n",
381 | "\n",
382 | "for pair in window_data: \n",
383 | " u_p.append(prepare_word(pair[0], word2index).view(1, -1))\n",
384 | " v_p.append(prepare_word(pair[1], word2index).view(1, -1))\n",
385 | " \n",
386 | " try:\n",
387 | " cooc = X_ik[pair]\n",
388 | " except:\n",
389 | " cooc = 1\n",
390 | "\n",
391 | " co_p.append(torch.log(Variable(FloatTensor([cooc]))).view(1, -1))\n",
392 | " weight_p.append(Variable(FloatTensor([weighting_dic[pair]])).view(1, -1))\n",
393 | " \n",
394 | "train_data = list(zip(u_p, v_p, co_p, weight_p))\n",
395 | "del u_p\n",
396 | "del v_p\n",
397 | "del co_p\n",
398 | "del weight_p\n",
399 | "print(train_data[0]) # tuple (center vec i, context vec j log(x_ij), weight f(w_ij))"
400 | ]
401 | },
402 | {
403 | "cell_type": "markdown",
404 | "metadata": {},
405 | "source": [
406 | "## Modeling "
407 | ]
408 | },
409 | {
410 | "cell_type": "markdown",
411 | "metadata": {},
412 | "source": [
413 | "
\n",
414 | "borrowed image from https://nlp.stanford.edu/pubs/glove.pdf"
415 | ]
416 | },
417 | {
418 | "cell_type": "code",
419 | "execution_count": 19,
420 | "metadata": {
421 | "collapsed": true
422 | },
423 | "outputs": [],
424 | "source": [
425 | "class GloVe(nn.Module):\n",
426 | " \n",
427 | " def __init__(self, vocab_size,projection_dim):\n",
428 | " super(GloVe,self).__init__()\n",
429 | " self.embedding_v = nn.Embedding(vocab_size, projection_dim) # center embedding\n",
430 | " self.embedding_u = nn.Embedding(vocab_size, projection_dim) # out embedding\n",
431 | " \n",
432 | " self.v_bias = nn.Embedding(vocab_size, 1)\n",
433 | " self.u_bias = nn.Embedding(vocab_size, 1)\n",
434 | " \n",
435 | " initrange = (2.0 / (vocab_size + projection_dim))**0.5 # Xavier init\n",
436 | " self.embedding_v.weight.data.uniform_(-initrange, initrange) # init\n",
437 | " self.embedding_u.weight.data.uniform_(-initrange, initrange) # init\n",
438 | " self.v_bias.weight.data.uniform_(-initrange, initrange) # init\n",
439 | " self.u_bias.weight.data.uniform_(-initrange, initrange) # init\n",
440 | " \n",
441 | " def forward(self, center_words, target_words, coocs, weights):\n",
442 | " center_embeds = self.embedding_v(center_words) # B x 1 x D\n",
443 | " target_embeds = self.embedding_u(target_words) # B x 1 x D\n",
444 | " \n",
445 | " center_bias = self.v_bias(center_words).squeeze(1)\n",
446 | " target_bias = self.u_bias(target_words).squeeze(1)\n",
447 | " \n",
448 | " inner_product = target_embeds.bmm(center_embeds.transpose(1, 2)).squeeze(2) # Bx1\n",
449 | " \n",
450 | " loss = weights*torch.pow(inner_product +center_bias + target_bias - coocs, 2)\n",
451 | " \n",
452 | " return torch.sum(loss)\n",
453 | " \n",
454 | " def prediction(self, inputs):\n",
455 | " v_embeds = self.embedding_v(inputs) # B x 1 x D\n",
456 | " u_embeds = self.embedding_u(inputs) # B x 1 x D\n",
457 | " \n",
458 | " return v_embeds+u_embeds # final embed"
459 | ]
460 | },
461 | {
462 | "cell_type": "markdown",
463 | "metadata": {},
464 | "source": [
465 | "## Train "
466 | ]
467 | },
468 | {
469 | "cell_type": "code",
470 | "execution_count": 22,
471 | "metadata": {
472 | "collapsed": true
473 | },
474 | "outputs": [],
475 | "source": [
476 | "EMBEDDING_SIZE = 50\n",
477 | "BATCH_SIZE = 256\n",
478 | "EPOCH = 50"
479 | ]
480 | },
481 | {
482 | "cell_type": "code",
483 | "execution_count": 23,
484 | "metadata": {
485 | "collapsed": true
486 | },
487 | "outputs": [],
488 | "source": [
489 | "losses = []\n",
490 | "model = GloVe(len(word2index), EMBEDDING_SIZE)\n",
491 | "if USE_CUDA:\n",
492 | " model = model.cuda()\n",
493 | "optimizer = optim.Adam(model.parameters(), lr=0.001)"
494 | ]
495 | },
496 | {
497 | "cell_type": "code",
498 | "execution_count": 24,
499 | "metadata": {
500 | "collapsed": false
501 | },
502 | "outputs": [
503 | {
504 | "name": "stdout",
505 | "output_type": "stream",
506 | "text": [
507 | "Epoch : 0, mean_loss : 236.10\n",
508 | "Epoch : 10, mean_loss : 2.27\n",
509 | "Epoch : 20, mean_loss : 0.53\n",
510 | "Epoch : 30, mean_loss : 0.12\n",
511 | "Epoch : 40, mean_loss : 0.04\n"
512 | ]
513 | }
514 | ],
515 | "source": [
516 | "for epoch in range(EPOCH):\n",
517 | " for i,batch in enumerate(getBatch(BATCH_SIZE, train_data)):\n",
518 | " \n",
519 | " inputs, targets, coocs, weights = zip(*batch)\n",
520 | " \n",
521 | " inputs = torch.cat(inputs) # B x 1\n",
522 | " targets = torch.cat(targets) # B x 1\n",
523 | " coocs = torch.cat(coocs)\n",
524 | " weights = torch.cat(weights)\n",
525 | " model.zero_grad()\n",
526 | "\n",
527 | " loss = model(inputs, targets, coocs, weights)\n",
528 | " \n",
529 | " loss.backward()\n",
530 | " optimizer.step()\n",
531 | " \n",
532 | " losses.append(loss.data.tolist()[0])\n",
533 | " if epoch % 10 == 0:\n",
534 | " print(\"Epoch : %d, mean_loss : %.02f\" % (epoch, np.mean(losses)))\n",
535 | " losses = []"
536 | ]
537 | },
538 | {
539 | "cell_type": "markdown",
540 | "metadata": {},
541 | "source": [
542 | "## Test "
543 | ]
544 | },
545 | {
546 | "cell_type": "code",
547 | "execution_count": 25,
548 | "metadata": {
549 | "collapsed": true
550 | },
551 | "outputs": [],
552 | "source": [
553 | "def word_similarity(target, vocab):\n",
554 | " if USE_CUDA:\n",
555 | " target_V = model.prediction(prepare_word(target, word2index))\n",
556 | " else:\n",
557 | " target_V = model.prediction(prepare_word(target, word2index))\n",
558 | " similarities = []\n",
559 | " for i in range(len(vocab)):\n",
560 | " if vocab[i] == target: \n",
561 | " continue\n",
562 | " \n",
563 | " if USE_CUDA:\n",
564 | " vector = model.prediction(prepare_word(list(vocab)[i], word2index))\n",
565 | " else:\n",
566 | " vector = model.prediction(prepare_word(list(vocab)[i], word2index))\n",
567 | " \n",
568 | " cosine_sim = F.cosine_similarity(target_V, vector).data.tolist()[0] \n",
569 | " similarities.append([vocab[i], cosine_sim])\n",
570 | " return sorted(similarities, key=lambda x: x[1], reverse=True)[:10]"
571 | ]
572 | },
573 | {
574 | "cell_type": "code",
575 | "execution_count": 86,
576 | "metadata": {
577 | "collapsed": false
578 | },
579 | "outputs": [
580 | {
581 | "data": {
582 | "text/plain": [
583 | "'spiral'"
584 | ]
585 | },
586 | "execution_count": 86,
587 | "metadata": {},
588 | "output_type": "execute_result"
589 | }
590 | ],
591 | "source": [
592 | "test = random.choice(list(vocab))\n",
593 | "test"
594 | ]
595 | },
596 | {
597 | "cell_type": "code",
598 | "execution_count": 87,
599 | "metadata": {
600 | "collapsed": false
601 | },
602 | "outputs": [
603 | {
604 | "data": {
605 | "text/plain": [
606 | "[['horns', 0.9727935194969177],\n",
607 | " ['swords', 0.9076412916183472],\n",
608 | " ['hooked', 0.8984033465385437],\n",
609 | " ['thar', 0.8066437244415283],\n",
610 | " ['montaigne', 0.8062068819999695],\n",
611 | " ['rabelais', 0.789764940738678],\n",
612 | " ['orion', 0.7886737585067749],\n",
613 | " ['isaiah', 0.780662477016449],\n",
614 | " ['hamlet', 0.7799868583679199],\n",
615 | " ['colnett', 0.7792885899543762]]"
616 | ]
617 | },
618 | "execution_count": 87,
619 | "metadata": {},
620 | "output_type": "execute_result"
621 | }
622 | ],
623 | "source": [
624 | "word_similarity(test, vocab)"
625 | ]
626 | },
627 | {
628 | "cell_type": "markdown",
629 | "metadata": {
630 | "collapsed": true
631 | },
632 | "source": [
633 | "## TODO"
634 | ]
635 | },
636 | {
637 | "cell_type": "markdown",
638 | "metadata": {},
639 | "source": [
640 | "* Use sparse-matrix to build co-occurence matrix for memory efficiency"
641 | ]
642 | },
643 | {
644 | "cell_type": "markdown",
645 | "metadata": {},
646 | "source": [
647 | "## Suggested Readings"
648 | ]
649 | },
650 | {
651 | "cell_type": "markdown",
652 | "metadata": {},
653 | "source": [
654 | "* Word embeddings in 2017: Trends and future directions"
655 | ]
656 | }
657 | ],
658 | "metadata": {
659 | "kernelspec": {
660 | "display_name": "Python 3",
661 | "language": "python",
662 | "name": "python3"
663 | },
664 | "language_info": {
665 | "codemirror_mode": {
666 | "name": "ipython",
667 | "version": 3
668 | },
669 | "file_extension": ".py",
670 | "mimetype": "text/x-python",
671 | "name": "python",
672 | "nbconvert_exporter": "python",
673 | "pygments_lexer": "ipython3",
674 | "version": "3.5.2"
675 | }
676 | },
677 | "nbformat": 4,
678 | "nbformat_minor": 2
679 | }
680 |
--------------------------------------------------------------------------------
/notebooks/04.Window-Classifier-for-NER.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# 4. Word Window Classification and Neural Networks "
8 | ]
9 | },
10 | {
11 | "cell_type": "markdown",
12 | "metadata": {},
13 | "source": [
14 | "I recommend you take a look at these material first."
15 | ]
16 | },
17 | {
18 | "cell_type": "markdown",
19 | "metadata": {},
20 | "source": [
21 | "* http://web.stanford.edu/class/cs224n/lectures/cs224n-2017-lecture4.pdf\n",
22 | "* https://en.wikipedia.org/wiki/Named-entity_recognition"
23 | ]
24 | },
25 | {
26 | "cell_type": "code",
27 | "execution_count": 1,
28 | "metadata": {
29 | "collapsed": true
30 | },
31 | "outputs": [],
32 | "source": [
33 | "import torch\n",
34 | "import torch.nn as nn\n",
35 | "from torch.autograd import Variable\n",
36 | "import torch.optim as optim\n",
37 | "import torch.nn.functional as F\n",
38 | "import nltk\n",
39 | "import random\n",
40 | "import numpy as np\n",
41 | "from collections import Counter\n",
42 | "flatten = lambda l: [item for sublist in l for item in sublist]\n",
43 | "from sklearn_crfsuite import metrics\n",
44 | "random.seed(1024)"
45 | ]
46 | },
47 | {
48 | "cell_type": "markdown",
49 | "metadata": {},
50 | "source": [
51 | "You also need sklearn_crfsuite latest version for print confusion matrix"
52 | ]
53 | },
54 | {
55 | "cell_type": "code",
56 | "execution_count": 2,
57 | "metadata": {
58 | "collapsed": false
59 | },
60 | "outputs": [
61 | {
62 | "name": "stdout",
63 | "output_type": "stream",
64 | "text": [
65 | "0.3.0.post4\n",
66 | "3.2.4\n"
67 | ]
68 | }
69 | ],
70 | "source": [
71 | "print(torch.__version__)\n",
72 | "print(nltk.__version__)"
73 | ]
74 | },
75 | {
76 | "cell_type": "code",
77 | "execution_count": 3,
78 | "metadata": {
79 | "collapsed": true
80 | },
81 | "outputs": [],
82 | "source": [
83 | "USE_CUDA = torch.cuda.is_available()\n",
84 | "gpus = [0]\n",
85 | "torch.cuda.set_device(gpus[0])\n",
86 | "\n",
87 | "FloatTensor = torch.cuda.FloatTensor if USE_CUDA else torch.FloatTensor\n",
88 | "LongTensor = torch.cuda.LongTensor if USE_CUDA else torch.LongTensor\n",
89 | "ByteTensor = torch.cuda.ByteTensor if USE_CUDA else torch.ByteTensor"
90 | ]
91 | },
92 | {
93 | "cell_type": "code",
94 | "execution_count": 4,
95 | "metadata": {
96 | "collapsed": true
97 | },
98 | "outputs": [],
99 | "source": [
100 | "def getBatch(batch_size, train_data):\n",
101 | " random.shuffle(train_data)\n",
102 | " sindex = 0\n",
103 | " eindex = batch_size\n",
104 | " while eindex < len(train_data):\n",
105 | " batch = train_data[sindex: eindex]\n",
106 | " temp = eindex\n",
107 | " eindex = eindex + batch_size\n",
108 | " sindex = temp\n",
109 | " yield batch\n",
110 | " \n",
111 | " if eindex >= len(train_data):\n",
112 | " batch = train_data[sindex:]\n",
113 | " yield batch"
114 | ]
115 | },
116 | {
117 | "cell_type": "code",
118 | "execution_count": 5,
119 | "metadata": {
120 | "collapsed": true
121 | },
122 | "outputs": [],
123 | "source": [
124 | "def prepare_sequence(seq, word2index):\n",
125 | " idxs = list(map(lambda w: word2index[w] if word2index.get(w) is not None else word2index[\"\"], seq))\n",
126 | " return Variable(LongTensor(idxs))\n",
127 | "\n",
128 | "def prepare_word(word, word2index):\n",
129 | " return Variable(LongTensor([word2index[word]]) if word2index.get(word) is not None else LongTensor([word2index[\"\"]]))\n",
130 | "\n",
131 | "def prepare_tag(tag,tag2index):\n",
132 | " return Variable(LongTensor([tag2index[tag]]))"
133 | ]
134 | },
135 | {
136 | "cell_type": "markdown",
137 | "metadata": {},
138 | "source": [
139 | "## Data load and Preprocessing "
140 | ]
141 | },
142 | {
143 | "cell_type": "markdown",
144 | "metadata": {},
145 | "source": [
146 | "CoNLL-2002 Shared Task: Language-Independent Named Entity Recognition
\n",
147 | "https://www.clips.uantwerpen.be/conll2002/ner/"
148 | ]
149 | },
150 | {
151 | "cell_type": "code",
152 | "execution_count": 6,
153 | "metadata": {
154 | "collapsed": true
155 | },
156 | "outputs": [],
157 | "source": [
158 | "corpus = nltk.corpus.conll2002.iob_sents()"
159 | ]
160 | },
161 | {
162 | "cell_type": "code",
163 | "execution_count": 7,
164 | "metadata": {
165 | "collapsed": true
166 | },
167 | "outputs": [],
168 | "source": [
169 | "data = []\n",
170 | "for cor in corpus:\n",
171 | " sent, _, tag = list(zip(*cor))\n",
172 | " data.append([sent, tag])"
173 | ]
174 | },
175 | {
176 | "cell_type": "code",
177 | "execution_count": 8,
178 | "metadata": {
179 | "collapsed": false
180 | },
181 | "outputs": [
182 | {
183 | "name": "stdout",
184 | "output_type": "stream",
185 | "text": [
186 | "35651\n",
187 | "[('Sao', 'Paulo', '(', 'Brasil', ')', ',', '23', 'may', '(', 'EFECOM', ')', '.'), ('B-LOC', 'I-LOC', 'O', 'B-LOC', 'O', 'O', 'O', 'O', 'O', 'B-ORG', 'O', 'O')]\n"
188 | ]
189 | }
190 | ],
191 | "source": [
192 | "print(len(data))\n",
193 | "print(data[0])"
194 | ]
195 | },
196 | {
197 | "cell_type": "markdown",
198 | "metadata": {},
199 | "source": [
200 | "### Build Vocab"
201 | ]
202 | },
203 | {
204 | "cell_type": "code",
205 | "execution_count": 9,
206 | "metadata": {
207 | "collapsed": true
208 | },
209 | "outputs": [],
210 | "source": [
211 | "sents,tags = list(zip(*data))\n",
212 | "vocab = list(set(flatten(sents)))\n",
213 | "tagset = list(set(flatten(tags)))"
214 | ]
215 | },
216 | {
217 | "cell_type": "code",
218 | "execution_count": 10,
219 | "metadata": {
220 | "collapsed": true
221 | },
222 | "outputs": [],
223 | "source": [
224 | "word2index={'' : 0, '' : 1} # dummy token is for start or end of sentence\n",
225 | "for vo in vocab:\n",
226 | " if word2index.get(vo) is None:\n",
227 | " word2index[vo] = len(word2index)\n",
228 | "index2word = {v:k for k, v in word2index.items()}\n",
229 | "\n",
230 | "tag2index = {}\n",
231 | "for tag in tagset:\n",
232 | " if tag2index.get(tag) is None:\n",
233 | " tag2index[tag] = len(tag2index)\n",
234 | "index2tag={v:k for k, v in tag2index.items()}"
235 | ]
236 | },
237 | {
238 | "cell_type": "markdown",
239 | "metadata": {},
240 | "source": [
241 | "### Prepare data"
242 | ]
243 | },
244 | {
245 | "cell_type": "markdown",
246 | "metadata": {},
247 | "source": [
248 | "Example : Classify 'Paris' in the context of this sentence with window length 2"
249 | ]
250 | },
251 | {
252 | "cell_type": "markdown",
253 | "metadata": {},
254 | "source": [
255 | "
"
256 | ]
257 | },
258 | {
259 | "cell_type": "markdown",
260 | "metadata": {},
261 | "source": [
262 | "borrowed image from http://web.stanford.edu/class/cs224n/lectures/cs224n-2017-lecture4.pdf"
263 | ]
264 | },
265 | {
266 | "cell_type": "code",
267 | "execution_count": 11,
268 | "metadata": {
269 | "collapsed": true
270 | },
271 | "outputs": [],
272 | "source": [
273 | "WINDOW_SIZE = 2\n",
274 | "windows = []"
275 | ]
276 | },
277 | {
278 | "cell_type": "code",
279 | "execution_count": 12,
280 | "metadata": {
281 | "collapsed": true
282 | },
283 | "outputs": [],
284 | "source": [
285 | "for sample in data:\n",
286 | " dummy = [''] * WINDOW_SIZE\n",
287 | " window = list(nltk.ngrams(dummy + list(sample[0]) + dummy, WINDOW_SIZE * 2 + 1))\n",
288 | " windows.extend([[list(window[i]), sample[1][i]] for i in range(len(sample[0]))])"
289 | ]
290 | },
291 | {
292 | "cell_type": "code",
293 | "execution_count": 13,
294 | "metadata": {
295 | "collapsed": false
296 | },
297 | "outputs": [
298 | {
299 | "data": {
300 | "text/plain": [
301 | "[['', '', 'Sao', 'Paulo', '('], 'B-LOC']"
302 | ]
303 | },
304 | "execution_count": 13,
305 | "metadata": {},
306 | "output_type": "execute_result"
307 | }
308 | ],
309 | "source": [
310 | "windows[0]"
311 | ]
312 | },
313 | {
314 | "cell_type": "code",
315 | "execution_count": 14,
316 | "metadata": {
317 | "collapsed": false
318 | },
319 | "outputs": [
320 | {
321 | "data": {
322 | "text/plain": [
323 | "678377"
324 | ]
325 | },
326 | "execution_count": 14,
327 | "metadata": {},
328 | "output_type": "execute_result"
329 | }
330 | ],
331 | "source": [
332 | "len(windows)"
333 | ]
334 | },
335 | {
336 | "cell_type": "code",
337 | "execution_count": 15,
338 | "metadata": {
339 | "collapsed": true
340 | },
341 | "outputs": [],
342 | "source": [
343 | "random.shuffle(windows)\n",
344 | "\n",
345 | "train_data = windows[:int(len(windows) * 0.9)]\n",
346 | "test_data = windows[int(len(windows) * 0.9):]"
347 | ]
348 | },
349 | {
350 | "cell_type": "markdown",
351 | "metadata": {},
352 | "source": [
353 | "## Modeling "
354 | ]
355 | },
356 | {
357 | "cell_type": "markdown",
358 | "metadata": {},
359 | "source": [
360 | "
\n",
361 | "borrowed image from http://web.stanford.edu/class/cs224n/lectures/cs224n-2017-lecture4.pdf"
362 | ]
363 | },
364 | {
365 | "cell_type": "code",
366 | "execution_count": 16,
367 | "metadata": {
368 | "collapsed": true
369 | },
370 | "outputs": [],
371 | "source": [
372 | "class WindowClassifier(nn.Module): \n",
373 | " def __init__(self, vocab_size, embedding_size, window_size, hidden_size, output_size):\n",
374 | "\n",
375 | " super(WindowClassifier, self).__init__()\n",
376 | " \n",
377 | " self.embed = nn.Embedding(vocab_size, embedding_size)\n",
378 | " self.h_layer1 = nn.Linear(embedding_size * (window_size * 2 + 1), hidden_size)\n",
379 | " self.h_layer2 = nn.Linear(hidden_size, hidden_size)\n",
380 | " self.o_layer = nn.Linear(hidden_size, output_size)\n",
381 | " self.relu = nn.ReLU()\n",
382 | " self.softmax = nn.LogSoftmax(dim=1)\n",
383 | " self.dropout = nn.Dropout(0.3)\n",
384 | " \n",
385 | " def forward(self, inputs, is_training=False): \n",
386 | " embeds = self.embed(inputs) # BxWxD\n",
387 | " concated = embeds.view(-1, embeds.size(1)*embeds.size(2)) # Bx(W*D)\n",
388 | " h0 = self.relu(self.h_layer1(concated))\n",
389 | " if is_training:\n",
390 | " h0 = self.dropout(h0)\n",
391 | " h1 = self.relu(self.h_layer2(h0))\n",
392 | " if is_training:\n",
393 | " h1 = self.dropout(h1)\n",
394 | " out = self.softmax(self.o_layer(h1))\n",
395 | " return out"
396 | ]
397 | },
398 | {
399 | "cell_type": "code",
400 | "execution_count": 20,
401 | "metadata": {
402 | "collapsed": true
403 | },
404 | "outputs": [],
405 | "source": [
406 | "BATCH_SIZE = 128\n",
407 | "EMBEDDING_SIZE = 50 # x (WINDOW_SIZE*2+1) = 250\n",
408 | "HIDDEN_SIZE = 300\n",
409 | "EPOCH = 3\n",
410 | "LEARNING_RATE = 0.001"
411 | ]
412 | },
413 | {
414 | "cell_type": "markdown",
415 | "metadata": {},
416 | "source": [
417 | "## Training "
418 | ]
419 | },
420 | {
421 | "cell_type": "markdown",
422 | "metadata": {},
423 | "source": [
424 | "It takes for a while if you use just cpu."
425 | ]
426 | },
427 | {
428 | "cell_type": "code",
429 | "execution_count": 22,
430 | "metadata": {
431 | "collapsed": true
432 | },
433 | "outputs": [],
434 | "source": [
435 | "model = WindowClassifier(len(word2index), EMBEDDING_SIZE, WINDOW_SIZE, HIDDEN_SIZE, len(tag2index))\n",
436 | "if USE_CUDA:\n",
437 | " model = model.cuda()\n",
438 | "loss_function = nn.CrossEntropyLoss()\n",
439 | "optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)"
440 | ]
441 | },
442 | {
443 | "cell_type": "code",
444 | "execution_count": 23,
445 | "metadata": {
446 | "collapsed": false
447 | },
448 | "outputs": [
449 | {
450 | "name": "stdout",
451 | "output_type": "stream",
452 | "text": [
453 | "[0/3] mean_loss : 2.25\n",
454 | "[0/3] mean_loss : 0.47\n",
455 | "[0/3] mean_loss : 0.36\n",
456 | "[0/3] mean_loss : 0.31\n",
457 | "[0/3] mean_loss : 0.28\n",
458 | "[1/3] mean_loss : 0.22\n",
459 | "[1/3] mean_loss : 0.21\n",
460 | "[1/3] mean_loss : 0.21\n",
461 | "[1/3] mean_loss : 0.19\n",
462 | "[1/3] mean_loss : 0.19\n",
463 | "[2/3] mean_loss : 0.12\n",
464 | "[2/3] mean_loss : 0.15\n",
465 | "[2/3] mean_loss : 0.15\n",
466 | "[2/3] mean_loss : 0.14\n",
467 | "[2/3] mean_loss : 0.14\n"
468 | ]
469 | }
470 | ],
471 | "source": [
472 | "for epoch in range(EPOCH):\n",
473 | " losses = []\n",
474 | " for i,batch in enumerate(getBatch(BATCH_SIZE, train_data)):\n",
475 | " x,y=list(zip(*batch))\n",
476 | " inputs = torch.cat([prepare_sequence(sent, word2index).view(1, -1) for sent in x])\n",
477 | " targets = torch.cat([prepare_tag(tag, tag2index) for tag in y])\n",
478 | " model.zero_grad()\n",
479 | " preds = model(inputs, is_training=True)\n",
480 | " loss = loss_function(preds, targets)\n",
481 | " losses.append(loss.data.tolist()[0])\n",
482 | " loss.backward()\n",
483 | " optimizer.step()\n",
484 | "\n",
485 | " if i % 1000 == 0:\n",
486 | " print(\"[%d/%d] mean_loss : %0.2f\" %(epoch, EPOCH, np.mean(losses)))\n",
487 | " losses = []"
488 | ]
489 | },
490 | {
491 | "cell_type": "markdown",
492 | "metadata": {},
493 | "source": [
494 | "## Test "
495 | ]
496 | },
497 | {
498 | "cell_type": "code",
499 | "execution_count": 24,
500 | "metadata": {
501 | "collapsed": true
502 | },
503 | "outputs": [],
504 | "source": [
505 | "for_f1_score = []"
506 | ]
507 | },
508 | {
509 | "cell_type": "code",
510 | "execution_count": 25,
511 | "metadata": {
512 | "collapsed": false
513 | },
514 | "outputs": [
515 | {
516 | "name": "stdout",
517 | "output_type": "stream",
518 | "text": [
519 | "95.69120551903063\n"
520 | ]
521 | }
522 | ],
523 | "source": [
524 | "accuracy = 0\n",
525 | "for test in test_data:\n",
526 | " x, y = test[0], test[1]\n",
527 | " input_ = prepare_sequence(x, word2index).view(1, -1)\n",
528 | "\n",
529 | " i = model(input_).max(1)[1]\n",
530 | " pred = index2tag[i.data.tolist()[0]]\n",
531 | " for_f1_score.append([pred, y])\n",
532 | " if pred == y:\n",
533 | " accuracy += 1\n",
534 | "\n",
535 | "print(accuracy/len(test_data) * 100)"
536 | ]
537 | },
538 | {
539 | "cell_type": "markdown",
540 | "metadata": {},
541 | "source": [
542 | "This high score is because most of labels are 'O' tag. So we need to measure f1 score."
543 | ]
544 | },
545 | {
546 | "cell_type": "markdown",
547 | "metadata": {},
548 | "source": [
549 | "### Print Confusion matrix "
550 | ]
551 | },
552 | {
553 | "cell_type": "code",
554 | "execution_count": 26,
555 | "metadata": {
556 | "collapsed": true
557 | },
558 | "outputs": [],
559 | "source": [
560 | "y_pred, y_test = list(zip(*for_f1_score))"
561 | ]
562 | },
563 | {
564 | "cell_type": "code",
565 | "execution_count": 27,
566 | "metadata": {
567 | "collapsed": true
568 | },
569 | "outputs": [],
570 | "source": [
571 | "sorted_labels = sorted(\n",
572 | " list(set(y_test) - {'O'}),\n",
573 | " key=lambda name: (name[1:], name[0])\n",
574 | ")"
575 | ]
576 | },
577 | {
578 | "cell_type": "code",
579 | "execution_count": 28,
580 | "metadata": {
581 | "collapsed": false
582 | },
583 | "outputs": [
584 | {
585 | "data": {
586 | "text/plain": [
587 | "['B-LOC', 'I-LOC', 'B-MISC', 'I-MISC', 'B-ORG', 'I-ORG', 'B-PER', 'I-PER']"
588 | ]
589 | },
590 | "execution_count": 28,
591 | "metadata": {},
592 | "output_type": "execute_result"
593 | }
594 | ],
595 | "source": [
596 | "sorted_labels"
597 | ]
598 | },
599 | {
600 | "cell_type": "code",
601 | "execution_count": 29,
602 | "metadata": {
603 | "collapsed": true
604 | },
605 | "outputs": [],
606 | "source": [
607 | "y_pred = [[y] for y in y_pred] # this is because sklearn_crfsuite.metrics function flatten inputs\n",
608 | "y_test = [[y] for y in y_test]"
609 | ]
610 | },
611 | {
612 | "cell_type": "code",
613 | "execution_count": 30,
614 | "metadata": {
615 | "collapsed": false
616 | },
617 | "outputs": [
618 | {
619 | "name": "stdout",
620 | "output_type": "stream",
621 | "text": [
622 | " precision recall f1-score support\n",
623 | "\n",
624 | " B-LOC 0.802 0.636 0.710 1085\n",
625 | " I-LOC 0.732 0.457 0.562 311\n",
626 | " B-MISC 0.750 0.378 0.503 801\n",
627 | " I-MISC 0.679 0.331 0.445 641\n",
628 | " B-ORG 0.723 0.738 0.730 1430\n",
629 | " I-ORG 0.710 0.700 0.705 969\n",
630 | " B-PER 0.782 0.773 0.777 1268\n",
631 | " I-PER 0.853 0.871 0.861 950\n",
632 | "\n",
633 | "avg / total 0.759 0.656 0.693 7455\n",
634 | "\n"
635 | ]
636 | }
637 | ],
638 | "source": [
639 | "print(metrics.flat_classification_report(\n",
640 | " y_test, y_pred, labels = sorted_labels, digits=3\n",
641 | "))"
642 | ]
643 | },
644 | {
645 | "cell_type": "markdown",
646 | "metadata": {
647 | "collapsed": true
648 | },
649 | "source": [
650 | "### TODO"
651 | ]
652 | },
653 | {
654 | "cell_type": "markdown",
655 | "metadata": {},
656 | "source": [
657 | "* use max-margin objective function http://pytorch.org/docs/master/nn.html#multilabelmarginloss"
658 | ]
659 | },
660 | {
661 | "cell_type": "code",
662 | "execution_count": null,
663 | "metadata": {
664 | "collapsed": true
665 | },
666 | "outputs": [],
667 | "source": []
668 | }
669 | ],
670 | "metadata": {
671 | "kernelspec": {
672 | "display_name": "Python 3",
673 | "language": "python",
674 | "name": "python3"
675 | },
676 | "language_info": {
677 | "codemirror_mode": {
678 | "name": "ipython",
679 | "version": 3
680 | },
681 | "file_extension": ".py",
682 | "mimetype": "text/x-python",
683 | "name": "python",
684 | "nbconvert_exporter": "python",
685 | "pygments_lexer": "ipython3",
686 | "version": "3.5.2"
687 | }
688 | },
689 | "nbformat": 4,
690 | "nbformat_minor": 2
691 | }
692 |
--------------------------------------------------------------------------------
/notebooks/06.RNN-Language-Model.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# 6. Recurrent Neural Networks and Language Models"
8 | ]
9 | },
10 | {
11 | "cell_type": "markdown",
12 | "metadata": {},
13 | "source": [
14 | "* http://web.stanford.edu/class/cs224n/lectures/cs224n-2017-lecture8.pdf\n",
15 | "* http://web.stanford.edu/class/cs224n/lectures/cs224n-2017-lecture9.pdf\n",
16 | "* http://colah.github.io/posts/2015-08-Understanding-LSTMs/\n",
17 | "* https://github.com/pytorch/examples/tree/master/word_language_model\n",
18 | "* https://github.com/yunjey/pytorch-tutorial/blob/master/tutorials/02-intermediate/language_model"
19 | ]
20 | },
21 | {
22 | "cell_type": "code",
23 | "execution_count": 1,
24 | "metadata": {
25 | "collapsed": true
26 | },
27 | "outputs": [],
28 | "source": [
29 | "import torch\n",
30 | "import torch.nn as nn\n",
31 | "from torch.autograd import Variable\n",
32 | "import torch.optim as optim\n",
33 | "import torch.nn.functional as F\n",
34 | "import nltk\n",
35 | "import random\n",
36 | "import numpy as np\n",
37 | "from collections import Counter, OrderedDict\n",
38 | "import nltk\n",
39 | "from copy import deepcopy\n",
40 | "flatten = lambda l: [item for sublist in l for item in sublist]\n",
41 | "random.seed(1024)"
42 | ]
43 | },
44 | {
45 | "cell_type": "code",
46 | "execution_count": 2,
47 | "metadata": {
48 | "collapsed": true
49 | },
50 | "outputs": [],
51 | "source": [
52 | "USE_CUDA = torch.cuda.is_available()\n",
53 | "gpus = [0]\n",
54 | "torch.cuda.set_device(gpus[0])\n",
55 | "\n",
56 | "FloatTensor = torch.cuda.FloatTensor if USE_CUDA else torch.FloatTensor\n",
57 | "LongTensor = torch.cuda.LongTensor if USE_CUDA else torch.LongTensor\n",
58 | "ByteTensor = torch.cuda.ByteTensor if USE_CUDA else torch.ByteTensor"
59 | ]
60 | },
61 | {
62 | "cell_type": "code",
63 | "execution_count": 4,
64 | "metadata": {
65 | "collapsed": true
66 | },
67 | "outputs": [],
68 | "source": [
69 | "def prepare_sequence(seq, to_index):\n",
70 | " idxs = list(map(lambda w: to_index[w] if to_index.get(w) is not None else to_index[\"\"], seq))\n",
71 | " return LongTensor(idxs)"
72 | ]
73 | },
74 | {
75 | "cell_type": "markdown",
76 | "metadata": {},
77 | "source": [
78 | "## Data load and Preprocessing"
79 | ]
80 | },
81 | {
82 | "cell_type": "markdown",
83 | "metadata": {},
84 | "source": [
85 | "### Penn TreeBank"
86 | ]
87 | },
88 | {
89 | "cell_type": "code",
90 | "execution_count": 5,
91 | "metadata": {
92 | "collapsed": true
93 | },
94 | "outputs": [],
95 | "source": [
96 | "def prepare_ptb_dataset(filename, word2index=None):\n",
97 | " corpus = open(filename, 'r', encoding='utf-8').readlines()\n",
98 | " corpus = flatten([co.strip().split() + [''] for co in corpus])\n",
99 | " \n",
100 | " if word2index == None:\n",
101 | " vocab = list(set(corpus))\n",
102 | " word2index = {'': 0}\n",
103 | " for vo in vocab:\n",
104 | " if word2index.get(vo) is None:\n",
105 | " word2index[vo] = len(word2index)\n",
106 | " \n",
107 | " return prepare_sequence(corpus, word2index), word2index"
108 | ]
109 | },
110 | {
111 | "cell_type": "code",
112 | "execution_count": 175,
113 | "metadata": {
114 | "collapsed": true
115 | },
116 | "outputs": [],
117 | "source": [
118 | "# borrowed code from https://github.com/pytorch/examples/tree/master/word_language_model\n",
119 | "\n",
120 | "def batchify(data, bsz):\n",
121 | " # Work out how cleanly we can divide the dataset into bsz parts.\n",
122 | " nbatch = data.size(0) // bsz\n",
123 | " # Trim off any extra elements that wouldn't cleanly fit (remainders).\n",
124 | " data = data.narrow(0, 0, nbatch * bsz)\n",
125 | " # Evenly divide the data across the bsz batches.\n",
126 | " data = data.view(bsz, -1).contiguous()\n",
127 | " if USE_CUDA:\n",
128 | " data = data.cuda()\n",
129 | " return data"
130 | ]
131 | },
132 | {
133 | "cell_type": "code",
134 | "execution_count": 176,
135 | "metadata": {
136 | "collapsed": true
137 | },
138 | "outputs": [],
139 | "source": [
140 | "def getBatch(data, seq_length):\n",
141 | " for i in range(0, data.size(1) - seq_length, seq_length):\n",
142 | " inputs = Variable(data[:, i: i + seq_length])\n",
143 | " targets = Variable(data[:, (i + 1): (i + 1) + seq_length].contiguous())\n",
144 | " yield (inputs, targets)"
145 | ]
146 | },
147 | {
148 | "cell_type": "code",
149 | "execution_count": 177,
150 | "metadata": {
151 | "collapsed": true
152 | },
153 | "outputs": [],
154 | "source": [
155 | "train_data, word2index = prepare_ptb_dataset('../dataset/ptb/ptb.train.txt',)\n",
156 | "dev_data , _ = prepare_ptb_dataset('../dataset/ptb/ptb.valid.txt', word2index)\n",
157 | "test_data, _ = prepare_ptb_dataset('../dataset/ptb/ptb.test.txt', word2index)"
158 | ]
159 | },
160 | {
161 | "cell_type": "code",
162 | "execution_count": 178,
163 | "metadata": {
164 | "collapsed": false
165 | },
166 | "outputs": [
167 | {
168 | "data": {
169 | "text/plain": [
170 | "10000"
171 | ]
172 | },
173 | "execution_count": 178,
174 | "metadata": {},
175 | "output_type": "execute_result"
176 | }
177 | ],
178 | "source": [
179 | "len(word2index)"
180 | ]
181 | },
182 | {
183 | "cell_type": "code",
184 | "execution_count": 179,
185 | "metadata": {
186 | "collapsed": true
187 | },
188 | "outputs": [],
189 | "source": [
190 | "index2word = {v:k for k, v in word2index.items()}"
191 | ]
192 | },
193 | {
194 | "cell_type": "markdown",
195 | "metadata": {},
196 | "source": [
197 | "## Modeling "
198 | ]
199 | },
200 | {
201 | "cell_type": "markdown",
202 | "metadata": {},
203 | "source": [
204 | "
\n",
205 | "borrowed image from http://web.stanford.edu/class/cs224n/lectures/cs224n-2017-lecture8.pdf"
206 | ]
207 | },
208 | {
209 | "cell_type": "code",
210 | "execution_count": 180,
211 | "metadata": {
212 | "collapsed": true
213 | },
214 | "outputs": [],
215 | "source": [
216 | "class LanguageModel(nn.Module): \n",
217 | " def __init__(self, vocab_size, embedding_size, hidden_size, n_layers=1, dropout_p=0.5):\n",
218 | "\n",
219 | " super(LanguageModel, self).__init__()\n",
220 | " self.n_layers = n_layers\n",
221 | " self.hidden_size = hidden_size\n",
222 | " self.embed = nn.Embedding(vocab_size, embedding_size)\n",
223 | " self.rnn = nn.LSTM(embedding_size, hidden_size, n_layers, batch_first=True)\n",
224 | " self.linear = nn.Linear(hidden_size, vocab_size)\n",
225 | " self.dropout = nn.Dropout(dropout_p)\n",
226 | " \n",
227 | " def init_weight(self):\n",
228 | " self.embed.weight = nn.init.xavier_uniform(self.embed.weight)\n",
229 | " self.linear.weight = nn.init.xavier_uniform(self.linear.weight)\n",
230 | " self.linear.bias.data.fill_(0)\n",
231 | " \n",
232 | " def init_hidden(self,batch_size):\n",
233 | " hidden = Variable(torch.zeros(self.n_layers,batch_size,self.hidden_size))\n",
234 | " context = Variable(torch.zeros(self.n_layers,batch_size,self.hidden_size))\n",
235 | " return (hidden.cuda(), context.cuda()) if USE_CUDA else (hidden, context)\n",
236 | " \n",
237 | " def detach_hidden(self, hiddens):\n",
238 | " return tuple([hidden.detach() for hidden in hiddens])\n",
239 | " \n",
240 | " def forward(self, inputs, hidden, is_training=False): \n",
241 | "\n",
242 | " embeds = self.embed(inputs)\n",
243 | " if is_training:\n",
244 | " embeds = self.dropout(embeds)\n",
245 | " out,hidden = self.rnn(embeds, hidden)\n",
246 | " return self.linear(out.contiguous().view(out.size(0) * out.size(1), -1)), hidden"
247 | ]
248 | },
249 | {
250 | "cell_type": "markdown",
251 | "metadata": {},
252 | "source": [
253 | "## Train "
254 | ]
255 | },
256 | {
257 | "cell_type": "markdown",
258 | "metadata": {},
259 | "source": [
260 | "It takes for a while..."
261 | ]
262 | },
263 | {
264 | "cell_type": "code",
265 | "execution_count": 181,
266 | "metadata": {
267 | "collapsed": true
268 | },
269 | "outputs": [],
270 | "source": [
271 | "EMBED_SIZE = 128\n",
272 | "HIDDEN_SIZE = 1024\n",
273 | "NUM_LAYER = 1\n",
274 | "LR = 0.01\n",
275 | "SEQ_LENGTH = 30 # for bptt\n",
276 | "BATCH_SIZE = 20\n",
277 | "EPOCH = 40\n",
278 | "RESCHEDULED = False"
279 | ]
280 | },
281 | {
282 | "cell_type": "code",
283 | "execution_count": 182,
284 | "metadata": {
285 | "collapsed": true
286 | },
287 | "outputs": [],
288 | "source": [
289 | "train_data = batchify(train_data, BATCH_SIZE)\n",
290 | "dev_data = batchify(dev_data, BATCH_SIZE//2)\n",
291 | "test_data = batchify(test_data, BATCH_SIZE//2)"
292 | ]
293 | },
294 | {
295 | "cell_type": "code",
296 | "execution_count": 185,
297 | "metadata": {
298 | "collapsed": true
299 | },
300 | "outputs": [],
301 | "source": [
302 | "model = LanguageModel(len(word2index), EMBED_SIZE, HIDDEN_SIZE, NUM_LAYER, 0.5)\n",
303 | "model.init_weight() \n",
304 | "if USE_CUDA:\n",
305 | " model = model.cuda()\n",
306 | "loss_function = nn.CrossEntropyLoss()\n",
307 | "optimizer = optim.Adam(model.parameters(), lr=LR)"
308 | ]
309 | },
310 | {
311 | "cell_type": "code",
312 | "execution_count": 186,
313 | "metadata": {
314 | "collapsed": false
315 | },
316 | "outputs": [
317 | {
318 | "name": "stdout",
319 | "output_type": "stream",
320 | "text": [
321 | "[00/40] mean_loss : 9.45, Perplexity : 12712.23\n",
322 | "[00/40] mean_loss : 5.88, Perplexity : 358.21\n",
323 | "[00/40] mean_loss : 5.55, Perplexity : 256.44\n",
324 | "[01/40] mean_loss : 5.38, Perplexity : 217.46\n",
325 | "[01/40] mean_loss : 5.21, Perplexity : 182.41\n",
326 | "[01/40] mean_loss : 5.10, Perplexity : 164.39\n",
327 | "[02/40] mean_loss : 5.08, Perplexity : 160.87\n",
328 | "[02/40] mean_loss : 4.99, Perplexity : 147.18\n",
329 | "[02/40] mean_loss : 4.92, Perplexity : 136.52\n",
330 | "[03/40] mean_loss : 4.92, Perplexity : 136.64\n",
331 | "[03/40] mean_loss : 4.86, Perplexity : 129.32\n",
332 | "[03/40] mean_loss : 4.80, Perplexity : 121.46\n",
333 | "[04/40] mean_loss : 4.80, Perplexity : 121.91\n",
334 | "[04/40] mean_loss : 4.77, Perplexity : 117.64\n",
335 | "[04/40] mean_loss : 4.71, Perplexity : 111.22\n",
336 | "[05/40] mean_loss : 4.72, Perplexity : 112.01\n",
337 | "[05/40] mean_loss : 4.70, Perplexity : 109.46\n",
338 | "[05/40] mean_loss : 4.64, Perplexity : 103.96\n",
339 | "[06/40] mean_loss : 4.66, Perplexity : 105.25\n",
340 | "[06/40] mean_loss : 4.64, Perplexity : 103.63\n",
341 | "[06/40] mean_loss : 4.60, Perplexity : 99.00\n",
342 | "[07/40] mean_loss : 4.60, Perplexity : 99.89\n",
343 | "[07/40] mean_loss : 4.59, Perplexity : 98.97\n",
344 | "[07/40] mean_loss : 4.55, Perplexity : 94.97\n",
345 | "[08/40] mean_loss : 4.56, Perplexity : 95.54\n",
346 | "[08/40] mean_loss : 4.56, Perplexity : 95.67\n",
347 | "[08/40] mean_loss : 4.52, Perplexity : 91.98\n",
348 | "[09/40] mean_loss : 4.53, Perplexity : 92.61\n",
349 | "[09/40] mean_loss : 4.53, Perplexity : 92.79\n",
350 | "[09/40] mean_loss : 4.50, Perplexity : 89.63\n",
351 | "[10/40] mean_loss : 4.50, Perplexity : 90.13\n",
352 | "[10/40] mean_loss : 4.50, Perplexity : 90.19\n",
353 | "[10/40] mean_loss : 4.47, Perplexity : 87.11\n",
354 | "[11/40] mean_loss : 4.48, Perplexity : 88.11\n",
355 | "[11/40] mean_loss : 4.48, Perplexity : 88.26\n",
356 | "[11/40] mean_loss : 4.45, Perplexity : 86.05\n",
357 | "[12/40] mean_loss : 4.46, Perplexity : 86.81\n",
358 | "[12/40] mean_loss : 4.47, Perplexity : 87.03\n",
359 | "[12/40] mean_loss : 4.43, Perplexity : 84.04\n",
360 | "[13/40] mean_loss : 4.45, Perplexity : 85.27\n",
361 | "[13/40] mean_loss : 4.45, Perplexity : 85.83\n",
362 | "[13/40] mean_loss : 4.42, Perplexity : 83.33\n",
363 | "[14/40] mean_loss : 4.43, Perplexity : 84.15\n",
364 | "[14/40] mean_loss : 4.43, Perplexity : 84.31\n",
365 | "[14/40] mean_loss : 4.41, Perplexity : 82.29\n",
366 | "[15/40] mean_loss : 4.43, Perplexity : 83.82\n",
367 | "[15/40] mean_loss : 4.43, Perplexity : 83.70\n",
368 | "[15/40] mean_loss : 4.40, Perplexity : 81.59\n",
369 | "[16/40] mean_loss : 4.42, Perplexity : 83.06\n",
370 | "[16/40] mean_loss : 4.42, Perplexity : 83.29\n",
371 | "[16/40] mean_loss : 4.39, Perplexity : 80.89\n",
372 | "[17/40] mean_loss : 4.41, Perplexity : 82.44\n",
373 | "[17/40] mean_loss : 4.41, Perplexity : 82.51\n",
374 | "[17/40] mean_loss : 4.39, Perplexity : 80.59\n",
375 | "[18/40] mean_loss : 4.40, Perplexity : 81.59\n",
376 | "[18/40] mean_loss : 4.41, Perplexity : 82.21\n",
377 | "[18/40] mean_loss : 4.38, Perplexity : 79.87\n",
378 | "[19/40] mean_loss : 4.40, Perplexity : 81.43\n",
379 | "[19/40] mean_loss : 4.40, Perplexity : 81.67\n",
380 | "[19/40] mean_loss : 4.37, Perplexity : 79.28\n",
381 | "[20/40] mean_loss : 4.40, Perplexity : 81.18\n",
382 | "[20/40] mean_loss : 4.40, Perplexity : 81.17\n",
383 | "[20/40] mean_loss : 4.37, Perplexity : 79.11\n",
384 | "[21/40] mean_loss : 4.40, Perplexity : 81.44\n",
385 | "[21/40] mean_loss : 4.34, Perplexity : 76.43\n",
386 | "[21/40] mean_loss : 4.21, Perplexity : 67.17\n",
387 | "[22/40] mean_loss : 4.26, Perplexity : 70.84\n",
388 | "[22/40] mean_loss : 4.26, Perplexity : 70.75\n",
389 | "[22/40] mean_loss : 4.17, Perplexity : 64.99\n",
390 | "[23/40] mean_loss : 4.22, Perplexity : 68.36\n",
391 | "[23/40] mean_loss : 4.22, Perplexity : 67.82\n",
392 | "[23/40] mean_loss : 4.15, Perplexity : 63.74\n",
393 | "[24/40] mean_loss : 4.20, Perplexity : 66.66\n",
394 | "[24/40] mean_loss : 4.20, Perplexity : 66.43\n",
395 | "[24/40] mean_loss : 4.14, Perplexity : 62.85\n",
396 | "[25/40] mean_loss : 4.18, Perplexity : 65.53\n",
397 | "[25/40] mean_loss : 4.17, Perplexity : 64.99\n",
398 | "[25/40] mean_loss : 4.13, Perplexity : 61.94\n",
399 | "[26/40] mean_loss : 4.17, Perplexity : 64.61\n",
400 | "[26/40] mean_loss : 4.16, Perplexity : 64.34\n",
401 | "[26/40] mean_loss : 4.12, Perplexity : 61.27\n",
402 | "[27/40] mean_loss : 4.15, Perplexity : 63.73\n",
403 | "[27/40] mean_loss : 4.15, Perplexity : 63.32\n",
404 | "[27/40] mean_loss : 4.11, Perplexity : 60.87\n",
405 | "[28/40] mean_loss : 4.14, Perplexity : 62.96\n",
406 | "[28/40] mean_loss : 4.14, Perplexity : 63.01\n",
407 | "[28/40] mean_loss : 4.10, Perplexity : 60.33\n",
408 | "[29/40] mean_loss : 4.14, Perplexity : 62.54\n",
409 | "[29/40] mean_loss : 4.13, Perplexity : 62.36\n",
410 | "[29/40] mean_loss : 4.10, Perplexity : 60.06\n",
411 | "[30/40] mean_loss : 4.13, Perplexity : 62.05\n",
412 | "[30/40] mean_loss : 4.13, Perplexity : 61.91\n",
413 | "[30/40] mean_loss : 4.09, Perplexity : 59.46\n",
414 | "[31/40] mean_loss : 4.12, Perplexity : 61.45\n",
415 | "[31/40] mean_loss : 4.11, Perplexity : 61.24\n",
416 | "[31/40] mean_loss : 4.08, Perplexity : 59.12\n",
417 | "[32/40] mean_loss : 4.11, Perplexity : 61.03\n",
418 | "[32/40] mean_loss : 4.11, Perplexity : 60.88\n",
419 | "[32/40] mean_loss : 4.07, Perplexity : 58.69\n",
420 | "[33/40] mean_loss : 4.11, Perplexity : 60.71\n",
421 | "[33/40] mean_loss : 4.10, Perplexity : 60.57\n",
422 | "[33/40] mean_loss : 4.07, Perplexity : 58.38\n",
423 | "[34/40] mean_loss : 4.10, Perplexity : 60.33\n",
424 | "[34/40] mean_loss : 4.10, Perplexity : 60.23\n",
425 | "[34/40] mean_loss : 4.06, Perplexity : 58.06\n",
426 | "[35/40] mean_loss : 4.09, Perplexity : 60.00\n",
427 | "[35/40] mean_loss : 4.09, Perplexity : 59.74\n",
428 | "[35/40] mean_loss : 4.06, Perplexity : 57.75\n",
429 | "[36/40] mean_loss : 4.09, Perplexity : 59.58\n",
430 | "[36/40] mean_loss : 4.09, Perplexity : 59.47\n",
431 | "[36/40] mean_loss : 4.05, Perplexity : 57.59\n",
432 | "[37/40] mean_loss : 4.08, Perplexity : 59.30\n",
433 | "[37/40] mean_loss : 4.08, Perplexity : 59.11\n",
434 | "[37/40] mean_loss : 4.05, Perplexity : 57.11\n",
435 | "[38/40] mean_loss : 4.08, Perplexity : 58.98\n",
436 | "[38/40] mean_loss : 4.07, Perplexity : 58.70\n",
437 | "[38/40] mean_loss : 4.04, Perplexity : 57.10\n",
438 | "[39/40] mean_loss : 4.07, Perplexity : 58.79\n",
439 | "[39/40] mean_loss : 4.07, Perplexity : 58.58\n",
440 | "[39/40] mean_loss : 4.04, Perplexity : 56.79\n"
441 | ]
442 | }
443 | ],
444 | "source": [
445 | "for epoch in range(EPOCH):\n",
446 | " total_loss = 0\n",
447 | " losses = []\n",
448 | " hidden = model.init_hidden(BATCH_SIZE)\n",
449 | " for i,batch in enumerate(getBatch(train_data, SEQ_LENGTH)):\n",
450 | " inputs, targets = batch\n",
451 | " hidden = model.detach_hidden(hidden)\n",
452 | " model.zero_grad()\n",
453 | " preds, hidden = model(inputs, hidden, True)\n",
454 | "\n",
455 | " loss = loss_function(preds, targets.view(-1))\n",
456 | " losses.append(loss.data[0])\n",
457 | " loss.backward()\n",
458 | " torch.nn.utils.clip_grad_norm(model.parameters(), 0.5) # gradient clipping\n",
459 | " optimizer.step()\n",
460 | "\n",
461 | " if i > 0 and i % 500 == 0:\n",
462 | " print(\"[%02d/%d] mean_loss : %0.2f, Perplexity : %0.2f\" % (epoch,EPOCH, np.mean(losses), np.exp(np.mean(losses))))\n",
463 | " losses = []\n",
464 | " \n",
465 | " # learning rate anealing\n",
466 | " # You can use http://pytorch.org/docs/master/optim.html#how-to-adjust-learning-rate\n",
467 | " if RESCHEDULED == False and epoch == EPOCH//2:\n",
468 | " LR *= 0.1\n",
469 | " optimizer = optim.Adam(model.parameters(), lr=LR)\n",
470 | " RESCHEDULED = True"
471 | ]
472 | },
473 | {
474 | "cell_type": "markdown",
475 | "metadata": {},
476 | "source": [
477 | "### Test "
478 | ]
479 | },
480 | {
481 | "cell_type": "code",
482 | "execution_count": 189,
483 | "metadata": {
484 | "collapsed": false
485 | },
486 | "outputs": [
487 | {
488 | "name": "stdout",
489 | "output_type": "stream",
490 | "text": [
491 | "Test Perpelexity : 155.89\n"
492 | ]
493 | }
494 | ],
495 | "source": [
496 | "total_loss = 0\n",
497 | "hidden = model.init_hidden(BATCH_SIZE//2)\n",
498 | "for batch in getBatch(test_data, SEQ_LENGTH):\n",
499 | " inputs,targets = batch\n",
500 | " \n",
501 | " hidden = model.detach_hidden(hidden)\n",
502 | " model.zero_grad()\n",
503 | " preds, hidden = model(inputs, hidden)\n",
504 | " total_loss += inputs.size(1) * loss_function(preds, targets.view(-1)).data\n",
505 | "\n",
506 | "total_loss = total_loss[0]/test_data.size(1)\n",
507 | "print(\"Test Perpelexity : %5.2f\" % (np.exp(total_loss)))"
508 | ]
509 | },
510 | {
511 | "cell_type": "markdown",
512 | "metadata": {
513 | "collapsed": true
514 | },
515 | "source": [
516 | "## Further topics"
517 | ]
518 | },
519 | {
520 | "cell_type": "markdown",
521 | "metadata": {},
522 | "source": [
523 | "* Pointer Sentinel Mixture Models\n",
524 | "* Regularizing and Optimizing LSTM Language Models"
525 | ]
526 | },
527 | {
528 | "cell_type": "code",
529 | "execution_count": null,
530 | "metadata": {
531 | "collapsed": true
532 | },
533 | "outputs": [],
534 | "source": []
535 | }
536 | ],
537 | "metadata": {
538 | "kernelspec": {
539 | "display_name": "Python 3",
540 | "language": "python",
541 | "name": "python3"
542 | },
543 | "language_info": {
544 | "codemirror_mode": {
545 | "name": "ipython",
546 | "version": 3
547 | },
548 | "file_extension": ".py",
549 | "mimetype": "text/x-python",
550 | "name": "python",
551 | "nbconvert_exporter": "python",
552 | "pygments_lexer": "ipython3",
553 | "version": "3.5.2"
554 | }
555 | },
556 | "nbformat": 4,
557 | "nbformat_minor": 2
558 | }
559 |
--------------------------------------------------------------------------------
/notebooks/08.CNN-for-Text-Classification.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# 8. Convolutional Neural Networks"
8 | ]
9 | },
10 | {
11 | "cell_type": "markdown",
12 | "metadata": {},
13 | "source": [
14 | "I recommend you take a look at these material first."
15 | ]
16 | },
17 | {
18 | "cell_type": "markdown",
19 | "metadata": {},
20 | "source": [
21 | "* http://web.stanford.edu/class/cs224n/lectures/cs224n-2017-lecture13-CNNs.pdf\n",
22 | "* http://www.aclweb.org/anthology/D14-1181\n",
23 | "* https://github.com/Shawn1993/cnn-text-classification-pytorch\n",
24 | "* http://cogcomp.org/Data/QA/QC/"
25 | ]
26 | },
27 | {
28 | "cell_type": "code",
29 | "execution_count": 58,
30 | "metadata": {
31 | "collapsed": true
32 | },
33 | "outputs": [],
34 | "source": [
35 | "import torch\n",
36 | "import torch.nn as nn\n",
37 | "from torch.autograd import Variable\n",
38 | "import torch.optim as optim\n",
39 | "import torch.nn.functional as F\n",
40 | "import nltk\n",
41 | "import random\n",
42 | "import numpy as np\n",
43 | "from collections import Counter, OrderedDict\n",
44 | "import nltk\n",
45 | "import re\n",
46 | "from copy import deepcopy\n",
47 | "flatten = lambda l: [item for sublist in l for item in sublist]\n",
48 | "random.seed(1024)"
49 | ]
50 | },
51 | {
52 | "cell_type": "code",
53 | "execution_count": 2,
54 | "metadata": {
55 | "collapsed": true
56 | },
57 | "outputs": [],
58 | "source": [
59 | "USE_CUDA = torch.cuda.is_available()\n",
60 | "gpus = [0]\n",
61 | "torch.cuda.set_device(gpus[0])\n",
62 | "\n",
63 | "FloatTensor = torch.cuda.FloatTensor if USE_CUDA else torch.FloatTensor\n",
64 | "LongTensor = torch.cuda.LongTensor if USE_CUDA else torch.LongTensor\n",
65 | "ByteTensor = torch.cuda.ByteTensor if USE_CUDA else torch.ByteTensor"
66 | ]
67 | },
68 | {
69 | "cell_type": "code",
70 | "execution_count": 3,
71 | "metadata": {
72 | "collapsed": true
73 | },
74 | "outputs": [],
75 | "source": [
76 | "def getBatch(batch_size, train_data):\n",
77 | " random.shuffle(train_data)\n",
78 | " sindex = 0\n",
79 | " eindex = batch_size\n",
80 | " while eindex < len(train_data):\n",
81 | " batch = train_data[sindex: eindex]\n",
82 | " temp = eindex\n",
83 | " eindex = eindex + batch_size\n",
84 | " sindex = temp\n",
85 | " yield batch\n",
86 | " \n",
87 | " if eindex >= len(train_data):\n",
88 | " batch = train_data[sindex:]\n",
89 | " yield batch"
90 | ]
91 | },
92 | {
93 | "cell_type": "code",
94 | "execution_count": 110,
95 | "metadata": {
96 | "collapsed": true
97 | },
98 | "outputs": [],
99 | "source": [
100 | "def pad_to_batch(batch):\n",
101 | " x,y = zip(*batch)\n",
102 | " max_x = max([s.size(1) for s in x])\n",
103 | " x_p = []\n",
104 | " for i in range(len(batch)):\n",
105 | " if x[i].size(1) < max_x:\n",
106 | " x_p.append(torch.cat([x[i], Variable(LongTensor([word2index['']] * (max_x - x[i].size(1)))).view(1, -1)], 1))\n",
107 | " else:\n",
108 | " x_p.append(x[i])\n",
109 | " return torch.cat(x_p), torch.cat(y).view(-1)"
110 | ]
111 | },
112 | {
113 | "cell_type": "code",
114 | "execution_count": 20,
115 | "metadata": {
116 | "collapsed": true
117 | },
118 | "outputs": [],
119 | "source": [
120 | "def prepare_sequence(seq, to_index):\n",
121 | " idxs = list(map(lambda w: to_index[w] if to_index.get(w) is not None else to_index[\"\"], seq))\n",
122 | " return Variable(LongTensor(idxs))"
123 | ]
124 | },
125 | {
126 | "cell_type": "markdown",
127 | "metadata": {},
128 | "source": [
129 | "## Data load & Preprocessing"
130 | ]
131 | },
132 | {
133 | "cell_type": "markdown",
134 | "metadata": {},
135 | "source": [
136 | "### TREC question dataset(http://cogcomp.org/Data/QA/QC/)"
137 | ]
138 | },
139 | {
140 | "cell_type": "markdown",
141 | "metadata": {},
142 | "source": [
143 | "Task involves\n",
144 | "classifying a question into 6 question\n",
145 | "types (whether the question is about person,\n",
146 | "location, numeric information, etc.)"
147 | ]
148 | },
149 | {
150 | "cell_type": "code",
151 | "execution_count": 53,
152 | "metadata": {
153 | "collapsed": true
154 | },
155 | "outputs": [],
156 | "source": [
157 | "data = open('../dataset/train_5500.label.txt', 'r', encoding='latin-1').readlines()"
158 | ]
159 | },
160 | {
161 | "cell_type": "code",
162 | "execution_count": 54,
163 | "metadata": {
164 | "collapsed": true
165 | },
166 | "outputs": [],
167 | "source": [
168 | "data = [[d.split(':')[1][:-1], d.split(':')[0]] for d in data]"
169 | ]
170 | },
171 | {
172 | "cell_type": "code",
173 | "execution_count": 61,
174 | "metadata": {
175 | "collapsed": true
176 | },
177 | "outputs": [],
178 | "source": [
179 | "X, y = list(zip(*data))\n",
180 | "X = list(X)"
181 | ]
182 | },
183 | {
184 | "cell_type": "markdown",
185 | "metadata": {},
186 | "source": [
187 | "### Num masking "
188 | ]
189 | },
190 | {
191 | "cell_type": "markdown",
192 | "metadata": {},
193 | "source": [
194 | "It reduces the search space. ex. my birthday is 12.22 ==> my birthday is ##.##"
195 | ]
196 | },
197 | {
198 | "cell_type": "code",
199 | "execution_count": 62,
200 | "metadata": {
201 | "collapsed": true
202 | },
203 | "outputs": [],
204 | "source": [
205 | "for i, x in enumerate(X):\n",
206 | " X[i] = re.sub('\\d', '#', x).split()"
207 | ]
208 | },
209 | {
210 | "cell_type": "markdown",
211 | "metadata": {},
212 | "source": [
213 | "### Build Vocab "
214 | ]
215 | },
216 | {
217 | "cell_type": "code",
218 | "execution_count": 63,
219 | "metadata": {
220 | "collapsed": true
221 | },
222 | "outputs": [],
223 | "source": [
224 | "vocab = list(set(flatten(X)))"
225 | ]
226 | },
227 | {
228 | "cell_type": "code",
229 | "execution_count": 64,
230 | "metadata": {
231 | "collapsed": false
232 | },
233 | "outputs": [
234 | {
235 | "data": {
236 | "text/plain": [
237 | "9117"
238 | ]
239 | },
240 | "execution_count": 64,
241 | "metadata": {},
242 | "output_type": "execute_result"
243 | }
244 | ],
245 | "source": [
246 | "len(vocab)"
247 | ]
248 | },
249 | {
250 | "cell_type": "code",
251 | "execution_count": 31,
252 | "metadata": {
253 | "collapsed": false
254 | },
255 | "outputs": [
256 | {
257 | "data": {
258 | "text/plain": [
259 | "6"
260 | ]
261 | },
262 | "execution_count": 31,
263 | "metadata": {},
264 | "output_type": "execute_result"
265 | }
266 | ],
267 | "source": [
268 | "len(set(y)) # num of class"
269 | ]
270 | },
271 | {
272 | "cell_type": "code",
273 | "execution_count": 94,
274 | "metadata": {
275 | "collapsed": true
276 | },
277 | "outputs": [],
278 | "source": [
279 | "word2index={'': 0, '': 1}\n",
280 | "\n",
281 | "for vo in vocab:\n",
282 | " if word2index.get(vo) is None:\n",
283 | " word2index[vo] = len(word2index)\n",
284 | " \n",
285 | "index2word = {v:k for k, v in word2index.items()}\n",
286 | "\n",
287 | "target2index = {}\n",
288 | "\n",
289 | "for cl in set(y):\n",
290 | " if target2index.get(cl) is None:\n",
291 | " target2index[cl] = len(target2index)\n",
292 | "\n",
293 | "index2target = {v:k for k, v in target2index.items()}"
294 | ]
295 | },
296 | {
297 | "cell_type": "code",
298 | "execution_count": 95,
299 | "metadata": {
300 | "collapsed": true
301 | },
302 | "outputs": [],
303 | "source": [
304 | "X_p, y_p = [], []\n",
305 | "for pair in zip(X,y):\n",
306 | " X_p.append(prepare_sequence(pair[0], word2index).view(1, -1))\n",
307 | " y_p.append(Variable(LongTensor([target2index[pair[1]]])).view(1, -1))\n",
308 | " \n",
309 | "data_p = list(zip(X_p, y_p))\n",
310 | "random.shuffle(data_p)\n",
311 | "\n",
312 | "train_data = data_p[: int(len(data_p) * 0.9)]\n",
313 | "test_data = data_p[int(len(data_p) * 0.9):]"
314 | ]
315 | },
316 | {
317 | "cell_type": "markdown",
318 | "metadata": {},
319 | "source": [
320 | "### Load Pretrained word vector"
321 | ]
322 | },
323 | {
324 | "cell_type": "markdown",
325 | "metadata": {},
326 | "source": [
327 | "you can download pretrained word vector from here https://github.com/mmihaltz/word2vec-GoogleNews-vectors "
328 | ]
329 | },
330 | {
331 | "cell_type": "code",
332 | "execution_count": 41,
333 | "metadata": {
334 | "collapsed": true
335 | },
336 | "outputs": [],
337 | "source": [
338 | "import gensim"
339 | ]
340 | },
341 | {
342 | "cell_type": "code",
343 | "execution_count": 43,
344 | "metadata": {
345 | "collapsed": true
346 | },
347 | "outputs": [],
348 | "source": [
349 | "model = gensim.models.KeyedVectors.load_word2vec_format('../dataset/GoogleNews-vectors-negative300.bin', binary=True)"
350 | ]
351 | },
352 | {
353 | "cell_type": "code",
354 | "execution_count": 48,
355 | "metadata": {
356 | "collapsed": false
357 | },
358 | "outputs": [
359 | {
360 | "data": {
361 | "text/plain": [
362 | "3000000"
363 | ]
364 | },
365 | "execution_count": 48,
366 | "metadata": {},
367 | "output_type": "execute_result"
368 | }
369 | ],
370 | "source": [
371 | "len(model.index2word)"
372 | ]
373 | },
374 | {
375 | "cell_type": "code",
376 | "execution_count": 96,
377 | "metadata": {
378 | "collapsed": true
379 | },
380 | "outputs": [],
381 | "source": [
382 | "pretrained = []\n",
383 | "\n",
384 | "for key in word2index.keys():\n",
385 | " try:\n",
386 | " pretrained.append(model[word2index[key]])\n",
387 | " except:\n",
388 | " pretrained.append(np.random.randn(300))\n",
389 | " \n",
390 | "pretrained_vectors = np.vstack(pretrained)"
391 | ]
392 | },
393 | {
394 | "cell_type": "markdown",
395 | "metadata": {},
396 | "source": [
397 | "## Modeling "
398 | ]
399 | },
400 | {
401 | "cell_type": "markdown",
402 | "metadata": {},
403 | "source": [
404 | "
\n",
405 | "borrowed image from http://www.aclweb.org/anthology/D14-1181"
406 | ]
407 | },
408 | {
409 | "cell_type": "code",
410 | "execution_count": 117,
411 | "metadata": {
412 | "collapsed": true
413 | },
414 | "outputs": [],
415 | "source": [
416 | "class CNNClassifier(nn.Module):\n",
417 | " \n",
418 | " def __init__(self, vocab_size, embedding_dim, output_size, kernel_dim=100, kernel_sizes=(3, 4, 5), dropout=0.5):\n",
419 | " super(CNNClassifier,self).__init__()\n",
420 | "\n",
421 | " self.embedding = nn.Embedding(vocab_size, embedding_dim)\n",
422 | " self.convs = nn.ModuleList([nn.Conv2d(1, kernel_dim, (K, embedding_dim)) for K in kernel_sizes])\n",
423 | "\n",
424 | " # kernal_size = (K,D) \n",
425 | " self.dropout = nn.Dropout(dropout)\n",
426 | " self.fc = nn.Linear(len(kernel_sizes) * kernel_dim, output_size)\n",
427 | " \n",
428 | " \n",
429 | " def init_weights(self, pretrained_word_vectors, is_static=False):\n",
430 | " self.embedding.weight = nn.Parameter(torch.from_numpy(pretrained_word_vectors).float())\n",
431 | " if is_static:\n",
432 | " self.embedding.weight.requires_grad = False\n",
433 | "\n",
434 | "\n",
435 | " def forward(self, inputs, is_training=False):\n",
436 | " inputs = self.embedding(inputs).unsqueeze(1) # (B,1,T,D)\n",
437 | " inputs = [F.relu(conv(inputs)).squeeze(3) for conv in self.convs] #[(N,Co,W), ...]*len(Ks)\n",
438 | " inputs = [F.max_pool1d(i, i.size(2)).squeeze(2) for i in inputs] #[(N,Co), ...]*len(Ks)\n",
439 | "\n",
440 | " concated = torch.cat(inputs, 1)\n",
441 | "\n",
442 | " if is_training:\n",
443 | " concated = self.dropout(concated) # (N,len(Ks)*Co)\n",
444 | " out = self.fc(concated) \n",
445 | " return F.log_softmax(out,1)"
446 | ]
447 | },
448 | {
449 | "cell_type": "markdown",
450 | "metadata": {},
451 | "source": [
452 | "## Train "
453 | ]
454 | },
455 | {
456 | "cell_type": "markdown",
457 | "metadata": {},
458 | "source": [
459 | "It takes for a while if you use just cpu."
460 | ]
461 | },
462 | {
463 | "cell_type": "code",
464 | "execution_count": 145,
465 | "metadata": {
466 | "collapsed": true
467 | },
468 | "outputs": [],
469 | "source": [
470 | "EPOCH = 5\n",
471 | "BATCH_SIZE = 50\n",
472 | "KERNEL_SIZES = [3,4,5]\n",
473 | "KERNEL_DIM = 100\n",
474 | "LR = 0.001"
475 | ]
476 | },
477 | {
478 | "cell_type": "code",
479 | "execution_count": 146,
480 | "metadata": {
481 | "collapsed": true
482 | },
483 | "outputs": [],
484 | "source": [
485 | "model = CNNClassifier(len(word2index), 300, len(target2index), KERNEL_DIM, KERNEL_SIZES)\n",
486 | "model.init_weights(pretrained_vectors) # initialize embedding matrix using pretrained vectors\n",
487 | "\n",
488 | "if USE_CUDA:\n",
489 | " model = model.cuda()\n",
490 | " \n",
491 | "loss_function = nn.CrossEntropyLoss()\n",
492 | "optimizer = optim.Adam(model.parameters(), lr=LR)"
493 | ]
494 | },
495 | {
496 | "cell_type": "code",
497 | "execution_count": 147,
498 | "metadata": {
499 | "collapsed": false
500 | },
501 | "outputs": [
502 | {
503 | "name": "stdout",
504 | "output_type": "stream",
505 | "text": [
506 | "[0/5] mean_loss : 2.13\n",
507 | "[1/5] mean_loss : 0.12\n",
508 | "[2/5] mean_loss : 0.08\n",
509 | "[3/5] mean_loss : 0.02\n",
510 | "[4/5] mean_loss : 0.05\n"
511 | ]
512 | }
513 | ],
514 | "source": [
515 | "for epoch in range(EPOCH):\n",
516 | " losses = []\n",
517 | " for i,batch in enumerate(getBatch(BATCH_SIZE, train_data)):\n",
518 | " inputs,targets = pad_to_batch(batch)\n",
519 | " \n",
520 | " model.zero_grad()\n",
521 | " preds = model(inputs, True)\n",
522 | " \n",
523 | " loss = loss_function(preds, targets)\n",
524 | " losses.append(loss.data.tolist()[0])\n",
525 | " loss.backward()\n",
526 | " \n",
527 | " #for param in model.parameters():\n",
528 | " # param.grad.data.clamp_(-3, 3)\n",
529 | " \n",
530 | " optimizer.step()\n",
531 | " \n",
532 | " if i % 100 == 0:\n",
533 | " print(\"[%d/%d] mean_loss : %0.2f\" %(epoch, EPOCH, np.mean(losses)))\n",
534 | " losses = []"
535 | ]
536 | },
537 | {
538 | "cell_type": "markdown",
539 | "metadata": {},
540 | "source": [
541 | "## Test "
542 | ]
543 | },
544 | {
545 | "cell_type": "code",
546 | "execution_count": 150,
547 | "metadata": {
548 | "collapsed": true
549 | },
550 | "outputs": [],
551 | "source": [
552 | "accuracy = 0"
553 | ]
554 | },
555 | {
556 | "cell_type": "code",
557 | "execution_count": 151,
558 | "metadata": {
559 | "collapsed": false
560 | },
561 | "outputs": [
562 | {
563 | "name": "stdout",
564 | "output_type": "stream",
565 | "text": [
566 | "97.61904761904762\n"
567 | ]
568 | }
569 | ],
570 | "source": [
571 | "for test in test_data:\n",
572 | " pred = model(test[0]).max(1)[1]\n",
573 | " pred = pred.data.tolist()[0]\n",
574 | " target = test[1].data.tolist()[0][0]\n",
575 | " if pred == target:\n",
576 | " accuracy += 1\n",
577 | "\n",
578 | "print(accuracy/len(test_data) * 100)"
579 | ]
580 | },
581 | {
582 | "cell_type": "markdown",
583 | "metadata": {
584 | "collapsed": true
585 | },
586 | "source": [
587 | "## Further topics "
588 | ]
589 | },
590 | {
591 | "cell_type": "markdown",
592 | "metadata": {},
593 | "source": [
594 | "* Character-Aware Neural Language Models\n",
595 | "* Character level CNN for text classification"
596 | ]
597 | },
598 | {
599 | "cell_type": "markdown",
600 | "metadata": {},
601 | "source": [
602 | "## Suggested Reading"
603 | ]
604 | },
605 | {
606 | "cell_type": "markdown",
607 | "metadata": {},
608 | "source": [
609 | "* https://blog.statsbot.co/text-classifier-algorithms-in-machine-learning-acc115293278\n",
610 | "* Bag of Tricks for Efficient Text Classification\n",
611 | "* Which Encoding is the Best for Text Classification in Chinese, English, Japanese and Korean?"
612 | ]
613 | },
614 | {
615 | "cell_type": "code",
616 | "execution_count": null,
617 | "metadata": {
618 | "collapsed": true
619 | },
620 | "outputs": [],
621 | "source": []
622 | }
623 | ],
624 | "metadata": {
625 | "kernelspec": {
626 | "display_name": "Python 3",
627 | "language": "python",
628 | "name": "python3"
629 | },
630 | "language_info": {
631 | "codemirror_mode": {
632 | "name": "ipython",
633 | "version": 3
634 | },
635 | "file_extension": ".py",
636 | "mimetype": "text/x-python",
637 | "name": "python",
638 | "nbconvert_exporter": "python",
639 | "pygments_lexer": "ipython3",
640 | "version": "3.5.2"
641 | }
642 | },
643 | "nbformat": 4,
644 | "nbformat_minor": 2
645 | }
646 |
--------------------------------------------------------------------------------
/notebooks/09.Recursive-NN-for-Sentiment-Classification.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# 9. Recursive Neural Networks and Constituency Parsing"
8 | ]
9 | },
10 | {
11 | "cell_type": "markdown",
12 | "metadata": {},
13 | "source": [
14 | "I recommend you take a look at these material first."
15 | ]
16 | },
17 | {
18 | "cell_type": "markdown",
19 | "metadata": {},
20 | "source": [
21 | "* http://web.stanford.edu/class/cs224n/lectures/cs224n-2017-lecture14-TreeRNNs.pdf\n",
22 | "* https://nlp.stanford.edu/~socherr/EMNLP2013_RNTN.pdf"
23 | ]
24 | },
25 | {
26 | "cell_type": "code",
27 | "execution_count": 4,
28 | "metadata": {
29 | "collapsed": true
30 | },
31 | "outputs": [],
32 | "source": [
33 | "import torch\n",
34 | "import torch.nn as nn\n",
35 | "from torch.autograd import Variable\n",
36 | "import torch.optim as optim\n",
37 | "import torch.nn.functional as F\n",
38 | "import nltk\n",
39 | "import random\n",
40 | "import numpy as np\n",
41 | "from collections import Counter, OrderedDict\n",
42 | "import nltk\n",
43 | "from copy import deepcopy\n",
44 | "import os\n",
45 | "from IPython.display import Image, display\n",
46 | "from nltk.draw import TreeWidget\n",
47 | "from nltk.draw.util import CanvasFrame\n",
48 | "from nltk.tree import Tree as nltkTree\n",
49 | "flatten = lambda l: [item for sublist in l for item in sublist]\n",
50 | "random.seed(1024)"
51 | ]
52 | },
53 | {
54 | "cell_type": "code",
55 | "execution_count": 2,
56 | "metadata": {
57 | "collapsed": true
58 | },
59 | "outputs": [],
60 | "source": [
61 | "USE_CUDA = torch.cuda.is_available()\n",
62 | "gpus = [0]\n",
63 | "torch.cuda.set_device(gpus[0])\n",
64 | "\n",
65 | "FloatTensor = torch.cuda.FloatTensor if USE_CUDA else torch.FloatTensor\n",
66 | "LongTensor = torch.cuda.LongTensor if USE_CUDA else torch.LongTensor\n",
67 | "ByteTensor = torch.cuda.ByteTensor if USE_CUDA else torch.ByteTensor"
68 | ]
69 | },
70 | {
71 | "cell_type": "code",
72 | "execution_count": 3,
73 | "metadata": {
74 | "collapsed": true
75 | },
76 | "outputs": [],
77 | "source": [
78 | "def getBatch(batch_size, train_data):\n",
79 | " random.shuffle(train_data)\n",
80 | " sindex = 0\n",
81 | " eindex = batch_size\n",
82 | " while eindex < len(train_data):\n",
83 | " batch = train_data[sindex: eindex]\n",
84 | " temp = eindex\n",
85 | " eindex = eindex + batch_size\n",
86 | " sindex = temp\n",
87 | " yield batch\n",
88 | " \n",
89 | " if eindex >= len(train_data):\n",
90 | " batch = train_data[sindex:]\n",
91 | " yield batch"
92 | ]
93 | },
94 | {
95 | "cell_type": "code",
96 | "execution_count": 5,
97 | "metadata": {
98 | "collapsed": true
99 | },
100 | "outputs": [],
101 | "source": [
102 | "# Borrowed from https://stackoverflow.com/questions/31779707/how-do-you-make-nltk-draw-trees-that-are-inline-in-ipython-jupyter\n",
103 | "\n",
104 | "def draw_nltk_tree(tree):\n",
105 | " cf = CanvasFrame()\n",
106 | " tc = TreeWidget(cf.canvas(), tree)\n",
107 | " tc['node_font'] = 'arial 15 bold'\n",
108 | " tc['leaf_font'] = 'arial 15'\n",
109 | " tc['node_color'] = '#005990'\n",
110 | " tc['leaf_color'] = '#3F8F57'\n",
111 | " tc['line_color'] = '#175252'\n",
112 | " cf.add_widget(tc, 50, 50)\n",
113 | " cf.print_to_file('tmp_tree_output.ps')\n",
114 | " cf.destroy()\n",
115 | " os.system('convert tmp_tree_output.ps tmp_tree_output.png')\n",
116 | " display(Image(filename='tmp_tree_output.png'))\n",
117 | " os.system('rm tmp_tree_output.ps tmp_tree_output.png')"
118 | ]
119 | },
120 | {
121 | "cell_type": "markdown",
122 | "metadata": {},
123 | "source": [
124 | "## Data load and Preprocessing"
125 | ]
126 | },
127 | {
128 | "cell_type": "markdown",
129 | "metadata": {},
130 | "source": [
131 | "### Stanford Sentiment Treebank(https://nlp.stanford.edu/sentiment/index.html)"
132 | ]
133 | },
134 | {
135 | "cell_type": "code",
136 | "execution_count": 10,
137 | "metadata": {
138 | "collapsed": false
139 | },
140 | "outputs": [
141 | {
142 | "name": "stdout",
143 | "output_type": "stream",
144 | "text": [
145 | "(3 (2 (1 Deflated) (2 (2 ending) (2 aside))) (4 (2 ,) (4 (2 there) (3 (3 (2 's) (3 (2 much) (2 (2 to) (3 (3 recommend) (2 (2 the) (2 film)))))) (2 .)))))\n",
146 | "\n"
147 | ]
148 | }
149 | ],
150 | "source": [
151 | "sample = random.choice(open('../dataset/trees/train.txt', 'r', encoding='utf-8').readlines())\n",
152 | "print(sample)"
153 | ]
154 | },
155 | {
156 | "cell_type": "code",
157 | "execution_count": 11,
158 | "metadata": {
159 | "collapsed": false
160 | },
161 | "outputs": [
162 | {
163 | "data": {
164 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAhEAAAGVCAMAAAB3iVNzAAAJJGlDQ1BpY2MAAHjalZVnUJNZF8fv\n8zzphUASQodQQ5EqJYCUEFoo0quoQOidUEVsiLgCK4qINEUQUUDBVSmyVkSxsCgoYkE3yCKgrBtX\nERWUF/Sd0Xnf2Q/7n7n3/OY/Z+4995wPFwCCOFgSvLQnJqULvJ3smIFBwUzwg8L4aSkcT0838I96\nPwyg5XhvBfj3IkREpvGX4sLSyuWnCNIBgLKXWDMrPWWZDy8xPTz+K59dZsFSgUt8Y5mjv/Ho15xv\nLPqa4+vNXXoVCgAcKfoHDv+B/3vvslQ4gvTYqMhspk9yVHpWmCCSmbbcCR6Xy/QUJEfFJkT+UPC/\nSv4HpUdmpy9HbnLKBkFsdEw68/8ONTIwNATfZ/HW62uPIUb//85nWd+95HoA2LMAIHu+e+GVAHTu\nAED68XdPbamvlHwAOu7wMwSZ3zzU8oYGBEABdCADFIEq0AS6wAiYAUtgCxyAC/AAviAIrAN8EAMS\ngQBkgVywDRSAIrAH7AdVoBY0gCbQCk6DTnAeXAHXwW1wFwyDJ0AIJsArIALvwTwEQViIDNEgGUgJ\nUod0ICOIDVlDDpAb5A0FQaFQNJQEZUC50HaoCCqFqqA6qAn6BToHXYFuQoPQI2gMmob+hj7BCEyC\n6bACrAHrw2yYA7vCvvBaOBpOhXPgfHg3XAHXwyfgDvgKfBsehoXwK3gWAQgRYSDKiC7CRriIBxKM\nRCECZDNSiJQj9Ugr0o30IfcQITKDfERhUDQUE6WLskQ5o/xQfFQqajOqGFWFOo7qQPWi7qHGUCLU\nFzQZLY/WQVugeehAdDQ6C12ALkc3otvR19DD6An0ewwGw8CwMGYYZ0wQJg6zEVOMOYhpw1zGDGLG\nMbNYLFYGq4O1wnpgw7Dp2AJsJfYE9hJ2CDuB/YAj4pRwRjhHXDAuCZeHK8c14y7ihnCTuHm8OF4d\nb4H3wEfgN+BL8A34bvwd/AR+niBBYBGsCL6EOMI2QgWhlXCNMEp4SyQSVYjmRC9iLHErsYJ4iniD\nOEb8SKKStElcUggpg7SbdIx0mfSI9JZMJmuQbcnB5HTybnIT+Sr5GfmDGE1MT4wnFiG2RaxarENs\nSOw1BU9Rp3Ao6yg5lHLKGcodyow4XlxDnCseJr5ZvFr8nPiI+KwETcJQwkMiUaJYolnipsQUFUvV\noDpQI6j51CPUq9RxGkJTpXFpfNp2WgPtGm2CjqGz6Dx6HL2IfpI+QBdJUiWNJf0lsyWrJS9IChkI\nQ4PBYyQwShinGQ8Yn6QUpDhSkVK7pFqlhqTmpOWkbaUjpQul26SHpT/JMGUcZOJl9sp0yjyVRclq\ny3rJZskekr0mOyNHl7OU48sVyp2WeywPy2vLe8tvlD8i3y8/q6Co4KSQolCpcFVhRpGhaKsYp1im\neFFxWommZK0Uq1SmdEnpJVOSyWEmMCuYvUyRsryys3KGcp3ygPK8CkvFTyVPpU3lqSpBla0apVqm\n2qMqUlNSc1fLVWtRe6yOV2erx6gfUO9Tn9NgaQRo7NTo1JhiSbN4rBxWC2tUk6xpo5mqWa95Xwuj\nxdaK1zqodVcb1jbRjtGu1r6jA+uY6sTqHNQZXIFeYb4iaUX9ihFdki5HN1O3RXdMj6Hnppen16n3\nWl9NP1h/r36f/hcDE4MEgwaDJ4ZUQxfDPMNuw7+NtI34RtVG91eSVzqu3LKya+UbYx3jSONDxg9N\naCbuJjtNekw+m5qZCkxbTafN1MxCzWrMRth0tie7mH3DHG1uZ77F/Lz5RwtTi3SL0xZ/Wepaxls2\nW06tYq2KXNWwatxKxSrMqs5KaM20DrU+bC20UbYJs6m3eW6rahth22g7ydHixHFOcF7bGdgJ7Nrt\n5rgW3E3cy/aIvZN9of2AA9XBz6HK4ZmjimO0Y4ujyMnEaaPTZWe0s6vzXucRngKPz2viiVzMXDa5\n9LqSXH1cq1yfu2m7Cdy63WF3F/d97qOr1Vcnre70AB48j30eTz1Znqmev3phvDy9qr1eeBt653r3\n+dB81vs0+7z3tfMt8X3ip+mX4dfjT/EP8W/ynwuwDygNEAbqB24KvB0kGxQb1BWMDfYPbgyeXeOw\nZv+aiRCTkIKQB2tZa7PX3lwnuy5h3YX1lPVh68+EokMDQptDF8I8wurDZsN54TXhIj6Xf4D/KsI2\noixiOtIqsjRyMsoqqjRqKtoqel/0dIxNTHnMTCw3tir2TZxzXG3cXLxH/LH4xYSAhLZEXGJo4rkk\nalJ8Um+yYnJ28mCKTkpBijDVInV/qkjgKmhMg9LWpnWl05c+xf4MzYwdGWOZ1pnVmR+y/LPOZEtk\nJ2X3b9DesGvDZI5jztGNqI38jT25yrnbcsc2cTbVbYY2h2/u2aK6JX/LxFanrce3EbbFb/stzyCv\nNO/d9oDt3fkK+Vvzx3c47WgpECsQFIzstNxZ+xPqp9ifBnat3FW560thROGtIoOi8qKFYn7xrZ8N\nf674eXF31O6BEtOSQ3swe5L2PNhrs/d4qURpTun4Pvd9HWXMssKyd/vX779Zblxee4BwIOOAsMKt\noqtSrXJP5UJVTNVwtV11W418za6auYMRB4cO2R5qrVWoLar9dDj28MM6p7qOeo368iOYI5lHXjT4\nN/QdZR9tapRtLGr8fCzpmPC49/HeJrOmpmb55pIWuCWjZfpEyIm7J+1PdrXqtta1MdqKToFTGade\n/hL6y4PTrqd7zrDPtJ5VP1vTTmsv7IA6NnSIOmM6hV1BXYPnXM71dFt2t/+q9+ux88rnqy9IXii5\nSLiYf3HxUs6l2cspl2euRF8Z71nf8+Rq4NX7vV69A9dcr9247nj9ah+n79INqxvnb1rcPHeLfavz\ntuntjn6T/vbfTH5rHzAd6Lhjdqfrrvnd7sFVgxeHbIau3LO/d/0+7/7t4dXDgw/8HjwcCRkRPox4\nOPUo4dGbx5mP559sHUWPFj4Vf1r+TP5Z/e9av7cJTYUXxuzH+p/7PH8yzh9/9UfaHwsT+S/IL8on\nlSabpoymzk87Tt99ueblxKuUV/MzBX9K/FnzWvP12b9s/+oXBYom3gjeLP5d/Fbm7bF3xu96Zj1n\nn71PfD8/V/hB5sPxj+yPfZ8CPk3OZy1gFyo+a33u/uL6ZXQxcXHxPy6ikLxyKdSVAAAAIGNIUk0A\nAHomAACAhAAA+gAAAIDoAAB1MAAA6mAAADqYAAAXcJy6UTwAAACrUExURf///wBZkABZkABZkABZ\nkABZkABZkABZkABZkABZkABZkABZkBdSUhdSUhdTUxdSUhdSUhdSUhdSUhdSUhdSUhdSUhdSUhdS\nUhdSUhdSUhdSUhdSUgBZkABZkBdSUgBZkBdSUj+PVz+PVz+PVz+PVz+PVz+PVz+PVz+PVz+PVz+P\nVz+PVz+PVz+PVz+PVxdSUhdSUhdSUhdTUxdSUhhUVABZkBdSUj+PV////2WyAXMAAAA1dFJOUwAR\niO6ZuyJVZjPdqkR3dZnMu6OIZjMRqt3uVSJEzMd31rsziN2ZIkQRqlXuZsx3r+HSW+wg7fJpnQAA\nAAFiS0dEAIgFHUgAAAAJcEhZcwAAAEgAAABIAEbJaz4AAAAHdElNRQfhCwINNSfD9n+TAAASs0lE\nQVR42u2dCZuiuBqFU9pVXVWtrQ6uZffMZVFwq5m7SP7/P7skiFuxhiAJnPd5ulAgIZpD8sVODoQA\nAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA\nAAAAAAAAAAAAAAAAAAAAaDZPne6x++2p7mIAZfh27D53j891FwMoQ6fzQp6Ox7qLAZTiO9oIcEUn\niCNe6i4EUIjX5273te5CAKV4O3brLgJQhvfjE3lBZAnOPPPR54+6iwGU4ekbfqECAAAAgCi9/s9+\nr+5CADXo9QdD3x/9MfL94QCyaDcnMRjjCX87GRuQRWu5E8MFyKJ1THvGLF4MF06ymBm9ad3FBVXC\nxDD3FyPjY5Lj7MmHMVr4c8iimVzEsCyUbglZNA9RMVyALBpDWJVlxFBNXqAGqrmvy7c3oAaqbuQh\nC4143Pig2LgF1EAdvyHk+W0D1EDNvzMm/v4JakCZH50hCxUYqyGGCydZjOsuR2vpKSQG1UsFAABA\nUZ5+/aq7CK1EXReQH1geVgcdZV1Afh2hiDp47bwp6gLy3lWyWK3g5fhedxG+0jl+hyJq4un3UT3P\nh5fuDwJF1MPbu4KCID+6L1BEPbx2uwoO8t6Ovzud47GjoFabztvxXUXjKDbQYKg4CGo478f354C3\nussRB3qNOjjdiwp2HFAEAAAAoAG9n0pOTZmOf46x9uvxLI354s/F3FBrycT0wxj68z/n/tD4gCoe\nyHQ88kfBjRht1aBnjPzFbMw0uhzPFv7IULINayC9wVXbwNqKQf3f/KQ/8/1Zf5K+C1RAjAJuFFJH\nkZIbhKtmA1RBUi9RX+8xHQ8ygoYwtBgo07c1idS2oIbeI39dZ+sGFCZHjT+09yjcHyDYlEneXuFB\nvYdwzIhgUw6F7v2qe4/StzqCzZII1HBlvYescADBpjCivUAFvYfsWkSwKUCpe11q71FRS49gswgS\nalRO71FxNIhgMxeyWv2y+TzoJkawmYUhMTJkbY0hWIwHdvRhmCJY0ObTk3tTimbXe/BgYDpGRAEA\nAIVpryFHez95Kr/e5S53EHQeqcOwBFYkcfx1fJb6vQg6j9RhWAIrklj+epW7JErQeaQOwxJYkSQh\n/XsRdB55sGEJrEgSkf29CDqPPNiwBFYkyUj+XgSdRx5tWAIrkmTkfi+CziOPNiyBFUkKUhUh6Dzy\ncMMSWJEk8YvfKR1p96eg80g9hiXoNWLohLdKR1Z+gs4j9RiWQBEAAACAaix/6jFfrdf/2cdzqGPx\npebWX/xLbF6d3GKkMflgDw/1R3+MfH+OR4h+xZd4Ty9HizH5WAwFJqzJLEYS015/wGQwix4wvPzg\njxAdDvDY+it8efMNjcWMfbFTwzcKf8ESixEHe34xaxiMmCe99canx9Zj6iVHWlVMhouP08ve5eXD\ni3HPqWFgzy9epp5msIdFDmd4QLmsqjD82VXDYCxGxb7YKhTBnlg7L3TzXzUl7e1G5FTFZDi/zScI\nKQpFmHIVcbnje8Xv+KBVMdocdY4kVEVs5FAswpRRDM5EUlTQ3qhTQlX05vO4TApFmBKKETUMMuuw\njVFn6apIqfgCEWa5Ylwqrop2vmVR56jk+sfefJhSC7kjTNFiLHuPatxbE3WWU8R05qenzxthChQj\nbBjmFTUMCZ+3BVFnKUUE0WPmF5MvwixUjFPDMBrUdLs2O+osoYjpLNf9nyvCzFsMhe7RpkadA2FF\njHP/CpUjwswuxvL0P1R1NQzxNDDqNEQVMVj0C1xlMShVjPHpf6hqbxjiiaLOcd0FqZVivwYuS9pc\n9dVvmac6FBIAAEQxqUlMK+0EK3deQDdsGuCs1le7XGp5hLqXHeZ9ouuDdzx9K+QBUvB0qakfxose\nxTxh0+CPtaHbyy6X7bqqdIveJ0pRxO9iHiAFT5ea+mGwlWkaFPOEHVb3ZucRsnXobuW5QaNhskr3\nVkHrsSXsvX06xs9xtimK6HSeyEv+RVEFT5ea+lG8sGI+2hdFnJMiLLon7s4l68Pq3EasHI/sqcff\nR8e8QBbeJkUR/Cso5gFSzjLkwYYjovylWxtBgmZgswq2W3pWhLcOt+x9dGxPg53rdEU8/T5+L1CC\ngqdLTf0ono/vesURhCvCpBw3UsR6w99yRUTHwtNTFfGr2y3ivlDwdKmpH0YnkIQulojnXmNLTDvc\nFSnC3KzPbUR0LFsRr8ffRRb3FzxdaupH8qZH50YukaUT/DuEuyJFsHpfR73G6RjvNawURQR1VKR5\nLHi61NSP4lf3XYcAOOI0+txZTAmfwfhic1aEY5O1ySJOHl2Gx9a7FRuDJCoi+OC/mQdIzqvnOd2W\ndbG6eOry0ee3usuRE/4L1S78hSoYWdKgp4gU4Tr0YNlBgHEIR5/sWDj6TP7NMnLtyXn1PKdTW9LF\nauPlx1GjX6jUx/TqLgFQCm9TdwkAAKBaemOjaVMuC1FwNp7w5D0Jqatm0jPYHFx/9PffI59NxTV6\nis7+qxQogs2jGxujIZ8a3o8mGS7ZbPG57w/ZHG2FpghXT7sVMfkwmCuFPxokzAbOPKF5tFQRxZqA\nuEaksbROEUHlzgTDhHOgMWty6NkeRUirzxKa0oEWKKKaNr+xoWejFfGAuLB5oWczFfHoG7hJoWfT\nFFFrJ9+I0LMxilCoNvQOPfVXhKottq6hp86K0CKq06KQ12ipCP1uP1Ubshi0U8RU4y76KtjRQsea\nUHvgWJ6gwai7CAAAIIZlZqwDvMGll1QFLgA0gq8lzz2l3nMvqfJg5j81F3rZfSShtreKWIXlTMXs\nTaQqQi+7jySU8VZhpiOHLSEb16S7T8KXiJmflPcal307utleKtFmq8MsQvZsG7pTRKmi7JLh9iYu\nDdI61uX81dah/DqOTQqimd1HEsp4q5grj3zSNaGHNd/azGTACRVx3ues2YLBKMme7oOaNPkaY2/D\nV5meU0XZpRDamwRKYmuXL5ffu8QOrufu9iIfQx+7jzRU8FaxeO0dbEJZRbDFw8x45NRGnPeFbccl\nEeH1atFw8WjwMkp1zi4Frgh2nnt1Pl8aysVgCwSdOtl9pKCEt4ob+oyY4dCCqcAmYaWdzAbO20vX\n760cloaJ5GDzFennVOfsUi95yuz6/PBSIcU/hU52H8mo4a1yrucCijAP1unter/a2beKyHfJiyLu\nLi+KPnYfiSjirWJFpodRldz0GtE+3pqfK5vvjt6GrleXXmObeclrRdxeflc4qmRoZveRhDLeKiaP\n5qxz7ds7N4wirxSxOtxElkH1e5876n2ywYLt8MgyShVll4J7Ms8LvZCuL29zO73Pgp9AM7uPBNTx\nVuGWl5+XNoK9N/fmjSLYvqvRp8WMMdeOGY1Cg5o9p4qyS+PARp8kVMTN5dnoc2cXNptohN2HLt4q\nV2x3dZcAKANzyGXmqACcsE8GygDowbTfx/QjcGYyWPzzz2Kg4Xy6O1RYE90AxiN/1iOkN9P/2WrG\noHwebWdpLBbG8utrPTFGdZdAd760C6f2QlegiFJM+/OY2CGIKebaRplQRAmSaz5eKVoARQiT0Tvo\nGmX2oAgh8kSQekaZUIQIue9/DaNMKKIwxWIE7aLMnl93CTSjeA1rFmVCEYUQ7AV0ijKhiPyUiRT1\niTKhiLyUvs81iTKXUEQuljJiARaDqN9OQBH5kDNemPbr/hwAAAAAeCyyfDaa4ToC5PlsNMN1BEjz\n2WiI6wgIkeWz0QzXESDLZ6MhriNAms9GM1xHgEyfjQa4jgBZPhsNcR0B0nw29HAdcYU8dtqFNJ8N\nLVxHbJm+rwAAAACQB1xHGoMUn42Pmf/PP/7so+4PAyRQ3mdjaYRTNdlkSz1mZYM0Sq6Yno5H/mg8\njXkDNKWUIr42C1GDAbRFXBHT/jw2dAiCCr0WhIIbRBWRVu9JWgE6ILSGPrtvQJipLcUVkTN+RJip\nKUUVUeTmR5ipI4VWTBcPEBBmakcBRYjVLsJMzciriDI9AMJMnciliNJRIsJMfcjhsyHnFkeYqQtZ\niljKCwNYIILOQ39kDhVgPAIAAPpiiT83GjSSVeYKBlmWI1XkBuSTvaZFluVIFbkB+bgmsQ6U8kfZ\nxyPJcqSS3EAlODYhe5r6bHpZliNV5Aaks/vMOkOS5UgluQH57HfOapt2gizLkSpyA1XgbW3nkHxY\nouWI9NxAVXg08XcJSZYjleQGKmEXRJZbuk44KstypIrcQDWw0echMZCQZjlSQW4AAAAAaBnLfh+T\nqHRHiuUIZ9kf+v/+tz+EKPRGkiLYbGyuBaYLiEJnZChiOp75V9O5IQqtKa2IUA538/IhCn0pp4hY\nOYRAFJpSRhEfbO1PyqodiEJHhBURyGExyHx6NEShHWKK4HLIuRQMotALAUX0CsghBKLQiKKKmLAV\nvwILRSEKXSikiEAO/kzYFQCi0IL8iignhxCIQn1yKmIpQQ6nnCAKtcmjCFaJMi1kIAqVyVbEsorq\n46KAJjSlort5CfsRAADQGcusuwRALdwYv4lORSYhT1VlDCQSo4hOVSYh3+A+oggb16S7veVQxw0V\nwP5YG0o3XvByH+y/MZt47byRpyqWaXU6L9VkDApCD2vySc01WTlnRXjOyvM2JnF3gSw2zn2SykxC\nvqONUAC6P/0JlBApgi8VXrM2Yx3TdTz9PlbiCRF0SN9e6v46AOHOAezPlSIiB7NzL3LN23s1giCv\nz7AfUYGiinjtdn9VVZa3Y7furwN8VcSe9RpBNLm24xTxdnyvpmV/P8LXTg2uFbGmn2R9YJHlJogs\nD3GKeD++M5MQ+V5Cz3z0+aPurwPcKIJ87qjpmlejT3KviJNJiPyOA/64AAAAAChB72fmijCgJEWe\nIZub6Xjo/+kP8fw/HfHlT6pbGou5MSVTY77Ag0L1w5fduH+M/NFHzGugCXIV8aVdOLUXQB9kKqI3\niIkdWEyR7TsAlEGaIlJqPlYpQFFGchSR0TsgytQHKYrIE0EiytSE8orIff8jytSCUUl/w0IxAqJM\nDSilCIEaRpSpOiUUIdgLIMpUG2FFlIkUEWUqzEBIEaXvc0SZyiLkeGlIiAVYDCLtWQ+gZnpyxguS\nsgEAAADA4/GEXGekuYXAdUQ5tlQklSy3kMrsTIAoNqXUJZ8OpRurQDJZbiGV2ZkAYfhSQLon3mbn\nFUspyy2kMjsTIARTxGETvLDotkg6aW4hVdmZAEGYIqjNXoV/8yLLLaQyOxMgiKgiJLmFVGlnAoQQ\n6zVkuYVUZmcChHGpR/ZBZLk2nQKpZLmFVGZnAoTxDtHoc10glSy3kMrsTAAAAADQBqb/wby6hiDH\nhmQ6/O8QkmgGUhSxHA7/N8Tz3ZqBDEVMFoPg72AxqfvDAAlIUEQoCEiiIZRXRCQISKIZlFbEx+Ky\nVMNYYHmX9pRVxNgfJ74DOlJSEfcSgCS0p5wivgoAktCdUoqIixuu4wqgIWUUET+2uIw9gI6UUETS\nYBOS0BpxRST/+gBJ6IyoIqZp/4+xHOL/vbRFVBGj1DqfDkd1fzDwYCbpjcAUv2cDAEBrMYusOAct\ngD+O9oosBxE4jDSdO0VkOYjAYURJtuY+qMitQ3crjxDboc6e8O3OJsQ1XYeaa5PutkGfsDWpY9mU\nLyKOEphucPCT+RnR3f5OEVkOInAYURKXblzP3blkfViRz2DrBvVqB//cQBJuUPtbtiRw5QQtgMm2\nq2DHmkQJCD2syWewY7WzvNV9r0GyHUTgMKIcLg3CwU1Qt8ykitUxcdfhNlABWzlM6P5kM7Anpx3u\nOQHfyXYESiHrr4rIchCBw4h6sLomJuW4kasE39o0PMjqmSvitGUvogRh6MC2NvkaWWY6iMBhREFC\nRVwrgeRRRGRIkqqILAcROIyoCK/jzSF8w91G9m64XR2SFRElOCsirtfIchCBw4iS8Dp26SfxVpsg\nRmSR5TaMLKmdrIgowVkRq93as+8UkeUgAocRJeF1zAaT3GWEjT6DsSTf2iRZEecEkSL46NO8tTLK\nchCBwwgAAAAAAMhP1mw8OU4lQB8yFQF/kZaRqQg8yq1lQBHgFigC3AJFgFuyFDGCIloGFAFugSLA\nLZmKgK9Iy4AiwC1QBLgFigC3ZCliAEW0jCxFGFAEAAAAAL5gmcnHTGrylWFwIGkNZrQQJBaXWh5x\nPRKzlhw0E4tmKCLcQhFtwaWU2i7dO9QJ+gVvRelhe3eY9xpsueC1ZwloLKwRcHcbj2ycoAdZedx0\n5PZwpIhrzxLQWLgimAiCFxYXw8G+O3xWxJVnCWgsLj0vOud9BO8mbg+fFUGu1hiDpnKriNjDUESr\nuFaERbdxh6GIVsGig0gRxDwwbzvr5jCBIlrGgY0+SVjPbPTJbUrOQBEAAAAAAAAAAABQnf8DmwKy\nJJha+OIAAAAldEVYdGRhdGU6Y3JlYXRlADIwMTctMTEtMDJUMjI6NTM6MzkrMDk6MDBXyJTvAAAA\nJXRFWHRkYXRlOm1vZGlmeQAyMDE3LTExLTAyVDIyOjUzOjM5KzA5OjAwJpUsUwAAACN0RVh0cHM6\nSGlSZXNCb3VuZGluZ0JveAA1Mjl4NDA1LTI2NC0yMDKzNU7+AAAAHHRFWHRwczpMZXZlbABBZG9i\nZS0zLjAgRVBTRi0zLjAKm3C74wAAACJ0RVh0cHM6U3BvdENvbG9yLTAAZm9udCBMaWJlcmF0aW9u\nU2Fuc/4Zp8YAAAAASUVORK5CYII=\n",
165 | "text/plain": [
166 | ""
167 | ]
168 | },
169 | "metadata": {},
170 | "output_type": "display_data"
171 | }
172 | ],
173 | "source": [
174 | "draw_nltk_tree(nltkTree.fromstring(sample))"
175 | ]
176 | },
177 | {
178 | "cell_type": "markdown",
179 | "metadata": {},
180 | "source": [
181 | "### Tree Class "
182 | ]
183 | },
184 | {
185 | "cell_type": "markdown",
186 | "metadata": {},
187 | "source": [
188 | "borrowed code from https://github.com/bogatyy/cs224d/tree/master/assignment3"
189 | ]
190 | },
191 | {
192 | "cell_type": "code",
193 | "execution_count": 10,
194 | "metadata": {
195 | "collapsed": true
196 | },
197 | "outputs": [],
198 | "source": [
199 | "class Node: # a node in the tree\n",
200 | " def __init__(self, label, word=None):\n",
201 | " self.label = label\n",
202 | " self.word = word\n",
203 | " self.parent = None # reference to parent\n",
204 | " self.left = None # reference to left child\n",
205 | " self.right = None # reference to right child\n",
206 | " # true if I am a leaf (could have probably derived this from if I have\n",
207 | " # a word)\n",
208 | " self.isLeaf = False\n",
209 | " # true if we have finished performing fowardprop on this node (note,\n",
210 | " # there are many ways to implement the recursion.. some might not\n",
211 | " # require this flag)\n",
212 | "\n",
213 | " def __str__(self):\n",
214 | " if self.isLeaf:\n",
215 | " return '[{0}:{1}]'.format(self.word, self.label)\n",
216 | " return '({0} <- [{1}:{2}] -> {3})'.format(self.left, self.word, self.label, self.right)\n",
217 | "\n",
218 | "\n",
219 | "class Tree:\n",
220 | "\n",
221 | " def __init__(self, treeString, openChar='(', closeChar=')'):\n",
222 | " tokens = []\n",
223 | " self.open = '('\n",
224 | " self.close = ')'\n",
225 | " for toks in treeString.strip().split():\n",
226 | " tokens += list(toks)\n",
227 | " self.root = self.parse(tokens)\n",
228 | " # get list of labels as obtained through a post-order traversal\n",
229 | " self.labels = get_labels(self.root)\n",
230 | " self.num_words = len(self.labels)\n",
231 | "\n",
232 | " def parse(self, tokens, parent=None):\n",
233 | " assert tokens[0] == self.open, \"Malformed tree\"\n",
234 | " assert tokens[-1] == self.close, \"Malformed tree\"\n",
235 | "\n",
236 | " split = 2 # position after open and label\n",
237 | " countOpen = countClose = 0\n",
238 | "\n",
239 | " if tokens[split] == self.open:\n",
240 | " countOpen += 1\n",
241 | " split += 1\n",
242 | " # Find where left child and right child split\n",
243 | " while countOpen != countClose:\n",
244 | " if tokens[split] == self.open:\n",
245 | " countOpen += 1\n",
246 | " if tokens[split] == self.close:\n",
247 | " countClose += 1\n",
248 | " split += 1\n",
249 | "\n",
250 | " # New node\n",
251 | " node = Node(int(tokens[1])) # zero index labels\n",
252 | "\n",
253 | " node.parent = parent\n",
254 | "\n",
255 | " # leaf Node\n",
256 | " if countOpen == 0:\n",
257 | " node.word = ''.join(tokens[2: -1]).lower() # lower case?\n",
258 | " node.isLeaf = True\n",
259 | " return node\n",
260 | "\n",
261 | " node.left = self.parse(tokens[2: split], parent=node)\n",
262 | " node.right = self.parse(tokens[split: -1], parent=node)\n",
263 | "\n",
264 | " return node\n",
265 | "\n",
266 | " def get_words(self):\n",
267 | " leaves = getLeaves(self.root)\n",
268 | " words = [node.word for node in leaves]\n",
269 | " return words\n",
270 | "\n",
271 | "def get_labels(node):\n",
272 | " if node is None:\n",
273 | " return []\n",
274 | " return get_labels(node.left) + get_labels(node.right) + [node.label]\n",
275 | "\n",
276 | "def getLeaves(node):\n",
277 | " if node is None:\n",
278 | " return []\n",
279 | " if node.isLeaf:\n",
280 | " return [node]\n",
281 | " else:\n",
282 | " return getLeaves(node.left) + getLeaves(node.right)\n",
283 | "\n",
284 | " \n",
285 | "def loadTrees(dataSet='train'):\n",
286 | " \"\"\"\n",
287 | " Loads training trees. Maps leaf node words to word ids.\n",
288 | " \"\"\"\n",
289 | " file = '../dataset/trees/%s.txt' % dataSet\n",
290 | " print(\"Loading %s trees..\" % dataSet)\n",
291 | " with open(file, 'r', encoding='utf-8') as fid:\n",
292 | " trees = [Tree(l) for l in fid.readlines()]\n",
293 | "\n",
294 | " return trees"
295 | ]
296 | },
297 | {
298 | "cell_type": "code",
299 | "execution_count": 11,
300 | "metadata": {
301 | "collapsed": false
302 | },
303 | "outputs": [
304 | {
305 | "name": "stdout",
306 | "output_type": "stream",
307 | "text": [
308 | "Loading train trees..\n"
309 | ]
310 | }
311 | ],
312 | "source": [
313 | "train_data = loadTrees('train')"
314 | ]
315 | },
316 | {
317 | "cell_type": "markdown",
318 | "metadata": {},
319 | "source": [
320 | "### Build Vocab "
321 | ]
322 | },
323 | {
324 | "cell_type": "code",
325 | "execution_count": 12,
326 | "metadata": {
327 | "collapsed": true
328 | },
329 | "outputs": [],
330 | "source": [
331 | "vocab = list(set(flatten([t.get_words() for t in train_data])))"
332 | ]
333 | },
334 | {
335 | "cell_type": "code",
336 | "execution_count": 13,
337 | "metadata": {
338 | "collapsed": true
339 | },
340 | "outputs": [],
341 | "source": [
342 | "word2index = {'': 0}\n",
343 | "for vo in vocab:\n",
344 | " if word2index.get(vo) is None:\n",
345 | " word2index[vo] = len(word2index)\n",
346 | " \n",
347 | "index2word = {v:k for k, v in word2index.items()}"
348 | ]
349 | },
350 | {
351 | "cell_type": "markdown",
352 | "metadata": {},
353 | "source": [
354 | "## Modeling "
355 | ]
356 | },
357 | {
358 | "cell_type": "markdown",
359 | "metadata": {},
360 | "source": [
361 | "
\n",
362 | "borrowed image from https://nlp.stanford.edu/~socherr/EMNLP2013_RNTN.pdf"
363 | ]
364 | },
365 | {
366 | "cell_type": "code",
367 | "execution_count": 14,
368 | "metadata": {
369 | "collapsed": true
370 | },
371 | "outputs": [],
372 | "source": [
373 | "class RNTN(nn.Module):\n",
374 | " \n",
375 | " def __init__(self, word2index, hidden_size, output_size):\n",
376 | " super(RNTN,self).__init__()\n",
377 | " \n",
378 | " self.word2index = word2index\n",
379 | " self.embed = nn.Embedding(len(word2index), hidden_size)\n",
380 | "# self.V = nn.ModuleList([nn.Linear(hidden_size*2,hidden_size*2) for _ in range(hidden_size)])\n",
381 | "# self.W = nn.Linear(hidden_size*2,hidden_size)\n",
382 | " self.V = nn.ParameterList([nn.Parameter(torch.randn(hidden_size * 2, hidden_size * 2)) for _ in range(hidden_size)]) # Tensor\n",
383 | " self.W = nn.Parameter(torch.randn(hidden_size * 2, hidden_size))\n",
384 | " self.b = nn.Parameter(torch.randn(1, hidden_size))\n",
385 | "# self.W_out = nn.Parameter(torch.randn(hidden_size,output_size))\n",
386 | " self.W_out = nn.Linear(hidden_size, output_size)\n",
387 | " \n",
388 | " def init_weight(self):\n",
389 | " nn.init.xavier_uniform(self.embed.state_dict()['weight'])\n",
390 | " nn.init.xavier_uniform(self.W_out.state_dict()['weight'])\n",
391 | " for param in self.V.parameters():\n",
392 | " nn.init.xavier_uniform(param)\n",
393 | " nn.init.xavier_uniform(self.W)\n",
394 | " self.b.data.fill_(0)\n",
395 | "# nn.init.xavier_uniform(self.W_out)\n",
396 | " \n",
397 | " def tree_propagation(self, node):\n",
398 | " \n",
399 | " recursive_tensor = OrderedDict()\n",
400 | " current = None\n",
401 | " if node.isLeaf:\n",
402 | " tensor = Variable(LongTensor([self.word2index[node.word]])) if node.word in self.word2index.keys() \\\n",
403 | " else Variable(LongTensor([self.word2index['']]))\n",
404 | " current = self.embed(tensor) # 1xD\n",
405 | " else:\n",
406 | " recursive_tensor.update(self.tree_propagation(node.left))\n",
407 | " recursive_tensor.update(self.tree_propagation(node.right))\n",
408 | " \n",
409 | " concated = torch.cat([recursive_tensor[node.left], recursive_tensor[node.right]], 1) # 1x2D\n",
410 | " xVx = [] \n",
411 | " for i, v in enumerate(self.V):\n",
412 | "# xVx.append(torch.matmul(v(concated),concated.transpose(0,1)))\n",
413 | " xVx.append(torch.matmul(torch.matmul(concated, v), concated.transpose(0, 1)))\n",
414 | " \n",
415 | " xVx = torch.cat(xVx, 1) # 1xD\n",
416 | "# Wx = self.W(concated)\n",
417 | " Wx = torch.matmul(concated, self.W) # 1xD\n",
418 | "\n",
419 | " current = F.tanh(xVx + Wx + self.b) # 1xD\n",
420 | " recursive_tensor[node] = current\n",
421 | " return recursive_tensor\n",
422 | " \n",
423 | " def forward(self, Trees, root_only=False):\n",
424 | " \n",
425 | " propagated = []\n",
426 | " if not isinstance(Trees, list):\n",
427 | " Trees = [Trees]\n",
428 | " \n",
429 | " for Tree in Trees:\n",
430 | " recursive_tensor = self.tree_propagation(Tree.root)\n",
431 | " if root_only:\n",
432 | " recursive_tensor = recursive_tensor[Tree.root]\n",
433 | " propagated.append(recursive_tensor)\n",
434 | " else:\n",
435 | " recursive_tensor = [tensor for node,tensor in recursive_tensor.items()]\n",
436 | " propagated.extend(recursive_tensor)\n",
437 | " \n",
438 | " propagated = torch.cat(propagated) # (num_of_node in batch, D)\n",
439 | " \n",
440 | "# return F.log_softmax(propagated.matmul(self.W_out))\n",
441 | " return F.log_softmax(self.W_out(propagated),1)"
442 | ]
443 | },
444 | {
445 | "cell_type": "markdown",
446 | "metadata": {},
447 | "source": [
448 | "## Training "
449 | ]
450 | },
451 | {
452 | "cell_type": "markdown",
453 | "metadata": {},
454 | "source": [
455 | "It takes for a while... It builds its computational graph dynamically. So Its computation is difficult to train with batch."
456 | ]
457 | },
458 | {
459 | "cell_type": "code",
460 | "execution_count": 15,
461 | "metadata": {
462 | "collapsed": true
463 | },
464 | "outputs": [],
465 | "source": [
466 | "HIDDEN_SIZE = 30\n",
467 | "ROOT_ONLY = False\n",
468 | "BATCH_SIZE = 20\n",
469 | "EPOCH = 20\n",
470 | "LR = 0.01\n",
471 | "LAMBDA = 1e-5\n",
472 | "RESCHEDULED = False"
473 | ]
474 | },
475 | {
476 | "cell_type": "code",
477 | "execution_count": 18,
478 | "metadata": {
479 | "collapsed": true
480 | },
481 | "outputs": [],
482 | "source": [
483 | "model = RNTN(word2index, HIDDEN_SIZE,5)\n",
484 | "model.init_weight()\n",
485 | "if USE_CUDA:\n",
486 | " model = model.cuda()\n",
487 | "\n",
488 | "loss_function = nn.CrossEntropyLoss()\n",
489 | "optimizer = optim.Adam(model.parameters(), lr=LR)"
490 | ]
491 | },
492 | {
493 | "cell_type": "code",
494 | "execution_count": 19,
495 | "metadata": {
496 | "collapsed": false
497 | },
498 | "outputs": [
499 | {
500 | "name": "stdout",
501 | "output_type": "stream",
502 | "text": [
503 | "[0/20] mean_loss : 1.62\n",
504 | "[0/20] mean_loss : 1.25\n",
505 | "[0/20] mean_loss : 0.95\n",
506 | "[0/20] mean_loss : 0.90\n",
507 | "[0/20] mean_loss : 0.88\n",
508 | "[1/20] mean_loss : 0.88\n",
509 | "[1/20] mean_loss : 0.84\n",
510 | "[1/20] mean_loss : 0.83\n",
511 | "[1/20] mean_loss : 0.82\n",
512 | "[1/20] mean_loss : 0.82\n",
513 | "[2/20] mean_loss : 0.81\n",
514 | "[2/20] mean_loss : 0.79\n",
515 | "[2/20] mean_loss : 0.78\n",
516 | "[2/20] mean_loss : 0.76\n",
517 | "[2/20] mean_loss : 0.75\n",
518 | "[3/20] mean_loss : 0.68\n",
519 | "[3/20] mean_loss : 0.73\n",
520 | "[3/20] mean_loss : 0.74\n",
521 | "[3/20] mean_loss : 0.72\n",
522 | "[3/20] mean_loss : 0.72\n",
523 | "[4/20] mean_loss : 0.74\n",
524 | "[4/20] mean_loss : 0.69\n",
525 | "[4/20] mean_loss : 0.69\n",
526 | "[4/20] mean_loss : 0.68\n",
527 | "[4/20] mean_loss : 0.67\n",
528 | "[5/20] mean_loss : 0.73\n",
529 | "[5/20] mean_loss : 0.65\n",
530 | "[5/20] mean_loss : 0.64\n",
531 | "[5/20] mean_loss : 0.64\n",
532 | "[5/20] mean_loss : 0.65\n",
533 | "[6/20] mean_loss : 0.67\n",
534 | "[6/20] mean_loss : 0.62\n",
535 | "[6/20] mean_loss : 0.62\n",
536 | "[6/20] mean_loss : 0.62\n",
537 | "[6/20] mean_loss : 0.62\n",
538 | "[7/20] mean_loss : 0.57\n",
539 | "[7/20] mean_loss : 0.59\n",
540 | "[7/20] mean_loss : 0.59\n",
541 | "[7/20] mean_loss : 0.59\n",
542 | "[7/20] mean_loss : 0.59\n",
543 | "[8/20] mean_loss : 0.60\n",
544 | "[8/20] mean_loss : 0.58\n",
545 | "[8/20] mean_loss : 0.59\n",
546 | "[8/20] mean_loss : 0.60\n",
547 | "[8/20] mean_loss : 0.60\n",
548 | "[9/20] mean_loss : 0.52\n",
549 | "[9/20] mean_loss : 0.58\n",
550 | "[9/20] mean_loss : 0.60\n",
551 | "[9/20] mean_loss : 0.59\n",
552 | "[9/20] mean_loss : 0.59\n",
553 | "[10/20] mean_loss : 0.56\n",
554 | "[10/20] mean_loss : 0.56\n",
555 | "[10/20] mean_loss : 0.56\n",
556 | "[10/20] mean_loss : 0.56\n",
557 | "[10/20] mean_loss : 0.56\n",
558 | "[11/20] mean_loss : 0.52\n",
559 | "[11/20] mean_loss : 0.54\n",
560 | "[11/20] mean_loss : 0.54\n",
561 | "[11/20] mean_loss : 0.54\n",
562 | "[11/20] mean_loss : 0.55\n",
563 | "[12/20] mean_loss : 0.55\n",
564 | "[12/20] mean_loss : 0.53\n",
565 | "[12/20] mean_loss : 0.53\n",
566 | "[12/20] mean_loss : 0.53\n",
567 | "[12/20] mean_loss : 0.53\n",
568 | "[13/20] mean_loss : 0.59\n",
569 | "[13/20] mean_loss : 0.52\n",
570 | "[13/20] mean_loss : 0.52\n",
571 | "[13/20] mean_loss : 0.53\n",
572 | "[13/20] mean_loss : 0.53\n",
573 | "[14/20] mean_loss : 0.49\n",
574 | "[14/20] mean_loss : 0.51\n",
575 | "[14/20] mean_loss : 0.51\n",
576 | "[14/20] mean_loss : 0.52\n",
577 | "[14/20] mean_loss : 0.52\n",
578 | "[15/20] mean_loss : 0.43\n",
579 | "[15/20] mean_loss : 0.51\n",
580 | "[15/20] mean_loss : 0.51\n",
581 | "[15/20] mean_loss : 0.51\n",
582 | "[15/20] mean_loss : 0.51\n",
583 | "[16/20] mean_loss : 0.46\n",
584 | "[16/20] mean_loss : 0.50\n",
585 | "[16/20] mean_loss : 0.50\n",
586 | "[16/20] mean_loss : 0.50\n",
587 | "[16/20] mean_loss : 0.50\n",
588 | "[17/20] mean_loss : 0.50\n",
589 | "[17/20] mean_loss : 0.50\n",
590 | "[17/20] mean_loss : 0.50\n",
591 | "[17/20] mean_loss : 0.50\n",
592 | "[17/20] mean_loss : 0.51\n",
593 | "[18/20] mean_loss : 0.46\n",
594 | "[18/20] mean_loss : 0.50\n",
595 | "[18/20] mean_loss : 0.50\n",
596 | "[18/20] mean_loss : 0.49\n",
597 | "[18/20] mean_loss : 0.49\n",
598 | "[19/20] mean_loss : 0.49\n",
599 | "[19/20] mean_loss : 0.49\n",
600 | "[19/20] mean_loss : 0.49\n",
601 | "[19/20] mean_loss : 0.50\n",
602 | "[19/20] mean_loss : 0.50\n"
603 | ]
604 | }
605 | ],
606 | "source": [
607 | "for epoch in range(EPOCH):\n",
608 | " losses = []\n",
609 | " \n",
610 | " # learning rate annealing\n",
611 | " if RESCHEDULED == False and epoch == EPOCH//2:\n",
612 | " LR *= 0.1\n",
613 | " optimizer = optim.Adam(model.parameters(), lr=LR, weight_decay=LAMBDA) # L2 norm\n",
614 | " RESCHEDULED = True\n",
615 | " \n",
616 | " for i, batch in enumerate(getBatch(BATCH_SIZE, train_data)):\n",
617 | " \n",
618 | " if ROOT_ONLY:\n",
619 | " labels = [tree.labels[-1] for tree in batch]\n",
620 | " labels = Variable(LongTensor(labels))\n",
621 | " else:\n",
622 | " labels = [tree.labels for tree in batch]\n",
623 | " labels = Variable(LongTensor(flatten(labels)))\n",
624 | " \n",
625 | " model.zero_grad()\n",
626 | " preds = model(batch, ROOT_ONLY)\n",
627 | " \n",
628 | " loss = loss_function(preds, labels)\n",
629 | " losses.append(loss.data.tolist()[0])\n",
630 | " \n",
631 | " loss.backward()\n",
632 | " optimizer.step()\n",
633 | " \n",
634 | " if i % 100 == 0:\n",
635 | " print('[%d/%d] mean_loss : %.2f' % (epoch, EPOCH, np.mean(losses)))\n",
636 | " losses = []\n",
637 | " "
638 | ]
639 | },
640 | {
641 | "cell_type": "markdown",
642 | "metadata": {},
643 | "source": [
644 | "The convergence of the model is unstable according to the initial values. I tried to 5~6 times for this."
645 | ]
646 | },
647 | {
648 | "cell_type": "markdown",
649 | "metadata": {},
650 | "source": [
651 | "## Test"
652 | ]
653 | },
654 | {
655 | "cell_type": "code",
656 | "execution_count": 20,
657 | "metadata": {
658 | "collapsed": false
659 | },
660 | "outputs": [
661 | {
662 | "name": "stdout",
663 | "output_type": "stream",
664 | "text": [
665 | "Loading test trees..\n"
666 | ]
667 | }
668 | ],
669 | "source": [
670 | "test_data = loadTrees('test')"
671 | ]
672 | },
673 | {
674 | "cell_type": "code",
675 | "execution_count": 21,
676 | "metadata": {
677 | "collapsed": true
678 | },
679 | "outputs": [],
680 | "source": [
681 | "accuracy = 0\n",
682 | "num_node = 0"
683 | ]
684 | },
685 | {
686 | "cell_type": "markdown",
687 | "metadata": {},
688 | "source": [
689 | "### Fine-grained all"
690 | ]
691 | },
692 | {
693 | "cell_type": "markdown",
694 | "metadata": {},
695 | "source": [
696 | "In paper, they acheived 80.2 accuracy. "
697 | ]
698 | },
699 | {
700 | "cell_type": "code",
701 | "execution_count": 23,
702 | "metadata": {
703 | "collapsed": false
704 | },
705 | "outputs": [
706 | {
707 | "name": "stdout",
708 | "output_type": "stream",
709 | "text": [
710 | "79.33705899068254\n"
711 | ]
712 | }
713 | ],
714 | "source": [
715 | "for test in test_data:\n",
716 | " model.zero_grad()\n",
717 | " preds = model(test, ROOT_ONLY)\n",
718 | " labels = test.labels[-1:] if ROOT_ONLY else test.labels\n",
719 | " for pred, label in zip(preds.max(1)[1].data.tolist(), labels):\n",
720 | " num_node += 1\n",
721 | " if pred == label:\n",
722 | " accuracy += 1\n",
723 | "\n",
724 | "print(accuracy/num_node * 100)"
725 | ]
726 | },
727 | {
728 | "cell_type": "markdown",
729 | "metadata": {},
730 | "source": [
731 | "## TODO "
732 | ]
733 | },
734 | {
735 | "cell_type": "markdown",
736 | "metadata": {},
737 | "source": [
738 | "* https://github.com/nearai/pytorch-tools # Dynamic batch using TensorFold"
739 | ]
740 | },
741 | {
742 | "cell_type": "markdown",
743 | "metadata": {
744 | "collapsed": true
745 | },
746 | "source": [
747 | "## Further topics "
748 | ]
749 | },
750 | {
751 | "cell_type": "markdown",
752 | "metadata": {},
753 | "source": [
754 | "* Improved Semantic Representations From Tree-Structured Long Short-Term Memory Networks\n",
755 | "* A Fast Unified Model for Parsing and Sentence Understanding(SPINN)\n",
756 | "* Posting about SPINN"
757 | ]
758 | },
759 | {
760 | "cell_type": "code",
761 | "execution_count": null,
762 | "metadata": {
763 | "collapsed": true
764 | },
765 | "outputs": [],
766 | "source": []
767 | }
768 | ],
769 | "metadata": {
770 | "kernelspec": {
771 | "display_name": "Python 3",
772 | "language": "python",
773 | "name": "python3"
774 | },
775 | "language_info": {
776 | "codemirror_mode": {
777 | "name": "ipython",
778 | "version": 3
779 | },
780 | "file_extension": ".py",
781 | "mimetype": "text/x-python",
782 | "name": "python",
783 | "nbconvert_exporter": "python",
784 | "pygments_lexer": "ipython3",
785 | "version": "3.5.2"
786 | }
787 | },
788 | "nbformat": 4,
789 | "nbformat_minor": 2
790 | }
791 |
--------------------------------------------------------------------------------
/notebooks/10.Dynamic-Memory-Network-for-Question-Answering.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# 10. \tDynamic Memory Networks for Question Answering"
8 | ]
9 | },
10 | {
11 | "cell_type": "markdown",
12 | "metadata": {},
13 | "source": [
14 | "I recommend you take a look at these material first."
15 | ]
16 | },
17 | {
18 | "cell_type": "markdown",
19 | "metadata": {},
20 | "source": [
21 | "* http://web.stanford.edu/class/cs224n/lectures/cs224n-2017-lecture16-DMN-QA.pdf\n",
22 | "* https://arxiv.org/abs/1506.07285"
23 | ]
24 | },
25 | {
26 | "cell_type": "code",
27 | "execution_count": 1,
28 | "metadata": {
29 | "collapsed": true
30 | },
31 | "outputs": [],
32 | "source": [
33 | "import torch\n",
34 | "import torch.nn as nn\n",
35 | "from torch.autograd import Variable\n",
36 | "import torch.optim as optim\n",
37 | "import torch.nn.functional as F\n",
38 | "import nltk\n",
39 | "import random\n",
40 | "import numpy as np\n",
41 | "from collections import Counter, OrderedDict\n",
42 | "import nltk\n",
43 | "from copy import deepcopy\n",
44 | "import os\n",
45 | "import re\n",
46 | "import unicodedata\n",
47 | "flatten = lambda l: [item for sublist in l for item in sublist]\n",
48 | "\n",
49 | "from torch.nn.utils.rnn import PackedSequence, pack_padded_sequence\n",
50 | "random.seed(1024)"
51 | ]
52 | },
53 | {
54 | "cell_type": "code",
55 | "execution_count": 2,
56 | "metadata": {
57 | "collapsed": true
58 | },
59 | "outputs": [],
60 | "source": [
61 | "USE_CUDA = torch.cuda.is_available()\n",
62 | "gpus = [0]\n",
63 | "torch.cuda.set_device(gpus[0])\n",
64 | "\n",
65 | "FloatTensor = torch.cuda.FloatTensor if USE_CUDA else torch.FloatTensor\n",
66 | "LongTensor = torch.cuda.LongTensor if USE_CUDA else torch.LongTensor\n",
67 | "ByteTensor = torch.cuda.ByteTensor if USE_CUDA else torch.ByteTensor"
68 | ]
69 | },
70 | {
71 | "cell_type": "code",
72 | "execution_count": 3,
73 | "metadata": {
74 | "collapsed": true
75 | },
76 | "outputs": [],
77 | "source": [
78 | "def getBatch(batch_size, train_data):\n",
79 | " random.shuffle(train_data)\n",
80 | " sindex = 0\n",
81 | " eindex = batch_size\n",
82 | " while eindex < len(train_data):\n",
83 | " batch = train_data[sindex: eindex]\n",
84 | " temp = eindex\n",
85 | " eindex = eindex + batch_size\n",
86 | " sindex = temp\n",
87 | " yield batch\n",
88 | " \n",
89 | " if eindex >= len(train_data):\n",
90 | " batch = train_data[sindex:]\n",
91 | " yield batch"
92 | ]
93 | },
94 | {
95 | "cell_type": "code",
96 | "execution_count": 4,
97 | "metadata": {
98 | "collapsed": true
99 | },
100 | "outputs": [],
101 | "source": [
102 | "def pad_to_batch(batch, w_to_ix): # for bAbI dataset\n",
103 | " fact,q,a = list(zip(*batch))\n",
104 | " max_fact = max([len(f) for f in fact])\n",
105 | " max_len = max([f.size(1) for f in flatten(fact)])\n",
106 | " max_q = max([qq.size(1) for qq in q])\n",
107 | " max_a = max([aa.size(1) for aa in a])\n",
108 | " \n",
109 | " facts, fact_masks, q_p, a_p = [], [], [], []\n",
110 | " for i in range(len(batch)):\n",
111 | " fact_p_t = []\n",
112 | " for j in range(len(fact[i])):\n",
113 | " if fact[i][j].size(1) < max_len:\n",
114 | " fact_p_t.append(torch.cat([fact[i][j], Variable(LongTensor([w_to_ix['']] * (max_len - fact[i][j].size(1)))).view(1, -1)], 1))\n",
115 | " else:\n",
116 | " fact_p_t.append(fact[i][j])\n",
117 | "\n",
118 | " while len(fact_p_t) < max_fact:\n",
119 | " fact_p_t.append(Variable(LongTensor([w_to_ix['']] * max_len)).view(1, -1))\n",
120 | "\n",
121 | " fact_p_t = torch.cat(fact_p_t)\n",
122 | " facts.append(fact_p_t)\n",
123 | " fact_masks.append(torch.cat([Variable(ByteTensor(tuple(map(lambda s: s ==0, t.data))), volatile=False) for t in fact_p_t]).view(fact_p_t.size(0), -1))\n",
124 | "\n",
125 | " if q[i].size(1) < max_q:\n",
126 | " q_p.append(torch.cat([q[i], Variable(LongTensor([w_to_ix['']] * (max_q - q[i].size(1)))).view(1, -1)], 1))\n",
127 | " else:\n",
128 | " q_p.append(q[i])\n",
129 | "\n",
130 | " if a[i].size(1) < max_a:\n",
131 | " a_p.append(torch.cat([a[i], Variable(LongTensor([w_to_ix['']] * (max_a - a[i].size(1)))).view(1, -1)], 1))\n",
132 | " else:\n",
133 | " a_p.append(a[i])\n",
134 | "\n",
135 | " questions = torch.cat(q_p)\n",
136 | " answers = torch.cat(a_p)\n",
137 | " question_masks = torch.cat([Variable(ByteTensor(tuple(map(lambda s: s ==0, t.data))), volatile=False) for t in questions]).view(questions.size(0), -1)\n",
138 | " \n",
139 | " return facts, fact_masks, questions, question_masks, answers"
140 | ]
141 | },
142 | {
143 | "cell_type": "code",
144 | "execution_count": 5,
145 | "metadata": {
146 | "collapsed": true
147 | },
148 | "outputs": [],
149 | "source": [
150 | "def prepare_sequence(seq, to_index):\n",
151 | " idxs = list(map(lambda w: to_index[w] if to_index.get(w) is not None else to_index[\"\"], seq))\n",
152 | " return Variable(LongTensor(idxs))"
153 | ]
154 | },
155 | {
156 | "cell_type": "markdown",
157 | "metadata": {},
158 | "source": [
159 | "## Data load and Preprocessing "
160 | ]
161 | },
162 | {
163 | "cell_type": "markdown",
164 | "metadata": {},
165 | "source": [
166 | "### bAbI dataset(https://research.fb.com/downloads/babi/)"
167 | ]
168 | },
169 | {
170 | "cell_type": "code",
171 | "execution_count": 24,
172 | "metadata": {
173 | "collapsed": true
174 | },
175 | "outputs": [],
176 | "source": [
177 | "def bAbI_data_load(path):\n",
178 | " try:\n",
179 | " data = open(path).readlines()\n",
180 | " except:\n",
181 | " print(\"Such a file does not exist at %s\".format(path))\n",
182 | " return None\n",
183 | " \n",
184 | " data = [d[:-1] for d in data]\n",
185 | " data_p = []\n",
186 | " fact = []\n",
187 | " qa = []\n",
188 | " try:\n",
189 | " for d in data:\n",
190 | " index = d.split(' ')[0]\n",
191 | " if index == '1':\n",
192 | " fact = []\n",
193 | " qa = []\n",
194 | " if '?' in d:\n",
195 | " temp = d.split('\\t')\n",
196 | " q = temp[0].strip().replace('?', '').split(' ')[1:] + ['?']\n",
197 | " a = temp[1].split() + ['']\n",
198 | " stemp = deepcopy(fact)\n",
199 | " data_p.append([stemp, q, a])\n",
200 | " else:\n",
201 | " tokens = d.replace('.', '').split(' ')[1:] + ['']\n",
202 | " fact.append(tokens)\n",
203 | " except:\n",
204 | " print(\"Please check the data is right\")\n",
205 | " return None\n",
206 | " return data_p"
207 | ]
208 | },
209 | {
210 | "cell_type": "code",
211 | "execution_count": 25,
212 | "metadata": {
213 | "collapsed": true
214 | },
215 | "outputs": [],
216 | "source": [
217 | "train_data = bAbI_data_load('../dataset/bAbI/en-10k/qa5_three-arg-relations_train.txt')"
218 | ]
219 | },
220 | {
221 | "cell_type": "code",
222 | "execution_count": 26,
223 | "metadata": {
224 | "collapsed": false
225 | },
226 | "outputs": [
227 | {
228 | "data": {
229 | "text/plain": [
230 | "[[['Bill', 'travelled', 'to', 'the', 'office', ''],\n",
231 | " ['Bill', 'picked', 'up', 'the', 'football', 'there', ''],\n",
232 | " ['Bill', 'went', 'to', 'the', 'bedroom', ''],\n",
233 | " ['Bill', 'gave', 'the', 'football', 'to', 'Fred', '']],\n",
234 | " ['What', 'did', 'Bill', 'give', 'to', 'Fred', '?'],\n",
235 | " ['football', '']]"
236 | ]
237 | },
238 | "execution_count": 26,
239 | "metadata": {},
240 | "output_type": "execute_result"
241 | }
242 | ],
243 | "source": [
244 | "train_data[0]"
245 | ]
246 | },
247 | {
248 | "cell_type": "code",
249 | "execution_count": 11,
250 | "metadata": {
251 | "collapsed": true
252 | },
253 | "outputs": [],
254 | "source": [
255 | "fact,q,a = list(zip(*train_data))"
256 | ]
257 | },
258 | {
259 | "cell_type": "code",
260 | "execution_count": 12,
261 | "metadata": {
262 | "collapsed": true
263 | },
264 | "outputs": [],
265 | "source": [
266 | "vocab = list(set(flatten(flatten(fact)) + flatten(q) + flatten(a)))"
267 | ]
268 | },
269 | {
270 | "cell_type": "code",
271 | "execution_count": 13,
272 | "metadata": {
273 | "collapsed": true
274 | },
275 | "outputs": [],
276 | "source": [
277 | "word2index={'': 0, '': 1, '': 2, '': 3}\n",
278 | "for vo in vocab:\n",
279 | " if word2index.get(vo) is None:\n",
280 | " word2index[vo] = len(word2index)\n",
281 | "index2word = {v:k for k, v in word2index.items()}"
282 | ]
283 | },
284 | {
285 | "cell_type": "code",
286 | "execution_count": 14,
287 | "metadata": {
288 | "collapsed": false
289 | },
290 | "outputs": [
291 | {
292 | "data": {
293 | "text/plain": [
294 | "44"
295 | ]
296 | },
297 | "execution_count": 14,
298 | "metadata": {},
299 | "output_type": "execute_result"
300 | }
301 | ],
302 | "source": [
303 | "len(word2index)"
304 | ]
305 | },
306 | {
307 | "cell_type": "code",
308 | "execution_count": 15,
309 | "metadata": {
310 | "collapsed": true
311 | },
312 | "outputs": [],
313 | "source": [
314 | "for t in train_data:\n",
315 | " for i,fact in enumerate(t[0]):\n",
316 | " t[0][i] = prepare_sequence(fact, word2index).view(1, -1)\n",
317 | " \n",
318 | " t[1] = prepare_sequence(t[1], word2index).view(1, -1)\n",
319 | " t[2] = prepare_sequence(t[2], word2index).view(1, -1)"
320 | ]
321 | },
322 | {
323 | "cell_type": "markdown",
324 | "metadata": {},
325 | "source": [
326 | "## Modeling "
327 | ]
328 | },
329 | {
330 | "cell_type": "markdown",
331 | "metadata": {},
332 | "source": [
333 | "
\n",
334 | "borrowed image from https://arxiv.org/pdf/1506.07285.pdf"
335 | ]
336 | },
337 | {
338 | "cell_type": "code",
339 | "execution_count": 16,
340 | "metadata": {
341 | "collapsed": true
342 | },
343 | "outputs": [],
344 | "source": [
345 | "class DMN(nn.Module):\n",
346 | " def __init__(self, input_size, hidden_size, output_size, dropout_p=0.1):\n",
347 | " super(DMN, self).__init__()\n",
348 | " \n",
349 | " self.hidden_size = hidden_size\n",
350 | " self.embed = nn.Embedding(input_size, hidden_size, padding_idx=0) #sparse=True)\n",
351 | " self.input_gru = nn.GRU(hidden_size, hidden_size, batch_first=True)\n",
352 | " self.question_gru = nn.GRU(hidden_size, hidden_size, batch_first=True)\n",
353 | " \n",
354 | " self.gate = nn.Sequential(\n",
355 | " nn.Linear(hidden_size * 4, hidden_size),\n",
356 | " nn.Tanh(),\n",
357 | " nn.Linear(hidden_size, 1),\n",
358 | " nn.Sigmoid()\n",
359 | " )\n",
360 | " \n",
361 | " self.attention_grucell = nn.GRUCell(hidden_size, hidden_size)\n",
362 | " self.memory_grucell = nn.GRUCell(hidden_size, hidden_size)\n",
363 | " self.answer_grucell = nn.GRUCell(hidden_size * 2, hidden_size)\n",
364 | " self.answer_fc = nn.Linear(hidden_size, output_size)\n",
365 | " \n",
366 | " self.dropout = nn.Dropout(dropout_p)\n",
367 | " \n",
368 | " def init_hidden(self, inputs):\n",
369 | " hidden = Variable(torch.zeros(1, inputs.size(0), self.hidden_size))\n",
370 | " return hidden.cuda() if USE_CUDA else hidden\n",
371 | " \n",
372 | " def init_weight(self):\n",
373 | " nn.init.xavier_uniform(self.embed.state_dict()['weight'])\n",
374 | " \n",
375 | " for name, param in self.input_gru.state_dict().items():\n",
376 | " if 'weight' in name: nn.init.xavier_normal(param)\n",
377 | " for name, param in self.question_gru.state_dict().items():\n",
378 | " if 'weight' in name: nn.init.xavier_normal(param)\n",
379 | " for name, param in self.gate.state_dict().items():\n",
380 | " if 'weight' in name: nn.init.xavier_normal(param)\n",
381 | " for name, param in self.attention_grucell.state_dict().items():\n",
382 | " if 'weight' in name: nn.init.xavier_normal(param)\n",
383 | " for name, param in self.memory_grucell.state_dict().items():\n",
384 | " if 'weight' in name: nn.init.xavier_normal(param)\n",
385 | " for name, param in self.answer_grucell.state_dict().items():\n",
386 | " if 'weight' in name: nn.init.xavier_normal(param)\n",
387 | " \n",
388 | " nn.init.xavier_normal(self.answer_fc.state_dict()['weight'])\n",
389 | " self.answer_fc.bias.data.fill_(0)\n",
390 | " \n",
391 | " def forward(self, facts, fact_masks, questions, question_masks, num_decode, episodes=3, is_training=False):\n",
392 | " \"\"\"\n",
393 | " facts : (B,T_C,T_I) / LongTensor in List # batch_size, num_of_facts, length_of_each_fact(padded)\n",
394 | " fact_masks : (B,T_C,T_I) / ByteTensor in List # batch_size, num_of_facts, length_of_each_fact(padded)\n",
395 | " questions : (B,T_Q) / LongTensor # batch_size, question_length\n",
396 | " question_masks : (B,T_Q) / ByteTensor # batch_size, question_length\n",
397 | " \"\"\"\n",
398 | " # Input Module\n",
399 | " C = [] # encoded facts\n",
400 | " for fact, fact_mask in zip(facts, fact_masks):\n",
401 | " embeds = self.embed(fact)\n",
402 | " if is_training:\n",
403 | " embeds = self.dropout(embeds)\n",
404 | " hidden = self.init_hidden(fact)\n",
405 | " outputs, hidden = self.input_gru(embeds, hidden)\n",
406 | " real_hidden = []\n",
407 | "\n",
408 | " for i, o in enumerate(outputs): # B,T,D\n",
409 | " real_length = fact_mask[i].data.tolist().count(0) \n",
410 | " real_hidden.append(o[real_length - 1])\n",
411 | "\n",
412 | " C.append(torch.cat(real_hidden).view(fact.size(0), -1).unsqueeze(0))\n",
413 | " \n",
414 | " encoded_facts = torch.cat(C) # B,T_C,D\n",
415 | " \n",
416 | " # Question Module\n",
417 | " embeds = self.embed(questions)\n",
418 | " if is_training:\n",
419 | " embeds = self.dropout(embeds)\n",
420 | " hidden = self.init_hidden(questions)\n",
421 | " outputs, hidden = self.question_gru(embeds, hidden)\n",
422 | " \n",
423 | " if isinstance(question_masks, torch.autograd.variable.Variable):\n",
424 | " real_question = []\n",
425 | " for i, o in enumerate(outputs): # B,T,D\n",
426 | " real_length = question_masks[i].data.tolist().count(0) \n",
427 | " real_question.append(o[real_length - 1])\n",
428 | " encoded_question = torch.cat(real_question).view(questions.size(0), -1) # B,D\n",
429 | " else: # for inference mode\n",
430 | " encoded_question = hidden.squeeze(0) # B,D\n",
431 | " \n",
432 | " # Episodic Memory Module\n",
433 | " memory = encoded_question\n",
434 | " T_C = encoded_facts.size(1)\n",
435 | " B = encoded_facts.size(0)\n",
436 | " for i in range(episodes):\n",
437 | " hidden = self.init_hidden(encoded_facts.transpose(0, 1)[0]).squeeze(0) # B,D\n",
438 | " for t in range(T_C):\n",
439 | " #TODO: fact masking\n",
440 | " #TODO: gate function => softmax\n",
441 | " z = torch.cat([\n",
442 | " encoded_facts.transpose(0, 1)[t] * encoded_question, # B,D , element-wise product\n",
443 | " encoded_facts.transpose(0, 1)[t] * memory, # B,D , element-wise product\n",
444 | " torch.abs(encoded_facts.transpose(0,1)[t] - encoded_question), # B,D\n",
445 | " torch.abs(encoded_facts.transpose(0,1)[t] - memory) # B,D\n",
446 | " ], 1)\n",
447 | " g_t = self.gate(z) # B,1 scalar\n",
448 | " hidden = g_t * self.attention_grucell(encoded_facts.transpose(0, 1)[t], hidden) + (1 - g_t) * hidden\n",
449 | " \n",
450 | " e = hidden\n",
451 | " memory = self.memory_grucell(e, memory)\n",
452 | " \n",
453 | " # Answer Module\n",
454 | " answer_hidden = memory\n",
455 | " start_decode = Variable(LongTensor([[word2index['']] * memory.size(0)])).transpose(0, 1)\n",
456 | " y_t_1 = self.embed(start_decode).squeeze(1) # B,D\n",
457 | " \n",
458 | " decodes = []\n",
459 | " for t in range(num_decode):\n",
460 | " answer_hidden = self.answer_grucell(torch.cat([y_t_1, encoded_question], 1), answer_hidden)\n",
461 | " decodes.append(F.log_softmax(self.answer_fc(answer_hidden),1))\n",
462 | " return torch.cat(decodes, 1).view(B * num_decode, -1)\n"
463 | ]
464 | },
465 | {
466 | "cell_type": "markdown",
467 | "metadata": {},
468 | "source": [
469 | "## Train "
470 | ]
471 | },
472 | {
473 | "cell_type": "markdown",
474 | "metadata": {},
475 | "source": [
476 | "It takes for a while if you use just cpu."
477 | ]
478 | },
479 | {
480 | "cell_type": "code",
481 | "execution_count": 17,
482 | "metadata": {
483 | "collapsed": true
484 | },
485 | "outputs": [],
486 | "source": [
487 | "HIDDEN_SIZE = 80\n",
488 | "BATCH_SIZE = 64\n",
489 | "LR = 0.001\n",
490 | "EPOCH = 50\n",
491 | "NUM_EPISODE = 3\n",
492 | "EARLY_STOPPING = False"
493 | ]
494 | },
495 | {
496 | "cell_type": "code",
497 | "execution_count": 18,
498 | "metadata": {
499 | "collapsed": true
500 | },
501 | "outputs": [],
502 | "source": [
503 | "model = DMN(len(word2index), HIDDEN_SIZE, len(word2index))\n",
504 | "model.init_weight()\n",
505 | "if USE_CUDA:\n",
506 | " model = model.cuda()\n",
507 | "\n",
508 | "loss_function = nn.CrossEntropyLoss(ignore_index=0)\n",
509 | "optimizer = optim.Adam(model.parameters(), lr=LR)"
510 | ]
511 | },
512 | {
513 | "cell_type": "code",
514 | "execution_count": 19,
515 | "metadata": {
516 | "collapsed": false
517 | },
518 | "outputs": [
519 | {
520 | "name": "stdout",
521 | "output_type": "stream",
522 | "text": [
523 | "[0/50] mean_loss : 3.86\n",
524 | "[0/50] mean_loss : 1.32\n",
525 | "[1/50] mean_loss : 0.68\n",
526 | "[1/50] mean_loss : 0.65\n",
527 | "[2/50] mean_loss : 0.62\n",
528 | "[2/50] mean_loss : 0.65\n",
529 | "[3/50] mean_loss : 0.65\n",
530 | "[3/50] mean_loss : 0.64\n",
531 | "[4/50] mean_loss : 0.60\n",
532 | "[4/50] mean_loss : 0.62\n",
533 | "[5/50] mean_loss : 0.63\n",
534 | "[5/50] mean_loss : 0.61\n",
535 | "[6/50] mean_loss : 0.60\n",
536 | "[6/50] mean_loss : 0.61\n",
537 | "[7/50] mean_loss : 0.63\n",
538 | "[7/50] mean_loss : 0.60\n",
539 | "[8/50] mean_loss : 0.62\n",
540 | "[8/50] mean_loss : 0.60\n",
541 | "[9/50] mean_loss : 0.58\n",
542 | "[9/50] mean_loss : 0.60\n",
543 | "[10/50] mean_loss : 0.60\n",
544 | "[10/50] mean_loss : 0.60\n",
545 | "[11/50] mean_loss : 0.62\n",
546 | "[11/50] mean_loss : 0.60\n",
547 | "[12/50] mean_loss : 0.61\n",
548 | "[12/50] mean_loss : 0.60\n",
549 | "[13/50] mean_loss : 0.57\n",
550 | "[13/50] mean_loss : 0.60\n",
551 | "[14/50] mean_loss : 0.59\n",
552 | "[14/50] mean_loss : 0.60\n",
553 | "[15/50] mean_loss : 0.61\n",
554 | "[15/50] mean_loss : 0.60\n",
555 | "[16/50] mean_loss : 0.59\n",
556 | "[16/50] mean_loss : 0.60\n",
557 | "[17/50] mean_loss : 0.59\n",
558 | "[17/50] mean_loss : 0.60\n",
559 | "[18/50] mean_loss : 0.51\n",
560 | "[18/50] mean_loss : 0.50\n",
561 | "[19/50] mean_loss : 0.44\n",
562 | "[19/50] mean_loss : 0.37\n",
563 | "[20/50] mean_loss : 0.30\n",
564 | "[20/50] mean_loss : 0.33\n",
565 | "[21/50] mean_loss : 0.31\n",
566 | "[21/50] mean_loss : 0.31\n",
567 | "[22/50] mean_loss : 0.29\n",
568 | "[22/50] mean_loss : 0.31\n",
569 | "[23/50] mean_loss : 0.29\n",
570 | "[23/50] mean_loss : 0.31\n",
571 | "[24/50] mean_loss : 0.24\n",
572 | "[24/50] mean_loss : 0.31\n",
573 | "[25/50] mean_loss : 0.30\n",
574 | "[25/50] mean_loss : 0.30\n",
575 | "[26/50] mean_loss : 0.14\n",
576 | "[26/50] mean_loss : 0.16\n",
577 | "[27/50] mean_loss : 0.12\n",
578 | "[27/50] mean_loss : 0.15\n",
579 | "[28/50] mean_loss : 0.18\n",
580 | "[28/50] mean_loss : 0.14\n",
581 | "[29/50] mean_loss : 0.12\n",
582 | "[29/50] mean_loss : 0.14\n",
583 | "[30/50] mean_loss : 0.14\n",
584 | "[30/50] mean_loss : 0.14\n",
585 | "[31/50] mean_loss : 0.13\n",
586 | "[31/50] mean_loss : 0.14\n",
587 | "[32/50] mean_loss : 0.11\n",
588 | "[32/50] mean_loss : 0.13\n",
589 | "[33/50] mean_loss : 0.08\n",
590 | "[33/50] mean_loss : 0.06\n",
591 | "[34/50] mean_loss : 0.01\n",
592 | "[34/50] mean_loss : 0.03\n",
593 | "[35/50] mean_loss : 0.01\n",
594 | "Early Stopping!\n"
595 | ]
596 | }
597 | ],
598 | "source": [
599 | "for epoch in range(EPOCH):\n",
600 | " losses = []\n",
601 | " if EARLY_STOPPING: \n",
602 | " break\n",
603 | " \n",
604 | " for i,batch in enumerate(getBatch(BATCH_SIZE, train_data)):\n",
605 | " facts, fact_masks, questions, question_masks, answers = pad_to_batch(batch, word2index)\n",
606 | " \n",
607 | " model.zero_grad()\n",
608 | " pred = model(facts, fact_masks, questions, question_masks, answers.size(1), NUM_EPISODE, True)\n",
609 | " loss = loss_function(pred, answers.view(-1))\n",
610 | " losses.append(loss.data.tolist()[0])\n",
611 | " \n",
612 | " loss.backward()\n",
613 | " optimizer.step()\n",
614 | " \n",
615 | " if i % 100 == 0:\n",
616 | " print(\"[%d/%d] mean_loss : %0.2f\" %(epoch, EPOCH, np.mean(losses)))\n",
617 | " \n",
618 | " if np.mean(losses) < 0.01:\n",
619 | " EARLY_STOPPING = True\n",
620 | " print(\"Early Stopping!\")\n",
621 | " break\n",
622 | " losses = []"
623 | ]
624 | },
625 | {
626 | "cell_type": "markdown",
627 | "metadata": {},
628 | "source": [
629 | "## Test "
630 | ]
631 | },
632 | {
633 | "cell_type": "code",
634 | "execution_count": 21,
635 | "metadata": {
636 | "collapsed": true
637 | },
638 | "outputs": [],
639 | "source": [
640 | "def pad_to_fact(fact, x_to_ix): # this is for inference\n",
641 | " \n",
642 | " max_x = max([s.size(1) for s in fact])\n",
643 | " x_p = []\n",
644 | " for i in range(len(fact)):\n",
645 | " if fact[i].size(1) < max_x:\n",
646 | " x_p.append(torch.cat([fact[i], Variable(LongTensor([x_to_ix['']] * (max_x - fact[i].size(1)))).view(1, -1)], 1))\n",
647 | " else:\n",
648 | " x_p.append(fact[i])\n",
649 | " \n",
650 | " fact = torch.cat(x_p)\n",
651 | " fact_mask = torch.cat([Variable(ByteTensor(tuple(map(lambda s: s ==0, t.data))), volatile=False) for t in fact]).view(fact.size(0), -1)\n",
652 | " return fact, fact_mask"
653 | ]
654 | },
655 | {
656 | "cell_type": "markdown",
657 | "metadata": {},
658 | "source": [
659 | "### Prepare Test data "
660 | ]
661 | },
662 | {
663 | "cell_type": "code",
664 | "execution_count": 27,
665 | "metadata": {
666 | "collapsed": true
667 | },
668 | "outputs": [],
669 | "source": [
670 | "test_data = bAbI_data_load('../dataset/bAbI/en-10k/qa5_three-arg-relations_test.txt')"
671 | ]
672 | },
673 | {
674 | "cell_type": "code",
675 | "execution_count": 28,
676 | "metadata": {
677 | "collapsed": true
678 | },
679 | "outputs": [],
680 | "source": [
681 | "for t in test_data:\n",
682 | " for i, fact in enumerate(t[0]):\n",
683 | " t[0][i] = prepare_sequence(fact, word2index).view(1, -1)\n",
684 | " \n",
685 | " t[1] = prepare_sequence(t[1], word2index).view(1, -1)\n",
686 | " t[2] = prepare_sequence(t[2], word2index).view(1, -1)"
687 | ]
688 | },
689 | {
690 | "cell_type": "markdown",
691 | "metadata": {},
692 | "source": [
693 | "### Accuracy "
694 | ]
695 | },
696 | {
697 | "cell_type": "code",
698 | "execution_count": 31,
699 | "metadata": {
700 | "collapsed": true
701 | },
702 | "outputs": [],
703 | "source": [
704 | "accuracy = 0"
705 | ]
706 | },
707 | {
708 | "cell_type": "code",
709 | "execution_count": 32,
710 | "metadata": {
711 | "collapsed": false
712 | },
713 | "outputs": [
714 | {
715 | "name": "stdout",
716 | "output_type": "stream",
717 | "text": [
718 | "97.39999999999999\n"
719 | ]
720 | }
721 | ],
722 | "source": [
723 | "for t in test_data:\n",
724 | " fact, fact_mask = pad_to_fact(t[0], word2index)\n",
725 | " question = t[1]\n",
726 | " question_mask = Variable(ByteTensor([0] * t[1].size(1)), volatile=False).unsqueeze(0)\n",
727 | " answer = t[2].squeeze(0)\n",
728 | " \n",
729 | " model.zero_grad()\n",
730 | " pred = model([fact], [fact_mask], question, question_mask, answer.size(0), NUM_EPISODE)\n",
731 | " if pred.max(1)[1].data.tolist() == answer.data.tolist():\n",
732 | " accuracy += 1\n",
733 | "\n",
734 | "print(accuracy/len(test_data) * 100)"
735 | ]
736 | },
737 | {
738 | "cell_type": "markdown",
739 | "metadata": {},
740 | "source": [
741 | "### Sample test result "
742 | ]
743 | },
744 | {
745 | "cell_type": "code",
746 | "execution_count": 34,
747 | "metadata": {
748 | "collapsed": false
749 | },
750 | "outputs": [
751 | {
752 | "name": "stdout",
753 | "output_type": "stream",
754 | "text": [
755 | "Facts : \n",
756 | "Bill went back to the bedroom \n",
757 | "Mary went to the office \n",
758 | "Jeff journeyed to the kitchen \n",
759 | "Fred journeyed to the kitchen \n",
760 | "Fred got the milk there \n",
761 | "Fred handed the milk to Jeff \n",
762 | "Jeff passed the milk to Fred \n",
763 | "Fred gave the milk to Jeff \n",
764 | "\n",
765 | "Question : Who received the milk ?\n",
766 | "\n",
767 | "Answer : Jeff \n",
768 | "Prediction : Jeff \n"
769 | ]
770 | }
771 | ],
772 | "source": [
773 | "t = random.choice(test_data)\n",
774 | "fact, fact_mask = pad_to_fact(t[0], word2index)\n",
775 | "question = t[1]\n",
776 | "question_mask = Variable(ByteTensor([0] * t[1].size(1)), volatile=False).unsqueeze(0)\n",
777 | "answer = t[2].squeeze(0)\n",
778 | "\n",
779 | "model.zero_grad()\n",
780 | "pred = model([fact], [fact_mask], question, question_mask, answer.size(0), NUM_EPISODE)\n",
781 | "\n",
782 | "print(\"Facts : \")\n",
783 | "print('\\n'.join([' '.join(list(map(lambda x: index2word[x],f))) for f in fact.data.tolist()]))\n",
784 | "print(\"\")\n",
785 | "print(\"Question : \",' '.join(list(map(lambda x: index2word[x], question.data.tolist()[0]))))\n",
786 | "print(\"\")\n",
787 | "print(\"Answer : \",' '.join(list(map(lambda x: index2word[x], answer.data.tolist()))))\n",
788 | "print(\"Prediction : \",' '.join(list(map(lambda x: index2word[x], pred.max(1)[1].data.tolist()))))"
789 | ]
790 | },
791 | {
792 | "cell_type": "markdown",
793 | "metadata": {
794 | "collapsed": true
795 | },
796 | "source": [
797 | "## Further topics "
798 | ]
799 | },
800 | {
801 | "cell_type": "markdown",
802 | "metadata": {},
803 | "source": [
804 | "* Dynamic Memory Networks for Visual and Textual Question Answering(DMN+)\n",
805 | "* DMN+ Pytorch implementation\n",
806 | "* Dynamic Coattention Networks For Question Answering\n",
807 | "* DCN+: Mixed Objective and Deep Residual Coattention for Question Answering"
808 | ]
809 | },
810 | {
811 | "cell_type": "code",
812 | "execution_count": null,
813 | "metadata": {
814 | "collapsed": true
815 | },
816 | "outputs": [],
817 | "source": []
818 | }
819 | ],
820 | "metadata": {
821 | "kernelspec": {
822 | "display_name": "Python 3",
823 | "language": "python",
824 | "name": "python3"
825 | },
826 | "language_info": {
827 | "codemirror_mode": {
828 | "name": "ipython",
829 | "version": 3
830 | },
831 | "file_extension": ".py",
832 | "mimetype": "text/x-python",
833 | "name": "python",
834 | "nbconvert_exporter": "python",
835 | "pygments_lexer": "ipython3",
836 | "version": "3.5.2"
837 | }
838 | },
839 | "nbformat": 4,
840 | "nbformat_minor": 2
841 | }
842 |
--------------------------------------------------------------------------------
/script/docker-compose.yml:
--------------------------------------------------------------------------------
1 | cs224n:
2 | image: dsksd/deepstudy:0.2
3 | ports:
4 | - "8888:8888"
5 | - "6006:6006"
6 | volumes:
7 | - ../:/notebooks/work
8 | command: /run.sh
9 | tty: true
--------------------------------------------------------------------------------
/script/prepare_dataset.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 |
3 | OUT_DIR="${1:-../dataset}"
4 |
5 | mkdir -v -p $OUT_DIR
6 |
7 | echo "download machine translation dataset from http://www.manythings.org/anki/..."
8 | curl -o "$OUT_DIR/fra-eng.zip" "http://www.manythings.org/anki/fra-eng.zip"
9 | unzip "$OUT_DIR/fra-eng.zip" -d "$OUT_DIR"
10 | rm "$OUT_DIR/fra-eng.zip"
11 | rm "$OUT_DIR/_about.txt"
12 | mv "$OUT_DIR/fra.txt" "$OUT_DIR/eng-fra.txt"
13 |
14 | echo "download ptb dataset... (clone from https://github.com/tomsercu/lstm/tree/master/data"
15 | mkdir -v -p "$OUT_DIR/ptb"
16 | wget "https://raw.githubusercontent.com/tomsercu/lstm/master/data/ptb.train.txt" -P "$OUT_DIR/ptb"
17 | wget "https://raw.githubusercontent.com/tomsercu/lstm/master/data/ptb.valid.txt" -P "$OUT_DIR/ptb"
18 | wget "https://raw.githubusercontent.com/tomsercu/lstm/master/data/ptb.test.txt" -P "$OUT_DIR/ptb"
19 |
20 | echo "download TREC question dataset..."
21 | curl -o "$OUT_DIR/train_5500.label.txt" "http://cogcomp.org/Data/QA/QC/train_5500.label"
22 |
23 | echo "download Stanford sentment treebank..."
24 | curl -o "$OUT_DIR/trainDevTestTrees_PTB.zip" "https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip"
25 | unzip "$OUT_DIR/trainDevTestTrees_PTB.zip" -d "$OUT_DIR"
26 | rm "$OUT_DIR/trainDevTestTrees_PTB.zip"
27 |
28 | echo "download bAbI dataset..."
29 | curl -o "$OUT_DIR/tasks_1-20_v1-2.tar.gz" "http://www.thespermwhale.com/jaseweston/babi/tasks_1-20_v1-2.tar.gz"
30 | tar zxvf "$OUT_DIR/tasks_1-20_v1-2.tar.gz"
31 | mv "tasks_1-20_v1-2" "$OUT_DIR/bAbI"
32 | rm "$OUT_DIR/tasks_1-20_v1-2.tar.gz"
33 |
34 | echo "download nltk dataset..."
35 | python3 -c "import nltk;nltk.download('gutenberg');nltk.download('brown');nltk.download('conll2002');nltk.download('timit')"
36 |
37 | echo "download dependency parser dataset... (clone from https://github.com/rguthrie3/DeepDependencyParsingProblemSet"
38 | mkdir -v -p "$OUT_DIR/dparser"
39 | wget "https://raw.githubusercontent.com/rguthrie3/DeepDependencyParsingProblemSet/master/data/train.txt" -P "$OUT_DIR/dparser"
40 | wget "https://raw.githubusercontent.com/rguthrie3/DeepDependencyParsingProblemSet/master/data/vocab.txt" -P "$OUT_DIR/dparser"
41 | wget "https://raw.githubusercontent.com/rguthrie3/DeepDependencyParsingProblemSet/master/data/dev.txt" -P "$OUT_DIR/dparser"
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from distutils.core import setup
2 | from Cython.Build import cythonize
3 | from Cython.Distutils import build_ext
4 | from distutils.extension import Extension
5 |
6 | setup(name="Corpus",cmdclass = {'build_ext': build_ext},ext_modules = [Extension("Corpus", ["Corpus.pyx"],language='c++')])
--------------------------------------------------------------------------------