├── README.md ├── Fine Tune BERT for Text Classification with TensorFlow.pdf └── Copy_of_Fine_Tune_BERT_for_Text_Classification_with_TensorFlow.ipynb /README.md: -------------------------------------------------------------------------------- 1 | # Fine Tune BERT for Text Classification with TensorFlow 2 | 3 | Fine Tune BERT for Text Classification with TensorFlow Coursera Course 4 | -------------------------------------------------------------------------------- /Fine Tune BERT for Text Classification with TensorFlow.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tamanna18/Fine_Tune_BERT_for_Text_Classification_with_TensorFlow/HEAD/Fine Tune BERT for Text Classification with TensorFlow.pdf -------------------------------------------------------------------------------- /Copy_of_Fine_Tune_BERT_for_Text_Classification_with_TensorFlow.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "view-in-github", 7 | "colab_type": "text" 8 | }, 9 | "source": [ 10 | "\"Open" 11 | ] 12 | }, 13 | { 14 | "cell_type": "markdown", 15 | "metadata": { 16 | "id": "zGCJYkQj_Uu2" 17 | }, 18 | "source": [ 19 | "

Fine-Tune BERT for Text Classification with TensorFlow

" 20 | ] 21 | }, 22 | { 23 | "cell_type": "markdown", 24 | "metadata": { 25 | "id": "4y2m1S6e12il" 26 | }, 27 | "source": [ 28 | "
\n", 29 | " \n", 30 | "

Figure 1: BERT Classification Model

\n", 31 | "
" 32 | ] 33 | }, 34 | { 35 | "cell_type": "markdown", 36 | "metadata": { 37 | "id": "eYYYWqWr_WCC" 38 | }, 39 | "source": [ 40 | "In this [project](https://www.coursera.org/projects/fine-tune-bert-tensorflow/), you will learn how to fine-tune a BERT model for text classification using TensorFlow and TF-Hub." 41 | ] 42 | }, 43 | { 44 | "cell_type": "markdown", 45 | "metadata": { 46 | "id": "5yQG5PCO_WFx" 47 | }, 48 | "source": [ 49 | "The pretrained BERT model used in this project is [available](https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/2) on [TensorFlow Hub](https://tfhub.dev/)." 50 | ] 51 | }, 52 | { 53 | "cell_type": "markdown", 54 | "metadata": { 55 | "id": "7pKNS21u_WJo" 56 | }, 57 | "source": [ 58 | "### Learning Objectives" 59 | ] 60 | }, 61 | { 62 | "cell_type": "markdown", 63 | "metadata": { 64 | "id": "_3NHSMXv_WMv" 65 | }, 66 | "source": [ 67 | "By the time you complete this project, you will be able to:\n", 68 | "\n", 69 | "- Build TensorFlow Input Pipelines for Text Data with the [`tf.data`](https://www.tensorflow.org/api_docs/python/tf/data) API\n", 70 | "- Tokenize and Preprocess Text for BERT\n", 71 | "- Fine-tune BERT for text classification with TensorFlow 2 and [TF Hub](https://tfhub.dev)" 72 | ] 73 | }, 74 | { 75 | "cell_type": "markdown", 76 | "metadata": { 77 | "id": "o6BEe-3-AVRQ" 78 | }, 79 | "source": [ 80 | "### Prerequisites" 81 | ] 82 | }, 83 | { 84 | "cell_type": "markdown", 85 | "metadata": { 86 | "id": "Sc9f-8rLAVUS" 87 | }, 88 | "source": [ 89 | "In order to be successful with this project, it is assumed you are:\n", 90 | "\n", 91 | "- Competent in the Python programming language\n", 92 | "- Familiar with deep learning for Natural Language Processing (NLP)\n", 93 | "- Familiar with TensorFlow, and its Keras API" 94 | ] 95 | }, 96 | { 97 | "cell_type": "markdown", 98 | "metadata": { 99 | "id": "MYXXV5n3Ab-4" 100 | }, 101 | "source": [ 102 | "### Contents" 103 | ] 104 | }, 105 | { 106 | "cell_type": "markdown", 107 | "metadata": { 108 | "id": "XhK-SYGyAjxe" 109 | }, 110 | "source": [ 111 | "This project/notebook consists of several Tasks.\n", 112 | "\n", 113 | "- **[Task 1]()**: Introduction to the Project.\n", 114 | "- **[Task 2]()**: Setup your TensorFlow and Colab Runtime\n", 115 | "- **[Task 3]()**: Download and Import the Quora Insincere Questions Dataset\n", 116 | "- **[Task 4]()**: Create tf.data.Datasets for Training and Evaluation\n", 117 | "- **[Task 5]()**: Download a Pre-trained BERT Model from TensorFlow Hub\n", 118 | "- **[Task 6]()**: Tokenize and Preprocess Text for BERT\n", 119 | "- **[Task 7]()**: Wrap a Python Function into a TensorFlow op for Eager Execution\n", 120 | "- **[Task 8]()**: Create a TensorFlow Input Pipeline with `tf.data`\n", 121 | "- **[Task 9]()**: Add a Classification Head to the BERT `hub.KerasLayer`\n", 122 | "- **[Task 10]()**: Fine-Tune BERT for Text Classification\n", 123 | "- **[Task 11]()**: Evaluate the BERT Text Classification Model" 124 | ] 125 | }, 126 | { 127 | "cell_type": "markdown", 128 | "metadata": { 129 | "id": "IaArqXjRAcBa" 130 | }, 131 | "source": [ 132 | "## Task 2: Setup your TensorFlow and Colab Runtime." 133 | ] 134 | }, 135 | { 136 | "cell_type": "markdown", 137 | "metadata": { 138 | "id": "GDDhjzZ5A4Q_" 139 | }, 140 | "source": [ 141 | "You will only be able to use the Colab Notebook after you save it to your Google Drive folder. Click on the File menu and select “Save a copy in Drive…\n", 142 | "\n", 143 | "![Copy to Drive](https://drive.google.com/uc?id=1CH3eDmuJL8WR0AP1r3UE6sOPuqq8_Wl7)\n" 144 | ] 145 | }, 146 | { 147 | "cell_type": "markdown", 148 | "metadata": { 149 | "id": "mpe6GhLuBJWB" 150 | }, 151 | "source": [ 152 | "### Check GPU Availability\n", 153 | "\n", 154 | "Check if your Colab notebook is configured to use Graphical Processing Units (GPUs). If zero GPUs are available, check if the Colab notebook is configured to use GPUs (Menu > Runtime > Change Runtime Type).\n", 155 | "\n", 156 | "![Hardware Accelerator Settings](https://drive.google.com/uc?id=1qrihuuMtvzXJHiRV8M7RngbxFYipXKQx)\n" 157 | ] 158 | }, 159 | { 160 | "cell_type": "code", 161 | "execution_count": null, 162 | "metadata": { 163 | "colab": { 164 | "base_uri": "https://localhost:8080/" 165 | }, 166 | "id": "8V9c8vzSL3aj", 167 | "outputId": "cd2c5f92-c098-4ede-81f2-fbeaaa08cef9" 168 | }, 169 | "outputs": [ 170 | { 171 | "name": "stdout", 172 | "output_type": "stream", 173 | "text": [ 174 | "Mon Jan 31 08:37:13 2022 \n", 175 | "+-----------------------------------------------------------------------------+\n", 176 | "| NVIDIA-SMI 495.46 Driver Version: 460.32.03 CUDA Version: 11.2 |\n", 177 | "|-------------------------------+----------------------+----------------------+\n", 178 | "| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |\n", 179 | "| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\n", 180 | "| | | MIG M. |\n", 181 | "|===============================+======================+======================|\n", 182 | "| 0 Tesla K80 Off | 00000000:00:04.0 Off | 0 |\n", 183 | "| N/A 28C P8 28W / 149W | 0MiB / 11441MiB | 0% Default |\n", 184 | "| | | N/A |\n", 185 | "+-------------------------------+----------------------+----------------------+\n", 186 | " \n", 187 | "+-----------------------------------------------------------------------------+\n", 188 | "| Processes: |\n", 189 | "| GPU GI CI PID Type Process name GPU Memory |\n", 190 | "| ID ID Usage |\n", 191 | "|=============================================================================|\n", 192 | "| No running processes found |\n", 193 | "+-----------------------------------------------------------------------------+\n" 194 | ] 195 | } 196 | ], 197 | "source": [ 198 | "!nvidia-smi" 199 | ] 200 | }, 201 | { 202 | "cell_type": "markdown", 203 | "metadata": { 204 | "id": "Obch3rAuBVf0" 205 | }, 206 | "source": [ 207 | "### Install TensorFlow and TensorFlow Model Garden" 208 | ] 209 | }, 210 | { 211 | "cell_type": "code", 212 | "execution_count": null, 213 | "metadata": { 214 | "colab": { 215 | "base_uri": "https://localhost:8080/" 216 | }, 217 | "id": "bUQEY3dFB0jX", 218 | "outputId": "57e4e20c-9daa-4447-a9bd-c75a3942593f" 219 | }, 220 | "outputs": [ 221 | { 222 | "name": "stdout", 223 | "output_type": "stream", 224 | "text": [ 225 | "2.7.0\n" 226 | ] 227 | } 228 | ], 229 | "source": [ 230 | "import tensorflow as tf\n", 231 | "print(tf.version.VERSION)" 232 | ] 233 | }, 234 | { 235 | "cell_type": "code", 236 | "execution_count": null, 237 | "metadata": { 238 | "id": "aU3YLZ1TYKUt" 239 | }, 240 | "outputs": [], 241 | "source": [ 242 | "!pip install -q tensorflow==2.7.0" 243 | ] 244 | }, 245 | { 246 | "cell_type": "code", 247 | "execution_count": null, 248 | "metadata": { 249 | "colab": { 250 | "base_uri": "https://localhost:8080/" 251 | }, 252 | "id": "AFRTC-zwUy6D", 253 | "outputId": "5876b05d-cddc-458e-f04e-b87b8d913321" 254 | }, 255 | "outputs": [ 256 | { 257 | "name": "stdout", 258 | "output_type": "stream", 259 | "text": [ 260 | "Cloning into 'models'...\n", 261 | "remote: Enumerating objects: 2650, done.\u001b[K\n", 262 | "remote: Counting objects: 100% (2650/2650), done.\u001b[K\n", 263 | "remote: Compressing objects: 100% (2311/2311), done.\u001b[K\n", 264 | "remote: Total 2650 (delta 506), reused 1388 (delta 306), pack-reused 0\u001b[K\n", 265 | "Receiving objects: 100% (2650/2650), 34.02 MiB | 18.55 MiB/s, done.\n", 266 | "Resolving deltas: 100% (506/506), done.\n", 267 | "Note: checking out '400d68abbccda2f0f6609e3a924467718b144233'.\n", 268 | "\n", 269 | "You are in 'detached HEAD' state. You can look around, make experimental\n", 270 | "changes and commit them, and you can discard any commits you make in this\n", 271 | "state without impacting any branches by performing another checkout.\n", 272 | "\n", 273 | "If you want to create a new branch to retain commits you create, you may\n", 274 | "do so (now or later) by using -b with the checkout command again. Example:\n", 275 | "\n", 276 | " git checkout -b \n", 277 | "\n" 278 | ] 279 | } 280 | ], 281 | "source": [ 282 | "!git clone --depth 1 -b v2.3.0 https://github.com/tensorflow/models.git" 283 | ] 284 | }, 285 | { 286 | "cell_type": "code", 287 | "execution_count": null, 288 | "metadata": { 289 | "colab": { 290 | "base_uri": "https://localhost:8080/" 291 | }, 292 | "id": "3H2G0571zLLs", 293 | "outputId": "86486174-5d18-463e-982e-13a6711c08a0" 294 | }, 295 | "outputs": [ 296 | { 297 | "name": "stdout", 298 | "output_type": "stream", 299 | "text": [ 300 | "\u001b[K |████████████████████████████████| 8.0 MB 11.6 MB/s \n", 301 | "\u001b[K |████████████████████████████████| 205 kB 53.5 MB/s \n", 302 | "\u001b[K |████████████████████████████████| 15.7 MB 48.8 MB/s \n", 303 | "\u001b[K |████████████████████████████████| 11.3 MB 40.3 MB/s \n", 304 | "\u001b[K |████████████████████████████████| 280 kB 44.4 MB/s \n", 305 | "\u001b[K |████████████████████████████████| 99 kB 9.2 MB/s \n", 306 | "\u001b[K |████████████████████████████████| 38.1 MB 1.5 MB/s \n", 307 | "\u001b[K |████████████████████████████████| 213 kB 54.1 MB/s \n", 308 | "\u001b[K |████████████████████████████████| 4.2 MB 44.7 MB/s \n", 309 | "\u001b[K |████████████████████████████████| 1.1 MB 42.1 MB/s \n", 310 | "\u001b[K |████████████████████████████████| 352 kB 51.0 MB/s \n", 311 | "\u001b[K |████████████████████████████████| 1.2 MB 44.0 MB/s \n", 312 | "\u001b[K |████████████████████████████████| 1.9 MB 40.6 MB/s \n", 313 | "\u001b[K |████████████████████████████████| 11.2 MB 38.8 MB/s \n", 314 | "\u001b[K |████████████████████████████████| 47.7 MB 88 kB/s \n", 315 | "\u001b[K |████████████████████████████████| 596 kB 41.3 MB/s \n", 316 | "\u001b[K |████████████████████████████████| 4.3 MB 41.4 MB/s \n", 317 | "\u001b[K |████████████████████████████████| 75 kB 4.2 MB/s \n", 318 | "\u001b[K |████████████████████████████████| 111 kB 52.7 MB/s \n", 319 | "\u001b[K |████████████████████████████████| 45 kB 3.2 MB/s \n", 320 | "\u001b[K |████████████████████████████████| 895 kB 37.4 MB/s \n", 321 | "\u001b[K |████████████████████████████████| 1.1 MB 46.9 MB/s \n", 322 | "\u001b[?25h Building wheel for py-cpuinfo (setup.py) ... \u001b[?25l\u001b[?25hdone\n", 323 | "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n", 324 | "yellowbrick 1.3.post1 requires numpy<1.20,>=1.16.0, but you have numpy 1.21.5 which is incompatible.\n", 325 | "pandas-gbq 0.13.3 requires google-cloud-bigquery[bqstorage,pandas]<2.0.0dev,>=1.11.1, but you have google-cloud-bigquery 2.32.0 which is incompatible.\n", 326 | "google-colab 1.0.0 requires pandas~=1.1.0; python_version >= \"3.0\", but you have pandas 1.3.5 which is incompatible.\n", 327 | "google-colab 1.0.0 requires six~=1.15.0, but you have six 1.16.0 which is incompatible.\n", 328 | "google-cloud-translate 1.5.0 requires google-api-core[grpc]<2.0.0dev,>=1.6.0, but you have google-api-core 2.4.0 which is incompatible.\n", 329 | "google-cloud-translate 1.5.0 requires google-cloud-core<2.0dev,>=1.0.0, but you have google-cloud-core 2.2.2 which is incompatible.\n", 330 | "google-cloud-storage 1.18.1 requires google-cloud-core<2.0dev,>=1.0.0, but you have google-cloud-core 2.2.2 which is incompatible.\n", 331 | "google-cloud-storage 1.18.1 requires google-resumable-media<0.5.0dev,>=0.3.1, but you have google-resumable-media 2.1.0 which is incompatible.\n", 332 | "google-cloud-language 1.2.0 requires google-api-core[grpc]<2.0.0dev,>=1.6.0, but you have google-api-core 2.4.0 which is incompatible.\n", 333 | "google-cloud-firestore 1.7.0 requires google-api-core[grpc]<2.0.0dev,>=1.14.0, but you have google-api-core 2.4.0 which is incompatible.\n", 334 | "google-cloud-firestore 1.7.0 requires google-cloud-core<2.0dev,>=1.0.3, but you have google-cloud-core 2.2.2 which is incompatible.\n", 335 | "google-cloud-datastore 1.8.0 requires google-api-core[grpc]<2.0.0dev,>=1.6.0, but you have google-api-core 2.4.0 which is incompatible.\n", 336 | "google-cloud-datastore 1.8.0 requires google-cloud-core<2.0dev,>=1.0.0, but you have google-cloud-core 2.2.2 which is incompatible.\n", 337 | "google-cloud-bigquery-storage 1.1.0 requires google-api-core[grpc]<2.0.0dev,>=1.14.0, but you have google-api-core 2.4.0 which is incompatible.\n", 338 | "firebase-admin 4.4.0 requires google-api-core[grpc]<2.0.0dev,>=1.14.0; platform_python_implementation != \"PyPy\", but you have google-api-core 2.4.0 which is incompatible.\n", 339 | "earthengine-api 0.1.295 requires google-api-python-client<2,>=1.12.1, but you have google-api-python-client 2.36.0 which is incompatible.\n", 340 | "datascience 0.10.6 requires folium==0.2.1, but you have folium 0.8.3 which is incompatible.\n", 341 | "albumentations 0.1.12 requires imgaug<0.2.7,>=0.2.5, but you have imgaug 0.2.9 which is incompatible.\u001b[0m\n" 342 | ] 343 | } 344 | ], 345 | "source": [ 346 | "# install requirements to use tensorflow/models repository\n", 347 | "!pip install -Uqr models/official/requirements.txt\n", 348 | "# you may have to restart the runtime afterwards" 349 | ] 350 | }, 351 | { 352 | "cell_type": "markdown", 353 | "metadata": { 354 | "id": "GVjksk4yCXur" 355 | }, 356 | "source": [ 357 | "## Restart the Runtime\n", 358 | "\n", 359 | "**Note** \n", 360 | "After installing the required Python packages, you'll need to restart the Colab Runtime Engine (Menu > Runtime > Restart runtime...)\n", 361 | "\n", 362 | "![Restart of the Colab Runtime Engine](https://drive.google.com/uc?id=1xnjAy2sxIymKhydkqb0RKzgVK9rh3teH)" 363 | ] 364 | }, 365 | { 366 | "cell_type": "markdown", 367 | "metadata": { 368 | "id": "IMsEoT3Fg4Wg" 369 | }, 370 | "source": [ 371 | "## Task 3: Download and Import the Quora Insincere Questions Dataset" 372 | ] 373 | }, 374 | { 375 | "cell_type": "code", 376 | "execution_count": null, 377 | "metadata": { 378 | "id": "GmqEylyFYTdP" 379 | }, 380 | "outputs": [], 381 | "source": [ 382 | "import numpy as np\n", 383 | "import tensorflow as tf\n", 384 | "import tensorflow_hub as hub\n", 385 | "import sys\n", 386 | "sys.path.append('models')\n", 387 | "from official.nlp.data import classifier_data_lib\n", 388 | "from official.nlp.bert import tokenization\n", 389 | "from official.nlp import optimization" 390 | ] 391 | }, 392 | { 393 | "cell_type": "code", 394 | "execution_count": null, 395 | "metadata": { 396 | "colab": { 397 | "base_uri": "https://localhost:8080/" 398 | }, 399 | "id": "ZuX1lB8pPJ-W", 400 | "outputId": "cc1902be-27e8-449e-9c3f-36b2328743c4" 401 | }, 402 | "outputs": [ 403 | { 404 | "name": "stdout", 405 | "output_type": "stream", 406 | "text": [ 407 | "TF Version: 2.7.0\n", 408 | "Eager mode: True\n", 409 | "Hub version: 0.12.0\n", 410 | "GPU is available\n" 411 | ] 412 | } 413 | ], 414 | "source": [ 415 | "print(\"TF Version: \", tf.__version__)\n", 416 | "print(\"Eager mode: \", tf.executing_eagerly())\n", 417 | "print(\"Hub version: \", hub.__version__)\n", 418 | "print(\"GPU is\", \"available\" if tf.config.experimental.list_physical_devices(\"GPU\") else \"NOT AVAILABLE\")" 419 | ] 420 | }, 421 | { 422 | "cell_type": "markdown", 423 | "metadata": { 424 | "id": "QtbwpWgyEZg7" 425 | }, 426 | "source": [ 427 | "A downloadable copy of the [Quora Insincere Questions Classification data](https://www.kaggle.com/c/quora-insincere-questions-classification/data) can be found [https://archive.org/download/fine-tune-bert-tensorflow-train.csv/train.csv.zip](https://archive.org/download/fine-tune-bert-tensorflow-train.csv/train.csv.zip). Decompress and read the data into a pandas DataFrame." 428 | ] 429 | }, 430 | { 431 | "cell_type": "code", 432 | "execution_count": null, 433 | "metadata": { 434 | "colab": { 435 | "background_save": true 436 | }, 437 | "id": "0nI-9itVwCCQ", 438 | "outputId": "6ec77447-e72c-4529-9733-5e7342845940" 439 | }, 440 | "outputs": [ 441 | { 442 | "data": { 443 | "text/plain": [ 444 | "(1306122, 3)" 445 | ] 446 | }, 447 | "execution_count": null, 448 | "metadata": {}, 449 | "output_type": "execute_result" 450 | } 451 | ], 452 | "source": [ 453 | "import numpy as np\n", 454 | "import pandas as pd\n", 455 | "from sklearn.model_selection import train_test_split\n", 456 | "\n", 457 | "df = pd.read_csv ('https://archive.org/download/fine-tune-bert-tensorflow-train.csv/train.csv.zip',\n", 458 | " compression='zip', low_memory=False)\n", 459 | "df.shape" 460 | ] 461 | }, 462 | { 463 | "cell_type": "code", 464 | "execution_count": null, 465 | "metadata": { 466 | "colab": { 467 | "background_save": true 468 | }, 469 | "id": "yeHE98KiMvDd", 470 | "outputId": "b81bdd1e-ff0f-48c5-ad5b-0e1b8578921a" 471 | }, 472 | "outputs": [ 473 | { 474 | "data": { 475 | "text/html": [ 476 | "\n", 477 | "
\n", 478 | "
\n", 479 | "
\n", 480 | "\n", 493 | "\n", 494 | " \n", 495 | " \n", 496 | " \n", 497 | " \n", 498 | " \n", 499 | " \n", 500 | " \n", 501 | " \n", 502 | " \n", 503 | " \n", 504 | " \n", 505 | " \n", 506 | " \n", 507 | " \n", 508 | " \n", 509 | " \n", 510 | " \n", 511 | " \n", 512 | " \n", 513 | " \n", 514 | " \n", 515 | " \n", 516 | " \n", 517 | " \n", 518 | " \n", 519 | " \n", 520 | " \n", 521 | " \n", 522 | " \n", 523 | " \n", 524 | " \n", 525 | " \n", 526 | " \n", 527 | " \n", 528 | " \n", 529 | " \n", 530 | " \n", 531 | " \n", 532 | " \n", 533 | " \n", 534 | " \n", 535 | " \n", 536 | " \n", 537 | " \n", 538 | " \n", 539 | " \n", 540 | " \n", 541 | " \n", 542 | " \n", 543 | " \n", 544 | " \n", 545 | " \n", 546 | " \n", 547 | " \n", 548 | " \n", 549 | " \n", 550 | " \n", 551 | " \n", 552 | " \n", 553 | " \n", 554 | " \n", 555 | " \n", 556 | " \n", 557 | " \n", 558 | " \n", 559 | " \n", 560 | " \n", 561 | " \n", 562 | " \n", 563 | " \n", 564 | " \n", 565 | " \n", 566 | " \n", 567 | " \n", 568 | " \n", 569 | " \n", 570 | " \n", 571 | " \n", 572 | " \n", 573 | " \n", 574 | " \n", 575 | " \n", 576 | " \n", 577 | " \n", 578 | " \n", 579 | " \n", 580 | " \n", 581 | " \n", 582 | " \n", 583 | " \n", 584 | " \n", 585 | " \n", 586 | " \n", 587 | " \n", 588 | " \n", 589 | " \n", 590 | " \n", 591 | " \n", 592 | " \n", 593 | " \n", 594 | " \n", 595 | " \n", 596 | " \n", 597 | " \n", 598 | " \n", 599 | " \n", 600 | " \n", 601 | " \n", 602 | " \n", 603 | " \n", 604 | " \n", 605 | " \n", 606 | " \n", 607 | " \n", 608 | " \n", 609 | " \n", 610 | " \n", 611 | " \n", 612 | " \n", 613 | " \n", 614 | " \n", 615 | " \n", 616 | " \n", 617 | " \n", 618 | " \n", 619 | " \n", 620 | " \n", 621 | " \n", 622 | " \n", 623 | " \n", 624 | "
qidquestion_texttarget
1306102ffff3778790af9baae76What steps can I take to live a normal life if...0
1306103ffff3f0a2449ffe4b9ffIsn't Trump right after all? Why should the US...1
1306104ffff41393389d4206066Is 33 too late for a career in creative advert...0
1306105ffff42493fc203cd9532What is difference between the filteration wor...0
1306106ffff48dd47bee89fff79If the universe \"popped\" into existence from n...0
1306107ffff5fd051a032f32a39How does a shared service technology team meas...0
1306108ffff6d528040d3888b93How is DSATM civil engineering?0
1306109ffff8776cd30cdc8d7f8Do you know any problem that depends solely on...0
1306110ffff94d427ade3716cd1What are some comic ideas for you Tube videos ...0
1306111ffffa382c58368071dc9If you had $10 million of Bitcoin, could you s...0
1306112ffffa5b0fa76431c063fAre you ashamed of being an Indian?1
1306113ffffae5dbda3dc9e9771What are the methods to determine fossil ages ...0
1306114ffffba7c4888798571c1What is your story today?0
1306115ffffc0c7158658a06fd9How do I consume 150 gms protein daily both ve...0
1306116ffffc404da586ac5a08fWhat are the good career options for a msc che...0
1306117ffffcc4e2331aaf1e41eWhat other technical skills do you need as a c...0
1306118ffffd431801e5a2f4861Does MS in ECE have good job prospects in USA ...0
1306119ffffd48fb36b63db010cIs foam insulation toxic?0
1306120ffffec519fa37cf60c78How can one start a research project based on ...0
1306121ffffed09fedb5088744aWho wins in a battle between a Wolverine and a...0
\n", 625 | "
\n", 626 | " \n", 636 | " \n", 637 | " \n", 674 | "\n", 675 | " \n", 699 | "
\n", 700 | "
\n", 701 | " " 702 | ], 703 | "text/plain": [ 704 | " qid ... target\n", 705 | "1306102 ffff3778790af9baae76 ... 0\n", 706 | "1306103 ffff3f0a2449ffe4b9ff ... 1\n", 707 | "1306104 ffff41393389d4206066 ... 0\n", 708 | "1306105 ffff42493fc203cd9532 ... 0\n", 709 | "1306106 ffff48dd47bee89fff79 ... 0\n", 710 | "1306107 ffff5fd051a032f32a39 ... 0\n", 711 | "1306108 ffff6d528040d3888b93 ... 0\n", 712 | "1306109 ffff8776cd30cdc8d7f8 ... 0\n", 713 | "1306110 ffff94d427ade3716cd1 ... 0\n", 714 | "1306111 ffffa382c58368071dc9 ... 0\n", 715 | "1306112 ffffa5b0fa76431c063f ... 1\n", 716 | "1306113 ffffae5dbda3dc9e9771 ... 0\n", 717 | "1306114 ffffba7c4888798571c1 ... 0\n", 718 | "1306115 ffffc0c7158658a06fd9 ... 0\n", 719 | "1306116 ffffc404da586ac5a08f ... 0\n", 720 | "1306117 ffffcc4e2331aaf1e41e ... 0\n", 721 | "1306118 ffffd431801e5a2f4861 ... 0\n", 722 | "1306119 ffffd48fb36b63db010c ... 0\n", 723 | "1306120 ffffec519fa37cf60c78 ... 0\n", 724 | "1306121 ffffed09fedb5088744a ... 0\n", 725 | "\n", 726 | "[20 rows x 3 columns]" 727 | ] 728 | }, 729 | "execution_count": null, 730 | "metadata": {}, 731 | "output_type": "execute_result" 732 | } 733 | ], 734 | "source": [ 735 | "df.tail(20)" 736 | ] 737 | }, 738 | { 739 | "cell_type": "code", 740 | "execution_count": null, 741 | "metadata": { 742 | "colab": { 743 | "background_save": true 744 | }, 745 | "id": "leRFRWJMocVa", 746 | "outputId": "ff9da9d1-05d4-4d34-bcdf-3194a3ff0153" 747 | }, 748 | "outputs": [ 749 | { 750 | "data": { 751 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYIAAAEICAYAAABS0fM3AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAAWdklEQVR4nO3dfbRddX3n8fdHEBF5qiaOmIDxIShR6UivoOPqiJW2gMvQ1pYhBa0WiWPFNRWxUGuB0RkfaqsztliMVlGsPE7LiiWIg6UyVYKEikiiaIoRAlQiIKigGPjOH2fHOXO5N/eEe/c5nLvfr7XOYj/8zt7fX27I5+7fb599UlVIkrrrMaMuQJI0WgaBJHWcQSBJHWcQSFLHGQSS1HEGgSR1nEEgDSjJa5P8c9/6j5I8Y46O/fYkH2uWlySpJDvP0bH3a2rdaS6Op/nHINDQNf8obXs9lOT+vvVjh1TDoUk2z+YYVbV7Vd00F+epqndX1etnU0/fOTclOazv2Dc3tT44F8fX/DMnv3FIO6Kqdt+2nGQT8PqqunxHjpFk56raOte1jcJ86ovGk1cEetRIcnCSq5L8IMntSf4qyS59+yvJm5J8G/h2s+2Pmra3JXl90+ZZzb7HJfnzJDcn+V6Ss5I8PskTgEuBp/ZdiTx1inqelGR1knuTfAV45qT9/ec6MsmGJD9McmuSk6c7T5IzklyU5NNJ7gVe22z79KQSfr/p1+1JTu4779lJ/lvf+s+vOpKcA+wHfLY53x9NHmpqalid5K4kG5Oc0HesM5JckORTTV/WJ5nY8Z+mxolBoEeTB4G3AAuAFwMvB/5gUpvfAA4BliU5HDgJOAx4FnDopLbvBfYH/n2zfxFwWlX9GDgCuK0ZMtm9qm6bop4zgZ8A+wC/37ym8zfAG6pqD+B5wD/OcJ6jgIuAvYG/neaYLwOWAr8GnNI/3DOdqno1cDPwyuZ8fzZFs/OAzcBTgd8G3p3kV/r2L2/a7A2sBv5qpvNqvI1lECT5eJI7ktwwYPujm9/W1if5TNv16ZGpqmuram1Vba2qTcBHgJdOavaeqrqrqu4HjgY+UVXrq+o+4IxtjZIEWAm8pWn/Q+DdwDGD1NJMrL6KJjiq6gbgk9t5y8/ohdOeVXV3Vf3LDKe4qqourqqHmr5M5b825/468AlgxSC1b0+SfYGXAKdU1U+q6jrgY8Br+pr9c1WtaeYUzgF+cbbn1aPbWAYBcDZw+CANkywF/hh4SVU9F/jD9srSbCTZP8k/JPm3Zsjk3fSuDvrd0rf81Enr/csLgd2Aa5uhph8An2u2D2IhvTm0/mN+dzvtXwUcCXw3yReTvHiG498yw/7Jbb5Lr7+z9VRgWzD2H3tR3/q/9S3fB+w6V3cw6dFpLIOgqq4E7urfluSZST6X5Nok/yfJc5pdJwBnVtXdzXvvGHK5GtxfA98EllbVnsDbgUxq0/+43NuBxX3r+/Ytfx+4H3huVe3dvPbqm6ie6bG7W4Ctk46533SNq+qaqjoKeDJwMXDBDOcZ5LG/k8+9bVjpx/RCbpun7MCxbwOemGSPSce+dYB6NE+NZRBMYxXw5qr6JeBk4MPN9v2B/ZN8KcnaZlxZj057APcCP2qC/I0ztL8AeF2SA5LsBvzpth1V9RDwUeCDSZ4MkGRRkl9vmnwPeFKSvaY6cDMs8nfAGUl2S7IM+L2p2ibZJcmxSfaqqp81fXhokPPM4E+bcz8XeB1wfrP9OuDIJE9M8hQefpX7PWDKzzdU1S3Al4H3JNk1yYHA8cDkiWp1yLwIgiS7A/8BuDDJdfTGlvdpdu9Mb8LtUHpjrB9Nsvfwq9QATgZ+F/ghvX/Ez99e46q6FPgQcAWwEVjb7Ppp899Ttm1vhpouB57dvPebwLnATc3Q0VTDLicCu9MbKjmb3jj9dF4NbGrO85+BY3fgPNP5YlP/F4A/r6rPN9vPAb4GbAI+z8P/nN4DvKM538k83ApgCb2rg78HTt/R23c1v2Rcv5gmyRLgH6rqeUn2BG6sqn2maHcWcHVVfaJZ/wJwalVdM9SC1bokBwA3AI/zvnxpcPPiiqCq7gW+k+R3oHfHSJJtdzpcTHNbYZIF9IaKtvtpUI2PJL+Z3ucFfgF4H/BZQ0DaMWMZBEnOBa4Cnp1kc5Lj6V2KH5/ka8B6evdpA1wG3JlkA70hhLdV1Z2jqFuteANwB/Cv9D6HMNO8gqRJxnZoSJI0N8byikCSNHfG7kMiCxYsqCVLloy6DEkaK9dee+33q2rKD1SOXRAsWbKEdevWjboMSRorSab9ZLxDQ5LUcQaBJHWcQSBJHWcQSFLHGQSS1HEGgSR1nEEgSR1nEEhSxxkEktRxY/fJ4tlYcuolIzv3pve+YmTnlqTtae2KIMnHk9yR5IZp9h+b5PokX0/y5b7vD5AkDVGbQ0NnA9v7fuDvAC+tqucD76L3ncOSpCFrbWioqq5svk5yuv1f7ltdCyxuqxZJ0vQeLZPFxwOXTrczycok65Ks27JlyxDLkqT5b+RBkORl9ILglOnaVNWqqpqoqomFC6d8nLYk6REa6V1DSQ4EPgYc4fcIS9JojOyKIMl+wN8Br66qb42qDknqutauCJKcCxwKLEiyGTgdeCxAVZ0FnAY8CfhwEoCtVTXRVj2SpKm1edfQihn2vx54fVvnlyQNZuSTxZKk0TIIJKnjDAJJ6jiDQJI6ziCQpI4zCCSp4wwCSeo4g0CSOs4gkKSOMwgkqeMMAknqOINAkjrOIJCkjjMIJKnjDAJJ6jiDQJI6ziCQpI4zCCSp4wwCSeo4g0CSOs4gkKSOMwgkqeMMAknqOINAkjrOIJCkjmstCJJ8PMkdSW6YZn+SfCjJxiTXJzmorVokSdNr84rgbODw7ew/AljavFYCf91iLZKkabQWBFV1JXDXdpocBXyqetYCeyfZp616JElTG+UcwSLglr71zc22h0myMsm6JOu2bNkylOIkqSvGYrK4qlZV1URVTSxcuHDU5UjSvDLKILgV2LdvfXGzTZI0RKMMgtXAa5q7h14E3FNVt4+wHknqpJ3bOnCSc4FDgQVJNgOnA48FqKqzgDXAkcBG4D7gdW3VIkmaXmtBUFUrZthfwJvaOr8kaTBjMVksSWqPQSBJHWcQSFLHGQSS1HEGgSR1nEEgSR1nEEhSxxkEktRxBoEkdZxBIEkdZxBIUscZBJLUcQaBJHWcQSBJHWcQSFLHGQSS1HEGgSR1nEEgSR1nEEhSxxkEktRxBoEkdZxBIEkdZxBIUscZBJLUca0GQZLDk9yYZGOSU6fYv1+SK5J8Ncn1SY5ssx5J0sO1FgRJdgLOBI4AlgErkiyb1OwdwAVV9QLgGODDbdUjSZpam1cEBwMbq+qmqnoAOA84alKbAvZslvcCbmuxHknSFHZu8diLgFv61jcDh0xqcwbw+SRvBp4AHNZiPZKkKYx6sngFcHZVLQaOBM5J8rCakqxMsi7Jui1btgy9SEmaz9oMgluBffvWFzfb+h0PXABQVVcBuwILJh+oqlZV1URVTSxcuLClciWpm9oMgmuApUmenmQXepPBqye1uRl4OUCSA+gFgb/yS9IQDRQESZ6/oweuqq3AicBlwDfo3R20Psk7kyxvmr0VOCHJ14BzgddWVe3ouSRJj9ygk8UfTvI44Gzgb6vqnkHeVFVrgDWTtp3Wt7wBeMmANUiSWjDQFUFV/TJwLL0x/2uTfCbJr7ZamSRpKAaeI6iqb9P7ANgpwEuBDyX5ZpLfaqs4SVL7Bp0jODDJB+mN9f8K8MqqOqBZ/mCL9UmSWjboHMFfAh8D3l5V92/bWFW3JXlHK5VJkoZi0CB4BXB/VT0I0Hzoa9equq+qzmmtOklS6wadI7gceHzf+m7NNknSmBs0CHatqh9tW2mWd2unJEnSMA0aBD9OctC2lSS/BNy/nfaSpDEx6BzBHwIXJrkNCPAU4D+1VZQkaXgGCoKquibJc4BnN5turKqftVeWJGlYduT7CF4ILGnec1ASqupTrVQlSRqagYIgyTnAM4HrgAebzQUYBJI05ga9IpgAlvlkUEmafwa9a+gGehPEkqR5ZtArggXAhiRfAX66bWNVLZ/+LZKkcTBoEJzRZhGSpNEZ9PbRLyZ5GrC0qi5PshuwU7ulSZKGYdDHUJ8AXAR8pNm0CLi4pZokSUM06GTxm+h9peS98PMvqXlyW0VJkoZn0CD4aVU9sG0lyc70PkcgSRpzgwbBF5O8HXh8813FFwKfba8sSdKwDBoEpwJbgK8DbwDW0Pv+YknSmBv0rqGHgI82L0nSPDLos4a+wxRzAlX1jDmvSJI0VDvyrKFtdgV+B3ji3JcjSRq2geYIqurOvtetVfU/6H2h/XYlOTzJjUk2Jjl1mjZHJ9mQZH2Sz+xY+ZKk2Rp0aOigvtXH0LtC2O57k+wEnAn8KrAZuCbJ6qra0NdmKfDHwEuq6u4kfjZBkoZs0KGhv+hb3gpsAo6e4T0HAxur6iaAJOcBRwEb+tqcAJxZVXcDVNUdA9YjSZojg9419LJHcOxFwC1965uBQya12R8gyZfoPbvojKr63OQDJVkJrATYb7/9HkEpkqTpDDo0dNL29lfVB2Zx/qXAocBi4Mokz6+qH0w6/ipgFcDExISfaJakObQjdw29EFjdrL8S+Arw7e2851Zg3771xc22fpuBq6vqZ8B3knyLXjBcM2BdkqRZGjQIFgMHVdUPAZKcAVxSVcdt5z3XAEuTPJ1eABwD/O6kNhcDK4BPJFlAb6jopoGrlyTN2qCPmPh3wAN96w8026ZVVVuBE4HLgG8AF1TV+iTvTLLtm80uA+5MsgG4AnhbVd25Ix2QJM3OoFcEnwK+kuTvm/XfAD4505uqag295xL1bzutb7mAk5qXJGkEBr1r6L8nuRT45WbT66rqq+2VJUkalkGHhgB2A+6tqv8JbG7G/iVJY27Qr6o8HTiF3qeAAR4LfLqtoiRJwzPoFcFvAsuBHwNU1W3AHm0VJUkankGD4IFmYrcAkjyhvZIkScM0aBBckOQjwN5JTgAuxy+pkaR5Yca7hpIEOB94DnAv8GzgtKr63y3XJkkaghmDoKoqyZqqej7gP/6SNM8MOjT0L0le2GolkqSRGPSTxYcAxyXZRO/OodC7WDiwrcIkScMx07eM7VdVNwO/PqR6JElDNtMVwcX0njr63ST/q6peNYSaJElDNNMcQfqWn9FmIZKk0ZgpCGqaZUnSPDHT0NAvJrmX3pXB45tl+H+TxXu2Wp0kqXXbDYKq2mlYhUiSRmNHHkMtSZqHDAJJ6jiDQJI6ziCQpI4zCCSp4wwCSeo4g0CSOs4gkKSOMwgkqeNaDYIkhye5McnGJKdup92rklSSiTbrkSQ9XGtBkGQn4EzgCGAZsCLJsina7QH8F+DqtmqRJE2vzSuCg4GNVXVTVT0AnAccNUW7dwHvA37SYi2SpGm0GQSLgFv61jc3234uyUHAvlV1yfYOlGRlknVJ1m3ZsmXuK5WkDhvZZHGSxwAfAN46U9uqWlVVE1U1sXDhwvaLk6QOaTMIbgX27Vtf3GzbZg/gecA/JdkEvAhY7YSxJA1Xm0FwDbA0ydOT7AIcA6zetrOq7qmqBVW1pKqWAGuB5VW1rsWaJEmTtBYEVbUVOBG4DPgGcEFVrU/yziTL2zqvJGnHzPRVlbNSVWuANZO2nTZN20PbrEWSNDU/WSxJHWcQSFLHGQSS1HEGgSR1nEEgSR1nEEhSxxkEktRxBoEkdZxBIEkdZxBIUscZBJLUcQaBJHWcQSBJHWcQSFLHGQSS1HEGgSR1nEEgSR1nEEhSxxkEktRxBoEkdZxBIEkdZxBIUscZBJLUcQaBJHWcQSBJHddqECQ5PMmNSTYmOXWK/Scl2ZDk+iRfSPK0NuuRJD1ca0GQZCfgTOAIYBmwIsmySc2+CkxU1YHARcCftVWPJGlqbV4RHAxsrKqbquoB4DzgqP4GVXVFVd3XrK4FFrdYjyRpCm0GwSLglr71zc226RwPXDrVjiQrk6xLsm7Lli1zWKIk6VExWZzkOGACeP9U+6tqVVVNVNXEwoULh1ucJM1zO7d47FuBffvWFzfb/j9JDgP+BHhpVf20xXokSVNo84rgGmBpkqcn2QU4Bljd3yDJC4CPAMur6o4Wa5EkTaO1IKiqrcCJwGXAN4ALqmp9kncmWd40ez+wO3BhkuuSrJ7mcJKklrQ5NERVrQHWTNp2Wt/yYW2eX5I0s0fFZLEkaXQMAknqOINAkjrOIJCkjjMIJKnjDAJJ6jiDQJI6ziCQpI4zCCSp4wwCSeo4g0CSOs4gkKSOMwgkqeMMAknquFYfQy1J882SUy8Z2bk3vfcVrRzXKwJJ6jiDQJI6ziCQpI4zCCSp4wwCSeo4g0CSOs4gkKSOMwgkqeMMAknqOINAkjqu1SBIcniSG5NsTHLqFPsfl+T8Zv/VSZa0WY8k6eFaC4IkOwFnAkcAy4AVSZZNanY8cHdVPQv4IPC+tuqRJE2tzSuCg4GNVXVTVT0AnAccNanNUcAnm+WLgJcnSYs1SZImafPpo4uAW/rWNwOHTNemqrYmuQd4EvD9/kZJVgIrm9UfJbnxEda0YPKxhyWju9YZWZ9HyD53Q+f6nPfNqs9Pm27HWDyGuqpWAatme5wk66pqYg5KGhv2uRvscze01ec2h4ZuBfbtW1/cbJuyTZKdgb2AO1usSZI0SZtBcA2wNMnTk+wCHAOsntRmNfB7zfJvA/9YVdViTZKkSVobGmrG/E8ELgN2Aj5eVeuTvBNYV1Wrgb8BzkmyEbiLXli0adbDS2PIPneDfe6GVvocfwGXpG7zk8WS1HEGgSR13LwMgi4+2mKAPp+UZEOS65N8Icm09xSPi5n63NfuVUkqydjfajhIn5Mc3fys1yf5zLBrnGsD/N3eL8kVSb7a/P0+chR1zpUkH09yR5IbptmfJB9q/jyuT3LQrE9aVfPqRW9i+l+BZwC7AF8Dlk1q8wfAWc3yMcD5o657CH1+GbBbs/zGLvS5abcHcCWwFpgYdd1D+DkvBb4K/EKz/uRR1z2EPq8C3tgsLwM2jbruWfb5PwIHATdMs/9I4FIgwIuAq2d7zvl4RdDFR1vM2OequqKq7mtW19L7XMc4G+TnDPAues+w+skwi2vJIH0+ATizqu4GqKo7hlzjXBukzwXs2SzvBdw2xPrmXFVdSe8uyukcBXyqetYCeyfZZzbnnI9BMNWjLRZN16aqtgLbHm0xrgbpc7/j6f1GMc5m7HNzybxvVV0yzMJaNMjPeX9g/yRfSrI2yeFDq64dg/T5DOC4JJuBNcCbh1PayOzo/+8zGotHTGjuJDkOmABeOupa2pTkMcAHgNeOuJRh25ne8NCh9K76rkzy/Kr6wSiLatkK4Oyq+oskL6b32aTnVdVDoy5sXMzHK4IuPtpikD6T5DDgT4DlVfXTIdXWlpn6vAfwPOCfkmyiN5a6eswnjAf5OW8GVlfVz6rqO8C36AXDuBqkz8cDFwBU1VXArvQeSDdfDfT/+46Yj0HQxUdbzNjnJC8APkIvBMZ93Bhm6HNV3VNVC6pqSVUtoTcvsryq1o2m3DkxyN/ti+ldDZBkAb2hopuGWONcG6TPNwMvB0hyAL0g2DLUKodrNfCa5u6hFwH3VNXtszngvBsaqkfnoy1aNWCf3w/sDlzYzIvfXFXLR1b0LA3Y53llwD5fBvxakg3Ag8Dbqmpsr3YH7PNbgY8meQu9iePXjvMvdknOpRfmC5p5j9OBxwJU1Vn05kGOBDYC9wGvm/U5x/jPS5I0B+bj0JAkaQcYBJLUcQaBJHWcQSBJHWcQSFLHGQSS1HEGgSR13P8FdJXGehzQddcAAAAASUVORK5CYII=\n", 752 | "text/plain": [ 753 | "
" 754 | ] 755 | }, 756 | "metadata": {}, 757 | "output_type": "display_data" 758 | } 759 | ], 760 | "source": [ 761 | "df.target.plot(kind='hist' , title='Target distribution');" 762 | ] 763 | }, 764 | { 765 | "cell_type": "markdown", 766 | "metadata": { 767 | "id": "ELjswHcFHfp3" 768 | }, 769 | "source": [ 770 | "## Task 4: Create tf.data.Datasets for Training and Evaluation" 771 | ] 772 | }, 773 | { 774 | "cell_type": "code", 775 | "execution_count": null, 776 | "metadata": { 777 | "colab": { 778 | "background_save": true 779 | }, 780 | "id": "fScULIGPwuWk", 781 | "outputId": "ee4db1e9-b6a4-43f9-f861-d2fbf0e7cc94" 782 | }, 783 | "outputs": [ 784 | { 785 | "data": { 786 | "text/plain": [ 787 | "((9795, 3), (972, 3))" 788 | ] 789 | }, 790 | "execution_count": null, 791 | "metadata": {}, 792 | "output_type": "execute_result" 793 | } 794 | ], 795 | "source": [ 796 | "train_df, remaining = train_test_split(df, random_state=42, train_size=0.0075, stratify=df.target.values)\n", 797 | "valid_df, _ = train_test_split(remaining, random_state=42, train_size=0.00075, stratify=remaining.target.values)\n", 798 | "train_df.shape, valid_df.shape" 799 | ] 800 | }, 801 | { 802 | "cell_type": "code", 803 | "execution_count": null, 804 | "metadata": { 805 | "colab": { 806 | "background_save": true 807 | }, 808 | "id": "qQYMGT5_qLPX", 809 | "outputId": "f8ae5afe-b422-4b2b-a41f-3a8e73104022" 810 | }, 811 | "outputs": [ 812 | { 813 | "name": "stdout", 814 | "output_type": "stream", 815 | "text": [ 816 | "tf.Tensor(b'Why are unhealthy relationships so desirable?', shape=(), dtype=string)\n", 817 | "tf.Tensor(0, shape=(), dtype=int64)\n" 818 | ] 819 | } 820 | ], 821 | "source": [ 822 | "with tf.device('/cpu:0'):\n", 823 | " train_data = tf.data.Dataset.from_tensor_slices((train_df['question_text'].values, train_df['target'].values))\n", 824 | " valid_data = tf.data.Dataset.from_tensor_slices((valid_df.question_text.values, valid_df.target.values))\n", 825 | "\n", 826 | "for text, label in train_data.take(1):\n", 827 | " print(text)\n", 828 | " print(label)" 829 | ] 830 | }, 831 | { 832 | "cell_type": "markdown", 833 | "metadata": { 834 | "id": "e2-ReN88Hvy_" 835 | }, 836 | "source": [ 837 | "## Task 5: Download a Pre-trained BERT Model from TensorFlow Hub" 838 | ] 839 | }, 840 | { 841 | "cell_type": "code", 842 | "execution_count": null, 843 | "metadata": { 844 | "colab": { 845 | "background_save": true 846 | }, 847 | "id": "EMb5M86b4-BU" 848 | }, 849 | "outputs": [], 850 | "source": [ 851 | "\"\"\"\n", 852 | "Each line of the dataset is composed of the review text and its label\n", 853 | "- Data preprocessing consists of transforming text to BERT input features:\n", 854 | "input_word_ids, input_mask, segment_ids\n", 855 | "- In the process, tokenizing the text is done with the provided BERT model tokenizer\n", 856 | "\"\"\"\n", 857 | "\n", 858 | "label_list = [0,1]\n", 859 | " # Label categories\n", 860 | "max_seq_length = 128 # maximum length of (token) input sequences\n", 861 | "train_batch_size = 32\n", 862 | "\n", 863 | "\n", 864 | "\n", 865 | "# Get BERT layer and tokenizer:\n", 866 | "# More details here: https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/2\n", 867 | "bert_layer = hub.KerasLayer(\"https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/2\", trainable=True)\n", 868 | "\n", 869 | "vocab_file = bert_layer.resolved_object.vocab_file.asset_path.numpy()\n", 870 | "do_lower_case = bert_layer.resolved_object.do_lower_case.numpy()\n", 871 | "tokenizer = tokenization.FullTokenizer(vocab_file, do_lower_case)\n" 872 | ] 873 | }, 874 | { 875 | "cell_type": "code", 876 | "execution_count": null, 877 | "metadata": { 878 | "colab": { 879 | "background_save": true 880 | }, 881 | "id": "wEUezMK-zkkI", 882 | "outputId": "33d3e568-0785-4fe2-e5af-9813f0bb3a90" 883 | }, 884 | "outputs": [ 885 | { 886 | "data": { 887 | "text/plain": [ 888 | "['hi', '##,', 'how', 'are', 'you', '##?']" 889 | ] 890 | }, 891 | "execution_count": null, 892 | "metadata": {}, 893 | "output_type": "execute_result" 894 | } 895 | ], 896 | "source": [ 897 | "tokenizer.wordpiece_tokenizer.tokenize('hi, how are you?')" 898 | ] 899 | }, 900 | { 901 | "cell_type": "code", 902 | "execution_count": null, 903 | "metadata": { 904 | "colab": { 905 | "background_save": true 906 | }, 907 | "id": "5AFsmTO5JSmc", 908 | "outputId": "8056be9c-7510-4c90-b000-9b31f14631d5" 909 | }, 910 | "outputs": [ 911 | { 912 | "data": { 913 | "text/plain": [ 914 | "[7632, 29623, 2129, 2024, 2017, 29632]" 915 | ] 916 | }, 917 | "execution_count": null, 918 | "metadata": {}, 919 | "output_type": "execute_result" 920 | } 921 | ], 922 | "source": [ 923 | "tokenizer.convert_tokens_to_ids(tokenizer.wordpiece_tokenizer.tokenize('hi, how are you?'))" 924 | ] 925 | }, 926 | { 927 | "cell_type": "markdown", 928 | "metadata": { 929 | "id": "9QinzNq6OsP1" 930 | }, 931 | "source": [ 932 | "## Task 6: Tokenize and Preprocess Text for BERT" 933 | ] 934 | }, 935 | { 936 | "cell_type": "markdown", 937 | "metadata": { 938 | "id": "3FTqJ698zZ1e" 939 | }, 940 | "source": [ 941 | "
\n", 942 | " \n", 943 | "

Figure 2: BERT Tokenizer

\n", 944 | "
" 945 | ] 946 | }, 947 | { 948 | "cell_type": "markdown", 949 | "metadata": { 950 | "id": "cWYkggYe6HZc" 951 | }, 952 | "source": [ 953 | "We'll need to transform our data into a format BERT understands. This involves two steps. First, we create InputExamples using `classifier_data_lib`'s constructor `InputExample` provided in the BERT library." 954 | ] 955 | }, 956 | { 957 | "cell_type": "code", 958 | "execution_count": null, 959 | "metadata": { 960 | "id": "m-21A5aNJM0W" 961 | }, 962 | "outputs": [], 963 | "source": [ 964 | "# This provides a function to convert row to input features and label\n", 965 | "\n", 966 | "def to_feature(text, label, label_list=label_list, max_seq_length=max_seq_length, tokenizer=tokenizer):\n", 967 | " example = classifier_data_lib.InputExample(guid=None,\n", 968 | " text_a = text.numpy(),\n", 969 | " text_b = None,\n", 970 | " label =label.numpy())\n", 971 | " feature = classifier_data_lib.convert_single_example(0, example, label_list, max_seq_length, tokenizer)\n", 972 | "\n", 973 | " return (feature.input_word_ids, feature.input_mask, feature.segment_ids, feature.label_id)\n", 974 | " \n", 975 | "\n", 976 | "\n", 977 | "\n" 978 | ] 979 | }, 980 | { 981 | "cell_type": "markdown", 982 | "metadata": { 983 | "id": "A_HQSsHwWCsK" 984 | }, 985 | "source": [ 986 | "You want to use [`Dataset.map`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset#map) to apply this function to each element of the dataset. [`Dataset.map`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset#map) runs in graph mode.\n", 987 | "\n", 988 | "- Graph tensors do not have a value.\n", 989 | "- In graph mode you can only use TensorFlow Ops and functions.\n", 990 | "\n", 991 | "So you can't `.map` this function directly: You need to wrap it in a [`tf.py_function`](https://www.tensorflow.org/api_docs/python/tf/py_function). The [`tf.py_function`](https://www.tensorflow.org/api_docs/python/tf/py_function) will pass regular tensors (with a value and a `.numpy()` method to access it), to the wrapped python function." 992 | ] 993 | }, 994 | { 995 | "cell_type": "markdown", 996 | "metadata": { 997 | "id": "zaNlkKVfWX0Q" 998 | }, 999 | "source": [ 1000 | "## Task 7: Wrap a Python Function into a TensorFlow op for Eager Execution" 1001 | ] 1002 | }, 1003 | { 1004 | "cell_type": "code", 1005 | "execution_count": null, 1006 | "metadata": { 1007 | "id": "AGACBcfCWC2O" 1008 | }, 1009 | "outputs": [], 1010 | "source": [ 1011 | "def to_feature_map(text, label):\n", 1012 | " input_ids, input_mask, segment_ids, label_id =tf.py_function(to_feature, inp=[text,label],\n", 1013 | " Tout=[tf.int32, tf.int32, tf.int32, tf.int32])\n", 1014 | " input_ids.set_shape([max_seq_length])\n", 1015 | " input_mask.set_shape([max_seq_length])\n", 1016 | " segment_ids.set_shape([max_seq_length])\n", 1017 | " label_id.set_shape([])\n", 1018 | "\n", 1019 | " x={\n", 1020 | " 'input_word_ids': input_ids,\n", 1021 | " 'input_mask': input_mask,\n", 1022 | " 'input_type_ids': segment_ids\n", 1023 | " }\n", 1024 | " \n", 1025 | " return(x, label_id)\n", 1026 | " " 1027 | ] 1028 | }, 1029 | { 1030 | "cell_type": "markdown", 1031 | "metadata": { 1032 | "id": "dhdO6MjTbtn1" 1033 | }, 1034 | "source": [ 1035 | "## Task 8: Create a TensorFlow Input Pipeline with `tf.data`" 1036 | ] 1037 | }, 1038 | { 1039 | "cell_type": "code", 1040 | "execution_count": null, 1041 | "metadata": { 1042 | "colab": { 1043 | "base_uri": "https://localhost:8080/", 1044 | "height": 503 1045 | }, 1046 | "id": "LHRdiO3dnPNr", 1047 | "outputId": "0b99b546-1352-4ea4-94ff-3adb447e6ee5" 1048 | }, 1049 | "outputs": [ 1050 | { 1051 | "ename": "TypeError", 1052 | "evalue": "ignored", 1053 | "output_type": "error", 1054 | "traceback": [ 1055 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 1056 | "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", 1057 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;31m# train\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m train_data = (train_data.map(to_feature_map,\n\u001b[0;32m----> 5\u001b[0;31m num_parallel_calls=tf.data.experimental.AUTOTUNE)\n\u001b[0m\u001b[1;32m 6\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0mshuffle\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1000\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0mbatch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m32\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdrop_remainder\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 1058 | "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/tensorflow/python/data/ops/dataset_ops.py\u001b[0m in \u001b[0;36mmap\u001b[0;34m(self, map_func, num_parallel_calls, deterministic, name)\u001b[0m\n\u001b[1;32m 2010\u001b[0m \u001b[0mdeterministic\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2011\u001b[0m \u001b[0mpreserve_cardinality\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2012\u001b[0;31m name=name)\n\u001b[0m\u001b[1;32m 2013\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2014\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mflat_map\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmap_func\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 1059 | "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/tensorflow/python/data/ops/dataset_ops.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, input_dataset, map_func, num_parallel_calls, deterministic, use_inter_op_parallelism, preserve_cardinality, use_legacy_function, name)\u001b[0m\n\u001b[1;32m 5503\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_transformation_name\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5504\u001b[0m \u001b[0mdataset\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0minput_dataset\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 5505\u001b[0;31m use_legacy_function=use_legacy_function)\n\u001b[0m\u001b[1;32m 5506\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mdeterministic\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5507\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_deterministic\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m\"default\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 1060 | "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/tensorflow/python/data/ops/dataset_ops.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, func, transformation_name, dataset, input_classes, input_shapes, input_types, input_structure, add_to_graph, use_legacy_function, defun_kwargs)\u001b[0m\n\u001b[1;32m 4531\u001b[0m \u001b[0mfn_factory\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtrace_tf_function\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdefun_kwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4532\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 4533\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_function\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfn_factory\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 4534\u001b[0m \u001b[0;31m# There is no graph to add in eager mode.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4535\u001b[0m \u001b[0madd_to_graph\u001b[0m \u001b[0;34m&=\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mcontext\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mexecuting_eagerly\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 1061 | "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/tensorflow/python/eager/function.py\u001b[0m in \u001b[0;36mget_concrete_function\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 3243\u001b[0m \"\"\"\n\u001b[1;32m 3244\u001b[0m graph_function = self._get_concrete_function_garbage_collected(\n\u001b[0;32m-> 3245\u001b[0;31m *args, **kwargs)\n\u001b[0m\u001b[1;32m 3246\u001b[0m \u001b[0mgraph_function\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_garbage_collector\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrelease\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# pylint: disable=protected-access\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3247\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mgraph_function\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 1062 | "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/tensorflow/python/eager/function.py\u001b[0m in \u001b[0;36m_get_concrete_function_garbage_collected\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 3208\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkwargs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3209\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_lock\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 3210\u001b[0;31m \u001b[0mgraph_function\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_maybe_define_function\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 3211\u001b[0m \u001b[0mseen_names\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mset\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3212\u001b[0m captured = object_identity.ObjectIdentitySet(\n", 1063 | "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/tensorflow/python/eager/function.py\u001b[0m in \u001b[0;36m_maybe_define_function\u001b[0;34m(self, args, kwargs)\u001b[0m\n\u001b[1;32m 3555\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3556\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_function_cache\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmissed\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madd\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcall_context_key\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 3557\u001b[0;31m \u001b[0mgraph_function\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_create_graph_function\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 3558\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_function_cache\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprimary\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mcache_key\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mgraph_function\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3559\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", 1064 | "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/tensorflow/python/eager/function.py\u001b[0m in \u001b[0;36m_create_graph_function\u001b[0;34m(self, args, kwargs, override_flat_arg_shapes)\u001b[0m\n\u001b[1;32m 3400\u001b[0m \u001b[0marg_names\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0marg_names\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3401\u001b[0m \u001b[0moverride_flat_arg_shapes\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0moverride_flat_arg_shapes\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 3402\u001b[0;31m capture_by_value=self._capture_by_value),\n\u001b[0m\u001b[1;32m 3403\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_function_attributes\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3404\u001b[0m \u001b[0mfunction_spec\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfunction_spec\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 1065 | "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/tensorflow/python/framework/func_graph.py\u001b[0m in \u001b[0;36mfunc_graph_from_py_func\u001b[0;34m(name, python_func, args, kwargs, signature, func_graph, autograph, autograph_options, add_control_dependencies, arg_names, op_return_value, collections, capture_by_value, override_flat_arg_shapes, acd_record_initial_resource_uses)\u001b[0m\n\u001b[1;32m 1141\u001b[0m \u001b[0m_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moriginal_func\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtf_decorator\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0munwrap\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpython_func\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1142\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1143\u001b[0;31m \u001b[0mfunc_outputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpython_func\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mfunc_args\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mfunc_kwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1144\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1145\u001b[0m \u001b[0;31m# invariant: `func_outputs` contains only Tensors, CompositeTensors,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 1066 | "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/tensorflow/python/data/ops/dataset_ops.py\u001b[0m in \u001b[0;36mwrapped_fn\u001b[0;34m(*args)\u001b[0m\n\u001b[1;32m 4508\u001b[0m attributes=defun_kwargs)\n\u001b[1;32m 4509\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mwrapped_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;31m# pylint: disable=missing-docstring\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 4510\u001b[0;31m \u001b[0mret\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mwrapper_helper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 4511\u001b[0m \u001b[0mret\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mstructure\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto_tensor_list\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_output_structure\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mret\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4512\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mops\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconvert_to_tensor\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mt\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mt\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mret\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 1067 | "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/tensorflow/python/data/ops/dataset_ops.py\u001b[0m in \u001b[0;36mwrapper_helper\u001b[0;34m(*args)\u001b[0m\n\u001b[1;32m 4438\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0m_should_unpack\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnested_args\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4439\u001b[0m \u001b[0mnested_args\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mnested_args\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 4440\u001b[0;31m \u001b[0mret\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mautograph\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtf_convert\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_func\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mag_ctx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mnested_args\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 4441\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0m_should_pack\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mret\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4442\u001b[0m \u001b[0mret\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtuple\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mret\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 1068 | "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/tensorflow/python/autograph/impl/api.py\u001b[0m in \u001b[0;36mwrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 697\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mException\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;31m# pylint:disable=broad-except\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 698\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mhasattr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0me\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'ag_error_metadata'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 699\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mag_error_metadata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto_exception\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0me\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 700\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 701\u001b[0m \u001b[0;32mraise\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 1069 | "\u001b[0;31mTypeError\u001b[0m: in user code:\n\n File \"\", line 2, in to_feature_map *\n input_ids, input_mask, segment_ids, label_id =tf.py_function(to_feature, inp=[text,label],\n\n TypeError: Tensors in list passed to 'input' of 'EagerPyFunc' Op have types [, int32] that are invalid. Tensors: [{'input_word_ids': , 'input_mask': , 'input_type_ids': }, ]\n" 1070 | ] 1071 | } 1072 | ], 1073 | "source": [ 1074 | "from tensorflow.python.data.ops.dataset_ops import AUTOTUNE\n", 1075 | "with tf.device('/cpu:0'):\n", 1076 | " # train\n", 1077 | " train_data = (train_data.map(to_feature_map,\n", 1078 | " num_parallel_calls=tf.data.experimental.AUTOTUNE)\n", 1079 | " .shuffle(1000)\n", 1080 | " .batch(32, drop_remainder=True)\n", 1081 | " .prefetch(tf.data.experimental.AUTOTUNE))\n", 1082 | "\n", 1083 | " # valid\n", 1084 | " valid_data = (valid_data.map(to_feature_map,\n", 1085 | " num_parallel_calls=tf.data.experimental.AUTOTUNE)\n", 1086 | " .batch(32, drop_remainder=True)\n", 1087 | " .prefetch(tf.data.experimental.AUTOTUNE))\n", 1088 | " " 1089 | ] 1090 | }, 1091 | { 1092 | "cell_type": "markdown", 1093 | "metadata": { 1094 | "id": "KLUWnfx-YDi2" 1095 | }, 1096 | "source": [ 1097 | "The resulting `tf.data.Datasets` return `(features, labels)` pairs, as expected by [`keras.Model.fit`](https://www.tensorflow.org/api_docs/python/tf/keras/Model#fit):" 1098 | ] 1099 | }, 1100 | { 1101 | "cell_type": "code", 1102 | "execution_count": null, 1103 | "metadata": { 1104 | "id": "B0Z2cy9GHQ8x" 1105 | }, 1106 | "outputs": [], 1107 | "source": [ 1108 | "# train data spec\n", 1109 | "train_data.element_spec\n" 1110 | ] 1111 | }, 1112 | { 1113 | "cell_type": "code", 1114 | "execution_count": null, 1115 | "metadata": { 1116 | "id": "DGAH-ycYOmao" 1117 | }, 1118 | "outputs": [], 1119 | "source": [ 1120 | "# valid data spec\n", 1121 | "valid_data.element_spec\n" 1122 | ] 1123 | }, 1124 | { 1125 | "cell_type": "markdown", 1126 | "metadata": { 1127 | "id": "GZxe-7yhPyQe" 1128 | }, 1129 | "source": [ 1130 | "## Task 9: Add a Classification Head to the BERT Layer" 1131 | ] 1132 | }, 1133 | { 1134 | "cell_type": "markdown", 1135 | "metadata": { 1136 | "id": "9THH5V0Dw2HO" 1137 | }, 1138 | "source": [ 1139 | "
\n", 1140 | " \n", 1141 | "

Figure 3: BERT Layer

\n", 1142 | "
" 1143 | ] 1144 | }, 1145 | { 1146 | "cell_type": "code", 1147 | "execution_count": null, 1148 | "metadata": { 1149 | "id": "G9il4gtlADcp" 1150 | }, 1151 | "outputs": [], 1152 | "source": [ 1153 | "# Building the model\n", 1154 | "def create_model():\n", 1155 | " input_word_ids = tf.keras.layers.Input(shape=(max_seq_length,), dtype=tf.int32,\n", 1156 | " name=\"input_word_ids\")\n", 1157 | " input_mask = tf.keras.layers.Input(shape=(max_seq_length,), dtype=tf.int32,\n", 1158 | " name=\"input_mask\")\n", 1159 | " input_type_ids = tf.keras.layers.Input(shape=(max_seq_length,), dtype=tf.int32,\n", 1160 | " name=\"input_type_ids\")\n", 1161 | " pooled_output, sequence_output = bert_layer([input_word_ids, input_mask, input_type_ids])\n", 1162 | "\n", 1163 | " drop = tf.keras.layers.Dropout(0.4)(pooled_output)\n", 1164 | " output = tf.keras.layers.Dense(1, activation='sigmoid', name=\"output\")(drop)\n", 1165 | "\n", 1166 | " model =tf.keras.models(\n", 1167 | " input={\n", 1168 | " 'input_word_ids' : input_word_ids,\n", 1169 | " 'input_mask': input_mask,\n", 1170 | " 'input_type_ids': input_type_ids\n", 1171 | " },\n", 1172 | " outputs=output)\n", 1173 | " return model" 1174 | ] 1175 | }, 1176 | { 1177 | "cell_type": "markdown", 1178 | "metadata": { 1179 | "id": "S6maM-vr7YaJ" 1180 | }, 1181 | "source": [ 1182 | "## Task 10: Fine-Tune BERT for Text Classification" 1183 | ] 1184 | }, 1185 | { 1186 | "cell_type": "code", 1187 | "execution_count": null, 1188 | "metadata": { 1189 | "id": "ptCtiiONsBgo" 1190 | }, 1191 | "outputs": [], 1192 | "source": [ 1193 | "model = create_model()\n", 1194 | "model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=2e.5),\n", 1195 | " loss=tf.keras.losses.BinaryCrossentropy(),\n", 1196 | " metrics=[tf.keras.metrics.BinaryAccuracy()])\n", 1197 | "model.summary()\n" 1198 | ] 1199 | }, 1200 | { 1201 | "cell_type": "code", 1202 | "execution_count": null, 1203 | "metadata": { 1204 | "id": "6GJaFnkbMtPL" 1205 | }, 1206 | "outputs": [], 1207 | "source": [ 1208 | "tf.keras.utils.plot_model(model=model, show_shapes=True, dpi=76)" 1209 | ] 1210 | }, 1211 | { 1212 | "cell_type": "code", 1213 | "execution_count": null, 1214 | "metadata": { 1215 | "id": "OcREcgPUHr9O" 1216 | }, 1217 | "outputs": [], 1218 | "source": [ 1219 | "# Train model\n", 1220 | "epochs = 4\n", 1221 | "history = model.fit(train_data,\n", 1222 | " validation_data=valid_data,\n", 1223 | " epochs=epochs,\n", 1224 | " verbose=1)\n" 1225 | ] 1226 | }, 1227 | { 1228 | "cell_type": "markdown", 1229 | "metadata": { 1230 | "id": "kNZl1lx_cA5Y" 1231 | }, 1232 | "source": [ 1233 | "## Task 11: Evaluate the BERT Text Classification Model" 1234 | ] 1235 | }, 1236 | { 1237 | "cell_type": "code", 1238 | "execution_count": null, 1239 | "metadata": { 1240 | "id": "dCjgrUYH_IsE" 1241 | }, 1242 | "outputs": [], 1243 | "source": [ 1244 | "import matplotlib.pyplot as plt\n", 1245 | "\n", 1246 | "def plot_graphs(history, metric):\n", 1247 | " plt.plot(history.history[metric])\n", 1248 | " plt.plot(history.history['val_'+metric], '')\n", 1249 | " plt.xlabel(\"Epochs\")\n", 1250 | " plt.ylabel(metric)\n", 1251 | " plt.legend([metric, 'val_'+metric])\n", 1252 | " plt.show()" 1253 | ] 1254 | }, 1255 | { 1256 | "cell_type": "code", 1257 | "execution_count": null, 1258 | "metadata": { 1259 | "id": "v6lrFRra_KmA" 1260 | }, 1261 | "outputs": [], 1262 | "source": [ 1263 | "plot_graphs(history, 'loss')" 1264 | ] 1265 | }, 1266 | { 1267 | "cell_type": "code", 1268 | "execution_count": null, 1269 | "metadata": { 1270 | "id": "opu9neBA_98R" 1271 | }, 1272 | "outputs": [], 1273 | "source": [ 1274 | "plot_graphs(history, 'binary_accuracy')" 1275 | ] 1276 | }, 1277 | { 1278 | "cell_type": "code", 1279 | "execution_count": null, 1280 | "metadata": { 1281 | "id": "hkhtCCgnUbY6" 1282 | }, 1283 | "outputs": [], 1284 | "source": [ 1285 | "" 1286 | ] 1287 | }, 1288 | { 1289 | "cell_type": "code", 1290 | "execution_count": null, 1291 | "metadata": { 1292 | "id": "K4B8NQBLd9rN" 1293 | }, 1294 | "outputs": [], 1295 | "source": [ 1296 | "" 1297 | ] 1298 | }, 1299 | { 1300 | "cell_type": "code", 1301 | "execution_count": null, 1302 | "metadata": { 1303 | "id": "FeVNOGfFJT9O" 1304 | }, 1305 | "outputs": [], 1306 | "source": [ 1307 | "" 1308 | ] 1309 | }, 1310 | { 1311 | "cell_type": "code", 1312 | "execution_count": null, 1313 | "metadata": { 1314 | "id": "I_YWudFRJT__" 1315 | }, 1316 | "outputs": [], 1317 | "source": [ 1318 | "" 1319 | ] 1320 | }, 1321 | { 1322 | "cell_type": "code", 1323 | "execution_count": null, 1324 | "metadata": { 1325 | "id": "hENB__IlJUCk" 1326 | }, 1327 | "outputs": [], 1328 | "source": [ 1329 | "" 1330 | ] 1331 | }, 1332 | { 1333 | "cell_type": "code", 1334 | "execution_count": null, 1335 | "metadata": { 1336 | "id": "wkYpiGrhJUFK" 1337 | }, 1338 | "outputs": [], 1339 | "source": [ 1340 | "" 1341 | ] 1342 | }, 1343 | { 1344 | "cell_type": "code", 1345 | "execution_count": null, 1346 | "metadata": { 1347 | "id": "iYqbQZJnJUHw" 1348 | }, 1349 | "outputs": [], 1350 | "source": [ 1351 | "" 1352 | ] 1353 | }, 1354 | { 1355 | "cell_type": "code", 1356 | "execution_count": null, 1357 | "metadata": { 1358 | "id": "aiKuBGgfJUKv" 1359 | }, 1360 | "outputs": [], 1361 | "source": [ 1362 | "" 1363 | ] 1364 | } 1365 | ], 1366 | "metadata": { 1367 | "accelerator": "GPU", 1368 | "colab": { 1369 | "collapsed_sections": [], 1370 | "machine_shape": "hm", 1371 | "name": "Copy of Fine-Tune-BERT-for-Text-Classification-with-TensorFlow.ipynb", 1372 | "provenance": [], 1373 | "include_colab_link": true 1374 | }, 1375 | "kernelspec": { 1376 | "display_name": "Python 3", 1377 | "name": "python3" 1378 | } 1379 | }, 1380 | "nbformat": 4, 1381 | "nbformat_minor": 0 1382 | } --------------------------------------------------------------------------------