├── 1_pretrained_vectors.ipynb
├── 2_context_vectors.ipynb
├── 3_finetuning.ipynb
└── README.md
/1_pretrained_vectors.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "colab": {
6 | "name": "1-pretrained-vectors.ipynb",
7 | "version": "0.3.2",
8 | "provenance": [],
9 | "include_colab_link": true
10 | },
11 | "kernelspec": {
12 | "name": "python3",
13 | "display_name": "Python 3"
14 | },
15 | "accelerator": "GPU"
16 | },
17 | "cells": [
18 | {
19 | "cell_type": "markdown",
20 | "metadata": {
21 | "id": "view-in-github",
22 | "colab_type": "text"
23 | },
24 | "source": [
25 | "
"
26 | ]
27 | },
28 | {
29 | "cell_type": "markdown",
30 | "metadata": {
31 | "id": "GkZM-nI_5Qug",
32 | "colab_type": "text"
33 | },
34 | "source": [
35 | "# Part I: Pre-trained embeddings\n",
36 | "\n",
37 | "We are going to build some PyTorch models that are commonly used for text classification. We also need to build out some infrastructure to run these models.\n",
38 | "\n",
39 | "Once we have the models and the boilerplate stuff out of the way, we can see the impact of pre-trained embeddings for classification tasks. Pre-training methods like word2vec are context-limited language models whose goal is to predict a word given a fixed context, or a fixed context given a word. Pre-trained embeddings are particularly useful for smaller datasets.\n",
40 | "\n",
41 | "Most of this code is inspired or derived from [Baseline](https://github.com/dpressel/baseline/) (Pressel et al, 2018), an open source project for building and evaluating NLP models across a variety of NLP tasks. For this tutorial, we will only concern ourselves with Text Classification using a few useful models.\n",
42 | "\n",
43 | "## Word Embeddings in NLP\n",
44 | "\n",
45 | "We start our models with what are called \"one-hot\" vectors. This is notionally a sparse vector with length `|V|` where V is our vocabulary, and where only the word representated at this temporal location is a 1. The rest are zeros.\n",
46 | "\n",
47 | "\n",
48 | "\n",
49 | "These vectors are not truly represented as a vector, but as an array of indices (in PyTorch, they are `torch.LongTensor`s), one for each word's index in the vocab. This representation is not particularly helpful in DNNs since we want continuous representations for each word.\n",
50 | "\n",
51 | "\n",
52 | "\n",
53 | "The general idea of an embedding is that we want to project from a large one-hot vector to a compact, distributed representation with smaller dimensionality. We can look at this as a matrix multiply between a one-hot vector `|V|` and a weight matrix to a lower dimension of size `|D|`. Since only a single vector value in the one-hot vector is on at a time, this matrix multiply is simplified to an address lookup in that matrix.\n",
54 | "\n",
55 | "\n",
56 | "\n",
57 | "\n",
58 | "In PyTorch, this is called an `nn.Embedding`. In fact, in Torch7, this was called a `nn.LookupTable` which may have actually been a better name, but which seems to have fallen out of favor in DNN toolkits. In this tutorial we are going to refer to multiple types of embeddings, and in this case, we are referring to word vectors, which are typically lookup table embeddings.\n",
59 | "\n",
60 | "Embeddings make up lowest layer of a typical DNN for text and we will feed their output to some pooling mechanism yielding a fixed length representation, followed by some number of fully connected layers.\n",
61 | "\n",
62 | "### Pre-training with Word2Vec\n",
63 | "\n",
64 | "There has been a large amount of research that has gone into building distributed representations for words through pre-training. Some widely used algorithms in NLP include Word2Vec, GloVe and fastText. For instance, word2vec is actually 2 different algorithms with 2 different objectives. They can be thought of a fixed context window non-causal LMs, but they are shallow models and extremely fast to train\n",
65 | "\n",
66 | "\n",
67 | "\n",
68 | "\n",
69 | "* **CBOW objective** given all words in a fixed context window except the middle word, predict the middle word\n",
70 | "* **SkipGram objective**: given a word in a fixed context window, predict all other words in that window\n",
71 | "\n",
72 | "Once we have trained these models, the learned distributed representation matrix can be plugged right in as our embedding weights and this often improves the model significantly.\n",
73 | "\n",
74 | "Before we begin, we will download some data that can be used for our experiments"
75 | ]
76 | },
77 | {
78 | "cell_type": "code",
79 | "metadata": {
80 | "id": "rIAhO5RRf4jL",
81 | "colab_type": "code",
82 | "colab": {
83 | "base_uri": "https://localhost:8080/",
84 | "height": 1000
85 | },
86 | "outputId": "64239e21-a3cd-45dc-998a-bb80b2bdb0ae"
87 | },
88 | "source": [
89 | "!wget https://www.dropbox.com/s/7jyi4pi894bh2qh/sst2.tar.gz?dl=1\n",
90 | "!tar -xzf 'sst2.tar.gz?dl=1'\n",
91 | "\n",
92 | "!wget https://www.dropbox.com/s/08km2ean8bkt7p3/trec.tar.gz?dl=1\n",
93 | "!tar -xzf 'trec.tar.gz?dl=1'\n",
94 | "\n",
95 | "!wget https://www.dropbox.com/s/699kgut7hdb5tg9/GoogleNews-vectors-negative300.bin.gz?dl=1\n",
96 | "!mv 'GoogleNews-vectors-negative300.bin.gz?dl=1' GoogleNews-vectors-negative300.bin.gz\n",
97 | "!gunzip GoogleNews-vectors-negative300.bin.gz"
98 | ],
99 | "execution_count": 1,
100 | "outputs": [
101 | {
102 | "output_type": "stream",
103 | "text": [
104 | "--2019-06-30 01:12:58-- https://www.dropbox.com/s/7jyi4pi894bh2qh/sst2.tar.gz?dl=1\n",
105 | "Resolving www.dropbox.com (www.dropbox.com)... 162.125.81.1, 2620:100:6031:1::a27d:5101\n",
106 | "Connecting to www.dropbox.com (www.dropbox.com)|162.125.81.1|:443... connected.\n",
107 | "HTTP request sent, awaiting response... 301 Moved Permanently\n",
108 | "Location: /s/dl/7jyi4pi894bh2qh/sst2.tar.gz [following]\n",
109 | "--2019-06-30 01:12:58-- https://www.dropbox.com/s/dl/7jyi4pi894bh2qh/sst2.tar.gz\n",
110 | "Reusing existing connection to www.dropbox.com:443.\n",
111 | "HTTP request sent, awaiting response... 302 Found\n",
112 | "Location: https://uc04e0a125eff9ec13572a9a6a9f.dl.dropboxusercontent.com/cd/0/get/AjzPSePvGb0BHoTcpOMyaVFbMZnr4Sv3S4e7fGsv-it1K8QAMhv9okwYdZyizFJPI1YLLmzxNzCo6_aqlZZdzy91IzKq8CoytjxPrD1RMFg2vw/file?dl=1# [following]\n",
113 | "--2019-06-30 01:12:59-- https://uc04e0a125eff9ec13572a9a6a9f.dl.dropboxusercontent.com/cd/0/get/AjzPSePvGb0BHoTcpOMyaVFbMZnr4Sv3S4e7fGsv-it1K8QAMhv9okwYdZyizFJPI1YLLmzxNzCo6_aqlZZdzy91IzKq8CoytjxPrD1RMFg2vw/file?dl=1\n",
114 | "Resolving uc04e0a125eff9ec13572a9a6a9f.dl.dropboxusercontent.com (uc04e0a125eff9ec13572a9a6a9f.dl.dropboxusercontent.com)... 162.125.81.6, 2620:100:6031:6::a27d:5106\n",
115 | "Connecting to uc04e0a125eff9ec13572a9a6a9f.dl.dropboxusercontent.com (uc04e0a125eff9ec13572a9a6a9f.dl.dropboxusercontent.com)|162.125.81.6|:443... connected.\n",
116 | "HTTP request sent, awaiting response... 200 OK\n",
117 | "Length: 1759259 (1.7M) [application/binary]\n",
118 | "Saving to: ‘sst2.tar.gz?dl=1’\n",
119 | "\n",
120 | "sst2.tar.gz?dl=1 100%[===================>] 1.68M 3.94MB/s in 0.4s \n",
121 | "\n",
122 | "2019-06-30 01:13:00 (3.94 MB/s) - ‘sst2.tar.gz?dl=1’ saved [1759259/1759259]\n",
123 | "\n",
124 | "--2019-06-30 01:13:04-- https://www.dropbox.com/s/08km2ean8bkt7p3/trec.tar.gz?dl=1\n",
125 | "Resolving www.dropbox.com (www.dropbox.com)... 162.125.81.1, 2620:100:6031:1::a27d:5101\n",
126 | "Connecting to www.dropbox.com (www.dropbox.com)|162.125.81.1|:443... connected.\n",
127 | "HTTP request sent, awaiting response... 301 Moved Permanently\n",
128 | "Location: /s/dl/08km2ean8bkt7p3/trec.tar.gz [following]\n",
129 | "--2019-06-30 01:13:04-- https://www.dropbox.com/s/dl/08km2ean8bkt7p3/trec.tar.gz\n",
130 | "Reusing existing connection to www.dropbox.com:443.\n",
131 | "HTTP request sent, awaiting response... 302 Found\n",
132 | "Location: https://uc8eb0b5de3f5bf250a5403c857d.dl.dropboxusercontent.com/cd/0/get/AjxktPxIg9dMjFh6y7djMj1wZ2zZfxHGq7BdLeSom_HCGa5lE95AMt4DOMcQTJtqIf72KG3bWidAK0oxCdRDXGtZ4PnqYLPRIhln_ne9KgJ50g/file?dl=1# [following]\n",
133 | "--2019-06-30 01:13:04-- https://uc8eb0b5de3f5bf250a5403c857d.dl.dropboxusercontent.com/cd/0/get/AjxktPxIg9dMjFh6y7djMj1wZ2zZfxHGq7BdLeSom_HCGa5lE95AMt4DOMcQTJtqIf72KG3bWidAK0oxCdRDXGtZ4PnqYLPRIhln_ne9KgJ50g/file?dl=1\n",
134 | "Resolving uc8eb0b5de3f5bf250a5403c857d.dl.dropboxusercontent.com (uc8eb0b5de3f5bf250a5403c857d.dl.dropboxusercontent.com)... 162.125.81.6, 2620:100:6031:6::a27d:5106\n",
135 | "Connecting to uc8eb0b5de3f5bf250a5403c857d.dl.dropboxusercontent.com (uc8eb0b5de3f5bf250a5403c857d.dl.dropboxusercontent.com)|162.125.81.6|:443... connected.\n",
136 | "HTTP request sent, awaiting response... 200 OK\n",
137 | "Length: 117253 (115K) [application/binary]\n",
138 | "Saving to: ‘trec.tar.gz?dl=1’\n",
139 | "\n",
140 | "trec.tar.gz?dl=1 100%[===================>] 114.50K --.-KB/s in 0.01s \n",
141 | "\n",
142 | "2019-06-30 01:13:05 (9.31 MB/s) - ‘trec.tar.gz?dl=1’ saved [117253/117253]\n",
143 | "\n",
144 | "--2019-06-30 01:13:07-- https://www.dropbox.com/s/699kgut7hdb5tg9/GoogleNews-vectors-negative300.bin.gz?dl=1\n",
145 | "Resolving www.dropbox.com (www.dropbox.com)... 162.125.81.1, 2620:100:6031:1::a27d:5101\n",
146 | "Connecting to www.dropbox.com (www.dropbox.com)|162.125.81.1|:443... connected.\n",
147 | "HTTP request sent, awaiting response... 301 Moved Permanently\n",
148 | "Location: /s/dl/699kgut7hdb5tg9/GoogleNews-vectors-negative300.bin.gz [following]\n",
149 | "--2019-06-30 01:13:08-- https://www.dropbox.com/s/dl/699kgut7hdb5tg9/GoogleNews-vectors-negative300.bin.gz\n",
150 | "Reusing existing connection to www.dropbox.com:443.\n",
151 | "HTTP request sent, awaiting response... 302 Found\n",
152 | "Location: https://ucc94c6cb27db59d15946e726c0e.dl.dropboxusercontent.com/cd/0/get/AjxtsvhYKA888x7WZQOTs9rdfEmxm3sk44V8o6XNM13xter70pPBaST3EOBeTVxBoSq76DyBy_ZdiawefmaObDzO0NuQa6g3qWhzEuQ-6iHVEA/file?dl=1# [following]\n",
153 | "--2019-06-30 01:13:08-- https://ucc94c6cb27db59d15946e726c0e.dl.dropboxusercontent.com/cd/0/get/AjxtsvhYKA888x7WZQOTs9rdfEmxm3sk44V8o6XNM13xter70pPBaST3EOBeTVxBoSq76DyBy_ZdiawefmaObDzO0NuQa6g3qWhzEuQ-6iHVEA/file?dl=1\n",
154 | "Resolving ucc94c6cb27db59d15946e726c0e.dl.dropboxusercontent.com (ucc94c6cb27db59d15946e726c0e.dl.dropboxusercontent.com)... 162.125.81.6, 2620:100:6031:6::a27d:5106\n",
155 | "Connecting to ucc94c6cb27db59d15946e726c0e.dl.dropboxusercontent.com (ucc94c6cb27db59d15946e726c0e.dl.dropboxusercontent.com)|162.125.81.6|:443... connected.\n",
156 | "HTTP request sent, awaiting response... 200 OK\n",
157 | "Length: 1743563840 (1.6G) [application/binary]\n",
158 | "Saving to: ‘GoogleNews-vectors-negative300.bin.gz?dl=1’\n",
159 | "\n",
160 | "GoogleNews-vectors- 100%[===================>] 1.62G 44.3MB/s in 43s \n",
161 | "\n",
162 | "2019-06-30 01:13:52 (38.5 MB/s) - ‘GoogleNews-vectors-negative300.bin.gz?dl=1’ saved [1743563840/1743563840]\n",
163 | "\n"
164 | ],
165 | "name": "stdout"
166 | }
167 | ]
168 | },
169 | {
170 | "cell_type": "markdown",
171 | "metadata": {
172 | "id": "OiXst1zUqOa4",
173 | "colab_type": "text"
174 | },
175 | "source": [
176 | ""
177 | ]
178 | },
179 | {
180 | "cell_type": "markdown",
181 | "metadata": {
182 | "id": "u_ZR8A0x6XQh",
183 | "colab_type": "text"
184 | },
185 | "source": [
186 | "## First, lets do some fun stuff\n",
187 | "\n",
188 | "We will start by building out some models that we will reuse later in the tutorial. First, we will build a convolutional neural network (CNN) that can classify text. Basically CNNs learn a kernel that can be used to filter images or text. An example of 2D filtering*:\n",
189 | "\n",
190 | "\n",
191 | "\n",
192 | "In the case of text filtering, we have a one-dimensional filter operation like this*:\n",
193 | "\n",
194 | "\n",
195 | "\n",
196 | "This type of model has been used often in text, including by [Collobert et al 2011](https://ronan.collobert.com/pub/matos/2011_nlp_jmlr.pdf), but we will implement a multiple parallel filter variation of this introduced by [Kim 2014](https://www.aclweb.org/anthology/D14-1181).\n",
197 | "\n",
198 | "### Convolutional Neural Network for Text Classification\n",
199 | "\n",
200 | "We are using PyTorch, so every layer we have is going to inherit `nn.Module`.\n",
201 | "\n",
202 | "#### Convolutions (actually cross correlations)\n",
203 | "\n",
204 | "The first characteristic of this model is that we will have multiple convolutional filter lengths, and some number of filters associated with each length. For each filter of length `K` convolved with a signal of length `T`, the output signal will be `T - K + 1`. To handle the ends of the signal where the filter is hanging off (e.g. centered at 0), we will add some zero-padding. So if we have a filter of length `K=3`, we want to zero-pad the temporal signal by a single pad value on both ends of the signal.\n",
205 | "\n",
206 | "We are going to support multiple parallel filters, so we can add a `torch.nn.Conv1d` for each filter length, followed by a `torch.nn.ReLU` activation layer. Since we have more than one of these, we will create a `nn.ModuleList` to track them. When we call `forward()`, the data will be oriented as $$B \\times C \\times T$$ where `B` is the batch size, `C` is the number of hidden units and `T` is the temporal length of the vector.\n",
207 | "\n",
208 | "#### Pooling\n",
209 | "\n",
210 | "Both of the papers mentioned above do max-over-time pooling over the features. For each feature map in the vector, we simply select the maximum value for that feature map and concatenate all of these together. Our $$B \\times C \\times T$$ vector is then going to be reduced along the time dimension to $$B \\times C$$\n",
211 | "\n",
212 | "*Images courtesy of http://cs231n.github.io/convolutional-networks/"
213 | ]
214 | },
215 | {
216 | "cell_type": "code",
217 | "metadata": {
218 | "id": "PTT1BSfB9kKO",
219 | "colab_type": "code",
220 | "colab": {}
221 | },
222 | "source": [
223 | "import numpy as np\n",
224 | "import torch\n",
225 | "import torch.nn as nn\n",
226 | "import torch.nn.functional as F\n",
227 | "from typing import List, Tuple\n",
228 | "import os\n",
229 | "import io\n",
230 | "import re\n",
231 | "import codecs\n",
232 | "from collections import Counter\n",
233 | "from torch.utils.data import DataLoader, TensorDataset"
234 | ],
235 | "execution_count": 0,
236 | "outputs": []
237 | },
238 | {
239 | "cell_type": "code",
240 | "metadata": {
241 | "id": "rD0S18xj9qi9",
242 | "colab_type": "code",
243 | "colab": {}
244 | },
245 | "source": [
246 | "class ParallelConv(nn.Module):\n",
247 | "\n",
248 | " def __init__(self, input_dims, filters, dropout=0.5):\n",
249 | " super().__init__()\n",
250 | " convs = [] \n",
251 | " self.output_dims = sum([t[1] for t in filters])\n",
252 | " for (filter_length, output_dims) in filters:\n",
253 | " pad = filter_length//2\n",
254 | " conv = nn.Sequential(\n",
255 | " nn.Conv1d(input_dims, output_dims, filter_length, padding=pad),\n",
256 | " nn.ReLU()\n",
257 | " )\n",
258 | " convs.append(conv)\n",
259 | " # Add the module so its managed correctly\n",
260 | " self.convs = nn.ModuleList(convs)\n",
261 | " self.conv_drop = nn.Dropout(dropout)\n",
262 | "\n",
263 | "\n",
264 | " def forward(self, input_bct):\n",
265 | " mots = []\n",
266 | " for conv in self.convs:\n",
267 | " # In Conv1d, data BxCxT, max over time\n",
268 | " conv_out = conv(input_bct)\n",
269 | " mot, _ = conv_out.max(2)\n",
270 | " mots.append(mot)\n",
271 | " mots = torch.cat(mots, 1)\n",
272 | " return self.conv_drop(mots)\n",
273 | "\n",
274 | "class ConvClassifier(nn.Module):\n",
275 | "\n",
276 | " def __init__(self, embeddings, num_classes, embed_dims,\n",
277 | " filters=[(2, 100), (3, 100), (4, 100)],\n",
278 | " dropout=0.5, hidden_units=[]):\n",
279 | " super().__init__()\n",
280 | " self.embeddings = embeddings\n",
281 | " self.dropout = nn.Dropout(dropout)\n",
282 | " self.convs = ParallelConv(embed_dims, filters, dropout)\n",
283 | " \n",
284 | " input_units = self.convs.output_dims\n",
285 | " output_units = self.convs.output_dims\n",
286 | " sequence = []\n",
287 | " for h in hidden_units:\n",
288 | " sequence.append(self.dropout(nn.Linear(input_units, h)))\n",
289 | " input_units = h\n",
290 | " output_units = h\n",
291 | " \n",
292 | " sequence.append(nn.Linear(output_units, num_classes))\n",
293 | " self.outputs = nn.Sequential(*sequence)\n",
294 | "\n",
295 | " def forward(self, inputs):\n",
296 | " one_hots, lengths = inputs\n",
297 | " embed = self.dropout(self.embeddings(one_hots))\n",
298 | " embed = embed.transpose(1, 2).contiguous()\n",
299 | " hidden = self.convs(embed)\n",
300 | " linear = self.outputs(hidden)\n",
301 | " return F.log_softmax(linear, dim=-1)\n"
302 | ],
303 | "execution_count": 0,
304 | "outputs": []
305 | },
306 | {
307 | "cell_type": "markdown",
308 | "metadata": {
309 | "id": "lBN5z1huBPqr",
310 | "colab_type": "text"
311 | },
312 | "source": [
313 | "### LSTM Model\n",
314 | "\n",
315 | "Our second model that we will explore uses Long Short-Term Memory (LSTM) units, which are a form of Recurrent Neural Networks. These models tend to perform extremely well on NLP tasks. Text classification is a simple case, where we give our inputs and take the final LSTM output as a form of pooling. That looks like the **Many-to-One** image in this taxonomy from [Andrej Karpathy's 2015 blog post on using RNNs for character-level language modeling](https://karpathy.github.io/2015/05/21/rnn-effectiveness/)\n",
316 | "\n",
317 | ""
318 | ]
319 | },
320 | {
321 | "cell_type": "code",
322 | "metadata": {
323 | "id": "aeQUDe_BCOYV",
324 | "colab_type": "code",
325 | "colab": {}
326 | },
327 | "source": [
328 | "class LSTMClassifier(nn.Module):\n",
329 | "\n",
330 | " def __init__(self, embeddings, num_classes, embed_dims, rnn_units, rnn_layers=1, dropout=0.5, hidden_units=[]):\n",
331 | " super().__init__()\n",
332 | " self.embeddings = embeddings\n",
333 | " self.dropout = nn.Dropout(dropout)\n",
334 | " self.rnn = torch.nn.LSTM(embed_dims,\n",
335 | " rnn_units,\n",
336 | " rnn_layers,\n",
337 | " dropout=dropout,\n",
338 | " bidirectional=False,\n",
339 | " batch_first=False)\n",
340 | " nn.init.orthogonal_(self.rnn.weight_hh_l0)\n",
341 | " nn.init.orthogonal_(self.rnn.weight_ih_l0)\n",
342 | " sequence = []\n",
343 | " input_units = rnn_units\n",
344 | " output_units = rnn_units\n",
345 | " for h in hidden_units:\n",
346 | " sequence.append(nn.Linear(input_units, h))\n",
347 | " input_units = h\n",
348 | " output_units = h\n",
349 | " \n",
350 | " sequence.append(nn.Linear(output_units, num_classes))\n",
351 | " self.outputs = nn.Sequential(*sequence)\n",
352 | " \n",
353 | " \n",
354 | " def forward(self, inputs):\n",
355 | " one_hots, lengths = inputs\n",
356 | " embed = self.dropout(self.embeddings(one_hots))\n",
357 | " embed = embed.transpose(0, 1)\n",
358 | " packed = torch.nn.utils.rnn.pack_padded_sequence(embed, lengths.tolist())\n",
359 | " _, hidden = self.rnn(packed)\n",
360 | " hidden = hidden[0].view(hidden[0].shape[1:])\n",
361 | " linear = self.outputs(hidden)\n",
362 | " return F.log_softmax(linear, dim=-1)\n"
363 | ],
364 | "execution_count": 0,
365 | "outputs": []
366 | },
367 | {
368 | "cell_type": "markdown",
369 | "metadata": {
370 | "id": "JK47ZmfPCgyz",
371 | "colab_type": "text"
372 | },
373 | "source": [
374 | "## Training our model\n",
375 | "\n",
376 | "To set our model up for training (and evaluation), we need a loss function, some metrics, and an optimizer, along with some training data.\n",
377 | "\n",
378 | "### Defining Metrics\n",
379 | "\n",
380 | "For classification problems, most things we would like to know can be defined in terms of a confusion matrix.\n",
381 | "\n",
382 | "\n",
383 | "\n",
384 | "The class below implements a confusion matrix and provides metrics associated using it. This implementation is taken from verbatim from Baseline (https://github.com/dpressel/baseline/blob/master/python/baseline/confusion.py)"
385 | ]
386 | },
387 | {
388 | "cell_type": "code",
389 | "metadata": {
390 | "id": "kDOWjUTGQaAH",
391 | "colab_type": "code",
392 | "colab": {}
393 | },
394 | "source": [
395 | "\n",
396 | "class ConfusionMatrix:\n",
397 | " \"\"\"Confusion matrix with metrics\n",
398 | "\n",
399 | " This class accumulates classification output, and tracks it in a confusion matrix.\n",
400 | " Metrics are available that use the confusion matrix\n",
401 | " \"\"\"\n",
402 | " def __init__(self, labels):\n",
403 | " \"\"\"Constructor with input labels\n",
404 | "\n",
405 | " :param labels: Either a dictionary (`k=int,v=str`) or an array of labels\n",
406 | " \"\"\"\n",
407 | " if type(labels) is dict:\n",
408 | " self.labels = []\n",
409 | " for i in range(len(labels)):\n",
410 | " self.labels.append(labels[i])\n",
411 | " else:\n",
412 | " self.labels = labels\n",
413 | " nc = len(self.labels)\n",
414 | " self._cm = np.zeros((nc, nc), dtype=np.int)\n",
415 | "\n",
416 | " def add(self, truth, guess):\n",
417 | " \"\"\"Add a single value to the confusion matrix based off `truth` and `guess`\n",
418 | "\n",
419 | " :param truth: The real `y` value (or ground truth label)\n",
420 | " :param guess: The guess for `y` value (or assertion)\n",
421 | " \"\"\"\n",
422 | "\n",
423 | " self._cm[truth, guess] += 1\n",
424 | "\n",
425 | " def __str__(self):\n",
426 | " values = []\n",
427 | " width = max(8, max(len(x) for x in self.labels) + 1)\n",
428 | " for i, label in enumerate([''] + self.labels):\n",
429 | " values += [\"{:>{width}}\".format(label, width=width+1)]\n",
430 | " values += ['\\n']\n",
431 | " for i, label in enumerate(self.labels):\n",
432 | " values += [\"{:>{width}}\".format(label, width=width+1)]\n",
433 | " for j in range(len(self.labels)):\n",
434 | " values += [\"{:{width}d}\".format(self._cm[i, j], width=width + 1)]\n",
435 | " values += ['\\n']\n",
436 | " values += ['\\n']\n",
437 | " return ''.join(values)\n",
438 | "\n",
439 | " def save(self, outfile):\n",
440 | " ordered_fieldnames = OrderedDict([(\"labels\", None)] + [(l, None) for l in self.labels])\n",
441 | " with open(outfile, 'w') as f:\n",
442 | " dw = csv.DictWriter(f, delimiter=',', fieldnames=ordered_fieldnames)\n",
443 | " dw.writeheader()\n",
444 | " for index, row in enumerate(self._cm):\n",
445 | " row_dict = {l: row[i] for i, l in enumerate(self.labels)}\n",
446 | " row_dict.update({\"labels\": self.labels[index]})\n",
447 | " dw.writerow(row_dict)\n",
448 | "\n",
449 | " def reset(self):\n",
450 | " \"\"\"Reset the matrix\n",
451 | " \"\"\"\n",
452 | " self._cm *= 0\n",
453 | "\n",
454 | " def get_correct(self):\n",
455 | " \"\"\"Get the diagonals of the confusion matrix\n",
456 | "\n",
457 | " :return: (``int``) Number of correct classifications\n",
458 | " \"\"\"\n",
459 | " return self._cm.diagonal().sum()\n",
460 | "\n",
461 | " def get_total(self):\n",
462 | " \"\"\"Get total classifications\n",
463 | "\n",
464 | " :return: (``int``) total classifications\n",
465 | " \"\"\"\n",
466 | " return self._cm.sum()\n",
467 | "\n",
468 | " def get_acc(self):\n",
469 | " \"\"\"Get the accuracy\n",
470 | "\n",
471 | " :return: (``float``) accuracy\n",
472 | " \"\"\"\n",
473 | " return float(self.get_correct())/self.get_total()\n",
474 | "\n",
475 | " def get_recall(self):\n",
476 | " \"\"\"Get the recall\n",
477 | "\n",
478 | " :return: (``float``) recall\n",
479 | " \"\"\"\n",
480 | " total = np.sum(self._cm, axis=1)\n",
481 | " total = (total == 0) + total\n",
482 | " return np.diag(self._cm) / total.astype(float)\n",
483 | "\n",
484 | " def get_support(self):\n",
485 | " return np.sum(self._cm, axis=1)\n",
486 | "\n",
487 | " def get_precision(self):\n",
488 | " \"\"\"Get the precision\n",
489 | " :return: (``float``) precision\n",
490 | " \"\"\"\n",
491 | "\n",
492 | " total = np.sum(self._cm, axis=0)\n",
493 | " total = (total == 0) + total\n",
494 | " return np.diag(self._cm) / total.astype(float)\n",
495 | "\n",
496 | " def get_mean_precision(self):\n",
497 | " \"\"\"Get the mean precision across labels\n",
498 | "\n",
499 | " :return: (``float``) mean precision\n",
500 | " \"\"\"\n",
501 | " return np.mean(self.get_precision())\n",
502 | "\n",
503 | " def get_weighted_precision(self):\n",
504 | " return np.sum(self.get_precision() * self.get_support())/float(self.get_total())\n",
505 | "\n",
506 | " def get_mean_recall(self):\n",
507 | " \"\"\"Get the mean recall across labels\n",
508 | "\n",
509 | " :return: (``float``) mean recall\n",
510 | " \"\"\"\n",
511 | " return np.mean(self.get_recall())\n",
512 | "\n",
513 | " def get_weighted_recall(self):\n",
514 | " return np.sum(self.get_recall() * self.get_support())/float(self.get_total())\n",
515 | "\n",
516 | " def get_weighted_f(self, beta=1):\n",
517 | " return np.sum(self.get_class_f(beta) * self.get_support())/float(self.get_total())\n",
518 | "\n",
519 | " def get_macro_f(self, beta=1):\n",
520 | " \"\"\"Get the macro F_b, with adjustable beta (defaulting to F1)\n",
521 | "\n",
522 | " :param beta: (``float``) defaults to 1 (F1)\n",
523 | " :return: (``float``) macro F_b\n",
524 | " \"\"\"\n",
525 | " if beta < 0:\n",
526 | " raise Exception('Beta must be greater than 0')\n",
527 | " return np.mean(self.get_class_f(beta))\n",
528 | "\n",
529 | " def get_class_f(self, beta=1):\n",
530 | " p = self.get_precision()\n",
531 | " r = self.get_recall()\n",
532 | "\n",
533 | " b = beta*beta\n",
534 | " d = (b * p + r)\n",
535 | " d = (d == 0) + d\n",
536 | "\n",
537 | " return (b + 1) * p * r / d\n",
538 | "\n",
539 | " def get_f(self, beta=1):\n",
540 | " \"\"\"Get 2 class F_b, with adjustable beta (defaulting to F1)\n",
541 | "\n",
542 | " :param beta: (``float``) defaults to 1 (F1)\n",
543 | " :return: (``float``) 2-class F_b\n",
544 | " \"\"\"\n",
545 | " p = self.get_precision()[1]\n",
546 | " r = self.get_recall()[1]\n",
547 | " if beta < 0:\n",
548 | " raise Exception('Beta must be greater than 0')\n",
549 | " d = (beta*beta * p + r)\n",
550 | " if d == 0:\n",
551 | " return 0\n",
552 | " return (beta*beta + 1) * p * r / d\n",
553 | "\n",
554 | " def get_all_metrics(self):\n",
555 | " \"\"\"Make a map of metrics suitable for reporting, keyed by metric name\n",
556 | "\n",
557 | " :return: (``dict``) Map of metrics keyed by metric names\n",
558 | " \"\"\"\n",
559 | " metrics = {'acc': self.get_acc()}\n",
560 | " # If 2 class, assume second class is positive AKA 1\n",
561 | " if len(self.labels) == 2:\n",
562 | " metrics['precision'] = self.get_precision()[1]\n",
563 | " metrics['recall'] = self.get_recall()[1]\n",
564 | " metrics['f1'] = self.get_f(1)\n",
565 | " else:\n",
566 | " metrics['mean_precision'] = self.get_mean_precision()\n",
567 | " metrics['mean_recall'] = self.get_mean_recall()\n",
568 | " metrics['macro_f1'] = self.get_macro_f(1)\n",
569 | " metrics['weighted_precision'] = self.get_weighted_precision()\n",
570 | " metrics['weighted_recall'] = self.get_weighted_recall()\n",
571 | " metrics['weighted_f1'] = self.get_weighted_f(1)\n",
572 | " return metrics\n",
573 | "\n",
574 | " def add_batch(self, truth, guess):\n",
575 | " \"\"\"Add a batch of data to the confusion matrix\n",
576 | "\n",
577 | " :param truth: The truth tensor\n",
578 | " :param guess: The guess tensor\n",
579 | " :return:\n",
580 | " \"\"\"\n",
581 | " for truth_i, guess_i in zip(truth, guess):\n",
582 | " self.add(truth_i, guess_i)\n"
583 | ],
584 | "execution_count": 0,
585 | "outputs": []
586 | },
587 | {
588 | "cell_type": "markdown",
589 | "metadata": {
590 | "id": "IOfC-eQR92j9",
591 | "colab_type": "text"
592 | },
593 | "source": [
594 | "Our `Trainer` is simple, but it gets the job done. We will use PyTorch's `DataLoader` to feed our batches to the trainer. The `run()` method cycles a single epoch. \n",
595 | "For every batch, we will do a stochastic gradient minibatch update, and we return the loss and the predictions and ground truth back to the `run()` method for tabulation"
596 | ]
597 | },
598 | {
599 | "cell_type": "code",
600 | "metadata": {
601 | "id": "FaQ2BraAR-qI",
602 | "colab_type": "code",
603 | "colab": {}
604 | },
605 | "source": [
606 | "\n",
607 | "class Trainer:\n",
608 | " def __init__(self, optimizer: torch.optim.Optimizer):\n",
609 | " self.optimizer = optimizer\n",
610 | "\n",
611 | " def run(self, model, labels, train, loss, batch_size): \n",
612 | " model.train() \n",
613 | " train_loader = DataLoader(train, batch_size=batch_size, shuffle=True)\n",
614 | "\n",
615 | " cm = ConfusionMatrix(labels)\n",
616 | "\n",
617 | " for batch in train_loader:\n",
618 | " loss_value, y_pred, y_actual = self.update(model, loss, batch)\n",
619 | " _, best = y_pred.max(1)\n",
620 | " yt = y_actual.cpu().int().numpy()\n",
621 | " yp = best.cpu().int().numpy()\n",
622 | " cm.add_batch(yt, yp)\n",
623 | "\n",
624 | " print(cm.get_all_metrics())\n",
625 | " return cm\n",
626 | " \n",
627 | " def update(self, model, loss, batch):\n",
628 | " self.optimizer.zero_grad()\n",
629 | " x, lengths, y = batch\n",
630 | " lengths, perm_idx = lengths.sort(0, descending=True)\n",
631 | " x_sorted = x[perm_idx]\n",
632 | " y_sorted = y[perm_idx]\n",
633 | " y_sorted = y_sorted.to('cuda:0')\n",
634 | " inputs = (x_sorted.to('cuda:0'), lengths)\n",
635 | " y_pred = model(inputs)\n",
636 | " loss_value = loss(y_pred, y_sorted)\n",
637 | " loss_value.backward()\n",
638 | " self.optimizer.step()\n",
639 | " return loss_value.item(), y_pred, y_sorted"
640 | ],
641 | "execution_count": 0,
642 | "outputs": []
643 | },
644 | {
645 | "cell_type": "markdown",
646 | "metadata": {
647 | "id": "kTV2epU5SZlg",
648 | "colab_type": "text"
649 | },
650 | "source": [
651 | "After training a epoch, we would like to test the validation performance. Our evaluator class is similar to our `Trainer`, but it doesnt update our model -- it just gives us a way to evaluate the model on data"
652 | ]
653 | },
654 | {
655 | "cell_type": "code",
656 | "metadata": {
657 | "id": "baYbdBXqSoa6",
658 | "colab_type": "code",
659 | "colab": {}
660 | },
661 | "source": [
662 | "\n",
663 | "class Evaluator:\n",
664 | " def __init__(self):\n",
665 | " pass\n",
666 | "\n",
667 | " def run(self, model, labels, dataset, batch_size=1):\n",
668 | " model.eval()\n",
669 | " valid_loader = DataLoader(dataset, batch_size=batch_size)\n",
670 | " cm = ConfusionMatrix(labels)\n",
671 | " for batch in valid_loader:\n",
672 | " y_pred, y_actual = self.inference(model, batch)\n",
673 | " _, best = y_pred.max(1)\n",
674 | " yt = y_actual.cpu().int().numpy()\n",
675 | " yp = best.cpu().int().numpy()\n",
676 | " cm.add_batch(yt, yp)\n",
677 | " return cm\n",
678 | "\n",
679 | " def inference(self, model, batch):\n",
680 | " with torch.no_grad():\n",
681 | " x, lengths, y = batch\n",
682 | " lengths, perm_idx = lengths.sort(0, descending=True)\n",
683 | " x_sorted = x[perm_idx]\n",
684 | " y_sorted = y[perm_idx]\n",
685 | " y_sorted = y_sorted.to('cuda:0')\n",
686 | " inputs = (x_sorted.to('cuda:0'), lengths)\n",
687 | " y_pred = model(inputs)\n",
688 | " return y_pred, y_sorted\n"
689 | ],
690 | "execution_count": 0,
691 | "outputs": []
692 | },
693 | {
694 | "cell_type": "markdown",
695 | "metadata": {
696 | "id": "ySvFi6OVRnf5",
697 | "colab_type": "text"
698 | },
699 | "source": [
700 | "We can encapsulate training multiple epochs and testing in a single function. The best model is defined in terms of some metric -- here accuracy, and we only save the checkpoints when we improve on the model. This is called early stopping, and is particularly helpful on smaller datasets"
701 | ]
702 | },
703 | {
704 | "cell_type": "code",
705 | "metadata": {
706 | "id": "mOhAZmODS3N8",
707 | "colab_type": "code",
708 | "colab": {}
709 | },
710 | "source": [
711 | "def fit(model, labels, optimizer, loss, epochs, batch_size, train, valid, test):\n",
712 | "\n",
713 | " trainer = Trainer(optimizer)\n",
714 | " evaluator = Evaluator()\n",
715 | " best_acc = 0.0\n",
716 | " \n",
717 | " for epoch in range(epochs):\n",
718 | " print('EPOCH {}'.format(epoch + 1))\n",
719 | " print('=================================')\n",
720 | " print('Training Results')\n",
721 | " cm = trainer.run(model, labels, train, loss, batch_size)\n",
722 | " print('Validation Results')\n",
723 | " cm = evaluator.run(model, labels, valid)\n",
724 | " print(cm.get_all_metrics())\n",
725 | " if cm.get_acc() > best_acc:\n",
726 | " print('New best model {:.2f}'.format(cm.get_acc()))\n",
727 | " best_acc = cm.get_acc()\n",
728 | " torch.save(model.state_dict(), './checkpoint.pth')\n",
729 | " if test:\n",
730 | " model.load_state_dict(torch.load('./checkpoint.pth'))\n",
731 | " cm = evaluator.run(model, labels, test)\n",
732 | " print('Final result')\n",
733 | " print(cm.get_all_metrics())\n",
734 | " return cm.get_acc()"
735 | ],
736 | "execution_count": 0,
737 | "outputs": []
738 | },
739 | {
740 | "cell_type": "markdown",
741 | "metadata": {
742 | "id": "mp1tjK47TCe_",
743 | "colab_type": "text"
744 | },
745 | "source": [
746 | "### A Reader for our Data\n",
747 | "\n",
748 | "We need a reader to load our data from files and put it into a `Dataset`.\n",
749 | "\n",
750 | "The reader needs to perform a few steps\n",
751 | "\n",
752 | "* **read in sentences and labels**: it should convert the sentences into tokens and record a vocabulary of the labels\n",
753 | "* **vectorize tokens**: it should convert tokens into tensors that comprise rows in our `TensorDataset`\n",
754 | "* **tabulate the vocabulary**: if no vectorizer is provided, we need to build a vocab of attested words. If a vectorizer is provided upfront, we dont need this step"
755 | ]
756 | },
757 | {
758 | "cell_type": "code",
759 | "metadata": {
760 | "id": "Uzf5KiQ1UFRU",
761 | "colab_type": "code",
762 | "colab": {}
763 | },
764 | "source": [
765 | "\n",
766 | "def whitespace_tokenizer(words: str) -> List[str]:\n",
767 | " return words.split() \n",
768 | "\n",
769 | "def sst2_tokenizer(words: str) -> List[str]:\n",
770 | " REPLACE = { \"'s\": \" 's \",\n",
771 | " \"'ve\": \" 've \",\n",
772 | " \"n't\": \" n't \",\n",
773 | " \"'re\": \" 're \",\n",
774 | " \"'d\": \" 'd \",\n",
775 | " \"'ll\": \" 'll \",\n",
776 | " \",\": \" , \",\n",
777 | " \"!\": \" ! \",\n",
778 | " }\n",
779 | " words = words.lower()\n",
780 | " words = re.sub(r\"[^A-Za-z0-9(),!?\\'\\`]\", \" \", words)\n",
781 | " for k, v in REPLACE.items():\n",
782 | " words = words.replace(k, v)\n",
783 | " return [w.strip() for w in words.split()]\n",
784 | "\n",
785 | "\n",
786 | "class Reader:\n",
787 | "\n",
788 | " def __init__(self, files, lowercase=True, min_freq=0,\n",
789 | " tokenizer=sst2_tokenizer, vectorizer=None):\n",
790 | " self.lowercase = lowercase\n",
791 | " self.tokenizer = tokenizer\n",
792 | " build_vocab = vectorizer is None\n",
793 | " self.vectorizer = vectorizer if vectorizer else self._vectorizer\n",
794 | " x = Counter()\n",
795 | " y = Counter()\n",
796 | " for file_name in files:\n",
797 | " if file_name is None:\n",
798 | " continue\n",
799 | " with codecs.open(file_name, encoding='utf-8', mode='r') as f:\n",
800 | " for line in f:\n",
801 | " words = line.split()\n",
802 | " y.update(words[0])\n",
803 | "\n",
804 | " if build_vocab:\n",
805 | " words = self.tokenizer(' '.join(words[1:]))\n",
806 | " words = words if not self.lowercase else [w.lower() for w in words]\n",
807 | " x.update(words)\n",
808 | " self.labels = list(y.keys())\n",
809 | "\n",
810 | " if build_vocab:\n",
811 | " x = dict(filter(lambda cnt: cnt[1] >= min_freq, x.items()))\n",
812 | " alpha = list(x.keys())\n",
813 | " alpha.sort()\n",
814 | " self.vocab = {w: i+1 for i, w in enumerate(alpha)}\n",
815 | " self.vocab['[PAD]'] = 0\n",
816 | "\n",
817 | " self.labels.sort()\n",
818 | "\n",
819 | " def _vectorizer(self, words: List[str]) -> List[int]:\n",
820 | " return [self.vocab.get(w, 0) for w in words]\n",
821 | "\n",
822 | " def load(self, filename: str) -> TensorDataset:\n",
823 | " label2index = {l: i for i, l in enumerate(self.labels)}\n",
824 | " xs = []\n",
825 | " lengths = []\n",
826 | " ys = []\n",
827 | " with codecs.open(filename, encoding='utf-8', mode='r') as f:\n",
828 | " for line in f:\n",
829 | " words = line.split()\n",
830 | " ys.append(label2index[words[0]])\n",
831 | " words = self.tokenizer(' '.join(words[1:]))\n",
832 | " words = words if not self.lowercase else [w.lower() for w in words]\n",
833 | " vec = self.vectorizer(words)\n",
834 | " lengths.append(len(vec))\n",
835 | " xs.append(torch.tensor(vec, dtype=torch.long))\n",
836 | " x_tensor = torch.nn.utils.rnn.pad_sequence(xs, batch_first=True)\n",
837 | " lengths_tensor = torch.tensor(lengths, dtype=torch.long)\n",
838 | " y_tensor = torch.tensor(ys, dtype=torch.long)\n",
839 | " return TensorDataset(x_tensor, lengths_tensor, y_tensor)\n"
840 | ],
841 | "execution_count": 0,
842 | "outputs": []
843 | },
844 | {
845 | "cell_type": "markdown",
846 | "metadata": {
847 | "id": "3fMHlMQCUslk",
848 | "colab_type": "text"
849 | },
850 | "source": [
851 | "### Pre-trained Embeddings\n",
852 | "\n",
853 | "We would like to investigate how pre-training embeddings helps our models improve. To do this, we need a mechanism to load in pre-trained embeddings and convert them into PyTorch's `nn.Embedding` object. Specifically, we wish to support `word2vec`, `GloVe` and `fastText` embeddings. Rest-assured, these are simple file formats, and you do not need any 3rd party dependencies to read them in! We will do it by hand.\n",
854 | "\n",
855 | "For binary files, the first line contains 2 numbers delimited by a space. The first number is the vocab size and the second is the embedding dimension. We then read each line, splitting it by a space and reading the first portion as the vocabulary (token) and the second portion as a binary vector.\n",
856 | "\n",
857 | "For text files, the first line may contain 2 numbers as in the binary file, but for `GloVe` files, this is omitted. We can check if the first line contains the dimensions, and if it doesnt, we can just read in the first vector to figure out its dimension (again its space delimited, but the vector is also space delimited, so we split along the first space to find the token).\n",
858 | "\n",
859 | "Notice that in this code, we have already created an alphabet that we will pass in for each key, so if that word is present in the embedding file, we will use its value, otherwise, we will initialize the vector randomly."
860 | ]
861 | },
862 | {
863 | "cell_type": "code",
864 | "metadata": {
865 | "id": "w2nKdPwIVbFn",
866 | "colab_type": "code",
867 | "colab": {}
868 | },
869 | "source": [
870 | "def init_embeddings(vocab_size, embed_dim, unif):\n",
871 | " return np.random.uniform(-unif, unif, (vocab_size, embed_dim))\n",
872 | " \n",
873 | "\n",
874 | "class EmbeddingsReader:\n",
875 | "\n",
876 | " @staticmethod\n",
877 | " def from_text(filename, vocab, unif=0.25):\n",
878 | " \n",
879 | " with io.open(filename, \"r\", encoding=\"utf-8\") as f:\n",
880 | " for i, line in enumerate(f):\n",
881 | " line = line.rstrip(\"\\n \")\n",
882 | " values = line.split(\" \")\n",
883 | "\n",
884 | " if i == 0:\n",
885 | " # fastText style\n",
886 | " if len(values) == 2:\n",
887 | " weight = init_embeddings(len(vocab), values[1], unif)\n",
888 | " continue\n",
889 | " # glove style\n",
890 | " else:\n",
891 | " weight = init_embeddings(len(vocab), len(values[1:]), unif)\n",
892 | " word = values[0]\n",
893 | " if word in vocab:\n",
894 | " vec = np.asarray(values[1:], dtype=np.float32)\n",
895 | " weight[vocab[word]] = vec\n",
896 | " if '[PAD]' in vocab:\n",
897 | " weight[vocab['[PAD]']] = 0.0\n",
898 | " \n",
899 | " embeddings = nn.Embedding(weight.shape[0], weight.shape[1])\n",
900 | " embeddings.weight = nn.Parameter(torch.from_numpy(weight).float())\n",
901 | " return embeddings, weight.shape[1]\n",
902 | " \n",
903 | " @staticmethod\n",
904 | " def from_binary(filename, vocab, unif=0.25):\n",
905 | " def read_word(f):\n",
906 | "\n",
907 | " s = bytearray()\n",
908 | " ch = f.read(1)\n",
909 | "\n",
910 | " while ch != b' ':\n",
911 | " s.extend(ch)\n",
912 | " ch = f.read(1)\n",
913 | " s = s.decode('utf-8')\n",
914 | " # Only strip out normal space and \\n not other spaces which are words.\n",
915 | " return s.strip(' \\n')\n",
916 | "\n",
917 | " vocab_size = len(vocab)\n",
918 | " with io.open(filename, \"rb\") as f:\n",
919 | " header = f.readline()\n",
920 | " file_vocab_size, embed_dim = map(int, header.split())\n",
921 | " weight = init_embeddings(len(vocab), embed_dim, unif)\n",
922 | " if '[PAD]' in vocab:\n",
923 | " weight[vocab['[PAD]']] = 0.0\n",
924 | " width = 4 * embed_dim\n",
925 | " for i in range(file_vocab_size):\n",
926 | " word = read_word(f)\n",
927 | " raw = f.read(width)\n",
928 | " if word in vocab:\n",
929 | " vec = np.fromstring(raw, dtype=np.float32)\n",
930 | " weight[vocab[word]] = vec\n",
931 | " embeddings = nn.Embedding(weight.shape[0], weight.shape[1])\n",
932 | " embeddings.weight = nn.Parameter(torch.from_numpy(weight).float())\n",
933 | " return embeddings, embed_dim\n",
934 | "\n"
935 | ],
936 | "execution_count": 0,
937 | "outputs": []
938 | },
939 | {
940 | "cell_type": "markdown",
941 | "metadata": {
942 | "id": "SOeo4W0zSvws",
943 | "colab_type": "text"
944 | },
945 | "source": [
946 | "### Now to run some stuff!\n",
947 | "\n",
948 | "We did a lot of work to set things up, but its pretty boilerplate and we will reuse a lot of it. So far, we made 2 classifiers we can run along with code to train and evaluate our models, and a reader to load our data.\n"
949 | ]
950 | },
951 | {
952 | "cell_type": "code",
953 | "metadata": {
954 | "id": "ahoBQQ-LddMr",
955 | "colab_type": "code",
956 | "colab": {}
957 | },
958 | "source": [
959 | "BASE = 'sst2'\n",
960 | "TRAIN = os.path.join(BASE, 'stsa.binary.phrases.train')\n",
961 | "VALID = os.path.join(BASE, 'stsa.binary.dev')\n",
962 | "TEST = os.path.join(BASE, 'stsa.binary.test')\n",
963 | "PRETRAINED_EMBEDDINGS_FILE = 'GoogleNews-vectors-negative300.bin'\n",
964 | "\n"
965 | ],
966 | "execution_count": 0,
967 | "outputs": []
968 | },
969 | {
970 | "cell_type": "markdown",
971 | "metadata": {
972 | "id": "kbWM-ycVeL5b",
973 | "colab_type": "text"
974 | },
975 | "source": [
976 | "Lets read in our datasets:"
977 | ]
978 | },
979 | {
980 | "cell_type": "code",
981 | "metadata": {
982 | "id": "tCwkXTO7eQEx",
983 | "colab_type": "code",
984 | "colab": {}
985 | },
986 | "source": [
987 | "r = Reader((TRAIN, VALID, TEST,))\n",
988 | "train = r.load(TRAIN)\n",
989 | "valid = r.load(VALID)\n",
990 | "test = r.load(TEST)"
991 | ],
992 | "execution_count": 0,
993 | "outputs": []
994 | },
995 | {
996 | "cell_type": "markdown",
997 | "metadata": {
998 | "id": "tJXsq3-xiATu",
999 | "colab_type": "text"
1000 | },
1001 | "source": [
1002 | "## Model trained with randomly initialized embeddings\n",
1003 | "\n",
1004 | "First, we are going to train a model without any pretrained embeddings for 10 epochs. During training, we will see the training and validation performance, and after the final epoch, we will see the results from the best model trained on these epochs. "
1005 | ]
1006 | },
1007 | {
1008 | "cell_type": "code",
1009 | "metadata": {
1010 | "id": "okeGF78yh3y_",
1011 | "colab_type": "code",
1012 | "outputId": "5a5fc139-8da5-400c-b5c5-de84de3d9b0a",
1013 | "colab": {
1014 | "base_uri": "https://localhost:8080/",
1015 | "height": 1000
1016 | }
1017 | },
1018 | "source": [
1019 | "embed_dim = 300\n",
1020 | "embeddings = nn.Embedding(len(r.vocab), embed_dim)\n",
1021 | "model = ConvClassifier(embeddings, len(r.labels), embed_dim)\n",
1022 | "\n",
1023 | "num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
1024 | "print(f\"Model has {num_params} parameters\") \n",
1025 | "\n",
1026 | "\n",
1027 | "model.to('cuda:0')\n",
1028 | "loss = torch.nn.NLLLoss()\n",
1029 | "loss = loss.to('cuda:0')\n",
1030 | "\n",
1031 | "learnable_params = [p for p in model.parameters() if p.requires_grad]\n",
1032 | "optimizer = torch.optim.Adadelta(learnable_params, lr=1.0)\n",
1033 | "\n",
1034 | "fit(model, r.labels, optimizer, loss, 10, 50, train, valid, test)\n"
1035 | ],
1036 | "execution_count": 13,
1037 | "outputs": [
1038 | {
1039 | "output_type": "stream",
1040 | "text": [
1041 | "Model has 5442302 parameters\n",
1042 | "EPOCH 1\n",
1043 | "=================================\n",
1044 | "Training Results\n",
1045 | "{'acc': 0.5926248359558738, 'precision': 0.6179160630500119, 'recall': 0.6762583118388982, 'f1': 0.6457721335924437}\n",
1046 | "Validation Results\n",
1047 | "{'acc': 0.7110091743119266, 'precision': 0.7622950819672131, 'recall': 0.6283783783783784, 'f1': 0.6888888888888889}\n",
1048 | "New best model 0.71\n",
1049 | "EPOCH 2\n",
1050 | "=================================\n",
1051 | "Training Results\n",
1052 | "{'acc': 0.6340224269435168, 'precision': 0.6503253333333333, 'recall': 0.7213611301734542, 'f1': 0.6840038593578207}\n",
1053 | "Validation Results\n",
1054 | "{'acc': 0.6089449541284404, 'precision': 0.5695006747638327, 'recall': 0.9504504504504504, 'f1': 0.7122362869198312}\n",
1055 | "EPOCH 3\n",
1056 | "=================================\n",
1057 | "Training Results\n",
1058 | "{'acc': 0.6637647639713621, 'precision': 0.6763477437133999, 'recall': 0.7433919401784236, 'f1': 0.7082868319298364}\n",
1059 | "Validation Results\n",
1060 | "{'acc': 0.7144495412844036, 'precision': 0.7506426735218509, 'recall': 0.6576576576576577, 'f1': 0.7010804321728692}\n",
1061 | "New best model 0.71\n",
1062 | "EPOCH 4\n",
1063 | "=================================\n",
1064 | "Training Results\n",
1065 | "{'acc': 0.6873481373682775, 'precision': 0.6971548679277991, 'recall': 0.7613289476797842, 'f1': 0.7278300606279975}\n",
1066 | "Validation Results\n",
1067 | "{'acc': 0.7075688073394495, 'precision': 0.6556836902800659, 'recall': 0.8963963963963963, 'f1': 0.7573739295908659}\n",
1068 | "EPOCH 5\n",
1069 | "=================================\n",
1070 | "Training Results\n",
1071 | "{'acc': 0.7016800717246398, 'precision': 0.7109843018933928, 'recall': 0.7695165526870016, 'f1': 0.7390933781833471}\n",
1072 | "Validation Results\n",
1073 | "{'acc': 0.6490825688073395, 'precision': 0.597457627118644, 'recall': 0.9527027027027027, 'f1': 0.734375}\n",
1074 | "EPOCH 6\n",
1075 | "=================================\n",
1076 | "Training Results\n",
1077 | "{'acc': 0.7181819363054014, 'precision': 0.7250941083778342, 'recall': 0.783998674838496, 'f1': 0.7533967777512478}\n",
1078 | "Validation Results\n",
1079 | "{'acc': 0.7591743119266054, 'precision': 0.7378048780487805, 'recall': 0.8175675675675675, 'f1': 0.7756410256410255}\n",
1080 | "New best model 0.76\n",
1081 | "EPOCH 7\n",
1082 | "=================================\n",
1083 | "Training Results\n",
1084 | "{'acc': 0.7320201140837567, 'precision': 0.7382341929658423, 'recall': 0.7932274781703306, 'f1': 0.7647434581251569}\n",
1085 | "Validation Results\n",
1086 | "{'acc': 0.7431192660550459, 'precision': 0.7370689655172413, 'recall': 0.7702702702702703, 'f1': 0.7533039647577091}\n",
1087 | "EPOCH 8\n",
1088 | "=================================\n",
1089 | "Training Results\n",
1090 | "{'acc': 0.7397253154194982, 'precision': 0.7452230704735008, 'recall': 0.7992380321351664, 'f1': 0.7712860095226134}\n",
1091 | "Validation Results\n",
1092 | "{'acc': 0.6788990825688074, 'precision': 0.6209439528023599, 'recall': 0.9481981981981982, 'f1': 0.7504456327985739}\n",
1093 | "EPOCH 9\n",
1094 | "=================================\n",
1095 | "Training Results\n",
1096 | "{'acc': 0.7485089850703603, 'precision': 0.7541725852272727, 'recall': 0.8040890697839513, 'f1': 0.7783313290958025}\n",
1097 | "Validation Results\n",
1098 | "{'acc': 0.7545871559633027, 'precision': 0.7281746031746031, 'recall': 0.8265765765765766, 'f1': 0.7742616033755274}\n",
1099 | "EPOCH 10\n",
1100 | "=================================\n",
1101 | "Training Results\n",
1102 | "{'acc': 0.7572016995621159, 'precision': 0.7618233111935491, 'recall': 0.8115431032442794, 'f1': 0.7858976121728768}\n",
1103 | "Validation Results\n",
1104 | "{'acc': 0.7591743119266054, 'precision': 0.7303149606299213, 'recall': 0.8355855855855856, 'f1': 0.7794117647058825}\n",
1105 | "Final result\n",
1106 | "{'acc': 0.7391543108182317, 'precision': 0.728421052631579, 'recall': 0.7612761276127613, 'f1': 0.7444862829478214}\n"
1107 | ],
1108 | "name": "stdout"
1109 | },
1110 | {
1111 | "output_type": "execute_result",
1112 | "data": {
1113 | "text/plain": [
1114 | "0.7391543108182317"
1115 | ]
1116 | },
1117 | "metadata": {
1118 | "tags": []
1119 | },
1120 | "execution_count": 13
1121 | }
1122 | ]
1123 | },
1124 | {
1125 | "cell_type": "markdown",
1126 | "metadata": {
1127 | "id": "a54pDMAQjXY-",
1128 | "colab_type": "text"
1129 | },
1130 | "source": [
1131 | "Yikes, thats not very encouraging! What about our LSTM?"
1132 | ]
1133 | },
1134 | {
1135 | "cell_type": "code",
1136 | "metadata": {
1137 | "id": "f8s4lSIqjc6G",
1138 | "colab_type": "code",
1139 | "outputId": "0e1b7450-210d-46cf-9c7c-d7120991fe15",
1140 | "colab": {
1141 | "base_uri": "https://localhost:8080/",
1142 | "height": 1000
1143 | }
1144 | },
1145 | "source": [
1146 | "embed_dim = 300\n",
1147 | "embeddings = nn.Embedding(len(r.vocab), embed_dim)\n",
1148 | "model = LSTMClassifier(embeddings, len(r.labels), embed_dim, 100, hidden_units=[100])\n",
1149 | "\n",
1150 | "num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
1151 | "print(f\"Model has {num_params} parameters\") \n",
1152 | "\n",
1153 | "\n",
1154 | "model.to('cuda:0')\n",
1155 | "loss = torch.nn.NLLLoss()\n",
1156 | "loss = loss.to('cuda:0')\n",
1157 | "\n",
1158 | "learnable_params = [p for p in model.parameters() if p.requires_grad]\n",
1159 | "optimizer = torch.optim.Adadelta(learnable_params, lr=1.0)\n",
1160 | "\n",
1161 | "fit(model, r.labels, optimizer, loss, 10, 50, train, valid, test)\n"
1162 | ],
1163 | "execution_count": 14,
1164 | "outputs": [
1165 | {
1166 | "output_type": "stream",
1167 | "text": [
1168 | "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/rnn.py:54: UserWarning: dropout option adds dropout after all but last recurrent layer, so non-zero dropout expects num_layers greater than 1, but got dropout=0.5 and num_layers=1\n",
1169 | " \"num_layers={}\".format(dropout, num_layers))\n"
1170 | ],
1171 | "name": "stderr"
1172 | },
1173 | {
1174 | "output_type": "stream",
1175 | "text": [
1176 | "Model has 5342502 parameters\n",
1177 | "EPOCH 1\n",
1178 | "=================================\n",
1179 | "Training Results\n",
1180 | "{'acc': 0.6416496667143099, 'precision': 0.648109286089027, 'recall': 0.7600511133722994, 'f1': 0.6996307873269656}\n",
1181 | "Validation Results\n",
1182 | "{'acc': 0.7488532110091743, 'precision': 0.7300613496932515, 'recall': 0.8040540540540541, 'f1': 0.765273311897106}\n",
1183 | "New best model 0.75\n",
1184 | "EPOCH 2\n",
1185 | "=================================\n",
1186 | "Training Results\n",
1187 | "{'acc': 0.7467158690765453, 'precision': 0.7576038743550285, 'recall': 0.7921862798457133, 'f1': 0.7745092368734602}\n",
1188 | "Validation Results\n",
1189 | "{'acc': 0.7786697247706422, 'precision': 0.747534516765286, 'recall': 0.8536036036036037, 'f1': 0.7970557308096741}\n",
1190 | "New best model 0.78\n",
1191 | "EPOCH 3\n",
1192 | "=================================\n",
1193 | "Training Results\n",
1194 | "{'acc': 0.7794467327607489, 'precision': 0.7900387712496272, 'recall': 0.8149033342009986, 'f1': 0.8022784456248252}\n",
1195 | "Validation Results\n",
1196 | "{'acc': 0.783256880733945, 'precision': 0.7958236658932715, 'recall': 0.7725225225225225, 'f1': 0.7839999999999999}\n",
1197 | "New best model 0.78\n",
1198 | "EPOCH 4\n",
1199 | "=================================\n",
1200 | "Training Results\n",
1201 | "{'acc': 0.7937786671171113, 'precision': 0.8048943938623654, 'recall': 0.8242267919259803, 'f1': 0.8144458863830334}\n",
1202 | "Validation Results\n",
1203 | "{'acc': 0.7901376146788991, 'precision': 0.7713097713097713, 'recall': 0.8355855855855856, 'f1': 0.8021621621621621}\n",
1204 | "New best model 0.79\n",
1205 | "EPOCH 5\n",
1206 | "=================================\n",
1207 | "Training Results\n",
1208 | "{'acc': 0.8033159652291421, 'precision': 0.814348632359759, 'recall': 0.8313258714120069, 'f1': 0.8227496809096125}\n",
1209 | "Validation Results\n",
1210 | "{'acc': 0.7878440366972477, 'precision': 0.8018648018648019, 'recall': 0.7747747747747747, 'f1': 0.7880870561282932}\n",
1211 | "EPOCH 6\n",
1212 | "=================================\n",
1213 | "Training Results\n",
1214 | "{'acc': 0.8122165772274269, 'precision': 0.8227892183038098, 'recall': 0.8386379232826143, 'f1': 0.8306379787184175}\n",
1215 | "Validation Results\n",
1216 | "{'acc': 0.7889908256880734, 'precision': 0.8125, 'recall': 0.7612612612612613, 'f1': 0.786046511627907}\n",
1217 | "EPOCH 7\n",
1218 | "=================================\n",
1219 | "Training Results\n",
1220 | "{'acc': 0.8170501942542326, 'precision': 0.8279907814791536, 'recall': 0.841666863863319, 'f1': 0.8347728126173487}\n",
1221 | "Validation Results\n",
1222 | "{'acc': 0.8027522935779816, 'precision': 0.7995594713656388, 'recall': 0.8175675675675675, 'f1': 0.8084632516703786}\n",
1223 | "New best model 0.80\n",
1224 | "EPOCH 8\n",
1225 | "=================================\n",
1226 | "Training Results\n",
1227 | "{'acc': 0.8234690297683243, 'precision': 0.8332597224482206, 'recall': 0.8482453441870371, 'f1': 0.8406857571706653}\n",
1228 | "Validation Results\n",
1229 | "{'acc': 0.7924311926605505, 'precision': 0.8065268065268065, 'recall': 0.7792792792792793, 'f1': 0.7926689576174113}\n",
1230 | "EPOCH 9\n",
1231 | "=================================\n",
1232 | "Training Results\n",
1233 | "{'acc': 0.824742401995816, 'precision': 0.8359920588578769, 'recall': 0.846991173477839, 'f1': 0.8414556738839128}\n",
1234 | "Validation Results\n",
1235 | "{'acc': 0.7981651376146789, 'precision': 0.8116279069767441, 'recall': 0.786036036036036, 'f1': 0.7986270022883294}\n",
1236 | "EPOCH 10\n",
1237 | "=================================\n",
1238 | "Training Results\n",
1239 | "{'acc': 0.8271332233209028, 'precision': 0.8384886956115125, 'recall': 0.8486476253579119, 'f1': 0.8435375749735388}\n",
1240 | "Validation Results\n",
1241 | "{'acc': 0.7993119266055045, 'precision': 0.799554565701559, 'recall': 0.8085585585585585, 'f1': 0.8040313549832027}\n",
1242 | "Final result\n",
1243 | "{'acc': 0.8050521691378364, 'precision': 0.7991360691144709, 'recall': 0.8140814081408141, 'f1': 0.8065395095367847}\n"
1244 | ],
1245 | "name": "stdout"
1246 | },
1247 | {
1248 | "output_type": "execute_result",
1249 | "data": {
1250 | "text/plain": [
1251 | "0.8050521691378364"
1252 | ]
1253 | },
1254 | "metadata": {
1255 | "tags": []
1256 | },
1257 | "execution_count": 14
1258 | }
1259 | ]
1260 | },
1261 | {
1262 | "cell_type": "markdown",
1263 | "metadata": {
1264 | "id": "tHtNnvvoj22N",
1265 | "colab_type": "text"
1266 | },
1267 | "source": [
1268 | "## Same model with pre-trained word embeddings\n",
1269 | "\n",
1270 | "The models below are identical to the ones above, the only difference is that we are going to initialize the embeddings using our previously defined `EmbeddingsReader`. First lets take a look at our CNN model again. Notice we only run 5 epochs here instead of 10!"
1271 | ]
1272 | },
1273 | {
1274 | "cell_type": "code",
1275 | "metadata": {
1276 | "id": "cEzWqC2fkNSz",
1277 | "colab_type": "code",
1278 | "outputId": "eae932aa-fa84-41cd-e1b4-2791db35167b",
1279 | "colab": {
1280 | "base_uri": "https://localhost:8080/",
1281 | "height": 663
1282 | }
1283 | },
1284 | "source": [
1285 | "embeddings, embed_dim = EmbeddingsReader.from_binary(PRETRAINED_EMBEDDINGS_FILE, r.vocab)\n",
1286 | "model = ConvClassifier(embeddings, len(r.labels), embed_dim)\n",
1287 | "num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
1288 | "print(f\"Model has {num_params} parameters\") \n",
1289 | "\n",
1290 | "\n",
1291 | "model.to('cuda:0')\n",
1292 | "loss = torch.nn.NLLLoss()\n",
1293 | "loss = loss.to('cuda:0')\n",
1294 | "\n",
1295 | "learnable_params = [p for p in model.parameters() if p.requires_grad]\n",
1296 | "optimizer = torch.optim.Adadelta(learnable_params, lr=1.0)\n",
1297 | "\n",
1298 | "fit(model, r.labels, optimizer, loss, 5, 50, train, valid, test)"
1299 | ],
1300 | "execution_count": 16,
1301 | "outputs": [
1302 | {
1303 | "output_type": "stream",
1304 | "text": [
1305 | "/usr/local/lib/python3.6/dist-packages/ipykernel_launcher.py:60: DeprecationWarning: The binary mode of fromstring is deprecated, as it behaves surprisingly on unicode inputs. Use frombuffer instead\n"
1306 | ],
1307 | "name": "stderr"
1308 | },
1309 | {
1310 | "output_type": "stream",
1311 | "text": [
1312 | "Model has 5442302 parameters\n",
1313 | "EPOCH 1\n",
1314 | "=================================\n",
1315 | "Training Results\n",
1316 | "{'acc': 0.8329153727212484, 'precision': 0.8410831129054712, 'recall': 0.8577817742965995, 'f1': 0.8493503754817998}\n",
1317 | "Validation Results\n",
1318 | "{'acc': 0.8268348623853211, 'precision': 0.9080779944289693, 'recall': 0.7342342342342343, 'f1': 0.8119551681195517}\n",
1319 | "New best model 0.83\n",
1320 | "EPOCH 2\n",
1321 | "=================================\n",
1322 | "Training Results\n",
1323 | "{'acc': 0.8798742220085498, 'precision': 0.8858578775128565, 'recall': 0.8967793842731726, 'f1': 0.8912851750373358}\n",
1324 | "Validation Results\n",
1325 | "{'acc': 0.8474770642201835, 'precision': 0.8373101952277657, 'recall': 0.8693693693693694, 'f1': 0.8530386740331491}\n",
1326 | "New best model 0.85\n",
1327 | "EPOCH 3\n",
1328 | "=================================\n",
1329 | "Training Results\n",
1330 | "{'acc': 0.8955834773456686, 'precision': 0.9008221873462791, 'recall': 0.9100309993137556, 'f1': 0.9054031783402}\n",
1331 | "Validation Results\n",
1332 | "{'acc': 0.8486238532110092, 'precision': 0.8436123348017621, 'recall': 0.8626126126126126, 'f1': 0.8530066815144767}\n",
1333 | "New best model 0.85\n",
1334 | "EPOCH 4\n",
1335 | "=================================\n",
1336 | "Training Results\n",
1337 | "{'acc': 0.9059523654838165, 'precision': 0.9103606664948091, 'recall': 0.9192361390473035, 'f1': 0.91477687507359}\n",
1338 | "Validation Results\n",
1339 | "{'acc': 0.841743119266055, 'precision': 0.9090909090909091, 'recall': 0.7657657657657657, 'f1': 0.8312958435207825}\n",
1340 | "EPOCH 5\n",
1341 | "=================================\n",
1342 | "Training Results\n",
1343 | "{'acc': 0.9139694130793519, 'precision': 0.9181293410925474, 'recall': 0.9258856101658818, 'f1': 0.9219911634756995}\n",
1344 | "Validation Results\n",
1345 | "{'acc': 0.8428899082568807, 'precision': 0.851258581235698, 'recall': 0.8378378378378378, 'f1': 0.844494892167991}\n",
1346 | "Final result\n",
1347 | "{'acc': 0.8725974739154311, 'precision': 0.8537095088819227, 'recall': 0.8987898789878987, 'f1': 0.8756698821007501}\n"
1348 | ],
1349 | "name": "stdout"
1350 | },
1351 | {
1352 | "output_type": "execute_result",
1353 | "data": {
1354 | "text/plain": [
1355 | "0.8725974739154311"
1356 | ]
1357 | },
1358 | "metadata": {
1359 | "tags": []
1360 | },
1361 | "execution_count": 16
1362 | }
1363 | ]
1364 | },
1365 | {
1366 | "cell_type": "markdown",
1367 | "metadata": {
1368 | "id": "QcnajM3klDx9",
1369 | "colab_type": "text"
1370 | },
1371 | "source": [
1372 | "Much better! And now the LSTM!"
1373 | ]
1374 | },
1375 | {
1376 | "cell_type": "code",
1377 | "metadata": {
1378 | "id": "zqJ12jfAkxoX",
1379 | "colab_type": "code",
1380 | "outputId": "76774035-dad3-445c-a7ef-20005a636a8e",
1381 | "colab": {
1382 | "base_uri": "https://localhost:8080/",
1383 | "height": 683
1384 | }
1385 | },
1386 | "source": [
1387 | "model = LSTMClassifier(embeddings, len(r.labels), embed_dim, 100, hidden_units=[100])\n",
1388 | "\n",
1389 | "num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
1390 | "print(f\"Model has {num_params} parameters\") \n",
1391 | "\n",
1392 | "\n",
1393 | "model.to('cuda:0')\n",
1394 | "loss = torch.nn.NLLLoss()\n",
1395 | "loss = loss.to('cuda:0')\n",
1396 | "\n",
1397 | "learnable_params = [p for p in model.parameters() if p.requires_grad]\n",
1398 | "optimizer = torch.optim.Adadelta(learnable_params, lr=1.0)\n",
1399 | "\n",
1400 | "fit(model, r.labels, optimizer, loss, 5, 50, train, valid, test)"
1401 | ],
1402 | "execution_count": 17,
1403 | "outputs": [
1404 | {
1405 | "output_type": "stream",
1406 | "text": [
1407 | "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/rnn.py:54: UserWarning: dropout option adds dropout after all but last recurrent layer, so non-zero dropout expects num_layers greater than 1, but got dropout=0.5 and num_layers=1\n",
1408 | " \"num_layers={}\".format(dropout, num_layers))\n"
1409 | ],
1410 | "name": "stderr"
1411 | },
1412 | {
1413 | "output_type": "stream",
1414 | "text": [
1415 | "Model has 5342502 parameters\n",
1416 | "EPOCH 1\n",
1417 | "=================================\n",
1418 | "Training Results\n",
1419 | "{'acc': 0.870012084042567, 'precision': 0.8834494400722794, 'recall': 0.8792683215409736, 'f1': 0.8813539220569748}\n",
1420 | "Validation Results\n",
1421 | "{'acc': 0.8371559633027523, 'precision': 0.8081632653061225, 'recall': 0.8918918918918919, 'f1': 0.8479657387580299}\n",
1422 | "New best model 0.84\n",
1423 | "EPOCH 2\n",
1424 | "=================================\n",
1425 | "Training Results\n",
1426 | "{'acc': 0.904562050908902, 'precision': 0.9189545934530094, 'recall': 0.9061028419981543, 'f1': 0.9124834677755669}\n",
1427 | "Validation Results\n",
1428 | "{'acc': 0.8646788990825688, 'precision': 0.9137055837563451, 'recall': 0.8108108108108109, 'f1': 0.8591885441527446}\n",
1429 | "New best model 0.86\n",
1430 | "EPOCH 3\n",
1431 | "=================================\n",
1432 | "Training Results\n",
1433 | "{'acc': 0.9171658372422395, 'precision': 0.9286806517895542, 'recall': 0.9197804018078989, 'f1': 0.9242090996635478}\n",
1434 | "Validation Results\n",
1435 | "{'acc': 0.8600917431192661, 'precision': 0.8340248962655602, 'recall': 0.9054054054054054, 'f1': 0.8682505399568036}\n",
1436 | "EPOCH 4\n",
1437 | "=================================\n",
1438 | "Training Results\n",
1439 | "{'acc': 0.9252738399968815, 'precision': 0.9348261076703192, 'recall': 0.9286542511654322, 'f1': 0.9317299588076782}\n",
1440 | "Validation Results\n",
1441 | "{'acc': 0.856651376146789, 'precision': 0.8444924406047516, 'recall': 0.8806306306306306, 'f1': 0.8621830209481808}\n",
1442 | "EPOCH 5\n",
1443 | "=================================\n",
1444 | "Training Results\n",
1445 | "{'acc': 0.9308091111082236, 'precision': 0.9387294497766796, 'recall': 0.9350197591045695, 'f1': 0.9368709321762635}\n",
1446 | "Validation Results\n",
1447 | "{'acc': 0.8635321100917431, 'precision': 0.9156010230179028, 'recall': 0.8063063063063063, 'f1': 0.8574850299401198}\n",
1448 | "Final result\n",
1449 | "{'acc': 0.8802855573860516, 'precision': 0.9157641395908543, 'recall': 0.8371837183718371, 'f1': 0.8747126436781609}\n"
1450 | ],
1451 | "name": "stdout"
1452 | },
1453 | {
1454 | "output_type": "execute_result",
1455 | "data": {
1456 | "text/plain": [
1457 | "0.8802855573860516"
1458 | ]
1459 | },
1460 | "metadata": {
1461 | "tags": []
1462 | },
1463 | "execution_count": 17
1464 | }
1465 | ]
1466 | },
1467 | {
1468 | "cell_type": "markdown",
1469 | "metadata": {
1470 | "id": "NpfM53E-lS9y",
1471 | "colab_type": "text"
1472 | },
1473 | "source": [
1474 | "#### A quick note about these models on this data\n",
1475 | "\n",
1476 | "Both of these models are surprisingly strong baselines and do fairly well on this dataset averaged over many runs. Even with only 2-5 epochs of data it is quite common to see scores higher than in the Kim 2014 paper."
1477 | ]
1478 | },
1479 | {
1480 | "cell_type": "markdown",
1481 | "metadata": {
1482 | "id": "O4mLnxyIlM4u",
1483 | "colab_type": "text"
1484 | },
1485 | "source": [
1486 | "## Conclusions\n",
1487 | "\n",
1488 | "### Its not hard to get good performance with a Deep Learning model for Text Classification\n",
1489 | "\n",
1490 | "We saw above how to get good results on the SST-2 dataset using fairly simple deep learning models, even with very few training epochs. This behavior is not limited to a single dataset -- these results have been shown over and over. Also, using PyTorch, we were able to code an entire pipeline in this minimalistic notebook.\n",
1491 | "\n",
1492 | "### Pre-trained embeddings often help\n",
1493 | "\n",
1494 | "We can see that pre-trained embeddings can have a massive impact on the performance of our models, especially for smaller datasets. The `word2vec` algorithm caused an explosion in the NLP community -- even though pre-training embeddings had been widely studied prior to that work, the results were reliable and fast. `GloVe` and `fastText` embeddings came shortly thereafter, and all 3 models are in quite common use today. The code above can load any of these flavors of embeddings and incorporate them into downstream models for large improvements.\n",
1495 | "\n",
1496 | "For large datasets, like those used in Language Modeling and Neural Machine Translation, models are typically trained from random embeddings, which are sufficient in those cases.\n",
1497 | "\n",
1498 | "### Incorporating pre-trained embeddings into your model is simple\n",
1499 | "\n",
1500 | "The file formats are very simple to read and can be incorporated with only a few lines of code. In some cases, memory-mapping the file can increase the loading speed. This is implemented in [Baseline](https://github.com/dpressel/baseline/)\n",
1501 | "\n",
1502 | "### Some Further Resources\n",
1503 | "\n",
1504 | "- [The Unreasonable Effectiveness of Recurrent Neural Networks, Karpathy, 2015](https://karpathy.github.io/2015/05/21/rnn-effectiveness/)\n",
1505 | "- Tensorflow tutorial for word2vec: https://www.tensorflow.org/tutorials/representation/word2vec\n",
1506 | "- Tensorflow docs on feature columns (for images above): https://www.tensorflow.org/guide/feature_columns\n",
1507 | "- Xin Rong, wrote some amazing software to visualize word embeddings and the training process. Sadly, Xin is no longer with us -- he was a great researcher and an awesome guy. We miss him.\n",
1508 | " - https://ronxin.github.io/wevi/\n",
1509 | " - Accompanying talk from a2-dlearn 2015: https://www.youtube.com/channel/UCVdeq2cIxnujw2kTdzg2N5g\n",
1510 | " - https://ronxin.github.io/lamvi/dist/#model=word2vec&backend=browser&query_in=darcy&query_out=G_bennet,B_circumstances\n",
1511 | " - Accompanying paper for Lamvi [Visual Tools for Debugging Neural Language Models, Rong & Adar, 2016](http://www.cond.org/ICML16_NeuralVis.pdf)\n",
1512 | "\n"
1513 | ]
1514 | }
1515 | ]
1516 | }
--------------------------------------------------------------------------------
/2_context_vectors.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "colab": {
6 | "name": "2-context-vectors.ipynb",
7 | "version": "0.3.2",
8 | "provenance": [],
9 | "include_colab_link": true
10 | },
11 | "kernelspec": {
12 | "name": "python3",
13 | "display_name": "Python 3"
14 | },
15 | "accelerator": "GPU"
16 | },
17 | "cells": [
18 | {
19 | "cell_type": "markdown",
20 | "metadata": {
21 | "id": "view-in-github",
22 | "colab_type": "text"
23 | },
24 | "source": [
25 | "
"
26 | ]
27 | },
28 | {
29 | "cell_type": "markdown",
30 | "metadata": {
31 | "id": "WGiURKhLmONv",
32 | "colab_type": "text"
33 | },
34 | "source": [
35 | "# Part II: Contextualized embeddings\n",
36 | "\n",
37 | "\n",
38 | "In this section, we are going to learn how to train an LSTM-based word-level language model. Then we will take load a pre-trained langage model checkpoint and use everything below the output layers as the lower layers of our previously defined classification model. We dont really need to change anything else, we just need to pass this whole network as the `embedding` parameter to the model.\n",
39 | "\n",
40 | "## LSTM Language Models\n",
41 | "\n",
42 | "We are going to quickly build an LSTM language model so that we can see how the training works. For both our objectives and our metrics, we are interested in the perplexity, which is the exponentiated cross-entropy loss."
43 | ]
44 | },
45 | {
46 | "cell_type": "code",
47 | "metadata": {
48 | "id": "PvcXdbNZq8NG",
49 | "colab_type": "code",
50 | "colab": {
51 | "base_uri": "https://localhost:8080/",
52 | "height": 238
53 | },
54 | "outputId": "78355358-4f8d-4a9a-a156-292a451e04d5"
55 | },
56 | "source": [
57 | "!wget https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-v1.zip\n",
58 | "!unzip wikitext-2-v1.zip"
59 | ],
60 | "execution_count": 1,
61 | "outputs": [
62 | {
63 | "output_type": "stream",
64 | "text": [
65 | "--2019-06-30 19:10:48-- https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-v1.zip\n",
66 | "Resolving s3.amazonaws.com (s3.amazonaws.com)... 52.216.134.253\n",
67 | "Connecting to s3.amazonaws.com (s3.amazonaws.com)|52.216.134.253|:443... connected.\n",
68 | "HTTP request sent, awaiting response... 200 OK\n",
69 | "Length: 4475746 (4.3M) [application/zip]\n",
70 | "Saving to: ‘wikitext-2-v1.zip.5’\n",
71 | "\n",
72 | "wikitext-2-v1.zip.5 100%[===================>] 4.27M 18.2MB/s in 0.2s \n",
73 | "\n",
74 | "2019-06-30 19:10:49 (18.2 MB/s) - ‘wikitext-2-v1.zip.5’ saved [4475746/4475746]\n",
75 | "\n",
76 | "Archive: wikitext-2-v1.zip\n",
77 | "replace wikitext-2/wiki.test.tokens? [y]es, [n]o, [A]ll, [N]one, [r]ename: "
78 | ],
79 | "name": "stdout"
80 | }
81 | ]
82 | },
83 | {
84 | "cell_type": "markdown",
85 | "metadata": {
86 | "id": "TNUhkQC_dzmR",
87 | "colab_type": "text"
88 | },
89 | "source": [
90 | "Our LSTM model will be a word-based model. We will have a randomly trained embedding to start and we will put each output timestep through our LSTM blocks and then project to the output vocabulary size. At every step of training, we will detach our hidden states, preventing full backpropagation, but we will initialize the new batch from our old hidden state. We will also create a function that resets the hidden state, which we will use at the start of each epoch to zero out the hidden states."
91 | ]
92 | },
93 | {
94 | "cell_type": "code",
95 | "metadata": {
96 | "id": "wl5CfAsAuaTu",
97 | "colab_type": "code",
98 | "colab": {}
99 | },
100 | "source": [
101 | "import torch\n",
102 | "import torch.nn as nn\n",
103 | "import torch.nn.functional as F\n",
104 | "from typing import List, Tuple\n",
105 | "import os\n",
106 | "import io\n",
107 | "import re\n",
108 | "import codecs\n",
109 | "import numpy as np\n",
110 | "from collections import Counter\n",
111 | "import math\n",
112 | "import time\n",
113 | "\n",
114 | "class LSTMLanguageModel(nn.Module):\n",
115 | "\n",
116 | " def __init__(self, vocab_size, embed_dim, hidden_dim, dropout=0.5, layers=2):\n",
117 | " super().__init__()\n",
118 | " self.layers = layers\n",
119 | " self.hidden_dim = hidden_dim\n",
120 | " self.embed = nn.Embedding(vocab_size, embed_dim)\n",
121 | " self.rnn = torch.nn.LSTM(embed_dim,\n",
122 | " hidden_dim,\n",
123 | " layers,\n",
124 | " dropout=dropout,\n",
125 | " bidirectional=False,\n",
126 | " batch_first=True)\n",
127 | " self.proj = nn.Linear(embed_dim, vocab_size)\n",
128 | " self.proj.bias.data.zero_()\n",
129 | "\n",
130 | " # Tie weights\n",
131 | " self.proj.weight = self.embed.weight\n",
132 | "\n",
133 | " def forward(self, x, hidden):\n",
134 | " emb = self.embed(x)\n",
135 | " decoded, hidden = self.rnn(emb, hidden)\n",
136 | " return self.proj(decoded), hidden\n",
137 | " \n",
138 | " def init_hidden(self, batchsz):\n",
139 | " weight = next(self.parameters()).data\n",
140 | " return (torch.autograd.Variable(weight.new(self.layers, batchsz, self.hidden_dim).zero_()),\n",
141 | " torch.autograd.Variable(weight.new(self.layers, batchsz, self.hidden_dim).zero_()))\n"
142 | ],
143 | "execution_count": 0,
144 | "outputs": []
145 | },
146 | {
147 | "cell_type": "markdown",
148 | "metadata": {
149 | "id": "_fJp1BpIe3gu",
150 | "colab_type": "text"
151 | },
152 | "source": [
153 | "Our dataset reader will read in a sequence of words and vectorize them. We would like this to be a long sequence of text (like maybe a book), and we will read this in contiguously. Our task is to learn to predict the next word, so we will end up using this sequence for input and output"
154 | ]
155 | },
156 | {
157 | "cell_type": "code",
158 | "metadata": {
159 | "id": "_C28CdeCrsj5",
160 | "colab_type": "code",
161 | "colab": {}
162 | },
163 | "source": [
164 | "\n",
165 | "\n",
166 | "\n",
167 | "class WordDatasetReader(object):\n",
168 | " \"\"\"Provide a base-class to do operations to read words to tensors\n",
169 | " \"\"\"\n",
170 | " def __init__(self, nctx, vectorizer=None):\n",
171 | " self.nctx = nctx\n",
172 | " self.num_words = {}\n",
173 | " self.vectorizer = vectorizer if vectorizer else self._vectorizer\n",
174 | "\n",
175 | " def build_vocab(self, files, min_freq=0):\n",
176 | " x = Counter()\n",
177 | "\n",
178 | " for file in files:\n",
179 | " if file is None:\n",
180 | " continue\n",
181 | " self.num_words[file] = 0\n",
182 | " with codecs.open(file, encoding='utf-8', mode='r') as f:\n",
183 | " sentences = []\n",
184 | " for line in f:\n",
185 | " split_sentence = line.split() + ['']\n",
186 | " self.num_words[file] += len(split_sentence)\n",
187 | " sentences += split_sentence\n",
188 | " x.update(Counter(sentences))\n",
189 | " x = dict(filter(lambda cnt: cnt[1] >= min_freq, x.items()))\n",
190 | " alpha = list(x.keys())\n",
191 | " alpha.sort()\n",
192 | " self.vocab = {w: i+1 for i, w in enumerate(alpha)}\n",
193 | " self.vocab['[PAD]'] = 0\n",
194 | " \n",
195 | " def _vectorizer(self, words: List[str]) -> List[int]:\n",
196 | " return [self.vocab.get(w, 0) for w in words]\n",
197 | "\n",
198 | " \n",
199 | " def load_features(self, filename):\n",
200 | "\n",
201 | " with codecs.open(filename, encoding='utf-8', mode='r') as f:\n",
202 | " sentences = []\n",
203 | " for line in f:\n",
204 | " sentences += line.strip().split() + ['']\n",
205 | " return torch.tensor(self.vectorizer(sentences), dtype=torch.long)\n",
206 | "\n",
207 | " def load(self, filename, batch_size):\n",
208 | " x_tensor = self.load_features(filename)\n",
209 | " rest = x_tensor.shape[0]//batch_size\n",
210 | " num_steps = rest // self.nctx\n",
211 | " # if num_examples is divisible by batchsz * nctx (equivalent to rest is divisible by nctx), we\n",
212 | " # have a problem. reduce rest in that case.\n",
213 | "\n",
214 | " if rest % self.nctx == 0:\n",
215 | " rest = rest-1\n",
216 | " trunc = batch_size * rest\n",
217 | " \n",
218 | " x_tensor = x_tensor.narrow(0, 0, trunc)\n",
219 | " x_tensor = x_tensor.view(batch_size, -1).contiguous()\n",
220 | " return x_tensor\n",
221 | " \n",
222 | " "
223 | ],
224 | "execution_count": 0,
225 | "outputs": []
226 | },
227 | {
228 | "cell_type": "markdown",
229 | "metadata": {
230 | "id": "EEw2NNZxfp9U",
231 | "colab_type": "text"
232 | },
233 | "source": [
234 | "This class will keep track of our running average as we go so we dont have to remember to average things in our loops"
235 | ]
236 | },
237 | {
238 | "cell_type": "code",
239 | "metadata": {
240 | "id": "pyuDxweqgD01",
241 | "colab_type": "code",
242 | "colab": {}
243 | },
244 | "source": [
245 | "class Average(object):\n",
246 | " def __init__(self, name, fmt=':f'):\n",
247 | " self.name = name\n",
248 | " self.fmt = fmt\n",
249 | " self.val = 0\n",
250 | " self.avg = 0\n",
251 | " self.sum = 0\n",
252 | " self.count = 0\n",
253 | "\n",
254 | " def update(self, val, n=1):\n",
255 | " self.val = val\n",
256 | " self.sum += val * n\n",
257 | " self.count += n\n",
258 | " self.avg = self.sum / self.count\n",
259 | "\n",
260 | " def __str__(self):\n",
261 | " fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'\n",
262 | " return fmtstr.format(**self.__dict__)\n"
263 | ],
264 | "execution_count": 0,
265 | "outputs": []
266 | },
267 | {
268 | "cell_type": "markdown",
269 | "metadata": {
270 | "id": "JnLnechhgIqM",
271 | "colab_type": "text"
272 | },
273 | "source": [
274 | "We are going to train on batches of contiguous text. Our batches will have been pre-created by the loader. Each batch will be `BxT` where `B` is the batch size we specified, and `T` is the number of backprop steps through time."
275 | ]
276 | },
277 | {
278 | "cell_type": "code",
279 | "metadata": {
280 | "id": "z-BDzx2hwIoS",
281 | "colab_type": "code",
282 | "colab": {}
283 | },
284 | "source": [
285 | "class SequenceCriterion(nn.Module):\n",
286 | "\n",
287 | " def __init__(self):\n",
288 | " super().__init__()\n",
289 | " self.crit = nn.CrossEntropyLoss(ignore_index=0, size_average=True)\n",
290 | " \n",
291 | " def forward(self, inputs, targets):\n",
292 | " \"\"\"Evaluate some loss over a sequence.\n",
293 | "\n",
294 | " :param inputs: torch.FloatTensor, [B, .., C] The scores from the model. Batch First\n",
295 | " :param targets: torch.LongTensor, The labels.\n",
296 | "\n",
297 | " :returns: torch.FloatTensor, The loss.\n",
298 | " \"\"\"\n",
299 | " total_sz = targets.nelement()\n",
300 | " loss = self.crit(inputs.view(total_sz, -1), targets.view(total_sz))\n",
301 | " return loss\n",
302 | "\n",
303 | "class LMTrainer:\n",
304 | " \n",
305 | " def __init__(self, optimizer: torch.optim.Optimizer, nctx):\n",
306 | " self.optimizer = optimizer\n",
307 | " self.nctx = nctx\n",
308 | " \n",
309 | " def run(self, model, train_data, loss_function, batch_size=20, clip=0.25):\n",
310 | " avg_loss = Average('average_train_loss')\n",
311 | " metrics = {}\n",
312 | " self.optimizer.zero_grad()\n",
313 | " start = time.time()\n",
314 | " model.train()\n",
315 | " hidden = model.init_hidden(batch_size)\n",
316 | " num_steps = train_data.shape[1]//self.nctx\n",
317 | " for i in range(num_steps):\n",
318 | " x = train_data[:,i*self.nctx:(i + 1) * self.nctx]\n",
319 | " y = train_data[:, i*self.nctx+1:(i + 1)*self.nctx + 1]\n",
320 | " labels = y.to('cuda:0').transpose(0, 1).contiguous()\n",
321 | " inputs = x.to('cuda:0')\n",
322 | " logits, (h, c) = model(inputs, hidden)\n",
323 | " hidden = (h.detach(), c.detach())\n",
324 | " logits = logits.transpose(0, 1).contiguous()\n",
325 | " loss = loss_function(logits, labels)\n",
326 | " loss.backward()\n",
327 | "\n",
328 | " avg_loss.update(loss.item())\n",
329 | "\n",
330 | " torch.nn.utils.clip_grad_norm_(model.parameters(), clip)\n",
331 | " self.optimizer.step()\n",
332 | " self.optimizer.zero_grad()\n",
333 | " if (i + 1) % 100 == 0:\n",
334 | " print(avg_loss)\n",
335 | "\n",
336 | " # How much time elapsed in minutes\n",
337 | " elapsed = (time.time() - start)/60\n",
338 | " train_token_loss = avg_loss.avg\n",
339 | " train_token_ppl = math.exp(train_token_loss)\n",
340 | " metrics['train_elapsed_min'] = elapsed\n",
341 | " metrics['average_train_loss'] = train_token_loss\n",
342 | " metrics['train_ppl'] = train_token_ppl\n",
343 | " return metrics\n",
344 | "\n",
345 | "class LMEvaluator:\n",
346 | " def __init__(self, nctx):\n",
347 | " self.nctx = nctx\n",
348 | " \n",
349 | " def run(self, model, valid_data, loss_function, batch_size=20):\n",
350 | " avg_valid_loss = Average('average_valid_loss')\n",
351 | " start = time.time()\n",
352 | " model.eval()\n",
353 | " hidden = model.init_hidden(batch_size)\n",
354 | " metrics = {}\n",
355 | " num_steps = valid_data.shape[1]//self.nctx\n",
356 | " for i in range(num_steps):\n",
357 | "\n",
358 | " with torch.no_grad():\n",
359 | " x = valid_data[:,i*self.nctx:(i + 1) * self.nctx]\n",
360 | " y = valid_data[:, i*self.nctx+1:(i + 1)*self.nctx + 1]\n",
361 | " labels = y.to('cuda:0').transpose(0, 1).contiguous()\n",
362 | " inputs = x.to('cuda:0')\n",
363 | " \n",
364 | " logits, hidden = model(inputs, hidden)\n",
365 | " logits = logits.transpose(0, 1).contiguous()\n",
366 | " loss = loss_function(logits, labels)\n",
367 | " avg_valid_loss.update(loss.item())\n",
368 | "\n",
369 | " valid_token_loss = avg_valid_loss.avg\n",
370 | " valid_token_ppl = math.exp(valid_token_loss)\n",
371 | "\n",
372 | " elapsed = (time.time() - start)/60\n",
373 | " metrics['valid_elapsed_min'] = elapsed\n",
374 | "\n",
375 | " metrics['average_valid_loss'] = valid_token_loss\n",
376 | " metrics['average_valid_word_ppl'] = valid_token_ppl\n",
377 | " return metrics\n",
378 | " \n",
379 | "def fit_lm(model, optimizer, epochs, batch_size, nctx, train_data, valid_data):\n",
380 | "\n",
381 | " loss = SequenceCriterion()\n",
382 | " trainer = LMTrainer(optimizer, nctx)\n",
383 | " evaluator = LMEvaluator(nctx)\n",
384 | " best_acc = 0.0\n",
385 | "\n",
386 | " metrics = evaluator.run(model, valid_data, loss, batch_size)\n",
387 | "\n",
388 | " for epoch in range(epochs):\n",
389 | "\n",
390 | " print('EPOCH {}'.format(epoch + 1))\n",
391 | " print('=================================')\n",
392 | " print('Training Results')\n",
393 | " metrics = trainer.run(model, train_data, loss, batch_size)\n",
394 | " print(metrics)\n",
395 | " print('Validation Results')\n",
396 | " metrics = evaluator.run(model, valid_data, loss, batch_size)\n",
397 | " print(metrics)"
398 | ],
399 | "execution_count": 0,
400 | "outputs": []
401 | },
402 | {
403 | "cell_type": "markdown",
404 | "metadata": {
405 | "id": "iW1dxycFoihz",
406 | "colab_type": "text"
407 | },
408 | "source": [
409 | "Now we will train it on [Wikitext-2, Merity et al. 2016](https://blog.einstein.ai/the-wikitext-long-term-dependency-language-modeling-dataset/). We will use 35 steps of backprop."
410 | ]
411 | },
412 | {
413 | "cell_type": "code",
414 | "metadata": {
415 | "id": "aceEL3kxyJcq",
416 | "colab_type": "code",
417 | "colab": {}
418 | },
419 | "source": [
420 | "BASE = 'wikitext-2'\n",
421 | "TRAIN = os.path.join(BASE, 'wiki.train.tokens')\n",
422 | "VALID = os.path.join(BASE, 'wiki.valid.tokens')\n",
423 | "\n",
424 | "batch_size = 20\n",
425 | "nctx = 35\n",
426 | "reader = WordDatasetReader(nctx)\n",
427 | "reader.build_vocab((TRAIN,))\n",
428 | "\n",
429 | "train_set = reader.load(TRAIN, batch_size)\n",
430 | "valid_set = reader.load(VALID, batch_size)"
431 | ],
432 | "execution_count": 0,
433 | "outputs": []
434 | },
435 | {
436 | "cell_type": "markdown",
437 | "metadata": {
438 | "id": "UX5AqOMfo6_I",
439 | "colab_type": "text"
440 | },
441 | "source": [
442 | "Lets start with 1 epoch"
443 | ]
444 | },
445 | {
446 | "cell_type": "code",
447 | "metadata": {
448 | "id": "sKQuguge7fMY",
449 | "colab_type": "code",
450 | "colab": {
451 | "base_uri": "https://localhost:8080/",
452 | "height": 663
453 | },
454 | "outputId": "06a1ad3d-4fd4-4a6c-ed02-e0c3798f1e8c"
455 | },
456 | "source": [
457 | "\n",
458 | "model = LSTMLanguageModel(len(reader.vocab), 512, 512)\n",
459 | "model.to('cuda:0')\n",
460 | "\n",
461 | "num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
462 | "print(f\"Model has {num_params} parameters\") \n",
463 | "\n",
464 | "\n",
465 | "learnable_params = [p for p in model.parameters() if p.requires_grad]\n",
466 | "optimizer = torch.optim.Adam(learnable_params, lr=0.001)\n",
467 | "fit_lm(model, optimizer, 1, batch_size, nctx, train_set, valid_set)\n"
468 | ],
469 | "execution_count": 7,
470 | "outputs": [
471 | {
472 | "output_type": "stream",
473 | "text": [
474 | "Model has 21274623 parameters\n"
475 | ],
476 | "name": "stdout"
477 | },
478 | {
479 | "output_type": "stream",
480 | "text": [
481 | "/usr/local/lib/python3.6/dist-packages/torch/nn/_reduction.py:46: UserWarning: size_average and reduce args will be deprecated, please use reduction='mean' instead.\n",
482 | " warnings.warn(warning.format(ret))\n"
483 | ],
484 | "name": "stderr"
485 | },
486 | {
487 | "output_type": "stream",
488 | "text": [
489 | "EPOCH 1\n",
490 | "=================================\n",
491 | "Training Results\n",
492 | "average_train_loss 7.130287 (7.630262)\n",
493 | "average_train_loss 6.948112 (7.174242)\n",
494 | "average_train_loss 6.429612 (6.957250)\n",
495 | "average_train_loss 6.706582 (6.817735)\n",
496 | "average_train_loss 6.480259 (6.716661)\n",
497 | "average_train_loss 6.253604 (6.639940)\n",
498 | "average_train_loss 6.250593 (6.584427)\n",
499 | "average_train_loss 6.086081 (6.535446)\n",
500 | "average_train_loss 6.046218 (6.491021)\n",
501 | "average_train_loss 5.840661 (6.455732)\n",
502 | "average_train_loss 6.127773 (6.425025)\n",
503 | "average_train_loss 5.766460 (6.398616)\n",
504 | "average_train_loss 6.137995 (6.376816)\n",
505 | "average_train_loss 6.115303 (6.351095)\n",
506 | "average_train_loss 6.203366 (6.333509)\n",
507 | "average_train_loss 6.009459 (6.318195)\n",
508 | "average_train_loss 6.126120 (6.297565)\n",
509 | "average_train_loss 5.796104 (6.276726)\n",
510 | "average_train_loss 5.737082 (6.260670)\n",
511 | "average_train_loss 5.954897 (6.243683)\n",
512 | "average_train_loss 5.674878 (6.226430)\n",
513 | "average_train_loss 5.613625 (6.207307)\n",
514 | "average_train_loss 5.878324 (6.189868)\n",
515 | "average_train_loss 5.824322 (6.178013)\n",
516 | "average_train_loss 5.932457 (6.164326)\n",
517 | "average_train_loss 5.771354 (6.153274)\n",
518 | "average_train_loss 5.401644 (6.139012)\n",
519 | "average_train_loss 5.825085 (6.124258)\n",
520 | "average_train_loss 5.493943 (6.110777)\n",
521 | "{'train_elapsed_min': 2.2767986059188843, 'average_train_loss': 6.0969437521813985, 'train_ppl': 444.4971984264256}\n",
522 | "Validation Results\n",
523 | "{'valid_elapsed_min': 0.07908193667729696, 'average_valid_loss': 5.534342344345585, 'average_valid_word_ppl': 253.24118736148796}\n"
524 | ],
525 | "name": "stdout"
526 | }
527 | ]
528 | },
529 | {
530 | "cell_type": "markdown",
531 | "metadata": {
532 | "id": "lvywGcj9pJcA",
533 | "colab_type": "text"
534 | },
535 | "source": [
536 | "We can sample out of our language model using the code below."
537 | ]
538 | },
539 | {
540 | "cell_type": "code",
541 | "metadata": {
542 | "id": "DWOgGxLbaefT",
543 | "colab_type": "code",
544 | "colab": {
545 | "base_uri": "https://localhost:8080/",
546 | "height": 34
547 | },
548 | "outputId": "e9d1f2f6-cc1c-4b8a-9469-6fdb042ebebb"
549 | },
550 | "source": [
551 | "def sample(model, index2word, start_word='the', maxlen=20):\n",
552 | " \n",
553 | "\n",
554 | " model.eval() \n",
555 | " words = [start_word]\n",
556 | " x = torch.tensor(reader.vocab.get(start_word)).long().reshape(1, 1).to('cuda:0')\n",
557 | " hidden = model.init_hidden(1)\n",
558 | "\n",
559 | " with torch.no_grad():\n",
560 | " for i in range(20):\n",
561 | " output, hidden = model(x, hidden)\n",
562 | " word_softmax = output.squeeze().exp().cpu()\n",
563 | " selected = torch.multinomial(word_softmax, 1)[0]\n",
564 | " x.fill_(selected)\n",
565 | " word = index2word[selected.item()]\n",
566 | " words.append(word)\n",
567 | " words.append('...')\n",
568 | " return words\n",
569 | "\n",
570 | "index2word = {i: w for w, i in reader.vocab.items()}\n",
571 | "words = sample(model, index2word)\n",
572 | "print(' '.join(words))\n"
573 | ],
574 | "execution_count": 8,
575 | "outputs": [
576 | {
577 | "output_type": "stream",
578 | "text": [
579 | "the latter story pass that would be in Park or Ireland . Like Liam Stuart illustrator , NC apologize and livestock ...\n"
580 | ],
581 | "name": "stdout"
582 | }
583 | ]
584 | },
585 | {
586 | "cell_type": "markdown",
587 | "metadata": {
588 | "id": "aXjMYTo7paIh",
589 | "colab_type": "text"
590 | },
591 | "source": [
592 | "Lets train a few more epochs and try again"
593 | ]
594 | },
595 | {
596 | "cell_type": "code",
597 | "metadata": {
598 | "id": "in_ftVIipSAS",
599 | "colab_type": "code",
600 | "colab": {
601 | "base_uri": "https://localhost:8080/",
602 | "height": 1000
603 | },
604 | "outputId": "b10e9154-2410-477e-8801-8c556b9492b5"
605 | },
606 | "source": [
607 | "fit_lm(model, optimizer, 3, batch_size, 35, train_set, valid_set)"
608 | ],
609 | "execution_count": 9,
610 | "outputs": [
611 | {
612 | "output_type": "stream",
613 | "text": [
614 | "/usr/local/lib/python3.6/dist-packages/torch/nn/_reduction.py:46: UserWarning: size_average and reduce args will be deprecated, please use reduction='mean' instead.\n",
615 | " warnings.warn(warning.format(ret))\n"
616 | ],
617 | "name": "stderr"
618 | },
619 | {
620 | "output_type": "stream",
621 | "text": [
622 | "EPOCH 1\n",
623 | "=================================\n",
624 | "Training Results\n",
625 | "average_train_loss 5.888901 (5.598958)\n",
626 | "average_train_loss 5.818975 (5.584962)\n",
627 | "average_train_loss 5.503317 (5.580568)\n",
628 | "average_train_loss 5.877174 (5.584086)\n",
629 | "average_train_loss 5.637257 (5.563949)\n",
630 | "average_train_loss 5.447403 (5.540543)\n",
631 | "average_train_loss 5.460862 (5.531339)\n",
632 | "average_train_loss 5.488514 (5.525078)\n",
633 | "average_train_loss 5.359737 (5.517013)\n",
634 | "average_train_loss 5.121772 (5.509718)\n",
635 | "average_train_loss 5.441720 (5.503375)\n",
636 | "average_train_loss 5.280029 (5.499962)\n",
637 | "average_train_loss 5.543726 (5.500008)\n",
638 | "average_train_loss 5.556562 (5.494267)\n",
639 | "average_train_loss 5.593565 (5.495319)\n",
640 | "average_train_loss 5.347257 (5.496266)\n",
641 | "average_train_loss 5.519910 (5.489749)\n",
642 | "average_train_loss 5.264927 (5.483263)\n",
643 | "average_train_loss 5.207999 (5.481013)\n",
644 | "average_train_loss 5.434073 (5.476722)\n",
645 | "average_train_loss 5.112748 (5.471222)\n",
646 | "average_train_loss 5.142471 (5.463090)\n",
647 | "average_train_loss 5.362827 (5.455768)\n",
648 | "average_train_loss 5.287307 (5.454580)\n",
649 | "average_train_loss 5.420770 (5.449693)\n",
650 | "average_train_loss 5.358116 (5.448896)\n",
651 | "average_train_loss 5.019379 (5.443910)\n",
652 | "average_train_loss 5.375151 (5.437743)\n",
653 | "average_train_loss 5.061219 (5.432098)\n",
654 | "{'train_elapsed_min': 2.3279770851135253, 'average_train_loss': 5.425043830046748, 'train_ppl': 227.0212964512552}\n",
655 | "Validation Results\n",
656 | "{'valid_elapsed_min': 0.0792834202448527, 'average_valid_loss': 5.2854782812057, 'average_valid_word_ppl': 197.4485968212379}\n",
657 | "EPOCH 2\n",
658 | "=================================\n",
659 | "Training Results\n",
660 | "average_train_loss 5.594471 (5.241230)\n",
661 | "average_train_loss 5.464106 (5.238882)\n",
662 | "average_train_loss 5.156283 (5.241025)\n",
663 | "average_train_loss 5.523200 (5.251309)\n",
664 | "average_train_loss 5.250049 (5.232446)\n",
665 | "average_train_loss 5.129644 (5.205283)\n",
666 | "average_train_loss 5.083561 (5.195768)\n",
667 | "average_train_loss 5.166030 (5.192733)\n",
668 | "average_train_loss 5.086248 (5.188236)\n",
669 | "average_train_loss 4.777071 (5.182954)\n",
670 | "average_train_loss 5.161200 (5.178061)\n",
671 | "average_train_loss 5.008852 (5.176982)\n",
672 | "average_train_loss 5.142172 (5.180397)\n",
673 | "average_train_loss 5.281511 (5.176891)\n",
674 | "average_train_loss 5.325745 (5.180915)\n",
675 | "average_train_loss 5.139209 (5.184646)\n",
676 | "average_train_loss 5.210761 (5.179825)\n",
677 | "average_train_loss 5.041038 (5.176068)\n",
678 | "average_train_loss 4.949179 (5.175644)\n",
679 | "average_train_loss 5.189332 (5.173076)\n",
680 | "average_train_loss 4.856432 (5.168824)\n",
681 | "average_train_loss 4.836321 (5.162272)\n",
682 | "average_train_loss 5.064670 (5.156098)\n",
683 | "average_train_loss 4.960176 (5.156356)\n",
684 | "average_train_loss 5.104852 (5.152806)\n",
685 | "average_train_loss 5.087171 (5.153634)\n",
686 | "average_train_loss 4.799379 (5.150230)\n",
687 | "average_train_loss 5.135460 (5.145286)\n",
688 | "average_train_loss 4.837797 (5.140926)\n",
689 | "{'train_elapsed_min': 2.3252848744392396, 'average_train_loss': 5.135120418061357, 'train_ppl': 169.88477583838508}\n",
690 | "Validation Results\n",
691 | "{'valid_elapsed_min': 0.0793849547704061, 'average_valid_loss': 5.172104538640668, 'average_valid_word_ppl': 176.2854469011485}\n",
692 | "EPOCH 3\n",
693 | "=================================\n",
694 | "Training Results\n",
695 | "average_train_loss 5.397759 (5.021339)\n",
696 | "average_train_loss 5.223657 (5.019342)\n",
697 | "average_train_loss 4.937864 (5.024831)\n",
698 | "average_train_loss 5.331026 (5.037493)\n",
699 | "average_train_loss 5.013303 (5.017465)\n",
700 | "average_train_loss 4.930802 (4.989496)\n",
701 | "average_train_loss 4.877040 (4.980525)\n",
702 | "average_train_loss 5.008083 (4.979450)\n",
703 | "average_train_loss 4.872459 (4.977045)\n",
704 | "average_train_loss 4.537238 (4.972334)\n",
705 | "average_train_loss 4.895549 (4.967794)\n",
706 | "average_train_loss 4.778692 (4.967996)\n",
707 | "average_train_loss 4.913031 (4.972698)\n",
708 | "average_train_loss 4.998916 (4.970212)\n",
709 | "average_train_loss 5.101923 (4.975337)\n",
710 | "average_train_loss 4.911569 (4.980182)\n",
711 | "average_train_loss 5.020348 (4.975942)\n",
712 | "average_train_loss 4.840934 (4.973322)\n",
713 | "average_train_loss 4.820501 (4.974319)\n",
714 | "average_train_loss 4.994349 (4.972502)\n",
715 | "average_train_loss 4.695219 (4.968739)\n",
716 | "average_train_loss 4.709369 (4.962705)\n",
717 | "average_train_loss 4.882851 (4.957035)\n",
718 | "average_train_loss 4.744756 (4.957699)\n",
719 | "average_train_loss 4.943813 (4.954420)\n",
720 | "average_train_loss 4.885959 (4.955989)\n",
721 | "average_train_loss 4.626930 (4.953493)\n",
722 | "average_train_loss 4.934981 (4.949243)\n",
723 | "average_train_loss 4.673239 (4.945310)\n",
724 | "{'train_elapsed_min': 2.333573551972707, 'average_train_loss': 4.939903614627328, 'train_ppl': 139.75677840163237}\n",
725 | "Validation Results\n",
726 | "{'valid_elapsed_min': 0.07996076345443726, 'average_valid_loss': 5.096776191649899, 'average_valid_word_ppl': 163.49398352530844}\n"
727 | ],
728 | "name": "stdout"
729 | }
730 | ]
731 | },
732 | {
733 | "cell_type": "code",
734 | "metadata": {
735 | "id": "K-9gBJE9pu2G",
736 | "colab_type": "code",
737 | "colab": {
738 | "base_uri": "https://localhost:8080/",
739 | "height": 34
740 | },
741 | "outputId": "34eb8d38-1eff-4fff-dec4-458cc9367b80"
742 | },
743 | "source": [
744 | "index2word = {i: w for w, i in reader.vocab.items()}\n",
745 | "words = sample(model, index2word)\n",
746 | "print(' '.join(words))"
747 | ],
748 | "execution_count": 10,
749 | "outputs": [
750 | {
751 | "output_type": "stream",
752 | "text": [
753 | "the Supreme Court did not introduce any contact with its way and Grosser Davies of the chance of the country . ...\n"
754 | ],
755 | "name": "stdout"
756 | }
757 | ]
758 | },
759 | {
760 | "cell_type": "markdown",
761 | "metadata": {
762 | "id": "HKbIXgu6rJs8",
763 | "colab_type": "text"
764 | },
765 | "source": [
766 | "\n",
767 | "## ELMo\n",
768 | "\n",
769 | "For the rest of this section, we will focus on ELMo ([Peters et al 2018](https://export.arxiv.org/pdf/1802.05365)), a language model with an embedding layer and 2 subsequent LSTM layers. Actually, at training time, ELMo is basically two LMs -- one working in the forward direction and one working in the backward direction. The losses for the forward and reverse directions are averaged. At inference time, the forward and backward layers are aggregated into a single bidirectional representation at each layer.\n",
770 | "\n",
771 | "In our example, we created a word-based LM. You might have been wondering what to do about words that we havent seen yet -- and that is a valid concern! Instead of using a word embedding layer like our example above, what if we had a model that used a character compositional approach, taking each character in a word and applying a pooling operation to yield a word representation. This would mean that the model can handle words that its never seen in the input before.\n",
772 | "\n",
773 | "This is exactly what ELMo does -- its based on the research of [Kim et al. 2015](https://arxiv.org/abs/1508.06615). \n",
774 | "\n",
775 | "There is a nice [slide deck by the authors here](http://www.people.fas.harvard.edu/~yoonkim/data/char-nlm-slides.pdf), but the key high-level points are listed here:\n",
776 | "\n",
777 | "### Kim Language Model\n",
778 | "\n",
779 | "* **Goal**: predict the next word in the sentence (causal LM) but account for unseen words by using a character compositional approach that relies on letters within the pre-segmented words. This also has the important impact of reducing the number of parameters required in the model drastically over word-level models.\n",
780 | "\n",
781 | "* **Using**: LSTM layers that take in a word representation for each position. Each word is put in and used to predict the next word over a context\n",
782 | "\n",
783 | "* **The Twist**: use embeddings approach from [dos Santos & Zadrozny 2014](http://proceedings.mlr.press/v32/santos14.pdf) to represent words, but add parallel filters as in [Kim 2014](https://www.aclweb.org/anthology/D14-1181). Also, add highway layers on top of the base model\n",
784 | "\n",
785 | "\n",
786 | "\n",
787 | "### ELMo Language Model\n",
788 | "\n",
789 | "* **Goal**: predict the next word in the sentence (causal LM) on the forward sequence **and** predict the previous word on the sentence conditioned on the following context. \n",
790 | "\n",
791 | "* **Using**: LSTM layers as before, but bidirectional, sum the forward and backward loss to make one big loss\n",
792 | "\n",
793 | "* **The Twist** Potentially use all layers of the model (except we dont need head with the big softmax at the end over the words). After the fact, we can freeze our biLM embeddings but still provide useful information by learning a linear combination of the layers during downstream training. During the biLM training, these scalars dont exist\n"
794 | ]
795 | },
796 | {
797 | "cell_type": "markdown",
798 | "metadata": {
799 | "id": "_19LSBBrq-fD",
800 | "colab_type": "text"
801 | },
802 | "source": [
803 | "\n",
804 | "### ELMo with AllenNLP\n",
805 | "\n",
806 | "Even though ELMo is just a network like described above, there are a lot of details to getting it set up and reloading the pre-trained checkpoints that are provided, and these details are not really important for demonstration purposes. So, we will just install [AllenNLP](https://github.com/allenai/allennlp) and use it as a contextual embedding layer.\n",
807 | "\n",
808 | "If you are interested in learning more about using ELMo with AllenNLP, they have provided a [tutorial here](https://github.com/allenai/allennlp/blob/master/tutorials/how_to/elmo.md)\n",
809 | "\n",
810 | "#### TensorFlow and ELMo\n",
811 | "\n",
812 | "ELMo was originally trained with TensorFlow. You can find the code to train and use it in the [bilm-tf repository](https://github.com/allenai/bilm-tf/tree/master/bilm)\n",
813 | "\n",
814 | "TF-Hub contains the [pre-trained ELMo model](https://tfhub.dev/google/elmo/2) and is very easy to integrate if you are using TensorFlow already. The model takes a sequence of words (mixed-case) as inputs and can just be \"glued\" in to your existing models as a sub-graph of your own.\n"
815 | ]
816 | },
817 | {
818 | "cell_type": "code",
819 | "metadata": {
820 | "id": "ziMb2Iphr4vJ",
821 | "colab_type": "code",
822 | "outputId": "53b96581-c495-41ca-c255-b1e69b321971",
823 | "colab": {
824 | "base_uri": "https://localhost:8080/",
825 | "height": 1000
826 | }
827 | },
828 | "source": [
829 | "!pip install allennlp"
830 | ],
831 | "execution_count": 11,
832 | "outputs": [
833 | {
834 | "output_type": "stream",
835 | "text": [
836 | "Collecting allennlp\n",
837 | " Using cached https://files.pythonhosted.org/packages/30/8c/72b14d20c9cbb0306939ea41109fc599302634fd5c59ccba1a659b7d0360/allennlp-0.8.4-py3-none-any.whl\n",
838 | "Collecting jsonnet>=0.10.0; sys_platform != \"win32\" (from allennlp)\n",
839 | "\u001b[?25l Downloading https://files.pythonhosted.org/packages/a9/a8/adba6cd0f84ee6ab064e7f70cd03a2836cefd2e063fd565180ec13beae93/jsonnet-0.13.0.tar.gz (255kB)\n",
840 | "\u001b[K |████████████████████████████████| 256kB 3.4MB/s \n",
841 | "\u001b[?25hCollecting numpydoc>=0.8.0 (from allennlp)\n",
842 | " Downloading https://files.pythonhosted.org/packages/6a/f3/7cfe4c616e4b9fe05540256cc9c6661c052c8a4cec2915732793b36e1843/numpydoc-0.9.1.tar.gz\n",
843 | "Collecting pytorch-pretrained-bert>=0.6.0 (from allennlp)\n",
844 | "\u001b[?25l Downloading https://files.pythonhosted.org/packages/d7/e0/c08d5553b89973d9a240605b9c12404bcf8227590de62bae27acbcfe076b/pytorch_pretrained_bert-0.6.2-py3-none-any.whl (123kB)\n",
845 | "\u001b[K |████████████████████████████████| 133kB 49.5MB/s \n",
846 | "\u001b[?25hRequirement already satisfied: sqlparse>=0.2.4 in /usr/local/lib/python3.6/dist-packages (from allennlp) (0.3.0)\n",
847 | "Requirement already satisfied: pytest in /usr/local/lib/python3.6/dist-packages (from allennlp) (3.6.4)\n",
848 | "Collecting parsimonious>=0.8.0 (from allennlp)\n",
849 | " Using cached https://files.pythonhosted.org/packages/02/fc/067a3f89869a41009e1a7cdfb14725f8ddd246f30f63c645e8ef8a1c56f4/parsimonious-0.8.1.tar.gz\n",
850 | "Collecting conllu==0.11 (from allennlp)\n",
851 | " Using cached https://files.pythonhosted.org/packages/d4/2c/856344d9b69baf5b374c395b4286626181a80f0c2b2f704914d18a1cea47/conllu-0.11-py2.py3-none-any.whl\n",
852 | "Collecting overrides (from allennlp)\n",
853 | " Downloading https://files.pythonhosted.org/packages/de/55/3100c6d14c1ed177492fcf8f07c4a7d2d6c996c0a7fc6a9a0a41308e7eec/overrides-1.9.tar.gz\n",
854 | "Collecting awscli>=1.11.91 (from allennlp)\n",
855 | "\u001b[?25l Downloading https://files.pythonhosted.org/packages/20/fa/f4b6207d59267da0be60be3df32682d2c7479122c7cb87556bd4412675fe/awscli-1.16.190-py2.py3-none-any.whl (1.7MB)\n",
856 | "\u001b[K |████████████████████████████████| 1.7MB 51.2MB/s \n",
857 | "\u001b[?25hRequirement already satisfied: spacy<2.2,>=2.0.18 in /usr/local/lib/python3.6/dist-packages (from allennlp) (2.1.4)\n",
858 | "Collecting tensorboardX>=1.2 (from allennlp)\n",
859 | "\u001b[?25l Downloading https://files.pythonhosted.org/packages/a2/57/2f0a46538295b8e7f09625da6dd24c23f9d0d7ef119ca1c33528660130d5/tensorboardX-1.7-py2.py3-none-any.whl (238kB)\n",
860 | "\u001b[K |████████████████████████████████| 245kB 52.8MB/s \n",
861 | "\u001b[?25hRequirement already satisfied: boto3 in /usr/local/lib/python3.6/dist-packages (from allennlp) (1.9.175)\n",
862 | "Requirement already satisfied: tqdm>=4.19 in /usr/local/lib/python3.6/dist-packages (from allennlp) (4.28.1)\n",
863 | "Requirement already satisfied: scipy in /usr/local/lib/python3.6/dist-packages (from allennlp) (1.3.0)\n",
864 | "Requirement already satisfied: requests>=2.18 in /usr/local/lib/python3.6/dist-packages (from allennlp) (2.21.0)\n",
865 | "Requirement already satisfied: h5py in /usr/local/lib/python3.6/dist-packages (from allennlp) (2.8.0)\n",
866 | "Requirement already satisfied: gevent>=1.3.6 in /usr/local/lib/python3.6/dist-packages (from allennlp) (1.4.0)\n",
867 | "Collecting jsonpickle (from allennlp)\n",
868 | " Downloading https://files.pythonhosted.org/packages/07/07/c157520a3ebd166c8c24c6ae0ecae7c3968eb4653ff0e5af369bb82f004d/jsonpickle-1.2-py2.py3-none-any.whl\n",
869 | "Collecting flask-cors>=3.0.7 (from allennlp)\n",
870 | " Downloading https://files.pythonhosted.org/packages/78/38/e68b11daa5d613e3a91e4bf3da76c94ac9ee0d9cd515af9c1ab80d36f709/Flask_Cors-3.0.8-py2.py3-none-any.whl\n",
871 | "Requirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from allennlp) (1.16.4)\n",
872 | "Requirement already satisfied: flask>=1.0.2 in /usr/local/lib/python3.6/dist-packages (from allennlp) (1.0.3)\n",
873 | "Requirement already satisfied: pytz>=2017.3 in /usr/local/lib/python3.6/dist-packages (from allennlp) (2018.9)\n",
874 | "Requirement already satisfied: nltk in /usr/local/lib/python3.6/dist-packages (from allennlp) (3.2.5)\n",
875 | "Collecting ftfy (from allennlp)\n",
876 | "\u001b[?25l Downloading https://files.pythonhosted.org/packages/8f/86/df789c5834f15ae1ca53a8d4c1fc4788676c2e32112f6a786f2625d9c6e6/ftfy-5.5.1-py3-none-any.whl (43kB)\n",
877 | "\u001b[K |████████████████████████████████| 51kB 26.0MB/s \n",
878 | "\u001b[?25hRequirement already satisfied: editdistance in /usr/local/lib/python3.6/dist-packages (from allennlp) (0.5.3)\n",
879 | "Requirement already satisfied: torch>=0.4.1 in /usr/local/lib/python3.6/dist-packages (from allennlp) (1.1.0)\n",
880 | "Collecting word2number>=1.1 (from allennlp)\n",
881 | " Downloading https://files.pythonhosted.org/packages/4a/29/a31940c848521f0725f0df6b25dca8917f13a2025b0e8fcbe5d0457e45e6/word2number-1.1.zip\n",
882 | "Collecting responses>=0.7 (from allennlp)\n",
883 | " Using cached https://files.pythonhosted.org/packages/d1/5a/b887e89925f1de7890ef298a74438371ed4ed29b33def9e6d02dc6036fd8/responses-0.10.6-py2.py3-none-any.whl\n",
884 | "Requirement already satisfied: matplotlib>=2.2.3 in /usr/local/lib/python3.6/dist-packages (from allennlp) (3.0.3)\n",
885 | "Collecting flaky (from allennlp)\n",
886 | " Downloading https://files.pythonhosted.org/packages/ae/09/94d623dda1adacd51722f3e3e0f88ba08dd030ac2b2662bfb4383096340d/flaky-3.6.0-py2.py3-none-any.whl\n",
887 | "Requirement already satisfied: scikit-learn in /usr/local/lib/python3.6/dist-packages (from allennlp) (0.21.2)\n",
888 | "Collecting unidecode (from allennlp)\n",
889 | "\u001b[?25l Downloading https://files.pythonhosted.org/packages/d0/42/d9edfed04228bacea2d824904cae367ee9efd05e6cce7ceaaedd0b0ad964/Unidecode-1.1.1-py2.py3-none-any.whl (238kB)\n",
890 | "\u001b[K |████████████████████████████████| 245kB 27.3MB/s \n",
891 | "\u001b[?25hRequirement already satisfied: sphinx>=1.6.5 in /usr/local/lib/python3.6/dist-packages (from numpydoc>=0.8.0->allennlp) (1.8.5)\n",
892 | "Requirement already satisfied: Jinja2>=2.3 in /usr/local/lib/python3.6/dist-packages (from numpydoc>=0.8.0->allennlp) (2.10.1)\n",
893 | "Collecting regex (from pytorch-pretrained-bert>=0.6.0->allennlp)\n",
894 | "\u001b[?25l Downloading https://files.pythonhosted.org/packages/6f/4e/1b178c38c9a1a184288f72065a65ca01f3154df43c6ad898624149b8b4e0/regex-2019.06.08.tar.gz (651kB)\n",
895 | "\u001b[K |████████████████████████████████| 655kB 50.1MB/s \n",
896 | "\u001b[?25hRequirement already satisfied: six>=1.10.0 in /usr/local/lib/python3.6/dist-packages (from pytest->allennlp) (1.12.0)\n",
897 | "Requirement already satisfied: py>=1.5.0 in /usr/local/lib/python3.6/dist-packages (from pytest->allennlp) (1.8.0)\n",
898 | "Requirement already satisfied: more-itertools>=4.0.0 in /usr/local/lib/python3.6/dist-packages (from pytest->allennlp) (7.0.0)\n",
899 | "Requirement already satisfied: setuptools in /usr/local/lib/python3.6/dist-packages (from pytest->allennlp) (41.0.1)\n",
900 | "Requirement already satisfied: atomicwrites>=1.0 in /usr/local/lib/python3.6/dist-packages (from pytest->allennlp) (1.3.0)\n",
901 | "Requirement already satisfied: attrs>=17.4.0 in /usr/local/lib/python3.6/dist-packages (from pytest->allennlp) (19.1.0)\n",
902 | "Requirement already satisfied: pluggy<0.8,>=0.5 in /usr/local/lib/python3.6/dist-packages (from pytest->allennlp) (0.7.1)\n",
903 | "Requirement already satisfied: PyYAML<=5.1,>=3.10; python_version != \"2.6\" in /usr/local/lib/python3.6/dist-packages (from awscli>=1.11.91->allennlp) (3.13)\n",
904 | "Requirement already satisfied: s3transfer<0.3.0,>=0.2.0 in /usr/local/lib/python3.6/dist-packages (from awscli>=1.11.91->allennlp) (0.2.1)\n",
905 | "Collecting botocore==1.12.180 (from awscli>=1.11.91->allennlp)\n",
906 | "\u001b[?25l Downloading https://files.pythonhosted.org/packages/3b/27/fa7da6feb20d1dfc0ab562226061b20da2d27ea18ca32dc764fe86704a99/botocore-1.12.180-py2.py3-none-any.whl (5.6MB)\n",
907 | "\u001b[K |████████████████████████████████| 5.6MB 35.1MB/s \n",
908 | "\u001b[?25hCollecting rsa<=3.5.0,>=3.1.2 (from awscli>=1.11.91->allennlp)\n",
909 | "\u001b[?25l Downloading https://files.pythonhosted.org/packages/e1/ae/baedc9cb175552e95f3395c43055a6a5e125ae4d48a1d7a924baca83e92e/rsa-3.4.2-py2.py3-none-any.whl (46kB)\n",
910 | "\u001b[K |████████████████████████████████| 51kB 26.7MB/s \n",
911 | "\u001b[?25hRequirement already satisfied: docutils>=0.10 in /usr/local/lib/python3.6/dist-packages (from awscli>=1.11.91->allennlp) (0.14)\n",
912 | "Collecting colorama<=0.3.9,>=0.2.5 (from awscli>=1.11.91->allennlp)\n",
913 | " Downloading https://files.pythonhosted.org/packages/db/c8/7dcf9dbcb22429512708fe3a547f8b6101c0d02137acbd892505aee57adf/colorama-0.3.9-py2.py3-none-any.whl\n",
914 | "Requirement already satisfied: thinc<7.1.0,>=7.0.2 in /usr/local/lib/python3.6/dist-packages (from spacy<2.2,>=2.0.18->allennlp) (7.0.4)\n",
915 | "Requirement already satisfied: blis<0.3.0,>=0.2.2 in /usr/local/lib/python3.6/dist-packages (from spacy<2.2,>=2.0.18->allennlp) (0.2.4)\n",
916 | "Requirement already satisfied: cymem<2.1.0,>=2.0.2 in /usr/local/lib/python3.6/dist-packages (from spacy<2.2,>=2.0.18->allennlp) (2.0.2)\n",
917 | "Requirement already satisfied: jsonschema<3.1.0,>=2.6.0 in /usr/local/lib/python3.6/dist-packages (from spacy<2.2,>=2.0.18->allennlp) (2.6.0)\n",
918 | "Requirement already satisfied: preshed<2.1.0,>=2.0.1 in /usr/local/lib/python3.6/dist-packages (from spacy<2.2,>=2.0.18->allennlp) (2.0.1)\n",
919 | "Requirement already satisfied: srsly<1.1.0,>=0.0.5 in /usr/local/lib/python3.6/dist-packages (from spacy<2.2,>=2.0.18->allennlp) (0.0.7)\n",
920 | "Requirement already satisfied: wasabi<1.1.0,>=0.2.0 in /usr/local/lib/python3.6/dist-packages (from spacy<2.2,>=2.0.18->allennlp) (0.2.2)\n",
921 | "Requirement already satisfied: murmurhash<1.1.0,>=0.28.0 in /usr/local/lib/python3.6/dist-packages (from spacy<2.2,>=2.0.18->allennlp) (1.0.2)\n",
922 | "Requirement already satisfied: plac<1.0.0,>=0.9.6 in /usr/local/lib/python3.6/dist-packages (from spacy<2.2,>=2.0.18->allennlp) (0.9.6)\n",
923 | "Requirement already satisfied: protobuf>=3.2.0 in /usr/local/lib/python3.6/dist-packages (from tensorboardX>=1.2->allennlp) (3.7.1)\n",
924 | "Requirement already satisfied: jmespath<1.0.0,>=0.7.1 in /usr/local/lib/python3.6/dist-packages (from boto3->allennlp) (0.9.4)\n",
925 | "Requirement already satisfied: urllib3<1.25,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from requests>=2.18->allennlp) (1.24.3)\n",
926 | "Requirement already satisfied: idna<2.9,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests>=2.18->allennlp) (2.8)\n",
927 | "Requirement already satisfied: chardet<3.1.0,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests>=2.18->allennlp) (3.0.4)\n",
928 | "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests>=2.18->allennlp) (2019.6.16)\n",
929 | "Requirement already satisfied: greenlet>=0.4.14; platform_python_implementation == \"CPython\" in /usr/local/lib/python3.6/dist-packages (from gevent>=1.3.6->allennlp) (0.4.15)\n",
930 | "Requirement already satisfied: Werkzeug>=0.14 in /usr/local/lib/python3.6/dist-packages (from flask>=1.0.2->allennlp) (0.15.4)\n",
931 | "Requirement already satisfied: click>=5.1 in /usr/local/lib/python3.6/dist-packages (from flask>=1.0.2->allennlp) (7.0)\n",
932 | "Requirement already satisfied: itsdangerous>=0.24 in /usr/local/lib/python3.6/dist-packages (from flask>=1.0.2->allennlp) (1.1.0)\n",
933 | "Requirement already satisfied: wcwidth in /usr/local/lib/python3.6/dist-packages (from ftfy->allennlp) (0.1.7)\n",
934 | "Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib>=2.2.3->allennlp) (1.1.0)\n",
935 | "Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.6/dist-packages (from matplotlib>=2.2.3->allennlp) (0.10.0)\n",
936 | "Requirement already satisfied: python-dateutil>=2.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib>=2.2.3->allennlp) (2.5.3)\n",
937 | "Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib>=2.2.3->allennlp) (2.4.0)\n",
938 | "Requirement already satisfied: joblib>=0.11 in /usr/local/lib/python3.6/dist-packages (from scikit-learn->allennlp) (0.13.2)\n",
939 | "Requirement already satisfied: Pygments>=2.0 in /usr/local/lib/python3.6/dist-packages (from sphinx>=1.6.5->numpydoc>=0.8.0->allennlp) (2.1.3)\n",
940 | "Requirement already satisfied: snowballstemmer>=1.1 in /usr/local/lib/python3.6/dist-packages (from sphinx>=1.6.5->numpydoc>=0.8.0->allennlp) (1.2.1)\n",
941 | "Requirement already satisfied: sphinxcontrib-websupport in /usr/local/lib/python3.6/dist-packages (from sphinx>=1.6.5->numpydoc>=0.8.0->allennlp) (1.1.2)\n",
942 | "Requirement already satisfied: packaging in /usr/local/lib/python3.6/dist-packages (from sphinx>=1.6.5->numpydoc>=0.8.0->allennlp) (19.0)\n",
943 | "Requirement already satisfied: imagesize in /usr/local/lib/python3.6/dist-packages (from sphinx>=1.6.5->numpydoc>=0.8.0->allennlp) (1.1.0)\n",
944 | "Requirement already satisfied: babel!=2.0,>=1.3 in /usr/local/lib/python3.6/dist-packages (from sphinx>=1.6.5->numpydoc>=0.8.0->allennlp) (2.7.0)\n",
945 | "Requirement already satisfied: alabaster<0.8,>=0.7 in /usr/local/lib/python3.6/dist-packages (from sphinx>=1.6.5->numpydoc>=0.8.0->allennlp) (0.7.12)\n",
946 | "Requirement already satisfied: MarkupSafe>=0.23 in /usr/local/lib/python3.6/dist-packages (from Jinja2>=2.3->numpydoc>=0.8.0->allennlp) (1.1.1)\n",
947 | "Requirement already satisfied: pyasn1>=0.1.3 in /usr/local/lib/python3.6/dist-packages (from rsa<=3.5.0,>=3.1.2->awscli>=1.11.91->allennlp) (0.4.5)\n",
948 | "Building wheels for collected packages: jsonnet, numpydoc, parsimonious, overrides, word2number, regex\n",
949 | " Building wheel for jsonnet (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
950 | " Stored in directory: /root/.cache/pip/wheels/1a/30/ab/ae4a57b1df44fa20a531edb9601b27603da8f5336225691f3f\n",
951 | " Building wheel for numpydoc (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
952 | " Stored in directory: /root/.cache/pip/wheels/51/30/d1/92a39ba40f21cb70e53f8af96eb98f002a781843c065406500\n",
953 | " Building wheel for parsimonious (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
954 | " Stored in directory: /root/.cache/pip/wheels/b7/8d/e7/a0e74217da5caeb3c1c7689639b6d28ddbf9985b840bc96a9a\n",
955 | " Building wheel for overrides (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
956 | " Stored in directory: /root/.cache/pip/wheels/8d/52/86/e5a83b1797e7d263b458d2334edd2704c78508b3eea9323718\n",
957 | " Building wheel for word2number (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
958 | " Stored in directory: /root/.cache/pip/wheels/46/2f/53/5f5c1d275492f2fce1cdab9a9bb12d49286dead829a4078e0e\n",
959 | " Building wheel for regex (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
960 | " Stored in directory: /root/.cache/pip/wheels/35/e4/80/abf3b33ba89cf65cd262af8a22a5a999cc28fbfabea6b38473\n",
961 | "Successfully built jsonnet numpydoc parsimonious overrides word2number regex\n",
962 | "Installing collected packages: jsonnet, numpydoc, regex, pytorch-pretrained-bert, parsimonious, conllu, overrides, botocore, rsa, colorama, awscli, tensorboardX, jsonpickle, flask-cors, ftfy, word2number, responses, flaky, unidecode, allennlp\n",
963 | " Found existing installation: botocore 1.12.175\n",
964 | " Uninstalling botocore-1.12.175:\n",
965 | " Successfully uninstalled botocore-1.12.175\n",
966 | " Found existing installation: rsa 4.0\n",
967 | " Uninstalling rsa-4.0:\n",
968 | " Successfully uninstalled rsa-4.0\n",
969 | "Successfully installed allennlp-0.8.4 awscli-1.16.190 botocore-1.12.180 colorama-0.3.9 conllu-0.11 flaky-3.6.0 flask-cors-3.0.8 ftfy-5.5.1 jsonnet-0.13.0 jsonpickle-1.2 numpydoc-0.9.1 overrides-1.9 parsimonious-0.8.1 pytorch-pretrained-bert-0.6.2 regex-2019.6.8 responses-0.10.6 rsa-3.4.2 tensorboardX-1.7 unidecode-1.1.1 word2number-1.1\n"
970 | ],
971 | "name": "stdout"
972 | },
973 | {
974 | "output_type": "display_data",
975 | "data": {
976 | "application/vnd.colab-display-data+json": {
977 | "pip_warning": {
978 | "packages": [
979 | "rsa"
980 | ]
981 | }
982 | }
983 | },
984 | "metadata": {
985 | "tags": []
986 | }
987 | }
988 | ]
989 | },
990 | {
991 | "cell_type": "markdown",
992 | "metadata": {
993 | "id": "usE8HacRrOfK",
994 | "colab_type": "text"
995 | },
996 | "source": [
997 | "### Approach\n",
998 | "\n",
999 | "We will use mostly the same code as in our previous classification experiments. For brevity, I have compacted it all here and omitted parts that arent required for this section. For more information, see the previous section."
1000 | ]
1001 | },
1002 | {
1003 | "cell_type": "code",
1004 | "metadata": {
1005 | "id": "AN93oU4plXgA",
1006 | "colab_type": "code",
1007 | "colab": {}
1008 | },
1009 | "source": [
1010 | "import torch\n",
1011 | "import torch.nn as nn\n",
1012 | "import torch.nn.functional as F\n",
1013 | "from typing import List, Tuple\n",
1014 | "import os\n",
1015 | "import io\n",
1016 | "import re\n",
1017 | "import codecs\n",
1018 | "import numpy as np\n",
1019 | "from collections import Counter\n",
1020 | "from torch.utils.data import DataLoader, TensorDataset\n",
1021 | "\n",
1022 | "class LSTMClassifier(nn.Module):\n",
1023 | "\n",
1024 | " def __init__(self, embeddings, num_classes, embed_dims, rnn_units, rnn_layers=1, dropout=0.5, hidden_units=[]):\n",
1025 | " super().__init__()\n",
1026 | " self.embeddings = embeddings\n",
1027 | " self.dropout = nn.Dropout(dropout)\n",
1028 | " self.rnn = torch.nn.LSTM(embed_dims,\n",
1029 | " rnn_units,\n",
1030 | " rnn_layers,\n",
1031 | " dropout=dropout,\n",
1032 | " bidirectional=False,\n",
1033 | " batch_first=False)\n",
1034 | " nn.init.orthogonal_(self.rnn.weight_hh_l0)\n",
1035 | " nn.init.orthogonal_(self.rnn.weight_ih_l0)\n",
1036 | " sequence = []\n",
1037 | " input_units = rnn_units\n",
1038 | " output_units = rnn_units\n",
1039 | " for h in hidden_units:\n",
1040 | " sequence.append(nn.Linear(input_units, h))\n",
1041 | " input_units = h\n",
1042 | " output_units = h\n",
1043 | " \n",
1044 | " sequence.append(nn.Linear(output_units, num_classes))\n",
1045 | " self.outputs = nn.Sequential(*sequence)\n",
1046 | " \n",
1047 | " \n",
1048 | " def forward(self, inputs):\n",
1049 | " one_hots, lengths = inputs\n",
1050 | " embed = self.dropout(self.embeddings(one_hots))\n",
1051 | " embed = embed.transpose(0, 1)\n",
1052 | " packed = torch.nn.utils.rnn.pack_padded_sequence(embed, lengths.tolist())\n",
1053 | " _, hidden = self.rnn(packed)\n",
1054 | " hidden = hidden[0].view(hidden[0].shape[1:])\n",
1055 | " linear = self.outputs(hidden)\n",
1056 | " return F.log_softmax(linear, dim=-1)\n",
1057 | "\n",
1058 | "class ConfusionMatrix:\n",
1059 | " \"\"\"Confusion matrix with metrics\n",
1060 | "\n",
1061 | " This class accumulates classification output, and tracks it in a confusion matrix.\n",
1062 | " Metrics are available that use the confusion matrix\n",
1063 | " \"\"\"\n",
1064 | " def __init__(self, labels):\n",
1065 | " \"\"\"Constructor with input labels\n",
1066 | "\n",
1067 | " :param labels: Either a dictionary (`k=int,v=str`) or an array of labels\n",
1068 | " \"\"\"\n",
1069 | " if type(labels) is dict:\n",
1070 | " self.labels = []\n",
1071 | " for i in range(len(labels)):\n",
1072 | " self.labels.append(labels[i])\n",
1073 | " else:\n",
1074 | " self.labels = labels\n",
1075 | " nc = len(self.labels)\n",
1076 | " self._cm = np.zeros((nc, nc), dtype=np.int)\n",
1077 | "\n",
1078 | " def add(self, truth, guess):\n",
1079 | " \"\"\"Add a single value to the confusion matrix based off `truth` and `guess`\n",
1080 | "\n",
1081 | " :param truth: The real `y` value (or ground truth label)\n",
1082 | " :param guess: The guess for `y` value (or assertion)\n",
1083 | " \"\"\"\n",
1084 | "\n",
1085 | " self._cm[truth, guess] += 1\n",
1086 | "\n",
1087 | " def __str__(self):\n",
1088 | " values = []\n",
1089 | " width = max(8, max(len(x) for x in self.labels) + 1)\n",
1090 | " for i, label in enumerate([''] + self.labels):\n",
1091 | " values += [\"{:>{width}}\".format(label, width=width+1)]\n",
1092 | " values += ['\\n']\n",
1093 | " for i, label in enumerate(self.labels):\n",
1094 | " values += [\"{:>{width}}\".format(label, width=width+1)]\n",
1095 | " for j in range(len(self.labels)):\n",
1096 | " values += [\"{:{width}d}\".format(self._cm[i, j], width=width + 1)]\n",
1097 | " values += ['\\n']\n",
1098 | " values += ['\\n']\n",
1099 | " return ''.join(values)\n",
1100 | "\n",
1101 | " def save(self, outfile):\n",
1102 | " ordered_fieldnames = OrderedDict([(\"labels\", None)] + [(l, None) for l in self.labels])\n",
1103 | " with open(outfile, 'w') as f:\n",
1104 | " dw = csv.DictWriter(f, delimiter=',', fieldnames=ordered_fieldnames)\n",
1105 | " dw.writeheader()\n",
1106 | " for index, row in enumerate(self._cm):\n",
1107 | " row_dict = {l: row[i] for i, l in enumerate(self.labels)}\n",
1108 | " row_dict.update({\"labels\": self.labels[index]})\n",
1109 | " dw.writerow(row_dict)\n",
1110 | "\n",
1111 | " def reset(self):\n",
1112 | " \"\"\"Reset the matrix\n",
1113 | " \"\"\"\n",
1114 | " self._cm *= 0\n",
1115 | "\n",
1116 | " def get_correct(self):\n",
1117 | " \"\"\"Get the diagonals of the confusion matrix\n",
1118 | "\n",
1119 | " :return: (``int``) Number of correct classifications\n",
1120 | " \"\"\"\n",
1121 | " return self._cm.diagonal().sum()\n",
1122 | "\n",
1123 | " def get_total(self):\n",
1124 | " \"\"\"Get total classifications\n",
1125 | "\n",
1126 | " :return: (``int``) total classifications\n",
1127 | " \"\"\"\n",
1128 | " return self._cm.sum()\n",
1129 | "\n",
1130 | " def get_acc(self):\n",
1131 | " \"\"\"Get the accuracy\n",
1132 | "\n",
1133 | " :return: (``float``) accuracy\n",
1134 | " \"\"\"\n",
1135 | " return float(self.get_correct())/self.get_total()\n",
1136 | "\n",
1137 | " def get_recall(self):\n",
1138 | " \"\"\"Get the recall\n",
1139 | "\n",
1140 | " :return: (``float``) recall\n",
1141 | " \"\"\"\n",
1142 | " total = np.sum(self._cm, axis=1)\n",
1143 | " total = (total == 0) + total\n",
1144 | " return np.diag(self._cm) / total.astype(float)\n",
1145 | "\n",
1146 | " def get_support(self):\n",
1147 | " return np.sum(self._cm, axis=1)\n",
1148 | "\n",
1149 | " def get_precision(self):\n",
1150 | " \"\"\"Get the precision\n",
1151 | " :return: (``float``) precision\n",
1152 | " \"\"\"\n",
1153 | "\n",
1154 | " total = np.sum(self._cm, axis=0)\n",
1155 | " total = (total == 0) + total\n",
1156 | " return np.diag(self._cm) / total.astype(float)\n",
1157 | "\n",
1158 | " def get_mean_precision(self):\n",
1159 | " \"\"\"Get the mean precision across labels\n",
1160 | "\n",
1161 | " :return: (``float``) mean precision\n",
1162 | " \"\"\"\n",
1163 | " return np.mean(self.get_precision())\n",
1164 | "\n",
1165 | " def get_weighted_precision(self):\n",
1166 | " return np.sum(self.get_precision() * self.get_support())/float(self.get_total())\n",
1167 | "\n",
1168 | " def get_mean_recall(self):\n",
1169 | " \"\"\"Get the mean recall across labels\n",
1170 | "\n",
1171 | " :return: (``float``) mean recall\n",
1172 | " \"\"\"\n",
1173 | " return np.mean(self.get_recall())\n",
1174 | "\n",
1175 | " def get_weighted_recall(self):\n",
1176 | " return np.sum(self.get_recall() * self.get_support())/float(self.get_total())\n",
1177 | "\n",
1178 | " def get_weighted_f(self, beta=1):\n",
1179 | " return np.sum(self.get_class_f(beta) * self.get_support())/float(self.get_total())\n",
1180 | "\n",
1181 | " def get_macro_f(self, beta=1):\n",
1182 | " \"\"\"Get the macro F_b, with adjustable beta (defaulting to F1)\n",
1183 | "\n",
1184 | " :param beta: (``float``) defaults to 1 (F1)\n",
1185 | " :return: (``float``) macro F_b\n",
1186 | " \"\"\"\n",
1187 | " if beta < 0:\n",
1188 | " raise Exception('Beta must be greater than 0')\n",
1189 | " return np.mean(self.get_class_f(beta))\n",
1190 | "\n",
1191 | " def get_class_f(self, beta=1):\n",
1192 | " p = self.get_precision()\n",
1193 | " r = self.get_recall()\n",
1194 | "\n",
1195 | " b = beta*beta\n",
1196 | " d = (b * p + r)\n",
1197 | " d = (d == 0) + d\n",
1198 | "\n",
1199 | " return (b + 1) * p * r / d\n",
1200 | "\n",
1201 | " def get_f(self, beta=1):\n",
1202 | " \"\"\"Get 2 class F_b, with adjustable beta (defaulting to F1)\n",
1203 | "\n",
1204 | " :param beta: (``float``) defaults to 1 (F1)\n",
1205 | " :return: (``float``) 2-class F_b\n",
1206 | " \"\"\"\n",
1207 | " p = self.get_precision()[1]\n",
1208 | " r = self.get_recall()[1]\n",
1209 | " if beta < 0:\n",
1210 | " raise Exception('Beta must be greater than 0')\n",
1211 | " d = (beta*beta * p + r)\n",
1212 | " if d == 0:\n",
1213 | " return 0\n",
1214 | " return (beta*beta + 1) * p * r / d\n",
1215 | "\n",
1216 | " def get_all_metrics(self):\n",
1217 | " \"\"\"Make a map of metrics suitable for reporting, keyed by metric name\n",
1218 | "\n",
1219 | " :return: (``dict``) Map of metrics keyed by metric names\n",
1220 | " \"\"\"\n",
1221 | " metrics = {'acc': self.get_acc()}\n",
1222 | " # If 2 class, assume second class is positive AKA 1\n",
1223 | " if len(self.labels) == 2:\n",
1224 | " metrics['precision'] = self.get_precision()[1]\n",
1225 | " metrics['recall'] = self.get_recall()[1]\n",
1226 | " metrics['f1'] = self.get_f(1)\n",
1227 | " else:\n",
1228 | " metrics['mean_precision'] = self.get_mean_precision()\n",
1229 | " metrics['mean_recall'] = self.get_mean_recall()\n",
1230 | " metrics['macro_f1'] = self.get_macro_f(1)\n",
1231 | " metrics['weighted_precision'] = self.get_weighted_precision()\n",
1232 | " metrics['weighted_recall'] = self.get_weighted_recall()\n",
1233 | " metrics['weighted_f1'] = self.get_weighted_f(1)\n",
1234 | " return metrics\n",
1235 | "\n",
1236 | " def add_batch(self, truth, guess):\n",
1237 | " \"\"\"Add a batch of data to the confusion matrix\n",
1238 | "\n",
1239 | " :param truth: The truth tensor\n",
1240 | " :param guess: The guess tensor\n",
1241 | " :return:\n",
1242 | " \"\"\"\n",
1243 | " for truth_i, guess_i in zip(truth, guess):\n",
1244 | " self.add(truth_i, guess_i)\n",
1245 | "\n",
1246 | "class Trainer:\n",
1247 | " def __init__(self, optimizer: torch.optim.Optimizer):\n",
1248 | " self.optimizer = optimizer\n",
1249 | "\n",
1250 | " def run(self, model, labels, train, loss, batch_size): \n",
1251 | " model.train() \n",
1252 | " train_loader = DataLoader(train, batch_size=batch_size, shuffle=True)\n",
1253 | "\n",
1254 | " cm = ConfusionMatrix(labels)\n",
1255 | "\n",
1256 | " for batch in train_loader:\n",
1257 | " loss_value, y_pred, y_actual = self.update(model, loss, batch)\n",
1258 | " _, best = y_pred.max(1)\n",
1259 | " yt = y_actual.cpu().int().numpy()\n",
1260 | " yp = best.cpu().int().numpy()\n",
1261 | " cm.add_batch(yt, yp)\n",
1262 | "\n",
1263 | " print(cm.get_all_metrics())\n",
1264 | " return cm\n",
1265 | " \n",
1266 | " def update(self, model, loss, batch):\n",
1267 | " self.optimizer.zero_grad()\n",
1268 | " x, lengths, y = batch\n",
1269 | " lengths, perm_idx = lengths.sort(0, descending=True)\n",
1270 | " x_sorted = x[perm_idx]\n",
1271 | " y_sorted = y[perm_idx]\n",
1272 | " y_sorted = y_sorted.to('cuda:0')\n",
1273 | " inputs = (x_sorted.to('cuda:0'), lengths)\n",
1274 | " y_pred = model(inputs)\n",
1275 | " loss_value = loss(y_pred, y_sorted)\n",
1276 | " loss_value.backward()\n",
1277 | " self.optimizer.step()\n",
1278 | " return loss_value.item(), y_pred, y_sorted\n",
1279 | "\n",
1280 | "class Evaluator:\n",
1281 | " def __init__(self):\n",
1282 | " pass\n",
1283 | "\n",
1284 | " def run(self, model, labels, dataset, batch_size=1):\n",
1285 | " model.eval()\n",
1286 | " valid_loader = DataLoader(dataset, batch_size=batch_size)\n",
1287 | " cm = ConfusionMatrix(labels)\n",
1288 | " for batch in valid_loader:\n",
1289 | " y_pred, y_actual = self.inference(model, batch)\n",
1290 | " _, best = y_pred.max(1)\n",
1291 | " yt = y_actual.cpu().int().numpy()\n",
1292 | " yp = best.cpu().int().numpy()\n",
1293 | " cm.add_batch(yt, yp)\n",
1294 | " return cm\n",
1295 | "\n",
1296 | " def inference(self, model, batch):\n",
1297 | " with torch.no_grad():\n",
1298 | " x, lengths, y = batch\n",
1299 | " lengths, perm_idx = lengths.sort(0, descending=True)\n",
1300 | " x_sorted = x[perm_idx]\n",
1301 | " y_sorted = y[perm_idx]\n",
1302 | " y_sorted = y_sorted.to('cuda:0')\n",
1303 | " inputs = (x_sorted.to('cuda:0'), lengths)\n",
1304 | " y_pred = model(inputs)\n",
1305 | " return y_pred, y_sorted\n",
1306 | "\n",
1307 | "def fit(model, labels, optimizer, loss, epochs, batch_size, train, valid, test):\n",
1308 | "\n",
1309 | " trainer = Trainer(optimizer)\n",
1310 | " evaluator = Evaluator()\n",
1311 | " best_acc = 0.0\n",
1312 | " \n",
1313 | " for epoch in range(epochs):\n",
1314 | " print('EPOCH {}'.format(epoch + 1))\n",
1315 | " print('=================================')\n",
1316 | " print('Training Results')\n",
1317 | " cm = trainer.run(model, labels, train, loss, batch_size)\n",
1318 | " print('Validation Results')\n",
1319 | " cm = evaluator.run(model, labels, valid)\n",
1320 | " print(cm.get_all_metrics())\n",
1321 | " if cm.get_acc() > best_acc:\n",
1322 | " print('New best model {:.2f}'.format(cm.get_acc()))\n",
1323 | " best_acc = cm.get_acc()\n",
1324 | " torch.save(model.state_dict(), './checkpoint.pth')\n",
1325 | " if test:\n",
1326 | " model.load_state_dict(torch.load('./checkpoint.pth'))\n",
1327 | " cm = evaluator.run(model, labels, test)\n",
1328 | " print('Final result')\n",
1329 | " print(cm.get_all_metrics())\n",
1330 | " return cm.get_acc()\n",
1331 | "\n",
1332 | "def whitespace_tokenizer(words: str) -> List[str]:\n",
1333 | " return words.split() \n",
1334 | "\n",
1335 | "def sst2_tokenizer(words: str) -> List[str]:\n",
1336 | " REPLACE = { \"'s\": \" 's \",\n",
1337 | " \"'ve\": \" 've \",\n",
1338 | " \"n't\": \" n't \",\n",
1339 | " \"'re\": \" 're \",\n",
1340 | " \"'d\": \" 'd \",\n",
1341 | " \"'ll\": \" 'll \",\n",
1342 | " \",\": \" , \",\n",
1343 | " \"!\": \" ! \",\n",
1344 | " }\n",
1345 | " words = words.lower()\n",
1346 | " words = re.sub(r\"[^A-Za-z0-9(),!?\\'\\`]\", \" \", words)\n",
1347 | " for k, v in REPLACE.items():\n",
1348 | " words = words.replace(k, v)\n",
1349 | " return [w.strip() for w in words.split()]\n",
1350 | "\n",
1351 | "\n",
1352 | "class Reader:\n",
1353 | "\n",
1354 | " def __init__(self, files, lowercase=True, min_freq=0,\n",
1355 | " tokenizer=sst2_tokenizer, vectorizer=None):\n",
1356 | " self.lowercase = lowercase\n",
1357 | " self.tokenizer = tokenizer\n",
1358 | " build_vocab = vectorizer is None\n",
1359 | " self.vectorizer = vectorizer if vectorizer else self._vectorizer\n",
1360 | " x = Counter()\n",
1361 | " y = Counter()\n",
1362 | " for file_name in files:\n",
1363 | " if file_name is None:\n",
1364 | " continue\n",
1365 | " with codecs.open(file_name, encoding='utf-8', mode='r') as f:\n",
1366 | " for line in f:\n",
1367 | " words = line.split()\n",
1368 | " y.update(words[0])\n",
1369 | "\n",
1370 | " if build_vocab:\n",
1371 | " words = self.tokenizer(' '.join(words[1:]))\n",
1372 | " words = words if not self.lowercase else [w.lower() for w in words]\n",
1373 | " x.update(words)\n",
1374 | " self.labels = list(y.keys())\n",
1375 | "\n",
1376 | " if build_vocab:\n",
1377 | " x = dict(filter(lambda cnt: cnt[1] >= min_freq, x.items()))\n",
1378 | " alpha = list(x.keys())\n",
1379 | " alpha.sort()\n",
1380 | " self.vocab = {w: i+1 for i, w in enumerate(alpha)}\n",
1381 | " self.vocab['[PAD]'] = 0\n",
1382 | "\n",
1383 | " self.labels.sort()\n",
1384 | "\n",
1385 | " def _vectorizer(self, words: List[str]) -> List[int]:\n",
1386 | " return [self.vocab.get(w, 0) for w in words]\n",
1387 | "\n",
1388 | " def load(self, filename: str) -> TensorDataset:\n",
1389 | " label2index = {l: i for i, l in enumerate(self.labels)}\n",
1390 | " xs = []\n",
1391 | " lengths = []\n",
1392 | " ys = []\n",
1393 | " with codecs.open(filename, encoding='utf-8', mode='r') as f:\n",
1394 | " for line in f:\n",
1395 | " words = line.split()\n",
1396 | " ys.append(label2index[words[0]])\n",
1397 | " words = self.tokenizer(' '.join(words[1:]))\n",
1398 | " words = words if not self.lowercase else [w.lower() for w in words]\n",
1399 | " vec = self.vectorizer(words)\n",
1400 | " lengths.append(len(vec))\n",
1401 | " xs.append(torch.tensor(vec, dtype=torch.long))\n",
1402 | " x_tensor = torch.nn.utils.rnn.pad_sequence(xs, batch_first=True)\n",
1403 | " lengths_tensor = torch.tensor(lengths, dtype=torch.long)\n",
1404 | " y_tensor = torch.tensor(ys, dtype=torch.long)\n",
1405 | " return TensorDataset(x_tensor, lengths_tensor, y_tensor)"
1406 | ],
1407 | "execution_count": 0,
1408 | "outputs": []
1409 | },
1410 | {
1411 | "cell_type": "markdown",
1412 | "metadata": {
1413 | "id": "fJpvG1NTxSo9",
1414 | "colab_type": "text"
1415 | },
1416 | "source": [
1417 | "### The new thing: set up to use ELMo"
1418 | ]
1419 | },
1420 | {
1421 | "cell_type": "code",
1422 | "metadata": {
1423 | "id": "HH2fR03DxeNw",
1424 | "colab_type": "code",
1425 | "colab": {}
1426 | },
1427 | "source": [
1428 | "from allennlp.modules.elmo import Elmo, batch_to_ids\n",
1429 | "\n",
1430 | "\n",
1431 | "def elmo_vectorizer(sentence):\n",
1432 | " character_ids = batch_to_ids([sentence])\n",
1433 | " return character_ids.squeeze(0)\n",
1434 | "\n",
1435 | " \n",
1436 | "class ElmoEmbedding(nn.Module):\n",
1437 | " def __init__(self, options_file, weight_file, dropout=0.5):\n",
1438 | " super().__init__()\n",
1439 | " self.elmo = Elmo(options_file, weight_file, 2, dropout=dropout)\n",
1440 | " def forward(self, xch):\n",
1441 | " elmo = self.elmo(xch)\n",
1442 | " e1, e2 = elmo['elmo_representations']\n",
1443 | " mask = elmo['mask']\n",
1444 | " embeddings = (e1 + e2) * mask.float().unsqueeze(-1)\n",
1445 | " return embeddings\n"
1446 | ],
1447 | "execution_count": 0,
1448 | "outputs": []
1449 | },
1450 | {
1451 | "cell_type": "markdown",
1452 | "metadata": {
1453 | "id": "Xuxm3Ugau6jW",
1454 | "colab_type": "text"
1455 | },
1456 | "source": [
1457 | "As before, we are going to load up our data with a reader. This time, though, we will provide a vectorizer for ELMo. In our simple example `Reader`, we only allow a single feature as our input vector to our classifier, so we can stop counting up our vocab. In real life, you probably want to support both word vector features and context vector features so you might want to modify the code to support both. This is a very common approach -- just using ELMo to augment an existing setup. Here, we just look at using ELMo features by themselves.\n"
1458 | ]
1459 | },
1460 | {
1461 | "cell_type": "code",
1462 | "metadata": {
1463 | "id": "-NzW34_RwGUw",
1464 | "colab_type": "code",
1465 | "outputId": "e667b3f4-a3b3-417f-8b1c-61180acd83f1",
1466 | "colab": {
1467 | "base_uri": "https://localhost:8080/",
1468 | "height": 377
1469 | }
1470 | },
1471 | "source": [
1472 | "!wget https://www.dropbox.com/s/08km2ean8bkt7p3/trec.tar.gz?dl=1\n",
1473 | "!tar -xzf 'trec.tar.gz?dl=1'"
1474 | ],
1475 | "execution_count": 14,
1476 | "outputs": [
1477 | {
1478 | "output_type": "stream",
1479 | "text": [
1480 | "--2019-06-30 19:21:55-- https://www.dropbox.com/s/08km2ean8bkt7p3/trec.tar.gz?dl=1\n",
1481 | "Resolving www.dropbox.com (www.dropbox.com)... 162.125.8.1, 2620:100:601b:1::a27d:801\n",
1482 | "Connecting to www.dropbox.com (www.dropbox.com)|162.125.8.1|:443... connected.\n",
1483 | "HTTP request sent, awaiting response... 301 Moved Permanently\n",
1484 | "Location: /s/dl/08km2ean8bkt7p3/trec.tar.gz [following]\n",
1485 | "--2019-06-30 19:21:56-- https://www.dropbox.com/s/dl/08km2ean8bkt7p3/trec.tar.gz\n",
1486 | "Reusing existing connection to www.dropbox.com:443.\n",
1487 | "HTTP request sent, awaiting response... 302 Found\n",
1488 | "Location: https://uc7fa2ae1930db92d5916f06ba12.dl.dropboxusercontent.com/cd/0/get/Aj3XoF2sz7098a7ulJjBQP5DA6LkkkTQEAgFciDKPLgTZrHSUdejKQ7f8hkI3LiEt0BP_zf3LYg-ul8IZkevEcRCL4oxvYa8Uw-4SCn9GK2Lqw/file?dl=1# [following]\n",
1489 | "--2019-06-30 19:21:56-- https://uc7fa2ae1930db92d5916f06ba12.dl.dropboxusercontent.com/cd/0/get/Aj3XoF2sz7098a7ulJjBQP5DA6LkkkTQEAgFciDKPLgTZrHSUdejKQ7f8hkI3LiEt0BP_zf3LYg-ul8IZkevEcRCL4oxvYa8Uw-4SCn9GK2Lqw/file?dl=1\n",
1490 | "Resolving uc7fa2ae1930db92d5916f06ba12.dl.dropboxusercontent.com (uc7fa2ae1930db92d5916f06ba12.dl.dropboxusercontent.com)... 162.125.8.6, 2620:100:601b:6::a27d:806\n",
1491 | "Connecting to uc7fa2ae1930db92d5916f06ba12.dl.dropboxusercontent.com (uc7fa2ae1930db92d5916f06ba12.dl.dropboxusercontent.com)|162.125.8.6|:443... connected.\n",
1492 | "HTTP request sent, awaiting response... 200 OK\n",
1493 | "Length: 117253 (115K) [application/binary]\n",
1494 | "Saving to: ‘trec.tar.gz?dl=1’\n",
1495 | "\n",
1496 | "trec.tar.gz?dl=1 100%[===================>] 114.50K --.-KB/s in 0.07s \n",
1497 | "\n",
1498 | "2019-06-30 19:21:56 (1.71 MB/s) - ‘trec.tar.gz?dl=1’ saved [117253/117253]\n",
1499 | "\n"
1500 | ],
1501 | "name": "stdout"
1502 | }
1503 | ]
1504 | },
1505 | {
1506 | "cell_type": "markdown",
1507 | "metadata": {
1508 | "id": "hPy-A82s048z",
1509 | "colab_type": "text"
1510 | },
1511 | "source": [
1512 | "We will set up our reader slightly differently than in the last experiment. Here we will use an `elmo_vectorizer`"
1513 | ]
1514 | },
1515 | {
1516 | "cell_type": "code",
1517 | "metadata": {
1518 | "id": "aI5r7tmlvgUu",
1519 | "colab_type": "code",
1520 | "outputId": "b4e390a9-7886-46d5-87a4-5c3b07c9602b",
1521 | "colab": {
1522 | "base_uri": "https://localhost:8080/",
1523 | "height": 54
1524 | }
1525 | },
1526 | "source": [
1527 | "BASE = 'trec'\n",
1528 | "TRAIN = os.path.join(BASE, 'trec.nodev.utf8')\n",
1529 | "VALID = os.path.join(BASE, 'trec.dev.utf8')\n",
1530 | "TEST = os.path.join(BASE, 'trec.test.utf8')\n",
1531 | "\n",
1532 | "\n",
1533 | "\n",
1534 | "reader = Reader((TRAIN, VALID, TEST,), lowercase=False, vectorizer=elmo_vectorizer)\n",
1535 | "train = reader.load(TRAIN)\n",
1536 | "valid = reader.load(VALID)\n",
1537 | "test = reader.load(TEST)"
1538 | ],
1539 | "execution_count": 15,
1540 | "outputs": [
1541 | {
1542 | "output_type": "stream",
1543 | "text": [
1544 | "/usr/local/lib/python3.6/dist-packages/ipykernel_launcher.py:392: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n"
1545 | ],
1546 | "name": "stderr"
1547 | }
1548 | ]
1549 | },
1550 | {
1551 | "cell_type": "markdown",
1552 | "metadata": {
1553 | "id": "YmE1VJSpws30",
1554 | "colab_type": "text"
1555 | },
1556 | "source": [
1557 | "Building the network is basically the same as before, but we are using ELMo instead of word vectors. The command below will take a few minutes -- this is a much larger (forward) network than before, even though the learnable parameters havent really changed"
1558 | ]
1559 | },
1560 | {
1561 | "cell_type": "code",
1562 | "metadata": {
1563 | "id": "_7y6t65zw_iV",
1564 | "colab_type": "code",
1565 | "outputId": "5e06aaca-85b9-4da8-ecd9-49fc81c553b8",
1566 | "colab": {
1567 | "base_uri": "https://localhost:8080/",
1568 | "height": 1000
1569 | }
1570 | },
1571 | "source": [
1572 | "options_file = \"https://allennlp.s3.amazonaws.com/models/elmo/2x4096_512_2048cnn_2xhighway/elmo_2x4096_512_2048cnn_2xhighway_options.json\"\n",
1573 | "weight_file = \"https://allennlp.s3.amazonaws.com/models/elmo/2x4096_512_2048cnn_2xhighway/elmo_2x4096_512_2048cnn_2xhighway_weights.hdf5\"\n",
1574 | "embeddings = ElmoEmbedding(options_file, weight_file)\n",
1575 | "model = LSTMClassifier(embeddings, len(reader.labels), embed_dims=1024, rnn_units=100, hidden_units=[100])\n",
1576 | "\n",
1577 | "num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
1578 | "print(f\"Model has {num_params} parameters\") \n",
1579 | "\n",
1580 | "\n",
1581 | "model.to('cuda:0')\n",
1582 | "loss = torch.nn.NLLLoss()\n",
1583 | "loss = loss.to('cuda:0')\n",
1584 | "\n",
1585 | "learnable_params = [p for p in model.parameters() if p.requires_grad]\n",
1586 | "optimizer = torch.optim.Adadelta(learnable_params, lr=1.0)\n",
1587 | "\n",
1588 | "fit(model, reader.labels, optimizer, loss, 12, 50, train, valid, test)"
1589 | ],
1590 | "execution_count": 16,
1591 | "outputs": [
1592 | {
1593 | "output_type": "stream",
1594 | "text": [
1595 | "100%|██████████| 336/336 [00:00<00:00, 192499.13B/s]\n",
1596 | "100%|██████████| 374434792/374434792 [00:07<00:00, 47927932.74B/s]\n",
1597 | "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/rnn.py:54: UserWarning: dropout option adds dropout after all but last recurrent layer, so non-zero dropout expects num_layers greater than 1, but got dropout=0.5 and num_layers=1\n",
1598 | " \"num_layers={}\".format(dropout, num_layers))\n"
1599 | ],
1600 | "name": "stderr"
1601 | },
1602 | {
1603 | "output_type": "stream",
1604 | "text": [
1605 | "Model has 461114 parameters\n",
1606 | "EPOCH 1\n",
1607 | "=================================\n",
1608 | "Training Results\n",
1609 | "{'acc': 0.5608, 'mean_precision': 0.6483439079531595, 'mean_recall': 0.48504498768062404, 'macro_f1': 0.49849106627634976, 'weighted_precision': 0.5726962123308302, 'weighted_recall': 0.5608, 'weighted_f1': 0.5554788741148295}\n",
1610 | "Validation Results\n",
1611 | "{'acc': 0.7942477876106194, 'mean_precision': 0.8454765420711672, 'mean_recall': 0.7033276693176708, 'macro_f1': 0.7144610587048233, 'weighted_precision': 0.8102895316760798, 'weighted_recall': 0.7942477876106194, 'weighted_f1': 0.7887777126414355}\n",
1612 | "New best model 0.79\n",
1613 | "EPOCH 2\n",
1614 | "=================================\n",
1615 | "Training Results\n",
1616 | "{'acc': 0.806, 'mean_precision': 0.799350329535837, 'mean_recall': 0.7675872074813431, 'macro_f1': 0.780728542640896, 'weighted_precision': 0.8062829252605372, 'weighted_recall': 0.806, 'weighted_f1': 0.8058397968891035}\n",
1617 | "Validation Results\n",
1618 | "{'acc': 0.8628318584070797, 'mean_precision': 0.8566120843164245, 'mean_recall': 0.7974693543452065, 'macro_f1': 0.8182667932069821, 'weighted_precision': 0.8675313347987196, 'weighted_recall': 0.8628318584070797, 'weighted_f1': 0.8625189847178025}\n",
1619 | "New best model 0.86\n",
1620 | "EPOCH 3\n",
1621 | "=================================\n",
1622 | "Training Results\n",
1623 | "{'acc': 0.8678, 'mean_precision': 0.8675015318855253, 'mean_recall': 0.8346532456259291, 'macro_f1': 0.8484927361816553, 'weighted_precision': 0.8682517001586247, 'weighted_recall': 0.8678, 'weighted_f1': 0.8677362764323896}\n",
1624 | "Validation Results\n",
1625 | "{'acc': 0.8451327433628318, 'mean_precision': 0.8284211573091326, 'mean_recall': 0.8093879960516328, 'macro_f1': 0.8110225138172149, 'weighted_precision': 0.8691115810447773, 'weighted_recall': 0.8451327433628318, 'weighted_f1': 0.8465397357783465}\n",
1626 | "EPOCH 4\n",
1627 | "=================================\n",
1628 | "Training Results\n",
1629 | "{'acc': 0.8872, 'mean_precision': 0.8764661421280517, 'mean_recall': 0.8546009991636673, 'macro_f1': 0.8643704500888516, 'weighted_precision': 0.887561002584866, 'weighted_recall': 0.8872, 'weighted_f1': 0.8872276932804481}\n",
1630 | "Validation Results\n",
1631 | "{'acc': 0.911504424778761, 'mean_precision': 0.8515226408802844, 'mean_recall': 0.8617077224611437, 'macro_f1': 0.8561887828467433, 'weighted_precision': 0.9122749445064818, 'weighted_recall': 0.911504424778761, 'weighted_f1': 0.9118158975632823}\n",
1632 | "New best model 0.91\n",
1633 | "EPOCH 5\n",
1634 | "=================================\n",
1635 | "Training Results\n",
1636 | "{'acc': 0.9034, 'mean_precision': 0.9068352283292169, 'mean_recall': 0.8843802250756597, 'macro_f1': 0.8946296708241798, 'weighted_precision': 0.9040643245149811, 'weighted_recall': 0.9034, 'weighted_f1': 0.9035884797279896}\n",
1637 | "Validation Results\n",
1638 | "{'acc': 0.8871681415929203, 'mean_precision': 0.8310659320074388, 'mean_recall': 0.841863153832931, 'macro_f1': 0.8355145420604436, 'weighted_precision': 0.8885588116644558, 'weighted_recall': 0.8871681415929203, 'weighted_f1': 0.8871217953267708}\n",
1639 | "EPOCH 6\n",
1640 | "=================================\n",
1641 | "Training Results\n",
1642 | "{'acc': 0.9136, 'mean_precision': 0.9192746333288291, 'mean_recall': 0.8914669258673943, 'macro_f1': 0.903828395837297, 'weighted_precision': 0.9139512391702285, 'weighted_recall': 0.9136, 'weighted_f1': 0.913614469191629}\n",
1643 | "Validation Results\n",
1644 | "{'acc': 0.9048672566371682, 'mean_precision': 0.8453363940567148, 'mean_recall': 0.8564313119872883, 'macro_f1': 0.8503405229048734, 'weighted_precision': 0.905338925288292, 'weighted_recall': 0.9048672566371682, 'weighted_f1': 0.9048873521303485}\n",
1645 | "EPOCH 7\n",
1646 | "=================================\n",
1647 | "Training Results\n",
1648 | "{'acc': 0.9184, 'mean_precision': 0.9217954417236368, 'mean_recall': 0.9035341951741954, 'macro_f1': 0.9119837710405331, 'weighted_precision': 0.9188094085046438, 'weighted_recall': 0.9184, 'weighted_f1': 0.9185065760698944}\n",
1649 | "Validation Results\n",
1650 | "{'acc': 0.9004424778761062, 'mean_precision': 0.8343728710441182, 'mean_recall': 0.8542340405568197, 'macro_f1': 0.8426629413676556, 'weighted_precision': 0.9019737446759757, 'weighted_recall': 0.9004424778761062, 'weighted_f1': 0.9006848775343249}\n",
1651 | "EPOCH 8\n",
1652 | "=================================\n",
1653 | "Training Results\n",
1654 | "{'acc': 0.9252, 'mean_precision': 0.9227662229391251, 'mean_recall': 0.9085845822017588, 'macro_f1': 0.9152560555320276, 'weighted_precision': 0.9254505663098069, 'weighted_recall': 0.9252, 'weighted_f1': 0.9252609572329403}\n",
1655 | "Validation Results\n",
1656 | "{'acc': 0.8960176991150443, 'mean_precision': 0.8848359324236518, 'mean_recall': 0.8515305594157283, 'macro_f1': 0.8641410893717477, 'weighted_precision': 0.897474904298379, 'weighted_recall': 0.8960176991150443, 'weighted_f1': 0.8954448264468791}\n",
1657 | "EPOCH 9\n",
1658 | "=================================\n",
1659 | "Training Results\n",
1660 | "{'acc': 0.9366, 'mean_precision': 0.9421415595699045, 'mean_recall': 0.9253828413493465, 'macro_f1': 0.9332020129586184, 'weighted_precision': 0.9367741614764589, 'weighted_recall': 0.9366, 'weighted_f1': 0.9366203849323997}\n",
1661 | "Validation Results\n",
1662 | "{'acc': 0.9004424778761062, 'mean_precision': 0.8408851907016573, 'mean_recall': 0.8542708251432938, 'macro_f1': 0.8466559111080202, 'weighted_precision': 0.9022774132643538, 'weighted_recall': 0.9004424778761062, 'weighted_f1': 0.9006261595204735}\n",
1663 | "EPOCH 10\n",
1664 | "=================================\n",
1665 | "Training Results\n",
1666 | "{'acc': 0.9422, 'mean_precision': 0.9415872377873563, 'mean_recall': 0.9301100255239593, 'macro_f1': 0.9356066415360083, 'weighted_precision': 0.9423787344008276, 'weighted_recall': 0.9422, 'weighted_f1': 0.9422531801175381}\n",
1667 | "Validation Results\n",
1668 | "{'acc': 0.9026548672566371, 'mean_precision': 0.8534388800712419, 'mean_recall': 0.855449985872144, 'macro_f1': 0.8538969412521858, 'weighted_precision': 0.9037659180936365, 'weighted_recall': 0.9026548672566371, 'weighted_f1': 0.90246529999771}\n",
1669 | "EPOCH 11\n",
1670 | "=================================\n",
1671 | "Training Results\n",
1672 | "{'acc': 0.9432, 'mean_precision': 0.9422754608090832, 'mean_recall': 0.938376139581592, 'macro_f1': 0.9402970803722553, 'weighted_precision': 0.9432734858574917, 'weighted_recall': 0.9432, 'weighted_f1': 0.943229377825017}\n",
1673 | "Validation Results\n",
1674 | "{'acc': 0.9137168141592921, 'mean_precision': 0.8628400105220431, 'mean_recall': 0.8646482805732667, 'macro_f1': 0.8633776502808389, 'weighted_precision': 0.9132203845237589, 'weighted_recall': 0.9137168141592921, 'weighted_f1': 0.9130050592497775}\n",
1675 | "New best model 0.91\n",
1676 | "EPOCH 12\n",
1677 | "=================================\n",
1678 | "Training Results\n",
1679 | "{'acc': 0.9544, 'mean_precision': 0.9557163129978826, 'mean_recall': 0.9458500359607124, 'macro_f1': 0.9506063779628039, 'weighted_precision': 0.9545506681185594, 'weighted_recall': 0.9544, 'weighted_f1': 0.9544423559597639}\n",
1680 | "Validation Results\n",
1681 | "{'acc': 0.9092920353982301, 'mean_precision': 0.8510768742634296, 'mean_recall': 0.8608961905116529, 'macro_f1': 0.8550990513587272, 'weighted_precision': 0.9106944939486582, 'weighted_recall': 0.9092920353982301, 'weighted_f1': 0.9093039549088799}\n",
1682 | "Final result\n",
1683 | "{'acc': 0.944, 'mean_precision': 0.9333687372820768, 'mean_recall': 0.9161547629123813, 'macro_f1': 0.9230157805001022, 'weighted_precision': 0.9449538854974426, 'weighted_recall': 0.944, 'weighted_f1': 0.9429751846143404}\n"
1684 | ],
1685 | "name": "stdout"
1686 | },
1687 | {
1688 | "output_type": "execute_result",
1689 | "data": {
1690 | "text/plain": [
1691 | "0.944"
1692 | ]
1693 | },
1694 | "metadata": {
1695 | "tags": []
1696 | },
1697 | "execution_count": 16
1698 | }
1699 | ]
1700 | },
1701 | {
1702 | "cell_type": "markdown",
1703 | "metadata": {
1704 | "id": "bgrkf2EP4htl",
1705 | "colab_type": "text"
1706 | },
1707 | "source": [
1708 | "Let's see how this number compares against a randomly initialized baseline model that is otherwise identical. We dont really need to use such a huge embedding size in this case -- we are using word vectors instead of character compositional vectors and we dont really have enough information to train a huge word embedding from scratch. Also, since we dont have much information, we will use lowercased features. Note that using these word embeddings features, our model has **6x more parameters than before**. Also, we might want to train it longer."
1709 | ]
1710 | },
1711 | {
1712 | "cell_type": "code",
1713 | "metadata": {
1714 | "id": "b97zMOBr3-UA",
1715 | "colab_type": "code",
1716 | "outputId": "7079df5b-96c3-48d0-895e-e03be50b212d",
1717 | "colab": {
1718 | "base_uri": "https://localhost:8080/",
1719 | "height": 1000
1720 | }
1721 | },
1722 | "source": [
1723 | "\n",
1724 | "r = Reader((TRAIN, VALID, TEST,), lowercase=True)\n",
1725 | "train = r.load(TRAIN)\n",
1726 | "valid = r.load(VALID)\n",
1727 | "test = r.load(TEST)\n",
1728 | "\n",
1729 | "embeddings = nn.Embedding(len(r.vocab), 300)\n",
1730 | "model = LSTMClassifier(embeddings, len(r.labels), embeddings.weight.shape[1], rnn_units=100, hidden_units=[100])\n",
1731 | "\n",
1732 | "num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
1733 | "print(f\"Model has {num_params} parameters\") \n",
1734 | "\n",
1735 | "\n",
1736 | "model.to('cuda:0')\n",
1737 | "loss = torch.nn.NLLLoss()\n",
1738 | "loss = loss.to('cuda:0')\n",
1739 | "\n",
1740 | "learnable_params = [p for p in model.parameters() if p.requires_grad]\n",
1741 | "optimizer = torch.optim.Adadelta(learnable_params, lr=1.0)\n",
1742 | "\n",
1743 | "fit(model, r.labels, optimizer, loss, 48, 50, train, valid, test)"
1744 | ],
1745 | "execution_count": 17,
1746 | "outputs": [
1747 | {
1748 | "output_type": "stream",
1749 | "text": [
1750 | "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/rnn.py:54: UserWarning: dropout option adds dropout after all but last recurrent layer, so non-zero dropout expects num_layers greater than 1, but got dropout=0.5 and num_layers=1\n",
1751 | " \"num_layers={}\".format(dropout, num_layers))\n"
1752 | ],
1753 | "name": "stderr"
1754 | },
1755 | {
1756 | "output_type": "stream",
1757 | "text": [
1758 | "Model has 2801306 parameters\n",
1759 | "EPOCH 1\n",
1760 | "=================================\n",
1761 | "Training Results\n",
1762 | "{'acc': 0.2542, 'mean_precision': 0.31261376969825, 'mean_recall': 0.2142077519956256, 'macro_f1': 0.21100023435767312, 'weighted_precision': 0.24503296037652494, 'weighted_recall': 0.2542, 'weighted_f1': 0.22759490197394544}\n",
1763 | "Validation Results\n",
1764 | "{'acc': 0.31194690265486724, 'mean_precision': 0.30219607432329726, 'mean_recall': 0.2697651533307747, 'macro_f1': 0.23033350512330675, 'weighted_precision': 0.3308620779538906, 'weighted_recall': 0.31194690265486724, 'weighted_f1': 0.25984647084205925}\n",
1765 | "New best model 0.31\n",
1766 | "EPOCH 2\n",
1767 | "=================================\n",
1768 | "Training Results\n",
1769 | "{'acc': 0.3776, 'mean_precision': 0.4211727454667504, 'mean_recall': 0.3614489295644783, 'macro_f1': 0.3731848509532312, 'weighted_precision': 0.37235350929169814, 'weighted_recall': 0.3776, 'weighted_f1': 0.3631077775707718}\n",
1770 | "Validation Results\n",
1771 | "{'acc': 0.4557522123893805, 'mean_precision': 0.486368442230124, 'mean_recall': 0.44059437145396335, 'macro_f1': 0.4116212305386946, 'weighted_precision': 0.5028626993515615, 'weighted_recall': 0.4557522123893805, 'weighted_f1': 0.42325171491881114}\n",
1772 | "New best model 0.46\n",
1773 | "EPOCH 3\n",
1774 | "=================================\n",
1775 | "Training Results\n",
1776 | "{'acc': 0.5432, 'mean_precision': 0.5874298893317987, 'mean_recall': 0.5213749406751867, 'macro_f1': 0.5432030532887705, 'weighted_precision': 0.5414128746603742, 'weighted_recall': 0.5432, 'weighted_f1': 0.539731134218227}\n",
1777 | "Validation Results\n",
1778 | "{'acc': 0.6592920353982301, 'mean_precision': 0.6419746068058818, 'mean_recall': 0.6268845865056035, 'macro_f1': 0.6307977552497951, 'weighted_precision': 0.662392825769239, 'weighted_recall': 0.6592920353982301, 'weighted_f1': 0.6571270878109032}\n",
1779 | "New best model 0.66\n",
1780 | "EPOCH 4\n",
1781 | "=================================\n",
1782 | "Training Results\n",
1783 | "{'acc': 0.6652, 'mean_precision': 0.7019592671173288, 'mean_recall': 0.6360073323821094, 'macro_f1': 0.6596012872763003, 'weighted_precision': 0.6699451752521243, 'weighted_recall': 0.6652, 'weighted_f1': 0.6663461696747198}\n",
1784 | "Validation Results\n",
1785 | "{'acc': 0.7256637168141593, 'mean_precision': 0.6947344026048728, 'mean_recall': 0.684312146723145, 'macro_f1': 0.6882240153205471, 'weighted_precision': 0.7329104731535425, 'weighted_recall': 0.7256637168141593, 'weighted_f1': 0.7276818270316268}\n",
1786 | "New best model 0.73\n",
1787 | "EPOCH 5\n",
1788 | "=================================\n",
1789 | "Training Results\n",
1790 | "{'acc': 0.7262, 'mean_precision': 0.7559759521433286, 'mean_recall': 0.6964030168119725, 'macro_f1': 0.718961419314074, 'weighted_precision': 0.7314068173369025, 'weighted_recall': 0.7262, 'weighted_f1': 0.7276772804705455}\n",
1791 | "Validation Results\n",
1792 | "{'acc': 0.7699115044247787, 'mean_precision': 0.767162729738685, 'mean_recall': 0.7171683804327849, 'macro_f1': 0.7358018640759697, 'weighted_precision': 0.7869035814233268, 'weighted_recall': 0.7699115044247787, 'weighted_f1': 0.7730466432148493}\n",
1793 | "New best model 0.77\n",
1794 | "EPOCH 6\n",
1795 | "=================================\n",
1796 | "Training Results\n",
1797 | "{'acc': 0.7774, 'mean_precision': 0.7991829240246574, 'mean_recall': 0.7460179937770648, 'macro_f1': 0.7670813561197591, 'weighted_precision': 0.7829082373273452, 'weighted_recall': 0.7774, 'weighted_f1': 0.7790066442175066}\n",
1798 | "Validation Results\n",
1799 | "{'acc': 0.7942477876106194, 'mean_precision': 0.7722010233604815, 'mean_recall': 0.7413137565797516, 'macro_f1': 0.7527105440234411, 'weighted_precision': 0.7965885015875331, 'weighted_recall': 0.7942477876106194, 'weighted_f1': 0.7932362973950707}\n",
1800 | "New best model 0.79\n",
1801 | "EPOCH 7\n",
1802 | "=================================\n",
1803 | "Training Results\n",
1804 | "{'acc': 0.8094, 'mean_precision': 0.8190739759778141, 'mean_recall': 0.7893403911783444, 'macro_f1': 0.802535501291079, 'weighted_precision': 0.8128192118492741, 'weighted_recall': 0.8094, 'weighted_f1': 0.8105827362652248}\n",
1805 | "Validation Results\n",
1806 | "{'acc': 0.7876106194690266, 'mean_precision': 0.7485831299937852, 'mean_recall': 0.737125918471548, 'macro_f1': 0.739828985707053, 'weighted_precision': 0.7889226923611903, 'weighted_recall': 0.7876106194690266, 'weighted_f1': 0.7851833379237152}\n",
1807 | "EPOCH 8\n",
1808 | "=================================\n",
1809 | "Training Results\n",
1810 | "{'acc': 0.8334, 'mean_precision': 0.8425566386344959, 'mean_recall': 0.8020487930554469, 'macro_f1': 0.8188927928986053, 'weighted_precision': 0.8361427439552401, 'weighted_recall': 0.8334, 'weighted_f1': 0.8341487335405394}\n",
1811 | "Validation Results\n",
1812 | "{'acc': 0.8053097345132744, 'mean_precision': 0.7836886850840924, 'mean_recall': 0.7717037075656982, 'macro_f1': 0.7766933663665941, 'weighted_precision': 0.8124743795936054, 'weighted_recall': 0.8053097345132744, 'weighted_f1': 0.8076258267664438}\n",
1813 | "New best model 0.81\n",
1814 | "EPOCH 9\n",
1815 | "=================================\n",
1816 | "Training Results\n",
1817 | "{'acc': 0.8468, 'mean_precision': 0.8504311729666333, 'mean_recall': 0.8200134573572487, 'macro_f1': 0.8331837029506359, 'weighted_precision': 0.8494250195065668, 'weighted_recall': 0.8468, 'weighted_f1': 0.8475914482379496}\n",
1818 | "Validation Results\n",
1819 | "{'acc': 0.8163716814159292, 'mean_precision': 0.8122787276154813, 'mean_recall': 0.7728318623385028, 'macro_f1': 0.7845697172233551, 'weighted_precision': 0.841008158316879, 'weighted_recall': 0.8163716814159292, 'weighted_f1': 0.8194346806526578}\n",
1820 | "New best model 0.82\n",
1821 | "EPOCH 10\n",
1822 | "=================================\n",
1823 | "Training Results\n",
1824 | "{'acc': 0.856, 'mean_precision': 0.8583456863001512, 'mean_recall': 0.8344367321617767, 'macro_f1': 0.8452794073478773, 'weighted_precision': 0.858502410730061, 'weighted_recall': 0.856, 'weighted_f1': 0.8568463379552975}\n",
1825 | "Validation Results\n",
1826 | "{'acc': 0.827433628318584, 'mean_precision': 0.8065578820468894, 'mean_recall': 0.7875269000666668, 'macro_f1': 0.7948885464239536, 'weighted_precision': 0.8376851538502355, 'weighted_recall': 0.827433628318584, 'weighted_f1': 0.8300346063518604}\n",
1827 | "New best model 0.83\n",
1828 | "EPOCH 11\n",
1829 | "=================================\n",
1830 | "Training Results\n",
1831 | "{'acc': 0.8752, 'mean_precision': 0.8867543443057883, 'mean_recall': 0.8682828032176982, 'macro_f1': 0.876877938560822, 'weighted_precision': 0.8765970442190595, 'weighted_recall': 0.8752, 'weighted_f1': 0.8756987373269298}\n",
1832 | "Validation Results\n",
1833 | "{'acc': 0.8429203539823009, 'mean_precision': 0.810890350645565, 'mean_recall': 0.8200588537595558, 'macro_f1': 0.8091225485091558, 'weighted_precision': 0.856571882711008, 'weighted_recall': 0.8429203539823009, 'weighted_f1': 0.8457428127083414}\n",
1834 | "New best model 0.84\n",
1835 | "EPOCH 12\n",
1836 | "=================================\n",
1837 | "Training Results\n",
1838 | "{'acc': 0.8792, 'mean_precision': 0.8792557729505756, 'mean_recall': 0.8559171118516193, 'macro_f1': 0.8664273096431591, 'weighted_precision': 0.8802943928708451, 'weighted_recall': 0.8792, 'weighted_f1': 0.8794929463166824}\n",
1839 | "Validation Results\n",
1840 | "{'acc': 0.831858407079646, 'mean_precision': 0.7962677456184699, 'mean_recall': 0.8169955255150682, 'macro_f1': 0.8012212821990726, 'weighted_precision': 0.8439342848824284, 'weighted_recall': 0.831858407079646, 'weighted_f1': 0.8345906017116217}\n",
1841 | "EPOCH 13\n",
1842 | "=================================\n",
1843 | "Training Results\n",
1844 | "{'acc': 0.8982, 'mean_precision': 0.8926979300525084, 'mean_recall': 0.8720075500179675, 'macro_f1': 0.8813865977662623, 'weighted_precision': 0.8990274368829654, 'weighted_recall': 0.8982, 'weighted_f1': 0.898400074877493}\n",
1845 | "Validation Results\n",
1846 | "{'acc': 0.8407079646017699, 'mean_precision': 0.8044426437483548, 'mean_recall': 0.8230926174040855, 'macro_f1': 0.8115058510891844, 'weighted_precision': 0.8446897114706726, 'weighted_recall': 0.8407079646017699, 'weighted_f1': 0.8418088284238727}\n",
1847 | "EPOCH 14\n",
1848 | "=================================\n",
1849 | "Training Results\n",
1850 | "{'acc': 0.8994, 'mean_precision': 0.90397473118525, 'mean_recall': 0.8869886283764901, 'macro_f1': 0.8949155550364044, 'weighted_precision': 0.9000862304408612, 'weighted_recall': 0.8994, 'weighted_f1': 0.8996270341701729}\n",
1851 | "Validation Results\n",
1852 | "{'acc': 0.834070796460177, 'mean_precision': 0.7941534082526428, 'mean_recall': 0.8181802560779231, 'macro_f1': 0.802448595927142, 'weighted_precision': 0.839962405943381, 'weighted_recall': 0.834070796460177, 'weighted_f1': 0.8360121837053236}\n",
1853 | "EPOCH 15\n",
1854 | "=================================\n",
1855 | "Training Results\n",
1856 | "{'acc': 0.9078, 'mean_precision': 0.9025906984450397, 'mean_recall': 0.8840592072689094, 'macro_f1': 0.8925740985975588, 'weighted_precision': 0.9079516195367331, 'weighted_recall': 0.9078, 'weighted_f1': 0.9077917625402113}\n",
1857 | "Validation Results\n",
1858 | "{'acc': 0.8429203539823009, 'mean_precision': 0.8178604107482675, 'mean_recall': 0.8027279926613361, 'macro_f1': 0.807847400976306, 'weighted_precision': 0.8541454685154969, 'weighted_recall': 0.8429203539823009, 'weighted_f1': 0.8452297642709833}\n",
1859 | "EPOCH 16\n",
1860 | "=================================\n",
1861 | "Training Results\n",
1862 | "{'acc': 0.9114, 'mean_precision': 0.9091156125381755, 'mean_recall': 0.8954153246439646, 'macro_f1': 0.901884018226399, 'weighted_precision': 0.9120534355149411, 'weighted_recall': 0.9114, 'weighted_f1': 0.9115938328769627}\n",
1863 | "Validation Results\n",
1864 | "{'acc': 0.8584070796460177, 'mean_precision': 0.8452925879669096, 'mean_recall': 0.8140923455585375, 'macro_f1': 0.8271059195675882, 'weighted_precision': 0.866784367505974, 'weighted_recall': 0.8584070796460177, 'weighted_f1': 0.8601283537102449}\n",
1865 | "New best model 0.86\n",
1866 | "EPOCH 17\n",
1867 | "=================================\n",
1868 | "Training Results\n",
1869 | "{'acc': 0.922, 'mean_precision': 0.92318827916946, 'mean_recall': 0.9167872228173993, 'macro_f1': 0.919907344495105, 'weighted_precision': 0.9222631334663767, 'weighted_recall': 0.922, 'weighted_f1': 0.9220983595773123}\n",
1870 | "Validation Results\n",
1871 | "{'acc': 0.8495575221238938, 'mean_precision': 0.8226381745596281, 'mean_recall': 0.8064587339697171, 'macro_f1': 0.8118331344786779, 'weighted_precision': 0.8588090104120005, 'weighted_recall': 0.8495575221238938, 'weighted_f1': 0.8509990153606489}\n",
1872 | "EPOCH 18\n",
1873 | "=================================\n",
1874 | "Training Results\n",
1875 | "{'acc': 0.925, 'mean_precision': 0.9272713764228396, 'mean_recall': 0.9097196644145997, 'macro_f1': 0.9178903031944213, 'weighted_precision': 0.9256665310570704, 'weighted_recall': 0.925, 'weighted_f1': 0.9251760923127654}\n",
1876 | "Validation Results\n",
1877 | "{'acc': 0.8362831858407079, 'mean_precision': 0.8130991700945112, 'mean_recall': 0.8186032795814143, 'macro_f1': 0.8123228744939271, 'weighted_precision': 0.8468684662820305, 'weighted_recall': 0.8362831858407079, 'weighted_f1': 0.8392648607279487}\n",
1878 | "EPOCH 19\n",
1879 | "=================================\n",
1880 | "Training Results\n",
1881 | "{'acc': 0.9226, 'mean_precision': 0.9265697322783559, 'mean_recall': 0.907365970956147, 'macro_f1': 0.9162751616718158, 'weighted_precision': 0.9229974594112873, 'weighted_recall': 0.9226, 'weighted_f1': 0.922695479769021}\n",
1882 | "Validation Results\n",
1883 | "{'acc': 0.8407079646017699, 'mean_precision': 0.7961368682179472, 'mean_recall': 0.8225977371214205, 'macro_f1': 0.8033500248964449, 'weighted_precision': 0.8499048273885906, 'weighted_recall': 0.8407079646017699, 'weighted_f1': 0.8433057108007154}\n",
1884 | "EPOCH 20\n",
1885 | "=================================\n",
1886 | "Training Results\n",
1887 | "{'acc': 0.9306, 'mean_precision': 0.9325672078372992, 'mean_recall': 0.9292912885334105, 'macro_f1': 0.9308909230234336, 'weighted_precision': 0.9309373491584835, 'weighted_recall': 0.9306, 'weighted_f1': 0.9307301028725288}\n",
1888 | "Validation Results\n",
1889 | "{'acc': 0.8451327433628318, 'mean_precision': 0.8169164169164169, 'mean_recall': 0.8240567649306224, 'macro_f1': 0.81237323931053, 'weighted_precision': 0.8649006598121644, 'weighted_recall': 0.8451327433628318, 'weighted_f1': 0.8481262461147536}\n",
1890 | "EPOCH 21\n",
1891 | "=================================\n",
1892 | "Training Results\n",
1893 | "{'acc': 0.9372, 'mean_precision': 0.94778588624423, 'mean_recall': 0.931739382135247, 'macro_f1': 0.9392845360347503, 'weighted_precision': 0.9376509979046739, 'weighted_recall': 0.9372, 'weighted_f1': 0.9373365710922599}\n",
1894 | "Validation Results\n",
1895 | "{'acc': 0.8429203539823009, 'mean_precision': 0.8122060187568131, 'mean_recall': 0.8039074327168304, 'macro_f1': 0.8074842893609094, 'weighted_precision': 0.846085726902366, 'weighted_recall': 0.8429203539823009, 'weighted_f1': 0.843907236330442}\n",
1896 | "EPOCH 22\n",
1897 | "=================================\n",
1898 | "Training Results\n",
1899 | "{'acc': 0.9396, 'mean_precision': 0.9374128845763084, 'mean_recall': 0.919864102313256, 'macro_f1': 0.9279841844104714, 'weighted_precision': 0.9395460389520726, 'weighted_recall': 0.9396, 'weighted_f1': 0.9395077256432057}\n",
1900 | "Validation Results\n",
1901 | "{'acc': 0.838495575221239, 'mean_precision': 0.8166287688346512, 'mean_recall': 0.7965319356459588, 'macro_f1': 0.8036210071046136, 'weighted_precision': 0.8514493271781142, 'weighted_recall': 0.838495575221239, 'weighted_f1': 0.8411121771904451}\n",
1902 | "EPOCH 23\n",
1903 | "=================================\n",
1904 | "Training Results\n",
1905 | "{'acc': 0.9394, 'mean_precision': 0.9343864652716637, 'mean_recall': 0.9292029606780483, 'macro_f1': 0.9317217171943124, 'weighted_precision': 0.9395732746438947, 'weighted_recall': 0.9394, 'weighted_f1': 0.9394587072607552}\n",
1906 | "Validation Results\n",
1907 | "{'acc': 0.8495575221238938, 'mean_precision': 0.8152645128671007, 'mean_recall': 0.8101445689658334, 'macro_f1': 0.8116836837706645, 'weighted_precision': 0.853045900832462, 'weighted_recall': 0.8495575221238938, 'weighted_f1': 0.8499444374703017}\n",
1908 | "EPOCH 24\n",
1909 | "=================================\n",
1910 | "Training Results\n",
1911 | "{'acc': 0.9366, 'mean_precision': 0.9385826391144599, 'mean_recall': 0.9310259128633419, 'macro_f1': 0.9346948428938683, 'weighted_precision': 0.9367440103332517, 'weighted_recall': 0.9366, 'weighted_f1': 0.9366526615949582}\n",
1912 | "Validation Results\n",
1913 | "{'acc': 0.8473451327433629, 'mean_precision': 0.8025432410455261, 'mean_recall': 0.8283322432914666, 'macro_f1': 0.8061102573975748, 'weighted_precision': 0.861458116214064, 'weighted_recall': 0.8473451327433629, 'weighted_f1': 0.8502841373610759}\n",
1914 | "EPOCH 25\n",
1915 | "=================================\n",
1916 | "Training Results\n",
1917 | "{'acc': 0.9474, 'mean_precision': 0.945203027270548, 'mean_recall': 0.9397940805142109, 'macro_f1': 0.9424193517790628, 'weighted_precision': 0.9476039831561757, 'weighted_recall': 0.9474, 'weighted_f1': 0.9474674977170525}\n",
1918 | "Validation Results\n",
1919 | "{'acc': 0.8517699115044248, 'mean_precision': 0.8122309943824346, 'mean_recall': 0.8316675962171595, 'macro_f1': 0.817457525740406, 'weighted_precision': 0.8589009917207507, 'weighted_recall': 0.8517699115044248, 'weighted_f1': 0.8536305273856584}\n",
1920 | "EPOCH 26\n",
1921 | "=================================\n",
1922 | "Training Results\n",
1923 | "{'acc': 0.9412, 'mean_precision': 0.9333724069841898, 'mean_recall': 0.9327119878566131, 'macro_f1': 0.9330083105008247, 'weighted_precision': 0.9415099002878724, 'weighted_recall': 0.9412, 'weighted_f1': 0.9413101950625798}\n",
1924 | "Validation Results\n",
1925 | "{'acc': 0.8672566371681416, 'mean_precision': 0.8605083530628996, 'mean_recall': 0.8246033707390573, 'macro_f1': 0.8391902596303691, 'weighted_precision': 0.8669141290171042, 'weighted_recall': 0.8672566371681416, 'weighted_f1': 0.8665873995194234}\n",
1926 | "New best model 0.87\n",
1927 | "EPOCH 27\n",
1928 | "=================================\n",
1929 | "Training Results\n",
1930 | "{'acc': 0.9452, 'mean_precision': 0.9393570838696085, 'mean_recall': 0.9381922191949007, 'macro_f1': 0.938758641015847, 'weighted_precision': 0.9451643947972791, 'weighted_recall': 0.9452, 'weighted_f1': 0.9451693027949712}\n",
1931 | "Validation Results\n",
1932 | "{'acc': 0.8539823008849557, 'mean_precision': 0.8230172208866048, 'mean_recall': 0.8123173810979951, 'macro_f1': 0.8156989547159269, 'weighted_precision': 0.86038962724662, 'weighted_recall': 0.8539823008849557, 'weighted_f1': 0.8546422455744229}\n",
1933 | "EPOCH 28\n",
1934 | "=================================\n",
1935 | "Training Results\n",
1936 | "{'acc': 0.9496, 'mean_precision': 0.9460303491056411, 'mean_recall': 0.9315008817259457, 'macro_f1': 0.938347207012888, 'weighted_precision': 0.9496092730822427, 'weighted_recall': 0.9496, 'weighted_f1': 0.9495572903768849}\n",
1937 | "Validation Results\n",
1938 | "{'acc': 0.8495575221238938, 'mean_precision': 0.8264021292481143, 'mean_recall': 0.8075150926865313, 'macro_f1': 0.8142132100384322, 'weighted_precision': 0.8614722274611524, 'weighted_recall': 0.8495575221238938, 'weighted_f1': 0.8519763491364317}\n",
1939 | "EPOCH 29\n",
1940 | "=================================\n",
1941 | "Training Results\n",
1942 | "{'acc': 0.9492, 'mean_precision': 0.9523110976436452, 'mean_recall': 0.9431909498306, 'macro_f1': 0.9475782265270856, 'weighted_precision': 0.9492766331969764, 'weighted_recall': 0.9492, 'weighted_f1': 0.9492146339395731}\n",
1943 | "Validation Results\n",
1944 | "{'acc': 0.8517699115044248, 'mean_precision': 0.8258343506751618, 'mean_recall': 0.8103395389633586, 'macro_f1': 0.8158066635203088, 'weighted_precision': 0.860982906249134, 'weighted_recall': 0.8517699115044248, 'weighted_f1': 0.8534222923294074}\n",
1945 | "EPOCH 30\n",
1946 | "=================================\n",
1947 | "Training Results\n",
1948 | "{'acc': 0.9486, 'mean_precision': 0.9560017443521739, 'mean_recall': 0.9400692468871957, 'macro_f1': 0.9475746328501561, 'weighted_precision': 0.9488545438754485, 'weighted_recall': 0.9486, 'weighted_f1': 0.9486620015461529}\n",
1949 | "Validation Results\n",
1950 | "{'acc': 0.8517699115044248, 'mean_precision': 0.8177916897275509, 'mean_recall': 0.8317336038356844, 'macro_f1': 0.8189379757266612, 'weighted_precision': 0.8641615919629556, 'weighted_recall': 0.8517699115044248, 'weighted_f1': 0.8543175095648456}\n",
1951 | "EPOCH 31\n",
1952 | "=================================\n",
1953 | "Training Results\n",
1954 | "{'acc': 0.958, 'mean_precision': 0.9538284266182177, 'mean_recall': 0.9446486493014933, 'macro_f1': 0.9490660805682173, 'weighted_precision': 0.9581057645690427, 'weighted_recall': 0.958, 'weighted_f1': 0.958018653218701}\n",
1955 | "Validation Results\n",
1956 | "{'acc': 0.8517699115044248, 'mean_precision': 0.8013986378685444, 'mean_recall': 0.8119935701678367, 'macro_f1': 0.804411328159357, 'weighted_precision': 0.8579675339835053, 'weighted_recall': 0.8517699115044248, 'weighted_f1': 0.8534679803361881}\n",
1957 | "EPOCH 32\n",
1958 | "=================================\n",
1959 | "Training Results\n",
1960 | "{'acc': 0.954, 'mean_precision': 0.9512729471583152, 'mean_recall': 0.9522326430845925, 'macro_f1': 0.9517306869942752, 'weighted_precision': 0.9541486554643743, 'weighted_recall': 0.954, 'weighted_f1': 0.9540543038088639}\n",
1961 | "Validation Results\n",
1962 | "{'acc': 0.8584070796460177, 'mean_precision': 0.8247875724959964, 'mean_recall': 0.813382140867244, 'macro_f1': 0.8148679951657951, 'weighted_precision': 0.8728418185739959, 'weighted_recall': 0.8584070796460177, 'weighted_f1': 0.8607256592741758}\n",
1963 | "EPOCH 33\n",
1964 | "=================================\n",
1965 | "Training Results\n",
1966 | "{'acc': 0.957, 'mean_precision': 0.9565681480304719, 'mean_recall': 0.9454327046857133, 'macro_f1': 0.9507466808370677, 'weighted_precision': 0.9571322707135252, 'weighted_recall': 0.957, 'weighted_f1': 0.9570200735546852}\n",
1967 | "Validation Results\n",
1968 | "{'acc': 0.8407079646017699, 'mean_precision': 0.8012295177369267, 'mean_recall': 0.7988123932835594, 'macro_f1': 0.7956541104094256, 'weighted_precision': 0.854940819905837, 'weighted_recall': 0.8407079646017699, 'weighted_f1': 0.843834784432942}\n",
1969 | "EPOCH 34\n",
1970 | "=================================\n",
1971 | "Training Results\n",
1972 | "{'acc': 0.9562, 'mean_precision': 0.9492352170638396, 'mean_recall': 0.9501709445363726, 'macro_f1': 0.9496843838684074, 'weighted_precision': 0.9563024102937812, 'weighted_recall': 0.9562, 'weighted_f1': 0.9562374926155631}\n",
1973 | "Validation Results\n",
1974 | "{'acc': 0.8539823008849557, 'mean_precision': 0.8074247463353991, 'mean_recall': 0.8132105098252905, 'macro_f1': 0.8095076634215594, 'weighted_precision': 0.8571530973672192, 'weighted_recall': 0.8539823008849557, 'weighted_f1': 0.8549969022822371}\n",
1975 | "EPOCH 35\n",
1976 | "=================================\n",
1977 | "Training Results\n",
1978 | "{'acc': 0.958, 'mean_precision': 0.9581863087122168, 'mean_recall': 0.9556184908351323, 'macro_f1': 0.9568856509156166, 'weighted_precision': 0.9581083670874939, 'weighted_recall': 0.958, 'weighted_f1': 0.9580411865238977}\n",
1979 | "Validation Results\n",
1980 | "{'acc': 0.8495575221238938, 'mean_precision': 0.8178061479593234, 'mean_recall': 0.805632013633958, 'macro_f1': 0.807789691127139, 'weighted_precision': 0.8649796275543147, 'weighted_recall': 0.8495575221238938, 'weighted_f1': 0.8526672381640437}\n",
1981 | "EPOCH 36\n",
1982 | "=================================\n",
1983 | "Training Results\n",
1984 | "{'acc': 0.9574, 'mean_precision': 0.9531272685674982, 'mean_recall': 0.9515769491207995, 'macro_f1': 0.9523405767311283, 'weighted_precision': 0.9573866537776834, 'weighted_recall': 0.9574, 'weighted_f1': 0.9573873211612607}\n",
1985 | "Validation Results\n",
1986 | "{'acc': 0.8561946902654868, 'mean_precision': 0.8131151427942895, 'mean_recall': 0.8110658945033787, 'macro_f1': 0.8077795032214256, 'weighted_precision': 0.8693120626269972, 'weighted_recall': 0.8561946902654868, 'weighted_f1': 0.8589532280938094}\n",
1987 | "EPOCH 37\n",
1988 | "=================================\n",
1989 | "Training Results\n",
1990 | "{'acc': 0.9566, 'mean_precision': 0.9589041860848261, 'mean_recall': 0.9547448481946009, 'macro_f1': 0.9567909158741493, 'weighted_precision': 0.9566010555107037, 'weighted_recall': 0.9566, 'weighted_f1': 0.9565890848568762}\n",
1991 | "Validation Results\n",
1992 | "{'acc': 0.8495575221238938, 'mean_precision': 0.8030750892649611, 'mean_recall': 0.8075017730562345, 'macro_f1': 0.8022990697875542, 'weighted_precision': 0.858293669895924, 'weighted_recall': 0.8495575221238938, 'weighted_f1': 0.851768488915397}\n",
1993 | "EPOCH 38\n",
1994 | "=================================\n",
1995 | "Training Results\n",
1996 | "{'acc': 0.957, 'mean_precision': 0.9544796723873579, 'mean_recall': 0.9522358611983669, 'macro_f1': 0.9533352482886838, 'weighted_precision': 0.9571066919837921, 'weighted_recall': 0.957, 'weighted_f1': 0.9570317579959188}\n",
1997 | "Validation Results\n",
1998 | "{'acc': 0.8539823008849557, 'mean_precision': 0.8106297488632482, 'mean_recall': 0.8123418403963019, 'macro_f1': 0.8098548414705528, 'weighted_precision': 0.8585169985936104, 'weighted_recall': 0.8539823008849557, 'weighted_f1': 0.8548092957757126}\n",
1999 | "EPOCH 39\n",
2000 | "=================================\n",
2001 | "Training Results\n",
2002 | "{'acc': 0.9564, 'mean_precision': 0.9532129259082603, 'mean_recall': 0.9506704508751782, 'macro_f1': 0.951926709753589, 'weighted_precision': 0.9564913201918928, 'weighted_recall': 0.9564, 'weighted_f1': 0.9564356136646421}\n",
2003 | "Validation Results\n",
2004 | "{'acc': 0.8539823008849557, 'mean_precision': 0.8017848860212725, 'mean_recall': 0.8119553221526408, 'macro_f1': 0.804904791937732, 'weighted_precision': 0.8588867681742558, 'weighted_recall': 0.8539823008849557, 'weighted_f1': 0.855560943534421}\n",
2005 | "EPOCH 40\n",
2006 | "=================================\n",
2007 | "Training Results\n",
2008 | "{'acc': 0.9644, 'mean_precision': 0.9670062774145092, 'mean_recall': 0.9573277463956331, 'macro_f1': 0.9619863456379548, 'weighted_precision': 0.9645118227152805, 'weighted_recall': 0.9644, 'weighted_f1': 0.9644225052829172}\n",
2009 | "Validation Results\n",
2010 | "{'acc': 0.8495575221238938, 'mean_precision': 0.8119564587793305, 'mean_recall': 0.8080099729691964, 'macro_f1': 0.8081568266024983, 'weighted_precision': 0.8580128676213757, 'weighted_recall': 0.8495575221238938, 'weighted_f1': 0.8519747328540073}\n",
2011 | "EPOCH 41\n",
2012 | "=================================\n",
2013 | "Training Results\n",
2014 | "{'acc': 0.9628, 'mean_precision': 0.9582207266041486, 'mean_recall': 0.9557915435105366, 'macro_f1': 0.956989832500887, 'weighted_precision': 0.9628821971783511, 'weighted_recall': 0.9628, 'weighted_f1': 0.9628296736582989}\n",
2015 | "Validation Results\n",
2016 | "{'acc': 0.8451327433628318, 'mean_precision': 0.7965093061348186, 'mean_recall': 0.8045728236987139, 'macro_f1': 0.7982598158685201, 'weighted_precision': 0.8513942451058316, 'weighted_recall': 0.8451327433628318, 'weighted_f1': 0.8470011392375405}\n",
2017 | "EPOCH 42\n",
2018 | "=================================\n",
2019 | "Training Results\n",
2020 | "{'acc': 0.964, 'mean_precision': 0.9645740109204436, 'mean_recall': 0.9606825654964162, 'macro_f1': 0.9625915166364768, 'weighted_precision': 0.9640606955680181, 'weighted_recall': 0.964, 'weighted_f1': 0.9640155053392376}\n",
2021 | "Validation Results\n",
2022 | "{'acc': 0.8495575221238938, 'mean_precision': 0.8037359987261851, 'mean_recall': 0.8058441622200901, 'macro_f1': 0.8012095203517711, 'weighted_precision': 0.8602910880270568, 'weighted_recall': 0.8495575221238938, 'weighted_f1': 0.8518937207419582}\n",
2023 | "EPOCH 43\n",
2024 | "=================================\n",
2025 | "Training Results\n",
2026 | "{'acc': 0.967, 'mean_precision': 0.9673205256002168, 'mean_recall': 0.9608809490156882, 'macro_f1': 0.9640313629040639, 'weighted_precision': 0.967090742201184, 'weighted_recall': 0.967, 'weighted_f1': 0.9670273824588701}\n",
2027 | "Validation Results\n",
2028 | "{'acc': 0.8584070796460177, 'mean_precision': 0.8071223611395778, 'mean_recall': 0.815180322836991, 'macro_f1': 0.8086659140417566, 'weighted_precision': 0.8654518199826157, 'weighted_recall': 0.8584070796460177, 'weighted_f1': 0.8603401318409323}\n",
2029 | "EPOCH 44\n",
2030 | "=================================\n",
2031 | "Training Results\n",
2032 | "{'acc': 0.9658, 'mean_precision': 0.9609493483829002, 'mean_recall': 0.9637935616364777, 'macro_f1': 0.9623305485433838, 'weighted_precision': 0.9659282069261604, 'weighted_recall': 0.9658, 'weighted_f1': 0.9658413169266654}\n",
2033 | "Validation Results\n",
2034 | "{'acc': 0.8451327433628318, 'mean_precision': 0.811964839602692, 'mean_recall': 0.804398454041185, 'macro_f1': 0.8075118081867685, 'weighted_precision': 0.8495195126057683, 'weighted_recall': 0.8451327433628318, 'weighted_f1': 0.846453823025803}\n",
2035 | "EPOCH 45\n",
2036 | "=================================\n",
2037 | "Training Results\n",
2038 | "{'acc': 0.9658, 'mean_precision': 0.9598439053443757, 'mean_recall': 0.9526976401858064, 'macro_f1': 0.9561667547996912, 'weighted_precision': 0.9657984947117778, 'weighted_recall': 0.9658, 'weighted_f1': 0.9657855836823659}\n",
2039 | "Validation Results\n",
2040 | "{'acc': 0.8539823008849557, 'mean_precision': 0.8027867894992161, 'mean_recall': 0.8119553221526408, 'macro_f1': 0.8053321278501304, 'weighted_precision': 0.8592071325485433, 'weighted_recall': 0.8539823008849557, 'weighted_f1': 0.8556693507865624}\n",
2041 | "EPOCH 46\n",
2042 | "=================================\n",
2043 | "Training Results\n",
2044 | "{'acc': 0.9662, 'mean_precision': 0.9604700540065974, 'mean_recall': 0.9625075307979163, 'macro_f1': 0.9614812705685866, 'weighted_precision': 0.9662056816206401, 'weighted_recall': 0.9662, 'weighted_f1': 0.966200656893312}\n",
2045 | "Validation Results\n",
2046 | "{'acc': 0.8606194690265486, 'mean_precision': 0.8296025254745819, 'mean_recall': 0.8174572008429252, 'macro_f1': 0.8224370751844012, 'weighted_precision': 0.8663240618203951, 'weighted_recall': 0.8606194690265486, 'weighted_f1': 0.8621524978508673}\n",
2047 | "EPOCH 47\n",
2048 | "=================================\n",
2049 | "Training Results\n",
2050 | "{'acc': 0.9658, 'mean_precision': 0.9623643280997959, 'mean_recall': 0.9601608706065794, 'macro_f1': 0.9612521277633985, 'weighted_precision': 0.9657974850075743, 'weighted_recall': 0.9658, 'weighted_f1': 0.9657933462320052}\n",
2051 | "Validation Results\n",
2052 | "{'acc': 0.8606194690265486, 'mean_precision': 0.8289130643629324, 'mean_recall': 0.8162020131702755, 'macro_f1': 0.8212697994296662, 'weighted_precision': 0.8663358703067995, 'weighted_recall': 0.8606194690265486, 'weighted_f1': 0.8619390343788721}\n",
2053 | "EPOCH 48\n",
2054 | "=================================\n",
2055 | "Training Results\n",
2056 | "{'acc': 0.9694, 'mean_precision': 0.965498989334344, 'mean_recall': 0.9691871765140303, 'macro_f1': 0.9673126581580327, 'weighted_precision': 0.969441121137196, 'weighted_recall': 0.9694, 'weighted_f1': 0.969411138333605}\n",
2057 | "Validation Results\n",
2058 | "{'acc': 0.8473451327433629, 'mean_precision': 0.8108888178689653, 'mean_recall': 0.8042876327132501, 'macro_f1': 0.8053138141785771, 'weighted_precision': 0.8569467358619901, 'weighted_recall': 0.8473451327433629, 'weighted_f1': 0.8498193213999172}\n",
2059 | "Final result\n",
2060 | "{'acc': 0.882, 'mean_precision': 0.8966656639557661, 'mean_recall': 0.8599837937999005, 'macro_f1': 0.8737226125229776, 'weighted_precision': 0.8813809951547577, 'weighted_recall': 0.882, 'weighted_f1': 0.878768700658464}\n"
2061 | ],
2062 | "name": "stdout"
2063 | },
2064 | {
2065 | "output_type": "execute_result",
2066 | "data": {
2067 | "text/plain": [
2068 | "0.882"
2069 | ]
2070 | },
2071 | "metadata": {
2072 | "tags": []
2073 | },
2074 | "execution_count": 17
2075 | }
2076 | ]
2077 | },
2078 | {
2079 | "cell_type": "markdown",
2080 | "metadata": {
2081 | "id": "iPWVd-iW1mMG",
2082 | "colab_type": "text"
2083 | },
2084 | "source": [
2085 | "## Conclusion\n",
2086 | "\n",
2087 | "Without even concatenating word features, our ELMo model, with far fewer parameters, surpasses the performance of the randomly initialized baseline, which we would expect. It also significantly out-performs our CNN pre-trained, fine-tuned word embeddings baseline from the last section -- that model's max performance is around 93. Note that this dataset is tiny, and the variance is large between datasets, but this model consistently outperforms both CNN and LSTM baselines.\n",
2088 | "\n",
2089 | "Contextual embeddings consistently outperform non-contextual embeddings on almost every task in NLP, not just in text classification. This method is becoming so commonly used that some papers have even started reporting this approach as a baseline.\n",
2090 | "\n",
2091 | "### Some more references\n",
2092 | "\n",
2093 | "- The PyTorch examples actually contain a [nice word-language model](https://github.com/pytorch/examples/tree/master/word_language_model)\n",
2094 | "\n",
2095 | "- There is a [Tensorflow tutorial](https://www.tensorflow.org/tutorials/sequences/recurrent) as well\n",
2096 | "\n",
2097 | "- The original source code for training [ELMo's bilm is here](https://github.com/allenai/bilm-tf/tree/master/bilm)\n",
2098 | "\n",
2099 | "- [A succinct implementation](https://github.com/dpressel/baseline/blob/master/python/baseline/pytorch/embeddings.py#L63) of character-compositional embeddings in Baseline for PyTorch\n",
2100 | "\n",
2101 | "\n",
2102 | "\n",
2103 | "\n"
2104 | ]
2105 | }
2106 | ]
2107 | }
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # dliss-tutorial
2 | Tutorial for [International Summer School
3 | on Deep Learning, 2019](http://dl-lab.eu/) in Gdansk, Poland
4 |
5 | ## Sections
6 |
7 | ### Overview Talk
8 |
9 | https://docs.google.com/presentation/d/1DJI1yX4U5IgApGwavt0AmOCLWwso7ou1Un93sMuAWmA/
10 |
11 | ### Tutorial
12 | There are currently 3 hands-on sections to this tutorial.
13 |
14 | - The [first section](1_pretrained_vectors.ipynb) covers pre-trained word embeddings [(colab)](https://colab.research.google.com/github/dpressel/dlss-tutorial/blob/master/1_pretrained_vectors.ipynb)
15 |
16 | - The [second section](2_context_vectors.ipynb) covers pre-trained contextual emeddings [(colab)](https://colab.research.google.com/github/dpressel/dlss-tutorial/blob/master/2_context_vectors.ipynb)
17 | - The [third section](3_finetuning.ipynb) covers fine-tuning a pre-trained model [(colab)](https://colab.research.google.com/github/dpressel/dlss-tutorial/blob/master/3_finetuning.ipynb)
18 |
19 | ### Updates
20 |
21 | - *April 2022* If you are interested in learning how to build different Transformer architectures from the ground up, I have a [new set of tutorials](https://github.com/dpressel/tfs) with in-depth details and full implementations of several popular Transformer models. They show how to build models step by step, how to pretrain them, and how to use them for downstream tasks. There is an accompanying Python package that contains all of the tutorial pieces put together
22 |
23 |
24 |
25 | - *July 2020* I have posted a set of [Colab tutorials](https://github.com/dpressel/mead-tutorials) using [MEAD](https://github.com/dpressel/mead-baseline) which is referenced in these tutorials. This new set of notebooks covers similar material, including transfer learning for classification and taggers, as well as training Transformer-based models from scratch using the [MEAD API](https://github.com/dpressel/mead-baseline/tree/master/layers) with TPUs. MEAD makes it easy to train lots of powerful models for NLP using a simple YAML configuration and makes it easy to extend the code with new models while comparing against strong baselines!
26 |
--------------------------------------------------------------------------------