├── .gitattributes ├── .github ├── FUNDING.yml └── workflows │ ├── python-publish.yaml │ └── test.yaml ├── .gitignore ├── .gitmodules ├── CHANGELOG.md ├── LICENSE ├── MANIFEST.in ├── README.md ├── evaluation ├── README.md ├── install_jumanpp.sh ├── jglue │ ├── README.md │ ├── requirements.txt │ └── transformers-4.26.1_jglue-1.1.0_chitra-0.1.8.patch ├── pytorch │ ├── classification_utils.py │ ├── convert_dataset.py │ ├── multiple_choice_utils.py │ ├── qa_utils.py │ ├── requirements.txt │ ├── run_all.sh │ ├── run_evaluation.py │ └── tokenizer_utils.py ├── summary_results.py └── tensorflow │ ├── classification_utils.py │ ├── convert_dataset.py │ ├── multiple_choice_utils.py │ ├── qa_utils.py │ ├── requirements.txt │ ├── run_all.sh │ ├── run_evaluation.py │ └── tokenizer_utils.py ├── misc └── license-header.txt ├── pretraining └── bert │ ├── README.md │ ├── __init__.py │ ├── convert_original_tf2_checkpoint_to_pytorch_nvidia.py │ ├── corpus_preprocessing │ ├── __init__.py │ ├── filter │ │ ├── __init__.py │ │ ├── document_filter │ │ │ ├── __init__.py │ │ │ ├── document_filter.py │ │ │ └── document_filter_name.py │ │ └── sentence_filter │ │ │ ├── __init__.py │ │ │ ├── sentence_filter.py │ │ │ └── sentence_filter_name.py │ └── normalizer │ │ ├── __init__.py │ │ ├── document_normalizer │ │ ├── __init__.py │ │ ├── document_normalizer.py │ │ └── document_normalizer_name.py │ │ └── sentence_normalizer │ │ ├── __init__.py │ │ ├── sentence_normalizer.py │ │ └── sentence_normalizer_name.py │ ├── prepare_dataset.py │ ├── preprocess_dataset.py │ ├── requirements.txt │ ├── resources │ └── ng_words.txt │ ├── run_create_pretraining_data.sh │ ├── run_prepare_dataset.sh │ ├── split_dataset.py │ ├── train_pos_substitution_tokenizer.py │ └── train_wordpiece_tokenizer.py ├── requirements.txt ├── setup.py ├── sudachitra ├── __init__.py ├── conjugation_preserving_normalizer.py ├── input_string_normalizer.py ├── pretokenizer │ ├── __init__.py │ ├── japanese_bert_wordpiece_tokenizer.py │ ├── pos_substitution_tokenizer.py │ └── sudachipy_pretokenizer.py ├── resources │ ├── conjugation_type_table.json │ └── inflection_table.json ├── sudachipy_word_tokenizer.py ├── tokenization_bert_sudachipy.py ├── tokenization_electra_sudachipy.py └── word_formatter.py └── tests ├── __init__.py ├── test_japanese_bert_wordpiece_tokenizer.py └── test_tokenization_bert_sudachipy.py /.gitattributes: -------------------------------------------------------------------------------- 1 | * text=auto 2 | 3 | *.in text 4 | *.md text 5 | *.py text 6 | *.txt text 7 | 8 | *.pyc binary 9 | *.pyd binary 10 | *.pyo binary 11 | *.pyw binary 12 | *.dic binary -------------------------------------------------------------------------------- /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | # These are supported funding model platforms 2 | 3 | github: WorksApplications 4 | -------------------------------------------------------------------------------- /.github/workflows/python-publish.yaml: -------------------------------------------------------------------------------- 1 | # This workflows will upload a Python Package using Twine when a release is created 2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 3 | 4 | name: Upload Python Package 5 | 6 | on: 7 | release: 8 | types: [published] 9 | 10 | jobs: 11 | deploy: 12 | runs-on: ubuntu-latest 13 | 14 | steps: 15 | - uses: actions/checkout@v3 16 | - name: Set up Python 17 | uses: actions/setup-python@v4 18 | with: 19 | python-version: "3.x" 20 | - name: Install dependencies 21 | run: | 22 | python -m pip install --upgrade pip 23 | pip install -r requirements.txt 24 | pip install setuptools wheel 25 | - name: Build 26 | run: | 27 | python setup.py sdist 28 | - name: Publish a Python distribution to PyPI 29 | uses: pypa/gh-action-pypi-publish@release/v1 30 | with: 31 | user: ${{ secrets.PYPI_USERNAME }} 32 | password: ${{ secrets.PYPI_PASSWORD }} 33 | verbose: true 34 | -------------------------------------------------------------------------------- /.github/workflows/test.yaml: -------------------------------------------------------------------------------- 1 | name: test 2 | 3 | on: 4 | workflow_dispatch: 5 | push: 6 | branches: [main] 7 | pull_request: 8 | types: [opened, synchronize, reopened] 9 | branches: [main] 10 | 11 | jobs: 12 | test: 13 | name: Test package 14 | runs-on: ubuntu-latest 15 | 16 | steps: 17 | - uses: actions/checkout@v2 18 | 19 | - name: Set up Python 20 | uses: actions/setup-python@v2 21 | with: 22 | python-version: '3.x' 23 | 24 | - name: Display Python version 25 | run: | 26 | python -c "import sys; print(sys.version)" 27 | 28 | - name: Install Dependencies 29 | run: | 30 | python -m pip install --upgrade pip 31 | pip install -r requirements.txt 32 | 33 | - name: Run Test 34 | run: | 35 | python -m unittest discover tests 36 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | 134 | # pytype static type analyzer 135 | .pytype/ 136 | 137 | # Cython debug symbols 138 | cython_debug/ 139 | 140 | # pycharm 141 | .idea/ 142 | 143 | # notebooks for dev 144 | notebooks_for_dev/ 145 | 146 | # pretraining data 147 | pretraining/bert/datasets/ 148 | pretraining/bert/tokenizers/ -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "pretraining/bert/models"] 2 | path = pretraining/bert/models 3 | url = https://github.com/t-yamamura/models.git 4 | [submodule "pretraining/bert/DeepLearningExamples"] 5 | path = pretraining/bert/DeepLearningExamples 6 | url = https://github.com/NVIDIA/DeepLearningExamples.git 7 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # chiTra-1.1 model (2023-03-17) 2 | 3 | - A pretrained Japanese BERT base model, trained using chiTra tokenizer. 4 | 5 | ## Updates / Changes 6 | 7 | - Cleaning processes of the NWJC corpus are added. 8 | - Total size after cleaning is 79 GB. 9 | - Vocabulary is rebuilt in the same way. 10 | - Total vocab size is `32597`. 11 | - Sudachi libraries are updated to: 12 | - SudachiPy: `0.6.6`. 13 | - SudachiDict: `20220729-core`. 14 | - SudachiTra: `0.1.8`. 15 | - `word_form_type` is changed to `normalized_nouns`. 16 | - Total training steps is increased to `20472`. 17 | 18 | # [0.1.8](https://github.com/WorksApplications/SudachiTra/releases/tag/v0.1.8) (2023-03-10) 19 | 20 | ## Highlights 21 | 22 | - Add new `word_format_type`: `normalized_nouns`. (#48, #50) 23 | - Normalizes morphemes that do not have conjugation form. 24 | 25 | ## Other 26 | 27 | - Faster part-of-speech matching (#36) 28 | - Use HuggingFace compatible pretokenizer (#38) 29 | - Fix/Update pretraining scripts and documents (#39, #40, #45, #46) 30 | - Fix github test workflow (#49) 31 | - Enable to save vocab file with duplicated items (#54) 32 | 33 | # chiTra-1.0 (2022-02-25) 34 | 35 | - A pretrained Japanese BERT base model, trained using chiTra tokenizer. 36 | 37 | ## Details 38 | 39 | - Model 40 | - chiTra-1.0 is a BERT base model. 41 | - Corpus 42 | - We used NINJAL Web Japanese Corpus (NWJC) from National Institute for Japanese Language and Linguistics. 43 | - Cleaning process is explained [here](https://github.com/WorksApplications/SudachiTra/tree/main/pretraining/bert#2-preprocessing-corpus-cleaning). 44 | - Total size after cleaning is 109 GB. 45 | - Vocabulary 46 | - Vocabulary is built on the above corpus, [using WordPiece](https://github.com/WorksApplications/SudachiTra/tree/main/pretraining/bert#wordpiece) and vocab size 32000. 47 | - We added 常用漢字 and 人名用漢字 to cover usual Japanese text. 48 | - Total vocab size is `32615`. 49 | - Sudachi libraries 50 | - SudachiPy: `0.6.2` 51 | - SudachiDict: `20211220-core` 52 | - chiTra: `0.1.7` 53 | - We used `word_form_type`: `normalized_and_surface`. 54 | - Training Parameters 55 | - See [our paper](https://github.com/WorksApplications/SudachiTra#chitra%E3%81%AE%E5%BC%95%E7%94%A8--citing-chitra)) or [pretraining page](https://github.com/WorksApplications/SudachiTra/tree/main/pretraining/bert#5training). 56 | - Total training step is `10236`. 57 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "{}" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright {yyyy} {name of copyright owner} 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include README.md LICENSE requirements.txt 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Sudachi Transformers (chiTra) 2 | 3 | [![](https://img.shields.io/badge/python-3.6+-blue.svg)](https://www.python.org/downloads/release/python-360/) 4 | [![test](https://github.com/WorksApplications/SudachiTra/actions/workflows/test.yaml/badge.svg)](https://github.com/WorksApplications/SudachiTra/actions/workflows/test.yaml) 5 | [![](https://img.shields.io/github/license/WorksApplications/SudachiTra.svg)](https://github.com/WorksApplications/SudachiTra/blob/main/LICENSE) 6 | 7 | chiTraは事前学習済みの大規模な言語モデルと [Transformers](https://github.com/huggingface/transformers) 向けの日本語形態素解析器を提供します。 / chiTra provides the pre-trained language models and a Japanese tokenizer for [Transformers](https://github.com/huggingface/transformers). 8 | 9 | chiTraはSuda**chi Tra**nsformersの略称です。 / chiTra stands for Suda**chi Tra**nsformers. 10 | 11 | ## 事前学習済みモデル / Pretrained Model 12 | 公開データは [Open Data Sponsorship Program](https://registry.opendata.aws/sudachi/) を使用してAWSでホストされています。 / Datas are generously hosted by AWS with their [Open Data Sponsorship Program](https://registry.opendata.aws/sudachi/). 13 | 14 | | Version | Normalized | SudachiTra | Sudachi | SudachiDict | Text | Pretrained Model | 15 | | ------- | ---------------------- | ---------- | ------- | ------------- | ------------ | ------------------------------------------------------------------------------------------- | 16 | | v1.0 | normalized_and_surface | v0.1.7 | 0.6.2 | 20211220-core | NWJC (109GB) | 395 MB ([tar.gz](https://sudachi.s3.ap-northeast-1.amazonaws.com/chitra/chiTra-1.0.tar.gz)) | 17 | | v1.1 | normalized_nouns | v0.1.8 | 0.6.6 | 20220729-core | NWJC with additional cleaning (79GB) | 396 MB ([tar.gz](https://sudachi.s3.ap-northeast-1.amazonaws.com/chitra/chiTra-1.1.tar.gz)) | 18 | 19 | ### 特長 / Features 20 | - 大規模テキストによる学習 / Training on large texts 21 | - 国語研日本語ウェブコーパス (NWJC) をつかってモデルを学習することで多様な表現とさまざまなドメインに対応しています / Models are trained on NINJAL Web Japanese Corpus (NWJC) to support a wide variety of expressions and domains. 22 | - Sudachi の利用 / Using Sudachi 23 | - 形態素解析器 Sudachi を利用することで表記ゆれによる弊害を抑えています / By using the morphological analyzer Sudachi, reduce the negative effects of various notations. 24 | 25 | # chiTraの使い方 / How to use chiTra 26 | 27 | ## クイックツアー / Quick Tour 28 | 事前準備 / Requirements 29 | ```bash 30 | $ pip install sudachitra 31 | $ wget https://sudachi.s3.ap-northeast-1.amazonaws.com/chitra/chiTra-1.1.tar.gz 32 | $ tar -zxvf chiTra-1.1.tar.gz 33 | ``` 34 | 35 | モデルの読み込み / Load the model 36 | ```python 37 | >>> from sudachitra.tokenization_bert_sudachipy import BertSudachipyTokenizer 38 | >>> from transformers import BertModel 39 | 40 | >>> tokenizer = BertSudachipyTokenizer.from_pretrained('chiTra-1.1') 41 | >>> tokenizer.tokenize("選挙管理委員会とすだち") 42 | ['選挙', '##管理', '##委員会', 'と', '酢', '##橘'] 43 | 44 | >>> model = BertModel.from_pretrained('chiTra-1.1') 45 | >>> model(**tokenizer("まさにオールマイティーな商品だ。", return_tensors="pt")).last_hidden_state 46 | tensor([[[ 0.8583, -1.1752, -0.7987, ..., -1.1691, -0.8355, 3.4678], 47 | [ 0.0220, 1.1702, -2.3334, ..., 0.6673, -2.0774, 2.7731], 48 | [ 0.0894, -1.3009, 3.4650, ..., -0.1140, 0.1767, 1.9859], 49 | ..., 50 | [-0.4429, -1.6267, -2.1493, ..., -1.7801, -1.8009, 2.5343], 51 | [ 1.7204, -1.0540, -0.4362, ..., -0.0228, 0.5622, 2.5800], 52 | [ 1.1125, -0.3986, 1.8532, ..., -0.8021, -1.5888, 2.9520]]], 53 | grad_fn=) 54 | ``` 55 | 56 | ## インストール / Installation 57 | 58 | ```shell script 59 | $ pip install sudachitra 60 | ``` 61 | 62 | デフォルトの [Sudachi dictionary](https://github.com/WorksApplications/SudachiDict) は [SudachiDict-core](https://pypi.org/project/SudachiDict-core/) を使用します。 / The default [Sudachi dictionary](https://github.com/WorksApplications/SudachiDict) is [SudachiDict-core](https://pypi.org/project/SudachiDict-core/). 63 | 64 | [SudachiDict-small](https://pypi.org/project/SudachiDict-small/) や [SudachiDict-full](https://pypi.org/project/SudachiDict-full/) など他の辞書をインストールして使用することもできます。 / You can use other dictionaries, such as [SudachiDict-small](https://pypi.org/project/SudachiDict-small/) and [SudachiDict-full](https://pypi.org/project/SudachiDict-full/) .
65 | その場合は以下のように使いたい辞書をインストールしてください。 / In such cases, you need to install the dictionaries.
66 | 事前学習済みモデルを使いたい場合はcore辞書を使用して学習されていることに注意してください。 / If you want to use a pre-trained model, note that it is trained with SudachiDict-core. 67 | 68 | ```shell script 69 | $ pip install sudachidict_small sudachidict_full 70 | ``` 71 | 72 | ## 事前学習 / Pretraining 73 | 74 | 事前学習方法の詳細は [pretraining/bert/README.md](https://github.com/WorksApplications/SudachiTra/tree/main/pretraining/bert) を参照ください。 / Please refer to [pretraining/bert/README.md](https://github.com/WorksApplications/SudachiTra/tree/main/pretraining/bert). 75 | 76 | 77 | ## 開発者向け / For Developers 78 | TBD 79 | 80 | ## ライセンス / License 81 | 82 | Copyright (c) 2022 National Institute for Japanese Language and Linguistics and Works Applications Co., Ltd. All rights reserved. 83 | 84 | "chiTra"は [Apache License, Version 2.0](https://www.apache.org/licenses/LICENSE-2.0) で [国立国語研究所](https://www.ninjal.ac.jp/) 及び [株式会社ワークスアプリケーションズ](https://www.worksap.co.jp/) によって提供されています。 / "chiTra" is distributed by [National Institute for Japanese Language and Linguistics](https://www.ninjal.ac.jp/) and [Works Applications Co.,Ltd.](https://www.worksap.co.jp/) under [Apache License, Version 2.0](https://www.apache.org/licenses/LICENSE-2.0). 85 | 86 | 87 | ## 連絡先 / Contact us 88 | 質問があれば、issueやslackをご利用ください。 / Open an issue, or come to our Slack workspace for questions and discussion. 89 | 90 | 開発者やユーザーの方々が質問したり議論するためのSlackワークスペースを用意しています。 / We have a Slack workspace for developers and users to ask questions and discuss. 91 | https://sudachi-dev.slack.com/ ( [こちら](https://join.slack.com/t/sudachi-dev/shared_invite/enQtMzg2NTI2NjYxNTUyLTMyYmNkZWQ0Y2E5NmQxMTI3ZGM3NDU0NzU4NGE1Y2UwYTVmNTViYjJmNDI0MWZiYTg4ODNmMzgxYTQ3ZmI2OWU) から招待を受けてください) / https://sudachi-dev.slack.com/ (Get invitation [here](https://join.slack.com/t/sudachi-dev/shared_invite/enQtMzg2NTI2NjYxNTUyLTMyYmNkZWQ0Y2E5NmQxMTI3ZGM3NDU0NzU4NGE1Y2UwYTVmNTViYjJmNDI0MWZiYTg4ODNmMzgxYTQ3ZmI2OWU) ) 92 | 93 | 94 | 95 | ## chiTraの引用 / Citing chiTra 96 | chiTraについての論文を発表しています。 / We have published a following paper about chiTra; 97 | - 勝田哲弘, 林政義, 山村崇, Tolmachev Arseny, 高岡一馬, 内田佳孝, 浅原正幸, 単語正規化による表記ゆれに頑健な BERT モデルの構築. 言語処理学会第28回年次大会, 2022. 98 | 99 | chiTraを論文や書籍、サービスなどで引用される際には、以下のBibTexをご利用ください。 / When citing chiTra in papers, books, or services, please use the follow BibTex entries; 100 | ``` 101 | @INPROCEEDINGS{katsuta2022chitra, 102 | author = {勝田哲弘, 林政義, 山村崇, Tolmachev Arseny, 高岡一馬, 内田佳孝, 浅原正幸}, 103 | title = {単語正規化による表記ゆれに頑健な BERT モデルの構築}, 104 | booktitle = "言語処理学会第28回年次大会(NLP2022)", 105 | year = "2022", 106 | pages = "", 107 | publisher = "言語処理学会", 108 | } 109 | ``` 110 | 111 | ### 実験に使用したモデル / Model used for experiment 112 | 「単語正規化による表記ゆれに頑健なBERTモデルの構築」の実験において使用したモデルを以下で公開しています。/ The model used in the experiment of "単語正規化による表記ゆれに頑健なBERTモデルの構築" is published below. 113 | 114 | |   Normalized | Text | Pretrained Model | 115 | | ---------------------- | -------- | ---------------------------------------------------------------------------------------------------------------- | 116 | | surface | Wiki-40B | [tar.gz](https://sudachi.s3.ap-northeast-1.amazonaws.com/chitra/nlp2022/Wikipedia_surface.tar.gz) | 117 | | normalized_and_surface | Wiki-40B | [tar.gz](https://sudachi.s3.ap-northeast-1.amazonaws.com/chitra/nlp2022/Wikipedia_normalized_and_surface.tar.gz) | 118 | | normalized_conjugation | Wiki-40B | [tar.gz](https://sudachi.s3.ap-northeast-1.amazonaws.com/chitra/nlp2022/Wikipedia_normalized_conjugation.tar.gz) | 119 | | normalized | Wiki-40B | [tar.gz](https://sudachi.s3.ap-northeast-1.amazonaws.com/chitra/nlp2022/Wikipedia_normalized.tar.gz) | 120 | 121 | Enjoy chiTra! 122 | -------------------------------------------------------------------------------- /evaluation/README.md: -------------------------------------------------------------------------------- 1 | # Evaluation 2 | 3 | This folder contains scripts to evaluate models. 4 | 5 | # Evaluation methods 6 | 7 | ## Performance on downstream tasks 8 | 9 | We evaluated chiTra models with below 3 tasks. 10 | 11 | - [The Multilingual Amazon Reviews Corpus (Amazon)](https://registry.opendata.aws/amazon-reviews-ml/) 12 | - Text classification task / 文章分類 13 | - We used 'ja' subset only. 14 | - We used `review_body` column as an input text and `stars` column (1 to 5) as a target class. 15 | - [京都大学常識推論データセット (KUCI)](https://nlp.ist.i.kyoto-u.ac.jp/?KUCI) 16 | - Multiple choice task / 常識推論 17 | - [解答可能性付き読解データセット (RCQA)](http://www.cl.ecei.tohoku.ac.jp/rcqa/) 18 | - Question answering task (SQuAD 2.0 format) / 読解 19 | 20 | ### Steps 21 | 22 | 0. Framework 23 | 24 | We prepare evaluation scripts with `pytorch` and `tensorflow`. 25 | The scripts in those directories work equivalently. 26 | 27 | We prepared those scripts with reference to example code of transformers, before [v4.16.0](https://github.com/huggingface/transformers/tree/v4.16.0). 28 | We confirmed they works with transformers-v4.26.1, but you may need updates. 29 | 30 | 1. Prepare datasets 31 | 32 | Use `convert_dataset.py` to convert datsets into suitable format. 33 | For KUCI and RCQA task, you need to download original data beforehand. 34 | 35 | ```bash 36 | python convert_dataset.py amazon --output /datasets/amazon 37 | python convert_dataset.py kuci --input /datasets/KUCI --output /datasets/kuci 38 | python convert_dataset.py rcqa --input /datasets/all-v1.0.json.gz --output /datasets/rcqa 39 | ``` 40 | 41 | By default, we assumes all 3 datasets locate at the same directory. 42 | We also assume the directory name of each datasets are: `amazon`, `kuci`, `rcqa` (case sensitive). 43 | 44 | 2. Modify script 45 | 46 | You need to modify `run_all.sh` script to set pathes to datasets and models. 47 | 48 | - IO 49 | - `SCRIPT_DIR`: Path to the `SudachiTra/evaluation/pytorch` or `tensorflow`. 50 | - `OUTPUT_ROOT`: Path to the directory where experiment results will be written. 51 | - Datasets 52 | - `DATASET_ROOT`: Path to the directory where you prepared datasets. 53 | - `DATASETS`: List of tasks to evaluate. Used for the output directory name. 54 | - Models 55 | - `MODEL_ROOT`: Path to the directory where you put models to evaluate. 56 | - Not neccessary if you set `MODEL_DIRS` by yourself. 57 | - `MODEL_NAMES`: List of models to evaluate. Used for the output directory name. 58 | - `MODEL_DIRS`: Mapping from model name to the model directory or huggingface model name. 59 | 60 | Note: You need to include `bert` in the model path to automaticaly load BERT models. 61 | 62 | 3. Run and collect results. 63 | 64 | Use modified `run_all.sh` to run evaluation with each models, tasks and hyper parameters. 65 | 66 | ```bash 67 | # Install lib 68 | python -m pip install -U -r /path/to/SudachiTra/evaluation/pytorch/requirements.txt 69 | # You need additional libraries to use touhoku-bert: 70 | # python -m pip install fugashi ipadic unidic_lite 71 | 72 | # Run experiment 73 | /path/to/SudachiTra/evaluation/pytorch/run_all.sh 74 | ``` 75 | 76 | `run_all.sh` will write outputs to directories named `[model_name]_[task_name]/[learning_rate]_[batch_size]_[num_epoch]` under the `OUTPUT_ROOT`. 77 | Use `summary_results.py` to gather result files of each tasks. 78 | 79 | ```bash 80 | # Correct test result file. 81 | python summary_results.py amazon /output/chitra_amazon/ --output /summary/amazon.csv 82 | python summary_results.py kuci /output/chitra_kuci/ --output /summary/kuci.csv 83 | python summary_results.py rcqa /output/chitra_rcqa/ --output /summary/rcqa.csv 84 | ``` 85 | 86 | ## Robustness to the text normalization 87 | 88 | Run evaluation with test data whose texts are normalized. 89 | 90 | Ideal model should be robust to this change (outputs remain same after nomralization). 91 | 92 | ### Steps 93 | 94 | Prepare text normalized datasets using `convert_dataset.py` with `--word-form` option. 95 | Provide chiTra tokenizer `word_form_type` to specify the normalization type. 96 | We used `normalized_and_surface` for our experiment. 97 | 98 | ```bash 99 | python convert_dataset.py amazon --output /datasets_normalized/amazon \ 100 | --word-form normalized_and_surface 101 | ``` 102 | 103 | Run experiments in the same way with modified datasets (see above section). 104 | 105 | # Results 106 | 107 | TBA 108 | 109 | Also check [our paper](https://github.com/WorksApplications/SudachiTra#chitra%E3%81%AE%E5%BC%95%E7%94%A8--citing-chitra). 110 | 111 | # Script Usage 112 | 113 | This section shows the list of scripts and their usage. 114 | Also check the help of each scripts. 115 | 116 | ## install_jumanpp.sh 117 | 118 | `install_jumanpp.sh` is a helper script to install Juman++. 119 | 120 | Juman++ is neccessary to use `tokenizer_utils.Juman`, which will be used to tokenize data for Kyoto-U BERT. 121 | 122 | The default install location is `$HOME/.local/usr`. 123 | Modify the script as you want and set PATH. 124 | 125 | ## convert_dataset.py 126 | 127 | `convert_dataset.py` is a script to preprocess datasets. 128 | `run_evaluation.py` requires the dataset format produced by this script. 129 | 130 | This script has `--seed` and `--split-rate` option to randomize train/dev/test set, 131 | however, in our experiment we use default split (no option). 132 | 133 | ### Example 134 | 135 | ```bash 136 | # Amazon Review 137 | # Raw data will be loaded from huggingface datasets hub. 138 | python convert_dataset.py amazon --output ./amazon 139 | 140 | # RCQA dataset 141 | # Download raw data from http://www.cl.ecei.tohoku.ac.jp/rcqa/ beforehand. 142 | python convert_dataset.py rcqa --input ./all-v1.0.json.gz --output ./rcqa 143 | 144 | # KUCI dataset 145 | # Download raw data from https://nlp.ist.i.kyoto-u.ac.jp/?KUCI and untar beforehand. 146 | python convert_dataset.py kuci --input ./KUCI/ --output ./kuci 147 | ``` 148 | 149 | ### Tokenize texts (wakati) 150 | 151 | Some BERT models need texts in the dataset tokenized (wakati-gaki): 152 | 153 | - NICT BERT: by MeCab (juman dic) 154 | - Kyoto-U BERT: by Juman++ 155 | 156 | Use `--tokenize` option to tokenize text columns. 157 | You need to install tokenizers to use. 158 | 159 | Note that `run_evaluation.py` also has an option to tokenize text. 160 | 161 | ```bash 162 | # tokenize with Juman++ 163 | python convert_dataset.py rcqa \ 164 | --input ./all-v1.0.json.gz --output ./rcqa_juman \ 165 | --tokenize juman 166 | 167 | # tokenize with MeCab (juman dic) 168 | python convert_dataset.py rcqa \ 169 | --input ./all-v1.0.json.gz --output ./rcqa_mecab \ 170 | --tokenize mecab --dicdir /var/lib/mecab/dic/juman-utf-8 --mecabrc /etc/mecabrc 171 | ``` 172 | 173 | ### Normalize texts 174 | 175 | You can use `convert_dataset.py` to generate datasets for [testing model robustness](#robustness-to-the-text-normalization). 176 | 177 | Provide `word_form_type` using `--word-form` option to apply sudachitra normalization to texts in datasets. 178 | We used `normalized_and_surface` for our experiment. 179 | 180 | ```bash 181 | python convert_dataset.py amazon --output ./amazon_normalized \ 182 | --word-form normalized_and_surface 183 | ``` 184 | 185 | ## run_evaluation.py 186 | 187 | `run_evaluation.py` is a script to run a single evaluation (with single model, dataset, set of hyper parameters). 188 | 189 | Note: 190 | 191 | - The model path for `--model_name_or_path` must contain `bert` to let `transformers.AutoModel` work correctly. 192 | - To use sudachi tokenizer, set `sudachi` for `tokenizer_name`. 193 | - Script assumes that `vocab.txt` and `tokenizer_config.json` are in the model path. 194 | - You may need to clear huggingface datasets cache file before running this script: 195 | - Dataset preprocessing will generate a cache file with random hash due to the our non-picklable conversion. 196 | - The random hash become same if you use same seed due to the set_seed. 197 | 198 | ### Example 199 | 200 | Template 201 | 202 | ```bash 203 | python ./run_evaluation.py \ 204 | --model_name_or_path [./path/to/model or name in huggingface-hub] \ 205 | --from_pt [set true if load pytorch model] \ 206 | --pretokenizer_name [set "juman" or "mecab-juman" to tokenize text before using HF-tokenizer] \ 207 | --tokenizer_name [set "sudachi" to use SudachiTokenizer] \ 208 | --dataset_name ["amazon" or "rcqa" or "kuci"] \ 209 | --dataset_dir [./path/to/dataset/dir] \ 210 | --output_dir [./path/to/output] \ 211 | --do_train [set to finetune model] \ 212 | --do_eval [set to evaluate model with dev set] \ 213 | --do_predict [set to evaluate model with test data] \ 214 | --per_device_eval_batch_size [evaluation batch size] \ 215 | --per_device_train_batch_size [training batch size] \ 216 | --learning_rate [learning rate] \ 217 | --num_train_epochs [epochs to finetune] \ 218 | --overwrite_cache [set to overwrite data preprocess cache] \ 219 | --max_train_samples [limit number of train samples (for test run)] \ 220 | --max_val_samples [limit number of val samples (for test run)] \ 221 | --max_test_samples [limit number of test samples (for test run)] \ 222 | ``` 223 | 224 | Run finetuning with tohoku BERT and amazon dataset, 225 | assuming dataset file (generated by `convert_dataset.py`) locates under `datasets/amazon/`. 226 | 227 | ```bash 228 | python ./run_evaluation.py \ 229 | --model_name_or_path "cl-tohoku/bert-base-japanese-whole-word-masking" \ 230 | --dataset_name "amazon" \ 231 | --dataset_dir ./datasets/amazon \ 232 | --output_dir ./output/tohoku_amazon \ 233 | --do_train \ 234 | --per_device_eval_batch_size 64 \ 235 | --per_device_train_batch_size 16 \ 236 | --learning_rate 5e-5 \ 237 | --num_train_epochs 4 \ 238 | # --max_train_samples 100 \ 239 | # --max_val_samples 100 \ 240 | # --max_test_samples 100 \ 241 | ``` 242 | 243 | Run prediction with NICT BERT and KUCI dataset. 244 | Assume dataset is not tokenized. 245 | 246 | ```bash 247 | python ./run_evaluation.py \ 248 | --model_name_or_path ./path/to/nict_bert/model \ 249 | --pretokenizer_name "mecab-juman" \ 250 | --dataset_name "kuci" \ 251 | --dataset_dir ./datasets/kuci \ 252 | --output_dir ./output/nict_kuci \ 253 | --do_eval \ 254 | --do_predict \ 255 | --per_device_eval_batch_size 64 \ 256 | --per_device_train_batch_size 16 \ 257 | --learning_rate 5e-5 \ 258 | --num_train_epochs 4 \ 259 | ``` 260 | 261 | Run whole steps with chitra (normalized_and_surface) and RCQA dataset. 262 | 263 | ```bash 264 | python ./run_evaluation.py \ 265 | --model_name_or_path ./path/to/chitra/model \ 266 | --tokenizer_name "sudachi" \ 267 | --dataset_name "rcqa" \ 268 | --dataset_dir ./datasets/rcqa \ 269 | --output_dir ./output/chitra_rcqa \ 270 | --do_train \ 271 | --do_eval \ 272 | --do_predict \ 273 | --per_device_eval_batch_size 64 \ 274 | --per_device_train_batch_size 16 \ 275 | --learning_rate 5e-5 \ 276 | --num_train_epochs 4 \ 277 | ``` 278 | 279 | ## run_all.sh 280 | 281 | `run_all.sh` is a script to run `run_evaluation.py` with different models, tasks and hyper parameters. 282 | 283 | This assumes all model files are placed in the same directory (and named `bert` for `run_evaluation.py`), and all 3 datasets are placed in another directory. 284 | You will need to set those directories in the script for your environment. 285 | 286 | ```bash 287 | ./run_all.sh 288 | ``` 289 | 290 | ## summary_results.py 291 | 292 | `summary_results.py` is a script to collect and summarize metrics of 293 | test results of models with each hyper-parameters. 294 | 295 | It requires the input directory has a structure generated by `run_all.sh`, i.e.: 296 | 297 | ``` 298 | [model and task dir (e.g. "chitra_amazon")] 299 | ├── [hyper-parameter dirs (ex. "5e-5_32_3")] 300 | │ ├── [validation result file] 301 | │ └── [test result file] 302 | ... 303 | ``` 304 | 305 | ```bash 306 | python ./summary_results.py amazon -i ./out/chitra_amazon 307 | ``` 308 | -------------------------------------------------------------------------------- /evaluation/install_jumanpp.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -eu 4 | 5 | # install jumanpp v2.0.0-rc3 6 | # ref: https://qiita.com/Gushi_maru/items/ee434b5bc9f020c8feb6 7 | 8 | if [ ! -d ./jumanpp-2.0.0-rc3 ]; then 9 | wget https://github.com/ku-nlp/jumanpp/releases/download/v2.0.0-rc3/jumanpp-2.0.0-rc3.tar.xz 10 | tar -xvf jumanpp-2.0.0-rc3.tar.xz 11 | fi 12 | if [ ! -d jumanpp-2.0.0-rc3/build ]; then 13 | mkdir jumanpp-2.0.0-rc3/build 14 | fi 15 | 16 | cd jumanpp-2.0.0-rc3/build/ 17 | 18 | apt install cmake -y 19 | 20 | cmake .. -DCMAKE_BUILD_TYPE=Release -DCMAKE_INSTALL_PREFIX=$HOME/.local/usr 21 | make 22 | make install 23 | 24 | -------------------------------------------------------------------------------- /evaluation/jglue/README.md: -------------------------------------------------------------------------------- 1 | # Evaluation with JGLUE 2 | 3 | [JGLUE](https://github.com/yahoojapan/JGLUE) is a Japanese NLP task set. 4 | 5 | This folder contains the resources for the evaluation of chiTra models on JGLUE. 6 | 7 | ## Results 8 | 9 | The performance on the JGLUE dev set are shown in below table. 10 | Results for other models are taken from [JGLUE - Baseline Score](https://github.com/yahoojapan/JGLUE#baseline-scores). 11 | 12 | | Model | MARC-ja | JSTS | JNLI | JSQuAD | JCommonsenseQA | 13 | | --------------------------- | ------- | ---------------- | ----- | ----------- | -------------- | 14 | | | acc | Pearson/Spearman | acc | EM/F1 | acc | 15 | | chiTra-1.0 | 0.956 | 0.903/0.861 | 0.882 | 0.839/0.919 | 0.788 | 16 | | chiTra-1.1 | 0.960 | 0.916/0.876 | 0.900 | 0.860/0.937 | 0.840 | 17 | | | | 18 | | Tohoku BERT base | 0.958 | 0.909/0.868 | 0.899 | 0.871/0.941 | 0.808 | 19 | | Tohoku BERT base (char) | 0.956 | 0.893/0.851 | 0.892 | 0.864/0.937 | 0.718 | 20 | | Tohoku BERT large | 0.955 | 0.913/0.872 | 0.900 | 0.880/0.946 | 0.816 | 21 | | NICT BERT base | 0.958 | 0.910/0.871 | 0.902 | 0.897/0.947 | 0.823 | 22 | | Waseda RoBERTa base | 0.962 | 0.913/0.873 | 0.895 | 0.864/0.927 | 0.840 | 23 | | Waseda RoBERTa large (s128) | 0.954 | 0.930/0.896 | 0.924 | 0.884/0.940 | 0.907 | 24 | | Waseda RoBERTa large (s512) | 0.961 | 0.926/0.892 | 0.926 | 0.918/0.963 | 0.891 | 25 | | XLM RoBERTa base | 0.961 | 0.877/0.831 | 0.893 | -/- | 0.687 | 26 | | XLM RoBERTa large | 0.964 | 0.918/0.884 | 0.919 | -/- | 0.840 | 27 | 28 | Note that chiTra-1.0 and 1.1 are base-size BERT model. 29 | Comparing to the base models, chiTra-1.1 achieves comparable results. 30 | 31 | In JSQuAD task, chiTra performs a little poorly. 32 | Due to the normalization feature of the chiTra tokenizer and the difficulty of 33 | taking alignment of the original and normalized texts, chiTra model outputs normalized text as a result. 34 | In some cases this cause a failure in the answer matching. 35 | 36 | ## Reproduction 37 | 38 | We provide a patch file to modify [the transformers library](https://github.com/huggingface/transformers). 39 | You can follow the instructions in [JGLUE fine-turning page](https://github.com/yahoojapan/JGLUE/tree/main/fine-tuning), replacing the patch file with ours. 40 | You will also need to install some python modules (use `requirements.txt`). 41 | We used `transformers-v4.26.1` to generate our patch, but other versions may work. 42 | 43 | The main additions of our patch comparing to JGLUE's are followings: 44 | 45 | - Assign `sudachitra.BertSudachipyTokenizer` to `transformers.AutoTokenizer`, so that we can auto-load chiTra tokenizer. 46 | - Modification for JSQuAD task: 47 | - Pretokenize datasets using Sudachi tokenizer, instead of whitespace-separation. 48 | - Manage the alignment gap caused by the normalization feature of the chiTra tokenizer. 49 | - Remove whitespaces from the final output of the model in the evaluation. 50 | -------------------------------------------------------------------------------- /evaluation/jglue/requirements.txt: -------------------------------------------------------------------------------- 1 | sudachitra 2 | pytextspan 3 | -------------------------------------------------------------------------------- /evaluation/pytorch/classification_utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import sys 3 | 4 | import numpy as np 5 | from transformers import ( 6 | AutoConfig, 7 | AutoModelForSequenceClassification, 8 | EvalPrediction, 9 | Trainer, 10 | default_data_collator, 11 | ) 12 | 13 | 14 | logger = logging.getLogger(__name__) 15 | logging.basicConfig( 16 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 17 | datefmt="%m/%d/%Y %H:%M:%S", 18 | handlers=[logging.StreamHandler(sys.stderr)], 19 | ) 20 | logger.setLevel(logging.INFO) 21 | 22 | 23 | def setup_args(data_args, raw_datadict): 24 | # num_label to initialize model 25 | dataset_key = list(raw_datadict.keys())[0] # at least one data file exists 26 | label_list = raw_datadict[dataset_key].unique("label") 27 | logger.info(f"classification task with {len(label_list)} labels.") 28 | 29 | data_args.label_list = sorted(label_list) # sort for determinism 30 | data_args.label2id = {l: i for i, l in enumerate(label_list)} 31 | 32 | # columns of input text (2 columns maximum) 33 | data_columns = [ 34 | c for c in raw_datadict[dataset_key].column_names if c != "label"] 35 | if "sentence1" in data_columns: 36 | if "sentence2" in data_columns: 37 | text_columns = ["sentence1", "sentence2"] 38 | else: 39 | text_columns = ["sentence1"] 40 | else: 41 | text_columns = data_columns[:2] 42 | data_args.data_columns = data_columns 43 | data_args.text_columns = text_columns 44 | return data_args 45 | 46 | 47 | def pretokenize_texts(raw_datadict, pretok, data_args): 48 | def subfunc(examples): 49 | for c in data_args.text_columns: 50 | examples[c] = [pretok(s) for s in examples[c]] 51 | return examples 52 | 53 | raw_datadict = raw_datadict.map(subfunc, batched=True) 54 | return raw_datadict 55 | 56 | 57 | def preprocess_dataset(raw_datadict, data_args, tokenizer, max_length): 58 | padding = "max_length" if data_args.pad_to_max_length else False 59 | 60 | # Truncate text before tokenization for sudachi, which has a input bytes limit. 61 | # This may affect the result with a large max_length (tokens). 62 | MAX_CHAR_LENGTH = 2**14 63 | 64 | def subfunc(examples): 65 | # Tokenize texts 66 | texts = ([s[:MAX_CHAR_LENGTH] for s in examples[c]] 67 | for c in data_args.text_columns) 68 | result = tokenizer(*texts, padding=padding, 69 | max_length=max_length, truncation=True) 70 | 71 | # Map labels to ids 72 | if "label" in examples: 73 | result["label"] = [ 74 | (data_args.label2id[l] if l != -1 else -1) for l in examples["label"]] 75 | return result 76 | 77 | datadict = raw_datadict.map( 78 | subfunc, 79 | batched=True, 80 | load_from_cache_file=not data_args.overwrite_cache, 81 | remove_columns=data_args.data_columns, 82 | ) 83 | return datadict 84 | 85 | 86 | def setup_config(config_name_or_path, data_args): 87 | config = AutoConfig.from_pretrained( 88 | config_name_or_path, 89 | finetuning_task=data_args.dataset_name, 90 | num_labels=len(data_args.label2id), 91 | ) 92 | # add label <-> id mapping 93 | config.label2id = data_args.label2id 94 | config.id2label = {i: l for l, i in config.label2id.items()} 95 | return config 96 | 97 | 98 | def setup_trainer(model_name_or_path, config_name, datadict, data_args, training_args, tokenizer, from_tf=False): 99 | config = setup_config(config_name or model_name_or_path, data_args) 100 | model = AutoModelForSequenceClassification.from_pretrained( 101 | model_name_or_path, 102 | config=config, 103 | from_tf=from_tf, 104 | ) 105 | 106 | def compute_metrics(p: EvalPrediction): 107 | preds = p.predictions[0] if isinstance( 108 | p.predictions, tuple) else p.predictions 109 | preds = np.argmax(preds, axis=1) 110 | return {"accuracy": (preds == p.label_ids).astype(np.float32).mean().item()} 111 | 112 | data_collator = ( 113 | default_data_collator 114 | if data_args.pad_to_max_length 115 | else None 116 | ) 117 | 118 | trainer = Trainer( 119 | model=model, 120 | args=training_args, 121 | train_dataset=datadict["train"] if training_args.do_train else None, 122 | eval_dataset=datadict["validation"] if training_args.do_eval else None, 123 | compute_metrics=compute_metrics, 124 | tokenizer=tokenizer, 125 | data_collator=data_collator, 126 | ) 127 | return trainer 128 | 129 | 130 | def evaluate_model(trainer, dataset, label2id=lambda x: x, output_dir=None, stage="eval"): 131 | p = trainer.predict(dataset, metric_key_prefix=stage) 132 | predictions = np.argmax(p.predictions, axis=1) 133 | labels = p.label_ids if p.label_ids is not None else dataset["label"] 134 | metrics = p.metrics if p.metrics is not None else {} 135 | 136 | if output_dir is not None: 137 | i2l = {i: l for l, i in label2id.items()} 138 | output_file = output_dir / f"{stage}_predictions.tsv" 139 | if trainer.is_world_process_zero(): 140 | with open(output_file, "w") as w: 141 | w.write("index\tlabel\tprediction\n") 142 | for i, (l, p) in enumerate(zip(labels, predictions)): 143 | w.write(f"{i}\t{i2l[l]}\t{i2l[p]}\n") 144 | 145 | trainer.log_metrics(stage, metrics) 146 | trainer.save_metrics(stage, metrics) 147 | return metrics 148 | -------------------------------------------------------------------------------- /evaluation/pytorch/multiple_choice_utils.py: -------------------------------------------------------------------------------- 1 | import itertools as it 2 | import logging 3 | import sys 4 | from dataclasses import dataclass 5 | from typing import Optional, Union 6 | 7 | import numpy as np 8 | import torch 9 | from transformers import ( 10 | AutoConfig, 11 | AutoModelForMultipleChoice, 12 | EvalPrediction, 13 | Trainer, 14 | default_data_collator, 15 | ) 16 | from transformers.file_utils import PaddingStrategy 17 | from transformers.tokenization_utils_base import PreTrainedTokenizerBase 18 | 19 | 20 | logger = logging.getLogger(__name__) 21 | logging.basicConfig( 22 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 23 | datefmt="%m/%d/%Y %H:%M:%S", 24 | handlers=[logging.StreamHandler(sys.stderr)], 25 | ) 26 | logger.setLevel(logging.INFO) 27 | 28 | 29 | def setup_args(data_args, raw_datadict): 30 | # set dataset column name, assuming to use convert_dataset.py 31 | dataset_key = list(raw_datadict.keys())[0] # at least one data file exists 32 | column_names = raw_datadict[dataset_key].column_names 33 | 34 | data_args.context_column = "context" 35 | data_args.choice_columns = [ 36 | c for c in column_names if c.startswith("choice")] 37 | data_args.label_column = "label" 38 | return data_args 39 | 40 | 41 | def pretokenize_texts(raw_datadict, pretok, data_args): 42 | text_columns = data_args.choice_columns + [data_args.context_column] 43 | 44 | def subfunc(examples): 45 | for c in text_columns: 46 | examples[c] = [pretok(s) for s in examples[c]] 47 | return examples 48 | 49 | raw_datadict = raw_datadict.map(subfunc, batched=True) 50 | return raw_datadict 51 | 52 | 53 | def preprocess_dataset(raw_datadict, data_args, tokenizer, max_length): 54 | context_column = data_args.context_column 55 | choice_columns = data_args.choice_columns 56 | n_choices = len(choice_columns) 57 | 58 | padding = "max_length" if data_args.pad_to_max_length else False 59 | 60 | def subfunc(examples): 61 | first_sentences = ([c] * n_choices for c in examples[context_column]) 62 | second_sentences = (examples[clm] for clm in choice_columns) 63 | 64 | # flatten 65 | first_sentences = list(it.chain(*first_sentences)) 66 | second_sentences = list(it.chain(*zip(*second_sentences))) 67 | 68 | # tokenize 69 | tokenized = tokenizer( 70 | first_sentences, 71 | second_sentences, 72 | truncation=True, 73 | max_length=max_length, 74 | padding=padding, 75 | ) 76 | 77 | # un-flatten 78 | result = {k: [v[i:i+n_choices] for i in range(0, len(v), n_choices)] 79 | for k, v in tokenized.items()} 80 | 81 | # keep label column as it is, assuming it contains 0-indexed integer 82 | return result 83 | 84 | datadict = raw_datadict.map( 85 | subfunc, 86 | batched=True, 87 | load_from_cache_file=not data_args.overwrite_cache, 88 | ) 89 | return datadict 90 | 91 | 92 | def setup_trainer(model_name_or_path, config_name, datadict, data_args, training_args, tokenizer, from_tf=False): 93 | config = AutoConfig.from_pretrained(config_name or model_name_or_path) 94 | model = AutoModelForMultipleChoice.from_pretrained( 95 | model_name_or_path, 96 | config=config, 97 | from_tf=from_tf, 98 | ) 99 | 100 | def compute_metrics(p: EvalPrediction): 101 | predictions, label_ids = p 102 | preds = np.argmax(predictions, axis=1) 103 | return {"accuracy": (preds == label_ids).astype(np.float32).mean().item()} 104 | 105 | data_collator = ( 106 | default_data_collator 107 | if data_args.pad_to_max_length 108 | else DataCollatorForMultipleChoice(tokenizer=tokenizer, pad_to_multiple_of=8 if training_args.fp16 else None) 109 | ) 110 | 111 | trainer = Trainer( 112 | model=model, 113 | args=training_args, 114 | train_dataset=datadict["train"] if training_args.do_train else None, 115 | eval_dataset=datadict["validation"] if training_args.do_eval else None, 116 | compute_metrics=compute_metrics, 117 | data_collator=data_collator, 118 | tokenizer=tokenizer, 119 | ) 120 | return trainer 121 | 122 | 123 | @dataclass 124 | class DataCollatorForMultipleChoice: 125 | """ 126 | Data collator that will dynamically pad the inputs for multiple choice received. 127 | 128 | Args: 129 | tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`): 130 | The tokenizer used for encoding the data. 131 | padding (:obj:`bool`, :obj:`str` or :class:`~transformers.file_utils.PaddingStrategy`, `optional`, defaults to :obj:`True`): 132 | Select a strategy to pad the returned sequences (according to the model's padding side and padding index) 133 | among: 134 | 135 | * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single 136 | sequence if provided). 137 | * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the 138 | maximum acceptable input length for the model if that argument is not provided. 139 | * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of 140 | different lengths). 141 | max_length (:obj:`int`, `optional`): 142 | Maximum length of the returned list and optionally padding length (see above). 143 | pad_to_multiple_of (:obj:`int`, `optional`): 144 | If set will pad the sequence to a multiple of the provided value. 145 | 146 | This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >= 147 | 7.5 (Volta). 148 | """ 149 | 150 | tokenizer: PreTrainedTokenizerBase 151 | padding: Union[bool, str, PaddingStrategy] = True 152 | max_length: Optional[int] = None 153 | pad_to_multiple_of: Optional[int] = None 154 | 155 | def __call__(self, features): 156 | label_name = "label" if "label" in features[0].keys() else "labels" 157 | labels = [feature.pop(label_name) for feature in features] 158 | batch_size = len(features) 159 | num_choices = len(features[0]["input_ids"]) 160 | flattened_features = [ 161 | [{k: v[i] for k, v in feature.items()} for i in range(num_choices)] for feature in features 162 | ] 163 | flattened_features = list(it.chain(*flattened_features)) 164 | 165 | batch = self.tokenizer.pad( 166 | flattened_features, 167 | padding=self.padding, 168 | max_length=self.max_length, 169 | pad_to_multiple_of=self.pad_to_multiple_of, 170 | return_tensors="pt", 171 | ) 172 | 173 | # Un-flatten 174 | batch = {k: v.view(batch_size, num_choices, -1) 175 | for k, v in batch.items()} 176 | # Add back labels 177 | batch["labels"] = torch.tensor(labels, dtype=torch.int64) 178 | return batch 179 | 180 | 181 | def evaluate_model(trainer, dataset, output_dir=None, stage="eval"): 182 | p = trainer.predict(dataset, metric_key_prefix=stage) 183 | predictions = np.argmax(p.predictions, axis=1) 184 | labels = p.label_ids if p.label_ids is not None else dataset["label"] 185 | metrics = p.metrics if p.metrics is not None else {} 186 | 187 | if output_dir is not None: 188 | output_file = output_dir / f"{stage}_predictions.tsv" 189 | if trainer.is_world_process_zero(): 190 | with open(output_file, "w") as w: 191 | w.write("index\tlabel\tprediction\n") 192 | for i, (l, p) in enumerate(zip(labels, predictions)): 193 | w.write(f"{i}\t{l}\t{p}\n") 194 | 195 | trainer.log_metrics(stage, metrics) 196 | trainer.save_metrics(stage, metrics) 197 | return metrics 198 | -------------------------------------------------------------------------------- /evaluation/pytorch/requirements.txt: -------------------------------------------------------------------------------- 1 | sudachitra 2 | 3 | # frameworks 4 | torch 5 | transformers 6 | tokenizers 7 | datasets 8 | 9 | # QA task 10 | pytextspan 11 | 12 | # Japanese tokenizers 13 | mecab-python3 14 | mojimoji 15 | pyknp 16 | -------------------------------------------------------------------------------- /evaluation/pytorch/run_all.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -eu 3 | 4 | # set your own dir 5 | SCRIPT_DIR="./scripts" 6 | MODEL_ROOT="./bert" 7 | DATASET_ROOT="./datasets" 8 | OUTPUT_ROOT="./out" 9 | 10 | # model to search 11 | MODEL_NAMES=( 12 | "tohoku" 13 | "kyoto" 14 | "nict" 15 | "chitra_surface" 16 | "chitra_normalized_and_surface" 17 | "chitra_normalized_conjugation" 18 | "chitra_normalized" 19 | ) 20 | 21 | DATASETS=("amazon" "rcqa" "kuci") 22 | 23 | # Hyperparameters from Appendix A.3, Devlin et al., 2019 24 | BATCHES=(16 32) 25 | LRS=(5e-5 3e-5 2e-5) 26 | EPOCHS=(2 3 4) 27 | 28 | # set path to the model files 29 | declare -A MODEL_DIRS=( 30 | ["tohoku"]="cl-tohoku/bert-base-japanese-whole-word-masking" 31 | ["kyoto"]="${MODEL_ROOT}/Japanese_L-12_H-768_A-12_E-30_BPE_WWM_transformers" 32 | ["nict"]="${MODEL_ROOT}/NICT_BERT-base_JapaneseWikipedia_32K_BPE" 33 | ["chitra_surface"]="${MODEL_ROOT}/Wikipedia_surface/phase_2" 34 | ["chitra_normalized_and_surface"]="${MODEL_ROOT}/Wikipedia_normalized_and_surface/phase_2" 35 | ["chitra_normalized_conjugation"]="${MODEL_ROOT}/Wikipedia_normalized_conjugation/phase_2" 36 | ["chitra_normalized"]="${MODEL_ROOT}/Wikipedia_normalized/phase_2" 37 | ) 38 | 39 | function set_model_args() { 40 | MODEL=$1 41 | DATASET=$2 42 | MODEL_DIR="${MODEL_DIRS[$1]}" 43 | DATASET_DIR="${DATASET_ROOT}/${DATASET}" 44 | OUTPUT_DIR="${OUTPUT_ROOT}/${MODEL}_${DATASET}/${LR}_${BATCH}_${EPOCH}/" 45 | export MODEL DATASET MODEL_DIR DATASET_DIR OUTPUT_DIR 46 | 47 | # pretokenizer 48 | PRETOKENIZER="identity" 49 | if [ ${MODEL} = "kyoto" ] ; then 50 | PRETOKENIZER="juman" 51 | elif [ ${MODEL} = "nict" ] ; then 52 | PRETOKENIZER="mecab-juman" 53 | fi 54 | export PRETOKENIZER 55 | 56 | # tokenizer (sudachi) 57 | TOKENIZER=${MODEL_DIR} 58 | if [ ${MODEL:0:6} = "chitra" ] ; then 59 | TOKENIZER="sudachi" 60 | fi 61 | export TOKENIZER 62 | } 63 | 64 | command_echo='( echo \ 65 | "${MODEL}, ${DATASET}, ${MODEL_DIR}, ${DATASET_DIR}, ${OUTPUT_DIR}, " \ 66 | "${PRETOKENIZER}, ${TOKENIZER}, ${BATCH}, ${LR}, ${EPOCH}, " \ 67 | )' 68 | 69 | export SCRIPT_PATH="${SCRIPT_DIR}/run_evaluation.py" 70 | command_run='( \ 71 | python ${SCRIPT_PATH} \ 72 | --model_name_or_path ${MODEL_DIR} \ 73 | --pretokenizer_name ${PRETOKENIZER} \ 74 | --tokenizer_name ${TOKENIZER} \ 75 | --dataset_name ${DATASET} \ 76 | --dataset_dir ${DATASET_DIR} \ 77 | --output_dir ${OUTPUT_DIR} \ 78 | --do_train \ 79 | --do_eval \ 80 | --do_predict \ 81 | --gradient_accumulation_steps $((BATCH / 8)) \ 82 | --per_device_eval_batch_size 64 \ 83 | --per_device_train_batch_size 8 \ 84 | --learning_rate ${LR} \ 85 | --num_train_epochs ${EPOCH} \ 86 | --overwrite_cache \ 87 | # --max_train_samples 100 \ 88 | # --max_val_samples 100 \ 89 | # --max_test_samples 100 \ 90 | )' 91 | 92 | # mkdir for log 93 | mkdir -p logs 94 | /bin/true > logs/jobs.txt 95 | 96 | for DATASET in ${DATASETS[@]}; do 97 | for MODEL in ${MODEL_NAMES[@]}; do 98 | for BATCH in ${BATCHES[@]}; do 99 | for LR in ${LRS[@]}; do 100 | for EPOCH in ${EPOCHS[@]}; do 101 | export BATCH LR EPOCH 102 | set_model_args ${MODEL} ${DATASET} 103 | 104 | script -c "${command_echo}" logs/echo.log 105 | script -c "${command_run}" logs/${MODEL}_${DATASET}_batch${BATCH}_lr${LR}_epochs${EPOCH}.log 106 | done 107 | done 108 | done 109 | done 110 | done 111 | -------------------------------------------------------------------------------- /evaluation/pytorch/tokenizer_utils.py: -------------------------------------------------------------------------------- 1 | import MeCab 2 | import mojimoji 3 | import pyknp 4 | import unicodedata as ud 5 | 6 | 7 | class Identity(): 8 | is_identity = False 9 | 10 | def __init__(self): 11 | self.is_identity = True 12 | return 13 | 14 | def tokenize(self, line: str) -> str: 15 | return line 16 | 17 | def __call__(self, line: str) -> str: 18 | return self.tokenize(line) 19 | 20 | 21 | class MecabJuman(Identity): 22 | # tokenization for NICT BERT 23 | def __init__(self, dicdir: str = None, mecabrc: str = None): 24 | # assume existance of followings (installed by `apt install mecab`) 25 | dicdir = dicdir or "/var/lib/mecab/dic/juman-utf8" 26 | mecabrc = mecabrc or "/etc/mecabrc" 27 | assert dicdir and mecabrc 28 | 29 | tagger = MeCab.Tagger(f"-r {mecabrc} -d {dicdir} -Owakati") 30 | charset = tagger.dictionary_info().charset 31 | assert charset in ["utf-8", "utf8"] 32 | 33 | self.tagger = tagger 34 | return 35 | 36 | def tokenize(self, line: str) -> str: 37 | # tokenize text and 38 | normalized = mojimoji.han_to_zen(line).replace("\u3000", " ") 39 | separated = self.tagger.parse(normalized).rstrip() 40 | # rm surrogate char 41 | result = "".join(ch for ch in separated if ud.category(ch) != "Cs") 42 | return result 43 | 44 | 45 | class Juman(Identity): 46 | # tokenization for Kyoto-U BERT 47 | def __init__(self): 48 | # assume Juman++ is installed (see install_jumanpp.sh) 49 | self.tok = pyknp.Juman() 50 | return 51 | 52 | def tokenize(self, line: str) -> str: 53 | normalized = mojimoji.han_to_zen(line) 54 | 55 | # truncate input according to the jumanpp input limit 56 | truncated = _utf8_byte_truncate(normalized, 4096) 57 | morphs = self.tok.analysis(truncated) 58 | separated = " ".join(m.midasi for m in morphs) 59 | return separated 60 | 61 | 62 | def _utf8_lead_byte(b): 63 | '''A UTF-8 intermediate byte starts with the bits 10xxxxxx.''' 64 | return (b & 0xC0) != 0x80 65 | 66 | 67 | def _utf8_byte_truncate(text: str, max_bytes: int): 68 | utf8 = text.encode('utf8') 69 | if len(utf8) <= max_bytes: 70 | return text 71 | # separate before lead byte 72 | i = max_bytes 73 | while i > 0 and not _utf8_lead_byte(utf8[i]): 74 | i -= 1 75 | return utf8[:i].decode("utf8") 76 | -------------------------------------------------------------------------------- /evaluation/summary_results.py: -------------------------------------------------------------------------------- 1 | import argparse as ap 2 | import json 3 | import logging 4 | import sys 5 | from collections import defaultdict as ddict 6 | from enum import Enum 7 | from pathlib import Path 8 | 9 | import pandas as pd 10 | 11 | 12 | logger = logging.getLogger(__name__) 13 | logging.basicConfig( 14 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 15 | datefmt="%m/%d/%Y %H:%M:%S", 16 | handlers=[logging.StreamHandler(sys.stderr)], 17 | ) 18 | logger.setLevel(logging.INFO) 19 | 20 | 21 | class Stages (Enum): 22 | VALIDATION = "validation" 23 | TEST = "test" 24 | 25 | 26 | def summary_amazon(args): 27 | results = {} 28 | for subdir in args.input_dir.glob("*"): 29 | if not subdir.is_dir(): 30 | continue 31 | results[subdir.name] = {} 32 | for stage in Stages: 33 | results[subdir.name][stage] = {} 34 | with (subdir / f"{stage.value}_predictions.tsv").open() as f: 35 | f.readline() # skip headerline 36 | ids, labels, preds = zip(*(line.strip().split("\t") 37 | for line in f.readlines())) 38 | 39 | num_samples = len(ids) 40 | labels = [int(v[6:] if v.startswith("LABEL_") else v) 41 | for v in labels] 42 | preds = [int(v[6:] if v.startswith("LABEL_") else v) 43 | for v in preds] 44 | results[subdir.name][stage]["acc"] = sum( 45 | l == p for l, p in zip(labels, preds)) / num_samples 46 | results[subdir.name][stage]["mse"] = sum( 47 | (l-p)**2 for l, p in zip(labels, preds)) / num_samples 48 | results[subdir.name][stage]["mae"] = sum( 49 | abs(l-p) for l, p in zip(labels, preds)) / num_samples 50 | 51 | log_best_model(results, key=lambda k: results[k][Stages.VALIDATION]["acc"]) 52 | 53 | df = pd.DataFrame( 54 | data=((hp, 55 | ret[Stages.VALIDATION]["acc"], ret[Stages.VALIDATION]["mse"], ret[Stages.VALIDATION]["mae"], 56 | ret[Stages.TEST]["acc"], ret[Stages.TEST]["mse"], ret[Stages.TEST]["mae"],) 57 | for hp, ret in results.items()), 58 | columns=["Parameter", "Dev Acc", "Dev MSE", "Dev MAE", "Test Acc", "Test MSE", "Test MAE"]) 59 | df.to_csv(args.output_file) 60 | return 61 | 62 | 63 | def summary_kuci(args): 64 | results = ddict(ddict) 65 | for subdir in args.input_dir.glob("*"): 66 | if not subdir.is_dir(): 67 | continue 68 | results[subdir.name] = {} 69 | for stage in Stages: 70 | results[subdir.name][stage] = {} 71 | with (subdir / f"{stage.value}_predictions.tsv").open() as f: 72 | f.readline() # skip headerline 73 | ids, labels, preds = zip(*(line.strip().split("\t") 74 | for line in f.readlines())) 75 | 76 | results[subdir.name][stage]["acc"] = sum( 77 | l == p for l, p in zip(labels, preds)) / len(labels) 78 | 79 | log_best_model(results, key=lambda k: results[k][Stages.VALIDATION]["acc"]) 80 | 81 | df = pd.DataFrame( 82 | data=((hp, ret[Stages.VALIDATION]["acc"], ret[Stages.TEST]["acc"]) 83 | for hp, ret in results.items()), 84 | columns=["Parameter", "Dev Acc", "Test Acc"]) 85 | df.to_csv(args.output_file) 86 | return 87 | 88 | 89 | def summary_rcqa(args): 90 | results = {} 91 | for subdir in args.input_dir.glob("*"): 92 | if not subdir.is_dir(): 93 | continue 94 | results[subdir.name] = {} 95 | for stage in Stages: 96 | with (subdir / f"{stage.value}_metrics.json").open() as f: 97 | metrics = json.load(f) 98 | results[subdir.name][stage] = metrics 99 | 100 | log_best_model( 101 | results, key=lambda k: results[k][Stages.VALIDATION]["exact"]) 102 | 103 | df = pd.DataFrame( 104 | data=((hp, 105 | ret[Stages.VALIDATION]["exact"], ret[Stages.VALIDATION]["f1"], 106 | ret[Stages.TEST]["exact"], ret[Stages.TEST]["f1"],) 107 | for hp, ret in results.items()), 108 | columns=["Parameter", "Dev EM", "Dev F1", "Test EM", "Test F1"]) 109 | df.to_csv(args.output_file) 110 | return 111 | 112 | 113 | def log_best_model(results, key): 114 | best_model = max(results, key=key) 115 | logger.info(f"best model: {best_model}") 116 | logger.info(f"result: {results[best_model]}") 117 | return best_model 118 | 119 | 120 | SUMMARY_FUNCS = { 121 | "amazon": summary_amazon, 122 | "kuci": summary_kuci, 123 | "rcqa": summary_rcqa, 124 | } 125 | 126 | 127 | def parse_args(): 128 | parser = ap.ArgumentParser() 129 | parser.add_argument(dest="dataset_name", type=str, 130 | help="Target dataset name. Set \"list\" to list available datasets.") 131 | parser.add_argument(dest="input_dir", type=str, 132 | help="Input directory. output_dir of run_evaluation.py.") 133 | 134 | parser.add_argument("-o", "--output", dest="output_file", type=str, default="./output.csv", 135 | help="File to output summary. `output.csv` by default.") 136 | parser.add_argument("--overwrite", action="store_true", 137 | help="Overwrite output files when they already exist.") 138 | 139 | args = parser.parse_args() 140 | args.dataset_name = args.dataset_name.lower() 141 | args.input_dir = Path(args.input_dir) 142 | args.output_file = Path(args.output_file) 143 | return args 144 | 145 | 146 | def validate_args(args): 147 | if args.dataset_name not in SUMMARY_FUNCS and args.dataset_name != "list": 148 | logger.error(f"Unknown dataset name ({args.dataset_name}). " 149 | f"It must be one of {list(SUMMARY_FUNCS.keys())} or \"list\".") 150 | raise ValueError 151 | 152 | if not args.input_dir.is_dir(): 153 | raise ValueError("input should be directory.") 154 | 155 | if not args.overwrite: 156 | if args.output_file.exists(): 157 | raise ValueError( 158 | f"File {args.output_file} already exists. Set --overwrite to continue anyway.") 159 | return 160 | 161 | 162 | def main(): 163 | args = parse_args() 164 | validate_args(args) 165 | 166 | if args.dataset_name == "list": 167 | logger.info(f"Available datasets: {list(SUMMARY_FUNCS.keys())}") 168 | return 169 | 170 | logger.info(f"input_dir: {args.input_dir}") 171 | 172 | summary_func = SUMMARY_FUNCS[args.dataset_name] 173 | summary_func(args) 174 | return 175 | 176 | 177 | if __name__ == "__main__": 178 | main() 179 | -------------------------------------------------------------------------------- /evaluation/tensorflow/classification_utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import sys 3 | 4 | import numpy as np 5 | import tensorflow as tf 6 | from transformers import ( 7 | TFAutoModelForSequenceClassification, 8 | ) 9 | 10 | 11 | logger = logging.getLogger(__name__) 12 | logging.basicConfig( 13 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 14 | datefmt="%m/%d/%Y %H:%M:%S", 15 | handlers=[logging.StreamHandler(sys.stderr)], 16 | ) 17 | logger.setLevel(logging.INFO) 18 | 19 | 20 | def setup_args(data_args, raw_datadict): 21 | # num_label to initialize model 22 | dataset_key = list(raw_datadict.keys())[0] # at least one data file exists 23 | label_list = raw_datadict[dataset_key].unique("label") 24 | logger.info(f"classification task with {len(label_list)} labels.") 25 | 26 | data_args.label_list = sorted(label_list) 27 | data_args.label2id = {l: i for i, l in enumerate(label_list)} 28 | 29 | # columns of input text 30 | data_columns = [ 31 | c for c in raw_datadict[dataset_key].column_names if c != "label"] 32 | if "sentence1" in data_columns: 33 | if "sentence2" in data_columns: 34 | text_columns = ["sentence1", "sentence2"] 35 | else: 36 | text_columns = ["sentence1"] 37 | else: 38 | text_columns = data_columns[:2] 39 | 40 | data_args.data_columns = data_columns 41 | data_args.text_columns = text_columns 42 | return data_args 43 | 44 | 45 | def pretokenize_texts(raw_datadict, pretok, data_args): 46 | def subfunc(examples): 47 | for c in data_args.text_columns: 48 | examples[c] = [pretok(s) for s in examples[c]] 49 | return examples 50 | 51 | raw_datadict = raw_datadict.map(subfunc, batched=True) 52 | return raw_datadict 53 | 54 | 55 | def preprocess_dataset(raw_datadict, data_args, tokenizer, max_length): 56 | # Truncate text before tokenization for sudachi, which has a input bytes limit. 57 | # This may affect the result with a large max_length (tokens). 58 | MAX_CHAR_LENGTH = 2**14 59 | 60 | def subfunc(examples): 61 | # Tokenize texts 62 | texts = ([s[:MAX_CHAR_LENGTH] for s in examples[c]] 63 | for c in data_args.text_columns) 64 | result = tokenizer(*texts, max_length=max_length, truncation=True) 65 | 66 | # Map labels to ids 67 | if "label" in examples: 68 | result["label"] = [ 69 | (data_args.label2id[l] if l != -1 else -1) for l in examples["label"]] 70 | return result 71 | 72 | datadict = raw_datadict.map( 73 | subfunc, 74 | batched=True, 75 | load_from_cache_file=not data_args.overwrite_cache, 76 | remove_columns=data_args.data_columns 77 | ) 78 | return datadict 79 | 80 | 81 | def convert_dataset_for_tensorflow( 82 | dataset, batch_size, dataset_mode="variable_batch", shuffle=True, drop_remainder=False 83 | ): 84 | def densify_ragged_batch(features, label=None): 85 | features = { 86 | feature: ragged_tensor.to_tensor(shape=batch_shape[feature]) for feature, ragged_tensor in features.items() 87 | } 88 | if label is None: 89 | return features 90 | else: 91 | return features, label 92 | 93 | # convert all columns except "label". 94 | # dataset should not have unneccessary columns. 95 | feature_keys = list(set(dataset.features.keys()) - {"label"}) 96 | 97 | # trim input length for each batch 98 | if dataset_mode == "variable_batch": 99 | batch_shape = {key: None for key in feature_keys} 100 | data = {key: tf.ragged.constant(dataset[key]) for key in feature_keys} 101 | elif dataset_mode == "constant_batch": 102 | data = {key: tf.ragged.constant(dataset[key]) for key in feature_keys} 103 | batch_shape = { 104 | key: tf.concat( 105 | ([batch_size], ragged_tensor.bounding_shape()[1:]), axis=0) 106 | for key, ragged_tensor in data.items() 107 | } 108 | else: 109 | raise ValueError(f"Unknown dataset_mode: {dataset_mode}") 110 | 111 | if "label" in dataset.features: 112 | labels = tf.convert_to_tensor(np.array(dataset["label"])) 113 | tf_dataset = tf.data.Dataset.from_tensor_slices((data, labels)) 114 | else: 115 | tf_dataset = tf.data.Dataset.from_tensor_slices(data) 116 | 117 | if shuffle: 118 | tf_dataset = tf_dataset.shuffle(buffer_size=len(dataset)) 119 | 120 | # ref: https://github.com/tensorflow/tensorflow/issues/42146 121 | options = tf.data.Options() 122 | options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.OFF 123 | 124 | tf_dataset = ( 125 | tf_dataset.with_options(options) 126 | .batch(batch_size=batch_size, drop_remainder=drop_remainder) 127 | .map(densify_ragged_batch) 128 | ) 129 | return tf_dataset 130 | 131 | 132 | def setup_model(model_name_or_path, config, training_args, from_pt=False): 133 | model = TFAutoModelForSequenceClassification.from_pretrained( 134 | model_name_or_path, 135 | config=config, 136 | from_pt=from_pt, 137 | ) 138 | 139 | optimizer = tf.keras.optimizers.Adam( 140 | learning_rate=training_args.learning_rate, 141 | beta_1=training_args.adam_beta1, 142 | beta_2=training_args.adam_beta2, 143 | epsilon=training_args.adam_epsilon, 144 | clipnorm=training_args.max_grad_norm, 145 | ) 146 | loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) 147 | metrics = ["accuracy"] 148 | 149 | model.compile(optimizer=optimizer, loss=loss_fn, metrics=metrics) 150 | return model 151 | 152 | 153 | def evaluate_model(model, dataset, tf_dataset, label2id, output_dir=None, stage="eval"): 154 | predictions = model.predict(tf_dataset)["logits"] 155 | predicted_class = np.argmax(predictions, axis=1) 156 | 157 | labels = dataset["label"] 158 | acc = sum(predicted_class == labels) / len(labels) 159 | metrics = {"accuracy": acc} 160 | 161 | if output_dir is not None: 162 | id2label = {i: l for l, i in label2id.items()} 163 | output_file = output_dir / f"{stage}_predictions.tsv" 164 | with open(output_file, "w") as writer: 165 | writer.write("index\tlabel\tprediction\n") 166 | for index, (label, item) in enumerate(zip(labels, predicted_class)): 167 | label = id2label[label] 168 | item = id2label[item] 169 | writer.write(f"{index}\t{label}\t{item}\n") 170 | 171 | return metrics 172 | -------------------------------------------------------------------------------- /evaluation/tensorflow/multiple_choice_utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import sys 3 | import itertools as it 4 | 5 | import numpy as np 6 | import tensorflow as tf 7 | from transformers import ( 8 | TFAutoModelForMultipleChoice, 9 | ) 10 | 11 | 12 | logger = logging.getLogger(__name__) 13 | logging.basicConfig( 14 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 15 | datefmt="%m/%d/%Y %H:%M:%S", 16 | handlers=[logging.StreamHandler(sys.stderr)], 17 | ) 18 | logger.setLevel(logging.INFO) 19 | 20 | 21 | def setup_args(data_args, raw_datadict): 22 | # set dataset column name, assuming to use convert_dataset.py 23 | dataset_key = list(raw_datadict.keys())[0] # at least one data file exists 24 | column_names = raw_datadict[dataset_key].column_names 25 | 26 | data_args.context_column = "context" 27 | data_args.choice_columns = [ 28 | c for c in column_names if c.startswith("choice")] 29 | data_args.label_column = "label" 30 | data_args.column_names = column_names 31 | return data_args 32 | 33 | 34 | def pretokenize_texts(raw_datadict, pretok, data_args): 35 | text_columns = data_args.choice_columns + [data_args.context_column] 36 | 37 | def subfunc(examples): 38 | for c in text_columns: 39 | examples[c] = [pretok(s) for s in examples[c]] 40 | return examples 41 | 42 | raw_datadict = raw_datadict.map(subfunc, batched=True) 43 | return raw_datadict 44 | 45 | 46 | def preprocess_dataset(raw_datadict, data_args, tokenizer, max_length): 47 | context_column = data_args.context_column 48 | choice_columns = data_args.choice_columns 49 | label_column = data_args.label_column 50 | n_choices = len(choice_columns) 51 | 52 | def subfunc(examples): 53 | first_sentences = ([c] * n_choices for c in examples[context_column]) 54 | second_sentences = (examples[clm] for clm in choice_columns) 55 | 56 | # flatten 57 | first_sentences = list(it.chain(*first_sentences)) 58 | second_sentences = list(it.chain(*zip(*second_sentences))) 59 | 60 | # tokenize 61 | tokenized = tokenizer( 62 | first_sentences, 63 | second_sentences, 64 | max_length=max_length, 65 | truncation=True, 66 | ) 67 | 68 | # un-flatten 69 | result = {k: [v[i:i+n_choices] for i in range(0, len(v), n_choices)] 70 | for k, v in tokenized.items()} 71 | 72 | # keep label column as it is, assuming it contains 0-indexed integer 73 | return result 74 | 75 | data_columns = [c for c in data_args.column_names if c != label_column] 76 | 77 | datadict = raw_datadict.map( 78 | subfunc, 79 | batched=True, 80 | load_from_cache_file=not data_args.overwrite_cache, 81 | remove_columns=data_columns, 82 | ) 83 | return datadict 84 | 85 | 86 | def convert_dataset_for_tensorflow( 87 | dataset, batch_size, dataset_mode="variable_batch", shuffle=True, drop_remainder=False 88 | ): 89 | def densify_ragged_batch(features, label=None): 90 | features = { 91 | feature: ragged_tensor.to_tensor(shape=batch_shape[feature]) for feature, ragged_tensor in features.items() 92 | } 93 | if label is None: 94 | return features 95 | else: 96 | return features, label 97 | 98 | # convert all columns except "label". 99 | # dataset should not have unneccessary columns. 100 | feature_keys = list(set(dataset.features.keys()) - {"label"}) 101 | 102 | # trim input length for each batch 103 | if dataset_mode == "variable_batch": 104 | batch_shape = {key: None for key in feature_keys} 105 | data = {key: tf.ragged.constant(dataset[key]) for key in feature_keys} 106 | elif dataset_mode == "constant_batch": 107 | data = {key: tf.ragged.constant(dataset[key]) for key in feature_keys} 108 | batch_shape = { 109 | key: tf.concat( 110 | ([batch_size], ragged_tensor.bounding_shape()[1:]), axis=0) 111 | for key, ragged_tensor in data.items() 112 | } 113 | else: 114 | raise ValueError(f"Unknown dataset_mode: {dataset_mode}") 115 | 116 | if "label" in dataset.features: 117 | labels = tf.convert_to_tensor(np.array(dataset["label"])) 118 | tf_dataset = tf.data.Dataset.from_tensor_slices((data, labels)) 119 | else: 120 | tf_dataset = tf.data.Dataset.from_tensor_slices(data) 121 | 122 | if shuffle: 123 | tf_dataset = tf_dataset.shuffle(buffer_size=len(dataset)) 124 | 125 | # ref: https://github.com/tensorflow/tensorflow/issues/42146 126 | options = tf.data.Options() 127 | options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.OFF 128 | 129 | tf_dataset = ( 130 | tf_dataset.with_options(options) 131 | .batch(batch_size=batch_size, drop_remainder=drop_remainder) 132 | .map(densify_ragged_batch) 133 | ) 134 | return tf_dataset 135 | 136 | 137 | def setup_model(model_name_or_path, config, training_args, from_pt=False): 138 | model = TFAutoModelForMultipleChoice.from_pretrained( 139 | model_name_or_path, 140 | config=config, 141 | from_pt=from_pt, 142 | ) 143 | 144 | optimizer = tf.keras.optimizers.Adam( 145 | learning_rate=training_args.learning_rate, 146 | beta_1=training_args.adam_beta1, 147 | beta_2=training_args.adam_beta2, 148 | epsilon=training_args.adam_epsilon, 149 | clipnorm=training_args.max_grad_norm, 150 | ) 151 | loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) 152 | metrics = [tf.keras.metrics.SparseCategoricalAccuracy(name="accuracy")] 153 | 154 | model.compile(optimizer=optimizer, loss=loss_fn, metrics=metrics) 155 | return model 156 | 157 | 158 | def evaluate_model(model, dataset, tf_dataset, output_dir=None, stage="eval"): 159 | metrics = model.evaluate(tf_dataset, return_dict=True) 160 | labels = dataset["label"] 161 | 162 | if output_dir is not None: 163 | predictions = model.predict(tf_dataset)["logits"] 164 | predicted_class = np.argmax(predictions, axis=1) 165 | 166 | output_file = output_dir / f"{stage}_predictions.tsv" 167 | with open(output_file, "w") as writer: 168 | writer.write("index\tlabel\tprediction\n") 169 | for index, (label, item) in enumerate(zip(labels, predicted_class)): 170 | writer.write(f"{index}\t{label}\t{item}\n") 171 | 172 | return metrics 173 | -------------------------------------------------------------------------------- /evaluation/tensorflow/requirements.txt: -------------------------------------------------------------------------------- 1 | sudachitra 2 | 3 | # frameworks 4 | tensorflow 5 | torch 6 | transformers 7 | tokenizers 8 | datasets 9 | 10 | # QA task 11 | pytextspan 12 | 13 | # Japanese tokenizers 14 | mecab-python3 15 | mojimoji 16 | pyknp 17 | -------------------------------------------------------------------------------- /evaluation/tensorflow/run_all.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -eu 3 | 4 | # set your own dir 5 | SCRIPT_DIR="./scripts" 6 | MODEL_ROOT="./bert" 7 | DATASET_ROOT="./datasets" 8 | OUTPUT_ROOT="./out" 9 | 10 | # model to search 11 | MODEL_NAMES=( 12 | "tohoku" 13 | "kyoto" 14 | "nict" 15 | "chitra_surface" 16 | "chitra_normalized_and_surface" 17 | "chitra_normalized_conjugation" 18 | "chitra_normalized" 19 | ) 20 | 21 | DATASETS=("amazon" "rcqa" "kuci") 22 | 23 | # Hyperparameters from Appendix A.3, Devlin et al., 2019 24 | BATCHES=(16 32) 25 | LRS=(5e-5 3e-5 2e-5) 26 | EPOCHS=(4) 27 | 28 | # set path to the model files 29 | declare -A MODEL_DIRS=( 30 | ["tohoku"]="cl-tohoku/bert-base-japanese-whole-word-masking" 31 | ["kyoto"]="${MODEL_ROOT}/Japanese_L-12_H-768_A-12_E-30_BPE_WWM_transformers" 32 | ["nict"]="${MODEL_ROOT}/NICT_BERT-base_JapaneseWikipedia_32K_BPE" 33 | ["chitra_surface"]="${MODEL_ROOT}/Wikipedia_surface/phase_2" 34 | ["chitra_normalized_and_surface"]="${MODEL_ROOT}/Wikipedia_normalized_and_surface/phase_2" 35 | ["chitra_normalized_conjugation"]="${MODEL_ROOT}/Wikipedia_normalized_conjugation/phase_2" 36 | ["chitra_normalized"]="${MODEL_ROOT}/Wikipedia_normalized/phase_2" 37 | ) 38 | 39 | function set_model_args() { 40 | MODEL=$1 41 | DATASET=$2 42 | MODEL_DIR="${MODEL_DIRS[$1]}" 43 | DATASET_DIR="${DATASET_ROOT}/${DATASET}" 44 | OUTPUT_DIR="${OUTPUT_ROOT}/${MODEL}_${DATASET}" 45 | export MODEL DATASET MODEL_DIR DATASET_DIR OUTPUT_DIR 46 | 47 | # whether if we load the model from pytorch param 48 | FROM_PT=true 49 | if [ ${MODEL} = "tohoku" ] ; then 50 | FROM_PT=false 51 | fi 52 | export FROM_PT 53 | 54 | # pretokenizer 55 | PRETOKENIZER="identity" 56 | if [ ${MODEL} = "kyoto" ] ; then 57 | PRETOKENIZER="juman" 58 | elif [ ${MODEL} = "nict" ] ; then 59 | PRETOKENIZER="mecab-juman" 60 | fi 61 | export PRETOKENIZER 62 | 63 | # tokenizer (sudachi) 64 | TOKENIZER=${MODEL_DIR} 65 | if [ ${MODEL:0:6} = "chitra" ] ; then 66 | TOKENIZER="sudachi" 67 | fi 68 | export TOKENIZER 69 | } 70 | 71 | command_echo='( echo \ 72 | "${MODEL}, ${DATASET}, ${MODEL_DIR}, ${DATASET_DIR}, ${OUTPUT_DIR}, " \ 73 | "${FROM_PT}, ${PRETOKENIZER}, "${TOKENIZER}, ${BATCH}, ${LR}, ${EPOCH}, " \ 74 | )' 75 | 76 | export SCRIPT_PATH="${SCRIPT_DIR}/run_evaluation.py" 77 | command_run='( \ 78 | python ${SCRIPT_PATH} \ 79 | --model_name_or_path ${MODEL_DIR} \ 80 | --from_pt ${FROM_PT} \ 81 | --pretokenizer_name ${PRETOKENIZER} \ 82 | --tokenizer_name ${TOKENIZER} \ 83 | --dataset_name ${DATASET} \ 84 | --dataset_dir ${DATASET_DIR} \ 85 | --output_dir ${OUTPUT_DIR} \ 86 | --do_train \ 87 | --do_eval \ 88 | --do_predict \ 89 | --per_device_eval_batch_size 64 \ 90 | --per_device_train_batch_size ${BATCH} \ 91 | --learning_rate ${LR} \ 92 | --num_train_epochs ${EPOCH} \ 93 | --overwrite_cache \ 94 | # --max_train_samples 100 \ 95 | # --max_val_samples 100 \ 96 | # --max_test_samples 100 \ 97 | )' 98 | 99 | # mkdir for log 100 | mkdir -p logs 101 | /bin/true > logs/jobs.txt 102 | 103 | for DATASET in ${DATASETS[@]}; do 104 | for MODEL in ${MODEL_NAMES[@]}; do 105 | for BATCH in ${BATCHES[@]}; do 106 | for LR in ${LRS[@]}; do 107 | for EPOCH in ${EPOCHS[@]}; do 108 | export BATCH LR EPOCH 109 | set_model_args ${MODEL} ${DATASET} 110 | 111 | script -c "${command_echo}" logs/echo.log 112 | script -c "${command_run}" logs/${MODEL}_${DATASET}_batch${BATCH}_lr${LR}_epochs${EPOCH}.log 113 | done 114 | done 115 | done 116 | done 117 | done 118 | -------------------------------------------------------------------------------- /evaluation/tensorflow/tokenizer_utils.py: -------------------------------------------------------------------------------- 1 | import MeCab 2 | import mojimoji 3 | import pyknp 4 | import unicodedata as ud 5 | 6 | 7 | class Identity(): 8 | is_identity = False 9 | 10 | def __init__(self): 11 | self.is_identity = True 12 | return 13 | 14 | def tokenize(self, line: str) -> str: 15 | return line 16 | 17 | def __call__(self, line: str) -> str: 18 | return self.tokenize(line) 19 | 20 | 21 | class MecabJuman(Identity): 22 | # tokenization for NICT BERT 23 | def __init__(self, dicdir: str = None, mecabrc: str = None): 24 | # assume existance of followings (installed by `apt install mecab`) 25 | dicdir = dicdir or "/var/lib/mecab/dic/juman-utf8" 26 | mecabrc = mecabrc or "/etc/mecabrc" 27 | assert dicdir and mecabrc 28 | 29 | tagger = MeCab.Tagger(f"-r {mecabrc} -d {dicdir} -Owakati") 30 | charset = tagger.dictionary_info().charset 31 | assert charset in ["utf-8", "utf8"] 32 | 33 | self.tagger = tagger 34 | return 35 | 36 | def tokenize(self, line: str) -> str: 37 | # tokenize text and 38 | normalized = mojimoji.han_to_zen(line).replace("\u3000", " ") 39 | separated = self.tagger.parse(normalized).rstrip() 40 | # rm surrogate char 41 | result = "".join(ch for ch in separated if ud.category(ch) != "Cs") 42 | return result 43 | 44 | 45 | class Juman(Identity): 46 | # tokenization for Kyoto-U BERT 47 | def __init__(self): 48 | # assume Juman++ is installed (see install_jumanpp.sh) 49 | self.tok = pyknp.Juman() 50 | return 51 | 52 | def tokenize(self, line: str) -> str: 53 | normalized = mojimoji.han_to_zen(line) 54 | 55 | # truncate input according to the jumanpp input limit 56 | truncated = _utf8_byte_truncate(normalized, 4096) 57 | morphs = self.tok.analysis(truncated) 58 | separated = " ".join(m.midasi for m in morphs) 59 | return separated 60 | 61 | 62 | def _utf8_lead_byte(b): 63 | '''A UTF-8 intermediate byte starts with the bits 10xxxxxx.''' 64 | return (b & 0xC0) != 0x80 65 | 66 | 67 | def _utf8_byte_truncate(text: str, max_bytes: int): 68 | utf8 = text.encode('utf8') 69 | if len(utf8) <= max_bytes: 70 | return text 71 | # separate before lead byte 72 | i = max_bytes 73 | while i > 0 and not _utf8_lead_byte(utf8[i]): 74 | i -= 1 75 | return utf8[:i].decode("utf8") 76 | -------------------------------------------------------------------------------- /misc/license-header.txt: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Works Applications Co., Ltd. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. -------------------------------------------------------------------------------- /pretraining/bert/README.md: -------------------------------------------------------------------------------- 1 | # Training Sudachi BERT Models 2 | 3 | This repository also provides a script and recipe to train the [BERT](https://arxiv.org/abs/1810.04805) model. 4 | 5 | ## Pretrained models 6 | 7 | You can download the pretrained models from [README](../../README.md). 8 | 9 | ## Set up 10 | 11 | In order to pretrain models, you need to download this repository, including its submodules. 12 | 13 | ```shell script 14 | $ git clone --recursive https://github.com/WorksApplications/SudachiTra/ 15 | ``` 16 | 17 | In addition, you need to install the required packages to pretrain models. 18 | 19 | ```shell script 20 | $ pip install -U sudachitra 21 | $ cd SudachiTra/ 22 | $ pip install -r requirements.txt 23 | $ pip install -r pretraining/bert/requirements.txt 24 | $ pip install -r pretraining/bert/models/official/requirements.txt 25 | ``` 26 | 27 | ## Quick Start Guide 28 | 29 | In the following guide, we use [wiki40b](https://www.tensorflow.org/datasets/catalog/wiki40b) dataset. 30 | 31 | ### 1. Download Wiki40b 32 | 33 | For pre-training BERT, you need to prepare the data split into document units. 34 | 35 | The [`run_prepare_dataset.sh`](run_prepare_dataset.sh) script launches download and processing of wiki40b. 36 | The component steps in the script to prepare the datasets are as follows: 37 | 38 | * Data download - wiki40b is downloaded in the `datasets/corpus` directory. 39 | * Sentence segmentation - the corpus text is processed into separate sentences. 40 | * Document segmentation - the corpus text divided into document. 41 | 42 | The processed data is saved in `./datasets/corpus_splitted_by_paragraph`. 43 | The corpus files are approximately 4.0GB in total. 44 | 45 | ```shell script 46 | $ cd pretraining/bert/ 47 | # It may take several hours. 48 | $ ./run_prepare_dataset.sh 49 | ``` 50 | 51 | ### 2. Preprocessing: Corpus Cleaning 52 | 53 | Some sentences in the downloaded corpus are too short or too long to have much impact on learning. 54 | There are also documents that are too short or contain inappropriate words. 55 | To filter out such sentences and documents, use the [`preprocess_dataset.py`](preprocess_dataset.py) script. 56 | In addition to the cleaning process, this script also performs the sentence-level and document-level normalization process. 57 | 58 | Example script to apply all cleaning and normalization processes. 59 | 60 | ```shell 61 | $ py preprocess_dataset.py \ 62 | -i ./datasets/corpus_splitted_by_paragraph/ja_wiki40b_train.paragraph.txt \ 63 | -o ./datasets/corpus_splitted_by_paragraph/ja_wiki40b_train.preprocessed.paragraph.txt \ 64 | --sentence_filter_names email url sequence_length \ 65 | --document_filter_names short_document script ng_words \ 66 | --sentence_normalizer_names citation whitespace \ 67 | --document_normalizer_names concat_short_sentence 68 | ``` 69 | 70 | 71 | ### 3. Building vocabulary 72 | 73 | You can specify tokenizer options of SudachiPy, such as sudachi dictionaries, split modes, and word forms. 74 | 75 | The following word forms are available: 76 | 77 | * `surface` 78 | * `dictionary` 79 | * `normalized` 80 | * `dictionary_and_surface` 81 | * `normalized_and_surface` 82 | 83 | A implements three kinds of subword tokenizers: 84 | 85 | * `WordPiece` 86 | * `Character` 87 | * `POS Substitution (part-of-speech substitution)` 88 | 89 | #### WordPiece 90 | 91 | We used WordPiece to obtain subwords. 92 | We used an implementation of WordPiece in [Tokenizers](https://github.com/huggingface/tokenizers). 93 | 94 | ```shell script 95 | $ python3 train_wordpiece_tokenizer.py \ 96 | --input_file datasets/corpus_splitted_by_paragraph/ja_wiki40b_train.preprocessed.paragraph.txt \ 97 | --do_nfkc \ 98 | --vocab_size 32000 \ 99 | --limit_alphabet 5000 \ 100 | --dict_type core \ 101 | --split_mode C \ 102 | --word_form_type normalized \ 103 | --output_dir _tokenizers/ja_wiki40b/wordpiece/train_CoreDic_normalized_unit-C \ 104 | --config_name config.json \ 105 | --vocab_prefix wordpiece 106 | ``` 107 | 108 | #### Character 109 | 110 | You can get a vocabulary for `Character` by extracting only the characters from the vocabulary created by `Wordpiece` tokenization. 111 | 112 | ```shell script 113 | # e.g. #characters(5,000) + #special_tokens(5) = 5,005 114 | $ OUTPUT_DIR="tokenizers/ja_wiki40b/character/train_CoreDic_normalized_unit-C" 115 | $ mkdir -p $OUTPUT_DIR 116 | $ head -n 5005 _tokenizers/ja_wiki40b/wordpiece/train_CoreDic_normalized_unit-C/wordpiece-vocab.txt > $OUTPUT_DIR/vocab.txt 117 | ``` 118 | 119 | #### POS Substitution (part-of-speech substitution) 120 | 121 | `POS Substitution` is a method using part-of-speech tags to reduce a vocabulary size. 122 | In `POS Substitution`, instead of using a subword tokenizer, low frequency words are replaced by part-of-speech tags. 123 | Finally, only part-of-speech tags that do not appear in a training corpus are treated as unknown words. 124 | 125 | 126 | ```shell script 127 | $ python3 train_pos_substitution_tokenizer.py \ 128 | --input_file datasets/corpus_splitted_by_paragraph/ja_wiki40b_train.preprocessed.paragraph.txt \ 129 | --token_size 32000 \ 130 | --limit_character 5000 \ 131 | --dict_type core \ 132 | --split_mode C \ 133 | --word_form_type normalized \ 134 | --output_file _tokenizers/ja_wiki40b/pos_substitution/train_CoreDic_normalized_unit-C/vocab.txt 135 | ``` 136 | 137 | ### 4.Creating data for pretraining 138 | 139 | To create the data for pre-training, we utilize a code based on [TensorFlow Model Garden](https://github.com/tensorflow/models). 140 | The code to create the pre-training data with the tokenizer modified for SudachiPy is `pretraining/models/official/nlp/data/create_pretraining_data.py`. 141 | 142 | This code will consume a lot of memory. 143 | We can handle this by splitting the training corpus into multiple files and processing them in parallel. 144 | Therefore, we recommend split train data into multiple files. 145 | 146 | In the following example, the number of sentences per file (`--line_per_file`) is set to 700,000. 147 | It consumes about 10 GB or more of memory to create the data for pre-training from this one file. 148 | 149 | 150 | ```shell script 151 | # splits wiki40b into multiple files 152 | $ python3 split_dataset.py \ 153 | --input_file datasets/corpus_splitted_by_paragraph/ja_wiki40b_train.preprocessed.paragraph.txt \ 154 | --line_per_file 700000 155 | $ TRAIN_FILE_NUM=`find datasets/corpus_splitted_by_paragraph -type f | grep -E "ja_wiki40b_train.preprocessed.paragraph[0-9]+.txt" | wc -l` 156 | ``` 157 | 158 | ```shell script 159 | # Change the value according to the execution environment. 160 | $ MAX_PROCS=8 161 | 162 | $ mkdir datasets_for_pretraining 163 | $ export $PYTHONPATH="$PYTHONPATH:./models" 164 | $ seq 1 ${TRAIN_FILE_NUM} | xargs -L 1 -I {} -P ${MAX_PROCS} python3 models/official/nlp/data/create_pretraining_data.py \ 165 | --input_file datasets/corpus_splitted_by_paragraph/ja_wiki40b_train.preprocessed.paragraph{}.txt \ 166 | --output_file datasets_for_pretraining/pretraining_train_{}.tf_record \ 167 | --do_nfkc \ 168 | --vocab_file _tokenizers/ja_wiki40b/wordpiece/train_CoreDic_normalized_unit-C/wordpiece-vocab.txt \ 169 | --tokenizer_type wordpiece \ 170 | --word_form_type normalized \ 171 | --split_mode C \ 172 | --sudachi_dic_type core \ 173 | --do_whole_word_mask \ 174 | --max_seq_length 512 \ 175 | --max_predictions_per_seq 80 \ 176 | --dupe_factor 10 177 | ``` 178 | 179 | ### 5.Training 180 | 181 | #### NVIDIA DeepLearningExamples 182 | 183 | To pretrain a model, we utilize a code based on [NVIDIA Deep Learning Examples](https://github.com/NVIDIA/DeepLearningExamples). 184 | 185 | nvidia-docker is used. 186 | Put the train data in this directory (`SudachiTra/pretraining/bert/DeepLearningExamples/TensorFlow2/LanguageModeling/BERT/data`). 187 | 188 | ```shell script 189 | $ docker pull nvcr.io/nvidia/tensorflow:21.10-tf2-py3 190 | $ cd DeepLearningExamples/TensorFlow2/LanguageModeling/BERT 191 | $ bash scripts/docker/build.sh 192 | $ bash scripts/docker/launch.sh 193 | 194 | $ python3 data/bertPrep.py --action download --dataset google_pretrained_weights # Change the config if necessary. ex. vocab_size 195 | $ bash scripts/run_pretraining_lamb.sh 176 22 8 7.5e-4 5e-4 tf32 true 4 2000 200 11374 100 64 192 base # Change the path in run_pretraining_lamb if necessary. 196 | ``` 197 | 198 | ### 6.Converting a model to pytorch format 199 | 200 | #### NVIDIA DeepLearningExamples 201 | 202 | To convert the generated model checkpoints to Pytorch, you can use `convert_original_tf2_checkpoint_to_pytorch_nvidia.py`. 203 | 204 | ```shell script 205 | $ cd SudachiTra/pretraining/bert/ 206 | $ python3 convert_original_tf2_checkpoint_to_pytorch_nvidia.py \ 207 | --tf_checkpoint_path /path/to/checkpoint \ 208 | --config_file /path/to/bert_config.json \ 209 | --pytorch_dump_path /path/to/pytorch_model.bin 210 | ``` -------------------------------------------------------------------------------- /pretraining/bert/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Works Applications Co., Ltd. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /pretraining/bert/corpus_preprocessing/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Works Applications Co., Ltd. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /pretraining/bert/corpus_preprocessing/filter/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Works Applications Co., Ltd. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /pretraining/bert/corpus_preprocessing/filter/document_filter/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Works Applications Co., Ltd. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .document_filter import ( 16 | DF, SequenceDocumentFilter, 17 | ShortDocumentFilter, 18 | NGWordsFilter, 19 | ScriptFilter 20 | ) 21 | from .document_filter_name import DocumentFilterName 22 | -------------------------------------------------------------------------------- /pretraining/bert/corpus_preprocessing/filter/document_filter/document_filter.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Works Applications Co., Ltd. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import re 16 | from re import Match 17 | from sudachipy import Dictionary, SplitMode 18 | from typing import List, TypeVar 19 | 20 | 21 | class DocumentFilter(object): 22 | 23 | def is_filtered(self, document: List[str]) -> bool: 24 | """ 25 | Determine if the input document should be filtered or not. 26 | 27 | Args: 28 | document (List[str]): Sentences in a training corpus to pretrain model. 29 | 30 | Returns: 31 | bool: `True` if the document should be filtered, otherwise `False`. 32 | """ 33 | raise NotImplementedError() 34 | 35 | 36 | DF = TypeVar('DF', bound=DocumentFilter) 37 | 38 | 39 | class SequenceDocumentFilter(DocumentFilter): 40 | def __init__(self, document_filters: List[DF]): 41 | """ 42 | Constructs a `SequenceDocumentFilter`. 43 | 44 | Args: 45 | document_filters (List[DF]): A list of DocumentFilters. 46 | """ 47 | self._document_filters: List[DF] = document_filters 48 | 49 | def is_filtered(self, document: List[str]) -> bool: 50 | """ 51 | Applies filters to the document in sequence. 52 | 53 | Args: 54 | document (List[str]): A list of sentences. 55 | 56 | Returns: 57 | bool: `True` if the document should be filtered, otherwise `False`. 58 | """ 59 | return any([df.is_filtered(document) for df in self._document_filters]) 60 | 61 | 62 | class ScriptFilter(DocumentFilter): 63 | curly_brackets_ptn = re.compile(r'[\{|\}]') 64 | 65 | def is_filtered(self, document: List[str]) -> bool: 66 | """ 67 | Determines if the document contains the curly bracket ```{``` to if it contains code. 68 | 69 | Args: 70 | document (List[str]): A list of sentences. 71 | 72 | Returns: 73 | bool: `True` if the document contains the curly bracket, otherwise `False`. 74 | """ 75 | return any([self.curly_brackets_ptn.search(sentence) for sentence in document]) 76 | 77 | 78 | class ShortDocumentFilter(DocumentFilter): 79 | 80 | def __init__(self, min_sentence_num: int = 5): 81 | """ 82 | Constructs a ShortDocumentFilter. 83 | 84 | Args: 85 | min_sentence_num (int): The minimum number of sentences a document should contain. 86 | """ 87 | self.min_sentence_num = min_sentence_num 88 | 89 | def is_filtered(self, document: List[str]) -> bool: 90 | """ 91 | Determines if the number of sentences in the document is suitable. 92 | 93 | Args: 94 | document (List[str]): A list of sentences. 95 | 96 | Returns: 97 | bool: `True` if the document is short, otherwise `False`. 98 | """ 99 | return len(document) < self.min_sentence_num 100 | 101 | 102 | class NGWordsFilter(DocumentFilter): 103 | DICT_TYPE = 'core' 104 | SPLIT_MODE = SplitMode.A 105 | 106 | def __init__(self, ng_words_file_path: str): 107 | """ 108 | Constructs a NGWordsFilter. 109 | 110 | Args: 111 | ng_words_file_path (str): A file path of NG word list. 112 | """ 113 | self.ng_words_file_path = ng_words_file_path 114 | with open(self.ng_words_file_path, 'r', encoding='utf-8') as f: 115 | ng_words = [line.rstrip() for line in f if line.strip() != ''] 116 | self.ng_words_ptn = re.compile(r'({})'.format('|'.join(ng_words))) 117 | 118 | self.sudachi = Dictionary(dict=self.DICT_TYPE).create(self.SPLIT_MODE) 119 | 120 | def is_matched_by_morpheme(self, match: Match, sentence: str) -> bool: 121 | """ 122 | Determines if a substring in the sentence matches at the morphological level. 123 | 124 | Args: 125 | match (Match): A Match object. 126 | sentence (str): A sentence. 127 | 128 | Returns: 129 | bool: `True` if a substring is included in the sentence as a word, otherwise `False`. 130 | """ 131 | matched_begin_id, matched_end_id = match.span() 132 | 133 | morph_begin_ids = set() 134 | morph_end_ids = set() 135 | for m in self.sudachi.tokenize(sentence): 136 | morph_begin_id, morph_end_id = m.begin(), m.end() 137 | if morph_begin_id <= matched_begin_id: 138 | morph_begin_ids.add(morph_begin_id) 139 | morph_end_ids.add(morph_end_id) 140 | else: 141 | break 142 | 143 | return matched_begin_id in morph_begin_ids and matched_end_id in morph_end_ids 144 | 145 | def contain_ng_words(self, sentence: str) -> bool: 146 | """ 147 | Determines if the sentence contains NG words. 148 | 149 | Args: 150 | sentence (str): A sentence. 151 | 152 | Returns: 153 | bool: `True` if the sentence contains NG words, otherwise `False`. 154 | """ 155 | matches = [match for match in self.ng_words_ptn.finditer(sentence)] 156 | if matches: 157 | return any([self.is_matched_by_morpheme(match, sentence) for match in matches]) 158 | else: 159 | return False 160 | 161 | def is_filtered(self, document: List[str]) -> bool: 162 | """ 163 | Determines if the document contains even a single NG word. 164 | 165 | Args: 166 | document (List[str]): A list of sentences. 167 | 168 | Returns: 169 | bool: `True` if the document contains NG words, otherwise `False`. 170 | """ 171 | return any([self.contain_ng_words(sentence) for sentence in document]) 172 | -------------------------------------------------------------------------------- /pretraining/bert/corpus_preprocessing/filter/document_filter/document_filter_name.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Works Applications Co., Ltd. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from enum import Enum 16 | 17 | 18 | class DocumentFilterName(str, Enum): 19 | NG_WORDS = 'ng_words' 20 | SHORT_DOCUMENT = 'short_document' 21 | SCRIPT = 'script' 22 | 23 | def __str__(self): 24 | return self.value 25 | -------------------------------------------------------------------------------- /pretraining/bert/corpus_preprocessing/filter/sentence_filter/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Works Applications Co., Ltd. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .sentence_filter import ( 16 | SF, SequenceSentenceFilter, 17 | SentenceFilter, 18 | UrlFilter, 19 | EmailFilter, 20 | SequenceLengthFilter 21 | ) 22 | from .sentence_filter_name import SentenceFilterName 23 | -------------------------------------------------------------------------------- /pretraining/bert/corpus_preprocessing/filter/sentence_filter/sentence_filter.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Works Applications Co., Ltd. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import re 16 | from typing import List, TypeVar 17 | 18 | 19 | class SentenceFilter(object): 20 | 21 | def is_filtered(self, sentence: str) -> bool: 22 | """ 23 | Determine if the input sentence should be filtered or not. 24 | 25 | Args: 26 | sentence (str): A sentence of a document in a training corpus to pretrain model. 27 | 28 | Returns: 29 | bool: `True` if the sentence should be filtered, otherwise `False`. 30 | """ 31 | raise NotImplementedError() 32 | 33 | 34 | SF = TypeVar('SF', bound=SentenceFilter) 35 | 36 | 37 | class SequenceSentenceFilter(SentenceFilter): 38 | def __init__(self, sentence_filters: List[SF]): 39 | """ 40 | Constructs a SequenceSentenceFilter. 41 | 42 | Args: 43 | sentence_filters (List[SF]): A list of SentenceFilters. 44 | """ 45 | self._sentence_filters: List[SF] = sentence_filters 46 | 47 | def is_filtered(self, sentence: str) -> bool: 48 | """ 49 | Applies filters to the sentence in sequence. 50 | 51 | Args: 52 | sentence (str): A sentence. 53 | 54 | Returns: 55 | bool:`True` if the sentence should be filtered, otherwise `False`. 56 | """ 57 | return any([sf.is_filtered(sentence) for sf in self._sentence_filters]) 58 | 59 | 60 | class UrlFilter(SentenceFilter): 61 | url_pattern = re.compile(r'(https?|sftp?)://[\w/:%#\$&\?\(\)~\.=\+\-]+') 62 | 63 | def is_filtered(self, sentence: str) -> bool: 64 | """ 65 | Determines if the sentence contains URL. 66 | 67 | Args: 68 | sentence (str): A sentence. 69 | 70 | Returns: 71 | bool: `True` if the sentence contains URL, otherwise `False`. 72 | """ 73 | return bool(self.url_pattern.search(sentence)) 74 | 75 | 76 | class EmailFilter(SentenceFilter): 77 | mail_pattern = re.compile(r'[\w\d_-]+@[\w\d_-]+\.[\w\d._-]+') 78 | 79 | def is_filtered(self, sentence: str) -> bool: 80 | """ 81 | Determines if the sentence contains email address. 82 | 83 | Args: 84 | sentence (str): A sentence. 85 | 86 | Returns: 87 | bool: `True` if the sentence contains email address, otherwise `False`. 88 | """ 89 | return bool(self.mail_pattern.search(sentence)) 90 | 91 | 92 | class SequenceLengthFilter(SentenceFilter): 93 | 94 | def __init__(self, min_seq_len: int = 10, max_seq_len: int = 200): 95 | """ 96 | Constructs a SequenceLengthFilter. 97 | 98 | Args: 99 | min_seq_len (int): The minimum number of characters a sentence should contain. 100 | max_seq_len (int): The maximum number of characters a sentence should contain. 101 | """ 102 | self.min_seq_len = min_seq_len 103 | self.max_seq_len = max_seq_len 104 | 105 | def is_filtered(self, sentence: str) -> bool: 106 | """ 107 | Determines if the number of characters in the sentence is suitable. 108 | 109 | Args: 110 | sentence (str): A sentence. 111 | 112 | Returns: 113 | bool: `True` if the sentence length is either too short or too long, otherwise `False`. 114 | """ 115 | return len(sentence) < self.min_seq_len or self.max_seq_len < len(sentence) 116 | -------------------------------------------------------------------------------- /pretraining/bert/corpus_preprocessing/filter/sentence_filter/sentence_filter_name.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Works Applications Co., Ltd. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from enum import Enum 16 | 17 | 18 | class SentenceFilterName(str, Enum): 19 | EMAIL = 'email' 20 | URL = 'url' 21 | SEQUENCE_LENGTH = 'sequence_length' 22 | 23 | def __str__(self): 24 | return self.value 25 | -------------------------------------------------------------------------------- /pretraining/bert/corpus_preprocessing/normalizer/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Works Applications Co., Ltd. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /pretraining/bert/corpus_preprocessing/normalizer/document_normalizer/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Works Applications Co., Ltd. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .document_normalizer import ( 16 | DN, SequenceDocumentNormalizer, 17 | ConcatShortSentenceNormalizer 18 | ) 19 | from .document_normalizer_name import DocumentNormalizerName 20 | -------------------------------------------------------------------------------- /pretraining/bert/corpus_preprocessing/normalizer/document_normalizer/document_normalizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Works Applications Co., Ltd. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import List, TypeVar 16 | 17 | 18 | class DocumentNormalizer(object): 19 | 20 | def normalize(self, document: List[str]) -> List[str]: 21 | """ 22 | Normalizes the document. 23 | 24 | Args: 25 | document (List[str]): A list of sentences. 26 | 27 | Returns: 28 | List[str]: A normalized document. 29 | """ 30 | raise NotImplementedError() 31 | 32 | 33 | DN = TypeVar('DN', bound=DocumentNormalizer) 34 | 35 | 36 | class SequenceDocumentNormalizer(DocumentNormalizer): 37 | def __init__(self, document_normalizers: List[DN]): 38 | """ 39 | Constructs a SequenceDocumentNormalizer. 40 | 41 | Args: 42 | document_normalizers (List[DN]): A list of DocumentNormalizers. 43 | """ 44 | self._document_normalizers: List[DN] = document_normalizers 45 | 46 | def normalize(self, document: List[str]) -> List[str]: 47 | """ 48 | Applies normalizers to the document in sequence. 49 | 50 | Args: 51 | document (List[str]): A list of sentences. 52 | 53 | Returns: 54 | List[str]: A normalized document. 55 | """ 56 | for document_normalizer in self._document_normalizers: 57 | document = document_normalizer.normalize(document) 58 | 59 | return document 60 | 61 | 62 | class ConcatShortSentenceNormalizer(DocumentNormalizer): 63 | 64 | def __init__(self, concat_char_num: int = 2): 65 | """ 66 | Constructs a ConcatShortSentenceNormalizer. 67 | 68 | Args: 69 | concat_char_num (int): The maximum number of characters to be concatenated with the previous sentence. 70 | """ 71 | self.concat_char_num = concat_char_num 72 | 73 | def normalize(self, document: List[str]) -> List[str]: 74 | """ 75 | Joins a short sentence that are only a few characters to the previous sentence. 76 | 77 | Args: 78 | document (List[str]): A list of sentences. 79 | 80 | Returns: 81 | List[str]: A document with short sentences concatenated. 82 | """ 83 | 84 | if len(document) == 1: 85 | return document 86 | else: 87 | concat_ids = [] 88 | for i, sentence in enumerate(document): 89 | if 0 < i and len(sentence) <= self.concat_char_num: 90 | concat_ids.append(i) 91 | 92 | for concat_id in concat_ids[::-1]: 93 | document[concat_id - 1] += document.pop(concat_id) 94 | return document 95 | -------------------------------------------------------------------------------- /pretraining/bert/corpus_preprocessing/normalizer/document_normalizer/document_normalizer_name.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Works Applications Co., Ltd. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from enum import Enum 16 | 17 | 18 | class DocumentNormalizerName(str, Enum): 19 | CONCAT_SHORT_SENTENCE = 'concat_short_sentence' 20 | 21 | def __str__(self): 22 | return self.value 23 | -------------------------------------------------------------------------------- /pretraining/bert/corpus_preprocessing/normalizer/sentence_normalizer/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Works Applications Co., Ltd. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .sentence_normalizer import ( 16 | SN, SequenceSentenceNormalizer, 17 | CitationNormalizer, 18 | WhitespaceNormalizer 19 | ) 20 | from .sentence_normalizer_name import SentenceNormalizerName 21 | -------------------------------------------------------------------------------- /pretraining/bert/corpus_preprocessing/normalizer/sentence_normalizer/sentence_normalizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Works Applications Co., Ltd. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import re 16 | from typing import List, TypeVar 17 | 18 | 19 | class SentenceNormalizer(object): 20 | 21 | def normalize(self, sentence: str) -> str: 22 | """ 23 | Normalizes the sentence. 24 | 25 | Args: 26 | sentence (str): A sentence. 27 | 28 | Returns: 29 | str: A normalized sentence. 30 | """ 31 | raise NotImplementedError() 32 | 33 | 34 | SN = TypeVar('SN', bound=SentenceNormalizer) 35 | 36 | 37 | class SequenceSentenceNormalizer(SentenceNormalizer): 38 | def __init__(self, sentence_normalizers: List[SN]): 39 | """ 40 | Constructs a SequenceSentenceNormalizer. 41 | 42 | Args: 43 | sentence_normalizers (List[SN]): A list of SentenceNormalizers. 44 | """ 45 | self._sentence_normalizers: List[SN] = sentence_normalizers 46 | 47 | def normalize(self, sentence: str) -> str: 48 | """ 49 | Applies normalizers to the sentence in sequence. 50 | 51 | Args: 52 | sentence (str): A sentence. 53 | 54 | Returns: 55 | str: A normalized sentence. 56 | """ 57 | for sentence_normalizer in self._sentence_normalizers: 58 | sentence = sentence_normalizer.normalize(sentence) 59 | 60 | return sentence 61 | 62 | 63 | class WhitespaceNormalizer(SentenceNormalizer): 64 | continuous_whitespace_pattern = re.compile(r'\s+') 65 | 66 | def normalize(self, sentence: str) -> str: 67 | """ 68 | Removes invisible characters and replaces consecutive whitespace with a single whitespace. 69 | 70 | Args: 71 | sentence (str): A sentence. 72 | 73 | Returns: 74 | str: A sentence with consecutive whitespace and invisible characters removed. 75 | """ 76 | sentence = "".join(c for c in sentence if c.isprintable()) 77 | sentence = self.continuous_whitespace_pattern.sub(' ', sentence) 78 | 79 | return sentence 80 | 81 | 82 | class CitationNormalizer(SentenceNormalizer): 83 | citation_pattern = re.compile(r'\[\d+?\]|\[要.+?\]|\{\{+[^{}]+?\}\}+|\[(要出典|リンク切れ|.+?\?)\]') 84 | 85 | def normalize(self, sentence: str) -> str: 86 | """ 87 | Removes citation markers. 88 | 89 | Args: 90 | sentence (str): A sentence. 91 | 92 | Returns: 93 | str: A sentence with citation markers removed. 94 | """ 95 | return self.citation_pattern.sub('', sentence) 96 | -------------------------------------------------------------------------------- /pretraining/bert/corpus_preprocessing/normalizer/sentence_normalizer/sentence_normalizer_name.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Works Applications Co., Ltd. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from enum import Enum 16 | 17 | 18 | class SentenceNormalizerName(str, Enum): 19 | CITATION = 'citation' 20 | WHITESPACE = 'whitespace' 21 | 22 | def __str__(self): 23 | return self.value 24 | -------------------------------------------------------------------------------- /pretraining/bert/prepare_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Works Applications Co., Ltd. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import argparse 16 | import os 17 | import tensorflow_datasets as tfds 18 | from bunkai import Bunkai 19 | from progressbar import progressbar as tqdm 20 | from tensorflow_datasets.core import DatasetInfo 21 | from typing import List, Tuple 22 | 23 | 24 | TARGET_DATASETS = ['train', 'validation', 'test'] 25 | 26 | START_ARTICLE_DELIMITER = '_START_ARTICLE_' 27 | START_PARAGRAPH_DELIMITER = '_START_PARAGRAPH_' 28 | NEW_LINE_DELIMITER = '_NEWLINE_' 29 | 30 | 31 | def split_article(article_text: str) -> List[List[str]]: 32 | """ 33 | Splits an article into paragraphs. 34 | 35 | Args: article_text (str): An article in wikipedia. 36 | 37 | Returns: 38 | List[List[str]]: List of paragraphs containing sentences 39 | """ 40 | paragraphs = [] 41 | lines = article_text.split('\n') 42 | for i in range(2, len(lines), 2): 43 | if lines[i-1] == START_PARAGRAPH_DELIMITER: 44 | paragraphs.append(lines[i].split(NEW_LINE_DELIMITER)) 45 | 46 | return paragraphs 47 | 48 | 49 | def download_wiki40b_corpus(target: str) -> Tuple[DatasetInfo, List[str]]: 50 | """ 51 | Downloads the target corpus and disambiguates sentence boundaries for sentences in the corpus. 52 | 53 | Args: 54 | target (str): Target dataset name. 55 | 56 | Returns: 57 | (tuple): tuple containing: 58 | ds_info (DatasetInfo): Dataset information for target corpus. 59 | all_sentences (List[str]): Sentences in the target corpus. 60 | """ 61 | 62 | ds, ds_info = tfds.load(name='wiki40b/ja', split=TARGET_DATASETS, with_info=True) 63 | 64 | bunkai = Bunkai() 65 | 66 | all_sentences = [] 67 | for line in tqdm(tfds.as_dataframe(ds[TARGET_DATASETS.index(target)], ds_info).itertuples()): 68 | paragraphs = split_article(line.text.decode('utf-8')) 69 | all_sentences.append(START_ARTICLE_DELIMITER) 70 | for paragraph in paragraphs: 71 | all_sentences.append(START_PARAGRAPH_DELIMITER) 72 | for sentences in paragraph: 73 | for sentence in bunkai(sentences): 74 | if sentence: 75 | all_sentences.append(sentence) 76 | 77 | return ds_info, all_sentences 78 | 79 | 80 | def main(): 81 | args = get_args() 82 | 83 | dataset_info, corpus_sentences = download_wiki40b_corpus(args.target) 84 | 85 | os.makedirs(args.output_dir, exist_ok=True) 86 | 87 | with open(os.path.join(args.output_dir, 'dataset_info_{}.json'.format(args.target)), 'w') as f: 88 | f.write(dataset_info.as_json) 89 | 90 | with open(os.path.join(args.output_dir, 'ja_wiki40b_{}.txt'.format(args.target)), 'w') as f: 91 | for sentence in corpus_sentences: 92 | f.write(sentence + '\n') 93 | 94 | 95 | def get_args(): 96 | parser = argparse.ArgumentParser(description='Download and parse target dataset.') 97 | parser.add_argument('-t', '--target', choices=['train', 'validation', 'test'], help='Target dataset.') 98 | parser.add_argument('-o', '--output_dir', required=True, 99 | help='The path to the target directory in which to save a corpus file and a config file.') 100 | 101 | args = parser.parse_args() 102 | 103 | return args 104 | 105 | 106 | if __name__ == '__main__': 107 | main() 108 | -------------------------------------------------------------------------------- /pretraining/bert/preprocess_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Works Applications Co., Ltd. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import argparse 16 | from logzero import logger 17 | from progressbar import progressbar 18 | from typing import List 19 | 20 | from corpus_preprocessing.filter.sentence_filter import ( 21 | SF, SentenceFilterName, 22 | SequenceSentenceFilter, 23 | EmailFilter, UrlFilter, SequenceLengthFilter 24 | ) 25 | from corpus_preprocessing.filter.document_filter import ( 26 | DF, DocumentFilterName, 27 | SequenceDocumentFilter, 28 | NGWordsFilter, ShortDocumentFilter, ScriptFilter 29 | ) 30 | from corpus_preprocessing.normalizer.sentence_normalizer import ( 31 | SN, SentenceNormalizerName, 32 | SequenceSentenceNormalizer, 33 | CitationNormalizer, WhitespaceNormalizer 34 | ) 35 | 36 | from corpus_preprocessing.normalizer.document_normalizer import ( 37 | DN, DocumentNormalizerName, 38 | SequenceDocumentNormalizer, 39 | ConcatShortSentenceNormalizer 40 | ) 41 | 42 | 43 | def load_dataset(input_dataset_path: str) -> List[List[str]]: 44 | documents = [] 45 | with open(input_dataset_path, 'r', encoding='utf-8') as f: 46 | document = [] 47 | for sentence in f: 48 | sentence = sentence.strip() 49 | if sentence != '': 50 | document.append(sentence) 51 | else: 52 | documents.append(document) 53 | document = [] 54 | 55 | return documents 56 | 57 | 58 | def load_sentence_filters(sentence_filter_names: List[str], **kwargs) -> List[SF]: 59 | sentence_filters = [] 60 | for sentence_filter_name in sentence_filter_names: 61 | if sentence_filter_name == SentenceFilterName.EMAIL: 62 | sentence_filters.append(EmailFilter()) 63 | elif sentence_filter_name == SentenceFilterName.URL: 64 | sentence_filters.append(UrlFilter()) 65 | elif sentence_filter_name == SentenceFilterName.SEQUENCE_LENGTH: 66 | sentence_filters.append(SequenceLengthFilter(min_seq_len=kwargs['min_seq_len'], 67 | max_seq_len=kwargs['max_seq_len'])) 68 | else: 69 | raise ValueError('Invalid sentence filter name `{}`: {}'.format( 70 | sentence_filter_name, ','.join(map(str, SentenceFilterName))) 71 | ) 72 | 73 | return sentence_filters 74 | 75 | 76 | def load_document_filters(document_filter_names: List[str], **kwargs) -> List[DF]: 77 | document_filters = [] 78 | for document_filter_name in document_filter_names: 79 | if document_filter_name == DocumentFilterName.SHORT_DOCUMENT: 80 | document_filters.append(ShortDocumentFilter(min_sentence_num=kwargs['min_sentence_num'])) 81 | elif document_filter_name == DocumentFilterName.SCRIPT: 82 | document_filters.append(ScriptFilter()) 83 | elif document_filter_name == DocumentFilterName.NG_WORDS: 84 | document_filters.append(NGWordsFilter(ng_words_file_path=kwargs['ng_words_file_path'])) 85 | else: 86 | raise ValueError('Invalid document filter name `{}`: {}'.format( 87 | document_filter_name, ','.join(map(str, DocumentFilterName))) 88 | ) 89 | 90 | return document_filters 91 | 92 | 93 | def load_sentence_normalizers(sentence_normalizer_names: List[str], **kwargs) -> List[SN]: 94 | sentence_normalizers = [] 95 | for sentence_normalizer_name in sentence_normalizer_names: 96 | if sentence_normalizer_name == SentenceNormalizerName.CITATION: 97 | sentence_normalizers.append(CitationNormalizer()) 98 | elif sentence_normalizer_name == SentenceNormalizerName.WHITESPACE: 99 | sentence_normalizers.append(WhitespaceNormalizer()) 100 | else: 101 | raise ValueError('Invalid sentence normalizer name `{}`: {}'.format( 102 | sentence_normalizer_name, ','.join(map(str, SentenceNormalizerName))) 103 | ) 104 | 105 | return sentence_normalizers 106 | 107 | 108 | def load_document_normalizers(document_normalizer_names: List[str], **kwargs) -> List[DN]: 109 | document_normalizers = [] 110 | for document_normalizer_name in document_normalizer_names: 111 | if document_normalizer_name == DocumentNormalizerName.CONCAT_SHORT_SENTENCE: 112 | document_normalizers.append(ConcatShortSentenceNormalizer(concat_char_num=kwargs['concat_char_num'])) 113 | else: 114 | raise ValueError('Invalid document normalizer name `{}`: {}'.format( 115 | document_normalizer_name, ','.join(map(str, DocumentNormalizerName))) 116 | ) 117 | 118 | return document_normalizers 119 | 120 | 121 | def main(): 122 | args = get_args() 123 | for k, v in vars(args).items(): 124 | logger.info('{}: {}'.format(k, v)) 125 | 126 | sequence_sentence_filter = SequenceSentenceFilter( 127 | load_sentence_filters(args.sentence_filter_names, 128 | min_seq_len=args.min_seq_len, 129 | max_seq_len=args.max_seq_len) 130 | ) 131 | 132 | sequence_document_filter = SequenceDocumentFilter( 133 | load_document_filters(args.document_filter_names, 134 | min_sentence_num=args.min_sentence_num, 135 | ng_words_file_path=args.ng_words_file_path) 136 | ) 137 | 138 | sequence_sentence_normalizer = SequenceSentenceNormalizer( 139 | load_sentence_normalizers(args.sentence_normalizer_names) 140 | ) 141 | 142 | sequence_document_normalizer = SequenceDocumentNormalizer( 143 | load_document_normalizers(args.document_normalizer_names, 144 | concat_char_num=args.concat_char_num) 145 | ) 146 | 147 | documents = load_dataset(args.input_dataset_path) 148 | preprocessed_documents = [] 149 | for document in progressbar(documents): 150 | # normalize 151 | document = [sequence_sentence_normalizer.normalize(s) for s in document] 152 | document = sequence_document_normalizer.normalize(document) 153 | 154 | # filter 155 | document = [s for s in document if not sequence_sentence_filter.is_filtered(s)] 156 | if not sequence_document_filter.is_filtered(document): 157 | preprocessed_documents.append(document) 158 | 159 | logger.info('#Document w/o filtering:\t{}'.format(len(documents))) 160 | logger.info('#Document w/ filtering:\t{}'.format(len(preprocessed_documents))) 161 | 162 | with open(args.output_dataset_path, 'w', encoding='utf-8') as f: 163 | f.write('\n\n'.join(['\n'.join(document) for document in preprocessed_documents])) 164 | 165 | 166 | def get_args(): 167 | parser = argparse.ArgumentParser(description='Cleaning and corpus_preprocessing dataset.') 168 | parser.add_argument('-i', '--input_dataset_path', required=True, 169 | help='Input dataset.') 170 | parser.add_argument('-o', '--output_dataset_path', required=True, 171 | help='Output dataset.') 172 | 173 | # Sentence filters 174 | parser.add_argument('-sf', '--sentence_filter_names', nargs='*', default=list(), 175 | choices=SentenceFilterName, type=SentenceFilterName, 176 | help='A list of filter names to remove unnecessary sentences from the training corpus.') 177 | parser.add_argument('--min_seq_len', type=int, default=10, 178 | help='The minimum number of characters a sentence should contain (for SequenceLengthFilter).') 179 | parser.add_argument('--max_seq_len', type=int, default=200, 180 | help='The maximum number of characters a sentence should contain (for SequenceLengthFilter).') 181 | 182 | # Document filters 183 | parser.add_argument('-df', '--document_filter_names', nargs='*', default=list(), 184 | choices=DocumentFilterName, type=DocumentFilterName, 185 | help='A list of filter names to remove unnecessary documents from the training corpus.') 186 | parser.add_argument('--min_sentence_num', type=int, default=5, 187 | help='The minimum number of sentences a document should contain (for ShortDocumentFilter).') 188 | parser.add_argument('--ng_words_file_path', 189 | default='./resources/ng_words.txt', 190 | help='A file path of NG word list (for NGWordsFilter).') 191 | 192 | # Sentence normalizers 193 | parser.add_argument('-sn', '--sentence_normalizer_names', nargs='*', default=list(), 194 | choices=SentenceNormalizerName, type=SentenceNormalizerName, 195 | help='A list of filter names to normalize sentences from the training corpus.') 196 | 197 | # Document normalizers 198 | parser.add_argument('-dn', '--document_normalizer_names', nargs='*', default=list(), 199 | choices=DocumentNormalizerName, type=DocumentNormalizerName, 200 | help='A list of filter names to normalize documents from the training corpus.') 201 | parser.add_argument('--concat_char_num', type=int, default=2, 202 | help='The maximum number of characters to be concatenated with the previous sentence ' 203 | '(for ConcatShortSentenceNormalizer).') 204 | 205 | args = parser.parse_args() 206 | 207 | return args 208 | 209 | 210 | if __name__ == '__main__': 211 | main() 212 | -------------------------------------------------------------------------------- /pretraining/bert/requirements.txt: -------------------------------------------------------------------------------- 1 | # for pretraining bert 2 | bunkai~=1.4.3 3 | tensorflow>=2.5.0 4 | tensorflow_datasets>=4.3.0 5 | -------------------------------------------------------------------------------- /pretraining/bert/resources/ng_words.txt: -------------------------------------------------------------------------------- 1 | fuck 2 | g スポット 3 | sm女王 4 | tenga 5 | あばずれ 6 | あぱずれ 7 | あほ 8 | うざ 9 | うんこ 10 | え〇 11 | えっち 12 | おしっこ 13 | おしりのあな 14 | おっぱい 15 | おもらし 16 | かたわ 17 | きちがい 18 | きめぇ 19 | きめえ 20 | くそ 21 | せんずり 22 | ち〇 23 | ちんぐり 24 | ちんこ 25 | つるぺた 26 | つんぼ 27 | ふたなり 28 | ぶさいく 29 | ぶす 30 | ま〇 31 | まんぐり 32 | まんこ 33 | めくら 34 | やりまん 35 | アスペ 36 | アスホール 37 | アナリングス 38 | アナル 39 | アヌス 40 | アバズレ 41 | アパズレ 42 | アホ 43 | イマラチオ 44 | イメクラ 45 | イラマチオ 46 | ウザ 47 | ウンコ 48 | エ〇 49 | エッチ 50 | エロ 51 | オカマ 52 | オッパイ 53 | オナ 54 | オナニー 55 | オフパコ 56 | オマンコ 57 | オルガズム 58 | オーガズム 59 | カス 60 | ガイジ 61 | キチガイ 62 | キモ 63 | クズ 64 | クソ 65 | クリトリス 66 | クンニ 67 | クンニリングス 68 | グループ・セックス 69 | グロ 70 | ゲイボーイ 71 | ゲイ・セックス 72 | ゲロ 73 | コカイン 74 | コキ 75 | コンドーム 76 | ザーメン 77 | シコ 78 | ショタ 79 | スカトロ 80 | スケベ 81 | ストリップ劇場 82 | スマタ 83 | セクロス 84 | セックス 85 | セフレ 86 | センズリ 87 | ダッチワイフ 88 | チ〇 89 | テレフォンセックス 90 | ディルド 91 | ディープ・スロート 92 | デブ 93 | デリヘル 94 | デートレイプ 95 | ドキュン 96 | ナマポ 97 | ニガー 98 | ヌい 99 | ヌく 100 | ヌけ 101 | ネオ・ナチ 102 | ハメ撮り 103 | パイズリ 104 | パイパン 105 | パンチラ 106 | パンティー 107 | ビッチ 108 | ピロートーク 109 | ファック 110 | フェラ 111 | フェラチオ 112 | ブサイク 113 | ブス 114 | プリンス アルバート ピアス 115 | ペッティング 116 | ペニス 117 | ペニスバンド 118 | ホモ 119 | ボンテージ 120 | ボールギャグ 121 | ポルノグラフィー 122 | マ〇 123 | マザー・ファッカー 124 | マスターベーション 125 | マラ 126 | マンコ 127 | ヤラせ 128 | ラブホ 129 | リスカ 130 | リストカット 131 | リョナ 132 | リンチ 133 | レイプ 134 | レズ 135 | 不細工 136 | 中出し 137 | 乱交 138 | 二穴 139 | 人妻 140 | 側位 141 | 児童性虐待 142 | 前戯 143 | 勃起する 144 | 合いの子 145 | 四十八手 146 | 売り専 147 | 売国 148 | 売女 149 | 売春婦 150 | 外人 151 | 夢精 152 | 大人のおもちゃ 153 | 大人のオモチャ 154 | 大人の玩具 155 | 大陰唇 156 | 射精 157 | 尻軽 158 | 尿道プレイ 159 | 巨乳 160 | 巨根 161 | 強姦犯 162 | 後戯 163 | 後背位 164 | 手コキ 165 | 手マン 166 | 援交 167 | 援助交際 168 | 支那 169 | 新しいポルノ 170 | 正常位 171 | 殺し方 172 | 殺人方法 173 | 氏ね 174 | 氏んだ 175 | 氏んで 176 | 気違い 177 | 池沼 178 | 淫乱 179 | 潮吹き女 180 | 潮吹き男性 181 | 熟女 182 | 獣姦 183 | 玉なめ 184 | 玉舐め 185 | 男根 186 | 痴呆 187 | 穴兄弟 188 | 竿姉妹 189 | 筆おろし 190 | 精液 191 | 糞便 192 | 糞尿愛好症 193 | 素股 194 | 緊縛 195 | 老害 196 | 肉便器 197 | 自慰 198 | 裸の女性 199 | 貞操帯 200 | 賢者タイム 201 | 足フェチ 202 | 輪姦 203 | 近親相姦 204 | 阿呆 205 | 陰毛 206 | 電マ 207 | 顔射 208 | 顔面騎乗 209 | 騎上位 210 | 騎乗位 -------------------------------------------------------------------------------- /pretraining/bert/run_create_pretraining_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | cd $(dirname $0) && cd ../../ 4 | 5 | WORK_DIR="pretraining/bert" 6 | TARGET="small" 7 | 8 | seq -f %02g 1 8|xargs -L 1 -I {} -P 8 python3 pretraining/bert/create_pretraining_data.py \ 9 | --input_file $WORK_DIR/datasets/corpus_splitted_by_paragraph/ja_wiki40b_${TARGET}.paragraph{}.txt \ 10 | --output_file $WORK_DIR/models/pretraining_${TARGET}_{}.tf_record \ 11 | --vocab_file $WORK_DIR/models/vocab.txt \ 12 | --tokenizer_type wordpiece \ 13 | --word_form_type normalized \ 14 | --split_mode C \ 15 | --sudachi_dic_type core \ 16 | --do_whole_word_mask \ 17 | --max_seq_length 512 \ 18 | --max_predictions_per_seq 80 \ 19 | --dupe_factor 10 -------------------------------------------------------------------------------- /pretraining/bert/run_prepare_dataset.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cd $(dirname $0) 4 | 5 | DATASET_DIR="./datasets" 6 | CORPUS_DIR="${DATASET_DIR}/corpus" 7 | SPLITTED_CORPUS_DIR="${DATASET_DIR}/corpus_splitted_by_paragraph" 8 | 9 | # download dataset 10 | mkdir -p ${CORPUS_DIR} 11 | for target in "train" "validation" "test"; do 12 | time python3 prepare_dataset.py --target ${target} --output_dir ${CORPUS_DIR} 13 | done 14 | 15 | ### split dataset for each paragraph 16 | 17 | mkdir -p ${SPLITTED_CORPUS_DIR} 18 | for target in "train" "validation" "test"; do 19 | cat ${CORPUS_DIR}/ja_wiki40b_${target}.txt | sed -e "s/_START_ARTICLE_//g" -e "s/_START_PARAGRAPH_//g" | cat -s > ${SPLITTED_CORPUS_DIR}/ja_wiki40b_${target}.paragraph.txt 20 | done 21 | 22 | -------------------------------------------------------------------------------- /pretraining/bert/split_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Works Applications Co., Ltd. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import argparse 16 | import os 17 | from typing import List 18 | 19 | 20 | def write_lines(tmp_lines: List[str], file_dir: str, file_name: str, file_id: int, file_ext: str): 21 | with open(os.path.join(file_dir, f"{file_name}{file_id}{file_ext}"), 'w') as f: 22 | for line in tmp_lines: 23 | print(line, file=f) 24 | 25 | 26 | def main(): 27 | args = get_args() 28 | 29 | file_dir = os.path.dirname(args.input_file) 30 | file_name, file_ext = os.path.splitext(os.path.basename(args.input_file)) 31 | 32 | with open(args.input_file, 'r') as f: 33 | 34 | tmp_lines = [] 35 | file_id = 1 36 | for line in f: 37 | tmp_lines.append(line.strip()) 38 | if len(tmp_lines) > args.line_per_file and line == '\n': 39 | write_lines(tmp_lines, file_dir, file_name, file_id, file_ext) 40 | file_id += 1 41 | tmp_lines = [] 42 | 43 | if len(tmp_lines) > 0: 44 | write_lines(tmp_lines, file_dir, file_name, file_id, file_ext) 45 | 46 | 47 | def get_args(): 48 | parser = argparse.ArgumentParser(description='Split dataset.') 49 | 50 | parser.add_argument('-i', '--input_file', help='Input file to be splitted (corpus splitted by paragraph).') 51 | parser.add_argument('--line_per_file', type=int, help='Max number of lines per file.') 52 | 53 | args = parser.parse_args() 54 | 55 | return args 56 | 57 | 58 | if __name__ == '__main__': 59 | main() 60 | -------------------------------------------------------------------------------- /pretraining/bert/train_pos_substitution_tokenizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Works Applications Co., Ltd. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import argparse 16 | import os 17 | from glob import glob 18 | 19 | from sudachitra.pretokenizer import PartOfSpeechSubstitutionTokenizer 20 | 21 | 22 | def main(): 23 | args = get_args() 24 | 25 | if args.input_file: 26 | files = [args.input_file] 27 | elif args.input_dir: 28 | files = glob(os.path.join(args.input_dir, '*.txt')) 29 | else: 30 | raise ValueError("`input_file` or `input_dir` must be specified.") 31 | 32 | pos_tokenizer = PartOfSpeechSubstitutionTokenizer( 33 | split_mode=args.split_mode, 34 | dict_type=args.dict_type, 35 | word_form_type=args.word_form_type 36 | ) 37 | 38 | pos_tokenizer.train( 39 | files, 40 | token_size=args.token_size, 41 | min_frequency=args.min_frequency, 42 | limit_character=args.limit_character, 43 | special_tokens=args.special_tokens 44 | ) 45 | 46 | os.makedirs(os.path.dirname(args.output_file), exist_ok=True) 47 | pos_tokenizer.save_vocab(args.output_file) 48 | 49 | 50 | def get_args(): 51 | parser = argparse.ArgumentParser(description='Trainer of part-of-speech substitution tokenizer.') 52 | 53 | # Input 54 | parser.add_argument('-f', '--input_file', default='', 55 | help='Input file to train tokenizer.') 56 | parser.add_argument('-d', '--input_dir', default='', 57 | help='Input directory containing files to train tokenizer.') 58 | 59 | # Parameters 60 | parser.add_argument('--token_size', type=int, default=32000, 61 | help='The size of the vocabulary, excluding special tokens and pos tags.') 62 | parser.add_argument('--min_frequency', type=int, default=1, 63 | help='Ignores all words (tokens and characters) with total frequency lower than this.') 64 | parser.add_argument('--limit_character', type=int, default=1000, 65 | help='The maximum different characters to keep in the vocabulary.') 66 | parser.add_argument('--special_tokens', nargs='*', default=["[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]"], 67 | help='A list of special tokens the model should know of.') 68 | 69 | # Tokenization 70 | parser.add_argument('--dict_type', default='core', choices=['small', 'core', 'full'], 71 | help='Sudachi dictionary type to be used for tokenization.') 72 | parser.add_argument('--split_mode', default='C', choices=['A', 'B', 'C', 'a', 'b', 'c'], 73 | help='The mode of splitting.') 74 | parser.add_argument('--word_form_type', default='surface', 75 | choices=['surface', 'dictionary', 'normalized', 'dictionary_and_surface', 'normalized_and_surface'], 76 | help='Word form type for each morpheme.') 77 | 78 | # output 79 | parser.add_argument('-o', '--output_file', 80 | help='The output path where the vocabulary will be stored.') 81 | 82 | args = parser.parse_args() 83 | 84 | return args 85 | 86 | 87 | if __name__ == '__main__': 88 | main() 89 | -------------------------------------------------------------------------------- /pretraining/bert/train_wordpiece_tokenizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Works Applications Co., Ltd. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import argparse 16 | import os 17 | from glob import glob 18 | 19 | from sudachitra import get_split_mode 20 | from sudachitra.pretokenizer import JapaneseBertWordPieceTokenizer 21 | from sudachitra.word_formatter import WordFormTypes 22 | from sudachitra.pretokenizer import pretokenizer_handler 23 | from sudachipy import Dictionary 24 | 25 | 26 | def main(): 27 | args = get_args() 28 | 29 | if args.input_file: 30 | files = [args.input_file] 31 | elif args.input_dir: 32 | files = glob(os.path.join(args.input_dir, '*.txt')) 33 | else: 34 | raise ValueError("`input_file` or `input_dir` must be specified.") 35 | 36 | settings = dict( 37 | vocab_size=args.vocab_size, 38 | min_frequency=args.min_frequency, 39 | limit_alphabet=args.limit_alphabet 40 | ) 41 | 42 | wp_tokenizer = JapaneseBertWordPieceTokenizer(do_strip=args.do_strip, 43 | do_lower_case=args.do_lower_case, 44 | do_nfkc=args.do_nfkc) 45 | 46 | sudachi_dict = Dictionary(dict=args.dict_type) 47 | sudachi_pre_tokenizer = sudachi_dict.pre_tokenizer( 48 | mode=get_split_mode(args.split_mode), 49 | handler=pretokenizer_handler(sudachi_dict, word_form_type=args.word_form_type) 50 | ) 51 | wp_tokenizer.pre_tokenizer = sudachi_pre_tokenizer 52 | 53 | if args.disable_parallelism: 54 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 55 | wp_tokenizer.train(files, **settings) 56 | os.environ["TOKENIZERS_PARALLELISM"] = "true" 57 | 58 | os.makedirs(args.output_dir, exist_ok=True) 59 | wp_tokenizer.save(os.path.join(args.output_dir, args.config_name)) 60 | wp_tokenizer.save_vocab(args.output_dir, args.vocab_prefix) 61 | 62 | 63 | def get_args(): 64 | parser = argparse.ArgumentParser(description='Trainer of wordpiece tokenizer.') 65 | 66 | # input 67 | parser.add_argument('-f', '--input_file', default='', 68 | help='Input file to train tokenizer.') 69 | parser.add_argument('-d', '--input_dir', default='', 70 | help='Input directory containing files to train tokenizer.') 71 | 72 | # Normalizers 73 | parser.add_argument('--do_strip', action='store_true', default=False, 74 | help='Removes all whitespace characters on both sides of the input.') 75 | parser.add_argument('--do_lower_case', action='store_true', default=False, 76 | help='Replaces all uppercase to lowercase.') 77 | parser.add_argument('--do_nfkc', action='store_true', default=False, 78 | help='NFKC unicode normalization.') 79 | 80 | # Parameters 81 | parser.add_argument('--vocab_size', type=int, default=32000, 82 | help='The size of the final vocabulary, including all tokens and alphabet.') 83 | parser.add_argument('--min_frequency', type=int, default=1, 84 | help='The minimum frequency a pair should have in order to be merged.') 85 | parser.add_argument('--limit_alphabet', type=int, default=5000, 86 | help='The maximum different characters to keep in the alphabet.') 87 | 88 | # Tokenization 89 | parser.add_argument('--dict_type', default='core', choices=['small', 'core', 'full'], 90 | help='Sudachi dictionary type to be used for tokenization.') 91 | parser.add_argument('--split_mode', default='C', choices=['A', 'B', 'C', 'a', 'b', 'c'], 92 | help='The mode of splitting.') 93 | parser.add_argument('--word_form_type', default='surface', 94 | choices=WordFormTypes, type=WordFormTypes, 95 | help='Word form type for each morpheme.') 96 | 97 | # Wordpiece 98 | parser.add_argument('--disable_parallelism', action='store_true', default=False, 99 | help='This flag argument disables parallel processing only for wordpiece training. ' 100 | 'Note that this flag rewrites the value of a global environment variable ' 101 | '(TOKENIZERS_PARALLELISM), so it may affect other programs as well.') 102 | 103 | # Output 104 | parser.add_argument('-o', '--output_dir', 105 | help='The output dir to be saved vocabulary and config file.') 106 | parser.add_argument('-c', '--config_name', default='config.json', 107 | help='Output json file name.') 108 | parser.add_argument('-v', '--vocab_prefix', default='', 109 | help='Prefix of vocab file.') 110 | 111 | args = parser.parse_args() 112 | 113 | return args 114 | 115 | 116 | if __name__ == '__main__': 117 | main() 118 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | logzero~=1.7.0 2 | progressbar2~=3.53.1 3 | tokenizers>=0.10.3 4 | transformers>=4.6.1 5 | sudachipy>=0.6.2 6 | sudachidict_core==20210802.* -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Works Applications Co., Ltd. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from setuptools import find_packages, setup 16 | 17 | setup( 18 | name="SudachiTra", 19 | use_scm_version=True, 20 | setup_requires=["setuptools_scm"], 21 | description="Japanese tokenizer for Transformers.", 22 | long_description=open("README.md", encoding="utf-8").read(), 23 | long_description_content_type="text/markdown", 24 | url="https://github.com/WorksApplications/SudachiTra", 25 | license="Apache-2.0", 26 | author="Works Applications", 27 | author_email="sudachi@worksap.co.jp", 28 | packages=find_packages(include=['sudachitra']), 29 | install_requires=[ 30 | "logzero~=1.7.0", 31 | "tokenizers>=0.10.3", 32 | "transformers>=4.6.1", 33 | "sudachipy>=0.6.2", 34 | "sudachidict_core>=20210802" 35 | ], 36 | extras_require={ 37 | "pretrain":[ 38 | "progressbar2~=3.53.1", 39 | "bunkai~=1.4.3", 40 | "pytextspan>=0.5.4", 41 | "tensorflow>=2.5.0", 42 | "tensorflow_datasets>=4.3.0" 43 | ] 44 | }, 45 | include_package_data=True 46 | ) 47 | -------------------------------------------------------------------------------- /sudachitra/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Works Applications Co., Ltd. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .sudachipy_word_tokenizer import SudachipyWordTokenizer, get_split_mode 16 | from .tokenization_bert_sudachipy import BertSudachipyTokenizer 17 | from .tokenization_electra_sudachipy import ElectraSudachipyTokenizer 18 | -------------------------------------------------------------------------------- /sudachitra/conjugation_preserving_normalizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Works Applications Co., Ltd. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import json 16 | from sudachipy.morpheme import Morpheme 17 | from sudachipy.sudachipy import Dictionary 18 | 19 | from .word_formatter import CONJUGATIVE_POS 20 | 21 | 22 | class ConjugationPreservingNormalizer: 23 | def __init__(self, inflection_table_path, conjugation_type_table_path, sudachi_dict: Dictionary) -> None: 24 | """ 25 | Constructs a ConjugationPreservingNormalizer. 26 | 27 | Args: 28 | inflection_table_path (:obj:`str`): 29 | In this table, get the difference between the inflected form to the end form each conjugation. 30 | conjugation_type_table_path (:obj:`str`): 31 | In this table, the conjugation types such as potential verb 32 | that change due to the normalization of the Sudachi dictionary are stored in blacklist format. 33 | sudachi_dict (:obj:`Dictionary`): 34 | dictionary instance. 35 | """ 36 | self.sudachi_dict = sudachi_dict 37 | self.id2pos = list(enumerate(self.sudachi_dict.pos_matcher([()]))) 38 | self.pos2id = {pos: pos_id for (pos_id, pos) in self.id2pos} 39 | self.is_necessary_inflection = sudachi_dict.pos_matcher( 40 | lambda p: p[0] in CONJUGATIVE_POS and (p[4] == "サ行変格" or p[5] != "終止形-一般") 41 | ) # 「為る->する」のため、サ行変格のみ 終止形-一般も変化 42 | 43 | with open(inflection_table_path) as jf: 44 | self.inflection_table = self._load_json(jf, "inflection") 45 | with open(conjugation_type_table_path) as jf: 46 | self.conjugation_type_table = self._load_json(jf, "conjugation_type") 47 | 48 | def _load_json(self, json_file, table_type) -> dict: 49 | """ 50 | Convert json to python file when loading. 51 | 52 | Args: 53 | json_file (:obj:`TextIOWrapper`): 54 | json file object. 55 | table_type (:obj:`str`): 56 | Select "inflection" or "conjugation_type" and load in each format. 57 | 58 | Returns: 59 | Dict: 60 | inflection: Key of int and Inflection table of list. 61 | conjugation_type: Key of tuple and Conjugation type conversion destination of str. 62 | """ 63 | if table_type not in ['inflection', 'conjugation_type']: 64 | raise ValueError('Invalid table_type error : {}'.format(table_type)) 65 | 66 | data = json.load(json_file) 67 | table = {} 68 | for pos_0, convert_tables in data.items(): 69 | for key, convert_table in convert_tables.items(): 70 | if table_type == "inflection": 71 | pos_4, pos_5 = key.split("|") 72 | for pos in self.sudachi_dict.pos_matcher(lambda p: p[0] == pos_0 and p[4] == pos_4 and p[5] == pos_5): 73 | pos_id = self.pos2id[pos] 74 | table[pos_id] = convert_table 75 | elif table_type == "conjugation_type": 76 | surface_token, reading, normalized_token, pos_4 = key.split("|") 77 | for pos in self.sudachi_dict.pos_matcher(lambda p: p[0] == pos_0 and p[4] == pos_4): 78 | pos_id = self.pos2id[pos] 79 | table[(pos_id, surface_token, reading, normalized_token)] = convert_table 80 | 81 | return table 82 | 83 | def _change_pos(self, key: tuple, pos: tuple) -> tuple: 84 | """ 85 | Make conjugation-type changes. 86 | 87 | If the part of speech does not exist after the change, it will be grouped into "一般". (Example: 撥音便, etc.) 88 | 89 | Args: 90 | key (:obj:`tuple`): (pos_id, surface, reading_form, normalized_form) 91 | pos (:obj:`tuple`): Get the part of speech. 92 | 93 | Returns: 94 | tuple: Changed part of speech 95 | """ 96 | conj_type = self.conjugation_type_table[key] 97 | res = (*pos[:4], conj_type, pos[5]) 98 | if res not in self.pos2id: 99 | conj_form = pos[5].split("-")[0] + "-一般" 100 | res = (*pos[:4], conj_type, conj_form) 101 | if res not in self.pos2id: 102 | res = (pos[0], "一般", pos[2], pos[3], conj_type, conj_form) 103 | 104 | assert res in self.pos2id 105 | return res 106 | 107 | def _is_changed_conjugation_type(self, key: tuple) -> bool: 108 | """ 109 | Check if the conjugation type changes after being normalized. 110 | 111 | Args: 112 | key (:obj:`tuple`): (pos_id, surface, reading_form, normalized_form) 113 | 114 | Returns: bool 115 | """ 116 | return key in self.conjugation_type_table 117 | 118 | def normalized(self, morpheme: Morpheme) -> str: 119 | """ 120 | The output token retain conjugation information in word normalization by Sudachi tokenizer 121 | 122 | Args: 123 | morpheme (:obj:`Morpheme`): A Morpheme obtained from the analysis by sudachipy. 124 | 125 | Returns: 126 | str: Normalized token with conjugate information retained. 127 | """ 128 | normalized_token = morpheme.normalized_form() 129 | if not self.is_necessary_inflection(morpheme): 130 | return normalized_token 131 | 132 | pos_id = morpheme.part_of_speech_id() 133 | conj_type_table_key = (pos_id, morpheme.surface(), morpheme.reading_form(), normalized_token) 134 | if self._is_changed_conjugation_type(conj_type_table_key): 135 | pos = self._change_pos(conj_type_table_key, morpheme.part_of_speech()) 136 | if pos[5] == "終止形-一般": 137 | return normalized_token 138 | pos_id = self.pos2id[pos] 139 | 140 | for convert_table in self.inflection_table[pos_id]: 141 | if convert_table == ['', '']: 142 | return normalized_token 143 | if normalized_token.endswith(convert_table[0]): 144 | src = convert_table[0][::-1] 145 | tgt = convert_table[1][::-1] 146 | return normalized_token[::-1].replace(src, tgt, 1)[::-1] 147 | 148 | return normalized_token 149 | -------------------------------------------------------------------------------- /sudachitra/input_string_normalizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Works Applications Co., Ltd. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from tokenizers.normalizers import Lowercase, NFKC, Sequence, Strip 16 | 17 | 18 | class InputStringNormalizer(object): 19 | def __init__(self, do_strip=False, do_lower_case=False, do_nfkc=False): 20 | self.do_strip: bool = do_strip 21 | self.do_lower_case: bool = do_lower_case 22 | self.do_nfkc: bool = do_nfkc 23 | self._normalizer: Sequence = self._init_normalizer() 24 | 25 | def _init_normalizer(self) -> Sequence: 26 | normalizers = [] 27 | if self.do_strip: 28 | normalizers.append(Strip()) 29 | if self.do_lower_case: 30 | normalizers.append(Lowercase()) 31 | if self.do_nfkc: 32 | normalizers.append(NFKC()) 33 | return Sequence(normalizers) 34 | 35 | @property 36 | def normalizer(self) -> Sequence: 37 | return self._normalizer 38 | 39 | def normalize_str(self, text: str) -> str: 40 | return self.normalizer.normalize_str(text) 41 | -------------------------------------------------------------------------------- /sudachitra/pretokenizer/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Works Applications Co., Ltd. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .japanese_bert_wordpiece_tokenizer import JapaneseBertWordPieceTokenizer 16 | from .pos_substitution_tokenizer import PartOfSpeechSubstitutionTokenizer 17 | from .sudachipy_pretokenizer import pretokenizer_handler 18 | -------------------------------------------------------------------------------- /sudachitra/pretokenizer/japanese_bert_wordpiece_tokenizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Works Applications Co., Ltd. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | from typing import Callable, Dict, Iterator, List, Optional, Union 17 | 18 | from logzero import logger 19 | from tokenizers import Tokenizer, AddedToken, decoders, trainers 20 | from tokenizers.models import WordPiece 21 | from tokenizers.pre_tokenizers import BertPreTokenizer, PreTokenizer 22 | from tokenizers.processors import BertProcessing 23 | from tokenizers.implementations import BertWordPieceTokenizer 24 | from tokenizers.implementations.base_tokenizer import BaseTokenizer 25 | 26 | from ..input_string_normalizer import InputStringNormalizer 27 | 28 | 29 | class JapaneseBertWordPieceTokenizer(BaseTokenizer): 30 | def __init__( 31 | self, 32 | vocab: Optional[Union[str, Dict[str, int]]] = None, 33 | unk_token: Union[str, AddedToken] = "[UNK]", 34 | sep_token: Union[str, AddedToken] = "[SEP]", 35 | cls_token: Union[str, AddedToken] = "[CLS]", 36 | pad_token: Union[str, AddedToken] = "[PAD]", 37 | mask_token: Union[str, AddedToken] = "[MASK]", 38 | do_lower_case: bool = False, 39 | do_nfkc: bool = False, 40 | do_strip: bool = False, 41 | wordpieces_prefix: str = "##", 42 | ): 43 | if vocab is not None: 44 | tokenizer = Tokenizer(WordPiece(vocab, unk_token=str(unk_token))) 45 | else: 46 | tokenizer = Tokenizer(WordPiece(unk_token=str(unk_token))) 47 | 48 | # Let the tokenizer know about special tokens if they are part of the vocab 49 | if tokenizer.token_to_id(str(unk_token)) is not None: 50 | tokenizer.add_special_tokens([str(unk_token)]) 51 | if tokenizer.token_to_id(str(sep_token)) is not None: 52 | tokenizer.add_special_tokens([str(sep_token)]) 53 | if tokenizer.token_to_id(str(cls_token)) is not None: 54 | tokenizer.add_special_tokens([str(cls_token)]) 55 | if tokenizer.token_to_id(str(pad_token)) is not None: 56 | tokenizer.add_special_tokens([str(pad_token)]) 57 | if tokenizer.token_to_id(str(mask_token)) is not None: 58 | tokenizer.add_special_tokens([str(mask_token)]) 59 | 60 | _normalizer = InputStringNormalizer(do_strip=do_strip, do_lower_case=do_lower_case, do_nfkc=do_nfkc) 61 | tokenizer.normalizer = _normalizer.normalizer 62 | tokenizer.pre_tokenizer = BertPreTokenizer() 63 | 64 | if vocab is not None: 65 | sep_token_id = tokenizer.token_to_id(str(sep_token)) 66 | if sep_token_id is None: 67 | raise TypeError("sep_token not found in the vocabulary") 68 | cls_token_id = tokenizer.token_to_id(str(cls_token)) 69 | if cls_token_id is None: 70 | raise TypeError("cls_token not found in the vocabulary") 71 | 72 | tokenizer.post_processor = BertProcessing( 73 | (str(sep_token), sep_token_id), (str(cls_token), cls_token_id) 74 | ) 75 | tokenizer.decoder = decoders.WordPiece(prefix=wordpieces_prefix) 76 | 77 | parameters = { 78 | "model": "BertSudachiWordPiece", 79 | "unk_token": unk_token, 80 | "sep_token": sep_token, 81 | "cls_token": cls_token, 82 | "pad_token": pad_token, 83 | "mask_token": mask_token, 84 | "do_strip": do_strip, 85 | "do_lower_case": do_lower_case, 86 | "do_nfkc": do_nfkc, 87 | "wordpieces_prefix": wordpieces_prefix, 88 | } 89 | 90 | super().__init__(tokenizer, parameters) 91 | 92 | @staticmethod 93 | def from_file(vocab: str, **kwargs): 94 | vocab = WordPiece.read_file(vocab) 95 | return BertWordPieceTokenizer(vocab, **kwargs) 96 | 97 | def train( 98 | self, 99 | files: Union[str, List[str]], 100 | vocab_size: int = 30000, 101 | min_frequency: int = 2, 102 | limit_alphabet: int = 1000, 103 | initial_alphabet: List[str] = [], 104 | special_tokens: List[Union[str, AddedToken]] = [ 105 | "[PAD]", 106 | "[UNK]", 107 | "[CLS]", 108 | "[SEP]", 109 | "[MASK]", 110 | ], 111 | show_progress: bool = True, 112 | wordpieces_prefix: str = "##", 113 | ): 114 | """ Train the model using the given files """ 115 | logger.info("Parameters for training") 116 | logger.info("\tvocab_size: {}".format(vocab_size)) 117 | logger.info("\tdo_strip: {}".format(self._parameters["do_strip"])) 118 | logger.info("\tdo_lower_case: {}".format(self._parameters["do_lower_case"])) 119 | logger.info("\tdo_nfkc: {}".format(self._parameters["do_nfkc"])) 120 | logger.info("\tmin_frequency: {}".format(min_frequency)) 121 | logger.info("\tlimit_alphabet: {}".format(limit_alphabet)) 122 | logger.info("\tinitial_alphabet: {}".format(",".join(initial_alphabet))) 123 | logger.info("\tspecial_tokens: {}".format(",".join(special_tokens))) 124 | logger.info("\twordpieces_prefix: {}".format(wordpieces_prefix)) 125 | 126 | trainer = trainers.WordPieceTrainer( 127 | vocab_size=vocab_size, 128 | min_frequency=min_frequency, 129 | limit_alphabet=limit_alphabet, 130 | initial_alphabet=initial_alphabet, 131 | special_tokens=special_tokens, 132 | show_progress=show_progress, 133 | continuing_subword_prefix=wordpieces_prefix, 134 | ) 135 | if isinstance(files, str): 136 | files = [files] 137 | 138 | logger.info("Input files") 139 | logger.info("\n".join(map(lambda x: "\t{}".format(x), files))) 140 | 141 | self._tokenizer.train(files, trainer=trainer) 142 | logger.info("#Vocab: {}".format(self.get_vocab_size())) 143 | 144 | def train_from_iterator( 145 | self, 146 | iterator: Union[Iterator[str], Iterator[Iterator[str]]], 147 | vocab_size: int = 30000, 148 | min_frequency: int = 2, 149 | limit_alphabet: int = 1000, 150 | initial_alphabet: List[str] = [], 151 | special_tokens: List[Union[str, AddedToken]] = [ 152 | "[PAD]", 153 | "[UNK]", 154 | "[CLS]", 155 | "[SEP]", 156 | "[MASK]", 157 | ], 158 | show_progress: bool = True, 159 | wordpieces_prefix: str = "##", 160 | ): 161 | """ Train the model using the given iterator """ 162 | logger.info("Parameters for training") 163 | logger.info("\tvocab_size: {}".format(vocab_size)) 164 | logger.info("\tdo_strip: {}".format(self._parameters["do_strip"])) 165 | logger.info("\tdo_lower_case: {}".format(self._parameters["do_lower_case"])) 166 | logger.info("\tdo_nfkc: {}".format(self._parameters["do_nfkc"])) 167 | logger.info("\tmin_frequency: {}".format(min_frequency)) 168 | logger.info("\tlimit_alphabet: {}".format(limit_alphabet)) 169 | logger.info("\tinitial_alphabet: {}".format(",".join(initial_alphabet))) 170 | logger.info("\tspecial_tokens: {}".format(",".join(special_tokens))) 171 | logger.info("\twordpieces_prefix: {}".format(wordpieces_prefix)) 172 | 173 | trainer = trainers.WordPieceTrainer( 174 | vocab_size=vocab_size, 175 | min_frequency=min_frequency, 176 | limit_alphabet=limit_alphabet, 177 | initial_alphabet=initial_alphabet, 178 | special_tokens=special_tokens, 179 | show_progress=show_progress, 180 | continuing_subword_prefix=wordpieces_prefix, 181 | ) 182 | self._tokenizer.train_from_iterator(iterator, trainer=trainer) 183 | 184 | logger.info("#Vocab: {}".format(self.get_vocab_size())) 185 | 186 | def save(self, output_tokenizer_path: str, pretty: bool = False): 187 | """ 188 | Saves a config file of the tokenizer. 189 | 190 | Notes: To serialize tokenizer, puts dummy pre-tokenizer. 191 | 192 | Args: 193 | output_tokenizer_path (str): Output file path to be saved a config file of tokenizer. 194 | pretty (bool): Json format type. 195 | """ 196 | logger.info("Saving config to `{}`".format(output_tokenizer_path)) 197 | 198 | self.pre_tokenizer = BertPreTokenizer() # dummy 199 | super().save(output_tokenizer_path, pretty=pretty) 200 | 201 | def save_vocab(self, output_dir: str, prefix: str): 202 | """ 203 | Save the vocabulary. 204 | 205 | Args: 206 | output_dir (str): The path to the target directory in which to save the various files. 207 | prefix (str): An optional prefix, used to prefix each file name 208 | """ 209 | logger.info("Saving vocab to `{}`".format(os.path.join(output_dir, "{}-vocab.txt".format(prefix)))) 210 | 211 | self.model.save(output_dir, prefix) 212 | -------------------------------------------------------------------------------- /sudachitra/pretokenizer/pos_substitution_tokenizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Works Applications Co., Ltd. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from collections import defaultdict 16 | from typing import Dict, List, Optional, Tuple, Union 17 | 18 | from logzero import logger 19 | from progressbar import progressbar as tqdm 20 | 21 | from .. import SudachipyWordTokenizer 22 | from ..tokenization_bert_sudachipy import pos_substitution_format 23 | from ..word_formatter import word_formatter 24 | 25 | 26 | class PartOfSpeechSubstitutionTokenizer(SudachipyWordTokenizer): 27 | def __init__( 28 | self, 29 | split_mode: Optional[str] = "C", 30 | dict_type: Optional[str] = "core", 31 | word_form_type: Optional[str] = "surface", 32 | **kwargs 33 | ): 34 | """ 35 | Constructs a PartOfSpeechSubstitutionTokenizer. 36 | 37 | Args: 38 | split_mode (:obj:`str`, `optional`, defaults to :obj:`"C"`): 39 | The mode of splitting. 40 | "A", "B", or "C" can be specified. 41 | dict_type (:obj:`str`, `optional`, defaults to :obj:`"core"`): 42 | Sudachi dictionary type to be used for tokenization. 43 | "small", "core", or "full" can be specified. 44 | word_form_type (:obj:`str`, `optional`, defaults to :obj:`"surface"`): 45 | Word form type for each morpheme. 46 | The values defined in WordFormTypes can be specified. 47 | **kwargs: 48 | Sudachi dictionary parameters. 49 | """ 50 | super().__init__(split_mode=split_mode, dict_type=dict_type, **kwargs) 51 | self.word_form_type = word_form_type 52 | self.word_formatter = word_formatter(self.word_form_type, self.sudachi_dict) 53 | self._vocab = None 54 | 55 | def get_word2freq_and_pos(self, files: List[str]) -> Tuple[Dict[str, int], List[str]]: 56 | """ 57 | Tokenizes sentences in the specified files and returns tokenized data. 58 | 59 | Args: 60 | files (List[str]): List of paths of input files. 61 | 62 | Returns: 63 | Tuple[Dict[str, int], List[str]]: 64 | 1. Dictionary of tokens and its frequency. 65 | 2. List of part-of-speech tags. 66 | 67 | """ 68 | word2freq = defaultdict(int) 69 | pos_list = [] 70 | 71 | logger.info("Tokenization") 72 | for file in files: 73 | logger.info("\tReading file: {}".format(file)) 74 | with open(file, 'r') as f: 75 | for line in tqdm(f): 76 | line = line.strip() 77 | if line != "": 78 | for m in self.tokenize(line): 79 | word2freq[self.word_formatter(m)] += 1 80 | pos_list.append(pos_substitution_format(m)) 81 | 82 | return word2freq, list(set(pos_list)) 83 | 84 | def train( 85 | self, 86 | files: Union[str, List[str]], 87 | token_size: int = 32000, 88 | min_frequency: int = 1, 89 | limit_character: int = 1000, 90 | special_tokens: List[str] = ["[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]"] 91 | ): 92 | """ 93 | Reads the input files and builds vocabulary. 94 | 95 | Args: 96 | files (str | List[str]): 97 | List of paths of input files to build vocabulary. 98 | token_size (int): 99 | The size of the vocabulary, excluding special tokens and pos tags. 100 | min_frequency (int): 101 | Ignores all words (tokens and characters) with total frequency lower than this. 102 | limit_character (int): 103 | The maximum different characters to keep in the vocabulary. 104 | special_tokens (List[str]): 105 | A list of special tokens the model should know of. 106 | """ 107 | logger.info("Parameters for training") 108 | logger.info("\ttoken_size: {}".format(token_size)) 109 | logger.info("\tmin_frequency: {}".format(min_frequency)) 110 | logger.info("\tlimit_character: {}".format(limit_character)) 111 | logger.info("\tspecial_tokens: {}".format(",".join(special_tokens))) 112 | 113 | if isinstance(files, str): 114 | files = [files] 115 | word2freq, pos_list = self.get_word2freq_and_pos(files) 116 | 117 | word2freq = {word: freq for word, freq in word2freq.items() if min_frequency <= freq} 118 | 119 | char2freq = {word: freq for word, freq in word2freq.items() if len(word) == 1} 120 | if limit_character < len(char2freq): 121 | sorted_char2freq = sorted(char2freq.items(), key=lambda x: x[1], reverse=True) 122 | for char, _ in sorted_char2freq[limit_character:]: 123 | del word2freq[char] 124 | 125 | word2freq = sorted(word2freq.items(), key=lambda x: x[1], reverse=True) 126 | if token_size < len(word2freq): 127 | word2freq = word2freq[:token_size] 128 | 129 | self._vocab = special_tokens + pos_list + list(map(lambda x: x[0], word2freq)) 130 | logger.info("#Vocab, including POS, all (special) tokens and characters\n{}".format(len(self.vocab))) 131 | 132 | def save_vocab(self, output_path: str): 133 | """ 134 | Saves Vocabulary into the specified path. 135 | 136 | Args: 137 | output_path (str): The output path where the vocabulary will be stored. 138 | """ 139 | logger.info("Saving vocab to `{}`".format(output_path)) 140 | with open(output_path, 'w') as f: 141 | f.write("\n".join(self.vocab)) 142 | 143 | @property 144 | def vocab(self): 145 | return self._vocab 146 | -------------------------------------------------------------------------------- /sudachitra/pretokenizer/sudachipy_pretokenizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Works Applications Co., Ltd. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from sudachipy import Dictionary, MorphemeList 16 | from tokenizers import NormalizedString 17 | from typing import Callable, List, Optional 18 | 19 | from ..word_formatter import word_formatter, WordFormTypes 20 | 21 | 22 | def pretokenizer_handler(sudachi_dict: Dictionary, word_form_type: Optional[str] = 'surface')\ 23 | -> Callable[[int, NormalizedString, MorphemeList], List[NormalizedString]]: 24 | """ 25 | A handler for Dictionary.pre_tokenizer that transform MorphemeList into list to tokens. 26 | 27 | Returns a handler to convert a morpheme to the specified word form. 28 | 29 | Args: 30 | sudachi_dict (Dictionary): 31 | Sudachi dictionary. 32 | word_form_type (:obj:`str`, `optional`, defaults to :obj:`"surface"`): 33 | Word form type for each morpheme. 34 | The values defined in WordFormTypes can be specified. 35 | 36 | Returns: 37 | Callable[[int, NormalizedString, MorphemeList], List[NormalizedString]]: 38 | A handler for Dictionary.pre_tokenizer that transform MorphemeList into list to tokens. 39 | https://worksapplications.github.io/sudachi.rs/python/api/sudachipy.html#sudachipy.Dictionary.pre_tokenizer 40 | """ 41 | _word_formatter = word_formatter(word_form_type, sudachi_dict) if word_form_type != WordFormTypes.SURFACE else None 42 | 43 | def _handler(index: int, original: NormalizedString, morphemes: MorphemeList) -> List[NormalizedString]: 44 | normalized_strings = [] 45 | 46 | for m in morphemes: 47 | begin_index = m.begin() 48 | end_index = m.end() 49 | if begin_index == end_index: # empty token 50 | continue 51 | 52 | normalized_string = original[begin_index:end_index] 53 | 54 | if _word_formatter is not None: 55 | # replace the word form of the `original` string by using `NormalizedString.replace()` with side effect. 56 | normalized_string.replace(normalized_string.normalized, _word_formatter(m)) 57 | 58 | normalized_strings.append(normalized_string) 59 | 60 | return normalized_strings 61 | 62 | return _handler 63 | -------------------------------------------------------------------------------- /sudachitra/sudachipy_word_tokenizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Works Applications Co., Ltd. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import Optional 16 | 17 | from sudachipy import Dictionary 18 | from sudachipy import MorphemeList 19 | from sudachipy import SplitMode 20 | 21 | 22 | def get_split_mode(split_mode: str) -> SplitMode: 23 | """ 24 | Returns the specified SplitMode. 25 | "A", "B", or "C" can be specified. 26 | 27 | Args: 28 | split_mode (str): The mode of splitting. 29 | 30 | Returns: 31 | SplitMode: Unit to split text. 32 | 33 | Raises: 34 | ValueError: If `split_mode` is not defined in SudachiPy. 35 | """ 36 | split_mode = split_mode.upper() 37 | if split_mode == "C": 38 | return SplitMode.C 39 | elif split_mode == "B": 40 | return SplitMode.B 41 | elif split_mode == "A": 42 | return SplitMode.A 43 | else: 44 | raise ValueError("Invalid `split_mode`: " + split_mode) 45 | 46 | 47 | class SudachipyWordTokenizer: 48 | """Runs tokenization with SudachiPy.""" 49 | 50 | def __init__( 51 | self, 52 | split_mode: Optional[str] = "A", 53 | config_path: Optional[str] = None, 54 | resource_dir: Optional[str] = None, 55 | dict_type: Optional[str] = "core", 56 | ): 57 | """ 58 | Constructs a SudachipyTokenizer. 59 | 60 | Args: 61 | split_mode (:obj:`str`, `optional`, defaults to :obj:`"C"`): 62 | The mode of splitting. 63 | "A", "B", or "C" can be specified. 64 | config_path (:obj:`str`, `optional`, defaults to :obj:`None`): 65 | Path to a config file of SudachiPy to be used for the sudachi dictionary initialization. 66 | resource_dir (:obj:`str`, `optional`, defaults to :obj:`None`): 67 | Path to a resource dir containing resource files, such as "sudachi.json". 68 | dict_type (:obj:`str`, `optional`, defaults to :obj:`"core"`): 69 | Sudachi dictionary type to be used for tokenization. 70 | "small", "core", or "full" can be specified. 71 | """ 72 | self.split_mode = get_split_mode(split_mode) 73 | 74 | self.sudachi_dict = Dictionary(config_path=config_path, resource_dir=resource_dir, dict=dict_type) 75 | self.sudachi = self.sudachi_dict.create(self.split_mode) 76 | 77 | def tokenize(self, text: str) -> MorphemeList: 78 | """ 79 | Tokenizes the specified text and returns its morphemes. 80 | 81 | Args: 82 | text (str): Input string. 83 | 84 | Returns: 85 | MorphemeList: List of morphemes. 86 | """ 87 | return self.sudachi.tokenize(text) 88 | -------------------------------------------------------------------------------- /sudachitra/tokenization_electra_sudachipy.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Works Applications Co., Ltd. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from . import BertSudachipyTokenizer 16 | 17 | 18 | VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"} 19 | 20 | # TODO: set official URL 21 | PRETRAINED_VOCAB_FILES_MAP = { 22 | "vocab_file": { 23 | "megagonlabs/electra-base-ud-japanese": "https://.../vocab.txt", 24 | } 25 | } 26 | 27 | PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { 28 | "megagonlabs/electra-base-ud-japanese": 512, 29 | } 30 | 31 | PRETRAINED_INIT_CONFIGURATION = { 32 | "megagonlabs/electra-base-ud-japanese": { 33 | "do_lower_case": False, 34 | "word_tokenizer_type": "sudachipy", 35 | "subword_tokenizer_type": "pos_substitution", 36 | }, 37 | } 38 | 39 | 40 | class ElectraSudachipyTokenizer(BertSudachipyTokenizer): 41 | 42 | vocab_files_names = VOCAB_FILES_NAMES 43 | pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP 44 | max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES 45 | pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION 46 | -------------------------------------------------------------------------------- /sudachitra/word_formatter.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Works Applications Co., Ltd. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | from enum import Enum 17 | from typing import Callable 18 | 19 | from sudachipy import Dictionary, Morpheme 20 | 21 | 22 | HALF_ASCII_TRANSLATE_TABLE = str.maketrans({chr(0xFF01 + _): chr(0x21 + _) for _ in range(94)}) 23 | 24 | CONJUGATIVE_POS = {'動詞', '形容詞', '助動詞'} 25 | 26 | 27 | class WordFormTypes(str, Enum): 28 | SURFACE = 'surface' 29 | DICTIONARY = 'dictionary' 30 | NORMALIZED = 'normalized' 31 | DICTIONARY_AND_SURFACE = 'dictionary_and_surface' 32 | NORMALIZED_AND_SURFACE = 'normalized_and_surface' 33 | SURFACE_HALF_ASCII = 'surface_half_ascii' 34 | DICTIONARY_HALF_ASCII = 'dictionary_half_ascii' 35 | DICTIONARY_AND_SURFACE_HALF_ASCII = 'dictionary_and_surface_half_ascii' 36 | NORMALIZED_CONJUGATION = 'normalized_conjugation' 37 | NORMALIZED_NOUNS = "normalized_nouns" 38 | 39 | def __str__(self): 40 | return self.value 41 | 42 | 43 | def word_formatter(word_form_type, sudachi_dict: Dictionary) -> Callable[[Morpheme], str]: 44 | """ 45 | Returns the function that converts a morpheme to the specified word form. 46 | 47 | Args: 48 | word_form_type (str): Word form. 49 | sudachi_dict (Dictionary): Sudachi dictionary. 50 | 51 | Returns: 52 | Callable[[Morpheme], str]: The function that converts a morpheme to the specified word form. 53 | """ 54 | 55 | if word_form_type not in list(WordFormTypes): 56 | raise ValueError('Invalid word_form_type error `{}`: {}'.format(word_form_type, 57 | list(map(str, WordFormTypes)))) 58 | 59 | if word_form_type == WordFormTypes.NORMALIZED_CONJUGATION: 60 | from sudachitra.conjugation_preserving_normalizer import ConjugationPreservingNormalizer 61 | conjugation_preserving_normalizer = ConjugationPreservingNormalizer( 62 | os.path.join(os.path.dirname(__file__), "resources/inflection_table.json"), 63 | os.path.join(os.path.dirname(__file__), "resources/conjugation_type_table.json"), 64 | sudachi_dict) 65 | 66 | conjugation_matcher = sudachi_dict.pos_matcher(lambda p: p[0] in CONJUGATIVE_POS) 67 | nouns_matcher = sudachi_dict.pos_matcher(lambda x: x[5] == "*") 68 | 69 | word_formatters = { 70 | WordFormTypes.SURFACE: ( 71 | lambda m: m.surface() 72 | ), 73 | WordFormTypes.DICTIONARY: ( 74 | lambda m: m.dictionary_form() 75 | ), 76 | WordFormTypes.NORMALIZED: ( 77 | lambda m: m.normalized_form() 78 | ), 79 | WordFormTypes.DICTIONARY_AND_SURFACE: ( 80 | lambda m: m.surface() if conjugation_matcher(m) else m.dictionary_form() 81 | ), 82 | WordFormTypes.NORMALIZED_AND_SURFACE: ( 83 | lambda m: m.surface() if conjugation_matcher(m) else m.normalized_form() 84 | ), 85 | WordFormTypes.NORMALIZED_NOUNS: ( 86 | lambda m: m.normalized_form() if nouns_matcher(m) else m.surface() 87 | ), 88 | WordFormTypes.SURFACE_HALF_ASCII: ( 89 | lambda m: m.surface().translate(HALF_ASCII_TRANSLATE_TABLE) 90 | ), 91 | WordFormTypes.DICTIONARY_HALF_ASCII: ( 92 | lambda m: m.dictionary_form().translate(HALF_ASCII_TRANSLATE_TABLE) 93 | ), 94 | WordFormTypes.DICTIONARY_AND_SURFACE_HALF_ASCII: ( 95 | lambda m: m.surface().translate(HALF_ASCII_TRANSLATE_TABLE) if conjugation_matcher(m) 96 | else m.dictionary_form().translate(HALF_ASCII_TRANSLATE_TABLE) 97 | ), 98 | WordFormTypes.NORMALIZED_CONJUGATION: ( 99 | lambda m: conjugation_preserving_normalizer.normalized(m) 100 | ) 101 | } 102 | 103 | return word_formatters[word_form_type] 104 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Works Applications Co., Ltd. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /tests/test_japanese_bert_wordpiece_tokenizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Works Applications Co., Ltd. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import re 16 | import unittest 17 | from typing import List 18 | 19 | from sudachipy import Dictionary, SplitMode 20 | from sudachitra.pretokenizer import JapaneseBertWordPieceTokenizer 21 | from sudachitra.pretokenizer import pretokenizer_handler 22 | from sudachitra.word_formatter import WordFormTypes 23 | from tokenizers import Encoding 24 | 25 | 26 | class JapaneseBertWordPieceTokenizerTest(unittest.TestCase): 27 | 28 | def setUp(self) -> None: 29 | self.vocab = { 30 | "[UNK]": 0, 31 | "[SEP]": 1, 32 | "[CLS]": 2, 33 | "[PAD]": 3, 34 | "[MASK]": 4, 35 | "引越": 5, 36 | "引っ越し": 6, 37 | "し": 7, 38 | "て": 8, 39 | "為る": 9, 40 | "する": 10, 41 | "から": 11, 42 | "すだち": 12, 43 | "酢橘": 13, 44 | "-": 14, 45 | "-": 15, 46 | "Sudachi": 16, 47 | "Sudachi": 17, 48 | "sudachi": 18, 49 | "を": 19, 50 | "とど": 20, 51 | "届": 21, 52 | "##け": 22, 53 | "##る": 23, 54 | # "届け": 22, 55 | # "届ける": 23, 56 | "ます": 24, 57 | "…": 25, 58 | ".": 26, 59 | "。": 27, 60 | "\n": 28 61 | } 62 | self.wp_tokenizer = JapaneseBertWordPieceTokenizer(vocab=self.vocab, 63 | do_lower_case=False, 64 | do_nfkc=False, 65 | do_strip=False) 66 | self.wordpieces_prefix = '##' 67 | self.unk_token = '[UNK]' 68 | self.prefix_pattern = re.compile(f'^{self.wordpieces_prefix}') 69 | self.delete_prefix = lambda x: self.prefix_pattern.sub('', x) 70 | 71 | self.sudachi_dict = Dictionary(dict='core') 72 | self.test_sentence = '引越してからすだちSudachiをとどけます。' 73 | self.sudachi = self.sudachi_dict.create(mode=SplitMode.C) 74 | 75 | def set_pretokenizer(self, word_form_type: WordFormTypes): 76 | pretok = self.sudachi_dict.pre_tokenizer(mode=SplitMode.C, 77 | handler=pretokenizer_handler(self.sudachi_dict, 78 | word_form_type=word_form_type)) 79 | self.wp_tokenizer.pre_tokenizer = pretok 80 | 81 | def validate_encoding(self, tokens: List[str], encoding: Encoding): 82 | """ 83 | Validates properties of `Encoding`. 84 | (https://github.com/huggingface/tokenizers/blob/master/bindings/python/py_src/tokenizers/__init__.pyi#L69) 85 | 86 | Args: 87 | tokens (List[str]): Expected tokens. 88 | encoding (Encoding): Encoded tokens. 89 | """ 90 | self.assertListEqual(tokens, encoding.tokens) 91 | self.assertListEqual(list(map(self.wp_tokenizer.token_to_id, tokens)), encoding.ids) 92 | self.assertListEqual([None, *[0 for _ in range(len(tokens) - 2)], None], encoding.sequence_ids) 93 | self.assertListEqual([1, *[0 for _ in range(len(tokens) - 2)], 1], encoding.special_tokens_mask) 94 | self.assertListEqual([0 for _ in range(len(tokens))], encoding.type_ids) 95 | self.assertListEqual([1 for _ in range(len(tokens))], encoding.attention_mask) 96 | # ToDo: add test for encoding.offsets and encoding.word_ids (https://github.com/WorksApplications/SudachiTra/issues/42) 97 | 98 | def test_surface(self): 99 | word_form_type = WordFormTypes.SURFACE 100 | tokens = ['[CLS]', '引越', 'し', 'て', 'から', 'すだち', 'Sudachi', 'を', 'とど', '##け', 'ます', '。', '[SEP]'] 101 | 102 | self.set_pretokenizer(word_form_type) 103 | encoding = self.wp_tokenizer.encode(self.test_sentence) 104 | self.validate_encoding(tokens, encoding) 105 | 106 | def test_normalized_and_surface(self): 107 | word_form_type = WordFormTypes.NORMALIZED_AND_SURFACE 108 | tokens = ['[CLS]', '引っ越し', 'し', 'て', 'から', '酢橘', 'Sudachi', 'を', 'とど', '##け', 'ます', '。', '[SEP]'] 109 | 110 | self.set_pretokenizer(word_form_type) 111 | encoding = self.wp_tokenizer.encode(self.test_sentence) 112 | self.validate_encoding(tokens, encoding) 113 | 114 | def test_normalized_nouns(self): 115 | word_form_type = WordFormTypes.NORMALIZED_NOUNS 116 | tokens = ['[CLS]', '引っ越し', 'し', 'て', 'から', '酢橘', 'Sudachi', 'を', 'とど', '##け', 'ます', '。', '[SEP]'] 117 | 118 | self.set_pretokenizer(word_form_type) 119 | encoding = self.wp_tokenizer.encode(self.test_sentence) 120 | self.validate_encoding(tokens, encoding) 121 | 122 | 123 | def test_normalized_conjugation(self): 124 | word_form_type = WordFormTypes.NORMALIZED_CONJUGATION 125 | tokens = ['[CLS]', '引っ越し', 'し', 'て', 'から', '酢橘', 'Sudachi', 'を', '届', '##け', 'ます', '。', '[SEP]'] 126 | 127 | self.set_pretokenizer(word_form_type) 128 | encoding = self.wp_tokenizer.encode(self.test_sentence) 129 | self.validate_encoding(tokens, encoding) 130 | 131 | def test_normalized_form(self): 132 | word_form_type = WordFormTypes.NORMALIZED 133 | tokens = ['[CLS]', '引っ越し', '為る', 'て', 'から', '酢橘', 'Sudachi', 'を', '届', '##け', '##る', 'ます', '。', '[SEP]'] 134 | 135 | self.set_pretokenizer(word_form_type) 136 | encoding = self.wp_tokenizer.encode(self.test_sentence) 137 | self.validate_encoding(tokens, encoding) 138 | 139 | def test_dictionary_form(self): 140 | word_form_type = WordFormTypes.DICTIONARY 141 | tokens = ['[CLS]', '引越', 'する', 'て', 'から', 'すだち', 'Sudachi', 'を', 'とど', '##け', '##る', 'ます', '。', '[SEP]'] 142 | 143 | self.set_pretokenizer(word_form_type) 144 | encoding = self.wp_tokenizer.encode(self.test_sentence) 145 | self.validate_encoding(tokens, encoding) 146 | 147 | def test_dictionary_and_surface(self): 148 | word_form_type = WordFormTypes.DICTIONARY_AND_SURFACE 149 | tokens = ['[CLS]', '引越', 'し', 'て', 'から', 'すだち', 'Sudachi', 'を', 'とど', '##け', 'ます', '。', '[SEP]'] 150 | 151 | self.set_pretokenizer(word_form_type) 152 | encoding = self.wp_tokenizer.encode(self.test_sentence) 153 | self.validate_encoding(tokens, encoding) 154 | 155 | def test_normalizers(self): 156 | self.wp_tokenizer = JapaneseBertWordPieceTokenizer(vocab=self.vocab, 157 | do_lower_case=True, 158 | do_nfkc=True, 159 | do_strip=True) 160 | word_form_type = WordFormTypes.SURFACE 161 | self.set_pretokenizer(word_form_type) 162 | 163 | # lowercase 164 | sentence = 'SUDACHI' 165 | encoding = self.wp_tokenizer.encode(sentence) 166 | tokens = ['[CLS]', 'sudachi', '[SEP]'] 167 | self.validate_encoding(tokens, encoding) 168 | 169 | # # strip 170 | sentence = ' sudachi\n' 171 | encoding = self.wp_tokenizer.encode(sentence) 172 | tokens = ['[CLS]', 'sudachi', '[SEP]'] 173 | self.validate_encoding(tokens, encoding) 174 | self.validate_encoding(tokens, encoding) 175 | 176 | # nfkc 177 | sentence = '…' 178 | encoding = self.wp_tokenizer.encode(sentence) 179 | tokens = ['[CLS]', '.', '.', '.', '[SEP]'] 180 | self.validate_encoding(tokens, encoding) 181 | self.validate_encoding(tokens, encoding) 182 | 183 | def test_oov(self): 184 | word_form_type = WordFormTypes.SURFACE 185 | self.set_pretokenizer(word_form_type) 186 | 187 | sentence = 'OOV酢橘OOV酢橘OOV' 188 | encoding = self.wp_tokenizer.encode(sentence) 189 | tokens = ['[CLS]', '[UNK]', '酢橘', '[UNK]', '酢橘', '[UNK]', '[SEP]'] 190 | self.validate_encoding(tokens, encoding) 191 | -------------------------------------------------------------------------------- /tests/test_tokenization_bert_sudachipy.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Works Applications Co., Ltd. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | import pickle 17 | import shutil 18 | import tempfile 19 | import unittest 20 | 21 | from transformers.models.bert.tokenization_bert import WordpieceTokenizer 22 | 23 | from sudachitra import BertSudachipyTokenizer, SudachipyWordTokenizer 24 | from sudachitra.tokenization_bert_sudachipy import VOCAB_FILES_NAMES 25 | from sudachitra.word_formatter import word_formatter 26 | 27 | 28 | class BertSudachipyTokenizationTest(unittest.TestCase): 29 | 30 | tokenizer_class = BertSudachipyTokenizer 31 | 32 | def setUp(self): 33 | super().setUp() 34 | self.tmpdirname = tempfile.mkdtemp() 35 | 36 | vocab_tokens = [ 37 | "[UNK]", 38 | "[CLS]", 39 | "[SEP]", 40 | "こんにちは", 41 | "こん", 42 | "にちは", 43 | "ばんは", 44 | "##こん", 45 | "##にちは", 46 | "##ばんは", 47 | "世界", 48 | "##世界", 49 | "、", 50 | "##、", 51 | "。", 52 | "##。", 53 | ] 54 | 55 | self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["vocab_file"]) 56 | with open(self.vocab_file, "w", encoding="utf-8") as vocab_writer: 57 | vocab_writer.write("".join([x + "\n" for x in vocab_tokens])) 58 | 59 | def tearDown(self): 60 | shutil.rmtree(self.tmpdirname) 61 | 62 | def test_pickle_sudachipy_tokenizer(self): 63 | tokenizer = self.tokenizer_class( 64 | self.vocab_file, 65 | do_lower_case=False, 66 | do_word_tokenize=True, 67 | subword_tokenizer_type='wordpiece' 68 | ) 69 | self.assertIsNotNone(tokenizer) 70 | 71 | text = "こんにちは、世界。\nこんばんは、世界。" 72 | tokens = tokenizer.tokenize(text) 73 | self.assertListEqual(tokens, ["こんにちは", "、", "世界", "。", "こん", "##ばんは", "、", "世界", "。"]) 74 | self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens), [3, 12, 10, 14, 4, 9, 12, 10, 14]) 75 | 76 | filename = os.path.join(self.tmpdirname, "tokenizer.bin") 77 | with open(filename, "wb") as handle: 78 | pickle.dump(tokenizer, handle) 79 | 80 | with open(filename, "rb") as handle: 81 | tokenizer_new = pickle.load(handle) 82 | 83 | tokens_loaded = tokenizer_new.tokenize(text) 84 | 85 | self.assertListEqual(tokens, tokens_loaded) 86 | 87 | def test_sudachipy_tokenizer_small(self): 88 | try: 89 | tokenizer = SudachipyWordTokenizer(dict_type="small") 90 | except ModuleNotFoundError: 91 | return 92 | 93 | self.assertListEqual( 94 | list(map(word_formatter('surface', tokenizer.sudachi_dict), 95 | tokenizer.tokenize("appleはsmall辞書に、apple pieはcore辞書に、apple storeはfull辞書に収録されている。"))), 96 | ["apple", "は", "small", "辞書", "に", "、", "apple", " ", "pie", "は", "core", "辞書", "に", "、", 97 | "apple", " ", "store", "は", "full", "辞書", "に", "収録", "さ", "れ", "て", "いる", "。"] 98 | ) 99 | 100 | def test_sudachipy_tokenizer_core(self): 101 | try: 102 | tokenizer = SudachipyWordTokenizer(dict_type="core") 103 | except ModuleNotFoundError: 104 | return 105 | 106 | self.assertListEqual( 107 | list(map(word_formatter('surface', tokenizer.sudachi_dict), 108 | tokenizer.tokenize("appleはsmall辞書に、apple pieはcore辞書に、apple storeはfull辞書に収録されている。"))), 109 | ["apple", "は", "small", "辞書", "に", "、", "apple pie", "は", "core", "辞書", "に", "、", 110 | "apple", " ", "store", "は", "full", "辞書", "に", "収録", "さ", "れ", "て", "いる", "。"] 111 | ) 112 | 113 | def test_sudachipy_tokenizer_full(self): 114 | try: 115 | tokenizer = SudachipyWordTokenizer(dict_type="full") 116 | except ModuleNotFoundError: 117 | return 118 | 119 | self.assertListEqual( 120 | list(map(word_formatter('surface', tokenizer.sudachi_dict), 121 | tokenizer.tokenize("appleはsmall辞書に、apple pieはcore辞書に、apple storeはfull辞書に収録されている。"))), 122 | ["apple", "は", "small", "辞書", "に", "、", "apple pie", "は", "core", "辞書", "に", "、", 123 | "apple store", "は", "full", "辞書", "に", "収録", "さ", "れ", "て", "いる", "。"] 124 | ) 125 | 126 | def test_sudachipy_tokenizer_surface(self): 127 | tokenizer = self.tokenizer_class(self.vocab_file, do_subword_tokenize=False, 128 | word_form_type='surface') 129 | 130 | self.assertListEqual( 131 | tokenizer.tokenize("appleの辞書形はAppleで正規形はアップルである。"), 132 | ["apple", "の", "辞書", "形", "は", "Apple", "で", "正規", "形", "は", "アップル", "で", "ある", "。"] 133 | 134 | ) 135 | 136 | def test_sudachipy_tokenizer_dictionary_form(self): 137 | tokenizer = self.tokenizer_class(self.vocab_file, do_subword_tokenize=False, 138 | word_form_type='dictionary') 139 | 140 | self.assertListEqual( 141 | tokenizer.tokenize("appleの辞書形はAppleで正規形はアップルである。"), 142 | ["Apple", "の", "辞書", "形", "は", "Apple", "で", "正規", "形", "は", "アップル", "だ", "ある", "。"] 143 | ) 144 | 145 | def test_sudachipy_tokenizer_normalized_form(self): 146 | tokenizer = self.tokenizer_class(self.vocab_file, do_subword_tokenize=False, 147 | word_form_type='normalized') 148 | 149 | self.assertListEqual( 150 | tokenizer.tokenize("appleの辞書形はAppleで正規形はアップルである。"), 151 | ["アップル", "の", "辞書", "形", "は", "アップル", "で", "正規", "形", "は", "アップル", "だ", "有る", "。"] 152 | ) 153 | 154 | def test_sudachipy_tokenizer_dictionary_form_and_surface(self): 155 | tokenizer = self.tokenizer_class(self.vocab_file, do_subword_tokenize=False, 156 | word_form_type='dictionary_and_surface') 157 | 158 | self.assertListEqual( 159 | tokenizer.tokenize("appleの辞書形はAppleで正規形はアップルである。"), 160 | ["Apple", "の", "辞書", "形", "は", "Apple", "で", "正規", "形", "は", "アップル", "で", "ある", "。"] 161 | ) 162 | 163 | def test_sudachipy_tokenizer_normalized_form_and_surface(self): 164 | tokenizer = self.tokenizer_class(self.vocab_file, do_subword_tokenize=False, 165 | word_form_type='normalized_and_surface') 166 | 167 | self.assertListEqual( 168 | tokenizer.tokenize("appleの辞書形はAppleで正規形はアップルである。"), 169 | ["アップル", "の", "辞書", "形", "は", "アップル", "で", "正規", "形", "は", "アップル", "で", "ある", "。"] 170 | ) 171 | self.assertListEqual( 172 | tokenizer.tokenize("強がっている"), 173 | ["強", "がる", "て", "いる"] 174 | ) 175 | 176 | def test_sudachipy_tokenizer_normalized_nouns(self): 177 | tokenizer = self.tokenizer_class(self.vocab_file, do_subword_tokenize=False, 178 | word_form_type='normalized_nouns') 179 | 180 | self.assertListEqual( 181 | tokenizer.tokenize("appleの辞書形はAppleで正規形はアップルである。"), 182 | ["アップル", "の", "辞書", "形", "は", "アップル", "で", "正規", "形", "は", "アップル", "で", "ある", "。"] 183 | ) 184 | self.assertListEqual( 185 | tokenizer.tokenize("強がっている"), 186 | ["強", "がっ", "て", "いる"] 187 | ) 188 | 189 | def test_sudachipy_tokenizer_surface_half_ascii(self): 190 | tokenizer = self.tokenizer_class(self.vocab_file, do_subword_tokenize=False, 191 | word_form_type='surface_half_ascii') 192 | 193 | self.assertListEqual( 194 | tokenizer.tokenize("appleやappleの辞書形はAppleで正規形はアップルである。"), 195 | ["apple", "や", "apple", "の", "辞書", "形", "は", "Apple", "で", "正規", "形", "は", "アップル", "で", "ある", "。"] 196 | 197 | ) 198 | 199 | def test_sudachipy_tokenizer_dictionary_form_half_ascii(self): 200 | tokenizer = self.tokenizer_class(self.vocab_file, do_subword_tokenize=False, 201 | word_form_type='dictionary_half_ascii') 202 | 203 | self.assertListEqual( 204 | tokenizer.tokenize("appleやappleの辞書形はAppleで正規形はアップルである。"), 205 | ["Apple", "や", "Apple", "の", "辞書", "形", "は", "Apple", "で", "正規", "形", "は", "アップル", "だ", "ある", "。"] 206 | ) 207 | 208 | def test_sudachipy_tokenizer_dictionary_form_and_surface_half_ascii(self): 209 | tokenizer = self.tokenizer_class(self.vocab_file, do_subword_tokenize=False, 210 | word_form_type='dictionary_and_surface_half_ascii') 211 | 212 | self.assertListEqual( 213 | tokenizer.tokenize("appleやappleの辞書形はAppleで正規形はアップルである。"), 214 | ["Apple", "や", "Apple", "の", "辞書", "形", "は", "Apple", "で", "正規", "形", "は", "アップル", "で", "ある", "。"] 215 | ) 216 | 217 | def test_sudachipy_tokenizer_normalized_form_leaved_conjugation(self): 218 | tokenizer = self.tokenizer_class(self.vocab_file, do_subword_tokenize=False, 219 | word_form_type='normalized_conjugation') 220 | 221 | self.assertListEqual( 222 | tokenizer.tokenize("appleの辞書形はAppleで正規形はアップルである。"), 223 | ["アップル", "の", "辞書", "形", "は", "アップル", "で", "正規", "形", "は", "アップル", "で", "有る", "。"] 224 | ) 225 | 226 | def test_sudachipy_tokenizer_normalized_form_leaved_conjugation_can_do(self): 227 | tokenizer = self.tokenizer_class(self.vocab_file, do_subword_tokenize=False, 228 | word_form_type='normalized_conjugation', sudachipy_kwargs={"split_mode": "C"}) 229 | 230 | self.assertListEqual( 231 | tokenizer.tokenize("畳み込めたので大丈夫です。"), 232 | ["畳み込み", "た", "の", "で", "大丈夫", "です", "。"] 233 | ) 234 | 235 | self.assertListEqual( 236 | tokenizer.tokenize("泳げます。"), 237 | ["泳ぎ", "ます", "。"] 238 | ) 239 | 240 | def test_sudachipy_tokenizer_unit_a(self): 241 | try: 242 | tokenizer = SudachipyWordTokenizer(split_mode="A") 243 | except ModuleNotFoundError: 244 | return 245 | 246 | self.assertListEqual( 247 | list(map(word_formatter('surface', tokenizer.sudachi_dict), tokenizer.tokenize("徳島阿波おどり空港"))), 248 | ["徳島", "阿波", "おどり", "空港"] 249 | ) 250 | 251 | def test_sudachipy_tokenizer_unit_b(self): 252 | try: 253 | tokenizer = SudachipyWordTokenizer(split_mode="B") 254 | except ModuleNotFoundError: 255 | return 256 | 257 | self.assertListEqual( 258 | list(map(word_formatter('surface', tokenizer.sudachi_dict), tokenizer.tokenize("徳島阿波おどり空港"))), 259 | ["徳島", "阿波おどり", "空港"] 260 | ) 261 | 262 | def test_sudachipy_tokenizer_unit_c(self): 263 | try: 264 | tokenizer = SudachipyWordTokenizer(split_mode="C") 265 | except ModuleNotFoundError: 266 | return 267 | 268 | self.assertListEqual( 269 | list(map(word_formatter('surface', tokenizer.sudachi_dict), tokenizer.tokenize("徳島阿波おどり空港"))), 270 | ["徳島阿波おどり空港"] 271 | ) 272 | 273 | def test_wordpiece_tokenizer(self): 274 | vocab_tokens = ["[UNK]", "[CLS]", "[SEP]", "こんにちは", "こん", "にちは" "ばんは", "##こん", "##にちは", "##ばんは"] 275 | 276 | vocab = {} 277 | for (i, token) in enumerate(vocab_tokens): 278 | vocab[token] = i 279 | tokenizer = WordpieceTokenizer(vocab=vocab, unk_token="[UNK]") 280 | 281 | self.assertListEqual(tokenizer.tokenize(""), []) 282 | 283 | self.assertListEqual(tokenizer.tokenize("こんにちは"), ["こんにちは"]) 284 | 285 | self.assertListEqual(tokenizer.tokenize("こんばんは"), ["こん", "##ばんは"]) 286 | 287 | self.assertListEqual(tokenizer.tokenize("こんばんは こんばんにちは こんにちは"), ["こん", "##ばんは", "[UNK]", "こんにちは"]) 288 | 289 | def test_sequence_builders(self): 290 | pass 291 | 292 | 293 | class BertSudachipyCharacterTokenizationTest(unittest.TestCase): 294 | 295 | tokenizer_class = BertSudachipyTokenizer 296 | 297 | def setUp(self): 298 | super().setUp() 299 | self.tmpdirname = tempfile.mkdtemp() 300 | 301 | vocab_tokens = ["[UNK]", "[CLS]", "[SEP]", "こ", "ん", "に", "ち", "は", "ば", "世", "界", "、", "。"] 302 | 303 | self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["vocab_file"]) 304 | with open(self.vocab_file, "w", encoding="utf-8") as vocab_writer: 305 | vocab_writer.write("".join([x + "\n" for x in vocab_tokens])) 306 | 307 | def test_full_tokenizer(self): 308 | tokenizer = self.tokenizer_class(self.vocab_file, subword_tokenizer_type="character") 309 | 310 | tokens = tokenizer.tokenize("こんにちは、世界。こんばんは、世界。") 311 | self.assertListEqual( 312 | tokens, ["こ", "ん", "に", "ち", "は", "、", "世", "界", "。", "こ", "ん", "ば", "ん", "は", "、", "世", "界", "。"] 313 | ) 314 | self.assertListEqual( 315 | tokenizer.convert_tokens_to_ids(tokens), [3, 4, 5, 6, 7, 11, 9, 10, 12, 3, 4, 8, 4, 7, 11, 9, 10, 12] 316 | ) 317 | 318 | def test_sequence_builders(self): 319 | pass 320 | --------------------------------------------------------------------------------