├── README.md
├── LICENSE
└── Roberta_NER_WithText.ipynb
/README.md:
--------------------------------------------------------------------------------
1 | # RoBERTa-NER
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/Roberta_NER_WithText.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "colab": {
6 | "provenance": [],
7 | "gpuType": "T4"
8 | },
9 | "kernelspec": {
10 | "name": "python3",
11 | "display_name": "Python 3"
12 | },
13 | "language_info": {
14 | "name": "python"
15 | },
16 | "accelerator": "GPU",
17 | "widgets": {
18 | "application/vnd.jupyter.widget-state+json": {
19 | "167150fc901e444ebf8bcf7143a257b3": {
20 | "model_module": "@jupyter-widgets/controls",
21 | "model_name": "HBoxModel",
22 | "model_module_version": "1.5.0",
23 | "state": {
24 | "_dom_classes": [],
25 | "_model_module": "@jupyter-widgets/controls",
26 | "_model_module_version": "1.5.0",
27 | "_model_name": "HBoxModel",
28 | "_view_count": null,
29 | "_view_module": "@jupyter-widgets/controls",
30 | "_view_module_version": "1.5.0",
31 | "_view_name": "HBoxView",
32 | "box_style": "",
33 | "children": [
34 | "IPY_MODEL_8be485f189ce4b7e950c57334b989601",
35 | "IPY_MODEL_fc5cb6b2c47247e8a59769b6e583549b",
36 | "IPY_MODEL_b397c4608f98450fbcb81ad28e427c04"
37 | ],
38 | "layout": "IPY_MODEL_477973dccac1415baff1c06e0d3dc836"
39 | }
40 | },
41 | "8be485f189ce4b7e950c57334b989601": {
42 | "model_module": "@jupyter-widgets/controls",
43 | "model_name": "HTMLModel",
44 | "model_module_version": "1.5.0",
45 | "state": {
46 | "_dom_classes": [],
47 | "_model_module": "@jupyter-widgets/controls",
48 | "_model_module_version": "1.5.0",
49 | "_model_name": "HTMLModel",
50 | "_view_count": null,
51 | "_view_module": "@jupyter-widgets/controls",
52 | "_view_module_version": "1.5.0",
53 | "_view_name": "HTMLView",
54 | "description": "",
55 | "description_tooltip": null,
56 | "layout": "IPY_MODEL_62ec567181fa473b9267cbfa34e5bb8b",
57 | "placeholder": "",
58 | "style": "IPY_MODEL_d4d71ee95b0847bf95f040c640457f2f",
59 | "value": "100%"
60 | }
61 | },
62 | "fc5cb6b2c47247e8a59769b6e583549b": {
63 | "model_module": "@jupyter-widgets/controls",
64 | "model_name": "FloatProgressModel",
65 | "model_module_version": "1.5.0",
66 | "state": {
67 | "_dom_classes": [],
68 | "_model_module": "@jupyter-widgets/controls",
69 | "_model_module_version": "1.5.0",
70 | "_model_name": "FloatProgressModel",
71 | "_view_count": null,
72 | "_view_module": "@jupyter-widgets/controls",
73 | "_view_module_version": "1.5.0",
74 | "_view_name": "ProgressView",
75 | "bar_style": "success",
76 | "description": "",
77 | "description_tooltip": null,
78 | "layout": "IPY_MODEL_10c6b510aea040939d357e09b82ed8dd",
79 | "max": 3,
80 | "min": 0,
81 | "orientation": "horizontal",
82 | "style": "IPY_MODEL_7554a5b794844f5482fab12edb702ee2",
83 | "value": 3
84 | }
85 | },
86 | "b397c4608f98450fbcb81ad28e427c04": {
87 | "model_module": "@jupyter-widgets/controls",
88 | "model_name": "HTMLModel",
89 | "model_module_version": "1.5.0",
90 | "state": {
91 | "_dom_classes": [],
92 | "_model_module": "@jupyter-widgets/controls",
93 | "_model_module_version": "1.5.0",
94 | "_model_name": "HTMLModel",
95 | "_view_count": null,
96 | "_view_module": "@jupyter-widgets/controls",
97 | "_view_module_version": "1.5.0",
98 | "_view_name": "HTMLView",
99 | "description": "",
100 | "description_tooltip": null,
101 | "layout": "IPY_MODEL_31f1dab6c44a4c80b48b462b436a80e0",
102 | "placeholder": "",
103 | "style": "IPY_MODEL_3f90e5333a8e45428725fdc2c0a56f70",
104 | "value": " 3/3 [00:00<00:00, 114.10it/s]"
105 | }
106 | },
107 | "477973dccac1415baff1c06e0d3dc836": {
108 | "model_module": "@jupyter-widgets/base",
109 | "model_name": "LayoutModel",
110 | "model_module_version": "1.2.0",
111 | "state": {
112 | "_model_module": "@jupyter-widgets/base",
113 | "_model_module_version": "1.2.0",
114 | "_model_name": "LayoutModel",
115 | "_view_count": null,
116 | "_view_module": "@jupyter-widgets/base",
117 | "_view_module_version": "1.2.0",
118 | "_view_name": "LayoutView",
119 | "align_content": null,
120 | "align_items": null,
121 | "align_self": null,
122 | "border": null,
123 | "bottom": null,
124 | "display": null,
125 | "flex": null,
126 | "flex_flow": null,
127 | "grid_area": null,
128 | "grid_auto_columns": null,
129 | "grid_auto_flow": null,
130 | "grid_auto_rows": null,
131 | "grid_column": null,
132 | "grid_gap": null,
133 | "grid_row": null,
134 | "grid_template_areas": null,
135 | "grid_template_columns": null,
136 | "grid_template_rows": null,
137 | "height": null,
138 | "justify_content": null,
139 | "justify_items": null,
140 | "left": null,
141 | "margin": null,
142 | "max_height": null,
143 | "max_width": null,
144 | "min_height": null,
145 | "min_width": null,
146 | "object_fit": null,
147 | "object_position": null,
148 | "order": null,
149 | "overflow": null,
150 | "overflow_x": null,
151 | "overflow_y": null,
152 | "padding": null,
153 | "right": null,
154 | "top": null,
155 | "visibility": null,
156 | "width": null
157 | }
158 | },
159 | "62ec567181fa473b9267cbfa34e5bb8b": {
160 | "model_module": "@jupyter-widgets/base",
161 | "model_name": "LayoutModel",
162 | "model_module_version": "1.2.0",
163 | "state": {
164 | "_model_module": "@jupyter-widgets/base",
165 | "_model_module_version": "1.2.0",
166 | "_model_name": "LayoutModel",
167 | "_view_count": null,
168 | "_view_module": "@jupyter-widgets/base",
169 | "_view_module_version": "1.2.0",
170 | "_view_name": "LayoutView",
171 | "align_content": null,
172 | "align_items": null,
173 | "align_self": null,
174 | "border": null,
175 | "bottom": null,
176 | "display": null,
177 | "flex": null,
178 | "flex_flow": null,
179 | "grid_area": null,
180 | "grid_auto_columns": null,
181 | "grid_auto_flow": null,
182 | "grid_auto_rows": null,
183 | "grid_column": null,
184 | "grid_gap": null,
185 | "grid_row": null,
186 | "grid_template_areas": null,
187 | "grid_template_columns": null,
188 | "grid_template_rows": null,
189 | "height": null,
190 | "justify_content": null,
191 | "justify_items": null,
192 | "left": null,
193 | "margin": null,
194 | "max_height": null,
195 | "max_width": null,
196 | "min_height": null,
197 | "min_width": null,
198 | "object_fit": null,
199 | "object_position": null,
200 | "order": null,
201 | "overflow": null,
202 | "overflow_x": null,
203 | "overflow_y": null,
204 | "padding": null,
205 | "right": null,
206 | "top": null,
207 | "visibility": null,
208 | "width": null
209 | }
210 | },
211 | "d4d71ee95b0847bf95f040c640457f2f": {
212 | "model_module": "@jupyter-widgets/controls",
213 | "model_name": "DescriptionStyleModel",
214 | "model_module_version": "1.5.0",
215 | "state": {
216 | "_model_module": "@jupyter-widgets/controls",
217 | "_model_module_version": "1.5.0",
218 | "_model_name": "DescriptionStyleModel",
219 | "_view_count": null,
220 | "_view_module": "@jupyter-widgets/base",
221 | "_view_module_version": "1.2.0",
222 | "_view_name": "StyleView",
223 | "description_width": ""
224 | }
225 | },
226 | "10c6b510aea040939d357e09b82ed8dd": {
227 | "model_module": "@jupyter-widgets/base",
228 | "model_name": "LayoutModel",
229 | "model_module_version": "1.2.0",
230 | "state": {
231 | "_model_module": "@jupyter-widgets/base",
232 | "_model_module_version": "1.2.0",
233 | "_model_name": "LayoutModel",
234 | "_view_count": null,
235 | "_view_module": "@jupyter-widgets/base",
236 | "_view_module_version": "1.2.0",
237 | "_view_name": "LayoutView",
238 | "align_content": null,
239 | "align_items": null,
240 | "align_self": null,
241 | "border": null,
242 | "bottom": null,
243 | "display": null,
244 | "flex": null,
245 | "flex_flow": null,
246 | "grid_area": null,
247 | "grid_auto_columns": null,
248 | "grid_auto_flow": null,
249 | "grid_auto_rows": null,
250 | "grid_column": null,
251 | "grid_gap": null,
252 | "grid_row": null,
253 | "grid_template_areas": null,
254 | "grid_template_columns": null,
255 | "grid_template_rows": null,
256 | "height": null,
257 | "justify_content": null,
258 | "justify_items": null,
259 | "left": null,
260 | "margin": null,
261 | "max_height": null,
262 | "max_width": null,
263 | "min_height": null,
264 | "min_width": null,
265 | "object_fit": null,
266 | "object_position": null,
267 | "order": null,
268 | "overflow": null,
269 | "overflow_x": null,
270 | "overflow_y": null,
271 | "padding": null,
272 | "right": null,
273 | "top": null,
274 | "visibility": null,
275 | "width": null
276 | }
277 | },
278 | "7554a5b794844f5482fab12edb702ee2": {
279 | "model_module": "@jupyter-widgets/controls",
280 | "model_name": "ProgressStyleModel",
281 | "model_module_version": "1.5.0",
282 | "state": {
283 | "_model_module": "@jupyter-widgets/controls",
284 | "_model_module_version": "1.5.0",
285 | "_model_name": "ProgressStyleModel",
286 | "_view_count": null,
287 | "_view_module": "@jupyter-widgets/base",
288 | "_view_module_version": "1.2.0",
289 | "_view_name": "StyleView",
290 | "bar_color": null,
291 | "description_width": ""
292 | }
293 | },
294 | "31f1dab6c44a4c80b48b462b436a80e0": {
295 | "model_module": "@jupyter-widgets/base",
296 | "model_name": "LayoutModel",
297 | "model_module_version": "1.2.0",
298 | "state": {
299 | "_model_module": "@jupyter-widgets/base",
300 | "_model_module_version": "1.2.0",
301 | "_model_name": "LayoutModel",
302 | "_view_count": null,
303 | "_view_module": "@jupyter-widgets/base",
304 | "_view_module_version": "1.2.0",
305 | "_view_name": "LayoutView",
306 | "align_content": null,
307 | "align_items": null,
308 | "align_self": null,
309 | "border": null,
310 | "bottom": null,
311 | "display": null,
312 | "flex": null,
313 | "flex_flow": null,
314 | "grid_area": null,
315 | "grid_auto_columns": null,
316 | "grid_auto_flow": null,
317 | "grid_auto_rows": null,
318 | "grid_column": null,
319 | "grid_gap": null,
320 | "grid_row": null,
321 | "grid_template_areas": null,
322 | "grid_template_columns": null,
323 | "grid_template_rows": null,
324 | "height": null,
325 | "justify_content": null,
326 | "justify_items": null,
327 | "left": null,
328 | "margin": null,
329 | "max_height": null,
330 | "max_width": null,
331 | "min_height": null,
332 | "min_width": null,
333 | "object_fit": null,
334 | "object_position": null,
335 | "order": null,
336 | "overflow": null,
337 | "overflow_x": null,
338 | "overflow_y": null,
339 | "padding": null,
340 | "right": null,
341 | "top": null,
342 | "visibility": null,
343 | "width": null
344 | }
345 | },
346 | "3f90e5333a8e45428725fdc2c0a56f70": {
347 | "model_module": "@jupyter-widgets/controls",
348 | "model_name": "DescriptionStyleModel",
349 | "model_module_version": "1.5.0",
350 | "state": {
351 | "_model_module": "@jupyter-widgets/controls",
352 | "_model_module_version": "1.5.0",
353 | "_model_name": "DescriptionStyleModel",
354 | "_view_count": null,
355 | "_view_module": "@jupyter-widgets/base",
356 | "_view_module_version": "1.2.0",
357 | "_view_name": "StyleView",
358 | "description_width": ""
359 | }
360 | }
361 | }
362 | }
363 | },
364 | "cells": [
365 | {
366 | "cell_type": "markdown",
367 | "source": [
368 | "# [KerasNLP] Named Entity Recognition using RoBERTa\n",
369 | "\n",
370 | "**Author:** [Usha Rengaraju](https://www.linkedin.com/in/usha-rengaraju-b570b7a2/)
\n",
371 | "**Date created:** 2023/07/10
\n",
372 | "**Last modified:** 2023/07/10
\n",
373 | "**Description:** Named Entity Recognition using pretrained RoBERTa\n"
374 | ],
375 | "metadata": {
376 | "id": "EKRr1Vkvvcar"
377 | }
378 | },
379 | {
380 | "cell_type": "markdown",
381 | "source": [
382 | "## Overview\n",
383 | "\n",
384 | "Named entity recognition (NER) is an NLP task that extracts information from text. NER detects and categorizes important information in text known as named entities.\n",
385 | "\n",
386 | "KerasNLP has a variety of pretrained models available. In this guide we create the whole NER pipeline using the pretrained Roberta Backbone.\n"
387 | ],
388 | "metadata": {
389 | "id": "pcanbuwJ7PUX"
390 | }
391 | },
392 | {
393 | "cell_type": "markdown",
394 | "source": [
395 | "## Imports & setup\n",
396 | "\n",
397 | "This tutorial requires you to have KeraNLP installed:\n",
398 | "\n",
399 | "```shell\n",
400 | "pip install keras-nlp\n",
401 | "```\n",
402 | "\n",
403 | "We begin by importing all required packages:"
404 | ],
405 | "metadata": {
406 | "id": "DmC_kCnI7VPq"
407 | }
408 | },
409 | {
410 | "cell_type": "code",
411 | "execution_count": null,
412 | "metadata": {
413 | "id": "4kbJNQkGYVUU",
414 | "colab": {
415 | "base_uri": "https://localhost:8080/"
416 | },
417 | "outputId": "72e1e296-a133-4095-bf28-55e5d8e27a52"
418 | },
419 | "outputs": [
420 | {
421 | "output_type": "stream",
422 | "name": "stdout",
423 | "text": [
424 | "\u001b[?25l \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m0.0/486.2 kB\u001b[0m \u001b[31m?\u001b[0m eta \u001b[36m-:--:--\u001b[0m\r\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m486.2/486.2 kB\u001b[0m \u001b[31m29.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
425 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m110.5/110.5 kB\u001b[0m \u001b[31m13.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
426 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m212.5/212.5 kB\u001b[0m \u001b[31m20.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
427 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m134.3/134.3 kB\u001b[0m \u001b[31m14.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
428 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m268.8/268.8 kB\u001b[0m \u001b[31m30.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
429 | "\u001b[?25h--2023-07-08 13:24:16-- https://raw.githubusercontent.com/sighsmile/conlleval/master/conlleval.py\n",
430 | "Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...\n",
431 | "Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.\n",
432 | "HTTP request sent, awaiting response... 200 OK\n",
433 | "Length: 7502 (7.3K) [text/plain]\n",
434 | "Saving to: ‘conlleval.py’\n",
435 | "\n",
436 | "conlleval.py 100%[===================>] 7.33K --.-KB/s in 0s \n",
437 | "\n",
438 | "2023-07-08 13:24:17 (99.1 MB/s) - ‘conlleval.py’ saved [7502/7502]\n",
439 | "\n"
440 | ]
441 | }
442 | ],
443 | "source": [
444 | "!pip3 install -q datasets\n",
445 | "!wget https://raw.githubusercontent.com/sighsmile/conlleval/master/conlleval.py"
446 | ]
447 | },
448 | {
449 | "cell_type": "code",
450 | "execution_count": null,
451 | "metadata": {
452 | "id": "LiNd4e7hYVUV"
453 | },
454 | "outputs": [],
455 | "source": [
456 | "import os\n",
457 | "import numpy as np\n",
458 | "import tensorflow as tf\n",
459 | "from tensorflow import keras\n",
460 | "from tensorflow.keras import layers\n",
461 | "from datasets import load_dataset\n",
462 | "from collections import Counter\n",
463 | "from conlleval import evaluate\n",
464 | "import keras_nlp"
465 | ]
466 | },
467 | {
468 | "cell_type": "markdown",
469 | "source": [
470 | "## Data loading\n",
471 | "\n",
472 | "This guide uses the\n",
473 | "[Conll 2003 dataset](https://huggingface.co/datasets/conll2003)\n",
474 | "for demonstration purposes.\n",
475 | "\n",
476 | "To get started, we first download and unzip the dataset:"
477 | ],
478 | "metadata": {
479 | "id": "r8CwaHyl_8oX"
480 | }
481 | },
482 | {
483 | "cell_type": "code",
484 | "source": [
485 | "conll_data = load_dataset(\"conll2003\")"
486 | ],
487 | "metadata": {
488 | "colab": {
489 | "base_uri": "https://localhost:8080/",
490 | "height": 86,
491 | "referenced_widgets": [
492 | "167150fc901e444ebf8bcf7143a257b3",
493 | "8be485f189ce4b7e950c57334b989601",
494 | "fc5cb6b2c47247e8a59769b6e583549b",
495 | "b397c4608f98450fbcb81ad28e427c04",
496 | "477973dccac1415baff1c06e0d3dc836",
497 | "62ec567181fa473b9267cbfa34e5bb8b",
498 | "d4d71ee95b0847bf95f040c640457f2f",
499 | "10c6b510aea040939d357e09b82ed8dd",
500 | "7554a5b794844f5482fab12edb702ee2",
501 | "31f1dab6c44a4c80b48b462b436a80e0",
502 | "3f90e5333a8e45428725fdc2c0a56f70"
503 | ]
504 | },
505 | "id": "7V_FhoYPe-EN",
506 | "outputId": "3cac9420-1ae0-4a9a-e365-ff90fd671384"
507 | },
508 | "execution_count": null,
509 | "outputs": [
510 | {
511 | "output_type": "stream",
512 | "name": "stderr",
513 | "text": [
514 | "WARNING:datasets.builder:Found cached dataset conll2003 (/root/.cache/huggingface/datasets/conll2003/conll2003/1.0.0/9a4d16a94f8674ba3466315300359b0acd891b68b6c8743ddf60b9c702adce98)\n"
515 | ]
516 | },
517 | {
518 | "output_type": "display_data",
519 | "data": {
520 | "text/plain": [
521 | " 0%| | 0/3 [00:00, ?it/s]"
522 | ],
523 | "application/vnd.jupyter.widget-view+json": {
524 | "version_major": 2,
525 | "version_minor": 0,
526 | "model_id": "167150fc901e444ebf8bcf7143a257b3"
527 | }
528 | },
529 | "metadata": {}
530 | }
531 | ]
532 | },
533 | {
534 | "cell_type": "code",
535 | "source": [
536 | "def export_to_file(export_file_path, data):\n",
537 | " with open(export_file_path, \"w\") as f:\n",
538 | " for record in data:\n",
539 | " ner_tags = record[\"ner_tags\"]\n",
540 | " tokens = record[\"tokens\"]\n",
541 | " if len(tokens) > 0:\n",
542 | " f.write(\n",
543 | " str(len(tokens))\n",
544 | " + \"\\t\"\n",
545 | " + \"\\t\".join(tokens)\n",
546 | " + \"\\t\"\n",
547 | " + \"\\t\".join(map(str, ner_tags))\n",
548 | " + \"\\n\"\n",
549 | " )\n",
550 | "\n",
551 | "\n",
552 | "os.mkdir(\"data\")\n",
553 | "export_to_file(\"./data/conll_train.txt\", conll_data[\"train\"])\n",
554 | "export_to_file(\"./data/conll_val.txt\", conll_data[\"validation\"])"
555 | ],
556 | "metadata": {
557 | "id": "B4WjrivLfAxc"
558 | },
559 | "execution_count": null,
560 | "outputs": []
561 | },
562 | {
563 | "cell_type": "markdown",
564 | "source": [
565 | "Generating the entities and tags mapping"
566 | ],
567 | "metadata": {
568 | "id": "Xo85Q67fh1b7"
569 | }
570 | },
571 | {
572 | "cell_type": "code",
573 | "source": [
574 | "def make_tag_lookup_table():\n",
575 | " iob_labels = [\"B\", \"I\"]\n",
576 | " ner_labels = [\"PER\", \"ORG\", \"LOC\", \"MISC\"]\n",
577 | " all_labels = [(label1, label2) for label2 in ner_labels for label1 in iob_labels]\n",
578 | " all_labels = [\"-\".join([a, b]) for a, b in all_labels]\n",
579 | " all_labels = [\"[PAD]\", \"O\"] + all_labels\n",
580 | " return dict(zip(range(0, len(all_labels) + 1), all_labels))\n",
581 | "\n",
582 | "\n",
583 | "mapping = make_tag_lookup_table()\n",
584 | "print(mapping)"
585 | ],
586 | "metadata": {
587 | "colab": {
588 | "base_uri": "https://localhost:8080/"
589 | },
590 | "id": "k668ZOSrfDFU",
591 | "outputId": "1d8d15ed-3cd6-49e5-b2f4-1815795466d9"
592 | },
593 | "execution_count": null,
594 | "outputs": [
595 | {
596 | "output_type": "stream",
597 | "name": "stdout",
598 | "text": [
599 | "{0: '[PAD]', 1: 'O', 2: 'B-PER', 3: 'I-PER', 4: 'B-ORG', 5: 'I-ORG', 6: 'B-LOC', 7: 'I-LOC', 8: 'B-MISC', 9: 'I-MISC'}\n"
600 | ]
601 | }
602 | ]
603 | },
604 | {
605 | "cell_type": "code",
606 | "source": [
607 | "all_tokens = sum(conll_data[\"train\"][\"tokens\"], [])\n",
608 | "all_tokens_array = np.array(list(map(str.lower, all_tokens)))\n",
609 | "\n",
610 | "counter = Counter(all_tokens_array)\n",
611 | "print(len(counter))\n",
612 | "\n",
613 | "num_tags = len(mapping)\n",
614 | "vocab_size = 20000\n",
615 | "vocabulary = [token for token, count in counter.most_common(vocab_size - 2)]\n",
616 | "\n",
617 | "lookup_layer = keras.layers.StringLookup(\n",
618 | " vocabulary=vocabulary\n",
619 | ")"
620 | ],
621 | "metadata": {
622 | "colab": {
623 | "base_uri": "https://localhost:8080/"
624 | },
625 | "id": "o11ECz4kfFPk",
626 | "outputId": "46b66db1-da41-46a4-99b3-ab356f3865ca"
627 | },
628 | "execution_count": null,
629 | "outputs": [
630 | {
631 | "output_type": "stream",
632 | "name": "stdout",
633 | "text": [
634 | "21009\n"
635 | ]
636 | }
637 | ]
638 | },
639 | {
640 | "cell_type": "code",
641 | "source": [
642 | "train_data = tf.data.TextLineDataset(\"./data/conll_train.txt\")\n",
643 | "val_data = tf.data.TextLineDataset(\"./data/conll_val.txt\")"
644 | ],
645 | "metadata": {
646 | "id": "62ohE_h4fIJt"
647 | },
648 | "execution_count": null,
649 | "outputs": []
650 | },
651 | {
652 | "cell_type": "code",
653 | "source": [
654 | "print(list(train_data.take(1).as_numpy_iterator()))"
655 | ],
656 | "metadata": {
657 | "colab": {
658 | "base_uri": "https://localhost:8080/"
659 | },
660 | "id": "8C714wpJfMhW",
661 | "outputId": "abf9c329-92b6-4dbd-b5d9-cb847255f0a4"
662 | },
663 | "execution_count": null,
664 | "outputs": [
665 | {
666 | "output_type": "stream",
667 | "name": "stdout",
668 | "text": [
669 | "[b'9\\tEU\\trejects\\tGerman\\tcall\\tto\\tboycott\\tBritish\\tlamb\\t.\\t3\\t0\\t7\\t0\\t0\\t0\\t7\\t0\\t0']\n"
670 | ]
671 | }
672 | ]
673 | },
674 | {
675 | "cell_type": "markdown",
676 | "source": [
677 | "## Preprocessing Dataset\n",
678 | "\n",
679 | "For tokenizing the text we use the tensorflow text `Fastwordpiecetokenizer` and create the data generator for training the model.\n"
680 | ],
681 | "metadata": {
682 | "id": "gZAE1MFliivB"
683 | }
684 | },
685 | {
686 | "cell_type": "code",
687 | "source": [
688 | "import tensorflow_text as tf_text\n",
689 | "tok = keras_nlp.models.BertTokenizer.from_preset(\"bert_base_en_uncased\", lowercase=True)\n",
690 | "tokenizer = tf_text.FastWordpieceTokenizer(tok.vocabulary)"
691 | ],
692 | "metadata": {
693 | "id": "Kltg-IYby3fV"
694 | },
695 | "execution_count": null,
696 | "outputs": []
697 | },
698 | {
699 | "cell_type": "code",
700 | "source": [
701 | "\n",
702 | "def map_record_to_training_data(record):\n",
703 | " record = tf.strings.split(record, sep=\"\\t\")\n",
704 | " length = tf.strings.to_number(record[0], out_type=tf.int32)\n",
705 | " tokens = record[1 : length + 1]\n",
706 | " # mask = tf.ones([length])\n",
707 | " # print(tokens)\n",
708 | "\n",
709 | " # tokens = tf.split(tokens, num_or_size_splits = tokens.shape[0], axis = 0)\n",
710 | " tokens = tf.strings.reduce_join(record[1 : length + 1],separator=' ')\n",
711 | " tokens = tokenizer.tokenize_with_offsets(tokens)[0]\n",
712 | " tags = record[length + 1 :]\n",
713 | " tags = tf.strings.to_number(tags, out_type=tf.int64)\n",
714 | " tags += 1\n",
715 | " return (tokens, tags)\n",
716 | "\n",
717 | "def fil(ds):\n",
718 | " return ds.filter(lambda x,y: tokenizer.tokenize_with_offsets(x)[0].shape==y.shape)\n",
719 | "\n",
720 | "\n",
721 | "batch_size = 32\n",
722 | "train_dataset = train_data.map(map_record_to_training_data)\n",
723 | " # .map(lambda x, y,z: (lowercase_and_convert_to_ids(x), y,z))\n",
724 | "\n",
725 | "# train_dataset = train_dataset.apply(fil)\n",
726 | "val_dataset = val_data.map(map_record_to_training_data)\n",
727 | " # .map(lambda x, y,z: (lowercase_and_convert_to_ids(x), y,z))\n",
728 | "\n",
729 | "# val_dataset = val_dataset.apply(fil)\n"
730 | ],
731 | "metadata": {
732 | "id": "Rp7HnnSPfOE1"
733 | },
734 | "execution_count": null,
735 | "outputs": []
736 | },
737 | {
738 | "cell_type": "code",
739 | "source": [
740 | "x_train = []\n",
741 | "y_train = []\n",
742 | "cnt =0\n",
743 | "mnt= 0\n",
744 | "for x,y in train_dataset:\n",
745 | " if x.shape == y.shape:\n",
746 | " x_train.append(x)\n",
747 | " y_oh=[]\n",
748 | " for tag in y:\n",
749 | " t = [0]*num_tags\n",
750 | " t[tag]=1\n",
751 | " y_oh.append(t)\n",
752 | " y_train.append(y_oh)\n",
753 | "len(x_train)"
754 | ],
755 | "metadata": {
756 | "colab": {
757 | "base_uri": "https://localhost:8080/"
758 | },
759 | "id": "Gw_I4pqZ-I_M",
760 | "outputId": "38f6d37e-abc0-4217-8c50-a4e16fc01e66"
761 | },
762 | "execution_count": null,
763 | "outputs": [
764 | {
765 | "output_type": "execute_result",
766 | "data": {
767 | "text/plain": [
768 | "5416"
769 | ]
770 | },
771 | "metadata": {},
772 | "execution_count": 32
773 | }
774 | ]
775 | },
776 | {
777 | "cell_type": "code",
778 | "source": [
779 | "x_val = []\n",
780 | "y_val = []\n",
781 | "cnt =0\n",
782 | "mnt= 0\n",
783 | "for x,y in val_dataset:\n",
784 | " if x.shape == y.shape:\n",
785 | " x_val.append(x)\n",
786 | " y_oh=[]\n",
787 | " for tag in y:\n",
788 | " t = [0]*num_tags\n",
789 | " t[tag]=1\n",
790 | " y_oh.append(t)\n",
791 | " y_val.append(y_oh)\n",
792 | "len(x_val)"
793 | ],
794 | "metadata": {
795 | "colab": {
796 | "base_uri": "https://localhost:8080/"
797 | },
798 | "id": "gq260EwR8f6K",
799 | "outputId": "621b44f7-fc66-47db-8d34-d35f29155403"
800 | },
801 | "execution_count": null,
802 | "outputs": [
803 | {
804 | "output_type": "execute_result",
805 | "data": {
806 | "text/plain": [
807 | "1205"
808 | ]
809 | },
810 | "metadata": {},
811 | "execution_count": 34
812 | }
813 | ]
814 | },
815 | {
816 | "cell_type": "markdown",
817 | "source": [
818 | "## Model Building\n",
819 | "\n",
820 | "For this pipeline we use the `CustomNonPaddingTokenLoss` and then create the NER model. The backbone of the model is the pretrained `Roberta` model of KerasNLP with the base configuration. Then we use a Dense layer head for entity classification."
821 | ],
822 | "metadata": {
823 | "id": "GnNEySU_jD7H"
824 | }
825 | },
826 | {
827 | "cell_type": "code",
828 | "source": [
829 | "class CustomNonPaddingTokenLoss(keras.losses.Loss):\n",
830 | " def __init__(self, name=\"custom_ner_loss\"):\n",
831 | " super().__init__(name=name)\n",
832 | "\n",
833 | " def call(self, y_true, y_pred):\n",
834 | " loss_fn = keras.losses.CategoricalCrossentropy()\n",
835 | " loss = loss_fn(y_true, y_pred)\n",
836 | " mask = tf.cast((y_true > 0), dtype=tf.float32)\n",
837 | " loss = loss * mask\n",
838 | " return tf.reduce_sum(loss) / tf.reduce_sum(mask)\n",
839 | "\n",
840 | "\n",
841 | "loss = CustomNonPaddingTokenLoss()"
842 | ],
843 | "metadata": {
844 | "id": "1xZqKuoTfa_g"
845 | },
846 | "execution_count": null,
847 | "outputs": []
848 | },
849 | {
850 | "cell_type": "code",
851 | "source": [
852 | "class NERModel(keras.Model):\n",
853 | " def __init__(\n",
854 | " self, num_tags, ff_dim=32\n",
855 | " ):\n",
856 | " super().__init__()\n",
857 | " self.tokenizer_ = tokenizer\n",
858 | " # self.proc = keras_nlp.models.RobertaPreprocessor.from_preset(\"roberta_base_en\")\n",
859 | " self.transformer_block =keras_nlp.models.RobertaBackbone.from_preset(\"roberta_base_en\")\n",
860 | " # self.transformer_block = keras_nlp.models.RobertaBackbone(vocab_size,4, num_heads, ff_dim,32,max_sequence_length=maxlen)\n",
861 | " self.dropout1 = layers.Dropout(0.1)\n",
862 | " self.flat=layers.Flatten()\n",
863 | " self.ff = layers.Dense(ff_dim, activation=\"relu\")\n",
864 | " self.dropout2 = layers.Dropout(0.1)\n",
865 | " self.ff_final = layers.Dense(num_tags, activation=\"softmax\")\n",
866 | "\n",
867 | " def call(self, inputs, training=False):\n",
868 | " # print(inputs)\n",
869 | " # inputs = self.tokenizer_.tokenize_with_offsets(inputs)[0]\n",
870 | " # print(inputs)\n",
871 | " # print(inputs.shape)\n",
872 | " mask = tf.ones_like(inputs)\n",
873 | " # print(mask)\n",
874 | " # inp = self.proc(inputs)\n",
875 | " x = self.transformer_block([tf.expand_dims(inputs,axis=0),tf.expand_dims(mask,0)])\n",
876 | " x = self.dropout1(x, training=training)\n",
877 | " x = self.ff(x)\n",
878 | " x = self.dropout2(x, training=training)\n",
879 | " x = self.ff_final(x)\n",
880 | " return x\n",
881 | "ner_model = NERModel(num_tags, ff_dim=64)\n",
882 | "# ner_model.compile(optimizer=\"adam\", loss=loss)"
883 | ],
884 | "metadata": {
885 | "id": "2o9E7RW-bv90"
886 | },
887 | "execution_count": null,
888 | "outputs": []
889 | },
890 | {
891 | "cell_type": "code",
892 | "source": [
893 | "optimizer = keras.optimizers.Adam(10e-5)\n",
894 | "# Instantiate a loss function.\n",
895 | "loss_fn = loss\n",
896 | "train_acc_metric = keras.metrics.CategoricalAccuracy()\n",
897 | "val_acc_metric = keras.metrics.CategoricalAccuracy()"
898 | ],
899 | "metadata": {
900 | "id": "7HPPJvGPMjmB"
901 | },
902 | "execution_count": null,
903 | "outputs": []
904 | },
905 | {
906 | "cell_type": "code",
907 | "source": [
908 | "import numpy as np\n",
909 | "\n",
910 | "@tf.function\n",
911 | "def train_step(x, y):\n",
912 | " with tf.GradientTape() as tape:\n",
913 | " logits = ner_model(x, training=True)\n",
914 | " loss_value = loss_fn(y, logits)\n",
915 | " grads = tape.gradient(loss_value, ner_model.trainable_weights)\n",
916 | " optimizer.apply_gradients(zip(grads, ner_model.trainable_weights))\n",
917 | " train_acc_metric.update_state(y, logits)\n",
918 | " return loss_value\n",
919 | "@tf.function\n",
920 | "def test_step(x, y):\n",
921 | " val_logits = ner_model(x, training=False)\n",
922 | " val_acc_metric.update_state(y, val_logits)\n",
923 | "import time\n",
924 | "from tqdm import tqdm\n",
925 | "train_acc_list=[]\n",
926 | "train_loss_list=[]\n",
927 | "epochs = 2\n",
928 | "for epoch in range(epochs):\n",
929 | " print(\"\\nStart of epoch %d\" % (epoch,))\n",
930 | " start_time = time.time()\n",
931 | " train_loss = []\n",
932 | " train_loss_batch=[]\n",
933 | " for step, (x_batch_train, y_batch_train) in tqdm(enumerate(zip(x_train,y_train))):\n",
934 | " loss_value = train_step(x_batch_train, tf.expand_dims(y_batch_train,axis=0))\n",
935 | " train_loss.append(float(loss_value))\n",
936 | " train_loss_batch.append(float(loss_value))\n",
937 | " if step % 1000 == 0:\n",
938 | " print(\n",
939 | " \"Training loss (for one batch) at step %d: %.4f\"\n",
940 | " % (step, np.mean(train_loss_batch))\n",
941 | " )\n",
942 | " train_loss_batch=[]\n",
943 | " print(\"Seen so far: %d samples\" % ((step + 1) ))\n",
944 | " train_loss_list.append(np.mean(train_loss))\n",
945 | " train_acc = train_acc_metric.result()\n",
946 | " print(\"Training acc over epoch: %.4f\" % (float(train_acc),))\n",
947 | " train_acc_list.append(float(train_acc))\n",
948 | " train_acc_metric.reset_states()\n",
949 | " print(\"Time taken: %.2fs\" % (time.time() - start_time))"
950 | ],
951 | "metadata": {
952 | "colab": {
953 | "base_uri": "https://localhost:8080/"
954 | },
955 | "id": "TycS5BAkXmGq",
956 | "outputId": "6610cbf8-d23f-48c6-bd31-818f4a6bd047"
957 | },
958 | "execution_count": null,
959 | "outputs": [
960 | {
961 | "output_type": "stream",
962 | "name": "stdout",
963 | "text": [
964 | "\n",
965 | "Start of epoch 0\n"
966 | ]
967 | },
968 | {
969 | "output_type": "stream",
970 | "name": "stderr",
971 | "text": [
972 | "1it [00:48, 48.86s/it]"
973 | ]
974 | },
975 | {
976 | "output_type": "stream",
977 | "name": "stdout",
978 | "text": [
979 | "Training loss (for one batch) at step 0: 1.3125\n",
980 | "Seen so far: 1 samples\n"
981 | ]
982 | },
983 | {
984 | "output_type": "stream",
985 | "name": "stderr",
986 | "text": [
987 | "4it [01:23, 16.33s/it]WARNING:tensorflow:5 out of the last 5 calls to triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details.\n",
988 | "5it [01:31, 13.49s/it]WARNING:tensorflow:6 out of the last 6 calls to triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details.\n",
989 | "1004it [08:09, 21.06it/s]"
990 | ]
991 | },
992 | {
993 | "output_type": "stream",
994 | "name": "stdout",
995 | "text": [
996 | "Training loss (for one batch) at step 1000: 1.0708\n",
997 | "Seen so far: 1001 samples\n"
998 | ]
999 | },
1000 | {
1001 | "output_type": "stream",
1002 | "name": "stderr",
1003 | "text": [
1004 | "2003it [09:04, 21.12it/s]"
1005 | ]
1006 | },
1007 | {
1008 | "output_type": "stream",
1009 | "name": "stdout",
1010 | "text": [
1011 | "Training loss (for one batch) at step 2000: 1.0809\n",
1012 | "Seen so far: 2001 samples\n"
1013 | ]
1014 | },
1015 | {
1016 | "output_type": "stream",
1017 | "name": "stderr",
1018 | "text": [
1019 | "3003it [10:05, 19.56it/s]"
1020 | ]
1021 | },
1022 | {
1023 | "output_type": "stream",
1024 | "name": "stdout",
1025 | "text": [
1026 | "Training loss (for one batch) at step 3000: 1.0715\n",
1027 | "Seen so far: 3001 samples\n"
1028 | ]
1029 | },
1030 | {
1031 | "output_type": "stream",
1032 | "name": "stderr",
1033 | "text": [
1034 | "4003it [11:07, 20.41it/s]"
1035 | ]
1036 | },
1037 | {
1038 | "output_type": "stream",
1039 | "name": "stdout",
1040 | "text": [
1041 | "Training loss (for one batch) at step 4000: 0.9544\n",
1042 | "Seen so far: 4001 samples\n"
1043 | ]
1044 | },
1045 | {
1046 | "output_type": "stream",
1047 | "name": "stderr",
1048 | "text": [
1049 | "5005it [12:19, 20.50it/s]"
1050 | ]
1051 | },
1052 | {
1053 | "output_type": "stream",
1054 | "name": "stdout",
1055 | "text": [
1056 | "Training loss (for one batch) at step 5000: 1.0672\n",
1057 | "Seen so far: 5001 samples\n"
1058 | ]
1059 | },
1060 | {
1061 | "output_type": "stream",
1062 | "name": "stderr",
1063 | "text": [
1064 | "5416it [12:41, 7.12it/s]\n"
1065 | ]
1066 | },
1067 | {
1068 | "output_type": "stream",
1069 | "name": "stdout",
1070 | "text": [
1071 | "Training acc over epoch: 0.8199\n",
1072 | "Time taken: 761.06s\n",
1073 | "\n",
1074 | "Start of epoch 1\n"
1075 | ]
1076 | },
1077 | {
1078 | "output_type": "stream",
1079 | "name": "stderr",
1080 | "text": [
1081 | "3it [00:00, 20.47it/s]"
1082 | ]
1083 | },
1084 | {
1085 | "output_type": "stream",
1086 | "name": "stdout",
1087 | "text": [
1088 | "Training loss (for one batch) at step 0: 1.3068\n",
1089 | "Seen so far: 1 samples\n"
1090 | ]
1091 | },
1092 | {
1093 | "output_type": "stream",
1094 | "name": "stderr",
1095 | "text": [
1096 | "1003it [00:55, 21.15it/s]"
1097 | ]
1098 | },
1099 | {
1100 | "output_type": "stream",
1101 | "name": "stdout",
1102 | "text": [
1103 | "Training loss (for one batch) at step 1000: 1.0671\n",
1104 | "Seen so far: 1001 samples\n"
1105 | ]
1106 | },
1107 | {
1108 | "output_type": "stream",
1109 | "name": "stderr",
1110 | "text": [
1111 | "2004it [01:50, 19.74it/s]"
1112 | ]
1113 | },
1114 | {
1115 | "output_type": "stream",
1116 | "name": "stdout",
1117 | "text": [
1118 | "Training loss (for one batch) at step 2000: 1.0768\n",
1119 | "Seen so far: 2001 samples\n"
1120 | ]
1121 | },
1122 | {
1123 | "output_type": "stream",
1124 | "name": "stderr",
1125 | "text": [
1126 | "3003it [02:43, 17.31it/s]"
1127 | ]
1128 | },
1129 | {
1130 | "output_type": "stream",
1131 | "name": "stdout",
1132 | "text": [
1133 | "Training loss (for one batch) at step 3000: 1.0724\n",
1134 | "Seen so far: 3001 samples\n"
1135 | ]
1136 | },
1137 | {
1138 | "output_type": "stream",
1139 | "name": "stderr",
1140 | "text": [
1141 | "4005it [03:39, 20.32it/s]"
1142 | ]
1143 | },
1144 | {
1145 | "output_type": "stream",
1146 | "name": "stdout",
1147 | "text": [
1148 | "Training loss (for one batch) at step 4000: 0.9605\n",
1149 | "Seen so far: 4001 samples\n"
1150 | ]
1151 | },
1152 | {
1153 | "output_type": "stream",
1154 | "name": "stderr",
1155 | "text": [
1156 | "5004it [04:33, 19.27it/s]"
1157 | ]
1158 | },
1159 | {
1160 | "output_type": "stream",
1161 | "name": "stdout",
1162 | "text": [
1163 | "Training loss (for one batch) at step 5000: 1.0558\n",
1164 | "Seen so far: 5001 samples\n"
1165 | ]
1166 | },
1167 | {
1168 | "output_type": "stream",
1169 | "name": "stderr",
1170 | "text": [
1171 | "5416it [04:55, 18.35it/s]"
1172 | ]
1173 | },
1174 | {
1175 | "output_type": "stream",
1176 | "name": "stdout",
1177 | "text": [
1178 | "Training acc over epoch: 0.8199\n",
1179 | "Time taken: 295.22s\n"
1180 | ]
1181 | },
1182 | {
1183 | "output_type": "stream",
1184 | "name": "stderr",
1185 | "text": [
1186 | "\n"
1187 | ]
1188 | }
1189 | ]
1190 | },
1191 | {
1192 | "cell_type": "code",
1193 | "source": [
1194 | "\n",
1195 | "txt= \"eu rejects german call to boycott british lamb\"\n",
1196 | "# Sample inference using the trained model\n",
1197 | "sample_input = tokenizer.tokenize_with_offsets(txt)[0]\n",
1198 | "\n",
1199 | "output = ner_model.predict(sample_input)\n",
1200 | "prediction = np.argmax(output, axis=-1)[0]\n",
1201 | "prediction = [mapping[i] for i in prediction]\n",
1202 | "\n",
1203 | "# eu -> B-ORG, german -> B-MISC, british -> B-MISC\n",
1204 | "print(sample_input)\n",
1205 | "print(prediction)\n",
1206 | "for tok, pred in zip(txt.split(), prediction):\n",
1207 | " print(tok, pred)"
1208 | ],
1209 | "metadata": {
1210 | "id": "mvUDhYbgfjYx",
1211 | "colab": {
1212 | "base_uri": "https://localhost:8080/"
1213 | },
1214 | "outputId": "7b744823-e14f-4bdd-d45b-a93269ea945e"
1215 | },
1216 | "execution_count": null,
1217 | "outputs": [
1218 | {
1219 | "output_type": "stream",
1220 | "name": "stdout",
1221 | "text": [
1222 | "1/1 [==============================] - 4s 4s/step\n",
1223 | "tf.Tensor([ 7327 19164 2446 2655 2000 17757 2329 12559], shape=(8,), dtype=int64)\n",
1224 | "['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O']\n",
1225 | "eu O\n",
1226 | "rejects O\n",
1227 | "german O\n",
1228 | "call O\n",
1229 | "to O\n",
1230 | "boycott O\n",
1231 | "british O\n",
1232 | "lamb O\n"
1233 | ]
1234 | }
1235 | ]
1236 | },
1237 | {
1238 | "cell_type": "code",
1239 | "source": [],
1240 | "metadata": {
1241 | "id": "lAUNQQCVGAuI"
1242 | },
1243 | "execution_count": null,
1244 | "outputs": []
1245 | }
1246 | ]
1247 | }
--------------------------------------------------------------------------------