├── .gitignore
├── LICENSE
├── README.md
├── jetmoe
├── __init__.py
├── configuration_jetmoe.py
├── modeling_jetmoe.py
└── utils
│ ├── __init__.py
│ ├── gate.py
│ ├── moe.py
│ └── parallel_experts.py
├── resources
├── 1-myshell-mit.png
├── 2-performance.png
├── 3-architecture.png
├── 4-phase1-data.png
└── 5-phase2-data.png
└── setup.py
/.gitignore:
--------------------------------------------------------------------------------
1 | *.pyc
2 | *.egg-info
3 | build/
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # JetMoE: Reaching LLaMA2 Performance with 0.1M Dollars
2 |
3 |
4 |
5 |

6 |

7 |
8 |
9 | ## Key Messages
10 |
11 | 1. JetMoE-8B is **trained with less than $ 0.1 million**1 **cost but outperforms LLaMA2-7B from Meta AI**, who has multi-billion-dollar training resources. LLM training can be **much cheaper than people previously thought**.
12 |
13 | 2. JetMoE-8B is **fully open-sourced and academia-friendly** because:
14 | - It **only uses public datasets** for training, and the code is open-sourced. No proprietary resource is needed.
15 | - It **can be finetuned with very limited compute budget** (e.g., consumer-grade GPU) that most labs can afford.
16 |
17 | 3. JetMoE-8B **only has 2.2B active parameters** during inference, which drastically lowers the computational cost. Compared to a model with similar inference computation, like Gemma-2B, JetMoE-8B achieves constantly better performance.
18 |
19 | 1 We used a 96×H100 GPU cluster for 2 weeks, which cost ~$0.08 million.
20 |
21 | Website: [https://research.myshell.ai/jetmoe](https://research.myshell.ai/jetmoe)
22 |
23 | HuggingFace: [https://huggingface.co/jetmoe/jetmoe-8b](https://huggingface.co/jetmoe/jetmoe-8b)
24 |
25 | Online Demo on Lepton AI: [https://www.lepton.ai/playground/chat?model=jetmoe-8b-chat](https://www.lepton.ai/playground/chat?model=jetmoe-8b-chat)
26 |
27 | Technical Report: [https://arxiv.org/pdf/2404.07413.pdf](https://arxiv.org/pdf/2404.07413.pdf)
28 |
29 | ## Authors
30 |
31 | The project is contributed by [Yikang Shen](https://scholar.google.com.hk/citations?user=qff5rRYAAAAJ), [Zhen Guo](https://zguo0525.github.io/), [Tianle Cai](https://www.tianle.website/#/) and [Zengyi Qin](https://www.qinzy.tech/). For technical inquiries, please contact [Yikang Shen](https://scholar.google.com.hk/citations?user=qff5rRYAAAAJ). For media and collaboration inquiries, please contact [Zengyi Qin](https://www.qinzy.tech/).
32 |
33 | ## Collaboration
34 | **If you have great ideas but need more resources (GPU, data, funding, etc.)**, welcome to contact **MyShell.ai** via [Zengyi Qin](https://www.qinzy.tech/). **MyShell.ai** is open to collaborations and are actively supporting high-quality open-source projects.
35 |
36 | ## Benchmarks
37 | We use the same evaluation methodology as in the Open LLM leaderboard. For MBPP code benchmark, we use the same evaluation methodology as in the LLaMA2 and Deepseek-MoE paper. The results are shown below:
38 |
39 | |Model|Activate Params|Training Tokens|Open LLM Leaderboard Avg|ARC|Hellaswag|MMLU|TruthfulQA|WinoGrande|GSM8k|MBPP|HumanEval|
40 | |---|---|---|---|---|---|---|---|---|---|---|---|
41 | |Shot||||25|10|5|0|5|5|3|0|
42 | |Metric||||acc_norm|acc_norm|acc|mc2|acc|acc|Pass@1|Pass@1|
43 | |LLaMA2-7B|7B|2T|51.0|53.1|78.6|46.9|38.8|74|14.5|20.8|12.8|
44 | |LLaMA-13B|13B|1T|51.4|**56.2**|**80.9**|47.7|39.5|**76.2**|7.6|22.0|15.8|
45 | |DeepseekMoE-16B|2.8B|2T|51.1|53.2|79.8|46.3|36.1|73.7|17.3|34.0|**25.0**|
46 | |Gemma-2B|2B|2T|46.4|48.4|71.8|41.8|33.1|66.3|16.9|28.0|24.4|
47 | |JetMoE-8B|2.2B|1.25T|**53.0**|48.7|80.5|**49.2**|**41.7**|70.2|**27.8**|**34.2**|14.6|
48 |
49 | | Model | MT-Bench Score |
50 | |---------------------|-----------|
51 | | GPT-4 | 9.014 |
52 | | GPT-3.5-turbo | 7.995 |
53 | | Claude-v1 | 7.923 |
54 | | **JetMoE-8B-chat** | **6.681** |
55 | | Llama-2-13b-chat | 6.650 |
56 | | Vicuna-13b-v1.3 | 6.413 |
57 | | Wizardlm-13b | 6.353 |
58 | | Llama-2-7b-chat | 6.269 |
59 |
60 | To our surprise, despite the lower training cost and computation, JetMoE-8B performs even better than LLaMA2-7B, LLaMA-13B, and DeepseekMoE-16B. Compared to a model with similar training and inference computation, like Gemma-2B, JetMoE-8B achieves better performance.
61 |
62 | ## Model Usage
63 | To load the models, you need install this package:
64 | ```
65 | pip install -e .
66 | ```
67 |
68 | Then you can load the model with the following code:
69 | ```python
70 | from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, AutoModelForSequenceClassification
71 | from jetmoe import JetMoEForCausalLM, JetMoEConfig, JetMoEForSequenceClassification
72 |
73 | AutoConfig.register("jetmoe", JetMoEConfig)
74 | AutoModelForCausalLM.register(JetMoEConfig, JetMoEForCausalLM)
75 | AutoModelForSequenceClassification.register(JetMoEConfig, JetMoEForSequenceClassification)
76 |
77 | tokenizer = AutoTokenizer.from_pretrained('jetmoe/jetmoe-8b')
78 | model = AutoModelForCausalLM.from_pretrained('jetmoe/jetmoe-8b')
79 | ```
80 |
81 | ## Model Details
82 | Please refer to the technical report [https://arxiv.org/pdf/2404.07413.pdf](https://arxiv.org/pdf/2404.07413.pdf) for model details and training details.
83 |
84 | ## Acknowledgement
85 | We express our gratitude to [Shengding Hu](https://shengdinghu.github.io/) for his valuable advice on the Phase 2 data mixture. We also express our gratitude to [Exabits](https://www.exabits.ai/) for their assistance in setting up the GPU clusters.
86 |
--------------------------------------------------------------------------------
/jetmoe/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 JetMoE AI and The HuggingFace Inc. team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | from typing import TYPE_CHECKING
15 |
16 | from transformers.utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available
17 |
18 |
19 | _import_structure = {
20 | "configuration_jetmoe": ["JETMOE_PRETRAINED_CONFIG_ARCHIVE_MAP", "JetMoEConfig"],
21 | }
22 |
23 |
24 | try:
25 | if not is_torch_available():
26 | raise OptionalDependencyNotAvailable()
27 | except OptionalDependencyNotAvailable:
28 | pass
29 | else:
30 | _import_structure["modeling_jetmoe"] = [
31 | "JetMoEForCausalLM",
32 | "JetMoEModel",
33 | "JetMoEPreTrainedModel",
34 | "JetMoEForSequenceClassification",
35 | ]
36 |
37 | if TYPE_CHECKING:
38 | from .configuration_jetmoe import JETMOE_PRETRAINED_CONFIG_ARCHIVE_MAP, JetMoEConfig
39 |
40 | try:
41 | if not is_torch_available():
42 | raise OptionalDependencyNotAvailable()
43 | except OptionalDependencyNotAvailable:
44 | pass
45 | else:
46 | from .modeling_jetmoe import (
47 | JetMoEForCausalLM,
48 | JetMoEForSequenceClassification,
49 | JetMoEModel,
50 | JetMoEPreTrainedModel,
51 | )
52 |
53 | else:
54 | import sys
55 |
56 | sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
57 |
--------------------------------------------------------------------------------
/jetmoe/configuration_jetmoe.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 JetMoE AI and the HuggingFace Inc. team. All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """JetMoE model configuration"""
16 |
17 | from transformers.configuration_utils import PretrainedConfig
18 | from transformers.utils import logging
19 |
20 |
21 | logger = logging.get_logger(__name__)
22 |
23 |
24 | class JetMoEConfig(PretrainedConfig):
25 | r"""
26 | This is the configuration class to store the configuration of a [`JetMoEModel`]. It is used to instantiate an
27 | JetMoE model according to the specified arguments, defining the model architecture. Instantiating a configuration
28 | with the defaults will yield a configuration of the JetMoE-4B.
29 |
30 | [jetmoe/jetmoe-8b](https://huggingface.co/jetmoe/jetmoe-8b)
31 |
32 | Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
33 | documentation from [`PretrainedConfig`] for more information.
34 |
35 |
36 | Args:
37 | vocab_size (`int`, *optional*, defaults to 32000):
38 | Vocabulary size of the JetMoE model. Defines the number of different tokens that can be represented by the
39 | `inputs_ids` passed when calling [`JetMoEModel`]
40 | hidden_size (`int`, *optional*, defaults to 2048):
41 | Dimension of the hidden representations.
42 | num_hidden_layers (`int`, *optional*, defaults to 12): Defines the number of blocks.
43 | num_attention_heads (`int`, *optional*, defaults to 32):
44 | Number of attention heads for each attention layer in the Transformer encoder.
45 | num_key_value_heads (`int`, *optional*, defaults to 16):
46 | Number of attention heads for each key and value in the Transformer encoder.
47 | kv_channels (`int`, *optional*, defaults to 128): Defines the number of channels for the key and value tensors.
48 | ffn_hidden_size (`int`, *optional*, defaults to 5632): Defines the hidden size of the feed-forward layer.
49 | max_position_embeddings (`int`, *optional*, defaults to 4096):
50 | The maximum sequence length that this model might ever be used with. JetMoE's sliding window attention
51 | allows sequence of up to 4096*32 tokens.
52 | activation_function (`string`, *optional*, defaults to `"silu"`): Defines the activation function for MLP experts.
53 | glu (`bool`, *optional*, defaults to `True`): Whether to use Gated Linear Units in the MLP experts.
54 | moe_num_experts (`int`, *optional*, defaults to 8): Defines the number of experts in the mixture of experts.
55 | moe_top_k (`int, *optional*, defaults to 2): Defines the number of experts to use for each token.
56 | use_cache (`bool`, *optional*, defaults to `True`):
57 | Whether or not the model should return the last key/values attentions (not used by all models). Only
58 | relevant if `config.is_decoder=True`.
59 | bos_token_id (`int`, *optional*, defaults to 1):
60 | The id of the "beginning-of-sequence" token.
61 | eos_token_id (`int`, *optional*, defaults to 2):
62 | The id of the "end-of-sequence" token.
63 | tie_word_embeddings (`bool`, *optional*, defaults to `True`):
64 | Whether the model's input and output word embeddings should be tied.
65 | bias (`bool`, *optional*, defaults to `True`): Whether to use bias in the feed-forward and attention layer.
66 | rope_theta (`float`, *optional*, defaults to 10000.0):
67 | The base period of the RoPE embeddings.
68 | rms_norm_eps (`float`, *optional*, defaults to 1e-06):
69 | The epsilon used by the rms normalization layers.
70 | initializer_range (`float`, *optional*, defaults to 0.01):
71 | The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
72 |
73 | ```python
74 | >>> from transformers import JetMoEModel, JetMoEConfig
75 |
76 | >>> # Initializing a JetMoE 4B style configuration
77 | >>> configuration = JetMoEConfig()
78 |
79 | >>> # Initializing a model from the JetMoE 4B style configuration
80 | >>> model = JetMoEModel(configuration)
81 |
82 | >>> # Accessing the model configuration
83 | >>> configuration = model.config
84 | ```"""
85 |
86 | model_type = "jetmoe"
87 | keys_to_ignore_at_inference = ["past_key_values"]
88 |
89 | def __init__(
90 | self,
91 | vocab_size=32000,
92 | hidden_size=2048,
93 | num_hidden_layers=12,
94 | num_attention_heads=32,
95 | num_key_value_heads=16,
96 | kv_channels=128,
97 | ffn_hidden_size=5632,
98 | max_position_embeddings=4096,
99 | activation_function="silu",
100 | glu=True,
101 | moe_num_experts=8,
102 | moe_top_k=2,
103 | use_cache=True,
104 | bos_token_id=1,
105 | eos_token_id=2,
106 | tie_word_embeddings=True,
107 | bias=True,
108 | rope_theta=10000.0,
109 | rms_norm_eps=1e-6,
110 | initializer_range=0.01,
111 | **kwargs,
112 | ):
113 | self.vocab_size = vocab_size
114 | self.hidden_size = hidden_size
115 | self.num_hidden_layers = num_hidden_layers
116 | self.num_attention_heads = num_attention_heads
117 | self.num_key_value_heads = num_key_value_heads
118 | self.kv_channels = kv_channels
119 | self.ffn_hidden_size = ffn_hidden_size
120 | self.max_position_embeddings = max_position_embeddings
121 | self.activation_function = activation_function
122 | self.glu = glu
123 | self.moe_num_experts = moe_num_experts
124 | self.moe_top_k = moe_top_k
125 | self.use_cache = use_cache
126 | self.initializer_range = initializer_range
127 |
128 | self.bos_token_id = bos_token_id
129 | self.eos_token_id = eos_token_id
130 |
131 | self.bias = bias
132 | self.rope_theta = rope_theta
133 | self.rms_norm_eps = rms_norm_eps
134 |
135 | super().__init__(
136 | bos_token_id=bos_token_id, eos_token_id=eos_token_id, tie_word_embeddings=tie_word_embeddings, **kwargs
137 | )
138 |
--------------------------------------------------------------------------------
/jetmoe/modeling_jetmoe.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 JetMoE AI and the HuggingFace Inc. team. All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """PyTorch JetMoE model."""
16 |
17 | import math
18 | import warnings
19 | from typing import List, Optional, Tuple, Union
20 |
21 | import torch
22 | import torch.utils.checkpoint
23 | from torch import nn
24 | from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
25 | from torch.nn import functional as F
26 |
27 | from transformers.activations import ACT2FN
28 | from transformers.cache_utils import Cache, DynamicCache
29 | from transformers.modeling_attn_mask_utils import (
30 | _prepare_4d_causal_attention_mask,
31 | _prepare_4d_causal_attention_mask_for_sdpa,
32 | )
33 | from transformers.modeling_outputs import (
34 | BaseModelOutputWithPast,
35 | CausalLMOutputWithPast,
36 | SequenceClassifierOutputWithPast,
37 | dataclass,
38 | )
39 | from transformers.modeling_utils import PreTrainedModel
40 | from transformers.utils import (
41 | add_start_docstrings,
42 | add_start_docstrings_to_model_forward,
43 | is_flash_attn_2_available,
44 | is_flash_attn_greater_or_equal_2_10,
45 | logging,
46 | replace_return_docstrings,
47 | )
48 | from .configuration_jetmoe import JetMoEConfig
49 | from .utils import MoE, ParallelExperts
50 |
51 |
52 | if is_flash_attn_2_available():
53 | from flash_attn import flash_attn_func, flash_attn_varlen_func
54 | from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
55 |
56 | logger = logging.get_logger(__name__)
57 |
58 | _CHECKPOINT_FOR_DOC = "jetmoe"
59 | _CONFIG_FOR_DOC = "JetMoEConfig"
60 |
61 |
62 | @dataclass
63 | class JetMoEBaseModelOutputWithPast(BaseModelOutputWithPast):
64 | """
65 | Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding).
66 |
67 | Args:
68 | last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
69 | Sequence of hidden-states at the output of the last layer of the model.
70 |
71 | If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
72 | hidden_size)` is output.
73 | past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
74 | Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
75 | `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
76 | `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
77 | encoder_sequence_length, embed_size_per_head)`.
78 |
79 | Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if
80 | `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values`
81 | input) to speed up sequential decoding.
82 | hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
83 | Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
84 | one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
85 |
86 | Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
87 | attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
88 | Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
89 | sequence_length)`.
90 |
91 | Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
92 | heads.
93 | """
94 |
95 | last_hidden_state: torch.FloatTensor = None
96 | past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
97 | hidden_states: Optional[Tuple[torch.FloatTensor]] = None
98 | attentions: Optional[Tuple[torch.FloatTensor]] = None
99 | aux_loss: Optional[torch.FloatTensor] = None
100 |
101 |
102 | @dataclass
103 | class JetMoECausalLMOutputWithPast(CausalLMOutputWithPast):
104 | """
105 | Base class for causal language model (or autoregressive) outputs.
106 |
107 | Args:
108 | loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
109 | Language modeling loss (for next-token prediction).
110 | logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
111 | Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
112 | past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
113 | Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
114 | `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
115 |
116 | Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
117 | `past_key_values` input) to speed up sequential decoding.
118 | hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
119 | Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
120 | one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
121 |
122 | Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
123 | attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
124 | Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
125 | sequence_length)`.
126 |
127 | Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
128 | heads.
129 | """
130 |
131 | loss: Optional[torch.FloatTensor] = None
132 | logits: torch.FloatTensor = None
133 | past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
134 | hidden_states: Optional[Tuple[torch.FloatTensor]] = None
135 | attentions: Optional[Tuple[torch.FloatTensor]] = None
136 | aux_loss: Optional[torch.FloatTensor] = None
137 |
138 |
139 | @dataclass
140 | class JetMoESequenceClassifierOutputWithPast(SequenceClassifierOutputWithPast):
141 | """
142 | Base class for outputs of sentence classification models.
143 |
144 | Args:
145 | loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
146 | Classification (or regression if config.num_labels==1) loss.
147 | logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
148 | Classification (or regression if config.num_labels==1) scores (before SoftMax).
149 | past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
150 | Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
151 | `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
152 |
153 | Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
154 | `past_key_values` input) to speed up sequential decoding.
155 | hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
156 | Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
157 | one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
158 |
159 | Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
160 | attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
161 | Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
162 | sequence_length)`.
163 |
164 | Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
165 | heads.
166 | """
167 |
168 | loss: Optional[torch.FloatTensor] = None
169 | logits: torch.FloatTensor = None
170 | past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
171 | hidden_states: Optional[Tuple[torch.FloatTensor]] = None
172 | attentions: Optional[Tuple[torch.FloatTensor]] = None
173 | aux_loss: Optional[torch.FloatTensor] = None
174 |
175 |
176 | # Copied from transformers.models.llama.modeling_llama._get_unpad_data
177 | def _get_unpad_data(attention_mask):
178 | seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
179 | indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
180 | max_seqlen_in_batch = seqlens_in_batch.max().item()
181 | cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
182 | return (
183 | indices,
184 | cu_seqlens,
185 | max_seqlen_in_batch,
186 | )
187 |
188 |
189 | class JetMoERMSNorm(nn.Module):
190 | def __init__(self, hidden_size, eps=1e-6):
191 | """
192 | JetMoERMSNorm module
193 | """
194 | super().__init__()
195 | self.weight = nn.Parameter(torch.ones(hidden_size))
196 | self.variance_epsilon = eps
197 |
198 | def forward(self, hidden_states):
199 | input_dtype = hidden_states.dtype
200 | hidden_states = hidden_states.to(torch.float32)
201 | variance = hidden_states.pow(2).mean(-1, keepdim=True)
202 | hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
203 | return self.weight * hidden_states.to(input_dtype)
204 |
205 |
206 | # copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding
207 | class JetMoERotaryEmbedding(nn.Module):
208 | def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
209 | super().__init__()
210 |
211 | self.dim = dim
212 | self.max_position_embeddings = max_position_embeddings
213 | self.base = base
214 | inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
215 | self.register_buffer("inv_freq", inv_freq, persistent=False)
216 |
217 | # Build here to make `torch.jit.trace` work.
218 | self._set_cos_sin_cache(
219 | seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
220 | )
221 |
222 | def _set_cos_sin_cache(self, seq_len, device, dtype):
223 | self.max_seq_len_cached = seq_len
224 | t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
225 |
226 | freqs = torch.outer(t, self.inv_freq)
227 | # Different from paper, but it uses a different permutation in order to obtain the same calculation
228 | emb = torch.cat((freqs, freqs), dim=-1)
229 | self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
230 | self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
231 |
232 | def forward(self, x, seq_len=None):
233 | # x: [bs, num_attention_heads, seq_len, head_size]
234 | if seq_len > self.max_seq_len_cached:
235 | self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
236 |
237 | return (
238 | self.cos_cached[:seq_len].to(dtype=x.dtype),
239 | self.sin_cached[:seq_len].to(dtype=x.dtype),
240 | )
241 |
242 |
243 | # Copied from transformers.models.llama.modeling_llama.rotate_half
244 | def rotate_half(x):
245 | """Rotates half the hidden dims of the input."""
246 | x1 = x[..., : x.shape[-1] // 2]
247 | x2 = x[..., x.shape[-1] // 2 :]
248 | return torch.cat((-x2, x1), dim=-1)
249 |
250 |
251 | # copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
252 | def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=2):
253 | """Applies Rotary Position Embedding to the query and key tensors.
254 |
255 | Args:
256 | q (`torch.Tensor`): The query tensor.
257 | k (`torch.Tensor`): The key tensor.
258 | cos (`torch.Tensor`): The cosine part of the rotary embedding.
259 | sin (`torch.Tensor`): The sine part of the rotary embedding.
260 | position_ids (`torch.Tensor`):
261 | The position indices of the tokens corresponding to the query and key tensors. For example, this can be
262 | used to pass offsetted position ids when working with a KV-cache.
263 | unsqueeze_dim (`int`, *optional*, defaults to 1):
264 | The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
265 | sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
266 | that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
267 | k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
268 | cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
269 | the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
270 | Returns:
271 | `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
272 | """
273 | cos = cos[position_ids].unsqueeze(unsqueeze_dim)
274 | sin = sin[position_ids].unsqueeze(unsqueeze_dim)
275 | q_embed = (q * cos) + (rotate_half(q) * sin)
276 | k_embed = (k * cos) + (rotate_half(k) * sin)
277 | return q_embed, k_embed
278 |
279 |
280 | class JetMoEAttention(nn.Module):
281 | """
282 | Multi-headed attention from 'Attention Is All You Need' paper.
283 | """
284 |
285 | def __init__(self, config: JetMoEConfig, layer_idx: Optional[int] = None):
286 | """
287 | Initialize the JetMoEAttention module.
288 |
289 | Args:
290 | config: Configuration object with model hyperparameters.
291 | """
292 | super().__init__()
293 | self.config = config
294 | self.layer_idx = layer_idx
295 | self.is_causal = True
296 | if layer_idx is None:
297 | logger.warning_once(
298 | f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
299 | "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
300 | "when creating this class."
301 | )
302 |
303 | self.top_k = config.moe_top_k
304 |
305 | self.kv_projection_size = config.kv_channels * config.num_key_value_heads
306 | self.num_key_value_heads = config.num_key_value_heads
307 | self.num_heads = config.num_attention_heads
308 | assert self.num_heads == self.num_key_value_heads * config.moe_top_k
309 | self.hidden_size_per_attention_head = config.kv_channels
310 |
311 | self.experts = MoE(
312 | input_size=config.hidden_size,
313 | hidden_size=self.kv_projection_size,
314 | num_experts=config.moe_num_experts,
315 | top_k=config.moe_top_k,
316 | glu=False,
317 | )
318 |
319 | self.kv_proj = torch.nn.Linear(config.hidden_size, self.kv_projection_size * 2, bias=False)
320 |
321 | self.rotary_emb = JetMoERotaryEmbedding(
322 | config.kv_channels,
323 | max_position_embeddings=config.max_position_embeddings,
324 | base=config.rope_theta,
325 | )
326 |
327 | # def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
328 | # return tensor.view(bsz, seq_len, self.num_attention_heads, self.hidden_size_per_attention_head).transpose(1, 2).contiguous()
329 |
330 | def forward(
331 | self,
332 | hidden_states: torch.Tensor,
333 | attention_mask: Optional[torch.Tensor] = None,
334 | position_ids: Optional[torch.LongTensor] = None,
335 | past_key_value: Optional[Cache] = None,
336 | output_attentions: bool = False,
337 | use_cache: bool = False,
338 | **kwargs,
339 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
340 | if "padding_mask" in kwargs:
341 | warnings.warn(
342 | "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
343 | )
344 | bsz, q_len, _ = hidden_states.size()
345 |
346 | query_states, aux_loss = self.experts.map(hidden_states)
347 | key_states, value_states = self.kv_proj(hidden_states).chunk(2, dim=-1)
348 |
349 | query_states = query_states.view(bsz, q_len, self.num_heads, self.hidden_size_per_attention_head).transpose(
350 | 1, 2
351 | )
352 | key_states = key_states.view(
353 | bsz, q_len, self.num_key_value_heads, self.hidden_size_per_attention_head
354 | ).transpose(1, 2)
355 | value_states = value_states.view(
356 | bsz, q_len, self.num_key_value_heads, self.hidden_size_per_attention_head
357 | ).transpose(1, 2)
358 |
359 | kv_seq_len = key_states.shape[2]
360 | if past_key_value is not None:
361 | if self.layer_idx is None:
362 | raise ValueError(
363 | f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
364 | "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
365 | "with a layer index."
366 | )
367 | kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
368 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
369 | query_states, key_states = apply_rotary_pos_emb(
370 | query_states, key_states, cos, sin, position_ids, unsqueeze_dim=1
371 | )
372 |
373 | if past_key_value is not None:
374 | cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
375 | key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
376 |
377 | # repeat k/v heads if n_kv_heads < n_heads
378 | key_states = key_states.repeat(1, self.top_k, 1, 1)
379 | value_states = value_states.repeat(1, self.top_k, 1, 1)
380 |
381 | attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(
382 | self.hidden_size_per_attention_head
383 | )
384 |
385 | if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
386 | raise ValueError(
387 | f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
388 | f" {attn_weights.size()}"
389 | )
390 |
391 | if attention_mask is not None:
392 | if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
393 | raise ValueError(
394 | f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
395 | )
396 |
397 | attn_weights = attn_weights + attention_mask
398 |
399 | # upcast attention to fp32
400 | attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
401 | # attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
402 | attn_output = torch.matmul(attn_weights, value_states)
403 |
404 | if attn_output.size() != (bsz, self.num_heads, q_len, self.hidden_size_per_attention_head):
405 | raise ValueError(
406 | f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.hidden_size_per_attention_head)}, but is"
407 | f" {attn_output.size()}"
408 | )
409 |
410 | attn_output = attn_output.transpose(1, 2).contiguous()
411 | attn_output = attn_output.reshape(bsz, q_len, self.top_k, self.kv_projection_size)
412 |
413 | attn_output = self.experts.reduce(attn_output)
414 | attn_output = attn_output.view(bsz, q_len, -1)
415 |
416 | if not output_attentions:
417 | attn_weights = None
418 |
419 | return attn_output, attn_weights, past_key_value, aux_loss
420 |
421 |
422 | # copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->JetMoE
423 | class JetMoESdpaAttention(JetMoEAttention):
424 | """
425 | JetMoE attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
426 | `JetMoEAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
427 | SDPA API.
428 | """
429 |
430 | # Adapted from JetMoEAttention.forward
431 | def forward(
432 | self,
433 | hidden_states: torch.Tensor,
434 | attention_mask: Optional[torch.Tensor] = None,
435 | position_ids: Optional[torch.LongTensor] = None,
436 | past_key_value: Optional[Cache] = None,
437 | output_attentions: bool = False,
438 | use_cache: bool = False,
439 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
440 | if output_attentions:
441 | # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
442 | logger.warning_once(
443 | "JetMoEModel is using JetMoESdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
444 | 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
445 | )
446 | return super().forward(
447 | hidden_states=hidden_states,
448 | attention_mask=attention_mask,
449 | position_ids=position_ids,
450 | past_key_value=past_key_value,
451 | output_attentions=output_attentions,
452 | use_cache=use_cache,
453 | )
454 |
455 | bsz, q_len, _ = hidden_states.size()
456 |
457 | query_states, aux_loss = self.experts.map(hidden_states)
458 | key_states, value_states = self.kv_proj(hidden_states).chunk(2, dim=-1)
459 |
460 | query_states = query_states.view(bsz, q_len, self.num_heads, self.hidden_size_per_attention_head).transpose(
461 | 1, 2
462 | )
463 | key_states = key_states.view(
464 | bsz, q_len, self.num_key_value_heads, self.hidden_size_per_attention_head
465 | ).transpose(1, 2)
466 | value_states = value_states.view(
467 | bsz, q_len, self.num_key_value_heads, self.hidden_size_per_attention_head
468 | ).transpose(1, 2)
469 |
470 | kv_seq_len = key_states.shape[2]
471 | if past_key_value is not None:
472 | kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
473 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
474 |
475 | query_states, key_states = apply_rotary_pos_emb(
476 | query_states, key_states, cos, sin, position_ids, unsqueeze_dim=1
477 | )
478 |
479 | if past_key_value is not None:
480 | cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
481 | key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
482 |
483 | key_states = key_states.repeat(1, self.top_k, 1, 1)
484 | value_states = value_states.repeat(1, self.top_k, 1, 1)
485 |
486 | if attention_mask is not None:
487 | if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
488 | raise ValueError(
489 | f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
490 | )
491 |
492 | # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
493 | # Reference: https://github.com/pytorch/pytorch/issues/112577.
494 | if query_states.device.type == "cuda" and attention_mask is not None:
495 | query_states = query_states.contiguous()
496 | key_states = key_states.contiguous()
497 | value_states = value_states.contiguous()
498 |
499 | attn_output = torch.nn.functional.scaled_dot_product_attention(
500 | query_states,
501 | key_states,
502 | value_states,
503 | attn_mask=attention_mask,
504 | dropout_p=0.0,
505 | # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
506 | is_causal=self.is_causal and attention_mask is None and q_len > 1,
507 | )
508 |
509 | attn_output = attn_output.transpose(1, 2).contiguous()
510 | attn_output = attn_output.reshape(bsz, q_len, self.top_k, self.kv_projection_size)
511 |
512 | attn_output = self.experts.reduce(attn_output)
513 | attn_output = attn_output.view(bsz, q_len, -1)
514 |
515 | return attn_output, None, past_key_value, aux_loss
516 |
517 |
518 | class JetMoEFlashAttention2(JetMoEAttention):
519 | def __init__(self, *args, **kwargs):
520 | super().__init__(*args, **kwargs)
521 |
522 | # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
523 | # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
524 | # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
525 | self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
526 |
527 | def forward(
528 | self,
529 | hidden_states: Optional[torch.FloatTensor],
530 | attention_mask: Optional[torch.FloatTensor] = None,
531 | position_ids: Optional[torch.LongTensor] = None,
532 | past_key_value: Optional[Cache] = None,
533 | use_cache: Optional[bool] = False,
534 | output_attentions: Optional[bool] = False,
535 | **kwargs,
536 | ) -> Union[
537 | Tuple[torch.Tensor, Tuple[torch.Tensor]],
538 | Optional[Tuple[torch.Tensor, Tuple[torch.Tensor], Tuple[torch.Tensor, ...]]],
539 | ]:
540 | """
541 | Forward pass of the JetMoEAttention module.
542 |
543 | Args:
544 | hidden_states (Optional[torch.FloatTensor]): Input hidden states.
545 | attention_mask (Optional[torch.FloatTensor]): Attention mask.
546 | layer_past (Optional[Tuple[torch.Tensor]]): Past layer state.
547 | use_cache (Optional[bool]): Whether to use cached states.
548 | output_attentions (Optional[bool]): Whether to output attention weights.
549 |
550 | Returns:
551 | Union[Tuple[torch.Tensor, Tuple[torch.Tensor]], Optional[Tuple[...]]]: Tuple containing outputs.
552 | """
553 | # assert attention_mask is None, "attention_mask is not supported"
554 | assert output_attentions is False, "output_attentions is not supported"
555 |
556 | B, T, C = hidden_states.size() # batch size, sequence length, embedding dimensionality (hidden_size)
557 |
558 | # calculate query, key, values
559 | query_layer, aux_loss = self.experts.map(hidden_states)
560 | key_layer, value_layer = self.kv_proj(hidden_states).chunk(2, dim=-1)
561 |
562 | query_layer = query_layer.view(B, T, self.num_heads, self.hidden_size_per_attention_head) # (B, T, k * nh, hs)
563 | key_layer = key_layer.view(
564 | B, T, self.num_key_value_heads, self.hidden_size_per_attention_head
565 | ) # (B, T, nh, hs)
566 | value_layer = value_layer.view(
567 | B, T, self.num_key_value_heads, self.hidden_size_per_attention_head
568 | ) # (B, T, nh, hs)
569 |
570 | kv_seq_len = key_layer.shape[1]
571 | if past_key_value is not None:
572 | if self.layer_idx is None:
573 | raise ValueError(
574 | f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
575 | "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
576 | "with a layer index."
577 | )
578 | kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
579 | cos, sin = self.rotary_emb(value_layer, seq_len=kv_seq_len)
580 | query_layer, key_layer = apply_rotary_pos_emb(query_layer, key_layer, cos, sin, position_ids)
581 |
582 | # query_layer = query_layer.contiguous()
583 | # expand the key_layer and value_layer [sk, b, ng, hn] -> [sk, b, np, hn]
584 | key_layer = key_layer.repeat(1, 1, self.top_k, 1)
585 | value_layer = value_layer.repeat(1, 1, self.top_k, 1)
586 |
587 | if past_key_value is not None:
588 | cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
589 | # print(self.layer_idx, key_layer.size())
590 | key_layer = key_layer.transpose(1, 2)
591 | value_layer = value_layer.transpose(1, 2)
592 | key_layer, value_layer = past_key_value.update(key_layer, value_layer, self.layer_idx, cache_kwargs)
593 | key_layer = key_layer.transpose(1, 2)
594 | value_layer = value_layer.transpose(1, 2)
595 |
596 | context_layer = self._flash_attention_forward(
597 | query_layer,
598 | key_layer,
599 | value_layer,
600 | attention_mask,
601 | T,
602 | )
603 |
604 | # output projection
605 | y = self.experts.reduce(context_layer.reshape(T, B, self.top_k, self.kv_projection_size))
606 | y = y.view(B, T, C) # re-assemble all head outputs side by side
607 |
608 | if not output_attentions:
609 | attn_weights = None
610 |
611 | return y, attn_weights, past_key_value, aux_loss
612 |
613 | def _flash_attention_forward(
614 | self,
615 | query_states,
616 | key_states,
617 | value_states,
618 | attention_mask,
619 | query_length,
620 | dropout=0.0,
621 | softmax_scale=None,
622 | ):
623 | """
624 | Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
625 | first unpad the input, then computes the attention scores and pad the final attention scores.
626 |
627 | Args:
628 | query_states (`torch.Tensor`):
629 | Input query states to be passed to Flash Attention API
630 | key_states (`torch.Tensor`):
631 | Input key states to be passed to Flash Attention API
632 | value_states (`torch.Tensor`):
633 | Input value states to be passed to Flash Attention API
634 | attention_mask (`torch.Tensor`):
635 | The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
636 | position of padding tokens and 1 for the position of non-padding tokens.
637 | dropout (`float`):
638 | Attention dropout
639 | softmax_scale (`float`, *optional*):
640 | The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
641 | """
642 | if not self._flash_attn_uses_top_left_mask:
643 | causal = self.is_causal
644 | else:
645 | # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
646 | causal = self.is_causal and query_length != 1
647 |
648 | # Contains at least one padding token in the sequence
649 | if attention_mask is not None:
650 | batch_size = query_states.shape[0]
651 | query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
652 | query_states, key_states, value_states, attention_mask, query_length
653 | )
654 |
655 | cu_seqlens_q, cu_seqlens_k = cu_seq_lens
656 | max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
657 |
658 | attn_output_unpad = flash_attn_varlen_func(
659 | query_states,
660 | key_states,
661 | value_states,
662 | cu_seqlens_q=cu_seqlens_q,
663 | cu_seqlens_k=cu_seqlens_k,
664 | max_seqlen_q=max_seqlen_in_batch_q,
665 | max_seqlen_k=max_seqlen_in_batch_k,
666 | dropout_p=dropout,
667 | softmax_scale=softmax_scale,
668 | causal=causal,
669 | )
670 |
671 | attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
672 | else:
673 | attn_output = flash_attn_func(
674 | query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
675 | )
676 |
677 | return attn_output
678 |
679 | def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
680 | indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
681 | batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
682 |
683 | key_layer = index_first_axis(
684 | key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
685 | )
686 | value_layer = index_first_axis(
687 | value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
688 | )
689 | if query_length == kv_seq_len:
690 | query_layer = index_first_axis(
691 | query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
692 | )
693 | cu_seqlens_q = cu_seqlens_k
694 | max_seqlen_in_batch_q = max_seqlen_in_batch_k
695 | indices_q = indices_k
696 | elif query_length == 1:
697 | max_seqlen_in_batch_q = 1
698 | cu_seqlens_q = torch.arange(
699 | batch_size + 1, dtype=torch.int32, device=query_layer.device
700 | ) # There is a memcpy here, that is very bad.
701 | indices_q = cu_seqlens_q[:-1]
702 | query_layer = query_layer.squeeze(1)
703 | else:
704 | # The -q_len: slice assumes left padding.
705 | attention_mask = attention_mask[:, -query_length:]
706 | query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
707 |
708 | return (
709 | query_layer,
710 | key_layer,
711 | value_layer,
712 | indices_q,
713 | (cu_seqlens_q, cu_seqlens_k),
714 | (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
715 | )
716 |
717 |
718 | JETMOE_ATTENTION_CLASSES = {
719 | "eager": JetMoEAttention,
720 | "flash_attention_2": JetMoEFlashAttention2,
721 | "sdpa": JetMoESdpaAttention,
722 | }
723 |
724 |
725 | class JetMoEBlock(nn.Module):
726 | def __init__(self, config: JetMoEConfig, layer_idx: Optional[int] = None):
727 | """
728 | Initialize the JetMoEBlock module.
729 |
730 | Args:
731 | config: Configuration object with model hyperparameters.
732 | """
733 | super().__init__()
734 | self.input_layernorm = JetMoERMSNorm(config.hidden_size)
735 | self.self_attention = JETMOE_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
736 | self.post_attention_layernorm = JetMoERMSNorm(config.hidden_size)
737 |
738 | self.mlp = MoE(
739 | input_size=config.hidden_size,
740 | hidden_size=config.ffn_hidden_size,
741 | num_experts=config.moe_num_experts,
742 | activation=ACT2FN[config.activation_function],
743 | top_k=config.moe_top_k,
744 | bias=config.bias,
745 | glu=config.glu,
746 | )
747 |
748 | def forward(
749 | self,
750 | hidden_states: Optional[torch.FloatTensor],
751 | position_ids: Optional[torch.LongTensor] = None,
752 | past_key_value: Optional[Tuple[torch.Tensor]] = None,
753 | attention_mask: Optional[torch.FloatTensor] = None,
754 | output_attentions: Optional[bool] = False,
755 | use_cache: Optional[bool] = False,
756 | **kwargs,
757 | ) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]:
758 | """
759 | Forward pass of the JetMoEBlock module.
760 |
761 | Args:
762 | hidden_states (Optional[torch.FloatTensor]): Input hidden states.
763 | layer_past (Optional[Tuple[torch.Tensor]]): Past layer state.
764 | attention_mask (Optional[torch.FloatTensor]): Attention mask.
765 | head_mask (Optional[torch.FloatTensor]): Head mask.
766 | use_cache (Optional[bool]): Whether to use cached states.
767 | output_attentions (Optional[bool]): Whether to output attention weights.
768 |
769 | Returns:
770 | Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]:
771 | Tuple containing outputs or optional attention weights.
772 | """
773 | # Self Attention
774 | attn_output, self_attn_weights, present_key_value, att_aux_loss = self.self_attention(
775 | hidden_states=self.input_layernorm(hidden_states),
776 | attention_mask=attention_mask,
777 | position_ids=position_ids,
778 | past_key_value=past_key_value,
779 | output_attentions=output_attentions,
780 | use_cache=use_cache,
781 | )
782 |
783 | hidden_states = hidden_states + attn_output
784 | x_mlp, mlp_aux_loss = self.mlp(self.post_attention_layernorm(hidden_states))
785 | hidden_states = hidden_states + x_mlp
786 |
787 | outputs = (hidden_states,)
788 |
789 | if output_attentions:
790 | outputs += (self_attn_weights,)
791 |
792 | if use_cache:
793 | outputs += (present_key_value,)
794 |
795 | outputs += (att_aux_loss + mlp_aux_loss,)
796 |
797 | return outputs
798 |
799 |
800 | class JetMoEPreTrainedModel(PreTrainedModel):
801 | """
802 | An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
803 | models.
804 | """
805 |
806 | config_class = JetMoEConfig
807 | base_model_prefix = "transformer"
808 | supports_gradient_checkpointing = False
809 | _no_split_modules = ["JetMoEBlock"]
810 | _skip_keys_device_placement = "past_key_values"
811 | _supports_flash_attn_2 = True
812 | _supports_sdpa = True
813 | _supports_cache_class = True
814 |
815 | def __init__(self, *inputs, **kwargs):
816 | """
817 | Initialize the JetMoEPreTrainedModel.
818 |
819 | Args:
820 | *inputs: Variable length input arguments.
821 | **kwargs: Keyword arguments.
822 | """
823 | super().__init__(*inputs, **kwargs)
824 |
825 | self.gradient_checkpointing = False
826 |
827 | def _init_weights(self, module):
828 | """Initialize the weights."""
829 | if isinstance(module, (nn.Linear,)):
830 | # Slightly different from Mesh Transformer JAX which uses truncated_normal for initialization
831 | # cf https://github.com/pytorch/pytorch/pull/5617
832 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
833 | if module.bias is not None:
834 | module.bias.data.zero_()
835 | elif isinstance(module, nn.Embedding):
836 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
837 | if module.padding_idx is not None:
838 | module.weight.data[module.padding_idx].zero_()
839 | elif isinstance(module, nn.LayerNorm):
840 | module.bias.data.zero_()
841 | module.weight.data.fill_(1.0)
842 | elif isinstance(module, ParallelExperts):
843 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
844 |
845 | # def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs={}):
846 | # for module in self.modules():
847 | # if hasattr(module, "gradient_checkpointing"):
848 | # self._set_gradient_checkpointing(
849 | # module, True, gradient_checkpointing_kwargs
850 | # )
851 |
852 | # def gradient_checkpointing_disable(self):
853 | # for module in self.modules():
854 | # if hasattr(module, "gradient_checkpointing"):
855 | # self._set_gradient_checkpointing(
856 | # module, False
857 | # )
858 |
859 | # def _set_gradient_checkpointing(
860 | # self,
861 | # module,
862 | # value=False,
863 | # gradient_checkpointing_kwargs={"use_reentrant": False},
864 | # ):
865 | # """
866 | # Set gradient checkpointing for the JetMoEModel.
867 |
868 | # Args:
869 | # module: The module for which gradient checkpointing is set.
870 | # value (bool): Whether to enable gradient checkpointing.
871 | # """
872 | # self._gradient_checkpointing_func = checkpoint
873 | # self.gradient_checkpointing = True
874 | # if isinstance(module, JetMoEModel):
875 | # module.gradient_checkpointing = value
876 | # module.gradient_checkpointing_kwargs = gradient_checkpointing_kwargs
877 | # module._gradient_checkpointing_func = checkpoint
878 |
879 |
880 | JETMOE_START_DOCSTRING = r"""
881 | This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
882 | it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
883 | behavior.
884 |
885 | Parameters:
886 | config ([`JetMoEConfig`]): Model configuration class with all the parameters of the model.
887 | Initializing with a config file does not load the weights associated with the model, only the
888 | configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
889 | """
890 |
891 | JETMOE_INPUTS_DOCSTRING = r"""
892 | Args:
893 | input_ids (`torch.LongTensor` of shape `({0})`):
894 | Indices of input sequence tokens in the vocabulary.
895 |
896 | Indices can be obtained using [`AutoProcenizer`]. See [`PreTrainedTokenizer.encode`] and
897 | [`PreTrainedTokenizer.__call__`] for details.
898 |
899 | [What are input IDs?](../glossary#input-ids)
900 | attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
901 | Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
902 |
903 | - 1 for tokens that are **not masked**,
904 | - 0 for tokens that are **masked**.
905 |
906 | [What are attention masks?](../glossary#attention-mask)
907 | token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):
908 | Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
909 | 1]`:
910 |
911 | - 0 corresponds to a *sentence A* token,
912 | - 1 corresponds to a *sentence B* token.
913 |
914 | [What are token type IDs?](../glossary#token-type-ids)
915 | position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
916 | Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
917 | config.n_positions - 1]`.
918 |
919 | [What are position IDs?](../glossary#position-ids)
920 | head_mask (`torch.FloatTensor` of shape `(num_attention_heads,)` or `(n_layer, num_attention_heads)`, *optional*):
921 | Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
922 |
923 | - 1 indicates the head is **not masked**,
924 | - 0 indicates the head is **masked**.
925 |
926 | inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_dim)`, *optional*):
927 | Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
928 | is useful if you want more control over how to convert *input_ids* indices into associated vectors than the
929 | model's internal embedding lookup matrix.
930 | output_attentions (`bool`, *optional*):
931 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
932 | tensors for more detail.
933 | output_hidden_states (`bool`, *optional*):
934 | Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
935 | more detail.
936 | return_dict (`bool`, *optional*):
937 | Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
938 | """
939 |
940 |
941 | @add_start_docstrings(
942 | "The bare JetMoE Model outputting raw hidden-states without any specific head on top.",
943 | JETMOE_START_DOCSTRING,
944 | )
945 | class JetMoEModel(JetMoEPreTrainedModel):
946 | """
947 | Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`JetMoEBlock`]
948 |
949 | Args:
950 | config: JetMoEConfig
951 | """
952 |
953 | def __init__(self, config: JetMoEConfig):
954 | super().__init__(config)
955 | self.padding_idx = config.pad_token_id
956 | self.vocab_size = config.vocab_size
957 |
958 | self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
959 | self.layers = nn.ModuleList([JetMoEBlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)])
960 | self._attn_implementation = config._attn_implementation
961 | self.norm = JetMoERMSNorm(config.hidden_size, eps=config.rms_norm_eps)
962 |
963 | self.gradient_checkpointing = False
964 | # Initialize weights and apply final processing
965 | self.post_init()
966 |
967 | def get_input_embeddings(self):
968 | return self.embed_tokens
969 |
970 | def set_input_embeddings(self, value):
971 | self.embed_tokens = value
972 |
973 | @add_start_docstrings_to_model_forward(JETMOE_INPUTS_DOCSTRING)
974 | def forward(
975 | self,
976 | input_ids: torch.LongTensor = None,
977 | attention_mask: Optional[torch.Tensor] = None,
978 | position_ids: Optional[torch.LongTensor] = None,
979 | past_key_values: Optional[List[torch.FloatTensor]] = None,
980 | inputs_embeds: Optional[torch.FloatTensor] = None,
981 | use_cache: Optional[bool] = None,
982 | output_attentions: Optional[bool] = None,
983 | output_hidden_states: Optional[bool] = None,
984 | return_dict: Optional[bool] = None,
985 | ) -> Union[Tuple, BaseModelOutputWithPast]:
986 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
987 | output_hidden_states = (
988 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
989 | )
990 | use_cache = use_cache if use_cache is not None else self.config.use_cache
991 |
992 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
993 |
994 | # retrieve input_ids and inputs_embeds
995 | if input_ids is not None and inputs_embeds is not None:
996 | raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
997 | elif input_ids is not None:
998 | batch_size, seq_length = input_ids.shape
999 | elif inputs_embeds is not None:
1000 | batch_size, seq_length, _ = inputs_embeds.shape
1001 | else:
1002 | raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
1003 |
1004 | if self.gradient_checkpointing and self.training:
1005 | if use_cache:
1006 | logger.warning_once(
1007 | "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
1008 | )
1009 | use_cache = False
1010 |
1011 | past_key_values_length = 0
1012 |
1013 | if use_cache:
1014 | use_legacy_cache = not isinstance(past_key_values, Cache)
1015 | if use_legacy_cache:
1016 | past_key_values = DynamicCache.from_legacy_cache(past_key_values)
1017 | past_key_values_length = past_key_values.get_usable_length(seq_length)
1018 |
1019 | if position_ids is None:
1020 | device = input_ids.device if input_ids is not None else inputs_embeds.device
1021 | position_ids = torch.arange(
1022 | past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
1023 | )
1024 | position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
1025 | else:
1026 | position_ids = position_ids.view(-1, seq_length).long()
1027 |
1028 | if inputs_embeds is None:
1029 | inputs_embeds = self.embed_tokens(input_ids)
1030 |
1031 | if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache:
1032 | is_padding_right = attention_mask[:, -1].sum().item() != batch_size
1033 | if is_padding_right:
1034 | raise ValueError(
1035 | "You are attempting to perform batched generation with padding_side='right'"
1036 | " this may lead to unexpected behaviour for Flash Attention version of JetMoE. Make sure to "
1037 | " call `tokenizer.padding_side = 'left'` before tokenizing the input. "
1038 | )
1039 |
1040 | if self._attn_implementation == "flash_attention_2":
1041 | # 2d mask is passed through the layers
1042 | attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
1043 | elif self._attn_implementation == "sdpa" and not output_attentions:
1044 | # output_attentions=True can not be supported when using SDPA, and we fall back on
1045 | # the manual implementation that requires a 4D causal mask in all cases.
1046 | attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
1047 | attention_mask,
1048 | (batch_size, seq_length),
1049 | inputs_embeds,
1050 | past_key_values_length,
1051 | )
1052 | else:
1053 | # 4d mask is passed through the layers
1054 | attention_mask = _prepare_4d_causal_attention_mask(
1055 | attention_mask,
1056 | (batch_size, seq_length),
1057 | inputs_embeds,
1058 | past_key_values_length,
1059 | )
1060 |
1061 | hidden_states = inputs_embeds
1062 |
1063 | # decoder layers
1064 | all_hidden_states = () if output_hidden_states else None
1065 | all_self_attns = () if output_attentions else None
1066 | next_decoder_cache = None
1067 |
1068 | aux_loss = 0
1069 | for decoder_layer in self.layers:
1070 | if output_hidden_states:
1071 | all_hidden_states += (hidden_states,)
1072 |
1073 | # hidden_states: Optional[torch.FloatTensor],
1074 | # position_ids: Optional[torch.LongTensor] = None,
1075 | # past_key_value: Optional[Tuple[torch.Tensor]] = None,
1076 | # attention_mask: Optional[torch.FloatTensor] = None,
1077 | # output_attentions: Optional[bool] = False,
1078 | # use_cache: Optional[bool] = False,
1079 |
1080 | if self.gradient_checkpointing and self.training:
1081 | layer_outputs = self._gradient_checkpointing_func(
1082 | # decoder_layer.__call__,
1083 | decoder_layer,
1084 | hidden_states,
1085 | position_ids,
1086 | past_key_values,
1087 | attention_mask,
1088 | output_attentions,
1089 | use_cache,
1090 | use_reentrant=False,
1091 | )
1092 | else:
1093 | layer_outputs = decoder_layer(
1094 | hidden_states,
1095 | attention_mask=attention_mask,
1096 | position_ids=position_ids,
1097 | past_key_value=past_key_values,
1098 | output_attentions=output_attentions,
1099 | use_cache=use_cache,
1100 | )
1101 |
1102 | hidden_states = layer_outputs[0]
1103 |
1104 | if use_cache:
1105 | next_decoder_cache = layer_outputs[2 if output_attentions else 1]
1106 |
1107 | if output_attentions:
1108 | all_self_attns += (layer_outputs[1],)
1109 |
1110 | aux_loss += layer_outputs[-1]
1111 |
1112 | hidden_states = self.norm(hidden_states)
1113 |
1114 | # add hidden states from the last decoder layer
1115 | if output_hidden_states:
1116 | all_hidden_states += (hidden_states,)
1117 |
1118 | next_cache = None
1119 | if use_cache:
1120 | next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
1121 |
1122 | if not return_dict:
1123 | return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
1124 | return JetMoEBaseModelOutputWithPast(
1125 | last_hidden_state=hidden_states,
1126 | past_key_values=next_cache,
1127 | hidden_states=all_hidden_states,
1128 | attentions=all_self_attns,
1129 | aux_loss=aux_loss,
1130 | )
1131 |
1132 |
1133 | class JetMoEForCausalLM(JetMoEPreTrainedModel):
1134 | _tied_weights_keys = ["lm_head.weight"]
1135 |
1136 | def __init__(self, config):
1137 | super().__init__(config)
1138 | self.model = JetMoEModel(config)
1139 | self.vocab_size = config.vocab_size
1140 | self.aux_loss_coef = getattr(config, "aux_loss_coef", 0.01)
1141 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1142 | self.tie_word_embeddings = config.tie_word_embeddings
1143 |
1144 | # Initialize weights and apply final processing
1145 | self.post_init()
1146 |
1147 | def get_input_embeddings(self):
1148 | return self.model.embed_tokens
1149 |
1150 | def set_input_embeddings(self, value):
1151 | self.model.embed_tokens = value
1152 |
1153 | def get_output_embeddings(self):
1154 | return self.lm_head
1155 |
1156 | def set_output_embeddings(self, new_embeddings):
1157 | self.lm_head = new_embeddings
1158 |
1159 | def set_decoder(self, decoder):
1160 | self.model = decoder
1161 |
1162 | def get_decoder(self):
1163 | return self.model
1164 |
1165 | @add_start_docstrings_to_model_forward(JETMOE_INPUTS_DOCSTRING)
1166 | @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1167 | def forward(
1168 | self,
1169 | input_ids: torch.LongTensor = None,
1170 | attention_mask: Optional[torch.Tensor] = None,
1171 | position_ids: Optional[torch.LongTensor] = None,
1172 | past_key_values: Optional[List[torch.FloatTensor]] = None,
1173 | inputs_embeds: Optional[torch.FloatTensor] = None,
1174 | labels: Optional[torch.LongTensor] = None,
1175 | use_cache: Optional[bool] = None,
1176 | output_attentions: Optional[bool] = None,
1177 | output_hidden_states: Optional[bool] = None,
1178 | return_dict: Optional[bool] = None,
1179 | ) -> Union[Tuple, CausalLMOutputWithPast]:
1180 | r"""
1181 | Args:
1182 | labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1183 | Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1184 | config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1185 | (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1186 |
1187 | Returns:
1188 | """
1189 |
1190 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1191 | output_hidden_states = (
1192 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1193 | )
1194 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1195 |
1196 | # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1197 | outputs = self.model(
1198 | input_ids=input_ids,
1199 | attention_mask=attention_mask,
1200 | position_ids=position_ids,
1201 | past_key_values=past_key_values,
1202 | inputs_embeds=inputs_embeds,
1203 | use_cache=use_cache,
1204 | output_attentions=output_attentions,
1205 | output_hidden_states=output_hidden_states,
1206 | return_dict=return_dict,
1207 | )
1208 |
1209 | hidden_states = outputs[0]
1210 | logits = self.lm_head(hidden_states)
1211 | logits = logits.float()
1212 |
1213 | loss = None
1214 | if labels is not None:
1215 | # Shift so that tokens < n predict n
1216 | shift_logits = logits[..., :-1, :].contiguous()
1217 | shift_labels = labels[..., 1:].contiguous()
1218 | # Flatten the tokens
1219 | shift_logits = shift_logits.view(-1, self.config.vocab_size)
1220 | shift_labels = shift_labels.view(-1)
1221 | # Ensure tensors are on the same device
1222 | shift_labels = shift_labels.to(shift_logits.device)
1223 | loss_fct = CrossEntropyLoss()
1224 | loss = loss_fct(shift_logits, shift_labels)
1225 |
1226 | if not return_dict:
1227 | output = (logits,) + outputs[1:]
1228 | return (loss,) + output if loss is not None else output
1229 |
1230 | if labels is not None and self.model.training:
1231 | loss += self.aux_loss_coef * outputs.aux_loss.to(loss.device)
1232 |
1233 | return JetMoECausalLMOutputWithPast(
1234 | loss=loss,
1235 | logits=logits,
1236 | past_key_values=outputs.past_key_values,
1237 | hidden_states=outputs.hidden_states,
1238 | attentions=outputs.attentions,
1239 | aux_loss=outputs.aux_loss,
1240 | )
1241 |
1242 | def prepare_inputs_for_generation(
1243 | self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
1244 | ):
1245 | # Omit tokens covered by past_key_values
1246 | if past_key_values is not None:
1247 | if isinstance(past_key_values, Cache):
1248 | cache_length = past_key_values.get_seq_length()
1249 | past_length = past_key_values.seen_tokens
1250 | max_cache_length = past_key_values.get_max_length()
1251 | else:
1252 | cache_length = past_length = past_key_values[0][0].shape[2]
1253 | max_cache_length = None
1254 |
1255 | # Keep only the unprocessed tokens:
1256 | # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
1257 | # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
1258 | # input)
1259 | if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
1260 | input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
1261 | # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
1262 | # input_ids based on the past_length.
1263 | elif past_length < input_ids.shape[1]:
1264 | input_ids = input_ids[:, past_length:]
1265 | # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
1266 |
1267 | # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
1268 | if (
1269 | max_cache_length is not None
1270 | and attention_mask is not None
1271 | and cache_length + input_ids.shape[1] > max_cache_length
1272 | ):
1273 | attention_mask = attention_mask[:, -max_cache_length:]
1274 |
1275 | position_ids = kwargs.get("position_ids", None)
1276 | if attention_mask is not None and position_ids is None:
1277 | # create position_ids on the fly for batch generation
1278 | position_ids = attention_mask.long().cumsum(-1) - 1
1279 | position_ids.masked_fill_(attention_mask == 0, 1)
1280 | if past_key_values:
1281 | position_ids = position_ids[:, -input_ids.shape[1] :]
1282 |
1283 | # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1284 | if inputs_embeds is not None and past_key_values is None:
1285 | model_inputs = {"inputs_embeds": inputs_embeds}
1286 | else:
1287 | model_inputs = {"input_ids": input_ids}
1288 |
1289 | model_inputs.update(
1290 | {
1291 | "position_ids": position_ids,
1292 | "past_key_values": past_key_values,
1293 | "use_cache": kwargs.get("use_cache"),
1294 | "attention_mask": attention_mask,
1295 | }
1296 | )
1297 | return model_inputs
1298 |
1299 | @staticmethod
1300 | def _reorder_cache(past_key_values, beam_idx):
1301 | reordered_past = ()
1302 | for layer_past in past_key_values:
1303 | reordered_past += (
1304 | tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
1305 | )
1306 | return reordered_past
1307 |
1308 |
1309 | @add_start_docstrings(
1310 | """
1311 | The JetMoE Model transformer with a sequence classification head on top (linear layer).
1312 |
1313 | [`JetMoEForSequenceClassification`] uses the last token in order to do the classification, as other causal models
1314 | (e.g. GPT-2) do.
1315 |
1316 | Since it does classification on the last token, it requires to know the position of the last token. If a
1317 | `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
1318 | no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
1319 | padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
1320 | each row of the batch).
1321 | """,
1322 | JETMOE_START_DOCSTRING,
1323 | )
1324 | # Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->JetMoE, LLAMA->JETMOE
1325 | class JetMoEForSequenceClassification(JetMoEPreTrainedModel):
1326 | def __init__(self, config):
1327 | super().__init__(config)
1328 | self.num_labels = config.num_labels
1329 | self.model = JetMoEModel(config)
1330 | self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
1331 |
1332 | # Initialize weights and apply final processing
1333 | self.post_init()
1334 |
1335 | def get_input_embeddings(self):
1336 | return self.model.embed_tokens
1337 |
1338 | def set_input_embeddings(self, value):
1339 | self.model.embed_tokens = value
1340 |
1341 | @add_start_docstrings_to_model_forward(JETMOE_INPUTS_DOCSTRING)
1342 | def forward(
1343 | self,
1344 | input_ids: torch.LongTensor = None,
1345 | attention_mask: Optional[torch.Tensor] = None,
1346 | position_ids: Optional[torch.LongTensor] = None,
1347 | past_key_values: Optional[List[torch.FloatTensor]] = None,
1348 | inputs_embeds: Optional[torch.FloatTensor] = None,
1349 | labels: Optional[torch.LongTensor] = None,
1350 | use_cache: Optional[bool] = None,
1351 | output_attentions: Optional[bool] = None,
1352 | output_hidden_states: Optional[bool] = None,
1353 | return_dict: Optional[bool] = None,
1354 | ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1355 | r"""
1356 | labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1357 | Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1358 | config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1359 | `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1360 | """
1361 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1362 |
1363 | transformer_outputs = self.model(
1364 | input_ids,
1365 | attention_mask=attention_mask,
1366 | position_ids=position_ids,
1367 | past_key_values=past_key_values,
1368 | inputs_embeds=inputs_embeds,
1369 | use_cache=use_cache,
1370 | output_attentions=output_attentions,
1371 | output_hidden_states=output_hidden_states,
1372 | return_dict=return_dict,
1373 | )
1374 | hidden_states = transformer_outputs[0]
1375 | logits = self.score(hidden_states)
1376 |
1377 | if input_ids is not None:
1378 | batch_size = input_ids.shape[0]
1379 | else:
1380 | batch_size = inputs_embeds.shape[0]
1381 |
1382 | if self.config.pad_token_id is None and batch_size != 1:
1383 | raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
1384 | if self.config.pad_token_id is None:
1385 | sequence_lengths = -1
1386 | else:
1387 | if input_ids is not None:
1388 | # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
1389 | sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
1390 | sequence_lengths = sequence_lengths % input_ids.shape[-1]
1391 | sequence_lengths = sequence_lengths.to(logits.device)
1392 | else:
1393 | sequence_lengths = -1
1394 |
1395 | pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
1396 |
1397 | loss = None
1398 | if labels is not None:
1399 | labels = labels.to(logits.device)
1400 | if self.config.problem_type is None:
1401 | if self.num_labels == 1:
1402 | self.config.problem_type = "regression"
1403 | elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1404 | self.config.problem_type = "single_label_classification"
1405 | else:
1406 | self.config.problem_type = "multi_label_classification"
1407 |
1408 | if self.config.problem_type == "regression":
1409 | loss_fct = MSELoss()
1410 | if self.num_labels == 1:
1411 | loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
1412 | else:
1413 | loss = loss_fct(pooled_logits, labels)
1414 | elif self.config.problem_type == "single_label_classification":
1415 | loss_fct = CrossEntropyLoss()
1416 | loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
1417 | elif self.config.problem_type == "multi_label_classification":
1418 | loss_fct = BCEWithLogitsLoss()
1419 | loss = loss_fct(pooled_logits, labels)
1420 | if not return_dict:
1421 | output = (pooled_logits,) + transformer_outputs[1:]
1422 | return ((loss,) + output) if loss is not None else output
1423 |
1424 | return SequenceClassifierOutputWithPast(
1425 | loss=loss,
1426 | logits=pooled_logits,
1427 | past_key_values=transformer_outputs.past_key_values,
1428 | hidden_states=transformer_outputs.hidden_states,
1429 | attentions=transformer_outputs.attentions,
1430 | )
1431 |
--------------------------------------------------------------------------------
/jetmoe/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .moe import MoE
2 | from .parallel_experts import ParallelExperts
--------------------------------------------------------------------------------
/jetmoe/utils/gate.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | class top_k_gating(nn.Module):
6 | def __init__(
7 | self,
8 | input_size,
9 | num_experts,
10 | top_k,
11 | ):
12 | """
13 | Initialize the top-k gating mechanism.
14 |
15 | Args:
16 | input_size (int): Size of the input.
17 | num_experts (int): Number of experts.
18 | top_k (int): Number of top experts to select.
19 | acc_aux_loss (bool): Whether to accumulate auxiliary loss statistics.
20 | dropout (float): Dropout rate for gating network.
21 | hidden_size (int): Hidden size of the gating network.
22 | sample_topk (int): Number of top-k experts to sample during training.
23 | aux_loss (str): Type of auxiliary loss ('mi' or 'switch').
24 | gate_type (str): Type of gating mechanism ('mlp', 'linear', or 'gmm').
25 | """
26 | super().__init__()
27 |
28 | self.num_experts = num_experts
29 | self.input_size = input_size
30 | assert top_k <= num_experts
31 | self.top_k = top_k
32 |
33 | self.layer = nn.Linear(input_size, num_experts, bias=False)
34 |
35 | def extra_repr(self):
36 | """
37 | Return extra representation string for the module.
38 | """
39 | return "k={}, num_experts={}".format(self.top_k, self.num_experts)
40 |
41 | def compute_aux_loss(self, probs, logits, gates):
42 | """
43 | Calculate and return the auxiliary loss based on the accumulated statistics.
44 |
45 | Args:
46 | eps (float): Small epsilon value for numerical stability.
47 |
48 | Returns:
49 | torch.Tensor: The calculated auxiliary loss.
50 | """
51 | count = logits.size(0)
52 | probs = probs.sum(0)
53 | freq = (gates > 0).float().sum(0)
54 | lsesq = (torch.log(torch.exp(logits).sum(dim=-1)) ** 2).sum()
55 |
56 | switchloss = self.num_experts * (F.normalize(probs, p=1, dim=0) * F.normalize(freq, p=1, dim=0)).sum()
57 | zloss = lsesq / count
58 | loss = switchloss + 0.1 * zloss
59 |
60 | return loss
61 |
62 | def forward(self, x):
63 | """
64 | Compute the top-k gating for the input.
65 |
66 | See paper: https://arxiv.org/abs/1701.06538.
67 |
68 | Args:
69 | x (torch.Tensor): Input tensor with shape [batch_size, input_size].
70 | skip_mask (torch.Tensor): Skip mask tensor (binary) with the same shape as `x`.
71 | x: input Tensor with shape [batch_size, input_size]
72 | train: a boolean - we only add noise at training time.
73 | noise_epsilon: a float
74 |
75 | Returns:
76 | torch.Tensor: Top-k indices.
77 | torch.Tensor: Top-k gating values.
78 | torch.Tensor: Probability values for each expert.
79 | gates: a Tensor with shape [batch_size, num_experts]
80 | load: a Tensor with shape [num_experts]
81 | """
82 |
83 | logits = self.layer(x).float()
84 | top_k_logits, top_k_indices = logits.topk(self.top_k, dim=1)
85 | top_k_gates = torch.softmax(top_k_logits, dim=1).type_as(x)
86 |
87 | if self.training:
88 | probs = torch.softmax(logits, dim=1)
89 | zeros = torch.zeros_like(probs)
90 | zeros = zeros.to(top_k_gates.dtype) # Convert zeros to match top_k_gates dtype
91 | gates = zeros.scatter(1, top_k_indices, top_k_gates)
92 | self.loss = self.compute_aux_loss(probs, logits, gates)
93 | else:
94 | self.loss = 0
95 |
96 | return top_k_indices, top_k_gates
--------------------------------------------------------------------------------
/jetmoe/utils/moe.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 | from .parallel_experts import ParallelExperts, compute_gating
8 |
9 | from .gate import top_k_gating
10 |
11 |
12 | class MoE(nn.Module):
13 | """
14 | A Sparsely gated mixture of experts layer with 1-layer Feed-Forward networks as experts.
15 |
16 | Args:
17 | input_size: integer - size of the input
18 | hidden_size: integer - size of the expert's hidden layer
19 | num_experts: an integer - number of experts
20 | top_k: an integer - how many experts to use for each batch element
21 | bias: a boolean - whether to include bias in linear layers
22 | activation: an activation function to apply to expert's outputs
23 | glu: an boolean - whether to use GLU activation
24 | """
25 |
26 | def __init__(
27 | self,
28 | input_size,
29 | hidden_size,
30 | num_experts,
31 | top_k,
32 | bias=True,
33 | activation=None,
34 | glu=True,
35 | ):
36 | super(MoE, self).__init__()
37 |
38 | self.num_experts = num_experts
39 | self.input_size = input_size
40 | self.hidden_size = hidden_size
41 | self.glu = glu
42 | if bias:
43 | self.bias = torch.nn.Parameter(torch.empty(input_size))
44 | torch.nn.init.zeros_(self.bias)
45 | else:
46 | self.bias = None
47 |
48 | self.input_linear = ParallelExperts(num_experts, input_size, hidden_size * 2 if glu else hidden_size)
49 | self.output_linear = ParallelExperts(num_experts, hidden_size, input_size)
50 |
51 | self.top_k = min(top_k, self.num_experts)
52 | self.activation = activation
53 |
54 | self.router = top_k_gating(
55 | input_size=input_size,
56 | num_experts=num_experts,
57 | top_k=top_k,
58 | )
59 |
60 | def extra_repr(self):
61 | return "k={}, e={}".format(self.top_k, self.num_experts)
62 |
63 | def get_aux_loss_and_clear(self):
64 | """
65 | Get the accumulated auxiliary loss and clear it.
66 |
67 | Returns:
68 | float: Accumulated auxiliary loss.
69 | """
70 |
71 | return self.gate.get_aux_loss_and_clear()
72 |
73 | def compute_gate(self, x):
74 | top_k_indices, self.top_k_gates = self.router(x)
75 |
76 | self.batch_gates, self.batch_index, expert_size, self.index_sorted_experts = compute_gating(
77 | self.top_k, self.num_experts, self.top_k_gates, top_k_indices
78 | )
79 | self.expert_size = expert_size.tolist()
80 |
81 | return self.router.loss
82 |
83 | def batch_forward(self, x):
84 | """
85 | Forward pass of the mixture of experts layer.
86 |
87 | Args:
88 | x (Tensor): Input tensor.
89 | skip_mask (Tensor): Skip mask tensor.
90 | sample_topk (int): Number of experts to sample during training.
91 | multiply_by_gates (bool): Whether to multiply outputs by gating values.
92 |
93 | Returns:
94 | Tensor: Output tensor.
95 | float: Gating loss.
96 | """
97 | bsz, length, emb_size = x.size()
98 | x = x.reshape(-1, emb_size)
99 | loss = self.compute_gate(x)
100 |
101 | expert_inputs = x[self.batch_index]
102 | h = self.input_linear(expert_inputs, self.expert_size)
103 | if self.glu:
104 | h, g = h.chunk(2, dim=-1)
105 | h = self.activation(h) * g
106 | else:
107 | h = self.activation(h)
108 | expert_outputs = self.output_linear(h, self.expert_size)
109 |
110 | expert_outputs = expert_outputs * self.batch_gates[:, None]
111 |
112 | zeros = torch.zeros((bsz * length, self.input_size), dtype=expert_outputs.dtype, device=expert_outputs.device)
113 | y = zeros.index_add(0, self.batch_index, expert_outputs)
114 | y = y.view(bsz, length, self.input_size)
115 | if self.bias is not None:
116 | y = y + self.bias
117 | return y, loss
118 |
119 | def single_forward(self, x):
120 | bsz, length, emb_size = x.size()
121 |
122 | x = x.reshape(1, self.input_size)
123 | top_k_indices, top_k_gates = self.router(x)
124 | loss = self.router.loss
125 |
126 | y_list = []
127 | for i in range(self.top_k):
128 | expert_idx = top_k_indices[0, i]
129 |
130 | h = F.linear(x, self.input_linear.weight[expert_idx])
131 | if self.glu:
132 | h, g = h.chunk(2, dim=-1)
133 | h = self.activation(h) * g
134 | else:
135 | h = self.activation(h)
136 | y = F.linear(h, self.output_linear.weight[expert_idx]) * top_k_gates[0, i]
137 |
138 | y_list.append(y)
139 |
140 | y = sum(y_list)
141 | y = y.view(bsz, length, self.input_size)
142 | if self.bias is not None:
143 | y = y + self.bias
144 | return y, loss
145 |
146 | def forward(self, x):
147 | """
148 | Forward pass of the mixture of experts layer.
149 |
150 | Args:
151 | x (Tensor): Input tensor.
152 |
153 | Returns:
154 | Tensor: Output tensor.
155 | """
156 | bsz, length, emb_size = x.size()
157 | if bsz * length == 1:
158 | return self.single_forward(x)
159 | else:
160 | return self.batch_forward(x)
161 |
162 | def single_map(self, x):
163 | bsz, length, emb_size = x.size()
164 |
165 | x = x.reshape(1, self.input_size)
166 | self.top_k_indices, self.top_k_gates = self.router(x)
167 | loss = self.router.loss
168 |
169 | y_list = []
170 | for i in range(self.top_k):
171 | expert_idx = self.top_k_indices[0, i]
172 | y = F.linear(x, self.input_linear.weight[expert_idx])
173 | y_list.append(y)
174 | y = torch.cat(y_list, dim=0)
175 | y = y.view(bsz, length, self.top_k, -1)
176 | return y, loss
177 |
178 | def batch_map(self, x):
179 | """
180 |
181 | Args:
182 | x: tensor shape [batch_size, input_size]
183 | train: a boolean scalar.
184 | loss_coef: a scalar - multiplier on load-balancing losses
185 |
186 | Returns:
187 | y: a tensor with shape [batch_size, output_size].
188 | extra_training_loss: a scalar. This should be added into the overall
189 | training loss of the model. The backpropagation of this loss
190 | encourages all experts to be approximately equally used across a batch.
191 | """
192 | """
193 | Map input through the mixture of experts layer.
194 |
195 | Args:
196 | x (Tensor): Input tensor.
197 | skip_mask (Tensor): Skip mask tensor.
198 | sample_topk (int): Number of experts to sample during training.
199 | return_indices (bool): Whether to return expert indices.
200 |
201 | Returns:
202 | Tensor: Output tensor.
203 | float: Gating loss.
204 | """
205 | bsz, length, emb_size = x.size()
206 | x = x.reshape(-1, emb_size)
207 | loss = self.compute_gate(x)
208 |
209 | expert_inputs = x[self.batch_index]
210 | expert_outputs = self.input_linear(expert_inputs, self.expert_size)
211 |
212 | zeros = torch.zeros(
213 | (bsz * length * self.top_k, self.hidden_size), dtype=expert_outputs.dtype, device=expert_outputs.device
214 | )
215 | y = zeros.index_add(0, self.index_sorted_experts, expert_outputs)
216 | y = y.view(bsz, length, self.top_k, -1)
217 | return y, loss
218 |
219 | def map(self, x):
220 | """
221 | Map input through the mixture of experts layer.
222 |
223 | Args:
224 | x (Tensor): Input tensor.
225 |
226 | Returns:
227 | Tensor: Output tensor.
228 | """
229 | bsz, length, emb_size = x.size()
230 | if bsz * length == 1:
231 | return self.single_map(x)
232 | else:
233 | return self.batch_map(x)
234 |
235 | def single_reduce(self, x):
236 | bsz, length, k, emb_size = x.size()
237 |
238 | x = x.reshape(k, emb_size)
239 |
240 | y_list = []
241 | for i in range(self.top_k):
242 | expert_idx = self.top_k_indices[0, i]
243 | y = F.linear(x[i], self.output_linear.weight[expert_idx]) * self.top_k_gates[0, i]
244 | y_list.append(y)
245 | y = sum(y_list)
246 | y = y.view(bsz, length, self.input_size)
247 | if self.bias is not None:
248 | y = y + self.bias
249 | return y
250 |
251 | def batch_reduce(self, x):
252 | """
253 | Reduce the mapped output.
254 |
255 | Args:
256 | x (Tensor): Mapped output tensor.
257 | multiply_by_gates (bool): Whether to multiply outputs by gating values.
258 |
259 | Returns:
260 | Tensor: Reduced output tensor.
261 | """
262 |
263 | bsz, length, k, emb_size = x.size()
264 | x = x.reshape(-1, emb_size)
265 |
266 | expert_inputs = x[self.index_sorted_experts]
267 | expert_outputs = self.output_linear(expert_inputs, self.expert_size)
268 |
269 | expert_outputs = expert_outputs * self.batch_gates[:, None]
270 |
271 | zeros = torch.zeros((bsz * length, self.input_size), dtype=expert_outputs.dtype, device=expert_outputs.device)
272 | y = zeros.index_add(0, self.batch_index, expert_outputs)
273 | y = y.view(bsz, length, self.input_size)
274 | if self.bias is not None:
275 | y = y + self.bias
276 | return y
277 |
278 | def reduce(self, x):
279 | """
280 | Reduce the mapped output.
281 |
282 | Args:
283 | x (Tensor): Mapped output tensor.
284 |
285 | Returns:
286 | Tensor: Reduced output tensor.
287 | """
288 | bsz, length, k, emb_size = x.size()
289 | if bsz * length == 1:
290 | return self.single_reduce(x)
291 | else:
292 | return self.batch_reduce(x)
--------------------------------------------------------------------------------
/jetmoe/utils/parallel_experts.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | @torch.jit.script
7 | def compute_gating(k: int, num_experts: int, top_k_gates: torch.Tensor, top_k_indices: torch.Tensor):
8 | """
9 | Compute gating values for the mixture of experts based on probabilities and top-k indices.
10 |
11 | Args:
12 | k (int): Number of experts to select.
13 | num_experts (int): Total number of experts.
14 | top_k_gates (torch.Tensor): Gating values for top-k experts (batch_size x k).
15 | top_k_indices (torch.Tensor): Indices of top-k experts (batch_size x k).
16 |
17 | Returns:
18 | torch.Tensor: Batch-level gating values.
19 | torch.Tensor: Batch-level expert indices.
20 | torch.Tensor: Expert size for each expert.
21 | torch.Tensor: Sorted indices of top-k experts.
22 | """
23 | zeros = torch.zeros([top_k_gates.size(0), num_experts], dtype=top_k_gates.dtype, device=top_k_gates.device)
24 | gates = zeros.scatter(1, top_k_indices, 1)
25 | expert_size = gates.long().sum(0)
26 | top_k_gates = top_k_gates.flatten()
27 | top_k_experts = top_k_indices.flatten()
28 | _, index_sorted_experts = top_k_experts.sort(0)
29 | batch_index = index_sorted_experts.div(k, rounding_mode="trunc")
30 | batch_gates = top_k_gates[index_sorted_experts]
31 | return batch_gates, batch_index, expert_size, index_sorted_experts
32 |
33 |
34 | class ParallelExperts(nn.Module):
35 | def __init__(self, num_experts, input_size, output_size) -> None:
36 | """
37 | Initialize the ParallelExperts module.
38 |
39 | Args:
40 | num_experts (int): Number of experts.
41 | input_size (int): Size of the input.
42 | output_size (int): Size of the output.
43 | bias (bool): Whether to include bias terms.
44 | """
45 | super().__init__()
46 | self.weight = nn.Parameter(torch.empty(num_experts, output_size, input_size))
47 | self.reset_parameters()
48 | self.num_experts = num_experts
49 | self.input_size = input_size
50 | self.output_size = output_size
51 |
52 | def extra_repr(self):
53 | return "num_experts={}, input_size={}, output_size={}".format(
54 | self.num_experts, self.input_size, self.output_size
55 | )
56 |
57 | def reset_parameters(self) -> None:
58 | """
59 | Reset the parameters of the model.
60 | """
61 | nn.init.uniform_(self.weight, -1.0 / self.weight.size(1), 1.0 / self.weight.size(1))
62 |
63 | def forward(self, inputs, expert_size):
64 | """
65 | Forward pass of the ParallelExperts module.
66 |
67 | Args:
68 | inputs (Tensor): Input tensor.
69 | expert_size: Expert size information.
70 |
71 | Returns:
72 | Tensor: Output tensor.
73 | """
74 | input_list = inputs.split(expert_size, dim=0)
75 | output_list = []
76 | for i in range(self.num_experts):
77 | output_list.append(F.linear(input_list[i], self.weight[i]))
78 | results = torch.cat(output_list, dim=0)
79 | return results
--------------------------------------------------------------------------------
/resources/1-myshell-mit.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/myshell-ai/JetMoE/0a9bc5a32386af9ea7abe9e32298976a94da908a/resources/1-myshell-mit.png
--------------------------------------------------------------------------------
/resources/2-performance.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/myshell-ai/JetMoE/0a9bc5a32386af9ea7abe9e32298976a94da908a/resources/2-performance.png
--------------------------------------------------------------------------------
/resources/3-architecture.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/myshell-ai/JetMoE/0a9bc5a32386af9ea7abe9e32298976a94da908a/resources/3-architecture.png
--------------------------------------------------------------------------------
/resources/4-phase1-data.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/myshell-ai/JetMoE/0a9bc5a32386af9ea7abe9e32298976a94da908a/resources/4-phase1-data.png
--------------------------------------------------------------------------------
/resources/5-phase2-data.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/myshell-ai/JetMoE/0a9bc5a32386af9ea7abe9e32298976a94da908a/resources/5-phase2-data.png
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup(name='jetmoe',
4 | packages=find_packages(),
5 | install_requires=[
6 | 'torch',
7 | 'transformers',
8 | 'scattermoe @ git+https://github.com/shawntan/scattermoe@main#egg=scattermoe'
9 | ])
--------------------------------------------------------------------------------