├── models └── .keep ├── images ├── ner-image.png └── bilstm-crf.png ├── floyd.yml ├── README.md └── ner.ipynb /models/.keep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /images/ner-image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/floydhub/named-entity-recognition-template/HEAD/images/ner-image.png -------------------------------------------------------------------------------- /images/bilstm-crf.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/floydhub/named-entity-recognition-template/HEAD/images/bilstm-crf.png -------------------------------------------------------------------------------- /floyd.yml: -------------------------------------------------------------------------------- 1 | env: tensorflow-1.7 2 | machine: cpu 3 | data: 4 | - source: floydhub/datasets/ner/1 5 | destination: ner 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Named Entity Recognition 2 | 3 | [Named Entity Recognition](https://en.wikipedia.org/wiki/Named-entity_recognition) is one of the most common [NLP](https://en.wikipedia.org/wiki/Natural-language_processing) problems. The goal is classify named entities in text into pre-defined categories such as the names of persons, organizations, locations, expressions of times, quantities, monetary values, percentages, etc. 4 | *What can you use it for?* Here are a few ideas - social media, chatbot, customer support tickets, survey responses, and data mining! 5 | 6 | ### Try it now 7 | 8 | [![Run on FloydHub](https://static.floydhub.com/button/button.svg)](https://floydhub.com/run?template=https://github.com/floydhub/named-entity-recognition-template) 9 | 10 | Click this button to open a Workspace on FloydHub that will train this model. 11 | 12 | ### Predicting named entities of GMB(Groningen Meaning Bank) corpus 13 | 14 | In this notebook we will perform a [Sequence Tagging with a LSTM-CRF model](https://www.depends-on-the-definition.com/sequence-tagging-lstm-crf/) to extract the named entities from the annotated corpus. 15 | 16 | ![ner-image](images/ner-image.png) 17 | 18 | Entity tags are encoded using a BIO annotation scheme, where each entity label is prefixed with either B or I letter. B- denotes the beginning and I- inside of an entity. The prefixes are used to detect multiword entities, e.g. sentence:"World War II", tags:(B-eve, I-eve, I-eve). All other words, which don’t refer to entities of interest, are labeled with the O tag. 19 | 20 | Tag | Label meaning | Example Given 21 | --- | ------------- | ------------- 22 | geo | Geographical Entity | London 23 | org | Organization | ONU 24 | per | Person | Bush 25 | gpe | Geopolitical Entity | British 26 | tim | Time indicator | Wednesday 27 | art | Artifact | Chrysler 28 | eve | Event | Christmas 29 | nat | Natural Phenomenon | Hurricane 30 | O | No-Label | the 31 | 32 | We will: 33 | - Preprocess text data for NLP 34 | - Build and train a Bi-directional LSTM-CRF model using Keras and Tensorflow 35 | - Evaluate our model on the test set 36 | - Run the model on your own sentences! 37 | -------------------------------------------------------------------------------- /ner.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Named Entity Recognition\n", 8 | "\n", 9 | "Hi 🙂, if you are seeing this notebook, you have succesfully started your first project on FloydHub 🚀, hooray!!\n", 10 | "\n", 11 | "[Named Enity Recognition](https://en.wikipedia.org/wiki/Named-entity_recognition) is one of the most common [NLP](https://en.wikipedia.org/wiki/Natural-language_processing) problems. The goal is classify named entities in text into pre-defined categories such as the names of persons, organizations, locations, expressions of times, quantities, monetary values, percentages, etc.\n", 12 | "*What can you use it for?* Here are a few ideas - social media, chatbot, customer support tickets, survey responses, and data mining! \n", 13 | "\n", 14 | "### Predicting named entities of GMB(Groningen Meaning Bank) corpus\n", 15 | "\n", 16 | "In this notebook we will perform a [Sequence Tagging with a LSTM-CRF model](https://www.depends-on-the-definition.com/sequence-tagging-lstm-crf/) to extract the named entities from the annotated corpus.\n", 17 | "\n", 18 | "\n", 19 | "\n", 20 | "Entity tags are encoded using a BIO annotation scheme, where each entity label is prefixed with either B or I letter. B- denotes the beginning and I- inside of an entity. The prefixes are used to detect multiword entities, e.g. sentence:\"World War II\", tags:(B-eve, I-eve, I-eve). All other words, which don’t refer to entities of interest, are labeled with the O tag.\n", 21 | "\n", 22 | "Tag | Label meaning | Example Given\n", 23 | "------------ | ------------- | \n", 24 | "geo | Geographical Entity | London\n", 25 | "org | Organization | ONU\n", 26 | "per | Person | Bush\n", 27 | "gpe | Geopolitical Entity | British\n", 28 | "tim | Time indicator | Wednesday\n", 29 | "art | Artifact | Chrysler\n", 30 | "eve | Event | Christmas\n", 31 | "nat | Natural Phenomenon | Hurricane\n", 32 | "O | No-Label | the\n", 33 | "\n", 34 | "We will:\n", 35 | "- Preprocess text data for NLP\n", 36 | "- Build and train a Bi-directional LSTM-CRF model using Keras and Tensorflow\n", 37 | "- Evaluate our model on the test set\n", 38 | "- Run the model on your own sentences!\n", 39 | "\n", 40 | "#### Instructions\n", 41 | "- To execute a code cell, click on the cell and press `Shift + Enter` (shortcut for Run).\n", 42 | "- To learn more about Workspaces, check out the [Getting Started Notebook](get_started_workspace.ipynb).\n", 43 | "- **Tip**: *Feel free to try this Notebook with your own data and on your own super awesome named entity recognition task.*\n", 44 | "\n", 45 | "Now, let's get started! 🚀\n", 46 | "\n", 47 | "### Initial Setup\n", 48 | "Let's start by importing some packages." 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": 1, 54 | "metadata": {}, 55 | "outputs": [], 56 | "source": [ 57 | "# Install extra-dependencies\n", 58 | "! pip -q install git+https://www.github.com/keras-team/keras-contrib.git sklearn-crfsuite" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": 2, 64 | "metadata": {}, 65 | "outputs": [ 66 | { 67 | "name": "stderr", 68 | "output_type": "stream", 69 | "text": [ 70 | "Using TensorFlow backend.\n" 71 | ] 72 | } 73 | ], 74 | "source": [ 75 | "import tensorflow as tf\n", 76 | "import keras\n", 77 | "\n", 78 | "import pandas as pd\n", 79 | "import numpy as np\n", 80 | "import matplotlib.pyplot as plt" 81 | ] 82 | }, 83 | { 84 | "cell_type": "markdown", 85 | "metadata": {}, 86 | "source": [ 87 | "## Training Parameters\n", 88 | "We'll set the hyperparameters for training our model. If you understand what they mean, feel free to play around - otherwise, we recommend keeping the defaults for your first run 🙂" 89 | ] 90 | }, 91 | { 92 | "cell_type": "code", 93 | "execution_count": 3, 94 | "metadata": {}, 95 | "outputs": [], 96 | "source": [ 97 | "# Hyperparams if GPU is available\n", 98 | "if tf.test.is_gpu_available():\n", 99 | " BATCH_SIZE = 512 # Number of examples used in each iteration\n", 100 | " EPOCHS = 5 # Number of passes through entire dataset\n", 101 | " MAX_LEN = 75 # Max length of review (in words)\n", 102 | " EMBEDDING = 40 # Dimension of word embedding vector\n", 103 | "\n", 104 | " \n", 105 | "# Hyperparams for CPU training\n", 106 | "else:\n", 107 | " BATCH_SIZE = 32\n", 108 | " EPOCHS = 5\n", 109 | " MAX_LEN = 75\n", 110 | " EMBEDDING = 20" 111 | ] 112 | }, 113 | { 114 | "cell_type": "markdown", 115 | "metadata": {}, 116 | "source": [ 117 | "## Data\n", 118 | "\n", 119 | "The movie ner dataset is already attached to your workspace (if you want to attach your own data, [check out our docs](https://docs.floydhub.com/guides/workspace/#attaching-floydhub-datasets)).\n", 120 | "\n", 121 | "Let's take a look at data." 122 | ] 123 | }, 124 | { 125 | "cell_type": "code", 126 | "execution_count": 4, 127 | "metadata": {}, 128 | "outputs": [ 129 | { 130 | "name": "stdout", 131 | "output_type": "stream", 132 | "text": [ 133 | "Number of sentences: 47959\n", 134 | "Number of words in the dataset: 35178\n", 135 | "Tags: ['O', 'I-nat', 'I-eve', 'B-nat', 'I-art', 'B-gpe', 'I-org', 'I-gpe', 'B-per', 'I-per', 'B-eve', 'I-tim', 'B-geo', 'I-geo', 'B-org', 'B-art', 'B-tim']\n", 136 | "Number of Labels: 17\n", 137 | "What the dataset looks like:\n" 138 | ] 139 | }, 140 | { 141 | "data": { 142 | "text/html": [ 143 | "
\n", 144 | "\n", 157 | "\n", 158 | " \n", 159 | " \n", 160 | " \n", 161 | " \n", 162 | " \n", 163 | " \n", 164 | " \n", 165 | " \n", 166 | " \n", 167 | " \n", 168 | " \n", 169 | " \n", 170 | " \n", 171 | " \n", 172 | " \n", 173 | " \n", 174 | " \n", 175 | " \n", 176 | " \n", 177 | " \n", 178 | " \n", 179 | " \n", 180 | " \n", 181 | " \n", 182 | " \n", 183 | " \n", 184 | " \n", 185 | " \n", 186 | " \n", 187 | " \n", 188 | " \n", 189 | " \n", 190 | " \n", 191 | " \n", 192 | " \n", 193 | " \n", 194 | " \n", 195 | " \n", 196 | " \n", 197 | " \n", 198 | " \n", 199 | " \n", 200 | " \n", 201 | " \n", 202 | " \n", 203 | " \n", 204 | " \n", 205 | " \n", 206 | " \n", 207 | " \n", 208 | " \n", 209 | " \n", 210 | " \n", 211 | " \n", 212 | " \n", 213 | " \n", 214 | " \n", 215 | " \n", 216 | " \n", 217 | " \n", 218 | " \n", 219 | " \n", 220 | " \n", 221 | " \n", 222 | " \n", 223 | " \n", 224 | " \n", 225 | " \n", 226 | " \n", 227 | " \n", 228 | " \n", 229 | " \n", 230 | " \n", 231 | " \n", 232 | " \n", 233 | " \n", 234 | " \n", 235 | " \n", 236 | " \n", 237 | " \n", 238 | " \n", 239 | "
Sentence #WordPOSTag
0Sentence: 1ThousandsNNSO
1Sentence: 1ofINO
2Sentence: 1demonstratorsNNSO
3Sentence: 1haveVBPO
4Sentence: 1marchedVBNO
5Sentence: 1throughINO
6Sentence: 1LondonNNPB-geo
7Sentence: 1toTOO
8Sentence: 1protestVBO
9Sentence: 1theDTO
\n", 240 | "
" 241 | ], 242 | "text/plain": [ 243 | " Sentence # Word POS Tag\n", 244 | "0 Sentence: 1 Thousands NNS O\n", 245 | "1 Sentence: 1 of IN O\n", 246 | "2 Sentence: 1 demonstrators NNS O\n", 247 | "3 Sentence: 1 have VBP O\n", 248 | "4 Sentence: 1 marched VBN O\n", 249 | "5 Sentence: 1 through IN O\n", 250 | "6 Sentence: 1 London NNP B-geo\n", 251 | "7 Sentence: 1 to TO O\n", 252 | "8 Sentence: 1 protest VB O\n", 253 | "9 Sentence: 1 the DT O" 254 | ] 255 | }, 256 | "execution_count": 4, 257 | "metadata": {}, 258 | "output_type": "execute_result" 259 | } 260 | ], 261 | "source": [ 262 | "data = pd.read_csv(\"/floyd/input/ner/ner_dataset.csv\", encoding=\"latin1\")\n", 263 | "data = data.fillna(method=\"ffill\")\n", 264 | "\n", 265 | "print(\"Number of sentences: \", len(data.groupby(['Sentence #'])))\n", 266 | "\n", 267 | "words = list(set(data[\"Word\"].values))\n", 268 | "n_words = len(words)\n", 269 | "print(\"Number of words in the dataset: \", n_words)\n", 270 | "\n", 271 | "tags = list(set(data[\"Tag\"].values))\n", 272 | "print(\"Tags:\", tags)\n", 273 | "n_tags = len(tags)\n", 274 | "print(\"Number of Labels: \", n_tags)\n", 275 | "\n", 276 | "print(\"What the dataset looks like:\")\n", 277 | "# Show the first 10 rows\n", 278 | "data.head(n=10) " 279 | ] 280 | }, 281 | { 282 | "cell_type": "code", 283 | "execution_count": 5, 284 | "metadata": {}, 285 | "outputs": [ 286 | { 287 | "name": "stdout", 288 | "output_type": "stream", 289 | "text": [ 290 | "This is what a sentence looks like:\n", 291 | "[('Thousands', 'NNS', 'O'), ('of', 'IN', 'O'), ('demonstrators', 'NNS', 'O'), ('have', 'VBP', 'O'), ('marched', 'VBN', 'O'), ('through', 'IN', 'O'), ('London', 'NNP', 'B-geo'), ('to', 'TO', 'O'), ('protest', 'VB', 'O'), ('the', 'DT', 'O'), ('war', 'NN', 'O'), ('in', 'IN', 'O'), ('Iraq', 'NNP', 'B-geo'), ('and', 'CC', 'O'), ('demand', 'VB', 'O'), ('the', 'DT', 'O'), ('withdrawal', 'NN', 'O'), ('of', 'IN', 'O'), ('British', 'JJ', 'B-gpe'), ('troops', 'NNS', 'O'), ('from', 'IN', 'O'), ('that', 'DT', 'O'), ('country', 'NN', 'O'), ('.', '.', 'O')]\n" 292 | ] 293 | } 294 | ], 295 | "source": [ 296 | "class SentenceGetter(object):\n", 297 | " \"\"\"Class to Get the sentence in this format:\n", 298 | " [(Token_1, Part_of_Speech_1, Tag_1), ..., (Token_n, Part_of_Speech_1, Tag_1)]\"\"\"\n", 299 | " def __init__(self, data):\n", 300 | " \"\"\"Args:\n", 301 | " data is the pandas.DataFrame which contains the above dataset\"\"\"\n", 302 | " self.n_sent = 1\n", 303 | " self.data = data\n", 304 | " self.empty = False\n", 305 | " agg_func = lambda s: [(w, p, t) for w, p, t in zip(s[\"Word\"].values.tolist(),\n", 306 | " s[\"POS\"].values.tolist(),\n", 307 | " s[\"Tag\"].values.tolist())]\n", 308 | " self.grouped = self.data.groupby(\"Sentence #\").apply(agg_func)\n", 309 | " self.sentences = [s for s in self.grouped]\n", 310 | " \n", 311 | " def get_next(self):\n", 312 | " \"\"\"Return one sentence\"\"\"\n", 313 | " try:\n", 314 | " s = self.grouped[\"Sentence: {}\".format(self.n_sent)]\n", 315 | " self.n_sent += 1\n", 316 | " return s\n", 317 | " except:\n", 318 | " return None\n", 319 | " \n", 320 | "getter = SentenceGetter(data)\n", 321 | "sent = getter.get_next()\n", 322 | "print('This is what a sentence looks like:')\n", 323 | "print(sent)" 324 | ] 325 | }, 326 | { 327 | "cell_type": "markdown", 328 | "metadata": {}, 329 | "source": [ 330 | "As you can see from the output Cell above, each sentence in the dataset is represented as a list of tuple: [`(Token_1, PoS_1, Tag_1)`, ..., `(Token_n, PoS_n, Tag_n)`]." 331 | ] 332 | }, 333 | { 334 | "cell_type": "code", 335 | "execution_count": 6, 336 | "metadata": {}, 337 | "outputs": [ 338 | { 339 | "data": { 340 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAY4AAAEWCAYAAABxMXBSAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAHNNJREFUeJzt3XmYXVWd7vHvaxIIMpgE0jQm0QSN8tBXDVgPRKEVASGCGlTGpiVwsSMtDkg7xOHpIEo3eG0QlMYOg4JXUUSEyCDkQhBsLpCEIPNQkqST3IQUZiBIgwR+94+1CrZlndRZsU6dOqfez/Ocp/Zee++1166dnF+tYa+tiMDMzKxer2p2AczMrLU4cJiZWREHDjMzK+LAYWZmRRw4zMysiAOHmZkVceCwliZpmqTOZpfDbChx4LCmk/RM5fOSpP+urB/b7PK1Kkm7SdrU7HJY+xne7AKYRcR23cuSlgIfi4j/07wSNZak4RHhL3RrWa5x2KAnaRtJ50taJWmFpP8laUSNfT8v6T5Jf53XP5TX10u6XdLulX1XS/qspAckbZD0I0lb1cj3JEm3SPoPSU9LekjSuyrbx0i6LOe5XNJsSa/qcez5ktYBs3rJfx9Ji3PeqyX9a2Xb30q6K1/DPZL2qWy7M5/rznzs9ZJG5823AcMqtbc98jEfl/SopLWSrpM0LqePlBSSZkr6naR1ks7pUc5PSHpE0kZJ90t6S06fIOkaSU9JekLSSZu9qdbaIsIffwbNB1gKHNgj7ZvA7cBOwM7AAuAreds0oDMv/wtwFzAmr08FVgFvB4YBM4HHgOF5+2rgP3OeY4FO4Pga5ToJ2AR8AhgBHAesBXbI228AvgO8GtgFWAzM6HHsP+RybNNL/ouBI/Ly9sDeeXki8HvgQNIfeocAXcDovP1O4FHgDcC2wB3AaXnbbsCmHuc5CngYeFO+jm8A8/O2kUAAVwE7AJOA9cB+eftHgWXAHoCANwPj8zXdD3wR2Crn/V/Au5v978mfBv0/bXYB/PGn+qkROFYC+1fWpwOP5OVpwO+A84H5wPaV/b7fHWAqacsqX8qrgcMr284Dvl2jXCcBS3qk3QccAbwe+AMworLtBOCGyrGP9XHddwNfAXbskT4buLBH2q+Bo/LyncDnKttOBa7Oy70FjvnAsZX1EcALpODZHTg6KtvnAqdUzvvxXsr+buDxHmlfAy5o9r8nfxrzcR+HDWqSBPw16Qu/2zJgXGX9r0hf1B+IiI2V9NcDR0r6fCVtqx7Hrq4sP0uq1dSyosf6MuC1+Twjga5UXCDVDqqjvZZvJl+AGcBpwGN5lNg/R8SNOe9jJB1R2XdEPm+ta9iO2l4PfE/S+ZW0TaSaw4Y+8ptACtK95TlR0vpK2jCgbfuphjoHDhvUIiIkrSZ9OXV/ab2OVAvp9iSpCenHkt4fEQty+nLguoj4t34qzvge668D/l8+zzOk5qNa001vdhrqiHgYOErSMOBo4KrcV7EcuCgiPrUF5e3tnMuBz0fEz3tukDSyj/yWk5rEegaE5aQa4Fu2oIzWgtw5bq3gcmC2pB0l/RWpSed/V3eIiJuA/wn8srsTGJgDfEpSh5LtJH1Q0qu3sBwTckf3cEl/T/oL/KaIWEJqMvqmpO0lvUrSZEn71puxpOMk7RgRL5L+8o/8uRQ4QtIBkoblgQIHdHf+92ENqXP8dZW07wFflfTmfN7Rkj5SZzEvAmZJelv+fb5J0njgNzmvU3IH+3BJb5W0Z535Wotx4LBW8M/AQ8CDwL2kDu1v9twpIq4D/hG4QdJbI+I/gU8D/0Hq5H0M+Dv6+Ot/M24jdQyvJQWvD0dEd/POMcAo4JG8/aekfoN6vR94VNJG4F+BIyPihYh4AvgIqc/gKVLz2Geo4/9uRKwj/Z4W5RFZUyLicuC7pBrN06Tf53vrKWBE/BA4G7gS2Jh/joqIF0id9u/M5esCLmDzTWbWwlS7Zm1m3fLw0sMj4sBml8Ws2VzjMDOzIg4cZmZWxE1VZmZWxDUOMzMr0pbPcey0004xceLEZhfDzKylLFq06KmIGNvXfg0NHEoznW4EXiRNfdAhaQxpqOJE0vQSR0bEuvyE8LmkYX3PkuYMuifnMwP4as72GxFx6ebOO3HiRBYuXNj/F2Rm1sYkLet7r4FpqnpPREyJiI68Pgu4OSImAzfzykyh7wMm589M0jhwcqCZDewN7EV6EGw0ZmbWFM3o45hOehqW/POwSvplkdwJjJK0C3AwMC8i1uYHmuaRJrYzM7MmaHTgCOAmSYskzcxpO0fEqry8mleerh3Hn04EtyKn1Ur/E/kdAgslLezq6urPazAzs4pGd47vGxEr8/xC8yQ9Ut2YJ7Drl/HAETGHNDcRHR0dHmNsZtYgDa1xRMTK/HMN8AtSH8WTuQmK/HNN3n0ladK4buNzWq10MzNrgoYFDknbStq+exk4CHiA9GKYGXm3GcA1eXkucFyedXMqsCE3ad0IHJRn8Ryd87mxUeU2M7PNa2RT1c7AL/KLbYYDP46IX0laAFwh6UTSTJpH5v2vJw3F7SQNxz0BICLWSvo66XWhAKdHxNoGltvMzDajLacc6ejoCD/HYWZWRtKiyqMTNXnKETMzK9KWU45Y7ybOuq7X9KVnHjrAJTGzVuYah5mZFXHgMDOzIg4cZmZWxIHDzMyKOHCYmVkRj6qymqOtwCOuzOzPucZhZmZFHDjMzKyIA4eZmRVx4DAzsyIOHGZmVsSjqtrQ5kZJmZn9pVzjMDOzIg4cZmZWxIHDzMyKOHCYmVkRd47bZvnlT2bWk2scZmZWxIHDzMyKOHCYmVkRBw4zMyviwGFmZkUcOMzMrIgDh5mZFXHgMDOzIg4cZmZWxIHDzMyKOHCYmVkRBw4zMyviwGFmZkUcOMzMrIgDh5mZFWl44JA0TNJiSdfm9UmS7pLUKemnkrbK6Vvn9c68fWIljy/l9EclHdzoMpuZWW0DUeP4DPBwZf0s4JyIeCOwDjgxp58IrMvp5+T9kLQ7cDTwN8A04N8lDRuAcpuZWS8aGjgkjQcOBS7K6wL2B67Mu1wKHJaXp+d18vYD8v7TgZ9ExPMRsQToBPZqZLnNzKy2Rtc4vg18AXgpr+8IrI+ITXl9BTAuL48DlgPk7Rvy/i+n93LMyyTNlLRQ0sKurq7+vg4zM8sa9s5xSe8H1kTEIkn7Neo83SJiDjAHoKOjIxp9vsGg1vvAzcwaqWGBA9gH+KCkQ4CRwA7AucAoScNzrWI8sDLvvxKYAKyQNBx4DfD7Snq36jFmZjbAGtZUFRFfiojxETGR1Ll9S0QcC8wHDs+7zQCuyctz8zp5+y0RETn96DzqahIwGbi7UeU2M7PNa2SNo5YvAj+R9A1gMXBxTr8Y+KGkTmAtKdgQEQ9KugJ4CNgEnBwRLw58sc3MDAYocETErcCtefkJehkVFRHPAUfUOP4M4IzGldDMzOrlJ8fNzKyIA4eZmRVpRh+HtYFaQ4GXnnnoAJfEzAaaaxxmZlbEgcPMzIo4cJiZWREHDjMzK+LAYWZmRRw4zMysiAOHmZkVceAwM7MiDhxmZlbEgcPMzIo4cJiZWREHDjMzK+JJDluA3y1uZoOJaxxmZlbEgcPMzIo4cJiZWREHDjMzK+LAYWZmRRw4zMysiAOHmZkVceAwM7MiDhxmZlbEgcPMzIo4cJiZWREHDjMzK+LAYWZmRRw4zMysiAOHmZkV6TNwSHqDpK3z8n6SPi1pVOOLZmZmg1E9L3L6OdAh6Y3AHOAa4MfAIY0smLWmWi+dWnrmoQNcEjNrlHqaql6KiE3Ah4DvRMTngV0aWywzMxus6gkcL0g6BpgBXJvTRvR1kKSRku6W9FtJD0r6Wk6fJOkuSZ2Sfippq5y+dV7vzNsnVvL6Uk5/VNLBpRdpZmb9p57AcQLwDuCMiFgiaRLwwzqOex7YPyLeBkwBpkmaCpwFnBMRbwTWASfm/U8E1uX0c/J+SNodOBr4G2Aa8O+ShtV7gWZm1r/6DBwR8RDwReCevL4kIs6q47iIiGfy6oj8CWB/4MqcfilwWF6entfJ2w+QpJz+k4h4PiKWAJ3AXnVcm5mZNUA9o6o+ANwL/CqvT5E0t57MJQ2TdC+wBpgH/A5Yn/tMAFYA4/LyOGA5QN6+Adixmt7LMdVzzZS0UNLCrq6ueopnZmZboJ6mqtNIf+GvB4iIe4Fd68k8Il6MiCnA+JzHbltWzLrONSciOiKiY+zYsY06jZnZkFdX53hEbOiR9lLJSSJiPTCf1FcySlL3MODxwMq8vBKYAJC3vwb4fTW9l2PMzGyA1RM4HpT0d8AwSZMlfQe4o6+DJI3tflBQ0jbAe4GHSQHk8LzbDNJzIQBz8zp5+y0RETn96DzqahIwGbi7rqszM7N+V0/g+BRpRNPzwOXA08ApdRy3CzBf0n3AAmBeRFxL6mg/VVInqQ/j4rz/xcCOOf1UYBZARDwIXAE8ROpnOTkiXqzv8szMrL/1+eR4RDwLfCV/6hYR9wF79JL+BL2MioqI54AjauR1BnBGyfnNzKwxagYOSb8kDZ/tVUR8sCElMjOzQW1zNY5vDVgpzMysZdQMHBHx6+7lPC3IbqQayKMR8ccBKNuQU2uCQDOzwaTPPg5JhwLfIz28J2CSpI9HxA2NLpyZmQ0+9Uyr/m/AeyKiE9L7OYDrAAcOM7MhqJ7huBu7g0b2BLCxQeUxM7NBrp4ax0JJ15OepQjSkNkFkj4MEBFXNbB8ZmY2yNQTOEYCTwLvzutdwDbAB0iBxIHDzGwIqecBwBMGoiBmZtYa6hlVNYk07cjE6v5+ANDMbGiqp6nqatI8Ur+kcFZcMzNrP/UEjuci4ryGl8TMzFpCPYHjXEmzgZtIM+QCEBH3NKxUZmY2aNUTON4CfJT0rvDupqrud4ebmdkQU0/gOALY1fNTmZkZ1Pfk+APAqEYXxMzMWkM9NY5RwCOSFvCnfRwejmtmNgTVEzhmN7wUZmbWMup5cvzXfe1j1pda7xpZeuahA1wSM/tL9dnHIWmqpAWSnpH0R0kvSnp6IApnZmaDTz2d498FjgEeJ01u+DHg/EYWyszMBq96Agf5fRzDIuLFiPg+MK2xxTIzs8Gqns7xZ/M7x++V9E1gFXUGHDMzaz/1BICP5v0+CfwBmAB8pJGFMjOzwaueUVXL8uJzks4DJvR4layZmQ0h9YyqulXSDpLGAPcAF0o6u/FFMzOzwaiepqrXRMTTwIeByyJib+DAxhbLzMwGq3oCx3BJuwBHAtc2uDxmZjbI1RM4TgduBDojYoGkXUnPdJiZ2RBUT+f4z4CfVdafwKOqzMyGrHqe47B+VmveJjOzVuAH+czMrIgDh5mZFannOY6vVpa3rjdjSRMkzZf0kKQHJX0mp4+RNE/S4/nn6JwuSedJ6pR0n6Q9K3nNyPs/LmlG2SWamVl/qhk4JH1R0juAwyvJ/7cg703AP0XE7sBU4GRJuwOzgJsjYjJwc14HeB8wOX9mAhfkcowhvUxqb2AvYHZ3sDEzs4G3uRrHI8ARwK6Sbpd0IbCjpDfXk3FErIqIe/LyRuBhYBwwHbg073YpcFhenk56wDAi4k5gVH5+5GBgXkSsjYh1wDw8O6+ZWdNsLnCsB74MdAL7Aefm9FmS7ig5iaSJwB7AXcDOEbEqb1oN7JyXxwHLK4etyGm10nueY6akhZIWdnV1lRTPzMwKbC5wHAxcB7wBOJvUVPSHiDghIt5Z7wkkbQf8HDglT13ysogIIIpL3YuImBMRHRHRMXbs2P7I0szMelEzcETElyPiAGAp8ENgGDBW0m8k/bKezCWNIAWNH0XEVTn5ydwERf65JqevJE3Z3m18TquVbmZmTVDPcNwbI2JhRMwBVkTEvsAJfR0kScDFwMMRUZ1Ndy7QPTJqBnBNJf24PLpqKrAhN2ndCBwkaXTuFD8op5mZWRPUM+XIFyqrx+e0p+rIex/SS6Dul3RvTvsycCZwhaQTgWWkyRMBrgcOIfWpPEsOThGxVtLXgQV5v9MjYm0d5zczswYomnIkIn5bsO9vANXYfEAv+wdwco28LgEuqffcZmbWOH5y3MzMijhwmJlZEQcOMzMr4sBhZmZF/D4OG5RqvbNk6ZmHDnBJzKwn1zjMzKyIA4eZmRVxU5U1lV+ja9Z6XOMwM7MiDhxmZlbEgcPMzIo4cJiZWREHDjMzK+LAYWZmRRw4zMysiAOHmZkVceAwM7MiDhxmZlbEgcPMzIo4cJiZWREHDjMzK+LAYWZmRTytegN5ynAza0eucZiZWREHDjMzK+LAYWZmRRw4zMysiAOHmZkVceAwM7MiDhxmZlbEgcPMzIo4cJiZWREHDjMzK+LAYWZmRRoWOCRdImmNpAcqaWMkzZP0eP45OqdL0nmSOiXdJ2nPyjEz8v6PS5rRqPKamVl9Glnj+AEwrUfaLODmiJgM3JzXAd4HTM6fmcAFkAINMBvYG9gLmN0dbMzMrDkaFjgi4jZgbY/k6cCleflS4LBK+mWR3AmMkrQLcDAwLyLWRsQ6YB5/HozMzGwADXQfx84RsSovrwZ2zsvjgOWV/VbktFrpf0bSTEkLJS3s6urq31KbmdnLmtY5HhEBRD/mNyciOiKiY+zYsf2VrZmZ9TDQgePJ3ARF/rkmp68EJlT2G5/TaqWbmVmTDHTgmAt0j4yaAVxTST8uj66aCmzITVo3AgdJGp07xQ/KaWZm1iQNe3WspMuB/YCdJK0gjY46E7hC0onAMuDIvPv1wCFAJ/AscAJARKyV9HVgQd7v9Ijo2eFuZmYDqGGBIyKOqbHpgF72DeDkGvlcAlzSj0UzM7O/gJ8cNzOzIg4cZmZWxIHDzMyKOHCYmVkRBw4zMyviwGFmZkUcOMzMrEjDnuMYSibOuq7ZRTAzGzCucZiZWRHXOKyl1KrdLT3z0AEuidnQ5RqHmZkVceAwM7MiDhxmZlbEgcPMzIo4cJiZWREHDjMzK+LAYWZmRfwch7UFP99hNnBc4zAzsyIOHGZmVsSBw8zMijhwmJlZEQcOMzMr4sBhZmZFHDjMzKyIn+OwtubnO8z6n2scZmZWxIHDzMyKOHCYmVkR93EUqNVebq3HfR9mW841DjMzK+LAYWZmRdxUZVbhJiyzvrnGYWZmRVqmxiFpGnAuMAy4KCLObHKRbAjZkoERrqVYu2qJGoekYcD5wPuA3YFjJO3e3FKZmQ1NrVLj2AvojIgnACT9BJgOPNSIk3nYrfWH/vp3VKvm4v4Ya5ZWCRzjgOWV9RXA3tUdJM0EZubVZyQ9WniOnYCntriErcXX2kJ0Vt277gQ8VbB/K2v5+1pgIK/19fXs1CqBo08RMQeYs6XHS1oYER39WKRBy9fannyt7WkwXmtL9HEAK4EJlfXxOc3MzAZYqwSOBcBkSZMkbQUcDcxtcpnMzIaklmiqiohNkj4J3EgajntJRDzYz6fZ4mauFuRrbU++1vY06K5VEdHsMpiZWQtplaYqMzMbJBw4zMysyJAPHJKmSXpUUqekWc0uT3+SNEHSfEkPSXpQ0mdy+hhJ8yQ9nn+ObnZZ+4ukYZIWS7o2r0+SdFe+vz/NgytanqRRkq6U9IikhyW9o13vq6TP5n+/D0i6XNLIdrqvki6RtEbSA5W0Xu+lkvPydd8nac9mlHlIB44hMJXJJuCfImJ3YCpwcr6+WcDNETEZuDmvt4vPAA9X1s8CzomINwLrgBObUqr+dy7wq4jYDXgb6Zrb7r5KGgd8GuiIiP9BGhxzNO11X38ATOuRVutevg+YnD8zgQsGqIx/YkgHDipTmUTEH4HuqUzaQkSsioh78vJG0pfLONI1Xpp3uxQ4rDkl7F+SxgOHAhfldQH7A1fmXdriWiW9BngXcDFARPwxItbTpveVNPpzG0nDgVcDq2ij+xoRtwFreyTXupfTgcsiuRMYJWmXgSnpK4Z64OhtKpNxTSpLQ0maCOwB3AXsHBGr8qbVwM5NKlZ/+zbwBeClvL4jsD4iNuX1drm/k4Au4Pu5We4iSdvShvc1IlYC3wL+ixQwNgCLaM/7WlXrXg6K76yhHjiGBEnbAT8HTomIp6vbIo3Hbvkx2ZLeD6yJiEXNLssAGA7sCVwQEXsAf6BHs1Qb3dfRpL+yJwGvBbblz5t12tpgvJdDPXC0/VQmkkaQgsaPIuKqnPxkd/U2/1zTrPL1o32AD0paSmpy3J/UDzAqN3FA+9zfFcCKiLgrr19JCiTteF8PBJZERFdEvABcRbrX7Xhfq2rdy0HxnTXUA0dbT2WS2/gvBh6OiLMrm+YCM/LyDOCagS5bf4uIL0XE+IiYSLqPt0TEscB84PC8W7tc62pguaQ356QDSK8YaLv7Smqimirp1fnfc/e1tt197aHWvZwLHJdHV00FNlSatAbMkH9yXNIhpLbx7qlMzmhykfqNpH2B24H7eaXd/8ukfo4rgNcBy4AjI6Jn51zLkrQf8LmIeL+kXUk1kDHAYuDvI+L5ZpavP0iaQhoEsBXwBHAC6Q/Btruvkr4GHEUaJbgY+BipXb8t7quky4H9SNOnPwnMBq6ml3uZg+d3Sc11zwInRMTCAS/zUA8cZmZWZqg3VZmZWSEHDjMzK+LAYWZmRRw4zMysiAOHmZkVceCwQUnSMw3IU5JukbRDf+fd4zy3Supo5DnyeT6dZ8b9UY/0KXmYeV/Hnybpc/1QjrGSfvWX5mOtw4HDhpJDgN/2nHZlMKk8DV2PTwDvzQ86Vk0hXeuAiIguYJWkfQbqnNZcDhzWMvJftj+XtCB/9snpp+V3Gtwq6QlJn66RxbHkJ3AlTcx/rV+Y3/Vwk6Rt8raXawySdsrTmCDpeElX5/cjLJX0SUmn5okG75Q0pnKuj0q6N79DYq98/La5nHfnY6ZX8p0r6RbSFNo9r/vUnM8Dkk7Jad8DdgVukPTZyr5bAacDR+XzH6X0boer8/sb7pT01l7O8Q+SbpC0jaQ3SPqVpEWSbpe0W97nB0rvgrgj/54Pr2Rxdf792lAQEf74M+g+wDO9pP0Y2Dcvv440lQrAacAdwNakp29/D4zo5fhlwPZ5eSLpSeQpef0K0tPHALeS3v9Azm9pXj4e6AS2B8aSZmo9KW87hzSJZPfxF+bldwEP5OV/qZxjFPAYadK+40nzT43ppcxvJz35vy2wHfAgsEfethTYqZdjjge+W1n/DjA7L+8P3Fv5vX0O+CQpoG6d028GJuflvUnTt0B6b8TPSH9w7k56JUH3OcYB9zf7340/A/MpqRabNduBwO5p1gUAdsgz/wJcF2nKieclrSFNQ72ix/FjIr2XpNuSiLg3Ly8iBZO+zM95bJS0AfhlTr8fqP4lfzmkdy1I2kHSKOAg0kSM3f0KI0kBEGBe9D49yL7ALyLiDwCSrgL+ljTNRr32BT6Sy3OLpB0r/TzHkabpPiwiXsi/z3cCP6v8nreu5HV1RLwEPCSpOm37GtLstTYEOHBYK3kVMDUinqsm5i+46jxFL9L7v+1Nkl6Vv/h6O2ab7v14pRl3ZI88qse8VFl/qcc5e87lE4CAj0TEoz3KvzdpavRmuJ/UJzIeWEK67vURMaXG/tXrV2V5JPDfDSmhDTru47BWchPwqe6VPNFfiUdJ/QJ9WUpqIoJXZmAtdRS8PNHkhojYANwIfCpPVIekPerI53bgsDw77LbAh3La5mwkNadV8zg2n3M/4Kl4ZYDAYuDjwFxJr83pSyQdkfeXpLfVUc43AQ/0uZe1BQcOG6xeLWlF5XMq+d3TuZP3IeCkwjyvI81C2pdvAf8oaTGpj2NLPJeP/x6vvA/768AI4D5JD+b1zYr06t8fAHeTZjW+KCL6aqaaT2rSu1fSUaS+jLdLug84k1em6+4+x29IfR3XSdqJFGROlPRbUp9KPa9Tfg/p92tDgGfHtSFD6YU4l0XEe5tdlnYj6TZgekSsa3ZZrPFc47AhI9ILby5s9AOAQ42kscDZDhpDh2scZmZWxDUOMzMr4sBhZmZFHDjMzKyIA4eZmRVx4DAzsyL/HzE8ipDUU0ewAAAAAElFTkSuQmCC\n", 341 | "text/plain": [ 342 | "
" 343 | ] 344 | }, 345 | "metadata": {}, 346 | "output_type": "display_data" 347 | } 348 | ], 349 | "source": [ 350 | "# Get all the sentences\n", 351 | "sentences = getter.sentences\n", 352 | "\n", 353 | "# Plot sentence by lenght\n", 354 | "plt.hist([len(s) for s in sentences], bins=50)\n", 355 | "plt.title('Token per sentence')\n", 356 | "plt.xlabel('Len (number of token)')\n", 357 | "plt.ylabel('# samples')\n", 358 | "plt.show()" 359 | ] 360 | }, 361 | { 362 | "cell_type": "markdown", 363 | "metadata": {}, 364 | "source": [ 365 | "## Data Preprocessing\n", 366 | "\n", 367 | "Before feeding the data into the model, we have to preprocess the text.\n", 368 | "\n", 369 | "- We will use the `word2idx` dictionary to convert each word to a corresponding integer ID and the `tag2idx` to do the same for the labels. Representing words as integers saves a lot of memory!\n", 370 | "- In order to feed the text into our Bi-LSTM-CRF, all texts should be the same length. We ensure this using the `sequence.pad_sequences()` method and `MAX_LEN` variable. All texts longer than `MAX_LEN` are truncated and shorter texts are padded to get them to the same length.\n", 371 | "\n", 372 | "The *Tokens per sentence* plot (see above) is useful for setting the `MAX_LEN` training hyperparameter." 373 | ] 374 | }, 375 | { 376 | "cell_type": "code", 377 | "execution_count": 7, 378 | "metadata": {}, 379 | "outputs": [ 380 | { 381 | "name": "stdout", 382 | "output_type": "stream", 383 | "text": [ 384 | "The word Obama is identified by the index: 9881\n", 385 | "The labels B-geo(which defines Geopraphical Enitities) is identified by the index: 13\n", 386 | "Raw Sample: Thousands of demonstrators have marched through London to protest the war in Iraq and demand the withdrawal of British troops from that country .\n", 387 | "Raw Label: O O O O O O B-geo O O O O O B-geo O O O O O B-gpe O O O O O\n", 388 | "After processing, sample: [16817 7825 10253 17489 66 10783 7144 32555 6507 8582 7721 25544\n", 389 | " 28446 2656 25382 8582 32363 7825 23884 21607 15364 29850 18029 8610\n", 390 | " 0 0 0 0 0 0 0 0 0 0 0 0\n", 391 | " 0 0 0 0 0 0 0 0 0 0 0 0\n", 392 | " 0 0 0 0 0 0 0 0 0 0 0 0\n", 393 | " 0 0 0 0 0 0 0 0 0 0 0 0\n", 394 | " 0 0 0]\n", 395 | "After processing, labels: [[0. 1. 0. ... 0. 0. 0.]\n", 396 | " [0. 1. 0. ... 0. 0. 0.]\n", 397 | " [0. 1. 0. ... 0. 0. 0.]\n", 398 | " ...\n", 399 | " [1. 0. 0. ... 0. 0. 0.]\n", 400 | " [1. 0. 0. ... 0. 0. 0.]\n", 401 | " [1. 0. 0. ... 0. 0. 0.]]\n" 402 | ] 403 | } 404 | ], 405 | "source": [ 406 | "# Vocabulary Key:word -> Value:token_index\n", 407 | "# The first 2 entries are reserved for PAD and UNK\n", 408 | "word2idx = {w: i + 2 for i, w in enumerate(words)}\n", 409 | "word2idx[\"UNK\"] = 1 # Unknown words\n", 410 | "word2idx[\"PAD\"] = 0 # Padding\n", 411 | "\n", 412 | "# Vocabulary Key:token_index -> Value:word\n", 413 | "idx2word = {i: w for w, i in word2idx.items()}\n", 414 | "\n", 415 | "# Vocabulary Key:Label/Tag -> Value:tag_index\n", 416 | "# The first entry is reserved for PAD\n", 417 | "tag2idx = {t: i+1 for i, t in enumerate(tags)}\n", 418 | "tag2idx[\"PAD\"] = 0\n", 419 | "\n", 420 | "# Vocabulary Key:tag_index -> Value:Label/Tag\n", 421 | "idx2tag = {i: w for w, i in tag2idx.items()}\n", 422 | "\n", 423 | "print(\"The word Obama is identified by the index: {}\".format(word2idx[\"Obama\"]))\n", 424 | "print(\"The labels B-geo(which defines Geopraphical Enitities) is identified by the index: {}\".format(tag2idx[\"B-geo\"]))\n", 425 | "\n", 426 | "\n", 427 | "from keras.preprocessing.sequence import pad_sequences\n", 428 | "# Convert each sentence from list of Token to list of word_index\n", 429 | "X = [[word2idx[w[0]] for w in s] for s in sentences]\n", 430 | "# Padding each sentence to have the same lenght\n", 431 | "X = pad_sequences(maxlen=MAX_LEN, sequences=X, padding=\"post\", value=word2idx[\"PAD\"])\n", 432 | "\n", 433 | "# Convert Tag/Label to tag_index\n", 434 | "y = [[tag2idx[w[2]] for w in s] for s in sentences]\n", 435 | "# Padding each sentence to have the same lenght\n", 436 | "y = pad_sequences(maxlen=MAX_LEN, sequences=y, padding=\"post\", value=tag2idx[\"PAD\"])\n", 437 | "\n", 438 | "from keras.utils import to_categorical\n", 439 | "# One-Hot encode\n", 440 | "y = [to_categorical(i, num_classes=n_tags+1) for i in y] # n_tags+1(PAD)\n", 441 | "\n", 442 | "from sklearn.model_selection import train_test_split\n", 443 | "X_tr, X_te, y_tr, y_te = train_test_split(X, y, test_size=0.1)\n", 444 | "X_tr.shape, X_te.shape, np.array(y_tr).shape, np.array(y_te).shape\n", 445 | "\n", 446 | "print('Raw Sample: ', ' '.join([w[0] for w in sentences[0]]))\n", 447 | "print('Raw Label: ', ' '.join([w[2] for w in sentences[0]]))\n", 448 | "print('After processing, sample:', X[0])\n", 449 | "print('After processing, labels:', y[0])" 450 | ] 451 | }, 452 | { 453 | "cell_type": "markdown", 454 | "metadata": {}, 455 | "source": [ 456 | "## Model\n", 457 | "\n", 458 | "We will implement a model similar to Zhiheng Huang’s [Bidirectional LSTM-CRF Models for Sequence Tagging](https://arxiv.org/pdf/1508.01991v1.pdf).\n", 459 | "\n", 460 | "\n", 461 | "\n", 462 | "*Image from [the paper](https://arxiv.org/pdf/1508.01991v1.pdf)*" 463 | ] 464 | }, 465 | { 466 | "cell_type": "code", 467 | "execution_count": 8, 468 | "metadata": {}, 469 | "outputs": [ 470 | { 471 | "name": "stdout", 472 | "output_type": "stream", 473 | "text": [ 474 | "_________________________________________________________________\n", 475 | "Layer (type) Output Shape Param # \n", 476 | "=================================================================\n", 477 | "input_1 (InputLayer) (None, 75) 0 \n", 478 | "_________________________________________________________________\n", 479 | "embedding_1 (Embedding) (None, 75, 20) 703600 \n", 480 | "_________________________________________________________________\n", 481 | "bidirectional_1 (Bidirection (None, 75, 100) 28400 \n", 482 | "_________________________________________________________________\n", 483 | "time_distributed_1 (TimeDist (None, 75, 50) 5050 \n", 484 | "_________________________________________________________________\n", 485 | "crf_1 (CRF) (None, 75, 18) 1278 \n", 486 | "=================================================================\n", 487 | "Total params: 738,328\n", 488 | "Trainable params: 738,328\n", 489 | "Non-trainable params: 0\n", 490 | "_________________________________________________________________\n" 491 | ] 492 | } 493 | ], 494 | "source": [ 495 | "from keras.models import Model, Input\n", 496 | "from keras.layers import LSTM, Embedding, Dense, TimeDistributed, Dropout, Bidirectional\n", 497 | "from keras_contrib.layers import CRF\n", 498 | "\n", 499 | "# Model definition\n", 500 | "input = Input(shape=(MAX_LEN,))\n", 501 | "model = Embedding(input_dim=n_words+2, output_dim=EMBEDDING, # n_words + 2 (PAD & UNK)\n", 502 | " input_length=MAX_LEN, mask_zero=True)(input) # default: 20-dim embedding\n", 503 | "model = Bidirectional(LSTM(units=50, return_sequences=True,\n", 504 | " recurrent_dropout=0.1))(model) # variational biLSTM\n", 505 | "model = TimeDistributed(Dense(50, activation=\"relu\"))(model) # a dense layer as suggested by neuralNer\n", 506 | "crf = CRF(n_tags+1) # CRF layer, n_tags+1(PAD)\n", 507 | "out = crf(model) # output\n", 508 | "\n", 509 | "model = Model(input, out)\n", 510 | "model.compile(optimizer=\"rmsprop\", loss=crf.loss_function, metrics=[crf.accuracy])\n", 511 | "\n", 512 | "model.summary()" 513 | ] 514 | }, 515 | { 516 | "cell_type": "markdown", 517 | "metadata": {}, 518 | "source": [ 519 | "## Training & Evaluate\n", 520 | "\n", 521 | "The Training is defined at the beginning by the type of instance on which runs:\n", 522 | "\n", 523 | "- On CPU machine: 25 minutes for 5 epochs.\n", 524 | "- On GPU machine: 5 minute for 5 epochs.\n", 525 | "\n", 526 | "*Note*: Accuracy isn't the best metric to choose for evaluating this type of task because most of the time it will correctly predict '**O**' or '**PAD**' without identifing the important Tags, which are the ones we are interested in. So after training for some epochs, we can monitor the **precision**, **recall** and **f1-score** for each of the Tags." 527 | ] 528 | }, 529 | { 530 | "cell_type": "code", 531 | "execution_count": 9, 532 | "metadata": {}, 533 | "outputs": [ 534 | { 535 | "name": "stdout", 536 | "output_type": "stream", 537 | "text": [ 538 | "Train on 38846 samples, validate on 4317 samples\n", 539 | "Epoch 1/5\n", 540 | "38846/38846 [==============================] - 254s 7ms/step - loss: 9.1008 - acc: 0.9034 - val_loss: 9.0202 - val_acc: 0.9517\n", 541 | "Epoch 2/5\n", 542 | "38846/38846 [==============================] - 247s 6ms/step - loss: 8.8708 - acc: 0.9589 - val_loss: 8.9733 - val_acc: 0.9625\n", 543 | "Epoch 3/5\n", 544 | "38846/38846 [==============================] - 246s 6ms/step - loss: 8.8402 - acc: 0.9668 - val_loss: 8.9584 - val_acc: 0.9648\n", 545 | "Epoch 4/5\n", 546 | "38846/38846 [==============================] - 248s 6ms/step - loss: 8.8287 - acc: 0.9703 - val_loss: 8.9542 - val_acc: 0.9654\n", 547 | "Epoch 5/5\n", 548 | "38846/38846 [==============================] - 249s 6ms/step - loss: 8.8227 - acc: 0.9721 - val_loss: 8.9500 - val_acc: 0.9681\n" 549 | ] 550 | } 551 | ], 552 | "source": [ 553 | "history = model.fit(X_tr, np.array(y_tr), batch_size=BATCH_SIZE, epochs=EPOCHS,\n", 554 | " validation_split=0.1, verbose=2)" 555 | ] 556 | }, 557 | { 558 | "cell_type": "code", 559 | "execution_count": 10, 560 | "metadata": {}, 561 | "outputs": [], 562 | "source": [ 563 | "# Eval\n", 564 | "pred_cat = model.predict(X_te)\n", 565 | "pred = np.argmax(pred_cat, axis=-1)\n", 566 | "y_te_true = np.argmax(y_te, -1)" 567 | ] 568 | }, 569 | { 570 | "cell_type": "code", 571 | "execution_count": 11, 572 | "metadata": {}, 573 | "outputs": [ 574 | { 575 | "name": "stdout", 576 | "output_type": "stream", 577 | "text": [ 578 | " precision recall f1-score support\n", 579 | "\n", 580 | " B-art 0.00 0.00 0.00 52\n", 581 | " B-eve 1.00 0.33 0.50 15\n", 582 | " B-geo 0.86 0.90 0.88 3677\n", 583 | " B-gpe 0.97 0.93 0.95 1570\n", 584 | " B-nat 0.00 0.00 0.00 22\n", 585 | " B-org 0.74 0.73 0.74 2012\n", 586 | " B-per 0.82 0.82 0.82 1726\n", 587 | " B-tim 0.93 0.87 0.90 2063\n", 588 | " I-art 0.00 0.00 0.00 36\n", 589 | " I-eve 0.00 0.00 0.00 11\n", 590 | " I-geo 0.83 0.77 0.80 696\n", 591 | " I-gpe 0.92 0.57 0.71 21\n", 592 | " I-nat 0.00 0.00 0.00 8\n", 593 | " I-org 0.74 0.81 0.77 1657\n", 594 | " I-per 0.87 0.90 0.88 1835\n", 595 | " I-tim 0.84 0.73 0.78 642\n", 596 | " O 0.99 0.99 0.99 88839\n", 597 | " PAD 1.00 1.00 1.00 254818\n", 598 | "\n", 599 | "avg / total 0.99 0.99 0.99 359700\n", 600 | "\n" 601 | ] 602 | }, 603 | { 604 | "name": "stderr", 605 | "output_type": "stream", 606 | "text": [ 607 | "/usr/local/lib/python3.6/site-packages/sklearn/metrics/classification.py:1135: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples.\n", 608 | " 'precision', 'predicted', average, warn_for)\n" 609 | ] 610 | } 611 | ], 612 | "source": [ 613 | "from sklearn_crfsuite.metrics import flat_classification_report\n", 614 | "\n", 615 | "# Convert the index to tag\n", 616 | "pred_tag = [[idx2tag[i] for i in row] for row in pred]\n", 617 | "y_te_true_tag = [[idx2tag[i] for i in row] for row in y_te_true] \n", 618 | "\n", 619 | "report = flat_classification_report(y_pred=pred_tag, y_true=y_te_true_tag)\n", 620 | "print(report)" 621 | ] 622 | }, 623 | { 624 | "cell_type": "markdown", 625 | "metadata": {}, 626 | "source": [ 627 | "Evaluate some samples in the test set. (At each execution it will test on a different sample)." 628 | ] 629 | }, 630 | { 631 | "cell_type": "code", 632 | "execution_count": 12, 633 | "metadata": {}, 634 | "outputs": [ 635 | { 636 | "name": "stdout", 637 | "output_type": "stream", 638 | "text": [ 639 | "Sample number 2106 of 4796 (Test Set)\n", 640 | "Word ||True ||Pred\n", 641 | "==============================\n", 642 | "He : O O\n", 643 | "said : O O\n", 644 | ", : O O\n", 645 | "however : O O\n", 646 | ", : O O\n", 647 | "the : O O\n", 648 | "world : O O\n", 649 | "body : O O\n", 650 | "remains : O O\n", 651 | "committed : O O\n", 652 | "to : O O\n", 653 | "work : O O\n", 654 | "for : O O\n", 655 | "peace : O O\n", 656 | "in : O O\n", 657 | "such : O O\n", 658 | "places : O O\n", 659 | "as : O O\n", 660 | "Lebanon : B-geo B-geo\n", 661 | ", : O O\n", 662 | "Darfur : B-geo B-geo\n", 663 | ", : O O\n", 664 | "Haiti : B-geo B-geo\n", 665 | "and : O O\n", 666 | "Iraq : B-geo B-geo\n", 667 | ". : O O\n" 668 | ] 669 | } 670 | ], 671 | "source": [ 672 | "i = np.random.randint(0,X_te.shape[0]) # choose a random number between 0 and len(X_te)\n", 673 | "p = model.predict(np.array([X_te[i]]))\n", 674 | "p = np.argmax(p, axis=-1)\n", 675 | "true = np.argmax(y_te[i], -1)\n", 676 | "\n", 677 | "print(\"Sample number {} of {} (Test Set)\".format(i, X_te.shape[0]))\n", 678 | "# Visualization\n", 679 | "print(\"{:15}||{:5}||{}\".format(\"Word\", \"True\", \"Pred\"))\n", 680 | "print(30 * \"=\")\n", 681 | "for w, t, pred in zip(X_te[i], true, p[0]):\n", 682 | " if w != 0:\n", 683 | " print(\"{:15}: {:5} {}\".format(words[w-2], idx2tag[t], idx2tag[pred]))" 684 | ] 685 | }, 686 | { 687 | "cell_type": "markdown", 688 | "metadata": {}, 689 | "source": [ 690 | "## It's your turn\n", 691 | "\n", 692 | "Test out the model you just trained. Run the code Cell below and type your reviews in the widget, Have fun!🎉\n", 693 | "\n", 694 | "Here are some inspirations:\n", 695 | "\n", 696 | "- Obama was the president of USA.\n", 697 | "- The 1906 San Francisco earthquake was the biggest earthquake that has ever hit San Francisco on April 18, 1906\n", 698 | "- Next Monday is Christmas!\n", 699 | "\n", 700 | "Can you do better? Play around with the model hyperparameters!" 701 | ] 702 | }, 703 | { 704 | "cell_type": "code", 705 | "execution_count": 13, 706 | "metadata": {}, 707 | "outputs": [ 708 | { 709 | "data": { 710 | "application/vnd.jupyter.widget-view+json": { 711 | "model_id": "c20be87443af4311bc10820fb1e51639", 712 | "version_major": 2, 713 | "version_minor": 0 714 | }, 715 | "text/plain": [ 716 | "interactive(children=(Textarea(value='', description='sentence', placeholder='Type your sentence here'), Butto…" 717 | ] 718 | }, 719 | "metadata": {}, 720 | "output_type": "display_data" 721 | } 722 | ], 723 | "source": [ 724 | "from ipywidgets import interact_manual\n", 725 | "from ipywidgets import widgets\n", 726 | "\n", 727 | "import re\n", 728 | "import string\n", 729 | "\n", 730 | "# Custom Tokenizer\n", 731 | "re_tok = re.compile(f'([{string.punctuation}“”¨«»®´·º½¾¿¡§£₤‘’])')\n", 732 | "def tokenize(s): return re_tok.sub(r' \\1 ', s).split()\n", 733 | " \n", 734 | "def get_prediction(sentence):\n", 735 | " test_sentence = tokenize(sentence) # Tokenization\n", 736 | " # Preprocessing\n", 737 | " x_test_sent = pad_sequences(sequences=[[word2idx.get(w, 0) for w in test_sentence]],\n", 738 | " padding=\"post\", value=word2idx[\"PAD\"], maxlen=MAX_LEN)\n", 739 | " # Evaluation\n", 740 | " p = model.predict(np.array([x_test_sent[0]]))\n", 741 | " p = np.argmax(p, axis=-1)\n", 742 | " # Visualization\n", 743 | " print(\"{:15}||{}\".format(\"Word\", \"Prediction\"))\n", 744 | " print(30 * \"=\")\n", 745 | " for w, pred in zip(test_sentence, p[0]):\n", 746 | " print(\"{:15}: {:5}\".format(w, idx2tag[pred]))\n", 747 | "\n", 748 | "interact_manual(get_prediction, sentence=widgets.Textarea(placeholder='Type your sentence here'));" 749 | ] 750 | }, 751 | { 752 | "cell_type": "markdown", 753 | "metadata": {}, 754 | "source": [ 755 | "## Save the result" 756 | ] 757 | }, 758 | { 759 | "cell_type": "code", 760 | "execution_count": 14, 761 | "metadata": {}, 762 | "outputs": [], 763 | "source": [ 764 | "import pickle\n", 765 | "\n", 766 | "# Saving Vocab\n", 767 | "with open('models/word_to_index.pickle', 'wb') as handle:\n", 768 | " pickle.dump(word2idx, handle, protocol=pickle.HIGHEST_PROTOCOL)\n", 769 | " \n", 770 | "# Saving Vocab\n", 771 | "with open('models/tag_to_index.pickle', 'wb') as handle:\n", 772 | " pickle.dump(tag2idx, handle, protocol=pickle.HIGHEST_PROTOCOL)\n", 773 | " \n", 774 | "# Saving Model Weight\n", 775 | "model.save_weights('models/lstm_crf_weights.h5')" 776 | ] 777 | }, 778 | { 779 | "cell_type": "markdown", 780 | "metadata": {}, 781 | "source": [ 782 | "##### That's all folks - don't forget to shutdown your workspace once you're done 🙂" 783 | ] 784 | } 785 | ], 786 | "metadata": { 787 | "kernelspec": { 788 | "display_name": "Python 2", 789 | "language": "python", 790 | "name": "python2" 791 | }, 792 | "language_info": { 793 | "codemirror_mode": { 794 | "name": "ipython", 795 | "version": 2 796 | }, 797 | "file_extension": ".py", 798 | "mimetype": "text/x-python", 799 | "name": "python", 800 | "nbconvert_exporter": "python", 801 | "pygments_lexer": "ipython2", 802 | "version": "2.7.10" 803 | } 804 | }, 805 | "nbformat": 4, 806 | "nbformat_minor": 2 807 | } 808 | --------------------------------------------------------------------------------