├── README.md ├── LICENSE └── Roberta_NER_WithText.ipynb /README.md: -------------------------------------------------------------------------------- 1 | # RoBERTa-NER -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /Roberta_NER_WithText.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "provenance": [], 7 | "gpuType": "T4" 8 | }, 9 | "kernelspec": { 10 | "name": "python3", 11 | "display_name": "Python 3" 12 | }, 13 | "language_info": { 14 | "name": "python" 15 | }, 16 | "accelerator": "GPU", 17 | "widgets": { 18 | "application/vnd.jupyter.widget-state+json": { 19 | "167150fc901e444ebf8bcf7143a257b3": { 20 | "model_module": "@jupyter-widgets/controls", 21 | "model_name": "HBoxModel", 22 | "model_module_version": "1.5.0", 23 | "state": { 24 | "_dom_classes": [], 25 | "_model_module": "@jupyter-widgets/controls", 26 | "_model_module_version": "1.5.0", 27 | "_model_name": "HBoxModel", 28 | "_view_count": null, 29 | "_view_module": "@jupyter-widgets/controls", 30 | "_view_module_version": "1.5.0", 31 | "_view_name": "HBoxView", 32 | "box_style": "", 33 | "children": [ 34 | "IPY_MODEL_8be485f189ce4b7e950c57334b989601", 35 | "IPY_MODEL_fc5cb6b2c47247e8a59769b6e583549b", 36 | "IPY_MODEL_b397c4608f98450fbcb81ad28e427c04" 37 | ], 38 | "layout": "IPY_MODEL_477973dccac1415baff1c06e0d3dc836" 39 | } 40 | }, 41 | "8be485f189ce4b7e950c57334b989601": { 42 | "model_module": "@jupyter-widgets/controls", 43 | "model_name": "HTMLModel", 44 | "model_module_version": "1.5.0", 45 | "state": { 46 | "_dom_classes": [], 47 | "_model_module": "@jupyter-widgets/controls", 48 | "_model_module_version": "1.5.0", 49 | "_model_name": "HTMLModel", 50 | "_view_count": null, 51 | "_view_module": "@jupyter-widgets/controls", 52 | "_view_module_version": "1.5.0", 53 | "_view_name": "HTMLView", 54 | "description": "", 55 | "description_tooltip": null, 56 | "layout": "IPY_MODEL_62ec567181fa473b9267cbfa34e5bb8b", 57 | "placeholder": "​", 58 | "style": "IPY_MODEL_d4d71ee95b0847bf95f040c640457f2f", 59 | "value": "100%" 60 | } 61 | }, 62 | "fc5cb6b2c47247e8a59769b6e583549b": { 63 | "model_module": "@jupyter-widgets/controls", 64 | "model_name": "FloatProgressModel", 65 | "model_module_version": "1.5.0", 66 | "state": { 67 | "_dom_classes": [], 68 | "_model_module": "@jupyter-widgets/controls", 69 | "_model_module_version": "1.5.0", 70 | "_model_name": "FloatProgressModel", 71 | "_view_count": null, 72 | "_view_module": "@jupyter-widgets/controls", 73 | "_view_module_version": "1.5.0", 74 | "_view_name": "ProgressView", 75 | "bar_style": "success", 76 | "description": "", 77 | "description_tooltip": null, 78 | "layout": "IPY_MODEL_10c6b510aea040939d357e09b82ed8dd", 79 | "max": 3, 80 | "min": 0, 81 | "orientation": "horizontal", 82 | "style": "IPY_MODEL_7554a5b794844f5482fab12edb702ee2", 83 | "value": 3 84 | } 85 | }, 86 | "b397c4608f98450fbcb81ad28e427c04": { 87 | "model_module": "@jupyter-widgets/controls", 88 | "model_name": "HTMLModel", 89 | "model_module_version": "1.5.0", 90 | "state": { 91 | "_dom_classes": [], 92 | "_model_module": "@jupyter-widgets/controls", 93 | "_model_module_version": "1.5.0", 94 | "_model_name": "HTMLModel", 95 | "_view_count": null, 96 | "_view_module": "@jupyter-widgets/controls", 97 | "_view_module_version": "1.5.0", 98 | "_view_name": "HTMLView", 99 | "description": "", 100 | "description_tooltip": null, 101 | "layout": "IPY_MODEL_31f1dab6c44a4c80b48b462b436a80e0", 102 | "placeholder": "​", 103 | "style": "IPY_MODEL_3f90e5333a8e45428725fdc2c0a56f70", 104 | "value": " 3/3 [00:00<00:00, 114.10it/s]" 105 | } 106 | }, 107 | "477973dccac1415baff1c06e0d3dc836": { 108 | "model_module": "@jupyter-widgets/base", 109 | "model_name": "LayoutModel", 110 | "model_module_version": "1.2.0", 111 | "state": { 112 | "_model_module": "@jupyter-widgets/base", 113 | "_model_module_version": "1.2.0", 114 | "_model_name": "LayoutModel", 115 | "_view_count": null, 116 | "_view_module": "@jupyter-widgets/base", 117 | "_view_module_version": "1.2.0", 118 | "_view_name": "LayoutView", 119 | "align_content": null, 120 | "align_items": null, 121 | "align_self": null, 122 | "border": null, 123 | "bottom": null, 124 | "display": null, 125 | "flex": null, 126 | "flex_flow": null, 127 | "grid_area": null, 128 | "grid_auto_columns": null, 129 | "grid_auto_flow": null, 130 | "grid_auto_rows": null, 131 | "grid_column": null, 132 | "grid_gap": null, 133 | "grid_row": null, 134 | "grid_template_areas": null, 135 | "grid_template_columns": null, 136 | "grid_template_rows": null, 137 | "height": null, 138 | "justify_content": null, 139 | "justify_items": null, 140 | "left": null, 141 | "margin": null, 142 | "max_height": null, 143 | "max_width": null, 144 | "min_height": null, 145 | "min_width": null, 146 | "object_fit": null, 147 | "object_position": null, 148 | "order": null, 149 | "overflow": null, 150 | "overflow_x": null, 151 | "overflow_y": null, 152 | "padding": null, 153 | "right": null, 154 | "top": null, 155 | "visibility": null, 156 | "width": null 157 | } 158 | }, 159 | "62ec567181fa473b9267cbfa34e5bb8b": { 160 | "model_module": "@jupyter-widgets/base", 161 | "model_name": "LayoutModel", 162 | "model_module_version": "1.2.0", 163 | "state": { 164 | "_model_module": "@jupyter-widgets/base", 165 | "_model_module_version": "1.2.0", 166 | "_model_name": "LayoutModel", 167 | "_view_count": null, 168 | "_view_module": "@jupyter-widgets/base", 169 | "_view_module_version": "1.2.0", 170 | "_view_name": "LayoutView", 171 | "align_content": null, 172 | "align_items": null, 173 | "align_self": null, 174 | "border": null, 175 | "bottom": null, 176 | "display": null, 177 | "flex": null, 178 | "flex_flow": null, 179 | "grid_area": null, 180 | "grid_auto_columns": null, 181 | "grid_auto_flow": null, 182 | "grid_auto_rows": null, 183 | "grid_column": null, 184 | "grid_gap": null, 185 | "grid_row": null, 186 | "grid_template_areas": null, 187 | "grid_template_columns": null, 188 | "grid_template_rows": null, 189 | "height": null, 190 | "justify_content": null, 191 | "justify_items": null, 192 | "left": null, 193 | "margin": null, 194 | "max_height": null, 195 | "max_width": null, 196 | "min_height": null, 197 | "min_width": null, 198 | "object_fit": null, 199 | "object_position": null, 200 | "order": null, 201 | "overflow": null, 202 | "overflow_x": null, 203 | "overflow_y": null, 204 | "padding": null, 205 | "right": null, 206 | "top": null, 207 | "visibility": null, 208 | "width": null 209 | } 210 | }, 211 | "d4d71ee95b0847bf95f040c640457f2f": { 212 | "model_module": "@jupyter-widgets/controls", 213 | "model_name": "DescriptionStyleModel", 214 | "model_module_version": "1.5.0", 215 | "state": { 216 | "_model_module": "@jupyter-widgets/controls", 217 | "_model_module_version": "1.5.0", 218 | "_model_name": "DescriptionStyleModel", 219 | "_view_count": null, 220 | "_view_module": "@jupyter-widgets/base", 221 | "_view_module_version": "1.2.0", 222 | "_view_name": "StyleView", 223 | "description_width": "" 224 | } 225 | }, 226 | "10c6b510aea040939d357e09b82ed8dd": { 227 | "model_module": "@jupyter-widgets/base", 228 | "model_name": "LayoutModel", 229 | "model_module_version": "1.2.0", 230 | "state": { 231 | "_model_module": "@jupyter-widgets/base", 232 | "_model_module_version": "1.2.0", 233 | "_model_name": "LayoutModel", 234 | "_view_count": null, 235 | "_view_module": "@jupyter-widgets/base", 236 | "_view_module_version": "1.2.0", 237 | "_view_name": "LayoutView", 238 | "align_content": null, 239 | "align_items": null, 240 | "align_self": null, 241 | "border": null, 242 | "bottom": null, 243 | "display": null, 244 | "flex": null, 245 | "flex_flow": null, 246 | "grid_area": null, 247 | "grid_auto_columns": null, 248 | "grid_auto_flow": null, 249 | "grid_auto_rows": null, 250 | "grid_column": null, 251 | "grid_gap": null, 252 | "grid_row": null, 253 | "grid_template_areas": null, 254 | "grid_template_columns": null, 255 | "grid_template_rows": null, 256 | "height": null, 257 | "justify_content": null, 258 | "justify_items": null, 259 | "left": null, 260 | "margin": null, 261 | "max_height": null, 262 | "max_width": null, 263 | "min_height": null, 264 | "min_width": null, 265 | "object_fit": null, 266 | "object_position": null, 267 | "order": null, 268 | "overflow": null, 269 | "overflow_x": null, 270 | "overflow_y": null, 271 | "padding": null, 272 | "right": null, 273 | "top": null, 274 | "visibility": null, 275 | "width": null 276 | } 277 | }, 278 | "7554a5b794844f5482fab12edb702ee2": { 279 | "model_module": "@jupyter-widgets/controls", 280 | "model_name": "ProgressStyleModel", 281 | "model_module_version": "1.5.0", 282 | "state": { 283 | "_model_module": "@jupyter-widgets/controls", 284 | "_model_module_version": "1.5.0", 285 | "_model_name": "ProgressStyleModel", 286 | "_view_count": null, 287 | "_view_module": "@jupyter-widgets/base", 288 | "_view_module_version": "1.2.0", 289 | "_view_name": "StyleView", 290 | "bar_color": null, 291 | "description_width": "" 292 | } 293 | }, 294 | "31f1dab6c44a4c80b48b462b436a80e0": { 295 | "model_module": "@jupyter-widgets/base", 296 | "model_name": "LayoutModel", 297 | "model_module_version": "1.2.0", 298 | "state": { 299 | "_model_module": "@jupyter-widgets/base", 300 | "_model_module_version": "1.2.0", 301 | "_model_name": "LayoutModel", 302 | "_view_count": null, 303 | "_view_module": "@jupyter-widgets/base", 304 | "_view_module_version": "1.2.0", 305 | "_view_name": "LayoutView", 306 | "align_content": null, 307 | "align_items": null, 308 | "align_self": null, 309 | "border": null, 310 | "bottom": null, 311 | "display": null, 312 | "flex": null, 313 | "flex_flow": null, 314 | "grid_area": null, 315 | "grid_auto_columns": null, 316 | "grid_auto_flow": null, 317 | "grid_auto_rows": null, 318 | "grid_column": null, 319 | "grid_gap": null, 320 | "grid_row": null, 321 | "grid_template_areas": null, 322 | "grid_template_columns": null, 323 | "grid_template_rows": null, 324 | "height": null, 325 | "justify_content": null, 326 | "justify_items": null, 327 | "left": null, 328 | "margin": null, 329 | "max_height": null, 330 | "max_width": null, 331 | "min_height": null, 332 | "min_width": null, 333 | "object_fit": null, 334 | "object_position": null, 335 | "order": null, 336 | "overflow": null, 337 | "overflow_x": null, 338 | "overflow_y": null, 339 | "padding": null, 340 | "right": null, 341 | "top": null, 342 | "visibility": null, 343 | "width": null 344 | } 345 | }, 346 | "3f90e5333a8e45428725fdc2c0a56f70": { 347 | "model_module": "@jupyter-widgets/controls", 348 | "model_name": "DescriptionStyleModel", 349 | "model_module_version": "1.5.0", 350 | "state": { 351 | "_model_module": "@jupyter-widgets/controls", 352 | "_model_module_version": "1.5.0", 353 | "_model_name": "DescriptionStyleModel", 354 | "_view_count": null, 355 | "_view_module": "@jupyter-widgets/base", 356 | "_view_module_version": "1.2.0", 357 | "_view_name": "StyleView", 358 | "description_width": "" 359 | } 360 | } 361 | } 362 | } 363 | }, 364 | "cells": [ 365 | { 366 | "cell_type": "markdown", 367 | "source": [ 368 | "# [KerasNLP] Named Entity Recognition using RoBERTa\n", 369 | "\n", 370 | "**Author:** [Usha Rengaraju](https://www.linkedin.com/in/usha-rengaraju-b570b7a2/)
\n", 371 | "**Date created:** 2023/07/10
\n", 372 | "**Last modified:** 2023/07/10
\n", 373 | "**Description:** Named Entity Recognition using pretrained RoBERTa\n" 374 | ], 375 | "metadata": { 376 | "id": "EKRr1Vkvvcar" 377 | } 378 | }, 379 | { 380 | "cell_type": "markdown", 381 | "source": [ 382 | "## Overview\n", 383 | "\n", 384 | "Named entity recognition (NER) is an NLP task that extracts information from text. NER detects and categorizes important information in text known as named entities.\n", 385 | "\n", 386 | "KerasNLP has a variety of pretrained models available. In this guide we create the whole NER pipeline using the pretrained Roberta Backbone.\n" 387 | ], 388 | "metadata": { 389 | "id": "pcanbuwJ7PUX" 390 | } 391 | }, 392 | { 393 | "cell_type": "markdown", 394 | "source": [ 395 | "## Imports & setup\n", 396 | "\n", 397 | "This tutorial requires you to have KeraNLP installed:\n", 398 | "\n", 399 | "```shell\n", 400 | "pip install keras-nlp\n", 401 | "```\n", 402 | "\n", 403 | "We begin by importing all required packages:" 404 | ], 405 | "metadata": { 406 | "id": "DmC_kCnI7VPq" 407 | } 408 | }, 409 | { 410 | "cell_type": "code", 411 | "execution_count": null, 412 | "metadata": { 413 | "id": "4kbJNQkGYVUU", 414 | "colab": { 415 | "base_uri": "https://localhost:8080/" 416 | }, 417 | "outputId": "72e1e296-a133-4095-bf28-55e5d8e27a52" 418 | }, 419 | "outputs": [ 420 | { 421 | "output_type": "stream", 422 | "name": "stdout", 423 | "text": [ 424 | "\u001b[?25l \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m0.0/486.2 kB\u001b[0m \u001b[31m?\u001b[0m eta \u001b[36m-:--:--\u001b[0m\r\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m486.2/486.2 kB\u001b[0m \u001b[31m29.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", 425 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m110.5/110.5 kB\u001b[0m \u001b[31m13.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", 426 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m212.5/212.5 kB\u001b[0m \u001b[31m20.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", 427 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m134.3/134.3 kB\u001b[0m \u001b[31m14.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", 428 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m268.8/268.8 kB\u001b[0m \u001b[31m30.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", 429 | "\u001b[?25h--2023-07-08 13:24:16-- https://raw.githubusercontent.com/sighsmile/conlleval/master/conlleval.py\n", 430 | "Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...\n", 431 | "Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.\n", 432 | "HTTP request sent, awaiting response... 200 OK\n", 433 | "Length: 7502 (7.3K) [text/plain]\n", 434 | "Saving to: ‘conlleval.py’\n", 435 | "\n", 436 | "conlleval.py 100%[===================>] 7.33K --.-KB/s in 0s \n", 437 | "\n", 438 | "2023-07-08 13:24:17 (99.1 MB/s) - ‘conlleval.py’ saved [7502/7502]\n", 439 | "\n" 440 | ] 441 | } 442 | ], 443 | "source": [ 444 | "!pip3 install -q datasets\n", 445 | "!wget https://raw.githubusercontent.com/sighsmile/conlleval/master/conlleval.py" 446 | ] 447 | }, 448 | { 449 | "cell_type": "code", 450 | "execution_count": null, 451 | "metadata": { 452 | "id": "LiNd4e7hYVUV" 453 | }, 454 | "outputs": [], 455 | "source": [ 456 | "import os\n", 457 | "import numpy as np\n", 458 | "import tensorflow as tf\n", 459 | "from tensorflow import keras\n", 460 | "from tensorflow.keras import layers\n", 461 | "from datasets import load_dataset\n", 462 | "from collections import Counter\n", 463 | "from conlleval import evaluate\n", 464 | "import keras_nlp" 465 | ] 466 | }, 467 | { 468 | "cell_type": "markdown", 469 | "source": [ 470 | "## Data loading\n", 471 | "\n", 472 | "This guide uses the\n", 473 | "[Conll 2003 dataset](https://huggingface.co/datasets/conll2003)\n", 474 | "for demonstration purposes.\n", 475 | "\n", 476 | "To get started, we first download and unzip the dataset:" 477 | ], 478 | "metadata": { 479 | "id": "r8CwaHyl_8oX" 480 | } 481 | }, 482 | { 483 | "cell_type": "code", 484 | "source": [ 485 | "conll_data = load_dataset(\"conll2003\")" 486 | ], 487 | "metadata": { 488 | "colab": { 489 | "base_uri": "https://localhost:8080/", 490 | "height": 86, 491 | "referenced_widgets": [ 492 | "167150fc901e444ebf8bcf7143a257b3", 493 | "8be485f189ce4b7e950c57334b989601", 494 | "fc5cb6b2c47247e8a59769b6e583549b", 495 | "b397c4608f98450fbcb81ad28e427c04", 496 | "477973dccac1415baff1c06e0d3dc836", 497 | "62ec567181fa473b9267cbfa34e5bb8b", 498 | "d4d71ee95b0847bf95f040c640457f2f", 499 | "10c6b510aea040939d357e09b82ed8dd", 500 | "7554a5b794844f5482fab12edb702ee2", 501 | "31f1dab6c44a4c80b48b462b436a80e0", 502 | "3f90e5333a8e45428725fdc2c0a56f70" 503 | ] 504 | }, 505 | "id": "7V_FhoYPe-EN", 506 | "outputId": "3cac9420-1ae0-4a9a-e365-ff90fd671384" 507 | }, 508 | "execution_count": null, 509 | "outputs": [ 510 | { 511 | "output_type": "stream", 512 | "name": "stderr", 513 | "text": [ 514 | "WARNING:datasets.builder:Found cached dataset conll2003 (/root/.cache/huggingface/datasets/conll2003/conll2003/1.0.0/9a4d16a94f8674ba3466315300359b0acd891b68b6c8743ddf60b9c702adce98)\n" 515 | ] 516 | }, 517 | { 518 | "output_type": "display_data", 519 | "data": { 520 | "text/plain": [ 521 | " 0%| | 0/3 [00:00 0:\n", 542 | " f.write(\n", 543 | " str(len(tokens))\n", 544 | " + \"\\t\"\n", 545 | " + \"\\t\".join(tokens)\n", 546 | " + \"\\t\"\n", 547 | " + \"\\t\".join(map(str, ner_tags))\n", 548 | " + \"\\n\"\n", 549 | " )\n", 550 | "\n", 551 | "\n", 552 | "os.mkdir(\"data\")\n", 553 | "export_to_file(\"./data/conll_train.txt\", conll_data[\"train\"])\n", 554 | "export_to_file(\"./data/conll_val.txt\", conll_data[\"validation\"])" 555 | ], 556 | "metadata": { 557 | "id": "B4WjrivLfAxc" 558 | }, 559 | "execution_count": null, 560 | "outputs": [] 561 | }, 562 | { 563 | "cell_type": "markdown", 564 | "source": [ 565 | "Generating the entities and tags mapping" 566 | ], 567 | "metadata": { 568 | "id": "Xo85Q67fh1b7" 569 | } 570 | }, 571 | { 572 | "cell_type": "code", 573 | "source": [ 574 | "def make_tag_lookup_table():\n", 575 | " iob_labels = [\"B\", \"I\"]\n", 576 | " ner_labels = [\"PER\", \"ORG\", \"LOC\", \"MISC\"]\n", 577 | " all_labels = [(label1, label2) for label2 in ner_labels for label1 in iob_labels]\n", 578 | " all_labels = [\"-\".join([a, b]) for a, b in all_labels]\n", 579 | " all_labels = [\"[PAD]\", \"O\"] + all_labels\n", 580 | " return dict(zip(range(0, len(all_labels) + 1), all_labels))\n", 581 | "\n", 582 | "\n", 583 | "mapping = make_tag_lookup_table()\n", 584 | "print(mapping)" 585 | ], 586 | "metadata": { 587 | "colab": { 588 | "base_uri": "https://localhost:8080/" 589 | }, 590 | "id": "k668ZOSrfDFU", 591 | "outputId": "1d8d15ed-3cd6-49e5-b2f4-1815795466d9" 592 | }, 593 | "execution_count": null, 594 | "outputs": [ 595 | { 596 | "output_type": "stream", 597 | "name": "stdout", 598 | "text": [ 599 | "{0: '[PAD]', 1: 'O', 2: 'B-PER', 3: 'I-PER', 4: 'B-ORG', 5: 'I-ORG', 6: 'B-LOC', 7: 'I-LOC', 8: 'B-MISC', 9: 'I-MISC'}\n" 600 | ] 601 | } 602 | ] 603 | }, 604 | { 605 | "cell_type": "code", 606 | "source": [ 607 | "all_tokens = sum(conll_data[\"train\"][\"tokens\"], [])\n", 608 | "all_tokens_array = np.array(list(map(str.lower, all_tokens)))\n", 609 | "\n", 610 | "counter = Counter(all_tokens_array)\n", 611 | "print(len(counter))\n", 612 | "\n", 613 | "num_tags = len(mapping)\n", 614 | "vocab_size = 20000\n", 615 | "vocabulary = [token for token, count in counter.most_common(vocab_size - 2)]\n", 616 | "\n", 617 | "lookup_layer = keras.layers.StringLookup(\n", 618 | " vocabulary=vocabulary\n", 619 | ")" 620 | ], 621 | "metadata": { 622 | "colab": { 623 | "base_uri": "https://localhost:8080/" 624 | }, 625 | "id": "o11ECz4kfFPk", 626 | "outputId": "46b66db1-da41-46a4-99b3-ab356f3865ca" 627 | }, 628 | "execution_count": null, 629 | "outputs": [ 630 | { 631 | "output_type": "stream", 632 | "name": "stdout", 633 | "text": [ 634 | "21009\n" 635 | ] 636 | } 637 | ] 638 | }, 639 | { 640 | "cell_type": "code", 641 | "source": [ 642 | "train_data = tf.data.TextLineDataset(\"./data/conll_train.txt\")\n", 643 | "val_data = tf.data.TextLineDataset(\"./data/conll_val.txt\")" 644 | ], 645 | "metadata": { 646 | "id": "62ohE_h4fIJt" 647 | }, 648 | "execution_count": null, 649 | "outputs": [] 650 | }, 651 | { 652 | "cell_type": "code", 653 | "source": [ 654 | "print(list(train_data.take(1).as_numpy_iterator()))" 655 | ], 656 | "metadata": { 657 | "colab": { 658 | "base_uri": "https://localhost:8080/" 659 | }, 660 | "id": "8C714wpJfMhW", 661 | "outputId": "abf9c329-92b6-4dbd-b5d9-cb847255f0a4" 662 | }, 663 | "execution_count": null, 664 | "outputs": [ 665 | { 666 | "output_type": "stream", 667 | "name": "stdout", 668 | "text": [ 669 | "[b'9\\tEU\\trejects\\tGerman\\tcall\\tto\\tboycott\\tBritish\\tlamb\\t.\\t3\\t0\\t7\\t0\\t0\\t0\\t7\\t0\\t0']\n" 670 | ] 671 | } 672 | ] 673 | }, 674 | { 675 | "cell_type": "markdown", 676 | "source": [ 677 | "## Preprocessing Dataset\n", 678 | "\n", 679 | "For tokenizing the text we use the tensorflow text `Fastwordpiecetokenizer` and create the data generator for training the model.\n" 680 | ], 681 | "metadata": { 682 | "id": "gZAE1MFliivB" 683 | } 684 | }, 685 | { 686 | "cell_type": "code", 687 | "source": [ 688 | "import tensorflow_text as tf_text\n", 689 | "tok = keras_nlp.models.BertTokenizer.from_preset(\"bert_base_en_uncased\", lowercase=True)\n", 690 | "tokenizer = tf_text.FastWordpieceTokenizer(tok.vocabulary)" 691 | ], 692 | "metadata": { 693 | "id": "Kltg-IYby3fV" 694 | }, 695 | "execution_count": null, 696 | "outputs": [] 697 | }, 698 | { 699 | "cell_type": "code", 700 | "source": [ 701 | "\n", 702 | "def map_record_to_training_data(record):\n", 703 | " record = tf.strings.split(record, sep=\"\\t\")\n", 704 | " length = tf.strings.to_number(record[0], out_type=tf.int32)\n", 705 | " tokens = record[1 : length + 1]\n", 706 | " # mask = tf.ones([length])\n", 707 | " # print(tokens)\n", 708 | "\n", 709 | " # tokens = tf.split(tokens, num_or_size_splits = tokens.shape[0], axis = 0)\n", 710 | " tokens = tf.strings.reduce_join(record[1 : length + 1],separator=' ')\n", 711 | " tokens = tokenizer.tokenize_with_offsets(tokens)[0]\n", 712 | " tags = record[length + 1 :]\n", 713 | " tags = tf.strings.to_number(tags, out_type=tf.int64)\n", 714 | " tags += 1\n", 715 | " return (tokens, tags)\n", 716 | "\n", 717 | "def fil(ds):\n", 718 | " return ds.filter(lambda x,y: tokenizer.tokenize_with_offsets(x)[0].shape==y.shape)\n", 719 | "\n", 720 | "\n", 721 | "batch_size = 32\n", 722 | "train_dataset = train_data.map(map_record_to_training_data)\n", 723 | " # .map(lambda x, y,z: (lowercase_and_convert_to_ids(x), y,z))\n", 724 | "\n", 725 | "# train_dataset = train_dataset.apply(fil)\n", 726 | "val_dataset = val_data.map(map_record_to_training_data)\n", 727 | " # .map(lambda x, y,z: (lowercase_and_convert_to_ids(x), y,z))\n", 728 | "\n", 729 | "# val_dataset = val_dataset.apply(fil)\n" 730 | ], 731 | "metadata": { 732 | "id": "Rp7HnnSPfOE1" 733 | }, 734 | "execution_count": null, 735 | "outputs": [] 736 | }, 737 | { 738 | "cell_type": "code", 739 | "source": [ 740 | "x_train = []\n", 741 | "y_train = []\n", 742 | "cnt =0\n", 743 | "mnt= 0\n", 744 | "for x,y in train_dataset:\n", 745 | " if x.shape == y.shape:\n", 746 | " x_train.append(x)\n", 747 | " y_oh=[]\n", 748 | " for tag in y:\n", 749 | " t = [0]*num_tags\n", 750 | " t[tag]=1\n", 751 | " y_oh.append(t)\n", 752 | " y_train.append(y_oh)\n", 753 | "len(x_train)" 754 | ], 755 | "metadata": { 756 | "colab": { 757 | "base_uri": "https://localhost:8080/" 758 | }, 759 | "id": "Gw_I4pqZ-I_M", 760 | "outputId": "38f6d37e-abc0-4217-8c50-a4e16fc01e66" 761 | }, 762 | "execution_count": null, 763 | "outputs": [ 764 | { 765 | "output_type": "execute_result", 766 | "data": { 767 | "text/plain": [ 768 | "5416" 769 | ] 770 | }, 771 | "metadata": {}, 772 | "execution_count": 32 773 | } 774 | ] 775 | }, 776 | { 777 | "cell_type": "code", 778 | "source": [ 779 | "x_val = []\n", 780 | "y_val = []\n", 781 | "cnt =0\n", 782 | "mnt= 0\n", 783 | "for x,y in val_dataset:\n", 784 | " if x.shape == y.shape:\n", 785 | " x_val.append(x)\n", 786 | " y_oh=[]\n", 787 | " for tag in y:\n", 788 | " t = [0]*num_tags\n", 789 | " t[tag]=1\n", 790 | " y_oh.append(t)\n", 791 | " y_val.append(y_oh)\n", 792 | "len(x_val)" 793 | ], 794 | "metadata": { 795 | "colab": { 796 | "base_uri": "https://localhost:8080/" 797 | }, 798 | "id": "gq260EwR8f6K", 799 | "outputId": "621b44f7-fc66-47db-8d34-d35f29155403" 800 | }, 801 | "execution_count": null, 802 | "outputs": [ 803 | { 804 | "output_type": "execute_result", 805 | "data": { 806 | "text/plain": [ 807 | "1205" 808 | ] 809 | }, 810 | "metadata": {}, 811 | "execution_count": 34 812 | } 813 | ] 814 | }, 815 | { 816 | "cell_type": "markdown", 817 | "source": [ 818 | "## Model Building\n", 819 | "\n", 820 | "For this pipeline we use the `CustomNonPaddingTokenLoss` and then create the NER model. The backbone of the model is the pretrained `Roberta` model of KerasNLP with the base configuration. Then we use a Dense layer head for entity classification." 821 | ], 822 | "metadata": { 823 | "id": "GnNEySU_jD7H" 824 | } 825 | }, 826 | { 827 | "cell_type": "code", 828 | "source": [ 829 | "class CustomNonPaddingTokenLoss(keras.losses.Loss):\n", 830 | " def __init__(self, name=\"custom_ner_loss\"):\n", 831 | " super().__init__(name=name)\n", 832 | "\n", 833 | " def call(self, y_true, y_pred):\n", 834 | " loss_fn = keras.losses.CategoricalCrossentropy()\n", 835 | " loss = loss_fn(y_true, y_pred)\n", 836 | " mask = tf.cast((y_true > 0), dtype=tf.float32)\n", 837 | " loss = loss * mask\n", 838 | " return tf.reduce_sum(loss) / tf.reduce_sum(mask)\n", 839 | "\n", 840 | "\n", 841 | "loss = CustomNonPaddingTokenLoss()" 842 | ], 843 | "metadata": { 844 | "id": "1xZqKuoTfa_g" 845 | }, 846 | "execution_count": null, 847 | "outputs": [] 848 | }, 849 | { 850 | "cell_type": "code", 851 | "source": [ 852 | "class NERModel(keras.Model):\n", 853 | " def __init__(\n", 854 | " self, num_tags, ff_dim=32\n", 855 | " ):\n", 856 | " super().__init__()\n", 857 | " self.tokenizer_ = tokenizer\n", 858 | " # self.proc = keras_nlp.models.RobertaPreprocessor.from_preset(\"roberta_base_en\")\n", 859 | " self.transformer_block =keras_nlp.models.RobertaBackbone.from_preset(\"roberta_base_en\")\n", 860 | " # self.transformer_block = keras_nlp.models.RobertaBackbone(vocab_size,4, num_heads, ff_dim,32,max_sequence_length=maxlen)\n", 861 | " self.dropout1 = layers.Dropout(0.1)\n", 862 | " self.flat=layers.Flatten()\n", 863 | " self.ff = layers.Dense(ff_dim, activation=\"relu\")\n", 864 | " self.dropout2 = layers.Dropout(0.1)\n", 865 | " self.ff_final = layers.Dense(num_tags, activation=\"softmax\")\n", 866 | "\n", 867 | " def call(self, inputs, training=False):\n", 868 | " # print(inputs)\n", 869 | " # inputs = self.tokenizer_.tokenize_with_offsets(inputs)[0]\n", 870 | " # print(inputs)\n", 871 | " # print(inputs.shape)\n", 872 | " mask = tf.ones_like(inputs)\n", 873 | " # print(mask)\n", 874 | " # inp = self.proc(inputs)\n", 875 | " x = self.transformer_block([tf.expand_dims(inputs,axis=0),tf.expand_dims(mask,0)])\n", 876 | " x = self.dropout1(x, training=training)\n", 877 | " x = self.ff(x)\n", 878 | " x = self.dropout2(x, training=training)\n", 879 | " x = self.ff_final(x)\n", 880 | " return x\n", 881 | "ner_model = NERModel(num_tags, ff_dim=64)\n", 882 | "# ner_model.compile(optimizer=\"adam\", loss=loss)" 883 | ], 884 | "metadata": { 885 | "id": "2o9E7RW-bv90" 886 | }, 887 | "execution_count": null, 888 | "outputs": [] 889 | }, 890 | { 891 | "cell_type": "code", 892 | "source": [ 893 | "optimizer = keras.optimizers.Adam(10e-5)\n", 894 | "# Instantiate a loss function.\n", 895 | "loss_fn = loss\n", 896 | "train_acc_metric = keras.metrics.CategoricalAccuracy()\n", 897 | "val_acc_metric = keras.metrics.CategoricalAccuracy()" 898 | ], 899 | "metadata": { 900 | "id": "7HPPJvGPMjmB" 901 | }, 902 | "execution_count": null, 903 | "outputs": [] 904 | }, 905 | { 906 | "cell_type": "code", 907 | "source": [ 908 | "import numpy as np\n", 909 | "\n", 910 | "@tf.function\n", 911 | "def train_step(x, y):\n", 912 | " with tf.GradientTape() as tape:\n", 913 | " logits = ner_model(x, training=True)\n", 914 | " loss_value = loss_fn(y, logits)\n", 915 | " grads = tape.gradient(loss_value, ner_model.trainable_weights)\n", 916 | " optimizer.apply_gradients(zip(grads, ner_model.trainable_weights))\n", 917 | " train_acc_metric.update_state(y, logits)\n", 918 | " return loss_value\n", 919 | "@tf.function\n", 920 | "def test_step(x, y):\n", 921 | " val_logits = ner_model(x, training=False)\n", 922 | " val_acc_metric.update_state(y, val_logits)\n", 923 | "import time\n", 924 | "from tqdm import tqdm\n", 925 | "train_acc_list=[]\n", 926 | "train_loss_list=[]\n", 927 | "epochs = 2\n", 928 | "for epoch in range(epochs):\n", 929 | " print(\"\\nStart of epoch %d\" % (epoch,))\n", 930 | " start_time = time.time()\n", 931 | " train_loss = []\n", 932 | " train_loss_batch=[]\n", 933 | " for step, (x_batch_train, y_batch_train) in tqdm(enumerate(zip(x_train,y_train))):\n", 934 | " loss_value = train_step(x_batch_train, tf.expand_dims(y_batch_train,axis=0))\n", 935 | " train_loss.append(float(loss_value))\n", 936 | " train_loss_batch.append(float(loss_value))\n", 937 | " if step % 1000 == 0:\n", 938 | " print(\n", 939 | " \"Training loss (for one batch) at step %d: %.4f\"\n", 940 | " % (step, np.mean(train_loss_batch))\n", 941 | " )\n", 942 | " train_loss_batch=[]\n", 943 | " print(\"Seen so far: %d samples\" % ((step + 1) ))\n", 944 | " train_loss_list.append(np.mean(train_loss))\n", 945 | " train_acc = train_acc_metric.result()\n", 946 | " print(\"Training acc over epoch: %.4f\" % (float(train_acc),))\n", 947 | " train_acc_list.append(float(train_acc))\n", 948 | " train_acc_metric.reset_states()\n", 949 | " print(\"Time taken: %.2fs\" % (time.time() - start_time))" 950 | ], 951 | "metadata": { 952 | "colab": { 953 | "base_uri": "https://localhost:8080/" 954 | }, 955 | "id": "TycS5BAkXmGq", 956 | "outputId": "6610cbf8-d23f-48c6-bd31-818f4a6bd047" 957 | }, 958 | "execution_count": null, 959 | "outputs": [ 960 | { 961 | "output_type": "stream", 962 | "name": "stdout", 963 | "text": [ 964 | "\n", 965 | "Start of epoch 0\n" 966 | ] 967 | }, 968 | { 969 | "output_type": "stream", 970 | "name": "stderr", 971 | "text": [ 972 | "1it [00:48, 48.86s/it]" 973 | ] 974 | }, 975 | { 976 | "output_type": "stream", 977 | "name": "stdout", 978 | "text": [ 979 | "Training loss (for one batch) at step 0: 1.3125\n", 980 | "Seen so far: 1 samples\n" 981 | ] 982 | }, 983 | { 984 | "output_type": "stream", 985 | "name": "stderr", 986 | "text": [ 987 | "4it [01:23, 16.33s/it]WARNING:tensorflow:5 out of the last 5 calls to triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details.\n", 988 | "5it [01:31, 13.49s/it]WARNING:tensorflow:6 out of the last 6 calls to triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details.\n", 989 | "1004it [08:09, 21.06it/s]" 990 | ] 991 | }, 992 | { 993 | "output_type": "stream", 994 | "name": "stdout", 995 | "text": [ 996 | "Training loss (for one batch) at step 1000: 1.0708\n", 997 | "Seen so far: 1001 samples\n" 998 | ] 999 | }, 1000 | { 1001 | "output_type": "stream", 1002 | "name": "stderr", 1003 | "text": [ 1004 | "2003it [09:04, 21.12it/s]" 1005 | ] 1006 | }, 1007 | { 1008 | "output_type": "stream", 1009 | "name": "stdout", 1010 | "text": [ 1011 | "Training loss (for one batch) at step 2000: 1.0809\n", 1012 | "Seen so far: 2001 samples\n" 1013 | ] 1014 | }, 1015 | { 1016 | "output_type": "stream", 1017 | "name": "stderr", 1018 | "text": [ 1019 | "3003it [10:05, 19.56it/s]" 1020 | ] 1021 | }, 1022 | { 1023 | "output_type": "stream", 1024 | "name": "stdout", 1025 | "text": [ 1026 | "Training loss (for one batch) at step 3000: 1.0715\n", 1027 | "Seen so far: 3001 samples\n" 1028 | ] 1029 | }, 1030 | { 1031 | "output_type": "stream", 1032 | "name": "stderr", 1033 | "text": [ 1034 | "4003it [11:07, 20.41it/s]" 1035 | ] 1036 | }, 1037 | { 1038 | "output_type": "stream", 1039 | "name": "stdout", 1040 | "text": [ 1041 | "Training loss (for one batch) at step 4000: 0.9544\n", 1042 | "Seen so far: 4001 samples\n" 1043 | ] 1044 | }, 1045 | { 1046 | "output_type": "stream", 1047 | "name": "stderr", 1048 | "text": [ 1049 | "5005it [12:19, 20.50it/s]" 1050 | ] 1051 | }, 1052 | { 1053 | "output_type": "stream", 1054 | "name": "stdout", 1055 | "text": [ 1056 | "Training loss (for one batch) at step 5000: 1.0672\n", 1057 | "Seen so far: 5001 samples\n" 1058 | ] 1059 | }, 1060 | { 1061 | "output_type": "stream", 1062 | "name": "stderr", 1063 | "text": [ 1064 | "5416it [12:41, 7.12it/s]\n" 1065 | ] 1066 | }, 1067 | { 1068 | "output_type": "stream", 1069 | "name": "stdout", 1070 | "text": [ 1071 | "Training acc over epoch: 0.8199\n", 1072 | "Time taken: 761.06s\n", 1073 | "\n", 1074 | "Start of epoch 1\n" 1075 | ] 1076 | }, 1077 | { 1078 | "output_type": "stream", 1079 | "name": "stderr", 1080 | "text": [ 1081 | "3it [00:00, 20.47it/s]" 1082 | ] 1083 | }, 1084 | { 1085 | "output_type": "stream", 1086 | "name": "stdout", 1087 | "text": [ 1088 | "Training loss (for one batch) at step 0: 1.3068\n", 1089 | "Seen so far: 1 samples\n" 1090 | ] 1091 | }, 1092 | { 1093 | "output_type": "stream", 1094 | "name": "stderr", 1095 | "text": [ 1096 | "1003it [00:55, 21.15it/s]" 1097 | ] 1098 | }, 1099 | { 1100 | "output_type": "stream", 1101 | "name": "stdout", 1102 | "text": [ 1103 | "Training loss (for one batch) at step 1000: 1.0671\n", 1104 | "Seen so far: 1001 samples\n" 1105 | ] 1106 | }, 1107 | { 1108 | "output_type": "stream", 1109 | "name": "stderr", 1110 | "text": [ 1111 | "2004it [01:50, 19.74it/s]" 1112 | ] 1113 | }, 1114 | { 1115 | "output_type": "stream", 1116 | "name": "stdout", 1117 | "text": [ 1118 | "Training loss (for one batch) at step 2000: 1.0768\n", 1119 | "Seen so far: 2001 samples\n" 1120 | ] 1121 | }, 1122 | { 1123 | "output_type": "stream", 1124 | "name": "stderr", 1125 | "text": [ 1126 | "3003it [02:43, 17.31it/s]" 1127 | ] 1128 | }, 1129 | { 1130 | "output_type": "stream", 1131 | "name": "stdout", 1132 | "text": [ 1133 | "Training loss (for one batch) at step 3000: 1.0724\n", 1134 | "Seen so far: 3001 samples\n" 1135 | ] 1136 | }, 1137 | { 1138 | "output_type": "stream", 1139 | "name": "stderr", 1140 | "text": [ 1141 | "4005it [03:39, 20.32it/s]" 1142 | ] 1143 | }, 1144 | { 1145 | "output_type": "stream", 1146 | "name": "stdout", 1147 | "text": [ 1148 | "Training loss (for one batch) at step 4000: 0.9605\n", 1149 | "Seen so far: 4001 samples\n" 1150 | ] 1151 | }, 1152 | { 1153 | "output_type": "stream", 1154 | "name": "stderr", 1155 | "text": [ 1156 | "5004it [04:33, 19.27it/s]" 1157 | ] 1158 | }, 1159 | { 1160 | "output_type": "stream", 1161 | "name": "stdout", 1162 | "text": [ 1163 | "Training loss (for one batch) at step 5000: 1.0558\n", 1164 | "Seen so far: 5001 samples\n" 1165 | ] 1166 | }, 1167 | { 1168 | "output_type": "stream", 1169 | "name": "stderr", 1170 | "text": [ 1171 | "5416it [04:55, 18.35it/s]" 1172 | ] 1173 | }, 1174 | { 1175 | "output_type": "stream", 1176 | "name": "stdout", 1177 | "text": [ 1178 | "Training acc over epoch: 0.8199\n", 1179 | "Time taken: 295.22s\n" 1180 | ] 1181 | }, 1182 | { 1183 | "output_type": "stream", 1184 | "name": "stderr", 1185 | "text": [ 1186 | "\n" 1187 | ] 1188 | } 1189 | ] 1190 | }, 1191 | { 1192 | "cell_type": "code", 1193 | "source": [ 1194 | "\n", 1195 | "txt= \"eu rejects german call to boycott british lamb\"\n", 1196 | "# Sample inference using the trained model\n", 1197 | "sample_input = tokenizer.tokenize_with_offsets(txt)[0]\n", 1198 | "\n", 1199 | "output = ner_model.predict(sample_input)\n", 1200 | "prediction = np.argmax(output, axis=-1)[0]\n", 1201 | "prediction = [mapping[i] for i in prediction]\n", 1202 | "\n", 1203 | "# eu -> B-ORG, german -> B-MISC, british -> B-MISC\n", 1204 | "print(sample_input)\n", 1205 | "print(prediction)\n", 1206 | "for tok, pred in zip(txt.split(), prediction):\n", 1207 | " print(tok, pred)" 1208 | ], 1209 | "metadata": { 1210 | "id": "mvUDhYbgfjYx", 1211 | "colab": { 1212 | "base_uri": "https://localhost:8080/" 1213 | }, 1214 | "outputId": "7b744823-e14f-4bdd-d45b-a93269ea945e" 1215 | }, 1216 | "execution_count": null, 1217 | "outputs": [ 1218 | { 1219 | "output_type": "stream", 1220 | "name": "stdout", 1221 | "text": [ 1222 | "1/1 [==============================] - 4s 4s/step\n", 1223 | "tf.Tensor([ 7327 19164 2446 2655 2000 17757 2329 12559], shape=(8,), dtype=int64)\n", 1224 | "['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O']\n", 1225 | "eu O\n", 1226 | "rejects O\n", 1227 | "german O\n", 1228 | "call O\n", 1229 | "to O\n", 1230 | "boycott O\n", 1231 | "british O\n", 1232 | "lamb O\n" 1233 | ] 1234 | } 1235 | ] 1236 | }, 1237 | { 1238 | "cell_type": "code", 1239 | "source": [], 1240 | "metadata": { 1241 | "id": "lAUNQQCVGAuI" 1242 | }, 1243 | "execution_count": null, 1244 | "outputs": [] 1245 | } 1246 | ] 1247 | } --------------------------------------------------------------------------------