├── LICENSE ├── README.md ├── assets ├── last_layer.png ├── pearson.png └── rank_count.png ├── deepseekmoe ├── config.json ├── configuration_deepseek.py └── modeling_deepseek.py ├── dynamic_analysis.ipynb ├── env.txt ├── grok ├── config.json ├── configuration_grok1.py ├── modeling_grok1.py └── modeling_grok1_outputs.py ├── mistral ├── config.json ├── configuration_mistral.py └── modeling_mistral.py ├── mixtral_base ├── config.json ├── configuration_moe_mistral.py └── modeling_moe_mistral.py ├── mixtral_base22 └── config.json ├── mixtral_instruct ├── config.json ├── configuration_moe_mistral.py └── modeling_mixtral_instruct.py ├── static_analysis.ipynb └── wikitext103_test.csv /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Lo Ka Man 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # A Closer Look into MoEs in LLMs 2 | 3 | This repository contains the code of [A Closer Look into Mixture-of-Experts in Large Language Models](https://arxiv.org/abs/2406.18219). 4 | 5 | ## Overview :eyes: 6 | 7 | We make an initial attempt to understand the inner workings of MoE-based large language models. 8 | Concretely, we comprehensively study the parametric and behavioral features of four recent MoE-based models ([Mixtral 8x7B](https://arxiv.org/pdf/2401.04088), Mixtral 8x22B, [DeepSeekMoE](https://arxiv.org/pdf/2401.06066), [Grok-1](https://github.com/xai-org/grok-1)) and reveal some intriguing observations, including: 9 | 10 | - **Neurons act like fine-grained experts.** \ 11 | Intuitively, the gate embedding determines the expert selection while the gate projection matrix of expert is responsible for choosing neurons to activate. Interestingly, their similarity values show association as described in the table below (X and Y denotes the similarity values of the gate embedding and the three projection matrices, respectively). 12 | Therefore, they may learn similar knowledge to perform the choosing operation reasonably, in the other words, the expert neurons are fine-grained experts. 13 | 14 | Squared Pearson coefficient 15 | 16 | - **The router of MoE usually selects experts with larger output norms.** \ 17 | Using Mixtral as an example here, we find that the expert that outputs feature vector with the *i*-th largest norm is the most likely to be assigned with the *i*-th highest score by the gate. 18 | (In the figure, larger ranking index means larger norm/score, so rank 8 is the largest.) 19 | 20 | Norm-score rank counting 21 | 22 | - **The expert diversity increases as the layer increases, while the last layer is an outlier.** \ 23 | In several experiments, we observe that the similarities between experts are generally lower in deep layers, whereas the similarities increase in the last layer(s). 24 | For instance, the figure below shows the similarity heat maps of the Mixtral experts' outputs of different layers, where the comparison of the values is: Layer 31 > Layer 6 > Layer 27. 25 | 26 | Deep and last layers 27 | 28 | Based on the observations, we also provide suggestions for a broad spectrum of MoE practitioners, such as router design and expert allocation. 29 | **Check out our paper for more inspiring observations and suggestions!** 30 | 31 | ## Setup :wrench: 32 | 33 | 1. Download the model checkpoints \ 34 | By default, our code loads the pre-downloaded models from the `ckpt` directory. 35 | You can also modify it to directly download from HuggingFace. The download links of the models we used are listed below: 36 | - [Mixtral 8x7B Base](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1) 37 | - [Mixtral 8x22B Base](https://huggingface.co/mistralai/Mixtral-8x22B-v0.1) 38 | - [Mixtral 8x7B Instruct](https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1) 39 | - [Mistral](https://huggingface.co/mistralai/Mistral-7B-v0.1) 40 | - [DeepSeekMoE](https://huggingface.co/deepseek-ai/deepseek-moe-16b-base) 41 | - [Grok-1](https://huggingface.co/hpcai-tech/grok-1) 42 | 43 | 2. Create the conda environment 44 | ```bash 45 | git clone https://github.com/kamanphoebe/Look-into-MoEs.git 46 | cd Look-into-MoEs 47 | conda create -n analyze --file env.txt 48 | ``` 49 | After creating the conda enviroment, you have to select it as the Jupyter kernel. 50 | 51 | ## Usage :memo: 52 | 53 | The two Jupyter notebooks `static_analysis.ipynb` and `dynamic_analysis.ipynb` contains the code of experiments about the static parameters and dynamic behaviours, respectively. 54 | You can simply run the corresponding code blocks for each experiment, which is titled the same as in the paper. 55 | Note that some experiments employ part of the [Wikitext 103 test set](https://huggingface.co/datasets/Salesforce/wikitext), which we have already provided in the `wikitext103_text.csv`. 56 | 57 | ## Citation :star2: 58 | 59 | Please cite our work if you find it useful! 60 | ```bibtex 61 | @article{lo2024closer, 62 | title={A Closer Look into Mixture-of-Experts in Large Language Models}, 63 | author={Lo, Ka Man and Huang, Zeyu and Qiu, Zihan and Wang, Zili and Fu, Jie}, 64 | journal={arXiv preprint arXiv:2406.18219}, 65 | year={2024} 66 | } 67 | ``` 68 | 69 | ## Acknowledgement :tada: 70 | 71 | Our configuration and modeling files of the models are modified based on the corresponding HuggingFace repositories as listed in the [Setup](https://github.com/kamanphoebe/Look-into-MoEs/tree/main?tab=readme-ov-file#setup-wrench) section. 72 | Thanks for the authors' great work! -------------------------------------------------------------------------------- /assets/last_layer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kamanphoebe/Look-into-MoEs/f53d24b88e30343309aa1c85df3314cea75601d2/assets/last_layer.png -------------------------------------------------------------------------------- /assets/pearson.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kamanphoebe/Look-into-MoEs/f53d24b88e30343309aa1c85df3314cea75601d2/assets/pearson.png -------------------------------------------------------------------------------- /assets/rank_count.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kamanphoebe/Look-into-MoEs/f53d24b88e30343309aa1c85df3314cea75601d2/assets/rank_count.png -------------------------------------------------------------------------------- /deepseekmoe/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "DeepseekForCausalLM" 4 | ], 5 | "attention_bias": false, 6 | "attention_dropout": 0.0, 7 | "auto_map": { 8 | "AutoConfig": "configuration_deepseek.DeepseekConfig", 9 | "AutoModel": "modeling_deepseek.DeepseekModel", 10 | "AutoModelForCausalLM": "modeling_deepseek.DeepseekForCausalLM" 11 | }, 12 | "bos_token_id": 100000, 13 | "eos_token_id": 100001, 14 | "first_k_dense_replace": 1, 15 | "hidden_act": "silu", 16 | "hidden_size": 2048, 17 | "initializer_range": 0.02, 18 | "intermediate_size": 10944, 19 | "max_position_embeddings": 4096, 20 | "model_type": "deepseek", 21 | "moe_intermediate_size": 1408, 22 | "moe_layer_freq": 1, 23 | "n_routed_experts": 64, 24 | "n_shared_experts": 2, 25 | "norm_topk_prob": false, 26 | "num_attention_heads": 16, 27 | "num_experts_per_tok": 6, 28 | "num_hidden_layers": 28, 29 | "num_key_value_heads": 16, 30 | "pretraining_tp": 1, 31 | "rms_norm_eps": 1e-06, 32 | "rope_scaling": null, 33 | "rope_theta": 10000, 34 | "scoring_func": "softmax", 35 | "tie_word_embeddings": false, 36 | "torch_dtype": "bfloat16", 37 | "transformers_version": "4.36.0", 38 | "use_cache": true, 39 | "vocab_size": 102400 40 | } 41 | -------------------------------------------------------------------------------- /deepseekmoe/configuration_deepseek.py: -------------------------------------------------------------------------------- 1 | from transformers.configuration_utils import PretrainedConfig 2 | from transformers.utils import logging 3 | 4 | logger = logging.get_logger(__name__) 5 | 6 | DEEPSEEK_PRETRAINED_CONFIG_ARCHIVE_MAP = {} 7 | class DeepseekConfig(PretrainedConfig): 8 | r""" 9 | This is the configuration class to store the configuration of a [`DeepseekModel`]. It is used to instantiate an DeepSeek 10 | model according to the specified arguments, defining the model architecture. Instantiating a configuration with the 11 | defaults will yield a similar configuration to that of the DeepSeek-7B. 12 | 13 | Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the 14 | documentation from [`PretrainedConfig`] for more information. 15 | 16 | 17 | Args: 18 | vocab_size (`int`, *optional*, defaults to 102400): 19 | Vocabulary size of the Deep model. Defines the number of different tokens that can be represented by the 20 | `inputs_ids` passed when calling [`DeepseekModel`] 21 | hidden_size (`int`, *optional*, defaults to 4096): 22 | Dimension of the hidden representations. 23 | intermediate_size (`int`, *optional*, defaults to 11008): 24 | Dimension of the MLP representations. 25 | moe_intermediate_size (`int`, *optional*, defaults to 1407): 26 | Dimension of the MoE representations. 27 | num_hidden_layers (`int`, *optional*, defaults to 32): 28 | Number of hidden layers in the Transformer decoder. 29 | num_attention_heads (`int`, *optional*, defaults to 32): 30 | Number of attention heads for each attention layer in the Transformer decoder. 31 | n_shared_experts (`int`, *optional*, defaults to None): 32 | Number of shared experts, None means dense model. 33 | n_routed_experts (`int`, *optional*, defaults to None): 34 | Number of routed experts, None means dense model. 35 | num_experts_per_tok (`int`, *optional*, defaults to None): 36 | Number of selected experts, None means dense model. 37 | moe_layer_freq (`int`, *optional*, defaults to 1): 38 | The frequency of the MoE layer: one expert layer for every `moe_layer_freq - 1` dense layers. 39 | first_k_dense_replace (`int`, *optional*, defaults to 0): 40 | Number of dense layers in shallow layers(embed->dense->dense->...->dense->moe->moe...->lm_head). 41 | \--k dense layers--/ 42 | norm_topk_prob (`bool`, *optional*, defaults to False): 43 | Whether to normalize the weights of the routed experts. 44 | scoring_func (`str`, *optional*, defaults to 'softmax'): 45 | Method of computing expert weights. 46 | aux_loss_alpha (`float`, *optional*, defaults to 0.001): 47 | Auxiliary loss weight coefficient. 48 | seq_aux = (`bool`, *optional*, defaults to True): 49 | Whether to compute the auxiliary loss for each individual sample. 50 | num_key_value_heads (`int`, *optional*): 51 | This is the number of key_value heads that should be used to implement Grouped Query Attention. If 52 | `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if 53 | `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When 54 | converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed 55 | by meanpooling all the original heads within that group. For more details checkout [this 56 | paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to 57 | `num_attention_heads`. 58 | hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): 59 | The non-linear activation function (function or string) in the decoder. 60 | max_position_embeddings (`int`, *optional*, defaults to 2048): 61 | The maximum sequence length that this model might ever be used with. 62 | initializer_range (`float`, *optional*, defaults to 0.02): 63 | The standard deviation of the truncated_normal_initializer for initializing all weight matrices. 64 | rms_norm_eps (`float`, *optional*, defaults to 1e-06): 65 | The epsilon used by the rms normalization layers. 66 | use_cache (`bool`, *optional*, defaults to `True`): 67 | Whether or not the model should return the last key/values attentions (not used by all models). Only 68 | relevant if `config.is_decoder=True`. 69 | pad_token_id (`int`, *optional*): 70 | Padding token id. 71 | bos_token_id (`int`, *optional*, defaults to 1): 72 | Beginning of stream token id. 73 | eos_token_id (`int`, *optional*, defaults to 2): 74 | End of stream token id. 75 | pretraining_tp (`int`, *optional*, defaults to 1): 76 | Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this 77 | document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is 78 | necessary to ensure exact reproducibility of the pretraining results. Please refer to [this 79 | issue](https://github.com/pytorch/pytorch/issues/76232). 80 | tie_word_embeddings (`bool`, *optional*, defaults to `False`): 81 | Whether to tie weight embeddings 82 | rope_theta (`float`, *optional*, defaults to 10000.0): 83 | The base period of the RoPE embeddings. 84 | rope_scaling (`Dict`, *optional*): 85 | Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling 86 | strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is 87 | `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update 88 | `max_position_embeddings` to the expected new maximum. 89 | attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): 90 | Whether to use a bias in the query, key, value and output projection layers during self-attention. 91 | attention_dropout (`float`, *optional*, defaults to 0.0): 92 | The dropout ratio for the attention probabilities. 93 | 94 | ```python 95 | >>> from transformers import DeepseekModel, DeepseekConfig 96 | 97 | >>> # Initializing a Deepseek deepseek-7b style configuration 98 | >>> configuration = DeepseekConfig() 99 | 100 | >>> # Accessing the model configuration 101 | >>> configuration = model.config 102 | ```""" 103 | 104 | model_type = "deepseek" 105 | keys_to_ignore_at_inference = ["past_key_values"] 106 | 107 | def __init__( 108 | self, 109 | vocab_size=102400, 110 | hidden_size=4096, 111 | intermediate_size=11008, 112 | moe_intermediate_size = 1407, 113 | num_hidden_layers=30, 114 | num_attention_heads=32, 115 | num_key_value_heads=32, 116 | n_shared_experts = None, 117 | n_routed_experts = None, 118 | num_experts_per_tok = None, 119 | moe_layer_freq = 1, 120 | first_k_dense_replace = 0, 121 | norm_topk_prob = False, 122 | scoring_func = 'softmax', 123 | aux_loss_alpha = 0.001, 124 | seq_aux = True, 125 | hidden_act="silu", 126 | max_position_embeddings=2048, 127 | initializer_range=0.02, 128 | rms_norm_eps=1e-6, 129 | use_cache=True, 130 | pad_token_id=None, 131 | bos_token_id=100000, 132 | eos_token_id=100001, 133 | pretraining_tp=1, 134 | tie_word_embeddings=False, 135 | rope_theta=10000.0, 136 | rope_scaling=None, 137 | attention_bias=False, 138 | attention_dropout=0.0, 139 | **kwargs, 140 | ): 141 | self.vocab_size = vocab_size 142 | self.max_position_embeddings = max_position_embeddings 143 | self.hidden_size = hidden_size 144 | self.intermediate_size = intermediate_size 145 | self.moe_intermediate_size = moe_intermediate_size 146 | self.num_hidden_layers = num_hidden_layers 147 | self.num_attention_heads = num_attention_heads 148 | self.n_shared_experts = n_shared_experts 149 | self.n_routed_experts = n_routed_experts 150 | self.num_experts_per_tok = num_experts_per_tok 151 | self.moe_layer_freq = moe_layer_freq 152 | self.first_k_dense_replace = first_k_dense_replace 153 | self.norm_topk_prob = norm_topk_prob 154 | self.scoring_func = scoring_func 155 | self.aux_loss_alpha = aux_loss_alpha 156 | self.seq_aux = seq_aux 157 | # for backward compatibility 158 | if num_key_value_heads is None: 159 | num_key_value_heads = num_attention_heads 160 | 161 | self.num_key_value_heads = num_key_value_heads 162 | self.hidden_act = hidden_act 163 | self.initializer_range = initializer_range 164 | self.rms_norm_eps = rms_norm_eps 165 | self.pretraining_tp = pretraining_tp 166 | self.use_cache = use_cache 167 | self.rope_theta = rope_theta 168 | self.rope_scaling = rope_scaling 169 | self._rope_scaling_validation() 170 | self.attention_bias = attention_bias 171 | self.attention_dropout = attention_dropout 172 | 173 | super().__init__( 174 | pad_token_id=pad_token_id, 175 | bos_token_id=bos_token_id, 176 | eos_token_id=eos_token_id, 177 | tie_word_embeddings=tie_word_embeddings, 178 | **kwargs, 179 | ) 180 | 181 | def _rope_scaling_validation(self): 182 | """ 183 | Validate the `rope_scaling` configuration. 184 | """ 185 | if self.rope_scaling is None: 186 | return 187 | 188 | if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2: 189 | raise ValueError( 190 | "`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, " 191 | f"got {self.rope_scaling}" 192 | ) 193 | rope_scaling_type = self.rope_scaling.get("type", None) 194 | rope_scaling_factor = self.rope_scaling.get("factor", None) 195 | if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]: 196 | raise ValueError( 197 | f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}" 198 | ) 199 | if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0: 200 | raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}") -------------------------------------------------------------------------------- /env.txt: -------------------------------------------------------------------------------- 1 | # This file may be used to create an environment using: 2 | # $ conda create --name --file 3 | # platform: linux-64 4 | @EXPLICIT 5 | https://conda.anaconda.org/conda-forge/linux-64/_libgcc_mutex-0.1-conda_forge.tar.bz2 6 | https://repo.anaconda.com/pkgs/main/linux-64/blas-1.0-mkl.conda 7 | https://repo.anaconda.com/pkgs/main/linux-64/ca-certificates-2024.3.11-h06a4308_0.conda 8 | https://conda.anaconda.org/nvidia/linux-64/cuda-cudart-11.8.89-0.tar.bz2 9 | https://conda.anaconda.org/nvidia/linux-64/cuda-cupti-11.8.87-0.tar.bz2 10 | https://conda.anaconda.org/nvidia/linux-64/cuda-nvrtc-11.8.89-0.tar.bz2 11 | https://conda.anaconda.org/nvidia/linux-64/cuda-nvtx-11.8.86-0.tar.bz2 12 | https://conda.anaconda.org/conda-forge/linux-64/ld_impl_linux-64-2.40-h41732ed_0.conda 13 | https://conda.anaconda.org/nvidia/linux-64/libcublas-11.11.3.6-0.tar.bz2 14 | https://conda.anaconda.org/nvidia/linux-64/libcufft-10.9.0.58-0.tar.bz2 15 | https://conda.anaconda.org/nvidia/linux-64/libcufile-1.8.1.2-0.tar.bz2 16 | https://conda.anaconda.org/nvidia/linux-64/libcurand-10.3.4.101-0.tar.bz2 17 | https://conda.anaconda.org/nvidia/linux-64/libcusolver-11.4.1.48-0.tar.bz2 18 | https://conda.anaconda.org/nvidia/linux-64/libcusparse-11.7.5.86-0.tar.bz2 19 | https://repo.anaconda.com/pkgs/main/linux-64/libgfortran5-11.2.0-h1234567_1.conda 20 | https://conda.anaconda.org/nvidia/linux-64/libnpp-11.8.0.86-0.tar.bz2 21 | https://conda.anaconda.org/nvidia/linux-64/libnvjpeg-11.9.0.86-0.tar.bz2 22 | https://conda.anaconda.org/conda-forge/linux-64/libstdcxx-ng-13.2.0-h7e041cc_3.conda 23 | https://conda.anaconda.org/conda-forge/linux-64/python_abi-3.10-4_cp310.conda 24 | https://conda.anaconda.org/pytorch/noarch/pytorch-mutex-1.0-cuda.tar.bz2 25 | https://conda.anaconda.org/conda-forge/noarch/tzdata-2023c-h71feb2d_0.conda 26 | https://conda.anaconda.org/nvidia/linux-64/cuda-libraries-11.8.0-0.tar.bz2 27 | https://repo.anaconda.com/pkgs/main/linux-64/libgfortran-ng-11.2.0-h00389a5_1.conda 28 | https://conda.anaconda.org/conda-forge/linux-64/libgomp-13.2.0-h807b86a_3.conda 29 | https://conda.anaconda.org/conda-forge/linux-64/_openmp_mutex-4.5-2_gnu.tar.bz2 30 | https://conda.anaconda.org/nvidia/linux-64/cuda-runtime-11.8.0-0.tar.bz2 31 | https://conda.anaconda.org/conda-forge/linux-64/libgcc-ng-13.2.0-h807b86a_3.conda 32 | https://conda.anaconda.org/pytorch/linux-64/pytorch-cuda-11.8-h7e8668a_5.tar.bz2 33 | https://repo.anaconda.com/pkgs/main/linux-64/abseil-cpp-20211102.0-hd4dd3e8_0.conda 34 | https://repo.anaconda.com/pkgs/main/linux-64/aws-c-common-0.8.5-h5eee18b_0.conda 35 | https://conda.anaconda.org/conda-forge/linux-64/bzip2-1.0.8-hd590300_5.conda 36 | https://repo.anaconda.com/pkgs/main/linux-64/c-ares-1.19.1-h5eee18b_0.conda 37 | https://repo.anaconda.com/pkgs/main/linux-64/expat-2.5.0-h6a678d5_0.conda 38 | https://repo.anaconda.com/pkgs/main/linux-64/gflags-2.2.2-h6a678d5_1.conda 39 | https://repo.anaconda.com/pkgs/main/linux-64/giflib-5.2.1-h5eee18b_3.conda 40 | https://repo.anaconda.com/pkgs/main/linux-64/gmp-6.2.1-h295c915_3.conda 41 | https://repo.anaconda.com/pkgs/main/linux-64/icu-73.1-h6a678d5_0.conda 42 | https://repo.anaconda.com/pkgs/main/linux-64/jpeg-9e-h5eee18b_1.conda 43 | https://repo.anaconda.com/pkgs/main/linux-64/lame-3.100-h7b6447c_0.conda 44 | https://repo.anaconda.com/pkgs/main/linux-64/lerc-3.0-h295c915_0.conda 45 | https://repo.anaconda.com/pkgs/main/linux-64/libbrotlicommon-1.0.9-h5eee18b_7.conda 46 | https://repo.anaconda.com/pkgs/main/linux-64/libdeflate-1.17-h5eee18b_1.conda 47 | https://repo.anaconda.com/pkgs/main/linux-64/libev-4.33-h7f8727e_1.conda 48 | https://repo.anaconda.com/pkgs/main/linux-64/libffi-3.4.4-h6a678d5_0.conda 49 | https://repo.anaconda.com/pkgs/main/linux-64/libiconv-1.16-h7f8727e_2.conda 50 | https://conda.anaconda.org/pytorch/linux-64/libjpeg-turbo-2.0.0-h9bf148f_0.tar.bz2 51 | https://conda.anaconda.org/conda-forge/linux-64/libnsl-2.0.1-hd590300_0.conda 52 | https://conda.anaconda.org/conda-forge/linux-64/libsodium-1.0.18-h36c2ea0_1.tar.bz2 53 | https://repo.anaconda.com/pkgs/main/linux-64/libtasn1-4.19.0-h5eee18b_0.conda 54 | https://repo.anaconda.com/pkgs/main/linux-64/libunistring-0.9.10-h27cfd23_0.conda 55 | https://conda.anaconda.org/conda-forge/linux-64/libuuid-2.38.1-h0b41bf4_0.conda 56 | https://repo.anaconda.com/pkgs/main/linux-64/libwebp-base-1.3.2-h5eee18b_0.conda 57 | https://repo.anaconda.com/pkgs/main/linux-64/libxcb-1.15-h7f8727e_0.conda 58 | https://conda.anaconda.org/conda-forge/linux-64/libzlib-1.2.13-hd590300_5.conda 59 | https://repo.anaconda.com/pkgs/main/linux-64/lz4-c-1.9.4-h6a678d5_0.conda 60 | https://conda.anaconda.org/conda-forge/linux-64/ncurses-6.4-h59595ed_2.conda 61 | https://repo.anaconda.com/pkgs/main/linux-64/openh264-2.1.1-h4ff587b_0.conda 62 | https://conda.anaconda.org/conda-forge/linux-64/openssl-3.2.1-hd590300_1.conda 63 | https://repo.anaconda.com/pkgs/main/linux-64/re2-2022.04.01-h295c915_0.conda 64 | https://repo.anaconda.com/pkgs/main/linux-64/snappy-1.1.10-h6a678d5_1.conda 65 | https://repo.anaconda.com/pkgs/main/linux-64/tbb-2021.8.0-hdb19cb5_0.conda 66 | https://repo.anaconda.com/pkgs/main/linux-64/utf8proc-2.6.1-h5eee18b_1.conda 67 | https://repo.anaconda.com/pkgs/main/linux-64/xxhash-0.8.0-h7f8727e_3.conda 68 | https://repo.anaconda.com/pkgs/main/linux-64/xz-5.4.2-h5eee18b_0.conda 69 | https://repo.anaconda.com/pkgs/main/linux-64/yaml-0.2.5-h7b6447c_0.conda 70 | https://repo.anaconda.com/pkgs/main/linux-64/aws-c-cal-0.5.20-hdbd6064_0.conda 71 | https://repo.anaconda.com/pkgs/main/linux-64/aws-c-compression-0.2.16-h5eee18b_0.conda 72 | https://repo.anaconda.com/pkgs/main/linux-64/aws-c-sdkutils-0.1.6-h5eee18b_0.conda 73 | https://repo.anaconda.com/pkgs/main/linux-64/aws-checksums-0.1.13-h5eee18b_0.conda 74 | https://repo.anaconda.com/pkgs/main/linux-64/glog-0.5.0-h6a678d5_1.conda 75 | https://repo.anaconda.com/pkgs/main/linux-64/libbrotlidec-1.0.9-h5eee18b_7.conda 76 | https://repo.anaconda.com/pkgs/main/linux-64/libbrotlienc-1.0.9-h5eee18b_7.conda 77 | https://repo.anaconda.com/pkgs/main/linux-64/libedit-3.1.20230828-h5eee18b_0.conda 78 | https://repo.anaconda.com/pkgs/main/linux-64/libevent-2.1.12-hdbd6064_1.conda 79 | https://repo.anaconda.com/pkgs/main/linux-64/libidn2-2.3.4-h5eee18b_0.conda 80 | https://conda.anaconda.org/conda-forge/linux-64/libsqlite-3.44.2-h2797004_0.conda 81 | https://repo.anaconda.com/pkgs/main/linux-64/libssh2-1.10.0-hdbd6064_2.conda 82 | https://repo.anaconda.com/pkgs/main/linux-64/mpfr-4.0.2-hb69a4c5_1.conda 83 | https://repo.anaconda.com/pkgs/main/linux-64/nettle-3.7.3-hbbd107a_1.conda 84 | https://conda.anaconda.org/conda-forge/linux-64/readline-8.2-h8228510_1.conda 85 | https://repo.anaconda.com/pkgs/main/linux-64/s2n-1.3.27-hdbd6064_0.conda 86 | https://conda.anaconda.org/conda-forge/linux-64/tk-8.6.13-noxft_h4845f30_101.conda 87 | https://conda.anaconda.org/conda-forge/linux-64/zeromq-4.3.5-h59595ed_0.conda 88 | https://conda.anaconda.org/conda-forge/linux-64/zlib-1.2.13-hd590300_5.conda 89 | https://repo.anaconda.com/pkgs/main/linux-64/aws-c-io-0.13.10-h5eee18b_0.conda 90 | https://repo.anaconda.com/pkgs/main/linux-64/brotli-bin-1.0.9-h5eee18b_7.conda 91 | https://repo.anaconda.com/pkgs/main/linux-64/gnutls-3.6.15-he1e5248_0.conda 92 | https://repo.anaconda.com/pkgs/main/linux-64/intel-openmp-2023.1.0-hdb19cb5_46306.conda 93 | https://repo.anaconda.com/pkgs/main/linux-64/krb5-1.20.1-h143b758_1.conda 94 | https://repo.anaconda.com/pkgs/main/linux-64/libcups-2.4.2-h2d74bed_1.conda 95 | https://repo.anaconda.com/pkgs/main/linux-64/libllvm14-14.0.6-hdb19cb5_3.conda 96 | https://repo.anaconda.com/pkgs/main/linux-64/libnghttp2-1.57.0-h2d74bed_0.conda 97 | https://repo.anaconda.com/pkgs/main/linux-64/libpng-1.6.39-h5eee18b_0.conda 98 | https://repo.anaconda.com/pkgs/main/linux-64/libprotobuf-3.20.3-he621ea3_0.conda 99 | https://repo.anaconda.com/pkgs/main/linux-64/libxml2-2.10.4-hf1b16e4_1.conda 100 | https://repo.anaconda.com/pkgs/main/linux-64/llvm-openmp-14.0.6-h9e868ea_0.conda 101 | https://repo.anaconda.com/pkgs/main/linux-64/mpc-1.1.0-h10f8cd9_1.conda 102 | https://repo.anaconda.com/pkgs/main/linux-64/pcre2-10.42-hebb0a14_0.conda 103 | https://conda.anaconda.org/conda-forge/linux-64/python-3.10.13-hd12c33a_0_cpython.conda 104 | https://conda.anaconda.org/conda-forge/linux-64/sqlite-3.44.2-h2c6b66d_0.conda 105 | https://repo.anaconda.com/pkgs/main/linux-64/zstd-1.5.5-hc292b87_0.conda 106 | https://conda.anaconda.org/conda-forge/noarch/antlr-python-runtime-4.9.3-pyhd8ed1ab_1.tar.bz2 107 | https://repo.anaconda.com/pkgs/main/linux-64/async-timeout-4.0.3-py310h06a4308_0.conda 108 | https://conda.anaconda.org/conda-forge/noarch/attrs-23.2.0-pyh71513ae_0.conda 109 | https://repo.anaconda.com/pkgs/main/linux-64/aws-c-event-stream-0.2.15-h6a678d5_0.conda 110 | https://repo.anaconda.com/pkgs/main/linux-64/aws-c-http-0.6.25-h5eee18b_0.conda 111 | https://repo.anaconda.com/pkgs/main/linux-64/brotli-1.0.9-h5eee18b_7.conda 112 | https://repo.anaconda.com/pkgs/main/linux-64/brotli-python-1.0.9-py310h6a678d5_7.conda 113 | https://conda.anaconda.org/conda-forge/noarch/cached_property-1.5.2-pyha770c72_1.tar.bz2 114 | https://repo.anaconda.com/pkgs/main/linux-64/certifi-2024.2.2-py310h06a4308_0.conda 115 | https://repo.anaconda.com/pkgs/main/noarch/charset-normalizer-2.0.4-pyhd3eb1b0_0.conda 116 | https://repo.anaconda.com/pkgs/main/noarch/cycler-0.11.0-pyhd3eb1b0_0.conda 117 | https://repo.anaconda.com/pkgs/main/linux-64/cyrus-sasl-2.1.28-h52b45da_1.conda 118 | https://conda.anaconda.org/conda-forge/linux-64/debugpy-1.8.0-py310hc6cd4ac_1.conda 119 | https://conda.anaconda.org/conda-forge/noarch/decorator-5.1.1-pyhd8ed1ab_0.tar.bz2 120 | https://conda.anaconda.org/conda-forge/noarch/defusedxml-0.7.1-pyhd8ed1ab_0.tar.bz2 121 | https://repo.anaconda.com/pkgs/main/linux-64/dill-0.3.6-py310h06a4308_0.conda 122 | https://conda.anaconda.org/conda-forge/noarch/entrypoints-0.4-pyhd8ed1ab_0.tar.bz2 123 | https://conda.anaconda.org/conda-forge/noarch/exceptiongroup-1.2.0-pyhd8ed1ab_0.conda 124 | https://conda.anaconda.org/conda-forge/noarch/executing-2.0.1-pyhd8ed1ab_0.conda 125 | https://repo.anaconda.com/pkgs/main/linux-64/filelock-3.13.1-py310h06a4308_0.conda 126 | https://repo.anaconda.com/pkgs/main/linux-64/freetype-2.12.1-h4a9f257_0.conda 127 | https://repo.anaconda.com/pkgs/main/linux-64/frozenlist-1.4.0-py310h5eee18b_0.conda 128 | https://repo.anaconda.com/pkgs/main/linux-64/fsspec-2023.10.0-py310h06a4308_0.conda 129 | https://repo.anaconda.com/pkgs/main/linux-64/gmpy2-2.1.2-py310heeb90bb_0.conda 130 | https://repo.anaconda.com/pkgs/main/linux-64/grpc-cpp-1.48.2-he1ff14a_1.conda 131 | https://conda.anaconda.org/conda-forge/noarch/hpack-4.0.0-pyh9f0ad1d_0.tar.bz2 132 | https://conda.anaconda.org/conda-forge/noarch/hyperframe-6.0.1-pyhd8ed1ab_0.tar.bz2 133 | https://repo.anaconda.com/pkgs/main/linux-64/idna-3.4-py310h06a4308_0.conda 134 | https://repo.anaconda.com/pkgs/main/linux-64/importlib_resources-6.1.1-py310h06a4308_1.conda 135 | https://conda.anaconda.org/anaconda/linux-64/joblib-1.2.0-py310h06a4308_0.tar.bz2 136 | https://conda.anaconda.org/conda-forge/noarch/json5-0.9.24-pyhd8ed1ab_0.conda 137 | https://conda.anaconda.org/conda-forge/linux-64/jsonpointer-2.4-py310hff52083_3.conda 138 | https://repo.anaconda.com/pkgs/main/linux-64/kiwisolver-1.4.4-py310h6a678d5_0.conda 139 | https://repo.anaconda.com/pkgs/main/linux-64/libboost-1.82.0-h109eef0_2.conda 140 | https://repo.anaconda.com/pkgs/main/linux-64/libclang13-14.0.6-default_he11475f_1.conda 141 | https://repo.anaconda.com/pkgs/main/linux-64/libcurl-8.5.0-h251f7ec_0.conda 142 | https://repo.anaconda.com/pkgs/main/linux-64/libglib-2.78.4-hdc74915_0.conda 143 | https://repo.anaconda.com/pkgs/main/linux-64/libpq-12.17-hdbd6064_0.conda 144 | https://repo.anaconda.com/pkgs/main/linux-64/libtiff-4.5.1-h6a678d5_0.conda 145 | https://repo.anaconda.com/pkgs/main/linux-64/libxkbcommon-1.0.1-h5eee18b_1.conda 146 | https://repo.anaconda.com/pkgs/main/linux-64/markupsafe-2.1.1-py310h7f8727e_0.conda 147 | https://conda.anaconda.org/conda-forge/noarch/mistune-3.0.2-pyhd8ed1ab_0.conda 148 | https://repo.anaconda.com/pkgs/main/linux-64/mkl-2023.1.0-h213fc3f_46344.conda 149 | https://repo.anaconda.com/pkgs/main/linux-64/mpmath-1.3.0-py310h06a4308_0.conda 150 | https://repo.anaconda.com/pkgs/main/linux-64/multidict-6.0.4-py310h5eee18b_0.conda 151 | https://repo.anaconda.com/pkgs/main/noarch/munkres-1.1.4-py_0.conda 152 | https://conda.anaconda.org/conda-forge/noarch/nest-asyncio-1.5.8-pyhd8ed1ab_0.conda 153 | https://repo.anaconda.com/pkgs/main/linux-64/networkx-3.1-py310h06a4308_0.conda 154 | https://repo.anaconda.com/pkgs/main/linux-64/orc-1.7.4-hb3bc3d3_1.conda 155 | https://conda.anaconda.org/conda-forge/noarch/packaging-23.2-pyhd8ed1ab_0.conda 156 | https://conda.anaconda.org/conda-forge/noarch/pandocfilters-1.5.0-pyhd8ed1ab_0.tar.bz2 157 | https://conda.anaconda.org/conda-forge/noarch/parso-0.8.3-pyhd8ed1ab_0.tar.bz2 158 | https://conda.anaconda.org/conda-forge/noarch/pickleshare-0.7.5-py_1003.tar.bz2 159 | https://conda.anaconda.org/conda-forge/noarch/pkgutil-resolve-name-1.3.10-pyhd8ed1ab_1.conda 160 | https://repo.anaconda.com/pkgs/main/linux-64/ply-3.11-py310h06a4308_0.conda 161 | https://conda.anaconda.org/conda-forge/noarch/prometheus_client-0.20.0-pyhd8ed1ab_0.conda 162 | https://conda.anaconda.org/conda-forge/linux-64/psutil-5.9.5-py310h2372a71_1.conda 163 | https://conda.anaconda.org/conda-forge/noarch/ptyprocess-0.7.0-pyhd3deb0d_0.tar.bz2 164 | https://conda.anaconda.org/conda-forge/noarch/pure_eval-0.2.2-pyhd8ed1ab_0.tar.bz2 165 | https://repo.anaconda.com/pkgs/main/noarch/pycparser-2.21-pyhd3eb1b0_0.conda 166 | https://conda.anaconda.org/conda-forge/noarch/pygments-2.17.2-pyhd8ed1ab_0.conda 167 | https://repo.anaconda.com/pkgs/main/linux-64/pyparsing-3.0.9-py310h06a4308_0.conda 168 | https://repo.anaconda.com/pkgs/main/linux-64/pyqt5-sip-12.13.0-py310h5eee18b_0.conda 169 | https://repo.anaconda.com/pkgs/main/linux-64/pysocks-1.7.1-py310h06a4308_0.conda 170 | https://conda.anaconda.org/conda-forge/noarch/python-fastjsonschema-2.19.1-pyhd8ed1ab_0.conda 171 | https://conda.anaconda.org/conda-forge/noarch/python-json-logger-2.0.7-pyhd8ed1ab_0.conda 172 | https://repo.anaconda.com/pkgs/main/noarch/python-tzdata-2023.3-pyhd3eb1b0_0.conda 173 | https://repo.anaconda.com/pkgs/main/linux-64/python-xxhash-2.0.2-py310h5eee18b_1.conda 174 | https://conda.anaconda.org/conda-forge/noarch/pytz-2024.1-pyhd8ed1ab_0.conda 175 | https://repo.anaconda.com/pkgs/main/linux-64/pyyaml-6.0.1-py310h5eee18b_0.conda 176 | https://conda.anaconda.org/conda-forge/linux-64/pyzmq-24.0.1-py310h330234f_1.tar.bz2 177 | https://repo.anaconda.com/pkgs/main/linux-64/regex-2023.10.3-py310h5eee18b_0.conda 178 | https://conda.anaconda.org/conda-forge/noarch/rfc3986-validator-0.1.1-pyh9f0ad1d_0.tar.bz2 179 | https://conda.anaconda.org/conda-forge/linux-64/rpds-py-0.18.0-py310hcb5633a_0.conda 180 | https://repo.anaconda.com/pkgs/main/linux-64/safetensors-0.4.2-py310ha89cbab_0.conda 181 | https://conda.anaconda.org/conda-forge/noarch/send2trash-1.8.2-pyh41d4057_0.conda 182 | https://conda.anaconda.org/conda-forge/noarch/setuptools-68.2.2-pyhd8ed1ab_0.conda 183 | https://conda.anaconda.org/conda-forge/noarch/six-1.16.0-pyh6c4a22f_0.tar.bz2 184 | https://conda.anaconda.org/conda-forge/noarch/sniffio-1.3.1-pyhd8ed1ab_0.conda 185 | https://conda.anaconda.org/conda-forge/noarch/soupsieve-2.5-pyhd8ed1ab_1.conda 186 | https://conda.anaconda.org/anaconda/noarch/threadpoolctl-2.2.0-pyh0d69192_0.tar.bz2 187 | https://repo.anaconda.com/pkgs/main/linux-64/tomli-2.0.1-py310h06a4308_0.conda 188 | https://conda.anaconda.org/conda-forge/linux-64/tornado-6.3.3-py310h2372a71_1.conda 189 | https://repo.anaconda.com/pkgs/main/linux-64/tqdm-4.65.0-py310h2f386ee_0.conda 190 | https://conda.anaconda.org/conda-forge/noarch/traitlets-5.14.0-pyhd8ed1ab_0.conda 191 | https://conda.anaconda.org/conda-forge/noarch/types-python-dateutil-2.9.0.20240316-pyhd8ed1ab_0.conda 192 | https://conda.anaconda.org/conda-forge/noarch/typing_extensions-4.8.0-pyha770c72_0.conda 193 | https://conda.anaconda.org/conda-forge/noarch/typing_utils-0.1.0-pyhd8ed1ab_0.tar.bz2 194 | https://conda.anaconda.org/conda-forge/noarch/uri-template-1.3.0-pyhd8ed1ab_0.conda 195 | https://conda.anaconda.org/conda-forge/noarch/wcwidth-0.2.12-pyhd8ed1ab_0.conda 196 | https://conda.anaconda.org/conda-forge/noarch/webcolors-1.13-pyhd8ed1ab_0.conda 197 | https://conda.anaconda.org/conda-forge/noarch/webencodings-0.5.1-pyhd8ed1ab_2.conda 198 | https://conda.anaconda.org/conda-forge/noarch/websocket-client-1.7.0-pyhd8ed1ab_0.conda 199 | https://conda.anaconda.org/conda-forge/noarch/wheel-0.42.0-pyhd8ed1ab_0.conda 200 | https://conda.anaconda.org/conda-forge/noarch/zipp-3.17.0-pyhd8ed1ab_0.conda 201 | https://repo.anaconda.com/pkgs/main/noarch/aiosignal-1.2.0-pyhd3eb1b0_0.conda 202 | https://conda.anaconda.org/conda-forge/noarch/anyio-4.3.0-pyhd8ed1ab_0.conda 203 | https://conda.anaconda.org/conda-forge/noarch/asttokens-2.4.1-pyhd8ed1ab_0.conda 204 | https://conda.anaconda.org/conda-forge/noarch/async-lru-2.0.4-pyhd8ed1ab_0.conda 205 | https://repo.anaconda.com/pkgs/main/linux-64/aws-c-auth-0.6.19-h5eee18b_0.conda 206 | https://repo.anaconda.com/pkgs/main/linux-64/aws-c-mqtt-0.7.13-h5eee18b_0.conda 207 | https://conda.anaconda.org/conda-forge/noarch/babel-2.14.0-pyhd8ed1ab_0.conda 208 | https://conda.anaconda.org/conda-forge/noarch/beautifulsoup4-4.12.3-pyha770c72_0.conda 209 | https://conda.anaconda.org/conda-forge/noarch/bleach-6.1.0-pyhd8ed1ab_0.conda 210 | https://repo.anaconda.com/pkgs/main/linux-64/boost-cpp-1.82.0-hdb19cb5_2.conda 211 | https://conda.anaconda.org/conda-forge/noarch/cached-property-1.5.2-hd8ed1ab_1.tar.bz2 212 | https://repo.anaconda.com/pkgs/main/linux-64/cffi-1.16.0-py310h5eee18b_0.conda 213 | https://conda.anaconda.org/conda-forge/noarch/comm-0.1.4-pyhd8ed1ab_0.conda 214 | https://conda.anaconda.org/pytorch/linux-64/ffmpeg-4.3-hf484d3e_0.tar.bz2 215 | https://conda.anaconda.org/conda-forge/linux-64/fontconfig-2.14.2-h14ed4e7_0.conda 216 | https://repo.anaconda.com/pkgs/main/noarch/fonttools-4.25.0-pyhd3eb1b0_0.conda 217 | https://repo.anaconda.com/pkgs/main/linux-64/glib-tools-2.78.4-h6a678d5_0.conda 218 | https://conda.anaconda.org/conda-forge/noarch/h11-0.14.0-pyhd8ed1ab_0.tar.bz2 219 | https://conda.anaconda.org/conda-forge/noarch/h2-4.1.0-pyhd8ed1ab_0.tar.bz2 220 | https://conda.anaconda.org/conda-forge/noarch/importlib-metadata-6.9.0-pyha770c72_0.conda 221 | https://conda.anaconda.org/conda-forge/noarch/jedi-0.19.1-pyhd8ed1ab_0.conda 222 | https://repo.anaconda.com/pkgs/main/linux-64/jinja2-3.1.2-py310h06a4308_0.conda 223 | https://conda.anaconda.org/conda-forge/noarch/jupyterlab_pygments-0.3.0-pyhd8ed1ab_1.conda 224 | https://repo.anaconda.com/pkgs/main/linux-64/lcms2-2.12-h3be6417_0.conda 225 | https://repo.anaconda.com/pkgs/main/linux-64/libclang-14.0.6-default_hc6dbbc7_1.conda 226 | https://repo.anaconda.com/pkgs/main/linux-64/libwebp-1.3.2-h11a3e52_0.conda 227 | https://conda.anaconda.org/conda-forge/noarch/matplotlib-inline-0.1.6-pyhd8ed1ab_0.tar.bz2 228 | https://repo.anaconda.com/pkgs/main/linux-64/mkl-service-2.4.0-py310h5eee18b_1.conda 229 | https://repo.anaconda.com/pkgs/main/linux-64/multiprocess-0.70.14-py310h06a4308_0.conda 230 | https://repo.anaconda.com/pkgs/main/linux-64/mysql-5.7.24-h721c034_2.conda 231 | https://conda.anaconda.org/conda-forge/noarch/omegaconf-2.3.0-pyhd8ed1ab_0.conda 232 | https://repo.anaconda.com/pkgs/main/linux-64/openjpeg-2.4.0-h3ad879b_0.conda 233 | https://conda.anaconda.org/conda-forge/noarch/overrides-7.7.0-pyhd8ed1ab_0.conda 234 | https://conda.anaconda.org/conda-forge/noarch/pexpect-4.8.0-pyh1a96a4e_2.tar.bz2 235 | https://conda.anaconda.org/conda-forge/noarch/pip-23.3.1-pyhd8ed1ab_0.conda 236 | https://conda.anaconda.org/conda-forge/noarch/platformdirs-4.0.0-pyhd8ed1ab_0.conda 237 | https://conda.anaconda.org/conda-forge/noarch/prompt-toolkit-3.0.41-pyha770c72_0.conda 238 | https://conda.anaconda.org/conda-forge/noarch/python-dateutil-2.8.2-pyhd8ed1ab_0.tar.bz2 239 | https://conda.anaconda.org/conda-forge/noarch/referencing-0.34.0-pyhd8ed1ab_0.conda 240 | https://conda.anaconda.org/conda-forge/noarch/rfc3339-validator-0.1.4-pyhd8ed1ab_0.tar.bz2 241 | https://repo.anaconda.com/pkgs/main/linux-64/sip-6.7.12-py310h6a678d5_0.conda 242 | https://conda.anaconda.org/conda-forge/noarch/sympy-1.12-pypyh9d50eac_103.conda 243 | https://conda.anaconda.org/conda-forge/noarch/terminado-0.18.1-pyh0d859eb_0.conda 244 | https://conda.anaconda.org/conda-forge/noarch/tinycss2-1.2.1-pyhd8ed1ab_0.tar.bz2 245 | https://conda.anaconda.org/conda-forge/noarch/typing-extensions-4.8.0-hd8ed1ab_0.conda 246 | https://repo.anaconda.com/pkgs/main/linux-64/yarl-1.9.3-py310h5eee18b_0.conda 247 | https://repo.anaconda.com/pkgs/main/linux-64/aiohttp-3.9.3-py310h5eee18b_0.conda 248 | https://conda.anaconda.org/conda-forge/linux-64/argon2-cffi-bindings-21.2.0-py310h2372a71_4.conda 249 | https://conda.anaconda.org/conda-forge/noarch/arrow-1.3.0-pyhd8ed1ab_0.conda 250 | https://repo.anaconda.com/pkgs/main/linux-64/aws-c-s3-0.1.51-hdbd6064_0.conda 251 | https://repo.anaconda.com/pkgs/main/linux-64/cryptography-41.0.3-py310hdda0065_0.conda 252 | https://conda.anaconda.org/conda-forge/noarch/fqdn-1.5.1-pyhd8ed1ab_0.tar.bz2 253 | https://repo.anaconda.com/pkgs/main/linux-64/glib-2.78.4-h6a678d5_0.conda 254 | https://conda.anaconda.org/conda-forge/noarch/httpcore-1.0.5-pyhd8ed1ab_0.conda 255 | https://conda.anaconda.org/conda-forge/noarch/hydra-core-1.3.2-pyhd8ed1ab_0.conda 256 | https://conda.anaconda.org/conda-forge/noarch/importlib_metadata-6.9.0-hd8ed1ab_0.conda 257 | https://conda.anaconda.org/conda-forge/noarch/jsonschema-specifications-2023.12.1-pyhd8ed1ab_0.conda 258 | https://conda.anaconda.org/conda-forge/linux-64/jupyter_core-5.5.0-py310hff52083_0.conda 259 | https://conda.anaconda.org/conda-forge/noarch/jupyter_server_terminals-0.5.3-pyhd8ed1ab_0.conda 260 | https://repo.anaconda.com/pkgs/main/linux-64/libthrift-0.15.0-h1795dd8_2.conda 261 | https://repo.anaconda.com/pkgs/main/linux-64/numpy-base-1.26.4-py310hb5e798b_0.conda 262 | https://repo.anaconda.com/pkgs/main/linux-64/pillow-10.0.1-py310ha6cbd5a_0.conda 263 | https://conda.anaconda.org/conda-forge/noarch/stack_data-0.6.2-pyhd8ed1ab_0.conda 264 | https://conda.anaconda.org/conda-forge/noarch/argon2-cffi-23.1.0-pyhd8ed1ab_0.conda 265 | https://repo.anaconda.com/pkgs/main/linux-64/aws-crt-cpp-0.18.16-h6a678d5_0.conda 266 | https://repo.anaconda.com/pkgs/main/linux-64/dbus-1.13.18-hb2f20db_0.conda 267 | https://repo.anaconda.com/pkgs/main/linux-64/gstreamer-1.14.1-h5eee18b_1.conda 268 | https://conda.anaconda.org/conda-forge/noarch/httpx-0.27.0-pyhd8ed1ab_0.conda 269 | https://conda.anaconda.org/conda-forge/noarch/ipython-8.18.1-pyh31011fe_1.conda 270 | https://conda.anaconda.org/conda-forge/noarch/isoduration-20.11.0-pyhd8ed1ab_0.tar.bz2 271 | https://conda.anaconda.org/conda-forge/noarch/jsonschema-4.21.1-pyhd8ed1ab_0.conda 272 | https://conda.anaconda.org/conda-forge/noarch/jupyter_client-7.4.9-pyhd8ed1ab_0.conda 273 | https://repo.anaconda.com/pkgs/main/linux-64/pyopenssl-23.2.0-py310h06a4308_0.conda 274 | https://repo.anaconda.com/pkgs/main/linux-64/aws-sdk-cpp-1.10.55-h721c034_0.conda 275 | https://repo.anaconda.com/pkgs/main/linux-64/gst-plugins-base-1.14.1-h6a678d5_1.conda 276 | https://conda.anaconda.org/conda-forge/noarch/ipykernel-6.26.0-pyhf8b6a83_0.conda 277 | https://conda.anaconda.org/conda-forge/noarch/jsonschema-with-format-nongpl-4.21.1-pyhd8ed1ab_0.conda 278 | https://conda.anaconda.org/conda-forge/noarch/nbformat-5.10.4-pyhd8ed1ab_0.conda 279 | https://repo.anaconda.com/pkgs/main/linux-64/urllib3-1.26.18-py310h06a4308_0.conda 280 | https://repo.anaconda.com/pkgs/main/linux-64/arrow-cpp-14.0.2-h374c478_1.conda 281 | https://conda.anaconda.org/conda-forge/noarch/jupyter_events-0.10.0-pyhd8ed1ab_0.conda 282 | https://conda.anaconda.org/conda-forge/noarch/nbclient-0.10.0-pyhd8ed1ab_0.conda 283 | https://repo.anaconda.com/pkgs/main/linux-64/qt-main-5.15.2-h53bd1ea_10.conda 284 | https://repo.anaconda.com/pkgs/main/linux-64/requests-2.31.0-py310h06a4308_0.conda 285 | https://repo.anaconda.com/pkgs/main/linux-64/huggingface_hub-0.20.3-py310h06a4308_0.conda 286 | https://conda.anaconda.org/conda-forge/noarch/nbconvert-core-7.16.3-pyhd8ed1ab_0.conda 287 | https://repo.anaconda.com/pkgs/main/linux-64/pyqt-5.15.10-py310h6a678d5_0.conda 288 | https://repo.anaconda.com/pkgs/main/noarch/responses-0.13.3-pyhd3eb1b0_0.conda 289 | https://conda.anaconda.org/conda-forge/noarch/jupyter_server-2.13.0-pyhd8ed1ab_0.conda 290 | https://repo.anaconda.com/pkgs/main/linux-64/tokenizers-0.15.1-py310h22610ee_0.conda 291 | https://conda.anaconda.org/conda-forge/noarch/jupyter-lsp-2.2.4-pyhd8ed1ab_0.conda 292 | https://conda.anaconda.org/conda-forge/noarch/jupyterlab_server-2.25.4-pyhd8ed1ab_0.conda 293 | https://conda.anaconda.org/conda-forge/noarch/notebook-shim-0.2.4-pyhd8ed1ab_0.conda 294 | https://conda.anaconda.org/conda-forge/noarch/jupyterlab-4.1.5-pyhd8ed1ab_0.conda 295 | https://conda.anaconda.org/conda-forge/noarch/notebook-7.1.2-pyhd8ed1ab_0.conda 296 | https://conda.anaconda.org/conda-forge/linux-64/ml_dtypes-0.4.0-py310hcc13569_0.conda 297 | https://repo.anaconda.com/pkgs/main/linux-64/bottleneck-1.3.7-py310ha9d4c09_0.conda 298 | https://repo.anaconda.com/pkgs/main/linux-64/contourpy-1.2.0-py310hdb19cb5_0.conda 299 | https://repo.anaconda.com/pkgs/main/linux-64/matplotlib-3.8.0-py310h06a4308_0.conda 300 | https://repo.anaconda.com/pkgs/main/linux-64/matplotlib-base-3.8.0-py310h1128e8f_0.conda 301 | https://repo.anaconda.com/pkgs/main/linux-64/mkl_fft-1.3.8-py310h5eee18b_0.conda 302 | https://repo.anaconda.com/pkgs/main/linux-64/mkl_random-1.2.4-py310hdb19cb5_0.conda 303 | https://repo.anaconda.com/pkgs/main/linux-64/numpy-1.26.4-py310h5f9d8c6_0.conda 304 | https://repo.anaconda.com/pkgs/main/linux-64/numexpr-2.8.7-py310h85018f9_0.conda 305 | https://repo.anaconda.com/pkgs/main/linux-64/pyarrow-14.0.2-py310h1eedbd7_0.conda 306 | https://repo.anaconda.com/pkgs/main/linux-64/scipy-1.11.4-py310h5f9d8c6_0.conda 307 | https://repo.anaconda.com/pkgs/main/linux-64/transformers-4.37.2-py310h06a4308_0.conda 308 | https://repo.anaconda.com/pkgs/main/linux-64/pandas-2.2.1-py310h6a678d5_0.conda 309 | https://conda.anaconda.org/anaconda/linux-64/scikit-learn-1.3.0-py310h1128e8f_0.tar.bz2 310 | https://repo.anaconda.com/pkgs/main/linux-64/datasets-2.12.0-py310h06a4308_0.conda 311 | https://conda.anaconda.org/conda-forge/noarch/accelerate-0.27.0-pyhd8ed1ab_0.conda 312 | https://conda.anaconda.org/pytorch/linux-64/pytorch-2.1.1-py3.10_cuda11.8_cudnn8.7.0_0.tar.bz2 313 | https://conda.anaconda.org/pytorch/linux-64/torchaudio-2.1.1-py310_cu118.tar.bz2 314 | https://conda.anaconda.org/pytorch/linux-64/torchtriton-2.1.0-py310.tar.bz2 315 | https://conda.anaconda.org/pytorch/linux-64/torchvision-0.16.1-py310_cu118.tar.bz2 316 | -------------------------------------------------------------------------------- /grok/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "Grok1ModelForCausalLM" 4 | ], 5 | "auto_map": { 6 | "AutoConfig": "configuration_grok1.Grok1Config", 7 | "AutoModel": "modeling_grok1.Grok1Model", 8 | "AutoModelForCausalLM": "modeling_grok1.Grok1ModelForCausalLM" 9 | }, 10 | "vocab_size": 131072, 11 | "hidden_size": 6144, 12 | "intermediate_size": 32768, 13 | "num_hidden_layers": 64, 14 | "num_attention_heads": 48, 15 | "num_key_value_heads": 8, 16 | "attn_output_multiplier": 0.08838834764831845, 17 | "embedding_multiplier_scale": 78.38367176906169, 18 | "output_multiplier_scale": 0.5773502691896257, 19 | "max_attn_value": 30.0, 20 | "max_position_embeddings": 8192, 21 | "rms_norm_eps": 1e-5, 22 | "use_cache": true, 23 | "pad_token_id": 0, 24 | "bos_token_id": 1, 25 | "eos_token_id": 2, 26 | "tie_word_embeddings": true, 27 | "num_experts_per_tok": 2, 28 | "num_experts": 8, 29 | "output_router_logits": false, 30 | "router_aux_loss_coef": 0.001, 31 | "torch_dtype": "bfloat16", 32 | "transformers_version": "4.35.0" 33 | } -------------------------------------------------------------------------------- /grok/configuration_grok1.py: -------------------------------------------------------------------------------- 1 | from transformers.configuration_utils import PretrainedConfig 2 | 3 | 4 | class Grok1Config(PretrainedConfig): 5 | model_type = "grok-1" 6 | keys_to_ignore_at_inference = ["past_key_values"] 7 | 8 | def __init__( 9 | self, 10 | vocab_size=32000, 11 | hidden_size=4096, 12 | intermediate_size=32768, 13 | num_hidden_layers=32, 14 | num_attention_heads=32, 15 | num_key_value_heads=32, 16 | attn_output_multiplier=1.0, 17 | max_attn_value=1.0, 18 | max_position_embeddings=4096, 19 | embedding_multiplier_scale: float = 1.0, 20 | output_multiplier_scale: float = 1.0, 21 | rms_norm_eps=1e-5, 22 | use_cache=True, 23 | pad_token_id=None, 24 | bos_token_id=1, 25 | eos_token_id=2, 26 | tie_word_embeddings=True, 27 | num_experts_per_tok=2, 28 | num_experts=8, 29 | output_router_logits=False, 30 | router_aux_loss_coef=0.001, 31 | **kwargs 32 | ): 33 | self.vocab_size = vocab_size 34 | self.attn_output_multiplier = attn_output_multiplier 35 | self.max_attn_value = max_attn_value 36 | self.max_position_embeddings = max_position_embeddings 37 | self.embedding_multiplier_scale = embedding_multiplier_scale 38 | self.output_multiplier_scale = output_multiplier_scale 39 | self.hidden_size = hidden_size 40 | self.intermediate_size = intermediate_size 41 | self.num_hidden_layers = num_hidden_layers 42 | self.num_attention_heads = num_attention_heads 43 | 44 | # for backward compatibility 45 | if num_key_value_heads is None: 46 | num_key_value_heads = num_attention_heads 47 | 48 | self.num_key_value_heads = num_key_value_heads 49 | self.rms_norm_eps = rms_norm_eps 50 | self.use_cache = use_cache 51 | 52 | self.num_experts_per_tok = num_experts_per_tok 53 | self.num_experts = num_experts 54 | self.output_router_logits = output_router_logits 55 | self.router_aux_loss_coef = router_aux_loss_coef 56 | super().__init__( 57 | pad_token_id=pad_token_id, 58 | bos_token_id=bos_token_id, 59 | eos_token_id=eos_token_id, 60 | tie_word_embeddings=tie_word_embeddings, 61 | **kwargs, 62 | ) 63 | -------------------------------------------------------------------------------- /grok/modeling_grok1.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Tuple, Union 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from transformers.modeling_utils import PreTrainedModel 7 | from transformers.utils import logging 8 | 9 | try: 10 | from transformers.modeling_attn_mask_utils import \ 11 | _prepare_4d_causal_attention_mask 12 | 13 | HAS_MASK_UTILS = True 14 | except ImportError: 15 | HAS_MASK_UTILS = False 16 | 17 | from .configuration_grok1 import Grok1Config 18 | from .modeling_grok1_outputs import (MoeCausalLMOutputWithPast, 19 | MoeModelOutputWithPast) 20 | 21 | logger = logging.get_logger(__name__) 22 | 23 | 24 | # copied from https://github.com/huggingface/transformers/blob/v4.36.1/src/transformers/models/mixtral/modeling_mixtral.py 25 | def load_balancing_loss_func( 26 | gate_logits: torch.Tensor, num_experts: torch.Tensor = None, top_k=2 27 | ) -> float: 28 | r""" 29 | Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch. 30 | 31 | See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss 32 | function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between 33 | experts is too unbalanced. 34 | 35 | Args: 36 | gate_logits (Union[`torch.Tensor`, Tuple[torch.Tensor]): 37 | Logits from the `gate`, should be a tuple of tensors. Shape: [batch_size, seqeunce_length, num_experts]. 38 | num_experts (`int`, *optional*): 39 | Number of experts 40 | 41 | Returns: 42 | The auxiliary loss. 43 | """ 44 | if gate_logits is None: 45 | return 0 46 | 47 | if isinstance(gate_logits, tuple): 48 | # cat along the layers? 49 | compute_device = gate_logits[0].device 50 | gate_logits = torch.cat( 51 | [gate.to(compute_device) for gate in gate_logits], dim=0 52 | ) 53 | 54 | routing_weights, selected_experts = torch.topk(gate_logits, top_k, dim=-1) 55 | routing_weights = routing_weights.softmax(dim=-1) 56 | 57 | # cast the expert indices to int64, otherwise one-hot encoding will fail 58 | if selected_experts.dtype != torch.int64: 59 | selected_experts = selected_experts.to(torch.int64) 60 | 61 | if len(selected_experts.shape) == 2: 62 | selected_experts = selected_experts.unsqueeze(2) 63 | 64 | expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts) 65 | 66 | # For a given token, determine if it was routed to a given expert. 67 | expert_mask = torch.max(expert_mask, axis=-2).values 68 | 69 | # cast to float32 otherwise mean will fail 70 | expert_mask = expert_mask.to(torch.float32) 71 | tokens_per_group_and_expert = torch.mean(expert_mask, axis=-2) 72 | 73 | router_prob_per_group_and_expert = torch.mean(routing_weights, axis=-1) 74 | return torch.mean( 75 | tokens_per_group_and_expert * router_prob_per_group_and_expert.unsqueeze(-1) 76 | ) * (num_experts**2) 77 | 78 | 79 | # Copied from transformers.models.llama.modeling_llama.repeat_kv 80 | def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: 81 | """ 82 | This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, 83 | num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) 84 | """ 85 | batch, num_key_value_heads, slen, head_dim = hidden_states.shape 86 | if n_rep == 1: 87 | return hidden_states 88 | hidden_states = hidden_states[:, :, None, :, :].expand( 89 | batch, num_key_value_heads, n_rep, slen, head_dim 90 | ) 91 | return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) 92 | 93 | 94 | class RMSNorm(nn.Module): 95 | def __init__( 96 | self, 97 | hidden_size: int, 98 | eps: float = 1e-5, 99 | create_scale: bool = True, 100 | ) -> None: 101 | super().__init__() 102 | self.variance_epsilon = eps 103 | if create_scale: 104 | self.scale = nn.Parameter(torch.zeros(hidden_size)) 105 | else: 106 | self.scale = 1.0 107 | 108 | def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: 109 | input_dtype = hidden_states.dtype 110 | hidden_states = hidden_states.to(torch.float32) 111 | variance = hidden_states.pow(2).mean(-1, keepdim=True) 112 | hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) 113 | hidden_states = self.scale * hidden_states 114 | return hidden_states.to(input_dtype) 115 | 116 | 117 | class RotaryEmbedding(nn.Module): 118 | def __init__( 119 | self, dim: int, max_position_embeddings: int = 2048, base: int = 10000 120 | ) -> None: 121 | super().__init__() 122 | assert dim % 2 == 0 123 | self.dim = dim 124 | self.max_position_embeddings = max_position_embeddings 125 | self.base = base 126 | inv_freq = 1.0 / ( 127 | self.base ** (torch.arange(0, self.dim, 2).float() / self.dim) 128 | ) 129 | self.register_buffer("inv_freq", inv_freq, persistent=False) 130 | 131 | self._set_cos_sin_cache( 132 | seq_len=max_position_embeddings, 133 | device=self.inv_freq.device, 134 | dtype=torch.get_default_dtype(), 135 | ) 136 | 137 | def _set_cos_sin_cache(self, seq_len, device, dtype): 138 | self.max_seq_len_cached = seq_len 139 | t = torch.arange( 140 | self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype 141 | ) 142 | 143 | freqs = torch.outer(t, self.inv_freq) 144 | # Different from paper, but it uses a different permutation in order to obtain the same calculation 145 | emb = torch.cat((freqs, freqs), dim=-1) 146 | self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) 147 | self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) 148 | 149 | def forward(self, x, seq_len=None): 150 | # x: [bs, num_attention_heads, seq_len, head_size] 151 | if seq_len > self.max_seq_len_cached: 152 | self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) 153 | 154 | return ( 155 | self.cos_cached[:seq_len].to(dtype=x.dtype), 156 | self.sin_cached[:seq_len].to(dtype=x.dtype), 157 | ) 158 | 159 | 160 | # Copied from transformers.models.llama.modeling_llama.rotate_half 161 | def rotate_half(x): 162 | """Rotates half the hidden dims of the input.""" 163 | x1 = x[..., : x.shape[-1] // 2] 164 | x2 = x[..., x.shape[-1] // 2 :] 165 | return torch.cat((-x2, x1), dim=-1) 166 | 167 | 168 | # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb 169 | def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): 170 | """Applies Rotary Position Embedding to the query and key tensors. 171 | 172 | Args: 173 | q (`torch.Tensor`): The query tensor. 174 | k (`torch.Tensor`): The key tensor. 175 | cos (`torch.Tensor`): The cosine part of the rotary embedding. 176 | sin (`torch.Tensor`): The sine part of the rotary embedding. 177 | position_ids (`torch.Tensor`): 178 | The position indices of the tokens corresponding to the query and key tensors. For example, this can be 179 | used to pass offsetted position ids when working with a KV-cache. 180 | unsqueeze_dim (`int`, *optional*, defaults to 1): 181 | The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and 182 | sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note 183 | that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and 184 | k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes 185 | cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have 186 | the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. 187 | Returns: 188 | `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. 189 | """ 190 | cos = cos[position_ids].unsqueeze(unsqueeze_dim) 191 | sin = sin[position_ids].unsqueeze(unsqueeze_dim) 192 | q_embed = (q * cos) + (rotate_half(q) * sin) 193 | k_embed = (k * cos) + (rotate_half(k) * sin) 194 | return q_embed, k_embed 195 | 196 | 197 | class MultiHeadAttention(nn.Module): 198 | def __init__( 199 | self, 200 | hidden_size: int, 201 | num_heads: int, 202 | num_key_value_heads: Optional[int] = None, 203 | max_position_embeddings: int = 2048, 204 | attn_output_multiplier: float = 1.0, 205 | max_attn_val: float = 30.0, 206 | ): 207 | super().__init__() 208 | self.hidden_size = hidden_size 209 | self.num_heads = num_heads 210 | self.head_dim = hidden_size // num_heads 211 | if num_key_value_heads is None: 212 | num_key_value_heads = num_heads 213 | self.num_key_value_heads = num_key_value_heads 214 | self.num_key_value_groups = self.num_heads // self.num_key_value_heads 215 | self.attn_output_multiplier = attn_output_multiplier 216 | self.max_attn_val = max_attn_val 217 | 218 | if (self.head_dim * self.num_heads) != self.hidden_size: 219 | raise ValueError( 220 | f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" 221 | f" and `num_heads`: {self.num_heads})." 222 | ) 223 | 224 | self.q_proj = nn.Linear(hidden_size, self.num_heads * self.head_dim, bias=False) 225 | self.k_proj = nn.Linear( 226 | hidden_size, self.num_key_value_heads * self.head_dim, bias=False 227 | ) 228 | self.v_proj = nn.Linear( 229 | hidden_size, self.num_key_value_heads * self.head_dim, bias=False 230 | ) 231 | self.o_proj = nn.Linear(self.num_heads * self.head_dim, hidden_size, bias=False) 232 | 233 | self.rotary_emb = RotaryEmbedding( 234 | self.head_dim, 235 | max_position_embeddings=max_position_embeddings, 236 | ) 237 | 238 | def forward( 239 | self, 240 | hidden_states: torch.Tensor, 241 | attention_mask: Optional[torch.Tensor] = None, 242 | position_ids: Optional[torch.LongTensor] = None, 243 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 244 | output_attentions: bool = False, 245 | use_cache: bool = False, 246 | **kwargs, 247 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 248 | bsz, q_len, _ = hidden_states.size() 249 | 250 | query_states = self.q_proj(hidden_states) 251 | key_states = self.k_proj(hidden_states) 252 | value_states = self.v_proj(hidden_states) 253 | 254 | query_states = query_states.view( 255 | bsz, q_len, self.num_heads, self.head_dim 256 | ).transpose(1, 2) 257 | key_states = key_states.view( 258 | bsz, q_len, self.num_key_value_heads, self.head_dim 259 | ).transpose(1, 2) 260 | value_states = value_states.view( 261 | bsz, q_len, self.num_key_value_heads, self.head_dim 262 | ).transpose(1, 2) 263 | 264 | kv_seq_len = key_states.shape[-2] 265 | if past_key_value is not None: 266 | kv_seq_len += past_key_value[0].shape[-2] 267 | 268 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) 269 | query_states, key_states = apply_rotary_pos_emb( 270 | query_states, key_states, cos, sin, position_ids 271 | ) 272 | 273 | if past_key_value is not None: 274 | # reuse k, v, self_attention 275 | key_states = torch.cat([past_key_value[0], key_states], dim=2) 276 | value_states = torch.cat([past_key_value[1], value_states], dim=2) 277 | 278 | past_key_value = (key_states, value_states) if use_cache else None 279 | 280 | # repeat k/v heads if n_kv_heads < n_heads 281 | key_states = repeat_kv(key_states, self.num_key_value_groups) 282 | value_states = repeat_kv(value_states, self.num_key_value_groups) 283 | 284 | attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)).to( 285 | torch.float 286 | ) 287 | attn_weights = attn_weights * self.attn_output_multiplier 288 | attn_weights = self.max_attn_val * F.tanh(attn_weights / self.max_attn_val) 289 | 290 | if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): 291 | raise ValueError( 292 | f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" 293 | f" {attn_weights.size()}" 294 | ) 295 | 296 | if attention_mask is not None: 297 | if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): 298 | raise ValueError( 299 | f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" 300 | ) 301 | 302 | attn_weights = attn_weights + attention_mask 303 | 304 | attn_weights = F.softmax(attn_weights, dim=-1).to(query_states.dtype) 305 | attn_output = torch.matmul(attn_weights, value_states) 306 | 307 | if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): 308 | raise ValueError( 309 | f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" 310 | f" {attn_output.size()}" 311 | ) 312 | 313 | attn_output = attn_output.transpose(1, 2).contiguous() 314 | attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) 315 | 316 | attn_output = self.o_proj(attn_output) 317 | 318 | if not output_attentions: 319 | attn_weights = None 320 | 321 | return attn_output, attn_weights, past_key_value 322 | 323 | 324 | class MoeMLP(nn.Module): 325 | def __init__( 326 | self, 327 | hidden_dim: int, 328 | ffn_dim: int, 329 | ) -> None: 330 | super().__init__() 331 | self.linear_v = nn.Linear(hidden_dim, ffn_dim, bias=False) 332 | self.linear_1 = nn.Linear(ffn_dim, hidden_dim, bias=False) 333 | self.linear = nn.Linear(hidden_dim, ffn_dim, bias=False) 334 | self.act_fn = nn.GELU() 335 | 336 | def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: 337 | current_hidden_states = self.act_fn(self.linear(hidden_states)) * self.linear_v( 338 | hidden_states 339 | ) 340 | current_hidden_states = self.linear_1(current_hidden_states) 341 | return current_hidden_states 342 | 343 | 344 | class MoeBlock(nn.Module): 345 | def __init__( 346 | self, 347 | hidden_dim: int, 348 | ffn_dim: int, 349 | num_experts: int, 350 | top_k: int, 351 | ) -> None: 352 | super().__init__() 353 | self.num_experts = num_experts 354 | self.top_k = top_k 355 | self.gate = nn.Linear(hidden_dim, num_experts, bias=False) 356 | self.experts = nn.ModuleList( 357 | [MoeMLP(hidden_dim, ffn_dim) for _ in range(num_experts)] 358 | ) 359 | 360 | def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor]: 361 | batch_size, sequence_length, hidden_dim = hidden_states.shape 362 | hidden_states = hidden_states.view(-1, hidden_dim) 363 | # router_logits: (batch * sequence_length, n_experts) 364 | router_logits = self.gate(hidden_states) 365 | 366 | routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) 367 | routing_weights, selected_experts = torch.topk( 368 | routing_weights, self.top_k, dim=-1 369 | ) 370 | # we cast back to the input dtype 371 | routing_weights = routing_weights.to(hidden_states.dtype) 372 | 373 | final_hidden_states = torch.zeros( 374 | (batch_size * sequence_length, hidden_dim), 375 | dtype=hidden_states.dtype, 376 | device=hidden_states.device, 377 | ) 378 | # One hot encode the selected experts to create an expert mask 379 | # this will be used to easily index which expert is going to be sollicitated 380 | expert_mask = torch.nn.functional.one_hot( 381 | selected_experts, num_classes=self.num_experts 382 | ).permute(2, 1, 0) 383 | 384 | # Loop over all available experts in the model and perform the computation on each expert 385 | for expert_idx in range(self.num_experts): 386 | expert_layer = self.experts[expert_idx] 387 | idx, top_x = torch.where(expert_mask[expert_idx]) 388 | 389 | if top_x.shape[0] == 0: 390 | continue 391 | 392 | # in torch it is faster to index using lists than torch tensors 393 | top_x_list = top_x.tolist() 394 | idx_list = idx.tolist() 395 | 396 | # Index the correct hidden states and compute the expert hidden state for 397 | # the current expert. We need to make sure to multiply the output hidden 398 | # states by `routing_weights` on the corresponding tokens (top-1 and top-2) 399 | current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim) 400 | current_hidden_states = ( 401 | expert_layer(current_state) 402 | * routing_weights[top_x_list, idx_list, None] 403 | ) 404 | 405 | # However `index_add_` only support torch tensors for indexing so we'll use 406 | # the `top_x` tensor here. 407 | final_hidden_states.index_add_( 408 | 0, top_x, current_hidden_states.to(hidden_states.dtype) 409 | ) 410 | final_hidden_states = final_hidden_states.reshape( 411 | batch_size, sequence_length, hidden_dim 412 | ) 413 | return final_hidden_states, router_logits 414 | 415 | 416 | class DecoderLayer(nn.Module): 417 | def __init__( 418 | self, 419 | hidden_size: int, 420 | intermediate_size: int, 421 | num_heads: int, 422 | num_key_value_heads: int, 423 | num_experts: int, 424 | top_k: int, 425 | max_position_embeddings: int = 2048, 426 | attn_output_multiplier: float = 1.0, 427 | max_attn_val: float = 30.0, 428 | rms_norm_eps: float = 1e-5, 429 | ) -> None: 430 | super().__init__() 431 | self.attn = MultiHeadAttention( 432 | hidden_size, 433 | num_heads, 434 | num_key_value_heads, 435 | max_position_embeddings=max_position_embeddings, 436 | attn_output_multiplier=attn_output_multiplier, 437 | max_attn_val=max_attn_val, 438 | ) 439 | self.moe_block = MoeBlock(hidden_size, intermediate_size, num_experts, top_k) 440 | self.pre_attn_norm = RMSNorm(hidden_size, eps=rms_norm_eps) 441 | self.post_attn_norm = RMSNorm(hidden_size, eps=rms_norm_eps) 442 | self.pre_moe_norm = RMSNorm(hidden_size, eps=rms_norm_eps) 443 | self.post_moe_norm = RMSNorm(hidden_size, eps=rms_norm_eps) 444 | 445 | def forward( 446 | self, 447 | hidden_states: torch.Tensor, 448 | attention_mask: Optional[torch.Tensor] = None, 449 | position_ids: Optional[torch.LongTensor] = None, 450 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 451 | output_attentions: Optional[bool] = False, 452 | output_router_logits: Optional[bool] = False, 453 | use_cache: Optional[bool] = False, 454 | **kwargs, 455 | ) -> Tuple[ 456 | torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] 457 | ]: 458 | residual = hidden_states 459 | hidden_states = self.pre_attn_norm(hidden_states) 460 | hidden_states, attention_weights, present_key_value = self.attn( 461 | hidden_states, 462 | attention_mask=attention_mask, 463 | position_ids=position_ids, 464 | past_key_value=past_key_value, 465 | output_attentions=output_attentions, 466 | use_cache=use_cache, 467 | ) 468 | hidden_states = self.post_attn_norm(hidden_states) 469 | hidden_states = residual + hidden_states 470 | 471 | residual = hidden_states 472 | hidden_states = self.pre_moe_norm(hidden_states) 473 | hidden_states, router_logits = self.moe_block(hidden_states) 474 | hidden_states = self.post_moe_norm(hidden_states) 475 | hidden_states = residual + hidden_states 476 | 477 | outputs = (hidden_states,) 478 | if output_attentions: 479 | outputs += (attention_weights,) 480 | if use_cache: 481 | outputs += (present_key_value,) 482 | if output_router_logits: 483 | outputs += (router_logits,) 484 | return outputs 485 | 486 | 487 | class Grok1PretrainedModel(PreTrainedModel): 488 | config_class = Grok1Config 489 | base_model_prefix = "model" 490 | supports_gradient_checkpointing = True 491 | _no_split_modules = ["DecoderLayer"] 492 | _skip_keys_device_placement = "past_key_values" 493 | _supports_flash_attn_2 = False 494 | _supports_cache_class = False 495 | 496 | def _init_weights(self, module) -> None: 497 | if isinstance(module, nn.Linear): 498 | module.weight.data.zero_() 499 | if module.bias is not None: 500 | module.bias.data.zero_() 501 | elif isinstance(module, nn.Embedding): 502 | module.weight.data.zero_() 503 | 504 | 505 | # Copied from transformers.models.bart.modeling_bart._make_causal_mask 506 | def _make_causal_mask( 507 | input_ids_shape: torch.Size, 508 | dtype: torch.dtype, 509 | device: torch.device, 510 | past_key_values_length: int = 0, 511 | ): 512 | """ 513 | Make causal mask used for bi-directional self-attention. 514 | """ 515 | bsz, tgt_len = input_ids_shape 516 | mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) 517 | mask_cond = torch.arange(mask.size(-1), device=device) 518 | mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) 519 | mask = mask.to(dtype) 520 | 521 | if past_key_values_length > 0: 522 | mask = torch.cat( 523 | [ 524 | torch.zeros( 525 | tgt_len, past_key_values_length, dtype=dtype, device=device 526 | ), 527 | mask, 528 | ], 529 | dim=-1, 530 | ) 531 | return mask[None, None, :, :].expand( 532 | bsz, 1, tgt_len, tgt_len + past_key_values_length 533 | ) 534 | 535 | 536 | # Copied from transformers.models.bart.modeling_bart._expand_mask 537 | def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): 538 | """ 539 | Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. 540 | """ 541 | bsz, src_len = mask.size() 542 | tgt_len = tgt_len if tgt_len is not None else src_len 543 | 544 | expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) 545 | 546 | inverted_mask = 1.0 - expanded_mask 547 | 548 | return inverted_mask.masked_fill( 549 | inverted_mask.to(torch.bool), torch.finfo(dtype).min 550 | ) 551 | 552 | 553 | class Grok1Model(Grok1PretrainedModel): 554 | def __init__(self, config: Grok1Config, **kwargs) -> None: 555 | super().__init__(config) 556 | self.padding_idx = config.pad_token_id 557 | self.vocab_size = config.vocab_size 558 | self.embedding_multiplier_scale = config.embedding_multiplier_scale 559 | 560 | self.embed_tokens = nn.Embedding( 561 | config.vocab_size, config.hidden_size, self.padding_idx 562 | ) 563 | self.layers = nn.ModuleList( 564 | [ 565 | DecoderLayer( 566 | hidden_size=config.hidden_size, 567 | intermediate_size=config.intermediate_size, 568 | num_heads=config.num_attention_heads, 569 | num_key_value_heads=config.num_key_value_heads, 570 | num_experts=config.num_experts, 571 | top_k=config.num_experts_per_tok, 572 | max_position_embeddings=config.max_position_embeddings, 573 | attn_output_multiplier=config.attn_output_multiplier, 574 | max_attn_val=config.max_attn_value, 575 | rms_norm_eps=config.rms_norm_eps, 576 | ) 577 | for layer_idx in range(config.num_hidden_layers) 578 | # for layer_idx in range(1, config.num_hidden_layers) 579 | ] 580 | ) 581 | # self.layers.insert(0, None) 582 | self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) 583 | self.gradient_checkpointing = False 584 | self.post_init() 585 | 586 | def get_input_embeddings(self): 587 | return self.embed_tokens 588 | 589 | def set_input_embeddings(self, value): 590 | self.embed_tokens = value 591 | 592 | # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask 593 | def _prepare_decoder_attention_mask( 594 | self, attention_mask, input_shape, inputs_embeds, past_key_values_length 595 | ): 596 | # create causal mask 597 | # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] 598 | combined_attention_mask = None 599 | if input_shape[-1] > 1: 600 | combined_attention_mask = _make_causal_mask( 601 | input_shape, 602 | inputs_embeds.dtype, 603 | device=inputs_embeds.device, 604 | past_key_values_length=past_key_values_length, 605 | ) 606 | 607 | if attention_mask is not None: 608 | # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] 609 | expanded_attn_mask = _expand_mask( 610 | attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] 611 | ).to(inputs_embeds.device) 612 | combined_attention_mask = ( 613 | expanded_attn_mask 614 | if combined_attention_mask is None 615 | else expanded_attn_mask + combined_attention_mask 616 | ) 617 | 618 | return combined_attention_mask 619 | 620 | def forward( 621 | self, 622 | input_ids: torch.LongTensor = None, 623 | attention_mask: Optional[torch.Tensor] = None, 624 | position_ids: Optional[torch.LongTensor] = None, 625 | past_key_values: Optional[List[torch.FloatTensor]] = None, 626 | inputs_embeds: Optional[torch.FloatTensor] = None, 627 | use_cache: Optional[bool] = None, 628 | output_attentions: Optional[bool] = None, 629 | output_hidden_states: Optional[bool] = None, 630 | output_router_logits: Optional[bool] = None, 631 | return_dict: Optional[bool] = None, 632 | decoder_layer_idx: Optional[int] = None, 633 | ) -> Union[Tuple, MoeModelOutputWithPast]: 634 | output_attentions = ( 635 | output_attentions 636 | if output_attentions is not None 637 | else self.config.output_attentions 638 | ) 639 | output_hidden_states = ( 640 | output_hidden_states 641 | if output_hidden_states is not None 642 | else self.config.output_hidden_states 643 | ) 644 | use_cache = use_cache if use_cache is not None else self.config.use_cache 645 | 646 | return_dict = ( 647 | return_dict if return_dict is not None else self.config.use_return_dict 648 | ) 649 | 650 | # retrieve input_ids and inputs_embeds 651 | if input_ids is not None and inputs_embeds is not None: 652 | raise ValueError( 653 | "You cannot specify both input_ids and inputs_embeds at the same time" 654 | ) 655 | elif input_ids is not None: 656 | batch_size, seq_length = input_ids.shape[:2] 657 | elif inputs_embeds is not None: 658 | batch_size, seq_length = inputs_embeds.shape[:2] 659 | else: 660 | raise ValueError("You have to specify either input_ids or inputs_embeds") 661 | 662 | seq_length_with_past = seq_length 663 | past_key_values_length = 0 664 | if past_key_values is not None: 665 | past_key_values_length = past_key_values[0][0].shape[2] 666 | seq_length_with_past = seq_length_with_past + past_key_values_length 667 | 668 | if position_ids is None: 669 | device = input_ids.device if input_ids is not None else inputs_embeds.device 670 | position_ids = torch.arange( 671 | past_key_values_length, 672 | seq_length + past_key_values_length, 673 | dtype=torch.long, 674 | device=device, 675 | ) 676 | position_ids = position_ids.unsqueeze(0) 677 | 678 | if inputs_embeds is None: 679 | inputs_embeds = self.embed_tokens(input_ids) 680 | inputs_embeds = inputs_embeds * self.embedding_multiplier_scale 681 | 682 | if HAS_MASK_UTILS: 683 | # 4d mask is passed through the layers 684 | attention_mask = _prepare_4d_causal_attention_mask( 685 | attention_mask, 686 | (batch_size, seq_length), 687 | inputs_embeds, 688 | past_key_values_length, 689 | ) 690 | else: 691 | if attention_mask is None: 692 | attention_mask = torch.ones( 693 | (batch_size, seq_length_with_past), 694 | dtype=torch.bool, 695 | device=inputs_embeds.device, 696 | ) 697 | attention_mask = self._prepare_decoder_attention_mask( 698 | attention_mask, 699 | (batch_size, seq_length), 700 | inputs_embeds, 701 | past_key_values_length, 702 | ) 703 | 704 | # embed positions 705 | hidden_states = inputs_embeds 706 | 707 | if self.gradient_checkpointing and self.training: 708 | if use_cache: 709 | logger.warning_once( 710 | "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." 711 | ) 712 | use_cache = False 713 | 714 | # decoder layers 715 | all_hidden_states = () if output_hidden_states else None 716 | all_self_attns = () if output_attentions else None 717 | all_router_logits = () if output_router_logits else None 718 | next_decoder_cache = () if use_cache else None 719 | 720 | for idx, decoder_layer in enumerate(self.layers): 721 | if decoder_layer_idx is not None and idx != decoder_layer_idx: 722 | continue 723 | if output_hidden_states: 724 | all_hidden_states += (hidden_states,) 725 | 726 | past_key_value = ( 727 | past_key_values[idx] if past_key_values is not None else None 728 | ) 729 | 730 | if self.gradient_checkpointing and self.training: 731 | 732 | def create_custom_forward(module): 733 | def custom_forward(*inputs): 734 | # None for past_key_value 735 | return module(*inputs, past_key_value, output_attentions) 736 | 737 | return custom_forward 738 | 739 | layer_outputs = torch.utils.checkpoint.checkpoint( 740 | create_custom_forward(decoder_layer), 741 | hidden_states, 742 | attention_mask, 743 | position_ids, 744 | ) 745 | else: 746 | layer_outputs = decoder_layer( 747 | hidden_states, 748 | attention_mask=attention_mask, 749 | position_ids=position_ids, 750 | past_key_value=past_key_value, 751 | output_attentions=output_attentions, 752 | use_cache=use_cache, 753 | ) 754 | 755 | hidden_states = layer_outputs[0] 756 | 757 | if use_cache: 758 | next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) 759 | 760 | if output_attentions: 761 | all_self_attns += (layer_outputs[1],) 762 | 763 | if output_router_logits: 764 | all_router_logits += (layer_outputs[-1],) 765 | 766 | hidden_states = self.norm(hidden_states) 767 | 768 | # add hidden states from the last decoder layer 769 | if output_hidden_states: 770 | all_hidden_states += (hidden_states,) 771 | next_cache = next_decoder_cache if use_cache else None 772 | 773 | if not return_dict: 774 | return tuple( 775 | v 776 | for v in [ 777 | hidden_states, 778 | next_cache, 779 | all_hidden_states, 780 | all_self_attns, 781 | all_router_logits, 782 | ] 783 | if v is not None 784 | ) 785 | return MoeModelOutputWithPast( 786 | last_hidden_state=hidden_states, 787 | past_key_values=next_cache, 788 | hidden_states=all_hidden_states, 789 | attentions=all_self_attns, 790 | router_logits=all_router_logits, 791 | ) 792 | 793 | 794 | class Grok1ModelForCausalLM(Grok1PretrainedModel): 795 | _tied_weights_keys = ["lm_head.weight"] 796 | 797 | def __init__(self, config: Grok1Config, **kwargs): 798 | super().__init__(config) 799 | self.model = Grok1Model(config) 800 | self.vocab_size = config.vocab_size 801 | self.output_multiplier_scale = config.output_multiplier_scale 802 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 803 | self.router_aux_loss_coef = config.router_aux_loss_coef 804 | self.num_experts = config.num_experts 805 | self.num_experts_per_tok = config.num_experts_per_tok 806 | self.post_init() 807 | 808 | def get_input_embeddings(self): 809 | return self.model.embed_tokens 810 | 811 | def set_input_embeddings(self, value): 812 | self.model.embed_tokens = value 813 | 814 | def get_output_embeddings(self): 815 | return self.lm_head 816 | 817 | def set_output_embeddings(self, new_embeddings): 818 | self.lm_head = new_embeddings 819 | 820 | def set_decoder(self, decoder): 821 | self.model = decoder 822 | 823 | def get_decoder(self): 824 | return self.model 825 | 826 | def forward( 827 | self, 828 | input_ids: torch.LongTensor = None, 829 | attention_mask: Optional[torch.Tensor] = None, 830 | position_ids: Optional[torch.LongTensor] = None, 831 | past_key_values: Optional[List[torch.FloatTensor]] = None, 832 | inputs_embeds: Optional[torch.FloatTensor] = None, 833 | labels: Optional[torch.LongTensor] = None, 834 | use_cache: Optional[bool] = None, 835 | output_attentions: Optional[bool] = None, 836 | output_hidden_states: Optional[bool] = None, 837 | output_router_logits: Optional[bool] = None, 838 | return_dict: Optional[bool] = None, 839 | decoder_layer_idx: Optional[int] = None, 840 | ) -> Union[Tuple, MoeCausalLMOutputWithPast]: 841 | output_attentions = ( 842 | output_attentions 843 | if output_attentions is not None 844 | else self.config.output_attentions 845 | ) 846 | output_router_logits = ( 847 | output_router_logits 848 | if output_router_logits is not None 849 | else self.config.output_router_logits 850 | ) 851 | 852 | output_hidden_states = ( 853 | output_hidden_states 854 | if output_hidden_states is not None 855 | else self.config.output_hidden_states 856 | ) 857 | return_dict = ( 858 | return_dict if return_dict is not None else self.config.use_return_dict 859 | ) 860 | 861 | # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) 862 | outputs = self.model( 863 | input_ids=input_ids, 864 | attention_mask=attention_mask, 865 | position_ids=position_ids, 866 | past_key_values=past_key_values, 867 | inputs_embeds=inputs_embeds, 868 | use_cache=use_cache, 869 | output_attentions=output_attentions, 870 | output_hidden_states=output_hidden_states, 871 | output_router_logits=output_router_logits, 872 | return_dict=return_dict, 873 | decoder_layer_idx=decoder_layer_idx, 874 | ) 875 | 876 | hidden_states = outputs[0] 877 | logits = self.lm_head(hidden_states) 878 | logits = logits * self.output_multiplier_scale 879 | logits = logits.float() 880 | 881 | loss = None 882 | if labels is not None: 883 | # Shift so that tokens < n predict n 884 | shift_logits = logits[..., :-1, :].contiguous() 885 | shift_labels = labels[..., 1:].contiguous() 886 | # Flatten the tokens 887 | loss_fct = nn.CrossEntropyLoss() 888 | shift_logits = shift_logits.view(-1, self.config.vocab_size) 889 | shift_labels = shift_labels.view(-1) 890 | # Enable model parallelism 891 | shift_labels = shift_labels.to(shift_logits.device) 892 | loss = loss_fct(shift_logits, shift_labels) 893 | 894 | aux_loss = None 895 | if output_router_logits: 896 | aux_loss = load_balancing_loss_func( 897 | outputs.router_logits if return_dict else outputs[-1], 898 | self.num_experts, 899 | self.num_experts_per_tok, 900 | ) 901 | if labels is not None: 902 | loss += self.router_aux_loss_coef * aux_loss 903 | 904 | if not return_dict: 905 | output = (logits,) + outputs[1:] 906 | if output_router_logits: 907 | output = (aux_loss,) + output 908 | return (loss,) + output if loss is not None else output 909 | 910 | return MoeCausalLMOutputWithPast( 911 | loss=loss, 912 | aux_loss=aux_loss, 913 | logits=logits, 914 | past_key_values=outputs.past_key_values, 915 | hidden_states=outputs.hidden_states, 916 | attentions=outputs.attentions, 917 | router_logits=outputs.router_logits, 918 | ) 919 | 920 | def prepare_inputs_for_generation( 921 | self, 922 | input_ids, 923 | past_key_values=None, 924 | attention_mask=None, 925 | inputs_embeds=None, 926 | **kwargs, 927 | ): 928 | if past_key_values: 929 | input_ids = input_ids[:, -1:] 930 | 931 | position_ids = kwargs.get("position_ids", None) 932 | if attention_mask is not None and position_ids is None: 933 | # create position_ids on the fly for batch generation 934 | position_ids = attention_mask.long().cumsum(-1) - 1 935 | position_ids.masked_fill_(attention_mask == 0, 1) 936 | if past_key_values: 937 | position_ids = position_ids[:, -1].unsqueeze(-1) 938 | 939 | # if `inputs_embeds` are passed, we only want to use them in the 1st generation step 940 | if inputs_embeds is not None and past_key_values is None: 941 | model_inputs = {"inputs_embeds": inputs_embeds} 942 | else: 943 | model_inputs = {"input_ids": input_ids} 944 | 945 | model_inputs.update( 946 | { 947 | "position_ids": position_ids, 948 | "past_key_values": past_key_values, 949 | "use_cache": kwargs.get("use_cache"), 950 | "attention_mask": attention_mask, 951 | } 952 | ) 953 | return model_inputs 954 | -------------------------------------------------------------------------------- /grok/modeling_grok1_outputs.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Optional, Tuple 3 | 4 | import torch 5 | from transformers.modeling_outputs import ModelOutput 6 | 7 | __all__ = [ 8 | "MoeModelOutputWithPast", 9 | "MoeCausalLMOutputWithPast", 10 | ] 11 | 12 | try: 13 | from transformers.modeling_outputs import ( 14 | MoeCausalLMOutputWithPast, 15 | MoeModelOutputWithPast, 16 | ) 17 | except: 18 | 19 | @dataclass 20 | class MoeModelOutputWithPast(ModelOutput): 21 | """ 22 | Base class for model's outputs, with potential hidden states and attentions. 23 | 24 | Args: 25 | last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): 26 | Sequence of hidden-states at the output of the last layer of the model. 27 | past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): 28 | Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape 29 | `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if 30 | `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads, 31 | encoder_sequence_length, embed_size_per_head)`. 32 | 33 | Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if 34 | `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values` 35 | input) to speed up sequential decoding. 36 | hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): 37 | Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + 38 | one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. 39 | 40 | Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. 41 | attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): 42 | Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, 43 | sequence_length)`. 44 | 45 | Attentions weights after the attention softmax, used to compute the weighted average in the self-attention 46 | heads. 47 | router_logits (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_probs=True` and `config.add_router_probs=True` is passed or when `config.output_router_probs=True`): 48 | Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`. 49 | 50 | Raw router logtis (post-softmax) that are computed by MoE routers, these terms are used to compute the auxiliary 51 | loss for Mixture of Experts models. 52 | """ 53 | 54 | last_hidden_state: torch.FloatTensor = None 55 | past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None 56 | hidden_states: Optional[Tuple[torch.FloatTensor]] = None 57 | attentions: Optional[Tuple[torch.FloatTensor]] = None 58 | router_logits: Optional[Tuple[torch.FloatTensor]] = None 59 | 60 | @dataclass 61 | class MoeCausalLMOutputWithPast(ModelOutput): 62 | """ 63 | Base class for causal language model (or autoregressive) with mixture of experts outputs. 64 | 65 | Args: 66 | loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): 67 | Language modeling loss (for next-token prediction). 68 | 69 | logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): 70 | Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). 71 | 72 | aux_loss (`torch.FloatTensor`, *optional*, returned when `labels` is provided): 73 | aux_loss for the sparse modules. 74 | 75 | router_logits (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_probs=True` and `config.add_router_probs=True` is passed or when `config.output_router_probs=True`): 76 | Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`. 77 | 78 | Raw router logtis (post-softmax) that are computed by MoE routers, these terms are used to compute the auxiliary 79 | loss for Mixture of Experts models. 80 | 81 | past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): 82 | Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape 83 | `(batch_size, num_heads, sequence_length, embed_size_per_head)`) 84 | 85 | Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see 86 | `past_key_values` input) to speed up sequential decoding. 87 | hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): 88 | Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + 89 | one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. 90 | 91 | Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. 92 | attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): 93 | Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, 94 | sequence_length)`. 95 | 96 | Attentions weights after the attention softmax, used to compute the weighted average in the self-attention 97 | heads. 98 | """ 99 | 100 | loss: Optional[torch.FloatTensor] = None 101 | aux_loss: Optional[torch.FloatTensor] = None 102 | logits: torch.FloatTensor = None 103 | past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None 104 | hidden_states: Optional[Tuple[torch.FloatTensor]] = None 105 | attentions: Optional[Tuple[torch.FloatTensor]] = None 106 | router_logits: Optional[Tuple[torch.FloatTensor]] = None 107 | -------------------------------------------------------------------------------- /mistral/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "MistralForCausalLM" 4 | ], 5 | "bos_token_id": 1, 6 | "eos_token_id": 2, 7 | "hidden_act": "silu", 8 | "hidden_size": 4096, 9 | "initializer_range": 0.02, 10 | "intermediate_size": 14336, 11 | "max_position_embeddings": 32768, 12 | "model_type": "mistral", 13 | "num_attention_heads": 32, 14 | "num_hidden_layers": 32, 15 | "num_key_value_heads": 8, 16 | "rms_norm_eps": 1e-05, 17 | "rope_theta": 10000.0, 18 | "sliding_window": 4096, 19 | "tie_word_embeddings": false, 20 | "torch_dtype": "bfloat16", 21 | "transformers_version": "4.34.0.dev0", 22 | "use_cache": true, 23 | "vocab_size": 32000 24 | } 25 | -------------------------------------------------------------------------------- /mistral/configuration_mistral.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 Mistral AI and the HuggingFace Inc. team. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ Mistral model configuration""" 16 | 17 | from transformers.configuration_utils import PretrainedConfig 18 | from transformers.utils import logging 19 | 20 | 21 | logger = logging.get_logger(__name__) 22 | 23 | MISTRAL_PRETRAINED_CONFIG_ARCHIVE_MAP = { 24 | "mistralai/Mistral-7B-v0.1": "https://huggingface.co/mistralai/Mistral-7B-v0.1/resolve/main/config.json", 25 | "mistralai/Mistral-7B-Instruct-v0.1": "https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1/resolve/main/config.json", 26 | } 27 | 28 | 29 | class MistralConfig(PretrainedConfig): 30 | r""" 31 | This is the configuration class to store the configuration of a [`MistralModel`]. It is used to instantiate an 32 | Mistral model according to the specified arguments, defining the model architecture. Instantiating a configuration 33 | with the defaults will yield a similar configuration to that of the Mistral-7B-v0.1 or Mistral-7B-Instruct-v0.1. 34 | 35 | [mistralai/Mistral-7B-v0.1](https://huggingface.co/mistralai/Mistral-7B-v0.1) 36 | [mistralai/Mistral-7B-Instruct-v0.1](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1) 37 | 38 | Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the 39 | documentation from [`PretrainedConfig`] for more information. 40 | 41 | 42 | Args: 43 | vocab_size (`int`, *optional*, defaults to 32000): 44 | Vocabulary size of the Mistral model. Defines the number of different tokens that can be represented by the 45 | `inputs_ids` passed when calling [`MistralModel`] 46 | hidden_size (`int`, *optional*, defaults to 4096): 47 | Dimension of the hidden representations. 48 | intermediate_size (`int`, *optional*, defaults to 14336): 49 | Dimension of the MLP representations. 50 | num_hidden_layers (`int`, *optional*, defaults to 32): 51 | Number of hidden layers in the Transformer encoder. 52 | num_attention_heads (`int`, *optional*, defaults to 32): 53 | Number of attention heads for each attention layer in the Transformer encoder. 54 | num_key_value_heads (`int`, *optional*, defaults to 8): 55 | This is the number of key_value heads that should be used to implement Grouped Query Attention. If 56 | `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if 57 | `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When 58 | converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed 59 | by meanpooling all the original heads within that group. For more details checkout [this 60 | paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `8`. 61 | hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): 62 | The non-linear activation function (function or string) in the decoder. 63 | max_position_embeddings (`int`, *optional*, defaults to `4096*32`): 64 | The maximum sequence length that this model might ever be used with. Mistral's sliding window attention 65 | allows sequence of up to 4096*32 tokens. 66 | initializer_range (`float`, *optional*, defaults to 0.02): 67 | The standard deviation of the truncated_normal_initializer for initializing all weight matrices. 68 | rms_norm_eps (`float`, *optional*, defaults to 1e-06): 69 | The epsilon used by the rms normalization layers. 70 | use_cache (`bool`, *optional*, defaults to `True`): 71 | Whether or not the model should return the last key/values attentions (not used by all models). Only 72 | relevant if `config.is_decoder=True`. 73 | pad_token_id (`int`, *optional*): 74 | The id of the padding token. 75 | bos_token_id (`int`, *optional*, defaults to 1): 76 | The id of the "beginning-of-sequence" token. 77 | eos_token_id (`int`, *optional*, defaults to 2): 78 | The id of the "end-of-sequence" token. 79 | tie_word_embeddings (`bool`, *optional*, defaults to `False`): 80 | Whether the model's input and output word embeddings should be tied. 81 | rope_theta (`float`, *optional*, defaults to 10000.0): 82 | The base period of the RoPE embeddings. 83 | sliding_window (`int`, *optional*, defaults to 4096): 84 | Sliding window attention window size. If not specified, will default to `4096`. 85 | attention_dropout (`float`, *optional*, defaults to 0.0): 86 | The dropout ratio for the attention probabilities. 87 | 88 | ```python 89 | >>> from transformers import MistralModel, MistralConfig 90 | 91 | >>> # Initializing a Mistral 7B style configuration 92 | >>> configuration = MistralConfig() 93 | 94 | >>> # Initializing a model from the Mistral 7B style configuration 95 | >>> model = MistralModel(configuration) 96 | 97 | >>> # Accessing the model configuration 98 | >>> configuration = model.config 99 | ```""" 100 | 101 | model_type = "mistral" 102 | keys_to_ignore_at_inference = ["past_key_values"] 103 | 104 | def __init__( 105 | self, 106 | vocab_size=32000, 107 | hidden_size=4096, 108 | intermediate_size=14336, 109 | num_hidden_layers=32, 110 | num_attention_heads=32, 111 | num_key_value_heads=8, 112 | hidden_act="silu", 113 | max_position_embeddings=4096 * 32, 114 | initializer_range=0.02, 115 | rms_norm_eps=1e-6, 116 | use_cache=True, 117 | pad_token_id=None, 118 | bos_token_id=1, 119 | eos_token_id=2, 120 | tie_word_embeddings=False, 121 | rope_theta=10000.0, 122 | sliding_window=4096, 123 | attention_dropout=0.0, 124 | **kwargs, 125 | ): 126 | self.vocab_size = vocab_size 127 | self.max_position_embeddings = max_position_embeddings 128 | self.hidden_size = hidden_size 129 | self.intermediate_size = intermediate_size 130 | self.num_hidden_layers = num_hidden_layers 131 | self.num_attention_heads = num_attention_heads 132 | self.sliding_window = sliding_window 133 | 134 | # for backward compatibility 135 | if num_key_value_heads is None: 136 | num_key_value_heads = num_attention_heads 137 | 138 | self.num_key_value_heads = num_key_value_heads 139 | self.hidden_act = hidden_act 140 | self.initializer_range = initializer_range 141 | self.rms_norm_eps = rms_norm_eps 142 | self.use_cache = use_cache 143 | self.rope_theta = rope_theta 144 | self.attention_dropout = attention_dropout 145 | 146 | super().__init__( 147 | pad_token_id=pad_token_id, 148 | bos_token_id=bos_token_id, 149 | eos_token_id=eos_token_id, 150 | tie_word_embeddings=tie_word_embeddings, 151 | **kwargs, 152 | ) -------------------------------------------------------------------------------- /mixtral_base/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "MixtralForCausalLM" 4 | ], 5 | "auto_map": { 6 | "AutoConfig": "configuration_moe_mistral.MixtralConfig", 7 | "AutoModelForCausalLM": "modeling_moe_mistral.MixtralForCausalLM" 8 | }, 9 | "attention_dropout": 0.0, 10 | "bos_token_id": 1, 11 | "eos_token_id": 2, 12 | "hidden_act": "silu", 13 | "hidden_size": 4096, 14 | "initializer_range": 0.02, 15 | "intermediate_size": 14336, 16 | "max_position_embeddings": 32768, 17 | "model_type": "mistral", 18 | "num_attention_heads": 32, 19 | "num_experts": 8, 20 | "num_experts_per_token": 1, 21 | "num_hidden_layers": 32, 22 | "num_key_value_heads": 8, 23 | "rms_norm_eps": 1e-05, 24 | "rope_theta": 1000000.0, 25 | "tie_word_embeddings": false, 26 | "torch_dtype": "float16", 27 | "transformers_version": "4.36.0.dev0", 28 | "use_cache": true, 29 | "vocab_size": 32000, 30 | "single_expert": false 31 | } 32 | -------------------------------------------------------------------------------- /mixtral_base/configuration_moe_mistral.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 Mistral AI and the HuggingFace Inc. team. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ Mistral model configuration""" 16 | 17 | from transformers.configuration_utils import PretrainedConfig 18 | from transformers.utils import logging 19 | 20 | 21 | logger = logging.get_logger(__name__) 22 | 23 | MISTRAL_PRETRAINED_CONFIG_ARCHIVE_MAP = { 24 | "mistralai/Mistral-7B-v0.1": "https://huggingface.co/mistralai/Mistral-7B-v0.1/resolve/main/config.json", 25 | "mistralai/Mistral-7B-Instruct-v0.1": "https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1/resolve/main/config.json", 26 | } 27 | 28 | 29 | class MixtralConfig(PretrainedConfig): 30 | r""" 31 | This is the configuration class to store the configuration of a [`MistralModel`]. It is used to instantiate an 32 | Mistral model according to the specified arguments, defining the model architecture. Instantiating a configuration 33 | with the defaults will yield a similar configuration to that of the Mistral-7B-v0.1 or Mistral-7B-Instruct-v0.1. 34 | 35 | [mistralai/Mistral-7B-v0.1](https://huggingface.co/mistralai/Mistral-7B-v0.1) 36 | [mistralai/Mistral-7B-Instruct-v0.1](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1) 37 | 38 | Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the 39 | documentation from [`PretrainedConfig`] for more information. 40 | 41 | 42 | Args: 43 | vocab_size (`int`, *optional*, defaults to 32000): 44 | Vocabulary size of the Mistral model. Defines the number of different tokens that can be represented by the 45 | `inputs_ids` passed when calling [`MistralModel`] 46 | hidden_size (`int`, *optional*, defaults to 4096): 47 | Dimension of the hidden representations. 48 | intermediate_size (`int`, *optional*, defaults to 14336): 49 | Dimension of the MLP representations. 50 | num_hidden_layers (`int`, *optional*, defaults to 32): 51 | Number of hidden layers in the Transformer encoder. 52 | num_attention_heads (`int`, *optional*, defaults to 32): 53 | Number of attention heads for each attention layer in the Transformer encoder. 54 | num_key_value_heads (`int`, *optional*, defaults to 8): 55 | This is the number of key_value heads that should be used to implement Grouped Query Attention. If 56 | `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if 57 | `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When 58 | converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed 59 | by meanpooling all the original heads within that group. For more details checkout [this 60 | paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `8`. 61 | hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): 62 | The non-linear activation function (function or string) in the decoder. 63 | max_position_embeddings (`int`, *optional*, defaults to `4096*32`): 64 | The maximum sequence length that this model might ever be used with. Mistral's sliding window attention 65 | allows sequence of up to 4096*32 tokens. 66 | initializer_range (`float`, *optional*, defaults to 0.02): 67 | The standard deviation of the truncated_normal_initializer for initializing all weight matrices. 68 | rms_norm_eps (`float`, *optional*, defaults to 1e-06): 69 | The epsilon used by the rms normalization layers. 70 | use_cache (`bool`, *optional*, defaults to `True`): 71 | Whether or not the model should return the last key/values attentions (not used by all models). Only 72 | relevant if `config.is_decoder=True`. 73 | pad_token_id (`int`, *optional*): 74 | The id of the padding token. 75 | bos_token_id (`int`, *optional*, defaults to 1): 76 | The id of the "beginning-of-sequence" token. 77 | eos_token_id (`int`, *optional*, defaults to 2): 78 | The id of the "end-of-sequence" token. 79 | tie_word_embeddings (`bool`, *optional*, defaults to `False`): 80 | Whether the model's input and output word embeddings should be tied. 81 | rope_theta (`float`, *optional*, defaults to 10000.0): 82 | The base period of the RoPE embeddings. 83 | sliding_window (`int`, *optional*, defaults to 4096): 84 | Sliding window attention window size. If not specified, will default to `4096`. 85 | attention_dropout (`float`, *optional*, defaults to 0.0): 86 | The dropout ratio for the attention probabilities. 87 | 88 | ```python 89 | >>> from transformers import MistralModel, MistralConfig 90 | 91 | >>> # Initializing a Mistral 7B style configuration 92 | >>> configuration = MistralConfig() 93 | 94 | >>> # Initializing a model from the Mistral 7B style configuration 95 | >>> model = MistralModel(configuration) 96 | 97 | >>> # Accessing the model configuration 98 | >>> configuration = model.config 99 | ```""" 100 | 101 | model_type = "mistral" 102 | keys_to_ignore_at_inference = ["past_key_values"] 103 | 104 | def __init__( 105 | self, 106 | vocab_size=32000, 107 | hidden_size=4096, 108 | intermediate_size=14336, 109 | num_hidden_layers=32, 110 | num_attention_heads=32, 111 | num_key_value_heads=8, 112 | hidden_act="silu", 113 | max_position_embeddings=4096 * 32, 114 | initializer_range=0.02, 115 | rms_norm_eps=1e-6, 116 | use_cache=True, 117 | pad_token_id=None, 118 | bos_token_id=1, 119 | eos_token_id=2, 120 | tie_word_embeddings=False, 121 | rope_theta=10000.0, 122 | attention_dropout=0.0, 123 | num_experts_per_token=2, 124 | num_experts=8, 125 | single_expert=False, 126 | **kwargs, 127 | ): 128 | self.vocab_size = vocab_size 129 | self.max_position_embeddings = max_position_embeddings 130 | self.hidden_size = hidden_size 131 | self.intermediate_size = intermediate_size 132 | self.num_hidden_layers = num_hidden_layers 133 | self.num_attention_heads = num_attention_heads 134 | 135 | # for backward compatibility 136 | if num_key_value_heads is None: 137 | num_key_value_heads = num_attention_heads 138 | 139 | self.num_key_value_heads = num_key_value_heads 140 | self.hidden_act = hidden_act 141 | self.initializer_range = initializer_range 142 | self.rms_norm_eps = rms_norm_eps 143 | self.use_cache = use_cache 144 | self.rope_theta = rope_theta 145 | self.attention_dropout = attention_dropout 146 | self.num_experts = num_experts 147 | self.num_experts_per_token = num_experts_per_token 148 | 149 | self.single_expert = single_expert 150 | 151 | super().__init__( 152 | pad_token_id=pad_token_id, 153 | bos_token_id=bos_token_id, 154 | eos_token_id=eos_token_id, 155 | tie_word_embeddings=tie_word_embeddings, 156 | **kwargs, 157 | ) 158 | -------------------------------------------------------------------------------- /mixtral_base/modeling_moe_mistral.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 Mistral AI and the HuggingFace Inc. team. All rights reserved. 3 | # 4 | # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX 5 | # and OPT implementations in this library. It has been modified from its 6 | # original forms to accommodate minor architectural differences compared 7 | # to GPT-NeoX and OPT used by the Meta AI team that trained the model. 8 | # 9 | # Licensed under the Apache License, Version 2.0 (the "License"); 10 | # you may not use this file except in compliance with the License. 11 | # You may obtain a copy of the License at 12 | # 13 | # http://www.apache.org/licenses/LICENSE-2.0 14 | # 15 | # Unless required by applicable law or agreed to in writing, software 16 | # distributed under the License is distributed on an "AS IS" BASIS, 17 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 18 | # See the License for the specific language governing permissions and 19 | # limitations under the License. 20 | """ PyTorch Mistral model.""" 21 | import inspect 22 | import math 23 | import warnings 24 | from typing import List, Optional, Tuple, Union 25 | 26 | import torch 27 | import torch.nn.functional as F 28 | import torch.utils.checkpoint 29 | from torch import nn 30 | from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss 31 | 32 | from transformers.activations import ACT2FN 33 | from transformers.cache_utils import Cache, DynamicCache 34 | from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask 35 | from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast 36 | from transformers.modeling_utils import PreTrainedModel 37 | from transformers.utils import ( 38 | add_start_docstrings, 39 | add_start_docstrings_to_model_forward, 40 | is_flash_attn_2_available, 41 | is_flash_attn_greater_or_equal_2_10, 42 | logging, 43 | replace_return_docstrings, 44 | ) 45 | from .configuration_moe_mistral import MixtralConfig 46 | 47 | 48 | 49 | if is_flash_attn_2_available(): 50 | from flash_attn import flash_attn_func, flash_attn_varlen_func 51 | from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa 52 | 53 | _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters) 54 | 55 | 56 | logger = logging.get_logger(__name__) 57 | 58 | _CONFIG_FOR_DOC = "MixtralConfig" 59 | 60 | 61 | # Copied from transformers.models.llama.modeling_llama._get_unpad_data 62 | def _get_unpad_data(attention_mask): 63 | seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) 64 | indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() 65 | max_seqlen_in_batch = seqlens_in_batch.max().item() 66 | cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) 67 | return ( 68 | indices, 69 | cu_seqlens, 70 | max_seqlen_in_batch, 71 | ) 72 | 73 | 74 | # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Mistral 75 | class MistralRMSNorm(nn.Module): 76 | def __init__(self, hidden_size, eps=1e-6): 77 | """ 78 | MistralRMSNorm is equivalent to T5LayerNorm 79 | """ 80 | super().__init__() 81 | self.weight = nn.Parameter(torch.ones(hidden_size)) 82 | self.variance_epsilon = eps 83 | 84 | def forward(self, hidden_states): 85 | input_dtype = hidden_states.dtype 86 | hidden_states = hidden_states.to(torch.float32) 87 | variance = hidden_states.pow(2).mean(-1, keepdim=True) 88 | hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) 89 | return self.weight * hidden_states.to(input_dtype) 90 | 91 | 92 | # Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Mistral 93 | class MistralRotaryEmbedding(nn.Module): 94 | def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): 95 | super().__init__() 96 | 97 | self.dim = dim 98 | self.max_position_embeddings = max_position_embeddings 99 | self.base = base 100 | inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) 101 | self.register_buffer("inv_freq", inv_freq, persistent=False) 102 | 103 | # Build here to make `torch.jit.trace` work. 104 | self._set_cos_sin_cache( 105 | seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() 106 | ) 107 | 108 | def _set_cos_sin_cache(self, seq_len, device, dtype): 109 | self.max_seq_len_cached = seq_len 110 | t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) 111 | 112 | freqs = torch.outer(t, self.inv_freq) 113 | # Different from paper, but it uses a different permutation in order to obtain the same calculation 114 | emb = torch.cat((freqs, freqs), dim=-1) 115 | self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) 116 | self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) 117 | 118 | def forward(self, x, seq_len=None): 119 | # x: [bs, num_attention_heads, seq_len, head_size] 120 | if seq_len > self.max_seq_len_cached: 121 | self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) 122 | 123 | return ( 124 | self.cos_cached[:seq_len].to(dtype=x.dtype), 125 | self.sin_cached[:seq_len].to(dtype=x.dtype), 126 | ) 127 | 128 | 129 | # Copied from transformers.models.llama.modeling_llama.rotate_half 130 | def rotate_half(x): 131 | """Rotates half the hidden dims of the input.""" 132 | x1 = x[..., : x.shape[-1] // 2] 133 | x2 = x[..., x.shape[-1] // 2 :] 134 | return torch.cat((-x2, x1), dim=-1) 135 | 136 | 137 | # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb 138 | def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): 139 | """Applies Rotary Position Embedding to the query and key tensors. 140 | 141 | Args: 142 | q (`torch.Tensor`): The query tensor. 143 | k (`torch.Tensor`): The key tensor. 144 | cos (`torch.Tensor`): The cosine part of the rotary embedding. 145 | sin (`torch.Tensor`): The sine part of the rotary embedding. 146 | position_ids (`torch.Tensor`): 147 | The position indices of the tokens corresponding to the query and key tensors. For example, this can be 148 | used to pass offsetted position ids when working with a KV-cache. 149 | unsqueeze_dim (`int`, *optional*, defaults to 1): 150 | The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and 151 | sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note 152 | that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and 153 | k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes 154 | cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have 155 | the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. 156 | Returns: 157 | `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. 158 | """ 159 | cos = cos[position_ids].unsqueeze(unsqueeze_dim) 160 | sin = sin[position_ids].unsqueeze(unsqueeze_dim) 161 | q_embed = (q * cos) + (rotate_half(q) * sin) 162 | k_embed = (k * cos) + (rotate_half(k) * sin) 163 | return q_embed, k_embed 164 | 165 | 166 | class FeedForward(nn.Module): 167 | def __init__( 168 | self, 169 | config 170 | ): 171 | """ 172 | Initialize the FeedForward module. 173 | 174 | Args: 175 | dim (int): Input dimension. 176 | hidden_dim (int): Hidden dimension of the feedforward layer. 177 | multiple_of (int): Value to ensure hidden dimension is a multiple of this value. 178 | ffn_dim_multiplier (float, optional): Custom multiplier for hidden dimension. Defaults to None. 179 | 180 | Attributes: 181 | w1 (ColumnParallelLinear): Linear transformation for the first layer. 182 | w2 (RowParallelLinear): Linear transformation for the second layer. 183 | w3 (ColumnParallelLinear): Linear transformation for the third layer. 184 | 185 | """ 186 | super().__init__() 187 | 188 | self.w1 = nn.Linear( 189 | config.hidden_size, config.intermediate_size, bias=False 190 | ) 191 | self.w2 = nn.Linear( 192 | config.intermediate_size, config.hidden_size, bias=False 193 | ) 194 | self.w3 = nn.Linear( 195 | config.hidden_size, config.intermediate_size, bias=False 196 | ) 197 | 198 | def forward(self, x): 199 | return self.w2(F.silu(self.w1(x)) * self.w3(x)) 200 | 201 | 202 | class MoE(nn.Module): 203 | def __init__( 204 | self, 205 | config, 206 | ): 207 | super().__init__() 208 | self.config = config 209 | if self.config.single_expert: 210 | print("using single expert!") 211 | num_experts = config.num_experts 212 | self.experts = nn.ModuleList([FeedForward(config) for i in range(num_experts)]) 213 | self.gate = nn.Linear(config.hidden_size, num_experts, bias=False) 214 | self.num_experts_per_token = config.num_experts_per_token 215 | 216 | def forward(self, x): 217 | if self.config.single_expert: 218 | return self.experts[0](x) 219 | orig_shape = x.shape 220 | x = x.view(-1, x.shape[-1]) 221 | 222 | scores = self.gate(x) 223 | expert_weights, expert_indices = torch.topk(scores, self.num_experts_per_token, dim=-1) 224 | expert_weights = expert_weights.softmax(dim=-1) 225 | if hasattr(self.config, 'one_gate') and (self.config.one_gate): 226 | # print("using one gate!") 227 | expert_weights = torch.ones_like(expert_weights) 228 | if hasattr(self.config, 'norm_one_gate') and (self.config.norm_one_gate): 229 | # print("using normed one gate!") 230 | expert_weights = expert_weights / self.num_experts_per_token 231 | flat_expert_indices = expert_indices.view(-1) 232 | 233 | x = x.repeat_interleave(self.num_experts_per_token, dim=0) 234 | y = torch.empty_like(x) 235 | for i, expert in enumerate(self.experts): 236 | y[flat_expert_indices == i] = expert(x[flat_expert_indices == i]) 237 | y = (y.view(*expert_weights.shape, -1) * expert_weights.unsqueeze(-1)).sum(dim=1) 238 | return y.view(*orig_shape) 239 | 240 | 241 | # Copied from transformers.models.llama.modeling_llama.repeat_kv 242 | def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: 243 | """ 244 | This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, 245 | num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) 246 | """ 247 | batch, num_key_value_heads, slen, head_dim = hidden_states.shape 248 | if n_rep == 1: 249 | return hidden_states 250 | hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) 251 | return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) 252 | 253 | 254 | class MistralAttention(nn.Module): 255 | """ 256 | Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer 257 | and "Generating Long Sequences with Sparse Transformers". 258 | """ 259 | 260 | def __init__(self, config: MixtralConfig, layer_idx: Optional[int] = None): 261 | super().__init__() 262 | self.config = config 263 | self.layer_idx = layer_idx 264 | if layer_idx is None: 265 | logger.warning_once( 266 | f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will " 267 | "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " 268 | "when creating this class." 269 | ) 270 | 271 | self.hidden_size = config.hidden_size 272 | self.num_heads = config.num_attention_heads 273 | self.head_dim = self.hidden_size // self.num_heads 274 | self.num_key_value_heads = config.num_key_value_heads 275 | self.num_key_value_groups = self.num_heads // self.num_key_value_heads 276 | self.max_position_embeddings = config.max_position_embeddings 277 | self.rope_theta = config.rope_theta 278 | self.is_causal = True 279 | self.attention_dropout = config.attention_dropout 280 | 281 | if (self.head_dim * self.num_heads) != self.hidden_size: 282 | raise ValueError( 283 | f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" 284 | f" and `num_heads`: {self.num_heads})." 285 | ) 286 | self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) 287 | self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) 288 | self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) 289 | self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) 290 | 291 | self.rotary_emb = MistralRotaryEmbedding( 292 | self.head_dim, 293 | max_position_embeddings=self.max_position_embeddings, 294 | base=self.rope_theta, 295 | ) 296 | 297 | def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): 298 | return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() 299 | 300 | def forward( 301 | self, 302 | hidden_states: torch.Tensor, 303 | attention_mask: Optional[torch.Tensor] = None, 304 | position_ids: Optional[torch.LongTensor] = None, 305 | past_key_value: Optional[Cache] = None, 306 | output_attentions: bool = False, 307 | use_cache: bool = False, 308 | **kwargs, 309 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 310 | if "padding_mask" in kwargs: 311 | warnings.warn( 312 | "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" 313 | ) 314 | bsz, q_len, _ = hidden_states.size() 315 | 316 | query_states = self.q_proj(hidden_states) 317 | key_states = self.k_proj(hidden_states) 318 | value_states = self.v_proj(hidden_states) 319 | 320 | query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) 321 | key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) 322 | value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) 323 | 324 | kv_seq_len = key_states.shape[-2] 325 | if past_key_value is not None: 326 | if self.layer_idx is None: 327 | raise ValueError( 328 | f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " 329 | "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " 330 | "with a layer index." 331 | ) 332 | kv_seq_len += past_key_value.get_seq_length(self.layer_idx) 333 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) 334 | query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) 335 | 336 | if past_key_value is not None: 337 | cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models 338 | key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) 339 | 340 | # repeat k/v heads if n_kv_heads < n_heads 341 | key_states = repeat_kv(key_states, self.num_key_value_groups) 342 | value_states = repeat_kv(value_states, self.num_key_value_groups) 343 | 344 | attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) 345 | 346 | if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): 347 | raise ValueError( 348 | f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" 349 | f" {attn_weights.size()}" 350 | ) 351 | 352 | if attention_mask is not None: 353 | if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): 354 | raise ValueError( 355 | f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" 356 | ) 357 | 358 | attn_weights = attn_weights + attention_mask 359 | 360 | # upcast attention to fp32 361 | attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) 362 | attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) 363 | attn_output = torch.matmul(attn_weights, value_states) 364 | 365 | if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): 366 | raise ValueError( 367 | f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" 368 | f" {attn_output.size()}" 369 | ) 370 | 371 | attn_output = attn_output.transpose(1, 2).contiguous() 372 | attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) 373 | 374 | attn_output = self.o_proj(attn_output) 375 | 376 | if not output_attentions: 377 | attn_weights = None 378 | 379 | return attn_output, attn_weights, past_key_value 380 | 381 | 382 | class MistralFlashAttention2(MistralAttention): 383 | """ 384 | Mistral flash attention module. This module inherits from `MistralAttention` as the weights of the module stays 385 | untouched. The only required change would be on the forward pass where it needs to correctly call the public API of 386 | flash attention and deal with padding tokens in case the input contains any of them. 387 | """ 388 | 389 | # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ 390 | def __init__(self, *args, **kwargs): 391 | super().__init__(*args, **kwargs) 392 | 393 | # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. 394 | # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. 395 | # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). 396 | self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() 397 | 398 | def forward( 399 | self, 400 | hidden_states: torch.Tensor, 401 | attention_mask: Optional[torch.Tensor] = None, 402 | position_ids: Optional[torch.LongTensor] = None, 403 | past_key_value: Optional[Cache] = None, 404 | output_attentions: bool = False, 405 | use_cache: bool = False, 406 | **kwargs, 407 | ): 408 | if "padding_mask" in kwargs: 409 | warnings.warn( 410 | "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" 411 | ) 412 | 413 | # overwrite attention_mask with padding_mask 414 | attention_mask = kwargs.pop("padding_mask") 415 | bsz, q_len, _ = hidden_states.size() 416 | 417 | query_states = self.q_proj(hidden_states) 418 | key_states = self.k_proj(hidden_states) 419 | value_states = self.v_proj(hidden_states) 420 | 421 | query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) 422 | key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) 423 | value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) 424 | 425 | kv_seq_len = key_states.shape[-2] 426 | if past_key_value is not None: 427 | kv_seq_len += past_key_value.get_seq_length(self.layer_idx) 428 | 429 | # Because the input can be padded, the absolute sequence length depends on the max position id. 430 | rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1 431 | cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len) 432 | 433 | query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) 434 | 435 | use_sliding_windows = ( 436 | _flash_supports_window_size 437 | and getattr(self.config, "sliding_window", None) is not None 438 | and kv_seq_len > self.config.sliding_window 439 | ) 440 | 441 | if not _flash_supports_window_size: 442 | logger.warning_once( 443 | "The current flash attention version does not support sliding window attention, for a more memory efficient implementation" 444 | " make sure to upgrade flash-attn library." 445 | ) 446 | 447 | if past_key_value is not None: 448 | # Activate slicing cache only if the config has a value `sliding_windows` attribute 449 | if getattr(self.config, "sliding_window", None) is not None and kv_seq_len > self.config.sliding_window: 450 | slicing_tokens = 1 - self.config.sliding_window 451 | 452 | past_key = past_key_value[0] 453 | past_value = past_key_value[1] 454 | 455 | past_key = past_key[:, :, slicing_tokens:, :].contiguous() 456 | past_value = past_value[:, :, slicing_tokens:, :].contiguous() 457 | 458 | if past_key.shape[-2] != self.config.sliding_window - 1: 459 | raise ValueError( 460 | f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got" 461 | f" {past_key.shape}" 462 | ) 463 | 464 | past_key_value = (past_key, past_value) 465 | 466 | if attention_mask is not None: 467 | attention_mask = attention_mask[:, slicing_tokens:] 468 | attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1) 469 | 470 | cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models 471 | key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) 472 | 473 | # repeat k/v heads if n_kv_heads < n_heads 474 | key_states = repeat_kv(key_states, self.num_key_value_groups) 475 | value_states = repeat_kv(value_states, self.num_key_value_groups) 476 | dropout_rate = 0.0 if not self.training else self.attention_dropout 477 | 478 | # In PEFT, usually we cast the layer norms in float32 for training stability reasons 479 | # therefore the input hidden states gets silently casted in float32. Hence, we need 480 | # cast them back in float16 just to be sure everything works as expected. 481 | input_dtype = query_states.dtype 482 | if input_dtype == torch.float32: 483 | # Handle the case where the model is quantized 484 | if hasattr(self.config, "_pre_quantization_dtype"): 485 | target_dtype = self.config._pre_quantization_dtype 486 | else: 487 | target_dtype = self.q_proj.weight.dtype 488 | 489 | logger.warning_once( 490 | f"The input hidden states seems to be silently casted in float32, this might be related to" 491 | f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" 492 | f" {target_dtype}." 493 | ) 494 | 495 | query_states = query_states.to(target_dtype) 496 | key_states = key_states.to(target_dtype) 497 | value_states = value_states.to(target_dtype) 498 | 499 | # Reashape to the expected shape for Flash Attention 500 | query_states = query_states.transpose(1, 2) 501 | key_states = key_states.transpose(1, 2) 502 | value_states = value_states.transpose(1, 2) 503 | 504 | attn_output = self._flash_attention_forward( 505 | query_states, 506 | key_states, 507 | value_states, 508 | attention_mask, 509 | q_len, 510 | dropout=dropout_rate, 511 | use_sliding_windows=use_sliding_windows, 512 | ) 513 | 514 | attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() 515 | attn_output = self.o_proj(attn_output) 516 | 517 | if not output_attentions: 518 | attn_weights = None 519 | 520 | return attn_output, attn_weights, past_key_value 521 | 522 | def _flash_attention_forward( 523 | self, 524 | query_states, 525 | key_states, 526 | value_states, 527 | attention_mask, 528 | query_length, 529 | dropout=0.0, 530 | softmax_scale=None, 531 | use_sliding_windows=False, 532 | ): 533 | """ 534 | Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token 535 | first unpad the input, then computes the attention scores and pad the final attention scores. 536 | 537 | Args: 538 | query_states (`torch.Tensor`): 539 | Input query states to be passed to Flash Attention API 540 | key_states (`torch.Tensor`): 541 | Input key states to be passed to Flash Attention API 542 | value_states (`torch.Tensor`): 543 | Input value states to be passed to Flash Attention API 544 | attention_mask (`torch.Tensor`): 545 | The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the 546 | position of padding tokens and 1 for the position of non-padding tokens. 547 | dropout (`int`, *optional*): 548 | Attention dropout 549 | softmax_scale (`float`, *optional*): 550 | The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) 551 | use_sliding_windows (`bool`, *optional*): 552 | Whether to activate sliding window attention. 553 | """ 554 | if not self._flash_attn_uses_top_left_mask: 555 | causal = self.is_causal 556 | else: 557 | # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__. 558 | causal = self.is_causal and query_length != 1 559 | 560 | # Contains at least one padding token in the sequence 561 | if attention_mask is not None: 562 | batch_size = query_states.shape[0] 563 | query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( 564 | query_states, key_states, value_states, attention_mask, query_length 565 | ) 566 | 567 | cu_seqlens_q, cu_seqlens_k = cu_seq_lens 568 | max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens 569 | 570 | if not use_sliding_windows: 571 | attn_output_unpad = flash_attn_varlen_func( 572 | query_states, 573 | key_states, 574 | value_states, 575 | cu_seqlens_q=cu_seqlens_q, 576 | cu_seqlens_k=cu_seqlens_k, 577 | max_seqlen_q=max_seqlen_in_batch_q, 578 | max_seqlen_k=max_seqlen_in_batch_k, 579 | dropout_p=dropout, 580 | softmax_scale=softmax_scale, 581 | causal=causal, 582 | ) 583 | else: 584 | attn_output_unpad = flash_attn_varlen_func( 585 | query_states, 586 | key_states, 587 | value_states, 588 | cu_seqlens_q=cu_seqlens_q, 589 | cu_seqlens_k=cu_seqlens_k, 590 | max_seqlen_q=max_seqlen_in_batch_q, 591 | max_seqlen_k=max_seqlen_in_batch_k, 592 | dropout_p=dropout, 593 | softmax_scale=softmax_scale, 594 | causal=causal, 595 | window_size=(self.config.sliding_window, self.config.sliding_window), 596 | ) 597 | 598 | attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) 599 | else: 600 | if not use_sliding_windows: 601 | attn_output = flash_attn_func( 602 | query_states, 603 | key_states, 604 | value_states, 605 | dropout, 606 | softmax_scale=softmax_scale, 607 | causal=causal, 608 | ) 609 | else: 610 | attn_output = flash_attn_func( 611 | query_states, 612 | key_states, 613 | value_states, 614 | dropout, 615 | softmax_scale=softmax_scale, 616 | causal=causal, 617 | window_size=(self.config.sliding_window, self.config.sliding_window), 618 | ) 619 | 620 | return attn_output 621 | 622 | def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): 623 | batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape 624 | 625 | # On the first iteration we need to properly re-create the padding mask 626 | # by slicing it on the proper place 627 | if kv_seq_len != attention_mask.shape[-1]: 628 | attention_mask_num_tokens = attention_mask.shape[-1] 629 | attention_mask = attention_mask[:, attention_mask_num_tokens - kv_seq_len :] 630 | 631 | indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) 632 | 633 | key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k) 634 | value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k) 635 | 636 | if query_length == kv_seq_len: 637 | query_layer = index_first_axis( 638 | query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k 639 | ) 640 | cu_seqlens_q = cu_seqlens_k 641 | max_seqlen_in_batch_q = max_seqlen_in_batch_k 642 | indices_q = indices_k 643 | elif query_length == 1: 644 | max_seqlen_in_batch_q = 1 645 | cu_seqlens_q = torch.arange( 646 | batch_size + 1, dtype=torch.int32, device=query_layer.device 647 | ) # There is a memcpy here, that is very bad. 648 | indices_q = cu_seqlens_q[:-1] 649 | query_layer = query_layer.squeeze(1) 650 | else: 651 | # The -q_len: slice assumes left padding. 652 | attention_mask = attention_mask[:, -query_length:] 653 | query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) 654 | 655 | return ( 656 | query_layer, 657 | key_layer, 658 | value_layer, 659 | indices_q, 660 | (cu_seqlens_q, cu_seqlens_k), 661 | (max_seqlen_in_batch_q, max_seqlen_in_batch_k), 662 | ) 663 | 664 | 665 | class MistralDecoderLayer(nn.Module): 666 | def __init__(self, config: MixtralConfig, layer_idx: int): 667 | super().__init__() 668 | self.hidden_size = config.hidden_size 669 | self.self_attn = ( 670 | MistralAttention(config=config, layer_idx=layer_idx) 671 | if not getattr(config, "_flash_attn_2_enabled", False) 672 | else MistralFlashAttention2(config, layer_idx=layer_idx) 673 | ) 674 | self.mlp = MoE(config) 675 | self.input_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) 676 | self.post_attention_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) 677 | 678 | def forward( 679 | self, 680 | hidden_states: torch.Tensor, 681 | attention_mask: Optional[torch.Tensor] = None, 682 | position_ids: Optional[torch.LongTensor] = None, 683 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 684 | output_attentions: Optional[bool] = False, 685 | use_cache: Optional[bool] = False, 686 | **kwargs, 687 | ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: 688 | if "padding_mask" in kwargs: 689 | warnings.warn( 690 | "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" 691 | ) 692 | """ 693 | Args: 694 | hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` 695 | attention_mask (`torch.FloatTensor`, *optional*): attention mask of size 696 | `(batch, sequence_length)` where padding elements are indicated by 0. 697 | output_attentions (`bool`, *optional*): 698 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under 699 | returned tensors for more detail. 700 | use_cache (`bool`, *optional*): 701 | If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding 702 | (see `past_key_values`). 703 | past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states 704 | """ 705 | 706 | residual = hidden_states 707 | 708 | hidden_states = self.input_layernorm(hidden_states) 709 | 710 | # Self Attention 711 | hidden_states, self_attn_weights, present_key_value = self.self_attn( 712 | hidden_states=hidden_states, 713 | attention_mask=attention_mask, 714 | position_ids=position_ids, 715 | past_key_value=past_key_value, 716 | output_attentions=output_attentions, 717 | use_cache=use_cache, 718 | ) 719 | hidden_states = residual + hidden_states 720 | 721 | # Fully Connected 722 | residual = hidden_states 723 | hidden_states = self.post_attention_layernorm(hidden_states) 724 | hidden_states = self.mlp(hidden_states) 725 | hidden_states = residual + hidden_states 726 | 727 | outputs = (hidden_states,) 728 | 729 | if output_attentions: 730 | outputs += (self_attn_weights,) 731 | 732 | if use_cache: 733 | outputs += (present_key_value,) 734 | 735 | return outputs 736 | 737 | 738 | MISTRAL_START_DOCSTRING = r""" 739 | This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the 740 | library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads 741 | etc.) 742 | 743 | This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. 744 | Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage 745 | and behavior. 746 | 747 | Parameters: 748 | config ([`MixtralConfig`]): 749 | Model configuration class with all the parameters of the model. Initializing with a config file does not 750 | load the weights associated with the model, only the configuration. Check out the 751 | [`~PreTrainedModel.from_pretrained`] method to load the model weights. 752 | """ 753 | 754 | 755 | @add_start_docstrings( 756 | "The bare Mistral Model outputting raw hidden-states without any specific head on top.", 757 | MISTRAL_START_DOCSTRING, 758 | ) 759 | class MistralPreTrainedModel(PreTrainedModel): 760 | config_class = MixtralConfig 761 | base_model_prefix = "model" 762 | supports_gradient_checkpointing = True 763 | _no_split_modules = ["MistralDecoderLayer"] 764 | _skip_keys_device_placement = "past_key_values" 765 | _supports_flash_attn_2 = True 766 | _supports_cache_class = True 767 | 768 | def _init_weights(self, module): 769 | std = self.config.initializer_range 770 | if isinstance(module, nn.Linear): 771 | module.weight.data.normal_(mean=0.0, std=std) 772 | if module.bias is not None: 773 | module.bias.data.zero_() 774 | elif isinstance(module, nn.Embedding): 775 | module.weight.data.normal_(mean=0.0, std=std) 776 | if module.padding_idx is not None: 777 | module.weight.data[module.padding_idx].zero_() 778 | 779 | 780 | MISTRAL_INPUTS_DOCSTRING = r""" 781 | Args: 782 | input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): 783 | Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide 784 | it. 785 | 786 | Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and 787 | [`PreTrainedTokenizer.__call__`] for details. 788 | 789 | [What are input IDs?](../glossary#input-ids) 790 | attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): 791 | Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: 792 | 793 | - 1 for tokens that are **not masked**, 794 | - 0 for tokens that are **masked**. 795 | 796 | [What are attention masks?](../glossary#attention-mask) 797 | 798 | Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and 799 | [`PreTrainedTokenizer.__call__`] for details. 800 | 801 | If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see 802 | `past_key_values`). 803 | 804 | If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] 805 | and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more 806 | information on the default strategy. 807 | 808 | - 1 indicates the head is **not masked**, 809 | - 0 indicates the head is **masked**. 810 | position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): 811 | Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, 812 | config.n_positions - 1]`. 813 | 814 | [What are position IDs?](../glossary#position-ids) 815 | past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): 816 | Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention 817 | blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` 818 | returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. 819 | 820 | Two formats are allowed: 821 | - a [`~cache_utils.Cache`] instance; 822 | - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of 823 | shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy 824 | cache format. 825 | 826 | The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the 827 | legacy cache format will be returned. 828 | 829 | If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't 830 | have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` 831 | of shape `(batch_size, sequence_length)`. 832 | inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): 833 | Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This 834 | is useful if you want more control over how to convert `input_ids` indices into associated vectors than the 835 | model's internal embedding lookup matrix. 836 | use_cache (`bool`, *optional*): 837 | If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see 838 | `past_key_values`). 839 | output_attentions (`bool`, *optional*): 840 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned 841 | tensors for more detail. 842 | output_hidden_states (`bool`, *optional*): 843 | Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for 844 | more detail. 845 | return_dict (`bool`, *optional*): 846 | Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. 847 | """ 848 | 849 | 850 | @add_start_docstrings( 851 | "The bare Mistral Model outputting raw hidden-states without any specific head on top.", 852 | MISTRAL_START_DOCSTRING, 853 | ) 854 | class MistralModel(MistralPreTrainedModel): 855 | """ 856 | Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MistralDecoderLayer`] 857 | 858 | Args: 859 | config: MixtralConfig 860 | """ 861 | 862 | def __init__(self, config: MixtralConfig): 863 | super().__init__(config) 864 | self.padding_idx = config.pad_token_id 865 | self.vocab_size = config.vocab_size 866 | 867 | self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) 868 | self.layers = nn.ModuleList( 869 | [MistralDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] 870 | ) 871 | self.norm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) 872 | 873 | self.gradient_checkpointing = False 874 | # Initialize weights and apply final processing 875 | self.post_init() 876 | 877 | def get_input_embeddings(self): 878 | return self.embed_tokens 879 | 880 | def set_input_embeddings(self, value): 881 | self.embed_tokens = value 882 | 883 | @add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING) 884 | def forward( 885 | self, 886 | input_ids: torch.LongTensor = None, 887 | attention_mask: Optional[torch.Tensor] = None, 888 | position_ids: Optional[torch.LongTensor] = None, 889 | past_key_values: Optional[List[torch.FloatTensor]] = None, 890 | inputs_embeds: Optional[torch.FloatTensor] = None, 891 | use_cache: Optional[bool] = None, 892 | output_attentions: Optional[bool] = None, 893 | output_hidden_states: Optional[bool] = None, 894 | return_dict: Optional[bool] = None, 895 | decoder_layer_idx: Optional[int] = None, 896 | ) -> Union[Tuple, BaseModelOutputWithPast]: 897 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 898 | output_hidden_states = ( 899 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 900 | ) 901 | use_cache = use_cache if use_cache is not None else self.config.use_cache 902 | 903 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 904 | 905 | # retrieve input_ids and inputs_embeds 906 | if input_ids is not None and inputs_embeds is not None: 907 | raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") 908 | elif input_ids is not None: 909 | batch_size, seq_length = input_ids.shape 910 | elif inputs_embeds is not None: 911 | batch_size, seq_length, _ = inputs_embeds.shape 912 | else: 913 | raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") 914 | 915 | seq_length_with_past = seq_length 916 | past_key_values_length = 0 917 | 918 | if use_cache: 919 | use_legacy_cache = not isinstance(past_key_values, Cache) 920 | if use_legacy_cache: 921 | past_key_values = DynamicCache.from_legacy_cache(past_key_values) 922 | past_key_values_length = past_key_values.get_seq_length() 923 | seq_length_with_past = seq_length_with_past + past_key_values_length 924 | 925 | if position_ids is None: 926 | device = input_ids.device if input_ids is not None else inputs_embeds.device 927 | position_ids = torch.arange( 928 | past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device 929 | ) 930 | position_ids = position_ids.unsqueeze(0).view(-1, seq_length) 931 | else: 932 | position_ids = position_ids.view(-1, seq_length).long() 933 | 934 | if inputs_embeds is None: 935 | inputs_embeds = self.embed_tokens(input_ids) 936 | 937 | if ( 938 | attention_mask is not None 939 | and hasattr(self.config, "_flash_attn_2_enabled") 940 | and self.config._flash_attn_2_enabled 941 | and use_cache 942 | ): 943 | is_padding_right = attention_mask[:, -1].sum().item() != batch_size 944 | if is_padding_right: 945 | raise ValueError( 946 | "You are attempting to perform batched generation with padding_side='right'" 947 | " this may lead to unexpected behaviour for Flash Attention version of Mistral. Make sure to " 948 | " call `tokenizer.padding_side = 'left'` before tokenizing the input. " 949 | ) 950 | 951 | if getattr(self.config, "_flash_attn_2_enabled", False): 952 | # 2d mask is passed through the layers 953 | attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None 954 | else: 955 | # 4d mask is passed through the layers 956 | attention_mask = _prepare_4d_causal_attention_mask( 957 | attention_mask, 958 | (batch_size, seq_length), 959 | inputs_embeds, 960 | past_key_values_length 961 | ) 962 | 963 | hidden_states = inputs_embeds 964 | 965 | if self.gradient_checkpointing and self.training: 966 | if use_cache: 967 | logger.warning_once( 968 | "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." 969 | ) 970 | use_cache = False 971 | 972 | # decoder layers 973 | all_hidden_states = () if output_hidden_states else None 974 | all_self_attns = () if output_attentions else None 975 | next_decoder_cache = None 976 | 977 | for i, decoder_layer in enumerate(self.layers): 978 | if decoder_layer_idx is not None and i != decoder_layer_idx: 979 | continue 980 | if output_hidden_states: 981 | all_hidden_states += (hidden_states,) 982 | 983 | if self.gradient_checkpointing and self.training: 984 | layer_outputs = self._gradient_checkpointing_func( 985 | decoder_layer.__call__, 986 | hidden_states, 987 | attention_mask, 988 | position_ids, 989 | past_key_values, 990 | output_attentions, 991 | use_cache, 992 | ) 993 | else: 994 | layer_outputs = decoder_layer( 995 | hidden_states, 996 | attention_mask=attention_mask, 997 | position_ids=position_ids, 998 | past_key_value=past_key_values, 999 | output_attentions=output_attentions, 1000 | use_cache=use_cache, 1001 | ) 1002 | 1003 | hidden_states = layer_outputs[0] 1004 | 1005 | if use_cache: 1006 | next_decoder_cache = layer_outputs[2 if output_attentions else 1] 1007 | 1008 | if output_attentions: 1009 | all_self_attns += (layer_outputs[1],) 1010 | 1011 | hidden_states = self.norm(hidden_states) 1012 | 1013 | # add hidden states from the last decoder layer 1014 | if output_hidden_states: 1015 | all_hidden_states += (hidden_states,) 1016 | 1017 | next_cache = None 1018 | if use_cache: 1019 | next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache 1020 | 1021 | if not return_dict: 1022 | return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) 1023 | return BaseModelOutputWithPast( 1024 | last_hidden_state=hidden_states, 1025 | past_key_values=next_cache, 1026 | hidden_states=all_hidden_states, 1027 | attentions=all_self_attns, 1028 | ) 1029 | 1030 | 1031 | class MixtralForCausalLM(MistralPreTrainedModel): 1032 | _tied_weights_keys = ["lm_head.weight"] 1033 | 1034 | def __init__(self, config): 1035 | super().__init__(config) 1036 | self.model = MistralModel(config) 1037 | self.vocab_size = config.vocab_size 1038 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 1039 | 1040 | # Initialize weights and apply final processing 1041 | self.post_init() 1042 | 1043 | def get_input_embeddings(self): 1044 | return self.model.embed_tokens 1045 | 1046 | def set_input_embeddings(self, value): 1047 | self.model.embed_tokens = value 1048 | 1049 | def get_output_embeddings(self): 1050 | return self.lm_head 1051 | 1052 | def set_output_embeddings(self, new_embeddings): 1053 | self.lm_head = new_embeddings 1054 | 1055 | def set_decoder(self, decoder): 1056 | self.model = decoder 1057 | 1058 | def get_decoder(self): 1059 | return self.model 1060 | 1061 | def _init_weights(self, module): 1062 | return 1063 | 1064 | @add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING) 1065 | @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) 1066 | def forward( 1067 | self, 1068 | input_ids: torch.LongTensor = None, 1069 | attention_mask: Optional[torch.Tensor] = None, 1070 | position_ids: Optional[torch.LongTensor] = None, 1071 | past_key_values: Optional[List[torch.FloatTensor]] = None, 1072 | inputs_embeds: Optional[torch.FloatTensor] = None, 1073 | labels: Optional[torch.LongTensor] = None, 1074 | use_cache: Optional[bool] = None, 1075 | output_attentions: Optional[bool] = None, 1076 | output_hidden_states: Optional[bool] = None, 1077 | return_dict: Optional[bool] = None, 1078 | decoder_layer_idx: Optional[int] = None, 1079 | ) -> Union[Tuple, CausalLMOutputWithPast]: 1080 | r""" 1081 | Args: 1082 | labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): 1083 | Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., 1084 | config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored 1085 | (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. 1086 | 1087 | Returns: 1088 | 1089 | Example: 1090 | 1091 | ```python 1092 | >>> from transformers import AutoTokenizer, MistralForCausalLM 1093 | 1094 | >>> model = MistralForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) 1095 | >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) 1096 | 1097 | >>> prompt = "Hey, are you conscious? Can you talk to me?" 1098 | >>> inputs = tokenizer(prompt, return_tensors="pt") 1099 | 1100 | >>> # Generate 1101 | >>> generate_ids = model.generate(inputs.input_ids, max_length=30) 1102 | >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] 1103 | "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." 1104 | ```""" 1105 | 1106 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 1107 | output_hidden_states = ( 1108 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 1109 | ) 1110 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1111 | 1112 | # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) 1113 | outputs = self.model( 1114 | input_ids=input_ids, 1115 | attention_mask=attention_mask, 1116 | position_ids=position_ids, 1117 | past_key_values=past_key_values, 1118 | inputs_embeds=inputs_embeds, 1119 | use_cache=use_cache, 1120 | output_attentions=output_attentions, 1121 | output_hidden_states=output_hidden_states, 1122 | return_dict=return_dict, 1123 | decoder_layer_idx=decoder_layer_idx, 1124 | ) 1125 | 1126 | hidden_states = outputs[0] 1127 | logits = self.lm_head(hidden_states) 1128 | logits = logits.float() 1129 | 1130 | loss = None 1131 | if labels is not None: 1132 | # Shift so that tokens < n predict n 1133 | shift_logits = logits[..., :-1, :].contiguous() 1134 | shift_labels = labels[..., 1:].contiguous() 1135 | # Flatten the tokens 1136 | loss_fct = CrossEntropyLoss() 1137 | shift_logits = shift_logits.view(-1, self.config.vocab_size) 1138 | shift_labels = shift_labels.view(-1) 1139 | # Enable model parallelism 1140 | shift_labels = shift_labels.to(shift_logits.device) 1141 | loss = loss_fct(shift_logits, shift_labels) 1142 | 1143 | if not return_dict: 1144 | output = (logits,) + outputs[1:] 1145 | return (loss,) + output if loss is not None else output 1146 | 1147 | return CausalLMOutputWithPast( 1148 | loss=loss, 1149 | logits=logits, 1150 | past_key_values=outputs.past_key_values, 1151 | hidden_states=outputs.hidden_states, 1152 | attentions=outputs.attentions, 1153 | ) 1154 | 1155 | def prepare_inputs_for_generation( 1156 | self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs 1157 | ): 1158 | # Omit tokens covered by past_key_values 1159 | if past_key_values is not None: 1160 | if isinstance(past_key_values, Cache): 1161 | cache_length = past_key_values.get_seq_length() 1162 | past_length = past_key_values.seen_tokens 1163 | else: 1164 | cache_length = past_length = past_key_values[0][0].shape[2] 1165 | 1166 | # Keep only the unprocessed tokens: 1167 | # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where 1168 | # some of the inputs are exclusivelly passed as part of the cache (e.g. when passing input_embeds as 1169 | # input) 1170 | if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: 1171 | input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] 1172 | # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard 1173 | # input_ids based on the past_length. 1174 | elif past_length < input_ids.shape[1]: 1175 | input_ids = input_ids[:, past_length:] 1176 | # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. 1177 | 1178 | # If the cache has seen more tokens than it can hold, then the cache has a size limit. Let's discard the 1179 | # older attention values, as their corresponding values are not part of the input. 1180 | if cache_length < past_length and attention_mask is not None: 1181 | attention_mask = attention_mask[:, -(cache_length + input_ids.shape[1]) :] 1182 | 1183 | position_ids = kwargs.get("position_ids", None) 1184 | if attention_mask is not None and position_ids is None: 1185 | # create position_ids on the fly for batch generation 1186 | position_ids = attention_mask.long().cumsum(-1) - 1 1187 | position_ids.masked_fill_(attention_mask == 0, 1) 1188 | if past_key_values: 1189 | position_ids = position_ids[:, -input_ids.shape[1] :] 1190 | 1191 | # if `inputs_embeds` are passed, we only want to use them in the 1st generation step 1192 | if inputs_embeds is not None and past_key_values is None: 1193 | model_inputs = {"inputs_embeds": inputs_embeds} 1194 | else: 1195 | model_inputs = {"input_ids": input_ids} 1196 | 1197 | model_inputs.update( 1198 | { 1199 | "position_ids": position_ids, 1200 | "past_key_values": past_key_values, 1201 | "use_cache": kwargs.get("use_cache"), 1202 | "attention_mask": attention_mask, 1203 | } 1204 | ) 1205 | return model_inputs 1206 | 1207 | @staticmethod 1208 | def _reorder_cache(past_key_values, beam_idx): 1209 | reordered_past = () 1210 | for layer_past in past_key_values: 1211 | reordered_past += ( 1212 | tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), 1213 | ) 1214 | return reordered_past 1215 | 1216 | 1217 | @add_start_docstrings( 1218 | """ 1219 | The Mistral Model transformer with a sequence classification head on top (linear layer). 1220 | 1221 | [`MistralForSequenceClassification`] uses the last token in order to do the classification, as other causal models 1222 | (e.g. GPT-2) do. 1223 | 1224 | Since it does classification on the last token, it requires to know the position of the last token. If a 1225 | `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If 1226 | no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the 1227 | padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in 1228 | each row of the batch). 1229 | """, 1230 | MISTRAL_START_DOCSTRING, 1231 | ) 1232 | # Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->Mistral, LLAMA->MISTRAL 1233 | class MistralForSequenceClassification(MistralPreTrainedModel): 1234 | def __init__(self, config): 1235 | super().__init__(config) 1236 | self.num_labels = config.num_labels 1237 | self.model = MistralModel(config) 1238 | self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) 1239 | 1240 | # Initialize weights and apply final processing 1241 | self.post_init() 1242 | 1243 | def get_input_embeddings(self): 1244 | return self.model.embed_tokens 1245 | 1246 | def set_input_embeddings(self, value): 1247 | self.model.embed_tokens = value 1248 | 1249 | @add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING) 1250 | def forward( 1251 | self, 1252 | input_ids: torch.LongTensor = None, 1253 | attention_mask: Optional[torch.Tensor] = None, 1254 | position_ids: Optional[torch.LongTensor] = None, 1255 | past_key_values: Optional[List[torch.FloatTensor]] = None, 1256 | inputs_embeds: Optional[torch.FloatTensor] = None, 1257 | labels: Optional[torch.LongTensor] = None, 1258 | use_cache: Optional[bool] = None, 1259 | output_attentions: Optional[bool] = None, 1260 | output_hidden_states: Optional[bool] = None, 1261 | return_dict: Optional[bool] = None, 1262 | ) -> Union[Tuple, SequenceClassifierOutputWithPast]: 1263 | r""" 1264 | labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): 1265 | Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., 1266 | config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If 1267 | `config.num_labels > 1` a classification loss is computed (Cross-Entropy). 1268 | """ 1269 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1270 | 1271 | transformer_outputs = self.model( 1272 | input_ids, 1273 | attention_mask=attention_mask, 1274 | position_ids=position_ids, 1275 | past_key_values=past_key_values, 1276 | inputs_embeds=inputs_embeds, 1277 | use_cache=use_cache, 1278 | output_attentions=output_attentions, 1279 | output_hidden_states=output_hidden_states, 1280 | return_dict=return_dict, 1281 | ) 1282 | hidden_states = transformer_outputs[0] 1283 | logits = self.score(hidden_states) 1284 | 1285 | if input_ids is not None: 1286 | batch_size = input_ids.shape[0] 1287 | else: 1288 | batch_size = inputs_embeds.shape[0] 1289 | 1290 | if self.config.pad_token_id is None and batch_size != 1: 1291 | raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") 1292 | if self.config.pad_token_id is None: 1293 | sequence_lengths = -1 1294 | else: 1295 | if input_ids is not None: 1296 | sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1).to( 1297 | logits.device 1298 | ) 1299 | else: 1300 | sequence_lengths = -1 1301 | 1302 | pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] 1303 | 1304 | loss = None 1305 | if labels is not None: 1306 | labels = labels.to(logits.device) 1307 | if self.config.problem_type is None: 1308 | if self.num_labels == 1: 1309 | self.config.problem_type = "regression" 1310 | elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): 1311 | self.config.problem_type = "single_label_classification" 1312 | else: 1313 | self.config.problem_type = "multi_label_classification" 1314 | 1315 | if self.config.problem_type == "regression": 1316 | loss_fct = MSELoss() 1317 | if self.num_labels == 1: 1318 | loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) 1319 | else: 1320 | loss = loss_fct(pooled_logits, labels) 1321 | elif self.config.problem_type == "single_label_classification": 1322 | loss_fct = CrossEntropyLoss() 1323 | loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) 1324 | elif self.config.problem_type == "multi_label_classification": 1325 | loss_fct = BCEWithLogitsLoss() 1326 | loss = loss_fct(pooled_logits, labels) 1327 | if not return_dict: 1328 | output = (pooled_logits,) + transformer_outputs[1:] 1329 | return ((loss,) + output) if loss is not None else output 1330 | 1331 | return SequenceClassifierOutputWithPast( 1332 | loss=loss, 1333 | logits=pooled_logits, 1334 | past_key_values=transformer_outputs.past_key_values, 1335 | hidden_states=transformer_outputs.hidden_states, 1336 | attentions=transformer_outputs.attentions, 1337 | ) -------------------------------------------------------------------------------- /mixtral_base22/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "MixtralForCausalLM" 4 | ], 5 | "attention_dropout": 0.0, 6 | "bos_token_id": 1, 7 | "eos_token_id": 2, 8 | "hidden_act": "silu", 9 | "hidden_size": 6144, 10 | "initializer_range": 0.02, 11 | "intermediate_size": 16384, 12 | "max_position_embeddings": 65536, 13 | "model_type": "mixtral", 14 | "num_attention_heads": 48, 15 | "num_experts_per_tok": 2, 16 | "num_hidden_layers": 56, 17 | "num_key_value_heads": 8, 18 | "num_experts": 8, 19 | "rms_norm_eps": 1e-05, 20 | "rope_theta": 1000000, 21 | "tie_word_embeddings": false, 22 | "torch_dtype": "bfloat16", 23 | "transformers_version": "4.38.0", 24 | "use_cache": true, 25 | "vocab_size": 32000 26 | } 27 | -------------------------------------------------------------------------------- /mixtral_instruct/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "MixtralForCausalLM" 4 | ], 5 | "attention_dropout": 0.0, 6 | "bos_token_id": 1, 7 | "eos_token_id": 2, 8 | "hidden_act": "silu", 9 | "hidden_size": 4096, 10 | "initializer_range": 0.02, 11 | "intermediate_size": 14336, 12 | "max_position_embeddings": 32768, 13 | "model_type": "mixtral", 14 | "num_attention_heads": 32, 15 | "num_experts_per_tok": 2, 16 | "num_hidden_layers": 32, 17 | "num_key_value_heads": 8, 18 | "num_local_experts": 8, 19 | "output_router_logits": false, 20 | "rms_norm_eps": 1e-05, 21 | "rope_theta": 1000000.0, 22 | "router_aux_loss_coef": 0.02, 23 | "sliding_window": null, 24 | "tie_word_embeddings": false, 25 | "torch_dtype": "bfloat16", 26 | "transformers_version": "4.36.0.dev0", 27 | "use_cache": true, 28 | "vocab_size": 32000 29 | } 30 | -------------------------------------------------------------------------------- /mixtral_instruct/configuration_moe_mistral.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 Mistral AI and the HuggingFace Inc. team. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ Mistral model configuration""" 16 | 17 | from transformers.configuration_utils import PretrainedConfig 18 | from transformers.utils import logging 19 | 20 | 21 | logger = logging.get_logger(__name__) 22 | 23 | MISTRAL_PRETRAINED_CONFIG_ARCHIVE_MAP = { 24 | "mistralai/Mistral-7B-v0.1": "https://huggingface.co/mistralai/Mistral-7B-v0.1/resolve/main/config.json", 25 | "mistralai/Mistral-7B-Instruct-v0.1": "https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1/resolve/main/config.json", 26 | } 27 | 28 | 29 | class MixtralConfig(PretrainedConfig): 30 | r""" 31 | This is the configuration class to store the configuration of a [`MistralModel`]. It is used to instantiate an 32 | Mistral model according to the specified arguments, defining the model architecture. Instantiating a configuration 33 | with the defaults will yield a similar configuration to that of the Mistral-7B-v0.1 or Mistral-7B-Instruct-v0.1. 34 | 35 | [mistralai/Mistral-7B-v0.1](https://huggingface.co/mistralai/Mistral-7B-v0.1) 36 | [mistralai/Mistral-7B-Instruct-v0.1](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1) 37 | 38 | Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the 39 | documentation from [`PretrainedConfig`] for more information. 40 | 41 | 42 | Args: 43 | vocab_size (`int`, *optional*, defaults to 32000): 44 | Vocabulary size of the Mistral model. Defines the number of different tokens that can be represented by the 45 | `inputs_ids` passed when calling [`MistralModel`] 46 | hidden_size (`int`, *optional*, defaults to 4096): 47 | Dimension of the hidden representations. 48 | intermediate_size (`int`, *optional*, defaults to 14336): 49 | Dimension of the MLP representations. 50 | num_hidden_layers (`int`, *optional*, defaults to 32): 51 | Number of hidden layers in the Transformer encoder. 52 | num_attention_heads (`int`, *optional*, defaults to 32): 53 | Number of attention heads for each attention layer in the Transformer encoder. 54 | num_key_value_heads (`int`, *optional*, defaults to 8): 55 | This is the number of key_value heads that should be used to implement Grouped Query Attention. If 56 | `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if 57 | `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When 58 | converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed 59 | by meanpooling all the original heads within that group. For more details checkout [this 60 | paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `8`. 61 | hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): 62 | The non-linear activation function (function or string) in the decoder. 63 | max_position_embeddings (`int`, *optional*, defaults to `4096*32`): 64 | The maximum sequence length that this model might ever be used with. Mistral's sliding window attention 65 | allows sequence of up to 4096*32 tokens. 66 | initializer_range (`float`, *optional*, defaults to 0.02): 67 | The standard deviation of the truncated_normal_initializer for initializing all weight matrices. 68 | rms_norm_eps (`float`, *optional*, defaults to 1e-06): 69 | The epsilon used by the rms normalization layers. 70 | use_cache (`bool`, *optional*, defaults to `True`): 71 | Whether or not the model should return the last key/values attentions (not used by all models). Only 72 | relevant if `config.is_decoder=True`. 73 | pad_token_id (`int`, *optional*): 74 | The id of the padding token. 75 | bos_token_id (`int`, *optional*, defaults to 1): 76 | The id of the "beginning-of-sequence" token. 77 | eos_token_id (`int`, *optional*, defaults to 2): 78 | The id of the "end-of-sequence" token. 79 | tie_word_embeddings (`bool`, *optional*, defaults to `False`): 80 | Whether the model's input and output word embeddings should be tied. 81 | rope_theta (`float`, *optional*, defaults to 10000.0): 82 | The base period of the RoPE embeddings. 83 | sliding_window (`int`, *optional*, defaults to 4096): 84 | Sliding window attention window size. If not specified, will default to `4096`. 85 | attention_dropout (`float`, *optional*, defaults to 0.0): 86 | The dropout ratio for the attention probabilities. 87 | 88 | ```python 89 | >>> from transformers import MistralModel, MistralConfig 90 | 91 | >>> # Initializing a Mistral 7B style configuration 92 | >>> configuration = MistralConfig() 93 | 94 | >>> # Initializing a model from the Mistral 7B style configuration 95 | >>> model = MistralModel(configuration) 96 | 97 | >>> # Accessing the model configuration 98 | >>> configuration = model.config 99 | ```""" 100 | 101 | model_type = "mistral" 102 | keys_to_ignore_at_inference = ["past_key_values"] 103 | 104 | def __init__( 105 | self, 106 | vocab_size=32000, 107 | hidden_size=4096, 108 | intermediate_size=14336, 109 | num_hidden_layers=32, 110 | num_attention_heads=32, 111 | num_key_value_heads=8, 112 | hidden_act="silu", 113 | max_position_embeddings=4096 * 32, 114 | initializer_range=0.02, 115 | rms_norm_eps=1e-6, 116 | use_cache=True, 117 | pad_token_id=None, 118 | bos_token_id=1, 119 | eos_token_id=2, 120 | tie_word_embeddings=False, 121 | rope_theta=10000.0, 122 | attention_dropout=0.0, 123 | num_experts_per_token=2, 124 | num_experts=8, 125 | single_expert=False, 126 | **kwargs, 127 | ): 128 | self.vocab_size = vocab_size 129 | self.max_position_embeddings = max_position_embeddings 130 | self.hidden_size = hidden_size 131 | self.intermediate_size = intermediate_size 132 | self.num_hidden_layers = num_hidden_layers 133 | self.num_attention_heads = num_attention_heads 134 | 135 | # for backward compatibility 136 | if num_key_value_heads is None: 137 | num_key_value_heads = num_attention_heads 138 | 139 | self.num_key_value_heads = num_key_value_heads 140 | self.hidden_act = hidden_act 141 | self.initializer_range = initializer_range 142 | self.rms_norm_eps = rms_norm_eps 143 | self.use_cache = use_cache 144 | self.rope_theta = rope_theta 145 | self.attention_dropout = attention_dropout 146 | self.num_experts = num_experts 147 | self.num_experts_per_token = num_experts_per_token 148 | 149 | self.single_expert = single_expert 150 | 151 | super().__init__( 152 | pad_token_id=pad_token_id, 153 | bos_token_id=bos_token_id, 154 | eos_token_id=eos_token_id, 155 | tie_word_embeddings=tie_word_embeddings, 156 | **kwargs, 157 | ) 158 | --------------------------------------------------------------------------------