├── LICENSE ├── PAIGCP_Bag_of_Words.ipynb ├── PAIGCP_Basic_MNIST.ipynb ├── PAIGCP_Basic_XOR.ipynb ├── PAIGCP_CNN_layers.ipynb ├── PAIGCP_Embeddings.ipynb ├── PAIGCP_Generation.ipynb ├── PAIGCP_Translation.ipynb ├── PAIGCP_XOR_perceptron.ipynb ├── PAIGCP_autoencoder.ipynb ├── PAIGCP_fashion_CNN.ipynb ├── PAIGCP_image_captioning.ipynb ├── PAIGCP_text_cleaning.ipynb └── README.md /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /PAIGCP_Bag_of_Words.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "PAIGCP_Bag_of_Words", 7 | "provenance": [], 8 | "collapsed_sections": [], 9 | "authorship_tag": "ABX9TyMaC0WHSnWsQoGaeTtdU7A3", 10 | "include_colab_link": true 11 | }, 12 | "kernelspec": { 13 | "name": "python3", 14 | "display_name": "Python 3" 15 | } 16 | }, 17 | "cells": [ 18 | { 19 | "cell_type": "markdown", 20 | "metadata": { 21 | "id": "view-in-github", 22 | "colab_type": "text" 23 | }, 24 | "source": [ 25 | "\"Open" 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "metadata": { 31 | "id": "DbkMepZnYVHs", 32 | "colab_type": "code", 33 | "colab": {} 34 | }, 35 | "source": [ 36 | "import urllib.request" 37 | ], 38 | "execution_count": 10, 39 | "outputs": [] 40 | }, 41 | { 42 | "cell_type": "markdown", 43 | "metadata": { 44 | "id": "SLUJ5DYpY_yH", 45 | "colab_type": "text" 46 | }, 47 | "source": [ 48 | "Pull down Mody Dick" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "metadata": { 54 | "id": "dPyg06asYkY_", 55 | "colab_type": "code", 56 | "colab": {} 57 | }, 58 | "source": [ 59 | "url = \"https://www.gutenberg.org/files/2701/2701-0.txt\"\n", 60 | "file = urllib.request.urlopen(url)" 61 | ], 62 | "execution_count": 11, 63 | "outputs": [] 64 | }, 65 | { 66 | "cell_type": "markdown", 67 | "metadata": { 68 | "id": "XFw0-dPOZDCZ", 69 | "colab_type": "text" 70 | }, 71 | "source": [ 72 | "Load the text" 73 | ] 74 | }, 75 | { 76 | "cell_type": "code", 77 | "metadata": { 78 | "id": "ER0LYVxLYo2k", 79 | "colab_type": "code", 80 | "colab": { 81 | "base_uri": "https://localhost:8080/", 82 | "height": 52 83 | }, 84 | "outputId": "f0934ca1-dff6-4f83-893a-8b3daedc9760" 85 | }, 86 | "source": [ 87 | "text = [line.decode('utf-8') for line in file]\n", 88 | "text = ''.join(text)\n", 89 | "text[7600:8000]" 90 | ], 91 | "execution_count": 12, 92 | "outputs": [ 93 | { 94 | "output_type": "execute_result", 95 | "data": { 96 | "application/vnd.google.colaboratory.intrinsic+json": { 97 | "type": "string" 98 | }, 99 | "text/plain": [ 100 | "'ok whatsoever,\\r\\n sacred or profane. Therefore you must not, in every case at least,\\r\\n take the higgledy-piggledy whale statements, however authentic, in\\r\\n these extracts, for veritable gospel cetology. Far from it. As\\r\\n touching the ancient authors generally, as well as the poets here\\r\\n appearing, these extracts are solely valuable or entertaining, as\\r\\n affording a glancing bird’s eye view o'" 101 | ] 102 | }, 103 | "metadata": { 104 | "tags": [] 105 | }, 106 | "execution_count": 12 107 | } 108 | ] 109 | }, 110 | { 111 | "cell_type": "markdown", 112 | "metadata": { 113 | "id": "yN2YwlWcY72h", 114 | "colab_type": "text" 115 | }, 116 | "source": [ 117 | "Tokenize\n" 118 | ] 119 | }, 120 | { 121 | "cell_type": "code", 122 | "metadata": { 123 | "id": "WmXGjk_YY9Wc", 124 | "colab_type": "code", 125 | "colab": { 126 | "base_uri": "https://localhost:8080/", 127 | "height": 425 128 | }, 129 | "outputId": "b841031b-4127-438a-f396-9bed2d2bd687" 130 | }, 131 | "source": [ 132 | "import nltk\n", 133 | "nltk.download('punkt')\n", 134 | "from nltk import word_tokenize\n", 135 | "tokens = word_tokenize(text)\n", 136 | "tokens[200:222]" 137 | ], 138 | "execution_count": 13, 139 | "outputs": [ 140 | { 141 | "output_type": "stream", 142 | "text": [ 143 | "[nltk_data] Downloading package punkt to /root/nltk_data...\n", 144 | "[nltk_data] Package punkt is already up-to-date!\n" 145 | ], 146 | "name": "stdout" 147 | }, 148 | { 149 | "output_type": "execute_result", 150 | "data": { 151 | "text/plain": [ 152 | "['.',\n", 153 | " 'A',\n", 154 | " 'Bosom',\n", 155 | " 'Friend',\n", 156 | " '.',\n", 157 | " 'CHAPTER',\n", 158 | " '11',\n", 159 | " '.',\n", 160 | " 'Nightgown',\n", 161 | " '.',\n", 162 | " 'CHAPTER',\n", 163 | " '12',\n", 164 | " '.',\n", 165 | " 'Biographical',\n", 166 | " '.',\n", 167 | " 'CHAPTER',\n", 168 | " '13',\n", 169 | " '.',\n", 170 | " 'Wheelbarrow',\n", 171 | " '.',\n", 172 | " 'CHAPTER',\n", 173 | " '14']" 174 | ] 175 | }, 176 | "metadata": { 177 | "tags": [] 178 | }, 179 | "execution_count": 13 180 | } 181 | ] 182 | }, 183 | { 184 | "cell_type": "markdown", 185 | "metadata": { 186 | "id": "GHvzrkUsZUs8", 187 | "colab_type": "text" 188 | }, 189 | "source": [ 190 | "Then clean" 191 | ] 192 | }, 193 | { 194 | "cell_type": "code", 195 | "metadata": { 196 | "id": "bb2pv0mLZLxV", 197 | "colab_type": "code", 198 | "colab": { 199 | "base_uri": "https://localhost:8080/", 200 | "height": 391 201 | }, 202 | "outputId": "08e76702-62f0-460f-b175-aa43a9a7952f" 203 | }, 204 | "source": [ 205 | "import string\n", 206 | "tokens = [word for word in tokens if word.isalpha()]\n", 207 | "table = str.maketrans('', '', string.punctuation)\n", 208 | "tokens = [w.translate(table) for w in tokens]\n", 209 | "tokens = [word.lower() for word in tokens]\n", 210 | "tokens[200:222]" 211 | ], 212 | "execution_count": 14, 213 | "outputs": [ 214 | { 215 | "output_type": "execute_result", 216 | "data": { 217 | "text/plain": [ 218 | "['specksnyder',\n", 219 | " 'chapter',\n", 220 | " 'the',\n", 221 | " 'chapter',\n", 222 | " 'the',\n", 223 | " 'chapter',\n", 224 | " 'the',\n", 225 | " 'chapter',\n", 226 | " 'sunset',\n", 227 | " 'chapter',\n", 228 | " 'dusk',\n", 229 | " 'chapter',\n", 230 | " 'first',\n", 231 | " 'chapter',\n", 232 | " 'midnight',\n", 233 | " 'forecastle',\n", 234 | " 'chapter',\n", 235 | " 'moby',\n", 236 | " 'dick',\n", 237 | " 'chapter',\n", 238 | " 'the',\n", 239 | " 'whiteness']" 240 | ] 241 | }, 242 | "metadata": { 243 | "tags": [] 244 | }, 245 | "execution_count": 14 246 | } 247 | ] 248 | }, 249 | { 250 | "cell_type": "markdown", 251 | "metadata": { 252 | "id": "zA8tdJU0ZXZC", 253 | "colab_type": "text" 254 | }, 255 | "source": [ 256 | "Stop Words and Stemming" 257 | ] 258 | }, 259 | { 260 | "cell_type": "code", 261 | "metadata": { 262 | "id": "_qQo2cT7ZZg5", 263 | "colab_type": "code", 264 | "colab": { 265 | "base_uri": "https://localhost:8080/", 266 | "height": 425 267 | }, 268 | "outputId": "f7f2880e-7d28-4542-cc24-5a63166b9827" 269 | }, 270 | "source": [ 271 | "from nltk.corpus import stopwords\n", 272 | "nltk.download('stopwords')\n", 273 | "stop_words = set(stopwords.words('english'))\n", 274 | "tokens = [w for w in tokens if not w in stop_words]\n", 275 | "\n", 276 | "from nltk.stem.porter import PorterStemmer\n", 277 | "porter = PorterStemmer()\n", 278 | "tokens = [porter.stem(word) for word in tokens]\n", 279 | "tokens[200:222]" 280 | ], 281 | "execution_count": 15, 282 | "outputs": [ 283 | { 284 | "output_type": "stream", 285 | "text": [ 286 | "[nltk_data] Downloading package stopwords to /root/nltk_data...\n", 287 | "[nltk_data] Unzipping corpora/stopwords.zip.\n" 288 | ], 289 | "name": "stdout" 290 | }, 291 | { 292 | "output_type": "execute_result", 293 | "data": { 294 | "text/plain": [ 295 | "['wood',\n", 296 | " 'stone',\n", 297 | " 'mountain',\n", 298 | " 'star',\n", 299 | " 'chapter',\n", 300 | " 'brit',\n", 301 | " 'chapter',\n", 302 | " 'squid',\n", 303 | " 'chapter',\n", 304 | " 'line',\n", 305 | " 'chapter',\n", 306 | " 'stubb',\n", 307 | " 'kill',\n", 308 | " 'whale',\n", 309 | " 'chapter',\n", 310 | " 'dart',\n", 311 | " 'chapter',\n", 312 | " 'crotch',\n", 313 | " 'chapter',\n", 314 | " 'stubb',\n", 315 | " 'supper',\n", 316 | " 'chapter']" 317 | ] 318 | }, 319 | "metadata": { 320 | "tags": [] 321 | }, 322 | "execution_count": 15 323 | } 324 | ] 325 | }, 326 | { 327 | "cell_type": "markdown", 328 | "metadata": { 329 | "id": "2EHfQLMOZoOt", 330 | "colab_type": "text" 331 | }, 332 | "source": [ 333 | "Vocabulary - word counts" 334 | ] 335 | }, 336 | { 337 | "cell_type": "code", 338 | "metadata": { 339 | "id": "Npw5JfpeZp-2", 340 | "colab_type": "code", 341 | "colab": { 342 | "base_uri": "https://localhost:8080/", 343 | "height": 187 344 | }, 345 | "outputId": "e295aa34-995b-4615-ef3f-538eedbbe44d" 346 | }, 347 | "source": [ 348 | "from nltk.probability import FreqDist\n", 349 | "\n", 350 | "word_counts = FreqDist(tokens)\n", 351 | "\n", 352 | "top = 500\n", 353 | "vocabulary = word_counts.most_common(top)\n", 354 | " \n", 355 | "vocabulary[:10]" 356 | ], 357 | "execution_count": 17, 358 | "outputs": [ 359 | { 360 | "output_type": "execute_result", 361 | "data": { 362 | "text/plain": [ 363 | "[('whale', 1455),\n", 364 | " ('one', 920),\n", 365 | " ('like', 590),\n", 366 | " ('upon', 567),\n", 367 | " ('ship', 553),\n", 368 | " ('ye', 521),\n", 369 | " ('man', 496),\n", 370 | " ('ahab', 495),\n", 371 | " ('sea', 461),\n", 372 | " ('seem', 460)]" 373 | ] 374 | }, 375 | "metadata": { 376 | "tags": [] 377 | }, 378 | "execution_count": 17 379 | } 380 | ] 381 | }, 382 | { 383 | "cell_type": "markdown", 384 | "metadata": { 385 | "id": "mXj4YLGYbccP", 386 | "colab_type": "text" 387 | }, 388 | "source": [ 389 | "Count Vector - word vector" 390 | ] 391 | }, 392 | { 393 | "cell_type": "code", 394 | "metadata": { 395 | "id": "i6txpap3bZmj", 396 | "colab_type": "code", 397 | "colab": { 398 | "base_uri": "https://localhost:8080/", 399 | "height": 34 400 | }, 401 | "outputId": "c001cdd5-972b-486f-d677-327132f2b1b4" 402 | }, 403 | "source": [ 404 | "import numpy as np\n", 405 | "\n", 406 | "voc_size = len(vocabulary)\n", 407 | "doc_vector = np.zeros(voc_size)\n", 408 | " \n", 409 | "word_vector = [(idx,word_counts[word[0]]) for idx, word in enumerate(vocabulary) if word[0] in word_counts.keys()] \n", 410 | "word_vector[10]" 411 | ], 412 | "execution_count": 18, 413 | "outputs": [ 414 | { 415 | "output_type": "execute_result", 416 | "data": { 417 | "text/plain": [ 418 | "(10, 443)" 419 | ] 420 | }, 421 | "metadata": { 422 | "tags": [] 423 | }, 424 | "execution_count": 18 425 | } 426 | ] 427 | }, 428 | { 429 | "cell_type": "code", 430 | "metadata": { 431 | "id": "rrUAuUyhb0s8", 432 | "colab_type": "code", 433 | "colab": { 434 | "base_uri": "https://localhost:8080/", 435 | "height": 969 436 | }, 437 | "outputId": "764493bb-1161-44d6-8a1e-7d817960a975" 438 | }, 439 | "source": [ 440 | "for idx, count in word_vector:\n", 441 | " doc_vector[idx] = count\n", 442 | "\n", 443 | "doc_vector" 444 | ], 445 | "execution_count": 19, 446 | "outputs": [ 447 | { 448 | "output_type": "execute_result", 449 | "data": { 450 | "text/plain": [ 451 | "array([1455., 920., 590., 567., 553., 521., 496., 495., 461.,\n", 452 | " 460., 443., 434., 429., 424., 364., 342., 338., 337.,\n", 453 | " 331., 322., 317., 315., 312., 312., 311., 307., 307.,\n", 454 | " 298., 292., 284., 280., 278., 277., 277., 268., 268.,\n", 455 | " 266., 256., 255., 251., 249., 247., 243., 241., 240.,\n", 456 | " 238., 236., 231., 230., 228., 224., 222., 217., 217.,\n", 457 | " 215., 211., 211., 205., 204., 204., 203., 201., 201.,\n", 458 | " 196., 196., 193., 192., 191., 190., 189., 184., 182.,\n", 459 | " 182., 180., 179., 176., 176., 175., 171., 171., 168.,\n", 460 | " 168., 168., 167., 167., 164., 161., 159., 159., 159.,\n", 461 | " 153., 153., 153., 152., 148., 143., 142., 140., 139.,\n", 462 | " 138., 137., 134., 132., 132., 130., 129., 129., 128.,\n", 463 | " 128., 127., 126., 126., 125., 125., 125., 124., 123.,\n", 464 | " 122., 122., 122., 121., 121., 121., 120., 119., 119.,\n", 465 | " 119., 119., 118., 118., 115., 115., 114., 113., 113.,\n", 466 | " 113., 112., 112., 110., 110., 107., 107., 106., 106.,\n", 467 | " 105., 105., 105., 105., 104., 104., 103., 103., 102.,\n", 468 | " 102., 100., 99., 98., 98., 97., 97., 97., 96.,\n", 469 | " 96., 94., 94., 94., 93., 93., 93., 93., 92.,\n", 470 | " 91., 91., 91., 91., 90., 90., 90., 90., 90.,\n", 471 | " 89., 89., 88., 88., 88., 87., 87., 87., 87.,\n", 472 | " 87., 87., 86., 86., 86., 84., 84., 84., 83.,\n", 473 | " 83., 83., 82., 82., 82., 81., 81., 80., 80.,\n", 474 | " 80., 80., 80., 80., 80., 80., 80., 79., 79.,\n", 475 | " 78., 78., 78., 78., 78., 78., 77., 77., 76.,\n", 476 | " 76., 76., 76., 76., 76., 76., 76., 75., 75.,\n", 477 | " 75., 75., 75., 75., 75., 75., 75., 74., 74.,\n", 478 | " 74., 74., 73., 72., 72., 72., 72., 71., 71.,\n", 479 | " 71., 70., 70., 69., 69., 69., 68., 68., 68.,\n", 480 | " 68., 68., 68., 68., 67., 67., 67., 67., 67.,\n", 481 | " 66., 66., 65., 65., 65., 64., 64., 64., 64.,\n", 482 | " 64., 64., 63., 63., 63., 63., 62., 62., 62.,\n", 483 | " 61., 61., 61., 60., 60., 60., 60., 60., 60.,\n", 484 | " 60., 60., 59., 59., 59., 59., 59., 59., 57.,\n", 485 | " 57., 57., 57., 57., 56., 56., 56., 56., 56.,\n", 486 | " 56., 56., 56., 56., 56., 56., 56., 56., 55.,\n", 487 | " 55., 55., 55., 55., 55., 54., 54., 54., 54.,\n", 488 | " 54., 54., 54., 54., 54., 54., 53., 53., 53.,\n", 489 | " 53., 53., 53., 53., 53., 52., 52., 52., 52.,\n", 490 | " 52., 52., 52., 51., 51., 51., 51., 51., 51.,\n", 491 | " 51., 51., 51., 51., 50., 50., 50., 50., 50.,\n", 492 | " 50., 50., 50., 50., 50., 49., 49., 49., 49.,\n", 493 | " 49., 49., 48., 48., 48., 48., 48., 48., 48.,\n", 494 | " 48., 47., 47., 47., 47., 47., 47., 47., 47.,\n", 495 | " 47., 47., 47., 47., 46., 46., 46., 46., 46.,\n", 496 | " 46., 46., 45., 45., 45., 45., 45., 45., 45.,\n", 497 | " 45., 45., 45., 45., 45., 45., 45., 44., 44.,\n", 498 | " 44., 44., 44., 44., 44., 44., 44., 44., 44.,\n", 499 | " 43., 43., 43., 43., 43., 43., 43., 43., 43.,\n", 500 | " 43., 43., 43., 43., 43., 43., 42., 42., 42.,\n", 501 | " 42., 42., 42., 42., 42., 42., 42., 42., 42.,\n", 502 | " 42., 42., 42., 42., 42., 42., 41., 41., 41.,\n", 503 | " 41., 41., 41., 41., 41., 41., 41., 41., 41.,\n", 504 | " 41., 41., 41., 41., 41., 40., 40., 40., 40.,\n", 505 | " 40., 40., 40., 40., 40., 40., 40., 40., 40.,\n", 506 | " 40., 39., 39., 39., 39.])" 507 | ] 508 | }, 509 | "metadata": { 510 | "tags": [] 511 | }, 512 | "execution_count": 19 513 | } 514 | ] 515 | }, 516 | { 517 | "cell_type": "markdown", 518 | "metadata": { 519 | "id": "9NTcOHHQcBxI", 520 | "colab_type": "text" 521 | }, 522 | "source": [ 523 | "Bag of Words\n" 524 | ] 525 | }, 526 | { 527 | "cell_type": "code", 528 | "metadata": { 529 | "id": "1zw-9VhdcAsM", 530 | "colab_type": "code", 531 | "colab": { 532 | "base_uri": "https://localhost:8080/", 533 | "height": 68 534 | }, 535 | "outputId": "49fac49d-0f63-4129-a561-b72f8eb6fdae" 536 | }, 537 | "source": [ 538 | "from nltk import sent_tokenize\n", 539 | "\n", 540 | "docs = sent_tokenize(text)[703:706]\n", 541 | "docs" 542 | ], 543 | "execution_count": 20, 544 | "outputs": [ 545 | { 546 | "output_type": "execute_result", 547 | "data": { 548 | "text/plain": [ 549 | "['I began to twitch all over.',\n", 550 | " 'Besides, it was getting late, and my decent harpooneer ought to be home\\r\\nand going bedwards.',\n", 551 | " 'Suppose now, he should tumble in upon me at\\r\\nmidnight—how could I tell from what vile hole he had been coming?']" 552 | ] 553 | }, 554 | "metadata": { 555 | "tags": [] 556 | }, 557 | "execution_count": 20 558 | } 559 | ] 560 | }, 561 | { 562 | "cell_type": "markdown", 563 | "metadata": { 564 | "id": "VXuVQoHjcv7M", 565 | "colab_type": "text" 566 | }, 567 | "source": [ 568 | "Import helpers" 569 | ] 570 | }, 571 | { 572 | "cell_type": "code", 573 | "metadata": { 574 | "id": "yZirOVoEc1AK", 575 | "colab_type": "code", 576 | "colab": { 577 | "base_uri": "https://localhost:8080/", 578 | "height": 34 579 | }, 580 | "outputId": "ddce57c6-1b04-4e54-8b6b-1ebcaaa23cc4" 581 | }, 582 | "source": [ 583 | "from sklearn.feature_extraction.text import CountVectorizer\n", 584 | "\n", 585 | "count_vectorizer=CountVectorizer(stop_words='english')\n", 586 | "\n", 587 | "word_count_vector=count_vectorizer.fit_transform(docs)\n", 588 | "word_count_vector.shape" 589 | ], 590 | "execution_count": 21, 591 | "outputs": [ 592 | { 593 | "output_type": "execute_result", 594 | "data": { 595 | "text/plain": [ 596 | "(3, 17)" 597 | ] 598 | }, 599 | "metadata": { 600 | "tags": [] 601 | }, 602 | "execution_count": 21 603 | } 604 | ] 605 | }, 606 | { 607 | "cell_type": "code", 608 | "metadata": { 609 | "id": "AAFg_lV8c_3H", 610 | "colab_type": "code", 611 | "colab": { 612 | "base_uri": "https://localhost:8080/", 613 | "height": 68 614 | }, 615 | "outputId": "93ee33b0-dcf9-4c23-bb2e-c289b7e63296" 616 | }, 617 | "source": [ 618 | "word_count_vector.toarray()" 619 | ], 620 | "execution_count": 22, 621 | "outputs": [ 622 | { 623 | "output_type": "execute_result", 624 | "data": { 625 | "text/plain": [ 626 | "array([[0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0],\n", 627 | " [1, 0, 0, 1, 1, 1, 1, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0],\n", 628 | " [0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 1, 0, 1]])" 629 | ] 630 | }, 631 | "metadata": { 632 | "tags": [] 633 | }, 634 | "execution_count": 22 635 | } 636 | ] 637 | }, 638 | { 639 | "cell_type": "code", 640 | "metadata": { 641 | "id": "-6-tLvWEdEzX", 642 | "colab_type": "code", 643 | "colab": { 644 | "base_uri": "https://localhost:8080/", 645 | "height": 306 646 | }, 647 | "outputId": "33b223ca-d27d-4bd7-98e3-49efbd2b785f" 648 | }, 649 | "source": [ 650 | "count_vectorizer.get_feature_names()" 651 | ], 652 | "execution_count": 23, 653 | "outputs": [ 654 | { 655 | "output_type": "execute_result", 656 | "data": { 657 | "text/plain": [ 658 | "['bedwards',\n", 659 | " 'began',\n", 660 | " 'coming',\n", 661 | " 'decent',\n", 662 | " 'getting',\n", 663 | " 'going',\n", 664 | " 'harpooneer',\n", 665 | " 'hole',\n", 666 | " 'home',\n", 667 | " 'late',\n", 668 | " 'midnight',\n", 669 | " 'ought',\n", 670 | " 'suppose',\n", 671 | " 'tell',\n", 672 | " 'tumble',\n", 673 | " 'twitch',\n", 674 | " 'vile']" 675 | ] 676 | }, 677 | "metadata": { 678 | "tags": [] 679 | }, 680 | "execution_count": 23 681 | } 682 | ] 683 | } 684 | ] 685 | } -------------------------------------------------------------------------------- /PAIGCP_Basic_MNIST.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "PAIGCP_Basic_MNIST.ipynb", 7 | "provenance": [], 8 | "authorship_tag": "ABX9TyPBneCG3ZzhL1NgmTDJEMRy", 9 | "include_colab_link": true 10 | }, 11 | "kernelspec": { 12 | "name": "python3", 13 | "display_name": "Python 3" 14 | } 15 | }, 16 | "cells": [ 17 | { 18 | "cell_type": "markdown", 19 | "metadata": { 20 | "id": "view-in-github", 21 | "colab_type": "text" 22 | }, 23 | "source": [ 24 | "\"Open" 25 | ] 26 | }, 27 | { 28 | "cell_type": "markdown", 29 | "metadata": { 30 | "id": "KGZmnlgHrjB9", 31 | "colab_type": "text" 32 | }, 33 | "source": [ 34 | "A Basic MNIST example" 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "metadata": { 40 | "id": "P5UbgbKBrtnr", 41 | "colab_type": "code", 42 | "colab": {} 43 | }, 44 | "source": [ 45 | "import tensorflow as tf\n", 46 | "import numpy as np\n", 47 | "import matplotlib.pyplot as plt" 48 | ], 49 | "execution_count": 1, 50 | "outputs": [] 51 | }, 52 | { 53 | "cell_type": "markdown", 54 | "metadata": { 55 | "id": "UZYqeAh_s3NI", 56 | "colab_type": "text" 57 | }, 58 | "source": [ 59 | "Loading and normalizing data" 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "metadata": { 65 | "id": "WXkDlySQs9p4", 66 | "colab_type": "code", 67 | "colab": { 68 | "base_uri": "https://localhost:8080/", 69 | "height": 51 70 | }, 71 | "outputId": "a62dadbf-853f-4a61-e21c-bdeaa7d2a6d5" 72 | }, 73 | "source": [ 74 | "mnist = tf.keras.datasets.mnist\n", 75 | "\n", 76 | "(x_train, y_train), (x_test, y_test) = mnist.load_data()\n", 77 | "x_train, x_test = x_train / 255.0, x_test / 255.0" 78 | ], 79 | "execution_count": 2, 80 | "outputs": [ 81 | { 82 | "output_type": "stream", 83 | "text": [ 84 | "Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz\n", 85 | "11493376/11490434 [==============================] - 0s 0us/step\n" 86 | ], 87 | "name": "stdout" 88 | } 89 | ] 90 | }, 91 | { 92 | "cell_type": "code", 93 | "metadata": { 94 | "id": "xvTjRrhLaKj6", 95 | "colab_type": "code", 96 | "colab": { 97 | "base_uri": "https://localhost:8080/", 98 | "height": 51 99 | }, 100 | "outputId": "98cdf407-c180-4c6c-de79-a2789ceafdbf" 101 | }, 102 | "source": [ 103 | "print(x_train.shape)\n", 104 | "print(y_train.shape)" 105 | ], 106 | "execution_count": 3, 107 | "outputs": [ 108 | { 109 | "output_type": "stream", 110 | "text": [ 111 | "(60000, 28, 28)\n", 112 | "(60000,)\n" 113 | ], 114 | "name": "stdout" 115 | } 116 | ] 117 | }, 118 | { 119 | "cell_type": "code", 120 | "metadata": { 121 | "id": "JGMIErJFZ4ed", 122 | "colab_type": "code", 123 | "colab": { 124 | "base_uri": "https://localhost:8080/", 125 | "height": 282 126 | }, 127 | "outputId": "b9a8c9cd-b7f2-457c-c458-c88418c05ca0" 128 | }, 129 | "source": [ 130 | "plt.imshow(x_test[0])\n", 131 | "print(y_test[0])" 132 | ], 133 | "execution_count": 4, 134 | "outputs": [ 135 | { 136 | "output_type": "stream", 137 | "text": [ 138 | "7\n" 139 | ], 140 | "name": "stdout" 141 | }, 142 | { 143 | "output_type": "display_data", 144 | "data": { 145 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAANiklEQVR4nO3df4wc9XnH8c8n/kV8QGtDcF3j4ISQqE4aSHWBRNDKESUFImSiJBRLtVyJ5lALElRRW0QVBalVSlEIok0aySluHESgaQBhJTSNa6W1UKljg4yxgdaEmsau8QFOaxPAP/DTP24cHXD7vWNndmft5/2SVrs7z87Oo/F9PLMzO/t1RAjA8e9tbTcAoD8IO5AEYQeSIOxAEoQdSGJ6Pxc207PiBA31c5FAKq/qZzoYBzxRrVbYbV8s6XZJ0yT9bUTcXHr9CRrSeb6wziIBFGyIdR1rXe/G254m6auSLpG0WNIy24u7fT8AvVXnM/u5kp6OiGci4qCkeyQtbaYtAE2rE/YFkn4y7vnOatrr2B6xvcn2pkM6UGNxAOro+dH4iFgZEcMRMTxDs3q9OAAd1An7LkkLxz0/vZoGYADVCftGSWfZfpftmZKulLSmmbYANK3rU28Rcdj2tZL+SWOn3lZFxLbGOgPQqFrn2SPiQUkPNtQLgB7i67JAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJGoN2Wx7h6T9kl6TdDgihptoCkDzaoW98rGIeKGB9wHQQ+zGA0nUDXtI+oHtR2yPTPQC2yO2N9nedEgHai4OQLfq7sZfEBG7bJ8maa3tpyJi/fgXRMRKSSsl6WTPjZrLA9ClWlv2iNhV3Y9Kul/SuU00BaB5XYfd9pDtk44+lvRxSVubagxAs+rsxs+TdL/to+/zrYj4fiNdAWhc12GPiGcknd1gLwB6iFNvQBKEHUiCsANJEHYgCcIOJNHEhTApvPjZj3asvXP508V5nxqdV6wfPDCjWF9wd7k+e+dLHWtHNj9RnBd5sGUHkiDsQBKEHUiCsANJEHYgCcIOJEHYgSQ4zz5Ff/xH3+pY+9TQT8szn1lz4UvK5R2HX+5Yu/35j9Vc+LHrR6NndKwN3foLxXmnr3uk6XZax5YdSIKwA0kQdiAJwg4kQdiBJAg7kARhB5JwRP8GaTnZc+M8X9i35TXpZ58+r2PthQ+W/8+c82R5Hf/0V1ysz/zg/xbrt3zgvo61i97+SnHe7718YrH+idmdr5Wv65U4WKxvODBUrC854VDXy37P964u1t87srHr927ThlinfbF3wj8otuxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kATXs0/R0Hc2FGr13vvkerPrr39pScfan5+/qLzsfy3/5v0tS97TRUdTM/2VI8X60Jbdxfop6+8t1n91Zuff25+9o/xb/MejSbfstlfZHrW9ddy0ubbX2t5e3c/pbZsA6prKbvw3JF38hmk3SFoXEWdJWlc9BzDAJg17RKyXtPcNk5dKWl09Xi3p8ob7AtCwbj+zz4uIox+onpPUcTAz2yOSRiTpBM3ucnEA6qp9ND7GrqTpeKVHRKyMiOGIGJ6hWXUXB6BL3YZ9j+35klTdjzbXEoBe6DbsayStqB6vkPRAM+0A6JVJP7Pbvltjv1x+qu2dkr4g6WZJ37Z9laRnJV3RyyZRdvi5PR1rQ/d2rknSa5O899B3Xuyio2bs+b2PFuvvn1n+8/3S3vd1rC36u2eK8x4uVo9Nk4Y9IpZ1KB2bv0IBJMXXZYEkCDuQBGEHkiDsQBKEHUiCS1zRmulnLCzWv3LjV4r1GZ5WrP/D7b/ZsXbK7oeL8x6P2LIDSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKcZ0drnvrDBcX6h2eVh7LedrA8HPXcJ15+yz0dz9iyA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EASnGdHTx34xIc71h799G2TzF0eQej3r7uuWH/7v/1okvfPhS07kARhB5Ig7EAShB1IgrADSRB2IAnCDiTBeXb01H9f0nl7cqLL59GX/ddFxfrs7z9WrEexms+kW3bbq2yP2t46btpNtnfZ3lzdLu1tmwDqmspu/DckXTzB9Nsi4pzq9mCzbQFo2qRhj4j1kvb2oRcAPVTnAN21trdUu/lzOr3I9ojtTbY3HdKBGosDUEe3Yf+apDMlnSNpt6RbO70wIlZGxHBEDM+Y5MIGAL3TVdgjYk9EvBYRRyR9XdK5zbYFoGldhd32/HFPPylpa6fXAhgMk55nt323pCWSTrW9U9IXJC2xfY7GTmXukHR1D3vEAHvbSScV68t//aGOtX1HXi3OO/rFdxfrsw5sLNbxepOGPSKWTTD5jh70AqCH+LoskARhB5Ig7EAShB1IgrADSXCJK2rZftP7i/Xvnvo3HWtLt3+qOO+sBzm11iS27EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBOfZUfR/v/ORYn3Lb/9Vsf7jw4c61l76y9OL887S7mIdbw1bdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgvPsyU1f8MvF+vWf//tifZbLf0JXPra8Y+0d/8j16v3Elh1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkuA8+3HO08v/xGd/d2ex/pkTXyzW79p/WrE+7/OdtydHinOiaZNu2W0vtP1D20/Y3mb7umr6XNtrbW+v7uf0vl0A3ZrKbvxhSZ+LiMWSPiLpGtuLJd0gaV1EnCVpXfUcwICaNOwRsTsiHq0e75f0pKQFkpZKWl29bLWky3vVJID63tJndtuLJH1I0gZJ8yLi6I+EPSdpXod5RiSNSNIJmt1tnwBqmvLReNsnSrpX0vURsW98LSJCUkw0X0SsjIjhiBieoVm1mgXQvSmF3fYMjQX9roi4r5q8x/b8qj5f0mhvWgTQhEl3421b0h2SnoyIL48rrZG0QtLN1f0DPekQ9Zz9vmL5z067s9bbf/WLnynWf/Gxh2u9P5ozlc/s50taLulx25uraTdqLOTftn2VpGclXdGbFgE0YdKwR8RDktyhfGGz7QDoFb4uCyRB2IEkCDuQBGEHkiDsQBJc4nocmLb4vR1rI/fU+/rD4lXXFOuL7vz3Wu+P/mHLDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJcJ79OPDUH3T+Yd/LZu/rWJuK0//lYPkFMeEPFGEAsWUHkiDsQBKEHUiCsANJEHYgCcIOJEHYgSQ4z34MePWyc4v1dZfdWqgy5BbGsGUHkiDsQBKEHUiCsANJEHYgCcIOJEHYgSSmMj77QknflDRPUkhaGRG3275J0mclPV+99MaIeLBXjWb2P+dPK9bfOb37c+l37T+tWJ+xr3w9O1ezHzum8qWaw5I+FxGP2j5J0iO211a12yLiS71rD0BTpjI++25Ju6vH+20/KWlBrxsD0Ky39Jnd9iJJH5K0oZp0re0ttlfZnvC3kWyP2N5ke9MhHajVLIDuTTnstk+UdK+k6yNin6SvSTpT0jka2/JP+AXtiFgZEcMRMTxDsxpoGUA3phR22zM0FvS7IuI+SYqIPRHxWkQckfR1SeWrNQC0atKw27akOyQ9GRFfHjd9/riXfVLS1ubbA9CUqRyNP1/SckmP295cTbtR0jLb52js7MsOSVf3pEPU8hcvLi7WH/6tRcV67H68wW7QpqkcjX9IkicocU4dOIbwDTogCcIOJEHYgSQIO5AEYQeSIOxAEo4+Drl7sufGeb6wb8sDstkQ67Qv9k50qpwtO5AFYQeSIOxAEoQdSIKwA0kQdiAJwg4k0dfz7Lafl/TsuEmnSnqhbw28NYPa26D2JdFbt5rs7YyIeMdEhb6G/U0LtzdFxHBrDRQMam+D2pdEb93qV2/sxgNJEHYgibbDvrLl5ZcMam+D2pdEb93qS2+tfmYH0D9tb9kB9AlhB5JoJey2L7b9H7aftn1DGz10YnuH7cdtb7a9qeVeVtketb113LS5ttfa3l7dTzjGXku93WR7V7XuNtu+tKXeFtr+oe0nbG+zfV01vdV1V+irL+ut75/ZbU+T9J+SLpK0U9JGScsi4om+NtKB7R2ShiOi9S9g2P4NSS9J+mZEfKCadoukvRFxc/Uf5ZyI+JMB6e0mSS+1PYx3NVrR/PHDjEu6XNLvqsV1V+jrCvVhvbWxZT9X0tMR8UxEHJR0j6SlLfQx8CJivaS9b5i8VNLq6vFqjf2x9F2H3gZCROyOiEerx/slHR1mvNV1V+irL9oI+wJJPxn3fKcGa7z3kPQD24/YHmm7mQnMi4jd1ePnJM1rs5kJTDqMdz+9YZjxgVl33Qx/XhcH6N7sgoj4NUmXSLqm2l0dSDH2GWyQzp1OaRjvfplgmPGfa3PddTv8eV1thH2XpIXjnp9eTRsIEbGruh+VdL8GbyjqPUdH0K3uR1vu5+cGaRjviYYZ1wCsuzaHP28j7BslnWX7XbZnSrpS0poW+ngT20PVgRPZHpL0cQ3eUNRrJK2oHq+Q9ECLvbzOoAzj3WmYcbW87lof/jwi+n6TdKnGjsj/WNKfttFDh77eLemx6rat7d4k3a2x3bpDGju2cZWkUyStk7Rd0j9LmjtAvd0p6XFJWzQWrPkt9XaBxnbRt0jaXN0ubXvdFfrqy3rj67JAEhygA5Ig7EAShB1IgrADSRB2IAnCDiRB2IEk/h9BCfQTVPflJQAAAABJRU5ErkJggg==\n", 146 | "text/plain": [ 147 | "
" 148 | ] 149 | }, 150 | "metadata": { 151 | "tags": [], 152 | "needs_background": "light" 153 | } 154 | } 155 | ] 156 | }, 157 | { 158 | "cell_type": "markdown", 159 | "metadata": { 160 | "id": "9J2c19WYtYAl", 161 | "colab_type": "text" 162 | }, 163 | "source": [ 164 | "Build the Model" 165 | ] 166 | }, 167 | { 168 | "cell_type": "code", 169 | "metadata": { 170 | "id": "kg2V13vzte4Z", 171 | "colab_type": "code", 172 | "colab": { 173 | "base_uri": "https://localhost:8080/", 174 | "height": 255 175 | }, 176 | "outputId": "47fda1f4-a479-4ba6-ae1d-34ad1b35b4e4" 177 | }, 178 | "source": [ 179 | "model = tf.keras.models.Sequential([\n", 180 | " tf.keras.layers.Flatten(input_shape=(28, 28)),\n", 181 | " tf.keras.layers.Dense(128, activation='relu'), \n", 182 | " tf.keras.layers.Dense(10, activation='softmax')\n", 183 | "])\n", 184 | "\n", 185 | "optimizer = tf.keras.optimizers.Adam(learning_rate=.001)\n", 186 | "\n", 187 | "model.compile(optimizer=optimizer,\n", 188 | " loss='sparse_categorical_crossentropy', \n", 189 | " metrics=['accuracy'])\n", 190 | "\n", 191 | "model.summary()\n" 192 | ], 193 | "execution_count": 10, 194 | "outputs": [ 195 | { 196 | "output_type": "stream", 197 | "text": [ 198 | "Model: \"sequential_1\"\n", 199 | "_________________________________________________________________\n", 200 | "Layer (type) Output Shape Param # \n", 201 | "=================================================================\n", 202 | "flatten_1 (Flatten) (None, 784) 0 \n", 203 | "_________________________________________________________________\n", 204 | "dense_2 (Dense) (None, 128) 100480 \n", 205 | "_________________________________________________________________\n", 206 | "dense_3 (Dense) (None, 10) 1290 \n", 207 | "=================================================================\n", 208 | "Total params: 101,770\n", 209 | "Trainable params: 101,770\n", 210 | "Non-trainable params: 0\n", 211 | "_________________________________________________________________\n" 212 | ], 213 | "name": "stdout" 214 | } 215 | ] 216 | }, 217 | { 218 | "cell_type": "markdown", 219 | "metadata": { 220 | "id": "e7Q_wHCs-t1B", 221 | "colab_type": "text" 222 | }, 223 | "source": [ 224 | "Train and evaluate" 225 | ] 226 | }, 227 | { 228 | "cell_type": "code", 229 | "metadata": { 230 | "id": "jctrnd5_-tBf", 231 | "colab_type": "code", 232 | "colab": { 233 | "base_uri": "https://localhost:8080/", 234 | "height": 187 235 | }, 236 | "outputId": "9e347c7d-f6d7-4ce8-e900-e8c1a28e3f21" 237 | }, 238 | "source": [ 239 | "history = model.fit(x_train, y_train, epochs=5, validation_data=(x_test, y_test))" 240 | ], 241 | "execution_count": 11, 242 | "outputs": [ 243 | { 244 | "output_type": "stream", 245 | "text": [ 246 | "Epoch 1/5\n", 247 | "1875/1875 [==============================] - 4s 2ms/step - loss: 0.2634 - accuracy: 0.9246 - val_loss: 0.1386 - val_accuracy: 0.9576\n", 248 | "Epoch 2/5\n", 249 | "1875/1875 [==============================] - 4s 2ms/step - loss: 0.1151 - accuracy: 0.9660 - val_loss: 0.0975 - val_accuracy: 0.9714\n", 250 | "Epoch 3/5\n", 251 | "1875/1875 [==============================] - 4s 2ms/step - loss: 0.0785 - accuracy: 0.9760 - val_loss: 0.0835 - val_accuracy: 0.9741\n", 252 | "Epoch 4/5\n", 253 | "1875/1875 [==============================] - 4s 2ms/step - loss: 0.0583 - accuracy: 0.9822 - val_loss: 0.0793 - val_accuracy: 0.9761\n", 254 | "Epoch 5/5\n", 255 | "1875/1875 [==============================] - 4s 2ms/step - loss: 0.0444 - accuracy: 0.9867 - val_loss: 0.0751 - val_accuracy: 0.9765\n" 256 | ], 257 | "name": "stdout" 258 | } 259 | ] 260 | }, 261 | { 262 | "cell_type": "code", 263 | "metadata": { 264 | "id": "bMGawmzGeUi6", 265 | "colab_type": "code", 266 | "colab": { 267 | "base_uri": "https://localhost:8080/", 268 | "height": 68 269 | }, 270 | "outputId": "e245f553-2b08-4f44-9c82-3357cc8b13a8" 271 | }, 272 | "source": [ 273 | "model.predict(x_test[:1])" 274 | ], 275 | "execution_count": 7, 276 | "outputs": [ 277 | { 278 | "output_type": "execute_result", 279 | "data": { 280 | "text/plain": [ 281 | "array([[2.2199970e-06, 1.8599103e-07, 8.6192616e-05, 2.5271380e-04,\n", 282 | " 6.8717615e-11, 3.2470743e-07, 1.5319234e-11, 9.9964976e-01,\n", 283 | " 1.2159430e-06, 7.4156669e-06]], dtype=float32)" 284 | ] 285 | }, 286 | "metadata": { 287 | "tags": [] 288 | }, 289 | "execution_count": 7 290 | } 291 | ] 292 | }, 293 | { 294 | "cell_type": "code", 295 | "metadata": { 296 | "id": "pnJldOQqdIqF", 297 | "colab_type": "code", 298 | "colab": { 299 | "base_uri": "https://localhost:8080/", 300 | "height": 34 301 | }, 302 | "outputId": "844e12d0-fe36-41ff-9a70-a1783045fdea" 303 | }, 304 | "source": [ 305 | "np.argmax(model.predict(x_test[:1]))" 306 | ], 307 | "execution_count": 8, 308 | "outputs": [ 309 | { 310 | "output_type": "execute_result", 311 | "data": { 312 | "text/plain": [ 313 | "7" 314 | ] 315 | }, 316 | "metadata": { 317 | "tags": [] 318 | }, 319 | "execution_count": 8 320 | } 321 | ] 322 | }, 323 | { 324 | "cell_type": "code", 325 | "metadata": { 326 | "id": "ht_fAclqEv_o", 327 | "colab_type": "code", 328 | "colab": { 329 | "base_uri": "https://localhost:8080/", 330 | "height": 296 331 | }, 332 | "outputId": "0c263a3d-91e5-4a2f-f23a-3eb6ea03fabf" 333 | }, 334 | "source": [ 335 | "plt.plot(history.history['loss'], label='train loss')\n", 336 | "plt.plot(history.history['val_loss'], label = 'test loss') \n", 337 | "plt.xlabel('Epoch')\n", 338 | "plt.ylabel('Loss')\n", 339 | "plt.legend(loc='lower right')" 340 | ], 341 | "execution_count": 9, 342 | "outputs": [ 343 | { 344 | "output_type": "execute_result", 345 | "data": { 346 | "text/plain": [ 347 | "" 348 | ] 349 | }, 350 | "metadata": { 351 | "tags": [] 352 | }, 353 | "execution_count": 9 354 | }, 355 | { 356 | "output_type": "display_data", 357 | "data": { 358 | "image/png": "\n", 359 | "text/plain": [ 360 | "
" 361 | ] 362 | }, 363 | "metadata": { 364 | "tags": [], 365 | "needs_background": "light" 366 | } 367 | } 368 | ] 369 | } 370 | ] 371 | } -------------------------------------------------------------------------------- /PAIGCP_Translation.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "PAIGCP_Translation.ipynb", 7 | "provenance": [], 8 | "collapsed_sections": [], 9 | "authorship_tag": "ABX9TyOon4bJkMZBRt5LS/ijtFKA", 10 | "include_colab_link": true 11 | }, 12 | "kernelspec": { 13 | "name": "python3", 14 | "display_name": "Python 3" 15 | }, 16 | "accelerator": "GPU" 17 | }, 18 | "cells": [ 19 | { 20 | "cell_type": "markdown", 21 | "metadata": { 22 | "id": "view-in-github", 23 | "colab_type": "text" 24 | }, 25 | "source": [ 26 | "\"Open" 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "metadata": { 32 | "id": "rhUrDc_zOA-H", 33 | "colab_type": "code", 34 | "colab": {} 35 | }, 36 | "source": [ 37 | "import tensorflow as tf\n", 38 | "\n", 39 | "import numpy as np\n", 40 | "import os\n", 41 | "import time\n" 42 | ], 43 | "execution_count": null, 44 | "outputs": [] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "metadata": { 49 | "id": "oG0JNbJsRfwT", 50 | "colab_type": "code", 51 | "colab": { 52 | "base_uri": "https://localhost:8080/", 53 | "height": 67 54 | }, 55 | "outputId": "090686f1-6f70-4c1f-d1df-c7848cdc2557" 56 | }, 57 | "source": [ 58 | "path_to_zip = tf.keras.utils.get_file(\n", 59 | " 'spa-eng.zip', origin='http://storage.googleapis.com/download.tensorflow.org/data/spa-eng.zip',\n", 60 | " extract=True)\n", 61 | "\n", 62 | "path_to_file = os.path.dirname(path_to_zip)+\"/spa-eng/spa.txt\"\n", 63 | "path_to_file" 64 | ], 65 | "execution_count": null, 66 | "outputs": [ 67 | { 68 | "output_type": "stream", 69 | "text": [ 70 | "Downloading data from http://storage.googleapis.com/download.tensorflow.org/data/spa-eng.zip\n", 71 | "2646016/2638744 [==============================] - 0s 0us/step\n" 72 | ], 73 | "name": "stdout" 74 | }, 75 | { 76 | "output_type": "execute_result", 77 | "data": { 78 | "text/plain": [ 79 | "'/root/.keras/datasets/spa-eng/spa.txt'" 80 | ] 81 | }, 82 | "metadata": { 83 | "tags": [] 84 | }, 85 | "execution_count": 2 86 | } 87 | ] 88 | }, 89 | { 90 | "cell_type": "code", 91 | "metadata": { 92 | "id": "EswXJREBNYY-", 93 | "colab_type": "code", 94 | "colab": { 95 | "base_uri": "https://localhost:8080/", 96 | "height": 185 97 | }, 98 | "outputId": "bdd86aaf-37de-4fee-a859-321609de9e7c" 99 | }, 100 | "source": [ 101 | "# Vectorize the data.\n", 102 | "input_texts = []\n", 103 | "target_texts = []\n", 104 | "input_characters = set()\n", 105 | "target_characters = set()\n", 106 | "with open(path_to_file, 'r', encoding='utf-8') as f:\n", 107 | " lines = f.read().split('\\n')\n", 108 | "\n", 109 | "lines[:10]" 110 | ], 111 | "execution_count": null, 112 | "outputs": [ 113 | { 114 | "output_type": "execute_result", 115 | "data": { 116 | "text/plain": [ 117 | "['Go.\\tVe.',\n", 118 | " 'Go.\\tVete.',\n", 119 | " 'Go.\\tVaya.',\n", 120 | " 'Go.\\tVáyase.',\n", 121 | " 'Hi.\\tHola.',\n", 122 | " 'Run!\\t¡Corre!',\n", 123 | " 'Run.\\tCorred.',\n", 124 | " 'Who?\\t¿Quién?',\n", 125 | " 'Fire!\\t¡Fuego!',\n", 126 | " 'Fire!\\t¡Incendio!']" 127 | ] 128 | }, 129 | "metadata": { 130 | "tags": [] 131 | }, 132 | "execution_count": 3 133 | } 134 | ] 135 | }, 136 | { 137 | "cell_type": "code", 138 | "metadata": { 139 | "id": "vFmWesFKR6pr", 140 | "colab_type": "code", 141 | "colab": {} 142 | }, 143 | "source": [ 144 | "num_samples = 1000\n", 145 | "for line in lines[: min(num_samples, len(lines) - 1)]:\n", 146 | " input_text, target_text = line.split('\\t')\n", 147 | " # We use \"tab\" as the \"start sequence\" character\n", 148 | " # for the targets, and \"\\n\" as \"end sequence\" character.\n", 149 | " target_text = '\\t' + target_text + '\\n'\n", 150 | " input_texts.append(input_text)\n", 151 | " target_texts.append(target_text)\n", 152 | " for char in input_text:\n", 153 | " if char not in input_characters:\n", 154 | " input_characters.add(char)\n", 155 | " for char in target_text:\n", 156 | " if char not in target_characters:\n", 157 | " target_characters.add(char)" 158 | ], 159 | "execution_count": null, 160 | "outputs": [] 161 | }, 162 | { 163 | "cell_type": "code", 164 | "metadata": { 165 | "id": "FeeqI8GCSO8e", 166 | "colab_type": "code", 167 | "colab": { 168 | "base_uri": "https://localhost:8080/", 169 | "height": 50 170 | }, 171 | "outputId": "51f9dcb8-8f6d-41e7-81cc-0d5feca621fe" 172 | }, 173 | "source": [ 174 | "input_characters = sorted(list(input_characters))\n", 175 | "num_encoder_tokens = len(input_characters)\n", 176 | "num_encoder_tokens, \",\".join(input_characters)" 177 | ], 178 | "execution_count": null, 179 | "outputs": [ 180 | { 181 | "output_type": "execute_result", 182 | "data": { 183 | "text/plain": [ 184 | "(61,\n", 185 | " \" ,!,',,,.,0,1,3,8,9,:,?,A,B,C,D,E,F,G,H,I,J,K,L,M,N,O,P,R,S,T,U,V,W,Y,a,b,c,d,e,f,g,h,i,j,k,l,m,n,o,p,q,r,s,t,u,v,w,x,y,z\")" 186 | ] 187 | }, 188 | "metadata": { 189 | "tags": [] 190 | }, 191 | "execution_count": 5 192 | } 193 | ] 194 | }, 195 | { 196 | "cell_type": "code", 197 | "metadata": { 198 | "id": "JKsMQLeaSajS", 199 | "colab_type": "code", 200 | "colab": { 201 | "base_uri": "https://localhost:8080/", 202 | "height": 50 203 | }, 204 | "outputId": "11922cd1-a0ca-452c-9ac7-7e40ed604d42" 205 | }, 206 | "source": [ 207 | "target_characters = sorted(list(target_characters))\n", 208 | "num_decoder_tokens = len(target_characters)\n", 209 | "num_decoder_tokens, \",\".join(target_characters)" 210 | ], 211 | "execution_count": null, 212 | "outputs": [ 213 | { 214 | "output_type": "execute_result", 215 | "data": { 216 | "text/plain": [ 217 | "(70,\n", 218 | " '\\t,\\n, ,!,,,.,0,3,8,:,?,A,B,C,D,E,F,G,H,I,J,L,M,N,O,P,Q,R,S,T,U,V,Y,a,b,c,d,e,f,g,h,i,j,l,m,n,o,p,q,r,s,t,u,v,x,y,z,¡,¿,Á,É,Ó,Ú,á,é,í,ñ,ó,ú,ü')" 219 | ] 220 | }, 221 | "metadata": { 222 | "tags": [] 223 | }, 224 | "execution_count": 6 225 | } 226 | ] 227 | }, 228 | { 229 | "cell_type": "code", 230 | "metadata": { 231 | "id": "zXfuLaoVUD2b", 232 | "colab_type": "code", 233 | "colab": { 234 | "base_uri": "https://localhost:8080/", 235 | "height": 34 236 | }, 237 | "outputId": "746f953a-c606-4ff2-82a0-e6369fe23a40" 238 | }, 239 | "source": [ 240 | "max_encoder_seq_length = max([len(txt) for txt in input_texts])\n", 241 | "print('Max sequence length for inputs:', max_encoder_seq_length)" 242 | ], 243 | "execution_count": null, 244 | "outputs": [ 245 | { 246 | "output_type": "stream", 247 | "text": [ 248 | "Max sequence length for inputs: 11\n" 249 | ], 250 | "name": "stdout" 251 | } 252 | ] 253 | }, 254 | { 255 | "cell_type": "code", 256 | "metadata": { 257 | "id": "GwdHoAo41sx1", 258 | "colab_type": "code", 259 | "colab": { 260 | "base_uri": "https://localhost:8080/", 261 | "height": 34 262 | }, 263 | "outputId": "49084ed2-5d75-456f-c39b-1085398217a2" 264 | }, 265 | "source": [ 266 | "max_decoder_seq_length = max([len(txt) for txt in target_texts])\n", 267 | "print('Max sequence length for outputs:', max_decoder_seq_length)" 268 | ], 269 | "execution_count": null, 270 | "outputs": [ 271 | { 272 | "output_type": "stream", 273 | "text": [ 274 | "Max sequence length for outputs: 31\n" 275 | ], 276 | "name": "stdout" 277 | } 278 | ] 279 | }, 280 | { 281 | "cell_type": "code", 282 | "metadata": { 283 | "id": "AHc1qMwjVCQR", 284 | "colab_type": "code", 285 | "colab": {} 286 | }, 287 | "source": [ 288 | "input_token_index = dict(\n", 289 | " [(char, i) for i, char in enumerate(input_characters)])\n", 290 | "target_token_index = dict(\n", 291 | " [(char, i) for i, char in enumerate(target_characters)])\n", 292 | "\n", 293 | "encoder_input_data = np.zeros(\n", 294 | " (len(input_texts), max_encoder_seq_length, num_encoder_tokens),\n", 295 | " dtype='float32')\n", 296 | "decoder_input_data = np.zeros(\n", 297 | " (len(input_texts), max_decoder_seq_length, num_decoder_tokens),\n", 298 | " dtype='float32')\n", 299 | "decoder_target_data = np.zeros(\n", 300 | " (len(input_texts), max_decoder_seq_length, num_decoder_tokens),\n", 301 | " dtype='float32')" 302 | ], 303 | "execution_count": null, 304 | "outputs": [] 305 | }, 306 | { 307 | "cell_type": "code", 308 | "metadata": { 309 | "id": "4UgjBMmhVZU6", 310 | "colab_type": "code", 311 | "colab": {} 312 | }, 313 | "source": [ 314 | "for i, (input_text, target_text) in enumerate(zip(input_texts, target_texts)):\n", 315 | " for t, char in enumerate(input_text):\n", 316 | " encoder_input_data[i, t, input_token_index[char]] = 1.\n", 317 | " encoder_input_data[i, t + 1:, input_token_index[' ']] = 1.\n", 318 | " for t, char in enumerate(target_text):\n", 319 | " # decoder_target_data is ahead of decoder_input_data by one timestep\n", 320 | " decoder_input_data[i, t, target_token_index[char]] = 1.\n", 321 | " if t > 0:\n", 322 | " # decoder_target_data will be ahead by one timestep\n", 323 | " # and will not include the start character.\n", 324 | " decoder_target_data[i, t - 1, target_token_index[char]] = 1.\n", 325 | " decoder_input_data[i, t + 1:, target_token_index[' ']] = 1.\n", 326 | " decoder_target_data[i, t:, target_token_index[' ']] = 1." 327 | ], 328 | "execution_count": null, 329 | "outputs": [] 330 | }, 331 | { 332 | "cell_type": "code", 333 | "metadata": { 334 | "id": "vNJTVSHO3NQF", 335 | "colab_type": "code", 336 | "colab": {} 337 | }, 338 | "source": [ 339 | "from tensorflow.keras.models import Model\n", 340 | "from tensorflow.keras.layers import Input, LSTM, Dense, Embedding, Reshape\n", 341 | "\n", 342 | "batch_size = 64 # Batch size for training.\n", 343 | "epochs = 100 # Number of epochs to train for.\n", 344 | "latent_dim = 64 # Latent dimensionality of the encoding space.\n", 345 | "\n", 346 | "# Define an input sequence and process it.\n", 347 | "encoder_inputs = Input(shape=(None, num_encoder_tokens))\n", 348 | "#encoder_embedded = Embedding(input_dim=num_encoder_tokens, output_dim=64)(encoder_inputs)\n", 349 | "#encoder_reshape = Reshape((-1, 64))(encoder_embedded)\n", 350 | "encoder = LSTM(latent_dim, return_state=True)\n", 351 | "encoder_outputs, state_h, state_c = encoder(encoder_inputs)\n", 352 | "# We discard `encoder_outputs` and only keep the states.\n", 353 | "encoder_states = [state_h, state_c]" 354 | ], 355 | "execution_count": null, 356 | "outputs": [] 357 | }, 358 | { 359 | "cell_type": "code", 360 | "metadata": { 361 | "id": "Za7_kDtc4Gax", 362 | "colab_type": "code", 363 | "colab": {} 364 | }, 365 | "source": [ 366 | "# Set up the decoder, using `encoder_states` as initial state.\n", 367 | "decoder_inputs = Input(shape=(None, num_decoder_tokens))\n", 368 | "# We set up our decoder to return full output sequences,\n", 369 | "# and to return internal states as well. We don't use the\n", 370 | "# return states in the training model, but we will use them in inference.\n", 371 | "decoder_lstm = LSTM(latent_dim, return_sequences=True, return_state=True)\n", 372 | "decoder_outputs, _, _ = decoder_lstm(decoder_inputs,\n", 373 | " initial_state=encoder_states)\n", 374 | "decoder_dense = Dense(num_decoder_tokens, activation='softmax')\n", 375 | "decoder_outputs = decoder_dense(decoder_outputs)" 376 | ], 377 | "execution_count": null, 378 | "outputs": [] 379 | }, 380 | { 381 | "cell_type": "code", 382 | "metadata": { 383 | "id": "BkuLHwekRO2h", 384 | "colab_type": "code", 385 | "colab": { 386 | "base_uri": "https://localhost:8080/", 387 | "height": 34 388 | }, 389 | "outputId": "e4ceab26-609c-4286-c015-e6d03733e8fa" 390 | }, 391 | "source": [ 392 | "batch_size = 64\n", 393 | "epochs = 100\n", 394 | "encoder_input_data.shape,decoder_input_data.shape" 395 | ], 396 | "execution_count": null, 397 | "outputs": [ 398 | { 399 | "output_type": "execute_result", 400 | "data": { 401 | "text/plain": [ 402 | "((1000, 11, 61), (1000, 31, 70))" 403 | ] 404 | }, 405 | "metadata": { 406 | "tags": [] 407 | }, 408 | "execution_count": 76 409 | } 410 | ] 411 | }, 412 | { 413 | "cell_type": "code", 414 | "metadata": { 415 | "id": "u-iU3cxY4qiQ", 416 | "colab_type": "code", 417 | "colab": { 418 | "base_uri": "https://localhost:8080/", 419 | "height": 1000 420 | }, 421 | "outputId": "2e7dabe1-b899-4951-eba5-3b402f8f54f9" 422 | }, 423 | "source": [ 424 | "# Define the model that will turn\n", 425 | "# `encoder_input_data` & `decoder_input_data` into `decoder_target_data`\n", 426 | "model = Model([encoder_inputs, decoder_inputs], decoder_outputs)\n", 427 | "\n", 428 | "# Run training\n", 429 | "model.compile(optimizer='rmsprop', loss='categorical_crossentropy',\n", 430 | " metrics=['accuracy'])\n", 431 | "model.fit([encoder_input_data, decoder_input_data], decoder_target_data,\n", 432 | " batch_size=batch_size,\n", 433 | " epochs=epochs,\n", 434 | " validation_split=0.2)\n", 435 | "# Save model\n", 436 | "model.save('s2s.h5')" 437 | ], 438 | "execution_count": null, 439 | "outputs": [ 440 | { 441 | "output_type": "stream", 442 | "text": [ 443 | "Epoch 1/100\n", 444 | "13/13 [==============================] - 1s 76ms/step - loss: 3.3758 - accuracy: 0.5017 - val_loss: 2.1734 - val_accuracy: 0.5777\n", 445 | "Epoch 2/100\n", 446 | "13/13 [==============================] - 0s 19ms/step - loss: 1.7953 - accuracy: 0.6275 - val_loss: 1.9087 - val_accuracy: 0.5744\n", 447 | "Epoch 3/100\n", 448 | "13/13 [==============================] - 0s 19ms/step - loss: 1.6027 - accuracy: 0.6265 - val_loss: 1.7623 - val_accuracy: 0.5731\n", 449 | "Epoch 4/100\n", 450 | "13/13 [==============================] - 0s 20ms/step - loss: 1.5276 - accuracy: 0.6279 - val_loss: 1.7184 - val_accuracy: 0.5808\n", 451 | "Epoch 5/100\n", 452 | "13/13 [==============================] - 0s 19ms/step - loss: 1.4907 - accuracy: 0.6279 - val_loss: 1.6831 - val_accuracy: 0.5787\n", 453 | "Epoch 6/100\n", 454 | "13/13 [==============================] - 0s 19ms/step - loss: 1.4550 - accuracy: 0.6304 - val_loss: 1.7033 - val_accuracy: 0.5885\n", 455 | "Epoch 7/100\n", 456 | "13/13 [==============================] - 0s 19ms/step - loss: 1.4257 - accuracy: 0.6348 - val_loss: 1.6383 - val_accuracy: 0.5852\n", 457 | "Epoch 8/100\n", 458 | "13/13 [==============================] - 0s 20ms/step - loss: 1.4009 - accuracy: 0.6354 - val_loss: 1.6056 - val_accuracy: 0.5852\n", 459 | "Epoch 9/100\n", 460 | "13/13 [==============================] - 0s 19ms/step - loss: 1.3657 - accuracy: 0.6364 - val_loss: 1.5636 - val_accuracy: 0.5818\n", 461 | "Epoch 10/100\n", 462 | "13/13 [==============================] - 0s 19ms/step - loss: 1.3541 - accuracy: 0.6381 - val_loss: 1.5429 - val_accuracy: 0.5874\n", 463 | "Epoch 11/100\n", 464 | "13/13 [==============================] - 0s 20ms/step - loss: 1.3300 - accuracy: 0.6407 - val_loss: 1.5173 - val_accuracy: 0.5894\n", 465 | "Epoch 12/100\n", 466 | "13/13 [==============================] - 0s 20ms/step - loss: 1.3094 - accuracy: 0.6447 - val_loss: 1.5152 - val_accuracy: 0.5902\n", 467 | "Epoch 13/100\n", 468 | "13/13 [==============================] - 0s 22ms/step - loss: 1.2850 - accuracy: 0.6465 - val_loss: 1.5090 - val_accuracy: 0.5900\n", 469 | "Epoch 14/100\n", 470 | "13/13 [==============================] - 0s 19ms/step - loss: 1.2668 - accuracy: 0.6517 - val_loss: 1.4856 - val_accuracy: 0.5927\n", 471 | "Epoch 15/100\n", 472 | "13/13 [==============================] - 0s 19ms/step - loss: 1.2545 - accuracy: 0.6569 - val_loss: 1.4439 - val_accuracy: 0.5956\n", 473 | "Epoch 16/100\n", 474 | "13/13 [==============================] - 0s 19ms/step - loss: 1.2277 - accuracy: 0.6612 - val_loss: 1.4186 - val_accuracy: 0.6074\n", 475 | "Epoch 17/100\n", 476 | "13/13 [==============================] - 0s 20ms/step - loss: 1.2127 - accuracy: 0.6692 - val_loss: 1.4098 - val_accuracy: 0.6161\n", 477 | "Epoch 18/100\n", 478 | "13/13 [==============================] - 0s 19ms/step - loss: 1.2023 - accuracy: 0.6785 - val_loss: 1.3878 - val_accuracy: 0.6355\n", 479 | "Epoch 19/100\n", 480 | "13/13 [==============================] - 0s 20ms/step - loss: 1.1801 - accuracy: 0.6871 - val_loss: 1.3800 - val_accuracy: 0.6265\n", 481 | "Epoch 20/100\n", 482 | "13/13 [==============================] - 0s 19ms/step - loss: 1.1642 - accuracy: 0.6931 - val_loss: 1.4433 - val_accuracy: 0.6026\n", 483 | "Epoch 21/100\n", 484 | "13/13 [==============================] - 0s 19ms/step - loss: 1.1519 - accuracy: 0.6966 - val_loss: 1.3476 - val_accuracy: 0.6387\n", 485 | "Epoch 22/100\n", 486 | "13/13 [==============================] - 0s 20ms/step - loss: 1.1384 - accuracy: 0.7000 - val_loss: 1.3278 - val_accuracy: 0.6463\n", 487 | "Epoch 23/100\n", 488 | "13/13 [==============================] - 0s 19ms/step - loss: 1.1191 - accuracy: 0.7078 - val_loss: 1.3109 - val_accuracy: 0.6473\n", 489 | "Epoch 24/100\n", 490 | "13/13 [==============================] - 0s 20ms/step - loss: 1.1111 - accuracy: 0.7101 - val_loss: 1.3061 - val_accuracy: 0.6489\n", 491 | "Epoch 25/100\n", 492 | "13/13 [==============================] - 0s 19ms/step - loss: 1.0920 - accuracy: 0.7144 - val_loss: 1.3037 - val_accuracy: 0.6453\n", 493 | "Epoch 26/100\n", 494 | "13/13 [==============================] - 0s 19ms/step - loss: 1.0803 - accuracy: 0.7178 - val_loss: 1.2765 - val_accuracy: 0.6565\n", 495 | "Epoch 27/100\n", 496 | "13/13 [==============================] - 0s 20ms/step - loss: 1.0676 - accuracy: 0.7210 - val_loss: 1.2781 - val_accuracy: 0.6492\n", 497 | "Epoch 28/100\n", 498 | "13/13 [==============================] - 0s 20ms/step - loss: 1.0515 - accuracy: 0.7256 - val_loss: 1.2460 - val_accuracy: 0.6752\n", 499 | "Epoch 29/100\n", 500 | "13/13 [==============================] - 0s 19ms/step - loss: 1.0453 - accuracy: 0.7257 - val_loss: 1.2347 - val_accuracy: 0.6669\n", 501 | "Epoch 30/100\n", 502 | "13/13 [==============================] - 0s 20ms/step - loss: 1.0265 - accuracy: 0.7293 - val_loss: 1.2283 - val_accuracy: 0.6694\n", 503 | "Epoch 31/100\n", 504 | "13/13 [==============================] - 0s 19ms/step - loss: 1.0154 - accuracy: 0.7310 - val_loss: 1.2207 - val_accuracy: 0.6703\n", 505 | "Epoch 32/100\n", 506 | "13/13 [==============================] - 0s 19ms/step - loss: 1.0018 - accuracy: 0.7335 - val_loss: 1.2072 - val_accuracy: 0.6668\n", 507 | "Epoch 33/100\n", 508 | "13/13 [==============================] - 0s 19ms/step - loss: 0.9917 - accuracy: 0.7342 - val_loss: 1.1838 - val_accuracy: 0.6798\n", 509 | "Epoch 34/100\n", 510 | "13/13 [==============================] - 0s 19ms/step - loss: 0.9774 - accuracy: 0.7361 - val_loss: 1.1854 - val_accuracy: 0.6868\n", 511 | "Epoch 35/100\n", 512 | "13/13 [==============================] - 0s 19ms/step - loss: 0.9643 - accuracy: 0.7400 - val_loss: 1.1625 - val_accuracy: 0.6831\n", 513 | "Epoch 36/100\n", 514 | "13/13 [==============================] - 0s 19ms/step - loss: 0.9541 - accuracy: 0.7401 - val_loss: 1.1549 - val_accuracy: 0.6840\n", 515 | "Epoch 37/100\n", 516 | "13/13 [==============================] - 0s 19ms/step - loss: 0.9413 - accuracy: 0.7429 - val_loss: 1.1386 - val_accuracy: 0.6834\n", 517 | "Epoch 38/100\n", 518 | "13/13 [==============================] - 0s 19ms/step - loss: 0.9296 - accuracy: 0.7438 - val_loss: 1.1314 - val_accuracy: 0.6853\n", 519 | "Epoch 39/100\n", 520 | "13/13 [==============================] - 0s 20ms/step - loss: 0.9170 - accuracy: 0.7471 - val_loss: 1.1208 - val_accuracy: 0.6916\n", 521 | "Epoch 40/100\n", 522 | "13/13 [==============================] - 0s 20ms/step - loss: 0.9081 - accuracy: 0.7498 - val_loss: 1.1234 - val_accuracy: 0.6781\n", 523 | "Epoch 41/100\n", 524 | "13/13 [==============================] - 0s 20ms/step - loss: 0.8963 - accuracy: 0.7528 - val_loss: 1.1042 - val_accuracy: 0.6890\n", 525 | "Epoch 42/100\n", 526 | "13/13 [==============================] - 0s 19ms/step - loss: 0.8853 - accuracy: 0.7545 - val_loss: 1.0868 - val_accuracy: 0.6902\n", 527 | "Epoch 43/100\n", 528 | "13/13 [==============================] - 0s 19ms/step - loss: 0.8759 - accuracy: 0.7552 - val_loss: 1.0849 - val_accuracy: 0.6877\n", 529 | "Epoch 44/100\n", 530 | "13/13 [==============================] - 0s 19ms/step - loss: 0.8669 - accuracy: 0.7563 - val_loss: 1.0749 - val_accuracy: 0.6848\n", 531 | "Epoch 45/100\n", 532 | "13/13 [==============================] - 0s 20ms/step - loss: 0.8584 - accuracy: 0.7600 - val_loss: 1.0624 - val_accuracy: 0.6918\n", 533 | "Epoch 46/100\n", 534 | "13/13 [==============================] - 0s 20ms/step - loss: 0.8450 - accuracy: 0.7616 - val_loss: 1.0523 - val_accuracy: 0.6911\n", 535 | "Epoch 47/100\n", 536 | "13/13 [==============================] - 0s 19ms/step - loss: 0.8404 - accuracy: 0.7620 - val_loss: 1.0539 - val_accuracy: 0.6927\n", 537 | "Epoch 48/100\n", 538 | "13/13 [==============================] - 0s 20ms/step - loss: 0.8294 - accuracy: 0.7647 - val_loss: 1.0419 - val_accuracy: 0.6927\n", 539 | "Epoch 49/100\n", 540 | "13/13 [==============================] - 0s 20ms/step - loss: 0.8254 - accuracy: 0.7623 - val_loss: 1.0392 - val_accuracy: 0.6990\n", 541 | "Epoch 50/100\n", 542 | "13/13 [==============================] - 0s 20ms/step - loss: 0.8145 - accuracy: 0.7666 - val_loss: 1.0302 - val_accuracy: 0.6989\n", 543 | "Epoch 51/100\n", 544 | "13/13 [==============================] - 0s 20ms/step - loss: 0.8102 - accuracy: 0.7665 - val_loss: 1.0235 - val_accuracy: 0.6977\n", 545 | "Epoch 52/100\n", 546 | "13/13 [==============================] - 0s 19ms/step - loss: 0.8019 - accuracy: 0.7703 - val_loss: 1.0170 - val_accuracy: 0.7010\n", 547 | "Epoch 53/100\n", 548 | "13/13 [==============================] - 0s 20ms/step - loss: 0.7937 - accuracy: 0.7730 - val_loss: 1.0248 - val_accuracy: 0.6944\n", 549 | "Epoch 54/100\n", 550 | "13/13 [==============================] - 0s 20ms/step - loss: 0.7881 - accuracy: 0.7750 - val_loss: 1.0192 - val_accuracy: 0.6953\n", 551 | "Epoch 55/100\n", 552 | "13/13 [==============================] - 0s 19ms/step - loss: 0.7839 - accuracy: 0.7730 - val_loss: 0.9980 - val_accuracy: 0.7066\n", 553 | "Epoch 56/100\n", 554 | "13/13 [==============================] - 0s 20ms/step - loss: 0.7752 - accuracy: 0.7784 - val_loss: 1.0022 - val_accuracy: 0.7092\n", 555 | "Epoch 57/100\n", 556 | "13/13 [==============================] - 0s 21ms/step - loss: 0.7712 - accuracy: 0.7788 - val_loss: 0.9954 - val_accuracy: 0.7037\n", 557 | "Epoch 58/100\n", 558 | "13/13 [==============================] - 0s 19ms/step - loss: 0.7624 - accuracy: 0.7803 - val_loss: 0.9922 - val_accuracy: 0.7098\n", 559 | "Epoch 59/100\n", 560 | "13/13 [==============================] - 0s 20ms/step - loss: 0.7579 - accuracy: 0.7817 - val_loss: 1.0125 - val_accuracy: 0.6976\n", 561 | "Epoch 60/100\n", 562 | "13/13 [==============================] - 0s 20ms/step - loss: 0.7543 - accuracy: 0.7820 - val_loss: 0.9926 - val_accuracy: 0.7069\n", 563 | "Epoch 61/100\n", 564 | "13/13 [==============================] - 0s 19ms/step - loss: 0.7481 - accuracy: 0.7839 - val_loss: 0.9898 - val_accuracy: 0.7121\n", 565 | "Epoch 62/100\n", 566 | "13/13 [==============================] - 0s 19ms/step - loss: 0.7456 - accuracy: 0.7844 - val_loss: 0.9801 - val_accuracy: 0.7089\n", 567 | "Epoch 63/100\n", 568 | "13/13 [==============================] - 0s 19ms/step - loss: 0.7366 - accuracy: 0.7869 - val_loss: 0.9874 - val_accuracy: 0.7119\n", 569 | "Epoch 64/100\n", 570 | "13/13 [==============================] - 0s 19ms/step - loss: 0.7323 - accuracy: 0.7882 - val_loss: 0.9717 - val_accuracy: 0.7110\n", 571 | "Epoch 65/100\n", 572 | "13/13 [==============================] - 0s 20ms/step - loss: 0.7276 - accuracy: 0.7895 - val_loss: 0.9655 - val_accuracy: 0.7148\n", 573 | "Epoch 66/100\n", 574 | "13/13 [==============================] - 0s 19ms/step - loss: 0.7220 - accuracy: 0.7906 - val_loss: 0.9847 - val_accuracy: 0.7079\n", 575 | "Epoch 67/100\n", 576 | "13/13 [==============================] - 0s 20ms/step - loss: 0.7175 - accuracy: 0.7924 - val_loss: 0.9762 - val_accuracy: 0.7134\n", 577 | "Epoch 68/100\n", 578 | "13/13 [==============================] - 0s 19ms/step - loss: 0.7146 - accuracy: 0.7911 - val_loss: 0.9653 - val_accuracy: 0.7134\n", 579 | "Epoch 69/100\n", 580 | "13/13 [==============================] - 0s 20ms/step - loss: 0.7086 - accuracy: 0.7933 - val_loss: 0.9729 - val_accuracy: 0.7118\n", 581 | "Epoch 70/100\n", 582 | "13/13 [==============================] - 0s 19ms/step - loss: 0.7077 - accuracy: 0.7929 - val_loss: 0.9548 - val_accuracy: 0.7165\n", 583 | "Epoch 71/100\n", 584 | "13/13 [==============================] - 0s 20ms/step - loss: 0.7013 - accuracy: 0.7953 - val_loss: 0.9525 - val_accuracy: 0.7192\n", 585 | "Epoch 72/100\n", 586 | "13/13 [==============================] - 0s 20ms/step - loss: 0.6964 - accuracy: 0.7962 - val_loss: 0.9481 - val_accuracy: 0.7206\n", 587 | "Epoch 73/100\n", 588 | "13/13 [==============================] - 0s 20ms/step - loss: 0.6940 - accuracy: 0.7962 - val_loss: 0.9493 - val_accuracy: 0.7158\n", 589 | "Epoch 74/100\n", 590 | "13/13 [==============================] - 0s 19ms/step - loss: 0.6893 - accuracy: 0.7981 - val_loss: 0.9526 - val_accuracy: 0.7131\n", 591 | "Epoch 75/100\n", 592 | "13/13 [==============================] - 0s 20ms/step - loss: 0.6846 - accuracy: 0.7977 - val_loss: 0.9522 - val_accuracy: 0.7169\n", 593 | "Epoch 76/100\n", 594 | "13/13 [==============================] - 0s 20ms/step - loss: 0.6810 - accuracy: 0.7996 - val_loss: 0.9414 - val_accuracy: 0.7261\n", 595 | "Epoch 77/100\n", 596 | "13/13 [==============================] - 0s 19ms/step - loss: 0.6782 - accuracy: 0.8002 - val_loss: 0.9425 - val_accuracy: 0.7198\n", 597 | "Epoch 78/100\n", 598 | "13/13 [==============================] - 0s 20ms/step - loss: 0.6724 - accuracy: 0.8013 - val_loss: 0.9387 - val_accuracy: 0.7218\n", 599 | "Epoch 79/100\n", 600 | "13/13 [==============================] - 0s 20ms/step - loss: 0.6714 - accuracy: 0.8014 - val_loss: 0.9341 - val_accuracy: 0.7219\n", 601 | "Epoch 80/100\n", 602 | "13/13 [==============================] - 0s 19ms/step - loss: 0.6648 - accuracy: 0.8038 - val_loss: 0.9447 - val_accuracy: 0.7189\n", 603 | "Epoch 81/100\n", 604 | "13/13 [==============================] - 0s 19ms/step - loss: 0.6637 - accuracy: 0.8056 - val_loss: 0.9519 - val_accuracy: 0.7156\n", 605 | "Epoch 82/100\n", 606 | "13/13 [==============================] - 0s 21ms/step - loss: 0.6585 - accuracy: 0.8044 - val_loss: 0.9419 - val_accuracy: 0.7194\n", 607 | "Epoch 83/100\n", 608 | "13/13 [==============================] - 0s 20ms/step - loss: 0.6546 - accuracy: 0.8057 - val_loss: 0.9433 - val_accuracy: 0.7173\n", 609 | "Epoch 84/100\n", 610 | "13/13 [==============================] - 0s 19ms/step - loss: 0.6522 - accuracy: 0.8080 - val_loss: 0.9382 - val_accuracy: 0.7203\n", 611 | "Epoch 85/100\n", 612 | "13/13 [==============================] - 0s 19ms/step - loss: 0.6474 - accuracy: 0.8094 - val_loss: 0.9362 - val_accuracy: 0.7219\n", 613 | "Epoch 86/100\n", 614 | "13/13 [==============================] - 0s 20ms/step - loss: 0.6433 - accuracy: 0.8104 - val_loss: 0.9290 - val_accuracy: 0.7260\n", 615 | "Epoch 87/100\n", 616 | "13/13 [==============================] - 0s 20ms/step - loss: 0.6397 - accuracy: 0.8106 - val_loss: 0.9324 - val_accuracy: 0.7268\n", 617 | "Epoch 88/100\n", 618 | "13/13 [==============================] - 0s 19ms/step - loss: 0.6378 - accuracy: 0.8114 - val_loss: 0.9267 - val_accuracy: 0.7256\n", 619 | "Epoch 89/100\n", 620 | "13/13 [==============================] - 0s 19ms/step - loss: 0.6336 - accuracy: 0.8124 - val_loss: 0.9598 - val_accuracy: 0.7163\n", 621 | "Epoch 90/100\n", 622 | "13/13 [==============================] - 0s 20ms/step - loss: 0.6322 - accuracy: 0.8130 - val_loss: 0.9288 - val_accuracy: 0.7266\n", 623 | "Epoch 91/100\n", 624 | "13/13 [==============================] - 0s 20ms/step - loss: 0.6272 - accuracy: 0.8146 - val_loss: 0.9322 - val_accuracy: 0.7206\n", 625 | "Epoch 92/100\n", 626 | "13/13 [==============================] - 0s 20ms/step - loss: 0.6214 - accuracy: 0.8170 - val_loss: 0.9392 - val_accuracy: 0.7247\n", 627 | "Epoch 93/100\n", 628 | "13/13 [==============================] - 0s 20ms/step - loss: 0.6214 - accuracy: 0.8163 - val_loss: 0.9352 - val_accuracy: 0.7250\n", 629 | "Epoch 94/100\n", 630 | "13/13 [==============================] - 0s 20ms/step - loss: 0.6184 - accuracy: 0.8172 - val_loss: 0.9275 - val_accuracy: 0.7321\n", 631 | "Epoch 95/100\n", 632 | "13/13 [==============================] - 0s 20ms/step - loss: 0.6141 - accuracy: 0.8190 - val_loss: 0.9285 - val_accuracy: 0.7313\n", 633 | "Epoch 96/100\n", 634 | "13/13 [==============================] - 0s 19ms/step - loss: 0.6100 - accuracy: 0.8195 - val_loss: 0.9233 - val_accuracy: 0.7303\n", 635 | "Epoch 97/100\n", 636 | "13/13 [==============================] - 0s 20ms/step - loss: 0.6084 - accuracy: 0.8209 - val_loss: 0.9337 - val_accuracy: 0.7289\n", 637 | "Epoch 98/100\n", 638 | "13/13 [==============================] - 0s 19ms/step - loss: 0.6041 - accuracy: 0.8223 - val_loss: 0.9185 - val_accuracy: 0.7300\n", 639 | "Epoch 99/100\n", 640 | "13/13 [==============================] - 0s 19ms/step - loss: 0.5998 - accuracy: 0.8225 - val_loss: 0.9341 - val_accuracy: 0.7229\n", 641 | "Epoch 100/100\n", 642 | "13/13 [==============================] - 0s 20ms/step - loss: 0.5978 - accuracy: 0.8235 - val_loss: 0.9374 - val_accuracy: 0.7284\n" 643 | ], 644 | "name": "stdout" 645 | } 646 | ] 647 | }, 648 | { 649 | "cell_type": "code", 650 | "metadata": { 651 | "id": "iCzu6qk70FlV", 652 | "colab_type": "code", 653 | "colab": {} 654 | }, 655 | "source": [ 656 | "encoder_model = Model(encoder_inputs, encoder_states)\n", 657 | "\n", 658 | "decoder_state_input_h = Input(shape=(latent_dim,))\n", 659 | "decoder_state_input_c = Input(shape=(latent_dim,))\n", 660 | "decoder_states_inputs = [decoder_state_input_h, decoder_state_input_c]\n", 661 | "decoder_outputs, state_h, state_c = decoder_lstm(\n", 662 | " decoder_inputs, initial_state=decoder_states_inputs)\n", 663 | "decoder_states = [state_h, state_c]\n", 664 | "decoder_outputs = decoder_dense(decoder_outputs)\n", 665 | "decoder_model = Model(\n", 666 | " [decoder_inputs] + decoder_states_inputs,\n", 667 | " [decoder_outputs] + decoder_states)\n", 668 | "\n", 669 | "# Reverse-lookup token index to decode sequences back to\n", 670 | "# something readable.\n", 671 | "reverse_input_char_index = dict(\n", 672 | " (i, char) for char, i in input_token_index.items())\n", 673 | "reverse_target_char_index = dict(\n", 674 | " (i, char) for char, i in target_token_index.items())" 675 | ], 676 | "execution_count": null, 677 | "outputs": [] 678 | }, 679 | { 680 | "cell_type": "code", 681 | "metadata": { 682 | "id": "ppivXYwF0PC-", 683 | "colab_type": "code", 684 | "colab": {} 685 | }, 686 | "source": [ 687 | "def decode_sequence(input_seq):\n", 688 | " # Encode the input as state vectors.\n", 689 | " states_value = encoder_model.predict(input_seq)\n", 690 | "\n", 691 | " # Generate empty target sequence of length 1.\n", 692 | " target_seq = np.zeros((1, 1, num_decoder_tokens))\n", 693 | " # Populate the first character of target sequence with the start character.\n", 694 | " target_seq[0, 0, target_token_index['\\t']] = 1.\n", 695 | "\n", 696 | " # Sampling loop for a batch of sequences\n", 697 | " # (to simplify, here we assume a batch of size 1).\n", 698 | " stop_condition = False\n", 699 | " decoded_sentence = ''\n", 700 | " while not stop_condition:\n", 701 | " output_tokens, h, c = decoder_model.predict(\n", 702 | " [target_seq] + states_value)\n", 703 | "\n", 704 | " # Sample a token\n", 705 | " sampled_token_index = np.argmax(output_tokens[0, -1, :])\n", 706 | " sampled_char = reverse_target_char_index[sampled_token_index]\n", 707 | " decoded_sentence += sampled_char\n", 708 | "\n", 709 | " # Exit condition: either hit max length\n", 710 | " # or find stop character.\n", 711 | " if (sampled_char == '\\n' or\n", 712 | " len(decoded_sentence) > max_decoder_seq_length):\n", 713 | " stop_condition = True\n", 714 | "\n", 715 | " # Update the target sequence (of length 1).\n", 716 | " target_seq = np.zeros((1, 1, num_decoder_tokens))\n", 717 | " target_seq[0, 0, sampled_token_index] = 1.\n", 718 | "\n", 719 | " # Update states\n", 720 | " states_value = [h, c]\n", 721 | "\n", 722 | " return decoded_sentence" 723 | ], 724 | "execution_count": null, 725 | "outputs": [] 726 | }, 727 | { 728 | "cell_type": "code", 729 | "metadata": { 730 | "id": "SFJRvGN80X3d", 731 | "colab_type": "code", 732 | "colab": { 733 | "base_uri": "https://localhost:8080/", 734 | "height": 1000 735 | }, 736 | "outputId": "a29869d6-eb73-4df2-fa34-2fbd1b34c49a" 737 | }, 738 | "source": [ 739 | "for seq_index in range(100):\n", 740 | " # Take one sequence (part of the training set)\n", 741 | " # for trying out decoding.\n", 742 | " input_seq = encoder_input_data[seq_index: seq_index + 1]\n", 743 | " decoded_sentence = decode_sequence(input_seq)\n", 744 | " print('-')\n", 745 | " print('Input sentence:', input_texts[seq_index])\n", 746 | " print('Decoded sentence:', decoded_sentence)" 747 | ], 748 | "execution_count": null, 749 | "outputs": [ 750 | { 751 | "output_type": "stream", 752 | "text": [ 753 | "-\n", 754 | "Input sentence: Go.\n", 755 | "Decoded sentence: Vete.\n", 756 | "\n", 757 | "-\n", 758 | "Input sentence: Go.\n", 759 | "Decoded sentence: Vete.\n", 760 | "\n", 761 | "-\n", 762 | "Input sentence: Go.\n", 763 | "Decoded sentence: Vete.\n", 764 | "\n", 765 | "-\n", 766 | "Input sentence: Go.\n", 767 | "Decoded sentence: Vete.\n", 768 | "\n", 769 | "-\n", 770 | "Input sentence: Hi.\n", 771 | "Decoded sentence: Vente.\n", 772 | "\n", 773 | "-\n", 774 | "Input sentence: Run!\n", 775 | "Decoded sentence: ¡Ve elo!\n", 776 | "\n", 777 | "-\n", 778 | "Input sentence: Run.\n", 779 | "Decoded sentence: Vete.\n", 780 | "\n", 781 | "-\n", 782 | "Input sentence: Who?\n", 783 | "Decoded sentence: ¿Quién anor?\n", 784 | "\n", 785 | "-\n", 786 | "Input sentence: Fire!\n", 787 | "Decoded sentence: ¡Despara!\n", 788 | "\n", 789 | "-\n", 790 | "Input sentence: Fire!\n", 791 | "Decoded sentence: ¡Despara!\n", 792 | "\n", 793 | "-\n", 794 | "Input sentence: Fire!\n", 795 | "Decoded sentence: ¡Despara!\n", 796 | "\n", 797 | "-\n", 798 | "Input sentence: Help!\n", 799 | "Decoded sentence: ¡Vete!\n", 800 | "\n", 801 | "-\n", 802 | "Input sentence: Help!\n", 803 | "Decoded sentence: ¡Vete!\n", 804 | "\n", 805 | "-\n", 806 | "Input sentence: Help!\n", 807 | "Decoded sentence: ¡Vete!\n", 808 | "\n", 809 | "-\n", 810 | "Input sentence: Jump!\n", 811 | "Decoded sentence: ¡Dete alo!\n", 812 | "\n", 813 | "-\n", 814 | "Input sentence: Jump.\n", 815 | "Decoded sentence: Satente.\n", 816 | "\n", 817 | "-\n", 818 | "Input sentence: Stop!\n", 819 | "Decoded sentence: ¡Vete a!\n", 820 | "\n", 821 | "-\n", 822 | "Input sentence: Stop!\n", 823 | "Decoded sentence: ¡Vete a!\n", 824 | "\n", 825 | "-\n", 826 | "Input sentence: Stop!\n", 827 | "Decoded sentence: ¡Vete a!\n", 828 | "\n", 829 | "-\n", 830 | "Input sentence: Wait!\n", 831 | "Decoded sentence: ¡Desesta!\n", 832 | "\n", 833 | "-\n", 834 | "Input sentence: Wait.\n", 835 | "Decoded sentence: Allante alo.\n", 836 | "\n", 837 | "-\n", 838 | "Input sentence: Go on.\n", 839 | "Decoded sentence: Vente.\n", 840 | "\n", 841 | "-\n", 842 | "Input sentence: Go on.\n", 843 | "Decoded sentence: Vente.\n", 844 | "\n", 845 | "-\n", 846 | "Input sentence: Hello!\n", 847 | "Decoded sentence: ¡Lorado!\n", 848 | "\n", 849 | "-\n", 850 | "Input sentence: I ran.\n", 851 | "Decoded sentence: Corrí.\n", 852 | "\n", 853 | "-\n", 854 | "Input sentence: I ran.\n", 855 | "Decoded sentence: Corrí.\n", 856 | "\n", 857 | "-\n", 858 | "Input sentence: I try.\n", 859 | "Decoded sentence: Es trerda.\n", 860 | "\n", 861 | "-\n", 862 | "Input sentence: I won!\n", 863 | "Decoded sentence: ¡Prado!\n", 864 | "\n", 865 | "-\n", 866 | "Input sentence: Oh no!\n", 867 | "Decoded sentence: ¡Para!\n", 868 | "\n", 869 | "-\n", 870 | "Input sentence: Relax.\n", 871 | "Decoded sentence: Abrala a Tom.\n", 872 | "\n", 873 | "-\n", 874 | "Input sentence: Smile.\n", 875 | "Decoded sentence: Menten ararra.\n", 876 | "\n", 877 | "-\n", 878 | "Input sentence: Attack!\n", 879 | "Decoded sentence: ¡De pira!\n", 880 | "\n", 881 | "-\n", 882 | "Input sentence: Attack!\n", 883 | "Decoded sentence: ¡De pira!\n", 884 | "\n", 885 | "-\n", 886 | "Input sentence: Get up.\n", 887 | "Decoded sentence: Ayada.\n", 888 | "\n", 889 | "-\n", 890 | "Input sentence: Go now.\n", 891 | "Decoded sentence: Antrade.\n", 892 | "\n", 893 | "-\n", 894 | "Input sentence: Got it!\n", 895 | "Decoded sentence: ¡Lo te alo!\n", 896 | "\n", 897 | "-\n", 898 | "Input sentence: Got it?\n", 899 | "Decoded sentence: ¿Qué porro?\n", 900 | "\n", 901 | "-\n", 902 | "Input sentence: Got it?\n", 903 | "Decoded sentence: ¿Qué porro?\n", 904 | "\n", 905 | "-\n", 906 | "Input sentence: He ran.\n", 907 | "Decoded sentence: Mente ciente.\n", 908 | "\n", 909 | "-\n", 910 | "Input sentence: Hop in.\n", 911 | "Decoded sentence: Menten astrada.\n", 912 | "\n", 913 | "-\n", 914 | "Input sentence: Hug me.\n", 915 | "Decoded sentence: Ayrada esto.\n", 916 | "\n", 917 | "-\n", 918 | "Input sentence: I fell.\n", 919 | "Decoded sentence: Me cuerda.\n", 920 | "\n", 921 | "-\n", 922 | "Input sentence: I know.\n", 923 | "Decoded sentence: Me carrió.\n", 924 | "\n", 925 | "-\n", 926 | "Input sentence: I left.\n", 927 | "Decoded sentence: Lo vira.\n", 928 | "\n", 929 | "-\n", 930 | "Input sentence: I lied.\n", 931 | "Decoded sentence: Mentente.\n", 932 | "\n", 933 | "-\n", 934 | "Input sentence: I lost.\n", 935 | "Decoded sentence: Ve arí.\n", 936 | "\n", 937 | "-\n", 938 | "Input sentence: I quit.\n", 939 | "Decoded sentence: Mentre.\n", 940 | "\n", 941 | "-\n", 942 | "Input sentence: I quit.\n", 943 | "Decoded sentence: Mentre.\n", 944 | "\n", 945 | "-\n", 946 | "Input sentence: I work.\n", 947 | "Decoded sentence: Estrada.\n", 948 | "\n", 949 | "-\n", 950 | "Input sentence: I'm 19.\n", 951 | "Decoded sentence: Estoy alo.\n", 952 | "\n", 953 | "-\n", 954 | "Input sentence: I'm up.\n", 955 | "Decoded sentence: Estoy lorada.\n", 956 | "\n", 957 | "-\n", 958 | "Input sentence: Listen.\n", 959 | "Decoded sentence: Mintente en arira.\n", 960 | "\n", 961 | "-\n", 962 | "Input sentence: Listen.\n", 963 | "Decoded sentence: Mintente en arira.\n", 964 | "\n", 965 | "-\n", 966 | "Input sentence: Listen.\n", 967 | "Decoded sentence: Mintente en arira.\n", 968 | "\n", 969 | "-\n", 970 | "Input sentence: No way!\n", 971 | "Decoded sentence: ¡De pirra!\n", 972 | "\n", 973 | "-\n", 974 | "Input sentence: No way!\n", 975 | "Decoded sentence: ¡De pirra!\n", 976 | "\n", 977 | "-\n", 978 | "Input sentence: No way!\n", 979 | "Decoded sentence: ¡De pirra!\n", 980 | "\n", 981 | "-\n", 982 | "Input sentence: No way!\n", 983 | "Decoded sentence: ¡De pirra!\n", 984 | "\n", 985 | "-\n", 986 | "Input sentence: No way!\n", 987 | "Decoded sentence: ¡De pirra!\n", 988 | "\n", 989 | "-\n", 990 | "Input sentence: No way!\n", 991 | "Decoded sentence: ¡De pirra!\n", 992 | "\n", 993 | "-\n", 994 | "Input sentence: No way!\n", 995 | "Decoded sentence: ¡De pirra!\n", 996 | "\n", 997 | "-\n", 998 | "Input sentence: No way!\n", 999 | "Decoded sentence: ¡De pirra!\n", 1000 | "\n", 1001 | "-\n", 1002 | "Input sentence: No way!\n", 1003 | "Decoded sentence: ¡De pirra!\n", 1004 | "\n", 1005 | "-\n", 1006 | "Input sentence: No way!\n", 1007 | "Decoded sentence: ¡De pirra!\n", 1008 | "\n", 1009 | "-\n", 1010 | "Input sentence: Really?\n", 1011 | "Decoded sentence: ¿Pien iro?\n", 1012 | "\n", 1013 | "-\n", 1014 | "Input sentence: Really?\n", 1015 | "Decoded sentence: ¿Pien iro?\n", 1016 | "\n", 1017 | "-\n", 1018 | "Input sentence: Thanks.\n", 1019 | "Decoded sentence: Alala asto.\n", 1020 | "\n", 1021 | "-\n", 1022 | "Input sentence: Thanks.\n", 1023 | "Decoded sentence: Alala asto.\n", 1024 | "\n", 1025 | "-\n", 1026 | "Input sentence: Try it.\n", 1027 | "Decoded sentence: Pruente alo.\n", 1028 | "\n", 1029 | "-\n", 1030 | "Input sentence: We try.\n", 1031 | "Decoded sentence: Pruén astora.\n", 1032 | "\n", 1033 | "-\n", 1034 | "Input sentence: We won.\n", 1035 | "Decoded sentence: Minte alo.\n", 1036 | "\n", 1037 | "-\n", 1038 | "Input sentence: Why me?\n", 1039 | "Decoded sentence: ¿Quién an carado?\n", 1040 | "\n", 1041 | "-\n", 1042 | "Input sentence: Ask Tom.\n", 1043 | "Decoded sentence: Ayude a Tom.\n", 1044 | "\n", 1045 | "-\n", 1046 | "Input sentence: Awesome!\n", 1047 | "Decoded sentence: ¡Desperra!\n", 1048 | "\n", 1049 | "-\n", 1050 | "Input sentence: Be calm.\n", 1051 | "Decoded sentence: Séntate carra.\n", 1052 | "\n", 1053 | "-\n", 1054 | "Input sentence: Be cool.\n", 1055 | "Decoded sentence: Sénta carri.\n", 1056 | "\n", 1057 | "-\n", 1058 | "Input sentence: Be fair.\n", 1059 | "Decoded sentence: Sé antrente.\n", 1060 | "\n", 1061 | "-\n", 1062 | "Input sentence: Be kind.\n", 1063 | "Decoded sentence: Sén antre.\n", 1064 | "\n", 1065 | "-\n", 1066 | "Input sentence: Be nice.\n", 1067 | "Decoded sentence: Sénta cire.\n", 1068 | "\n", 1069 | "-\n", 1070 | "Input sentence: Beat it.\n", 1071 | "Decoded sentence: Séla asto.\n", 1072 | "\n", 1073 | "-\n", 1074 | "Input sentence: Call me.\n", 1075 | "Decoded sentence: Allama a Tom.\n", 1076 | "\n", 1077 | "-\n", 1078 | "Input sentence: Call me.\n", 1079 | "Decoded sentence: Allama a Tom.\n", 1080 | "\n", 1081 | "-\n", 1082 | "Input sentence: Call me.\n", 1083 | "Decoded sentence: Allama a Tom.\n", 1084 | "\n", 1085 | "-\n", 1086 | "Input sentence: Call us.\n", 1087 | "Decoded sentence: ¡Llamame a Tomásto.\n", 1088 | "\n", 1089 | "-\n", 1090 | "Input sentence: Come in.\n", 1091 | "Decoded sentence: Vente a Tom.\n", 1092 | "\n", 1093 | "-\n", 1094 | "Input sentence: Come in.\n", 1095 | "Decoded sentence: Vente a Tom.\n", 1096 | "\n", 1097 | "-\n", 1098 | "Input sentence: Come in.\n", 1099 | "Decoded sentence: Vente a Tom.\n", 1100 | "\n", 1101 | "-\n", 1102 | "Input sentence: Come on!\n", 1103 | "Decoded sentence: ¡Vete alo!\n", 1104 | "\n", 1105 | "-\n", 1106 | "Input sentence: Come on.\n", 1107 | "Decoded sentence: Vente aquí.\n", 1108 | "\n", 1109 | "-\n", 1110 | "Input sentence: Come on.\n", 1111 | "Decoded sentence: Vente aquí.\n", 1112 | "\n", 1113 | "-\n", 1114 | "Input sentence: Drop it!\n", 1115 | "Decoded sentence: ¡De prirdo!\n", 1116 | "\n", 1117 | "-\n", 1118 | "Input sentence: Get Tom.\n", 1119 | "Decoded sentence: Ayada a Tomás.\n", 1120 | "\n", 1121 | "-\n", 1122 | "Input sentence: Get out!\n", 1123 | "Decoded sentence: ¡La te alo!\n", 1124 | "\n", 1125 | "-\n", 1126 | "Input sentence: Get out.\n", 1127 | "Decoded sentence: Alata alo.\n", 1128 | "\n", 1129 | "-\n", 1130 | "Input sentence: Get out.\n", 1131 | "Decoded sentence: Alata alo.\n", 1132 | "\n", 1133 | "-\n", 1134 | "Input sentence: Get out.\n", 1135 | "Decoded sentence: Alata alo.\n", 1136 | "\n", 1137 | "-\n", 1138 | "Input sentence: Get out.\n", 1139 | "Decoded sentence: Alata alo.\n", 1140 | "\n", 1141 | "-\n", 1142 | "Input sentence: Get out.\n", 1143 | "Decoded sentence: Alata alo.\n", 1144 | "\n", 1145 | "-\n", 1146 | "Input sentence: Go away!\n", 1147 | "Decoded sentence: ¡La te a aquí!\n", 1148 | "\n", 1149 | "-\n", 1150 | "Input sentence: Go away!\n", 1151 | "Decoded sentence: ¡La te a aquí!\n", 1152 | "\n" 1153 | ], 1154 | "name": "stdout" 1155 | } 1156 | ] 1157 | } 1158 | ] 1159 | } -------------------------------------------------------------------------------- /PAIGCP_image_captioning.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "accelerator": "GPU", 6 | "colab": { 7 | "name": "PAIGCP_image_captioning.ipynb", 8 | "provenance": [], 9 | "private_outputs": true, 10 | "collapsed_sections": [], 11 | "include_colab_link": true 12 | }, 13 | "kernelspec": { 14 | "display_name": "Python 3", 15 | "name": "python3" 16 | } 17 | }, 18 | "cells": [ 19 | { 20 | "cell_type": "markdown", 21 | "metadata": { 22 | "id": "view-in-github", 23 | "colab_type": "text" 24 | }, 25 | "source": [ 26 | "\"Open" 27 | ] 28 | }, 29 | { 30 | "cell_type": "markdown", 31 | "metadata": { 32 | "colab_type": "text", 33 | "id": "K2s1A9eLRPEj" 34 | }, 35 | "source": [ 36 | "##### Copyright 2018 The TensorFlow Authors.\n" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "metadata": { 42 | "cellView": "form", 43 | "colab_type": "code", 44 | "id": "VRLVEKiTEn04", 45 | "colab": {} 46 | }, 47 | "source": [ 48 | "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", 49 | "# you may not use this file except in compliance with the License.\n", 50 | "# You may obtain a copy of the License at\n", 51 | "#\n", 52 | "# https://www.apache.org/licenses/LICENSE-2.0\n", 53 | "#\n", 54 | "# Unless required by applicable law or agreed to in writing, software\n", 55 | "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", 56 | "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", 57 | "# See the License for the specific language governing permissions and\n", 58 | "# limitations under the License." 59 | ], 60 | "execution_count": null, 61 | "outputs": [] 62 | }, 63 | { 64 | "cell_type": "markdown", 65 | "metadata": { 66 | "colab_type": "text", 67 | "id": "Cffg2i257iMS" 68 | }, 69 | "source": [ 70 | "# Image captioning with visual attention\n", 71 | "\n", 72 | "\n", 73 | " \n", 78 | " \n", 83 | " \n", 88 | " \n", 91 | "
\n", 74 | " \n", 75 | " \n", 76 | " View on TensorFlow.org\n", 77 | " \n", 79 | " \n", 80 | " \n", 81 | " Run in Google Colab\n", 82 | " \n", 84 | " \n", 85 | " \n", 86 | " View source on GitHub\n", 87 | " \n", 89 | " Download notebook\n", 90 | "
" 92 | ] 93 | }, 94 | { 95 | "cell_type": "markdown", 96 | "metadata": { 97 | "colab_type": "text", 98 | "id": "QASbY_HGo4Lq" 99 | }, 100 | "source": [ 101 | "Given an image like the example below, our goal is to generate a caption such as \"a surfer riding on a wave\".\n", 102 | "\n", 103 | "![Man Surfing](https://tensorflow.org/images/surf.jpg)\n", 104 | "\n", 105 | "*[Image Source](https://commons.wikimedia.org/wiki/Surfing#/media/File:Surfing_in_Hawaii.jpg); License: Public Domain*\n", 106 | "\n", 107 | "To accomplish this, you'll use an attention-based model, which enables us to see what parts of the image the model focuses on as it generates a caption.\n", 108 | "\n", 109 | "![Prediction](https://tensorflow.org/images/imcap_prediction.png)\n", 110 | "\n", 111 | "The model architecture is similar to [Show, Attend and Tell: Neural Image Caption Generation with Visual Attention](https://arxiv.org/abs/1502.03044).\n", 112 | "\n", 113 | "This notebook is an end-to-end example. When you run the notebook, it downloads the [MS-COCO](http://cocodataset.org/#home) dataset, preprocesses and caches a subset of images using Inception V3, trains an encoder-decoder model, and generates captions on new images using the trained model.\n", 114 | "\n", 115 | "In this example, you will train a model on a relatively small amount of data—the first 30,000 captions for about 20,000 images (because there are multiple captions per image in the dataset)." 116 | ] 117 | }, 118 | { 119 | "cell_type": "code", 120 | "metadata": { 121 | "colab_type": "code", 122 | "id": "U8l4RJ0XRPEm", 123 | "colab": {} 124 | }, 125 | "source": [ 126 | "import tensorflow as tf\n", 127 | "\n", 128 | "# You'll generate plots of attention in order to see which parts of an image\n", 129 | "# our model focuses on during captioning\n", 130 | "import matplotlib.pyplot as plt\n", 131 | "\n", 132 | "# Scikit-learn includes many helpful utilities\n", 133 | "from sklearn.model_selection import train_test_split\n", 134 | "from sklearn.utils import shuffle\n", 135 | "\n", 136 | "import re\n", 137 | "import numpy as np\n", 138 | "import os\n", 139 | "import time\n", 140 | "import json\n", 141 | "from glob import glob\n", 142 | "from PIL import Image\n", 143 | "import pickle" 144 | ], 145 | "execution_count": null, 146 | "outputs": [] 147 | }, 148 | { 149 | "cell_type": "markdown", 150 | "metadata": { 151 | "colab_type": "text", 152 | "id": "b6qbGw8MRPE5" 153 | }, 154 | "source": [ 155 | "## Download and prepare the MS-COCO dataset\n", 156 | "\n", 157 | "You will use the [MS-COCO dataset](http://cocodataset.org/#home) to train our model. The dataset contains over 82,000 images, each of which has at least 5 different caption annotations. The code below downloads and extracts the dataset automatically.\n", 158 | "\n", 159 | "**Caution: large download ahead**. You'll use the training set, which is a 13GB file." 160 | ] 161 | }, 162 | { 163 | "cell_type": "code", 164 | "metadata": { 165 | "colab_type": "code", 166 | "id": "krQuPYTtRPE7", 167 | "colab": {} 168 | }, 169 | "source": [ 170 | "# Download caption annotation files\n", 171 | "annotation_folder = '/annotations/'\n", 172 | "if not os.path.exists(os.path.abspath('.') + annotation_folder):\n", 173 | " annotation_zip = tf.keras.utils.get_file('captions.zip',\n", 174 | " cache_subdir=os.path.abspath('.'),\n", 175 | " origin = 'http://images.cocodataset.org/annotations/annotations_trainval2014.zip',\n", 176 | " extract = True)\n", 177 | " annotation_file = os.path.dirname(annotation_zip)+'/annotations/captions_train2014.json'\n", 178 | " os.remove(annotation_zip)\n", 179 | "\n", 180 | "# Download image files\n", 181 | "image_folder = '/train2014/'\n", 182 | "if not os.path.exists(os.path.abspath('.') + image_folder):\n", 183 | " image_zip = tf.keras.utils.get_file('train2014.zip',\n", 184 | " cache_subdir=os.path.abspath('.'),\n", 185 | " origin = 'http://images.cocodataset.org/zips/train2014.zip',\n", 186 | " extract = True)\n", 187 | " PATH = os.path.dirname(image_zip) + image_folder\n", 188 | " os.remove(image_zip)\n", 189 | "else:\n", 190 | " PATH = os.path.abspath('.') + image_folder" 191 | ], 192 | "execution_count": null, 193 | "outputs": [] 194 | }, 195 | { 196 | "cell_type": "markdown", 197 | "metadata": { 198 | "colab_type": "text", 199 | "id": "aANEzb5WwSzg" 200 | }, 201 | "source": [ 202 | "## Optional: limit the size of the training set \n", 203 | "To speed up training for this tutorial, you'll use a subset of 30,000 captions and their corresponding images to train our model. Choosing to use more data would result in improved captioning quality." 204 | ] 205 | }, 206 | { 207 | "cell_type": "code", 208 | "metadata": { 209 | "colab_type": "code", 210 | "id": "4G3b8x8_RPFD", 211 | "colab": {} 212 | }, 213 | "source": [ 214 | "# Read the json file\n", 215 | "with open(annotation_file, 'r') as f:\n", 216 | " annotations = json.load(f)\n", 217 | "\n", 218 | "# Store captions and image names in vectors\n", 219 | "all_captions = []\n", 220 | "all_img_name_vector = []\n", 221 | "\n", 222 | "for annot in annotations['annotations']:\n", 223 | " caption = ' ' + annot['caption'] + ' '\n", 224 | " image_id = annot['image_id']\n", 225 | " full_coco_image_path = PATH + 'COCO_train2014_' + '%012d.jpg' % (image_id)\n", 226 | "\n", 227 | " all_img_name_vector.append(full_coco_image_path)\n", 228 | " all_captions.append(caption)\n", 229 | "\n", 230 | "# Shuffle captions and image_names together\n", 231 | "# Set a random state\n", 232 | "train_captions, img_name_vector = shuffle(all_captions,\n", 233 | " all_img_name_vector,\n", 234 | " random_state=1)\n", 235 | "\n", 236 | "# Select the first 30000 captions from the shuffled set\n", 237 | "num_examples = 30000\n", 238 | "train_captions = train_captions[:num_examples]\n", 239 | "img_name_vector = img_name_vector[:num_examples]" 240 | ], 241 | "execution_count": null, 242 | "outputs": [] 243 | }, 244 | { 245 | "cell_type": "code", 246 | "metadata": { 247 | "colab_type": "code", 248 | "id": "mPBMgK34RPFL", 249 | "colab": {} 250 | }, 251 | "source": [ 252 | "len(train_captions), len(all_captions)" 253 | ], 254 | "execution_count": null, 255 | "outputs": [] 256 | }, 257 | { 258 | "cell_type": "markdown", 259 | "metadata": { 260 | "colab_type": "text", 261 | "id": "8cSW4u-ORPFQ" 262 | }, 263 | "source": [ 264 | "## Preprocess the images using InceptionV3\n", 265 | "Next, you will use InceptionV3 (which is pretrained on Imagenet) to classify each image. You will extract features from the last convolutional layer.\n", 266 | "\n", 267 | "First, you will convert the images into InceptionV3's expected format by:\n", 268 | "* Resizing the image to 299px by 299px\n", 269 | "* [Preprocess the images](https://cloud.google.com/tpu/docs/inception-v3-advanced#preprocessing_stage) using the [preprocess_input](https://www.tensorflow.org/api_docs/python/tf/keras/applications/inception_v3/preprocess_input) method to normalize the image so that it contains pixels in the range of -1 to 1, which matches the format of the images used to train InceptionV3." 270 | ] 271 | }, 272 | { 273 | "cell_type": "code", 274 | "metadata": { 275 | "colab_type": "code", 276 | "id": "zXR0217aRPFR", 277 | "colab": {} 278 | }, 279 | "source": [ 280 | "def load_image(image_path):\n", 281 | " img = tf.io.read_file(image_path)\n", 282 | " img = tf.image.decode_jpeg(img, channels=3)\n", 283 | " img = tf.image.resize(img, (299, 299))\n", 284 | " img = tf.keras.applications.inception_v3.preprocess_input(img)\n", 285 | " return img, image_path" 286 | ], 287 | "execution_count": null, 288 | "outputs": [] 289 | }, 290 | { 291 | "cell_type": "markdown", 292 | "metadata": { 293 | "colab_type": "text", 294 | "id": "MDvIu4sXRPFV" 295 | }, 296 | "source": [ 297 | "## Initialize InceptionV3 and load the pretrained Imagenet weights\n", 298 | "\n", 299 | "Now you'll create a tf.keras model where the output layer is the last convolutional layer in the InceptionV3 architecture. The shape of the output of this layer is ```8x8x2048```. You use the last convolutional layer because you are using attention in this example. You don't perform this initialization during training because it could become a bottleneck.\n", 300 | "\n", 301 | "* You forward each image through the network and store the resulting vector in a dictionary (image_name --> feature_vector).\n", 302 | "* After all the images are passed through the network, you pickle the dictionary and save it to disk.\n" 303 | ] 304 | }, 305 | { 306 | "cell_type": "code", 307 | "metadata": { 308 | "colab_type": "code", 309 | "id": "RD3vW4SsRPFW", 310 | "colab": {} 311 | }, 312 | "source": [ 313 | "image_model = tf.keras.applications.InceptionV3(include_top=False,\n", 314 | " weights='imagenet')\n", 315 | "new_input = image_model.input\n", 316 | "hidden_layer = image_model.layers[-1].output\n", 317 | "\n", 318 | "image_features_extract_model = tf.keras.Model(new_input, hidden_layer)" 319 | ], 320 | "execution_count": null, 321 | "outputs": [] 322 | }, 323 | { 324 | "cell_type": "markdown", 325 | "metadata": { 326 | "colab_type": "text", 327 | "id": "rERqlR3WRPGO" 328 | }, 329 | "source": [ 330 | "## Caching the features extracted from InceptionV3\n", 331 | "\n", 332 | "You will pre-process each image with InceptionV3 and cache the output to disk. Caching the output in RAM would be faster but also memory intensive, requiring 8 \\* 8 \\* 2048 floats per image. At the time of writing, this exceeds the memory limitations of Colab (currently 12GB of memory).\n", 333 | "\n", 334 | "Performance could be improved with a more sophisticated caching strategy (for example, by sharding the images to reduce random access disk I/O), but that would require more code.\n", 335 | "\n", 336 | "The caching will take about 10 minutes to run in Colab with a GPU. If you'd like to see a progress bar, you can: \n", 337 | "\n", 338 | "1. install [tqdm](https://github.com/tqdm/tqdm):\n", 339 | "\n", 340 | " `!pip install tqdm`\n", 341 | "\n", 342 | "2. Import tqdm:\n", 343 | "\n", 344 | " `from tqdm import tqdm`\n", 345 | "\n", 346 | "3. Change the following line:\n", 347 | "\n", 348 | " `for img, path in image_dataset:`\n", 349 | "\n", 350 | " to:\n", 351 | "\n", 352 | " `for img, path in tqdm(image_dataset):`\n" 353 | ] 354 | }, 355 | { 356 | "cell_type": "code", 357 | "metadata": { 358 | "colab_type": "code", 359 | "id": "Dx_fvbVgRPGQ", 360 | "colab": {} 361 | }, 362 | "source": [ 363 | "# Get unique images\n", 364 | "encode_train = sorted(set(img_name_vector))\n", 365 | "\n", 366 | "# Feel free to change batch_size according to your system configuration\n", 367 | "image_dataset = tf.data.Dataset.from_tensor_slices(encode_train)\n", 368 | "image_dataset = image_dataset.map(\n", 369 | " load_image, num_parallel_calls=tf.data.experimental.AUTOTUNE).batch(16)\n", 370 | "\n", 371 | "for img, path in image_dataset:\n", 372 | " batch_features = image_features_extract_model(img)\n", 373 | " batch_features = tf.reshape(batch_features,\n", 374 | " (batch_features.shape[0], -1, batch_features.shape[3]))\n", 375 | "\n", 376 | " for bf, p in zip(batch_features, path):\n", 377 | " path_of_feature = p.numpy().decode(\"utf-8\")\n", 378 | " np.save(path_of_feature, bf.numpy())" 379 | ], 380 | "execution_count": null, 381 | "outputs": [] 382 | }, 383 | { 384 | "cell_type": "markdown", 385 | "metadata": { 386 | "colab_type": "text", 387 | "id": "nyqH3zFwRPFi" 388 | }, 389 | "source": [ 390 | "## Preprocess and tokenize the captions\n", 391 | "\n", 392 | "* First, you'll tokenize the captions (for example, by splitting on spaces). This gives us a vocabulary of all of the unique words in the data (for example, \"surfing\", \"football\", and so on).\n", 393 | "* Next, you'll limit the vocabulary size to the top 5,000 words (to save memory). You'll replace all other words with the token \"UNK\" (unknown).\n", 394 | "* You then create word-to-index and index-to-word mappings.\n", 395 | "* Finally, you pad all sequences to be the same length as the longest one." 396 | ] 397 | }, 398 | { 399 | "cell_type": "code", 400 | "metadata": { 401 | "colab_type": "code", 402 | "id": "HZfK8RhQRPFj", 403 | "colab": {} 404 | }, 405 | "source": [ 406 | "# Find the maximum length of any caption in our dataset\n", 407 | "def calc_max_length(tensor):\n", 408 | " return max(len(t) for t in tensor)" 409 | ], 410 | "execution_count": null, 411 | "outputs": [] 412 | }, 413 | { 414 | "cell_type": "code", 415 | "metadata": { 416 | "colab_type": "code", 417 | "id": "oJGE34aiRPFo", 418 | "colab": {} 419 | }, 420 | "source": [ 421 | "# Choose the top 5000 words from the vocabulary\n", 422 | "top_k = 5000\n", 423 | "tokenizer = tf.keras.preprocessing.text.Tokenizer(num_words=top_k,\n", 424 | " oov_token=\"\",\n", 425 | " filters='!\"#$%&()*+.,-/:;=?@[\\]^_`{|}~ ')\n", 426 | "tokenizer.fit_on_texts(train_captions)\n", 427 | "train_seqs = tokenizer.texts_to_sequences(train_captions)" 428 | ], 429 | "execution_count": null, 430 | "outputs": [] 431 | }, 432 | { 433 | "cell_type": "code", 434 | "metadata": { 435 | "colab_type": "code", 436 | "id": "8Q44tNQVRPFt", 437 | "colab": {} 438 | }, 439 | "source": [ 440 | "tokenizer.word_index[''] = 0\n", 441 | "tokenizer.index_word[0] = ''" 442 | ], 443 | "execution_count": null, 444 | "outputs": [] 445 | }, 446 | { 447 | "cell_type": "code", 448 | "metadata": { 449 | "colab_type": "code", 450 | "id": "0fpJb5ojRPFv", 451 | "colab": {} 452 | }, 453 | "source": [ 454 | "# Create the tokenized vectors\n", 455 | "train_seqs = tokenizer.texts_to_sequences(train_captions)" 456 | ], 457 | "execution_count": null, 458 | "outputs": [] 459 | }, 460 | { 461 | "cell_type": "code", 462 | "metadata": { 463 | "colab_type": "code", 464 | "id": "AidglIZVRPF4", 465 | "colab": {} 466 | }, 467 | "source": [ 468 | "# Pad each vector to the max_length of the captions\n", 469 | "# If you do not provide a max_length value, pad_sequences calculates it automatically\n", 470 | "cap_vector = tf.keras.preprocessing.sequence.pad_sequences(train_seqs, padding='post')" 471 | ], 472 | "execution_count": null, 473 | "outputs": [] 474 | }, 475 | { 476 | "cell_type": "code", 477 | "metadata": { 478 | "colab_type": "code", 479 | "id": "gL0wkttkRPGA", 480 | "colab": {} 481 | }, 482 | "source": [ 483 | "# Calculates the max_length, which is used to store the attention weights\n", 484 | "max_length = calc_max_length(train_seqs)" 485 | ], 486 | "execution_count": null, 487 | "outputs": [] 488 | }, 489 | { 490 | "cell_type": "markdown", 491 | "metadata": { 492 | "colab_type": "text", 493 | "id": "M3CD75nDpvTI" 494 | }, 495 | "source": [ 496 | "## Split the data into training and testing" 497 | ] 498 | }, 499 | { 500 | "cell_type": "code", 501 | "metadata": { 502 | "colab_type": "code", 503 | "id": "iS7DDMszRPGF", 504 | "colab": {} 505 | }, 506 | "source": [ 507 | "# Create training and validation sets using an 80-20 split\n", 508 | "img_name_train, img_name_val, cap_train, cap_val = train_test_split(img_name_vector,\n", 509 | " cap_vector,\n", 510 | " test_size=0.2,\n", 511 | " random_state=0)" 512 | ], 513 | "execution_count": null, 514 | "outputs": [] 515 | }, 516 | { 517 | "cell_type": "code", 518 | "metadata": { 519 | "colab_type": "code", 520 | "id": "XmViPkRFRPGH", 521 | "colab": {} 522 | }, 523 | "source": [ 524 | "len(img_name_train), len(cap_train), len(img_name_val), len(cap_val)" 525 | ], 526 | "execution_count": null, 527 | "outputs": [] 528 | }, 529 | { 530 | "cell_type": "markdown", 531 | "metadata": { 532 | "colab_type": "text", 533 | "id": "uEWM9xrYcg45" 534 | }, 535 | "source": [ 536 | "## Create a tf.data dataset for training\n" 537 | ] 538 | }, 539 | { 540 | "cell_type": "markdown", 541 | "metadata": { 542 | "colab_type": "text", 543 | "id": "horagNvhhZiy" 544 | }, 545 | "source": [ 546 | " Our images and captions are ready! Next, let's create a tf.data dataset to use for training our model." 547 | ] 548 | }, 549 | { 550 | "cell_type": "code", 551 | "metadata": { 552 | "colab_type": "code", 553 | "id": "Q3TnZ1ToRPGV", 554 | "colab": {} 555 | }, 556 | "source": [ 557 | "# Feel free to change these parameters according to your system's configuration\n", 558 | "\n", 559 | "BATCH_SIZE = 64\n", 560 | "BUFFER_SIZE = 1000\n", 561 | "embedding_dim = 256\n", 562 | "units = 512\n", 563 | "vocab_size = top_k + 1\n", 564 | "num_steps = len(img_name_train) // BATCH_SIZE\n", 565 | "# Shape of the vector extracted from InceptionV3 is (64, 2048)\n", 566 | "# These two variables represent that vector shape\n", 567 | "features_shape = 2048\n", 568 | "attention_features_shape = 64" 569 | ], 570 | "execution_count": null, 571 | "outputs": [] 572 | }, 573 | { 574 | "cell_type": "code", 575 | "metadata": { 576 | "colab_type": "code", 577 | "id": "SmZS2N0bXG3T", 578 | "colab": {} 579 | }, 580 | "source": [ 581 | "# Load the numpy files\n", 582 | "def map_func(img_name, cap):\n", 583 | " img_tensor = np.load(img_name.decode('utf-8')+'.npy')\n", 584 | " return img_tensor, cap" 585 | ], 586 | "execution_count": null, 587 | "outputs": [] 588 | }, 589 | { 590 | "cell_type": "code", 591 | "metadata": { 592 | "colab_type": "code", 593 | "id": "FDF_Nm3tRPGZ", 594 | "colab": {} 595 | }, 596 | "source": [ 597 | "dataset = tf.data.Dataset.from_tensor_slices((img_name_train, cap_train))\n", 598 | "\n", 599 | "# Use map to load the numpy files in parallel\n", 600 | "dataset = dataset.map(lambda item1, item2: tf.numpy_function(\n", 601 | " map_func, [item1, item2], [tf.float32, tf.int32]),\n", 602 | " num_parallel_calls=tf.data.experimental.AUTOTUNE)\n", 603 | "\n", 604 | "# Shuffle and batch\n", 605 | "dataset = dataset.shuffle(BUFFER_SIZE).batch(BATCH_SIZE)\n", 606 | "dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)" 607 | ], 608 | "execution_count": null, 609 | "outputs": [] 610 | }, 611 | { 612 | "cell_type": "markdown", 613 | "metadata": { 614 | "colab_type": "text", 615 | "id": "nrvoDphgRPGd" 616 | }, 617 | "source": [ 618 | "## Model\n", 619 | "\n", 620 | "Fun fact: the decoder below is identical to the one in the example for [Neural Machine Translation with Attention](../sequences/nmt_with_attention.ipynb).\n", 621 | "\n", 622 | "The model architecture is inspired by the [Show, Attend and Tell](https://arxiv.org/pdf/1502.03044.pdf) paper.\n", 623 | "\n", 624 | "* In this example, you extract the features from the lower convolutional layer of InceptionV3 giving us a vector of shape (8, 8, 2048).\n", 625 | "* You squash that to a shape of (64, 2048).\n", 626 | "* This vector is then passed through the CNN Encoder (which consists of a single Fully connected layer).\n", 627 | "* The RNN (here GRU) attends over the image to predict the next word." 628 | ] 629 | }, 630 | { 631 | "cell_type": "code", 632 | "metadata": { 633 | "colab_type": "code", 634 | "id": "ja2LFTMSdeV3", 635 | "colab": {} 636 | }, 637 | "source": [ 638 | "class BahdanauAttention(tf.keras.Model):\n", 639 | " def __init__(self, units):\n", 640 | " super(BahdanauAttention, self).__init__()\n", 641 | " self.W1 = tf.keras.layers.Dense(units)\n", 642 | " self.W2 = tf.keras.layers.Dense(units)\n", 643 | " self.V = tf.keras.layers.Dense(1)\n", 644 | "\n", 645 | " def call(self, features, hidden):\n", 646 | " # features(CNN_encoder output) shape == (batch_size, 64, embedding_dim)\n", 647 | "\n", 648 | " # hidden shape == (batch_size, hidden_size)\n", 649 | " # hidden_with_time_axis shape == (batch_size, 1, hidden_size)\n", 650 | " hidden_with_time_axis = tf.expand_dims(hidden, 1)\n", 651 | "\n", 652 | " # score shape == (batch_size, 64, hidden_size)\n", 653 | " score = tf.nn.tanh(self.W1(features) + self.W2(hidden_with_time_axis))\n", 654 | "\n", 655 | " # attention_weights shape == (batch_size, 64, 1)\n", 656 | " # you get 1 at the last axis because you are applying score to self.V\n", 657 | " attention_weights = tf.nn.softmax(self.V(score), axis=1)\n", 658 | "\n", 659 | " # context_vector shape after sum == (batch_size, hidden_size)\n", 660 | " context_vector = attention_weights * features\n", 661 | " context_vector = tf.reduce_sum(context_vector, axis=1)\n", 662 | "\n", 663 | " return context_vector, attention_weights" 664 | ], 665 | "execution_count": null, 666 | "outputs": [] 667 | }, 668 | { 669 | "cell_type": "code", 670 | "metadata": { 671 | "colab_type": "code", 672 | "id": "AZ7R1RxHRPGf", 673 | "colab": {} 674 | }, 675 | "source": [ 676 | "class CNN_Encoder(tf.keras.Model):\n", 677 | " # Since you have already extracted the features and dumped it using pickle\n", 678 | " # This encoder passes those features through a Fully connected layer\n", 679 | " def __init__(self, embedding_dim):\n", 680 | " super(CNN_Encoder, self).__init__()\n", 681 | " # shape after fc == (batch_size, 64, embedding_dim)\n", 682 | " self.fc = tf.keras.layers.Dense(embedding_dim)\n", 683 | "\n", 684 | " def call(self, x):\n", 685 | " x = self.fc(x)\n", 686 | " x = tf.nn.relu(x)\n", 687 | " return x" 688 | ], 689 | "execution_count": null, 690 | "outputs": [] 691 | }, 692 | { 693 | "cell_type": "code", 694 | "metadata": { 695 | "colab_type": "code", 696 | "id": "V9UbGQmERPGi", 697 | "colab": {} 698 | }, 699 | "source": [ 700 | "class RNN_Decoder(tf.keras.Model):\n", 701 | " def __init__(self, embedding_dim, units, vocab_size):\n", 702 | " super(RNN_Decoder, self).__init__()\n", 703 | " self.units = units\n", 704 | "\n", 705 | " self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)\n", 706 | " self.gru = tf.keras.layers.GRU(self.units,\n", 707 | " return_sequences=True,\n", 708 | " return_state=True,\n", 709 | " recurrent_initializer='glorot_uniform')\n", 710 | " self.fc1 = tf.keras.layers.Dense(self.units)\n", 711 | " self.fc2 = tf.keras.layers.Dense(vocab_size)\n", 712 | "\n", 713 | " self.attention = BahdanauAttention(self.units)\n", 714 | "\n", 715 | " def call(self, x, features, hidden):\n", 716 | " # defining attention as a separate model\n", 717 | " context_vector, attention_weights = self.attention(features, hidden)\n", 718 | "\n", 719 | " # x shape after passing through embedding == (batch_size, 1, embedding_dim)\n", 720 | " x = self.embedding(x)\n", 721 | "\n", 722 | " # x shape after concatenation == (batch_size, 1, embedding_dim + hidden_size)\n", 723 | " x = tf.concat([tf.expand_dims(context_vector, 1), x], axis=-1)\n", 724 | "\n", 725 | " # passing the concatenated vector to the GRU\n", 726 | " output, state = self.gru(x)\n", 727 | "\n", 728 | " # shape == (batch_size, max_length, hidden_size)\n", 729 | " x = self.fc1(output)\n", 730 | "\n", 731 | " # x shape == (batch_size * max_length, hidden_size)\n", 732 | " x = tf.reshape(x, (-1, x.shape[2]))\n", 733 | "\n", 734 | " # output shape == (batch_size * max_length, vocab)\n", 735 | " x = self.fc2(x)\n", 736 | "\n", 737 | " return x, state, attention_weights\n", 738 | "\n", 739 | " def reset_state(self, batch_size):\n", 740 | " return tf.zeros((batch_size, self.units))" 741 | ], 742 | "execution_count": null, 743 | "outputs": [] 744 | }, 745 | { 746 | "cell_type": "code", 747 | "metadata": { 748 | "colab_type": "code", 749 | "id": "Qs_Sr03wRPGk", 750 | "colab": {} 751 | }, 752 | "source": [ 753 | "encoder = CNN_Encoder(embedding_dim)\n", 754 | "decoder = RNN_Decoder(embedding_dim, units, vocab_size)" 755 | ], 756 | "execution_count": null, 757 | "outputs": [] 758 | }, 759 | { 760 | "cell_type": "code", 761 | "metadata": { 762 | "colab_type": "code", 763 | "id": "-bYN7xA0RPGl", 764 | "colab": {} 765 | }, 766 | "source": [ 767 | "optimizer = tf.keras.optimizers.Adam()\n", 768 | "loss_object = tf.keras.losses.SparseCategoricalCrossentropy(\n", 769 | " from_logits=True, reduction='none')\n", 770 | "\n", 771 | "def loss_function(real, pred):\n", 772 | " mask = tf.math.logical_not(tf.math.equal(real, 0))\n", 773 | " loss_ = loss_object(real, pred)\n", 774 | "\n", 775 | " mask = tf.cast(mask, dtype=loss_.dtype)\n", 776 | " loss_ *= mask\n", 777 | "\n", 778 | " return tf.reduce_mean(loss_)" 779 | ], 780 | "execution_count": null, 781 | "outputs": [] 782 | }, 783 | { 784 | "cell_type": "markdown", 785 | "metadata": { 786 | "colab_type": "text", 787 | "id": "6A3Ni64joyab" 788 | }, 789 | "source": [ 790 | "## Checkpoint" 791 | ] 792 | }, 793 | { 794 | "cell_type": "code", 795 | "metadata": { 796 | "colab_type": "code", 797 | "id": "PpJAqPMWo0uE", 798 | "colab": {} 799 | }, 800 | "source": [ 801 | "checkpoint_path = \"./checkpoints/train\"\n", 802 | "ckpt = tf.train.Checkpoint(encoder=encoder,\n", 803 | " decoder=decoder,\n", 804 | " optimizer = optimizer)\n", 805 | "ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=5)" 806 | ], 807 | "execution_count": null, 808 | "outputs": [] 809 | }, 810 | { 811 | "cell_type": "code", 812 | "metadata": { 813 | "colab_type": "code", 814 | "id": "fUkbqhc_uObw", 815 | "colab": {} 816 | }, 817 | "source": [ 818 | "start_epoch = 0\n", 819 | "if ckpt_manager.latest_checkpoint:\n", 820 | " start_epoch = int(ckpt_manager.latest_checkpoint.split('-')[-1])\n", 821 | " # restoring the latest checkpoint in checkpoint_path\n", 822 | " ckpt.restore(ckpt_manager.latest_checkpoint)" 823 | ], 824 | "execution_count": null, 825 | "outputs": [] 826 | }, 827 | { 828 | "cell_type": "markdown", 829 | "metadata": { 830 | "colab_type": "text", 831 | "id": "PHod7t72RPGn" 832 | }, 833 | "source": [ 834 | "## Training\n", 835 | "\n", 836 | "* You extract the features stored in the respective `.npy` files and then pass those features through the encoder.\n", 837 | "* The encoder output, hidden state(initialized to 0) and the decoder input (which is the start token) is passed to the decoder.\n", 838 | "* The decoder returns the predictions and the decoder hidden state.\n", 839 | "* The decoder hidden state is then passed back into the model and the predictions are used to calculate the loss.\n", 840 | "* Use teacher forcing to decide the next input to the decoder.\n", 841 | "* Teacher forcing is the technique where the target word is passed as the next input to the decoder.\n", 842 | "* The final step is to calculate the gradients and apply it to the optimizer and backpropagate.\n" 843 | ] 844 | }, 845 | { 846 | "cell_type": "code", 847 | "metadata": { 848 | "colab_type": "code", 849 | "id": "Vt4WZ5mhJE-E", 850 | "colab": {} 851 | }, 852 | "source": [ 853 | "# adding this in a separate cell because if you run the training cell\n", 854 | "# many times, the loss_plot array will be reset\n", 855 | "loss_plot = []" 856 | ], 857 | "execution_count": null, 858 | "outputs": [] 859 | }, 860 | { 861 | "cell_type": "code", 862 | "metadata": { 863 | "colab_type": "code", 864 | "id": "sqgyz2ANKlpU", 865 | "colab": {} 866 | }, 867 | "source": [ 868 | "@tf.function\n", 869 | "def train_step(img_tensor, target):\n", 870 | " loss = 0\n", 871 | "\n", 872 | " # initializing the hidden state for each batch\n", 873 | " # because the captions are not related from image to image\n", 874 | " hidden = decoder.reset_state(batch_size=target.shape[0])\n", 875 | "\n", 876 | " dec_input = tf.expand_dims([tokenizer.word_index['']] * target.shape[0], 1)\n", 877 | "\n", 878 | " with tf.GradientTape() as tape:\n", 879 | " features = encoder(img_tensor)\n", 880 | "\n", 881 | " for i in range(1, target.shape[1]):\n", 882 | " # passing the features through the decoder\n", 883 | " predictions, hidden, _ = decoder(dec_input, features, hidden)\n", 884 | "\n", 885 | " loss += loss_function(target[:, i], predictions)\n", 886 | "\n", 887 | " # using teacher forcing\n", 888 | " dec_input = tf.expand_dims(target[:, i], 1)\n", 889 | "\n", 890 | " total_loss = (loss / int(target.shape[1]))\n", 891 | "\n", 892 | " trainable_variables = encoder.trainable_variables + decoder.trainable_variables\n", 893 | "\n", 894 | " gradients = tape.gradient(loss, trainable_variables)\n", 895 | "\n", 896 | " optimizer.apply_gradients(zip(gradients, trainable_variables))\n", 897 | "\n", 898 | " return loss, total_loss" 899 | ], 900 | "execution_count": null, 901 | "outputs": [] 902 | }, 903 | { 904 | "cell_type": "code", 905 | "metadata": { 906 | "colab_type": "code", 907 | "id": "UlA4VIQpRPGo", 908 | "colab": {} 909 | }, 910 | "source": [ 911 | "EPOCHS = 20\n", 912 | "\n", 913 | "for epoch in range(start_epoch, EPOCHS):\n", 914 | " start = time.time()\n", 915 | " total_loss = 0\n", 916 | "\n", 917 | " for (batch, (img_tensor, target)) in enumerate(dataset):\n", 918 | " batch_loss, t_loss = train_step(img_tensor, target)\n", 919 | " total_loss += t_loss\n", 920 | "\n", 921 | " if batch % 100 == 0:\n", 922 | " print ('Epoch {} Batch {} Loss {:.4f}'.format(\n", 923 | " epoch + 1, batch, batch_loss.numpy() / int(target.shape[1])))\n", 924 | " # storing the epoch end loss value to plot later\n", 925 | " loss_plot.append(total_loss / num_steps)\n", 926 | "\n", 927 | " if epoch % 5 == 0:\n", 928 | " ckpt_manager.save()\n", 929 | "\n", 930 | " print ('Epoch {} Loss {:.6f}'.format(epoch + 1,\n", 931 | " total_loss/num_steps))\n", 932 | " print ('Time taken for 1 epoch {} sec\\n'.format(time.time() - start))" 933 | ], 934 | "execution_count": null, 935 | "outputs": [] 936 | }, 937 | { 938 | "cell_type": "code", 939 | "metadata": { 940 | "colab_type": "code", 941 | "id": "1Wm83G-ZBPcC", 942 | "colab": {} 943 | }, 944 | "source": [ 945 | "plt.plot(loss_plot)\n", 946 | "plt.xlabel('Epochs')\n", 947 | "plt.ylabel('Loss')\n", 948 | "plt.title('Loss Plot')\n", 949 | "plt.show()" 950 | ], 951 | "execution_count": null, 952 | "outputs": [] 953 | }, 954 | { 955 | "cell_type": "markdown", 956 | "metadata": { 957 | "colab_type": "text", 958 | "id": "xGvOcLQKghXN" 959 | }, 960 | "source": [ 961 | "## Caption!\n", 962 | "\n", 963 | "* The evaluate function is similar to the training loop, except you don't use teacher forcing here. The input to the decoder at each time step is its previous predictions along with the hidden state and the encoder output.\n", 964 | "* Stop predicting when the model predicts the end token.\n", 965 | "* And store the attention weights for every time step." 966 | ] 967 | }, 968 | { 969 | "cell_type": "code", 970 | "metadata": { 971 | "colab_type": "code", 972 | "id": "RCWpDtyNRPGs", 973 | "colab": {} 974 | }, 975 | "source": [ 976 | "def evaluate(image):\n", 977 | " attention_plot = np.zeros((max_length, attention_features_shape))\n", 978 | "\n", 979 | " hidden = decoder.reset_state(batch_size=1)\n", 980 | "\n", 981 | " temp_input = tf.expand_dims(load_image(image)[0], 0)\n", 982 | " img_tensor_val = image_features_extract_model(temp_input)\n", 983 | " img_tensor_val = tf.reshape(img_tensor_val, (img_tensor_val.shape[0], -1, img_tensor_val.shape[3]))\n", 984 | "\n", 985 | " features = encoder(img_tensor_val)\n", 986 | "\n", 987 | " dec_input = tf.expand_dims([tokenizer.word_index['']], 0)\n", 988 | " result = []\n", 989 | "\n", 990 | " for i in range(max_length):\n", 991 | " predictions, hidden, attention_weights = decoder(dec_input, features, hidden)\n", 992 | "\n", 993 | " attention_plot[i] = tf.reshape(attention_weights, (-1, )).numpy()\n", 994 | "\n", 995 | " predicted_id = tf.random.categorical(predictions, 1)[0][0].numpy()\n", 996 | " result.append(tokenizer.index_word[predicted_id])\n", 997 | "\n", 998 | " if tokenizer.index_word[predicted_id] == '':\n", 999 | " return result, attention_plot\n", 1000 | "\n", 1001 | " dec_input = tf.expand_dims([predicted_id], 0)\n", 1002 | "\n", 1003 | " attention_plot = attention_plot[:len(result), :]\n", 1004 | " return result, attention_plot" 1005 | ], 1006 | "execution_count": null, 1007 | "outputs": [] 1008 | }, 1009 | { 1010 | "cell_type": "code", 1011 | "metadata": { 1012 | "colab_type": "code", 1013 | "id": "fD_y7PD6RPGt", 1014 | "colab": {} 1015 | }, 1016 | "source": [ 1017 | "def plot_attention(image, result, attention_plot):\n", 1018 | " temp_image = np.array(Image.open(image))\n", 1019 | "\n", 1020 | " fig = plt.figure(figsize=(10, 10))\n", 1021 | "\n", 1022 | " len_result = len(result)\n", 1023 | " for l in range(len_result):\n", 1024 | " temp_att = np.resize(attention_plot[l], (8, 8))\n", 1025 | " ax = fig.add_subplot(len_result//2, len_result//2, l+1)\n", 1026 | " ax.set_title(result[l])\n", 1027 | " img = ax.imshow(temp_image)\n", 1028 | " ax.imshow(temp_att, cmap='gray', alpha=0.6, extent=img.get_extent())\n", 1029 | "\n", 1030 | " plt.tight_layout()\n", 1031 | " plt.show()" 1032 | ], 1033 | "execution_count": null, 1034 | "outputs": [] 1035 | }, 1036 | { 1037 | "cell_type": "code", 1038 | "metadata": { 1039 | "colab_type": "code", 1040 | "id": "7x8RiPHe_4qI", 1041 | "colab": {} 1042 | }, 1043 | "source": [ 1044 | "# captions on the validation set\n", 1045 | "rid = np.random.randint(0, len(img_name_val))\n", 1046 | "image = img_name_val[rid]\n", 1047 | "real_caption = ' '.join([tokenizer.index_word[i] for i in cap_val[rid] if i not in [0]])\n", 1048 | "result, attention_plot = evaluate(image)\n", 1049 | "\n", 1050 | "print ('Real Caption:', real_caption)\n", 1051 | "print ('Prediction Caption:', ' '.join(result))\n", 1052 | "plot_attention(image, result, attention_plot)\n" 1053 | ], 1054 | "execution_count": null, 1055 | "outputs": [] 1056 | }, 1057 | { 1058 | "cell_type": "markdown", 1059 | "metadata": { 1060 | "colab_type": "text", 1061 | "id": "Rprk3HEvZuxb" 1062 | }, 1063 | "source": [ 1064 | "## Try it on your own images\n", 1065 | "For fun, below we've provided a method you can use to caption your own images with the model we've just trained. Keep in mind, it was trained on a relatively small amount of data, and your images may be different from the training data (so be prepared for weird results!)\n" 1066 | ] 1067 | }, 1068 | { 1069 | "cell_type": "code", 1070 | "metadata": { 1071 | "colab_type": "code", 1072 | "id": "9Psd1quzaAWg", 1073 | "colab": {} 1074 | }, 1075 | "source": [ 1076 | "image_url = 'https://tensorflow.org/images/surf.jpg'\n", 1077 | "image_extension = image_url[-4:]\n", 1078 | "image_path = tf.keras.utils.get_file('image'+image_extension,\n", 1079 | " origin=image_url)\n", 1080 | "\n", 1081 | "result, attention_plot = evaluate(image_path)\n", 1082 | "print ('Prediction Caption:', ' '.join(result))\n", 1083 | "plot_attention(image_path, result, attention_plot)\n", 1084 | "# opening the image\n", 1085 | "Image.open(image_path)" 1086 | ], 1087 | "execution_count": null, 1088 | "outputs": [] 1089 | }, 1090 | { 1091 | "cell_type": "markdown", 1092 | "metadata": { 1093 | "colab_type": "text", 1094 | "id": "VJZXyJco6uLO" 1095 | }, 1096 | "source": [ 1097 | "# Next steps\n", 1098 | "\n", 1099 | "Congrats! You've just trained an image captioning model with attention. Next, take a look at this example [Neural Machine Translation with Attention](../sequences/nmt_with_attention.ipynb). It uses a similar architecture to translate between Spanish and English sentences. You can also experiment with training the code in this notebook on a different dataset." 1100 | ] 1101 | } 1102 | ] 1103 | } -------------------------------------------------------------------------------- /PAIGCP_text_cleaning.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "PAIGCP_text_cleaning.ipynb", 7 | "provenance": [], 8 | "authorship_tag": "ABX9TyMSP4WgeV5o3jxbRp687qIZ", 9 | "include_colab_link": true 10 | }, 11 | "kernelspec": { 12 | "name": "python3", 13 | "display_name": "Python 3" 14 | } 15 | }, 16 | "cells": [ 17 | { 18 | "cell_type": "markdown", 19 | "metadata": { 20 | "id": "view-in-github", 21 | "colab_type": "text" 22 | }, 23 | "source": [ 24 | "\"Open" 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "metadata": { 30 | "id": "DbkMepZnYVHs", 31 | "colab_type": "code", 32 | "colab": {} 33 | }, 34 | "source": [ 35 | "import urllib.request" 36 | ], 37 | "execution_count": 1, 38 | "outputs": [] 39 | }, 40 | { 41 | "cell_type": "markdown", 42 | "metadata": { 43 | "id": "SLUJ5DYpY_yH", 44 | "colab_type": "text" 45 | }, 46 | "source": [ 47 | "Pull down Mody Dick" 48 | ] 49 | }, 50 | { 51 | "cell_type": "code", 52 | "metadata": { 53 | "id": "dPyg06asYkY_", 54 | "colab_type": "code", 55 | "colab": {} 56 | }, 57 | "source": [ 58 | "url = \"https://www.gutenberg.org/files/2701/2701-0.txt\"\n", 59 | "file = urllib.request.urlopen(url)" 60 | ], 61 | "execution_count": 2, 62 | "outputs": [] 63 | }, 64 | { 65 | "cell_type": "markdown", 66 | "metadata": { 67 | "id": "XFw0-dPOZDCZ", 68 | "colab_type": "text" 69 | }, 70 | "source": [ 71 | "Load the text" 72 | ] 73 | }, 74 | { 75 | "cell_type": "code", 76 | "metadata": { 77 | "id": "ER0LYVxLYo2k", 78 | "colab_type": "code", 79 | "colab": { 80 | "base_uri": "https://localhost:8080/", 81 | "height": 52 82 | }, 83 | "outputId": "755b6782-95fc-4e1e-cccd-9e0b360f4cfb" 84 | }, 85 | "source": [ 86 | "text = [line.decode('utf-8') for line in file]\n", 87 | "text = ''.join(text)\n", 88 | "text[7600:8000]" 89 | ], 90 | "execution_count": 3, 91 | "outputs": [ 92 | { 93 | "output_type": "execute_result", 94 | "data": { 95 | "application/vnd.google.colaboratory.intrinsic+json": { 96 | "type": "string" 97 | }, 98 | "text/plain": [ 99 | "'ok whatsoever,\\r\\n sacred or profane. Therefore you must not, in every case at least,\\r\\n take the higgledy-piggledy whale statements, however authentic, in\\r\\n these extracts, for veritable gospel cetology. Far from it. As\\r\\n touching the ancient authors generally, as well as the poets here\\r\\n appearing, these extracts are solely valuable or entertaining, as\\r\\n affording a glancing bird’s eye view o'" 100 | ] 101 | }, 102 | "metadata": { 103 | "tags": [] 104 | }, 105 | "execution_count": 3 106 | } 107 | ] 108 | }, 109 | { 110 | "cell_type": "markdown", 111 | "metadata": { 112 | "id": "yN2YwlWcY72h", 113 | "colab_type": "text" 114 | }, 115 | "source": [ 116 | "Tokenize\n" 117 | ] 118 | }, 119 | { 120 | "cell_type": "code", 121 | "metadata": { 122 | "id": "WmXGjk_YY9Wc", 123 | "colab_type": "code", 124 | "colab": { 125 | "base_uri": "https://localhost:8080/", 126 | "height": 391 127 | }, 128 | "outputId": "b0b21983-ad30-49d5-c573-cf29a8412344" 129 | }, 130 | "source": [ 131 | "tokens = text.split()\n", 132 | "tokens[200:222]" 133 | ], 134 | "execution_count": 5, 135 | "outputs": [ 136 | { 137 | "output_type": "execute_result", 138 | "data": { 139 | "text/plain": [ 140 | "['Merry',\n", 141 | " 'Christmas.',\n", 142 | " 'CHAPTER',\n", 143 | " '23.',\n", 144 | " 'The',\n", 145 | " 'Lee',\n", 146 | " 'Shore.',\n", 147 | " 'CHAPTER',\n", 148 | " '24.',\n", 149 | " 'The',\n", 150 | " 'Advocate.',\n", 151 | " 'CHAPTER',\n", 152 | " '25.',\n", 153 | " 'Postscript.',\n", 154 | " 'CHAPTER',\n", 155 | " '26.',\n", 156 | " 'Knights',\n", 157 | " 'and',\n", 158 | " 'Squires.',\n", 159 | " 'CHAPTER',\n", 160 | " '27.',\n", 161 | " 'Knights']" 162 | ] 163 | }, 164 | "metadata": { 165 | "tags": [] 166 | }, 167 | "execution_count": 5 168 | } 169 | ] 170 | }, 171 | { 172 | "cell_type": "markdown", 173 | "metadata": { 174 | "id": "GHvzrkUsZUs8", 175 | "colab_type": "text" 176 | }, 177 | "source": [ 178 | "Lowercase" 179 | ] 180 | }, 181 | { 182 | "cell_type": "code", 183 | "metadata": { 184 | "id": "bb2pv0mLZLxV", 185 | "colab_type": "code", 186 | "colab": { 187 | "base_uri": "https://localhost:8080/", 188 | "height": 391 189 | }, 190 | "outputId": "b8d5ddd8-06af-4abe-9d8f-1823752c145d" 191 | }, 192 | "source": [ 193 | "tokens = text.lower().split()\n", 194 | "tokens[200:222]" 195 | ], 196 | "execution_count": 6, 197 | "outputs": [ 198 | { 199 | "output_type": "execute_result", 200 | "data": { 201 | "text/plain": [ 202 | "['merry',\n", 203 | " 'christmas.',\n", 204 | " 'chapter',\n", 205 | " '23.',\n", 206 | " 'the',\n", 207 | " 'lee',\n", 208 | " 'shore.',\n", 209 | " 'chapter',\n", 210 | " '24.',\n", 211 | " 'the',\n", 212 | " 'advocate.',\n", 213 | " 'chapter',\n", 214 | " '25.',\n", 215 | " 'postscript.',\n", 216 | " 'chapter',\n", 217 | " '26.',\n", 218 | " 'knights',\n", 219 | " 'and',\n", 220 | " 'squires.',\n", 221 | " 'chapter',\n", 222 | " '27.',\n", 223 | " 'knights']" 224 | ] 225 | }, 226 | "metadata": { 227 | "tags": [] 228 | }, 229 | "execution_count": 6 230 | } 231 | ] 232 | }, 233 | { 234 | "cell_type": "markdown", 235 | "metadata": { 236 | "id": "zA8tdJU0ZXZC", 237 | "colab_type": "text" 238 | }, 239 | "source": [ 240 | "Punctuation" 241 | ] 242 | }, 243 | { 244 | "cell_type": "code", 245 | "metadata": { 246 | "id": "_qQo2cT7ZZg5", 247 | "colab_type": "code", 248 | "colab": { 249 | "base_uri": "https://localhost:8080/", 250 | "height": 391 251 | }, 252 | "outputId": "ee052edb-ccd8-4f12-f0c3-96380698d779" 253 | }, 254 | "source": [ 255 | "import string\n", 256 | "table = str.maketrans('', '', string.punctuation)\n", 257 | "\n", 258 | "tokens = [w.translate(table) for w in tokens]\n", 259 | "\n", 260 | "tokens[200:222]" 261 | ], 262 | "execution_count": 7, 263 | "outputs": [ 264 | { 265 | "output_type": "execute_result", 266 | "data": { 267 | "text/plain": [ 268 | "['merry',\n", 269 | " 'christmas',\n", 270 | " 'chapter',\n", 271 | " '23',\n", 272 | " 'the',\n", 273 | " 'lee',\n", 274 | " 'shore',\n", 275 | " 'chapter',\n", 276 | " '24',\n", 277 | " 'the',\n", 278 | " 'advocate',\n", 279 | " 'chapter',\n", 280 | " '25',\n", 281 | " 'postscript',\n", 282 | " 'chapter',\n", 283 | " '26',\n", 284 | " 'knights',\n", 285 | " 'and',\n", 286 | " 'squires',\n", 287 | " 'chapter',\n", 288 | " '27',\n", 289 | " 'knights']" 290 | ] 291 | }, 292 | "metadata": { 293 | "tags": [] 294 | }, 295 | "execution_count": 7 296 | } 297 | ] 298 | }, 299 | { 300 | "cell_type": "markdown", 301 | "metadata": { 302 | "id": "2EHfQLMOZoOt", 303 | "colab_type": "text" 304 | }, 305 | "source": [ 306 | "Alpha only" 307 | ] 308 | }, 309 | { 310 | "cell_type": "code", 311 | "metadata": { 312 | "id": "Npw5JfpeZp-2", 313 | "colab_type": "code", 314 | "colab": { 315 | "base_uri": "https://localhost:8080/", 316 | "height": 391 317 | }, 318 | "outputId": "b6d7c335-ccea-4f07-dbc1-8d08486ac027" 319 | }, 320 | "source": [ 321 | "tokens = [word for word in tokens if word.isalpha()]\n", 322 | "tokens[200:222]" 323 | ], 324 | "execution_count": 8, 325 | "outputs": [ 326 | { 327 | "output_type": "execute_result", 328 | "data": { 329 | "text/plain": [ 330 | "['queen',\n", 331 | " 'mab',\n", 332 | " 'chapter',\n", 333 | " 'cetology',\n", 334 | " 'chapter',\n", 335 | " 'the',\n", 336 | " 'specksnyder',\n", 337 | " 'chapter',\n", 338 | " 'the',\n", 339 | " 'cabintable',\n", 340 | " 'chapter',\n", 341 | " 'the',\n", 342 | " 'masthead',\n", 343 | " 'chapter',\n", 344 | " 'the',\n", 345 | " 'quarterdeck',\n", 346 | " 'chapter',\n", 347 | " 'sunset',\n", 348 | " 'chapter',\n", 349 | " 'dusk',\n", 350 | " 'chapter',\n", 351 | " 'first']" 352 | ] 353 | }, 354 | "metadata": { 355 | "tags": [] 356 | }, 357 | "execution_count": 8 358 | } 359 | ] 360 | } 361 | ] 362 | } -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PAIGCP 2 | Course content for Practical AI on the Google Cloud Platform 3 | --------------------------------------------------------------------------------