├── .dockerignore ├── .gitignore ├── .gitmodules ├── LICENSE ├── README.md ├── albert_config.json ├── config.ini ├── data ├── .gitkeep └── wiki │ └── .gitkeep ├── model └── .gitkeep ├── notebook ├── .gitkeep ├── AlbertExample.ipynb ├── finetune-to-livedoor-corpus.ipynb └── pretraining.ipynb ├── pretraining-loss.png ├── requirements.txt └── src ├── convert_tfmodel_to_pytorch.py ├── data-download-and-extract.py ├── file-preprocessing.sh ├── run_classifier.py ├── run_create_pretraining_data.sh ├── train-sentencepiece.py └── utils.py /.dockerignore: -------------------------------------------------------------------------------- 1 | data 2 | model 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | data/* 2 | !data/.gitkeep 3 | !data/wiki/ 4 | data/wiki/* 5 | !data/wiki/.gitkeep 6 | model/* 7 | !model/.gitkeep 8 | __pycache__ 9 | .ipynb_checkpoints 10 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "wikiextractor"] 2 | path = wikiextractor 3 | url = https://github.com/attardi/wikiextractor.git 4 | [submodule "ALBERT"] 5 | path = ALBERT 6 | url = https://github.com/alinear-corp/ALBERT.git 7 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ALBERT with SentencePiece for Japanese text. 2 | This is a repository of Japanese ALBERT model with SentencePiece tokenizer. 3 | (Note: this project is a fork of [bert-japanese](https://github.com/yoheikikuta/bert-japanese)) 4 | 5 | To clone this repository together with the required 6 | [ALBERT](https://github.com/alinear-corp/ALBERT) (our fork of original [ALBERT](https://github.com/google-research/ALBERT)) and 7 | [WikiExtractor](https://github.com/attardi/wikiextractor): 8 | 9 | git clone --recurse-submodules https://github.com/alinear-corp/albert-japanese 10 | 11 | ## Pretrained models 12 | We provide pretrained BERT model and trained SentencePiece model for Japanese text. 13 | Training data is the Japanese wikipedia corpus from [`Wikimedia Downloads`](https://dumps.wikimedia.org/). 14 | The latest release is `v2`, you can download the pretrained model at: 15 | - **[`Pretrained BERT model and trained SentencePiece model`](https://drive.google.com/drive/folders/1qvVrG4u8F94694zSExj8gWWoiAqyAlC2?usp=sharing)** 16 | 17 | Loss function during training is as below (after 1M steps the loss function massively changes because `max_seq_length` is changed from `128` to `512`.): 18 | ![pretraining-loss](pretraining-loss.png) 19 | 20 | ## Use pretrained model with transformers 21 | 22 | You can use the pretrained model with transformers. 23 | The model is uploaded to hagging face's repository. 24 | - **[ALINEAR/albert-japanese-v2](https://huggingface.co/ALINEAR/albert-japanese-v2)** 25 | 26 | For example, if you want to train AlbertForSequenceClassification, you can load the pretrained model by: 27 | 28 | from transformers import AlbertTokenizer, AlbertForSequenceClassification 29 | tokenizer = AlbertTokenizer.from_pretrained("ALINEAR/albert-japanese-v2") 30 | model = AlbertForSequenceClassification.from_pretrained("ALINEAR/albert-japanese-v2", num_labels=...) 31 | 32 | ## Finetuning with BERT Japanese 33 | We also provide a simple Japanese text classification problem with [`livedoor ニュースコーパス`](https://www.rondhuit.com/download.html). 34 | Try the following notebook to check the usability of finetuning. 35 | You can run the notebook on CPU (too slow) or GPU/TPU environments. 36 | - **[finetune-to-livedoor-corpus.ipynb](https://github.com/alinear-corp/albert-japanese/blob/master/notebook/finetune-to-livedoor-corpus.ipynb)** 37 | 38 | The results are the following: 39 | - ALBERT with SentencePiece 40 | ``` 41 | precision recall f1-score support 42 | 43 | dokujo-tsushin 1.00 0.93 0.96 178 44 | it-life-hack 0.92 0.96 0.94 172 45 | kaden-channel 0.95 0.98 0.97 176 46 | livedoor-homme 0.90 0.82 0.86 95 47 | movie-enter 0.96 0.98 0.97 158 48 | peachy 0.95 0.97 0.96 174 49 | smax 0.99 0.98 0.98 167 50 | sports-watch 0.96 0.98 0.97 190 51 | topic-news 0.96 0.94 0.95 163 52 | 53 | accuracy 0.96 1473 54 | macro avg 0.95 0.95 0.95 1473 55 | weighted avg 0.96 0.96 0.96 1473 56 | ``` 57 | - BERT with SentencePiece (from [original bert-japanese repository](https://github.com/yoheikikuta/bert-japanese#finetuning-with-bert-japanese)) 58 | ``` 59 | precision recall f1-score support 60 | 61 | dokujo-tsushin 0.98 0.94 0.96 178 62 | it-life-hack 0.96 0.97 0.96 172 63 | kaden-channel 0.99 0.98 0.99 176 64 | livedoor-homme 0.98 0.88 0.93 95 65 | movie-enter 0.96 0.99 0.98 158 66 | peachy 0.94 0.98 0.96 174 67 | smax 0.98 0.99 0.99 167 68 | sports-watch 0.98 1.00 0.99 190 69 | topic-news 0.99 0.98 0.98 163 70 | 71 | micro avg 0.97 0.97 0.97 1473 72 | macro avg 0.97 0.97 0.97 1473 73 | weighted avg 0.97 0.97 0.97 1473 74 | ``` 75 | 76 | ## Pretraining from scratch 77 | All scripts for pretraining from scratch are provided. 78 | Follow the instructions below. 79 | 80 | ### Data preparation 81 | Data downloading and preprocessing. 82 | 83 | ``` 84 | python3 src/data-download-and-extract.py 85 | bash src/file-preprocessing.sh 86 | ``` 87 | 88 | The above scripts use the latest jawiki data and wikiextractor module, which are different from those used for the pretrained model. 89 | If you wanna prepare the same situation, use the following information: 90 | 91 | - albert-japanese: commit `e420eab47a1d6d4775adc07e0b112aac8088d81b` 92 | - ALBERT: commit `08a848f08ec79d85f434b5c2fb6147e89f01bccb` 93 | - dataset: `jawiki-20191201-pages-articles-multistream.xml.bz2` in the [Google Drive](https://drive.google.com/drive/folders/1qvVrG4u8F94694zSExj8gWWoiAqyAlC2?usp=sharing) 94 | - wikiextractor: commit `16186e290d9eb0eb3a3784c6c0635a9ed7e855c3` 95 | 96 | ### Training SentencePiece model 97 | Train a SentencePiece model using the preprocessed data. 98 | 99 | ``` 100 | python3 src/train-sentencepiece.py 101 | ``` 102 | 103 | ### Creating data for pretraining 104 | Create .tfrecord files for pretraining. 105 | 106 | ``` 107 | bash src/run_create_pretraining_data.sh [extract_dir] [max_seq_length] 108 | ``` 109 | 110 | `extract_dir` is a base directory to which the wikipedia texts are extracted, and, 111 | `max_seq_length` need to be 128 and 512 112 | 113 | ### Pretraining 114 | You need GPU/TPU environment to pretrain a BERT model. 115 | The following notebook provides the link to Colab notebook where you can run the scripts with TPUs. 116 | 117 | - **[pretraining.ipynb](https://github.com/alinear-corp/albert-japanese/blob/master/notebook/pretraining.ipynb)** 118 | -------------------------------------------------------------------------------- /albert_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "attention_probs_dropout_prob": 0.1, 3 | "hidden_act": "gelu", 4 | "hidden_dropout_prob": 0.1, 5 | "embedding_size": 128, 6 | "hidden_size": 768, 7 | "initializer_range": 0.02, 8 | "intermediate_size": 3072, 9 | "max_position_embeddings": 512, 10 | "num_attention_heads": 12, 11 | "num_hidden_layers": 12, 12 | "num_hidden_groups": 1, 13 | "net_structure_type": 0, 14 | "gap_size": 0, 15 | "num_memory_blocks": 0, 16 | "inner_group_num": 1, 17 | "down_scale_factor": 1, 18 | "type_vocab_size": 2, 19 | "vocab_size": 32000 20 | } 21 | 22 | -------------------------------------------------------------------------------- /config.ini: -------------------------------------------------------------------------------- 1 | [DATA] 2 | FILEURL = https://dumps.wikimedia.org/jawiki/latest/jawiki-latest-pages-articles-multistream.xml.bz2 3 | FILEPATH = /work/data/jawiki-latest-pages-articles-multistream.xml.bz2 4 | DATADIR = /work/data/ 5 | TEXTDIR = /work/data/wiki/ 6 | 7 | [SENTENCEPIECE] 8 | PREFIX = /work/model/wiki-ja 9 | VOCABSIZE = 32000 10 | CTLSYMBOLS = [CLS],[SEP],[MASK] 11 | 12 | [FINETUNING-DATA] 13 | FILEURL = https://www.rondhuit.com/download/ldcc-20140209.tar.gz 14 | -------------------------------------------------------------------------------- /data/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alinear-corp/albert-japanese/ec054c5647ec2ca13f9c57f83a238b53cc998255/data/.gitkeep -------------------------------------------------------------------------------- /data/wiki/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alinear-corp/albert-japanese/ec054c5647ec2ca13f9c57f83a238b53cc998255/data/wiki/.gitkeep -------------------------------------------------------------------------------- /model/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alinear-corp/albert-japanese/ec054c5647ec2ca13f9c57f83a238b53cc998255/model/.gitkeep -------------------------------------------------------------------------------- /notebook/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alinear-corp/albert-japanese/ec054c5647ec2ca13f9c57f83a238b53cc998255/notebook/.gitkeep -------------------------------------------------------------------------------- /notebook/AlbertExample.ipynb: -------------------------------------------------------------------------------- 1 | {"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"AlbertExample.ipynb","provenance":[],"collapsed_sections":[],"mount_file_id":"1ejG4EIKo-eILREVwU4Y0PkTlaGutCPrF","authorship_tag":"ABX9TyOpmEYrFcL2bpXTwLLgl4XR"},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"},"accelerator":"GPU","widgets":{"application/vnd.jupyter.widget-state+json":{"eaadb0acc47e4bcd881f8ba4864d2b42":{"model_module":"@jupyter-widgets/controls","model_name":"HBoxModel","model_module_version":"1.5.0","state":{"_view_name":"HBoxView","_dom_classes":[],"_model_name":"HBoxModel","_view_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_view_count":null,"_view_module_version":"1.5.0","box_style":"","layout":"IPY_MODEL_22eb56149e584897bdd438ca62669267","_model_module":"@jupyter-widgets/controls","children":["IPY_MODEL_de4e5eec3f9846d897ec1559e50d6cc1","IPY_MODEL_e66b1567e47d4e56aae7681c694d2c18","IPY_MODEL_0dad0162cfa645f28e70a0d561ac7497"]}},"22eb56149e584897bdd438ca62669267":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_view_name":"LayoutView","grid_template_rows":null,"right":null,"justify_content":null,"_view_module":"@jupyter-widgets/base","overflow":null,"_model_module_version":"1.2.0","_view_count":null,"flex_flow":null,"width":null,"min_width":null,"border":null,"align_items":null,"bottom":null,"_model_module":"@jupyter-widgets/base","top":null,"grid_column":null,"overflow_y":null,"overflow_x":null,"grid_auto_flow":null,"grid_area":null,"grid_template_columns":null,"flex":null,"_model_name":"LayoutModel","justify_items":null,"grid_row":null,"max_height":null,"align_content":null,"visibility":null,"align_self":null,"height":null,"min_height":null,"padding":null,"grid_auto_rows":null,"grid_gap":null,"max_width":null,"order":null,"_view_module_version":"1.2.0","grid_template_areas":null,"object_position":null,"object_fit":null,"grid_auto_columns":null,"margin":null,"display":null,"left":null}},"de4e5eec3f9846d897ec1559e50d6cc1":{"model_module":"@jupyter-widgets/controls","model_name":"HTMLModel","model_module_version":"1.5.0","state":{"_view_name":"HTMLView","style":"IPY_MODEL_1816c2ac4e1e4eec8e5575311a5e404f","_dom_classes":[],"description":"","_model_name":"HTMLModel","placeholder":"​","_view_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","value":"Downloading: 100%","_view_count":null,"_view_module_version":"1.5.0","description_tooltip":null,"_model_module":"@jupyter-widgets/controls","layout":"IPY_MODEL_4e3e0856c0bf4f51b3dd33403e73aeb3"}},"e66b1567e47d4e56aae7681c694d2c18":{"model_module":"@jupyter-widgets/controls","model_name":"FloatProgressModel","model_module_version":"1.5.0","state":{"_view_name":"ProgressView","style":"IPY_MODEL_17ae15d338e7489384da7510449d6aed","_dom_classes":[],"description":"","_model_name":"FloatProgressModel","bar_style":"success","max":784585,"_view_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","value":784585,"_view_count":null,"_view_module_version":"1.5.0","orientation":"horizontal","min":0,"description_tooltip":null,"_model_module":"@jupyter-widgets/controls","layout":"IPY_MODEL_30ed416a6b1249199a4189306b2b4575"}},"0dad0162cfa645f28e70a0d561ac7497":{"model_module":"@jupyter-widgets/controls","model_name":"HTMLModel","model_module_version":"1.5.0","state":{"_view_name":"HTMLView","style":"IPY_MODEL_c07a4fa96f6b43c59af1f143c7f646b7","_dom_classes":[],"description":"","_model_name":"HTMLModel","placeholder":"​","_view_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","value":" 766k/766k [00:00<00:00, 1.24MB/s]","_view_count":null,"_view_module_version":"1.5.0","description_tooltip":null,"_model_module":"@jupyter-widgets/controls","layout":"IPY_MODEL_6cc83ca859124b8cb9aa9c66fc91939b"}},"1816c2ac4e1e4eec8e5575311a5e404f":{"model_module":"@jupyter-widgets/controls","model_name":"DescriptionStyleModel","model_module_version":"1.5.0","state":{"_view_name":"StyleView","_model_name":"DescriptionStyleModel","description_width":"","_view_module":"@jupyter-widgets/base","_model_module_version":"1.5.0","_view_count":null,"_view_module_version":"1.2.0","_model_module":"@jupyter-widgets/controls"}},"4e3e0856c0bf4f51b3dd33403e73aeb3":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_view_name":"LayoutView","grid_template_rows":null,"right":null,"justify_content":null,"_view_module":"@jupyter-widgets/base","overflow":null,"_model_module_version":"1.2.0","_view_count":null,"flex_flow":null,"width":null,"min_width":null,"border":null,"align_items":null,"bottom":null,"_model_module":"@jupyter-widgets/base","top":null,"grid_column":null,"overflow_y":null,"overflow_x":null,"grid_auto_flow":null,"grid_area":null,"grid_template_columns":null,"flex":null,"_model_name":"LayoutModel","justify_items":null,"grid_row":null,"max_height":null,"align_content":null,"visibility":null,"align_self":null,"height":null,"min_height":null,"padding":null,"grid_auto_rows":null,"grid_gap":null,"max_width":null,"order":null,"_view_module_version":"1.2.0","grid_template_areas":null,"object_position":null,"object_fit":null,"grid_auto_columns":null,"margin":null,"display":null,"left":null}},"17ae15d338e7489384da7510449d6aed":{"model_module":"@jupyter-widgets/controls","model_name":"ProgressStyleModel","model_module_version":"1.5.0","state":{"_view_name":"StyleView","_model_name":"ProgressStyleModel","description_width":"","_view_module":"@jupyter-widgets/base","_model_module_version":"1.5.0","_view_count":null,"_view_module_version":"1.2.0","bar_color":null,"_model_module":"@jupyter-widgets/controls"}},"30ed416a6b1249199a4189306b2b4575":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_view_name":"LayoutView","grid_template_rows":null,"right":null,"justify_content":null,"_view_module":"@jupyter-widgets/base","overflow":null,"_model_module_version":"1.2.0","_view_count":null,"flex_flow":null,"width":null,"min_width":null,"border":null,"align_items":null,"bottom":null,"_model_module":"@jupyter-widgets/base","top":null,"grid_column":null,"overflow_y":null,"overflow_x":null,"grid_auto_flow":null,"grid_area":null,"grid_template_columns":null,"flex":null,"_model_name":"LayoutModel","justify_items":null,"grid_row":null,"max_height":null,"align_content":null,"visibility":null,"align_self":null,"height":null,"min_height":null,"padding":null,"grid_auto_rows":null,"grid_gap":null,"max_width":null,"order":null,"_view_module_version":"1.2.0","grid_template_areas":null,"object_position":null,"object_fit":null,"grid_auto_columns":null,"margin":null,"display":null,"left":null}},"c07a4fa96f6b43c59af1f143c7f646b7":{"model_module":"@jupyter-widgets/controls","model_name":"DescriptionStyleModel","model_module_version":"1.5.0","state":{"_view_name":"StyleView","_model_name":"DescriptionStyleModel","description_width":"","_view_module":"@jupyter-widgets/base","_model_module_version":"1.5.0","_view_count":null,"_view_module_version":"1.2.0","_model_module":"@jupyter-widgets/controls"}},"6cc83ca859124b8cb9aa9c66fc91939b":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_view_name":"LayoutView","grid_template_rows":null,"right":null,"justify_content":null,"_view_module":"@jupyter-widgets/base","overflow":null,"_model_module_version":"1.2.0","_view_count":null,"flex_flow":null,"width":null,"min_width":null,"border":null,"align_items":null,"bottom":null,"_model_module":"@jupyter-widgets/base","top":null,"grid_column":null,"overflow_y":null,"overflow_x":null,"grid_auto_flow":null,"grid_area":null,"grid_template_columns":null,"flex":null,"_model_name":"LayoutModel","justify_items":null,"grid_row":null,"max_height":null,"align_content":null,"visibility":null,"align_self":null,"height":null,"min_height":null,"padding":null,"grid_auto_rows":null,"grid_gap":null,"max_width":null,"order":null,"_view_module_version":"1.2.0","grid_template_areas":null,"object_position":null,"object_fit":null,"grid_auto_columns":null,"margin":null,"display":null,"left":null}},"e041146ae35a49999deecffacd5cdab7":{"model_module":"@jupyter-widgets/controls","model_name":"HBoxModel","model_module_version":"1.5.0","state":{"_view_name":"HBoxView","_dom_classes":[],"_model_name":"HBoxModel","_view_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_view_count":null,"_view_module_version":"1.5.0","box_style":"","layout":"IPY_MODEL_5af56be959294b5a80ebc9321634fcd7","_model_module":"@jupyter-widgets/controls","children":["IPY_MODEL_cc8774f4ec614acc93048ea8d489003f","IPY_MODEL_f1f25d2a879d4fdc87a50f720a5c7c2f","IPY_MODEL_34a5ac426b624121a7abccac0a0a50ee"]}},"5af56be959294b5a80ebc9321634fcd7":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_view_name":"LayoutView","grid_template_rows":null,"right":null,"justify_content":null,"_view_module":"@jupyter-widgets/base","overflow":null,"_model_module_version":"1.2.0","_view_count":null,"flex_flow":null,"width":null,"min_width":null,"border":null,"align_items":null,"bottom":null,"_model_module":"@jupyter-widgets/base","top":null,"grid_column":null,"overflow_y":null,"overflow_x":null,"grid_auto_flow":null,"grid_area":null,"grid_template_columns":null,"flex":null,"_model_name":"LayoutModel","justify_items":null,"grid_row":null,"max_height":null,"align_content":null,"visibility":null,"align_self":null,"height":null,"min_height":null,"padding":null,"grid_auto_rows":null,"grid_gap":null,"max_width":null,"order":null,"_view_module_version":"1.2.0","grid_template_areas":null,"object_position":null,"object_fit":null,"grid_auto_columns":null,"margin":null,"display":null,"left":null}},"cc8774f4ec614acc93048ea8d489003f":{"model_module":"@jupyter-widgets/controls","model_name":"HTMLModel","model_module_version":"1.5.0","state":{"_view_name":"HTMLView","style":"IPY_MODEL_e5e2c45b1c8a413aa0d6387a3e13de87","_dom_classes":[],"description":"","_model_name":"HTMLModel","placeholder":"​","_view_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","value":"Downloading: 100%","_view_count":null,"_view_module_version":"1.5.0","description_tooltip":null,"_model_module":"@jupyter-widgets/controls","layout":"IPY_MODEL_5774894d34434f48ac6639389caee424"}},"f1f25d2a879d4fdc87a50f720a5c7c2f":{"model_module":"@jupyter-widgets/controls","model_name":"FloatProgressModel","model_module_version":"1.5.0","state":{"_view_name":"ProgressView","style":"IPY_MODEL_2d5d877e4cbf4ba996bc1f01134c7343","_dom_classes":[],"description":"","_model_name":"FloatProgressModel","bar_style":"success","max":156,"_view_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","value":156,"_view_count":null,"_view_module_version":"1.5.0","orientation":"horizontal","min":0,"description_tooltip":null,"_model_module":"@jupyter-widgets/controls","layout":"IPY_MODEL_e03425cba05742748e55d6f3612c4c19"}},"34a5ac426b624121a7abccac0a0a50ee":{"model_module":"@jupyter-widgets/controls","model_name":"HTMLModel","model_module_version":"1.5.0","state":{"_view_name":"HTMLView","style":"IPY_MODEL_7b95fc0bd04b4f2c9e195f490d1fbb28","_dom_classes":[],"description":"","_model_name":"HTMLModel","placeholder":"​","_view_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","value":" 156/156 [00:00<00:00, 3.34kB/s]","_view_count":null,"_view_module_version":"1.5.0","description_tooltip":null,"_model_module":"@jupyter-widgets/controls","layout":"IPY_MODEL_04809ef19b4b425da061f18b1f4ad5fe"}},"e5e2c45b1c8a413aa0d6387a3e13de87":{"model_module":"@jupyter-widgets/controls","model_name":"DescriptionStyleModel","model_module_version":"1.5.0","state":{"_view_name":"StyleView","_model_name":"DescriptionStyleModel","description_width":"","_view_module":"@jupyter-widgets/base","_model_module_version":"1.5.0","_view_count":null,"_view_module_version":"1.2.0","_model_module":"@jupyter-widgets/controls"}},"5774894d34434f48ac6639389caee424":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_view_name":"LayoutView","grid_template_rows":null,"right":null,"justify_content":null,"_view_module":"@jupyter-widgets/base","overflow":null,"_model_module_version":"1.2.0","_view_count":null,"flex_flow":null,"width":null,"min_width":null,"border":null,"align_items":null,"bottom":null,"_model_module":"@jupyter-widgets/base","top":null,"grid_column":null,"overflow_y":null,"overflow_x":null,"grid_auto_flow":null,"grid_area":null,"grid_template_columns":null,"flex":null,"_model_name":"LayoutModel","justify_items":null,"grid_row":null,"max_height":null,"align_content":null,"visibility":null,"align_self":null,"height":null,"min_height":null,"padding":null,"grid_auto_rows":null,"grid_gap":null,"max_width":null,"order":null,"_view_module_version":"1.2.0","grid_template_areas":null,"object_position":null,"object_fit":null,"grid_auto_columns":null,"margin":null,"display":null,"left":null}},"2d5d877e4cbf4ba996bc1f01134c7343":{"model_module":"@jupyter-widgets/controls","model_name":"ProgressStyleModel","model_module_version":"1.5.0","state":{"_view_name":"StyleView","_model_name":"ProgressStyleModel","description_width":"","_view_module":"@jupyter-widgets/base","_model_module_version":"1.5.0","_view_count":null,"_view_module_version":"1.2.0","bar_color":null,"_model_module":"@jupyter-widgets/controls"}},"e03425cba05742748e55d6f3612c4c19":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_view_name":"LayoutView","grid_template_rows":null,"right":null,"justify_content":null,"_view_module":"@jupyter-widgets/base","overflow":null,"_model_module_version":"1.2.0","_view_count":null,"flex_flow":null,"width":null,"min_width":null,"border":null,"align_items":null,"bottom":null,"_model_module":"@jupyter-widgets/base","top":null,"grid_column":null,"overflow_y":null,"overflow_x":null,"grid_auto_flow":null,"grid_area":null,"grid_template_columns":null,"flex":null,"_model_name":"LayoutModel","justify_items":null,"grid_row":null,"max_height":null,"align_content":null,"visibility":null,"align_self":null,"height":null,"min_height":null,"padding":null,"grid_auto_rows":null,"grid_gap":null,"max_width":null,"order":null,"_view_module_version":"1.2.0","grid_template_areas":null,"object_position":null,"object_fit":null,"grid_auto_columns":null,"margin":null,"display":null,"left":null}},"7b95fc0bd04b4f2c9e195f490d1fbb28":{"model_module":"@jupyter-widgets/controls","model_name":"DescriptionStyleModel","model_module_version":"1.5.0","state":{"_view_name":"StyleView","_model_name":"DescriptionStyleModel","description_width":"","_view_module":"@jupyter-widgets/base","_model_module_version":"1.5.0","_view_count":null,"_view_module_version":"1.2.0","_model_module":"@jupyter-widgets/controls"}},"04809ef19b4b425da061f18b1f4ad5fe":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_view_name":"LayoutView","grid_template_rows":null,"right":null,"justify_content":null,"_view_module":"@jupyter-widgets/base","overflow":null,"_model_module_version":"1.2.0","_view_count":null,"flex_flow":null,"width":null,"min_width":null,"border":null,"align_items":null,"bottom":null,"_model_module":"@jupyter-widgets/base","top":null,"grid_column":null,"overflow_y":null,"overflow_x":null,"grid_auto_flow":null,"grid_area":null,"grid_template_columns":null,"flex":null,"_model_name":"LayoutModel","justify_items":null,"grid_row":null,"max_height":null,"align_content":null,"visibility":null,"align_self":null,"height":null,"min_height":null,"padding":null,"grid_auto_rows":null,"grid_gap":null,"max_width":null,"order":null,"_view_module_version":"1.2.0","grid_template_areas":null,"object_position":null,"object_fit":null,"grid_auto_columns":null,"margin":null,"display":null,"left":null}},"ac16804ce3a8403893b113a9f3161419":{"model_module":"@jupyter-widgets/controls","model_name":"HBoxModel","model_module_version":"1.5.0","state":{"_view_name":"HBoxView","_dom_classes":[],"_model_name":"HBoxModel","_view_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_view_count":null,"_view_module_version":"1.5.0","box_style":"","layout":"IPY_MODEL_b5fdd12f586f4dc0b37847712b60966c","_model_module":"@jupyter-widgets/controls","children":["IPY_MODEL_e93b366efe1e4413880e274cbbf9e8e0","IPY_MODEL_c8a3916a79b84e15ac121e27e6e3babd","IPY_MODEL_07eb68ff9aea4331a995690b42734543"]}},"b5fdd12f586f4dc0b37847712b60966c":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_view_name":"LayoutView","grid_template_rows":null,"right":null,"justify_content":null,"_view_module":"@jupyter-widgets/base","overflow":null,"_model_module_version":"1.2.0","_view_count":null,"flex_flow":null,"width":null,"min_width":null,"border":null,"align_items":null,"bottom":null,"_model_module":"@jupyter-widgets/base","top":null,"grid_column":null,"overflow_y":null,"overflow_x":null,"grid_auto_flow":null,"grid_area":null,"grid_template_columns":null,"flex":null,"_model_name":"LayoutModel","justify_items":null,"grid_row":null,"max_height":null,"align_content":null,"visibility":null,"align_self":null,"height":null,"min_height":null,"padding":null,"grid_auto_rows":null,"grid_gap":null,"max_width":null,"order":null,"_view_module_version":"1.2.0","grid_template_areas":null,"object_position":null,"object_fit":null,"grid_auto_columns":null,"margin":null,"display":null,"left":null}},"e93b366efe1e4413880e274cbbf9e8e0":{"model_module":"@jupyter-widgets/controls","model_name":"HTMLModel","model_module_version":"1.5.0","state":{"_view_name":"HTMLView","style":"IPY_MODEL_168d8822412746d19b5b19068fca4d12","_dom_classes":[],"description":"","_model_name":"HTMLModel","placeholder":"​","_view_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","value":"Downloading: 100%","_view_count":null,"_view_module_version":"1.5.0","description_tooltip":null,"_model_module":"@jupyter-widgets/controls","layout":"IPY_MODEL_6311a24e132741e5baaf195128379f0c"}},"c8a3916a79b84e15ac121e27e6e3babd":{"model_module":"@jupyter-widgets/controls","model_name":"FloatProgressModel","model_module_version":"1.5.0","state":{"_view_name":"ProgressView","style":"IPY_MODEL_bd00331257854066a8ee4636c8ff0abf","_dom_classes":[],"description":"","_model_name":"FloatProgressModel","bar_style":"success","max":22,"_view_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","value":22,"_view_count":null,"_view_module_version":"1.5.0","orientation":"horizontal","min":0,"description_tooltip":null,"_model_module":"@jupyter-widgets/controls","layout":"IPY_MODEL_161edabbd96345fc99eaacebd167a280"}},"07eb68ff9aea4331a995690b42734543":{"model_module":"@jupyter-widgets/controls","model_name":"HTMLModel","model_module_version":"1.5.0","state":{"_view_name":"HTMLView","style":"IPY_MODEL_b15e75305b4e4ce4a2eea55f6f3f2d3a","_dom_classes":[],"description":"","_model_name":"HTMLModel","placeholder":"​","_view_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","value":" 22.0/22.0 [00:00<00:00, 452B/s]","_view_count":null,"_view_module_version":"1.5.0","description_tooltip":null,"_model_module":"@jupyter-widgets/controls","layout":"IPY_MODEL_d56e2ac86eba4d90a2a93353d30522e1"}},"168d8822412746d19b5b19068fca4d12":{"model_module":"@jupyter-widgets/controls","model_name":"DescriptionStyleModel","model_module_version":"1.5.0","state":{"_view_name":"StyleView","_model_name":"DescriptionStyleModel","description_width":"","_view_module":"@jupyter-widgets/base","_model_module_version":"1.5.0","_view_count":null,"_view_module_version":"1.2.0","_model_module":"@jupyter-widgets/controls"}},"6311a24e132741e5baaf195128379f0c":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_view_name":"LayoutView","grid_template_rows":null,"right":null,"justify_content":null,"_view_module":"@jupyter-widgets/base","overflow":null,"_model_module_version":"1.2.0","_view_count":null,"flex_flow":null,"width":null,"min_width":null,"border":null,"align_items":null,"bottom":null,"_model_module":"@jupyter-widgets/base","top":null,"grid_column":null,"overflow_y":null,"overflow_x":null,"grid_auto_flow":null,"grid_area":null,"grid_template_columns":null,"flex":null,"_model_name":"LayoutModel","justify_items":null,"grid_row":null,"max_height":null,"align_content":null,"visibility":null,"align_self":null,"height":null,"min_height":null,"padding":null,"grid_auto_rows":null,"grid_gap":null,"max_width":null,"order":null,"_view_module_version":"1.2.0","grid_template_areas":null,"object_position":null,"object_fit":null,"grid_auto_columns":null,"margin":null,"display":null,"left":null}},"bd00331257854066a8ee4636c8ff0abf":{"model_module":"@jupyter-widgets/controls","model_name":"ProgressStyleModel","model_module_version":"1.5.0","state":{"_view_name":"StyleView","_model_name":"ProgressStyleModel","description_width":"","_view_module":"@jupyter-widgets/base","_model_module_version":"1.5.0","_view_count":null,"_view_module_version":"1.2.0","bar_color":null,"_model_module":"@jupyter-widgets/controls"}},"161edabbd96345fc99eaacebd167a280":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_view_name":"LayoutView","grid_template_rows":null,"right":null,"justify_content":null,"_view_module":"@jupyter-widgets/base","overflow":null,"_model_module_version":"1.2.0","_view_count":null,"flex_flow":null,"width":null,"min_width":null,"border":null,"align_items":null,"bottom":null,"_model_module":"@jupyter-widgets/base","top":null,"grid_column":null,"overflow_y":null,"overflow_x":null,"grid_auto_flow":null,"grid_area":null,"grid_template_columns":null,"flex":null,"_model_name":"LayoutModel","justify_items":null,"grid_row":null,"max_height":null,"align_content":null,"visibility":null,"align_self":null,"height":null,"min_height":null,"padding":null,"grid_auto_rows":null,"grid_gap":null,"max_width":null,"order":null,"_view_module_version":"1.2.0","grid_template_areas":null,"object_position":null,"object_fit":null,"grid_auto_columns":null,"margin":null,"display":null,"left":null}},"b15e75305b4e4ce4a2eea55f6f3f2d3a":{"model_module":"@jupyter-widgets/controls","model_name":"DescriptionStyleModel","model_module_version":"1.5.0","state":{"_view_name":"StyleView","_model_name":"DescriptionStyleModel","description_width":"","_view_module":"@jupyter-widgets/base","_model_module_version":"1.5.0","_view_count":null,"_view_module_version":"1.2.0","_model_module":"@jupyter-widgets/controls"}},"d56e2ac86eba4d90a2a93353d30522e1":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_view_name":"LayoutView","grid_template_rows":null,"right":null,"justify_content":null,"_view_module":"@jupyter-widgets/base","overflow":null,"_model_module_version":"1.2.0","_view_count":null,"flex_flow":null,"width":null,"min_width":null,"border":null,"align_items":null,"bottom":null,"_model_module":"@jupyter-widgets/base","top":null,"grid_column":null,"overflow_y":null,"overflow_x":null,"grid_auto_flow":null,"grid_area":null,"grid_template_columns":null,"flex":null,"_model_name":"LayoutModel","justify_items":null,"grid_row":null,"max_height":null,"align_content":null,"visibility":null,"align_self":null,"height":null,"min_height":null,"padding":null,"grid_auto_rows":null,"grid_gap":null,"max_width":null,"order":null,"_view_module_version":"1.2.0","grid_template_areas":null,"object_position":null,"object_fit":null,"grid_auto_columns":null,"margin":null,"display":null,"left":null}},"3083a8f61e9a4371a0524a5ca17078a9":{"model_module":"@jupyter-widgets/controls","model_name":"HBoxModel","model_module_version":"1.5.0","state":{"_view_name":"HBoxView","_dom_classes":[],"_model_name":"HBoxModel","_view_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_view_count":null,"_view_module_version":"1.5.0","box_style":"","layout":"IPY_MODEL_f0dc83610cfb4f0c8282197dd9d989a1","_model_module":"@jupyter-widgets/controls","children":["IPY_MODEL_6486ca33edd54d60840977fbb2b586c7","IPY_MODEL_45f6924f383748eeadbc497a7a6b2935","IPY_MODEL_4a6a8626c556412083b49654e0fad381"]}},"f0dc83610cfb4f0c8282197dd9d989a1":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_view_name":"LayoutView","grid_template_rows":null,"right":null,"justify_content":null,"_view_module":"@jupyter-widgets/base","overflow":null,"_model_module_version":"1.2.0","_view_count":null,"flex_flow":null,"width":null,"min_width":null,"border":null,"align_items":null,"bottom":null,"_model_module":"@jupyter-widgets/base","top":null,"grid_column":null,"overflow_y":null,"overflow_x":null,"grid_auto_flow":null,"grid_area":null,"grid_template_columns":null,"flex":null,"_model_name":"LayoutModel","justify_items":null,"grid_row":null,"max_height":null,"align_content":null,"visibility":null,"align_self":null,"height":null,"min_height":null,"padding":null,"grid_auto_rows":null,"grid_gap":null,"max_width":null,"order":null,"_view_module_version":"1.2.0","grid_template_areas":null,"object_position":null,"object_fit":null,"grid_auto_columns":null,"margin":null,"display":null,"left":null}},"6486ca33edd54d60840977fbb2b586c7":{"model_module":"@jupyter-widgets/controls","model_name":"HTMLModel","model_module_version":"1.5.0","state":{"_view_name":"HTMLView","style":"IPY_MODEL_a3913ac763b544afbe89c984ad0f58f4","_dom_classes":[],"description":"","_model_name":"HTMLModel","placeholder":"​","_view_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","value":"Downloading: 100%","_view_count":null,"_view_module_version":"1.5.0","description_tooltip":null,"_model_module":"@jupyter-widgets/controls","layout":"IPY_MODEL_669f172cac364e21bc9b715f60cbf94b"}},"45f6924f383748eeadbc497a7a6b2935":{"model_module":"@jupyter-widgets/controls","model_name":"FloatProgressModel","model_module_version":"1.5.0","state":{"_view_name":"ProgressView","style":"IPY_MODEL_ebe2eacf7f744f5a889ace41cf9b4d91","_dom_classes":[],"description":"","_model_name":"FloatProgressModel","bar_style":"success","max":1232,"_view_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","value":1232,"_view_count":null,"_view_module_version":"1.5.0","orientation":"horizontal","min":0,"description_tooltip":null,"_model_module":"@jupyter-widgets/controls","layout":"IPY_MODEL_63b5910962a84221b36bc88bceebcbb4"}},"4a6a8626c556412083b49654e0fad381":{"model_module":"@jupyter-widgets/controls","model_name":"HTMLModel","model_module_version":"1.5.0","state":{"_view_name":"HTMLView","style":"IPY_MODEL_923d3137698e44a39f6f4aaf3cbc4d21","_dom_classes":[],"description":"","_model_name":"HTMLModel","placeholder":"​","_view_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","value":" 1.20k/1.20k [00:00<00:00, 31.5kB/s]","_view_count":null,"_view_module_version":"1.5.0","description_tooltip":null,"_model_module":"@jupyter-widgets/controls","layout":"IPY_MODEL_7f5c6fdbcdbe430199b76f9decf89443"}},"a3913ac763b544afbe89c984ad0f58f4":{"model_module":"@jupyter-widgets/controls","model_name":"DescriptionStyleModel","model_module_version":"1.5.0","state":{"_view_name":"StyleView","_model_name":"DescriptionStyleModel","description_width":"","_view_module":"@jupyter-widgets/base","_model_module_version":"1.5.0","_view_count":null,"_view_module_version":"1.2.0","_model_module":"@jupyter-widgets/controls"}},"669f172cac364e21bc9b715f60cbf94b":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_view_name":"LayoutView","grid_template_rows":null,"right":null,"justify_content":null,"_view_module":"@jupyter-widgets/base","overflow":null,"_model_module_version":"1.2.0","_view_count":null,"flex_flow":null,"width":null,"min_width":null,"border":null,"align_items":null,"bottom":null,"_model_module":"@jupyter-widgets/base","top":null,"grid_column":null,"overflow_y":null,"overflow_x":null,"grid_auto_flow":null,"grid_area":null,"grid_template_columns":null,"flex":null,"_model_name":"LayoutModel","justify_items":null,"grid_row":null,"max_height":null,"align_content":null,"visibility":null,"align_self":null,"height":null,"min_height":null,"padding":null,"grid_auto_rows":null,"grid_gap":null,"max_width":null,"order":null,"_view_module_version":"1.2.0","grid_template_areas":null,"object_position":null,"object_fit":null,"grid_auto_columns":null,"margin":null,"display":null,"left":null}},"ebe2eacf7f744f5a889ace41cf9b4d91":{"model_module":"@jupyter-widgets/controls","model_name":"ProgressStyleModel","model_module_version":"1.5.0","state":{"_view_name":"StyleView","_model_name":"ProgressStyleModel","description_width":"","_view_module":"@jupyter-widgets/base","_model_module_version":"1.5.0","_view_count":null,"_view_module_version":"1.2.0","bar_color":null,"_model_module":"@jupyter-widgets/controls"}},"63b5910962a84221b36bc88bceebcbb4":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_view_name":"LayoutView","grid_template_rows":null,"right":null,"justify_content":null,"_view_module":"@jupyter-widgets/base","overflow":null,"_model_module_version":"1.2.0","_view_count":null,"flex_flow":null,"width":null,"min_width":null,"border":null,"align_items":null,"bottom":null,"_model_module":"@jupyter-widgets/base","top":null,"grid_column":null,"overflow_y":null,"overflow_x":null,"grid_auto_flow":null,"grid_area":null,"grid_template_columns":null,"flex":null,"_model_name":"LayoutModel","justify_items":null,"grid_row":null,"max_height":null,"align_content":null,"visibility":null,"align_self":null,"height":null,"min_height":null,"padding":null,"grid_auto_rows":null,"grid_gap":null,"max_width":null,"order":null,"_view_module_version":"1.2.0","grid_template_areas":null,"object_position":null,"object_fit":null,"grid_auto_columns":null,"margin":null,"display":null,"left":null}},"923d3137698e44a39f6f4aaf3cbc4d21":{"model_module":"@jupyter-widgets/controls","model_name":"DescriptionStyleModel","model_module_version":"1.5.0","state":{"_view_name":"StyleView","_model_name":"DescriptionStyleModel","description_width":"","_view_module":"@jupyter-widgets/base","_model_module_version":"1.5.0","_view_count":null,"_view_module_version":"1.2.0","_model_module":"@jupyter-widgets/controls"}},"7f5c6fdbcdbe430199b76f9decf89443":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_view_name":"LayoutView","grid_template_rows":null,"right":null,"justify_content":null,"_view_module":"@jupyter-widgets/base","overflow":null,"_model_module_version":"1.2.0","_view_count":null,"flex_flow":null,"width":null,"min_width":null,"border":null,"align_items":null,"bottom":null,"_model_module":"@jupyter-widgets/base","top":null,"grid_column":null,"overflow_y":null,"overflow_x":null,"grid_auto_flow":null,"grid_area":null,"grid_template_columns":null,"flex":null,"_model_name":"LayoutModel","justify_items":null,"grid_row":null,"max_height":null,"align_content":null,"visibility":null,"align_self":null,"height":null,"min_height":null,"padding":null,"grid_auto_rows":null,"grid_gap":null,"max_width":null,"order":null,"_view_module_version":"1.2.0","grid_template_areas":null,"object_position":null,"object_fit":null,"grid_auto_columns":null,"margin":null,"display":null,"left":null}},"110e98d7d9414e07a8324de854e65310":{"model_module":"@jupyter-widgets/controls","model_name":"HBoxModel","model_module_version":"1.5.0","state":{"_view_name":"HBoxView","_dom_classes":[],"_model_name":"HBoxModel","_view_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_view_count":null,"_view_module_version":"1.5.0","box_style":"","layout":"IPY_MODEL_e7efe75c28f04445ba0fcf46857fabac","_model_module":"@jupyter-widgets/controls","children":["IPY_MODEL_72bbdf8757644d64ab54c6527f86fd30","IPY_MODEL_b8153323ab2945e0ae104ce18349955c","IPY_MODEL_1ef7289ade1244ccaf8084f0b6211dd2"]}},"e7efe75c28f04445ba0fcf46857fabac":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_view_name":"LayoutView","grid_template_rows":null,"right":null,"justify_content":null,"_view_module":"@jupyter-widgets/base","overflow":null,"_model_module_version":"1.2.0","_view_count":null,"flex_flow":null,"width":null,"min_width":null,"border":null,"align_items":null,"bottom":null,"_model_module":"@jupyter-widgets/base","top":null,"grid_column":null,"overflow_y":null,"overflow_x":null,"grid_auto_flow":null,"grid_area":null,"grid_template_columns":null,"flex":null,"_model_name":"LayoutModel","justify_items":null,"grid_row":null,"max_height":null,"align_content":null,"visibility":null,"align_self":null,"height":null,"min_height":null,"padding":null,"grid_auto_rows":null,"grid_gap":null,"max_width":null,"order":null,"_view_module_version":"1.2.0","grid_template_areas":null,"object_position":null,"object_fit":null,"grid_auto_columns":null,"margin":null,"display":null,"left":null}},"72bbdf8757644d64ab54c6527f86fd30":{"model_module":"@jupyter-widgets/controls","model_name":"HTMLModel","model_module_version":"1.5.0","state":{"_view_name":"HTMLView","style":"IPY_MODEL_ed34f7fb01bf414597534561f6d2b887","_dom_classes":[],"description":"","_model_name":"HTMLModel","placeholder":"​","_view_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","value":"Downloading: 100%","_view_count":null,"_view_module_version":"1.5.0","description_tooltip":null,"_model_module":"@jupyter-widgets/controls","layout":"IPY_MODEL_6b839734aa474e818806ed9fc56dde9a"}},"b8153323ab2945e0ae104ce18349955c":{"model_module":"@jupyter-widgets/controls","model_name":"FloatProgressModel","model_module_version":"1.5.0","state":{"_view_name":"ProgressView","style":"IPY_MODEL_e9494f39659d44e8b5d93d8dac34dd94","_dom_classes":[],"description":"","_model_name":"FloatProgressModel","bar_style":"success","max":48288230,"_view_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","value":48288230,"_view_count":null,"_view_module_version":"1.5.0","orientation":"horizontal","min":0,"description_tooltip":null,"_model_module":"@jupyter-widgets/controls","layout":"IPY_MODEL_02fe1305013742dbaad56717c6301d3a"}},"1ef7289ade1244ccaf8084f0b6211dd2":{"model_module":"@jupyter-widgets/controls","model_name":"HTMLModel","model_module_version":"1.5.0","state":{"_view_name":"HTMLView","style":"IPY_MODEL_7aae80ca599c4a40ba801eae7d296833","_dom_classes":[],"description":"","_model_name":"HTMLModel","placeholder":"​","_view_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","value":" 46.1M/46.1M [00:01<00:00, 29.6MB/s]","_view_count":null,"_view_module_version":"1.5.0","description_tooltip":null,"_model_module":"@jupyter-widgets/controls","layout":"IPY_MODEL_36b314c290984f57ad5b7357e1971576"}},"ed34f7fb01bf414597534561f6d2b887":{"model_module":"@jupyter-widgets/controls","model_name":"DescriptionStyleModel","model_module_version":"1.5.0","state":{"_view_name":"StyleView","_model_name":"DescriptionStyleModel","description_width":"","_view_module":"@jupyter-widgets/base","_model_module_version":"1.5.0","_view_count":null,"_view_module_version":"1.2.0","_model_module":"@jupyter-widgets/controls"}},"6b839734aa474e818806ed9fc56dde9a":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_view_name":"LayoutView","grid_template_rows":null,"right":null,"justify_content":null,"_view_module":"@jupyter-widgets/base","overflow":null,"_model_module_version":"1.2.0","_view_count":null,"flex_flow":null,"width":null,"min_width":null,"border":null,"align_items":null,"bottom":null,"_model_module":"@jupyter-widgets/base","top":null,"grid_column":null,"overflow_y":null,"overflow_x":null,"grid_auto_flow":null,"grid_area":null,"grid_template_columns":null,"flex":null,"_model_name":"LayoutModel","justify_items":null,"grid_row":null,"max_height":null,"align_content":null,"visibility":null,"align_self":null,"height":null,"min_height":null,"padding":null,"grid_auto_rows":null,"grid_gap":null,"max_width":null,"order":null,"_view_module_version":"1.2.0","grid_template_areas":null,"object_position":null,"object_fit":null,"grid_auto_columns":null,"margin":null,"display":null,"left":null}},"e9494f39659d44e8b5d93d8dac34dd94":{"model_module":"@jupyter-widgets/controls","model_name":"ProgressStyleModel","model_module_version":"1.5.0","state":{"_view_name":"StyleView","_model_name":"ProgressStyleModel","description_width":"","_view_module":"@jupyter-widgets/base","_model_module_version":"1.5.0","_view_count":null,"_view_module_version":"1.2.0","bar_color":null,"_model_module":"@jupyter-widgets/controls"}},"02fe1305013742dbaad56717c6301d3a":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_view_name":"LayoutView","grid_template_rows":null,"right":null,"justify_content":null,"_view_module":"@jupyter-widgets/base","overflow":null,"_model_module_version":"1.2.0","_view_count":null,"flex_flow":null,"width":null,"min_width":null,"border":null,"align_items":null,"bottom":null,"_model_module":"@jupyter-widgets/base","top":null,"grid_column":null,"overflow_y":null,"overflow_x":null,"grid_auto_flow":null,"grid_area":null,"grid_template_columns":null,"flex":null,"_model_name":"LayoutModel","justify_items":null,"grid_row":null,"max_height":null,"align_content":null,"visibility":null,"align_self":null,"height":null,"min_height":null,"padding":null,"grid_auto_rows":null,"grid_gap":null,"max_width":null,"order":null,"_view_module_version":"1.2.0","grid_template_areas":null,"object_position":null,"object_fit":null,"grid_auto_columns":null,"margin":null,"display":null,"left":null}},"7aae80ca599c4a40ba801eae7d296833":{"model_module":"@jupyter-widgets/controls","model_name":"DescriptionStyleModel","model_module_version":"1.5.0","state":{"_view_name":"StyleView","_model_name":"DescriptionStyleModel","description_width":"","_view_module":"@jupyter-widgets/base","_model_module_version":"1.5.0","_view_count":null,"_view_module_version":"1.2.0","_model_module":"@jupyter-widgets/controls"}},"36b314c290984f57ad5b7357e1971576":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_view_name":"LayoutView","grid_template_rows":null,"right":null,"justify_content":null,"_view_module":"@jupyter-widgets/base","overflow":null,"_model_module_version":"1.2.0","_view_count":null,"flex_flow":null,"width":null,"min_width":null,"border":null,"align_items":null,"bottom":null,"_model_module":"@jupyter-widgets/base","top":null,"grid_column":null,"overflow_y":null,"overflow_x":null,"grid_auto_flow":null,"grid_area":null,"grid_template_columns":null,"flex":null,"_model_name":"LayoutModel","justify_items":null,"grid_row":null,"max_height":null,"align_content":null,"visibility":null,"align_self":null,"height":null,"min_height":null,"padding":null,"grid_auto_rows":null,"grid_gap":null,"max_width":null,"order":null,"_view_module_version":"1.2.0","grid_template_areas":null,"object_position":null,"object_fit":null,"grid_auto_columns":null,"margin":null,"display":null,"left":null}}}}},"cells":[{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"WoYfd-f_Be1P","executionInfo":{"status":"ok","timestamp":1635414755294,"user_tz":-540,"elapsed":1614,"user":{"displayName":"乗松潤矢","photoUrl":"https://lh3.googleusercontent.com/a/default-user=s64","userId":"14372592260560654082"}},"outputId":"ecef8ac6-df45-4568-cfcf-1c05a834650e"},"source":["!wget https://github.com/jnory/datasets/raw/master/wikinews/jawikinews.json"],"execution_count":1,"outputs":[{"output_type":"stream","name":"stdout","text":["--2021-10-28 09:52:33-- https://github.com/jnory/datasets/raw/master/wikinews/jawikinews.json\n","Resolving github.com (github.com)... 140.82.114.3\n","Connecting to github.com (github.com)|140.82.114.3|:443... connected.\n","HTTP request sent, awaiting response... 302 Found\n","Location: https://raw.githubusercontent.com/jnory/datasets/master/wikinews/jawikinews.json [following]\n","--2021-10-28 09:52:34-- https://raw.githubusercontent.com/jnory/datasets/master/wikinews/jawikinews.json\n","Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...\n","Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.\n","HTTP request sent, awaiting response... 200 OK\n","Length: 6796960 (6.5M) [text/plain]\n","Saving to: ‘jawikinews.json’\n","\n","jawikinews.json 100%[===================>] 6.48M --.-KB/s in 0.08s \n","\n","2021-10-28 09:52:34 (76.4 MB/s) - ‘jawikinews.json’ saved [6796960/6796960]\n","\n"]}]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"y0-ltct2BHXc","executionInfo":{"status":"ok","timestamp":1635414785336,"user_tz":-540,"elapsed":30051,"user":{"displayName":"乗松潤矢","photoUrl":"https://lh3.googleusercontent.com/a/default-user=s64","userId":"14372592260560654082"}},"outputId":"f99eb5c4-6371-42f7-bb6f-56bb1bd1f7c1"},"source":["!pip install transformers sentencepiece fugashi unidic-lite"],"execution_count":2,"outputs":[{"output_type":"stream","name":"stdout","text":["Collecting transformers\n"," Downloading transformers-4.11.3-py3-none-any.whl (2.9 MB)\n","\u001b[K |████████████████████████████████| 2.9 MB 5.3 MB/s \n","\u001b[?25hCollecting sentencepiece\n"," Downloading sentencepiece-0.1.96-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.2 MB)\n","\u001b[K |████████████████████████████████| 1.2 MB 27.9 MB/s \n","\u001b[?25hCollecting fugashi\n"," Downloading fugashi-1.1.1-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl (490 kB)\n","\u001b[K |████████████████████████████████| 490 kB 44.6 MB/s \n","\u001b[?25hCollecting unidic-lite\n"," Downloading unidic-lite-1.0.8.tar.gz (47.4 MB)\n","\u001b[K |████████████████████████████████| 47.4 MB 53 kB/s \n","\u001b[?25hCollecting pyyaml>=5.1\n"," Downloading PyYAML-6.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (596 kB)\n","\u001b[K |████████████████████████████████| 596 kB 34.6 MB/s \n","\u001b[?25hRequirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.7/dist-packages (from transformers) (2019.12.20)\n","Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.7/dist-packages (from transformers) (1.19.5)\n","Collecting tokenizers<0.11,>=0.10.1\n"," Downloading tokenizers-0.10.3-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (3.3 MB)\n","\u001b[K |████████████████████████████████| 3.3 MB 29.3 MB/s \n","\u001b[?25hCollecting huggingface-hub>=0.0.17\n"," Downloading huggingface_hub-0.0.19-py3-none-any.whl (56 kB)\n","\u001b[K |████████████████████████████████| 56 kB 4.2 MB/s \n","\u001b[?25hRequirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.7/dist-packages (from transformers) (4.62.3)\n","Requirement already satisfied: requests in /usr/local/lib/python3.7/dist-packages (from transformers) (2.23.0)\n","Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.7/dist-packages (from transformers) (21.0)\n","Requirement already satisfied: importlib-metadata in /usr/local/lib/python3.7/dist-packages (from transformers) (4.8.1)\n","Requirement already satisfied: filelock in /usr/local/lib/python3.7/dist-packages (from transformers) (3.3.0)\n","Collecting sacremoses\n"," Downloading sacremoses-0.0.46-py3-none-any.whl (895 kB)\n","\u001b[K |████████████████████████████████| 895 kB 23.5 MB/s \n","\u001b[?25hRequirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from huggingface-hub>=0.0.17->transformers) (3.7.4.3)\n","Requirement already satisfied: pyparsing>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from packaging>=20.0->transformers) (2.4.7)\n","Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.7/dist-packages (from importlib-metadata->transformers) (3.6.0)\n","Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (2.10)\n","Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (3.0.4)\n","Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (2021.5.30)\n","Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (1.24.3)\n","Requirement already satisfied: joblib in /usr/local/lib/python3.7/dist-packages (from sacremoses->transformers) (1.0.1)\n","Requirement already satisfied: click in /usr/local/lib/python3.7/dist-packages (from sacremoses->transformers) (7.1.2)\n","Requirement already satisfied: six in /usr/local/lib/python3.7/dist-packages (from sacremoses->transformers) (1.15.0)\n","Building wheels for collected packages: unidic-lite\n"," Building wheel for unidic-lite (setup.py) ... \u001b[?25l\u001b[?25hdone\n"," Created wheel for unidic-lite: filename=unidic_lite-1.0.8-py3-none-any.whl size=47658836 sha256=9b557d7c5b555fe06d228e910702779e78330db245e029a9e17017f86bc06c0d\n"," Stored in directory: /root/.cache/pip/wheels/de/69/b1/112140b599f2b13f609d485a99e357ba68df194d2079c5b1a2\n","Successfully built unidic-lite\n","Installing collected packages: pyyaml, tokenizers, sacremoses, huggingface-hub, unidic-lite, transformers, sentencepiece, fugashi\n"," Attempting uninstall: pyyaml\n"," Found existing installation: PyYAML 3.13\n"," Uninstalling PyYAML-3.13:\n"," Successfully uninstalled PyYAML-3.13\n","Successfully installed fugashi-1.1.1 huggingface-hub-0.0.19 pyyaml-6.0 sacremoses-0.0.46 sentencepiece-0.1.96 tokenizers-0.10.3 transformers-4.11.3 unidic-lite-1.0.8\n"]}]},{"cell_type":"code","metadata":{"id":"zqb-RkX7A_No","executionInfo":{"status":"ok","timestamp":1635414813758,"user_tz":-540,"elapsed":28430,"user":{"displayName":"乗松潤矢","photoUrl":"https://lh3.googleusercontent.com/a/default-user=s64","userId":"14372592260560654082"}}},"source":["import json\n","import random\n","\n","import torch\n","import numpy as np\n","import seaborn as sns\n","import pandas as pd\n","from sklearn.model_selection import train_test_split\n","from torch.utils.data import Dataset, DataLoader\n","from transformers import PreTrainedTokenizer, AdamW, get_linear_schedule_with_warmup\n","from transformers import AlbertTokenizer, AlbertForSequenceClassification"],"execution_count":3,"outputs":[]},{"cell_type":"code","metadata":{"id":"8uPHO4ptBC7y","executionInfo":{"status":"ok","timestamp":1635414813760,"user_tz":-540,"elapsed":25,"user":{"displayName":"乗松潤矢","photoUrl":"https://lh3.googleusercontent.com/a/default-user=s64","userId":"14372592260560654082"}}},"source":["class WikinewsDataset(Dataset):\n"," def __init__(self, corpus):\n"," self.data = [entry for entry in corpus if len(entry[\"categories\"]) == 1]\n","\n"," def __len__(self):\n"," return len(self.data)\n","\n"," def __getitem__(self, index):\n"," return self.data[index]"],"execution_count":4,"outputs":[]},{"cell_type":"code","metadata":{"id":"AURVmgW_BRE3","executionInfo":{"status":"ok","timestamp":1635414813761,"user_tz":-540,"elapsed":23,"user":{"displayName":"乗松潤矢","photoUrl":"https://lh3.googleusercontent.com/a/default-user=s64","userId":"14372592260560654082"}}},"source":["cat_ids = {\n"," \"スポーツ\": 0,\n"," \"事件\": 1,\n"," \"事故\": 2,\n"," \"学術\": 3,\n"," \"政治\": 4,\n"," \"文化\": 5,\n"," \"気象\": 6,\n"," \"社会\": 7,\n"," \"経済\": 8,\n"," \"野球\": 0, # mix into sports\n","}\n","num_labels = 9\n","\n","class CollateFN:\n"," def __init__(self, tokenizer: PreTrainedTokenizer, device):\n"," self.tokenizer = tokenizer\n"," self.device = device\n","\n"," def __call__(self, entries):\n"," texts = [entry[\"text\"] for entry in entries]\n"," categories = [entry[\"categories\"][0] for entry in entries]\n"," categories = torch.tensor(\n"," [cat_ids[c] for c in categories], dtype=torch.long).to(self.device)\n"," batch = self.tokenizer.batch_encode_plus(\n"," texts, padding=True, truncation=True, max_length=512,\n"," return_tensors=\"pt\", return_attention_mask=True, return_token_type_ids=True,\n"," ).to(self.device)\n"," return batch, categories"],"execution_count":5,"outputs":[]},{"cell_type":"code","metadata":{"id":"D_aH72v5GZx4","executionInfo":{"status":"ok","timestamp":1635414813762,"user_tz":-540,"elapsed":20,"user":{"displayName":"乗松潤矢","photoUrl":"https://lh3.googleusercontent.com/a/default-user=s64","userId":"14372592260560654082"}}},"source":["with open(\"jawikinews.json\", encoding=\"utf-8\") as fp:\n"," corpus = json.loads(fp.read())"],"execution_count":6,"outputs":[]},{"cell_type":"code","metadata":{"id":"F9yc_UJVPAPH","executionInfo":{"status":"ok","timestamp":1635414813763,"user_tz":-540,"elapsed":18,"user":{"displayName":"乗松潤矢","photoUrl":"https://lh3.googleusercontent.com/a/default-user=s64","userId":"14372592260560654082"}}},"source":["# model = None\n","# import gc\n","# gc.collect()"],"execution_count":7,"outputs":[]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/","height":1000,"referenced_widgets":["eaadb0acc47e4bcd881f8ba4864d2b42","22eb56149e584897bdd438ca62669267","de4e5eec3f9846d897ec1559e50d6cc1","e66b1567e47d4e56aae7681c694d2c18","0dad0162cfa645f28e70a0d561ac7497","1816c2ac4e1e4eec8e5575311a5e404f","4e3e0856c0bf4f51b3dd33403e73aeb3","17ae15d338e7489384da7510449d6aed","30ed416a6b1249199a4189306b2b4575","c07a4fa96f6b43c59af1f143c7f646b7","6cc83ca859124b8cb9aa9c66fc91939b","e041146ae35a49999deecffacd5cdab7","5af56be959294b5a80ebc9321634fcd7","cc8774f4ec614acc93048ea8d489003f","f1f25d2a879d4fdc87a50f720a5c7c2f","34a5ac426b624121a7abccac0a0a50ee","e5e2c45b1c8a413aa0d6387a3e13de87","5774894d34434f48ac6639389caee424","2d5d877e4cbf4ba996bc1f01134c7343","e03425cba05742748e55d6f3612c4c19","7b95fc0bd04b4f2c9e195f490d1fbb28","04809ef19b4b425da061f18b1f4ad5fe","ac16804ce3a8403893b113a9f3161419","b5fdd12f586f4dc0b37847712b60966c","e93b366efe1e4413880e274cbbf9e8e0","c8a3916a79b84e15ac121e27e6e3babd","07eb68ff9aea4331a995690b42734543","168d8822412746d19b5b19068fca4d12","6311a24e132741e5baaf195128379f0c","bd00331257854066a8ee4636c8ff0abf","161edabbd96345fc99eaacebd167a280","b15e75305b4e4ce4a2eea55f6f3f2d3a","d56e2ac86eba4d90a2a93353d30522e1","3083a8f61e9a4371a0524a5ca17078a9","f0dc83610cfb4f0c8282197dd9d989a1","6486ca33edd54d60840977fbb2b586c7","45f6924f383748eeadbc497a7a6b2935","4a6a8626c556412083b49654e0fad381","a3913ac763b544afbe89c984ad0f58f4","669f172cac364e21bc9b715f60cbf94b","ebe2eacf7f744f5a889ace41cf9b4d91","63b5910962a84221b36bc88bceebcbb4","923d3137698e44a39f6f4aaf3cbc4d21","7f5c6fdbcdbe430199b76f9decf89443","110e98d7d9414e07a8324de854e65310","e7efe75c28f04445ba0fcf46857fabac","72bbdf8757644d64ab54c6527f86fd30","b8153323ab2945e0ae104ce18349955c","1ef7289ade1244ccaf8084f0b6211dd2","ed34f7fb01bf414597534561f6d2b887","6b839734aa474e818806ed9fc56dde9a","e9494f39659d44e8b5d93d8dac34dd94","02fe1305013742dbaad56717c6301d3a","7aae80ca599c4a40ba801eae7d296833","36b314c290984f57ad5b7357e1971576"]},"id":"44m4M7NXBSuY","executionInfo":{"status":"ok","timestamp":1635417817992,"user_tz":-540,"elapsed":3004246,"user":{"displayName":"乗松潤矢","photoUrl":"https://lh3.googleusercontent.com/a/default-user=s64","userId":"14372592260560654082"}},"outputId":"dc572842-db22-4f93-8c1c-82b7757d7e00"},"source":["random.seed(1234)\n","torch.random.manual_seed(1234)\n","np.random.seed(1234)\n","\n","batch_size = 8\n","n_epoch = 10\n","\n","device = \"cuda:0\"\n","\n","tokenizer = AlbertTokenizer.from_pretrained(\"ALINEAR/albert-japanese-v2\")\n","model = AlbertForSequenceClassification.from_pretrained(\"ALINEAR/albert-japanese-v2\", num_labels=num_labels).to(device)\n","\n","train, dev = train_test_split(corpus, test_size=0.05)\n","train = DataLoader(\n"," WikinewsDataset(train), collate_fn=CollateFN(tokenizer, device),\n"," shuffle=True, batch_size=batch_size, drop_last=True)\n","dev = DataLoader(\n"," WikinewsDataset(dev), collate_fn=CollateFN(tokenizer, device),\n"," shuffle=False, batch_size=batch_size)\n","print(\"# of train =\", len(train.dataset.data), \"# of dev=\", len(dev.dataset.data))\n","\n","n_iter = len(train) * n_epoch\n","optimizer = AdamW(model.parameters(), lr=1e-6)\n","scheduler = get_linear_schedule_with_warmup(optimizer, 2 * n_iter // n_epoch, n_iter)\n","print(\"n_iter =\", n_iter)\n","\n","# start training\n","model.train()\n","for epoch in range(n_epoch):\n"," for i, (batch, category) in enumerate(train):\n"," # forward-backward\n"," loss = model(labels=category, **batch).loss\n"," loss.backward()\n"," optimizer.step()\n"," scheduler.step()\n"," if i % 10 == 0:\n"," print(epoch, i, \"loss =\", loss.detach().cpu().numpy())\n","\n"," # evaluate by epoch\n"," with torch.no_grad():\n"," dev_loss = 0.0\n"," model.eval()\n"," for batch, category in dev:\n"," dev_loss += model(labels=category, **batch).loss.detach().cpu().numpy()\n"," print(epoch, \"DEV LOSS =\", dev_loss / len(dev))\n"," model.train()\n","\n","model.save_pretrained(\"news_classifier\")\n","tokenizer.save_pretrained(\"news_classifier\")"],"execution_count":8,"outputs":[{"output_type":"display_data","data":{"application/vnd.jupyter.widget-view+json":{"model_id":"eaadb0acc47e4bcd881f8ba4864d2b42","version_minor":0,"version_major":2},"text/plain":["Downloading: 0%| | 0.00/766k [00:00 ./log\n","\n","# for small data training, use\n","# --train_step=2212 \\\n","# --warmup_step=221 \\"],"execution_count":0,"outputs":[]},{"cell_type":"code","metadata":{"id":"tl6HrfdpPnNx","colab_type":"code","colab":{}},"source":["! tail -n 100 ./log"],"execution_count":0,"outputs":[]},{"cell_type":"code","metadata":{"id":"8JP_sAIk00N1","colab_type":"code","colab":{}},"source":["ls {FINETUNE_OUTPUT_DIR}"],"execution_count":0,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"uGVoXDgNbRYs","colab_type":"text"},"source":["## Predict using the finetuned model\n","\n","Let's predict test data using the finetuned model. "]},{"cell_type":"code","metadata":{"id":"fYp-PZqgbRYs","colab_type":"code","colab":{}},"source":["import sys\n","sys.path.append(\"./src\")\n","\n","from ALBERT import tokenization\n","from run_classifier import LivedoorProcessor\n","from ALBERT.classifier_utils import model_fn_builder\n","from ALBERT.classifier_utils import file_based_input_fn_builder\n","from ALBERT.classifier_utils import file_based_convert_examples_to_features\n","from utils import str_to_value"],"execution_count":0,"outputs":[]},{"cell_type":"code","metadata":{"id":"Ud0AEhvkbRYv","colab_type":"code","colab":{}},"source":["from ALBERT import modeling\n","from ALBERT import optimization\n","import tensorflow as tf"],"execution_count":0,"outputs":[]},{"cell_type":"code","metadata":{"id":"r16NmAbhbRYz","colab_type":"code","colab":{}},"source":["import configparser\n","import json\n","import glob\n","import os\n","import pandas as pd\n","import tempfile\n","\n","albert_config = modeling.AlbertConfig.from_json_file(\"albert_config.json\")"],"execution_count":0,"outputs":[]},{"cell_type":"code","metadata":{"id":"5_E-Nsg_vZCA","colab_type":"code","colab":{}},"source":["!cp -pr {FINETUNE_OUTPUT_DIR} data"],"execution_count":0,"outputs":[]},{"cell_type":"code","metadata":{"id":"Jb_Pexq7bRY1","colab_type":"code","colab":{}},"source":["FINETUNED_MODEL_PATH = os.path.abspath(\"./data/livedoor_output/model.ckpt-best\")\n","# FINETUNED_MODEL_PATH = os.path.abspath(\"./data/livedoor_output_light/model.ckpt-best\")"],"execution_count":0,"outputs":[]},{"cell_type":"code","metadata":{"id":"Ikvo2hE0bRY3","colab_type":"code","colab":{}},"source":["class FLAGS(object):\n"," '''Parameters.'''\n"," def __init__(self):\n"," self.model_file = \"./model/wiki-ja_albert.model\"\n"," self.vocab_file = \"./model/wiki-ja_albert.vocab\"\n"," self.do_lower_case = True\n"," self.use_tpu = False\n"," self.output_dir = \"./data/dummy\"\n"," self.data_dir = EXTRACTDIR\n"," self.max_seq_length = 512\n"," self.init_checkpoint = FINETUNED_MODEL_PATH\n"," self.predict_batch_size = 4\n"," \n"," # The following parameters are not used in predictions.\n"," # Just use to create RunConfig.\n"," self.master = None\n"," self.save_checkpoints_steps = 1\n"," self.iterations_per_loop = 1\n"," self.num_tpu_cores = 1\n"," self.learning_rate = 0\n"," self.num_warmup_steps = 0\n"," self.num_train_steps = 0\n"," self.train_batch_size = 0\n"," self.eval_batch_size = 0"],"execution_count":0,"outputs":[]},{"cell_type":"code","metadata":{"id":"0CZ_5T5ZbRY5","colab_type":"code","colab":{}},"source":["FLAGS = FLAGS()"],"execution_count":0,"outputs":[]},{"cell_type":"code","metadata":{"id":"Kmcrx7hNbRY8","colab_type":"code","colab":{}},"source":["processor = LivedoorProcessor(use_spm=True, do_lower_case=True)\n","label_list = processor.get_labels()"],"execution_count":0,"outputs":[]},{"cell_type":"code","metadata":{"id":"Jcfa6HCsbRY9","colab_type":"code","colab":{}},"source":["tokenizer = tokenization.FullTokenizer(\n"," spm_model_file=FLAGS.model_file, vocab_file=FLAGS.vocab_file,\n"," do_lower_case=FLAGS.do_lower_case)\n","\n","tpu_cluster_resolver = None\n","\n","is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2\n","\n","run_config = tf.contrib.tpu.RunConfig(\n"," cluster=tpu_cluster_resolver,\n"," master=FLAGS.master,\n"," model_dir=FLAGS.output_dir,\n"," save_checkpoints_steps=FLAGS.save_checkpoints_steps,\n"," tpu_config=tf.contrib.tpu.TPUConfig(\n"," iterations_per_loop=FLAGS.iterations_per_loop,\n"," num_shards=FLAGS.num_tpu_cores,\n"," per_host_input_for_training=is_per_host))"],"execution_count":0,"outputs":[]},{"cell_type":"code","metadata":{"id":"7uuY4eWGbRY_","colab_type":"code","colab":{}},"source":["model_fn = model_fn_builder(\n"," albert_config=albert_config,\n"," num_labels=len(label_list),\n"," init_checkpoint=FLAGS.init_checkpoint,\n"," learning_rate=FLAGS.learning_rate,\n"," task_name=\"livedoor\",\n"," num_train_steps=FLAGS.num_train_steps,\n"," num_warmup_steps=FLAGS.num_warmup_steps,\n"," use_tpu=FLAGS.use_tpu,\n"," use_one_hot_embeddings=FLAGS.use_tpu)\n","\n","\n","estimator = tf.contrib.tpu.TPUEstimator(\n"," use_tpu=FLAGS.use_tpu,\n"," model_fn=model_fn,\n"," config=run_config,\n"," train_batch_size=FLAGS.train_batch_size,\n"," eval_batch_size=FLAGS.eval_batch_size,\n"," predict_batch_size=FLAGS.predict_batch_size)"],"execution_count":0,"outputs":[]},{"cell_type":"code","metadata":{"id":"UYm_SPbzbRZB","colab_type":"code","colab":{}},"source":["predict_examples = processor.get_test_examples(FLAGS.data_dir)\n","predict_file = tempfile.NamedTemporaryFile(mode='w+t', encoding='utf-8', suffix='.tf_record')\n","\n","file_based_convert_examples_to_features(predict_examples, label_list,\n"," FLAGS.max_seq_length, tokenizer,\n"," predict_file.name, task_name=\"livedoor\")\n","\n","predict_drop_remainder = True if FLAGS.use_tpu else False\n","\n","predict_input_fn = file_based_input_fn_builder(\n"," input_file=predict_file.name,\n"," seq_length=FLAGS.max_seq_length,\n"," is_training=False,\n"," drop_remainder=predict_drop_remainder, task_name=\"livedoor\", use_tpu=False, bsz=32)"],"execution_count":0,"outputs":[]},{"cell_type":"code","metadata":{"id":"40Cse62RbRZC","colab_type":"code","colab":{}},"source":["result = estimator.predict(input_fn=predict_input_fn)"],"execution_count":0,"outputs":[]},{"cell_type":"code","metadata":{"scrolled":true,"id":"oqZpey6obRZE","colab_type":"code","colab":{}},"source":["%%time\n","# It will take a few hours on CPU environment.\n","\n","result = list(result)"],"execution_count":0,"outputs":[]},{"cell_type":"code","metadata":{"id":"sovDnZz5bRZG","colab_type":"code","colab":{}},"source":["result[:2]"],"execution_count":0,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"7Afahf8EbRZH","colab_type":"text"},"source":["Read test data set and add prediction results."]},{"cell_type":"code","metadata":{"id":"BHzcLZGQbRZI","colab_type":"code","colab":{}},"source":["import pandas as pd"],"execution_count":0,"outputs":[]},{"cell_type":"code","metadata":{"id":"8XwVwsIHbRZJ","colab_type":"code","colab":{}},"source":["test_df = pd.read_csv(os.path.join(EXTRACTDIR, \"test.tsv\"), sep='\\t')"],"execution_count":0,"outputs":[]},{"cell_type":"code","metadata":{"id":"9bLONIAkbRZK","colab_type":"code","colab":{}},"source":["test_df['predict'] = [ label_list[elem['probabilities'].argmax()] for elem in result ]"],"execution_count":0,"outputs":[]},{"cell_type":"code","metadata":{"id":"JsWLi3DwbRZM","colab_type":"code","colab":{}},"source":["test_df.head()"],"execution_count":0,"outputs":[]},{"cell_type":"code","metadata":{"id":"QGuIHWwnbRZN","colab_type":"code","colab":{}},"source":["sum( test_df['label'] == test_df['predict'] ) / len(test_df)"],"execution_count":0,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"FCGoCrx2bRZO","colab_type":"text"},"source":["A littel more detailed check using `sklearn.metrics`."]},{"cell_type":"code","metadata":{"id":"sbf2x_U-bRZO","colab_type":"code","colab":{}},"source":["!pip install scikit-learn"],"execution_count":0,"outputs":[]},{"cell_type":"code","metadata":{"id":"B90lrvULbRZP","colab_type":"code","colab":{}},"source":["from sklearn.metrics import classification_report\n","from sklearn.metrics import confusion_matrix"],"execution_count":0,"outputs":[]},{"cell_type":"code","metadata":{"id":"duYDRM5ebRZR","colab_type":"code","colab":{}},"source":["print(classification_report(test_df['label'], test_df['predict']))"],"execution_count":0,"outputs":[]},{"cell_type":"code","metadata":{"id":"OUgEcLPqbRZR","colab_type":"code","colab":{}},"source":["print(confusion_matrix(test_df['label'], test_df['predict']))"],"execution_count":0,"outputs":[]},{"cell_type":"code","metadata":{"id":"8xekCgB60mWG","colab_type":"code","colab":{}},"source":[""],"execution_count":0,"outputs":[]}]} -------------------------------------------------------------------------------- /pretraining-loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alinear-corp/albert-japanese/ec054c5647ec2ca13f9c57f83a238b53cc998255/pretraining-loss.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | sentencepiece 2 | jupyter 3 | tensorflow==1.15.2 4 | pandas 5 | torch 6 | torchvision 7 | transformers -------------------------------------------------------------------------------- /src/convert_tfmodel_to_pytorch.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | 4 | from transformers import AlbertConfig, AlbertForMaskedLM, load_tf_weights_in_albert, AlbertTokenizer 5 | 6 | logging.basicConfig(level=logging.DEBUG) 7 | 8 | 9 | def main(args): 10 | with open(args.config) as fp: 11 | data = json.loads(fp.read()) 12 | config = AlbertConfig(**data) 13 | model = AlbertForMaskedLM(config) 14 | model: AlbertForMaskedLM = load_tf_weights_in_albert(model, config, args.checkpoint) 15 | model.save_pretrained(args.output) 16 | 17 | tokenizer = AlbertTokenizer.from_pretrained(args.spiece, keep_accents=True) 18 | tokenizer.save_pretrained(args.output) 19 | 20 | 21 | def get_parser(): 22 | import argparse 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument("config") 25 | parser.add_argument("checkpoint") 26 | parser.add_argument("spiece") 27 | parser.add_argument("output") 28 | return parser 29 | 30 | 31 | if __name__ == "__main__": 32 | main(get_parser().parse_args()) 33 | -------------------------------------------------------------------------------- /src/data-download-and-extract.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import configparser 4 | import os 5 | import subprocess 6 | import sys 7 | from urllib.request import urlretrieve 8 | 9 | CURDIR = os.path.dirname(os.path.abspath(__file__)) 10 | CONFIGPATH = os.path.join(CURDIR, os.pardir, 'config.ini') 11 | config = configparser.ConfigParser() 12 | config.read(CONFIGPATH) 13 | 14 | FILEURL = config['DATA']['FILEURL'] 15 | FILEPATH = config['DATA']['FILEPATH'] 16 | EXTRACTDIR = config['DATA']['TEXTDIR'] 17 | 18 | 19 | def reporthook(blocknum, blocksize, totalsize): 20 | ''' 21 | Callback function to show progress of file downloading. 22 | ''' 23 | readsofar = blocknum * blocksize 24 | if totalsize > 0: 25 | percent = readsofar * 1e2 / totalsize 26 | s = "\r%5.1f%% %*d / %d" % ( 27 | percent, len(str(totalsize)), readsofar, totalsize) 28 | sys.stderr.write(s) 29 | if readsofar >= totalsize: # near the end 30 | sys.stderr.write("\n") 31 | else: # total size is unknown 32 | sys.stderr.write("read %d\n" % (readsofar,)) 33 | 34 | 35 | def download(): 36 | urlretrieve(FILEURL, FILEPATH, reporthook) 37 | 38 | 39 | def extract(): 40 | subprocess.call(['python3', 41 | os.path.join(CURDIR, os.pardir, 42 | 'wikiextractor', 'WikiExtractor.py'), 43 | FILEPATH, "-o={}".format(EXTRACTDIR)]) 44 | 45 | 46 | def main(): 47 | download() 48 | extract() 49 | 50 | 51 | if __name__ == "__main__": 52 | main() 53 | -------------------------------------------------------------------------------- /src/file-preprocessing.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Read data_text_dir path from a config file. 4 | CURDIR=$(cd $(dirname $0); pwd) 5 | source <(sed -n '/^\[DATA\]/,/^\[/p' ${CURDIR}/../config.ini | grep TEXTDIR | sed 's/ *= */=/g') 6 | 7 | # Text preprocessing. 8 | # 1-1. Remove blank lines. 9 | # 1-2. Remove line with a blank line. 11 | # 2-1. Remove spaces at the end of each line. 12 | # 2-2. Break line at each 。, but not at 。」 or 。), position. 13 | # 2-3. Remove spaces at the head of each line. 14 | # 3. Remove lines with the head 。(these lines are not meaningful). 15 | # 4. Convert upper case characters to lower case ones. 16 | for FILE in $( find ${TEXTDIR} -name "wiki_*" ); do 17 | echo "Processing ${FILE}" 18 | sed -i -e '/^$/d; ///g' ${FILE} 19 | sed -i -e 's/ *$//g; s/。\([^」|)|)|"]\)/。\n\1/g; s/^[ ]*//g' ${FILE} 20 | sed -i -e '/^。/d' ${FILE} 21 | sed -i -e 's/\(.*\)/\L\1/' ${FILE} 22 | done 23 | 24 | # Concat all text files in each text directory. 25 | for DIR in $( find ${TEXTDIR} -mindepth 1 -type d ); do 26 | echo "Processing ${DIR}" 27 | for f in $( find ${DIR} -name "wiki_*" ); do cat $f >> ${DIR}/all.txt; done 28 | done 29 | -------------------------------------------------------------------------------- /src/run_classifier.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # This file is based on https://github.com/google-research/ALBERT/blob/master/run_classifier.py and https://github.com/yoheikikuta/bert-japanese/blob/master/src/run_classifier.py. 3 | 4 | # Copyright 2018 The Google AI Team Authors. 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | """BERT finetuning on classification tasks.""" 18 | 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | from __future__ import print_function 22 | 23 | import os 24 | import time 25 | from ALBERT import classifier_utils 26 | from ALBERT import fine_tuning_utils 27 | from ALBERT import modeling 28 | from ALBERT import tokenization 29 | import tensorflow.compat.v1 as tf 30 | from tensorflow.contrib import cluster_resolver as contrib_cluster_resolver 31 | from tensorflow.contrib import tpu as contrib_tpu 32 | 33 | flags = tf.flags 34 | 35 | FLAGS = flags.FLAGS 36 | 37 | ## Required parameters 38 | flags.DEFINE_string( 39 | "data_dir", None, 40 | "The input data dir. Should contain the .tsv files (or other data files) " 41 | "for the task.") 42 | 43 | flags.DEFINE_string( 44 | "albert_config_file", None, 45 | "The config json file corresponding to the pre-trained ALBERT model. " 46 | "This specifies the model architecture.") 47 | 48 | flags.DEFINE_string("task_name", None, "The name of the task to train.") 49 | 50 | flags.DEFINE_string( 51 | "vocab_file", None, 52 | "The vocabulary file that the ALBERT model was trained on.") 53 | 54 | flags.DEFINE_string("spm_model_file", None, 55 | "The model file for sentence piece tokenization.") 56 | 57 | flags.DEFINE_string( 58 | "output_dir", None, 59 | "The output directory where the model checkpoints will be written.") 60 | 61 | flags.DEFINE_string("cached_dir", None, 62 | "Path to cached training and dev tfrecord file. " 63 | "The file will be generated if not exist.") 64 | 65 | ## Other parameters 66 | 67 | flags.DEFINE_string( 68 | "init_checkpoint", None, 69 | "Initial checkpoint (usually from a pre-trained BERT model).") 70 | 71 | flags.DEFINE_string( 72 | "albert_hub_module_handle", None, 73 | "If set, the ALBERT hub module to use.") 74 | 75 | flags.DEFINE_bool( 76 | "do_lower_case", True, 77 | "Whether to lower case the input text. Should be True for uncased " 78 | "models and False for cased models.") 79 | 80 | flags.DEFINE_integer( 81 | "max_seq_length", 512, 82 | "The maximum total input sequence length after WordPiece tokenization. " 83 | "Sequences longer than this will be truncated, and sequences shorter " 84 | "than this will be padded.") 85 | 86 | flags.DEFINE_bool("do_train", False, "Whether to run training.") 87 | 88 | flags.DEFINE_bool("do_eval", False, "Whether to run eval on the dev set.") 89 | 90 | flags.DEFINE_bool( 91 | "do_predict", False, 92 | "Whether to run the model in inference mode on the test set.") 93 | 94 | flags.DEFINE_integer("train_batch_size", 32, "Total batch size for training.") 95 | 96 | flags.DEFINE_integer("eval_batch_size", 8, "Total batch size for eval.") 97 | 98 | flags.DEFINE_integer("predict_batch_size", 8, "Total batch size for predict.") 99 | 100 | flags.DEFINE_float("learning_rate", 5e-5, "The initial learning rate for Adam.") 101 | 102 | flags.DEFINE_integer("train_step", 1000, 103 | "Total number of training steps to perform.") 104 | 105 | flags.DEFINE_integer( 106 | "warmup_step", 0, 107 | "number of steps to perform linear learning rate warmup for.") 108 | 109 | flags.DEFINE_integer("save_checkpoints_steps", 1000, 110 | "How often to save the model checkpoint.") 111 | 112 | flags.DEFINE_integer("keep_checkpoint_max", 5, 113 | "How many checkpoints to keep.") 114 | 115 | flags.DEFINE_integer("iterations_per_loop", 1000, 116 | "How many steps to make in each estimator call.") 117 | 118 | flags.DEFINE_bool("use_tpu", False, "Whether to use TPU or GPU/CPU.") 119 | 120 | flags.DEFINE_string("optimizer", "adamw", "Optimizer to use") 121 | 122 | tf.flags.DEFINE_string( 123 | "tpu_name", None, 124 | "The Cloud TPU to use for training. This should be either the name " 125 | "used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 " 126 | "url.") 127 | 128 | tf.flags.DEFINE_string( 129 | "tpu_zone", None, 130 | "[Optional] GCE zone where the Cloud TPU is located in. If not " 131 | "specified, we will attempt to automatically detect the GCE project from " 132 | "metadata.") 133 | 134 | tf.flags.DEFINE_string( 135 | "gcp_project", None, 136 | "[Optional] Project name for the Cloud TPU-enabled project. If not " 137 | "specified, we will attempt to automatically detect the GCE project from " 138 | "metadata.") 139 | 140 | tf.flags.DEFINE_string("master", None, "[Optional] TensorFlow master URL.") 141 | 142 | flags.DEFINE_integer( 143 | "num_tpu_cores", 8, 144 | "Only used if `use_tpu` is True. Total number of TPU cores to use.") 145 | 146 | 147 | class InputExample(object): 148 | """A single training/test example for simple sequence classification.""" 149 | 150 | def __init__(self, guid, text_a, text_b=None, label=None): 151 | """Constructs a InputExample. 152 | 153 | Args: 154 | guid: Unique id for the example. 155 | text_a: string. The untokenized text of the first sequence. For single 156 | sequence tasks, only this sequence must be specified. 157 | text_b: (Optional) string. The untokenized text of the second sequence. 158 | Only must be specified for sequence pair tasks. 159 | label: (Optional) string. The label of the example. This should be 160 | specified for train and dev examples, but not for test examples. 161 | """ 162 | self.guid = guid 163 | self.text_a = text_a 164 | self.text_b = text_b 165 | self.label = label 166 | 167 | 168 | class InputFeatures(object): 169 | """A single set of features of data.""" 170 | 171 | def __init__(self, 172 | input_ids, 173 | input_mask, 174 | segment_ids, 175 | label_id, 176 | is_real_example=True): 177 | self.input_ids = input_ids 178 | self.input_mask = input_mask 179 | self.segment_ids = segment_ids 180 | self.label_id = label_id 181 | self.is_real_example = is_real_example 182 | 183 | 184 | class LivedoorProcessor(classifier_utils.DataProcessor): 185 | """Processor for the livedoor data set (see https://www.rondhuit.com/download.html).""" 186 | 187 | def get_train_examples(self, data_dir): 188 | """See base class.""" 189 | return self._create_examples( 190 | self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") 191 | 192 | def get_dev_examples(self, data_dir): 193 | """See base class.""" 194 | return self._create_examples( 195 | self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") 196 | 197 | def get_test_examples(self, data_dir): 198 | """See base class.""" 199 | return self._create_examples( 200 | self._read_tsv(os.path.join(data_dir, "test.tsv")), "test") 201 | 202 | def get_labels(self): 203 | """See base class.""" 204 | return ['dokujo-tsushin', 'it-life-hack', 'kaden-channel', 'livedoor-homme', 'movie-enter', 'peachy', 'smax', 'sports-watch', 'topic-news'] 205 | 206 | def _create_examples(self, lines, set_type): 207 | """Creates examples for the training and dev sets.""" 208 | examples = [] 209 | for (i, line) in enumerate(lines): 210 | if i == 0: 211 | idx_text = line.index('text') 212 | idx_label = line.index('label') 213 | else: 214 | guid = "%s-%s" % (set_type, i) 215 | text_a = tokenization.convert_to_unicode(line[idx_text]) 216 | label = tokenization.convert_to_unicode(line[idx_label]) 217 | examples.append( 218 | InputExample(guid=guid, text_a=text_a, text_b=None, label=label)) 219 | return examples 220 | 221 | 222 | def main(_): 223 | tf.logging.set_verbosity(tf.logging.INFO) 224 | 225 | processors = { 226 | "livedoor": LivedoorProcessor, 227 | } 228 | 229 | if not FLAGS.do_train and not FLAGS.do_eval and not FLAGS.do_predict: 230 | raise ValueError( 231 | "At least one of `do_train`, `do_eval` or `do_predict' must be True.") 232 | 233 | if not FLAGS.albert_config_file and not FLAGS.albert_hub_module_handle: 234 | raise ValueError("At least one of `--albert_config_file` and " 235 | "`--albert_hub_module_handle` must be set") 236 | 237 | if FLAGS.albert_config_file: 238 | albert_config = modeling.AlbertConfig.from_json_file( 239 | FLAGS.albert_config_file) 240 | if FLAGS.max_seq_length > albert_config.max_position_embeddings: 241 | raise ValueError( 242 | "Cannot use sequence length %d because the ALBERT model " 243 | "was only trained up to sequence length %d" % 244 | (FLAGS.max_seq_length, albert_config.max_position_embeddings)) 245 | else: 246 | albert_config = None # Get the config from TF-Hub. 247 | 248 | tf.gfile.MakeDirs(FLAGS.output_dir) 249 | 250 | task_name = FLAGS.task_name.lower() 251 | 252 | if task_name not in processors: 253 | raise ValueError("Task not found: %s" % (task_name)) 254 | 255 | processor = processors[task_name]( 256 | use_spm=True if FLAGS.spm_model_file else False, 257 | do_lower_case=FLAGS.do_lower_case) 258 | 259 | label_list = processor.get_labels() 260 | 261 | tokenizer = fine_tuning_utils.create_vocab( 262 | vocab_file=FLAGS.vocab_file, 263 | do_lower_case=FLAGS.do_lower_case, 264 | spm_model_file=FLAGS.spm_model_file, 265 | hub_module=FLAGS.albert_hub_module_handle) 266 | 267 | tpu_cluster_resolver = None 268 | if FLAGS.use_tpu and FLAGS.tpu_name: 269 | tpu_cluster_resolver = contrib_cluster_resolver.TPUClusterResolver( 270 | FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project) 271 | 272 | is_per_host = contrib_tpu.InputPipelineConfig.PER_HOST_V2 273 | if FLAGS.do_train: 274 | iterations_per_loop = int(min(FLAGS.iterations_per_loop, 275 | FLAGS.save_checkpoints_steps)) 276 | else: 277 | iterations_per_loop = FLAGS.iterations_per_loop 278 | run_config = contrib_tpu.RunConfig( 279 | cluster=tpu_cluster_resolver, 280 | master=FLAGS.master, 281 | model_dir=FLAGS.output_dir, 282 | save_checkpoints_steps=int(FLAGS.save_checkpoints_steps), 283 | keep_checkpoint_max=0, 284 | tpu_config=contrib_tpu.TPUConfig( 285 | iterations_per_loop=iterations_per_loop, 286 | num_shards=FLAGS.num_tpu_cores, 287 | per_host_input_for_training=is_per_host)) 288 | 289 | train_examples = None 290 | if FLAGS.do_train: 291 | train_examples = processor.get_train_examples(FLAGS.data_dir) 292 | model_fn = classifier_utils.model_fn_builder( 293 | albert_config=albert_config, 294 | num_labels=len(label_list), 295 | init_checkpoint=FLAGS.init_checkpoint, 296 | learning_rate=FLAGS.learning_rate, 297 | num_train_steps=FLAGS.train_step, 298 | num_warmup_steps=FLAGS.warmup_step, 299 | use_tpu=FLAGS.use_tpu, 300 | use_one_hot_embeddings=FLAGS.use_tpu, 301 | task_name=task_name, 302 | hub_module=FLAGS.albert_hub_module_handle, 303 | optimizer=FLAGS.optimizer) 304 | 305 | # If TPU is not available, this will fall back to normal Estimator on CPU 306 | # or GPU. 307 | estimator = contrib_tpu.TPUEstimator( 308 | use_tpu=FLAGS.use_tpu, 309 | model_fn=model_fn, 310 | config=run_config, 311 | train_batch_size=FLAGS.train_batch_size, 312 | eval_batch_size=FLAGS.eval_batch_size, 313 | predict_batch_size=FLAGS.predict_batch_size) 314 | 315 | if FLAGS.do_train: 316 | cached_dir = FLAGS.cached_dir 317 | if not cached_dir: 318 | cached_dir = FLAGS.output_dir 319 | train_file = os.path.join(cached_dir, task_name + "_train.tf_record") 320 | if not tf.gfile.Exists(train_file): 321 | classifier_utils.file_based_convert_examples_to_features( 322 | train_examples, label_list, FLAGS.max_seq_length, tokenizer, 323 | train_file, task_name) 324 | tf.logging.info("***** Running training *****") 325 | tf.logging.info(" Num examples = %d", len(train_examples)) 326 | tf.logging.info(" Batch size = %d", FLAGS.train_batch_size) 327 | tf.logging.info(" Num steps = %d", FLAGS.train_step) 328 | train_input_fn = classifier_utils.file_based_input_fn_builder( 329 | input_file=train_file, 330 | seq_length=FLAGS.max_seq_length, 331 | is_training=True, 332 | drop_remainder=True, 333 | task_name=task_name, 334 | use_tpu=FLAGS.use_tpu, 335 | bsz=FLAGS.train_batch_size) 336 | estimator.train(input_fn=train_input_fn, max_steps=FLAGS.train_step) 337 | 338 | if FLAGS.do_eval: 339 | eval_examples = processor.get_dev_examples(FLAGS.data_dir) 340 | num_actual_eval_examples = len(eval_examples) 341 | if FLAGS.use_tpu: 342 | # TPU requires a fixed batch size for all batches, therefore the number 343 | # of examples must be a multiple of the batch size, or else examples 344 | # will get dropped. So we pad with fake examples which are ignored 345 | # later on. These do NOT count towards the metric (all tf.metrics 346 | # support a per-instance weight, and these get a weight of 0.0). 347 | while len(eval_examples) % FLAGS.eval_batch_size != 0: 348 | eval_examples.append(classifier_utils.PaddingInputExample()) 349 | 350 | cached_dir = FLAGS.cached_dir 351 | if not cached_dir: 352 | cached_dir = FLAGS.output_dir 353 | eval_file = os.path.join(cached_dir, task_name + "_eval.tf_record") 354 | if not tf.gfile.Exists(eval_file): 355 | classifier_utils.file_based_convert_examples_to_features( 356 | eval_examples, label_list, FLAGS.max_seq_length, tokenizer, 357 | eval_file, task_name) 358 | 359 | tf.logging.info("***** Running evaluation *****") 360 | tf.logging.info(" Num examples = %d (%d actual, %d padding)", 361 | len(eval_examples), num_actual_eval_examples, 362 | len(eval_examples) - num_actual_eval_examples) 363 | tf.logging.info(" Batch size = %d", FLAGS.eval_batch_size) 364 | 365 | # This tells the estimator to run through the entire set. 366 | eval_steps = None 367 | # However, if running eval on the TPU, you will need to specify the 368 | # number of steps. 369 | if FLAGS.use_tpu: 370 | assert len(eval_examples) % FLAGS.eval_batch_size == 0 371 | eval_steps = int(len(eval_examples) // FLAGS.eval_batch_size) 372 | 373 | eval_drop_remainder = True if FLAGS.use_tpu else False 374 | eval_input_fn = classifier_utils.file_based_input_fn_builder( 375 | input_file=eval_file, 376 | seq_length=FLAGS.max_seq_length, 377 | is_training=False, 378 | drop_remainder=eval_drop_remainder, 379 | task_name=task_name, 380 | use_tpu=FLAGS.use_tpu, 381 | bsz=FLAGS.eval_batch_size) 382 | 383 | best_trial_info_file = os.path.join(FLAGS.output_dir, "best_trial.txt") 384 | 385 | def _best_trial_info(): 386 | """Returns information about which checkpoints have been evaled so far.""" 387 | if tf.gfile.Exists(best_trial_info_file): 388 | with tf.gfile.GFile(best_trial_info_file, "r") as best_info: 389 | global_step, best_metric_global_step, metric_value = ( 390 | best_info.read().split(":")) 391 | global_step = int(global_step) 392 | best_metric_global_step = int(best_metric_global_step) 393 | metric_value = float(metric_value) 394 | else: 395 | metric_value = -1 396 | best_metric_global_step = -1 397 | global_step = -1 398 | tf.logging.info( 399 | "Best trial info: Step: %s, Best Value Step: %s, " 400 | "Best Value: %s", global_step, best_metric_global_step, metric_value) 401 | return global_step, best_metric_global_step, metric_value 402 | 403 | def _remove_checkpoint(checkpoint_path): 404 | for ext in ["meta", "data-00000-of-00001", "index"]: 405 | src_ckpt = checkpoint_path + ".{}".format(ext) 406 | tf.logging.info("removing {}".format(src_ckpt)) 407 | tf.gfile.Remove(src_ckpt) 408 | 409 | def _find_valid_cands(curr_step): 410 | filenames = tf.gfile.ListDirectory(FLAGS.output_dir) 411 | candidates = [] 412 | for filename in filenames: 413 | if filename.endswith(".index"): 414 | ckpt_name = filename[:-6] 415 | idx = ckpt_name.split("-")[-1] 416 | if int(idx) > curr_step: 417 | candidates.append(filename) 418 | return candidates 419 | 420 | output_eval_file = os.path.join(FLAGS.output_dir, "eval_results.txt") 421 | 422 | if task_name == "sts-b": 423 | key_name = "pearson" 424 | elif task_name == "cola": 425 | key_name = "matthew_corr" 426 | else: 427 | key_name = "eval_accuracy" 428 | 429 | global_step, best_perf_global_step, best_perf = _best_trial_info() 430 | writer = tf.gfile.GFile(output_eval_file, "w") 431 | while global_step < FLAGS.train_step: 432 | steps_and_files = {} 433 | filenames = tf.gfile.ListDirectory(FLAGS.output_dir) 434 | for filename in filenames: 435 | if filename.endswith(".index"): 436 | ckpt_name = filename[:-6] 437 | cur_filename = os.path.join(FLAGS.output_dir, ckpt_name) 438 | if cur_filename.split("-")[-1] == "best": 439 | continue 440 | gstep = int(cur_filename.split("-")[-1]) 441 | if gstep not in steps_and_files: 442 | tf.logging.info("Add {} to eval list.".format(cur_filename)) 443 | steps_and_files[gstep] = cur_filename 444 | tf.logging.info("found {} files.".format(len(steps_and_files))) 445 | if not steps_and_files: 446 | tf.logging.info("found 0 file, global step: {}. Sleeping." 447 | .format(global_step)) 448 | time.sleep(60) 449 | else: 450 | for checkpoint in sorted(steps_and_files.items()): 451 | step, checkpoint_path = checkpoint 452 | if global_step >= step: 453 | if (best_perf_global_step != step and 454 | len(_find_valid_cands(step)) > 1): 455 | _remove_checkpoint(checkpoint_path) 456 | continue 457 | result = estimator.evaluate( 458 | input_fn=eval_input_fn, 459 | steps=eval_steps, 460 | checkpoint_path=checkpoint_path) 461 | global_step = result["global_step"] 462 | tf.logging.info("***** Eval results *****") 463 | for key in sorted(result.keys()): 464 | tf.logging.info(" %s = %s", key, str(result[key])) 465 | writer.write("%s = %s\n" % (key, str(result[key]))) 466 | writer.write("best = {}\n".format(best_perf)) 467 | if result[key_name] > best_perf: 468 | best_perf = result[key_name] 469 | best_perf_global_step = global_step 470 | elif len(_find_valid_cands(global_step)) > 1: 471 | _remove_checkpoint(checkpoint_path) 472 | writer.write("=" * 50 + "\n") 473 | writer.flush() 474 | with tf.gfile.GFile(best_trial_info_file, "w") as best_info: 475 | best_info.write("{}:{}:{}".format( 476 | global_step, best_perf_global_step, best_perf)) 477 | writer.close() 478 | 479 | for ext in ["meta", "data-00000-of-00001", "index"]: 480 | src_ckpt = "model.ckpt-{}.{}".format(best_perf_global_step, ext) 481 | tgt_ckpt = "model.ckpt-best.{}".format(ext) 482 | tf.logging.info("saving {} to {}".format(src_ckpt, tgt_ckpt)) 483 | tf.io.gfile.rename( 484 | os.path.join(FLAGS.output_dir, src_ckpt), 485 | os.path.join(FLAGS.output_dir, tgt_ckpt), 486 | overwrite=True) 487 | 488 | if FLAGS.do_predict: 489 | predict_examples = processor.get_test_examples(FLAGS.data_dir) 490 | num_actual_predict_examples = len(predict_examples) 491 | if FLAGS.use_tpu: 492 | # TPU requires a fixed batch size for all batches, therefore the number 493 | # of examples must be a multiple of the batch size, or else examples 494 | # will get dropped. So we pad with fake examples which are ignored 495 | # later on. 496 | while len(predict_examples) % FLAGS.predict_batch_size != 0: 497 | predict_examples.append(classifier_utils.PaddingInputExample()) 498 | 499 | predict_file = os.path.join(FLAGS.output_dir, "predict.tf_record") 500 | classifier_utils.file_based_convert_examples_to_features( 501 | predict_examples, label_list, 502 | FLAGS.max_seq_length, tokenizer, 503 | predict_file, task_name) 504 | 505 | tf.logging.info("***** Running prediction*****") 506 | tf.logging.info(" Num examples = %d (%d actual, %d padding)", 507 | len(predict_examples), num_actual_predict_examples, 508 | len(predict_examples) - num_actual_predict_examples) 509 | tf.logging.info(" Batch size = %d", FLAGS.predict_batch_size) 510 | 511 | predict_drop_remainder = True if FLAGS.use_tpu else False 512 | predict_input_fn = classifier_utils.file_based_input_fn_builder( 513 | input_file=predict_file, 514 | seq_length=FLAGS.max_seq_length, 515 | is_training=False, 516 | drop_remainder=predict_drop_remainder, 517 | task_name=task_name, 518 | use_tpu=FLAGS.use_tpu, 519 | bsz=FLAGS.predict_batch_size) 520 | 521 | checkpoint_path = os.path.join(FLAGS.output_dir, "model.ckpt-best") 522 | result = estimator.predict( 523 | input_fn=predict_input_fn, 524 | checkpoint_path=checkpoint_path) 525 | 526 | output_predict_file = os.path.join(FLAGS.output_dir, "test_results.tsv") 527 | output_submit_file = os.path.join(FLAGS.output_dir, "submit_results.tsv") 528 | with tf.gfile.GFile(output_predict_file, "w") as pred_writer,\ 529 | tf.gfile.GFile(output_submit_file, "w") as sub_writer: 530 | sub_writer.write("index" + "\t" + "prediction\n") 531 | num_written_lines = 0 532 | tf.logging.info("***** Predict results *****") 533 | for (i, (example, prediction)) in\ 534 | enumerate(zip(predict_examples, result)): 535 | probabilities = prediction["probabilities"] 536 | if i >= num_actual_predict_examples: 537 | break 538 | output_line = "\t".join( 539 | str(class_probability) 540 | for class_probability in probabilities) + "\n" 541 | pred_writer.write(output_line) 542 | 543 | if task_name != "sts-b": 544 | actual_label = label_list[int(prediction["predictions"])] 545 | else: 546 | actual_label = str(prediction["predictions"]) 547 | sub_writer.write(example.guid + "\t" + actual_label + "\n") 548 | num_written_lines += 1 549 | assert num_written_lines == num_actual_predict_examples 550 | 551 | 552 | if __name__ == "__main__": 553 | flags.mark_flag_as_required("data_dir") 554 | flags.mark_flag_as_required("task_name") 555 | flags.mark_flag_as_required("spm_model_file") 556 | flags.mark_flag_as_required("output_dir") 557 | tf.app.run() 558 | -------------------------------------------------------------------------------- /src/run_create_pretraining_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -eu 4 | 5 | export PYTHONPATH=`dirname $0`/.. 6 | 7 | basedir=$1 8 | max_seq_length=$2 9 | 10 | for DIR in $( find ${basedir}/wiki/ -mindepth 1 -type d ); do 11 | out=${DIR}/all-maxseq${max_seq_length}.tfrecord 12 | if [ -f ${out} ]; then 13 | continue 14 | fi 15 | python ALBERT/create_pretraining_data.py \ 16 | --input_file=${DIR}/all.txt \ 17 | --output_file=${out} \ 18 | --spm_model_file=${basedir}/model/wiki-ja_albert.model \ 19 | --vocab_file=${basedir}/model/wiki-ja_albert.vocab \ 20 | --do_lower_case=True \ 21 | --max_seq_length=${max_seq_length} \ 22 | --max_predictions_per_seq=20 \ 23 | --masked_lm_prob=0.15 \ 24 | --random_seed=12345 \ 25 | --dupe_factor=5 \ 26 | --do_whole_word_mask=False \ 27 | --do_permutation=False \ 28 | --favor_shorter_ngram=False \ 29 | --random_next_sentence=False 30 | done 31 | -------------------------------------------------------------------------------- /src/train-sentencepiece.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import configparser 4 | import glob 5 | import os 6 | import sentencepiece as sp 7 | 8 | CURDIR = os.path.dirname(os.path.abspath(__file__)) 9 | CONFIGPATH = os.path.join(CURDIR, os.pardir, 'config.ini') 10 | config = configparser.ConfigParser() 11 | config.read(CONFIGPATH) 12 | 13 | TEXTDIR = config['DATA']['TEXTDIR'] 14 | PREFIX = config['SENTENCEPIECE']['PREFIX'] 15 | VOCABSIZE = config['SENTENCEPIECE']['VOCABSIZE'] 16 | CTLSYMBOLS = config['SENTENCEPIECE']['CTLSYMBOLS'] 17 | 18 | 19 | def _get_text_file(text_dir=TEXTDIR): 20 | file_list = glob.glob(f'{text_dir}/**/*.txt') 21 | files = ",".join(file_list) 22 | return files 23 | 24 | 25 | def train(prefix=PREFIX, vocab_size=VOCABSIZE, ctl_symbols=CTLSYMBOLS): 26 | files = _get_text_file() 27 | # https://github.com/google-research/albert/blob/a41cf11700c1ed2b7beab0a2649817fa52c8d6e1/README.md#sentencepiece 28 | command = f'--input={files} --model_prefix={prefix} --vocab_size={vocab_size} ' \ 29 | f'--pad_id=0 --unk_id=1 --eos_id=-1 --bos_id=-1 --user_defined_symbols=(,),",-,.,–,£,€ ' \ 30 | f'--control_symbols={ctl_symbols} --input_sentence_size=15000000 ' \ 31 | f'--shuffle_input_sentence=true --character_coverage=0.99995 --model_type=unigram' 32 | sp.SentencePieceTrainer.Train(command) 33 | 34 | 35 | def main(): 36 | train() 37 | 38 | 39 | if __name__ == "__main__": 40 | main() 41 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | def str_to_value(input_str): 2 | """ 3 | Convert data type of value of dict to appropriate one. 4 | Assume there are only three types: str, int, float. 5 | """ 6 | if input_str.isalpha(): 7 | return input_str 8 | elif input_str.isdigit(): 9 | return int(input_str) 10 | else: 11 | return float(input_str) 12 | --------------------------------------------------------------------------------