├── CONTRIBUTING.md ├── LICENSE-APACHE ├── README.md └── Tutorial_on_SimCSE.ipynb /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to contribute 2 | 3 | We'd love to accept your patches and contributions to this project. 4 | There are just a few small guidelines you need to follow. 5 | 6 | - Unless you explicitly state otherwise, any contribution intentionally 7 | submitted for inclusion in work by you shall be licensed under 8 | Apache 2.0 without any additional terms or conditions. 9 | - All submissions, including submissions by project members, require review. 10 | We use GitHub pull requests for this purpose. 11 | Consult GitHub Help for more information on using pull requests. 12 | -------------------------------------------------------------------------------- /LICENSE-APACHE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Tutorial on SimCSE 2 | 3 | This repository provides a tutorial notebook on how to implement SimCSE: 4 | 5 | > Tianyu Gao, Xingcheng Yao, and Danqi Chen. 6 | > [SimCSE: Simple Contrastive Learning of Sentence Embeddings](https://aclanthology.org/2021.emnlp-main.552/). 7 | > EMNLP 2021. 8 | 9 | The notebook is described in the following blog post: 10 | 11 | > [社内勉強会で使用したSimCSEのチュートリアル資料を公開しました](https://tech.legalforce.co.jp/entry/2023/11/09/110057). 12 | > LegalOn Technologies Engineering Blog. 13 | > 2023-11-09. 14 | 15 | ## Disclaimer 16 | 17 | This software is developed by LegalOn Technologies, Inc., 18 | but not an officially supported LegalOn Technologies product. 19 | 20 | ## License 21 | 22 | Licensed under Apache License, Version 2.0 23 | ([LICENSE-APACHE](./LICENSE-APACHE) or http://www.apache.org/licenses/LICENSE-2.0). 24 | 25 | ## Credits 26 | 27 | Portions of this software are ported from https://github.com/hppRC/simple-simcse. 28 | -------------------------------------------------------------------------------- /Tutorial_on_SimCSE.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "provenance": [], 7 | "gpuType": "T4" 8 | }, 9 | "kernelspec": { 10 | "name": "python3", 11 | "display_name": "Python 3" 12 | }, 13 | "language_info": { 14 | "name": "python" 15 | }, 16 | "widgets": { 17 | "application/vnd.jupyter.widget-state+json": { 18 | "f56b8f398ee247dfa87ab14a4bf49cb0": { 19 | "model_module": "@jupyter-widgets/controls", 20 | "model_name": "HBoxModel", 21 | "model_module_version": "1.5.0", 22 | "state": { 23 | "_dom_classes": [], 24 | "_model_module": "@jupyter-widgets/controls", 25 | "_model_module_version": "1.5.0", 26 | "_model_name": "HBoxModel", 27 | "_view_count": null, 28 | "_view_module": "@jupyter-widgets/controls", 29 | "_view_module_version": "1.5.0", 30 | "_view_name": "HBoxView", 31 | "box_style": "", 32 | "children": [ 33 | "IPY_MODEL_0d89a0be7ccd4e688486fe195284851a", 34 | "IPY_MODEL_b27907ebe6f04850b34dd61707ffb01e", 35 | "IPY_MODEL_8298baf4dba046559b7b13f08b3dad90" 36 | ], 37 | "layout": "IPY_MODEL_ed86772d77ea4d4a97c57cd74ae6f258" 38 | } 39 | }, 40 | "0d89a0be7ccd4e688486fe195284851a": { 41 | "model_module": "@jupyter-widgets/controls", 42 | "model_name": "HTMLModel", 43 | "model_module_version": "1.5.0", 44 | "state": { 45 | "_dom_classes": [], 46 | "_model_module": "@jupyter-widgets/controls", 47 | "_model_module_version": "1.5.0", 48 | "_model_name": "HTMLModel", 49 | "_view_count": null, 50 | "_view_module": "@jupyter-widgets/controls", 51 | "_view_module_version": "1.5.0", 52 | "_view_name": "HTMLView", 53 | "description": "", 54 | "description_tooltip": null, 55 | "layout": "IPY_MODEL_e1ee5a33a31d4991a81cae0426dacf44", 56 | "placeholder": "​", 57 | "style": "IPY_MODEL_2621a2b022e24af0801b3e4b584f2c76", 58 | "value": "Downloading (…)lve/main/config.json: 100%" 59 | } 60 | }, 61 | "b27907ebe6f04850b34dd61707ffb01e": { 62 | "model_module": "@jupyter-widgets/controls", 63 | "model_name": "FloatProgressModel", 64 | "model_module_version": "1.5.0", 65 | "state": { 66 | "_dom_classes": [], 67 | "_model_module": "@jupyter-widgets/controls", 68 | "_model_module_version": "1.5.0", 69 | "_model_name": "FloatProgressModel", 70 | "_view_count": null, 71 | "_view_module": "@jupyter-widgets/controls", 72 | "_view_module_version": "1.5.0", 73 | "_view_name": "ProgressView", 74 | "bar_style": "success", 75 | "description": "", 76 | "description_tooltip": null, 77 | "layout": "IPY_MODEL_e0af523f1c304388a027b4fffd470010", 78 | "max": 570, 79 | "min": 0, 80 | "orientation": "horizontal", 81 | "style": "IPY_MODEL_bfe851cf720346d0b63347283acc99ac", 82 | "value": 570 83 | } 84 | }, 85 | "8298baf4dba046559b7b13f08b3dad90": { 86 | "model_module": "@jupyter-widgets/controls", 87 | "model_name": "HTMLModel", 88 | "model_module_version": "1.5.0", 89 | "state": { 90 | "_dom_classes": [], 91 | "_model_module": "@jupyter-widgets/controls", 92 | "_model_module_version": "1.5.0", 93 | "_model_name": "HTMLModel", 94 | "_view_count": null, 95 | "_view_module": "@jupyter-widgets/controls", 96 | "_view_module_version": "1.5.0", 97 | "_view_name": "HTMLView", 98 | "description": "", 99 | "description_tooltip": null, 100 | "layout": "IPY_MODEL_e77f3ecf903848afae1db2dc3a21f565", 101 | "placeholder": "​", 102 | "style": "IPY_MODEL_178c637e246d4486bf78d6f76d4bcc67", 103 | "value": " 570/570 [00:00<00:00, 10.2kB/s]" 104 | } 105 | }, 106 | "ed86772d77ea4d4a97c57cd74ae6f258": { 107 | "model_module": "@jupyter-widgets/base", 108 | "model_name": "LayoutModel", 109 | "model_module_version": "1.2.0", 110 | "state": { 111 | "_model_module": "@jupyter-widgets/base", 112 | "_model_module_version": "1.2.0", 113 | "_model_name": "LayoutModel", 114 | "_view_count": null, 115 | "_view_module": "@jupyter-widgets/base", 116 | "_view_module_version": "1.2.0", 117 | "_view_name": "LayoutView", 118 | "align_content": null, 119 | "align_items": null, 120 | "align_self": null, 121 | "border": null, 122 | "bottom": null, 123 | "display": null, 124 | "flex": null, 125 | "flex_flow": null, 126 | "grid_area": null, 127 | "grid_auto_columns": null, 128 | "grid_auto_flow": null, 129 | "grid_auto_rows": null, 130 | "grid_column": null, 131 | "grid_gap": null, 132 | "grid_row": null, 133 | "grid_template_areas": null, 134 | "grid_template_columns": null, 135 | "grid_template_rows": null, 136 | "height": null, 137 | "justify_content": null, 138 | "justify_items": null, 139 | "left": null, 140 | "margin": null, 141 | "max_height": null, 142 | "max_width": null, 143 | "min_height": null, 144 | "min_width": null, 145 | "object_fit": null, 146 | "object_position": null, 147 | "order": null, 148 | "overflow": null, 149 | "overflow_x": null, 150 | "overflow_y": null, 151 | "padding": null, 152 | "right": null, 153 | "top": null, 154 | "visibility": null, 155 | "width": null 156 | } 157 | }, 158 | "e1ee5a33a31d4991a81cae0426dacf44": { 159 | "model_module": "@jupyter-widgets/base", 160 | "model_name": "LayoutModel", 161 | "model_module_version": "1.2.0", 162 | "state": { 163 | "_model_module": "@jupyter-widgets/base", 164 | "_model_module_version": "1.2.0", 165 | "_model_name": "LayoutModel", 166 | "_view_count": null, 167 | "_view_module": "@jupyter-widgets/base", 168 | "_view_module_version": "1.2.0", 169 | "_view_name": "LayoutView", 170 | "align_content": null, 171 | "align_items": null, 172 | "align_self": null, 173 | "border": null, 174 | "bottom": null, 175 | "display": null, 176 | "flex": null, 177 | "flex_flow": null, 178 | "grid_area": null, 179 | "grid_auto_columns": null, 180 | "grid_auto_flow": null, 181 | "grid_auto_rows": null, 182 | "grid_column": null, 183 | "grid_gap": null, 184 | "grid_row": null, 185 | "grid_template_areas": null, 186 | "grid_template_columns": null, 187 | "grid_template_rows": null, 188 | "height": null, 189 | "justify_content": null, 190 | "justify_items": null, 191 | "left": null, 192 | "margin": null, 193 | "max_height": null, 194 | "max_width": null, 195 | "min_height": null, 196 | "min_width": null, 197 | "object_fit": null, 198 | "object_position": null, 199 | "order": null, 200 | "overflow": null, 201 | "overflow_x": null, 202 | "overflow_y": null, 203 | "padding": null, 204 | "right": null, 205 | "top": null, 206 | "visibility": null, 207 | "width": null 208 | } 209 | }, 210 | "2621a2b022e24af0801b3e4b584f2c76": { 211 | "model_module": "@jupyter-widgets/controls", 212 | "model_name": "DescriptionStyleModel", 213 | "model_module_version": "1.5.0", 214 | "state": { 215 | "_model_module": "@jupyter-widgets/controls", 216 | "_model_module_version": "1.5.0", 217 | "_model_name": "DescriptionStyleModel", 218 | "_view_count": null, 219 | "_view_module": "@jupyter-widgets/base", 220 | "_view_module_version": "1.2.0", 221 | "_view_name": "StyleView", 222 | "description_width": "" 223 | } 224 | }, 225 | "e0af523f1c304388a027b4fffd470010": { 226 | "model_module": "@jupyter-widgets/base", 227 | "model_name": "LayoutModel", 228 | "model_module_version": "1.2.0", 229 | "state": { 230 | "_model_module": "@jupyter-widgets/base", 231 | "_model_module_version": "1.2.0", 232 | "_model_name": "LayoutModel", 233 | "_view_count": null, 234 | "_view_module": "@jupyter-widgets/base", 235 | "_view_module_version": "1.2.0", 236 | "_view_name": "LayoutView", 237 | "align_content": null, 238 | "align_items": null, 239 | "align_self": null, 240 | "border": null, 241 | "bottom": null, 242 | "display": null, 243 | "flex": null, 244 | "flex_flow": null, 245 | "grid_area": null, 246 | "grid_auto_columns": null, 247 | "grid_auto_flow": null, 248 | "grid_auto_rows": null, 249 | "grid_column": null, 250 | "grid_gap": null, 251 | "grid_row": null, 252 | "grid_template_areas": null, 253 | "grid_template_columns": null, 254 | "grid_template_rows": null, 255 | "height": null, 256 | "justify_content": null, 257 | "justify_items": null, 258 | "left": null, 259 | "margin": null, 260 | "max_height": null, 261 | "max_width": null, 262 | "min_height": null, 263 | "min_width": null, 264 | "object_fit": null, 265 | "object_position": null, 266 | "order": null, 267 | "overflow": null, 268 | "overflow_x": null, 269 | "overflow_y": null, 270 | "padding": null, 271 | "right": null, 272 | "top": null, 273 | "visibility": null, 274 | "width": null 275 | } 276 | }, 277 | "bfe851cf720346d0b63347283acc99ac": { 278 | "model_module": "@jupyter-widgets/controls", 279 | "model_name": "ProgressStyleModel", 280 | "model_module_version": "1.5.0", 281 | "state": { 282 | "_model_module": "@jupyter-widgets/controls", 283 | "_model_module_version": "1.5.0", 284 | "_model_name": "ProgressStyleModel", 285 | "_view_count": null, 286 | "_view_module": "@jupyter-widgets/base", 287 | "_view_module_version": "1.2.0", 288 | "_view_name": "StyleView", 289 | "bar_color": null, 290 | "description_width": "" 291 | } 292 | }, 293 | "e77f3ecf903848afae1db2dc3a21f565": { 294 | "model_module": "@jupyter-widgets/base", 295 | "model_name": "LayoutModel", 296 | "model_module_version": "1.2.0", 297 | "state": { 298 | "_model_module": "@jupyter-widgets/base", 299 | "_model_module_version": "1.2.0", 300 | "_model_name": "LayoutModel", 301 | "_view_count": null, 302 | "_view_module": "@jupyter-widgets/base", 303 | "_view_module_version": "1.2.0", 304 | "_view_name": "LayoutView", 305 | "align_content": null, 306 | "align_items": null, 307 | "align_self": null, 308 | "border": null, 309 | "bottom": null, 310 | "display": null, 311 | "flex": null, 312 | "flex_flow": null, 313 | "grid_area": null, 314 | "grid_auto_columns": null, 315 | "grid_auto_flow": null, 316 | "grid_auto_rows": null, 317 | "grid_column": null, 318 | "grid_gap": null, 319 | "grid_row": null, 320 | "grid_template_areas": null, 321 | "grid_template_columns": null, 322 | "grid_template_rows": null, 323 | "height": null, 324 | "justify_content": null, 325 | "justify_items": null, 326 | "left": null, 327 | "margin": null, 328 | "max_height": null, 329 | "max_width": null, 330 | "min_height": null, 331 | "min_width": null, 332 | "object_fit": null, 333 | "object_position": null, 334 | "order": null, 335 | "overflow": null, 336 | "overflow_x": null, 337 | "overflow_y": null, 338 | "padding": null, 339 | "right": null, 340 | "top": null, 341 | "visibility": null, 342 | "width": null 343 | } 344 | }, 345 | "178c637e246d4486bf78d6f76d4bcc67": { 346 | "model_module": "@jupyter-widgets/controls", 347 | "model_name": "DescriptionStyleModel", 348 | "model_module_version": "1.5.0", 349 | "state": { 350 | "_model_module": "@jupyter-widgets/controls", 351 | "_model_module_version": "1.5.0", 352 | "_model_name": "DescriptionStyleModel", 353 | "_view_count": null, 354 | "_view_module": "@jupyter-widgets/base", 355 | "_view_module_version": "1.2.0", 356 | "_view_name": "StyleView", 357 | "description_width": "" 358 | } 359 | }, 360 | "d19466cf70da42e69ec97a0e078677f6": { 361 | "model_module": "@jupyter-widgets/controls", 362 | "model_name": "HBoxModel", 363 | "model_module_version": "1.5.0", 364 | "state": { 365 | "_dom_classes": [], 366 | "_model_module": "@jupyter-widgets/controls", 367 | "_model_module_version": "1.5.0", 368 | "_model_name": "HBoxModel", 369 | "_view_count": null, 370 | "_view_module": "@jupyter-widgets/controls", 371 | "_view_module_version": "1.5.0", 372 | "_view_name": "HBoxView", 373 | "box_style": "", 374 | "children": [ 375 | "IPY_MODEL_c1077acfc06740899abed6b993784313", 376 | "IPY_MODEL_d94598b0db264defbbf1ed575baf3130", 377 | "IPY_MODEL_c3c41882d2994fa8a4214afe2a95cf69" 378 | ], 379 | "layout": "IPY_MODEL_4de477ae5de048ccbfc048d95e03eb76" 380 | } 381 | }, 382 | "c1077acfc06740899abed6b993784313": { 383 | "model_module": "@jupyter-widgets/controls", 384 | "model_name": "HTMLModel", 385 | "model_module_version": "1.5.0", 386 | "state": { 387 | "_dom_classes": [], 388 | "_model_module": "@jupyter-widgets/controls", 389 | "_model_module_version": "1.5.0", 390 | "_model_name": "HTMLModel", 391 | "_view_count": null, 392 | "_view_module": "@jupyter-widgets/controls", 393 | "_view_module_version": "1.5.0", 394 | "_view_name": "HTMLView", 395 | "description": "", 396 | "description_tooltip": null, 397 | "layout": "IPY_MODEL_3020a5ea349e47998bf6a7749c1f124c", 398 | "placeholder": "​", 399 | "style": "IPY_MODEL_fbe178463818478aa40159e8285a529d", 400 | "value": "Downloading model.safetensors: 100%" 401 | } 402 | }, 403 | "d94598b0db264defbbf1ed575baf3130": { 404 | "model_module": "@jupyter-widgets/controls", 405 | "model_name": "FloatProgressModel", 406 | "model_module_version": "1.5.0", 407 | "state": { 408 | "_dom_classes": [], 409 | "_model_module": "@jupyter-widgets/controls", 410 | "_model_module_version": "1.5.0", 411 | "_model_name": "FloatProgressModel", 412 | "_view_count": null, 413 | "_view_module": "@jupyter-widgets/controls", 414 | "_view_module_version": "1.5.0", 415 | "_view_name": "ProgressView", 416 | "bar_style": "success", 417 | "description": "", 418 | "description_tooltip": null, 419 | "layout": "IPY_MODEL_a6396e902ac248c7877494a9e248df22", 420 | "max": 440449768, 421 | "min": 0, 422 | "orientation": "horizontal", 423 | "style": "IPY_MODEL_edb212d4e7674c3797e5184d3a093b89", 424 | "value": 440449768 425 | } 426 | }, 427 | "c3c41882d2994fa8a4214afe2a95cf69": { 428 | "model_module": "@jupyter-widgets/controls", 429 | "model_name": "HTMLModel", 430 | "model_module_version": "1.5.0", 431 | "state": { 432 | "_dom_classes": [], 433 | "_model_module": "@jupyter-widgets/controls", 434 | "_model_module_version": "1.5.0", 435 | "_model_name": "HTMLModel", 436 | "_view_count": null, 437 | "_view_module": "@jupyter-widgets/controls", 438 | "_view_module_version": "1.5.0", 439 | "_view_name": "HTMLView", 440 | "description": "", 441 | "description_tooltip": null, 442 | "layout": "IPY_MODEL_ad61d035232544aea3e1f3cc666b4ae7", 443 | "placeholder": "​", 444 | "style": "IPY_MODEL_c61ac74e2a2d44fd9505686b476b9f3a", 445 | "value": " 440M/440M [00:02<00:00, 247MB/s]" 446 | } 447 | }, 448 | "4de477ae5de048ccbfc048d95e03eb76": { 449 | "model_module": "@jupyter-widgets/base", 450 | "model_name": "LayoutModel", 451 | "model_module_version": "1.2.0", 452 | "state": { 453 | "_model_module": "@jupyter-widgets/base", 454 | "_model_module_version": "1.2.0", 455 | "_model_name": "LayoutModel", 456 | "_view_count": null, 457 | "_view_module": "@jupyter-widgets/base", 458 | "_view_module_version": "1.2.0", 459 | "_view_name": "LayoutView", 460 | "align_content": null, 461 | "align_items": null, 462 | "align_self": null, 463 | "border": null, 464 | "bottom": null, 465 | "display": null, 466 | "flex": null, 467 | "flex_flow": null, 468 | "grid_area": null, 469 | "grid_auto_columns": null, 470 | "grid_auto_flow": null, 471 | "grid_auto_rows": null, 472 | "grid_column": null, 473 | "grid_gap": null, 474 | "grid_row": null, 475 | "grid_template_areas": null, 476 | "grid_template_columns": null, 477 | "grid_template_rows": null, 478 | "height": null, 479 | "justify_content": null, 480 | "justify_items": null, 481 | "left": null, 482 | "margin": null, 483 | "max_height": null, 484 | "max_width": null, 485 | "min_height": null, 486 | "min_width": null, 487 | "object_fit": null, 488 | "object_position": null, 489 | "order": null, 490 | "overflow": null, 491 | "overflow_x": null, 492 | "overflow_y": null, 493 | "padding": null, 494 | "right": null, 495 | "top": null, 496 | "visibility": null, 497 | "width": null 498 | } 499 | }, 500 | "3020a5ea349e47998bf6a7749c1f124c": { 501 | "model_module": "@jupyter-widgets/base", 502 | "model_name": "LayoutModel", 503 | "model_module_version": "1.2.0", 504 | "state": { 505 | "_model_module": "@jupyter-widgets/base", 506 | "_model_module_version": "1.2.0", 507 | "_model_name": "LayoutModel", 508 | "_view_count": null, 509 | "_view_module": "@jupyter-widgets/base", 510 | "_view_module_version": "1.2.0", 511 | "_view_name": "LayoutView", 512 | "align_content": null, 513 | "align_items": null, 514 | "align_self": null, 515 | "border": null, 516 | "bottom": null, 517 | "display": null, 518 | "flex": null, 519 | "flex_flow": null, 520 | "grid_area": null, 521 | "grid_auto_columns": null, 522 | "grid_auto_flow": null, 523 | "grid_auto_rows": null, 524 | "grid_column": null, 525 | "grid_gap": null, 526 | "grid_row": null, 527 | "grid_template_areas": null, 528 | "grid_template_columns": null, 529 | "grid_template_rows": null, 530 | "height": null, 531 | "justify_content": null, 532 | "justify_items": null, 533 | "left": null, 534 | "margin": null, 535 | "max_height": null, 536 | "max_width": null, 537 | "min_height": null, 538 | "min_width": null, 539 | "object_fit": null, 540 | "object_position": null, 541 | "order": null, 542 | "overflow": null, 543 | "overflow_x": null, 544 | "overflow_y": null, 545 | "padding": null, 546 | "right": null, 547 | "top": null, 548 | "visibility": null, 549 | "width": null 550 | } 551 | }, 552 | "fbe178463818478aa40159e8285a529d": { 553 | "model_module": "@jupyter-widgets/controls", 554 | "model_name": "DescriptionStyleModel", 555 | "model_module_version": "1.5.0", 556 | "state": { 557 | "_model_module": "@jupyter-widgets/controls", 558 | "_model_module_version": "1.5.0", 559 | "_model_name": "DescriptionStyleModel", 560 | "_view_count": null, 561 | "_view_module": "@jupyter-widgets/base", 562 | "_view_module_version": "1.2.0", 563 | "_view_name": "StyleView", 564 | "description_width": "" 565 | } 566 | }, 567 | "a6396e902ac248c7877494a9e248df22": { 568 | "model_module": "@jupyter-widgets/base", 569 | "model_name": "LayoutModel", 570 | "model_module_version": "1.2.0", 571 | "state": { 572 | "_model_module": "@jupyter-widgets/base", 573 | "_model_module_version": "1.2.0", 574 | "_model_name": "LayoutModel", 575 | "_view_count": null, 576 | "_view_module": "@jupyter-widgets/base", 577 | "_view_module_version": "1.2.0", 578 | "_view_name": "LayoutView", 579 | "align_content": null, 580 | "align_items": null, 581 | "align_self": null, 582 | "border": null, 583 | "bottom": null, 584 | "display": null, 585 | "flex": null, 586 | "flex_flow": null, 587 | "grid_area": null, 588 | "grid_auto_columns": null, 589 | "grid_auto_flow": null, 590 | "grid_auto_rows": null, 591 | "grid_column": null, 592 | "grid_gap": null, 593 | "grid_row": null, 594 | "grid_template_areas": null, 595 | "grid_template_columns": null, 596 | "grid_template_rows": null, 597 | "height": null, 598 | "justify_content": null, 599 | "justify_items": null, 600 | "left": null, 601 | "margin": null, 602 | "max_height": null, 603 | "max_width": null, 604 | "min_height": null, 605 | "min_width": null, 606 | "object_fit": null, 607 | "object_position": null, 608 | "order": null, 609 | "overflow": null, 610 | "overflow_x": null, 611 | "overflow_y": null, 612 | "padding": null, 613 | "right": null, 614 | "top": null, 615 | "visibility": null, 616 | "width": null 617 | } 618 | }, 619 | "edb212d4e7674c3797e5184d3a093b89": { 620 | "model_module": "@jupyter-widgets/controls", 621 | "model_name": "ProgressStyleModel", 622 | "model_module_version": "1.5.0", 623 | "state": { 624 | "_model_module": "@jupyter-widgets/controls", 625 | "_model_module_version": "1.5.0", 626 | "_model_name": "ProgressStyleModel", 627 | "_view_count": null, 628 | "_view_module": "@jupyter-widgets/base", 629 | "_view_module_version": "1.2.0", 630 | "_view_name": "StyleView", 631 | "bar_color": null, 632 | "description_width": "" 633 | } 634 | }, 635 | "ad61d035232544aea3e1f3cc666b4ae7": { 636 | "model_module": "@jupyter-widgets/base", 637 | "model_name": "LayoutModel", 638 | "model_module_version": "1.2.0", 639 | "state": { 640 | "_model_module": "@jupyter-widgets/base", 641 | "_model_module_version": "1.2.0", 642 | "_model_name": "LayoutModel", 643 | "_view_count": null, 644 | "_view_module": "@jupyter-widgets/base", 645 | "_view_module_version": "1.2.0", 646 | "_view_name": "LayoutView", 647 | "align_content": null, 648 | "align_items": null, 649 | "align_self": null, 650 | "border": null, 651 | "bottom": null, 652 | "display": null, 653 | "flex": null, 654 | "flex_flow": null, 655 | "grid_area": null, 656 | "grid_auto_columns": null, 657 | "grid_auto_flow": null, 658 | "grid_auto_rows": null, 659 | "grid_column": null, 660 | "grid_gap": null, 661 | "grid_row": null, 662 | "grid_template_areas": null, 663 | "grid_template_columns": null, 664 | "grid_template_rows": null, 665 | "height": null, 666 | "justify_content": null, 667 | "justify_items": null, 668 | "left": null, 669 | "margin": null, 670 | "max_height": null, 671 | "max_width": null, 672 | "min_height": null, 673 | "min_width": null, 674 | "object_fit": null, 675 | "object_position": null, 676 | "order": null, 677 | "overflow": null, 678 | "overflow_x": null, 679 | "overflow_y": null, 680 | "padding": null, 681 | "right": null, 682 | "top": null, 683 | "visibility": null, 684 | "width": null 685 | } 686 | }, 687 | "c61ac74e2a2d44fd9505686b476b9f3a": { 688 | "model_module": "@jupyter-widgets/controls", 689 | "model_name": "DescriptionStyleModel", 690 | "model_module_version": "1.5.0", 691 | "state": { 692 | "_model_module": "@jupyter-widgets/controls", 693 | "_model_module_version": "1.5.0", 694 | "_model_name": "DescriptionStyleModel", 695 | "_view_count": null, 696 | "_view_module": "@jupyter-widgets/base", 697 | "_view_module_version": "1.2.0", 698 | "_view_name": "StyleView", 699 | "description_width": "" 700 | } 701 | }, 702 | "a4891855e50a452f9671e49b1c1936bc": { 703 | "model_module": "@jupyter-widgets/controls", 704 | "model_name": "HBoxModel", 705 | "model_module_version": "1.5.0", 706 | "state": { 707 | "_dom_classes": [], 708 | "_model_module": "@jupyter-widgets/controls", 709 | "_model_module_version": "1.5.0", 710 | "_model_name": "HBoxModel", 711 | "_view_count": null, 712 | "_view_module": "@jupyter-widgets/controls", 713 | "_view_module_version": "1.5.0", 714 | "_view_name": "HBoxView", 715 | "box_style": "", 716 | "children": [ 717 | "IPY_MODEL_c11ebafc26c449178b5e61c437584009", 718 | "IPY_MODEL_dccbb5c014b5481a84b77ff57001e600", 719 | "IPY_MODEL_bcdf9ed42db94c4ea8532fe561c65cbc" 720 | ], 721 | "layout": "IPY_MODEL_54c0a73db1994c8e9fd2d605241c1253" 722 | } 723 | }, 724 | "c11ebafc26c449178b5e61c437584009": { 725 | "model_module": "@jupyter-widgets/controls", 726 | "model_name": "HTMLModel", 727 | "model_module_version": "1.5.0", 728 | "state": { 729 | "_dom_classes": [], 730 | "_model_module": "@jupyter-widgets/controls", 731 | "_model_module_version": "1.5.0", 732 | "_model_name": "HTMLModel", 733 | "_view_count": null, 734 | "_view_module": "@jupyter-widgets/controls", 735 | "_view_module_version": "1.5.0", 736 | "_view_name": "HTMLView", 737 | "description": "", 738 | "description_tooltip": null, 739 | "layout": "IPY_MODEL_9caba4686e8a4db4adc32f6e2cffbd4c", 740 | "placeholder": "​", 741 | "style": "IPY_MODEL_918a52cbc7f346da81ae77b76f33a713", 742 | "value": "Downloading (…)okenizer_config.json: 100%" 743 | } 744 | }, 745 | "dccbb5c014b5481a84b77ff57001e600": { 746 | "model_module": "@jupyter-widgets/controls", 747 | "model_name": "FloatProgressModel", 748 | "model_module_version": "1.5.0", 749 | "state": { 750 | "_dom_classes": [], 751 | "_model_module": "@jupyter-widgets/controls", 752 | "_model_module_version": "1.5.0", 753 | "_model_name": "FloatProgressModel", 754 | "_view_count": null, 755 | "_view_module": "@jupyter-widgets/controls", 756 | "_view_module_version": "1.5.0", 757 | "_view_name": "ProgressView", 758 | "bar_style": "success", 759 | "description": "", 760 | "description_tooltip": null, 761 | "layout": "IPY_MODEL_3ceecade00b9474684e6e079b0bf507f", 762 | "max": 28, 763 | "min": 0, 764 | "orientation": "horizontal", 765 | "style": "IPY_MODEL_34e1e8af04b24a499141cc4fa5bd0693", 766 | "value": 28 767 | } 768 | }, 769 | "bcdf9ed42db94c4ea8532fe561c65cbc": { 770 | "model_module": "@jupyter-widgets/controls", 771 | "model_name": "HTMLModel", 772 | "model_module_version": "1.5.0", 773 | "state": { 774 | "_dom_classes": [], 775 | "_model_module": "@jupyter-widgets/controls", 776 | "_model_module_version": "1.5.0", 777 | "_model_name": "HTMLModel", 778 | "_view_count": null, 779 | "_view_module": "@jupyter-widgets/controls", 780 | "_view_module_version": "1.5.0", 781 | "_view_name": "HTMLView", 782 | "description": "", 783 | "description_tooltip": null, 784 | "layout": "IPY_MODEL_663ac313a6e04906bcfdaf22477843a9", 785 | "placeholder": "​", 786 | "style": "IPY_MODEL_d846c8467dd24a7e9499834a13cd02f1", 787 | "value": " 28.0/28.0 [00:00<00:00, 1.00kB/s]" 788 | } 789 | }, 790 | "54c0a73db1994c8e9fd2d605241c1253": { 791 | "model_module": "@jupyter-widgets/base", 792 | "model_name": "LayoutModel", 793 | "model_module_version": "1.2.0", 794 | "state": { 795 | "_model_module": "@jupyter-widgets/base", 796 | "_model_module_version": "1.2.0", 797 | "_model_name": "LayoutModel", 798 | "_view_count": null, 799 | "_view_module": "@jupyter-widgets/base", 800 | "_view_module_version": "1.2.0", 801 | "_view_name": "LayoutView", 802 | "align_content": null, 803 | "align_items": null, 804 | "align_self": null, 805 | "border": null, 806 | "bottom": null, 807 | "display": null, 808 | "flex": null, 809 | "flex_flow": null, 810 | "grid_area": null, 811 | "grid_auto_columns": null, 812 | "grid_auto_flow": null, 813 | "grid_auto_rows": null, 814 | "grid_column": null, 815 | "grid_gap": null, 816 | "grid_row": null, 817 | "grid_template_areas": null, 818 | "grid_template_columns": null, 819 | "grid_template_rows": null, 820 | "height": null, 821 | "justify_content": null, 822 | "justify_items": null, 823 | "left": null, 824 | "margin": null, 825 | "max_height": null, 826 | "max_width": null, 827 | "min_height": null, 828 | "min_width": null, 829 | "object_fit": null, 830 | "object_position": null, 831 | "order": null, 832 | "overflow": null, 833 | "overflow_x": null, 834 | "overflow_y": null, 835 | "padding": null, 836 | "right": null, 837 | "top": null, 838 | "visibility": null, 839 | "width": null 840 | } 841 | }, 842 | "9caba4686e8a4db4adc32f6e2cffbd4c": { 843 | "model_module": "@jupyter-widgets/base", 844 | "model_name": "LayoutModel", 845 | "model_module_version": "1.2.0", 846 | "state": { 847 | "_model_module": "@jupyter-widgets/base", 848 | "_model_module_version": "1.2.0", 849 | "_model_name": "LayoutModel", 850 | "_view_count": null, 851 | "_view_module": "@jupyter-widgets/base", 852 | "_view_module_version": "1.2.0", 853 | "_view_name": "LayoutView", 854 | "align_content": null, 855 | "align_items": null, 856 | "align_self": null, 857 | "border": null, 858 | "bottom": null, 859 | "display": null, 860 | "flex": null, 861 | "flex_flow": null, 862 | "grid_area": null, 863 | "grid_auto_columns": null, 864 | "grid_auto_flow": null, 865 | "grid_auto_rows": null, 866 | "grid_column": null, 867 | "grid_gap": null, 868 | "grid_row": null, 869 | "grid_template_areas": null, 870 | "grid_template_columns": null, 871 | "grid_template_rows": null, 872 | "height": null, 873 | "justify_content": null, 874 | "justify_items": null, 875 | "left": null, 876 | "margin": null, 877 | "max_height": null, 878 | "max_width": null, 879 | "min_height": null, 880 | "min_width": null, 881 | "object_fit": null, 882 | "object_position": null, 883 | "order": null, 884 | "overflow": null, 885 | "overflow_x": null, 886 | "overflow_y": null, 887 | "padding": null, 888 | "right": null, 889 | "top": null, 890 | "visibility": null, 891 | "width": null 892 | } 893 | }, 894 | "918a52cbc7f346da81ae77b76f33a713": { 895 | "model_module": "@jupyter-widgets/controls", 896 | "model_name": "DescriptionStyleModel", 897 | "model_module_version": "1.5.0", 898 | "state": { 899 | "_model_module": "@jupyter-widgets/controls", 900 | "_model_module_version": "1.5.0", 901 | "_model_name": "DescriptionStyleModel", 902 | "_view_count": null, 903 | "_view_module": "@jupyter-widgets/base", 904 | "_view_module_version": "1.2.0", 905 | "_view_name": "StyleView", 906 | "description_width": "" 907 | } 908 | }, 909 | "3ceecade00b9474684e6e079b0bf507f": { 910 | "model_module": "@jupyter-widgets/base", 911 | "model_name": "LayoutModel", 912 | "model_module_version": "1.2.0", 913 | "state": { 914 | "_model_module": "@jupyter-widgets/base", 915 | "_model_module_version": "1.2.0", 916 | "_model_name": "LayoutModel", 917 | "_view_count": null, 918 | "_view_module": "@jupyter-widgets/base", 919 | "_view_module_version": "1.2.0", 920 | "_view_name": "LayoutView", 921 | "align_content": null, 922 | "align_items": null, 923 | "align_self": null, 924 | "border": null, 925 | "bottom": null, 926 | "display": null, 927 | "flex": null, 928 | "flex_flow": null, 929 | "grid_area": null, 930 | "grid_auto_columns": null, 931 | "grid_auto_flow": null, 932 | "grid_auto_rows": null, 933 | "grid_column": null, 934 | "grid_gap": null, 935 | "grid_row": null, 936 | "grid_template_areas": null, 937 | "grid_template_columns": null, 938 | "grid_template_rows": null, 939 | "height": null, 940 | "justify_content": null, 941 | "justify_items": null, 942 | "left": null, 943 | "margin": null, 944 | "max_height": null, 945 | "max_width": null, 946 | "min_height": null, 947 | "min_width": null, 948 | "object_fit": null, 949 | "object_position": null, 950 | "order": null, 951 | "overflow": null, 952 | "overflow_x": null, 953 | "overflow_y": null, 954 | "padding": null, 955 | "right": null, 956 | "top": null, 957 | "visibility": null, 958 | "width": null 959 | } 960 | }, 961 | "34e1e8af04b24a499141cc4fa5bd0693": { 962 | "model_module": "@jupyter-widgets/controls", 963 | "model_name": "ProgressStyleModel", 964 | "model_module_version": "1.5.0", 965 | "state": { 966 | "_model_module": "@jupyter-widgets/controls", 967 | "_model_module_version": "1.5.0", 968 | "_model_name": "ProgressStyleModel", 969 | "_view_count": null, 970 | "_view_module": "@jupyter-widgets/base", 971 | "_view_module_version": "1.2.0", 972 | "_view_name": "StyleView", 973 | "bar_color": null, 974 | "description_width": "" 975 | } 976 | }, 977 | "663ac313a6e04906bcfdaf22477843a9": { 978 | "model_module": "@jupyter-widgets/base", 979 | "model_name": "LayoutModel", 980 | "model_module_version": "1.2.0", 981 | "state": { 982 | "_model_module": "@jupyter-widgets/base", 983 | "_model_module_version": "1.2.0", 984 | "_model_name": "LayoutModel", 985 | "_view_count": null, 986 | "_view_module": "@jupyter-widgets/base", 987 | "_view_module_version": "1.2.0", 988 | "_view_name": "LayoutView", 989 | "align_content": null, 990 | "align_items": null, 991 | "align_self": null, 992 | "border": null, 993 | "bottom": null, 994 | "display": null, 995 | "flex": null, 996 | "flex_flow": null, 997 | "grid_area": null, 998 | "grid_auto_columns": null, 999 | "grid_auto_flow": null, 1000 | "grid_auto_rows": null, 1001 | "grid_column": null, 1002 | "grid_gap": null, 1003 | "grid_row": null, 1004 | "grid_template_areas": null, 1005 | "grid_template_columns": null, 1006 | "grid_template_rows": null, 1007 | "height": null, 1008 | "justify_content": null, 1009 | "justify_items": null, 1010 | "left": null, 1011 | "margin": null, 1012 | "max_height": null, 1013 | "max_width": null, 1014 | "min_height": null, 1015 | "min_width": null, 1016 | "object_fit": null, 1017 | "object_position": null, 1018 | "order": null, 1019 | "overflow": null, 1020 | "overflow_x": null, 1021 | "overflow_y": null, 1022 | "padding": null, 1023 | "right": null, 1024 | "top": null, 1025 | "visibility": null, 1026 | "width": null 1027 | } 1028 | }, 1029 | "d846c8467dd24a7e9499834a13cd02f1": { 1030 | "model_module": "@jupyter-widgets/controls", 1031 | "model_name": "DescriptionStyleModel", 1032 | "model_module_version": "1.5.0", 1033 | "state": { 1034 | "_model_module": "@jupyter-widgets/controls", 1035 | "_model_module_version": "1.5.0", 1036 | "_model_name": "DescriptionStyleModel", 1037 | "_view_count": null, 1038 | "_view_module": "@jupyter-widgets/base", 1039 | "_view_module_version": "1.2.0", 1040 | "_view_name": "StyleView", 1041 | "description_width": "" 1042 | } 1043 | }, 1044 | "cefd61371ec247a2b7720894a719a42d": { 1045 | "model_module": "@jupyter-widgets/controls", 1046 | "model_name": "HBoxModel", 1047 | "model_module_version": "1.5.0", 1048 | "state": { 1049 | "_dom_classes": [], 1050 | "_model_module": "@jupyter-widgets/controls", 1051 | "_model_module_version": "1.5.0", 1052 | "_model_name": "HBoxModel", 1053 | "_view_count": null, 1054 | "_view_module": "@jupyter-widgets/controls", 1055 | "_view_module_version": "1.5.0", 1056 | "_view_name": "HBoxView", 1057 | "box_style": "", 1058 | "children": [ 1059 | "IPY_MODEL_3e5d1f5547a14a7f92c4aaaa99fdc360", 1060 | "IPY_MODEL_6f6145876e894419a6f2271361b58c05", 1061 | "IPY_MODEL_7bb7e60954e744eea9ac066ef62ac979" 1062 | ], 1063 | "layout": "IPY_MODEL_af03e157d43d45e18d54a98a642baf3d" 1064 | } 1065 | }, 1066 | "3e5d1f5547a14a7f92c4aaaa99fdc360": { 1067 | "model_module": "@jupyter-widgets/controls", 1068 | "model_name": "HTMLModel", 1069 | "model_module_version": "1.5.0", 1070 | "state": { 1071 | "_dom_classes": [], 1072 | "_model_module": "@jupyter-widgets/controls", 1073 | "_model_module_version": "1.5.0", 1074 | "_model_name": "HTMLModel", 1075 | "_view_count": null, 1076 | "_view_module": "@jupyter-widgets/controls", 1077 | "_view_module_version": "1.5.0", 1078 | "_view_name": "HTMLView", 1079 | "description": "", 1080 | "description_tooltip": null, 1081 | "layout": "IPY_MODEL_13b393738e404659925ca7d8e975ca54", 1082 | "placeholder": "​", 1083 | "style": "IPY_MODEL_2360e8a6d5cd44dabf3f6ddd76fb4c83", 1084 | "value": "Downloading (…)solve/main/vocab.txt: 100%" 1085 | } 1086 | }, 1087 | "6f6145876e894419a6f2271361b58c05": { 1088 | "model_module": "@jupyter-widgets/controls", 1089 | "model_name": "FloatProgressModel", 1090 | "model_module_version": "1.5.0", 1091 | "state": { 1092 | "_dom_classes": [], 1093 | "_model_module": "@jupyter-widgets/controls", 1094 | "_model_module_version": "1.5.0", 1095 | "_model_name": "FloatProgressModel", 1096 | "_view_count": null, 1097 | "_view_module": "@jupyter-widgets/controls", 1098 | "_view_module_version": "1.5.0", 1099 | "_view_name": "ProgressView", 1100 | "bar_style": "success", 1101 | "description": "", 1102 | "description_tooltip": null, 1103 | "layout": "IPY_MODEL_d56dd16aae604b85ad65aabfd7f16401", 1104 | "max": 231508, 1105 | "min": 0, 1106 | "orientation": "horizontal", 1107 | "style": "IPY_MODEL_d3425e7e522b4f3cb654df99ce9fdde5", 1108 | "value": 231508 1109 | } 1110 | }, 1111 | "7bb7e60954e744eea9ac066ef62ac979": { 1112 | "model_module": "@jupyter-widgets/controls", 1113 | "model_name": "HTMLModel", 1114 | "model_module_version": "1.5.0", 1115 | "state": { 1116 | "_dom_classes": [], 1117 | "_model_module": "@jupyter-widgets/controls", 1118 | "_model_module_version": "1.5.0", 1119 | "_model_name": "HTMLModel", 1120 | "_view_count": null, 1121 | "_view_module": "@jupyter-widgets/controls", 1122 | "_view_module_version": "1.5.0", 1123 | "_view_name": "HTMLView", 1124 | "description": "", 1125 | "description_tooltip": null, 1126 | "layout": "IPY_MODEL_f606507c96514fed9f2758c7cb7afe67", 1127 | "placeholder": "​", 1128 | "style": "IPY_MODEL_da8eda9a9a844776823b206eeba8d9f8", 1129 | "value": " 232k/232k [00:00<00:00, 6.59MB/s]" 1130 | } 1131 | }, 1132 | "af03e157d43d45e18d54a98a642baf3d": { 1133 | "model_module": "@jupyter-widgets/base", 1134 | "model_name": "LayoutModel", 1135 | "model_module_version": "1.2.0", 1136 | "state": { 1137 | "_model_module": "@jupyter-widgets/base", 1138 | "_model_module_version": "1.2.0", 1139 | "_model_name": "LayoutModel", 1140 | "_view_count": null, 1141 | "_view_module": "@jupyter-widgets/base", 1142 | "_view_module_version": "1.2.0", 1143 | "_view_name": "LayoutView", 1144 | "align_content": null, 1145 | "align_items": null, 1146 | "align_self": null, 1147 | "border": null, 1148 | "bottom": null, 1149 | "display": null, 1150 | "flex": null, 1151 | "flex_flow": null, 1152 | "grid_area": null, 1153 | "grid_auto_columns": null, 1154 | "grid_auto_flow": null, 1155 | "grid_auto_rows": null, 1156 | "grid_column": null, 1157 | "grid_gap": null, 1158 | "grid_row": null, 1159 | "grid_template_areas": null, 1160 | "grid_template_columns": null, 1161 | "grid_template_rows": null, 1162 | "height": null, 1163 | "justify_content": null, 1164 | "justify_items": null, 1165 | "left": null, 1166 | "margin": null, 1167 | "max_height": null, 1168 | "max_width": null, 1169 | "min_height": null, 1170 | "min_width": null, 1171 | "object_fit": null, 1172 | "object_position": null, 1173 | "order": null, 1174 | "overflow": null, 1175 | "overflow_x": null, 1176 | "overflow_y": null, 1177 | "padding": null, 1178 | "right": null, 1179 | "top": null, 1180 | "visibility": null, 1181 | "width": null 1182 | } 1183 | }, 1184 | "13b393738e404659925ca7d8e975ca54": { 1185 | "model_module": "@jupyter-widgets/base", 1186 | "model_name": "LayoutModel", 1187 | "model_module_version": "1.2.0", 1188 | "state": { 1189 | "_model_module": "@jupyter-widgets/base", 1190 | "_model_module_version": "1.2.0", 1191 | "_model_name": "LayoutModel", 1192 | "_view_count": null, 1193 | "_view_module": "@jupyter-widgets/base", 1194 | "_view_module_version": "1.2.0", 1195 | "_view_name": "LayoutView", 1196 | "align_content": null, 1197 | "align_items": null, 1198 | "align_self": null, 1199 | "border": null, 1200 | "bottom": null, 1201 | "display": null, 1202 | "flex": null, 1203 | "flex_flow": null, 1204 | "grid_area": null, 1205 | "grid_auto_columns": null, 1206 | "grid_auto_flow": null, 1207 | "grid_auto_rows": null, 1208 | "grid_column": null, 1209 | "grid_gap": null, 1210 | "grid_row": null, 1211 | "grid_template_areas": null, 1212 | "grid_template_columns": null, 1213 | "grid_template_rows": null, 1214 | "height": null, 1215 | "justify_content": null, 1216 | "justify_items": null, 1217 | "left": null, 1218 | "margin": null, 1219 | "max_height": null, 1220 | "max_width": null, 1221 | "min_height": null, 1222 | "min_width": null, 1223 | "object_fit": null, 1224 | "object_position": null, 1225 | "order": null, 1226 | "overflow": null, 1227 | "overflow_x": null, 1228 | "overflow_y": null, 1229 | "padding": null, 1230 | "right": null, 1231 | "top": null, 1232 | "visibility": null, 1233 | "width": null 1234 | } 1235 | }, 1236 | "2360e8a6d5cd44dabf3f6ddd76fb4c83": { 1237 | "model_module": "@jupyter-widgets/controls", 1238 | "model_name": "DescriptionStyleModel", 1239 | "model_module_version": "1.5.0", 1240 | "state": { 1241 | "_model_module": "@jupyter-widgets/controls", 1242 | "_model_module_version": "1.5.0", 1243 | "_model_name": "DescriptionStyleModel", 1244 | "_view_count": null, 1245 | "_view_module": "@jupyter-widgets/base", 1246 | "_view_module_version": "1.2.0", 1247 | "_view_name": "StyleView", 1248 | "description_width": "" 1249 | } 1250 | }, 1251 | "d56dd16aae604b85ad65aabfd7f16401": { 1252 | "model_module": "@jupyter-widgets/base", 1253 | "model_name": "LayoutModel", 1254 | "model_module_version": "1.2.0", 1255 | "state": { 1256 | "_model_module": "@jupyter-widgets/base", 1257 | "_model_module_version": "1.2.0", 1258 | "_model_name": "LayoutModel", 1259 | "_view_count": null, 1260 | "_view_module": "@jupyter-widgets/base", 1261 | "_view_module_version": "1.2.0", 1262 | "_view_name": "LayoutView", 1263 | "align_content": null, 1264 | "align_items": null, 1265 | "align_self": null, 1266 | "border": null, 1267 | "bottom": null, 1268 | "display": null, 1269 | "flex": null, 1270 | "flex_flow": null, 1271 | "grid_area": null, 1272 | "grid_auto_columns": null, 1273 | "grid_auto_flow": null, 1274 | "grid_auto_rows": null, 1275 | "grid_column": null, 1276 | "grid_gap": null, 1277 | "grid_row": null, 1278 | "grid_template_areas": null, 1279 | "grid_template_columns": null, 1280 | "grid_template_rows": null, 1281 | "height": null, 1282 | "justify_content": null, 1283 | "justify_items": null, 1284 | "left": null, 1285 | "margin": null, 1286 | "max_height": null, 1287 | "max_width": null, 1288 | "min_height": null, 1289 | "min_width": null, 1290 | "object_fit": null, 1291 | "object_position": null, 1292 | "order": null, 1293 | "overflow": null, 1294 | "overflow_x": null, 1295 | "overflow_y": null, 1296 | "padding": null, 1297 | "right": null, 1298 | "top": null, 1299 | "visibility": null, 1300 | "width": null 1301 | } 1302 | }, 1303 | "d3425e7e522b4f3cb654df99ce9fdde5": { 1304 | "model_module": "@jupyter-widgets/controls", 1305 | "model_name": "ProgressStyleModel", 1306 | "model_module_version": "1.5.0", 1307 | "state": { 1308 | "_model_module": "@jupyter-widgets/controls", 1309 | "_model_module_version": "1.5.0", 1310 | "_model_name": "ProgressStyleModel", 1311 | "_view_count": null, 1312 | "_view_module": "@jupyter-widgets/base", 1313 | "_view_module_version": "1.2.0", 1314 | "_view_name": "StyleView", 1315 | "bar_color": null, 1316 | "description_width": "" 1317 | } 1318 | }, 1319 | "f606507c96514fed9f2758c7cb7afe67": { 1320 | "model_module": "@jupyter-widgets/base", 1321 | "model_name": "LayoutModel", 1322 | "model_module_version": "1.2.0", 1323 | "state": { 1324 | "_model_module": "@jupyter-widgets/base", 1325 | "_model_module_version": "1.2.0", 1326 | "_model_name": "LayoutModel", 1327 | "_view_count": null, 1328 | "_view_module": "@jupyter-widgets/base", 1329 | "_view_module_version": "1.2.0", 1330 | "_view_name": "LayoutView", 1331 | "align_content": null, 1332 | "align_items": null, 1333 | "align_self": null, 1334 | "border": null, 1335 | "bottom": null, 1336 | "display": null, 1337 | "flex": null, 1338 | "flex_flow": null, 1339 | "grid_area": null, 1340 | "grid_auto_columns": null, 1341 | "grid_auto_flow": null, 1342 | "grid_auto_rows": null, 1343 | "grid_column": null, 1344 | "grid_gap": null, 1345 | "grid_row": null, 1346 | "grid_template_areas": null, 1347 | "grid_template_columns": null, 1348 | "grid_template_rows": null, 1349 | "height": null, 1350 | "justify_content": null, 1351 | "justify_items": null, 1352 | "left": null, 1353 | "margin": null, 1354 | "max_height": null, 1355 | "max_width": null, 1356 | "min_height": null, 1357 | "min_width": null, 1358 | "object_fit": null, 1359 | "object_position": null, 1360 | "order": null, 1361 | "overflow": null, 1362 | "overflow_x": null, 1363 | "overflow_y": null, 1364 | "padding": null, 1365 | "right": null, 1366 | "top": null, 1367 | "visibility": null, 1368 | "width": null 1369 | } 1370 | }, 1371 | "da8eda9a9a844776823b206eeba8d9f8": { 1372 | "model_module": "@jupyter-widgets/controls", 1373 | "model_name": "DescriptionStyleModel", 1374 | "model_module_version": "1.5.0", 1375 | "state": { 1376 | "_model_module": "@jupyter-widgets/controls", 1377 | "_model_module_version": "1.5.0", 1378 | "_model_name": "DescriptionStyleModel", 1379 | "_view_count": null, 1380 | "_view_module": "@jupyter-widgets/base", 1381 | "_view_module_version": "1.2.0", 1382 | "_view_name": "StyleView", 1383 | "description_width": "" 1384 | } 1385 | }, 1386 | "c7177b2f3c0d48a4af0457aacd2c9a3c": { 1387 | "model_module": "@jupyter-widgets/controls", 1388 | "model_name": "HBoxModel", 1389 | "model_module_version": "1.5.0", 1390 | "state": { 1391 | "_dom_classes": [], 1392 | "_model_module": "@jupyter-widgets/controls", 1393 | "_model_module_version": "1.5.0", 1394 | "_model_name": "HBoxModel", 1395 | "_view_count": null, 1396 | "_view_module": "@jupyter-widgets/controls", 1397 | "_view_module_version": "1.5.0", 1398 | "_view_name": "HBoxView", 1399 | "box_style": "", 1400 | "children": [ 1401 | "IPY_MODEL_4b58a7f58b6342298d584dc24f5b4d31", 1402 | "IPY_MODEL_bb88d58e78d54355b26cf7573ac2c8dd", 1403 | "IPY_MODEL_9ab47692e09145028bd678753632b4bb" 1404 | ], 1405 | "layout": "IPY_MODEL_678af9641c10452c95ca84e3c6d957a5" 1406 | } 1407 | }, 1408 | "4b58a7f58b6342298d584dc24f5b4d31": { 1409 | "model_module": "@jupyter-widgets/controls", 1410 | "model_name": "HTMLModel", 1411 | "model_module_version": "1.5.0", 1412 | "state": { 1413 | "_dom_classes": [], 1414 | "_model_module": "@jupyter-widgets/controls", 1415 | "_model_module_version": "1.5.0", 1416 | "_model_name": "HTMLModel", 1417 | "_view_count": null, 1418 | "_view_module": "@jupyter-widgets/controls", 1419 | "_view_module_version": "1.5.0", 1420 | "_view_name": "HTMLView", 1421 | "description": "", 1422 | "description_tooltip": null, 1423 | "layout": "IPY_MODEL_b2ba15544c4d4ef8962630ba613cba7f", 1424 | "placeholder": "​", 1425 | "style": "IPY_MODEL_fdde69cca8b2404cb9f2409f57c678d5", 1426 | "value": "Downloading (…)/main/tokenizer.json: 100%" 1427 | } 1428 | }, 1429 | "bb88d58e78d54355b26cf7573ac2c8dd": { 1430 | "model_module": "@jupyter-widgets/controls", 1431 | "model_name": "FloatProgressModel", 1432 | "model_module_version": "1.5.0", 1433 | "state": { 1434 | "_dom_classes": [], 1435 | "_model_module": "@jupyter-widgets/controls", 1436 | "_model_module_version": "1.5.0", 1437 | "_model_name": "FloatProgressModel", 1438 | "_view_count": null, 1439 | "_view_module": "@jupyter-widgets/controls", 1440 | "_view_module_version": "1.5.0", 1441 | "_view_name": "ProgressView", 1442 | "bar_style": "success", 1443 | "description": "", 1444 | "description_tooltip": null, 1445 | "layout": "IPY_MODEL_71500c0871ff497bbe6f1c18398dcf76", 1446 | "max": 466062, 1447 | "min": 0, 1448 | "orientation": "horizontal", 1449 | "style": "IPY_MODEL_d61d353498054abbab18722bab879c7d", 1450 | "value": 466062 1451 | } 1452 | }, 1453 | "9ab47692e09145028bd678753632b4bb": { 1454 | "model_module": "@jupyter-widgets/controls", 1455 | "model_name": "HTMLModel", 1456 | "model_module_version": "1.5.0", 1457 | "state": { 1458 | "_dom_classes": [], 1459 | "_model_module": "@jupyter-widgets/controls", 1460 | "_model_module_version": "1.5.0", 1461 | "_model_name": "HTMLModel", 1462 | "_view_count": null, 1463 | "_view_module": "@jupyter-widgets/controls", 1464 | "_view_module_version": "1.5.0", 1465 | "_view_name": "HTMLView", 1466 | "description": "", 1467 | "description_tooltip": null, 1468 | "layout": "IPY_MODEL_6f87eb2140e24227afeba107f6269b6f", 1469 | "placeholder": "​", 1470 | "style": "IPY_MODEL_e7363446238744aabb6ce02cfe202a9e", 1471 | "value": " 466k/466k [00:00<00:00, 2.93MB/s]" 1472 | } 1473 | }, 1474 | "678af9641c10452c95ca84e3c6d957a5": { 1475 | "model_module": "@jupyter-widgets/base", 1476 | "model_name": "LayoutModel", 1477 | "model_module_version": "1.2.0", 1478 | "state": { 1479 | "_model_module": "@jupyter-widgets/base", 1480 | "_model_module_version": "1.2.0", 1481 | "_model_name": "LayoutModel", 1482 | "_view_count": null, 1483 | "_view_module": "@jupyter-widgets/base", 1484 | "_view_module_version": "1.2.0", 1485 | "_view_name": "LayoutView", 1486 | "align_content": null, 1487 | "align_items": null, 1488 | "align_self": null, 1489 | "border": null, 1490 | "bottom": null, 1491 | "display": null, 1492 | "flex": null, 1493 | "flex_flow": null, 1494 | "grid_area": null, 1495 | "grid_auto_columns": null, 1496 | "grid_auto_flow": null, 1497 | "grid_auto_rows": null, 1498 | "grid_column": null, 1499 | "grid_gap": null, 1500 | "grid_row": null, 1501 | "grid_template_areas": null, 1502 | "grid_template_columns": null, 1503 | "grid_template_rows": null, 1504 | "height": null, 1505 | "justify_content": null, 1506 | "justify_items": null, 1507 | "left": null, 1508 | "margin": null, 1509 | "max_height": null, 1510 | "max_width": null, 1511 | "min_height": null, 1512 | "min_width": null, 1513 | "object_fit": null, 1514 | "object_position": null, 1515 | "order": null, 1516 | "overflow": null, 1517 | "overflow_x": null, 1518 | "overflow_y": null, 1519 | "padding": null, 1520 | "right": null, 1521 | "top": null, 1522 | "visibility": null, 1523 | "width": null 1524 | } 1525 | }, 1526 | "b2ba15544c4d4ef8962630ba613cba7f": { 1527 | "model_module": "@jupyter-widgets/base", 1528 | "model_name": "LayoutModel", 1529 | "model_module_version": "1.2.0", 1530 | "state": { 1531 | "_model_module": "@jupyter-widgets/base", 1532 | "_model_module_version": "1.2.0", 1533 | "_model_name": "LayoutModel", 1534 | "_view_count": null, 1535 | "_view_module": "@jupyter-widgets/base", 1536 | "_view_module_version": "1.2.0", 1537 | "_view_name": "LayoutView", 1538 | "align_content": null, 1539 | "align_items": null, 1540 | "align_self": null, 1541 | "border": null, 1542 | "bottom": null, 1543 | "display": null, 1544 | "flex": null, 1545 | "flex_flow": null, 1546 | "grid_area": null, 1547 | "grid_auto_columns": null, 1548 | "grid_auto_flow": null, 1549 | "grid_auto_rows": null, 1550 | "grid_column": null, 1551 | "grid_gap": null, 1552 | "grid_row": null, 1553 | "grid_template_areas": null, 1554 | "grid_template_columns": null, 1555 | "grid_template_rows": null, 1556 | "height": null, 1557 | "justify_content": null, 1558 | "justify_items": null, 1559 | "left": null, 1560 | "margin": null, 1561 | "max_height": null, 1562 | "max_width": null, 1563 | "min_height": null, 1564 | "min_width": null, 1565 | "object_fit": null, 1566 | "object_position": null, 1567 | "order": null, 1568 | "overflow": null, 1569 | "overflow_x": null, 1570 | "overflow_y": null, 1571 | "padding": null, 1572 | "right": null, 1573 | "top": null, 1574 | "visibility": null, 1575 | "width": null 1576 | } 1577 | }, 1578 | "fdde69cca8b2404cb9f2409f57c678d5": { 1579 | "model_module": "@jupyter-widgets/controls", 1580 | "model_name": "DescriptionStyleModel", 1581 | "model_module_version": "1.5.0", 1582 | "state": { 1583 | "_model_module": "@jupyter-widgets/controls", 1584 | "_model_module_version": "1.5.0", 1585 | "_model_name": "DescriptionStyleModel", 1586 | "_view_count": null, 1587 | "_view_module": "@jupyter-widgets/base", 1588 | "_view_module_version": "1.2.0", 1589 | "_view_name": "StyleView", 1590 | "description_width": "" 1591 | } 1592 | }, 1593 | "71500c0871ff497bbe6f1c18398dcf76": { 1594 | "model_module": "@jupyter-widgets/base", 1595 | "model_name": "LayoutModel", 1596 | "model_module_version": "1.2.0", 1597 | "state": { 1598 | "_model_module": "@jupyter-widgets/base", 1599 | "_model_module_version": "1.2.0", 1600 | "_model_name": "LayoutModel", 1601 | "_view_count": null, 1602 | "_view_module": "@jupyter-widgets/base", 1603 | "_view_module_version": "1.2.0", 1604 | "_view_name": "LayoutView", 1605 | "align_content": null, 1606 | "align_items": null, 1607 | "align_self": null, 1608 | "border": null, 1609 | "bottom": null, 1610 | "display": null, 1611 | "flex": null, 1612 | "flex_flow": null, 1613 | "grid_area": null, 1614 | "grid_auto_columns": null, 1615 | "grid_auto_flow": null, 1616 | "grid_auto_rows": null, 1617 | "grid_column": null, 1618 | "grid_gap": null, 1619 | "grid_row": null, 1620 | "grid_template_areas": null, 1621 | "grid_template_columns": null, 1622 | "grid_template_rows": null, 1623 | "height": null, 1624 | "justify_content": null, 1625 | "justify_items": null, 1626 | "left": null, 1627 | "margin": null, 1628 | "max_height": null, 1629 | "max_width": null, 1630 | "min_height": null, 1631 | "min_width": null, 1632 | "object_fit": null, 1633 | "object_position": null, 1634 | "order": null, 1635 | "overflow": null, 1636 | "overflow_x": null, 1637 | "overflow_y": null, 1638 | "padding": null, 1639 | "right": null, 1640 | "top": null, 1641 | "visibility": null, 1642 | "width": null 1643 | } 1644 | }, 1645 | "d61d353498054abbab18722bab879c7d": { 1646 | "model_module": "@jupyter-widgets/controls", 1647 | "model_name": "ProgressStyleModel", 1648 | "model_module_version": "1.5.0", 1649 | "state": { 1650 | "_model_module": "@jupyter-widgets/controls", 1651 | "_model_module_version": "1.5.0", 1652 | "_model_name": "ProgressStyleModel", 1653 | "_view_count": null, 1654 | "_view_module": "@jupyter-widgets/base", 1655 | "_view_module_version": "1.2.0", 1656 | "_view_name": "StyleView", 1657 | "bar_color": null, 1658 | "description_width": "" 1659 | } 1660 | }, 1661 | "6f87eb2140e24227afeba107f6269b6f": { 1662 | "model_module": "@jupyter-widgets/base", 1663 | "model_name": "LayoutModel", 1664 | "model_module_version": "1.2.0", 1665 | "state": { 1666 | "_model_module": "@jupyter-widgets/base", 1667 | "_model_module_version": "1.2.0", 1668 | "_model_name": "LayoutModel", 1669 | "_view_count": null, 1670 | "_view_module": "@jupyter-widgets/base", 1671 | "_view_module_version": "1.2.0", 1672 | "_view_name": "LayoutView", 1673 | "align_content": null, 1674 | "align_items": null, 1675 | "align_self": null, 1676 | "border": null, 1677 | "bottom": null, 1678 | "display": null, 1679 | "flex": null, 1680 | "flex_flow": null, 1681 | "grid_area": null, 1682 | "grid_auto_columns": null, 1683 | "grid_auto_flow": null, 1684 | "grid_auto_rows": null, 1685 | "grid_column": null, 1686 | "grid_gap": null, 1687 | "grid_row": null, 1688 | "grid_template_areas": null, 1689 | "grid_template_columns": null, 1690 | "grid_template_rows": null, 1691 | "height": null, 1692 | "justify_content": null, 1693 | "justify_items": null, 1694 | "left": null, 1695 | "margin": null, 1696 | "max_height": null, 1697 | "max_width": null, 1698 | "min_height": null, 1699 | "min_width": null, 1700 | "object_fit": null, 1701 | "object_position": null, 1702 | "order": null, 1703 | "overflow": null, 1704 | "overflow_x": null, 1705 | "overflow_y": null, 1706 | "padding": null, 1707 | "right": null, 1708 | "top": null, 1709 | "visibility": null, 1710 | "width": null 1711 | } 1712 | }, 1713 | "e7363446238744aabb6ce02cfe202a9e": { 1714 | "model_module": "@jupyter-widgets/controls", 1715 | "model_name": "DescriptionStyleModel", 1716 | "model_module_version": "1.5.0", 1717 | "state": { 1718 | "_model_module": "@jupyter-widgets/controls", 1719 | "_model_module_version": "1.5.0", 1720 | "_model_name": "DescriptionStyleModel", 1721 | "_view_count": null, 1722 | "_view_module": "@jupyter-widgets/base", 1723 | "_view_module_version": "1.2.0", 1724 | "_view_name": "StyleView", 1725 | "description_width": "" 1726 | } 1727 | } 1728 | } 1729 | }, 1730 | "accelerator": "GPU" 1731 | }, 1732 | "cells": [ 1733 | { 1734 | "cell_type": "markdown", 1735 | "source": [ 1736 | "# Tutorial on SimCSE\n", 1737 | "\n", 1738 | "- 作成者: Shunsuke Kanda ([@kampersanda](https://github.com/kampersanda))\n", 1739 | "- 作成日: 2023-10-29\n", 1740 | "\n", 1741 | "## SimCSEについて\n", 1742 | "\n", 1743 | "SimCSEは、対照学習を用いた文埋め込み技術です。\n", 1744 | "\n", 1745 | "> Tianyu Gao, Xingcheng Yao, and Danqi Chen. [SimCSE: Simple Contrastive Learning of Sentence Embeddings](https://aclanthology.org/2021.emnlp-main.552/). EMNLP 2021.\n", 1746 | "\n", 1747 | "SimCSEは、簡単なアルゴリズムでラベルの無い文集合から文埋め込みモデルを学習することができます。\n", 1748 | "その文埋め込みは、Semantic Textual Similarity (STS) 評価タスクにおいて、教師ありのSentence-BERTと同程度の性能を示します。\n", 1749 | "\n", 1750 | "また、正例ペアから成る訓練セットを用いて教師あり学習することで、更にその性能を向上することができます。\n", 1751 | "訓練セットの作り方次第では、それぞれの目的に応じた文埋め込みモデルを獲得することも可能です。\n", 1752 | "\n", 1753 | "SimCSEは、実装の容易さ・応用の容易さ・高い性能などから、研究でも実用でも非常に有用な文埋め込み技術のひとつです。その実装方法を習得することは、自然言語処理や情報検索エンジニアにとって有益だと考えます。\n", 1754 | "\n", 1755 | "## この資料について\n", 1756 | "\n", 1757 | "### 動機\n", 1758 | "\n", 1759 | "SimCSEのアイデアはシンプルなので、そのアルゴリズムを理解することはあまり難しくないです。良い教材もネットに揃っています。\n", 1760 | "\n", 1761 | "しかし、SimCSEを実装し応用できるようになるには、深層学習フレームワークや自然言語処理について一定の知識と経験が必要になります。例えば、ある深層学習アルゴリズムの実装を眺めてみて、ライブラリの使用方法や、当たり前に記述されているヒューリスティックの意味が分からず、一行一行調べながらコードを読んだ経験のある方も多いと思います。\n", 1762 | "\n", 1763 | "また、初学者の方にとってはGPUなどの環境構築も1つのハードルであり、Google Colaboratoryなどの環境で試せることも重要でしょう。\n", 1764 | "\n", 1765 | "### 目的\n", 1766 | "\n", 1767 | "この資料は、SimCSEについて上記のような問題を解決することを目的とし、SimCSEの学習から評価まで一連の実装とその解説をNotebookで提供します。\n", 1768 | "\n", 1769 | "### 特徴\n", 1770 | "\n", 1771 | "- 以下の一連の処理を、上から順に実行することで簡単に試すことが可能です\n", 1772 | " - データセットの準備\n", 1773 | " - モデルの定義\n", 1774 | " - モデルの学習\n", 1775 | " - モデルの評価\n", 1776 | "- PyTorchやTransformersの基本的な使用方法や、自然言語処理でよく知られるヒューリスティックも解説します\n", 1777 | "- 教師なしと教師ありの両方の実装を提供します\n", 1778 | "\n", 1779 | "### 想定する利用者\n", 1780 | "\n", 1781 | "- SimCSEのアイデアは理解できるが、深層学習フレームワークなどの経験の少なさから実際にモデルを実装するのにはハードルを感じるという方\n", 1782 | "- 業務などでSimCSEのアルゴリズムを実装する必要がある方\n", 1783 | "\n", 1784 | "### 読むのに必要なこと\n", 1785 | "\n", 1786 | "SimCSEの目的と基本的なアイデアを理解していることを前提とします。これらの知識習得のために、以下のスライド資料をオススメします。\n", 1787 | "\n", 1788 | "> [[輪講資料] SimCSE: Simple Contrastive Learning of Sentence Embeddings](https://speakerdeck.com/hpprc/lun-jiang-zi-liao-simcse-simple-contrastive-learning-of-sentence-embeddings-823255cd-bd1f-40ec-a65c-0eced7a9191d)\n", 1789 | "\n", 1790 | "また、深層学習モデルの基本的な学習方法(ミニバッチ学習など)も既知を前提とします。\n", 1791 | "\n", 1792 | "### 作成方法\n", 1793 | "\n", 1794 | "SimCSEのシンプルな再実装 [hppRC/simple-simcse](https://github.com/hppRC/simple-simcse) が存在します。こちらのレポジトリは、SimCSEの学習と評価アルゴリズムの簡潔な実装を提供しており、コードの各パートにも丁寧な解説コメントを記述しています。深層学習の経験があり、論文を読んでその内容が理解できる方にとってはsimple-simcseが必要十分な資料だと思います。\n", 1795 | "\n", 1796 | "このNotebookでは更に基礎的な部分からの解説と簡単な利用を試み、作成者が解説コメントを追記しつつ、simple-simcseの内容をNotebookで再実装しました。また、教師あり学習のパートも追加しました。\n", 1797 | "\n", 1798 | "## 参考資料\n", 1799 | "\n", 1800 | "以下を引用しつつ解説します。\n", 1801 | "\n", 1802 | "- Tianyu Gao, Xingcheng Yao, and Danqi Chen. [SimCSE: Simple Contrastive Learning of Sentence Embeddings](https://aclanthology.org/2021.emnlp-main.552/). EMNLP 2021. (\"論文\"として引用)\n", 1803 | "- https://github.com/princeton-nlp/SimCSE. (\"オリジナルの実装\"として引用)\n", 1804 | "- [[輪講資料] SimCSE: Simple Contrastive Learning of Sentence Embeddings](https://speakerdeck.com/hpprc/lun-jiang-zi-liao-simcse-simple-contrastive-learning-of-sentence-embeddings-823255cd-bd1f-40ec-a65c-0eced7a9191d) by Hayato Tsukagoshi (\"スライド\"として引用)\n", 1805 | "- 岡﨑, 荒瀬, 鈴木, 鶴岡, and 宮尾. [IT Text 自然言語処理の基礎](https://www.ohmsha.co.jp/book/9784274229008/), 2022. (\"岡﨑ら本\"として引用)\n", 1806 | "\n", 1807 | "## Notebookの構成\n", 1808 | "\n", 1809 | "Notebookは以下の4章で構成されます。\n", 1810 | "\n", 1811 | "1. 共通の設定\n", 1812 | "2. 教師なし学習(unsup-SimCSE)\n", 1813 | "3. 教師あり学習(sup-SimCSE)\n", 1814 | "4. 評価\n", 1815 | "\n", 1816 | "上から順に実行する想定ですが、2と3は任意で実行をスキップしても機能します。Google Colaboratoryで実行されることを想定します。\n", 1817 | "\n", 1818 | "## クレジット\n", 1819 | "\n", 1820 | "このNotebookは、LegalOn Technologiesの社内勉強会で使用した資料です。検索チームが主催するセマンティック検索とベクトル検索に関する勉強会の発表資料として作成されました。\n", 1821 | "\n", 1822 | "このNotebookの実装とコメントの大部分は、[hppRC/simple-simcse](https://github.com/hppRC/simple-simcse)からの移植です。\n", 1823 | "\n", 1824 | "このNotebookは、[Apache License Version 2.0](https://www.apache.org/licenses/LICENSE-2.0)に準拠します。\n", 1825 | "\n", 1826 | "## 謝辞\n", 1827 | "\n", 1828 | "このNotebookは、[hppRC/simple-simcse](https://github.com/hppRC/simple-simcse)と上記スライド無しでは作成できませんでした。これらの制作者である塚越駿さんに感謝致します。\n", 1829 | "\n", 1830 | "同僚の小林さんと藤田さんにも資料作成にあたって有益なコメントを頂きました。感謝致します。" 1831 | ], 1832 | "metadata": { 1833 | "id": "_tpRwoEmOgFY" 1834 | } 1835 | }, 1836 | { 1837 | "cell_type": "markdown", 1838 | "source": [ 1839 | "# 1. 共通の設定" 1840 | ], 1841 | "metadata": { 1842 | "id": "O4xNBGPWn_1b" 1843 | } 1844 | }, 1845 | { 1846 | "cell_type": "code", 1847 | "execution_count": 1, 1848 | "metadata": { 1849 | "colab": { 1850 | "base_uri": "https://localhost:8080/" 1851 | }, 1852 | "id": "uVv0QKUIJ9nw", 1853 | "outputId": "4f1e07d3-895a-43c5-d648-363cefa6c14a" 1854 | }, 1855 | "outputs": [ 1856 | { 1857 | "output_type": "stream", 1858 | "name": "stdout", 1859 | "text": [ 1860 | "Collecting transformers==4.34.0\n", 1861 | " Downloading transformers-4.34.0-py3-none-any.whl (7.7 MB)\n", 1862 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.7/7.7 MB\u001b[0m \u001b[31m20.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", 1863 | "\u001b[?25hRequirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from transformers==4.34.0) (3.12.4)\n", 1864 | "Collecting huggingface-hub<1.0,>=0.16.4 (from transformers==4.34.0)\n", 1865 | " Downloading huggingface_hub-0.18.0-py3-none-any.whl (301 kB)\n", 1866 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m302.0/302.0 kB\u001b[0m \u001b[31m38.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", 1867 | "\u001b[?25hRequirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from transformers==4.34.0) (1.23.5)\n", 1868 | "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from transformers==4.34.0) (23.2)\n", 1869 | "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from transformers==4.34.0) (6.0.1)\n", 1870 | "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers==4.34.0) (2023.6.3)\n", 1871 | "Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from transformers==4.34.0) (2.31.0)\n", 1872 | "Collecting tokenizers<0.15,>=0.14 (from transformers==4.34.0)\n", 1873 | " Downloading tokenizers-0.14.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.8 MB)\n", 1874 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m3.8/3.8 MB\u001b[0m \u001b[31m57.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", 1875 | "\u001b[?25hCollecting safetensors>=0.3.1 (from transformers==4.34.0)\n", 1876 | " Downloading safetensors-0.4.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)\n", 1877 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.3/1.3 MB\u001b[0m \u001b[31m57.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", 1878 | "\u001b[?25hRequirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.10/dist-packages (from transformers==4.34.0) (4.66.1)\n", 1879 | "Requirement already satisfied: fsspec>=2023.5.0 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.16.4->transformers==4.34.0) (2023.6.0)\n", 1880 | "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.16.4->transformers==4.34.0) (4.5.0)\n", 1881 | "Collecting huggingface-hub<1.0,>=0.16.4 (from transformers==4.34.0)\n", 1882 | " Downloading huggingface_hub-0.17.3-py3-none-any.whl (295 kB)\n", 1883 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m295.0/295.0 kB\u001b[0m \u001b[31m33.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", 1884 | "\u001b[?25hRequirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->transformers==4.34.0) (3.3.1)\n", 1885 | "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->transformers==4.34.0) (3.4)\n", 1886 | "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->transformers==4.34.0) (2.0.7)\n", 1887 | "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->transformers==4.34.0) (2023.7.22)\n", 1888 | "Installing collected packages: safetensors, huggingface-hub, tokenizers, transformers\n", 1889 | "Successfully installed huggingface-hub-0.17.3 safetensors-0.4.0 tokenizers-0.14.1 transformers-4.34.0\n" 1890 | ] 1891 | } 1892 | ], 1893 | "source": [ 1894 | "!pip install transformers==4.34.0" 1895 | ] 1896 | }, 1897 | { 1898 | "cell_type": "code", 1899 | "source": [ 1900 | "import csv\n", 1901 | "import os\n", 1902 | "from typing import Callable\n", 1903 | "\n", 1904 | "import more_itertools\n", 1905 | "import pandas as pd\n", 1906 | "import scipy\n", 1907 | "from sklearn import metrics as sklearn_metrics\n", 1908 | "import torch\n", 1909 | "import tqdm\n", 1910 | "import transformers" 1911 | ], 1912 | "metadata": { 1913 | "id": "9_pIr2GIWRpX" 1914 | }, 1915 | "execution_count": 2, 1916 | "outputs": [] 1917 | }, 1918 | { 1919 | "cell_type": "code", 1920 | "source": [ 1921 | "# モデルを保存するためにGoogle Driveをマウントしておく\n", 1922 | "from google.colab import drive\n", 1923 | "drive.mount('/content/drive')\n", 1924 | "\n", 1925 | "# モデルの保存先パス (ファイル名の衝突に注意)\n", 1926 | "unsup_model_path = './drive/MyDrive/unsup-simcse-model.pth'\n", 1927 | "sup_model_path = './drive/MyDrive/sup-simcse-model.pth'" 1928 | ], 1929 | "metadata": { 1930 | "colab": { 1931 | "base_uri": "https://localhost:8080/" 1932 | }, 1933 | "id": "tC-rs-fO2iuf", 1934 | "outputId": "8e383ad9-4ed7-4cec-e6ca-72354f6f0860" 1935 | }, 1936 | "execution_count": 3, 1937 | "outputs": [ 1938 | { 1939 | "output_type": "stream", 1940 | "name": "stdout", 1941 | "text": [ 1942 | "Mounted at /content/drive\n" 1943 | ] 1944 | } 1945 | ] 1946 | }, 1947 | { 1948 | "cell_type": "code", 1949 | "source": [ 1950 | "# SimCSEモデルクラスの定義\n", 1951 | "\n", 1952 | "# torch.nn.Moduleのサブクラスとして、SimCSEモデルを定義\n", 1953 | "# https://pytorch.org/docs/stable/generated/torch.nn.Module.html\n", 1954 | "class SimCSEModel(torch.nn.Module):\n", 1955 | "\n", 1956 | " # SimCSEに使用する事前学習済みモデル名をリストで管理\n", 1957 | " # 論文と同じくBERTとRoBERTaモデルを想定\n", 1958 | " SUPPORTED_MODELS = ['bert-base-uncased', 'bert-large-uncased', 'roberta-base', 'roberta-large']\n", 1959 | "\n", 1960 | " # 内部で使用するTransformersのモデル名を受け取る\n", 1961 | " def __init__(self, model_name: str) -> None:\n", 1962 | " if not model_name in self.SUPPORTED_MODELS:\n", 1963 | " raise ValueError(f'{model_name} is not supported.')\n", 1964 | "\n", 1965 | " # 親クラスの__init__()を最初に呼び出す仕様\n", 1966 | " super().__init__()\n", 1967 | "\n", 1968 | " # SimCSEに使用する事前学習済みTransformersモデルをインスタンス化\n", 1969 | " #\n", 1970 | " # Automodelでモデル名からLookupしてモデルをダウンロードし読み込んでくれる\n", 1971 | " # https://huggingface.co/docs/transformers/model_doc/auto#transformers.AutoModel\n", 1972 | " self.backbone: transformers.modeling_utils.PreTrainedModel = transformers.AutoModel.from_pretrained(model_name)\n", 1973 | "\n", 1974 | " # 追加で多層パーセプトロン(MLP)層を定義\n", 1975 | " #\n", 1976 | " # backboneから得られる埋め込みを更に変換する\n", 1977 | " # 性能改善のためのオプショナルなコンポネントなので、無くても機能する\n", 1978 | " #\n", 1979 | " # 論文6.3節を参照\n", 1980 | " # オリジナルの実装は以下を参照\n", 1981 | " # https://github.com/princeton-nlp/SimCSE/blob/0.4/simcse/models.py#L19\n", 1982 | " self.hidden_size: int = self.backbone.config.hidden_size\n", 1983 | " self.dense = torch.nn.Linear(self.hidden_size, self.hidden_size)\n", 1984 | " self.activation = torch.nn.Tanh()\n", 1985 | "\n", 1986 | " # 入力文のトークン列を受け取り、その文埋め込みを返す\n", 1987 | " #\n", 1988 | " # 引数はそのままBERT/RoBERTaモデルに引き渡される\n", 1989 | " # 引数の意味はBertModelのforwardを参照\n", 1990 | " # https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertModel\n", 1991 | " def forward(\n", 1992 | " self,\n", 1993 | " input_ids: torch.Tensor,\n", 1994 | " attention_mask: torch.Tensor = None,\n", 1995 | " # RoBERTa variants don't have token_type_ids, so this argument is optional\n", 1996 | " token_type_ids: torch.Tensor = None,\n", 1997 | " ) -> torch.Tensor:\n", 1998 | " # input_ids.size() = (batch_size, seq_len)\n", 1999 | " # attention_mask.size() = (batch_size, seq_len)\n", 2000 | " # token_type_ids.size() = (batch_size, seq_len)\n", 2001 | "\n", 2002 | " # BERT/RoBERTaモデルで推論\n", 2003 | " outputs = self.backbone.forward(\n", 2004 | " input_ids=input_ids,\n", 2005 | " attention_mask=attention_mask,\n", 2006 | " token_type_ids=token_type_ids,\n", 2007 | " )\n", 2008 | "\n", 2009 | " # 推論結果から文埋め込みを抽出\n", 2010 | " #\n", 2011 | " # BERTモデルをファインチューニングする場合は、Transformerの[CLS]トークンに対応した最終層の隠れ状態ベクトルを\n", 2012 | " # 文の埋め込み表現として用いることが多い (岡崎ら本7.3節参照)\n", 2013 | " # [CLS]トークン: 入力文の先頭に付随する特殊トークンで、文全体を表現する役割として分類問題などで使用される\n", 2014 | " #\n", 2015 | " # outputs.last_hidden_state.size() = (batch_size, seq_len, hidden_size)\n", 2016 | " # emb.size() = (batch_size, hidden_size)\n", 2017 | " emb = outputs.last_hidden_state[:, 0]\n", 2018 | "\n", 2019 | " # 上の代わりに、全サブワードに対応する最終層の隠れ状態ベクトルの平均プーリングや最大プーリングなどを用いても良い\n", 2020 | " # オリジナルの実装では4種類を試している\n", 2021 | " # https://github.com/princeton-nlp/SimCSE/blob/0.4/simcse/models.py#L63\n", 2022 | "\n", 2023 | " # unsup-SimCSEの場合、訓練時のみMLP層を使用するのが最も性能が良いという報告\n", 2024 | " # sup-SimCSEの場合は、推論時でもMLP層を使用するか、もしくはMLP層自体を使用しない方が良い性能\n", 2025 | " # 論文6.3節を参照\n", 2026 | " #\n", 2027 | " # (コメント) unsup-SimCSEでは学習データに適合し過ぎないで欲しいお気持ちがある?\n", 2028 | " #\n", 2029 | " # self.trainingのON/OFFはtorch.nn.Module.train()/.eval()で制御可能\n", 2030 | " if self.training:\n", 2031 | " emb = self.dense(emb)\n", 2032 | " emb = self.activation(emb)\n", 2033 | "\n", 2034 | " # emb.size() = (batch_size, hidden_size)\n", 2035 | " return emb" 2036 | ], 2037 | "metadata": { 2038 | "id": "sStcE1WGNThG" 2039 | }, 2040 | "execution_count": 4, 2041 | "outputs": [] 2042 | }, 2043 | { 2044 | "cell_type": "code", 2045 | "source": [ 2046 | "# 使用する計算デバイスの設定\n", 2047 | "# cpuでは低速すぎるので、基本的にはcudaを使用する\n", 2048 | "\n", 2049 | "# device = 'cpu'\n", 2050 | "device = 'cuda'" 2051 | ], 2052 | "metadata": { 2053 | "id": "Pvbpoxcb7Dz5" 2054 | }, 2055 | "execution_count": 5, 2056 | "outputs": [] 2057 | }, 2058 | { 2059 | "cell_type": "markdown", 2060 | "source": [ 2061 | "# 2. 教師なし学習(unsup-SimCSE)" 2062 | ], 2063 | "metadata": { 2064 | "id": "xkqF-xcxPoFh" 2065 | } 2066 | }, 2067 | { 2068 | "cell_type": "markdown", 2069 | "source": [ 2070 | "## 2.1 モデルインスタンスの生成" 2071 | ], 2072 | "metadata": { 2073 | "id": "UGnQ82C3-x0n" 2074 | } 2075 | }, 2076 | { 2077 | "cell_type": "code", 2078 | "source": [ 2079 | "# Transformersの事前学習済みモデルからインスタンスを生成\n", 2080 | "\n", 2081 | "# この例ではベーシックなBERTモデルのbert-base-uncasedを使ってみる\n", 2082 | "model_name = 'bert-base-uncased'\n", 2083 | "\n", 2084 | "# SimCSEModelをインスタンス化し、指定したデバイスに載せる\n", 2085 | "model = SimCSEModel(model_name).to(device)\n", 2086 | "\n", 2087 | "# テキストをトークンに分割しモデルに入力できる形式に変換するためのトークナイザを生成する\n", 2088 | "tokenizer: transformers.tokenization_utils.PreTrainedTokenizer = transformers.AutoTokenizer.from_pretrained(model_name)" 2089 | ], 2090 | "metadata": { 2091 | "id": "Uq0SGUuvPjAB", 2092 | "colab": { 2093 | "base_uri": "https://localhost:8080/", 2094 | "height": 177, 2095 | "referenced_widgets": [ 2096 | "f56b8f398ee247dfa87ab14a4bf49cb0", 2097 | "0d89a0be7ccd4e688486fe195284851a", 2098 | "b27907ebe6f04850b34dd61707ffb01e", 2099 | "8298baf4dba046559b7b13f08b3dad90", 2100 | "ed86772d77ea4d4a97c57cd74ae6f258", 2101 | "e1ee5a33a31d4991a81cae0426dacf44", 2102 | "2621a2b022e24af0801b3e4b584f2c76", 2103 | "e0af523f1c304388a027b4fffd470010", 2104 | "bfe851cf720346d0b63347283acc99ac", 2105 | "e77f3ecf903848afae1db2dc3a21f565", 2106 | "178c637e246d4486bf78d6f76d4bcc67", 2107 | "d19466cf70da42e69ec97a0e078677f6", 2108 | "c1077acfc06740899abed6b993784313", 2109 | "d94598b0db264defbbf1ed575baf3130", 2110 | "c3c41882d2994fa8a4214afe2a95cf69", 2111 | "4de477ae5de048ccbfc048d95e03eb76", 2112 | "3020a5ea349e47998bf6a7749c1f124c", 2113 | "fbe178463818478aa40159e8285a529d", 2114 | "a6396e902ac248c7877494a9e248df22", 2115 | "edb212d4e7674c3797e5184d3a093b89", 2116 | "ad61d035232544aea3e1f3cc666b4ae7", 2117 | "c61ac74e2a2d44fd9505686b476b9f3a", 2118 | "a4891855e50a452f9671e49b1c1936bc", 2119 | "c11ebafc26c449178b5e61c437584009", 2120 | "dccbb5c014b5481a84b77ff57001e600", 2121 | "bcdf9ed42db94c4ea8532fe561c65cbc", 2122 | "54c0a73db1994c8e9fd2d605241c1253", 2123 | "9caba4686e8a4db4adc32f6e2cffbd4c", 2124 | "918a52cbc7f346da81ae77b76f33a713", 2125 | "3ceecade00b9474684e6e079b0bf507f", 2126 | "34e1e8af04b24a499141cc4fa5bd0693", 2127 | "663ac313a6e04906bcfdaf22477843a9", 2128 | "d846c8467dd24a7e9499834a13cd02f1", 2129 | "cefd61371ec247a2b7720894a719a42d", 2130 | "3e5d1f5547a14a7f92c4aaaa99fdc360", 2131 | "6f6145876e894419a6f2271361b58c05", 2132 | "7bb7e60954e744eea9ac066ef62ac979", 2133 | "af03e157d43d45e18d54a98a642baf3d", 2134 | "13b393738e404659925ca7d8e975ca54", 2135 | "2360e8a6d5cd44dabf3f6ddd76fb4c83", 2136 | "d56dd16aae604b85ad65aabfd7f16401", 2137 | "d3425e7e522b4f3cb654df99ce9fdde5", 2138 | "f606507c96514fed9f2758c7cb7afe67", 2139 | "da8eda9a9a844776823b206eeba8d9f8", 2140 | "c7177b2f3c0d48a4af0457aacd2c9a3c", 2141 | "4b58a7f58b6342298d584dc24f5b4d31", 2142 | "bb88d58e78d54355b26cf7573ac2c8dd", 2143 | "9ab47692e09145028bd678753632b4bb", 2144 | "678af9641c10452c95ca84e3c6d957a5", 2145 | "b2ba15544c4d4ef8962630ba613cba7f", 2146 | "fdde69cca8b2404cb9f2409f57c678d5", 2147 | "71500c0871ff497bbe6f1c18398dcf76", 2148 | "d61d353498054abbab18722bab879c7d", 2149 | "6f87eb2140e24227afeba107f6269b6f", 2150 | "e7363446238744aabb6ce02cfe202a9e" 2151 | ] 2152 | }, 2153 | "outputId": "db6c2d0c-67d2-4872-b587-05ca072d77e4" 2154 | }, 2155 | "execution_count": 6, 2156 | "outputs": [ 2157 | { 2158 | "output_type": "display_data", 2159 | "data": { 2160 | "text/plain": [ 2161 | "Downloading (…)lve/main/config.json: 0%| | 0.00/570 [00:00] 114.48M 273MB/s in 0.4s \n", 2275 | "\n", 2276 | "2023-11-01 09:02:23 (273 MB/s) - ‘wiki1m_for_simcse.txt’ saved [120038621/120038621]\n", 2277 | "\n" 2278 | ] 2279 | } 2280 | ] 2281 | }, 2282 | { 2283 | "cell_type": "code", 2284 | "source": [ 2285 | "# データセットは単純な行区切りのテキストデータ\n", 2286 | "with open('./datasets/unsup-simcse/train.txt') as f:\n", 2287 | " sentences = [line.rstrip('\\n') for line in f]\n", 2288 | "\n", 2289 | "# 表示のためにDataFrameに変換\n", 2290 | "train_examples = pd.DataFrame(sentences, columns=['sentences'])\n", 2291 | "train_examples" 2292 | ], 2293 | "metadata": { 2294 | "colab": { 2295 | "base_uri": "https://localhost:8080/", 2296 | "height": 424 2297 | }, 2298 | "id": "bCvzpdQlnvS0", 2299 | "outputId": "eeb0225c-6404-4036-9d83-65a343f1cb94" 2300 | }, 2301 | "execution_count": 8, 2302 | "outputs": [ 2303 | { 2304 | "output_type": "execute_result", 2305 | "data": { 2306 | "text/plain": [ 2307 | " sentences\n", 2308 | "0 YMCA in South Australia\n", 2309 | "1 South Australia (SA)  has a unique position in...\n", 2310 | "2 The compound of philosophical radicalism, evan...\n", 2311 | "3 It was into this social setting that in Februa...\n", 2312 | "4 for apprentices and others, after their day's ...\n", 2313 | "... ...\n", 2314 | "999995 Rubaschow: Roman.\n", 2315 | "999996 Typoskript, März 1940, 326 pages.\"\n", 2316 | "999997 He deemed the discovery important because \"\"Da...\n", 2317 | "999998 In 2018, he reported that Elsinor Verlag (publ...\n", 2318 | "999999 He also reported a new English translation to ...\n", 2319 | "\n", 2320 | "[1000000 rows x 1 columns]" 2321 | ], 2322 | "text/html": [ 2323 | "\n", 2324 | "
\n", 2325 | "
\n", 2326 | "\n", 2339 | "\n", 2340 | " \n", 2341 | " \n", 2342 | " \n", 2343 | " \n", 2344 | " \n", 2345 | " \n", 2346 | " \n", 2347 | " \n", 2348 | " \n", 2349 | " \n", 2350 | " \n", 2351 | " \n", 2352 | " \n", 2353 | " \n", 2354 | " \n", 2355 | " \n", 2356 | " \n", 2357 | " \n", 2358 | " \n", 2359 | " \n", 2360 | " \n", 2361 | " \n", 2362 | " \n", 2363 | " \n", 2364 | " \n", 2365 | " \n", 2366 | " \n", 2367 | " \n", 2368 | " \n", 2369 | " \n", 2370 | " \n", 2371 | " \n", 2372 | " \n", 2373 | " \n", 2374 | " \n", 2375 | " \n", 2376 | " \n", 2377 | " \n", 2378 | " \n", 2379 | " \n", 2380 | " \n", 2381 | " \n", 2382 | " \n", 2383 | " \n", 2384 | " \n", 2385 | " \n", 2386 | " \n", 2387 | " \n", 2388 | " \n", 2389 | " \n", 2390 | " \n", 2391 | " \n", 2392 | "
sentences
0YMCA in South Australia
1South Australia (SA)  has a unique position in...
2The compound of philosophical radicalism, evan...
3It was into this social setting that in Februa...
4for apprentices and others, after their day's ...
......
999995Rubaschow: Roman.
999996Typoskript, März 1940, 326 pages.\"
999997He deemed the discovery important because \"\"Da...
999998In 2018, he reported that Elsinor Verlag (publ...
999999He also reported a new English translation to ...
\n", 2393 | "

1000000 rows × 1 columns

\n", 2394 | "
\n", 2395 | "
\n", 2396 | "\n", 2397 | "
\n", 2398 | " \n", 2406 | "\n", 2407 | " \n", 2447 | "\n", 2448 | " \n", 2472 | "
\n", 2473 | "\n", 2474 | "\n", 2475 | "
\n", 2476 | " \n", 2487 | "\n", 2488 | "\n", 2577 | "\n", 2578 | " \n", 2600 | "
\n", 2601 | "
\n", 2602 | "
\n" 2603 | ] 2604 | }, 2605 | "metadata": {}, 2606 | "execution_count": 8 2607 | } 2608 | ] 2609 | }, 2610 | { 2611 | "cell_type": "code", 2612 | "source": [ 2613 | "# PyTorchのDatasetとDataLoaderを使ってデータセットを処理する\n", 2614 | "#\n", 2615 | "# - torch.utils.data.Dataset: 訓練データを格納しアクセスするためのクラス\n", 2616 | "# - torch.utils.data.DataLoader: 訓練データをイテレートするためのクラス(ミニバッチ化や再シャッフルなどを提供)\n", 2617 | "#\n", 2618 | "# https://pytorch.org/tutorials/beginner/basics/data_tutorial.html\n", 2619 | "\n", 2620 | "# Datasetクラスを定義する\n", 2621 | "#\n", 2622 | "# __init__, __len__, __getitem__関数を実装すれば良い\n", 2623 | "# https://pytorch.org/tutorials/beginner/basics/data_tutorial.html#creating-a-custom-dataset-for-your-files\n", 2624 | "class UnsupSimCSEDataset(torch.utils.data.Dataset):\n", 2625 | " def __init__(self, sentences: list[str]) -> None:\n", 2626 | " self.sentences = sentences\n", 2627 | "\n", 2628 | " # 事例の数を返す\n", 2629 | " def __len__(self) -> int:\n", 2630 | " return len(self.sentences)\n", 2631 | "\n", 2632 | " # idx番目の事例を返す\n", 2633 | " def __getitem__(self, idx: int) -> str:\n", 2634 | " return self.sentences[idx]\n", 2635 | "\n", 2636 | "\n", 2637 | "# 時間が掛かりすぎるので訓練事例を10万件に抑える\n", 2638 | "train_examples = train_examples[:100000]\n", 2639 | "\n", 2640 | "# Datasetインスタンスを生成\n", 2641 | "train_dataset = UnsupSimCSEDataset(train_examples['sentences'].tolist())" 2642 | ], 2643 | "metadata": { 2644 | "id": "GnRjbw_VSmtc" 2645 | }, 2646 | "execution_count": 9, 2647 | "outputs": [] 2648 | }, 2649 | { 2650 | "cell_type": "code", 2651 | "source": [ 2652 | "# 上で作ったDatasetについてDataLoaderを作成\n", 2653 | "\n", 2654 | "# ミニバッチのサイズ\n", 2655 | "#\n", 2656 | "# 論文で実際に使用されたモデルごとのバッチサイズの一覧は以下で提供されている\n", 2657 | "# https://github.com/princeton-nlp/SimCSE/tree/0.4#training\n", 2658 | "batch_size = 64\n", 2659 | "\n", 2660 | "# DataLoaderでDatasetからフェッチされた部分データについて、ミニバッチを形成するための前処理を記述できる\n", 2661 | "# トークナイザを用いて、文をTransformersモデルに入力できる形式に変換する\n", 2662 | "#\n", 2663 | "# パラメータはオリジナルの実装に由来\n", 2664 | "# https://github.com/princeton-nlp/SimCSE/blob/0.4/run_unsup_example.sh\n", 2665 | "def collate_fn(batch: list[str]) -> transformers.tokenization_utils.BatchEncoding:\n", 2666 | " return tokenizer(\n", 2667 | " batch,\n", 2668 | " # トークン列の長さをミニバッチ内の最大長に揃える\n", 2669 | " padding='longest',\n", 2670 | " # トークン列長がmax_lengthを超える場合は、末尾トークンを取り除きmax_lengthに揃える\n", 2671 | " truncation='longest_first',\n", 2672 | " # トークン列の最大長を指定\n", 2673 | " max_length=32,\n", 2674 | " # 結果をtorch.Tensor型で受け取る\n", 2675 | " return_tensors='pt',\n", 2676 | " )\n", 2677 | "\n", 2678 | "# DataLoaderインスタンスを生成\n", 2679 | "# https://pytorch.org/docs/stable/data.html\n", 2680 | "train_dataloader = torch.utils.data.DataLoader(\n", 2681 | " train_dataset,\n", 2682 | " collate_fn=collate_fn,\n", 2683 | " batch_size=batch_size,\n", 2684 | " shuffle=True,\n", 2685 | " num_workers=2,\n", 2686 | " # GPUへのメモリコピーを高速化する設定\n", 2687 | " # https://stackoverflow.com/questions/55563376/pytorch-how-does-pin-memory-work-in-dataloader\n", 2688 | " pin_memory=True,\n", 2689 | " # 最後のバッチのサイズがbatch_sizeで割り切れない場合は、異なるサイズのバッチが生成されないように切り捨てる\n", 2690 | " drop_last=True,\n", 2691 | ")" 2692 | ], 2693 | "metadata": { 2694 | "id": "WcUnVnhesR4T" 2695 | }, 2696 | "execution_count": 10, 2697 | "outputs": [] 2698 | }, 2699 | { 2700 | "cell_type": "markdown", 2701 | "source": [ 2702 | "## 2.3 ファインチューニング" 2703 | ], 2704 | "metadata": { 2705 | "id": "UrtOGul_9P-O" 2706 | } 2707 | }, 2708 | { 2709 | "cell_type": "code", 2710 | "source": [ 2711 | "# 主な学習パラメータ\n", 2712 | "\n", 2713 | "# エポック数 i.e., 訓練データを何回繰り返して学習するか\n", 2714 | "# 論文でのunsup-SimCSEのエポック数は1 (付録A参照)\n", 2715 | "epochs = 1\n", 2716 | "\n", 2717 | "# 学習率: 各学習ステップにおけるパラメータ更新の幅で、小さいほど細かい調整となる\n", 2718 | "# 論文で実際に使用されたモデルごとの学習率の一覧は以下で提供されている\n", 2719 | "# https://github.com/princeton-nlp/SimCSE/tree/0.4#training\n", 2720 | "learning_rate = 3e-5\n", 2721 | "\n", 2722 | "# 出力の確率分布の形状を制御するための温度パラメータ\n", 2723 | "# 論文の式(1)のτに相当する (解説は実際に使用されている箇所で後ほど)\n", 2724 | "temperature = 0.05" 2725 | ], 2726 | "metadata": { 2727 | "id": "8oG8h7fn-_JD" 2728 | }, 2729 | "execution_count": 11, 2730 | "outputs": [] 2731 | }, 2732 | { 2733 | "cell_type": "code", 2734 | "source": [ 2735 | "# 学習プロセス\n", 2736 | "\n", 2737 | "# torch.optimを通してモデルのパラメータを更新する\n", 2738 | "#\n", 2739 | "# パラメータはtorch.nn.Module.parameters()で受け渡せば、後はoptimizerが管理を請け負ってくれる\n", 2740 | "# https://pytorch.org/docs/stable/optim.html\n", 2741 | "#\n", 2742 | "# TransformersのBERT/RoBERTaモデルはtorch.nn.Moduleのサブクラスなので、\n", 2743 | "# SimCSEModel.backboneのパラメータもparameters()で再帰的にイテレートされる\n", 2744 | "# https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertModel\n", 2745 | "# https://huggingface.co/docs/transformers/model_doc/roberta#transformers.RobertaModel\n", 2746 | "#\n", 2747 | "# オリジナルの実装では、transformers.Trainerのデフォルト値であるAdamWを使用している\n", 2748 | "# https://huggingface.co/docs/transformers/main_classes/trainer#transformers.Trainer.optimizers\n", 2749 | "#\n", 2750 | "# TransformersにもAdamWの実装があるが、現在は非推奨\n", 2751 | "# https://github.com/huggingface/transformers/issues/3407\n", 2752 | "# https://github.com/huggingface/transformers/issues/18757\n", 2753 | "optimizer = torch.optim.AdamW(\n", 2754 | " params=model.parameters(),\n", 2755 | " lr=learning_rate\n", 2756 | ")\n", 2757 | "\n", 2758 | "# 学習ステップに応じて学習率を変動させるためのスケジューラを設定する\n", 2759 | "#\n", 2760 | "# オリジナルの実装では、transformers.Trainerのデフォルト値であるLinearSchedulerを使用している\n", 2761 | "# https://huggingface.co/docs/transformers/main_classes/trainer#transformers.Trainer.optimizers\n", 2762 | "#\n", 2763 | "# これはステップ数に応じて線形に学習率がゼロに近づくスケジューラで、序盤は大胆に、終盤はきめ細かくパラメータを更新する\n", 2764 | "# https://huggingface.co/docs/transformers/main_classes/optimizer_schedules#transformers.get_linear_schedule_with_warmup\n", 2765 | "#\n", 2766 | "# ただし、学習初期は予測がランダムで勾配も大きくなりやすいから、学習率は小さく抑えた方が良いというヒューリスティックも存在する\n", 2767 | "# そのためのアイデアが学習率のウォームアップで、序盤から中盤にかけて学習率を徐々に大きくしていく\n", 2768 | "# 参考:岡崎ら本6.4.2「学習率のウォームアップ」\n", 2769 | "#\n", 2770 | "# num_warmup_stepsでウォームアップのためのステップ数を指定できる\n", 2771 | "# ただし、オリジナルの実装ではtransformers.Trainerのデフォルト値を使用しているのでnum_warmup_steps=0\n", 2772 | "# https://huggingface.co/docs/transformers/v4.34.1/en/main_classes/trainer#transformers.TrainingArguments\n", 2773 | "lr_scheduler = transformers.optimization.get_linear_schedule_with_warmup(\n", 2774 | " optimizer=optimizer,\n", 2775 | " # とりあえずここではオリジナルの実装に合わせてウォームアップ無し\n", 2776 | " num_warmup_steps=0,\n", 2777 | " # len(train_dataloader) is the number of steps in one epoch\n", 2778 | " num_training_steps=len(train_dataloader) * epochs,\n", 2779 | ")\n", 2780 | "\n", 2781 | "# 訓練モードに設定\n", 2782 | "# https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.train\n", 2783 | "model.train()\n", 2784 | "\n", 2785 | "# ファインチューニング\n", 2786 | "for epoch in range(epochs):\n", 2787 | " for batch in tqdm.tqdm(train_dataloader):\n", 2788 | " # ミニバッチをデバイスに載せる\n", 2789 | " batch = batch.to(device)\n", 2790 | "\n", 2791 | " # unsup-SimCSEのメインの学習処理\n", 2792 | " #\n", 2793 | " # 同じバッチの埋め込みを2回計算しているだけ\n", 2794 | " # ただし内部では異なるドロップアウトが適用されているため、異なるデータ拡張による正例ペアが得られている\n", 2795 | " # https://pytorch.org/docs/stable/generated/torch.nn.Dropout.html\n", 2796 | " emb1 = model.forward(**batch)\n", 2797 | " emb2 = model.forward(**batch)\n", 2798 | "\n", 2799 | " # (余談) 例えばtransformersのBERTモデルでは、実際にドロップアウト層が組み込まれていることが以下から確認できる。\n", 2800 | " #\n", 2801 | " # https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/models/bert/modeling_bert.py#L192\n", 2802 | " # https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/models/bert/modeling_bert.py#L261\n", 2803 | " #\n", 2804 | " # ドロップアウトする確率はconfig.{hidden_dropout_prob, attention_probs_dropout_prob}で指定されており、以下の手順で変更できる\n", 2805 | " # https://stackoverflow.com/questions/64947064/transformers-pretrained-model-with-dropout-setting\n", 2806 | "\n", 2807 | " # 全対コサイン類似度計算 such that\n", 2808 | " #\n", 2809 | " # - sim_matrix.size() = (batch_size, batch_size)\n", 2810 | " # - sim_matrix[i][j] = cosine_sim(emb1[i], emb2[j])\n", 2811 | " #\n", 2812 | " # つまり、スライドP25の左の行列を作っている\n", 2813 | " #\n", 2814 | " # なぜこのコードで全対が計算できているかは、以下の解説が参考になる\n", 2815 | " # https://medium.com/@dhruvbird/all-pairs-cosine-similarity-in-pytorch-867e722c8572\n", 2816 | " sim_matrix = torch.nn.functional.cosine_similarity(emb1.unsqueeze(1), emb2.unsqueeze(0), dim=-1)\n", 2817 | "\n", 2818 | " # 温度パラメータによる確率分布の形状の調整\n", 2819 | " #\n", 2820 | " # sim_matrixは、この後のcross_entropyにてsoftmaxにより確率分布に変換される\n", 2821 | " # その際に、温度パラメータが1より小さいと高い類似度を強調するように調整できる\n", 2822 | " # https://qiita.com/nkriskeeic/items/db3b4b5e835e63a7f243\n", 2823 | " #\n", 2824 | " # SimCSEの性能は温度パラメータに敏感なので、適切な値を設定する必要がある\n", 2825 | " # 論文では0.05が良かったと報告している (付録D参照)\n", 2826 | " sim_matrix = sim_matrix / temperature\n", 2827 | "\n", 2828 | " # sim_matrixについて交差エントロピー損失を計算\n", 2829 | " # https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html#torch.nn.CrossEntropyLoss\n", 2830 | " #\n", 2831 | " # SimCSEの目的関数は正例間のコサイン類似度を最大化すること\n", 2832 | " #\n", 2833 | " # ここでは、sim_matrixをクラス数batch_sizeな分類問題の推論結果と見なして、\n", 2834 | " # 各行sim_matrix[i,:]が正解クラスに高い類似度、不正解クラスに低い類似度を予測できているかを評価している\n", 2835 | " #\n", 2836 | " # sim_matrixは対角成分に正例同士の類似度を格納しているので、行sim_matrix[i,:]の正解クラスとはsim_matrix[i,i]\n", 2837 | " # labels=[0,1,2,...,batch_size-1]で各行の正解クラスを指定している\n", 2838 | " labels = torch.arange(batch_size).long().to(device)\n", 2839 | " loss = torch.nn.functional.cross_entropy(sim_matrix, labels)\n", 2840 | "\n", 2841 | " # 全ての勾配をゼロに初期化\n", 2842 | " # https://stackoverflow.com/questions/48001598/why-do-we-need-to-call-zero-grad-in-pytorch\n", 2843 | " optimizer.zero_grad()\n", 2844 | "\n", 2845 | " # 誤差逆伝播により勾配を計算\n", 2846 | " loss.backward()\n", 2847 | "\n", 2848 | " # パラメータを更新\n", 2849 | " optimizer.step()\n", 2850 | "\n", 2851 | " # 学習率の更新\n", 2852 | " lr_scheduler.step()\n", 2853 | "\n", 2854 | " # もし開発用の評価セットが手元にある場合は、定期的に評価を実行しベストパフォーマンス時点のモデルを保持しておくと良い\n", 2855 | " # https://github.com/hppRC/simple-simcse/blob/main/train.py#L301-L331" 2856 | ], 2857 | "metadata": { 2858 | "colab": { 2859 | "base_uri": "https://localhost:8080/" 2860 | }, 2861 | "id": "8BqfcYx9X0Mx", 2862 | "outputId": "8f8d6feb-cd32-4d99-83fd-23aabe66c2cc" 2863 | }, 2864 | "execution_count": 12, 2865 | "outputs": [ 2866 | { 2867 | "output_type": "stream", 2868 | "name": "stderr", 2869 | "text": [ 2870 | "100%|██████████| 1562/1562 [16:35<00:00, 1.57it/s]\n" 2871 | ] 2872 | } 2873 | ] 2874 | }, 2875 | { 2876 | "cell_type": "code", 2877 | "source": [ 2878 | "# 学習結果をDriveに保存しておく\n", 2879 | "torch.save(model, unsup_model_path)" 2880 | ], 2881 | "metadata": { 2882 | "id": "zHrERtkuJKGf" 2883 | }, 2884 | "execution_count": 13, 2885 | "outputs": [] 2886 | }, 2887 | { 2888 | "cell_type": "markdown", 2889 | "source": [ 2890 | "# 3. 教師あり学習(sup-SimCSE)\n", 2891 | "\n", 2892 | "NLI(自然言語推論)データセットの含意ペアを正例、矛盾ペアを負例として対照学習する。\n", 2893 | "\n", 2894 | "注記: unsup-SimCSEと実装は大体一緒なので、異なる点のみコメントしてます。" 2895 | ], 2896 | "metadata": { 2897 | "id": "0q0Us9PCcrX3" 2898 | } 2899 | }, 2900 | { 2901 | "cell_type": "markdown", 2902 | "source": [ 2903 | "## 3.1 モデルインスタンスの生成" 2904 | ], 2905 | "metadata": { 2906 | "id": "cLmzRLo5uUTR" 2907 | } 2908 | }, 2909 | { 2910 | "cell_type": "code", 2911 | "source": [ 2912 | "model_name = 'bert-base-uncased'\n", 2913 | "model = SimCSEModel(model_name).to(device)\n", 2914 | "tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)" 2915 | ], 2916 | "metadata": { 2917 | "id": "ZMKHHb3HODOz" 2918 | }, 2919 | "execution_count": 14, 2920 | "outputs": [] 2921 | }, 2922 | { 2923 | "cell_type": "markdown", 2924 | "source": [ 2925 | "## 3.2 データセットの準備" 2926 | ], 2927 | "metadata": { 2928 | "id": "GCtW-jWHuO6o" 2929 | } 2930 | }, 2931 | { 2932 | "cell_type": "code", 2933 | "source": [ 2934 | "# sup-SimCSE訓練用のデータセットをダウンロード\n", 2935 | "\n", 2936 | "# 論文で実際に使用されたNLI(自然言語推論)データセットが使用できる\n", 2937 | "!mkdir -p ./datasets/sup-simcse\n", 2938 | "!wget https://huggingface.co/datasets/princeton-nlp/datasets-for-simcse/resolve/main/nli_for_simcse.csv\n", 2939 | "!mv ./nli_for_simcse.csv ./datasets/sup-simcse/train.csv" 2940 | ], 2941 | "metadata": { 2942 | "colab": { 2943 | "base_uri": "https://localhost:8080/" 2944 | }, 2945 | "id": "Z3dyCVYNctzy", 2946 | "outputId": "28ed2cce-6849-48b9-f281-93a051289d4c" 2947 | }, 2948 | "execution_count": 15, 2949 | "outputs": [ 2950 | { 2951 | "output_type": "stream", 2952 | "name": "stdout", 2953 | "text": [ 2954 | "--2023-11-01 09:19:09-- https://huggingface.co/datasets/princeton-nlp/datasets-for-simcse/resolve/main/nli_for_simcse.csv\n", 2955 | "Resolving huggingface.co (huggingface.co)... 3.163.189.114, 3.163.189.74, 3.163.189.90, ...\n", 2956 | "Connecting to huggingface.co (huggingface.co)|3.163.189.114|:443... connected.\n", 2957 | "HTTP request sent, awaiting response... 302 Found\n", 2958 | "Location: https://cdn-lfs.huggingface.co/datasets/princeton-nlp/datasets-for-simcse/0747687ec3594fa449d2004fd3757a56c24bf5f7428976fb5b67176775a68d48?response-content-disposition=attachment%3B+filename*%3DUTF-8%27%27nli_for_simcse.csv%3B+filename%3D%22nli_for_simcse.csv%22%3B&response-content-type=text%2Fcsv&Expires=1699089549&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTY5OTA4OTU0OX19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5odWdnaW5nZmFjZS5jby9kYXRhc2V0cy9wcmluY2V0b24tbmxwL2RhdGFzZXRzLWZvci1zaW1jc2UvMDc0NzY4N2VjMzU5NGZhNDQ5ZDIwMDRmZDM3NTdhNTZjMjRiZjVmNzQyODk3NmZiNWI2NzE3Njc3NWE2OGQ0OD9yZXNwb25zZS1jb250ZW50LWRpc3Bvc2l0aW9uPSomcmVzcG9uc2UtY29udGVudC10eXBlPSoifV19&Signature=DeNs78CHc7funEJrvd9p701jcLZTgPZok38qR2g4UOtc-nudMhlXcPwhmaP6RzzsQAXNdaMe%7E7-CdBCxsYTbRMUTkA7fvJ%7E70KNwjrrm-bmTql4OdkD0anGNZOUAloMQ26FyuwFspt0nu8GFt6V1KN9-eEuiSLp8tWfSkGLpYoXzEd6JXvvltnt75HGYlfSviMJvsuWbRBuSh60XUvW-YP-rMUAPbQwmbEfF%7ESmBESy1CszRMKU1lopVWaQrVcWih-pNmDaRVBE2nrw7BmiKtnOEchbtZQulMPMXrGDC2YVHqI7nkMnWLnh-mGFssBR36xim3RDBNQmfhEi1rhCS1A__&Key-Pair-Id=KVTP0A1DKRTAX [following]\n", 2959 | "--2023-11-01 09:19:09-- https://cdn-lfs.huggingface.co/datasets/princeton-nlp/datasets-for-simcse/0747687ec3594fa449d2004fd3757a56c24bf5f7428976fb5b67176775a68d48?response-content-disposition=attachment%3B+filename*%3DUTF-8%27%27nli_for_simcse.csv%3B+filename%3D%22nli_for_simcse.csv%22%3B&response-content-type=text%2Fcsv&Expires=1699089549&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTY5OTA4OTU0OX19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5odWdnaW5nZmFjZS5jby9kYXRhc2V0cy9wcmluY2V0b24tbmxwL2RhdGFzZXRzLWZvci1zaW1jc2UvMDc0NzY4N2VjMzU5NGZhNDQ5ZDIwMDRmZDM3NTdhNTZjMjRiZjVmNzQyODk3NmZiNWI2NzE3Njc3NWE2OGQ0OD9yZXNwb25zZS1jb250ZW50LWRpc3Bvc2l0aW9uPSomcmVzcG9uc2UtY29udGVudC10eXBlPSoifV19&Signature=DeNs78CHc7funEJrvd9p701jcLZTgPZok38qR2g4UOtc-nudMhlXcPwhmaP6RzzsQAXNdaMe%7E7-CdBCxsYTbRMUTkA7fvJ%7E70KNwjrrm-bmTql4OdkD0anGNZOUAloMQ26FyuwFspt0nu8GFt6V1KN9-eEuiSLp8tWfSkGLpYoXzEd6JXvvltnt75HGYlfSviMJvsuWbRBuSh60XUvW-YP-rMUAPbQwmbEfF%7ESmBESy1CszRMKU1lopVWaQrVcWih-pNmDaRVBE2nrw7BmiKtnOEchbtZQulMPMXrGDC2YVHqI7nkMnWLnh-mGFssBR36xim3RDBNQmfhEi1rhCS1A__&Key-Pair-Id=KVTP0A1DKRTAX\n", 2960 | "Resolving cdn-lfs.huggingface.co (cdn-lfs.huggingface.co)... 108.138.94.25, 108.138.94.23, 108.138.94.14, ...\n", 2961 | "Connecting to cdn-lfs.huggingface.co (cdn-lfs.huggingface.co)|108.138.94.25|:443... connected.\n", 2962 | "HTTP request sent, awaiting response... 200 OK\n", 2963 | "Length: 48978197 (47M) [text/csv]\n", 2964 | "Saving to: ‘nli_for_simcse.csv’\n", 2965 | "\n", 2966 | "nli_for_simcse.csv 100%[===================>] 46.71M 193MB/s in 0.2s \n", 2967 | "\n", 2968 | "2023-11-01 09:19:09 (193 MB/s) - ‘nli_for_simcse.csv’ saved [48978197/48978197]\n", 2969 | "\n" 2970 | ] 2971 | } 2972 | ] 2973 | }, 2974 | { 2975 | "cell_type": "code", 2976 | "source": [ 2977 | "# NLIデータセットは各エントリが以下の3つ組\n", 2978 | "#\n", 2979 | "# - sent0: premise (前提文)\n", 2980 | "# - sent1: entailment (含意)\n", 2981 | "# - hard_neg: contradiction (矛盾)\n", 2982 | "#\n", 2983 | "# (premise, entailment)を正例、(premise, contradiction)を負例として学習する\n", 2984 | "#\n", 2985 | "# unsup-SimCSEと同様、ミニバッチ内の別の事例同士も負例として学習する\n", 2986 | "# (スライド25Pの右図が分かりやすい)\n", 2987 | "\n", 2988 | "train_examples = pd.read_csv('./datasets/sup-simcse/train.csv')\n", 2989 | "train_examples" 2990 | ], 2991 | "metadata": { 2992 | "colab": { 2993 | "base_uri": "https://localhost:8080/", 2994 | "height": 424 2995 | }, 2996 | "id": "Q6t_QlYGfeez", 2997 | "outputId": "93b8fd47-eba3-4975-bfbd-724cd8710be8" 2998 | }, 2999 | "execution_count": 16, 3000 | "outputs": [ 3001 | { 3002 | "output_type": "execute_result", 3003 | "data": { 3004 | "text/plain": [ 3005 | " sent0 \\\n", 3006 | "0 you know during the season and i guess at at y... \n", 3007 | "1 One of our number will carry out your instruct... \n", 3008 | "2 How do you know? All this is their information... \n", 3009 | "3 yeah i tell you what though if you go price so... \n", 3010 | "4 my walkman broke so i'm upset now i just have ... \n", 3011 | "... ... \n", 3012 | "275596 A group of four kids stand in front of a statu... \n", 3013 | "275597 a kid doing tricks on a skateboard on a bridge \n", 3014 | "275598 A dog with a blue collar plays ball outside. \n", 3015 | "275599 Four dirty and barefooted children. \n", 3016 | "275600 A man is surfing in a bodysuit in beautiful bl... \n", 3017 | "\n", 3018 | " sent1 \\\n", 3019 | "0 You lose the things to the following level if ... \n", 3020 | "1 A member of my team will execute your orders w... \n", 3021 | "2 This information belongs to them. \n", 3022 | "3 The tennis shoes can be in the hundred dollar ... \n", 3023 | "4 I'm upset that my walkman broke and now I have... \n", 3024 | "... ... \n", 3025 | "275596 four kids standing \n", 3026 | "275597 a kid is skateboarding \n", 3027 | "275598 a dog is outside \n", 3028 | "275599 four children have dirty feet. \n", 3029 | "275600 On the beautiful blue water there is a man in ... \n", 3030 | "\n", 3031 | " hard_neg \n", 3032 | "0 They never perform recalls on anything. \n", 3033 | "1 We have no one free at the moment so you have ... \n", 3034 | "2 They have no information at all. \n", 3035 | "3 The tennis shoes are not over hundred dollars. \n", 3036 | "4 My walkman still works as well as it always did. \n", 3037 | "... ... \n", 3038 | "275596 the kids are seated \n", 3039 | "275597 a kid is inside \n", 3040 | "275598 a dog is on the couch \n", 3041 | "275599 four kids won awards for 'cleanest feet' \n", 3042 | "275600 A man in a business suit is heading to a board... \n", 3043 | "\n", 3044 | "[275601 rows x 3 columns]" 3045 | ], 3046 | "text/html": [ 3047 | "\n", 3048 | "
\n", 3049 | "
\n", 3050 | "\n", 3063 | "\n", 3064 | " \n", 3065 | " \n", 3066 | " \n", 3067 | " \n", 3068 | " \n", 3069 | " \n", 3070 | " \n", 3071 | " \n", 3072 | " \n", 3073 | " \n", 3074 | " \n", 3075 | " \n", 3076 | " \n", 3077 | " \n", 3078 | " \n", 3079 | " \n", 3080 | " \n", 3081 | " \n", 3082 | " \n", 3083 | " \n", 3084 | " \n", 3085 | " \n", 3086 | " \n", 3087 | " \n", 3088 | " \n", 3089 | " \n", 3090 | " \n", 3091 | " \n", 3092 | " \n", 3093 | " \n", 3094 | " \n", 3095 | " \n", 3096 | " \n", 3097 | " \n", 3098 | " \n", 3099 | " \n", 3100 | " \n", 3101 | " \n", 3102 | " \n", 3103 | " \n", 3104 | " \n", 3105 | " \n", 3106 | " \n", 3107 | " \n", 3108 | " \n", 3109 | " \n", 3110 | " \n", 3111 | " \n", 3112 | " \n", 3113 | " \n", 3114 | " \n", 3115 | " \n", 3116 | " \n", 3117 | " \n", 3118 | " \n", 3119 | " \n", 3120 | " \n", 3121 | " \n", 3122 | " \n", 3123 | " \n", 3124 | " \n", 3125 | " \n", 3126 | " \n", 3127 | " \n", 3128 | " \n", 3129 | " \n", 3130 | " \n", 3131 | " \n", 3132 | " \n", 3133 | " \n", 3134 | " \n", 3135 | " \n", 3136 | " \n", 3137 | " \n", 3138 | " \n", 3139 | " \n", 3140 | "
sent0sent1hard_neg
0you know during the season and i guess at at y...You lose the things to the following level if ...They never perform recalls on anything.
1One of our number will carry out your instruct...A member of my team will execute your orders w...We have no one free at the moment so you have ...
2How do you know? All this is their information...This information belongs to them.They have no information at all.
3yeah i tell you what though if you go price so...The tennis shoes can be in the hundred dollar ...The tennis shoes are not over hundred dollars.
4my walkman broke so i'm upset now i just have ...I'm upset that my walkman broke and now I have...My walkman still works as well as it always did.
............
275596A group of four kids stand in front of a statu...four kids standingthe kids are seated
275597a kid doing tricks on a skateboard on a bridgea kid is skateboardinga kid is inside
275598A dog with a blue collar plays ball outside.a dog is outsidea dog is on the couch
275599Four dirty and barefooted children.four children have dirty feet.four kids won awards for 'cleanest feet'
275600A man is surfing in a bodysuit in beautiful bl...On the beautiful blue water there is a man in ...A man in a business suit is heading to a board...
\n", 3141 | "

275601 rows × 3 columns

\n", 3142 | "
\n", 3143 | "
\n", 3144 | "\n", 3145 | "
\n", 3146 | " \n", 3154 | "\n", 3155 | " \n", 3195 | "\n", 3196 | " \n", 3220 | "
\n", 3221 | "\n", 3222 | "\n", 3223 | "
\n", 3224 | " \n", 3235 | "\n", 3236 | "\n", 3325 | "\n", 3326 | " \n", 3348 | "
\n", 3349 | "
\n", 3350 | "
\n" 3351 | ] 3352 | }, 3353 | "metadata": {}, 3354 | "execution_count": 16 3355 | } 3356 | ] 3357 | }, 3358 | { 3359 | "cell_type": "code", 3360 | "source": [ 3361 | "# sup-SimCSEでは3つ組を事例として返す\n", 3362 | "class SupSimCSEDataset(torch.utils.data.Dataset):\n", 3363 | " def __init__(\n", 3364 | " self,\n", 3365 | " premise: list[str],\n", 3366 | " entailment: list[str],\n", 3367 | " contradiction: list[str],\n", 3368 | " ):\n", 3369 | " assert len(premise) == len(entailment) == len(contradiction)\n", 3370 | " self.premise = premise\n", 3371 | " self.entailment = entailment\n", 3372 | " self.contradiction = contradiction\n", 3373 | "\n", 3374 | " def __getitem__(self, index: int) -> tuple[str, str, str]:\n", 3375 | " return self.premise[index], self.entailment[index], self.contradiction[index]\n", 3376 | "\n", 3377 | " def __len__(self) -> int:\n", 3378 | " return len(self.premise)\n", 3379 | "\n", 3380 | "\n", 3381 | "# 学習に時間が掛かりすぎるので10万件に抑える\n", 3382 | "train_examples = train_examples[:100000]\n", 3383 | "\n", 3384 | "train_dataset = SupSimCSEDataset(\n", 3385 | " premise=train_examples['sent0'].tolist(),\n", 3386 | " entailment=train_examples['sent1'].tolist(),\n", 3387 | " contradiction=train_examples['hard_neg'].tolist(),\n", 3388 | ")" 3389 | ], 3390 | "metadata": { 3391 | "id": "RGlWy99djrfr" 3392 | }, 3393 | "execution_count": 17, 3394 | "outputs": [] 3395 | }, 3396 | { 3397 | "cell_type": "code", 3398 | "source": [ 3399 | "def tokenize(batch: list[str]) -> transformers.tokenization_utils.BatchEncoding:\n", 3400 | " return tokenizer(\n", 3401 | " batch,\n", 3402 | " padding=True,\n", 3403 | " truncation=True,\n", 3404 | " return_tensors=\"pt\",\n", 3405 | " max_length=32,\n", 3406 | " )\n", 3407 | "\n", 3408 | "\n", 3409 | "def collate_fn(batch: list[tuple[str, str, str]]) -> transformers.tokenization_utils.BatchEncoding:\n", 3410 | " premise, entailment, contradiction = zip(*batch)\n", 3411 | " return transformers.tokenization_utils.BatchEncoding(\n", 3412 | " {\n", 3413 | " 'premise': tokenize(premise),\n", 3414 | " 'entailment': tokenize(entailment),\n", 3415 | " 'contradiction': tokenize(contradiction),\n", 3416 | " }\n", 3417 | " )\n", 3418 | "\n", 3419 | "# 論文では、sup-SimCSEのミニバッチのサイズは512\n", 3420 | "# しかしここでは、メモリ使用量の関係で64\n", 3421 | "# https://github.com/princeton-nlp/SimCSE/tree/0.4#training\n", 3422 | "batch_size = 64\n", 3423 | "\n", 3424 | "train_dataloader = torch.utils.data.DataLoader(\n", 3425 | " train_dataset,\n", 3426 | " collate_fn=collate_fn,\n", 3427 | " batch_size=batch_size,\n", 3428 | " shuffle=True,\n", 3429 | " num_workers=2,\n", 3430 | " pin_memory=True,\n", 3431 | " drop_last=True,\n", 3432 | ")" 3433 | ], 3434 | "metadata": { 3435 | "id": "FxhrNLNkuIFf" 3436 | }, 3437 | "execution_count": 18, 3438 | "outputs": [] 3439 | }, 3440 | { 3441 | "cell_type": "markdown", 3442 | "source": [ 3443 | "## 3.3 ファインチューニング" 3444 | ], 3445 | "metadata": { 3446 | "id": "NHxJdfc2uKmi" 3447 | } 3448 | }, 3449 | { 3450 | "cell_type": "code", 3451 | "source": [ 3452 | "# 論文ではsup-SimCSEのエポック数は3だが、ここでは時間の都合上1 (付録A参照)\n", 3453 | "epochs = 1\n", 3454 | "\n", 3455 | "# sup-SimCSEのbert-base-uncasedでの学習率は5e-5\n", 3456 | "# https://github.com/princeton-nlp/SimCSE/tree/0.4#training\n", 3457 | "learning_rate = 5e-5\n", 3458 | "\n", 3459 | "temperature = 0.05\n", 3460 | "\n", 3461 | "optimizer = torch.optim.AdamW(\n", 3462 | " params=model.parameters(),\n", 3463 | " lr=learning_rate\n", 3464 | ")\n", 3465 | "\n", 3466 | "lr_scheduler = transformers.optimization.get_linear_schedule_with_warmup(\n", 3467 | " optimizer=optimizer,\n", 3468 | " num_warmup_steps=0,\n", 3469 | " num_training_steps=len(train_dataloader) * epochs,\n", 3470 | ")\n", 3471 | "\n", 3472 | "for epoch in range(epochs):\n", 3473 | " model.train()\n", 3474 | "\n", 3475 | " for batch in tqdm.tqdm(train_dataloader):\n", 3476 | " batch = batch.to(device)\n", 3477 | "\n", 3478 | " # それぞれの文について埋め込みを計算\n", 3479 | " emb_pre = model.forward(**batch['premise'])\n", 3480 | " emb_ent = model.forward(**batch['entailment'])\n", 3481 | " emb_cnt = model.forward(**batch['contradiction'])\n", 3482 | "\n", 3483 | " # (emb_pre, emb_ent)と(emb_pre, emb_cnt)のそれぞれについて全対で類似度を計算し、最後に連結させる\n", 3484 | " # スライドP25の右を作ってると思えば分かりやすい\n", 3485 | " sim_matrix_pe = torch.nn.functional.cosine_similarity(emb_pre.unsqueeze(1), emb_ent.unsqueeze(0), dim=-1)\n", 3486 | " sim_matrix_pc = torch.nn.functional.cosine_similarity(emb_pre.unsqueeze(1), emb_cnt.unsqueeze(0), dim=-1)\n", 3487 | " sim_matrix = torch.cat([sim_matrix_pe, sim_matrix_pc], dim=1)\n", 3488 | "\n", 3489 | " sim_matrix = sim_matrix / temperature\n", 3490 | "\n", 3491 | " labels = torch.arange(batch_size).long().to(device)\n", 3492 | " loss = torch.nn.functional.cross_entropy(sim_matrix, labels)\n", 3493 | "\n", 3494 | " optimizer.zero_grad()\n", 3495 | " loss.backward()\n", 3496 | "\n", 3497 | " optimizer.step()\n", 3498 | " lr_scheduler.step()" 3499 | ], 3500 | "metadata": { 3501 | "colab": { 3502 | "base_uri": "https://localhost:8080/" 3503 | }, 3504 | "id": "Yam78QTKquun", 3505 | "outputId": "5b895b75-91dc-41d6-d06a-32651da143b9" 3506 | }, 3507 | "execution_count": 19, 3508 | "outputs": [ 3509 | { 3510 | "output_type": "stream", 3511 | "name": "stderr", 3512 | "text": [ 3513 | "100%|██████████| 1562/1562 [22:40<00:00, 1.15it/s]\n" 3514 | ] 3515 | } 3516 | ] 3517 | }, 3518 | { 3519 | "cell_type": "code", 3520 | "source": [ 3521 | "# 学習結果をDriveに保存しておく\n", 3522 | "torch.save(model, sup_model_path)" 3523 | ], 3524 | "metadata": { 3525 | "id": "-FIUaPA7mOF0" 3526 | }, 3527 | "execution_count": 20, 3528 | "outputs": [] 3529 | }, 3530 | { 3531 | "cell_type": "markdown", 3532 | "source": [ 3533 | "# 4. 評価\n", 3534 | "\n", 3535 | "STS (semantic textual similarity) Taskで埋め込みモデルの性能を評価する。" 3536 | ], 3537 | "metadata": { 3538 | "id": "I0od4-7Hg4v7" 3539 | } 3540 | }, 3541 | { 3542 | "cell_type": "code", 3543 | "source": [ 3544 | "# 訓練済みSimCSEモデルを読み込む\n", 3545 | "\n", 3546 | "unsup_model = None\n", 3547 | "if os.path.exists(unsup_model_path):\n", 3548 | " unsup_model = torch.load(unsup_model_path)\n", 3549 | "else:\n", 3550 | " print(f'{unsup_model_path} does not exist.')\n", 3551 | "\n", 3552 | "sup_model = None\n", 3553 | "if os.path.exists(sup_model_path):\n", 3554 | " sup_model = torch.load(sup_model_path)\n", 3555 | "else:\n", 3556 | " print(f'{sup_model_path} does not exist.')" 3557 | ], 3558 | "metadata": { 3559 | "id": "lafv58Cox0oj" 3560 | }, 3561 | "execution_count": 21, 3562 | "outputs": [] 3563 | }, 3564 | { 3565 | "cell_type": "code", 3566 | "source": [ 3567 | "# ファインチューニングして無いモデルも比較用に作成\n", 3568 | "\n", 3569 | "untuned_model = SimCSEModel('bert-base-uncased').to(device)" 3570 | ], 3571 | "metadata": { 3572 | "id": "VRLmJpIbioal" 3573 | }, 3574 | "execution_count": 22, 3575 | "outputs": [] 3576 | }, 3577 | { 3578 | "cell_type": "code", 3579 | "source": [ 3580 | "# STS Benchmarkデータセットを使用する。\n", 3581 | "# https://ixa2.si.ehu.eus/stswiki/index.php/STSbenchmark\n", 3582 | "\n", 3583 | "!wget http://ixa2.si.ehu.es/stswiki/images/4/48/Stsbenchmark.tar.gz\n", 3584 | "!tar -zxvf Stsbenchmark.tar.gz\n", 3585 | "!mkdir -p ./datasets/sts\n", 3586 | "!mv stsbenchmark ./datasets/sts/stsb\n", 3587 | "!rm Stsbenchmark.tar.gz" 3588 | ], 3589 | "metadata": { 3590 | "colab": { 3591 | "base_uri": "https://localhost:8080/" 3592 | }, 3593 | "id": "dtAWx0vMfZJy", 3594 | "outputId": "86628bb4-c1f3-456f-ff47-b60ddef91c44" 3595 | }, 3596 | "execution_count": 23, 3597 | "outputs": [ 3598 | { 3599 | "output_type": "stream", 3600 | "name": "stdout", 3601 | "text": [ 3602 | "--2023-11-01 09:42:00-- http://ixa2.si.ehu.es/stswiki/images/4/48/Stsbenchmark.tar.gz\n", 3603 | "Resolving ixa2.si.ehu.es (ixa2.si.ehu.es)... 158.227.106.100\n", 3604 | "Connecting to ixa2.si.ehu.es (ixa2.si.ehu.es)|158.227.106.100|:80... connected.\n", 3605 | "HTTP request sent, awaiting response... 302 Found\n", 3606 | "Location: http://ixa2.si.ehu.eus/stswiki/images/4/48/Stsbenchmark.tar.gz [following]\n", 3607 | "--2023-11-01 09:42:01-- http://ixa2.si.ehu.eus/stswiki/images/4/48/Stsbenchmark.tar.gz\n", 3608 | "Resolving ixa2.si.ehu.eus (ixa2.si.ehu.eus)... 158.227.106.100\n", 3609 | "Connecting to ixa2.si.ehu.eus (ixa2.si.ehu.eus)|158.227.106.100|:80... connected.\n", 3610 | "HTTP request sent, awaiting response... 200 OK\n", 3611 | "Length: 409630 (400K) [application/x-gzip]\n", 3612 | "Saving to: ‘Stsbenchmark.tar.gz’\n", 3613 | "\n", 3614 | "Stsbenchmark.tar.gz 100%[===================>] 400.03K 394KB/s in 1.0s \n", 3615 | "\n", 3616 | "2023-11-01 09:42:02 (394 KB/s) - ‘Stsbenchmark.tar.gz’ saved [409630/409630]\n", 3617 | "\n", 3618 | "stsbenchmark/\n", 3619 | "stsbenchmark/readme.txt\n", 3620 | "stsbenchmark/sts-test.csv\n", 3621 | "stsbenchmark/correlation.pl\n", 3622 | "stsbenchmark/LICENSE.txt\n", 3623 | "stsbenchmark/sts-dev.csv\n", 3624 | "stsbenchmark/sts-train.csv\n" 3625 | ] 3626 | } 3627 | ] 3628 | }, 3629 | { 3630 | "cell_type": "code", 3631 | "source": [ 3632 | "# STS Benchmarkデータセットをパース\n", 3633 | "#\n", 3634 | "# 色々と列が含まれているが使用するのは sentence1, sentence2, score のみ\n", 3635 | "# scoreには、人手評価により決めたsentence1とsentence2の意味的な類似度がアノテーションされている\n", 3636 | "\n", 3637 | "names = ['genre', 'file', 'year', 'sid', 'score', 'sentence1', 'sentence2']\n", 3638 | "sts_test_df = pd.read_csv(\n", 3639 | " 'datasets/sts/stsb/sts-test.csv',\n", 3640 | " sep='\\t',\n", 3641 | " header=None,\n", 3642 | " names=names,\n", 3643 | " # オプショナルで追加列が存在するので、パースする列数を指定する必要あり\n", 3644 | " usecols=range(len(names)),\n", 3645 | " # エラー「ParserError: Error tokenizing data. C error: EOF inside string starting at row 1118.」に対処\n", 3646 | " # https://stackoverflow.com/questions/18016037/pandas-parsererror-eof-character-when-reading-multiple-csv-files-to-hdf5\n", 3647 | " quoting=csv.QUOTE_NONE,\n", 3648 | ")\n", 3649 | "sts_test_df" 3650 | ], 3651 | "metadata": { 3652 | "colab": { 3653 | "base_uri": "https://localhost:8080/", 3654 | "height": 424 3655 | }, 3656 | "id": "QD6MUAGfte6z", 3657 | "outputId": "36d58896-133f-4282-9010-e895680ca178" 3658 | }, 3659 | "execution_count": 24, 3660 | "outputs": [ 3661 | { 3662 | "output_type": "execute_result", 3663 | "data": { 3664 | "text/plain": [ 3665 | " genre file year sid score \\\n", 3666 | "0 main-captions MSRvid 2012test 24 2.5 \n", 3667 | "1 main-captions MSRvid 2012test 33 3.6 \n", 3668 | "2 main-captions MSRvid 2012test 45 5.0 \n", 3669 | "3 main-captions MSRvid 2012test 63 4.2 \n", 3670 | "4 main-captions MSRvid 2012test 66 1.5 \n", 3671 | "... ... ... ... ... ... \n", 3672 | "1374 main-news headlines 2016 1354 0.0 \n", 3673 | "1375 main-news headlines 2016 1360 1.0 \n", 3674 | "1376 main-news headlines 2016 1368 1.0 \n", 3675 | "1377 main-news headlines 2016 1420 0.0 \n", 3676 | "1378 main-news headlines 2016 1432 0.0 \n", 3677 | "\n", 3678 | " sentence1 \\\n", 3679 | "0 A girl is styling her hair. \n", 3680 | "1 A group of men play soccer on the beach. \n", 3681 | "2 One woman is measuring another woman's ankle. \n", 3682 | "3 A man is cutting up a cucumber. \n", 3683 | "4 A man is playing a harp. \n", 3684 | "... ... \n", 3685 | "1374 Philippines, Canada pledge to further boost re... \n", 3686 | "1375 Israel bars Palestinians from Jerusalem's Old ... \n", 3687 | "1376 How much do you know about Secret Service? \n", 3688 | "1377 Obama Struggles to Soothe Saudi Fears As Iran ... \n", 3689 | "1378 South Korea declares end to MERS outbreak \n", 3690 | "\n", 3691 | " sentence2 \n", 3692 | "0 A girl is brushing her hair. \n", 3693 | "1 A group of boys are playing soccer on the beach. \n", 3694 | "2 A woman measures another woman's ankle. \n", 3695 | "3 A man is slicing a cucumber. \n", 3696 | "4 A man is playing a keyboard. \n", 3697 | "... ... \n", 3698 | "1374 Philippines saves 100 after ferry sinks \n", 3699 | "1375 Two-state solution between Palestinians, Israe... \n", 3700 | "1376 Lawmakers from both sides express outrage at S... \n", 3701 | "1377 Myanmar Struggles to Finalize Voter Lists for ... \n", 3702 | "1378 North Korea Delegation Meets With South Korean... \n", 3703 | "\n", 3704 | "[1379 rows x 7 columns]" 3705 | ], 3706 | "text/html": [ 3707 | "\n", 3708 | "
\n", 3709 | "
\n", 3710 | "\n", 3723 | "\n", 3724 | " \n", 3725 | " \n", 3726 | " \n", 3727 | " \n", 3728 | " \n", 3729 | " \n", 3730 | " \n", 3731 | " \n", 3732 | " \n", 3733 | " \n", 3734 | " \n", 3735 | " \n", 3736 | " \n", 3737 | " \n", 3738 | " \n", 3739 | " \n", 3740 | " \n", 3741 | " \n", 3742 | " \n", 3743 | " \n", 3744 | " \n", 3745 | " \n", 3746 | " \n", 3747 | " \n", 3748 | " \n", 3749 | " \n", 3750 | " \n", 3751 | " \n", 3752 | " \n", 3753 | " \n", 3754 | " \n", 3755 | " \n", 3756 | " \n", 3757 | " \n", 3758 | " \n", 3759 | " \n", 3760 | " \n", 3761 | " \n", 3762 | " \n", 3763 | " \n", 3764 | " \n", 3765 | " \n", 3766 | " \n", 3767 | " \n", 3768 | " \n", 3769 | " \n", 3770 | " \n", 3771 | " \n", 3772 | " \n", 3773 | " \n", 3774 | " \n", 3775 | " \n", 3776 | " \n", 3777 | " \n", 3778 | " \n", 3779 | " \n", 3780 | " \n", 3781 | " \n", 3782 | " \n", 3783 | " \n", 3784 | " \n", 3785 | " \n", 3786 | " \n", 3787 | " \n", 3788 | " \n", 3789 | " \n", 3790 | " \n", 3791 | " \n", 3792 | " \n", 3793 | " \n", 3794 | " \n", 3795 | " \n", 3796 | " \n", 3797 | " \n", 3798 | " \n", 3799 | " \n", 3800 | " \n", 3801 | " \n", 3802 | " \n", 3803 | " \n", 3804 | " \n", 3805 | " \n", 3806 | " \n", 3807 | " \n", 3808 | " \n", 3809 | " \n", 3810 | " \n", 3811 | " \n", 3812 | " \n", 3813 | " \n", 3814 | " \n", 3815 | " \n", 3816 | " \n", 3817 | " \n", 3818 | " \n", 3819 | " \n", 3820 | " \n", 3821 | " \n", 3822 | " \n", 3823 | " \n", 3824 | " \n", 3825 | " \n", 3826 | " \n", 3827 | " \n", 3828 | " \n", 3829 | " \n", 3830 | " \n", 3831 | " \n", 3832 | " \n", 3833 | " \n", 3834 | " \n", 3835 | " \n", 3836 | " \n", 3837 | " \n", 3838 | " \n", 3839 | " \n", 3840 | " \n", 3841 | " \n", 3842 | " \n", 3843 | " \n", 3844 | " \n", 3845 | " \n", 3846 | " \n", 3847 | " \n", 3848 | "
genrefileyearsidscoresentence1sentence2
0main-captionsMSRvid2012test242.5A girl is styling her hair.A girl is brushing her hair.
1main-captionsMSRvid2012test333.6A group of men play soccer on the beach.A group of boys are playing soccer on the beach.
2main-captionsMSRvid2012test455.0One woman is measuring another woman's ankle.A woman measures another woman's ankle.
3main-captionsMSRvid2012test634.2A man is cutting up a cucumber.A man is slicing a cucumber.
4main-captionsMSRvid2012test661.5A man is playing a harp.A man is playing a keyboard.
........................
1374main-newsheadlines201613540.0Philippines, Canada pledge to further boost re...Philippines saves 100 after ferry sinks
1375main-newsheadlines201613601.0Israel bars Palestinians from Jerusalem's Old ...Two-state solution between Palestinians, Israe...
1376main-newsheadlines201613681.0How much do you know about Secret Service?Lawmakers from both sides express outrage at S...
1377main-newsheadlines201614200.0Obama Struggles to Soothe Saudi Fears As Iran ...Myanmar Struggles to Finalize Voter Lists for ...
1378main-newsheadlines201614320.0South Korea declares end to MERS outbreakNorth Korea Delegation Meets With South Korean...
\n", 3849 | "

1379 rows × 7 columns

\n", 3850 | "
\n", 3851 | "
\n", 3852 | "\n", 3853 | "
\n", 3854 | " \n", 3862 | "\n", 3863 | " \n", 3903 | "\n", 3904 | " \n", 3928 | "
\n", 3929 | "\n", 3930 | "\n", 3931 | "
\n", 3932 | " \n", 3943 | "\n", 3944 | "\n", 4033 | "\n", 4034 | " \n", 4056 | "
\n", 4057 | "
\n", 4058 | "
\n" 4059 | ] 4060 | }, 4061 | "metadata": {}, 4062 | "execution_count": 24 4063 | } 4064 | ] 4065 | }, 4066 | { 4067 | "cell_type": "code", 4068 | "source": [ 4069 | "# ミニバッチのサイズ\n", 4070 | "batch_size = 512\n", 4071 | "\n", 4072 | "# 受け取ったSimCSEモデルを使って、入力文を埋め込みに変換する\n", 4073 | "#\n", 4074 | "# inference_modeを指定することで、評価時には余分な勾配の計算ための処理をスキップできる\n", 4075 | "# https://pytorch.org/docs/stable/generated/torch.inference_mode.html\n", 4076 | "@torch.inference_mode()\n", 4077 | "def encode(model: SimCSEModel, texts: list[str]) -> torch.Tensor:\n", 4078 | " # 評価モードに切り替え\n", 4079 | " # https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.eval\n", 4080 | " model.eval()\n", 4081 | "\n", 4082 | " embs = []\n", 4083 | " for batch in more_itertools.chunked(texts, batch_size):\n", 4084 | " batch = tokenizer(\n", 4085 | " batch,\n", 4086 | " padding=True,\n", 4087 | " truncation=True,\n", 4088 | " return_tensors='pt',\n", 4089 | " )\n", 4090 | " batch = batch.to(device)\n", 4091 | " emb = model(**batch)\n", 4092 | " embs.append(emb.cpu())\n", 4093 | "\n", 4094 | " # shape of output: (len(texts), hidden_size)\n", 4095 | " return torch.cat(embs, dim=0)" 4096 | ], 4097 | "metadata": { 4098 | "id": "VSrtrzsiUSmA" 4099 | }, 4100 | "execution_count": 25, 4101 | "outputs": [] 4102 | }, 4103 | { 4104 | "cell_type": "code", 4105 | "source": [ 4106 | "# モデルから得られた文埋め込み間のコサイン類似度と正解スコアを比較して、\n", 4107 | "# Spearmanの順位相関係数(×100)を性能スコアとして返す。\n", 4108 | "#\n", 4109 | "# Pearsonの相関係数も伝統的に使用されてきたが、Spearmanの方が文埋め込みの評価に適しているという議論がある\n", 4110 | "# (論文付録B参照)\n", 4111 | "def evaluate(model: SimCSEModel) -> float:\n", 4112 | " sentences1 = sts_test_df['sentence1']\n", 4113 | " sentences2 = sts_test_df['sentence2']\n", 4114 | " scores = sts_test_df['score']\n", 4115 | "\n", 4116 | " embeddings1 = encode(model, sentences1)\n", 4117 | " embeddings2 = encode(model, sentences2)\n", 4118 | "\n", 4119 | " cosine_scores = 1 - sklearn_metrics.pairwise.paired_cosine_distances(embeddings1, embeddings2)\n", 4120 | " spearman = float(scipy.stats.spearmanr(scores, cosine_scores)[0]) * 100\n", 4121 | "\n", 4122 | " return spearman" 4123 | ], 4124 | "metadata": { 4125 | "id": "poUgAed05LEQ" 4126 | }, 4127 | "execution_count": 26, 4128 | "outputs": [] 4129 | }, 4130 | { 4131 | "cell_type": "code", 4132 | "source": [ 4133 | "print(f'notrain-simcse: {evaluate(untuned_model):g}')\n", 4134 | "if unsup_model is not None:\n", 4135 | " print(f'unsup-simcse: {evaluate(unsup_model):g}')\n", 4136 | "if sup_model is not None:\n", 4137 | " print(f'sup-simcse: {evaluate(sup_model):g}')" 4138 | ], 4139 | "metadata": { 4140 | "colab": { 4141 | "base_uri": "https://localhost:8080/" 4142 | }, 4143 | "id": "b_aQC616lyjd", 4144 | "outputId": "64527852-46f7-4227-eb79-6e76dfadaf55" 4145 | }, 4146 | "execution_count": 27, 4147 | "outputs": [ 4148 | { 4149 | "output_type": "stream", 4150 | "name": "stdout", 4151 | "text": [ 4152 | "notrain-simcse: 20.2978\n", 4153 | "unsup-simcse: 61.5961\n", 4154 | "sup-simcse: 80.6809\n" 4155 | ] 4156 | } 4157 | ] 4158 | } 4159 | ] 4160 | } --------------------------------------------------------------------------------