├── .gitignore
├── .pre-commit-config.yaml
├── LICENSE
├── README.md
├── configs
├── compute_mined_mapping.yaml
├── compute_tokenizer_info.yaml
├── cross_tokenizer_distill.yaml
├── data
│ └── tulu3.yaml
├── eval.yaml
├── eval_lockstep.yaml
├── eval_lockstep_mined.yaml
├── math_cross_tokenizer_distill.yaml
├── math_same_tokenizer_distill.yaml
├── models
│ ├── gemma_llama_qwen.yaml
│ └── llama_qwen.yaml
├── optimizer
│ └── adamw.yaml
├── train_zett_hn.yaml
├── train_zett_hn_gemma2.yaml
└── zett.yaml
├── docs
├── byteification.md
├── pytorch_alm_from_scratch.ipynb
└── tokenizer_transfer.md
├── examples
├── gemma2_distill_from_openmath2-llama_gpu.sh
├── llama3_to_byte_tokenizer_gpu.sh
├── llama3_to_qwen2_tokenizer_gpu.sh
└── llama3_to_qwen2_tokenizer_tpu.sh
├── pyproject.toml
├── requirements.txt
├── rust_utils
├── .gitignore
├── Cargo.lock
├── Cargo.toml
├── pyproject.toml
├── rust-toolchain
├── src
│ └── lib.rs
└── tpu_build.sh
├── scripts
├── compute_mined_mapping.py
├── compute_tokenizer_info.py
├── cross_tokenizer_distill.py
├── eval.py
├── eval_lockstep.py
├── export_checkpoint.py
├── push_flax_version_to_hub.py
├── train_zett_hn.py
└── zett.py
├── setup.py
└── tokenkit
├── __init__.py
├── align.py
├── baseline_utils.py
├── byteify.py
├── compat
├── hyper_roberta.py
└── hypernet.py
├── constants.py
├── data
└── __init__.py
├── eval
├── __init__.py
└── generate.py
├── gcs_utils.py
├── hf
├── __init__.py
├── configuration_tpu_gemma2.py
├── configuration_tpu_llama.py
├── modelling_flax_tpu_gemma2.py
├── modelling_flax_tpu_llama.py
├── modelling_tpu_gemma2.py
└── modelling_tpu_llama.py
├── model_kinds.py
├── models
├── __init__.py
├── hypernet
│ └── __init__.py
├── lora.py
├── param.py
└── sharding.py
├── parse_args.py
├── training
├── __init__.py
├── checkpoint.py
├── collators
│ ├── __init__.py
│ ├── tokenizer_aligner.py
│ └── tokenizer_sampler.py
├── losses.py
├── lr.py
├── multitask.py
└── opt.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | outputs
2 | tokenkit_env
3 | wandb
4 | artifacts
5 | logs
6 |
7 | ### Python ###
8 | # Byte-compiled / optimized / DLL files
9 | __pycache__/
10 | *.py[cod]
11 | *$py.class
12 |
13 | # C extensions
14 | *.so
15 |
16 | # Distribution / packaging
17 | .Python
18 | build/
19 | develop-eggs/
20 | dist/
21 | downloads/
22 | eggs/
23 | .eggs/
24 | lib/
25 | lib64/
26 | parts/
27 | sdist/
28 | var/
29 | wheels/
30 | share/python-wheels/
31 | *.egg-info/
32 | .installed.cfg
33 | *.egg
34 | MANIFEST
35 |
36 | # PyInstaller
37 | # Usually these files are written by a python script from a template
38 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
39 | *.manifest
40 | *.spec
41 |
42 | # Installer logs
43 | pip-log.txt
44 | pip-delete-this-directory.txt
45 |
46 | # Unit test / coverage reports
47 | htmlcov/
48 | .tox/
49 | .nox/
50 | .coverage
51 | .coverage.*
52 | .cache
53 | nosetests.xml
54 | coverage.xml
55 | *.cover
56 | *.py,cover
57 | .hypothesis/
58 | .pytest_cache/
59 | cover/
60 |
61 | # Translations
62 | *.mo
63 | *.pot
64 |
65 | # Django stuff:
66 | *.log
67 | local_settings.py
68 | db.sqlite3
69 | db.sqlite3-journal
70 |
71 | # Flask stuff:
72 | instance/
73 | .webassets-cache
74 |
75 | # Scrapy stuff:
76 | .scrapy
77 |
78 | # Sphinx documentation
79 | docs/_build/
80 |
81 | # PyBuilder
82 | .pybuilder/
83 | target/
84 |
85 | # Jupyter Notebook
86 | .ipynb_checkpoints
87 |
88 | # IPython
89 | profile_default/
90 | ipython_config.py
91 |
92 | # pyenv
93 | # For a library or package, you might want to ignore these files since the code is
94 | # intended to run in multiple environments; otherwise, check them in:
95 | # .python-version
96 |
97 | # pipenv
98 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
99 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
100 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
101 | # install all needed dependencies.
102 | #Pipfile.lock
103 |
104 | # poetry
105 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
106 | # This is especially recommended for binary packages to ensure reproducibility, and is more
107 | # commonly ignored for libraries.
108 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
109 | #poetry.lock
110 |
111 | # pdm
112 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
113 | #pdm.lock
114 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
115 | # in version control.
116 | # https://pdm.fming.dev/#use-with-ide
117 | .pdm.toml
118 |
119 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
120 | __pypackages__/
121 |
122 | # Celery stuff
123 | celerybeat-schedule
124 | celerybeat.pid
125 |
126 | # SageMath parsed files
127 | *.sage.py
128 |
129 | # Environments
130 | .env
131 | .venv
132 | env/
133 | venv/
134 | ENV/
135 | env.bak/
136 | venv.bak/
137 |
138 | # Spyder project settings
139 | .spyderproject
140 | .spyproject
141 |
142 | # Rope project settings
143 | .ropeproject
144 |
145 | # mkdocs documentation
146 | /site
147 |
148 | # mypy
149 | .mypy_cache/
150 | .dmypy.json
151 | dmypy.json
152 |
153 | # Pyre type checker
154 | .pyre/
155 |
156 | # pytype static type analyzer
157 | .pytype/
158 |
159 | # Cython debug symbols
160 | cython_debug/
161 |
162 | # PyCharm
163 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
164 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
165 | # and can be added to the global gitignore or merged into this file. For a more nuclear
166 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
167 | #.idea/
168 |
169 | ### Python Patch ###
170 | # Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration
171 | poetry.toml
172 |
173 | # ruff
174 | .ruff_cache/
175 |
176 | # LSP config files
177 | pyrightconfig.json
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | - repo: https://github.com/psf/black
3 | rev: 24.2.0
4 | hooks:
5 | - id: black
6 | language_version: python3
7 |
8 | - repo: https://github.com/pycqa/isort
9 | rev: 5.13.2
10 | hooks:
11 | - id: isort
12 | name: isort (python)
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
tokenkit🔁
2 | Tokenization Transfer for LLMs
3 |
4 |
5 |

6 |
7 |
8 | `tokenkit` is a toolkit implementing advanced methods to transfer *models* and *model knowledge* across tokenizers.
9 |
10 | ## News
11 |
12 | - __2025-04-23__: A new guide on [implementing cross-tokenizer distillation via ALM from scratch in PyTorch](./docs/pytorch_alm_from_scratch.ipynb)! 🔥
13 | - __2025-04-22__: New [Llama3-2-3B-IT-Byte](https://huggingface.co/benjamin/Llama3-2-3B-IT-Byte) and [Gemma2-2B-IT-Byte](https://huggingface.co/benjamin/Gemma2-2B-IT-Byte) checkpoints with native `transformers` support (plus, documentation on how to train them). Also, new guides for [running tokenizer transfer](./docs/tokenizer_transfer.md) and [byteification](./docs/byteification.md)!
14 | - __2025-04-02__: The initial release of `tokenkit` with support for cross-tokenizer distillation via ALM and Zero-Shot Tokenizer Transfer via FVT!
15 |
16 | ## Contents
17 | - [Why Transfer Across Tokenizers?](#why-transfer-across-tokenizers)
18 | - [Installation](#installation)
19 | - [Quickstart](#quickstart)
20 | - [Guides](#guides)
21 | - [Tokenizer Transfer via tokenkit](./docs/tokenizer_transfer.md)
22 | - [Byteification: A Unified Interface to Tokenizers](./docs/byteification.md)
23 | - [Implementing ALM From Scratch in PyTorch](./docs/pytorch_alm_from_scratch.ipynb) (new! 🔥)
24 | - [Features](#features)
25 | - [Cross-Tokenizer Distillation](#cross-tokenizer-distillation)
26 | - [Zero-Shot Tokenizer Transfer](#zero-shot-tokenizer-transfer)
27 | - [Token-Level Ensembling & Evaluating Transferred Models](#token-level-ensembling--evaluating-transferred-models)
28 | - [Citation](#citation)
29 | - [Acknowledgments](#acknowledgments)
30 |
31 | ## Why Transfer Across Tokenizers?
32 |
33 | LLMs are bound to the tokenizer they were pretrained with. This limits their adaptability, reusability and modularity. Tokenizer transfer can lift this limitation. For example:
34 | - If we want to reuse an LLM trained primarily on English in another language, we might want to update its tokenizer to one that is more suitable for the new language.
35 | - If we want to combine (e.g., token-level ensemble) two LLMs, we need to transfer them to a common tokenizer.
36 | - If we want to experiment with better tokenization schemes (e.g., byte-level tokenization), we might want to transfer an existing LLM to this tokenizer instead of training a new one expensively from scratch.
37 | - If we want to transfer knowledge from a large teacher model to a smaller student model (which uses another tokenizer), we might want to use *cross-tokenizer distillation* to directly transfer the teacher's knowledge to the student without the need to first transfer the teacher to the student's tokenizer.
38 |
39 | This library aims to let you accomplish all of this.
40 |
41 | ## Installation
42 |
43 | `tokenkit` is primarily implemented in Jax, using PyTorch for data loading (so your PyTorch installation does not need to support an accelerator). Recommended installation:
44 |
45 |
46 | TPU
47 |
48 | ```bash
49 | # Clone the repository & install the library
50 | git clone https://github.com/bminixhofer/tokenkit
51 |
52 | # Create a new virtual environment
53 | # Currently, requires Python <=3.10, but we are working on this: https://github.com/bminixhofer/tokenkit/issues/4
54 | python -m venv tokenkit_env
55 | . tokenkit_env/bin/activate
56 |
57 | # Install torch & jax 0.5.0
58 | pip install torch jax[tpu]==0.5.0 -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
59 |
60 | # Currently, tokenkit relies on a fork of `lm_eval`
61 | pip install git+https://github.com/bminixhofer/lm-evaluation-harness
62 |
63 | # Install the library and the remaining dependencies
64 | pip install -r requirements.txt
65 | pip install -e .
66 | # You can ignore warnings from the command below, see https://github.com/bminixhofer/tokenkit/issues/4
67 | pip install paxml==1.4.0 praxis==1.4.0 --no-deps
68 | ```
69 |
70 |
71 |
72 | GPU
73 |
74 | ```bash
75 | # Clone the repository & install the library
76 | git clone https://github.com/bminixhofer/tokenkit
77 |
78 | # Create a new virtual environment
79 | # Currently, requires Python <=3.10, but we are working on this: https://github.com/bminixhofer/tokenkit/issues/4
80 | python -m venv tokenkit_env
81 | . tokenkit_env/bin/activate
82 |
83 | # Install torch & jax 0.5.0
84 | # you may need to substitute cuda12 with the version of CUDA you are using:
85 | pip install torch jax[cuda12]==0.5.0
86 |
87 | # Currently, tokenkit relies on a fork of `lm_eval`
88 | pip install git+https://github.com/bminixhofer/lm-evaluation-harness
89 |
90 | # Install the library and the remaining dependencies
91 | pip install -r requirements.txt
92 | pip install -e .
93 | # You can ignore warnings from the command below, see https://github.com/bminixhofer/tokenkit/issues/4
94 | pip install paxml==1.4.0 praxis==1.4.0 --no-deps
95 | ```
96 |
97 |
98 | ## Quickstart
99 |
100 | After installing the library, you can play around with the scripts in `examples/` to get started immediately. For example:
101 |
102 | ```
103 | bash examples/llama3_to_byte_tokenizer_gpu.sh
104 | ```
105 |
106 | If you're interested in reproducing or improving on a public model which has been trained via ALM, you can also take a look at the `tokenkit` command used to train that model, for example [in the Training section of the Llama3-2-3B-IT-Byte model card](https://huggingface.co/benjamin/Llama3-2-3B-IT-Byte#training).
107 |
108 | ## Guides
109 |
110 | - [Tokenizer Transfer via tokenkit](./docs/tokenizer_transfer.md) (start here!)
111 | - [Byteification: A Unified Interface to Tokenizers](./docs/byteification.md)
112 | - [Implementing ALM From Scratch in PyTorch](./docs/pytorch_alm_from_scratch.ipynb) (interactive notebook)
113 |
114 | ## Features
115 |
116 | ### Cross-Tokenizer Distillation
117 |
118 | `tokenkit` supports [Approximate Likelihood Matching (ALM)](https://arxiv.org/abs/2503.20083) for cross-tokenizer distillation. ALM usually performs best, but we have also implemented the following baselines:
119 |
120 | - [Dual Space Knowledge Distillation (DSKD)](https://arxiv.org/abs/2406.17328)
121 | - [Universal Logit Distillation (ULD)](https://arxiv.org/abs/2402.12030)
122 | - [Minimum Edit Distance Logit Alignment (MinED)](https://arxiv.org/abs/2401.10491)
123 |
124 | You can run cross-tokenizer distillation using the [`scripts/cross_tokenizer_distill.py`](scripts/cross_tokenizer_distill.py) script. See [`examples`](examples) for examples on transferring to different subword tokenizers and to byte-level tokenization.
125 |
126 | ### Zero-Shot Tokenizer Transfer
127 |
128 | `tokenkit` supports Zero-Shot Tokenizer Transfer (ZeTT) via [Fast Vocabulary Transfer (FVT)](https://aclanthology.org/2022.emnlp-industry.41). Zero-Shot Tokenizer Transfer is usually used to obtain a good initialization for additional training, but can in some cases also be useful on its own. See our [ZeTT paper](https://arxiv.org/abs/2405.07883) for more details.
129 |
130 | You can run Zero-Shot Tokenizer Transfer using the [`scripts/zett.py`](scripts/zett.py) script.
131 |
132 | **🚧 We are working on implementing more ZeTT methods (including hypernetwork training introduced [here](https://arxiv.org/abs/2405.07883)).**
133 |
134 | ### Token-Level Ensembling & Evaluating Transferred Models
135 |
136 | `tokenkit` supports autoregressive generation & loglikelihood scoring evaluation by implementing a Jax backend to the [LM Evaluation Harness](https://github.com/EleutherAI/lm-evaluation-harness). Alongside generating from single models, you can also generate from *token-level ensembles* of models. There are some predefined ensembles in [`configs/models`](configs/models). For example, this evaluates a token-level ensemle of Llama and Qwen on MMLU:
137 |
138 | ```bash
139 | python3 scripts/eval_lockstep.py \
140 | models=llama_qwen \
141 | eval.tasks=[mmlu]
142 | ```
143 |
144 | To evaluate pretrained byte-level models, you'll need to pass embeddings to expand the input ids with (i.e., to use as n-gram embeddings). For example:
145 |
146 | ```bash
147 | python3 scripts/eval.py \
148 | model.pretrained_model_name_or_path=\'benjamin/Gemma2-2B-IT-Byte\' \
149 | model.tokenizer_name=\'google/gemma-2-2b-it:source=Gemma2:conversion=byte\' \
150 | expand_model.pretrained_model_name_or_path=\'benjamin/gemma-2-2b-it-flax\' \
151 | expand_model.tokenizer_name=\'google/gemma-2-2b-it:source=Gemma2\' \
152 | eval.tasks=[mmlu]
153 | ```
154 |
155 | To evaluate any other model (e.g., subword-to-subword transferred models), use something like the following:
156 |
157 | ```bash
158 | python3 scripts/eval.py \
159 | model.pretrained_model_name_or_path=\'benjamin/Gemma2-2B-IT-with-Qwen2-Tokenizer\' \
160 | model.tokenizer_name=\'benjamin/Gemma2-2B-IT-with-Qwen2-Tokenizer:source=Gemma2:conversion=prebyteified\' \
161 | eval.tasks=[mmlu] \
162 | ```
163 |
164 | ## Citation
165 |
166 | To refer to this repository or to cite Approximate Likelihood Matching, please use this citation:
167 |
168 | ```
169 | @article{alm,
170 | title={Cross-Tokenizer Distillation via Approximate Likelihood Matching},
171 | author={Minixhofer, Benjamin and Vuli{\'c}, Ivan and Ponti, Edoardo Maria},
172 | journal={arXiv preprint arXiv:2503.20083},
173 | year={2025}
174 | }
175 | ```
176 |
177 | Please use this citation for Zero-Shot Tokenizer Transfer:
178 |
179 | ```
180 | @inproceedings{zett,
181 | title={Zero-Shot Tokenizer Transfer},
182 | author={Benjamin Minixhofer and Edoardo Ponti and Ivan Vuli{\'c}},
183 | booktitle={The Thirty-eighth Annual Conference on Neural Information Processing Systems},
184 | year={2024},
185 | url={https://openreview.net/forum?id=RwBObRsIzC}
186 | }
187 | ```
188 |
189 | ## Acknowledgments
190 |
191 | Constituent projects (ALM, ZeTT) were supported by a Royal Society University Research Fellowship ‘Inclusive and Sustainable Language Technology for a Truly Multilingual World’ (no 221137; 2022-) awarded to Ivan Vulić, by the Google Cloud Research Credits program with the award GCP329647813, and by Cloud TPUs from Google’s TPU Research Cloud (TRC). The name `tokenkit` and the README layout were inspired by [mergekit](https://github.com/arcee-ai/mergekit). [big_vision](https://github.com/google-research/big_vision) was extremely useful as a high-quality reference JAX training codebase.
192 |
--------------------------------------------------------------------------------
/configs/compute_mined_mapping.yaml:
--------------------------------------------------------------------------------
1 | teacher_tokenizer_name: google/gemma-2-2b-it:source=Gemma2
2 | target_tokenizer_name: Qwen/Qwen2.5-1.5B:source=Qwen2:target=Gemma2
3 | output: outputs/tokenizer_data/gemma2_to_qwen2
4 | num_workers: 64
--------------------------------------------------------------------------------
/configs/compute_tokenizer_info.yaml:
--------------------------------------------------------------------------------
1 | teacher_tokenizer_name: google/gemma-2-2b-it:source=Gemma2
2 | target_tokenizer_name: Qwen/Qwen2.5-1.5B:source=Qwen2:target=Gemma2
3 | output: outputs/tokenizer_data/gemma2_to_qwen2
4 | seed: 1234
5 | additive_smoothing_constant: 1e-9
6 | teacher_subsample_percent: null
7 | student_subsample_percent: null
8 |
9 | data:
10 | batch_size: 1024
11 | num_workers: 64
12 |
13 | defaults:
14 | - _self_
15 | - data: tulu3
--------------------------------------------------------------------------------
/configs/cross_tokenizer_distill.yaml:
--------------------------------------------------------------------------------
1 | steps: 5_000
2 | warmup_steps: 2_000
3 | name: "unnamed"
4 | output: "outputs/cross_tokenizer_distill"
5 | num_workers: 16
6 | log_interval: 10
7 | sync_interval: 100
8 | eval_interval: 5000
9 | save_interval: 5000
10 | losses: [alm_unbiased]
11 | target_tokenizer_name: "Qwen/Qwen2.5-1.5B:source=Qwen2:target=Gemma2"
12 |
13 | train_model_mode: "lora"
14 | model_lora_rank: 64
15 | model_lora_alpha: 64
16 | train_embeddings: true
17 |
18 | bce_temp: 100.0
19 | alm_diff_fn: "binary_ce"
20 | alm_mode: "append_space"
21 | tokenizer_pair_data_path: "artifacts/tokenizer_data/gemma2_to_qwen2"
22 | tokenizer_pair_bias_threshold: 1.e-4
23 | tokenizer_pair_bias_threshold_side_path: null
24 |
25 | student:
26 | pretrained_model_name_or_path: "benjamin/gemma-2-2b-it-flax"
27 | tokenizer_name: "google/gemma-2-2b-it:source=Gemma2"
28 |
29 | data:
30 | batch_size: 16
31 | num_workers: 16
32 | kind: "hf"
33 | mix_languages: false
34 | streaming: false
35 | shuffle_buffer_size: "inf"
36 | dataset_configs:
37 | - lang_code: en
38 | kwargs:
39 | path: allenai/tulu-3-sft-mixture
40 | split: train
41 |
42 | hypernet:
43 | architecture: transformer
44 | num_layers: 1
45 | residual: true
46 | residual_alpha: 1
47 | use_attention: false
48 |
49 | optimizer:
50 | type: adamw
51 | weight_decay: 0.01
52 | b1: 0.9
53 | b2: 0.95
54 | eps: 1.e-8
55 | grad_acc_steps: null
56 | learning_rate: 1.e-5
57 | max_grad_norm: 1.0
58 | param_groups:
59 | - pattern: .*(projector_query|projector_s2t|projector_t2s|projector_latents|loss_weights).*
60 | lr_scale: 2
61 |
62 | eval:
63 | tasks: [arc_easy,arc_challenge,piqa,hellaswag,boolq,arithmetic,mmlu]
64 | lengths: [128, 256, 512, 1024, 2048]
65 | tokens_per_batch: 8192
66 | add_bos: true
67 | chat_template_mode: surround_instruct
68 | confirm_run_unsafe_code: true
--------------------------------------------------------------------------------
/configs/data/tulu3.yaml:
--------------------------------------------------------------------------------
1 | kind: "hf"
2 | mix_languages: false
3 | streaming: false
4 | shuffle_buffer_size: "inf"
5 | dataset_configs:
6 | - lang_code: en
7 | kwargs:
8 | path: allenai/tulu-3-sft-mixture
9 | split: train
--------------------------------------------------------------------------------
/configs/eval.yaml:
--------------------------------------------------------------------------------
1 | use_cpu: false
2 | pad_to_multiple_of: 128
3 | output: outputs/eval
4 |
5 | model:
6 | pretrained_model_name_or_path: "benjamin/gemma-2-2b-it-flax"
7 | tokenizer_name: "google/gemma-2-2b-it:source=Gemma2"
8 |
9 | expand_model:
10 | pretrained_model_name_or_path: null
11 | tokenizer_name: null
12 |
13 | eval:
14 | tasks: ["piqa", "boolq", "arc_easy", "hellaswag"]
15 | lengths: [128, 256, 512, 1024, 2048]
16 | tokens_per_batch: 4096
17 | add_bos: true
18 | chat_template_mode: surround_instruct
19 | confirm_run_unsafe_code: true
20 |
--------------------------------------------------------------------------------
/configs/eval_lockstep.yaml:
--------------------------------------------------------------------------------
1 | combine_strategy: "mean_prob"
2 | pad_to_multiple_of: 128
3 | output: outputs/eval_lockstep
4 |
5 | eval:
6 | tasks: [arc_easy,arc_challenge,piqa,boolq,arithmetic,mmlu,ifeval,agieval_en,agieval_cn]
7 | lengths: [128, 256, 512, 1024, 2048]
8 | tokens_per_batch: 4096
9 | add_bos: true
10 | chat_template_mode: surround_instruct
11 | confirm_run_unsafe_code: true
12 |
13 | models:
14 | - pretrained_model_name_or_path: "/mnt/disks/persist/exports/20250424160833_gemma2_to_qwen2_ours_agg_approx_gradmag_preserve_mag_20k"
15 | tokenizer_name: "Qwen/Qwen2.5-1.5B:source=Qwen2:target=Gemma2"
16 | add_bos: true
17 | - pretrained_model_name_or_path: "/mnt/disks/persist/exports/20250424174156_llama3_to_qwen2_ours_agg_approx_gradmag_preserve_mag_20k"
18 | tokenizer_name: "Qwen/Qwen2.5-1.5B:source=Qwen2:target=Llama3"
19 | add_bos: true
20 | - pretrained_model_name_or_path: "benjamin/Qwen2.5-1.5B-Instruct-flax"
21 | tokenizer_name: "Qwen/Qwen2-1.5B-Instruct:source=Qwen2"
22 | add_bos: false
--------------------------------------------------------------------------------
/configs/eval_lockstep_mined.yaml:
--------------------------------------------------------------------------------
1 | combine_strategy: "mean_prob"
2 | pad_to_multiple_of: 128
3 | output: outputs/eval_lockstep
4 |
5 | eval:
6 | tasks: [arc_easy,arc_challenge,piqa,boolq,arithmetic,mmlu,ifeval,agieval_en,agieval_cn]
7 | lengths: [128, 256, 512, 1024, 2048]
8 | tokens_per_batch: 4096
9 | add_bos: true
10 | chat_template_mode: surround_instruct
11 | confirm_run_unsafe_code: true
12 |
13 | models:
14 | - pretrained_model_name_or_path: "benjamin/Qwen2.5-1.5B-Instruct-flax" # pivot first
15 | tokenizer_name: "Qwen/Qwen2-1.5B-Instruct:source=Qwen2"
16 | add_bos: false
17 | - pretrained_model_name_or_path: "benjamin/gemma-2-2b-it-flax"
18 | tokenizer_name: "google/gemma-2-2b-it:source=Gemma2"
19 | add_bos: true
20 | - pretrained_model_name_or_path: "benjamin/Llama-3.2-3B-Instruct-flax"
21 | tokenizer_name: "meta-llama/Llama-3.2-3B-Instruct:source=Llama3"
22 | add_bos: true
23 |
24 | baseline_mined_mapping_paths: [null, "artifacts/tokenizer_data/gemma2_to_qwen2", "artifacts/tokenizer_data/llama3_to_qwen2"]
--------------------------------------------------------------------------------
/configs/math_cross_tokenizer_distill.yaml:
--------------------------------------------------------------------------------
1 | steps: 5_000
2 | warmup_steps: 2_000
3 | name: "unnamed"
4 | output: "outputs/cross_tokenizer_distill"
5 | num_workers: 16
6 | log_interval: 10
7 | sync_interval: 100
8 | eval_interval: 5000
9 | save_interval: 5000
10 | losses: [alm_unbiased]
11 | target_tokenizer_name: google/gemma-2-2b-it:source=Gemma2
12 | tokens_to_add: [<|start_header_id|>,<|end_header_id|>,<|eot_id|>,<|eom_id|>,<|python_tag|>]
13 |
14 | train_model_mode: "lora"
15 | model_lora_rank: 64
16 | model_lora_alpha: 64
17 | train_embeddings: true
18 |
19 | bce_temp: 100.0
20 | alm_diff_fn: "binary_ce"
21 | alm_mode: "space_merge+append_space"
22 | tokenizer_pair_data_path: "artifacts/tokenizer_data/math_llama3_to_gemma2"
23 | tokenizer_pair_bias_threshold: 0.1
24 | tokenizer_pair_bias_threshold_side_path: null
25 |
26 | student:
27 | pretrained_model_name_or_path: "benjamin/gemma-2-2b-it-flax"
28 | tokenizer_name: "google/gemma-2-2b-it:source=Gemma2"
29 |
30 | teacher:
31 | pretrained_model_name_or_path: "benjamin/OpenMath2-Llama3.1-8B-flax"
32 | tokenizer_name: "nvidia/OpenMath2-Llama3.1-8B:source=Llama3"
33 |
34 | data:
35 | batch_size: 16
36 | num_workers: 16
37 | kind: "hf"
38 | mix_languages: false
39 | streaming: false
40 | dataset_configs:
41 | - lang_code: en
42 | kwargs:
43 | path: benjamin/OpenMathInstruct-2-2M-formatted
44 | split: train
45 |
46 | hypernet:
47 | architecture: transformer
48 | num_layers: 1
49 | residual: true
50 | residual_alpha: 1
51 | use_attention: false
52 |
53 | optimizer:
54 | type: adamw
55 | weight_decay: 0.01
56 | b1: 0.9
57 | b2: 0.95
58 | eps: 1.e-8
59 | grad_acc_steps: null
60 | learning_rate: 1.e-5
61 | max_grad_norm: 1.0
62 | param_groups:
63 | - pattern: .*(projector_query|projector_s2t|projector_t2s|projector_latents|loss_weights).*
64 | lr_scale: 2
65 |
66 | eval:
67 | tasks: [arc_easy,arc_challenge,piqa,hellaswag,boolq,arithmetic,mmlu]
68 | lengths: [128, 256, 512, 1024, 2048]
69 | tokens_per_batch: 8192
70 | add_bos: true
71 | chat_template_mode: surround_instruct
72 | confirm_run_unsafe_code: true
--------------------------------------------------------------------------------
/configs/math_same_tokenizer_distill.yaml:
--------------------------------------------------------------------------------
1 | steps: 5_000
2 | warmup_steps: 2_000
3 | name: "unnamed"
4 | output: "outputs/same_tokenizer_distill"
5 | num_workers: 16
6 | log_interval: 10
7 | sync_interval: 100
8 | eval_interval: 5000
9 | save_interval: 5000
10 | losses: [alm_unbiased]
11 | target_tokenizer_name: meta-llama/Llama-3.2-3B-Instruct:source=Llama3
12 |
13 | train_model_mode: "lora"
14 | model_lora_rank: 64
15 | model_lora_alpha: 64
16 | train_embeddings: true
17 |
18 | bce_temp: 100.0
19 | alm_diff_fn: "binary_ce"
20 | alm_mode: "space_merge+append_space"
21 | tokenizer_pair_data_path: "artifacts/tokenizer_data/math_llama3_to_llama3"
22 | tokenizer_pair_bias_threshold: 0.1
23 | tokenizer_pair_bias_threshold_side_path: null
24 |
25 | student:
26 | pretrained_model_name_or_path: "benjamin/Llama-3.2-3B-Instruct-flax"
27 | tokenizer_name: "meta-llama/Llama-3.2-3B-Instruct:source=Llama3"
28 |
29 | teacher:
30 | pretrained_model_name_or_path: "benjamin/OpenMath2-Llama3.1-8B-flax"
31 | tokenizer_name: "nvidia/OpenMath2-Llama3.1-8B:source=Llama3"
32 |
33 | data:
34 | batch_size: 16
35 | num_workers: 16
36 | kind: "hf"
37 | mix_languages: false
38 | streaming: false
39 | dataset_configs:
40 | - lang_code: en
41 | kwargs:
42 | path: benjamin/OpenMathInstruct-2-2M-formatted
43 | split: train
44 |
45 | hypernet:
46 | architecture: transformer
47 | num_layers: 1
48 | residual: true
49 | residual_alpha: 1
50 | use_attention: false
51 |
52 | optimizer:
53 | type: adamw
54 | weight_decay: 0.01
55 | b1: 0.9
56 | b2: 0.95
57 | eps: 1.e-8
58 | grad_acc_steps: null
59 | learning_rate: 1.e-5
60 | max_grad_norm: 1.0
61 | param_groups:
62 | - pattern: .*(projector_query|projector_s2t|projector_t2s|projector_latents|loss_weights).*
63 | lr_scale: 2
64 |
65 | eval:
66 | tasks: [arc_easy,arc_challenge,piqa,hellaswag,boolq,arithmetic,mmlu]
67 | lengths: [128, 256, 512, 1024, 2048]
68 | tokens_per_batch: 8192
69 | add_bos: true
70 | chat_template_mode: surround_instruct
71 | confirm_run_unsafe_code: true
--------------------------------------------------------------------------------
/configs/models/gemma_llama_qwen.yaml:
--------------------------------------------------------------------------------
1 | - pretrained_model_name_or_path: "benjamin/Gemma2-2B-IT-with-Qwen2-Tokenizer"
2 | tokenizer_name: "benjamin/Gemma2-2B-IT-with-Qwen2-Tokenizer:source=Gemma2:conversion=prebyteified"
3 | add_bos: true
4 | - pretrained_model_name_or_path: "benjamin/Llama3.2-3B-IT-with-Qwen2-Tokenizer"
5 | tokenizer_name: "benjamin/Llama3.2-3B-IT-with-Qwen2-Tokenizer:source=Llama3:conversion=prebyteified"
6 | add_bos: true
7 | - pretrained_model_name_or_path: "benjamin/Qwen2.5-1.5B-Instruct-flax"
8 | tokenizer_name: "Qwen/Qwen2-1.5B-Instruct:source=Qwen2"
9 | add_bos: false
--------------------------------------------------------------------------------
/configs/models/llama_qwen.yaml:
--------------------------------------------------------------------------------
1 | - pretrained_model_name_or_path: "benjamin/Gemma2-2B-IT-with-Qwen2-Tokenizer"
2 | tokenizer_name: "benjamin/Gemma2-2B-IT-with-Qwen2-Tokenizer:source=Qwen2"
3 | add_bos: true
4 | - pretrained_model_name_or_path: "benjamin/Qwen2.5-1.5B-Instruct-flax"
5 | tokenizer_name: "Qwen/Qwen2-1.5B-Instruct:source=Qwen2"
6 | add_bos: false
--------------------------------------------------------------------------------
/configs/optimizer/adamw.yaml:
--------------------------------------------------------------------------------
1 | type: adamw
2 | weight_decay: 0.01
3 | b1: 0.9
4 | b2: 0.95
5 | eps: 1e-8
6 | grad_acc_steps: null
--------------------------------------------------------------------------------
/configs/train_zett_hn.yaml:
--------------------------------------------------------------------------------
1 | losses: ["sft"]
2 | output: "outputs/hn"
3 | seed: 1234
4 | dtype: bfloat16
5 | pad_to_multiple_of: 128
6 | identity_steps: 0
7 | identity_lr: 3.e-4
8 | warmup_steps: 10_000
9 | steps: 200_000
10 | train_embeddings: false
11 | train_model_mode: "no"
12 | num_workers: 64
13 | name: "unnamed"
14 | compat: false
15 | eval_at_step_zero: false
16 | save_at_step_zero: false
17 |
18 | n_data_parallel: 1
19 | n_model_parallel: 8
20 |
21 | log_interval: 50
22 | sync_interval: 100
23 | eval_interval: 10_000
24 | save_interval: 10_000
25 |
26 | ppl_eval_data: null
27 |
28 | optimizer:
29 | learning_rate: 6e-5
30 |
31 | eval:
32 | tasks: [piqa,hellaswag,arc_easy]
33 | lengths: [128, 256, 512, 1024, 2048]
34 | tokens_per_batch: 8192
35 | add_bos: true
36 | chat_template_mode: direct_encode
37 | confirm_run_unsafe_code: true
38 | tokenizers:
39 | - tokenizer: openai-community/gpt2:source=GPT2:target=TinyLlama:conversion=manual_add_prefix_space
40 | name: gpt2
41 | - tokenizer: mistralai/Mistral-Small-3.1-24B-Base-2503:source=Mistral:target=TinyLlama:conversion=manual_add_prefix_space
42 | name: mistral
43 | - tokenizer: meta-llama/Llama-3.2-3B:source=Llama3:target=TinyLlama:conversion=manual_add_prefix_space
44 | name: llama3
45 |
46 | data:
47 | batch_size: 128
48 | num_workers: 16
49 | kind: hf
50 | mix_languages: false
51 | streaming: true
52 | dataset_configs:
53 | - lang_code: en
54 | kwargs:
55 | path: "allenai/madlad-400"
56 | name: "en"
57 | split: "clean"
58 |
59 | # TODO: disentangle data/collator args
60 | collator:
61 | do_tokenizer_sampling: true
62 | sample_text_span: true
63 | n_pools: 1
64 | add_prefix_space: true
65 | hn_surface_maxlen: 8
66 | n_token_subsample: null
67 | identity_n_token_subsample: 16384
68 | pad_to_multiple_of: 128
69 | tokenizer_sample_max: 32768
70 | tokenizer_sample_mean: 32768
71 | tokenizer_sample_min: 32768
72 | tokenizer_sample_std: 0
73 | tokenizer_batch_size: 2048
74 | tokenizer_noise_std: 4
75 | tokenizer_noise_mean: 1.e-5
76 | block_size: 128
77 |
78 | hypernet:
79 | architecture: transformer
80 | residual_alpha: 1
81 | residual: true
82 | use_attention: true
83 | num_layers: 3
84 | shared: true
85 | num_heads: 16
86 | use_attention_mask: false
87 | multiply_hidden_dim_by_num_embeddings: false
88 |
89 | optimizer:
90 | type: adamw
91 | weight_decay: 0.01
92 | b1: 0.9
93 | b2: 0.95
94 | eps: 1.e-8
95 | grad_acc_steps: null
96 | learning_rate: 1.e-5
97 | max_grad_norm: 1.0
98 | param_groups:
99 | - pattern: .*(projector_query|projector_s2t|projector_t2s|projector_latents|loss_weights).*
100 | lr_scale: 2
101 |
102 | model:
103 | pretrained_model_name_or_path: "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T"
104 | tokenizer_name: "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T:source=TinyLlama"
105 | revision: "refs/pr/8"
--------------------------------------------------------------------------------
/configs/train_zett_hn_gemma2.yaml:
--------------------------------------------------------------------------------
1 | losses: ["sft"]
2 | output: "outputs/hn"
3 | seed: 1234
4 | dtype: bfloat16
5 | pad_to_multiple_of: 128
6 | identity_steps: 0
7 | identity_lr: 3.e-4
8 | warmup_steps: 10_000
9 | steps: 200_000
10 | train_embeddings: false
11 | train_model_mode: "no"
12 | num_workers: 64
13 | name: "unnamed"
14 | compat: false
15 | eval_at_step_zero: false
16 | save_at_step_zero: false
17 |
18 | n_data_parallel: 1
19 | n_model_parallel: 8
20 |
21 | log_interval: 50
22 | sync_interval: 100
23 | eval_interval: 10_000
24 | save_interval: 10_000
25 |
26 | ppl_eval_data: null
27 |
28 | optimizer:
29 | learning_rate: 6e-5
30 |
31 | eval:
32 | tasks: [piqa,hellaswag,arc_easy]
33 | lengths: [128, 256, 512, 1024, 2048]
34 | tokens_per_batch: 8192
35 | add_bos: true
36 | chat_template_mode: direct_encode
37 | confirm_run_unsafe_code: true
38 | tokenizers:
39 | - tokenizer: openai-community/gpt2:source=GPT2:target=Gemma2
40 | name: gpt2
41 | - tokenizer: mistralai/Mistral-Small-3.1-24B-Base-2503:source=Mistral:target=Gemma2
42 | name: mistral
43 | - tokenizer: meta-llama/Llama-3.2-3B:source=Llama3:target=Gemma2
44 | name: llama3
45 |
46 | data:
47 | batch_size: 128
48 | kind: "hf_saved"
49 | lang_code: en
50 | dataset_configs:
51 | - path: "/lfs/dolmino_50B_medium"
52 |
53 | # TODO: disentangle data/collator args
54 | collator:
55 | do_tokenizer_sampling: true
56 | sample_text_span: true
57 | n_pools: 1
58 | add_prefix_space: false
59 | hn_surface_maxlen: 8
60 | n_token_subsample: null
61 | identity_n_token_subsample: 16384
62 | pad_to_multiple_of: 128
63 | tokenizer_sample_max: 32768
64 | tokenizer_sample_mean: 32768
65 | tokenizer_sample_min: 32768
66 | tokenizer_sample_std: 0
67 | tokenizer_batch_size: 2048
68 | tokenizer_noise_std: 4
69 | tokenizer_noise_mean: 1.e-5
70 | block_size: 128
71 |
72 | hypernet:
73 | architecture: transformer
74 | residual_alpha: 1
75 | residual: true
76 | use_attention: true
77 | num_layers: 3
78 | shared: true
79 | num_heads: 16
80 | use_attention_mask: false
81 | multiply_hidden_dim_by_num_embeddings: false
82 |
83 | optimizer:
84 | type: adamw
85 | weight_decay: 0.01
86 | b1: 0.9
87 | b2: 0.95
88 | eps: 1.e-8
89 | grad_acc_steps: null
90 | learning_rate: 1.e-5
91 | max_grad_norm: 1.0
92 | param_groups:
93 | - pattern: .*(projector_query|projector_s2t|projector_t2s|projector_latents|loss_weights).*
94 | lr_scale: 2
95 |
96 | model:
97 | pretrained_model_name_or_path: "benjamin/gemma-2-2b-flax"
98 | tokenizer_name: "google/gemma-2-2b:source=Gemma2"
--------------------------------------------------------------------------------
/configs/zett.yaml:
--------------------------------------------------------------------------------
1 | source_model:
2 | pretrained_model_name_or_path: benjamin/gemma-2-2b-it-flax
3 | tokenizer_name: google/gemma-2-2b-it:source=Gemma2
4 |
5 | target_tokenizer_name: Qwen/Qwen2.5-1.5B:source=Qwen2:target=Gemma2
6 | output: outputs/zett/gemma-2-2b-it-flax-fvt-qwen
--------------------------------------------------------------------------------
/docs/byteification.md:
--------------------------------------------------------------------------------
1 | # Byteification: A Unified Interface to Tokenizers
2 |
3 | In this guide, we'll take a look at how `tokenkit` interacts with tokenizers: `tokenkit` uses a unified byte-level interface to tokenizers to prevent issues stemming from tokenizers using different encoding schemes. For example, let's say we want to compute the number of overlapping tokenizers between the Gemma2 and Llama3 tokenizers. Here is the naive approach:
4 |
5 | ```python
6 | from transformers import AutoTokenizer
7 |
8 | tok1 = AutoTokenizer.from_pretrained("google/gemma-2-2b-it")
9 | tok2 = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-3B-Instruct")
10 |
11 | n_overlap = len(set(tok1.get_vocab().keys()) & set(tok2.get_vocab().keys()))
12 | # 25632 - this is suspiciously low!
13 | ```
14 |
15 | The two tokenizers use a different encoding, so even if two tokens encode the same UTF-8 bytes, they might look different!
16 |
17 | ```python
18 | tok1.tokenize(" Café") # ['▁Café']
19 | tok2.tokenize(" Café") # ['ĠCafé']
20 | ```
21 |
22 | We can fix this by instead using the `tokenkit.byteify.ByteifyTokenizer` interface. 'Byteification' preserves the tokenizers functionality, while providing a unified (byte-level) encoding:
23 |
24 | ```python
25 | from tokenkit.byteify import load_byteify_tokenizer
26 |
27 | tok1 = load_byteify_tokenizer("google/gemma-2-2b-it:source=Gemma2")
28 | tok2 = load_byteify_tokenizer("meta-llama/Llama-3.2-3B-Instruct:source=Llama3")
29 |
30 | n_overlap = len(set(tok1.get_vocab().keys()) & set(tok2.get_vocab().keys()))
31 | # 85699 - this is much more reasonable!
32 |
33 | tok1.tokenize(" Café") # ['ĠCafé']
34 | tok2.tokenize(" Café") # ['ĠCafé']
35 | ```
36 |
37 | This always 100% preserves the tokenizer functionality (e.g., which tokens any text is encoded as). The API mostly matches the HuggingFace tokenizers API (e.g., `convert_ids_to_tokens`, `convert_tokens_to_ids`, `get_vocab`, `tokenize`, `add_tokens`) but is not exactly the same.
38 |
39 | This allows us to compute things like lexical overlap and token sequence alignments accurately. `tokenkit` also implements an exact alignment algorithm between tokenizers, including tokenizers with different special tokens (e.g., different chat templates).
40 |
41 | ```python
42 | from tokenkit.byteify import load_byteify_tokenizer
43 | from tokenkit import align
44 |
45 | tok1 = load_byteify_tokenizer("google/gemma-2-2b-it:source=Gemma2")
46 | tok2 = load_byteify_tokenizer("meta-llama/Llama-3.2-3B-Instruct:source=Llama3")
47 |
48 | # Gemma2 chat template
49 | tokens1 = tok1.tokenize("user\nWhat's ultracrepidarianism?\n")
50 | # Llama3 chat template
51 | tokens2 = tok2.tokenize("<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nWhat's ultracrepidarianism?<|eot_id|>")
52 |
53 | alignment_indices = align.get_alignment_indices(tokens1, tokens2, tok1, tok2)[0]
54 |
55 | for (start1, end1, start2, end2) in alignment_indices:
56 | print(tokens1[start1:end1], tokens2[start2:end2])
57 |
58 | # [''] ['<|begin_of_text|>']
59 | # [''] ['<|start_header_id|>']
60 | # ['user'] ['user']
61 | # ['Ċ'] ['<|end_header_id|>', 'ĊĊ']
62 | # ['What'] ['What']
63 | # ["'", 's'] ["'s"]
64 | # ['Ġultra', 'cre', 'pid'] ['Ġultr', 'ac', 'repid']
65 | # ['arian'] ['arian']
66 | # ['ism'] ['ism']
67 | # ['?'] ['?']
68 | # ['', 'Ċ'] ['<|eot_id|>']
69 | ```
70 |
71 | ## Tokenizer Specs
72 |
73 | As you've seen above, `tokenkit` uses colon-separated *tokenizer spec* to load tokenizers. This gives us a simple way to specify additional arguments and modifications to the tokenizer.
74 |
75 | - The `source=` argument (the only required argument) enables making sure we set special tokens and the chat template correctly for the given model family. See [tokenkit/model_kinds.py](../tokenkit/model_kinds.py) for supported model families or to add new model families.
76 | - The optional `target=` argument enables updating the tokenizer special tokens / chat template to a different model family. E.g. `google/gemma-2-2b-it:source=Gemma2:target=Qwen2` would tokenize all regular text equivalent to the Gemma2 tokenizer, but use the Qwen2 chat template and special tokens. Since Qwen2 does not use a \ token, it would thus also not use a \ token.
77 | - The optional `conversion=` argument enables conversion to a different encoding scheme. `conversion=byte` is the one you are most likely to encounter. This converts the tokenizer to tokenize all regular (non-special-token) bytes as individual tokens i.e. to byte-level tokenization (*this is different and unrelated to byteification!*). Special tokens are kept as-is. For example:
78 |
79 | ```python
80 | from tokenkit.byteify import load_byteify_tokenizer
81 |
82 | tok = load_byteify_tokenizer("google/gemma-2-2b-it:source=Gemma2:conversion=byte")
83 |
84 | tok.tokenize("Hello, world!") # ['', 'H', 'e', 'l', 'l', 'o', ',', 'Ġ', 'w', 'o', 'r', 'l', 'd', '!']
85 | print(len(tok)) # 256 + some special tokens
86 | ```
87 |
88 | ---
89 |
90 |
--------------------------------------------------------------------------------
/docs/tokenizer_transfer.md:
--------------------------------------------------------------------------------
1 | # Tokenizer Transfer via tokenkit
2 |
3 | This guide will walk you through the process of transferring a pretrained model to a new tokenizer using tokenkit.
4 |
5 | First, follow the installation instructions in the [README](../README.md).
6 |
7 | Then, the scripts in `examples/` provide a starting point for transferring a model to a new tokenizer. For example:
8 |
9 | ```bash
10 | bash examples/llama3_to_qwen2_tokenizer_gpu.sh
11 | # or on TPU: examples/llama3_to_qwen2_tokenizer_tpu.sh
12 | ```
13 |
14 | This will distill the [Llama3.2-3B-Instruct](https://huggingface.co/meta-llama/Llama-3.2-3B-Instruct) model to the [Qwen2.5-1.5B](https://huggingface.co/Qwen/Qwen2.5-1.5B-Instruct) tokenizer. Let's have a look at what it runs:
15 |
16 | ```bash
17 | # examples/llama3_to_qwen2_tokenizer_gpu.sh
18 | NAME=llama3_to_qwen2_tokenizer
19 | python3 scripts/cross_tokenizer_distill.py \
20 | --config=configs/cross_tokenizer_distill.yaml \
21 | --overrides \
22 | losses=[sft,alm_unconstrained] \
23 | alm_mode=merge_by_space_prob+append_space \
24 | tokenizer_pair_bias_threshold=0.1 \
25 | train_model_mode=lora \
26 | model_lora_rank=64 \
27 | model_lora_alpha=64 \
28 | n_data_parallel=1 \
29 | n_model_parallel=1 \
30 | steps=5000 \
31 | eval_interval=1000 \
32 | save_interval=1000 \
33 | data.batch_size=64 \
34 | optimizer.grad_acc_steps=4 \
35 | data.num_workers=16 \
36 | student.pretrained_model_name_or_path=benjamin/Llama-3.2-3B-Instruct-flax \
37 | student.tokenizer_name=\'meta-llama/Llama-3.2-3B-Instruct:source=Llama3\' \
38 | target_tokenizer_name=\'Qwen/Qwen2.5-1.5B:source=Qwen2:target=Llama3\' \
39 | name=$NAME
40 | ```
41 |
42 | Default arguments are taken from [`configs/cross_tokenizer_distill.yaml`](../configs/cross_tokenizer_distill.yaml). You can keep many of these as-is. A notably parameter which we don't override here is the dataset: we use the [Tulu3 instruction-tuning dataset](https://huggingface.co/datasets/allenai/tulu-3-sft-mixture). This is a good choice for transfer of chat / instruction-following models. You can update this to fit your use case by modifying the `data` section in [`configs/cross_tokenizer_distill.yaml`](../configs/cross_tokenizer_distill.yaml).
43 |
44 | Let's go over the overriden arguments in more detail:
45 |
46 | ```
47 | losses=[sft,alm_unconstrained] \
48 | alm_mode=merge_by_space_prob+append_space \
49 | tokenizer_pair_bias_threshold=0.1 \
50 | ```
51 |
52 | These arguments configure the losses to optimize and the ALM mode to use. The above configuration should be the best for many cases. Importantly, *it is different to what is described in the [ALM paper](https://arxiv.org/abs/2503.20083)*. In particular, it achieves equivalent or better results without precomputation. A more detailed description is forthcoming in an updated version of our paper.
53 |
54 | ```
55 | hypernet.architecture=identity \
56 | ```
57 |
58 | By default, `tokenkit` uses a one-layer embedding projector network (`hypernet.architecture=transformer`). This improves performance but can be memory-intensive, so we disable it here.
59 |
60 | ```
61 | multitask_aggregation_fn=approx_gradmag_preserve_mag \
62 | ```
63 |
64 | Use the last-layer gradient magnitudes to approximately reweigh the the multiple objectives (in this case, SFT and ALM) to contribute equally to the final loss gradients. This adds a little extra overhead since we need to backpropagate through the last layer separately for every objective, but it removes the requirement to manually tune loss weights. If you observe that this adds too much overhead, you can skip it and manually tune the loss weights using e.g. `loss_weights=[1.,0.5]`, or leave it out completely to use uniform loss weights (instead of uniform loss *gradient* weights).
65 |
66 | ```
67 | train_model_mode=lora \
68 | model_lora_rank=64 \
69 | model_lora_alpha=64 \
70 | ```
71 |
72 | Train the model using LoRA with rank = alpha = 64. `tokenkit` applies LoRA to the QKV projections, attention output projection, as well as the MLP up-, down- and gate projections (see [lora.py](../tokenkit/models/lora.py)).
73 |
74 | You can use `train_model_mode=full` to train the full model instead. However, in this case, we need to store a separate copy of the model parameters for the student and the teacher, whereas with LoRA, we can use a single model parameter copy and materialize / dematerialize the LoRA parameters as needed. Storing a separate teacher model copy makes training substantially more memory intensive. A rule of thumb is: for transfer to a similar kind of tokenizer (e.g., another subword tokenizer), LoRA is sufficient. For transfer to a very different tokenizer (e.g., a byte-level tokenizer), full-finetuning helps.
75 |
76 | ```
77 | n_data_parallel=1 \
78 | n_model_parallel=1 \
79 | ```
80 |
81 | Data and model parallelism. Set this such that the product of the two is the number of GPUs or TPU cores you have available. Often (especially for larger models) you will want to increase model parallelism and keep data parallelism at 1.
82 |
83 | ```
84 | steps=5000 \
85 | eval_interval=1000 \
86 | save_interval=1000 \
87 | data.batch_size=64 \
88 | optimizer.grad_acc_steps=4 \
89 | data.num_workers=16 \
90 | ```
91 |
92 | Train for 5000 steps, evaluate every 1000 steps, save the model every 1000 steps at a global batch size of 64 with 4 gradient accumulation steps (i.e., a local batch size of 16). Evaluation is done via (a fork of) [`lm-evaluation-harness`](https://github.com/bminixhofer/lm-evaluation-harness) and runs the tasks configured via `eval.tasks`.
93 |
94 | ```
95 | student.pretrained_model_name_or_path=benjamin/Llama-3.2-3B-Instruct-flax \
96 | student.tokenizer_name=\'meta-llama/Llama-3.2-3B-Instruct:source=Llama3\' \
97 | target_tokenizer_name=\'Qwen/Qwen2.5-1.5B:source=Qwen2:target=Llama3\' \
98 | ```
99 |
100 | The (local or HF hub) paths to the model to transfer. If we do not specify a separate teacher, the teacher will be the student with the original tokenizer (this is what we want for tokenizer transfer). Notably:
101 |
102 | - The model is `benjamin/Llama-3.2-3B-Instruct-flax` since the original `meta-llama/Llama-3.2-3B-Instruct` model is not in Flax format. You can convert supported models to Flax using the `scripts/push_flax_version_to_hub.py` script.
103 | - The tokenizer is specified using a tokenizer spec which differs from the HuggingFace `AutoTokenizer` format by including additional colon-separated tags. For example: `Qwen/Qwen2.5-1.5B:source=Qwen2:target=Llama3` specifies the Qwen2.5-1B-Instruct tokenizer initially stemming from the Qwen2 model family (`source=Qwen2`) updated to use the special tokens of the Llama3 family instead (`target=Llama3`). See the [byteification](./byteification.md) guide for more details on the interface tokenkit provides to use HuggingFace tokenizers. For our purposes in this guide, it is important that when you transfer across tokenizers, you can choose to either (i) preserve the original special tokens (safer but potentially inconvenient) or (ii) use the special tokens from the new tokenizer (less safe but potentially more convenient). More on this below in [To Keep or to Change Special Tokens?](#to-keep-or-to-change-the-special-tokens).
104 |
105 | ```
106 | name=$NAME
107 | ```
108 |
109 | The name to track the experiment with. By default, `tokenkit` uses [Weights & Biases](https://www.wandb.ai/) to track experiments.
110 |
111 | This is all, you can now transfer your first model!
112 |
113 | ## Transfer to Bytes
114 |
115 | We need to make a couple of adjustments to enable effective transfer to byte-level tokenizers. Let's compare the example config in [`examples/llama3_to_byte_tokenizer_gpu.sh`](../examples/llama3_to_byte_tokenizer_gpu.sh) to the config we used above:
116 |
117 | ```diff
118 | - losses=[sft,alm_unconstrained] \
119 | + losses=[sft,alm_unconstrained,alm_latents] \
120 | ```
121 |
122 | For transfer to bytes, the ALM latent (hidden-state alignment) objective substantially improves performance.
123 |
124 | ```diff
125 | - train_model_mode=lora \
126 | - model_lora_rank=64 \
127 | - model_lora_alpha=64 \
128 | + train_model_mode=full \
129 | + expand_input_ids=true \
130 | + output_embeddings_mode=untie \
131 | ```
132 |
133 | We train the full model to give it more capacity to adapt to the fundamental change in tokenization. We also *expand* the input IDs to inject some extra parameters while preserving total FLOPs. What input ID expansion does is: for every byte embedding, add the subword embedding of the longest matching subword ending at this byte position (where the subwords and subword embeddings are taken from the original tokenizer and embedding matrix). Finally, we untie the byte input and output embeddings, since there is no reason to tie them in the byte-level case (we don't save any considerable amount of parameters). This may also marginally improve performance.
134 |
135 | ```diff
136 | - target_tokenizer_name=\'Qwen/Qwen2.5-1.5B:source=Qwen2:target=Llama3\' \
137 | + target_tokenizer_name=\'meta-llama/Llama-3.2-3B-Instruct:source=Llama3:conversion=byte\'
138 | ```
139 |
140 | The target tokenizer is now specified using a tokenizer spec which includes the conversion to bytes. See the [byteification](./byteification.md#tokenizer-spec) guide for details on this spec.
141 |
142 | ## Exporting the Model
143 |
144 | `tokenkit` uses a custom internal format to checkpoint model fine-tuning parameter diffs. To export a checkpoint to the HuggingFace format, run e.g.:
145 |
146 | ```bash
147 | # --with_pt exports the model in PyTorch format (in addition to Flax)
148 | python3 scripts/export_checkpoint.py \
149 | --checkpoint_path=outputs/cross_tokenizer_distill/step_5000 \
150 | --output=checkpoints/llama3_to_qwen2_tokenizer_hf \
151 | --with_pt
152 | ```
153 |
154 | If you are exporting a model which has been trained with input ID expansion, you need to also specify which embeddings and tokenizer to use for expansion, e.g.:
155 |
156 | ```bash
157 | python3 scripts/export_checkpoint.py \
158 | --checkpoint_path=outputs/cross_tokenizer_distill/step_5000 \
159 | --output=checkpoints/llama3_to_bytes \
160 | --with_pt \
161 | --expand_input_ids_model=benjamin/Llama-3.2-3B-Instruct-flax \
162 | --expand_input_ids_tokenizer=meta-llama/Llama-3.2-3B-Instruct:source=Llama3
163 | ```
164 |
165 | Afterwards, you can load the model as usual using HuggingFace transformers:
166 |
167 | ```python
168 | from tranformers import AutoModelForCausalLM
169 | from tokenkit.byteify import load_byteify_tokenizer
170 |
171 | model = AutoModelForCausalLM.from_pretrained("checkpoints/llama3_to_bytes", trust_remote_code=True)
172 | tokenizer = load_byteify_tokenizer("meta-llama/Llama-3.2-3B-Instruct:source=Llama3:conversion=byte")
173 |
174 | tokens = tokenizer.tokenizer.apply_chat_template([{"role": "user", "content": "Hello, how are you?"}], return_tensors="pt")
175 | output = model.generate(tokens)
176 | print(tokenizer.decode(output[0]))
177 | ```
178 |
179 | ## To Keep or to Change the Special Tokens?
180 |
181 | In the above example where we transferred to the Qwen2 tokenizer, by using the target tokenizer spec `Qwen/Qwen2.5-1.5B:source=Qwen2:target=Llama3`, we transferred to a tokenizer using all the *regular* tokens from the Qwen2 tokenizer, but keeping the special tokens (and the chat template) from the Llama3 tokenizer. We can instead transfer to a tokenizer which is completely equivalent to the Qwen2 tokenizer (regular and special tokens) by specifying it as `Qwen/Qwen2.5-1.5B:source=Qwen2:target=Qwen2`. What to choose depends on your use case:
182 |
183 | - *Keeping the special tokens:* This is the safer choice, since the model will not have to learn to use a new chat template format with new special tokens. If you just want to, for example, transfer to a new tokenizer which encodes some domain more efficiently, this is the better choice.
184 | - *Changing the special tokens:* If you are using tokenizer transfer to combine (e.g., ensemble) multiple models, this is more convenient since we don't need to worry about aligning the different special tokens and chat templates to each other (which is quite easy to do, but still inconvenient). However, there's some things to be careful about: for example, transferring Gemma2 to the Llama3 chat template is quite easy since both use similar formats and both use a \ token. However, transferring Gemma2 to the Qwen2 chat template is not as straightforward *since Gemma2 uses a \ token, but Qwen2 doesn't*. The model thus has to learn to re-distribute the original attention sink behavior of the \ token across other tokens. This may or may not work well, depending on the training budget, dataset and so on.
185 |
186 | ---
187 |
--------------------------------------------------------------------------------
/examples/gemma2_distill_from_openmath2-llama_gpu.sh:
--------------------------------------------------------------------------------
1 | NAME=gemma2_distill_from_openmath2-llama
2 | python3 scripts/cross_tokenizer_distill.py \
3 | --config=configs/math_cross_tokenizer_distill.yaml \
4 | --overrides \
5 | losses=[sft,alm_unconstrained] \
6 | alm_mode=merge_by_space_prob+append_space \
7 | tokenizer_pair_bias_threshold=0.1 \
8 | max_teacher_length=1024 \
9 | max_student_length=1024 \
10 | n_data_parallel=1 \
11 | n_model_parallel=1 \
12 | steps=5000 \
13 | eval_interval=5000 \
14 | save_interval=5000 \
15 | optimizer.learning_rate=5.e-6 \
16 | optimizer.weight_decay=0.0 \
17 | optimizer.max_grad_norm=null \
18 | optimizer.grad_acc_steps=4 \
19 | eval.tasks=[math_500_openmath2,gsm8k_openmath2] \
20 | eval.lengths=[2048] \
21 | eval.tokens_per_batch=16384 \
22 | eval.chat_template_mode=direct_encode_no_force_eos \
23 | data.batch_size=64 \
24 | log_interval=10 \
25 | sync_interval=100 \
26 | use_chat_template=true \
27 | chat_template_mode=direct_encode \
28 | hypernet.architecture=identity \
29 | train_embeddings=true \
30 | train_model_mode=full \
31 | eval_at_step_zero=false \
32 | save_at_step_zero=false \
33 | skip_lm_eval=false \
34 | num_workers=24 \
35 | name=$NAME
36 |
--------------------------------------------------------------------------------
/examples/llama3_to_byte_tokenizer_gpu.sh:
--------------------------------------------------------------------------------
1 | NAME=llama3_to_byte
2 | python3 scripts/cross_tokenizer_distill.py \
3 | --config=configs/cross_tokenizer_distill.yaml \
4 | --overrides \
5 | losses=[sft,alm_unconstrained,alm_latents] \
6 | alm_mode=merge_by_space_prob+append_space \
7 | tokenizer_pair_bias_threshold=0.1 \
8 | hypernet.architecture=identity \
9 | multitask_aggregation_fn=approx_gradmag_preserve_mag \
10 | train_model_mode=full \
11 | expand_input_ids=true \
12 | output_embeddings_mode=untie \
13 | n_data_parallel=1 \
14 | n_model_parallel=1 \
15 | steps=5000 \
16 | eval_interval=1000 \
17 | save_interval=1000 \
18 | data.batch_size=64 \
19 | optimizer.grad_acc_steps=4 \
20 | data.num_workers=16 \
21 | data.batch_size=64 \
22 | student.pretrained_model_name_or_path="benjamin/Llama-3.2-3B-Instruct-flax" \
23 | student.tokenizer_name=\'meta-llama/Llama-3.2-3B-Instruct:source=Llama3\' \
24 | target_tokenizer_name=\'meta-llama/Llama-3.2-3B-Instruct:source=Llama3:conversion=byte\' \
25 | num_workers=16 \
26 | name=$NAME
--------------------------------------------------------------------------------
/examples/llama3_to_qwen2_tokenizer_gpu.sh:
--------------------------------------------------------------------------------
1 | NAME=llama3_to_qwen2_tokenizer
2 | python3 scripts/cross_tokenizer_distill.py \
3 | --config=configs/cross_tokenizer_distill.yaml \
4 | --overrides \
5 | losses=[sft,alm_unconstrained] \
6 | alm_mode=merge_by_space_prob+append_space \
7 | tokenizer_pair_bias_threshold=0.1 \
8 | hypernet.architecture=identity \
9 | multitask_aggregation_fn=approx_gradmag_preserve_mag \
10 | train_model_mode=lora \
11 | model_lora_rank=64 \
12 | model_lora_alpha=64 \
13 | n_data_parallel=1 \
14 | n_model_parallel=1 \
15 | steps=5000 \
16 | eval_interval=1000 \
17 | save_interval=1000 \
18 | data.batch_size=64 \
19 | optimizer.grad_acc_steps=4 \
20 | data.num_workers=16 \
21 | data.batch_size=64 \
22 | student.pretrained_model_name_or_path="benjamin/Llama-3.2-3B-Instruct-flax" \
23 | student.tokenizer_name=\'meta-llama/Llama-3.2-3B-Instruct:source=Llama3\' \
24 | target_tokenizer_name=\'Qwen/Qwen2.5-1.5B:source=Qwen2:target=Llama3\' \
25 | num_workers=16 \
26 | name=$NAME
--------------------------------------------------------------------------------
/examples/llama3_to_qwen2_tokenizer_tpu.sh:
--------------------------------------------------------------------------------
1 | # tpuv3-8
2 | N_DATA_PARALLEL=1
3 | N_MODEL_PARALLEL=8
4 |
5 | # tpuv4-8
6 | # N_DATA_PARALLEL=1
7 | # N_MODEL_PARALLEL=4
8 |
9 | # tpuv4-32
10 | # N_DATA_PARALLEL=4
11 | # N_MODEL_PARALLEL=4
12 |
13 |
14 | NAME=llama3_to_qwen2_tokenizer
15 | python3 scripts/cross_tokenizer_distill.py \
16 | --config=configs/cross_tokenizer_distill.yaml \
17 | --overrides \
18 | losses=[sft,alm_unconstrained] \
19 | alm_mode=merge_by_space_prob+append_space \
20 | tokenizer_pair_bias_threshold=0.1 \
21 | hypernet.architecture=identity \
22 | multitask_aggregation_fn=approx_gradmag_preserve_mag \
23 | train_model_mode=lora \
24 | model_lora_rank=64 \
25 | model_lora_alpha=64 \
26 | n_data_parallel=$N_DATA_PARALLEL \
27 | n_model_parallel=$N_MODEL_PARALLEL \
28 | steps=5000 \
29 | eval_interval=1000 \
30 | save_interval=1000 \
31 | data.batch_size=64 \
32 | optimizer.grad_acc_steps=4 \
33 | data.num_workers=16 \
34 | data.batch_size=64 \
35 | student.pretrained_model_name_or_path="benjamin/Llama-3.2-3B-Instruct-flax" \
36 | student.tokenizer_name=\'meta-llama/Llama-3.2-3B-Instruct:source=Llama3\' \
37 | target_tokenizer_name=\'Qwen/Qwen2.5-1.5B:source=Qwen2:target=Llama3\' \
38 | num_workers=16 \
39 | name=$NAME
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.black]
2 | line-length = 88
3 | target-version = ['py37']
4 | include = '\.pyi?$'
5 | extend-exclude = '''
6 | # A regex preceded with ^/ will apply only to files and directories
7 | # in the root of the project.
8 | ^/build/
9 | '''
10 |
11 | [tool.isort]
12 | profile = "black"
13 | multi_line_output = 3
14 | include_trailing_comma = true
15 | force_grid_wrap = 0
16 | use_parentheses = true
17 | ensure_newline_before_comments = true
18 | line_length = 88
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | hydra-core==1.3.2
2 | omegaconf==2.3.0
3 | tokenizers==0.20.3
4 | datasets==3.2.0
5 | scipy==1.14.1
6 | google-cloud-storage==2.19.0
7 | ipython==8.28.0
8 | ipdb==0.13.13
9 | flax==0.10.2
10 | torchdata==0.8.0
11 | wandb==0.19.1
12 | pytest==8.3.5
13 | lingvo==0.12.7
14 | optax==0.2.4
15 | numpy==1.26.4
16 | fiddle==0.3.0
17 | jaxtyping==0.2.36
18 | typeguard==4.4.1
19 | jax-bitempered-loss==0.0.2
20 | editdistance==0.8.1
21 | langdetect==1.0.9
22 | immutabledict==4.2.1
23 | maturin==1.8.3
24 | transformers==4.46.0
25 | git+https://github.com/google/CommonLoopUtils@307b0bc65ae2e2d801b6c19df2431c060a0aa4ec
--------------------------------------------------------------------------------
/rust_utils/.gitignore:
--------------------------------------------------------------------------------
1 | target
--------------------------------------------------------------------------------
/rust_utils/Cargo.toml:
--------------------------------------------------------------------------------
1 | [package]
2 | name = "rust_utils"
3 | version = "0.14.1-dev.0"
4 | edition = "2021"
5 |
6 | [lib]
7 | name = "rust_utils"
8 | crate-type = ["cdylib"]
9 |
10 | [dependencies]
11 | pyo3 = { version = "0.19", features = ["serde"]}
12 | onig = { version = "6.0", default-features = false }
13 | rand = "0.8"
14 | rand_distr = "0.4.3"
15 |
16 | [dependencies.tokenizers]
17 | version = "0.20.3"
18 | default-features = false
19 | features = ["onig"]
20 |
21 | [dev-dependencies]
22 | tempfile = "3.1"
23 | pyo3 = { version = "0.19", features = ["auto-initialize"] }
24 |
25 | [features]
26 | defaut = ["pyo3/extension-module"]
27 |
--------------------------------------------------------------------------------
/rust_utils/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["maturin>=1.4,<2.0"]
3 | build-backend = "maturin"
4 |
5 | [project]
6 | name = "rust_utils"
7 | requires-python = ">=3.8"
8 | classifiers = [
9 | "Programming Language :: Rust",
10 | "Programming Language :: Python :: Implementation :: CPython",
11 | "Programming Language :: Python :: Implementation :: PyPy",
12 | ]
13 | dynamic = ["version"]
14 |
15 | [tool.maturin]
16 | features = ["pyo3/extension-module"]
17 |
--------------------------------------------------------------------------------
/rust_utils/rust-toolchain:
--------------------------------------------------------------------------------
1 | stable
2 |
--------------------------------------------------------------------------------
/rust_utils/tpu_build.sh:
--------------------------------------------------------------------------------
1 | # otherwise use 'maturin develop --release'
2 | set -e
3 |
4 | . ../tokenkit_env/bin/activate
5 | python3 -m maturin build --release
6 | mv $PWD/target/wheels/rust_utils-0.14.1.dev0-cp310-cp310-manylinux_2_34_x86_64.whl $PWD/target/wheels/rust_utils-0.14.1.dev0-cp310-none-any.whl
7 | pip install --force-reinstall $PWD/target/wheels/rust_utils-0.14.1.dev0-cp310-none-any.whl
8 | #pip install fsspec==2023.9.2
9 | #pip install --upgrade huggingface_hub datasets
10 | # TODO: do we need the above? if yes, pin versions / check why
--------------------------------------------------------------------------------
/scripts/compute_mined_mapping.py:
--------------------------------------------------------------------------------
1 | import json
2 | from pathlib import Path
3 |
4 | import numpy as np
5 | from dataclasses import dataclass
6 |
7 | from tokenkit import baseline_utils
8 | from tokenkit.byteify import load_byteify_tokenizer
9 | from tokenkit import parse_args
10 |
11 | @dataclass
12 | class ComputedMinedMappingArgs:
13 | teacher_tokenizer_name: str
14 | target_tokenizer_name: str
15 | output: str
16 | num_workers: int
17 |
18 | def main(args: ComputedMinedMappingArgs) -> None:
19 | output_dir = Path(args.output)
20 | output_dir.mkdir(exist_ok=True, parents=True)
21 |
22 | tokenizer_teacher = load_byteify_tokenizer(args.teacher_tokenizer_name)
23 | target_tokenizer = load_byteify_tokenizer(args.target_tokenizer_name)
24 |
25 | mined_mapping, mined_distances = baseline_utils.compute_mined_mapping(
26 | tokenizer_teacher, target_tokenizer, num_workers=args.num_workers
27 | )
28 |
29 | np.save(output_dir / "mined_mapping.npy", mined_mapping)
30 | json.dump(
31 | mined_distances,
32 | open(output_dir / "mined_distances.json", "w"),
33 | indent=4,
34 | )
35 |
36 |
37 | if __name__ == "__main__":
38 | main(parse_args.parse_args(ComputedMinedMappingArgs))
39 |
--------------------------------------------------------------------------------
/scripts/compute_tokenizer_info.py:
--------------------------------------------------------------------------------
1 | """
2 | Example Usage:
3 |
4 | ipython --pdb scripts/compute_tokenizer_info.py -- \
5 | teacher_tokenizer_name=google/gemma-2-2b-it:source=Gemma2 \
6 | target_tokenizer_name=Qwen/Qwen2.5-1.5B:source=Qwen2:target=Gemma2 \
7 | output='outputs/tokenizer_data/gemma2_to_qwen2_new'
8 | """
9 |
10 | import json
11 | import os
12 | import pickle
13 | from collections import Counter
14 | from functools import partial
15 | from pathlib import Path
16 |
17 | import hydra
18 | import numpy as np
19 | from scipy import sparse
20 | from torch.utils.data import DataLoader
21 | from tqdm.auto import tqdm
22 |
23 | from tokenkit import data
24 | from tokenkit.byteify import load_byteify_tokenizer
25 |
26 |
27 | def compute_prefix_map(tokenizer):
28 | prefix_map = {}
29 |
30 | for token in tokenizer.get_vocab().keys():
31 | for i in range(1, len(token) + 1):
32 | if token[:i] in prefix_map:
33 | prefix_map[token[:i]].append(token)
34 | else:
35 | prefix_map[token[:i]] = [token]
36 |
37 | return prefix_map
38 |
39 |
40 | def is_valid(tokens, tokenizer):
41 | try:
42 | return tokenizer.backend_tokenize("".join(tokens)) == tokens
43 | except UnicodeDecodeError:
44 | return False
45 |
46 |
47 | def compute_cover_set(pretoken, tokenizer, prefix_map):
48 | cover_set = []
49 | for i in range(len(pretoken) - 1, -1, -1):
50 | B = prefix_map.get(pretoken[i:], [])
51 | try:
52 | tcur = tokenizer.backend_tokenize(pretoken[:i])
53 | except UnicodeDecodeError:
54 | continue
55 |
56 | for b in B:
57 | if is_valid(tcur + [b], tokenizer):
58 | cover_set.append(tcur + [b])
59 |
60 | return cover_set
61 |
62 |
63 | def compute_cover_dict(pretoken, tokenizer, prefix_map):
64 | cover_set = compute_cover_set(pretoken, tokenizer, prefix_map)
65 | cover_dict = {}
66 |
67 | for seq in cover_set:
68 | joined_seq = "".join(seq)[len(pretoken) :]
69 | if len(joined_seq) == 0:
70 | continue
71 |
72 | cover_dict[joined_seq] = tokenizer.convert_tokens_to_ids(seq)
73 |
74 | return cover_dict
75 |
76 |
77 | def compute_pair_bias(
78 | pretoken1,
79 | pretoken2,
80 | tokenizer1,
81 | tokenizer2,
82 | prefix_map1,
83 | prefix_map2,
84 | probs1,
85 | probs2,
86 | return_diff_cover_dicts=False,
87 | ):
88 | cover_dict1 = compute_cover_dict(pretoken1, tokenizer1, prefix_map1)
89 | cover_dict2 = compute_cover_dict(pretoken2, tokenizer2, prefix_map2)
90 |
91 | diff_keys1 = set(cover_dict1.keys()) - set(cover_dict2.keys())
92 | diff_keys2 = set(cover_dict2.keys()) - set(cover_dict1.keys())
93 |
94 | bias1 = 0.0
95 | for key in diff_keys1:
96 | bias1 += probs1[cover_dict1[key][-1]]
97 |
98 | bias2 = 0.0
99 | for key in diff_keys2:
100 | bias2 += probs2[cover_dict2[key][-1]]
101 |
102 | if return_diff_cover_dicts:
103 | diff_cover_set1 = {key: probs1[cover_dict1[key][-1]] for key in diff_keys1}
104 | diff_cover_set2 = {key: probs2[cover_dict2[key][-1]] for key in diff_keys2}
105 | return bias1, bias2, diff_cover_set1, diff_cover_set2
106 | else:
107 | return bias1, bias2
108 |
109 |
110 | def count_tokens_map(examples, tokenizer):
111 | flat_input_ids = [
112 | input_id
113 | for input_ids in tokenizer(examples["text"], add_special_tokens=False)[
114 | "input_ids"
115 | ]
116 | for input_id in input_ids
117 | ]
118 | return {
119 | "counter": pickle.dumps(Counter(flat_input_ids)),
120 | }
121 |
122 |
123 | def count_tokens(dset, tokenizer, num_workers):
124 | token_counters_dset = dset.map(
125 | partial(count_tokens_map, tokenizer=tokenizer),
126 | batched=False, # already batched
127 | num_proc=num_workers if num_workers > 0 else None,
128 | remove_columns=dset.column_names,
129 | desc="Counting tokens",
130 | )
131 |
132 | global_token_counter = Counter()
133 | for i in tqdm(range(len(token_counters_dset)), desc="Merging token counters"):
134 | global_token_counter.update(pickle.loads(token_counters_dset[i]["counter"]))
135 |
136 | return global_token_counter
137 |
138 |
139 | if __name__ == "__main__":
140 | args = None
141 |
142 | os.environ["TOKENIZERS_PARALLELISM"] = "false"
143 |
144 | def _parse_args_fn(local_args):
145 | global args
146 | args = local_args
147 |
148 | hydra.main(
149 | version_base=None,
150 | config_path="../configs",
151 | config_name="compute_tokenizer_info",
152 | )(_parse_args_fn)()
153 |
154 | output_dir = Path(args.output)
155 | output_dir.mkdir(exist_ok=True, parents=True)
156 |
157 | tokenizer_teacher = load_byteify_tokenizer(args.teacher_tokenizer_name)
158 | target_tokenizer = load_byteify_tokenizer(args.target_tokenizer_name)
159 |
160 | dset = data.get_dataset(**args.data, seed=args.seed)
161 |
162 | if not (output_dir / "teacher_counts.json").exists():
163 | assert isinstance(dset, data.HFDataset)
164 | dset_to_use = dset.stream
165 |
166 | if args.teacher_subsample_percent is not None:
167 | n_subsample = int(len(dset_to_use) * args.teacher_subsample_percent)
168 | dset_to_use = dset_to_use.select(np.arange(n_subsample))
169 |
170 | teacher_token_counts = count_tokens(
171 | dset_to_use, tokenizer_teacher, args.data.num_workers
172 | )
173 | json.dump(
174 | teacher_token_counts,
175 | open(output_dir / "teacher_counts.json", "w"),
176 | indent=4,
177 | )
178 | else:
179 | teacher_token_counts = Counter(
180 | json.load(open(output_dir / "teacher_counts.json"))
181 | )
182 | if not (output_dir / "student_counts.json").exists():
183 | assert isinstance(dset, data.HFDataset)
184 | dset_to_use = dset.stream
185 |
186 | if args.student_subsample_percent is not None:
187 | n_subsample = int(len(dset_to_use) * args.student_subsample_percent)
188 | dset_to_use = dset_to_use.select(np.arange(n_subsample))
189 |
190 | student_token_counts = count_tokens(
191 | dset_to_use, target_tokenizer, args.data.num_workers
192 | )
193 | json.dump(
194 | student_token_counts,
195 | open(output_dir / "student_counts.json", "w"),
196 | indent=4,
197 | )
198 | else:
199 | student_token_counts = Counter(
200 | json.load(open(output_dir / "student_counts.json"))
201 | )
202 |
203 | if not (output_dir / "pairs.json").exists():
204 | teacher_tokens_dict = {}
205 | for token in sorted(
206 | tokenizer_teacher.get_vocab().keys(), key=lambda x: x[::-1]
207 | ):
208 | if token[-1] not in teacher_tokens_dict:
209 | teacher_tokens_dict[token[-1]] = []
210 |
211 | teacher_tokens_dict[token[-1]].append(token)
212 |
213 | student_tokens_dict = {}
214 | for token in sorted(target_tokenizer.get_vocab().keys(), key=lambda x: x[::-1]):
215 | if token[-1] not in student_tokens_dict:
216 | student_tokens_dict[token[-1]] = []
217 |
218 | student_tokens_dict[token[-1]].append(token)
219 |
220 | pairs = []
221 | for last_byte in tqdm(
222 | set(teacher_tokens_dict.keys()) & set(student_tokens_dict.keys())
223 | ):
224 | for teacher_token in teacher_tokens_dict[last_byte]:
225 | for student_token in student_tokens_dict[last_byte]:
226 | if teacher_token.endswith(student_token) or student_token.endswith(
227 | teacher_token
228 | ):
229 | pairs.append((teacher_token, student_token))
230 |
231 | json.dump(pairs, open(output_dir / "pairs.json", "w"), indent=4)
232 | else:
233 | pairs = json.load(open(output_dir / "pairs.json"))
234 |
235 | print(f"Found {len(pairs)} pairs")
236 |
237 | prefix_map_teacher = compute_prefix_map(tokenizer_teacher)
238 | prefix_map_student = compute_prefix_map(target_tokenizer)
239 |
240 | teacher_counts_sum = sum(teacher_token_counts.values())
241 | teacher_token_probs = np.array(
242 | [
243 | teacher_token_counts[token_id]
244 | + args.additive_smoothing_constant * teacher_counts_sum
245 | for token_id in range(len(tokenizer_teacher))
246 | ],
247 | dtype=np.float32,
248 | )
249 | teacher_token_probs /= teacher_token_probs.sum()
250 |
251 | student_counts_sum = sum(student_token_counts.values())
252 | student_token_probs = np.array(
253 | [
254 | student_token_counts[token_id]
255 | + args.additive_smoothing_constant * student_counts_sum
256 | for token_id in range(len(target_tokenizer))
257 | ],
258 | dtype=np.float32,
259 | )
260 | student_token_probs /= student_token_probs.sum()
261 |
262 | is_space_only_teacher = {
263 | tokenizer_teacher.convert_ids_to_tokens(i): len(
264 | tokenizer_teacher.decode(i).strip()
265 | )
266 | == 0
267 | for i in range(len(tokenizer_teacher))
268 | }
269 | is_space_only_student = {
270 | target_tokenizer.convert_ids_to_tokens(i): len(
271 | target_tokenizer.decode(i).strip()
272 | )
273 | == 0
274 | for i in range(len(target_tokenizer))
275 | }
276 |
277 | def pair_collate(pairs):
278 | biases1 = []
279 | biases2 = []
280 | for pair in pairs:
281 | if is_space_only_teacher[pair[0]] or is_space_only_student[pair[1]]:
282 | biases1.append(1.0) # can take long to compute and likely high
283 | biases2.append(1.0)
284 | continue
285 | bias1, bias2 = compute_pair_bias(
286 | *pair,
287 | tokenizer_teacher,
288 | target_tokenizer,
289 | prefix_map_teacher,
290 | prefix_map_student,
291 | teacher_token_probs,
292 | student_token_probs,
293 | )
294 | biases1.append(bias1)
295 | biases2.append(bias2)
296 |
297 | return {
298 | "biases1": biases1,
299 | "biases2": biases2,
300 | }
301 |
302 | pair_permutation = np.random.permutation(len(pairs))
303 | inv_pair_permutation = np.argsort(pair_permutation)
304 |
305 | biases1 = []
306 | biases2 = []
307 | pair_data_loader = DataLoader(
308 | [pairs[i] for i in pair_permutation],
309 | batch_size=args.data.batch_size,
310 | num_workers=args.data.num_workers,
311 | collate_fn=pair_collate,
312 | )
313 |
314 | for batch in tqdm(pair_data_loader, desc="Computing pair biases"):
315 | biases1.extend(batch["biases1"])
316 | biases2.extend(batch["biases2"])
317 |
318 | biases1 = np.array(biases1)[inv_pair_permutation]
319 | biases2 = np.array(biases2)[inv_pair_permutation]
320 |
321 | bias1_matrix = sparse.coo_matrix(
322 | (
323 | biases1,
324 | (
325 | np.array(tokenizer_teacher.convert_tokens_to_ids(x[0] for x in pairs)),
326 | np.array(target_tokenizer.convert_tokens_to_ids(x[1] for x in pairs)),
327 | ),
328 | ),
329 | shape=(len(tokenizer_teacher), len(target_tokenizer)),
330 | )
331 | bias2_matrix = sparse.coo_matrix(
332 | (
333 | biases2,
334 | (
335 | np.array(tokenizer_teacher.convert_tokens_to_ids(x[0] for x in pairs)),
336 | np.array(target_tokenizer.convert_tokens_to_ids(x[1] for x in pairs)),
337 | ),
338 | ),
339 | shape=(len(tokenizer_teacher), len(target_tokenizer)),
340 | )
341 |
342 | sparse.save_npz(output_dir / "bias1_matrix.npz", bias1_matrix)
343 | sparse.save_npz(output_dir / "bias2_matrix.npz", bias2_matrix)
344 |
--------------------------------------------------------------------------------
/scripts/eval.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | from functools import partial
4 | from pathlib import Path
5 | from pprint import pformat, pprint
6 | import yaml
7 |
8 | import datasets
9 | import jax
10 | import jax.numpy as jnp
11 | import numpy as np
12 | from jax.experimental import multihost_utils
13 | from jax.sharding import NamedSharding
14 | from jax.sharding import PartitionSpec as P
15 | from transformers import FlaxAutoModelForCausalLM
16 | from dataclasses import dataclass, asdict
17 |
18 | from tokenkit.hf import get_config
19 | from tokenkit import utils, parse_args
20 | from tokenkit.byteify import load_byteify_tokenizer
21 | from tokenkit.eval import ATOL, evaluate, score
22 | from tokenkit.models import param, sharding
23 |
24 | logger = logging.getLogger(__name__)
25 |
26 | datasets.config.HF_DATASETS_TRUST_REMOTE_CODE = (
27 | True # careful about this, required for lm_eval
28 | )
29 |
30 | @dataclass
31 | class EvalScriptArgs:
32 | model: parse_args.ModelArgs
33 | expand_model: parse_args.ModelArgs
34 | eval: parse_args.EvalArgs
35 | output: str | None = None
36 | pad_to_multiple_of: int = 128
37 | use_cpu: bool = False
38 |
39 |
40 | def pad_embeddings(embeddings, tokenizer):
41 | n_embed_diff = len(tokenizer) - len(embeddings)
42 |
43 | embeddings_mean = embeddings.mean(0)
44 | embeddings_std = embeddings.std(0)
45 |
46 | return np.concatenate(
47 | [
48 | embeddings,
49 | np.random.normal(
50 | size=(n_embed_diff, *embeddings.shape[1:]),
51 | )
52 | * embeddings_std[None]
53 | + embeddings_mean[None],
54 | ]
55 | )
56 |
57 |
58 | def main(args: EvalScriptArgs) -> None:
59 | logger.info(pformat(args))
60 |
61 | model_kwargs = asdict(args.model)
62 | eval_kwargs = asdict(args.eval)
63 |
64 | if args.use_cpu:
65 | jax.config.update("jax_default_device", jax.devices("cpu")[0])
66 | mesh = sharding.get_mesh(devices=jax.devices("cpu"))
67 | else:
68 | mesh = sharding.get_mesh()
69 |
70 | if args.output is not None:
71 | output_dir = Path(args.output)
72 | output_dir.mkdir(parents=True, exist_ok=True)
73 |
74 | with open(output_dir / "args.yaml", "w") as f:
75 | yaml.dump(asdict(args), f)
76 | else:
77 | output_dir = None
78 |
79 | tokenizer = load_byteify_tokenizer(model_kwargs.pop("tokenizer_name"))
80 |
81 | config = get_config(**model_kwargs)
82 | config.max_length = eval_kwargs["lengths"][-1]
83 | config.mesh = mesh
84 |
85 | model = FlaxAutoModelForCausalLM.from_config(config, _do_init=False)
86 |
87 | params = param.load_params(**model_kwargs)
88 |
89 | if args.expand_model.pretrained_model_name_or_path is not None:
90 | expand_model_kwargs = asdict(args.expand_model)
91 |
92 | expand_tokenizer = load_byteify_tokenizer(
93 | expand_model_kwargs.pop("tokenizer_name")
94 | )
95 | expand_config = get_config(**expand_model_kwargs)
96 | expand_vocab = expand_tokenizer.get_vocab()
97 |
98 | expand_input_ids_model_params = param.load_params(**expand_model_kwargs)
99 | expand_input_ids_embeddings = param.get(
100 | expand_input_ids_model_params,
101 | param.get_input_embedding_path(expand_config.model_type),
102 | )
103 |
104 | n_overflow = expand_input_ids_embeddings.shape[0] % args.pad_to_multiple_of
105 | if n_overflow > 0:
106 | n_pad = args.pad_to_multiple_of - n_overflow
107 | else:
108 | n_pad = 0
109 |
110 | expand_input_ids_embeddings = np.pad(
111 | expand_input_ids_embeddings,
112 | ((0, n_pad), (0, 0)),
113 | mode="constant",
114 | constant_values=0,
115 | )
116 | else:
117 | expand_tokenizer = None
118 | expand_vocab = None
119 | expand_input_ids_embeddings = None
120 |
121 | input_embeddings = param.get(
122 | params, param.get_input_embedding_path(config.model_type)
123 | )
124 | input_embeddings = input_embeddings[: len(tokenizer)]
125 |
126 | if len(input_embeddings) < len(tokenizer):
127 | print("Padding input embeddings...")
128 | input_embeddings = pad_embeddings(input_embeddings, tokenizer)
129 |
130 | if not config.tie_word_embeddings:
131 | output_embeddings = param.get(
132 | params, param.get_output_embedding_path(config.model_type)
133 | )
134 | output_embeddings = output_embeddings[:, : len(tokenizer)]
135 | print("Padding output embeddings...")
136 | output_embeddings = pad_embeddings(output_embeddings.T, tokenizer).T
137 | else:
138 | output_embeddings = None
139 |
140 | n_overflow = input_embeddings.shape[0] % args.pad_to_multiple_of
141 | if n_overflow > 0:
142 | n_pad = args.pad_to_multiple_of - n_overflow
143 | else:
144 | n_pad = 0
145 |
146 | input_embeddings = np.pad(
147 | input_embeddings,
148 | ((0, n_pad), (0, 0)),
149 | mode="constant",
150 | constant_values=0,
151 | )
152 | if output_embeddings is not None:
153 | output_embeddings = np.pad(
154 | output_embeddings,
155 | ((0, 0), (0, n_pad)),
156 | mode="constant",
157 | constant_values=0,
158 | )
159 | logit_mask = np.zeros((input_embeddings.shape[0],), dtype=bool)
160 | logit_mask[: model.config.vocab_size] = True
161 |
162 | model.config.vocab_size = input_embeddings.shape[0]
163 |
164 | params = param.put(
165 | params, param.get_input_embedding_path(config.model_type), input_embeddings
166 | )
167 | if output_embeddings is not None:
168 | params = param.put(
169 | params,
170 | param.get_output_embedding_path(config.model_type),
171 | output_embeddings,
172 | )
173 |
174 | if expand_input_ids_embeddings is not None:
175 | # expects stacked embedding format
176 | params["original_embeddings"] = expand_input_ids_embeddings[:, None, :]
177 |
178 | shard_patterns = sharding.get_shard_patterns(config.model_type)
179 | param_shardings = sharding.get_sharding_fn(shard_patterns, mesh)(
180 | {"params": params}
181 | )["params"]
182 | params = sharding.to_devices(params, param_shardings, dtype=jnp.float32)
183 |
184 | multihost_utils.sync_global_devices("loaded weights")
185 |
186 | jaxlm_kwargs = {"precompile": not args.use_cpu}
187 |
188 | if args.expand_model.pretrained_model_name_or_path is not None:
189 | # TODO: move elsewhere, probably into jaxlm
190 | expand_input_ids_dict = utils.get_expand_input_ids_dict(
191 | tokenizer,
192 | expand_vocab,
193 | )
194 |
195 | def compute_inputs_embeds(model_params, input_ids, expanded_input_ids):
196 | input_embeddings = param.get(
197 | model_params, param.get_input_embedding_path(config.model_type)
198 | )
199 |
200 | standard_inputs_embeds = jnp.take(
201 | input_embeddings,
202 | input_ids,
203 | axis=0,
204 | )
205 | expanded_inputs_embeds = jnp.take(
206 | expand_input_ids_embeddings,
207 | expanded_input_ids,
208 | axis=0,
209 | )
210 |
211 | inputs_embeds = standard_inputs_embeds + expanded_inputs_embeds
212 |
213 | return inputs_embeds
214 |
215 | @partial(
216 | jax.jit,
217 | static_argnames=("model_fn", "atol"),
218 | in_shardings=(
219 | param_shardings,
220 | NamedSharding(mesh, P()),
221 | NamedSharding(mesh, P()),
222 | NamedSharding(mesh, P()),
223 | NamedSharding(mesh, P()),
224 | NamedSharding(mesh, P()),
225 | NamedSharding(mesh, P()),
226 | ),
227 | out_shardings=(NamedSharding(mesh, P()), NamedSharding(mesh, P())),
228 | )
229 | def jaxlm_inner_score_fn(
230 | model_fn,
231 | params,
232 | input_ids,
233 | expanded_input_ids,
234 | labels,
235 | suffix_mask,
236 | space_mask,
237 | logit_mask,
238 | atol=ATOL,
239 | ):
240 | inputs_embeds = compute_inputs_embeds(
241 | params,
242 | input_ids,
243 | expanded_input_ids,
244 | )
245 | return score(
246 | model_fn,
247 | params,
248 | (None, inputs_embeds),
249 | labels=labels,
250 | suffix_mask=suffix_mask,
251 | space_mask=space_mask,
252 | logit_mask=logit_mask,
253 | atol=atol,
254 | )
255 |
256 | def jaxlm_score_fn(model_fn, params, model_args, *pargs):
257 | (input_ids,) = model_args
258 |
259 | expanded_input_ids = utils.np_expand_input_ids(
260 | input_ids,
261 | expand_input_ids_dict,
262 | )
263 |
264 | return jaxlm_inner_score_fn(
265 | model_fn,
266 | params,
267 | input_ids,
268 | expanded_input_ids,
269 | *pargs,
270 | )
271 |
272 | jaxlm_kwargs["expand_input_ids"] = True
273 | jaxlm_kwargs["expand_input_ids_vocab"] = expand_vocab
274 | jaxlm_kwargs["score_fn"] = jaxlm_score_fn
275 |
276 | results, _ = evaluate(
277 | model=model,
278 | config=config,
279 | params=params,
280 | tokenizer=tokenizer,
281 | logit_mask=logit_mask,
282 | output=output_dir,
283 | **eval_kwargs,
284 | jaxlm_kwargs=jaxlm_kwargs,
285 | )
286 |
287 | if jax.process_index() == 0:
288 | pprint(results)
289 |
290 |
291 | if __name__ == "__main__":
292 | os.environ["HF_ALLOW_CODE_EVAL"] = "1"
293 | main(parse_args.parse_args(EvalScriptArgs))
--------------------------------------------------------------------------------
/scripts/eval_lockstep.py:
--------------------------------------------------------------------------------
1 | """
2 | Example Usage:
3 |
4 | python3 scripts/eval_lockstep.py models=llama_qwen +eval.limit=100
5 | """
6 |
7 | import logging
8 | from pathlib import Path
9 | from pprint import pformat, pprint
10 | from dataclasses import dataclass, asdict
11 | import os
12 | import yaml
13 | import datasets
14 | import jax
15 | import jax.numpy as jnp
16 | import numpy as np
17 | from jax.experimental import multihost_utils
18 | from transformers import FlaxAutoModelForCausalLM
19 |
20 | from tokenkit import parse_args
21 | from tokenkit.hf import get_config
22 | from tokenkit.byteify import load_byteify_tokenizer
23 | from tokenkit.eval import evaluate_lockstep
24 | from tokenkit.models import param, sharding
25 |
26 | logger = logging.getLogger(__name__)
27 |
28 | datasets.config.HF_DATASETS_TRUST_REMOTE_CODE = (
29 | True # careful about this, required for lm_eval
30 | )
31 |
32 | @dataclass
33 | class EvalLockstepScriptArgs:
34 | combine_strategy: str
35 | models: list[parse_args.ModelArgs]
36 | eval: parse_args.EvalArgs
37 | baseline_mined_mapping_paths: list[str] | None = None
38 | output: str | None = None
39 | pad_to_multiple_of: int = 128
40 | use_cpu: bool = False
41 |
42 |
43 | def pad_embeddings(embeddings, tokenizer):
44 | n_embed_diff = len(tokenizer) - len(embeddings)
45 |
46 | embeddings_mean = embeddings.mean(0)
47 | embeddings_std = embeddings.std(0)
48 |
49 | return np.concatenate(
50 | [
51 | embeddings,
52 | np.random.normal(
53 | size=(n_embed_diff, *embeddings.shape[1:]),
54 | )
55 | * embeddings_std[None]
56 | + embeddings_mean[None],
57 | ]
58 | )
59 |
60 |
61 | def main(args: EvalLockstepScriptArgs) -> None:
62 | logger.info(pformat(args))
63 |
64 | eval_kwargs = asdict(args.eval)
65 |
66 | if args.output is not None:
67 | output_dir = Path(args.output)
68 | output_dir.mkdir(parents=True, exist_ok=True)
69 |
70 | with open(output_dir / "args.yaml", "w") as f:
71 | yaml.dump(asdict(args), f)
72 | else:
73 | output_dir = None
74 |
75 | if args.use_cpu:
76 | jax.config.update("jax_default_device", jax.devices("cpu")[0])
77 | mesh = sharding.get_mesh(devices=jax.devices("cpu"))
78 | else:
79 | mesh = sharding.get_mesh()
80 |
81 | all_models = []
82 | all_configs = []
83 | all_params = []
84 | all_tokenizers = []
85 | all_logit_masks = []
86 |
87 | eval_kwargs.pop("add_bos")
88 | all_add_bos = []
89 |
90 | for model_idx, model_kwargs in enumerate(args.models):
91 | print("Loading model...")
92 |
93 | config = get_config(model_kwargs["pretrained_model_name_or_path"])
94 |
95 | config.max_length = eval_kwargs["lengths"][-1]
96 | config.mesh = mesh
97 |
98 | tokenizer = load_byteify_tokenizer(model_kwargs.pop("tokenizer_name"))
99 |
100 | model = FlaxAutoModelForCausalLM.from_config(config, _do_init=False)
101 | params = param.load_params(
102 | pretrained_model_name_or_path=model_kwargs["pretrained_model_name_or_path"]
103 | )
104 |
105 | input_embeddings = param.get(
106 | params, param.get_input_embedding_path(config.model_type)
107 | )
108 |
109 | if len(input_embeddings) < len(tokenizer):
110 | print("Padding input embeddings...")
111 | input_embeddings = pad_embeddings(input_embeddings, tokenizer)
112 |
113 | if not config.tie_word_embeddings:
114 | output_embeddings = param.get(
115 | params, param.get_output_embedding_path(config.model_type)
116 | )
117 | print("Padding output embeddings...")
118 | output_embeddings = pad_embeddings(output_embeddings.T, tokenizer).T
119 | else:
120 | output_embeddings = None
121 |
122 | n_overflow = input_embeddings.shape[0] % args.pad_to_multiple_of
123 | if n_overflow > 0:
124 | n_pad = args.pad_to_multiple_of - n_overflow
125 | else:
126 | n_pad = 0
127 |
128 | input_embeddings = np.pad(
129 | input_embeddings,
130 | ((0, n_pad), (0, 0)),
131 | mode="constant",
132 | constant_values=0,
133 | )
134 | if output_embeddings is not None:
135 | output_embeddings = np.pad(
136 | output_embeddings,
137 | ((0, 0), (0, n_pad)),
138 | mode="constant",
139 | constant_values=0,
140 | )
141 | logit_mask = np.zeros((input_embeddings.shape[0],), dtype=bool)
142 | logit_mask[: model.config.vocab_size] = True
143 | model.config.vocab_size = input_embeddings.shape[0]
144 |
145 | params = param.put(
146 | params, param.get_input_embedding_path(config.model_type), input_embeddings
147 | )
148 | if output_embeddings is not None:
149 | params = param.put(
150 | params,
151 | param.get_output_embedding_path(config.model_type),
152 | output_embeddings,
153 | )
154 |
155 | shard_patterns = sharding.get_shard_patterns(config.model_type)
156 | param_shardings = sharding.get_sharding_fn(shard_patterns, mesh)(
157 | {"params": params}
158 | )["params"]
159 | params = sharding.to_devices(params, param_shardings, dtype=jnp.float32)
160 |
161 | multihost_utils.sync_global_devices("loaded weights")
162 |
163 | if args.baseline_mined_mapping_paths is not None:
164 | if args.baseline_mined_mapping_paths[model_idx] is not None:
165 | config.mined_mapping = np.load(
166 | Path(args.baseline_mined_mapping_paths[model_idx]) / "mined_mapping.npy"
167 | )
168 | else:
169 | config.mined_mapping = None
170 |
171 | all_models.append(model)
172 | all_configs.append(config)
173 | all_params.append(params)
174 | all_tokenizers.append(tokenizer)
175 | all_logit_masks.append(logit_mask)
176 | all_add_bos.append(model_kwargs["add_bos"])
177 |
178 | # static combine fn for the moment
179 | def combine_fn(hidden_states, logits, combine_params, output_embeddings):
180 | if args.combine_strategy == "mean_prob":
181 | aggregated_probs = None
182 | for model_logits in logits:
183 | model_probs = jax.nn.softmax(model_logits, axis=-1)
184 | if aggregated_probs is None:
185 | aggregated_probs = model_probs
186 | else:
187 | aggregated_probs += model_probs
188 |
189 | aggregated_probs /= len(logits)
190 | return jnp.log(aggregated_probs)
191 | elif args.combine_strategy == "mean_logits":
192 | aggregated_logits = None
193 | for model_logits in logits:
194 | if aggregated_logits is None:
195 | aggregated_logits = model_logits
196 | else:
197 | aggregated_logits += model_logits
198 |
199 | aggregated_logits /= len(logits)
200 | return aggregated_logits
201 | else:
202 | raise ValueError(f"Unknown combine strategy: {args.combine_strategy}")
203 |
204 | results = evaluate_lockstep(
205 | models=all_models,
206 | configs=all_configs,
207 | params=all_params,
208 | tokenizers=all_tokenizers,
209 | logit_masks=all_logit_masks,
210 | add_bos=all_add_bos,
211 | combine_fn=combine_fn,
212 | combine_params={},
213 | jaxlm_kwargs={"precompile": not args.use_cpu},
214 | output=output_dir,
215 | **eval_kwargs,
216 | )
217 |
218 | if jax.process_index() == 0:
219 | pprint(results[0])
220 |
221 |
222 | if __name__ == "__main__":
223 | os.environ["HF_ALLOW_CODE_EVAL"] = "1"
224 | main(parse_args.parse_args(EvalLockstepScriptArgs))
--------------------------------------------------------------------------------
/scripts/export_checkpoint.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from transformers import (
3 | HfArgumentParser,
4 | AutoTokenizer,
5 | FlaxAutoModelForCausalLM,
6 | AutoModelForCausalLM,
7 | )
8 | from transformers.modeling_flax_pytorch_utils import load_flax_weights_in_pytorch_model
9 | from omegaconf import OmegaConf
10 | from flax import serialization, traverse_util
11 | from pathlib import Path
12 | from pickle import UnpicklingError
13 | from flax.serialization import from_bytes
14 | from flax.traverse_util import flatten_dict, unflatten_dict
15 | import jax
16 | import jax.numpy as jnp
17 | from tokenkit.models.hypernet import Hypernet
18 | from tokenkit.models import param, lora, sharding
19 | from tokenkit.byteify import load_byteify_tokenizer
20 | from tokenkit.hf import get_config
21 | from tokenkit import gcs_utils, utils, constants
22 | import json
23 | import os
24 | import torch
25 | from pprint import pformat
26 | import logging
27 |
28 | logger = logging.getLogger(__name__)
29 |
30 |
31 | # transformers `load_flax_checkpoint_in_pytorch_model` does not support custom models
32 | # so we patch it here
33 | def load_flax_checkpoint_in_pytorch_model(model, flax_checkpoint_path, flax_cls):
34 | """Load flax checkpoints in a PyTorch model"""
35 | flax_checkpoint_path = os.path.abspath(flax_checkpoint_path)
36 | logger.info(f"Loading Flax weights from {flax_checkpoint_path}")
37 |
38 | # load flax weight dict
39 | if flax_checkpoint_path.endswith(".safetensors"):
40 | from safetensors.flax import load_file as safe_load_file
41 |
42 | flax_state_dict = safe_load_file(flax_checkpoint_path)
43 | flax_state_dict = unflatten_dict(flax_state_dict, sep=".")
44 | else:
45 | with open(flax_checkpoint_path, "rb") as state_f:
46 | try:
47 | flax_state_dict = from_bytes(flax_cls, state_f.read())
48 | except UnpicklingError:
49 | raise EnvironmentError(f"Unable to convert {flax_checkpoint_path} to Flax deserializable object. ")
50 |
51 | return load_flax_weights_in_pytorch_model(model, flax_state_dict)
52 |
53 |
54 |
55 | @dataclass
56 | class Args:
57 | checkpoint: str = "outputs/patch"
58 | output: str = "outputs/export"
59 | use_cpu: bool = True
60 | tmp_save_dir: str = "/tmp/tokenkit/"
61 | with_pt: bool = False
62 | expand_input_ids_model: str | None = None
63 | expand_input_ids_tokenizer: str | None = None
64 | overwrite_args: str | None = None
65 |
66 |
67 | if __name__ == "__main__":
68 | (args,) = HfArgumentParser([Args]).parse_args_into_dataclasses()
69 |
70 | tmp_checkpoint_dir = Path(args.tmp_save_dir) / "checkpoint"
71 | tmp_output_dir = Path(args.tmp_save_dir) / "output"
72 |
73 | tmp_checkpoint_dir.mkdir(parents=True, exist_ok=True)
74 | tmp_output_dir.mkdir(parents=True, exist_ok=True)
75 |
76 | if args.use_cpu:
77 | jax.config.update("jax_default_device", jax.devices("cpu")[0])
78 | mesh = sharding.get_mesh(devices=jax.devices("cpu"))
79 | else:
80 | mesh = sharding.get_mesh()
81 |
82 | if gcs_utils.is_gcs_path(args.checkpoint):
83 | checkpoint_bucket, checkpoint_blob = gcs_utils.parse_gcs_path(args.checkpoint)
84 | checkpoint_dir = tmp_checkpoint_dir
85 |
86 | for filename in ["args.yaml", "params.msgpack", "config.json", "tokenizer.json", "tokenizer_config.json"]:
87 | gcs_utils.download_from_gcs(checkpoint_bucket, f"{checkpoint_blob}/{filename}", checkpoint_dir / filename)
88 | else:
89 | checkpoint_dir = Path(args.checkpoint)
90 |
91 | ckpt_args = OmegaConf.load(checkpoint_dir / "args.yaml")
92 | if args.overwrite_args is not None:
93 | ckpt_args = OmegaConf.merge(
94 | ckpt_args, OmegaConf.create(json.loads(args.overwrite_args))
95 | )
96 |
97 | logger.info("Using checkpoint args:")
98 | logger.info(pformat(ckpt_args))
99 |
100 | params = serialization.msgpack_restore(
101 | open(checkpoint_dir / "params.msgpack", "rb").read()
102 | )
103 |
104 | config = get_config(checkpoint_dir)
105 | config.mesh = mesh
106 | tokenizer = AutoTokenizer.from_pretrained(checkpoint_dir)
107 | dtype = getattr(jnp, ckpt_args.dtype)
108 |
109 | n_embd = params["new_embeddings"].shape[-1]
110 |
111 | hypernet = Hypernet(
112 | dtype=dtype,
113 | hidden_size=n_embd,
114 | num_embeddings=1 if config.tie_word_embeddings else 2,
115 | max_seq_length=1,
116 | vocab_size=config.vocab_size,
117 | **ckpt_args.hypernet,
118 | )
119 | model_kwargs = OmegaConf.to_object(ckpt_args.student)
120 |
121 | if "model" in params:
122 | model_params = params["model"]
123 | original_model_params = param.load_params(**model_kwargs)
124 | else:
125 | model_params = original_model_params = param.load_params(**model_kwargs)
126 |
127 | # model params may be partial at this point e.g. if trained with LoRA, merge them
128 | flat_merged_model_params = traverse_util.flatten_dict(original_model_params)
129 | flat_model_params = traverse_util.flatten_dict(model_params)
130 |
131 | for key in flat_model_params.keys():
132 | flat_merged_model_params[key] = flat_model_params[key]
133 |
134 | merged_model_params = traverse_util.unflatten_dict(flat_merged_model_params)
135 | # assigned later
136 | merged_model_params = param.unassign_embeddings(merged_model_params, config=config)
137 |
138 | if "model_lora" in params:
139 | logger.info("Materializing LoRA parameters...")
140 | merged_model_params = lora.materialize_lora(
141 | merged_model_params,
142 | params["model_lora"],
143 | ckpt_args.model_lora_alpha,
144 | )
145 |
146 | hypernet_fn = hypernet.apply
147 |
148 | def predict_embeddings(params): # TODO: add indices for subsampling
149 | embeddings = params["new_embeddings"]
150 |
151 | predicted_embeddings = hypernet_fn(
152 | params["hypernet"],
153 | embeddings[:, None, :, :],
154 | jnp.ones((embeddings.shape[0], 1), dtype=bool),
155 | jnp.arange(embeddings.shape[0], dtype=jnp.int32),
156 | )
157 |
158 | return predicted_embeddings
159 |
160 | embeddings = jax.device_get(predict_embeddings(params))
161 | embeddings = embeddings.copy() # not writeable otherwise
162 |
163 | # remove padding
164 | config.vocab_size = len(tokenizer)
165 | embeddings = embeddings[: len(tokenizer)] # remove padding
166 |
167 | merged_model_params = param.assign_embeddings(merged_model_params, embeddings, config=config)
168 |
169 | model_to_save = FlaxAutoModelForCausalLM.from_config(config)
170 | if gcs_utils.is_gcs_path(args.output):
171 | output_dir = tmp_output_dir
172 | else:
173 | output_dir = Path(args.output)
174 |
175 | del config.mesh
176 |
177 | # from_flax does not work with multiple shards so it is more convenient to save the model as a single shard
178 | model_to_save.save_pretrained(
179 | output_dir, params=merged_model_params, max_shard_size="100GB"
180 | )
181 |
182 | if args.with_pt:
183 | if args.expand_input_ids_model is not None:
184 | byteify_tokenizer = load_byteify_tokenizer(ckpt_args.target_tokenizer_name)
185 | expand_tokenizer = load_byteify_tokenizer(args.expand_input_ids_tokenizer)
186 |
187 | expand_input_ids_dict = utils.get_expand_input_ids_dict(
188 | byteify_tokenizer,
189 | expand_tokenizer.get_vocab(),
190 | max_length=constants.EXPAND_INPUT_IDS_MAX_LENGTH,
191 | )
192 |
193 | config.expand_input_ids = True
194 | config.expand_input_ids_maxlen = constants.EXPAND_INPUT_IDS_MAX_LENGTH
195 | config.expand_input_ids_vocab_size = len(expand_tokenizer)
196 | # make json serializable - will be deserialized in PT model init
197 | config.expand_input_ids_dict = (
198 | {",".join([str(n) for n in k]): int(v) for k, v in expand_input_ids_dict[0].items()},
199 | [int(n) for n in expand_input_ids_dict[1]],
200 | )
201 |
202 | pt_model = AutoModelForCausalLM.from_config(config)
203 | pt_model = load_flax_checkpoint_in_pytorch_model(pt_model, output_dir / "flax_model.msgpack", type(model_to_save))
204 |
205 | # set expansion embedding data
206 | if args.expand_input_ids_model is not None:
207 | expand_input_ids_model_config = get_config(args.expand_input_ids_model)
208 | expand_input_ids_model_params = param.load_params(pretrained_model_name_or_path=args.expand_input_ids_model)
209 | expand_input_ids_embeddings = param.get(
210 | expand_input_ids_model_params,
211 | param.get_input_embedding_path(expand_input_ids_model_config.model_type),
212 | )
213 |
214 | pt_model.model.expand_embed_tokens.weight.data[:] = torch.from_numpy(expand_input_ids_embeddings)
215 |
216 | pt_model.save_pretrained(output_dir)
217 | else:
218 | pt_model = None
219 |
220 | if args.expand_input_ids_model is not None:
221 | raise ValueError("expand_input_ids_model is not supported when with_pt is False")
222 |
223 | config.auto_map = {
224 | "AutoConfig": f"configuration_{config.model_type}.{type(config).__name__}",
225 | "FlaxAutoModelForCausalLM": f"modelling_flax_{config.model_type}.{type(model_to_save).__name__}"
226 | }
227 |
228 | if pt_model is not None:
229 | config.auto_map["AutoModelForCausalLM"] = f"modelling_{config.model_type}.{type(pt_model).__name__}"
230 |
231 | tokenizer.save_pretrained(output_dir)
232 | config.save_pretrained(output_dir)
233 |
234 | if gcs_utils.is_gcs_path(args.output):
235 | output_bucket, output_blob = gcs_utils.parse_gcs_path(args.output)
236 | for filename in ["config.json", "flax_model.msgpack", "tokenizer.json", "tokenizer_config.json"]:
237 | gcs_utils.upload_to_gcs(output_bucket, output_dir / filename, f"{output_blob}/{filename}")
--------------------------------------------------------------------------------
/scripts/push_flax_version_to_hub.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from transformers import HfArgumentParser, AutoModelForCausalLM, AutoTokenizer
3 | import json
4 | import jax
5 | import transformers
6 | import shutil
7 |
8 | from tokenkit.models import sharding
9 |
10 | TMP_PATH = "/mnt/disks/persist/tmp/model"
11 |
12 | @dataclass
13 | class Args:
14 | model_name_or_path: str = "Qwen/Qwen2-0.5B"
15 | hub_user: str = "benjamin"
16 | model_class: str = "Llama"
17 | extra_args: str | None = None # for Qwen2: "{\"attention_bias\": true, \"max_length\": 8192}", for Llama3: "{\"max_length\": 8192}"
18 | use_cpu: bool = False
19 |
20 |
21 | if __name__ == "__main__":
22 | (args,) = HfArgumentParser([Args]).parse_args_into_dataclasses()
23 | print(args)
24 |
25 | if args.use_cpu:
26 | jax.config.update('jax_default_device', jax.devices('cpu')[0])
27 | mesh = sharding.get_mesh(devices=jax.devices("cpu"))
28 | else:
29 | mesh = sharding.get_mesh()
30 |
31 | shutil.rmtree(TMP_PATH, ignore_errors=True)
32 | AutoModelForCausalLM.from_pretrained(args.model_name_or_path).save_pretrained(
33 | TMP_PATH, max_shard_size="100GB"
34 | )
35 | tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
36 | config_class = getattr(transformers, args.model_class + "Config")
37 | if hasattr(transformers, "Flax" + args.model_class + "ForCausalLM"):
38 | model_class = getattr(transformers, "Flax" + args.model_class + "ForCausalLM")
39 | elif hasattr(transformers, "Flax" + args.model_class + "LMHeadModel"):
40 | model_class = getattr(transformers, "Flax" + args.model_class + "LMHeadModel")
41 | else:
42 | raise ValueError(f"Model class '{args.model_class}' not found")
43 |
44 | config = config_class.from_pretrained(TMP_PATH, args.model_name_or_path)
45 | for key, value in json.loads(args.extra_args or "{}").items():
46 | setattr(config, key, value)
47 |
48 | config.mesh = mesh
49 |
50 | flax_model = model_class.from_pretrained(TMP_PATH, config=config)
51 | model_name = args.hub_user + "/" + args.model_name_or_path.split("/")[-1] + "-flax"
52 |
53 | del config.mesh
54 |
55 | flax_model.push_to_hub(model_name, private=True, safe_serialization=False)
56 | tokenizer.push_to_hub(model_name, private=True)
--------------------------------------------------------------------------------
/scripts/zett.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from pprint import pformat
3 |
4 | import jax
5 | from dataclasses import dataclass, asdict
6 | from transformers import FlaxAutoModelForCausalLM
7 |
8 | from tokenkit.hf import get_config
9 | from tokenkit import utils
10 | from tokenkit.byteify import load_byteify_tokenizer
11 | from tokenkit.models import param, sharding
12 | from tokenkit import parse_args
13 |
14 | logger = logging.getLogger(__name__)
15 |
16 | @dataclass
17 | class ZettArgs:
18 | source_model: parse_args.ModelArgs
19 | target_tokenizer_name: str
20 | output: str
21 |
22 | def main(args: ZettArgs) -> None:
23 | logger.info(pformat(args))
24 |
25 | # Load the model & tokenizer
26 | source_tokenizer = load_byteify_tokenizer(args.source_model.tokenizer_name)
27 | target_tokenizer = load_byteify_tokenizer(args.target_tokenizer_name)
28 |
29 | mesh = sharding.get_mesh(devices=jax.devices("cpu"))
30 | config = get_config(args.source_model.pretrained_model_name_or_path)
31 | config.mesh = mesh
32 |
33 | model = FlaxAutoModelForCausalLM.from_config(
34 | config,
35 | _do_init=False,
36 | input_shape=(1, 128),
37 | )
38 | del model.config.mesh
39 |
40 | model_params = param.load_params(**asdict(args.source_model))
41 | embeddings, model_params = param.stack_embeddings(
42 | model_params,
43 | config,
44 | pop_embeddings=True,
45 | )
46 |
47 | diff_embeddings, original_to_new_indices, diff_indices = utils.fvt(
48 | source_tokenizer,
49 | target_tokenizer,
50 | embeddings,
51 | )
52 | new_embeddings = embeddings[original_to_new_indices]
53 | if len(diff_indices) > 0:
54 | new_embeddings[diff_indices] = diff_embeddings
55 |
56 | model_params = param.assign_embeddings(model_params, new_embeddings, config)
57 |
58 | model.save_pretrained(args.output, params=model_params)
59 | config.save_pretrained(args.output)
60 | target_tokenizer.save_pretrained(args.output)
61 |
62 |
63 | if __name__ == "__main__":
64 | main(parse_args.parse_args(ZettArgs))
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import find_packages, setup
2 |
3 | setup(
4 | name="tokenkit",
5 | version="0.0.1",
6 | packages=find_packages(),
7 | install_requires=[],
8 | entry_points={
9 | "console_scripts": [],
10 | },
11 | )
12 |
--------------------------------------------------------------------------------
/tokenkit/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bminixhofer/tokenkit/7abac3578f6b3fa38f985e9f03bff7a47d5ab3b1/tokenkit/__init__.py
--------------------------------------------------------------------------------
/tokenkit/baseline_utils.py:
--------------------------------------------------------------------------------
1 | import math
2 | import multiprocessing
3 | from functools import partial
4 |
5 | import editdistance
6 | import jax
7 | import jax.numpy as jnp
8 | import numpy as np
9 | from tqdm.auto import tqdm
10 |
11 |
12 | def _compute_edit_distance(token, sorted_original_vocab):
13 | min_edit_distance = math.inf
14 | best_match = None
15 |
16 | closer_to_start = len(token) < len(
17 | sorted_original_vocab[int(len(sorted_original_vocab) / 2)]
18 | )
19 |
20 | if closer_to_start:
21 | candidates = sorted_original_vocab
22 | else:
23 | candidates = reversed(sorted_original_vocab)
24 |
25 | for original_token in candidates:
26 | if closer_to_start:
27 | # tokens only get longer
28 | if len(original_token) - len(token) >= min_edit_distance:
29 | break
30 | if len(token) - len(original_token) >= min_edit_distance:
31 | continue
32 | else:
33 | # tokens only get shorter
34 | if len(token) - len(original_token) >= min_edit_distance:
35 | break
36 | if len(original_token) - len(token) >= min_edit_distance:
37 | continue
38 |
39 | edit_distance = editdistance.eval(token, original_token)
40 | if edit_distance < min_edit_distance:
41 | min_edit_distance = edit_distance
42 | best_match = original_token
43 |
44 | return token, best_match, min_edit_distance
45 |
46 |
47 | def compute_mined_mapping(
48 | tokenizer_original, tokenizer_new, num_workers=1, chunksize=500
49 | ):
50 | original_vocab = tokenizer_original.get_vocab()
51 | new_vocab = tokenizer_new.get_vocab()
52 |
53 | mapping = np.zeros(len(tokenizer_new), dtype=np.int32)
54 | edit_distances = {}
55 |
56 | intersection = [token for token in new_vocab.keys() if token in original_vocab]
57 | completion = [token for token in new_vocab.keys() if token not in original_vocab]
58 | sorted_completion = sorted(completion, key=lambda x: len(x))
59 | sorted_original_vocab = sorted(original_vocab.keys(), key=lambda x: len(x))
60 |
61 | for token in intersection:
62 | mapping[new_vocab[token]] = original_vocab[token]
63 | edit_distances[token] = 0
64 |
65 | with multiprocessing.Pool(max(num_workers, 1)) as pool:
66 | results = list(
67 | tqdm(
68 | pool.imap_unordered(
69 | partial(
70 | _compute_edit_distance,
71 | sorted_original_vocab=sorted_original_vocab,
72 | ),
73 | sorted_completion,
74 | chunksize=chunksize,
75 | ),
76 | desc="Computing MinED mapping",
77 | total=len(sorted_completion),
78 | )
79 | )
80 |
81 | for token, best_match, min_edit_distance in results:
82 | mapping[new_vocab[token]] = original_vocab[best_match]
83 | edit_distances[token] = min_edit_distance
84 |
85 | return mapping, edit_distances
86 |
87 |
88 | def compute_forward_kl_divergence(
89 | logits,
90 | teacher_logits,
91 | target,
92 | kd_temp,
93 | padding_id,
94 | tea_temp=None,
95 | reduction="sum",
96 | log=None,
97 | use_tea_temp=False,
98 | ):
99 | logits = logits / kd_temp
100 | teacher_logits = teacher_logits / kd_temp
101 | teacher_logits = teacher_logits / tea_temp if use_tea_temp else teacher_logits
102 |
103 | lprobs = jax.nn.log_softmax(logits, -1)
104 | teacher_probs = jax.nn.softmax(teacher_logits, -1)
105 | teacher_lprobs = jax.nn.log_softmax(teacher_logits, -1)
106 | kld = teacher_probs * (teacher_lprobs - lprobs)
107 | inf_mask = jnp.isinf(logits)
108 | kld = jnp.where(inf_mask, 0.0, kld).sum(-1)
109 |
110 | if reduction == "sum":
111 | pad_mask = target == padding_id
112 | kld = jnp.where(pad_mask, 0.0, kld)
113 | kld = kld.sum()
114 |
115 | if log is not None:
116 | log["forward_kl"] = kld
117 |
118 | return kld
119 |
120 |
121 | def compute_reverse_kl_divergence(
122 | logits,
123 | teacher_logits,
124 | target,
125 | kd_temp,
126 | padding_id,
127 | tea_temp=None,
128 | reduction="sum",
129 | log=None,
130 | use_tea_temp=False,
131 | ):
132 | logits = logits / kd_temp
133 | teacher_logits = teacher_logits / kd_temp
134 | teacher_logits = teacher_logits / tea_temp if use_tea_temp else teacher_logits
135 |
136 | probs = jax.nn.softmax(logits, -1)
137 | lprobs = jax.nn.log_softmax(logits, -1)
138 | teacher_lprobs = jax.nn.log_softmax(teacher_logits, -1)
139 | kld = probs * (lprobs - teacher_lprobs)
140 | inf_mask = jnp.isinf(logits) | jnp.isinf(teacher_logits)
141 | kld = jnp.where(inf_mask, 0.0, kld).sum(-1)
142 |
143 | if reduction == "sum":
144 | pad_mask = target == padding_id
145 | kld = jnp.where(pad_mask, 0.0, kld)
146 | kld = kld.sum()
147 |
148 | if log is not None:
149 | log["reverse_kl"] = kld
150 |
151 | return kld
152 |
153 |
154 | def compute_adaptive_kl_divergence(
155 | logits,
156 | teacher_logits,
157 | target,
158 | kd_temp,
159 | padding_id,
160 | alpha,
161 | tea_temp=None,
162 | reduction="sum",
163 | log=None,
164 | use_tea_temp=False,
165 | ):
166 | probs = jax.nn.softmax(logits / kd_temp, axis=-1).astype(jnp.float32)
167 | if use_tea_temp:
168 | teacher_probs = jax.nn.softmax(
169 | teacher_logits / tea_temp / kd_temp, axis=-1
170 | ).astype(jnp.float32)
171 | else:
172 | teacher_probs = jax.nn.softmax(teacher_logits / kd_temp, axis=-1).astype(
173 | jnp.float32
174 | )
175 |
176 | sorted_teacher_probs = jnp.sort(teacher_probs, axis=-1)
177 | sorted_idx = jnp.argsort(teacher_probs, axis=-1)
178 | sorted_probs = jnp.take_along_axis(
179 | probs, sorted_idx, axis=-1
180 | ) # TODO: check if we need [..., None]?
181 | gap = jnp.abs(sorted_teacher_probs - sorted_probs)
182 | cum_teacher_probs = jnp.cumsum(sorted_teacher_probs, axis=-1)
183 | tail_mask = (cum_teacher_probs < alpha).astype(jnp.float32)
184 | g_head = jax.lax.stop_gradient(jnp.sum(gap * (1 - tail_mask), axis=-1))
185 | g_tail = jax.lax.stop_gradient(jnp.sum(gap * tail_mask, axis=-1))
186 |
187 | fkl = compute_forward_kl_divergence(
188 | logits,
189 | teacher_logits,
190 | target,
191 | kd_temp,
192 | padding_id,
193 | tea_temp=tea_temp,
194 | reduction="none",
195 | use_tea_temp=use_tea_temp,
196 | )
197 | rkl = compute_reverse_kl_divergence(
198 | logits,
199 | teacher_logits,
200 | target,
201 | kd_temp,
202 | padding_id,
203 | tea_temp=tea_temp,
204 | reduction="none",
205 | use_tea_temp=use_tea_temp,
206 | )
207 |
208 | akl = (g_head / (g_head + g_tail)) * fkl + (g_tail / (g_head + g_tail)) * rkl
209 |
210 | if reduction == "sum":
211 | pad_mask = target == padding_id
212 | akl = jnp.where(pad_mask, 0.0, akl)
213 | akl = akl.sum()
214 |
215 | if log is not None:
216 | log["adaptive_kl"] = akl
217 |
218 | return akl
219 |
220 |
221 | def compute_skewed_forward_kl_divergence(
222 | logits,
223 | teacher_logits,
224 | target,
225 | kd_temp,
226 | padding_id,
227 | skew_lambda,
228 | tea_temp=None,
229 | reduction="sum",
230 | log=None,
231 | use_tea_temp=False,
232 | epsilon=1e-9,
233 | ):
234 | logits = logits / kd_temp
235 | teacher_logits = teacher_logits / kd_temp
236 | teacher_logits = teacher_logits / tea_temp if use_tea_temp else teacher_logits
237 |
238 | student_probs = jax.nn.softmax(logits, -1).astype(jnp.float32)
239 | teacher_probs = jax.nn.softmax(teacher_logits, -1).astype(jnp.float32)
240 | mixed_probs = skew_lambda * teacher_probs + (1 - skew_lambda) * student_probs
241 | mixed_lprobs = jnp.log(mixed_probs + epsilon)
242 | teacher_lprobs = jax.nn.log_softmax(teacher_logits, -1).astype(jnp.float32)
243 | kld = teacher_probs * (teacher_lprobs - mixed_lprobs)
244 | inf_mask = jnp.isinf(logits) | jnp.isinf(teacher_logits)
245 | kld = jnp.where(inf_mask, 0.0, kld).sum(-1)
246 |
247 | if reduction == "sum":
248 | pad_mask = target == padding_id
249 | kld = jnp.where(pad_mask, 0.0, kld)
250 | kld = kld.sum()
251 |
252 | if log is not None:
253 | log["skewed_forward_kl"] = kld
254 |
255 | return kld
256 |
257 |
258 | def compute_skewed_reverse_kl_divergence(
259 | logits,
260 | teacher_logits,
261 | target,
262 | kd_temp,
263 | padding_id,
264 | skew_lambda,
265 | tea_temp=None,
266 | reduction="sum",
267 | log=None,
268 | use_tea_temp=False,
269 | epsilon=1e-9,
270 | ):
271 | logits = logits / kd_temp
272 | teacher_logits = teacher_logits / kd_temp
273 | teacher_logits = teacher_logits / tea_temp if use_tea_temp else teacher_logits
274 |
275 | student_probs = jax.nn.softmax(logits, -1).astype(jnp.float32)
276 | teacher_probs = jax.nn.softmax(teacher_logits, -1).astype(jnp.float32)
277 | mixed_probs = (1 - skew_lambda) * teacher_probs + skew_lambda * student_probs
278 | mixed_lprobs = jnp.log(mixed_probs + epsilon)
279 | student_lprobs = jax.nn.log_softmax(logits, -1).astype(jnp.float32)
280 | kld = student_probs * (student_lprobs - mixed_lprobs)
281 | inf_mask = jnp.isinf(logits) | jnp.isinf(teacher_logits)
282 | kld = jnp.where(inf_mask, 0.0, kld).sum(-1)
283 |
284 | if reduction == "sum":
285 | pad_mask = target == padding_id
286 | kld = jnp.where(pad_mask, 0.0, kld)
287 | kld = kld.sum()
288 |
289 | if log is not None:
290 | log["skewed_reverse_kl"] = kld
291 |
292 | return kld
293 |
294 |
295 | def compute_js_divergence(
296 | logits,
297 | teacher_logits,
298 | target,
299 | kd_temp,
300 | tea_temp,
301 | padding_id,
302 | reduction="sum",
303 | log=None,
304 | use_tea_temp=False,
305 | epsilon=1e-9,
306 | ):
307 | logits = logits / kd_temp
308 | teacher_logits = teacher_logits / kd_temp
309 | teacher_logits = teacher_logits / tea_temp if use_tea_temp else teacher_logits
310 |
311 | probs = jax.nn.softmax(logits, -1).astype(jnp.float32)
312 | teacher_probs = jax.nn.softmax(teacher_logits, -1).astype(jnp.float32)
313 | m_probs = (probs + teacher_probs) / 2
314 |
315 | lprobs = jnp.log(probs + epsilon)
316 | teacher_lprobs = jnp.log(teacher_probs + epsilon)
317 | m_lprobs = jnp.log(m_probs + epsilon)
318 |
319 | kld1 = teacher_probs * (teacher_lprobs - m_lprobs)
320 | kld2 = probs * (lprobs - m_lprobs)
321 | kld = (kld1 + kld2) / 2
322 |
323 | if reduction == "sum":
324 | pad_mask = target == padding_id
325 | kld = jnp.where(pad_mask, 0.0, kld)
326 | kld = kld.sum()
327 |
328 | if log is not None:
329 | log["js_div"] = kld
330 |
331 | return kld
332 |
--------------------------------------------------------------------------------
/tokenkit/constants.py:
--------------------------------------------------------------------------------
1 | # Character mapping dictionaries
2 | EXPAND_INPUT_IDS_MAX_LENGTH = 16
3 | CHARS_TO_BYTES = {
4 | "Ā": 0,
5 | "ā": 1,
6 | "Ă": 2,
7 | "ă": 3,
8 | "Ą": 4,
9 | "ą": 5,
10 | "Ć": 6,
11 | "ć": 7,
12 | "Ĉ": 8,
13 | "ĉ": 9,
14 | "Ċ": 10,
15 | "ċ": 11,
16 | "Č": 12,
17 | "č": 13,
18 | "Ď": 14,
19 | "ď": 15,
20 | "Đ": 16,
21 | "đ": 17,
22 | "Ē": 18,
23 | "ē": 19,
24 | "Ĕ": 20,
25 | "ĕ": 21,
26 | "Ė": 22,
27 | "ė": 23,
28 | "Ę": 24,
29 | "ę": 25,
30 | "Ě": 26,
31 | "ě": 27,
32 | "Ĝ": 28,
33 | "ĝ": 29,
34 | "Ğ": 30,
35 | "ğ": 31,
36 | "Ġ": 32,
37 | "!": 33,
38 | '"': 34,
39 | "#": 35,
40 | "$": 36,
41 | "%": 37,
42 | "&": 38,
43 | "'": 39,
44 | "(": 40,
45 | ")": 41,
46 | "*": 42,
47 | "+": 43,
48 | ",": 44,
49 | "-": 45,
50 | ".": 46,
51 | "/": 47,
52 | "0": 48,
53 | "1": 49,
54 | "2": 50,
55 | "3": 51,
56 | "4": 52,
57 | "5": 53,
58 | "6": 54,
59 | "7": 55,
60 | "8": 56,
61 | "9": 57,
62 | ":": 58,
63 | ";": 59,
64 | "<": 60,
65 | "=": 61,
66 | ">": 62,
67 | "?": 63,
68 | "@": 64,
69 | "A": 65,
70 | "B": 66,
71 | "C": 67,
72 | "D": 68,
73 | "E": 69,
74 | "F": 70,
75 | "G": 71,
76 | "H": 72,
77 | "I": 73,
78 | "J": 74,
79 | "K": 75,
80 | "L": 76,
81 | "M": 77,
82 | "N": 78,
83 | "O": 79,
84 | "P": 80,
85 | "Q": 81,
86 | "R": 82,
87 | "S": 83,
88 | "T": 84,
89 | "U": 85,
90 | "V": 86,
91 | "W": 87,
92 | "X": 88,
93 | "Y": 89,
94 | "Z": 90,
95 | "[": 91,
96 | "\\": 92,
97 | "]": 93,
98 | "^": 94,
99 | "_": 95,
100 | "`": 96,
101 | "a": 97,
102 | "b": 98,
103 | "c": 99,
104 | "d": 100,
105 | "e": 101,
106 | "f": 102,
107 | "g": 103,
108 | "h": 104,
109 | "i": 105,
110 | "j": 106,
111 | "k": 107,
112 | "l": 108,
113 | "m": 109,
114 | "n": 110,
115 | "o": 111,
116 | "p": 112,
117 | "q": 113,
118 | "r": 114,
119 | "s": 115,
120 | "t": 116,
121 | "u": 117,
122 | "v": 118,
123 | "w": 119,
124 | "x": 120,
125 | "y": 121,
126 | "z": 122,
127 | "{": 123,
128 | "|": 124,
129 | "}": 125,
130 | "~": 126,
131 | "ġ": 127,
132 | "Ģ": 128,
133 | "ģ": 129,
134 | "Ĥ": 130,
135 | "ĥ": 131,
136 | "Ħ": 132,
137 | "ħ": 133,
138 | "Ĩ": 134,
139 | "ĩ": 135,
140 | "Ī": 136,
141 | "ī": 137,
142 | "Ĭ": 138,
143 | "ĭ": 139,
144 | "Į": 140,
145 | "į": 141,
146 | "İ": 142,
147 | "ı": 143,
148 | "IJ": 144,
149 | "ij": 145,
150 | "Ĵ": 146,
151 | "ĵ": 147,
152 | "Ķ": 148,
153 | "ķ": 149,
154 | "ĸ": 150,
155 | "Ĺ": 151,
156 | "ĺ": 152,
157 | "Ļ": 153,
158 | "ļ": 154,
159 | "Ľ": 155,
160 | "ľ": 156,
161 | "Ŀ": 157,
162 | "ŀ": 158,
163 | "Ł": 159,
164 | "ł": 160,
165 | "¡": 161,
166 | "¢": 162,
167 | "£": 163,
168 | "¤": 164,
169 | "¥": 165,
170 | "¦": 166,
171 | "§": 167,
172 | "¨": 168,
173 | "©": 169,
174 | "ª": 170,
175 | "«": 171,
176 | "¬": 172,
177 | "Ń": 173,
178 | "®": 174,
179 | "¯": 175,
180 | "°": 176,
181 | "±": 177,
182 | "²": 178,
183 | "³": 179,
184 | "´": 180,
185 | "µ": 181,
186 | "¶": 182,
187 | "·": 183,
188 | "¸": 184,
189 | "¹": 185,
190 | "º": 186,
191 | "»": 187,
192 | "¼": 188,
193 | "½": 189,
194 | "¾": 190,
195 | "¿": 191,
196 | "À": 192,
197 | "Á": 193,
198 | "Â": 194,
199 | "Ã": 195,
200 | "Ä": 196,
201 | "Å": 197,
202 | "Æ": 198,
203 | "Ç": 199,
204 | "È": 200,
205 | "É": 201,
206 | "Ê": 202,
207 | "Ë": 203,
208 | "Ì": 204,
209 | "Í": 205,
210 | "Î": 206,
211 | "Ï": 207,
212 | "Ð": 208,
213 | "Ñ": 209,
214 | "Ò": 210,
215 | "Ó": 211,
216 | "Ô": 212,
217 | "Õ": 213,
218 | "Ö": 214,
219 | "×": 215,
220 | "Ø": 216,
221 | "Ù": 217,
222 | "Ú": 218,
223 | "Û": 219,
224 | "Ü": 220,
225 | "Ý": 221,
226 | "Þ": 222,
227 | "ß": 223,
228 | "à": 224,
229 | "á": 225,
230 | "â": 226,
231 | "ã": 227,
232 | "ä": 228,
233 | "å": 229,
234 | "æ": 230,
235 | "ç": 231,
236 | "è": 232,
237 | "é": 233,
238 | "ê": 234,
239 | "ë": 235,
240 | "ì": 236,
241 | "í": 237,
242 | "î": 238,
243 | "ï": 239,
244 | "ð": 240,
245 | "ñ": 241,
246 | "ò": 242,
247 | "ó": 243,
248 | "ô": 244,
249 | "õ": 245,
250 | "ö": 246,
251 | "÷": 247,
252 | "ø": 248,
253 | "ù": 249,
254 | "ú": 250,
255 | "û": 251,
256 | "ü": 252,
257 | "ý": 253,
258 | "þ": 254,
259 | "ÿ": 255,
260 | }
261 | BYTES_TO_CHARS = {v: k for k, v in CHARS_TO_BYTES.items()}
262 | MAX_CHARS_PER_TOKEN = 16
263 | # for hn training
264 | DEFAULT_SPLIT_REGEX = r"'s|'t|'re|'ve|'m|'ll|'d| ?[\p{L}\p{M}]+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"
--------------------------------------------------------------------------------
/tokenkit/data/__init__.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import datasets
4 | from datasets import interleave_datasets, load_dataset
5 | from torch.utils.data import Dataset
6 | from functools import partial
7 |
8 | from tokenkit.utils import preprocess_messages
9 |
10 |
11 | class JSONLDataset(Dataset):
12 | def __init__(
13 | self, path, lang_code, batch_size, n_subsample=None, num_workers=0, seed=None
14 | ):
15 | self.path = path
16 | self.lang_code = lang_code
17 | self.batch_size = batch_size
18 |
19 | self.raw_dset = load_dataset(
20 | "json",
21 | data_files=path,
22 | split="train" if n_subsample is None else f"train[:{n_subsample}]",
23 | )
24 | self.dset = self.raw_dset.map(
25 | lambda x, idx: {
26 | "text": [x["text"]],
27 | "lang_code": [[lang_code] * len(x["text"])],
28 | "index": [idx],
29 | },
30 | batched=True,
31 | batch_size=batch_size,
32 | drop_last_batch=True,
33 | remove_columns=self.raw_dset.column_names,
34 | with_indices=True,
35 | )
36 |
37 | # TODO: check consistency between this and iterable data - iterable cycles through the data on consecutive calls
38 | # (but this does NOT influence the main stream)
39 | # TODO: return texts directly instead of dicts?
40 | def get_texts(self, n: int, lang_code: None | str = None):
41 | if lang_code is not None and lang_code != self.lang_code:
42 | raise ValueError("Invalid lang_code")
43 |
44 | texts = self.raw_dset[:n]["text"]
45 | return [{"text": t} for t in texts]
46 |
47 | def get_torch_dataset(self):
48 | return self.dset
49 |
50 |
51 | def process_example(example, lang_code):
52 | if "messages" in example:
53 | text = preprocess_messages(example["messages"])
54 | else:
55 | text = example["text"]
56 |
57 | return {
58 | "text": text,
59 | "lang_code": lang_code,
60 | }
61 |
62 |
63 | class HFDataset:
64 | def __init__(
65 | self,
66 | dataset_configs,
67 | mix_languages,
68 | batch_size,
69 | streaming=True,
70 | n_subsample=None,
71 | shuffle_buffer_size=None,
72 | num_workers=0,
73 | seed=1234,
74 | ):
75 | self.dataset_configs = dataset_configs
76 | self.mix_languages = mix_languages
77 | self.seed = seed
78 | self.batch_size = batch_size
79 | self.shuffle_buffer_size = shuffle_buffer_size
80 |
81 | if n_subsample is not None:
82 | raise ValueError("Subsampling not supported for HF datasets")
83 |
84 | self.dset_streams = {}
85 | self.probs = {}
86 |
87 | if streaming:
88 | process_kwargs = {}
89 | else:
90 | process_kwargs = {"num_proc": num_workers if num_workers > 0 else None}
91 |
92 | for config in dataset_configs:
93 | stream = load_dataset(
94 | **config["kwargs"],
95 | streaming=streaming,
96 | trust_remote_code=True,
97 | )
98 |
99 | if self.shuffle_buffer_size is not None:
100 | if streaming:
101 | stream = stream.shuffle(
102 | buffer_size=self.shuffle_buffer_size, seed=seed
103 | )
104 | else:
105 | stream = stream.shuffle(seed=seed)
106 |
107 | self.dset_streams[config["lang_code"]] = stream.map(
108 | partial(process_example, lang_code=config["lang_code"]), **process_kwargs, remove_columns=stream.column_names if not streaming else None
109 | )
110 |
111 | if "p" in config:
112 | self.probs[config["lang_code"]] = config["p"]
113 |
114 | if 0 < len(self.probs) < len(self.dset_streams):
115 | raise ValueError(
116 | "If you provide probabilities, you must provide them for all datasets"
117 | )
118 |
119 | if len(self.probs) == 0:
120 | self.probs = {k: 1.0 for k in self.dset_streams.keys()}
121 |
122 | # normalize probabilities
123 | total = sum(self.probs.values())
124 | for k in self.probs:
125 | self.probs[k] /= total
126 |
127 | if self.mix_languages:
128 | self.stream = interleave_datasets(
129 | list(self.dset_streams.values()),
130 | probabilities=list(self.probs.values()),
131 | seed=seed,
132 | ).batch(batch_size, drop_last_batch=True, **process_kwargs)
133 | else:
134 | self.stream = interleave_datasets(
135 | [
136 | s.batch(batch_size, drop_last_batch=True, **process_kwargs)
137 | for s in self.dset_streams.values()
138 | ],
139 | probabilities=list(self.probs.values()),
140 | seed=seed,
141 | )
142 |
143 | def get_texts(self, n: int, lang_code: None | str = None):
144 | if n == 0:
145 | return []
146 |
147 | if lang_code is None:
148 | # unbatch
149 | batches_to_take = math.ceil(n / self.batch_size)
150 | batches = list(self.stream.take(batches_to_take))
151 | out = []
152 | keys = list(batches[0].keys())
153 |
154 | for batch in batches:
155 | for i in range(len(batch[keys[0]])):
156 | out.append({k: batch[k][i] for k in keys})
157 |
158 | return out
159 |
160 | return list(self.dset_streams[lang_code].take(n))
161 |
162 | def get_torch_dataset(self):
163 | return self.stream
164 |
165 |
166 | class HFSavedDataset(Dataset):
167 | def __init__(
168 | self,
169 | dataset_configs,
170 | lang_code,
171 | batch_size,
172 | n_subsample=None,
173 | num_workers=0,
174 | seed=None,
175 | ):
176 | self.dataset_configs = dataset_configs
177 | self.lang_code = lang_code
178 | self.seed = seed
179 | self.batch_size = batch_size
180 |
181 | if n_subsample is not None:
182 | raise ValueError("Subsampling not supported for HF datasets")
183 |
184 | self.dsets = []
185 | self.probs = []
186 |
187 | for config in dataset_configs:
188 | self.dsets.append(datasets.load_from_disk(config["path"])["train"])
189 |
190 | if "p" in config:
191 | self.probs.append(config["p"])
192 |
193 | if 0 < len(self.probs) < len(self.dsets):
194 | raise ValueError(
195 | "If you provide probabilities, you must provide them for all datasets"
196 | )
197 |
198 | if len(self.probs) == 0:
199 | self.probs = None
200 |
201 | self.dset = interleave_datasets(
202 | self.dsets,
203 | probabilities=self.probs,
204 | seed=seed,
205 | )
206 |
207 | def __len__(self):
208 | return len(self.dset) // self.batch_size
209 |
210 | def get_texts(self, n: int, lang_code: None | str = None):
211 | if lang_code is not None and lang_code != self.lang_code:
212 | raise ValueError("Invalid lang_code")
213 |
214 | texts = self.dset[:n]["text"]
215 | return [{"text": t} for t in texts]
216 |
217 | def __getitem__(self, batch_idx):
218 | start, end = batch_idx * self.batch_size, (batch_idx + 1) * self.batch_size
219 |
220 | texts = self.dset[start:end]["text"]
221 |
222 | return {
223 | "text": texts,
224 | "lang_code": [self.lang_code] * len(texts),
225 | "index": list(range(start, start + len(texts))),
226 | }
227 |
228 | def get_torch_dataset(self):
229 | return self
230 |
231 |
232 | def get_dataset(kind, **kwargs):
233 | if kind == "jsonl":
234 | return JSONLDataset(**kwargs)
235 | elif kind == "hf":
236 | return HFDataset(**kwargs)
237 | elif kind == "hf_saved":
238 | return HFSavedDataset(**kwargs)
239 | else:
240 | raise ValueError("Invalid dataset kind")
241 |
242 |
243 | def test_load_tulu3():
244 | dset = get_dataset(
245 | "hf",
246 | dataset_configs=[
247 | {
248 | "lang_code": "en",
249 | "kwargs": {"path": "allenai/tulu-3-sft-mixture", "split": "train"},
250 | }
251 | ],
252 | batch_size=16,
253 | num_workers=16,
254 | streaming=False,
255 | mix_languages=False,
256 | )
257 | assert dset.get_texts(1)[0]["text"].startswith(
258 | "<||><||><||><||>"
259 | )
260 |
--------------------------------------------------------------------------------
/tokenkit/gcs_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | from concurrent.futures import ThreadPoolExecutor
3 |
4 | from google.cloud import storage
5 |
6 |
7 | def download_from_gcs(bucket_name, source_blob_name, destination_file_name):
8 | storage_client = storage.Client()
9 | bucket = storage_client.bucket(bucket_name)
10 | blob = bucket.blob(source_blob_name)
11 | blob.download_to_filename(destination_file_name)
12 | print(
13 | f"Downloaded {source_blob_name} from bucket {bucket_name} to {destination_file_name}"
14 | )
15 |
16 |
17 | def upload_to_gcs(bucket_name, source_file_name, destination_blob_name):
18 | storage_client = storage.Client()
19 | bucket = storage_client.bucket(bucket_name)
20 | blob = bucket.blob(destination_blob_name)
21 | blob.upload_from_filename(source_file_name)
22 | print(
23 | f"Uploaded {source_file_name} to bucket {bucket_name} as {destination_blob_name}"
24 | )
25 |
26 |
27 | def is_gcs_path(path):
28 | return path.startswith("gs://")
29 |
30 |
31 | def parse_gcs_path(gcs_path):
32 | path_parts = gcs_path[len("gs://") :].split("/", 1)
33 | bucket_name = path_parts[0]
34 | blob_name = path_parts[1].rstrip("/") if len(path_parts) > 1 else ""
35 | return bucket_name, blob_name
36 |
37 |
38 | def upload_file(bucket_name, source_file_path, destination_blob_name):
39 | """
40 | Uploads a single file to the specified GCS bucket.
41 | """
42 | storage_client = storage.Client()
43 | bucket = storage_client.bucket(bucket_name)
44 | blob = bucket.blob(destination_blob_name)
45 |
46 | blob.upload_from_filename(source_file_path)
47 | print(f"Uploaded {source_file_path} to gs://{bucket_name}/{destination_blob_name}")
48 |
49 |
50 | def upload_directory_to_gcs(bucket_name, source_directory, target_directory=""):
51 | """
52 | Uploads all files from a local directory to the specified GCS bucket,
53 | optionally under a target directory.
54 |
55 | Args:
56 | bucket_name (str): The name of the GCS bucket.
57 | source_directory (str): Path to the local directory to upload.
58 | target_directory (str): Optional target directory in the GCS bucket.
59 | """
60 | executor = ThreadPoolExecutor()
61 |
62 | for root, _, files in os.walk(source_directory):
63 | for file in files:
64 | local_path = os.path.join(root, file)
65 | # Define the destination path in the bucket
66 | relative_path = os.path.relpath(local_path, source_directory)
67 | destination_path = os.path.join(target_directory, relative_path).replace(
68 | "\\", "/"
69 | ) # Ensure GCS path format
70 | # upload_file(bucket_name, local_path, destination_path)
71 | executor.submit(upload_file, bucket_name, local_path, destination_path)
72 |
73 | return executor
74 |
--------------------------------------------------------------------------------
/tokenkit/hf/__init__.py:
--------------------------------------------------------------------------------
1 | from transformers import AutoConfig, AutoModel, AutoModelForCausalLM, FlaxAutoModelForCausalLM
2 | from tokenkit.hf.configuration_tpu_llama import TPULlamaConfig
3 | from tokenkit.hf.modelling_tpu_llama import TPULlamaForCausalLM, TPULlamaModel
4 | from tokenkit.hf.modelling_flax_tpu_llama import FlaxTPULlamaForCausalLM, FlaxTPULlamaModel
5 |
6 | from tokenkit.hf.configuration_tpu_gemma2 import TPUGemma2Config
7 | from tokenkit.hf.modelling_tpu_gemma2 import TPUGemma2ForCausalLM, TPUGemma2Model
8 | from tokenkit.hf.modelling_flax_tpu_gemma2 import FlaxTPUGemma2ForCausalLM, FlaxTPUGemma2Model
9 |
10 | AutoConfig.register("tpu_llama", TPULlamaConfig)
11 | AutoModel.register(TPULlamaConfig, TPULlamaModel)
12 | AutoModelForCausalLM.register(TPULlamaConfig, TPULlamaForCausalLM)
13 | TPULlamaForCausalLM.register_for_auto_class("AutoModelForCausalLM")
14 | TPULlamaModel.register_for_auto_class("AutoModel")
15 | FlaxAutoModelForCausalLM.register(TPULlamaConfig, FlaxTPULlamaForCausalLM)
16 | FlaxTPULlamaForCausalLM.register_for_auto_class("FlaxAutoModelForCausalLM")
17 | FlaxTPULlamaModel.register_for_auto_class("FlaxAutoModel")
18 |
19 | AutoConfig.register("tpu_gemma2", TPUGemma2Config)
20 | AutoModel.register(TPUGemma2Config, TPUGemma2Model)
21 | AutoModelForCausalLM.register(TPUGemma2Config, TPUGemma2ForCausalLM)
22 | TPUGemma2ForCausalLM.register_for_auto_class("AutoModelForCausalLM")
23 | TPUGemma2Model.register_for_auto_class("AutoModel")
24 | FlaxAutoModelForCausalLM.register(TPUGemma2Config, FlaxTPUGemma2ForCausalLM)
25 | FlaxTPUGemma2ForCausalLM.register_for_auto_class("FlaxAutoModelForCausalLM")
26 | FlaxTPUGemma2Model.register_for_auto_class("FlaxAutoModel")
27 |
28 | __all__ = ["TPULlamaConfig", "TPULlamaModel", "TPULlamaForCausalLM", "FlaxTPULlamaForCausalLM", "FlaxTPULlamaModel", "TPUGemma2Config", "TPUGemma2Model", "TPUGemma2ForCausalLM", "FlaxTPUGemma2ForCausalLM", "FlaxTPUGemma2Model"]
29 |
30 |
31 | def get_config(pretrained_model_name_or_path: str, **kwargs):
32 | config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
33 | # compatibility with outside jax checkpoints
34 | if config.model_type in {"llama", "tpu_llama"}:
35 | config = TPULlamaConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
36 | config.model_type = "tpu_llama"
37 | return config
38 | elif config.model_type in {"gemma2", "tpu_gemma2"}:
39 | config = TPUGemma2Config.from_pretrained(pretrained_model_name_or_path, **kwargs)
40 | config.model_type = "tpu_gemma2"
41 | return config
42 | else:
43 | return config
44 |
--------------------------------------------------------------------------------
/tokenkit/hf/configuration_tpu_gemma2.py:
--------------------------------------------------------------------------------
1 | """TPU Gemma2 model configuration"""
2 |
3 | from transformers.configuration_utils import PretrainedConfig
4 |
5 |
6 | class TPUGemma2Config(PretrainedConfig):
7 | r"""
8 | This is the configuration class to store the configuration of a [`Gemma2Model`]. It is used to instantiate an Gemma2
9 | model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
10 | defaults will yield a similar configuration to that of the Gemma2-7B.
11 | e.g. [google/gemma2-7b](https://huggingface.co/google/gemma2-7b)
12 | Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
13 | documentation from [`PretrainedConfig`] for more information.
14 | Args:
15 | vocab_size (`int`, *optional*, defaults to 256000):
16 | Vocabulary size of the Gemma2 model. Defines the number of different tokens that can be represented by the
17 | `inputs_ids` passed when calling [`Gemma2Model`]
18 | hidden_size (`int`, *optional*, defaults to 3072):
19 | Dimension of the hidden representations.
20 | intermediate_size (`int`, *optional*, defaults to 24576):
21 | Dimension of the MLP representations.
22 | num_hidden_layers (`int`, *optional*, defaults to 28):
23 | Number of hidden layers in the Transformer decoder.
24 | num_attention_heads (`int`, *optional*, defaults to 16):
25 | Number of attention heads for each attention layer in the Transformer decoder.
26 | num_key_value_heads (`int`, *optional*, defaults to 16):
27 | This is the number of key_value heads that should be used to implement Grouped Query Attention. If
28 | `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
29 | `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
30 | converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
31 | by meanpooling all the original heads within that group. For more details checkout [this
32 | paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
33 | `num_attention_heads`.
34 | head_dim (`int`, *optional*, defaults to 256):
35 | The attention head dimension.
36 | hidden_activation (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
37 | The non-linear activation function (function or string) in the decoder. Will default to `"gelu_pytorch_tanh"`
38 | if not specified. `"gelu_pytorch_tanh"` uses an approximation of the `"gelu"` activation function.
39 | max_position_embeddings (`int`, *optional*, defaults to 8192):
40 | The maximum sequence length that this model might ever be used with.
41 | initializer_range (`float`, *optional*, defaults to 0.02):
42 | The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
43 | rms_norm_eps (`float`, *optional*, defaults to 1e-06):
44 | The epsilon used by the rms normalization layers.
45 | use_cache (`bool`, *optional*, defaults to `True`):
46 | Whether or not the model should return the last key/values attentions (not used by all models). Only
47 | relevant if `config.is_decoder=True`.
48 | pad_token_id (`int`, *optional*, defaults to 0):
49 | Padding token id.
50 | eos_token_id (`int`, *optional*, defaults to 1):
51 | End of stream token id.
52 | bos_token_id (`int`, *optional*, defaults to 2):
53 | Beginning of stream token id.
54 | tie_word_embeddings (`bool`, *optional*, defaults to `True`):
55 | Whether to tie weight embeddings
56 | rope_theta (`float`, *optional*, defaults to 10000.0):
57 | The base period of the RoPE embeddings.
58 | attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
59 | Whether to use a bias in the query, key, value and output projection layers during self-attention.
60 | attention_dropout (`float`, *optional*, defaults to 0.0):
61 | The dropout ratio for the attention probabilities.
62 | query_pre_attn_scalar (`float`, *optional*, defaults to 224): scaling factor used on the attention scores
63 | sliding_window (`int`, *optional*, defaults to 4096): in Gemma2, every other layer uses sliding window attention. This is the
64 | size of the sliding window.
65 | final_logit_softcapping (`float`, *optional*, defaults to 30.0): scaling factor when applying tanh softcapping on the logits.
66 | attn_logit_softcapping (`float`, *optional*, defaults to 50.0): scaling factor when applying tanh softcapping on the attention scores.
67 | cache_implementation (`str`, *optional*, defaults to `"hybrid"`): the cache type to be used with `generate`.
68 |
69 | ```python
70 | >>> from transformers import Gemma2Model, Gemma2Config
71 | >>> # Initializing a Gemma2 gemma2-7b style configuration
72 | >>> configuration = Gemma2Config()
73 | >>> # Initializing a model from the gemma2-7b style configuration
74 | >>> model = Gemma2Model(configuration)
75 | >>> # Accessing the model configuration
76 | >>> configuration = model.config
77 | ```"""
78 |
79 | model_type = "tpu_gemma2"
80 | keys_to_ignore_at_inference = ["past_key_values"]
81 |
82 | def __init__(
83 | self,
84 | vocab_size=256000,
85 | hidden_size=3072,
86 | intermediate_size=24576,
87 | num_hidden_layers=28,
88 | num_attention_heads=16,
89 | num_key_value_heads=16,
90 | head_dim=256,
91 | hidden_activation="gelu_pytorch_tanh",
92 | max_position_embeddings=8192,
93 | initializer_range=0.02,
94 | rms_norm_eps=1e-6,
95 | use_cache=True,
96 | pad_token_id=0,
97 | eos_token_id=1,
98 | bos_token_id=2,
99 | tie_word_embeddings=True,
100 | rope_theta=10000.0,
101 | attention_bias=False,
102 | attention_dropout=0.0,
103 | query_pre_attn_scalar=224,
104 | sliding_window=4096,
105 | final_logit_softcapping=30.0,
106 | attn_logit_softcapping=50.0,
107 | cache_implementation="hybrid",
108 | expand_input_ids=False, # Transformers-native PyTorch generation support
109 | expand_input_ids_maxlen=None,
110 | expand_input_ids_vocab_size=None,
111 | expand_input_ids_dict=None,
112 | **kwargs,
113 | ):
114 | super().__init__(
115 | pad_token_id=pad_token_id,
116 | bos_token_id=bos_token_id,
117 | eos_token_id=eos_token_id,
118 | tie_word_embeddings=tie_word_embeddings,
119 | **kwargs,
120 | )
121 | self.vocab_size = vocab_size
122 | self.max_position_embeddings = max_position_embeddings
123 | self.hidden_size = hidden_size
124 | self.intermediate_size = intermediate_size
125 | self.num_hidden_layers = num_hidden_layers
126 | self.num_attention_heads = num_attention_heads
127 | self.head_dim = head_dim
128 | self.num_key_value_heads = num_key_value_heads
129 | self.initializer_range = initializer_range
130 | self.rms_norm_eps = rms_norm_eps
131 | self.use_cache = use_cache
132 | self.rope_theta = rope_theta
133 | self.attention_bias = attention_bias
134 | self.attention_dropout = attention_dropout
135 | self.hidden_activation = hidden_activation
136 | self.query_pre_attn_scalar = query_pre_attn_scalar
137 | self.sliding_window = sliding_window
138 | self.final_logit_softcapping = final_logit_softcapping
139 | self.attn_logit_softcapping = attn_logit_softcapping
140 | self.cache_implementation = cache_implementation
141 |
142 | self.expand_input_ids = expand_input_ids
143 | self.expand_input_ids_maxlen = expand_input_ids_maxlen
144 | self.expand_input_ids_vocab_size = expand_input_ids_vocab_size
145 | self.expand_input_ids_dict = expand_input_ids_dict
146 |
--------------------------------------------------------------------------------
/tokenkit/hf/configuration_tpu_llama.py:
--------------------------------------------------------------------------------
1 | """TPU LLaMA model configuration"""
2 |
3 | from transformers.configuration_utils import PretrainedConfig
4 | from transformers.modeling_rope_utils import rope_config_validation
5 |
6 |
7 | class TPULlamaConfig(PretrainedConfig):
8 | r"""
9 | This is the configuration class to store the configuration of a [`LlamaModel`]. It is used to instantiate an LLaMA
10 | model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
11 | defaults will yield a similar configuration to that of the LLaMA-7B.
12 |
13 | Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
14 | documentation from [`PretrainedConfig`] for more information.
15 |
16 |
17 | Args:
18 | vocab_size (`int`, *optional*, defaults to 32000):
19 | Vocabulary size of the LLaMA model. Defines the number of different tokens that can be represented by the
20 | `inputs_ids` passed when calling [`TPULlamaModel`]
21 | hidden_size (`int`, *optional*, defaults to 4096):
22 | Dimension of the hidden representations.
23 | intermediate_size (`int`, *optional*, defaults to 11008):
24 | Dimension of the MLP representations.
25 | num_hidden_layers (`int`, *optional*, defaults to 32):
26 | Number of hidden layers in the Transformer decoder.
27 | num_attention_heads (`int`, *optional*, defaults to 32):
28 | Number of attention heads for each attention layer in the Transformer decoder.
29 | num_key_value_heads (`int`, *optional*):
30 | This is the number of key_value heads that should be used to implement Grouped Query Attention. If
31 | `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
32 | `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
33 | converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
34 | by meanpooling all the original heads within that group. For more details checkout [this
35 | paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
36 | `num_attention_heads`.
37 | hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
38 | The non-linear activation function (function or string) in the decoder.
39 | max_position_embeddings (`int`, *optional*, defaults to 2048):
40 | The maximum sequence length that this model might ever be used with. Llama 1 supports up to 2048 tokens,
41 | Llama 2 up to 4096, CodeLlama up to 16384.
42 | initializer_range (`float`, *optional*, defaults to 0.02):
43 | The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
44 | rms_norm_eps (`float`, *optional*, defaults to 1e-06):
45 | The epsilon used by the rms normalization layers.
46 | use_cache (`bool`, *optional*, defaults to `True`):
47 | Whether or not the model should return the last key/values attentions (not used by all models). Only
48 | relevant if `config.is_decoder=True`.
49 | pad_token_id (`int`, *optional*):
50 | Padding token id.
51 | bos_token_id (`int`, *optional*, defaults to 1):
52 | Beginning of stream token id.
53 | eos_token_id (`int`, *optional*, defaults to 2):
54 | End of stream token id.
55 | pretraining_tp (`int`, *optional*, defaults to 1):
56 | Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
57 | document](https://huggingface.co/docs/transformers/main/perf_train_gpu_many#tensor-parallelism) to
58 | understand more about it. This value is necessary to ensure exact reproducibility of the pretraining
59 | results. Please refer to [this issue](https://github.com/pytorch/pytorch/issues/76232).
60 | tie_word_embeddings (`bool`, *optional*, defaults to `False`):
61 | Whether to tie weight embeddings
62 | rope_theta (`float`, *optional*, defaults to 10000.0):
63 | The base period of the RoPE embeddings.
64 | rope_scaling (`Dict`, *optional*):
65 | Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
66 | and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
67 | accordingly.
68 | Expected contents:
69 | `rope_type` (`str`):
70 | The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
71 | 'llama3'], with 'default' being the original RoPE implementation.
72 | `factor` (`float`, *optional*):
73 | Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
74 | most scaling types, a `factor` of x will enable the model to handle sequences of length x *
75 | original maximum pre-trained length.
76 | `original_max_position_embeddings` (`int`, *optional*):
77 | Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
78 | pretraining.
79 | `attention_factor` (`float`, *optional*):
80 | Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
81 | computation. If unspecified, it defaults to value recommended by the implementation, using the
82 | `factor` field to infer the suggested value.
83 | `beta_fast` (`float`, *optional*):
84 | Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
85 | ramp function. If unspecified, it defaults to 32.
86 | `beta_slow` (`float`, *optional*):
87 | Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
88 | ramp function. If unspecified, it defaults to 1.
89 | `short_factor` (`List[float]`, *optional*):
90 | Only used with 'longrope'. The scaling factor to be applied to short contexts (<
91 | `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
92 | size divided by the number of attention heads divided by 2
93 | `long_factor` (`List[float]`, *optional*):
94 | Only used with 'longrope'. The scaling factor to be applied to long contexts (<
95 | `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
96 | size divided by the number of attention heads divided by 2
97 | `low_freq_factor` (`float`, *optional*):
98 | Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
99 | `high_freq_factor` (`float`, *optional*):
100 | Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
101 | attention_bias (`bool`, *optional*, defaults to `False`):
102 | Whether to use a bias in the query, key, value and output projection layers during self-attention.
103 | attention_dropout (`float`, *optional*, defaults to 0.0):
104 | The dropout ratio for the attention probabilities.
105 | mlp_bias (`bool`, *optional*, defaults to `False`):
106 | Whether to use a bias in up_proj, down_proj and gate_proj layers in the MLP layers.
107 | head_dim (`int`, *optional*):
108 | The attention head dimension. If None, it will default to hidden_size // num_heads
109 |
110 | ```python
111 | >>> from transformers import LlamaModel, LlamaConfig
112 |
113 | >>> # Initializing a LLaMA llama-7b style configuration
114 | >>> configuration = LlamaConfig()
115 |
116 | >>> # Initializing a model from the llama-7b style configuration
117 | >>> model = LlamaModel(configuration)
118 |
119 | >>> # Accessing the model configuration
120 | >>> configuration = model.config
121 | ```"""
122 |
123 | model_type = "tpu_llama"
124 | keys_to_ignore_at_inference = ["past_key_values"]
125 |
126 | def __init__(
127 | self,
128 | vocab_size=32000,
129 | hidden_size=4096,
130 | intermediate_size=11008,
131 | num_hidden_layers=32,
132 | num_attention_heads=32,
133 | num_key_value_heads=None,
134 | hidden_act="silu",
135 | max_position_embeddings=2048,
136 | initializer_range=0.02,
137 | rms_norm_eps=1e-6,
138 | use_cache=True,
139 | pad_token_id=None,
140 | bos_token_id=1,
141 | eos_token_id=2,
142 | pretraining_tp=1,
143 | tie_word_embeddings=False,
144 | rope_theta=10000.0,
145 | rope_scaling=None,
146 | attention_bias=False,
147 | attention_dropout=0.0,
148 | mlp_bias=False,
149 | head_dim=None,
150 | expand_input_ids=False, # Transformers-native PyTorch generation support
151 | expand_input_ids_maxlen=None,
152 | expand_input_ids_vocab_size=None,
153 | expand_input_ids_dict=None,
154 | **kwargs,
155 | ):
156 | self.vocab_size = vocab_size
157 | self.max_position_embeddings = max_position_embeddings
158 | self.hidden_size = hidden_size
159 | self.intermediate_size = intermediate_size
160 | self.num_hidden_layers = num_hidden_layers
161 | self.num_attention_heads = num_attention_heads
162 |
163 | # for backward compatibility
164 | if num_key_value_heads is None:
165 | num_key_value_heads = num_attention_heads
166 |
167 | self.num_key_value_heads = num_key_value_heads
168 | self.hidden_act = hidden_act
169 | self.initializer_range = initializer_range
170 | self.rms_norm_eps = rms_norm_eps
171 | self.pretraining_tp = pretraining_tp
172 | self.use_cache = use_cache
173 | self.rope_theta = rope_theta
174 | self.rope_scaling = rope_scaling
175 | self.attention_bias = attention_bias
176 | self.attention_dropout = attention_dropout
177 | self.mlp_bias = mlp_bias
178 | self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads
179 | # Validate the correctness of rotary position embeddings parameters
180 | # BC: if there is a 'type' field, copy it it to 'rope_type'.
181 | if self.rope_scaling is not None and "type" in self.rope_scaling:
182 | self.rope_scaling["rope_type"] = self.rope_scaling["type"]
183 | rope_config_validation(self)
184 |
185 | self.expand_input_ids = expand_input_ids
186 | self.expand_input_ids_maxlen = expand_input_ids_maxlen
187 | self.expand_input_ids_vocab_size = expand_input_ids_vocab_size
188 | self.expand_input_ids_dict = expand_input_ids_dict
189 |
190 | super().__init__(
191 | pad_token_id=pad_token_id,
192 | bos_token_id=bos_token_id,
193 | eos_token_id=eos_token_id,
194 | tie_word_embeddings=tie_word_embeddings,
195 | **kwargs,
196 | )
--------------------------------------------------------------------------------
/tokenkit/model_kinds.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from typing import Callable, Dict, List, Optional
3 |
4 | from tokenkit.constants import BYTES_TO_CHARS, CHARS_TO_BYTES
5 |
6 | BYTE_FALLBACK_MAP = {f"<0x{num:02X}>": num for num in range(256)}
7 | INV_BYTE_FALLBACK_MAP = {v: k for k, v in BYTE_FALLBACK_MAP.items()}
8 |
9 |
10 | def sentencepiece_byte_fallback_byte_fn(token: str) -> str:
11 | if token in BYTE_FALLBACK_MAP:
12 | return BYTES_TO_CHARS[BYTE_FALLBACK_MAP[token]]
13 | else:
14 | return "".join(
15 | BYTES_TO_CHARS[b] for b in token.replace("▁", " ").encode("utf-8")
16 | )
17 |
18 |
19 | def sentencepiece_byte_fallback_precedence_fn(token: str) -> int:
20 | if token in BYTE_FALLBACK_MAP:
21 | return 0
22 | else:
23 | return 1
24 |
25 |
26 | def identity_byte_fn(token: str) -> str:
27 | return token
28 |
29 |
30 | class BaseModelKind(ABC):
31 | SPECIAL_KEYS = [
32 | "<||>",
33 | "<||>",
34 | "<||>",
35 | "<||>",
36 | "<||>",
37 | "<||>",
38 | "<||>",
39 | "<||>",
40 | ]
41 |
42 | def __init__(self):
43 | self._byte_fallback_fn = identity_byte_fn
44 | self._byte_fallback_precedence_fn = lambda x: 0
45 |
46 | @property
47 | @abstractmethod
48 | def special_tokens(self) -> List[str]:
49 | pass
50 |
51 | @property
52 | @abstractmethod
53 | def replacements(self) -> Dict[str, Optional[List[str]]]:
54 | pass
55 |
56 | @property
57 | def byte_fallback_fn(self) -> Callable[[str], str]:
58 | return self._byte_fallback_fn
59 |
60 | @byte_fallback_fn.setter
61 | def byte_fallback_fn(self, value: Callable[[str], str]):
62 | self._byte_fallback_fn = value
63 |
64 | @property
65 | def byte_fallback_precedence_fn(self) -> Callable[[str], int]:
66 | return self._byte_fallback_precedence_fn
67 |
68 | @byte_fallback_precedence_fn.setter
69 | def byte_fallback_precedence_fn(self, value: Callable[[str], int]):
70 | self._byte_fallback_precedence_fn = value
71 |
72 |
73 | class Qwen2ModelKind(BaseModelKind):
74 | @property
75 | def special_tokens(self) -> List[str]:
76 | return ["<|im_start|>", "<|im_end|>", "<|endoftext|>"]
77 |
78 | @property
79 | def replacements(self) -> Dict[str, Optional[List[str]]]:
80 | return {
81 | "<||>": None,
82 | "<||>": ["<|endoftext|>"],
83 | "<||>": ["<|im_start|>"],
84 | "<||>": ["Ċ"],
85 | "<||>": ["<|endoftext|>"],
86 | "<||>": ["<|im_end|>", "Ċ"],
87 | "<||>": ["system"],
88 | "<||>": ["user"],
89 | "<||>": ["assistant"],
90 | }
91 |
92 |
93 | class Llama3ModelKind(BaseModelKind):
94 | @property
95 | def special_tokens(self) -> List[str]:
96 | return [
97 | "<|begin_of_text|>",
98 | "<|start_header_id|>",
99 | "<|end_header_id|>",
100 | "<|eot_id|>",
101 | "<|end_of_text|>",
102 | ]
103 |
104 | @property
105 | def replacements(self) -> Dict[str, Optional[List[str]]]:
106 | return {
107 | "<||>": ["<|begin_of_text|>"],
108 | "<||>": ["<|end_of_text|>"],
109 | "<||>": ["<|start_header_id|>"],
110 | "<||>": ["<|end_header_id|>", "ĊĊ"],
111 | # give eot precedence over eos - not ideal but should work for chat templates
112 | "<||>": ["<|eot_id|>"],
113 | "<||>": ["<|eot_id|>"],
114 | "<||>": ["system"],
115 | "<||>": ["user"],
116 | "<||>": ["assistant"],
117 | }
118 |
119 |
120 | class Gemma2ModelKind(BaseModelKind):
121 | def __init__(self):
122 | super().__init__()
123 | self._byte_fallback_fn = sentencepiece_byte_fallback_byte_fn
124 | self._byte_fallback_precedence_fn = sentencepiece_byte_fallback_precedence_fn
125 |
126 | @property
127 | def special_tokens(self) -> List[str]:
128 | return ["", "", "", "", ""]
129 |
130 | @property
131 | def replacements(self) -> Dict[str, Optional[List[str]]]:
132 | return {
133 | "<||>": [""],
134 | "<||>": [""],
135 | "<||>": [""],
136 | "<||>": ["Ċ"],
137 | "<||>": [""],
138 | "<||>": ["", "Ċ"],
139 | "<||>": ["user"],
140 | "<||>": ["user"],
141 | "<||>": ["model"],
142 | }
143 |
144 |
145 | class Phi3ModelKind(BaseModelKind):
146 | @property
147 | def special_tokens(self) -> List[str]:
148 | return ["<|user|>", "<|assistant|>", "<|end|>", "<|endoftext|>"]
149 |
150 | @property
151 | def replacements(self) -> Dict[str, Optional[List[str]]]:
152 | return {
153 | "<||>": None,
154 | "<||>": ["<|endoftext|>"],
155 | "<||>": None,
156 | "<||>": ["Ċ"],
157 | "<||>": ["<|endoftext|>"],
158 | "<||>": ["<|end|>", "Ċ"],
159 | "<||>": ["<|user|>"],
160 | "<||>": ["<|user|>"],
161 | "<||>": ["<|assistant|>"],
162 | }
163 |
164 |
165 | class GPT2ModelKind(BaseModelKind):
166 | @property
167 | def special_tokens(self) -> List[str]:
168 | return ["<|endoftext|>"]
169 |
170 | @property
171 | def replacements(self) -> Dict[str, Optional[List[str]]]:
172 | return {
173 | "<||>": None,
174 | "<||>": ["<|endoftext|>"],
175 | "<||>": None,
176 | "<||>": None,
177 | "<||>": ["<|endoftext|>"],
178 | "<||>": ["<|endoftext|>"],
179 | "<||>": None,
180 | "<||>": None,
181 | "<||>": None,
182 | }
183 |
184 |
185 | class TinyLlamaModelKind(BaseModelKind):
186 | def __init__(self):
187 | super().__init__()
188 | self._byte_fallback_fn = sentencepiece_byte_fallback_byte_fn
189 | self._byte_fallback_precedence_fn = sentencepiece_byte_fallback_precedence_fn
190 |
191 | @property
192 | def special_tokens(self) -> List[str]:
193 | return ["", "", ""]
194 |
195 | @property
196 | def replacements(self) -> Dict[str, Optional[List[str]]]:
197 | return {
198 | "<||>": [""],
199 | "<||>": [""],
200 | "<||>": None, # chat template exists but not supported for TinyLlama
201 | "<||>": None,
202 | "<||>": [""],
203 | "<||>": [""],
204 | "<||>": None,
205 | "<||>": None,
206 | "<||>": None,
207 | }
208 |
209 |
210 | class MistralModelKind(BaseModelKind):
211 | def __init__(self):
212 | super().__init__()
213 |
214 | @property
215 | def special_tokens(self) -> List[str]:
216 | return ["", "", ""]
217 |
218 | @property
219 | def replacements(self) -> Dict[str, Optional[List[str]]]:
220 | return {
221 | "<||>": [""],
222 | "<||>": [""],
223 | "<||>": None, # chat template exists but not supported for Mistral
224 | "<||>": None,
225 | "<||>": [""],
226 | "<||>": [""],
227 | "<||>": None,
228 | "<||>": None,
229 | "<||>": None,
230 | }
231 |
232 | # Model kind registry
233 | def get_model_kind_cls(model_kind: str) -> BaseModelKind:
234 | return {
235 | "Qwen2": Qwen2ModelKind(),
236 | "Llama3": Llama3ModelKind(),
237 | "Gemma2": Gemma2ModelKind(),
238 | "Phi3": Phi3ModelKind(),
239 | "GPT2": GPT2ModelKind(),
240 | "TinyLlama": TinyLlamaModelKind(),
241 | "Mistral": MistralModelKind(),
242 | }[model_kind]
243 |
--------------------------------------------------------------------------------
/tokenkit/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bminixhofer/tokenkit/7abac3578f6b3fa38f985e9f03bff7a47d5ab3b1/tokenkit/models/__init__.py
--------------------------------------------------------------------------------
/tokenkit/models/lora.py:
--------------------------------------------------------------------------------
1 | import jax
2 | import jax.numpy as jnp
3 | import numpy as np
4 | import regex as re
5 |
6 | from tokenkit import utils
7 |
8 | LORA_PATTERNS = {
9 | "llama": [
10 | ".*self_attn.(q_proj|k_proj|v_proj).kernel",
11 | ".*self_attn.o_proj.kernel",
12 | ".*mlp.down_proj.kernel",
13 | ".*mlp.up_proj.kernel",
14 | ".*mlp.gate_proj.kernel",
15 | ],
16 | "gemma2": [
17 | ".*self_attn.(q_proj|k_proj|v_proj).kernel",
18 | ".*self_attn.o_proj.kernel",
19 | ".*mlp.down_proj.kernel",
20 | ".*mlp.up_proj.kernel",
21 | ".*mlp.gate_proj.kernel",
22 | ],
23 | }
24 | LORA_PATTERNS["tpu_llama"] = LORA_PATTERNS["llama"]
25 | LORA_PATTERNS["tpu_gemma2"] = LORA_PATTERNS["gemma2"]
26 |
27 |
28 | def init_lora_params(args, params, model_type, seed, dtype=jnp.float32):
29 | def iter_keys(key):
30 | while True:
31 | key, out_key = jax.random.split(key)
32 | yield out_key
33 |
34 | key_it = iter_keys(jax.random.PRNGKey(seed))
35 |
36 | lora_patterns = LORA_PATTERNS[model_type]
37 | lora_rank = args.model_lora_rank
38 | stddev = 1.0 / lora_rank
39 |
40 | def init_lora(path, param):
41 | path_tuple = tuple(str(utils.keystr(x)) for x in path)
42 | path = ".".join(path_tuple)
43 |
44 | lora_params = np.array([]) # indicates no lora params
45 |
46 | for key in lora_patterns:
47 | if re.match(key, path):
48 | assert len(param.shape) == 2
49 | b_dim, a_dim = param.shape
50 |
51 | b = np.zeros((b_dim, lora_rank), dtype=dtype)
52 | a = jax.device_get(
53 | jax.random.normal(next(key_it), (lora_rank, a_dim), dtype=dtype)
54 | * stddev
55 | )
56 | lora_params = {"a": a, "b": b}
57 |
58 | return lora_params
59 |
60 | return jax.tree_util.tree_map_with_path(init_lora, params)
61 |
62 |
63 | def materialize_lora(param_tree, lora_param_tree, alpha):
64 | def materialize(param, lora_params):
65 | if not isinstance(lora_params, dict):
66 | assert lora_params.shape[0] == 0
67 | return param
68 |
69 | a, b = lora_params["a"], lora_params["b"]
70 | scale = alpha / b.shape[-1]
71 |
72 | return (param + scale * b @ a).astype(param.dtype)
73 |
74 | return jax.tree.map(materialize, param_tree, lora_param_tree)
75 |
76 |
77 | # NOTE: not clear if this is save w.r.t. rounding errors. probably not? dangerous.
78 | # NOTE: update: no instability so far, seems safe in fp32. but still dangerous.
79 | def dematerialize_lora(param_tree, lora_param_tree, alpha):
80 | def dematerialize(param, lora_params):
81 | if not isinstance(lora_params, dict):
82 | return param
83 |
84 | a, b = lora_params["a"], lora_params["b"]
85 | scale = alpha / b.shape[-1]
86 |
87 | return (param - scale * b @ a).astype(param.dtype)
88 |
89 | return jax.tree.map(dematerialize, param_tree, lora_param_tree)
90 |
--------------------------------------------------------------------------------
/tokenkit/models/param.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import json
3 |
4 | import jax.numpy as jnp
5 | import numpy as np
6 | from flax import serialization, traverse_util
7 | from transformers import AutoConfig
8 | from transformers.utils.hub import cached_file
9 |
10 |
11 | def get_input_embedding_path(model_type):
12 | return {
13 | "gpt2": "transformer.wte.embedding",
14 | "roberta": "roberta.embeddings.word_embeddings.embedding",
15 | "xlm-roberta": "roberta.embeddings.word_embeddings.embedding",
16 | "xglm": "model.embed_tokens.embedding",
17 | "mistral": "model.embed_tokens.embedding",
18 | "llama": "model.embed_tokens.embedding",
19 | "tpu_llama": "model.embed_tokens.embedding",
20 | "gemma": "model.embed_tokens.embedding",
21 | "gemma2": "model.embed_tokens.embedding",
22 | "tpu_gemma2": "model.embed_tokens.embedding",
23 | }[model_type]
24 |
25 |
26 | def get_output_embedding_path(model_type):
27 | return {
28 | "gpt2": "lm_head.kernel",
29 | "roberta": None,
30 | "xlm-roberta": None,
31 | "xglm": None,
32 | "mistral": "lm_head.kernel",
33 | "llama": "lm_head.kernel",
34 | "tpu_llama": "lm_head.kernel",
35 | "gemma": "lm_head.kernel",
36 | "gemma2": "lm_head.kernel",
37 | "tpu_gemma2": "lm_head.kernel",
38 | }[model_type]
39 |
40 |
41 | def get_layer_path(model_type):
42 | return {
43 | "gemma2": "model.layers",
44 | "gpt2": "transformer.h",
45 | "llama": "model.layers",
46 | "tpu_llama": "model.layers",
47 | "tpu_gemma2": "model.layers",
48 | }[model_type]
49 |
50 |
51 | def load_params(**kwargs):
52 | kwargs = copy.copy(kwargs)
53 | config = AutoConfig.from_pretrained(**kwargs)
54 | path = kwargs.pop("pretrained_model_name_or_path")
55 | embedding_path = kwargs.pop("embedding_path", None)
56 |
57 | try:
58 | index = cached_file(path, "flax_model.msgpack.index.json", **kwargs)
59 | except OSError:
60 | index = None
61 |
62 | if index is not None:
63 | index = json.load(open(index))
64 | files = [
65 | cached_file(path, x, **kwargs) for x in set(index["weight_map"].values())
66 | ]
67 | else:
68 | files = [cached_file(path, "flax_model.msgpack", **kwargs)]
69 |
70 | flat_params = {}
71 | for x in files:
72 | flat_params.update(
73 | traverse_util.flatten_dict(
74 | serialization.msgpack_restore(open(x, "rb").read())
75 | )
76 | )
77 |
78 | params = traverse_util.unflatten_dict(flat_params)
79 |
80 | if embedding_path is not None:
81 | embeddings = np.load(embedding_path)
82 | params = put(
83 | params, get_input_embedding_path(config.model_type), embeddings[:, 0]
84 | )
85 | if embeddings.shape[1] > 1:
86 | params = put(
87 | params, get_output_embedding_path(config.model_type), embeddings[:, 1].T
88 | )
89 |
90 | return params
91 |
92 |
93 | def put(pytree, path, value):
94 | path = tuple(path.split("."))
95 |
96 | flat_pytree = traverse_util.flatten_dict(pytree)
97 | # this is potentially safer than simply overwriting, preserves dtype etc.
98 | if path in flat_pytree and isinstance(flat_pytree[path], jnp.ndarray):
99 | flat_pytree[path] = flat_pytree[path].at[:].set(value)
100 | else:
101 | flat_pytree[path] = value
102 |
103 | return traverse_util.unflatten_dict(flat_pytree)
104 |
105 |
106 | def pop(pytree, path):
107 | path = tuple(path.split("."))
108 | flat_pytree = traverse_util.flatten_dict(pytree)
109 | if path in flat_pytree:
110 | value = flat_pytree.pop(path)
111 | else:
112 | value = None
113 |
114 | return traverse_util.unflatten_dict(flat_pytree), value
115 |
116 |
117 | def get(pytree, path):
118 | path = tuple(path.split("."))
119 | out = traverse_util.flatten_dict(pytree)[path]
120 |
121 | if isinstance(out, dict):
122 | return traverse_util.unflatten_dict(out)
123 | else:
124 | return out
125 |
126 |
127 | def keys(pytree):
128 | return [".".join(x) for x in traverse_util.flatten_dict(pytree).keys()]
129 |
130 |
131 | def assign_embeddings(model_params, embeddings, config):
132 | model_params = put(
133 | model_params,
134 | get_input_embedding_path(config.model_type),
135 | embeddings[:, 0],
136 | )
137 | if not config.tie_word_embeddings:
138 | model_params = put(
139 | model_params,
140 | get_output_embedding_path(config.model_type),
141 | embeddings[:, 1].T,
142 | )
143 |
144 | return model_params
145 |
146 |
147 | def unassign_embeddings(model_params, config):
148 | model_params, x = pop(model_params, get_input_embedding_path(config.model_type))
149 | if isinstance(x, jnp.ndarray):
150 | x.delete()
151 | if get_output_embedding_path(config.model_type):
152 | model_params, x = pop(
153 | model_params, get_output_embedding_path(config.model_type)
154 | )
155 | if isinstance(x, jnp.ndarray):
156 | x.delete()
157 |
158 | return model_params
159 |
160 |
161 | def stack_embeddings(model_params, config, pop_embeddings=False):
162 | if config.tie_word_embeddings:
163 | input_embeddings = get(
164 | model_params, get_input_embedding_path(config.model_type)
165 | )
166 |
167 | embeddings = input_embeddings[:, None, :]
168 | else:
169 | input_embeddings = get(
170 | model_params, get_input_embedding_path(config.model_type)
171 | )
172 | output_embeddings = get(
173 | model_params, get_output_embedding_path(config.model_type)
174 | )
175 |
176 | embeddings = np.stack([input_embeddings, output_embeddings.T], axis=1)
177 |
178 | if pop_embeddings:
179 | model_params = unassign_embeddings(model_params, config)
180 |
181 | return embeddings, model_params
182 |
183 |
184 | def get_num_layers(config):
185 | if hasattr(config, "num_hidden_layers"):
186 | return config.num_hidden_layers
187 | elif hasattr(config, "n_layer"): # gpt2
188 | return config.n_layer
189 | else:
190 | raise ValueError("Could not determine number of layers from config")
191 |
192 |
193 | def set_num_layers(config, num_layers):
194 | if hasattr(config, "num_hidden_layers"):
195 | config.num_hidden_layers = num_layers
196 | elif hasattr(config, "n_layer"): # gpt2
197 | config.n_layer = num_layers
198 | else:
199 | raise ValueError("Could not determine number of layers from config")
200 |
201 |
202 | def get_layer_n_mask(model_params, config, layer_idx):
203 | if layer_idx < 0:
204 | layer_idx = get_num_layers(config) + layer_idx
205 |
206 | flat_params = traverse_util.flatten_dict(model_params)
207 | mask = {}
208 | subpath = f"{get_layer_path(config.model_type)}.{layer_idx}"
209 |
210 | for key in flat_params.keys():
211 | if subpath in ".".join(key):
212 | mask[key] = True
213 | else:
214 | mask[key] = False
215 |
216 | return traverse_util.unflatten_dict(mask)
217 |
218 |
219 | def strip_layers(model_params, config, n_keep=1):
220 | for layer_idx in range(n_keep, get_num_layers(config)):
221 | model_params, _ = pop(
222 | model_params, f"{get_layer_path(config.model_type)}.{layer_idx}"
223 | )
224 |
225 | set_num_layers(config, n_keep)
226 |
227 | return model_params
228 |
--------------------------------------------------------------------------------
/tokenkit/models/sharding.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | import jax
4 | import jax.experimental
5 | import jax.experimental.mesh_utils
6 | import regex as re
7 | from jax.experimental.multihost_utils import process_allgather
8 | from jax.sharding import PartitionSpec as P
9 | import numpy as np
10 |
11 | from tokenkit import utils
12 |
13 | logger = logging.getLogger(__name__)
14 |
15 |
16 | SHARD_PATTERNS = {
17 | "hypernet": {
18 | "(opt_state|params).*ffn_layer1.linear": P(None, "model"),
19 | "(opt_state|params).*ffn_layer2.linear": P("model", None),
20 | "(opt_state|params).*self_attention.(query|key|value).w": P(None, "model"),
21 | "(opt_state|params).*self_attention.post.w": P("model", None),
22 | "(opt_state|params).*embeddings": P("model", None),
23 | },
24 | "llama": {
25 | "(opt_state|params).*embed_tokens.*embedding": P("model", "data"),
26 | "(opt_state|params).*self_attn.(q_proj|k_proj|v_proj).kernel.a": P(
27 | "model", "data"
28 | ),
29 | "(opt_state|params).*self_attn.(q_proj|k_proj|v_proj).kernel.b": P(
30 | "data", "model"
31 | ),
32 | "(opt_state|params).*self_attn.(q_proj|k_proj|v_proj).kernel.w": P(
33 | "data", "model"
34 | ),
35 | "(opt_state|params).*norm.weight": P("model"),
36 | "(opt_state|params).*self_attn.(q_proj|k_proj|v_proj).kernel": P(
37 | "data", "model"
38 | ),
39 | "(opt_state|params).*self_attn.o_proj.kernel": P("model", "data"),
40 | "(opt_state|params).*lm_head.kernel": P("data", "model"),
41 | "(opt_state|params).*mlp.down_proj.kernel": P("model", "data"),
42 | "(opt_state|params).*mlp.up_proj.kernel": P("data", "model"),
43 | "(opt_state|params).*mlp.gate_proj.kernel": P("data", "model"),
44 | "(opt_state|params).*norm.kernel": P("model"),
45 | ".*(cached_value|cached_key)": P("data", None, "model", None),
46 | },
47 | "mistral": {
48 | "(opt_state|params).*embed_tokens.*embedding": P("model", None),
49 | "(opt_state|params).*self_attn.(q_proj|k_proj|v_proj).kernel": P(None, "model"),
50 | "(opt_state|params).*self_attn.o_proj.kernel": P("model", None),
51 | "(opt_state|params).*lm_head.kernel": P(None, "model"),
52 | "(opt_state|params).*mlp.down_proj.kernel": P("model", None),
53 | "(opt_state|params).*mlp.up_proj.kernel": P(None, "model"),
54 | "(opt_state|params).*mlp.gate_proj.kernel": P(None, "model"),
55 | },
56 | "gemma": {
57 | "(opt_state|params).*embed_tokens.*embedding": P("model", "data"),
58 | "(opt_state|params).*self_attn.(q_proj|k_proj|v_proj).kernel.a": P(
59 | "model", "data"
60 | ),
61 | "(opt_state|params).*self_attn.(q_proj|k_proj|v_proj).kernel.b": P(
62 | "data", "model"
63 | ),
64 | "(opt_state|params).*self_attn.(q_proj|k_proj|v_proj).kernel.w": P(
65 | "data", "model"
66 | ),
67 | "(opt_state|params).*self_attn.(q_proj|k_proj|v_proj).kernel": P(
68 | "data", "model"
69 | ),
70 | "(opt_state|params).*self_attn.o_proj.kernel": P("model", "data"),
71 | "(opt_state|params).*lm_head.kernel": P("data", "model"),
72 | "(opt_state|params).*mlp.down_proj.kernel": P("model", "data"),
73 | "(opt_state|params).*mlp.up_proj.kernel": P("data", "model"),
74 | "(opt_state|params).*mlp.gate_proj.kernel": P("data", "model"),
75 | "(opt_state|params).*norm.kernel": P("model"),
76 | },
77 | "gemma2": {
78 | "(opt_state|params).*embed_tokens.*embedding": P("model", "data"),
79 | "(opt_state|params).*self_attn.(q_proj|k_proj|v_proj).kernel.a": P(
80 | "model", "data"
81 | ),
82 | "(opt_state|params).*self_attn.(q_proj|k_proj|v_proj).kernel.b": P(
83 | "data", "model"
84 | ),
85 | "(opt_state|params).*self_attn.(q_proj|k_proj|v_proj).kernel.w": P(
86 | "data", "model"
87 | ),
88 | "(opt_state|params).*self_attn.(q_proj|k_proj|v_proj).kernel": P(
89 | "data", "model"
90 | ),
91 | "(opt_state|params).*self_attn.o_proj.kernel": P("model", "data"),
92 | "(opt_state|params).*lm_head.kernel": P("data", "model"),
93 | "(opt_state|params).*mlp.down_proj.kernel": P("model", "data"),
94 | "(opt_state|params).*mlp.up_proj.kernel": P("data", "model"),
95 | "(opt_state|params).*mlp.gate_proj.kernel": P("data", "model"),
96 | "(opt_state|params).*norm.kernel": P("model"),
97 | },
98 | "gpt2": {
99 | "(opt_state|params).*c_attn.kernel": P(None, "model"),
100 | "(opt_state|params).*c_proj.kernel": P("model", None),
101 | "(opt_state|params).*c_fc.kernel": P(None, "model"),
102 | },
103 | "xlm-roberta": {
104 | "(opt_state|params).*self.(query|key|value).kernel": P(None, "model"),
105 | "(opt_state|params).*output.dense.kernel": P("model", None),
106 | "(opt_state|params).*intermediate.dense.kernel": P(None, "model"),
107 | },
108 | }
109 | SHARD_PATTERNS["tpu_llama"] = SHARD_PATTERNS["llama"]
110 | SHARD_PATTERNS["tpu_gemma2"] = SHARD_PATTERNS["gemma2"]
111 |
112 | def get_shard_patterns(kind):
113 | return SHARD_PATTERNS.get(kind, {})
114 |
115 |
116 | def get_sharding_fn(shard_patterns, mesh):
117 | name_to_size = {name: size for name, size in mesh.shape_tuple}
118 |
119 | def get_pspec(path, v):
120 | # this is a dummy parameter for e.g. PEFT, so no need to shard
121 | if np.prod(v.shape) == 0:
122 | return P()
123 |
124 | path_tuple = tuple(str(utils.keystr(x)) for x in path)
125 | path = ".".join(path_tuple)
126 |
127 | for key, value in shard_patterns.items():
128 | if re.match(key, path):
129 | pspec = value
130 | for dim, name in enumerate(pspec):
131 | if name is None:
132 | continue
133 |
134 | if name not in name_to_size:
135 | raise ValueError(
136 | f"Unknown sharding name {name} in {pspec} for {path}"
137 | )
138 |
139 | if v.shape[dim] % name_to_size[name] != 0:
140 | logger.warning(
141 | "Want to shard %s with %s, but shape %s is not divisible by %s.",
142 | path,
143 | pspec,
144 | v.shape,
145 | name_to_size[name],
146 | )
147 | return P()
148 |
149 | logger.debug("Sharding %s with %s.", path, pspec)
150 | return P(*pspec)
151 |
152 | return P()
153 |
154 | def get_tree_shardings(tree):
155 | pspecs = jax.tree_util.tree_map_with_path(get_pspec, tree)
156 | return jax.tree.map(
157 | lambda pspec: jax.sharding.NamedSharding(mesh, pspec), pspecs
158 | )
159 |
160 | return get_tree_shardings
161 |
162 |
163 | def to_global_array(pytree, pytree_sharding=None):
164 | if pytree_sharding is None:
165 | pytree_sharding = jax.tree.map(lambda _: None, pytree)
166 |
167 | def to_global_array_fn(array, sharding):
168 | if array is None:
169 | return None
170 |
171 | if sharding is None:
172 | return array
173 |
174 | def cb(index):
175 | return array[index]
176 |
177 | return jax.make_array_from_callback(array.shape, sharding, cb)
178 |
179 | return jax.tree.map(to_global_array_fn, pytree, pytree_sharding)
180 |
181 |
182 | def sync_across_devices(pytree):
183 | if jax.process_count() == 1:
184 | return pytree
185 |
186 | return jax.tree.map(lambda x: x[0], process_allgather(pytree))
187 |
188 |
189 | def to_devices(pytree, pytree_sharding=None, dtype=None):
190 | # TODO: handle non-numpy inputs?
191 | pytree = to_global_array(pytree, pytree_sharding)
192 |
193 | return jax.jit(
194 | lambda x: x if dtype is None else jax.tree.map(lambda x: x.astype(dtype), x),
195 | in_shardings=(pytree_sharding,) if pytree_sharding is not None else None,
196 | out_shardings=pytree_sharding,
197 | )(pytree)
198 |
199 |
200 | def get_mesh(n_data_parallel=1, n_model_parallel=-1, devices=None):
201 | if devices is None:
202 | devices = jax.devices()
203 |
204 | device_count = len(devices)
205 |
206 | if n_data_parallel == -1:
207 | n_data_parallel = device_count
208 |
209 | if n_model_parallel == -1:
210 | n_model_parallel = device_count
211 |
212 | devices = jax.experimental.mesh_utils.create_device_mesh(
213 | mesh_shape=(n_data_parallel, n_model_parallel),
214 | devices=devices,
215 | )
216 | return jax.sharding.Mesh(devices, ["data", "model"])
217 |
--------------------------------------------------------------------------------
/tokenkit/parse_args.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass, fields, is_dataclass
2 | from transformers import HfArgumentParser
3 | from argparse import ArgumentParser
4 | import yaml
5 |
6 |
7 | @dataclass
8 | class HypernetArgs:
9 | architecture: str
10 | num_layers: int
11 | residual: bool
12 | residual_alpha: float
13 | use_attention: bool
14 | use_attention_mask: bool = False
15 | num_heads: int = 16
16 | shared: bool = True
17 | multiply_hidden_dim_by_num_embeddings: bool = True
18 |
19 | @dataclass
20 | class EvalArgs:
21 | tasks: list[str]
22 | lengths: list[int]
23 | tokens_per_batch: int
24 | add_bos: bool
25 | chat_template_mode: str
26 | confirm_run_unsafe_code: bool
27 |
28 | @dataclass
29 | class ModelArgs:
30 | pretrained_model_name_or_path: str
31 | tokenizer_name: str
32 | revision: str | None = None
33 |
34 | def restore_dataclasses(args, cls):
35 | for field in fields(cls):
36 | if is_dataclass(field.type):
37 | setattr(
38 | args,
39 | field.name,
40 | restore_dataclasses(getattr(args, field.name), field.type),
41 | )
42 | elif isinstance(field.type, list) and is_dataclass(field.type.__args__[0]):
43 | setattr(
44 | args,
45 | field.name,
46 | [
47 | restore_dataclasses(item, field.type.__args__[0])
48 | for item in getattr(args, field.name)
49 | ],
50 | )
51 | elif isinstance(field.type, dict):
52 | setattr(
53 | args,
54 | field.name,
55 | {
56 | k: restore_dataclasses(v, field.type.__args__[1])
57 | for k, v in getattr(args, field.name).items()
58 | },
59 | )
60 |
61 | if not isinstance(args, cls):
62 | return cls(**args) if args is not None else None
63 |
64 | return args
65 |
66 |
67 | def parse_args(cls):
68 | parser = ArgumentParser()
69 | parser.add_argument("--config", type=str, required=True)
70 | parser.add_argument("--overrides", type=str, nargs="*")
71 | meta_args = parser.parse_args()
72 |
73 | (args,) = HfArgumentParser([cls]).parse_yaml_file(meta_args.config)
74 |
75 | for overrides in (meta_args.overrides or []):
76 | for override in overrides.split():
77 | first_equals = override.find("=")
78 | key = override[:first_equals].split(".")
79 | try:
80 | value = yaml.safe_load(override[first_equals + 1 :])
81 | except yaml.YAMLError:
82 | raise ValueError(f"Invalid YAML: {override[first_equals + 1 :]}")
83 |
84 | current = args
85 | for k in key[:-1]:
86 | if isinstance(current, list):
87 | current = current[int(k)]
88 | elif isinstance(current, dict):
89 | current = current[k]
90 | else:
91 | current = getattr(current, k)
92 |
93 | if isinstance(current, list):
94 | if int(key[-1]) >= len(current):
95 | raise ValueError(f"Invalid key: {key[-1]}")
96 | current[int(key[-1])] = value
97 | elif isinstance(current, dict):
98 | if key[-1] not in current:
99 | raise ValueError(f"Invalid key: {key[-1]}")
100 | current[key[-1]] = value
101 | else:
102 | if not hasattr(current, key[-1]):
103 | raise ValueError(f"Invalid key: {key[-1]}")
104 | setattr(current, key[-1], value)
105 |
106 | return restore_dataclasses(args, cls)
107 |
--------------------------------------------------------------------------------
/tokenkit/training/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bminixhofer/tokenkit/7abac3578f6b3fa38f985e9f03bff7a47d5ab3b1/tokenkit/training/__init__.py
--------------------------------------------------------------------------------
/tokenkit/training/checkpoint.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import jax
4 | from flax import serialization, traverse_util
5 | from jax.experimental import multihost_utils
6 | from jax.sharding import NamedSharding
7 | from jax.sharding import PartitionSpec as P
8 |
9 |
10 | def save(
11 | path,
12 | params,
13 | param_shardings,
14 | mesh,
15 | train_mask,
16 | keys_to_keep={
17 | "hypernet",
18 | },
19 | batch_size=16,
20 | ):
21 | flat_keys_to_save = [
22 | k
23 | for k, trainable in traverse_util.flatten_dict(train_mask).items()
24 | if trainable or k[0] in keys_to_keep
25 | ]
26 | flat_params = traverse_util.flatten_dict(params)
27 | flat_shardings = traverse_util.flatten_dict(param_shardings)
28 |
29 | flat_params_to_save = {k: flat_params[k] for k in flat_keys_to_save}
30 | shardings_to_save = {k: flat_shardings[k] for k in flat_keys_to_save}
31 |
32 | none_shardings_to_save = jax.tree.map(
33 | lambda _: NamedSharding(mesh, P()), shardings_to_save
34 | )
35 |
36 | keys = list(flat_params_to_save.keys())
37 | n_batches = math.ceil(len(keys) / batch_size)
38 |
39 | all_flat_out_params = {}
40 |
41 | for i in range(n_batches):
42 | batch_keys = keys[i * batch_size : (i + 1) * batch_size]
43 |
44 | flat_device_params = jax.jit(
45 | lambda x: x,
46 | in_shardings=([shardings_to_save[k] for k in batch_keys],),
47 | out_shardings=[none_shardings_to_save[k] for k in batch_keys],
48 | )([flat_params_to_save[k] for k in batch_keys])
49 |
50 | for key, value in zip(batch_keys, flat_device_params):
51 | all_flat_out_params[key] = jax.device_get(value)
52 | value.delete()
53 |
54 | if jax.process_index() == 0:
55 | open(path, "wb").write(
56 | serialization.msgpack_serialize(
57 | traverse_util.unflatten_dict(all_flat_out_params), in_place=True
58 | )
59 | )
60 |
61 | multihost_utils.sync_global_devices("saved checkpoint")
62 |
--------------------------------------------------------------------------------
/tokenkit/training/collators/__init__.py:
--------------------------------------------------------------------------------
1 | from tokenkit.training.collators.tokenizer_aligner import TokenizerAlignerCollator
2 | from tokenkit.training.collators.tokenizer_sampler import TokenizerSamplerCollator
3 |
4 | __all__ = [
5 | "TokenizerAlignerCollator",
6 | "TokenizerSamplerCollator",
7 | ]
8 |
--------------------------------------------------------------------------------
/tokenkit/training/collators/tokenizer_aligner.py:
--------------------------------------------------------------------------------
1 | import json
2 | from pathlib import Path
3 |
4 | import numpy as np
5 | from jax.sharding import PartitionSpec as P
6 | from scipy import sparse
7 |
8 | from tokenkit import align, utils
9 |
10 |
11 | class TokenizerAlignerCollator:
12 | def __init__(
13 | self,
14 | tokenizer_original,
15 | tokenizer_new,
16 | max_teacher_length,
17 | max_student_length,
18 | use_chat_template=False,
19 | chat_template_mode="direct_encode",
20 | expand_input_ids_dict=None,
21 | loss_mask_mode=None,
22 | tokenizer_pair_data_path=None,
23 | tokenizer_pair_bias_threshold=0.0,
24 | require_bias_matrices=False,
25 | ):
26 | self.tokenizer_original = tokenizer_original
27 | self.tokenizer_original_vocab = tokenizer_original.get_vocab()
28 | self.tokenizer_new = tokenizer_new
29 | self.max_teacher_length = max_teacher_length
30 | self.max_student_length = max_student_length
31 | self.use_chat_template = use_chat_template
32 | self.chat_template_mode = chat_template_mode
33 | self.expand_input_ids_dict = expand_input_ids_dict
34 |
35 | if loss_mask_mode is None:
36 | loss_mask_string = None
37 | elif loss_mask_mode == "dolly":
38 | loss_mask_string = "### Response:\n"
39 | elif loss_mask_mode == "openmath2":
40 | loss_mask_string = "<|start_header_id|>assistant<|end_header_id|>\n\n"
41 | else:
42 | raise ValueError(f"Unknown loss mask mode: {loss_mask_mode}")
43 |
44 | self.loss_mask_tokens_original = (
45 | self.tokenizer_original.encode(loss_mask_string, add_special_tokens=False)
46 | if loss_mask_string is not None
47 | else None
48 | )
49 | self.loss_mask_tokens_new = (
50 | self.tokenizer_new.encode(loss_mask_string, add_special_tokens=False)
51 | if loss_mask_string is not None
52 | else None
53 | )
54 |
55 | bias1_matrix_path = Path(tokenizer_pair_data_path) / "bias1_matrix.npz"
56 | bias2_matrix_path = Path(tokenizer_pair_data_path) / "bias2_matrix.npz"
57 | teacher_token_counts_path = (
58 | Path(tokenizer_pair_data_path) / "teacher_counts.json"
59 | )
60 | student_token_counts_path = (
61 | Path(tokenizer_pair_data_path) / "student_counts.json"
62 | )
63 |
64 | if bias1_matrix_path.exists():
65 | self.tokenizer_pair_bias1_matrix = sparse.load_npz(
66 | bias1_matrix_path
67 | ).todok()
68 | else:
69 | self.tokenizer_pair_bias1_matrix = None
70 | if bias2_matrix_path.exists():
71 | self.tokenizer_pair_bias2_matrix = sparse.load_npz(
72 | bias2_matrix_path
73 | ).todok()
74 | else:
75 | self.tokenizer_pair_bias2_matrix = None
76 | if teacher_token_counts_path.exists():
77 | self.teacher_token_probs = utils.compute_unigram_probabilities(
78 | tokenizer_original, json.load(open(teacher_token_counts_path))
79 | )
80 | else:
81 | self.teacher_token_probs = None
82 | if student_token_counts_path.exists():
83 | self.student_token_probs = utils.compute_unigram_probabilities(
84 | tokenizer_new, json.load(open(student_token_counts_path))
85 | )
86 | else:
87 | self.student_token_probs = None
88 |
89 | if require_bias_matrices and (
90 | self.tokenizer_pair_bias1_matrix is None
91 | or self.tokenizer_pair_bias2_matrix is None
92 | ):
93 | raise ValueError(
94 | "Bias matrices are required but not found in the given path."
95 | )
96 |
97 | self.tokenizer_pair_bias_threshold = tokenizer_pair_bias_threshold
98 |
99 | self.prefix_map_original = self._compute_prefix_map(tokenizer_original)
100 | self.prefix_map_new = self._compute_prefix_map(tokenizer_new)
101 |
102 | def _compute_loss_mask(self, input_ids, attention_mask, loss_mask_tokens):
103 | loss_mask = attention_mask.astype(bool)
104 | if loss_mask_tokens is not None:
105 | for i in range(len(input_ids)):
106 | for j in range(len(input_ids[i])):
107 | if input_ids[i][j] != loss_mask_tokens[0]:
108 | continue
109 |
110 | if (
111 | input_ids[i][j : j + len(loss_mask_tokens)].tolist()
112 | == loss_mask_tokens
113 | ):
114 | loss_mask[i, : j + len(loss_mask_tokens)] = False
115 |
116 | return loss_mask
117 |
118 | def _compute_prefix_map(self, tokenizer):
119 | prefix_map = {}
120 |
121 | for token in tokenizer.get_vocab().keys():
122 | for i in range(1, len(token) + 1):
123 | if token[:i] in prefix_map:
124 | prefix_map[token[:i]].append(token)
125 | else:
126 | prefix_map[token[:i]] = [token]
127 |
128 | return prefix_map
129 |
130 | def _encode_with_chat_template(self, texts, tokenizer, max_length):
131 | input_ids = np.full(
132 | (len(texts), max_length), fill_value=tokenizer.pad_token_id, dtype=np.int32
133 | )
134 | attention_mask = np.zeros((len(texts), max_length), dtype=np.int32)
135 |
136 | for i in range(len(texts)):
137 | current_input_ids, _ = utils.encode_prompt(
138 | utils.preprocess_prompt(texts[i], self.chat_template_mode),
139 | tokenizer,
140 | max_length=max_length,
141 | )
142 | input_ids[i, : len(current_input_ids)] = current_input_ids
143 | attention_mask[i, : len(current_input_ids)] = 1
144 |
145 | return {
146 | "input_ids": input_ids,
147 | "attention_mask": attention_mask,
148 | }
149 |
150 | def __call__(self, examples):
151 | # batched internally
152 | examples = examples[0]
153 |
154 | texts = examples["text"]
155 |
156 | if self.use_chat_template:
157 | encoding_original = self._encode_with_chat_template(
158 | texts,
159 | tokenizer=self.tokenizer_original,
160 | max_length=self.max_teacher_length,
161 | )
162 | encoding_new = self._encode_with_chat_template(
163 | texts,
164 | tokenizer=self.tokenizer_new,
165 | max_length=self.max_student_length,
166 | )
167 | else:
168 | encoding_original = self.tokenizer_original(
169 | texts,
170 | max_length=self.max_teacher_length,
171 | padding="max_length",
172 | truncation=True,
173 | return_tensors="np",
174 | )
175 | encoding_new = self.tokenizer_new(
176 | texts,
177 | max_length=self.max_student_length,
178 | padding="max_length",
179 | truncation=True,
180 | return_tensors="np",
181 | )
182 |
183 | input_ids_original = encoding_original["input_ids"]
184 | attention_mask_original = encoding_original["attention_mask"]
185 | input_ids_new = encoding_new["input_ids"]
186 | attention_mask_new = encoding_new["attention_mask"]
187 |
188 | (
189 | alignment_matrix_a,
190 | alignment_matrix_b,
191 | ) = align.get_unconstrained_alignments(
192 | input_ids_original,
193 | input_ids_new,
194 | attention_mask_original,
195 | attention_mask_new,
196 | tokenizer_teacher=self.tokenizer_original,
197 | tokenizer_student=self.tokenizer_new,
198 | )
199 |
200 | (
201 | alignment_matrix_a_space,
202 | alignment_matrix_b_space,
203 | ) = align.get_space_alignments(
204 | input_ids_original,
205 | input_ids_new,
206 | attention_mask_original,
207 | attention_mask_new,
208 | tokenizer_teacher=self.tokenizer_original,
209 | tokenizer_student=self.tokenizer_new,
210 | )
211 |
212 | if (
213 | self.tokenizer_pair_bias1_matrix is not None
214 | and self.tokenizer_pair_bias2_matrix is not None
215 | ):
216 | (
217 | alignment_matrix_a_unbiased,
218 | alignment_matrix_b_unbiased,
219 | ) = align.get_unbiased_alignments(
220 | input_ids_original,
221 | input_ids_new,
222 | attention_mask_original,
223 | attention_mask_new,
224 | tokenizer_teacher=self.tokenizer_original,
225 | tokenizer_student=self.tokenizer_new,
226 | pair_data=(
227 | self.tokenizer_pair_bias1_matrix,
228 | self.tokenizer_pair_bias2_matrix,
229 | self.teacher_token_probs,
230 | self.student_token_probs,
231 | ),
232 | bias_threshold=self.tokenizer_pair_bias_threshold,
233 | )
234 | else:
235 | alignment_matrix_a_unbiased = np.full_like(
236 | alignment_matrix_a, fill_value=np.nan
237 | )
238 | alignment_matrix_b_unbiased = np.full_like(
239 | alignment_matrix_b, fill_value=np.nan
240 | )
241 |
242 | occuring_tokens_mask_original = np.zeros(
243 | len(self.tokenizer_original), dtype=bool
244 | )
245 | occuring_tokens_mask_new = np.zeros(len(self.tokenizer_new), dtype=bool)
246 |
247 | occuring_tokens_mask_original[input_ids_original.flatten()] = True
248 | occuring_tokens_mask_new[input_ids_new.flatten()] = True
249 |
250 | loss_mask_original = self._compute_loss_mask(
251 | input_ids_original, attention_mask_original, self.loss_mask_tokens_original
252 | )
253 | loss_mask_new = self._compute_loss_mask(
254 | input_ids_new, attention_mask_new, self.loss_mask_tokens_new
255 | )
256 |
257 | batch = {
258 | "input_ids_new": input_ids_new,
259 | "attention_mask_new": attention_mask_new,
260 | "occuring_tokens_mask_new": occuring_tokens_mask_new,
261 | "input_ids_original": input_ids_original,
262 | "attention_mask_original": attention_mask_original,
263 | "occuring_tokens_mask_original": occuring_tokens_mask_original,
264 | "alignment_matrix_a_unconstrained": alignment_matrix_a,
265 | "alignment_matrix_b_unconstrained": alignment_matrix_b,
266 | "alignment_matrix_a_space": alignment_matrix_a_space,
267 | "alignment_matrix_b_space": alignment_matrix_b_space,
268 | "alignment_matrix_a_unbiased": alignment_matrix_a_unbiased,
269 | "alignment_matrix_b_unbiased": alignment_matrix_b_unbiased,
270 | "loss_mask_original": loss_mask_original,
271 | "loss_mask_new": loss_mask_new,
272 | }
273 |
274 | if self.expand_input_ids_dict is not None:
275 | batch["expanded_input_ids_new"] = utils.np_expand_input_ids(
276 | input_ids_new,
277 | self.expand_input_ids_dict,
278 | )
279 |
280 | return batch
281 |
282 | def get_batch_pspecs(self):
283 | batch_specs = {
284 | "input_ids_new": P("data", None),
285 | "attention_mask_new": P("data", None),
286 | "occuring_tokens_mask_new": P(),
287 | "input_ids_original": P("data", None),
288 | "attention_mask_original": P("data", None),
289 | "occuring_tokens_mask_original": P(),
290 | "alignment_matrix_a_unconstrained": P("data", None),
291 | "alignment_matrix_b_unconstrained": P("data", None),
292 | "alignment_matrix_a_space": P("data", None),
293 | "alignment_matrix_b_space": P("data", None),
294 | "alignment_matrix_a_unbiased": P("data", None),
295 | "alignment_matrix_b_unbiased": P("data", None),
296 | "loss_mask_original": P("data", None),
297 | "loss_mask_new": P("data", None),
298 | }
299 |
300 | if self.expand_input_ids_dict is not None:
301 | batch_specs["expanded_input_ids_new"] = P("data", None)
302 |
303 | return batch_specs
304 |
--------------------------------------------------------------------------------
/tokenkit/training/lr.py:
--------------------------------------------------------------------------------
1 | import optax
2 |
3 |
4 | def linear_warmup_linear_decay_with_linear_prefix(
5 | lr, steps, warmup_steps, prefix_steps=0, prefix_lr=0.0
6 | ):
7 | """Returns a linear warmup, linear decay learning rate function."""
8 |
9 | prefix_fn = optax.linear_schedule(
10 | init_value=0.0, end_value=prefix_lr, transition_steps=prefix_steps
11 | )
12 |
13 | warmup_fn = optax.linear_schedule(
14 | init_value=0.0,
15 | end_value=lr,
16 | transition_steps=warmup_steps,
17 | )
18 |
19 | decay_fn = optax.linear_schedule(
20 | init_value=lr,
21 | end_value=0.0,
22 | transition_steps=steps - warmup_steps - prefix_steps,
23 | )
24 |
25 | fn = optax.join_schedules(
26 | schedules=[prefix_fn, warmup_fn, decay_fn],
27 | boundaries=[
28 | prefix_steps,
29 | prefix_steps + warmup_steps,
30 | ],
31 | )
32 |
33 | return fn
34 |
35 |
36 | def linear_warmup_cosine_decay_with_linear_prefix(
37 | lr, steps, warmup_steps, alpha=0.0, prefix_steps=0, prefix_lr=0.0
38 | ):
39 | """Returns a linear warmup, cosine decay learning rate function."""
40 |
41 | prefix_fn = optax.linear_schedule(
42 | init_value=0.0, end_value=prefix_lr, transition_steps=prefix_steps
43 | )
44 |
45 | warmup_fn = optax.linear_schedule(
46 | init_value=0.0,
47 | end_value=lr,
48 | transition_steps=warmup_steps,
49 | )
50 |
51 | decay_fn = optax.cosine_decay_schedule(
52 | init_value=lr,
53 | decay_steps=steps - warmup_steps - prefix_steps,
54 | alpha=alpha,
55 | )
56 |
57 | fn = optax.join_schedules(
58 | schedules=[prefix_fn, warmup_fn, decay_fn],
59 | boundaries=[
60 | prefix_steps,
61 | prefix_steps + warmup_steps,
62 | ],
63 | )
64 |
65 | return fn
66 |
--------------------------------------------------------------------------------
/tokenkit/training/multitask.py:
--------------------------------------------------------------------------------
1 | import jax
2 | import jax.numpy as jnp
3 | from typing import Any
4 |
5 |
6 | def pcgrad(task_grads: Any) -> Any:
7 | """
8 | Implements PCGrad (Project Conflicting Gradients) in JAX.
9 |
10 | Args:
11 | task_grads: A pytree containing gradients where the first dimension of each
12 | array represents different tasks. Shape: (n_tasks, ...)
13 |
14 | Returns:
15 | Modified gradients after applying PCGrad, with the same structure as the input.
16 | """
17 | # Get the structure of the input pytree and convert to flat representation
18 | flat_task_grads, treedef = jax.tree.flatten(task_grads)
19 |
20 | # Check if all elements are arrays and get number of tasks
21 | n_tasks = None
22 | for grad in flat_task_grads:
23 | if grad is not None:
24 | n_tasks = grad.shape[0]
25 | break
26 |
27 | if n_tasks is None:
28 | # No valid gradients found
29 | return task_grads
30 |
31 | # Create a new flat list to store the modified gradients
32 | modified_flat_grads = []
33 |
34 | # Process each element in the flat list
35 | for grad in flat_task_grads:
36 | if grad is None:
37 | modified_flat_grads.append(None)
38 | continue
39 |
40 | # Ensure the first dimension matches the number of tasks
41 | assert (
42 | grad.shape[0] == n_tasks
43 | ), f"Expected first dimension to be {n_tasks}, got {grad.shape[0]}"
44 |
45 | # Extract shape for reshaping later
46 | original_shape = grad.shape
47 | # Reshape to (n_tasks, -1) for easier processing
48 | reshaped_grad = grad.reshape(n_tasks, -1)
49 |
50 | # Initialize modified gradients with copies of the original
51 | modified_grad = reshaped_grad.copy()
52 |
53 | # Apply PCGrad for each task
54 | for i in range(n_tasks):
55 | # Project task i's gradient onto normal plane of other tasks' gradients if they conflict
56 | grad_i = modified_grad[i]
57 |
58 | for j in range(n_tasks):
59 | if i == j:
60 | continue
61 |
62 | grad_j = reshaped_grad[j]
63 |
64 | # Calculate dot product to check for conflict
65 | dot_product = jnp.sum(grad_i * grad_j)
66 |
67 | # If dot product is negative, project gradient
68 | def project(g_i, g_j, dot):
69 | g_j_norm_squared = jnp.sum(g_j * g_j)
70 | # Avoid division by zero
71 | safe_norm_squared = jnp.maximum(g_j_norm_squared, 1e-8)
72 | # Project g_i onto normal plane of g_j
73 | return g_i - jnp.minimum(0.0, dot) * g_j / safe_norm_squared
74 |
75 | modified_grad = modified_grad.at[i].set(
76 | project(modified_grad[i], grad_j, dot_product)
77 | )
78 |
79 | # Reshape back to original shape and add to the modified list
80 | modified_flat_grads.append(modified_grad.reshape(original_shape))
81 |
82 | # Reconstruct the pytree with modified gradients
83 | return jax.tree.unflatten(treedef, modified_flat_grads)
84 |
85 |
86 | def gradmag(task_grads: Any, epsilon: float = 1e-8) -> Any:
87 | """
88 | Normalizes gradients of all tasks to have the same magnitude.
89 |
90 | Args:
91 | task_grads: A pytree containing gradients where the first dimension of each
92 | array represents different tasks. Shape: (n_tasks, ...)
93 | epsilon: Small constant to avoid division by zero
94 |
95 | Returns:
96 | Modified gradients after normalization, with the same structure as the input.
97 | """
98 | global_grad_norms = compute_global_grad_norm(task_grads) + epsilon
99 | return jax.tree.map(
100 | lambda grad: grad
101 | / jnp.reshape(global_grad_norms, (-1,) + (1,) * (grad.ndim - 1)),
102 | task_grads,
103 | )
104 |
105 |
106 | def gradclip(task_grads: Any, max_norm: float) -> Any:
107 | """
108 | Clips gradients of all tasks to have the same magnitude.
109 |
110 | Args:
111 | task_grads: A pytree containing gradients where the first dimension of each
112 | array represents different tasks. Shape: (n_tasks, ...)
113 | max_norm: Maximum allowed norm for gradients
114 |
115 | Returns:
116 | Modified gradients after clipping, with the same structure as the input.
117 | """
118 | global_grad_norms = compute_global_grad_norm(task_grads)
119 | denominators = jnp.maximum(global_grad_norms, max_norm)
120 | return jax.tree.map(lambda grad: grad / jnp.reshape(denominators, (-1,) + (1,) * (grad.ndim - 1)) * max_norm, task_grads)
121 |
122 |
123 | def compute_global_grad_norm(task_grads: Any) -> jnp.ndarray:
124 | """
125 | Computes the global gradient norm for a pytree of task gradients.
126 |
127 | Args:
128 | task_grads: A pytree containing gradients where the first dimension of each
129 | array represents different tasks. Shape: (n_tasks, ...)
130 |
131 | Returns:
132 | Global gradient norm, a scalar.
133 | """
134 | global_grad_norms = jnp.sqrt(
135 | jax.tree.reduce(
136 | lambda x, y: x + y,
137 | jax.tree.map(
138 | lambda x: jnp.square(x).reshape(x.shape[0], -1).sum(axis=1),
139 | task_grads,
140 | ),
141 | )
142 | )
143 | return global_grad_norms
144 |
145 |
146 | def compute_inv_global_grad_norm(task_grads: Any, epsilon: float = 1e-8) -> jnp.ndarray:
147 | """
148 | Computes the inverse of the global gradient norm for a pytree of task gradients.
149 |
150 | Args:
151 | task_grads: A pytree containing gradients where the first dimension of each
152 | """
153 |
154 | return 1 / (compute_global_grad_norm(task_grads) + epsilon)
--------------------------------------------------------------------------------
/tokenkit/training/opt.py:
--------------------------------------------------------------------------------
1 | from pprint import pprint
2 |
3 | import optax
4 | import regex as re
5 | from flax import traverse_util
6 |
7 |
8 | def decay_mask_fn(params):
9 | flat_params = traverse_util.flatten_dict(params)
10 |
11 | # TODO: this is somewhat hacky but (almost) always accurate
12 | flat_mask = {
13 | path: not (
14 | path[-1] in {"bias", "b"}
15 | or any(
16 | ln_name in ".".join(path[-2:])
17 | for ln_name in {"layernorm", "layer_norm", "ln"}
18 | )
19 | )
20 | for path in flat_params
21 | }
22 | return traverse_util.unflatten_dict(flat_mask)
23 |
24 |
25 | def get_optimizer(train_mask, learning_rate_fn, **optimizer_kwargs):
26 | transforms = []
27 |
28 | opt_type = optimizer_kwargs.pop("type")
29 | grad_acc_steps = optimizer_kwargs.pop("grad_acc_steps", None)
30 | max_grad_norm = optimizer_kwargs.pop("max_grad_norm", None)
31 |
32 | if opt_type == "adamw":
33 | opt_fn = optax.adamw
34 | else:
35 | raise ValueError(f"Unknown optimizer type: {opt_type}")
36 |
37 | flat_param_group_labels = {}
38 | flat_train_mask = traverse_util.flatten_dict(train_mask)
39 | param_groups = optimizer_kwargs.pop("param_groups", [])
40 | optimizers = {
41 | "_default": opt_fn(
42 | mask=decay_mask_fn, learning_rate=learning_rate_fn, **optimizer_kwargs
43 | ),
44 | "_do_not_train": optax.set_to_zero(),
45 | }
46 |
47 | for group in param_groups:
48 | for key, trainable in flat_train_mask.items():
49 | if not trainable:
50 | flat_param_group_labels[key] = "_do_not_train"
51 | elif re.match(group["pattern"], ".".join(key)):
52 | flat_param_group_labels[key] = group["pattern"]
53 | if group["pattern"] not in optimizers:
54 | optimizers[group["pattern"]] = opt_fn(
55 | mask=decay_mask_fn,
56 | learning_rate=lambda count: learning_rate_fn(count)
57 | * group["lr_scale"],
58 | **optimizer_kwargs,
59 | )
60 |
61 | for key in optimizers.keys():
62 | if key == "_do_not_train":
63 | continue
64 |
65 | if max_grad_norm is not None:
66 | optimizers[key] = optax.chain(
67 | optax.clip_by_global_norm(max_grad_norm),
68 | optimizers[key],
69 | )
70 |
71 | if grad_acc_steps is not None and grad_acc_steps > 1:
72 | optimizers[key] = optax.MultiSteps(opt=optimizers[key], every_k_schedule=grad_acc_steps)
73 |
74 | for key, trainable in flat_train_mask.items():
75 | if key not in flat_param_group_labels:
76 | if trainable:
77 | flat_param_group_labels[key] = "_default"
78 | else:
79 | flat_param_group_labels[key] = "_do_not_train"
80 |
81 | print("Special parameter groups:")
82 | pprint(
83 | {
84 | k: v
85 | for k, v in flat_param_group_labels.items()
86 | if v not in {"_default", "_do_not_train"}
87 | }
88 | )
89 |
90 | return optax.multi_transform(
91 | optimizers,
92 | traverse_util.unflatten_dict(flat_param_group_labels),
93 | )
94 |
--------------------------------------------------------------------------------