├── .gitignore ├── .pre-commit-config.yaml ├── LICENSE ├── README.md ├── configs ├── compute_mined_mapping.yaml ├── compute_tokenizer_info.yaml ├── cross_tokenizer_distill.yaml ├── data │ └── tulu3.yaml ├── eval.yaml ├── eval_lockstep.yaml ├── eval_lockstep_mined.yaml ├── math_cross_tokenizer_distill.yaml ├── math_same_tokenizer_distill.yaml ├── models │ ├── gemma_llama_qwen.yaml │ └── llama_qwen.yaml ├── optimizer │ └── adamw.yaml ├── train_zett_hn.yaml ├── train_zett_hn_gemma2.yaml └── zett.yaml ├── docs ├── byteification.md ├── pytorch_alm_from_scratch.ipynb └── tokenizer_transfer.md ├── examples ├── gemma2_distill_from_openmath2-llama_gpu.sh ├── llama3_to_byte_tokenizer_gpu.sh ├── llama3_to_qwen2_tokenizer_gpu.sh └── llama3_to_qwen2_tokenizer_tpu.sh ├── pyproject.toml ├── requirements.txt ├── rust_utils ├── .gitignore ├── Cargo.lock ├── Cargo.toml ├── pyproject.toml ├── rust-toolchain ├── src │ └── lib.rs └── tpu_build.sh ├── scripts ├── compute_mined_mapping.py ├── compute_tokenizer_info.py ├── cross_tokenizer_distill.py ├── eval.py ├── eval_lockstep.py ├── export_checkpoint.py ├── push_flax_version_to_hub.py ├── train_zett_hn.py └── zett.py ├── setup.py └── tokenkit ├── __init__.py ├── align.py ├── baseline_utils.py ├── byteify.py ├── compat ├── hyper_roberta.py └── hypernet.py ├── constants.py ├── data └── __init__.py ├── eval ├── __init__.py └── generate.py ├── gcs_utils.py ├── hf ├── __init__.py ├── configuration_tpu_gemma2.py ├── configuration_tpu_llama.py ├── modelling_flax_tpu_gemma2.py ├── modelling_flax_tpu_llama.py ├── modelling_tpu_gemma2.py └── modelling_tpu_llama.py ├── model_kinds.py ├── models ├── __init__.py ├── hypernet │ └── __init__.py ├── lora.py ├── param.py └── sharding.py ├── parse_args.py ├── training ├── __init__.py ├── checkpoint.py ├── collators │ ├── __init__.py │ ├── tokenizer_aligner.py │ └── tokenizer_sampler.py ├── losses.py ├── lr.py ├── multitask.py └── opt.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | outputs 2 | tokenkit_env 3 | wandb 4 | artifacts 5 | logs 6 | 7 | ### Python ### 8 | # Byte-compiled / optimized / DLL files 9 | __pycache__/ 10 | *.py[cod] 11 | *$py.class 12 | 13 | # C extensions 14 | *.so 15 | 16 | # Distribution / packaging 17 | .Python 18 | build/ 19 | develop-eggs/ 20 | dist/ 21 | downloads/ 22 | eggs/ 23 | .eggs/ 24 | lib/ 25 | lib64/ 26 | parts/ 27 | sdist/ 28 | var/ 29 | wheels/ 30 | share/python-wheels/ 31 | *.egg-info/ 32 | .installed.cfg 33 | *.egg 34 | MANIFEST 35 | 36 | # PyInstaller 37 | # Usually these files are written by a python script from a template 38 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 39 | *.manifest 40 | *.spec 41 | 42 | # Installer logs 43 | pip-log.txt 44 | pip-delete-this-directory.txt 45 | 46 | # Unit test / coverage reports 47 | htmlcov/ 48 | .tox/ 49 | .nox/ 50 | .coverage 51 | .coverage.* 52 | .cache 53 | nosetests.xml 54 | coverage.xml 55 | *.cover 56 | *.py,cover 57 | .hypothesis/ 58 | .pytest_cache/ 59 | cover/ 60 | 61 | # Translations 62 | *.mo 63 | *.pot 64 | 65 | # Django stuff: 66 | *.log 67 | local_settings.py 68 | db.sqlite3 69 | db.sqlite3-journal 70 | 71 | # Flask stuff: 72 | instance/ 73 | .webassets-cache 74 | 75 | # Scrapy stuff: 76 | .scrapy 77 | 78 | # Sphinx documentation 79 | docs/_build/ 80 | 81 | # PyBuilder 82 | .pybuilder/ 83 | target/ 84 | 85 | # Jupyter Notebook 86 | .ipynb_checkpoints 87 | 88 | # IPython 89 | profile_default/ 90 | ipython_config.py 91 | 92 | # pyenv 93 | # For a library or package, you might want to ignore these files since the code is 94 | # intended to run in multiple environments; otherwise, check them in: 95 | # .python-version 96 | 97 | # pipenv 98 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 99 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 100 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 101 | # install all needed dependencies. 102 | #Pipfile.lock 103 | 104 | # poetry 105 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 106 | # This is especially recommended for binary packages to ensure reproducibility, and is more 107 | # commonly ignored for libraries. 108 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 109 | #poetry.lock 110 | 111 | # pdm 112 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 113 | #pdm.lock 114 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 115 | # in version control. 116 | # https://pdm.fming.dev/#use-with-ide 117 | .pdm.toml 118 | 119 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 120 | __pypackages__/ 121 | 122 | # Celery stuff 123 | celerybeat-schedule 124 | celerybeat.pid 125 | 126 | # SageMath parsed files 127 | *.sage.py 128 | 129 | # Environments 130 | .env 131 | .venv 132 | env/ 133 | venv/ 134 | ENV/ 135 | env.bak/ 136 | venv.bak/ 137 | 138 | # Spyder project settings 139 | .spyderproject 140 | .spyproject 141 | 142 | # Rope project settings 143 | .ropeproject 144 | 145 | # mkdocs documentation 146 | /site 147 | 148 | # mypy 149 | .mypy_cache/ 150 | .dmypy.json 151 | dmypy.json 152 | 153 | # Pyre type checker 154 | .pyre/ 155 | 156 | # pytype static type analyzer 157 | .pytype/ 158 | 159 | # Cython debug symbols 160 | cython_debug/ 161 | 162 | # PyCharm 163 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 164 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 165 | # and can be added to the global gitignore or merged into this file. For a more nuclear 166 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 167 | #.idea/ 168 | 169 | ### Python Patch ### 170 | # Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration 171 | poetry.toml 172 | 173 | # ruff 174 | .ruff_cache/ 175 | 176 | # LSP config files 177 | pyrightconfig.json -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/psf/black 3 | rev: 24.2.0 4 | hooks: 5 | - id: black 6 | language_version: python3 7 | 8 | - repo: https://github.com/pycqa/isort 9 | rev: 5.13.2 10 | hooks: 11 | - id: isort 12 | name: isort (python) -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

tokenkit🔁

2 |

Tokenization Transfer for LLMs

3 | 4 |
5 | 6 |
7 | 8 | `tokenkit` is a toolkit implementing advanced methods to transfer *models* and *model knowledge* across tokenizers. 9 | 10 | ## News 11 | 12 | - __2025-04-23__: A new guide on [implementing cross-tokenizer distillation via ALM from scratch in PyTorch](./docs/pytorch_alm_from_scratch.ipynb)! 🔥 13 | - __2025-04-22__: New [Llama3-2-3B-IT-Byte](https://huggingface.co/benjamin/Llama3-2-3B-IT-Byte) and [Gemma2-2B-IT-Byte](https://huggingface.co/benjamin/Gemma2-2B-IT-Byte) checkpoints with native `transformers` support (plus, documentation on how to train them). Also, new guides for [running tokenizer transfer](./docs/tokenizer_transfer.md) and [byteification](./docs/byteification.md)! 14 | - __2025-04-02__: The initial release of `tokenkit` with support for cross-tokenizer distillation via ALM and Zero-Shot Tokenizer Transfer via FVT! 15 | 16 | ## Contents 17 | - [Why Transfer Across Tokenizers?](#why-transfer-across-tokenizers) 18 | - [Installation](#installation) 19 | - [Quickstart](#quickstart) 20 | - [Guides](#guides) 21 | - [Tokenizer Transfer via tokenkit](./docs/tokenizer_transfer.md) 22 | - [Byteification: A Unified Interface to Tokenizers](./docs/byteification.md) 23 | - [Implementing ALM From Scratch in PyTorch](./docs/pytorch_alm_from_scratch.ipynb) (new! 🔥) 24 | - [Features](#features) 25 | - [Cross-Tokenizer Distillation](#cross-tokenizer-distillation) 26 | - [Zero-Shot Tokenizer Transfer](#zero-shot-tokenizer-transfer) 27 | - [Token-Level Ensembling & Evaluating Transferred Models](#token-level-ensembling--evaluating-transferred-models) 28 | - [Citation](#citation) 29 | - [Acknowledgments](#acknowledgments) 30 | 31 | ## Why Transfer Across Tokenizers? 32 | 33 | LLMs are bound to the tokenizer they were pretrained with. This limits their adaptability, reusability and modularity. Tokenizer transfer can lift this limitation. For example: 34 | - If we want to reuse an LLM trained primarily on English in another language, we might want to update its tokenizer to one that is more suitable for the new language. 35 | - If we want to combine (e.g., token-level ensemble) two LLMs, we need to transfer them to a common tokenizer. 36 | - If we want to experiment with better tokenization schemes (e.g., byte-level tokenization), we might want to transfer an existing LLM to this tokenizer instead of training a new one expensively from scratch. 37 | - If we want to transfer knowledge from a large teacher model to a smaller student model (which uses another tokenizer), we might want to use *cross-tokenizer distillation* to directly transfer the teacher's knowledge to the student without the need to first transfer the teacher to the student's tokenizer. 38 | 39 | This library aims to let you accomplish all of this. 40 | 41 | ## Installation 42 | 43 | `tokenkit` is primarily implemented in Jax, using PyTorch for data loading (so your PyTorch installation does not need to support an accelerator). Recommended installation: 44 | 45 |
46 | TPU 47 | 48 | ```bash 49 | # Clone the repository & install the library 50 | git clone https://github.com/bminixhofer/tokenkit 51 | 52 | # Create a new virtual environment 53 | # Currently, requires Python <=3.10, but we are working on this: https://github.com/bminixhofer/tokenkit/issues/4 54 | python -m venv tokenkit_env 55 | . tokenkit_env/bin/activate 56 | 57 | # Install torch & jax 0.5.0 58 | pip install torch jax[tpu]==0.5.0 -f https://storage.googleapis.com/jax-releases/libtpu_releases.html 59 | 60 | # Currently, tokenkit relies on a fork of `lm_eval` 61 | pip install git+https://github.com/bminixhofer/lm-evaluation-harness 62 | 63 | # Install the library and the remaining dependencies 64 | pip install -r requirements.txt 65 | pip install -e . 66 | # You can ignore warnings from the command below, see https://github.com/bminixhofer/tokenkit/issues/4 67 | pip install paxml==1.4.0 praxis==1.4.0 --no-deps 68 | ``` 69 |
70 | 71 |
72 | GPU 73 | 74 | ```bash 75 | # Clone the repository & install the library 76 | git clone https://github.com/bminixhofer/tokenkit 77 | 78 | # Create a new virtual environment 79 | # Currently, requires Python <=3.10, but we are working on this: https://github.com/bminixhofer/tokenkit/issues/4 80 | python -m venv tokenkit_env 81 | . tokenkit_env/bin/activate 82 | 83 | # Install torch & jax 0.5.0 84 | # you may need to substitute cuda12 with the version of CUDA you are using: 85 | pip install torch jax[cuda12]==0.5.0 86 | 87 | # Currently, tokenkit relies on a fork of `lm_eval` 88 | pip install git+https://github.com/bminixhofer/lm-evaluation-harness 89 | 90 | # Install the library and the remaining dependencies 91 | pip install -r requirements.txt 92 | pip install -e . 93 | # You can ignore warnings from the command below, see https://github.com/bminixhofer/tokenkit/issues/4 94 | pip install paxml==1.4.0 praxis==1.4.0 --no-deps 95 | ``` 96 |
97 | 98 | ## Quickstart 99 | 100 | After installing the library, you can play around with the scripts in `examples/` to get started immediately. For example: 101 | 102 | ``` 103 | bash examples/llama3_to_byte_tokenizer_gpu.sh 104 | ``` 105 | 106 | If you're interested in reproducing or improving on a public model which has been trained via ALM, you can also take a look at the `tokenkit` command used to train that model, for example [in the Training section of the Llama3-2-3B-IT-Byte model card](https://huggingface.co/benjamin/Llama3-2-3B-IT-Byte#training). 107 | 108 | ## Guides 109 | 110 | - [Tokenizer Transfer via tokenkit](./docs/tokenizer_transfer.md) (start here!) 111 | - [Byteification: A Unified Interface to Tokenizers](./docs/byteification.md) 112 | - [Implementing ALM From Scratch in PyTorch](./docs/pytorch_alm_from_scratch.ipynb) (interactive notebook) 113 | 114 | ## Features 115 | 116 | ### Cross-Tokenizer Distillation 117 | 118 | `tokenkit` supports [Approximate Likelihood Matching (ALM)](https://arxiv.org/abs/2503.20083) for cross-tokenizer distillation. ALM usually performs best, but we have also implemented the following baselines: 119 | 120 | - [Dual Space Knowledge Distillation (DSKD)](https://arxiv.org/abs/2406.17328) 121 | - [Universal Logit Distillation (ULD)](https://arxiv.org/abs/2402.12030) 122 | - [Minimum Edit Distance Logit Alignment (MinED)](https://arxiv.org/abs/2401.10491) 123 | 124 | You can run cross-tokenizer distillation using the [`scripts/cross_tokenizer_distill.py`](scripts/cross_tokenizer_distill.py) script. See [`examples`](examples) for examples on transferring to different subword tokenizers and to byte-level tokenization. 125 | 126 | ### Zero-Shot Tokenizer Transfer 127 | 128 | `tokenkit` supports Zero-Shot Tokenizer Transfer (ZeTT) via [Fast Vocabulary Transfer (FVT)](https://aclanthology.org/2022.emnlp-industry.41). Zero-Shot Tokenizer Transfer is usually used to obtain a good initialization for additional training, but can in some cases also be useful on its own. See our [ZeTT paper](https://arxiv.org/abs/2405.07883) for more details. 129 | 130 | You can run Zero-Shot Tokenizer Transfer using the [`scripts/zett.py`](scripts/zett.py) script. 131 | 132 | **🚧 We are working on implementing more ZeTT methods (including hypernetwork training introduced [here](https://arxiv.org/abs/2405.07883)).** 133 | 134 | ### Token-Level Ensembling & Evaluating Transferred Models 135 | 136 | `tokenkit` supports autoregressive generation & loglikelihood scoring evaluation by implementing a Jax backend to the [LM Evaluation Harness](https://github.com/EleutherAI/lm-evaluation-harness). Alongside generating from single models, you can also generate from *token-level ensembles* of models. There are some predefined ensembles in [`configs/models`](configs/models). For example, this evaluates a token-level ensemle of Llama and Qwen on MMLU: 137 | 138 | ```bash 139 | python3 scripts/eval_lockstep.py \ 140 | models=llama_qwen \ 141 | eval.tasks=[mmlu] 142 | ``` 143 | 144 | To evaluate pretrained byte-level models, you'll need to pass embeddings to expand the input ids with (i.e., to use as n-gram embeddings). For example: 145 | 146 | ```bash 147 | python3 scripts/eval.py \ 148 | model.pretrained_model_name_or_path=\'benjamin/Gemma2-2B-IT-Byte\' \ 149 | model.tokenizer_name=\'google/gemma-2-2b-it:source=Gemma2:conversion=byte\' \ 150 | expand_model.pretrained_model_name_or_path=\'benjamin/gemma-2-2b-it-flax\' \ 151 | expand_model.tokenizer_name=\'google/gemma-2-2b-it:source=Gemma2\' \ 152 | eval.tasks=[mmlu] 153 | ``` 154 | 155 | To evaluate any other model (e.g., subword-to-subword transferred models), use something like the following: 156 | 157 | ```bash 158 | python3 scripts/eval.py \ 159 | model.pretrained_model_name_or_path=\'benjamin/Gemma2-2B-IT-with-Qwen2-Tokenizer\' \ 160 | model.tokenizer_name=\'benjamin/Gemma2-2B-IT-with-Qwen2-Tokenizer:source=Gemma2:conversion=prebyteified\' \ 161 | eval.tasks=[mmlu] \ 162 | ``` 163 | 164 | ## Citation 165 | 166 | To refer to this repository or to cite Approximate Likelihood Matching, please use this citation: 167 | 168 | ``` 169 | @article{alm, 170 | title={Cross-Tokenizer Distillation via Approximate Likelihood Matching}, 171 | author={Minixhofer, Benjamin and Vuli{\'c}, Ivan and Ponti, Edoardo Maria}, 172 | journal={arXiv preprint arXiv:2503.20083}, 173 | year={2025} 174 | } 175 | ``` 176 | 177 | Please use this citation for Zero-Shot Tokenizer Transfer: 178 | 179 | ``` 180 | @inproceedings{zett, 181 | title={Zero-Shot Tokenizer Transfer}, 182 | author={Benjamin Minixhofer and Edoardo Ponti and Ivan Vuli{\'c}}, 183 | booktitle={The Thirty-eighth Annual Conference on Neural Information Processing Systems}, 184 | year={2024}, 185 | url={https://openreview.net/forum?id=RwBObRsIzC} 186 | } 187 | ``` 188 | 189 | ## Acknowledgments 190 | 191 | Constituent projects (ALM, ZeTT) were supported by a Royal Society University Research Fellowship ‘Inclusive and Sustainable Language Technology for a Truly Multilingual World’ (no 221137; 2022-) awarded to Ivan Vulić, by the Google Cloud Research Credits program with the award GCP329647813, and by Cloud TPUs from Google’s TPU Research Cloud (TRC). The name `tokenkit` and the README layout were inspired by [mergekit](https://github.com/arcee-ai/mergekit). [big_vision](https://github.com/google-research/big_vision) was extremely useful as a high-quality reference JAX training codebase. 192 | -------------------------------------------------------------------------------- /configs/compute_mined_mapping.yaml: -------------------------------------------------------------------------------- 1 | teacher_tokenizer_name: google/gemma-2-2b-it:source=Gemma2 2 | target_tokenizer_name: Qwen/Qwen2.5-1.5B:source=Qwen2:target=Gemma2 3 | output: outputs/tokenizer_data/gemma2_to_qwen2 4 | num_workers: 64 -------------------------------------------------------------------------------- /configs/compute_tokenizer_info.yaml: -------------------------------------------------------------------------------- 1 | teacher_tokenizer_name: google/gemma-2-2b-it:source=Gemma2 2 | target_tokenizer_name: Qwen/Qwen2.5-1.5B:source=Qwen2:target=Gemma2 3 | output: outputs/tokenizer_data/gemma2_to_qwen2 4 | seed: 1234 5 | additive_smoothing_constant: 1e-9 6 | teacher_subsample_percent: null 7 | student_subsample_percent: null 8 | 9 | data: 10 | batch_size: 1024 11 | num_workers: 64 12 | 13 | defaults: 14 | - _self_ 15 | - data: tulu3 -------------------------------------------------------------------------------- /configs/cross_tokenizer_distill.yaml: -------------------------------------------------------------------------------- 1 | steps: 5_000 2 | warmup_steps: 2_000 3 | name: "unnamed" 4 | output: "outputs/cross_tokenizer_distill" 5 | num_workers: 16 6 | log_interval: 10 7 | sync_interval: 100 8 | eval_interval: 5000 9 | save_interval: 5000 10 | losses: [alm_unbiased] 11 | target_tokenizer_name: "Qwen/Qwen2.5-1.5B:source=Qwen2:target=Gemma2" 12 | 13 | train_model_mode: "lora" 14 | model_lora_rank: 64 15 | model_lora_alpha: 64 16 | train_embeddings: true 17 | 18 | bce_temp: 100.0 19 | alm_diff_fn: "binary_ce" 20 | alm_mode: "append_space" 21 | tokenizer_pair_data_path: "artifacts/tokenizer_data/gemma2_to_qwen2" 22 | tokenizer_pair_bias_threshold: 1.e-4 23 | tokenizer_pair_bias_threshold_side_path: null 24 | 25 | student: 26 | pretrained_model_name_or_path: "benjamin/gemma-2-2b-it-flax" 27 | tokenizer_name: "google/gemma-2-2b-it:source=Gemma2" 28 | 29 | data: 30 | batch_size: 16 31 | num_workers: 16 32 | kind: "hf" 33 | mix_languages: false 34 | streaming: false 35 | shuffle_buffer_size: "inf" 36 | dataset_configs: 37 | - lang_code: en 38 | kwargs: 39 | path: allenai/tulu-3-sft-mixture 40 | split: train 41 | 42 | hypernet: 43 | architecture: transformer 44 | num_layers: 1 45 | residual: true 46 | residual_alpha: 1 47 | use_attention: false 48 | 49 | optimizer: 50 | type: adamw 51 | weight_decay: 0.01 52 | b1: 0.9 53 | b2: 0.95 54 | eps: 1.e-8 55 | grad_acc_steps: null 56 | learning_rate: 1.e-5 57 | max_grad_norm: 1.0 58 | param_groups: 59 | - pattern: .*(projector_query|projector_s2t|projector_t2s|projector_latents|loss_weights).* 60 | lr_scale: 2 61 | 62 | eval: 63 | tasks: [arc_easy,arc_challenge,piqa,hellaswag,boolq,arithmetic,mmlu] 64 | lengths: [128, 256, 512, 1024, 2048] 65 | tokens_per_batch: 8192 66 | add_bos: true 67 | chat_template_mode: surround_instruct 68 | confirm_run_unsafe_code: true -------------------------------------------------------------------------------- /configs/data/tulu3.yaml: -------------------------------------------------------------------------------- 1 | kind: "hf" 2 | mix_languages: false 3 | streaming: false 4 | shuffle_buffer_size: "inf" 5 | dataset_configs: 6 | - lang_code: en 7 | kwargs: 8 | path: allenai/tulu-3-sft-mixture 9 | split: train -------------------------------------------------------------------------------- /configs/eval.yaml: -------------------------------------------------------------------------------- 1 | use_cpu: false 2 | pad_to_multiple_of: 128 3 | output: outputs/eval 4 | 5 | model: 6 | pretrained_model_name_or_path: "benjamin/gemma-2-2b-it-flax" 7 | tokenizer_name: "google/gemma-2-2b-it:source=Gemma2" 8 | 9 | expand_model: 10 | pretrained_model_name_or_path: null 11 | tokenizer_name: null 12 | 13 | eval: 14 | tasks: ["piqa", "boolq", "arc_easy", "hellaswag"] 15 | lengths: [128, 256, 512, 1024, 2048] 16 | tokens_per_batch: 4096 17 | add_bos: true 18 | chat_template_mode: surround_instruct 19 | confirm_run_unsafe_code: true 20 | -------------------------------------------------------------------------------- /configs/eval_lockstep.yaml: -------------------------------------------------------------------------------- 1 | combine_strategy: "mean_prob" 2 | pad_to_multiple_of: 128 3 | output: outputs/eval_lockstep 4 | 5 | eval: 6 | tasks: [arc_easy,arc_challenge,piqa,boolq,arithmetic,mmlu,ifeval,agieval_en,agieval_cn] 7 | lengths: [128, 256, 512, 1024, 2048] 8 | tokens_per_batch: 4096 9 | add_bos: true 10 | chat_template_mode: surround_instruct 11 | confirm_run_unsafe_code: true 12 | 13 | models: 14 | - pretrained_model_name_or_path: "/mnt/disks/persist/exports/20250424160833_gemma2_to_qwen2_ours_agg_approx_gradmag_preserve_mag_20k" 15 | tokenizer_name: "Qwen/Qwen2.5-1.5B:source=Qwen2:target=Gemma2" 16 | add_bos: true 17 | - pretrained_model_name_or_path: "/mnt/disks/persist/exports/20250424174156_llama3_to_qwen2_ours_agg_approx_gradmag_preserve_mag_20k" 18 | tokenizer_name: "Qwen/Qwen2.5-1.5B:source=Qwen2:target=Llama3" 19 | add_bos: true 20 | - pretrained_model_name_or_path: "benjamin/Qwen2.5-1.5B-Instruct-flax" 21 | tokenizer_name: "Qwen/Qwen2-1.5B-Instruct:source=Qwen2" 22 | add_bos: false -------------------------------------------------------------------------------- /configs/eval_lockstep_mined.yaml: -------------------------------------------------------------------------------- 1 | combine_strategy: "mean_prob" 2 | pad_to_multiple_of: 128 3 | output: outputs/eval_lockstep 4 | 5 | eval: 6 | tasks: [arc_easy,arc_challenge,piqa,boolq,arithmetic,mmlu,ifeval,agieval_en,agieval_cn] 7 | lengths: [128, 256, 512, 1024, 2048] 8 | tokens_per_batch: 4096 9 | add_bos: true 10 | chat_template_mode: surround_instruct 11 | confirm_run_unsafe_code: true 12 | 13 | models: 14 | - pretrained_model_name_or_path: "benjamin/Qwen2.5-1.5B-Instruct-flax" # pivot first 15 | tokenizer_name: "Qwen/Qwen2-1.5B-Instruct:source=Qwen2" 16 | add_bos: false 17 | - pretrained_model_name_or_path: "benjamin/gemma-2-2b-it-flax" 18 | tokenizer_name: "google/gemma-2-2b-it:source=Gemma2" 19 | add_bos: true 20 | - pretrained_model_name_or_path: "benjamin/Llama-3.2-3B-Instruct-flax" 21 | tokenizer_name: "meta-llama/Llama-3.2-3B-Instruct:source=Llama3" 22 | add_bos: true 23 | 24 | baseline_mined_mapping_paths: [null, "artifacts/tokenizer_data/gemma2_to_qwen2", "artifacts/tokenizer_data/llama3_to_qwen2"] -------------------------------------------------------------------------------- /configs/math_cross_tokenizer_distill.yaml: -------------------------------------------------------------------------------- 1 | steps: 5_000 2 | warmup_steps: 2_000 3 | name: "unnamed" 4 | output: "outputs/cross_tokenizer_distill" 5 | num_workers: 16 6 | log_interval: 10 7 | sync_interval: 100 8 | eval_interval: 5000 9 | save_interval: 5000 10 | losses: [alm_unbiased] 11 | target_tokenizer_name: google/gemma-2-2b-it:source=Gemma2 12 | tokens_to_add: [<|start_header_id|>,<|end_header_id|>,<|eot_id|>,<|eom_id|>,<|python_tag|>] 13 | 14 | train_model_mode: "lora" 15 | model_lora_rank: 64 16 | model_lora_alpha: 64 17 | train_embeddings: true 18 | 19 | bce_temp: 100.0 20 | alm_diff_fn: "binary_ce" 21 | alm_mode: "space_merge+append_space" 22 | tokenizer_pair_data_path: "artifacts/tokenizer_data/math_llama3_to_gemma2" 23 | tokenizer_pair_bias_threshold: 0.1 24 | tokenizer_pair_bias_threshold_side_path: null 25 | 26 | student: 27 | pretrained_model_name_or_path: "benjamin/gemma-2-2b-it-flax" 28 | tokenizer_name: "google/gemma-2-2b-it:source=Gemma2" 29 | 30 | teacher: 31 | pretrained_model_name_or_path: "benjamin/OpenMath2-Llama3.1-8B-flax" 32 | tokenizer_name: "nvidia/OpenMath2-Llama3.1-8B:source=Llama3" 33 | 34 | data: 35 | batch_size: 16 36 | num_workers: 16 37 | kind: "hf" 38 | mix_languages: false 39 | streaming: false 40 | dataset_configs: 41 | - lang_code: en 42 | kwargs: 43 | path: benjamin/OpenMathInstruct-2-2M-formatted 44 | split: train 45 | 46 | hypernet: 47 | architecture: transformer 48 | num_layers: 1 49 | residual: true 50 | residual_alpha: 1 51 | use_attention: false 52 | 53 | optimizer: 54 | type: adamw 55 | weight_decay: 0.01 56 | b1: 0.9 57 | b2: 0.95 58 | eps: 1.e-8 59 | grad_acc_steps: null 60 | learning_rate: 1.e-5 61 | max_grad_norm: 1.0 62 | param_groups: 63 | - pattern: .*(projector_query|projector_s2t|projector_t2s|projector_latents|loss_weights).* 64 | lr_scale: 2 65 | 66 | eval: 67 | tasks: [arc_easy,arc_challenge,piqa,hellaswag,boolq,arithmetic,mmlu] 68 | lengths: [128, 256, 512, 1024, 2048] 69 | tokens_per_batch: 8192 70 | add_bos: true 71 | chat_template_mode: surround_instruct 72 | confirm_run_unsafe_code: true -------------------------------------------------------------------------------- /configs/math_same_tokenizer_distill.yaml: -------------------------------------------------------------------------------- 1 | steps: 5_000 2 | warmup_steps: 2_000 3 | name: "unnamed" 4 | output: "outputs/same_tokenizer_distill" 5 | num_workers: 16 6 | log_interval: 10 7 | sync_interval: 100 8 | eval_interval: 5000 9 | save_interval: 5000 10 | losses: [alm_unbiased] 11 | target_tokenizer_name: meta-llama/Llama-3.2-3B-Instruct:source=Llama3 12 | 13 | train_model_mode: "lora" 14 | model_lora_rank: 64 15 | model_lora_alpha: 64 16 | train_embeddings: true 17 | 18 | bce_temp: 100.0 19 | alm_diff_fn: "binary_ce" 20 | alm_mode: "space_merge+append_space" 21 | tokenizer_pair_data_path: "artifacts/tokenizer_data/math_llama3_to_llama3" 22 | tokenizer_pair_bias_threshold: 0.1 23 | tokenizer_pair_bias_threshold_side_path: null 24 | 25 | student: 26 | pretrained_model_name_or_path: "benjamin/Llama-3.2-3B-Instruct-flax" 27 | tokenizer_name: "meta-llama/Llama-3.2-3B-Instruct:source=Llama3" 28 | 29 | teacher: 30 | pretrained_model_name_or_path: "benjamin/OpenMath2-Llama3.1-8B-flax" 31 | tokenizer_name: "nvidia/OpenMath2-Llama3.1-8B:source=Llama3" 32 | 33 | data: 34 | batch_size: 16 35 | num_workers: 16 36 | kind: "hf" 37 | mix_languages: false 38 | streaming: false 39 | dataset_configs: 40 | - lang_code: en 41 | kwargs: 42 | path: benjamin/OpenMathInstruct-2-2M-formatted 43 | split: train 44 | 45 | hypernet: 46 | architecture: transformer 47 | num_layers: 1 48 | residual: true 49 | residual_alpha: 1 50 | use_attention: false 51 | 52 | optimizer: 53 | type: adamw 54 | weight_decay: 0.01 55 | b1: 0.9 56 | b2: 0.95 57 | eps: 1.e-8 58 | grad_acc_steps: null 59 | learning_rate: 1.e-5 60 | max_grad_norm: 1.0 61 | param_groups: 62 | - pattern: .*(projector_query|projector_s2t|projector_t2s|projector_latents|loss_weights).* 63 | lr_scale: 2 64 | 65 | eval: 66 | tasks: [arc_easy,arc_challenge,piqa,hellaswag,boolq,arithmetic,mmlu] 67 | lengths: [128, 256, 512, 1024, 2048] 68 | tokens_per_batch: 8192 69 | add_bos: true 70 | chat_template_mode: surround_instruct 71 | confirm_run_unsafe_code: true -------------------------------------------------------------------------------- /configs/models/gemma_llama_qwen.yaml: -------------------------------------------------------------------------------- 1 | - pretrained_model_name_or_path: "benjamin/Gemma2-2B-IT-with-Qwen2-Tokenizer" 2 | tokenizer_name: "benjamin/Gemma2-2B-IT-with-Qwen2-Tokenizer:source=Gemma2:conversion=prebyteified" 3 | add_bos: true 4 | - pretrained_model_name_or_path: "benjamin/Llama3.2-3B-IT-with-Qwen2-Tokenizer" 5 | tokenizer_name: "benjamin/Llama3.2-3B-IT-with-Qwen2-Tokenizer:source=Llama3:conversion=prebyteified" 6 | add_bos: true 7 | - pretrained_model_name_or_path: "benjamin/Qwen2.5-1.5B-Instruct-flax" 8 | tokenizer_name: "Qwen/Qwen2-1.5B-Instruct:source=Qwen2" 9 | add_bos: false -------------------------------------------------------------------------------- /configs/models/llama_qwen.yaml: -------------------------------------------------------------------------------- 1 | - pretrained_model_name_or_path: "benjamin/Gemma2-2B-IT-with-Qwen2-Tokenizer" 2 | tokenizer_name: "benjamin/Gemma2-2B-IT-with-Qwen2-Tokenizer:source=Qwen2" 3 | add_bos: true 4 | - pretrained_model_name_or_path: "benjamin/Qwen2.5-1.5B-Instruct-flax" 5 | tokenizer_name: "Qwen/Qwen2-1.5B-Instruct:source=Qwen2" 6 | add_bos: false -------------------------------------------------------------------------------- /configs/optimizer/adamw.yaml: -------------------------------------------------------------------------------- 1 | type: adamw 2 | weight_decay: 0.01 3 | b1: 0.9 4 | b2: 0.95 5 | eps: 1e-8 6 | grad_acc_steps: null -------------------------------------------------------------------------------- /configs/train_zett_hn.yaml: -------------------------------------------------------------------------------- 1 | losses: ["sft"] 2 | output: "outputs/hn" 3 | seed: 1234 4 | dtype: bfloat16 5 | pad_to_multiple_of: 128 6 | identity_steps: 0 7 | identity_lr: 3.e-4 8 | warmup_steps: 10_000 9 | steps: 200_000 10 | train_embeddings: false 11 | train_model_mode: "no" 12 | num_workers: 64 13 | name: "unnamed" 14 | compat: false 15 | eval_at_step_zero: false 16 | save_at_step_zero: false 17 | 18 | n_data_parallel: 1 19 | n_model_parallel: 8 20 | 21 | log_interval: 50 22 | sync_interval: 100 23 | eval_interval: 10_000 24 | save_interval: 10_000 25 | 26 | ppl_eval_data: null 27 | 28 | optimizer: 29 | learning_rate: 6e-5 30 | 31 | eval: 32 | tasks: [piqa,hellaswag,arc_easy] 33 | lengths: [128, 256, 512, 1024, 2048] 34 | tokens_per_batch: 8192 35 | add_bos: true 36 | chat_template_mode: direct_encode 37 | confirm_run_unsafe_code: true 38 | tokenizers: 39 | - tokenizer: openai-community/gpt2:source=GPT2:target=TinyLlama:conversion=manual_add_prefix_space 40 | name: gpt2 41 | - tokenizer: mistralai/Mistral-Small-3.1-24B-Base-2503:source=Mistral:target=TinyLlama:conversion=manual_add_prefix_space 42 | name: mistral 43 | - tokenizer: meta-llama/Llama-3.2-3B:source=Llama3:target=TinyLlama:conversion=manual_add_prefix_space 44 | name: llama3 45 | 46 | data: 47 | batch_size: 128 48 | num_workers: 16 49 | kind: hf 50 | mix_languages: false 51 | streaming: true 52 | dataset_configs: 53 | - lang_code: en 54 | kwargs: 55 | path: "allenai/madlad-400" 56 | name: "en" 57 | split: "clean" 58 | 59 | # TODO: disentangle data/collator args 60 | collator: 61 | do_tokenizer_sampling: true 62 | sample_text_span: true 63 | n_pools: 1 64 | add_prefix_space: true 65 | hn_surface_maxlen: 8 66 | n_token_subsample: null 67 | identity_n_token_subsample: 16384 68 | pad_to_multiple_of: 128 69 | tokenizer_sample_max: 32768 70 | tokenizer_sample_mean: 32768 71 | tokenizer_sample_min: 32768 72 | tokenizer_sample_std: 0 73 | tokenizer_batch_size: 2048 74 | tokenizer_noise_std: 4 75 | tokenizer_noise_mean: 1.e-5 76 | block_size: 128 77 | 78 | hypernet: 79 | architecture: transformer 80 | residual_alpha: 1 81 | residual: true 82 | use_attention: true 83 | num_layers: 3 84 | shared: true 85 | num_heads: 16 86 | use_attention_mask: false 87 | multiply_hidden_dim_by_num_embeddings: false 88 | 89 | optimizer: 90 | type: adamw 91 | weight_decay: 0.01 92 | b1: 0.9 93 | b2: 0.95 94 | eps: 1.e-8 95 | grad_acc_steps: null 96 | learning_rate: 1.e-5 97 | max_grad_norm: 1.0 98 | param_groups: 99 | - pattern: .*(projector_query|projector_s2t|projector_t2s|projector_latents|loss_weights).* 100 | lr_scale: 2 101 | 102 | model: 103 | pretrained_model_name_or_path: "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T" 104 | tokenizer_name: "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T:source=TinyLlama" 105 | revision: "refs/pr/8" -------------------------------------------------------------------------------- /configs/train_zett_hn_gemma2.yaml: -------------------------------------------------------------------------------- 1 | losses: ["sft"] 2 | output: "outputs/hn" 3 | seed: 1234 4 | dtype: bfloat16 5 | pad_to_multiple_of: 128 6 | identity_steps: 0 7 | identity_lr: 3.e-4 8 | warmup_steps: 10_000 9 | steps: 200_000 10 | train_embeddings: false 11 | train_model_mode: "no" 12 | num_workers: 64 13 | name: "unnamed" 14 | compat: false 15 | eval_at_step_zero: false 16 | save_at_step_zero: false 17 | 18 | n_data_parallel: 1 19 | n_model_parallel: 8 20 | 21 | log_interval: 50 22 | sync_interval: 100 23 | eval_interval: 10_000 24 | save_interval: 10_000 25 | 26 | ppl_eval_data: null 27 | 28 | optimizer: 29 | learning_rate: 6e-5 30 | 31 | eval: 32 | tasks: [piqa,hellaswag,arc_easy] 33 | lengths: [128, 256, 512, 1024, 2048] 34 | tokens_per_batch: 8192 35 | add_bos: true 36 | chat_template_mode: direct_encode 37 | confirm_run_unsafe_code: true 38 | tokenizers: 39 | - tokenizer: openai-community/gpt2:source=GPT2:target=Gemma2 40 | name: gpt2 41 | - tokenizer: mistralai/Mistral-Small-3.1-24B-Base-2503:source=Mistral:target=Gemma2 42 | name: mistral 43 | - tokenizer: meta-llama/Llama-3.2-3B:source=Llama3:target=Gemma2 44 | name: llama3 45 | 46 | data: 47 | batch_size: 128 48 | kind: "hf_saved" 49 | lang_code: en 50 | dataset_configs: 51 | - path: "/lfs/dolmino_50B_medium" 52 | 53 | # TODO: disentangle data/collator args 54 | collator: 55 | do_tokenizer_sampling: true 56 | sample_text_span: true 57 | n_pools: 1 58 | add_prefix_space: false 59 | hn_surface_maxlen: 8 60 | n_token_subsample: null 61 | identity_n_token_subsample: 16384 62 | pad_to_multiple_of: 128 63 | tokenizer_sample_max: 32768 64 | tokenizer_sample_mean: 32768 65 | tokenizer_sample_min: 32768 66 | tokenizer_sample_std: 0 67 | tokenizer_batch_size: 2048 68 | tokenizer_noise_std: 4 69 | tokenizer_noise_mean: 1.e-5 70 | block_size: 128 71 | 72 | hypernet: 73 | architecture: transformer 74 | residual_alpha: 1 75 | residual: true 76 | use_attention: true 77 | num_layers: 3 78 | shared: true 79 | num_heads: 16 80 | use_attention_mask: false 81 | multiply_hidden_dim_by_num_embeddings: false 82 | 83 | optimizer: 84 | type: adamw 85 | weight_decay: 0.01 86 | b1: 0.9 87 | b2: 0.95 88 | eps: 1.e-8 89 | grad_acc_steps: null 90 | learning_rate: 1.e-5 91 | max_grad_norm: 1.0 92 | param_groups: 93 | - pattern: .*(projector_query|projector_s2t|projector_t2s|projector_latents|loss_weights).* 94 | lr_scale: 2 95 | 96 | model: 97 | pretrained_model_name_or_path: "benjamin/gemma-2-2b-flax" 98 | tokenizer_name: "google/gemma-2-2b:source=Gemma2" -------------------------------------------------------------------------------- /configs/zett.yaml: -------------------------------------------------------------------------------- 1 | source_model: 2 | pretrained_model_name_or_path: benjamin/gemma-2-2b-it-flax 3 | tokenizer_name: google/gemma-2-2b-it:source=Gemma2 4 | 5 | target_tokenizer_name: Qwen/Qwen2.5-1.5B:source=Qwen2:target=Gemma2 6 | output: outputs/zett/gemma-2-2b-it-flax-fvt-qwen -------------------------------------------------------------------------------- /docs/byteification.md: -------------------------------------------------------------------------------- 1 | # Byteification: A Unified Interface to Tokenizers 2 | 3 | In this guide, we'll take a look at how `tokenkit` interacts with tokenizers: `tokenkit` uses a unified byte-level interface to tokenizers to prevent issues stemming from tokenizers using different encoding schemes. For example, let's say we want to compute the number of overlapping tokenizers between the Gemma2 and Llama3 tokenizers. Here is the naive approach: 4 | 5 | ```python 6 | from transformers import AutoTokenizer 7 | 8 | tok1 = AutoTokenizer.from_pretrained("google/gemma-2-2b-it") 9 | tok2 = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-3B-Instruct") 10 | 11 | n_overlap = len(set(tok1.get_vocab().keys()) & set(tok2.get_vocab().keys())) 12 | # 25632 - this is suspiciously low! 13 | ``` 14 | 15 | The two tokenizers use a different encoding, so even if two tokens encode the same UTF-8 bytes, they might look different! 16 | 17 | ```python 18 | tok1.tokenize(" Café") # ['▁Café'] 19 | tok2.tokenize(" Café") # ['ĠCafé'] 20 | ``` 21 | 22 | We can fix this by instead using the `tokenkit.byteify.ByteifyTokenizer` interface. 'Byteification' preserves the tokenizers functionality, while providing a unified (byte-level) encoding: 23 | 24 | ```python 25 | from tokenkit.byteify import load_byteify_tokenizer 26 | 27 | tok1 = load_byteify_tokenizer("google/gemma-2-2b-it:source=Gemma2") 28 | tok2 = load_byteify_tokenizer("meta-llama/Llama-3.2-3B-Instruct:source=Llama3") 29 | 30 | n_overlap = len(set(tok1.get_vocab().keys()) & set(tok2.get_vocab().keys())) 31 | # 85699 - this is much more reasonable! 32 | 33 | tok1.tokenize(" Café") # ['ĠCafé'] 34 | tok2.tokenize(" Café") # ['ĠCafé'] 35 | ``` 36 | 37 | This always 100% preserves the tokenizer functionality (e.g., which tokens any text is encoded as). The API mostly matches the HuggingFace tokenizers API (e.g., `convert_ids_to_tokens`, `convert_tokens_to_ids`, `get_vocab`, `tokenize`, `add_tokens`) but is not exactly the same. 38 | 39 | This allows us to compute things like lexical overlap and token sequence alignments accurately. `tokenkit` also implements an exact alignment algorithm between tokenizers, including tokenizers with different special tokens (e.g., different chat templates). 40 | 41 | ```python 42 | from tokenkit.byteify import load_byteify_tokenizer 43 | from tokenkit import align 44 | 45 | tok1 = load_byteify_tokenizer("google/gemma-2-2b-it:source=Gemma2") 46 | tok2 = load_byteify_tokenizer("meta-llama/Llama-3.2-3B-Instruct:source=Llama3") 47 | 48 | # Gemma2 chat template 49 | tokens1 = tok1.tokenize("user\nWhat's ultracrepidarianism?\n") 50 | # Llama3 chat template 51 | tokens2 = tok2.tokenize("<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nWhat's ultracrepidarianism?<|eot_id|>") 52 | 53 | alignment_indices = align.get_alignment_indices(tokens1, tokens2, tok1, tok2)[0] 54 | 55 | for (start1, end1, start2, end2) in alignment_indices: 56 | print(tokens1[start1:end1], tokens2[start2:end2]) 57 | 58 | # [''] ['<|begin_of_text|>'] 59 | # [''] ['<|start_header_id|>'] 60 | # ['user'] ['user'] 61 | # ['Ċ'] ['<|end_header_id|>', 'ĊĊ'] 62 | # ['What'] ['What'] 63 | # ["'", 's'] ["'s"] 64 | # ['Ġultra', 'cre', 'pid'] ['Ġultr', 'ac', 'repid'] 65 | # ['arian'] ['arian'] 66 | # ['ism'] ['ism'] 67 | # ['?'] ['?'] 68 | # ['', 'Ċ'] ['<|eot_id|>'] 69 | ``` 70 | 71 | ## Tokenizer Specs 72 | 73 | As you've seen above, `tokenkit` uses colon-separated *tokenizer spec* to load tokenizers. This gives us a simple way to specify additional arguments and modifications to the tokenizer. 74 | 75 | - The `source=` argument (the only required argument) enables making sure we set special tokens and the chat template correctly for the given model family. See [tokenkit/model_kinds.py](../tokenkit/model_kinds.py) for supported model families or to add new model families. 76 | - The optional `target=` argument enables updating the tokenizer special tokens / chat template to a different model family. E.g. `google/gemma-2-2b-it:source=Gemma2:target=Qwen2` would tokenize all regular text equivalent to the Gemma2 tokenizer, but use the Qwen2 chat template and special tokens. Since Qwen2 does not use a \ token, it would thus also not use a \ token. 77 | - The optional `conversion=` argument enables conversion to a different encoding scheme. `conversion=byte` is the one you are most likely to encounter. This converts the tokenizer to tokenize all regular (non-special-token) bytes as individual tokens i.e. to byte-level tokenization (*this is different and unrelated to byteification!*). Special tokens are kept as-is. For example: 78 | 79 | ```python 80 | from tokenkit.byteify import load_byteify_tokenizer 81 | 82 | tok = load_byteify_tokenizer("google/gemma-2-2b-it:source=Gemma2:conversion=byte") 83 | 84 | tok.tokenize("Hello, world!") # ['', 'H', 'e', 'l', 'l', 'o', ',', 'Ġ', 'w', 'o', 'r', 'l', 'd', '!'] 85 | print(len(tok)) # 256 + some special tokens 86 | ``` 87 | 88 | --- 89 |

Next: Implementing ALM From Scratch in PyTorch

90 | -------------------------------------------------------------------------------- /docs/tokenizer_transfer.md: -------------------------------------------------------------------------------- 1 | # Tokenizer Transfer via tokenkit 2 | 3 | This guide will walk you through the process of transferring a pretrained model to a new tokenizer using tokenkit. 4 | 5 | First, follow the installation instructions in the [README](../README.md). 6 | 7 | Then, the scripts in `examples/` provide a starting point for transferring a model to a new tokenizer. For example: 8 | 9 | ```bash 10 | bash examples/llama3_to_qwen2_tokenizer_gpu.sh 11 | # or on TPU: examples/llama3_to_qwen2_tokenizer_tpu.sh 12 | ``` 13 | 14 | This will distill the [Llama3.2-3B-Instruct](https://huggingface.co/meta-llama/Llama-3.2-3B-Instruct) model to the [Qwen2.5-1.5B](https://huggingface.co/Qwen/Qwen2.5-1.5B-Instruct) tokenizer. Let's have a look at what it runs: 15 | 16 | ```bash 17 | # examples/llama3_to_qwen2_tokenizer_gpu.sh 18 | NAME=llama3_to_qwen2_tokenizer 19 | python3 scripts/cross_tokenizer_distill.py \ 20 | --config=configs/cross_tokenizer_distill.yaml \ 21 | --overrides \ 22 | losses=[sft,alm_unconstrained] \ 23 | alm_mode=merge_by_space_prob+append_space \ 24 | tokenizer_pair_bias_threshold=0.1 \ 25 | train_model_mode=lora \ 26 | model_lora_rank=64 \ 27 | model_lora_alpha=64 \ 28 | n_data_parallel=1 \ 29 | n_model_parallel=1 \ 30 | steps=5000 \ 31 | eval_interval=1000 \ 32 | save_interval=1000 \ 33 | data.batch_size=64 \ 34 | optimizer.grad_acc_steps=4 \ 35 | data.num_workers=16 \ 36 | student.pretrained_model_name_or_path=benjamin/Llama-3.2-3B-Instruct-flax \ 37 | student.tokenizer_name=\'meta-llama/Llama-3.2-3B-Instruct:source=Llama3\' \ 38 | target_tokenizer_name=\'Qwen/Qwen2.5-1.5B:source=Qwen2:target=Llama3\' \ 39 | name=$NAME 40 | ``` 41 | 42 | Default arguments are taken from [`configs/cross_tokenizer_distill.yaml`](../configs/cross_tokenizer_distill.yaml). You can keep many of these as-is. A notably parameter which we don't override here is the dataset: we use the [Tulu3 instruction-tuning dataset](https://huggingface.co/datasets/allenai/tulu-3-sft-mixture). This is a good choice for transfer of chat / instruction-following models. You can update this to fit your use case by modifying the `data` section in [`configs/cross_tokenizer_distill.yaml`](../configs/cross_tokenizer_distill.yaml). 43 | 44 | Let's go over the overriden arguments in more detail: 45 | 46 | ``` 47 | losses=[sft,alm_unconstrained] \ 48 | alm_mode=merge_by_space_prob+append_space \ 49 | tokenizer_pair_bias_threshold=0.1 \ 50 | ``` 51 | 52 | These arguments configure the losses to optimize and the ALM mode to use. The above configuration should be the best for many cases. Importantly, *it is different to what is described in the [ALM paper](https://arxiv.org/abs/2503.20083)*. In particular, it achieves equivalent or better results without precomputation. A more detailed description is forthcoming in an updated version of our paper. 53 | 54 | ``` 55 | hypernet.architecture=identity \ 56 | ``` 57 | 58 | By default, `tokenkit` uses a one-layer embedding projector network (`hypernet.architecture=transformer`). This improves performance but can be memory-intensive, so we disable it here. 59 | 60 | ``` 61 | multitask_aggregation_fn=approx_gradmag_preserve_mag \ 62 | ``` 63 | 64 | Use the last-layer gradient magnitudes to approximately reweigh the the multiple objectives (in this case, SFT and ALM) to contribute equally to the final loss gradients. This adds a little extra overhead since we need to backpropagate through the last layer separately for every objective, but it removes the requirement to manually tune loss weights. If you observe that this adds too much overhead, you can skip it and manually tune the loss weights using e.g. `loss_weights=[1.,0.5]`, or leave it out completely to use uniform loss weights (instead of uniform loss *gradient* weights). 65 | 66 | ``` 67 | train_model_mode=lora \ 68 | model_lora_rank=64 \ 69 | model_lora_alpha=64 \ 70 | ``` 71 | 72 | Train the model using LoRA with rank = alpha = 64. `tokenkit` applies LoRA to the QKV projections, attention output projection, as well as the MLP up-, down- and gate projections (see [lora.py](../tokenkit/models/lora.py)). 73 | 74 | You can use `train_model_mode=full` to train the full model instead. However, in this case, we need to store a separate copy of the model parameters for the student and the teacher, whereas with LoRA, we can use a single model parameter copy and materialize / dematerialize the LoRA parameters as needed. Storing a separate teacher model copy makes training substantially more memory intensive. A rule of thumb is: for transfer to a similar kind of tokenizer (e.g., another subword tokenizer), LoRA is sufficient. For transfer to a very different tokenizer (e.g., a byte-level tokenizer), full-finetuning helps. 75 | 76 | ``` 77 | n_data_parallel=1 \ 78 | n_model_parallel=1 \ 79 | ``` 80 | 81 | Data and model parallelism. Set this such that the product of the two is the number of GPUs or TPU cores you have available. Often (especially for larger models) you will want to increase model parallelism and keep data parallelism at 1. 82 | 83 | ``` 84 | steps=5000 \ 85 | eval_interval=1000 \ 86 | save_interval=1000 \ 87 | data.batch_size=64 \ 88 | optimizer.grad_acc_steps=4 \ 89 | data.num_workers=16 \ 90 | ``` 91 | 92 | Train for 5000 steps, evaluate every 1000 steps, save the model every 1000 steps at a global batch size of 64 with 4 gradient accumulation steps (i.e., a local batch size of 16). Evaluation is done via (a fork of) [`lm-evaluation-harness`](https://github.com/bminixhofer/lm-evaluation-harness) and runs the tasks configured via `eval.tasks`. 93 | 94 | ``` 95 | student.pretrained_model_name_or_path=benjamin/Llama-3.2-3B-Instruct-flax \ 96 | student.tokenizer_name=\'meta-llama/Llama-3.2-3B-Instruct:source=Llama3\' \ 97 | target_tokenizer_name=\'Qwen/Qwen2.5-1.5B:source=Qwen2:target=Llama3\' \ 98 | ``` 99 | 100 | The (local or HF hub) paths to the model to transfer. If we do not specify a separate teacher, the teacher will be the student with the original tokenizer (this is what we want for tokenizer transfer). Notably: 101 | 102 | - The model is `benjamin/Llama-3.2-3B-Instruct-flax` since the original `meta-llama/Llama-3.2-3B-Instruct` model is not in Flax format. You can convert supported models to Flax using the `scripts/push_flax_version_to_hub.py` script. 103 | - The tokenizer is specified using a tokenizer spec which differs from the HuggingFace `AutoTokenizer` format by including additional colon-separated tags. For example: `Qwen/Qwen2.5-1.5B:source=Qwen2:target=Llama3` specifies the Qwen2.5-1B-Instruct tokenizer initially stemming from the Qwen2 model family (`source=Qwen2`) updated to use the special tokens of the Llama3 family instead (`target=Llama3`). See the [byteification](./byteification.md) guide for more details on the interface tokenkit provides to use HuggingFace tokenizers. For our purposes in this guide, it is important that when you transfer across tokenizers, you can choose to either (i) preserve the original special tokens (safer but potentially inconvenient) or (ii) use the special tokens from the new tokenizer (less safe but potentially more convenient). More on this below in [To Keep or to Change Special Tokens?](#to-keep-or-to-change-the-special-tokens). 104 | 105 | ``` 106 | name=$NAME 107 | ``` 108 | 109 | The name to track the experiment with. By default, `tokenkit` uses [Weights & Biases](https://www.wandb.ai/) to track experiments. 110 | 111 | This is all, you can now transfer your first model! 112 | 113 | ## Transfer to Bytes 114 | 115 | We need to make a couple of adjustments to enable effective transfer to byte-level tokenizers. Let's compare the example config in [`examples/llama3_to_byte_tokenizer_gpu.sh`](../examples/llama3_to_byte_tokenizer_gpu.sh) to the config we used above: 116 | 117 | ```diff 118 | - losses=[sft,alm_unconstrained] \ 119 | + losses=[sft,alm_unconstrained,alm_latents] \ 120 | ``` 121 | 122 | For transfer to bytes, the ALM latent (hidden-state alignment) objective substantially improves performance. 123 | 124 | ```diff 125 | - train_model_mode=lora \ 126 | - model_lora_rank=64 \ 127 | - model_lora_alpha=64 \ 128 | + train_model_mode=full \ 129 | + expand_input_ids=true \ 130 | + output_embeddings_mode=untie \ 131 | ``` 132 | 133 | We train the full model to give it more capacity to adapt to the fundamental change in tokenization. We also *expand* the input IDs to inject some extra parameters while preserving total FLOPs. What input ID expansion does is: for every byte embedding, add the subword embedding of the longest matching subword ending at this byte position (where the subwords and subword embeddings are taken from the original tokenizer and embedding matrix). Finally, we untie the byte input and output embeddings, since there is no reason to tie them in the byte-level case (we don't save any considerable amount of parameters). This may also marginally improve performance. 134 | 135 | ```diff 136 | - target_tokenizer_name=\'Qwen/Qwen2.5-1.5B:source=Qwen2:target=Llama3\' \ 137 | + target_tokenizer_name=\'meta-llama/Llama-3.2-3B-Instruct:source=Llama3:conversion=byte\' 138 | ``` 139 | 140 | The target tokenizer is now specified using a tokenizer spec which includes the conversion to bytes. See the [byteification](./byteification.md#tokenizer-spec) guide for details on this spec. 141 | 142 | ## Exporting the Model 143 | 144 | `tokenkit` uses a custom internal format to checkpoint model fine-tuning parameter diffs. To export a checkpoint to the HuggingFace format, run e.g.: 145 | 146 | ```bash 147 | # --with_pt exports the model in PyTorch format (in addition to Flax) 148 | python3 scripts/export_checkpoint.py \ 149 | --checkpoint_path=outputs/cross_tokenizer_distill/step_5000 \ 150 | --output=checkpoints/llama3_to_qwen2_tokenizer_hf \ 151 | --with_pt 152 | ``` 153 | 154 | If you are exporting a model which has been trained with input ID expansion, you need to also specify which embeddings and tokenizer to use for expansion, e.g.: 155 | 156 | ```bash 157 | python3 scripts/export_checkpoint.py \ 158 | --checkpoint_path=outputs/cross_tokenizer_distill/step_5000 \ 159 | --output=checkpoints/llama3_to_bytes \ 160 | --with_pt \ 161 | --expand_input_ids_model=benjamin/Llama-3.2-3B-Instruct-flax \ 162 | --expand_input_ids_tokenizer=meta-llama/Llama-3.2-3B-Instruct:source=Llama3 163 | ``` 164 | 165 | Afterwards, you can load the model as usual using HuggingFace transformers: 166 | 167 | ```python 168 | from tranformers import AutoModelForCausalLM 169 | from tokenkit.byteify import load_byteify_tokenizer 170 | 171 | model = AutoModelForCausalLM.from_pretrained("checkpoints/llama3_to_bytes", trust_remote_code=True) 172 | tokenizer = load_byteify_tokenizer("meta-llama/Llama-3.2-3B-Instruct:source=Llama3:conversion=byte") 173 | 174 | tokens = tokenizer.tokenizer.apply_chat_template([{"role": "user", "content": "Hello, how are you?"}], return_tensors="pt") 175 | output = model.generate(tokens) 176 | print(tokenizer.decode(output[0])) 177 | ``` 178 | 179 | ## To Keep or to Change the Special Tokens? 180 | 181 | In the above example where we transferred to the Qwen2 tokenizer, by using the target tokenizer spec `Qwen/Qwen2.5-1.5B:source=Qwen2:target=Llama3`, we transferred to a tokenizer using all the *regular* tokens from the Qwen2 tokenizer, but keeping the special tokens (and the chat template) from the Llama3 tokenizer. We can instead transfer to a tokenizer which is completely equivalent to the Qwen2 tokenizer (regular and special tokens) by specifying it as `Qwen/Qwen2.5-1.5B:source=Qwen2:target=Qwen2`. What to choose depends on your use case: 182 | 183 | - *Keeping the special tokens:* This is the safer choice, since the model will not have to learn to use a new chat template format with new special tokens. If you just want to, for example, transfer to a new tokenizer which encodes some domain more efficiently, this is the better choice. 184 | - *Changing the special tokens:* If you are using tokenizer transfer to combine (e.g., ensemble) multiple models, this is more convenient since we don't need to worry about aligning the different special tokens and chat templates to each other (which is quite easy to do, but still inconvenient). However, there's some things to be careful about: for example, transferring Gemma2 to the Llama3 chat template is quite easy since both use similar formats and both use a \ token. However, transferring Gemma2 to the Qwen2 chat template is not as straightforward *since Gemma2 uses a \ token, but Qwen2 doesn't*. The model thus has to learn to re-distribute the original attention sink behavior of the \ token across other tokens. This may or may not work well, depending on the training budget, dataset and so on. 185 | 186 | --- 187 |

Next: Byteification: A Unified Interface to Tokenizers

-------------------------------------------------------------------------------- /examples/gemma2_distill_from_openmath2-llama_gpu.sh: -------------------------------------------------------------------------------- 1 | NAME=gemma2_distill_from_openmath2-llama 2 | python3 scripts/cross_tokenizer_distill.py \ 3 | --config=configs/math_cross_tokenizer_distill.yaml \ 4 | --overrides \ 5 | losses=[sft,alm_unconstrained] \ 6 | alm_mode=merge_by_space_prob+append_space \ 7 | tokenizer_pair_bias_threshold=0.1 \ 8 | max_teacher_length=1024 \ 9 | max_student_length=1024 \ 10 | n_data_parallel=1 \ 11 | n_model_parallel=1 \ 12 | steps=5000 \ 13 | eval_interval=5000 \ 14 | save_interval=5000 \ 15 | optimizer.learning_rate=5.e-6 \ 16 | optimizer.weight_decay=0.0 \ 17 | optimizer.max_grad_norm=null \ 18 | optimizer.grad_acc_steps=4 \ 19 | eval.tasks=[math_500_openmath2,gsm8k_openmath2] \ 20 | eval.lengths=[2048] \ 21 | eval.tokens_per_batch=16384 \ 22 | eval.chat_template_mode=direct_encode_no_force_eos \ 23 | data.batch_size=64 \ 24 | log_interval=10 \ 25 | sync_interval=100 \ 26 | use_chat_template=true \ 27 | chat_template_mode=direct_encode \ 28 | hypernet.architecture=identity \ 29 | train_embeddings=true \ 30 | train_model_mode=full \ 31 | eval_at_step_zero=false \ 32 | save_at_step_zero=false \ 33 | skip_lm_eval=false \ 34 | num_workers=24 \ 35 | name=$NAME 36 | -------------------------------------------------------------------------------- /examples/llama3_to_byte_tokenizer_gpu.sh: -------------------------------------------------------------------------------- 1 | NAME=llama3_to_byte 2 | python3 scripts/cross_tokenizer_distill.py \ 3 | --config=configs/cross_tokenizer_distill.yaml \ 4 | --overrides \ 5 | losses=[sft,alm_unconstrained,alm_latents] \ 6 | alm_mode=merge_by_space_prob+append_space \ 7 | tokenizer_pair_bias_threshold=0.1 \ 8 | hypernet.architecture=identity \ 9 | multitask_aggregation_fn=approx_gradmag_preserve_mag \ 10 | train_model_mode=full \ 11 | expand_input_ids=true \ 12 | output_embeddings_mode=untie \ 13 | n_data_parallel=1 \ 14 | n_model_parallel=1 \ 15 | steps=5000 \ 16 | eval_interval=1000 \ 17 | save_interval=1000 \ 18 | data.batch_size=64 \ 19 | optimizer.grad_acc_steps=4 \ 20 | data.num_workers=16 \ 21 | data.batch_size=64 \ 22 | student.pretrained_model_name_or_path="benjamin/Llama-3.2-3B-Instruct-flax" \ 23 | student.tokenizer_name=\'meta-llama/Llama-3.2-3B-Instruct:source=Llama3\' \ 24 | target_tokenizer_name=\'meta-llama/Llama-3.2-3B-Instruct:source=Llama3:conversion=byte\' \ 25 | num_workers=16 \ 26 | name=$NAME -------------------------------------------------------------------------------- /examples/llama3_to_qwen2_tokenizer_gpu.sh: -------------------------------------------------------------------------------- 1 | NAME=llama3_to_qwen2_tokenizer 2 | python3 scripts/cross_tokenizer_distill.py \ 3 | --config=configs/cross_tokenizer_distill.yaml \ 4 | --overrides \ 5 | losses=[sft,alm_unconstrained] \ 6 | alm_mode=merge_by_space_prob+append_space \ 7 | tokenizer_pair_bias_threshold=0.1 \ 8 | hypernet.architecture=identity \ 9 | multitask_aggregation_fn=approx_gradmag_preserve_mag \ 10 | train_model_mode=lora \ 11 | model_lora_rank=64 \ 12 | model_lora_alpha=64 \ 13 | n_data_parallel=1 \ 14 | n_model_parallel=1 \ 15 | steps=5000 \ 16 | eval_interval=1000 \ 17 | save_interval=1000 \ 18 | data.batch_size=64 \ 19 | optimizer.grad_acc_steps=4 \ 20 | data.num_workers=16 \ 21 | data.batch_size=64 \ 22 | student.pretrained_model_name_or_path="benjamin/Llama-3.2-3B-Instruct-flax" \ 23 | student.tokenizer_name=\'meta-llama/Llama-3.2-3B-Instruct:source=Llama3\' \ 24 | target_tokenizer_name=\'Qwen/Qwen2.5-1.5B:source=Qwen2:target=Llama3\' \ 25 | num_workers=16 \ 26 | name=$NAME -------------------------------------------------------------------------------- /examples/llama3_to_qwen2_tokenizer_tpu.sh: -------------------------------------------------------------------------------- 1 | # tpuv3-8 2 | N_DATA_PARALLEL=1 3 | N_MODEL_PARALLEL=8 4 | 5 | # tpuv4-8 6 | # N_DATA_PARALLEL=1 7 | # N_MODEL_PARALLEL=4 8 | 9 | # tpuv4-32 10 | # N_DATA_PARALLEL=4 11 | # N_MODEL_PARALLEL=4 12 | 13 | 14 | NAME=llama3_to_qwen2_tokenizer 15 | python3 scripts/cross_tokenizer_distill.py \ 16 | --config=configs/cross_tokenizer_distill.yaml \ 17 | --overrides \ 18 | losses=[sft,alm_unconstrained] \ 19 | alm_mode=merge_by_space_prob+append_space \ 20 | tokenizer_pair_bias_threshold=0.1 \ 21 | hypernet.architecture=identity \ 22 | multitask_aggregation_fn=approx_gradmag_preserve_mag \ 23 | train_model_mode=lora \ 24 | model_lora_rank=64 \ 25 | model_lora_alpha=64 \ 26 | n_data_parallel=$N_DATA_PARALLEL \ 27 | n_model_parallel=$N_MODEL_PARALLEL \ 28 | steps=5000 \ 29 | eval_interval=1000 \ 30 | save_interval=1000 \ 31 | data.batch_size=64 \ 32 | optimizer.grad_acc_steps=4 \ 33 | data.num_workers=16 \ 34 | data.batch_size=64 \ 35 | student.pretrained_model_name_or_path="benjamin/Llama-3.2-3B-Instruct-flax" \ 36 | student.tokenizer_name=\'meta-llama/Llama-3.2-3B-Instruct:source=Llama3\' \ 37 | target_tokenizer_name=\'Qwen/Qwen2.5-1.5B:source=Qwen2:target=Llama3\' \ 38 | num_workers=16 \ 39 | name=$NAME -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | line-length = 88 3 | target-version = ['py37'] 4 | include = '\.pyi?$' 5 | extend-exclude = ''' 6 | # A regex preceded with ^/ will apply only to files and directories 7 | # in the root of the project. 8 | ^/build/ 9 | ''' 10 | 11 | [tool.isort] 12 | profile = "black" 13 | multi_line_output = 3 14 | include_trailing_comma = true 15 | force_grid_wrap = 0 16 | use_parentheses = true 17 | ensure_newline_before_comments = true 18 | line_length = 88 -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | hydra-core==1.3.2 2 | omegaconf==2.3.0 3 | tokenizers==0.20.3 4 | datasets==3.2.0 5 | scipy==1.14.1 6 | google-cloud-storage==2.19.0 7 | ipython==8.28.0 8 | ipdb==0.13.13 9 | flax==0.10.2 10 | torchdata==0.8.0 11 | wandb==0.19.1 12 | pytest==8.3.5 13 | lingvo==0.12.7 14 | optax==0.2.4 15 | numpy==1.26.4 16 | fiddle==0.3.0 17 | jaxtyping==0.2.36 18 | typeguard==4.4.1 19 | jax-bitempered-loss==0.0.2 20 | editdistance==0.8.1 21 | langdetect==1.0.9 22 | immutabledict==4.2.1 23 | maturin==1.8.3 24 | transformers==4.46.0 25 | git+https://github.com/google/CommonLoopUtils@307b0bc65ae2e2d801b6c19df2431c060a0aa4ec -------------------------------------------------------------------------------- /rust_utils/.gitignore: -------------------------------------------------------------------------------- 1 | target -------------------------------------------------------------------------------- /rust_utils/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "rust_utils" 3 | version = "0.14.1-dev.0" 4 | edition = "2021" 5 | 6 | [lib] 7 | name = "rust_utils" 8 | crate-type = ["cdylib"] 9 | 10 | [dependencies] 11 | pyo3 = { version = "0.19", features = ["serde"]} 12 | onig = { version = "6.0", default-features = false } 13 | rand = "0.8" 14 | rand_distr = "0.4.3" 15 | 16 | [dependencies.tokenizers] 17 | version = "0.20.3" 18 | default-features = false 19 | features = ["onig"] 20 | 21 | [dev-dependencies] 22 | tempfile = "3.1" 23 | pyo3 = { version = "0.19", features = ["auto-initialize"] } 24 | 25 | [features] 26 | defaut = ["pyo3/extension-module"] 27 | -------------------------------------------------------------------------------- /rust_utils/pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["maturin>=1.4,<2.0"] 3 | build-backend = "maturin" 4 | 5 | [project] 6 | name = "rust_utils" 7 | requires-python = ">=3.8" 8 | classifiers = [ 9 | "Programming Language :: Rust", 10 | "Programming Language :: Python :: Implementation :: CPython", 11 | "Programming Language :: Python :: Implementation :: PyPy", 12 | ] 13 | dynamic = ["version"] 14 | 15 | [tool.maturin] 16 | features = ["pyo3/extension-module"] 17 | -------------------------------------------------------------------------------- /rust_utils/rust-toolchain: -------------------------------------------------------------------------------- 1 | stable 2 | -------------------------------------------------------------------------------- /rust_utils/tpu_build.sh: -------------------------------------------------------------------------------- 1 | # otherwise use 'maturin develop --release' 2 | set -e 3 | 4 | . ../tokenkit_env/bin/activate 5 | python3 -m maturin build --release 6 | mv $PWD/target/wheels/rust_utils-0.14.1.dev0-cp310-cp310-manylinux_2_34_x86_64.whl $PWD/target/wheels/rust_utils-0.14.1.dev0-cp310-none-any.whl 7 | pip install --force-reinstall $PWD/target/wheels/rust_utils-0.14.1.dev0-cp310-none-any.whl 8 | #pip install fsspec==2023.9.2 9 | #pip install --upgrade huggingface_hub datasets 10 | # TODO: do we need the above? if yes, pin versions / check why -------------------------------------------------------------------------------- /scripts/compute_mined_mapping.py: -------------------------------------------------------------------------------- 1 | import json 2 | from pathlib import Path 3 | 4 | import numpy as np 5 | from dataclasses import dataclass 6 | 7 | from tokenkit import baseline_utils 8 | from tokenkit.byteify import load_byteify_tokenizer 9 | from tokenkit import parse_args 10 | 11 | @dataclass 12 | class ComputedMinedMappingArgs: 13 | teacher_tokenizer_name: str 14 | target_tokenizer_name: str 15 | output: str 16 | num_workers: int 17 | 18 | def main(args: ComputedMinedMappingArgs) -> None: 19 | output_dir = Path(args.output) 20 | output_dir.mkdir(exist_ok=True, parents=True) 21 | 22 | tokenizer_teacher = load_byteify_tokenizer(args.teacher_tokenizer_name) 23 | target_tokenizer = load_byteify_tokenizer(args.target_tokenizer_name) 24 | 25 | mined_mapping, mined_distances = baseline_utils.compute_mined_mapping( 26 | tokenizer_teacher, target_tokenizer, num_workers=args.num_workers 27 | ) 28 | 29 | np.save(output_dir / "mined_mapping.npy", mined_mapping) 30 | json.dump( 31 | mined_distances, 32 | open(output_dir / "mined_distances.json", "w"), 33 | indent=4, 34 | ) 35 | 36 | 37 | if __name__ == "__main__": 38 | main(parse_args.parse_args(ComputedMinedMappingArgs)) 39 | -------------------------------------------------------------------------------- /scripts/compute_tokenizer_info.py: -------------------------------------------------------------------------------- 1 | """ 2 | Example Usage: 3 | 4 | ipython --pdb scripts/compute_tokenizer_info.py -- \ 5 | teacher_tokenizer_name=google/gemma-2-2b-it:source=Gemma2 \ 6 | target_tokenizer_name=Qwen/Qwen2.5-1.5B:source=Qwen2:target=Gemma2 \ 7 | output='outputs/tokenizer_data/gemma2_to_qwen2_new' 8 | """ 9 | 10 | import json 11 | import os 12 | import pickle 13 | from collections import Counter 14 | from functools import partial 15 | from pathlib import Path 16 | 17 | import hydra 18 | import numpy as np 19 | from scipy import sparse 20 | from torch.utils.data import DataLoader 21 | from tqdm.auto import tqdm 22 | 23 | from tokenkit import data 24 | from tokenkit.byteify import load_byteify_tokenizer 25 | 26 | 27 | def compute_prefix_map(tokenizer): 28 | prefix_map = {} 29 | 30 | for token in tokenizer.get_vocab().keys(): 31 | for i in range(1, len(token) + 1): 32 | if token[:i] in prefix_map: 33 | prefix_map[token[:i]].append(token) 34 | else: 35 | prefix_map[token[:i]] = [token] 36 | 37 | return prefix_map 38 | 39 | 40 | def is_valid(tokens, tokenizer): 41 | try: 42 | return tokenizer.backend_tokenize("".join(tokens)) == tokens 43 | except UnicodeDecodeError: 44 | return False 45 | 46 | 47 | def compute_cover_set(pretoken, tokenizer, prefix_map): 48 | cover_set = [] 49 | for i in range(len(pretoken) - 1, -1, -1): 50 | B = prefix_map.get(pretoken[i:], []) 51 | try: 52 | tcur = tokenizer.backend_tokenize(pretoken[:i]) 53 | except UnicodeDecodeError: 54 | continue 55 | 56 | for b in B: 57 | if is_valid(tcur + [b], tokenizer): 58 | cover_set.append(tcur + [b]) 59 | 60 | return cover_set 61 | 62 | 63 | def compute_cover_dict(pretoken, tokenizer, prefix_map): 64 | cover_set = compute_cover_set(pretoken, tokenizer, prefix_map) 65 | cover_dict = {} 66 | 67 | for seq in cover_set: 68 | joined_seq = "".join(seq)[len(pretoken) :] 69 | if len(joined_seq) == 0: 70 | continue 71 | 72 | cover_dict[joined_seq] = tokenizer.convert_tokens_to_ids(seq) 73 | 74 | return cover_dict 75 | 76 | 77 | def compute_pair_bias( 78 | pretoken1, 79 | pretoken2, 80 | tokenizer1, 81 | tokenizer2, 82 | prefix_map1, 83 | prefix_map2, 84 | probs1, 85 | probs2, 86 | return_diff_cover_dicts=False, 87 | ): 88 | cover_dict1 = compute_cover_dict(pretoken1, tokenizer1, prefix_map1) 89 | cover_dict2 = compute_cover_dict(pretoken2, tokenizer2, prefix_map2) 90 | 91 | diff_keys1 = set(cover_dict1.keys()) - set(cover_dict2.keys()) 92 | diff_keys2 = set(cover_dict2.keys()) - set(cover_dict1.keys()) 93 | 94 | bias1 = 0.0 95 | for key in diff_keys1: 96 | bias1 += probs1[cover_dict1[key][-1]] 97 | 98 | bias2 = 0.0 99 | for key in diff_keys2: 100 | bias2 += probs2[cover_dict2[key][-1]] 101 | 102 | if return_diff_cover_dicts: 103 | diff_cover_set1 = {key: probs1[cover_dict1[key][-1]] for key in diff_keys1} 104 | diff_cover_set2 = {key: probs2[cover_dict2[key][-1]] for key in diff_keys2} 105 | return bias1, bias2, diff_cover_set1, diff_cover_set2 106 | else: 107 | return bias1, bias2 108 | 109 | 110 | def count_tokens_map(examples, tokenizer): 111 | flat_input_ids = [ 112 | input_id 113 | for input_ids in tokenizer(examples["text"], add_special_tokens=False)[ 114 | "input_ids" 115 | ] 116 | for input_id in input_ids 117 | ] 118 | return { 119 | "counter": pickle.dumps(Counter(flat_input_ids)), 120 | } 121 | 122 | 123 | def count_tokens(dset, tokenizer, num_workers): 124 | token_counters_dset = dset.map( 125 | partial(count_tokens_map, tokenizer=tokenizer), 126 | batched=False, # already batched 127 | num_proc=num_workers if num_workers > 0 else None, 128 | remove_columns=dset.column_names, 129 | desc="Counting tokens", 130 | ) 131 | 132 | global_token_counter = Counter() 133 | for i in tqdm(range(len(token_counters_dset)), desc="Merging token counters"): 134 | global_token_counter.update(pickle.loads(token_counters_dset[i]["counter"])) 135 | 136 | return global_token_counter 137 | 138 | 139 | if __name__ == "__main__": 140 | args = None 141 | 142 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 143 | 144 | def _parse_args_fn(local_args): 145 | global args 146 | args = local_args 147 | 148 | hydra.main( 149 | version_base=None, 150 | config_path="../configs", 151 | config_name="compute_tokenizer_info", 152 | )(_parse_args_fn)() 153 | 154 | output_dir = Path(args.output) 155 | output_dir.mkdir(exist_ok=True, parents=True) 156 | 157 | tokenizer_teacher = load_byteify_tokenizer(args.teacher_tokenizer_name) 158 | target_tokenizer = load_byteify_tokenizer(args.target_tokenizer_name) 159 | 160 | dset = data.get_dataset(**args.data, seed=args.seed) 161 | 162 | if not (output_dir / "teacher_counts.json").exists(): 163 | assert isinstance(dset, data.HFDataset) 164 | dset_to_use = dset.stream 165 | 166 | if args.teacher_subsample_percent is not None: 167 | n_subsample = int(len(dset_to_use) * args.teacher_subsample_percent) 168 | dset_to_use = dset_to_use.select(np.arange(n_subsample)) 169 | 170 | teacher_token_counts = count_tokens( 171 | dset_to_use, tokenizer_teacher, args.data.num_workers 172 | ) 173 | json.dump( 174 | teacher_token_counts, 175 | open(output_dir / "teacher_counts.json", "w"), 176 | indent=4, 177 | ) 178 | else: 179 | teacher_token_counts = Counter( 180 | json.load(open(output_dir / "teacher_counts.json")) 181 | ) 182 | if not (output_dir / "student_counts.json").exists(): 183 | assert isinstance(dset, data.HFDataset) 184 | dset_to_use = dset.stream 185 | 186 | if args.student_subsample_percent is not None: 187 | n_subsample = int(len(dset_to_use) * args.student_subsample_percent) 188 | dset_to_use = dset_to_use.select(np.arange(n_subsample)) 189 | 190 | student_token_counts = count_tokens( 191 | dset_to_use, target_tokenizer, args.data.num_workers 192 | ) 193 | json.dump( 194 | student_token_counts, 195 | open(output_dir / "student_counts.json", "w"), 196 | indent=4, 197 | ) 198 | else: 199 | student_token_counts = Counter( 200 | json.load(open(output_dir / "student_counts.json")) 201 | ) 202 | 203 | if not (output_dir / "pairs.json").exists(): 204 | teacher_tokens_dict = {} 205 | for token in sorted( 206 | tokenizer_teacher.get_vocab().keys(), key=lambda x: x[::-1] 207 | ): 208 | if token[-1] not in teacher_tokens_dict: 209 | teacher_tokens_dict[token[-1]] = [] 210 | 211 | teacher_tokens_dict[token[-1]].append(token) 212 | 213 | student_tokens_dict = {} 214 | for token in sorted(target_tokenizer.get_vocab().keys(), key=lambda x: x[::-1]): 215 | if token[-1] not in student_tokens_dict: 216 | student_tokens_dict[token[-1]] = [] 217 | 218 | student_tokens_dict[token[-1]].append(token) 219 | 220 | pairs = [] 221 | for last_byte in tqdm( 222 | set(teacher_tokens_dict.keys()) & set(student_tokens_dict.keys()) 223 | ): 224 | for teacher_token in teacher_tokens_dict[last_byte]: 225 | for student_token in student_tokens_dict[last_byte]: 226 | if teacher_token.endswith(student_token) or student_token.endswith( 227 | teacher_token 228 | ): 229 | pairs.append((teacher_token, student_token)) 230 | 231 | json.dump(pairs, open(output_dir / "pairs.json", "w"), indent=4) 232 | else: 233 | pairs = json.load(open(output_dir / "pairs.json")) 234 | 235 | print(f"Found {len(pairs)} pairs") 236 | 237 | prefix_map_teacher = compute_prefix_map(tokenizer_teacher) 238 | prefix_map_student = compute_prefix_map(target_tokenizer) 239 | 240 | teacher_counts_sum = sum(teacher_token_counts.values()) 241 | teacher_token_probs = np.array( 242 | [ 243 | teacher_token_counts[token_id] 244 | + args.additive_smoothing_constant * teacher_counts_sum 245 | for token_id in range(len(tokenizer_teacher)) 246 | ], 247 | dtype=np.float32, 248 | ) 249 | teacher_token_probs /= teacher_token_probs.sum() 250 | 251 | student_counts_sum = sum(student_token_counts.values()) 252 | student_token_probs = np.array( 253 | [ 254 | student_token_counts[token_id] 255 | + args.additive_smoothing_constant * student_counts_sum 256 | for token_id in range(len(target_tokenizer)) 257 | ], 258 | dtype=np.float32, 259 | ) 260 | student_token_probs /= student_token_probs.sum() 261 | 262 | is_space_only_teacher = { 263 | tokenizer_teacher.convert_ids_to_tokens(i): len( 264 | tokenizer_teacher.decode(i).strip() 265 | ) 266 | == 0 267 | for i in range(len(tokenizer_teacher)) 268 | } 269 | is_space_only_student = { 270 | target_tokenizer.convert_ids_to_tokens(i): len( 271 | target_tokenizer.decode(i).strip() 272 | ) 273 | == 0 274 | for i in range(len(target_tokenizer)) 275 | } 276 | 277 | def pair_collate(pairs): 278 | biases1 = [] 279 | biases2 = [] 280 | for pair in pairs: 281 | if is_space_only_teacher[pair[0]] or is_space_only_student[pair[1]]: 282 | biases1.append(1.0) # can take long to compute and likely high 283 | biases2.append(1.0) 284 | continue 285 | bias1, bias2 = compute_pair_bias( 286 | *pair, 287 | tokenizer_teacher, 288 | target_tokenizer, 289 | prefix_map_teacher, 290 | prefix_map_student, 291 | teacher_token_probs, 292 | student_token_probs, 293 | ) 294 | biases1.append(bias1) 295 | biases2.append(bias2) 296 | 297 | return { 298 | "biases1": biases1, 299 | "biases2": biases2, 300 | } 301 | 302 | pair_permutation = np.random.permutation(len(pairs)) 303 | inv_pair_permutation = np.argsort(pair_permutation) 304 | 305 | biases1 = [] 306 | biases2 = [] 307 | pair_data_loader = DataLoader( 308 | [pairs[i] for i in pair_permutation], 309 | batch_size=args.data.batch_size, 310 | num_workers=args.data.num_workers, 311 | collate_fn=pair_collate, 312 | ) 313 | 314 | for batch in tqdm(pair_data_loader, desc="Computing pair biases"): 315 | biases1.extend(batch["biases1"]) 316 | biases2.extend(batch["biases2"]) 317 | 318 | biases1 = np.array(biases1)[inv_pair_permutation] 319 | biases2 = np.array(biases2)[inv_pair_permutation] 320 | 321 | bias1_matrix = sparse.coo_matrix( 322 | ( 323 | biases1, 324 | ( 325 | np.array(tokenizer_teacher.convert_tokens_to_ids(x[0] for x in pairs)), 326 | np.array(target_tokenizer.convert_tokens_to_ids(x[1] for x in pairs)), 327 | ), 328 | ), 329 | shape=(len(tokenizer_teacher), len(target_tokenizer)), 330 | ) 331 | bias2_matrix = sparse.coo_matrix( 332 | ( 333 | biases2, 334 | ( 335 | np.array(tokenizer_teacher.convert_tokens_to_ids(x[0] for x in pairs)), 336 | np.array(target_tokenizer.convert_tokens_to_ids(x[1] for x in pairs)), 337 | ), 338 | ), 339 | shape=(len(tokenizer_teacher), len(target_tokenizer)), 340 | ) 341 | 342 | sparse.save_npz(output_dir / "bias1_matrix.npz", bias1_matrix) 343 | sparse.save_npz(output_dir / "bias2_matrix.npz", bias2_matrix) 344 | -------------------------------------------------------------------------------- /scripts/eval.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | from functools import partial 4 | from pathlib import Path 5 | from pprint import pformat, pprint 6 | import yaml 7 | 8 | import datasets 9 | import jax 10 | import jax.numpy as jnp 11 | import numpy as np 12 | from jax.experimental import multihost_utils 13 | from jax.sharding import NamedSharding 14 | from jax.sharding import PartitionSpec as P 15 | from transformers import FlaxAutoModelForCausalLM 16 | from dataclasses import dataclass, asdict 17 | 18 | from tokenkit.hf import get_config 19 | from tokenkit import utils, parse_args 20 | from tokenkit.byteify import load_byteify_tokenizer 21 | from tokenkit.eval import ATOL, evaluate, score 22 | from tokenkit.models import param, sharding 23 | 24 | logger = logging.getLogger(__name__) 25 | 26 | datasets.config.HF_DATASETS_TRUST_REMOTE_CODE = ( 27 | True # careful about this, required for lm_eval 28 | ) 29 | 30 | @dataclass 31 | class EvalScriptArgs: 32 | model: parse_args.ModelArgs 33 | expand_model: parse_args.ModelArgs 34 | eval: parse_args.EvalArgs 35 | output: str | None = None 36 | pad_to_multiple_of: int = 128 37 | use_cpu: bool = False 38 | 39 | 40 | def pad_embeddings(embeddings, tokenizer): 41 | n_embed_diff = len(tokenizer) - len(embeddings) 42 | 43 | embeddings_mean = embeddings.mean(0) 44 | embeddings_std = embeddings.std(0) 45 | 46 | return np.concatenate( 47 | [ 48 | embeddings, 49 | np.random.normal( 50 | size=(n_embed_diff, *embeddings.shape[1:]), 51 | ) 52 | * embeddings_std[None] 53 | + embeddings_mean[None], 54 | ] 55 | ) 56 | 57 | 58 | def main(args: EvalScriptArgs) -> None: 59 | logger.info(pformat(args)) 60 | 61 | model_kwargs = asdict(args.model) 62 | eval_kwargs = asdict(args.eval) 63 | 64 | if args.use_cpu: 65 | jax.config.update("jax_default_device", jax.devices("cpu")[0]) 66 | mesh = sharding.get_mesh(devices=jax.devices("cpu")) 67 | else: 68 | mesh = sharding.get_mesh() 69 | 70 | if args.output is not None: 71 | output_dir = Path(args.output) 72 | output_dir.mkdir(parents=True, exist_ok=True) 73 | 74 | with open(output_dir / "args.yaml", "w") as f: 75 | yaml.dump(asdict(args), f) 76 | else: 77 | output_dir = None 78 | 79 | tokenizer = load_byteify_tokenizer(model_kwargs.pop("tokenizer_name")) 80 | 81 | config = get_config(**model_kwargs) 82 | config.max_length = eval_kwargs["lengths"][-1] 83 | config.mesh = mesh 84 | 85 | model = FlaxAutoModelForCausalLM.from_config(config, _do_init=False) 86 | 87 | params = param.load_params(**model_kwargs) 88 | 89 | if args.expand_model.pretrained_model_name_or_path is not None: 90 | expand_model_kwargs = asdict(args.expand_model) 91 | 92 | expand_tokenizer = load_byteify_tokenizer( 93 | expand_model_kwargs.pop("tokenizer_name") 94 | ) 95 | expand_config = get_config(**expand_model_kwargs) 96 | expand_vocab = expand_tokenizer.get_vocab() 97 | 98 | expand_input_ids_model_params = param.load_params(**expand_model_kwargs) 99 | expand_input_ids_embeddings = param.get( 100 | expand_input_ids_model_params, 101 | param.get_input_embedding_path(expand_config.model_type), 102 | ) 103 | 104 | n_overflow = expand_input_ids_embeddings.shape[0] % args.pad_to_multiple_of 105 | if n_overflow > 0: 106 | n_pad = args.pad_to_multiple_of - n_overflow 107 | else: 108 | n_pad = 0 109 | 110 | expand_input_ids_embeddings = np.pad( 111 | expand_input_ids_embeddings, 112 | ((0, n_pad), (0, 0)), 113 | mode="constant", 114 | constant_values=0, 115 | ) 116 | else: 117 | expand_tokenizer = None 118 | expand_vocab = None 119 | expand_input_ids_embeddings = None 120 | 121 | input_embeddings = param.get( 122 | params, param.get_input_embedding_path(config.model_type) 123 | ) 124 | input_embeddings = input_embeddings[: len(tokenizer)] 125 | 126 | if len(input_embeddings) < len(tokenizer): 127 | print("Padding input embeddings...") 128 | input_embeddings = pad_embeddings(input_embeddings, tokenizer) 129 | 130 | if not config.tie_word_embeddings: 131 | output_embeddings = param.get( 132 | params, param.get_output_embedding_path(config.model_type) 133 | ) 134 | output_embeddings = output_embeddings[:, : len(tokenizer)] 135 | print("Padding output embeddings...") 136 | output_embeddings = pad_embeddings(output_embeddings.T, tokenizer).T 137 | else: 138 | output_embeddings = None 139 | 140 | n_overflow = input_embeddings.shape[0] % args.pad_to_multiple_of 141 | if n_overflow > 0: 142 | n_pad = args.pad_to_multiple_of - n_overflow 143 | else: 144 | n_pad = 0 145 | 146 | input_embeddings = np.pad( 147 | input_embeddings, 148 | ((0, n_pad), (0, 0)), 149 | mode="constant", 150 | constant_values=0, 151 | ) 152 | if output_embeddings is not None: 153 | output_embeddings = np.pad( 154 | output_embeddings, 155 | ((0, 0), (0, n_pad)), 156 | mode="constant", 157 | constant_values=0, 158 | ) 159 | logit_mask = np.zeros((input_embeddings.shape[0],), dtype=bool) 160 | logit_mask[: model.config.vocab_size] = True 161 | 162 | model.config.vocab_size = input_embeddings.shape[0] 163 | 164 | params = param.put( 165 | params, param.get_input_embedding_path(config.model_type), input_embeddings 166 | ) 167 | if output_embeddings is not None: 168 | params = param.put( 169 | params, 170 | param.get_output_embedding_path(config.model_type), 171 | output_embeddings, 172 | ) 173 | 174 | if expand_input_ids_embeddings is not None: 175 | # expects stacked embedding format 176 | params["original_embeddings"] = expand_input_ids_embeddings[:, None, :] 177 | 178 | shard_patterns = sharding.get_shard_patterns(config.model_type) 179 | param_shardings = sharding.get_sharding_fn(shard_patterns, mesh)( 180 | {"params": params} 181 | )["params"] 182 | params = sharding.to_devices(params, param_shardings, dtype=jnp.float32) 183 | 184 | multihost_utils.sync_global_devices("loaded weights") 185 | 186 | jaxlm_kwargs = {"precompile": not args.use_cpu} 187 | 188 | if args.expand_model.pretrained_model_name_or_path is not None: 189 | # TODO: move elsewhere, probably into jaxlm 190 | expand_input_ids_dict = utils.get_expand_input_ids_dict( 191 | tokenizer, 192 | expand_vocab, 193 | ) 194 | 195 | def compute_inputs_embeds(model_params, input_ids, expanded_input_ids): 196 | input_embeddings = param.get( 197 | model_params, param.get_input_embedding_path(config.model_type) 198 | ) 199 | 200 | standard_inputs_embeds = jnp.take( 201 | input_embeddings, 202 | input_ids, 203 | axis=0, 204 | ) 205 | expanded_inputs_embeds = jnp.take( 206 | expand_input_ids_embeddings, 207 | expanded_input_ids, 208 | axis=0, 209 | ) 210 | 211 | inputs_embeds = standard_inputs_embeds + expanded_inputs_embeds 212 | 213 | return inputs_embeds 214 | 215 | @partial( 216 | jax.jit, 217 | static_argnames=("model_fn", "atol"), 218 | in_shardings=( 219 | param_shardings, 220 | NamedSharding(mesh, P()), 221 | NamedSharding(mesh, P()), 222 | NamedSharding(mesh, P()), 223 | NamedSharding(mesh, P()), 224 | NamedSharding(mesh, P()), 225 | NamedSharding(mesh, P()), 226 | ), 227 | out_shardings=(NamedSharding(mesh, P()), NamedSharding(mesh, P())), 228 | ) 229 | def jaxlm_inner_score_fn( 230 | model_fn, 231 | params, 232 | input_ids, 233 | expanded_input_ids, 234 | labels, 235 | suffix_mask, 236 | space_mask, 237 | logit_mask, 238 | atol=ATOL, 239 | ): 240 | inputs_embeds = compute_inputs_embeds( 241 | params, 242 | input_ids, 243 | expanded_input_ids, 244 | ) 245 | return score( 246 | model_fn, 247 | params, 248 | (None, inputs_embeds), 249 | labels=labels, 250 | suffix_mask=suffix_mask, 251 | space_mask=space_mask, 252 | logit_mask=logit_mask, 253 | atol=atol, 254 | ) 255 | 256 | def jaxlm_score_fn(model_fn, params, model_args, *pargs): 257 | (input_ids,) = model_args 258 | 259 | expanded_input_ids = utils.np_expand_input_ids( 260 | input_ids, 261 | expand_input_ids_dict, 262 | ) 263 | 264 | return jaxlm_inner_score_fn( 265 | model_fn, 266 | params, 267 | input_ids, 268 | expanded_input_ids, 269 | *pargs, 270 | ) 271 | 272 | jaxlm_kwargs["expand_input_ids"] = True 273 | jaxlm_kwargs["expand_input_ids_vocab"] = expand_vocab 274 | jaxlm_kwargs["score_fn"] = jaxlm_score_fn 275 | 276 | results, _ = evaluate( 277 | model=model, 278 | config=config, 279 | params=params, 280 | tokenizer=tokenizer, 281 | logit_mask=logit_mask, 282 | output=output_dir, 283 | **eval_kwargs, 284 | jaxlm_kwargs=jaxlm_kwargs, 285 | ) 286 | 287 | if jax.process_index() == 0: 288 | pprint(results) 289 | 290 | 291 | if __name__ == "__main__": 292 | os.environ["HF_ALLOW_CODE_EVAL"] = "1" 293 | main(parse_args.parse_args(EvalScriptArgs)) -------------------------------------------------------------------------------- /scripts/eval_lockstep.py: -------------------------------------------------------------------------------- 1 | """ 2 | Example Usage: 3 | 4 | python3 scripts/eval_lockstep.py models=llama_qwen +eval.limit=100 5 | """ 6 | 7 | import logging 8 | from pathlib import Path 9 | from pprint import pformat, pprint 10 | from dataclasses import dataclass, asdict 11 | import os 12 | import yaml 13 | import datasets 14 | import jax 15 | import jax.numpy as jnp 16 | import numpy as np 17 | from jax.experimental import multihost_utils 18 | from transformers import FlaxAutoModelForCausalLM 19 | 20 | from tokenkit import parse_args 21 | from tokenkit.hf import get_config 22 | from tokenkit.byteify import load_byteify_tokenizer 23 | from tokenkit.eval import evaluate_lockstep 24 | from tokenkit.models import param, sharding 25 | 26 | logger = logging.getLogger(__name__) 27 | 28 | datasets.config.HF_DATASETS_TRUST_REMOTE_CODE = ( 29 | True # careful about this, required for lm_eval 30 | ) 31 | 32 | @dataclass 33 | class EvalLockstepScriptArgs: 34 | combine_strategy: str 35 | models: list[parse_args.ModelArgs] 36 | eval: parse_args.EvalArgs 37 | baseline_mined_mapping_paths: list[str] | None = None 38 | output: str | None = None 39 | pad_to_multiple_of: int = 128 40 | use_cpu: bool = False 41 | 42 | 43 | def pad_embeddings(embeddings, tokenizer): 44 | n_embed_diff = len(tokenizer) - len(embeddings) 45 | 46 | embeddings_mean = embeddings.mean(0) 47 | embeddings_std = embeddings.std(0) 48 | 49 | return np.concatenate( 50 | [ 51 | embeddings, 52 | np.random.normal( 53 | size=(n_embed_diff, *embeddings.shape[1:]), 54 | ) 55 | * embeddings_std[None] 56 | + embeddings_mean[None], 57 | ] 58 | ) 59 | 60 | 61 | def main(args: EvalLockstepScriptArgs) -> None: 62 | logger.info(pformat(args)) 63 | 64 | eval_kwargs = asdict(args.eval) 65 | 66 | if args.output is not None: 67 | output_dir = Path(args.output) 68 | output_dir.mkdir(parents=True, exist_ok=True) 69 | 70 | with open(output_dir / "args.yaml", "w") as f: 71 | yaml.dump(asdict(args), f) 72 | else: 73 | output_dir = None 74 | 75 | if args.use_cpu: 76 | jax.config.update("jax_default_device", jax.devices("cpu")[0]) 77 | mesh = sharding.get_mesh(devices=jax.devices("cpu")) 78 | else: 79 | mesh = sharding.get_mesh() 80 | 81 | all_models = [] 82 | all_configs = [] 83 | all_params = [] 84 | all_tokenizers = [] 85 | all_logit_masks = [] 86 | 87 | eval_kwargs.pop("add_bos") 88 | all_add_bos = [] 89 | 90 | for model_idx, model_kwargs in enumerate(args.models): 91 | print("Loading model...") 92 | 93 | config = get_config(model_kwargs["pretrained_model_name_or_path"]) 94 | 95 | config.max_length = eval_kwargs["lengths"][-1] 96 | config.mesh = mesh 97 | 98 | tokenizer = load_byteify_tokenizer(model_kwargs.pop("tokenizer_name")) 99 | 100 | model = FlaxAutoModelForCausalLM.from_config(config, _do_init=False) 101 | params = param.load_params( 102 | pretrained_model_name_or_path=model_kwargs["pretrained_model_name_or_path"] 103 | ) 104 | 105 | input_embeddings = param.get( 106 | params, param.get_input_embedding_path(config.model_type) 107 | ) 108 | 109 | if len(input_embeddings) < len(tokenizer): 110 | print("Padding input embeddings...") 111 | input_embeddings = pad_embeddings(input_embeddings, tokenizer) 112 | 113 | if not config.tie_word_embeddings: 114 | output_embeddings = param.get( 115 | params, param.get_output_embedding_path(config.model_type) 116 | ) 117 | print("Padding output embeddings...") 118 | output_embeddings = pad_embeddings(output_embeddings.T, tokenizer).T 119 | else: 120 | output_embeddings = None 121 | 122 | n_overflow = input_embeddings.shape[0] % args.pad_to_multiple_of 123 | if n_overflow > 0: 124 | n_pad = args.pad_to_multiple_of - n_overflow 125 | else: 126 | n_pad = 0 127 | 128 | input_embeddings = np.pad( 129 | input_embeddings, 130 | ((0, n_pad), (0, 0)), 131 | mode="constant", 132 | constant_values=0, 133 | ) 134 | if output_embeddings is not None: 135 | output_embeddings = np.pad( 136 | output_embeddings, 137 | ((0, 0), (0, n_pad)), 138 | mode="constant", 139 | constant_values=0, 140 | ) 141 | logit_mask = np.zeros((input_embeddings.shape[0],), dtype=bool) 142 | logit_mask[: model.config.vocab_size] = True 143 | model.config.vocab_size = input_embeddings.shape[0] 144 | 145 | params = param.put( 146 | params, param.get_input_embedding_path(config.model_type), input_embeddings 147 | ) 148 | if output_embeddings is not None: 149 | params = param.put( 150 | params, 151 | param.get_output_embedding_path(config.model_type), 152 | output_embeddings, 153 | ) 154 | 155 | shard_patterns = sharding.get_shard_patterns(config.model_type) 156 | param_shardings = sharding.get_sharding_fn(shard_patterns, mesh)( 157 | {"params": params} 158 | )["params"] 159 | params = sharding.to_devices(params, param_shardings, dtype=jnp.float32) 160 | 161 | multihost_utils.sync_global_devices("loaded weights") 162 | 163 | if args.baseline_mined_mapping_paths is not None: 164 | if args.baseline_mined_mapping_paths[model_idx] is not None: 165 | config.mined_mapping = np.load( 166 | Path(args.baseline_mined_mapping_paths[model_idx]) / "mined_mapping.npy" 167 | ) 168 | else: 169 | config.mined_mapping = None 170 | 171 | all_models.append(model) 172 | all_configs.append(config) 173 | all_params.append(params) 174 | all_tokenizers.append(tokenizer) 175 | all_logit_masks.append(logit_mask) 176 | all_add_bos.append(model_kwargs["add_bos"]) 177 | 178 | # static combine fn for the moment 179 | def combine_fn(hidden_states, logits, combine_params, output_embeddings): 180 | if args.combine_strategy == "mean_prob": 181 | aggregated_probs = None 182 | for model_logits in logits: 183 | model_probs = jax.nn.softmax(model_logits, axis=-1) 184 | if aggregated_probs is None: 185 | aggregated_probs = model_probs 186 | else: 187 | aggregated_probs += model_probs 188 | 189 | aggregated_probs /= len(logits) 190 | return jnp.log(aggregated_probs) 191 | elif args.combine_strategy == "mean_logits": 192 | aggregated_logits = None 193 | for model_logits in logits: 194 | if aggregated_logits is None: 195 | aggregated_logits = model_logits 196 | else: 197 | aggregated_logits += model_logits 198 | 199 | aggregated_logits /= len(logits) 200 | return aggregated_logits 201 | else: 202 | raise ValueError(f"Unknown combine strategy: {args.combine_strategy}") 203 | 204 | results = evaluate_lockstep( 205 | models=all_models, 206 | configs=all_configs, 207 | params=all_params, 208 | tokenizers=all_tokenizers, 209 | logit_masks=all_logit_masks, 210 | add_bos=all_add_bos, 211 | combine_fn=combine_fn, 212 | combine_params={}, 213 | jaxlm_kwargs={"precompile": not args.use_cpu}, 214 | output=output_dir, 215 | **eval_kwargs, 216 | ) 217 | 218 | if jax.process_index() == 0: 219 | pprint(results[0]) 220 | 221 | 222 | if __name__ == "__main__": 223 | os.environ["HF_ALLOW_CODE_EVAL"] = "1" 224 | main(parse_args.parse_args(EvalLockstepScriptArgs)) -------------------------------------------------------------------------------- /scripts/export_checkpoint.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from transformers import ( 3 | HfArgumentParser, 4 | AutoTokenizer, 5 | FlaxAutoModelForCausalLM, 6 | AutoModelForCausalLM, 7 | ) 8 | from transformers.modeling_flax_pytorch_utils import load_flax_weights_in_pytorch_model 9 | from omegaconf import OmegaConf 10 | from flax import serialization, traverse_util 11 | from pathlib import Path 12 | from pickle import UnpicklingError 13 | from flax.serialization import from_bytes 14 | from flax.traverse_util import flatten_dict, unflatten_dict 15 | import jax 16 | import jax.numpy as jnp 17 | from tokenkit.models.hypernet import Hypernet 18 | from tokenkit.models import param, lora, sharding 19 | from tokenkit.byteify import load_byteify_tokenizer 20 | from tokenkit.hf import get_config 21 | from tokenkit import gcs_utils, utils, constants 22 | import json 23 | import os 24 | import torch 25 | from pprint import pformat 26 | import logging 27 | 28 | logger = logging.getLogger(__name__) 29 | 30 | 31 | # transformers `load_flax_checkpoint_in_pytorch_model` does not support custom models 32 | # so we patch it here 33 | def load_flax_checkpoint_in_pytorch_model(model, flax_checkpoint_path, flax_cls): 34 | """Load flax checkpoints in a PyTorch model""" 35 | flax_checkpoint_path = os.path.abspath(flax_checkpoint_path) 36 | logger.info(f"Loading Flax weights from {flax_checkpoint_path}") 37 | 38 | # load flax weight dict 39 | if flax_checkpoint_path.endswith(".safetensors"): 40 | from safetensors.flax import load_file as safe_load_file 41 | 42 | flax_state_dict = safe_load_file(flax_checkpoint_path) 43 | flax_state_dict = unflatten_dict(flax_state_dict, sep=".") 44 | else: 45 | with open(flax_checkpoint_path, "rb") as state_f: 46 | try: 47 | flax_state_dict = from_bytes(flax_cls, state_f.read()) 48 | except UnpicklingError: 49 | raise EnvironmentError(f"Unable to convert {flax_checkpoint_path} to Flax deserializable object. ") 50 | 51 | return load_flax_weights_in_pytorch_model(model, flax_state_dict) 52 | 53 | 54 | 55 | @dataclass 56 | class Args: 57 | checkpoint: str = "outputs/patch" 58 | output: str = "outputs/export" 59 | use_cpu: bool = True 60 | tmp_save_dir: str = "/tmp/tokenkit/" 61 | with_pt: bool = False 62 | expand_input_ids_model: str | None = None 63 | expand_input_ids_tokenizer: str | None = None 64 | overwrite_args: str | None = None 65 | 66 | 67 | if __name__ == "__main__": 68 | (args,) = HfArgumentParser([Args]).parse_args_into_dataclasses() 69 | 70 | tmp_checkpoint_dir = Path(args.tmp_save_dir) / "checkpoint" 71 | tmp_output_dir = Path(args.tmp_save_dir) / "output" 72 | 73 | tmp_checkpoint_dir.mkdir(parents=True, exist_ok=True) 74 | tmp_output_dir.mkdir(parents=True, exist_ok=True) 75 | 76 | if args.use_cpu: 77 | jax.config.update("jax_default_device", jax.devices("cpu")[0]) 78 | mesh = sharding.get_mesh(devices=jax.devices("cpu")) 79 | else: 80 | mesh = sharding.get_mesh() 81 | 82 | if gcs_utils.is_gcs_path(args.checkpoint): 83 | checkpoint_bucket, checkpoint_blob = gcs_utils.parse_gcs_path(args.checkpoint) 84 | checkpoint_dir = tmp_checkpoint_dir 85 | 86 | for filename in ["args.yaml", "params.msgpack", "config.json", "tokenizer.json", "tokenizer_config.json"]: 87 | gcs_utils.download_from_gcs(checkpoint_bucket, f"{checkpoint_blob}/{filename}", checkpoint_dir / filename) 88 | else: 89 | checkpoint_dir = Path(args.checkpoint) 90 | 91 | ckpt_args = OmegaConf.load(checkpoint_dir / "args.yaml") 92 | if args.overwrite_args is not None: 93 | ckpt_args = OmegaConf.merge( 94 | ckpt_args, OmegaConf.create(json.loads(args.overwrite_args)) 95 | ) 96 | 97 | logger.info("Using checkpoint args:") 98 | logger.info(pformat(ckpt_args)) 99 | 100 | params = serialization.msgpack_restore( 101 | open(checkpoint_dir / "params.msgpack", "rb").read() 102 | ) 103 | 104 | config = get_config(checkpoint_dir) 105 | config.mesh = mesh 106 | tokenizer = AutoTokenizer.from_pretrained(checkpoint_dir) 107 | dtype = getattr(jnp, ckpt_args.dtype) 108 | 109 | n_embd = params["new_embeddings"].shape[-1] 110 | 111 | hypernet = Hypernet( 112 | dtype=dtype, 113 | hidden_size=n_embd, 114 | num_embeddings=1 if config.tie_word_embeddings else 2, 115 | max_seq_length=1, 116 | vocab_size=config.vocab_size, 117 | **ckpt_args.hypernet, 118 | ) 119 | model_kwargs = OmegaConf.to_object(ckpt_args.student) 120 | 121 | if "model" in params: 122 | model_params = params["model"] 123 | original_model_params = param.load_params(**model_kwargs) 124 | else: 125 | model_params = original_model_params = param.load_params(**model_kwargs) 126 | 127 | # model params may be partial at this point e.g. if trained with LoRA, merge them 128 | flat_merged_model_params = traverse_util.flatten_dict(original_model_params) 129 | flat_model_params = traverse_util.flatten_dict(model_params) 130 | 131 | for key in flat_model_params.keys(): 132 | flat_merged_model_params[key] = flat_model_params[key] 133 | 134 | merged_model_params = traverse_util.unflatten_dict(flat_merged_model_params) 135 | # assigned later 136 | merged_model_params = param.unassign_embeddings(merged_model_params, config=config) 137 | 138 | if "model_lora" in params: 139 | logger.info("Materializing LoRA parameters...") 140 | merged_model_params = lora.materialize_lora( 141 | merged_model_params, 142 | params["model_lora"], 143 | ckpt_args.model_lora_alpha, 144 | ) 145 | 146 | hypernet_fn = hypernet.apply 147 | 148 | def predict_embeddings(params): # TODO: add indices for subsampling 149 | embeddings = params["new_embeddings"] 150 | 151 | predicted_embeddings = hypernet_fn( 152 | params["hypernet"], 153 | embeddings[:, None, :, :], 154 | jnp.ones((embeddings.shape[0], 1), dtype=bool), 155 | jnp.arange(embeddings.shape[0], dtype=jnp.int32), 156 | ) 157 | 158 | return predicted_embeddings 159 | 160 | embeddings = jax.device_get(predict_embeddings(params)) 161 | embeddings = embeddings.copy() # not writeable otherwise 162 | 163 | # remove padding 164 | config.vocab_size = len(tokenizer) 165 | embeddings = embeddings[: len(tokenizer)] # remove padding 166 | 167 | merged_model_params = param.assign_embeddings(merged_model_params, embeddings, config=config) 168 | 169 | model_to_save = FlaxAutoModelForCausalLM.from_config(config) 170 | if gcs_utils.is_gcs_path(args.output): 171 | output_dir = tmp_output_dir 172 | else: 173 | output_dir = Path(args.output) 174 | 175 | del config.mesh 176 | 177 | # from_flax does not work with multiple shards so it is more convenient to save the model as a single shard 178 | model_to_save.save_pretrained( 179 | output_dir, params=merged_model_params, max_shard_size="100GB" 180 | ) 181 | 182 | if args.with_pt: 183 | if args.expand_input_ids_model is not None: 184 | byteify_tokenizer = load_byteify_tokenizer(ckpt_args.target_tokenizer_name) 185 | expand_tokenizer = load_byteify_tokenizer(args.expand_input_ids_tokenizer) 186 | 187 | expand_input_ids_dict = utils.get_expand_input_ids_dict( 188 | byteify_tokenizer, 189 | expand_tokenizer.get_vocab(), 190 | max_length=constants.EXPAND_INPUT_IDS_MAX_LENGTH, 191 | ) 192 | 193 | config.expand_input_ids = True 194 | config.expand_input_ids_maxlen = constants.EXPAND_INPUT_IDS_MAX_LENGTH 195 | config.expand_input_ids_vocab_size = len(expand_tokenizer) 196 | # make json serializable - will be deserialized in PT model init 197 | config.expand_input_ids_dict = ( 198 | {",".join([str(n) for n in k]): int(v) for k, v in expand_input_ids_dict[0].items()}, 199 | [int(n) for n in expand_input_ids_dict[1]], 200 | ) 201 | 202 | pt_model = AutoModelForCausalLM.from_config(config) 203 | pt_model = load_flax_checkpoint_in_pytorch_model(pt_model, output_dir / "flax_model.msgpack", type(model_to_save)) 204 | 205 | # set expansion embedding data 206 | if args.expand_input_ids_model is not None: 207 | expand_input_ids_model_config = get_config(args.expand_input_ids_model) 208 | expand_input_ids_model_params = param.load_params(pretrained_model_name_or_path=args.expand_input_ids_model) 209 | expand_input_ids_embeddings = param.get( 210 | expand_input_ids_model_params, 211 | param.get_input_embedding_path(expand_input_ids_model_config.model_type), 212 | ) 213 | 214 | pt_model.model.expand_embed_tokens.weight.data[:] = torch.from_numpy(expand_input_ids_embeddings) 215 | 216 | pt_model.save_pretrained(output_dir) 217 | else: 218 | pt_model = None 219 | 220 | if args.expand_input_ids_model is not None: 221 | raise ValueError("expand_input_ids_model is not supported when with_pt is False") 222 | 223 | config.auto_map = { 224 | "AutoConfig": f"configuration_{config.model_type}.{type(config).__name__}", 225 | "FlaxAutoModelForCausalLM": f"modelling_flax_{config.model_type}.{type(model_to_save).__name__}" 226 | } 227 | 228 | if pt_model is not None: 229 | config.auto_map["AutoModelForCausalLM"] = f"modelling_{config.model_type}.{type(pt_model).__name__}" 230 | 231 | tokenizer.save_pretrained(output_dir) 232 | config.save_pretrained(output_dir) 233 | 234 | if gcs_utils.is_gcs_path(args.output): 235 | output_bucket, output_blob = gcs_utils.parse_gcs_path(args.output) 236 | for filename in ["config.json", "flax_model.msgpack", "tokenizer.json", "tokenizer_config.json"]: 237 | gcs_utils.upload_to_gcs(output_bucket, output_dir / filename, f"{output_blob}/{filename}") -------------------------------------------------------------------------------- /scripts/push_flax_version_to_hub.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from transformers import HfArgumentParser, AutoModelForCausalLM, AutoTokenizer 3 | import json 4 | import jax 5 | import transformers 6 | import shutil 7 | 8 | from tokenkit.models import sharding 9 | 10 | TMP_PATH = "/mnt/disks/persist/tmp/model" 11 | 12 | @dataclass 13 | class Args: 14 | model_name_or_path: str = "Qwen/Qwen2-0.5B" 15 | hub_user: str = "benjamin" 16 | model_class: str = "Llama" 17 | extra_args: str | None = None # for Qwen2: "{\"attention_bias\": true, \"max_length\": 8192}", for Llama3: "{\"max_length\": 8192}" 18 | use_cpu: bool = False 19 | 20 | 21 | if __name__ == "__main__": 22 | (args,) = HfArgumentParser([Args]).parse_args_into_dataclasses() 23 | print(args) 24 | 25 | if args.use_cpu: 26 | jax.config.update('jax_default_device', jax.devices('cpu')[0]) 27 | mesh = sharding.get_mesh(devices=jax.devices("cpu")) 28 | else: 29 | mesh = sharding.get_mesh() 30 | 31 | shutil.rmtree(TMP_PATH, ignore_errors=True) 32 | AutoModelForCausalLM.from_pretrained(args.model_name_or_path).save_pretrained( 33 | TMP_PATH, max_shard_size="100GB" 34 | ) 35 | tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path) 36 | config_class = getattr(transformers, args.model_class + "Config") 37 | if hasattr(transformers, "Flax" + args.model_class + "ForCausalLM"): 38 | model_class = getattr(transformers, "Flax" + args.model_class + "ForCausalLM") 39 | elif hasattr(transformers, "Flax" + args.model_class + "LMHeadModel"): 40 | model_class = getattr(transformers, "Flax" + args.model_class + "LMHeadModel") 41 | else: 42 | raise ValueError(f"Model class '{args.model_class}' not found") 43 | 44 | config = config_class.from_pretrained(TMP_PATH, args.model_name_or_path) 45 | for key, value in json.loads(args.extra_args or "{}").items(): 46 | setattr(config, key, value) 47 | 48 | config.mesh = mesh 49 | 50 | flax_model = model_class.from_pretrained(TMP_PATH, config=config) 51 | model_name = args.hub_user + "/" + args.model_name_or_path.split("/")[-1] + "-flax" 52 | 53 | del config.mesh 54 | 55 | flax_model.push_to_hub(model_name, private=True, safe_serialization=False) 56 | tokenizer.push_to_hub(model_name, private=True) -------------------------------------------------------------------------------- /scripts/zett.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from pprint import pformat 3 | 4 | import jax 5 | from dataclasses import dataclass, asdict 6 | from transformers import FlaxAutoModelForCausalLM 7 | 8 | from tokenkit.hf import get_config 9 | from tokenkit import utils 10 | from tokenkit.byteify import load_byteify_tokenizer 11 | from tokenkit.models import param, sharding 12 | from tokenkit import parse_args 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | @dataclass 17 | class ZettArgs: 18 | source_model: parse_args.ModelArgs 19 | target_tokenizer_name: str 20 | output: str 21 | 22 | def main(args: ZettArgs) -> None: 23 | logger.info(pformat(args)) 24 | 25 | # Load the model & tokenizer 26 | source_tokenizer = load_byteify_tokenizer(args.source_model.tokenizer_name) 27 | target_tokenizer = load_byteify_tokenizer(args.target_tokenizer_name) 28 | 29 | mesh = sharding.get_mesh(devices=jax.devices("cpu")) 30 | config = get_config(args.source_model.pretrained_model_name_or_path) 31 | config.mesh = mesh 32 | 33 | model = FlaxAutoModelForCausalLM.from_config( 34 | config, 35 | _do_init=False, 36 | input_shape=(1, 128), 37 | ) 38 | del model.config.mesh 39 | 40 | model_params = param.load_params(**asdict(args.source_model)) 41 | embeddings, model_params = param.stack_embeddings( 42 | model_params, 43 | config, 44 | pop_embeddings=True, 45 | ) 46 | 47 | diff_embeddings, original_to_new_indices, diff_indices = utils.fvt( 48 | source_tokenizer, 49 | target_tokenizer, 50 | embeddings, 51 | ) 52 | new_embeddings = embeddings[original_to_new_indices] 53 | if len(diff_indices) > 0: 54 | new_embeddings[diff_indices] = diff_embeddings 55 | 56 | model_params = param.assign_embeddings(model_params, new_embeddings, config) 57 | 58 | model.save_pretrained(args.output, params=model_params) 59 | config.save_pretrained(args.output) 60 | target_tokenizer.save_pretrained(args.output) 61 | 62 | 63 | if __name__ == "__main__": 64 | main(parse_args.parse_args(ZettArgs)) -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | 3 | setup( 4 | name="tokenkit", 5 | version="0.0.1", 6 | packages=find_packages(), 7 | install_requires=[], 8 | entry_points={ 9 | "console_scripts": [], 10 | }, 11 | ) 12 | -------------------------------------------------------------------------------- /tokenkit/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bminixhofer/tokenkit/7abac3578f6b3fa38f985e9f03bff7a47d5ab3b1/tokenkit/__init__.py -------------------------------------------------------------------------------- /tokenkit/baseline_utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import multiprocessing 3 | from functools import partial 4 | 5 | import editdistance 6 | import jax 7 | import jax.numpy as jnp 8 | import numpy as np 9 | from tqdm.auto import tqdm 10 | 11 | 12 | def _compute_edit_distance(token, sorted_original_vocab): 13 | min_edit_distance = math.inf 14 | best_match = None 15 | 16 | closer_to_start = len(token) < len( 17 | sorted_original_vocab[int(len(sorted_original_vocab) / 2)] 18 | ) 19 | 20 | if closer_to_start: 21 | candidates = sorted_original_vocab 22 | else: 23 | candidates = reversed(sorted_original_vocab) 24 | 25 | for original_token in candidates: 26 | if closer_to_start: 27 | # tokens only get longer 28 | if len(original_token) - len(token) >= min_edit_distance: 29 | break 30 | if len(token) - len(original_token) >= min_edit_distance: 31 | continue 32 | else: 33 | # tokens only get shorter 34 | if len(token) - len(original_token) >= min_edit_distance: 35 | break 36 | if len(original_token) - len(token) >= min_edit_distance: 37 | continue 38 | 39 | edit_distance = editdistance.eval(token, original_token) 40 | if edit_distance < min_edit_distance: 41 | min_edit_distance = edit_distance 42 | best_match = original_token 43 | 44 | return token, best_match, min_edit_distance 45 | 46 | 47 | def compute_mined_mapping( 48 | tokenizer_original, tokenizer_new, num_workers=1, chunksize=500 49 | ): 50 | original_vocab = tokenizer_original.get_vocab() 51 | new_vocab = tokenizer_new.get_vocab() 52 | 53 | mapping = np.zeros(len(tokenizer_new), dtype=np.int32) 54 | edit_distances = {} 55 | 56 | intersection = [token for token in new_vocab.keys() if token in original_vocab] 57 | completion = [token for token in new_vocab.keys() if token not in original_vocab] 58 | sorted_completion = sorted(completion, key=lambda x: len(x)) 59 | sorted_original_vocab = sorted(original_vocab.keys(), key=lambda x: len(x)) 60 | 61 | for token in intersection: 62 | mapping[new_vocab[token]] = original_vocab[token] 63 | edit_distances[token] = 0 64 | 65 | with multiprocessing.Pool(max(num_workers, 1)) as pool: 66 | results = list( 67 | tqdm( 68 | pool.imap_unordered( 69 | partial( 70 | _compute_edit_distance, 71 | sorted_original_vocab=sorted_original_vocab, 72 | ), 73 | sorted_completion, 74 | chunksize=chunksize, 75 | ), 76 | desc="Computing MinED mapping", 77 | total=len(sorted_completion), 78 | ) 79 | ) 80 | 81 | for token, best_match, min_edit_distance in results: 82 | mapping[new_vocab[token]] = original_vocab[best_match] 83 | edit_distances[token] = min_edit_distance 84 | 85 | return mapping, edit_distances 86 | 87 | 88 | def compute_forward_kl_divergence( 89 | logits, 90 | teacher_logits, 91 | target, 92 | kd_temp, 93 | padding_id, 94 | tea_temp=None, 95 | reduction="sum", 96 | log=None, 97 | use_tea_temp=False, 98 | ): 99 | logits = logits / kd_temp 100 | teacher_logits = teacher_logits / kd_temp 101 | teacher_logits = teacher_logits / tea_temp if use_tea_temp else teacher_logits 102 | 103 | lprobs = jax.nn.log_softmax(logits, -1) 104 | teacher_probs = jax.nn.softmax(teacher_logits, -1) 105 | teacher_lprobs = jax.nn.log_softmax(teacher_logits, -1) 106 | kld = teacher_probs * (teacher_lprobs - lprobs) 107 | inf_mask = jnp.isinf(logits) 108 | kld = jnp.where(inf_mask, 0.0, kld).sum(-1) 109 | 110 | if reduction == "sum": 111 | pad_mask = target == padding_id 112 | kld = jnp.where(pad_mask, 0.0, kld) 113 | kld = kld.sum() 114 | 115 | if log is not None: 116 | log["forward_kl"] = kld 117 | 118 | return kld 119 | 120 | 121 | def compute_reverse_kl_divergence( 122 | logits, 123 | teacher_logits, 124 | target, 125 | kd_temp, 126 | padding_id, 127 | tea_temp=None, 128 | reduction="sum", 129 | log=None, 130 | use_tea_temp=False, 131 | ): 132 | logits = logits / kd_temp 133 | teacher_logits = teacher_logits / kd_temp 134 | teacher_logits = teacher_logits / tea_temp if use_tea_temp else teacher_logits 135 | 136 | probs = jax.nn.softmax(logits, -1) 137 | lprobs = jax.nn.log_softmax(logits, -1) 138 | teacher_lprobs = jax.nn.log_softmax(teacher_logits, -1) 139 | kld = probs * (lprobs - teacher_lprobs) 140 | inf_mask = jnp.isinf(logits) | jnp.isinf(teacher_logits) 141 | kld = jnp.where(inf_mask, 0.0, kld).sum(-1) 142 | 143 | if reduction == "sum": 144 | pad_mask = target == padding_id 145 | kld = jnp.where(pad_mask, 0.0, kld) 146 | kld = kld.sum() 147 | 148 | if log is not None: 149 | log["reverse_kl"] = kld 150 | 151 | return kld 152 | 153 | 154 | def compute_adaptive_kl_divergence( 155 | logits, 156 | teacher_logits, 157 | target, 158 | kd_temp, 159 | padding_id, 160 | alpha, 161 | tea_temp=None, 162 | reduction="sum", 163 | log=None, 164 | use_tea_temp=False, 165 | ): 166 | probs = jax.nn.softmax(logits / kd_temp, axis=-1).astype(jnp.float32) 167 | if use_tea_temp: 168 | teacher_probs = jax.nn.softmax( 169 | teacher_logits / tea_temp / kd_temp, axis=-1 170 | ).astype(jnp.float32) 171 | else: 172 | teacher_probs = jax.nn.softmax(teacher_logits / kd_temp, axis=-1).astype( 173 | jnp.float32 174 | ) 175 | 176 | sorted_teacher_probs = jnp.sort(teacher_probs, axis=-1) 177 | sorted_idx = jnp.argsort(teacher_probs, axis=-1) 178 | sorted_probs = jnp.take_along_axis( 179 | probs, sorted_idx, axis=-1 180 | ) # TODO: check if we need [..., None]? 181 | gap = jnp.abs(sorted_teacher_probs - sorted_probs) 182 | cum_teacher_probs = jnp.cumsum(sorted_teacher_probs, axis=-1) 183 | tail_mask = (cum_teacher_probs < alpha).astype(jnp.float32) 184 | g_head = jax.lax.stop_gradient(jnp.sum(gap * (1 - tail_mask), axis=-1)) 185 | g_tail = jax.lax.stop_gradient(jnp.sum(gap * tail_mask, axis=-1)) 186 | 187 | fkl = compute_forward_kl_divergence( 188 | logits, 189 | teacher_logits, 190 | target, 191 | kd_temp, 192 | padding_id, 193 | tea_temp=tea_temp, 194 | reduction="none", 195 | use_tea_temp=use_tea_temp, 196 | ) 197 | rkl = compute_reverse_kl_divergence( 198 | logits, 199 | teacher_logits, 200 | target, 201 | kd_temp, 202 | padding_id, 203 | tea_temp=tea_temp, 204 | reduction="none", 205 | use_tea_temp=use_tea_temp, 206 | ) 207 | 208 | akl = (g_head / (g_head + g_tail)) * fkl + (g_tail / (g_head + g_tail)) * rkl 209 | 210 | if reduction == "sum": 211 | pad_mask = target == padding_id 212 | akl = jnp.where(pad_mask, 0.0, akl) 213 | akl = akl.sum() 214 | 215 | if log is not None: 216 | log["adaptive_kl"] = akl 217 | 218 | return akl 219 | 220 | 221 | def compute_skewed_forward_kl_divergence( 222 | logits, 223 | teacher_logits, 224 | target, 225 | kd_temp, 226 | padding_id, 227 | skew_lambda, 228 | tea_temp=None, 229 | reduction="sum", 230 | log=None, 231 | use_tea_temp=False, 232 | epsilon=1e-9, 233 | ): 234 | logits = logits / kd_temp 235 | teacher_logits = teacher_logits / kd_temp 236 | teacher_logits = teacher_logits / tea_temp if use_tea_temp else teacher_logits 237 | 238 | student_probs = jax.nn.softmax(logits, -1).astype(jnp.float32) 239 | teacher_probs = jax.nn.softmax(teacher_logits, -1).astype(jnp.float32) 240 | mixed_probs = skew_lambda * teacher_probs + (1 - skew_lambda) * student_probs 241 | mixed_lprobs = jnp.log(mixed_probs + epsilon) 242 | teacher_lprobs = jax.nn.log_softmax(teacher_logits, -1).astype(jnp.float32) 243 | kld = teacher_probs * (teacher_lprobs - mixed_lprobs) 244 | inf_mask = jnp.isinf(logits) | jnp.isinf(teacher_logits) 245 | kld = jnp.where(inf_mask, 0.0, kld).sum(-1) 246 | 247 | if reduction == "sum": 248 | pad_mask = target == padding_id 249 | kld = jnp.where(pad_mask, 0.0, kld) 250 | kld = kld.sum() 251 | 252 | if log is not None: 253 | log["skewed_forward_kl"] = kld 254 | 255 | return kld 256 | 257 | 258 | def compute_skewed_reverse_kl_divergence( 259 | logits, 260 | teacher_logits, 261 | target, 262 | kd_temp, 263 | padding_id, 264 | skew_lambda, 265 | tea_temp=None, 266 | reduction="sum", 267 | log=None, 268 | use_tea_temp=False, 269 | epsilon=1e-9, 270 | ): 271 | logits = logits / kd_temp 272 | teacher_logits = teacher_logits / kd_temp 273 | teacher_logits = teacher_logits / tea_temp if use_tea_temp else teacher_logits 274 | 275 | student_probs = jax.nn.softmax(logits, -1).astype(jnp.float32) 276 | teacher_probs = jax.nn.softmax(teacher_logits, -1).astype(jnp.float32) 277 | mixed_probs = (1 - skew_lambda) * teacher_probs + skew_lambda * student_probs 278 | mixed_lprobs = jnp.log(mixed_probs + epsilon) 279 | student_lprobs = jax.nn.log_softmax(logits, -1).astype(jnp.float32) 280 | kld = student_probs * (student_lprobs - mixed_lprobs) 281 | inf_mask = jnp.isinf(logits) | jnp.isinf(teacher_logits) 282 | kld = jnp.where(inf_mask, 0.0, kld).sum(-1) 283 | 284 | if reduction == "sum": 285 | pad_mask = target == padding_id 286 | kld = jnp.where(pad_mask, 0.0, kld) 287 | kld = kld.sum() 288 | 289 | if log is not None: 290 | log["skewed_reverse_kl"] = kld 291 | 292 | return kld 293 | 294 | 295 | def compute_js_divergence( 296 | logits, 297 | teacher_logits, 298 | target, 299 | kd_temp, 300 | tea_temp, 301 | padding_id, 302 | reduction="sum", 303 | log=None, 304 | use_tea_temp=False, 305 | epsilon=1e-9, 306 | ): 307 | logits = logits / kd_temp 308 | teacher_logits = teacher_logits / kd_temp 309 | teacher_logits = teacher_logits / tea_temp if use_tea_temp else teacher_logits 310 | 311 | probs = jax.nn.softmax(logits, -1).astype(jnp.float32) 312 | teacher_probs = jax.nn.softmax(teacher_logits, -1).astype(jnp.float32) 313 | m_probs = (probs + teacher_probs) / 2 314 | 315 | lprobs = jnp.log(probs + epsilon) 316 | teacher_lprobs = jnp.log(teacher_probs + epsilon) 317 | m_lprobs = jnp.log(m_probs + epsilon) 318 | 319 | kld1 = teacher_probs * (teacher_lprobs - m_lprobs) 320 | kld2 = probs * (lprobs - m_lprobs) 321 | kld = (kld1 + kld2) / 2 322 | 323 | if reduction == "sum": 324 | pad_mask = target == padding_id 325 | kld = jnp.where(pad_mask, 0.0, kld) 326 | kld = kld.sum() 327 | 328 | if log is not None: 329 | log["js_div"] = kld 330 | 331 | return kld 332 | -------------------------------------------------------------------------------- /tokenkit/constants.py: -------------------------------------------------------------------------------- 1 | # Character mapping dictionaries 2 | EXPAND_INPUT_IDS_MAX_LENGTH = 16 3 | CHARS_TO_BYTES = { 4 | "Ā": 0, 5 | "ā": 1, 6 | "Ă": 2, 7 | "ă": 3, 8 | "Ą": 4, 9 | "ą": 5, 10 | "Ć": 6, 11 | "ć": 7, 12 | "Ĉ": 8, 13 | "ĉ": 9, 14 | "Ċ": 10, 15 | "ċ": 11, 16 | "Č": 12, 17 | "č": 13, 18 | "Ď": 14, 19 | "ď": 15, 20 | "Đ": 16, 21 | "đ": 17, 22 | "Ē": 18, 23 | "ē": 19, 24 | "Ĕ": 20, 25 | "ĕ": 21, 26 | "Ė": 22, 27 | "ė": 23, 28 | "Ę": 24, 29 | "ę": 25, 30 | "Ě": 26, 31 | "ě": 27, 32 | "Ĝ": 28, 33 | "ĝ": 29, 34 | "Ğ": 30, 35 | "ğ": 31, 36 | "Ġ": 32, 37 | "!": 33, 38 | '"': 34, 39 | "#": 35, 40 | "$": 36, 41 | "%": 37, 42 | "&": 38, 43 | "'": 39, 44 | "(": 40, 45 | ")": 41, 46 | "*": 42, 47 | "+": 43, 48 | ",": 44, 49 | "-": 45, 50 | ".": 46, 51 | "/": 47, 52 | "0": 48, 53 | "1": 49, 54 | "2": 50, 55 | "3": 51, 56 | "4": 52, 57 | "5": 53, 58 | "6": 54, 59 | "7": 55, 60 | "8": 56, 61 | "9": 57, 62 | ":": 58, 63 | ";": 59, 64 | "<": 60, 65 | "=": 61, 66 | ">": 62, 67 | "?": 63, 68 | "@": 64, 69 | "A": 65, 70 | "B": 66, 71 | "C": 67, 72 | "D": 68, 73 | "E": 69, 74 | "F": 70, 75 | "G": 71, 76 | "H": 72, 77 | "I": 73, 78 | "J": 74, 79 | "K": 75, 80 | "L": 76, 81 | "M": 77, 82 | "N": 78, 83 | "O": 79, 84 | "P": 80, 85 | "Q": 81, 86 | "R": 82, 87 | "S": 83, 88 | "T": 84, 89 | "U": 85, 90 | "V": 86, 91 | "W": 87, 92 | "X": 88, 93 | "Y": 89, 94 | "Z": 90, 95 | "[": 91, 96 | "\\": 92, 97 | "]": 93, 98 | "^": 94, 99 | "_": 95, 100 | "`": 96, 101 | "a": 97, 102 | "b": 98, 103 | "c": 99, 104 | "d": 100, 105 | "e": 101, 106 | "f": 102, 107 | "g": 103, 108 | "h": 104, 109 | "i": 105, 110 | "j": 106, 111 | "k": 107, 112 | "l": 108, 113 | "m": 109, 114 | "n": 110, 115 | "o": 111, 116 | "p": 112, 117 | "q": 113, 118 | "r": 114, 119 | "s": 115, 120 | "t": 116, 121 | "u": 117, 122 | "v": 118, 123 | "w": 119, 124 | "x": 120, 125 | "y": 121, 126 | "z": 122, 127 | "{": 123, 128 | "|": 124, 129 | "}": 125, 130 | "~": 126, 131 | "ġ": 127, 132 | "Ģ": 128, 133 | "ģ": 129, 134 | "Ĥ": 130, 135 | "ĥ": 131, 136 | "Ħ": 132, 137 | "ħ": 133, 138 | "Ĩ": 134, 139 | "ĩ": 135, 140 | "Ī": 136, 141 | "ī": 137, 142 | "Ĭ": 138, 143 | "ĭ": 139, 144 | "Į": 140, 145 | "į": 141, 146 | "İ": 142, 147 | "ı": 143, 148 | "IJ": 144, 149 | "ij": 145, 150 | "Ĵ": 146, 151 | "ĵ": 147, 152 | "Ķ": 148, 153 | "ķ": 149, 154 | "ĸ": 150, 155 | "Ĺ": 151, 156 | "ĺ": 152, 157 | "Ļ": 153, 158 | "ļ": 154, 159 | "Ľ": 155, 160 | "ľ": 156, 161 | "Ŀ": 157, 162 | "ŀ": 158, 163 | "Ł": 159, 164 | "ł": 160, 165 | "¡": 161, 166 | "¢": 162, 167 | "£": 163, 168 | "¤": 164, 169 | "¥": 165, 170 | "¦": 166, 171 | "§": 167, 172 | "¨": 168, 173 | "©": 169, 174 | "ª": 170, 175 | "«": 171, 176 | "¬": 172, 177 | "Ń": 173, 178 | "®": 174, 179 | "¯": 175, 180 | "°": 176, 181 | "±": 177, 182 | "²": 178, 183 | "³": 179, 184 | "´": 180, 185 | "µ": 181, 186 | "¶": 182, 187 | "·": 183, 188 | "¸": 184, 189 | "¹": 185, 190 | "º": 186, 191 | "»": 187, 192 | "¼": 188, 193 | "½": 189, 194 | "¾": 190, 195 | "¿": 191, 196 | "À": 192, 197 | "Á": 193, 198 | "Â": 194, 199 | "Ã": 195, 200 | "Ä": 196, 201 | "Å": 197, 202 | "Æ": 198, 203 | "Ç": 199, 204 | "È": 200, 205 | "É": 201, 206 | "Ê": 202, 207 | "Ë": 203, 208 | "Ì": 204, 209 | "Í": 205, 210 | "Î": 206, 211 | "Ï": 207, 212 | "Ð": 208, 213 | "Ñ": 209, 214 | "Ò": 210, 215 | "Ó": 211, 216 | "Ô": 212, 217 | "Õ": 213, 218 | "Ö": 214, 219 | "×": 215, 220 | "Ø": 216, 221 | "Ù": 217, 222 | "Ú": 218, 223 | "Û": 219, 224 | "Ü": 220, 225 | "Ý": 221, 226 | "Þ": 222, 227 | "ß": 223, 228 | "à": 224, 229 | "á": 225, 230 | "â": 226, 231 | "ã": 227, 232 | "ä": 228, 233 | "å": 229, 234 | "æ": 230, 235 | "ç": 231, 236 | "è": 232, 237 | "é": 233, 238 | "ê": 234, 239 | "ë": 235, 240 | "ì": 236, 241 | "í": 237, 242 | "î": 238, 243 | "ï": 239, 244 | "ð": 240, 245 | "ñ": 241, 246 | "ò": 242, 247 | "ó": 243, 248 | "ô": 244, 249 | "õ": 245, 250 | "ö": 246, 251 | "÷": 247, 252 | "ø": 248, 253 | "ù": 249, 254 | "ú": 250, 255 | "û": 251, 256 | "ü": 252, 257 | "ý": 253, 258 | "þ": 254, 259 | "ÿ": 255, 260 | } 261 | BYTES_TO_CHARS = {v: k for k, v in CHARS_TO_BYTES.items()} 262 | MAX_CHARS_PER_TOKEN = 16 263 | # for hn training 264 | DEFAULT_SPLIT_REGEX = r"'s|'t|'re|'ve|'m|'ll|'d| ?[\p{L}\p{M}]+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+" -------------------------------------------------------------------------------- /tokenkit/data/__init__.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import datasets 4 | from datasets import interleave_datasets, load_dataset 5 | from torch.utils.data import Dataset 6 | from functools import partial 7 | 8 | from tokenkit.utils import preprocess_messages 9 | 10 | 11 | class JSONLDataset(Dataset): 12 | def __init__( 13 | self, path, lang_code, batch_size, n_subsample=None, num_workers=0, seed=None 14 | ): 15 | self.path = path 16 | self.lang_code = lang_code 17 | self.batch_size = batch_size 18 | 19 | self.raw_dset = load_dataset( 20 | "json", 21 | data_files=path, 22 | split="train" if n_subsample is None else f"train[:{n_subsample}]", 23 | ) 24 | self.dset = self.raw_dset.map( 25 | lambda x, idx: { 26 | "text": [x["text"]], 27 | "lang_code": [[lang_code] * len(x["text"])], 28 | "index": [idx], 29 | }, 30 | batched=True, 31 | batch_size=batch_size, 32 | drop_last_batch=True, 33 | remove_columns=self.raw_dset.column_names, 34 | with_indices=True, 35 | ) 36 | 37 | # TODO: check consistency between this and iterable data - iterable cycles through the data on consecutive calls 38 | # (but this does NOT influence the main stream) 39 | # TODO: return texts directly instead of dicts? 40 | def get_texts(self, n: int, lang_code: None | str = None): 41 | if lang_code is not None and lang_code != self.lang_code: 42 | raise ValueError("Invalid lang_code") 43 | 44 | texts = self.raw_dset[:n]["text"] 45 | return [{"text": t} for t in texts] 46 | 47 | def get_torch_dataset(self): 48 | return self.dset 49 | 50 | 51 | def process_example(example, lang_code): 52 | if "messages" in example: 53 | text = preprocess_messages(example["messages"]) 54 | else: 55 | text = example["text"] 56 | 57 | return { 58 | "text": text, 59 | "lang_code": lang_code, 60 | } 61 | 62 | 63 | class HFDataset: 64 | def __init__( 65 | self, 66 | dataset_configs, 67 | mix_languages, 68 | batch_size, 69 | streaming=True, 70 | n_subsample=None, 71 | shuffle_buffer_size=None, 72 | num_workers=0, 73 | seed=1234, 74 | ): 75 | self.dataset_configs = dataset_configs 76 | self.mix_languages = mix_languages 77 | self.seed = seed 78 | self.batch_size = batch_size 79 | self.shuffle_buffer_size = shuffle_buffer_size 80 | 81 | if n_subsample is not None: 82 | raise ValueError("Subsampling not supported for HF datasets") 83 | 84 | self.dset_streams = {} 85 | self.probs = {} 86 | 87 | if streaming: 88 | process_kwargs = {} 89 | else: 90 | process_kwargs = {"num_proc": num_workers if num_workers > 0 else None} 91 | 92 | for config in dataset_configs: 93 | stream = load_dataset( 94 | **config["kwargs"], 95 | streaming=streaming, 96 | trust_remote_code=True, 97 | ) 98 | 99 | if self.shuffle_buffer_size is not None: 100 | if streaming: 101 | stream = stream.shuffle( 102 | buffer_size=self.shuffle_buffer_size, seed=seed 103 | ) 104 | else: 105 | stream = stream.shuffle(seed=seed) 106 | 107 | self.dset_streams[config["lang_code"]] = stream.map( 108 | partial(process_example, lang_code=config["lang_code"]), **process_kwargs, remove_columns=stream.column_names if not streaming else None 109 | ) 110 | 111 | if "p" in config: 112 | self.probs[config["lang_code"]] = config["p"] 113 | 114 | if 0 < len(self.probs) < len(self.dset_streams): 115 | raise ValueError( 116 | "If you provide probabilities, you must provide them for all datasets" 117 | ) 118 | 119 | if len(self.probs) == 0: 120 | self.probs = {k: 1.0 for k in self.dset_streams.keys()} 121 | 122 | # normalize probabilities 123 | total = sum(self.probs.values()) 124 | for k in self.probs: 125 | self.probs[k] /= total 126 | 127 | if self.mix_languages: 128 | self.stream = interleave_datasets( 129 | list(self.dset_streams.values()), 130 | probabilities=list(self.probs.values()), 131 | seed=seed, 132 | ).batch(batch_size, drop_last_batch=True, **process_kwargs) 133 | else: 134 | self.stream = interleave_datasets( 135 | [ 136 | s.batch(batch_size, drop_last_batch=True, **process_kwargs) 137 | for s in self.dset_streams.values() 138 | ], 139 | probabilities=list(self.probs.values()), 140 | seed=seed, 141 | ) 142 | 143 | def get_texts(self, n: int, lang_code: None | str = None): 144 | if n == 0: 145 | return [] 146 | 147 | if lang_code is None: 148 | # unbatch 149 | batches_to_take = math.ceil(n / self.batch_size) 150 | batches = list(self.stream.take(batches_to_take)) 151 | out = [] 152 | keys = list(batches[0].keys()) 153 | 154 | for batch in batches: 155 | for i in range(len(batch[keys[0]])): 156 | out.append({k: batch[k][i] for k in keys}) 157 | 158 | return out 159 | 160 | return list(self.dset_streams[lang_code].take(n)) 161 | 162 | def get_torch_dataset(self): 163 | return self.stream 164 | 165 | 166 | class HFSavedDataset(Dataset): 167 | def __init__( 168 | self, 169 | dataset_configs, 170 | lang_code, 171 | batch_size, 172 | n_subsample=None, 173 | num_workers=0, 174 | seed=None, 175 | ): 176 | self.dataset_configs = dataset_configs 177 | self.lang_code = lang_code 178 | self.seed = seed 179 | self.batch_size = batch_size 180 | 181 | if n_subsample is not None: 182 | raise ValueError("Subsampling not supported for HF datasets") 183 | 184 | self.dsets = [] 185 | self.probs = [] 186 | 187 | for config in dataset_configs: 188 | self.dsets.append(datasets.load_from_disk(config["path"])["train"]) 189 | 190 | if "p" in config: 191 | self.probs.append(config["p"]) 192 | 193 | if 0 < len(self.probs) < len(self.dsets): 194 | raise ValueError( 195 | "If you provide probabilities, you must provide them for all datasets" 196 | ) 197 | 198 | if len(self.probs) == 0: 199 | self.probs = None 200 | 201 | self.dset = interleave_datasets( 202 | self.dsets, 203 | probabilities=self.probs, 204 | seed=seed, 205 | ) 206 | 207 | def __len__(self): 208 | return len(self.dset) // self.batch_size 209 | 210 | def get_texts(self, n: int, lang_code: None | str = None): 211 | if lang_code is not None and lang_code != self.lang_code: 212 | raise ValueError("Invalid lang_code") 213 | 214 | texts = self.dset[:n]["text"] 215 | return [{"text": t} for t in texts] 216 | 217 | def __getitem__(self, batch_idx): 218 | start, end = batch_idx * self.batch_size, (batch_idx + 1) * self.batch_size 219 | 220 | texts = self.dset[start:end]["text"] 221 | 222 | return { 223 | "text": texts, 224 | "lang_code": [self.lang_code] * len(texts), 225 | "index": list(range(start, start + len(texts))), 226 | } 227 | 228 | def get_torch_dataset(self): 229 | return self 230 | 231 | 232 | def get_dataset(kind, **kwargs): 233 | if kind == "jsonl": 234 | return JSONLDataset(**kwargs) 235 | elif kind == "hf": 236 | return HFDataset(**kwargs) 237 | elif kind == "hf_saved": 238 | return HFSavedDataset(**kwargs) 239 | else: 240 | raise ValueError("Invalid dataset kind") 241 | 242 | 243 | def test_load_tulu3(): 244 | dset = get_dataset( 245 | "hf", 246 | dataset_configs=[ 247 | { 248 | "lang_code": "en", 249 | "kwargs": {"path": "allenai/tulu-3-sft-mixture", "split": "train"}, 250 | } 251 | ], 252 | batch_size=16, 253 | num_workers=16, 254 | streaming=False, 255 | mix_languages=False, 256 | ) 257 | assert dset.get_texts(1)[0]["text"].startswith( 258 | "<||><||><||><||>" 259 | ) 260 | -------------------------------------------------------------------------------- /tokenkit/gcs_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from concurrent.futures import ThreadPoolExecutor 3 | 4 | from google.cloud import storage 5 | 6 | 7 | def download_from_gcs(bucket_name, source_blob_name, destination_file_name): 8 | storage_client = storage.Client() 9 | bucket = storage_client.bucket(bucket_name) 10 | blob = bucket.blob(source_blob_name) 11 | blob.download_to_filename(destination_file_name) 12 | print( 13 | f"Downloaded {source_blob_name} from bucket {bucket_name} to {destination_file_name}" 14 | ) 15 | 16 | 17 | def upload_to_gcs(bucket_name, source_file_name, destination_blob_name): 18 | storage_client = storage.Client() 19 | bucket = storage_client.bucket(bucket_name) 20 | blob = bucket.blob(destination_blob_name) 21 | blob.upload_from_filename(source_file_name) 22 | print( 23 | f"Uploaded {source_file_name} to bucket {bucket_name} as {destination_blob_name}" 24 | ) 25 | 26 | 27 | def is_gcs_path(path): 28 | return path.startswith("gs://") 29 | 30 | 31 | def parse_gcs_path(gcs_path): 32 | path_parts = gcs_path[len("gs://") :].split("/", 1) 33 | bucket_name = path_parts[0] 34 | blob_name = path_parts[1].rstrip("/") if len(path_parts) > 1 else "" 35 | return bucket_name, blob_name 36 | 37 | 38 | def upload_file(bucket_name, source_file_path, destination_blob_name): 39 | """ 40 | Uploads a single file to the specified GCS bucket. 41 | """ 42 | storage_client = storage.Client() 43 | bucket = storage_client.bucket(bucket_name) 44 | blob = bucket.blob(destination_blob_name) 45 | 46 | blob.upload_from_filename(source_file_path) 47 | print(f"Uploaded {source_file_path} to gs://{bucket_name}/{destination_blob_name}") 48 | 49 | 50 | def upload_directory_to_gcs(bucket_name, source_directory, target_directory=""): 51 | """ 52 | Uploads all files from a local directory to the specified GCS bucket, 53 | optionally under a target directory. 54 | 55 | Args: 56 | bucket_name (str): The name of the GCS bucket. 57 | source_directory (str): Path to the local directory to upload. 58 | target_directory (str): Optional target directory in the GCS bucket. 59 | """ 60 | executor = ThreadPoolExecutor() 61 | 62 | for root, _, files in os.walk(source_directory): 63 | for file in files: 64 | local_path = os.path.join(root, file) 65 | # Define the destination path in the bucket 66 | relative_path = os.path.relpath(local_path, source_directory) 67 | destination_path = os.path.join(target_directory, relative_path).replace( 68 | "\\", "/" 69 | ) # Ensure GCS path format 70 | # upload_file(bucket_name, local_path, destination_path) 71 | executor.submit(upload_file, bucket_name, local_path, destination_path) 72 | 73 | return executor 74 | -------------------------------------------------------------------------------- /tokenkit/hf/__init__.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoConfig, AutoModel, AutoModelForCausalLM, FlaxAutoModelForCausalLM 2 | from tokenkit.hf.configuration_tpu_llama import TPULlamaConfig 3 | from tokenkit.hf.modelling_tpu_llama import TPULlamaForCausalLM, TPULlamaModel 4 | from tokenkit.hf.modelling_flax_tpu_llama import FlaxTPULlamaForCausalLM, FlaxTPULlamaModel 5 | 6 | from tokenkit.hf.configuration_tpu_gemma2 import TPUGemma2Config 7 | from tokenkit.hf.modelling_tpu_gemma2 import TPUGemma2ForCausalLM, TPUGemma2Model 8 | from tokenkit.hf.modelling_flax_tpu_gemma2 import FlaxTPUGemma2ForCausalLM, FlaxTPUGemma2Model 9 | 10 | AutoConfig.register("tpu_llama", TPULlamaConfig) 11 | AutoModel.register(TPULlamaConfig, TPULlamaModel) 12 | AutoModelForCausalLM.register(TPULlamaConfig, TPULlamaForCausalLM) 13 | TPULlamaForCausalLM.register_for_auto_class("AutoModelForCausalLM") 14 | TPULlamaModel.register_for_auto_class("AutoModel") 15 | FlaxAutoModelForCausalLM.register(TPULlamaConfig, FlaxTPULlamaForCausalLM) 16 | FlaxTPULlamaForCausalLM.register_for_auto_class("FlaxAutoModelForCausalLM") 17 | FlaxTPULlamaModel.register_for_auto_class("FlaxAutoModel") 18 | 19 | AutoConfig.register("tpu_gemma2", TPUGemma2Config) 20 | AutoModel.register(TPUGemma2Config, TPUGemma2Model) 21 | AutoModelForCausalLM.register(TPUGemma2Config, TPUGemma2ForCausalLM) 22 | TPUGemma2ForCausalLM.register_for_auto_class("AutoModelForCausalLM") 23 | TPUGemma2Model.register_for_auto_class("AutoModel") 24 | FlaxAutoModelForCausalLM.register(TPUGemma2Config, FlaxTPUGemma2ForCausalLM) 25 | FlaxTPUGemma2ForCausalLM.register_for_auto_class("FlaxAutoModelForCausalLM") 26 | FlaxTPUGemma2Model.register_for_auto_class("FlaxAutoModel") 27 | 28 | __all__ = ["TPULlamaConfig", "TPULlamaModel", "TPULlamaForCausalLM", "FlaxTPULlamaForCausalLM", "FlaxTPULlamaModel", "TPUGemma2Config", "TPUGemma2Model", "TPUGemma2ForCausalLM", "FlaxTPUGemma2ForCausalLM", "FlaxTPUGemma2Model"] 29 | 30 | 31 | def get_config(pretrained_model_name_or_path: str, **kwargs): 32 | config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) 33 | # compatibility with outside jax checkpoints 34 | if config.model_type in {"llama", "tpu_llama"}: 35 | config = TPULlamaConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) 36 | config.model_type = "tpu_llama" 37 | return config 38 | elif config.model_type in {"gemma2", "tpu_gemma2"}: 39 | config = TPUGemma2Config.from_pretrained(pretrained_model_name_or_path, **kwargs) 40 | config.model_type = "tpu_gemma2" 41 | return config 42 | else: 43 | return config 44 | -------------------------------------------------------------------------------- /tokenkit/hf/configuration_tpu_gemma2.py: -------------------------------------------------------------------------------- 1 | """TPU Gemma2 model configuration""" 2 | 3 | from transformers.configuration_utils import PretrainedConfig 4 | 5 | 6 | class TPUGemma2Config(PretrainedConfig): 7 | r""" 8 | This is the configuration class to store the configuration of a [`Gemma2Model`]. It is used to instantiate an Gemma2 9 | model according to the specified arguments, defining the model architecture. Instantiating a configuration with the 10 | defaults will yield a similar configuration to that of the Gemma2-7B. 11 | e.g. [google/gemma2-7b](https://huggingface.co/google/gemma2-7b) 12 | Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the 13 | documentation from [`PretrainedConfig`] for more information. 14 | Args: 15 | vocab_size (`int`, *optional*, defaults to 256000): 16 | Vocabulary size of the Gemma2 model. Defines the number of different tokens that can be represented by the 17 | `inputs_ids` passed when calling [`Gemma2Model`] 18 | hidden_size (`int`, *optional*, defaults to 3072): 19 | Dimension of the hidden representations. 20 | intermediate_size (`int`, *optional*, defaults to 24576): 21 | Dimension of the MLP representations. 22 | num_hidden_layers (`int`, *optional*, defaults to 28): 23 | Number of hidden layers in the Transformer decoder. 24 | num_attention_heads (`int`, *optional*, defaults to 16): 25 | Number of attention heads for each attention layer in the Transformer decoder. 26 | num_key_value_heads (`int`, *optional*, defaults to 16): 27 | This is the number of key_value heads that should be used to implement Grouped Query Attention. If 28 | `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if 29 | `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When 30 | converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed 31 | by meanpooling all the original heads within that group. For more details checkout [this 32 | paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to 33 | `num_attention_heads`. 34 | head_dim (`int`, *optional*, defaults to 256): 35 | The attention head dimension. 36 | hidden_activation (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`): 37 | The non-linear activation function (function or string) in the decoder. Will default to `"gelu_pytorch_tanh"` 38 | if not specified. `"gelu_pytorch_tanh"` uses an approximation of the `"gelu"` activation function. 39 | max_position_embeddings (`int`, *optional*, defaults to 8192): 40 | The maximum sequence length that this model might ever be used with. 41 | initializer_range (`float`, *optional*, defaults to 0.02): 42 | The standard deviation of the truncated_normal_initializer for initializing all weight matrices. 43 | rms_norm_eps (`float`, *optional*, defaults to 1e-06): 44 | The epsilon used by the rms normalization layers. 45 | use_cache (`bool`, *optional*, defaults to `True`): 46 | Whether or not the model should return the last key/values attentions (not used by all models). Only 47 | relevant if `config.is_decoder=True`. 48 | pad_token_id (`int`, *optional*, defaults to 0): 49 | Padding token id. 50 | eos_token_id (`int`, *optional*, defaults to 1): 51 | End of stream token id. 52 | bos_token_id (`int`, *optional*, defaults to 2): 53 | Beginning of stream token id. 54 | tie_word_embeddings (`bool`, *optional*, defaults to `True`): 55 | Whether to tie weight embeddings 56 | rope_theta (`float`, *optional*, defaults to 10000.0): 57 | The base period of the RoPE embeddings. 58 | attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): 59 | Whether to use a bias in the query, key, value and output projection layers during self-attention. 60 | attention_dropout (`float`, *optional*, defaults to 0.0): 61 | The dropout ratio for the attention probabilities. 62 | query_pre_attn_scalar (`float`, *optional*, defaults to 224): scaling factor used on the attention scores 63 | sliding_window (`int`, *optional*, defaults to 4096): in Gemma2, every other layer uses sliding window attention. This is the 64 | size of the sliding window. 65 | final_logit_softcapping (`float`, *optional*, defaults to 30.0): scaling factor when applying tanh softcapping on the logits. 66 | attn_logit_softcapping (`float`, *optional*, defaults to 50.0): scaling factor when applying tanh softcapping on the attention scores. 67 | cache_implementation (`str`, *optional*, defaults to `"hybrid"`): the cache type to be used with `generate`. 68 | 69 | ```python 70 | >>> from transformers import Gemma2Model, Gemma2Config 71 | >>> # Initializing a Gemma2 gemma2-7b style configuration 72 | >>> configuration = Gemma2Config() 73 | >>> # Initializing a model from the gemma2-7b style configuration 74 | >>> model = Gemma2Model(configuration) 75 | >>> # Accessing the model configuration 76 | >>> configuration = model.config 77 | ```""" 78 | 79 | model_type = "tpu_gemma2" 80 | keys_to_ignore_at_inference = ["past_key_values"] 81 | 82 | def __init__( 83 | self, 84 | vocab_size=256000, 85 | hidden_size=3072, 86 | intermediate_size=24576, 87 | num_hidden_layers=28, 88 | num_attention_heads=16, 89 | num_key_value_heads=16, 90 | head_dim=256, 91 | hidden_activation="gelu_pytorch_tanh", 92 | max_position_embeddings=8192, 93 | initializer_range=0.02, 94 | rms_norm_eps=1e-6, 95 | use_cache=True, 96 | pad_token_id=0, 97 | eos_token_id=1, 98 | bos_token_id=2, 99 | tie_word_embeddings=True, 100 | rope_theta=10000.0, 101 | attention_bias=False, 102 | attention_dropout=0.0, 103 | query_pre_attn_scalar=224, 104 | sliding_window=4096, 105 | final_logit_softcapping=30.0, 106 | attn_logit_softcapping=50.0, 107 | cache_implementation="hybrid", 108 | expand_input_ids=False, # Transformers-native PyTorch generation support 109 | expand_input_ids_maxlen=None, 110 | expand_input_ids_vocab_size=None, 111 | expand_input_ids_dict=None, 112 | **kwargs, 113 | ): 114 | super().__init__( 115 | pad_token_id=pad_token_id, 116 | bos_token_id=bos_token_id, 117 | eos_token_id=eos_token_id, 118 | tie_word_embeddings=tie_word_embeddings, 119 | **kwargs, 120 | ) 121 | self.vocab_size = vocab_size 122 | self.max_position_embeddings = max_position_embeddings 123 | self.hidden_size = hidden_size 124 | self.intermediate_size = intermediate_size 125 | self.num_hidden_layers = num_hidden_layers 126 | self.num_attention_heads = num_attention_heads 127 | self.head_dim = head_dim 128 | self.num_key_value_heads = num_key_value_heads 129 | self.initializer_range = initializer_range 130 | self.rms_norm_eps = rms_norm_eps 131 | self.use_cache = use_cache 132 | self.rope_theta = rope_theta 133 | self.attention_bias = attention_bias 134 | self.attention_dropout = attention_dropout 135 | self.hidden_activation = hidden_activation 136 | self.query_pre_attn_scalar = query_pre_attn_scalar 137 | self.sliding_window = sliding_window 138 | self.final_logit_softcapping = final_logit_softcapping 139 | self.attn_logit_softcapping = attn_logit_softcapping 140 | self.cache_implementation = cache_implementation 141 | 142 | self.expand_input_ids = expand_input_ids 143 | self.expand_input_ids_maxlen = expand_input_ids_maxlen 144 | self.expand_input_ids_vocab_size = expand_input_ids_vocab_size 145 | self.expand_input_ids_dict = expand_input_ids_dict 146 | -------------------------------------------------------------------------------- /tokenkit/hf/configuration_tpu_llama.py: -------------------------------------------------------------------------------- 1 | """TPU LLaMA model configuration""" 2 | 3 | from transformers.configuration_utils import PretrainedConfig 4 | from transformers.modeling_rope_utils import rope_config_validation 5 | 6 | 7 | class TPULlamaConfig(PretrainedConfig): 8 | r""" 9 | This is the configuration class to store the configuration of a [`LlamaModel`]. It is used to instantiate an LLaMA 10 | model according to the specified arguments, defining the model architecture. Instantiating a configuration with the 11 | defaults will yield a similar configuration to that of the LLaMA-7B. 12 | 13 | Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the 14 | documentation from [`PretrainedConfig`] for more information. 15 | 16 | 17 | Args: 18 | vocab_size (`int`, *optional*, defaults to 32000): 19 | Vocabulary size of the LLaMA model. Defines the number of different tokens that can be represented by the 20 | `inputs_ids` passed when calling [`TPULlamaModel`] 21 | hidden_size (`int`, *optional*, defaults to 4096): 22 | Dimension of the hidden representations. 23 | intermediate_size (`int`, *optional*, defaults to 11008): 24 | Dimension of the MLP representations. 25 | num_hidden_layers (`int`, *optional*, defaults to 32): 26 | Number of hidden layers in the Transformer decoder. 27 | num_attention_heads (`int`, *optional*, defaults to 32): 28 | Number of attention heads for each attention layer in the Transformer decoder. 29 | num_key_value_heads (`int`, *optional*): 30 | This is the number of key_value heads that should be used to implement Grouped Query Attention. If 31 | `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if 32 | `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When 33 | converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed 34 | by meanpooling all the original heads within that group. For more details checkout [this 35 | paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to 36 | `num_attention_heads`. 37 | hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): 38 | The non-linear activation function (function or string) in the decoder. 39 | max_position_embeddings (`int`, *optional*, defaults to 2048): 40 | The maximum sequence length that this model might ever be used with. Llama 1 supports up to 2048 tokens, 41 | Llama 2 up to 4096, CodeLlama up to 16384. 42 | initializer_range (`float`, *optional*, defaults to 0.02): 43 | The standard deviation of the truncated_normal_initializer for initializing all weight matrices. 44 | rms_norm_eps (`float`, *optional*, defaults to 1e-06): 45 | The epsilon used by the rms normalization layers. 46 | use_cache (`bool`, *optional*, defaults to `True`): 47 | Whether or not the model should return the last key/values attentions (not used by all models). Only 48 | relevant if `config.is_decoder=True`. 49 | pad_token_id (`int`, *optional*): 50 | Padding token id. 51 | bos_token_id (`int`, *optional*, defaults to 1): 52 | Beginning of stream token id. 53 | eos_token_id (`int`, *optional*, defaults to 2): 54 | End of stream token id. 55 | pretraining_tp (`int`, *optional*, defaults to 1): 56 | Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this 57 | document](https://huggingface.co/docs/transformers/main/perf_train_gpu_many#tensor-parallelism) to 58 | understand more about it. This value is necessary to ensure exact reproducibility of the pretraining 59 | results. Please refer to [this issue](https://github.com/pytorch/pytorch/issues/76232). 60 | tie_word_embeddings (`bool`, *optional*, defaults to `False`): 61 | Whether to tie weight embeddings 62 | rope_theta (`float`, *optional*, defaults to 10000.0): 63 | The base period of the RoPE embeddings. 64 | rope_scaling (`Dict`, *optional*): 65 | Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type 66 | and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value 67 | accordingly. 68 | Expected contents: 69 | `rope_type` (`str`): 70 | The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', 71 | 'llama3'], with 'default' being the original RoPE implementation. 72 | `factor` (`float`, *optional*): 73 | Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In 74 | most scaling types, a `factor` of x will enable the model to handle sequences of length x * 75 | original maximum pre-trained length. 76 | `original_max_position_embeddings` (`int`, *optional*): 77 | Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during 78 | pretraining. 79 | `attention_factor` (`float`, *optional*): 80 | Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention 81 | computation. If unspecified, it defaults to value recommended by the implementation, using the 82 | `factor` field to infer the suggested value. 83 | `beta_fast` (`float`, *optional*): 84 | Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear 85 | ramp function. If unspecified, it defaults to 32. 86 | `beta_slow` (`float`, *optional*): 87 | Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear 88 | ramp function. If unspecified, it defaults to 1. 89 | `short_factor` (`List[float]`, *optional*): 90 | Only used with 'longrope'. The scaling factor to be applied to short contexts (< 91 | `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden 92 | size divided by the number of attention heads divided by 2 93 | `long_factor` (`List[float]`, *optional*): 94 | Only used with 'longrope'. The scaling factor to be applied to long contexts (< 95 | `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden 96 | size divided by the number of attention heads divided by 2 97 | `low_freq_factor` (`float`, *optional*): 98 | Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE 99 | `high_freq_factor` (`float`, *optional*): 100 | Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE 101 | attention_bias (`bool`, *optional*, defaults to `False`): 102 | Whether to use a bias in the query, key, value and output projection layers during self-attention. 103 | attention_dropout (`float`, *optional*, defaults to 0.0): 104 | The dropout ratio for the attention probabilities. 105 | mlp_bias (`bool`, *optional*, defaults to `False`): 106 | Whether to use a bias in up_proj, down_proj and gate_proj layers in the MLP layers. 107 | head_dim (`int`, *optional*): 108 | The attention head dimension. If None, it will default to hidden_size // num_heads 109 | 110 | ```python 111 | >>> from transformers import LlamaModel, LlamaConfig 112 | 113 | >>> # Initializing a LLaMA llama-7b style configuration 114 | >>> configuration = LlamaConfig() 115 | 116 | >>> # Initializing a model from the llama-7b style configuration 117 | >>> model = LlamaModel(configuration) 118 | 119 | >>> # Accessing the model configuration 120 | >>> configuration = model.config 121 | ```""" 122 | 123 | model_type = "tpu_llama" 124 | keys_to_ignore_at_inference = ["past_key_values"] 125 | 126 | def __init__( 127 | self, 128 | vocab_size=32000, 129 | hidden_size=4096, 130 | intermediate_size=11008, 131 | num_hidden_layers=32, 132 | num_attention_heads=32, 133 | num_key_value_heads=None, 134 | hidden_act="silu", 135 | max_position_embeddings=2048, 136 | initializer_range=0.02, 137 | rms_norm_eps=1e-6, 138 | use_cache=True, 139 | pad_token_id=None, 140 | bos_token_id=1, 141 | eos_token_id=2, 142 | pretraining_tp=1, 143 | tie_word_embeddings=False, 144 | rope_theta=10000.0, 145 | rope_scaling=None, 146 | attention_bias=False, 147 | attention_dropout=0.0, 148 | mlp_bias=False, 149 | head_dim=None, 150 | expand_input_ids=False, # Transformers-native PyTorch generation support 151 | expand_input_ids_maxlen=None, 152 | expand_input_ids_vocab_size=None, 153 | expand_input_ids_dict=None, 154 | **kwargs, 155 | ): 156 | self.vocab_size = vocab_size 157 | self.max_position_embeddings = max_position_embeddings 158 | self.hidden_size = hidden_size 159 | self.intermediate_size = intermediate_size 160 | self.num_hidden_layers = num_hidden_layers 161 | self.num_attention_heads = num_attention_heads 162 | 163 | # for backward compatibility 164 | if num_key_value_heads is None: 165 | num_key_value_heads = num_attention_heads 166 | 167 | self.num_key_value_heads = num_key_value_heads 168 | self.hidden_act = hidden_act 169 | self.initializer_range = initializer_range 170 | self.rms_norm_eps = rms_norm_eps 171 | self.pretraining_tp = pretraining_tp 172 | self.use_cache = use_cache 173 | self.rope_theta = rope_theta 174 | self.rope_scaling = rope_scaling 175 | self.attention_bias = attention_bias 176 | self.attention_dropout = attention_dropout 177 | self.mlp_bias = mlp_bias 178 | self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads 179 | # Validate the correctness of rotary position embeddings parameters 180 | # BC: if there is a 'type' field, copy it it to 'rope_type'. 181 | if self.rope_scaling is not None and "type" in self.rope_scaling: 182 | self.rope_scaling["rope_type"] = self.rope_scaling["type"] 183 | rope_config_validation(self) 184 | 185 | self.expand_input_ids = expand_input_ids 186 | self.expand_input_ids_maxlen = expand_input_ids_maxlen 187 | self.expand_input_ids_vocab_size = expand_input_ids_vocab_size 188 | self.expand_input_ids_dict = expand_input_ids_dict 189 | 190 | super().__init__( 191 | pad_token_id=pad_token_id, 192 | bos_token_id=bos_token_id, 193 | eos_token_id=eos_token_id, 194 | tie_word_embeddings=tie_word_embeddings, 195 | **kwargs, 196 | ) -------------------------------------------------------------------------------- /tokenkit/model_kinds.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Callable, Dict, List, Optional 3 | 4 | from tokenkit.constants import BYTES_TO_CHARS, CHARS_TO_BYTES 5 | 6 | BYTE_FALLBACK_MAP = {f"<0x{num:02X}>": num for num in range(256)} 7 | INV_BYTE_FALLBACK_MAP = {v: k for k, v in BYTE_FALLBACK_MAP.items()} 8 | 9 | 10 | def sentencepiece_byte_fallback_byte_fn(token: str) -> str: 11 | if token in BYTE_FALLBACK_MAP: 12 | return BYTES_TO_CHARS[BYTE_FALLBACK_MAP[token]] 13 | else: 14 | return "".join( 15 | BYTES_TO_CHARS[b] for b in token.replace("▁", " ").encode("utf-8") 16 | ) 17 | 18 | 19 | def sentencepiece_byte_fallback_precedence_fn(token: str) -> int: 20 | if token in BYTE_FALLBACK_MAP: 21 | return 0 22 | else: 23 | return 1 24 | 25 | 26 | def identity_byte_fn(token: str) -> str: 27 | return token 28 | 29 | 30 | class BaseModelKind(ABC): 31 | SPECIAL_KEYS = [ 32 | "<||>", 33 | "<||>", 34 | "<||>", 35 | "<||>", 36 | "<||>", 37 | "<||>", 38 | "<||>", 39 | "<||>", 40 | ] 41 | 42 | def __init__(self): 43 | self._byte_fallback_fn = identity_byte_fn 44 | self._byte_fallback_precedence_fn = lambda x: 0 45 | 46 | @property 47 | @abstractmethod 48 | def special_tokens(self) -> List[str]: 49 | pass 50 | 51 | @property 52 | @abstractmethod 53 | def replacements(self) -> Dict[str, Optional[List[str]]]: 54 | pass 55 | 56 | @property 57 | def byte_fallback_fn(self) -> Callable[[str], str]: 58 | return self._byte_fallback_fn 59 | 60 | @byte_fallback_fn.setter 61 | def byte_fallback_fn(self, value: Callable[[str], str]): 62 | self._byte_fallback_fn = value 63 | 64 | @property 65 | def byte_fallback_precedence_fn(self) -> Callable[[str], int]: 66 | return self._byte_fallback_precedence_fn 67 | 68 | @byte_fallback_precedence_fn.setter 69 | def byte_fallback_precedence_fn(self, value: Callable[[str], int]): 70 | self._byte_fallback_precedence_fn = value 71 | 72 | 73 | class Qwen2ModelKind(BaseModelKind): 74 | @property 75 | def special_tokens(self) -> List[str]: 76 | return ["<|im_start|>", "<|im_end|>", "<|endoftext|>"] 77 | 78 | @property 79 | def replacements(self) -> Dict[str, Optional[List[str]]]: 80 | return { 81 | "<||>": None, 82 | "<||>": ["<|endoftext|>"], 83 | "<||>": ["<|im_start|>"], 84 | "<||>": ["Ċ"], 85 | "<||>": ["<|endoftext|>"], 86 | "<||>": ["<|im_end|>", "Ċ"], 87 | "<||>": ["system"], 88 | "<||>": ["user"], 89 | "<||>": ["assistant"], 90 | } 91 | 92 | 93 | class Llama3ModelKind(BaseModelKind): 94 | @property 95 | def special_tokens(self) -> List[str]: 96 | return [ 97 | "<|begin_of_text|>", 98 | "<|start_header_id|>", 99 | "<|end_header_id|>", 100 | "<|eot_id|>", 101 | "<|end_of_text|>", 102 | ] 103 | 104 | @property 105 | def replacements(self) -> Dict[str, Optional[List[str]]]: 106 | return { 107 | "<||>": ["<|begin_of_text|>"], 108 | "<||>": ["<|end_of_text|>"], 109 | "<||>": ["<|start_header_id|>"], 110 | "<||>": ["<|end_header_id|>", "ĊĊ"], 111 | # give eot precedence over eos - not ideal but should work for chat templates 112 | "<||>": ["<|eot_id|>"], 113 | "<||>": ["<|eot_id|>"], 114 | "<||>": ["system"], 115 | "<||>": ["user"], 116 | "<||>": ["assistant"], 117 | } 118 | 119 | 120 | class Gemma2ModelKind(BaseModelKind): 121 | def __init__(self): 122 | super().__init__() 123 | self._byte_fallback_fn = sentencepiece_byte_fallback_byte_fn 124 | self._byte_fallback_precedence_fn = sentencepiece_byte_fallback_precedence_fn 125 | 126 | @property 127 | def special_tokens(self) -> List[str]: 128 | return ["", "", "", "", ""] 129 | 130 | @property 131 | def replacements(self) -> Dict[str, Optional[List[str]]]: 132 | return { 133 | "<||>": [""], 134 | "<||>": [""], 135 | "<||>": [""], 136 | "<||>": ["Ċ"], 137 | "<||>": [""], 138 | "<||>": ["", "Ċ"], 139 | "<||>": ["user"], 140 | "<||>": ["user"], 141 | "<||>": ["model"], 142 | } 143 | 144 | 145 | class Phi3ModelKind(BaseModelKind): 146 | @property 147 | def special_tokens(self) -> List[str]: 148 | return ["<|user|>", "<|assistant|>", "<|end|>", "<|endoftext|>"] 149 | 150 | @property 151 | def replacements(self) -> Dict[str, Optional[List[str]]]: 152 | return { 153 | "<||>": None, 154 | "<||>": ["<|endoftext|>"], 155 | "<||>": None, 156 | "<||>": ["Ċ"], 157 | "<||>": ["<|endoftext|>"], 158 | "<||>": ["<|end|>", "Ċ"], 159 | "<||>": ["<|user|>"], 160 | "<||>": ["<|user|>"], 161 | "<||>": ["<|assistant|>"], 162 | } 163 | 164 | 165 | class GPT2ModelKind(BaseModelKind): 166 | @property 167 | def special_tokens(self) -> List[str]: 168 | return ["<|endoftext|>"] 169 | 170 | @property 171 | def replacements(self) -> Dict[str, Optional[List[str]]]: 172 | return { 173 | "<||>": None, 174 | "<||>": ["<|endoftext|>"], 175 | "<||>": None, 176 | "<||>": None, 177 | "<||>": ["<|endoftext|>"], 178 | "<||>": ["<|endoftext|>"], 179 | "<||>": None, 180 | "<||>": None, 181 | "<||>": None, 182 | } 183 | 184 | 185 | class TinyLlamaModelKind(BaseModelKind): 186 | def __init__(self): 187 | super().__init__() 188 | self._byte_fallback_fn = sentencepiece_byte_fallback_byte_fn 189 | self._byte_fallback_precedence_fn = sentencepiece_byte_fallback_precedence_fn 190 | 191 | @property 192 | def special_tokens(self) -> List[str]: 193 | return ["", "", ""] 194 | 195 | @property 196 | def replacements(self) -> Dict[str, Optional[List[str]]]: 197 | return { 198 | "<||>": [""], 199 | "<||>": [""], 200 | "<||>": None, # chat template exists but not supported for TinyLlama 201 | "<||>": None, 202 | "<||>": [""], 203 | "<||>": [""], 204 | "<||>": None, 205 | "<||>": None, 206 | "<||>": None, 207 | } 208 | 209 | 210 | class MistralModelKind(BaseModelKind): 211 | def __init__(self): 212 | super().__init__() 213 | 214 | @property 215 | def special_tokens(self) -> List[str]: 216 | return ["", "", ""] 217 | 218 | @property 219 | def replacements(self) -> Dict[str, Optional[List[str]]]: 220 | return { 221 | "<||>": [""], 222 | "<||>": [""], 223 | "<||>": None, # chat template exists but not supported for Mistral 224 | "<||>": None, 225 | "<||>": [""], 226 | "<||>": [""], 227 | "<||>": None, 228 | "<||>": None, 229 | "<||>": None, 230 | } 231 | 232 | # Model kind registry 233 | def get_model_kind_cls(model_kind: str) -> BaseModelKind: 234 | return { 235 | "Qwen2": Qwen2ModelKind(), 236 | "Llama3": Llama3ModelKind(), 237 | "Gemma2": Gemma2ModelKind(), 238 | "Phi3": Phi3ModelKind(), 239 | "GPT2": GPT2ModelKind(), 240 | "TinyLlama": TinyLlamaModelKind(), 241 | "Mistral": MistralModelKind(), 242 | }[model_kind] 243 | -------------------------------------------------------------------------------- /tokenkit/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bminixhofer/tokenkit/7abac3578f6b3fa38f985e9f03bff7a47d5ab3b1/tokenkit/models/__init__.py -------------------------------------------------------------------------------- /tokenkit/models/lora.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | import numpy as np 4 | import regex as re 5 | 6 | from tokenkit import utils 7 | 8 | LORA_PATTERNS = { 9 | "llama": [ 10 | ".*self_attn.(q_proj|k_proj|v_proj).kernel", 11 | ".*self_attn.o_proj.kernel", 12 | ".*mlp.down_proj.kernel", 13 | ".*mlp.up_proj.kernel", 14 | ".*mlp.gate_proj.kernel", 15 | ], 16 | "gemma2": [ 17 | ".*self_attn.(q_proj|k_proj|v_proj).kernel", 18 | ".*self_attn.o_proj.kernel", 19 | ".*mlp.down_proj.kernel", 20 | ".*mlp.up_proj.kernel", 21 | ".*mlp.gate_proj.kernel", 22 | ], 23 | } 24 | LORA_PATTERNS["tpu_llama"] = LORA_PATTERNS["llama"] 25 | LORA_PATTERNS["tpu_gemma2"] = LORA_PATTERNS["gemma2"] 26 | 27 | 28 | def init_lora_params(args, params, model_type, seed, dtype=jnp.float32): 29 | def iter_keys(key): 30 | while True: 31 | key, out_key = jax.random.split(key) 32 | yield out_key 33 | 34 | key_it = iter_keys(jax.random.PRNGKey(seed)) 35 | 36 | lora_patterns = LORA_PATTERNS[model_type] 37 | lora_rank = args.model_lora_rank 38 | stddev = 1.0 / lora_rank 39 | 40 | def init_lora(path, param): 41 | path_tuple = tuple(str(utils.keystr(x)) for x in path) 42 | path = ".".join(path_tuple) 43 | 44 | lora_params = np.array([]) # indicates no lora params 45 | 46 | for key in lora_patterns: 47 | if re.match(key, path): 48 | assert len(param.shape) == 2 49 | b_dim, a_dim = param.shape 50 | 51 | b = np.zeros((b_dim, lora_rank), dtype=dtype) 52 | a = jax.device_get( 53 | jax.random.normal(next(key_it), (lora_rank, a_dim), dtype=dtype) 54 | * stddev 55 | ) 56 | lora_params = {"a": a, "b": b} 57 | 58 | return lora_params 59 | 60 | return jax.tree_util.tree_map_with_path(init_lora, params) 61 | 62 | 63 | def materialize_lora(param_tree, lora_param_tree, alpha): 64 | def materialize(param, lora_params): 65 | if not isinstance(lora_params, dict): 66 | assert lora_params.shape[0] == 0 67 | return param 68 | 69 | a, b = lora_params["a"], lora_params["b"] 70 | scale = alpha / b.shape[-1] 71 | 72 | return (param + scale * b @ a).astype(param.dtype) 73 | 74 | return jax.tree.map(materialize, param_tree, lora_param_tree) 75 | 76 | 77 | # NOTE: not clear if this is save w.r.t. rounding errors. probably not? dangerous. 78 | # NOTE: update: no instability so far, seems safe in fp32. but still dangerous. 79 | def dematerialize_lora(param_tree, lora_param_tree, alpha): 80 | def dematerialize(param, lora_params): 81 | if not isinstance(lora_params, dict): 82 | return param 83 | 84 | a, b = lora_params["a"], lora_params["b"] 85 | scale = alpha / b.shape[-1] 86 | 87 | return (param - scale * b @ a).astype(param.dtype) 88 | 89 | return jax.tree.map(dematerialize, param_tree, lora_param_tree) 90 | -------------------------------------------------------------------------------- /tokenkit/models/param.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import json 3 | 4 | import jax.numpy as jnp 5 | import numpy as np 6 | from flax import serialization, traverse_util 7 | from transformers import AutoConfig 8 | from transformers.utils.hub import cached_file 9 | 10 | 11 | def get_input_embedding_path(model_type): 12 | return { 13 | "gpt2": "transformer.wte.embedding", 14 | "roberta": "roberta.embeddings.word_embeddings.embedding", 15 | "xlm-roberta": "roberta.embeddings.word_embeddings.embedding", 16 | "xglm": "model.embed_tokens.embedding", 17 | "mistral": "model.embed_tokens.embedding", 18 | "llama": "model.embed_tokens.embedding", 19 | "tpu_llama": "model.embed_tokens.embedding", 20 | "gemma": "model.embed_tokens.embedding", 21 | "gemma2": "model.embed_tokens.embedding", 22 | "tpu_gemma2": "model.embed_tokens.embedding", 23 | }[model_type] 24 | 25 | 26 | def get_output_embedding_path(model_type): 27 | return { 28 | "gpt2": "lm_head.kernel", 29 | "roberta": None, 30 | "xlm-roberta": None, 31 | "xglm": None, 32 | "mistral": "lm_head.kernel", 33 | "llama": "lm_head.kernel", 34 | "tpu_llama": "lm_head.kernel", 35 | "gemma": "lm_head.kernel", 36 | "gemma2": "lm_head.kernel", 37 | "tpu_gemma2": "lm_head.kernel", 38 | }[model_type] 39 | 40 | 41 | def get_layer_path(model_type): 42 | return { 43 | "gemma2": "model.layers", 44 | "gpt2": "transformer.h", 45 | "llama": "model.layers", 46 | "tpu_llama": "model.layers", 47 | "tpu_gemma2": "model.layers", 48 | }[model_type] 49 | 50 | 51 | def load_params(**kwargs): 52 | kwargs = copy.copy(kwargs) 53 | config = AutoConfig.from_pretrained(**kwargs) 54 | path = kwargs.pop("pretrained_model_name_or_path") 55 | embedding_path = kwargs.pop("embedding_path", None) 56 | 57 | try: 58 | index = cached_file(path, "flax_model.msgpack.index.json", **kwargs) 59 | except OSError: 60 | index = None 61 | 62 | if index is not None: 63 | index = json.load(open(index)) 64 | files = [ 65 | cached_file(path, x, **kwargs) for x in set(index["weight_map"].values()) 66 | ] 67 | else: 68 | files = [cached_file(path, "flax_model.msgpack", **kwargs)] 69 | 70 | flat_params = {} 71 | for x in files: 72 | flat_params.update( 73 | traverse_util.flatten_dict( 74 | serialization.msgpack_restore(open(x, "rb").read()) 75 | ) 76 | ) 77 | 78 | params = traverse_util.unflatten_dict(flat_params) 79 | 80 | if embedding_path is not None: 81 | embeddings = np.load(embedding_path) 82 | params = put( 83 | params, get_input_embedding_path(config.model_type), embeddings[:, 0] 84 | ) 85 | if embeddings.shape[1] > 1: 86 | params = put( 87 | params, get_output_embedding_path(config.model_type), embeddings[:, 1].T 88 | ) 89 | 90 | return params 91 | 92 | 93 | def put(pytree, path, value): 94 | path = tuple(path.split(".")) 95 | 96 | flat_pytree = traverse_util.flatten_dict(pytree) 97 | # this is potentially safer than simply overwriting, preserves dtype etc. 98 | if path in flat_pytree and isinstance(flat_pytree[path], jnp.ndarray): 99 | flat_pytree[path] = flat_pytree[path].at[:].set(value) 100 | else: 101 | flat_pytree[path] = value 102 | 103 | return traverse_util.unflatten_dict(flat_pytree) 104 | 105 | 106 | def pop(pytree, path): 107 | path = tuple(path.split(".")) 108 | flat_pytree = traverse_util.flatten_dict(pytree) 109 | if path in flat_pytree: 110 | value = flat_pytree.pop(path) 111 | else: 112 | value = None 113 | 114 | return traverse_util.unflatten_dict(flat_pytree), value 115 | 116 | 117 | def get(pytree, path): 118 | path = tuple(path.split(".")) 119 | out = traverse_util.flatten_dict(pytree)[path] 120 | 121 | if isinstance(out, dict): 122 | return traverse_util.unflatten_dict(out) 123 | else: 124 | return out 125 | 126 | 127 | def keys(pytree): 128 | return [".".join(x) for x in traverse_util.flatten_dict(pytree).keys()] 129 | 130 | 131 | def assign_embeddings(model_params, embeddings, config): 132 | model_params = put( 133 | model_params, 134 | get_input_embedding_path(config.model_type), 135 | embeddings[:, 0], 136 | ) 137 | if not config.tie_word_embeddings: 138 | model_params = put( 139 | model_params, 140 | get_output_embedding_path(config.model_type), 141 | embeddings[:, 1].T, 142 | ) 143 | 144 | return model_params 145 | 146 | 147 | def unassign_embeddings(model_params, config): 148 | model_params, x = pop(model_params, get_input_embedding_path(config.model_type)) 149 | if isinstance(x, jnp.ndarray): 150 | x.delete() 151 | if get_output_embedding_path(config.model_type): 152 | model_params, x = pop( 153 | model_params, get_output_embedding_path(config.model_type) 154 | ) 155 | if isinstance(x, jnp.ndarray): 156 | x.delete() 157 | 158 | return model_params 159 | 160 | 161 | def stack_embeddings(model_params, config, pop_embeddings=False): 162 | if config.tie_word_embeddings: 163 | input_embeddings = get( 164 | model_params, get_input_embedding_path(config.model_type) 165 | ) 166 | 167 | embeddings = input_embeddings[:, None, :] 168 | else: 169 | input_embeddings = get( 170 | model_params, get_input_embedding_path(config.model_type) 171 | ) 172 | output_embeddings = get( 173 | model_params, get_output_embedding_path(config.model_type) 174 | ) 175 | 176 | embeddings = np.stack([input_embeddings, output_embeddings.T], axis=1) 177 | 178 | if pop_embeddings: 179 | model_params = unassign_embeddings(model_params, config) 180 | 181 | return embeddings, model_params 182 | 183 | 184 | def get_num_layers(config): 185 | if hasattr(config, "num_hidden_layers"): 186 | return config.num_hidden_layers 187 | elif hasattr(config, "n_layer"): # gpt2 188 | return config.n_layer 189 | else: 190 | raise ValueError("Could not determine number of layers from config") 191 | 192 | 193 | def set_num_layers(config, num_layers): 194 | if hasattr(config, "num_hidden_layers"): 195 | config.num_hidden_layers = num_layers 196 | elif hasattr(config, "n_layer"): # gpt2 197 | config.n_layer = num_layers 198 | else: 199 | raise ValueError("Could not determine number of layers from config") 200 | 201 | 202 | def get_layer_n_mask(model_params, config, layer_idx): 203 | if layer_idx < 0: 204 | layer_idx = get_num_layers(config) + layer_idx 205 | 206 | flat_params = traverse_util.flatten_dict(model_params) 207 | mask = {} 208 | subpath = f"{get_layer_path(config.model_type)}.{layer_idx}" 209 | 210 | for key in flat_params.keys(): 211 | if subpath in ".".join(key): 212 | mask[key] = True 213 | else: 214 | mask[key] = False 215 | 216 | return traverse_util.unflatten_dict(mask) 217 | 218 | 219 | def strip_layers(model_params, config, n_keep=1): 220 | for layer_idx in range(n_keep, get_num_layers(config)): 221 | model_params, _ = pop( 222 | model_params, f"{get_layer_path(config.model_type)}.{layer_idx}" 223 | ) 224 | 225 | set_num_layers(config, n_keep) 226 | 227 | return model_params 228 | -------------------------------------------------------------------------------- /tokenkit/models/sharding.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import jax 4 | import jax.experimental 5 | import jax.experimental.mesh_utils 6 | import regex as re 7 | from jax.experimental.multihost_utils import process_allgather 8 | from jax.sharding import PartitionSpec as P 9 | import numpy as np 10 | 11 | from tokenkit import utils 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | 16 | SHARD_PATTERNS = { 17 | "hypernet": { 18 | "(opt_state|params).*ffn_layer1.linear": P(None, "model"), 19 | "(opt_state|params).*ffn_layer2.linear": P("model", None), 20 | "(opt_state|params).*self_attention.(query|key|value).w": P(None, "model"), 21 | "(opt_state|params).*self_attention.post.w": P("model", None), 22 | "(opt_state|params).*embeddings": P("model", None), 23 | }, 24 | "llama": { 25 | "(opt_state|params).*embed_tokens.*embedding": P("model", "data"), 26 | "(opt_state|params).*self_attn.(q_proj|k_proj|v_proj).kernel.a": P( 27 | "model", "data" 28 | ), 29 | "(opt_state|params).*self_attn.(q_proj|k_proj|v_proj).kernel.b": P( 30 | "data", "model" 31 | ), 32 | "(opt_state|params).*self_attn.(q_proj|k_proj|v_proj).kernel.w": P( 33 | "data", "model" 34 | ), 35 | "(opt_state|params).*norm.weight": P("model"), 36 | "(opt_state|params).*self_attn.(q_proj|k_proj|v_proj).kernel": P( 37 | "data", "model" 38 | ), 39 | "(opt_state|params).*self_attn.o_proj.kernel": P("model", "data"), 40 | "(opt_state|params).*lm_head.kernel": P("data", "model"), 41 | "(opt_state|params).*mlp.down_proj.kernel": P("model", "data"), 42 | "(opt_state|params).*mlp.up_proj.kernel": P("data", "model"), 43 | "(opt_state|params).*mlp.gate_proj.kernel": P("data", "model"), 44 | "(opt_state|params).*norm.kernel": P("model"), 45 | ".*(cached_value|cached_key)": P("data", None, "model", None), 46 | }, 47 | "mistral": { 48 | "(opt_state|params).*embed_tokens.*embedding": P("model", None), 49 | "(opt_state|params).*self_attn.(q_proj|k_proj|v_proj).kernel": P(None, "model"), 50 | "(opt_state|params).*self_attn.o_proj.kernel": P("model", None), 51 | "(opt_state|params).*lm_head.kernel": P(None, "model"), 52 | "(opt_state|params).*mlp.down_proj.kernel": P("model", None), 53 | "(opt_state|params).*mlp.up_proj.kernel": P(None, "model"), 54 | "(opt_state|params).*mlp.gate_proj.kernel": P(None, "model"), 55 | }, 56 | "gemma": { 57 | "(opt_state|params).*embed_tokens.*embedding": P("model", "data"), 58 | "(opt_state|params).*self_attn.(q_proj|k_proj|v_proj).kernel.a": P( 59 | "model", "data" 60 | ), 61 | "(opt_state|params).*self_attn.(q_proj|k_proj|v_proj).kernel.b": P( 62 | "data", "model" 63 | ), 64 | "(opt_state|params).*self_attn.(q_proj|k_proj|v_proj).kernel.w": P( 65 | "data", "model" 66 | ), 67 | "(opt_state|params).*self_attn.(q_proj|k_proj|v_proj).kernel": P( 68 | "data", "model" 69 | ), 70 | "(opt_state|params).*self_attn.o_proj.kernel": P("model", "data"), 71 | "(opt_state|params).*lm_head.kernel": P("data", "model"), 72 | "(opt_state|params).*mlp.down_proj.kernel": P("model", "data"), 73 | "(opt_state|params).*mlp.up_proj.kernel": P("data", "model"), 74 | "(opt_state|params).*mlp.gate_proj.kernel": P("data", "model"), 75 | "(opt_state|params).*norm.kernel": P("model"), 76 | }, 77 | "gemma2": { 78 | "(opt_state|params).*embed_tokens.*embedding": P("model", "data"), 79 | "(opt_state|params).*self_attn.(q_proj|k_proj|v_proj).kernel.a": P( 80 | "model", "data" 81 | ), 82 | "(opt_state|params).*self_attn.(q_proj|k_proj|v_proj).kernel.b": P( 83 | "data", "model" 84 | ), 85 | "(opt_state|params).*self_attn.(q_proj|k_proj|v_proj).kernel.w": P( 86 | "data", "model" 87 | ), 88 | "(opt_state|params).*self_attn.(q_proj|k_proj|v_proj).kernel": P( 89 | "data", "model" 90 | ), 91 | "(opt_state|params).*self_attn.o_proj.kernel": P("model", "data"), 92 | "(opt_state|params).*lm_head.kernel": P("data", "model"), 93 | "(opt_state|params).*mlp.down_proj.kernel": P("model", "data"), 94 | "(opt_state|params).*mlp.up_proj.kernel": P("data", "model"), 95 | "(opt_state|params).*mlp.gate_proj.kernel": P("data", "model"), 96 | "(opt_state|params).*norm.kernel": P("model"), 97 | }, 98 | "gpt2": { 99 | "(opt_state|params).*c_attn.kernel": P(None, "model"), 100 | "(opt_state|params).*c_proj.kernel": P("model", None), 101 | "(opt_state|params).*c_fc.kernel": P(None, "model"), 102 | }, 103 | "xlm-roberta": { 104 | "(opt_state|params).*self.(query|key|value).kernel": P(None, "model"), 105 | "(opt_state|params).*output.dense.kernel": P("model", None), 106 | "(opt_state|params).*intermediate.dense.kernel": P(None, "model"), 107 | }, 108 | } 109 | SHARD_PATTERNS["tpu_llama"] = SHARD_PATTERNS["llama"] 110 | SHARD_PATTERNS["tpu_gemma2"] = SHARD_PATTERNS["gemma2"] 111 | 112 | def get_shard_patterns(kind): 113 | return SHARD_PATTERNS.get(kind, {}) 114 | 115 | 116 | def get_sharding_fn(shard_patterns, mesh): 117 | name_to_size = {name: size for name, size in mesh.shape_tuple} 118 | 119 | def get_pspec(path, v): 120 | # this is a dummy parameter for e.g. PEFT, so no need to shard 121 | if np.prod(v.shape) == 0: 122 | return P() 123 | 124 | path_tuple = tuple(str(utils.keystr(x)) for x in path) 125 | path = ".".join(path_tuple) 126 | 127 | for key, value in shard_patterns.items(): 128 | if re.match(key, path): 129 | pspec = value 130 | for dim, name in enumerate(pspec): 131 | if name is None: 132 | continue 133 | 134 | if name not in name_to_size: 135 | raise ValueError( 136 | f"Unknown sharding name {name} in {pspec} for {path}" 137 | ) 138 | 139 | if v.shape[dim] % name_to_size[name] != 0: 140 | logger.warning( 141 | "Want to shard %s with %s, but shape %s is not divisible by %s.", 142 | path, 143 | pspec, 144 | v.shape, 145 | name_to_size[name], 146 | ) 147 | return P() 148 | 149 | logger.debug("Sharding %s with %s.", path, pspec) 150 | return P(*pspec) 151 | 152 | return P() 153 | 154 | def get_tree_shardings(tree): 155 | pspecs = jax.tree_util.tree_map_with_path(get_pspec, tree) 156 | return jax.tree.map( 157 | lambda pspec: jax.sharding.NamedSharding(mesh, pspec), pspecs 158 | ) 159 | 160 | return get_tree_shardings 161 | 162 | 163 | def to_global_array(pytree, pytree_sharding=None): 164 | if pytree_sharding is None: 165 | pytree_sharding = jax.tree.map(lambda _: None, pytree) 166 | 167 | def to_global_array_fn(array, sharding): 168 | if array is None: 169 | return None 170 | 171 | if sharding is None: 172 | return array 173 | 174 | def cb(index): 175 | return array[index] 176 | 177 | return jax.make_array_from_callback(array.shape, sharding, cb) 178 | 179 | return jax.tree.map(to_global_array_fn, pytree, pytree_sharding) 180 | 181 | 182 | def sync_across_devices(pytree): 183 | if jax.process_count() == 1: 184 | return pytree 185 | 186 | return jax.tree.map(lambda x: x[0], process_allgather(pytree)) 187 | 188 | 189 | def to_devices(pytree, pytree_sharding=None, dtype=None): 190 | # TODO: handle non-numpy inputs? 191 | pytree = to_global_array(pytree, pytree_sharding) 192 | 193 | return jax.jit( 194 | lambda x: x if dtype is None else jax.tree.map(lambda x: x.astype(dtype), x), 195 | in_shardings=(pytree_sharding,) if pytree_sharding is not None else None, 196 | out_shardings=pytree_sharding, 197 | )(pytree) 198 | 199 | 200 | def get_mesh(n_data_parallel=1, n_model_parallel=-1, devices=None): 201 | if devices is None: 202 | devices = jax.devices() 203 | 204 | device_count = len(devices) 205 | 206 | if n_data_parallel == -1: 207 | n_data_parallel = device_count 208 | 209 | if n_model_parallel == -1: 210 | n_model_parallel = device_count 211 | 212 | devices = jax.experimental.mesh_utils.create_device_mesh( 213 | mesh_shape=(n_data_parallel, n_model_parallel), 214 | devices=devices, 215 | ) 216 | return jax.sharding.Mesh(devices, ["data", "model"]) 217 | -------------------------------------------------------------------------------- /tokenkit/parse_args.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, fields, is_dataclass 2 | from transformers import HfArgumentParser 3 | from argparse import ArgumentParser 4 | import yaml 5 | 6 | 7 | @dataclass 8 | class HypernetArgs: 9 | architecture: str 10 | num_layers: int 11 | residual: bool 12 | residual_alpha: float 13 | use_attention: bool 14 | use_attention_mask: bool = False 15 | num_heads: int = 16 16 | shared: bool = True 17 | multiply_hidden_dim_by_num_embeddings: bool = True 18 | 19 | @dataclass 20 | class EvalArgs: 21 | tasks: list[str] 22 | lengths: list[int] 23 | tokens_per_batch: int 24 | add_bos: bool 25 | chat_template_mode: str 26 | confirm_run_unsafe_code: bool 27 | 28 | @dataclass 29 | class ModelArgs: 30 | pretrained_model_name_or_path: str 31 | tokenizer_name: str 32 | revision: str | None = None 33 | 34 | def restore_dataclasses(args, cls): 35 | for field in fields(cls): 36 | if is_dataclass(field.type): 37 | setattr( 38 | args, 39 | field.name, 40 | restore_dataclasses(getattr(args, field.name), field.type), 41 | ) 42 | elif isinstance(field.type, list) and is_dataclass(field.type.__args__[0]): 43 | setattr( 44 | args, 45 | field.name, 46 | [ 47 | restore_dataclasses(item, field.type.__args__[0]) 48 | for item in getattr(args, field.name) 49 | ], 50 | ) 51 | elif isinstance(field.type, dict): 52 | setattr( 53 | args, 54 | field.name, 55 | { 56 | k: restore_dataclasses(v, field.type.__args__[1]) 57 | for k, v in getattr(args, field.name).items() 58 | }, 59 | ) 60 | 61 | if not isinstance(args, cls): 62 | return cls(**args) if args is not None else None 63 | 64 | return args 65 | 66 | 67 | def parse_args(cls): 68 | parser = ArgumentParser() 69 | parser.add_argument("--config", type=str, required=True) 70 | parser.add_argument("--overrides", type=str, nargs="*") 71 | meta_args = parser.parse_args() 72 | 73 | (args,) = HfArgumentParser([cls]).parse_yaml_file(meta_args.config) 74 | 75 | for overrides in (meta_args.overrides or []): 76 | for override in overrides.split(): 77 | first_equals = override.find("=") 78 | key = override[:first_equals].split(".") 79 | try: 80 | value = yaml.safe_load(override[first_equals + 1 :]) 81 | except yaml.YAMLError: 82 | raise ValueError(f"Invalid YAML: {override[first_equals + 1 :]}") 83 | 84 | current = args 85 | for k in key[:-1]: 86 | if isinstance(current, list): 87 | current = current[int(k)] 88 | elif isinstance(current, dict): 89 | current = current[k] 90 | else: 91 | current = getattr(current, k) 92 | 93 | if isinstance(current, list): 94 | if int(key[-1]) >= len(current): 95 | raise ValueError(f"Invalid key: {key[-1]}") 96 | current[int(key[-1])] = value 97 | elif isinstance(current, dict): 98 | if key[-1] not in current: 99 | raise ValueError(f"Invalid key: {key[-1]}") 100 | current[key[-1]] = value 101 | else: 102 | if not hasattr(current, key[-1]): 103 | raise ValueError(f"Invalid key: {key[-1]}") 104 | setattr(current, key[-1], value) 105 | 106 | return restore_dataclasses(args, cls) 107 | -------------------------------------------------------------------------------- /tokenkit/training/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bminixhofer/tokenkit/7abac3578f6b3fa38f985e9f03bff7a47d5ab3b1/tokenkit/training/__init__.py -------------------------------------------------------------------------------- /tokenkit/training/checkpoint.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import jax 4 | from flax import serialization, traverse_util 5 | from jax.experimental import multihost_utils 6 | from jax.sharding import NamedSharding 7 | from jax.sharding import PartitionSpec as P 8 | 9 | 10 | def save( 11 | path, 12 | params, 13 | param_shardings, 14 | mesh, 15 | train_mask, 16 | keys_to_keep={ 17 | "hypernet", 18 | }, 19 | batch_size=16, 20 | ): 21 | flat_keys_to_save = [ 22 | k 23 | for k, trainable in traverse_util.flatten_dict(train_mask).items() 24 | if trainable or k[0] in keys_to_keep 25 | ] 26 | flat_params = traverse_util.flatten_dict(params) 27 | flat_shardings = traverse_util.flatten_dict(param_shardings) 28 | 29 | flat_params_to_save = {k: flat_params[k] for k in flat_keys_to_save} 30 | shardings_to_save = {k: flat_shardings[k] for k in flat_keys_to_save} 31 | 32 | none_shardings_to_save = jax.tree.map( 33 | lambda _: NamedSharding(mesh, P()), shardings_to_save 34 | ) 35 | 36 | keys = list(flat_params_to_save.keys()) 37 | n_batches = math.ceil(len(keys) / batch_size) 38 | 39 | all_flat_out_params = {} 40 | 41 | for i in range(n_batches): 42 | batch_keys = keys[i * batch_size : (i + 1) * batch_size] 43 | 44 | flat_device_params = jax.jit( 45 | lambda x: x, 46 | in_shardings=([shardings_to_save[k] for k in batch_keys],), 47 | out_shardings=[none_shardings_to_save[k] for k in batch_keys], 48 | )([flat_params_to_save[k] for k in batch_keys]) 49 | 50 | for key, value in zip(batch_keys, flat_device_params): 51 | all_flat_out_params[key] = jax.device_get(value) 52 | value.delete() 53 | 54 | if jax.process_index() == 0: 55 | open(path, "wb").write( 56 | serialization.msgpack_serialize( 57 | traverse_util.unflatten_dict(all_flat_out_params), in_place=True 58 | ) 59 | ) 60 | 61 | multihost_utils.sync_global_devices("saved checkpoint") 62 | -------------------------------------------------------------------------------- /tokenkit/training/collators/__init__.py: -------------------------------------------------------------------------------- 1 | from tokenkit.training.collators.tokenizer_aligner import TokenizerAlignerCollator 2 | from tokenkit.training.collators.tokenizer_sampler import TokenizerSamplerCollator 3 | 4 | __all__ = [ 5 | "TokenizerAlignerCollator", 6 | "TokenizerSamplerCollator", 7 | ] 8 | -------------------------------------------------------------------------------- /tokenkit/training/collators/tokenizer_aligner.py: -------------------------------------------------------------------------------- 1 | import json 2 | from pathlib import Path 3 | 4 | import numpy as np 5 | from jax.sharding import PartitionSpec as P 6 | from scipy import sparse 7 | 8 | from tokenkit import align, utils 9 | 10 | 11 | class TokenizerAlignerCollator: 12 | def __init__( 13 | self, 14 | tokenizer_original, 15 | tokenizer_new, 16 | max_teacher_length, 17 | max_student_length, 18 | use_chat_template=False, 19 | chat_template_mode="direct_encode", 20 | expand_input_ids_dict=None, 21 | loss_mask_mode=None, 22 | tokenizer_pair_data_path=None, 23 | tokenizer_pair_bias_threshold=0.0, 24 | require_bias_matrices=False, 25 | ): 26 | self.tokenizer_original = tokenizer_original 27 | self.tokenizer_original_vocab = tokenizer_original.get_vocab() 28 | self.tokenizer_new = tokenizer_new 29 | self.max_teacher_length = max_teacher_length 30 | self.max_student_length = max_student_length 31 | self.use_chat_template = use_chat_template 32 | self.chat_template_mode = chat_template_mode 33 | self.expand_input_ids_dict = expand_input_ids_dict 34 | 35 | if loss_mask_mode is None: 36 | loss_mask_string = None 37 | elif loss_mask_mode == "dolly": 38 | loss_mask_string = "### Response:\n" 39 | elif loss_mask_mode == "openmath2": 40 | loss_mask_string = "<|start_header_id|>assistant<|end_header_id|>\n\n" 41 | else: 42 | raise ValueError(f"Unknown loss mask mode: {loss_mask_mode}") 43 | 44 | self.loss_mask_tokens_original = ( 45 | self.tokenizer_original.encode(loss_mask_string, add_special_tokens=False) 46 | if loss_mask_string is not None 47 | else None 48 | ) 49 | self.loss_mask_tokens_new = ( 50 | self.tokenizer_new.encode(loss_mask_string, add_special_tokens=False) 51 | if loss_mask_string is not None 52 | else None 53 | ) 54 | 55 | bias1_matrix_path = Path(tokenizer_pair_data_path) / "bias1_matrix.npz" 56 | bias2_matrix_path = Path(tokenizer_pair_data_path) / "bias2_matrix.npz" 57 | teacher_token_counts_path = ( 58 | Path(tokenizer_pair_data_path) / "teacher_counts.json" 59 | ) 60 | student_token_counts_path = ( 61 | Path(tokenizer_pair_data_path) / "student_counts.json" 62 | ) 63 | 64 | if bias1_matrix_path.exists(): 65 | self.tokenizer_pair_bias1_matrix = sparse.load_npz( 66 | bias1_matrix_path 67 | ).todok() 68 | else: 69 | self.tokenizer_pair_bias1_matrix = None 70 | if bias2_matrix_path.exists(): 71 | self.tokenizer_pair_bias2_matrix = sparse.load_npz( 72 | bias2_matrix_path 73 | ).todok() 74 | else: 75 | self.tokenizer_pair_bias2_matrix = None 76 | if teacher_token_counts_path.exists(): 77 | self.teacher_token_probs = utils.compute_unigram_probabilities( 78 | tokenizer_original, json.load(open(teacher_token_counts_path)) 79 | ) 80 | else: 81 | self.teacher_token_probs = None 82 | if student_token_counts_path.exists(): 83 | self.student_token_probs = utils.compute_unigram_probabilities( 84 | tokenizer_new, json.load(open(student_token_counts_path)) 85 | ) 86 | else: 87 | self.student_token_probs = None 88 | 89 | if require_bias_matrices and ( 90 | self.tokenizer_pair_bias1_matrix is None 91 | or self.tokenizer_pair_bias2_matrix is None 92 | ): 93 | raise ValueError( 94 | "Bias matrices are required but not found in the given path." 95 | ) 96 | 97 | self.tokenizer_pair_bias_threshold = tokenizer_pair_bias_threshold 98 | 99 | self.prefix_map_original = self._compute_prefix_map(tokenizer_original) 100 | self.prefix_map_new = self._compute_prefix_map(tokenizer_new) 101 | 102 | def _compute_loss_mask(self, input_ids, attention_mask, loss_mask_tokens): 103 | loss_mask = attention_mask.astype(bool) 104 | if loss_mask_tokens is not None: 105 | for i in range(len(input_ids)): 106 | for j in range(len(input_ids[i])): 107 | if input_ids[i][j] != loss_mask_tokens[0]: 108 | continue 109 | 110 | if ( 111 | input_ids[i][j : j + len(loss_mask_tokens)].tolist() 112 | == loss_mask_tokens 113 | ): 114 | loss_mask[i, : j + len(loss_mask_tokens)] = False 115 | 116 | return loss_mask 117 | 118 | def _compute_prefix_map(self, tokenizer): 119 | prefix_map = {} 120 | 121 | for token in tokenizer.get_vocab().keys(): 122 | for i in range(1, len(token) + 1): 123 | if token[:i] in prefix_map: 124 | prefix_map[token[:i]].append(token) 125 | else: 126 | prefix_map[token[:i]] = [token] 127 | 128 | return prefix_map 129 | 130 | def _encode_with_chat_template(self, texts, tokenizer, max_length): 131 | input_ids = np.full( 132 | (len(texts), max_length), fill_value=tokenizer.pad_token_id, dtype=np.int32 133 | ) 134 | attention_mask = np.zeros((len(texts), max_length), dtype=np.int32) 135 | 136 | for i in range(len(texts)): 137 | current_input_ids, _ = utils.encode_prompt( 138 | utils.preprocess_prompt(texts[i], self.chat_template_mode), 139 | tokenizer, 140 | max_length=max_length, 141 | ) 142 | input_ids[i, : len(current_input_ids)] = current_input_ids 143 | attention_mask[i, : len(current_input_ids)] = 1 144 | 145 | return { 146 | "input_ids": input_ids, 147 | "attention_mask": attention_mask, 148 | } 149 | 150 | def __call__(self, examples): 151 | # batched internally 152 | examples = examples[0] 153 | 154 | texts = examples["text"] 155 | 156 | if self.use_chat_template: 157 | encoding_original = self._encode_with_chat_template( 158 | texts, 159 | tokenizer=self.tokenizer_original, 160 | max_length=self.max_teacher_length, 161 | ) 162 | encoding_new = self._encode_with_chat_template( 163 | texts, 164 | tokenizer=self.tokenizer_new, 165 | max_length=self.max_student_length, 166 | ) 167 | else: 168 | encoding_original = self.tokenizer_original( 169 | texts, 170 | max_length=self.max_teacher_length, 171 | padding="max_length", 172 | truncation=True, 173 | return_tensors="np", 174 | ) 175 | encoding_new = self.tokenizer_new( 176 | texts, 177 | max_length=self.max_student_length, 178 | padding="max_length", 179 | truncation=True, 180 | return_tensors="np", 181 | ) 182 | 183 | input_ids_original = encoding_original["input_ids"] 184 | attention_mask_original = encoding_original["attention_mask"] 185 | input_ids_new = encoding_new["input_ids"] 186 | attention_mask_new = encoding_new["attention_mask"] 187 | 188 | ( 189 | alignment_matrix_a, 190 | alignment_matrix_b, 191 | ) = align.get_unconstrained_alignments( 192 | input_ids_original, 193 | input_ids_new, 194 | attention_mask_original, 195 | attention_mask_new, 196 | tokenizer_teacher=self.tokenizer_original, 197 | tokenizer_student=self.tokenizer_new, 198 | ) 199 | 200 | ( 201 | alignment_matrix_a_space, 202 | alignment_matrix_b_space, 203 | ) = align.get_space_alignments( 204 | input_ids_original, 205 | input_ids_new, 206 | attention_mask_original, 207 | attention_mask_new, 208 | tokenizer_teacher=self.tokenizer_original, 209 | tokenizer_student=self.tokenizer_new, 210 | ) 211 | 212 | if ( 213 | self.tokenizer_pair_bias1_matrix is not None 214 | and self.tokenizer_pair_bias2_matrix is not None 215 | ): 216 | ( 217 | alignment_matrix_a_unbiased, 218 | alignment_matrix_b_unbiased, 219 | ) = align.get_unbiased_alignments( 220 | input_ids_original, 221 | input_ids_new, 222 | attention_mask_original, 223 | attention_mask_new, 224 | tokenizer_teacher=self.tokenizer_original, 225 | tokenizer_student=self.tokenizer_new, 226 | pair_data=( 227 | self.tokenizer_pair_bias1_matrix, 228 | self.tokenizer_pair_bias2_matrix, 229 | self.teacher_token_probs, 230 | self.student_token_probs, 231 | ), 232 | bias_threshold=self.tokenizer_pair_bias_threshold, 233 | ) 234 | else: 235 | alignment_matrix_a_unbiased = np.full_like( 236 | alignment_matrix_a, fill_value=np.nan 237 | ) 238 | alignment_matrix_b_unbiased = np.full_like( 239 | alignment_matrix_b, fill_value=np.nan 240 | ) 241 | 242 | occuring_tokens_mask_original = np.zeros( 243 | len(self.tokenizer_original), dtype=bool 244 | ) 245 | occuring_tokens_mask_new = np.zeros(len(self.tokenizer_new), dtype=bool) 246 | 247 | occuring_tokens_mask_original[input_ids_original.flatten()] = True 248 | occuring_tokens_mask_new[input_ids_new.flatten()] = True 249 | 250 | loss_mask_original = self._compute_loss_mask( 251 | input_ids_original, attention_mask_original, self.loss_mask_tokens_original 252 | ) 253 | loss_mask_new = self._compute_loss_mask( 254 | input_ids_new, attention_mask_new, self.loss_mask_tokens_new 255 | ) 256 | 257 | batch = { 258 | "input_ids_new": input_ids_new, 259 | "attention_mask_new": attention_mask_new, 260 | "occuring_tokens_mask_new": occuring_tokens_mask_new, 261 | "input_ids_original": input_ids_original, 262 | "attention_mask_original": attention_mask_original, 263 | "occuring_tokens_mask_original": occuring_tokens_mask_original, 264 | "alignment_matrix_a_unconstrained": alignment_matrix_a, 265 | "alignment_matrix_b_unconstrained": alignment_matrix_b, 266 | "alignment_matrix_a_space": alignment_matrix_a_space, 267 | "alignment_matrix_b_space": alignment_matrix_b_space, 268 | "alignment_matrix_a_unbiased": alignment_matrix_a_unbiased, 269 | "alignment_matrix_b_unbiased": alignment_matrix_b_unbiased, 270 | "loss_mask_original": loss_mask_original, 271 | "loss_mask_new": loss_mask_new, 272 | } 273 | 274 | if self.expand_input_ids_dict is not None: 275 | batch["expanded_input_ids_new"] = utils.np_expand_input_ids( 276 | input_ids_new, 277 | self.expand_input_ids_dict, 278 | ) 279 | 280 | return batch 281 | 282 | def get_batch_pspecs(self): 283 | batch_specs = { 284 | "input_ids_new": P("data", None), 285 | "attention_mask_new": P("data", None), 286 | "occuring_tokens_mask_new": P(), 287 | "input_ids_original": P("data", None), 288 | "attention_mask_original": P("data", None), 289 | "occuring_tokens_mask_original": P(), 290 | "alignment_matrix_a_unconstrained": P("data", None), 291 | "alignment_matrix_b_unconstrained": P("data", None), 292 | "alignment_matrix_a_space": P("data", None), 293 | "alignment_matrix_b_space": P("data", None), 294 | "alignment_matrix_a_unbiased": P("data", None), 295 | "alignment_matrix_b_unbiased": P("data", None), 296 | "loss_mask_original": P("data", None), 297 | "loss_mask_new": P("data", None), 298 | } 299 | 300 | if self.expand_input_ids_dict is not None: 301 | batch_specs["expanded_input_ids_new"] = P("data", None) 302 | 303 | return batch_specs 304 | -------------------------------------------------------------------------------- /tokenkit/training/lr.py: -------------------------------------------------------------------------------- 1 | import optax 2 | 3 | 4 | def linear_warmup_linear_decay_with_linear_prefix( 5 | lr, steps, warmup_steps, prefix_steps=0, prefix_lr=0.0 6 | ): 7 | """Returns a linear warmup, linear decay learning rate function.""" 8 | 9 | prefix_fn = optax.linear_schedule( 10 | init_value=0.0, end_value=prefix_lr, transition_steps=prefix_steps 11 | ) 12 | 13 | warmup_fn = optax.linear_schedule( 14 | init_value=0.0, 15 | end_value=lr, 16 | transition_steps=warmup_steps, 17 | ) 18 | 19 | decay_fn = optax.linear_schedule( 20 | init_value=lr, 21 | end_value=0.0, 22 | transition_steps=steps - warmup_steps - prefix_steps, 23 | ) 24 | 25 | fn = optax.join_schedules( 26 | schedules=[prefix_fn, warmup_fn, decay_fn], 27 | boundaries=[ 28 | prefix_steps, 29 | prefix_steps + warmup_steps, 30 | ], 31 | ) 32 | 33 | return fn 34 | 35 | 36 | def linear_warmup_cosine_decay_with_linear_prefix( 37 | lr, steps, warmup_steps, alpha=0.0, prefix_steps=0, prefix_lr=0.0 38 | ): 39 | """Returns a linear warmup, cosine decay learning rate function.""" 40 | 41 | prefix_fn = optax.linear_schedule( 42 | init_value=0.0, end_value=prefix_lr, transition_steps=prefix_steps 43 | ) 44 | 45 | warmup_fn = optax.linear_schedule( 46 | init_value=0.0, 47 | end_value=lr, 48 | transition_steps=warmup_steps, 49 | ) 50 | 51 | decay_fn = optax.cosine_decay_schedule( 52 | init_value=lr, 53 | decay_steps=steps - warmup_steps - prefix_steps, 54 | alpha=alpha, 55 | ) 56 | 57 | fn = optax.join_schedules( 58 | schedules=[prefix_fn, warmup_fn, decay_fn], 59 | boundaries=[ 60 | prefix_steps, 61 | prefix_steps + warmup_steps, 62 | ], 63 | ) 64 | 65 | return fn 66 | -------------------------------------------------------------------------------- /tokenkit/training/multitask.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | from typing import Any 4 | 5 | 6 | def pcgrad(task_grads: Any) -> Any: 7 | """ 8 | Implements PCGrad (Project Conflicting Gradients) in JAX. 9 | 10 | Args: 11 | task_grads: A pytree containing gradients where the first dimension of each 12 | array represents different tasks. Shape: (n_tasks, ...) 13 | 14 | Returns: 15 | Modified gradients after applying PCGrad, with the same structure as the input. 16 | """ 17 | # Get the structure of the input pytree and convert to flat representation 18 | flat_task_grads, treedef = jax.tree.flatten(task_grads) 19 | 20 | # Check if all elements are arrays and get number of tasks 21 | n_tasks = None 22 | for grad in flat_task_grads: 23 | if grad is not None: 24 | n_tasks = grad.shape[0] 25 | break 26 | 27 | if n_tasks is None: 28 | # No valid gradients found 29 | return task_grads 30 | 31 | # Create a new flat list to store the modified gradients 32 | modified_flat_grads = [] 33 | 34 | # Process each element in the flat list 35 | for grad in flat_task_grads: 36 | if grad is None: 37 | modified_flat_grads.append(None) 38 | continue 39 | 40 | # Ensure the first dimension matches the number of tasks 41 | assert ( 42 | grad.shape[0] == n_tasks 43 | ), f"Expected first dimension to be {n_tasks}, got {grad.shape[0]}" 44 | 45 | # Extract shape for reshaping later 46 | original_shape = grad.shape 47 | # Reshape to (n_tasks, -1) for easier processing 48 | reshaped_grad = grad.reshape(n_tasks, -1) 49 | 50 | # Initialize modified gradients with copies of the original 51 | modified_grad = reshaped_grad.copy() 52 | 53 | # Apply PCGrad for each task 54 | for i in range(n_tasks): 55 | # Project task i's gradient onto normal plane of other tasks' gradients if they conflict 56 | grad_i = modified_grad[i] 57 | 58 | for j in range(n_tasks): 59 | if i == j: 60 | continue 61 | 62 | grad_j = reshaped_grad[j] 63 | 64 | # Calculate dot product to check for conflict 65 | dot_product = jnp.sum(grad_i * grad_j) 66 | 67 | # If dot product is negative, project gradient 68 | def project(g_i, g_j, dot): 69 | g_j_norm_squared = jnp.sum(g_j * g_j) 70 | # Avoid division by zero 71 | safe_norm_squared = jnp.maximum(g_j_norm_squared, 1e-8) 72 | # Project g_i onto normal plane of g_j 73 | return g_i - jnp.minimum(0.0, dot) * g_j / safe_norm_squared 74 | 75 | modified_grad = modified_grad.at[i].set( 76 | project(modified_grad[i], grad_j, dot_product) 77 | ) 78 | 79 | # Reshape back to original shape and add to the modified list 80 | modified_flat_grads.append(modified_grad.reshape(original_shape)) 81 | 82 | # Reconstruct the pytree with modified gradients 83 | return jax.tree.unflatten(treedef, modified_flat_grads) 84 | 85 | 86 | def gradmag(task_grads: Any, epsilon: float = 1e-8) -> Any: 87 | """ 88 | Normalizes gradients of all tasks to have the same magnitude. 89 | 90 | Args: 91 | task_grads: A pytree containing gradients where the first dimension of each 92 | array represents different tasks. Shape: (n_tasks, ...) 93 | epsilon: Small constant to avoid division by zero 94 | 95 | Returns: 96 | Modified gradients after normalization, with the same structure as the input. 97 | """ 98 | global_grad_norms = compute_global_grad_norm(task_grads) + epsilon 99 | return jax.tree.map( 100 | lambda grad: grad 101 | / jnp.reshape(global_grad_norms, (-1,) + (1,) * (grad.ndim - 1)), 102 | task_grads, 103 | ) 104 | 105 | 106 | def gradclip(task_grads: Any, max_norm: float) -> Any: 107 | """ 108 | Clips gradients of all tasks to have the same magnitude. 109 | 110 | Args: 111 | task_grads: A pytree containing gradients where the first dimension of each 112 | array represents different tasks. Shape: (n_tasks, ...) 113 | max_norm: Maximum allowed norm for gradients 114 | 115 | Returns: 116 | Modified gradients after clipping, with the same structure as the input. 117 | """ 118 | global_grad_norms = compute_global_grad_norm(task_grads) 119 | denominators = jnp.maximum(global_grad_norms, max_norm) 120 | return jax.tree.map(lambda grad: grad / jnp.reshape(denominators, (-1,) + (1,) * (grad.ndim - 1)) * max_norm, task_grads) 121 | 122 | 123 | def compute_global_grad_norm(task_grads: Any) -> jnp.ndarray: 124 | """ 125 | Computes the global gradient norm for a pytree of task gradients. 126 | 127 | Args: 128 | task_grads: A pytree containing gradients where the first dimension of each 129 | array represents different tasks. Shape: (n_tasks, ...) 130 | 131 | Returns: 132 | Global gradient norm, a scalar. 133 | """ 134 | global_grad_norms = jnp.sqrt( 135 | jax.tree.reduce( 136 | lambda x, y: x + y, 137 | jax.tree.map( 138 | lambda x: jnp.square(x).reshape(x.shape[0], -1).sum(axis=1), 139 | task_grads, 140 | ), 141 | ) 142 | ) 143 | return global_grad_norms 144 | 145 | 146 | def compute_inv_global_grad_norm(task_grads: Any, epsilon: float = 1e-8) -> jnp.ndarray: 147 | """ 148 | Computes the inverse of the global gradient norm for a pytree of task gradients. 149 | 150 | Args: 151 | task_grads: A pytree containing gradients where the first dimension of each 152 | """ 153 | 154 | return 1 / (compute_global_grad_norm(task_grads) + epsilon) -------------------------------------------------------------------------------- /tokenkit/training/opt.py: -------------------------------------------------------------------------------- 1 | from pprint import pprint 2 | 3 | import optax 4 | import regex as re 5 | from flax import traverse_util 6 | 7 | 8 | def decay_mask_fn(params): 9 | flat_params = traverse_util.flatten_dict(params) 10 | 11 | # TODO: this is somewhat hacky but (almost) always accurate 12 | flat_mask = { 13 | path: not ( 14 | path[-1] in {"bias", "b"} 15 | or any( 16 | ln_name in ".".join(path[-2:]) 17 | for ln_name in {"layernorm", "layer_norm", "ln"} 18 | ) 19 | ) 20 | for path in flat_params 21 | } 22 | return traverse_util.unflatten_dict(flat_mask) 23 | 24 | 25 | def get_optimizer(train_mask, learning_rate_fn, **optimizer_kwargs): 26 | transforms = [] 27 | 28 | opt_type = optimizer_kwargs.pop("type") 29 | grad_acc_steps = optimizer_kwargs.pop("grad_acc_steps", None) 30 | max_grad_norm = optimizer_kwargs.pop("max_grad_norm", None) 31 | 32 | if opt_type == "adamw": 33 | opt_fn = optax.adamw 34 | else: 35 | raise ValueError(f"Unknown optimizer type: {opt_type}") 36 | 37 | flat_param_group_labels = {} 38 | flat_train_mask = traverse_util.flatten_dict(train_mask) 39 | param_groups = optimizer_kwargs.pop("param_groups", []) 40 | optimizers = { 41 | "_default": opt_fn( 42 | mask=decay_mask_fn, learning_rate=learning_rate_fn, **optimizer_kwargs 43 | ), 44 | "_do_not_train": optax.set_to_zero(), 45 | } 46 | 47 | for group in param_groups: 48 | for key, trainable in flat_train_mask.items(): 49 | if not trainable: 50 | flat_param_group_labels[key] = "_do_not_train" 51 | elif re.match(group["pattern"], ".".join(key)): 52 | flat_param_group_labels[key] = group["pattern"] 53 | if group["pattern"] not in optimizers: 54 | optimizers[group["pattern"]] = opt_fn( 55 | mask=decay_mask_fn, 56 | learning_rate=lambda count: learning_rate_fn(count) 57 | * group["lr_scale"], 58 | **optimizer_kwargs, 59 | ) 60 | 61 | for key in optimizers.keys(): 62 | if key == "_do_not_train": 63 | continue 64 | 65 | if max_grad_norm is not None: 66 | optimizers[key] = optax.chain( 67 | optax.clip_by_global_norm(max_grad_norm), 68 | optimizers[key], 69 | ) 70 | 71 | if grad_acc_steps is not None and grad_acc_steps > 1: 72 | optimizers[key] = optax.MultiSteps(opt=optimizers[key], every_k_schedule=grad_acc_steps) 73 | 74 | for key, trainable in flat_train_mask.items(): 75 | if key not in flat_param_group_labels: 76 | if trainable: 77 | flat_param_group_labels[key] = "_default" 78 | else: 79 | flat_param_group_labels[key] = "_do_not_train" 80 | 81 | print("Special parameter groups:") 82 | pprint( 83 | { 84 | k: v 85 | for k, v in flat_param_group_labels.items() 86 | if v not in {"_default", "_do_not_train"} 87 | } 88 | ) 89 | 90 | return optax.multi_transform( 91 | optimizers, 92 | traverse_util.unflatten_dict(flat_param_group_labels), 93 | ) 94 | --------------------------------------------------------------------------------