├── LICENSE
├── README.md
├── datasets.py
├── dynamic_rnn.py
├── img
├── backchar.jpg
├── backtrace.jpg
├── checks1.jpg
├── checks2.jpg
├── configs.jpg
├── crfscores.jpg
├── dci.jpg
├── dunston.jpg
├── end1.jpg
├── end2.jpg
├── forwchar.jpg
├── framework.png
├── highway.png
├── ill.jpg
├── in1.jpg
├── in2.jpg
├── loss.png
├── model.jpg
├── nothighway.png
├── sorted.jpg
├── tagscores.jpg
├── tagscores0.jpg
├── update.png
├── vloss1.png
├── vloss3.png
├── vscore.png
└── word.jpg
├── inference.py
├── models.py
├── train.py
└── utils.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018 Sagar Vinodababu
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | This is a **[PyTorch](https://pytorch.org) Tutorial to Sequence Labeling**.
2 |
3 | This is the second in [a series of tutorials](https://github.com/sgrvinod/Deep-Tutorials-for-PyTorch) I'm writing about _implementing_ cool models on your own with the amazing PyTorch library.
4 |
5 | Basic knowledge of PyTorch, recurrent neural networks is assumed.
6 |
7 | If you're new to PyTorch, first read [Deep Learning with PyTorch: A 60 Minute Blitz](https://pytorch.org/tutorials/beginner/deep_learning_60min_blitz.html) and [Learning PyTorch with Examples](https://pytorch.org/tutorials/beginner/pytorch_with_examples.html).
8 |
9 | Questions, suggestions, or corrections can be posted as issues.
10 |
11 | I'm using `PyTorch 0.4` in `Python 3.6`.
12 |
13 | ---
14 |
15 | **27 Jan 2020**: Working code for two new tutorials has been added — [Super-Resolution](https://github.com/sgrvinod/a-PyTorch-Tutorial-to-Super-Resolution) and [Machine Translation](https://github.com/sgrvinod/a-PyTorch-Tutorial-to-Machine-Translation)
16 |
17 | ---
18 |
19 | # Contents
20 |
21 | [***Objective***](https://github.com/sgrvinod/a-PyTorch-Tutorial-to-Sequence-Labeling#objective)
22 |
23 | [***Concepts***](https://github.com/sgrvinod/a-PyTorch-Tutorial-to-Sequence-Labeling#concepts)
24 |
25 | [***Overview***](https://github.com/sgrvinod/a-PyTorch-Tutorial-to-Sequence-Labeling#overview)
26 |
27 | [***Implementation***](https://github.com/sgrvinod/a-PyTorch-Tutorial-to-Sequence-Labeling#implementation)
28 |
29 | [***Training***](https://github.com/sgrvinod/a-PyTorch-Tutorial-to-Sequence-Labeling#training)
30 |
31 | [***Frequently Asked Questions***](https://github.com/sgrvinod/a-PyTorch-Tutorial-to-Sequence-Labeling#faqs)
32 |
33 | # Objective
34 |
35 | **To build a model that can tag each word in a sentence with entities, parts of speech, etc.**
36 |
37 | 
38 |
39 | We will be implementing the [_Empower Sequence Labeling with Task-Aware Neural Language Model_](https://arxiv.org/abs/1709.04109) paper. This is more advanced than most sequence tagging models, but you will learn many useful concepts – and it works extremely well. The authors' original implementation can be found [here](https://github.com/LiyuanLucasLiu/LM-LSTM-CRF).
40 |
41 | This model is special because it augments the sequence labeling task by training it _concurrently_ with language models.
42 |
43 | # Concepts
44 |
45 | * **Sequence Labeling**. duh.
46 |
47 | * **Language Models**. Language Modeling is to predict the next word or character in a sequence of words or characters. Neural language models achieve impressive results across a wide variety of NLP tasks like text generation, machine translation, image captioning, optical character recognition, and what have you.
48 |
49 | * **Character RNNs**. RNNs operating on individual characters in a text [are known](http://karpathy.github.io/2015/05/21/rnn-effectiveness/) to capture the underlying style and structure. In a sequence labeling task, they are especially useful since sub-word information can often yield important clues to an entity or tag.
50 |
51 | * **Multi-Task Learning**. Datasets available to train a model are often small. Creating annotations or handcrafted features to help your model along is not only cumbersome, but also frequently not adaptable to the diverse domains or settings in which your model may be useful. Sequence labeling, unfortunately, is a prime example. There is a way to mitigate this problem – jointly training multiple models that are joined at the hip will maximize the information available to each model, improving performance.
52 |
53 | * **Conditional Random Fields**. Discrete classifiers predict a class or label at a word. Conditional Random Fields (CRFs) can do you one better – they predict labels based on not just the word, but also the neighborhood. Which makes sense, because there _are_ patterns in a sequence of entities or labels. CRFs are widely used to model ordered information, be it for sequence labeling, gene sequencing, or even object detection and image segmentation in computer vision.
54 |
55 | * **Viterbi Decoding**. Since we're using CRFs, we're not so much predicting the right label at each word as we are predicting the right label _sequence_ for a word sequence. Viterbi Decoding is a way to do exactly this – find the most optimal tag sequence from the scores computed by a Conditional Random Field.
56 |
57 | * **Highway Networks**. Fully connected layers are a staple in any neural network to transform or extract features at different locations. Highway Networks accomplish this, but also allow information to flow unimpeded across transformations. This makes deep networks much more efficient or feasible.
58 |
59 | # Overview
60 |
61 | In this section, I will present an overview of this model. If you're already familiar with it, you can skip straight to the [Implementation](https://github.com/sgrvinod/a-PyTorch-Tutorial-to-Sequence-Labeling#implementation) section or the commented code.
62 |
63 | ### LM-LSTM-CRF
64 |
65 | The authors refer to the model as the _Language Model - Long Short-Term Memory - Conditional Random Field_ since it involves **co-training language models with an LSTM + CRF combination**.
66 |
67 | 
68 |
69 | This image from the paper thoroughly represents the entire model, but don't worry if it seems too complex at this time. We'll break it down to take a closer look at the components.
70 |
71 | ### Multi-Task Learning
72 |
73 | **Multi-task learning is when you simultaneously train a model on two or more tasks.**
74 |
75 | Usually we're only interested in _one_ of these tasks – in this case, the sequence labeling.
76 |
77 | But when layers in a neural network contribute towards performing multiple functions, they learn more than they would have if they had trained only on the primary task. This is because the information extracted at each layer is expanded to accomodate all tasks. When there is more information to work with, **performance on the primary task is enhanced**.
78 |
79 | Enriching existing features in this manner removes the need for using handcrafted features for sequence labeling.
80 |
81 | The **total loss** during multi-task learning is usually a linear combination of the losses on the individual tasks. The parameters of the combination can be fixed or learned as updateable weights.
82 |
83 |
84 |
85 |
86 |
87 | Since we're aggregating individual losses, you can see how upstream layers shared by multiple tasks would receive updates from all of them during backpropagation.
88 |
89 |
90 |
91 |
92 |
93 | The authors of the paper **simply add the losses** (`β=1`), and we will do the same.
94 |
95 | Let's take a look at the tasks that make up our model.
96 |
97 | **There are _three_**.
98 |
99 | 
100 |
101 | This leverages sub-word information to predict the next word.
102 |
103 | We do the same in the backward direction.
104 | 
105 |
106 | We _also_ use the outputs of these two **character-RNNs** as inputs to our **word-RNN** and **Conditional Random Field (CRF)** to perform our primary task of sequence labeling.
107 | 
108 |
109 | We're using sub-word information in our tagging task because it can be a powerful indicator of the tags, whether they're parts of speech or entities. For example, it may learn that adjectives commonly end with "-y" or "-ul", or that places often end with "-land" or "-burg".
110 |
111 | But our sub-word features, viz. the outputs of the Character RNNs, are also enriched with _additional_ information – the knowledge it needs to predict the next word in both forward and backward directions, because of models 1 and 2.
112 |
113 | Therefore, our sequence tagging model uses both
114 | - **word-level information** in the form of word embeddings.
115 | - **character-level information** up to and including each word in both directions, enriched with the know-how required to be able to predict the next word in both directions.
116 |
117 | The Bidirectional LSTM/RNN encodes these features into new features at each word containing information about the word and its neighborhood, at both the word-level and the character-level. This forms the input to the Conditional Random Field.
118 |
119 | ### Conditional Random Field (CRF)
120 |
121 | Without a CRF, we would have simply used a single linear layer to transform the output of the Bidirectional LSTM into scores for each tag. These are known as **emission scores**, which are a representation of the likelihood of the word being a certain tag.
122 |
123 | A CRF calculates not only the emission scores but also the **transition scores**, which are the likelihood of a word being a certain tag _considering_ the previous word was a certain tag. Therefore the transition scores measure how likely it is to transition from one tag to another.
124 |
125 | If there are `m` tags, transition scores are stored in a matrix of dimesions `m, m`, where the rows represent the tag of the previous word and the columns represent the tag of the current word. A value in this matrix at position `i, j` is the **likelihood of transitioning from the `i`th tag at the previous word to the `j`th tag at the current word**. Unlike emission scores, transition scores are not defined for each word in the sentence. They are global.
126 |
127 | In our model, the CRF layer outputs the **aggregate of the emission and transition scores at each word**.
128 |
129 | For a sentence of length `L`, emission scores would be an `L, m` tensor. Since the emission scores at each word do not depend on the tag of the previous word, we create a new dimension like `L, _, m` and broadcast (copy) the tensor along this direction to get an `L, m, m` tensor.
130 |
131 | The transition scores are an `m, m` tensor. Since the transition scores are global and do not depend on the word, we create a new dimension like `_, m, m` and broadcast (copy) the tensor along this direction to get an `L, m, m` tensor.
132 |
133 | We can now **add them to get the total scores which are an `L, m, m` tensor**. A value at position `k, i, j` is the _aggregate_ of the emission score of the `j`th tag at the `k`th word and the transition score of the `j`th tag at the `k`th word considering the previous word was the `i`th tag.
134 |
135 | For our example sentence `dunston checks in `, if we assume there are 5 tags in total, the total scores would look like this –
136 |
137 | 
138 |
139 | But wait a minute, why are there `` end `` tags? While we're at it, why are we using an `` token?
140 |
141 | ### About `` and `` tags, `` and `` tokens
142 |
143 | Since we're modeling the likelihood of transitioning between tags, we also include a `` tag and an `` tag in our tag-set.
144 |
145 | The transition score of a certain tag given that the previous tag was a `` tag represents the **likelihood of this tag being the _first_ tag in a sentence**. For example, sentences usually start with articles (a, an, the) or nouns or pronouns.
146 |
147 | The transition score of the `` tag considering a certain previous tag indicates the **likelihood of this previous tag being the _last_ tag in a sentence**.
148 |
149 | We will use an `` token in all sentences and not a `` token because the total CRF scores at each word are defined with respect to the _previous_ word's tag, which would make no sense at a `` token.
150 |
151 | The correct tag of the `` token is always the `` tag. The "previous tag" of the first word is always the `` tag.
152 |
153 | To illustrate, if our example sentence `dunston checks in ` had the tags `tag_2, tag_3, tag_3, `, the values in red indicate the scores of these tags.
154 |
155 | 
156 |
157 | ### Highway Networks
158 |
159 | We generally use activated linear layers to transform and process outputs of an RNN/LSTM.
160 |
161 | If you're familiar with residual connections, we can add the input before the transformation to the transformed output, creating a path for data-flow around the transformation.
162 |
163 |
164 |
165 |
166 |
167 | This path is a shortcut for the flow of gradients during backpropagation, and aids in the convergence of deep networks.
168 |
169 | A **Highway Network** is similar to a residual network, but we use a **sigmoid-activated gate to determine the ratio in which the input and transformed output is combined**.
170 |
171 |
172 |
173 |
174 |
175 | Since the character-RNNs contribute towards multiple tasks, **Highway Networks are used for extracting task-specific information** from its outputs.
176 |
177 | Therefore, we will use Highway Networks at **three locations** in our combined model –
178 |
179 | - to transform the output of the forward character-RNN to predict the next word.
180 | - to transform the output of the backward character-RNN to predict the next word (in the backward direction).
181 | - to transform the concatenated output of the forward and backward character-RNNs for use in the word-level RNN along with the word embedding.
182 |
183 | In a naive co-training setting, where we use the outputs of the character-RNNs directly for multiple tasks, i.e. without transformation, the discordance between the nature of the tasks could hurt performance.
184 |
185 | ### Putting it all together
186 |
187 | It might be clear by now what our combined network looks like.
188 |
189 | 
190 |
191 | ### Other configurations
192 |
193 | Progressively removing parts of our network results in progressively simpler networks that are used widely for sequence labeling.
194 |
195 | 
196 |
197 | #### (a) a Bi-LSTM + CRF sequence tagger that leverages sub-word information.
198 |
199 | There is no multi-task learning.
200 |
201 | Using character-level information without co-training still improves performance.
202 |
203 | #### (b) a Bi-LSTM + CRF sequence tagger.
204 |
205 | There is no multi-task learning or character-level processing.
206 |
207 | This configuration is used quite commonly in the industry and works well.
208 |
209 | #### (c) a Bi-LSTM sequence tagger.
210 |
211 | There is no multi-task learning, character-level processing, or CRFing. Note that a linear or Highway layer would replace the latter.
212 |
213 | This could work reasonably well, but a Conditional Random Field provides a sizeable performance boost.
214 |
215 | ### Viterbi Loss
216 |
217 | Remember, we're not using a linear layer that computes only the emission scores. Cross Entropy is not a suitable loss metric.
218 |
219 | Instead we will use the **Viterbi Loss** which, like Cross Entropy, is a "negative log likelihood". But here we will measure the likelihood of the gold (true) tag sequence, instead of the likelihood of the true tag at each word in the sequence. To find the likelihood, we consider the softmax over the scores of all tag sequences.
220 |
221 | The score of a tag sequence `t` is defined as the sum of the scores of the individual tags.
222 |
223 |
224 |
225 |
226 |
227 | For example, consider the CRF scores we looked at earlier –
228 |
229 | 
230 |
231 | The score of the tag sequence `tag_2, tag_3, tag_3, tag` is the sum of the values in red, `4.85 + 6.79 + 3.85 + 3.52 = 19.01`.
232 |
233 | **The Viterbi Loss is then defined as**
234 |
235 |
236 |
237 |
238 |
239 | where `t_G` is the gold tag sequence and `T` represents the space of all possible tag sequences.
240 |
241 | This simplifies to –
242 |
243 |
244 |
245 |
246 |
247 | Therefore, the Viterbi Loss is the **difference between the log-sum-exp of the scores of all possible tag sequences and the score of the gold tag sequence**, i.e. `log-sum-exp(all scores) - gold score`.
248 |
249 | ### Viterbi Decoding
250 |
251 | **Viterbi Decoding** is a way to construct the most optimal tag sequence, considering not only the likelihood of a tag at a certain word (emission scores), but also the likelihood of a tag considering the previous and next tags (transition scores).
252 |
253 | Once you generate CRF scores in a `L, m, m` matrix for a sequence of length `L`, we start decoding.
254 |
255 | Viterbi Decoding is best understood with an example. Consider again –
256 |
257 | 
258 |
259 | For the first word in the sequence, the `previous_tag` can only be ``. Therefore only consider that one row.
260 |
261 | These are also the cumulative scores for each `current_tag` at the first word.
262 |
263 | 
264 |
265 | We will also keep track of the `previous_tag` that corresponds to each score. These are known as **backpointers**. At the first word, they are obviously all `` tags.
266 |
267 | At the second word, **add the previous cumulative scores to the CRF scores of this word to generate new cumulative scores**.
268 |
269 | Note that the first word's `current_tag`s are the second word's `previous_tag`s. Therefore, broadcast the first word's cumulative score along the `current_tag` dimension.
270 |
271 | 
272 |
273 | For each `current_tag`, consider only the maximum of the scores from all `previous_tag`s.
274 |
275 | Store backpointers, i.e. the previous tags that correspond to these maximum scores.
276 |
277 | 
278 |
279 | Repeat this process at the third word.
280 |
281 | 
282 | 
283 |
284 | ...and the last word, which is the `` token.
285 |
286 | Here, the only difference is you _already know_ the correct tag. You need the maximum score and backpointer **only for the `` tag**.
287 |
288 | 
289 | 
290 |
291 | Now that you accumulated CRF scores across the entire sequence, **you trace _backwards_ to reveal the tag sequence with the highest possible score**.
292 |
293 | 
294 |
295 | We find that the most optimal tag sequence for `dunston checks in ` is `tag_2 tag_3 tag_3 `.
296 |
297 | # Implementation
298 |
299 | The sections below briefly describe the implementation.
300 |
301 | They are meant to provide some context, but **details are best understood directly from the code**, which is quite heavily commented.
302 |
303 | ### Dataset
304 |
305 | I use the CoNLL 2003 NER dataset to compare my results with the paper.
306 |
307 | Here's a snippet –
308 |
309 | ```
310 | -DOCSTART- -X- O O
311 |
312 | EU NNP I-NP I-ORG
313 | rejects VBZ I-VP O
314 | German JJ I-NP I-MISC
315 | call NN I-NP O
316 | to TO I-VP O
317 | boycott VB I-VP O
318 | British JJ I-NP I-MISC
319 | lamb NN I-NP O
320 | . . O O
321 | ```
322 |
323 | This dataset is not meant to be publicly distributed, although you may find it somewhere online.
324 |
325 | There are several public datasets online that you can use to train the model. These may not all be 100% human annotated, but they are sufficient.
326 |
327 | For NER tagging, you can use the [Groningen Meaning Bank](http://gmb.let.rug.nl/data.php).
328 |
329 | For POS tagging, NLTK has a small dataset available you can access with `nltk.corpus.treebank.tagged_sents()`.
330 |
331 | You would either have to convert it to the CoNLL 2003 NER data format, or modify the code referenced in the [Data Pipeline](https://github.com/sgrvinod/a-PyTorch-Tutorial-to-Sequence-Labeling#data-pipeline) section.
332 |
333 | ### Inputs to model
334 |
335 | We will need eight inputs.
336 |
337 | #### Words
338 |
339 | These are the word sequences that must be tagged.
340 |
341 | `dunston checks in`
342 |
343 | As discussed earlier, we will not use `` tokens but we *will* need to use `` tokens.
344 |
345 | `dunston, checks, in, `
346 |
347 | Since we pass the sentences around as fixed size Tensors, we need to pad sentences (which are naturally of varying length) to the same length with `` tokens.
348 |
349 | `dunston, checks, in, , , , , ...`
350 |
351 | Furthermore, we create a `word_map` which is an index mapping for each word in the corpus, including the ``, and `` tokens. PyTorch, like other libraries, needs words encoded as indices to look up embeddings for them, or to identify their place in the predicted word scores.
352 |
353 | `4381, 448, 185, 4669, 0, 0, 0, ...`
354 |
355 | Therefore, **word sequences fed to the model must be an `Int` tensor of dimensions `N, L_w`** where `N` is the batch_size and `L_w` is the padded length of the word sequences (usually the length of the longest word sequence).
356 |
357 | #### Characters (Forward)
358 |
359 | These are the character sequences in the forward direction.
360 |
361 | `'d', 'u', 'n', 's', 't', 'o', 'n', ' ', 'c', 'h', 'e', 'c', 'k', 's', ' ', 'i', 'n', ' '`
362 |
363 | We need `` tokens in the character sequences to match the `` token in the word sequences. Since we're going to use character-level features at each word in the word sequence, we need character-level features at `` in the word sequence.
364 |
365 | `'d', 'u', 'n', 's', 't', 'o', 'n', ' ', 'c', 'h', 'e', 'c', 'k', 's', ' ', 'i', 'n', ' ', `
366 |
367 | We also need to pad them.
368 |
369 | `'d', 'u', 'n', 's', 't', 'o', 'n', ' ', 'c', 'h', 'e', 'c', 'k', 's', ' ', 'i', 'n', ' ', , , , , ...`
370 |
371 | And encode them with a `char_map`.
372 |
373 | `29, 2, 12, 8, 7, 14, 12, 3, 6, 18, 1, 6, 21, 8, 3, 17, 12, 3, 60, 0, 0, 0, ...`
374 |
375 | Therefore, **forward character sequences fed to the model must be an `Int` tensor of dimensions `N, L_c`**, where `L_c` is the padded length of the character sequences (usually the length of the longest character sequence).
376 |
377 | #### Characters (Backward)
378 |
379 | This would be processed the same as the forward sequence, but backward. (The `` tokens would still be at the end, naturally.)
380 |
381 | `'n', 'i', ' ', 's', 'k', 'c', 'e', 'h', 'c', ' ', 'n', 'o', 't', 's', 'n', 'u', 'd', ' ', , , , , ...`
382 |
383 | `12, 17, 3, 8, 21, 6, 1, 18, 6, 3, 12, 14, 7, 8, 12, 2, 29, 3, 60, 0, 0, 0, ...`
384 |
385 | Therefore, **backward character sequences fed to the model must be an `Int` tensor of dimensions `N, L_c`**.
386 |
387 | #### Character Markers (Forward)
388 |
389 | These markers are **positions in the character sequences** where we extract features to –
390 | - generate the next word in the language models, and
391 | - use as character-level features in the word-level RNN in the sequence labeler
392 |
393 | We will extract features at the end of every space `' '` in the character sequence, and at the `` token.
394 |
395 | For the forward character sequence, we extract at –
396 |
397 | `7, 14, 17, 18`
398 |
399 | These are points after `dunston`, `checks`, `in`, `` respectively. Thus, we have **a marker for each word in the word sequence**, which makes sense. (In the language models, however, since we're predicting the _next_ word, we won't predict at the marker which corresponds to ``.)
400 |
401 | We pad these with `0`s. It doesn't matter what we pad with as long as they're valid indices. (We will extract features at the pads, but we will not use them.)
402 |
403 | `7, 14, 17, 18, 0, 0, 0, ...`
404 |
405 | They are padded to the padded length of the word sequences, `L_w`.
406 |
407 | Therefore, **forward character markers fed to the model must be an `Int` tensor of dimensions `N, L_w`**.
408 |
409 | #### Character Markers (Backward)
410 |
411 | For the markers in the backward character sequences, we similarly find the positions of every space `' '` and the `` token.
412 |
413 | We also ensure that these **positions are in the same _word_ order as in the forward markers**. This alignment makes it easier to concatenate features extracted from the forward and backward character sequences, and also prevents having to re-order the targets in the language models.
414 |
415 | `17, 9, 2, 18`
416 |
417 | These are points after `notsnud`, `skcehc`, `ni`, `` respectively.
418 |
419 | We pad with `0`s.
420 |
421 | `17, 9, 2, 18, 0, 0, 0, ...`
422 |
423 | Therefore, **backward character markers fed to the model must be an `Int` tensor of dimensions `N, L_w`**.
424 |
425 | #### Tags
426 |
427 | Let's assume the correct tags for `dunston, checks, in, ` are –
428 |
429 | `tag_2, tag_3, tag_3, `
430 |
431 | We have a `tag_map` (containing the tags ``, `tag_1`, `tag_2`, `tag_3`, ``).
432 |
433 | Normally, we would just encode them directly (before padding) –
434 |
435 | `2, 3, 3, 5`
436 |
437 | These are `1D` encodings, i.e., tag positions in a `1D` tag map.
438 |
439 | But the **outputs of the CRF layer are `2D` `m, m` tensors** at each word. We would need to encode tag positions in these `2D` outputs.
440 |
441 | 
442 |
443 | The correct tag positions are marked in red.
444 |
445 | `(0, 2), (2, 3), (3, 3), (3, 4)`
446 |
447 | If we unroll these scores into a `1D` `m*m` tensor, then the tag positions in the unrolled tensor would be
448 |
449 | ```python
450 | tag_map[previous_tag] * len(tag_map) + tag_map[current_tag]
451 | ```
452 |
453 | Therefore, we encode `tag_2, tag_3, tag_3, ` as
454 |
455 | `2, 13, 18, 19`
456 |
457 | Note that you can retrieve the original `tag_map` indices by taking the modulus
458 |
459 | ```python
460 | t % len(tag_map)
461 | ```
462 |
463 | They will be padded to the padded length of the word sequences, `L_w`.
464 |
465 | Therefore, **tags fed to the model must be an `Int` tensor of dimensions `N, L_w`**.
466 |
467 | #### Word Lengths
468 |
469 | These are the actual lengths of the word sequences including the `` tokens. Since PyTorch supports dynamic graphs, we will compute only over these lengths and not over the ``.
470 |
471 | Therefore, **word lengths fed to the model must be an `Int` tensor of dimensions `N`**.
472 |
473 | #### Character Lengths
474 |
475 | These are the actual lengths of the character sequences including the `` tokens. Since PyTorch supports dynamic graphs, we will compute only over these lengths and not over the ``.
476 |
477 | Therefore, **character lengths fed to the model must be an `Int` tensor of dimensions `N`**.
478 |
479 | ### Data Pipeline
480 |
481 | See `read_words_tags()` in [`utils.py`](https://github.com/sgrvinod/a-PyTorch-Tutorial-to-Sequence-Labeling/blob/master/utils.py).
482 |
483 | This reads the input files in the CoNLL 2003 format, and extracts the word and tag sequences.
484 |
485 | See `create_maps()` in [`utils.py`](https://github.com/sgrvinod/a-PyTorch-Tutorial-to-Sequence-Labeling/blob/master/utils.py).
486 |
487 | Here, we create encoding maps for words, characters, and tags. We bin rare words and characters as ``s (unknowns).
488 |
489 | See `create_input_tensors()` in [`utils.py`](https://github.com/sgrvinod/a-PyTorch-Tutorial-to-Sequence-Labeling/blob/master/utils.py).
490 |
491 | We generate the eight inputs detailed in the [Inputs to Model](https://github.com/sgrvinod/a-PyTorch-Tutorial-to-Sequence-Labeling#inputs-to-model) section.
492 |
493 | See `load_embeddings()` in [`utils.py`](https://github.com/sgrvinod/a-PyTorch-Tutorial-to-Sequence-Labeling/blob/master/utils.py).
494 |
495 | We load pre-trained embeddings, with the option to expand the `word_map` to include out-of-corpus words present in the embedding vocabulary. Note that this may also include rare in-corpus words that were binned as ``s earlier.
496 |
497 | See `WCDataset` in [`datasets.py`](https://github.com/sgrvinod/a-PyTorch-Tutorial-to-Sequence-Labeling/blob/master/datasets.py).
498 |
499 | This is a subclass of PyTorch [`Dataset`](https://pytorch.org/docs/master/data.html#torch.utils.data.Dataset). It needs a `__len__` method defined, which returns the size of the dataset, and a `__getitem__` method which returns the `i`th set of the eight inputs to the model.
500 |
501 | The `Dataset` will be used by a PyTorch [`DataLoader`](https://pytorch.org/docs/master/data.html#torch.utils.data.DataLoader) in `train.py` to create and feed batches of data to the model for training or validation.
502 |
503 | ### Highway Networks
504 |
505 | See `Highway` in [`models.py`](https://github.com/sgrvinod/a-PyTorch-Tutorial-to-Sequence-Labeling/blob/master/models.py).
506 |
507 | A **transform** is a ReLU-activated linear transformation of the input. A **gate** is a sigmoid-activated linear transformation of the input. Note that **both transformations must be the same size as the input**, to allow for adding the input in a residual connection.
508 |
509 | The `num_layers` attribute specifices how many transform-gate-residual-connection operations we perform in series. Usually just one is sufficient.
510 |
511 | We store the requisite number of transform and gate layers in separate `ModuleList()`s, and use a `for` loop to perform successive operations.
512 |
513 | ### Language Models
514 |
515 | See `LM_LSTM_CRF` in [`models.py`](https://github.com/sgrvinod/a-PyTorch-Tutorial-to-Sequence-Labeling/blob/master/models.py).
516 |
517 | At the very outset, we **sort the forward and backward character sequences by decreasing lengths**. This is required to use [`pack_padded_sequence()`](https://pytorch.org/docs/master/nn.html#torch.nn.utils.rnn.pack_padded_sequence) in order for the LSTM to compute over only the valid timesteps, i.e. the true length of the sequences.
518 |
519 | Remember to also sort all other tensors in the same order.
520 |
521 | See [`dynamic_rnn.py`](https://github.com/sgrvinod/a-PyTorch-Tutorial-to-Sequence-Labeling/blob/master/dynamic_rnn.py) for an illustration of how `pack_padded_sequence()` can be used to take advantage of PyTorch's dynamic graphing and batching capabilities so that we don't process the pads. It flattens the sorted sequences by timestep while ignoring the pads, and the **LSTM computes over only the effective batch size `N_t` at each timestep**.
522 |
523 | 
524 |
525 | The **sorting allows the top `N_t` at any timestep to align with the outputs from the previous step**. At the third timestep, for example, we process only the top 5 images, using the top 5 outputs from the previous step. Except for the sorting, all of this is handled internally by PyTorch, but it's still very useful to understand what `pack_padded_sequence()` does so we can use it in other scenarios to achieve similar ends. (See the related question about handling variable length sequences in the [FAQs](https://github.com/sgrvinod/a-PyTorch-Tutorial-to-Sequence-Labeling#faqs) section.)
526 |
527 | Upon sorting, we **apply the forward and backward LSTMs on the forward and backward `packed_sequences`** respectively. We use [`pad_packed_sequence()`](https://pytorch.org/docs/master/nn.html#torch.nn.utils.rnn.pad_packed_sequence) to unflatten and re-pad the outputs.
528 |
529 | We **extract only the outputs at the forward and backward character markers** with [`gather`](https://pytorch.org/docs/master/torch.html#torch.gather). This function is very useful for extracting only certain indices from a tensor that are specified in a separate tensor.
530 |
531 | These **extracted outputs are processed by the forward and backward Highway layers** before applying a **linear layer to compute scores over the vocabulary** for predicting the next word at each marker. We do this only during training, since it makes no sense to perform language modeling for multi-task learning during validation or inference. The `training` attribute of any model is set with `model.train()` or `model.eval()` in `train.py`. (Note that this is primarily used to enable or disable dropout and batch-norm layers in a PyTorch model during training and inference respectively.)
532 |
533 | ### Sequence Labeling Model
534 |
535 | See `LM_LSTM_CRF` in [`models.py`](https://github.com/sgrvinod/a-PyTorch-Tutorial-to-Sequence-Labeling/blob/master/models.py) (continued).
536 |
537 | We also **sort the word sequences by decreasing lengths**, because there may not always be a correlation between the lengths of the word sequences and the character sequences.
538 |
539 | Remember to also sort all other tensors in the same order.
540 |
541 | We **concatenate the forward and backward character LSTM outputs at the markers, and run it through the third Highway layer**. This will extract the sub-word information at each word which we will use for sequence labeling.
542 |
543 | We **concatenate this result with the word embeddings, and compute BLSTM outputs** over the `packed_sequence`.
544 |
545 | Upon re-padding with `pad_packed_sequence()`, we have the features we need to feed to the CRF layer.
546 |
547 | ### Conditional Random Field (CRF)
548 |
549 | See `CRF` in [`models.py`](https://github.com/sgrvinod/a-PyTorch-Tutorial-to-Sequence-Labeling/blob/master/models.py).
550 |
551 | You may find this layer is surprisingly straightforward considering the value it adds to our model.
552 |
553 | A linear layer is used to transform the outputs from the BLSTM to scores for each tag, which are the **emission scores**.
554 |
555 | A single tensor is used to hold the **transition scores**. This tensor is a [`Parameter`](https://pytorch.org/docs/master/nn.html#torch.nn.Parameter) of the model, which means it is updateable during backpropagation, just like the weights of the other layers.
556 |
557 | To find the CRF scores, **compute the emission scores at each word and add it to the transition scores**, after broadcasting both as described in the [CRF Overview](https://github.com/sgrvinod/a-PyTorch-Tutorial-to-Sequence-Labeling#conditional-random-field-crf).
558 |
559 | ### Viterbi Loss
560 |
561 | See `ViterbiLoss` in [`models.py`](https://github.com/sgrvinod/a-PyTorch-Tutorial-to-Sequence-Labeling/blob/master/models.py).
562 |
563 | We established in the [Viterbi Loss Overview](https://github.com/sgrvinod/a-PyTorch-Tutorial-to-Sequence-Labeling#viterbi-loss) that we want to minimize the **difference between the log-sum-exp of the scores of all possible valid tag sequences and the score of the gold tag sequence**, i.e. `log-sum-exp(all scores) - gold score`.
564 |
565 | We sum the CRF scores of each true tag as described earlier to calculate the **gold score**.
566 |
567 | Remember how we encoded tag sequences with their positions in the unrolled CRF scores? We extract the scores at these positions with `gather()` and eliminate the pads with `pack_padded_sequences()` before summing.
568 |
569 | Finding the **log-sum-exp of the scores of all possible sequences** is slightly trickier. We use a `for` loop to iterate over the timesteps. At each timestep, we **accumulate scores for each `current_tag`** by –
570 |
571 | - **adding the CRF scores at this timestep to the accumulated scores from the previous timestep** to find the accumulated score for each `current_tag` for each `previous_tag`. We do this at only the effective batch size, i.e. for sequences that haven't completed yet. (Our sequences are still sorted by decreasing word lengths, from the `LM-LSTM-CRF` model.)
572 | - **for each `current_tag`, compute the log-sum-exp over the `previous_tag`s** to find the new accumulated scores at each `current_tag`.
573 |
574 | After computing over the variable lengths of all sequences, we are left with a tensor of dimensions `N, m`, where `m` is the number of (current) tags. These are the log-sum-exp accumulated scores over all possible sequences ending in each of the `m` tags. However, since valid sequences can only end with the `` tag, **sum over only the `` column to find the log-sum-exp of the scores of all possible valid sequences**.
575 |
576 | We find the difference, `log-sum-exp(all scores) - gold score`.
577 |
578 | ### Viterbi Decoding
579 |
580 | See `ViterbiDecoder` in [`inference.py`](https://github.com/sgrvinod/a-PyTorch-Tutorial-to-Sequence-Labeling/blob/master/inference.py).
581 |
582 | This implements the process described in the [Viterbi Decoding Overview](https://github.com/sgrvinod/a-PyTorch-Tutorial-to-Sequence-Labeling#viterbi-decoding).
583 |
584 | We accumulate scores in a `for` loop in a manner similar to what we did in `ViterbiLoss`, except here we **find the maximum of the `previous_tag` scores for each `current_tag`**, instead of computing the log-sum-exp. We also **keep track of the `previous_tag` that corresponds to this maximum score** in a backpointer tensor.
585 |
586 | We **pad the backpointer tensor with `` tags** because this allows us to trace backwards over the pads, eventually arriving at the _actual_ `` tag, whereupon the _actual_ **backtracing** begins.
587 |
588 | # Training
589 |
590 | See [`train.py`](https://github.com/sgrvinod/a-PyTorch-Tutorial-to-Sequence-Labeling/blob/master/train.py).
591 |
592 | The parameters for the model (and training it) are at the beginning of the file, so you can easily check or modify them should you wish to.
593 |
594 | To **train your model from scratch**, simply run this file –
595 |
596 | `python train.py`
597 |
598 | To **resume training at a checkpoint**, point to the corresponding file with the `checkpoint` parameter at the beginning of the code.
599 |
600 | Note that we perform validation at the end of every training epoch.
601 |
602 | ### Trimming Batch Inputs
603 |
604 | You will notice we **trim the inputs at each batch to the maximum sequence lengths in that batch**. This is so we don't have more pads in each batch that we actually need.
605 |
606 | But why? Although the RNNs in our model don't compute over the pads, **the linear layers still do**. It's pretty straightward to change this – see the related question about handling variable length sequences in the [FAQs](https://github.com/sgrvinod/a-PyTorch-Tutorial-to-Sequence-Labeling#faqs) section.
607 |
608 | For this tutorial, I figured a little extra computation over some pads was worth the straightforwardness of not having to perform a slew of operations – Highway, CRF, other linear layers, concatenations – on a `packed_sequence`.
609 |
610 | ### Loss
611 |
612 | In the multi-task scenario, we have chosen to sum the Cross Entropy losses from the two language modelling tasks and the Viterbi Loss from the sequence labeling task.
613 |
614 | Even though we are **minimizing the sum of these losses**, we are actually only interested in minimizing the Viterbi Loss _by virtue of minimizing the sum of these losses_. It is the Viterbi Loss which reflects performance on the primary task.
615 |
616 | We use `pack_padded_sequence()` to eliminate pads wherever necessary.
617 |
618 | ### F1 Score
619 |
620 | Like in the paper, we use the **macro-averaged F1 score as the criterion for early-stopping**. Naturally, computing the F1 score requires Viterbi Decoding the CRF scores to generate our optimal tag sequences.
621 |
622 | We use `pack_padded_sequence()` to eliminate pads wherever necessary.
623 |
624 | ### Remarks
625 |
626 | I have followed the parameters in the authors' implementation as closely as possible.
627 |
628 | I used a batch size of `10` sentences. I employed Stochastic Gradient Descent with momentum. The learning rate was decayed every epoch. I used 100D [GloVe](https://nlp.stanford.edu/projects/glove/) pretrained embeddings without fine-tuning.
629 |
630 | It took about 80s to train one epoch on a Titan X (Pascal).
631 |
632 | The F1 score on the validation set hit `91%` around epoch 50, and peaked at `91.6%` on epoch 171. I ran it for a total of 200 epochs. This is pretty close to the results in the paper.
633 |
634 | ### Model Checkpoint
635 |
636 | You can download this pretrained model [here](https://drive.google.com/open?id=1P-w-s6QbsixcGnm3UjPMkgGpuz684kiY).
637 |
638 | # FAQs
639 |
640 | __How do we decide if we need `` and `` tokens for a model that uses sequences?__
641 |
642 | If this seems confusing at first, it will easily resolve itself when you think about the requirements of the model you are planning to train.
643 |
644 | For sequence labeling with a CRF, you need the `` token (_or_ the `` token; see next question) because of how the CRF scores are structured.
645 |
646 | In my other tutorial on image captioning, I used _both_ `` and `` tokens. The model needed to start decoding _somewhere_, and learn to recognize _when_ to stop decoding during inference.
647 |
648 | If you're performing text classification, you would need neither.
649 |
650 | ---
651 |
652 | __Can we have the CRF generate `current_word -> next_word` scores instead of `previous_word -> current_word` scores?__
653 |
654 | Yes. In this case you would broadcast the emission scores like `L, m, _`, and you would have a `` token in every sentence instead of an `` token. The correct tag of the `` token would always be the `` tag. The "next tag" of the last word would always be the `` tag.
655 |
656 | I think the `previous word -> current word` convention is slightly better because there are language models in the mix. It fits in quite nicely to be able to predict the `` token at the last real word, and therefore learn to recognize when a sentence is complete.
657 |
658 | ---
659 |
660 | __Why are we using different vocabularies for the sequence tagger's inputs and language models' outputs?__
661 |
662 | The language models will learn to predict only those words it has seen during training. It's really unnecessary, and a huge waste of computation and memory, to use a linear-softmax layer with the extra ~400,000 out-of-corpus words from the embedding file it will never learn to predict.
663 |
664 | But we _can_ add these words to the input layer even if the model never sees them during training. This is because we're using pre-trained embeddings at the input. It doesn't _need_ to see them because the meanings of words are encoded in these vectors. If it's encountered a `chimpanzee` before, it very likely knows what to do with an `orangutan`.
665 |
666 | ---
667 |
668 | __Is it a good idea to fine-tune the pre-trained word embeddings we use in this model?__
669 |
670 | I refrain from fine-tuning because most of the input vocabulary is not in-corpus. Most embeddings will remain the same while a few are fine-tuned. If fine-tuning changes these embeddings sufficiently, the model may not work well with the words that weren't fine-tuned. In the real world, we're bound to encounter many words that weren't present in a newspaper corpus from 2003.
671 |
672 | ---
673 |
674 | __What are some ways we can construct dynamic graphs in PyTorch to compute over only the true lengths of sequences?__
675 |
676 | If you're using an RNN, simply use [`pack_padded_sequence()`](https://pytorch.org/docs/master/nn.html#torch.nn.utils.rnn.pack_padded_sequence). PyTorch will internally compute over only the true lengths. See [`dynamic_rnn.py`](https://github.com/sgrvinod/a-PyTorch-Tutorial-to-Sequence-Labeling/blob/master/dynamic_rnn.py) for an example.
677 |
678 | If you want to execute an operation (like a linear transformation) only on the true timesteps, `pack_padded_sequences()` is still the way to go. This flattens the tensor by timestep while removing the pads. You can perform your operation on this flattened tensor, and then use [`pad_packed_sequence()`](https://pytorch.org/docs/master/nn.html#torch.nn.utils.rnn.pad_packed_sequence) to unflatten it and re-pad it with `0`s.
679 |
680 | Similarly, if you want to perform an aggregation operation, like computing the loss, use `pack_padded_sequences()` to eliminate the pads.
681 |
682 | If you want to perform timestep-wise operations, you can take a leaf out of how `pack_padded_sequences()` works, and compute only on the effective batch size at each timestep with a `for` loop to iterate over the timesteps. We did this in the `ViterbiLoss` and `ViterbiDecoder`. I also used an `LSTMCell()` in this fashion in my image captioning tutorial.
683 |
684 | ---
685 |
686 | __*Dunston Checks In*? Really?__
687 |
688 | I had no memory of this movie for twenty years. I was trying to think of a short sentence that would be easier to visualize in this tutorial and it just popped into my mind riding a wave of 90s nostalgia.
689 |
690 |
691 |
692 |
693 |
694 | I wish I hadn't googled it though. Damn, the critics were harsh, weren't they? This gem was overwhelmingly and universally panned. I'm not sure I'd disagree if I watched it now, but that just goes to show the world is so much more fun when you're a kid.
695 |
696 | Didn't have to worry about LM-LSTM-CRFs or nuthin...
697 |
--------------------------------------------------------------------------------
/datasets.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data import Dataset
2 | import torch
3 |
4 |
5 | class WCDataset(Dataset):
6 | """
7 | PyTorch Dataset for the LM-LSTM-CRF model. To be used by a PyTorch DataLoader to feed batches to the model.
8 | """
9 |
10 | def __init__(self, wmaps, cmaps_f, cmaps_b, cmarkers_f, cmarkers_b, tmaps, wmap_lengths, cmap_lengths):
11 | """
12 | :param wmaps: padded encoded word sequences
13 | :param cmaps_f: padded encoded forward character sequences
14 | :param cmaps_b: padded encoded backward character sequences
15 | :param cmarkers_f: padded forward character markers
16 | :param cmarkers_b: padded backward character markers
17 | :param tmaps: padded encoded tag sequences (indices in unrolled CRF scores)
18 | :param wmap_lengths: word sequence lengths
19 | :param cmap_lengths: character sequence lengths
20 | """
21 | self.wmaps = wmaps
22 | self.cmaps_f = cmaps_f
23 | self.cmaps_b = cmaps_b
24 | self.cmarkers_f = cmarkers_f
25 | self.cmarkers_b = cmarkers_b
26 | self.tmaps = tmaps
27 | self.wmap_lengths = wmap_lengths
28 | self.cmap_lengths = cmap_lengths
29 |
30 | self.data_size = self.wmaps.size(0)
31 |
32 | def __getitem__(self, i):
33 | return self.wmaps[i], self.cmaps_f[i], self.cmaps_b[i], self.cmarkers_f[i], self.cmarkers_b[i], self.tmaps[i], \
34 | self.wmap_lengths[i], self.cmap_lengths[i]
35 |
36 | def __len__(self):
37 | return self.data_size
38 |
--------------------------------------------------------------------------------
/dynamic_rnn.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
4 |
5 | # Create a tensor with variable length sequences and pads (25)
6 | seqs = torch.LongTensor([[0, 1, 2, 3, 25, 25, 25],
7 | [4, 5, 25, 25, 25, 25, 25],
8 | [6, 7, 8, 9, 10, 11, 25]])
9 |
10 | # Store lengths of the actual sequences, ignoring padding
11 | # These are the points up to which we want the RNN to process the sequence
12 | seq_lens = torch.LongTensor([4, 2, 6])
13 |
14 | # Sort by decreasing lengths
15 | seq_lens, sort_ind = seq_lens.sort(dim=0, descending=True)
16 | seqs = seqs[sort_ind]
17 |
18 | # Create an embedding layer, with 0 vectors for the pads
19 | embeds = nn.Embedding(26, 10, padding_idx=25)
20 |
21 | # Create an LSTM layer
22 | lstm = nn.LSTM(10, 50, bidirectional=False, batch_first=True)
23 |
24 | # WITHOUT DYNAMIC BATCHING
25 |
26 | embeddings = embeds(seqs)
27 | out_static, _ = lstm(embeddings)
28 |
29 | # The number of timesteps in the output will be the same as the total padded timesteps in the input,
30 | # since the LSTM computed over the pads
31 | assert out_static.size(1) == embeddings.size(1)
32 |
33 | # Look at the output at a timestep that we know is a pad
34 | print(out_static[1, -1])
35 |
36 | # WITH DYNAMIC BATCHING
37 |
38 | # Pack the sequence
39 | packed_seqs = pack_padded_sequence(embeddings, seq_lens.tolist(), batch_first=True)
40 |
41 | # To execute the LSTM over only the valid timesteps
42 | out_dynamic, _ = lstm(packed_seqs)
43 |
44 | # Use the inverse function to re-pad it
45 | out_dynamic, lens = pad_packed_sequence(out_dynamic, batch_first=True)
46 |
47 | # Note that since we re-padded it, the total padded timesteps will be the length of the longest sequence (6)
48 | assert out_dynamic.size(1) != embeddings.size(1)
49 | print(out_dynamic.shape)
50 |
51 | # Look at the output at a timestep that we know is a pad
52 | print(out_dynamic[1, -1])
53 |
54 | # It's all zeros!
55 |
56 | #########################################################
57 |
58 | # So, what does pack_padded_sequence do?
59 | # It removes pads, flattens by timestep, and keeps track of effective batch_size at each timestep
60 |
61 | # The RNN computes only on the effective batch size "b_t" at each timestep
62 | # This is why we sort - so the top "b_t" rows at timestep "t" are aligned with the top "b_t" outputs from timestep "t-1"
63 |
64 | # Consider the original encoded sequences (sorted)
65 | print(seqs)
66 |
67 | # Let's pack it
68 | packed_seqs = pack_padded_sequence(seqs, seq_lens, batch_first=True)
69 |
70 | # The result of pack_padded_sequence() is a tuple containing the flattened tensor and the effective batch size at each timestep
71 | # Here's the flattened tensor with pads removed
72 | print(packed_seqs[0])
73 | # You can see it's flattened timestep-wise
74 | # Since pads are removed, the total datapoints are equal to the number of valid timsteps
75 | assert packed_seqs[0].size(0) == sum(seq_lens.tolist())
76 |
77 | # Here's the effective batch size at each timestep
78 | print(packed_seqs[1])
79 | # If you look at the original encoded sequences, you can see this is true
80 | print(seqs)
81 |
82 |
--------------------------------------------------------------------------------
/img/backchar.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sgrvinod/a-PyTorch-Tutorial-to-Sequence-Labeling/ee3f34b45a6e24dd748a144bfc25b1adf9e1f077/img/backchar.jpg
--------------------------------------------------------------------------------
/img/backtrace.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sgrvinod/a-PyTorch-Tutorial-to-Sequence-Labeling/ee3f34b45a6e24dd748a144bfc25b1adf9e1f077/img/backtrace.jpg
--------------------------------------------------------------------------------
/img/checks1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sgrvinod/a-PyTorch-Tutorial-to-Sequence-Labeling/ee3f34b45a6e24dd748a144bfc25b1adf9e1f077/img/checks1.jpg
--------------------------------------------------------------------------------
/img/checks2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sgrvinod/a-PyTorch-Tutorial-to-Sequence-Labeling/ee3f34b45a6e24dd748a144bfc25b1adf9e1f077/img/checks2.jpg
--------------------------------------------------------------------------------
/img/configs.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sgrvinod/a-PyTorch-Tutorial-to-Sequence-Labeling/ee3f34b45a6e24dd748a144bfc25b1adf9e1f077/img/configs.jpg
--------------------------------------------------------------------------------
/img/crfscores.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sgrvinod/a-PyTorch-Tutorial-to-Sequence-Labeling/ee3f34b45a6e24dd748a144bfc25b1adf9e1f077/img/crfscores.jpg
--------------------------------------------------------------------------------
/img/dci.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sgrvinod/a-PyTorch-Tutorial-to-Sequence-Labeling/ee3f34b45a6e24dd748a144bfc25b1adf9e1f077/img/dci.jpg
--------------------------------------------------------------------------------
/img/dunston.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sgrvinod/a-PyTorch-Tutorial-to-Sequence-Labeling/ee3f34b45a6e24dd748a144bfc25b1adf9e1f077/img/dunston.jpg
--------------------------------------------------------------------------------
/img/end1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sgrvinod/a-PyTorch-Tutorial-to-Sequence-Labeling/ee3f34b45a6e24dd748a144bfc25b1adf9e1f077/img/end1.jpg
--------------------------------------------------------------------------------
/img/end2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sgrvinod/a-PyTorch-Tutorial-to-Sequence-Labeling/ee3f34b45a6e24dd748a144bfc25b1adf9e1f077/img/end2.jpg
--------------------------------------------------------------------------------
/img/forwchar.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sgrvinod/a-PyTorch-Tutorial-to-Sequence-Labeling/ee3f34b45a6e24dd748a144bfc25b1adf9e1f077/img/forwchar.jpg
--------------------------------------------------------------------------------
/img/framework.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sgrvinod/a-PyTorch-Tutorial-to-Sequence-Labeling/ee3f34b45a6e24dd748a144bfc25b1adf9e1f077/img/framework.png
--------------------------------------------------------------------------------
/img/highway.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sgrvinod/a-PyTorch-Tutorial-to-Sequence-Labeling/ee3f34b45a6e24dd748a144bfc25b1adf9e1f077/img/highway.png
--------------------------------------------------------------------------------
/img/ill.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sgrvinod/a-PyTorch-Tutorial-to-Sequence-Labeling/ee3f34b45a6e24dd748a144bfc25b1adf9e1f077/img/ill.jpg
--------------------------------------------------------------------------------
/img/in1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sgrvinod/a-PyTorch-Tutorial-to-Sequence-Labeling/ee3f34b45a6e24dd748a144bfc25b1adf9e1f077/img/in1.jpg
--------------------------------------------------------------------------------
/img/in2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sgrvinod/a-PyTorch-Tutorial-to-Sequence-Labeling/ee3f34b45a6e24dd748a144bfc25b1adf9e1f077/img/in2.jpg
--------------------------------------------------------------------------------
/img/loss.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sgrvinod/a-PyTorch-Tutorial-to-Sequence-Labeling/ee3f34b45a6e24dd748a144bfc25b1adf9e1f077/img/loss.png
--------------------------------------------------------------------------------
/img/model.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sgrvinod/a-PyTorch-Tutorial-to-Sequence-Labeling/ee3f34b45a6e24dd748a144bfc25b1adf9e1f077/img/model.jpg
--------------------------------------------------------------------------------
/img/nothighway.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sgrvinod/a-PyTorch-Tutorial-to-Sequence-Labeling/ee3f34b45a6e24dd748a144bfc25b1adf9e1f077/img/nothighway.png
--------------------------------------------------------------------------------
/img/sorted.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sgrvinod/a-PyTorch-Tutorial-to-Sequence-Labeling/ee3f34b45a6e24dd748a144bfc25b1adf9e1f077/img/sorted.jpg
--------------------------------------------------------------------------------
/img/tagscores.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sgrvinod/a-PyTorch-Tutorial-to-Sequence-Labeling/ee3f34b45a6e24dd748a144bfc25b1adf9e1f077/img/tagscores.jpg
--------------------------------------------------------------------------------
/img/tagscores0.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sgrvinod/a-PyTorch-Tutorial-to-Sequence-Labeling/ee3f34b45a6e24dd748a144bfc25b1adf9e1f077/img/tagscores0.jpg
--------------------------------------------------------------------------------
/img/update.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sgrvinod/a-PyTorch-Tutorial-to-Sequence-Labeling/ee3f34b45a6e24dd748a144bfc25b1adf9e1f077/img/update.png
--------------------------------------------------------------------------------
/img/vloss1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sgrvinod/a-PyTorch-Tutorial-to-Sequence-Labeling/ee3f34b45a6e24dd748a144bfc25b1adf9e1f077/img/vloss1.png
--------------------------------------------------------------------------------
/img/vloss3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sgrvinod/a-PyTorch-Tutorial-to-Sequence-Labeling/ee3f34b45a6e24dd748a144bfc25b1adf9e1f077/img/vloss3.png
--------------------------------------------------------------------------------
/img/vscore.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sgrvinod/a-PyTorch-Tutorial-to-Sequence-Labeling/ee3f34b45a6e24dd748a144bfc25b1adf9e1f077/img/vscore.png
--------------------------------------------------------------------------------
/img/word.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sgrvinod/a-PyTorch-Tutorial-to-Sequence-Labeling/ee3f34b45a6e24dd748a144bfc25b1adf9e1f077/img/word.jpg
--------------------------------------------------------------------------------
/inference.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class ViterbiDecoder():
5 | """
6 | Viterbi Decoder.
7 | """
8 |
9 | def __init__(self, tag_map):
10 | """
11 | :param tag_map: tag map
12 | """
13 | self.tagset_size = len(tag_map)
14 | self.start_tag = tag_map['']
15 | self.end_tag = tag_map['']
16 |
17 | def decode(self, scores, lengths):
18 | """
19 | :param scores: CRF scores
20 | :param lengths: word sequence lengths
21 | :return: decoded sequences
22 | """
23 | batch_size = scores.size(0)
24 | word_pad_len = scores.size(1)
25 |
26 | # Create a tensor to hold accumulated sequence scores at each current tag
27 | scores_upto_t = torch.zeros(batch_size, self.tagset_size)
28 |
29 | # Create a tensor to hold back-pointers
30 | # i.e., indices of the previous_tag that corresponds to maximum accumulated score at current tag
31 | # Let pads be the tag index, since that was the last tag in the decoded sequence
32 | backpointers = torch.ones((batch_size, max(lengths), self.tagset_size), dtype=torch.long) * self.end_tag
33 |
34 | for t in range(max(lengths)):
35 | batch_size_t = sum([l > t for l in lengths]) # effective batch size (sans pads) at this timestep
36 | if t == 0:
37 | scores_upto_t[:batch_size_t] = scores[:batch_size_t, t, self.start_tag, :] # (batch_size, tagset_size)
38 | backpointers[:batch_size_t, t, :] = torch.ones((batch_size_t, self.tagset_size),
39 | dtype=torch.long) * self.start_tag
40 | else:
41 | # We add scores at current timestep to scores accumulated up to previous timestep, and
42 | # choose the previous timestep that corresponds to the max. accumulated score for each current timestep
43 | scores_upto_t[:batch_size_t], backpointers[:batch_size_t, t, :] = torch.max(
44 | scores[:batch_size_t, t, :, :] + scores_upto_t[:batch_size_t].unsqueeze(2),
45 | dim=1) # (batch_size, tagset_size)
46 |
47 | # Decode/trace best path backwards
48 | decoded = torch.zeros((batch_size, backpointers.size(1)), dtype=torch.long)
49 | pointer = torch.ones((batch_size, 1),
50 | dtype=torch.long) * self.end_tag # the pointers at the ends are all tags
51 |
52 | for t in list(reversed(range(backpointers.size(1)))):
53 | decoded[:, t] = torch.gather(backpointers[:, t, :], 1, pointer).squeeze(1)
54 | pointer = decoded[:, t].unsqueeze(1) # (batch_size, 1)
55 |
56 | # Sanity check
57 | assert torch.equal(decoded[:, 0], torch.ones((batch_size), dtype=torch.long) * self.start_tag)
58 |
59 | # Remove the at the beginning, and append with (to compare to targets, if any)
60 | decoded = torch.cat([decoded[:, 1:], torch.ones((batch_size, 1), dtype=torch.long) * self.start_tag],
61 | dim=1)
62 |
63 | return decoded
64 |
--------------------------------------------------------------------------------
/models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
4 | from utils import *
5 |
6 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
7 |
8 |
9 | class Highway(nn.Module):
10 | """
11 | Highway Network.
12 | """
13 |
14 | def __init__(self, size, num_layers=1, dropout=0.5):
15 | """
16 | :param size: size of linear layer (matches input size)
17 | :param num_layers: number of transform and gate layers
18 | :param dropout: dropout
19 | """
20 | super(Highway, self).__init__()
21 | self.size = size
22 | self.num_layers = num_layers
23 | self.transform = nn.ModuleList() # list of transform layers
24 | self.gate = nn.ModuleList() # list of gate layers
25 | self.dropout = nn.Dropout(p=dropout)
26 |
27 | for i in range(num_layers):
28 | transform = nn.Linear(size, size)
29 | gate = nn.Linear(size, size)
30 | self.transform.append(transform)
31 | self.gate.append(gate)
32 |
33 | def forward(self, x):
34 | """
35 | Forward propagation.
36 |
37 | :param x: input tensor
38 | :return: output tensor, with same dimensions as input tensor
39 | """
40 | transformed = nn.functional.relu(self.transform[0](x)) # transform input
41 | g = nn.functional.sigmoid(self.gate[0](x)) # calculate how much of the transformed input to keep
42 |
43 | out = g * transformed + (1 - g) * x # combine input and transformed input in this ratio
44 |
45 | # If there are additional layers
46 | for i in range(1, self.num_layers):
47 | out = self.dropout(out)
48 | transformed = nn.functional.relu(self.transform[i](out))
49 | g = nn.functional.sigmoid(self.gate[i](out))
50 |
51 | out = g * transformed + (1 - g) * out
52 |
53 | return out
54 |
55 |
56 | class CRF(nn.Module):
57 | """
58 | Conditional Random Field.
59 | """
60 |
61 | def __init__(self, hidden_dim, tagset_size):
62 | """
63 | :param hidden_dim: size of word RNN/BLSTM's output
64 | :param tagset_size: number of tags
65 | """
66 | super(CRF, self).__init__()
67 | self.tagset_size = tagset_size
68 | self.emission = nn.Linear(hidden_dim, self.tagset_size)
69 | self.transition = nn.Parameter(torch.Tensor(self.tagset_size, self.tagset_size))
70 | self.transition.data.zero_()
71 |
72 | def forward(self, feats):
73 | """
74 | Forward propagation.
75 |
76 | :param feats: output of word RNN/BLSTM, a tensor of dimensions (batch_size, timesteps, hidden_dim)
77 | :return: CRF scores, a tensor of dimensions (batch_size, timesteps, tagset_size, tagset_size)
78 | """
79 | self.batch_size = feats.size(0)
80 | self.timesteps = feats.size(1)
81 |
82 | emission_scores = self.emission(feats) # (batch_size, timesteps, tagset_size)
83 | emission_scores = emission_scores.unsqueeze(2).expand(self.batch_size, self.timesteps, self.tagset_size,
84 | self.tagset_size) # (batch_size, timesteps, tagset_size, tagset_size)
85 |
86 | crf_scores = emission_scores + self.transition.unsqueeze(0).unsqueeze(
87 | 0) # (batch_size, timesteps, tagset_size, tagset_size)
88 | return crf_scores
89 |
90 |
91 | class LM_LSTM_CRF(nn.Module):
92 | """
93 | The encompassing LM-LSTM-CRF model.
94 | """
95 |
96 | def __init__(self, tagset_size, charset_size, char_emb_dim, char_rnn_dim, char_rnn_layers, vocab_size,
97 | lm_vocab_size, word_emb_dim, word_rnn_dim, word_rnn_layers, dropout, highway_layers=1):
98 | """
99 | :param tagset_size: number of tags
100 | :param charset_size: size of character vocabulary
101 | :param char_emb_dim: size of character embeddings
102 | :param char_rnn_dim: size of character RNNs/LSTMs
103 | :param char_rnn_layers: number of layers in character RNNs/LSTMs
104 | :param vocab_size: input vocabulary size
105 | :param lm_vocab_size: vocabulary size of language models (in-corpus words subject to word frequency threshold)
106 | :param word_emb_dim: size of word embeddings
107 | :param word_rnn_dim: size of word RNN/BLSTM
108 | :param word_rnn_layers: number of layers in word RNNs/LSTMs
109 | :param dropout: dropout
110 | :param highway_layers: number of transform and gate layers
111 | """
112 |
113 | super(LM_LSTM_CRF, self).__init__()
114 |
115 | self.tagset_size = tagset_size # this is the size of the output vocab of the tagging model
116 |
117 | self.charset_size = charset_size
118 | self.char_emb_dim = char_emb_dim
119 | self.char_rnn_dim = char_rnn_dim
120 | self.char_rnn_layers = char_rnn_layers
121 |
122 | self.wordset_size = vocab_size # this is the size of the input vocab (embedding layer) of the tagging model
123 | self.lm_vocab_size = lm_vocab_size # this is the size of the output vocab of the language model
124 | self.word_emb_dim = word_emb_dim
125 | self.word_rnn_dim = word_rnn_dim
126 | self.word_rnn_layers = word_rnn_layers
127 |
128 | self.highway_layers = highway_layers
129 |
130 | self.dropout = nn.Dropout(p=dropout)
131 |
132 | self.char_embeds = nn.Embedding(self.charset_size, self.char_emb_dim) # character embedding layer
133 | self.forw_char_lstm = nn.LSTM(self.char_emb_dim, self.char_rnn_dim, num_layers=self.char_rnn_layers,
134 | bidirectional=False, dropout=dropout) # forward character LSTM
135 | self.back_char_lstm = nn.LSTM(self.char_emb_dim, self.char_rnn_dim, num_layers=self.char_rnn_layers,
136 | bidirectional=False, dropout=dropout) # backward character LSTM
137 |
138 | self.word_embeds = nn.Embedding(self.wordset_size, self.word_emb_dim) # word embedding layer
139 | self.word_blstm = nn.LSTM(self.word_emb_dim + self.char_rnn_dim * 2, self.word_rnn_dim // 2,
140 | num_layers=self.word_rnn_layers, bidirectional=True, dropout=dropout) # word BLSTM
141 |
142 | self.crf = CRF((self.word_rnn_dim // 2) * 2, self.tagset_size) # conditional random field
143 |
144 | self.forw_lm_hw = Highway(self.char_rnn_dim, num_layers=self.highway_layers,
145 | dropout=dropout) # highway to transform forward char LSTM output for the forward language model
146 | self.back_lm_hw = Highway(self.char_rnn_dim, num_layers=self.highway_layers,
147 | dropout=dropout) # highway to transform backward char LSTM output for the backward language model
148 | self.subword_hw = Highway(2 * self.char_rnn_dim, num_layers=self.highway_layers,
149 | dropout=dropout) # highway to transform combined forward and backward char LSTM outputs for use in the word BLSTM
150 |
151 | self.forw_lm_out = nn.Linear(self.char_rnn_dim,
152 | self.lm_vocab_size) # linear layer to find vocabulary scores for the forward language model
153 | self.back_lm_out = nn.Linear(self.char_rnn_dim,
154 | self.lm_vocab_size) # linear layer to find vocabulary scores for the backward language model
155 |
156 | def init_word_embeddings(self, embeddings):
157 | """
158 | Initialize embeddings with pre-trained embeddings.
159 |
160 | :param embeddings: pre-trained embeddings
161 | """
162 | self.word_embeds.weight = nn.Parameter(embeddings)
163 |
164 | def fine_tune_word_embeddings(self, fine_tune=False):
165 | """
166 | Fine-tune embedding layer? (Not fine-tuning only makes sense if using pre-trained embeddings).
167 |
168 | :param fine_tune: Fine-tune?
169 | """
170 | for p in self.word_embeds.parameters():
171 | p.requires_grad = fine_tune
172 |
173 | def forward(self, cmaps_f, cmaps_b, cmarkers_f, cmarkers_b, wmaps, tmaps, wmap_lengths, cmap_lengths):
174 | """
175 | Forward propagation.
176 |
177 | :param cmaps_f: padded encoded forward character sequences, a tensor of dimensions (batch_size, char_pad_len)
178 | :param cmaps_b: padded encoded backward character sequences, a tensor of dimensions (batch_size, char_pad_len)
179 | :param cmarkers_f: padded forward character markers, a tensor of dimensions (batch_size, word_pad_len)
180 | :param cmarkers_b: padded backward character markers, a tensor of dimensions (batch_size, word_pad_len)
181 | :param wmaps: padded encoded word sequences, a tensor of dimensions (batch_size, word_pad_len)
182 | :param tmaps: padded tag sequences, a tensor of dimensions (batch_size, word_pad_len)
183 | :param wmap_lengths: word sequence lengths, a tensor of dimensions (batch_size)
184 | :param cmap_lengths: character sequence lengths, a tensor of dimensions (batch_size, word_pad_len)
185 | """
186 | self.batch_size = cmaps_f.size(0)
187 | self.word_pad_len = wmaps.size(1)
188 |
189 | # Sort by decreasing true char. sequence length
190 | cmap_lengths, char_sort_ind = cmap_lengths.sort(dim=0, descending=True)
191 | cmaps_f = cmaps_f[char_sort_ind]
192 | cmaps_b = cmaps_b[char_sort_ind]
193 | cmarkers_f = cmarkers_f[char_sort_ind]
194 | cmarkers_b = cmarkers_b[char_sort_ind]
195 | wmaps = wmaps[char_sort_ind]
196 | tmaps = tmaps[char_sort_ind]
197 | wmap_lengths = wmap_lengths[char_sort_ind]
198 |
199 | # Embedding look-up for characters
200 | cf = self.char_embeds(cmaps_f) # (batch_size, char_pad_len, char_emb_dim)
201 | cb = self.char_embeds(cmaps_b)
202 |
203 | # Dropout
204 | cf = self.dropout(cf) # (batch_size, char_pad_len, char_emb_dim)
205 | cb = self.dropout(cb)
206 |
207 | # Pack padded sequence
208 | cf = pack_padded_sequence(cf, cmap_lengths.tolist(),
209 | batch_first=True) # packed sequence of char_emb_dim, with real sequence lengths
210 | cb = pack_padded_sequence(cb, cmap_lengths.tolist(), batch_first=True)
211 |
212 | # LSTM
213 | cf, _ = self.forw_char_lstm(cf) # packed sequence of char_rnn_dim, with real sequence lengths
214 | cb, _ = self.back_char_lstm(cb)
215 |
216 | # Unpack packed sequence
217 | cf, _ = pad_packed_sequence(cf, batch_first=True) # (batch_size, max_char_len_in_batch, char_rnn_dim)
218 | cb, _ = pad_packed_sequence(cb, batch_first=True)
219 |
220 | # Sanity check
221 | assert cf.size(1) == max(cmap_lengths.tolist()) == list(cmap_lengths)[0]
222 |
223 | # Select RNN outputs only at marker points (spaces in the character sequence)
224 | cmarkers_f = cmarkers_f.unsqueeze(2).expand(self.batch_size, self.word_pad_len, self.char_rnn_dim)
225 | cmarkers_b = cmarkers_b.unsqueeze(2).expand(self.batch_size, self.word_pad_len, self.char_rnn_dim)
226 | cf_selected = torch.gather(cf, 1, cmarkers_f) # (batch_size, word_pad_len, char_rnn_dim)
227 | cb_selected = torch.gather(cb, 1, cmarkers_b)
228 |
229 | # Only for co-training, not useful for tagging after model is trained
230 | if self.training:
231 | lm_f = self.forw_lm_hw(self.dropout(cf_selected)) # (batch_size, word_pad_len, char_rnn_dim)
232 | lm_b = self.back_lm_hw(self.dropout(cb_selected))
233 | lm_f_scores = self.forw_lm_out(self.dropout(lm_f)) # (batch_size, word_pad_len, lm_vocab_size)
234 | lm_b_scores = self.back_lm_out(self.dropout(lm_b))
235 |
236 | # Sort by decreasing true word sequence length
237 | wmap_lengths, word_sort_ind = wmap_lengths.sort(dim=0, descending=True)
238 | wmaps = wmaps[word_sort_ind]
239 | tmaps = tmaps[word_sort_ind]
240 | cf_selected = cf_selected[word_sort_ind] # for language model
241 | cb_selected = cb_selected[word_sort_ind]
242 | if self.training:
243 | lm_f_scores = lm_f_scores[word_sort_ind]
244 | lm_b_scores = lm_b_scores[word_sort_ind]
245 |
246 | # Embedding look-up for words
247 | w = self.word_embeds(wmaps) # (batch_size, word_pad_len, word_emb_dim)
248 | w = self.dropout(w)
249 |
250 | # Sub-word information at each word
251 | subword = self.subword_hw(self.dropout(
252 | torch.cat((cf_selected, cb_selected), dim=2))) # (batch_size, word_pad_len, 2 * char_rnn_dim)
253 | subword = self.dropout(subword)
254 |
255 | # Concatenate word embeddings and sub-word features
256 | w = torch.cat((w, subword), dim=2) # (batch_size, word_pad_len, word_emb_dim + 2 * char_rnn_dim)
257 |
258 | # Pack padded sequence
259 | w = pack_padded_sequence(w, list(wmap_lengths),
260 | batch_first=True) # packed sequence of word_emb_dim + 2 * char_rnn_dim, with real sequence lengths
261 |
262 | # LSTM
263 | w, _ = self.word_blstm(w) # packed sequence of word_rnn_dim, with real sequence lengths
264 |
265 | # Unpack packed sequence
266 | w, _ = pad_packed_sequence(w, batch_first=True) # (batch_size, max_word_len_in_batch, word_rnn_dim)
267 | w = self.dropout(w)
268 |
269 | crf_scores = self.crf(w) # (batch_size, max_word_len_in_batch, tagset_size, tagset_size)
270 |
271 | if self.training:
272 | return crf_scores, lm_f_scores, lm_b_scores, wmaps, tmaps, wmap_lengths, word_sort_ind, char_sort_ind
273 | else:
274 | return crf_scores, wmaps, tmaps, wmap_lengths, word_sort_ind, char_sort_ind # sort inds to reorder, if req.
275 |
276 |
277 | class ViterbiLoss(nn.Module):
278 | """
279 | Viterbi Loss.
280 | """
281 |
282 | def __init__(self, tag_map):
283 | """
284 | :param tag_map: tag map
285 | """
286 | super(ViterbiLoss, self).__init__()
287 | self.tagset_size = len(tag_map)
288 | self.start_tag = tag_map['']
289 | self.end_tag = tag_map['']
290 |
291 | def forward(self, scores, targets, lengths):
292 | """
293 | Forward propagation.
294 |
295 | :param scores: CRF scores
296 | :param targets: true tags indices in unrolled CRF scores
297 | :param lengths: word sequence lengths
298 | :return: viterbi loss
299 | """
300 |
301 | batch_size = scores.size(0)
302 | word_pad_len = scores.size(1)
303 |
304 | # Gold score
305 |
306 | targets = targets.unsqueeze(2)
307 | scores_at_targets = torch.gather(scores.view(batch_size, word_pad_len, -1), 2, targets).squeeze(
308 | 2) # (batch_size, word_pad_len)
309 |
310 | # Everything is already sorted by lengths
311 | scores_at_targets, _ = pack_padded_sequence(scores_at_targets, lengths, batch_first=True)
312 | gold_score = scores_at_targets.sum()
313 |
314 | # All paths' scores
315 |
316 | # Create a tensor to hold accumulated sequence scores at each current tag
317 | scores_upto_t = torch.zeros(batch_size, self.tagset_size).to(device)
318 |
319 | for t in range(max(lengths)):
320 | batch_size_t = sum([l > t for l in lengths]) # effective batch size (sans pads) at this timestep
321 | if t == 0:
322 | scores_upto_t[:batch_size_t] = scores[:batch_size_t, t, self.start_tag, :] # (batch_size, tagset_size)
323 | else:
324 | # We add scores at current timestep to scores accumulated up to previous timestep, and log-sum-exp
325 | # Remember, the cur_tag of the previous timestep is the prev_tag of this timestep
326 | # So, broadcast prev. timestep's cur_tag scores along cur. timestep's cur_tag dimension
327 | scores_upto_t[:batch_size_t] = log_sum_exp(
328 | scores[:batch_size_t, t, :, :] + scores_upto_t[:batch_size_t].unsqueeze(2),
329 | dim=1) # (batch_size, tagset_size)
330 |
331 | # We only need the final accumulated scores at the tag
332 | all_paths_scores = scores_upto_t[:, self.end_tag].sum()
333 |
334 | viterbi_loss = all_paths_scores - gold_score
335 | viterbi_loss = viterbi_loss / batch_size
336 |
337 | return viterbi_loss
338 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import time
2 | import torch
3 | import torch.optim as optim
4 | import os
5 | import sys
6 | from models import LM_LSTM_CRF, ViterbiLoss
7 | from utils import *
8 | from torch.nn.utils.rnn import pack_padded_sequence
9 | from datasets import WCDataset
10 | from inference import ViterbiDecoder
11 | from sklearn.metrics import f1_score
12 |
13 | # Data parameters
14 | task = 'ner' # tagging task, to choose column in CoNLL 2003 dataset
15 | train_file = './datasets/eng.train' # path to training data
16 | val_file = './datasets/eng.testa' # path to validation data
17 | test_file = './datasets/eng.testb' # path to test data
18 | emb_file = './embeddings/glove.6B.100d.txt' # path to pre-trained word embeddings
19 | min_word_freq = 5 # threshold for word frequency
20 | min_char_freq = 1 # threshold for character frequency
21 | caseless = True # lowercase everything?
22 | expand_vocab = True # expand model's input vocabulary to the pre-trained embeddings' vocabulary?
23 |
24 | # Model parameters
25 | char_emb_dim = 30 # character embedding size
26 | with open(emb_file, 'r') as f:
27 | word_emb_dim = len(f.readline().split(' ')) - 1 # word embedding size
28 | word_rnn_dim = 300 # word RNN size
29 | char_rnn_dim = 300 # character RNN size
30 | char_rnn_layers = 1 # number of layers in character RNN
31 | word_rnn_layers = 1 # number of layers in word RNN
32 | highway_layers = 1 # number of layers in highway network
33 | dropout = 0.5 # dropout
34 | fine_tune_word_embeddings = False # fine-tune pre-trained word embeddings?
35 |
36 | # Training parameters
37 | start_epoch = 0 # start at this epoch
38 | batch_size = 10 # batch size
39 | lr = 0.015 # learning rate
40 | lr_decay = 0.05 # decay learning rate by this amount
41 | momentum = 0.9 # momentum
42 | workers = 1 # number of workers for loading data in the DataLoader
43 | epochs = 200 # number of epochs to run without early-stopping
44 | grad_clip = 5. # clip gradients at this value
45 | print_freq = 100 # print training or validation status every __ batches
46 | best_f1 = 0. # F1 score to start with
47 | checkpoint = None # path to model checkpoint, None if none
48 |
49 | tag_ind = 1 if task == 'pos' else 3 # choose column in CoNLL 2003 dataset
50 |
51 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
52 |
53 |
54 | def main():
55 | """
56 | Training and validation.
57 | """
58 | global best_f1, epochs_since_improvement, checkpoint, start_epoch, word_map, char_map, tag_map
59 |
60 | # Read training and validation data
61 | train_words, train_tags = read_words_tags(train_file, tag_ind, caseless)
62 | val_words, val_tags = read_words_tags(val_file, tag_ind, caseless)
63 |
64 | # Initialize model or load checkpoint
65 | if checkpoint is not None:
66 | checkpoint = torch.load(checkpoint)
67 | model = checkpoint['model']
68 | optimizer = checkpoint['optimizer']
69 | word_map = checkpoint['word_map']
70 | lm_vocab_size = checkpoint['lm_vocab_size']
71 | tag_map = checkpoint['tag_map']
72 | char_map = checkpoint['char_map']
73 | start_epoch = checkpoint['epoch'] + 1
74 | best_f1 = checkpoint['f1']
75 | else:
76 | word_map, char_map, tag_map = create_maps(train_words + val_words, train_tags + val_tags, min_word_freq,
77 | min_char_freq) # create word, char, tag maps
78 | embeddings, word_map, lm_vocab_size = load_embeddings(emb_file, word_map,
79 | expand_vocab) # load pre-trained embeddings
80 |
81 | model = LM_LSTM_CRF(tagset_size=len(tag_map),
82 | charset_size=len(char_map),
83 | char_emb_dim=char_emb_dim,
84 | char_rnn_dim=char_rnn_dim,
85 | char_rnn_layers=char_rnn_layers,
86 | vocab_size=len(word_map),
87 | lm_vocab_size=lm_vocab_size,
88 | word_emb_dim=word_emb_dim,
89 | word_rnn_dim=word_rnn_dim,
90 | word_rnn_layers=word_rnn_layers,
91 | dropout=dropout,
92 | highway_layers=highway_layers).to(device)
93 | model.init_word_embeddings(embeddings.to(device)) # initialize embedding layer with pre-trained embeddings
94 | model.fine_tune_word_embeddings(fine_tune_word_embeddings) # fine-tune
95 | optimizer = optim.SGD(params=filter(lambda p: p.requires_grad, model.parameters()), lr=lr, momentum=momentum)
96 |
97 | # Loss functions
98 | lm_criterion = nn.CrossEntropyLoss().to(device)
99 | crf_criterion = ViterbiLoss(tag_map).to(device)
100 |
101 | # Since the language model's vocab is restricted to in-corpus indices, encode training/val with only these!
102 | # word_map might have been expanded, and in-corpus words eliminated due to low frequency might still be added because
103 | # they were in the pre-trained embeddings
104 | temp_word_map = {k: v for k, v in word_map.items() if v <= word_map['']}
105 | train_inputs = create_input_tensors(train_words, train_tags, temp_word_map, char_map,
106 | tag_map)
107 | val_inputs = create_input_tensors(val_words, val_tags, temp_word_map, char_map, tag_map)
108 |
109 | # DataLoaders
110 | train_loader = torch.utils.data.DataLoader(WCDataset(*train_inputs), batch_size=batch_size, shuffle=True,
111 | num_workers=workers, pin_memory=False)
112 | val_loader = torch.utils.data.DataLoader(WCDataset(*val_inputs), batch_size=batch_size, shuffle=True,
113 | num_workers=workers, pin_memory=False)
114 |
115 | # Viterbi decoder (to find accuracy during validation)
116 | vb_decoder = ViterbiDecoder(tag_map)
117 |
118 | # Epochs
119 | for epoch in range(start_epoch, epochs):
120 |
121 | # One epoch's training
122 | train(train_loader=train_loader,
123 | model=model,
124 | lm_criterion=lm_criterion,
125 | crf_criterion=crf_criterion,
126 | optimizer=optimizer,
127 | epoch=epoch,
128 | vb_decoder=vb_decoder)
129 |
130 | # One epoch's validation
131 | val_f1 = validate(val_loader=val_loader,
132 | model=model,
133 | crf_criterion=crf_criterion,
134 | vb_decoder=vb_decoder)
135 |
136 | # Did validation F1 score improve?
137 | is_best = val_f1 > best_f1
138 | best_f1 = max(val_f1, best_f1)
139 | if not is_best:
140 | epochs_since_improvement += 1
141 | print("\nEpochs since improvement: %d\n" % (epochs_since_improvement,))
142 | else:
143 | epochs_since_improvement = 0
144 |
145 | # Save checkpoint
146 | save_checkpoint(epoch, model, optimizer, val_f1, word_map, char_map, tag_map, lm_vocab_size, is_best)
147 |
148 | # Decay learning rate every epoch
149 | adjust_learning_rate(optimizer, lr / (1 + (epoch + 1) * lr_decay))
150 |
151 |
152 | def train(train_loader, model, lm_criterion, crf_criterion, optimizer, epoch, vb_decoder):
153 | """
154 | Performs one epoch's training.
155 |
156 | :param train_loader: DataLoader for training data
157 | :param model: model
158 | :param lm_criterion: cross entropy loss layer
159 | :param crf_criterion: viterbi loss layer
160 | :param optimizer: optimizer
161 | :param epoch: epoch number
162 | :param vb_decoder: viterbi decoder (to decode and find F1 score)
163 | """
164 |
165 | model.train() # training mode enables dropout
166 |
167 | batch_time = AverageMeter() # forward prop. + back prop. time per batch
168 | data_time = AverageMeter() # data loading time per batch
169 | ce_losses = AverageMeter() # cross entropy loss
170 | vb_losses = AverageMeter() # viterbi loss
171 | f1s = AverageMeter() # f1 score
172 |
173 | start = time.time()
174 |
175 | # Batches
176 | for i, (wmaps, cmaps_f, cmaps_b, cmarkers_f, cmarkers_b, tmaps, wmap_lengths, cmap_lengths) in enumerate(
177 | train_loader):
178 |
179 | data_time.update(time.time() - start)
180 |
181 | max_word_len = max(wmap_lengths.tolist())
182 | max_char_len = max(cmap_lengths.tolist())
183 |
184 | # Reduce batch's padded length to maximum in-batch sequence
185 | # This saves some compute on nn.Linear layers (RNNs are unaffected, since they don't compute over the pads)
186 | wmaps = wmaps[:, :max_word_len].to(device)
187 | cmaps_f = cmaps_f[:, :max_char_len].to(device)
188 | cmaps_b = cmaps_b[:, :max_char_len].to(device)
189 | cmarkers_f = cmarkers_f[:, :max_word_len].to(device)
190 | cmarkers_b = cmarkers_b[:, :max_word_len].to(device)
191 | tmaps = tmaps[:, :max_word_len].to(device)
192 | wmap_lengths = wmap_lengths.to(device)
193 | cmap_lengths = cmap_lengths.to(device)
194 |
195 | # Forward prop.
196 | crf_scores, lm_f_scores, lm_b_scores, wmaps_sorted, tmaps_sorted, wmap_lengths_sorted, _, __ = model(cmaps_f,
197 | cmaps_b,
198 | cmarkers_f,
199 | cmarkers_b,
200 | wmaps,
201 | tmaps,
202 | wmap_lengths,
203 | cmap_lengths)
204 |
205 | # LM loss
206 |
207 | # We don't predict the next word at the pads or tokens
208 | # We will only predict at [dunston, checks, in] among [dunston, checks, in, , , , ...]
209 | # So, prediction lengths are word sequence lengths - 1
210 | lm_lengths = wmap_lengths_sorted - 1
211 | lm_lengths = lm_lengths.tolist()
212 |
213 | # Remove scores at timesteps we won't predict at
214 | # pack_padded_sequence is a good trick to do this (see dynamic_rnn.py, where we explore this)
215 | lm_f_scores, _ = pack_padded_sequence(lm_f_scores, lm_lengths, batch_first=True)
216 | lm_b_scores, _ = pack_padded_sequence(lm_b_scores, lm_lengths, batch_first=True)
217 |
218 | # For the forward sequence, targets are from the second word onwards, up to
219 | # (timestep -> target) ...dunston -> checks, ...checks -> in, ...in ->
220 | lm_f_targets = wmaps_sorted[:, 1:]
221 | lm_f_targets, _ = pack_padded_sequence(lm_f_targets, lm_lengths, batch_first=True)
222 |
223 | # For the backward sequence, targets are followed by all words except the last word
224 | # ...notsnud -> , ...skcehc -> dunston, ...ni -> checks
225 | lm_b_targets = torch.cat(
226 | [torch.LongTensor([word_map['']] * wmaps_sorted.size(0)).unsqueeze(1).to(device), wmaps_sorted], dim=1)
227 | lm_b_targets, _ = pack_padded_sequence(lm_b_targets, lm_lengths, batch_first=True)
228 |
229 | # Calculate loss
230 | ce_loss = lm_criterion(lm_f_scores, lm_f_targets) + lm_criterion(lm_b_scores, lm_b_targets)
231 | vb_loss = crf_criterion(crf_scores, tmaps_sorted, wmap_lengths_sorted)
232 | loss = ce_loss + vb_loss
233 |
234 | # Back prop.
235 | optimizer.zero_grad()
236 | loss.backward()
237 |
238 | if grad_clip is not None:
239 | clip_gradient(optimizer, grad_clip)
240 |
241 | optimizer.step()
242 |
243 | # Viterbi decode to find accuracy / f1
244 | decoded = vb_decoder.decode(crf_scores.to("cpu"), wmap_lengths_sorted.to("cpu"))
245 |
246 | # Remove timesteps we won't predict at, and also tags, because to predict them would be cheating
247 | decoded, _ = pack_padded_sequence(decoded, lm_lengths, batch_first=True)
248 | tmaps_sorted = tmaps_sorted % vb_decoder.tagset_size # actual target indices (see create_input_tensors())
249 | tmaps_sorted, _ = pack_padded_sequence(tmaps_sorted, lm_lengths, batch_first=True)
250 |
251 | # F1
252 | f1 = f1_score(tmaps_sorted.to("cpu").numpy(), decoded.numpy(), average='macro')
253 |
254 | # Keep track of metrics
255 | ce_losses.update(ce_loss.item(), sum(lm_lengths))
256 | vb_losses.update(vb_loss.item(), crf_scores.size(0))
257 | batch_time.update(time.time() - start)
258 | f1s.update(f1, sum(lm_lengths))
259 |
260 | start = time.time()
261 |
262 | # Print training status
263 | if i % print_freq == 0:
264 | print('Epoch: [{0}][{1}/{2}]\t'
265 | 'Batch Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
266 | 'Data Load Time {data_time.val:.3f} ({data_time.avg:.3f})\t'
267 | 'CE Loss {ce_loss.val:.4f} ({ce_loss.avg:.4f})\t'
268 | 'VB Loss {vb_loss.val:.4f} ({vb_loss.avg:.4f})\t'
269 | 'F1 {f1.val:.3f} ({f1.avg:.3f})'.format(epoch, i, len(train_loader),
270 | batch_time=batch_time,
271 | data_time=data_time, ce_loss=ce_losses,
272 | vb_loss=vb_losses, f1=f1s))
273 |
274 |
275 | def validate(val_loader, model, crf_criterion, vb_decoder):
276 | """
277 | Performs one epoch's validation.
278 |
279 | :param val_loader: DataLoader for validation data
280 | :param model: model
281 | :param crf_criterion: viterbi loss layer
282 | :param vb_decoder: viterbi decoder
283 | :return: validation F1 score
284 | """
285 | model.eval()
286 |
287 | batch_time = AverageMeter()
288 | vb_losses = AverageMeter()
289 | f1s = AverageMeter()
290 |
291 | start = time.time()
292 |
293 | for i, (wmaps, cmaps_f, cmaps_b, cmarkers_f, cmarkers_b, tmaps, wmap_lengths, cmap_lengths) in enumerate(
294 | val_loader):
295 |
296 | max_word_len = max(wmap_lengths.tolist())
297 | max_char_len = max(cmap_lengths.tolist())
298 |
299 | # Reduce batch's padded length to maximum in-batch sequence
300 | # This saves some compute on nn.Linear layers (RNNs are unaffected, since they don't compute over the pads)
301 | wmaps = wmaps[:, :max_word_len].to(device)
302 | cmaps_f = cmaps_f[:, :max_char_len].to(device)
303 | cmaps_b = cmaps_b[:, :max_char_len].to(device)
304 | cmarkers_f = cmarkers_f[:, :max_word_len].to(device)
305 | cmarkers_b = cmarkers_b[:, :max_word_len].to(device)
306 | tmaps = tmaps[:, :max_word_len].to(device)
307 | wmap_lengths = wmap_lengths.to(device)
308 | cmap_lengths = cmap_lengths.to(device)
309 |
310 | # Forward prop.
311 | crf_scores, wmaps_sorted, tmaps_sorted, wmap_lengths_sorted, _, __ = model(cmaps_f,
312 | cmaps_b,
313 | cmarkers_f,
314 | cmarkers_b,
315 | wmaps,
316 | tmaps,
317 | wmap_lengths,
318 | cmap_lengths)
319 |
320 | # Viterbi / CRF layer loss
321 | vb_loss = crf_criterion(crf_scores, tmaps_sorted, wmap_lengths_sorted)
322 |
323 | # Viterbi decode to find accuracy / f1
324 | decoded = vb_decoder.decode(crf_scores.to("cpu"), wmap_lengths_sorted.to("cpu"))
325 |
326 | # Remove timesteps we won't predict at, and also tags, because to predict them would be cheating
327 | decoded, _ = pack_padded_sequence(decoded, (wmap_lengths_sorted - 1).tolist(), batch_first=True)
328 | tmaps_sorted = tmaps_sorted % vb_decoder.tagset_size # actual target indices (see create_input_tensors())
329 | tmaps_sorted, _ = pack_padded_sequence(tmaps_sorted, (wmap_lengths_sorted - 1).tolist(), batch_first=True)
330 |
331 | # f1
332 | f1 = f1_score(tmaps_sorted.to("cpu").numpy(), decoded.numpy(), average='macro')
333 |
334 | # Keep track of metrics
335 | vb_losses.update(vb_loss.item(), crf_scores.size(0))
336 | f1s.update(f1, sum((wmap_lengths_sorted - 1).tolist()))
337 | batch_time.update(time.time() - start)
338 |
339 | start = time.time()
340 |
341 | if i % print_freq == 0:
342 | print('Validation: [{0}/{1}]\t'
343 | 'Batch Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
344 | 'VB Loss {vb_loss.val:.4f} ({vb_loss.avg:.4f})\t'
345 | 'F1 Score {f1.val:.3f} ({f1.avg:.3f})\t'.format(i, len(val_loader), batch_time=batch_time,
346 | vb_loss=vb_losses, f1=f1s))
347 |
348 | print(
349 | '\n * LOSS - {vb_loss.avg:.3f}, F1 SCORE - {f1.avg:.3f}\n'.format(vb_loss=vb_losses,
350 | f1=f1s))
351 |
352 | return f1s.avg
353 |
354 |
355 | if __name__ == '__main__':
356 | main()
357 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | from collections import Counter
2 | import codecs
3 | import itertools
4 | from functools import reduce
5 | import numpy as np
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.init
9 | from torch.nn.utils.rnn import pack_padded_sequence
10 |
11 |
12 | def read_words_tags(file, tag_ind, caseless=True):
13 | """
14 | Reads raw data in the CoNLL 2003 format and returns word and tag sequences.
15 |
16 | :param file: file with raw data in the CoNLL 2003 format
17 | :param tag_ind: column index of tag
18 | :param caseless: lowercase words?
19 | :return: word, tag sequences
20 | """
21 | with codecs.open(file, 'r', 'utf-8') as f:
22 | lines = f.readlines()
23 | words = []
24 | tags = []
25 | temp_w = []
26 | temp_t = []
27 | for line in lines:
28 | if not (line.isspace() or (len(line) > 10 and line[0:10] == '-DOCSTART-')):
29 | feats = line.rstrip('\n').split()
30 | temp_w.append(feats[0].lower() if caseless else feats[0])
31 | temp_t.append(feats[tag_ind])
32 | elif len(temp_w) > 0:
33 | assert len(temp_w) == len(temp_t)
34 | words.append(temp_w)
35 | tags.append(temp_t)
36 | temp_w = []
37 | temp_t = []
38 | # last sentence
39 | if len(temp_w) > 0:
40 | assert len(temp_w) == len(temp_t)
41 | words.append(temp_w)
42 | tags.append(temp_t)
43 |
44 | # Sanity check
45 | assert len(words) == len(tags)
46 |
47 | return words, tags
48 |
49 |
50 | def create_maps(words, tags, min_word_freq=5, min_char_freq=1):
51 | """
52 | Creates word, char, tag maps.
53 |
54 | :param words: word sequences
55 | :param tags: tag sequences
56 | :param min_word_freq: words that occur fewer times than this threshold are binned as s
57 | :param min_char_freq: characters that occur fewer times than this threshold are binned as s
58 | :return: word, char, tag maps
59 | """
60 | word_freq = Counter()
61 | char_freq = Counter()
62 | tag_map = set()
63 | for w, t in zip(words, tags):
64 | word_freq.update(w)
65 | char_freq.update(list(reduce(lambda x, y: list(x) + [' '] + list(y), w)))
66 | tag_map.update(t)
67 |
68 | word_map = {k: v + 1 for v, k in enumerate([w for w in word_freq.keys() if word_freq[w] > min_word_freq])}
69 | char_map = {k: v + 1 for v, k in enumerate([c for c in char_freq.keys() if char_freq[c] > min_char_freq])}
70 | tag_map = {k: v + 1 for v, k in enumerate(tag_map)}
71 |
72 | word_map[''] = 0
73 | word_map[''] = len(word_map)
74 | word_map[''] = len(word_map)
75 | char_map[''] = 0
76 | char_map[''] = len(char_map)
77 | char_map[''] = len(char_map)
78 | tag_map[''] = 0
79 | tag_map[''] = len(tag_map)
80 | tag_map[''] = len(tag_map)
81 |
82 | return word_map, char_map, tag_map
83 |
84 |
85 | def create_input_tensors(words, tags, word_map, char_map, tag_map):
86 | """
87 | Creates input tensors that will be used to create a PyTorch Dataset.
88 |
89 | :param words: word sequences
90 | :param tags: tag sequences
91 | :param word_map: word map
92 | :param char_map: character map
93 | :param tag_map: tag map
94 | :return: padded encoded words, padded encoded forward chars, padded encoded backward chars,
95 | padded forward character markers, padded backward character markers, padded encoded tags,
96 | word sequence lengths, char sequence lengths
97 | """
98 | # Encode sentences into word maps with at the end
99 | # [['dunston', 'checks', 'in', '']] -> [[4670, 4670, 185, 4669]]
100 | wmaps = list(map(lambda s: list(map(lambda w: word_map.get(w, word_map['']), s)) + [word_map['']], words))
101 |
102 | # Forward and backward character streams
103 | # [['d', 'u', 'n', 's', 't', 'o', 'n', ' ', 'c', 'h', 'e', 'c', 'k', 's', ' ', 'i', 'n', ' ']]
104 | chars_f = list(map(lambda s: list(reduce(lambda x, y: list(x) + [' '] + list(y), s)) + [' '], words))
105 | # [['n', 'i', ' ', 's', 'k', 'c', 'e', 'h', 'c', ' ', 'n', 'o', 't', 's', 'n', 'u', 'd', ' ']]
106 | chars_b = list(
107 | map(lambda s: list(reversed([' '] + list(reduce(lambda x, y: list(x) + [' '] + list(y), s)))), words))
108 |
109 | # Encode streams into forward and backward character maps with at the end
110 | # [[29, 2, 12, 8, 7, 14, 12, 3, 6, 18, 1, 6, 21, 8, 3, 17, 12, 3, 60]]
111 | cmaps_f = list(
112 | map(lambda s: list(map(lambda c: char_map.get(c, char_map['']), s)) + [char_map['']], chars_f))
113 | # [[12, 17, 3, 8, 21, 6, 1, 18, 6, 3, 12, 14, 7, 8, 12, 2, 29, 3, 60]]
114 | cmaps_b = list(
115 | map(lambda s: list(map(lambda c: char_map.get(c, char_map['']), s)) + [char_map['']], chars_b))
116 |
117 | # Positions of spaces and character
118 | # Words are predicted or encoded at these places in the language and tagging models respectively
119 | # [[7, 14, 17, 18]] are points after '...dunston', '...checks', '...in', '...' respectively
120 | cmarkers_f = list(map(lambda s: [ind for ind in range(len(s)) if s[ind] == char_map[' ']] + [len(s) - 1], cmaps_f))
121 | # Reverse the markers for the backward stream before adding , so the words of the f and b markers coincide
122 | # i.e., [[17, 9, 2, 18]] are points after '...notsnud', '...skcehc', '...ni', '...' respectively
123 | cmarkers_b = list(
124 | map(lambda s: list(reversed([ind for ind in range(len(s)) if s[ind] == char_map[' ']])) + [len(s) - 1],
125 | cmaps_b))
126 |
127 | # Encode tags into tag maps with at the end
128 | tmaps = list(map(lambda s: list(map(lambda t: tag_map[t], s)) + [tag_map['']], tags))
129 | # Since we're using CRF scores of size (prev_tags, cur_tags), find indices of target sequence in the unrolled scores
130 | # This will be row_index (i.e. prev_tag) * n_columns (i.e. tagset_size) + column_index (i.e. cur_tag)
131 | tmaps = list(map(lambda s: [tag_map[''] * len(tag_map) + s[0]] + [s[i - 1] * len(tag_map) + s[i] for i in
132 | range(1, len(s))], tmaps))
133 | # Note - the actual tag indices can be recovered with tmaps % len(tag_map)
134 |
135 | # Pad, because need fixed length to be passed around by DataLoaders and other layers
136 | word_pad_len = max(list(map(lambda s: len(s), wmaps)))
137 | char_pad_len = max(list(map(lambda s: len(s), cmaps_f)))
138 |
139 | # Sanity check
140 | assert word_pad_len == max(list(map(lambda s: len(s), tmaps)))
141 |
142 | padded_wmaps = []
143 | padded_cmaps_f = []
144 | padded_cmaps_b = []
145 | padded_cmarkers_f = []
146 | padded_cmarkers_b = []
147 | padded_tmaps = []
148 | wmap_lengths = []
149 | cmap_lengths = []
150 |
151 | for w, cf, cb, cmf, cmb, t in zip(wmaps, cmaps_f, cmaps_b, cmarkers_f, cmarkers_b, tmaps):
152 | # Sanity checks
153 | assert len(w) == len(cmf) == len(cmb) == len(t)
154 | assert len(cmaps_f) == len(cmaps_b)
155 |
156 | # Pad
157 | # A note - it doesn't really matter what we pad with, as long as it's a valid index
158 | # i.e., we'll extract output at those pad points (to extract equal lengths), but never use them
159 |
160 | padded_wmaps.append(w + [word_map['']] * (word_pad_len - len(w)))
161 | padded_cmaps_f.append(cf + [char_map['']] * (char_pad_len - len(cf)))
162 | padded_cmaps_b.append(cb + [char_map['']] * (char_pad_len - len(cb)))
163 |
164 | # 0 is always a valid index to pad markers with (-1 is too but torch.gather has some issues with it)
165 | padded_cmarkers_f.append(cmf + [0] * (word_pad_len - len(w)))
166 | padded_cmarkers_b.append(cmb + [0] * (word_pad_len - len(w)))
167 |
168 | padded_tmaps.append(t + [tag_map['']] * (word_pad_len - len(t)))
169 |
170 | wmap_lengths.append(len(w))
171 | cmap_lengths.append(len(cf))
172 |
173 | # Sanity check
174 | assert len(padded_wmaps[-1]) == len(padded_tmaps[-1]) == len(padded_cmarkers_f[-1]) == len(
175 | padded_cmarkers_b[-1]) == word_pad_len
176 | assert len(padded_cmaps_f[-1]) == len(padded_cmaps_b[-1]) == char_pad_len
177 |
178 | padded_wmaps = torch.LongTensor(padded_wmaps)
179 | padded_cmaps_f = torch.LongTensor(padded_cmaps_f)
180 | padded_cmaps_b = torch.LongTensor(padded_cmaps_b)
181 | padded_cmarkers_f = torch.LongTensor(padded_cmarkers_f)
182 | padded_cmarkers_b = torch.LongTensor(padded_cmarkers_b)
183 | padded_tmaps = torch.LongTensor(padded_tmaps)
184 | wmap_lengths = torch.LongTensor(wmap_lengths)
185 | cmap_lengths = torch.LongTensor(cmap_lengths)
186 |
187 | return padded_wmaps, padded_cmaps_f, padded_cmaps_b, padded_cmarkers_f, padded_cmarkers_b, padded_tmaps, \
188 | wmap_lengths, cmap_lengths
189 |
190 |
191 | def init_embedding(input_embedding):
192 | """
193 | Initialize embedding tensor with values from the uniform distribution.
194 |
195 | :param input_embedding: embedding tensor
196 | :return:
197 | """
198 | bias = np.sqrt(3.0 / input_embedding.size(1))
199 | nn.init.uniform_(input_embedding, -bias, bias)
200 |
201 |
202 | def load_embeddings(emb_file, word_map, expand_vocab=True):
203 | """
204 | Load pre-trained embeddings for words in the word map.
205 |
206 | :param emb_file: file with pre-trained embeddings (in the GloVe format)
207 | :param word_map: word map
208 | :param expand_vocab: expand vocabulary of word map to vocabulary of pre-trained embeddings?
209 | :return: embeddings for words in word map, (possibly expanded) word map,
210 | number of words in word map that are in-corpus (subject to word frequency threshold)
211 | """
212 | with open(emb_file, 'r') as f:
213 | emb_len = len(f.readline().split(' ')) - 1
214 |
215 | print("Embedding length is %d." % emb_len)
216 |
217 | # Create tensor to hold embeddings for words that are in-corpus
218 | ic_embs = torch.FloatTensor(len(word_map), emb_len)
219 | init_embedding(ic_embs)
220 |
221 | if expand_vocab:
222 | print("You have elected to include embeddings that are out-of-corpus.")
223 | ooc_words = []
224 | ooc_embs = []
225 | else:
226 | print("You have elected NOT to include embeddings that are out-of-corpus.")
227 |
228 | # Read embedding file
229 | print("\nLoading embeddings...")
230 | for line in open(emb_file, 'r'):
231 | line = line.split(' ')
232 |
233 | emb_word = line[0]
234 | embedding = list(map(lambda t: float(t), filter(lambda n: n and not n.isspace(), line[1:])))
235 |
236 | if not expand_vocab and emb_word not in word_map:
237 | continue
238 |
239 | # If word is in train_vocab, store at the correct index (as in the word_map)
240 | if emb_word in word_map:
241 | ic_embs[word_map[emb_word]] = torch.FloatTensor(embedding)
242 |
243 | # If word is in dev or test vocab, store it and its embedding into lists
244 | elif expand_vocab:
245 | ooc_words.append(emb_word)
246 | ooc_embs.append(embedding)
247 |
248 | lm_vocab_size = len(word_map) # keep track of lang. model's output vocab size (no out-of-corpus words)
249 |
250 | if expand_vocab:
251 | print("'word_map' is being updated accordingly.")
252 | for word in ooc_words:
253 | word_map[word] = len(word_map)
254 | ooc_embs = torch.FloatTensor(np.asarray(ooc_embs))
255 | embeddings = torch.cat([ic_embs, ooc_embs], 0)
256 |
257 | else:
258 | embeddings = ic_embs
259 |
260 | # Sanity check
261 | assert embeddings.size(0) == len(word_map)
262 |
263 | print("\nDone.\n Embedding vocabulary: %d\n Language Model vocabulary: %d.\n" % (len(word_map), lm_vocab_size))
264 |
265 | return embeddings, word_map, lm_vocab_size
266 |
267 |
268 | def clip_gradient(optimizer, grad_clip):
269 | """
270 | Clip gradients computed during backpropagation to prevent gradient explosion.
271 |
272 | :param optimizer: optimized with the gradients to be clipped
273 | :param grad_clip: gradient clip value
274 | """
275 | for group in optimizer.param_groups:
276 | for param in group['params']:
277 | if param.grad is not None:
278 | param.grad.data.clamp_(-grad_clip, grad_clip)
279 |
280 |
281 | def save_checkpoint(epoch, model, optimizer, val_f1, word_map, char_map, tag_map, lm_vocab_size, is_best):
282 | """
283 | Save model checkpoint.
284 |
285 | :param epoch: epoch number
286 | :param model: model
287 | :param optimizer: optimized
288 | :param val_f1: validation F1 score
289 | :param word_map: word map
290 | :param char_map: char map
291 | :param tag_map: tag map
292 | :param lm_vocab_size: number of words in-corpus, i.e. size of output vocabulary of linear model
293 | :param is_best: is this checkpoint the best so far?
294 | :return:
295 | """
296 | state = {'epoch': epoch,
297 | 'f1': val_f1,
298 | 'model': model,
299 | 'optimizer': optimizer,
300 | 'word_map': word_map,
301 | 'tag_map': tag_map,
302 | 'char_map': char_map,
303 | 'lm_vocab_size': lm_vocab_size}
304 | filename = 'checkpoint_lm_lstm_crf.pth.tar'
305 | torch.save(state, filename)
306 | # If checkpoint is the best so far, create a copy to avoid being overwritten by a subsequent worse checkpoint
307 | if is_best:
308 | torch.save(state, 'BEST_' + filename)
309 |
310 |
311 | class AverageMeter(object):
312 | """
313 | Keeps track of most recent, average, sum, and count of a metric.
314 | """
315 |
316 | def __init__(self):
317 | self.reset()
318 |
319 | def reset(self):
320 | self.val = 0
321 | self.avg = 0
322 | self.sum = 0
323 | self.count = 0
324 |
325 | def update(self, val, n=1):
326 | self.val = val
327 | self.sum += val * n
328 | self.count += n
329 | self.avg = self.sum / self.count
330 |
331 |
332 | def adjust_learning_rate(optimizer, new_lr):
333 | """
334 | Shrinks learning rate by a specified factor.
335 |
336 | :param optimizer: optimizer whose learning rates must be decayed
337 | :param new_lr: new learning rate
338 | """
339 |
340 | print("\nDECAYING learning rate.")
341 | for param_group in optimizer.param_groups:
342 | param_group['lr'] = new_lr
343 | print("The new learning rate is %f\n" % (optimizer.param_groups[0]['lr'],))
344 |
345 |
346 | def log_sum_exp(tensor, dim):
347 | """
348 | Calculates the log-sum-exponent of a tensor's dimension in a numerically stable way.
349 |
350 | :param tensor: tensor
351 | :param dim: dimension to calculate log-sum-exp of
352 | :return: log-sum-exp
353 | """
354 | m, _ = torch.max(tensor, dim)
355 | m_expanded = m.unsqueeze(dim).expand_as(tensor)
356 | return m + torch.log(torch.sum(torch.exp(tensor - m_expanded), dim))
357 |
--------------------------------------------------------------------------------