├── .gitignore ├── LICENSE ├── README.md ├── assets └── figure1.png ├── modeling_monet.py └── modeling_monet_vllm.py /.gitignore: -------------------------------------------------------------------------------- 1 | # File created using '.gitignore Generator' for Visual Studio Code: https://bit.ly/vscode-gig 2 | # Created by https://www.toptal.com/developers/gitignore/api/visualstudiocode,linux,python 3 | # Edit at https://www.toptal.com/developers/gitignore?templates=visualstudiocode,linux,python 4 | 5 | ### Linux ### 6 | *~ 7 | 8 | # temporary files which can be created if a process still has a handle open of a deleted file 9 | .fuse_hidden* 10 | 11 | # KDE directory preferences 12 | .directory 13 | 14 | # Linux trash folder which might appear on any partition or disk 15 | .Trash-* 16 | 17 | # .nfs files are created when an open file is removed but is still being accessed 18 | .nfs* 19 | 20 | ### Python ### 21 | # Byte-compiled / optimized / DLL files 22 | __pycache__/ 23 | *.py[cod] 24 | *$py.class 25 | 26 | # C extensions 27 | *.so 28 | 29 | # Distribution / packaging 30 | .Python 31 | build/ 32 | develop-eggs/ 33 | dist/ 34 | downloads/ 35 | eggs/ 36 | .eggs/ 37 | lib/ 38 | lib64/ 39 | parts/ 40 | sdist/ 41 | var/ 42 | wheels/ 43 | share/python-wheels/ 44 | *.egg-info/ 45 | .installed.cfg 46 | *.egg 47 | MANIFEST 48 | 49 | # PyInstaller 50 | # Usually these files are written by a python script from a template 51 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 52 | *.manifest 53 | *.spec 54 | 55 | # Installer logs 56 | pip-log.txt 57 | pip-delete-this-directory.txt 58 | 59 | # Unit test / coverage reports 60 | htmlcov/ 61 | .tox/ 62 | .nox/ 63 | .coverage 64 | .coverage.* 65 | .cache 66 | nosetests.xml 67 | coverage.xml 68 | *.cover 69 | *.py,cover 70 | .hypothesis/ 71 | .pytest_cache/ 72 | cover/ 73 | 74 | # Translations 75 | *.mo 76 | *.pot 77 | 78 | # Django stuff: 79 | *.log 80 | local_settings.py 81 | db.sqlite3 82 | db.sqlite3-journal 83 | 84 | # Flask stuff: 85 | instance/ 86 | .webassets-cache 87 | 88 | # Scrapy stuff: 89 | .scrapy 90 | 91 | # Sphinx documentation 92 | docs/_build/ 93 | 94 | # PyBuilder 95 | .pybuilder/ 96 | target/ 97 | 98 | # Jupyter Notebook 99 | .ipynb_checkpoints 100 | 101 | # IPython 102 | profile_default/ 103 | ipython_config.py 104 | 105 | # pyenv 106 | # For a library or package, you might want to ignore these files since the code is 107 | # intended to run in multiple environments; otherwise, check them in: 108 | # .python-version 109 | 110 | # pipenv 111 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 112 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 113 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 114 | # install all needed dependencies. 115 | #Pipfile.lock 116 | 117 | # poetry 118 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 119 | # This is especially recommended for binary packages to ensure reproducibility, and is more 120 | # commonly ignored for libraries. 121 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 122 | #poetry.lock 123 | 124 | # pdm 125 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 126 | #pdm.lock 127 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 128 | # in version control. 129 | # https://pdm.fming.dev/#use-with-ide 130 | .pdm.toml 131 | 132 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 133 | __pypackages__/ 134 | 135 | # Celery stuff 136 | celerybeat-schedule 137 | celerybeat.pid 138 | 139 | # SageMath parsed files 140 | *.sage.py 141 | 142 | # Environments 143 | .env 144 | .venv 145 | env/ 146 | venv/ 147 | ENV/ 148 | env.bak/ 149 | venv.bak/ 150 | 151 | # Spyder project settings 152 | .spyderproject 153 | .spyproject 154 | 155 | # Rope project settings 156 | .ropeproject 157 | 158 | # mkdocs documentation 159 | /site 160 | 161 | # mypy 162 | .mypy_cache/ 163 | .dmypy.json 164 | dmypy.json 165 | 166 | # Pyre type checker 167 | .pyre/ 168 | 169 | # pytype static type analyzer 170 | .pytype/ 171 | 172 | # Cython debug symbols 173 | cython_debug/ 174 | 175 | # PyCharm 176 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 177 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 178 | # and can be added to the global gitignore or merged into this file. For a more nuclear 179 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 180 | #.idea/ 181 | 182 | ### Python Patch ### 183 | # Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration 184 | poetry.toml 185 | 186 | # ruff 187 | .ruff_cache/ 188 | 189 | # LSP config files 190 | pyrightconfig.json 191 | 192 | ### VisualStudioCode ### 193 | .vscode/* 194 | 195 | # Local History for Visual Studio Code 196 | .history/ 197 | 198 | # Built Visual Studio Code Extensions 199 | *.vsix 200 | 201 | ### VisualStudioCode Patch ### 202 | # Ignore all local history of files 203 | .history 204 | .ionide 205 | 206 | # End of https://www.toptal.com/developers/gitignore/api/visualstudiocode,linux,python 207 | 208 | # Custom rules (everything added below won't be overriden by 'Generate .gitignore File' if you use 'Update' option) 209 | 210 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2024 DMIS Laboratory, Korea University 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Monet: Mixture of Monosemantic Experts for Transformers 2 | 3 | [![arXiv](https://img.shields.io/badge/arXiv-2412.04139-b31b1b?style=flat-square)](https://arxiv.org/abs/2412.04139) 4 | [![Models](https://img.shields.io/badge/%F0%9F%A4%97Hugging_Face-Model_Zoo-ffd200?style=flat-square)](https://huggingface.co/MonetLLM) 5 | [![Demo](https://img.shields.io/badge/%F0%9F%A4%97Hugging_Face-Demo-ffd200?style=flat-square)](https://huggingface.co/spaces/MonetLLM/monet-vd-1.4B-100BT-hf-viewer) 6 | [![code](https://img.shields.io/badge/Github-Code-keygen.svg?logo=github&style=flat-square)](https://github.com/dmis-lab/Monet) 7 | [![License](https://img.shields.io/badge/License-Apache_2.0-blue?style=flat-square)](./LICENSE) 8 | 9 | ![](./assets/figure1.png) 10 | 11 | ## Introduction 12 | 13 | **Monet** presents a novel approach to enhancing mechanistic interpretability in large language models (LLMs) through an innovative Sparse Mixture-of-Experts (SMoE) architecture. By directly incorporating sparse dictionary learning into end-to-end pretraining, **Monet** addresses the fundamental challenge of polysemanticity - where individual neurons respond to multiple unrelated concepts - while maintaining model performance. 14 | 15 | #### ✨Key Highlights 16 | 17 | - 📈 **Scalable Expert Architecture**: **Monet** introduces parameter-efficient expert decomposition methods that enable scaling to 262,144 experts per layer while ensuring total parameters scale proportionally to the square root of expert count. 18 | - 📊 **Monosemantic Experts**: Through fine-grained expert specialization, **Monet** achieves monosemantic experts that demonstrate mutual exclusivity of knowledge, allowing transparent observation of model behavior and parametric knowledge. 19 | - 🛠️ **Robust Knowledge Control**: The architecture enables precise manipulation of domain-specific knowledge, language capabilities, and toxicity mitigation without compromising general performance. 20 | 21 | ### Why Monet? 22 | 23 | Unlike traditional approaches using post-hoc reconstruction (like Sparse Autoencoders), **Monet** integrates interpretability directly into its architecture. This enables both transparent understanding of model internals and fundamental behavior control. By scaling monosemantic experts, Monet paves the way for more transparent and controllable language models. 24 | 25 | ## News 26 | 27 | - **2025-01-23**: Our paper has been accepted to **ICLR 2025**! 🎉 28 | - **2024-12-06**: Released **Monet: Mixture of Monosemantic Experts for Transformers** on [arXiv](https://arxiv.org/abs/2412.04139), with [GitHub](https://github.com/dmis-lab/Monet), [models](https://huggingface.co/MonetLLM), and [demo](https://huggingface.co/spaces/MonetLLM/monet-vd-1.4B-100BT-hf-viewer). 29 | 30 | ## Model Checkpoints 31 | 32 | #### Base Models 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 |
ModelDataset#Params#TokensCheckpointDemo
Monet-VDFineWeb-Edu850M100BT🤗monet-vd-850M-100BT-hf
1.4B100BT🤗monet-vd-1.4B-100BT-hf🔍Viewer
4.1B100BT🤗monet-vd-4.1B-100BT-hf
StarCoderData1.4B100BT🤗codemonet-vd-1.4B-100BT-hf🔍Viewer
Monet-HDFineWeb-Edu850M100BT🤗monet-hd-850M-100BT-hf
1.4B100BT🤗monet-hd-1.4B-100BT-hf
4.1B100BT🤗monet-hd-4.1B-100BT-hf
91 | 92 | #### Instruction-Tuned Models 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 |
ModelPurposeRecipe#ParamsCheckpoint
Monet-VDChat CompletionSmolLM1.4B🤗monet-vd-1.4B-100BT-chat-hf
Vision-Language ModelLLaVA1.6B🤗visionmonet-vd-1.4B-100BT-hf
116 | 117 | ## Quickstart 118 | 119 | You can explore the core implementation of **Monet** in [modeling_monet.py](./modeling_monet.py). We've made it easy to use Monet by including our custom code in the 🤗[Hugging Face model zoo](https://huggingface.co/MonetLLM). Simply set `trust_remote_code=True` when loading the models through the Transformers library. 120 | 121 | ### Text Generation 122 | 123 | ```python 124 | from transformers import pipeline 125 | 126 | model_name = "MonetLLM/monet-vd-1.4B-100BT-hf" 127 | pipe = pipeline( 128 | "text-generation", 129 | model_name, 130 | tokenizer=AutoTokenizer.from_pretrained(model_name), 131 | torch_dtype=torch.bfloat16, 132 | device_map="auto", 133 | trust_remote_code=True, 134 | ) 135 | print(pipe("The key to life is", max_new_tokens=20, do_sample=True)[0]["generated_text"]) 136 | ``` 137 | 138 | Output: 139 | 140 | ``` 141 | The key to life is learning how to live creatively. The question is: how do we do that, and what will 142 | ``` 143 | 144 | ### Code Generation 145 | 146 | ```python 147 | from transformers import pipeline 148 | 149 | model_name = "MonetLLM/codemonet-vd-1.4B-100BT-hf" 150 | pipe = pipeline( 151 | "text-generation", 152 | model_name, 153 | tokenizer=AutoTokenizer.from_pretrained(model_name), 154 | torch_dtype=torch.bfloat16, 155 | device_map="auto", 156 | trust_remote_code=True, 157 | ) 158 | 159 | text = ''' 160 | def print_len(x: str): 161 | """For a given string x, print the length of x.""" 162 | ''' 163 | print(pipe(text, max_new_tokens=10)[0]["generated_text"].split("\n\n")[0]) 164 | ``` 165 | 166 | Output: 167 | 168 | ``` 169 | 170 | def print_len(x: str): 171 | """For a given string x, print the length of x.""" 172 | print(len(x)) 173 | ``` 174 | 175 | ### Chat Completion 176 | 177 | ```python 178 | from transformers import pipeline 179 | 180 | model_name = "MonetLLM/codemonet-vd-1.4B-100BT-chat-hf" 181 | pipe = pipeline( 182 | "text-generation", 183 | model_name, 184 | tokenizer=AutoTokenizer.from_pretrained(model_name), 185 | torch_dtype=torch.bfloat16, 186 | device_map="auto", 187 | trust_remote_code=True, 188 | ) 189 | 190 | text = tokenizer.apply_chat_template( 191 | [{"role": "user", "content": "Hi! How are you?"}], 192 | add_generation_prompt=True, 193 | tokenize=False, 194 | ) 195 | print(pipe(text, max_new_tokens=30, do_sample=True)[0]["generated_text"]) 196 | ``` 197 | 198 | Output: 199 | 200 | ``` 201 | [INST] Hi! How are you? [/INST] I'm good, thanks! How can I help you today? 202 | ``` 203 | 204 | ### Using vLLM 205 | 206 | For enhanced inference performance, **Monet** can be integrated with the vLLM engine. Note that **Monet** requires manual registration with vLLM's `ModelRegistry` before initialization. The custom implementation is provided in [modeling_monet_vllm.py](./modeling_monet_vllm.py). 207 | 208 | ```python 209 | from vllm import LLM, ModelRegistry, SamplingParams 210 | from modeling_monet_vllm import MonetForCausalLM 211 | 212 | # Register Monet architecture with vLLM 213 | ModelRegistry.register_model("MonetForCausalLM", MonetForCausalLM) 214 | 215 | model = LLM( 216 | "MonetLLM/monet-vd-1.4B-100BT-hf", 217 | trust_remote_code=True, 218 | dtype="bfloat16", 219 | gpu_memory_utilization=0.8 220 | ) 221 | sampling_params = SamplingParams(max_tokens=20, temperature=1.0) 222 | print(model.generate("The key to life is", sampling_params)[0].outputs[0].text) 223 | ``` 224 | Output: 225 | ``` 226 | what you’re born with. If you think that you don’t have the same control and 227 | ``` 228 | 229 | ### Get Expert Routing Probabilities 230 | 231 | Based on expert routing probabilities, **Monet** enables mechanistic interpretability by understanding which sparse features are activated to which token. Following the standard MoE approach, you can obtain expert routing probabilities for all layers by setting `output_router_probs=True`. The example below demonstrates how to compute and analyze the expert activation patterns: 232 | 233 | ```python 234 | import torch 235 | from transformers import AutoModelForCausalLM, AutoTokenizer 236 | 237 | model = AutoModelForCausalLM.from_pretrained( 238 | "MonetLLM/monet-vd-1.4B-100BT-hf", 239 | torch_dtype=torch.bfloat16, 240 | device_map="auto", 241 | trust_remote_code=True, 242 | ) 243 | tokenizer = AutoTokenizer.from_pretrained("MonetLLM/monet-vd-1.4B-100BT-hf") 244 | 245 | inputs = tokenizer("City and County of San Francisco", return_tensors="pt") 246 | outputs = model(**inputs.to(model.device), output_router_probs=True) 247 | 248 | # Get full expert routing probabilities: [batch_size, seq_len, moe_heads, moe_experts**2] 249 | g1, g2 = outputs.router_probs[0][0], outputs.router_probs[0][1] 250 | g = torch.einsum("bthi,bthj->bthij", g1, g2).flatten(-2) 251 | print(g.shape) 252 | 253 | # Print number of activated experts per token. 254 | for token, routing in zip(inputs.input_ids.squeeze(0), g.squeeze(0)): 255 | token = tokenizer.decode(token).ljust(16, " ") 256 | expert_indices = (routing.sum(0) > 1e-2).argwhere().squeeze(-1) 257 | print(f"Token: {token} Activated Experts: {len(expert_indices)}") 258 | ``` 259 | 260 | Output: 261 | 262 | ``` 263 | torch.Size([1, 7, 8, 262144]) 264 | Token: Activated Experts: 62 265 | Token: City Activated Experts: 60 266 | Token: and Activated Experts: 16 267 | Token: County Activated Experts: 102 268 | Token: of Activated Experts: 11 269 | Token: San Activated Experts: 70 270 | Token: Francisco Activated Experts: 67 271 | ``` 272 | 273 | ## Citation 274 | Please cite related papers/blogs using this BibTeX if you find this useful for your research and applications. 275 | ```bibtex 276 | @article{park2024monet, 277 | title={{Monet: Mixture of Monosemantic Experts for Transformers}}, 278 | author={Jungwoo Park and Young Jin Ahn and Kee-Eung Kim and Jaewoo Kang}, 279 | journal={arXiv preprint arXiv:2404.05567}, 280 | year={2024} 281 | } 282 | ``` 283 | -------------------------------------------------------------------------------- /assets/figure1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dmis-lab/Monet/b5cb786e938c4f84583eaca4e211e693e6c9d3cd/assets/figure1.png -------------------------------------------------------------------------------- /modeling_monet.py: -------------------------------------------------------------------------------- 1 | # fmt: off 2 | from __future__ import annotations 3 | 4 | from dataclasses import dataclass 5 | 6 | import torch 7 | import torch.utils.checkpoint 8 | from scipy.stats import norm 9 | from torch import nn 10 | from torch.nn import CrossEntropyLoss 11 | from transformers.activations import ACT2FN 12 | from transformers.cache_utils import Cache, DynamicCache, StaticCache 13 | from transformers.modeling_attn_mask_utils import AttentionMaskConverter 14 | from transformers.modeling_utils import PreTrainedModel 15 | from transformers.models.llama.configuration_llama import LlamaConfig 16 | from transformers.models.llama.modeling_llama import ( 17 | LLAMA_ATTENTION_CLASSES, 18 | LlamaRMSNorm, 19 | ) 20 | from transformers.utils import ModelOutput, logging 21 | 22 | logger = logging.get_logger(__name__) 23 | 24 | 25 | @dataclass 26 | class MonetModelOutputWithPast(ModelOutput): 27 | last_hidden_state: torch.FloatTensor = None 28 | past_key_values: tuple[tuple[torch.FloatTensor]] | None = None 29 | hidden_states: tuple[torch.FloatTensor, ...] | None = None 30 | attentions: tuple[torch.FloatTensor, ...] | None = None 31 | router_probs: tuple[tuple[torch.FloatTensor, ...], ...] | None = None 32 | 33 | 34 | @dataclass 35 | class MonetCausalLMOutputWithPast(ModelOutput): 36 | loss: torch.FloatTensor | None = None 37 | aux_loss: torch.FloatTensor | None = None 38 | logits: torch.FloatTensor = None 39 | past_key_values: tuple[tuple[torch.FloatTensor]] | None = None 40 | hidden_states: tuple[torch.FloatTensor, ...] | None = None 41 | attentions: tuple[torch.FloatTensor, ...] | None = None 42 | router_probs: tuple[tuple[torch.FloatTensor, ...], ...] | None = None 43 | 44 | 45 | class MonetConfig(LlamaConfig): 46 | model_type = "monet" 47 | keys_to_ignore_at_inference = ["past_key_values"] 48 | 49 | def __init__( 50 | self, 51 | vocab_size=32000, 52 | hidden_size=4096, 53 | intermediate_size=None, 54 | num_hidden_layers=32, 55 | num_attention_heads=32, 56 | num_key_value_heads=None, 57 | hidden_act="relu2", 58 | max_position_embeddings=2048, 59 | initializer_range=0.02, 60 | rms_norm_eps=1e-6, 61 | use_cache=True, 62 | pad_token_id=None, 63 | bos_token_id=1, 64 | eos_token_id=2, 65 | pretraining_tp=1, 66 | tie_word_embeddings=False, 67 | rope_theta=10000.0, 68 | rope_scaling=None, 69 | attention_bias=False, 70 | attention_dropout=0.0, 71 | mlp_bias=None, 72 | moe_dim=8, 73 | moe_heads=8, 74 | moe_experts=512, 75 | moe_topk=32, 76 | moe_groups=4, 77 | moe_decompose="vertical", 78 | output_router_probs=False, 79 | **kwargs, 80 | ): 81 | self.moe_dim = moe_dim 82 | self.moe_heads = moe_heads 83 | self.moe_experts = moe_experts 84 | self.moe_topk = moe_topk 85 | self.moe_groups = moe_groups 86 | self.moe_decompose = moe_decompose 87 | self.output_router_probs = output_router_probs 88 | 89 | super().__init__( 90 | vocab_size=vocab_size, 91 | hidden_size=hidden_size, 92 | intermediate_size=intermediate_size, 93 | num_hidden_layers=num_hidden_layers, 94 | num_attention_heads=num_attention_heads, 95 | num_key_value_heads=num_key_value_heads, 96 | hidden_act=hidden_act, 97 | max_position_embeddings=max_position_embeddings, 98 | initializer_range=initializer_range, 99 | rms_norm_eps=rms_norm_eps, 100 | use_cache=use_cache, 101 | pad_token_id=pad_token_id, 102 | bos_token_id=bos_token_id, 103 | eos_token_id=eos_token_id, 104 | pretraining_tp=pretraining_tp, 105 | tie_word_embeddings=tie_word_embeddings, 106 | rope_theta=rope_theta, 107 | rope_scaling=rope_scaling, 108 | attention_bias=attention_bias, 109 | attention_dropout=attention_dropout, 110 | mlp_bias=mlp_bias, 111 | **kwargs, 112 | ) 113 | 114 | 115 | class MonetRouter(nn.Module): 116 | def __init__(self, config: MonetConfig): 117 | super().__init__() 118 | self.config = config 119 | flatten_shape = config.moe_heads * config.moe_experts 120 | 121 | self.w1 = nn.Linear(config.hidden_size, flatten_shape, bias=False) 122 | self.w2 = nn.Linear(config.hidden_size, flatten_shape, bias=False) 123 | self.norm1 = nn.BatchNorm1d(config.moe_heads, affine=False) 124 | self.norm2 = nn.BatchNorm1d(config.moe_heads, affine=False) 125 | 126 | def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: 127 | g1z = self.w1(x).unflatten(-1, (self.config.moe_heads, -1)).float() 128 | g2z = self.w2(x).unflatten(-1, (self.config.moe_heads, -1)).float() 129 | 130 | g1n = self.norm1(g1z.transpose(2, 3).flatten(0, -2)) 131 | g2n = self.norm2(g2z.transpose(2, 3).flatten(0, -2)) 132 | g1n = g1n.view(g1z.size(0), g1z.size(1), g1z.size(3), -1).transpose(2, 3) 133 | g2n = g2n.view(g2z.size(0), g2z.size(1), g2z.size(3), -1).transpose(2, 3) 134 | 135 | sigma = float(norm.ppf(1 - self.config.moe_topk / self.config.moe_experts)) 136 | g1s = g1n.amax(-1, keepdim=True).clamp_max_(sigma) 137 | g2s = g2n.amax(-1, keepdim=True).clamp_max_(sigma) 138 | 139 | g1 = nn.functional.softmax(torch.where(g1n >= g1s, g1z, -1e10), dim=-1) 140 | g2 = nn.functional.softmax(torch.where(g2n >= g2s, g2z, -1e10), dim=-1) 141 | return g1, g2 142 | 143 | 144 | class MonetMoVDE(nn.Module): 145 | def __init__(self, config: MonetConfig): 146 | super().__init__() 147 | self.config = config 148 | self.act_fn = ACT2FN[config.hidden_act] 149 | flatten_shape = config.moe_experts * config.moe_dim // 2 150 | 151 | self.u1 = nn.Linear(config.hidden_size, flatten_shape) 152 | self.u2 = nn.Linear(config.hidden_size, flatten_shape) 153 | 154 | self.v11 = nn.Linear(flatten_shape, config.hidden_size // 2, bias=False) 155 | self.v12 = nn.Linear(flatten_shape, config.hidden_size // 2, bias=False) 156 | self.v21 = nn.Linear(flatten_shape, config.hidden_size // 2, bias=False) 157 | self.v22 = nn.Linear(flatten_shape, config.hidden_size // 2, bias=False) 158 | 159 | self.b1 = nn.Parameter(torch.zeros(config.moe_experts, config.hidden_size // 2)) 160 | self.b2 = nn.Parameter(torch.zeros(config.moe_experts, config.hidden_size // 2)) 161 | 162 | def forward( 163 | self, x: torch.Tensor, g1: torch.Tensor, g2: torch.Tensor 164 | ) -> torch.Tensor: 165 | g1, g2 = g1.type_as(x), g2.type_as(x) 166 | x1 = self.act_fn(self.u1(x).unflatten(-1, (self.config.moe_experts, -1))) 167 | x2 = self.act_fn(self.u2(x).unflatten(-1, (self.config.moe_experts, -1))) 168 | 169 | x11 = self.v11(torch.einsum("btim,bthi->btim", x1, g1).flatten(-2)) 170 | x12 = self.v12(torch.einsum("btjm,bthj,bthi->btim", x2, g2, g1).flatten(-2)) 171 | x13 = torch.einsum("bthi,id->btd", g1, self.b1.type_as(x)) 172 | 173 | x21 = self.v21(torch.einsum("btim,bthi,bthj->btjm", x1, g1, g2).flatten(-2)) 174 | x22 = self.v22(torch.einsum("btjm,bthj->btjm", x2, g2).flatten(-2)) 175 | x23 = torch.einsum("bthj,jd->btd", g2, self.b2.type_as(x)) 176 | 177 | return torch.cat((x11 + x12 + x13, x21 + x22 + x23), dim=-1) 178 | 179 | 180 | class MonetMoHDE(nn.Module): 181 | def __init__(self, config: MonetConfig): 182 | super().__init__() 183 | self.config = config 184 | self.act_fn = ACT2FN[config.hidden_act] 185 | flatten_shape = config.moe_experts * config.moe_dim 186 | 187 | self.u = nn.Linear(config.hidden_size, flatten_shape) 188 | self.v = nn.Linear(flatten_shape, config.hidden_size, bias=False) 189 | self.b = nn.Parameter(torch.zeros(config.moe_experts, config.hidden_size)) 190 | 191 | def forward( 192 | self, x: torch.Tensor, g1: torch.Tensor, g2: torch.Tensor 193 | ) -> torch.Tensor: 194 | g1, g2 = g1.type_as(x), g2.type_as(x) 195 | x = self.act_fn(self.u(x).unflatten(-1, (self.config.moe_experts, -1))) 196 | x = self.v(torch.einsum("btim,bthi,bthj->btjm", x, g1, g2).flatten(-2)) 197 | return x + torch.einsum("bthj,jd->btd", g2, self.b) 198 | 199 | 200 | class MonetDecoderLayer(nn.Module): 201 | def __init__(self, config: MonetConfig, layer_idx: int): 202 | super().__init__() 203 | self.hidden_size = config.hidden_size 204 | self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation]( 205 | config=config, layer_idx=layer_idx 206 | ) 207 | self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) 208 | self.post_attention_layernorm = LlamaRMSNorm( 209 | config.hidden_size, eps=config.rms_norm_eps 210 | ) 211 | 212 | if config.moe_decompose == "vertical": 213 | self.moe = MonetMoVDE(config) 214 | elif config.moe_decompose == "horizontal": 215 | self.moe = MonetMoHDE(config) 216 | if layer_idx % config.moe_groups == 0: 217 | self.router = MonetRouter(config).requires_grad_(False) 218 | 219 | def forward( 220 | self, 221 | hidden_states: torch.Tensor, 222 | attention_mask: torch.Tensor | None = None, 223 | position_ids: torch.LongTensor | None = None, 224 | past_key_value: Cache | None = None, 225 | previous_router_probs: tuple[torch.Tensor, torch.Tensor] | None = None, 226 | output_attentions: bool | None = False, 227 | use_cache: bool | None = False, 228 | cache_position: torch.LongTensor | None = None, 229 | **kwargs, 230 | ) -> tuple[torch.FloatTensor, ...]: 231 | residual = hidden_states 232 | 233 | hidden_states = self.input_layernorm(hidden_states) 234 | 235 | # Self Attention 236 | hidden_states, self_attn_weights, present_key_value = self.self_attn( 237 | hidden_states=hidden_states, 238 | attention_mask=attention_mask, 239 | position_ids=position_ids, 240 | past_key_value=past_key_value, 241 | output_attentions=output_attentions, 242 | use_cache=use_cache, 243 | cache_position=cache_position, 244 | ) 245 | hidden_states = residual + hidden_states 246 | 247 | # Fully Connected 248 | residual = hidden_states 249 | hidden_states = self.post_attention_layernorm(hidden_states) 250 | g1, g2 = ( 251 | self.router(hidden_states) 252 | if hasattr(self, "router") 253 | else previous_router_probs 254 | ) 255 | hidden_states = self.moe(hidden_states, g1, g2) 256 | hidden_states = residual + hidden_states 257 | 258 | outputs = (hidden_states,) 259 | 260 | if output_attentions: 261 | outputs += (self_attn_weights,) 262 | 263 | if use_cache: 264 | outputs += (present_key_value,) 265 | 266 | return outputs + ((g1, g2) if hasattr(self, "router") else None,) 267 | 268 | 269 | class MonetPreTrainedModel(PreTrainedModel): 270 | config_class = MonetConfig 271 | base_model_prefix = "model" 272 | supports_gradient_checkpointing = True 273 | _no_split_modules = ["MonetDecoderLayer"] 274 | _skip_keys_device_placement = ["past_key_values"] 275 | _supports_flash_attn_2 = True 276 | _supports_sdpa = True 277 | _supports_cache_class = True 278 | _supports_quantized_cache = True 279 | _supports_static_cache = True 280 | 281 | def _init_weights(self, module): 282 | std = self.config.initializer_range 283 | if isinstance(module, nn.Linear): 284 | module.weight.data.normal_(mean=0.0, std=std) 285 | if module.bias is not None: 286 | module.bias.data.zero_() 287 | elif isinstance(module, nn.Embedding): 288 | module.weight.data.normal_(mean=0.0, std=std) 289 | if module.padding_idx is not None: 290 | module.weight.data[module.padding_idx].zero_() 291 | 292 | 293 | class MonetModel(MonetPreTrainedModel): 294 | def __init__(self, config: MonetConfig): 295 | super().__init__(config) 296 | self.padding_idx = config.pad_token_id 297 | self.vocab_size = config.vocab_size 298 | 299 | self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) # noqa 300 | self.layers = nn.ModuleList([MonetDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]) # noqa 301 | self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) 302 | self.gradient_checkpointing = False 303 | 304 | # Initialize weights and apply final processing 305 | self.post_init() 306 | 307 | def get_input_embeddings(self): 308 | return self.embed_tokens 309 | 310 | def set_input_embeddings(self, value): 311 | self.embed_tokens = value 312 | 313 | def forward( 314 | self, 315 | input_ids: torch.LongTensor = None, 316 | attention_mask: torch.Tensor | None = None, 317 | position_ids: torch.LongTensor | None = None, 318 | past_key_values: Cache | list[torch.FloatTensor] | None = None, 319 | inputs_embeds: torch.FloatTensor | None = None, 320 | use_cache: bool | None = None, 321 | output_attentions: bool | None = None, 322 | output_hidden_states: bool | None = None, 323 | output_router_probs: bool | None = None, 324 | return_dict: bool | None = None, 325 | cache_position: torch.LongTensor | None = None, 326 | ) -> tuple[torch.Tensor, ...] | MonetModelOutputWithPast: 327 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions # noqa 328 | output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states # noqa 329 | output_router_probs = output_router_probs if output_router_probs is not None else self.config.output_router_probs # noqa 330 | use_cache = use_cache if use_cache is not None else self.config.use_cache 331 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict # noqa 332 | 333 | if (input_ids is None) ^ (inputs_embeds is not None): 334 | raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one") # noqa 335 | 336 | if self.gradient_checkpointing and self.training and use_cache: 337 | logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.") # noqa 338 | use_cache = False 339 | 340 | if inputs_embeds is None: 341 | inputs_embeds = self.embed_tokens(input_ids) 342 | 343 | return_legacy_cache = False 344 | if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs) # noqa 345 | return_legacy_cache = True 346 | past_key_values = DynamicCache.from_legacy_cache(past_key_values) 347 | logger.warning_once( 348 | "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. " # noqa 349 | "Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)" # noqa 350 | ) 351 | 352 | if cache_position is None: 353 | past_seen_tokens = ( 354 | past_key_values.get_seq_length() if past_key_values is not None else 0 355 | ) 356 | cache_position = torch.arange(past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device) # noqa 357 | if position_ids is None: 358 | position_ids = cache_position.unsqueeze(0) 359 | causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions) # noqa 360 | 361 | # embed positions 362 | hidden_states = inputs_embeds 363 | 364 | # decoder layers 365 | all_hidden_states = () if output_hidden_states else None 366 | all_self_attns = () if output_attentions else None 367 | all_router_probs = () if output_router_probs else None 368 | previous_router_probs, next_decoder_cache = None, None 369 | 370 | for decoder_layer in self.layers: 371 | if output_hidden_states: 372 | all_hidden_states += (hidden_states,) 373 | 374 | if self.gradient_checkpointing and self.training: 375 | layer_outputs = self._gradient_checkpointing_func( 376 | decoder_layer.__call__, 377 | hidden_states, 378 | causal_mask, 379 | position_ids, 380 | past_key_values, 381 | previous_router_probs, 382 | output_attentions, 383 | use_cache, 384 | cache_position, 385 | ) 386 | else: 387 | layer_outputs = decoder_layer( 388 | hidden_states, 389 | attention_mask=causal_mask, 390 | position_ids=position_ids, 391 | past_key_value=past_key_values, 392 | previous_router_probs=previous_router_probs, 393 | output_attentions=output_attentions, 394 | use_cache=use_cache, 395 | cache_position=cache_position, 396 | ) 397 | 398 | hidden_states = layer_outputs[0] 399 | if use_cache: 400 | next_decoder_cache = layer_outputs[2 if output_attentions else 1] 401 | if output_attentions: 402 | all_self_attns += (layer_outputs[1],) 403 | if output_router_probs: 404 | all_router_probs += (layer_outputs[-1],) 405 | previous_router_probs = ( 406 | layer_outputs[-1] 407 | if layer_outputs[-1] is not None 408 | else previous_router_probs 409 | ) 410 | 411 | hidden_states = self.norm(hidden_states) 412 | 413 | # add hidden states from the last decoder layer 414 | if output_hidden_states: 415 | all_hidden_states += (hidden_states,) 416 | 417 | next_cache = next_decoder_cache if use_cache else None 418 | if return_legacy_cache: 419 | next_cache = next_cache.to_legacy_cache() 420 | 421 | if not return_dict: 422 | return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_probs] if v is not None) # noqa 423 | return MonetModelOutputWithPast( 424 | last_hidden_state=hidden_states, 425 | past_key_values=next_cache, 426 | hidden_states=all_hidden_states, 427 | attentions=all_self_attns, 428 | router_probs=all_router_probs, 429 | ) 430 | 431 | def _update_causal_mask( 432 | self, 433 | attention_mask: torch.Tensor, 434 | input_tensor: torch.Tensor, 435 | cache_position: torch.Tensor, 436 | past_key_values: Cache, 437 | output_attentions: bool, 438 | ): 439 | if self.config._attn_implementation == "flash_attention_2": 440 | if attention_mask is not None and 0.0 in attention_mask: 441 | return attention_mask 442 | return None 443 | 444 | past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 # noqa 445 | using_static_cache = isinstance(past_key_values, StaticCache) 446 | 447 | if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: # noqa 448 | if AttentionMaskConverter._ignore_causal_mask_sdpa( 449 | attention_mask, 450 | inputs_embeds=input_tensor, 451 | past_key_values_length=past_seen_tokens, 452 | is_training=self.training, 453 | ): 454 | return None 455 | 456 | dtype, device = input_tensor.dtype, input_tensor.device 457 | min_dtype = torch.finfo(dtype).min 458 | sequence_length = input_tensor.shape[1] 459 | if using_static_cache: 460 | target_length = past_key_values.get_max_length() 461 | else: 462 | target_length = ( 463 | attention_mask.shape[-1] 464 | if isinstance(attention_mask, torch.Tensor) 465 | else past_seen_tokens + sequence_length + 1 466 | ) 467 | 468 | if attention_mask is not None and attention_mask.dim() == 4: 469 | if attention_mask.max() != 0: 470 | raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`") # noqa 471 | causal_mask = attention_mask 472 | else: 473 | causal_mask = torch.full( 474 | (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device # noqa 475 | ) 476 | if sequence_length != 1: 477 | causal_mask = torch.triu(causal_mask, diagonal=1) 478 | causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) # noqa 479 | causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) # noqa 480 | if attention_mask is not None: 481 | causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit # noqa 482 | mask_length = attention_mask.shape[-1] 483 | padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] # noqa 484 | padding_mask = padding_mask == 0 485 | causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(padding_mask, min_dtype) # noqa 486 | if ( 487 | self.config._attn_implementation == "sdpa" 488 | and attention_mask is not None 489 | and attention_mask.device.type == "cuda" 490 | and not output_attentions 491 | ): 492 | causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) # noqa 493 | 494 | return causal_mask 495 | 496 | 497 | class MonetForCausalLM(MonetPreTrainedModel): 498 | _tied_weights_keys = ["lm_head.weight"] 499 | 500 | def __init__(self, config): 501 | super().__init__(config) 502 | self.model = MonetModel(config) 503 | self.vocab_size = config.vocab_size 504 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 505 | 506 | # Initialize weights and apply final processing 507 | self.post_init() 508 | 509 | def get_input_embeddings(self): 510 | return self.model.embed_tokens 511 | 512 | def set_input_embeddings(self, value): 513 | self.model.embed_tokens = value 514 | 515 | def get_output_embeddings(self): 516 | return self.lm_head 517 | 518 | def set_output_embeddings(self, new_embeddings): 519 | self.lm_head = new_embeddings 520 | 521 | def set_decoder(self, decoder): 522 | self.model = decoder 523 | 524 | def get_decoder(self): 525 | return self.model 526 | 527 | def forward( 528 | self, 529 | input_ids: torch.LongTensor = None, 530 | attention_mask: torch.Tensor | None = None, 531 | position_ids: torch.LongTensor | None = None, 532 | past_key_values: Cache | list[torch.FloatTensor] | None = None, 533 | inputs_embeds: torch.FloatTensor | None = None, 534 | labels: torch.LongTensor | None = None, 535 | use_cache: bool | None = None, 536 | output_attentions: bool | None = None, 537 | output_hidden_states: bool | None = None, 538 | output_router_probs: bool | None = None, 539 | return_dict: bool | None = None, 540 | cache_position: torch.LongTensor | None = None, 541 | ) -> tuple[torch.Tensor, ...] | MonetCausalLMOutputWithPast: 542 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions # noqa 543 | output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states # noqa 544 | output_router_probs = output_router_probs if output_router_probs is not None else self.config.output_router_probs # noqa 545 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict # noqa 546 | 547 | # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) 548 | outputs = self.model( 549 | input_ids=input_ids, 550 | attention_mask=attention_mask, 551 | position_ids=position_ids, 552 | past_key_values=past_key_values, 553 | inputs_embeds=inputs_embeds, 554 | use_cache=use_cache, 555 | output_attentions=output_attentions, 556 | output_hidden_states=output_hidden_states, 557 | output_router_probs=output_router_probs, 558 | return_dict=return_dict, 559 | cache_position=cache_position, 560 | ) 561 | 562 | hidden_states = outputs[0] 563 | logits = self.lm_head(hidden_states) 564 | logits = logits.float() 565 | 566 | loss = None 567 | if labels is not None: 568 | # Shift so that tokens < n predict n 569 | shift_logits = logits[..., :-1, :].contiguous() 570 | shift_labels = labels[..., 1:].contiguous() 571 | # Flatten the tokens 572 | loss_fct = CrossEntropyLoss() 573 | shift_logits = shift_logits.view(-1, self.config.vocab_size) 574 | shift_labels = shift_labels.view(-1) 575 | # Enable model parallelism 576 | shift_labels = shift_labels.to(shift_logits.device) 577 | loss = loss_fct(shift_logits, shift_labels) 578 | 579 | if not return_dict: 580 | output = (logits,) + outputs[1:] 581 | return (loss,) + output if loss is not None else output 582 | 583 | return MonetCausalLMOutputWithPast( 584 | loss=loss, 585 | logits=logits, 586 | past_key_values=outputs.past_key_values, 587 | hidden_states=outputs.hidden_states, 588 | attentions=outputs.attentions, 589 | router_probs=outputs.router_probs, 590 | ) 591 | 592 | def prepare_inputs_for_generation( 593 | self, 594 | input_ids, 595 | past_key_values=None, 596 | attention_mask=None, 597 | inputs_embeds=None, 598 | cache_position=None, 599 | use_cache=True, 600 | **kwargs, 601 | ): 602 | past_length = 0 603 | if past_key_values is not None: 604 | past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() # noqa 605 | max_cache_length = ( 606 | torch.tensor(past_key_values.get_max_length(), device=input_ids.device) 607 | if past_key_values.get_max_length() is not None 608 | else None 609 | ) 610 | cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length) # noqa 611 | 612 | # Keep only the unprocessed tokens: 613 | if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: # noqa 614 | input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] 615 | # input_ids based on the past_length. 616 | elif past_length < input_ids.shape[1]: 617 | input_ids = input_ids[:, past_length:] 618 | 619 | if ( 620 | max_cache_length is not None 621 | and attention_mask is not None 622 | and cache_length + input_ids.shape[1] > max_cache_length 623 | ): 624 | attention_mask = attention_mask[:, -max_cache_length:] 625 | 626 | position_ids = kwargs.get("position_ids", None) 627 | if attention_mask is not None and position_ids is None: 628 | # create position_ids on the fly for batch generation 629 | position_ids = attention_mask.long().cumsum(-1) - 1 630 | position_ids.masked_fill_(attention_mask == 0, 1) 631 | if past_key_values: 632 | position_ids = position_ids[:, -input_ids.shape[1] :] 633 | 634 | if inputs_embeds is not None and past_length == 0: 635 | model_inputs = {"inputs_embeds": inputs_embeds} 636 | else: 637 | model_inputs = {"input_ids": input_ids.contiguous()} 638 | 639 | input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1] # noqa 640 | if cache_position is None: 641 | cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device) # noqa 642 | elif use_cache: 643 | cache_position = cache_position[-input_length:] 644 | 645 | model_inputs.update( 646 | { 647 | "position_ids": position_ids, 648 | "cache_position": cache_position, 649 | "past_key_values": past_key_values, 650 | "use_cache": use_cache, 651 | "attention_mask": attention_mask, 652 | } 653 | ) 654 | return model_inputs 655 | 656 | @staticmethod 657 | def _reorder_cache(past_key_values, beam_idx): 658 | reordered_past = () 659 | for layer_past in past_key_values: 660 | reordered_past += ( 661 | tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), # noqa 662 | ) 663 | return reordered_past 664 | -------------------------------------------------------------------------------- /modeling_monet_vllm.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Adapted from 3 | # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py 4 | # Copyright 2023 The vLLM team. 5 | # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. 6 | # 7 | # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX 8 | # and OPT implementations in this library. It has been modified from its 9 | # original forms to accommodate minor architectural differences compared 10 | # to GPT-NeoX and OPT used by the Meta AI team that trained the model. 11 | # 12 | # Licensed under the Apache License, Version 2.0 (the "License"); 13 | # you may not use this file except in compliance with the License. 14 | # You may obtain a copy of the License at 15 | # 16 | # http://www.apache.org/licenses/LICENSE-2.0 17 | # 18 | # Unless required by applicable law or agreed to in writing, software 19 | # distributed under the License is distributed on an "AS IS" BASIS, 20 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 21 | # See the License for the specific language governing permissions and 22 | # limitations under the License. 23 | """Inference-only Monet model compatible with HuggingFace weights.""" 24 | from __future__ import annotations 25 | 26 | from typing import Any, Dict, Iterable, List, Optional, Tuple, Union 27 | 28 | import torch 29 | from scipy.stats import norm 30 | from torch import nn 31 | from transformers import PretrainedConfig 32 | from transformers.activations import ACT2FN 33 | from vllm.attention import Attention, AttentionMetadata 34 | from vllm.config import CacheConfig, LoRAConfig 35 | from vllm.distributed import ( 36 | get_pp_group, 37 | get_tensor_model_parallel_rank, 38 | get_tensor_model_parallel_world_size, 39 | ) 40 | from vllm.model_executor.layers.layernorm import RMSNorm 41 | from vllm.model_executor.layers.linear import QKVParallelLinear, RowParallelLinear 42 | from vllm.model_executor.layers.logits_processor import LogitsProcessor 43 | from vllm.model_executor.layers.quantization.base_config import QuantizationConfig 44 | from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( 45 | get_compressed_tensors_cache_scale, 46 | ) 47 | from vllm.model_executor.layers.rotary_embedding import get_rope 48 | from vllm.model_executor.layers.sampler import Sampler, SamplerOutput 49 | from vllm.model_executor.layers.vocab_parallel_embedding import ( 50 | DEFAULT_VOCAB_PADDING_SIZE, 51 | ParallelLMHead, 52 | VocabParallelEmbedding, 53 | ) 54 | from vllm.model_executor.model_loader.weight_utils import ( 55 | default_weight_loader, 56 | kv_cache_scales_loader, 57 | maybe_remap_kv_scale_name, 58 | ) 59 | from vllm.model_executor.models.interfaces import SupportsLoRA 60 | from vllm.model_executor.models.utils import ( 61 | PPMissingLayer, 62 | is_pp_missing_parameter, 63 | make_layers, 64 | ) 65 | from vllm.model_executor.sampling_metadata import SamplingMetadata 66 | from vllm.sequence import IntermediateTensors 67 | from vllm.utils import is_hip 68 | 69 | 70 | class MonetRouter(nn.Module): 71 | def __init__(self, config: PretrainedConfig): 72 | super().__init__() 73 | self.config = config 74 | flatten_shape = config.moe_heads * config.moe_experts 75 | 76 | self.w1 = nn.Linear(config.hidden_size, flatten_shape, bias=False) 77 | self.w2 = nn.Linear(config.hidden_size, flatten_shape, bias=False) 78 | self.norm1 = nn.BatchNorm1d(config.moe_heads, affine=False) 79 | self.norm2 = nn.BatchNorm1d(config.moe_heads, affine=False) 80 | 81 | def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: 82 | g1z = self.w1(x).unflatten(-1, (self.config.moe_heads, -1)).float() 83 | g2z = self.w2(x).unflatten(-1, (self.config.moe_heads, -1)).float() 84 | 85 | g1n = self.norm1(g1z.transpose(-2, -1).flatten(0, -2)) 86 | g2n = self.norm2(g2z.transpose(-2, -1).flatten(0, -2)) 87 | g1n = g1n.view(*g1z.shape[:-2], g1z.size(-1), -1).transpose(-2, -1) 88 | g2n = g2n.view(*g2z.shape[:-2], g2z.size(-1), -1).transpose(-2, -1) 89 | 90 | sigma = float(norm.ppf(1 - self.config.moe_topk / self.config.moe_experts)) 91 | g1s = g1n.amax(-1, keepdim=True).clamp_max_(sigma) 92 | g2s = g2n.amax(-1, keepdim=True).clamp_max_(sigma) 93 | 94 | g1 = nn.functional.softmax(torch.where(g1n >= g1s, g1z, -1e10), dim=-1) 95 | g2 = nn.functional.softmax(torch.where(g2n >= g2s, g2z, -1e10), dim=-1) 96 | return g1, g2 97 | 98 | 99 | class MonetMoVDE(nn.Module): 100 | def __init__(self, config: PretrainedConfig): 101 | super().__init__() 102 | self.config = config 103 | self.act_fn = ACT2FN[config.hidden_act] 104 | flatten_shape = config.moe_experts * config.moe_dim // 2 105 | 106 | self.u1 = nn.Linear(config.hidden_size, flatten_shape) 107 | self.u2 = nn.Linear(config.hidden_size, flatten_shape) 108 | 109 | self.v11 = nn.Linear(flatten_shape, config.hidden_size // 2, bias=False) 110 | self.v12 = nn.Linear(flatten_shape, config.hidden_size // 2, bias=False) 111 | self.v21 = nn.Linear(flatten_shape, config.hidden_size // 2, bias=False) 112 | self.v22 = nn.Linear(flatten_shape, config.hidden_size // 2, bias=False) 113 | 114 | self.b1 = nn.Parameter(torch.zeros(config.moe_experts, config.hidden_size // 2)) 115 | self.b2 = nn.Parameter(torch.zeros(config.moe_experts, config.hidden_size // 2)) 116 | 117 | def forward( 118 | self, x: torch.Tensor, g1: torch.Tensor, g2: torch.Tensor 119 | ) -> torch.Tensor: 120 | g1, g2 = g1.type_as(x), g2.type_as(x) 121 | x1 = self.act_fn(self.u1(x).unflatten(-1, (self.config.moe_experts, -1))) 122 | x2 = self.act_fn(self.u2(x).unflatten(-1, (self.config.moe_experts, -1))) 123 | 124 | x11 = self.v11(torch.einsum("bim,bhi->bim", x1, g1).flatten(-2)) 125 | x12 = self.v12(torch.einsum("bjm,bhj,bhi->bim", x2, g2, g1).flatten(-2)) 126 | x13 = torch.einsum("bhi,id->bd", g1, self.b1.type_as(x)) 127 | 128 | x21 = self.v21(torch.einsum("bim,bhi,bhj->bjm", x1, g1, g2).flatten(-2)) 129 | x22 = self.v22(torch.einsum("bjm,bhj->bjm", x2, g2).flatten(-2)) 130 | x23 = torch.einsum("bhj,jd->bd", g2, self.b2.type_as(x)) 131 | 132 | return torch.cat((x11 + x12 + x13, x21 + x22 + x23), dim=-1) 133 | 134 | 135 | class MonetMoHDE(nn.Module): 136 | def __init__(self, config: PretrainedConfig): 137 | super().__init__() 138 | self.config = config 139 | self.act_fn = ACT2FN[config.hidden_act] 140 | flatten_shape = config.moe_experts * config.moe_dim 141 | 142 | self.u = nn.Linear(config.hidden_size, flatten_shape) 143 | self.v = nn.Linear(flatten_shape, config.hidden_size, bias=False) 144 | self.b = nn.Parameter(torch.zeros(config.moe_experts, config.hidden_size)) 145 | 146 | def forward( 147 | self, x: torch.Tensor, g1: torch.Tensor, g2: torch.Tensor 148 | ) -> torch.Tensor: 149 | g1, g2 = g1.type_as(x), g2.type_as(x) 150 | x = self.act_fn(self.u(x).unflatten(-1, (self.config.moe_experts, -1))) 151 | x = self.v(torch.einsum("bim,bhi,bhj->bjm", x, g1, g2).flatten(-2)) 152 | return x + torch.einsum("bhj,jd->bd", g2, self.b) 153 | 154 | 155 | class MonetAttention(nn.Module): 156 | def __init__( 157 | self, 158 | config: PretrainedConfig, 159 | hidden_size: int, 160 | num_heads: int, 161 | num_kv_heads: int, 162 | rope_theta: float = 10000, 163 | rope_scaling: Optional[Dict[str, Any]] = None, 164 | max_position_embeddings: int = 8192, 165 | quant_config: Optional[QuantizationConfig] = None, 166 | bias: bool = False, 167 | cache_config: Optional[CacheConfig] = None, 168 | prefix: str = "", 169 | ): 170 | super().__init__() 171 | self.hidden_size = hidden_size 172 | tp_size = get_tensor_model_parallel_world_size() 173 | self.total_num_heads = num_heads 174 | assert self.total_num_heads % tp_size == 0 175 | self.num_heads = self.total_num_heads // tp_size 176 | self.total_num_kv_heads = num_kv_heads 177 | if self.total_num_kv_heads >= tp_size: 178 | # Number of KV heads is greater than TP size, so we partition 179 | # the KV heads across multiple tensor parallel GPUs. 180 | assert self.total_num_kv_heads % tp_size == 0 181 | else: 182 | # Number of KV heads is less than TP size, so we replicate 183 | # the KV heads across multiple tensor parallel GPUs. 184 | assert tp_size % self.total_num_kv_heads == 0 185 | self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) 186 | # MistralConfig has an optional head_dim introduced by Mistral-Nemo 187 | self.head_dim = getattr( 188 | config, "head_dim", self.hidden_size // self.total_num_heads 189 | ) 190 | self.q_size = self.num_heads * self.head_dim 191 | self.kv_size = self.num_kv_heads * self.head_dim 192 | self.scaling = self.head_dim**-0.5 193 | self.rope_theta = rope_theta 194 | self.max_position_embeddings = max_position_embeddings 195 | 196 | self.qkv_proj = QKVParallelLinear( 197 | hidden_size=hidden_size, 198 | head_size=self.head_dim, 199 | total_num_heads=self.total_num_heads, 200 | total_num_kv_heads=self.total_num_kv_heads, 201 | bias=bias, 202 | quant_config=quant_config, 203 | prefix=f"{prefix}.qkv_proj", 204 | ) 205 | 206 | self.o_proj = RowParallelLinear( 207 | input_size=self.total_num_heads * self.head_dim, 208 | output_size=hidden_size, 209 | bias=bias, 210 | quant_config=quant_config, 211 | prefix=f"{prefix}.o_proj", 212 | ) 213 | 214 | is_neox_style = True 215 | if quant_config is not None and quant_config.get_name() == "gguf": 216 | is_neox_style = False 217 | 218 | self.rotary_emb = get_rope( 219 | self.head_dim, 220 | rotary_dim=self.head_dim, 221 | max_position=max_position_embeddings, 222 | base=rope_theta, 223 | rope_scaling=rope_scaling, 224 | is_neox_style=is_neox_style, 225 | ) 226 | self.attn = Attention( 227 | self.num_heads, 228 | self.head_dim, 229 | self.scaling, 230 | num_kv_heads=self.num_kv_heads, 231 | cache_config=cache_config, 232 | quant_config=quant_config, 233 | ) 234 | 235 | def forward( 236 | self, 237 | positions: torch.Tensor, 238 | hidden_states: torch.Tensor, 239 | kv_cache: torch.Tensor, 240 | attn_metadata: AttentionMetadata, 241 | ) -> torch.Tensor: 242 | qkv, _ = self.qkv_proj(hidden_states) 243 | q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) 244 | q, k = self.rotary_emb(positions, q, k) 245 | attn_output = self.attn(q, k, v, kv_cache, attn_metadata) 246 | output, _ = self.o_proj(attn_output) 247 | return output 248 | 249 | 250 | class MonetDecoderLayer(nn.Module): 251 | def __init__( 252 | self, 253 | config: PretrainedConfig, 254 | layer_idx: int, 255 | cache_config: Optional[CacheConfig] = None, 256 | quant_config: Optional[QuantizationConfig] = None, 257 | prefix: str = "", 258 | ): 259 | super().__init__() 260 | self.hidden_size = config.hidden_size 261 | rope_theta = getattr(config, "rope_theta", 10000) 262 | rope_scaling = getattr(config, "rope_scaling", None) 263 | if rope_scaling is not None and getattr( 264 | config, "original_max_position_embeddings", None 265 | ): 266 | rope_scaling["original_max_position_embeddings"] = ( 267 | config.original_max_position_embeddings 268 | ) 269 | max_position_embeddings = getattr(config, "max_position_embeddings", 8192) 270 | # Support abacusai/Smaug-72B-v0.1 with attention_bias 271 | # Support internlm/internlm-7b with bias 272 | attention_bias = getattr(config, "attention_bias", False) or getattr( 273 | config, "bias", False 274 | ) 275 | self.self_attn = MonetAttention( 276 | config=config, 277 | hidden_size=self.hidden_size, 278 | num_heads=config.num_attention_heads, 279 | num_kv_heads=getattr( 280 | config, "num_key_value_heads", config.num_attention_heads 281 | ), 282 | rope_theta=rope_theta, 283 | rope_scaling=rope_scaling, 284 | max_position_embeddings=max_position_embeddings, 285 | quant_config=quant_config, 286 | bias=attention_bias, 287 | cache_config=cache_config, 288 | prefix=f"{prefix}.self_attn", 289 | ) 290 | self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) 291 | self.post_attention_layernorm = RMSNorm( 292 | config.hidden_size, eps=config.rms_norm_eps 293 | ) 294 | 295 | if config.moe_decompose == "vertical": 296 | self.moe = MonetMoVDE(config) 297 | elif config.moe_decompose == "horizontal": 298 | self.moe = MonetMoHDE(config) 299 | if layer_idx % config.moe_groups == 0: 300 | self.router = MonetRouter(config).requires_grad_(False) 301 | 302 | def forward( 303 | self, 304 | positions: torch.Tensor, 305 | hidden_states: torch.Tensor, 306 | kv_cache: torch.Tensor, 307 | attn_metadata: AttentionMetadata, 308 | residual: Optional[torch.Tensor], 309 | previous_router_probs: tuple[torch.Tensor, torch.Tensor] | None = None, 310 | ) -> Tuple[torch.Tensor, torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: 311 | # Self Attention 312 | if residual is None: 313 | residual = hidden_states 314 | hidden_states = self.input_layernorm(hidden_states) 315 | else: 316 | hidden_states, residual = self.input_layernorm(hidden_states, residual) 317 | hidden_states = self.self_attn( 318 | positions=positions, 319 | hidden_states=hidden_states, 320 | kv_cache=kv_cache, 321 | attn_metadata=attn_metadata, 322 | ) 323 | 324 | # Fully Connected 325 | hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) 326 | g1, g2 = ( 327 | self.router(hidden_states) 328 | if hasattr(self, "router") 329 | else previous_router_probs 330 | ) 331 | hidden_states = self.moe(hidden_states, g1, g2) 332 | return hidden_states, residual, (g1, g2) 333 | 334 | 335 | class MonetModel(nn.Module): 336 | def __init__( 337 | self, 338 | config: PretrainedConfig, 339 | cache_config: Optional[CacheConfig] = None, 340 | quant_config: Optional[QuantizationConfig] = None, 341 | lora_config: Optional[LoRAConfig] = None, 342 | prefix: str = "", 343 | ): 344 | super().__init__() 345 | self.config = config 346 | self.padding_idx = config.pad_token_id 347 | lora_vocab = ( 348 | (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) 349 | if lora_config 350 | else 0 351 | ) 352 | self.vocab_size = config.vocab_size + lora_vocab 353 | self.org_vocab_size = config.vocab_size 354 | if get_pp_group().is_first_rank or ( 355 | config.tie_word_embeddings and get_pp_group().is_last_rank 356 | ): 357 | self.embed_tokens = VocabParallelEmbedding( 358 | self.vocab_size, 359 | config.hidden_size, 360 | org_num_embeddings=config.vocab_size, 361 | quant_config=quant_config, 362 | ) 363 | else: 364 | self.embed_tokens = PPMissingLayer() 365 | 366 | layer_idx = 0 367 | 368 | def layer_fn(prefix: str) -> MonetDecoderLayer: 369 | nonlocal layer_idx 370 | layer_idx += 1 371 | return MonetDecoderLayer( 372 | config=config, 373 | layer_idx=layer_idx - 1, 374 | cache_config=cache_config, 375 | quant_config=quant_config, 376 | prefix=prefix, 377 | ) 378 | 379 | self.start_layer, self.end_layer, self.layers = make_layers( 380 | config.num_hidden_layers, layer_fn, prefix=f"{prefix}.layers" 381 | ) 382 | if get_pp_group().is_last_rank: 383 | self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) 384 | else: 385 | self.norm = PPMissingLayer() 386 | 387 | def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: 388 | return self.embed_tokens(input_ids) 389 | 390 | def forward( 391 | self, 392 | input_ids: Optional[torch.Tensor], 393 | positions: torch.Tensor, 394 | kv_caches: List[torch.Tensor], 395 | attn_metadata: AttentionMetadata, 396 | intermediate_tensors: Optional[IntermediateTensors], 397 | inputs_embeds: Optional[torch.Tensor] = None, 398 | ) -> Union[torch.Tensor, IntermediateTensors]: 399 | if get_pp_group().is_first_rank: 400 | if inputs_embeds is not None: 401 | hidden_states = inputs_embeds 402 | else: 403 | hidden_states = self.get_input_embeddings(input_ids) 404 | residual = None 405 | else: 406 | assert intermediate_tensors is not None 407 | hidden_states = intermediate_tensors["hidden_states"] 408 | residual = intermediate_tensors["residual"] 409 | 410 | previous_router_probs = None 411 | for i in range(self.start_layer, self.end_layer): 412 | layer = self.layers[i] 413 | hidden_states, residual, previous_router_probs = layer( 414 | positions, 415 | hidden_states, 416 | kv_caches[i - self.start_layer], 417 | attn_metadata, 418 | residual, 419 | previous_router_probs, 420 | ) 421 | 422 | if not get_pp_group().is_last_rank: 423 | return IntermediateTensors( 424 | {"hidden_states": hidden_states, "residual": residual} 425 | ) 426 | 427 | hidden_states, _ = self.norm(hidden_states, residual) 428 | return hidden_states 429 | 430 | 431 | class MonetForCausalLM(nn.Module, SupportsLoRA): 432 | packed_modules_mapping = { 433 | "qkv_proj": [ 434 | "q_proj", 435 | "k_proj", 436 | "v_proj", 437 | ], 438 | # "gate_up_proj": [ 439 | # "gate_proj", 440 | # "up_proj", 441 | # ], 442 | } 443 | 444 | # LoRA specific attributes 445 | supported_lora_modules = [ 446 | "qkv_proj", 447 | "o_proj", 448 | # "gate_up_proj", 449 | # "down_proj", 450 | "embed_tokens", 451 | "lm_head", 452 | ] 453 | embedding_modules = { 454 | "embed_tokens": "input_embeddings", 455 | "lm_head": "output_embeddings", 456 | } 457 | embedding_padding_modules = ["lm_head"] 458 | bitsandbytes_stacked_params_mapping = { 459 | # shard_name, weight_name, index 460 | "q_proj": ("qkv_proj", 0), 461 | "k_proj": ("qkv_proj", 1), 462 | "v_proj": ("qkv_proj", 2), 463 | # "gate_proj": ("gate_up_proj", 0), 464 | # "up_proj": ("gate_up_proj", 1), 465 | } 466 | # Mistral/Llama models can also be loaded with --load-format mistral 467 | # from consolidated.safetensors checkpoints 468 | mistral_mapping = { 469 | "layers": "model.layers", 470 | "attention": "self_attn", 471 | "wq": "q_proj", 472 | "wk": "k_proj", 473 | "wv": "v_proj", 474 | "wo": "o_proj", 475 | "attention_norm": "input_layernorm", 476 | "feed_forward": "mlp", 477 | # "w1": "gate_proj", 478 | # "w2": "down_proj", 479 | # "w3": "up_proj", 480 | "ffn_norm": "post_attention_layernorm", 481 | "tok_embeddings": "model.embed_tokens", 482 | "output": "lm_head", 483 | "norm": "model.norm", 484 | } 485 | 486 | def __init__( 487 | self, 488 | config: PretrainedConfig, 489 | cache_config: Optional[CacheConfig] = None, 490 | quant_config: Optional[QuantizationConfig] = None, 491 | lora_config: Optional[LoRAConfig] = None, 492 | ) -> None: 493 | super().__init__() 494 | 495 | self.config = config 496 | self.lora_config = lora_config 497 | 498 | self.model = MonetModel( 499 | config, cache_config, quant_config, lora_config=lora_config, prefix="model" 500 | ) 501 | if get_pp_group().is_last_rank: 502 | self.unpadded_vocab_size = config.vocab_size 503 | if lora_config: 504 | self.unpadded_vocab_size += lora_config.lora_extra_vocab_size 505 | self.lm_head = ParallelLMHead( 506 | self.unpadded_vocab_size, 507 | config.hidden_size, 508 | org_num_embeddings=config.vocab_size, 509 | padding_size=( 510 | DEFAULT_VOCAB_PADDING_SIZE 511 | # We need bigger padding if using lora for kernel 512 | # compatibility 513 | if not lora_config 514 | else lora_config.lora_vocab_padding_size 515 | ), 516 | quant_config=quant_config, 517 | ) 518 | if config.tie_word_embeddings: 519 | self.lm_head.weight = self.model.embed_tokens.weight 520 | 521 | logit_scale = getattr(config, "logit_scale", 1.0) 522 | self.logits_processor = LogitsProcessor( 523 | self.unpadded_vocab_size, config.vocab_size, logit_scale 524 | ) 525 | self.sampler = Sampler() 526 | else: 527 | self.lm_head = PPMissingLayer() 528 | 529 | def forward( 530 | self, 531 | input_ids: torch.Tensor, 532 | positions: torch.Tensor, 533 | kv_caches: List[torch.Tensor], 534 | attn_metadata: AttentionMetadata, 535 | intermediate_tensors: Optional[IntermediateTensors] = None, 536 | ) -> Union[torch.Tensor, IntermediateTensors]: 537 | model_output = self.model( 538 | input_ids, positions, kv_caches, attn_metadata, intermediate_tensors 539 | ) 540 | return model_output 541 | 542 | def compute_logits( 543 | self, 544 | hidden_states: torch.Tensor, 545 | sampling_metadata: SamplingMetadata, 546 | ) -> Optional[torch.Tensor]: 547 | logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) 548 | return logits 549 | 550 | def sample( 551 | self, 552 | logits: torch.Tensor, 553 | sampling_metadata: SamplingMetadata, 554 | ) -> Optional[SamplerOutput]: 555 | next_tokens = self.sampler(logits, sampling_metadata) 556 | return next_tokens 557 | 558 | def make_empty_intermediate_tensors( 559 | self, batch_size: int, dtype: torch.dtype, device: torch.device 560 | ) -> IntermediateTensors: 561 | return IntermediateTensors( 562 | { 563 | "hidden_states": torch.zeros( 564 | (batch_size, self.config.hidden_size), dtype=dtype, device=device 565 | ), 566 | "residual": torch.zeros( 567 | (batch_size, self.config.hidden_size), dtype=dtype, device=device 568 | ), 569 | } 570 | ) 571 | 572 | def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): 573 | stacked_params_mapping = [ 574 | # (param_name, shard_name, shard_id) 575 | (".qkv_proj", ".q_proj", "q"), 576 | (".qkv_proj", ".k_proj", "k"), 577 | (".qkv_proj", ".v_proj", "v"), 578 | # (".gate_up_proj", ".gate_proj", 0), 579 | # (".gate_up_proj", ".up_proj", 1), 580 | ] 581 | params_dict = dict(self.named_parameters()) | dict(self.named_buffers()) 582 | for name, loaded_weight in weights: 583 | name, loaded_weight = self.maybe_remap_mistral(name, loaded_weight) 584 | 585 | if "rotary_emb.inv_freq" in name: 586 | continue 587 | if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name: 588 | # Models trained using ColossalAI may include these tensors in 589 | # the checkpoint. Skip them. 590 | continue 591 | # With tie_word_embeddings, we can skip lm_head.weight 592 | # The weight might appear unnecessarily in the files if the model is 593 | # processed with quantization, LoRA, fine-tuning, etc. 594 | if self.config.tie_word_embeddings and "lm_head.weight" in name: 595 | continue 596 | if scale_name := get_compressed_tensors_cache_scale(name): 597 | # Loading kv cache scales for compressed-tensors quantization 598 | param = params_dict[scale_name] 599 | weight_loader = getattr(param, "weight_loader", default_weight_loader) 600 | loaded_weight = loaded_weight[0] 601 | weight_loader(param, loaded_weight) 602 | continue 603 | for param_name, weight_name, shard_id in stacked_params_mapping: 604 | if weight_name not in name: 605 | continue 606 | name = name.replace(weight_name, param_name) 607 | # Skip loading extra bias for GPTQ models. 608 | if name.endswith(".bias") and name not in params_dict: 609 | continue 610 | 611 | if is_pp_missing_parameter(name, self): 612 | continue 613 | 614 | param = params_dict[name] 615 | weight_loader = param.weight_loader 616 | weight_loader(param, loaded_weight, shard_id) 617 | 618 | break 619 | else: 620 | # Skip loading extra bias for GPTQ models. 621 | if name.endswith(".bias") and name not in params_dict: 622 | continue 623 | # Remapping the name of FP8 kv-scale. 624 | name = maybe_remap_kv_scale_name(name, params_dict) 625 | if name is None: 626 | continue 627 | 628 | if is_pp_missing_parameter(name, self): 629 | continue 630 | 631 | param = params_dict[name] 632 | weight_loader = getattr(param, "weight_loader", default_weight_loader) 633 | weight_loader(param, loaded_weight) 634 | 635 | # If this function is called, it should always initialize KV cache scale 636 | # factors (or else raise an exception). Thus, handled exceptions should 637 | # make sure to leave KV cache scale factors in a known good (dummy) state 638 | def load_kv_cache_scales(self, quantization_param_path: str) -> None: 639 | tp_size = get_tensor_model_parallel_world_size() 640 | tp_rank = get_tensor_model_parallel_rank() 641 | for layer_idx, scaling_factor in kv_cache_scales_loader( 642 | quantization_param_path, 643 | tp_rank, 644 | tp_size, 645 | self.config.num_hidden_layers, 646 | self.config.__class__.model_type, 647 | ): 648 | if not isinstance(self.model.layers[layer_idx], nn.Identity): 649 | layer_self_attn = self.model.layers[layer_idx].self_attn 650 | 651 | if is_hip(): 652 | # The scaling factor convention we are assuming is 653 | # quantized_value * scaling_factor ~= true_value 654 | # which is consistent with the practice of setting 655 | # scaling_factor = tensor_amax / FPtype_max 656 | scaling_factor *= 2 657 | if hasattr(layer_self_attn, "kv_scale"): 658 | layer_self_attn.attn._kv_scale = scaling_factor 659 | else: 660 | raise RuntimeError( 661 | "Self attention has no KV cache scaling " "factor attribute!" 662 | ) 663 | 664 | # This function is used to remap the mistral format as 665 | # used by Mistral and Llama <=2 666 | def maybe_remap_mistral( 667 | self, name: str, loaded_weight: torch.Tensor 668 | ) -> Tuple[str, torch.Tensor]: 669 | 670 | def permute(w, n_heads): 671 | attn_in = self.config.head_dim * n_heads 672 | attn_out = self.config.hidden_size 673 | 674 | return ( 675 | w.view(n_heads, attn_in // n_heads // 2, 2, attn_out) 676 | .transpose(1, 2) 677 | .reshape(attn_in, attn_out) 678 | ) 679 | 680 | mapping = self.mistral_mapping 681 | modules = name.split(".") 682 | 683 | # rotary embeds should be sliced 684 | if "wk" in modules: 685 | loaded_weight = permute(loaded_weight, self.config.num_key_value_heads) 686 | elif "wq" in modules: 687 | loaded_weight = permute(loaded_weight, self.config.num_attention_heads) 688 | 689 | for item in modules: 690 | if item in mapping and mapping[item] not in name: 691 | name = name.replace(item, mapping[item]) 692 | 693 | return name, loaded_weight 694 | --------------------------------------------------------------------------------