├── .gitignore
├── LICENSE
├── README.md
├── assets
└── figure1.png
├── modeling_monet.py
└── modeling_monet_vllm.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # File created using '.gitignore Generator' for Visual Studio Code: https://bit.ly/vscode-gig
2 | # Created by https://www.toptal.com/developers/gitignore/api/visualstudiocode,linux,python
3 | # Edit at https://www.toptal.com/developers/gitignore?templates=visualstudiocode,linux,python
4 |
5 | ### Linux ###
6 | *~
7 |
8 | # temporary files which can be created if a process still has a handle open of a deleted file
9 | .fuse_hidden*
10 |
11 | # KDE directory preferences
12 | .directory
13 |
14 | # Linux trash folder which might appear on any partition or disk
15 | .Trash-*
16 |
17 | # .nfs files are created when an open file is removed but is still being accessed
18 | .nfs*
19 |
20 | ### Python ###
21 | # Byte-compiled / optimized / DLL files
22 | __pycache__/
23 | *.py[cod]
24 | *$py.class
25 |
26 | # C extensions
27 | *.so
28 |
29 | # Distribution / packaging
30 | .Python
31 | build/
32 | develop-eggs/
33 | dist/
34 | downloads/
35 | eggs/
36 | .eggs/
37 | lib/
38 | lib64/
39 | parts/
40 | sdist/
41 | var/
42 | wheels/
43 | share/python-wheels/
44 | *.egg-info/
45 | .installed.cfg
46 | *.egg
47 | MANIFEST
48 |
49 | # PyInstaller
50 | # Usually these files are written by a python script from a template
51 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
52 | *.manifest
53 | *.spec
54 |
55 | # Installer logs
56 | pip-log.txt
57 | pip-delete-this-directory.txt
58 |
59 | # Unit test / coverage reports
60 | htmlcov/
61 | .tox/
62 | .nox/
63 | .coverage
64 | .coverage.*
65 | .cache
66 | nosetests.xml
67 | coverage.xml
68 | *.cover
69 | *.py,cover
70 | .hypothesis/
71 | .pytest_cache/
72 | cover/
73 |
74 | # Translations
75 | *.mo
76 | *.pot
77 |
78 | # Django stuff:
79 | *.log
80 | local_settings.py
81 | db.sqlite3
82 | db.sqlite3-journal
83 |
84 | # Flask stuff:
85 | instance/
86 | .webassets-cache
87 |
88 | # Scrapy stuff:
89 | .scrapy
90 |
91 | # Sphinx documentation
92 | docs/_build/
93 |
94 | # PyBuilder
95 | .pybuilder/
96 | target/
97 |
98 | # Jupyter Notebook
99 | .ipynb_checkpoints
100 |
101 | # IPython
102 | profile_default/
103 | ipython_config.py
104 |
105 | # pyenv
106 | # For a library or package, you might want to ignore these files since the code is
107 | # intended to run in multiple environments; otherwise, check them in:
108 | # .python-version
109 |
110 | # pipenv
111 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
112 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
113 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
114 | # install all needed dependencies.
115 | #Pipfile.lock
116 |
117 | # poetry
118 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
119 | # This is especially recommended for binary packages to ensure reproducibility, and is more
120 | # commonly ignored for libraries.
121 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
122 | #poetry.lock
123 |
124 | # pdm
125 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
126 | #pdm.lock
127 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
128 | # in version control.
129 | # https://pdm.fming.dev/#use-with-ide
130 | .pdm.toml
131 |
132 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
133 | __pypackages__/
134 |
135 | # Celery stuff
136 | celerybeat-schedule
137 | celerybeat.pid
138 |
139 | # SageMath parsed files
140 | *.sage.py
141 |
142 | # Environments
143 | .env
144 | .venv
145 | env/
146 | venv/
147 | ENV/
148 | env.bak/
149 | venv.bak/
150 |
151 | # Spyder project settings
152 | .spyderproject
153 | .spyproject
154 |
155 | # Rope project settings
156 | .ropeproject
157 |
158 | # mkdocs documentation
159 | /site
160 |
161 | # mypy
162 | .mypy_cache/
163 | .dmypy.json
164 | dmypy.json
165 |
166 | # Pyre type checker
167 | .pyre/
168 |
169 | # pytype static type analyzer
170 | .pytype/
171 |
172 | # Cython debug symbols
173 | cython_debug/
174 |
175 | # PyCharm
176 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
177 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
178 | # and can be added to the global gitignore or merged into this file. For a more nuclear
179 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
180 | #.idea/
181 |
182 | ### Python Patch ###
183 | # Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration
184 | poetry.toml
185 |
186 | # ruff
187 | .ruff_cache/
188 |
189 | # LSP config files
190 | pyrightconfig.json
191 |
192 | ### VisualStudioCode ###
193 | .vscode/*
194 |
195 | # Local History for Visual Studio Code
196 | .history/
197 |
198 | # Built Visual Studio Code Extensions
199 | *.vsix
200 |
201 | ### VisualStudioCode Patch ###
202 | # Ignore all local history of files
203 | .history
204 | .ionide
205 |
206 | # End of https://www.toptal.com/developers/gitignore/api/visualstudiocode,linux,python
207 |
208 | # Custom rules (everything added below won't be overriden by 'Generate .gitignore File' if you use 'Update' option)
209 |
210 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright 2024 DMIS Laboratory, Korea University
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Monet: Mixture of Monosemantic Experts for Transformers
2 |
3 | [](https://arxiv.org/abs/2412.04139)
4 | [](https://huggingface.co/MonetLLM)
5 | [](https://huggingface.co/spaces/MonetLLM/monet-vd-1.4B-100BT-hf-viewer)
6 | [](https://github.com/dmis-lab/Monet)
7 | [](./LICENSE)
8 |
9 | 
10 |
11 | ## Introduction
12 |
13 | **Monet** presents a novel approach to enhancing mechanistic interpretability in large language models (LLMs) through an innovative Sparse Mixture-of-Experts (SMoE) architecture. By directly incorporating sparse dictionary learning into end-to-end pretraining, **Monet** addresses the fundamental challenge of polysemanticity - where individual neurons respond to multiple unrelated concepts - while maintaining model performance.
14 |
15 | #### ✨Key Highlights
16 |
17 | - 📈 **Scalable Expert Architecture**: **Monet** introduces parameter-efficient expert decomposition methods that enable scaling to 262,144 experts per layer while ensuring total parameters scale proportionally to the square root of expert count.
18 | - 📊 **Monosemantic Experts**: Through fine-grained expert specialization, **Monet** achieves monosemantic experts that demonstrate mutual exclusivity of knowledge, allowing transparent observation of model behavior and parametric knowledge.
19 | - 🛠️ **Robust Knowledge Control**: The architecture enables precise manipulation of domain-specific knowledge, language capabilities, and toxicity mitigation without compromising general performance.
20 |
21 | ### Why Monet?
22 |
23 | Unlike traditional approaches using post-hoc reconstruction (like Sparse Autoencoders), **Monet** integrates interpretability directly into its architecture. This enables both transparent understanding of model internals and fundamental behavior control. By scaling monosemantic experts, Monet paves the way for more transparent and controllable language models.
24 |
25 | ## News
26 |
27 | - **2025-01-23**: Our paper has been accepted to **ICLR 2025**! 🎉
28 | - **2024-12-06**: Released **Monet: Mixture of Monosemantic Experts for Transformers** on [arXiv](https://arxiv.org/abs/2412.04139), with [GitHub](https://github.com/dmis-lab/Monet), [models](https://huggingface.co/MonetLLM), and [demo](https://huggingface.co/spaces/MonetLLM/monet-vd-1.4B-100BT-hf-viewer).
29 |
30 | ## Model Checkpoints
31 |
32 | #### Base Models
33 |
34 |
91 |
92 | #### Instruction-Tuned Models
93 |
94 |
116 |
117 | ## Quickstart
118 |
119 | You can explore the core implementation of **Monet** in [modeling_monet.py](./modeling_monet.py). We've made it easy to use Monet by including our custom code in the 🤗[Hugging Face model zoo](https://huggingface.co/MonetLLM). Simply set `trust_remote_code=True` when loading the models through the Transformers library.
120 |
121 | ### Text Generation
122 |
123 | ```python
124 | from transformers import pipeline
125 |
126 | model_name = "MonetLLM/monet-vd-1.4B-100BT-hf"
127 | pipe = pipeline(
128 | "text-generation",
129 | model_name,
130 | tokenizer=AutoTokenizer.from_pretrained(model_name),
131 | torch_dtype=torch.bfloat16,
132 | device_map="auto",
133 | trust_remote_code=True,
134 | )
135 | print(pipe("The key to life is", max_new_tokens=20, do_sample=True)[0]["generated_text"])
136 | ```
137 |
138 | Output:
139 |
140 | ```
141 | The key to life is learning how to live creatively. The question is: how do we do that, and what will
142 | ```
143 |
144 | ### Code Generation
145 |
146 | ```python
147 | from transformers import pipeline
148 |
149 | model_name = "MonetLLM/codemonet-vd-1.4B-100BT-hf"
150 | pipe = pipeline(
151 | "text-generation",
152 | model_name,
153 | tokenizer=AutoTokenizer.from_pretrained(model_name),
154 | torch_dtype=torch.bfloat16,
155 | device_map="auto",
156 | trust_remote_code=True,
157 | )
158 |
159 | text = '''
160 | def print_len(x: str):
161 | """For a given string x, print the length of x."""
162 | '''
163 | print(pipe(text, max_new_tokens=10)[0]["generated_text"].split("\n\n")[0])
164 | ```
165 |
166 | Output:
167 |
168 | ```
169 |
170 | def print_len(x: str):
171 | """For a given string x, print the length of x."""
172 | print(len(x))
173 | ```
174 |
175 | ### Chat Completion
176 |
177 | ```python
178 | from transformers import pipeline
179 |
180 | model_name = "MonetLLM/codemonet-vd-1.4B-100BT-chat-hf"
181 | pipe = pipeline(
182 | "text-generation",
183 | model_name,
184 | tokenizer=AutoTokenizer.from_pretrained(model_name),
185 | torch_dtype=torch.bfloat16,
186 | device_map="auto",
187 | trust_remote_code=True,
188 | )
189 |
190 | text = tokenizer.apply_chat_template(
191 | [{"role": "user", "content": "Hi! How are you?"}],
192 | add_generation_prompt=True,
193 | tokenize=False,
194 | )
195 | print(pipe(text, max_new_tokens=30, do_sample=True)[0]["generated_text"])
196 | ```
197 |
198 | Output:
199 |
200 | ```
201 | [INST] Hi! How are you? [/INST] I'm good, thanks! How can I help you today?
202 | ```
203 |
204 | ### Using vLLM
205 |
206 | For enhanced inference performance, **Monet** can be integrated with the vLLM engine. Note that **Monet** requires manual registration with vLLM's `ModelRegistry` before initialization. The custom implementation is provided in [modeling_monet_vllm.py](./modeling_monet_vllm.py).
207 |
208 | ```python
209 | from vllm import LLM, ModelRegistry, SamplingParams
210 | from modeling_monet_vllm import MonetForCausalLM
211 |
212 | # Register Monet architecture with vLLM
213 | ModelRegistry.register_model("MonetForCausalLM", MonetForCausalLM)
214 |
215 | model = LLM(
216 | "MonetLLM/monet-vd-1.4B-100BT-hf",
217 | trust_remote_code=True,
218 | dtype="bfloat16",
219 | gpu_memory_utilization=0.8
220 | )
221 | sampling_params = SamplingParams(max_tokens=20, temperature=1.0)
222 | print(model.generate("The key to life is", sampling_params)[0].outputs[0].text)
223 | ```
224 | Output:
225 | ```
226 | what you’re born with. If you think that you don’t have the same control and
227 | ```
228 |
229 | ### Get Expert Routing Probabilities
230 |
231 | Based on expert routing probabilities, **Monet** enables mechanistic interpretability by understanding which sparse features are activated to which token. Following the standard MoE approach, you can obtain expert routing probabilities for all layers by setting `output_router_probs=True`. The example below demonstrates how to compute and analyze the expert activation patterns:
232 |
233 | ```python
234 | import torch
235 | from transformers import AutoModelForCausalLM, AutoTokenizer
236 |
237 | model = AutoModelForCausalLM.from_pretrained(
238 | "MonetLLM/monet-vd-1.4B-100BT-hf",
239 | torch_dtype=torch.bfloat16,
240 | device_map="auto",
241 | trust_remote_code=True,
242 | )
243 | tokenizer = AutoTokenizer.from_pretrained("MonetLLM/monet-vd-1.4B-100BT-hf")
244 |
245 | inputs = tokenizer("City and County of San Francisco", return_tensors="pt")
246 | outputs = model(**inputs.to(model.device), output_router_probs=True)
247 |
248 | # Get full expert routing probabilities: [batch_size, seq_len, moe_heads, moe_experts**2]
249 | g1, g2 = outputs.router_probs[0][0], outputs.router_probs[0][1]
250 | g = torch.einsum("bthi,bthj->bthij", g1, g2).flatten(-2)
251 | print(g.shape)
252 |
253 | # Print number of activated experts per token.
254 | for token, routing in zip(inputs.input_ids.squeeze(0), g.squeeze(0)):
255 | token = tokenizer.decode(token).ljust(16, " ")
256 | expert_indices = (routing.sum(0) > 1e-2).argwhere().squeeze(-1)
257 | print(f"Token: {token} Activated Experts: {len(expert_indices)}")
258 | ```
259 |
260 | Output:
261 |
262 | ```
263 | torch.Size([1, 7, 8, 262144])
264 | Token: Activated Experts: 62
265 | Token: City Activated Experts: 60
266 | Token: and Activated Experts: 16
267 | Token: County Activated Experts: 102
268 | Token: of Activated Experts: 11
269 | Token: San Activated Experts: 70
270 | Token: Francisco Activated Experts: 67
271 | ```
272 |
273 | ## Citation
274 | Please cite related papers/blogs using this BibTeX if you find this useful for your research and applications.
275 | ```bibtex
276 | @article{park2024monet,
277 | title={{Monet: Mixture of Monosemantic Experts for Transformers}},
278 | author={Jungwoo Park and Young Jin Ahn and Kee-Eung Kim and Jaewoo Kang},
279 | journal={arXiv preprint arXiv:2404.05567},
280 | year={2024}
281 | }
282 | ```
283 |
--------------------------------------------------------------------------------
/assets/figure1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dmis-lab/Monet/b5cb786e938c4f84583eaca4e211e693e6c9d3cd/assets/figure1.png
--------------------------------------------------------------------------------
/modeling_monet.py:
--------------------------------------------------------------------------------
1 | # fmt: off
2 | from __future__ import annotations
3 |
4 | from dataclasses import dataclass
5 |
6 | import torch
7 | import torch.utils.checkpoint
8 | from scipy.stats import norm
9 | from torch import nn
10 | from torch.nn import CrossEntropyLoss
11 | from transformers.activations import ACT2FN
12 | from transformers.cache_utils import Cache, DynamicCache, StaticCache
13 | from transformers.modeling_attn_mask_utils import AttentionMaskConverter
14 | from transformers.modeling_utils import PreTrainedModel
15 | from transformers.models.llama.configuration_llama import LlamaConfig
16 | from transformers.models.llama.modeling_llama import (
17 | LLAMA_ATTENTION_CLASSES,
18 | LlamaRMSNorm,
19 | )
20 | from transformers.utils import ModelOutput, logging
21 |
22 | logger = logging.get_logger(__name__)
23 |
24 |
25 | @dataclass
26 | class MonetModelOutputWithPast(ModelOutput):
27 | last_hidden_state: torch.FloatTensor = None
28 | past_key_values: tuple[tuple[torch.FloatTensor]] | None = None
29 | hidden_states: tuple[torch.FloatTensor, ...] | None = None
30 | attentions: tuple[torch.FloatTensor, ...] | None = None
31 | router_probs: tuple[tuple[torch.FloatTensor, ...], ...] | None = None
32 |
33 |
34 | @dataclass
35 | class MonetCausalLMOutputWithPast(ModelOutput):
36 | loss: torch.FloatTensor | None = None
37 | aux_loss: torch.FloatTensor | None = None
38 | logits: torch.FloatTensor = None
39 | past_key_values: tuple[tuple[torch.FloatTensor]] | None = None
40 | hidden_states: tuple[torch.FloatTensor, ...] | None = None
41 | attentions: tuple[torch.FloatTensor, ...] | None = None
42 | router_probs: tuple[tuple[torch.FloatTensor, ...], ...] | None = None
43 |
44 |
45 | class MonetConfig(LlamaConfig):
46 | model_type = "monet"
47 | keys_to_ignore_at_inference = ["past_key_values"]
48 |
49 | def __init__(
50 | self,
51 | vocab_size=32000,
52 | hidden_size=4096,
53 | intermediate_size=None,
54 | num_hidden_layers=32,
55 | num_attention_heads=32,
56 | num_key_value_heads=None,
57 | hidden_act="relu2",
58 | max_position_embeddings=2048,
59 | initializer_range=0.02,
60 | rms_norm_eps=1e-6,
61 | use_cache=True,
62 | pad_token_id=None,
63 | bos_token_id=1,
64 | eos_token_id=2,
65 | pretraining_tp=1,
66 | tie_word_embeddings=False,
67 | rope_theta=10000.0,
68 | rope_scaling=None,
69 | attention_bias=False,
70 | attention_dropout=0.0,
71 | mlp_bias=None,
72 | moe_dim=8,
73 | moe_heads=8,
74 | moe_experts=512,
75 | moe_topk=32,
76 | moe_groups=4,
77 | moe_decompose="vertical",
78 | output_router_probs=False,
79 | **kwargs,
80 | ):
81 | self.moe_dim = moe_dim
82 | self.moe_heads = moe_heads
83 | self.moe_experts = moe_experts
84 | self.moe_topk = moe_topk
85 | self.moe_groups = moe_groups
86 | self.moe_decompose = moe_decompose
87 | self.output_router_probs = output_router_probs
88 |
89 | super().__init__(
90 | vocab_size=vocab_size,
91 | hidden_size=hidden_size,
92 | intermediate_size=intermediate_size,
93 | num_hidden_layers=num_hidden_layers,
94 | num_attention_heads=num_attention_heads,
95 | num_key_value_heads=num_key_value_heads,
96 | hidden_act=hidden_act,
97 | max_position_embeddings=max_position_embeddings,
98 | initializer_range=initializer_range,
99 | rms_norm_eps=rms_norm_eps,
100 | use_cache=use_cache,
101 | pad_token_id=pad_token_id,
102 | bos_token_id=bos_token_id,
103 | eos_token_id=eos_token_id,
104 | pretraining_tp=pretraining_tp,
105 | tie_word_embeddings=tie_word_embeddings,
106 | rope_theta=rope_theta,
107 | rope_scaling=rope_scaling,
108 | attention_bias=attention_bias,
109 | attention_dropout=attention_dropout,
110 | mlp_bias=mlp_bias,
111 | **kwargs,
112 | )
113 |
114 |
115 | class MonetRouter(nn.Module):
116 | def __init__(self, config: MonetConfig):
117 | super().__init__()
118 | self.config = config
119 | flatten_shape = config.moe_heads * config.moe_experts
120 |
121 | self.w1 = nn.Linear(config.hidden_size, flatten_shape, bias=False)
122 | self.w2 = nn.Linear(config.hidden_size, flatten_shape, bias=False)
123 | self.norm1 = nn.BatchNorm1d(config.moe_heads, affine=False)
124 | self.norm2 = nn.BatchNorm1d(config.moe_heads, affine=False)
125 |
126 | def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
127 | g1z = self.w1(x).unflatten(-1, (self.config.moe_heads, -1)).float()
128 | g2z = self.w2(x).unflatten(-1, (self.config.moe_heads, -1)).float()
129 |
130 | g1n = self.norm1(g1z.transpose(2, 3).flatten(0, -2))
131 | g2n = self.norm2(g2z.transpose(2, 3).flatten(0, -2))
132 | g1n = g1n.view(g1z.size(0), g1z.size(1), g1z.size(3), -1).transpose(2, 3)
133 | g2n = g2n.view(g2z.size(0), g2z.size(1), g2z.size(3), -1).transpose(2, 3)
134 |
135 | sigma = float(norm.ppf(1 - self.config.moe_topk / self.config.moe_experts))
136 | g1s = g1n.amax(-1, keepdim=True).clamp_max_(sigma)
137 | g2s = g2n.amax(-1, keepdim=True).clamp_max_(sigma)
138 |
139 | g1 = nn.functional.softmax(torch.where(g1n >= g1s, g1z, -1e10), dim=-1)
140 | g2 = nn.functional.softmax(torch.where(g2n >= g2s, g2z, -1e10), dim=-1)
141 | return g1, g2
142 |
143 |
144 | class MonetMoVDE(nn.Module):
145 | def __init__(self, config: MonetConfig):
146 | super().__init__()
147 | self.config = config
148 | self.act_fn = ACT2FN[config.hidden_act]
149 | flatten_shape = config.moe_experts * config.moe_dim // 2
150 |
151 | self.u1 = nn.Linear(config.hidden_size, flatten_shape)
152 | self.u2 = nn.Linear(config.hidden_size, flatten_shape)
153 |
154 | self.v11 = nn.Linear(flatten_shape, config.hidden_size // 2, bias=False)
155 | self.v12 = nn.Linear(flatten_shape, config.hidden_size // 2, bias=False)
156 | self.v21 = nn.Linear(flatten_shape, config.hidden_size // 2, bias=False)
157 | self.v22 = nn.Linear(flatten_shape, config.hidden_size // 2, bias=False)
158 |
159 | self.b1 = nn.Parameter(torch.zeros(config.moe_experts, config.hidden_size // 2))
160 | self.b2 = nn.Parameter(torch.zeros(config.moe_experts, config.hidden_size // 2))
161 |
162 | def forward(
163 | self, x: torch.Tensor, g1: torch.Tensor, g2: torch.Tensor
164 | ) -> torch.Tensor:
165 | g1, g2 = g1.type_as(x), g2.type_as(x)
166 | x1 = self.act_fn(self.u1(x).unflatten(-1, (self.config.moe_experts, -1)))
167 | x2 = self.act_fn(self.u2(x).unflatten(-1, (self.config.moe_experts, -1)))
168 |
169 | x11 = self.v11(torch.einsum("btim,bthi->btim", x1, g1).flatten(-2))
170 | x12 = self.v12(torch.einsum("btjm,bthj,bthi->btim", x2, g2, g1).flatten(-2))
171 | x13 = torch.einsum("bthi,id->btd", g1, self.b1.type_as(x))
172 |
173 | x21 = self.v21(torch.einsum("btim,bthi,bthj->btjm", x1, g1, g2).flatten(-2))
174 | x22 = self.v22(torch.einsum("btjm,bthj->btjm", x2, g2).flatten(-2))
175 | x23 = torch.einsum("bthj,jd->btd", g2, self.b2.type_as(x))
176 |
177 | return torch.cat((x11 + x12 + x13, x21 + x22 + x23), dim=-1)
178 |
179 |
180 | class MonetMoHDE(nn.Module):
181 | def __init__(self, config: MonetConfig):
182 | super().__init__()
183 | self.config = config
184 | self.act_fn = ACT2FN[config.hidden_act]
185 | flatten_shape = config.moe_experts * config.moe_dim
186 |
187 | self.u = nn.Linear(config.hidden_size, flatten_shape)
188 | self.v = nn.Linear(flatten_shape, config.hidden_size, bias=False)
189 | self.b = nn.Parameter(torch.zeros(config.moe_experts, config.hidden_size))
190 |
191 | def forward(
192 | self, x: torch.Tensor, g1: torch.Tensor, g2: torch.Tensor
193 | ) -> torch.Tensor:
194 | g1, g2 = g1.type_as(x), g2.type_as(x)
195 | x = self.act_fn(self.u(x).unflatten(-1, (self.config.moe_experts, -1)))
196 | x = self.v(torch.einsum("btim,bthi,bthj->btjm", x, g1, g2).flatten(-2))
197 | return x + torch.einsum("bthj,jd->btd", g2, self.b)
198 |
199 |
200 | class MonetDecoderLayer(nn.Module):
201 | def __init__(self, config: MonetConfig, layer_idx: int):
202 | super().__init__()
203 | self.hidden_size = config.hidden_size
204 | self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](
205 | config=config, layer_idx=layer_idx
206 | )
207 | self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
208 | self.post_attention_layernorm = LlamaRMSNorm(
209 | config.hidden_size, eps=config.rms_norm_eps
210 | )
211 |
212 | if config.moe_decompose == "vertical":
213 | self.moe = MonetMoVDE(config)
214 | elif config.moe_decompose == "horizontal":
215 | self.moe = MonetMoHDE(config)
216 | if layer_idx % config.moe_groups == 0:
217 | self.router = MonetRouter(config).requires_grad_(False)
218 |
219 | def forward(
220 | self,
221 | hidden_states: torch.Tensor,
222 | attention_mask: torch.Tensor | None = None,
223 | position_ids: torch.LongTensor | None = None,
224 | past_key_value: Cache | None = None,
225 | previous_router_probs: tuple[torch.Tensor, torch.Tensor] | None = None,
226 | output_attentions: bool | None = False,
227 | use_cache: bool | None = False,
228 | cache_position: torch.LongTensor | None = None,
229 | **kwargs,
230 | ) -> tuple[torch.FloatTensor, ...]:
231 | residual = hidden_states
232 |
233 | hidden_states = self.input_layernorm(hidden_states)
234 |
235 | # Self Attention
236 | hidden_states, self_attn_weights, present_key_value = self.self_attn(
237 | hidden_states=hidden_states,
238 | attention_mask=attention_mask,
239 | position_ids=position_ids,
240 | past_key_value=past_key_value,
241 | output_attentions=output_attentions,
242 | use_cache=use_cache,
243 | cache_position=cache_position,
244 | )
245 | hidden_states = residual + hidden_states
246 |
247 | # Fully Connected
248 | residual = hidden_states
249 | hidden_states = self.post_attention_layernorm(hidden_states)
250 | g1, g2 = (
251 | self.router(hidden_states)
252 | if hasattr(self, "router")
253 | else previous_router_probs
254 | )
255 | hidden_states = self.moe(hidden_states, g1, g2)
256 | hidden_states = residual + hidden_states
257 |
258 | outputs = (hidden_states,)
259 |
260 | if output_attentions:
261 | outputs += (self_attn_weights,)
262 |
263 | if use_cache:
264 | outputs += (present_key_value,)
265 |
266 | return outputs + ((g1, g2) if hasattr(self, "router") else None,)
267 |
268 |
269 | class MonetPreTrainedModel(PreTrainedModel):
270 | config_class = MonetConfig
271 | base_model_prefix = "model"
272 | supports_gradient_checkpointing = True
273 | _no_split_modules = ["MonetDecoderLayer"]
274 | _skip_keys_device_placement = ["past_key_values"]
275 | _supports_flash_attn_2 = True
276 | _supports_sdpa = True
277 | _supports_cache_class = True
278 | _supports_quantized_cache = True
279 | _supports_static_cache = True
280 |
281 | def _init_weights(self, module):
282 | std = self.config.initializer_range
283 | if isinstance(module, nn.Linear):
284 | module.weight.data.normal_(mean=0.0, std=std)
285 | if module.bias is not None:
286 | module.bias.data.zero_()
287 | elif isinstance(module, nn.Embedding):
288 | module.weight.data.normal_(mean=0.0, std=std)
289 | if module.padding_idx is not None:
290 | module.weight.data[module.padding_idx].zero_()
291 |
292 |
293 | class MonetModel(MonetPreTrainedModel):
294 | def __init__(self, config: MonetConfig):
295 | super().__init__(config)
296 | self.padding_idx = config.pad_token_id
297 | self.vocab_size = config.vocab_size
298 |
299 | self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) # noqa
300 | self.layers = nn.ModuleList([MonetDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]) # noqa
301 | self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
302 | self.gradient_checkpointing = False
303 |
304 | # Initialize weights and apply final processing
305 | self.post_init()
306 |
307 | def get_input_embeddings(self):
308 | return self.embed_tokens
309 |
310 | def set_input_embeddings(self, value):
311 | self.embed_tokens = value
312 |
313 | def forward(
314 | self,
315 | input_ids: torch.LongTensor = None,
316 | attention_mask: torch.Tensor | None = None,
317 | position_ids: torch.LongTensor | None = None,
318 | past_key_values: Cache | list[torch.FloatTensor] | None = None,
319 | inputs_embeds: torch.FloatTensor | None = None,
320 | use_cache: bool | None = None,
321 | output_attentions: bool | None = None,
322 | output_hidden_states: bool | None = None,
323 | output_router_probs: bool | None = None,
324 | return_dict: bool | None = None,
325 | cache_position: torch.LongTensor | None = None,
326 | ) -> tuple[torch.Tensor, ...] | MonetModelOutputWithPast:
327 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions # noqa
328 | output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states # noqa
329 | output_router_probs = output_router_probs if output_router_probs is not None else self.config.output_router_probs # noqa
330 | use_cache = use_cache if use_cache is not None else self.config.use_cache
331 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict # noqa
332 |
333 | if (input_ids is None) ^ (inputs_embeds is not None):
334 | raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one") # noqa
335 |
336 | if self.gradient_checkpointing and self.training and use_cache:
337 | logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.") # noqa
338 | use_cache = False
339 |
340 | if inputs_embeds is None:
341 | inputs_embeds = self.embed_tokens(input_ids)
342 |
343 | return_legacy_cache = False
344 | if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs) # noqa
345 | return_legacy_cache = True
346 | past_key_values = DynamicCache.from_legacy_cache(past_key_values)
347 | logger.warning_once(
348 | "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. " # noqa
349 | "Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)" # noqa
350 | )
351 |
352 | if cache_position is None:
353 | past_seen_tokens = (
354 | past_key_values.get_seq_length() if past_key_values is not None else 0
355 | )
356 | cache_position = torch.arange(past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device) # noqa
357 | if position_ids is None:
358 | position_ids = cache_position.unsqueeze(0)
359 | causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions) # noqa
360 |
361 | # embed positions
362 | hidden_states = inputs_embeds
363 |
364 | # decoder layers
365 | all_hidden_states = () if output_hidden_states else None
366 | all_self_attns = () if output_attentions else None
367 | all_router_probs = () if output_router_probs else None
368 | previous_router_probs, next_decoder_cache = None, None
369 |
370 | for decoder_layer in self.layers:
371 | if output_hidden_states:
372 | all_hidden_states += (hidden_states,)
373 |
374 | if self.gradient_checkpointing and self.training:
375 | layer_outputs = self._gradient_checkpointing_func(
376 | decoder_layer.__call__,
377 | hidden_states,
378 | causal_mask,
379 | position_ids,
380 | past_key_values,
381 | previous_router_probs,
382 | output_attentions,
383 | use_cache,
384 | cache_position,
385 | )
386 | else:
387 | layer_outputs = decoder_layer(
388 | hidden_states,
389 | attention_mask=causal_mask,
390 | position_ids=position_ids,
391 | past_key_value=past_key_values,
392 | previous_router_probs=previous_router_probs,
393 | output_attentions=output_attentions,
394 | use_cache=use_cache,
395 | cache_position=cache_position,
396 | )
397 |
398 | hidden_states = layer_outputs[0]
399 | if use_cache:
400 | next_decoder_cache = layer_outputs[2 if output_attentions else 1]
401 | if output_attentions:
402 | all_self_attns += (layer_outputs[1],)
403 | if output_router_probs:
404 | all_router_probs += (layer_outputs[-1],)
405 | previous_router_probs = (
406 | layer_outputs[-1]
407 | if layer_outputs[-1] is not None
408 | else previous_router_probs
409 | )
410 |
411 | hidden_states = self.norm(hidden_states)
412 |
413 | # add hidden states from the last decoder layer
414 | if output_hidden_states:
415 | all_hidden_states += (hidden_states,)
416 |
417 | next_cache = next_decoder_cache if use_cache else None
418 | if return_legacy_cache:
419 | next_cache = next_cache.to_legacy_cache()
420 |
421 | if not return_dict:
422 | return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_probs] if v is not None) # noqa
423 | return MonetModelOutputWithPast(
424 | last_hidden_state=hidden_states,
425 | past_key_values=next_cache,
426 | hidden_states=all_hidden_states,
427 | attentions=all_self_attns,
428 | router_probs=all_router_probs,
429 | )
430 |
431 | def _update_causal_mask(
432 | self,
433 | attention_mask: torch.Tensor,
434 | input_tensor: torch.Tensor,
435 | cache_position: torch.Tensor,
436 | past_key_values: Cache,
437 | output_attentions: bool,
438 | ):
439 | if self.config._attn_implementation == "flash_attention_2":
440 | if attention_mask is not None and 0.0 in attention_mask:
441 | return attention_mask
442 | return None
443 |
444 | past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 # noqa
445 | using_static_cache = isinstance(past_key_values, StaticCache)
446 |
447 | if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: # noqa
448 | if AttentionMaskConverter._ignore_causal_mask_sdpa(
449 | attention_mask,
450 | inputs_embeds=input_tensor,
451 | past_key_values_length=past_seen_tokens,
452 | is_training=self.training,
453 | ):
454 | return None
455 |
456 | dtype, device = input_tensor.dtype, input_tensor.device
457 | min_dtype = torch.finfo(dtype).min
458 | sequence_length = input_tensor.shape[1]
459 | if using_static_cache:
460 | target_length = past_key_values.get_max_length()
461 | else:
462 | target_length = (
463 | attention_mask.shape[-1]
464 | if isinstance(attention_mask, torch.Tensor)
465 | else past_seen_tokens + sequence_length + 1
466 | )
467 |
468 | if attention_mask is not None and attention_mask.dim() == 4:
469 | if attention_mask.max() != 0:
470 | raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`") # noqa
471 | causal_mask = attention_mask
472 | else:
473 | causal_mask = torch.full(
474 | (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device # noqa
475 | )
476 | if sequence_length != 1:
477 | causal_mask = torch.triu(causal_mask, diagonal=1)
478 | causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) # noqa
479 | causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) # noqa
480 | if attention_mask is not None:
481 | causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit # noqa
482 | mask_length = attention_mask.shape[-1]
483 | padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] # noqa
484 | padding_mask = padding_mask == 0
485 | causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(padding_mask, min_dtype) # noqa
486 | if (
487 | self.config._attn_implementation == "sdpa"
488 | and attention_mask is not None
489 | and attention_mask.device.type == "cuda"
490 | and not output_attentions
491 | ):
492 | causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) # noqa
493 |
494 | return causal_mask
495 |
496 |
497 | class MonetForCausalLM(MonetPreTrainedModel):
498 | _tied_weights_keys = ["lm_head.weight"]
499 |
500 | def __init__(self, config):
501 | super().__init__(config)
502 | self.model = MonetModel(config)
503 | self.vocab_size = config.vocab_size
504 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
505 |
506 | # Initialize weights and apply final processing
507 | self.post_init()
508 |
509 | def get_input_embeddings(self):
510 | return self.model.embed_tokens
511 |
512 | def set_input_embeddings(self, value):
513 | self.model.embed_tokens = value
514 |
515 | def get_output_embeddings(self):
516 | return self.lm_head
517 |
518 | def set_output_embeddings(self, new_embeddings):
519 | self.lm_head = new_embeddings
520 |
521 | def set_decoder(self, decoder):
522 | self.model = decoder
523 |
524 | def get_decoder(self):
525 | return self.model
526 |
527 | def forward(
528 | self,
529 | input_ids: torch.LongTensor = None,
530 | attention_mask: torch.Tensor | None = None,
531 | position_ids: torch.LongTensor | None = None,
532 | past_key_values: Cache | list[torch.FloatTensor] | None = None,
533 | inputs_embeds: torch.FloatTensor | None = None,
534 | labels: torch.LongTensor | None = None,
535 | use_cache: bool | None = None,
536 | output_attentions: bool | None = None,
537 | output_hidden_states: bool | None = None,
538 | output_router_probs: bool | None = None,
539 | return_dict: bool | None = None,
540 | cache_position: torch.LongTensor | None = None,
541 | ) -> tuple[torch.Tensor, ...] | MonetCausalLMOutputWithPast:
542 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions # noqa
543 | output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states # noqa
544 | output_router_probs = output_router_probs if output_router_probs is not None else self.config.output_router_probs # noqa
545 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict # noqa
546 |
547 | # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
548 | outputs = self.model(
549 | input_ids=input_ids,
550 | attention_mask=attention_mask,
551 | position_ids=position_ids,
552 | past_key_values=past_key_values,
553 | inputs_embeds=inputs_embeds,
554 | use_cache=use_cache,
555 | output_attentions=output_attentions,
556 | output_hidden_states=output_hidden_states,
557 | output_router_probs=output_router_probs,
558 | return_dict=return_dict,
559 | cache_position=cache_position,
560 | )
561 |
562 | hidden_states = outputs[0]
563 | logits = self.lm_head(hidden_states)
564 | logits = logits.float()
565 |
566 | loss = None
567 | if labels is not None:
568 | # Shift so that tokens < n predict n
569 | shift_logits = logits[..., :-1, :].contiguous()
570 | shift_labels = labels[..., 1:].contiguous()
571 | # Flatten the tokens
572 | loss_fct = CrossEntropyLoss()
573 | shift_logits = shift_logits.view(-1, self.config.vocab_size)
574 | shift_labels = shift_labels.view(-1)
575 | # Enable model parallelism
576 | shift_labels = shift_labels.to(shift_logits.device)
577 | loss = loss_fct(shift_logits, shift_labels)
578 |
579 | if not return_dict:
580 | output = (logits,) + outputs[1:]
581 | return (loss,) + output if loss is not None else output
582 |
583 | return MonetCausalLMOutputWithPast(
584 | loss=loss,
585 | logits=logits,
586 | past_key_values=outputs.past_key_values,
587 | hidden_states=outputs.hidden_states,
588 | attentions=outputs.attentions,
589 | router_probs=outputs.router_probs,
590 | )
591 |
592 | def prepare_inputs_for_generation(
593 | self,
594 | input_ids,
595 | past_key_values=None,
596 | attention_mask=None,
597 | inputs_embeds=None,
598 | cache_position=None,
599 | use_cache=True,
600 | **kwargs,
601 | ):
602 | past_length = 0
603 | if past_key_values is not None:
604 | past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() # noqa
605 | max_cache_length = (
606 | torch.tensor(past_key_values.get_max_length(), device=input_ids.device)
607 | if past_key_values.get_max_length() is not None
608 | else None
609 | )
610 | cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length) # noqa
611 |
612 | # Keep only the unprocessed tokens:
613 | if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: # noqa
614 | input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
615 | # input_ids based on the past_length.
616 | elif past_length < input_ids.shape[1]:
617 | input_ids = input_ids[:, past_length:]
618 |
619 | if (
620 | max_cache_length is not None
621 | and attention_mask is not None
622 | and cache_length + input_ids.shape[1] > max_cache_length
623 | ):
624 | attention_mask = attention_mask[:, -max_cache_length:]
625 |
626 | position_ids = kwargs.get("position_ids", None)
627 | if attention_mask is not None and position_ids is None:
628 | # create position_ids on the fly for batch generation
629 | position_ids = attention_mask.long().cumsum(-1) - 1
630 | position_ids.masked_fill_(attention_mask == 0, 1)
631 | if past_key_values:
632 | position_ids = position_ids[:, -input_ids.shape[1] :]
633 |
634 | if inputs_embeds is not None and past_length == 0:
635 | model_inputs = {"inputs_embeds": inputs_embeds}
636 | else:
637 | model_inputs = {"input_ids": input_ids.contiguous()}
638 |
639 | input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1] # noqa
640 | if cache_position is None:
641 | cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device) # noqa
642 | elif use_cache:
643 | cache_position = cache_position[-input_length:]
644 |
645 | model_inputs.update(
646 | {
647 | "position_ids": position_ids,
648 | "cache_position": cache_position,
649 | "past_key_values": past_key_values,
650 | "use_cache": use_cache,
651 | "attention_mask": attention_mask,
652 | }
653 | )
654 | return model_inputs
655 |
656 | @staticmethod
657 | def _reorder_cache(past_key_values, beam_idx):
658 | reordered_past = ()
659 | for layer_past in past_key_values:
660 | reordered_past += (
661 | tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), # noqa
662 | )
663 | return reordered_past
664 |
--------------------------------------------------------------------------------
/modeling_monet_vllm.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Adapted from
3 | # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
4 | # Copyright 2023 The vLLM team.
5 | # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
6 | #
7 | # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
8 | # and OPT implementations in this library. It has been modified from its
9 | # original forms to accommodate minor architectural differences compared
10 | # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
11 | #
12 | # Licensed under the Apache License, Version 2.0 (the "License");
13 | # you may not use this file except in compliance with the License.
14 | # You may obtain a copy of the License at
15 | #
16 | # http://www.apache.org/licenses/LICENSE-2.0
17 | #
18 | # Unless required by applicable law or agreed to in writing, software
19 | # distributed under the License is distributed on an "AS IS" BASIS,
20 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
21 | # See the License for the specific language governing permissions and
22 | # limitations under the License.
23 | """Inference-only Monet model compatible with HuggingFace weights."""
24 | from __future__ import annotations
25 |
26 | from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
27 |
28 | import torch
29 | from scipy.stats import norm
30 | from torch import nn
31 | from transformers import PretrainedConfig
32 | from transformers.activations import ACT2FN
33 | from vllm.attention import Attention, AttentionMetadata
34 | from vllm.config import CacheConfig, LoRAConfig
35 | from vllm.distributed import (
36 | get_pp_group,
37 | get_tensor_model_parallel_rank,
38 | get_tensor_model_parallel_world_size,
39 | )
40 | from vllm.model_executor.layers.layernorm import RMSNorm
41 | from vllm.model_executor.layers.linear import QKVParallelLinear, RowParallelLinear
42 | from vllm.model_executor.layers.logits_processor import LogitsProcessor
43 | from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
44 | from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
45 | get_compressed_tensors_cache_scale,
46 | )
47 | from vllm.model_executor.layers.rotary_embedding import get_rope
48 | from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
49 | from vllm.model_executor.layers.vocab_parallel_embedding import (
50 | DEFAULT_VOCAB_PADDING_SIZE,
51 | ParallelLMHead,
52 | VocabParallelEmbedding,
53 | )
54 | from vllm.model_executor.model_loader.weight_utils import (
55 | default_weight_loader,
56 | kv_cache_scales_loader,
57 | maybe_remap_kv_scale_name,
58 | )
59 | from vllm.model_executor.models.interfaces import SupportsLoRA
60 | from vllm.model_executor.models.utils import (
61 | PPMissingLayer,
62 | is_pp_missing_parameter,
63 | make_layers,
64 | )
65 | from vllm.model_executor.sampling_metadata import SamplingMetadata
66 | from vllm.sequence import IntermediateTensors
67 | from vllm.utils import is_hip
68 |
69 |
70 | class MonetRouter(nn.Module):
71 | def __init__(self, config: PretrainedConfig):
72 | super().__init__()
73 | self.config = config
74 | flatten_shape = config.moe_heads * config.moe_experts
75 |
76 | self.w1 = nn.Linear(config.hidden_size, flatten_shape, bias=False)
77 | self.w2 = nn.Linear(config.hidden_size, flatten_shape, bias=False)
78 | self.norm1 = nn.BatchNorm1d(config.moe_heads, affine=False)
79 | self.norm2 = nn.BatchNorm1d(config.moe_heads, affine=False)
80 |
81 | def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
82 | g1z = self.w1(x).unflatten(-1, (self.config.moe_heads, -1)).float()
83 | g2z = self.w2(x).unflatten(-1, (self.config.moe_heads, -1)).float()
84 |
85 | g1n = self.norm1(g1z.transpose(-2, -1).flatten(0, -2))
86 | g2n = self.norm2(g2z.transpose(-2, -1).flatten(0, -2))
87 | g1n = g1n.view(*g1z.shape[:-2], g1z.size(-1), -1).transpose(-2, -1)
88 | g2n = g2n.view(*g2z.shape[:-2], g2z.size(-1), -1).transpose(-2, -1)
89 |
90 | sigma = float(norm.ppf(1 - self.config.moe_topk / self.config.moe_experts))
91 | g1s = g1n.amax(-1, keepdim=True).clamp_max_(sigma)
92 | g2s = g2n.amax(-1, keepdim=True).clamp_max_(sigma)
93 |
94 | g1 = nn.functional.softmax(torch.where(g1n >= g1s, g1z, -1e10), dim=-1)
95 | g2 = nn.functional.softmax(torch.where(g2n >= g2s, g2z, -1e10), dim=-1)
96 | return g1, g2
97 |
98 |
99 | class MonetMoVDE(nn.Module):
100 | def __init__(self, config: PretrainedConfig):
101 | super().__init__()
102 | self.config = config
103 | self.act_fn = ACT2FN[config.hidden_act]
104 | flatten_shape = config.moe_experts * config.moe_dim // 2
105 |
106 | self.u1 = nn.Linear(config.hidden_size, flatten_shape)
107 | self.u2 = nn.Linear(config.hidden_size, flatten_shape)
108 |
109 | self.v11 = nn.Linear(flatten_shape, config.hidden_size // 2, bias=False)
110 | self.v12 = nn.Linear(flatten_shape, config.hidden_size // 2, bias=False)
111 | self.v21 = nn.Linear(flatten_shape, config.hidden_size // 2, bias=False)
112 | self.v22 = nn.Linear(flatten_shape, config.hidden_size // 2, bias=False)
113 |
114 | self.b1 = nn.Parameter(torch.zeros(config.moe_experts, config.hidden_size // 2))
115 | self.b2 = nn.Parameter(torch.zeros(config.moe_experts, config.hidden_size // 2))
116 |
117 | def forward(
118 | self, x: torch.Tensor, g1: torch.Tensor, g2: torch.Tensor
119 | ) -> torch.Tensor:
120 | g1, g2 = g1.type_as(x), g2.type_as(x)
121 | x1 = self.act_fn(self.u1(x).unflatten(-1, (self.config.moe_experts, -1)))
122 | x2 = self.act_fn(self.u2(x).unflatten(-1, (self.config.moe_experts, -1)))
123 |
124 | x11 = self.v11(torch.einsum("bim,bhi->bim", x1, g1).flatten(-2))
125 | x12 = self.v12(torch.einsum("bjm,bhj,bhi->bim", x2, g2, g1).flatten(-2))
126 | x13 = torch.einsum("bhi,id->bd", g1, self.b1.type_as(x))
127 |
128 | x21 = self.v21(torch.einsum("bim,bhi,bhj->bjm", x1, g1, g2).flatten(-2))
129 | x22 = self.v22(torch.einsum("bjm,bhj->bjm", x2, g2).flatten(-2))
130 | x23 = torch.einsum("bhj,jd->bd", g2, self.b2.type_as(x))
131 |
132 | return torch.cat((x11 + x12 + x13, x21 + x22 + x23), dim=-1)
133 |
134 |
135 | class MonetMoHDE(nn.Module):
136 | def __init__(self, config: PretrainedConfig):
137 | super().__init__()
138 | self.config = config
139 | self.act_fn = ACT2FN[config.hidden_act]
140 | flatten_shape = config.moe_experts * config.moe_dim
141 |
142 | self.u = nn.Linear(config.hidden_size, flatten_shape)
143 | self.v = nn.Linear(flatten_shape, config.hidden_size, bias=False)
144 | self.b = nn.Parameter(torch.zeros(config.moe_experts, config.hidden_size))
145 |
146 | def forward(
147 | self, x: torch.Tensor, g1: torch.Tensor, g2: torch.Tensor
148 | ) -> torch.Tensor:
149 | g1, g2 = g1.type_as(x), g2.type_as(x)
150 | x = self.act_fn(self.u(x).unflatten(-1, (self.config.moe_experts, -1)))
151 | x = self.v(torch.einsum("bim,bhi,bhj->bjm", x, g1, g2).flatten(-2))
152 | return x + torch.einsum("bhj,jd->bd", g2, self.b)
153 |
154 |
155 | class MonetAttention(nn.Module):
156 | def __init__(
157 | self,
158 | config: PretrainedConfig,
159 | hidden_size: int,
160 | num_heads: int,
161 | num_kv_heads: int,
162 | rope_theta: float = 10000,
163 | rope_scaling: Optional[Dict[str, Any]] = None,
164 | max_position_embeddings: int = 8192,
165 | quant_config: Optional[QuantizationConfig] = None,
166 | bias: bool = False,
167 | cache_config: Optional[CacheConfig] = None,
168 | prefix: str = "",
169 | ):
170 | super().__init__()
171 | self.hidden_size = hidden_size
172 | tp_size = get_tensor_model_parallel_world_size()
173 | self.total_num_heads = num_heads
174 | assert self.total_num_heads % tp_size == 0
175 | self.num_heads = self.total_num_heads // tp_size
176 | self.total_num_kv_heads = num_kv_heads
177 | if self.total_num_kv_heads >= tp_size:
178 | # Number of KV heads is greater than TP size, so we partition
179 | # the KV heads across multiple tensor parallel GPUs.
180 | assert self.total_num_kv_heads % tp_size == 0
181 | else:
182 | # Number of KV heads is less than TP size, so we replicate
183 | # the KV heads across multiple tensor parallel GPUs.
184 | assert tp_size % self.total_num_kv_heads == 0
185 | self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
186 | # MistralConfig has an optional head_dim introduced by Mistral-Nemo
187 | self.head_dim = getattr(
188 | config, "head_dim", self.hidden_size // self.total_num_heads
189 | )
190 | self.q_size = self.num_heads * self.head_dim
191 | self.kv_size = self.num_kv_heads * self.head_dim
192 | self.scaling = self.head_dim**-0.5
193 | self.rope_theta = rope_theta
194 | self.max_position_embeddings = max_position_embeddings
195 |
196 | self.qkv_proj = QKVParallelLinear(
197 | hidden_size=hidden_size,
198 | head_size=self.head_dim,
199 | total_num_heads=self.total_num_heads,
200 | total_num_kv_heads=self.total_num_kv_heads,
201 | bias=bias,
202 | quant_config=quant_config,
203 | prefix=f"{prefix}.qkv_proj",
204 | )
205 |
206 | self.o_proj = RowParallelLinear(
207 | input_size=self.total_num_heads * self.head_dim,
208 | output_size=hidden_size,
209 | bias=bias,
210 | quant_config=quant_config,
211 | prefix=f"{prefix}.o_proj",
212 | )
213 |
214 | is_neox_style = True
215 | if quant_config is not None and quant_config.get_name() == "gguf":
216 | is_neox_style = False
217 |
218 | self.rotary_emb = get_rope(
219 | self.head_dim,
220 | rotary_dim=self.head_dim,
221 | max_position=max_position_embeddings,
222 | base=rope_theta,
223 | rope_scaling=rope_scaling,
224 | is_neox_style=is_neox_style,
225 | )
226 | self.attn = Attention(
227 | self.num_heads,
228 | self.head_dim,
229 | self.scaling,
230 | num_kv_heads=self.num_kv_heads,
231 | cache_config=cache_config,
232 | quant_config=quant_config,
233 | )
234 |
235 | def forward(
236 | self,
237 | positions: torch.Tensor,
238 | hidden_states: torch.Tensor,
239 | kv_cache: torch.Tensor,
240 | attn_metadata: AttentionMetadata,
241 | ) -> torch.Tensor:
242 | qkv, _ = self.qkv_proj(hidden_states)
243 | q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
244 | q, k = self.rotary_emb(positions, q, k)
245 | attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
246 | output, _ = self.o_proj(attn_output)
247 | return output
248 |
249 |
250 | class MonetDecoderLayer(nn.Module):
251 | def __init__(
252 | self,
253 | config: PretrainedConfig,
254 | layer_idx: int,
255 | cache_config: Optional[CacheConfig] = None,
256 | quant_config: Optional[QuantizationConfig] = None,
257 | prefix: str = "",
258 | ):
259 | super().__init__()
260 | self.hidden_size = config.hidden_size
261 | rope_theta = getattr(config, "rope_theta", 10000)
262 | rope_scaling = getattr(config, "rope_scaling", None)
263 | if rope_scaling is not None and getattr(
264 | config, "original_max_position_embeddings", None
265 | ):
266 | rope_scaling["original_max_position_embeddings"] = (
267 | config.original_max_position_embeddings
268 | )
269 | max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
270 | # Support abacusai/Smaug-72B-v0.1 with attention_bias
271 | # Support internlm/internlm-7b with bias
272 | attention_bias = getattr(config, "attention_bias", False) or getattr(
273 | config, "bias", False
274 | )
275 | self.self_attn = MonetAttention(
276 | config=config,
277 | hidden_size=self.hidden_size,
278 | num_heads=config.num_attention_heads,
279 | num_kv_heads=getattr(
280 | config, "num_key_value_heads", config.num_attention_heads
281 | ),
282 | rope_theta=rope_theta,
283 | rope_scaling=rope_scaling,
284 | max_position_embeddings=max_position_embeddings,
285 | quant_config=quant_config,
286 | bias=attention_bias,
287 | cache_config=cache_config,
288 | prefix=f"{prefix}.self_attn",
289 | )
290 | self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
291 | self.post_attention_layernorm = RMSNorm(
292 | config.hidden_size, eps=config.rms_norm_eps
293 | )
294 |
295 | if config.moe_decompose == "vertical":
296 | self.moe = MonetMoVDE(config)
297 | elif config.moe_decompose == "horizontal":
298 | self.moe = MonetMoHDE(config)
299 | if layer_idx % config.moe_groups == 0:
300 | self.router = MonetRouter(config).requires_grad_(False)
301 |
302 | def forward(
303 | self,
304 | positions: torch.Tensor,
305 | hidden_states: torch.Tensor,
306 | kv_cache: torch.Tensor,
307 | attn_metadata: AttentionMetadata,
308 | residual: Optional[torch.Tensor],
309 | previous_router_probs: tuple[torch.Tensor, torch.Tensor] | None = None,
310 | ) -> Tuple[torch.Tensor, torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
311 | # Self Attention
312 | if residual is None:
313 | residual = hidden_states
314 | hidden_states = self.input_layernorm(hidden_states)
315 | else:
316 | hidden_states, residual = self.input_layernorm(hidden_states, residual)
317 | hidden_states = self.self_attn(
318 | positions=positions,
319 | hidden_states=hidden_states,
320 | kv_cache=kv_cache,
321 | attn_metadata=attn_metadata,
322 | )
323 |
324 | # Fully Connected
325 | hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
326 | g1, g2 = (
327 | self.router(hidden_states)
328 | if hasattr(self, "router")
329 | else previous_router_probs
330 | )
331 | hidden_states = self.moe(hidden_states, g1, g2)
332 | return hidden_states, residual, (g1, g2)
333 |
334 |
335 | class MonetModel(nn.Module):
336 | def __init__(
337 | self,
338 | config: PretrainedConfig,
339 | cache_config: Optional[CacheConfig] = None,
340 | quant_config: Optional[QuantizationConfig] = None,
341 | lora_config: Optional[LoRAConfig] = None,
342 | prefix: str = "",
343 | ):
344 | super().__init__()
345 | self.config = config
346 | self.padding_idx = config.pad_token_id
347 | lora_vocab = (
348 | (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1))
349 | if lora_config
350 | else 0
351 | )
352 | self.vocab_size = config.vocab_size + lora_vocab
353 | self.org_vocab_size = config.vocab_size
354 | if get_pp_group().is_first_rank or (
355 | config.tie_word_embeddings and get_pp_group().is_last_rank
356 | ):
357 | self.embed_tokens = VocabParallelEmbedding(
358 | self.vocab_size,
359 | config.hidden_size,
360 | org_num_embeddings=config.vocab_size,
361 | quant_config=quant_config,
362 | )
363 | else:
364 | self.embed_tokens = PPMissingLayer()
365 |
366 | layer_idx = 0
367 |
368 | def layer_fn(prefix: str) -> MonetDecoderLayer:
369 | nonlocal layer_idx
370 | layer_idx += 1
371 | return MonetDecoderLayer(
372 | config=config,
373 | layer_idx=layer_idx - 1,
374 | cache_config=cache_config,
375 | quant_config=quant_config,
376 | prefix=prefix,
377 | )
378 |
379 | self.start_layer, self.end_layer, self.layers = make_layers(
380 | config.num_hidden_layers, layer_fn, prefix=f"{prefix}.layers"
381 | )
382 | if get_pp_group().is_last_rank:
383 | self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
384 | else:
385 | self.norm = PPMissingLayer()
386 |
387 | def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
388 | return self.embed_tokens(input_ids)
389 |
390 | def forward(
391 | self,
392 | input_ids: Optional[torch.Tensor],
393 | positions: torch.Tensor,
394 | kv_caches: List[torch.Tensor],
395 | attn_metadata: AttentionMetadata,
396 | intermediate_tensors: Optional[IntermediateTensors],
397 | inputs_embeds: Optional[torch.Tensor] = None,
398 | ) -> Union[torch.Tensor, IntermediateTensors]:
399 | if get_pp_group().is_first_rank:
400 | if inputs_embeds is not None:
401 | hidden_states = inputs_embeds
402 | else:
403 | hidden_states = self.get_input_embeddings(input_ids)
404 | residual = None
405 | else:
406 | assert intermediate_tensors is not None
407 | hidden_states = intermediate_tensors["hidden_states"]
408 | residual = intermediate_tensors["residual"]
409 |
410 | previous_router_probs = None
411 | for i in range(self.start_layer, self.end_layer):
412 | layer = self.layers[i]
413 | hidden_states, residual, previous_router_probs = layer(
414 | positions,
415 | hidden_states,
416 | kv_caches[i - self.start_layer],
417 | attn_metadata,
418 | residual,
419 | previous_router_probs,
420 | )
421 |
422 | if not get_pp_group().is_last_rank:
423 | return IntermediateTensors(
424 | {"hidden_states": hidden_states, "residual": residual}
425 | )
426 |
427 | hidden_states, _ = self.norm(hidden_states, residual)
428 | return hidden_states
429 |
430 |
431 | class MonetForCausalLM(nn.Module, SupportsLoRA):
432 | packed_modules_mapping = {
433 | "qkv_proj": [
434 | "q_proj",
435 | "k_proj",
436 | "v_proj",
437 | ],
438 | # "gate_up_proj": [
439 | # "gate_proj",
440 | # "up_proj",
441 | # ],
442 | }
443 |
444 | # LoRA specific attributes
445 | supported_lora_modules = [
446 | "qkv_proj",
447 | "o_proj",
448 | # "gate_up_proj",
449 | # "down_proj",
450 | "embed_tokens",
451 | "lm_head",
452 | ]
453 | embedding_modules = {
454 | "embed_tokens": "input_embeddings",
455 | "lm_head": "output_embeddings",
456 | }
457 | embedding_padding_modules = ["lm_head"]
458 | bitsandbytes_stacked_params_mapping = {
459 | # shard_name, weight_name, index
460 | "q_proj": ("qkv_proj", 0),
461 | "k_proj": ("qkv_proj", 1),
462 | "v_proj": ("qkv_proj", 2),
463 | # "gate_proj": ("gate_up_proj", 0),
464 | # "up_proj": ("gate_up_proj", 1),
465 | }
466 | # Mistral/Llama models can also be loaded with --load-format mistral
467 | # from consolidated.safetensors checkpoints
468 | mistral_mapping = {
469 | "layers": "model.layers",
470 | "attention": "self_attn",
471 | "wq": "q_proj",
472 | "wk": "k_proj",
473 | "wv": "v_proj",
474 | "wo": "o_proj",
475 | "attention_norm": "input_layernorm",
476 | "feed_forward": "mlp",
477 | # "w1": "gate_proj",
478 | # "w2": "down_proj",
479 | # "w3": "up_proj",
480 | "ffn_norm": "post_attention_layernorm",
481 | "tok_embeddings": "model.embed_tokens",
482 | "output": "lm_head",
483 | "norm": "model.norm",
484 | }
485 |
486 | def __init__(
487 | self,
488 | config: PretrainedConfig,
489 | cache_config: Optional[CacheConfig] = None,
490 | quant_config: Optional[QuantizationConfig] = None,
491 | lora_config: Optional[LoRAConfig] = None,
492 | ) -> None:
493 | super().__init__()
494 |
495 | self.config = config
496 | self.lora_config = lora_config
497 |
498 | self.model = MonetModel(
499 | config, cache_config, quant_config, lora_config=lora_config, prefix="model"
500 | )
501 | if get_pp_group().is_last_rank:
502 | self.unpadded_vocab_size = config.vocab_size
503 | if lora_config:
504 | self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
505 | self.lm_head = ParallelLMHead(
506 | self.unpadded_vocab_size,
507 | config.hidden_size,
508 | org_num_embeddings=config.vocab_size,
509 | padding_size=(
510 | DEFAULT_VOCAB_PADDING_SIZE
511 | # We need bigger padding if using lora for kernel
512 | # compatibility
513 | if not lora_config
514 | else lora_config.lora_vocab_padding_size
515 | ),
516 | quant_config=quant_config,
517 | )
518 | if config.tie_word_embeddings:
519 | self.lm_head.weight = self.model.embed_tokens.weight
520 |
521 | logit_scale = getattr(config, "logit_scale", 1.0)
522 | self.logits_processor = LogitsProcessor(
523 | self.unpadded_vocab_size, config.vocab_size, logit_scale
524 | )
525 | self.sampler = Sampler()
526 | else:
527 | self.lm_head = PPMissingLayer()
528 |
529 | def forward(
530 | self,
531 | input_ids: torch.Tensor,
532 | positions: torch.Tensor,
533 | kv_caches: List[torch.Tensor],
534 | attn_metadata: AttentionMetadata,
535 | intermediate_tensors: Optional[IntermediateTensors] = None,
536 | ) -> Union[torch.Tensor, IntermediateTensors]:
537 | model_output = self.model(
538 | input_ids, positions, kv_caches, attn_metadata, intermediate_tensors
539 | )
540 | return model_output
541 |
542 | def compute_logits(
543 | self,
544 | hidden_states: torch.Tensor,
545 | sampling_metadata: SamplingMetadata,
546 | ) -> Optional[torch.Tensor]:
547 | logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata)
548 | return logits
549 |
550 | def sample(
551 | self,
552 | logits: torch.Tensor,
553 | sampling_metadata: SamplingMetadata,
554 | ) -> Optional[SamplerOutput]:
555 | next_tokens = self.sampler(logits, sampling_metadata)
556 | return next_tokens
557 |
558 | def make_empty_intermediate_tensors(
559 | self, batch_size: int, dtype: torch.dtype, device: torch.device
560 | ) -> IntermediateTensors:
561 | return IntermediateTensors(
562 | {
563 | "hidden_states": torch.zeros(
564 | (batch_size, self.config.hidden_size), dtype=dtype, device=device
565 | ),
566 | "residual": torch.zeros(
567 | (batch_size, self.config.hidden_size), dtype=dtype, device=device
568 | ),
569 | }
570 | )
571 |
572 | def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
573 | stacked_params_mapping = [
574 | # (param_name, shard_name, shard_id)
575 | (".qkv_proj", ".q_proj", "q"),
576 | (".qkv_proj", ".k_proj", "k"),
577 | (".qkv_proj", ".v_proj", "v"),
578 | # (".gate_up_proj", ".gate_proj", 0),
579 | # (".gate_up_proj", ".up_proj", 1),
580 | ]
581 | params_dict = dict(self.named_parameters()) | dict(self.named_buffers())
582 | for name, loaded_weight in weights:
583 | name, loaded_weight = self.maybe_remap_mistral(name, loaded_weight)
584 |
585 | if "rotary_emb.inv_freq" in name:
586 | continue
587 | if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
588 | # Models trained using ColossalAI may include these tensors in
589 | # the checkpoint. Skip them.
590 | continue
591 | # With tie_word_embeddings, we can skip lm_head.weight
592 | # The weight might appear unnecessarily in the files if the model is
593 | # processed with quantization, LoRA, fine-tuning, etc.
594 | if self.config.tie_word_embeddings and "lm_head.weight" in name:
595 | continue
596 | if scale_name := get_compressed_tensors_cache_scale(name):
597 | # Loading kv cache scales for compressed-tensors quantization
598 | param = params_dict[scale_name]
599 | weight_loader = getattr(param, "weight_loader", default_weight_loader)
600 | loaded_weight = loaded_weight[0]
601 | weight_loader(param, loaded_weight)
602 | continue
603 | for param_name, weight_name, shard_id in stacked_params_mapping:
604 | if weight_name not in name:
605 | continue
606 | name = name.replace(weight_name, param_name)
607 | # Skip loading extra bias for GPTQ models.
608 | if name.endswith(".bias") and name not in params_dict:
609 | continue
610 |
611 | if is_pp_missing_parameter(name, self):
612 | continue
613 |
614 | param = params_dict[name]
615 | weight_loader = param.weight_loader
616 | weight_loader(param, loaded_weight, shard_id)
617 |
618 | break
619 | else:
620 | # Skip loading extra bias for GPTQ models.
621 | if name.endswith(".bias") and name not in params_dict:
622 | continue
623 | # Remapping the name of FP8 kv-scale.
624 | name = maybe_remap_kv_scale_name(name, params_dict)
625 | if name is None:
626 | continue
627 |
628 | if is_pp_missing_parameter(name, self):
629 | continue
630 |
631 | param = params_dict[name]
632 | weight_loader = getattr(param, "weight_loader", default_weight_loader)
633 | weight_loader(param, loaded_weight)
634 |
635 | # If this function is called, it should always initialize KV cache scale
636 | # factors (or else raise an exception). Thus, handled exceptions should
637 | # make sure to leave KV cache scale factors in a known good (dummy) state
638 | def load_kv_cache_scales(self, quantization_param_path: str) -> None:
639 | tp_size = get_tensor_model_parallel_world_size()
640 | tp_rank = get_tensor_model_parallel_rank()
641 | for layer_idx, scaling_factor in kv_cache_scales_loader(
642 | quantization_param_path,
643 | tp_rank,
644 | tp_size,
645 | self.config.num_hidden_layers,
646 | self.config.__class__.model_type,
647 | ):
648 | if not isinstance(self.model.layers[layer_idx], nn.Identity):
649 | layer_self_attn = self.model.layers[layer_idx].self_attn
650 |
651 | if is_hip():
652 | # The scaling factor convention we are assuming is
653 | # quantized_value * scaling_factor ~= true_value
654 | # which is consistent with the practice of setting
655 | # scaling_factor = tensor_amax / FPtype_max
656 | scaling_factor *= 2
657 | if hasattr(layer_self_attn, "kv_scale"):
658 | layer_self_attn.attn._kv_scale = scaling_factor
659 | else:
660 | raise RuntimeError(
661 | "Self attention has no KV cache scaling " "factor attribute!"
662 | )
663 |
664 | # This function is used to remap the mistral format as
665 | # used by Mistral and Llama <=2
666 | def maybe_remap_mistral(
667 | self, name: str, loaded_weight: torch.Tensor
668 | ) -> Tuple[str, torch.Tensor]:
669 |
670 | def permute(w, n_heads):
671 | attn_in = self.config.head_dim * n_heads
672 | attn_out = self.config.hidden_size
673 |
674 | return (
675 | w.view(n_heads, attn_in // n_heads // 2, 2, attn_out)
676 | .transpose(1, 2)
677 | .reshape(attn_in, attn_out)
678 | )
679 |
680 | mapping = self.mistral_mapping
681 | modules = name.split(".")
682 |
683 | # rotary embeds should be sliced
684 | if "wk" in modules:
685 | loaded_weight = permute(loaded_weight, self.config.num_key_value_heads)
686 | elif "wq" in modules:
687 | loaded_weight = permute(loaded_weight, self.config.num_attention_heads)
688 |
689 | for item in modules:
690 | if item in mapping and mapping[item] not in name:
691 | name = name.replace(item, mapping[item])
692 |
693 | return name, loaded_weight
694 |
--------------------------------------------------------------------------------