├── .github └── images │ └── bert-classification-tutorial.001.jpeg ├── .gitignore ├── .python-version ├── LICENSE ├── README.md ├── config ├── 1.json └── 4-ds.json ├── pyproject.toml ├── requirements-dev.lock ├── requirements.lock └── src ├── aggregate.py ├── prepare.py ├── train.py └── utils.py /.github/images/bert-classification-tutorial.001.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hppRC/bert-classification-tutorial-2024/0044a25d70238e74ec9d41df3c450e5ba939c4c0/.github/images/bert-classification-tutorial.001.jpeg -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | !.envrc 2 | /prev 3 | /data 4 | /datasets 5 | /outputs 6 | /scripts 7 | 8 | # Byte-compiled / optimized / DLL files 9 | __pycache__/ 10 | *.py[cod] 11 | *$py.class 12 | 13 | # C extensions 14 | *.so 15 | 16 | # Distribution / packaging 17 | .Python 18 | build/ 19 | develop-eggs/ 20 | dist/ 21 | downloads/ 22 | eggs/ 23 | .eggs/ 24 | lib/ 25 | lib64/ 26 | parts/ 27 | sdist/ 28 | var/ 29 | wheels/ 30 | share/python-wheels/ 31 | *.egg-info/ 32 | .installed.cfg 33 | *.egg 34 | MANIFEST 35 | 36 | # PyInstaller 37 | # Usually these files are written by a python script from a template 38 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 39 | *.manifest 40 | *.spec 41 | 42 | # Installer logs 43 | pip-log.txt 44 | pip-delete-this-directory.txt 45 | 46 | # Unit test / coverage reports 47 | htmlcov/ 48 | .tox/ 49 | .nox/ 50 | .coverage 51 | .coverage.* 52 | .cache 53 | nosetests.xml 54 | coverage.xml 55 | *.cover 56 | *.py,cover 57 | .hypothesis/ 58 | .pytest_cache/ 59 | cover/ 60 | 61 | # Translations 62 | *.mo 63 | *.pot 64 | 65 | # Django stuff: 66 | *.log 67 | local_settings.py 68 | db.sqlite3 69 | db.sqlite3-journal 70 | 71 | # Flask stuff: 72 | instance/ 73 | .webassets-cache 74 | 75 | # Scrapy stuff: 76 | .scrapy 77 | 78 | # Sphinx documentation 79 | docs/_build/ 80 | 81 | # PyBuilder 82 | .pybuilder/ 83 | target/ 84 | 85 | # Jupyter Notebook 86 | .ipynb_checkpoints 87 | 88 | # IPython 89 | profile_default/ 90 | ipython_config.py 91 | 92 | # pyenv 93 | # For a library or package, you might want to ignore these files since the code is 94 | # intended to run in multiple environments; otherwise, check them in: 95 | # .python-version 96 | 97 | # pipenv 98 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 99 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 100 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 101 | # install all needed dependencies. 102 | #Pipfile.lock 103 | 104 | # poetry 105 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 106 | # This is especially recommended for binary packages to ensure reproducibility, and is more 107 | # commonly ignored for libraries. 108 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 109 | #poetry.lock 110 | 111 | # pdm 112 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 113 | #pdm.lock 114 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 115 | # in version control. 116 | # https://pdm.fming.dev/#use-with-ide 117 | .pdm.toml 118 | 119 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 120 | __pypackages__/ 121 | 122 | # Celery stuff 123 | celerybeat-schedule 124 | celerybeat.pid 125 | 126 | # SageMath parsed files 127 | *.sage.py 128 | 129 | # Environments 130 | .env 131 | .venv 132 | env/ 133 | venv/ 134 | ENV/ 135 | env.bak/ 136 | venv.bak/ 137 | 138 | # Spyder project settings 139 | .spyderproject 140 | .spyproject 141 | 142 | # Rope project settings 143 | .ropeproject 144 | 145 | # mkdocs documentation 146 | /site 147 | 148 | # mypy 149 | .mypy_cache/ 150 | .dmypy.json 151 | dmypy.json 152 | 153 | # Pyre type checker 154 | .pyre/ 155 | 156 | # pytype static type analyzer 157 | .pytype/ 158 | 159 | # Cython debug symbols 160 | cython_debug/ 161 | 162 | # PyCharm 163 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 164 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 165 | # and can be added to the global gitignore or merged into this file. For a more nuclear 166 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 167 | #.idea/ 168 | -------------------------------------------------------------------------------- /.python-version: -------------------------------------------------------------------------------- 1 | 3.10.13 2 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Hayato Tsukagoshi 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # BERT Classification Tutorial 2024 2 | 3 | 本実装は[hppRC/bert-classification-tutorial](https://github.com/hppRC/bert-classification-tutorial)の2024版実装です。 4 | 実装の背景や詳細についてはこちらのリポジトリをご覧ください。 5 | 6 | 以前の実装からの主な変更点は以下の通りです。 7 | 8 | - 実装全体をHuggingFace関連ライブラリを利用するように変更 9 | - データセットの構築をHuggingFace Datasetsを利用するように変更 10 | - 訓練をHuggingFaceのTrainerとAccelerateを利用するように変更 11 | - 仮想環境の構築にryeを利用するよう変更 12 | 13 | 14 | ## 実行手順 15 | 16 | ```bash 17 | # 環境構築 18 | rye sync -f 19 | source .venv/bin/activate 20 | 21 | # データセット作成 22 | python src/prepare.py 23 | 24 | # 訓練 25 | accelerate launch --config_file config/4-ds.json src/train.py --model_name tohoku-nlp/bert-base-japanese-v3 --experiment_name 4-ds 26 | ``` 27 | 28 | ## 補足 29 | 30 | - `config`ディレクトリに`accelerate`利用時のconfigファイルを保存してあります 31 | - `4-ds.json`は4GPU+DeepSpeedを利用する場合の設定ファイルです 32 | - `1.json`は1GPUのみ利用する場合の設定ファイルです 33 | - `accelerate config --config_file config/hoge.json`を実行することでお好みの設定ファイルを対話的に作成することができます 34 | - `tensorboard --logdir ./outputs`を実行することでTensorBoardを利用して学習の進捗を確認することができます 35 | 36 | ## おわりに 37 | 38 | 本実装が研究・企業応用・個人利用問わずさまざまな方のお役に立てれば幸いです。 39 | 40 | 質問・バグ報告などがあればどんなことでも[Issue](https://github.com/hppRC/bert-classification-tutorial-2024/issues)にお書きください。 41 | 42 | 43 | ## 著者情報・引用 44 | 45 | 作者: [Hayato Tsukagoshi](https://hpprc.dev) \ 46 | email: [research.tsukagoshi.hayato@gmail.com](mailto:research.tsukagoshi.hayato@gmail.com) 47 | 関連学会記事: [BERTによるテキスト分類チュートリアル](https://www.jstage.jst.go.jp/article/jnlp/30/2/30_867/_article/-char/ja) 48 | 49 | 論文等で本実装を参照する場合は、以下をお使いください。 50 | 51 | 52 | ```bibtex 53 | @article{ 54 | hayato-tsukagoshi-2023-bert-classification-tutorial,, 55 | title={{BERT によるテキスト分類チュートリアル}}, 56 | author={塚越 駿 and 平子 潤}, 57 | journal={自然言語処理}, 58 | volume={30}, 59 | number={2}, 60 | pages={867-873}, 61 | year={2023}, 62 | doi={10.5715/jnlp.30.867}, 63 | url = {https://www.jstage.jst.go.jp/article/jnlp/30/2/30_867/_article/-char/ja}, 64 | } 65 | ``` 66 | -------------------------------------------------------------------------------- /config/1.json: -------------------------------------------------------------------------------- 1 | { 2 | "compute_environment": "LOCAL_MACHINE", 3 | "debug": false, 4 | "distributed_type": "NO", 5 | "downcast_bf16": "no", 6 | "enable_cpu_affinity": false, 7 | "gpu_ids": "0", 8 | "machine_rank": 0, 9 | "main_training_function": "main", 10 | "mixed_precision": "bf16", 11 | "num_machines": 1, 12 | "num_processes": 1, 13 | "rdzv_backend": "static", 14 | "same_network": true, 15 | "tpu_env": [], 16 | "tpu_use_cluster": false, 17 | "tpu_use_sudo": false, 18 | "use_cpu": false 19 | } -------------------------------------------------------------------------------- /config/4-ds.json: -------------------------------------------------------------------------------- 1 | { 2 | "compute_environment": "LOCAL_MACHINE", 3 | "debug": false, 4 | "deepspeed_config": { 5 | "gradient_accumulation_steps": 1, 6 | "gradient_clipping": 1.0, 7 | "offload_optimizer_device": "none", 8 | "offload_param_device": "none", 9 | "zero3_init_flag": false, 10 | "zero_stage": 2 11 | }, 12 | "distributed_type": "DEEPSPEED", 13 | "downcast_bf16": "no", 14 | "enable_cpu_affinity": false, 15 | "machine_rank": 0, 16 | "main_training_function": "main", 17 | "mixed_precision": "bf16", 18 | "num_machines": 1, 19 | "num_processes": 4, 20 | "rdzv_backend": "static", 21 | "same_network": true, 22 | "tpu_env": [], 23 | "tpu_use_cluster": false, 24 | "tpu_use_sudo": false, 25 | "use_cpu": false 26 | } 27 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "bert-classification-tutorial" 3 | version = "0.3.0" 4 | description = "Add your description here" 5 | authors = [{ name = "hppRC", email = "hpp.ricecake@gmail.com" }] 6 | dependencies = [ 7 | "torch==2.3.0+cu121", 8 | "transformers[ja,sentencepiece]>=4.41.1", 9 | "deepspeed>=0.14.2", 10 | "tensorboard>=2.14.0", 11 | "more-itertools>=10.2.0", 12 | "scikit-learn>=1.3.2", 13 | "datasets>=2.19.1", 14 | "accelerate>=0.30.1", 15 | "tokenizers>=0.19.1", 16 | "pandas>=2.0.3", 17 | "numpy>=1.26.3", 18 | "tqdm>=4.64.1", 19 | ] 20 | readme = "README.md" 21 | requires-python = ">= 3.8" 22 | 23 | [build-system] 24 | requires = ["hatchling"] 25 | build-backend = "hatchling.build" 26 | 27 | [tool.rye] 28 | managed = true 29 | dev-dependencies = [ 30 | "pip>=23.3.2", 31 | "setuptools>=69.0.3", 32 | "wheel>=0.42.0", 33 | "ruff>=0.4.5", 34 | ] 35 | 36 | [[tool.rye.sources]] 37 | name = "torch" 38 | url = "https://download.pytorch.org/whl/cu121" 39 | type = "index" 40 | 41 | [tool.hatch.metadata] 42 | allow-direct-references = true 43 | 44 | [tool.hatch.build.targets.wheel] 45 | packages = ["src"] 46 | 47 | [tool.ruff] 48 | exclude = [ 49 | ".git", 50 | ".mypy_cache", 51 | ".ruff_cache", 52 | ".venv", 53 | "outputs", 54 | "datasets", 55 | "prev", 56 | ] 57 | line-length = 120 58 | target-version = "py310" 59 | 60 | [tool.ruff.lint] 61 | fixable = ["ALL"] 62 | unfixable = [] 63 | # https://qiita.com/yuji38kwmt/items/63e82126076204923520 64 | select = ["F", "E", "W", "I", "B", "PL", "UP", "N"] 65 | ignore = [ 66 | "PLR0913", # Too many arguments in function definition 67 | "PLR2004", # Magic value used in comparison 68 | "N812", # Lowercase imported as non-lowercase 69 | "N806", # Lowercase imported as non-lowercase 70 | "F403", # unable to detect undefined names 71 | "E501", # Line too long 72 | "N999", # Invalid module name 73 | "PLR0912", # too many branches 74 | "B905", # zip strict 75 | "UP007", # Use `X | Y` for type annotations 76 | ] 77 | 78 | [tool.ruff.format] 79 | quote-style = "double" 80 | line-ending = "auto" 81 | -------------------------------------------------------------------------------- /requirements-dev.lock: -------------------------------------------------------------------------------- 1 | # generated by rye 2 | # use `rye lock` or `rye sync` to update this lockfile 3 | # 4 | # last locked with the following flags: 5 | # pre: false 6 | # features: [] 7 | # all-features: false 8 | # with-sources: false 9 | # generate-hashes: false 10 | 11 | -e file:. 12 | absl-py==2.1.0 13 | # via tensorboard 14 | accelerate==0.30.1 15 | # via bert-classification-tutorial 16 | aiohttp==3.9.5 17 | # via datasets 18 | # via fsspec 19 | aiosignal==1.3.1 20 | # via aiohttp 21 | annotated-types==0.7.0 22 | # via pydantic 23 | async-timeout==4.0.3 24 | # via aiohttp 25 | attrs==23.2.0 26 | # via aiohttp 27 | certifi==2022.12.7 28 | # via requests 29 | charset-normalizer==2.1.1 30 | # via requests 31 | datasets==2.19.1 32 | # via bert-classification-tutorial 33 | deepspeed==0.14.2 34 | # via bert-classification-tutorial 35 | dill==0.3.8 36 | # via datasets 37 | # via multiprocess 38 | filelock==3.13.1 39 | # via datasets 40 | # via huggingface-hub 41 | # via torch 42 | # via transformers 43 | # via triton 44 | frozenlist==1.4.1 45 | # via aiohttp 46 | # via aiosignal 47 | fsspec==2024.2.0 48 | # via datasets 49 | # via huggingface-hub 50 | # via torch 51 | fugashi==1.3.2 52 | # via transformers 53 | grpcio==1.64.0 54 | # via tensorboard 55 | hjson==3.1.0 56 | # via deepspeed 57 | huggingface-hub==0.23.2 58 | # via accelerate 59 | # via datasets 60 | # via tokenizers 61 | # via transformers 62 | idna==3.4 63 | # via requests 64 | # via yarl 65 | ipadic==1.0.0 66 | # via transformers 67 | jinja2==3.1.3 68 | # via torch 69 | joblib==1.4.2 70 | # via scikit-learn 71 | markdown==3.6 72 | # via tensorboard 73 | markupsafe==2.1.5 74 | # via jinja2 75 | # via werkzeug 76 | more-itertools==10.2.0 77 | # via bert-classification-tutorial 78 | mpmath==1.3.0 79 | # via sympy 80 | multidict==6.0.5 81 | # via aiohttp 82 | # via yarl 83 | multiprocess==0.70.16 84 | # via datasets 85 | networkx==3.2.1 86 | # via torch 87 | ninja==1.11.1.1 88 | # via deepspeed 89 | numpy==1.26.3 90 | # via accelerate 91 | # via bert-classification-tutorial 92 | # via datasets 93 | # via deepspeed 94 | # via pandas 95 | # via pyarrow 96 | # via scikit-learn 97 | # via scipy 98 | # via tensorboard 99 | # via transformers 100 | nvidia-cublas-cu12==12.1.3.1 101 | # via nvidia-cudnn-cu12 102 | # via nvidia-cusolver-cu12 103 | # via torch 104 | nvidia-cuda-cupti-cu12==12.1.105 105 | # via torch 106 | nvidia-cuda-nvrtc-cu12==12.1.105 107 | # via torch 108 | nvidia-cuda-runtime-cu12==12.1.105 109 | # via torch 110 | nvidia-cudnn-cu12==8.9.2.26 111 | # via torch 112 | nvidia-cufft-cu12==11.0.2.54 113 | # via torch 114 | nvidia-curand-cu12==10.3.2.106 115 | # via torch 116 | nvidia-cusolver-cu12==11.4.5.107 117 | # via torch 118 | nvidia-cusparse-cu12==12.1.0.106 119 | # via nvidia-cusolver-cu12 120 | # via torch 121 | nvidia-nccl-cu12==2.20.5 122 | # via torch 123 | nvidia-nvjitlink-cu12==12.1.105 124 | # via nvidia-cusolver-cu12 125 | # via nvidia-cusparse-cu12 126 | nvidia-nvtx-cu12==12.1.105 127 | # via torch 128 | packaging==22.0 129 | # via accelerate 130 | # via datasets 131 | # via deepspeed 132 | # via huggingface-hub 133 | # via transformers 134 | pandas==2.2.2 135 | # via bert-classification-tutorial 136 | # via datasets 137 | pip==24.0 138 | plac==1.4.3 139 | # via unidic 140 | protobuf==5.27.0 141 | # via tensorboard 142 | # via transformers 143 | psutil==5.9.8 144 | # via accelerate 145 | # via deepspeed 146 | py-cpuinfo==9.0.0 147 | # via deepspeed 148 | pyarrow==16.1.0 149 | # via datasets 150 | pyarrow-hotfix==0.6 151 | # via datasets 152 | pydantic==2.7.1 153 | # via deepspeed 154 | pydantic-core==2.18.2 155 | # via pydantic 156 | pynvml==11.5.0 157 | # via deepspeed 158 | python-dateutil==2.9.0.post0 159 | # via pandas 160 | pytz==2024.1 161 | # via pandas 162 | pyyaml==6.0.1 163 | # via accelerate 164 | # via datasets 165 | # via huggingface-hub 166 | # via transformers 167 | regex==2024.5.15 168 | # via transformers 169 | requests==2.28.1 170 | # via datasets 171 | # via huggingface-hub 172 | # via transformers 173 | # via unidic 174 | rhoknp==1.3.0 175 | # via transformers 176 | ruff==0.4.5 177 | safetensors==0.4.3 178 | # via accelerate 179 | # via transformers 180 | scikit-learn==1.5.0 181 | # via bert-classification-tutorial 182 | scipy==1.13.1 183 | # via scikit-learn 184 | sentencepiece==0.2.0 185 | # via transformers 186 | setuptools==70.0.0 187 | # via tensorboard 188 | six==1.16.0 189 | # via python-dateutil 190 | # via tensorboard 191 | sudachidict-core==20240409 192 | # via transformers 193 | sudachipy==0.6.8 194 | # via sudachidict-core 195 | # via transformers 196 | sympy==1.12 197 | # via torch 198 | tensorboard==2.16.2 199 | # via bert-classification-tutorial 200 | tensorboard-data-server==0.7.2 201 | # via tensorboard 202 | threadpoolctl==3.5.0 203 | # via scikit-learn 204 | tokenizers==0.19.1 205 | # via bert-classification-tutorial 206 | # via transformers 207 | torch==2.3.0+cu121 208 | # via accelerate 209 | # via bert-classification-tutorial 210 | # via deepspeed 211 | tqdm==4.64.1 212 | # via bert-classification-tutorial 213 | # via datasets 214 | # via deepspeed 215 | # via huggingface-hub 216 | # via transformers 217 | # via unidic 218 | transformers==4.41.1 219 | # via bert-classification-tutorial 220 | triton==2.3.0 221 | # via torch 222 | typing-extensions==4.9.0 223 | # via huggingface-hub 224 | # via pydantic 225 | # via pydantic-core 226 | # via torch 227 | tzdata==2024.1 228 | # via pandas 229 | unidic==1.1.0 230 | # via transformers 231 | unidic-lite==1.0.8 232 | # via transformers 233 | urllib3==1.26.13 234 | # via requests 235 | wasabi==0.10.1 236 | # via unidic 237 | werkzeug==3.0.3 238 | # via tensorboard 239 | wheel==0.43.0 240 | xxhash==3.4.1 241 | # via datasets 242 | yarl==1.9.4 243 | # via aiohttp 244 | -------------------------------------------------------------------------------- /requirements.lock: -------------------------------------------------------------------------------- 1 | # generated by rye 2 | # use `rye lock` or `rye sync` to update this lockfile 3 | # 4 | # last locked with the following flags: 5 | # pre: false 6 | # features: [] 7 | # all-features: false 8 | # with-sources: false 9 | # generate-hashes: false 10 | 11 | -e file:. 12 | absl-py==2.1.0 13 | # via tensorboard 14 | accelerate==0.30.1 15 | # via bert-classification-tutorial 16 | aiohttp==3.9.5 17 | # via datasets 18 | # via fsspec 19 | aiosignal==1.3.1 20 | # via aiohttp 21 | annotated-types==0.7.0 22 | # via pydantic 23 | async-timeout==4.0.3 24 | # via aiohttp 25 | attrs==23.2.0 26 | # via aiohttp 27 | certifi==2022.12.7 28 | # via requests 29 | charset-normalizer==2.1.1 30 | # via requests 31 | datasets==2.19.1 32 | # via bert-classification-tutorial 33 | deepspeed==0.14.2 34 | # via bert-classification-tutorial 35 | dill==0.3.8 36 | # via datasets 37 | # via multiprocess 38 | filelock==3.13.1 39 | # via datasets 40 | # via huggingface-hub 41 | # via torch 42 | # via transformers 43 | # via triton 44 | frozenlist==1.4.1 45 | # via aiohttp 46 | # via aiosignal 47 | fsspec==2024.2.0 48 | # via datasets 49 | # via huggingface-hub 50 | # via torch 51 | fugashi==1.3.2 52 | # via transformers 53 | grpcio==1.64.0 54 | # via tensorboard 55 | hjson==3.1.0 56 | # via deepspeed 57 | huggingface-hub==0.23.2 58 | # via accelerate 59 | # via datasets 60 | # via tokenizers 61 | # via transformers 62 | idna==3.4 63 | # via requests 64 | # via yarl 65 | ipadic==1.0.0 66 | # via transformers 67 | jinja2==3.1.3 68 | # via torch 69 | joblib==1.4.2 70 | # via scikit-learn 71 | markdown==3.6 72 | # via tensorboard 73 | markupsafe==2.1.5 74 | # via jinja2 75 | # via werkzeug 76 | more-itertools==10.2.0 77 | # via bert-classification-tutorial 78 | mpmath==1.3.0 79 | # via sympy 80 | multidict==6.0.5 81 | # via aiohttp 82 | # via yarl 83 | multiprocess==0.70.16 84 | # via datasets 85 | networkx==3.2.1 86 | # via torch 87 | ninja==1.11.1.1 88 | # via deepspeed 89 | numpy==1.26.3 90 | # via accelerate 91 | # via bert-classification-tutorial 92 | # via datasets 93 | # via deepspeed 94 | # via pandas 95 | # via pyarrow 96 | # via scikit-learn 97 | # via scipy 98 | # via tensorboard 99 | # via transformers 100 | nvidia-cublas-cu12==12.1.3.1 101 | # via nvidia-cudnn-cu12 102 | # via nvidia-cusolver-cu12 103 | # via torch 104 | nvidia-cuda-cupti-cu12==12.1.105 105 | # via torch 106 | nvidia-cuda-nvrtc-cu12==12.1.105 107 | # via torch 108 | nvidia-cuda-runtime-cu12==12.1.105 109 | # via torch 110 | nvidia-cudnn-cu12==8.9.2.26 111 | # via torch 112 | nvidia-cufft-cu12==11.0.2.54 113 | # via torch 114 | nvidia-curand-cu12==10.3.2.106 115 | # via torch 116 | nvidia-cusolver-cu12==11.4.5.107 117 | # via torch 118 | nvidia-cusparse-cu12==12.1.0.106 119 | # via nvidia-cusolver-cu12 120 | # via torch 121 | nvidia-nccl-cu12==2.20.5 122 | # via torch 123 | nvidia-nvjitlink-cu12==12.1.105 124 | # via nvidia-cusolver-cu12 125 | # via nvidia-cusparse-cu12 126 | nvidia-nvtx-cu12==12.1.105 127 | # via torch 128 | packaging==22.0 129 | # via accelerate 130 | # via datasets 131 | # via deepspeed 132 | # via huggingface-hub 133 | # via transformers 134 | pandas==2.2.2 135 | # via bert-classification-tutorial 136 | # via datasets 137 | plac==1.4.3 138 | # via unidic 139 | protobuf==5.27.0 140 | # via tensorboard 141 | # via transformers 142 | psutil==5.9.8 143 | # via accelerate 144 | # via deepspeed 145 | py-cpuinfo==9.0.0 146 | # via deepspeed 147 | pyarrow==16.1.0 148 | # via datasets 149 | pyarrow-hotfix==0.6 150 | # via datasets 151 | pydantic==2.7.1 152 | # via deepspeed 153 | pydantic-core==2.18.2 154 | # via pydantic 155 | pynvml==11.5.0 156 | # via deepspeed 157 | python-dateutil==2.9.0.post0 158 | # via pandas 159 | pytz==2024.1 160 | # via pandas 161 | pyyaml==6.0.1 162 | # via accelerate 163 | # via datasets 164 | # via huggingface-hub 165 | # via transformers 166 | regex==2024.5.15 167 | # via transformers 168 | requests==2.28.1 169 | # via datasets 170 | # via huggingface-hub 171 | # via transformers 172 | # via unidic 173 | rhoknp==1.3.0 174 | # via transformers 175 | safetensors==0.4.3 176 | # via accelerate 177 | # via transformers 178 | scikit-learn==1.5.0 179 | # via bert-classification-tutorial 180 | scipy==1.13.1 181 | # via scikit-learn 182 | sentencepiece==0.2.0 183 | # via transformers 184 | setuptools==70.0.0 185 | # via tensorboard 186 | six==1.16.0 187 | # via python-dateutil 188 | # via tensorboard 189 | sudachidict-core==20240409 190 | # via transformers 191 | sudachipy==0.6.8 192 | # via sudachidict-core 193 | # via transformers 194 | sympy==1.12 195 | # via torch 196 | tensorboard==2.16.2 197 | # via bert-classification-tutorial 198 | tensorboard-data-server==0.7.2 199 | # via tensorboard 200 | threadpoolctl==3.5.0 201 | # via scikit-learn 202 | tokenizers==0.19.1 203 | # via bert-classification-tutorial 204 | # via transformers 205 | torch==2.3.0+cu121 206 | # via accelerate 207 | # via bert-classification-tutorial 208 | # via deepspeed 209 | tqdm==4.64.1 210 | # via bert-classification-tutorial 211 | # via datasets 212 | # via deepspeed 213 | # via huggingface-hub 214 | # via transformers 215 | # via unidic 216 | transformers==4.41.1 217 | # via bert-classification-tutorial 218 | triton==2.3.0 219 | # via torch 220 | typing-extensions==4.9.0 221 | # via huggingface-hub 222 | # via pydantic 223 | # via pydantic-core 224 | # via torch 225 | tzdata==2024.1 226 | # via pandas 227 | unidic==1.1.0 228 | # via transformers 229 | unidic-lite==1.0.8 230 | # via transformers 231 | urllib3==1.26.13 232 | # via requests 233 | wasabi==0.10.1 234 | # via unidic 235 | werkzeug==3.0.3 236 | # via tensorboard 237 | xxhash==3.4.1 238 | # via datasets 239 | yarl==1.9.4 240 | # via aiohttp 241 | -------------------------------------------------------------------------------- /src/aggregate.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from pathlib import Path 3 | 4 | import pandas as pd 5 | from transformers import HfArgumentParser 6 | 7 | from src import utils 8 | 9 | 10 | @dataclass 11 | class Args: 12 | output_dir: Path = "./outputs" 13 | result_dir: Path = "./results" 14 | 15 | 16 | def main(args: Args): 17 | data = [] 18 | for path in args.output_dir.glob("**/metrics.json"): 19 | metrics = utils.load_json(path) 20 | config = utils.load_json(path.parent / "config.json") 21 | 22 | data.append( 23 | { 24 | "model_name": config["_name_or_path"], 25 | "best-val-f1": metrics["best-val"]["f1"], 26 | "best-val-acc": metrics["best-val"]["f1"], 27 | "f1": metrics["test"]["f1"], 28 | "accuracy": metrics["test"]["accuracy"], 29 | "precision": metrics["test"]["precision"], 30 | "recall": metrics["test"]["recall"], 31 | } 32 | ) 33 | 34 | args.output_dir.mkdir(parents=True, exist_ok=True) 35 | df = pd.DataFrame(data).sort_values("f1", ascending=False) 36 | df.to_csv(str(args.output_dir / "all.csv"), index=False) 37 | 38 | best_df = ( 39 | df.groupby("model_name") 40 | .apply(lambda x: x.nlargest(1, "best-val-f1").reset_index(drop=True), include_groups=False) 41 | .reset_index(level=0) 42 | ).sort_values("f1", ascending=False) 43 | 44 | best_df.to_csv(str(args.output_dir / "best.csv"), index=False) 45 | 46 | print("|Model|Accuracy|Precision|Recall|F1|") 47 | print("|:-|:-:|:-:|:-:|:-:|") 48 | for row in best_df.to_dict("records"): 49 | print( 50 | f'|[{row["model_name"]}](https://huggingface.co/{row["model_name"]})|{row["accuracy"]*100:.2f}|{row["precision"]*100:.2f}|{row["recall"]*100:.2f}|{row["f1"]*100:.2f}|' 51 | ) 52 | 53 | 54 | if __name__ == "__main__": 55 | parser = HfArgumentParser((Args,)) 56 | [args] = parser.parse_args_into_dataclasses() 57 | main(args) 58 | -------------------------------------------------------------------------------- /src/prepare.py: -------------------------------------------------------------------------------- 1 | import pdb 2 | import unicodedata 3 | from dataclasses import dataclass 4 | from pathlib import Path 5 | 6 | from transformers import HfArgumentParser 7 | 8 | import datasets as ds 9 | from src import utils 10 | 11 | 12 | @dataclass 13 | class Args: 14 | output_dir: Path = "./datasets/livedoor" 15 | seed: int = 42 16 | 17 | 18 | def process_title(title: str) -> str: 19 | title = unicodedata.normalize("NFKC", title) 20 | title = title.strip(" ").strip() 21 | return title 22 | 23 | 24 | # 記事本文の前処理 25 | # 重複した改行の削除、文頭の全角スペースの削除、NFKC正規化を実施 26 | def process_body(body: list[str]) -> str: 27 | body = [unicodedata.normalize("NFKC", line) for line in body] 28 | body = [line.strip(" ").strip() for line in body] 29 | body = [line for line in body if line] 30 | body = "\n".join(body) 31 | return body 32 | 33 | 34 | DATASET_URL = "https://www.rondhuit.com/download/ldcc-20140209.tar.gz" 35 | 36 | 37 | def main(args: Args): 38 | # datasetsのダウンロード用クラスを利用することでbashコマンドを使わずにダウンロードできる 39 | # tar.gzの解凍も自動で行ってくれて便利 40 | dl_manager: ds.DownloadManager = ds.DownloadManager( 41 | download_config=ds.DownloadConfig(num_proc=16), 42 | ) 43 | data_dir: str = dl_manager.download_and_extract(DATASET_URL) 44 | 45 | # ライブドアニュースコーパスの実データが保存されいているディレクトリへのパス 46 | input_dir = Path(data_dir, "text") 47 | 48 | # `.from_generator`の`gen_kwargs`にgenerator関数に渡す引数を指定 49 | # リストが渡されていると`num_proc`の数にリストを分割して分配・処理する 50 | # 例: リストの長さが4, num_proc=2の場合、2つのプロセスでそれぞれ2つの要素を処理する 51 | def generator(paths: list[Path]): 52 | for path in paths: 53 | category = path.parent.name 54 | 55 | # データフォーマット 56 | # 1行目:記事のURL 57 | # 2行目:記事の日付 58 | # 3行目:記事のタイトル 59 | # 4行目以降:記事の本文 60 | lines: list[str] = path.read_text().splitlines() 61 | url, date, title, *body = lines 62 | 63 | yield { 64 | "category": category, 65 | "category-id": path.stem, 66 | "url": url.strip(), 67 | "date": date.strip(), 68 | "title": process_title(title.strip()), 69 | "body": process_body(body), 70 | } 71 | 72 | # ライセンスファイル以外のテキストデータへのパスを取得 73 | paths = [path for path in input_dir.glob("*/*.txt") if path.name != "LICENSE.txt"] 74 | 75 | # generator関数から直接データセットを作成 76 | dataset = ds.Dataset.from_generator( 77 | generator, 78 | gen_kwargs={"paths": paths}, # リストを渡すとnum_procの数に分割して処理される 79 | num_proc=16, 80 | ) 81 | 82 | dataset = dataset.shuffle(seed=args.seed) 83 | 84 | # ラベルを作っておく 85 | labels = set(dataset["category"]) 86 | label2id = {label: i for i, label in enumerate(sorted(labels))} 87 | 88 | # ラベルを数値に変換 89 | # datasetsのmap関数はデータセットの各要素に対して関数を適用する(batch=Falseの時) 90 | # こんな感じでローカル関数を都度作るようにすると名前を考えなくていいので楽、あとdictの参照とか楽 91 | def process(x: dict): 92 | return { 93 | "label": label2id[x["category"]], 94 | } 95 | 96 | dataset = dataset.map(process, num_proc=16) 97 | 98 | # train, validation, testに一括で分割できないので段階的にやる(4:(1→1:1)) 99 | datasets = dataset.train_test_split(test_size=0.2) 100 | train_dataset = datasets["train"] 101 | val_test_datasets = datasets["test"].train_test_split(test_size=0.5) 102 | 103 | datasets = ds.DatasetDict( 104 | { 105 | "train": train_dataset, 106 | "validation": val_test_datasets["train"], 107 | "test": val_test_datasets["test"], 108 | } 109 | ) 110 | 111 | datasets.save_to_disk(str(args.output_dir)) 112 | utils.save_json(label2id, args.output_dir / "label2id.json") 113 | 114 | pdb.set_trace() 115 | 116 | 117 | if __name__ == "__main__": 118 | parser = HfArgumentParser((Args,)) 119 | [args] = parser.parse_args_into_dataclasses() 120 | main(args) 121 | -------------------------------------------------------------------------------- /src/train.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from pathlib import Path 3 | from typing import Any 4 | 5 | import torch 6 | import torch.utils.checkpoint 7 | from accelerate import Accelerator 8 | from sklearn.metrics import accuracy_score, precision_recall_fscore_support 9 | from transformers import ( 10 | AutoModelForSequenceClassification, 11 | AutoTokenizer, 12 | BatchEncoding, 13 | EvalPrediction, 14 | HfArgumentParser, 15 | PreTrainedTokenizer, 16 | TrainingArguments, 17 | ) 18 | from transformers import Trainer as HFTrainer 19 | from transformers.trainer_utils import PredictionOutput 20 | 21 | import datasets as ds 22 | from src import utils 23 | 24 | 25 | @dataclass 26 | class TrainingArgs(TrainingArguments): 27 | output_dir: str = None 28 | 29 | num_train_epochs: int = 20 30 | learning_rate: float = 3e-5 31 | per_device_train_batch_size: int = 32 32 | weight_decay: float = 0.01 33 | warmup_ratio: float = 0.1 34 | 35 | dataloader_num_workers: int = 4 36 | lr_scheduler_type: str = "cosine" 37 | 38 | # 使用するデータ型、BF16を利用することで高速かつ省メモリで学習可能 39 | # 一般にFP16よりBF16の方が学習が安定している 40 | bf16: bool = True 41 | 42 | # optimizerが持つ勾配情報を適宜再計算することで保持するメモリを削減するGradient Checkpointingの設定 43 | gradient_checkpointing: bool = True 44 | gradient_checkpointing_kwargs: dict = field(default_factory=lambda: {"use_reentrant": True}) 45 | 46 | # tensorboardで実験ログを残しておくとどんな感じで学習が進んでいるかわかって便利 47 | report_to: str = "tensorboard" 48 | logging_steps: int = 10 49 | logging_dir: str = None 50 | 51 | # 最良のモデルを選ぶ際に基準となる指標 52 | metric_for_best_model: str = "loss" 53 | greater_is_better: bool = False 54 | # val accuracyを基準に選ぶ場合は以下のようにする 55 | # metric_for_best_model: str = "acc" 56 | # greater_is_better: bool = True 57 | 58 | eval_strategy: str = "epoch" 59 | per_device_eval_batch_size: int = 32 60 | 61 | save_strategy: str = "epoch" 62 | save_total_limit: int = 1 63 | 64 | ddp_find_unused_parameters: bool = False 65 | load_best_model_at_end: bool = False 66 | remove_unused_columns: bool = False 67 | 68 | 69 | @dataclass 70 | class ExperimentConfig: 71 | model_name: str = "cl-tohoku/bert-base-japanese-v3" 72 | dataset_dir: Path = "./datasets/livedoor" 73 | experiment_name: str = "default" 74 | max_seq_len: int = 512 75 | 76 | def __post_init__(self): 77 | self.label2id = utils.load_json(self.dataset_dir / "label2id.json") 78 | 79 | 80 | @dataclass 81 | class DataCollator: 82 | tokenizer: PreTrainedTokenizer 83 | max_seq_len: int 84 | 85 | def __call__(self, data_list: list[dict[str, Any]]) -> BatchEncoding: 86 | title = [d["title"] for d in data_list] 87 | body = [d["body"] for d in data_list] 88 | inputs: BatchEncoding = self.tokenizer( 89 | title, 90 | body, 91 | padding=True, 92 | truncation="only_second", 93 | return_tensors="pt", 94 | max_length=self.max_seq_len, 95 | ) 96 | inputs["labels"] = torch.LongTensor([d["label"] for d in data_list]) 97 | return inputs 98 | 99 | 100 | class ComputeMetrics: 101 | def __init__(self, labels: list[str]): 102 | self.labels = labels 103 | 104 | def __call__(self, eval_pred: EvalPrediction): 105 | pred_labels = torch.Tensor(eval_pred.predictions.argmax(axis=1).reshape(-1)) 106 | gold_labels = torch.Tensor(eval_pred.label_ids.reshape(-1)) 107 | 108 | accuracy: float = accuracy_score(gold_labels, pred_labels) 109 | precision, recall, f1, _ = precision_recall_fscore_support( 110 | gold_labels, 111 | pred_labels, 112 | average="macro", 113 | zero_division=0, 114 | labels=self.labels, 115 | ) 116 | 117 | return { 118 | "accuracy": accuracy, 119 | "precision": precision, 120 | "recall": recall, 121 | "f1": f1, 122 | } 123 | 124 | 125 | def main(training_args: TrainingArgs, config: ExperimentConfig): 126 | tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained(config.model_name) 127 | if tokenizer.pad_token_id is None: 128 | tokenizer.pad_token_id = tokenizer.unk_token_id 129 | tokenizer.add_eos_token = True 130 | 131 | model = AutoModelForSequenceClassification.from_pretrained( 132 | config.model_name, 133 | num_labels=len(config.label2id), 134 | label2id=config.label2id, 135 | id2label={v: k for k, v in config.label2id.items()}, 136 | pad_token_id=tokenizer.pad_token_id, 137 | use_cache=False, 138 | ) 139 | 140 | datasets: ds.DatasetDict = ds.load_from_disk(str(config.dataset_dir)) 141 | 142 | data_collator = DataCollator( 143 | tokenizer=tokenizer, 144 | max_seq_len=config.max_seq_len, 145 | ) 146 | 147 | compute_metrics = ComputeMetrics(labels=list(config.label2id.values())) 148 | 149 | trainer = HFTrainer( 150 | args=training_args, 151 | model=model, 152 | tokenizer=tokenizer, 153 | train_dataset=datasets["train"], 154 | eval_dataset=datasets["validation"], 155 | data_collator=data_collator, 156 | compute_metrics=compute_metrics, 157 | ) 158 | 159 | trainer.train() 160 | trainer._load_best_model() 161 | 162 | trainer.save_model() 163 | trainer.save_state() 164 | trainer.tokenizer.save_pretrained(training_args.output_dir) 165 | 166 | # 最良のモデルを使ってval set, test setで評価 167 | val_prediction_output: PredictionOutput = trainer.predict(test_dataset=datasets["validation"]) 168 | test_prediction_output: PredictionOutput = trainer.predict(test_dataset=datasets["test"]) 169 | 170 | if training_args.process_index == 0: 171 | val_metrics: dict[str, float] = val_prediction_output.metrics 172 | val_metrics = {k.replace("test_", ""): v for k, v in val_metrics.items()} 173 | 174 | test_metrics: dict[str, float] = test_prediction_output.metrics 175 | test_metrics = {k.replace("test_", ""): v for k, v in test_metrics.items()} 176 | 177 | metrics = { 178 | "best-val": val_metrics, 179 | "test": test_metrics, 180 | } 181 | 182 | utils.save_json(metrics, Path(training_args.output_dir, "metrics.json")) 183 | 184 | with Path(training_args.output_dir, "training_args.json").open("w") as f: 185 | f.write(trainer.args.to_json_string()) 186 | 187 | 188 | def summarize_config(training_args: TrainingArgs, config: ExperimentConfig) -> str: 189 | accelerator = Accelerator() 190 | batch_size = training_args.per_device_train_batch_size * accelerator.num_processes 191 | config_summary = { 192 | "B": batch_size, 193 | "E": training_args.num_train_epochs, 194 | "LR": training_args.learning_rate, 195 | "L": config.max_seq_len, 196 | } 197 | config_summary = "".join(f"{k}{v}" for k, v in config_summary.items()) 198 | return config_summary 199 | 200 | 201 | if __name__ == "__main__": 202 | parser = HfArgumentParser((TrainingArgs, ExperimentConfig)) 203 | training_args, config = parser.parse_args_into_dataclasses() 204 | config_summary = summarize_config(training_args, config) 205 | model_name = config.model_name.replace("/", "__") 206 | 207 | training_args.output_dir = f"outputs/{model_name}/{config_summary}/{config.experiment_name}" 208 | training_args.logging_dir = training_args.output_dir 209 | training_args.run_name = f"{config_summary}/{config.experiment_name}" 210 | 211 | main(training_args, config) 212 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import random 3 | from collections.abc import Iterable 4 | from pathlib import Path 5 | from typing import Any 6 | 7 | import numpy as np 8 | import pandas as pd 9 | import torch 10 | 11 | 12 | def save_jsonl(data: Iterable | pd.DataFrame, path: Path | str) -> None: 13 | path = Path(path) 14 | 15 | if type(data) != pd.DataFrame: 16 | data = pd.DataFrame(data) 17 | 18 | data.to_json( 19 | path, 20 | orient="records", 21 | lines=True, 22 | force_ascii=False, 23 | ) 24 | 25 | 26 | def save_json(data: dict[Any, Any], path: Path | str) -> None: 27 | path = Path(path) 28 | with path.open("w") as f: 29 | json.dump(data, f, indent=2, ensure_ascii=False) 30 | 31 | 32 | def load_jsonl(path: Path | str) -> list[dict]: 33 | path = Path(path) 34 | df = pd.read_json(path, lines=True) 35 | return df.to_dict(orient="records") 36 | 37 | 38 | def load_jsonl_df(path: Path | str) -> pd.DataFrame: 39 | path = Path(path) 40 | df = pd.read_json(path, lines=True) 41 | return df 42 | 43 | 44 | def load_json(path: Path | str) -> dict: 45 | path = Path(path) 46 | with path.open() as f: 47 | data = json.load(f) 48 | return data 49 | 50 | 51 | def set_seed(seed: int): 52 | random.seed(seed) 53 | np.random.seed(seed) 54 | torch.manual_seed(seed) 55 | torch.cuda.manual_seed_all(seed) 56 | 57 | 58 | def log(data: dict, path: Path | str) -> dict: 59 | path = Path(path) 60 | 61 | if path.exists(): 62 | df: pd.DataFrame = pd.read_csv(path) 63 | df = pd.concat([df, pd.DataFrame([data])], ignore_index=True) 64 | df.to_csv(path, index=False) 65 | else: 66 | pd.DataFrame([data]).to_csv(path, index=False) 67 | 68 | 69 | def get_current_timestamp() -> str: 70 | return pd.Timestamp.now().strftime("%Y-%m-%d/%H:%M:%S") 71 | --------------------------------------------------------------------------------