├── .gitignore ├── LICENSE ├── README.md ├── examples ├── convert.py ├── mixtral │ ├── README.md │ ├── __init__.py │ ├── configuration_mixtral.py │ └── modeling_mixtral.py └── molora.py ├── scattermoe ├── __init__.py ├── kernels │ ├── __init__.py │ ├── ops.py │ └── single.py ├── mlp.py └── parallel_experts.py ├── setup.py └── tests └── test_mlp.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | .idea/ 161 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # scattermoe 2 | Triton-based implementation of Sparse Mixture-of-Experts (SMoE) on GPUs. 3 | ScatterMoE builds upon existing implementations, and overcoming some of the limitations to improve inference, training speed, and memory footprint. 4 | This implementation achieves this by avoiding padding and making excessive copies of the input. 5 | We also fuse expert linear transforms and reordering operations with `ParallelLinear`, a module that can be used to extend the concept of SMoEs. 6 | 7 | This implementation is lightweight (~700 lines). 8 | It will work within an FSDP or pipeline parallel framework, but does not include any additional multi-node training infrastructure code. 9 | You can find the report [here](https://arxiv.org/abs/2403.08245) 10 | 11 | ## Installation 12 | ```sh 13 | # Check all is working well. 14 | PYTHONPATH=. pytest tests 15 | # Install editable. This will allow you to modify scattermoe in this directory. 16 | pip install -e . 17 | ``` 18 | 19 | ## Usage 20 | ```python 21 | from scattermoe.mlp import MLP 22 | 23 | # Initialise module... 24 | mlp = MLP( 25 | input_size=x_dim, hidden_size=h_dim, 26 | activation=nn.GELU(), 27 | num_experts=E, top_k=k 28 | ) 29 | 30 | # Calling module... 31 | Y = mlp( 32 | X, # input tensor 33 | k_weights, # top-k weights from router 34 | k_idxs # top-k indices from router 35 | ) 36 | ``` 37 | 38 | ## Bibtex 39 | If you use ScatterMoE in your project, cite us! 40 | ```bibtex 41 | @article{tan2024scattered, 42 | title={Scattered Mixture-of-Experts Implementation}, 43 | author={Tan, Shawn and Shen, Yikang and Panda, Rameswar and Courville, Aaron}, 44 | journal={arXiv preprint arXiv:2403.08245}, 45 | year={2024} 46 | } 47 | ``` 48 | 49 | Enjoy! 50 | ---- 51 | ### Version 0.2.0 52 | 53 | - Made compileable. 54 | 55 | ----- 56 | ### More examples 57 | 1. [Integration into HuggingFace Mixtral](https://github.com/shawntan/scattermoe/tree/main/examples/mixtral) 58 | -------------------------------------------------------------------------------- /examples/convert.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import transformers 3 | import gc 4 | from mixtral.modeling_mixtral import MixtralModel, MixtralForCausalLM 5 | from mixtral.configuration_mixtral import MixtralConfig 6 | MODEL_NAME = "mistralai/Mixtral-8x7B-v0.1" 7 | import sys 8 | 9 | if __name__ == "__main__": 10 | target_directory = sys.argv[1] 11 | dtype = torch.bfloat16 12 | config = MixtralConfig.from_pretrained(MODEL_NAME, low_cpu_mem_usage=True, torch_dtype=dtype) 13 | num_experts = config.num_local_experts 14 | print("Loading original...") 15 | model_orig = transformers.MixtralForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=dtype, low_cpu_mem_usage=True) 16 | print("Initialising ScatterMoE") 17 | model = MixtralForCausalLM(config).to(dtype) 18 | state_dict_orig = model_orig.state_dict() 19 | for n, p in model.named_parameters(): 20 | assert p.dtype == torch.bfloat16 21 | if n in state_dict_orig: 22 | p.data[:] = state_dict_orig.pop(n) 23 | else: 24 | prefix, suffix = n.split('moe_mlp') 25 | for i in range(num_experts): 26 | if suffix == ".output_experts.weight": 27 | w2_param_name = prefix + "experts.%d.w2.weight" % i 28 | assert state_dict_orig[w2_param_name].dtype == torch.bfloat16 29 | p.data[i, :, :] = state_dict_orig.pop(w2_param_name) 30 | else: 31 | w1_param_name = prefix + "experts.%d.w1.weight" % i 32 | w3_param_name = prefix + "experts.%d.w3.weight" % i 33 | out_dim, in_dim = state_dict_orig[w1_param_name].size() 34 | p.data[i, :out_dim, :] = state_dict_orig.pop(w3_param_name) 35 | p.data[i, out_dim:, :] = state_dict_orig.pop(w1_param_name) 36 | assert len(state_dict_orig) == 0 37 | print("Saving to file.") 38 | model.to(dtype=torch.bfloat16).save_pretrained(target_directory, save_config=True) 39 | 40 | -------------------------------------------------------------------------------- /examples/mixtral/README.md: -------------------------------------------------------------------------------- 1 | # ScatterMoE Mixtral 2 | 3 | Example integration of ScatterMoE into HuggingFace's implementation of Mixtral. 4 | We replace `MixtralSparseMoeBlock`([original source](https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/mixtral/modeling_mixtral.py#L816)) with a ScatterMoE implementation ([source](https://github.com/shawntan/scattermoe/blob/main/examples/mixtral/modeling_mixtral.py#L667)). 5 | 6 | We do not support loading of the existing Mixtral model for now, but to initialise a model from scratch: 7 | ```python 8 | config = MixtralConfig.from_pretrained( 9 | "mistralai/Mixtral-8x7B-v0.1", 10 | torch_dtype=torch.bfloat16, 11 | low_cpu_mem_usage=True, 12 | attn_implementation='flash_attention_2' 13 | ) 14 | ``` 15 | for training: 16 | ```python 17 | config.output_router_logits = True 18 | ``` 19 | This will ensure that the auxiliary loss is added to the loss computed for training. The MoE auxiliary losses are 20 | load balancing losses that try to that there is no over-reliance on only a few experts during training. Then, 21 | ```python 22 | model = MixtralForCausalLM(config) 23 | ``` 24 | 25 | 26 | 27 | -------------------------------------------------------------------------------- /examples/mixtral/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Mixtral AI and The HuggingFace Inc. team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from typing import TYPE_CHECKING 15 | 16 | from transformers.utils import ( 17 | OptionalDependencyNotAvailable, 18 | _LazyModule, 19 | is_torch_available, 20 | ) 21 | 22 | 23 | _import_structure = { 24 | "configuration_mixtral": ["MIXTRAL_PRETRAINED_CONFIG_ARCHIVE_MAP", "MixtralConfig"], 25 | } 26 | 27 | 28 | try: 29 | if not is_torch_available(): 30 | raise OptionalDependencyNotAvailable() 31 | except OptionalDependencyNotAvailable: 32 | pass 33 | else: 34 | _import_structure["modeling_mixtral"] = [ 35 | "MixtralForCausalLM", 36 | "MixtralModel", 37 | "MixtralPreTrainedModel", 38 | "MixtralForSequenceClassification", 39 | ] 40 | 41 | 42 | if TYPE_CHECKING: 43 | from .configuration_mixtral import MIXTRAL_PRETRAINED_CONFIG_ARCHIVE_MAP, MixtralConfig 44 | 45 | try: 46 | if not is_torch_available(): 47 | raise OptionalDependencyNotAvailable() 48 | except OptionalDependencyNotAvailable: 49 | pass 50 | else: 51 | from .modeling_mixtral import ( 52 | MixtralForCausalLM, 53 | MixtralForSequenceClassification, 54 | MixtralModel, 55 | MixtralPreTrainedModel, 56 | ) 57 | 58 | 59 | else: 60 | import sys 61 | 62 | sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) 63 | -------------------------------------------------------------------------------- /examples/mixtral/configuration_mixtral.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 Mixtral AI and the HuggingFace Inc. team. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ Mixtral model configuration""" 16 | 17 | from transformers.configuration_utils import PretrainedConfig 18 | from transformers.utils import logging 19 | 20 | 21 | logger = logging.get_logger(__name__) 22 | 23 | MIXTRAL_PRETRAINED_CONFIG_ARCHIVE_MAP = { 24 | "mistral-ai/Mixtral-8x7B": "https://huggingface.co/mistral-ai/Mixtral-8x7B/resolve/main/config.json", 25 | } 26 | 27 | 28 | class MixtralConfig(PretrainedConfig): 29 | r""" 30 | This is the configuration class to store the configuration of a [`MixtralModel`]. It is used to instantiate an 31 | Mixtral model according to the specified arguments, defining the model architecture. Instantiating a configuration 32 | with the defaults will yield a similar configuration to that of the Mixtral-7B-v0.1 or Mixtral-7B-Instruct-v0.1. 33 | 34 | [mixtralai/Mixtral-8x7B](https://huggingface.co/mixtralai/Mixtral-8x7B) 35 | [mixtralai/Mixtral-7B-Instruct-v0.1](https://huggingface.co/mixtralai/Mixtral-7B-Instruct-v0.1) 36 | 37 | Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the 38 | documentation from [`PretrainedConfig`] for more information. 39 | 40 | 41 | Args: 42 | vocab_size (`int`, *optional*, defaults to 32000): 43 | Vocabulary size of the Mixtral model. Defines the number of different tokens that can be represented by the 44 | `inputs_ids` passed when calling [`MixtralModel`] 45 | hidden_size (`int`, *optional*, defaults to 4096): 46 | Dimension of the hidden representations. 47 | intermediate_size (`int`, *optional*, defaults to 14336): 48 | Dimension of the MLP representations. 49 | num_hidden_layers (`int`, *optional*, defaults to 32): 50 | Number of hidden layers in the Transformer encoder. 51 | num_attention_heads (`int`, *optional*, defaults to 32): 52 | Number of attention heads for each attention layer in the Transformer encoder. 53 | num_key_value_heads (`int`, *optional*, defaults to 8): 54 | This is the number of key_value heads that should be used to implement Grouped Query Attention. If 55 | `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if 56 | `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When 57 | converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed 58 | by meanpooling all the original heads within that group. For more details checkout [this 59 | paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `8`. 60 | hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): 61 | The non-linear activation function (function or string) in the decoder. 62 | max_position_embeddings (`int`, *optional*, defaults to `4096*32`): 63 | The maximum sequence length that this model might ever be used with. Mixtral's sliding window attention 64 | allows sequence of up to 4096*32 tokens. 65 | initializer_range (`float`, *optional*, defaults to 0.02): 66 | The standard deviation of the truncated_normal_initializer for initializing all weight matrices. 67 | rms_norm_eps (`float`, *optional*, defaults to 1e-05): 68 | The epsilon used by the rms normalization layers. 69 | use_cache (`bool`, *optional*, defaults to `True`): 70 | Whether or not the model should return the last key/values attentions (not used by all models). Only 71 | relevant if `config.is_decoder=True`. 72 | pad_token_id (`int`, *optional*): 73 | The id of the padding token. 74 | bos_token_id (`int`, *optional*, defaults to 1): 75 | The id of the "beginning-of-sequence" token. 76 | eos_token_id (`int`, *optional*, defaults to 2): 77 | The id of the "end-of-sequence" token. 78 | tie_word_embeddings (`bool`, *optional*, defaults to `False`): 79 | Whether the model's input and output word embeddings should be tied. 80 | rope_theta (`float`, *optional*, defaults to 1000000.0): 81 | The base period of the RoPE embeddings. 82 | sliding_window (`int`, *optional*, defaults to 4096): 83 | Sliding window attention window size. If not specified, will default to `4096`. 84 | attention_dropout (`float`, *optional*, defaults to 0.0): 85 | The dropout ratio for the attention probabilities. 86 | num_experts_per_tok (`int`, *optional*, defaults to 2): 87 | The number of experts to root per-token, can be also interpreted as the `top-p` routing 88 | parameter 89 | num_local_experts (`int`, *optional*, defaults to 8): 90 | Number of experts per Sparse MLP layer. 91 | output_router_logits (`bool`, *optional*, defaults to `False`): 92 | Whether or not the router logits should be returned by the model. Enabeling this will also 93 | allow the model to output the auxiliary loss. See [here]() for more details 94 | router_aux_loss_coef (`float`, *optional*, defaults to 0.001): 95 | The aux loss factor for the total loss. 96 | 97 | ```python 98 | >>> from transformers import MixtralModel, MixtralConfig 99 | 100 | >>> # Initializing a Mixtral 7B style configuration 101 | >>> configuration = MixtralConfig() 102 | 103 | >>> # Initializing a model from the Mixtral 7B style configuration 104 | >>> model = MixtralModel(configuration) 105 | 106 | >>> # Accessing the model configuration 107 | >>> configuration = model.config 108 | ```""" 109 | 110 | model_type = "mixtral" 111 | keys_to_ignore_at_inference = ["past_key_values"] 112 | 113 | def __init__( 114 | self, 115 | vocab_size=32000, 116 | hidden_size=4096, 117 | intermediate_size=14336, 118 | num_hidden_layers=32, 119 | num_attention_heads=32, 120 | num_key_value_heads=8, 121 | hidden_act="silu", 122 | max_position_embeddings=4096 * 32, 123 | initializer_range=0.02, 124 | rms_norm_eps=1e-5, 125 | use_cache=True, 126 | pad_token_id=None, 127 | bos_token_id=1, 128 | eos_token_id=2, 129 | tie_word_embeddings=False, 130 | rope_theta=1e6, 131 | sliding_window=4096, 132 | attention_dropout=0.0, 133 | num_experts_per_tok=2, 134 | num_local_experts=8, 135 | output_router_logits=False, 136 | router_aux_loss_coef=0.001, 137 | **kwargs, 138 | ): 139 | self.vocab_size = vocab_size 140 | self.max_position_embeddings = max_position_embeddings 141 | self.hidden_size = hidden_size 142 | self.intermediate_size = intermediate_size 143 | self.num_hidden_layers = num_hidden_layers 144 | self.num_attention_heads = num_attention_heads 145 | self.sliding_window = sliding_window 146 | 147 | # for backward compatibility 148 | if num_key_value_heads is None: 149 | num_key_value_heads = num_attention_heads 150 | 151 | self.num_key_value_heads = num_key_value_heads 152 | self.hidden_act = hidden_act 153 | self.initializer_range = initializer_range 154 | self.rms_norm_eps = rms_norm_eps 155 | self.use_cache = use_cache 156 | self.rope_theta = rope_theta 157 | self.attention_dropout = attention_dropout 158 | 159 | self.num_experts_per_tok = num_experts_per_tok 160 | self.num_local_experts = num_local_experts 161 | self.output_router_logits = output_router_logits 162 | self.router_aux_loss_coef = router_aux_loss_coef 163 | super().__init__( 164 | pad_token_id=pad_token_id, 165 | bos_token_id=bos_token_id, 166 | eos_token_id=eos_token_id, 167 | tie_word_embeddings=tie_word_embeddings, 168 | **kwargs, 169 | ) 170 | -------------------------------------------------------------------------------- /examples/mixtral/modeling_mixtral.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 Mistral AI and the HuggingFace Inc. team. All rights reserved. 3 | # 4 | # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX 5 | # and OPT implementations in this library. It has been modified from its 6 | # original forms to accommodate minor architectural differences compared 7 | # to GPT-NeoX and OPT used by the Meta AI team that trained the model. 8 | # 9 | # Licensed under the Apache License, Version 2.0 (the "License"); 10 | # you may not use this file except in compliance with the License. 11 | # You may obtain a copy of the License at 12 | # 13 | # http://www.apache.org/licenses/LICENSE-2.0 14 | # 15 | # Unless required by applicable law or agreed to in writing, software 16 | # distributed under the License is distributed on an "AS IS" BASIS, 17 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 18 | # See the License for the specific language governing permissions and 19 | # limitations under the License. 20 | """ PyTorch Mixtral model.""" 21 | import inspect 22 | import math 23 | import warnings 24 | from typing import List, Optional, Tuple, Union 25 | 26 | import torch 27 | import torch.nn.functional as F 28 | import torch.utils.checkpoint 29 | from torch import nn 30 | from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss 31 | 32 | from transformers.activations import ACT2FN 33 | from transformers.cache_utils import Cache, DynamicCache 34 | from transformers.modeling_attn_mask_utils import ( 35 | _prepare_4d_causal_attention_mask, 36 | ) 37 | from transformers.modeling_outputs import ( 38 | MoeCausalLMOutputWithPast, 39 | MoeModelOutputWithPast, 40 | SequenceClassifierOutputWithPast, 41 | ) 42 | 43 | from transformers.modeling_utils import PreTrainedModel 44 | from transformers.pytorch_utils import is_torch_greater_or_equal_than_1_13 45 | from transformers.utils import ( 46 | add_start_docstrings, 47 | add_start_docstrings_to_model_forward, 48 | is_flash_attn_2_available, 49 | is_flash_attn_greater_or_equal_2_10, 50 | logging, 51 | replace_return_docstrings, 52 | ) 53 | from transformers.utils.import_utils import is_torch_fx_available 54 | from .configuration_mixtral import MixtralConfig 55 | 56 | from megablocks.layers.dmoe import ParallelDroplessMLP 57 | from megablocks.layers.arguments import Arguments 58 | 59 | import scattermoe 60 | 61 | if is_flash_attn_2_available(): 62 | print("USING FLASH ATTN") 63 | from flash_attn import flash_attn_func, flash_attn_varlen_func 64 | from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa 65 | 66 | _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters) 67 | else: 68 | print("USing nive attention") 69 | 70 | # This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph. 71 | # It means that the function will not be traced through and simply appear as a node in the graph. 72 | if is_torch_fx_available(): 73 | if not is_torch_greater_or_equal_than_1_13: 74 | import torch.fx 75 | 76 | _prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask) 77 | 78 | 79 | logger = logging.get_logger(__name__) 80 | 81 | _CONFIG_FOR_DOC = "MixtralConfig" 82 | 83 | 84 | def load_balancing_loss_func(gate_logits: torch.Tensor, num_experts: torch.Tensor = None, top_k=2) -> float: 85 | r""" 86 | Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch. 87 | 88 | See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss 89 | function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between 90 | experts is too unbalanced. 91 | 92 | Args: 93 | gate_logits (Union[`torch.Tensor`, Tuple[torch.Tensor]): 94 | Logits from the `gate`, should be a tuple of tensors. Shape: [batch_size, seqeunce_length, num_experts]. 95 | num_experts (`int`, *optional*): 96 | Number of experts 97 | 98 | Returns: 99 | The auxiliary loss. 100 | """ 101 | if gate_logits is None: 102 | return 0 103 | 104 | if isinstance(gate_logits, tuple): 105 | # cat along the layers? 106 | gate_logits = torch.cat(gate_logits, dim=0) 107 | 108 | routing_weights, selected_experts = torch.topk(gate_logits, top_k, dim=-1) 109 | routing_weights = routing_weights.softmax(dim=-1) 110 | 111 | # cast the expert indices to int64, otherwise one-hot encoding will fail 112 | if selected_experts.dtype != torch.int64: 113 | selected_experts = selected_experts.to(torch.int64) 114 | 115 | if len(selected_experts.shape) == 2: 116 | selected_experts = selected_experts.unsqueeze(2) 117 | 118 | expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts) 119 | 120 | # For a given token, determine if it was routed to a given expert. 121 | expert_mask = torch.max(expert_mask, axis=-2).values 122 | 123 | # cast to float32 otherwise mean will fail 124 | expert_mask = expert_mask.to(torch.float32) 125 | tokens_per_group_and_expert = torch.mean(expert_mask, axis=-2) 126 | 127 | router_prob_per_group_and_expert = torch.mean(routing_weights, axis=-1) 128 | return torch.mean(tokens_per_group_and_expert * router_prob_per_group_and_expert.unsqueeze(-1)) * (num_experts**2) 129 | 130 | 131 | # Copied from transformers.models.llama.modeling_llama._get_unpad_data 132 | def _get_unpad_data(attention_mask): 133 | seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) 134 | indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() 135 | max_seqlen_in_batch = seqlens_in_batch.max().item() 136 | cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) 137 | return ( 138 | indices, 139 | cu_seqlens, 140 | max_seqlen_in_batch, 141 | ) 142 | 143 | 144 | # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Mixtral 145 | class MixtralRMSNorm(nn.Module): 146 | def __init__(self, hidden_size, eps=1e-6): 147 | """ 148 | MixtralRMSNorm is equivalent to T5LayerNorm 149 | """ 150 | super().__init__() 151 | self.weight = nn.Parameter(torch.ones(hidden_size)) 152 | self.variance_epsilon = eps 153 | 154 | def forward(self, hidden_states): 155 | input_dtype = hidden_states.dtype 156 | hidden_states = hidden_states.to(torch.float32) 157 | variance = hidden_states.pow(2).mean(-1, keepdim=True) 158 | hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) 159 | return self.weight * hidden_states.to(input_dtype) 160 | 161 | 162 | # Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Mixtral 163 | class MixtralRotaryEmbedding(nn.Module): 164 | def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): 165 | super().__init__() 166 | 167 | self.dim = dim 168 | self.max_position_embeddings = max_position_embeddings 169 | self.base = base 170 | inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) 171 | self.register_buffer("inv_freq", inv_freq, persistent=False) 172 | 173 | # Build here to make `torch.jit.trace` work. 174 | self._set_cos_sin_cache( 175 | seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() 176 | ) 177 | 178 | def _set_cos_sin_cache(self, seq_len, device, dtype): 179 | self.max_seq_len_cached = seq_len 180 | t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) 181 | 182 | freqs = torch.outer(t, self.inv_freq) 183 | # Different from paper, but it uses a different permutation in order to obtain the same calculation 184 | emb = torch.cat((freqs, freqs), dim=-1) 185 | self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) 186 | self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) 187 | 188 | def forward(self, x, seq_len=None): 189 | # x: [bs, num_attention_heads, seq_len, head_size] 190 | if seq_len > self.max_seq_len_cached: 191 | self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) 192 | 193 | return ( 194 | self.cos_cached[:seq_len].to(dtype=x.dtype), 195 | self.sin_cached[:seq_len].to(dtype=x.dtype), 196 | ) 197 | 198 | 199 | # Copied from transformers.models.llama.modeling_llama.rotate_half 200 | def rotate_half(x): 201 | """Rotates half the hidden dims of the input.""" 202 | x1 = x[..., : x.shape[-1] // 2] 203 | x2 = x[..., x.shape[-1] // 2 :] 204 | return torch.cat((-x2, x1), dim=-1) 205 | 206 | 207 | # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb 208 | def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): 209 | """Applies Rotary Position Embedding to the query and key tensors. 210 | 211 | Args: 212 | q (`torch.Tensor`): The query tensor. 213 | k (`torch.Tensor`): The key tensor. 214 | cos (`torch.Tensor`): The cosine part of the rotary embedding. 215 | sin (`torch.Tensor`): The sine part of the rotary embedding. 216 | position_ids (`torch.Tensor`): 217 | The position indices of the tokens corresponding to the query and key tensors. For example, this can be 218 | used to pass offsetted position ids when working with a KV-cache. 219 | unsqueeze_dim (`int`, *optional*, defaults to 1): 220 | The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and 221 | sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note 222 | that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and 223 | k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes 224 | cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have 225 | the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. 226 | Returns: 227 | `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. 228 | """ 229 | cos = cos[position_ids].unsqueeze(unsqueeze_dim) 230 | sin = sin[position_ids].unsqueeze(unsqueeze_dim) 231 | q_embed = (q * cos) + (rotate_half(q) * sin) 232 | k_embed = (k * cos) + (rotate_half(k) * sin) 233 | return q_embed, k_embed 234 | 235 | 236 | # Copied from transformers.models.llama.modeling_llama.repeat_kv 237 | def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: 238 | """ 239 | This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, 240 | num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) 241 | """ 242 | batch, num_key_value_heads, slen, head_dim = hidden_states.shape 243 | if n_rep == 1: 244 | return hidden_states 245 | hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) 246 | return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) 247 | 248 | 249 | # Copied from transformers.models.mistral.modeling_mistral.MistralAttention with Mistral->Mixtral 250 | class MixtralAttention(nn.Module): 251 | """ 252 | Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer 253 | and "Generating Long Sequences with Sparse Transformers". 254 | """ 255 | 256 | def __init__(self, config: MixtralConfig, layer_idx: Optional[int] = None): 257 | super().__init__() 258 | self.config = config 259 | self.layer_idx = layer_idx 260 | if layer_idx is None: 261 | logger.warning_once( 262 | f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will " 263 | "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " 264 | "when creating this class." 265 | ) 266 | 267 | self.hidden_size = config.hidden_size 268 | self.num_heads = config.num_attention_heads 269 | self.head_dim = self.hidden_size // self.num_heads 270 | self.num_key_value_heads = config.num_key_value_heads 271 | self.num_key_value_groups = self.num_heads // self.num_key_value_heads 272 | self.max_position_embeddings = config.max_position_embeddings 273 | self.rope_theta = config.rope_theta 274 | self.is_causal = True 275 | self.attention_dropout = config.attention_dropout 276 | 277 | if (self.head_dim * self.num_heads) != self.hidden_size: 278 | raise ValueError( 279 | f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" 280 | f" and `num_heads`: {self.num_heads})." 281 | ) 282 | self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) 283 | self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) 284 | self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) 285 | self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) 286 | 287 | self.rotary_emb = MixtralRotaryEmbedding( 288 | self.head_dim, 289 | max_position_embeddings=self.max_position_embeddings, 290 | base=self.rope_theta, 291 | ) 292 | 293 | def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): 294 | return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() 295 | 296 | def forward( 297 | self, 298 | hidden_states: torch.Tensor, 299 | attention_mask: Optional[torch.Tensor] = None, 300 | position_ids: Optional[torch.LongTensor] = None, 301 | past_key_value: Optional[Cache] = None, 302 | output_attentions: bool = False, 303 | use_cache: bool = False, 304 | **kwargs, 305 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 306 | if "padding_mask" in kwargs: 307 | warnings.warn( 308 | "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" 309 | ) 310 | bsz, q_len, _ = hidden_states.size() 311 | 312 | query_states = self.q_proj(hidden_states) 313 | key_states = self.k_proj(hidden_states) 314 | value_states = self.v_proj(hidden_states) 315 | 316 | query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) 317 | key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) 318 | value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) 319 | 320 | kv_seq_len = key_states.shape[-2] 321 | if past_key_value is not None: 322 | if self.layer_idx is None: 323 | raise ValueError( 324 | f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " 325 | "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " 326 | "with a layer index." 327 | ) 328 | kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) 329 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) 330 | query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) 331 | 332 | if past_key_value is not None: 333 | cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models 334 | key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) 335 | 336 | # repeat k/v heads if n_kv_heads < n_heads 337 | key_states = repeat_kv(key_states, self.num_key_value_groups) 338 | value_states = repeat_kv(value_states, self.num_key_value_groups) 339 | 340 | attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) 341 | 342 | if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): 343 | raise ValueError( 344 | f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" 345 | f" {attn_weights.size()}" 346 | ) 347 | 348 | if attention_mask is not None: 349 | if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): 350 | raise ValueError( 351 | f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" 352 | ) 353 | 354 | attn_weights = attn_weights + attention_mask 355 | 356 | # upcast attention to fp32 357 | attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) 358 | attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) 359 | attn_output = torch.matmul(attn_weights, value_states) 360 | 361 | if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): 362 | raise ValueError( 363 | f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" 364 | f" {attn_output.size()}" 365 | ) 366 | 367 | attn_output = attn_output.transpose(1, 2).contiguous() 368 | attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) 369 | 370 | attn_output = self.o_proj(attn_output) 371 | 372 | if not output_attentions: 373 | attn_weights = None 374 | 375 | return attn_output, attn_weights, past_key_value 376 | 377 | 378 | # Copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2 with Mistral->Mixtral 379 | class MixtralFlashAttention2(MixtralAttention): 380 | """ 381 | Mixtral flash attention module. This module inherits from `MixtralAttention` as the weights of the module stays 382 | untouched. The only required change would be on the forward pass where it needs to correctly call the public API of 383 | flash attention and deal with padding tokens in case the input contains any of them. 384 | """ 385 | 386 | # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ 387 | def __init__(self, *args, **kwargs): 388 | super().__init__(*args, **kwargs) 389 | 390 | # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. 391 | # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. 392 | # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). 393 | self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() 394 | 395 | def forward( 396 | self, 397 | hidden_states: torch.Tensor, 398 | attention_mask: Optional[torch.Tensor] = None, 399 | position_ids: Optional[torch.LongTensor] = None, 400 | past_key_value: Optional[Cache] = None, 401 | output_attentions: bool = False, 402 | use_cache: bool = False, 403 | **kwargs, 404 | ): 405 | if "padding_mask" in kwargs: 406 | warnings.warn( 407 | "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" 408 | ) 409 | 410 | # overwrite attention_mask with padding_mask 411 | attention_mask = kwargs.pop("padding_mask") 412 | bsz, q_len, _ = hidden_states.size() 413 | 414 | query_states = self.q_proj(hidden_states) 415 | key_states = self.k_proj(hidden_states) 416 | value_states = self.v_proj(hidden_states) 417 | 418 | query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) 419 | key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) 420 | value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) 421 | 422 | kv_seq_len = key_states.shape[-2] 423 | if past_key_value is not None: 424 | kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) 425 | 426 | # Because the input can be padded, the absolute sequence length depends on the max position id. 427 | rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1 428 | cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len) 429 | 430 | query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) 431 | 432 | use_sliding_windows = ( 433 | _flash_supports_window_size 434 | and getattr(self.config, "sliding_window", None) is not None 435 | and kv_seq_len > self.config.sliding_window 436 | ) 437 | 438 | if not _flash_supports_window_size: 439 | logger.warning_once( 440 | "The current flash attention version does not support sliding window attention, for a more memory efficient implementation" 441 | " make sure to upgrade flash-attn library." 442 | ) 443 | 444 | if past_key_value is not None: 445 | # Activate slicing cache only if the config has a value `sliding_windows` attribute 446 | if getattr(self.config, "sliding_window", None) is not None and kv_seq_len > self.config.sliding_window: 447 | slicing_tokens = 1 - self.config.sliding_window 448 | 449 | past_key = past_key_value[0] 450 | past_value = past_key_value[1] 451 | 452 | past_key = past_key[:, :, slicing_tokens:, :].contiguous() 453 | past_value = past_value[:, :, slicing_tokens:, :].contiguous() 454 | 455 | if past_key.shape[-2] != self.config.sliding_window - 1: 456 | raise ValueError( 457 | f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got" 458 | f" {past_key.shape}" 459 | ) 460 | 461 | past_key_value = (past_key, past_value) 462 | 463 | if attention_mask is not None: 464 | attention_mask = attention_mask[:, slicing_tokens:] 465 | attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1) 466 | 467 | cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models 468 | key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) 469 | 470 | # repeat k/v heads if n_kv_heads < n_heads 471 | key_states = repeat_kv(key_states, self.num_key_value_groups) 472 | value_states = repeat_kv(value_states, self.num_key_value_groups) 473 | dropout_rate = 0.0 if not self.training else self.attention_dropout 474 | 475 | # In PEFT, usually we cast the layer norms in float32 for training stability reasons 476 | # therefore the input hidden states gets silently casted in float32. Hence, we need 477 | # cast them back in float16 just to be sure everything works as expected. 478 | input_dtype = query_states.dtype 479 | if input_dtype == torch.float32: 480 | # Handle the case where the model is quantized 481 | if hasattr(self.config, "_pre_quantization_dtype"): 482 | target_dtype = self.config._pre_quantization_dtype 483 | else: 484 | target_dtype = self.q_proj.weight.dtype 485 | 486 | logger.warning_once( 487 | f"The input hidden states seems to be silently casted in float32, this might be related to" 488 | f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" 489 | f" {target_dtype}." 490 | ) 491 | 492 | query_states = query_states.to(target_dtype) 493 | key_states = key_states.to(target_dtype) 494 | value_states = value_states.to(target_dtype) 495 | 496 | # Reashape to the expected shape for Flash Attention 497 | query_states = query_states.transpose(1, 2) 498 | key_states = key_states.transpose(1, 2) 499 | value_states = value_states.transpose(1, 2) 500 | 501 | attn_output = self._flash_attention_forward( 502 | query_states, 503 | key_states, 504 | value_states, 505 | attention_mask, 506 | q_len, 507 | dropout=dropout_rate, 508 | use_sliding_windows=use_sliding_windows, 509 | ) 510 | 511 | attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() 512 | attn_output = self.o_proj(attn_output) 513 | 514 | if not output_attentions: 515 | attn_weights = None 516 | 517 | return attn_output, attn_weights, past_key_value 518 | 519 | def _flash_attention_forward( 520 | self, 521 | query_states, 522 | key_states, 523 | value_states, 524 | attention_mask, 525 | query_length, 526 | dropout=0.0, 527 | softmax_scale=None, 528 | use_sliding_windows=False, 529 | ): 530 | """ 531 | Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token 532 | first unpad the input, then computes the attention scores and pad the final attention scores. 533 | 534 | Args: 535 | query_states (`torch.Tensor`): 536 | Input query states to be passed to Flash Attention API 537 | key_states (`torch.Tensor`): 538 | Input key states to be passed to Flash Attention API 539 | value_states (`torch.Tensor`): 540 | Input value states to be passed to Flash Attention API 541 | attention_mask (`torch.Tensor`): 542 | The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the 543 | position of padding tokens and 1 for the position of non-padding tokens. 544 | dropout (`int`, *optional*): 545 | Attention dropout 546 | softmax_scale (`float`, *optional*): 547 | The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) 548 | use_sliding_windows (`bool`, *optional*): 549 | Whether to activate sliding window attention. 550 | """ 551 | if not self._flash_attn_uses_top_left_mask: 552 | causal = self.is_causal 553 | else: 554 | # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__. 555 | causal = self.is_causal and query_length != 1 556 | 557 | # Contains at least one padding token in the sequence 558 | if attention_mask is not None: 559 | batch_size = query_states.shape[0] 560 | query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( 561 | query_states, key_states, value_states, attention_mask, query_length 562 | ) 563 | 564 | cu_seqlens_q, cu_seqlens_k = cu_seq_lens 565 | max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens 566 | 567 | if not use_sliding_windows: 568 | attn_output_unpad = flash_attn_varlen_func( 569 | query_states, 570 | key_states, 571 | value_states, 572 | cu_seqlens_q=cu_seqlens_q, 573 | cu_seqlens_k=cu_seqlens_k, 574 | max_seqlen_q=max_seqlen_in_batch_q, 575 | max_seqlen_k=max_seqlen_in_batch_k, 576 | dropout_p=dropout, 577 | softmax_scale=softmax_scale, 578 | causal=causal, 579 | ) 580 | else: 581 | attn_output_unpad = flash_attn_varlen_func( 582 | query_states, 583 | key_states, 584 | value_states, 585 | cu_seqlens_q=cu_seqlens_q, 586 | cu_seqlens_k=cu_seqlens_k, 587 | max_seqlen_q=max_seqlen_in_batch_q, 588 | max_seqlen_k=max_seqlen_in_batch_k, 589 | dropout_p=dropout, 590 | softmax_scale=softmax_scale, 591 | causal=causal, 592 | window_size=(self.config.sliding_window, self.config.sliding_window), 593 | ) 594 | 595 | attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) 596 | else: 597 | if not use_sliding_windows: 598 | attn_output = flash_attn_func( 599 | query_states, 600 | key_states, 601 | value_states, 602 | dropout, 603 | softmax_scale=softmax_scale, 604 | causal=causal, 605 | ) 606 | else: 607 | attn_output = flash_attn_func( 608 | query_states, 609 | key_states, 610 | value_states, 611 | dropout, 612 | softmax_scale=softmax_scale, 613 | causal=causal, 614 | window_size=(self.config.sliding_window, self.config.sliding_window), 615 | ) 616 | 617 | return attn_output 618 | 619 | def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): 620 | batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape 621 | 622 | # On the first iteration we need to properly re-create the padding mask 623 | # by slicing it on the proper place 624 | if kv_seq_len != attention_mask.shape[-1]: 625 | attention_mask_num_tokens = attention_mask.shape[-1] 626 | attention_mask = attention_mask[:, attention_mask_num_tokens - kv_seq_len :] 627 | 628 | indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) 629 | 630 | key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k) 631 | value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k) 632 | 633 | if query_length == kv_seq_len: 634 | query_layer = index_first_axis( 635 | query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k 636 | ) 637 | cu_seqlens_q = cu_seqlens_k 638 | max_seqlen_in_batch_q = max_seqlen_in_batch_k 639 | indices_q = indices_k 640 | elif query_length == 1: 641 | max_seqlen_in_batch_q = 1 642 | cu_seqlens_q = torch.arange( 643 | batch_size + 1, dtype=torch.int32, device=query_layer.device 644 | ) # There is a memcpy here, that is very bad. 645 | indices_q = cu_seqlens_q[:-1] 646 | query_layer = query_layer.squeeze(1) 647 | else: 648 | # The -q_len: slice assumes left padding. 649 | attention_mask = attention_mask[:, -query_length:] 650 | query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) 651 | 652 | return ( 653 | query_layer, 654 | key_layer, 655 | value_layer, 656 | indices_q, 657 | (cu_seqlens_q, cu_seqlens_k), 658 | (max_seqlen_in_batch_q, max_seqlen_in_batch_k), 659 | ) 660 | 661 | 662 | MISTRAL_ATTENTION_CLASSES = { 663 | "eager": MixtralAttention, 664 | "flash_attention_2": MixtralFlashAttention2, 665 | } 666 | 667 | class MixtralSparseMoeBlock(nn.Module): 668 | def __init__(self, config): 669 | super().__init__() 670 | self.hidden_dim = config.hidden_size 671 | self.ffn_dim = config.intermediate_size 672 | self.num_experts = config.num_local_experts 673 | self.top_k = config.num_experts_per_tok 674 | 675 | # gating 676 | self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False) 677 | 678 | self.moe_mlp = scattermoe.mlp.GLUMLP( 679 | input_size=self.hidden_dim, 680 | hidden_size=self.ffn_dim, 681 | num_experts=self.num_experts, 682 | top_k=self.top_k, 683 | activation=ACT2FN[config.hidden_act] 684 | ) 685 | 686 | def forward(self, hidden_states: torch.Tensor): 687 | """ """ 688 | batch_size, sequence_length, hidden_dim = hidden_states.shape 689 | hidden_states = hidden_states.view(-1, hidden_dim) 690 | # router_logits: (batch * sequence_length, n_experts) 691 | router_logits = self.gate(hidden_states) 692 | 693 | routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) 694 | routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) 695 | routing_weights /= routing_weights.sum(dim=-1, keepdim=True) 696 | # we cast back to the input dtype 697 | routing_weights = routing_weights.to(hidden_states.dtype) 698 | final_hidden_states = self.moe_mlp(hidden_states, routing_weights, selected_experts) 699 | final_hidden_states = final_hidden_states.view(batch_size, sequence_length, hidden_dim) 700 | return final_hidden_states, router_logits 701 | 702 | class MegablocksMixtralSparseMoeBlock(nn.Module): 703 | """ 704 | This implementation is 705 | strictly equivalent to standard MoE with full capacity (no 706 | dropped tokens). It's faster since it formulates MoE operations 707 | in terms of block-sparse operations to accomodate imbalanced 708 | assignments of tokens to experts, whereas standard MoE either 709 | (1) drop tokens at the cost of reduced performance or (2) set 710 | capacity factor to number of experts and thus waste computation 711 | and memory on padding. 712 | """ 713 | 714 | def __init__(self, config): 715 | super().__init__() 716 | self.hidden_dim = config.hidden_size 717 | self.ffn_dim = config.intermediate_size 718 | self.num_experts = config.num_local_experts 719 | self.top_k = config.num_experts_per_tok 720 | 721 | # gating 722 | self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False) 723 | 724 | 725 | self.moe_mlp = ParallelDroplessMLP(Arguments( 726 | hidden_size=self.hidden_dim, 727 | ffn_hidden_size=self.ffn_dim, 728 | moe_num_experts=self.num_experts, 729 | moe_top_k=self.top_k, 730 | moe_capacity_factor=1, 731 | # init_method=partial(torch.nn.init.normal_, mean=0.0, std=0.1), 732 | mlp_type='glu', 733 | mlp_impl='sparse', 734 | fp16=False, 735 | bf16=False, 736 | bias=False 737 | )) 738 | 739 | def forward(self, hidden_states: torch.Tensor): 740 | """ """ 741 | batch_size, sequence_length, hidden_dim = hidden_states.shape 742 | hidden_states = hidden_states.view(-1, hidden_dim) 743 | # router_logits: (batch * sequence_length, n_experts) 744 | router_logits = self.gate(hidden_states) 745 | 746 | routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) 747 | routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) 748 | routing_weights /= routing_weights.sum(dim=-1, keepdim=True) 749 | # we cast back to the input dtype 750 | routing_weights = routing_weights.to(hidden_states.dtype) 751 | final_hidden_states = self.moe_mlp(hidden_states, router_logits, routing_weights, selected_experts) 752 | final_hidden_states = final_hidden_states.view(batch_size, sequence_length, hidden_dim) 753 | return final_hidden_states, router_logits 754 | 755 | 756 | class MixtralDecoderLayer(nn.Module): 757 | def __init__(self, config: MixtralConfig, layer_idx: int): 758 | super().__init__() 759 | self.hidden_size = config.hidden_size 760 | 761 | self.self_attn = MISTRAL_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx) 762 | 763 | self.block_sparse_moe = MixtralSparseMoeBlock(config) 764 | self.input_layernorm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) 765 | self.post_attention_layernorm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) 766 | 767 | def forward( 768 | self, 769 | hidden_states: torch.Tensor, 770 | attention_mask: Optional[torch.Tensor] = None, 771 | position_ids: Optional[torch.LongTensor] = None, 772 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 773 | output_attentions: Optional[bool] = False, 774 | output_router_logits: Optional[bool] = False, 775 | use_cache: Optional[bool] = False, 776 | **kwargs, 777 | ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: 778 | if "padding_mask" in kwargs: 779 | warnings.warn( 780 | "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" 781 | ) 782 | """ 783 | Args: 784 | hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` 785 | attention_mask (`torch.FloatTensor`, *optional*): attention mask of size 786 | `(batch, sequence_length)` where padding elements are indicated by 0. 787 | past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states 788 | output_attentions (`bool`, *optional*): 789 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under 790 | returned tensors for more detail. 791 | output_router_logits (`bool`, *optional*): 792 | Whether or not to return the logits of all the routers. They are useful for computing the router loss, and 793 | should not be returned during inference. 794 | use_cache (`bool`, *optional*): 795 | If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding 796 | (see `past_key_values`). 797 | """ 798 | 799 | residual = hidden_states 800 | 801 | hidden_states = self.input_layernorm(hidden_states) 802 | 803 | # Self Attention 804 | hidden_states, self_attn_weights, present_key_value = self.self_attn( 805 | hidden_states=hidden_states, 806 | attention_mask=attention_mask, 807 | position_ids=position_ids, 808 | past_key_value=past_key_value, 809 | output_attentions=output_attentions, 810 | use_cache=use_cache, 811 | ) 812 | hidden_states = residual + hidden_states 813 | 814 | # Fully Connected 815 | residual = hidden_states 816 | hidden_states = self.post_attention_layernorm(hidden_states) 817 | hidden_states, router_logits = self.block_sparse_moe(hidden_states) 818 | hidden_states = residual + hidden_states 819 | 820 | outputs = (hidden_states,) 821 | 822 | if output_attentions: 823 | outputs += (self_attn_weights,) 824 | 825 | if use_cache: 826 | outputs += (present_key_value,) 827 | 828 | if output_router_logits: 829 | outputs += (router_logits,) 830 | 831 | return outputs 832 | 833 | 834 | MIXTRAL_START_DOCSTRING = r""" 835 | This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the 836 | library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads 837 | etc.) 838 | 839 | This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. 840 | Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage 841 | and behavior. 842 | 843 | Parameters: 844 | config ([`MixtralConfig`]): 845 | Model configuration class with all the parameters of the model. Initializing with a config file does not 846 | load the weights associated with the model, only the configuration. Check out the 847 | [`~PreTrainedModel.from_pretrained`] method to load the model weights. 848 | """ 849 | 850 | 851 | @add_start_docstrings( 852 | "The bare Mixtral Model outputting raw hidden-states without any specific head on top.", 853 | MIXTRAL_START_DOCSTRING, 854 | ) 855 | # Copied from transformers.models.mistral.modeling_mistral.MistralPreTrainedModel with Mistral->Mixtral 856 | class MixtralPreTrainedModel(PreTrainedModel): 857 | config_class = MixtralConfig 858 | base_model_prefix = "model" 859 | supports_gradient_checkpointing = True 860 | _no_split_modules = ["MixtralDecoderLayer"] 861 | _skip_keys_device_placement = "past_key_values" 862 | _supports_flash_attn_2 = True 863 | _supports_cache_class = True 864 | 865 | def _init_weights(self, module): 866 | std = self.config.initializer_range 867 | if isinstance(module, nn.Linear): 868 | module.weight.data.normal_(mean=0.0, std=std) 869 | if module.bias is not None: 870 | module.bias.data.zero_() 871 | elif isinstance(module, nn.Embedding): 872 | module.weight.data.normal_(mean=0.0, std=std) 873 | if module.padding_idx is not None: 874 | module.weight.data[module.padding_idx].zero_() 875 | 876 | 877 | MIXTRAL_INPUTS_DOCSTRING = r""" 878 | Args: 879 | input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): 880 | Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide 881 | it. 882 | 883 | Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and 884 | [`PreTrainedTokenizer.__call__`] for details. 885 | 886 | [What are input IDs?](../glossary#input-ids) 887 | attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): 888 | Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: 889 | 890 | - 1 for tokens that are **not masked**, 891 | - 0 for tokens that are **masked**. 892 | 893 | [What are attention masks?](../glossary#attention-mask) 894 | 895 | Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and 896 | [`PreTrainedTokenizer.__call__`] for details. 897 | 898 | If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see 899 | `past_key_values`). 900 | 901 | If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] 902 | and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more 903 | information on the default strategy. 904 | 905 | - 1 indicates the head is **not masked**, 906 | - 0 indicates the head is **masked**. 907 | position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): 908 | Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, 909 | config.n_positions - 1]`. 910 | 911 | [What are position IDs?](../glossary#position-ids) 912 | past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): 913 | Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape 914 | `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape 915 | `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. 916 | 917 | Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention 918 | blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. 919 | 920 | If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that 921 | don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all 922 | `decoder_input_ids` of shape `(batch_size, sequence_length)`. 923 | inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): 924 | Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This 925 | is useful if you want more control over how to convert `input_ids` indices into associated vectors than the 926 | model's internal embedding lookup matrix. 927 | use_cache (`bool`, *optional*): 928 | If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see 929 | `past_key_values`). 930 | output_attentions (`bool`, *optional*): 931 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned 932 | tensors for more detail. 933 | output_hidden_states (`bool`, *optional*): 934 | Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for 935 | more detail. 936 | output_router_logits (`bool`, *optional*): 937 | Whether or not to return the logits of all the routers. They are useful for computing the router loss, and 938 | should not be returned during inference. 939 | return_dict (`bool`, *optional*): 940 | Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. 941 | """ 942 | 943 | 944 | @add_start_docstrings( 945 | "The bare Mixtral Model outputting raw hidden-states without any specific head on top.", 946 | MIXTRAL_START_DOCSTRING, 947 | ) 948 | # Copied from transformers.models.mistral.modeling_mistral.MistralModel with MISTRAL->MIXTRAL,Mistral->Mixtral 949 | class MixtralModel(MixtralPreTrainedModel): 950 | """ 951 | Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MixtralDecoderLayer`] 952 | 953 | Args: 954 | config: MixtralConfig 955 | """ 956 | 957 | def __init__(self, config: MixtralConfig): 958 | super().__init__(config) 959 | self.padding_idx = config.pad_token_id 960 | self.vocab_size = config.vocab_size 961 | 962 | self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) 963 | self.layers = nn.ModuleList( 964 | [MixtralDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] 965 | ) 966 | self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" 967 | self.norm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) 968 | 969 | self.gradient_checkpointing = False 970 | # Initialize weights and apply final processing 971 | self.post_init() 972 | 973 | def get_input_embeddings(self): 974 | return self.embed_tokens 975 | 976 | def set_input_embeddings(self, value): 977 | self.embed_tokens = value 978 | 979 | # Ignore copy 980 | @add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING) 981 | def forward( 982 | self, 983 | input_ids: torch.LongTensor = None, 984 | attention_mask: Optional[torch.Tensor] = None, 985 | position_ids: Optional[torch.LongTensor] = None, 986 | past_key_values: Optional[List[torch.FloatTensor]] = None, 987 | inputs_embeds: Optional[torch.FloatTensor] = None, 988 | use_cache: Optional[bool] = None, 989 | output_attentions: Optional[bool] = None, 990 | output_hidden_states: Optional[bool] = None, 991 | output_router_logits: Optional[bool] = None, 992 | return_dict: Optional[bool] = None, 993 | ) -> Union[Tuple, MoeModelOutputWithPast]: 994 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 995 | output_router_logits = ( 996 | output_router_logits if output_router_logits is not None else self.config.output_router_logits 997 | ) 998 | output_hidden_states = ( 999 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 1000 | ) 1001 | use_cache = use_cache if use_cache is not None else self.config.use_cache 1002 | 1003 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1004 | 1005 | # retrieve input_ids and inputs_embeds 1006 | if input_ids is not None and inputs_embeds is not None: 1007 | raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") 1008 | elif input_ids is not None: 1009 | batch_size, seq_length = input_ids.shape 1010 | elif inputs_embeds is not None: 1011 | batch_size, seq_length, _ = inputs_embeds.shape 1012 | else: 1013 | raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") 1014 | 1015 | past_key_values_length = 0 1016 | 1017 | if use_cache: 1018 | use_legacy_cache = not isinstance(past_key_values, Cache) 1019 | if use_legacy_cache: 1020 | past_key_values = DynamicCache.from_legacy_cache(past_key_values) 1021 | past_key_values_length = past_key_values.get_usable_length(seq_length) 1022 | 1023 | if position_ids is None: 1024 | device = input_ids.device if input_ids is not None else inputs_embeds.device 1025 | position_ids = torch.arange( 1026 | past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device 1027 | ) 1028 | position_ids = position_ids.unsqueeze(0).view(-1, seq_length) 1029 | else: 1030 | position_ids = position_ids.view(-1, seq_length).long() 1031 | 1032 | if inputs_embeds is None: 1033 | inputs_embeds = self.embed_tokens(input_ids) 1034 | 1035 | if attention_mask is not None and self._use_flash_attention_2 and use_cache: 1036 | is_padding_right = attention_mask[:, -1].sum().item() != batch_size 1037 | if is_padding_right: 1038 | raise ValueError( 1039 | "You are attempting to perform batched generation with padding_side='right'" 1040 | " this may lead to unexpected behaviour for Flash Attention version of Mixtral. Make sure to " 1041 | " call `tokenizer.padding_side = 'left'` before tokenizing the input. " 1042 | ) 1043 | 1044 | if self._use_flash_attention_2: 1045 | # 2d mask is passed through the layers 1046 | attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None 1047 | else: 1048 | # 4d mask is passed through the layers 1049 | attention_mask = _prepare_4d_causal_attention_mask( 1050 | attention_mask, 1051 | (batch_size, seq_length), 1052 | inputs_embeds, 1053 | past_key_values_length, 1054 | sliding_window=self.config.sliding_window, 1055 | ) 1056 | 1057 | hidden_states = inputs_embeds 1058 | 1059 | if self.gradient_checkpointing and self.training: 1060 | if use_cache: 1061 | logger.warning_once( 1062 | "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." 1063 | ) 1064 | use_cache = False 1065 | 1066 | # decoder layers 1067 | all_hidden_states = () if output_hidden_states else None 1068 | all_self_attns = () if output_attentions else None 1069 | all_router_logits = () if output_router_logits else None 1070 | next_decoder_cache = None 1071 | 1072 | for decoder_layer in self.layers: 1073 | if output_hidden_states: 1074 | all_hidden_states += (hidden_states,) 1075 | 1076 | if self.gradient_checkpointing and self.training: 1077 | layer_outputs = self._gradient_checkpointing_func( 1078 | decoder_layer.__call__, 1079 | hidden_states, 1080 | attention_mask, 1081 | position_ids, 1082 | past_key_values, 1083 | output_attentions, 1084 | output_router_logits, 1085 | use_cache, 1086 | ) 1087 | else: 1088 | layer_outputs = decoder_layer( 1089 | hidden_states, 1090 | attention_mask=attention_mask, 1091 | position_ids=position_ids, 1092 | past_key_value=past_key_values, 1093 | output_attentions=output_attentions, 1094 | output_router_logits=output_router_logits, 1095 | use_cache=use_cache, 1096 | ) 1097 | 1098 | hidden_states = layer_outputs[0] 1099 | 1100 | if use_cache: 1101 | next_decoder_cache = layer_outputs[2 if output_attentions else 1] 1102 | 1103 | if output_attentions: 1104 | all_self_attns += (layer_outputs[1],) 1105 | 1106 | if output_router_logits: 1107 | all_router_logits += (layer_outputs[-1],) 1108 | 1109 | hidden_states = self.norm(hidden_states) 1110 | 1111 | # add hidden states from the last decoder layer 1112 | if output_hidden_states: 1113 | all_hidden_states += (hidden_states,) 1114 | 1115 | next_cache = None 1116 | if use_cache: 1117 | next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache 1118 | 1119 | if not return_dict: 1120 | return tuple( 1121 | v 1122 | for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits] 1123 | if v is not None 1124 | ) 1125 | return MoeModelOutputWithPast( 1126 | last_hidden_state=hidden_states, 1127 | past_key_values=next_cache, 1128 | hidden_states=all_hidden_states, 1129 | attentions=all_self_attns, 1130 | router_logits=all_router_logits, 1131 | ) 1132 | 1133 | 1134 | class MixtralForCausalLM(MixtralPreTrainedModel): 1135 | _tied_weights_keys = ["lm_head.weight"] 1136 | 1137 | def __init__(self, config): 1138 | super().__init__(config) 1139 | self.model = MixtralModel(config) 1140 | self.vocab_size = config.vocab_size 1141 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 1142 | self.router_aux_loss_coef = config.router_aux_loss_coef 1143 | self.num_experts = config.num_local_experts 1144 | self.num_experts_per_tok = config.num_experts_per_tok 1145 | # Initialize weights and apply final processing 1146 | self.post_init() 1147 | 1148 | def get_input_embeddings(self): 1149 | return self.model.embed_tokens 1150 | 1151 | def set_input_embeddings(self, value): 1152 | self.model.embed_tokens = value 1153 | 1154 | def get_output_embeddings(self): 1155 | return self.lm_head 1156 | 1157 | def set_output_embeddings(self, new_embeddings): 1158 | self.lm_head = new_embeddings 1159 | 1160 | def set_decoder(self, decoder): 1161 | self.model = decoder 1162 | 1163 | def get_decoder(self): 1164 | return self.model 1165 | 1166 | @add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING) 1167 | @replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) 1168 | # Ignore copy 1169 | def forward( 1170 | self, 1171 | input_ids: torch.LongTensor = None, 1172 | attention_mask: Optional[torch.Tensor] = None, 1173 | position_ids: Optional[torch.LongTensor] = None, 1174 | past_key_values: Optional[List[torch.FloatTensor]] = None, 1175 | inputs_embeds: Optional[torch.FloatTensor] = None, 1176 | labels: Optional[torch.LongTensor] = None, 1177 | use_cache: Optional[bool] = None, 1178 | output_attentions: Optional[bool] = None, 1179 | output_hidden_states: Optional[bool] = None, 1180 | output_router_logits: Optional[bool] = None, 1181 | return_dict: Optional[bool] = None, 1182 | ) -> Union[Tuple, MoeCausalLMOutputWithPast]: 1183 | r""" 1184 | Args: 1185 | labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): 1186 | Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., 1187 | config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored 1188 | (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. 1189 | 1190 | Returns: 1191 | 1192 | Example: 1193 | 1194 | ```python 1195 | >>> from transformers import AutoTokenizer, MixtralForCausalLM 1196 | 1197 | >>> model = MixtralForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) 1198 | >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) 1199 | 1200 | >>> prompt = "Hey, are you conscious? Can you talk to me?" 1201 | >>> inputs = tokenizer(prompt, return_tensors="pt") 1202 | 1203 | >>> # Generate 1204 | >>> generate_ids = model.generate(inputs.input_ids, max_length=30) 1205 | >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] 1206 | "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." 1207 | ```""" 1208 | 1209 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 1210 | output_router_logits = ( 1211 | output_router_logits if output_router_logits is not None else self.config.output_router_logits 1212 | ) 1213 | 1214 | output_hidden_states = ( 1215 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 1216 | ) 1217 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1218 | 1219 | # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) 1220 | outputs = self.model( 1221 | input_ids=input_ids, 1222 | attention_mask=attention_mask, 1223 | position_ids=position_ids, 1224 | past_key_values=past_key_values, 1225 | inputs_embeds=inputs_embeds, 1226 | use_cache=use_cache, 1227 | output_attentions=output_attentions, 1228 | output_hidden_states=output_hidden_states, 1229 | output_router_logits=output_router_logits, 1230 | return_dict=return_dict, 1231 | ) 1232 | 1233 | hidden_states = outputs[0] 1234 | logits = self.lm_head(hidden_states) 1235 | logits = logits.float() 1236 | 1237 | loss = None 1238 | if labels is not None: 1239 | # Shift so that tokens < n predict n 1240 | shift_logits = logits[..., :-1, :].contiguous() 1241 | shift_labels = labels[..., 1:].contiguous() 1242 | # Flatten the tokens 1243 | loss_fct = CrossEntropyLoss() 1244 | shift_logits = shift_logits.view(-1, self.config.vocab_size) 1245 | shift_labels = shift_labels.view(-1) 1246 | # Enable model parallelism 1247 | shift_labels = shift_labels.to(shift_logits.device) 1248 | loss = loss_fct(shift_logits, shift_labels) 1249 | 1250 | aux_loss = None 1251 | if output_router_logits: 1252 | aux_loss = load_balancing_loss_func( 1253 | outputs.router_logits if return_dict else outputs[-1], self.num_experts, self.num_experts_per_tok 1254 | ) 1255 | if labels is not None: 1256 | loss += self.router_aux_loss_coef * aux_loss 1257 | 1258 | if not return_dict: 1259 | output = (logits,) + outputs[1:] 1260 | if output_router_logits: 1261 | output = (aux_loss,) + output 1262 | return (loss,) + output if loss is not None else output 1263 | 1264 | return MoeCausalLMOutputWithPast( 1265 | loss=loss, 1266 | aux_loss=aux_loss, 1267 | logits=logits, 1268 | past_key_values=outputs.past_key_values, 1269 | hidden_states=outputs.hidden_states, 1270 | attentions=outputs.attentions, 1271 | router_logits=outputs.router_logits, 1272 | ) 1273 | 1274 | def prepare_inputs_for_generation( 1275 | self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs 1276 | ): 1277 | # Omit tokens covered by past_key_values 1278 | if past_key_values is not None: 1279 | if isinstance(past_key_values, Cache): 1280 | cache_length = past_key_values.get_seq_length() 1281 | past_length = past_key_values.seen_tokens 1282 | max_cache_length = past_key_values.get_max_length() 1283 | else: 1284 | cache_length = past_length = past_key_values[0][0].shape[2] 1285 | max_cache_length = None 1286 | 1287 | # Keep only the unprocessed tokens: 1288 | # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where 1289 | # some of the inputs are exclusivelly passed as part of the cache (e.g. when passing input_embeds as 1290 | # input) 1291 | if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: 1292 | input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] 1293 | # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard 1294 | # input_ids based on the past_length. 1295 | elif past_length < input_ids.shape[1]: 1296 | input_ids = input_ids[:, past_length:] 1297 | # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. 1298 | 1299 | # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. 1300 | if ( 1301 | max_cache_length is not None 1302 | and attention_mask is not None 1303 | and cache_length + input_ids.shape[1] > max_cache_length 1304 | ): 1305 | attention_mask = attention_mask[:, -max_cache_length:] 1306 | 1307 | position_ids = kwargs.get("position_ids", None) 1308 | if attention_mask is not None and position_ids is None: 1309 | # create position_ids on the fly for batch generation 1310 | position_ids = attention_mask.long().cumsum(-1) - 1 1311 | position_ids.masked_fill_(attention_mask == 0, 1) 1312 | if past_key_values: 1313 | position_ids = position_ids[:, -input_ids.shape[1] :] 1314 | 1315 | # if `inputs_embeds` are passed, we only want to use them in the 1st generation step 1316 | if inputs_embeds is not None and past_key_values is None: 1317 | model_inputs = {"inputs_embeds": inputs_embeds} 1318 | else: 1319 | model_inputs = {"input_ids": input_ids} 1320 | 1321 | model_inputs.update( 1322 | { 1323 | "position_ids": position_ids, 1324 | "past_key_values": past_key_values, 1325 | "use_cache": kwargs.get("use_cache"), 1326 | "attention_mask": attention_mask, 1327 | } 1328 | ) 1329 | return model_inputs 1330 | 1331 | @staticmethod 1332 | def _reorder_cache(past_key_values, beam_idx): 1333 | reordered_past = () 1334 | for layer_past in past_key_values: 1335 | reordered_past += ( 1336 | tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), 1337 | ) 1338 | return reordered_past 1339 | 1340 | 1341 | @add_start_docstrings( 1342 | """ 1343 | The Mixtral Model transformer with a sequence classification head on top (linear layer). 1344 | 1345 | [`MixtralForSequenceClassification`] uses the last token in order to do the classification, as other causal models 1346 | (e.g. GPT-2) do. 1347 | 1348 | Since it does classification on the last token, it requires to know the position of the last token. If a 1349 | `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If 1350 | no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the 1351 | padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in 1352 | each row of the batch). 1353 | """, 1354 | MIXTRAL_START_DOCSTRING, 1355 | ) 1356 | # Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->Mixtral, LLAMA->MIXTRAL 1357 | class MixtralForSequenceClassification(MixtralPreTrainedModel): 1358 | def __init__(self, config): 1359 | super().__init__(config) 1360 | self.num_labels = config.num_labels 1361 | self.model = MixtralModel(config) 1362 | self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) 1363 | 1364 | # Initialize weights and apply final processing 1365 | self.post_init() 1366 | 1367 | def get_input_embeddings(self): 1368 | return self.model.embed_tokens 1369 | 1370 | def set_input_embeddings(self, value): 1371 | self.model.embed_tokens = value 1372 | 1373 | @add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING) 1374 | def forward( 1375 | self, 1376 | input_ids: torch.LongTensor = None, 1377 | attention_mask: Optional[torch.Tensor] = None, 1378 | position_ids: Optional[torch.LongTensor] = None, 1379 | past_key_values: Optional[List[torch.FloatTensor]] = None, 1380 | inputs_embeds: Optional[torch.FloatTensor] = None, 1381 | labels: Optional[torch.LongTensor] = None, 1382 | use_cache: Optional[bool] = None, 1383 | output_attentions: Optional[bool] = None, 1384 | output_hidden_states: Optional[bool] = None, 1385 | return_dict: Optional[bool] = None, 1386 | ) -> Union[Tuple, SequenceClassifierOutputWithPast]: 1387 | r""" 1388 | labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): 1389 | Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., 1390 | config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If 1391 | `config.num_labels > 1` a classification loss is computed (Cross-Entropy). 1392 | """ 1393 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1394 | 1395 | transformer_outputs = self.model( 1396 | input_ids, 1397 | attention_mask=attention_mask, 1398 | position_ids=position_ids, 1399 | past_key_values=past_key_values, 1400 | inputs_embeds=inputs_embeds, 1401 | use_cache=use_cache, 1402 | output_attentions=output_attentions, 1403 | output_hidden_states=output_hidden_states, 1404 | return_dict=return_dict, 1405 | ) 1406 | hidden_states = transformer_outputs[0] 1407 | logits = self.score(hidden_states) 1408 | 1409 | if input_ids is not None: 1410 | batch_size = input_ids.shape[0] 1411 | else: 1412 | batch_size = inputs_embeds.shape[0] 1413 | 1414 | if self.config.pad_token_id is None and batch_size != 1: 1415 | raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") 1416 | if self.config.pad_token_id is None: 1417 | sequence_lengths = -1 1418 | else: 1419 | if input_ids is not None: 1420 | sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1).to( 1421 | logits.device 1422 | ) 1423 | else: 1424 | sequence_lengths = -1 1425 | 1426 | pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] 1427 | 1428 | loss = None 1429 | if labels is not None: 1430 | labels = labels.to(logits.device) 1431 | if self.config.problem_type is None: 1432 | if self.num_labels == 1: 1433 | self.config.problem_type = "regression" 1434 | elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): 1435 | self.config.problem_type = "single_label_classification" 1436 | else: 1437 | self.config.problem_type = "multi_label_classification" 1438 | 1439 | if self.config.problem_type == "regression": 1440 | loss_fct = MSELoss() 1441 | if self.num_labels == 1: 1442 | loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) 1443 | else: 1444 | loss = loss_fct(pooled_logits, labels) 1445 | elif self.config.problem_type == "single_label_classification": 1446 | loss_fct = CrossEntropyLoss() 1447 | loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) 1448 | elif self.config.problem_type == "multi_label_classification": 1449 | loss_fct = BCEWithLogitsLoss() 1450 | loss = loss_fct(pooled_logits, labels) 1451 | if not return_dict: 1452 | output = (pooled_logits,) + transformer_outputs[1:] 1453 | return ((loss,) + output) if loss is not None else output 1454 | 1455 | return SequenceClassifierOutputWithPast( 1456 | loss=loss, 1457 | logits=pooled_logits, 1458 | past_key_values=transformer_outputs.past_key_values, 1459 | hidden_states=transformer_outputs.hidden_states, 1460 | attentions=transformer_outputs.attentions, 1461 | ) 1462 | -------------------------------------------------------------------------------- /examples/molora.py: -------------------------------------------------------------------------------- 1 | from scattermoe.mlp import MLP 2 | from torch import nn 3 | 4 | 5 | if __name__ == "__main__": 6 | d = 1024 7 | rank = 128 8 | N = 16 9 | top_k = 2 10 | mixture_of_lora = MLP( 11 | input_size=d, 12 | hidden_size=rank, 13 | num_experts=N, 14 | top_k=top_k, 15 | activation=nn.Identity(), 16 | ) 17 | print(mixture_of_lora) 18 | # MLP( 19 | # k=2 20 | # (experts): ParallelExperts(num_experts=16, input_size=1024, output_size=128) 21 | # (output_experts): ParallelExperts(num_experts=16, input_size=128, output_size=1024) 22 | # (activation): Identity() 23 | # ) 24 | 25 | -------------------------------------------------------------------------------- /scattermoe/__init__.py: -------------------------------------------------------------------------------- 1 | from . import kernels 2 | from . import parallel_experts 3 | from . import mlp 4 | -------------------------------------------------------------------------------- /scattermoe/kernels/__init__.py: -------------------------------------------------------------------------------- 1 | from . import ops 2 | from . import single -------------------------------------------------------------------------------- /scattermoe/kernels/ops.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import triton 3 | import triton.language as tl 4 | from torch.nn import functional as F 5 | 6 | BLOCK_M = 128 7 | ALLOW_TF32 = False 8 | 9 | @torch.library.custom_op("scattermoe::bincount", mutates_args={}) 10 | def compileable_bincount(x: torch.Tensor, minlength: int) -> torch.Tensor: 11 | return x.bincount(minlength=minlength) 12 | 13 | @compileable_bincount.register_fake 14 | def _(x: torch.Tensor, minlength: int) -> torch.Tensor: 15 | return torch.empty(minlength, dtype=torch.long, device=x.device) 16 | 17 | @torch.compile 18 | def flatten_and_sort(expert_idxs:torch.Tensor): 19 | flattened_expert_idxs = expert_idxs.flatten() 20 | sorted_expert_idxs, sorted_scattered_idxs = torch.sort(flattened_expert_idxs) 21 | return sorted_expert_idxs, sorted_scattered_idxs 22 | 23 | @torch.compile 24 | def padded_block_indices(sorted_experts_idxs: torch.Tensor, k: int, N_BLOCK_SIZE: int=BLOCK_M) : 25 | expert_counts = compileable_bincount(sorted_experts_idxs, minlength=k) 26 | padded_block_counts = ((expert_counts - 1) // N_BLOCK_SIZE) + 1 27 | padded_expert_block_end = padded_block_counts.cumsum(-1) 28 | expert_boundaries_end = expert_counts.cumsum(-1) 29 | expert_boundaries_start = expert_boundaries_end - expert_counts 30 | padded_expert_block_start = padded_expert_block_end - padded_block_counts 31 | block_idxs = torch.arange(padded_expert_block_end[-1], 32 | dtype=sorted_experts_idxs.dtype, 33 | device=sorted_experts_idxs.device) 34 | block_mask = ( 35 | (block_idxs[:, None] < padded_expert_block_start) | 36 | (block_idxs[:, None] >= padded_expert_block_end) 37 | ) 38 | expanded_block_idxs = ( 39 | N_BLOCK_SIZE * (block_idxs[:, None] - padded_expert_block_start) + 40 | expert_boundaries_start 41 | ) 42 | expanded_block_idxs = expanded_block_idxs.masked_fill(block_mask, 0).sum(-1) 43 | return expanded_block_idxs, expert_boundaries_end 44 | 45 | 46 | 47 | def _scatter2scatter_configs(): 48 | return [ 49 | triton.Config({'BLOCK_N': 128, 'BLOCK_K': 32}, num_stages=4, num_warps=4), 50 | ] 51 | 52 | @triton.autotune(configs=_scatter2scatter_configs(), key=['M', 'N', 'K'], ) 53 | @triton.heuristics({ 54 | "NO_K_MASK": lambda args: (args['K'] % args['BLOCK_K']) == 0, 55 | "NO_N_MASK": lambda args: (args['N'] % args['BLOCK_N']) == 0, 56 | }) 57 | @triton.jit 58 | def _scatter2scatter( 59 | X_ptr, stride_xm, stride_xk, 60 | W_ptr, stride_we, stride_wk, stride_wn, 61 | Y_ptr, stride_ym, stride_yn, 62 | grouped_idx_ptr, expert_idxs_ptr, block_start_idx_ptr, 63 | FAN_OUT: tl.constexpr, 64 | M, K: tl.constexpr, N: tl.constexpr, E: tl.constexpr, 65 | BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, 66 | ACC_TYPE: tl.constexpr, 67 | OUT_M, 68 | allow_tf32: tl.constexpr, 69 | x_grouped: tl.constexpr, y_grouped: tl.constexpr, 70 | NO_K_MASK: tl.constexpr, NO_N_MASK: tl.constexpr 71 | ): 72 | pid = tl.program_id(axis=0) 73 | 74 | N_BLOCK_COUNT = tl.cdiv(N, BLOCK_N) 75 | M_block_id = pid // N_BLOCK_COUNT 76 | N_block_id = pid % N_BLOCK_COUNT 77 | M_range = tl.arange(0, BLOCK_M) 78 | block_start_idx = tl.load(block_start_idx_ptr + M_block_id) 79 | # M_block = tl.max_contiguous((block_start_idx + M_range) % OUT_M, BLOCK_M) 80 | M_block = tl.max_contiguous(block_start_idx + M_range, BLOCK_M) 81 | E_idxs = tl.load(expert_idxs_ptr + M_block, mask=M_block < (FAN_OUT * M), other=E) 82 | E_idx = tl.min(E_idxs) 83 | E_mask = E_idxs == E_idx 84 | M_idx = tl.load(grouped_idx_ptr + M_block, mask=E_mask, other=0) 85 | if x_grouped: 86 | M_in_idx = M_block 87 | else: 88 | M_in_idx = M_idx // FAN_OUT 89 | 90 | if y_grouped: 91 | M_out_idx = M_block 92 | else: 93 | M_out_idx = M_idx 94 | 95 | K_block = tl.arange(0, BLOCK_K) 96 | 97 | N_block = N_block_id * BLOCK_N + tl.arange(0, BLOCK_N) 98 | N_mask = N_block < N 99 | # N_block = tl.max_contiguous(tl.multiple_of(N_block % N, BLOCK_N), BLOCK_N) 100 | # N_block = N_block_id * BLOCK_N + tl.arange(0, BLOCK_N) 101 | 102 | X_blk_ptrs = X_ptr + M_in_idx[:, None] * stride_xm + K_block[None, :] * stride_xk 103 | W_blk_ptrs = W_ptr + K_block[:, None] * stride_wk + N_block[None, :] * stride_wn + E_idx * stride_we 104 | 105 | acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) 106 | iters = tl.cdiv(K, BLOCK_K) 107 | for K_block_id in range(0, iters): 108 | if NO_K_MASK: 109 | x = tl.load(X_blk_ptrs, mask=E_mask[:, None]) 110 | if NO_N_MASK or K_block_id < (iters - 1): 111 | w = tl.load(W_blk_ptrs) 112 | else: 113 | w = tl.load(W_blk_ptrs, mask=N_mask[None, :]) 114 | else: 115 | K_mask = (K_block_id * BLOCK_K + K_block) < K 116 | x = tl.load(X_blk_ptrs, mask=E_mask[:, None] & K_mask[None, :]) 117 | w = tl.load(W_blk_ptrs, mask=K_mask[:, None] & N_mask[None, :]) 118 | X_blk_ptrs += BLOCK_K * stride_xk 119 | W_blk_ptrs += BLOCK_K * stride_wk 120 | acc += tl.dot(x, w, allow_tf32=allow_tf32, out_dtype=ACC_TYPE) 121 | 122 | Y_blk_ptrs = Y_ptr + (M_out_idx[:, None] * stride_ym + N_block[None, :] * stride_yn) 123 | tl.store(Y_blk_ptrs, acc, mask=E_mask[:, None] & N_mask[None, :]) 124 | 125 | def scatter2scatter(X, W, sorted_expert_idxs, sorted_scattered_idxs, k, 126 | padded_block_idxs, x_grouped=False, y_grouped=False, 127 | out=None): 128 | assert sorted_scattered_idxs.size(0) == sorted_expert_idxs.size(0) 129 | assert sorted_scattered_idxs.size(0) == X.size(0) * k 130 | # Pre-kernel setup 131 | x_dim = X.size(-1) 132 | y_dim = W.size(-1) 133 | L_scattered = sorted_expert_idxs.size(0) 134 | if out is None: 135 | O = torch.empty((L_scattered, y_dim), device=X.device, dtype=X.dtype) 136 | else: 137 | assert out.size(0) == L_scattered and out.size(1) == y_dim 138 | O = out 139 | 140 | # with torch.cuda.device(X.device): 141 | scatter2scatter_compileable(O, W, X, k, padded_block_idxs, sorted_expert_idxs, sorted_scattered_idxs, 142 | x_grouped, y_grouped) 143 | return O 144 | 145 | 146 | @torch.library.custom_op("scattermoe::scatter2scatter", mutates_args={"O"}) 147 | def scatter2scatter_compileable( 148 | O: torch.Tensor, 149 | W: torch.Tensor, 150 | X: torch.Tensor, 151 | k: int, 152 | padded_block_idxs: torch.Tensor, 153 | sorted_expert_idxs: torch.Tensor, 154 | sorted_scattered_idxs: torch.Tensor, 155 | x_grouped: bool, y_grouped: bool) -> None: 156 | def grid(META): 157 | grid_num = ( 158 | padded_block_idxs.size(0) * 159 | triton.cdiv(META['N'], META['BLOCK_N']), 160 | ) 161 | return grid_num 162 | 163 | _scatter2scatter[grid]( 164 | # X_ptr, stride_xm, stride_xk, 165 | X, X.stride(0), X.stride(1), 166 | # W_ptr, stride_we, stride_wk, stride_wn, 167 | W, W.stride(0), W.stride(1), W.stride(2), 168 | # Y_ptr, stride_ym, stride_yn, 169 | O, O.stride(0), O.stride(1), 170 | grouped_idx_ptr=sorted_scattered_idxs, 171 | expert_idxs_ptr=sorted_expert_idxs, 172 | block_start_idx_ptr=padded_block_idxs, 173 | FAN_OUT=k, 174 | M=X.size(0), 175 | K=X.size(1), 176 | N=O.size(1), E=W.size(0), 177 | BLOCK_M=BLOCK_M, 178 | ACC_TYPE=tl.float32, 179 | OUT_M=O.size(0), 180 | allow_tf32=ALLOW_TF32, 181 | x_grouped=x_grouped, y_grouped=y_grouped, 182 | ) 183 | 184 | 185 | def _config_XtY(): 186 | return [ 187 | triton.Config({'BLOCK_N': 128, 'BLOCK_K': 128, 'BLOCK_M': 32}, num_stages=4, num_warps=4), 188 | ] 189 | 190 | def group_bwd_W(DY, X, expert_offsets, E): 191 | DWt = torch.zeros((E, DY.size(-1), X.size(-1)), device=DY.device, dtype=DY.dtype) 192 | DW = DWt.permute(0, 2, 1) 193 | groupXtY_compileable(E, DW, DY, X, expert_offsets) 194 | return DW 195 | 196 | 197 | @torch.library.custom_op("scattermoe::groupXtY", mutates_args={"DW"}) 198 | def groupXtY_compileable( 199 | E: int, 200 | DW: torch.Tensor, 201 | DY: torch.Tensor, 202 | X: torch.Tensor, 203 | expert_offsets: torch.Tensor) -> None: 204 | def grid(META): 205 | grid = ( 206 | E * triton.cdiv(META['K'], META['BLOCK_K']), 207 | triton.cdiv(META['N'], META['BLOCK_N']), 208 | ) 209 | return grid 210 | 211 | _groupXtY[grid]( 212 | # DY_ptr, stride_dym, stride_dyk, 213 | DY, DY.stride(0), DY.stride(1), 214 | # X_ptr, stride_xm, stride_xn, 215 | X, X.stride(0), X.stride(1), 216 | # DW_ptr, stride_dwe, stride_dwk, stride_dwn, 217 | DW, DW.stride(0), DW.stride(1), DW.stride(2), 218 | # expert_offsets_ptr, 219 | expert_offsets, 220 | # K: tl.constexpr, N: tl.constexpr, 221 | M=DY.size(0), N=DY.size(-1), K=X.size(-1), 222 | # ACC_TYPE: tl.constexpr, 223 | ACC_TYPE=tl.float32, 224 | allow_tf32=True 225 | ) 226 | 227 | 228 | @triton.autotune(configs=_config_XtY(), key=['M', 'N', 'K'], ) 229 | @triton.heuristics({ 230 | "NO_K_MASK": lambda args: (args['K'] % args['BLOCK_K']) == 0, 231 | "NO_N_MASK": lambda args: (args['N'] % args['BLOCK_N']) == 0, 232 | }) 233 | @triton.jit 234 | def _groupXtY( 235 | DY_ptr, stride_dym, stride_dyk, 236 | X_ptr, stride_xm, stride_xn, 237 | DW_ptr, stride_dwe, stride_dwk, stride_dwn, 238 | expert_offsets_ptr, 239 | M, K: tl.constexpr, N: tl.constexpr, 240 | BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, 241 | ACC_TYPE: tl.constexpr, 242 | allow_tf32: tl.constexpr, 243 | NO_K_MASK: tl.constexpr, NO_N_MASK: tl.constexpr 244 | ): 245 | pid0 = tl.program_id(axis=0) 246 | pid1 = tl.program_id(axis=1) 247 | num0 = tl.num_programs(0) 248 | num1 = tl.num_programs(1) 249 | # pid1, pid0 = tl.swizzle2d(pid1, pid0, num1, num0, 128) 250 | pid0, pid1 = tl.swizzle2d(pid0, pid1, num0, num1, 4) 251 | 252 | K_BLOCK_COUNT = tl.cdiv(K, BLOCK_K) 253 | E_idx = pid0 // K_BLOCK_COUNT 254 | K_block_id = pid0 % K_BLOCK_COUNT 255 | N_block_id = pid1 256 | 257 | if E_idx == 0: 258 | start_idx = 0 259 | else: 260 | start_idx = tl.load(expert_offsets_ptr + E_idx - 1).to(tl.int32) 261 | end_idx = tl.load(expert_offsets_ptr + E_idx).to(tl.int32) 262 | 263 | if end_idx > start_idx: 264 | M_block = tl.max_contiguous(start_idx + tl.arange(0, BLOCK_M), BLOCK_M) 265 | 266 | K_block = K_block_id * BLOCK_K + tl.arange(0, BLOCK_K) 267 | K_mask = K_block < K 268 | K_block = tl.max_contiguous(tl.multiple_of(K_block % K, BLOCK_K), BLOCK_K) 269 | 270 | N_block = N_block_id * BLOCK_N + tl.arange(0, BLOCK_N) 271 | N_mask = N_block < N 272 | N_block = tl.max_contiguous(tl.multiple_of(N_block % N, BLOCK_N), BLOCK_N) 273 | 274 | M_idxs = M_block 275 | xt_blk_ptrs = X_ptr + K_block[:, None] * stride_xn + M_idxs[None, :] * stride_xm 276 | dy_blk_ptrs = DY_ptr + M_idxs[:, None] * stride_dym + N_block[None, :] * stride_dyk 277 | 278 | acc = tl.zeros((BLOCK_K, BLOCK_N), dtype=ACC_TYPE) 279 | iters = tl.cdiv(end_idx - start_idx, BLOCK_M) 280 | for i in range(0, iters): 281 | M_mask = (i * BLOCK_M + M_block) < end_idx 282 | if NO_K_MASK: 283 | xt = tl.load(xt_blk_ptrs, mask=M_mask[None, :]) 284 | else: 285 | xt = tl.load(xt_blk_ptrs, mask=K_mask[:, None] & M_mask[None, :]) 286 | if NO_N_MASK: 287 | dy = tl.load(dy_blk_ptrs, mask=M_mask[:, None]) 288 | else: 289 | dy = tl.load(dy_blk_ptrs, mask=M_mask[:, None] & N_mask[None, :]) 290 | # acc += tl.dot(xt, dy, out_dtype=ACC_TYPE, allow_tf32=allow_tf32) 291 | xt_blk_ptrs += BLOCK_M * stride_xm 292 | dy_blk_ptrs += BLOCK_M * stride_dym 293 | acc += tl.dot(xt, dy, out_dtype=ACC_TYPE, allow_tf32=allow_tf32) 294 | 295 | 296 | 297 | DW_blk_ptrs = DW_ptr + E_idx * stride_dwe + K_block[:, None] * stride_dwk + N_block[None, :] * stride_dwn 298 | acc = acc.to(DW_blk_ptrs.dtype.element_ty) 299 | tl.store(DW_blk_ptrs, acc, mask=K_mask[:, None] & N_mask[None, :]) 300 | 301 | 302 | def _config_grouping(): 303 | return [ 304 | triton.Config({'BLOCK_N': 256, 'BLOCK_K': 128}, num_stages=4, num_warps=4), 305 | # triton.Config({'BLOCK_N': 128, 'BLOCK_K': 64}, num_stages=4, num_warps=4), 306 | # triton.Config({'BLOCK_N': 64, 'BLOCK_K': 32}, num_stages=4, num_warps=4), 307 | ] 308 | 309 | def group(A, sorted_expert_idxs, coeff=None, fan_out=1, out=None): 310 | N = sorted_expert_idxs.size(0) 311 | K = A.size(1) 312 | assert A.size(0) * fan_out == N 313 | if out is not None: 314 | Y = out 315 | else: 316 | Y = torch.empty((N, K), dtype=A.dtype, device=A.device) 317 | group_compileable(A, K, N, Y, coeff, coeff is not None, fan_out, sorted_expert_idxs) 318 | return Y 319 | 320 | 321 | @torch.library.custom_op("scattermoe::group", mutates_args={"Y"}) 322 | def group_compileable( 323 | A: torch.Tensor, 324 | K: int, 325 | N: int, 326 | Y: torch.Tensor, 327 | coeff: torch.Tensor, has_coeff: bool, 328 | fan_out: int, 329 | sorted_expert_idxs: torch.Tensor) -> None: 330 | def grid(META): 331 | grid_num = (triton.cdiv(META['N'], META['BLOCK_N']),) 332 | return grid_num 333 | _group[grid]( 334 | # A_ptr, stride_an, stride_ai, 335 | A, A.stride(0), A.stride(1), has_coeff, coeff, fan_out, 336 | # Y_ptr, stride_yn, stride_yk, 337 | Y, Y.stride(0), Y.stride(1), 338 | # grouped_idx_ptr, 339 | sorted_expert_idxs, 340 | # N: tl.constexpr, K: tl.constexpr, 341 | N, K 342 | ) 343 | 344 | 345 | @triton.autotune(configs=_config_grouping(), key=['K']) 346 | @triton.heuristics({ 347 | "NO_K_MASK": lambda args: (args['K'] % args['BLOCK_K']) == 0 348 | }) 349 | @triton.jit 350 | def _group( 351 | src_ptr, stride_sn, stride_sk, has_coeff: tl.constexpr, coeff_ptr, FAN_OUT: tl.constexpr, 352 | tgt_ptr, stride_tn, stride_ti, 353 | grouped_idx_ptr, 354 | N, K: tl.constexpr, 355 | BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, 356 | NO_K_MASK: tl.constexpr 357 | ): 358 | pid = tl.program_id(axis=0) 359 | 360 | N_block_id = pid 361 | N_blk = N_block_id * BLOCK_N + tl.arange(0, BLOCK_N) 362 | N_mask = N_blk < N 363 | N_blk = tl.max_contiguous(tl.multiple_of(N_blk % N, BLOCK_N), BLOCK_N) 364 | N_idx = tl.load(grouped_idx_ptr + N_blk, mask=N_mask, other=0) 365 | 366 | K_blk = tl.arange(0, BLOCK_K) 367 | src_blk_ptrs = src_ptr + (N_idx // FAN_OUT)[:, None] * stride_sn + K_blk[None, :] * stride_sk 368 | tgt_blk_ptrs = tgt_ptr + N_blk[:, None] * stride_tn + K_blk[None, :] * stride_ti 369 | 370 | if has_coeff: 371 | c = tl.load(coeff_ptr + N_idx, mask=N_mask)[:, None] 372 | 373 | iters = tl.cdiv(K, BLOCK_K) 374 | for i in range(0, iters): 375 | if NO_K_MASK or i < iters - 1: 376 | block = tl.load(src_blk_ptrs, mask=N_mask[:, None]) 377 | if has_coeff: 378 | block *= c 379 | tl.store(tgt_blk_ptrs, block, mask=N_mask[:, None]) 380 | 381 | else: 382 | K_mask = (i * BLOCK_K + K_blk) < K 383 | mask = N_mask[:, None] & K_mask[None, :] 384 | block = tl.load(src_blk_ptrs, mask=mask) 385 | if has_coeff: 386 | block *= c 387 | tl.store(tgt_blk_ptrs, block, mask=mask) 388 | src_blk_ptrs += BLOCK_K * stride_sk 389 | tgt_blk_ptrs += BLOCK_K * stride_ti 390 | -------------------------------------------------------------------------------- /scattermoe/kernels/single.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import triton 3 | import triton.language as tl 4 | from torch.nn import functional as F 5 | 6 | @triton.jit 7 | def _single2scatter( 8 | X_ptr, stride_xm, stride_xk, 9 | W_ptr, stride_we, stride_wk, stride_wn, 10 | Y_ptr, stride_ym, stride_yn, 11 | expert_idxs_ptr, 12 | FAN_OUT: tl.constexpr, 13 | K: tl.constexpr, N: tl.constexpr, E: tl.constexpr, 14 | BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, 15 | ACC_TYPE: tl.constexpr, 16 | ): 17 | pid0 = tl.program_id(axis=0) 18 | pid1 = tl.program_id(axis=1) 19 | 20 | N_block_id = pid0 21 | if FAN_OUT == 1: 22 | in_idx = pid1 23 | else: 24 | in_idx = 0 25 | out_idx = pid1 26 | 27 | K_block = tl.arange(0, BLOCK_K) 28 | N_block = tl.max_contiguous(tl.multiple_of((N_block_id * BLOCK_N + tl.arange(0, BLOCK_N)) % N, BLOCK_N), BLOCK_N) 29 | E_idx = tl.load(expert_idxs_ptr + pid1) 30 | X_blk_ptrs = X_ptr + in_idx * stride_xm + K_block[:, None] * stride_xk 31 | W_blk_ptrs = W_ptr + E_idx * stride_we + K_block[:, None] * stride_wk + N_block[None, :] * stride_wn 32 | acc = tl.zeros((1, BLOCK_N), dtype=ACC_TYPE) 33 | for K_block_id in range(0, tl.cdiv(K, BLOCK_K)): 34 | x = tl.load(X_blk_ptrs) 35 | w = tl.load(W_blk_ptrs) 36 | acc += tl.sum(x * w, axis=0)[None, :] 37 | X_blk_ptrs += BLOCK_K * stride_xk 38 | W_blk_ptrs += BLOCK_K * stride_wk 39 | Y_blk_ptrs = Y_ptr + out_idx * stride_ym + N_block[None, :] * stride_yn 40 | tl.store(Y_blk_ptrs, acc) 41 | 42 | def single2scatter(X, W, expert_idxs): 43 | E, xdim, ydim = W.size() 44 | k = expert_idxs.size(1) 45 | assert X.size(0) == k or X.size(0) == 1 46 | Y = torch.empty((k, ydim), device=X.device, dtype=X.dtype) 47 | BLOCK_N = 128 48 | BLOCK_K = 128 49 | grid = ydim // BLOCK_N, k 50 | _single2scatter[grid]( 51 | X, X.stride(0), X.stride(1), 52 | W, W.stride(0), W.stride(1), W.stride(2), 53 | Y, Y.stride(0), Y.stride(1), 54 | expert_idxs, 55 | FAN_OUT=Y.size(0) // X.size(0), 56 | K=xdim, N=ydim, E=E, 57 | BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, 58 | ACC_TYPE=tl.float32 59 | ) 60 | return Y 61 | -------------------------------------------------------------------------------- /scattermoe/mlp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | from . import kernels 6 | from .parallel_experts import ParallelExperts 7 | 8 | class GLUMLP(nn.Module): 9 | def __init__( 10 | self, 11 | input_size, 12 | hidden_size, 13 | num_experts, 14 | top_k, 15 | activation=nn.SiLU(), 16 | ): 17 | super(GLUMLP, self).__init__() 18 | 19 | self.num_experts = num_experts 20 | self.input_size = input_size 21 | self.hidden_size = hidden_size 22 | self.experts = ParallelExperts(num_experts, input_size, 2 * hidden_size) 23 | self.output_experts = ParallelExperts(num_experts, hidden_size, input_size) 24 | self.top_k = min(top_k, self.num_experts) 25 | self.activation = activation 26 | 27 | def extra_repr(self): 28 | return 'k={}'.format(self.top_k) 29 | 30 | def forward(self, x: torch.Tensor, expert_p: torch.Tensor, expert_idxs: torch.Tensor): 31 | x_shape = x.size() 32 | x = x.view(-1, x_shape[-1]) 33 | with torch.no_grad(): 34 | sorted_expert_idxs, sorted_scattered_idxs = kernels.ops.flatten_and_sort(expert_idxs) 35 | padded_block_idxs, expert_offsets = kernels.ops.padded_block_indices(sorted_expert_idxs, self.num_experts) 36 | 37 | h, gates = self.experts( 38 | x, self.top_k, 39 | sorted_expert_idxs, sorted_scattered_idxs, 40 | padded_block_idxs, expert_offsets, 41 | grouped_out=True 42 | ).chunk(2, dim=-1) 43 | h = self.activation(gates) * h 44 | y = self.output_experts( 45 | h, 1, sorted_expert_idxs, sorted_scattered_idxs, 46 | padded_block_idxs, expert_offsets, 47 | grouped_in=True, 48 | gates=expert_p, 49 | ) 50 | y = y.view(*x_shape[:-1], y.size(-1)) 51 | return y 52 | 53 | 54 | class MLP(nn.Module): 55 | def __init__( 56 | self, 57 | input_size, 58 | hidden_size, 59 | num_experts, 60 | top_k, 61 | activation=None, 62 | ): 63 | super(MLP, self).__init__() 64 | 65 | self.num_experts = num_experts 66 | self.input_size = input_size 67 | self.hidden_size = hidden_size 68 | self.experts = ParallelExperts(num_experts, input_size, hidden_size) 69 | self.output_experts = ParallelExperts(num_experts, hidden_size, input_size) 70 | self.top_k = min(top_k, self.num_experts) 71 | self.activation = activation 72 | 73 | def extra_repr(self): 74 | return 'k={}'.format(self.top_k) 75 | 76 | def forward(self, x: torch.Tensor, expert_p: torch.Tensor, expert_idxs: torch.Tensor): 77 | x_shape = x.size() 78 | x = x.view(-1, x_shape[-1]) 79 | with torch.no_grad(): 80 | sorted_expert_idxs, sorted_scattered_idxs = kernels.ops.flatten_and_sort(expert_idxs) 81 | padded_block_idxs, expert_offsets = kernels.ops.padded_block_indices(sorted_expert_idxs, self.num_experts) 82 | 83 | h = self.experts( 84 | x, self.top_k, 85 | sorted_expert_idxs, sorted_scattered_idxs, 86 | padded_block_idxs, expert_offsets, 87 | grouped_out=True 88 | ) 89 | h = self.activation(h) 90 | y = self.output_experts( 91 | h, 1, sorted_expert_idxs, sorted_scattered_idxs, 92 | padded_block_idxs, expert_offsets, 93 | grouped_in=True, 94 | gates=expert_p, 95 | ) 96 | y = y.view(*x_shape[:-1], y.size(-1)) 97 | return y 98 | -------------------------------------------------------------------------------- /scattermoe/parallel_experts.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from . import kernels 4 | 5 | class ParallelLinear(torch.autograd.Function): 6 | @staticmethod 7 | def forward( 8 | ctx, x, expert_weights, k, 9 | sorted_expert_idxs, sorted_scattered_idxs, 10 | padded_block_idxs, expert_offsets, 11 | gates=None, grouped_in=False, grouped_out=False, 12 | ): 13 | with torch.device(x.device): 14 | output = kernels.ops.scatter2scatter( 15 | X=x, W=expert_weights, 16 | sorted_expert_idxs=sorted_expert_idxs, 17 | sorted_scattered_idxs=sorted_scattered_idxs, 18 | padded_block_idxs=padded_block_idxs, 19 | k=k, x_grouped=grouped_in, y_grouped=grouped_out 20 | ) 21 | if gates is not None: 22 | output_expanded = output.view(gates.size(0), gates.size(1), output.size(-1)) 23 | output = torch.bmm( 24 | gates[:, None, :], 25 | output_expanded 26 | ).squeeze(1) 27 | else: 28 | output_expanded = None 29 | 30 | ctx.save_for_backward( 31 | x, expert_weights, 32 | sorted_expert_idxs, 33 | sorted_scattered_idxs, 34 | padded_block_idxs, expert_offsets, 35 | gates, 36 | output_expanded 37 | ) 38 | ctx.grouped_in = grouped_in 39 | ctx.grouped_out = grouped_out 40 | ctx.k = k 41 | return output 42 | @staticmethod 43 | def backward(ctx, grad_out): 44 | (x, expert_weights, 45 | sorted_expert_idxs, 46 | sorted_scattered_idxs, 47 | padded_block_idxs, expert_offsets, 48 | gates, output_expanded) = ctx.saved_tensors 49 | k = ctx.k 50 | grouped_in = ctx.grouped_in 51 | grouped_out = ctx.grouped_out 52 | # print("backward") 53 | with torch.device(grad_out.device): 54 | if gates is not None: 55 | # calculate gates gradient 56 | d_gates = torch.bmm(output_expanded, grad_out[:, :, None]).squeeze(-1) 57 | gates_flat = gates.flatten() 58 | gate_fan = gates.size(1) 59 | # print("expanded and grouping") 60 | grouped_grad_out = output_expanded.flatten(0, 1) # reuse expanded buffer later 61 | else: 62 | d_gates = None 63 | gates_flat = None 64 | gate_fan = 1 65 | grouped_grad_out = None 66 | 67 | if grouped_out: 68 | grouped_grad_out = grad_out 69 | else: 70 | grouped_grad_out = kernels.ops.group(grad_out, sorted_scattered_idxs, 71 | fan_out=gate_fan, coeff=gates_flat, 72 | out=grouped_grad_out) 73 | if grouped_in: 74 | grouped_x = x 75 | d_expanded_input = None 76 | else: 77 | grouped_x = kernels.ops.group(x, sorted_scattered_idxs, fan_out=k) 78 | d_expanded_input = grouped_x 79 | d_weights = kernels.ops.group_bwd_W( 80 | DY=grouped_grad_out, X=grouped_x, 81 | expert_offsets=expert_offsets, 82 | E=expert_weights.size(0) 83 | ) 84 | d_expanded_input = kernels.ops.scatter2scatter( 85 | X=grouped_grad_out, x_grouped=True, 86 | W=expert_weights.permute(0, 2, 1), 87 | padded_block_idxs=padded_block_idxs, 88 | sorted_expert_idxs=sorted_expert_idxs, 89 | sorted_scattered_idxs=sorted_scattered_idxs, 90 | k=1, 91 | y_grouped=grouped_in, 92 | out=d_expanded_input # Reuse grouped_x buffer 93 | ) 94 | 95 | if k == 1: 96 | d_input = d_expanded_input 97 | else: 98 | d_input = d_expanded_input.view(x.size(0), k, d_expanded_input.size(-1)).sum(-2) 99 | # print("backward end.") 100 | return ( 101 | # x, expert_weights, k, 102 | d_input, d_weights, None, 103 | # sorted_expert_idxs, sorted_scattered_idxs, 104 | None, None, 105 | # padded_block_idxs, expert_offsets, 106 | None, None, 107 | # gates 108 | d_gates, None, None 109 | ) 110 | 111 | def parallel_linear(inputs, expert_weights, k, 112 | sorted_expert_idxs, sorted_scattered_idxs, 113 | padded_block_idxs, expert_offsets, 114 | gates=None, grouped_in=False, grouped_out=False): 115 | results = ParallelLinear.apply(inputs, expert_weights, k, 116 | sorted_expert_idxs, sorted_scattered_idxs, 117 | padded_block_idxs, expert_offsets, gates, 118 | grouped_in, grouped_out) 119 | return results 120 | 121 | class ParallelExperts(nn.Module): 122 | def __init__(self, num_experts, input_size, output_size) -> None: 123 | super().__init__() 124 | self.weight = nn.Parameter(torch.empty(num_experts, output_size, input_size)) 125 | self.reset_parameters() 126 | self.num_experts = num_experts 127 | self.input_size = input_size 128 | self.output_size = output_size 129 | 130 | def extra_repr(self): 131 | return 'num_experts={}, input_size={}, output_size={}'.format( 132 | self.num_experts, self.input_size, self.output_size) 133 | 134 | def reset_parameters(self) -> None: 135 | nn.init.normal_(self.weight, std=0.02) 136 | 137 | def forward(self, inputs, k, sorted_expert_idxs, sorted_scattered_idxs, 138 | padded_block_idxs, expert_offsets, 139 | gates=None, grouped_in=False, grouped_out=False): 140 | 141 | results = parallel_linear( 142 | inputs, self.weight.permute(0, 2, 1), k, 143 | sorted_expert_idxs, sorted_scattered_idxs, 144 | padded_block_idxs, expert_offsets, 145 | gates=gates, grouped_in=grouped_in, grouped_out=grouped_out 146 | ) 147 | return results 148 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | from setuptools import setup, find_packages 3 | 4 | def read(fname): 5 | return open(os.path.join(os.path.dirname(__file__), fname)).read() 6 | 7 | setup( 8 | name = "scattermoe", 9 | version = "0.2.0", 10 | author = "Shawn Tan", 11 | author_email = "shawn@wtf.sg", 12 | description = "Triton-based implementation of Sparse Mixture of Experts.", 13 | license = "Apache License", 14 | keywords = "triton pytorch llm", 15 | url = "https://github.com/shawntan/scattermoe", 16 | packages=find_packages(), 17 | long_description=read('README.md'), 18 | python_requires='>=3.10.10', 19 | install_requires=['torch', 'triton'], 20 | tests_require=['pytest'], 21 | classifiers=[ 22 | "Development Status :: 1 - Planning", 23 | "License :: OSI Approved :: Apache Software License", 24 | ], 25 | ) 26 | 27 | -------------------------------------------------------------------------------- /tests/test_mlp.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from torch import nn 4 | from torch.nn import functional as F 5 | from scattermoe.mlp import MLP 6 | 7 | 8 | def dumb_forward(m, x, expert_p, expert_idxs): 9 | output = torch.stack([ 10 | sum( 11 | expert_p[i, j] * F.linear( 12 | m.activation(F.linear(x[i], m.experts.weight[expert_idxs[i, j]])), 13 | m.output_experts.weight[expert_idxs[i, j]] 14 | ) 15 | for j in range(expert_idxs.size(1)) 16 | ) for i in range(expert_idxs.size(0)) 17 | ], dim=0) 18 | return output 19 | 20 | class TestClass: 21 | @pytest.mark.parametrize('length, x_dim, h_dim, E, k, dtype', [ 22 | (L, xd, (4 * xd) // k, 8, k, dt) 23 | for L in [1, 256, 512] 24 | for dt in [torch.float32] 25 | for xd in [128, 256, 512, 600, 100] 26 | for k in [2, 3, 4] 27 | ]) 28 | def test_mlp_correctness(self, length, x_dim, h_dim, E, k, dtype): 29 | logits = torch.randn(length, E, dtype=dtype) 30 | weights = torch.softmax(logits.float(), axis=-1).cuda().to(dtype) 31 | X = torch.randn(length, x_dim, dtype=dtype, requires_grad=True).cuda() 32 | DY = torch.randn(length, x_dim, dtype=dtype).cuda() 33 | k_weights, k_idxs = torch.topk(weights, k) 34 | k_weights.requires_grad_() 35 | 36 | mlp = MLP( 37 | input_size=x_dim, hidden_size=h_dim, 38 | activation=nn.GELU(), 39 | num_experts=E, top_k=k 40 | ).cuda().to(dtype) 41 | 42 | 43 | Y = mlp(X, k_weights, k_idxs) 44 | dX, dg, dW1, dW2 = torch.autograd.grad( 45 | outputs=(Y,), 46 | inputs=(X, k_weights, mlp.experts.weight, mlp.output_experts.weight), 47 | grad_outputs=(DY,) 48 | ) 49 | Y_ = dumb_forward(mlp, X, k_weights, k_idxs) 50 | dX_, dg_, dW1_, dW2_ = torch.autograd.grad( 51 | outputs=(Y_,), 52 | inputs=(X, k_weights, mlp.experts.weight, mlp.output_experts.weight), 53 | grad_outputs=(DY,) 54 | ) 55 | err_Y = torch.abs(Y_ - Y) 56 | err_dX = torch.abs(dX_ - dX) 57 | err_dg = torch.abs(dg_ - dg) 58 | err_dW1 = torch.abs(dW1_ - dW1) 59 | err_dW2 = torch.abs(dW2_ - dW2) 60 | tolerance = 1e-2 61 | assert err_Y.max() < tolerance, "Y error too large: max %0.05f" % err_Y.max() 62 | assert err_dX.max() < tolerance, "dX error too large: max %0.05f" % err_dX.max() 63 | assert err_dg.max() < tolerance, "dg error too large: max %0.05f" % err_dg.max() 64 | assert err_dW1.max() < tolerance, "dW1 error too large: max %0.05f" % err_dW1.max() 65 | assert err_dW2.max() < tolerance, "dW2 error too large: max %0.05f" % err_dW2.max() 66 | --------------------------------------------------------------------------------