├── LICENSE
├── README.md
├── assets
├── last_layer.png
├── pearson.png
└── rank_count.png
├── deepseekmoe
├── config.json
├── configuration_deepseek.py
└── modeling_deepseek.py
├── dynamic_analysis.ipynb
├── env.txt
├── grok
├── config.json
├── configuration_grok1.py
├── modeling_grok1.py
└── modeling_grok1_outputs.py
├── mistral
├── config.json
├── configuration_mistral.py
└── modeling_mistral.py
├── mixtral_base
├── config.json
├── configuration_moe_mistral.py
└── modeling_moe_mistral.py
├── mixtral_base22
└── config.json
├── mixtral_instruct
├── config.json
├── configuration_moe_mistral.py
└── modeling_mixtral_instruct.py
├── static_analysis.ipynb
└── wikitext103_test.csv
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Lo Ka Man
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # A Closer Look into MoEs in LLMs
2 |
3 | This repository contains the code of [A Closer Look into Mixture-of-Experts in Large Language Models](https://arxiv.org/abs/2406.18219).
4 |
5 | ## Overview :eyes:
6 |
7 | We make an initial attempt to understand the inner workings of MoE-based large language models.
8 | Concretely, we comprehensively study the parametric and behavioral features of four recent MoE-based models ([Mixtral 8x7B](https://arxiv.org/pdf/2401.04088), Mixtral 8x22B, [DeepSeekMoE](https://arxiv.org/pdf/2401.06066), [Grok-1](https://github.com/xai-org/grok-1)) and reveal some intriguing observations, including:
9 |
10 | - **Neurons act like fine-grained experts.** \
11 | Intuitively, the gate embedding determines the expert selection while the gate projection matrix of expert is responsible for choosing neurons to activate. Interestingly, their similarity values show association as described in the table below (X and Y denotes the similarity values of the gate embedding and the three projection matrices, respectively).
12 | Therefore, they may learn similar knowledge to perform the choosing operation reasonably, in the other words, the expert neurons are fine-grained experts.
13 |
14 |
15 |
16 | - **The router of MoE usually selects experts with larger output norms.** \
17 | Using Mixtral as an example here, we find that the expert that outputs feature vector with the *i*-th largest norm is the most likely to be assigned with the *i*-th highest score by the gate.
18 | (In the figure, larger ranking index means larger norm/score, so rank 8 is the largest.)
19 |
20 |
21 |
22 | - **The expert diversity increases as the layer increases, while the last layer is an outlier.** \
23 | In several experiments, we observe that the similarities between experts are generally lower in deep layers, whereas the similarities increase in the last layer(s).
24 | For instance, the figure below shows the similarity heat maps of the Mixtral experts' outputs of different layers, where the comparison of the values is: Layer 31 > Layer 6 > Layer 27.
25 |
26 |
27 |
28 | Based on the observations, we also provide suggestions for a broad spectrum of MoE practitioners, such as router design and expert allocation.
29 | **Check out our paper for more inspiring observations and suggestions!**
30 |
31 | ## Setup :wrench:
32 |
33 | 1. Download the model checkpoints \
34 | By default, our code loads the pre-downloaded models from the `ckpt` directory.
35 | You can also modify it to directly download from HuggingFace. The download links of the models we used are listed below:
36 | - [Mixtral 8x7B Base](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1)
37 | - [Mixtral 8x22B Base](https://huggingface.co/mistralai/Mixtral-8x22B-v0.1)
38 | - [Mixtral 8x7B Instruct](https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1)
39 | - [Mistral](https://huggingface.co/mistralai/Mistral-7B-v0.1)
40 | - [DeepSeekMoE](https://huggingface.co/deepseek-ai/deepseek-moe-16b-base)
41 | - [Grok-1](https://huggingface.co/hpcai-tech/grok-1)
42 |
43 | 2. Create the conda environment
44 | ```bash
45 | git clone https://github.com/kamanphoebe/Look-into-MoEs.git
46 | cd Look-into-MoEs
47 | conda create -n analyze --file env.txt
48 | ```
49 | After creating the conda enviroment, you have to select it as the Jupyter kernel.
50 |
51 | ## Usage :memo:
52 |
53 | The two Jupyter notebooks `static_analysis.ipynb` and `dynamic_analysis.ipynb` contains the code of experiments about the static parameters and dynamic behaviours, respectively.
54 | You can simply run the corresponding code blocks for each experiment, which is titled the same as in the paper.
55 | Note that some experiments employ part of the [Wikitext 103 test set](https://huggingface.co/datasets/Salesforce/wikitext), which we have already provided in the `wikitext103_text.csv`.
56 |
57 | ## Citation :star2:
58 |
59 | Please cite our work if you find it useful!
60 | ```bibtex
61 | @article{lo2024closer,
62 | title={A Closer Look into Mixture-of-Experts in Large Language Models},
63 | author={Lo, Ka Man and Huang, Zeyu and Qiu, Zihan and Wang, Zili and Fu, Jie},
64 | journal={arXiv preprint arXiv:2406.18219},
65 | year={2024}
66 | }
67 | ```
68 |
69 | ## Acknowledgement :tada:
70 |
71 | Our configuration and modeling files of the models are modified based on the corresponding HuggingFace repositories as listed in the [Setup](https://github.com/kamanphoebe/Look-into-MoEs/tree/main?tab=readme-ov-file#setup-wrench) section.
72 | Thanks for the authors' great work!
--------------------------------------------------------------------------------
/assets/last_layer.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kamanphoebe/Look-into-MoEs/f53d24b88e30343309aa1c85df3314cea75601d2/assets/last_layer.png
--------------------------------------------------------------------------------
/assets/pearson.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kamanphoebe/Look-into-MoEs/f53d24b88e30343309aa1c85df3314cea75601d2/assets/pearson.png
--------------------------------------------------------------------------------
/assets/rank_count.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kamanphoebe/Look-into-MoEs/f53d24b88e30343309aa1c85df3314cea75601d2/assets/rank_count.png
--------------------------------------------------------------------------------
/deepseekmoe/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "architectures": [
3 | "DeepseekForCausalLM"
4 | ],
5 | "attention_bias": false,
6 | "attention_dropout": 0.0,
7 | "auto_map": {
8 | "AutoConfig": "configuration_deepseek.DeepseekConfig",
9 | "AutoModel": "modeling_deepseek.DeepseekModel",
10 | "AutoModelForCausalLM": "modeling_deepseek.DeepseekForCausalLM"
11 | },
12 | "bos_token_id": 100000,
13 | "eos_token_id": 100001,
14 | "first_k_dense_replace": 1,
15 | "hidden_act": "silu",
16 | "hidden_size": 2048,
17 | "initializer_range": 0.02,
18 | "intermediate_size": 10944,
19 | "max_position_embeddings": 4096,
20 | "model_type": "deepseek",
21 | "moe_intermediate_size": 1408,
22 | "moe_layer_freq": 1,
23 | "n_routed_experts": 64,
24 | "n_shared_experts": 2,
25 | "norm_topk_prob": false,
26 | "num_attention_heads": 16,
27 | "num_experts_per_tok": 6,
28 | "num_hidden_layers": 28,
29 | "num_key_value_heads": 16,
30 | "pretraining_tp": 1,
31 | "rms_norm_eps": 1e-06,
32 | "rope_scaling": null,
33 | "rope_theta": 10000,
34 | "scoring_func": "softmax",
35 | "tie_word_embeddings": false,
36 | "torch_dtype": "bfloat16",
37 | "transformers_version": "4.36.0",
38 | "use_cache": true,
39 | "vocab_size": 102400
40 | }
41 |
--------------------------------------------------------------------------------
/deepseekmoe/configuration_deepseek.py:
--------------------------------------------------------------------------------
1 | from transformers.configuration_utils import PretrainedConfig
2 | from transformers.utils import logging
3 |
4 | logger = logging.get_logger(__name__)
5 |
6 | DEEPSEEK_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
7 | class DeepseekConfig(PretrainedConfig):
8 | r"""
9 | This is the configuration class to store the configuration of a [`DeepseekModel`]. It is used to instantiate an DeepSeek
10 | model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
11 | defaults will yield a similar configuration to that of the DeepSeek-7B.
12 |
13 | Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
14 | documentation from [`PretrainedConfig`] for more information.
15 |
16 |
17 | Args:
18 | vocab_size (`int`, *optional*, defaults to 102400):
19 | Vocabulary size of the Deep model. Defines the number of different tokens that can be represented by the
20 | `inputs_ids` passed when calling [`DeepseekModel`]
21 | hidden_size (`int`, *optional*, defaults to 4096):
22 | Dimension of the hidden representations.
23 | intermediate_size (`int`, *optional*, defaults to 11008):
24 | Dimension of the MLP representations.
25 | moe_intermediate_size (`int`, *optional*, defaults to 1407):
26 | Dimension of the MoE representations.
27 | num_hidden_layers (`int`, *optional*, defaults to 32):
28 | Number of hidden layers in the Transformer decoder.
29 | num_attention_heads (`int`, *optional*, defaults to 32):
30 | Number of attention heads for each attention layer in the Transformer decoder.
31 | n_shared_experts (`int`, *optional*, defaults to None):
32 | Number of shared experts, None means dense model.
33 | n_routed_experts (`int`, *optional*, defaults to None):
34 | Number of routed experts, None means dense model.
35 | num_experts_per_tok (`int`, *optional*, defaults to None):
36 | Number of selected experts, None means dense model.
37 | moe_layer_freq (`int`, *optional*, defaults to 1):
38 | The frequency of the MoE layer: one expert layer for every `moe_layer_freq - 1` dense layers.
39 | first_k_dense_replace (`int`, *optional*, defaults to 0):
40 | Number of dense layers in shallow layers(embed->dense->dense->...->dense->moe->moe...->lm_head).
41 | \--k dense layers--/
42 | norm_topk_prob (`bool`, *optional*, defaults to False):
43 | Whether to normalize the weights of the routed experts.
44 | scoring_func (`str`, *optional*, defaults to 'softmax'):
45 | Method of computing expert weights.
46 | aux_loss_alpha (`float`, *optional*, defaults to 0.001):
47 | Auxiliary loss weight coefficient.
48 | seq_aux = (`bool`, *optional*, defaults to True):
49 | Whether to compute the auxiliary loss for each individual sample.
50 | num_key_value_heads (`int`, *optional*):
51 | This is the number of key_value heads that should be used to implement Grouped Query Attention. If
52 | `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
53 | `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
54 | converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
55 | by meanpooling all the original heads within that group. For more details checkout [this
56 | paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
57 | `num_attention_heads`.
58 | hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
59 | The non-linear activation function (function or string) in the decoder.
60 | max_position_embeddings (`int`, *optional*, defaults to 2048):
61 | The maximum sequence length that this model might ever be used with.
62 | initializer_range (`float`, *optional*, defaults to 0.02):
63 | The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
64 | rms_norm_eps (`float`, *optional*, defaults to 1e-06):
65 | The epsilon used by the rms normalization layers.
66 | use_cache (`bool`, *optional*, defaults to `True`):
67 | Whether or not the model should return the last key/values attentions (not used by all models). Only
68 | relevant if `config.is_decoder=True`.
69 | pad_token_id (`int`, *optional*):
70 | Padding token id.
71 | bos_token_id (`int`, *optional*, defaults to 1):
72 | Beginning of stream token id.
73 | eos_token_id (`int`, *optional*, defaults to 2):
74 | End of stream token id.
75 | pretraining_tp (`int`, *optional*, defaults to 1):
76 | Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
77 | document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is
78 | necessary to ensure exact reproducibility of the pretraining results. Please refer to [this
79 | issue](https://github.com/pytorch/pytorch/issues/76232).
80 | tie_word_embeddings (`bool`, *optional*, defaults to `False`):
81 | Whether to tie weight embeddings
82 | rope_theta (`float`, *optional*, defaults to 10000.0):
83 | The base period of the RoPE embeddings.
84 | rope_scaling (`Dict`, *optional*):
85 | Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
86 | strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
87 | `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
88 | `max_position_embeddings` to the expected new maximum.
89 | attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
90 | Whether to use a bias in the query, key, value and output projection layers during self-attention.
91 | attention_dropout (`float`, *optional*, defaults to 0.0):
92 | The dropout ratio for the attention probabilities.
93 |
94 | ```python
95 | >>> from transformers import DeepseekModel, DeepseekConfig
96 |
97 | >>> # Initializing a Deepseek deepseek-7b style configuration
98 | >>> configuration = DeepseekConfig()
99 |
100 | >>> # Accessing the model configuration
101 | >>> configuration = model.config
102 | ```"""
103 |
104 | model_type = "deepseek"
105 | keys_to_ignore_at_inference = ["past_key_values"]
106 |
107 | def __init__(
108 | self,
109 | vocab_size=102400,
110 | hidden_size=4096,
111 | intermediate_size=11008,
112 | moe_intermediate_size = 1407,
113 | num_hidden_layers=30,
114 | num_attention_heads=32,
115 | num_key_value_heads=32,
116 | n_shared_experts = None,
117 | n_routed_experts = None,
118 | num_experts_per_tok = None,
119 | moe_layer_freq = 1,
120 | first_k_dense_replace = 0,
121 | norm_topk_prob = False,
122 | scoring_func = 'softmax',
123 | aux_loss_alpha = 0.001,
124 | seq_aux = True,
125 | hidden_act="silu",
126 | max_position_embeddings=2048,
127 | initializer_range=0.02,
128 | rms_norm_eps=1e-6,
129 | use_cache=True,
130 | pad_token_id=None,
131 | bos_token_id=100000,
132 | eos_token_id=100001,
133 | pretraining_tp=1,
134 | tie_word_embeddings=False,
135 | rope_theta=10000.0,
136 | rope_scaling=None,
137 | attention_bias=False,
138 | attention_dropout=0.0,
139 | **kwargs,
140 | ):
141 | self.vocab_size = vocab_size
142 | self.max_position_embeddings = max_position_embeddings
143 | self.hidden_size = hidden_size
144 | self.intermediate_size = intermediate_size
145 | self.moe_intermediate_size = moe_intermediate_size
146 | self.num_hidden_layers = num_hidden_layers
147 | self.num_attention_heads = num_attention_heads
148 | self.n_shared_experts = n_shared_experts
149 | self.n_routed_experts = n_routed_experts
150 | self.num_experts_per_tok = num_experts_per_tok
151 | self.moe_layer_freq = moe_layer_freq
152 | self.first_k_dense_replace = first_k_dense_replace
153 | self.norm_topk_prob = norm_topk_prob
154 | self.scoring_func = scoring_func
155 | self.aux_loss_alpha = aux_loss_alpha
156 | self.seq_aux = seq_aux
157 | # for backward compatibility
158 | if num_key_value_heads is None:
159 | num_key_value_heads = num_attention_heads
160 |
161 | self.num_key_value_heads = num_key_value_heads
162 | self.hidden_act = hidden_act
163 | self.initializer_range = initializer_range
164 | self.rms_norm_eps = rms_norm_eps
165 | self.pretraining_tp = pretraining_tp
166 | self.use_cache = use_cache
167 | self.rope_theta = rope_theta
168 | self.rope_scaling = rope_scaling
169 | self._rope_scaling_validation()
170 | self.attention_bias = attention_bias
171 | self.attention_dropout = attention_dropout
172 |
173 | super().__init__(
174 | pad_token_id=pad_token_id,
175 | bos_token_id=bos_token_id,
176 | eos_token_id=eos_token_id,
177 | tie_word_embeddings=tie_word_embeddings,
178 | **kwargs,
179 | )
180 |
181 | def _rope_scaling_validation(self):
182 | """
183 | Validate the `rope_scaling` configuration.
184 | """
185 | if self.rope_scaling is None:
186 | return
187 |
188 | if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
189 | raise ValueError(
190 | "`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, "
191 | f"got {self.rope_scaling}"
192 | )
193 | rope_scaling_type = self.rope_scaling.get("type", None)
194 | rope_scaling_factor = self.rope_scaling.get("factor", None)
195 | if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
196 | raise ValueError(
197 | f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
198 | )
199 | if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
200 | raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}")
--------------------------------------------------------------------------------
/env.txt:
--------------------------------------------------------------------------------
1 | # This file may be used to create an environment using:
2 | # $ conda create --name --file
3 | # platform: linux-64
4 | @EXPLICIT
5 | https://conda.anaconda.org/conda-forge/linux-64/_libgcc_mutex-0.1-conda_forge.tar.bz2
6 | https://repo.anaconda.com/pkgs/main/linux-64/blas-1.0-mkl.conda
7 | https://repo.anaconda.com/pkgs/main/linux-64/ca-certificates-2024.3.11-h06a4308_0.conda
8 | https://conda.anaconda.org/nvidia/linux-64/cuda-cudart-11.8.89-0.tar.bz2
9 | https://conda.anaconda.org/nvidia/linux-64/cuda-cupti-11.8.87-0.tar.bz2
10 | https://conda.anaconda.org/nvidia/linux-64/cuda-nvrtc-11.8.89-0.tar.bz2
11 | https://conda.anaconda.org/nvidia/linux-64/cuda-nvtx-11.8.86-0.tar.bz2
12 | https://conda.anaconda.org/conda-forge/linux-64/ld_impl_linux-64-2.40-h41732ed_0.conda
13 | https://conda.anaconda.org/nvidia/linux-64/libcublas-11.11.3.6-0.tar.bz2
14 | https://conda.anaconda.org/nvidia/linux-64/libcufft-10.9.0.58-0.tar.bz2
15 | https://conda.anaconda.org/nvidia/linux-64/libcufile-1.8.1.2-0.tar.bz2
16 | https://conda.anaconda.org/nvidia/linux-64/libcurand-10.3.4.101-0.tar.bz2
17 | https://conda.anaconda.org/nvidia/linux-64/libcusolver-11.4.1.48-0.tar.bz2
18 | https://conda.anaconda.org/nvidia/linux-64/libcusparse-11.7.5.86-0.tar.bz2
19 | https://repo.anaconda.com/pkgs/main/linux-64/libgfortran5-11.2.0-h1234567_1.conda
20 | https://conda.anaconda.org/nvidia/linux-64/libnpp-11.8.0.86-0.tar.bz2
21 | https://conda.anaconda.org/nvidia/linux-64/libnvjpeg-11.9.0.86-0.tar.bz2
22 | https://conda.anaconda.org/conda-forge/linux-64/libstdcxx-ng-13.2.0-h7e041cc_3.conda
23 | https://conda.anaconda.org/conda-forge/linux-64/python_abi-3.10-4_cp310.conda
24 | https://conda.anaconda.org/pytorch/noarch/pytorch-mutex-1.0-cuda.tar.bz2
25 | https://conda.anaconda.org/conda-forge/noarch/tzdata-2023c-h71feb2d_0.conda
26 | https://conda.anaconda.org/nvidia/linux-64/cuda-libraries-11.8.0-0.tar.bz2
27 | https://repo.anaconda.com/pkgs/main/linux-64/libgfortran-ng-11.2.0-h00389a5_1.conda
28 | https://conda.anaconda.org/conda-forge/linux-64/libgomp-13.2.0-h807b86a_3.conda
29 | https://conda.anaconda.org/conda-forge/linux-64/_openmp_mutex-4.5-2_gnu.tar.bz2
30 | https://conda.anaconda.org/nvidia/linux-64/cuda-runtime-11.8.0-0.tar.bz2
31 | https://conda.anaconda.org/conda-forge/linux-64/libgcc-ng-13.2.0-h807b86a_3.conda
32 | https://conda.anaconda.org/pytorch/linux-64/pytorch-cuda-11.8-h7e8668a_5.tar.bz2
33 | https://repo.anaconda.com/pkgs/main/linux-64/abseil-cpp-20211102.0-hd4dd3e8_0.conda
34 | https://repo.anaconda.com/pkgs/main/linux-64/aws-c-common-0.8.5-h5eee18b_0.conda
35 | https://conda.anaconda.org/conda-forge/linux-64/bzip2-1.0.8-hd590300_5.conda
36 | https://repo.anaconda.com/pkgs/main/linux-64/c-ares-1.19.1-h5eee18b_0.conda
37 | https://repo.anaconda.com/pkgs/main/linux-64/expat-2.5.0-h6a678d5_0.conda
38 | https://repo.anaconda.com/pkgs/main/linux-64/gflags-2.2.2-h6a678d5_1.conda
39 | https://repo.anaconda.com/pkgs/main/linux-64/giflib-5.2.1-h5eee18b_3.conda
40 | https://repo.anaconda.com/pkgs/main/linux-64/gmp-6.2.1-h295c915_3.conda
41 | https://repo.anaconda.com/pkgs/main/linux-64/icu-73.1-h6a678d5_0.conda
42 | https://repo.anaconda.com/pkgs/main/linux-64/jpeg-9e-h5eee18b_1.conda
43 | https://repo.anaconda.com/pkgs/main/linux-64/lame-3.100-h7b6447c_0.conda
44 | https://repo.anaconda.com/pkgs/main/linux-64/lerc-3.0-h295c915_0.conda
45 | https://repo.anaconda.com/pkgs/main/linux-64/libbrotlicommon-1.0.9-h5eee18b_7.conda
46 | https://repo.anaconda.com/pkgs/main/linux-64/libdeflate-1.17-h5eee18b_1.conda
47 | https://repo.anaconda.com/pkgs/main/linux-64/libev-4.33-h7f8727e_1.conda
48 | https://repo.anaconda.com/pkgs/main/linux-64/libffi-3.4.4-h6a678d5_0.conda
49 | https://repo.anaconda.com/pkgs/main/linux-64/libiconv-1.16-h7f8727e_2.conda
50 | https://conda.anaconda.org/pytorch/linux-64/libjpeg-turbo-2.0.0-h9bf148f_0.tar.bz2
51 | https://conda.anaconda.org/conda-forge/linux-64/libnsl-2.0.1-hd590300_0.conda
52 | https://conda.anaconda.org/conda-forge/linux-64/libsodium-1.0.18-h36c2ea0_1.tar.bz2
53 | https://repo.anaconda.com/pkgs/main/linux-64/libtasn1-4.19.0-h5eee18b_0.conda
54 | https://repo.anaconda.com/pkgs/main/linux-64/libunistring-0.9.10-h27cfd23_0.conda
55 | https://conda.anaconda.org/conda-forge/linux-64/libuuid-2.38.1-h0b41bf4_0.conda
56 | https://repo.anaconda.com/pkgs/main/linux-64/libwebp-base-1.3.2-h5eee18b_0.conda
57 | https://repo.anaconda.com/pkgs/main/linux-64/libxcb-1.15-h7f8727e_0.conda
58 | https://conda.anaconda.org/conda-forge/linux-64/libzlib-1.2.13-hd590300_5.conda
59 | https://repo.anaconda.com/pkgs/main/linux-64/lz4-c-1.9.4-h6a678d5_0.conda
60 | https://conda.anaconda.org/conda-forge/linux-64/ncurses-6.4-h59595ed_2.conda
61 | https://repo.anaconda.com/pkgs/main/linux-64/openh264-2.1.1-h4ff587b_0.conda
62 | https://conda.anaconda.org/conda-forge/linux-64/openssl-3.2.1-hd590300_1.conda
63 | https://repo.anaconda.com/pkgs/main/linux-64/re2-2022.04.01-h295c915_0.conda
64 | https://repo.anaconda.com/pkgs/main/linux-64/snappy-1.1.10-h6a678d5_1.conda
65 | https://repo.anaconda.com/pkgs/main/linux-64/tbb-2021.8.0-hdb19cb5_0.conda
66 | https://repo.anaconda.com/pkgs/main/linux-64/utf8proc-2.6.1-h5eee18b_1.conda
67 | https://repo.anaconda.com/pkgs/main/linux-64/xxhash-0.8.0-h7f8727e_3.conda
68 | https://repo.anaconda.com/pkgs/main/linux-64/xz-5.4.2-h5eee18b_0.conda
69 | https://repo.anaconda.com/pkgs/main/linux-64/yaml-0.2.5-h7b6447c_0.conda
70 | https://repo.anaconda.com/pkgs/main/linux-64/aws-c-cal-0.5.20-hdbd6064_0.conda
71 | https://repo.anaconda.com/pkgs/main/linux-64/aws-c-compression-0.2.16-h5eee18b_0.conda
72 | https://repo.anaconda.com/pkgs/main/linux-64/aws-c-sdkutils-0.1.6-h5eee18b_0.conda
73 | https://repo.anaconda.com/pkgs/main/linux-64/aws-checksums-0.1.13-h5eee18b_0.conda
74 | https://repo.anaconda.com/pkgs/main/linux-64/glog-0.5.0-h6a678d5_1.conda
75 | https://repo.anaconda.com/pkgs/main/linux-64/libbrotlidec-1.0.9-h5eee18b_7.conda
76 | https://repo.anaconda.com/pkgs/main/linux-64/libbrotlienc-1.0.9-h5eee18b_7.conda
77 | https://repo.anaconda.com/pkgs/main/linux-64/libedit-3.1.20230828-h5eee18b_0.conda
78 | https://repo.anaconda.com/pkgs/main/linux-64/libevent-2.1.12-hdbd6064_1.conda
79 | https://repo.anaconda.com/pkgs/main/linux-64/libidn2-2.3.4-h5eee18b_0.conda
80 | https://conda.anaconda.org/conda-forge/linux-64/libsqlite-3.44.2-h2797004_0.conda
81 | https://repo.anaconda.com/pkgs/main/linux-64/libssh2-1.10.0-hdbd6064_2.conda
82 | https://repo.anaconda.com/pkgs/main/linux-64/mpfr-4.0.2-hb69a4c5_1.conda
83 | https://repo.anaconda.com/pkgs/main/linux-64/nettle-3.7.3-hbbd107a_1.conda
84 | https://conda.anaconda.org/conda-forge/linux-64/readline-8.2-h8228510_1.conda
85 | https://repo.anaconda.com/pkgs/main/linux-64/s2n-1.3.27-hdbd6064_0.conda
86 | https://conda.anaconda.org/conda-forge/linux-64/tk-8.6.13-noxft_h4845f30_101.conda
87 | https://conda.anaconda.org/conda-forge/linux-64/zeromq-4.3.5-h59595ed_0.conda
88 | https://conda.anaconda.org/conda-forge/linux-64/zlib-1.2.13-hd590300_5.conda
89 | https://repo.anaconda.com/pkgs/main/linux-64/aws-c-io-0.13.10-h5eee18b_0.conda
90 | https://repo.anaconda.com/pkgs/main/linux-64/brotli-bin-1.0.9-h5eee18b_7.conda
91 | https://repo.anaconda.com/pkgs/main/linux-64/gnutls-3.6.15-he1e5248_0.conda
92 | https://repo.anaconda.com/pkgs/main/linux-64/intel-openmp-2023.1.0-hdb19cb5_46306.conda
93 | https://repo.anaconda.com/pkgs/main/linux-64/krb5-1.20.1-h143b758_1.conda
94 | https://repo.anaconda.com/pkgs/main/linux-64/libcups-2.4.2-h2d74bed_1.conda
95 | https://repo.anaconda.com/pkgs/main/linux-64/libllvm14-14.0.6-hdb19cb5_3.conda
96 | https://repo.anaconda.com/pkgs/main/linux-64/libnghttp2-1.57.0-h2d74bed_0.conda
97 | https://repo.anaconda.com/pkgs/main/linux-64/libpng-1.6.39-h5eee18b_0.conda
98 | https://repo.anaconda.com/pkgs/main/linux-64/libprotobuf-3.20.3-he621ea3_0.conda
99 | https://repo.anaconda.com/pkgs/main/linux-64/libxml2-2.10.4-hf1b16e4_1.conda
100 | https://repo.anaconda.com/pkgs/main/linux-64/llvm-openmp-14.0.6-h9e868ea_0.conda
101 | https://repo.anaconda.com/pkgs/main/linux-64/mpc-1.1.0-h10f8cd9_1.conda
102 | https://repo.anaconda.com/pkgs/main/linux-64/pcre2-10.42-hebb0a14_0.conda
103 | https://conda.anaconda.org/conda-forge/linux-64/python-3.10.13-hd12c33a_0_cpython.conda
104 | https://conda.anaconda.org/conda-forge/linux-64/sqlite-3.44.2-h2c6b66d_0.conda
105 | https://repo.anaconda.com/pkgs/main/linux-64/zstd-1.5.5-hc292b87_0.conda
106 | https://conda.anaconda.org/conda-forge/noarch/antlr-python-runtime-4.9.3-pyhd8ed1ab_1.tar.bz2
107 | https://repo.anaconda.com/pkgs/main/linux-64/async-timeout-4.0.3-py310h06a4308_0.conda
108 | https://conda.anaconda.org/conda-forge/noarch/attrs-23.2.0-pyh71513ae_0.conda
109 | https://repo.anaconda.com/pkgs/main/linux-64/aws-c-event-stream-0.2.15-h6a678d5_0.conda
110 | https://repo.anaconda.com/pkgs/main/linux-64/aws-c-http-0.6.25-h5eee18b_0.conda
111 | https://repo.anaconda.com/pkgs/main/linux-64/brotli-1.0.9-h5eee18b_7.conda
112 | https://repo.anaconda.com/pkgs/main/linux-64/brotli-python-1.0.9-py310h6a678d5_7.conda
113 | https://conda.anaconda.org/conda-forge/noarch/cached_property-1.5.2-pyha770c72_1.tar.bz2
114 | https://repo.anaconda.com/pkgs/main/linux-64/certifi-2024.2.2-py310h06a4308_0.conda
115 | https://repo.anaconda.com/pkgs/main/noarch/charset-normalizer-2.0.4-pyhd3eb1b0_0.conda
116 | https://repo.anaconda.com/pkgs/main/noarch/cycler-0.11.0-pyhd3eb1b0_0.conda
117 | https://repo.anaconda.com/pkgs/main/linux-64/cyrus-sasl-2.1.28-h52b45da_1.conda
118 | https://conda.anaconda.org/conda-forge/linux-64/debugpy-1.8.0-py310hc6cd4ac_1.conda
119 | https://conda.anaconda.org/conda-forge/noarch/decorator-5.1.1-pyhd8ed1ab_0.tar.bz2
120 | https://conda.anaconda.org/conda-forge/noarch/defusedxml-0.7.1-pyhd8ed1ab_0.tar.bz2
121 | https://repo.anaconda.com/pkgs/main/linux-64/dill-0.3.6-py310h06a4308_0.conda
122 | https://conda.anaconda.org/conda-forge/noarch/entrypoints-0.4-pyhd8ed1ab_0.tar.bz2
123 | https://conda.anaconda.org/conda-forge/noarch/exceptiongroup-1.2.0-pyhd8ed1ab_0.conda
124 | https://conda.anaconda.org/conda-forge/noarch/executing-2.0.1-pyhd8ed1ab_0.conda
125 | https://repo.anaconda.com/pkgs/main/linux-64/filelock-3.13.1-py310h06a4308_0.conda
126 | https://repo.anaconda.com/pkgs/main/linux-64/freetype-2.12.1-h4a9f257_0.conda
127 | https://repo.anaconda.com/pkgs/main/linux-64/frozenlist-1.4.0-py310h5eee18b_0.conda
128 | https://repo.anaconda.com/pkgs/main/linux-64/fsspec-2023.10.0-py310h06a4308_0.conda
129 | https://repo.anaconda.com/pkgs/main/linux-64/gmpy2-2.1.2-py310heeb90bb_0.conda
130 | https://repo.anaconda.com/pkgs/main/linux-64/grpc-cpp-1.48.2-he1ff14a_1.conda
131 | https://conda.anaconda.org/conda-forge/noarch/hpack-4.0.0-pyh9f0ad1d_0.tar.bz2
132 | https://conda.anaconda.org/conda-forge/noarch/hyperframe-6.0.1-pyhd8ed1ab_0.tar.bz2
133 | https://repo.anaconda.com/pkgs/main/linux-64/idna-3.4-py310h06a4308_0.conda
134 | https://repo.anaconda.com/pkgs/main/linux-64/importlib_resources-6.1.1-py310h06a4308_1.conda
135 | https://conda.anaconda.org/anaconda/linux-64/joblib-1.2.0-py310h06a4308_0.tar.bz2
136 | https://conda.anaconda.org/conda-forge/noarch/json5-0.9.24-pyhd8ed1ab_0.conda
137 | https://conda.anaconda.org/conda-forge/linux-64/jsonpointer-2.4-py310hff52083_3.conda
138 | https://repo.anaconda.com/pkgs/main/linux-64/kiwisolver-1.4.4-py310h6a678d5_0.conda
139 | https://repo.anaconda.com/pkgs/main/linux-64/libboost-1.82.0-h109eef0_2.conda
140 | https://repo.anaconda.com/pkgs/main/linux-64/libclang13-14.0.6-default_he11475f_1.conda
141 | https://repo.anaconda.com/pkgs/main/linux-64/libcurl-8.5.0-h251f7ec_0.conda
142 | https://repo.anaconda.com/pkgs/main/linux-64/libglib-2.78.4-hdc74915_0.conda
143 | https://repo.anaconda.com/pkgs/main/linux-64/libpq-12.17-hdbd6064_0.conda
144 | https://repo.anaconda.com/pkgs/main/linux-64/libtiff-4.5.1-h6a678d5_0.conda
145 | https://repo.anaconda.com/pkgs/main/linux-64/libxkbcommon-1.0.1-h5eee18b_1.conda
146 | https://repo.anaconda.com/pkgs/main/linux-64/markupsafe-2.1.1-py310h7f8727e_0.conda
147 | https://conda.anaconda.org/conda-forge/noarch/mistune-3.0.2-pyhd8ed1ab_0.conda
148 | https://repo.anaconda.com/pkgs/main/linux-64/mkl-2023.1.0-h213fc3f_46344.conda
149 | https://repo.anaconda.com/pkgs/main/linux-64/mpmath-1.3.0-py310h06a4308_0.conda
150 | https://repo.anaconda.com/pkgs/main/linux-64/multidict-6.0.4-py310h5eee18b_0.conda
151 | https://repo.anaconda.com/pkgs/main/noarch/munkres-1.1.4-py_0.conda
152 | https://conda.anaconda.org/conda-forge/noarch/nest-asyncio-1.5.8-pyhd8ed1ab_0.conda
153 | https://repo.anaconda.com/pkgs/main/linux-64/networkx-3.1-py310h06a4308_0.conda
154 | https://repo.anaconda.com/pkgs/main/linux-64/orc-1.7.4-hb3bc3d3_1.conda
155 | https://conda.anaconda.org/conda-forge/noarch/packaging-23.2-pyhd8ed1ab_0.conda
156 | https://conda.anaconda.org/conda-forge/noarch/pandocfilters-1.5.0-pyhd8ed1ab_0.tar.bz2
157 | https://conda.anaconda.org/conda-forge/noarch/parso-0.8.3-pyhd8ed1ab_0.tar.bz2
158 | https://conda.anaconda.org/conda-forge/noarch/pickleshare-0.7.5-py_1003.tar.bz2
159 | https://conda.anaconda.org/conda-forge/noarch/pkgutil-resolve-name-1.3.10-pyhd8ed1ab_1.conda
160 | https://repo.anaconda.com/pkgs/main/linux-64/ply-3.11-py310h06a4308_0.conda
161 | https://conda.anaconda.org/conda-forge/noarch/prometheus_client-0.20.0-pyhd8ed1ab_0.conda
162 | https://conda.anaconda.org/conda-forge/linux-64/psutil-5.9.5-py310h2372a71_1.conda
163 | https://conda.anaconda.org/conda-forge/noarch/ptyprocess-0.7.0-pyhd3deb0d_0.tar.bz2
164 | https://conda.anaconda.org/conda-forge/noarch/pure_eval-0.2.2-pyhd8ed1ab_0.tar.bz2
165 | https://repo.anaconda.com/pkgs/main/noarch/pycparser-2.21-pyhd3eb1b0_0.conda
166 | https://conda.anaconda.org/conda-forge/noarch/pygments-2.17.2-pyhd8ed1ab_0.conda
167 | https://repo.anaconda.com/pkgs/main/linux-64/pyparsing-3.0.9-py310h06a4308_0.conda
168 | https://repo.anaconda.com/pkgs/main/linux-64/pyqt5-sip-12.13.0-py310h5eee18b_0.conda
169 | https://repo.anaconda.com/pkgs/main/linux-64/pysocks-1.7.1-py310h06a4308_0.conda
170 | https://conda.anaconda.org/conda-forge/noarch/python-fastjsonschema-2.19.1-pyhd8ed1ab_0.conda
171 | https://conda.anaconda.org/conda-forge/noarch/python-json-logger-2.0.7-pyhd8ed1ab_0.conda
172 | https://repo.anaconda.com/pkgs/main/noarch/python-tzdata-2023.3-pyhd3eb1b0_0.conda
173 | https://repo.anaconda.com/pkgs/main/linux-64/python-xxhash-2.0.2-py310h5eee18b_1.conda
174 | https://conda.anaconda.org/conda-forge/noarch/pytz-2024.1-pyhd8ed1ab_0.conda
175 | https://repo.anaconda.com/pkgs/main/linux-64/pyyaml-6.0.1-py310h5eee18b_0.conda
176 | https://conda.anaconda.org/conda-forge/linux-64/pyzmq-24.0.1-py310h330234f_1.tar.bz2
177 | https://repo.anaconda.com/pkgs/main/linux-64/regex-2023.10.3-py310h5eee18b_0.conda
178 | https://conda.anaconda.org/conda-forge/noarch/rfc3986-validator-0.1.1-pyh9f0ad1d_0.tar.bz2
179 | https://conda.anaconda.org/conda-forge/linux-64/rpds-py-0.18.0-py310hcb5633a_0.conda
180 | https://repo.anaconda.com/pkgs/main/linux-64/safetensors-0.4.2-py310ha89cbab_0.conda
181 | https://conda.anaconda.org/conda-forge/noarch/send2trash-1.8.2-pyh41d4057_0.conda
182 | https://conda.anaconda.org/conda-forge/noarch/setuptools-68.2.2-pyhd8ed1ab_0.conda
183 | https://conda.anaconda.org/conda-forge/noarch/six-1.16.0-pyh6c4a22f_0.tar.bz2
184 | https://conda.anaconda.org/conda-forge/noarch/sniffio-1.3.1-pyhd8ed1ab_0.conda
185 | https://conda.anaconda.org/conda-forge/noarch/soupsieve-2.5-pyhd8ed1ab_1.conda
186 | https://conda.anaconda.org/anaconda/noarch/threadpoolctl-2.2.0-pyh0d69192_0.tar.bz2
187 | https://repo.anaconda.com/pkgs/main/linux-64/tomli-2.0.1-py310h06a4308_0.conda
188 | https://conda.anaconda.org/conda-forge/linux-64/tornado-6.3.3-py310h2372a71_1.conda
189 | https://repo.anaconda.com/pkgs/main/linux-64/tqdm-4.65.0-py310h2f386ee_0.conda
190 | https://conda.anaconda.org/conda-forge/noarch/traitlets-5.14.0-pyhd8ed1ab_0.conda
191 | https://conda.anaconda.org/conda-forge/noarch/types-python-dateutil-2.9.0.20240316-pyhd8ed1ab_0.conda
192 | https://conda.anaconda.org/conda-forge/noarch/typing_extensions-4.8.0-pyha770c72_0.conda
193 | https://conda.anaconda.org/conda-forge/noarch/typing_utils-0.1.0-pyhd8ed1ab_0.tar.bz2
194 | https://conda.anaconda.org/conda-forge/noarch/uri-template-1.3.0-pyhd8ed1ab_0.conda
195 | https://conda.anaconda.org/conda-forge/noarch/wcwidth-0.2.12-pyhd8ed1ab_0.conda
196 | https://conda.anaconda.org/conda-forge/noarch/webcolors-1.13-pyhd8ed1ab_0.conda
197 | https://conda.anaconda.org/conda-forge/noarch/webencodings-0.5.1-pyhd8ed1ab_2.conda
198 | https://conda.anaconda.org/conda-forge/noarch/websocket-client-1.7.0-pyhd8ed1ab_0.conda
199 | https://conda.anaconda.org/conda-forge/noarch/wheel-0.42.0-pyhd8ed1ab_0.conda
200 | https://conda.anaconda.org/conda-forge/noarch/zipp-3.17.0-pyhd8ed1ab_0.conda
201 | https://repo.anaconda.com/pkgs/main/noarch/aiosignal-1.2.0-pyhd3eb1b0_0.conda
202 | https://conda.anaconda.org/conda-forge/noarch/anyio-4.3.0-pyhd8ed1ab_0.conda
203 | https://conda.anaconda.org/conda-forge/noarch/asttokens-2.4.1-pyhd8ed1ab_0.conda
204 | https://conda.anaconda.org/conda-forge/noarch/async-lru-2.0.4-pyhd8ed1ab_0.conda
205 | https://repo.anaconda.com/pkgs/main/linux-64/aws-c-auth-0.6.19-h5eee18b_0.conda
206 | https://repo.anaconda.com/pkgs/main/linux-64/aws-c-mqtt-0.7.13-h5eee18b_0.conda
207 | https://conda.anaconda.org/conda-forge/noarch/babel-2.14.0-pyhd8ed1ab_0.conda
208 | https://conda.anaconda.org/conda-forge/noarch/beautifulsoup4-4.12.3-pyha770c72_0.conda
209 | https://conda.anaconda.org/conda-forge/noarch/bleach-6.1.0-pyhd8ed1ab_0.conda
210 | https://repo.anaconda.com/pkgs/main/linux-64/boost-cpp-1.82.0-hdb19cb5_2.conda
211 | https://conda.anaconda.org/conda-forge/noarch/cached-property-1.5.2-hd8ed1ab_1.tar.bz2
212 | https://repo.anaconda.com/pkgs/main/linux-64/cffi-1.16.0-py310h5eee18b_0.conda
213 | https://conda.anaconda.org/conda-forge/noarch/comm-0.1.4-pyhd8ed1ab_0.conda
214 | https://conda.anaconda.org/pytorch/linux-64/ffmpeg-4.3-hf484d3e_0.tar.bz2
215 | https://conda.anaconda.org/conda-forge/linux-64/fontconfig-2.14.2-h14ed4e7_0.conda
216 | https://repo.anaconda.com/pkgs/main/noarch/fonttools-4.25.0-pyhd3eb1b0_0.conda
217 | https://repo.anaconda.com/pkgs/main/linux-64/glib-tools-2.78.4-h6a678d5_0.conda
218 | https://conda.anaconda.org/conda-forge/noarch/h11-0.14.0-pyhd8ed1ab_0.tar.bz2
219 | https://conda.anaconda.org/conda-forge/noarch/h2-4.1.0-pyhd8ed1ab_0.tar.bz2
220 | https://conda.anaconda.org/conda-forge/noarch/importlib-metadata-6.9.0-pyha770c72_0.conda
221 | https://conda.anaconda.org/conda-forge/noarch/jedi-0.19.1-pyhd8ed1ab_0.conda
222 | https://repo.anaconda.com/pkgs/main/linux-64/jinja2-3.1.2-py310h06a4308_0.conda
223 | https://conda.anaconda.org/conda-forge/noarch/jupyterlab_pygments-0.3.0-pyhd8ed1ab_1.conda
224 | https://repo.anaconda.com/pkgs/main/linux-64/lcms2-2.12-h3be6417_0.conda
225 | https://repo.anaconda.com/pkgs/main/linux-64/libclang-14.0.6-default_hc6dbbc7_1.conda
226 | https://repo.anaconda.com/pkgs/main/linux-64/libwebp-1.3.2-h11a3e52_0.conda
227 | https://conda.anaconda.org/conda-forge/noarch/matplotlib-inline-0.1.6-pyhd8ed1ab_0.tar.bz2
228 | https://repo.anaconda.com/pkgs/main/linux-64/mkl-service-2.4.0-py310h5eee18b_1.conda
229 | https://repo.anaconda.com/pkgs/main/linux-64/multiprocess-0.70.14-py310h06a4308_0.conda
230 | https://repo.anaconda.com/pkgs/main/linux-64/mysql-5.7.24-h721c034_2.conda
231 | https://conda.anaconda.org/conda-forge/noarch/omegaconf-2.3.0-pyhd8ed1ab_0.conda
232 | https://repo.anaconda.com/pkgs/main/linux-64/openjpeg-2.4.0-h3ad879b_0.conda
233 | https://conda.anaconda.org/conda-forge/noarch/overrides-7.7.0-pyhd8ed1ab_0.conda
234 | https://conda.anaconda.org/conda-forge/noarch/pexpect-4.8.0-pyh1a96a4e_2.tar.bz2
235 | https://conda.anaconda.org/conda-forge/noarch/pip-23.3.1-pyhd8ed1ab_0.conda
236 | https://conda.anaconda.org/conda-forge/noarch/platformdirs-4.0.0-pyhd8ed1ab_0.conda
237 | https://conda.anaconda.org/conda-forge/noarch/prompt-toolkit-3.0.41-pyha770c72_0.conda
238 | https://conda.anaconda.org/conda-forge/noarch/python-dateutil-2.8.2-pyhd8ed1ab_0.tar.bz2
239 | https://conda.anaconda.org/conda-forge/noarch/referencing-0.34.0-pyhd8ed1ab_0.conda
240 | https://conda.anaconda.org/conda-forge/noarch/rfc3339-validator-0.1.4-pyhd8ed1ab_0.tar.bz2
241 | https://repo.anaconda.com/pkgs/main/linux-64/sip-6.7.12-py310h6a678d5_0.conda
242 | https://conda.anaconda.org/conda-forge/noarch/sympy-1.12-pypyh9d50eac_103.conda
243 | https://conda.anaconda.org/conda-forge/noarch/terminado-0.18.1-pyh0d859eb_0.conda
244 | https://conda.anaconda.org/conda-forge/noarch/tinycss2-1.2.1-pyhd8ed1ab_0.tar.bz2
245 | https://conda.anaconda.org/conda-forge/noarch/typing-extensions-4.8.0-hd8ed1ab_0.conda
246 | https://repo.anaconda.com/pkgs/main/linux-64/yarl-1.9.3-py310h5eee18b_0.conda
247 | https://repo.anaconda.com/pkgs/main/linux-64/aiohttp-3.9.3-py310h5eee18b_0.conda
248 | https://conda.anaconda.org/conda-forge/linux-64/argon2-cffi-bindings-21.2.0-py310h2372a71_4.conda
249 | https://conda.anaconda.org/conda-forge/noarch/arrow-1.3.0-pyhd8ed1ab_0.conda
250 | https://repo.anaconda.com/pkgs/main/linux-64/aws-c-s3-0.1.51-hdbd6064_0.conda
251 | https://repo.anaconda.com/pkgs/main/linux-64/cryptography-41.0.3-py310hdda0065_0.conda
252 | https://conda.anaconda.org/conda-forge/noarch/fqdn-1.5.1-pyhd8ed1ab_0.tar.bz2
253 | https://repo.anaconda.com/pkgs/main/linux-64/glib-2.78.4-h6a678d5_0.conda
254 | https://conda.anaconda.org/conda-forge/noarch/httpcore-1.0.5-pyhd8ed1ab_0.conda
255 | https://conda.anaconda.org/conda-forge/noarch/hydra-core-1.3.2-pyhd8ed1ab_0.conda
256 | https://conda.anaconda.org/conda-forge/noarch/importlib_metadata-6.9.0-hd8ed1ab_0.conda
257 | https://conda.anaconda.org/conda-forge/noarch/jsonschema-specifications-2023.12.1-pyhd8ed1ab_0.conda
258 | https://conda.anaconda.org/conda-forge/linux-64/jupyter_core-5.5.0-py310hff52083_0.conda
259 | https://conda.anaconda.org/conda-forge/noarch/jupyter_server_terminals-0.5.3-pyhd8ed1ab_0.conda
260 | https://repo.anaconda.com/pkgs/main/linux-64/libthrift-0.15.0-h1795dd8_2.conda
261 | https://repo.anaconda.com/pkgs/main/linux-64/numpy-base-1.26.4-py310hb5e798b_0.conda
262 | https://repo.anaconda.com/pkgs/main/linux-64/pillow-10.0.1-py310ha6cbd5a_0.conda
263 | https://conda.anaconda.org/conda-forge/noarch/stack_data-0.6.2-pyhd8ed1ab_0.conda
264 | https://conda.anaconda.org/conda-forge/noarch/argon2-cffi-23.1.0-pyhd8ed1ab_0.conda
265 | https://repo.anaconda.com/pkgs/main/linux-64/aws-crt-cpp-0.18.16-h6a678d5_0.conda
266 | https://repo.anaconda.com/pkgs/main/linux-64/dbus-1.13.18-hb2f20db_0.conda
267 | https://repo.anaconda.com/pkgs/main/linux-64/gstreamer-1.14.1-h5eee18b_1.conda
268 | https://conda.anaconda.org/conda-forge/noarch/httpx-0.27.0-pyhd8ed1ab_0.conda
269 | https://conda.anaconda.org/conda-forge/noarch/ipython-8.18.1-pyh31011fe_1.conda
270 | https://conda.anaconda.org/conda-forge/noarch/isoduration-20.11.0-pyhd8ed1ab_0.tar.bz2
271 | https://conda.anaconda.org/conda-forge/noarch/jsonschema-4.21.1-pyhd8ed1ab_0.conda
272 | https://conda.anaconda.org/conda-forge/noarch/jupyter_client-7.4.9-pyhd8ed1ab_0.conda
273 | https://repo.anaconda.com/pkgs/main/linux-64/pyopenssl-23.2.0-py310h06a4308_0.conda
274 | https://repo.anaconda.com/pkgs/main/linux-64/aws-sdk-cpp-1.10.55-h721c034_0.conda
275 | https://repo.anaconda.com/pkgs/main/linux-64/gst-plugins-base-1.14.1-h6a678d5_1.conda
276 | https://conda.anaconda.org/conda-forge/noarch/ipykernel-6.26.0-pyhf8b6a83_0.conda
277 | https://conda.anaconda.org/conda-forge/noarch/jsonschema-with-format-nongpl-4.21.1-pyhd8ed1ab_0.conda
278 | https://conda.anaconda.org/conda-forge/noarch/nbformat-5.10.4-pyhd8ed1ab_0.conda
279 | https://repo.anaconda.com/pkgs/main/linux-64/urllib3-1.26.18-py310h06a4308_0.conda
280 | https://repo.anaconda.com/pkgs/main/linux-64/arrow-cpp-14.0.2-h374c478_1.conda
281 | https://conda.anaconda.org/conda-forge/noarch/jupyter_events-0.10.0-pyhd8ed1ab_0.conda
282 | https://conda.anaconda.org/conda-forge/noarch/nbclient-0.10.0-pyhd8ed1ab_0.conda
283 | https://repo.anaconda.com/pkgs/main/linux-64/qt-main-5.15.2-h53bd1ea_10.conda
284 | https://repo.anaconda.com/pkgs/main/linux-64/requests-2.31.0-py310h06a4308_0.conda
285 | https://repo.anaconda.com/pkgs/main/linux-64/huggingface_hub-0.20.3-py310h06a4308_0.conda
286 | https://conda.anaconda.org/conda-forge/noarch/nbconvert-core-7.16.3-pyhd8ed1ab_0.conda
287 | https://repo.anaconda.com/pkgs/main/linux-64/pyqt-5.15.10-py310h6a678d5_0.conda
288 | https://repo.anaconda.com/pkgs/main/noarch/responses-0.13.3-pyhd3eb1b0_0.conda
289 | https://conda.anaconda.org/conda-forge/noarch/jupyter_server-2.13.0-pyhd8ed1ab_0.conda
290 | https://repo.anaconda.com/pkgs/main/linux-64/tokenizers-0.15.1-py310h22610ee_0.conda
291 | https://conda.anaconda.org/conda-forge/noarch/jupyter-lsp-2.2.4-pyhd8ed1ab_0.conda
292 | https://conda.anaconda.org/conda-forge/noarch/jupyterlab_server-2.25.4-pyhd8ed1ab_0.conda
293 | https://conda.anaconda.org/conda-forge/noarch/notebook-shim-0.2.4-pyhd8ed1ab_0.conda
294 | https://conda.anaconda.org/conda-forge/noarch/jupyterlab-4.1.5-pyhd8ed1ab_0.conda
295 | https://conda.anaconda.org/conda-forge/noarch/notebook-7.1.2-pyhd8ed1ab_0.conda
296 | https://conda.anaconda.org/conda-forge/linux-64/ml_dtypes-0.4.0-py310hcc13569_0.conda
297 | https://repo.anaconda.com/pkgs/main/linux-64/bottleneck-1.3.7-py310ha9d4c09_0.conda
298 | https://repo.anaconda.com/pkgs/main/linux-64/contourpy-1.2.0-py310hdb19cb5_0.conda
299 | https://repo.anaconda.com/pkgs/main/linux-64/matplotlib-3.8.0-py310h06a4308_0.conda
300 | https://repo.anaconda.com/pkgs/main/linux-64/matplotlib-base-3.8.0-py310h1128e8f_0.conda
301 | https://repo.anaconda.com/pkgs/main/linux-64/mkl_fft-1.3.8-py310h5eee18b_0.conda
302 | https://repo.anaconda.com/pkgs/main/linux-64/mkl_random-1.2.4-py310hdb19cb5_0.conda
303 | https://repo.anaconda.com/pkgs/main/linux-64/numpy-1.26.4-py310h5f9d8c6_0.conda
304 | https://repo.anaconda.com/pkgs/main/linux-64/numexpr-2.8.7-py310h85018f9_0.conda
305 | https://repo.anaconda.com/pkgs/main/linux-64/pyarrow-14.0.2-py310h1eedbd7_0.conda
306 | https://repo.anaconda.com/pkgs/main/linux-64/scipy-1.11.4-py310h5f9d8c6_0.conda
307 | https://repo.anaconda.com/pkgs/main/linux-64/transformers-4.37.2-py310h06a4308_0.conda
308 | https://repo.anaconda.com/pkgs/main/linux-64/pandas-2.2.1-py310h6a678d5_0.conda
309 | https://conda.anaconda.org/anaconda/linux-64/scikit-learn-1.3.0-py310h1128e8f_0.tar.bz2
310 | https://repo.anaconda.com/pkgs/main/linux-64/datasets-2.12.0-py310h06a4308_0.conda
311 | https://conda.anaconda.org/conda-forge/noarch/accelerate-0.27.0-pyhd8ed1ab_0.conda
312 | https://conda.anaconda.org/pytorch/linux-64/pytorch-2.1.1-py3.10_cuda11.8_cudnn8.7.0_0.tar.bz2
313 | https://conda.anaconda.org/pytorch/linux-64/torchaudio-2.1.1-py310_cu118.tar.bz2
314 | https://conda.anaconda.org/pytorch/linux-64/torchtriton-2.1.0-py310.tar.bz2
315 | https://conda.anaconda.org/pytorch/linux-64/torchvision-0.16.1-py310_cu118.tar.bz2
316 |
--------------------------------------------------------------------------------
/grok/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "architectures": [
3 | "Grok1ModelForCausalLM"
4 | ],
5 | "auto_map": {
6 | "AutoConfig": "configuration_grok1.Grok1Config",
7 | "AutoModel": "modeling_grok1.Grok1Model",
8 | "AutoModelForCausalLM": "modeling_grok1.Grok1ModelForCausalLM"
9 | },
10 | "vocab_size": 131072,
11 | "hidden_size": 6144,
12 | "intermediate_size": 32768,
13 | "num_hidden_layers": 64,
14 | "num_attention_heads": 48,
15 | "num_key_value_heads": 8,
16 | "attn_output_multiplier": 0.08838834764831845,
17 | "embedding_multiplier_scale": 78.38367176906169,
18 | "output_multiplier_scale": 0.5773502691896257,
19 | "max_attn_value": 30.0,
20 | "max_position_embeddings": 8192,
21 | "rms_norm_eps": 1e-5,
22 | "use_cache": true,
23 | "pad_token_id": 0,
24 | "bos_token_id": 1,
25 | "eos_token_id": 2,
26 | "tie_word_embeddings": true,
27 | "num_experts_per_tok": 2,
28 | "num_experts": 8,
29 | "output_router_logits": false,
30 | "router_aux_loss_coef": 0.001,
31 | "torch_dtype": "bfloat16",
32 | "transformers_version": "4.35.0"
33 | }
--------------------------------------------------------------------------------
/grok/configuration_grok1.py:
--------------------------------------------------------------------------------
1 | from transformers.configuration_utils import PretrainedConfig
2 |
3 |
4 | class Grok1Config(PretrainedConfig):
5 | model_type = "grok-1"
6 | keys_to_ignore_at_inference = ["past_key_values"]
7 |
8 | def __init__(
9 | self,
10 | vocab_size=32000,
11 | hidden_size=4096,
12 | intermediate_size=32768,
13 | num_hidden_layers=32,
14 | num_attention_heads=32,
15 | num_key_value_heads=32,
16 | attn_output_multiplier=1.0,
17 | max_attn_value=1.0,
18 | max_position_embeddings=4096,
19 | embedding_multiplier_scale: float = 1.0,
20 | output_multiplier_scale: float = 1.0,
21 | rms_norm_eps=1e-5,
22 | use_cache=True,
23 | pad_token_id=None,
24 | bos_token_id=1,
25 | eos_token_id=2,
26 | tie_word_embeddings=True,
27 | num_experts_per_tok=2,
28 | num_experts=8,
29 | output_router_logits=False,
30 | router_aux_loss_coef=0.001,
31 | **kwargs
32 | ):
33 | self.vocab_size = vocab_size
34 | self.attn_output_multiplier = attn_output_multiplier
35 | self.max_attn_value = max_attn_value
36 | self.max_position_embeddings = max_position_embeddings
37 | self.embedding_multiplier_scale = embedding_multiplier_scale
38 | self.output_multiplier_scale = output_multiplier_scale
39 | self.hidden_size = hidden_size
40 | self.intermediate_size = intermediate_size
41 | self.num_hidden_layers = num_hidden_layers
42 | self.num_attention_heads = num_attention_heads
43 |
44 | # for backward compatibility
45 | if num_key_value_heads is None:
46 | num_key_value_heads = num_attention_heads
47 |
48 | self.num_key_value_heads = num_key_value_heads
49 | self.rms_norm_eps = rms_norm_eps
50 | self.use_cache = use_cache
51 |
52 | self.num_experts_per_tok = num_experts_per_tok
53 | self.num_experts = num_experts
54 | self.output_router_logits = output_router_logits
55 | self.router_aux_loss_coef = router_aux_loss_coef
56 | super().__init__(
57 | pad_token_id=pad_token_id,
58 | bos_token_id=bos_token_id,
59 | eos_token_id=eos_token_id,
60 | tie_word_embeddings=tie_word_embeddings,
61 | **kwargs,
62 | )
63 |
--------------------------------------------------------------------------------
/grok/modeling_grok1.py:
--------------------------------------------------------------------------------
1 | from typing import List, Optional, Tuple, Union
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from transformers.modeling_utils import PreTrainedModel
7 | from transformers.utils import logging
8 |
9 | try:
10 | from transformers.modeling_attn_mask_utils import \
11 | _prepare_4d_causal_attention_mask
12 |
13 | HAS_MASK_UTILS = True
14 | except ImportError:
15 | HAS_MASK_UTILS = False
16 |
17 | from .configuration_grok1 import Grok1Config
18 | from .modeling_grok1_outputs import (MoeCausalLMOutputWithPast,
19 | MoeModelOutputWithPast)
20 |
21 | logger = logging.get_logger(__name__)
22 |
23 |
24 | # copied from https://github.com/huggingface/transformers/blob/v4.36.1/src/transformers/models/mixtral/modeling_mixtral.py
25 | def load_balancing_loss_func(
26 | gate_logits: torch.Tensor, num_experts: torch.Tensor = None, top_k=2
27 | ) -> float:
28 | r"""
29 | Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
30 |
31 | See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss
32 | function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between
33 | experts is too unbalanced.
34 |
35 | Args:
36 | gate_logits (Union[`torch.Tensor`, Tuple[torch.Tensor]):
37 | Logits from the `gate`, should be a tuple of tensors. Shape: [batch_size, seqeunce_length, num_experts].
38 | num_experts (`int`, *optional*):
39 | Number of experts
40 |
41 | Returns:
42 | The auxiliary loss.
43 | """
44 | if gate_logits is None:
45 | return 0
46 |
47 | if isinstance(gate_logits, tuple):
48 | # cat along the layers?
49 | compute_device = gate_logits[0].device
50 | gate_logits = torch.cat(
51 | [gate.to(compute_device) for gate in gate_logits], dim=0
52 | )
53 |
54 | routing_weights, selected_experts = torch.topk(gate_logits, top_k, dim=-1)
55 | routing_weights = routing_weights.softmax(dim=-1)
56 |
57 | # cast the expert indices to int64, otherwise one-hot encoding will fail
58 | if selected_experts.dtype != torch.int64:
59 | selected_experts = selected_experts.to(torch.int64)
60 |
61 | if len(selected_experts.shape) == 2:
62 | selected_experts = selected_experts.unsqueeze(2)
63 |
64 | expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)
65 |
66 | # For a given token, determine if it was routed to a given expert.
67 | expert_mask = torch.max(expert_mask, axis=-2).values
68 |
69 | # cast to float32 otherwise mean will fail
70 | expert_mask = expert_mask.to(torch.float32)
71 | tokens_per_group_and_expert = torch.mean(expert_mask, axis=-2)
72 |
73 | router_prob_per_group_and_expert = torch.mean(routing_weights, axis=-1)
74 | return torch.mean(
75 | tokens_per_group_and_expert * router_prob_per_group_and_expert.unsqueeze(-1)
76 | ) * (num_experts**2)
77 |
78 |
79 | # Copied from transformers.models.llama.modeling_llama.repeat_kv
80 | def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
81 | """
82 | This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
83 | num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
84 | """
85 | batch, num_key_value_heads, slen, head_dim = hidden_states.shape
86 | if n_rep == 1:
87 | return hidden_states
88 | hidden_states = hidden_states[:, :, None, :, :].expand(
89 | batch, num_key_value_heads, n_rep, slen, head_dim
90 | )
91 | return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
92 |
93 |
94 | class RMSNorm(nn.Module):
95 | def __init__(
96 | self,
97 | hidden_size: int,
98 | eps: float = 1e-5,
99 | create_scale: bool = True,
100 | ) -> None:
101 | super().__init__()
102 | self.variance_epsilon = eps
103 | if create_scale:
104 | self.scale = nn.Parameter(torch.zeros(hidden_size))
105 | else:
106 | self.scale = 1.0
107 |
108 | def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
109 | input_dtype = hidden_states.dtype
110 | hidden_states = hidden_states.to(torch.float32)
111 | variance = hidden_states.pow(2).mean(-1, keepdim=True)
112 | hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
113 | hidden_states = self.scale * hidden_states
114 | return hidden_states.to(input_dtype)
115 |
116 |
117 | class RotaryEmbedding(nn.Module):
118 | def __init__(
119 | self, dim: int, max_position_embeddings: int = 2048, base: int = 10000
120 | ) -> None:
121 | super().__init__()
122 | assert dim % 2 == 0
123 | self.dim = dim
124 | self.max_position_embeddings = max_position_embeddings
125 | self.base = base
126 | inv_freq = 1.0 / (
127 | self.base ** (torch.arange(0, self.dim, 2).float() / self.dim)
128 | )
129 | self.register_buffer("inv_freq", inv_freq, persistent=False)
130 |
131 | self._set_cos_sin_cache(
132 | seq_len=max_position_embeddings,
133 | device=self.inv_freq.device,
134 | dtype=torch.get_default_dtype(),
135 | )
136 |
137 | def _set_cos_sin_cache(self, seq_len, device, dtype):
138 | self.max_seq_len_cached = seq_len
139 | t = torch.arange(
140 | self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
141 | )
142 |
143 | freqs = torch.outer(t, self.inv_freq)
144 | # Different from paper, but it uses a different permutation in order to obtain the same calculation
145 | emb = torch.cat((freqs, freqs), dim=-1)
146 | self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
147 | self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
148 |
149 | def forward(self, x, seq_len=None):
150 | # x: [bs, num_attention_heads, seq_len, head_size]
151 | if seq_len > self.max_seq_len_cached:
152 | self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
153 |
154 | return (
155 | self.cos_cached[:seq_len].to(dtype=x.dtype),
156 | self.sin_cached[:seq_len].to(dtype=x.dtype),
157 | )
158 |
159 |
160 | # Copied from transformers.models.llama.modeling_llama.rotate_half
161 | def rotate_half(x):
162 | """Rotates half the hidden dims of the input."""
163 | x1 = x[..., : x.shape[-1] // 2]
164 | x2 = x[..., x.shape[-1] // 2 :]
165 | return torch.cat((-x2, x1), dim=-1)
166 |
167 |
168 | # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
169 | def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
170 | """Applies Rotary Position Embedding to the query and key tensors.
171 |
172 | Args:
173 | q (`torch.Tensor`): The query tensor.
174 | k (`torch.Tensor`): The key tensor.
175 | cos (`torch.Tensor`): The cosine part of the rotary embedding.
176 | sin (`torch.Tensor`): The sine part of the rotary embedding.
177 | position_ids (`torch.Tensor`):
178 | The position indices of the tokens corresponding to the query and key tensors. For example, this can be
179 | used to pass offsetted position ids when working with a KV-cache.
180 | unsqueeze_dim (`int`, *optional*, defaults to 1):
181 | The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
182 | sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
183 | that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
184 | k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
185 | cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
186 | the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
187 | Returns:
188 | `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
189 | """
190 | cos = cos[position_ids].unsqueeze(unsqueeze_dim)
191 | sin = sin[position_ids].unsqueeze(unsqueeze_dim)
192 | q_embed = (q * cos) + (rotate_half(q) * sin)
193 | k_embed = (k * cos) + (rotate_half(k) * sin)
194 | return q_embed, k_embed
195 |
196 |
197 | class MultiHeadAttention(nn.Module):
198 | def __init__(
199 | self,
200 | hidden_size: int,
201 | num_heads: int,
202 | num_key_value_heads: Optional[int] = None,
203 | max_position_embeddings: int = 2048,
204 | attn_output_multiplier: float = 1.0,
205 | max_attn_val: float = 30.0,
206 | ):
207 | super().__init__()
208 | self.hidden_size = hidden_size
209 | self.num_heads = num_heads
210 | self.head_dim = hidden_size // num_heads
211 | if num_key_value_heads is None:
212 | num_key_value_heads = num_heads
213 | self.num_key_value_heads = num_key_value_heads
214 | self.num_key_value_groups = self.num_heads // self.num_key_value_heads
215 | self.attn_output_multiplier = attn_output_multiplier
216 | self.max_attn_val = max_attn_val
217 |
218 | if (self.head_dim * self.num_heads) != self.hidden_size:
219 | raise ValueError(
220 | f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
221 | f" and `num_heads`: {self.num_heads})."
222 | )
223 |
224 | self.q_proj = nn.Linear(hidden_size, self.num_heads * self.head_dim, bias=False)
225 | self.k_proj = nn.Linear(
226 | hidden_size, self.num_key_value_heads * self.head_dim, bias=False
227 | )
228 | self.v_proj = nn.Linear(
229 | hidden_size, self.num_key_value_heads * self.head_dim, bias=False
230 | )
231 | self.o_proj = nn.Linear(self.num_heads * self.head_dim, hidden_size, bias=False)
232 |
233 | self.rotary_emb = RotaryEmbedding(
234 | self.head_dim,
235 | max_position_embeddings=max_position_embeddings,
236 | )
237 |
238 | def forward(
239 | self,
240 | hidden_states: torch.Tensor,
241 | attention_mask: Optional[torch.Tensor] = None,
242 | position_ids: Optional[torch.LongTensor] = None,
243 | past_key_value: Optional[Tuple[torch.Tensor]] = None,
244 | output_attentions: bool = False,
245 | use_cache: bool = False,
246 | **kwargs,
247 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
248 | bsz, q_len, _ = hidden_states.size()
249 |
250 | query_states = self.q_proj(hidden_states)
251 | key_states = self.k_proj(hidden_states)
252 | value_states = self.v_proj(hidden_states)
253 |
254 | query_states = query_states.view(
255 | bsz, q_len, self.num_heads, self.head_dim
256 | ).transpose(1, 2)
257 | key_states = key_states.view(
258 | bsz, q_len, self.num_key_value_heads, self.head_dim
259 | ).transpose(1, 2)
260 | value_states = value_states.view(
261 | bsz, q_len, self.num_key_value_heads, self.head_dim
262 | ).transpose(1, 2)
263 |
264 | kv_seq_len = key_states.shape[-2]
265 | if past_key_value is not None:
266 | kv_seq_len += past_key_value[0].shape[-2]
267 |
268 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
269 | query_states, key_states = apply_rotary_pos_emb(
270 | query_states, key_states, cos, sin, position_ids
271 | )
272 |
273 | if past_key_value is not None:
274 | # reuse k, v, self_attention
275 | key_states = torch.cat([past_key_value[0], key_states], dim=2)
276 | value_states = torch.cat([past_key_value[1], value_states], dim=2)
277 |
278 | past_key_value = (key_states, value_states) if use_cache else None
279 |
280 | # repeat k/v heads if n_kv_heads < n_heads
281 | key_states = repeat_kv(key_states, self.num_key_value_groups)
282 | value_states = repeat_kv(value_states, self.num_key_value_groups)
283 |
284 | attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)).to(
285 | torch.float
286 | )
287 | attn_weights = attn_weights * self.attn_output_multiplier
288 | attn_weights = self.max_attn_val * F.tanh(attn_weights / self.max_attn_val)
289 |
290 | if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
291 | raise ValueError(
292 | f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
293 | f" {attn_weights.size()}"
294 | )
295 |
296 | if attention_mask is not None:
297 | if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
298 | raise ValueError(
299 | f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
300 | )
301 |
302 | attn_weights = attn_weights + attention_mask
303 |
304 | attn_weights = F.softmax(attn_weights, dim=-1).to(query_states.dtype)
305 | attn_output = torch.matmul(attn_weights, value_states)
306 |
307 | if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
308 | raise ValueError(
309 | f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
310 | f" {attn_output.size()}"
311 | )
312 |
313 | attn_output = attn_output.transpose(1, 2).contiguous()
314 | attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
315 |
316 | attn_output = self.o_proj(attn_output)
317 |
318 | if not output_attentions:
319 | attn_weights = None
320 |
321 | return attn_output, attn_weights, past_key_value
322 |
323 |
324 | class MoeMLP(nn.Module):
325 | def __init__(
326 | self,
327 | hidden_dim: int,
328 | ffn_dim: int,
329 | ) -> None:
330 | super().__init__()
331 | self.linear_v = nn.Linear(hidden_dim, ffn_dim, bias=False)
332 | self.linear_1 = nn.Linear(ffn_dim, hidden_dim, bias=False)
333 | self.linear = nn.Linear(hidden_dim, ffn_dim, bias=False)
334 | self.act_fn = nn.GELU()
335 |
336 | def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
337 | current_hidden_states = self.act_fn(self.linear(hidden_states)) * self.linear_v(
338 | hidden_states
339 | )
340 | current_hidden_states = self.linear_1(current_hidden_states)
341 | return current_hidden_states
342 |
343 |
344 | class MoeBlock(nn.Module):
345 | def __init__(
346 | self,
347 | hidden_dim: int,
348 | ffn_dim: int,
349 | num_experts: int,
350 | top_k: int,
351 | ) -> None:
352 | super().__init__()
353 | self.num_experts = num_experts
354 | self.top_k = top_k
355 | self.gate = nn.Linear(hidden_dim, num_experts, bias=False)
356 | self.experts = nn.ModuleList(
357 | [MoeMLP(hidden_dim, ffn_dim) for _ in range(num_experts)]
358 | )
359 |
360 | def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor]:
361 | batch_size, sequence_length, hidden_dim = hidden_states.shape
362 | hidden_states = hidden_states.view(-1, hidden_dim)
363 | # router_logits: (batch * sequence_length, n_experts)
364 | router_logits = self.gate(hidden_states)
365 |
366 | routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
367 | routing_weights, selected_experts = torch.topk(
368 | routing_weights, self.top_k, dim=-1
369 | )
370 | # we cast back to the input dtype
371 | routing_weights = routing_weights.to(hidden_states.dtype)
372 |
373 | final_hidden_states = torch.zeros(
374 | (batch_size * sequence_length, hidden_dim),
375 | dtype=hidden_states.dtype,
376 | device=hidden_states.device,
377 | )
378 | # One hot encode the selected experts to create an expert mask
379 | # this will be used to easily index which expert is going to be sollicitated
380 | expert_mask = torch.nn.functional.one_hot(
381 | selected_experts, num_classes=self.num_experts
382 | ).permute(2, 1, 0)
383 |
384 | # Loop over all available experts in the model and perform the computation on each expert
385 | for expert_idx in range(self.num_experts):
386 | expert_layer = self.experts[expert_idx]
387 | idx, top_x = torch.where(expert_mask[expert_idx])
388 |
389 | if top_x.shape[0] == 0:
390 | continue
391 |
392 | # in torch it is faster to index using lists than torch tensors
393 | top_x_list = top_x.tolist()
394 | idx_list = idx.tolist()
395 |
396 | # Index the correct hidden states and compute the expert hidden state for
397 | # the current expert. We need to make sure to multiply the output hidden
398 | # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
399 | current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim)
400 | current_hidden_states = (
401 | expert_layer(current_state)
402 | * routing_weights[top_x_list, idx_list, None]
403 | )
404 |
405 | # However `index_add_` only support torch tensors for indexing so we'll use
406 | # the `top_x` tensor here.
407 | final_hidden_states.index_add_(
408 | 0, top_x, current_hidden_states.to(hidden_states.dtype)
409 | )
410 | final_hidden_states = final_hidden_states.reshape(
411 | batch_size, sequence_length, hidden_dim
412 | )
413 | return final_hidden_states, router_logits
414 |
415 |
416 | class DecoderLayer(nn.Module):
417 | def __init__(
418 | self,
419 | hidden_size: int,
420 | intermediate_size: int,
421 | num_heads: int,
422 | num_key_value_heads: int,
423 | num_experts: int,
424 | top_k: int,
425 | max_position_embeddings: int = 2048,
426 | attn_output_multiplier: float = 1.0,
427 | max_attn_val: float = 30.0,
428 | rms_norm_eps: float = 1e-5,
429 | ) -> None:
430 | super().__init__()
431 | self.attn = MultiHeadAttention(
432 | hidden_size,
433 | num_heads,
434 | num_key_value_heads,
435 | max_position_embeddings=max_position_embeddings,
436 | attn_output_multiplier=attn_output_multiplier,
437 | max_attn_val=max_attn_val,
438 | )
439 | self.moe_block = MoeBlock(hidden_size, intermediate_size, num_experts, top_k)
440 | self.pre_attn_norm = RMSNorm(hidden_size, eps=rms_norm_eps)
441 | self.post_attn_norm = RMSNorm(hidden_size, eps=rms_norm_eps)
442 | self.pre_moe_norm = RMSNorm(hidden_size, eps=rms_norm_eps)
443 | self.post_moe_norm = RMSNorm(hidden_size, eps=rms_norm_eps)
444 |
445 | def forward(
446 | self,
447 | hidden_states: torch.Tensor,
448 | attention_mask: Optional[torch.Tensor] = None,
449 | position_ids: Optional[torch.LongTensor] = None,
450 | past_key_value: Optional[Tuple[torch.Tensor]] = None,
451 | output_attentions: Optional[bool] = False,
452 | output_router_logits: Optional[bool] = False,
453 | use_cache: Optional[bool] = False,
454 | **kwargs,
455 | ) -> Tuple[
456 | torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
457 | ]:
458 | residual = hidden_states
459 | hidden_states = self.pre_attn_norm(hidden_states)
460 | hidden_states, attention_weights, present_key_value = self.attn(
461 | hidden_states,
462 | attention_mask=attention_mask,
463 | position_ids=position_ids,
464 | past_key_value=past_key_value,
465 | output_attentions=output_attentions,
466 | use_cache=use_cache,
467 | )
468 | hidden_states = self.post_attn_norm(hidden_states)
469 | hidden_states = residual + hidden_states
470 |
471 | residual = hidden_states
472 | hidden_states = self.pre_moe_norm(hidden_states)
473 | hidden_states, router_logits = self.moe_block(hidden_states)
474 | hidden_states = self.post_moe_norm(hidden_states)
475 | hidden_states = residual + hidden_states
476 |
477 | outputs = (hidden_states,)
478 | if output_attentions:
479 | outputs += (attention_weights,)
480 | if use_cache:
481 | outputs += (present_key_value,)
482 | if output_router_logits:
483 | outputs += (router_logits,)
484 | return outputs
485 |
486 |
487 | class Grok1PretrainedModel(PreTrainedModel):
488 | config_class = Grok1Config
489 | base_model_prefix = "model"
490 | supports_gradient_checkpointing = True
491 | _no_split_modules = ["DecoderLayer"]
492 | _skip_keys_device_placement = "past_key_values"
493 | _supports_flash_attn_2 = False
494 | _supports_cache_class = False
495 |
496 | def _init_weights(self, module) -> None:
497 | if isinstance(module, nn.Linear):
498 | module.weight.data.zero_()
499 | if module.bias is not None:
500 | module.bias.data.zero_()
501 | elif isinstance(module, nn.Embedding):
502 | module.weight.data.zero_()
503 |
504 |
505 | # Copied from transformers.models.bart.modeling_bart._make_causal_mask
506 | def _make_causal_mask(
507 | input_ids_shape: torch.Size,
508 | dtype: torch.dtype,
509 | device: torch.device,
510 | past_key_values_length: int = 0,
511 | ):
512 | """
513 | Make causal mask used for bi-directional self-attention.
514 | """
515 | bsz, tgt_len = input_ids_shape
516 | mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
517 | mask_cond = torch.arange(mask.size(-1), device=device)
518 | mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
519 | mask = mask.to(dtype)
520 |
521 | if past_key_values_length > 0:
522 | mask = torch.cat(
523 | [
524 | torch.zeros(
525 | tgt_len, past_key_values_length, dtype=dtype, device=device
526 | ),
527 | mask,
528 | ],
529 | dim=-1,
530 | )
531 | return mask[None, None, :, :].expand(
532 | bsz, 1, tgt_len, tgt_len + past_key_values_length
533 | )
534 |
535 |
536 | # Copied from transformers.models.bart.modeling_bart._expand_mask
537 | def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
538 | """
539 | Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
540 | """
541 | bsz, src_len = mask.size()
542 | tgt_len = tgt_len if tgt_len is not None else src_len
543 |
544 | expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
545 |
546 | inverted_mask = 1.0 - expanded_mask
547 |
548 | return inverted_mask.masked_fill(
549 | inverted_mask.to(torch.bool), torch.finfo(dtype).min
550 | )
551 |
552 |
553 | class Grok1Model(Grok1PretrainedModel):
554 | def __init__(self, config: Grok1Config, **kwargs) -> None:
555 | super().__init__(config)
556 | self.padding_idx = config.pad_token_id
557 | self.vocab_size = config.vocab_size
558 | self.embedding_multiplier_scale = config.embedding_multiplier_scale
559 |
560 | self.embed_tokens = nn.Embedding(
561 | config.vocab_size, config.hidden_size, self.padding_idx
562 | )
563 | self.layers = nn.ModuleList(
564 | [
565 | DecoderLayer(
566 | hidden_size=config.hidden_size,
567 | intermediate_size=config.intermediate_size,
568 | num_heads=config.num_attention_heads,
569 | num_key_value_heads=config.num_key_value_heads,
570 | num_experts=config.num_experts,
571 | top_k=config.num_experts_per_tok,
572 | max_position_embeddings=config.max_position_embeddings,
573 | attn_output_multiplier=config.attn_output_multiplier,
574 | max_attn_val=config.max_attn_value,
575 | rms_norm_eps=config.rms_norm_eps,
576 | )
577 | for layer_idx in range(config.num_hidden_layers)
578 | # for layer_idx in range(1, config.num_hidden_layers)
579 | ]
580 | )
581 | # self.layers.insert(0, None)
582 | self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
583 | self.gradient_checkpointing = False
584 | self.post_init()
585 |
586 | def get_input_embeddings(self):
587 | return self.embed_tokens
588 |
589 | def set_input_embeddings(self, value):
590 | self.embed_tokens = value
591 |
592 | # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
593 | def _prepare_decoder_attention_mask(
594 | self, attention_mask, input_shape, inputs_embeds, past_key_values_length
595 | ):
596 | # create causal mask
597 | # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
598 | combined_attention_mask = None
599 | if input_shape[-1] > 1:
600 | combined_attention_mask = _make_causal_mask(
601 | input_shape,
602 | inputs_embeds.dtype,
603 | device=inputs_embeds.device,
604 | past_key_values_length=past_key_values_length,
605 | )
606 |
607 | if attention_mask is not None:
608 | # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
609 | expanded_attn_mask = _expand_mask(
610 | attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
611 | ).to(inputs_embeds.device)
612 | combined_attention_mask = (
613 | expanded_attn_mask
614 | if combined_attention_mask is None
615 | else expanded_attn_mask + combined_attention_mask
616 | )
617 |
618 | return combined_attention_mask
619 |
620 | def forward(
621 | self,
622 | input_ids: torch.LongTensor = None,
623 | attention_mask: Optional[torch.Tensor] = None,
624 | position_ids: Optional[torch.LongTensor] = None,
625 | past_key_values: Optional[List[torch.FloatTensor]] = None,
626 | inputs_embeds: Optional[torch.FloatTensor] = None,
627 | use_cache: Optional[bool] = None,
628 | output_attentions: Optional[bool] = None,
629 | output_hidden_states: Optional[bool] = None,
630 | output_router_logits: Optional[bool] = None,
631 | return_dict: Optional[bool] = None,
632 | decoder_layer_idx: Optional[int] = None,
633 | ) -> Union[Tuple, MoeModelOutputWithPast]:
634 | output_attentions = (
635 | output_attentions
636 | if output_attentions is not None
637 | else self.config.output_attentions
638 | )
639 | output_hidden_states = (
640 | output_hidden_states
641 | if output_hidden_states is not None
642 | else self.config.output_hidden_states
643 | )
644 | use_cache = use_cache if use_cache is not None else self.config.use_cache
645 |
646 | return_dict = (
647 | return_dict if return_dict is not None else self.config.use_return_dict
648 | )
649 |
650 | # retrieve input_ids and inputs_embeds
651 | if input_ids is not None and inputs_embeds is not None:
652 | raise ValueError(
653 | "You cannot specify both input_ids and inputs_embeds at the same time"
654 | )
655 | elif input_ids is not None:
656 | batch_size, seq_length = input_ids.shape[:2]
657 | elif inputs_embeds is not None:
658 | batch_size, seq_length = inputs_embeds.shape[:2]
659 | else:
660 | raise ValueError("You have to specify either input_ids or inputs_embeds")
661 |
662 | seq_length_with_past = seq_length
663 | past_key_values_length = 0
664 | if past_key_values is not None:
665 | past_key_values_length = past_key_values[0][0].shape[2]
666 | seq_length_with_past = seq_length_with_past + past_key_values_length
667 |
668 | if position_ids is None:
669 | device = input_ids.device if input_ids is not None else inputs_embeds.device
670 | position_ids = torch.arange(
671 | past_key_values_length,
672 | seq_length + past_key_values_length,
673 | dtype=torch.long,
674 | device=device,
675 | )
676 | position_ids = position_ids.unsqueeze(0)
677 |
678 | if inputs_embeds is None:
679 | inputs_embeds = self.embed_tokens(input_ids)
680 | inputs_embeds = inputs_embeds * self.embedding_multiplier_scale
681 |
682 | if HAS_MASK_UTILS:
683 | # 4d mask is passed through the layers
684 | attention_mask = _prepare_4d_causal_attention_mask(
685 | attention_mask,
686 | (batch_size, seq_length),
687 | inputs_embeds,
688 | past_key_values_length,
689 | )
690 | else:
691 | if attention_mask is None:
692 | attention_mask = torch.ones(
693 | (batch_size, seq_length_with_past),
694 | dtype=torch.bool,
695 | device=inputs_embeds.device,
696 | )
697 | attention_mask = self._prepare_decoder_attention_mask(
698 | attention_mask,
699 | (batch_size, seq_length),
700 | inputs_embeds,
701 | past_key_values_length,
702 | )
703 |
704 | # embed positions
705 | hidden_states = inputs_embeds
706 |
707 | if self.gradient_checkpointing and self.training:
708 | if use_cache:
709 | logger.warning_once(
710 | "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
711 | )
712 | use_cache = False
713 |
714 | # decoder layers
715 | all_hidden_states = () if output_hidden_states else None
716 | all_self_attns = () if output_attentions else None
717 | all_router_logits = () if output_router_logits else None
718 | next_decoder_cache = () if use_cache else None
719 |
720 | for idx, decoder_layer in enumerate(self.layers):
721 | if decoder_layer_idx is not None and idx != decoder_layer_idx:
722 | continue
723 | if output_hidden_states:
724 | all_hidden_states += (hidden_states,)
725 |
726 | past_key_value = (
727 | past_key_values[idx] if past_key_values is not None else None
728 | )
729 |
730 | if self.gradient_checkpointing and self.training:
731 |
732 | def create_custom_forward(module):
733 | def custom_forward(*inputs):
734 | # None for past_key_value
735 | return module(*inputs, past_key_value, output_attentions)
736 |
737 | return custom_forward
738 |
739 | layer_outputs = torch.utils.checkpoint.checkpoint(
740 | create_custom_forward(decoder_layer),
741 | hidden_states,
742 | attention_mask,
743 | position_ids,
744 | )
745 | else:
746 | layer_outputs = decoder_layer(
747 | hidden_states,
748 | attention_mask=attention_mask,
749 | position_ids=position_ids,
750 | past_key_value=past_key_value,
751 | output_attentions=output_attentions,
752 | use_cache=use_cache,
753 | )
754 |
755 | hidden_states = layer_outputs[0]
756 |
757 | if use_cache:
758 | next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
759 |
760 | if output_attentions:
761 | all_self_attns += (layer_outputs[1],)
762 |
763 | if output_router_logits:
764 | all_router_logits += (layer_outputs[-1],)
765 |
766 | hidden_states = self.norm(hidden_states)
767 |
768 | # add hidden states from the last decoder layer
769 | if output_hidden_states:
770 | all_hidden_states += (hidden_states,)
771 | next_cache = next_decoder_cache if use_cache else None
772 |
773 | if not return_dict:
774 | return tuple(
775 | v
776 | for v in [
777 | hidden_states,
778 | next_cache,
779 | all_hidden_states,
780 | all_self_attns,
781 | all_router_logits,
782 | ]
783 | if v is not None
784 | )
785 | return MoeModelOutputWithPast(
786 | last_hidden_state=hidden_states,
787 | past_key_values=next_cache,
788 | hidden_states=all_hidden_states,
789 | attentions=all_self_attns,
790 | router_logits=all_router_logits,
791 | )
792 |
793 |
794 | class Grok1ModelForCausalLM(Grok1PretrainedModel):
795 | _tied_weights_keys = ["lm_head.weight"]
796 |
797 | def __init__(self, config: Grok1Config, **kwargs):
798 | super().__init__(config)
799 | self.model = Grok1Model(config)
800 | self.vocab_size = config.vocab_size
801 | self.output_multiplier_scale = config.output_multiplier_scale
802 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
803 | self.router_aux_loss_coef = config.router_aux_loss_coef
804 | self.num_experts = config.num_experts
805 | self.num_experts_per_tok = config.num_experts_per_tok
806 | self.post_init()
807 |
808 | def get_input_embeddings(self):
809 | return self.model.embed_tokens
810 |
811 | def set_input_embeddings(self, value):
812 | self.model.embed_tokens = value
813 |
814 | def get_output_embeddings(self):
815 | return self.lm_head
816 |
817 | def set_output_embeddings(self, new_embeddings):
818 | self.lm_head = new_embeddings
819 |
820 | def set_decoder(self, decoder):
821 | self.model = decoder
822 |
823 | def get_decoder(self):
824 | return self.model
825 |
826 | def forward(
827 | self,
828 | input_ids: torch.LongTensor = None,
829 | attention_mask: Optional[torch.Tensor] = None,
830 | position_ids: Optional[torch.LongTensor] = None,
831 | past_key_values: Optional[List[torch.FloatTensor]] = None,
832 | inputs_embeds: Optional[torch.FloatTensor] = None,
833 | labels: Optional[torch.LongTensor] = None,
834 | use_cache: Optional[bool] = None,
835 | output_attentions: Optional[bool] = None,
836 | output_hidden_states: Optional[bool] = None,
837 | output_router_logits: Optional[bool] = None,
838 | return_dict: Optional[bool] = None,
839 | decoder_layer_idx: Optional[int] = None,
840 | ) -> Union[Tuple, MoeCausalLMOutputWithPast]:
841 | output_attentions = (
842 | output_attentions
843 | if output_attentions is not None
844 | else self.config.output_attentions
845 | )
846 | output_router_logits = (
847 | output_router_logits
848 | if output_router_logits is not None
849 | else self.config.output_router_logits
850 | )
851 |
852 | output_hidden_states = (
853 | output_hidden_states
854 | if output_hidden_states is not None
855 | else self.config.output_hidden_states
856 | )
857 | return_dict = (
858 | return_dict if return_dict is not None else self.config.use_return_dict
859 | )
860 |
861 | # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
862 | outputs = self.model(
863 | input_ids=input_ids,
864 | attention_mask=attention_mask,
865 | position_ids=position_ids,
866 | past_key_values=past_key_values,
867 | inputs_embeds=inputs_embeds,
868 | use_cache=use_cache,
869 | output_attentions=output_attentions,
870 | output_hidden_states=output_hidden_states,
871 | output_router_logits=output_router_logits,
872 | return_dict=return_dict,
873 | decoder_layer_idx=decoder_layer_idx,
874 | )
875 |
876 | hidden_states = outputs[0]
877 | logits = self.lm_head(hidden_states)
878 | logits = logits * self.output_multiplier_scale
879 | logits = logits.float()
880 |
881 | loss = None
882 | if labels is not None:
883 | # Shift so that tokens < n predict n
884 | shift_logits = logits[..., :-1, :].contiguous()
885 | shift_labels = labels[..., 1:].contiguous()
886 | # Flatten the tokens
887 | loss_fct = nn.CrossEntropyLoss()
888 | shift_logits = shift_logits.view(-1, self.config.vocab_size)
889 | shift_labels = shift_labels.view(-1)
890 | # Enable model parallelism
891 | shift_labels = shift_labels.to(shift_logits.device)
892 | loss = loss_fct(shift_logits, shift_labels)
893 |
894 | aux_loss = None
895 | if output_router_logits:
896 | aux_loss = load_balancing_loss_func(
897 | outputs.router_logits if return_dict else outputs[-1],
898 | self.num_experts,
899 | self.num_experts_per_tok,
900 | )
901 | if labels is not None:
902 | loss += self.router_aux_loss_coef * aux_loss
903 |
904 | if not return_dict:
905 | output = (logits,) + outputs[1:]
906 | if output_router_logits:
907 | output = (aux_loss,) + output
908 | return (loss,) + output if loss is not None else output
909 |
910 | return MoeCausalLMOutputWithPast(
911 | loss=loss,
912 | aux_loss=aux_loss,
913 | logits=logits,
914 | past_key_values=outputs.past_key_values,
915 | hidden_states=outputs.hidden_states,
916 | attentions=outputs.attentions,
917 | router_logits=outputs.router_logits,
918 | )
919 |
920 | def prepare_inputs_for_generation(
921 | self,
922 | input_ids,
923 | past_key_values=None,
924 | attention_mask=None,
925 | inputs_embeds=None,
926 | **kwargs,
927 | ):
928 | if past_key_values:
929 | input_ids = input_ids[:, -1:]
930 |
931 | position_ids = kwargs.get("position_ids", None)
932 | if attention_mask is not None and position_ids is None:
933 | # create position_ids on the fly for batch generation
934 | position_ids = attention_mask.long().cumsum(-1) - 1
935 | position_ids.masked_fill_(attention_mask == 0, 1)
936 | if past_key_values:
937 | position_ids = position_ids[:, -1].unsqueeze(-1)
938 |
939 | # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
940 | if inputs_embeds is not None and past_key_values is None:
941 | model_inputs = {"inputs_embeds": inputs_embeds}
942 | else:
943 | model_inputs = {"input_ids": input_ids}
944 |
945 | model_inputs.update(
946 | {
947 | "position_ids": position_ids,
948 | "past_key_values": past_key_values,
949 | "use_cache": kwargs.get("use_cache"),
950 | "attention_mask": attention_mask,
951 | }
952 | )
953 | return model_inputs
954 |
--------------------------------------------------------------------------------
/grok/modeling_grok1_outputs.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from typing import Optional, Tuple
3 |
4 | import torch
5 | from transformers.modeling_outputs import ModelOutput
6 |
7 | __all__ = [
8 | "MoeModelOutputWithPast",
9 | "MoeCausalLMOutputWithPast",
10 | ]
11 |
12 | try:
13 | from transformers.modeling_outputs import (
14 | MoeCausalLMOutputWithPast,
15 | MoeModelOutputWithPast,
16 | )
17 | except:
18 |
19 | @dataclass
20 | class MoeModelOutputWithPast(ModelOutput):
21 | """
22 | Base class for model's outputs, with potential hidden states and attentions.
23 |
24 | Args:
25 | last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
26 | Sequence of hidden-states at the output of the last layer of the model.
27 | past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
28 | Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
29 | `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
30 | `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
31 | encoder_sequence_length, embed_size_per_head)`.
32 |
33 | Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if
34 | `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values`
35 | input) to speed up sequential decoding.
36 | hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
37 | Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
38 | one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
39 |
40 | Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
41 | attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
42 | Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
43 | sequence_length)`.
44 |
45 | Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
46 | heads.
47 | router_logits (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_probs=True` and `config.add_router_probs=True` is passed or when `config.output_router_probs=True`):
48 | Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`.
49 |
50 | Raw router logtis (post-softmax) that are computed by MoE routers, these terms are used to compute the auxiliary
51 | loss for Mixture of Experts models.
52 | """
53 |
54 | last_hidden_state: torch.FloatTensor = None
55 | past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
56 | hidden_states: Optional[Tuple[torch.FloatTensor]] = None
57 | attentions: Optional[Tuple[torch.FloatTensor]] = None
58 | router_logits: Optional[Tuple[torch.FloatTensor]] = None
59 |
60 | @dataclass
61 | class MoeCausalLMOutputWithPast(ModelOutput):
62 | """
63 | Base class for causal language model (or autoregressive) with mixture of experts outputs.
64 |
65 | Args:
66 | loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
67 | Language modeling loss (for next-token prediction).
68 |
69 | logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
70 | Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
71 |
72 | aux_loss (`torch.FloatTensor`, *optional*, returned when `labels` is provided):
73 | aux_loss for the sparse modules.
74 |
75 | router_logits (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_probs=True` and `config.add_router_probs=True` is passed or when `config.output_router_probs=True`):
76 | Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`.
77 |
78 | Raw router logtis (post-softmax) that are computed by MoE routers, these terms are used to compute the auxiliary
79 | loss for Mixture of Experts models.
80 |
81 | past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
82 | Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
83 | `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
84 |
85 | Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
86 | `past_key_values` input) to speed up sequential decoding.
87 | hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
88 | Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
89 | one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
90 |
91 | Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
92 | attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
93 | Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
94 | sequence_length)`.
95 |
96 | Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
97 | heads.
98 | """
99 |
100 | loss: Optional[torch.FloatTensor] = None
101 | aux_loss: Optional[torch.FloatTensor] = None
102 | logits: torch.FloatTensor = None
103 | past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
104 | hidden_states: Optional[Tuple[torch.FloatTensor]] = None
105 | attentions: Optional[Tuple[torch.FloatTensor]] = None
106 | router_logits: Optional[Tuple[torch.FloatTensor]] = None
107 |
--------------------------------------------------------------------------------
/mistral/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "architectures": [
3 | "MistralForCausalLM"
4 | ],
5 | "bos_token_id": 1,
6 | "eos_token_id": 2,
7 | "hidden_act": "silu",
8 | "hidden_size": 4096,
9 | "initializer_range": 0.02,
10 | "intermediate_size": 14336,
11 | "max_position_embeddings": 32768,
12 | "model_type": "mistral",
13 | "num_attention_heads": 32,
14 | "num_hidden_layers": 32,
15 | "num_key_value_heads": 8,
16 | "rms_norm_eps": 1e-05,
17 | "rope_theta": 10000.0,
18 | "sliding_window": 4096,
19 | "tie_word_embeddings": false,
20 | "torch_dtype": "bfloat16",
21 | "transformers_version": "4.34.0.dev0",
22 | "use_cache": true,
23 | "vocab_size": 32000
24 | }
25 |
--------------------------------------------------------------------------------
/mistral/configuration_mistral.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023 Mistral AI and the HuggingFace Inc. team. All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """ Mistral model configuration"""
16 |
17 | from transformers.configuration_utils import PretrainedConfig
18 | from transformers.utils import logging
19 |
20 |
21 | logger = logging.get_logger(__name__)
22 |
23 | MISTRAL_PRETRAINED_CONFIG_ARCHIVE_MAP = {
24 | "mistralai/Mistral-7B-v0.1": "https://huggingface.co/mistralai/Mistral-7B-v0.1/resolve/main/config.json",
25 | "mistralai/Mistral-7B-Instruct-v0.1": "https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1/resolve/main/config.json",
26 | }
27 |
28 |
29 | class MistralConfig(PretrainedConfig):
30 | r"""
31 | This is the configuration class to store the configuration of a [`MistralModel`]. It is used to instantiate an
32 | Mistral model according to the specified arguments, defining the model architecture. Instantiating a configuration
33 | with the defaults will yield a similar configuration to that of the Mistral-7B-v0.1 or Mistral-7B-Instruct-v0.1.
34 |
35 | [mistralai/Mistral-7B-v0.1](https://huggingface.co/mistralai/Mistral-7B-v0.1)
36 | [mistralai/Mistral-7B-Instruct-v0.1](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1)
37 |
38 | Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
39 | documentation from [`PretrainedConfig`] for more information.
40 |
41 |
42 | Args:
43 | vocab_size (`int`, *optional*, defaults to 32000):
44 | Vocabulary size of the Mistral model. Defines the number of different tokens that can be represented by the
45 | `inputs_ids` passed when calling [`MistralModel`]
46 | hidden_size (`int`, *optional*, defaults to 4096):
47 | Dimension of the hidden representations.
48 | intermediate_size (`int`, *optional*, defaults to 14336):
49 | Dimension of the MLP representations.
50 | num_hidden_layers (`int`, *optional*, defaults to 32):
51 | Number of hidden layers in the Transformer encoder.
52 | num_attention_heads (`int`, *optional*, defaults to 32):
53 | Number of attention heads for each attention layer in the Transformer encoder.
54 | num_key_value_heads (`int`, *optional*, defaults to 8):
55 | This is the number of key_value heads that should be used to implement Grouped Query Attention. If
56 | `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
57 | `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
58 | converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
59 | by meanpooling all the original heads within that group. For more details checkout [this
60 | paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `8`.
61 | hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
62 | The non-linear activation function (function or string) in the decoder.
63 | max_position_embeddings (`int`, *optional*, defaults to `4096*32`):
64 | The maximum sequence length that this model might ever be used with. Mistral's sliding window attention
65 | allows sequence of up to 4096*32 tokens.
66 | initializer_range (`float`, *optional*, defaults to 0.02):
67 | The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
68 | rms_norm_eps (`float`, *optional*, defaults to 1e-06):
69 | The epsilon used by the rms normalization layers.
70 | use_cache (`bool`, *optional*, defaults to `True`):
71 | Whether or not the model should return the last key/values attentions (not used by all models). Only
72 | relevant if `config.is_decoder=True`.
73 | pad_token_id (`int`, *optional*):
74 | The id of the padding token.
75 | bos_token_id (`int`, *optional*, defaults to 1):
76 | The id of the "beginning-of-sequence" token.
77 | eos_token_id (`int`, *optional*, defaults to 2):
78 | The id of the "end-of-sequence" token.
79 | tie_word_embeddings (`bool`, *optional*, defaults to `False`):
80 | Whether the model's input and output word embeddings should be tied.
81 | rope_theta (`float`, *optional*, defaults to 10000.0):
82 | The base period of the RoPE embeddings.
83 | sliding_window (`int`, *optional*, defaults to 4096):
84 | Sliding window attention window size. If not specified, will default to `4096`.
85 | attention_dropout (`float`, *optional*, defaults to 0.0):
86 | The dropout ratio for the attention probabilities.
87 |
88 | ```python
89 | >>> from transformers import MistralModel, MistralConfig
90 |
91 | >>> # Initializing a Mistral 7B style configuration
92 | >>> configuration = MistralConfig()
93 |
94 | >>> # Initializing a model from the Mistral 7B style configuration
95 | >>> model = MistralModel(configuration)
96 |
97 | >>> # Accessing the model configuration
98 | >>> configuration = model.config
99 | ```"""
100 |
101 | model_type = "mistral"
102 | keys_to_ignore_at_inference = ["past_key_values"]
103 |
104 | def __init__(
105 | self,
106 | vocab_size=32000,
107 | hidden_size=4096,
108 | intermediate_size=14336,
109 | num_hidden_layers=32,
110 | num_attention_heads=32,
111 | num_key_value_heads=8,
112 | hidden_act="silu",
113 | max_position_embeddings=4096 * 32,
114 | initializer_range=0.02,
115 | rms_norm_eps=1e-6,
116 | use_cache=True,
117 | pad_token_id=None,
118 | bos_token_id=1,
119 | eos_token_id=2,
120 | tie_word_embeddings=False,
121 | rope_theta=10000.0,
122 | sliding_window=4096,
123 | attention_dropout=0.0,
124 | **kwargs,
125 | ):
126 | self.vocab_size = vocab_size
127 | self.max_position_embeddings = max_position_embeddings
128 | self.hidden_size = hidden_size
129 | self.intermediate_size = intermediate_size
130 | self.num_hidden_layers = num_hidden_layers
131 | self.num_attention_heads = num_attention_heads
132 | self.sliding_window = sliding_window
133 |
134 | # for backward compatibility
135 | if num_key_value_heads is None:
136 | num_key_value_heads = num_attention_heads
137 |
138 | self.num_key_value_heads = num_key_value_heads
139 | self.hidden_act = hidden_act
140 | self.initializer_range = initializer_range
141 | self.rms_norm_eps = rms_norm_eps
142 | self.use_cache = use_cache
143 | self.rope_theta = rope_theta
144 | self.attention_dropout = attention_dropout
145 |
146 | super().__init__(
147 | pad_token_id=pad_token_id,
148 | bos_token_id=bos_token_id,
149 | eos_token_id=eos_token_id,
150 | tie_word_embeddings=tie_word_embeddings,
151 | **kwargs,
152 | )
--------------------------------------------------------------------------------
/mixtral_base/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "architectures": [
3 | "MixtralForCausalLM"
4 | ],
5 | "auto_map": {
6 | "AutoConfig": "configuration_moe_mistral.MixtralConfig",
7 | "AutoModelForCausalLM": "modeling_moe_mistral.MixtralForCausalLM"
8 | },
9 | "attention_dropout": 0.0,
10 | "bos_token_id": 1,
11 | "eos_token_id": 2,
12 | "hidden_act": "silu",
13 | "hidden_size": 4096,
14 | "initializer_range": 0.02,
15 | "intermediate_size": 14336,
16 | "max_position_embeddings": 32768,
17 | "model_type": "mistral",
18 | "num_attention_heads": 32,
19 | "num_experts": 8,
20 | "num_experts_per_token": 1,
21 | "num_hidden_layers": 32,
22 | "num_key_value_heads": 8,
23 | "rms_norm_eps": 1e-05,
24 | "rope_theta": 1000000.0,
25 | "tie_word_embeddings": false,
26 | "torch_dtype": "float16",
27 | "transformers_version": "4.36.0.dev0",
28 | "use_cache": true,
29 | "vocab_size": 32000,
30 | "single_expert": false
31 | }
32 |
--------------------------------------------------------------------------------
/mixtral_base/configuration_moe_mistral.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023 Mistral AI and the HuggingFace Inc. team. All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """ Mistral model configuration"""
16 |
17 | from transformers.configuration_utils import PretrainedConfig
18 | from transformers.utils import logging
19 |
20 |
21 | logger = logging.get_logger(__name__)
22 |
23 | MISTRAL_PRETRAINED_CONFIG_ARCHIVE_MAP = {
24 | "mistralai/Mistral-7B-v0.1": "https://huggingface.co/mistralai/Mistral-7B-v0.1/resolve/main/config.json",
25 | "mistralai/Mistral-7B-Instruct-v0.1": "https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1/resolve/main/config.json",
26 | }
27 |
28 |
29 | class MixtralConfig(PretrainedConfig):
30 | r"""
31 | This is the configuration class to store the configuration of a [`MistralModel`]. It is used to instantiate an
32 | Mistral model according to the specified arguments, defining the model architecture. Instantiating a configuration
33 | with the defaults will yield a similar configuration to that of the Mistral-7B-v0.1 or Mistral-7B-Instruct-v0.1.
34 |
35 | [mistralai/Mistral-7B-v0.1](https://huggingface.co/mistralai/Mistral-7B-v0.1)
36 | [mistralai/Mistral-7B-Instruct-v0.1](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1)
37 |
38 | Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
39 | documentation from [`PretrainedConfig`] for more information.
40 |
41 |
42 | Args:
43 | vocab_size (`int`, *optional*, defaults to 32000):
44 | Vocabulary size of the Mistral model. Defines the number of different tokens that can be represented by the
45 | `inputs_ids` passed when calling [`MistralModel`]
46 | hidden_size (`int`, *optional*, defaults to 4096):
47 | Dimension of the hidden representations.
48 | intermediate_size (`int`, *optional*, defaults to 14336):
49 | Dimension of the MLP representations.
50 | num_hidden_layers (`int`, *optional*, defaults to 32):
51 | Number of hidden layers in the Transformer encoder.
52 | num_attention_heads (`int`, *optional*, defaults to 32):
53 | Number of attention heads for each attention layer in the Transformer encoder.
54 | num_key_value_heads (`int`, *optional*, defaults to 8):
55 | This is the number of key_value heads that should be used to implement Grouped Query Attention. If
56 | `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
57 | `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
58 | converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
59 | by meanpooling all the original heads within that group. For more details checkout [this
60 | paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `8`.
61 | hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
62 | The non-linear activation function (function or string) in the decoder.
63 | max_position_embeddings (`int`, *optional*, defaults to `4096*32`):
64 | The maximum sequence length that this model might ever be used with. Mistral's sliding window attention
65 | allows sequence of up to 4096*32 tokens.
66 | initializer_range (`float`, *optional*, defaults to 0.02):
67 | The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
68 | rms_norm_eps (`float`, *optional*, defaults to 1e-06):
69 | The epsilon used by the rms normalization layers.
70 | use_cache (`bool`, *optional*, defaults to `True`):
71 | Whether or not the model should return the last key/values attentions (not used by all models). Only
72 | relevant if `config.is_decoder=True`.
73 | pad_token_id (`int`, *optional*):
74 | The id of the padding token.
75 | bos_token_id (`int`, *optional*, defaults to 1):
76 | The id of the "beginning-of-sequence" token.
77 | eos_token_id (`int`, *optional*, defaults to 2):
78 | The id of the "end-of-sequence" token.
79 | tie_word_embeddings (`bool`, *optional*, defaults to `False`):
80 | Whether the model's input and output word embeddings should be tied.
81 | rope_theta (`float`, *optional*, defaults to 10000.0):
82 | The base period of the RoPE embeddings.
83 | sliding_window (`int`, *optional*, defaults to 4096):
84 | Sliding window attention window size. If not specified, will default to `4096`.
85 | attention_dropout (`float`, *optional*, defaults to 0.0):
86 | The dropout ratio for the attention probabilities.
87 |
88 | ```python
89 | >>> from transformers import MistralModel, MistralConfig
90 |
91 | >>> # Initializing a Mistral 7B style configuration
92 | >>> configuration = MistralConfig()
93 |
94 | >>> # Initializing a model from the Mistral 7B style configuration
95 | >>> model = MistralModel(configuration)
96 |
97 | >>> # Accessing the model configuration
98 | >>> configuration = model.config
99 | ```"""
100 |
101 | model_type = "mistral"
102 | keys_to_ignore_at_inference = ["past_key_values"]
103 |
104 | def __init__(
105 | self,
106 | vocab_size=32000,
107 | hidden_size=4096,
108 | intermediate_size=14336,
109 | num_hidden_layers=32,
110 | num_attention_heads=32,
111 | num_key_value_heads=8,
112 | hidden_act="silu",
113 | max_position_embeddings=4096 * 32,
114 | initializer_range=0.02,
115 | rms_norm_eps=1e-6,
116 | use_cache=True,
117 | pad_token_id=None,
118 | bos_token_id=1,
119 | eos_token_id=2,
120 | tie_word_embeddings=False,
121 | rope_theta=10000.0,
122 | attention_dropout=0.0,
123 | num_experts_per_token=2,
124 | num_experts=8,
125 | single_expert=False,
126 | **kwargs,
127 | ):
128 | self.vocab_size = vocab_size
129 | self.max_position_embeddings = max_position_embeddings
130 | self.hidden_size = hidden_size
131 | self.intermediate_size = intermediate_size
132 | self.num_hidden_layers = num_hidden_layers
133 | self.num_attention_heads = num_attention_heads
134 |
135 | # for backward compatibility
136 | if num_key_value_heads is None:
137 | num_key_value_heads = num_attention_heads
138 |
139 | self.num_key_value_heads = num_key_value_heads
140 | self.hidden_act = hidden_act
141 | self.initializer_range = initializer_range
142 | self.rms_norm_eps = rms_norm_eps
143 | self.use_cache = use_cache
144 | self.rope_theta = rope_theta
145 | self.attention_dropout = attention_dropout
146 | self.num_experts = num_experts
147 | self.num_experts_per_token = num_experts_per_token
148 |
149 | self.single_expert = single_expert
150 |
151 | super().__init__(
152 | pad_token_id=pad_token_id,
153 | bos_token_id=bos_token_id,
154 | eos_token_id=eos_token_id,
155 | tie_word_embeddings=tie_word_embeddings,
156 | **kwargs,
157 | )
158 |
--------------------------------------------------------------------------------
/mixtral_base/modeling_moe_mistral.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023 Mistral AI and the HuggingFace Inc. team. All rights reserved.
3 | #
4 | # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5 | # and OPT implementations in this library. It has been modified from its
6 | # original forms to accommodate minor architectural differences compared
7 | # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8 | #
9 | # Licensed under the Apache License, Version 2.0 (the "License");
10 | # you may not use this file except in compliance with the License.
11 | # You may obtain a copy of the License at
12 | #
13 | # http://www.apache.org/licenses/LICENSE-2.0
14 | #
15 | # Unless required by applicable law or agreed to in writing, software
16 | # distributed under the License is distributed on an "AS IS" BASIS,
17 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18 | # See the License for the specific language governing permissions and
19 | # limitations under the License.
20 | """ PyTorch Mistral model."""
21 | import inspect
22 | import math
23 | import warnings
24 | from typing import List, Optional, Tuple, Union
25 |
26 | import torch
27 | import torch.nn.functional as F
28 | import torch.utils.checkpoint
29 | from torch import nn
30 | from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
31 |
32 | from transformers.activations import ACT2FN
33 | from transformers.cache_utils import Cache, DynamicCache
34 | from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
35 | from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
36 | from transformers.modeling_utils import PreTrainedModel
37 | from transformers.utils import (
38 | add_start_docstrings,
39 | add_start_docstrings_to_model_forward,
40 | is_flash_attn_2_available,
41 | is_flash_attn_greater_or_equal_2_10,
42 | logging,
43 | replace_return_docstrings,
44 | )
45 | from .configuration_moe_mistral import MixtralConfig
46 |
47 |
48 |
49 | if is_flash_attn_2_available():
50 | from flash_attn import flash_attn_func, flash_attn_varlen_func
51 | from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
52 |
53 | _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
54 |
55 |
56 | logger = logging.get_logger(__name__)
57 |
58 | _CONFIG_FOR_DOC = "MixtralConfig"
59 |
60 |
61 | # Copied from transformers.models.llama.modeling_llama._get_unpad_data
62 | def _get_unpad_data(attention_mask):
63 | seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
64 | indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
65 | max_seqlen_in_batch = seqlens_in_batch.max().item()
66 | cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
67 | return (
68 | indices,
69 | cu_seqlens,
70 | max_seqlen_in_batch,
71 | )
72 |
73 |
74 | # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Mistral
75 | class MistralRMSNorm(nn.Module):
76 | def __init__(self, hidden_size, eps=1e-6):
77 | """
78 | MistralRMSNorm is equivalent to T5LayerNorm
79 | """
80 | super().__init__()
81 | self.weight = nn.Parameter(torch.ones(hidden_size))
82 | self.variance_epsilon = eps
83 |
84 | def forward(self, hidden_states):
85 | input_dtype = hidden_states.dtype
86 | hidden_states = hidden_states.to(torch.float32)
87 | variance = hidden_states.pow(2).mean(-1, keepdim=True)
88 | hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
89 | return self.weight * hidden_states.to(input_dtype)
90 |
91 |
92 | # Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Mistral
93 | class MistralRotaryEmbedding(nn.Module):
94 | def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
95 | super().__init__()
96 |
97 | self.dim = dim
98 | self.max_position_embeddings = max_position_embeddings
99 | self.base = base
100 | inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
101 | self.register_buffer("inv_freq", inv_freq, persistent=False)
102 |
103 | # Build here to make `torch.jit.trace` work.
104 | self._set_cos_sin_cache(
105 | seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
106 | )
107 |
108 | def _set_cos_sin_cache(self, seq_len, device, dtype):
109 | self.max_seq_len_cached = seq_len
110 | t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
111 |
112 | freqs = torch.outer(t, self.inv_freq)
113 | # Different from paper, but it uses a different permutation in order to obtain the same calculation
114 | emb = torch.cat((freqs, freqs), dim=-1)
115 | self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
116 | self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
117 |
118 | def forward(self, x, seq_len=None):
119 | # x: [bs, num_attention_heads, seq_len, head_size]
120 | if seq_len > self.max_seq_len_cached:
121 | self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
122 |
123 | return (
124 | self.cos_cached[:seq_len].to(dtype=x.dtype),
125 | self.sin_cached[:seq_len].to(dtype=x.dtype),
126 | )
127 |
128 |
129 | # Copied from transformers.models.llama.modeling_llama.rotate_half
130 | def rotate_half(x):
131 | """Rotates half the hidden dims of the input."""
132 | x1 = x[..., : x.shape[-1] // 2]
133 | x2 = x[..., x.shape[-1] // 2 :]
134 | return torch.cat((-x2, x1), dim=-1)
135 |
136 |
137 | # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
138 | def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
139 | """Applies Rotary Position Embedding to the query and key tensors.
140 |
141 | Args:
142 | q (`torch.Tensor`): The query tensor.
143 | k (`torch.Tensor`): The key tensor.
144 | cos (`torch.Tensor`): The cosine part of the rotary embedding.
145 | sin (`torch.Tensor`): The sine part of the rotary embedding.
146 | position_ids (`torch.Tensor`):
147 | The position indices of the tokens corresponding to the query and key tensors. For example, this can be
148 | used to pass offsetted position ids when working with a KV-cache.
149 | unsqueeze_dim (`int`, *optional*, defaults to 1):
150 | The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
151 | sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
152 | that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
153 | k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
154 | cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
155 | the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
156 | Returns:
157 | `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
158 | """
159 | cos = cos[position_ids].unsqueeze(unsqueeze_dim)
160 | sin = sin[position_ids].unsqueeze(unsqueeze_dim)
161 | q_embed = (q * cos) + (rotate_half(q) * sin)
162 | k_embed = (k * cos) + (rotate_half(k) * sin)
163 | return q_embed, k_embed
164 |
165 |
166 | class FeedForward(nn.Module):
167 | def __init__(
168 | self,
169 | config
170 | ):
171 | """
172 | Initialize the FeedForward module.
173 |
174 | Args:
175 | dim (int): Input dimension.
176 | hidden_dim (int): Hidden dimension of the feedforward layer.
177 | multiple_of (int): Value to ensure hidden dimension is a multiple of this value.
178 | ffn_dim_multiplier (float, optional): Custom multiplier for hidden dimension. Defaults to None.
179 |
180 | Attributes:
181 | w1 (ColumnParallelLinear): Linear transformation for the first layer.
182 | w2 (RowParallelLinear): Linear transformation for the second layer.
183 | w3 (ColumnParallelLinear): Linear transformation for the third layer.
184 |
185 | """
186 | super().__init__()
187 |
188 | self.w1 = nn.Linear(
189 | config.hidden_size, config.intermediate_size, bias=False
190 | )
191 | self.w2 = nn.Linear(
192 | config.intermediate_size, config.hidden_size, bias=False
193 | )
194 | self.w3 = nn.Linear(
195 | config.hidden_size, config.intermediate_size, bias=False
196 | )
197 |
198 | def forward(self, x):
199 | return self.w2(F.silu(self.w1(x)) * self.w3(x))
200 |
201 |
202 | class MoE(nn.Module):
203 | def __init__(
204 | self,
205 | config,
206 | ):
207 | super().__init__()
208 | self.config = config
209 | if self.config.single_expert:
210 | print("using single expert!")
211 | num_experts = config.num_experts
212 | self.experts = nn.ModuleList([FeedForward(config) for i in range(num_experts)])
213 | self.gate = nn.Linear(config.hidden_size, num_experts, bias=False)
214 | self.num_experts_per_token = config.num_experts_per_token
215 |
216 | def forward(self, x):
217 | if self.config.single_expert:
218 | return self.experts[0](x)
219 | orig_shape = x.shape
220 | x = x.view(-1, x.shape[-1])
221 |
222 | scores = self.gate(x)
223 | expert_weights, expert_indices = torch.topk(scores, self.num_experts_per_token, dim=-1)
224 | expert_weights = expert_weights.softmax(dim=-1)
225 | if hasattr(self.config, 'one_gate') and (self.config.one_gate):
226 | # print("using one gate!")
227 | expert_weights = torch.ones_like(expert_weights)
228 | if hasattr(self.config, 'norm_one_gate') and (self.config.norm_one_gate):
229 | # print("using normed one gate!")
230 | expert_weights = expert_weights / self.num_experts_per_token
231 | flat_expert_indices = expert_indices.view(-1)
232 |
233 | x = x.repeat_interleave(self.num_experts_per_token, dim=0)
234 | y = torch.empty_like(x)
235 | for i, expert in enumerate(self.experts):
236 | y[flat_expert_indices == i] = expert(x[flat_expert_indices == i])
237 | y = (y.view(*expert_weights.shape, -1) * expert_weights.unsqueeze(-1)).sum(dim=1)
238 | return y.view(*orig_shape)
239 |
240 |
241 | # Copied from transformers.models.llama.modeling_llama.repeat_kv
242 | def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
243 | """
244 | This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
245 | num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
246 | """
247 | batch, num_key_value_heads, slen, head_dim = hidden_states.shape
248 | if n_rep == 1:
249 | return hidden_states
250 | hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
251 | return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
252 |
253 |
254 | class MistralAttention(nn.Module):
255 | """
256 | Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
257 | and "Generating Long Sequences with Sparse Transformers".
258 | """
259 |
260 | def __init__(self, config: MixtralConfig, layer_idx: Optional[int] = None):
261 | super().__init__()
262 | self.config = config
263 | self.layer_idx = layer_idx
264 | if layer_idx is None:
265 | logger.warning_once(
266 | f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
267 | "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
268 | "when creating this class."
269 | )
270 |
271 | self.hidden_size = config.hidden_size
272 | self.num_heads = config.num_attention_heads
273 | self.head_dim = self.hidden_size // self.num_heads
274 | self.num_key_value_heads = config.num_key_value_heads
275 | self.num_key_value_groups = self.num_heads // self.num_key_value_heads
276 | self.max_position_embeddings = config.max_position_embeddings
277 | self.rope_theta = config.rope_theta
278 | self.is_causal = True
279 | self.attention_dropout = config.attention_dropout
280 |
281 | if (self.head_dim * self.num_heads) != self.hidden_size:
282 | raise ValueError(
283 | f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
284 | f" and `num_heads`: {self.num_heads})."
285 | )
286 | self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
287 | self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
288 | self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
289 | self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
290 |
291 | self.rotary_emb = MistralRotaryEmbedding(
292 | self.head_dim,
293 | max_position_embeddings=self.max_position_embeddings,
294 | base=self.rope_theta,
295 | )
296 |
297 | def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
298 | return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
299 |
300 | def forward(
301 | self,
302 | hidden_states: torch.Tensor,
303 | attention_mask: Optional[torch.Tensor] = None,
304 | position_ids: Optional[torch.LongTensor] = None,
305 | past_key_value: Optional[Cache] = None,
306 | output_attentions: bool = False,
307 | use_cache: bool = False,
308 | **kwargs,
309 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
310 | if "padding_mask" in kwargs:
311 | warnings.warn(
312 | "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
313 | )
314 | bsz, q_len, _ = hidden_states.size()
315 |
316 | query_states = self.q_proj(hidden_states)
317 | key_states = self.k_proj(hidden_states)
318 | value_states = self.v_proj(hidden_states)
319 |
320 | query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
321 | key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
322 | value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
323 |
324 | kv_seq_len = key_states.shape[-2]
325 | if past_key_value is not None:
326 | if self.layer_idx is None:
327 | raise ValueError(
328 | f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
329 | "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
330 | "with a layer index."
331 | )
332 | kv_seq_len += past_key_value.get_seq_length(self.layer_idx)
333 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
334 | query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
335 |
336 | if past_key_value is not None:
337 | cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
338 | key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
339 |
340 | # repeat k/v heads if n_kv_heads < n_heads
341 | key_states = repeat_kv(key_states, self.num_key_value_groups)
342 | value_states = repeat_kv(value_states, self.num_key_value_groups)
343 |
344 | attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
345 |
346 | if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
347 | raise ValueError(
348 | f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
349 | f" {attn_weights.size()}"
350 | )
351 |
352 | if attention_mask is not None:
353 | if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
354 | raise ValueError(
355 | f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
356 | )
357 |
358 | attn_weights = attn_weights + attention_mask
359 |
360 | # upcast attention to fp32
361 | attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
362 | attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
363 | attn_output = torch.matmul(attn_weights, value_states)
364 |
365 | if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
366 | raise ValueError(
367 | f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
368 | f" {attn_output.size()}"
369 | )
370 |
371 | attn_output = attn_output.transpose(1, 2).contiguous()
372 | attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
373 |
374 | attn_output = self.o_proj(attn_output)
375 |
376 | if not output_attentions:
377 | attn_weights = None
378 |
379 | return attn_output, attn_weights, past_key_value
380 |
381 |
382 | class MistralFlashAttention2(MistralAttention):
383 | """
384 | Mistral flash attention module. This module inherits from `MistralAttention` as the weights of the module stays
385 | untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
386 | flash attention and deal with padding tokens in case the input contains any of them.
387 | """
388 |
389 | # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
390 | def __init__(self, *args, **kwargs):
391 | super().__init__(*args, **kwargs)
392 |
393 | # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
394 | # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
395 | # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
396 | self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
397 |
398 | def forward(
399 | self,
400 | hidden_states: torch.Tensor,
401 | attention_mask: Optional[torch.Tensor] = None,
402 | position_ids: Optional[torch.LongTensor] = None,
403 | past_key_value: Optional[Cache] = None,
404 | output_attentions: bool = False,
405 | use_cache: bool = False,
406 | **kwargs,
407 | ):
408 | if "padding_mask" in kwargs:
409 | warnings.warn(
410 | "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
411 | )
412 |
413 | # overwrite attention_mask with padding_mask
414 | attention_mask = kwargs.pop("padding_mask")
415 | bsz, q_len, _ = hidden_states.size()
416 |
417 | query_states = self.q_proj(hidden_states)
418 | key_states = self.k_proj(hidden_states)
419 | value_states = self.v_proj(hidden_states)
420 |
421 | query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
422 | key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
423 | value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
424 |
425 | kv_seq_len = key_states.shape[-2]
426 | if past_key_value is not None:
427 | kv_seq_len += past_key_value.get_seq_length(self.layer_idx)
428 |
429 | # Because the input can be padded, the absolute sequence length depends on the max position id.
430 | rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1
431 | cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len)
432 |
433 | query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
434 |
435 | use_sliding_windows = (
436 | _flash_supports_window_size
437 | and getattr(self.config, "sliding_window", None) is not None
438 | and kv_seq_len > self.config.sliding_window
439 | )
440 |
441 | if not _flash_supports_window_size:
442 | logger.warning_once(
443 | "The current flash attention version does not support sliding window attention, for a more memory efficient implementation"
444 | " make sure to upgrade flash-attn library."
445 | )
446 |
447 | if past_key_value is not None:
448 | # Activate slicing cache only if the config has a value `sliding_windows` attribute
449 | if getattr(self.config, "sliding_window", None) is not None and kv_seq_len > self.config.sliding_window:
450 | slicing_tokens = 1 - self.config.sliding_window
451 |
452 | past_key = past_key_value[0]
453 | past_value = past_key_value[1]
454 |
455 | past_key = past_key[:, :, slicing_tokens:, :].contiguous()
456 | past_value = past_value[:, :, slicing_tokens:, :].contiguous()
457 |
458 | if past_key.shape[-2] != self.config.sliding_window - 1:
459 | raise ValueError(
460 | f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got"
461 | f" {past_key.shape}"
462 | )
463 |
464 | past_key_value = (past_key, past_value)
465 |
466 | if attention_mask is not None:
467 | attention_mask = attention_mask[:, slicing_tokens:]
468 | attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1)
469 |
470 | cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
471 | key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
472 |
473 | # repeat k/v heads if n_kv_heads < n_heads
474 | key_states = repeat_kv(key_states, self.num_key_value_groups)
475 | value_states = repeat_kv(value_states, self.num_key_value_groups)
476 | dropout_rate = 0.0 if not self.training else self.attention_dropout
477 |
478 | # In PEFT, usually we cast the layer norms in float32 for training stability reasons
479 | # therefore the input hidden states gets silently casted in float32. Hence, we need
480 | # cast them back in float16 just to be sure everything works as expected.
481 | input_dtype = query_states.dtype
482 | if input_dtype == torch.float32:
483 | # Handle the case where the model is quantized
484 | if hasattr(self.config, "_pre_quantization_dtype"):
485 | target_dtype = self.config._pre_quantization_dtype
486 | else:
487 | target_dtype = self.q_proj.weight.dtype
488 |
489 | logger.warning_once(
490 | f"The input hidden states seems to be silently casted in float32, this might be related to"
491 | f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
492 | f" {target_dtype}."
493 | )
494 |
495 | query_states = query_states.to(target_dtype)
496 | key_states = key_states.to(target_dtype)
497 | value_states = value_states.to(target_dtype)
498 |
499 | # Reashape to the expected shape for Flash Attention
500 | query_states = query_states.transpose(1, 2)
501 | key_states = key_states.transpose(1, 2)
502 | value_states = value_states.transpose(1, 2)
503 |
504 | attn_output = self._flash_attention_forward(
505 | query_states,
506 | key_states,
507 | value_states,
508 | attention_mask,
509 | q_len,
510 | dropout=dropout_rate,
511 | use_sliding_windows=use_sliding_windows,
512 | )
513 |
514 | attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
515 | attn_output = self.o_proj(attn_output)
516 |
517 | if not output_attentions:
518 | attn_weights = None
519 |
520 | return attn_output, attn_weights, past_key_value
521 |
522 | def _flash_attention_forward(
523 | self,
524 | query_states,
525 | key_states,
526 | value_states,
527 | attention_mask,
528 | query_length,
529 | dropout=0.0,
530 | softmax_scale=None,
531 | use_sliding_windows=False,
532 | ):
533 | """
534 | Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
535 | first unpad the input, then computes the attention scores and pad the final attention scores.
536 |
537 | Args:
538 | query_states (`torch.Tensor`):
539 | Input query states to be passed to Flash Attention API
540 | key_states (`torch.Tensor`):
541 | Input key states to be passed to Flash Attention API
542 | value_states (`torch.Tensor`):
543 | Input value states to be passed to Flash Attention API
544 | attention_mask (`torch.Tensor`):
545 | The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
546 | position of padding tokens and 1 for the position of non-padding tokens.
547 | dropout (`int`, *optional*):
548 | Attention dropout
549 | softmax_scale (`float`, *optional*):
550 | The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
551 | use_sliding_windows (`bool`, *optional*):
552 | Whether to activate sliding window attention.
553 | """
554 | if not self._flash_attn_uses_top_left_mask:
555 | causal = self.is_causal
556 | else:
557 | # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
558 | causal = self.is_causal and query_length != 1
559 |
560 | # Contains at least one padding token in the sequence
561 | if attention_mask is not None:
562 | batch_size = query_states.shape[0]
563 | query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
564 | query_states, key_states, value_states, attention_mask, query_length
565 | )
566 |
567 | cu_seqlens_q, cu_seqlens_k = cu_seq_lens
568 | max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
569 |
570 | if not use_sliding_windows:
571 | attn_output_unpad = flash_attn_varlen_func(
572 | query_states,
573 | key_states,
574 | value_states,
575 | cu_seqlens_q=cu_seqlens_q,
576 | cu_seqlens_k=cu_seqlens_k,
577 | max_seqlen_q=max_seqlen_in_batch_q,
578 | max_seqlen_k=max_seqlen_in_batch_k,
579 | dropout_p=dropout,
580 | softmax_scale=softmax_scale,
581 | causal=causal,
582 | )
583 | else:
584 | attn_output_unpad = flash_attn_varlen_func(
585 | query_states,
586 | key_states,
587 | value_states,
588 | cu_seqlens_q=cu_seqlens_q,
589 | cu_seqlens_k=cu_seqlens_k,
590 | max_seqlen_q=max_seqlen_in_batch_q,
591 | max_seqlen_k=max_seqlen_in_batch_k,
592 | dropout_p=dropout,
593 | softmax_scale=softmax_scale,
594 | causal=causal,
595 | window_size=(self.config.sliding_window, self.config.sliding_window),
596 | )
597 |
598 | attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
599 | else:
600 | if not use_sliding_windows:
601 | attn_output = flash_attn_func(
602 | query_states,
603 | key_states,
604 | value_states,
605 | dropout,
606 | softmax_scale=softmax_scale,
607 | causal=causal,
608 | )
609 | else:
610 | attn_output = flash_attn_func(
611 | query_states,
612 | key_states,
613 | value_states,
614 | dropout,
615 | softmax_scale=softmax_scale,
616 | causal=causal,
617 | window_size=(self.config.sliding_window, self.config.sliding_window),
618 | )
619 |
620 | return attn_output
621 |
622 | def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
623 | batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape
624 |
625 | # On the first iteration we need to properly re-create the padding mask
626 | # by slicing it on the proper place
627 | if kv_seq_len != attention_mask.shape[-1]:
628 | attention_mask_num_tokens = attention_mask.shape[-1]
629 | attention_mask = attention_mask[:, attention_mask_num_tokens - kv_seq_len :]
630 |
631 | indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
632 |
633 | key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
634 | value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
635 |
636 | if query_length == kv_seq_len:
637 | query_layer = index_first_axis(
638 | query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k
639 | )
640 | cu_seqlens_q = cu_seqlens_k
641 | max_seqlen_in_batch_q = max_seqlen_in_batch_k
642 | indices_q = indices_k
643 | elif query_length == 1:
644 | max_seqlen_in_batch_q = 1
645 | cu_seqlens_q = torch.arange(
646 | batch_size + 1, dtype=torch.int32, device=query_layer.device
647 | ) # There is a memcpy here, that is very bad.
648 | indices_q = cu_seqlens_q[:-1]
649 | query_layer = query_layer.squeeze(1)
650 | else:
651 | # The -q_len: slice assumes left padding.
652 | attention_mask = attention_mask[:, -query_length:]
653 | query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
654 |
655 | return (
656 | query_layer,
657 | key_layer,
658 | value_layer,
659 | indices_q,
660 | (cu_seqlens_q, cu_seqlens_k),
661 | (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
662 | )
663 |
664 |
665 | class MistralDecoderLayer(nn.Module):
666 | def __init__(self, config: MixtralConfig, layer_idx: int):
667 | super().__init__()
668 | self.hidden_size = config.hidden_size
669 | self.self_attn = (
670 | MistralAttention(config=config, layer_idx=layer_idx)
671 | if not getattr(config, "_flash_attn_2_enabled", False)
672 | else MistralFlashAttention2(config, layer_idx=layer_idx)
673 | )
674 | self.mlp = MoE(config)
675 | self.input_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
676 | self.post_attention_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
677 |
678 | def forward(
679 | self,
680 | hidden_states: torch.Tensor,
681 | attention_mask: Optional[torch.Tensor] = None,
682 | position_ids: Optional[torch.LongTensor] = None,
683 | past_key_value: Optional[Tuple[torch.Tensor]] = None,
684 | output_attentions: Optional[bool] = False,
685 | use_cache: Optional[bool] = False,
686 | **kwargs,
687 | ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
688 | if "padding_mask" in kwargs:
689 | warnings.warn(
690 | "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
691 | )
692 | """
693 | Args:
694 | hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
695 | attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
696 | `(batch, sequence_length)` where padding elements are indicated by 0.
697 | output_attentions (`bool`, *optional*):
698 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under
699 | returned tensors for more detail.
700 | use_cache (`bool`, *optional*):
701 | If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
702 | (see `past_key_values`).
703 | past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
704 | """
705 |
706 | residual = hidden_states
707 |
708 | hidden_states = self.input_layernorm(hidden_states)
709 |
710 | # Self Attention
711 | hidden_states, self_attn_weights, present_key_value = self.self_attn(
712 | hidden_states=hidden_states,
713 | attention_mask=attention_mask,
714 | position_ids=position_ids,
715 | past_key_value=past_key_value,
716 | output_attentions=output_attentions,
717 | use_cache=use_cache,
718 | )
719 | hidden_states = residual + hidden_states
720 |
721 | # Fully Connected
722 | residual = hidden_states
723 | hidden_states = self.post_attention_layernorm(hidden_states)
724 | hidden_states = self.mlp(hidden_states)
725 | hidden_states = residual + hidden_states
726 |
727 | outputs = (hidden_states,)
728 |
729 | if output_attentions:
730 | outputs += (self_attn_weights,)
731 |
732 | if use_cache:
733 | outputs += (present_key_value,)
734 |
735 | return outputs
736 |
737 |
738 | MISTRAL_START_DOCSTRING = r"""
739 | This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
740 | library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
741 | etc.)
742 |
743 | This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
744 | Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
745 | and behavior.
746 |
747 | Parameters:
748 | config ([`MixtralConfig`]):
749 | Model configuration class with all the parameters of the model. Initializing with a config file does not
750 | load the weights associated with the model, only the configuration. Check out the
751 | [`~PreTrainedModel.from_pretrained`] method to load the model weights.
752 | """
753 |
754 |
755 | @add_start_docstrings(
756 | "The bare Mistral Model outputting raw hidden-states without any specific head on top.",
757 | MISTRAL_START_DOCSTRING,
758 | )
759 | class MistralPreTrainedModel(PreTrainedModel):
760 | config_class = MixtralConfig
761 | base_model_prefix = "model"
762 | supports_gradient_checkpointing = True
763 | _no_split_modules = ["MistralDecoderLayer"]
764 | _skip_keys_device_placement = "past_key_values"
765 | _supports_flash_attn_2 = True
766 | _supports_cache_class = True
767 |
768 | def _init_weights(self, module):
769 | std = self.config.initializer_range
770 | if isinstance(module, nn.Linear):
771 | module.weight.data.normal_(mean=0.0, std=std)
772 | if module.bias is not None:
773 | module.bias.data.zero_()
774 | elif isinstance(module, nn.Embedding):
775 | module.weight.data.normal_(mean=0.0, std=std)
776 | if module.padding_idx is not None:
777 | module.weight.data[module.padding_idx].zero_()
778 |
779 |
780 | MISTRAL_INPUTS_DOCSTRING = r"""
781 | Args:
782 | input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
783 | Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
784 | it.
785 |
786 | Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
787 | [`PreTrainedTokenizer.__call__`] for details.
788 |
789 | [What are input IDs?](../glossary#input-ids)
790 | attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
791 | Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
792 |
793 | - 1 for tokens that are **not masked**,
794 | - 0 for tokens that are **masked**.
795 |
796 | [What are attention masks?](../glossary#attention-mask)
797 |
798 | Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
799 | [`PreTrainedTokenizer.__call__`] for details.
800 |
801 | If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
802 | `past_key_values`).
803 |
804 | If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
805 | and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
806 | information on the default strategy.
807 |
808 | - 1 indicates the head is **not masked**,
809 | - 0 indicates the head is **masked**.
810 | position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
811 | Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
812 | config.n_positions - 1]`.
813 |
814 | [What are position IDs?](../glossary#position-ids)
815 | past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
816 | Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
817 | blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
818 | returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
819 |
820 | Two formats are allowed:
821 | - a [`~cache_utils.Cache`] instance;
822 | - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
823 | shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
824 | cache format.
825 |
826 | The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
827 | legacy cache format will be returned.
828 |
829 | If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
830 | have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
831 | of shape `(batch_size, sequence_length)`.
832 | inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
833 | Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
834 | is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
835 | model's internal embedding lookup matrix.
836 | use_cache (`bool`, *optional*):
837 | If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
838 | `past_key_values`).
839 | output_attentions (`bool`, *optional*):
840 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
841 | tensors for more detail.
842 | output_hidden_states (`bool`, *optional*):
843 | Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
844 | more detail.
845 | return_dict (`bool`, *optional*):
846 | Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
847 | """
848 |
849 |
850 | @add_start_docstrings(
851 | "The bare Mistral Model outputting raw hidden-states without any specific head on top.",
852 | MISTRAL_START_DOCSTRING,
853 | )
854 | class MistralModel(MistralPreTrainedModel):
855 | """
856 | Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MistralDecoderLayer`]
857 |
858 | Args:
859 | config: MixtralConfig
860 | """
861 |
862 | def __init__(self, config: MixtralConfig):
863 | super().__init__(config)
864 | self.padding_idx = config.pad_token_id
865 | self.vocab_size = config.vocab_size
866 |
867 | self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
868 | self.layers = nn.ModuleList(
869 | [MistralDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
870 | )
871 | self.norm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
872 |
873 | self.gradient_checkpointing = False
874 | # Initialize weights and apply final processing
875 | self.post_init()
876 |
877 | def get_input_embeddings(self):
878 | return self.embed_tokens
879 |
880 | def set_input_embeddings(self, value):
881 | self.embed_tokens = value
882 |
883 | @add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING)
884 | def forward(
885 | self,
886 | input_ids: torch.LongTensor = None,
887 | attention_mask: Optional[torch.Tensor] = None,
888 | position_ids: Optional[torch.LongTensor] = None,
889 | past_key_values: Optional[List[torch.FloatTensor]] = None,
890 | inputs_embeds: Optional[torch.FloatTensor] = None,
891 | use_cache: Optional[bool] = None,
892 | output_attentions: Optional[bool] = None,
893 | output_hidden_states: Optional[bool] = None,
894 | return_dict: Optional[bool] = None,
895 | decoder_layer_idx: Optional[int] = None,
896 | ) -> Union[Tuple, BaseModelOutputWithPast]:
897 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
898 | output_hidden_states = (
899 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
900 | )
901 | use_cache = use_cache if use_cache is not None else self.config.use_cache
902 |
903 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
904 |
905 | # retrieve input_ids and inputs_embeds
906 | if input_ids is not None and inputs_embeds is not None:
907 | raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
908 | elif input_ids is not None:
909 | batch_size, seq_length = input_ids.shape
910 | elif inputs_embeds is not None:
911 | batch_size, seq_length, _ = inputs_embeds.shape
912 | else:
913 | raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
914 |
915 | seq_length_with_past = seq_length
916 | past_key_values_length = 0
917 |
918 | if use_cache:
919 | use_legacy_cache = not isinstance(past_key_values, Cache)
920 | if use_legacy_cache:
921 | past_key_values = DynamicCache.from_legacy_cache(past_key_values)
922 | past_key_values_length = past_key_values.get_seq_length()
923 | seq_length_with_past = seq_length_with_past + past_key_values_length
924 |
925 | if position_ids is None:
926 | device = input_ids.device if input_ids is not None else inputs_embeds.device
927 | position_ids = torch.arange(
928 | past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
929 | )
930 | position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
931 | else:
932 | position_ids = position_ids.view(-1, seq_length).long()
933 |
934 | if inputs_embeds is None:
935 | inputs_embeds = self.embed_tokens(input_ids)
936 |
937 | if (
938 | attention_mask is not None
939 | and hasattr(self.config, "_flash_attn_2_enabled")
940 | and self.config._flash_attn_2_enabled
941 | and use_cache
942 | ):
943 | is_padding_right = attention_mask[:, -1].sum().item() != batch_size
944 | if is_padding_right:
945 | raise ValueError(
946 | "You are attempting to perform batched generation with padding_side='right'"
947 | " this may lead to unexpected behaviour for Flash Attention version of Mistral. Make sure to "
948 | " call `tokenizer.padding_side = 'left'` before tokenizing the input. "
949 | )
950 |
951 | if getattr(self.config, "_flash_attn_2_enabled", False):
952 | # 2d mask is passed through the layers
953 | attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
954 | else:
955 | # 4d mask is passed through the layers
956 | attention_mask = _prepare_4d_causal_attention_mask(
957 | attention_mask,
958 | (batch_size, seq_length),
959 | inputs_embeds,
960 | past_key_values_length
961 | )
962 |
963 | hidden_states = inputs_embeds
964 |
965 | if self.gradient_checkpointing and self.training:
966 | if use_cache:
967 | logger.warning_once(
968 | "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
969 | )
970 | use_cache = False
971 |
972 | # decoder layers
973 | all_hidden_states = () if output_hidden_states else None
974 | all_self_attns = () if output_attentions else None
975 | next_decoder_cache = None
976 |
977 | for i, decoder_layer in enumerate(self.layers):
978 | if decoder_layer_idx is not None and i != decoder_layer_idx:
979 | continue
980 | if output_hidden_states:
981 | all_hidden_states += (hidden_states,)
982 |
983 | if self.gradient_checkpointing and self.training:
984 | layer_outputs = self._gradient_checkpointing_func(
985 | decoder_layer.__call__,
986 | hidden_states,
987 | attention_mask,
988 | position_ids,
989 | past_key_values,
990 | output_attentions,
991 | use_cache,
992 | )
993 | else:
994 | layer_outputs = decoder_layer(
995 | hidden_states,
996 | attention_mask=attention_mask,
997 | position_ids=position_ids,
998 | past_key_value=past_key_values,
999 | output_attentions=output_attentions,
1000 | use_cache=use_cache,
1001 | )
1002 |
1003 | hidden_states = layer_outputs[0]
1004 |
1005 | if use_cache:
1006 | next_decoder_cache = layer_outputs[2 if output_attentions else 1]
1007 |
1008 | if output_attentions:
1009 | all_self_attns += (layer_outputs[1],)
1010 |
1011 | hidden_states = self.norm(hidden_states)
1012 |
1013 | # add hidden states from the last decoder layer
1014 | if output_hidden_states:
1015 | all_hidden_states += (hidden_states,)
1016 |
1017 | next_cache = None
1018 | if use_cache:
1019 | next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
1020 |
1021 | if not return_dict:
1022 | return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
1023 | return BaseModelOutputWithPast(
1024 | last_hidden_state=hidden_states,
1025 | past_key_values=next_cache,
1026 | hidden_states=all_hidden_states,
1027 | attentions=all_self_attns,
1028 | )
1029 |
1030 |
1031 | class MixtralForCausalLM(MistralPreTrainedModel):
1032 | _tied_weights_keys = ["lm_head.weight"]
1033 |
1034 | def __init__(self, config):
1035 | super().__init__(config)
1036 | self.model = MistralModel(config)
1037 | self.vocab_size = config.vocab_size
1038 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1039 |
1040 | # Initialize weights and apply final processing
1041 | self.post_init()
1042 |
1043 | def get_input_embeddings(self):
1044 | return self.model.embed_tokens
1045 |
1046 | def set_input_embeddings(self, value):
1047 | self.model.embed_tokens = value
1048 |
1049 | def get_output_embeddings(self):
1050 | return self.lm_head
1051 |
1052 | def set_output_embeddings(self, new_embeddings):
1053 | self.lm_head = new_embeddings
1054 |
1055 | def set_decoder(self, decoder):
1056 | self.model = decoder
1057 |
1058 | def get_decoder(self):
1059 | return self.model
1060 |
1061 | def _init_weights(self, module):
1062 | return
1063 |
1064 | @add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING)
1065 | @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1066 | def forward(
1067 | self,
1068 | input_ids: torch.LongTensor = None,
1069 | attention_mask: Optional[torch.Tensor] = None,
1070 | position_ids: Optional[torch.LongTensor] = None,
1071 | past_key_values: Optional[List[torch.FloatTensor]] = None,
1072 | inputs_embeds: Optional[torch.FloatTensor] = None,
1073 | labels: Optional[torch.LongTensor] = None,
1074 | use_cache: Optional[bool] = None,
1075 | output_attentions: Optional[bool] = None,
1076 | output_hidden_states: Optional[bool] = None,
1077 | return_dict: Optional[bool] = None,
1078 | decoder_layer_idx: Optional[int] = None,
1079 | ) -> Union[Tuple, CausalLMOutputWithPast]:
1080 | r"""
1081 | Args:
1082 | labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1083 | Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1084 | config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1085 | (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1086 |
1087 | Returns:
1088 |
1089 | Example:
1090 |
1091 | ```python
1092 | >>> from transformers import AutoTokenizer, MistralForCausalLM
1093 |
1094 | >>> model = MistralForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
1095 | >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
1096 |
1097 | >>> prompt = "Hey, are you conscious? Can you talk to me?"
1098 | >>> inputs = tokenizer(prompt, return_tensors="pt")
1099 |
1100 | >>> # Generate
1101 | >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1102 | >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1103 | "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
1104 | ```"""
1105 |
1106 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1107 | output_hidden_states = (
1108 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1109 | )
1110 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1111 |
1112 | # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1113 | outputs = self.model(
1114 | input_ids=input_ids,
1115 | attention_mask=attention_mask,
1116 | position_ids=position_ids,
1117 | past_key_values=past_key_values,
1118 | inputs_embeds=inputs_embeds,
1119 | use_cache=use_cache,
1120 | output_attentions=output_attentions,
1121 | output_hidden_states=output_hidden_states,
1122 | return_dict=return_dict,
1123 | decoder_layer_idx=decoder_layer_idx,
1124 | )
1125 |
1126 | hidden_states = outputs[0]
1127 | logits = self.lm_head(hidden_states)
1128 | logits = logits.float()
1129 |
1130 | loss = None
1131 | if labels is not None:
1132 | # Shift so that tokens < n predict n
1133 | shift_logits = logits[..., :-1, :].contiguous()
1134 | shift_labels = labels[..., 1:].contiguous()
1135 | # Flatten the tokens
1136 | loss_fct = CrossEntropyLoss()
1137 | shift_logits = shift_logits.view(-1, self.config.vocab_size)
1138 | shift_labels = shift_labels.view(-1)
1139 | # Enable model parallelism
1140 | shift_labels = shift_labels.to(shift_logits.device)
1141 | loss = loss_fct(shift_logits, shift_labels)
1142 |
1143 | if not return_dict:
1144 | output = (logits,) + outputs[1:]
1145 | return (loss,) + output if loss is not None else output
1146 |
1147 | return CausalLMOutputWithPast(
1148 | loss=loss,
1149 | logits=logits,
1150 | past_key_values=outputs.past_key_values,
1151 | hidden_states=outputs.hidden_states,
1152 | attentions=outputs.attentions,
1153 | )
1154 |
1155 | def prepare_inputs_for_generation(
1156 | self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
1157 | ):
1158 | # Omit tokens covered by past_key_values
1159 | if past_key_values is not None:
1160 | if isinstance(past_key_values, Cache):
1161 | cache_length = past_key_values.get_seq_length()
1162 | past_length = past_key_values.seen_tokens
1163 | else:
1164 | cache_length = past_length = past_key_values[0][0].shape[2]
1165 |
1166 | # Keep only the unprocessed tokens:
1167 | # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
1168 | # some of the inputs are exclusivelly passed as part of the cache (e.g. when passing input_embeds as
1169 | # input)
1170 | if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
1171 | input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
1172 | # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
1173 | # input_ids based on the past_length.
1174 | elif past_length < input_ids.shape[1]:
1175 | input_ids = input_ids[:, past_length:]
1176 | # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
1177 |
1178 | # If the cache has seen more tokens than it can hold, then the cache has a size limit. Let's discard the
1179 | # older attention values, as their corresponding values are not part of the input.
1180 | if cache_length < past_length and attention_mask is not None:
1181 | attention_mask = attention_mask[:, -(cache_length + input_ids.shape[1]) :]
1182 |
1183 | position_ids = kwargs.get("position_ids", None)
1184 | if attention_mask is not None and position_ids is None:
1185 | # create position_ids on the fly for batch generation
1186 | position_ids = attention_mask.long().cumsum(-1) - 1
1187 | position_ids.masked_fill_(attention_mask == 0, 1)
1188 | if past_key_values:
1189 | position_ids = position_ids[:, -input_ids.shape[1] :]
1190 |
1191 | # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1192 | if inputs_embeds is not None and past_key_values is None:
1193 | model_inputs = {"inputs_embeds": inputs_embeds}
1194 | else:
1195 | model_inputs = {"input_ids": input_ids}
1196 |
1197 | model_inputs.update(
1198 | {
1199 | "position_ids": position_ids,
1200 | "past_key_values": past_key_values,
1201 | "use_cache": kwargs.get("use_cache"),
1202 | "attention_mask": attention_mask,
1203 | }
1204 | )
1205 | return model_inputs
1206 |
1207 | @staticmethod
1208 | def _reorder_cache(past_key_values, beam_idx):
1209 | reordered_past = ()
1210 | for layer_past in past_key_values:
1211 | reordered_past += (
1212 | tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
1213 | )
1214 | return reordered_past
1215 |
1216 |
1217 | @add_start_docstrings(
1218 | """
1219 | The Mistral Model transformer with a sequence classification head on top (linear layer).
1220 |
1221 | [`MistralForSequenceClassification`] uses the last token in order to do the classification, as other causal models
1222 | (e.g. GPT-2) do.
1223 |
1224 | Since it does classification on the last token, it requires to know the position of the last token. If a
1225 | `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
1226 | no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
1227 | padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
1228 | each row of the batch).
1229 | """,
1230 | MISTRAL_START_DOCSTRING,
1231 | )
1232 | # Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->Mistral, LLAMA->MISTRAL
1233 | class MistralForSequenceClassification(MistralPreTrainedModel):
1234 | def __init__(self, config):
1235 | super().__init__(config)
1236 | self.num_labels = config.num_labels
1237 | self.model = MistralModel(config)
1238 | self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
1239 |
1240 | # Initialize weights and apply final processing
1241 | self.post_init()
1242 |
1243 | def get_input_embeddings(self):
1244 | return self.model.embed_tokens
1245 |
1246 | def set_input_embeddings(self, value):
1247 | self.model.embed_tokens = value
1248 |
1249 | @add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING)
1250 | def forward(
1251 | self,
1252 | input_ids: torch.LongTensor = None,
1253 | attention_mask: Optional[torch.Tensor] = None,
1254 | position_ids: Optional[torch.LongTensor] = None,
1255 | past_key_values: Optional[List[torch.FloatTensor]] = None,
1256 | inputs_embeds: Optional[torch.FloatTensor] = None,
1257 | labels: Optional[torch.LongTensor] = None,
1258 | use_cache: Optional[bool] = None,
1259 | output_attentions: Optional[bool] = None,
1260 | output_hidden_states: Optional[bool] = None,
1261 | return_dict: Optional[bool] = None,
1262 | ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1263 | r"""
1264 | labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1265 | Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1266 | config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1267 | `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1268 | """
1269 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1270 |
1271 | transformer_outputs = self.model(
1272 | input_ids,
1273 | attention_mask=attention_mask,
1274 | position_ids=position_ids,
1275 | past_key_values=past_key_values,
1276 | inputs_embeds=inputs_embeds,
1277 | use_cache=use_cache,
1278 | output_attentions=output_attentions,
1279 | output_hidden_states=output_hidden_states,
1280 | return_dict=return_dict,
1281 | )
1282 | hidden_states = transformer_outputs[0]
1283 | logits = self.score(hidden_states)
1284 |
1285 | if input_ids is not None:
1286 | batch_size = input_ids.shape[0]
1287 | else:
1288 | batch_size = inputs_embeds.shape[0]
1289 |
1290 | if self.config.pad_token_id is None and batch_size != 1:
1291 | raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
1292 | if self.config.pad_token_id is None:
1293 | sequence_lengths = -1
1294 | else:
1295 | if input_ids is not None:
1296 | sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1).to(
1297 | logits.device
1298 | )
1299 | else:
1300 | sequence_lengths = -1
1301 |
1302 | pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
1303 |
1304 | loss = None
1305 | if labels is not None:
1306 | labels = labels.to(logits.device)
1307 | if self.config.problem_type is None:
1308 | if self.num_labels == 1:
1309 | self.config.problem_type = "regression"
1310 | elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1311 | self.config.problem_type = "single_label_classification"
1312 | else:
1313 | self.config.problem_type = "multi_label_classification"
1314 |
1315 | if self.config.problem_type == "regression":
1316 | loss_fct = MSELoss()
1317 | if self.num_labels == 1:
1318 | loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
1319 | else:
1320 | loss = loss_fct(pooled_logits, labels)
1321 | elif self.config.problem_type == "single_label_classification":
1322 | loss_fct = CrossEntropyLoss()
1323 | loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
1324 | elif self.config.problem_type == "multi_label_classification":
1325 | loss_fct = BCEWithLogitsLoss()
1326 | loss = loss_fct(pooled_logits, labels)
1327 | if not return_dict:
1328 | output = (pooled_logits,) + transformer_outputs[1:]
1329 | return ((loss,) + output) if loss is not None else output
1330 |
1331 | return SequenceClassifierOutputWithPast(
1332 | loss=loss,
1333 | logits=pooled_logits,
1334 | past_key_values=transformer_outputs.past_key_values,
1335 | hidden_states=transformer_outputs.hidden_states,
1336 | attentions=transformer_outputs.attentions,
1337 | )
--------------------------------------------------------------------------------
/mixtral_base22/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "architectures": [
3 | "MixtralForCausalLM"
4 | ],
5 | "attention_dropout": 0.0,
6 | "bos_token_id": 1,
7 | "eos_token_id": 2,
8 | "hidden_act": "silu",
9 | "hidden_size": 6144,
10 | "initializer_range": 0.02,
11 | "intermediate_size": 16384,
12 | "max_position_embeddings": 65536,
13 | "model_type": "mixtral",
14 | "num_attention_heads": 48,
15 | "num_experts_per_tok": 2,
16 | "num_hidden_layers": 56,
17 | "num_key_value_heads": 8,
18 | "num_experts": 8,
19 | "rms_norm_eps": 1e-05,
20 | "rope_theta": 1000000,
21 | "tie_word_embeddings": false,
22 | "torch_dtype": "bfloat16",
23 | "transformers_version": "4.38.0",
24 | "use_cache": true,
25 | "vocab_size": 32000
26 | }
27 |
--------------------------------------------------------------------------------
/mixtral_instruct/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "architectures": [
3 | "MixtralForCausalLM"
4 | ],
5 | "attention_dropout": 0.0,
6 | "bos_token_id": 1,
7 | "eos_token_id": 2,
8 | "hidden_act": "silu",
9 | "hidden_size": 4096,
10 | "initializer_range": 0.02,
11 | "intermediate_size": 14336,
12 | "max_position_embeddings": 32768,
13 | "model_type": "mixtral",
14 | "num_attention_heads": 32,
15 | "num_experts_per_tok": 2,
16 | "num_hidden_layers": 32,
17 | "num_key_value_heads": 8,
18 | "num_local_experts": 8,
19 | "output_router_logits": false,
20 | "rms_norm_eps": 1e-05,
21 | "rope_theta": 1000000.0,
22 | "router_aux_loss_coef": 0.02,
23 | "sliding_window": null,
24 | "tie_word_embeddings": false,
25 | "torch_dtype": "bfloat16",
26 | "transformers_version": "4.36.0.dev0",
27 | "use_cache": true,
28 | "vocab_size": 32000
29 | }
30 |
--------------------------------------------------------------------------------
/mixtral_instruct/configuration_moe_mistral.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023 Mistral AI and the HuggingFace Inc. team. All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """ Mistral model configuration"""
16 |
17 | from transformers.configuration_utils import PretrainedConfig
18 | from transformers.utils import logging
19 |
20 |
21 | logger = logging.get_logger(__name__)
22 |
23 | MISTRAL_PRETRAINED_CONFIG_ARCHIVE_MAP = {
24 | "mistralai/Mistral-7B-v0.1": "https://huggingface.co/mistralai/Mistral-7B-v0.1/resolve/main/config.json",
25 | "mistralai/Mistral-7B-Instruct-v0.1": "https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1/resolve/main/config.json",
26 | }
27 |
28 |
29 | class MixtralConfig(PretrainedConfig):
30 | r"""
31 | This is the configuration class to store the configuration of a [`MistralModel`]. It is used to instantiate an
32 | Mistral model according to the specified arguments, defining the model architecture. Instantiating a configuration
33 | with the defaults will yield a similar configuration to that of the Mistral-7B-v0.1 or Mistral-7B-Instruct-v0.1.
34 |
35 | [mistralai/Mistral-7B-v0.1](https://huggingface.co/mistralai/Mistral-7B-v0.1)
36 | [mistralai/Mistral-7B-Instruct-v0.1](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1)
37 |
38 | Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
39 | documentation from [`PretrainedConfig`] for more information.
40 |
41 |
42 | Args:
43 | vocab_size (`int`, *optional*, defaults to 32000):
44 | Vocabulary size of the Mistral model. Defines the number of different tokens that can be represented by the
45 | `inputs_ids` passed when calling [`MistralModel`]
46 | hidden_size (`int`, *optional*, defaults to 4096):
47 | Dimension of the hidden representations.
48 | intermediate_size (`int`, *optional*, defaults to 14336):
49 | Dimension of the MLP representations.
50 | num_hidden_layers (`int`, *optional*, defaults to 32):
51 | Number of hidden layers in the Transformer encoder.
52 | num_attention_heads (`int`, *optional*, defaults to 32):
53 | Number of attention heads for each attention layer in the Transformer encoder.
54 | num_key_value_heads (`int`, *optional*, defaults to 8):
55 | This is the number of key_value heads that should be used to implement Grouped Query Attention. If
56 | `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
57 | `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
58 | converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
59 | by meanpooling all the original heads within that group. For more details checkout [this
60 | paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `8`.
61 | hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
62 | The non-linear activation function (function or string) in the decoder.
63 | max_position_embeddings (`int`, *optional*, defaults to `4096*32`):
64 | The maximum sequence length that this model might ever be used with. Mistral's sliding window attention
65 | allows sequence of up to 4096*32 tokens.
66 | initializer_range (`float`, *optional*, defaults to 0.02):
67 | The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
68 | rms_norm_eps (`float`, *optional*, defaults to 1e-06):
69 | The epsilon used by the rms normalization layers.
70 | use_cache (`bool`, *optional*, defaults to `True`):
71 | Whether or not the model should return the last key/values attentions (not used by all models). Only
72 | relevant if `config.is_decoder=True`.
73 | pad_token_id (`int`, *optional*):
74 | The id of the padding token.
75 | bos_token_id (`int`, *optional*, defaults to 1):
76 | The id of the "beginning-of-sequence" token.
77 | eos_token_id (`int`, *optional*, defaults to 2):
78 | The id of the "end-of-sequence" token.
79 | tie_word_embeddings (`bool`, *optional*, defaults to `False`):
80 | Whether the model's input and output word embeddings should be tied.
81 | rope_theta (`float`, *optional*, defaults to 10000.0):
82 | The base period of the RoPE embeddings.
83 | sliding_window (`int`, *optional*, defaults to 4096):
84 | Sliding window attention window size. If not specified, will default to `4096`.
85 | attention_dropout (`float`, *optional*, defaults to 0.0):
86 | The dropout ratio for the attention probabilities.
87 |
88 | ```python
89 | >>> from transformers import MistralModel, MistralConfig
90 |
91 | >>> # Initializing a Mistral 7B style configuration
92 | >>> configuration = MistralConfig()
93 |
94 | >>> # Initializing a model from the Mistral 7B style configuration
95 | >>> model = MistralModel(configuration)
96 |
97 | >>> # Accessing the model configuration
98 | >>> configuration = model.config
99 | ```"""
100 |
101 | model_type = "mistral"
102 | keys_to_ignore_at_inference = ["past_key_values"]
103 |
104 | def __init__(
105 | self,
106 | vocab_size=32000,
107 | hidden_size=4096,
108 | intermediate_size=14336,
109 | num_hidden_layers=32,
110 | num_attention_heads=32,
111 | num_key_value_heads=8,
112 | hidden_act="silu",
113 | max_position_embeddings=4096 * 32,
114 | initializer_range=0.02,
115 | rms_norm_eps=1e-6,
116 | use_cache=True,
117 | pad_token_id=None,
118 | bos_token_id=1,
119 | eos_token_id=2,
120 | tie_word_embeddings=False,
121 | rope_theta=10000.0,
122 | attention_dropout=0.0,
123 | num_experts_per_token=2,
124 | num_experts=8,
125 | single_expert=False,
126 | **kwargs,
127 | ):
128 | self.vocab_size = vocab_size
129 | self.max_position_embeddings = max_position_embeddings
130 | self.hidden_size = hidden_size
131 | self.intermediate_size = intermediate_size
132 | self.num_hidden_layers = num_hidden_layers
133 | self.num_attention_heads = num_attention_heads
134 |
135 | # for backward compatibility
136 | if num_key_value_heads is None:
137 | num_key_value_heads = num_attention_heads
138 |
139 | self.num_key_value_heads = num_key_value_heads
140 | self.hidden_act = hidden_act
141 | self.initializer_range = initializer_range
142 | self.rms_norm_eps = rms_norm_eps
143 | self.use_cache = use_cache
144 | self.rope_theta = rope_theta
145 | self.attention_dropout = attention_dropout
146 | self.num_experts = num_experts
147 | self.num_experts_per_token = num_experts_per_token
148 |
149 | self.single_expert = single_expert
150 |
151 | super().__init__(
152 | pad_token_id=pad_token_id,
153 | bos_token_id=bos_token_id,
154 | eos_token_id=eos_token_id,
155 | tie_word_embeddings=tie_word_embeddings,
156 | **kwargs,
157 | )
158 |
--------------------------------------------------------------------------------