├── .gitignore ├── Attention Tutorial.ipynb ├── README.MD ├── data ├── Time Dataset.json └── Time Vocabs.json └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | Time\ Dataset.ipynb 2 | .ipynb_checkpoints/ 3 | -------------------------------------------------------------------------------- /Attention Tutorial.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Attention Tutorial\n", 8 | "\n", 9 | "One of the most influential and interesting new neural networks types is the attention network. It's been used succesfully in translation services, [medical diagnosis](https://arxiv.org/pdf/1710.08312.pdf), and other tasks.\n", 10 | "\n", 11 | "Below we'll be walking through how to write your very own attention network. Our goal is to make a network that can translate human written times ('quarter after 3 pm') to military time ('15:15').\n", 12 | "\n", 13 | "The attention mechamism is defined in section **Model**.\n", 14 | "\n", 15 | "For a tutorial on how Attention Networks work, please visit [MuffinTech](http://muffintech.org/blog/id/12).\n", 16 | "\n", 17 | "Credit to Andrew Ng for reference model and inspiration." 18 | ] 19 | }, 20 | { 21 | "cell_type": "markdown", 22 | "metadata": {}, 23 | "source": [ 24 | "## Imports" 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": 1, 30 | "metadata": {}, 31 | "outputs": [ 32 | { 33 | "name": "stderr", 34 | "output_type": "stream", 35 | "text": [ 36 | "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/site-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.\n", 37 | " from ._conv import register_converters as _register_converters\n", 38 | "Using TensorFlow backend.\n" 39 | ] 40 | } 41 | ], 42 | "source": [ 43 | "# Imports\n", 44 | "from keras.layers import Bidirectional, Concatenate, Permute, Dot, Input, LSTM, Multiply, Reshape\n", 45 | "from keras.layers import RepeatVector, Dense, Activation, Lambda\n", 46 | "from keras.optimizers import Adam\n", 47 | "from keras.utils import to_categorical\n", 48 | "from keras.models import load_model, Model\n", 49 | "from keras.callbacks import LearningRateScheduler\n", 50 | "import keras.backend as K\n", 51 | "\n", 52 | "import matplotlib.pyplot as plt\n", 53 | "%matplotlib inline\n", 54 | "\n", 55 | "import numpy as np\n", 56 | "\n", 57 | "import random\n", 58 | "import math\n", 59 | "import json\n", 60 | "\n", 61 | "# Pinkie Pie was here" 62 | ] 63 | }, 64 | { 65 | "cell_type": "markdown", 66 | "metadata": {}, 67 | "source": [ 68 | "## Dataset\n", 69 | "\n", 70 | "The dataset was created using some simple rules. It is not exhaustive, but provides some very nice challenges.\n", 71 | "\n", 72 | "The dataset is included in the Github repo.\n", 73 | "\n", 74 | "Some example data pairs are listed below:\n", 75 | "\n", 76 | "['48 min before 10 a.m', '09:12'] \n", 77 | "['t11:36', '11:36'] \n", 78 | "[\"nine o'clock forty six p.m\", '21:46'] \n", 79 | "['2:59p.m.', '14:59'] \n", 80 | "['23 min after 20 p.m.', '20:23'] \n", 81 | "['46 min after seven p.m.', '19:46'] \n", 82 | "['10 before nine pm', '20:50'] \n", 83 | "['3.20', '03:20'] \n", 84 | "['7.57', '07:57'] \n", 85 | "['six hours and fifty five am', '06:55'] " 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": 2, 91 | "metadata": {}, 92 | "outputs": [], 93 | "source": [ 94 | "with open('data/Time Dataset.json','r') as f:\n", 95 | " dataset = json.loads(f.read())\n", 96 | "with open('data/Time Vocabs.json','r') as f:\n", 97 | " human_vocab, machine_vocab = json.loads(f.read())\n", 98 | " \n", 99 | "human_vocab_size = len(human_vocab)\n", 100 | "machine_vocab_size = len(machine_vocab)\n", 101 | "\n", 102 | "# Number of training examples\n", 103 | "m = len(dataset)" 104 | ] 105 | }, 106 | { 107 | "cell_type": "markdown", 108 | "metadata": {}, 109 | "source": [ 110 | "Next let's define some general helper methods. They are used to help tokenize data." 111 | ] 112 | }, 113 | { 114 | "cell_type": "code", 115 | "execution_count": 3, 116 | "metadata": { 117 | "scrolled": false 118 | }, 119 | "outputs": [], 120 | "source": [ 121 | "def preprocess_data(dataset, human_vocab, machine_vocab, Tx, Ty):\n", 122 | " \"\"\"\n", 123 | " A method for tokenizing data.\n", 124 | " \n", 125 | " Inputs:\n", 126 | " dataset - A list of sentence data pairs.\n", 127 | " human_vocab - A dictionary of tokens (char) to id's.\n", 128 | " machine_vocab - A dictionary of tokens (char) to id's.\n", 129 | " Tx - X data size\n", 130 | " Ty - Y data size\n", 131 | " \n", 132 | " Outputs:\n", 133 | " X - Sparse tokens for X data\n", 134 | " Y - Sparse tokens for Y data\n", 135 | " Xoh - One hot tokens for X data\n", 136 | " Yoh - One hot tokens for Y data\n", 137 | " \"\"\"\n", 138 | " \n", 139 | " # Metadata\n", 140 | " m = len(dataset)\n", 141 | " \n", 142 | " # Initialize\n", 143 | " X = np.zeros([m, Tx], dtype='int32')\n", 144 | " Y = np.zeros([m, Ty], dtype='int32')\n", 145 | " \n", 146 | " # Process data\n", 147 | " for i in range(m):\n", 148 | " data = dataset[i]\n", 149 | " X[i] = np.array(tokenize(data[0], human_vocab, Tx))\n", 150 | " Y[i] = np.array(tokenize(data[1], machine_vocab, Ty))\n", 151 | " \n", 152 | " # Expand one hots\n", 153 | " Xoh = oh_2d(X, len(human_vocab))\n", 154 | " Yoh = oh_2d(Y, len(machine_vocab))\n", 155 | " \n", 156 | " return (X, Y, Xoh, Yoh)\n", 157 | " \n", 158 | "def tokenize(sentence, vocab, length):\n", 159 | " \"\"\"\n", 160 | " Returns a series of id's for a given input token sequence.\n", 161 | " \n", 162 | " It is advised that the vocab supports and .\n", 163 | " \n", 164 | " Inputs:\n", 165 | " sentence - Series of tokens\n", 166 | " vocab - A dictionary from token to id\n", 167 | " length - Max number of tokens to consider\n", 168 | " \n", 169 | " Outputs:\n", 170 | " tokens - \n", 171 | " \"\"\"\n", 172 | " tokens = [0]*length\n", 173 | " for i in range(length):\n", 174 | " char = sentence[i] if i < len(sentence) else \"\"\n", 175 | " char = char if (char in vocab) else \"\"\n", 176 | " tokens[i] = vocab[char]\n", 177 | " \n", 178 | " return tokens\n", 179 | "\n", 180 | "def ids_to_keys(sentence, vocab):\n", 181 | " \"\"\"\n", 182 | " Converts a series of id's into the keys of a dictionary.\n", 183 | " \"\"\"\n", 184 | " reverse_vocab = {v: k for k, v in vocab.items()}\n", 185 | " \n", 186 | " return [reverse_vocab[id] for id in sentence]\n", 187 | "\n", 188 | "def oh_2d(dense, max_value):\n", 189 | " \"\"\"\n", 190 | " Create a one hot array for the 2D input dense array.\n", 191 | " \"\"\"\n", 192 | " # Initialize\n", 193 | " oh = np.zeros(np.append(dense.shape, [max_value]))\n", 194 | " \n", 195 | " # Set correct indices\n", 196 | " ids1, ids2 = np.meshgrid(np.arange(dense.shape[0]), np.arange(dense.shape[1]))\n", 197 | " \n", 198 | " oh[ids1.flatten(), ids2.flatten(), dense.flatten('F').astype(int)] = 1\n", 199 | " \n", 200 | " return oh" 201 | ] 202 | }, 203 | { 204 | "cell_type": "markdown", 205 | "metadata": {}, 206 | "source": [ 207 | "Our next goal is to tokenize the data using our vocabularies." 208 | ] 209 | }, 210 | { 211 | "cell_type": "code", 212 | "execution_count": 4, 213 | "metadata": { 214 | "scrolled": true 215 | }, 216 | "outputs": [], 217 | "source": [ 218 | "Tx = 41 # Max x sequence length\n", 219 | "Ty = 5 # y sequence length\n", 220 | "X, Y, Xoh, Yoh = preprocess_data(dataset, human_vocab, machine_vocab, Tx, Ty)\n", 221 | "\n", 222 | "# Split data 80-20 between training and test\n", 223 | "train_size = int(0.8*m)\n", 224 | "Xoh_train = Xoh[:train_size]\n", 225 | "Yoh_train = Yoh[:train_size]\n", 226 | "Xoh_test = Xoh[train_size:]\n", 227 | "Yoh_test = Yoh[train_size:]" 228 | ] 229 | }, 230 | { 231 | "cell_type": "markdown", 232 | "metadata": {}, 233 | "source": [ 234 | "To be careful, let's check that the code works:" 235 | ] 236 | }, 237 | { 238 | "cell_type": "code", 239 | "execution_count": 5, 240 | "metadata": {}, 241 | "outputs": [ 242 | { 243 | "name": "stdout", 244 | "output_type": "stream", 245 | "text": [ 246 | "Input data point 4.\n", 247 | "\n", 248 | "The data input is: 8:25\n", 249 | "The data output is: 08:25\n", 250 | "\n", 251 | "The tokenized input is:[11 13 5 8 40 40 40 40 40 40 40 40 40 40 40 40 40 40 40 40 40 40 40 40\n", 252 | " 40 40 40 40 40 40 40 40 40 40 40 40 40 40 40 40 40]\n", 253 | "The tokenized output is: [ 0 8 10 2 5]\n", 254 | "\n", 255 | "The one-hot input is: [[0. 0. 0. ... 0. 0. 0.]\n", 256 | " [0. 0. 0. ... 0. 0. 0.]\n", 257 | " [0. 0. 0. ... 0. 0. 0.]\n", 258 | " ...\n", 259 | " [0. 0. 0. ... 0. 0. 1.]\n", 260 | " [0. 0. 0. ... 0. 0. 1.]\n", 261 | " [0. 0. 0. ... 0. 0. 1.]]\n", 262 | "The one-hot output is: [[1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", 263 | " [0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]\n", 264 | " [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]\n", 265 | " [0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]\n", 266 | " [0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]]\n" 267 | ] 268 | } 269 | ], 270 | "source": [ 271 | "i = 4\n", 272 | "print(\"Input data point \" + str(i) + \".\")\n", 273 | "print(\"\")\n", 274 | "print(\"The data input is: \" + str(dataset[i][0]))\n", 275 | "print(\"The data output is: \" + str(dataset[i][1]))\n", 276 | "print(\"\")\n", 277 | "print(\"The tokenized input is:\" + str(X[i]))\n", 278 | "print(\"The tokenized output is: \" + str(Y[i]))\n", 279 | "print(\"\")\n", 280 | "print(\"The one-hot input is:\", Xoh[i])\n", 281 | "print(\"The one-hot output is:\", Yoh[i])" 282 | ] 283 | }, 284 | { 285 | "cell_type": "markdown", 286 | "metadata": {}, 287 | "source": [ 288 | "## Model\n", 289 | "\n", 290 | "Our next goal is to define our model. The important part will be defining the attention mechanism and then making sure to apply that correctly.\n", 291 | "\n", 292 | "Define some model metadata:" 293 | ] 294 | }, 295 | { 296 | "cell_type": "code", 297 | "execution_count": 6, 298 | "metadata": {}, 299 | "outputs": [], 300 | "source": [ 301 | "layer1_size = 32\n", 302 | "layer2_size = 64 # Attention layer" 303 | ] 304 | }, 305 | { 306 | "cell_type": "markdown", 307 | "metadata": {}, 308 | "source": [ 309 | "The next two code snippets defined the attention mechanism. This is split into two arcs:\n", 310 | "\n", 311 | "* Calculating context\n", 312 | "* Creating an attention layer\n", 313 | "\n", 314 | "As a refresher, an attention network pays attention to certain parts of the input at each output time step. _attention_ denotes which inputs are most relevant to the current output step. An input step will have attention weight ~1 if it is relevant, and ~0 otherwise. The _context_ is the \"summary of the input\".\n", 315 | "\n", 316 | "The requirements are thus. The attention matrix should have shape $(T_x)$ and sum to 1. Additionally, the context should be calculated in the same manner for each time step. Beyond that, there is some flexibility. This notebook calculates both this way:\n", 317 | "\n", 318 | "$$\n", 319 | "attention = Softmax(Dense(Dense(x, y_{t-1})))\n", 320 | "$$\n", 321 | "
\n", 322 | "$$\n", 323 | "context = \\sum_{i=1}^{m} ( attention_i * x_i )\n", 324 | "$$\n", 325 | "\n", 326 | "For safety, $y_0$ is defined as $\\vec{0}$.\n", 327 | "\n" 328 | ] 329 | }, 330 | { 331 | "cell_type": "code", 332 | "execution_count": 7, 333 | "metadata": {}, 334 | "outputs": [], 335 | "source": [ 336 | "# Define part of the attention layer gloablly so as to\n", 337 | "# share the same layers for each attention step.\n", 338 | "def softmax(x):\n", 339 | " return K.softmax(x, axis=1)\n", 340 | "\n", 341 | "at_repeat = RepeatVector(Tx)\n", 342 | "at_concatenate = Concatenate(axis=-1)\n", 343 | "at_dense1 = Dense(8, activation=\"tanh\")\n", 344 | "at_dense2 = Dense(1, activation=\"relu\")\n", 345 | "at_softmax = Activation(softmax, name='attention_weights')\n", 346 | "at_dot = Dot(axes=1)\n", 347 | "\n", 348 | "def one_step_of_attention(h_prev, a):\n", 349 | " \"\"\"\n", 350 | " Get the context.\n", 351 | " \n", 352 | " Input:\n", 353 | " h_prev - Previous hidden state of a RNN layer (m, n_h)\n", 354 | " a - Input data, possibly processed (m, Tx, n_a)\n", 355 | " \n", 356 | " Output:\n", 357 | " context - Current context (m, Tx, n_a)\n", 358 | " \"\"\"\n", 359 | " # Repeat vector to match a's dimensions\n", 360 | " h_repeat = at_repeat(h_prev)\n", 361 | " # Calculate attention weights\n", 362 | " i = at_concatenate([a, h_repeat])\n", 363 | " i = at_dense1(i)\n", 364 | " i = at_dense2(i)\n", 365 | " attention = at_softmax(i)\n", 366 | " # Calculate the context\n", 367 | " context = at_dot([attention, a])\n", 368 | " \n", 369 | " return context" 370 | ] 371 | }, 372 | { 373 | "cell_type": "code", 374 | "execution_count": 8, 375 | "metadata": {}, 376 | "outputs": [], 377 | "source": [ 378 | "def attention_layer(X, n_h, Ty):\n", 379 | " \"\"\"\n", 380 | " Creates an attention layer.\n", 381 | " \n", 382 | " Input:\n", 383 | " X - Layer input (m, Tx, x_vocab_size)\n", 384 | " n_h - Size of LSTM hidden layer\n", 385 | " Ty - Timesteps in output sequence\n", 386 | " \n", 387 | " Output:\n", 388 | " output - The output of the attention layer (m, Tx, n_h)\n", 389 | " \"\"\" \n", 390 | " # Define the default state for the LSTM layer\n", 391 | " h = Lambda(lambda X: K.zeros(shape=(K.shape(X)[0], n_h)))(X)\n", 392 | " c = Lambda(lambda X: K.zeros(shape=(K.shape(X)[0], n_h)))(X)\n", 393 | " # Messy, but the alternative is using more Input()\n", 394 | " \n", 395 | " at_LSTM = LSTM(n_h, return_state=True)\n", 396 | " \n", 397 | " output = []\n", 398 | " \n", 399 | " # Run attention step and RNN for each output time step\n", 400 | " for _ in range(Ty):\n", 401 | " context = one_step_of_attention(h, X)\n", 402 | " \n", 403 | " h, _, c = at_LSTM(context, initial_state=[h, c])\n", 404 | " \n", 405 | " output.append(h)\n", 406 | " \n", 407 | " return output" 408 | ] 409 | }, 410 | { 411 | "cell_type": "markdown", 412 | "metadata": {}, 413 | "source": [ 414 | "The sample model is organized as follows:\n", 415 | "\n", 416 | "1. BiLSTM\n", 417 | "2. Attention Layer\n", 418 | " * Outputs Ty lists of activations.\n", 419 | "3. Dense\n", 420 | " * Necessary to convert attention layer's output to the correct y dimensions" 421 | ] 422 | }, 423 | { 424 | "cell_type": "code", 425 | "execution_count": 9, 426 | "metadata": {}, 427 | "outputs": [], 428 | "source": [ 429 | "layer3 = Dense(machine_vocab_size, activation=softmax)\n", 430 | "\n", 431 | "def get_model(Tx, Ty, layer1_size, layer2_size, x_vocab_size, y_vocab_size):\n", 432 | " \"\"\"\n", 433 | " Creates a model.\n", 434 | " \n", 435 | " input:\n", 436 | " Tx - Number of x timesteps\n", 437 | " Ty - Number of y timesteps\n", 438 | " size_layer1 - Number of neurons in BiLSTM\n", 439 | " size_layer2 - Number of neurons in attention LSTM hidden layer\n", 440 | " x_vocab_size - Number of possible token types for x\n", 441 | " y_vocab_size - Number of possible token types for y\n", 442 | " \n", 443 | " Output:\n", 444 | " model - A Keras Model.\n", 445 | " \"\"\"\n", 446 | " \n", 447 | " # Create layers one by one\n", 448 | " X = Input(shape=(Tx, x_vocab_size))\n", 449 | " \n", 450 | " a1 = Bidirectional(LSTM(layer1_size, return_sequences=True), merge_mode='concat')(X)\n", 451 | "\n", 452 | " a2 = attention_layer(a1, layer2_size, Ty)\n", 453 | " \n", 454 | " a3 = [layer3(timestep) for timestep in a2]\n", 455 | " \n", 456 | " # Create Keras model\n", 457 | " model = Model(inputs=[X], outputs=a3)\n", 458 | " \n", 459 | " return model" 460 | ] 461 | }, 462 | { 463 | "cell_type": "markdown", 464 | "metadata": {}, 465 | "source": [ 466 | "The steps from here on out are for creating the model and training it. Simple as that." 467 | ] 468 | }, 469 | { 470 | "cell_type": "code", 471 | "execution_count": 10, 472 | "metadata": { 473 | "scrolled": false 474 | }, 475 | "outputs": [], 476 | "source": [ 477 | "# Obtain a model instance\n", 478 | "model = get_model(Tx, Ty, layer1_size, layer2_size, human_vocab_size, machine_vocab_size)" 479 | ] 480 | }, 481 | { 482 | "cell_type": "code", 483 | "execution_count": 11, 484 | "metadata": {}, 485 | "outputs": [], 486 | "source": [ 487 | "# Create optimizer\n", 488 | "opt = Adam(lr=0.05, decay=0.04, clipnorm=1.0)\n", 489 | "model.compile(optimizer=opt, loss='categorical_crossentropy', metrics=['accuracy'])" 490 | ] 491 | }, 492 | { 493 | "cell_type": "code", 494 | "execution_count": 12, 495 | "metadata": {}, 496 | "outputs": [], 497 | "source": [ 498 | "# Group the output by timestep, not example\n", 499 | "outputs_train = list(Yoh_train.swapaxes(0,1))" 500 | ] 501 | }, 502 | { 503 | "cell_type": "code", 504 | "execution_count": 13, 505 | "metadata": { 506 | "scrolled": true 507 | }, 508 | "outputs": [ 509 | { 510 | "name": "stdout", 511 | "output_type": "stream", 512 | "text": [ 513 | "WARNING:tensorflow:Variable *= will be deprecated. Use variable.assign_mul if you want assignment to the variable value or 'x = x * y' if you want a new python Tensor object.\n", 514 | "Epoch 1/30\n", 515 | "8000/8000 [==============================] - 13s 2ms/step - loss: 7.1344 - dense_3_loss: 2.1335 - dense_3_acc: 0.5379 - dense_3_acc_1: 0.2478 - dense_3_acc_2: 0.9748 - dense_3_acc_3: 0.2354 - dense_3_acc_4: 0.2124\n", 516 | "Epoch 2/30\n", 517 | "8000/8000 [==============================] - 5s 606us/step - loss: 3.3504 - dense_3_loss: 0.7929 - dense_3_acc: 0.7854 - dense_3_acc_1: 0.6706 - dense_3_acc_2: 1.0000 - dense_3_acc_3: 0.5629 - dense_3_acc_4: 0.7645\n", 518 | "Epoch 3/30\n", 519 | "8000/8000 [==============================] - 5s 645us/step - loss: 1.1928 - dense_3_loss: 0.1806 - dense_3_acc: 0.9091 - dense_3_acc_1: 0.9138 - dense_3_acc_2: 1.0000 - dense_3_acc_3: 0.8390 - dense_3_acc_4: 0.9433\n", 520 | "Epoch 4/30\n", 521 | "8000/8000 [==============================] - 5s 617us/step - loss: 0.5314 - dense_3_loss: 0.0494 - dense_3_acc: 0.9733 - dense_3_acc_1: 0.9754 - dense_3_acc_2: 1.0000 - dense_3_acc_3: 0.9169 - dense_3_acc_4: 0.9936\n", 522 | "Epoch 5/30\n", 523 | "8000/8000 [==============================] - 5s 614us/step - loss: 0.2888 - dense_3_loss: 0.0249 - dense_3_acc: 0.9894 - dense_3_acc_1: 0.9863 - dense_3_acc_2: 1.0000 - dense_3_acc_3: 0.9693 - dense_3_acc_4: 0.9978\n", 524 | "Epoch 6/30\n", 525 | "8000/8000 [==============================] - 5s 602us/step - loss: 0.1833 - dense_3_loss: 0.0142 - dense_3_acc: 0.9914 - dense_3_acc_1: 0.9913 - dense_3_acc_2: 1.0000 - dense_3_acc_3: 0.9864 - dense_3_acc_4: 0.9994\n", 526 | "Epoch 7/30\n", 527 | "8000/8000 [==============================] - 5s 599us/step - loss: 0.1370 - dense_3_loss: 0.0100 - dense_3_acc: 0.9921 - dense_3_acc_1: 0.9928 - dense_3_acc_2: 1.0000 - dense_3_acc_3: 0.9890 - dense_3_acc_4: 0.9996\n", 528 | "Epoch 8/30\n", 529 | "8000/8000 [==============================] - 5s 602us/step - loss: 0.1070 - dense_3_loss: 0.0070 - dense_3_acc: 0.9938 - dense_3_acc_1: 0.9935 - dense_3_acc_2: 1.0000 - dense_3_acc_3: 0.9901 - dense_3_acc_4: 0.9999\n", 530 | "Epoch 9/30\n", 531 | "8000/8000 [==============================] - 5s 596us/step - loss: 0.0927 - dense_3_loss: 0.0062 - dense_3_acc: 0.9944 - dense_3_acc_1: 0.9938 - dense_3_acc_2: 1.0000 - dense_3_acc_3: 0.9908 - dense_3_acc_4: 0.9998\n", 532 | "Epoch 10/30\n", 533 | "8000/8000 [==============================] - 5s 599us/step - loss: 0.0786 - dense_3_loss: 0.0046 - dense_3_acc: 0.9950 - dense_3_acc_1: 0.9946 - dense_3_acc_2: 1.0000 - dense_3_acc_3: 0.9919 - dense_3_acc_4: 1.0000\n", 534 | "Epoch 11/30\n", 535 | "8000/8000 [==============================] - 5s 605us/step - loss: 0.0706 - dense_3_loss: 0.0040 - dense_3_acc: 0.9946 - dense_3_acc_1: 0.9949 - dense_3_acc_2: 1.0000 - dense_3_acc_3: 0.9921 - dense_3_acc_4: 1.0000\n", 536 | "Epoch 12/30\n", 537 | "8000/8000 [==============================] - 5s 602us/step - loss: 0.0640 - dense_3_loss: 0.0034 - dense_3_acc: 0.9944 - dense_3_acc_1: 0.9941 - dense_3_acc_2: 1.0000 - dense_3_acc_3: 0.9936 - dense_3_acc_4: 1.0000\n", 538 | "Epoch 13/30\n", 539 | "8000/8000 [==============================] - 5s 653us/step - loss: 0.0560 - dense_3_loss: 0.0032 - dense_3_acc: 0.9954 - dense_3_acc_1: 0.9951 - dense_3_acc_2: 1.0000 - dense_3_acc_3: 0.9943 - dense_3_acc_4: 1.0000\n", 540 | "Epoch 14/30\n", 541 | "8000/8000 [==============================] - 5s 682us/step - loss: 0.0539 - dense_3_loss: 0.0031 - dense_3_acc: 0.9953 - dense_3_acc_1: 0.9951 - dense_3_acc_2: 1.0000 - dense_3_acc_3: 0.9951 - dense_3_acc_4: 1.0000\n", 542 | "Epoch 15/30\n", 543 | "8000/8000 [==============================] - 5s 654us/step - loss: 0.0475 - dense_3_loss: 0.0027 - dense_3_acc: 0.9951 - dense_3_acc_1: 0.9956 - dense_3_acc_2: 1.0000 - dense_3_acc_3: 0.9959 - dense_3_acc_4: 1.0000\n", 544 | "Epoch 16/30\n", 545 | "8000/8000 [==============================] - 5s 635us/step - loss: 0.0444 - dense_3_loss: 0.0024 - dense_3_acc: 0.9948 - dense_3_acc_1: 0.9946 - dense_3_acc_2: 1.0000 - dense_3_acc_3: 0.9960 - dense_3_acc_4: 1.0000\n", 546 | "Epoch 17/30\n", 547 | "8000/8000 [==============================] - 5s 662us/step - loss: 0.0407 - dense_3_loss: 0.0023 - dense_3_acc: 0.9956 - dense_3_acc_1: 0.9955 - dense_3_acc_2: 1.0000 - dense_3_acc_3: 0.9970 - dense_3_acc_4: 1.0000\n", 548 | "Epoch 18/30\n", 549 | "8000/8000 [==============================] - 5s 658us/step - loss: 0.0389 - dense_3_loss: 0.0021 - dense_3_acc: 0.9956 - dense_3_acc_1: 0.9956 - dense_3_acc_2: 1.0000 - dense_3_acc_3: 0.9971 - dense_3_acc_4: 1.0000\n", 550 | "Epoch 19/30\n", 551 | "8000/8000 [==============================] - 5s 643us/step - loss: 0.0361 - dense_3_loss: 0.0020 - dense_3_acc: 0.9961 - dense_3_acc_1: 0.9956 - dense_3_acc_2: 1.0000 - dense_3_acc_3: 0.9978 - dense_3_acc_4: 1.0000\n", 552 | "Epoch 20/30\n", 553 | "8000/8000 [==============================] - 5s 654us/step - loss: 0.0337 - dense_3_loss: 0.0018 - dense_3_acc: 0.9955 - dense_3_acc_1: 0.9961 - dense_3_acc_2: 1.0000 - dense_3_acc_3: 0.9979 - dense_3_acc_4: 1.0000\n", 554 | "Epoch 21/30\n", 555 | "8000/8000 [==============================] - 5s 641us/step - loss: 0.0320 - dense_3_loss: 0.0017 - dense_3_acc: 0.9964 - dense_3_acc_1: 0.9966 - dense_3_acc_2: 1.0000 - dense_3_acc_3: 0.9980 - dense_3_acc_4: 1.0000\n", 556 | "Epoch 22/30\n", 557 | "8000/8000 [==============================] - 5s 649us/step - loss: 0.0321 - dense_3_loss: 0.0017 - dense_3_acc: 0.9961 - dense_3_acc_1: 0.9963 - dense_3_acc_2: 1.0000 - dense_3_acc_3: 0.9981 - dense_3_acc_4: 1.0000\n", 558 | "Epoch 23/30\n", 559 | "8000/8000 [==============================] - 5s 649us/step - loss: 0.0296 - dense_3_loss: 0.0016 - dense_3_acc: 0.9966 - dense_3_acc_1: 0.9964 - dense_3_acc_2: 1.0000 - dense_3_acc_3: 0.9981 - dense_3_acc_4: 1.0000\n", 560 | "Epoch 24/30\n", 561 | "8000/8000 [==============================] - 5s 636us/step - loss: 0.0290 - dense_3_loss: 0.0015 - dense_3_acc: 0.9960 - dense_3_acc_1: 0.9958 - dense_3_acc_2: 1.0000 - dense_3_acc_3: 0.9983 - dense_3_acc_4: 1.0000\n", 562 | "Epoch 25/30\n", 563 | "8000/8000 [==============================] - 5s 657us/step - loss: 0.0270 - dense_3_loss: 0.0015 - dense_3_acc: 0.9970 - dense_3_acc_1: 0.9966 - dense_3_acc_2: 1.0000 - dense_3_acc_3: 0.9989 - dense_3_acc_4: 1.0000\n", 564 | "Epoch 26/30\n", 565 | "8000/8000 [==============================] - 5s 653us/step - loss: 0.0267 - dense_3_loss: 0.0014 - dense_3_acc: 0.9963 - dense_3_acc_1: 0.9965 - dense_3_acc_2: 1.0000 - dense_3_acc_3: 0.9989 - dense_3_acc_4: 1.0000\n", 566 | "Epoch 27/30\n", 567 | "8000/8000 [==============================] - 5s 637us/step - loss: 0.0248 - dense_3_loss: 0.0013 - dense_3_acc: 0.9964 - dense_3_acc_1: 0.9969 - dense_3_acc_2: 1.0000 - dense_3_acc_3: 0.9991 - dense_3_acc_4: 1.0000\n", 568 | "Epoch 28/30\n", 569 | "8000/8000 [==============================] - 5s 655us/step - loss: 0.0248 - dense_3_loss: 0.0013 - dense_3_acc: 0.9968 - dense_3_acc_1: 0.9961 - dense_3_acc_2: 1.0000 - dense_3_acc_3: 0.9994 - dense_3_acc_4: 1.0000\n", 570 | "Epoch 29/30\n", 571 | "8000/8000 [==============================] - 5s 647us/step - loss: 0.0234 - dense_3_loss: 0.0012 - dense_3_acc: 0.9971 - dense_3_acc_1: 0.9973 - dense_3_acc_2: 1.0000 - dense_3_acc_3: 0.9994 - dense_3_acc_4: 1.0000\n", 572 | "Epoch 30/30\n", 573 | "8000/8000 [==============================] - 5s 656us/step - loss: 0.0237 - dense_3_loss: 0.0012 - dense_3_acc: 0.9971 - dense_3_acc_1: 0.9968 - dense_3_acc_2: 1.0000 - dense_3_acc_3: 0.9994 - dense_3_acc_4: 1.0000\n" 574 | ] 575 | }, 576 | { 577 | "data": { 578 | "text/plain": [ 579 | "" 580 | ] 581 | }, 582 | "execution_count": 13, 583 | "metadata": {}, 584 | "output_type": "execute_result" 585 | } 586 | ], 587 | "source": [ 588 | "# Time to train\n", 589 | "# It takes a few minutes on an quad-core CPU\n", 590 | "model.fit([Xoh_train], outputs_train, epochs=30, batch_size=100)" 591 | ] 592 | }, 593 | { 594 | "cell_type": "markdown", 595 | "metadata": {}, 596 | "source": [ 597 | "## Evaluation\n", 598 | "\n", 599 | "The final training loss should be in the range of 0.02 to 0.5\n", 600 | "\n", 601 | "The test loss should be at a similar level." 602 | ] 603 | }, 604 | { 605 | "cell_type": "code", 606 | "execution_count": 14, 607 | "metadata": {}, 608 | "outputs": [ 609 | { 610 | "name": "stdout", 611 | "output_type": "stream", 612 | "text": [ 613 | "2000/2000 [==============================] - 1s 669us/step\n", 614 | "Test loss: 0.0888983543291688\n" 615 | ] 616 | } 617 | ], 618 | "source": [ 619 | "# Evaluate the test performance\n", 620 | "outputs_test = list(Yoh_test.swapaxes(0,1))\n", 621 | "score = model.evaluate(Xoh_test, outputs_test) \n", 622 | "print('Test loss: ', score[0])" 623 | ] 624 | }, 625 | { 626 | "cell_type": "markdown", 627 | "metadata": {}, 628 | "source": [ 629 | "Now that we've created this beautiful model, let's see how it does in action.\n", 630 | "\n", 631 | "The below code finds a random example and runs it through our model." 632 | ] 633 | }, 634 | { 635 | "cell_type": "code", 636 | "execution_count": 15, 637 | "metadata": {}, 638 | "outputs": [ 639 | { 640 | "name": "stdout", 641 | "output_type": "stream", 642 | "text": [ 643 | "Input: t6:27 a.m.\n", 644 | "Tokenized: [32 9 13 5 10 0 14 2 25 2 40 40 40 40 40 40 40 40 40 40 40 40 40 40\n", 645 | " 40 40 40 40 40 40 40 40 40 40 40 40 40 40 40 40 40]\n", 646 | "Prediction: [0, 6, 10, 2, 7]\n", 647 | "Prediction text: 06:27\n" 648 | ] 649 | } 650 | ], 651 | "source": [ 652 | "# Let's visually check model output.\n", 653 | "import random as random\n", 654 | "\n", 655 | "i = random.randint(0, m)\n", 656 | "\n", 657 | "def get_prediction(model, x):\n", 658 | " prediction = model.predict(x)\n", 659 | " max_prediction = [y.argmax() for y in prediction]\n", 660 | " str_prediction = \"\".join(ids_to_keys(max_prediction, machine_vocab))\n", 661 | " return (max_prediction, str_prediction)\n", 662 | "\n", 663 | "max_prediction, str_prediction = get_prediction(model, Xoh[i:i+1])\n", 664 | "\n", 665 | "print(\"Input: \" + str(dataset[i][0]))\n", 666 | "print(\"Tokenized: \" + str(X[i]))\n", 667 | "print(\"Prediction: \" + str(max_prediction))\n", 668 | "print(\"Prediction text: \" + str(str_prediction))" 669 | ] 670 | }, 671 | { 672 | "cell_type": "markdown", 673 | "metadata": {}, 674 | "source": [ 675 | "Last but not least, all introductions to Attention networks require a little tour.\n", 676 | "\n", 677 | "The below graph shows what inputs the model was focusing on when writing each individual letter." 678 | ] 679 | }, 680 | { 681 | "cell_type": "code", 682 | "execution_count": 16, 683 | "metadata": {}, 684 | "outputs": [ 685 | { 686 | "data": { 687 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAA58AAACfCAYAAAB++W3hAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAG55JREFUeJzt3Xu8ZXVd//HXe2aAAcHrWMl1MMkkKi8TFy9Jikmo0O+RFiBeiuRnSRdT+WkampdfeemiSdmoiKKCRlZTToJ5wxCRwQsKqE2IMGAhyGVEgTlzPv2x1rF9juecvffMXrP3HF7Px2M9zlrf9V2f9dn7nAf6me93fVeqCkmSJEmSurRs3AlIkiRJkpY+i09JkiRJUucsPiVJkiRJnbP4lCRJkiR1zuJTkiRJktQ5i09JkiRJUucsPiXpHibJM5NcMO485kryqiTvHXceO0qStyX5o3HnIUnSjmLxKUk7QJJPJrklyW5z2s9K8to5bdckOWpE912dpJKsmGmrqvdV1S+OIn7PffZJMpXkx+c59w9J3jTK+026JPsn+W7PVknu6Dl+XFU9v6pes4Pz+mSS39yR95QkaYbFpyR1LMlq4HFAAceONZmOVNX1wMeAZ/W2J7k/cAzw7nHktaP0FvcAVXVtVe05s7XNP9vT9ukxpClJ0lhZfEpS954NfBY4C3jOTGOSU4BnAqe1o2H/nORsYH/gn9u209q+hyf5TJJbk3wpyZE9cT6Z5DVJLkqyOckFSVa1py9sf97axjsiyXOT/HvP9Y9OcmmS29qfjx4w9lzvZk7xCRwPXFlVX27jvTnJdUluT3JZksfNFyjJkUk2zWn7wYhwkmVJXprkP5PcnOSDbaFLkpVJ3tu239p+ph9d4D7XJHlZkivbkel3JVnZc/6pSb7YxvlMkp+Zc+3/S3I5cMfcArSf3lHvmc+b5LQkNyb5VpJfTnJMkq8n+U6SP+y5dujPn+R1NP8I8tb2b+Gtbf+fTPLR9h5fS/Krc3J8W3t+c5JPJTlgmM8pSdIMi09J6t6zgfe125NnCqGqWtu2vaEdDXtaVT0LuBZ4Wtv2hiT7AB8GXgvcH3gx8PdJHthzjxOBXwd+BNi17QPw8+3P+7bxLu5NrC1YPgy8BXgA8OfAh5M8YIDYc/0DsCrJY3vansXsUc9LgYe3n+P9wN/1FntD+B3gl4HHA3sDtwBntOeeA9wH2K/9TM8Hvr9IrGcCTwZ+HPgJ4BUASR4BnAn83zbO3wLrMnvq9AnAU2i+36lt+By9fgxYCewDnA68HTgJeBRN0fhHSQ5s+w79+avq5cCngVPbv4VTk9wL+CjN7+JHaP6x4K+THDzn+3kNsAr4Is3frCRJQ7P4lKQOtYXYAcAHq+oy4D9pirlhnASsr6r1VTVdVR8FNtBMZ53xrqr6elV9H/ggTYE3iKcA/1FVZ1fVVFWdA3wVeNqwsdvzf0dTbJPkIJrC6f09fd5bVTe39/ozYDfgoQPm2uv5wMuralNV3QW8Cnh6O/q4haboekhVba2qy6rq9kVivbWqrquq7wCvoykoAU4B/raqLmnjvBu4Czi859q3tNcuVtwOagvwuqraApxLU+y9uao2V9UVwJXAz4748z8VuKaq3tX+Tr4A/D3wjJ4+H66qC9v7vBw4Isl+I/i8kqR7GItPSerWc4ALquqm9vj99Ey9HdABwDPaKZS3JrkVeCzwoJ4+/9Wz/z1gTwazN/DNOW3fpBl925bY725zXUkz6nl+Vd04czLJi5Nc1U7xvZVmhG6habyLOQD4h57v4ypgK/CjwNnA+cC5SW5I8oYkuywS67qe/W/SfCcz93jRnO99v57zc6/dXjdX1dZ2f6aY/e+e89/nf7/7UX3+A4DD5nzGZ9KMws74wWesqu8C32H2dyBJ0kCGej5FkjS4JLsDvwosTzJTwO0G3DfJz1bVl2gWIZprbtt1wNlV9bxtSGO++L1uoClAeu0PfGQb7gXw7zTFyXE0I7anzZxon+88DXgicEVVTSe5Bcg8ce4A9ui5djnQO834OuA3quqiBfL4Y+CP0yz2tB74GvDOBfr2juLtT/OdzNzjdVX1ugWug/7fb1e29fPP97f1qap60iL3+sH3k2RPminTNyzcXZKk+TnyKUnd+WWa0aiDaaaqPhx4GM1zd89u+/w38OA5181tey/wtCRPTrK8XVDmyCT7DpDDt4Hpee4xYz3wE0lOTLIiya+1+f7LALF/SFUV8B7g9cB9gX/uOb0XMNXmtCLJ6cC9Fwj1dWBlkqe0o3avoCncZ7wNeN3M4jdJHpjkuHb/F5L8dFuw3k4zDXV6kbRfkGTf9vnXlwMfaNvfDjw/yWFp3KvNZ68Bv44ubevnn/u39S80v/9nJdml3X4uycN6+hyT5LFJdqV59vOzVTXKEV9J0j2Exackdec5NM9LXltV/zWzAW8Fntk+n/dO4OB2yuM/ttf9CfCKtu3F7f/RPw74Q5rC7TrgJQzw3/Cq+h7Nc4wXtfEOn3P+Zprn/l4E3EwzMvnUnmnC2+I9NCOIH2ifE5xxPs2I6tdpprfeyQLTVqvqNuC3gXcA19OMhPaufvtmYB1wQZLNNKsJH9ae+zHgPJrC6yrgUzRTURfyfuAC4GqaZ3Jf2+awAXgeze/rFmAj8Nw+n31H2dbP/2aaZ0NvSfKWqtoM/CLNQkM30Eyxfj2zC/33A6+kGdF+FM2ItiRJQ0vzj9SSJN3zJLkG+M2q+rdx5zKJkpwFbKqqV4w7F0nSzs+RT0mSJElS5yw+JUmSJEmzJDkzyY1JvrLA+SR5S5KNSS5P8si+MZ12K0mSJEnqleTnge8C76mqQ+Y5fwzwOzTvHT+M5t3Uh83t18uRT0mSJEnSLFV1Ic1icws5jqYwrar6LM2r5B60SH+LT0mSJEnS0PZh9qr1m9q2Ba3oNJ0h7bpsZe2+fHSvT1vxkJGFYstXt44u2KTbc/fRxfru90cXS5IkSerAndzB3XVXxp1HV578C/eqm78zu5657PK7rqB57dmMtVW1tss8Jqr43H35Xhxxv18ZWbz7v2t0BeO3H7t5ZLEAmJ7cYnb6kY8YWaxln/7CyGJJkiRJXbikPjbuFDp103em+MxHZg9Krtz7G3dW1ZrtCHs9sF/P8b5t24KcditJkiRJS9g0xV01NWsbgXXAs9tVbw8Hbquqby12wUSNfEqSJEmSRquALUwPdU2Sc4AjgVVJNgGvBHYBqKq3AetpVrrdCHwP+PV+MS0+JUmSJGkJK+CuGq74rKoT+pwv4AXDxLT4lCRJkqQlbLqKO6vGnUZ3z3wmOTPJjUm+0tU9JEmSJEmLK8KWmr2NQ5cLDp0FHN1hfEmSJElSHwXcWctnbePQ2bTbqrowyequ4kuSJEmS+psm3Fnjf+Jy/BlIkiRJkjpThLvHNNrZa+zFZ5JTgFMAVi7bc8zZSJIkSdLSMl3hztpl3GmMv/isqrXAWoD77PLA8S/BJEmSJElLSPOeT0c+JUmSJEkdKsKd0+Mf+ezyVSvnABcDD02yKcnJXd1LkiRJkjS/ZsGhXWdt49DlarcndBVbkiRJkjSY5j2fTruVJEmSJHVoupb4tFtJkiRJ0vjNjHz2bv0kOTrJ15JsTPLSec7vn+QTSb6Q5PIkx/SL6cinJEmSJC1hxXCvWkmyHDgDeBKwCbg0ybqqurKn2yuAD1bV3yQ5GFgPrF4srsWnJEmSJC1h0xXuGm7a7aHAxqq6GiDJucBxQG/xWcC92/37ADf0C2rxKUmSJElLWMGwCw7tA1zXc7wJOGxOn1cBFyT5HeBewFH9gk5W8fnjy1i2dnTL/n7rj/YeWaxddvnKyGIBZMXKkcWavuOOkcUCWPbpL4w0niRJkqTxWeA9n6uSbOg5XltVa4cIewJwVlX9WZIjgLOTHFJV0wtdMFnFpyRJkiRppJpptz9U+t1UVWsWuOR6YL+e433btl4nA0cDVNXFSVYCq4AbF8rD1W4lSZIkaQkrwtT08llbH5cCByU5MMmuwPHAujl9rgWeCJDkYcBK4NuLBXXkU5IkSZKWsGLekc+F+1dNJTkVOB9YDpxZVVckeTWwoarWAS8C3p7khTSPlT63qmqxuBafkiRJkrSEVcGWGm7Sa1Wtp3l9Sm/b6T37VwKPGSamxackSZIkLWFFuHuIkc+udJ5B+4LSDcD1VfXUru8nSZIkSfpfVeHurfeA4hP4PeAq/vcFpJIkSZKkHaSAqSGn3Xah0wyS7As8BXhHl/eRJEmSJM1vmnD31uWztnHoeuTzL4HTgL0W6pDkFOAUgJU/umA3SZIkSdI2qIK7+79epXOdjXwmeSpwY1Vdtli/qlpbVWuqas2u9929q3QkSZIk6R6pec/nslnbOHQ58vkY4Ngkx9C8cPTeSd5bVSd1eE9JkiRJUo8q2LKURz6r6mVVtW9VrQaOBz5u4SlJkiRJO1rYOr1s1jYOA901yQ+9PHS+NkmSJEnSZKmCqa3LZm3jMOhd/2rAtnlV1Sd9x6ckSZIk7XhF2DK9fNY2Dos+85nkCODRwAOT/EHPqXsD4580LEmSJEnqa3o6Q/VPcjTwZpq67x1V9afz9PlV4FU0rxL9UlWduFjMfgsO7Qrs2fbrfQ/K7cDTB85ckiRJkjQWM9NuB5VkOXAG8CRgE3BpknVVdWVPn4OAlwGPqapbkvxIv7iLFp9V9SngU0nOqqpvDpytJEmSJGkiFGF6uEWGDgU2VtXVAEnOBY4Druzp8zzgjKq6BaCqbuwXdNBXrZyVpOY2VtUTBrxekiRJkjQONfS0232A63qONwGHzenzEwBJLqKZmvuqqvrIYkEHLT5f3LO/EvgVYGrAawf3zWVM/9aeIwu32/RtI4tVP/ngkcUatRW33THagHfeNbJQUzfeNLJYkiRJUie2jjuBbhWw9Yen3a5KsqHneG1VrR0i7ArgIOBIYF/gwiQ/XVW3LnZB/2SrLpvTdFGSzw2RmCRJkiRpHArqh0c+b6qqNQtccT2wX8/xvm1br03AJVW1BfhGkq/TFKOXLpTGoO/5vH/PtirJk4H7DHKtJEmSJGmcQm2dvfVxKXBQkgOT7AocD6yb0+cfaUY9SbKKZhru1YsFHXTa7WU0o7WhmW77DeDkAa+VJEmSJI1LQQ2x4FBVTSU5FTif5nnOM6vqiiSvBjZU1br23C8muZJm4vJLqurmxeIOOu32wIEzlSRJkiRNlunhulfVemD9nLbTe/YL+IN2G8hAxWeSlcBvA4+lGQH9NPC2qrpz0BtJkiRJksagGGSqbecGnXb7HmAz8Fft8YnA2cAzukhKkiRJkjQ6Ge5VK50YtPg8pKoO7jn+RDu3V5IkSZI0ySowASOfgz51+vkkh88cJDkM2LBIf0mSJEnSJCia4rN3G4NBi89HAZ9Jck2Sa4CLgZ9L8uUkl/e7OMn6JHtvR56SJEmSpG2U6dnbOAw67fbo7blJVR2z0LkkpwCnAKzc5d7bcxtJkiRJ0jx2pmc+X1tVz+ptSHL23LZtUVVrgbUA99n9QbW98SRJkiRJPYqhX7XShUGLz5/qPUiygmYqriRJkiRpwmXruDPo88xnkpcl2Qz8TJLbk2xuj/8b+KdBb+Izn5IkSZI0HinI1szaxmHR4rOq/qSq9gLeWFX3rqq92u0BVfWyQW9SVcdU1Q3bna0kSZIkaWg704JD/5rk5+c2VtWFI85HkiRJkjRKNRnTbgctPl/Ss78SOBS4DHjCyDOSJEmSJI3UsKOdSY4G3gwsB95RVX+6QL9fAc4Dfq6qNiwWc6Dis6qeNucG+wF/Oci1kiRJkqQxquGKzyTLgTOAJwGbgEuTrKuqK+f02wv4PeCSQeIu+sznIjYBD9vGayVJkiRJO0hopt32bn0cCmysqqur6m7gXOC4efq9Bng9cOcgeQw08pnkr2jeDgNNwfoI4PODXCtJkiRJGqMhRz6BfYDreo43AYf1dkjySGC/qvpwkt7HNBc06DOfV9LM9QW4FTinqi4a8NqB1V13Mf0f3xhdvOnq32lAb7x6tB/3tIc+fmSxprZMjSwWANMjfBo5I17GuUb3O5UkSZLuKeYZ7VyVpPcZzbVVtXagWMky4M+B5w6Tw6LFZ5IVwP8HfgO4tm3eHzgzyeeqasswN5MkSZIk7WDzr3Z7U1WtWeCK64H9eo73bdtm7AUcAnwyzWDTjwHrkhy72KJD/Z75fCNwf+DAqnpkVT0SeDBwX+BNfa6VJEmSJE2AId/zeSlwUJIDk+wKHA+smzlZVbdV1aqqWl1Vq4HPAosWntC/+Hwq8Lyq2txzo9uB3wKO6ZuyJEmSJGmsUsMVn1U1BZwKnA9cBXywqq5I8uokx25rHv2e+ayqH37Irqq2JvHhO0mSJEnaCQz7ns+qWg+sn9N2+gJ9jxwkZr+RzyuTPHtuY5KTgK8OcgNJkiRJ0hjV0K9a6US/kc8XAB9K8hvAZW3bGmB34P8sdmGSM2mm7d5YVYdsb6KSJEmSpG2zbEwFZ69Fi8+quh44LMkTgJ9qm9dX1ccGiH0W8FbgPduVoSRJkiRp2xUw5LTbLgz0ns+q+jjw8WECV9WFSVZvQ06SJEmSpBEJsGzr+JfsGaj47FKSU4BTAFayx5izkSRJkqQlZv73fO5w/RYc6lxVra2qNVW1ZpfsNu50JEmSJGnJGfI9n50Y+8inJEmSJKk7qZ1gwSFJkiRJ0s4v0+N/5rOzabdJzgEuBh6aZFOSk7u6lyRJkiRpAQWZmr2NQ2cjn1V1QlexJUmSJEkDqslY7XbsCw5JkiRJkroThl9wKMnRSb6WZGOSl85z/g+SXJnk8iQfS3JAv5gWn5IkSZK0lFWRrbO3xSRZDpwB/BJwMHBCkoPndPsCsKaqfgY4D3hDvzQsPiVJkiRpKSuGKj6BQ4GNVXV1Vd0NnAscNytk1Seq6nvt4WeBffsFtfiUJEmSpCVuyOJzH+C6nuNNbdtCTgb+tV/QyXrVSkFNjWnppT5esvrwkcY7/4ZLRhbryXs/fGSxRq7G/2CzJEmSdE+W+RccWpVkQ8/x2qpaO3Ts5CRgDfD4fn0nq/iUJEmSJI3cPKOdN1XVmgW6Xw/s13O8b9s2O2ZyFPBy4PFVdVe/HJx2K0mSJElL2ZALDgGXAgclOTDJrsDxwLreDkkeAfwtcGxV3ThIGo58SpIkSdJSVpCpAd6vMtO9airJqcD5wHLgzKq6IsmrgQ1VtQ54I7An8HdJAK6tqmMXi2vxKUmSJElLXKaHW4ulqtYD6+e0nd6zf9SwOVh8SpIkSdISlqqhRj670ukzn0lemOSKJF9Jck6SlV3eT5IkSZI0j+np2dsYdFZ8JtkH+F1gTVUdQjNX+Piu7idJkiRJmkf7zGfvNg5dT7tdAeyeZAuwB3BDx/eTJEmSJPWqgqU87baqrgfeBFwLfAu4raou6Op+kiRJkqT5ZXp61jYOXU67vR9wHHAgsDdwryQnzdPvlCQbkmzYQt/3kkqSJEmShlEFU1tnb2PQ5YJDRwHfqKpvV9UW4EPAo+d2qqq1VbWmqtbswm4dpiNJkiRJ90AFbJ2evY1Bl898XgscnmQP4PvAE4ENHd5PkiRJkjRXFUxNjTuL7orPqrokyXnA54Ep4AvA2q7uJ0mSJEmaT8HW8Uy17dXpardV9UrglV3eQ5IkSZK0iGLpF5+SJEmSpDGroiZg2m2XCw5JkiRJksatCrZMzd76SHJ0kq8l2ZjkpfOc3y3JB9rzlyRZ3S+mxackSZIkLXG1deusbTFJlgNnAL8EHAyckOTgOd1OBm6pqocAfwG8vl8OFp+SJEmStJTNrHbbuy3uUGBjVV1dVXcD5wLHzelzHPDudv884IlJslhQi09JkiRJWsKqaqiRT2Af4Lqe401t27x9qmoKuA14wGJBJ2rBoc3cctO/1Xnf7NNtFXDTCG87yngDx1r+oFHG2zhQsMHj7fBYo45nbpMRz9zGH2vS45nb+GNNejxzG3+sSY9nbuOPNenxBo11wIjuN5E2c8v5H536wKo5zSuTbOg5XltVnb4ac6KKz6p6YL8+STZU1ZpR3XOU8SY5t1HHM7fxx5r0eOY2/liTHs/cxh9r0uOZ2/hjTXo8cxt/rEmPN+rcdlZVdfSQl1wP7NdzvG/bNl+fTUlWAPcBbl4sqNNuJUmSJEm9LgUOSnJgkl2B44F1c/qsA57T7j8d+HhV1WJBJ2rkU5IkSZI0XlU1leRU4HxgOXBmVV2R5NXAhqpaB7wTODvJRuA7NAXqonbG4nPU85BHGW+Scxt1PHMbf6xJj2du44816fHMbfyxJj2euY0/1qTHM7fxx5r0eJ0+w7iUVdV6YP2cttN79u8EnjFMzPQZGZUkSZIkabv5zKckSZIkqXM7TfGZZL8kn0hyZZIrkvzeuHOakeTMJDcm+coIY76w/ZxfSXJOkpWjii1JkiRJO9pOU3wCU8CLqupg4HDgBUkOHnNOM84Chl2+eEFJ9gF+F1hTVYfQPOTb9wFeSZIkSZpUO03xWVXfqqrPt/ubgauAfcabVaOqLqRZ4WmUVgC7t+/M2QO4YcTxt0mS1Um+muR9Sa5Kcl6SPSYkp7OSfL3N7agkFyX5jySHjjM/SZIkSTtR8dkryWrgEcAl482kG1V1PfAm4FrgW8BtVXXBeLOa5aHAX1fVw4Dbgd8ecz4ADwH+DPjJdjsReCzwYuAPx5iXJEmSJHbC4jPJnsDfA79fVbePO58uJLkfcBxwILA3cK8kJ403q1muq6qL2v330hR54/aNqvpyVU0DVwAfa19y+2Vg9VgzkyRJkrRzFZ9JdqEpPN9XVR8adz4dOoqmmPp2VW0BPgQ8esw59Zr7fp5JeF/PXT370z3H0+yc77OVJEmSlpSdpvhMEuCdwFVV9efjzqdj1wKHJ9mj/dxPpHnGdVLsn+SIdv9E4N/HmYwkSZKkybfTFJ/AY4BnAU9I8sV2O2Z7AiZZn2Tv7U0syTnAxcBDk2xKcvL2xKuqS4DzgM/TTBtdBqzd3jxH6Gs0qw1fBdwP+JvtCTaq30MXJjk3SZIkaWeS5rE4aTDtYk//0r4CRpIkSZIGsjONfEqSJEmSdlKOfEqSJEmSOufIpyRJkiSpcxafkiRJkqTOWXxKkiRJkjpn8SlJmkhJvttBzNVJThx1XEmS1J/FpyTpnmQ1YPEpSdIYWHxKkiZakiOTfDLJeUm+muR9SdKeuybJG5J8OcnnkjykbT8rydN7YsyMov4p8LgkX0zywh3/aSRJuuey+JQk7QweAfw+cDDwYOAxPeduq6qfBt4K/GWfOC8FPl1VD6+qv+gkU0mSNC+LT0nSzuBzVbWpqqaBL9JMn51xTs/PI3Z0YpIkaTAWn5KkncFdPftbgRU9xzXP/hTt/8YlWQbs2ml2kiSpL4tPSdLO7td6fl7c7l8DPKrdPxbYpd3fDOy1wzKTJEk/sKJ/F0mSJtr9klxOMzp6Qtv2duCfknwJ+AhwR9t+ObC1bT/L5z4lSdpxUlX9e0mSNIGSXAOsqaqbxp2LJElanNNuJUmSJEmdc+RTkiRJktQ5Rz4lSZIkSZ2z+JQkSZIkdc7iU5IkSZLUOYtPSZIkSVLnLD4lSZIkSZ2z+JQkSZIkde5/AO+TGxMoUnl3AAAAAElFTkSuQmCC\n", 688 | "text/plain": [ 689 | "" 690 | ] 691 | }, 692 | "metadata": {}, 693 | "output_type": "display_data" 694 | } 695 | ], 696 | "source": [ 697 | "i = random.randint(0, m)\n", 698 | "\n", 699 | "def plot_attention_graph(model, x, Tx, Ty, human_vocab, layer=7):\n", 700 | " # Process input\n", 701 | " tokens = np.array([tokenize(x, human_vocab, Tx)])\n", 702 | " tokens_oh = oh_2d(tokens, len(human_vocab))\n", 703 | " \n", 704 | " # Monitor model layer\n", 705 | " layer = model.layers[layer]\n", 706 | " \n", 707 | " layer_over_time = K.function(model.inputs, [layer.get_output_at(t) for t in range(Ty)])\n", 708 | " layer_output = layer_over_time([tokens_oh])\n", 709 | " layer_output = [row.flatten().tolist() for row in layer_output]\n", 710 | " \n", 711 | " # Get model output\n", 712 | " prediction = get_prediction(model, tokens_oh)[1]\n", 713 | " \n", 714 | " # Graph the data\n", 715 | " fig = plt.figure()\n", 716 | " fig.set_figwidth(20)\n", 717 | " fig.set_figheight(1.8)\n", 718 | " ax = fig.add_subplot(111)\n", 719 | " \n", 720 | " plt.title(\"Attention Values per Timestep\")\n", 721 | " \n", 722 | " plt.rc('figure')\n", 723 | " cax = plt.imshow(layer_output, vmin=0, vmax=1)\n", 724 | " fig.colorbar(cax)\n", 725 | " \n", 726 | " plt.xlabel(\"Input\")\n", 727 | " ax.set_xticks(range(Tx))\n", 728 | " ax.set_xticklabels(x)\n", 729 | " \n", 730 | " plt.ylabel(\"Output\")\n", 731 | " ax.set_yticks(range(Ty))\n", 732 | " ax.set_yticklabels(prediction)\n", 733 | " \n", 734 | " plt.show()\n", 735 | " \n", 736 | "plot_attention_graph(model, dataset[i][0], Tx, Ty, human_vocab)" 737 | ] 738 | }, 739 | { 740 | "cell_type": "code", 741 | "execution_count": null, 742 | "metadata": {}, 743 | "outputs": [], 744 | "source": [] 745 | } 746 | ], 747 | "metadata": { 748 | "kernelspec": { 749 | "display_name": "Python 3", 750 | "language": "python", 751 | "name": "python3" 752 | }, 753 | "language_info": { 754 | "codemirror_mode": { 755 | "name": "ipython", 756 | "version": 3 757 | }, 758 | "file_extension": ".py", 759 | "mimetype": "text/x-python", 760 | "name": "python", 761 | "nbconvert_exporter": "python", 762 | "pygments_lexer": "ipython3", 763 | "version": "3.6.1" 764 | } 765 | }, 766 | "nbformat": 4, 767 | "nbformat_minor": 2 768 | } 769 | -------------------------------------------------------------------------------- /README.MD: -------------------------------------------------------------------------------- 1 | ## Description 2 | This is a ready-to-run example of an attention network. All attention network code is contained underneath Section "Model" of the Attention Network ipython notebook. 3 | 4 | For a tutorial on Attention Networks, see [MuffinTech](http://muffintech.org/blog/id/12). 5 | 6 | ## Requirements 7 | 8 | Run `pip install -r requirements.txt` to install requirements. The project requires python 3. 9 | 10 | ## Contributing 11 | If there are any improvements or broken links, feel free to create a PR. (Email for quick reponses.) 12 | -------------------------------------------------------------------------------- /data/Time Vocabs.json: -------------------------------------------------------------------------------- 1 | [{" ": 0, "'": 1, ".": 2, "0": 3, "1": 4, "2": 5, "3": 6, "4": 7, "5": 8, "6": 9, "7": 10, "8": 11, "9": 12, ":": 13, "a": 14, "b": 15, "c": 16, "d": 17, "e": 18, "f": 19, "g": 20, "h": 21, "i": 22, "k": 23, "l": 24, "m": 25, "n": 26, "o": 27, "p": 28, "q": 29, "r": 30, "s": 31, "t": 32, "u": 33, "v": 34, "w": 35, "x": 36, "y": 37, "z": 38, "": 39, "": 40}, {"0": 0, "1": 1, "2": 2, "3": 3, "4": 4, "5": 5, "6": 6, "7": 7, "8": 8, "9": 9, ":": 10}] -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | Keras>=2.1.6 2 | numpy 3 | matplotlib 4 | jupyter 5 | --------------------------------------------------------------------------------