├── .github ├── dependabot.yml └── workflows │ ├── lint.yml │ └── test.yml ├── .gitignore ├── .pre-commit-config.yaml ├── LICENSE ├── README.md ├── configs ├── base.yaml ├── callbacks │ ├── early_stopping.yaml │ ├── lr_monitor.yaml │ ├── model_checkpoint.yaml │ ├── model_summary.yaml │ └── progress_bar.yaml ├── datamodule │ ├── base │ │ ├── test.yaml │ │ ├── train.yaml │ │ └── valid.yaml │ ├── jcola.yaml │ ├── jcqa.yaml │ ├── jnli.yaml │ ├── jsquad.yaml │ ├── jsts.yaml │ ├── marc_ja.yaml │ ├── test │ │ ├── jcola.yaml │ │ ├── jcqa.yaml │ │ ├── jnli.yaml │ │ ├── jsquad.yaml │ │ ├── jsts.yaml │ │ └── marc_ja.yaml │ ├── train │ │ ├── jcola.yaml │ │ ├── jcqa.yaml │ │ ├── jnli.yaml │ │ ├── jsquad.yaml │ │ ├── jsts.yaml │ │ └── marc_ja.yaml │ └── valid │ │ ├── jcola.yaml │ │ ├── jcola_ood.yaml │ │ ├── jcola_ood_annotated.yaml │ │ ├── jcqa.yaml │ │ ├── jnli.yaml │ │ ├── jsquad.yaml │ │ ├── jsts.yaml │ │ └── marc_ja.yaml ├── eval.yaml ├── jcola.debug.yaml ├── jcola.yaml ├── jcqa.debug.yaml ├── jcqa.yaml ├── jnli.debug.yaml ├── jnli.yaml ├── jsquad.debug.yaml ├── jsquad.yaml ├── jsts.debug.yaml ├── jsts.yaml ├── logger │ └── wandb.yaml ├── marc_ja.debug.yaml ├── marc_ja.yaml ├── model │ ├── deberta_v2_base.yaml │ ├── deberta_v2_large.yaml │ ├── deberta_v2_tiny.yaml │ ├── deberta_v3_base.yaml │ ├── luke_base.yaml │ ├── luke_large.yaml │ ├── microsoft__mdeberta_v3_base.yaml │ ├── modernbert_130m.yaml │ ├── modernbert_30m.yaml │ ├── modernbert_310m.yaml │ ├── modernbert_70m.yaml │ ├── roberta_base.yaml │ └── roberta_large.yaml ├── module │ ├── jcola.yaml │ ├── jcqa.yaml │ ├── jnli.yaml │ ├── jsquad.yaml │ ├── jsts.yaml │ └── marc_ja.yaml ├── optimizer │ └── adamw.yaml ├── scheduler │ ├── constant_schedule_with_warmup.yaml │ ├── cosine_schedule_with_warmup.yaml │ └── linear_schedule_with_warmup.yaml └── trainer │ ├── cpu.debug.yaml │ ├── cpu.yaml │ ├── debug.yaml │ └── default.yaml ├── pyproject.toml ├── scripts ├── gen_sweeps.sh ├── gen_table.py └── run_sweeps.sh ├── src ├── datamodule │ ├── __init__.py │ ├── datamodule.py │ └── datasets │ │ ├── __init__.py │ │ ├── base.py │ │ ├── jcola.py │ │ ├── jcqa.py │ │ ├── jnli.py │ │ ├── jsquad.py │ │ ├── jsts.py │ │ ├── marc_ja.py │ │ └── util.py ├── metrics │ ├── __init__.py │ └── jsquad.py ├── modules │ ├── __init__.py │ ├── base.py │ ├── jcola.py │ ├── jcqa.py │ ├── jnli.py │ ├── jsquad.py │ ├── jsts.py │ └── marc_ja.py ├── test.py └── train.py ├── sweeps ├── jcola.yaml ├── jcqa.yaml ├── jnli.yaml ├── jsquad.yaml ├── jsts.yaml └── marc_ja.yaml ├── tests ├── datasets │ ├── conftest.py │ └── test_jsquad.py └── metrics │ └── test_squad.py └── uv.lock /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | # To get started with Dependabot version updates, you'll need to specify which 2 | # package ecosystems to update and where the package manifests are located. 3 | # Please see the documentation for all configuration options: 4 | # https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates 5 | 6 | version: 2 7 | updates: 8 | - package-ecosystem: "uv" 9 | directory: "/" 10 | schedule: 11 | interval: "monthly" 12 | timezone: "Asia/Tokyo" 13 | groups: 14 | dependencies: 15 | patterns: 16 | - "*" 17 | target-branch: "main" 18 | versioning-strategy: lockfile-only 19 | 20 | - package-ecosystem: "github-actions" 21 | # Workflow files stored in the 22 | # default location of `.github/workflows` 23 | directory: "/" 24 | schedule: 25 | interval: "monthly" 26 | timezone: "Asia/Tokyo" 27 | target-branch: "main" 28 | -------------------------------------------------------------------------------- /.github/workflows/lint.yml: -------------------------------------------------------------------------------- 1 | name: Lint 2 | 3 | on: [ push, pull_request ] 4 | 5 | jobs: 6 | lint: 7 | runs-on: ubuntu-latest 8 | steps: 9 | - name: Checkout repository 10 | uses: actions/checkout@v4 11 | - name: Install pre-commit and run linters 12 | run: | 13 | pipx install pre-commit 14 | pre-commit run --all-files 15 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: Test 2 | 3 | on: [ push, pull_request ] 4 | 5 | jobs: 6 | test: 7 | name: Run tests with pytest 8 | container: kunlp/jumanpp:ubuntu24.04 9 | runs-on: ubuntu-24.04 10 | strategy: 11 | max-parallel: 5 12 | fail-fast: false 13 | matrix: 14 | python-version: [ "3.9", "3.10", "3.11", "3.12", "3.13" ] 15 | steps: 16 | - name: Checkout repository 17 | uses: actions/checkout@v4 18 | - name: Install required apt packages 19 | run: | 20 | export DEBIAN_FRONTEND=noninteractive 21 | apt-get update -yq 22 | apt-get install -yq curl build-essential libsqlite3-dev libffi-dev 23 | - name: Install uv 24 | uses: astral-sh/setup-uv@v6 25 | with: 26 | python-version: ${{ matrix.python-version }} 27 | - name: Install dependencies 28 | run: uv sync --no-cache 29 | - name: Run tests 30 | run: uv run pytest -v ./tests 31 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | # Created by https://www.toptal.com/developers/gitignore/api/python 3 | # Edit at https://www.toptal.com/developers/gitignore?templates=python 4 | 5 | ### Python ### 6 | # Byte-compiled / optimized / DLL files 7 | __pycache__/ 8 | *.py[cod] 9 | *$py.class 10 | 11 | # C extensions 12 | *.so 13 | 14 | # Distribution / packaging 15 | .Python 16 | build/ 17 | develop-eggs/ 18 | dist/ 19 | downloads/ 20 | eggs/ 21 | .eggs/ 22 | lib/ 23 | lib64/ 24 | parts/ 25 | sdist/ 26 | var/ 27 | wheels/ 28 | share/python-wheels/ 29 | *.egg-info/ 30 | .installed.cfg 31 | *.egg 32 | MANIFEST 33 | 34 | # PyInstaller 35 | # Usually these files are written by a python script from a template 36 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 37 | *.manifest 38 | *.spec 39 | 40 | # Installer logs 41 | pip-log.txt 42 | pip-delete-this-directory.txt 43 | 44 | # Unit test / coverage reports 45 | htmlcov/ 46 | .tox/ 47 | .nox/ 48 | .coverage 49 | .coverage.* 50 | .cache 51 | nosetests.xml 52 | coverage.xml 53 | *.cover 54 | *.py,cover 55 | .hypothesis/ 56 | .pytest_cache/ 57 | cover/ 58 | 59 | # Translations 60 | *.mo 61 | *.pot 62 | 63 | # Django stuff: 64 | *.log 65 | local_settings.py 66 | db.sqlite3 67 | db.sqlite3-journal 68 | 69 | # Flask stuff: 70 | instance/ 71 | .webassets-cache 72 | 73 | # Scrapy stuff: 74 | .scrapy 75 | 76 | # Sphinx documentation 77 | docs/_build/ 78 | 79 | # PyBuilder 80 | .pybuilder/ 81 | target/ 82 | 83 | # Jupyter Notebook 84 | .ipynb_checkpoints 85 | 86 | # IPython 87 | profile_default/ 88 | ipython_config.py 89 | 90 | # pyenv 91 | # For a library or package, you might want to ignore these files since the code is 92 | # intended to run in multiple environments; otherwise, check them in: 93 | # .python-version 94 | 95 | # pipenv 96 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 97 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 98 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 99 | # install all needed dependencies. 100 | #Pipfile.lock 101 | 102 | # poetry 103 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 104 | # This is especially recommended for binary packages to ensure reproducibility, and is more 105 | # commonly ignored for libraries. 106 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 107 | #poetry.lock 108 | 109 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 110 | __pypackages__/ 111 | 112 | # Celery stuff 113 | celerybeat-schedule 114 | celerybeat.pid 115 | 116 | # SageMath parsed files 117 | *.sage.py 118 | 119 | # Environments 120 | .env 121 | .venv 122 | env/ 123 | venv/ 124 | ENV/ 125 | env.bak/ 126 | venv.bak/ 127 | 128 | # Spyder project settings 129 | .spyderproject 130 | .spyproject 131 | 132 | # Rope project settings 133 | .ropeproject 134 | 135 | # mkdocs documentation 136 | /site 137 | 138 | # mypy 139 | .mypy_cache/ 140 | .dmypy.json 141 | dmypy.json 142 | 143 | # Pyre type checker 144 | .pyre/ 145 | 146 | # pytype static type analyzer 147 | .pytype/ 148 | 149 | # Cython debug symbols 150 | cython_debug/ 151 | 152 | # PyCharm 153 | # JetBrains specific template is maintainted in a separate JetBrains.gitignore that can 154 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 155 | # and can be added to the global gitignore or merged into this file. For a more nuclear 156 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 157 | #.idea/ 158 | 159 | # End of https://www.toptal.com/developers/gitignore/api/python 160 | 161 | data 162 | !tests/data 163 | result 164 | outputs 165 | wandb 166 | log 167 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | ci: 2 | autofix_prs: false 3 | autoupdate_schedule: monthly 4 | skip: [mypy] 5 | 6 | repos: 7 | - repo: https://github.com/pre-commit/pre-commit-hooks 8 | rev: v5.0.0 9 | hooks: 10 | - id: end-of-file-fixer 11 | - id: trailing-whitespace 12 | - id: check-yaml 13 | - repo: https://github.com/astral-sh/ruff-pre-commit 14 | rev: v0.11.12 15 | hooks: 16 | - id: ruff 17 | args: [ --fix, --exit-non-zero-on-fix ] 18 | - id: ruff-format 19 | - repo: https://github.com/pre-commit/mirrors-mypy 20 | rev: v1.16.0 21 | hooks: 22 | - id: mypy 23 | additional_dependencies: 24 | - hydra-core 25 | - torch 26 | - torchmetrics 27 | - transformers==4.50.3 28 | - datasets 29 | - tokenizers==0.21.1 30 | - wandb 31 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # JGLUE Evaluation Scripts 2 | 3 | [![test](https://github.com/nobu-g/JGLUE-evaluation-scripts/actions/workflows/test.yml/badge.svg)](https://github.com/nobu-g/JGLUE-evaluation-scripts/actions/workflows/test.yml) 4 | [![lint](https://github.com/nobu-g/JGLUE-evaluation-scripts/actions/workflows/lint.yml/badge.svg)](https://github.com/nobu-g/JGLUE-evaluation-scripts/actions/workflows/lint.yml) 5 | [![pre-commit.ci status](https://results.pre-commit.ci/badge/github/nobu-g/JGLUE-evaluation-scripts/main.svg)](https://results.pre-commit.ci/latest/github/nobu-g/JGLUE-evaluation-scripts/main) 6 | [![Ruff](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/astral-sh/ruff/main/assets/badge/v2.json)](https://github.com/astral-sh/ruff) 7 | [![uv](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/astral-sh/uv/main/assets/badge/v0.json)](https://github.com/astral-sh/uv) 8 | [![CodeFactor Grade](https://img.shields.io/codefactor/grade/github/nobu-g/JGLUE-evaluation-scripts)](https://www.codefactor.io/repository/github/nobu-g/JGLUE-evaluation-scripts) 9 | [![license](https://img.shields.io/github/license/nobu-g/JGLUE-evaluation-scripts?color=blue)](https://github.com/nobu-g/JGLUE-evaluation-scripts/blob/main/LICENSE) 10 | 11 | ## Requirements 12 | 13 | - Python: 3.9+ 14 | - Dependencies: See [pyproject.toml](./pyproject.toml). 15 | 16 | ## Getting started 17 | 18 | - Create a virtual environment and install dependencies. 19 | ```shell 20 | $ uv venv -p /path/to/python 21 | $ uv sync 22 | ``` 23 | 24 | - Log in to [wandb](https://wandb.ai/site). 25 | ```shell 26 | $ wandb login 27 | ``` 28 | 29 | ## Training and evaluation 30 | 31 | You can train and test a model with the following command: 32 | 33 | ```shell 34 | # For training and evaluating MARC-ja 35 | uv run python src/train.py -cn marc_ja devices=[0,1] max_batches_per_device=16 36 | ``` 37 | 38 | Here are commonly used options: 39 | 40 | - `-cn`: Task name. Choose from `marc_ja`, `jcola`, `jsts`, `jnli`, `jsquad`, and `jcqa`. 41 | - `devices`: GPUs to use. 42 | - `max_batches_per_device`: Maximum number of batches to process per device (default: `4`). 43 | - `compile`: JIT-compile the model 44 | with [torch.compile](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html) for faster training ( 45 | default: `false`). 46 | - `model`: Pre-trained model name. see YAML config files under [configs/model](./configs/model). 47 | 48 | To evaluate on the out-of-domain split of the JCoLA dataset, specify `datamodule/valid=jcola_ood` ( 49 | or `datamodule/valid=jcola_ood_annotated`). 50 | For more options, see YAML config files under [configs](./configs). 51 | 52 | 60 | 61 | ## Debugging 62 | 63 | ```shell 64 | uv run python scripts/train.py -cn marc_ja.debug 65 | ``` 66 | 67 | You can specify `trainer=cpu.debug` to use CPU. 68 | 69 | ```shell 70 | uv run python scripts/train.py -cn marc_ja.debug trainer=cpu.debug 71 | ``` 72 | 73 | If you are on a machine with GPUs, you can specify the GPUs to use with the `devices` option. 74 | 75 | ```shell 76 | uv run python scripts/train.py -cn marc_ja.debug devices=[0] 77 | ``` 78 | 79 | ## Tuning hyper-parameters 80 | 81 | ```shell 82 | $ wandb sweep <(sed 's/MODEL_NAME/deberta_base/' sweeps/jcola.yaml) 83 | wandb: Creating sweep from: /dev/fd/xx 84 | wandb: Created sweep with ID: xxxxxxxx 85 | wandb: View sweep at: https://wandb.ai//JGLUE-evaluation-scripts/sweeps/xxxxxxxx 86 | wandb: Run sweep agent with: wandb agent /JGLUE-evaluation-scripts/xxxxxxxx 87 | $ DEVICES=0,1 MAX_BATCHES_PER_DEVICE=16 COMPILE=true wandb agent /JGLUE-evaluation-scripts/xxxxxxxx 88 | ``` 89 | 90 | ## Results 91 | 92 | We fine-tuned the following models and evaluated them on the dev set of JGLUE. 93 | We tuned learning rate and training epochs for each model and task 94 | following [the JGLUE paper](https://www.jstage.jst.go.jp/article/jnlp/30/1/30_63/_pdf/-char/ja). 95 | 96 | | Model | MARC-ja/acc | JCoLA/acc | JSTS/pearson | JSTS/spearman | JNLI/acc | JSQuAD/EM | JSQuAD/F1 | JComQA/acc | 97 | |-------------------------------|-------------|-----------|--------------|---------------|----------|-----------|-----------|------------| 98 | | Waseda RoBERTa base | 0.965 | 0.867 | 0.913 | 0.876 | 0.905 | 0.853 | 0.916 | 0.853 | 99 | | Waseda RoBERTa large (seq512) | 0.969 | 0.849 | 0.925 | 0.890 | 0.928 | 0.910 | 0.955 | 0.900 | 100 | | LUKE Japanese base* | 0.965 | - | 0.916 | 0.877 | 0.912 | - | - | 0.842 | 101 | | LUKE Japanese large* | 0.965 | - | 0.932 | 0.902 | 0.927 | - | - | 0.893 | 102 | | DeBERTaV2 base | 0.970 | 0.879 | 0.922 | 0.886 | 0.922 | 0.899 | 0.951 | 0.873 | 103 | | DeBERTaV2 large | 0.968 | 0.882 | 0.925 | 0.892 | 0.924 | 0.912 | 0.959 | 0.890 | 104 | | DeBERTaV3 base | 0.960 | 0.878 | 0.927 | 0.891 | 0.927 | 0.896 | 0.947 | 0.875 | 105 | 106 | *The scores of LUKE are from [the official repository](https://github.com/studio-ousia/luke). 107 | 108 | ## Tuned hyper-parameters 109 | 110 | - Learning rate: {2e-05, 3e-05, 5e-05} 111 | 112 | | Model | MARC-ja/acc | JCoLA/acc | JSTS/pearson | JSTS/spearman | JNLI/acc | JSQuAD/EM | JSQuAD/F1 | JComQA/acc | 113 | |-------------------------------|-------------|-----------|--------------|---------------|----------|-----------|-----------|------------| 114 | | Waseda RoBERTa base | 3e-05 | 3e-05 | 2e-05 | 2e-05 | 3e-05 | 3e-05 | 3e-05 | 5e-05 | 115 | | Waseda RoBERTa large (seq512) | 2e-05 | 2e-05 | 3e-05 | 3e-05 | 2e-05 | 2e-05 | 2e-05 | 3e-05 | 116 | | DeBERTaV2 base | 2e-05 | 3e-05 | 5e-05 | 5e-05 | 3e-05 | 2e-05 | 2e-05 | 5e-05 | 117 | | DeBERTaV2 large | 5e-05 | 2e-05 | 5e-05 | 5e-05 | 2e-05 | 2e-05 | 2e-05 | 3e-05 | 118 | | DeBERTaV3 base | 5e-05 | 2e-05 | 3e-05 | 3e-05 | 2e-05 | 5e-05 | 5e-05 | 2e-05 | 119 | 120 | - Training epochs: {3, 4} 121 | 122 | | Model | MARC-ja/acc | JCoLA/acc | JSTS/pearson | JSTS/spearman | JNLI/acc | JSQuAD/EM | JSQuAD/F1 | JComQA/acc | 123 | |-------------------------------|-------------|-----------|--------------|---------------|----------|-----------|-----------|------------| 124 | | Waseda RoBERTa base | 4 | 3 | 4 | 4 | 3 | 4 | 4 | 3 | 125 | | Waseda RoBERTa large (seq512) | 4 | 4 | 4 | 4 | 3 | 3 | 3 | 3 | 126 | | DeBERTaV2 base | 3 | 4 | 3 | 3 | 3 | 4 | 4 | 4 | 127 | | DeBERTaV2 large | 3 | 3 | 4 | 4 | 3 | 4 | 4 | 3 | 128 | | DeBERTaV3 base | 4 | 4 | 4 | 4 | 4 | 4 | 4 | 4 | 129 | 130 | ## Huggingface hub links 131 | 132 | - Waseda RoBERTa base: [nlp-waseda/roberta-base-japanese](https://huggingface.co/nlp-waseda/roberta-base-japanese) 133 | - Waseda RoBERTa large (seq512): [nlp-waseda/roberta-large-japanese-seq512](https://huggingface.co/nlp-waseda/roberta-large-japanese-seq512) 134 | - LUKE Japanese base: [studio-ousia/luke-base-japanese](https://huggingface.co/studio-ousia/luke-japanese-base-lite) 135 | - LUKE Japanese large: [studio-ousia/luke-large-japanese](https://huggingface.co/studio-ousia/luke-japanese-large-lite) 136 | - DeBERTaV2 base: [ku-nlp/deberta-v2-base-japanese](https://huggingface.co/ku-nlp/deberta-v2-base-japanese) 137 | - DeBERTaV2 large: [ku-nlp/deberta-v2-large-japanese](https://huggingface.co/ku-nlp/deberta-v2-large-japanese) 138 | - DeBERTaV3 base: [ku-nlp/deberta-v3-base-japanese](https://huggingface.co/ku-nlp/deberta-v3-base-japanese) 139 | 140 | ## Author 141 | 142 | Nobuhiro Ueda (ueda **at** nlp.ist.i.kyoto-u.ac.jp) 143 | 144 | ## Reference 145 | 146 | - [yahoojapan/JGLUE: JGLUE: Japanese General Language Understanding Evaluation](https://github.com/yahoojapan/JGLUE) 147 | - [JGLUE: Japanese General Language Understanding Evaluation](https://aclanthology.org/2022.lrec-1.317) (Kurihara et 148 | al., LREC 2022) 149 | - 栗原 健太郎, 河原 大輔, 柴田 知秀, JGLUE: 日本語言語理解ベンチマーク, 自然言語処理, 2023, 30 巻, 1 号, p. 63-87, 公開日 150 | 2023/03/15, Online ISSN 2185-8314, Print ISSN 151 | 1340-7619, https://doi.org/10.5715/jnlp.30.63, https://www.jstage.jst.go.jp/article/jnlp/30/1/30_63/_article/-char/ja 152 | -------------------------------------------------------------------------------- /configs/base.yaml: -------------------------------------------------------------------------------- 1 | # specify here default training configuration 2 | defaults: 3 | - _self_ 4 | 5 | project: JGLUE-evaluation-scripts 6 | 7 | # path to original working directory 8 | # hydra hijacks working directory by changing it to the current log directory, 9 | # so it's useful to have this path as a special variable 10 | # https://hydra.cc/docs/next/tutorials/basic/running_your_app/working_directory 11 | work_dir: ${hydra:runtime.cwd} 12 | 13 | # seed for random number generators in pytorch, numpy and python.random 14 | # "null" means the seed is randomly selected at runtime. 15 | seed: null 16 | 17 | # name of the run is accessed by loggers 18 | # should be used along with experiment mode 19 | name: ${hydra:job.config_name}-${hydra:job.override_dirname} 20 | 21 | exp_dir: ${work_dir}/result/${name} 22 | run_id: ${now:%m%d}_${now:%H%M%S} 23 | run_dir: ${exp_dir}/${run_id} 24 | config_name: ${hydra:job.config_name} 25 | 26 | # environment dependent settings 27 | devices: ${oc.env:DEVICES,1} 28 | max_batches_per_device: ${oc.env:MAX_BATCHES_PER_DEVICE,4} 29 | num_workers: ${oc.env:NUM_WORKERS,4} 30 | compile: ${oc.env:COMPILE,false} # compile model for faster training with pytorch 2.0 31 | 32 | hydra: 33 | run: 34 | dir: ${exp_dir} 35 | sweep: 36 | dir: ${work_dir}/multirun_result 37 | subdir: ${name}-${hydra:job.num} 38 | job: 39 | config: 40 | override_dirname: 41 | kv_sep: '_' 42 | item_sep: '-' 43 | exclude_keys: 44 | - seed 45 | - name 46 | - exp_dir 47 | - run_dir 48 | - devices 49 | - num_workers 50 | - checkpoint_path 51 | - logger 52 | - max_batches_per_device 53 | - compile 54 | -------------------------------------------------------------------------------- /configs/callbacks/early_stopping.yaml: -------------------------------------------------------------------------------- 1 | early_stopping: 2 | _target_: lightning.pytorch.callbacks.EarlyStopping 3 | monitor: ${monitor} 4 | patience: 3 5 | mode: ${mode} 6 | verbose: true 7 | -------------------------------------------------------------------------------- /configs/callbacks/lr_monitor.yaml: -------------------------------------------------------------------------------- 1 | lr_monitor: 2 | _target_: lightning.pytorch.callbacks.LearningRateMonitor 3 | logging_interval: null # "epoch", "step", or "null" 4 | log_momentum: false 5 | -------------------------------------------------------------------------------- /configs/callbacks/model_checkpoint.yaml: -------------------------------------------------------------------------------- 1 | model_checkpoint: 2 | _target_: lightning.pytorch.callbacks.ModelCheckpoint 3 | dirpath: ${run_dir} 4 | filename: "{epoch}-{step}" 5 | auto_insert_metric_name: false 6 | monitor: ${monitor} 7 | mode: ${mode} 8 | save_top_k: 1 9 | every_n_epochs: 1 10 | save_weights_only: true 11 | -------------------------------------------------------------------------------- /configs/callbacks/model_summary.yaml: -------------------------------------------------------------------------------- 1 | model_summary: 2 | _target_: lightning.pytorch.callbacks.RichModelSummary 3 | max_depth: 3 4 | -------------------------------------------------------------------------------- /configs/callbacks/progress_bar.yaml: -------------------------------------------------------------------------------- 1 | progress_bar: 2 | _target_: lightning.pytorch.callbacks.RichProgressBar 3 | console_kwargs: 4 | stderr: True 5 | -------------------------------------------------------------------------------- /configs/datamodule/base/test.yaml: -------------------------------------------------------------------------------- 1 | max_seq_length: ${max_seq_length} 2 | split: test 3 | tokenizer: ${model.tokenizer} 4 | segmenter_kwargs: ${model.segmenter_kwargs} 5 | limit_examples: ${limit_examples} 6 | -------------------------------------------------------------------------------- /configs/datamodule/base/train.yaml: -------------------------------------------------------------------------------- 1 | max_seq_length: ${max_seq_length} 2 | split: train 3 | tokenizer: ${model.tokenizer} 4 | segmenter_kwargs: ${model.segmenter_kwargs} 5 | limit_examples: ${limit_examples} 6 | -------------------------------------------------------------------------------- /configs/datamodule/base/valid.yaml: -------------------------------------------------------------------------------- 1 | max_seq_length: ${max_seq_length} 2 | split: validation 3 | tokenizer: ${model.tokenizer} 4 | segmenter_kwargs: ${model.segmenter_kwargs} 5 | limit_examples: ${limit_examples} 6 | -------------------------------------------------------------------------------- /configs/datamodule/jcola.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - train: jcola.yaml 3 | - valid: jcola.yaml # or jcola_ood.yaml, jcola_ood_annotated.yaml 4 | - test: jcola.yaml 5 | 6 | batch_size: ${max_batches_per_device} 7 | num_workers: ${num_workers} 8 | -------------------------------------------------------------------------------- /configs/datamodule/jcqa.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - train: jcqa.yaml 3 | - valid: jcqa.yaml 4 | - test: jcqa.yaml 5 | 6 | batch_size: ${max_batches_per_device} 7 | num_workers: ${num_workers} 8 | -------------------------------------------------------------------------------- /configs/datamodule/jnli.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - train: jnli.yaml 3 | - valid: jnli.yaml 4 | - test: jnli.yaml 5 | 6 | batch_size: ${max_batches_per_device} 7 | num_workers: ${num_workers} 8 | -------------------------------------------------------------------------------- /configs/datamodule/jsquad.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - train: jsquad.yaml 3 | - valid: jsquad.yaml 4 | - test: jsquad.yaml 5 | 6 | batch_size: ${max_batches_per_device} 7 | num_workers: ${num_workers} 8 | -------------------------------------------------------------------------------- /configs/datamodule/jsts.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - train: jsts.yaml 3 | - valid: jsts.yaml 4 | - test: jsts.yaml 5 | 6 | batch_size: ${max_batches_per_device} 7 | num_workers: ${num_workers} 8 | -------------------------------------------------------------------------------- /configs/datamodule/marc_ja.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - train: marc_ja.yaml 3 | - valid: marc_ja.yaml 4 | - test: marc_ja.yaml 5 | 6 | batch_size: ${max_batches_per_device} 7 | num_workers: ${num_workers} 8 | -------------------------------------------------------------------------------- /configs/datamodule/test/jcola.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - /datamodule/base/test@ 3 | _target_: datamodule.datasets.JCoLADataset 4 | -------------------------------------------------------------------------------- /configs/datamodule/test/jcqa.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - /datamodule/base/test@ 3 | _target_: datamodule.datasets.JCommonsenseQADataset 4 | -------------------------------------------------------------------------------- /configs/datamodule/test/jnli.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - /datamodule/base/test@ 3 | _target_: datamodule.datasets.JNLIDataset 4 | -------------------------------------------------------------------------------- /configs/datamodule/test/jsquad.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - /datamodule/base/test@ 3 | _target_: datamodule.datasets.JSQuADDataset 4 | -------------------------------------------------------------------------------- /configs/datamodule/test/jsts.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - /datamodule/base/test@ 3 | _target_: datamodule.datasets.JSTSDataset 4 | -------------------------------------------------------------------------------- /configs/datamodule/test/marc_ja.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - /datamodule/base/test@ 3 | _target_: datamodule.datasets.MARCJaDataset 4 | -------------------------------------------------------------------------------- /configs/datamodule/train/jcola.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - /datamodule/base/train@ 3 | _target_: datamodule.datasets.JCoLADataset 4 | -------------------------------------------------------------------------------- /configs/datamodule/train/jcqa.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - /datamodule/base/train@ 3 | _target_: datamodule.datasets.JCommonsenseQADataset 4 | -------------------------------------------------------------------------------- /configs/datamodule/train/jnli.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - /datamodule/base/train@ 3 | _target_: datamodule.datasets.JNLIDataset 4 | -------------------------------------------------------------------------------- /configs/datamodule/train/jsquad.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - /datamodule/base/train@ 3 | _target_: datamodule.datasets.JSQuADDataset 4 | -------------------------------------------------------------------------------- /configs/datamodule/train/jsts.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - /datamodule/base/train@ 3 | _target_: datamodule.datasets.JSTSDataset 4 | -------------------------------------------------------------------------------- /configs/datamodule/train/marc_ja.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - /datamodule/base/train@ 3 | _target_: datamodule.datasets.MARCJaDataset 4 | -------------------------------------------------------------------------------- /configs/datamodule/valid/jcola.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - /datamodule/base/valid@ 3 | _target_: datamodule.datasets.JCoLADataset 4 | -------------------------------------------------------------------------------- /configs/datamodule/valid/jcola_ood.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - /datamodule/base/valid@ 3 | _target_: datamodule.datasets.JCoLADataset 4 | split: validation_out_of_domain 5 | -------------------------------------------------------------------------------- /configs/datamodule/valid/jcola_ood_annotated.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - /datamodule/base/valid@ 3 | _target_: datamodule.datasets.JCoLADataset 4 | split: validation_out_of_domain_annotated 5 | -------------------------------------------------------------------------------- /configs/datamodule/valid/jcqa.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - /datamodule/base/valid@ 3 | _target_: datamodule.datasets.JCommonsenseQADataset 4 | -------------------------------------------------------------------------------- /configs/datamodule/valid/jnli.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - /datamodule/base/valid@ 3 | _target_: datamodule.datasets.JNLIDataset 4 | -------------------------------------------------------------------------------- /configs/datamodule/valid/jsquad.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - /datamodule/base/valid@ 3 | _target_: datamodule.datasets.JSQuADDataset 4 | -------------------------------------------------------------------------------- /configs/datamodule/valid/jsts.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - /datamodule/base/valid@ 3 | _target_: datamodule.datasets.JSTSDataset 4 | -------------------------------------------------------------------------------- /configs/datamodule/valid/marc_ja.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - /datamodule/base/valid@ 3 | _target_: datamodule.datasets.MARCJaDataset 4 | -------------------------------------------------------------------------------- /configs/eval.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - callbacks: [progress_bar.yaml] 3 | - logger: null 4 | - module: null 5 | - _self_ 6 | 7 | # required settings 8 | checkpoint_path: null # path to trained checkpoint 9 | eval_set: test # test or valid 10 | 11 | # environment dependent settings 12 | devices: ${oc.env:DEVICES,1} 13 | max_batches_per_device: ${oc.env:MAX_BATCHES_PER_DEVICE,4} 14 | num_workers: ${oc.env:NUM_WORKERS,4} 15 | compile: ${oc.env:COMPILE,false} # compile model for faster training with pytorch 2.0 16 | -------------------------------------------------------------------------------- /configs/jcola.debug.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - base 3 | - callbacks: [early_stopping, model_checkpoint, model_summary, progress_bar] 4 | - datamodule: jcola 5 | - logger: null 6 | - model: deberta_v2_tiny 7 | - module: jcola 8 | - optimizer: adamw 9 | - scheduler: constant_schedule_with_warmup 10 | - trainer: debug 11 | - _self_ 12 | 13 | max_seq_length: 128 14 | checkpoint_path: "" 15 | limit_examples: 100 16 | 17 | monitor: valid/accuracy 18 | mode: max 19 | 20 | # hyper-parameters to be tuned 21 | lr: 1e-4 22 | max_epochs: 2 23 | warmup_steps: null 24 | warmup_ratio: 0.1 25 | effective_batch_size: 4 26 | 27 | # environment dependent settings 28 | num_workers: 0 29 | -------------------------------------------------------------------------------- /configs/jcola.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - base 3 | - callbacks: [early_stopping, model_checkpoint, model_summary, progress_bar, lr_monitor] 4 | - datamodule: jcola 5 | - logger: wandb 6 | - model: deberta_v2_large 7 | - module: jcola 8 | - optimizer: adamw 9 | - scheduler: cosine_schedule_with_warmup 10 | - trainer: default 11 | - _self_ 12 | 13 | max_seq_length: 512 14 | checkpoint_path: "" 15 | limit_examples: -1 16 | 17 | # set monitor and mode for early_stopping and model_checkpoint 18 | monitor: valid/accuracy 19 | mode: max 20 | 21 | # hyper-parameters to be tuned 22 | lr: 4e-5 23 | max_epochs: 4 24 | warmup_steps: null 25 | warmup_ratio: 0.1 26 | effective_batch_size: 256 27 | -------------------------------------------------------------------------------- /configs/jcqa.debug.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - base 3 | - callbacks: [early_stopping, model_checkpoint, model_summary, progress_bar] 4 | - datamodule: jcqa 5 | - logger: null 6 | - model: deberta_v2_tiny 7 | - module: jcqa 8 | - optimizer: adamw 9 | - scheduler: constant_schedule_with_warmup 10 | - trainer: debug 11 | - _self_ 12 | 13 | max_seq_length: 64 14 | checkpoint_path: "" 15 | limit_examples: 100 16 | 17 | monitor: valid/accuracy 18 | mode: max 19 | 20 | # hyper-parameters to be tuned 21 | lr: 1e-4 22 | max_epochs: 2 23 | warmup_steps: null 24 | warmup_ratio: 0.1 25 | effective_batch_size: 4 26 | 27 | # environment dependent settings 28 | num_workers: 0 29 | -------------------------------------------------------------------------------- /configs/jcqa.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - base 3 | - callbacks: [early_stopping, model_checkpoint, model_summary, progress_bar, lr_monitor] 4 | - datamodule: jcqa 5 | - logger: wandb 6 | - model: deberta_v2_large 7 | - module: jcqa 8 | - optimizer: adamw 9 | - scheduler: cosine_schedule_with_warmup 10 | - trainer: default 11 | - _self_ 12 | 13 | max_seq_length: 128 14 | checkpoint_path: "" 15 | limit_examples: -1 16 | 17 | # set monitor and mode for early_stopping and model_checkpoint 18 | monitor: valid/accuracy 19 | mode: max 20 | 21 | # hyper-parameters to be tuned 22 | lr: 5e-5 23 | max_epochs: 4 24 | warmup_steps: null 25 | warmup_ratio: 0.1 26 | effective_batch_size: 32 27 | -------------------------------------------------------------------------------- /configs/jnli.debug.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - base 3 | - callbacks: [early_stopping, model_checkpoint, model_summary, progress_bar] 4 | - datamodule: jnli 5 | - logger: null 6 | - model: deberta_v2_tiny 7 | - module: jnli 8 | - optimizer: adamw 9 | - scheduler: constant_schedule_with_warmup 10 | - trainer: debug 11 | - _self_ 12 | 13 | max_seq_length: 128 14 | checkpoint_path: "" 15 | limit_examples: 100 16 | 17 | monitor: valid/accuracy 18 | mode: max 19 | 20 | # hyper-parameters to be tuned 21 | lr: 1e-4 22 | max_epochs: 2 23 | warmup_steps: null 24 | warmup_ratio: 0.1 25 | effective_batch_size: 4 26 | 27 | # environment dependent settings 28 | num_workers: 0 29 | -------------------------------------------------------------------------------- /configs/jnli.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - base 3 | - callbacks: [early_stopping, model_checkpoint, model_summary, progress_bar, lr_monitor] 4 | - datamodule: jnli 5 | - logger: wandb 6 | - model: deberta_v2_large 7 | - module: jnli 8 | - optimizer: adamw 9 | - scheduler: cosine_schedule_with_warmup 10 | - trainer: default 11 | - _self_ 12 | 13 | max_seq_length: 128 14 | checkpoint_path: "" 15 | limit_examples: -1 16 | 17 | # set monitor and mode for early_stopping and model_checkpoint 18 | monitor: valid/accuracy 19 | mode: max 20 | 21 | # hyper-parameters to be tuned 22 | lr: 5e-5 23 | max_epochs: 4 24 | warmup_steps: null 25 | warmup_ratio: 0.1 26 | effective_batch_size: 32 27 | -------------------------------------------------------------------------------- /configs/jsquad.debug.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - base 3 | - callbacks: [early_stopping, model_checkpoint, model_summary, progress_bar] 4 | - datamodule: jsquad 5 | - logger: null 6 | - model: deberta_v2_tiny 7 | - module: jsquad 8 | - optimizer: adamw 9 | - scheduler: constant_schedule_with_warmup 10 | - trainer: debug 11 | - _self_ 12 | 13 | max_seq_length: 128 14 | checkpoint_path: "" 15 | limit_examples: 3 16 | 17 | monitor: valid/f1 18 | mode: max 19 | 20 | # hyper-parameters to be tuned 21 | lr: 5e-5 22 | max_epochs: 2 23 | warmup_steps: null 24 | warmup_ratio: 0.1 25 | effective_batch_size: 4 26 | 27 | # environment dependent settings 28 | num_workers: 0 29 | -------------------------------------------------------------------------------- /configs/jsquad.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - base 3 | - callbacks: [early_stopping, model_checkpoint, model_summary, progress_bar, lr_monitor] 4 | - datamodule: jsquad 5 | - logger: wandb 6 | - model: deberta_v2_large 7 | - module: jsquad 8 | - optimizer: adamw 9 | - scheduler: cosine_schedule_with_warmup 10 | - trainer: default 11 | - _self_ 12 | 13 | max_seq_length: 384 # Max sequence length is 606 for ku-nlp/deberta-v3-base-japanese 14 | checkpoint_path: "" 15 | limit_examples: -1 16 | 17 | # set monitor and mode for early_stopping and model_checkpoint 18 | monitor: valid/f1 19 | mode: max 20 | 21 | # hyper-parameters to be tuned 22 | lr: 5e-5 23 | max_epochs: 4 24 | warmup_steps: null 25 | warmup_ratio: 0.1 26 | effective_batch_size: 32 27 | -------------------------------------------------------------------------------- /configs/jsts.debug.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - base 3 | - callbacks: [early_stopping, model_checkpoint, model_summary, progress_bar] 4 | - datamodule: jsts 5 | - logger: null 6 | - model: deberta_v2_tiny 7 | - module: jsts 8 | - optimizer: adamw 9 | - scheduler: constant_schedule_with_warmup 10 | - trainer: debug 11 | - _self_ 12 | 13 | max_seq_length: 128 14 | checkpoint_path: "" 15 | limit_examples: 100 16 | 17 | monitor: valid/spearman 18 | mode: max 19 | 20 | # hyper-parameters to be tuned 21 | lr: 1e-4 22 | max_epochs: 2 23 | warmup_steps: null 24 | warmup_ratio: 0.1 25 | effective_batch_size: 4 26 | 27 | # environment dependent settings 28 | num_workers: 0 29 | -------------------------------------------------------------------------------- /configs/jsts.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - base 3 | - callbacks: [early_stopping, model_checkpoint, model_summary, progress_bar, lr_monitor] 4 | - datamodule: jsts 5 | - logger: wandb 6 | - model: deberta_v2_large 7 | - module: jsts 8 | - optimizer: adamw 9 | - scheduler: cosine_schedule_with_warmup 10 | - trainer: default 11 | - _self_ 12 | 13 | max_seq_length: 128 14 | checkpoint_path: "" 15 | limit_examples: -1 16 | 17 | # set monitor and mode for early_stopping and model_checkpoint 18 | monitor: valid/spearman 19 | mode: max 20 | 21 | # hyper-parameters to be tuned 22 | lr: 5e-5 23 | max_epochs: 4 24 | warmup_steps: null 25 | warmup_ratio: 0.1 26 | effective_batch_size: 32 27 | -------------------------------------------------------------------------------- /configs/logger/wandb.yaml: -------------------------------------------------------------------------------- 1 | _target_: lightning.pytorch.loggers.WandbLogger 2 | name: ${name}-${hydra:job.num} 3 | save_dir: ${work_dir} 4 | project: ${project} 5 | group: ${name} 6 | tags: 7 | - ${config_name} 8 | settings: 9 | _target_: wandb.Settings 10 | start_method: fork 11 | -------------------------------------------------------------------------------- /configs/marc_ja.debug.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - base 3 | - callbacks: [early_stopping, model_checkpoint, model_summary, progress_bar] 4 | - datamodule: marc_ja 5 | - logger: null 6 | - model: deberta_v2_tiny 7 | - module: marc_ja 8 | - optimizer: adamw 9 | - scheduler: constant_schedule_with_warmup 10 | - trainer: debug 11 | - _self_ 12 | 13 | max_seq_length: 128 14 | checkpoint_path: "" 15 | limit_examples: 100 16 | 17 | monitor: valid/accuracy 18 | mode: max 19 | 20 | # hyper-parameters to be tuned 21 | lr: 1e-4 22 | max_epochs: 2 23 | warmup_steps: null 24 | warmup_ratio: 0.1 25 | effective_batch_size: 4 26 | 27 | # environment dependent settings 28 | num_workers: 0 29 | -------------------------------------------------------------------------------- /configs/marc_ja.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - base 3 | - callbacks: [early_stopping, model_checkpoint, model_summary, progress_bar, lr_monitor] 4 | - datamodule: marc_ja 5 | - logger: wandb 6 | - model: deberta_v2_large 7 | - module: marc_ja 8 | - optimizer: adamw 9 | - scheduler: cosine_schedule_with_warmup 10 | - trainer: default 11 | - _self_ 12 | 13 | max_seq_length: 512 14 | checkpoint_path: "" 15 | limit_examples: -1 16 | 17 | # set monitor and mode for early_stopping and model_checkpoint 18 | monitor: valid/accuracy 19 | mode: max 20 | 21 | # hyper-parameters to be tuned 22 | lr: 4e-5 23 | max_epochs: 4 24 | warmup_steps: null 25 | warmup_ratio: 0.1 26 | effective_batch_size: 256 27 | -------------------------------------------------------------------------------- /configs/model/deberta_v2_base.yaml: -------------------------------------------------------------------------------- 1 | model_name_or_path: ku-nlp/deberta-v2-base-japanese 2 | tokenizer: 3 | _target_: transformers.AutoTokenizer.from_pretrained 4 | pretrained_model_name_or_path: ${..model_name_or_path} 5 | _convert_: all 6 | segmenter_kwargs: 7 | analyzer: jumanpp 8 | h2z: false 9 | -------------------------------------------------------------------------------- /configs/model/deberta_v2_large.yaml: -------------------------------------------------------------------------------- 1 | model_name_or_path: ku-nlp/deberta-v2-base-japanese 2 | tokenizer: 3 | _target_: transformers.AutoTokenizer.from_pretrained 4 | pretrained_model_name_or_path: ${..model_name_or_path} 5 | _convert_: all 6 | segmenter_kwargs: 7 | analyzer: jumanpp 8 | h2z: false 9 | -------------------------------------------------------------------------------- /configs/model/deberta_v2_tiny.yaml: -------------------------------------------------------------------------------- 1 | model_name_or_path: ku-nlp/deberta-v2-tiny-japanese 2 | tokenizer: 3 | _target_: transformers.AutoTokenizer.from_pretrained 4 | pretrained_model_name_or_path: ${..model_name_or_path} 5 | _convert_: all 6 | segmenter_kwargs: 7 | analyzer: jumanpp 8 | h2z: false 9 | -------------------------------------------------------------------------------- /configs/model/deberta_v3_base.yaml: -------------------------------------------------------------------------------- 1 | model_name_or_path: ku-nlp/deberta-v3-base-japanese 2 | tokenizer: 3 | _target_: transformers.AutoTokenizer.from_pretrained 4 | pretrained_model_name_or_path: ${..model_name_or_path} 5 | _convert_: all 6 | segmenter_kwargs: 7 | analyzer: null 8 | h2z: false 9 | -------------------------------------------------------------------------------- /configs/model/luke_base.yaml: -------------------------------------------------------------------------------- 1 | model_name_or_path: studio-ousia/luke-japanese-base-lite 2 | tokenizer: 3 | _target_: transformers.AutoTokenizer.from_pretrained 4 | pretrained_model_name_or_path: ${..model_name_or_path} 5 | _convert_: all 6 | segmenter_kwargs: 7 | analyzer: null 8 | h2z: false 9 | -------------------------------------------------------------------------------- /configs/model/luke_large.yaml: -------------------------------------------------------------------------------- 1 | model_name_or_path: studio-ousia/luke-japanese-large-lite 2 | tokenizer: 3 | _target_: transformers.AutoTokenizer.from_pretrained 4 | pretrained_model_name_or_path: ${..model_name_or_path} 5 | _convert_: all 6 | segmenter_kwargs: 7 | analyzer: null 8 | h2z: false 9 | -------------------------------------------------------------------------------- /configs/model/microsoft__mdeberta_v3_base.yaml: -------------------------------------------------------------------------------- 1 | model_name_or_path: microsoft/mdeberta-v3-base 2 | tokenizer: 3 | _target_: transformers.AutoTokenizer.from_pretrained 4 | pretrained_model_name_or_path: ${..model_name_or_path} 5 | _convert_: all 6 | segmenter_kwargs: 7 | analyzer: null 8 | h2z: false 9 | -------------------------------------------------------------------------------- /configs/model/modernbert_130m.yaml: -------------------------------------------------------------------------------- 1 | model_name_or_path: sbintuitions/modernbert-ja-130m 2 | tokenizer: 3 | _target_: transformers.AutoTokenizer.from_pretrained 4 | pretrained_model_name_or_path: ${..model_name_or_path} 5 | _convert_: all 6 | segmenter_kwargs: 7 | analyzer: null 8 | -------------------------------------------------------------------------------- /configs/model/modernbert_30m.yaml: -------------------------------------------------------------------------------- 1 | model_name_or_path: sbintuitions/modernbert-ja-30m 2 | tokenizer: 3 | _target_: transformers.AutoTokenizer.from_pretrained 4 | pretrained_model_name_or_path: ${..model_name_or_path} 5 | _convert_: all 6 | segmenter_kwargs: 7 | analyzer: null 8 | -------------------------------------------------------------------------------- /configs/model/modernbert_310m.yaml: -------------------------------------------------------------------------------- 1 | model_name_or_path: sbintuitions/modernbert-ja-310m 2 | tokenizer: 3 | _target_: transformers.AutoTokenizer.from_pretrained 4 | pretrained_model_name_or_path: ${..model_name_or_path} 5 | _convert_: all 6 | segmenter_kwargs: 7 | analyzer: null 8 | -------------------------------------------------------------------------------- /configs/model/modernbert_70m.yaml: -------------------------------------------------------------------------------- 1 | model_name_or_path: sbintuitions/modernbert-ja-70m 2 | tokenizer: 3 | _target_: transformers.AutoTokenizer.from_pretrained 4 | pretrained_model_name_or_path: ${..model_name_or_path} 5 | _convert_: all 6 | segmenter_kwargs: 7 | analyzer: null 8 | -------------------------------------------------------------------------------- /configs/model/roberta_base.yaml: -------------------------------------------------------------------------------- 1 | model_name_or_path: nlp-waseda/roberta-base-japanese 2 | tokenizer: 3 | _target_: transformers.AutoTokenizer.from_pretrained 4 | pretrained_model_name_or_path: ${..model_name_or_path} 5 | _convert_: all 6 | segmenter_kwargs: 7 | analyzer: jumanpp 8 | h2z: false 9 | -------------------------------------------------------------------------------- /configs/model/roberta_large.yaml: -------------------------------------------------------------------------------- 1 | model_name_or_path: nlp-waseda/roberta-large-japanese-seq512 2 | tokenizer: 3 | _target_: transformers.AutoTokenizer.from_pretrained 4 | pretrained_model_name_or_path: ${..model_name_or_path} 5 | _convert_: all 6 | segmenter_kwargs: 7 | analyzer: jumanpp 8 | h2z: false 9 | -------------------------------------------------------------------------------- /configs/module/jcola.yaml: -------------------------------------------------------------------------------- 1 | cls: 2 | _target_: modules.JCoLAModule 3 | 4 | load_from_checkpoint: 5 | _target_: ${module.cls._target_}.load_from_checkpoint 6 | checkpoint_path: ${checkpoint_path} 7 | -------------------------------------------------------------------------------- /configs/module/jcqa.yaml: -------------------------------------------------------------------------------- 1 | cls: 2 | _target_: modules.JCommonsenseQAModule 3 | 4 | load_from_checkpoint: 5 | _target_: ${module.cls._target_}.load_from_checkpoint 6 | checkpoint_path: ${checkpoint_path} 7 | -------------------------------------------------------------------------------- /configs/module/jnli.yaml: -------------------------------------------------------------------------------- 1 | cls: 2 | _target_: modules.JNLIModule 3 | 4 | load_from_checkpoint: 5 | _target_: ${module.cls._target_}.load_from_checkpoint 6 | checkpoint_path: ${checkpoint_path} 7 | -------------------------------------------------------------------------------- /configs/module/jsquad.yaml: -------------------------------------------------------------------------------- 1 | cls: 2 | _target_: modules.JSQuADModule 3 | 4 | load_from_checkpoint: 5 | _target_: ${module.cls._target_}.load_from_checkpoint 6 | checkpoint_path: ${checkpoint_path} 7 | -------------------------------------------------------------------------------- /configs/module/jsts.yaml: -------------------------------------------------------------------------------- 1 | cls: 2 | _target_: modules.JSTSModule 3 | 4 | load_from_checkpoint: 5 | _target_: ${module.cls._target_}.load_from_checkpoint 6 | checkpoint_path: ${checkpoint_path} 7 | -------------------------------------------------------------------------------- /configs/module/marc_ja.yaml: -------------------------------------------------------------------------------- 1 | cls: 2 | _target_: modules.MARCJaModule 3 | 4 | load_from_checkpoint: 5 | _target_: ${module.cls._target_}.load_from_checkpoint 6 | checkpoint_path: ${checkpoint_path} 7 | -------------------------------------------------------------------------------- /configs/optimizer/adamw.yaml: -------------------------------------------------------------------------------- 1 | _target_: torch.optim.AdamW 2 | lr: ${lr} 3 | betas: [0.9, 0.999] 4 | eps: 1e-8 5 | weight_decay: 0.0 6 | -------------------------------------------------------------------------------- /configs/scheduler/constant_schedule_with_warmup.yaml: -------------------------------------------------------------------------------- 1 | _target_: transformers.optimization.get_constant_schedule_with_warmup 2 | num_warmup_steps: ${warmup_steps} 3 | -------------------------------------------------------------------------------- /configs/scheduler/cosine_schedule_with_warmup.yaml: -------------------------------------------------------------------------------- 1 | _target_: transformers.optimization.get_cosine_schedule_with_warmup 2 | num_warmup_steps: ${warmup_steps} 3 | num_training_steps: null 4 | -------------------------------------------------------------------------------- /configs/scheduler/linear_schedule_with_warmup.yaml: -------------------------------------------------------------------------------- 1 | _target_: transformers.optimization.get_linear_schedule_with_warmup 2 | num_warmup_steps: ${warmup_steps} 3 | num_training_steps: null 4 | -------------------------------------------------------------------------------- /configs/trainer/cpu.debug.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - debug 3 | 4 | accelerator: cpu 5 | -------------------------------------------------------------------------------- /configs/trainer/cpu.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default 3 | 4 | accelerator: cpu 5 | -------------------------------------------------------------------------------- /configs/trainer/debug.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default 3 | 4 | #fast_dev_run: 3 5 | limit_train_batches: 10 6 | limit_val_batches: 10 7 | limit_test_batches: 10 8 | -------------------------------------------------------------------------------- /configs/trainer/default.yaml: -------------------------------------------------------------------------------- 1 | _target_: lightning.Trainer 2 | gradient_clip_val: 0.5 3 | accumulate_grad_batches: 1 4 | max_epochs: ${max_epochs} 5 | min_epochs: 1 6 | precision: 32 7 | accelerator: auto 8 | strategy: auto 9 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "JGLUE-evaluation-scripts" 3 | version = "0.1.0" 4 | description = "" 5 | authors = [ 6 | {name = "Nobuhiro Ueda", email = "ueda@nlp.ist.i.kyoto-u.ac.jp"}, 7 | ] 8 | maintainers = [ 9 | {name = "Nobuhiro Ueda", email = "ueda@nlp.ist.i.kyoto-u.ac.jp"}, 10 | ] 11 | readme = "README.md" 12 | requires-python = ">=3.9,<3.14" 13 | dependencies = [ 14 | 'torch>=2.4.0; python_version < "3.13"', 15 | 'torch>=2.6.0; python_version >= "3.13"', 16 | "transformers>=4.48.0", 17 | 'sentencepiece>=0.2.0; python_version < "3.13"', 18 | "tokenizers>=0.21.0", 19 | "lightning>=2.4.0", 20 | "torchmetrics>=1.1.0", 21 | "omegaconf>=2.3.0", 22 | "hydra-core>=1.3.2", 23 | "rich>=13.3.0", 24 | "datasets>=3.5.0", 25 | "rhoknp>=1.4.0", 26 | "jaconv>=0.4.0", 27 | "mecab-python3>=1.0.10", 28 | ] 29 | 30 | [project.optional-dependencies] 31 | mecab = [ 32 | "mecab-python3>=1.0.0", 33 | ] 34 | 35 | [dependency-groups] 36 | dev = [ 37 | 'ipython>=8.13.1,<8.19.0; python_version == "3.9"', 38 | 'ipython>=8.19.0; python_version >= "3.10"', 39 | "ipdb>=0.13.13", 40 | "pytest>=8.0.0", 41 | "pip>=25.0", 42 | "types-attrs>=19.1.0", 43 | "wandb>=0.18.0", 44 | "prettytable>=3.16.0", 45 | ] 46 | flash-attn = [ 47 | "flash-attn>=2.6.3,<3", 48 | ] 49 | 50 | [build-system] 51 | requires = ["hatchling"] 52 | build-backend = "hatchling.build" 53 | 54 | [tool.uv] 55 | package = false 56 | no-build-isolation-package = ["flash-attn"] 57 | 58 | [tool.ruff] 59 | line-length = 120 60 | indent-width = 4 61 | target-version = "py39" # The minimum Python version to target 62 | src = ["src"] 63 | 64 | [tool.ruff.lint] 65 | select = ["ALL"] 66 | ignore = [ 67 | "PLR0912", # Too many branches 68 | "PLR0913", # Too many arguments in function definition 69 | "PLR0915", # Too many statements 70 | "E501", # Line too long 71 | "RUF001", # String contains ambiguous `ノ` (KATAKANA LETTER NO). Did you mean `/` (SOLIDUS)? 72 | "RUF002", # Docstring contains ambiguous `,` (FULLWIDTH COMMA). Did you mean `,` (COMMA)? 73 | "COM812", # Trailing comma missing 74 | "ANN002", # Missing type annotation for `*args` 75 | "ANN003", # Missing type annotation for `**kwargs` 76 | "PLR2004", # Magic value used in comparison 77 | "D", # pydocstyle 78 | "FA100", # Missing `from __future__ import annotations`, but uses `...` 79 | "S101", # Use of `assert` detected 80 | "TRY003", # Avoid specifying long messages outside the exception class 81 | "ANN401", # Dynamically typed expressions (typing.Any) are disallowed in `...` 82 | "C408", # Unnecessary `dict` call (rewrite as a literal) 83 | "FBT001", # Boolean-typed positional argument in function definition 84 | "FBT002", # Boolean default positional argument in function definition 85 | "ERA001", # Found commented-out code 86 | "EM102", # Exception must not use an f-string literal, assign to variable first 87 | ] 88 | 89 | [tool.ruff.lint.per-file-ignores] 90 | "tests/*" = [ 91 | "ANN", # flake8-annotations 92 | "INP001", # File `...` is part of an implicit namespace package. Add an `__init__.py`. 93 | ] 94 | "scripts/gen_table.py" = [ 95 | "T201", # `print` found 96 | "INP001", # File `...` is part of an implicit namespace package. Add an `__init__.py`. 97 | ] 98 | 99 | [tool.ruff.lint.flake8-tidy-imports] 100 | ban-relative-imports = "all" 101 | 102 | [tool.ruff.lint.pydocstyle] 103 | convention = "google" 104 | 105 | [tool.mypy] 106 | python_version = "3.9" 107 | 108 | [tool.pytest.ini_options] 109 | testpaths = ["tests"] 110 | filterwarnings = [ 111 | # note the use of single quote below to denote "raw" strings in TOML 112 | 'ignore::UserWarning', 113 | ] 114 | -------------------------------------------------------------------------------- /scripts/gen_sweeps.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -euo pipefail 4 | 5 | for task in marc_ja jcola jsts jnli jsquad jcqa; do 6 | for model in roberta_base roberta_large deberta_base deberta_large; do 7 | sweep_id=$(wandb sweep --name="${task}-${model}" "sweeps/${task}/${model}.yaml" 2>&1 | tail -1 | cut -d' ' -f8) 8 | echo "${task}-${model}" "${sweep_id}" 0 | tee -a sweep_status.txt 9 | done 10 | done 11 | -------------------------------------------------------------------------------- /scripts/gen_table.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from pathlib import Path 3 | from typing import TYPE_CHECKING, Optional 4 | 5 | import wandb 6 | from prettytable import PrettyTable 7 | 8 | if TYPE_CHECKING: 9 | from wandb.apis.public import Run, Sweep 10 | 11 | TASKS = { 12 | "marc_ja/accuracy": "MARC-ja/acc", 13 | "jcola/accuracy": "JCoLA/acc", 14 | "jsts/pearson": "JSTS/pearson", 15 | "jsts/spearman": "JSTS/spearman", 16 | "jnli/accuracy": "JNLI/acc", 17 | "jsquad/exact_match": "JSQuAD/EM", 18 | "jsquad/f1": "JSQuAD/F1", 19 | "jcqa/accuracy": "JComQA/acc", 20 | } 21 | MODELS = { 22 | "roberta_base": "Waseda RoBERTa base", # nlp-waseda/roberta-base-japanese 23 | "roberta_large": "Waseda RoBERTa large (seq512)", # nlp-waseda/roberta-large-japanese-seq512 24 | "deberta_base": "DeBERTaV2 base", # ku-nlp/deberta-v2-base-japanese 25 | "deberta_large": "DeBERTaV2 large", # ku-nlp/deberta-v2-large-japanese 26 | "deberta_v3_base": "DeBERTaV3 base", # ku-nlp/deberta-v3-base-japanese 27 | } 28 | 29 | 30 | @dataclass(frozen=True) 31 | class RunSummary: 32 | metric: float 33 | lr: float 34 | max_epochs: int 35 | batch_size: int 36 | 37 | 38 | def create_table(headers: list[str], align: list[str]) -> PrettyTable: 39 | table = PrettyTable() 40 | table.field_names = headers 41 | for header, a in zip(headers, align): 42 | table.align[header] = a 43 | return table 44 | 45 | 46 | def main() -> None: 47 | api = wandb.Api() 48 | name_to_sweep_path: dict[str, str] = { 49 | line.split()[0]: line.split()[1] for line in Path("sweep_status.txt").read_text().splitlines() 50 | } 51 | results: list[list[Optional[RunSummary]]] = [] 52 | for model in MODELS: 53 | items: list[Optional[RunSummary]] = [] 54 | for task_and_metric in TASKS: 55 | task, metric_name = task_and_metric.split("/") 56 | sweep: Sweep = api.sweep(name_to_sweep_path[f"{task}-{model}"]) 57 | if sweep.state == "FINISHED": 58 | run: Optional[Run] = sweep.best_run() 59 | assert run is not None 60 | metric_name = "valid/" + metric_name 61 | items.append( 62 | RunSummary( 63 | metric=run.summary[metric_name], 64 | lr=run.config["lr"], 65 | max_epochs=run.config["max_epochs"], 66 | batch_size=run.config["effective_batch_size"], 67 | ) 68 | ) 69 | else: 70 | items.append(None) 71 | results.append(items) 72 | 73 | headers = ["Model", *TASKS.values()] 74 | align = ["l"] + ["r"] * len(TASKS) 75 | 76 | # スコアのテーブル 77 | print("Scores of best runs:") 78 | score_table = create_table(headers, align) 79 | for model, items in zip(MODELS.values(), results): 80 | row = [model] + [f"{item.metric:.3f}" if item else "-" for item in items] 81 | score_table.add_row(row) 82 | print(score_table) 83 | print() 84 | 85 | # 学習率のテーブル 86 | print("Learning rates of best runs:") 87 | lr_table = create_table(headers, align) 88 | for model, items in zip(MODELS.values(), results): 89 | row = [model] + [str(item.lr) if item else "-" for item in items] 90 | lr_table.add_row(row) 91 | print(lr_table) 92 | print() 93 | 94 | # エポック数のテーブル 95 | print("Training epochs of best runs:") 96 | epoch_table = create_table(headers, align) 97 | for model, items in zip(MODELS.values(), results): 98 | row = [model] + [str(item.max_epochs) if item else "-" for item in items] 99 | epoch_table.add_row(row) 100 | print(epoch_table) 101 | 102 | 103 | if __name__ == "__main__": 104 | main() 105 | -------------------------------------------------------------------------------- /scripts/run_sweeps.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -euo pipefail 4 | 5 | for sweep in "$@"; do 6 | while read -r task_model sweep_id _; do 7 | if [[ "${sweep}" = "${task_model}" ]]; then 8 | wandb agent "${sweep_id}" 9 | break 10 | fi 11 | done < sweep_status.txt 12 | done 13 | -------------------------------------------------------------------------------- /src/datamodule/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nobu-g/JGLUE-evaluation-scripts/996fe58d68be5af2132e38d25e82aadb52e89209/src/datamodule/__init__.py -------------------------------------------------------------------------------- /src/datamodule/datamodule.py: -------------------------------------------------------------------------------- 1 | from dataclasses import fields, is_dataclass 2 | from typing import Any, Optional, Union 3 | 4 | import hydra 5 | import lightning 6 | import torch 7 | from lightning.pytorch.trainer.states import TrainerFn 8 | from omegaconf import DictConfig 9 | from torch import Tensor 10 | from torch.utils.data import DataLoader, Dataset 11 | 12 | 13 | class DataModule(lightning.LightningDataModule): 14 | def __init__(self, cfg: DictConfig) -> None: 15 | super().__init__() 16 | self.cfg: DictConfig = cfg 17 | self.batch_size: int = cfg.batch_size 18 | self.num_workers: int = cfg.num_workers 19 | 20 | self.train_dataset: Optional[Dataset] = None 21 | self.valid_dataset: Optional[Dataset] = None 22 | self.test_dataset: Optional[Dataset] = None 23 | 24 | def prepare_data(self) -> None: 25 | pass 26 | 27 | def setup(self, stage: Optional[str] = None) -> None: 28 | if stage == TrainerFn.FITTING: 29 | self.train_dataset = hydra.utils.instantiate(self.cfg.train) 30 | if stage in (TrainerFn.FITTING, TrainerFn.VALIDATING, TrainerFn.TESTING): 31 | self.valid_dataset = hydra.utils.instantiate(self.cfg.valid) 32 | if stage == TrainerFn.TESTING: 33 | self.test_dataset = hydra.utils.instantiate(self.cfg.test) 34 | 35 | def train_dataloader(self) -> DataLoader: 36 | assert self.train_dataset is not None 37 | return self._get_dataloader(dataset=self.train_dataset, shuffle=True) 38 | 39 | def val_dataloader(self) -> DataLoader: 40 | assert self.valid_dataset is not None 41 | return self._get_dataloader(self.valid_dataset, shuffle=False) 42 | 43 | def test_dataloader(self) -> DataLoader: 44 | assert self.test_dataset is not None 45 | return self._get_dataloader(self.test_dataset, shuffle=False) 46 | 47 | def _get_dataloader(self, dataset: Dataset, shuffle: bool) -> DataLoader: 48 | return DataLoader( 49 | dataset=dataset, 50 | batch_size=self.batch_size, 51 | shuffle=shuffle, 52 | num_workers=self.num_workers, 53 | collate_fn=dataclass_data_collator, 54 | pin_memory=True, 55 | ) 56 | 57 | 58 | def dataclass_data_collator(features: list[Any]) -> dict[str, Union[Tensor, list[str]]]: 59 | first: Any = features[0] 60 | assert is_dataclass(first), "Data must be a dataclass" 61 | batch: dict[str, Union[Tensor, list[str]]] = {} 62 | for field in fields(first): 63 | feats = [getattr(f, field.name) for f in features] 64 | batch[field.name] = torch.as_tensor(feats) 65 | return batch 66 | -------------------------------------------------------------------------------- /src/datamodule/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from datamodule.datasets.jcola import JCoLADataset 2 | from datamodule.datasets.jcqa import JCommonsenseQADataset 3 | from datamodule.datasets.jnli import JNLIDataset 4 | from datamodule.datasets.jsquad import JSQuADDataset 5 | from datamodule.datasets.jsts import JSTSDataset 6 | from datamodule.datasets.marc_ja import MARCJaDataset 7 | 8 | __all__ = ["JCoLADataset", "JCommonsenseQADataset", "JNLIDataset", "JSQuADDataset", "JSTSDataset", "MARCJaDataset"] 9 | -------------------------------------------------------------------------------- /src/datamodule/datasets/base.py: -------------------------------------------------------------------------------- 1 | from abc import ABC 2 | from typing import Generic, TypeVar 3 | 4 | from datasets import Dataset as HFDataset # type: ignore[attr-defined] 5 | from datasets import load_dataset # type: ignore[attr-defined] 6 | from torch.utils.data import Dataset 7 | from transformers import PreTrainedTokenizerBase 8 | 9 | FeatureType = TypeVar("FeatureType") 10 | 11 | 12 | class BaseDataset(Dataset[FeatureType], Generic[FeatureType], ABC): 13 | def __init__( 14 | self, 15 | dataset_name: str, 16 | split: str, 17 | tokenizer: PreTrainedTokenizerBase, 18 | max_seq_length: int, 19 | limit_examples: int = -1, 20 | ) -> None: 21 | self.split: str = split 22 | self.tokenizer: PreTrainedTokenizerBase = tokenizer 23 | self.max_seq_length: int = max_seq_length 24 | 25 | # NOTE: JGLUE does not provide test set. 26 | if self.split == "test": 27 | self.split = "validation" 28 | # columns: id, title, context, question, answers, is_impossible 29 | self.hf_dataset: HFDataset = load_dataset( 30 | "shunk031/JGLUE", name=dataset_name, split=self.split, trust_remote_code=True 31 | ) 32 | if limit_examples > 0: 33 | self.hf_dataset = self.hf_dataset.select(range(limit_examples)) 34 | 35 | def __getitem__(self, index: int) -> FeatureType: 36 | raise NotImplementedError 37 | 38 | def __len__(self) -> int: 39 | return len(self.hf_dataset) 40 | -------------------------------------------------------------------------------- /src/datamodule/datasets/jcola.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Any 3 | 4 | from omegaconf import DictConfig 5 | from transformers import PreTrainedTokenizerBase 6 | from transformers.utils import PaddingStrategy 7 | 8 | from datamodule.datasets.base import BaseDataset 9 | from datamodule.datasets.util import SequenceClassificationFeatures, batch_segment 10 | 11 | 12 | class JCoLADataset(BaseDataset[SequenceClassificationFeatures]): 13 | def __init__( 14 | self, 15 | split: str, 16 | tokenizer: PreTrainedTokenizerBase, 17 | max_seq_length: int, 18 | segmenter_kwargs: DictConfig, 19 | limit_examples: int = -1, 20 | ) -> None: 21 | super().__init__("JCoLA", split, tokenizer, max_seq_length, limit_examples) 22 | 23 | self.hf_dataset = self.hf_dataset.map( 24 | lambda x: {"segmented": batch_segment(x["sentence"], **segmenter_kwargs)}, # type: ignore[misc] 25 | batched=True, 26 | batch_size=100, 27 | num_proc=os.cpu_count(), 28 | ).map( 29 | lambda x: self.tokenizer( 30 | x["segmented"], 31 | padding=PaddingStrategy.MAX_LENGTH, 32 | truncation=True, 33 | max_length=self.max_seq_length, 34 | return_token_type_ids=True, 35 | ), 36 | batched=True, 37 | ) 38 | 39 | def __getitem__(self, index: int) -> SequenceClassificationFeatures: 40 | example: dict[str, Any] = self.hf_dataset[index] 41 | return SequenceClassificationFeatures( 42 | input_ids=example["input_ids"], 43 | attention_mask=example["attention_mask"], 44 | token_type_ids=example["token_type_ids"], 45 | labels=example["label"], 46 | ) 47 | -------------------------------------------------------------------------------- /src/datamodule/datasets/jcqa.py: -------------------------------------------------------------------------------- 1 | import os 2 | from itertools import chain 3 | from typing import Any 4 | 5 | from omegaconf import DictConfig 6 | from transformers import PreTrainedTokenizerBase 7 | from transformers.utils import PaddingStrategy 8 | 9 | from datamodule.datasets.base import BaseDataset 10 | from datamodule.datasets.util import MultipleChoiceFeatures, batch_segment 11 | 12 | CHOICE_NAMES = ["choice0", "choice1", "choice2", "choice3", "choice4"] 13 | NUM_CHOICES = len(CHOICE_NAMES) 14 | 15 | 16 | class JCommonsenseQADataset(BaseDataset[MultipleChoiceFeatures]): 17 | def __init__( 18 | self, 19 | split: str, 20 | tokenizer: PreTrainedTokenizerBase, 21 | max_seq_length: int, 22 | segmenter_kwargs: DictConfig, 23 | limit_examples: int = -1, 24 | ) -> None: 25 | super().__init__("JCommonsenseQA", split, tokenizer, max_seq_length, limit_examples) 26 | 27 | def preprocess_function(examples: dict[str, list]) -> dict[str, list[list[Any]]]: 28 | # (example, 5) 29 | first_sentences: list[list[str]] = [[question] * NUM_CHOICES for question in examples["question"]] 30 | second_sentences: list[list[str]] = [ 31 | [examples[name][i] for name in CHOICE_NAMES] for i in range(len(examples["question"])) 32 | ] 33 | # Tokenize 34 | tokenized_examples = self.tokenizer( 35 | list(chain(*first_sentences)), 36 | list(chain(*second_sentences)), 37 | truncation=True, 38 | max_length=self.max_seq_length, 39 | padding=PaddingStrategy.MAX_LENGTH, 40 | return_token_type_ids=True, 41 | ) 42 | # Un-flatten 43 | return { 44 | k: [v[i : i + NUM_CHOICES] for i in range(0, len(v), NUM_CHOICES)] 45 | for k, v in tokenized_examples.items() 46 | } 47 | 48 | self.hf_dataset = self.hf_dataset.map( 49 | lambda x: { 50 | key: batch_segment(x[key], **segmenter_kwargs) # type: ignore[misc] 51 | for key in ["question", *CHOICE_NAMES] 52 | }, 53 | batched=True, 54 | batch_size=100, 55 | num_proc=os.cpu_count(), 56 | ).map( 57 | preprocess_function, 58 | batched=True, 59 | ) 60 | 61 | def __getitem__(self, index: int) -> MultipleChoiceFeatures: 62 | example: dict[str, Any] = self.hf_dataset[index] 63 | return MultipleChoiceFeatures( 64 | input_ids=example["input_ids"], 65 | attention_mask=example["attention_mask"], 66 | token_type_ids=example["token_type_ids"], 67 | labels=example["label"], 68 | ) 69 | -------------------------------------------------------------------------------- /src/datamodule/datasets/jnli.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Any 3 | 4 | from omegaconf import DictConfig 5 | from transformers import PreTrainedTokenizerBase 6 | from transformers.utils import PaddingStrategy 7 | 8 | from datamodule.datasets.base import BaseDataset 9 | from datamodule.datasets.util import SequenceClassificationFeatures, batch_segment 10 | 11 | 12 | class JNLIDataset(BaseDataset[SequenceClassificationFeatures]): 13 | def __init__( 14 | self, 15 | split: str, 16 | tokenizer: PreTrainedTokenizerBase, 17 | max_seq_length: int, 18 | segmenter_kwargs: DictConfig, 19 | limit_examples: int = -1, 20 | ) -> None: 21 | super().__init__("JNLI", split, tokenizer, max_seq_length, limit_examples) 22 | 23 | self.hf_dataset = self.hf_dataset.map( 24 | lambda x: { 25 | "segmented1": batch_segment(x["sentence1"], **segmenter_kwargs), # type: ignore[misc] 26 | "segmented2": batch_segment(x["sentence2"], **segmenter_kwargs), # type: ignore[misc] 27 | }, 28 | batched=True, 29 | batch_size=100, 30 | num_proc=os.cpu_count(), 31 | ).map( 32 | lambda x: self.tokenizer( 33 | x["segmented1"], 34 | x["segmented2"], 35 | padding=PaddingStrategy.MAX_LENGTH, 36 | truncation=True, 37 | max_length=self.max_seq_length, 38 | return_token_type_ids=True, 39 | ), 40 | batched=True, 41 | ) 42 | 43 | def __getitem__(self, index: int) -> SequenceClassificationFeatures: 44 | example: dict[str, Any] = self.hf_dataset[index] 45 | return SequenceClassificationFeatures( 46 | input_ids=example["input_ids"], 47 | attention_mask=example["attention_mask"], 48 | token_type_ids=example["token_type_ids"], 49 | labels=example["label"], # 0: entailment, 1: contradiction, 2: neutral 50 | ) 51 | -------------------------------------------------------------------------------- /src/datamodule/datasets/jsquad.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Optional 2 | 3 | from omegaconf import DictConfig 4 | from transformers import PreTrainedTokenizerBase 5 | from transformers.utils import PaddingStrategy 6 | 7 | from datamodule.datasets.base import BaseDataset 8 | from datamodule.datasets.util import QuestionAnsweringFeatures, batch_segment 9 | 10 | 11 | class JSQuADDataset(BaseDataset[QuestionAnsweringFeatures]): 12 | def __init__( 13 | self, 14 | split: str, 15 | tokenizer: PreTrainedTokenizerBase, 16 | max_seq_length: int, 17 | segmenter_kwargs: DictConfig, 18 | limit_examples: int = -1, 19 | ) -> None: 20 | super().__init__("JSQuAD", split, tokenizer, max_seq_length, limit_examples) 21 | 22 | self.hf_dataset = self.hf_dataset.map( 23 | preprocess, 24 | batched=True, 25 | batch_size=100, 26 | fn_kwargs=dict(segmenter_kwargs=segmenter_kwargs), 27 | load_from_cache_file=False, 28 | ).map( 29 | lambda x: self.tokenizer( 30 | x["question"], 31 | x["context"], 32 | padding=PaddingStrategy.MAX_LENGTH, 33 | truncation="only_second", 34 | max_length=self.max_seq_length, 35 | return_offsets_mapping=True, 36 | return_token_type_ids=True, 37 | ), 38 | batched=True, 39 | load_from_cache_file=False, 40 | ) 41 | 42 | # skip invalid examples for training 43 | if self.split == "train": 44 | self.hf_dataset = self.hf_dataset.filter( 45 | lambda example: any(answer["answer_start"] >= 0 for answer in example["answers"]), 46 | load_from_cache_file=False, 47 | ) 48 | 49 | def __getitem__(self, index: int) -> QuestionAnsweringFeatures: 50 | example: dict[str, Any] = self.hf_dataset[index] 51 | start_positions = end_positions = 0 52 | for answer in example["answers"]: 53 | start_positions, end_positions = self._get_token_span(example, answer["text"], answer["answer_start"]) 54 | if start_positions > 0 or end_positions > 0: 55 | break 56 | 57 | return QuestionAnsweringFeatures( 58 | example_ids=index, 59 | input_ids=example["input_ids"], 60 | attention_mask=example["attention_mask"], 61 | token_type_ids=example["token_type_ids"], 62 | start_positions=start_positions, 63 | end_positions=end_positions, 64 | ) 65 | 66 | @staticmethod 67 | def _get_token_span(example: dict[str, Any], answer_text: str, answer_start: int) -> tuple[int, int]: 68 | """スパンの位置について、文字単位からトークン単位に変換""" 69 | # token_type_ids: 70 | # 0: 1番目の入力(=`question`)のトークン or パディング 71 | # 1: 2番目の入力(=`context`)のトークン 72 | token_type_ids: list[int] = example["token_type_ids"] 73 | # トークンのインデックスと文字のスパンのマッピングを保持した変数 74 | # "京都 大学" -> ["[CLS]", "▁京都", "▁大学", "[SEP]"] のように分割された場合、 75 | # [(0, 0), (0, 2), (2, 5), (0, 0)] 76 | offset_mapping: list[tuple[int, int]] = example["offset_mapping"] 77 | context: str = example["context"] 78 | assert len(offset_mapping) == len(token_type_ids) 79 | token_to_char_start_index = [x[0] for x in offset_mapping] 80 | token_to_char_end_index = [x[1] for x in offset_mapping] 81 | answer_end = answer_start + len(answer_text) 82 | token_start_index = token_end_index = 0 83 | for token_index, (token_type_id, char_start_index, char_end_index) in enumerate( 84 | zip(token_type_ids, token_to_char_start_index, token_to_char_end_index) 85 | ): 86 | if token_type_id != 1 or char_start_index == char_end_index == 0: 87 | continue 88 | # 半角スペースが無視されていない時があるため、その場合はマッピングを1つずらす 89 | char_start_offset = 1 if context[char_start_index] == " " else 0 90 | if answer_start == char_start_index + char_start_offset: 91 | token_start_index = token_index 92 | if answer_end == char_end_index: 93 | token_end_index = token_index 94 | return token_start_index, token_end_index 95 | 96 | 97 | def preprocess(examples: dict[str, list], segmenter_kwargs: dict[str, Any]) -> dict[str, Any]: 98 | if segmenter_kwargs["analyzer"] is None: 99 | return preprocess_no_segmentation(examples) 100 | return preprocess_with_segmentation(examples, segmenter_kwargs) 101 | 102 | 103 | def preprocess_with_segmentation(examples: dict[str, list], segmenter_kwargs: dict[str, Any]) -> dict[str, Any]: 104 | titles: list[str] 105 | bodies: list[str] 106 | titles, bodies = zip(*[context.split(" [SEP] ") for context in examples["context"]]) # type: ignore[assignment] 107 | segmented_titles = batch_segment(titles, **segmenter_kwargs) 108 | segmented_bodies = batch_segment(bodies, **segmenter_kwargs) 109 | segmented_contexts = [f"{title} [SEP] {body}" for title, body in zip(segmented_titles, segmented_bodies)] 110 | segmented_questions = batch_segment(examples["question"], **segmenter_kwargs) 111 | batch_answers: list[list[dict]] = [] 112 | for answers, segmented_context, title in zip(examples["answers"], segmented_contexts, titles): 113 | processed_answers: list[dict] = [] 114 | for answer_text, answer_start in zip(answers["text"], answers["answer_start"]): 115 | segmented_answer_text, segmented_answer_start = find_segmented_answer( 116 | segmented_context, answer_text, answer_start, len(title) 117 | ) 118 | if segmented_answer_start is None: 119 | processed_answers.append(dict(text=answer_text, answer_start=-1)) 120 | continue 121 | assert segmented_answer_text is not None 122 | processed_answers.append(dict(text=segmented_answer_text, answer_start=segmented_answer_start)) 123 | batch_answers.append(processed_answers) 124 | return {"context": segmented_contexts, "question": segmented_questions, "answers": batch_answers} 125 | 126 | 127 | def preprocess_no_segmentation(examples: dict[str, list]) -> dict[str, Any]: 128 | titles: list[str] 129 | bodies: list[str] 130 | titles, bodies = zip(*[context.split(" [SEP] ") for context in examples["context"]]) # type: ignore[assignment] 131 | contexts = [f"{title}[SEP]{body}" for title, body in zip(titles, bodies)] 132 | batch_answers: list[list[dict]] = [] 133 | assert len(examples["answers"]) == len(examples["context"]) == len(contexts) 134 | for answers, orig_context, context in zip(examples["answers"], examples["context"], contexts): 135 | processed_answers: list[dict] = [] 136 | for answer_text, answer_start in zip(answers["text"], answers["answer_start"]): 137 | # two whitespaces are stripped in the preprocessing 138 | offset = -2 if " [SEP] " in orig_context[:answer_start] else 0 139 | if context[answer_start + offset :].startswith(answer_text): 140 | processed_answers.append(dict(text=answer_text, answer_start=answer_start + offset)) 141 | batch_answers.append(processed_answers) 142 | return {"context": contexts, "question": examples["question"], "answers": batch_answers} 143 | 144 | 145 | def find_segmented_answer( 146 | segmented_context: str, answer_text: str, answer_start: int, sep_index: int 147 | ) -> tuple[Optional[str], Optional[int]]: 148 | """単語区切りされた context から単語区切りされた answer のスパンを探索 149 | 150 | Args: 151 | segmented_context: 単語区切りされた context 152 | answer_text: answer の文字列 153 | answer_start: answer の文字単位開始インデックス 154 | sep_index: [SEP] の文字単位開始インデックス 155 | 156 | Returns: 157 | Optional[str]: 単語区切りされた answer(見つからなければ None) 158 | Optional[int]: 単語区切りされた context における answer の文字単位開始インデックス(見つからなければ None) 159 | """ 160 | words = segmented_context.split(" ") 161 | char_to_word_index = {} 162 | char_index = 0 163 | for word_index, word in enumerate(words): 164 | char_to_word_index[char_index] = word_index 165 | # [SEP]だけ前後の半角スペースを考慮する必要があるため+2する 166 | char_length = len(word) + 2 if word == "[SEP]" else len(word) 167 | char_index += char_length 168 | 169 | # 答えのスパンの開始位置が単語区切りに沿うかチェック 170 | if answer_start in char_to_word_index: 171 | word_index = char_to_word_index[answer_start] 172 | buf = [] 173 | for word in words[word_index:]: 174 | buf.append(word) 175 | # 分かち書きしても答えのスパンが見つかる場合 176 | if "".join(buf) == answer_text: 177 | offset = 2 if answer_start >= sep_index else 0 178 | return " ".join(buf), answer_start + word_index - offset 179 | return None, None 180 | -------------------------------------------------------------------------------- /src/datamodule/datasets/jsts.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Any 3 | 4 | from omegaconf import DictConfig 5 | from transformers import PreTrainedTokenizerBase 6 | from transformers.utils import PaddingStrategy 7 | 8 | from datamodule.datasets.base import BaseDataset 9 | from datamodule.datasets.util import SequenceClassificationFeatures, batch_segment 10 | 11 | 12 | class JSTSDataset(BaseDataset[SequenceClassificationFeatures]): 13 | def __init__( 14 | self, 15 | split: str, 16 | tokenizer: PreTrainedTokenizerBase, 17 | max_seq_length: int, 18 | segmenter_kwargs: DictConfig, 19 | limit_examples: int = -1, 20 | ) -> None: 21 | super().__init__("JSTS", split, tokenizer, max_seq_length, limit_examples) 22 | 23 | self.hf_dataset = self.hf_dataset.map( 24 | lambda x: { 25 | "segmented1": batch_segment(x["sentence1"], **segmenter_kwargs), # type: ignore[misc] 26 | "segmented2": batch_segment(x["sentence2"], **segmenter_kwargs), # type: ignore[misc] 27 | }, 28 | batched=True, 29 | batch_size=100, 30 | num_proc=os.cpu_count(), 31 | ).map( 32 | lambda x: self.tokenizer( 33 | x["segmented1"], 34 | x["segmented2"], 35 | padding=PaddingStrategy.MAX_LENGTH, 36 | truncation=True, 37 | max_length=self.max_seq_length, 38 | return_token_type_ids=True, 39 | ), 40 | batched=True, 41 | ) 42 | 43 | def __getitem__(self, index: int) -> SequenceClassificationFeatures: 44 | example: dict[str, Any] = self.hf_dataset[index] 45 | return SequenceClassificationFeatures( 46 | input_ids=example["input_ids"], 47 | attention_mask=example["attention_mask"], 48 | token_type_ids=example["token_type_ids"], 49 | labels=example["label"], 50 | ) 51 | -------------------------------------------------------------------------------- /src/datamodule/datasets/marc_ja.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Any 3 | 4 | from omegaconf import DictConfig 5 | from transformers import PreTrainedTokenizerBase 6 | from transformers.utils import PaddingStrategy 7 | 8 | from datamodule.datasets.base import BaseDataset 9 | from datamodule.datasets.util import SequenceClassificationFeatures, batch_segment 10 | 11 | 12 | class MARCJaDataset(BaseDataset[SequenceClassificationFeatures]): 13 | def __init__( 14 | self, 15 | split: str, 16 | tokenizer: PreTrainedTokenizerBase, 17 | max_seq_length: int, 18 | segmenter_kwargs: DictConfig, 19 | limit_examples: int = -1, 20 | ) -> None: 21 | super().__init__("MARC-ja", split, tokenizer, max_seq_length, limit_examples) 22 | 23 | self.hf_dataset = self.hf_dataset.map( 24 | lambda x: {"segmented": batch_segment(x["sentence"], **segmenter_kwargs)}, # type: ignore[misc] 25 | batched=True, 26 | batch_size=100, 27 | num_proc=os.cpu_count(), 28 | ).map( 29 | lambda x: self.tokenizer( 30 | x["segmented"], 31 | padding=PaddingStrategy.MAX_LENGTH, 32 | truncation=True, 33 | max_length=self.max_seq_length, 34 | return_token_type_ids=True, 35 | ), 36 | batched=True, 37 | ) 38 | 39 | def __getitem__(self, index: int) -> SequenceClassificationFeatures: 40 | example: dict[str, Any] = self.hf_dataset[index] 41 | return SequenceClassificationFeatures( 42 | input_ids=example["input_ids"], 43 | attention_mask=example["attention_mask"], 44 | token_type_ids=example["token_type_ids"], 45 | labels=example["label"], 46 | ) 47 | -------------------------------------------------------------------------------- /src/datamodule/datasets/util.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Optional 3 | 4 | import jaconv 5 | from rhoknp import Jumanpp 6 | 7 | 8 | @dataclass(frozen=True) 9 | class SequenceClassificationFeatures: 10 | input_ids: list[int] 11 | attention_mask: list[int] 12 | token_type_ids: list[int] 13 | labels: int 14 | 15 | 16 | @dataclass(frozen=True) 17 | class MultipleChoiceFeatures: 18 | input_ids: list[list[int]] 19 | attention_mask: list[list[int]] 20 | token_type_ids: list[list[int]] 21 | labels: int 22 | 23 | 24 | @dataclass(frozen=True) 25 | class QuestionAnsweringFeatures: 26 | example_ids: int 27 | input_ids: list[int] 28 | attention_mask: list[int] 29 | token_type_ids: list[int] 30 | start_positions: int 31 | end_positions: int 32 | 33 | 34 | def batch_segment( 35 | texts: list[str], analyzer: Optional[str], h2z: bool = True, mecab_dic_dir: Optional[str] = None 36 | ) -> list[str]: 37 | if analyzer is None: 38 | return texts 39 | segmenter = WordSegmenter(analyzer, h2z, mecab_dic_dir) 40 | return [segmenter.get_segmented_string(text) for text in texts] 41 | 42 | 43 | class WordSegmenter: 44 | def __init__(self, analyzer: str, h2z: bool, mecab_dic_dir: Optional[str] = None) -> None: 45 | self._analyzer: str = analyzer 46 | self._h2z: bool = h2z 47 | 48 | if self._analyzer == "jumanpp": 49 | self._jumanpp = Jumanpp() 50 | elif self._analyzer == "mecab": 51 | tagger_options = [] 52 | if mecab_dic_dir is not None: 53 | tagger_options += f"-d {mecab_dic_dir}".split() 54 | import MeCab 55 | 56 | self._mecab = MeCab.Tagger(" ".join(tagger_options)) 57 | 58 | def get_words(self, string: str) -> list[str]: 59 | words: list[str] = [] 60 | 61 | if self._analyzer == "jumanpp": 62 | sentence = self._jumanpp.apply_to_sentence(string) 63 | words += [morpheme.text for morpheme in sentence.morphemes] 64 | elif self._analyzer == "mecab": 65 | self._mecab.parse("") 66 | node = self._mecab.parseToNode(string) 67 | while node: 68 | word = node.surface 69 | if node.feature.split(",")[0] != "BOS/EOS": 70 | words.append(word) 71 | node = node.next 72 | elif self._analyzer == "char": 73 | for char in string: 74 | words.append(char) 75 | else: 76 | raise NotImplementedError(f"unknown analyzer: {self._analyzer}") 77 | 78 | return words 79 | 80 | def get_segmented_string(self, string: str) -> str: 81 | if self._h2z is True: 82 | string = jaconv.h2z(string) 83 | words = self.get_words(string) 84 | return " ".join(word for word in words) 85 | -------------------------------------------------------------------------------- /src/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from metrics.jsquad import JSQuADMetric 2 | 3 | __all__ = ["JSQuADMetric"] 4 | -------------------------------------------------------------------------------- /src/metrics/jsquad.py: -------------------------------------------------------------------------------- 1 | # This file contains code adapted from transformers 2 | # (https://github.com/huggingface/transformers/blob/main/examples/flax/question-answering/utils_qa.py) 3 | # Copyright 2020 The HuggingFace Team All rights reserved. 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | 6 | from typing import Any 7 | 8 | import numpy as np 9 | import torch 10 | from torchmetrics import Metric 11 | from torchmetrics.text import SQuAD 12 | 13 | from datamodule.datasets import JSQuADDataset 14 | 15 | 16 | class JSQuADMetric(Metric): 17 | is_differentiable: bool = False 18 | higher_is_better: bool = True 19 | full_state_update: bool = False 20 | 21 | def __init__(self) -> None: 22 | super().__init__() 23 | self.squad = SQuAD() 24 | 25 | def update( 26 | self, 27 | example_ids: torch.Tensor, # (b) 28 | batch_start_logits: torch.Tensor, # (b, seq) 29 | batch_end_logits: torch.Tensor, # (b, seq) 30 | dataset: JSQuADDataset, 31 | ) -> None: 32 | preds = [] 33 | target = [] 34 | for example_id, start_logits, end_logits in zip( 35 | example_ids.tolist(), batch_start_logits.tolist(), batch_end_logits.tolist() 36 | ): 37 | example = dataset.hf_dataset[example_id] 38 | prediction_text: str = _postprocess_predictions(start_logits, end_logits, example) 39 | preds.append( 40 | { 41 | "prediction_text": self._postprocess_text(prediction_text), 42 | "id": example_id, 43 | } 44 | ) 45 | target.append( 46 | { 47 | "answers": { 48 | "text": [self._postprocess_text(answer["text"]) for answer in example["answers"]], 49 | "answer_start": [answer["answer_start"] for answer in example["answers"]], 50 | }, 51 | "id": example_id, 52 | } 53 | ) 54 | self.squad.update(preds, target) 55 | 56 | def compute(self) -> dict[str, torch.Tensor]: 57 | return {k: v / 100.0 for k, v in self.squad.compute().items()} 58 | 59 | @staticmethod 60 | def _postprocess_text(text: str) -> str: 61 | """句点を除去し,文字単位に分割""" 62 | return " ".join(text.replace(" ", "").rstrip("。")) 63 | 64 | 65 | def _postprocess_predictions( 66 | start_logits: list[float], 67 | end_logits: list[float], 68 | example: dict[str, Any], 69 | n_best_size: int = 20, 70 | max_answer_length: int = 30, 71 | ) -> str: 72 | """ 73 | Post-processes the predictions of a question-answering model to convert them to answers that are substrings of the 74 | original contexts. This is the base postprocessing functions for models that only return start and end logits. 75 | 76 | Args: 77 | start_logits (:obj:`List[float]`): 78 | The logits corresponding to the start of the span for each token. 79 | end_logits (:obj:`List[float]`): 80 | The logits corresponding to the end of the span for each token. 81 | example: The processed dataset. 82 | n_best_size (:obj:`int`, `optional`, defaults to 20): 83 | The total number of n-best predictions to generate when looking for an answer. 84 | max_answer_length (:obj:`int`, `optional`, defaults to 30): 85 | The maximum length of an answer that can be generated. This is needed because the start and end predictions 86 | are not conditioned on one another. 87 | """ 88 | prelim_predictions = [] 89 | 90 | # This is what will allow us to map some the positions in our logits to span of texts in the original 91 | # context. 92 | offset_mapping = example["offset_mapping"] 93 | # Optional `token_is_max_context`, if provided we will remove answers that do not have the maximum context 94 | # available in the current feature. 95 | token_is_max_context = example.get("token_is_max_context") 96 | 97 | # Go through all possibilities for the `n_best_size` greater start and end logits. 98 | start_indexes = np.argsort(start_logits)[-1 : -n_best_size - 1 : -1].tolist() 99 | end_indexes = np.argsort(end_logits)[-1 : -n_best_size - 1 : -1].tolist() 100 | for start_index in start_indexes: 101 | for end_index in end_indexes: 102 | # Don't consider out-of-scope answers, either because the indices are out of bounds or correspond 103 | # to part of the input_ids that are not in the context. 104 | if ( 105 | start_index >= len(offset_mapping) 106 | or end_index >= len(offset_mapping) 107 | or offset_mapping[start_index] is None 108 | or len(offset_mapping[start_index]) < 2 109 | or offset_mapping[end_index] is None 110 | or len(offset_mapping[end_index]) < 2 111 | ): 112 | continue 113 | # Don't consider answers with a length that is either < 0 or > max_answer_length. 114 | if end_index < start_index or end_index - start_index + 1 > max_answer_length: 115 | continue 116 | # Don't consider answer that don't have the maximum context available (if such information is 117 | # provided). 118 | if token_is_max_context is not None and not token_is_max_context.get(str(start_index), False): 119 | continue 120 | 121 | prelim_predictions.append( 122 | { 123 | "offsets": (offset_mapping[start_index][0], offset_mapping[end_index][1]), 124 | "score": start_logits[start_index] + end_logits[end_index], 125 | "start_logit": start_logits[start_index], 126 | "end_logit": end_logits[end_index], 127 | } 128 | ) 129 | 130 | # Only keep the best `n_best_size` predictions. 131 | predictions: list[dict[str, Any]] = sorted(prelim_predictions, key=lambda x: x["score"], reverse=True)[:n_best_size] 132 | # Use the offsets to gather the answer text in the original context. 133 | context = example["context"] 134 | for pred in predictions: 135 | offsets = pred.pop("offsets") 136 | pred["text"] = context[offsets[0] : offsets[1]] 137 | # In the very rare edge case we have not a single non-null prediction, we create a fake prediction to avoid 138 | # failure. 139 | if len(predictions) == 0 or (len(predictions) == 1 and predictions[0]["text"] == ""): 140 | predictions.insert(0, {"text": "empty", "start_logit": 0.0, "end_logit": 0.0, "score": 0.0}) 141 | # Compute the softmax of all scores (we do it with numpy to stay independent of torch/tf in this file, using 142 | # the LogSumExp trick). 143 | scores = np.array([pred.pop("score") for pred in predictions]) 144 | exp_scores = np.exp(scores - np.max(scores)) 145 | probs = exp_scores / exp_scores.sum() 146 | # Include the probabilities in our predictions. 147 | for prob, pred in zip(probs, predictions): 148 | pred["probability"] = prob 149 | 150 | return predictions[0]["text"] # return top 1 prediction 151 | -------------------------------------------------------------------------------- /src/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from modules.jcola import JCoLAModule 2 | from modules.jcqa import JCommonsenseQAModule 3 | from modules.jnli import JNLIModule 4 | from modules.jsquad import JSQuADModule 5 | from modules.jsts import JSTSModule 6 | from modules.marc_ja import MARCJaModule 7 | 8 | __all__ = ["JCoLAModule", "JCommonsenseQAModule", "JNLIModule", "JSQuADModule", "JSTSModule", "MARCJaModule"] 9 | -------------------------------------------------------------------------------- /src/modules/base.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from typing import Any 3 | 4 | import hydra 5 | from lightning import LightningModule 6 | from lightning.pytorch.utilities.types import OptimizerLRScheduler 7 | from omegaconf import DictConfig, OmegaConf 8 | 9 | 10 | class BaseModule(LightningModule): 11 | def __init__(self, hparams: DictConfig) -> None: 12 | super().__init__() 13 | self.save_hyperparameters(hparams) 14 | 15 | def configure_optimizers(self) -> OptimizerLRScheduler: 16 | # Split weights in two groups, one with weight decay and the other not. 17 | no_decay = ("bias", "LayerNorm.weight") 18 | optimizer_grouped_parameters = [ 19 | { 20 | "params": [ 21 | p for n, p in self.named_parameters() if not any(nd in n for nd in no_decay) and p.requires_grad 22 | ], 23 | "weight_decay": self.hparams.optimizer.weight_decay, 24 | "name": "decay", 25 | }, 26 | { 27 | "params": [ 28 | p for n, p in self.named_parameters() if any(nd in n for nd in no_decay) and p.requires_grad 29 | ], 30 | "weight_decay": 0.0, 31 | "name": "no_decay", 32 | }, 33 | ] 34 | optimizer = hydra.utils.instantiate( 35 | self.hparams.optimizer, params=optimizer_grouped_parameters, _convert_="partial" 36 | ) 37 | total_steps = self.trainer.estimated_stepping_batches 38 | warmup_steps = self.hparams.warmup_steps or total_steps * self.hparams.warmup_ratio 39 | if hasattr(self.hparams.scheduler, "num_warmup_steps"): 40 | self.hparams.scheduler.num_warmup_steps = warmup_steps 41 | if hasattr(self.hparams.scheduler, "num_training_steps"): 42 | self.hparams.scheduler.num_training_steps = total_steps 43 | lr_scheduler = hydra.utils.instantiate(self.hparams.scheduler, optimizer=optimizer) 44 | return {"optimizer": optimizer, "lr_scheduler": {"scheduler": lr_scheduler, "interval": "step", "frequency": 1}} 45 | 46 | def on_save_checkpoint(self, checkpoint: dict[str, Any]) -> None: 47 | hparams: DictConfig = copy.deepcopy(checkpoint["hyper_parameters"]) 48 | OmegaConf.set_struct(hparams, value=False) 49 | checkpoint["hyper_parameters"] = hparams 50 | -------------------------------------------------------------------------------- /src/modules/jcola.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | import torch 4 | from omegaconf import DictConfig 5 | from torchmetrics.classification import MulticlassAccuracy 6 | from transformers import AutoConfig, AutoModelForSequenceClassification 7 | from transformers.modeling_outputs import SequenceClassifierOutput 8 | 9 | from modules.base import BaseModule 10 | 11 | 12 | class JCoLAModule(BaseModule): 13 | def __init__(self, hparams: DictConfig) -> None: 14 | super().__init__(hparams) 15 | config = AutoConfig.from_pretrained( 16 | hparams.model.model_name_or_path, 17 | num_labels=2, 18 | finetuning_task="JCoLA", 19 | ) 20 | self.model = AutoModelForSequenceClassification.from_pretrained( 21 | hparams.model.model_name_or_path, 22 | config=config, 23 | ) 24 | self.metric = MulticlassAccuracy(num_classes=2, average="micro") 25 | 26 | def forward(self, batch: dict[str, Any]) -> SequenceClassifierOutput: 27 | return self.model(**batch) 28 | 29 | def training_step(self, batch: Any) -> torch.Tensor: 30 | out: SequenceClassifierOutput = self(batch) 31 | self.log("train/loss", out.loss) 32 | return out.loss 33 | 34 | def validation_step(self, batch: Any) -> None: 35 | out: SequenceClassifierOutput = self(batch) 36 | predictions = torch.argmax(out.logits, dim=1) # (b) 37 | self.metric.update(predictions, batch["labels"]) 38 | 39 | def on_validation_epoch_end(self) -> None: 40 | self.log("valid/accuracy", self.metric.compute()) 41 | self.metric.reset() 42 | 43 | def test_step(self, batch: Any) -> None: 44 | out: SequenceClassifierOutput = self(batch) 45 | predictions = torch.argmax(out.logits, dim=1) # (b) 46 | self.metric.update(predictions, batch["labels"]) 47 | 48 | def on_test_epoch_end(self) -> None: 49 | self.log("test/accuracy", self.metric.compute()) 50 | self.metric.reset() 51 | -------------------------------------------------------------------------------- /src/modules/jcqa.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | import torch 4 | from omegaconf import DictConfig 5 | from torchmetrics.classification import MulticlassAccuracy 6 | from transformers import AutoConfig, AutoModelForMultipleChoice 7 | from transformers.modeling_outputs import MultipleChoiceModelOutput 8 | 9 | from datamodule.datasets.jcqa import NUM_CHOICES 10 | from modules.base import BaseModule 11 | 12 | 13 | class JCommonsenseQAModule(BaseModule): 14 | def __init__(self, hparams: DictConfig) -> None: 15 | super().__init__(hparams) 16 | config = AutoConfig.from_pretrained( 17 | hparams.model.model_name_or_path, 18 | num_labels=NUM_CHOICES, 19 | finetuning_task="JCommonsenseQA", 20 | ) 21 | self.model = AutoModelForMultipleChoice.from_pretrained( 22 | hparams.model.model_name_or_path, 23 | config=config, 24 | ) 25 | self.metric = MulticlassAccuracy(num_classes=NUM_CHOICES, average="micro") 26 | 27 | def forward(self, batch: dict[str, Any]) -> MultipleChoiceModelOutput: 28 | return self.model(**batch) 29 | 30 | def training_step(self, batch: Any) -> torch.Tensor: 31 | out: MultipleChoiceModelOutput = self(batch) 32 | self.log("train/loss", out.loss) 33 | return out.loss 34 | 35 | def validation_step(self, batch: Any) -> None: 36 | out: MultipleChoiceModelOutput = self(batch) 37 | predictions = torch.argmax(out.logits, dim=1) # (b) 38 | self.metric.update(predictions, batch["labels"]) 39 | 40 | def on_validation_epoch_end(self) -> None: 41 | self.log("valid/accuracy", self.metric.compute()) 42 | self.metric.reset() 43 | 44 | def test_step(self, batch: Any) -> None: 45 | out: MultipleChoiceModelOutput = self(batch) 46 | predictions = torch.argmax(out.logits, dim=1) # (b) 47 | self.metric.update(predictions, batch["labels"]) 48 | 49 | def on_test_epoch_end(self) -> None: 50 | self.log("test/accuracy", self.metric.compute()) 51 | self.metric.reset() 52 | -------------------------------------------------------------------------------- /src/modules/jnli.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | import torch 4 | from omegaconf import DictConfig 5 | from torchmetrics.classification import MulticlassAccuracy 6 | from transformers import AutoConfig, AutoModelForSequenceClassification, PretrainedConfig, PreTrainedModel 7 | from transformers.modeling_outputs import SequenceClassifierOutput 8 | 9 | from modules.base import BaseModule 10 | 11 | 12 | class JNLIModule(BaseModule): 13 | def __init__(self, hparams: DictConfig) -> None: 14 | super().__init__(hparams) 15 | config: PretrainedConfig = AutoConfig.from_pretrained( 16 | hparams.model.model_name_or_path, 17 | num_labels=3, 18 | finetuning_task="JNLI", 19 | ) 20 | self.model: PreTrainedModel = AutoModelForSequenceClassification.from_pretrained( 21 | hparams.model.model_name_or_path, 22 | config=config, 23 | ) 24 | self.metric = MulticlassAccuracy(num_classes=3, average="micro") 25 | 26 | def forward(self, batch: dict[str, Any]) -> SequenceClassifierOutput: 27 | return self.model(**batch) 28 | 29 | def training_step(self, batch: Any) -> torch.Tensor: 30 | out: SequenceClassifierOutput = self(batch) 31 | self.log("train/loss", out.loss) 32 | return out.loss 33 | 34 | def validation_step(self, batch: Any) -> None: 35 | out: SequenceClassifierOutput = self(batch) 36 | predictions = torch.argmax(out.logits, dim=1) # (b) 37 | self.metric.update(predictions, batch["labels"]) 38 | 39 | def on_validation_epoch_end(self) -> None: 40 | self.log("valid/accuracy", self.metric.compute()) 41 | self.metric.reset() 42 | 43 | def test_step(self, batch: Any) -> None: 44 | out: SequenceClassifierOutput = self(batch) 45 | predictions = torch.argmax(out.logits, dim=1) # (b) 46 | self.metric.update(predictions, batch["labels"]) 47 | 48 | def on_test_epoch_end(self) -> None: 49 | self.log("test/accuracy", self.metric.compute()) 50 | self.metric.reset() 51 | -------------------------------------------------------------------------------- /src/modules/jsquad.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING, ClassVar 2 | 3 | import torch 4 | from omegaconf import DictConfig 5 | from transformers import AutoConfig, AutoModelForQuestionAnswering, PretrainedConfig, PreTrainedModel 6 | from transformers.modeling_outputs import QuestionAnsweringModelOutput 7 | 8 | from metrics import JSQuADMetric 9 | from modules.base import BaseModule 10 | 11 | if TYPE_CHECKING: 12 | from datamodule.datasets.jsquad import JSQuADDataset 13 | 14 | 15 | class JSQuADModule(BaseModule): 16 | MODEL_ARGS: ClassVar[list[str]] = [ 17 | "input_ids", 18 | "attention_mask", 19 | "token_type_ids", 20 | "start_positions", 21 | "end_positions", 22 | ] 23 | 24 | def __init__(self, hparams: DictConfig) -> None: 25 | super().__init__(hparams) 26 | config: PretrainedConfig = AutoConfig.from_pretrained( 27 | hparams.model.model_name_or_path, 28 | finetuning_task="JSQuAD", 29 | ) 30 | self.model: PreTrainedModel = AutoModelForQuestionAnswering.from_pretrained( 31 | hparams.model.model_name_or_path, 32 | config=config, 33 | ) 34 | self.metric = JSQuADMetric() 35 | 36 | def forward(self, batch: dict[str, torch.Tensor]) -> QuestionAnsweringModelOutput: 37 | return self.model(**{k: v for k, v in batch.items() if k in self.MODEL_ARGS}) 38 | 39 | def training_step(self, batch: dict[str, torch.Tensor]) -> torch.Tensor: 40 | out: QuestionAnsweringModelOutput = self(batch) 41 | self.log("train/loss", out.loss) 42 | return out.loss 43 | 44 | def validation_step(self, batch: dict[str, torch.Tensor]) -> None: 45 | out: QuestionAnsweringModelOutput = self(batch) 46 | dataset: JSQuADDataset = self.trainer.val_dataloaders.dataset 47 | self.metric.update(batch["example_ids"], out.start_logits, out.end_logits, dataset) 48 | 49 | def on_validation_epoch_end(self) -> None: 50 | self.log_dict({f"valid/{key}": value for key, value in self.metric.compute().items()}) 51 | self.metric.reset() 52 | 53 | def test_step(self, batch: dict[str, torch.Tensor]) -> None: 54 | out: QuestionAnsweringModelOutput = self(batch) 55 | dataset: JSQuADDataset = self.trainer.test_dataloaders.dataset 56 | self.metric.update(batch["example_ids"], out.start_logits, out.end_logits, dataset) 57 | 58 | def on_test_epoch_end(self) -> None: 59 | self.log_dict({f"test/{key}": value for key, value in self.metric.compute().items()}) 60 | self.metric.reset() 61 | -------------------------------------------------------------------------------- /src/modules/jsts.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | import torch 4 | from omegaconf import DictConfig 5 | from torchmetrics import MetricCollection, PearsonCorrCoef, SpearmanCorrCoef 6 | from transformers import AutoConfig, AutoModelForSequenceClassification 7 | from transformers.modeling_outputs import SequenceClassifierOutput 8 | 9 | from modules.base import BaseModule 10 | 11 | 12 | class JSTSModule(BaseModule): 13 | def __init__(self, hparams: DictConfig) -> None: 14 | super().__init__(hparams) 15 | config = AutoConfig.from_pretrained( 16 | hparams.model.model_name_or_path, 17 | num_labels=1, 18 | finetuning_task="JSTS", 19 | ) 20 | self.model = AutoModelForSequenceClassification.from_pretrained( 21 | hparams.model.model_name_or_path, 22 | config=config, 23 | ) 24 | self.metric = MetricCollection({"spearman": SpearmanCorrCoef(), "pearson": PearsonCorrCoef()}) 25 | 26 | def forward(self, batch: dict[str, Any]) -> SequenceClassifierOutput: 27 | return self.model(**batch) 28 | 29 | def training_step(self, batch: Any) -> torch.Tensor: 30 | out: SequenceClassifierOutput = self(batch) 31 | self.log("train/loss", out.loss) 32 | return out.loss 33 | 34 | def validation_step(self, batch: Any) -> None: 35 | out: SequenceClassifierOutput = self(batch) 36 | predictions = torch.squeeze(out.logits, dim=-1) # (b) 37 | self.metric.update(predictions, batch["labels"]) 38 | 39 | def on_validation_epoch_end(self) -> None: 40 | self.log_dict({f"valid/{key}": value for key, value in self.metric.compute().items()}) 41 | self.metric.reset() 42 | 43 | def test_step(self, batch: Any) -> None: 44 | out: SequenceClassifierOutput = self(batch) 45 | predictions = torch.squeeze(out.logits, dim=-1) # (b) 46 | self.metric.update(predictions, batch["labels"]) 47 | 48 | def on_test_epoch_end(self) -> None: 49 | self.log_dict({f"test/{key}": value for key, value in self.metric.compute().items()}) 50 | self.metric.reset() 51 | -------------------------------------------------------------------------------- /src/modules/marc_ja.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | import torch 4 | from omegaconf import DictConfig 5 | from torchmetrics.classification import MulticlassAccuracy 6 | from transformers import AutoConfig, AutoModelForSequenceClassification 7 | from transformers.modeling_outputs import SequenceClassifierOutput 8 | 9 | from modules.base import BaseModule 10 | 11 | 12 | class MARCJaModule(BaseModule): 13 | def __init__(self, hparams: DictConfig) -> None: 14 | super().__init__(hparams) 15 | config = AutoConfig.from_pretrained( 16 | hparams.model.model_name_or_path, 17 | num_labels=2, 18 | finetuning_task="MARC-ja", 19 | ) 20 | self.model = AutoModelForSequenceClassification.from_pretrained( 21 | hparams.model.model_name_or_path, 22 | config=config, 23 | ) 24 | self.metric = MulticlassAccuracy(num_classes=2, average="micro") 25 | 26 | def forward(self, batch: dict[str, Any]) -> SequenceClassifierOutput: 27 | return self.model(**batch) 28 | 29 | def training_step(self, batch: Any) -> torch.Tensor: 30 | out: SequenceClassifierOutput = self(batch) 31 | self.log("train/loss", out.loss) 32 | return out.loss 33 | 34 | def validation_step(self, batch: Any) -> None: 35 | out: SequenceClassifierOutput = self(batch) 36 | predictions = torch.argmax(out.logits, dim=1) # (b) 37 | self.metric.update(predictions, batch["labels"]) 38 | 39 | def on_validation_epoch_end(self) -> None: 40 | self.log("valid/accuracy", self.metric.compute()) 41 | self.metric.reset() 42 | 43 | def test_step(self, batch: Any) -> None: 44 | out: SequenceClassifierOutput = self(batch) 45 | predictions = torch.argmax(out.logits, dim=1) # (b) 46 | self.metric.update(predictions, batch["labels"]) 47 | 48 | def on_test_epoch_end(self) -> None: 49 | self.log("test/accuracy", self.metric.compute()) 50 | self.metric.reset() 51 | -------------------------------------------------------------------------------- /src/test.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import warnings 3 | from typing import TYPE_CHECKING, Union 4 | 5 | import hydra 6 | import torch 7 | import transformers.utils.logging as hf_logging 8 | from lightning.pytorch.trainer.states import TrainerFn 9 | from lightning.pytorch.utilities.warnings import PossibleUserWarning 10 | from omegaconf import DictConfig, ListConfig, OmegaConf 11 | 12 | from datamodule.datamodule import DataModule 13 | 14 | if TYPE_CHECKING: 15 | from lightning import Callback, LightningModule, Trainer 16 | from lightning.pytorch.loggers import Logger 17 | 18 | hf_logging.set_verbosity(hf_logging.ERROR) 19 | warnings.filterwarnings( 20 | "ignore", 21 | message=r"It is recommended to use .+ when logging on epoch level in distributed setting to accumulate the metric" 22 | r" across devices", 23 | category=PossibleUserWarning, 24 | ) 25 | logging.getLogger("torch").setLevel(logging.WARNING) 26 | 27 | 28 | @hydra.main(version_base=None, config_path="../configs", config_name="eval") 29 | def main(eval_cfg: DictConfig) -> None: 30 | if isinstance(eval_cfg.devices, str): 31 | eval_cfg.devices = ( 32 | list(map(int, eval_cfg.devices.split(","))) if "," in eval_cfg.devices else int(eval_cfg.devices) 33 | ) 34 | if isinstance(eval_cfg.max_batches_per_device, str): 35 | eval_cfg.max_batches_per_device = int(eval_cfg.max_batches_per_device) 36 | if isinstance(eval_cfg.num_workers, str): 37 | eval_cfg.num_workers = int(eval_cfg.num_workers) 38 | 39 | # Load saved model and configs 40 | model: LightningModule = hydra.utils.call(eval_cfg.module.load_from_checkpoint, _recursive_=False) 41 | if eval_cfg.compile is True: 42 | model = torch.compile(model) 43 | 44 | train_cfg: DictConfig = model.hparams 45 | OmegaConf.set_struct(train_cfg, value=False) # enable to add new key-value pairs 46 | cfg = OmegaConf.merge(train_cfg, eval_cfg) 47 | assert isinstance(cfg, DictConfig) 48 | 49 | logger: Union[Logger, bool] = cfg.get("logger", False) and hydra.utils.instantiate(cfg.get("logger")) 50 | callbacks: list[Callback] = list(map(hydra.utils.instantiate, cfg.get("callbacks", {}).values())) 51 | 52 | num_devices: int = 1 53 | if isinstance(cfg.devices, (list, ListConfig)): 54 | num_devices = len(cfg.devices) 55 | elif isinstance(cfg.devices, int): 56 | num_devices = cfg.devices 57 | cfg.effective_batch_size = cfg.max_batches_per_device * num_devices 58 | cfg.datamodule.batch_size = cfg.max_batches_per_device 59 | 60 | trainer: Trainer = hydra.utils.instantiate( 61 | cfg.trainer, 62 | logger=logger, 63 | callbacks=callbacks, 64 | devices=cfg.devices, 65 | ) 66 | 67 | datamodule = DataModule(cfg=cfg.datamodule) 68 | datamodule.setup(stage=TrainerFn.TESTING) 69 | if cfg.eval_set == "test": 70 | dataloader = datamodule.test_dataloader() 71 | elif cfg.eval_set == "valid": 72 | dataloader = datamodule.val_dataloader() 73 | else: 74 | raise ValueError(f"invalid eval_set: {cfg.eval_set}") 75 | trainer.test(model=model, dataloaders=dataloader) 76 | 77 | 78 | if __name__ == "__main__": 79 | main() 80 | -------------------------------------------------------------------------------- /src/train.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import math 3 | import warnings 4 | from typing import TYPE_CHECKING, Union 5 | 6 | import hydra 7 | import torch 8 | import transformers.utils.logging as hf_logging 9 | from lightning import Callback, LightningModule, Trainer, seed_everything 10 | from lightning.pytorch.utilities.warnings import PossibleUserWarning 11 | from omegaconf import DictConfig, ListConfig 12 | 13 | from datamodule.datamodule import DataModule 14 | 15 | if TYPE_CHECKING: 16 | from lightning.pytorch.loggers import Logger 17 | 18 | hf_logging.set_verbosity(hf_logging.ERROR) 19 | warnings.filterwarnings( 20 | "ignore", 21 | message=r"It is recommended to use .+ when logging on epoch level in distributed setting to accumulate the metric" 22 | r" across devices", 23 | category=PossibleUserWarning, 24 | ) 25 | logging.getLogger("torch").setLevel(logging.WARNING) 26 | 27 | 28 | @hydra.main(version_base=None, config_path="../configs") 29 | def main(cfg: DictConfig) -> None: 30 | if isinstance(cfg.devices, str): 31 | cfg.devices = list(map(int, cfg.devices.split(","))) if "," in cfg.devices else int(cfg.devices) 32 | if isinstance(cfg.max_batches_per_device, str): 33 | cfg.max_batches_per_device = int(cfg.max_batches_per_device) 34 | if isinstance(cfg.num_workers, str): 35 | cfg.num_workers = int(cfg.num_workers) 36 | cfg.seed = seed_everything(seed=cfg.seed, workers=True) 37 | 38 | logger: Union[Logger, bool] = cfg.get("logger", False) and hydra.utils.instantiate(cfg.get("logger")) 39 | callbacks: list[Callback] = list(map(hydra.utils.instantiate, cfg.get("callbacks", {}).values())) 40 | 41 | # Calculate gradient_accumulation_steps assuming DDP 42 | num_devices: int = 1 43 | if isinstance(cfg.devices, (list, ListConfig)): 44 | num_devices = len(cfg.devices) 45 | elif isinstance(cfg.devices, int): 46 | num_devices = cfg.devices 47 | cfg.trainer.accumulate_grad_batches = math.ceil( 48 | cfg.effective_batch_size / (cfg.max_batches_per_device * num_devices) 49 | ) 50 | batches_per_device = cfg.effective_batch_size // (num_devices * cfg.trainer.accumulate_grad_batches) 51 | # if effective_batch_size % (accumulate_grad_batches * num_devices) != 0, then 52 | # cause an error of at most accumulate_grad_batches * num_devices compared in effective_batch_size 53 | # otherwise, no error 54 | cfg.effective_batch_size = batches_per_device * num_devices * cfg.trainer.accumulate_grad_batches 55 | cfg.datamodule.batch_size = batches_per_device 56 | 57 | trainer: Trainer = hydra.utils.instantiate( 58 | cfg.trainer, 59 | logger=logger, 60 | callbacks=callbacks, 61 | devices=cfg.devices, 62 | ) 63 | 64 | datamodule = DataModule(cfg=cfg.datamodule) 65 | 66 | model: LightningModule = hydra.utils.instantiate(cfg.module.cls, hparams=cfg, _recursive_=False) 67 | if cfg.compile is True: 68 | model = torch.compile(model) 69 | 70 | trainer.fit(model=model, datamodule=datamodule) 71 | trainer.test(model=model, datamodule=datamodule, ckpt_path="best" if not trainer.fast_dev_run else None) 72 | 73 | 74 | if __name__ == "__main__": 75 | main() 76 | -------------------------------------------------------------------------------- /sweeps/jcola.yaml: -------------------------------------------------------------------------------- 1 | project: JGLUE-evaluation-scripts 2 | name: jcola-MODEL_NAME 3 | program: src/train.py 4 | method: grid 5 | metric: 6 | name: valid/accuracy 7 | goal: maximize 8 | parameters: 9 | lr: 10 | values: [0.00002, 0.00003, 0.00005] 11 | max_epochs: 12 | values: [3, 4] 13 | command: 14 | - ${env} 15 | - ${interpreter} 16 | - ${program} 17 | - "-cn" 18 | - "jcola" 19 | - "model=MODEL_NAME" 20 | - ${args_no_hyphens} 21 | -------------------------------------------------------------------------------- /sweeps/jcqa.yaml: -------------------------------------------------------------------------------- 1 | project: JGLUE-evaluation-scripts 2 | name: jcqa-MODEL_NAME 3 | program: src/train.py 4 | method: grid 5 | metric: 6 | name: valid/accuracy 7 | goal: maximize 8 | parameters: 9 | lr: 10 | values: [0.00002, 0.00003, 0.00005] 11 | max_epochs: 12 | values: [3, 4] 13 | command: 14 | - ${env} 15 | - ${interpreter} 16 | - ${program} 17 | - "-cn" 18 | - "jcqa" 19 | - "model=MODEL_NAME" 20 | - ${args_no_hyphens} 21 | -------------------------------------------------------------------------------- /sweeps/jnli.yaml: -------------------------------------------------------------------------------- 1 | project: JGLUE-evaluation-scripts 2 | name: jnli-MODEL_NAME 3 | program: src/train.py 4 | method: grid 5 | metric: 6 | name: valid/accuracy 7 | goal: maximize 8 | parameters: 9 | lr: 10 | values: [0.00002, 0.00003, 0.00005] 11 | max_epochs: 12 | values: [3, 4] 13 | command: 14 | - ${env} 15 | - ${interpreter} 16 | - ${program} 17 | - "-cn" 18 | - "jnli" 19 | - "model=MODEL_NAME" 20 | - ${args_no_hyphens} 21 | -------------------------------------------------------------------------------- /sweeps/jsquad.yaml: -------------------------------------------------------------------------------- 1 | project: JGLUE-evaluation-scripts 2 | name: jsquad-MODEL_NAME 3 | program: src/train.py 4 | method: grid 5 | metric: 6 | name: valid/f1 7 | goal: maximize 8 | parameters: 9 | lr: 10 | values: [0.00002, 0.00003, 0.00005] 11 | max_epochs: 12 | values: [3, 4] 13 | command: 14 | - ${env} 15 | - ${interpreter} 16 | - ${program} 17 | - "-cn" 18 | - "jsquad" 19 | - "model=MODEL_NAME" 20 | - ${args_no_hyphens} 21 | -------------------------------------------------------------------------------- /sweeps/jsts.yaml: -------------------------------------------------------------------------------- 1 | project: JGLUE-evaluation-scripts 2 | name: jsts-MODEL_NAME 3 | program: src/train.py 4 | method: grid 5 | metric: 6 | name: valid/spearman 7 | goal: maximize 8 | parameters: 9 | lr: 10 | values: [0.00002, 0.00003, 0.00005] 11 | max_epochs: 12 | values: [3, 4] 13 | command: 14 | - ${env} 15 | - ${interpreter} 16 | - ${program} 17 | - "-cn" 18 | - "jsts" 19 | - "model=MODEL_NAME" 20 | - ${args_no_hyphens} 21 | -------------------------------------------------------------------------------- /sweeps/marc_ja.yaml: -------------------------------------------------------------------------------- 1 | project: JGLUE-evaluation-scripts 2 | name: marc_ja-MODEL_NAME 3 | program: src/train.py 4 | method: grid 5 | metric: 6 | name: valid/accuracy 7 | goal: maximize 8 | parameters: 9 | lr: 10 | values: [0.00002, 0.00003, 0.00005] 11 | max_epochs: 12 | values: [3, 4] 13 | command: 14 | - ${env} 15 | - ${interpreter} 16 | - ${program} 17 | - "-cn" 18 | - "marc_ja" 19 | - "model=MODEL_NAME" 20 | - ${args_no_hyphens} 21 | -------------------------------------------------------------------------------- /tests/datasets/conftest.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from pathlib import Path 3 | 4 | import pytest 5 | from transformers import AutoTokenizer, PreTrainedTokenizerBase 6 | 7 | sys.path.append(str(Path(__file__).parent.parent.parent / "src")) 8 | 9 | 10 | @pytest.fixture 11 | def tokenizer() -> PreTrainedTokenizerBase: 12 | return AutoTokenizer.from_pretrained("ku-nlp/deberta-v2-tiny-japanese") 13 | 14 | 15 | @pytest.fixture 16 | def deberta_v3_tokenizer() -> PreTrainedTokenizerBase: 17 | return AutoTokenizer.from_pretrained("ku-nlp/deberta-v3-base-japanese") 18 | -------------------------------------------------------------------------------- /tests/datasets/test_jsquad.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | from datasets import Dataset as HFDataset # type: ignore[attr-defined] 4 | from datasets import load_dataset # type: ignore[attr-defined] 5 | from omegaconf import DictConfig 6 | from transformers import DebertaV2TokenizerFast, PreTrainedTokenizerBase 7 | 8 | from datamodule.datasets.jsquad import JSQuADDataset 9 | 10 | 11 | def test_init(tokenizer: PreTrainedTokenizerBase) -> None: 12 | _ = JSQuADDataset( 13 | "train", tokenizer, max_seq_length=128, segmenter_kwargs=DictConfig({"analyzer": None}), limit_examples=3 14 | ) 15 | 16 | 17 | def test_raw_examples() -> None: 18 | dataset: HFDataset = load_dataset("shunk031/JGLUE", name="JSQuAD", split="validation", trust_remote_code=True) 19 | for example in dataset: 20 | assert isinstance(example["id"], str) 21 | assert isinstance(example["title"], str) 22 | assert isinstance(example["context"], str) 23 | assert isinstance(example["question"], str) 24 | assert isinstance(example["answers"], dict) 25 | texts = example["answers"]["text"] 26 | answer_starts = example["answers"]["answer_start"] 27 | for text, answer_start in zip(texts, answer_starts): 28 | assert example["context"][answer_start:].startswith(text) 29 | assert example["is_impossible"] is False 30 | 31 | 32 | def test_examples(tokenizer: PreTrainedTokenizerBase) -> None: 33 | max_seq_length = 128 34 | dataset = JSQuADDataset( 35 | "validation", tokenizer, max_seq_length, segmenter_kwargs=DictConfig({"analyzer": "jumanpp"}), limit_examples=10 36 | ) 37 | for example in dataset.hf_dataset: 38 | for answer in example["answers"]: 39 | if answer["answer_start"] == -1: 40 | continue 41 | assert example["context"][answer["answer_start"] :].startswith(answer["text"]) 42 | 43 | 44 | def test_getitem(tokenizer: PreTrainedTokenizerBase) -> None: 45 | max_seq_length = 128 46 | dataset = JSQuADDataset( 47 | "train", tokenizer, max_seq_length, segmenter_kwargs=DictConfig({"analyzer": "jumanpp"}), limit_examples=3 48 | ) 49 | for i in range(len(dataset)): 50 | feature = dataset[i] 51 | assert len(feature.input_ids) == max_seq_length 52 | assert len(feature.attention_mask) == max_seq_length 53 | assert len(feature.token_type_ids) == max_seq_length 54 | assert isinstance(feature.start_positions, int) 55 | assert isinstance(feature.end_positions, int) 56 | 57 | 58 | def test_features_0_pretokenized(tokenizer: PreTrainedTokenizerBase) -> None: 59 | max_seq_length = 128 60 | dataset = JSQuADDataset( 61 | "validation", tokenizer, max_seq_length, segmenter_kwargs=DictConfig({"analyzer": "jumanpp"}), limit_examples=1 62 | ) 63 | example: dict[str, Any] = dict( 64 | id="a10336p0q0", 65 | title="梅雨", 66 | context="梅雨 [SEP] 梅雨 ( つゆ 、 ばい う ) は 、 北海道 と 小笠原 諸島 を 除く 日本 、 朝鮮 半島 南部 、 中国 の 南部 から 長江 流域 に かけて の 沿海 部 、 および 台湾 など 、 東 アジア の 広範囲に おいて み られる 特有の 気象 現象 で 、 5 月 から 7 月 に かけて 来る 曇り や 雨 の 多い 期間 の こと 。 雨季 の 一種 である 。", 67 | question="日本 で 梅雨 が ない の は 北海道 と どこ か 。", 68 | answers=[ 69 | dict(text="小笠原 諸島", answer_start=35), 70 | dict(text="小笠原 諸島 を 除く 日本", answer_start=35), 71 | dict(text="小笠原 諸島", answer_start=35), 72 | ], 73 | is_impossible=False, 74 | ) 75 | features = dataset[0] 76 | question_tokens: list[str] = tokenizer.tokenize(example["question"]) 77 | context_tokens: list[str] = tokenizer.tokenize(example["context"]) 78 | input_tokens = [tokenizer.cls_token, *question_tokens, tokenizer.sep_token, *context_tokens, tokenizer.sep_token] 79 | padded_input_tokens = input_tokens + [tokenizer.pad_token] * (max_seq_length - len(input_tokens)) 80 | assert features.input_ids == tokenizer.convert_tokens_to_ids(padded_input_tokens) 81 | assert features.attention_mask == [1] * len(input_tokens) + [0] * (max_seq_length - len(input_tokens)) 82 | assert features.token_type_ids == [0] * (len(question_tokens) + 2) + [1] * (len(context_tokens) + 1) + [0] * ( 83 | max_seq_length - len(input_tokens) 84 | ) 85 | 86 | assert 0 <= features.start_positions <= features.end_positions < max_seq_length 87 | answer_span = slice(features.start_positions, features.end_positions + 1) 88 | tokenized_answer_text: str = tokenizer.decode(features.input_ids[answer_span]) 89 | answers: list[dict[str, Any]] = example["answers"] 90 | assert tokenized_answer_text == answers[0]["text"] 91 | 92 | 93 | def test_features_0(deberta_v3_tokenizer: PreTrainedTokenizerBase) -> None: 94 | assert isinstance(deberta_v3_tokenizer, DebertaV2TokenizerFast) 95 | tokenizer: DebertaV2TokenizerFast = deberta_v3_tokenizer 96 | max_seq_length = 128 97 | dataset = JSQuADDataset( 98 | "validation", tokenizer, max_seq_length, segmenter_kwargs=DictConfig({"analyzer": None}), limit_examples=1 99 | ) 100 | example: dict[str, Any] = dict( 101 | id="a10336p0q0", 102 | title="梅雨", 103 | context="梅雨[SEP]梅雨(つゆ、ばいう)は、北海道と小笠原諸島を除く日本、朝鮮半島南部、中国の南部から長江流域にかけての沿海部、および台湾など、東アジアの広範囲においてみられる特有の気象現象で、5月から7月にかけて来る曇りや雨の多い期間のこと。雨季の一種である。", 104 | question="日本で梅雨がないのは北海道とどこか。", 105 | answers=[ 106 | dict(text="小笠原諸島", answer_start=19), 107 | dict(text="小笠原諸島を除く日本", answer_start=19), 108 | dict(text="小笠原諸島", answer_start=19), 109 | ], 110 | is_impossible=False, 111 | ) 112 | features = dataset[0] 113 | question_tokens: list[str] = tokenizer.tokenize(example["question"]) 114 | context_tokens: list[str] = tokenizer.tokenize(example["context"]) 115 | input_tokens = [tokenizer.cls_token, *question_tokens, tokenizer.sep_token, *context_tokens, tokenizer.sep_token] 116 | padded_input_tokens = input_tokens + [tokenizer.pad_token] * (max_seq_length - len(input_tokens)) 117 | assert features.input_ids == tokenizer.convert_tokens_to_ids(padded_input_tokens) 118 | assert features.attention_mask == [1] * len(input_tokens) + [0] * (max_seq_length - len(input_tokens)) 119 | assert features.token_type_ids == [0] * (len(question_tokens) + 2) + [1] * (len(context_tokens) + 1) + [0] * ( 120 | max_seq_length - len(input_tokens) 121 | ) 122 | 123 | assert 0 <= features.start_positions <= features.end_positions < max_seq_length 124 | answer_span = slice(features.start_positions, features.end_positions + 1) 125 | tokenized_answer_text: str = tokenizer.decode(features.input_ids[answer_span]) 126 | answers: list[dict[str, Any]] = example["answers"] 127 | assert tokenized_answer_text == answers[0]["text"] 128 | -------------------------------------------------------------------------------- /tests/metrics/test_squad.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | import pytest 4 | from torchmetrics.text import SQuAD 5 | 6 | CASES = [ 7 | { 8 | "preds": [ 9 | {"prediction_text": "1976", "id": "000"}, 10 | ], 11 | "target": [{"answers": {"answer_start": [97], "text": ["1976"]}, "id": "000"}], 12 | "exact_match": 1.0, 13 | "f1": 1.0, # precision: 1 / 1, recall: 1 / 1 14 | }, 15 | { 16 | "preds": [ 17 | {"prediction_text": "2 時間 21 分", "id": "001"}, 18 | ], 19 | "target": [{"answers": {"answer_start": [10], "text": ["2 時間"]}, "id": "001"}], 20 | "exact_match": 0.0, 21 | "f1": 2 / 3, # precision: 2 / 4, recall: 2 / 2 22 | }, 23 | { 24 | "preds": [ 25 | {"prediction_text": "2 時間 21 分", "id": "001"}, 26 | ], 27 | "target": [{"answers": {"answer_start": [10, 10], "text": ["2 時間", "2 時間 21 分"]}, "id": "001"}], 28 | "exact_match": 1.0, 29 | "f1": 1.0, # precision: 4 / 4, recall: 4 / 4 30 | }, 31 | { 32 | "preds": [ 33 | {"prediction_text": "2 時間 21 分", "id": "001"}, 34 | ], 35 | "target": [{"answers": {"answer_start": [10, 12], "text": ["2 時間", "時間 21 分"]}, "id": "001"}], 36 | "exact_match": 0.0, 37 | "f1": 6 / 7, # precision: 3 / 4, recall: 3 / 3 38 | }, 39 | { 40 | "preds": [ 41 | {"prediction_text": "2 時 間 2 1 分", "id": "001"}, 42 | ], 43 | "target": [{"answers": {"answer_start": [10], "text": ["2 時 間"]}, "id": "001"}], 44 | "exact_match": 0.0, 45 | "f1": 2 / 3, # precision: 3 / 6, recall: 3 / 3 46 | }, 47 | ] 48 | 49 | 50 | @pytest.mark.parametrize("case", CASES) 51 | def test_jsquad(case: dict[str, Any]): 52 | metric = SQuAD() 53 | metrics = metric(case["preds"], case["target"]) 54 | assert metrics["exact_match"].item() / 100.0 == pytest.approx(case["exact_match"]) 55 | assert metrics["f1"].item() / 100.0 == pytest.approx(case["f1"]) 56 | --------------------------------------------------------------------------------