├── GANBERT_pytorch.ipynb ├── LICENSE └── README.md /GANBERT_pytorch.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "accelerator": "GPU", 6 | "colab": { 7 | "name": "GANBERT_pytorch croce.ipynb", 8 | "provenance": [], 9 | "collapsed_sections": [], 10 | "include_colab_link": true 11 | }, 12 | "kernelspec": { 13 | "display_name": "Python 3", 14 | "name": "python3" 15 | }, 16 | "widgets": { 17 | "application/vnd.jupyter.widget-state+json": { 18 | "22cde8fbae4e49af993e33f3f2d9a28e": { 19 | "model_module": "@jupyter-widgets/controls", 20 | "model_name": "HBoxModel", 21 | "model_module_version": "1.5.0", 22 | "state": { 23 | "_view_name": "HBoxView", 24 | "_dom_classes": [], 25 | "_model_name": "HBoxModel", 26 | "_view_module": "@jupyter-widgets/controls", 27 | "_model_module_version": "1.5.0", 28 | "_view_count": null, 29 | "_view_module_version": "1.5.0", 30 | "box_style": "", 31 | "layout": "IPY_MODEL_2c671428c4df4355a7a53c71e9bd14ee", 32 | "_model_module": "@jupyter-widgets/controls", 33 | "children": [ 34 | "IPY_MODEL_44660941ffcc44beae72a1974d458583", 35 | "IPY_MODEL_afdbd837001c4c74b5ac332ca061ef3a", 36 | "IPY_MODEL_b489404bf53b4fab9bdb1c0c79a33008" 37 | ] 38 | } 39 | }, 40 | "2c671428c4df4355a7a53c71e9bd14ee": { 41 | "model_module": "@jupyter-widgets/base", 42 | "model_name": "LayoutModel", 43 | "model_module_version": "1.2.0", 44 | "state": { 45 | "_view_name": "LayoutView", 46 | "grid_template_rows": null, 47 | "right": null, 48 | "justify_content": null, 49 | "_view_module": "@jupyter-widgets/base", 50 | "overflow": null, 51 | "_model_module_version": "1.2.0", 52 | "_view_count": null, 53 | "flex_flow": null, 54 | "width": null, 55 | "min_width": null, 56 | "border": null, 57 | "align_items": null, 58 | "bottom": null, 59 | "_model_module": "@jupyter-widgets/base", 60 | "top": null, 61 | "grid_column": null, 62 | "overflow_y": null, 63 | "overflow_x": null, 64 | "grid_auto_flow": null, 65 | "grid_area": null, 66 | "grid_template_columns": null, 67 | "flex": null, 68 | "_model_name": "LayoutModel", 69 | "justify_items": null, 70 | "grid_row": null, 71 | "max_height": null, 72 | "align_content": null, 73 | "visibility": null, 74 | "align_self": null, 75 | "height": null, 76 | "min_height": null, 77 | "padding": null, 78 | "grid_auto_rows": null, 79 | "grid_gap": null, 80 | "max_width": null, 81 | "order": null, 82 | "_view_module_version": "1.2.0", 83 | "grid_template_areas": null, 84 | "object_position": null, 85 | "object_fit": null, 86 | "grid_auto_columns": null, 87 | "margin": null, 88 | "display": null, 89 | "left": null 90 | } 91 | }, 92 | "44660941ffcc44beae72a1974d458583": { 93 | "model_module": "@jupyter-widgets/controls", 94 | "model_name": "HTMLModel", 95 | "model_module_version": "1.5.0", 96 | "state": { 97 | "_view_name": "HTMLView", 98 | "style": "IPY_MODEL_2bfb43b9e8604cbf8ebfd162267f1b8a", 99 | "_dom_classes": [], 100 | "description": "", 101 | "_model_name": "HTMLModel", 102 | "placeholder": "​", 103 | "_view_module": "@jupyter-widgets/controls", 104 | "_model_module_version": "1.5.0", 105 | "value": "Downloading: 100%", 106 | "_view_count": null, 107 | "_view_module_version": "1.5.0", 108 | "description_tooltip": null, 109 | "_model_module": "@jupyter-widgets/controls", 110 | "layout": "IPY_MODEL_4f556917850542da8508f9f839cae9bc" 111 | } 112 | }, 113 | "afdbd837001c4c74b5ac332ca061ef3a": { 114 | "model_module": "@jupyter-widgets/controls", 115 | "model_name": "FloatProgressModel", 116 | "model_module_version": "1.5.0", 117 | "state": { 118 | "_view_name": "ProgressView", 119 | "style": "IPY_MODEL_28f045edfa48462fa96a94cffd3b143f", 120 | "_dom_classes": [], 121 | "description": "", 122 | "_model_name": "FloatProgressModel", 123 | "bar_style": "success", 124 | "max": 570, 125 | "_view_module": "@jupyter-widgets/controls", 126 | "_model_module_version": "1.5.0", 127 | "value": 570, 128 | "_view_count": null, 129 | "_view_module_version": "1.5.0", 130 | "orientation": "horizontal", 131 | "min": 0, 132 | "description_tooltip": null, 133 | "_model_module": "@jupyter-widgets/controls", 134 | "layout": "IPY_MODEL_14a6abbb244b41f89626a37640c63118" 135 | } 136 | }, 137 | "b489404bf53b4fab9bdb1c0c79a33008": { 138 | "model_module": "@jupyter-widgets/controls", 139 | "model_name": "HTMLModel", 140 | "model_module_version": "1.5.0", 141 | "state": { 142 | "_view_name": "HTMLView", 143 | "style": "IPY_MODEL_0dccd0ab28c34880be92b35aa05e6ffb", 144 | "_dom_classes": [], 145 | "description": "", 146 | "_model_name": "HTMLModel", 147 | "placeholder": "​", 148 | "_view_module": "@jupyter-widgets/controls", 149 | "_model_module_version": "1.5.0", 150 | "value": " 570/570 [00:00<00:00, 12.0kB/s]", 151 | "_view_count": null, 152 | "_view_module_version": "1.5.0", 153 | "description_tooltip": null, 154 | "_model_module": "@jupyter-widgets/controls", 155 | "layout": "IPY_MODEL_bf127a02cf6647aba4035c5cbfadc378" 156 | } 157 | }, 158 | "2bfb43b9e8604cbf8ebfd162267f1b8a": { 159 | "model_module": "@jupyter-widgets/controls", 160 | "model_name": "DescriptionStyleModel", 161 | "model_module_version": "1.5.0", 162 | "state": { 163 | "_view_name": "StyleView", 164 | "_model_name": "DescriptionStyleModel", 165 | "description_width": "", 166 | "_view_module": "@jupyter-widgets/base", 167 | "_model_module_version": "1.5.0", 168 | "_view_count": null, 169 | "_view_module_version": "1.2.0", 170 | "_model_module": "@jupyter-widgets/controls" 171 | } 172 | }, 173 | "4f556917850542da8508f9f839cae9bc": { 174 | "model_module": "@jupyter-widgets/base", 175 | "model_name": "LayoutModel", 176 | "model_module_version": "1.2.0", 177 | "state": { 178 | "_view_name": "LayoutView", 179 | "grid_template_rows": null, 180 | "right": null, 181 | "justify_content": null, 182 | "_view_module": "@jupyter-widgets/base", 183 | "overflow": null, 184 | "_model_module_version": "1.2.0", 185 | "_view_count": null, 186 | "flex_flow": null, 187 | "width": null, 188 | "min_width": null, 189 | "border": null, 190 | "align_items": null, 191 | "bottom": null, 192 | "_model_module": "@jupyter-widgets/base", 193 | "top": null, 194 | "grid_column": null, 195 | "overflow_y": null, 196 | "overflow_x": null, 197 | "grid_auto_flow": null, 198 | "grid_area": null, 199 | "grid_template_columns": null, 200 | "flex": null, 201 | "_model_name": "LayoutModel", 202 | "justify_items": null, 203 | "grid_row": null, 204 | "max_height": null, 205 | "align_content": null, 206 | "visibility": null, 207 | "align_self": null, 208 | "height": null, 209 | "min_height": null, 210 | "padding": null, 211 | "grid_auto_rows": null, 212 | "grid_gap": null, 213 | "max_width": null, 214 | "order": null, 215 | "_view_module_version": "1.2.0", 216 | "grid_template_areas": null, 217 | "object_position": null, 218 | "object_fit": null, 219 | "grid_auto_columns": null, 220 | "margin": null, 221 | "display": null, 222 | "left": null 223 | } 224 | }, 225 | "28f045edfa48462fa96a94cffd3b143f": { 226 | "model_module": "@jupyter-widgets/controls", 227 | "model_name": "ProgressStyleModel", 228 | "model_module_version": "1.5.0", 229 | "state": { 230 | "_view_name": "StyleView", 231 | "_model_name": "ProgressStyleModel", 232 | "description_width": "", 233 | "_view_module": "@jupyter-widgets/base", 234 | "_model_module_version": "1.5.0", 235 | "_view_count": null, 236 | "_view_module_version": "1.2.0", 237 | "bar_color": null, 238 | "_model_module": "@jupyter-widgets/controls" 239 | } 240 | }, 241 | "14a6abbb244b41f89626a37640c63118": { 242 | "model_module": "@jupyter-widgets/base", 243 | "model_name": "LayoutModel", 244 | "model_module_version": "1.2.0", 245 | "state": { 246 | "_view_name": "LayoutView", 247 | "grid_template_rows": null, 248 | "right": null, 249 | "justify_content": null, 250 | "_view_module": "@jupyter-widgets/base", 251 | "overflow": null, 252 | "_model_module_version": "1.2.0", 253 | "_view_count": null, 254 | "flex_flow": null, 255 | "width": null, 256 | "min_width": null, 257 | "border": null, 258 | "align_items": null, 259 | "bottom": null, 260 | "_model_module": "@jupyter-widgets/base", 261 | "top": null, 262 | "grid_column": null, 263 | "overflow_y": null, 264 | "overflow_x": null, 265 | "grid_auto_flow": null, 266 | "grid_area": null, 267 | "grid_template_columns": null, 268 | "flex": null, 269 | "_model_name": "LayoutModel", 270 | "justify_items": null, 271 | "grid_row": null, 272 | "max_height": null, 273 | "align_content": null, 274 | "visibility": null, 275 | "align_self": null, 276 | "height": null, 277 | "min_height": null, 278 | "padding": null, 279 | "grid_auto_rows": null, 280 | "grid_gap": null, 281 | "max_width": null, 282 | "order": null, 283 | "_view_module_version": "1.2.0", 284 | "grid_template_areas": null, 285 | "object_position": null, 286 | "object_fit": null, 287 | "grid_auto_columns": null, 288 | "margin": null, 289 | "display": null, 290 | "left": null 291 | } 292 | }, 293 | "0dccd0ab28c34880be92b35aa05e6ffb": { 294 | "model_module": "@jupyter-widgets/controls", 295 | "model_name": "DescriptionStyleModel", 296 | "model_module_version": "1.5.0", 297 | "state": { 298 | "_view_name": "StyleView", 299 | "_model_name": "DescriptionStyleModel", 300 | "description_width": "", 301 | "_view_module": "@jupyter-widgets/base", 302 | "_model_module_version": "1.5.0", 303 | "_view_count": null, 304 | "_view_module_version": "1.2.0", 305 | "_model_module": "@jupyter-widgets/controls" 306 | } 307 | }, 308 | "bf127a02cf6647aba4035c5cbfadc378": { 309 | "model_module": "@jupyter-widgets/base", 310 | "model_name": "LayoutModel", 311 | "model_module_version": "1.2.0", 312 | "state": { 313 | "_view_name": "LayoutView", 314 | "grid_template_rows": null, 315 | "right": null, 316 | "justify_content": null, 317 | "_view_module": "@jupyter-widgets/base", 318 | "overflow": null, 319 | "_model_module_version": "1.2.0", 320 | "_view_count": null, 321 | "flex_flow": null, 322 | "width": null, 323 | "min_width": null, 324 | "border": null, 325 | "align_items": null, 326 | "bottom": null, 327 | "_model_module": "@jupyter-widgets/base", 328 | "top": null, 329 | "grid_column": null, 330 | "overflow_y": null, 331 | "overflow_x": null, 332 | "grid_auto_flow": null, 333 | "grid_area": null, 334 | "grid_template_columns": null, 335 | "flex": null, 336 | "_model_name": "LayoutModel", 337 | "justify_items": null, 338 | "grid_row": null, 339 | "max_height": null, 340 | "align_content": null, 341 | "visibility": null, 342 | "align_self": null, 343 | "height": null, 344 | "min_height": null, 345 | "padding": null, 346 | "grid_auto_rows": null, 347 | "grid_gap": null, 348 | "max_width": null, 349 | "order": null, 350 | "_view_module_version": "1.2.0", 351 | "grid_template_areas": null, 352 | "object_position": null, 353 | "object_fit": null, 354 | "grid_auto_columns": null, 355 | "margin": null, 356 | "display": null, 357 | "left": null 358 | } 359 | }, 360 | "098f203f1209452a9d4192af92da7057": { 361 | "model_module": "@jupyter-widgets/controls", 362 | "model_name": "HBoxModel", 363 | "model_module_version": "1.5.0", 364 | "state": { 365 | "_view_name": "HBoxView", 366 | "_dom_classes": [], 367 | "_model_name": "HBoxModel", 368 | "_view_module": "@jupyter-widgets/controls", 369 | "_model_module_version": "1.5.0", 370 | "_view_count": null, 371 | "_view_module_version": "1.5.0", 372 | "box_style": "", 373 | "layout": "IPY_MODEL_ec76cfd2d1da498e9b66885ff1be46b3", 374 | "_model_module": "@jupyter-widgets/controls", 375 | "children": [ 376 | "IPY_MODEL_b94243bfea7e4441b809b38c7cecd875", 377 | "IPY_MODEL_0097aa33393342cd99be3dcc30edc5a0", 378 | "IPY_MODEL_293c4d8660f546f79b43e0dc63250c2f" 379 | ] 380 | } 381 | }, 382 | "ec76cfd2d1da498e9b66885ff1be46b3": { 383 | "model_module": "@jupyter-widgets/base", 384 | "model_name": "LayoutModel", 385 | "model_module_version": "1.2.0", 386 | "state": { 387 | "_view_name": "LayoutView", 388 | "grid_template_rows": null, 389 | "right": null, 390 | "justify_content": null, 391 | "_view_module": "@jupyter-widgets/base", 392 | "overflow": null, 393 | "_model_module_version": "1.2.0", 394 | "_view_count": null, 395 | "flex_flow": null, 396 | "width": null, 397 | "min_width": null, 398 | "border": null, 399 | "align_items": null, 400 | "bottom": null, 401 | "_model_module": "@jupyter-widgets/base", 402 | "top": null, 403 | "grid_column": null, 404 | "overflow_y": null, 405 | "overflow_x": null, 406 | "grid_auto_flow": null, 407 | "grid_area": null, 408 | "grid_template_columns": null, 409 | "flex": null, 410 | "_model_name": "LayoutModel", 411 | "justify_items": null, 412 | "grid_row": null, 413 | "max_height": null, 414 | "align_content": null, 415 | "visibility": null, 416 | "align_self": null, 417 | "height": null, 418 | "min_height": null, 419 | "padding": null, 420 | "grid_auto_rows": null, 421 | "grid_gap": null, 422 | "max_width": null, 423 | "order": null, 424 | "_view_module_version": "1.2.0", 425 | "grid_template_areas": null, 426 | "object_position": null, 427 | "object_fit": null, 428 | "grid_auto_columns": null, 429 | "margin": null, 430 | "display": null, 431 | "left": null 432 | } 433 | }, 434 | "b94243bfea7e4441b809b38c7cecd875": { 435 | "model_module": "@jupyter-widgets/controls", 436 | "model_name": "HTMLModel", 437 | "model_module_version": "1.5.0", 438 | "state": { 439 | "_view_name": "HTMLView", 440 | "style": "IPY_MODEL_9fc9479d78e442db91360de7f45b6d7f", 441 | "_dom_classes": [], 442 | "description": "", 443 | "_model_name": "HTMLModel", 444 | "placeholder": "​", 445 | "_view_module": "@jupyter-widgets/controls", 446 | "_model_module_version": "1.5.0", 447 | "value": "Downloading: 100%", 448 | "_view_count": null, 449 | "_view_module_version": "1.5.0", 450 | "description_tooltip": null, 451 | "_model_module": "@jupyter-widgets/controls", 452 | "layout": "IPY_MODEL_14bf2ab82bfc45bd8a3b93f2ab9aa656" 453 | } 454 | }, 455 | "0097aa33393342cd99be3dcc30edc5a0": { 456 | "model_module": "@jupyter-widgets/controls", 457 | "model_name": "FloatProgressModel", 458 | "model_module_version": "1.5.0", 459 | "state": { 460 | "_view_name": "ProgressView", 461 | "style": "IPY_MODEL_a56008dc6c6b41c3850611cc15fb6ea8", 462 | "_dom_classes": [], 463 | "description": "", 464 | "_model_name": "FloatProgressModel", 465 | "bar_style": "success", 466 | "max": 435779157, 467 | "_view_module": "@jupyter-widgets/controls", 468 | "_model_module_version": "1.5.0", 469 | "value": 435779157, 470 | "_view_count": null, 471 | "_view_module_version": "1.5.0", 472 | "orientation": "horizontal", 473 | "min": 0, 474 | "description_tooltip": null, 475 | "_model_module": "@jupyter-widgets/controls", 476 | "layout": "IPY_MODEL_504e8c3a107e4e04b51df39e1b3c584e" 477 | } 478 | }, 479 | "293c4d8660f546f79b43e0dc63250c2f": { 480 | "model_module": "@jupyter-widgets/controls", 481 | "model_name": "HTMLModel", 482 | "model_module_version": "1.5.0", 483 | "state": { 484 | "_view_name": "HTMLView", 485 | "style": "IPY_MODEL_1f25240f5457445795af70e20e5903f9", 486 | "_dom_classes": [], 487 | "description": "", 488 | "_model_name": "HTMLModel", 489 | "placeholder": "​", 490 | "_view_module": "@jupyter-widgets/controls", 491 | "_model_module_version": "1.5.0", 492 | "value": " 436M/436M [00:29<00:00, 14.8MB/s]", 493 | "_view_count": null, 494 | "_view_module_version": "1.5.0", 495 | "description_tooltip": null, 496 | "_model_module": "@jupyter-widgets/controls", 497 | "layout": "IPY_MODEL_5110d1cb9a7547f49244113dc5dd8321" 498 | } 499 | }, 500 | "9fc9479d78e442db91360de7f45b6d7f": { 501 | "model_module": "@jupyter-widgets/controls", 502 | "model_name": "DescriptionStyleModel", 503 | "model_module_version": "1.5.0", 504 | "state": { 505 | "_view_name": "StyleView", 506 | "_model_name": "DescriptionStyleModel", 507 | "description_width": "", 508 | "_view_module": "@jupyter-widgets/base", 509 | "_model_module_version": "1.5.0", 510 | "_view_count": null, 511 | "_view_module_version": "1.2.0", 512 | "_model_module": "@jupyter-widgets/controls" 513 | } 514 | }, 515 | "14bf2ab82bfc45bd8a3b93f2ab9aa656": { 516 | "model_module": "@jupyter-widgets/base", 517 | "model_name": "LayoutModel", 518 | "model_module_version": "1.2.0", 519 | "state": { 520 | "_view_name": "LayoutView", 521 | "grid_template_rows": null, 522 | "right": null, 523 | "justify_content": null, 524 | "_view_module": "@jupyter-widgets/base", 525 | "overflow": null, 526 | "_model_module_version": "1.2.0", 527 | "_view_count": null, 528 | "flex_flow": null, 529 | "width": null, 530 | "min_width": null, 531 | "border": null, 532 | "align_items": null, 533 | "bottom": null, 534 | "_model_module": "@jupyter-widgets/base", 535 | "top": null, 536 | "grid_column": null, 537 | "overflow_y": null, 538 | "overflow_x": null, 539 | "grid_auto_flow": null, 540 | "grid_area": null, 541 | "grid_template_columns": null, 542 | "flex": null, 543 | "_model_name": "LayoutModel", 544 | "justify_items": null, 545 | "grid_row": null, 546 | "max_height": null, 547 | "align_content": null, 548 | "visibility": null, 549 | "align_self": null, 550 | "height": null, 551 | "min_height": null, 552 | "padding": null, 553 | "grid_auto_rows": null, 554 | "grid_gap": null, 555 | "max_width": null, 556 | "order": null, 557 | "_view_module_version": "1.2.0", 558 | "grid_template_areas": null, 559 | "object_position": null, 560 | "object_fit": null, 561 | "grid_auto_columns": null, 562 | "margin": null, 563 | "display": null, 564 | "left": null 565 | } 566 | }, 567 | "a56008dc6c6b41c3850611cc15fb6ea8": { 568 | "model_module": "@jupyter-widgets/controls", 569 | "model_name": "ProgressStyleModel", 570 | "model_module_version": "1.5.0", 571 | "state": { 572 | "_view_name": "StyleView", 573 | "_model_name": "ProgressStyleModel", 574 | "description_width": "", 575 | "_view_module": "@jupyter-widgets/base", 576 | "_model_module_version": "1.5.0", 577 | "_view_count": null, 578 | "_view_module_version": "1.2.0", 579 | "bar_color": null, 580 | "_model_module": "@jupyter-widgets/controls" 581 | } 582 | }, 583 | "504e8c3a107e4e04b51df39e1b3c584e": { 584 | "model_module": "@jupyter-widgets/base", 585 | "model_name": "LayoutModel", 586 | "model_module_version": "1.2.0", 587 | "state": { 588 | "_view_name": "LayoutView", 589 | "grid_template_rows": null, 590 | "right": null, 591 | "justify_content": null, 592 | "_view_module": "@jupyter-widgets/base", 593 | "overflow": null, 594 | "_model_module_version": "1.2.0", 595 | "_view_count": null, 596 | "flex_flow": null, 597 | "width": null, 598 | "min_width": null, 599 | "border": null, 600 | "align_items": null, 601 | "bottom": null, 602 | "_model_module": "@jupyter-widgets/base", 603 | "top": null, 604 | "grid_column": null, 605 | "overflow_y": null, 606 | "overflow_x": null, 607 | "grid_auto_flow": null, 608 | "grid_area": null, 609 | "grid_template_columns": null, 610 | "flex": null, 611 | "_model_name": "LayoutModel", 612 | "justify_items": null, 613 | "grid_row": null, 614 | "max_height": null, 615 | "align_content": null, 616 | "visibility": null, 617 | "align_self": null, 618 | "height": null, 619 | "min_height": null, 620 | "padding": null, 621 | "grid_auto_rows": null, 622 | "grid_gap": null, 623 | "max_width": null, 624 | "order": null, 625 | "_view_module_version": "1.2.0", 626 | "grid_template_areas": null, 627 | "object_position": null, 628 | "object_fit": null, 629 | "grid_auto_columns": null, 630 | "margin": null, 631 | "display": null, 632 | "left": null 633 | } 634 | }, 635 | "1f25240f5457445795af70e20e5903f9": { 636 | "model_module": "@jupyter-widgets/controls", 637 | "model_name": "DescriptionStyleModel", 638 | "model_module_version": "1.5.0", 639 | "state": { 640 | "_view_name": "StyleView", 641 | "_model_name": "DescriptionStyleModel", 642 | "description_width": "", 643 | "_view_module": "@jupyter-widgets/base", 644 | "_model_module_version": "1.5.0", 645 | "_view_count": null, 646 | "_view_module_version": "1.2.0", 647 | "_model_module": "@jupyter-widgets/controls" 648 | } 649 | }, 650 | "5110d1cb9a7547f49244113dc5dd8321": { 651 | "model_module": "@jupyter-widgets/base", 652 | "model_name": "LayoutModel", 653 | "model_module_version": "1.2.0", 654 | "state": { 655 | "_view_name": "LayoutView", 656 | "grid_template_rows": null, 657 | "right": null, 658 | "justify_content": null, 659 | "_view_module": "@jupyter-widgets/base", 660 | "overflow": null, 661 | "_model_module_version": "1.2.0", 662 | "_view_count": null, 663 | "flex_flow": null, 664 | "width": null, 665 | "min_width": null, 666 | "border": null, 667 | "align_items": null, 668 | "bottom": null, 669 | "_model_module": "@jupyter-widgets/base", 670 | "top": null, 671 | "grid_column": null, 672 | "overflow_y": null, 673 | "overflow_x": null, 674 | "grid_auto_flow": null, 675 | "grid_area": null, 676 | "grid_template_columns": null, 677 | "flex": null, 678 | "_model_name": "LayoutModel", 679 | "justify_items": null, 680 | "grid_row": null, 681 | "max_height": null, 682 | "align_content": null, 683 | "visibility": null, 684 | "align_self": null, 685 | "height": null, 686 | "min_height": null, 687 | "padding": null, 688 | "grid_auto_rows": null, 689 | "grid_gap": null, 690 | "max_width": null, 691 | "order": null, 692 | "_view_module_version": "1.2.0", 693 | "grid_template_areas": null, 694 | "object_position": null, 695 | "object_fit": null, 696 | "grid_auto_columns": null, 697 | "margin": null, 698 | "display": null, 699 | "left": null 700 | } 701 | }, 702 | "d3a13b7869354881ac7b7887e05d56a7": { 703 | "model_module": "@jupyter-widgets/controls", 704 | "model_name": "HBoxModel", 705 | "model_module_version": "1.5.0", 706 | "state": { 707 | "_view_name": "HBoxView", 708 | "_dom_classes": [], 709 | "_model_name": "HBoxModel", 710 | "_view_module": "@jupyter-widgets/controls", 711 | "_model_module_version": "1.5.0", 712 | "_view_count": null, 713 | "_view_module_version": "1.5.0", 714 | "box_style": "", 715 | "layout": "IPY_MODEL_4d22913664fb4cdbb38e217d4197601d", 716 | "_model_module": "@jupyter-widgets/controls", 717 | "children": [ 718 | "IPY_MODEL_870fe8f58bb24a47b6a98fa0eed0ebf5", 719 | "IPY_MODEL_193a5a054d6f4a319842820c4c308322", 720 | "IPY_MODEL_5a478b81997c4263a60d938447232fd9" 721 | ] 722 | } 723 | }, 724 | "4d22913664fb4cdbb38e217d4197601d": { 725 | "model_module": "@jupyter-widgets/base", 726 | "model_name": "LayoutModel", 727 | "model_module_version": "1.2.0", 728 | "state": { 729 | "_view_name": "LayoutView", 730 | "grid_template_rows": null, 731 | "right": null, 732 | "justify_content": null, 733 | "_view_module": "@jupyter-widgets/base", 734 | "overflow": null, 735 | "_model_module_version": "1.2.0", 736 | "_view_count": null, 737 | "flex_flow": null, 738 | "width": null, 739 | "min_width": null, 740 | "border": null, 741 | "align_items": null, 742 | "bottom": null, 743 | "_model_module": "@jupyter-widgets/base", 744 | "top": null, 745 | "grid_column": null, 746 | "overflow_y": null, 747 | "overflow_x": null, 748 | "grid_auto_flow": null, 749 | "grid_area": null, 750 | "grid_template_columns": null, 751 | "flex": null, 752 | "_model_name": "LayoutModel", 753 | "justify_items": null, 754 | "grid_row": null, 755 | "max_height": null, 756 | "align_content": null, 757 | "visibility": null, 758 | "align_self": null, 759 | "height": null, 760 | "min_height": null, 761 | "padding": null, 762 | "grid_auto_rows": null, 763 | "grid_gap": null, 764 | "max_width": null, 765 | "order": null, 766 | "_view_module_version": "1.2.0", 767 | "grid_template_areas": null, 768 | "object_position": null, 769 | "object_fit": null, 770 | "grid_auto_columns": null, 771 | "margin": null, 772 | "display": null, 773 | "left": null 774 | } 775 | }, 776 | "870fe8f58bb24a47b6a98fa0eed0ebf5": { 777 | "model_module": "@jupyter-widgets/controls", 778 | "model_name": "HTMLModel", 779 | "model_module_version": "1.5.0", 780 | "state": { 781 | "_view_name": "HTMLView", 782 | "style": "IPY_MODEL_5fbebd67d40e470e8a75a4b5b540bbdf", 783 | "_dom_classes": [], 784 | "description": "", 785 | "_model_name": "HTMLModel", 786 | "placeholder": "​", 787 | "_view_module": "@jupyter-widgets/controls", 788 | "_model_module_version": "1.5.0", 789 | "value": "Downloading: 100%", 790 | "_view_count": null, 791 | "_view_module_version": "1.5.0", 792 | "description_tooltip": null, 793 | "_model_module": "@jupyter-widgets/controls", 794 | "layout": "IPY_MODEL_d8b619dff35e4f8cabf586221ad8c962" 795 | } 796 | }, 797 | "193a5a054d6f4a319842820c4c308322": { 798 | "model_module": "@jupyter-widgets/controls", 799 | "model_name": "FloatProgressModel", 800 | "model_module_version": "1.5.0", 801 | "state": { 802 | "_view_name": "ProgressView", 803 | "style": "IPY_MODEL_d7bef2816920414f99a5c9cc676ec254", 804 | "_dom_classes": [], 805 | "description": "", 806 | "_model_name": "FloatProgressModel", 807 | "bar_style": "success", 808 | "max": 213450, 809 | "_view_module": "@jupyter-widgets/controls", 810 | "_model_module_version": "1.5.0", 811 | "value": 213450, 812 | "_view_count": null, 813 | "_view_module_version": "1.5.0", 814 | "orientation": "horizontal", 815 | "min": 0, 816 | "description_tooltip": null, 817 | "_model_module": "@jupyter-widgets/controls", 818 | "layout": "IPY_MODEL_b465cf9004f549ffa85444bfa16ee4c2" 819 | } 820 | }, 821 | "5a478b81997c4263a60d938447232fd9": { 822 | "model_module": "@jupyter-widgets/controls", 823 | "model_name": "HTMLModel", 824 | "model_module_version": "1.5.0", 825 | "state": { 826 | "_view_name": "HTMLView", 827 | "style": "IPY_MODEL_748d5c1fb00e4d91809003527319a9f6", 828 | "_dom_classes": [], 829 | "description": "", 830 | "_model_name": "HTMLModel", 831 | "placeholder": "​", 832 | "_view_module": "@jupyter-widgets/controls", 833 | "_model_module_version": "1.5.0", 834 | "value": " 213k/213k [00:00<00:00, 102kB/s]", 835 | "_view_count": null, 836 | "_view_module_version": "1.5.0", 837 | "description_tooltip": null, 838 | "_model_module": "@jupyter-widgets/controls", 839 | "layout": "IPY_MODEL_7d501eb3d36e48128a5232879141b281" 840 | } 841 | }, 842 | "5fbebd67d40e470e8a75a4b5b540bbdf": { 843 | "model_module": "@jupyter-widgets/controls", 844 | "model_name": "DescriptionStyleModel", 845 | "model_module_version": "1.5.0", 846 | "state": { 847 | "_view_name": "StyleView", 848 | "_model_name": "DescriptionStyleModel", 849 | "description_width": "", 850 | "_view_module": "@jupyter-widgets/base", 851 | "_model_module_version": "1.5.0", 852 | "_view_count": null, 853 | "_view_module_version": "1.2.0", 854 | "_model_module": "@jupyter-widgets/controls" 855 | } 856 | }, 857 | "d8b619dff35e4f8cabf586221ad8c962": { 858 | "model_module": "@jupyter-widgets/base", 859 | "model_name": "LayoutModel", 860 | "model_module_version": "1.2.0", 861 | "state": { 862 | "_view_name": "LayoutView", 863 | "grid_template_rows": null, 864 | "right": null, 865 | "justify_content": null, 866 | "_view_module": "@jupyter-widgets/base", 867 | "overflow": null, 868 | "_model_module_version": "1.2.0", 869 | "_view_count": null, 870 | "flex_flow": null, 871 | "width": null, 872 | "min_width": null, 873 | "border": null, 874 | "align_items": null, 875 | "bottom": null, 876 | "_model_module": "@jupyter-widgets/base", 877 | "top": null, 878 | "grid_column": null, 879 | "overflow_y": null, 880 | "overflow_x": null, 881 | "grid_auto_flow": null, 882 | "grid_area": null, 883 | "grid_template_columns": null, 884 | "flex": null, 885 | "_model_name": "LayoutModel", 886 | "justify_items": null, 887 | "grid_row": null, 888 | "max_height": null, 889 | "align_content": null, 890 | "visibility": null, 891 | "align_self": null, 892 | "height": null, 893 | "min_height": null, 894 | "padding": null, 895 | "grid_auto_rows": null, 896 | "grid_gap": null, 897 | "max_width": null, 898 | "order": null, 899 | "_view_module_version": "1.2.0", 900 | "grid_template_areas": null, 901 | "object_position": null, 902 | "object_fit": null, 903 | "grid_auto_columns": null, 904 | "margin": null, 905 | "display": null, 906 | "left": null 907 | } 908 | }, 909 | "d7bef2816920414f99a5c9cc676ec254": { 910 | "model_module": "@jupyter-widgets/controls", 911 | "model_name": "ProgressStyleModel", 912 | "model_module_version": "1.5.0", 913 | "state": { 914 | "_view_name": "StyleView", 915 | "_model_name": "ProgressStyleModel", 916 | "description_width": "", 917 | "_view_module": "@jupyter-widgets/base", 918 | "_model_module_version": "1.5.0", 919 | "_view_count": null, 920 | "_view_module_version": "1.2.0", 921 | "bar_color": null, 922 | "_model_module": "@jupyter-widgets/controls" 923 | } 924 | }, 925 | "b465cf9004f549ffa85444bfa16ee4c2": { 926 | "model_module": "@jupyter-widgets/base", 927 | "model_name": "LayoutModel", 928 | "model_module_version": "1.2.0", 929 | "state": { 930 | "_view_name": "LayoutView", 931 | "grid_template_rows": null, 932 | "right": null, 933 | "justify_content": null, 934 | "_view_module": "@jupyter-widgets/base", 935 | "overflow": null, 936 | "_model_module_version": "1.2.0", 937 | "_view_count": null, 938 | "flex_flow": null, 939 | "width": null, 940 | "min_width": null, 941 | "border": null, 942 | "align_items": null, 943 | "bottom": null, 944 | "_model_module": "@jupyter-widgets/base", 945 | "top": null, 946 | "grid_column": null, 947 | "overflow_y": null, 948 | "overflow_x": null, 949 | "grid_auto_flow": null, 950 | "grid_area": null, 951 | "grid_template_columns": null, 952 | "flex": null, 953 | "_model_name": "LayoutModel", 954 | "justify_items": null, 955 | "grid_row": null, 956 | "max_height": null, 957 | "align_content": null, 958 | "visibility": null, 959 | "align_self": null, 960 | "height": null, 961 | "min_height": null, 962 | "padding": null, 963 | "grid_auto_rows": null, 964 | "grid_gap": null, 965 | "max_width": null, 966 | "order": null, 967 | "_view_module_version": "1.2.0", 968 | "grid_template_areas": null, 969 | "object_position": null, 970 | "object_fit": null, 971 | "grid_auto_columns": null, 972 | "margin": null, 973 | "display": null, 974 | "left": null 975 | } 976 | }, 977 | "748d5c1fb00e4d91809003527319a9f6": { 978 | "model_module": "@jupyter-widgets/controls", 979 | "model_name": "DescriptionStyleModel", 980 | "model_module_version": "1.5.0", 981 | "state": { 982 | "_view_name": "StyleView", 983 | "_model_name": "DescriptionStyleModel", 984 | "description_width": "", 985 | "_view_module": "@jupyter-widgets/base", 986 | "_model_module_version": "1.5.0", 987 | "_view_count": null, 988 | "_view_module_version": "1.2.0", 989 | "_model_module": "@jupyter-widgets/controls" 990 | } 991 | }, 992 | "7d501eb3d36e48128a5232879141b281": { 993 | "model_module": "@jupyter-widgets/base", 994 | "model_name": "LayoutModel", 995 | "model_module_version": "1.2.0", 996 | "state": { 997 | "_view_name": "LayoutView", 998 | "grid_template_rows": null, 999 | "right": null, 1000 | "justify_content": null, 1001 | "_view_module": "@jupyter-widgets/base", 1002 | "overflow": null, 1003 | "_model_module_version": "1.2.0", 1004 | "_view_count": null, 1005 | "flex_flow": null, 1006 | "width": null, 1007 | "min_width": null, 1008 | "border": null, 1009 | "align_items": null, 1010 | "bottom": null, 1011 | "_model_module": "@jupyter-widgets/base", 1012 | "top": null, 1013 | "grid_column": null, 1014 | "overflow_y": null, 1015 | "overflow_x": null, 1016 | "grid_auto_flow": null, 1017 | "grid_area": null, 1018 | "grid_template_columns": null, 1019 | "flex": null, 1020 | "_model_name": "LayoutModel", 1021 | "justify_items": null, 1022 | "grid_row": null, 1023 | "max_height": null, 1024 | "align_content": null, 1025 | "visibility": null, 1026 | "align_self": null, 1027 | "height": null, 1028 | "min_height": null, 1029 | "padding": null, 1030 | "grid_auto_rows": null, 1031 | "grid_gap": null, 1032 | "max_width": null, 1033 | "order": null, 1034 | "_view_module_version": "1.2.0", 1035 | "grid_template_areas": null, 1036 | "object_position": null, 1037 | "object_fit": null, 1038 | "grid_auto_columns": null, 1039 | "margin": null, 1040 | "display": null, 1041 | "left": null 1042 | } 1043 | }, 1044 | "547d7329b8554b7bb8ae51d61a5e41f8": { 1045 | "model_module": "@jupyter-widgets/controls", 1046 | "model_name": "HBoxModel", 1047 | "model_module_version": "1.5.0", 1048 | "state": { 1049 | "_view_name": "HBoxView", 1050 | "_dom_classes": [], 1051 | "_model_name": "HBoxModel", 1052 | "_view_module": "@jupyter-widgets/controls", 1053 | "_model_module_version": "1.5.0", 1054 | "_view_count": null, 1055 | "_view_module_version": "1.5.0", 1056 | "box_style": "", 1057 | "layout": "IPY_MODEL_81b6bafd3fc248afa76cf463f2cb8ab8", 1058 | "_model_module": "@jupyter-widgets/controls", 1059 | "children": [ 1060 | "IPY_MODEL_bac8f28c84144450b24bbc97fa4860db", 1061 | "IPY_MODEL_30e0829529874db1802f411f72f3d76a", 1062 | "IPY_MODEL_83dd72fbb4204fefb951b3800d73a8d2" 1063 | ] 1064 | } 1065 | }, 1066 | "81b6bafd3fc248afa76cf463f2cb8ab8": { 1067 | "model_module": "@jupyter-widgets/base", 1068 | "model_name": "LayoutModel", 1069 | "model_module_version": "1.2.0", 1070 | "state": { 1071 | "_view_name": "LayoutView", 1072 | "grid_template_rows": null, 1073 | "right": null, 1074 | "justify_content": null, 1075 | "_view_module": "@jupyter-widgets/base", 1076 | "overflow": null, 1077 | "_model_module_version": "1.2.0", 1078 | "_view_count": null, 1079 | "flex_flow": null, 1080 | "width": null, 1081 | "min_width": null, 1082 | "border": null, 1083 | "align_items": null, 1084 | "bottom": null, 1085 | "_model_module": "@jupyter-widgets/base", 1086 | "top": null, 1087 | "grid_column": null, 1088 | "overflow_y": null, 1089 | "overflow_x": null, 1090 | "grid_auto_flow": null, 1091 | "grid_area": null, 1092 | "grid_template_columns": null, 1093 | "flex": null, 1094 | "_model_name": "LayoutModel", 1095 | "justify_items": null, 1096 | "grid_row": null, 1097 | "max_height": null, 1098 | "align_content": null, 1099 | "visibility": null, 1100 | "align_self": null, 1101 | "height": null, 1102 | "min_height": null, 1103 | "padding": null, 1104 | "grid_auto_rows": null, 1105 | "grid_gap": null, 1106 | "max_width": null, 1107 | "order": null, 1108 | "_view_module_version": "1.2.0", 1109 | "grid_template_areas": null, 1110 | "object_position": null, 1111 | "object_fit": null, 1112 | "grid_auto_columns": null, 1113 | "margin": null, 1114 | "display": null, 1115 | "left": null 1116 | } 1117 | }, 1118 | "bac8f28c84144450b24bbc97fa4860db": { 1119 | "model_module": "@jupyter-widgets/controls", 1120 | "model_name": "HTMLModel", 1121 | "model_module_version": "1.5.0", 1122 | "state": { 1123 | "_view_name": "HTMLView", 1124 | "style": "IPY_MODEL_f9390d4fc9b147729f1ef5ea85d03774", 1125 | "_dom_classes": [], 1126 | "description": "", 1127 | "_model_name": "HTMLModel", 1128 | "placeholder": "​", 1129 | "_view_module": "@jupyter-widgets/controls", 1130 | "_model_module_version": "1.5.0", 1131 | "value": "Downloading: 100%", 1132 | "_view_count": null, 1133 | "_view_module_version": "1.5.0", 1134 | "description_tooltip": null, 1135 | "_model_module": "@jupyter-widgets/controls", 1136 | "layout": "IPY_MODEL_ddf20490dc874352aa7df35f07a60bcc" 1137 | } 1138 | }, 1139 | "30e0829529874db1802f411f72f3d76a": { 1140 | "model_module": "@jupyter-widgets/controls", 1141 | "model_name": "FloatProgressModel", 1142 | "model_module_version": "1.5.0", 1143 | "state": { 1144 | "_view_name": "ProgressView", 1145 | "style": "IPY_MODEL_c0f0d9a0d4f4440e81e5a6d5a2a3b4ab", 1146 | "_dom_classes": [], 1147 | "description": "", 1148 | "_model_name": "FloatProgressModel", 1149 | "bar_style": "success", 1150 | "max": 435797, 1151 | "_view_module": "@jupyter-widgets/controls", 1152 | "_model_module_version": "1.5.0", 1153 | "value": 435797, 1154 | "_view_count": null, 1155 | "_view_module_version": "1.5.0", 1156 | "orientation": "horizontal", 1157 | "min": 0, 1158 | "description_tooltip": null, 1159 | "_model_module": "@jupyter-widgets/controls", 1160 | "layout": "IPY_MODEL_f41b83ccf26c40ed88ca6eaf49e09de5" 1161 | } 1162 | }, 1163 | "83dd72fbb4204fefb951b3800d73a8d2": { 1164 | "model_module": "@jupyter-widgets/controls", 1165 | "model_name": "HTMLModel", 1166 | "model_module_version": "1.5.0", 1167 | "state": { 1168 | "_view_name": "HTMLView", 1169 | "style": "IPY_MODEL_be6339f6256d4b579c53ef8b430ecb25", 1170 | "_dom_classes": [], 1171 | "description": "", 1172 | "_model_name": "HTMLModel", 1173 | "placeholder": "​", 1174 | "_view_module": "@jupyter-widgets/controls", 1175 | "_model_module_version": "1.5.0", 1176 | "value": " 436k/436k [00:00<00:00, 1.26MB/s]", 1177 | "_view_count": null, 1178 | "_view_module_version": "1.5.0", 1179 | "description_tooltip": null, 1180 | "_model_module": "@jupyter-widgets/controls", 1181 | "layout": "IPY_MODEL_b79fef6d585843c0a4d885edb2d01ebd" 1182 | } 1183 | }, 1184 | "f9390d4fc9b147729f1ef5ea85d03774": { 1185 | "model_module": "@jupyter-widgets/controls", 1186 | "model_name": "DescriptionStyleModel", 1187 | "model_module_version": "1.5.0", 1188 | "state": { 1189 | "_view_name": "StyleView", 1190 | "_model_name": "DescriptionStyleModel", 1191 | "description_width": "", 1192 | "_view_module": "@jupyter-widgets/base", 1193 | "_model_module_version": "1.5.0", 1194 | "_view_count": null, 1195 | "_view_module_version": "1.2.0", 1196 | "_model_module": "@jupyter-widgets/controls" 1197 | } 1198 | }, 1199 | "ddf20490dc874352aa7df35f07a60bcc": { 1200 | "model_module": "@jupyter-widgets/base", 1201 | "model_name": "LayoutModel", 1202 | "model_module_version": "1.2.0", 1203 | "state": { 1204 | "_view_name": "LayoutView", 1205 | "grid_template_rows": null, 1206 | "right": null, 1207 | "justify_content": null, 1208 | "_view_module": "@jupyter-widgets/base", 1209 | "overflow": null, 1210 | "_model_module_version": "1.2.0", 1211 | "_view_count": null, 1212 | "flex_flow": null, 1213 | "width": null, 1214 | "min_width": null, 1215 | "border": null, 1216 | "align_items": null, 1217 | "bottom": null, 1218 | "_model_module": "@jupyter-widgets/base", 1219 | "top": null, 1220 | "grid_column": null, 1221 | "overflow_y": null, 1222 | "overflow_x": null, 1223 | "grid_auto_flow": null, 1224 | "grid_area": null, 1225 | "grid_template_columns": null, 1226 | "flex": null, 1227 | "_model_name": "LayoutModel", 1228 | "justify_items": null, 1229 | "grid_row": null, 1230 | "max_height": null, 1231 | "align_content": null, 1232 | "visibility": null, 1233 | "align_self": null, 1234 | "height": null, 1235 | "min_height": null, 1236 | "padding": null, 1237 | "grid_auto_rows": null, 1238 | "grid_gap": null, 1239 | "max_width": null, 1240 | "order": null, 1241 | "_view_module_version": "1.2.0", 1242 | "grid_template_areas": null, 1243 | "object_position": null, 1244 | "object_fit": null, 1245 | "grid_auto_columns": null, 1246 | "margin": null, 1247 | "display": null, 1248 | "left": null 1249 | } 1250 | }, 1251 | "c0f0d9a0d4f4440e81e5a6d5a2a3b4ab": { 1252 | "model_module": "@jupyter-widgets/controls", 1253 | "model_name": "ProgressStyleModel", 1254 | "model_module_version": "1.5.0", 1255 | "state": { 1256 | "_view_name": "StyleView", 1257 | "_model_name": "ProgressStyleModel", 1258 | "description_width": "", 1259 | "_view_module": "@jupyter-widgets/base", 1260 | "_model_module_version": "1.5.0", 1261 | "_view_count": null, 1262 | "_view_module_version": "1.2.0", 1263 | "bar_color": null, 1264 | "_model_module": "@jupyter-widgets/controls" 1265 | } 1266 | }, 1267 | "f41b83ccf26c40ed88ca6eaf49e09de5": { 1268 | "model_module": "@jupyter-widgets/base", 1269 | "model_name": "LayoutModel", 1270 | "model_module_version": "1.2.0", 1271 | "state": { 1272 | "_view_name": "LayoutView", 1273 | "grid_template_rows": null, 1274 | "right": null, 1275 | "justify_content": null, 1276 | "_view_module": "@jupyter-widgets/base", 1277 | "overflow": null, 1278 | "_model_module_version": "1.2.0", 1279 | "_view_count": null, 1280 | "flex_flow": null, 1281 | "width": null, 1282 | "min_width": null, 1283 | "border": null, 1284 | "align_items": null, 1285 | "bottom": null, 1286 | "_model_module": "@jupyter-widgets/base", 1287 | "top": null, 1288 | "grid_column": null, 1289 | "overflow_y": null, 1290 | "overflow_x": null, 1291 | "grid_auto_flow": null, 1292 | "grid_area": null, 1293 | "grid_template_columns": null, 1294 | "flex": null, 1295 | "_model_name": "LayoutModel", 1296 | "justify_items": null, 1297 | "grid_row": null, 1298 | "max_height": null, 1299 | "align_content": null, 1300 | "visibility": null, 1301 | "align_self": null, 1302 | "height": null, 1303 | "min_height": null, 1304 | "padding": null, 1305 | "grid_auto_rows": null, 1306 | "grid_gap": null, 1307 | "max_width": null, 1308 | "order": null, 1309 | "_view_module_version": "1.2.0", 1310 | "grid_template_areas": null, 1311 | "object_position": null, 1312 | "object_fit": null, 1313 | "grid_auto_columns": null, 1314 | "margin": null, 1315 | "display": null, 1316 | "left": null 1317 | } 1318 | }, 1319 | "be6339f6256d4b579c53ef8b430ecb25": { 1320 | "model_module": "@jupyter-widgets/controls", 1321 | "model_name": "DescriptionStyleModel", 1322 | "model_module_version": "1.5.0", 1323 | "state": { 1324 | "_view_name": "StyleView", 1325 | "_model_name": "DescriptionStyleModel", 1326 | "description_width": "", 1327 | "_view_module": "@jupyter-widgets/base", 1328 | "_model_module_version": "1.5.0", 1329 | "_view_count": null, 1330 | "_view_module_version": "1.2.0", 1331 | "_model_module": "@jupyter-widgets/controls" 1332 | } 1333 | }, 1334 | "b79fef6d585843c0a4d885edb2d01ebd": { 1335 | "model_module": "@jupyter-widgets/base", 1336 | "model_name": "LayoutModel", 1337 | "model_module_version": "1.2.0", 1338 | "state": { 1339 | "_view_name": "LayoutView", 1340 | "grid_template_rows": null, 1341 | "right": null, 1342 | "justify_content": null, 1343 | "_view_module": "@jupyter-widgets/base", 1344 | "overflow": null, 1345 | "_model_module_version": "1.2.0", 1346 | "_view_count": null, 1347 | "flex_flow": null, 1348 | "width": null, 1349 | "min_width": null, 1350 | "border": null, 1351 | "align_items": null, 1352 | "bottom": null, 1353 | "_model_module": "@jupyter-widgets/base", 1354 | "top": null, 1355 | "grid_column": null, 1356 | "overflow_y": null, 1357 | "overflow_x": null, 1358 | "grid_auto_flow": null, 1359 | "grid_area": null, 1360 | "grid_template_columns": null, 1361 | "flex": null, 1362 | "_model_name": "LayoutModel", 1363 | "justify_items": null, 1364 | "grid_row": null, 1365 | "max_height": null, 1366 | "align_content": null, 1367 | "visibility": null, 1368 | "align_self": null, 1369 | "height": null, 1370 | "min_height": null, 1371 | "padding": null, 1372 | "grid_auto_rows": null, 1373 | "grid_gap": null, 1374 | "max_width": null, 1375 | "order": null, 1376 | "_view_module_version": "1.2.0", 1377 | "grid_template_areas": null, 1378 | "object_position": null, 1379 | "object_fit": null, 1380 | "grid_auto_columns": null, 1381 | "margin": null, 1382 | "display": null, 1383 | "left": null 1384 | } 1385 | } 1386 | } 1387 | } 1388 | }, 1389 | "cells": [ 1390 | { 1391 | "cell_type": "markdown", 1392 | "metadata": { 1393 | "id": "view-in-github", 1394 | "colab_type": "text" 1395 | }, 1396 | "source": [ 1397 | "\"Open" 1398 | ] 1399 | }, 1400 | { 1401 | "cell_type": "markdown", 1402 | "metadata": { 1403 | "id": "fUpqAwtN8rTA" 1404 | }, 1405 | "source": [ 1406 | "# GAN-BERT (in Pytorch and compatible with HuggingFace)\n", 1407 | "\n", 1408 | "This is a Pytorch (+ **Huggingface** transformers) implementation of the GAN-BERT model from https://github.com/crux82/ganbert. While the original GAN-BERT was an extension of BERT, this implementation can be adapted to several architectures, ranging from Roberta to Albert!\n", 1409 | "\n", 1410 | "**NOTE**: given that this implementation is different from the original one in Tensorflow, some results can be slighty different.\n" 1411 | ] 1412 | }, 1413 | { 1414 | "cell_type": "markdown", 1415 | "metadata": { 1416 | "id": "Q0m5KR34gmRH" 1417 | }, 1418 | "source": [ 1419 | "Let's GO!\n", 1420 | "\n", 1421 | "Required Imports." 1422 | ] 1423 | }, 1424 | { 1425 | "cell_type": "code", 1426 | "metadata": { 1427 | "colab": { 1428 | "base_uri": "https://localhost:8080/" 1429 | }, 1430 | "id": "UIqpm34x2rms", 1431 | "outputId": "b0205d19-dff1-4967-d003-990c3c5c8164" 1432 | }, 1433 | "source": [ 1434 | "!pip install transformers==4.3.2\n", 1435 | "import torch\n", 1436 | "import io\n", 1437 | "import torch.nn.functional as F\n", 1438 | "import random\n", 1439 | "import numpy as np\n", 1440 | "import time\n", 1441 | "import math\n", 1442 | "import datetime\n", 1443 | "import torch.nn as nn\n", 1444 | "from transformers import *\n", 1445 | "from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler\n", 1446 | "#!pip install torch==1.7.1+cu101 torchvision==0.8.2+cu101 -f https://download.pytorch.org/whl/torch_stable.html\n", 1447 | "#!pip install sentencepiece\n", 1448 | "\n", 1449 | "##Set random values\n", 1450 | "seed_val = 42\n", 1451 | "random.seed(seed_val)\n", 1452 | "np.random.seed(seed_val)\n", 1453 | "torch.manual_seed(seed_val)\n", 1454 | "if torch.cuda.is_available():\n", 1455 | " torch.cuda.manual_seed_all(seed_val)" 1456 | ], 1457 | "execution_count": 1, 1458 | "outputs": [ 1459 | { 1460 | "output_type": "stream", 1461 | "name": "stdout", 1462 | "text": [ 1463 | "Collecting transformers==4.3.2\n", 1464 | " Downloading transformers-4.3.2-py3-none-any.whl (1.8 MB)\n", 1465 | "\u001b[K |████████████████████████████████| 1.8 MB 5.2 MB/s \n", 1466 | "\u001b[?25hRequirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.7/dist-packages (from transformers==4.3.2) (4.62.3)\n", 1467 | "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.7/dist-packages (from transformers==4.3.2) (1.19.5)\n", 1468 | "Requirement already satisfied: importlib-metadata in /usr/local/lib/python3.7/dist-packages (from transformers==4.3.2) (4.8.2)\n", 1469 | "Requirement already satisfied: requests in /usr/local/lib/python3.7/dist-packages (from transformers==4.3.2) (2.23.0)\n", 1470 | "Requirement already satisfied: filelock in /usr/local/lib/python3.7/dist-packages (from transformers==4.3.2) (3.4.0)\n", 1471 | "Collecting tokenizers<0.11,>=0.10.1\n", 1472 | " Downloading tokenizers-0.10.3-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (3.3 MB)\n", 1473 | "\u001b[K |████████████████████████████████| 3.3 MB 33.1 MB/s \n", 1474 | "\u001b[?25hRequirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.7/dist-packages (from transformers==4.3.2) (2019.12.20)\n", 1475 | "Collecting sacremoses\n", 1476 | " Downloading sacremoses-0.0.46-py3-none-any.whl (895 kB)\n", 1477 | "\u001b[K |████████████████████████████████| 895 kB 36.6 MB/s \n", 1478 | "\u001b[?25hRequirement already satisfied: packaging in /usr/local/lib/python3.7/dist-packages (from transformers==4.3.2) (21.3)\n", 1479 | "Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.7/dist-packages (from importlib-metadata->transformers==4.3.2) (3.6.0)\n", 1480 | "Requirement already satisfied: typing-extensions>=3.6.4 in /usr/local/lib/python3.7/dist-packages (from importlib-metadata->transformers==4.3.2) (3.10.0.2)\n", 1481 | "Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from packaging->transformers==4.3.2) (3.0.6)\n", 1482 | "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests->transformers==4.3.2) (2.10)\n", 1483 | "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests->transformers==4.3.2) (2021.10.8)\n", 1484 | "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests->transformers==4.3.2) (1.24.3)\n", 1485 | "Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests->transformers==4.3.2) (3.0.4)\n", 1486 | "Requirement already satisfied: six in /usr/local/lib/python3.7/dist-packages (from sacremoses->transformers==4.3.2) (1.15.0)\n", 1487 | "Requirement already satisfied: joblib in /usr/local/lib/python3.7/dist-packages (from sacremoses->transformers==4.3.2) (1.1.0)\n", 1488 | "Requirement already satisfied: click in /usr/local/lib/python3.7/dist-packages (from sacremoses->transformers==4.3.2) (7.1.2)\n", 1489 | "Installing collected packages: tokenizers, sacremoses, transformers\n", 1490 | "Successfully installed sacremoses-0.0.46 tokenizers-0.10.3 transformers-4.3.2\n" 1491 | ] 1492 | } 1493 | ] 1494 | }, 1495 | { 1496 | "cell_type": "code", 1497 | "metadata": { 1498 | "id": "LeZgRup520II", 1499 | "colab": { 1500 | "base_uri": "https://localhost:8080/" 1501 | }, 1502 | "outputId": "5b8d1039-e1e6-4712-e77b-e1e9f1b9fbcb" 1503 | }, 1504 | "source": [ 1505 | "# If there's a GPU available...\n", 1506 | "if torch.cuda.is_available(): \n", 1507 | " # Tell PyTorch to use the GPU. \n", 1508 | " device = torch.device(\"cuda\")\n", 1509 | " print('There are %d GPU(s) available.' % torch.cuda.device_count())\n", 1510 | " print('We will use the GPU:', torch.cuda.get_device_name(0))\n", 1511 | "# If not...\n", 1512 | "else:\n", 1513 | " print('No GPU available, using the CPU instead.')\n", 1514 | " device = torch.device(\"cpu\")" 1515 | ], 1516 | "execution_count": 2, 1517 | "outputs": [ 1518 | { 1519 | "output_type": "stream", 1520 | "name": "stdout", 1521 | "text": [ 1522 | "There are 1 GPU(s) available.\n", 1523 | "We will use the GPU: Tesla K80\n" 1524 | ] 1525 | } 1526 | ] 1527 | }, 1528 | { 1529 | "cell_type": "markdown", 1530 | "metadata": { 1531 | "id": "AU3ns8Ic7I-h" 1532 | }, 1533 | "source": [ 1534 | "### Input Parameters\n" 1535 | ] 1536 | }, 1537 | { 1538 | "cell_type": "code", 1539 | "metadata": { 1540 | "id": "jw0HC_hU3FUy", 1541 | "colab": { 1542 | "base_uri": "https://localhost:8080/" 1543 | }, 1544 | "outputId": "6ae87fcf-ed0b-4c78-b9aa-7d86d80cb933" 1545 | }, 1546 | "source": [ 1547 | "#--------------------------------\n", 1548 | "# Transformer parameters\n", 1549 | "#--------------------------------\n", 1550 | "max_seq_length = 64\n", 1551 | "batch_size = 64\n", 1552 | "\n", 1553 | "#--------------------------------\n", 1554 | "# GAN-BERT specific parameters\n", 1555 | "#--------------------------------\n", 1556 | "# number of hidden layers in the generator, \n", 1557 | "# each of the size of the output space\n", 1558 | "num_hidden_layers_g = 1; \n", 1559 | "# number of hidden layers in the discriminator, \n", 1560 | "# each of the size of the input space\n", 1561 | "num_hidden_layers_d = 1; \n", 1562 | "# size of the generator's input noisy vectors\n", 1563 | "noise_size = 100\n", 1564 | "# dropout to be applied to discriminator's input vectors\n", 1565 | "out_dropout_rate = 0.2\n", 1566 | "\n", 1567 | "# Replicate labeled data to balance poorly represented datasets, \n", 1568 | "# e.g., less than 1% of labeled material\n", 1569 | "apply_balance = True\n", 1570 | "\n", 1571 | "#--------------------------------\n", 1572 | "# Optimization parameters\n", 1573 | "#--------------------------------\n", 1574 | "learning_rate_discriminator = 5e-5\n", 1575 | "learning_rate_generator = 5e-5\n", 1576 | "epsilon = 1e-8\n", 1577 | "num_train_epochs = 10\n", 1578 | "multi_gpu = True\n", 1579 | "# Scheduler\n", 1580 | "apply_scheduler = False\n", 1581 | "warmup_proportion = 0.1\n", 1582 | "# Print\n", 1583 | "print_each_n_step = 10\n", 1584 | "\n", 1585 | "#--------------------------------\n", 1586 | "# Adopted Tranformer model\n", 1587 | "#--------------------------------\n", 1588 | "# Since this version is compatible with Huggingface transformers, you can uncomment\n", 1589 | "# (or add) transformer models compatible with GAN\n", 1590 | "\n", 1591 | "model_name = \"bert-base-cased\"\n", 1592 | "#model_name = \"bert-base-uncased\"\n", 1593 | "#model_name = \"roberta-base\"\n", 1594 | "#model_name = \"albert-base-v2\"\n", 1595 | "#model_name = \"xlm-roberta-base\"\n", 1596 | "#model_name = \"amazon/bort\"\n", 1597 | "\n", 1598 | "#--------------------------------\n", 1599 | "# Retrieve the TREC QC Dataset\n", 1600 | "#--------------------------------\n", 1601 | "! git clone https://github.com/crux82/ganbert\n", 1602 | "\n", 1603 | "# NOTE: in this setting 50 classes are involved\n", 1604 | "labeled_file = \"./ganbert/data/labeled.tsv\"\n", 1605 | "unlabeled_file = \"./ganbert/data/unlabeled.tsv\"\n", 1606 | "test_filename = \"./ganbert/data/test.tsv\"\n", 1607 | "\n", 1608 | "label_list = [\"UNK_UNK\",\"ABBR_abb\", \"ABBR_exp\", \"DESC_def\", \"DESC_desc\", \n", 1609 | " \"DESC_manner\", \"DESC_reason\", \"ENTY_animal\", \"ENTY_body\", \n", 1610 | " \"ENTY_color\", \"ENTY_cremat\", \"ENTY_currency\", \"ENTY_dismed\", \n", 1611 | " \"ENTY_event\", \"ENTY_food\", \"ENTY_instru\", \"ENTY_lang\", \n", 1612 | " \"ENTY_letter\", \"ENTY_other\", \"ENTY_plant\", \"ENTY_product\", \n", 1613 | " \"ENTY_religion\", \"ENTY_sport\", \"ENTY_substance\", \"ENTY_symbol\", \n", 1614 | " \"ENTY_techmeth\", \"ENTY_termeq\", \"ENTY_veh\", \"ENTY_word\", \"HUM_desc\", \n", 1615 | " \"HUM_gr\", \"HUM_ind\", \"HUM_title\", \"LOC_city\", \"LOC_country\", \n", 1616 | " \"LOC_mount\", \"LOC_other\", \"LOC_state\", \"NUM_code\", \"NUM_count\", \n", 1617 | " \"NUM_date\", \"NUM_dist\", \"NUM_money\", \"NUM_ord\", \"NUM_other\", \n", 1618 | " \"NUM_perc\", \"NUM_period\", \"NUM_speed\", \"NUM_temp\", \"NUM_volsize\", \n", 1619 | " \"NUM_weight\"]" 1620 | ], 1621 | "execution_count": 3, 1622 | "outputs": [ 1623 | { 1624 | "output_type": "stream", 1625 | "name": "stdout", 1626 | "text": [ 1627 | "Cloning into 'ganbert'...\n", 1628 | "remote: Enumerating objects: 77, done.\u001b[K\n", 1629 | "remote: Counting objects: 100% (77/77), done.\u001b[K\n", 1630 | "remote: Compressing objects: 100% (59/59), done.\u001b[K\n", 1631 | "remote: Total 77 (delta 33), reused 54 (delta 18), pack-reused 0\u001b[K\n", 1632 | "Unpacking objects: 100% (77/77), done.\n" 1633 | ] 1634 | } 1635 | ] 1636 | }, 1637 | { 1638 | "cell_type": "markdown", 1639 | "metadata": { 1640 | "id": "R6Q5jzVioTHb" 1641 | }, 1642 | "source": [ 1643 | "Load the Tranformer Model" 1644 | ] 1645 | }, 1646 | { 1647 | "cell_type": "code", 1648 | "metadata": { 1649 | "id": "gxghkkZq3Gbn", 1650 | "outputId": "a4a5afd0-1b6c-4c2d-e3eb-7ced49df4e33", 1651 | "colab": { 1652 | "base_uri": "https://localhost:8080/", 1653 | "height": 145, 1654 | "referenced_widgets": [ 1655 | "22cde8fbae4e49af993e33f3f2d9a28e", 1656 | "2c671428c4df4355a7a53c71e9bd14ee", 1657 | "44660941ffcc44beae72a1974d458583", 1658 | "afdbd837001c4c74b5ac332ca061ef3a", 1659 | "b489404bf53b4fab9bdb1c0c79a33008", 1660 | "2bfb43b9e8604cbf8ebfd162267f1b8a", 1661 | "4f556917850542da8508f9f839cae9bc", 1662 | "28f045edfa48462fa96a94cffd3b143f", 1663 | "14a6abbb244b41f89626a37640c63118", 1664 | "0dccd0ab28c34880be92b35aa05e6ffb", 1665 | "bf127a02cf6647aba4035c5cbfadc378", 1666 | "098f203f1209452a9d4192af92da7057", 1667 | "ec76cfd2d1da498e9b66885ff1be46b3", 1668 | "b94243bfea7e4441b809b38c7cecd875", 1669 | "0097aa33393342cd99be3dcc30edc5a0", 1670 | "293c4d8660f546f79b43e0dc63250c2f", 1671 | "9fc9479d78e442db91360de7f45b6d7f", 1672 | "14bf2ab82bfc45bd8a3b93f2ab9aa656", 1673 | "a56008dc6c6b41c3850611cc15fb6ea8", 1674 | "504e8c3a107e4e04b51df39e1b3c584e", 1675 | "1f25240f5457445795af70e20e5903f9", 1676 | "5110d1cb9a7547f49244113dc5dd8321", 1677 | "d3a13b7869354881ac7b7887e05d56a7", 1678 | "4d22913664fb4cdbb38e217d4197601d", 1679 | "870fe8f58bb24a47b6a98fa0eed0ebf5", 1680 | "193a5a054d6f4a319842820c4c308322", 1681 | "5a478b81997c4263a60d938447232fd9", 1682 | "5fbebd67d40e470e8a75a4b5b540bbdf", 1683 | "d8b619dff35e4f8cabf586221ad8c962", 1684 | "d7bef2816920414f99a5c9cc676ec254", 1685 | "b465cf9004f549ffa85444bfa16ee4c2", 1686 | "748d5c1fb00e4d91809003527319a9f6", 1687 | "7d501eb3d36e48128a5232879141b281", 1688 | "547d7329b8554b7bb8ae51d61a5e41f8", 1689 | "81b6bafd3fc248afa76cf463f2cb8ab8", 1690 | "bac8f28c84144450b24bbc97fa4860db", 1691 | "30e0829529874db1802f411f72f3d76a", 1692 | "83dd72fbb4204fefb951b3800d73a8d2", 1693 | "f9390d4fc9b147729f1ef5ea85d03774", 1694 | "ddf20490dc874352aa7df35f07a60bcc", 1695 | "c0f0d9a0d4f4440e81e5a6d5a2a3b4ab", 1696 | "f41b83ccf26c40ed88ca6eaf49e09de5", 1697 | "be6339f6256d4b579c53ef8b430ecb25", 1698 | "b79fef6d585843c0a4d885edb2d01ebd" 1699 | ] 1700 | } 1701 | }, 1702 | "source": [ 1703 | "transformer = AutoModel.from_pretrained(model_name)\n", 1704 | "tokenizer = AutoTokenizer.from_pretrained(model_name)" 1705 | ], 1706 | "execution_count": 4, 1707 | "outputs": [ 1708 | { 1709 | "output_type": "display_data", 1710 | "data": { 1711 | "application/vnd.jupyter.widget-view+json": { 1712 | "model_id": "22cde8fbae4e49af993e33f3f2d9a28e", 1713 | "version_minor": 0, 1714 | "version_major": 2 1715 | }, 1716 | "text/plain": [ 1717 | "Downloading: 0%| | 0.00/570 [00:00 0) for token_id in sent] \n", 1889 | " input_mask_array.append(att_mask)\n", 1890 | " # Convertion to Tensor\n", 1891 | " input_ids = torch.tensor(input_ids) \n", 1892 | " input_mask_array = torch.tensor(input_mask_array)\n", 1893 | " label_id_array = torch.tensor(label_id_array, dtype=torch.long)\n", 1894 | " label_mask_array = torch.tensor(label_mask_array)\n", 1895 | "\n", 1896 | " # Building the TensorDataset\n", 1897 | " dataset = TensorDataset(input_ids, input_mask_array, label_id_array, label_mask_array)\n", 1898 | "\n", 1899 | " if do_shuffle:\n", 1900 | " sampler = RandomSampler\n", 1901 | " else:\n", 1902 | " sampler = SequentialSampler\n", 1903 | "\n", 1904 | " # Building the DataLoader\n", 1905 | " return DataLoader(\n", 1906 | " dataset, # The training samples.\n", 1907 | " sampler = sampler(dataset), \n", 1908 | " batch_size = batch_size) # Trains with this batch size.\n", 1909 | "\n", 1910 | "def format_time(elapsed):\n", 1911 | " '''\n", 1912 | " Takes a time in seconds and returns a string hh:mm:ss\n", 1913 | " '''\n", 1914 | " # Round to the nearest second.\n", 1915 | " elapsed_rounded = int(round((elapsed)))\n", 1916 | " # Format as hh:mm:ss\n", 1917 | " return str(datetime.timedelta(seconds=elapsed_rounded))" 1918 | ], 1919 | "execution_count": 7, 1920 | "outputs": [] 1921 | }, 1922 | { 1923 | "cell_type": "markdown", 1924 | "metadata": { 1925 | "id": "Do3O-VeefT3g" 1926 | }, 1927 | "source": [ 1928 | "Convert the input examples into DataLoader" 1929 | ] 1930 | }, 1931 | { 1932 | "cell_type": "code", 1933 | "metadata": { 1934 | "id": "4c-nsMXlKX-D" 1935 | }, 1936 | "source": [ 1937 | "label_map = {}\n", 1938 | "for (i, label) in enumerate(label_list):\n", 1939 | " label_map[label] = i\n", 1940 | "#------------------------------\n", 1941 | "# Load the train dataset\n", 1942 | "#------------------------------\n", 1943 | "train_examples = labeled_examples\n", 1944 | "#The labeled (train) dataset is assigned with a mask set to True\n", 1945 | "train_label_masks = np.ones(len(labeled_examples), dtype=bool)\n", 1946 | "#If unlabel examples are available\n", 1947 | "if unlabeled_examples:\n", 1948 | " train_examples = train_examples + unlabeled_examples\n", 1949 | " #The unlabeled (train) dataset is assigned with a mask set to False\n", 1950 | " tmp_masks = np.zeros(len(unlabeled_examples), dtype=bool)\n", 1951 | " train_label_masks = np.concatenate([train_label_masks,tmp_masks])\n", 1952 | "\n", 1953 | "train_dataloader = generate_data_loader(train_examples, train_label_masks, label_map, do_shuffle = True, balance_label_examples = apply_balance)\n", 1954 | "\n", 1955 | "#------------------------------\n", 1956 | "# Load the test dataset\n", 1957 | "#------------------------------\n", 1958 | "#The labeled (test) dataset is assigned with a mask set to True\n", 1959 | "test_label_masks = np.ones(len(test_examples), dtype=bool)\n", 1960 | "\n", 1961 | "test_dataloader = generate_data_loader(test_examples, test_label_masks, label_map, do_shuffle = False, balance_label_examples = False)" 1962 | ], 1963 | "execution_count": 8, 1964 | "outputs": [] 1965 | }, 1966 | { 1967 | "cell_type": "markdown", 1968 | "metadata": { 1969 | "id": "6Ihcw3vquaQm" 1970 | }, 1971 | "source": [ 1972 | "We define the Generator and Discriminator as discussed in https://www.aclweb.org/anthology/2020.acl-main.191/" 1973 | ] 1974 | }, 1975 | { 1976 | "cell_type": "code", 1977 | "metadata": { 1978 | "id": "18kY64-n3I6y" 1979 | }, 1980 | "source": [ 1981 | "#------------------------------\n", 1982 | "# The Generator as in \n", 1983 | "# https://www.aclweb.org/anthology/2020.acl-main.191/\n", 1984 | "# https://github.com/crux82/ganbert\n", 1985 | "#------------------------------\n", 1986 | "class Generator(nn.Module):\n", 1987 | " def __init__(self, noise_size=100, output_size=512, hidden_sizes=[512], dropout_rate=0.1):\n", 1988 | " super(Generator, self).__init__()\n", 1989 | " layers = []\n", 1990 | " hidden_sizes = [noise_size] + hidden_sizes\n", 1991 | " for i in range(len(hidden_sizes)-1):\n", 1992 | " layers.extend([nn.Linear(hidden_sizes[i], hidden_sizes[i+1]), nn.LeakyReLU(0.2, inplace=True), nn.Dropout(dropout_rate)])\n", 1993 | "\n", 1994 | " layers.append(nn.Linear(hidden_sizes[-1],output_size))\n", 1995 | " self.layers = nn.Sequential(*layers)\n", 1996 | "\n", 1997 | " def forward(self, noise):\n", 1998 | " output_rep = self.layers(noise)\n", 1999 | " return output_rep\n", 2000 | "\n", 2001 | "#------------------------------\n", 2002 | "# The Discriminator\n", 2003 | "# https://www.aclweb.org/anthology/2020.acl-main.191/\n", 2004 | "# https://github.com/crux82/ganbert\n", 2005 | "#------------------------------\n", 2006 | "class Discriminator(nn.Module):\n", 2007 | " def __init__(self, input_size=512, hidden_sizes=[512], num_labels=2, dropout_rate=0.1):\n", 2008 | " super(Discriminator, self).__init__()\n", 2009 | " self.input_dropout = nn.Dropout(p=dropout_rate)\n", 2010 | " layers = []\n", 2011 | " hidden_sizes = [input_size] + hidden_sizes\n", 2012 | " for i in range(len(hidden_sizes)-1):\n", 2013 | " layers.extend([nn.Linear(hidden_sizes[i], hidden_sizes[i+1]), nn.LeakyReLU(0.2, inplace=True), nn.Dropout(dropout_rate)])\n", 2014 | "\n", 2015 | " self.layers = nn.Sequential(*layers) #per il flatten\n", 2016 | " self.logit = nn.Linear(hidden_sizes[-1],num_labels+1) # +1 for the probability of this sample being fake/real.\n", 2017 | " self.softmax = nn.Softmax(dim=-1)\n", 2018 | "\n", 2019 | " def forward(self, input_rep):\n", 2020 | " input_rep = self.input_dropout(input_rep)\n", 2021 | " last_rep = self.layers(input_rep)\n", 2022 | " logits = self.logit(last_rep)\n", 2023 | " probs = self.softmax(logits)\n", 2024 | " return last_rep, logits, probs" 2025 | ], 2026 | "execution_count": 9, 2027 | "outputs": [] 2028 | }, 2029 | { 2030 | "cell_type": "markdown", 2031 | "metadata": { 2032 | "id": "Uje9s2zQunFc" 2033 | }, 2034 | "source": [ 2035 | "We instantiate the Discriminator and Generator" 2036 | ] 2037 | }, 2038 | { 2039 | "cell_type": "code", 2040 | "metadata": { 2041 | "id": "Ylz5rvqE3U2S" 2042 | }, 2043 | "source": [ 2044 | "# The config file is required to get the dimension of the vector produced by \n", 2045 | "# the underlying transformer\n", 2046 | "config = AutoConfig.from_pretrained(model_name)\n", 2047 | "hidden_size = int(config.hidden_size)\n", 2048 | "# Define the number and width of hidden layers\n", 2049 | "hidden_levels_g = [hidden_size for i in range(0, num_hidden_layers_g)]\n", 2050 | "hidden_levels_d = [hidden_size for i in range(0, num_hidden_layers_d)]\n", 2051 | "\n", 2052 | "#-------------------------------------------------\n", 2053 | "# Instantiate the Generator and Discriminator\n", 2054 | "#-------------------------------------------------\n", 2055 | "generator = Generator(noise_size=noise_size, output_size=hidden_size, hidden_sizes=hidden_levels_g, dropout_rate=out_dropout_rate)\n", 2056 | "discriminator = Discriminator(input_size=hidden_size, hidden_sizes=hidden_levels_d, num_labels=len(label_list), dropout_rate=out_dropout_rate)\n", 2057 | "\n", 2058 | "# Put everything in the GPU if available\n", 2059 | "if torch.cuda.is_available(): \n", 2060 | " generator.cuda()\n", 2061 | " discriminator.cuda()\n", 2062 | " transformer.cuda()\n", 2063 | " if multi_gpu:\n", 2064 | " transformer = torch.nn.DataParallel(transformer)\n", 2065 | "\n", 2066 | "# print(config)" 2067 | ], 2068 | "execution_count": 10, 2069 | "outputs": [] 2070 | }, 2071 | { 2072 | "cell_type": "markdown", 2073 | "metadata": { 2074 | "id": "VG3qzp2-usZE" 2075 | }, 2076 | "source": [ 2077 | "Let's go with the training procedure" 2078 | ] 2079 | }, 2080 | { 2081 | "cell_type": "code", 2082 | "metadata": { 2083 | "id": "NhqylHGK3Va4", 2084 | "colab": { 2085 | "base_uri": "https://localhost:8080/" 2086 | }, 2087 | "outputId": "726efd06-d8de-4a45-a7bb-f186994a6b2a" 2088 | }, 2089 | "source": [ 2090 | "training_stats = []\n", 2091 | "\n", 2092 | "# Measure the total training time for the whole run.\n", 2093 | "total_t0 = time.time()\n", 2094 | "\n", 2095 | "#models parameters\n", 2096 | "transformer_vars = [i for i in transformer.parameters()]\n", 2097 | "d_vars = transformer_vars + [v for v in discriminator.parameters()]\n", 2098 | "g_vars = [v for v in generator.parameters()]\n", 2099 | "\n", 2100 | "#optimizer\n", 2101 | "dis_optimizer = torch.optim.AdamW(d_vars, lr=learning_rate_discriminator)\n", 2102 | "gen_optimizer = torch.optim.AdamW(g_vars, lr=learning_rate_generator) \n", 2103 | "\n", 2104 | "#scheduler\n", 2105 | "if apply_scheduler:\n", 2106 | " num_train_examples = len(train_examples)\n", 2107 | " num_train_steps = int(num_train_examples / batch_size * num_train_epochs)\n", 2108 | " num_warmup_steps = int(num_train_steps * warmup_proportion)\n", 2109 | "\n", 2110 | " scheduler_d = get_constant_schedule_with_warmup(dis_optimizer, \n", 2111 | " num_warmup_steps = num_warmup_steps)\n", 2112 | " scheduler_g = get_constant_schedule_with_warmup(gen_optimizer, \n", 2113 | " num_warmup_steps = num_warmup_steps)\n", 2114 | "\n", 2115 | "# For each epoch...\n", 2116 | "for epoch_i in range(0, num_train_epochs):\n", 2117 | " # ========================================\n", 2118 | " # Training\n", 2119 | " # ========================================\n", 2120 | " # Perform one full pass over the training set.\n", 2121 | " print(\"\")\n", 2122 | " print('======== Epoch {:} / {:} ========'.format(epoch_i + 1, num_train_epochs))\n", 2123 | " print('Training...')\n", 2124 | "\n", 2125 | " # Measure how long the training epoch takes.\n", 2126 | " t0 = time.time()\n", 2127 | "\n", 2128 | " # Reset the total loss for this epoch.\n", 2129 | " tr_g_loss = 0\n", 2130 | " tr_d_loss = 0\n", 2131 | "\n", 2132 | " # Put the model into training mode.\n", 2133 | " transformer.train() \n", 2134 | " generator.train()\n", 2135 | " discriminator.train()\n", 2136 | "\n", 2137 | " # For each batch of training data...\n", 2138 | " for step, batch in enumerate(train_dataloader):\n", 2139 | "\n", 2140 | " # Progress update every print_each_n_step batches.\n", 2141 | " if step % print_each_n_step == 0 and not step == 0:\n", 2142 | " # Calculate elapsed time in minutes.\n", 2143 | " elapsed = format_time(time.time() - t0)\n", 2144 | " \n", 2145 | " # Report progress.\n", 2146 | " print(' Batch {:>5,} of {:>5,}. Elapsed: {:}.'.format(step, len(train_dataloader), elapsed))\n", 2147 | "\n", 2148 | " # Unpack this training batch from our dataloader. \n", 2149 | " b_input_ids = batch[0].to(device)\n", 2150 | " b_input_mask = batch[1].to(device)\n", 2151 | " b_labels = batch[2].to(device)\n", 2152 | " b_label_mask = batch[3].to(device)\n", 2153 | "\n", 2154 | " real_batch_size = b_input_ids.shape[0]\n", 2155 | " \n", 2156 | " # Encode real data in the Transformer\n", 2157 | " model_outputs = transformer(b_input_ids, attention_mask=b_input_mask)\n", 2158 | " hidden_states = model_outputs[-1]\n", 2159 | " \n", 2160 | " # Generate fake data that should have the same distribution of the ones\n", 2161 | " # encoded by the transformer. \n", 2162 | " # First noisy input are used in input to the Generator\n", 2163 | " noise = torch.zeros(real_batch_size, noise_size, device=device).uniform_(0, 1)\n", 2164 | " # Gnerate Fake data\n", 2165 | " gen_rep = generator(noise)\n", 2166 | "\n", 2167 | " # Generate the output of the Discriminator for real and fake data.\n", 2168 | " # First, we put together the output of the tranformer and the generator\n", 2169 | " disciminator_input = torch.cat([hidden_states, gen_rep], dim=0)\n", 2170 | " # Then, we select the output of the disciminator\n", 2171 | " features, logits, probs = discriminator(disciminator_input)\n", 2172 | "\n", 2173 | " # Finally, we separate the discriminator's output for the real and fake\n", 2174 | " # data\n", 2175 | " features_list = torch.split(features, real_batch_size)\n", 2176 | " D_real_features = features_list[0]\n", 2177 | " D_fake_features = features_list[1]\n", 2178 | " \n", 2179 | " logits_list = torch.split(logits, real_batch_size)\n", 2180 | " D_real_logits = logits_list[0]\n", 2181 | " D_fake_logits = logits_list[1]\n", 2182 | " \n", 2183 | " probs_list = torch.split(probs, real_batch_size)\n", 2184 | " D_real_probs = probs_list[0]\n", 2185 | " D_fake_probs = probs_list[1]\n", 2186 | "\n", 2187 | " #---------------------------------\n", 2188 | " # LOSS evaluation\n", 2189 | " #---------------------------------\n", 2190 | " # Generator's LOSS estimation\n", 2191 | " g_loss_d = -1 * torch.mean(torch.log(1 - D_fake_probs[:,-1] + epsilon))\n", 2192 | " g_feat_reg = torch.mean(torch.pow(torch.mean(D_real_features, dim=0) - torch.mean(D_fake_features, dim=0), 2))\n", 2193 | " g_loss = g_loss_d + g_feat_reg\n", 2194 | " \n", 2195 | " # Disciminator's LOSS estimation\n", 2196 | " logits = D_real_logits[:,0:-1]\n", 2197 | " log_probs = F.log_softmax(logits, dim=-1)\n", 2198 | " # The discriminator provides an output for labeled and unlabeled real data\n", 2199 | " # so the loss evaluated for unlabeled data is ignored (masked)\n", 2200 | " label2one_hot = torch.nn.functional.one_hot(b_labels, len(label_list))\n", 2201 | " per_example_loss = -torch.sum(label2one_hot * log_probs, dim=-1)\n", 2202 | " per_example_loss = torch.masked_select(per_example_loss, b_label_mask.to(device))\n", 2203 | " labeled_example_count = per_example_loss.type(torch.float32).numel()\n", 2204 | "\n", 2205 | " # It may be the case that a batch does not contain labeled examples, \n", 2206 | " # so the \"supervised loss\" in this case is not evaluated\n", 2207 | " if labeled_example_count == 0:\n", 2208 | " D_L_Supervised = 0\n", 2209 | " else:\n", 2210 | " D_L_Supervised = torch.div(torch.sum(per_example_loss.to(device)), labeled_example_count)\n", 2211 | " \n", 2212 | " D_L_unsupervised1U = -1 * torch.mean(torch.log(1 - D_real_probs[:, -1] + epsilon))\n", 2213 | " D_L_unsupervised2U = -1 * torch.mean(torch.log(D_fake_probs[:, -1] + epsilon))\n", 2214 | " d_loss = D_L_Supervised + D_L_unsupervised1U + D_L_unsupervised2U\n", 2215 | "\n", 2216 | " #---------------------------------\n", 2217 | " # OPTIMIZATION\n", 2218 | " #---------------------------------\n", 2219 | " # Avoid gradient accumulation\n", 2220 | " gen_optimizer.zero_grad()\n", 2221 | " dis_optimizer.zero_grad()\n", 2222 | "\n", 2223 | " # Calculate weigth updates\n", 2224 | " # retain_graph=True is required since the underlying graph will be deleted after backward\n", 2225 | " g_loss.backward(retain_graph=True)\n", 2226 | " d_loss.backward() \n", 2227 | " \n", 2228 | " # Apply modifications\n", 2229 | " gen_optimizer.step()\n", 2230 | " dis_optimizer.step()\n", 2231 | "\n", 2232 | " # A detail log of the individual losses\n", 2233 | " #print(\"{0:.4f}\\t{1:.4f}\\t{2:.4f}\\t{3:.4f}\\t{4:.4f}\".\n", 2234 | " # format(D_L_Supervised, D_L_unsupervised1U, D_L_unsupervised2U,\n", 2235 | " # g_loss_d, g_feat_reg))\n", 2236 | "\n", 2237 | " # Save the losses to print them later\n", 2238 | " tr_g_loss += g_loss.item()\n", 2239 | " tr_d_loss += d_loss.item()\n", 2240 | "\n", 2241 | " # Update the learning rate with the scheduler\n", 2242 | " if apply_scheduler:\n", 2243 | " scheduler_d.step()\n", 2244 | " scheduler_g.step()\n", 2245 | "\n", 2246 | " # Calculate the average loss over all of the batches.\n", 2247 | " avg_train_loss_g = tr_g_loss / len(train_dataloader)\n", 2248 | " avg_train_loss_d = tr_d_loss / len(train_dataloader) \n", 2249 | " \n", 2250 | " # Measure how long this epoch took.\n", 2251 | " training_time = format_time(time.time() - t0)\n", 2252 | "\n", 2253 | " print(\"\")\n", 2254 | " print(\" Average training loss generetor: {0:.3f}\".format(avg_train_loss_g))\n", 2255 | " print(\" Average training loss discriminator: {0:.3f}\".format(avg_train_loss_d))\n", 2256 | " print(\" Training epcoh took: {:}\".format(training_time))\n", 2257 | " \n", 2258 | " # ========================================\n", 2259 | " # TEST ON THE EVALUATION DATASET\n", 2260 | " # ========================================\n", 2261 | " # After the completion of each training epoch, measure our performance on\n", 2262 | " # our test set.\n", 2263 | " print(\"\")\n", 2264 | " print(\"Running Test...\")\n", 2265 | "\n", 2266 | " t0 = time.time()\n", 2267 | "\n", 2268 | " # Put the model in evaluation mode--the dropout layers behave differently\n", 2269 | " # during evaluation.\n", 2270 | " transformer.eval() #maybe redundant\n", 2271 | " discriminator.eval()\n", 2272 | " generator.eval()\n", 2273 | "\n", 2274 | " # Tracking variables \n", 2275 | " total_test_accuracy = 0\n", 2276 | " \n", 2277 | " total_test_loss = 0\n", 2278 | " nb_test_steps = 0\n", 2279 | "\n", 2280 | " all_preds = []\n", 2281 | " all_labels_ids = []\n", 2282 | "\n", 2283 | " #loss\n", 2284 | " nll_loss = torch.nn.CrossEntropyLoss(ignore_index=-1)\n", 2285 | "\n", 2286 | " # Evaluate data for one epoch\n", 2287 | " for batch in test_dataloader:\n", 2288 | " \n", 2289 | " # Unpack this training batch from our dataloader. \n", 2290 | " b_input_ids = batch[0].to(device)\n", 2291 | " b_input_mask = batch[1].to(device)\n", 2292 | " b_labels = batch[2].to(device)\n", 2293 | " \n", 2294 | " # Tell pytorch not to bother with constructing the compute graph during\n", 2295 | " # the forward pass, since this is only needed for backprop (training).\n", 2296 | " with torch.no_grad(): \n", 2297 | " model_outputs = transformer(b_input_ids, attention_mask=b_input_mask)\n", 2298 | " hidden_states = model_outputs[-1]\n", 2299 | " _, logits, probs = discriminator(hidden_states)\n", 2300 | " ###log_probs = F.log_softmax(probs[:,1:], dim=-1)\n", 2301 | " filtered_logits = logits[:,0:-1]\n", 2302 | " # Accumulate the test loss.\n", 2303 | " total_test_loss += nll_loss(filtered_logits, b_labels)\n", 2304 | " \n", 2305 | " # Accumulate the predictions and the input labels\n", 2306 | " _, preds = torch.max(filtered_logits, 1)\n", 2307 | " all_preds += preds.detach().cpu()\n", 2308 | " all_labels_ids += b_labels.detach().cpu()\n", 2309 | "\n", 2310 | " # Report the final accuracy for this validation run.\n", 2311 | " all_preds = torch.stack(all_preds).numpy()\n", 2312 | " all_labels_ids = torch.stack(all_labels_ids).numpy()\n", 2313 | " test_accuracy = np.sum(all_preds == all_labels_ids) / len(all_preds)\n", 2314 | " print(\" Accuracy: {0:.3f}\".format(test_accuracy))\n", 2315 | "\n", 2316 | " # Calculate the average loss over all of the batches.\n", 2317 | " avg_test_loss = total_test_loss / len(test_dataloader)\n", 2318 | " avg_test_loss = avg_test_loss.item()\n", 2319 | " \n", 2320 | " # Measure how long the validation run took.\n", 2321 | " test_time = format_time(time.time() - t0)\n", 2322 | " \n", 2323 | " print(\" Test Loss: {0:.3f}\".format(avg_test_loss))\n", 2324 | " print(\" Test took: {:}\".format(test_time))\n", 2325 | "\n", 2326 | " # Record all statistics from this epoch.\n", 2327 | " training_stats.append(\n", 2328 | " {\n", 2329 | " 'epoch': epoch_i + 1,\n", 2330 | " 'Training Loss generator': avg_train_loss_g,\n", 2331 | " 'Training Loss discriminator': avg_train_loss_d,\n", 2332 | " 'Valid. Loss': avg_test_loss,\n", 2333 | " 'Valid. Accur.': test_accuracy,\n", 2334 | " 'Training Time': training_time,\n", 2335 | " 'Test Time': test_time\n", 2336 | " }\n", 2337 | " )" 2338 | ], 2339 | "execution_count": 11, 2340 | "outputs": [ 2341 | { 2342 | "output_type": "stream", 2343 | "name": "stdout", 2344 | "text": [ 2345 | "\n", 2346 | "======== Epoch 1 / 10 ========\n", 2347 | "Training...\n", 2348 | " Batch 10 of 92. Elapsed: 0:00:20.\n", 2349 | " Batch 20 of 92. Elapsed: 0:00:41.\n", 2350 | " Batch 30 of 92. Elapsed: 0:01:01.\n", 2351 | " Batch 40 of 92. Elapsed: 0:01:22.\n", 2352 | " Batch 50 of 92. Elapsed: 0:01:42.\n", 2353 | " Batch 60 of 92. Elapsed: 0:02:02.\n", 2354 | " Batch 70 of 92. Elapsed: 0:02:22.\n", 2355 | " Batch 80 of 92. Elapsed: 0:02:43.\n", 2356 | " Batch 90 of 92. Elapsed: 0:03:03.\n", 2357 | "\n", 2358 | " Average training loss generetor: 0.603\n", 2359 | " Average training loss discriminator: 4.578\n", 2360 | " Training epcoh took: 0:03:07\n", 2361 | "\n", 2362 | "Running Test...\n", 2363 | " Accuracy: 0.130\n", 2364 | " Test Loss: 3.086\n", 2365 | " Test took: 0:00:04\n", 2366 | "\n", 2367 | "======== Epoch 2 / 10 ========\n", 2368 | "Training...\n", 2369 | " Batch 10 of 92. Elapsed: 0:00:20.\n", 2370 | " Batch 20 of 92. Elapsed: 0:00:40.\n", 2371 | " Batch 30 of 92. Elapsed: 0:01:01.\n", 2372 | " Batch 40 of 92. Elapsed: 0:01:21.\n", 2373 | " Batch 50 of 92. Elapsed: 0:01:41.\n", 2374 | " Batch 60 of 92. Elapsed: 0:02:01.\n", 2375 | " Batch 70 of 92. Elapsed: 0:02:22.\n", 2376 | " Batch 80 of 92. Elapsed: 0:02:42.\n", 2377 | " Batch 90 of 92. Elapsed: 0:03:02.\n", 2378 | "\n", 2379 | " Average training loss generetor: 0.763\n", 2380 | " Average training loss discriminator: 2.759\n", 2381 | " Training epcoh took: 0:03:06\n", 2382 | "\n", 2383 | "Running Test...\n", 2384 | " Accuracy: 0.526\n", 2385 | " Test Loss: 2.510\n", 2386 | " Test took: 0:00:04\n", 2387 | "\n", 2388 | "======== Epoch 3 / 10 ========\n", 2389 | "Training...\n", 2390 | " Batch 10 of 92. Elapsed: 0:00:20.\n", 2391 | " Batch 20 of 92. Elapsed: 0:00:41.\n", 2392 | " Batch 30 of 92. Elapsed: 0:01:01.\n", 2393 | " Batch 40 of 92. Elapsed: 0:01:21.\n", 2394 | " Batch 50 of 92. Elapsed: 0:01:41.\n", 2395 | " Batch 60 of 92. Elapsed: 0:02:01.\n", 2396 | " Batch 70 of 92. Elapsed: 0:02:22.\n", 2397 | " Batch 80 of 92. Elapsed: 0:02:42.\n", 2398 | " Batch 90 of 92. Elapsed: 0:03:02.\n", 2399 | "\n", 2400 | " Average training loss generetor: 0.743\n", 2401 | " Average training loss discriminator: 1.991\n", 2402 | " Training epcoh took: 0:03:06\n", 2403 | "\n", 2404 | "Running Test...\n", 2405 | " Accuracy: 0.578\n", 2406 | " Test Loss: 2.066\n", 2407 | " Test took: 0:00:04\n", 2408 | "\n", 2409 | "======== Epoch 4 / 10 ========\n", 2410 | "Training...\n", 2411 | " Batch 10 of 92. Elapsed: 0:00:20.\n", 2412 | " Batch 20 of 92. Elapsed: 0:00:40.\n", 2413 | " Batch 30 of 92. Elapsed: 0:01:00.\n", 2414 | " Batch 40 of 92. Elapsed: 0:01:21.\n", 2415 | " Batch 50 of 92. Elapsed: 0:01:41.\n", 2416 | " Batch 60 of 92. Elapsed: 0:02:01.\n", 2417 | " Batch 70 of 92. Elapsed: 0:02:21.\n", 2418 | " Batch 80 of 92. Elapsed: 0:02:41.\n", 2419 | " Batch 90 of 92. Elapsed: 0:03:01.\n", 2420 | "\n", 2421 | " Average training loss generetor: 0.736\n", 2422 | " Average training loss discriminator: 1.399\n", 2423 | " Training epcoh took: 0:03:05\n", 2424 | "\n", 2425 | "Running Test...\n", 2426 | " Accuracy: 0.602\n", 2427 | " Test Loss: 2.016\n", 2428 | " Test took: 0:00:04\n", 2429 | "\n", 2430 | "======== Epoch 5 / 10 ========\n", 2431 | "Training...\n", 2432 | " Batch 10 of 92. Elapsed: 0:00:20.\n", 2433 | " Batch 20 of 92. Elapsed: 0:00:40.\n", 2434 | " Batch 30 of 92. Elapsed: 0:01:00.\n", 2435 | " Batch 40 of 92. Elapsed: 0:01:20.\n", 2436 | " Batch 50 of 92. Elapsed: 0:01:40.\n", 2437 | " Batch 60 of 92. Elapsed: 0:02:00.\n", 2438 | " Batch 70 of 92. Elapsed: 0:02:21.\n", 2439 | " Batch 80 of 92. Elapsed: 0:02:41.\n", 2440 | " Batch 90 of 92. Elapsed: 0:03:01.\n", 2441 | "\n", 2442 | " Average training loss generetor: 0.729\n", 2443 | " Average training loss discriminator: 1.140\n", 2444 | " Training epcoh took: 0:03:05\n", 2445 | "\n", 2446 | "Running Test...\n", 2447 | " Accuracy: 0.638\n", 2448 | " Test Loss: 1.960\n", 2449 | " Test took: 0:00:04\n", 2450 | "\n", 2451 | "======== Epoch 6 / 10 ========\n", 2452 | "Training...\n", 2453 | " Batch 10 of 92. Elapsed: 0:00:20.\n", 2454 | " Batch 20 of 92. Elapsed: 0:00:40.\n", 2455 | " Batch 30 of 92. Elapsed: 0:01:00.\n", 2456 | " Batch 40 of 92. Elapsed: 0:01:20.\n", 2457 | " Batch 50 of 92. Elapsed: 0:01:40.\n", 2458 | " Batch 60 of 92. Elapsed: 0:02:00.\n", 2459 | " Batch 70 of 92. Elapsed: 0:02:20.\n", 2460 | " Batch 80 of 92. Elapsed: 0:02:40.\n", 2461 | " Batch 90 of 92. Elapsed: 0:03:01.\n", 2462 | "\n", 2463 | " Average training loss generetor: 0.725\n", 2464 | " Average training loss discriminator: 1.020\n", 2465 | " Training epcoh took: 0:03:05\n", 2466 | "\n", 2467 | "Running Test...\n", 2468 | " Accuracy: 0.588\n", 2469 | " Test Loss: 2.236\n", 2470 | " Test took: 0:00:04\n", 2471 | "\n", 2472 | "======== Epoch 7 / 10 ========\n", 2473 | "Training...\n", 2474 | " Batch 10 of 92. Elapsed: 0:00:20.\n", 2475 | " Batch 20 of 92. Elapsed: 0:00:40.\n", 2476 | " Batch 30 of 92. Elapsed: 0:01:00.\n", 2477 | " Batch 40 of 92. Elapsed: 0:01:20.\n", 2478 | " Batch 50 of 92. Elapsed: 0:01:40.\n", 2479 | " Batch 60 of 92. Elapsed: 0:02:00.\n", 2480 | " Batch 70 of 92. Elapsed: 0:02:21.\n", 2481 | " Batch 80 of 92. Elapsed: 0:02:41.\n", 2482 | " Batch 90 of 92. Elapsed: 0:03:01.\n", 2483 | "\n", 2484 | " Average training loss generetor: 0.721\n", 2485 | " Average training loss discriminator: 0.925\n", 2486 | " Training epcoh took: 0:03:05\n", 2487 | "\n", 2488 | "Running Test...\n", 2489 | " Accuracy: 0.632\n", 2490 | " Test Loss: 2.116\n", 2491 | " Test took: 0:00:04\n", 2492 | "\n", 2493 | "======== Epoch 8 / 10 ========\n", 2494 | "Training...\n", 2495 | " Batch 10 of 92. Elapsed: 0:00:20.\n", 2496 | " Batch 20 of 92. Elapsed: 0:00:40.\n", 2497 | " Batch 30 of 92. Elapsed: 0:01:00.\n", 2498 | " Batch 40 of 92. Elapsed: 0:01:20.\n", 2499 | " Batch 50 of 92. Elapsed: 0:01:40.\n", 2500 | " Batch 60 of 92. Elapsed: 0:02:00.\n", 2501 | " Batch 70 of 92. Elapsed: 0:02:21.\n", 2502 | " Batch 80 of 92. Elapsed: 0:02:41.\n", 2503 | " Batch 90 of 92. Elapsed: 0:03:01.\n", 2504 | "\n", 2505 | " Average training loss generetor: 0.718\n", 2506 | " Average training loss discriminator: 0.861\n", 2507 | " Training epcoh took: 0:03:05\n", 2508 | "\n", 2509 | "Running Test...\n", 2510 | " Accuracy: 0.644\n", 2511 | " Test Loss: 2.150\n", 2512 | " Test took: 0:00:04\n", 2513 | "\n", 2514 | "======== Epoch 9 / 10 ========\n", 2515 | "Training...\n", 2516 | " Batch 10 of 92. Elapsed: 0:00:20.\n", 2517 | " Batch 20 of 92. Elapsed: 0:00:40.\n", 2518 | " Batch 30 of 92. Elapsed: 0:01:00.\n", 2519 | " Batch 40 of 92. Elapsed: 0:01:20.\n", 2520 | " Batch 50 of 92. Elapsed: 0:01:40.\n", 2521 | " Batch 60 of 92. Elapsed: 0:02:00.\n", 2522 | " Batch 70 of 92. Elapsed: 0:02:20.\n", 2523 | " Batch 80 of 92. Elapsed: 0:02:41.\n", 2524 | " Batch 90 of 92. Elapsed: 0:03:01.\n", 2525 | "\n", 2526 | " Average training loss generetor: 0.719\n", 2527 | " Average training loss discriminator: 0.824\n", 2528 | " Training epcoh took: 0:03:05\n", 2529 | "\n", 2530 | "Running Test...\n", 2531 | " Accuracy: 0.632\n", 2532 | " Test Loss: 2.214\n", 2533 | " Test took: 0:00:04\n", 2534 | "\n", 2535 | "======== Epoch 10 / 10 ========\n", 2536 | "Training...\n", 2537 | " Batch 10 of 92. Elapsed: 0:00:20.\n", 2538 | " Batch 20 of 92. Elapsed: 0:00:40.\n", 2539 | " Batch 30 of 92. Elapsed: 0:01:00.\n", 2540 | " Batch 40 of 92. Elapsed: 0:01:20.\n", 2541 | " Batch 50 of 92. Elapsed: 0:01:40.\n", 2542 | " Batch 60 of 92. Elapsed: 0:02:00.\n", 2543 | " Batch 70 of 92. Elapsed: 0:02:20.\n", 2544 | " Batch 80 of 92. Elapsed: 0:02:41.\n", 2545 | " Batch 90 of 92. Elapsed: 0:03:01.\n", 2546 | "\n", 2547 | " Average training loss generetor: 0.714\n", 2548 | " Average training loss discriminator: 0.791\n", 2549 | " Training epcoh took: 0:03:05\n", 2550 | "\n", 2551 | "Running Test...\n", 2552 | " Accuracy: 0.626\n", 2553 | " Test Loss: 2.339\n", 2554 | " Test took: 0:00:04\n" 2555 | ] 2556 | } 2557 | ] 2558 | }, 2559 | { 2560 | "cell_type": "code", 2561 | "metadata": { 2562 | "id": "dDm9NProRB4c", 2563 | "colab": { 2564 | "base_uri": "https://localhost:8080/" 2565 | }, 2566 | "outputId": "2ffebcf1-6b39-4442-88c6-5f72a58a3722" 2567 | }, 2568 | "source": [ 2569 | "for stat in training_stats:\n", 2570 | " print(stat)\n", 2571 | "\n", 2572 | "print(\"\\nTraining complete!\")\n", 2573 | "\n", 2574 | "print(\"Total training took {:} (h:mm:ss)\".format(format_time(time.time()-total_t0)))" 2575 | ], 2576 | "execution_count": 12, 2577 | "outputs": [ 2578 | { 2579 | "output_type": "stream", 2580 | "name": "stdout", 2581 | "text": [ 2582 | "{'epoch': 1, 'Training Loss generator': 0.6030529494518819, 'Training Loss discriminator': 4.57757222911586, 'Valid. Loss': 3.0862255096435547, 'Valid. Accur.': 0.13, 'Training Time': '0:03:07', 'Test Time': '0:00:04'}\n", 2583 | "{'epoch': 2, 'Training Loss generator': 0.7626331766014514, 'Training Loss discriminator': 2.759067593709282, 'Valid. Loss': 2.510432720184326, 'Valid. Accur.': 0.526, 'Training Time': '0:03:06', 'Test Time': '0:00:04'}\n", 2584 | "{'epoch': 3, 'Training Loss generator': 0.7427908233974291, 'Training Loss discriminator': 1.9905100030743557, 'Valid. Loss': 2.065537691116333, 'Valid. Accur.': 0.578, 'Training Time': '0:03:06', 'Test Time': '0:00:04'}\n", 2585 | "{'epoch': 4, 'Training Loss generator': 0.735916405916214, 'Training Loss discriminator': 1.3987060305864916, 'Valid. Loss': 2.016447067260742, 'Valid. Accur.': 0.602, 'Training Time': '0:03:05', 'Test Time': '0:00:04'}\n", 2586 | "{'epoch': 5, 'Training Loss generator': 0.7287991940975189, 'Training Loss discriminator': 1.1400610927654349, 'Valid. Loss': 1.9602296352386475, 'Valid. Accur.': 0.638, 'Training Time': '0:03:05', 'Test Time': '0:00:04'}\n", 2587 | "{'epoch': 6, 'Training Loss generator': 0.7248763740062714, 'Training Loss discriminator': 1.0197444175896437, 'Valid. Loss': 2.236262321472168, 'Valid. Accur.': 0.588, 'Training Time': '0:03:05', 'Test Time': '0:00:04'}\n", 2588 | "{'epoch': 7, 'Training Loss generator': 0.7208150255939235, 'Training Loss discriminator': 0.9249581824178281, 'Valid. Loss': 2.116415500640869, 'Valid. Accur.': 0.632, 'Training Time': '0:03:05', 'Test Time': '0:00:04'}\n", 2589 | "{'epoch': 8, 'Training Loss generator': 0.7179992360913235, 'Training Loss discriminator': 0.8613665920236836, 'Valid. Loss': 2.150423765182495, 'Valid. Accur.': 0.644, 'Training Time': '0:03:05', 'Test Time': '0:00:04'}\n", 2590 | "{'epoch': 9, 'Training Loss generator': 0.7185413422791854, 'Training Loss discriminator': 0.8235321206891019, 'Valid. Loss': 2.213574171066284, 'Valid. Accur.': 0.632, 'Training Time': '0:03:05', 'Test Time': '0:00:04'}\n", 2591 | "{'epoch': 10, 'Training Loss generator': 0.7144109431816184, 'Training Loss discriminator': 0.7913454485976178, 'Valid. Loss': 2.33937931060791, 'Valid. Accur.': 0.626, 'Training Time': '0:03:05', 'Test Time': '0:00:04'}\n", 2592 | "\n", 2593 | "Training complete!\n", 2594 | "Total training took 0:31:30 (h:mm:ss)\n" 2595 | ] 2596 | } 2597 | ] 2598 | } 2599 | ] 2600 | } -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GAN-BERT (in Pytorch and compatible with HuggingFace) 2 | 3 | This is an implementation in Pytorch (and **HuggingFace**) of the GAN-BERT method from https://github.com/crux82/ganbert which is available in Tensorflow. While the original GAN-BERT was an extension of BERT, this implementation can be adapted to several architectures, ranging from Roberta to Albert! 4 | 5 | **IMPORTANT**: Since this implementation is slightly different from the original Tensorflow one, some results may vary. Any feedback or suggestions for improving this first version would be appreciated. 6 | 7 | ## GANBERT 8 | 9 | This is the code for the paper **"GAN-BERT: Generative Adversarial Learning for Robust Text Classification with a Bunch of Labeled Examples"** published in the **ACL 2020 - short paper** by *Danilo Croce* (Tor Vergata, University of Rome), *Giuseppe Castellucci* (Amazon) and *Roberto Basili* (Tor Vergata, University of Rome). 10 | 11 | GAN-BERT is an extension of BERT which uses a Generative Adversarial setting to implement an effective semi-supervised learning schema. It allows training BERT with datasets composed of a limited amount of labeled examples and larger subsets of unlabeled material. 12 | GAN-BERT can be used in sequence classification tasks (also involving text pairs). 13 | 14 | As in the original implementation in Tensorflow, this code runs the GAN-BERT experiment over the TREC dataset for the fine-grained Question Classification task. We provide in this package the code as well as the data for running an experiment by using 2% of the labeled material (109 examples) and 5343 unlabeled examples. 15 | The test set is composed of 500 annotated examples. 16 | 17 | ## The Model 18 | 19 | GAN-BERT is an extension of the BERT model within the Generative Adversarial Network (GAN) framework (Goodfellow et al, 2014). In particular, the Semi-Supervised GAN (Salimans et al, 2016) is used to make the BERT fine-tuning robust in such training scenarios where obtaining annotated material is problematic. When fine-tuned with very few labeled examples the BERT model is not able to provide sufficient performances. With GAN-BERT we extend the fine-tuning stage by introducing a Discriminator-Generator setting, where: 20 | 21 | - the Generator G is devoted to producing "fake" vector representations of sentences; 22 | - the Discriminator D is a BERT-based classifier over k+1 categories. 23 | 24 | ![GAN-BERT model](https://github.com/crux82/ganbert/raw/master/ganbert.jpg) 25 | 26 | D has the role of classifying an example concerning the k categories of the task of interest, and it should recognize the examples that are generated by G (the k+1 category). 27 | G, instead, must produce representations as much similar as possible to the ones produced by the model for the "real" examples. G is penalized when D correctly classifies an example as fake. 28 | 29 | In this context, the model is trained on both labeled and unlabeled examples. The labeled examples contribute to the computation of the loss function concerning the task k categories. The unlabeled examples contribute to the computation of the loss functions as they should not be incorrectly classified as belonging to the k+1 category (i.e., the fake category). 30 | 31 | The resulting model is demonstrated to learn text classification tasks starting from very few labeled examples (50-60 examples) and to outperform the classical BERT fine-tuned models by a large margin in this setting. 32 | 33 | More details are available at [https://github.com/crux82/ganbert](https://github.com/crux82/ganbert) 34 | 35 | ## Citation 36 | 37 | If this software is usefull for your research, please cite the following paper: 38 | 39 | ```bibtex 40 | @inproceedings{croce-etal-2020-gan, 41 | title = "{GAN}-{BERT}: Generative Adversarial Learning for Robust Text Classification with a Bunch of Labeled Examples", 42 | author = "Croce, Danilo and 43 | Castellucci, Giuseppe and 44 | Basili, Roberto", 45 | booktitle = "Proceedings of the 58th Annual Meeting of the Association for Computational Linguistics", 46 | month = jul, 47 | year = "2020", 48 | address = "Online", 49 | publisher = "Association for Computational Linguistics", 50 | url = "https://www.aclweb.org/anthology/2020.acl-main.191", 51 | pages = "2114--2119" 52 | } 53 | ``` 54 | 55 | ## Acknowledgments 56 | 57 | We would like to thank *Osman Mutlu* and *Ali Hürriyetoğlu* for their implementation of GAN-BERT in Pytorch that inspired our porting. 58 | You can find their initial repository at this [link](https://github.com/OsmanMutlu/Pytorch-GANBERT). 59 | We would like to thank *Claudia Breazzano* (Tor Vergata, University of Rome) that supported this porting. 60 | --------------------------------------------------------------------------------