├── .flake8 ├── .github └── workflows │ ├── ci.yml │ └── release.yml ├── .gitignore ├── LICENSE ├── Makefile ├── README.md ├── allennlp_shiba ├── __init__.py ├── common │ ├── __init__.py │ └── testing │ │ ├── __init__.py │ │ └── test_case.py ├── data │ ├── __init__.py │ ├── token_indexers │ │ ├── __init__.py │ │ └── pretrained_shiba_indexer.py │ └── tokenizers │ │ ├── __init__.py │ │ └── codepoint_tokenizer.py └── modules │ ├── __init__.py │ └── token_embedders │ ├── __init__.py │ └── pretrained_shiba_embedder.py ├── poetry.lock ├── pyproject.toml ├── test_fixtures ├── data │ └── tokenizers │ │ └── codepoint_tokenizer.jsonnet └── modules │ └── token_embedders │ └── pretrained_shiba_embedder.jsonnet └── tests ├── __init__.py ├── data ├── __init__.py ├── token_indexers │ ├── __init__.py │ └── pretrained_shiba_indexer_test.py └── tokenizers │ ├── __init__.py │ └── code_point_tokenizer_test.py └── modules ├── __init__.py └── token_embedders ├── __init__.py └── pretrained_shiba_embedder_test.py /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 115 3 | 4 | ignore = 5 | # these rules don't play well with black 6 | E203 # whitespace before : 7 | W503 # line break before binary operator 8 | 9 | per-file-ignores = 10 | # __init__.py files are allowed to have unused imports and lines-too-long 11 | */__init__.py:F401 12 | */**/**/__init__.py:F401,E501 13 | 14 | # tests don't have to respect 15 | # E731: do not assign a lambda expression, use a def 16 | tests/**:E731 17 | 18 | # scripts don't have to respect 19 | # E402: imports not at top of file (because we mess with sys.path) 20 | scripts/**:E402 21 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: 4 | push: 5 | branches: [master] 6 | pull_request: 7 | branches: [master] 8 | 9 | jobs: 10 | build: 11 | runs-on: ubuntu-latest 12 | strategy: 13 | matrix: 14 | python-version: [3.7, 3.8] 15 | 16 | steps: 17 | - uses: actions/checkout@v2 18 | - name: Set up Python ${{ matrix.python-version }} 19 | uses: actions/setup-python@v2 20 | with: 21 | python-version: ${{ matrix.python-version }} 22 | 23 | - name: Install dependencies 24 | run: | 25 | python -m pip install --upgrade pip 26 | pip install . 27 | 28 | - name: Format 29 | run: | 30 | pip install black 31 | make format 32 | 33 | - name: Lint 34 | run: | 35 | pip install flake8 36 | make lint 37 | 38 | - name: Type check 39 | run: | 40 | pip install mypy 41 | make typecheck 42 | 43 | - name: Run tests 44 | run: | 45 | pip install pytest 46 | make test 47 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: Release 2 | 3 | on: 4 | push: 5 | tags: 6 | - "v[0-9]+.[0-9]+.[0-9]+" 7 | 8 | jobs: 9 | deploy: 10 | runs-on: ubuntu-latest 11 | 12 | steps: 13 | - uses: actions/checkout@v2 14 | - name: Set up Python 15 | uses: actions/setup-python@v2 16 | with: 17 | python-version: "3.x" 18 | 19 | - name: Install dependencies 20 | run: | 21 | python -m pip install --upgrade pip 22 | pip install poetry poetry-dynamic-versioning twine 23 | - name: Build and publish 24 | env: 25 | TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} 26 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} 27 | run: | 28 | poetry publish --build --username $TWINE_USERNAME --password $TWINE_PASSWORD 29 | release: 30 | runs-on: ubuntu-latest 31 | 32 | steps: 33 | - name: Checkout code 34 | uses: actions/checkout@v2 35 | 36 | - name: Create Release 37 | id: create_release 38 | uses: actions/create-release@v1 39 | env: 40 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 41 | with: 42 | tag_name: ${{ github.ref }} 43 | release_name: Release ${{ github.ref }} 44 | draft: false 45 | prerelease: false 46 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by https://www.toptal.com/developers/gitignore/api/python 2 | # Edit at https://www.toptal.com/developers/gitignore?templates=python 3 | 4 | ### Python ### 5 | # Byte-compiled / optimized / DLL files 6 | __pycache__/ 7 | *.py[cod] 8 | *$py.class 9 | 10 | # C extensions 11 | *.so 12 | 13 | # Distribution / packaging 14 | .Python 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | share/python-wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .nox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | *.py,cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | cover/ 57 | 58 | # Translations 59 | *.mo 60 | *.pot 61 | 62 | # Django stuff: 63 | *.log 64 | local_settings.py 65 | db.sqlite3 66 | db.sqlite3-journal 67 | 68 | # Flask stuff: 69 | instance/ 70 | .webassets-cache 71 | 72 | # Scrapy stuff: 73 | .scrapy 74 | 75 | # Sphinx documentation 76 | docs/_build/ 77 | 78 | # PyBuilder 79 | .pybuilder/ 80 | target/ 81 | 82 | # Jupyter Notebook 83 | .ipynb_checkpoints 84 | 85 | # IPython 86 | profile_default/ 87 | ipython_config.py 88 | 89 | # pyenv 90 | # For a library or package, you might want to ignore these files since the code is 91 | # intended to run in multiple environments; otherwise, check them in: 92 | .python-version 93 | 94 | # pipenv 95 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 96 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 97 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 98 | # install all needed dependencies. 99 | #Pipfile.lock 100 | 101 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 102 | __pypackages__/ 103 | 104 | # Celery stuff 105 | celerybeat-schedule 106 | celerybeat.pid 107 | 108 | # SageMath parsed files 109 | *.sage.py 110 | 111 | # Environments 112 | .env 113 | .venv 114 | env/ 115 | venv/ 116 | ENV/ 117 | env.bak/ 118 | venv.bak/ 119 | 120 | # Spyder project settings 121 | .spyderproject 122 | .spyproject 123 | 124 | # Rope project settings 125 | .ropeproject 126 | 127 | # mkdocs documentation 128 | /site 129 | 130 | # mypy 131 | .mypy_cache/ 132 | .dmypy.json 133 | dmypy.json 134 | 135 | # Pyre type checker 136 | .pyre/ 137 | 138 | # pytype static type analyzer 139 | .pytype/ 140 | 141 | # Cython debug symbols 142 | cython_debug/ 143 | 144 | # End of https://www.toptal.com/developers/gitignore/api/python 145 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY : lint 2 | lint : 3 | flake8 . 4 | 5 | .PHONY : format 6 | format : 7 | black --check . 8 | 9 | .PHONY : typecheck 10 | typecheck : 11 | mypy . \ 12 | --ignore-missing-imports \ 13 | --no-strict-optional \ 14 | --no-site-packages \ 15 | --cache-dir=/dev/null 16 | 17 | .PHONY : test 18 | test : 19 | pytest --color=yes -rf --durations=40 20 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Allennlp Integration for [Shiba](https://github.com/octanove/shiba) 2 | 3 | [![CI](https://github.com/shunk031/allennlp-shiba-model/actions/workflows/ci.yml/badge.svg)](https://github.com/shunk031/allennlp-shiba-model/actions/workflows/ci.yml) 4 | [![Release](https://github.com/shunk031/allennlp-shiba-model/actions/workflows/release.yml/badge.svg)](https://github.com/shunk031/allennlp-shiba-model/actions/workflows/release.yml) 5 | ![Python](https://img.shields.io/badge/python-3.7%20%7C%203.8-blue?logo=python) 6 | [![PyPI](https://img.shields.io/pypi/v/allennlp-shiba.svg)](https://pypi.org/project/allennlp-shiba/) 7 | 8 | `allennlp-shiab-model` is a Python library that provides AllenNLP integration for [shiba-model](https://pypi.org/project/shiba-model/). 9 | 10 | > SHIBA is an approximate reimplementation of CANINE [[1]](https://github.com/octanove/shiba#1) in raw Pytorch, pretrained on the Japanese wikipedia corpus using random span masking. If you are unfamiliar with CANINE, you can think of it as a very efficient (approximately 4x as efficient) character-level BERT model. Of course, the name SHIBA comes from the identically named Japanese canine. 11 | 12 | ## Installation 13 | 14 | Installing the library and dependencies is simple using `pip`. 15 | 16 | ```shell 17 | pip install allennlp-shiba 18 | ``` 19 | 20 | ## Example 21 | 22 | This library enables users to specify the in a jsonnet config file. Here is an example of the model in jsonnet config file: 23 | 24 | ```json 25 | { 26 | "dataset_reader": { 27 | "tokenizer": { 28 | "type": "shiba", 29 | }, 30 | "token_indexers": { 31 | "tokens": { 32 | "type": "shiba", 33 | } 34 | }, 35 | }, 36 | "model": { 37 | "shiba_embedder": { 38 | "type": "basic", 39 | "token_embedders": { 40 | "shiba": { 41 | "type": "shiba", 42 | "eval_model": true, 43 | } 44 | } 45 | 46 | } 47 | } 48 | } 49 | ``` 50 | 51 | 52 | ## Reference 53 | 54 | - Joshua Tanner and Masato Hagiwara (2021). [SHIBA: Japanese CANINE model](https://github.com/octanove/shiba). GitHub repository, GitHub. 55 | 56 | -------------------------------------------------------------------------------- /allennlp_shiba/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shunk031/allennlp-shiba-model/0741bcfd5e118c6f2d0b83b49840d04e2278a26e/allennlp_shiba/__init__.py -------------------------------------------------------------------------------- /allennlp_shiba/common/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shunk031/allennlp-shiba-model/0741bcfd5e118c6f2d0b83b49840d04e2278a26e/allennlp_shiba/common/__init__.py -------------------------------------------------------------------------------- /allennlp_shiba/common/testing/__init__.py: -------------------------------------------------------------------------------- 1 | from allennlp_shiba.common.testing.test_case import AllennlpShibaTestCase # NOQA 2 | -------------------------------------------------------------------------------- /allennlp_shiba/common/testing/test_case.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | 3 | from allennlp.common.testing import AllenNlpTestCase 4 | 5 | 6 | class AllennlpShibaTestCase(AllenNlpTestCase): 7 | PROJECT_ROOT = (pathlib.Path(__file__).parent / ".." / ".." / "..").resolve() 8 | MODULE_ROOT = PROJECT_ROOT / "allennlp_shiba" 9 | TOOLS_ROOT = MODULE_ROOT / "tools" 10 | TESTS_ROOT = PROJECT_ROOT / "tests" 11 | FIXTURES_ROOT = PROJECT_ROOT / "test_fixtures" 12 | -------------------------------------------------------------------------------- /allennlp_shiba/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shunk031/allennlp-shiba-model/0741bcfd5e118c6f2d0b83b49840d04e2278a26e/allennlp_shiba/data/__init__.py -------------------------------------------------------------------------------- /allennlp_shiba/data/token_indexers/__init__.py: -------------------------------------------------------------------------------- 1 | from allennlp_shiba.data.token_indexers.pretrained_shiba_indexer import ( # NOQA 2 | PretrainedShibaIndexer, 3 | ) 4 | -------------------------------------------------------------------------------- /allennlp_shiba/data/token_indexers/pretrained_shiba_indexer.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Optional 2 | 3 | from allennlp.data.token_indexers.token_indexer import IndexedTokenList, TokenIndexer 4 | from allennlp.data.tokenizers.token_class import Token 5 | from allennlp.data.vocabulary import Vocabulary 6 | from allennlp_shiba.data.tokenizers import ShibaCodepointTokenizer 7 | from overrides import overrides 8 | 9 | 10 | @TokenIndexer.register("shiba") 11 | class PretrainedShibaIndexer(TokenIndexer): 12 | def __init__( 13 | self, 14 | namespace: str = "tags", 15 | max_length: Optional[int] = None, 16 | token_min_padding_length: int = 0, 17 | ) -> None: 18 | 19 | super().__init__(token_min_padding_length=token_min_padding_length) 20 | 21 | self._namespace = namespace 22 | self._allennlp_tokenizer = ShibaCodepointTokenizer() 23 | self._tokenizer = self._allennlp_tokenizer._tokenizer 24 | self._added_to_vocabulary = False 25 | 26 | self._max_length = max_length 27 | if self._max_length is not None: 28 | num_added_tokens = len(self._allennlp_tokenizer.tokenize("a")) - 1 29 | self._effective_max_length = self._max_length - num_added_tokens 30 | 31 | if self._effective_max_length <= 0: 32 | raise ValueError( 33 | "max_length needs to be greater than the number of special tokens inserted." 34 | ) 35 | 36 | def _add_encoding_to_vocabulary_if_needed(self, vocab: Vocabulary) -> None: 37 | if self._added_to_vocabulary: 38 | return 39 | 40 | vocab.add_transformer_vocab(self._tokenizer, self._namespace) 41 | 42 | self._added_to_vocabulary = True 43 | 44 | def _extract_token(self, tokens: List[Token]) -> List[int]: 45 | indices: List[int] = [] 46 | for token in tokens: 47 | indices.append( 48 | token.text_id 49 | if token.text_id is not None 50 | else self._tokenizer.encode(token.text) 51 | ) 52 | return indices 53 | 54 | def _postprocess_output(self, output: IndexedTokenList) -> IndexedTokenList: 55 | if self._max_length is not None: 56 | # TODO (shunk031): need to support for max_length 57 | raise NotImplementedError 58 | 59 | return output 60 | 61 | @overrides 62 | def count_vocab_items( 63 | self, token: Token, counter: Dict[str, Dict[str, int]] 64 | ) -> None: 65 | # If we only use pretrained models, we don't need to do anything here. 66 | pass 67 | 68 | @overrides 69 | def tokens_to_indices( 70 | self, tokens: List[Token], vocabulary: Vocabulary 71 | ) -> IndexedTokenList: 72 | 73 | indices = self._extract_token(tokens) 74 | 75 | output: IndexedTokenList = { 76 | "token_ids": indices, 77 | "mask": [True] * len(indices), 78 | } 79 | 80 | return self._postprocess_output(output) 81 | 82 | @overrides 83 | def indices_to_tokens( 84 | self, indexed_tokens: IndexedTokenList, vocabulary: Vocabulary 85 | ) -> List[Token]: 86 | self._add_encoding_to_vocabulary_if_needed(vocabulary) 87 | 88 | token_ids = indexed_tokens["token_ids"] 89 | type_ids = indexed_tokens.get("type_ids") 90 | 91 | return [ 92 | Token( 93 | text=vocabulary.get_token_from_index(token_ids[i], self._namespace), 94 | text_id=token_ids[i], 95 | type_id=type_ids[i] if type_ids is not None else None, 96 | ) 97 | for i in range(len(token_ids)) 98 | ] 99 | 100 | @overrides 101 | def get_empty_token_list(self) -> IndexedTokenList: 102 | output: IndexedTokenList = {"token_ids": [], "mask": []} 103 | return output 104 | -------------------------------------------------------------------------------- /allennlp_shiba/data/tokenizers/__init__.py: -------------------------------------------------------------------------------- 1 | from allennlp_shiba.data.tokenizers.codepoint_tokenizer import ( # NOQA 2 | ShibaCodepointTokenizer, 3 | ) 4 | -------------------------------------------------------------------------------- /allennlp_shiba/data/tokenizers/codepoint_tokenizer.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from allennlp.data.tokenizers.token_class import Token 4 | from allennlp.data.tokenizers.tokenizer import Tokenizer 5 | from overrides import overrides 6 | from shiba import CodepointTokenizer 7 | 8 | 9 | @Tokenizer.register("shiba") 10 | class ShibaCodepointTokenizer(Tokenizer): 11 | def __init__(self) -> None: 12 | super().__init__() 13 | self._tokenizer = CodepointTokenizer() 14 | 15 | @overrides 16 | def tokenize(self, text: str) -> List[Token]: 17 | 18 | encoded_tokens = self._tokenizer.encode(text) 19 | token_ids = encoded_tokens["input_ids"] 20 | 21 | tokens = [] 22 | for tensor_token_id in token_ids: 23 | token_id = tensor_token_id.item() 24 | tokens.append( 25 | Token(text=self._tokenizer.decode([token_id]), text_id=token_id) 26 | ) 27 | 28 | return tokens 29 | -------------------------------------------------------------------------------- /allennlp_shiba/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shunk031/allennlp-shiba-model/0741bcfd5e118c6f2d0b83b49840d04e2278a26e/allennlp_shiba/modules/__init__.py -------------------------------------------------------------------------------- /allennlp_shiba/modules/token_embedders/__init__.py: -------------------------------------------------------------------------------- 1 | from allennlp_shiba.modules.token_embedders.pretrained_shiba_embedder import ( # NOQA 2 | PretrainedShibaEmbedder, 3 | ) 4 | -------------------------------------------------------------------------------- /allennlp_shiba/modules/token_embedders/pretrained_shiba_embedder.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | import torch.nn as nn 5 | from allennlp.modules.token_embedders.token_embedder import TokenEmbedder 6 | from overrides import overrides 7 | from shiba import Shiba 8 | from shiba.model import get_pretrained_state_dict 9 | 10 | 11 | @TokenEmbedder.register("shiba") 12 | class PretrainedShibaEmbedder(TokenEmbedder): 13 | def __init__( 14 | self, 15 | downsampling_rate: int = 4, 16 | upsampling_kernel_size: int = 4, 17 | embedder_slice_count: int = 8, 18 | embedder_bucket_count: int = 16000, 19 | hidden_size: int = 768, 20 | local_attention_window: int = 128, 21 | deep_transformer_stack: Optional[nn.Module] = None, 22 | deep_transformer_requires_transpose: bool = True, 23 | attention_heads: int = 12, 24 | transformer_ff_size: int = 3072, 25 | dropout: float = 0.1, 26 | activation: str = "gelu", 27 | padding_id: int = 0, 28 | max_length: int = 2048, 29 | shiba_specific_code: bool = False, 30 | deep_transformer_stack_layers: Optional[int] = None, 31 | train_parameters: bool = True, 32 | eval_mode: bool = False, 33 | ) -> None: 34 | super().__init__() 35 | 36 | shiba_model = Shiba( 37 | downsampling_rate=downsampling_rate, 38 | upsampling_kernel_size=upsampling_kernel_size, 39 | embedder_slice_count=embedder_slice_count, 40 | embedder_bucket_count=embedder_bucket_count, 41 | hidden_size=hidden_size, 42 | local_attention_window=local_attention_window, 43 | deep_transformer_stack=deep_transformer_stack, 44 | deep_transformer_requires_transpose=deep_transformer_requires_transpose, 45 | attention_heads=attention_heads, 46 | transformer_ff_size=transformer_ff_size, 47 | dropout=dropout, 48 | activation=activation, 49 | padding_id=padding_id, 50 | max_length=max_length, 51 | shiba_specific_code=shiba_specific_code, 52 | deep_transformer_stack_layers=deep_transformer_stack_layers, 53 | ) 54 | shiba_model.load_state_dict(get_pretrained_state_dict()) 55 | 56 | self.shiba_model = shiba_model 57 | self.config = self.shiba_model.config 58 | 59 | # I'm not sure if this works for all models; open an issue on github if you find a case 60 | # where it doesn't work. 61 | self.output_dim = self.config.hidden_size 62 | 63 | self.train_parameters = train_parameters 64 | if not train_parameters: 65 | for param in self.shiba_model.parameters(): 66 | param.requires_grad = False 67 | 68 | self.eval_mode = eval_mode 69 | if eval_mode: 70 | self.shiba_model.eval() 71 | 72 | @overrides 73 | def get_output_dim(self) -> int: 74 | return self.output_dim 75 | 76 | @overrides 77 | def train(self, mode: bool = True): 78 | self.training = mode 79 | for name, module in self.named_children(): 80 | if self.eval_mode and name == "deep_transformer": 81 | module.eval() 82 | else: 83 | module.train(mode) 84 | return self 85 | 86 | @overrides 87 | def forward( 88 | self, token_ids: torch.LongTensor, mask: torch.BoolTensor 89 | ) -> torch.Tensor: 90 | output = self.shiba_model(input_ids=token_ids, attention_mask=~mask) 91 | return output["embeddings"] 92 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "allennlp-shiba" 3 | version = "0.0.0" # using poetry-dynamic-versioning 4 | license = "Apache-2.0" 5 | description = "AllenNLP integration for Shiba: Japanese CANINE model" 6 | readme = "README.md" 7 | homepage = "https://github.com/shunk031/allennlp-shiba-model" 8 | repository = "https://github.com/shunk031/allennlp-shiba-model" 9 | authors = ["Shunsuke KITADA "] 10 | keywords = [ 11 | "natural language processing", 12 | "deep learning", 13 | "transformers", 14 | "allennlp", 15 | ] 16 | classifiers = [ 17 | "Programming Language :: Python :: 3", 18 | "License :: OSI Approved :: Apache Software License", 19 | "Operating System :: OS Independent", 20 | ] 21 | 22 | [tool.poetry.dependencies] 23 | python = "^3.7" 24 | allennlp = "^2.5.0" 25 | shiba-model = "^0.1.0" 26 | 27 | [tool.poetry.dev-dependencies] 28 | black = "^21.6b0" 29 | isort = "^5.9.1" 30 | flake8 = "^3.9.2" 31 | mypy = "^0.910" 32 | pytest = "^6.2.4" 33 | 34 | [build-system] 35 | requires = ["poetry-core>=1.0.0"] 36 | build-backend = "poetry.core.masonry.api" 37 | 38 | [tool.poetry-dynamic-versioning] 39 | enable = true 40 | -------------------------------------------------------------------------------- /test_fixtures/data/tokenizers/codepoint_tokenizer.jsonnet: -------------------------------------------------------------------------------- 1 | { 2 | type: "shiba", 3 | } 4 | -------------------------------------------------------------------------------- /test_fixtures/modules/token_embedders/pretrained_shiba_embedder.jsonnet: -------------------------------------------------------------------------------- 1 | { 2 | type: "shiba", 3 | } 4 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shunk031/allennlp-shiba-model/0741bcfd5e118c6f2d0b83b49840d04e2278a26e/tests/__init__.py -------------------------------------------------------------------------------- /tests/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shunk031/allennlp-shiba-model/0741bcfd5e118c6f2d0b83b49840d04e2278a26e/tests/data/__init__.py -------------------------------------------------------------------------------- /tests/data/token_indexers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shunk031/allennlp-shiba-model/0741bcfd5e118c6f2d0b83b49840d04e2278a26e/tests/data/token_indexers/__init__.py -------------------------------------------------------------------------------- /tests/data/token_indexers/pretrained_shiba_indexer_test.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from allennlp.data.vocabulary import Vocabulary 3 | from allennlp_shiba.common.testing import AllennlpShibaTestCase 4 | from allennlp_shiba.data.token_indexers import PretrainedShibaIndexer 5 | from allennlp_shiba.data.tokenizers import ShibaCodepointTokenizer 6 | from shiba import CodepointTokenizer 7 | 8 | 9 | class TestPretrainedShibaIndexer(AllennlpShibaTestCase): 10 | @pytest.mark.parametrize( 11 | "input_str", 12 | ("自然言語処理", "柴ドリル"), 13 | ) 14 | def test_pretrained_shiba_indexer(self, input_str: str): 15 | tokenizer = CodepointTokenizer() 16 | allennlp_tokenizer = ShibaCodepointTokenizer() 17 | 18 | tokens = tokenizer.encode(input_str) 19 | allennlp_tokens = allennlp_tokenizer.tokenize(input_str) 20 | 21 | vocab = Vocabulary() 22 | indexer = PretrainedShibaIndexer() 23 | indexed = indexer.tokens_to_indices(allennlp_tokens, vocab) 24 | 25 | output_token_ids = indexed["token_ids"] 26 | expect_token_ids = tokens["input_ids"].cpu().numpy().tolist() 27 | 28 | output_mask = indexed["mask"] 29 | expect_mask = (~tokens["attention_mask"].cpu().numpy()).tolist() 30 | 31 | assert output_token_ids == expect_token_ids 32 | assert output_mask == expect_mask 33 | -------------------------------------------------------------------------------- /tests/data/tokenizers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shunk031/allennlp-shiba-model/0741bcfd5e118c6f2d0b83b49840d04e2278a26e/tests/data/tokenizers/__init__.py -------------------------------------------------------------------------------- /tests/data/tokenizers/code_point_tokenizer_test.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from allennlp.common import Params 3 | from allennlp.data.tokenizers import Tokenizer 4 | from allennlp_shiba.common.testing import AllennlpShibaTestCase 5 | from allennlp_shiba.data.tokenizers import ShibaCodepointTokenizer 6 | from shiba import CodepointTokenizer 7 | 8 | 9 | class TestCodepointTokenizer(AllennlpShibaTestCase): 10 | @pytest.mark.parametrize( 11 | "input_str", 12 | ( 13 | "自然言語処理", 14 | "柴ドリル", 15 | ), 16 | ) 17 | def test_tokenize(self, input_str: str) -> None: 18 | allennlp_tokenizer = ShibaCodepointTokenizer() 19 | original_tokenizer = CodepointTokenizer() 20 | 21 | allennlp_output = allennlp_tokenizer.tokenize(input_str) 22 | original_output = original_tokenizer.encode(input_str) 23 | 24 | allennlp_token_ids = list(map(lambda x: x.text_id, allennlp_output)) 25 | original_token_ids = original_output["input_ids"].cpu().numpy().tolist() 26 | 27 | assert allennlp_token_ids == original_token_ids 28 | 29 | def test_from_params(self) -> None: 30 | params = Params.from_file( 31 | self.FIXTURES_ROOT / "data" / "tokenizers" / "codepoint_tokenizer.jsonnet" 32 | ) 33 | tokenizer = Tokenizer.from_params(params) 34 | 35 | assert isinstance(tokenizer, ShibaCodepointTokenizer) 36 | -------------------------------------------------------------------------------- /tests/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shunk031/allennlp-shiba-model/0741bcfd5e118c6f2d0b83b49840d04e2278a26e/tests/modules/__init__.py -------------------------------------------------------------------------------- /tests/modules/token_embedders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shunk031/allennlp-shiba-model/0741bcfd5e118c6f2d0b83b49840d04e2278a26e/tests/modules/token_embedders/__init__.py -------------------------------------------------------------------------------- /tests/modules/token_embedders/pretrained_shiba_embedder_test.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | import torch 4 | from allennlp.common import Params 5 | from allennlp.data.batch import Batch 6 | from allennlp.data.fields import TextField 7 | from allennlp.data.instance import Instance 8 | from allennlp.data.vocabulary import Vocabulary 9 | from allennlp.modules.text_field_embedders import BasicTextFieldEmbedder 10 | from allennlp.modules.token_embedders.token_embedder import TokenEmbedder 11 | from allennlp_shiba.common.testing import AllennlpShibaTestCase 12 | from allennlp_shiba.data.token_indexers import PretrainedShibaIndexer 13 | from allennlp_shiba.data.tokenizers import ShibaCodepointTokenizer 14 | from allennlp_shiba.modules.token_embedders import PretrainedShibaEmbedder 15 | from shiba import CodepointTokenizer, Shiba, get_pretrained_state_dict 16 | 17 | 18 | class TestPretrainedShibaEmbedder(AllennlpShibaTestCase): 19 | def plane_shiba_output( 20 | self, sentence1: str, sentence2: str 21 | ) -> Dict[str, torch.Tensor]: 22 | 23 | shiba_model = Shiba() 24 | shiba_model.load_state_dict(get_pretrained_state_dict()) 25 | shiba_model.eval() # disable dropout 26 | tokenizer = CodepointTokenizer() 27 | 28 | inputs = tokenizer.encode_batch([sentence1, sentence2]) 29 | outputs = shiba_model(**inputs) 30 | 31 | return outputs["embeddings"] 32 | 33 | def test_pretrained_shiba_embedder(self): 34 | 35 | tokenizer = ShibaCodepointTokenizer() 36 | token_indexer = PretrainedShibaIndexer() 37 | 38 | sentence1 = "自然言語処理" 39 | tokens1 = tokenizer.tokenize(sentence1) 40 | expected_tokens1 = ["[CLS]", "自", "然", "言", "語", "処", "理"] 41 | assert [t.text for t in tokens1] == expected_tokens1 42 | 43 | sentence2 = "柴ドリル" 44 | tokens2 = tokenizer.tokenize(sentence2) 45 | expected_tokens2 = ["[CLS]", "柴", "ド", "リ", "ル"] 46 | assert [t.text for t in tokens2] == expected_tokens2 47 | 48 | vocab = Vocabulary() 49 | 50 | params_dict = { 51 | "token_embedders": { 52 | "shiba": { 53 | "type": "shiba", 54 | "eval_mode": True, 55 | } 56 | } 57 | } 58 | params = Params(params_dict) 59 | token_embedder = BasicTextFieldEmbedder.from_params(vocab=vocab, params=params) 60 | 61 | instance1 = Instance({"tokens": TextField(tokens1, {"shiba": token_indexer})}) 62 | instance2 = Instance({"tokens": TextField(tokens2, {"shiba": token_indexer})}) 63 | 64 | batch = Batch([instance1, instance2]) 65 | batch.index_instances(vocab) 66 | 67 | padding_lengths = batch.get_padding_lengths() 68 | tensor_dict = batch.as_tensor_dict(padding_lengths) 69 | tokens = tensor_dict["tokens"] 70 | 71 | shiba_vectors = token_embedder(tokens) 72 | plane_shiba_vectors = self.plane_shiba_output(sentence1, sentence2) 73 | 74 | assert torch.all(torch.eq(shiba_vectors, plane_shiba_vectors)) 75 | 76 | def test_from_params(self) -> None: 77 | params = Params.from_file( 78 | self.FIXTURES_ROOT 79 | / "modules" 80 | / "token_embedders" 81 | / "pretrained_shiba_embedder.jsonnet" 82 | ) 83 | token_embedder = TokenEmbedder.from_params(params) 84 | assert isinstance(token_embedder, PretrainedShibaEmbedder) 85 | --------------------------------------------------------------------------------