├── requirements.txt ├── safetensor_meta_dump.sh ├── LICENSE ├── README.md └── transplant_vocab.py /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=2.0.0 2 | transformers>=4.30.0 3 | tqdm>=4.64.0 4 | numpy>=1.20.0 5 | -------------------------------------------------------------------------------- /safetensor_meta_dump.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Specify the file 4 | FILE="$1" 5 | 6 | # Extract the first 8 bytes and convert them to a decimal integer 7 | HEADER_LENGTH=$(dd "if=$FILE" bs=1 count=8 2>/dev/null | od -An -vtu8) 8 | 9 | # Extract the metadata, starting from the 9th byte 10 | dd "if=$FILE" bs=1 skip=8 "count=$HEADER_LENGTH" 2>/dev/null | jq 11 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Vocab Transplantation Tool 2 | 3 | Transplants vocabulary between language models, enabling the creation of draft models for efficient speculative decoding **WITHOUT** retraining. 4 | 5 | ## Table of Contents 6 | 7 | - [Features](#features) 8 | - [Installation](#installation) 9 | - [Usage](#usage) 10 | - [Basic Command](#basic-command) 11 | - [Options](#options) 12 | - [Examples](#examples) 13 | - [Token Mapping](#token-mapping) 14 | - [Automatic Special Token Mapping](#automatic-special-token-mapping) 15 | - [Manual Token Mapping Overrides](#manual-token-mapping-overrides) 16 | - [Layer Trimming](#layer-trimming) 17 | - [Hidden and Intermediate Size Trimming](#hidden-and-intermediate-size-trimming) 18 | - [Handling Models Without BOS Tokens](#handling-models-without-bos-tokens) 19 | - [Design Rationale](#design-rationale) 20 | - [Input Embeddings (Final Token Strategy)](#input-embeddings-final-token-strategy) 21 | - [Output Head (First Token Strategy)](#output-head-first-token-strategy) 22 | - [Mathematical Considerations](#mathematical-considerations) 23 | - [Credit](#credit) 24 | - [License](#license) 25 | 26 | This tool allows you to combine the transformer architecture and weights from a donor model with the tokenizer of a target model, creating a hybrid model that can serve as a draft model in speculative decoding pipelines. By matching token-to-token or multi-token mappings between vocabularies, it intelligently transfers embeddings while preserving semantic relationships. This approach eliminates the need for expensive retraining or distillation procedures typically required for creating compatible draft models, making it an efficient solution for accelerating inference through speculative decoding techniques. 27 | 28 | ## Features 29 | 30 | - Preserve the donor model's intelligence/performance. 31 | - Adapt donor model to use the target model's tokenizer. 32 | - Automatic special tokens mapping between models. 33 | - User-specified manual token mapping overrides. 34 | - (**only useful for fine-tuning**) Models can be "trimmed" by removing a range of layers. 35 | - (**only useful for fine-tuning**) Models can be "trimmed" by reducing the hidden state dimension. 36 | - (**only useful for fine-tuning**) Models can be "trimmed" by reducing the MLP's intermediate dimension. 37 | 38 | ## Installation 39 | 40 | ```bash 41 | pip install -r requirements.txt 42 | ``` 43 | 44 | **Requirements:** 45 | - Python 3.8+ 46 | - PyTorch 2.0+ 47 | - Transformers 4.30+ 48 | - tqdm 49 | 50 | ## Usage 51 | 52 | ### Basic Command 53 | ```bash 54 | python transplant_vocab.py /path/to/donor_model /path/to/target_model /path/to/output_model 55 | ``` 56 | 57 | ### Options 58 | 59 | | Flag | Description | 60 | |------|-------------| 61 | | `--override TARGET DONOR` | Override target token with donor sequence (can be used multiple times) | 62 | | `--weighting-decay-factor [0-1]` | Decay factor for multi-token mappings: 0=first token only, 0.5=decreasing weights, 1=uniform mean | 63 | | `--trim-layers START-END` | Trim out a range of layers from the model: start-end (inclusive) | 64 | | `--trim-hidden-size SIZE` | Trim the hidden state dimension (and the number of heads as a result) | 65 | | `--trim-intermediate-size SIZE` | Trim the intermediate dimension of the MLP blocks | 66 | | `--patch-missing-bos` | Patch `tokenizer_config.json` for models like `Qwen` which don't use any `` token | 67 | | `--use-cpu-only` | Use CPU instead of GPU (and with `float32` precision) | 68 | | `--trust-remote-code` | Allow custom code execution when loading models with non-standard architectures | 69 | | `--keep-tied-word-embeddings` | Keep tied word embeddings from donor model; default is to untie and create separate lm_head | 70 | | `--overwrite` | Replace existing output directory | 71 | | `--verbose` | Show detailed token mapping output | 72 | 73 | Note on tied vs untied embeddings: 74 | - By default, the tool unties embeddings and creates a separate `lm_head`. Pass `--keep-tied-word-embeddings` to preserve the donor model's tied embeddings. 75 | - If the donor uses untied embeddings, the output remains untied regardless of this flag. 76 | 77 | ### Examples 78 | 79 | #### Transplant `DeepSeek-R1` tokenizer into `Qwen2.5-0.5B-Instruct` model and output as new model called `DeepSeek-R1-DRAFT-0.5B`: 80 | 81 | ```bash 82 | python transplant_vocab.py ./Qwen2.5-0.5B-Instruct ./DeepSeek-R1 ./DeepSeek-R1-DRAFT-0.5B 83 | ``` 84 | 85 | #### With manual token mapping overrides for chat templates (see below for detailed explanation): 86 | 87 | ```bash 88 | python transplant_vocab.py ./Qwen2.5-0.5B-Instruct ./DeepSeek-R1 ./DeepSeek-R1-DRAFT-0.5B \ 89 | --override "<|User|>" "<|im_start|>user\\n" \ 90 | --override "<|Assistant|>" "<|im_start|>assistant\\n" \ 91 | --override ... 92 | ``` 93 | 94 | #### Use only first token for `lm_head` averaging (maximum front-loading): 95 | 96 | ```bash 97 | python transplant_vocab.py ./Qwen2.5-0.5B-Instruct ./DeepSeek-R1 ./DeepSeek-R1-DRAFT-0.5B-first --weighting-decay-factor 0.0 98 | ``` 99 | 100 | #### Use uniform mean for `lm_head` averaging (ie: equal weight to all tokens): 101 | 102 | ```bash 103 | python transplant_vocab.py ./Qwen2.5-0.5B-Instruct ./DeepSeek-R1 ./DeepSeek-R1-DRAFT-0.5B-mean --weighting-decay-factor 1.0 104 | ``` 105 | 106 | #### Use decreasing weights (eg: 1, 0.5, 0.25, etc.) for `lm_head` averaging (default behaviour): 107 | 108 | ```bash 109 | python transplant_vocab.py ./Qwen2.5-0.5B-Instruct ./DeepSeek-R1 ./DeepSeek-R1-DRAFT-0.5B-decay --weighting-decay-factor 0.5 110 | ``` 111 | 112 | #### Trim out intermediate layers to create a smaller model that we can use for further fine-tuning: 113 | 114 | ```bash 115 | python transplant_vocab.py ./Qwen2.5-0.5B-Instruct ./DeepSeek-R1 ./DeepSeek-R1-DRAFT-0.5B-trimmed --trim-layers 14-21 116 | ``` 117 | 118 | this leaves a model with 16 layer in total; 14 taken from the start and 2 from the end: 119 | 120 | ``` 121 | Trimming layers 14 through 21 (inclusive): 122 | - Old layer count : 24 (layers 0-23) 123 | - New layer count : 16 (keeping layers 0-13 and 22-23) 124 | - Removed 96 tensors from state_dict 125 | - Renamed 192 layer tensors to new indices 126 | - Updated model configuration: num_hidden_layers = 16 127 | ``` 128 | 129 | #### Reduce the hidden size (and the number of attention heads): 130 | 131 | ```bash 132 | python transplant_vocab.py ./Qwen2.5-0.5B-Instruct ./DeepSeek-R1 ./DeepSeek-R1-DRAFT-0.5B-small --trim-hidden-size 768 133 | ``` 134 | 135 | to create a smaller model with less attention heads: 136 | 137 | ``` 138 | Trimming hidden size from 896 to 768: 139 | - Old hidden size : 896 140 | - New hidden size : 768 141 | - Updated model configuration: hidden_size = 768 142 | - Updated model configuration: num_attention_heads = 12 143 | - Trimmed 243 tensors in state_dict 144 | ``` 145 | 146 | #### Reduce the intermediate size: 147 | 148 | ```bash 149 | python transplant_vocab.py ./Qwen2.5-0.5B-Instruct ./DeepSeek-R1 ./DeepSeek-R1-DRAFT-0.5B-small --trim-intermediate-size 3072 150 | ``` 151 | 152 | to create a smaller model: 153 | 154 | ``` 155 | Trimming intermediate size from 4864 to 3072: 156 | - Old intermediate size : 4864 157 | - New intermediate size : 3072 158 | - Updated model configuration: intermediate_size = 3072 159 | - Trimmed 72 tensors in state_dict 160 | ``` 161 | 162 | ### Token Mapping 163 | 164 | #### Automatic Special Token Mapping 165 | 166 | The tool automatically attempts to map three special tokens between models: 167 | - `bos_token_id` (Beginning of Sequence) 168 | - `eos_token_id` (End of Sequence) 169 | - `pad_token_id` (Padding) 170 | 171 | These mappings ensure that the transplanted model correctly handles sequence boundaries and padding, which is critical for proper functioning. 172 | 173 | **NOTE**: Some models reuse `eos_token_id` as `pad_token_id` so this automatic process is not possible in these cases, eg: 174 | 175 | ``` 176 | Processing 3 automatic token overrides: 177 | ✔ 'bos_token_id' : 0 '<|begin▁of▁sentence|>' → [151643] '<|endoftext|>' 178 | ✔ 'eos_token_id' : 1 '<|end▁of▁sentence|>' → [151645] '<|im_end|>' 179 | ✘ 'pad_token_id' : 1 is already mapped to [151645] 180 | ``` 181 | 182 | #### Manual Token Mapping Overrides 183 | 184 | For more complex models, especially those with chat templates or special tokens for specific tasks, you can manually map tokens using the `--override` option: 185 | 186 | ```bash 187 | python transplant_vocab.py ./donor_model ./target_model ./output_model --override "" "" 188 | ``` 189 | 190 | You can specify multiple overrides by repeating the `--override` option. This is particularly useful for: 191 | - Chat template tokens (user/assistant markers) 192 | - Special task tokens (FIM, tool calls, etc.) 193 | - Any token that needs specific handling 194 | 195 | #### Example: Mapping Chat and Special Tokens 196 | 197 | Here we manually map (target) `DeepSeek-V3` tokens to (donor) `Qwen2.5` tokens/sequences: 198 | 199 | ```bash 200 | python transplant_vocab.py ./Qwen2.5-0.5B-Instruct ./DeepSeek-V3 ./DeepSeek-V3-DRAFT-0.5B \ 201 | --override "<|▁pad▁|>" "<|endoftext|>" \ 202 | --override "<|fim▁hole|>" "<|fim_middle|>" \ 203 | --override "<|fim▁begin|>" "<|fim_prefix|>" \ 204 | --override "<|fim▁end|>" "<|fim_suffix|>" \ 205 | --override "<|User|>" "<|im_start|>user\n" \ 206 | --override "<|Assistant|>" "<|im_start|>assistant\n" \ 207 | --override "<|EOT|>" "<|endoftext|>" \ 208 | --override "<|tool▁calls▁begin|>" "" \ 209 | --override "<|tool▁call▁begin|>" "" \ 210 | --override "<|tool▁outputs▁begin|>" "" \ 211 | --override "<|tool▁output▁begin|>" "" \ 212 | --override "<|tool▁calls▁end|>" "" \ 213 | --override "<|tool▁call▁end|>" "" \ 214 | --override "<|tool▁outputs▁end|>" "" \ 215 | --override "<|tool▁output▁end|>" "" \ 216 | --override "<|tool▁sep|>" "" 217 | ``` 218 | 219 | which should output something like this: 220 | 221 | ``` 222 | Processing 16 manual token overrides: 223 | ✔ 2 : '<|▁pad▁|>' → [151643] '<|endoftext|>' 224 | ✔ 128800 : '<|fim▁hole|>' → [151660] '<|fim_middle|>' 225 | ✔ 128801 : '<|fim▁begin|>' → [151659] '<|fim_prefix|>' 226 | ✔ 128802 : '<|fim▁end|>' → [151661] '<|fim_suffix|>' 227 | ✔ 128803 : '<|User|>' → [151644, 872, 198] '<|im_start|>user\n' 228 | ✔ 128804 : '<|Assistant|>' → [151644, 77091, 198] '<|im_start|>assistant\n' 229 | ✔ 128805 : '<|EOT|>' → [151643] '<|endoftext|>' 230 | ✔ 128806 : '<|tool▁calls▁begin|>' → [151657] '' 231 | ✔ 128808 : '<|tool▁call▁begin|>' → [151657] '' 232 | ✔ 128810 : '<|tool▁outputs▁begin|>' → [151657] '' 233 | ✔ 128812 : '<|tool▁output▁begin|>' → [151657] '' 234 | ✔ 128807 : '<|tool▁calls▁end|>' → [151658] '' 235 | ✔ 128809 : '<|tool▁call▁end|>' → [151658] '' 236 | ✔ 128811 : '<|tool▁outputs▁end|>' → [151658] '' 237 | ✔ 128813 : '<|tool▁output▁end|>' → [151658] '' 238 | ✔ 128814 : '<|tool▁sep|>' → [151658] '' 239 | ``` 240 | 241 | **NOTE**: I suggest you use the `--verbose` flag to verify your mappings are working as expected, eg: 242 | 243 | ``` 244 | Transplanting tokens: 245 | - 0 : '<|begin▁of▁sentence|>' → [151643] 246 | - 1 : '<|end▁of▁sentence|>' → [151645] 247 | - 2 : '<|▁pad▁|>' → [151643] 248 | - 128800 : '<|fim▁hole|>' → [151660] 249 | - 128801 : '<|fim▁begin|>' → [151659] 250 | - 128802 : '<|fim▁end|>' → [151661] 251 | - 128803 : '<|User|>' → [151644, 872, 198] 252 | - 128804 : '<|Assistant|>' → [151644, 77091, 198] 253 | - 128805 : '<|EOT|>' → [151643] 254 | - 128806 : '<|tool▁calls▁begin|>' → [151657] 255 | - 128807 : '<|tool▁calls▁end|>' → [151658] 256 | - 128808 : '<|tool▁call▁begin|>' → [151657] 257 | - 128809 : '<|tool▁call▁end|>' → [151658] 258 | - 128810 : '<|tool▁outputs▁begin|>' → [151657] 259 | - 128811 : '<|tool▁outputs▁end|>' → [151658] 260 | - 128812 : '<|tool▁output▁begin|>' → [151657] 261 | - 128813 : '<|tool▁output▁end|>' → [151658] 262 | - 128814 : '<|tool▁sep|>' → [151658] 263 | ``` 264 | 265 | and also to explore other possible manual overrides... 266 | 267 | ## Layer Trimming 268 | 269 | The `--trim-layers` option allows you to remove a range of intermediate layers from the model. This can be useful for several reasons: 270 | 271 | ### Benefits of Layer Trimming 272 | 273 | - **Faster Inference**: Smaller models with fewer layers require less computation, resulting in faster inference times. This is particularly valuable for speculative decoding where draft model speed is critical. 274 | - **Reduced Memory Usage**: Trimmed models consume less GPU memory, allowing deployment on more modest hardware. 275 | - **More Efficient Fine-tuning**: Smaller models are faster and cheaper to fine-tune. 276 | 277 | ### Important Considerations 278 | 279 | - **Performance Impact**: Unlike vocabulary transplantation (which preserves most of the model's capabilities), layer trimming significantly impacts model performance. The resulting model will require fine-tuning to recover acceptable performance. 280 | - **Layer Selection Strategy**: Research such as ["The Unreasonable Ineffectiveness of the Deeper Layers"](https://arxiv.org/abs/2403.17887) suggests that not all layers contribute equally to model performance. 281 | - **Recommended Approach**: When trimming layers, it's generally advisable to: 282 | - Keep the very early layers (which transform embedding-space to hidden/latent representations) 283 | - Keep the early-intermediate layers (which store/transform useful semantic information) 284 | - Keep the final 1-2 layers (which transform hidden/latent representations to logit-space) 285 | - Remove the later-intermediate layers (which often contain redundant information) 286 | 287 | ### Example Trimming Strategy 288 | 289 | For a 24-layer model like `Qwen2.5-0.5B-Instruct`, you might use `--trim-layers 14-21`: 290 | 291 | This keeps layers 0-13 (the first 14 layers) and layers 22-23 (the final 2 layers), resulting in a 16-layer model that preserves both the input processing and output generation capabilities while removing 8 of the (later) intermediate layers. The resulting model will be approximately 2/3 the size and should run approximately 33% faster for speculative decoding. 292 | 293 | **IMPORTANT**: After layer trimming, you ***must fine-tune*** the model to recover performance. 294 | 295 | ## Hidden and Intermediate Size Trimming 296 | 297 | The `--trim-hidden-size` and `--trim-intermediate-size` options allows you to reduce the hidden state dimension (and the number of heads as a result), and the intermediate dimension of the MLP blocks throughout the model. This can be useful for the same reasons as layer trimming. 298 | 299 | ### Important Considerations 300 | 301 | - **Performance Impact**: Like layer trimming, reducing hidden/intermediate dimensions will impact model performance. The resulting model will require fine-tuning to recover acceptable performance. 302 | - **Recommended Approach**: When trimming hidden/intermediate size: 303 | - Ensure the new hidden state dimension leaves a whole number of heads (eg: 512/896 * 14 = 8 heads after trimming) 304 | - Consider the ratio between the original and new sizes (eg: reducing from 4864 to 2432 is a 50% reduction in intermediate size) 305 | - Preferably, choose sizes that are a multiple of 128 for compatibility with hardware acceleration (eg: 2432/128 = 19 and 512/128 = 4, etc) 306 | 307 | **IMPORTANT**: After trimming hidden/intermediate size, you ***must fine-tune*** the model to recover performance. 308 | 309 | ## Handling Models Without BOS Tokens 310 | 311 | Some language models, like Qwen, don't use beginning-of-sequence (BOS) tokens in their tokenization strategy. This can cause issues when transplanting vocabularies between models with different tokenization approaches. 312 | 313 | The `--patch-missing-bos` option addresses this by: 314 | 315 | 1. Modifying the `tokenizer_config.json` file to set `add_bos_token` to `false` 316 | 2. Removing any references to `bos_token` from the Jinja chat template 317 | 3. Ensuring the model doesn't automatically add BOS tokens where they aren't expected 318 | 319 | ### When to Use This Option 320 | 321 | Use `--patch-missing-bos` when: 322 | - The donor model doesn't use BOS tokens but the target model does 323 | - You notice unexpected tokens at the beginning of generated sequences 324 | - You're working with models like Qwen that have specific tokenization strategies 325 | 326 | ### Example 327 | 328 | ```bash 329 | python transplant_vocab.py ./Qwen2.5-0.5B-Instruct ./DeepSeek-R1 ./DeepSeek-R1-DRAFT-0.5B --patch-missing-bos 330 | ``` 331 | 332 | This will patch the tokenizer configuration in the output model to handle the absence of BOS tokens properly. 333 | 334 | ## Design Rationale 335 | 336 | ### Input Embeddings (Final Token Strategy) 337 | 338 | When a target token maps to multiple donor tokens: 339 | 340 | ```text 341 | Target: [X] → Donor: [A, B, C] 342 | ``` 343 | 344 | We use **C** (**ONLY** the final token) because: 345 | 346 | 1. Transformers process tokens sequentially, with transformer blocks "looking backward". 347 | 2. It's the transformer blocks that integrate context from previous tokens. 348 | 3. Taking the mean of all tokens doesn't align with how transformers process sequences. 349 | 4. Using the final token aligns with how the transformers process the previous token to create the next token. 350 | 351 | ### Output Head (First Token Strategy) 352 | 353 | When a target token maps to multiple donor tokens: 354 | 355 | ```text 356 | Target: [Y] → Donor: [D, E, F] 357 | ``` 358 | 359 | We use **D** (**MOSTLY** the first token) because: 360 | 361 | 1. The model decides on word endings in subsequent autoregressive passes. 362 | 2. When handling multi-token mappings, we have three options: 363 | - Use only the first token (`--weighting-decay-factor 0.0`) 364 | - Use a uniform mean of all tokens (`--weighting-decay-factor 1.0`) 365 | - Use exponentially decreasing weights (`--weighting-decay-factor 0.5`) 366 | 3. We choose to use `0.5` as the default because: 367 | - Using only the first token creates probability mass inflation for repeated prefixes. 368 | - Using a uniform mean inappropriately gives too much weight to trailing tokens. 369 | 370 | When preserving tied embeddings (only if `--keep-tied-word-embeddings` is passed and the donor is tied), this averaging strategy is applied to `embed_tokens` directly, and no separate `lm_head.weight` is saved; on load, the head is tied to the embeddings. 371 | 372 | ### Mathematical Considerations 373 | 374 | - Using means or scaling logits isn't mathematically ideal for probability distribution. 375 | - Proper token splitting would require subtracting `log(n)` from each token in an n-token group. 376 | - In the absence of an `lm_head.bias`, our approach provides the most practical solution. 377 | - The `--weighting-decay-factor` parameter controls how we handle cases where one target token maps to multiple donor tokens. The default value of `0.5` balances between preserving the importance of the first token while still incorporating information from all tokens in the sequence. Values closer to `0.0` or `1.0` may provide better initialisations for fine-tuning but could produce less reliable outputs if used without any further fine-tuning. 378 | 379 | ## Credit 380 | 381 | Original concept by [turboderp](https://huggingface.co/turboderp). Based on [original implementation](https://huggingface.co/turboderp/Qwama-0.5B-Instruct/blob/main/vocab_transplant.py). 382 | 383 | ## License 384 | 385 | Apache 2.0 License - See [LICENSE](LICENSE) for details 386 | -------------------------------------------------------------------------------- /transplant_vocab.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | Vocab Transplantation Tool 4 | 5 | All credit to turboderp for the original idea: 6 | 7 | https://huggingface.co/turboderp/Qwama-0.5B-Instruct/blob/main/vocab_transplant.py 8 | """ 9 | 10 | from tqdm import tqdm 11 | from typing import Tuple, Dict 12 | import argparse 13 | import json 14 | import os 15 | import re 16 | import shutil 17 | import sys 18 | import torch 19 | 20 | from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig 21 | 22 | import torch.nn as nn 23 | 24 | def parse_arguments() -> argparse.Namespace: 25 | """Parse and validate command line arguments""" 26 | parser = argparse.ArgumentParser( 27 | description="Transplant token embeddings between language models", 28 | formatter_class=argparse.ArgumentDefaultsHelpFormatter 29 | ) 30 | parser.add_argument("donor_dir", help="Path to donor model directory") 31 | parser.add_argument("target_dir", help="Path to target model directory") 32 | parser.add_argument("output_dir", help="Path to output model directory") 33 | parser.add_argument("--override", nargs=2, action="append", default=[], 34 | help="Override target token with donor sequence (can be used multiple times)") 35 | parser.add_argument("--weighting-decay-factor", type=float, default=0.5, 36 | help="Decay factor [0-1] for multi-token mappings: " 37 | "0=first token only, 0.5=decreasing weights, 1=uniform mean") 38 | parser.add_argument("--trim-layers", type=str, 39 | help="Trim out a range of layers from the model: start-end (inclusive)") 40 | parser.add_argument("--trim-hidden-size", type=int, 41 | help="Trim the hidden state dimension (and the number of heads as a result)") 42 | parser.add_argument("--trim-intermediate-size", type=int, 43 | help="Trim the intermediate dimension of the MLP blocks") 44 | parser.add_argument("--use-cpu-only", action="store_true", 45 | help="Use CPU only for model loading and processing in float32") 46 | parser.add_argument("--trust-remote-code", action="store_true", 47 | help="Allow custom code execution when loading models with non-standard architectures") 48 | parser.add_argument("--patch-missing-bos", action="store_true", 49 | help="Patch `tokenizer_config.json` for models like `Qwen` which don't use any `` token") 50 | parser.add_argument("--keep-tied-word-embeddings", action="store_true", 51 | help="Keep tied word embeddings from donor model; default is to untie and create separate lm_head") 52 | parser.add_argument("--overwrite", action="store_true", 53 | help="Overwrite output directory if it exists") 54 | parser.add_argument("--verbose", action="store_true", 55 | help="Show detailed token mapping output") 56 | 57 | args = parser.parse_args() 58 | 59 | if not (0.0 <= args.weighting_decay_factor <= 1.0): 60 | sys.exit(f"Error: --weighting-decay-factor must be between 0.0 and 1.0 (got {args.weighting_decay_factor})") 61 | 62 | if args.trim_layers: 63 | try: 64 | start, end = map(int, args.trim_layers.split('-')) 65 | if start < 0 or end < start: 66 | sys.exit(f"Error: Invalid layer range: {args.trim_layers}. Format should be start-end with start ≥ 0 and end ≥ start") 67 | except ValueError: 68 | sys.exit(f"Error: Invalid layer range format: {args.trim_layers}. Format should be start-end (e.g., 3-8)") 69 | 70 | return args 71 | 72 | def validate_directories(args: argparse.Namespace) -> None: 73 | """Validate input/output directory structure and permissions""" 74 | for dir_type, dir_path in [("donor", args.donor_dir), ("target", args.target_dir)]: 75 | if not os.path.isdir(dir_path): 76 | sys.exit(f"Error: {dir_type} directory does not exist: {dir_path}") 77 | if not os.access(dir_path, os.R_OK): 78 | sys.exit(f"Error: No read permissions for {dir_type} directory: {dir_path}") 79 | 80 | if os.path.exists(args.output_dir): 81 | if args.overwrite: 82 | if not os.access(args.output_dir, os.W_OK): 83 | sys.exit(f"Error: No write permissions for output directory: {args.output_dir}") 84 | shutil.rmtree(args.output_dir) 85 | else: 86 | sys.exit(f"Error: Output directory exists (use --overwrite to replace): {args.output_dir}") 87 | 88 | try: 89 | os.makedirs(args.output_dir, exist_ok=True) 90 | except OSError as e: 91 | sys.exit(f"Error: Failed to create output directory: {e}") 92 | 93 | def load_model_config(path: str) -> dict: 94 | """Load model configuration""" 95 | config_path = os.path.join(path, "config.json") 96 | if not os.path.exists(config_path): 97 | sys.exit(f"Error: Config file not found at {config_path}") 98 | 99 | try: 100 | print(f"Loading config from '{path}'... ", end="") 101 | with open(config_path, "r", encoding="utf-8") as f: 102 | config = json.load(f) 103 | print("Done.") 104 | except Exception as e: 105 | sys.exit(f"Error loading config from {config_path}: {e}") 106 | 107 | return config 108 | 109 | def load_tokenizer(path: str, trust_remote_code=False) -> AutoTokenizer: 110 | """Load tokenizer with error handling""" 111 | try: 112 | print(f"Loading tokenizer from '{path}'... ", end="") 113 | tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=trust_remote_code) 114 | print("Done.") 115 | return tokenizer 116 | except Exception as e: 117 | sys.exit(f"Failed to load tokenizer: {e}") 118 | 119 | def load_model(path: str, trust_remote_code=False, use_cpu_only=False) -> AutoModelForCausalLM: 120 | """Load model with error handling""" 121 | try: 122 | print(f"Loading model from '{path}'... ", end="") 123 | if use_cpu_only: 124 | model = AutoModelForCausalLM.from_pretrained( 125 | path, 126 | trust_remote_code=trust_remote_code, 127 | device_map='cpu', 128 | torch_dtype='float32', 129 | ) 130 | else: 131 | model = AutoModelForCausalLM.from_pretrained( 132 | path, 133 | trust_remote_code=trust_remote_code, 134 | device_map='auto', 135 | torch_dtype='auto', 136 | ) 137 | print("Done.") 138 | return model 139 | except Exception as e: 140 | sys.exit(f"Failed to load model: {e}") 141 | 142 | def count_model_parameters(model) -> Tuple[int, int, int]: 143 | """ 144 | Count the total number of parameters in a model. 145 | 146 | Args: 147 | model: The model to analyze 148 | 149 | Returns: 150 | Tuple of (total parameters, embedding and LM head parameters only, 151 | parameters excluding embeddings and LM head) 152 | """ 153 | total_params = 0 154 | embedding_params = 0 155 | non_embedding_params = 0 156 | 157 | for name, param in model.named_parameters(): 158 | param_count = param.numel() 159 | total_params += param_count 160 | 161 | # Separate embedding/LM head parameters from the rest 162 | if any(skip_name in name for skip_name in ['embed_tokens', 'lm_head']): 163 | embedding_params += param_count 164 | else: 165 | non_embedding_params += param_count 166 | 167 | return total_params, embedding_params, non_embedding_params 168 | 169 | def has_config_value(config, key: str) -> bool: 170 | """Check if a key exists in a model configuration, checking both flat and nested structures""" 171 | if isinstance(config, dict): 172 | return key in config or ("text_config" in config and key in config["text_config"]) 173 | return hasattr(config, key) or (hasattr(config, "text_config") and hasattr(config.text_config, key)) 174 | 175 | def get_config_value(config, key: str, default=...): 176 | """Get a value from a model configuration, handling both flat and nested structures""" 177 | assert default is not ... or has_config_value(config, key), f"{key} not found in model configuration" 178 | if not has_config_value(config, key): 179 | return default 180 | if isinstance(config, dict): 181 | return config[key] if key in config else config["text_config"][key] 182 | return getattr(config, key) if hasattr(config, key) else getattr(config.text_config, key) 183 | 184 | def set_config_value(config, key: str, value): 185 | """Set a value in a model configuration, updating both flat and nested structures if present""" 186 | assert has_config_value(config, key), f"{key} not found in model configuration" 187 | if isinstance(config, dict): 188 | if key in config: 189 | config[key] = value 190 | if "text_config" in config and key in config["text_config"]: 191 | config["text_config"][key] = value 192 | else: 193 | if hasattr(config, key): 194 | setattr(config, key, value) 195 | if hasattr(config, "text_config") and hasattr(config.text_config, key): 196 | setattr(config.text_config, key, value) 197 | 198 | def process_automatic_token_overrides(target_tokenizer, donor_tokenizer, target_config, donor_config, existing_map=None): 199 | """ 200 | Process automatic token overrides for special tokens. 201 | 202 | Args: 203 | target_tokenizer: The target tokenizer 204 | donor_tokenizer: The donor tokenizer 205 | target_config: The target model configuration 206 | donor_config: The donor model configuration 207 | 208 | Returns: 209 | Dictionary mapping target token IDs to donor token IDs 210 | """ 211 | override_map = existing_map.copy() if existing_map else {} 212 | 213 | special_tokens = ['bos_token_id', 'eos_token_id', 'pad_token_id'] 214 | print(f"\nProcessing {len(special_tokens)} automatic token overrides:") 215 | 216 | for token_attr in special_tokens: 217 | # First try to get from the tokenizer 218 | target_token_id = getattr(target_tokenizer, token_attr) 219 | donor_token_id = getattr(donor_tokenizer, token_attr) 220 | 221 | # Try to get from config if not found in tokenizer 222 | if target_token_id is None and has_config_value(target_config, token_attr): 223 | target_token_id = get_config_value(target_config, token_attr) 224 | if donor_token_id is None and has_config_value(donor_config, token_attr): 225 | donor_token_id = get_config_value(donor_config, token_attr) 226 | 227 | # Try to perform the automatic match 228 | if target_token_id is not None: 229 | if donor_token_id is not None: 230 | if target_token_id not in override_map: 231 | target_token = target_tokenizer.convert_ids_to_tokens(target_token_id) 232 | donor_token = donor_tokenizer.convert_ids_to_tokens(donor_token_id) 233 | override_map[target_token_id] = torch.tensor([donor_token_id], dtype=torch.long) 234 | print(f"✔ {repr(token_attr)} : {target_token_id} {repr(target_token)} → [{donor_token_id}] {repr(donor_token)}") 235 | else: 236 | print(f"✘ {repr(token_attr)} : {target_token_id} is already mapped to [{override_map[target_token_id].item()}]") 237 | else: 238 | print(f"✘ {repr(token_attr)} : Not found for donor model") 239 | else: 240 | print(f"✘ {repr(token_attr)} : Not found for target model") 241 | 242 | return override_map 243 | 244 | def process_manual_token_overrides(target_tokenizer, donor_tokenizer, manual_overrides, existing_map=None): 245 | """ 246 | Process manual token overrides specified by the user. 247 | 248 | Args: 249 | target_tokenizer: The target tokenizer 250 | donor_tokenizer: The donor tokenizer 251 | manual_overrides: List of (target_token, donor_tokens) pairs 252 | existing_map: Existing override map to update (optional) 253 | 254 | Returns: 255 | Updated dictionary mapping target token IDs to donor token IDs 256 | """ 257 | override_map = existing_map.copy() if existing_map else {} 258 | 259 | if not manual_overrides: 260 | return override_map 261 | 262 | print(f"\nProcessing {len(manual_overrides)} manual token overrides:") 263 | for target_token, donor_tokens in manual_overrides: 264 | # Encode target token and verify it's a single token 265 | target_id = target_tokenizer.encode(target_token, add_special_tokens=False) 266 | assert len(target_id) == 1, f"Target token '{target_token}' maps to {len(target_id)} tokens. Must be a 1 token." 267 | target_id = target_id[0] 268 | 269 | # Replace newline characters with the actual byte representation of a newline (0x0A) 270 | # NOTE: If you don't do this then it will get wrongly encoded as the "\\n" string literal 271 | if "\\n" in donor_tokens: 272 | donor_tokens = donor_tokens.replace("\\n", chr(10)) 273 | 274 | # Get the IDs from the token string 275 | encoded = donor_tokenizer.encode(donor_tokens, add_special_tokens=False, return_tensors="pt").flatten() 276 | assert encoded.numel() != 0, f"Donor token '{donor_tokens}' for target ID {target_id} encodes to 0 tokens." 277 | 278 | # Store the donor token IDs 279 | override_map[target_id] = encoded 280 | 281 | print(f"✔ {target_id:6d} : {repr(target_token)} → {encoded.tolist()} {repr(donor_tokens)}") 282 | 283 | return override_map 284 | 285 | def compute_front_loaded_mean(v, weighting_decay_factor=0.5): 286 | """ 287 | Computes the "front-loaded" exponentially-weighted mean with a weighting decay factor. 288 | 289 | Parameters: 290 | - v: torch tensor with values 291 | - weighting_decay_factor: parameter in [0, 1] controlling how quickly weights decay for subsequent vectors 292 | 293 | Returns: 294 | - Weighted average tensor 295 | 296 | Special cases: 297 | - weighting_decay_factor=0 : Returns only the first vector (maximum front-loading) 298 | - weighting_decay_factor=0.5 : Applies weights 1, 0.5, 0.25, 0.125, ... (earlier vectors have more influence) 299 | - weighting_decay_factor=1 : Returns the uniform arithmetic mean (no front-loading) 300 | """ 301 | # Assert that weighting_decay_factor is in the valid range [0, 1] 302 | assert 0 <= weighting_decay_factor <= 1, f"weighting_decay_factor must be in range [0, 1], got {weighting_decay_factor}" 303 | 304 | n = v.shape[0] 305 | 306 | if n == 1 or weighting_decay_factor == 0: 307 | return v[0] # First (or only) vector only 308 | elif weighting_decay_factor == 1: 309 | return torch.mean(v, dim=0) # Arithmetic mean 310 | else: 311 | # Compute the weights using geometric progression 312 | decay_powers = torch.tensor([weighting_decay_factor ** i for i in range(n)], device=v.device) 313 | decay_powers = decay_powers.view(-1, *([1] * (v.dim() - 1))) 314 | weighted_sum = torch.sum(decay_powers * v, dim=0) 315 | denominator = torch.sum(decay_powers) 316 | return weighted_sum / denominator 317 | 318 | def transplant_tokens(model, donor_config, target_tokenizer, donor_tokenizer, 319 | override_map, vocab_size, used_vocab_size, 320 | weighting_decay_factor, untie_word_embeddings=False, verbose=False): 321 | """ 322 | Transplant token embeddings from donor model to target vocabulary. 323 | 324 | Args: 325 | model: The donor model 326 | donor_config: The donor model configuration 327 | target_tokenizer: The target tokenizer 328 | donor_tokenizer: The donor tokenizer 329 | override_map: Dictionary mapping target token IDs to donor token IDs 330 | vocab_size: Total size of the target vocabulary 331 | used_vocab_size: Number of tokens actually used in the target vocabulary 332 | weighting_decay_factor: Factor for weighting multi-token mappings 333 | untie_word_embeddings: If True, force an untied lm_head even if donor used tied embeddings 334 | verbose: Whether to print detailed mapping information 335 | 336 | Returns: 337 | Tuple of (new state dict, embedding statistics) 338 | """ 339 | # Get donor hidden size 340 | donor_hidden_size = get_config_value(donor_config, "hidden_size") 341 | 342 | # Get donor embeddings 343 | donor_embed_tokens = model.model.embed_tokens.weight 344 | donor_tied = get_config_value(donor_config, "tie_word_embeddings", False) 345 | use_separate_head = (not donor_tied) or untie_word_embeddings 346 | 347 | if use_separate_head: 348 | if donor_tied: 349 | print("\nNOTE: Using an \"untied\" copy of 'embed_tokens.weight' as new 'lm_head.weight' tensor...\n") 350 | donor_lm_head = donor_embed_tokens 351 | else: 352 | print("\nNOTE: Using actual 'lm_head.weight' tensor as donor not configured with 'tie_word_embeddings...\n") 353 | donor_lm_head = model.lm_head.weight 354 | else: 355 | print("\nNOTE: Preserving tied word embeddings; applying front-loaded mean to 'embed_tokens'; no separate 'lm_head.weight' will be saved...\n") 356 | donor_lm_head = None 357 | 358 | # Initialize new embeddings 359 | new_embed_tokens = torch.zeros( 360 | (vocab_size, donor_hidden_size), 361 | dtype=donor_embed_tokens.dtype, 362 | device=donor_embed_tokens.device 363 | ) 364 | # Only allocate new_lm_head if we will save a separate head 365 | if use_separate_head: 366 | new_lm_head = torch.zeros( 367 | (vocab_size, donor_hidden_size), 368 | dtype=donor_lm_head.dtype, 369 | device=donor_lm_head.device 370 | ) 371 | 372 | # Track mapping statistics 373 | mapping_counts = {} 374 | lm_head_copy_count = 0 375 | lm_head_mean_count = 0 376 | 377 | # Configure progress display 378 | iterator = range(used_vocab_size) 379 | if verbose: 380 | print("Transplanting tokens:") 381 | else: 382 | iterator = tqdm(iterator, desc="Transplanting tokens", unit="token") 383 | 384 | for idx in iterator: 385 | decoded = target_tokenizer.decode([idx], decode_special_tokens=True) 386 | if idx in override_map: 387 | encoded = override_map[idx] 388 | else: 389 | encoded = donor_tokenizer.encode(decoded, add_special_tokens=False, return_tensors="pt").flatten() 390 | # Fall back to the actual token string (preserves exact representation) 391 | if encoded.numel() == 0 and hasattr(target_tokenizer, 'convert_ids_to_tokens'): 392 | print(f"WARNING: Token {idx} {repr(decoded)} → empty, trying convert_ids_to_tokens()") 393 | token_str = target_tokenizer.convert_ids_to_tokens(idx) 394 | encoded = donor_tokenizer.encode(token_str, add_special_tokens=False, return_tensors="pt").flatten() 395 | # Fall back to just using the EOS token as the last resort 396 | if encoded.numel() == 0: 397 | print(f"WARNING: Token {idx} → empty, using EOS [{donor_tokenizer.eos_token_id}] as fallback") 398 | encoded = torch.tensor([donor_tokenizer.eos_token_id], dtype=torch.long) 399 | 400 | if verbose: 401 | print(f"- {idx:6d} : {repr(decoded)} → {encoded.tolist()}") 402 | 403 | # Track mapping types 404 | if encoded.numel() in mapping_counts: 405 | mapping_counts[encoded.numel()] += 1 406 | else: 407 | mapping_counts[encoded.numel()] = 1 408 | 409 | if use_separate_head: 410 | # Use only the final token of encoded sequence for input embeddings 411 | new_embed_tokens[idx] = donor_embed_tokens[encoded[-1]] 412 | 413 | # Use a "front-loaded" exponentially-weighted mean for lm_head embeddings 414 | if encoded.numel() == 1: 415 | new_lm_head[idx] = donor_lm_head[encoded[0].item()] 416 | lm_head_copy_count += 1 417 | else: 418 | head_embeddings = donor_lm_head[encoded.flatten()] 419 | new_lm_head[idx] = compute_front_loaded_mean(head_embeddings, weighting_decay_factor) 420 | lm_head_mean_count += 1 421 | else: 422 | # Preserve tying: apply front-loaded mean directly to embed_tokens; mirror to lm_head 423 | if encoded.numel() == 1: 424 | new_embed_tokens[idx] = donor_embed_tokens[encoded[0].item()] 425 | lm_head_copy_count += 1 426 | else: 427 | emb_stack = donor_embed_tokens[encoded.flatten()] 428 | new_embed_tokens[idx] = compute_front_loaded_mean(emb_stack, weighting_decay_factor) 429 | lm_head_mean_count += 1 430 | 431 | # Print statistics 432 | print("\nTransplant mappings:") 433 | for count, occurrences in sorted(mapping_counts.items()): 434 | mapping_label = f"{count} to 1" 435 | print(f"- {mapping_label:<8}: {occurrences} ({(occurrences/used_vocab_size*100):.2g}%)") 436 | 437 | print("\nHead initialized with:") 438 | lm_head_zeroed_count = vocab_size - (lm_head_copy_count + lm_head_mean_count) 439 | print(f"- Copies : {lm_head_copy_count} ({(lm_head_copy_count/vocab_size*100):.2g}%)") 440 | print(f"- Means : {lm_head_mean_count} ({(lm_head_mean_count/vocab_size*100):.2g}%)") 441 | print(f"- Zeros : {lm_head_zeroed_count} ({(lm_head_zeroed_count/vocab_size*100):.2g}%)") 442 | 443 | # Make a copy of the model's state_dict and get the type 444 | new_state_dict = model.state_dict().copy() 445 | old_dtype = model.model.embed_tokens.weight.dtype 446 | 447 | # Update the state_dict with new embeddings 448 | new_state_dict['model.embed_tokens.weight'] = new_embed_tokens.to(dtype=old_dtype) 449 | # Only include a separate lm_head when we're untying or donor was already untied. 450 | # When preserving tying, ensure lm_head is not saved separately. 451 | if use_separate_head: 452 | new_state_dict['lm_head.weight'] = new_lm_head.to(dtype=old_dtype) 453 | else: 454 | if 'lm_head.weight' in new_state_dict: 455 | del new_state_dict['lm_head.weight'] 456 | 457 | return new_state_dict 458 | 459 | def trim_model_layers(model, state_dict, start_layer, end_layer): 460 | """ 461 | Trim out a range of layers from the model and its state_dict. 462 | 463 | Args: 464 | model: The model to modify 465 | state_dict: The state dictionary to modify 466 | start_layer: The first layer to remove (inclusive) 467 | end_layer: The last layer to remove (inclusive) 468 | 469 | Returns: 470 | Tuple of (modified model, modified state_dict) 471 | """ 472 | # Get the total number of layers in the model 473 | total_layers = get_config_value(model.config, 'num_hidden_layers') 474 | assert start_layer >= 0 and start_layer < end_layer and end_layer < total_layers, f"Invalid layer range: start={start_layer}, end={end_layer}, total={total_layers}" 475 | 476 | print(f"\nTrimming layers {start_layer} through {end_layer} (inclusive): ") 477 | 478 | # Calculate how many layers to keep 479 | new_layer_count = total_layers - (end_layer - start_layer + 1) 480 | print(f"- Old layer count : {total_layers} (layers 0-{total_layers-1})") 481 | print(f"- New layer count : {new_layer_count} (keeping layers 0-{start_layer-1} and {end_layer+1}-{total_layers-1})") 482 | 483 | # Update the model configuration 484 | set_config_value(model.config, 'num_hidden_layers', new_layer_count) 485 | 486 | # Create a mapping from old layer indices to new layer indices 487 | layer_mapping = {} 488 | new_idx = 0 489 | for old_idx in range(total_layers): 490 | if old_idx < start_layer or old_idx > end_layer: 491 | layer_mapping[old_idx] = new_idx 492 | new_idx += 1 493 | 494 | # Create a new state dict with trimmed layers 495 | new_state_dict = {} 496 | removed_keys = [] 497 | renamed_keys = [] 498 | 499 | # First pass: identify all keys to process 500 | all_keys = list(state_dict.keys()) 501 | 502 | layer_patterns = [r'model\.layers\.(\d+)\.', r'transformer\.h\.(\d+)\.', r'model\.decoder\.layers\.(\d+)\.'] 503 | 504 | for key in all_keys: 505 | # Check if this key corresponds to a layer 506 | layer_match = None 507 | for pattern in layer_patterns: 508 | match = re.search(pattern, key) 509 | if match: 510 | layer_idx = int(match.group(1)) 511 | if start_layer <= layer_idx <= end_layer: 512 | # This layer should be removed 513 | removed_keys.append(key) 514 | else: 515 | # This layer is kept, but we need to renumber it 516 | new_layer_idx = layer_mapping[layer_idx] 517 | prefix = match.group(0) # e.g., "model.layers.22." 518 | new_prefix = prefix.replace(f"{layer_idx}", f"{new_layer_idx}") 519 | new_key = key.replace(prefix, new_prefix) 520 | 521 | # Add to renamed keys list 522 | renamed_keys.append((key, new_key)) 523 | 524 | # Create a new tensor to avoid shared memory issues 525 | new_state_dict[new_key] = state_dict[key].clone() 526 | 527 | # We found a match, so no need to check other patterns 528 | break 529 | 530 | # If no layer match was found, keep the tensor as is 531 | if layer_match is None and key not in removed_keys and not any(key == old_key for old_key, _ in renamed_keys): 532 | new_state_dict[key] = state_dict[key].clone() 533 | 534 | # For models with specific architectures, we might need to modify the layers list 535 | if hasattr(model, 'model') and hasattr(model.model, 'layers'): 536 | # Create a new ModuleList with only the layers we want to keep 537 | new_layers = nn.ModuleList() 538 | for i, layer in enumerate(model.model.layers): 539 | if i < start_layer or i > end_layer: 540 | new_layers.append(layer) 541 | model.model.layers = new_layers 542 | 543 | print(f"- Removed {len(removed_keys)} tensors from state_dict") 544 | print(f"- Renamed {len(renamed_keys)} layer tensors to new indices") 545 | print(f"- Updated model configuration: num_hidden_layers = {new_layer_count}") 546 | 547 | return model, new_state_dict 548 | 549 | def trim_tensors(state_dict, old_size, new_size): 550 | """ 551 | Trim all tensors in a state dictionary that have dimensions matching old_size. 552 | 553 | Args: 554 | state_dict: The state dictionary to modify 555 | old_size: The original dimension size to look for 556 | new_size: The new dimension size to trim to 557 | 558 | Returns: 559 | Tuple of (new state dict, count of trimmed tensors) 560 | """ 561 | new_state_dict = {} 562 | trimmed_count = 0 563 | 564 | # Process each tensor in the state dict 565 | for key, tensor in state_dict.items(): 566 | # Check if this tensor has a dimension matching the size to trim 567 | if any(dim == old_size for dim in tensor.shape): 568 | # Create a new tensor with the appropriate dimensions 569 | new_shape = list(tensor.shape) 570 | for i, dim in enumerate(new_shape): 571 | if dim == old_size: 572 | new_shape[i] = new_size 573 | 574 | # Create a completely new tensor with the new shape 575 | if len(new_shape) == 1: 576 | new_tensor = torch.zeros( 577 | new_shape[0], 578 | dtype=tensor.dtype, 579 | device=tensor.device 580 | ) 581 | # Copy data from the original tensor 582 | new_tensor[:] = tensor[:new_size] 583 | elif len(new_shape) == 2: 584 | new_tensor = torch.zeros( 585 | new_shape[0], new_shape[1], 586 | dtype=tensor.dtype, 587 | device=tensor.device 588 | ) 589 | # Copy data based on which dimensions need trimming 590 | if tensor.shape[0] == old_size and tensor.shape[1] == old_size: 591 | new_tensor[:,:] = tensor[:new_size,:new_size] 592 | elif tensor.shape[0] == old_size: 593 | new_tensor[:,:] = tensor[:new_size,:] 594 | else: 595 | new_tensor[:,:] = tensor[:,:new_size] 596 | else: 597 | # For higher dimensional tensors 598 | new_tensor = torch.zeros( 599 | new_shape, 600 | dtype=tensor.dtype, 601 | device=tensor.device 602 | ) 603 | # Create slices for copying 604 | src_slices = tuple(slice(0, new_shape[i]) if tensor.shape[i] == old_size else slice(None) 605 | for i in range(len(tensor.shape))) 606 | dst_slices = tuple(slice(None) for _ in range(len(tensor.shape))) 607 | new_tensor[dst_slices] = tensor[src_slices] 608 | 609 | new_state_dict[key] = new_tensor 610 | trimmed_count += 1 611 | else: 612 | # Keep tensors that don't have matching dimensions unchanged 613 | new_state_dict[key] = tensor.clone() 614 | 615 | return new_state_dict, trimmed_count 616 | 617 | def trim_model_hidden_size(model, state_dict, new_size): 618 | """ 619 | Trim the hidden state dimension of the residual stream. 620 | 621 | Args: 622 | model: The model to modify 623 | state_dict: The state dictionary to modify 624 | new_size: The new hidden state dimension to use 625 | 626 | Returns: 627 | Tuple of (modified model, modified state_dict) 628 | """ 629 | old_size = get_config_value(model.config, 'hidden_size') 630 | old_num_heads = get_config_value(model.config, 'num_attention_heads') 631 | old_num_kv_heads = get_config_value(model.config, 'num_key_value_heads') 632 | assert new_size < old_size, f"New hidden size ({new_size}) must be smaller than old ({old_size})" 633 | assert old_size % old_num_heads == 0, f"Old hidden size ({old_size}) is not divisible by number of heads ({old_num_heads})" 634 | 635 | head_dimension = old_size // old_num_heads 636 | assert new_size % head_dimension == 0, f"New hidden size ({new_size}) is not divisible by head dimension ({head_dimension})" 637 | 638 | new_num_heads = new_size // head_dimension 639 | assert new_num_heads >= old_num_kv_heads, f"New num heads({new_num_heads}) is less than KV heads {old_num_kv_heads})" 640 | 641 | print(f"\nTrimming hidden size from {old_size} to {new_size}: ") 642 | print(f"- Old hidden size : {old_size}") 643 | print(f"- New hidden size : {new_size}") 644 | 645 | set_config_value(model.config, 'hidden_size', new_size) 646 | set_config_value(model.config, 'num_attention_heads', new_num_heads) 647 | print(f"- Updated model configuration: hidden_size = {new_size}") 648 | print(f"- Updated model configuration: num_attention_heads = {new_num_heads}") 649 | 650 | new_state_dict, trimmed_count = trim_tensors(state_dict, old_size, new_size) 651 | print(f"- Trimmed {trimmed_count} tensors in state_dict") 652 | 653 | return model, new_state_dict 654 | 655 | def trim_model_intermediate_size(model, state_dict, new_size): 656 | """ 657 | Trim the hidden state dimension of the MLP blocks. 658 | 659 | Args: 660 | model: The model to modify 661 | state_dict: The state dictionary to modify 662 | new_size: The new hidden state dimension to use 663 | 664 | Returns: 665 | Tuple of (modified model, modified state_dict) 666 | """ 667 | old_size = get_config_value(model.config, 'intermediate_size') 668 | assert new_size < old_size, f"New intermediate size ({new_size}) must be smaller than old ({old_size})" 669 | 670 | print(f"\nTrimming intermediate size from {old_size} to {new_size}: ") 671 | print(f"- Old intermediate size : {old_size}") 672 | print(f"- New intermediate size : {new_size}") 673 | 674 | set_config_value(model.config, 'intermediate_size', new_size) 675 | print(f"- Updated model configuration: intermediate_size = {new_size}") 676 | 677 | new_state_dict, trimmed_count = trim_tensors(state_dict, old_size, new_size) 678 | print(f"- Trimmed {trimmed_count} tensors in state_dict") 679 | 680 | return model, new_state_dict 681 | 682 | def patch_tokenizer_config_bos(output_dir): 683 | """ 684 | Patch the tokenizer configuration to handle models without BOS tokens. 685 | 686 | Args: 687 | output_dir: Path to the output directory containing the tokenizer_config.json 688 | """ 689 | tokenizer_config_path = os.path.join(output_dir, "tokenizer_config.json") 690 | if os.path.exists(tokenizer_config_path): 691 | print(f"\nPatching BOS handling in '{tokenizer_config_path}'") 692 | try: 693 | # Read the file as text without specifying encoding 694 | with open(tokenizer_config_path, "r") as f: 695 | config_text = f.read() 696 | 697 | # Make sure that add_bos_token is set to false 698 | config_text = config_text.replace('"add_bos_token": true', '"add_bos_token": false') 699 | print("- Updated 'add_bos_token' configuration.") 700 | 701 | # Remove any use of bos_token from chat template 702 | # NOTE: We can't (safely) set '"bos_token": null', but it shouldn't matter with these two patches... 703 | config_text = config_text.replace("{{ bos_token }}", "").replace("{{bos_token}}", "") 704 | print("- Removed all references to 'bos_token' from Jinja chat template.") 705 | 706 | # Write the modified text back without specifying encoding 707 | with open(tokenizer_config_path, "w") as f: 708 | f.write(config_text) 709 | except Exception as e: 710 | print(f"Warning: Failed to patch tokenizer configuration: {e}") 711 | 712 | def patch_config_dtype(output_dir): 713 | """ 714 | Patch the config.json file with the correct dtype based on what was actually saved in the safetensors file. 715 | 716 | Args: 717 | output_dir: Path to the output directory containing the config.json and model files 718 | """ 719 | import json 720 | from safetensors import safe_open 721 | 722 | config_path = os.path.join(output_dir, "config.json") 723 | model_path = os.path.join(output_dir, "model.safetensors") 724 | 725 | if not os.path.exists(config_path) or not os.path.exists(model_path): 726 | print(f"Warning: Could not find config.json or model.safetensors in {output_dir}") 727 | return 728 | 729 | print(f"\nPatching 'torch_dtype' in '{config_path}' based on actual saved tensors") 730 | 731 | try: 732 | # Open the safetensors file and check the dtype of a tensor 733 | with safe_open(model_path, framework="pt", device="cpu") as f: 734 | # Get the first tensor's dtype (embed_tokens is a good choice as it's always present) 735 | for key in f.keys(): 736 | if "embed_tokens" in key: 737 | tensor = f.get_tensor(key) 738 | dtype_str = str(tensor.dtype).split('.')[-1] 739 | break 740 | else: 741 | # Fallback to any tensor if embed_tokens not found 742 | key = list(f.keys())[0] 743 | tensor = f.get_tensor(key) 744 | dtype_str = str(tensor.dtype).split('.')[-1] 745 | 746 | # Read the config file 747 | with open(config_path, "r") as f: 748 | config = json.load(f) 749 | 750 | # Update the dtype 751 | config['torch_dtype'] = dtype_str 752 | print(f"- Updated 'torch_dtype' to '{dtype_str}' based on actual tensor dtype") 753 | 754 | # Write the modified config back 755 | with open(config_path, "w") as f: 756 | json.dump(config, f, indent=2) 757 | 758 | except Exception as e: 759 | print(f"Warning: Failed to patch config file: {e}") 760 | 761 | def debug_model_tensors(model, state_dict): 762 | """ 763 | Print detailed information about model parameters and state dict tensors 764 | to help debug shape mismatches. 765 | 766 | Args: 767 | model: The model to inspect 768 | state_dict: The state dictionary to inspect 769 | """ 770 | print("\n=== MODEL PARAMETERS ===") 771 | for name, param in model.named_parameters(): 772 | print(f"{name}: shape={param.shape}, dtype={param.dtype}") 773 | 774 | print("\n=== STATE DICT TENSORS ===") 775 | for key, tensor in state_dict.items(): 776 | print(f"{key}: shape={tensor.shape}, dtype={tensor.dtype}") 777 | 778 | print("\n=== SHAPE MISMATCHES ===") 779 | mismatches = [] 780 | for name, param in model.named_parameters(): 781 | if name in state_dict and param.shape != state_dict[name].shape: 782 | mismatches.append((name, param.shape, state_dict[name].shape)) 783 | 784 | if mismatches: 785 | print("Found shape mismatches between model parameters and state dict:") 786 | for name, model_shape, dict_shape in mismatches: 787 | print(f"- {name}: model={model_shape}, state_dict={dict_shape}") 788 | else: 789 | print("No shape mismatches found between model parameters and state dict.") 790 | 791 | # Check for tensors in state_dict that don't exist in model 792 | model_params = {name for name, _ in model.named_parameters()} 793 | extra_tensors = {key for key in state_dict if key not in model_params} 794 | if extra_tensors: 795 | print("\n=== EXTRA TENSORS IN STATE DICT ===") 796 | for key in extra_tensors: 797 | print(f"{key}: shape={state_dict[key].shape}") 798 | 799 | # Check for parameters in model that don't exist in state_dict 800 | missing_tensors = {name for name, _ in model.named_parameters() if name not in state_dict} 801 | if missing_tensors: 802 | print("\n=== MISSING TENSORS IN STATE DICT ===") 803 | for name in missing_tensors: 804 | print(name) 805 | 806 | def main(): 807 | args = parse_arguments() 808 | validate_directories(args) 809 | 810 | # Load configurations 811 | donor_config = load_model_config(args.donor_dir) 812 | target_config = load_model_config(args.target_dir) 813 | 814 | # Get configuration values we will need 815 | target_vocab_size = get_config_value(target_config, "vocab_size") 816 | donor_vocab_size = get_config_value(donor_config, "vocab_size") 817 | donor_num_layers = get_config_value(donor_config, 'num_hidden_layers') 818 | donor_tied_embeddings = get_config_value(donor_config, "tie_word_embeddings", False) 819 | donor_hidden_size = get_config_value(donor_config, "hidden_size") 820 | donor_num_heads = get_config_value(donor_config, "num_attention_heads") 821 | donor_intermediate_size = get_config_value(donor_config, "intermediate_size") 822 | 823 | # Load tokenizers 824 | donor_tokenizer = load_tokenizer(args.donor_dir, args.trust_remote_code) 825 | target_tokenizer = load_tokenizer(args.target_dir, args.trust_remote_code) 826 | 827 | # Load the donor model 828 | model = load_model(args.donor_dir, args.trust_remote_code, args.use_cpu_only) 829 | 830 | # The config file counts the all tokens, but we also need to know how many are used for the loop 831 | if hasattr(target_tokenizer, 'vocab'): 832 | used_target_vocab_size = max(target_tokenizer.vocab.values()) + 1 833 | unused_target_vocab_size = target_vocab_size - used_target_vocab_size 834 | else: 835 | # For TikToken tokenizers (eg: Kimi-K2-Instruct), just use the full vocabulary size 836 | used_target_vocab_size = target_vocab_size 837 | unused_target_vocab_size = 0 838 | 839 | # Count parameters in donor model 840 | donor_total_params, donor_embedding_params, donor_non_embedding_params = count_model_parameters(model) 841 | donor_total_params_b = donor_total_params / 1e9 842 | donor_embedding_params_b = donor_embedding_params / 1e9 843 | donor_non_embedding_params_b = donor_non_embedding_params / 1e9 844 | 845 | print("\nInput model configuration:") 846 | print(f"- Target vocabulary size : {target_vocab_size} (used = {used_target_vocab_size}, unused = {unused_target_vocab_size})") 847 | print(f"- Donor vocabulary size : {donor_vocab_size}") 848 | print(f"- Donor num layers : {donor_num_layers} (tied embeddings = {donor_tied_embeddings})") 849 | print(f"- Donor hidden size : {donor_hidden_size}") 850 | print(f"- Donor attention heads : {donor_num_heads}") 851 | print(f"- Donor intermediate size : {donor_intermediate_size} (ratio = 1:{donor_intermediate_size/donor_hidden_size:.1f})") 852 | print(f"- Donor total parameters : {donor_total_params} ({donor_total_params_b:.2f}B)") 853 | print(f"-- Embedding parameters : {donor_embedding_params} ({donor_embedding_params_b:.2f}B)") 854 | print(f"-- Non-embedding parameters : {donor_non_embedding_params} ({donor_non_embedding_params_b:.2f}B)") 855 | 856 | # Automatic and manual overrides 857 | override_map = {} 858 | 859 | # Process automatic and manual token overrides 860 | override_map = process_automatic_token_overrides(target_tokenizer, donor_tokenizer, target_config, donor_config) 861 | 862 | # Process manual token overrides 863 | override_map = process_manual_token_overrides(target_tokenizer, donor_tokenizer, args.override, override_map) 864 | 865 | # Transplant tokens from donor model to target vocabulary 866 | new_state_dict = transplant_tokens( 867 | model=model, 868 | donor_config=donor_config, 869 | target_tokenizer=target_tokenizer, 870 | donor_tokenizer=donor_tokenizer, 871 | override_map=override_map, 872 | vocab_size=target_vocab_size, 873 | used_vocab_size=used_target_vocab_size, 874 | weighting_decay_factor=args.weighting_decay_factor, 875 | untie_word_embeddings=not args.keep_tied_word_embeddings, 876 | verbose=args.verbose 877 | ) 878 | 879 | # Trim layers if requested 880 | if args.trim_layers: 881 | start_layer, end_layer = map(int, args.trim_layers.split('-')) 882 | model, new_state_dict = trim_model_layers(model, new_state_dict, start_layer, end_layer) 883 | 884 | # Trim hidden size if requested 885 | if args.trim_hidden_size: 886 | model, new_state_dict = trim_model_hidden_size(model, new_state_dict, args.trim_hidden_size) 887 | 888 | # Trim intermediate size if requested 889 | if args.trim_intermediate_size: 890 | model, new_state_dict = trim_model_intermediate_size(model, new_state_dict, args.trim_intermediate_size) 891 | 892 | # Update model architecture 893 | model.model.embed_tokens.num_embeddings = target_vocab_size 894 | model.lm_head.out_features = target_vocab_size 895 | 896 | # Update model config 897 | set_config_value(model.config, 'vocab_size', target_vocab_size) 898 | set_config_value(model.config, 'bos_token_id', target_tokenizer.bos_token_id) 899 | set_config_value(model.config, 'eos_token_id', target_tokenizer.eos_token_id) 900 | 901 | # Update the config's pad_token_id if it exists 902 | if has_config_value(model.config, 'pad_token_id'): 903 | if target_tokenizer.pad_token_id is not None: 904 | set_config_value(model.config, 'pad_token_id', target_tokenizer.pad_token_id) 905 | else: 906 | set_config_value(model.config, 'pad_token_id', target_tokenizer.eos_token_id) # Default to EOS if no PAD to copy 907 | 908 | # Set the config's tie_word_embeddings based on requested behavior 909 | if has_config_value(model.config, 'tie_word_embeddings'): 910 | if not args.keep_tied_word_embeddings: 911 | set_config_value(model.config, 'tie_word_embeddings', False) 912 | else: 913 | # Preserve donor setting 914 | set_config_value(model.config, 'tie_word_embeddings', get_config_value(donor_config, 'tie_word_embeddings', False)) 915 | 916 | # Re-initialize the model with the updated configuration 917 | # NOTE: This seems to be more robust that just altering the model and state dict parameters 918 | model = type(model)(model.config) 919 | 920 | output_num_layers = get_config_value(model.config, 'num_hidden_layers') 921 | output_tied_embeddings = get_config_value(model.config, "tie_word_embeddings", False) 922 | output_hidden_size = get_config_value(model.config, "hidden_size") 923 | output_num_heads = get_config_value(model.config, "num_attention_heads") 924 | output_intermediate_size = get_config_value(model.config, "intermediate_size") 925 | 926 | # Count parameters in output model 927 | output_total_params, output_embedding_params, output_non_embedding_params = count_model_parameters(model) 928 | output_total_params_b = output_total_params / 1e9 929 | output_embedding_params_b = output_embedding_params / 1e9 930 | output_non_embedding_params_b = output_non_embedding_params / 1e9 931 | 932 | # Print output model configuration values 933 | print("\nOutput model configuration:") 934 | print(f"- Output vocabulary size : {target_vocab_size}") 935 | print(f"- Output num layers : {output_num_layers} (tied embeddings = {output_tied_embeddings})") 936 | print(f"- Output hidden size : {output_hidden_size}") 937 | print(f"- Output attention heads : {output_num_heads}") 938 | print(f"- Output intermediate size : {output_intermediate_size} (ratio = 1:{output_intermediate_size/output_hidden_size:.1f})") 939 | print(f"- Output total parameters : {output_total_params} ({output_total_params_b:.2f}B)") 940 | print(f"-- Embedding parameters : {output_embedding_params} ({output_embedding_params_b:.2f}B)") 941 | print(f"-- Non-embedding parameters : {output_non_embedding_params} ({output_non_embedding_params_b:.2f}B)") 942 | 943 | # debug_model_tensors(model, new_state_dict) 944 | 945 | # Save final model and tokenizer 946 | print(f"\nSaving model and tokenizer to '{args.output_dir}' folder") 947 | model.save_pretrained(args.output_dir, state_dict=new_state_dict, safe_serialization=True) 948 | target_tokenizer.save_pretrained(args.output_dir) 949 | 950 | # Patch the stupid `torch_dtype` bug in the config file where it always saves as float32 regardless of the actual type... 951 | patch_config_dtype(args.output_dir) 952 | 953 | # Attempt to patch the EOS stuff if the donor tokenizer doesn't use BOS tokens 954 | if args.patch_missing_bos and (getattr(donor_tokenizer, "add_bos_token", False) 955 | or getattr(donor_tokenizer, "bos_token", None) is None): 956 | patch_tokenizer_config_bos(args.output_dir) 957 | 958 | # TODO: Figure out why it causes a segmentation fault on exit??? 959 | print("\nOperation completed successfully (ignore any 'segmentation fault' that follows!!!)") 960 | 961 | if __name__ == "__main__": 962 | main() 963 | --------------------------------------------------------------------------------