├── .gitignore ├── LICENSE ├── README.md ├── jetmoe ├── __init__.py ├── configuration_jetmoe.py ├── modeling_jetmoe.py └── utils │ ├── __init__.py │ ├── gate.py │ ├── moe.py │ └── parallel_experts.py ├── resources ├── 1-myshell-mit.png ├── 2-performance.png ├── 3-architecture.png ├── 4-phase1-data.png └── 5-phase2-data.png └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.egg-info 3 | build/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # JetMoE: Reaching LLaMA2 Performance with 0.1M Dollars 2 | 3 |
4 |
 
5 | 6 | 7 |
8 | 9 | ## Key Messages 10 | 11 | 1. JetMoE-8B is **trained with less than $ 0.1 million**1 **cost but outperforms LLaMA2-7B from Meta AI**, who has multi-billion-dollar training resources. LLM training can be **much cheaper than people previously thought**. 12 | 13 | 2. JetMoE-8B is **fully open-sourced and academia-friendly** because: 14 | - It **only uses public datasets** for training, and the code is open-sourced. No proprietary resource is needed. 15 | - It **can be finetuned with very limited compute budget** (e.g., consumer-grade GPU) that most labs can afford. 16 | 17 | 3. JetMoE-8B **only has 2.2B active parameters** during inference, which drastically lowers the computational cost. Compared to a model with similar inference computation, like Gemma-2B, JetMoE-8B achieves constantly better performance. 18 | 19 | 1 We used a 96×H100 GPU cluster for 2 weeks, which cost ~$0.08 million. 20 | 21 | Website: [https://research.myshell.ai/jetmoe](https://research.myshell.ai/jetmoe) 22 | 23 | HuggingFace: [https://huggingface.co/jetmoe/jetmoe-8b](https://huggingface.co/jetmoe/jetmoe-8b) 24 | 25 | Online Demo on Lepton AI: [https://www.lepton.ai/playground/chat?model=jetmoe-8b-chat](https://www.lepton.ai/playground/chat?model=jetmoe-8b-chat) 26 | 27 | Technical Report: [https://arxiv.org/pdf/2404.07413.pdf](https://arxiv.org/pdf/2404.07413.pdf) 28 | 29 | ## Authors 30 | 31 | The project is contributed by [Yikang Shen](https://scholar.google.com.hk/citations?user=qff5rRYAAAAJ), [Zhen Guo](https://zguo0525.github.io/), [Tianle Cai](https://www.tianle.website/#/) and [Zengyi Qin](https://www.qinzy.tech/). For technical inquiries, please contact [Yikang Shen](https://scholar.google.com.hk/citations?user=qff5rRYAAAAJ). For media and collaboration inquiries, please contact [Zengyi Qin](https://www.qinzy.tech/). 32 | 33 | ## Collaboration 34 | **If you have great ideas but need more resources (GPU, data, funding, etc.)**, welcome to contact **MyShell.ai** via [Zengyi Qin](https://www.qinzy.tech/). **MyShell.ai** is open to collaborations and are actively supporting high-quality open-source projects. 35 | 36 | ## Benchmarks 37 | We use the same evaluation methodology as in the Open LLM leaderboard. For MBPP code benchmark, we use the same evaluation methodology as in the LLaMA2 and Deepseek-MoE paper. The results are shown below: 38 | 39 | |Model|Activate Params|Training Tokens|Open LLM Leaderboard Avg|ARC|Hellaswag|MMLU|TruthfulQA|WinoGrande|GSM8k|MBPP|HumanEval| 40 | |---|---|---|---|---|---|---|---|---|---|---|---| 41 | |Shot||||25|10|5|0|5|5|3|0| 42 | |Metric||||acc_norm|acc_norm|acc|mc2|acc|acc|Pass@1|Pass@1| 43 | |LLaMA2-7B|7B|2T|51.0|53.1|78.6|46.9|38.8|74|14.5|20.8|12.8| 44 | |LLaMA-13B|13B|1T|51.4|**56.2**|**80.9**|47.7|39.5|**76.2**|7.6|22.0|15.8| 45 | |DeepseekMoE-16B|2.8B|2T|51.1|53.2|79.8|46.3|36.1|73.7|17.3|34.0|**25.0**| 46 | |Gemma-2B|2B|2T|46.4|48.4|71.8|41.8|33.1|66.3|16.9|28.0|24.4| 47 | |JetMoE-8B|2.2B|1.25T|**53.0**|48.7|80.5|**49.2**|**41.7**|70.2|**27.8**|**34.2**|14.6| 48 | 49 | | Model | MT-Bench Score | 50 | |---------------------|-----------| 51 | | GPT-4 | 9.014 | 52 | | GPT-3.5-turbo | 7.995 | 53 | | Claude-v1 | 7.923 | 54 | | **JetMoE-8B-chat** | **6.681** | 55 | | Llama-2-13b-chat | 6.650 | 56 | | Vicuna-13b-v1.3 | 6.413 | 57 | | Wizardlm-13b | 6.353 | 58 | | Llama-2-7b-chat | 6.269 | 59 | 60 | To our surprise, despite the lower training cost and computation, JetMoE-8B performs even better than LLaMA2-7B, LLaMA-13B, and DeepseekMoE-16B. Compared to a model with similar training and inference computation, like Gemma-2B, JetMoE-8B achieves better performance. 61 | 62 | ## Model Usage 63 | To load the models, you need install this package: 64 | ``` 65 | pip install -e . 66 | ``` 67 | 68 | Then you can load the model with the following code: 69 | ```python 70 | from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, AutoModelForSequenceClassification 71 | from jetmoe import JetMoEForCausalLM, JetMoEConfig, JetMoEForSequenceClassification 72 | 73 | AutoConfig.register("jetmoe", JetMoEConfig) 74 | AutoModelForCausalLM.register(JetMoEConfig, JetMoEForCausalLM) 75 | AutoModelForSequenceClassification.register(JetMoEConfig, JetMoEForSequenceClassification) 76 | 77 | tokenizer = AutoTokenizer.from_pretrained('jetmoe/jetmoe-8b') 78 | model = AutoModelForCausalLM.from_pretrained('jetmoe/jetmoe-8b') 79 | ``` 80 | 81 | ## Model Details 82 | Please refer to the technical report [https://arxiv.org/pdf/2404.07413.pdf](https://arxiv.org/pdf/2404.07413.pdf) for model details and training details. 83 | 84 | ## Acknowledgement 85 | We express our gratitude to [Shengding Hu](https://shengdinghu.github.io/) for his valuable advice on the Phase 2 data mixture. We also express our gratitude to [Exabits](https://www.exabits.ai/) for their assistance in setting up the GPU clusters. 86 | -------------------------------------------------------------------------------- /jetmoe/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 JetMoE AI and The HuggingFace Inc. team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from typing import TYPE_CHECKING 15 | 16 | from transformers.utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available 17 | 18 | 19 | _import_structure = { 20 | "configuration_jetmoe": ["JETMOE_PRETRAINED_CONFIG_ARCHIVE_MAP", "JetMoEConfig"], 21 | } 22 | 23 | 24 | try: 25 | if not is_torch_available(): 26 | raise OptionalDependencyNotAvailable() 27 | except OptionalDependencyNotAvailable: 28 | pass 29 | else: 30 | _import_structure["modeling_jetmoe"] = [ 31 | "JetMoEForCausalLM", 32 | "JetMoEModel", 33 | "JetMoEPreTrainedModel", 34 | "JetMoEForSequenceClassification", 35 | ] 36 | 37 | if TYPE_CHECKING: 38 | from .configuration_jetmoe import JETMOE_PRETRAINED_CONFIG_ARCHIVE_MAP, JetMoEConfig 39 | 40 | try: 41 | if not is_torch_available(): 42 | raise OptionalDependencyNotAvailable() 43 | except OptionalDependencyNotAvailable: 44 | pass 45 | else: 46 | from .modeling_jetmoe import ( 47 | JetMoEForCausalLM, 48 | JetMoEForSequenceClassification, 49 | JetMoEModel, 50 | JetMoEPreTrainedModel, 51 | ) 52 | 53 | else: 54 | import sys 55 | 56 | sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) 57 | -------------------------------------------------------------------------------- /jetmoe/configuration_jetmoe.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 JetMoE AI and the HuggingFace Inc. team. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """JetMoE model configuration""" 16 | 17 | from transformers.configuration_utils import PretrainedConfig 18 | from transformers.utils import logging 19 | 20 | 21 | logger = logging.get_logger(__name__) 22 | 23 | 24 | class JetMoEConfig(PretrainedConfig): 25 | r""" 26 | This is the configuration class to store the configuration of a [`JetMoEModel`]. It is used to instantiate an 27 | JetMoE model according to the specified arguments, defining the model architecture. Instantiating a configuration 28 | with the defaults will yield a configuration of the JetMoE-4B. 29 | 30 | [jetmoe/jetmoe-8b](https://huggingface.co/jetmoe/jetmoe-8b) 31 | 32 | Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the 33 | documentation from [`PretrainedConfig`] for more information. 34 | 35 | 36 | Args: 37 | vocab_size (`int`, *optional*, defaults to 32000): 38 | Vocabulary size of the JetMoE model. Defines the number of different tokens that can be represented by the 39 | `inputs_ids` passed when calling [`JetMoEModel`] 40 | hidden_size (`int`, *optional*, defaults to 2048): 41 | Dimension of the hidden representations. 42 | num_hidden_layers (`int`, *optional*, defaults to 12): Defines the number of blocks. 43 | num_attention_heads (`int`, *optional*, defaults to 32): 44 | Number of attention heads for each attention layer in the Transformer encoder. 45 | num_key_value_heads (`int`, *optional*, defaults to 16): 46 | Number of attention heads for each key and value in the Transformer encoder. 47 | kv_channels (`int`, *optional*, defaults to 128): Defines the number of channels for the key and value tensors. 48 | ffn_hidden_size (`int`, *optional*, defaults to 5632): Defines the hidden size of the feed-forward layer. 49 | max_position_embeddings (`int`, *optional*, defaults to 4096): 50 | The maximum sequence length that this model might ever be used with. JetMoE's sliding window attention 51 | allows sequence of up to 4096*32 tokens. 52 | activation_function (`string`, *optional*, defaults to `"silu"`): Defines the activation function for MLP experts. 53 | glu (`bool`, *optional*, defaults to `True`): Whether to use Gated Linear Units in the MLP experts. 54 | moe_num_experts (`int`, *optional*, defaults to 8): Defines the number of experts in the mixture of experts. 55 | moe_top_k (`int, *optional*, defaults to 2): Defines the number of experts to use for each token. 56 | use_cache (`bool`, *optional*, defaults to `True`): 57 | Whether or not the model should return the last key/values attentions (not used by all models). Only 58 | relevant if `config.is_decoder=True`. 59 | bos_token_id (`int`, *optional*, defaults to 1): 60 | The id of the "beginning-of-sequence" token. 61 | eos_token_id (`int`, *optional*, defaults to 2): 62 | The id of the "end-of-sequence" token. 63 | tie_word_embeddings (`bool`, *optional*, defaults to `True`): 64 | Whether the model's input and output word embeddings should be tied. 65 | bias (`bool`, *optional*, defaults to `True`): Whether to use bias in the feed-forward and attention layer. 66 | rope_theta (`float`, *optional*, defaults to 10000.0): 67 | The base period of the RoPE embeddings. 68 | rms_norm_eps (`float`, *optional*, defaults to 1e-06): 69 | The epsilon used by the rms normalization layers. 70 | initializer_range (`float`, *optional*, defaults to 0.01): 71 | The standard deviation of the truncated_normal_initializer for initializing all weight matrices. 72 | 73 | ```python 74 | >>> from transformers import JetMoEModel, JetMoEConfig 75 | 76 | >>> # Initializing a JetMoE 4B style configuration 77 | >>> configuration = JetMoEConfig() 78 | 79 | >>> # Initializing a model from the JetMoE 4B style configuration 80 | >>> model = JetMoEModel(configuration) 81 | 82 | >>> # Accessing the model configuration 83 | >>> configuration = model.config 84 | ```""" 85 | 86 | model_type = "jetmoe" 87 | keys_to_ignore_at_inference = ["past_key_values"] 88 | 89 | def __init__( 90 | self, 91 | vocab_size=32000, 92 | hidden_size=2048, 93 | num_hidden_layers=12, 94 | num_attention_heads=32, 95 | num_key_value_heads=16, 96 | kv_channels=128, 97 | ffn_hidden_size=5632, 98 | max_position_embeddings=4096, 99 | activation_function="silu", 100 | glu=True, 101 | moe_num_experts=8, 102 | moe_top_k=2, 103 | use_cache=True, 104 | bos_token_id=1, 105 | eos_token_id=2, 106 | tie_word_embeddings=True, 107 | bias=True, 108 | rope_theta=10000.0, 109 | rms_norm_eps=1e-6, 110 | initializer_range=0.01, 111 | **kwargs, 112 | ): 113 | self.vocab_size = vocab_size 114 | self.hidden_size = hidden_size 115 | self.num_hidden_layers = num_hidden_layers 116 | self.num_attention_heads = num_attention_heads 117 | self.num_key_value_heads = num_key_value_heads 118 | self.kv_channels = kv_channels 119 | self.ffn_hidden_size = ffn_hidden_size 120 | self.max_position_embeddings = max_position_embeddings 121 | self.activation_function = activation_function 122 | self.glu = glu 123 | self.moe_num_experts = moe_num_experts 124 | self.moe_top_k = moe_top_k 125 | self.use_cache = use_cache 126 | self.initializer_range = initializer_range 127 | 128 | self.bos_token_id = bos_token_id 129 | self.eos_token_id = eos_token_id 130 | 131 | self.bias = bias 132 | self.rope_theta = rope_theta 133 | self.rms_norm_eps = rms_norm_eps 134 | 135 | super().__init__( 136 | bos_token_id=bos_token_id, eos_token_id=eos_token_id, tie_word_embeddings=tie_word_embeddings, **kwargs 137 | ) 138 | -------------------------------------------------------------------------------- /jetmoe/modeling_jetmoe.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 JetMoE AI and the HuggingFace Inc. team. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """PyTorch JetMoE model.""" 16 | 17 | import math 18 | import warnings 19 | from typing import List, Optional, Tuple, Union 20 | 21 | import torch 22 | import torch.utils.checkpoint 23 | from torch import nn 24 | from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss 25 | from torch.nn import functional as F 26 | 27 | from transformers.activations import ACT2FN 28 | from transformers.cache_utils import Cache, DynamicCache 29 | from transformers.modeling_attn_mask_utils import ( 30 | _prepare_4d_causal_attention_mask, 31 | _prepare_4d_causal_attention_mask_for_sdpa, 32 | ) 33 | from transformers.modeling_outputs import ( 34 | BaseModelOutputWithPast, 35 | CausalLMOutputWithPast, 36 | SequenceClassifierOutputWithPast, 37 | dataclass, 38 | ) 39 | from transformers.modeling_utils import PreTrainedModel 40 | from transformers.utils import ( 41 | add_start_docstrings, 42 | add_start_docstrings_to_model_forward, 43 | is_flash_attn_2_available, 44 | is_flash_attn_greater_or_equal_2_10, 45 | logging, 46 | replace_return_docstrings, 47 | ) 48 | from .configuration_jetmoe import JetMoEConfig 49 | from .utils import MoE, ParallelExperts 50 | 51 | 52 | if is_flash_attn_2_available(): 53 | from flash_attn import flash_attn_func, flash_attn_varlen_func 54 | from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa 55 | 56 | logger = logging.get_logger(__name__) 57 | 58 | _CHECKPOINT_FOR_DOC = "jetmoe" 59 | _CONFIG_FOR_DOC = "JetMoEConfig" 60 | 61 | 62 | @dataclass 63 | class JetMoEBaseModelOutputWithPast(BaseModelOutputWithPast): 64 | """ 65 | Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding). 66 | 67 | Args: 68 | last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): 69 | Sequence of hidden-states at the output of the last layer of the model. 70 | 71 | If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1, 72 | hidden_size)` is output. 73 | past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): 74 | Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape 75 | `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if 76 | `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads, 77 | encoder_sequence_length, embed_size_per_head)`. 78 | 79 | Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if 80 | `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values` 81 | input) to speed up sequential decoding. 82 | hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): 83 | Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + 84 | one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. 85 | 86 | Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. 87 | attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): 88 | Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, 89 | sequence_length)`. 90 | 91 | Attentions weights after the attention softmax, used to compute the weighted average in the self-attention 92 | heads. 93 | """ 94 | 95 | last_hidden_state: torch.FloatTensor = None 96 | past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None 97 | hidden_states: Optional[Tuple[torch.FloatTensor]] = None 98 | attentions: Optional[Tuple[torch.FloatTensor]] = None 99 | aux_loss: Optional[torch.FloatTensor] = None 100 | 101 | 102 | @dataclass 103 | class JetMoECausalLMOutputWithPast(CausalLMOutputWithPast): 104 | """ 105 | Base class for causal language model (or autoregressive) outputs. 106 | 107 | Args: 108 | loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): 109 | Language modeling loss (for next-token prediction). 110 | logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): 111 | Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). 112 | past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): 113 | Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape 114 | `(batch_size, num_heads, sequence_length, embed_size_per_head)`) 115 | 116 | Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see 117 | `past_key_values` input) to speed up sequential decoding. 118 | hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): 119 | Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + 120 | one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. 121 | 122 | Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. 123 | attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): 124 | Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, 125 | sequence_length)`. 126 | 127 | Attentions weights after the attention softmax, used to compute the weighted average in the self-attention 128 | heads. 129 | """ 130 | 131 | loss: Optional[torch.FloatTensor] = None 132 | logits: torch.FloatTensor = None 133 | past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None 134 | hidden_states: Optional[Tuple[torch.FloatTensor]] = None 135 | attentions: Optional[Tuple[torch.FloatTensor]] = None 136 | aux_loss: Optional[torch.FloatTensor] = None 137 | 138 | 139 | @dataclass 140 | class JetMoESequenceClassifierOutputWithPast(SequenceClassifierOutputWithPast): 141 | """ 142 | Base class for outputs of sentence classification models. 143 | 144 | Args: 145 | loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): 146 | Classification (or regression if config.num_labels==1) loss. 147 | logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`): 148 | Classification (or regression if config.num_labels==1) scores (before SoftMax). 149 | past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): 150 | Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape 151 | `(batch_size, num_heads, sequence_length, embed_size_per_head)`) 152 | 153 | Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see 154 | `past_key_values` input) to speed up sequential decoding. 155 | hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): 156 | Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + 157 | one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. 158 | 159 | Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. 160 | attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): 161 | Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, 162 | sequence_length)`. 163 | 164 | Attentions weights after the attention softmax, used to compute the weighted average in the self-attention 165 | heads. 166 | """ 167 | 168 | loss: Optional[torch.FloatTensor] = None 169 | logits: torch.FloatTensor = None 170 | past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None 171 | hidden_states: Optional[Tuple[torch.FloatTensor]] = None 172 | attentions: Optional[Tuple[torch.FloatTensor]] = None 173 | aux_loss: Optional[torch.FloatTensor] = None 174 | 175 | 176 | # Copied from transformers.models.llama.modeling_llama._get_unpad_data 177 | def _get_unpad_data(attention_mask): 178 | seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) 179 | indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() 180 | max_seqlen_in_batch = seqlens_in_batch.max().item() 181 | cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) 182 | return ( 183 | indices, 184 | cu_seqlens, 185 | max_seqlen_in_batch, 186 | ) 187 | 188 | 189 | class JetMoERMSNorm(nn.Module): 190 | def __init__(self, hidden_size, eps=1e-6): 191 | """ 192 | JetMoERMSNorm module 193 | """ 194 | super().__init__() 195 | self.weight = nn.Parameter(torch.ones(hidden_size)) 196 | self.variance_epsilon = eps 197 | 198 | def forward(self, hidden_states): 199 | input_dtype = hidden_states.dtype 200 | hidden_states = hidden_states.to(torch.float32) 201 | variance = hidden_states.pow(2).mean(-1, keepdim=True) 202 | hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) 203 | return self.weight * hidden_states.to(input_dtype) 204 | 205 | 206 | # copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding 207 | class JetMoERotaryEmbedding(nn.Module): 208 | def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): 209 | super().__init__() 210 | 211 | self.dim = dim 212 | self.max_position_embeddings = max_position_embeddings 213 | self.base = base 214 | inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) 215 | self.register_buffer("inv_freq", inv_freq, persistent=False) 216 | 217 | # Build here to make `torch.jit.trace` work. 218 | self._set_cos_sin_cache( 219 | seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() 220 | ) 221 | 222 | def _set_cos_sin_cache(self, seq_len, device, dtype): 223 | self.max_seq_len_cached = seq_len 224 | t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) 225 | 226 | freqs = torch.outer(t, self.inv_freq) 227 | # Different from paper, but it uses a different permutation in order to obtain the same calculation 228 | emb = torch.cat((freqs, freqs), dim=-1) 229 | self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) 230 | self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) 231 | 232 | def forward(self, x, seq_len=None): 233 | # x: [bs, num_attention_heads, seq_len, head_size] 234 | if seq_len > self.max_seq_len_cached: 235 | self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) 236 | 237 | return ( 238 | self.cos_cached[:seq_len].to(dtype=x.dtype), 239 | self.sin_cached[:seq_len].to(dtype=x.dtype), 240 | ) 241 | 242 | 243 | # Copied from transformers.models.llama.modeling_llama.rotate_half 244 | def rotate_half(x): 245 | """Rotates half the hidden dims of the input.""" 246 | x1 = x[..., : x.shape[-1] // 2] 247 | x2 = x[..., x.shape[-1] // 2 :] 248 | return torch.cat((-x2, x1), dim=-1) 249 | 250 | 251 | # copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb 252 | def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=2): 253 | """Applies Rotary Position Embedding to the query and key tensors. 254 | 255 | Args: 256 | q (`torch.Tensor`): The query tensor. 257 | k (`torch.Tensor`): The key tensor. 258 | cos (`torch.Tensor`): The cosine part of the rotary embedding. 259 | sin (`torch.Tensor`): The sine part of the rotary embedding. 260 | position_ids (`torch.Tensor`): 261 | The position indices of the tokens corresponding to the query and key tensors. For example, this can be 262 | used to pass offsetted position ids when working with a KV-cache. 263 | unsqueeze_dim (`int`, *optional*, defaults to 1): 264 | The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and 265 | sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note 266 | that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and 267 | k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes 268 | cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have 269 | the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. 270 | Returns: 271 | `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. 272 | """ 273 | cos = cos[position_ids].unsqueeze(unsqueeze_dim) 274 | sin = sin[position_ids].unsqueeze(unsqueeze_dim) 275 | q_embed = (q * cos) + (rotate_half(q) * sin) 276 | k_embed = (k * cos) + (rotate_half(k) * sin) 277 | return q_embed, k_embed 278 | 279 | 280 | class JetMoEAttention(nn.Module): 281 | """ 282 | Multi-headed attention from 'Attention Is All You Need' paper. 283 | """ 284 | 285 | def __init__(self, config: JetMoEConfig, layer_idx: Optional[int] = None): 286 | """ 287 | Initialize the JetMoEAttention module. 288 | 289 | Args: 290 | config: Configuration object with model hyperparameters. 291 | """ 292 | super().__init__() 293 | self.config = config 294 | self.layer_idx = layer_idx 295 | self.is_causal = True 296 | if layer_idx is None: 297 | logger.warning_once( 298 | f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " 299 | "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " 300 | "when creating this class." 301 | ) 302 | 303 | self.top_k = config.moe_top_k 304 | 305 | self.kv_projection_size = config.kv_channels * config.num_key_value_heads 306 | self.num_key_value_heads = config.num_key_value_heads 307 | self.num_heads = config.num_attention_heads 308 | assert self.num_heads == self.num_key_value_heads * config.moe_top_k 309 | self.hidden_size_per_attention_head = config.kv_channels 310 | 311 | self.experts = MoE( 312 | input_size=config.hidden_size, 313 | hidden_size=self.kv_projection_size, 314 | num_experts=config.moe_num_experts, 315 | top_k=config.moe_top_k, 316 | glu=False, 317 | ) 318 | 319 | self.kv_proj = torch.nn.Linear(config.hidden_size, self.kv_projection_size * 2, bias=False) 320 | 321 | self.rotary_emb = JetMoERotaryEmbedding( 322 | config.kv_channels, 323 | max_position_embeddings=config.max_position_embeddings, 324 | base=config.rope_theta, 325 | ) 326 | 327 | # def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): 328 | # return tensor.view(bsz, seq_len, self.num_attention_heads, self.hidden_size_per_attention_head).transpose(1, 2).contiguous() 329 | 330 | def forward( 331 | self, 332 | hidden_states: torch.Tensor, 333 | attention_mask: Optional[torch.Tensor] = None, 334 | position_ids: Optional[torch.LongTensor] = None, 335 | past_key_value: Optional[Cache] = None, 336 | output_attentions: bool = False, 337 | use_cache: bool = False, 338 | **kwargs, 339 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 340 | if "padding_mask" in kwargs: 341 | warnings.warn( 342 | "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" 343 | ) 344 | bsz, q_len, _ = hidden_states.size() 345 | 346 | query_states, aux_loss = self.experts.map(hidden_states) 347 | key_states, value_states = self.kv_proj(hidden_states).chunk(2, dim=-1) 348 | 349 | query_states = query_states.view(bsz, q_len, self.num_heads, self.hidden_size_per_attention_head).transpose( 350 | 1, 2 351 | ) 352 | key_states = key_states.view( 353 | bsz, q_len, self.num_key_value_heads, self.hidden_size_per_attention_head 354 | ).transpose(1, 2) 355 | value_states = value_states.view( 356 | bsz, q_len, self.num_key_value_heads, self.hidden_size_per_attention_head 357 | ).transpose(1, 2) 358 | 359 | kv_seq_len = key_states.shape[2] 360 | if past_key_value is not None: 361 | if self.layer_idx is None: 362 | raise ValueError( 363 | f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " 364 | "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " 365 | "with a layer index." 366 | ) 367 | kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) 368 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) 369 | query_states, key_states = apply_rotary_pos_emb( 370 | query_states, key_states, cos, sin, position_ids, unsqueeze_dim=1 371 | ) 372 | 373 | if past_key_value is not None: 374 | cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models 375 | key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) 376 | 377 | # repeat k/v heads if n_kv_heads < n_heads 378 | key_states = key_states.repeat(1, self.top_k, 1, 1) 379 | value_states = value_states.repeat(1, self.top_k, 1, 1) 380 | 381 | attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt( 382 | self.hidden_size_per_attention_head 383 | ) 384 | 385 | if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): 386 | raise ValueError( 387 | f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" 388 | f" {attn_weights.size()}" 389 | ) 390 | 391 | if attention_mask is not None: 392 | if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): 393 | raise ValueError( 394 | f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" 395 | ) 396 | 397 | attn_weights = attn_weights + attention_mask 398 | 399 | # upcast attention to fp32 400 | attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) 401 | # attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) 402 | attn_output = torch.matmul(attn_weights, value_states) 403 | 404 | if attn_output.size() != (bsz, self.num_heads, q_len, self.hidden_size_per_attention_head): 405 | raise ValueError( 406 | f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.hidden_size_per_attention_head)}, but is" 407 | f" {attn_output.size()}" 408 | ) 409 | 410 | attn_output = attn_output.transpose(1, 2).contiguous() 411 | attn_output = attn_output.reshape(bsz, q_len, self.top_k, self.kv_projection_size) 412 | 413 | attn_output = self.experts.reduce(attn_output) 414 | attn_output = attn_output.view(bsz, q_len, -1) 415 | 416 | if not output_attentions: 417 | attn_weights = None 418 | 419 | return attn_output, attn_weights, past_key_value, aux_loss 420 | 421 | 422 | # copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->JetMoE 423 | class JetMoESdpaAttention(JetMoEAttention): 424 | """ 425 | JetMoE attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from 426 | `JetMoEAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to 427 | SDPA API. 428 | """ 429 | 430 | # Adapted from JetMoEAttention.forward 431 | def forward( 432 | self, 433 | hidden_states: torch.Tensor, 434 | attention_mask: Optional[torch.Tensor] = None, 435 | position_ids: Optional[torch.LongTensor] = None, 436 | past_key_value: Optional[Cache] = None, 437 | output_attentions: bool = False, 438 | use_cache: bool = False, 439 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 440 | if output_attentions: 441 | # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. 442 | logger.warning_once( 443 | "JetMoEModel is using JetMoESdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " 444 | 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' 445 | ) 446 | return super().forward( 447 | hidden_states=hidden_states, 448 | attention_mask=attention_mask, 449 | position_ids=position_ids, 450 | past_key_value=past_key_value, 451 | output_attentions=output_attentions, 452 | use_cache=use_cache, 453 | ) 454 | 455 | bsz, q_len, _ = hidden_states.size() 456 | 457 | query_states, aux_loss = self.experts.map(hidden_states) 458 | key_states, value_states = self.kv_proj(hidden_states).chunk(2, dim=-1) 459 | 460 | query_states = query_states.view(bsz, q_len, self.num_heads, self.hidden_size_per_attention_head).transpose( 461 | 1, 2 462 | ) 463 | key_states = key_states.view( 464 | bsz, q_len, self.num_key_value_heads, self.hidden_size_per_attention_head 465 | ).transpose(1, 2) 466 | value_states = value_states.view( 467 | bsz, q_len, self.num_key_value_heads, self.hidden_size_per_attention_head 468 | ).transpose(1, 2) 469 | 470 | kv_seq_len = key_states.shape[2] 471 | if past_key_value is not None: 472 | kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) 473 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) 474 | 475 | query_states, key_states = apply_rotary_pos_emb( 476 | query_states, key_states, cos, sin, position_ids, unsqueeze_dim=1 477 | ) 478 | 479 | if past_key_value is not None: 480 | cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models 481 | key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) 482 | 483 | key_states = key_states.repeat(1, self.top_k, 1, 1) 484 | value_states = value_states.repeat(1, self.top_k, 1, 1) 485 | 486 | if attention_mask is not None: 487 | if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): 488 | raise ValueError( 489 | f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" 490 | ) 491 | 492 | # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, 493 | # Reference: https://github.com/pytorch/pytorch/issues/112577. 494 | if query_states.device.type == "cuda" and attention_mask is not None: 495 | query_states = query_states.contiguous() 496 | key_states = key_states.contiguous() 497 | value_states = value_states.contiguous() 498 | 499 | attn_output = torch.nn.functional.scaled_dot_product_attention( 500 | query_states, 501 | key_states, 502 | value_states, 503 | attn_mask=attention_mask, 504 | dropout_p=0.0, 505 | # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. 506 | is_causal=self.is_causal and attention_mask is None and q_len > 1, 507 | ) 508 | 509 | attn_output = attn_output.transpose(1, 2).contiguous() 510 | attn_output = attn_output.reshape(bsz, q_len, self.top_k, self.kv_projection_size) 511 | 512 | attn_output = self.experts.reduce(attn_output) 513 | attn_output = attn_output.view(bsz, q_len, -1) 514 | 515 | return attn_output, None, past_key_value, aux_loss 516 | 517 | 518 | class JetMoEFlashAttention2(JetMoEAttention): 519 | def __init__(self, *args, **kwargs): 520 | super().__init__(*args, **kwargs) 521 | 522 | # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. 523 | # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. 524 | # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). 525 | self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() 526 | 527 | def forward( 528 | self, 529 | hidden_states: Optional[torch.FloatTensor], 530 | attention_mask: Optional[torch.FloatTensor] = None, 531 | position_ids: Optional[torch.LongTensor] = None, 532 | past_key_value: Optional[Cache] = None, 533 | use_cache: Optional[bool] = False, 534 | output_attentions: Optional[bool] = False, 535 | **kwargs, 536 | ) -> Union[ 537 | Tuple[torch.Tensor, Tuple[torch.Tensor]], 538 | Optional[Tuple[torch.Tensor, Tuple[torch.Tensor], Tuple[torch.Tensor, ...]]], 539 | ]: 540 | """ 541 | Forward pass of the JetMoEAttention module. 542 | 543 | Args: 544 | hidden_states (Optional[torch.FloatTensor]): Input hidden states. 545 | attention_mask (Optional[torch.FloatTensor]): Attention mask. 546 | layer_past (Optional[Tuple[torch.Tensor]]): Past layer state. 547 | use_cache (Optional[bool]): Whether to use cached states. 548 | output_attentions (Optional[bool]): Whether to output attention weights. 549 | 550 | Returns: 551 | Union[Tuple[torch.Tensor, Tuple[torch.Tensor]], Optional[Tuple[...]]]: Tuple containing outputs. 552 | """ 553 | # assert attention_mask is None, "attention_mask is not supported" 554 | assert output_attentions is False, "output_attentions is not supported" 555 | 556 | B, T, C = hidden_states.size() # batch size, sequence length, embedding dimensionality (hidden_size) 557 | 558 | # calculate query, key, values 559 | query_layer, aux_loss = self.experts.map(hidden_states) 560 | key_layer, value_layer = self.kv_proj(hidden_states).chunk(2, dim=-1) 561 | 562 | query_layer = query_layer.view(B, T, self.num_heads, self.hidden_size_per_attention_head) # (B, T, k * nh, hs) 563 | key_layer = key_layer.view( 564 | B, T, self.num_key_value_heads, self.hidden_size_per_attention_head 565 | ) # (B, T, nh, hs) 566 | value_layer = value_layer.view( 567 | B, T, self.num_key_value_heads, self.hidden_size_per_attention_head 568 | ) # (B, T, nh, hs) 569 | 570 | kv_seq_len = key_layer.shape[1] 571 | if past_key_value is not None: 572 | if self.layer_idx is None: 573 | raise ValueError( 574 | f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " 575 | "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " 576 | "with a layer index." 577 | ) 578 | kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) 579 | cos, sin = self.rotary_emb(value_layer, seq_len=kv_seq_len) 580 | query_layer, key_layer = apply_rotary_pos_emb(query_layer, key_layer, cos, sin, position_ids) 581 | 582 | # query_layer = query_layer.contiguous() 583 | # expand the key_layer and value_layer [sk, b, ng, hn] -> [sk, b, np, hn] 584 | key_layer = key_layer.repeat(1, 1, self.top_k, 1) 585 | value_layer = value_layer.repeat(1, 1, self.top_k, 1) 586 | 587 | if past_key_value is not None: 588 | cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models 589 | # print(self.layer_idx, key_layer.size()) 590 | key_layer = key_layer.transpose(1, 2) 591 | value_layer = value_layer.transpose(1, 2) 592 | key_layer, value_layer = past_key_value.update(key_layer, value_layer, self.layer_idx, cache_kwargs) 593 | key_layer = key_layer.transpose(1, 2) 594 | value_layer = value_layer.transpose(1, 2) 595 | 596 | context_layer = self._flash_attention_forward( 597 | query_layer, 598 | key_layer, 599 | value_layer, 600 | attention_mask, 601 | T, 602 | ) 603 | 604 | # output projection 605 | y = self.experts.reduce(context_layer.reshape(T, B, self.top_k, self.kv_projection_size)) 606 | y = y.view(B, T, C) # re-assemble all head outputs side by side 607 | 608 | if not output_attentions: 609 | attn_weights = None 610 | 611 | return y, attn_weights, past_key_value, aux_loss 612 | 613 | def _flash_attention_forward( 614 | self, 615 | query_states, 616 | key_states, 617 | value_states, 618 | attention_mask, 619 | query_length, 620 | dropout=0.0, 621 | softmax_scale=None, 622 | ): 623 | """ 624 | Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token 625 | first unpad the input, then computes the attention scores and pad the final attention scores. 626 | 627 | Args: 628 | query_states (`torch.Tensor`): 629 | Input query states to be passed to Flash Attention API 630 | key_states (`torch.Tensor`): 631 | Input key states to be passed to Flash Attention API 632 | value_states (`torch.Tensor`): 633 | Input value states to be passed to Flash Attention API 634 | attention_mask (`torch.Tensor`): 635 | The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the 636 | position of padding tokens and 1 for the position of non-padding tokens. 637 | dropout (`float`): 638 | Attention dropout 639 | softmax_scale (`float`, *optional*): 640 | The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) 641 | """ 642 | if not self._flash_attn_uses_top_left_mask: 643 | causal = self.is_causal 644 | else: 645 | # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__. 646 | causal = self.is_causal and query_length != 1 647 | 648 | # Contains at least one padding token in the sequence 649 | if attention_mask is not None: 650 | batch_size = query_states.shape[0] 651 | query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( 652 | query_states, key_states, value_states, attention_mask, query_length 653 | ) 654 | 655 | cu_seqlens_q, cu_seqlens_k = cu_seq_lens 656 | max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens 657 | 658 | attn_output_unpad = flash_attn_varlen_func( 659 | query_states, 660 | key_states, 661 | value_states, 662 | cu_seqlens_q=cu_seqlens_q, 663 | cu_seqlens_k=cu_seqlens_k, 664 | max_seqlen_q=max_seqlen_in_batch_q, 665 | max_seqlen_k=max_seqlen_in_batch_k, 666 | dropout_p=dropout, 667 | softmax_scale=softmax_scale, 668 | causal=causal, 669 | ) 670 | 671 | attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) 672 | else: 673 | attn_output = flash_attn_func( 674 | query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal 675 | ) 676 | 677 | return attn_output 678 | 679 | def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): 680 | indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) 681 | batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape 682 | 683 | key_layer = index_first_axis( 684 | key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k 685 | ) 686 | value_layer = index_first_axis( 687 | value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k 688 | ) 689 | if query_length == kv_seq_len: 690 | query_layer = index_first_axis( 691 | query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k 692 | ) 693 | cu_seqlens_q = cu_seqlens_k 694 | max_seqlen_in_batch_q = max_seqlen_in_batch_k 695 | indices_q = indices_k 696 | elif query_length == 1: 697 | max_seqlen_in_batch_q = 1 698 | cu_seqlens_q = torch.arange( 699 | batch_size + 1, dtype=torch.int32, device=query_layer.device 700 | ) # There is a memcpy here, that is very bad. 701 | indices_q = cu_seqlens_q[:-1] 702 | query_layer = query_layer.squeeze(1) 703 | else: 704 | # The -q_len: slice assumes left padding. 705 | attention_mask = attention_mask[:, -query_length:] 706 | query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) 707 | 708 | return ( 709 | query_layer, 710 | key_layer, 711 | value_layer, 712 | indices_q, 713 | (cu_seqlens_q, cu_seqlens_k), 714 | (max_seqlen_in_batch_q, max_seqlen_in_batch_k), 715 | ) 716 | 717 | 718 | JETMOE_ATTENTION_CLASSES = { 719 | "eager": JetMoEAttention, 720 | "flash_attention_2": JetMoEFlashAttention2, 721 | "sdpa": JetMoESdpaAttention, 722 | } 723 | 724 | 725 | class JetMoEBlock(nn.Module): 726 | def __init__(self, config: JetMoEConfig, layer_idx: Optional[int] = None): 727 | """ 728 | Initialize the JetMoEBlock module. 729 | 730 | Args: 731 | config: Configuration object with model hyperparameters. 732 | """ 733 | super().__init__() 734 | self.input_layernorm = JetMoERMSNorm(config.hidden_size) 735 | self.self_attention = JETMOE_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx) 736 | self.post_attention_layernorm = JetMoERMSNorm(config.hidden_size) 737 | 738 | self.mlp = MoE( 739 | input_size=config.hidden_size, 740 | hidden_size=config.ffn_hidden_size, 741 | num_experts=config.moe_num_experts, 742 | activation=ACT2FN[config.activation_function], 743 | top_k=config.moe_top_k, 744 | bias=config.bias, 745 | glu=config.glu, 746 | ) 747 | 748 | def forward( 749 | self, 750 | hidden_states: Optional[torch.FloatTensor], 751 | position_ids: Optional[torch.LongTensor] = None, 752 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 753 | attention_mask: Optional[torch.FloatTensor] = None, 754 | output_attentions: Optional[bool] = False, 755 | use_cache: Optional[bool] = False, 756 | **kwargs, 757 | ) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]: 758 | """ 759 | Forward pass of the JetMoEBlock module. 760 | 761 | Args: 762 | hidden_states (Optional[torch.FloatTensor]): Input hidden states. 763 | layer_past (Optional[Tuple[torch.Tensor]]): Past layer state. 764 | attention_mask (Optional[torch.FloatTensor]): Attention mask. 765 | head_mask (Optional[torch.FloatTensor]): Head mask. 766 | use_cache (Optional[bool]): Whether to use cached states. 767 | output_attentions (Optional[bool]): Whether to output attention weights. 768 | 769 | Returns: 770 | Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]: 771 | Tuple containing outputs or optional attention weights. 772 | """ 773 | # Self Attention 774 | attn_output, self_attn_weights, present_key_value, att_aux_loss = self.self_attention( 775 | hidden_states=self.input_layernorm(hidden_states), 776 | attention_mask=attention_mask, 777 | position_ids=position_ids, 778 | past_key_value=past_key_value, 779 | output_attentions=output_attentions, 780 | use_cache=use_cache, 781 | ) 782 | 783 | hidden_states = hidden_states + attn_output 784 | x_mlp, mlp_aux_loss = self.mlp(self.post_attention_layernorm(hidden_states)) 785 | hidden_states = hidden_states + x_mlp 786 | 787 | outputs = (hidden_states,) 788 | 789 | if output_attentions: 790 | outputs += (self_attn_weights,) 791 | 792 | if use_cache: 793 | outputs += (present_key_value,) 794 | 795 | outputs += (att_aux_loss + mlp_aux_loss,) 796 | 797 | return outputs 798 | 799 | 800 | class JetMoEPreTrainedModel(PreTrainedModel): 801 | """ 802 | An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained 803 | models. 804 | """ 805 | 806 | config_class = JetMoEConfig 807 | base_model_prefix = "transformer" 808 | supports_gradient_checkpointing = False 809 | _no_split_modules = ["JetMoEBlock"] 810 | _skip_keys_device_placement = "past_key_values" 811 | _supports_flash_attn_2 = True 812 | _supports_sdpa = True 813 | _supports_cache_class = True 814 | 815 | def __init__(self, *inputs, **kwargs): 816 | """ 817 | Initialize the JetMoEPreTrainedModel. 818 | 819 | Args: 820 | *inputs: Variable length input arguments. 821 | **kwargs: Keyword arguments. 822 | """ 823 | super().__init__(*inputs, **kwargs) 824 | 825 | self.gradient_checkpointing = False 826 | 827 | def _init_weights(self, module): 828 | """Initialize the weights.""" 829 | if isinstance(module, (nn.Linear,)): 830 | # Slightly different from Mesh Transformer JAX which uses truncated_normal for initialization 831 | # cf https://github.com/pytorch/pytorch/pull/5617 832 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) 833 | if module.bias is not None: 834 | module.bias.data.zero_() 835 | elif isinstance(module, nn.Embedding): 836 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) 837 | if module.padding_idx is not None: 838 | module.weight.data[module.padding_idx].zero_() 839 | elif isinstance(module, nn.LayerNorm): 840 | module.bias.data.zero_() 841 | module.weight.data.fill_(1.0) 842 | elif isinstance(module, ParallelExperts): 843 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) 844 | 845 | # def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs={}): 846 | # for module in self.modules(): 847 | # if hasattr(module, "gradient_checkpointing"): 848 | # self._set_gradient_checkpointing( 849 | # module, True, gradient_checkpointing_kwargs 850 | # ) 851 | 852 | # def gradient_checkpointing_disable(self): 853 | # for module in self.modules(): 854 | # if hasattr(module, "gradient_checkpointing"): 855 | # self._set_gradient_checkpointing( 856 | # module, False 857 | # ) 858 | 859 | # def _set_gradient_checkpointing( 860 | # self, 861 | # module, 862 | # value=False, 863 | # gradient_checkpointing_kwargs={"use_reentrant": False}, 864 | # ): 865 | # """ 866 | # Set gradient checkpointing for the JetMoEModel. 867 | 868 | # Args: 869 | # module: The module for which gradient checkpointing is set. 870 | # value (bool): Whether to enable gradient checkpointing. 871 | # """ 872 | # self._gradient_checkpointing_func = checkpoint 873 | # self.gradient_checkpointing = True 874 | # if isinstance(module, JetMoEModel): 875 | # module.gradient_checkpointing = value 876 | # module.gradient_checkpointing_kwargs = gradient_checkpointing_kwargs 877 | # module._gradient_checkpointing_func = checkpoint 878 | 879 | 880 | JETMOE_START_DOCSTRING = r""" 881 | This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use 882 | it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and 883 | behavior. 884 | 885 | Parameters: 886 | config ([`JetMoEConfig`]): Model configuration class with all the parameters of the model. 887 | Initializing with a config file does not load the weights associated with the model, only the 888 | configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. 889 | """ 890 | 891 | JETMOE_INPUTS_DOCSTRING = r""" 892 | Args: 893 | input_ids (`torch.LongTensor` of shape `({0})`): 894 | Indices of input sequence tokens in the vocabulary. 895 | 896 | Indices can be obtained using [`AutoProcenizer`]. See [`PreTrainedTokenizer.encode`] and 897 | [`PreTrainedTokenizer.__call__`] for details. 898 | 899 | [What are input IDs?](../glossary#input-ids) 900 | attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*): 901 | Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: 902 | 903 | - 1 for tokens that are **not masked**, 904 | - 0 for tokens that are **masked**. 905 | 906 | [What are attention masks?](../glossary#attention-mask) 907 | token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*): 908 | Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, 909 | 1]`: 910 | 911 | - 0 corresponds to a *sentence A* token, 912 | - 1 corresponds to a *sentence B* token. 913 | 914 | [What are token type IDs?](../glossary#token-type-ids) 915 | position_ids (`torch.LongTensor` of shape `({0})`, *optional*): 916 | Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, 917 | config.n_positions - 1]`. 918 | 919 | [What are position IDs?](../glossary#position-ids) 920 | head_mask (`torch.FloatTensor` of shape `(num_attention_heads,)` or `(n_layer, num_attention_heads)`, *optional*): 921 | Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: 922 | 923 | - 1 indicates the head is **not masked**, 924 | - 0 indicates the head is **masked**. 925 | 926 | inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_dim)`, *optional*): 927 | Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This 928 | is useful if you want more control over how to convert *input_ids* indices into associated vectors than the 929 | model's internal embedding lookup matrix. 930 | output_attentions (`bool`, *optional*): 931 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned 932 | tensors for more detail. 933 | output_hidden_states (`bool`, *optional*): 934 | Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for 935 | more detail. 936 | return_dict (`bool`, *optional*): 937 | Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. 938 | """ 939 | 940 | 941 | @add_start_docstrings( 942 | "The bare JetMoE Model outputting raw hidden-states without any specific head on top.", 943 | JETMOE_START_DOCSTRING, 944 | ) 945 | class JetMoEModel(JetMoEPreTrainedModel): 946 | """ 947 | Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`JetMoEBlock`] 948 | 949 | Args: 950 | config: JetMoEConfig 951 | """ 952 | 953 | def __init__(self, config: JetMoEConfig): 954 | super().__init__(config) 955 | self.padding_idx = config.pad_token_id 956 | self.vocab_size = config.vocab_size 957 | 958 | self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) 959 | self.layers = nn.ModuleList([JetMoEBlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]) 960 | self._attn_implementation = config._attn_implementation 961 | self.norm = JetMoERMSNorm(config.hidden_size, eps=config.rms_norm_eps) 962 | 963 | self.gradient_checkpointing = False 964 | # Initialize weights and apply final processing 965 | self.post_init() 966 | 967 | def get_input_embeddings(self): 968 | return self.embed_tokens 969 | 970 | def set_input_embeddings(self, value): 971 | self.embed_tokens = value 972 | 973 | @add_start_docstrings_to_model_forward(JETMOE_INPUTS_DOCSTRING) 974 | def forward( 975 | self, 976 | input_ids: torch.LongTensor = None, 977 | attention_mask: Optional[torch.Tensor] = None, 978 | position_ids: Optional[torch.LongTensor] = None, 979 | past_key_values: Optional[List[torch.FloatTensor]] = None, 980 | inputs_embeds: Optional[torch.FloatTensor] = None, 981 | use_cache: Optional[bool] = None, 982 | output_attentions: Optional[bool] = None, 983 | output_hidden_states: Optional[bool] = None, 984 | return_dict: Optional[bool] = None, 985 | ) -> Union[Tuple, BaseModelOutputWithPast]: 986 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 987 | output_hidden_states = ( 988 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 989 | ) 990 | use_cache = use_cache if use_cache is not None else self.config.use_cache 991 | 992 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 993 | 994 | # retrieve input_ids and inputs_embeds 995 | if input_ids is not None and inputs_embeds is not None: 996 | raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") 997 | elif input_ids is not None: 998 | batch_size, seq_length = input_ids.shape 999 | elif inputs_embeds is not None: 1000 | batch_size, seq_length, _ = inputs_embeds.shape 1001 | else: 1002 | raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") 1003 | 1004 | if self.gradient_checkpointing and self.training: 1005 | if use_cache: 1006 | logger.warning_once( 1007 | "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." 1008 | ) 1009 | use_cache = False 1010 | 1011 | past_key_values_length = 0 1012 | 1013 | if use_cache: 1014 | use_legacy_cache = not isinstance(past_key_values, Cache) 1015 | if use_legacy_cache: 1016 | past_key_values = DynamicCache.from_legacy_cache(past_key_values) 1017 | past_key_values_length = past_key_values.get_usable_length(seq_length) 1018 | 1019 | if position_ids is None: 1020 | device = input_ids.device if input_ids is not None else inputs_embeds.device 1021 | position_ids = torch.arange( 1022 | past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device 1023 | ) 1024 | position_ids = position_ids.unsqueeze(0).view(-1, seq_length) 1025 | else: 1026 | position_ids = position_ids.view(-1, seq_length).long() 1027 | 1028 | if inputs_embeds is None: 1029 | inputs_embeds = self.embed_tokens(input_ids) 1030 | 1031 | if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache: 1032 | is_padding_right = attention_mask[:, -1].sum().item() != batch_size 1033 | if is_padding_right: 1034 | raise ValueError( 1035 | "You are attempting to perform batched generation with padding_side='right'" 1036 | " this may lead to unexpected behaviour for Flash Attention version of JetMoE. Make sure to " 1037 | " call `tokenizer.padding_side = 'left'` before tokenizing the input. " 1038 | ) 1039 | 1040 | if self._attn_implementation == "flash_attention_2": 1041 | # 2d mask is passed through the layers 1042 | attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None 1043 | elif self._attn_implementation == "sdpa" and not output_attentions: 1044 | # output_attentions=True can not be supported when using SDPA, and we fall back on 1045 | # the manual implementation that requires a 4D causal mask in all cases. 1046 | attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( 1047 | attention_mask, 1048 | (batch_size, seq_length), 1049 | inputs_embeds, 1050 | past_key_values_length, 1051 | ) 1052 | else: 1053 | # 4d mask is passed through the layers 1054 | attention_mask = _prepare_4d_causal_attention_mask( 1055 | attention_mask, 1056 | (batch_size, seq_length), 1057 | inputs_embeds, 1058 | past_key_values_length, 1059 | ) 1060 | 1061 | hidden_states = inputs_embeds 1062 | 1063 | # decoder layers 1064 | all_hidden_states = () if output_hidden_states else None 1065 | all_self_attns = () if output_attentions else None 1066 | next_decoder_cache = None 1067 | 1068 | aux_loss = 0 1069 | for decoder_layer in self.layers: 1070 | if output_hidden_states: 1071 | all_hidden_states += (hidden_states,) 1072 | 1073 | # hidden_states: Optional[torch.FloatTensor], 1074 | # position_ids: Optional[torch.LongTensor] = None, 1075 | # past_key_value: Optional[Tuple[torch.Tensor]] = None, 1076 | # attention_mask: Optional[torch.FloatTensor] = None, 1077 | # output_attentions: Optional[bool] = False, 1078 | # use_cache: Optional[bool] = False, 1079 | 1080 | if self.gradient_checkpointing and self.training: 1081 | layer_outputs = self._gradient_checkpointing_func( 1082 | # decoder_layer.__call__, 1083 | decoder_layer, 1084 | hidden_states, 1085 | position_ids, 1086 | past_key_values, 1087 | attention_mask, 1088 | output_attentions, 1089 | use_cache, 1090 | use_reentrant=False, 1091 | ) 1092 | else: 1093 | layer_outputs = decoder_layer( 1094 | hidden_states, 1095 | attention_mask=attention_mask, 1096 | position_ids=position_ids, 1097 | past_key_value=past_key_values, 1098 | output_attentions=output_attentions, 1099 | use_cache=use_cache, 1100 | ) 1101 | 1102 | hidden_states = layer_outputs[0] 1103 | 1104 | if use_cache: 1105 | next_decoder_cache = layer_outputs[2 if output_attentions else 1] 1106 | 1107 | if output_attentions: 1108 | all_self_attns += (layer_outputs[1],) 1109 | 1110 | aux_loss += layer_outputs[-1] 1111 | 1112 | hidden_states = self.norm(hidden_states) 1113 | 1114 | # add hidden states from the last decoder layer 1115 | if output_hidden_states: 1116 | all_hidden_states += (hidden_states,) 1117 | 1118 | next_cache = None 1119 | if use_cache: 1120 | next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache 1121 | 1122 | if not return_dict: 1123 | return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) 1124 | return JetMoEBaseModelOutputWithPast( 1125 | last_hidden_state=hidden_states, 1126 | past_key_values=next_cache, 1127 | hidden_states=all_hidden_states, 1128 | attentions=all_self_attns, 1129 | aux_loss=aux_loss, 1130 | ) 1131 | 1132 | 1133 | class JetMoEForCausalLM(JetMoEPreTrainedModel): 1134 | _tied_weights_keys = ["lm_head.weight"] 1135 | 1136 | def __init__(self, config): 1137 | super().__init__(config) 1138 | self.model = JetMoEModel(config) 1139 | self.vocab_size = config.vocab_size 1140 | self.aux_loss_coef = getattr(config, "aux_loss_coef", 0.01) 1141 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 1142 | self.tie_word_embeddings = config.tie_word_embeddings 1143 | 1144 | # Initialize weights and apply final processing 1145 | self.post_init() 1146 | 1147 | def get_input_embeddings(self): 1148 | return self.model.embed_tokens 1149 | 1150 | def set_input_embeddings(self, value): 1151 | self.model.embed_tokens = value 1152 | 1153 | def get_output_embeddings(self): 1154 | return self.lm_head 1155 | 1156 | def set_output_embeddings(self, new_embeddings): 1157 | self.lm_head = new_embeddings 1158 | 1159 | def set_decoder(self, decoder): 1160 | self.model = decoder 1161 | 1162 | def get_decoder(self): 1163 | return self.model 1164 | 1165 | @add_start_docstrings_to_model_forward(JETMOE_INPUTS_DOCSTRING) 1166 | @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) 1167 | def forward( 1168 | self, 1169 | input_ids: torch.LongTensor = None, 1170 | attention_mask: Optional[torch.Tensor] = None, 1171 | position_ids: Optional[torch.LongTensor] = None, 1172 | past_key_values: Optional[List[torch.FloatTensor]] = None, 1173 | inputs_embeds: Optional[torch.FloatTensor] = None, 1174 | labels: Optional[torch.LongTensor] = None, 1175 | use_cache: Optional[bool] = None, 1176 | output_attentions: Optional[bool] = None, 1177 | output_hidden_states: Optional[bool] = None, 1178 | return_dict: Optional[bool] = None, 1179 | ) -> Union[Tuple, CausalLMOutputWithPast]: 1180 | r""" 1181 | Args: 1182 | labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): 1183 | Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., 1184 | config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored 1185 | (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. 1186 | 1187 | Returns: 1188 | """ 1189 | 1190 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 1191 | output_hidden_states = ( 1192 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 1193 | ) 1194 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1195 | 1196 | # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) 1197 | outputs = self.model( 1198 | input_ids=input_ids, 1199 | attention_mask=attention_mask, 1200 | position_ids=position_ids, 1201 | past_key_values=past_key_values, 1202 | inputs_embeds=inputs_embeds, 1203 | use_cache=use_cache, 1204 | output_attentions=output_attentions, 1205 | output_hidden_states=output_hidden_states, 1206 | return_dict=return_dict, 1207 | ) 1208 | 1209 | hidden_states = outputs[0] 1210 | logits = self.lm_head(hidden_states) 1211 | logits = logits.float() 1212 | 1213 | loss = None 1214 | if labels is not None: 1215 | # Shift so that tokens < n predict n 1216 | shift_logits = logits[..., :-1, :].contiguous() 1217 | shift_labels = labels[..., 1:].contiguous() 1218 | # Flatten the tokens 1219 | shift_logits = shift_logits.view(-1, self.config.vocab_size) 1220 | shift_labels = shift_labels.view(-1) 1221 | # Ensure tensors are on the same device 1222 | shift_labels = shift_labels.to(shift_logits.device) 1223 | loss_fct = CrossEntropyLoss() 1224 | loss = loss_fct(shift_logits, shift_labels) 1225 | 1226 | if not return_dict: 1227 | output = (logits,) + outputs[1:] 1228 | return (loss,) + output if loss is not None else output 1229 | 1230 | if labels is not None and self.model.training: 1231 | loss += self.aux_loss_coef * outputs.aux_loss.to(loss.device) 1232 | 1233 | return JetMoECausalLMOutputWithPast( 1234 | loss=loss, 1235 | logits=logits, 1236 | past_key_values=outputs.past_key_values, 1237 | hidden_states=outputs.hidden_states, 1238 | attentions=outputs.attentions, 1239 | aux_loss=outputs.aux_loss, 1240 | ) 1241 | 1242 | def prepare_inputs_for_generation( 1243 | self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs 1244 | ): 1245 | # Omit tokens covered by past_key_values 1246 | if past_key_values is not None: 1247 | if isinstance(past_key_values, Cache): 1248 | cache_length = past_key_values.get_seq_length() 1249 | past_length = past_key_values.seen_tokens 1250 | max_cache_length = past_key_values.get_max_length() 1251 | else: 1252 | cache_length = past_length = past_key_values[0][0].shape[2] 1253 | max_cache_length = None 1254 | 1255 | # Keep only the unprocessed tokens: 1256 | # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where 1257 | # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as 1258 | # input) 1259 | if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: 1260 | input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] 1261 | # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard 1262 | # input_ids based on the past_length. 1263 | elif past_length < input_ids.shape[1]: 1264 | input_ids = input_ids[:, past_length:] 1265 | # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. 1266 | 1267 | # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. 1268 | if ( 1269 | max_cache_length is not None 1270 | and attention_mask is not None 1271 | and cache_length + input_ids.shape[1] > max_cache_length 1272 | ): 1273 | attention_mask = attention_mask[:, -max_cache_length:] 1274 | 1275 | position_ids = kwargs.get("position_ids", None) 1276 | if attention_mask is not None and position_ids is None: 1277 | # create position_ids on the fly for batch generation 1278 | position_ids = attention_mask.long().cumsum(-1) - 1 1279 | position_ids.masked_fill_(attention_mask == 0, 1) 1280 | if past_key_values: 1281 | position_ids = position_ids[:, -input_ids.shape[1] :] 1282 | 1283 | # if `inputs_embeds` are passed, we only want to use them in the 1st generation step 1284 | if inputs_embeds is not None and past_key_values is None: 1285 | model_inputs = {"inputs_embeds": inputs_embeds} 1286 | else: 1287 | model_inputs = {"input_ids": input_ids} 1288 | 1289 | model_inputs.update( 1290 | { 1291 | "position_ids": position_ids, 1292 | "past_key_values": past_key_values, 1293 | "use_cache": kwargs.get("use_cache"), 1294 | "attention_mask": attention_mask, 1295 | } 1296 | ) 1297 | return model_inputs 1298 | 1299 | @staticmethod 1300 | def _reorder_cache(past_key_values, beam_idx): 1301 | reordered_past = () 1302 | for layer_past in past_key_values: 1303 | reordered_past += ( 1304 | tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), 1305 | ) 1306 | return reordered_past 1307 | 1308 | 1309 | @add_start_docstrings( 1310 | """ 1311 | The JetMoE Model transformer with a sequence classification head on top (linear layer). 1312 | 1313 | [`JetMoEForSequenceClassification`] uses the last token in order to do the classification, as other causal models 1314 | (e.g. GPT-2) do. 1315 | 1316 | Since it does classification on the last token, it requires to know the position of the last token. If a 1317 | `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If 1318 | no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the 1319 | padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in 1320 | each row of the batch). 1321 | """, 1322 | JETMOE_START_DOCSTRING, 1323 | ) 1324 | # Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->JetMoE, LLAMA->JETMOE 1325 | class JetMoEForSequenceClassification(JetMoEPreTrainedModel): 1326 | def __init__(self, config): 1327 | super().__init__(config) 1328 | self.num_labels = config.num_labels 1329 | self.model = JetMoEModel(config) 1330 | self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) 1331 | 1332 | # Initialize weights and apply final processing 1333 | self.post_init() 1334 | 1335 | def get_input_embeddings(self): 1336 | return self.model.embed_tokens 1337 | 1338 | def set_input_embeddings(self, value): 1339 | self.model.embed_tokens = value 1340 | 1341 | @add_start_docstrings_to_model_forward(JETMOE_INPUTS_DOCSTRING) 1342 | def forward( 1343 | self, 1344 | input_ids: torch.LongTensor = None, 1345 | attention_mask: Optional[torch.Tensor] = None, 1346 | position_ids: Optional[torch.LongTensor] = None, 1347 | past_key_values: Optional[List[torch.FloatTensor]] = None, 1348 | inputs_embeds: Optional[torch.FloatTensor] = None, 1349 | labels: Optional[torch.LongTensor] = None, 1350 | use_cache: Optional[bool] = None, 1351 | output_attentions: Optional[bool] = None, 1352 | output_hidden_states: Optional[bool] = None, 1353 | return_dict: Optional[bool] = None, 1354 | ) -> Union[Tuple, SequenceClassifierOutputWithPast]: 1355 | r""" 1356 | labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): 1357 | Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., 1358 | config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If 1359 | `config.num_labels > 1` a classification loss is computed (Cross-Entropy). 1360 | """ 1361 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1362 | 1363 | transformer_outputs = self.model( 1364 | input_ids, 1365 | attention_mask=attention_mask, 1366 | position_ids=position_ids, 1367 | past_key_values=past_key_values, 1368 | inputs_embeds=inputs_embeds, 1369 | use_cache=use_cache, 1370 | output_attentions=output_attentions, 1371 | output_hidden_states=output_hidden_states, 1372 | return_dict=return_dict, 1373 | ) 1374 | hidden_states = transformer_outputs[0] 1375 | logits = self.score(hidden_states) 1376 | 1377 | if input_ids is not None: 1378 | batch_size = input_ids.shape[0] 1379 | else: 1380 | batch_size = inputs_embeds.shape[0] 1381 | 1382 | if self.config.pad_token_id is None and batch_size != 1: 1383 | raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") 1384 | if self.config.pad_token_id is None: 1385 | sequence_lengths = -1 1386 | else: 1387 | if input_ids is not None: 1388 | # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility 1389 | sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 1390 | sequence_lengths = sequence_lengths % input_ids.shape[-1] 1391 | sequence_lengths = sequence_lengths.to(logits.device) 1392 | else: 1393 | sequence_lengths = -1 1394 | 1395 | pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] 1396 | 1397 | loss = None 1398 | if labels is not None: 1399 | labels = labels.to(logits.device) 1400 | if self.config.problem_type is None: 1401 | if self.num_labels == 1: 1402 | self.config.problem_type = "regression" 1403 | elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): 1404 | self.config.problem_type = "single_label_classification" 1405 | else: 1406 | self.config.problem_type = "multi_label_classification" 1407 | 1408 | if self.config.problem_type == "regression": 1409 | loss_fct = MSELoss() 1410 | if self.num_labels == 1: 1411 | loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) 1412 | else: 1413 | loss = loss_fct(pooled_logits, labels) 1414 | elif self.config.problem_type == "single_label_classification": 1415 | loss_fct = CrossEntropyLoss() 1416 | loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) 1417 | elif self.config.problem_type == "multi_label_classification": 1418 | loss_fct = BCEWithLogitsLoss() 1419 | loss = loss_fct(pooled_logits, labels) 1420 | if not return_dict: 1421 | output = (pooled_logits,) + transformer_outputs[1:] 1422 | return ((loss,) + output) if loss is not None else output 1423 | 1424 | return SequenceClassifierOutputWithPast( 1425 | loss=loss, 1426 | logits=pooled_logits, 1427 | past_key_values=transformer_outputs.past_key_values, 1428 | hidden_states=transformer_outputs.hidden_states, 1429 | attentions=transformer_outputs.attentions, 1430 | ) 1431 | -------------------------------------------------------------------------------- /jetmoe/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .moe import MoE 2 | from .parallel_experts import ParallelExperts -------------------------------------------------------------------------------- /jetmoe/utils/gate.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class top_k_gating(nn.Module): 6 | def __init__( 7 | self, 8 | input_size, 9 | num_experts, 10 | top_k, 11 | ): 12 | """ 13 | Initialize the top-k gating mechanism. 14 | 15 | Args: 16 | input_size (int): Size of the input. 17 | num_experts (int): Number of experts. 18 | top_k (int): Number of top experts to select. 19 | acc_aux_loss (bool): Whether to accumulate auxiliary loss statistics. 20 | dropout (float): Dropout rate for gating network. 21 | hidden_size (int): Hidden size of the gating network. 22 | sample_topk (int): Number of top-k experts to sample during training. 23 | aux_loss (str): Type of auxiliary loss ('mi' or 'switch'). 24 | gate_type (str): Type of gating mechanism ('mlp', 'linear', or 'gmm'). 25 | """ 26 | super().__init__() 27 | 28 | self.num_experts = num_experts 29 | self.input_size = input_size 30 | assert top_k <= num_experts 31 | self.top_k = top_k 32 | 33 | self.layer = nn.Linear(input_size, num_experts, bias=False) 34 | 35 | def extra_repr(self): 36 | """ 37 | Return extra representation string for the module. 38 | """ 39 | return "k={}, num_experts={}".format(self.top_k, self.num_experts) 40 | 41 | def compute_aux_loss(self, probs, logits, gates): 42 | """ 43 | Calculate and return the auxiliary loss based on the accumulated statistics. 44 | 45 | Args: 46 | eps (float): Small epsilon value for numerical stability. 47 | 48 | Returns: 49 | torch.Tensor: The calculated auxiliary loss. 50 | """ 51 | count = logits.size(0) 52 | probs = probs.sum(0) 53 | freq = (gates > 0).float().sum(0) 54 | lsesq = (torch.log(torch.exp(logits).sum(dim=-1)) ** 2).sum() 55 | 56 | switchloss = self.num_experts * (F.normalize(probs, p=1, dim=0) * F.normalize(freq, p=1, dim=0)).sum() 57 | zloss = lsesq / count 58 | loss = switchloss + 0.1 * zloss 59 | 60 | return loss 61 | 62 | def forward(self, x): 63 | """ 64 | Compute the top-k gating for the input. 65 | 66 | See paper: https://arxiv.org/abs/1701.06538. 67 | 68 | Args: 69 | x (torch.Tensor): Input tensor with shape [batch_size, input_size]. 70 | skip_mask (torch.Tensor): Skip mask tensor (binary) with the same shape as `x`. 71 | x: input Tensor with shape [batch_size, input_size] 72 | train: a boolean - we only add noise at training time. 73 | noise_epsilon: a float 74 | 75 | Returns: 76 | torch.Tensor: Top-k indices. 77 | torch.Tensor: Top-k gating values. 78 | torch.Tensor: Probability values for each expert. 79 | gates: a Tensor with shape [batch_size, num_experts] 80 | load: a Tensor with shape [num_experts] 81 | """ 82 | 83 | logits = self.layer(x).float() 84 | top_k_logits, top_k_indices = logits.topk(self.top_k, dim=1) 85 | top_k_gates = torch.softmax(top_k_logits, dim=1).type_as(x) 86 | 87 | if self.training: 88 | probs = torch.softmax(logits, dim=1) 89 | zeros = torch.zeros_like(probs) 90 | zeros = zeros.to(top_k_gates.dtype) # Convert zeros to match top_k_gates dtype 91 | gates = zeros.scatter(1, top_k_indices, top_k_gates) 92 | self.loss = self.compute_aux_loss(probs, logits, gates) 93 | else: 94 | self.loss = 0 95 | 96 | return top_k_indices, top_k_gates -------------------------------------------------------------------------------- /jetmoe/utils/moe.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from .parallel_experts import ParallelExperts, compute_gating 8 | 9 | from .gate import top_k_gating 10 | 11 | 12 | class MoE(nn.Module): 13 | """ 14 | A Sparsely gated mixture of experts layer with 1-layer Feed-Forward networks as experts. 15 | 16 | Args: 17 | input_size: integer - size of the input 18 | hidden_size: integer - size of the expert's hidden layer 19 | num_experts: an integer - number of experts 20 | top_k: an integer - how many experts to use for each batch element 21 | bias: a boolean - whether to include bias in linear layers 22 | activation: an activation function to apply to expert's outputs 23 | glu: an boolean - whether to use GLU activation 24 | """ 25 | 26 | def __init__( 27 | self, 28 | input_size, 29 | hidden_size, 30 | num_experts, 31 | top_k, 32 | bias=True, 33 | activation=None, 34 | glu=True, 35 | ): 36 | super(MoE, self).__init__() 37 | 38 | self.num_experts = num_experts 39 | self.input_size = input_size 40 | self.hidden_size = hidden_size 41 | self.glu = glu 42 | if bias: 43 | self.bias = torch.nn.Parameter(torch.empty(input_size)) 44 | torch.nn.init.zeros_(self.bias) 45 | else: 46 | self.bias = None 47 | 48 | self.input_linear = ParallelExperts(num_experts, input_size, hidden_size * 2 if glu else hidden_size) 49 | self.output_linear = ParallelExperts(num_experts, hidden_size, input_size) 50 | 51 | self.top_k = min(top_k, self.num_experts) 52 | self.activation = activation 53 | 54 | self.router = top_k_gating( 55 | input_size=input_size, 56 | num_experts=num_experts, 57 | top_k=top_k, 58 | ) 59 | 60 | def extra_repr(self): 61 | return "k={}, e={}".format(self.top_k, self.num_experts) 62 | 63 | def get_aux_loss_and_clear(self): 64 | """ 65 | Get the accumulated auxiliary loss and clear it. 66 | 67 | Returns: 68 | float: Accumulated auxiliary loss. 69 | """ 70 | 71 | return self.gate.get_aux_loss_and_clear() 72 | 73 | def compute_gate(self, x): 74 | top_k_indices, self.top_k_gates = self.router(x) 75 | 76 | self.batch_gates, self.batch_index, expert_size, self.index_sorted_experts = compute_gating( 77 | self.top_k, self.num_experts, self.top_k_gates, top_k_indices 78 | ) 79 | self.expert_size = expert_size.tolist() 80 | 81 | return self.router.loss 82 | 83 | def batch_forward(self, x): 84 | """ 85 | Forward pass of the mixture of experts layer. 86 | 87 | Args: 88 | x (Tensor): Input tensor. 89 | skip_mask (Tensor): Skip mask tensor. 90 | sample_topk (int): Number of experts to sample during training. 91 | multiply_by_gates (bool): Whether to multiply outputs by gating values. 92 | 93 | Returns: 94 | Tensor: Output tensor. 95 | float: Gating loss. 96 | """ 97 | bsz, length, emb_size = x.size() 98 | x = x.reshape(-1, emb_size) 99 | loss = self.compute_gate(x) 100 | 101 | expert_inputs = x[self.batch_index] 102 | h = self.input_linear(expert_inputs, self.expert_size) 103 | if self.glu: 104 | h, g = h.chunk(2, dim=-1) 105 | h = self.activation(h) * g 106 | else: 107 | h = self.activation(h) 108 | expert_outputs = self.output_linear(h, self.expert_size) 109 | 110 | expert_outputs = expert_outputs * self.batch_gates[:, None] 111 | 112 | zeros = torch.zeros((bsz * length, self.input_size), dtype=expert_outputs.dtype, device=expert_outputs.device) 113 | y = zeros.index_add(0, self.batch_index, expert_outputs) 114 | y = y.view(bsz, length, self.input_size) 115 | if self.bias is not None: 116 | y = y + self.bias 117 | return y, loss 118 | 119 | def single_forward(self, x): 120 | bsz, length, emb_size = x.size() 121 | 122 | x = x.reshape(1, self.input_size) 123 | top_k_indices, top_k_gates = self.router(x) 124 | loss = self.router.loss 125 | 126 | y_list = [] 127 | for i in range(self.top_k): 128 | expert_idx = top_k_indices[0, i] 129 | 130 | h = F.linear(x, self.input_linear.weight[expert_idx]) 131 | if self.glu: 132 | h, g = h.chunk(2, dim=-1) 133 | h = self.activation(h) * g 134 | else: 135 | h = self.activation(h) 136 | y = F.linear(h, self.output_linear.weight[expert_idx]) * top_k_gates[0, i] 137 | 138 | y_list.append(y) 139 | 140 | y = sum(y_list) 141 | y = y.view(bsz, length, self.input_size) 142 | if self.bias is not None: 143 | y = y + self.bias 144 | return y, loss 145 | 146 | def forward(self, x): 147 | """ 148 | Forward pass of the mixture of experts layer. 149 | 150 | Args: 151 | x (Tensor): Input tensor. 152 | 153 | Returns: 154 | Tensor: Output tensor. 155 | """ 156 | bsz, length, emb_size = x.size() 157 | if bsz * length == 1: 158 | return self.single_forward(x) 159 | else: 160 | return self.batch_forward(x) 161 | 162 | def single_map(self, x): 163 | bsz, length, emb_size = x.size() 164 | 165 | x = x.reshape(1, self.input_size) 166 | self.top_k_indices, self.top_k_gates = self.router(x) 167 | loss = self.router.loss 168 | 169 | y_list = [] 170 | for i in range(self.top_k): 171 | expert_idx = self.top_k_indices[0, i] 172 | y = F.linear(x, self.input_linear.weight[expert_idx]) 173 | y_list.append(y) 174 | y = torch.cat(y_list, dim=0) 175 | y = y.view(bsz, length, self.top_k, -1) 176 | return y, loss 177 | 178 | def batch_map(self, x): 179 | """ 180 | 181 | Args: 182 | x: tensor shape [batch_size, input_size] 183 | train: a boolean scalar. 184 | loss_coef: a scalar - multiplier on load-balancing losses 185 | 186 | Returns: 187 | y: a tensor with shape [batch_size, output_size]. 188 | extra_training_loss: a scalar. This should be added into the overall 189 | training loss of the model. The backpropagation of this loss 190 | encourages all experts to be approximately equally used across a batch. 191 | """ 192 | """ 193 | Map input through the mixture of experts layer. 194 | 195 | Args: 196 | x (Tensor): Input tensor. 197 | skip_mask (Tensor): Skip mask tensor. 198 | sample_topk (int): Number of experts to sample during training. 199 | return_indices (bool): Whether to return expert indices. 200 | 201 | Returns: 202 | Tensor: Output tensor. 203 | float: Gating loss. 204 | """ 205 | bsz, length, emb_size = x.size() 206 | x = x.reshape(-1, emb_size) 207 | loss = self.compute_gate(x) 208 | 209 | expert_inputs = x[self.batch_index] 210 | expert_outputs = self.input_linear(expert_inputs, self.expert_size) 211 | 212 | zeros = torch.zeros( 213 | (bsz * length * self.top_k, self.hidden_size), dtype=expert_outputs.dtype, device=expert_outputs.device 214 | ) 215 | y = zeros.index_add(0, self.index_sorted_experts, expert_outputs) 216 | y = y.view(bsz, length, self.top_k, -1) 217 | return y, loss 218 | 219 | def map(self, x): 220 | """ 221 | Map input through the mixture of experts layer. 222 | 223 | Args: 224 | x (Tensor): Input tensor. 225 | 226 | Returns: 227 | Tensor: Output tensor. 228 | """ 229 | bsz, length, emb_size = x.size() 230 | if bsz * length == 1: 231 | return self.single_map(x) 232 | else: 233 | return self.batch_map(x) 234 | 235 | def single_reduce(self, x): 236 | bsz, length, k, emb_size = x.size() 237 | 238 | x = x.reshape(k, emb_size) 239 | 240 | y_list = [] 241 | for i in range(self.top_k): 242 | expert_idx = self.top_k_indices[0, i] 243 | y = F.linear(x[i], self.output_linear.weight[expert_idx]) * self.top_k_gates[0, i] 244 | y_list.append(y) 245 | y = sum(y_list) 246 | y = y.view(bsz, length, self.input_size) 247 | if self.bias is not None: 248 | y = y + self.bias 249 | return y 250 | 251 | def batch_reduce(self, x): 252 | """ 253 | Reduce the mapped output. 254 | 255 | Args: 256 | x (Tensor): Mapped output tensor. 257 | multiply_by_gates (bool): Whether to multiply outputs by gating values. 258 | 259 | Returns: 260 | Tensor: Reduced output tensor. 261 | """ 262 | 263 | bsz, length, k, emb_size = x.size() 264 | x = x.reshape(-1, emb_size) 265 | 266 | expert_inputs = x[self.index_sorted_experts] 267 | expert_outputs = self.output_linear(expert_inputs, self.expert_size) 268 | 269 | expert_outputs = expert_outputs * self.batch_gates[:, None] 270 | 271 | zeros = torch.zeros((bsz * length, self.input_size), dtype=expert_outputs.dtype, device=expert_outputs.device) 272 | y = zeros.index_add(0, self.batch_index, expert_outputs) 273 | y = y.view(bsz, length, self.input_size) 274 | if self.bias is not None: 275 | y = y + self.bias 276 | return y 277 | 278 | def reduce(self, x): 279 | """ 280 | Reduce the mapped output. 281 | 282 | Args: 283 | x (Tensor): Mapped output tensor. 284 | 285 | Returns: 286 | Tensor: Reduced output tensor. 287 | """ 288 | bsz, length, k, emb_size = x.size() 289 | if bsz * length == 1: 290 | return self.single_reduce(x) 291 | else: 292 | return self.batch_reduce(x) -------------------------------------------------------------------------------- /jetmoe/utils/parallel_experts.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | @torch.jit.script 7 | def compute_gating(k: int, num_experts: int, top_k_gates: torch.Tensor, top_k_indices: torch.Tensor): 8 | """ 9 | Compute gating values for the mixture of experts based on probabilities and top-k indices. 10 | 11 | Args: 12 | k (int): Number of experts to select. 13 | num_experts (int): Total number of experts. 14 | top_k_gates (torch.Tensor): Gating values for top-k experts (batch_size x k). 15 | top_k_indices (torch.Tensor): Indices of top-k experts (batch_size x k). 16 | 17 | Returns: 18 | torch.Tensor: Batch-level gating values. 19 | torch.Tensor: Batch-level expert indices. 20 | torch.Tensor: Expert size for each expert. 21 | torch.Tensor: Sorted indices of top-k experts. 22 | """ 23 | zeros = torch.zeros([top_k_gates.size(0), num_experts], dtype=top_k_gates.dtype, device=top_k_gates.device) 24 | gates = zeros.scatter(1, top_k_indices, 1) 25 | expert_size = gates.long().sum(0) 26 | top_k_gates = top_k_gates.flatten() 27 | top_k_experts = top_k_indices.flatten() 28 | _, index_sorted_experts = top_k_experts.sort(0) 29 | batch_index = index_sorted_experts.div(k, rounding_mode="trunc") 30 | batch_gates = top_k_gates[index_sorted_experts] 31 | return batch_gates, batch_index, expert_size, index_sorted_experts 32 | 33 | 34 | class ParallelExperts(nn.Module): 35 | def __init__(self, num_experts, input_size, output_size) -> None: 36 | """ 37 | Initialize the ParallelExperts module. 38 | 39 | Args: 40 | num_experts (int): Number of experts. 41 | input_size (int): Size of the input. 42 | output_size (int): Size of the output. 43 | bias (bool): Whether to include bias terms. 44 | """ 45 | super().__init__() 46 | self.weight = nn.Parameter(torch.empty(num_experts, output_size, input_size)) 47 | self.reset_parameters() 48 | self.num_experts = num_experts 49 | self.input_size = input_size 50 | self.output_size = output_size 51 | 52 | def extra_repr(self): 53 | return "num_experts={}, input_size={}, output_size={}".format( 54 | self.num_experts, self.input_size, self.output_size 55 | ) 56 | 57 | def reset_parameters(self) -> None: 58 | """ 59 | Reset the parameters of the model. 60 | """ 61 | nn.init.uniform_(self.weight, -1.0 / self.weight.size(1), 1.0 / self.weight.size(1)) 62 | 63 | def forward(self, inputs, expert_size): 64 | """ 65 | Forward pass of the ParallelExperts module. 66 | 67 | Args: 68 | inputs (Tensor): Input tensor. 69 | expert_size: Expert size information. 70 | 71 | Returns: 72 | Tensor: Output tensor. 73 | """ 74 | input_list = inputs.split(expert_size, dim=0) 75 | output_list = [] 76 | for i in range(self.num_experts): 77 | output_list.append(F.linear(input_list[i], self.weight[i])) 78 | results = torch.cat(output_list, dim=0) 79 | return results -------------------------------------------------------------------------------- /resources/1-myshell-mit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/myshell-ai/JetMoE/0a9bc5a32386af9ea7abe9e32298976a94da908a/resources/1-myshell-mit.png -------------------------------------------------------------------------------- /resources/2-performance.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/myshell-ai/JetMoE/0a9bc5a32386af9ea7abe9e32298976a94da908a/resources/2-performance.png -------------------------------------------------------------------------------- /resources/3-architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/myshell-ai/JetMoE/0a9bc5a32386af9ea7abe9e32298976a94da908a/resources/3-architecture.png -------------------------------------------------------------------------------- /resources/4-phase1-data.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/myshell-ai/JetMoE/0a9bc5a32386af9ea7abe9e32298976a94da908a/resources/4-phase1-data.png -------------------------------------------------------------------------------- /resources/5-phase2-data.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/myshell-ai/JetMoE/0a9bc5a32386af9ea7abe9e32298976a94da908a/resources/5-phase2-data.png -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup(name='jetmoe', 4 | packages=find_packages(), 5 | install_requires=[ 6 | 'torch', 7 | 'transformers', 8 | 'scattermoe @ git+https://github.com/shawntan/scattermoe@main#egg=scattermoe' 9 | ]) --------------------------------------------------------------------------------