├── README.md └── Sentiment_Analysis_Series_part_1.ipynb /README.md: -------------------------------------------------------------------------------- 1 | # Sentiment-classification-using-XLNet 2 | Fine-tuning XLNet model for sentiment classification. 3 | Find my medium article on sentiment-analysis using XLNet here 4 | 5 | About XLNet: 6 | 11 | 12 | The notebook is divided into following parts: 13 |
    14 |
  1. Install and import all the dependencies required to set the code working.
  2. 15 |
  3. Prepare data
  4. 16 |
  5. Writing function to perform train step and evaluation.
  6. 17 |
  7. Fine-tuning XLNet model.
  8. 18 |
  9. Evaluate performance of the model.
  10. 19 |
  11. Making predictions on raw text.
  12. 20 |
21 | 22 | Results: 23 | 26 | -------------------------------------------------------------------------------- /Sentiment_Analysis_Series_part_1.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "Sentiment Analysis Series part-1.ipynb", 7 | "provenance": [], 8 | "collapsed_sections": [], 9 | "authorship_tag": "ABX9TyNxtXs+TX1Ira4fPz7vjXW1", 10 | "include_colab_link": true 11 | }, 12 | "kernelspec": { 13 | "name": "python3", 14 | "display_name": "Python 3" 15 | }, 16 | "accelerator": "GPU", 17 | "widgets": { 18 | "application/vnd.jupyter.widget-state+json": { 19 | "bd7d4d34bd6d46ae985741702b456aaf": { 20 | "model_module": "@jupyter-widgets/controls", 21 | "model_name": "HBoxModel", 22 | "state": { 23 | "_view_name": "HBoxView", 24 | "_dom_classes": [], 25 | "_model_name": "HBoxModel", 26 | "_view_module": "@jupyter-widgets/controls", 27 | "_model_module_version": "1.5.0", 28 | "_view_count": null, 29 | "_view_module_version": "1.5.0", 30 | "box_style": "", 31 | "layout": "IPY_MODEL_e8135d40c9c4460d9ce1eb0885c562e4", 32 | "_model_module": "@jupyter-widgets/controls", 33 | "children": [ 34 | "IPY_MODEL_0954f506d85d4723a3b32b6606db0015", 35 | "IPY_MODEL_297ea592ec1b44558448159860cd1aa0" 36 | ] 37 | } 38 | }, 39 | "e8135d40c9c4460d9ce1eb0885c562e4": { 40 | "model_module": "@jupyter-widgets/base", 41 | "model_name": "LayoutModel", 42 | "state": { 43 | "_view_name": "LayoutView", 44 | "grid_template_rows": null, 45 | "right": null, 46 | "justify_content": null, 47 | "_view_module": "@jupyter-widgets/base", 48 | "overflow": null, 49 | "_model_module_version": "1.2.0", 50 | "_view_count": null, 51 | "flex_flow": null, 52 | "width": null, 53 | "min_width": null, 54 | "border": null, 55 | "align_items": null, 56 | "bottom": null, 57 | "_model_module": "@jupyter-widgets/base", 58 | "top": null, 59 | "grid_column": null, 60 | "overflow_y": null, 61 | "overflow_x": null, 62 | "grid_auto_flow": null, 63 | "grid_area": null, 64 | "grid_template_columns": null, 65 | "flex": null, 66 | "_model_name": "LayoutModel", 67 | "justify_items": null, 68 | "grid_row": null, 69 | "max_height": null, 70 | "align_content": null, 71 | "visibility": null, 72 | "align_self": null, 73 | "height": null, 74 | "min_height": null, 75 | "padding": null, 76 | "grid_auto_rows": null, 77 | "grid_gap": null, 78 | "max_width": null, 79 | "order": null, 80 | "_view_module_version": "1.2.0", 81 | "grid_template_areas": null, 82 | "object_position": null, 83 | "object_fit": null, 84 | "grid_auto_columns": null, 85 | "margin": null, 86 | "display": null, 87 | "left": null 88 | } 89 | }, 90 | "0954f506d85d4723a3b32b6606db0015": { 91 | "model_module": "@jupyter-widgets/controls", 92 | "model_name": "FloatProgressModel", 93 | "state": { 94 | "_view_name": "ProgressView", 95 | "style": "IPY_MODEL_ed27989647d140eca7d3bf745145a3a0", 96 | "_dom_classes": [], 97 | "description": "Downloading: 100%", 98 | "_model_name": "FloatProgressModel", 99 | "bar_style": "success", 100 | "max": 798011, 101 | "_view_module": "@jupyter-widgets/controls", 102 | "_model_module_version": "1.5.0", 103 | "value": 798011, 104 | "_view_count": null, 105 | "_view_module_version": "1.5.0", 106 | "orientation": "horizontal", 107 | "min": 0, 108 | "description_tooltip": null, 109 | "_model_module": "@jupyter-widgets/controls", 110 | "layout": "IPY_MODEL_f97490f8dee84ccc8029e401abe1042e" 111 | } 112 | }, 113 | "297ea592ec1b44558448159860cd1aa0": { 114 | "model_module": "@jupyter-widgets/controls", 115 | "model_name": "HTMLModel", 116 | "state": { 117 | "_view_name": "HTMLView", 118 | "style": "IPY_MODEL_8d146a11180d4d918cbfc302fc7d2971", 119 | "_dom_classes": [], 120 | "description": "", 121 | "_model_name": "HTMLModel", 122 | "placeholder": "​", 123 | "_view_module": "@jupyter-widgets/controls", 124 | "_model_module_version": "1.5.0", 125 | "value": " 798k/798k [00:37<00:00, 21.2kB/s]", 126 | "_view_count": null, 127 | "_view_module_version": "1.5.0", 128 | "description_tooltip": null, 129 | "_model_module": "@jupyter-widgets/controls", 130 | "layout": "IPY_MODEL_73e86ec287c541d3a189620dd50c33b5" 131 | } 132 | }, 133 | "ed27989647d140eca7d3bf745145a3a0": { 134 | "model_module": "@jupyter-widgets/controls", 135 | "model_name": "ProgressStyleModel", 136 | "state": { 137 | "_view_name": "StyleView", 138 | "_model_name": "ProgressStyleModel", 139 | "description_width": "initial", 140 | "_view_module": "@jupyter-widgets/base", 141 | "_model_module_version": "1.5.0", 142 | "_view_count": null, 143 | "_view_module_version": "1.2.0", 144 | "bar_color": null, 145 | "_model_module": "@jupyter-widgets/controls" 146 | } 147 | }, 148 | "f97490f8dee84ccc8029e401abe1042e": { 149 | "model_module": "@jupyter-widgets/base", 150 | "model_name": "LayoutModel", 151 | "state": { 152 | "_view_name": "LayoutView", 153 | "grid_template_rows": null, 154 | "right": null, 155 | "justify_content": null, 156 | "_view_module": "@jupyter-widgets/base", 157 | "overflow": null, 158 | "_model_module_version": "1.2.0", 159 | "_view_count": null, 160 | "flex_flow": null, 161 | "width": null, 162 | "min_width": null, 163 | "border": null, 164 | "align_items": null, 165 | "bottom": null, 166 | "_model_module": "@jupyter-widgets/base", 167 | "top": null, 168 | "grid_column": null, 169 | "overflow_y": null, 170 | "overflow_x": null, 171 | "grid_auto_flow": null, 172 | "grid_area": null, 173 | "grid_template_columns": null, 174 | "flex": null, 175 | "_model_name": "LayoutModel", 176 | "justify_items": null, 177 | "grid_row": null, 178 | "max_height": null, 179 | "align_content": null, 180 | "visibility": null, 181 | "align_self": null, 182 | "height": null, 183 | "min_height": null, 184 | "padding": null, 185 | "grid_auto_rows": null, 186 | "grid_gap": null, 187 | "max_width": null, 188 | "order": null, 189 | "_view_module_version": "1.2.0", 190 | "grid_template_areas": null, 191 | "object_position": null, 192 | "object_fit": null, 193 | "grid_auto_columns": null, 194 | "margin": null, 195 | "display": null, 196 | "left": null 197 | } 198 | }, 199 | "8d146a11180d4d918cbfc302fc7d2971": { 200 | "model_module": "@jupyter-widgets/controls", 201 | "model_name": "DescriptionStyleModel", 202 | "state": { 203 | "_view_name": "StyleView", 204 | "_model_name": "DescriptionStyleModel", 205 | "description_width": "", 206 | "_view_module": "@jupyter-widgets/base", 207 | "_model_module_version": "1.5.0", 208 | "_view_count": null, 209 | "_view_module_version": "1.2.0", 210 | "_model_module": "@jupyter-widgets/controls" 211 | } 212 | }, 213 | "73e86ec287c541d3a189620dd50c33b5": { 214 | "model_module": "@jupyter-widgets/base", 215 | "model_name": "LayoutModel", 216 | "state": { 217 | "_view_name": "LayoutView", 218 | "grid_template_rows": null, 219 | "right": null, 220 | "justify_content": null, 221 | "_view_module": "@jupyter-widgets/base", 222 | "overflow": null, 223 | "_model_module_version": "1.2.0", 224 | "_view_count": null, 225 | "flex_flow": null, 226 | "width": null, 227 | "min_width": null, 228 | "border": null, 229 | "align_items": null, 230 | "bottom": null, 231 | "_model_module": "@jupyter-widgets/base", 232 | "top": null, 233 | "grid_column": null, 234 | "overflow_y": null, 235 | "overflow_x": null, 236 | "grid_auto_flow": null, 237 | "grid_area": null, 238 | "grid_template_columns": null, 239 | "flex": null, 240 | "_model_name": "LayoutModel", 241 | "justify_items": null, 242 | "grid_row": null, 243 | "max_height": null, 244 | "align_content": null, 245 | "visibility": null, 246 | "align_self": null, 247 | "height": null, 248 | "min_height": null, 249 | "padding": null, 250 | "grid_auto_rows": null, 251 | "grid_gap": null, 252 | "max_width": null, 253 | "order": null, 254 | "_view_module_version": "1.2.0", 255 | "grid_template_areas": null, 256 | "object_position": null, 257 | "object_fit": null, 258 | "grid_auto_columns": null, 259 | "margin": null, 260 | "display": null, 261 | "left": null 262 | } 263 | }, 264 | "c2b26508fa464328b2b30c851a0366fe": { 265 | "model_module": "@jupyter-widgets/controls", 266 | "model_name": "HBoxModel", 267 | "state": { 268 | "_view_name": "HBoxView", 269 | "_dom_classes": [], 270 | "_model_name": "HBoxModel", 271 | "_view_module": "@jupyter-widgets/controls", 272 | "_model_module_version": "1.5.0", 273 | "_view_count": null, 274 | "_view_module_version": "1.5.0", 275 | "box_style": "", 276 | "layout": "IPY_MODEL_18e0ec87d07d4c0bb3930b5cef204963", 277 | "_model_module": "@jupyter-widgets/controls", 278 | "children": [ 279 | "IPY_MODEL_bac82e5dc1974ef4ac18f743ecfb72b5", 280 | "IPY_MODEL_35d3d0c2b7004a6b87420e55a94ddcaf" 281 | ] 282 | } 283 | }, 284 | "18e0ec87d07d4c0bb3930b5cef204963": { 285 | "model_module": "@jupyter-widgets/base", 286 | "model_name": "LayoutModel", 287 | "state": { 288 | "_view_name": "LayoutView", 289 | "grid_template_rows": null, 290 | "right": null, 291 | "justify_content": null, 292 | "_view_module": "@jupyter-widgets/base", 293 | "overflow": null, 294 | "_model_module_version": "1.2.0", 295 | "_view_count": null, 296 | "flex_flow": null, 297 | "width": null, 298 | "min_width": null, 299 | "border": null, 300 | "align_items": null, 301 | "bottom": null, 302 | "_model_module": "@jupyter-widgets/base", 303 | "top": null, 304 | "grid_column": null, 305 | "overflow_y": null, 306 | "overflow_x": null, 307 | "grid_auto_flow": null, 308 | "grid_area": null, 309 | "grid_template_columns": null, 310 | "flex": null, 311 | "_model_name": "LayoutModel", 312 | "justify_items": null, 313 | "grid_row": null, 314 | "max_height": null, 315 | "align_content": null, 316 | "visibility": null, 317 | "align_self": null, 318 | "height": null, 319 | "min_height": null, 320 | "padding": null, 321 | "grid_auto_rows": null, 322 | "grid_gap": null, 323 | "max_width": null, 324 | "order": null, 325 | "_view_module_version": "1.2.0", 326 | "grid_template_areas": null, 327 | "object_position": null, 328 | "object_fit": null, 329 | "grid_auto_columns": null, 330 | "margin": null, 331 | "display": null, 332 | "left": null 333 | } 334 | }, 335 | "bac82e5dc1974ef4ac18f743ecfb72b5": { 336 | "model_module": "@jupyter-widgets/controls", 337 | "model_name": "FloatProgressModel", 338 | "state": { 339 | "_view_name": "ProgressView", 340 | "style": "IPY_MODEL_d8893bea6edb4154b051b344cae1636c", 341 | "_dom_classes": [], 342 | "description": "Downloading: 100%", 343 | "_model_name": "FloatProgressModel", 344 | "bar_style": "success", 345 | "max": 760, 346 | "_view_module": "@jupyter-widgets/controls", 347 | "_model_module_version": "1.5.0", 348 | "value": 760, 349 | "_view_count": null, 350 | "_view_module_version": "1.5.0", 351 | "orientation": "horizontal", 352 | "min": 0, 353 | "description_tooltip": null, 354 | "_model_module": "@jupyter-widgets/controls", 355 | "layout": "IPY_MODEL_d8c49d3421e843a6a2aadd20565ea5d2" 356 | } 357 | }, 358 | "35d3d0c2b7004a6b87420e55a94ddcaf": { 359 | "model_module": "@jupyter-widgets/controls", 360 | "model_name": "HTMLModel", 361 | "state": { 362 | "_view_name": "HTMLView", 363 | "style": "IPY_MODEL_3524b098b1944e5ba752f1fbfcb50aef", 364 | "_dom_classes": [], 365 | "description": "", 366 | "_model_name": "HTMLModel", 367 | "placeholder": "​", 368 | "_view_module": "@jupyter-widgets/controls", 369 | "_model_module_version": "1.5.0", 370 | "value": " 760/760 [00:42<00:00, 17.9B/s]", 371 | "_view_count": null, 372 | "_view_module_version": "1.5.0", 373 | "description_tooltip": null, 374 | "_model_module": "@jupyter-widgets/controls", 375 | "layout": "IPY_MODEL_0cba445897514d4b98192e28890df304" 376 | } 377 | }, 378 | "d8893bea6edb4154b051b344cae1636c": { 379 | "model_module": "@jupyter-widgets/controls", 380 | "model_name": "ProgressStyleModel", 381 | "state": { 382 | "_view_name": "StyleView", 383 | "_model_name": "ProgressStyleModel", 384 | "description_width": "initial", 385 | "_view_module": "@jupyter-widgets/base", 386 | "_model_module_version": "1.5.0", 387 | "_view_count": null, 388 | "_view_module_version": "1.2.0", 389 | "bar_color": null, 390 | "_model_module": "@jupyter-widgets/controls" 391 | } 392 | }, 393 | "d8c49d3421e843a6a2aadd20565ea5d2": { 394 | "model_module": "@jupyter-widgets/base", 395 | "model_name": "LayoutModel", 396 | "state": { 397 | "_view_name": "LayoutView", 398 | "grid_template_rows": null, 399 | "right": null, 400 | "justify_content": null, 401 | "_view_module": "@jupyter-widgets/base", 402 | "overflow": null, 403 | "_model_module_version": "1.2.0", 404 | "_view_count": null, 405 | "flex_flow": null, 406 | "width": null, 407 | "min_width": null, 408 | "border": null, 409 | "align_items": null, 410 | "bottom": null, 411 | "_model_module": "@jupyter-widgets/base", 412 | "top": null, 413 | "grid_column": null, 414 | "overflow_y": null, 415 | "overflow_x": null, 416 | "grid_auto_flow": null, 417 | "grid_area": null, 418 | "grid_template_columns": null, 419 | "flex": null, 420 | "_model_name": "LayoutModel", 421 | "justify_items": null, 422 | "grid_row": null, 423 | "max_height": null, 424 | "align_content": null, 425 | "visibility": null, 426 | "align_self": null, 427 | "height": null, 428 | "min_height": null, 429 | "padding": null, 430 | "grid_auto_rows": null, 431 | "grid_gap": null, 432 | "max_width": null, 433 | "order": null, 434 | "_view_module_version": "1.2.0", 435 | "grid_template_areas": null, 436 | "object_position": null, 437 | "object_fit": null, 438 | "grid_auto_columns": null, 439 | "margin": null, 440 | "display": null, 441 | "left": null 442 | } 443 | }, 444 | "3524b098b1944e5ba752f1fbfcb50aef": { 445 | "model_module": "@jupyter-widgets/controls", 446 | "model_name": "DescriptionStyleModel", 447 | "state": { 448 | "_view_name": "StyleView", 449 | "_model_name": "DescriptionStyleModel", 450 | "description_width": "", 451 | "_view_module": "@jupyter-widgets/base", 452 | "_model_module_version": "1.5.0", 453 | "_view_count": null, 454 | "_view_module_version": "1.2.0", 455 | "_model_module": "@jupyter-widgets/controls" 456 | } 457 | }, 458 | "0cba445897514d4b98192e28890df304": { 459 | "model_module": "@jupyter-widgets/base", 460 | "model_name": "LayoutModel", 461 | "state": { 462 | "_view_name": "LayoutView", 463 | "grid_template_rows": null, 464 | "right": null, 465 | "justify_content": null, 466 | "_view_module": "@jupyter-widgets/base", 467 | "overflow": null, 468 | "_model_module_version": "1.2.0", 469 | "_view_count": null, 470 | "flex_flow": null, 471 | "width": null, 472 | "min_width": null, 473 | "border": null, 474 | "align_items": null, 475 | "bottom": null, 476 | "_model_module": "@jupyter-widgets/base", 477 | "top": null, 478 | "grid_column": null, 479 | "overflow_y": null, 480 | "overflow_x": null, 481 | "grid_auto_flow": null, 482 | "grid_area": null, 483 | "grid_template_columns": null, 484 | "flex": null, 485 | "_model_name": "LayoutModel", 486 | "justify_items": null, 487 | "grid_row": null, 488 | "max_height": null, 489 | "align_content": null, 490 | "visibility": null, 491 | "align_self": null, 492 | "height": null, 493 | "min_height": null, 494 | "padding": null, 495 | "grid_auto_rows": null, 496 | "grid_gap": null, 497 | "max_width": null, 498 | "order": null, 499 | "_view_module_version": "1.2.0", 500 | "grid_template_areas": null, 501 | "object_position": null, 502 | "object_fit": null, 503 | "grid_auto_columns": null, 504 | "margin": null, 505 | "display": null, 506 | "left": null 507 | } 508 | }, 509 | "62d0e9673815468ba407019771a0b2ba": { 510 | "model_module": "@jupyter-widgets/controls", 511 | "model_name": "HBoxModel", 512 | "state": { 513 | "_view_name": "HBoxView", 514 | "_dom_classes": [], 515 | "_model_name": "HBoxModel", 516 | "_view_module": "@jupyter-widgets/controls", 517 | "_model_module_version": "1.5.0", 518 | "_view_count": null, 519 | "_view_module_version": "1.5.0", 520 | "box_style": "", 521 | "layout": "IPY_MODEL_6af011b92bd3462b937cea74a2094579", 522 | "_model_module": "@jupyter-widgets/controls", 523 | "children": [ 524 | "IPY_MODEL_47a41f0a294f42cbb2c9ec3916ca100e", 525 | "IPY_MODEL_3d9a5c47f9fb44ffa41f91cc37b8ee6e" 526 | ] 527 | } 528 | }, 529 | "6af011b92bd3462b937cea74a2094579": { 530 | "model_module": "@jupyter-widgets/base", 531 | "model_name": "LayoutModel", 532 | "state": { 533 | "_view_name": "LayoutView", 534 | "grid_template_rows": null, 535 | "right": null, 536 | "justify_content": null, 537 | "_view_module": "@jupyter-widgets/base", 538 | "overflow": null, 539 | "_model_module_version": "1.2.0", 540 | "_view_count": null, 541 | "flex_flow": null, 542 | "width": null, 543 | "min_width": null, 544 | "border": null, 545 | "align_items": null, 546 | "bottom": null, 547 | "_model_module": "@jupyter-widgets/base", 548 | "top": null, 549 | "grid_column": null, 550 | "overflow_y": null, 551 | "overflow_x": null, 552 | "grid_auto_flow": null, 553 | "grid_area": null, 554 | "grid_template_columns": null, 555 | "flex": null, 556 | "_model_name": "LayoutModel", 557 | "justify_items": null, 558 | "grid_row": null, 559 | "max_height": null, 560 | "align_content": null, 561 | "visibility": null, 562 | "align_self": null, 563 | "height": null, 564 | "min_height": null, 565 | "padding": null, 566 | "grid_auto_rows": null, 567 | "grid_gap": null, 568 | "max_width": null, 569 | "order": null, 570 | "_view_module_version": "1.2.0", 571 | "grid_template_areas": null, 572 | "object_position": null, 573 | "object_fit": null, 574 | "grid_auto_columns": null, 575 | "margin": null, 576 | "display": null, 577 | "left": null 578 | } 579 | }, 580 | "47a41f0a294f42cbb2c9ec3916ca100e": { 581 | "model_module": "@jupyter-widgets/controls", 582 | "model_name": "FloatProgressModel", 583 | "state": { 584 | "_view_name": "ProgressView", 585 | "style": "IPY_MODEL_767ef584a8ed49b4ae8ee7835208fb44", 586 | "_dom_classes": [], 587 | "description": "Downloading: 100%", 588 | "_model_name": "FloatProgressModel", 589 | "bar_style": "success", 590 | "max": 467042463, 591 | "_view_module": "@jupyter-widgets/controls", 592 | "_model_module_version": "1.5.0", 593 | "value": 467042463, 594 | "_view_count": null, 595 | "_view_module_version": "1.5.0", 596 | "orientation": "horizontal", 597 | "min": 0, 598 | "description_tooltip": null, 599 | "_model_module": "@jupyter-widgets/controls", 600 | "layout": "IPY_MODEL_769eba56a67d4d71b76f1240fd154a4e" 601 | } 602 | }, 603 | "3d9a5c47f9fb44ffa41f91cc37b8ee6e": { 604 | "model_module": "@jupyter-widgets/controls", 605 | "model_name": "HTMLModel", 606 | "state": { 607 | "_view_name": "HTMLView", 608 | "style": "IPY_MODEL_f6cc72303ee14b548a2c4c1282a1784b", 609 | "_dom_classes": [], 610 | "description": "", 611 | "_model_name": "HTMLModel", 612 | "placeholder": "​", 613 | "_view_module": "@jupyter-widgets/controls", 614 | "_model_module_version": "1.5.0", 615 | "value": " 467M/467M [00:06<00:00, 72.4MB/s]", 616 | "_view_count": null, 617 | "_view_module_version": "1.5.0", 618 | "description_tooltip": null, 619 | "_model_module": "@jupyter-widgets/controls", 620 | "layout": "IPY_MODEL_b9235fdb071742628781944bb3b6608d" 621 | } 622 | }, 623 | "767ef584a8ed49b4ae8ee7835208fb44": { 624 | "model_module": "@jupyter-widgets/controls", 625 | "model_name": "ProgressStyleModel", 626 | "state": { 627 | "_view_name": "StyleView", 628 | "_model_name": "ProgressStyleModel", 629 | "description_width": "initial", 630 | "_view_module": "@jupyter-widgets/base", 631 | "_model_module_version": "1.5.0", 632 | "_view_count": null, 633 | "_view_module_version": "1.2.0", 634 | "bar_color": null, 635 | "_model_module": "@jupyter-widgets/controls" 636 | } 637 | }, 638 | "769eba56a67d4d71b76f1240fd154a4e": { 639 | "model_module": "@jupyter-widgets/base", 640 | "model_name": "LayoutModel", 641 | "state": { 642 | "_view_name": "LayoutView", 643 | "grid_template_rows": null, 644 | "right": null, 645 | "justify_content": null, 646 | "_view_module": "@jupyter-widgets/base", 647 | "overflow": null, 648 | "_model_module_version": "1.2.0", 649 | "_view_count": null, 650 | "flex_flow": null, 651 | "width": null, 652 | "min_width": null, 653 | "border": null, 654 | "align_items": null, 655 | "bottom": null, 656 | "_model_module": "@jupyter-widgets/base", 657 | "top": null, 658 | "grid_column": null, 659 | "overflow_y": null, 660 | "overflow_x": null, 661 | "grid_auto_flow": null, 662 | "grid_area": null, 663 | "grid_template_columns": null, 664 | "flex": null, 665 | "_model_name": "LayoutModel", 666 | "justify_items": null, 667 | "grid_row": null, 668 | "max_height": null, 669 | "align_content": null, 670 | "visibility": null, 671 | "align_self": null, 672 | "height": null, 673 | "min_height": null, 674 | "padding": null, 675 | "grid_auto_rows": null, 676 | "grid_gap": null, 677 | "max_width": null, 678 | "order": null, 679 | "_view_module_version": "1.2.0", 680 | "grid_template_areas": null, 681 | "object_position": null, 682 | "object_fit": null, 683 | "grid_auto_columns": null, 684 | "margin": null, 685 | "display": null, 686 | "left": null 687 | } 688 | }, 689 | "f6cc72303ee14b548a2c4c1282a1784b": { 690 | "model_module": "@jupyter-widgets/controls", 691 | "model_name": "DescriptionStyleModel", 692 | "state": { 693 | "_view_name": "StyleView", 694 | "_model_name": "DescriptionStyleModel", 695 | "description_width": "", 696 | "_view_module": "@jupyter-widgets/base", 697 | "_model_module_version": "1.5.0", 698 | "_view_count": null, 699 | "_view_module_version": "1.2.0", 700 | "_model_module": "@jupyter-widgets/controls" 701 | } 702 | }, 703 | "b9235fdb071742628781944bb3b6608d": { 704 | "model_module": "@jupyter-widgets/base", 705 | "model_name": "LayoutModel", 706 | "state": { 707 | "_view_name": "LayoutView", 708 | "grid_template_rows": null, 709 | "right": null, 710 | "justify_content": null, 711 | "_view_module": "@jupyter-widgets/base", 712 | "overflow": null, 713 | "_model_module_version": "1.2.0", 714 | "_view_count": null, 715 | "flex_flow": null, 716 | "width": null, 717 | "min_width": null, 718 | "border": null, 719 | "align_items": null, 720 | "bottom": null, 721 | "_model_module": "@jupyter-widgets/base", 722 | "top": null, 723 | "grid_column": null, 724 | "overflow_y": null, 725 | "overflow_x": null, 726 | "grid_auto_flow": null, 727 | "grid_area": null, 728 | "grid_template_columns": null, 729 | "flex": null, 730 | "_model_name": "LayoutModel", 731 | "justify_items": null, 732 | "grid_row": null, 733 | "max_height": null, 734 | "align_content": null, 735 | "visibility": null, 736 | "align_self": null, 737 | "height": null, 738 | "min_height": null, 739 | "padding": null, 740 | "grid_auto_rows": null, 741 | "grid_gap": null, 742 | "max_width": null, 743 | "order": null, 744 | "_view_module_version": "1.2.0", 745 | "grid_template_areas": null, 746 | "object_position": null, 747 | "object_fit": null, 748 | "grid_auto_columns": null, 749 | "margin": null, 750 | "display": null, 751 | "left": null 752 | } 753 | } 754 | } 755 | } 756 | }, 757 | "cells": [ 758 | { 759 | "cell_type": "markdown", 760 | "metadata": { 761 | "id": "view-in-github", 762 | "colab_type": "text" 763 | }, 764 | "source": [ 765 | "\"Open" 766 | ] 767 | }, 768 | { 769 | "cell_type": "code", 770 | "metadata": { 771 | "id": "txqNzM1xIPIR", 772 | "colab_type": "code", 773 | "outputId": "efffa915-6104-4a1c-c751-90e516fad837", 774 | "colab": { 775 | "base_uri": "https://localhost:8080/", 776 | "height": 120 777 | } 778 | }, 779 | "source": [ 780 | "from google.colab import drive\n", 781 | "drive.mount('/content/drive')" 782 | ], 783 | "execution_count": 0, 784 | "outputs": [ 785 | { 786 | "output_type": "stream", 787 | "text": [ 788 | "Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3aietf%3awg%3aoauth%3a2.0%3aoob&response_type=code&scope=email%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdocs.test%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive.photos.readonly%20https%3a%2f%2fwww.googleapis.com%2fauth%2fpeopleapi.readonly\n", 789 | "\n", 790 | "Enter your authorization code:\n", 791 | "··········\n", 792 | "Mounted at /content/drive\n" 793 | ], 794 | "name": "stdout" 795 | } 796 | ] 797 | }, 798 | { 799 | "cell_type": "code", 800 | "metadata": { 801 | "id": "3m6sGbnGeXHk", 802 | "colab_type": "code", 803 | "outputId": "21365dc2-00df-4fcd-8b66-f1aee796f89f", 804 | "colab": { 805 | "base_uri": "https://localhost:8080/", 806 | "height": 351 807 | } 808 | }, 809 | "source": [ 810 | "!nvidia-smi" 811 | ], 812 | "execution_count": 0, 813 | "outputs": [ 814 | { 815 | "output_type": "stream", 816 | "text": [ 817 | "Sun Jun 14 03:02:22 2020 \n", 818 | "+-----------------------------------------------------------------------------+\n", 819 | "| NVIDIA-SMI 450.36.06 Driver Version: 418.67 CUDA Version: 10.1 |\n", 820 | "|-------------------------------+----------------------+----------------------+\n", 821 | "| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |\n", 822 | "| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\n", 823 | "| | | MIG M. |\n", 824 | "|===============================+======================+======================|\n", 825 | "| 0 Tesla P100-PCIE... Off | 00000000:00:04.0 Off | 0 |\n", 826 | "| N/A 34C P0 27W / 250W | 0MiB / 16280MiB | 0% Default |\n", 827 | "| | | ERR! |\n", 828 | "+-------------------------------+----------------------+----------------------+\n", 829 | " \n", 830 | "+-----------------------------------------------------------------------------+\n", 831 | "| Processes: |\n", 832 | "| GPU GI CI PID Type Process name GPU Memory |\n", 833 | "| ID ID Usage |\n", 834 | "|=============================================================================|\n", 835 | "| No running processes found |\n", 836 | "+-----------------------------------------------------------------------------+\n" 837 | ], 838 | "name": "stdout" 839 | } 840 | ] 841 | }, 842 | { 843 | "cell_type": "code", 844 | "metadata": { 845 | "id": "HqQKcukJAicL", 846 | "colab_type": "code", 847 | "outputId": "548f009b-9187-47f5-fc78-8275907438b5", 848 | "colab": { 849 | "base_uri": "https://localhost:8080/", 850 | "height": 605 851 | } 852 | }, 853 | "source": [ 854 | "!pip install transformers" 855 | ], 856 | "execution_count": 0, 857 | "outputs": [ 858 | { 859 | "output_type": "stream", 860 | "text": [ 861 | "Collecting transformers\n", 862 | "\u001b[?25l Downloading https://files.pythonhosted.org/packages/48/35/ad2c5b1b8f99feaaf9d7cdadaeef261f098c6e1a6a2935d4d07662a6b780/transformers-2.11.0-py3-none-any.whl (674kB)\n", 863 | "\u001b[K |████████████████████████████████| 675kB 2.8MB/s \n", 864 | "\u001b[?25hRequirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.6/dist-packages (from transformers) (4.41.1)\n", 865 | "Requirement already satisfied: requests in /usr/local/lib/python3.6/dist-packages (from transformers) (2.23.0)\n", 866 | "Collecting sacremoses\n", 867 | "\u001b[?25l Downloading https://files.pythonhosted.org/packages/7d/34/09d19aff26edcc8eb2a01bed8e98f13a1537005d31e95233fd48216eed10/sacremoses-0.0.43.tar.gz (883kB)\n", 868 | "\u001b[K |████████████████████████████████| 890kB 15.5MB/s \n", 869 | "\u001b[?25hCollecting tokenizers==0.7.0\n", 870 | "\u001b[?25l Downloading https://files.pythonhosted.org/packages/14/e5/a26eb4716523808bb0a799fcfdceb6ebf77a18169d9591b2f46a9adb87d9/tokenizers-0.7.0-cp36-cp36m-manylinux1_x86_64.whl (3.8MB)\n", 871 | "\u001b[K |████████████████████████████████| 3.8MB 12.9MB/s \n", 872 | "\u001b[?25hRequirement already satisfied: dataclasses; python_version < \"3.7\" in /usr/local/lib/python3.6/dist-packages (from transformers) (0.7)\n", 873 | "Collecting sentencepiece\n", 874 | "\u001b[?25l Downloading https://files.pythonhosted.org/packages/d4/a4/d0a884c4300004a78cca907a6ff9a5e9fe4f090f5d95ab341c53d28cbc58/sentencepiece-0.1.91-cp36-cp36m-manylinux1_x86_64.whl (1.1MB)\n", 875 | "\u001b[K |████████████████████████████████| 1.1MB 26.3MB/s \n", 876 | "\u001b[?25hRequirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.6/dist-packages (from transformers) (2019.12.20)\n", 877 | "Requirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from transformers) (1.18.5)\n", 878 | "Requirement already satisfied: packaging in /usr/local/lib/python3.6/dist-packages (from transformers) (20.4)\n", 879 | "Requirement already satisfied: filelock in /usr/local/lib/python3.6/dist-packages (from transformers) (3.0.12)\n", 880 | "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests->transformers) (2020.4.5.1)\n", 881 | "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests->transformers) (2.9)\n", 882 | "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from requests->transformers) (1.24.3)\n", 883 | "Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests->transformers) (3.0.4)\n", 884 | "Requirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from sacremoses->transformers) (1.12.0)\n", 885 | "Requirement already satisfied: click in /usr/local/lib/python3.6/dist-packages (from sacremoses->transformers) (7.1.2)\n", 886 | "Requirement already satisfied: joblib in /usr/local/lib/python3.6/dist-packages (from sacremoses->transformers) (0.15.1)\n", 887 | "Requirement already satisfied: pyparsing>=2.0.2 in /usr/local/lib/python3.6/dist-packages (from packaging->transformers) (2.4.7)\n", 888 | "Building wheels for collected packages: sacremoses\n", 889 | " Building wheel for sacremoses (setup.py) ... \u001b[?25l\u001b[?25hdone\n", 890 | " Created wheel for sacremoses: filename=sacremoses-0.0.43-cp36-none-any.whl size=893260 sha256=97710e3620e53f263dc38f0e295f2e16da22030a818999315a002e1088be2cf9\n", 891 | " Stored in directory: /root/.cache/pip/wheels/29/3c/fd/7ce5c3f0666dab31a50123635e6fb5e19ceb42ce38d4e58f45\n", 892 | "Successfully built sacremoses\n", 893 | "Installing collected packages: sacremoses, tokenizers, sentencepiece, transformers\n", 894 | "Successfully installed sacremoses-0.0.43 sentencepiece-0.1.91 tokenizers-0.7.0 transformers-2.11.0\n" 895 | ], 896 | "name": "stdout" 897 | } 898 | ] 899 | }, 900 | { 901 | "cell_type": "code", 902 | "metadata": { 903 | "id": "U8FRjwBo15uT", 904 | "colab_type": "code", 905 | "colab": {} 906 | }, 907 | "source": [ 908 | "!pip install -q -U watermark" 909 | ], 910 | "execution_count": 0, 911 | "outputs": [] 912 | }, 913 | { 914 | "cell_type": "code", 915 | "metadata": { 916 | "id": "3DNdmQhn2D43", 917 | "colab_type": "code", 918 | "outputId": "ab9b076b-e566-4104-ccb7-95d16d360c85", 919 | "colab": { 920 | "base_uri": "https://localhost:8080/", 921 | "height": 134 922 | } 923 | }, 924 | "source": [ 925 | "%reload_ext watermark\n", 926 | "%watermark -v -p numpy,pandas,torch,transformers" 927 | ], 928 | "execution_count": 0, 929 | "outputs": [ 930 | { 931 | "output_type": "stream", 932 | "text": [ 933 | "CPython 3.6.9\n", 934 | "IPython 5.5.0\n", 935 | "\n", 936 | "numpy 1.18.5\n", 937 | "pandas 1.0.4\n", 938 | "torch 1.5.0+cu101\n", 939 | "transformers 2.11.0\n" 940 | ], 941 | "name": "stdout" 942 | } 943 | ] 944 | }, 945 | { 946 | "cell_type": "markdown", 947 | "metadata": { 948 | "id": "Mk0GnvUQIdxc", 949 | "colab_type": "text" 950 | }, 951 | "source": [ 952 | "### Making the necessary imports" 953 | ] 954 | }, 955 | { 956 | "cell_type": "code", 957 | "metadata": { 958 | "id": "dOpc8w_12D1f", 959 | "colab_type": "code", 960 | "outputId": "2163b72f-d590-48e3-ab75-d956198150a8", 961 | "colab": { 962 | "base_uri": "https://localhost:8080/", 963 | "height": 87 964 | } 965 | }, 966 | "source": [ 967 | "import transformers\n", 968 | "from transformers import XLNetTokenizer, XLNetModel, AdamW, get_linear_schedule_with_warmup\n", 969 | "import torch\n", 970 | "\n", 971 | "import numpy as np\n", 972 | "import pandas as pd\n", 973 | "import seaborn as sns\n", 974 | "import matplotlib.pyplot as plt\n", 975 | "from matplotlib import rc\n", 976 | "from sklearn.model_selection import train_test_split\n", 977 | "from sklearn.metrics import confusion_matrix, classification_report, accuracy\n", 978 | "from collections import defaultdict\n", 979 | "from textwrap import wrap\n", 980 | "from pylab import rcParams\n", 981 | "\n", 982 | "from torch import nn, optim\n", 983 | "from keras.preprocessing.sequence import pad_sequences\n", 984 | "from torch.utils.data import TensorDataset,RandomSampler,SequentialSampler\n", 985 | "from torch.utils.data import Dataset, DataLoader\n", 986 | "import torch.nn.functional as F" 987 | ], 988 | "execution_count": 0, 989 | "outputs": [ 990 | { 991 | "output_type": "stream", 992 | "text": [ 993 | "/usr/local/lib/python3.6/dist-packages/statsmodels/tools/_testing.py:19: FutureWarning: pandas.util.testing is deprecated. Use the functions in the public API at pandas.testing instead.\n", 994 | " import pandas.util.testing as tm\n", 995 | "Using TensorFlow backend.\n" 996 | ], 997 | "name": "stderr" 998 | } 999 | ] 1000 | }, 1001 | { 1002 | "cell_type": "code", 1003 | "metadata": { 1004 | "id": "qAICzZNo2Dyw", 1005 | "colab_type": "code", 1006 | "outputId": "dd0b17dd-a7d5-4f20-9244-38cd860de614", 1007 | "colab": { 1008 | "base_uri": "https://localhost:8080/", 1009 | "height": 33 1010 | } 1011 | }, 1012 | "source": [ 1013 | "%matplotlib inline\n", 1014 | "%config InlineBackend.figure_format='retina'\n", 1015 | "\n", 1016 | "sns.set(style='whitegrid', palette='muted', font_scale=1.2)\n", 1017 | "\n", 1018 | "HAPPY_COLORS_PALETTE = [\"#01BEFE\", \"#FFDD00\", \"#FF7D00\", \"#FF006D\", \"#ADFF02\", \"#8F00FF\"]\n", 1019 | "\n", 1020 | "sns.set_palette(sns.color_palette(HAPPY_COLORS_PALETTE))\n", 1021 | "\n", 1022 | "rcParams['figure.figsize'] = 12, 8\n", 1023 | "\n", 1024 | "RANDOM_SEED = 42\n", 1025 | "np.random.seed(RANDOM_SEED)\n", 1026 | "torch.manual_seed(RANDOM_SEED)\n", 1027 | "\n", 1028 | "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", 1029 | "device" 1030 | ], 1031 | "execution_count": 0, 1032 | "outputs": [ 1033 | { 1034 | "output_type": "execute_result", 1035 | "data": { 1036 | "text/plain": [ 1037 | "device(type='cuda', index=0)" 1038 | ] 1039 | }, 1040 | "metadata": { 1041 | "tags": [] 1042 | }, 1043 | "execution_count": 7 1044 | } 1045 | ] 1046 | }, 1047 | { 1048 | "cell_type": "markdown", 1049 | "metadata": { 1050 | "id": "gC-WS1hYIxkr", 1051 | "colab_type": "text" 1052 | }, 1053 | "source": [ 1054 | "### Data Preprocessing" 1055 | ] 1056 | }, 1057 | { 1058 | "cell_type": "code", 1059 | "metadata": { 1060 | "id": "jJ83RaRG2Dv8", 1061 | "colab_type": "code", 1062 | "outputId": "69fd3bb8-b85f-42f7-9185-e26fddbf2d11", 1063 | "colab": { 1064 | "base_uri": "https://localhost:8080/", 1065 | "height": 196 1066 | } 1067 | }, 1068 | "source": [ 1069 | "df = pd.read_csv('/content/drive/My Drive/NLP/Sentiment Analysis Series/imdb.csv')\n", 1070 | "df.head()" 1071 | ], 1072 | "execution_count": 0, 1073 | "outputs": [ 1074 | { 1075 | "output_type": "execute_result", 1076 | "data": { 1077 | "text/html": [ 1078 | "
\n", 1079 | "\n", 1092 | "\n", 1093 | " \n", 1094 | " \n", 1095 | " \n", 1096 | " \n", 1097 | " \n", 1098 | " \n", 1099 | " \n", 1100 | " \n", 1101 | " \n", 1102 | " \n", 1103 | " \n", 1104 | " \n", 1105 | " \n", 1106 | " \n", 1107 | " \n", 1108 | " \n", 1109 | " \n", 1110 | " \n", 1111 | " \n", 1112 | " \n", 1113 | " \n", 1114 | " \n", 1115 | " \n", 1116 | " \n", 1117 | " \n", 1118 | " \n", 1119 | " \n", 1120 | " \n", 1121 | " \n", 1122 | " \n", 1123 | " \n", 1124 | " \n", 1125 | " \n", 1126 | " \n", 1127 | "
reviewsentiment
0One of the other reviewers has mentioned that ...positive
1A wonderful little production. <br /><br />The...positive
2I thought this was a wonderful way to spend ti...positive
3Basically there's a family where a little boy ...negative
4Petter Mattei's \"Love in the Time of Money\" is...positive
\n", 1128 | "
" 1129 | ], 1130 | "text/plain": [ 1131 | " review sentiment\n", 1132 | "0 One of the other reviewers has mentioned that ... positive\n", 1133 | "1 A wonderful little production.

The... positive\n", 1134 | "2 I thought this was a wonderful way to spend ti... positive\n", 1135 | "3 Basically there's a family where a little boy ... negative\n", 1136 | "4 Petter Mattei's \"Love in the Time of Money\" is... positive" 1137 | ] 1138 | }, 1139 | "metadata": { 1140 | "tags": [] 1141 | }, 1142 | "execution_count": 8 1143 | } 1144 | ] 1145 | }, 1146 | { 1147 | "cell_type": "code", 1148 | "metadata": { 1149 | "id": "4J5-ddIc2DtI", 1150 | "colab_type": "code", 1151 | "outputId": "fbe01f11-3cfd-405d-c00a-a945cde3a8c7", 1152 | "colab": { 1153 | "base_uri": "https://localhost:8080/", 1154 | "height": 644 1155 | } 1156 | }, 1157 | "source": [ 1158 | "from sklearn.utils import shuffle\n", 1159 | "df = shuffle(df)\n", 1160 | "df.head(20)" 1161 | ], 1162 | "execution_count": 0, 1163 | "outputs": [ 1164 | { 1165 | "output_type": "execute_result", 1166 | "data": { 1167 | "text/html": [ 1168 | "
\n", 1169 | "\n", 1182 | "\n", 1183 | " \n", 1184 | " \n", 1185 | " \n", 1186 | " \n", 1187 | " \n", 1188 | " \n", 1189 | " \n", 1190 | " \n", 1191 | " \n", 1192 | " \n", 1193 | " \n", 1194 | " \n", 1195 | " \n", 1196 | " \n", 1197 | " \n", 1198 | " \n", 1199 | " \n", 1200 | " \n", 1201 | " \n", 1202 | " \n", 1203 | " \n", 1204 | " \n", 1205 | " \n", 1206 | " \n", 1207 | " \n", 1208 | " \n", 1209 | " \n", 1210 | " \n", 1211 | " \n", 1212 | " \n", 1213 | " \n", 1214 | " \n", 1215 | " \n", 1216 | " \n", 1217 | " \n", 1218 | " \n", 1219 | " \n", 1220 | " \n", 1221 | " \n", 1222 | " \n", 1223 | " \n", 1224 | " \n", 1225 | " \n", 1226 | " \n", 1227 | " \n", 1228 | " \n", 1229 | " \n", 1230 | " \n", 1231 | " \n", 1232 | " \n", 1233 | " \n", 1234 | " \n", 1235 | " \n", 1236 | " \n", 1237 | " \n", 1238 | " \n", 1239 | " \n", 1240 | " \n", 1241 | " \n", 1242 | " \n", 1243 | " \n", 1244 | " \n", 1245 | " \n", 1246 | " \n", 1247 | " \n", 1248 | " \n", 1249 | " \n", 1250 | " \n", 1251 | " \n", 1252 | " \n", 1253 | " \n", 1254 | " \n", 1255 | " \n", 1256 | " \n", 1257 | " \n", 1258 | " \n", 1259 | " \n", 1260 | " \n", 1261 | " \n", 1262 | " \n", 1263 | " \n", 1264 | " \n", 1265 | " \n", 1266 | " \n", 1267 | " \n", 1268 | " \n", 1269 | " \n", 1270 | " \n", 1271 | " \n", 1272 | " \n", 1273 | " \n", 1274 | " \n", 1275 | " \n", 1276 | " \n", 1277 | " \n", 1278 | " \n", 1279 | " \n", 1280 | " \n", 1281 | " \n", 1282 | " \n", 1283 | " \n", 1284 | " \n", 1285 | " \n", 1286 | " \n", 1287 | " \n", 1288 | " \n", 1289 | " \n", 1290 | " \n", 1291 | " \n", 1292 | "
reviewsentiment
33553I really liked this Summerslam due to the look...positive
9427Not many television shows appeal to quite as m...positive
199The film quickly gets to a major chase scene w...negative
12447Jane Austen would definitely approve of this o...positive
39489Expectations were somewhat high for me when I ...negative
42724I've watched this movie on a fairly regular ba...positive
10822For once a story of hope highlighted over the ...positive
49498Okay, I didn't get the Purgatory thing the fir...positive
4144I was very disappointed with this series. It h...negative
36958The first 30 minutes of Tinseltown had my fing...negative
43106jeez, this was immensely boring. the leading m...negative
38695Great just great! The West Coast got \"Dirty\" H...positive
6188It's made in 2007 and the CG is bad for a movi...negative
1414This movie stinks majorly. The only reason I g...negative
18471We can start with the wooden acting but this f...negative
29282This movie starts off somewhat slowly and gets...positive
15177This is a slightly uneven entry with one stand...positive
34304I was first introduced to John Waters films by...positive
12609This movie has very good acting by virtually a...positive
12144I can't help but notice the negative reviews t...positive
\n", 1293 | "
" 1294 | ], 1295 | "text/plain": [ 1296 | " review sentiment\n", 1297 | "33553 I really liked this Summerslam due to the look... positive\n", 1298 | "9427 Not many television shows appeal to quite as m... positive\n", 1299 | "199 The film quickly gets to a major chase scene w... negative\n", 1300 | "12447 Jane Austen would definitely approve of this o... positive\n", 1301 | "39489 Expectations were somewhat high for me when I ... negative\n", 1302 | "42724 I've watched this movie on a fairly regular ba... positive\n", 1303 | "10822 For once a story of hope highlighted over the ... positive\n", 1304 | "49498 Okay, I didn't get the Purgatory thing the fir... positive\n", 1305 | "4144 I was very disappointed with this series. It h... negative\n", 1306 | "36958 The first 30 minutes of Tinseltown had my fing... negative\n", 1307 | "43106 jeez, this was immensely boring. the leading m... negative\n", 1308 | "38695 Great just great! The West Coast got \"Dirty\" H... positive\n", 1309 | "6188 It's made in 2007 and the CG is bad for a movi... negative\n", 1310 | "1414 This movie stinks majorly. The only reason I g... negative\n", 1311 | "18471 We can start with the wooden acting but this f... negative\n", 1312 | "29282 This movie starts off somewhat slowly and gets... positive\n", 1313 | "15177 This is a slightly uneven entry with one stand... positive\n", 1314 | "34304 I was first introduced to John Waters films by... positive\n", 1315 | "12609 This movie has very good acting by virtually a... positive\n", 1316 | "12144 I can't help but notice the negative reviews t... positive" 1317 | ] 1318 | }, 1319 | "metadata": { 1320 | "tags": [] 1321 | }, 1322 | "execution_count": 9 1323 | } 1324 | ] 1325 | }, 1326 | { 1327 | "cell_type": "code", 1328 | "metadata": { 1329 | "id": "uMSFcIqsdZyH", 1330 | "colab_type": "code", 1331 | "outputId": "47589740-52a5-4744-d34f-2e3e36e7b1ec", 1332 | "colab": { 1333 | "base_uri": "https://localhost:8080/", 1334 | "height": 33 1335 | } 1336 | }, 1337 | "source": [ 1338 | "df = df[:24000]\n", 1339 | "len(df)" 1340 | ], 1341 | "execution_count": 0, 1342 | "outputs": [ 1343 | { 1344 | "output_type": "execute_result", 1345 | "data": { 1346 | "text/plain": [ 1347 | "24000" 1348 | ] 1349 | }, 1350 | "metadata": { 1351 | "tags": [] 1352 | }, 1353 | "execution_count": 10 1354 | } 1355 | ] 1356 | }, 1357 | { 1358 | "cell_type": "code", 1359 | "metadata": { 1360 | "id": "GluMm1Nj2DqK", 1361 | "colab_type": "code", 1362 | "colab": {} 1363 | }, 1364 | "source": [ 1365 | "import re\n", 1366 | "def clean_text(text):\n", 1367 | " text = re.sub(r\"@[A-Za-z0-9]+\", ' ', text)\n", 1368 | " text = re.sub(r\"https?://[A-Za-z0-9./]+\", ' ', text)\n", 1369 | " text = re.sub(r\"[^a-zA-z.!?'0-9]\", ' ', text)\n", 1370 | " text = re.sub('\\t', ' ', text)\n", 1371 | " text = re.sub(r\" +\", ' ', text)\n", 1372 | " return text" 1373 | ], 1374 | "execution_count": 0, 1375 | "outputs": [] 1376 | }, 1377 | { 1378 | "cell_type": "code", 1379 | "metadata": { 1380 | "id": "8HpgvTb72wtm", 1381 | "colab_type": "code", 1382 | "colab": {} 1383 | }, 1384 | "source": [ 1385 | "df['review'] = df['review'].apply(clean_text)" 1386 | ], 1387 | "execution_count": 0, 1388 | "outputs": [] 1389 | }, 1390 | { 1391 | "cell_type": "code", 1392 | "metadata": { 1393 | "id": "ptC13l5r25qH", 1394 | "colab_type": "code", 1395 | "outputId": "531b8ac0-a6e2-4750-eadb-2ac2aa3ee628", 1396 | "colab": { 1397 | "base_uri": "https://localhost:8080/", 1398 | "height": 398 1399 | } 1400 | }, 1401 | "source": [ 1402 | "rcParams['figure.figsize'] = 8, 6\n", 1403 | "sns.countplot(df.sentiment)\n", 1404 | "plt.xlabel('review score');" 1405 | ], 1406 | "execution_count": 0, 1407 | "outputs": [ 1408 | { 1409 | "output_type": "display_data", 1410 | "data": { 1411 | "image/png": "\n", 1412 | "text/plain": [ 1413 | "
" 1414 | ] 1415 | }, 1416 | "metadata": { 1417 | "tags": [], 1418 | "image/png": { 1419 | "width": 530, 1420 | "height": 381 1421 | } 1422 | } 1423 | } 1424 | ] 1425 | }, 1426 | { 1427 | "cell_type": "code", 1428 | "metadata": { 1429 | "id": "2I-Du1vR3EZN", 1430 | "colab_type": "code", 1431 | "colab": {} 1432 | }, 1433 | "source": [ 1434 | "def sentiment2label(sentiment):\n", 1435 | " if sentiment == \"positive\":\n", 1436 | " return 1\n", 1437 | " else :\n", 1438 | " return 0\n", 1439 | "\n", 1440 | "df['sentiment'] = df['sentiment'].apply(sentiment2label)" 1441 | ], 1442 | "execution_count": 0, 1443 | "outputs": [] 1444 | }, 1445 | { 1446 | "cell_type": "code", 1447 | "metadata": { 1448 | "id": "YK3jPP3i3lRw", 1449 | "colab_type": "code", 1450 | "outputId": "257a6439-6471-411c-8e48-1762ba3796b5", 1451 | "colab": { 1452 | "base_uri": "https://localhost:8080/", 1453 | "height": 67 1454 | } 1455 | }, 1456 | "source": [ 1457 | "df['sentiment'].value_counts()" 1458 | ], 1459 | "execution_count": 0, 1460 | "outputs": [ 1461 | { 1462 | "output_type": "execute_result", 1463 | "data": { 1464 | "text/plain": [ 1465 | "0 12006\n", 1466 | "1 11994\n", 1467 | "Name: sentiment, dtype: int64" 1468 | ] 1469 | }, 1470 | "metadata": { 1471 | "tags": [] 1472 | }, 1473 | "execution_count": 15 1474 | } 1475 | ] 1476 | }, 1477 | { 1478 | "cell_type": "code", 1479 | "metadata": { 1480 | "id": "StOb4mAa3rgo", 1481 | "colab_type": "code", 1482 | "colab": {} 1483 | }, 1484 | "source": [ 1485 | "class_names = ['negative', 'positive']" 1486 | ], 1487 | "execution_count": 0, 1488 | "outputs": [] 1489 | }, 1490 | { 1491 | "cell_type": "markdown", 1492 | "metadata": { 1493 | "id": "mmBGGgUVJBpT", 1494 | "colab_type": "text" 1495 | }, 1496 | "source": [ 1497 | "### Playing with XLNetTokenizer" 1498 | ] 1499 | }, 1500 | { 1501 | "cell_type": "code", 1502 | "metadata": { 1503 | "id": "hVWO_38_32jL", 1504 | "colab_type": "code", 1505 | "outputId": "81c0d90d-c134-4b43-815a-1eee3d29a947", 1506 | "colab": { 1507 | "base_uri": "https://localhost:8080/", 1508 | "height": 65, 1509 | "referenced_widgets": [ 1510 | "bd7d4d34bd6d46ae985741702b456aaf", 1511 | "e8135d40c9c4460d9ce1eb0885c562e4", 1512 | "0954f506d85d4723a3b32b6606db0015", 1513 | "297ea592ec1b44558448159860cd1aa0", 1514 | "ed27989647d140eca7d3bf745145a3a0", 1515 | "f97490f8dee84ccc8029e401abe1042e", 1516 | "8d146a11180d4d918cbfc302fc7d2971", 1517 | "73e86ec287c541d3a189620dd50c33b5" 1518 | ] 1519 | } 1520 | }, 1521 | "source": [ 1522 | "from transformers import XLNetTokenizer, XLNetModel\n", 1523 | "PRE_TRAINED_MODEL_NAME = 'xlnet-base-cased'\n", 1524 | "tokenizer = XLNetTokenizer.from_pretrained(PRE_TRAINED_MODEL_NAME)" 1525 | ], 1526 | "execution_count": 0, 1527 | "outputs": [ 1528 | { 1529 | "output_type": "display_data", 1530 | "data": { 1531 | "application/vnd.jupyter.widget-view+json": { 1532 | "model_id": "bd7d4d34bd6d46ae985741702b456aaf", 1533 | "version_minor": 0, 1534 | "version_major": 2 1535 | }, 1536 | "text/plain": [ 1537 | "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=798011.0, style=ProgressStyle(descripti…" 1538 | ] 1539 | }, 1540 | "metadata": { 1541 | "tags": [] 1542 | } 1543 | }, 1544 | { 1545 | "output_type": "stream", 1546 | "text": [ 1547 | "\n" 1548 | ], 1549 | "name": "stdout" 1550 | } 1551 | ] 1552 | }, 1553 | { 1554 | "cell_type": "code", 1555 | "metadata": { 1556 | "id": "1ckpqWd_YWgZ", 1557 | "colab_type": "code", 1558 | "colab": {} 1559 | }, 1560 | "source": [ 1561 | "input_txt = \"India is my country. All Indians are my brothers and sisters\"\n", 1562 | "encodings = tokenizer.encode_plus(input_txt, add_special_tokens=True, max_length=16, return_tensors='pt', return_token_type_ids=False, return_attention_mask=True, pad_to_max_length=False)" 1563 | ], 1564 | "execution_count": 0, 1565 | "outputs": [] 1566 | }, 1567 | { 1568 | "cell_type": "code", 1569 | "metadata": { 1570 | "id": "Y_ew5njtYWc1", 1571 | "colab_type": "code", 1572 | "outputId": "dc94277d-2372-48bd-c9f7-2d367c8a3341", 1573 | "colab": { 1574 | "base_uri": "https://localhost:8080/", 1575 | "height": 50 1576 | } 1577 | }, 1578 | "source": [ 1579 | "print('input_ids : ',encodings['input_ids'])" 1580 | ], 1581 | "execution_count": 0, 1582 | "outputs": [ 1583 | { 1584 | "output_type": "stream", 1585 | "text": [ 1586 | "input_ids : tensor([[ 837, 27, 94, 234, 9, 394, 7056, 41, 94, 4194, 21, 8301,\n", 1587 | " 4, 3]])\n" 1588 | ], 1589 | "name": "stdout" 1590 | } 1591 | ] 1592 | }, 1593 | { 1594 | "cell_type": "code", 1595 | "metadata": { 1596 | "id": "eEXfe39mYWK7", 1597 | "colab_type": "code", 1598 | "outputId": "8953d12c-68ad-453c-8fd8-0cff386cf503", 1599 | "colab": { 1600 | "base_uri": "https://localhost:8080/", 1601 | "height": 251 1602 | } 1603 | }, 1604 | "source": [ 1605 | "tokenizer.convert_ids_to_tokens(encodings['input_ids'][0])" 1606 | ], 1607 | "execution_count": 0, 1608 | "outputs": [ 1609 | { 1610 | "output_type": "execute_result", 1611 | "data": { 1612 | "text/plain": [ 1613 | "['▁India',\n", 1614 | " '▁is',\n", 1615 | " '▁my',\n", 1616 | " '▁country',\n", 1617 | " '.',\n", 1618 | " '▁All',\n", 1619 | " '▁Indians',\n", 1620 | " '▁are',\n", 1621 | " '▁my',\n", 1622 | " '▁brothers',\n", 1623 | " '▁and',\n", 1624 | " '▁sisters',\n", 1625 | " '',\n", 1626 | " '']" 1627 | ] 1628 | }, 1629 | "metadata": { 1630 | "tags": [] 1631 | }, 1632 | "execution_count": 20 1633 | } 1634 | ] 1635 | }, 1636 | { 1637 | "cell_type": "code", 1638 | "metadata": { 1639 | "id": "XR1xkLYtTurv", 1640 | "colab_type": "code", 1641 | "outputId": "c86ef945-3963-453d-86f2-753cd7a45097", 1642 | "colab": { 1643 | "base_uri": "https://localhost:8080/", 1644 | "height": 33 1645 | } 1646 | }, 1647 | "source": [ 1648 | "type(encodings['attention_mask'])" 1649 | ], 1650 | "execution_count": 0, 1651 | "outputs": [ 1652 | { 1653 | "output_type": "execute_result", 1654 | "data": { 1655 | "text/plain": [ 1656 | "torch.Tensor" 1657 | ] 1658 | }, 1659 | "metadata": { 1660 | "tags": [] 1661 | }, 1662 | "execution_count": 21 1663 | } 1664 | ] 1665 | }, 1666 | { 1667 | "cell_type": "code", 1668 | "metadata": { 1669 | "id": "Pp97KX-9UtMK", 1670 | "colab_type": "code", 1671 | "colab": {} 1672 | }, 1673 | "source": [ 1674 | "attention_mask = pad_sequences(encodings['attention_mask'], maxlen=512, dtype=torch.Tensor ,truncating=\"post\",padding=\"post\")" 1675 | ], 1676 | "execution_count": 0, 1677 | "outputs": [] 1678 | }, 1679 | { 1680 | "cell_type": "code", 1681 | "metadata": { 1682 | "id": "6Pzeyv1_UvMT", 1683 | "colab_type": "code", 1684 | "outputId": "2acb94b0-943a-4273-bd71-32e62195a137", 1685 | "colab": { 1686 | "base_uri": "https://localhost:8080/", 1687 | "height": 385 1688 | } 1689 | }, 1690 | "source": [ 1691 | "attention_mask = attention_mask.astype(dtype = 'int64')\n", 1692 | "attention_mask = torch.tensor(attention_mask) \n", 1693 | "attention_mask.flatten()" 1694 | ], 1695 | "execution_count": 0, 1696 | "outputs": [ 1697 | { 1698 | "output_type": "execute_result", 1699 | "data": { 1700 | "text/plain": [ 1701 | "tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", 1702 | " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", 1703 | " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", 1704 | " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", 1705 | " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", 1706 | " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", 1707 | " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", 1708 | " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", 1709 | " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", 1710 | " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", 1711 | " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", 1712 | " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", 1713 | " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", 1714 | " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", 1715 | " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", 1716 | " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", 1717 | " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", 1718 | " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", 1719 | " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", 1720 | " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", 1721 | " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", 1722 | " 0, 0, 0, 0, 0, 0, 0, 0])" 1723 | ] 1724 | }, 1725 | "metadata": { 1726 | "tags": [] 1727 | }, 1728 | "execution_count": 23 1729 | } 1730 | ] 1731 | }, 1732 | { 1733 | "cell_type": "code", 1734 | "metadata": { 1735 | "id": "kM7cxiA9RlnA", 1736 | "colab_type": "code", 1737 | "outputId": "78b39189-c841-4174-ec97-364f2327905e", 1738 | "colab": { 1739 | "base_uri": "https://localhost:8080/", 1740 | "height": 50 1741 | } 1742 | }, 1743 | "source": [ 1744 | "encodings['input_ids']" 1745 | ], 1746 | "execution_count": 0, 1747 | "outputs": [ 1748 | { 1749 | "output_type": "execute_result", 1750 | "data": { 1751 | "text/plain": [ 1752 | "tensor([[ 837, 27, 94, 234, 9, 394, 7056, 41, 94, 4194, 21, 8301,\n", 1753 | " 4, 3]])" 1754 | ] 1755 | }, 1756 | "metadata": { 1757 | "tags": [] 1758 | }, 1759 | "execution_count": 24 1760 | } 1761 | ] 1762 | }, 1763 | { 1764 | "cell_type": "markdown", 1765 | "metadata": { 1766 | "id": "M12PPghlJUbg", 1767 | "colab_type": "text" 1768 | }, 1769 | "source": [ 1770 | "### Checking the distribution of token lengths" 1771 | ] 1772 | }, 1773 | { 1774 | "cell_type": "code", 1775 | "metadata": { 1776 | "id": "zW4_SrOw4BDy", 1777 | "colab_type": "code", 1778 | "colab": {} 1779 | }, 1780 | "source": [ 1781 | "token_lens = []\n", 1782 | "\n", 1783 | "for txt in df['review']:\n", 1784 | " tokens = tokenizer.encode(txt, max_length=512)\n", 1785 | " token_lens.append(len(tokens))" 1786 | ], 1787 | "execution_count": 0, 1788 | "outputs": [] 1789 | }, 1790 | { 1791 | "cell_type": "code", 1792 | "metadata": { 1793 | "id": "aSxO20TU4dz5", 1794 | "colab_type": "code", 1795 | "outputId": "f8379514-5ae2-4e43-bb19-12e29d370386", 1796 | "colab": { 1797 | "base_uri": "https://localhost:8080/", 1798 | "height": 398 1799 | } 1800 | }, 1801 | "source": [ 1802 | "sns.distplot(token_lens)\n", 1803 | "plt.xlim([0, 1024]);\n", 1804 | "plt.xlabel('Token count');" 1805 | ], 1806 | "execution_count": 0, 1807 | "outputs": [ 1808 | { 1809 | "output_type": "display_data", 1810 | "data": { 1811 | "image/png": "\n", 1812 | "text/plain": [ 1813 | "
" 1814 | ] 1815 | }, 1816 | "metadata": { 1817 | "tags": [], 1818 | "image/png": { 1819 | "width": 514, 1820 | "height": 381 1821 | } 1822 | } 1823 | } 1824 | ] 1825 | }, 1826 | { 1827 | "cell_type": "code", 1828 | "metadata": { 1829 | "id": "CR3rHUQR4pDE", 1830 | "colab_type": "code", 1831 | "colab": {} 1832 | }, 1833 | "source": [ 1834 | "MAX_LEN = 512" 1835 | ], 1836 | "execution_count": 0, 1837 | "outputs": [] 1838 | }, 1839 | { 1840 | "cell_type": "markdown", 1841 | "metadata": { 1842 | "id": "e6Kutw4dJyUG", 1843 | "colab_type": "text" 1844 | }, 1845 | "source": [ 1846 | "### Custom Dataset class" 1847 | ] 1848 | }, 1849 | { 1850 | "cell_type": "code", 1851 | "metadata": { 1852 | "id": "q2NOYXcjPK4z", 1853 | "colab_type": "code", 1854 | "colab": {} 1855 | }, 1856 | "source": [ 1857 | "class ImdbDataset(Dataset):\n", 1858 | "\n", 1859 | " def __init__(self, reviews, targets, tokenizer, max_len):\n", 1860 | " self.reviews = reviews\n", 1861 | " self.targets = targets\n", 1862 | " self.tokenizer = tokenizer\n", 1863 | " self.max_len = max_len\n", 1864 | " \n", 1865 | " def __len__(self):\n", 1866 | " return len(self.reviews)\n", 1867 | " \n", 1868 | " def __getitem__(self, item):\n", 1869 | " review = str(self.reviews[item])\n", 1870 | " target = self.targets[item]\n", 1871 | "\n", 1872 | " encoding = self.tokenizer.encode_plus(\n", 1873 | " review,\n", 1874 | " add_special_tokens=True,\n", 1875 | " max_length=self.max_len,\n", 1876 | " return_token_type_ids=False,\n", 1877 | " pad_to_max_length=False,\n", 1878 | " return_attention_mask=True,\n", 1879 | " return_tensors='pt',\n", 1880 | " )\n", 1881 | "\n", 1882 | " input_ids = pad_sequences(encoding['input_ids'], maxlen=MAX_LEN, dtype=torch.Tensor ,truncating=\"post\",padding=\"post\")\n", 1883 | " input_ids = input_ids.astype(dtype = 'int64')\n", 1884 | " input_ids = torch.tensor(input_ids) \n", 1885 | "\n", 1886 | " attention_mask = pad_sequences(encoding['attention_mask'], maxlen=MAX_LEN, dtype=torch.Tensor ,truncating=\"post\",padding=\"post\")\n", 1887 | " attention_mask = attention_mask.astype(dtype = 'int64')\n", 1888 | " attention_mask = torch.tensor(attention_mask) \n", 1889 | "\n", 1890 | " return {\n", 1891 | " 'review_text': review,\n", 1892 | " 'input_ids': input_ids,\n", 1893 | " 'attention_mask': attention_mask.flatten(),\n", 1894 | " 'targets': torch.tensor(target, dtype=torch.long)\n", 1895 | " }" 1896 | ], 1897 | "execution_count": 0, 1898 | "outputs": [] 1899 | }, 1900 | { 1901 | "cell_type": "code", 1902 | "metadata": { 1903 | "id": "wYXt2AtW6iaT", 1904 | "colab_type": "code", 1905 | "colab": {} 1906 | }, 1907 | "source": [ 1908 | "df_train, df_test = train_test_split(df, test_size=0.5, random_state=101)\n", 1909 | "df_val, df_test = train_test_split(df_test, test_size=0.5, random_state=101)" 1910 | ], 1911 | "execution_count": 0, 1912 | "outputs": [] 1913 | }, 1914 | { 1915 | "cell_type": "code", 1916 | "metadata": { 1917 | "id": "VmAsa9pg6oac", 1918 | "colab_type": "code", 1919 | "outputId": "7733a944-55b3-4181-f84a-be27f5eb3119", 1920 | "colab": { 1921 | "base_uri": "https://localhost:8080/", 1922 | "height": 33 1923 | } 1924 | }, 1925 | "source": [ 1926 | "df_train.shape, df_val.shape, df_test.shape" 1927 | ], 1928 | "execution_count": 0, 1929 | "outputs": [ 1930 | { 1931 | "output_type": "execute_result", 1932 | "data": { 1933 | "text/plain": [ 1934 | "((12000, 2), (6000, 2), (6000, 2))" 1935 | ] 1936 | }, 1937 | "metadata": { 1938 | "tags": [] 1939 | }, 1940 | "execution_count": 30 1941 | } 1942 | ] 1943 | }, 1944 | { 1945 | "cell_type": "markdown", 1946 | "metadata": { 1947 | "id": "iFw2z6ElMZMX", 1948 | "colab_type": "text" 1949 | }, 1950 | "source": [ 1951 | "### Custom Dataloader" 1952 | ] 1953 | }, 1954 | { 1955 | "cell_type": "code", 1956 | "metadata": { 1957 | "id": "3rd7890Z6zLr", 1958 | "colab_type": "code", 1959 | "colab": {} 1960 | }, 1961 | "source": [ 1962 | "def create_data_loader(df, tokenizer, max_len, batch_size):\n", 1963 | " ds = ImdbDataset(\n", 1964 | " reviews=df.review.to_numpy(),\n", 1965 | " targets=df.sentiment.to_numpy(),\n", 1966 | " tokenizer=tokenizer,\n", 1967 | " max_len=max_len\n", 1968 | " )\n", 1969 | "\n", 1970 | " return DataLoader(\n", 1971 | " ds,\n", 1972 | " batch_size=batch_size,\n", 1973 | " num_workers=4\n", 1974 | " )" 1975 | ], 1976 | "execution_count": 0, 1977 | "outputs": [] 1978 | }, 1979 | { 1980 | "cell_type": "code", 1981 | "metadata": { 1982 | "id": "tVU8o6i569ly", 1983 | "colab_type": "code", 1984 | "colab": {} 1985 | }, 1986 | "source": [ 1987 | "BATCH_SIZE = 4\n", 1988 | "\n", 1989 | "train_data_loader = create_data_loader(df_train, tokenizer, MAX_LEN, BATCH_SIZE)\n", 1990 | "val_data_loader = create_data_loader(df_val, tokenizer, MAX_LEN, BATCH_SIZE)\n", 1991 | "test_data_loader = create_data_loader(df_test, tokenizer, MAX_LEN, BATCH_SIZE)" 1992 | ], 1993 | "execution_count": 0, 1994 | "outputs": [] 1995 | }, 1996 | { 1997 | "cell_type": "markdown", 1998 | "metadata": { 1999 | "id": "aC5D5Dh8J5w9", 2000 | "colab_type": "text" 2001 | }, 2002 | "source": [ 2003 | "### Loading the Pre-trained XLNet model for sequence classification from huggingface transformers" 2004 | ] 2005 | }, 2006 | { 2007 | "cell_type": "code", 2008 | "metadata": { 2009 | "id": "H5mC8v6i7VH1", 2010 | "colab_type": "code", 2011 | "outputId": "0bc5d6fe-d14b-455b-cba5-00ce32812140", 2012 | "colab": { 2013 | "base_uri": "https://localhost:8080/", 2014 | "height": 114, 2015 | "referenced_widgets": [ 2016 | "c2b26508fa464328b2b30c851a0366fe", 2017 | "18e0ec87d07d4c0bb3930b5cef204963", 2018 | "bac82e5dc1974ef4ac18f743ecfb72b5", 2019 | "35d3d0c2b7004a6b87420e55a94ddcaf", 2020 | "d8893bea6edb4154b051b344cae1636c", 2021 | "d8c49d3421e843a6a2aadd20565ea5d2", 2022 | "3524b098b1944e5ba752f1fbfcb50aef", 2023 | "0cba445897514d4b98192e28890df304", 2024 | "62d0e9673815468ba407019771a0b2ba", 2025 | "6af011b92bd3462b937cea74a2094579", 2026 | "47a41f0a294f42cbb2c9ec3916ca100e", 2027 | "3d9a5c47f9fb44ffa41f91cc37b8ee6e", 2028 | "767ef584a8ed49b4ae8ee7835208fb44", 2029 | "769eba56a67d4d71b76f1240fd154a4e", 2030 | "f6cc72303ee14b548a2c4c1282a1784b", 2031 | "b9235fdb071742628781944bb3b6608d" 2032 | ] 2033 | } 2034 | }, 2035 | "source": [ 2036 | "from transformers import XLNetForSequenceClassification\n", 2037 | "model = XLNetForSequenceClassification.from_pretrained('xlnet-base-cased', num_labels = 2)\n", 2038 | "model = model.to(device)" 2039 | ], 2040 | "execution_count": 0, 2041 | "outputs": [ 2042 | { 2043 | "output_type": "display_data", 2044 | "data": { 2045 | "application/vnd.jupyter.widget-view+json": { 2046 | "model_id": "c2b26508fa464328b2b30c851a0366fe", 2047 | "version_minor": 0, 2048 | "version_major": 2 2049 | }, 2050 | "text/plain": [ 2051 | "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=760.0, style=ProgressStyle(description_…" 2052 | ] 2053 | }, 2054 | "metadata": { 2055 | "tags": [] 2056 | } 2057 | }, 2058 | { 2059 | "output_type": "stream", 2060 | "text": [ 2061 | "\n" 2062 | ], 2063 | "name": "stdout" 2064 | }, 2065 | { 2066 | "output_type": "display_data", 2067 | "data": { 2068 | "application/vnd.jupyter.widget-view+json": { 2069 | "model_id": "62d0e9673815468ba407019771a0b2ba", 2070 | "version_minor": 0, 2071 | "version_major": 2 2072 | }, 2073 | "text/plain": [ 2074 | "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=467042463.0, style=ProgressStyle(descri…" 2075 | ] 2076 | }, 2077 | "metadata": { 2078 | "tags": [] 2079 | } 2080 | }, 2081 | { 2082 | "output_type": "stream", 2083 | "text": [ 2084 | "\n" 2085 | ], 2086 | "name": "stdout" 2087 | } 2088 | ] 2089 | }, 2090 | { 2091 | "cell_type": "code", 2092 | "metadata": { 2093 | "id": "KYsVoULvfmvD", 2094 | "colab_type": "code", 2095 | "outputId": "f744743c-6554-4bc2-b3b0-1236164352fc", 2096 | "colab": { 2097 | "base_uri": "https://localhost:8080/", 2098 | "height": 1000 2099 | } 2100 | }, 2101 | "source": [ 2102 | "model" 2103 | ], 2104 | "execution_count": 0, 2105 | "outputs": [ 2106 | { 2107 | "output_type": "execute_result", 2108 | "data": { 2109 | "text/plain": [ 2110 | "XLNetForSequenceClassification(\n", 2111 | " (transformer): XLNetModel(\n", 2112 | " (word_embedding): Embedding(32000, 768)\n", 2113 | " (layer): ModuleList(\n", 2114 | " (0): XLNetLayer(\n", 2115 | " (rel_attn): XLNetRelativeAttention(\n", 2116 | " (layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", 2117 | " (dropout): Dropout(p=0.1, inplace=False)\n", 2118 | " )\n", 2119 | " (ff): XLNetFeedForward(\n", 2120 | " (layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", 2121 | " (layer_1): Linear(in_features=768, out_features=3072, bias=True)\n", 2122 | " (layer_2): Linear(in_features=3072, out_features=768, bias=True)\n", 2123 | " (dropout): Dropout(p=0.1, inplace=False)\n", 2124 | " )\n", 2125 | " (dropout): Dropout(p=0.1, inplace=False)\n", 2126 | " )\n", 2127 | " (1): XLNetLayer(\n", 2128 | " (rel_attn): XLNetRelativeAttention(\n", 2129 | " (layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", 2130 | " (dropout): Dropout(p=0.1, inplace=False)\n", 2131 | " )\n", 2132 | " (ff): XLNetFeedForward(\n", 2133 | " (layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", 2134 | " (layer_1): Linear(in_features=768, out_features=3072, bias=True)\n", 2135 | " (layer_2): Linear(in_features=3072, out_features=768, bias=True)\n", 2136 | " (dropout): Dropout(p=0.1, inplace=False)\n", 2137 | " )\n", 2138 | " (dropout): Dropout(p=0.1, inplace=False)\n", 2139 | " )\n", 2140 | " (2): XLNetLayer(\n", 2141 | " (rel_attn): XLNetRelativeAttention(\n", 2142 | " (layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", 2143 | " (dropout): Dropout(p=0.1, inplace=False)\n", 2144 | " )\n", 2145 | " (ff): XLNetFeedForward(\n", 2146 | " (layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", 2147 | " (layer_1): Linear(in_features=768, out_features=3072, bias=True)\n", 2148 | " (layer_2): Linear(in_features=3072, out_features=768, bias=True)\n", 2149 | " (dropout): Dropout(p=0.1, inplace=False)\n", 2150 | " )\n", 2151 | " (dropout): Dropout(p=0.1, inplace=False)\n", 2152 | " )\n", 2153 | " (3): XLNetLayer(\n", 2154 | " (rel_attn): XLNetRelativeAttention(\n", 2155 | " (layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", 2156 | " (dropout): Dropout(p=0.1, inplace=False)\n", 2157 | " )\n", 2158 | " (ff): XLNetFeedForward(\n", 2159 | " (layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", 2160 | " (layer_1): Linear(in_features=768, out_features=3072, bias=True)\n", 2161 | " (layer_2): Linear(in_features=3072, out_features=768, bias=True)\n", 2162 | " (dropout): Dropout(p=0.1, inplace=False)\n", 2163 | " )\n", 2164 | " (dropout): Dropout(p=0.1, inplace=False)\n", 2165 | " )\n", 2166 | " (4): XLNetLayer(\n", 2167 | " (rel_attn): XLNetRelativeAttention(\n", 2168 | " (layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", 2169 | " (dropout): Dropout(p=0.1, inplace=False)\n", 2170 | " )\n", 2171 | " (ff): XLNetFeedForward(\n", 2172 | " (layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", 2173 | " (layer_1): Linear(in_features=768, out_features=3072, bias=True)\n", 2174 | " (layer_2): Linear(in_features=3072, out_features=768, bias=True)\n", 2175 | " (dropout): Dropout(p=0.1, inplace=False)\n", 2176 | " )\n", 2177 | " (dropout): Dropout(p=0.1, inplace=False)\n", 2178 | " )\n", 2179 | " (5): XLNetLayer(\n", 2180 | " (rel_attn): XLNetRelativeAttention(\n", 2181 | " (layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", 2182 | " (dropout): Dropout(p=0.1, inplace=False)\n", 2183 | " )\n", 2184 | " (ff): XLNetFeedForward(\n", 2185 | " (layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", 2186 | " (layer_1): Linear(in_features=768, out_features=3072, bias=True)\n", 2187 | " (layer_2): Linear(in_features=3072, out_features=768, bias=True)\n", 2188 | " (dropout): Dropout(p=0.1, inplace=False)\n", 2189 | " )\n", 2190 | " (dropout): Dropout(p=0.1, inplace=False)\n", 2191 | " )\n", 2192 | " (6): XLNetLayer(\n", 2193 | " (rel_attn): XLNetRelativeAttention(\n", 2194 | " (layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", 2195 | " (dropout): Dropout(p=0.1, inplace=False)\n", 2196 | " )\n", 2197 | " (ff): XLNetFeedForward(\n", 2198 | " (layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", 2199 | " (layer_1): Linear(in_features=768, out_features=3072, bias=True)\n", 2200 | " (layer_2): Linear(in_features=3072, out_features=768, bias=True)\n", 2201 | " (dropout): Dropout(p=0.1, inplace=False)\n", 2202 | " )\n", 2203 | " (dropout): Dropout(p=0.1, inplace=False)\n", 2204 | " )\n", 2205 | " (7): XLNetLayer(\n", 2206 | " (rel_attn): XLNetRelativeAttention(\n", 2207 | " (layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", 2208 | " (dropout): Dropout(p=0.1, inplace=False)\n", 2209 | " )\n", 2210 | " (ff): XLNetFeedForward(\n", 2211 | " (layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", 2212 | " (layer_1): Linear(in_features=768, out_features=3072, bias=True)\n", 2213 | " (layer_2): Linear(in_features=3072, out_features=768, bias=True)\n", 2214 | " (dropout): Dropout(p=0.1, inplace=False)\n", 2215 | " )\n", 2216 | " (dropout): Dropout(p=0.1, inplace=False)\n", 2217 | " )\n", 2218 | " (8): XLNetLayer(\n", 2219 | " (rel_attn): XLNetRelativeAttention(\n", 2220 | " (layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", 2221 | " (dropout): Dropout(p=0.1, inplace=False)\n", 2222 | " )\n", 2223 | " (ff): XLNetFeedForward(\n", 2224 | " (layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", 2225 | " (layer_1): Linear(in_features=768, out_features=3072, bias=True)\n", 2226 | " (layer_2): Linear(in_features=3072, out_features=768, bias=True)\n", 2227 | " (dropout): Dropout(p=0.1, inplace=False)\n", 2228 | " )\n", 2229 | " (dropout): Dropout(p=0.1, inplace=False)\n", 2230 | " )\n", 2231 | " (9): XLNetLayer(\n", 2232 | " (rel_attn): XLNetRelativeAttention(\n", 2233 | " (layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", 2234 | " (dropout): Dropout(p=0.1, inplace=False)\n", 2235 | " )\n", 2236 | " (ff): XLNetFeedForward(\n", 2237 | " (layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", 2238 | " (layer_1): Linear(in_features=768, out_features=3072, bias=True)\n", 2239 | " (layer_2): Linear(in_features=3072, out_features=768, bias=True)\n", 2240 | " (dropout): Dropout(p=0.1, inplace=False)\n", 2241 | " )\n", 2242 | " (dropout): Dropout(p=0.1, inplace=False)\n", 2243 | " )\n", 2244 | " (10): XLNetLayer(\n", 2245 | " (rel_attn): XLNetRelativeAttention(\n", 2246 | " (layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", 2247 | " (dropout): Dropout(p=0.1, inplace=False)\n", 2248 | " )\n", 2249 | " (ff): XLNetFeedForward(\n", 2250 | " (layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", 2251 | " (layer_1): Linear(in_features=768, out_features=3072, bias=True)\n", 2252 | " (layer_2): Linear(in_features=3072, out_features=768, bias=True)\n", 2253 | " (dropout): Dropout(p=0.1, inplace=False)\n", 2254 | " )\n", 2255 | " (dropout): Dropout(p=0.1, inplace=False)\n", 2256 | " )\n", 2257 | " (11): XLNetLayer(\n", 2258 | " (rel_attn): XLNetRelativeAttention(\n", 2259 | " (layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", 2260 | " (dropout): Dropout(p=0.1, inplace=False)\n", 2261 | " )\n", 2262 | " (ff): XLNetFeedForward(\n", 2263 | " (layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", 2264 | " (layer_1): Linear(in_features=768, out_features=3072, bias=True)\n", 2265 | " (layer_2): Linear(in_features=3072, out_features=768, bias=True)\n", 2266 | " (dropout): Dropout(p=0.1, inplace=False)\n", 2267 | " )\n", 2268 | " (dropout): Dropout(p=0.1, inplace=False)\n", 2269 | " )\n", 2270 | " )\n", 2271 | " (dropout): Dropout(p=0.1, inplace=False)\n", 2272 | " )\n", 2273 | " (sequence_summary): SequenceSummary(\n", 2274 | " (summary): Linear(in_features=768, out_features=768, bias=True)\n", 2275 | " (first_dropout): Identity()\n", 2276 | " (last_dropout): Dropout(p=0.1, inplace=False)\n", 2277 | " )\n", 2278 | " (logits_proj): Linear(in_features=768, out_features=2, bias=True)\n", 2279 | ")" 2280 | ] 2281 | }, 2282 | "metadata": { 2283 | "tags": [] 2284 | }, 2285 | "execution_count": 34 2286 | } 2287 | ] 2288 | }, 2289 | { 2290 | "cell_type": "markdown", 2291 | "metadata": { 2292 | "id": "vpn2sTTMK_zL", 2293 | "colab_type": "text" 2294 | }, 2295 | "source": [ 2296 | "### Setting Hyperparameters" 2297 | ] 2298 | }, 2299 | { 2300 | "cell_type": "code", 2301 | "metadata": { 2302 | "id": "aQ9Od31B9YJa", 2303 | "colab_type": "code", 2304 | "colab": {} 2305 | }, 2306 | "source": [ 2307 | "EPOCHS = 3\n", 2308 | "\n", 2309 | "param_optimizer = list(model.named_parameters())\n", 2310 | "no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']\n", 2311 | "optimizer_grouped_parameters = [\n", 2312 | " {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},\n", 2313 | " {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay':0.0}\n", 2314 | "]\n", 2315 | "optimizer = AdamW(optimizer_grouped_parameters, lr=3e-5)\n", 2316 | "\n", 2317 | "total_steps = len(train_data_loader) * EPOCHS\n", 2318 | "\n", 2319 | "scheduler = get_linear_schedule_with_warmup(\n", 2320 | " optimizer,\n", 2321 | " num_warmup_steps=0,\n", 2322 | " num_training_steps=total_steps\n", 2323 | ")" 2324 | ], 2325 | "execution_count": 0, 2326 | "outputs": [] 2327 | }, 2328 | { 2329 | "cell_type": "markdown", 2330 | "metadata": { 2331 | "id": "G2AtJcvwLV3x", 2332 | "colab_type": "text" 2333 | }, 2334 | "source": [ 2335 | "### Sanity check with one batch" 2336 | ] 2337 | }, 2338 | { 2339 | "cell_type": "code", 2340 | "metadata": { 2341 | "id": "rUoaMqjvKdym", 2342 | "colab_type": "code", 2343 | "outputId": "e9f8a354-bfd6-4ad9-f139-119e937d1a05", 2344 | "colab": { 2345 | "base_uri": "https://localhost:8080/", 2346 | "height": 33 2347 | } 2348 | }, 2349 | "source": [ 2350 | "data = next(iter(val_data_loader))\n", 2351 | "data.keys()" 2352 | ], 2353 | "execution_count": 0, 2354 | "outputs": [ 2355 | { 2356 | "output_type": "execute_result", 2357 | "data": { 2358 | "text/plain": [ 2359 | "dict_keys(['review_text', 'input_ids', 'attention_mask', 'targets'])" 2360 | ] 2361 | }, 2362 | "metadata": { 2363 | "tags": [] 2364 | }, 2365 | "execution_count": 36 2366 | } 2367 | ] 2368 | }, 2369 | { 2370 | "cell_type": "code", 2371 | "metadata": { 2372 | "id": "RIUB5WJNKhBs", 2373 | "colab_type": "code", 2374 | "outputId": "84dbb5d9-adb4-44a6-b8bb-1627fbf87592", 2375 | "colab": { 2376 | "base_uri": "https://localhost:8080/", 2377 | "height": 50 2378 | } 2379 | }, 2380 | "source": [ 2381 | "input_ids = data['input_ids'].to(device)\n", 2382 | "attention_mask = data['attention_mask'].to(device)\n", 2383 | "targets = data['targets'].to(device)\n", 2384 | "print(input_ids.reshape(4,512).shape) # batch size x seq length\n", 2385 | "print(attention_mask.shape) # batch size x seq length" 2386 | ], 2387 | "execution_count": 0, 2388 | "outputs": [ 2389 | { 2390 | "output_type": "stream", 2391 | "text": [ 2392 | "torch.Size([4, 512])\n", 2393 | "torch.Size([4, 512])\n" 2394 | ], 2395 | "name": "stdout" 2396 | } 2397 | ] 2398 | }, 2399 | { 2400 | "cell_type": "code", 2401 | "metadata": { 2402 | "id": "EYsDR9leYb4Z", 2403 | "colab_type": "code", 2404 | "outputId": "af950d9a-fad4-48a6-a49a-7b310f657794", 2405 | "colab": { 2406 | "base_uri": "https://localhost:8080/", 2407 | "height": 886 2408 | } 2409 | }, 2410 | "source": [ 2411 | "input_ids[0]" 2412 | ], 2413 | "execution_count": 0, 2414 | "outputs": [ 2415 | { 2416 | "output_type": "execute_result", 2417 | "data": { 2418 | "text/plain": [ 2419 | "tensor([[ 35, 26, 215, 435, 52, 1365, 21, 3353, 9, 4716,\n", 2420 | " 2537, 31, 58, 162, 192, 40, 24, 1288, 275, 2263,\n", 2421 | " 22, 182, 24, 17, 26, 16071, 778, 26, 29, 3469,\n", 2422 | " 107, 151, 24, 18041, 3401, 17, 16847, 939, 9, 1577,\n", 2423 | " 477, 805, 57, 195, 406, 1825, 22, 107, 21, 63,\n", 2424 | " 8008, 4623, 38, 58, 5754, 9, 17, 2369, 17, 2369,\n", 2425 | " 17, 29720, 63, 220, 26, 46, 9, 7174, 63, 26,\n", 2426 | " 88, 23064, 151, 37, 24, 321, 17785, 19167, 61, 51,\n", 2427 | " 106, 3404, 3728, 705, 21, 113, 30977, 23, 27, 2640,\n", 2428 | " 641, 20, 13943, 30436, 9, 17, 2369, 17, 2369, 35,\n", 2429 | " 937, 22, 287, 197, 124, 20, 166, 6454, 25, 124,\n", 2430 | " 1432, 545, 292, 9, 1988, 11247, 27288, 56, 9185, 675,\n", 2431 | " 110, 24, 16180, 21, 18, 4948, 30, 102, 10488, 1062,\n", 2432 | " 29, 36, 738, 24, 434, 20, 5632, 16055, 21, 555,\n", 2433 | " 15450, 4140, 30182, 11247, 9, 394, 25, 71, 65, 20,\n", 2434 | " 18, 2598, 3640, 6941, 17045, 23, 35, 26, 189, 545,\n", 2435 | " 566, 9, 9, 9, 443, 35, 26, 189, 566, 24,\n", 2436 | " 434, 9, 4, 3, 0, 0, 0, 0, 0, 0,\n", 2437 | " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", 2438 | " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", 2439 | " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", 2440 | " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", 2441 | " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", 2442 | " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", 2443 | " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", 2444 | " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", 2445 | " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", 2446 | " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", 2447 | " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", 2448 | " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", 2449 | " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", 2450 | " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", 2451 | " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", 2452 | " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", 2453 | " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", 2454 | " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", 2455 | " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", 2456 | " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", 2457 | " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", 2458 | " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", 2459 | " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", 2460 | " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", 2461 | " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", 2462 | " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", 2463 | " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", 2464 | " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", 2465 | " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", 2466 | " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", 2467 | " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", 2468 | " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", 2469 | " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", 2470 | " 0, 0]], device='cuda:0')" 2471 | ] 2472 | }, 2473 | "metadata": { 2474 | "tags": [] 2475 | }, 2476 | "execution_count": 38 2477 | } 2478 | ] 2479 | }, 2480 | { 2481 | "cell_type": "code", 2482 | "metadata": { 2483 | "id": "KCFLKXL0KmaA", 2484 | "colab_type": "code", 2485 | "outputId": "7e543f7e-a43d-4c30-91bf-0efe1d2d476c", 2486 | "colab": { 2487 | "base_uri": "https://localhost:8080/", 2488 | "height": 100 2489 | } 2490 | }, 2491 | "source": [ 2492 | "outputs = model(input_ids.reshape(4,512), token_type_ids=None, attention_mask=attention_mask, labels=targets)\n", 2493 | "outputs" 2494 | ], 2495 | "execution_count": 0, 2496 | "outputs": [ 2497 | { 2498 | "output_type": "execute_result", 2499 | "data": { 2500 | "text/plain": [ 2501 | "(tensor(0.7846, device='cuda:0', grad_fn=),\n", 2502 | " tensor([[-0.8857, -0.5108],\n", 2503 | " [-0.7872, -0.4824],\n", 2504 | " [-0.7139, -0.5388],\n", 2505 | " [-0.7889, -0.4051]], device='cuda:0', grad_fn=))" 2506 | ] 2507 | }, 2508 | "metadata": { 2509 | "tags": [] 2510 | }, 2511 | "execution_count": 39 2512 | } 2513 | ] 2514 | }, 2515 | { 2516 | "cell_type": "code", 2517 | "metadata": { 2518 | "id": "zbyJNtqVkg4Y", 2519 | "colab_type": "code", 2520 | "outputId": "207c8486-99f7-4ae7-d930-dde9419dd99f", 2521 | "colab": { 2522 | "base_uri": "https://localhost:8080/", 2523 | "height": 33 2524 | } 2525 | }, 2526 | "source": [ 2527 | "type(outputs[0])" 2528 | ], 2529 | "execution_count": 0, 2530 | "outputs": [ 2531 | { 2532 | "output_type": "execute_result", 2533 | "data": { 2534 | "text/plain": [ 2535 | "torch.Tensor" 2536 | ] 2537 | }, 2538 | "metadata": { 2539 | "tags": [] 2540 | }, 2541 | "execution_count": 42 2542 | } 2543 | ] 2544 | }, 2545 | { 2546 | "cell_type": "markdown", 2547 | "metadata": { 2548 | "id": "eg44cHnNLd3J", 2549 | "colab_type": "text" 2550 | }, 2551 | "source": [ 2552 | "### Defining the training step function" 2553 | ] 2554 | }, 2555 | { 2556 | "cell_type": "code", 2557 | "metadata": { 2558 | "id": "tPnWttRNMArt", 2559 | "colab_type": "code", 2560 | "colab": {} 2561 | }, 2562 | "source": [ 2563 | "from sklearn import metrics\n", 2564 | "def train_epoch(model, data_loader, optimizer, device, scheduler, n_examples):\n", 2565 | " model = model.train()\n", 2566 | " losses = []\n", 2567 | " acc = 0\n", 2568 | " counter = 0\n", 2569 | " \n", 2570 | " for d in data_loader:\n", 2571 | " input_ids = d[\"input_ids\"].reshape(4,512).to(device)\n", 2572 | " attention_mask = d[\"attention_mask\"].to(device)\n", 2573 | " targets = d[\"targets\"].to(device)\n", 2574 | " \n", 2575 | " outputs = model(input_ids=input_ids, token_type_ids=None, attention_mask=attention_mask, labels = targets)\n", 2576 | " loss = outputs[0]\n", 2577 | " logits = outputs[1]\n", 2578 | "\n", 2579 | " # preds = preds.cpu().detach().numpy()\n", 2580 | " _, prediction = torch.max(outputs[1], dim=1)\n", 2581 | " targets = targets.cpu().detach().numpy()\n", 2582 | " prediction = prediction.cpu().detach().numpy()\n", 2583 | " accuracy = metrics.accuracy_score(targets, prediction)\n", 2584 | "\n", 2585 | " acc += accuracy\n", 2586 | " losses.append(loss.item())\n", 2587 | " \n", 2588 | " loss.backward()\n", 2589 | "\n", 2590 | " nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)\n", 2591 | " optimizer.step()\n", 2592 | " scheduler.step()\n", 2593 | " optimizer.zero_grad()\n", 2594 | " counter = counter + 1\n", 2595 | "\n", 2596 | " return acc / counter, np.mean(losses)" 2597 | ], 2598 | "execution_count": 0, 2599 | "outputs": [] 2600 | }, 2601 | { 2602 | "cell_type": "markdown", 2603 | "metadata": { 2604 | "id": "4N2ktqT8LoDS", 2605 | "colab_type": "text" 2606 | }, 2607 | "source": [ 2608 | "### Defining the evaluation function" 2609 | ] 2610 | }, 2611 | { 2612 | "cell_type": "code", 2613 | "metadata": { 2614 | "id": "V_ZyoJ4qb-CB", 2615 | "colab_type": "code", 2616 | "colab": {} 2617 | }, 2618 | "source": [ 2619 | "def eval_model(model, data_loader, device, n_examples):\n", 2620 | " model = model.eval()\n", 2621 | " losses = []\n", 2622 | " acc = 0\n", 2623 | " counter = 0\n", 2624 | " \n", 2625 | " with torch.no_grad():\n", 2626 | " for d in data_loader:\n", 2627 | " input_ids = d[\"input_ids\"].reshape(4,512).to(device)\n", 2628 | " attention_mask = d[\"attention_mask\"].to(device)\n", 2629 | " targets = d[\"targets\"].to(device)\n", 2630 | " \n", 2631 | " outputs = model(input_ids=input_ids, token_type_ids=None, attention_mask=attention_mask, labels = targets)\n", 2632 | " loss = outputs[0]\n", 2633 | " logits = outputs[1]\n", 2634 | "\n", 2635 | " _, prediction = torch.max(outputs[1], dim=1)\n", 2636 | " targets = targets.cpu().detach().numpy()\n", 2637 | " prediction = prediction.cpu().detach().numpy()\n", 2638 | " accuracy = metrics.accuracy_score(targets, prediction)\n", 2639 | "\n", 2640 | " acc += accuracy\n", 2641 | " losses.append(loss.item())\n", 2642 | " counter += 1\n", 2643 | "\n", 2644 | " return acc / counter, np.mean(losses)" 2645 | ], 2646 | "execution_count": 0, 2647 | "outputs": [] 2648 | }, 2649 | { 2650 | "cell_type": "markdown", 2651 | "metadata": { 2652 | "id": "9AG1TE43LvkX", 2653 | "colab_type": "text" 2654 | }, 2655 | "source": [ 2656 | "### Fine-tuning the pre-trained model" 2657 | ] 2658 | }, 2659 | { 2660 | "cell_type": "code", 2661 | "metadata": { 2662 | "id": "eNSQSFkScp6f", 2663 | "colab_type": "code", 2664 | "outputId": "1970b1cb-c420-406d-d5e0-a7f490e45500", 2665 | "colab": { 2666 | "base_uri": "https://localhost:8080/", 2667 | "height": 301 2668 | } 2669 | }, 2670 | "source": [ 2671 | "%%time\n", 2672 | "history = defaultdict(list)\n", 2673 | "best_accuracy = 0\n", 2674 | "\n", 2675 | "for epoch in range(EPOCHS):\n", 2676 | " print(f'Epoch {epoch + 1}/{EPOCHS}')\n", 2677 | " print('-' * 10)\n", 2678 | "\n", 2679 | " train_acc, train_loss = train_epoch(\n", 2680 | " model,\n", 2681 | " train_data_loader, \n", 2682 | " optimizer, \n", 2683 | " device, \n", 2684 | " scheduler, \n", 2685 | " len(df_train)\n", 2686 | " )\n", 2687 | "\n", 2688 | " print(f'Train loss {train_loss} Train accuracy {train_acc}')\n", 2689 | "\n", 2690 | " val_acc, val_loss = eval_model(\n", 2691 | " model,\n", 2692 | " val_data_loader, \n", 2693 | " device, \n", 2694 | " len(df_val)\n", 2695 | " )\n", 2696 | "\n", 2697 | " print(f'Val loss {val_loss} Val accuracy {val_acc}')\n", 2698 | " print()\n", 2699 | "\n", 2700 | " history['train_acc'].append(train_acc)\n", 2701 | " history['train_loss'].append(train_loss)\n", 2702 | " history['val_acc'].append(val_acc)\n", 2703 | " history['val_loss'].append(val_loss)\n", 2704 | "\n", 2705 | " if val_acc > best_accuracy:\n", 2706 | " torch.save(model.state_dict(), '/content/drive/My Drive/NLP/Sentiment Analysis Series/models/xlnet_model.bin')\n", 2707 | " best_accuracy = val_acc" 2708 | ], 2709 | "execution_count": 0, 2710 | "outputs": [ 2711 | { 2712 | "output_type": "stream", 2713 | "text": [ 2714 | "Epoch 1/3\n", 2715 | "----------\n", 2716 | "Train loss 0.40229895541320243 Train accuracy 0.90525\n", 2717 | "Val loss 0.3111661048134168 Val accuracy 0.9308333333333333\n", 2718 | "\n", 2719 | "Epoch 2/3\n", 2720 | "----------\n", 2721 | "Train loss 0.2054168249045809 Train accuracy 0.9594166666666667\n", 2722 | "Val loss 0.3556234954794248 Val accuracy 0.9366666666666666\n", 2723 | "\n", 2724 | "Epoch 3/3\n", 2725 | "----------\n", 2726 | "Train loss 0.08638643393417199 Train accuracy 0.985\n", 2727 | "Val loss 0.3777355106075605 Val accuracy 0.9403333333333334\n", 2728 | "\n", 2729 | "CPU times: user 2h 31min 50s, sys: 1h 36min 22s, total: 4h 8min 12s\n", 2730 | "Wall time: 4h 8min 44s\n" 2731 | ], 2732 | "name": "stdout" 2733 | } 2734 | ] 2735 | }, 2736 | { 2737 | "cell_type": "markdown", 2738 | "metadata": { 2739 | "id": "qe08GjtNL-Sh", 2740 | "colab_type": "text" 2741 | }, 2742 | "source": [ 2743 | "### Evaluation of the fine-tuned model" 2744 | ] 2745 | }, 2746 | { 2747 | "cell_type": "code", 2748 | "metadata": { 2749 | "id": "XqtQkz2yrZE3", 2750 | "colab_type": "code", 2751 | "outputId": "b57d0042-0a00-4ae0-e177-3ebc889e29f8", 2752 | "colab": { 2753 | "base_uri": "https://localhost:8080/", 2754 | "height": 33 2755 | } 2756 | }, 2757 | "source": [ 2758 | "model.load_state_dict(torch.load('/content/drive/My Drive/NLP/Sentiment Analysis Series/models/xlnet_model.bin'))" 2759 | ], 2760 | "execution_count": 0, 2761 | "outputs": [ 2762 | { 2763 | "output_type": "execute_result", 2764 | "data": { 2765 | "text/plain": [ 2766 | "" 2767 | ] 2768 | }, 2769 | "metadata": { 2770 | "tags": [] 2771 | }, 2772 | "execution_count": 46 2773 | } 2774 | ] 2775 | }, 2776 | { 2777 | "cell_type": "code", 2778 | "metadata": { 2779 | "id": "26k0PMpdy1QT", 2780 | "colab_type": "code", 2781 | "colab": {} 2782 | }, 2783 | "source": [ 2784 | "model = model.to(device)" 2785 | ], 2786 | "execution_count": 0, 2787 | "outputs": [] 2788 | }, 2789 | { 2790 | "cell_type": "code", 2791 | "metadata": { 2792 | "id": "QHINzx6ezSD0", 2793 | "colab_type": "code", 2794 | "outputId": "48dd67b6-d26d-45cf-aeff-4c7eda62e5e5", 2795 | "colab": { 2796 | "base_uri": "https://localhost:8080/", 2797 | "height": 50 2798 | } 2799 | }, 2800 | "source": [ 2801 | "test_acc, test_loss = eval_model(\n", 2802 | " model,\n", 2803 | " test_data_loader,\n", 2804 | " device,\n", 2805 | " len(df_test)\n", 2806 | ")\n", 2807 | "\n", 2808 | "print('Test Accuracy :', test_acc)\n", 2809 | "print('Test Loss :', test_loss)" 2810 | ], 2811 | "execution_count": 0, 2812 | "outputs": [ 2813 | { 2814 | "output_type": "stream", 2815 | "text": [ 2816 | "Test Accuracy : 0.956\n", 2817 | "Test Loss : 0.2740427735249201\n" 2818 | ], 2819 | "name": "stdout" 2820 | } 2821 | ] 2822 | }, 2823 | { 2824 | "cell_type": "code", 2825 | "metadata": { 2826 | "id": "NBWsLq4yzubR", 2827 | "colab_type": "code", 2828 | "colab": {} 2829 | }, 2830 | "source": [ 2831 | "def get_predictions(model, data_loader):\n", 2832 | " model = model.eval()\n", 2833 | " \n", 2834 | " review_texts = []\n", 2835 | " predictions = []\n", 2836 | " prediction_probs = []\n", 2837 | " real_values = []\n", 2838 | "\n", 2839 | " with torch.no_grad():\n", 2840 | " for d in data_loader:\n", 2841 | "\n", 2842 | " texts = d[\"review_text\"]\n", 2843 | " input_ids = d[\"input_ids\"].reshape(4,512).to(device)\n", 2844 | " attention_mask = d[\"attention_mask\"].to(device)\n", 2845 | " targets = d[\"targets\"].to(device)\n", 2846 | " \n", 2847 | " outputs = model(input_ids=input_ids, token_type_ids=None, attention_mask=attention_mask, labels = targets)\n", 2848 | "\n", 2849 | " loss = outputs[0]\n", 2850 | " logits = outputs[1]\n", 2851 | " \n", 2852 | " _, preds = torch.max(outputs[1], dim=1)\n", 2853 | "\n", 2854 | " probs = F.softmax(outputs[1], dim=1)\n", 2855 | "\n", 2856 | " review_texts.extend(texts)\n", 2857 | " predictions.extend(preds)\n", 2858 | " prediction_probs.extend(probs)\n", 2859 | " real_values.extend(targets)\n", 2860 | "\n", 2861 | " predictions = torch.stack(predictions).cpu()\n", 2862 | " prediction_probs = torch.stack(prediction_probs).cpu()\n", 2863 | " real_values = torch.stack(real_values).cpu()\n", 2864 | " return review_texts, predictions, prediction_probs, real_values" 2865 | ], 2866 | "execution_count": 0, 2867 | "outputs": [] 2868 | }, 2869 | { 2870 | "cell_type": "code", 2871 | "metadata": { 2872 | "id": "hwCQaTFH5KWy", 2873 | "colab_type": "code", 2874 | "colab": {} 2875 | }, 2876 | "source": [ 2877 | "y_review_texts, y_pred, y_pred_probs, y_test = get_predictions(\n", 2878 | " model,\n", 2879 | " test_data_loader\n", 2880 | ")" 2881 | ], 2882 | "execution_count": 0, 2883 | "outputs": [] 2884 | }, 2885 | { 2886 | "cell_type": "code", 2887 | "metadata": { 2888 | "id": "wviSSrIP5Pvl", 2889 | "colab_type": "code", 2890 | "outputId": "87c89f11-e95b-4fdd-adf7-5601732d5080", 2891 | "colab": { 2892 | "base_uri": "https://localhost:8080/", 2893 | "height": 167 2894 | } 2895 | }, 2896 | "source": [ 2897 | "print(classification_report(y_test, y_pred, target_names=class_names))" 2898 | ], 2899 | "execution_count": 0, 2900 | "outputs": [ 2901 | { 2902 | "output_type": "stream", 2903 | "text": [ 2904 | " precision recall f1-score support\n", 2905 | "\n", 2906 | " negative 0.95 0.96 0.96 2976\n", 2907 | " positive 0.96 0.96 0.96 3024\n", 2908 | "\n", 2909 | " accuracy 0.96 6000\n", 2910 | " macro avg 0.96 0.96 0.96 6000\n", 2911 | "weighted avg 0.96 0.96 0.96 6000\n", 2912 | "\n" 2913 | ], 2914 | "name": "stdout" 2915 | } 2916 | ] 2917 | }, 2918 | { 2919 | "cell_type": "markdown", 2920 | "metadata": { 2921 | "id": "FdrnJkm-MGvv", 2922 | "colab_type": "text" 2923 | }, 2924 | "source": [ 2925 | "### Custom prediction function on raw text" 2926 | ] 2927 | }, 2928 | { 2929 | "cell_type": "code", 2930 | "metadata": { 2931 | "id": "lIxCoTjo6tWM", 2932 | "colab_type": "code", 2933 | "colab": {} 2934 | }, 2935 | "source": [ 2936 | "def predict_sentiment(text):\n", 2937 | " review_text = text\n", 2938 | "\n", 2939 | " encoded_review = tokenizer.encode_plus(\n", 2940 | " review_text,\n", 2941 | " max_length=MAX_LEN,\n", 2942 | " add_special_tokens=True,\n", 2943 | " return_token_type_ids=False,\n", 2944 | " pad_to_max_length=False,\n", 2945 | " return_attention_mask=True,\n", 2946 | " return_tensors='pt',\n", 2947 | " )\n", 2948 | "\n", 2949 | " input_ids = pad_sequences(encoded_review['input_ids'], maxlen=MAX_LEN, dtype=torch.Tensor ,truncating=\"post\",padding=\"post\")\n", 2950 | " input_ids = input_ids.astype(dtype = 'int64')\n", 2951 | " input_ids = torch.tensor(input_ids) \n", 2952 | "\n", 2953 | " attention_mask = pad_sequences(encoded_review['attention_mask'], maxlen=MAX_LEN, dtype=torch.Tensor ,truncating=\"post\",padding=\"post\")\n", 2954 | " attention_mask = attention_mask.astype(dtype = 'int64')\n", 2955 | " attention_mask = torch.tensor(attention_mask) \n", 2956 | "\n", 2957 | " input_ids = input_ids.reshape(1,512).to(device)\n", 2958 | " attention_mask = attention_mask.to(device)\n", 2959 | "\n", 2960 | " outputs = model(input_ids=input_ids, attention_mask=attention_mask)\n", 2961 | "\n", 2962 | " outputs = outputs[0][0].cpu().detach()\n", 2963 | "\n", 2964 | " probs = F.softmax(outputs, dim=-1).cpu().detach().numpy().tolist()\n", 2965 | " _, prediction = torch.max(outputs, dim =-1)\n", 2966 | "\n", 2967 | " print(\"Positive score:\", probs[1])\n", 2968 | " print(\"Negative score:\", probs[0])\n", 2969 | " print(f'Review text: {review_text}')\n", 2970 | " print(f'Sentiment : {class_names[prediction]}')" 2971 | ], 2972 | "execution_count": 0, 2973 | "outputs": [] 2974 | }, 2975 | { 2976 | "cell_type": "code", 2977 | "metadata": { 2978 | "id": "8UlPGZgNFV1j", 2979 | "colab_type": "code", 2980 | "outputId": "cd61d132-8ff8-4416-d4fe-7c6e16a55f04", 2981 | "colab": { 2982 | "base_uri": "https://localhost:8080/", 2983 | "height": 84 2984 | } 2985 | }, 2986 | "source": [ 2987 | "text = \"Movie is the worst one I have ever seen!! The story has no meaning at all\"\n", 2988 | "predict_sentiment(text)" 2989 | ], 2990 | "execution_count": 0, 2991 | "outputs": [ 2992 | { 2993 | "output_type": "stream", 2994 | "text": [ 2995 | "Positive score: 0.00017438380746170878\n", 2996 | "Negative score: 0.999825656414032\n", 2997 | "Review text: Movie is the worst one I have ever seen!! The story has no meaning at all\n", 2998 | "Sentiment : negative\n" 2999 | ], 3000 | "name": "stdout" 3001 | } 3002 | ] 3003 | }, 3004 | { 3005 | "cell_type": "code", 3006 | "metadata": { 3007 | "id": "2oO0OhNoF4wo", 3008 | "colab_type": "code", 3009 | "outputId": "bd7bd3c0-28c1-4df2-b86d-92327a7bac6d", 3010 | "colab": { 3011 | "base_uri": "https://localhost:8080/", 3012 | "height": 84 3013 | } 3014 | }, 3015 | "source": [ 3016 | "text = \"This is the best movie I have ever seen!! The story is such a motivation\"\n", 3017 | "predict_sentiment(text)" 3018 | ], 3019 | "execution_count": 0, 3020 | "outputs": [ 3021 | { 3022 | "output_type": "stream", 3023 | "text": [ 3024 | "Positive score: 0.9998512268066406\n", 3025 | "Negative score: 0.00014876725617796183\n", 3026 | "Review text: This is the best movie I have ever seen!! The story is such a motivation\n", 3027 | "Sentiment : positive\n" 3028 | ], 3029 | "name": "stdout" 3030 | } 3031 | ] 3032 | } 3033 | ] 3034 | } --------------------------------------------------------------------------------