├── README.md ├── configuration_qwen2_parscale.py ├── cost_analysis.py ├── figures ├── 1t.png ├── cost.png ├── logo.jpg ├── scaling_comparison.png ├── scaling_law.png ├── scaling_law2.png └── teaser.png ├── modeling_qwen2_parscale.py └── parametric_fit.py /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | 4 | # Parallel Scaling Law for Language Model 5 | 6 | 7 | _Yet Another Scaling Law beyond Parameters and Inference Time Scaling_ 8 | 9 | [![Paper](https://img.shields.io/badge/arXiv-2505.10475-red)](https://arxiv.org/abs/2505.10475) 10 | [![huggingface](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Models-FFD21E)](https://huggingface.co/ParScale) 11 | 12 |
13 | 14 |
15 | 16 | 17 |

18 | 💡 Key Findings 19 | | 📈 Scaling Law 20 | | ⚡ Cost Analysis 21 | | 🔥 Models 22 | | 📚 Citation 23 |

24 |
25 | 26 | ## 🌟 About 27 | 28 | - Most believe that scaling language models requires a heavy cost in either **space** (parameter scaling) or **time** (inference-time scaling). 29 | - We introduce the *third* scaling paradigm for scaling LLMs: leverages **parallel computation** during both training and inference time (Parallel Scaling, or *ParScale*). 30 | - We apply $P$ diverse and learnable transformations to the input, execute forward passes of the model in parallel, and dynamically aggregate the $P$ outputs. 31 |
32 | 33 |
34 | 35 | --- 36 | 37 | ## 💡 Key Findings 38 |
39 | 40 |
41 | 42 | Here are the core insights and benefits distilled from our theoretical analysis and empirical evaluations: 43 | 44 | 📈 **Logarithmic Scaling Law**: We theoretically and empirically establish that **scaling with $P$ parallel streams is comparable to scaling the number of parameters by** $O(\log P)$. This suggests that parallel computation can serve as an efficient substitute for parameter growth, especially for larger models. 45 | 46 | ✅ **Universal Applicability**: Unlike inference-time scaling which requires specialized data and limited application, it works with any model architecture, optimization method, data, or downstream task. 47 | 48 | 49 | 🧠 **Stronger Performance on Reasoning Tasks**: Reasoning-intensive tasks (e.g., coding or math) benefit more from ParScale, which suggests that scaling computation can effectively push the boundary of reasoning. 50 | 51 | ⚡ **Superior Inference Efficiency**: ParScale can use up to **22x less memory increase** and **6x less latency increase** compared to parameter scaling that achieves the same performance improvement (batch size=1). 52 | 53 | 🧱 **Cost-Efficient Training via Two-Stage Strategy**: Training a parallel-scaled model doesn't require starting from scratch. With a two-stage training strategy, we can post-train ithe parallel components using only a small amount of data. 54 | 55 | 🔁 **Dynamic Adaptation at Inference Time**: We find that ParScale remains effective with frozen main parameters for different $P$. This illustrates the potential of dynamic parallel scaling: switching $P$ to dynamically adapt model capabilities during inference. 56 | 57 | We release the inference code in `modeling_qwen2_parscale.py` and `configuration_qwen2_parscale.py`. Our 67 checkpoints is available at [🤗 HuggingFace](https://huggingface.co/ParScale). 58 | 59 | --- 60 | 61 | ## 📈 Scaling Law 62 | 63 | - We carry out large-scale pre-training experiments on the Stack-V2 and Pile corpus, by ranging $P$ from 1 to 8 and model parameters from 500M to 4.4B. 64 | - We use the results to fit a new *parallel scaling law* that generalizes the Chinchilla scaling law. 65 | - We release our parametric fitting code in `parametric_fit.py`. 66 | - Feel free to try [🤗 HuggingFace Space](https://huggingface.co/spaces/ParScale/Parallel_Scaling_Law) for a nice visualization for the parallel scaling law! 67 |
68 | 69 | 70 |
71 | 72 | --- 73 | 74 | ## ⚡ Cost Analysis 75 | 76 |
77 | 78 |
79 | 80 | - We further compare the inference efficiency between parallel scaling and parameter scaling at equivalent performance levels. 81 | - We release our analysis code in `cost_analysis.py`. Before using it, you should first install [llm-analysis](https://github.com/cli99/llm-analysis): 82 | 83 | ```bash 84 | git clone https://github.com/cli99/llm-analysis.git 85 | cd llm-analysis 86 | pip install . 87 | ``` 88 | 89 | - You can use the following command to analyze the inference memory and latency cost for our 4.4B model, with $P=2$ and batch size=2: 90 | ```bash 91 | python cost_analysis.py --hidden_size 2560 --intermediate_size 13824 --P 2 --batch_size 2 92 | ``` 93 | 94 | --- 95 | 96 | ## 🔥 Models 97 | 98 | ✨ are our recommendation for strong models! 99 | 100 | ### Base models for scaling training data to 1T tokens 101 | 102 | These models demonstrate strong competitiveness among existing small models, including SmolLM, gemma, and Llama-3.2. 103 | 104 | |Model|Description|Download| 105 | |:-:|:-:|:-:| 106 | |ParScale-1.8B-P1|✨ Baseline $P=1$|[🤗 ParScale/ParScale-1.8B-P1](https://huggingface.co/ParScale/ParScale-1.8B-P1)| 107 | |ParScale-1.8B-P2|✨ ParScale $P=2$|[🤗 ParScale/ParScale-1.8B-P2](https://huggingface.co/ParScale/ParScale-1.8B-P2)| 108 | |ParScale-1.8B-P4|✨ ParScale $P=4$|[🤗 ParScale/ParScale-1.8B-P4](https://huggingface.co/ParScale/ParScale-1.8B-P4)| 109 | |ParScale-1.8B-P8|✨ ParScale $P=8$|[🤗 ParScale/ParScale-1.8B-P8](https://huggingface.co/ParScale/ParScale-1.8B-P8)| 110 | 111 | ### Instruct models for scaling training data to 1T tokens 112 | 113 | We post-trained the aforementioned base model on SmolTalk-1M to enable conversational capabilities. 114 | 115 | |Model|Description|Download| 116 | |:-:|:-:|:-:| 117 | |ParScale-1.8B-P1-Inst|✨ Baseline $P=1$|[🤗 ParScale/ParScale-1.8B-P1-Inst](https://huggingface.co/ParScale/ParScale-1.8B-P1-Inst)| 118 | |ParScale-1.8B-P2-Inst|✨ ParScale $P=2$|[🤗 ParScale/ParScale-1.8B-P2-Inst](https://huggingface.co/ParScale/ParScale-1.8B-P2-Inst)| 119 | |ParScale-1.8B-P4-Inst|✨ ParScale $P=4$|[🤗 ParScale/ParScale-1.8B-P4-Inst](https://huggingface.co/ParScale/ParScale-1.8B-P4-Inst)| 120 | |ParScale-1.8B-P8-Inst|✨ ParScale $P=8$|[🤗 ParScale/ParScale-1.8B-P8-Inst](https://huggingface.co/ParScale/ParScale-1.8B-P8-Inst)| 121 | 122 | 123 | ### Continual Pretraining Qwen-2.5-3B 124 | 125 | We froze the parameters of Qwen-2.5-3B and only fine-tuned the newly introduced parameters on Stack-V2-Python. Since the following models share the same backbone parameters as Qwen-2.5-3B, they have the potential for dynamic ParScale: switching P to adapt model capabilities during inference. 126 | 127 | |Model|Description|Download| 128 | |:-:|:-:|:-:| 129 | |ParScale-Qwen-3B-P2-Python|✨ ParScale $P=2$|[🤗 ParScale/ParScale-Qwen-3B-P2-Python](https://huggingface.co/ParScale/ParScale-Qwen-3B-P2-Python)| 130 | |ParScale-Qwen-3B-P4-Python|✨ ParScale $P=4$|[🤗 ParScale/ParScale-Qwen-3B-P4-Python](https://huggingface.co/ParScale/ParScale-Qwen-3B-P4-Python)| 131 | |ParScale-Qwen-3B-P8-Python|✨ ParScale $P=8$|[🤗 ParScale/ParScale-Qwen-3B-P8-Python](https://huggingface.co/ParScale/ParScale-Qwen-3B-P8-Python)| 132 | 133 | - For full continual pretraining on Stack-V2-Python 134 | 135 | |Model|Description|Download| 136 | |:-:|:-:|:-:| 137 | |ParScale-QwenInit-3B-P1-Python|Baseline $P=1$|[🤗 ParScale/ParScale-QwenInit-3B-P1-Python](https://huggingface.co/ParScale/ParScale-QwenInit-3B-P1-Python)| 138 | |ParScale-QwenInit-3B-P2-Python|ParScale $P=2$|[🤗 ParScale/ParScale-QwenInit-3B-P2-Python](https://huggingface.co/ParScale/ParScale-QwenInit-3B-P2-Python)| 139 | |ParScale-QwenInit-3B-P4-Python|ParScale $P=4$|[🤗 ParScale/ParScale-QwenInit-3B-P4-Python](https://huggingface.co/ParScale/ParScale-QwenInit-3B-P4-Python)| 140 | |ParScale-QwenInit-3B-P8-Python|ParScale $P=8$|[🤗 ParScale/ParScale-QwenInit-3B-P8-Python](https://huggingface.co/ParScale/ParScale-QwenInit-3B-P8-Python)| 141 | 142 | - For full continual pretraining on Pile 143 | 144 | |Model|Description|Download| 145 | |:-:|:-:|:-:| 146 | |ParScale-QwenInit-3B-P1-Pile|Baseline $P=1$|[🤗 ParScale/ParScale-QwenInit-3B-P1-Pile](https://huggingface.co/ParScale/ParScale-QwenInit-3B-P1-Pile)| 147 | |ParScale-QwenInit-3B-P2-Pile|ParScale $P=2$|[🤗 ParScale/ParScale-QwenInit-3B-P2-Pile](https://huggingface.co/ParScale/ParScale-QwenInit-3B-P2-Pile)| 148 | |ParScale-QwenInit-3B-P4-Pile|ParScale $P=4$|[🤗 ParScale/ParScale-QwenInit-3B-P4-Pile](https://huggingface.co/ParScale/ParScale-QwenInit-3B-P4-Pile)| 149 | |ParScale-QwenInit-3B-P8-Pile|ParScale $P=8$|[🤗 ParScale/ParScale-QwenInit-3B-P8-Pile](https://huggingface.co/ParScale/ParScale-QwenInit-3B-P8-Pile)| 150 | 151 | 152 | ### Checkpoints Used to Fit the Scaling Law 153 | 154 | Download link: https://huggingface.co/ParScale/ParScale-{size}-{P}-{dataset} 155 | 156 | - {size}: model size, from {0.7B, 0.9B, 1.3B, 1.8B, 3B, 4.7B} 157 | - {P}: number of parallels, from {P1, P2, P4, P8} 158 | - {dataset}: training dataset, from {Python, Pile} 159 | - $6\times 4 \times 2=48$ checkpoints in total. 160 | 161 | ### Usage Example with 🤗 Hugging Face 162 | 163 | ```python 164 | from transformers import AutoModelForCausalLM, AutoTokenizer 165 | name = "ParScale/ParScale-1.8B-P8" # or anything else you like 166 | model = AutoModelForCausalLM.from_pretrained(name, trust_remote_code=True).to("cuda") 167 | tokenizer = AutoTokenizer.from_pretrained(name) 168 | inputs = tokenizer.encode("Hello, how are you today?", return_tensors="pt").to("cuda") 169 | outputs = model.generate(inputs, max_new_tokens=128)[0] 170 | print(tokenizer.decode(outputs)) 171 | ``` 172 | 173 | 174 | ## 📚 Citation 175 | 176 | ```bibtex 177 | @article{ParScale, 178 | title={Parallel Scaling Law for Language Models}, 179 | author={Mouxiang Chen and Binyuan Hui and Zeyu Cui and Jiaxi Yang and Dayiheng Liu and Jianling Sun and Junyang Lin and Zhongxin Liu}, 180 | year={2025}, 181 | eprint={2505.10475}, 182 | archivePrefix={arXiv}, 183 | primaryClass={cs.LG}, 184 | journal={arXiv preprint arXiv:2505.10475}, 185 | url={https://arxiv.org/abs/2505.10475}, 186 | } 187 | ``` 188 | -------------------------------------------------------------------------------- /configuration_qwen2_parscale.py: -------------------------------------------------------------------------------- 1 | """Qwen2 model configuration, with support for ParScale""" 2 | 3 | from transformers.configuration_utils import PretrainedConfig 4 | from transformers.modeling_rope_utils import rope_config_validation 5 | from transformers.utils import logging 6 | 7 | 8 | logger = logging.get_logger(__name__) 9 | 10 | 11 | class Qwen2ParScaleConfig(PretrainedConfig): 12 | r""" 13 | This is the configuration class to store the configuration of a [`Qwen2Model`]. It is used to instantiate a 14 | Qwen2 model according to the specified arguments, defining the model architecture. Instantiating a configuration 15 | with the defaults will yield a similar configuration to that of 16 | Qwen2-7B-beta [Qwen/Qwen2-7B-beta](https://huggingface.co/Qwen/Qwen2-7B-beta). 17 | 18 | Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the 19 | documentation from [`PretrainedConfig`] for more information. 20 | 21 | 22 | Args: 23 | vocab_size (`int`, *optional*, defaults to 151936): 24 | Vocabulary size of the Qwen2 model. Defines the number of different tokens that can be represented by the 25 | `inputs_ids` passed when calling [`Qwen2Model`] 26 | hidden_size (`int`, *optional*, defaults to 4096): 27 | Dimension of the hidden representations. 28 | intermediate_size (`int`, *optional*, defaults to 22016): 29 | Dimension of the MLP representations. 30 | num_hidden_layers (`int`, *optional*, defaults to 32): 31 | Number of hidden layers in the Transformer encoder. 32 | num_attention_heads (`int`, *optional*, defaults to 32): 33 | Number of attention heads for each attention layer in the Transformer encoder. 34 | num_key_value_heads (`int`, *optional*, defaults to 32): 35 | This is the number of key_value heads that should be used to implement Grouped Query Attention. If 36 | `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if 37 | `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When 38 | converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed 39 | by meanpooling all the original heads within that group. For more details checkout [this 40 | paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`. 41 | hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): 42 | The non-linear activation function (function or string) in the decoder. 43 | max_position_embeddings (`int`, *optional*, defaults to 32768): 44 | The maximum sequence length that this model might ever be used with. 45 | initializer_range (`float`, *optional*, defaults to 0.02): 46 | The standard deviation of the truncated_normal_initializer for initializing all weight matrices. 47 | rms_norm_eps (`float`, *optional*, defaults to 1e-06): 48 | The epsilon used by the rms normalization layers. 49 | use_cache (`bool`, *optional*, defaults to `True`): 50 | Whether or not the model should return the last key/values attentions (not used by all models). Only 51 | relevant if `config.is_decoder=True`. 52 | tie_word_embeddings (`bool`, *optional*, defaults to `False`): 53 | Whether the model's input and output word embeddings should be tied. 54 | rope_theta (`float`, *optional*, defaults to 10000.0): 55 | The base period of the RoPE embeddings. 56 | rope_scaling (`Dict`, *optional*): 57 | Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type 58 | and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value 59 | accordingly. 60 | Expected contents: 61 | `rope_type` (`str`): 62 | The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', 63 | 'llama3'], with 'default' being the original RoPE implementation. 64 | `factor` (`float`, *optional*): 65 | Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In 66 | most scaling types, a `factor` of x will enable the model to handle sequences of length x * 67 | original maximum pre-trained length. 68 | `original_max_position_embeddings` (`int`, *optional*): 69 | Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during 70 | pretraining. 71 | `attention_factor` (`float`, *optional*): 72 | Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention 73 | computation. If unspecified, it defaults to value recommended by the implementation, using the 74 | `factor` field to infer the suggested value. 75 | `beta_fast` (`float`, *optional*): 76 | Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear 77 | ramp function. If unspecified, it defaults to 32. 78 | `beta_slow` (`float`, *optional*): 79 | Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear 80 | ramp function. If unspecified, it defaults to 1. 81 | `short_factor` (`List[float]`, *optional*): 82 | Only used with 'longrope'. The scaling factor to be applied to short contexts (< 83 | `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden 84 | size divided by the number of attention heads divided by 2 85 | `long_factor` (`List[float]`, *optional*): 86 | Only used with 'longrope'. The scaling factor to be applied to long contexts (< 87 | `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden 88 | size divided by the number of attention heads divided by 2 89 | `low_freq_factor` (`float`, *optional*): 90 | Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE 91 | `high_freq_factor` (`float`, *optional*): 92 | Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE 93 | use_sliding_window (`bool`, *optional*, defaults to `False`): 94 | Whether to use sliding window attention. 95 | sliding_window (`int`, *optional*, defaults to 4096): 96 | Sliding window attention (SWA) window size. If not specified, will default to `4096`. 97 | max_window_layers (`int`, *optional*, defaults to 28): 98 | The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention. 99 | attention_dropout (`float`, *optional*, defaults to 0.0): 100 | The dropout ratio for the attention probabilities. 101 | 102 | ```python 103 | >>> from transformers import Qwen2Model, Qwen2Config 104 | 105 | >>> # Initializing a Qwen2 style configuration 106 | >>> configuration = Qwen2Config() 107 | 108 | >>> # Initializing a model from the Qwen2-7B style configuration 109 | >>> model = Qwen2Model(configuration) 110 | 111 | >>> # Accessing the model configuration 112 | >>> configuration = model.config 113 | ```""" 114 | 115 | model_type = "qwen2_parscale" 116 | keys_to_ignore_at_inference = ["past_key_values"] 117 | 118 | # Default tensor parallel plan for base model `Qwen2` 119 | base_model_tp_plan = { 120 | "layers.*.self_attn.q_proj": "colwise", 121 | "layers.*.self_attn.k_proj": "colwise", 122 | "layers.*.self_attn.v_proj": "colwise", 123 | "layers.*.self_attn.o_proj": "rowwise", 124 | "layers.*.mlp.gate_proj": "colwise", 125 | "layers.*.mlp.up_proj": "colwise", 126 | "layers.*.mlp.down_proj": "rowwise", 127 | } 128 | 129 | def __init__( 130 | self, 131 | vocab_size=151936, 132 | hidden_size=4096, 133 | intermediate_size=22016, 134 | num_hidden_layers=32, 135 | num_attention_heads=32, 136 | num_key_value_heads=32, 137 | hidden_act="silu", 138 | max_position_embeddings=32768, 139 | initializer_range=0.02, 140 | rms_norm_eps=1e-6, 141 | use_cache=True, 142 | tie_word_embeddings=False, 143 | rope_theta=10000.0, 144 | rope_scaling=None, 145 | use_sliding_window=False, 146 | sliding_window=4096, 147 | max_window_layers=28, 148 | attention_dropout=0.0, 149 | parscale_n=1, 150 | parscale_n_tokens=48, 151 | parscale_attn_smooth=0.01, 152 | **kwargs, 153 | ): 154 | self.vocab_size = vocab_size 155 | self.max_position_embeddings = max_position_embeddings 156 | self.hidden_size = hidden_size 157 | self.intermediate_size = intermediate_size 158 | self.num_hidden_layers = num_hidden_layers 159 | self.num_attention_heads = num_attention_heads 160 | self.use_sliding_window = use_sliding_window 161 | self.sliding_window = sliding_window if use_sliding_window else None 162 | self.max_window_layers = max_window_layers 163 | self.parscale_n = parscale_n 164 | self.parscale_n_tokens = parscale_n_tokens 165 | self.parscale_attn_smooth = parscale_attn_smooth 166 | 167 | # for backward compatibility 168 | if num_key_value_heads is None: 169 | num_key_value_heads = num_attention_heads 170 | 171 | self.num_key_value_heads = num_key_value_heads 172 | self.hidden_act = hidden_act 173 | self.initializer_range = initializer_range 174 | self.rms_norm_eps = rms_norm_eps 175 | self.use_cache = use_cache 176 | self.rope_theta = rope_theta 177 | self.rope_scaling = rope_scaling 178 | self.attention_dropout = attention_dropout 179 | # Validate the correctness of rotary position embeddings parameters 180 | # BC: if there is a 'type' field, move it to 'rope_type'. 181 | if self.rope_scaling is not None and "type" in self.rope_scaling: 182 | self.rope_scaling["rope_type"] = self.rope_scaling["type"] 183 | rope_config_validation(self) 184 | 185 | super().__init__( 186 | tie_word_embeddings=tie_word_embeddings, 187 | **kwargs, 188 | ) 189 | -------------------------------------------------------------------------------- /cost_analysis.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import json 3 | import os 4 | from llm_analysis.analysis import LLMAnalysis, get_gpu_config_by_name, ModelConfig, ActivationRecomputation, BYTES_FP16 5 | 6 | 7 | if __name__ == "__main__": 8 | import argparse 9 | 10 | parser = argparse.ArgumentParser() 11 | 12 | # General model config 13 | parser.add_argument('--hidden_size', type=int, required=True) 14 | parser.add_argument('--intermediate_size', type=int, required=True) 15 | parser.add_argument('--num_hidden_layers', type=int, default=36) 16 | parser.add_argument('--num_attention_heads', type=int, default=16) 17 | parser.add_argument('--max_position_embeddings', type=int, default=2048) 18 | parser.add_argument('--num_key_value_heads', type=int, default=2) 19 | parser.add_argument('--vocab_size', type=int, default=151936) 20 | 21 | # Parscale config 22 | parser.add_argument('--P', type=int, default=1) # Number of parallel streams 23 | parser.add_argument('--parscale_prefix_tokens', type=int, default=48) # Number of prefix tokens 24 | 25 | # Data config 26 | parser.add_argument('--batch_size', type=int, default=1) 27 | parser.add_argument('--input_length', type=int, default=64) 28 | parser.add_argument('--output_length', type=int, default=64) 29 | 30 | # GPU config 31 | parser.add_argument('--gpu_config', type=str, default="a100-sxm-80gb") 32 | parser.add_argument('--flops_efficiency', type=float, default=0.7) # Recommended by llm-analysis 33 | parser.add_argument('--hbm_memory_efficiency', type=float, default=0.9) # Recommended by llm-analysis 34 | 35 | args = parser.parse_args() 36 | p = args.P 37 | model_config = ModelConfig( 38 | name="", 39 | num_layers=args.num_hidden_layers, 40 | n_head=args.num_attention_heads, 41 | hidden_dim=args.hidden_size, vocab_size=args.vocab_size, 42 | max_seq_len=args.max_position_embeddings + (args.parscale_prefix_tokens if p > 1 else 0), 43 | num_key_value_heads=args.num_key_value_heads, 44 | ffn_embed_dim=args.intermediate_size, 45 | mlp_gated_linear_units=True 46 | ) 47 | gpu_config = get_gpu_config_by_name("a100-sxm-80gb") 48 | gpu_config.mem_per_GPU_in_GB = 10000 49 | 50 | analysis = LLMAnalysis( 51 | model_config, 52 | gpu_config, 53 | flops_efficiency=0.7, 54 | hbm_memory_efficiency=0.9, 55 | ) 56 | seq_len = args.input_length + (args.parscale_prefix_tokens if p > 1 else 0) 57 | summary_dict = analysis.inference( 58 | batch_size_per_gpu=args.batch_size * p, 59 | seq_len=seq_len, 60 | num_tokens_to_generate=args.output_length, 61 | ) 62 | 63 | # We consider the influence of the aggregation layer. 64 | aggregate_param = (args.hidden_size + 1) * args.hidden_size * p if p > 1 else 0 65 | aggregate_param_vs_fwd_param = aggregate_param / analysis.get_num_params_per_layer_mlp() 66 | aggregate_latency = aggregate_param_vs_fwd_param * analysis.get_latency_fwd_per_layer_mlp(args.batch_size, args.input_length + args.output_length) if p > 1 else 0 67 | aggregate_memory = aggregate_param * analysis.dtype_config.weight_bits / 8 68 | 69 | prefill_activation_memory_per_gpu = max( 70 | # Each layer's activation memory will increase by P times 71 | analysis.get_activation_memory_per_layer( 72 | args.batch_size * p, 73 | seq_len, 74 | is_inference=True, 75 | layernorm_dtype_bytes=BYTES_FP16, 76 | ), 77 | # The embedding's activation memory will not participate in parallel and independent of P. 78 | analysis.get_activation_memory_output_embedding( 79 | args.batch_size, seq_len 80 | ) 81 | ) 82 | 83 | # Since we use batch_size * p as the new batch size, the latency for llm-analysis assumes the embedding latency is also computed in this new batch size. However, ParScale will not increase the computation for embedding. 84 | # Therefore, we should make a fix toward it. 85 | embedding_latency_estimate_for_embedding = ( 86 | analysis.get_latency_fwd_input_embedding(args.batch_size * p, args.input_length + args.output_length, dtype_bytes=analysis.dtype_config.embedding_bits) + 87 | analysis.get_latency_fwd_output_embedding_loss(args.batch_size * p, args.input_length + args.output_length) 88 | ) 89 | embedding_latency_real_for_embedding = ( 90 | analysis.get_latency_fwd_input_embedding(args.batch_size, args.input_length + args.output_length, dtype_bytes=analysis.dtype_config.embedding_bits) + 91 | analysis.get_latency_fwd_output_embedding_loss(args.batch_size, args.input_length + args.output_length) 92 | ) 93 | 94 | total_memory = ( 95 | summary_dict['kv_cache_memory_per_gpu'] + 96 | summary_dict['weight_memory_per_gpu'] + 97 | aggregate_memory + 98 | prefill_activation_memory_per_gpu 99 | ) 100 | total_latency = ( 101 | summary_dict['total_latency'] + aggregate_latency 102 | - embedding_latency_estimate_for_embedding 103 | + embedding_latency_real_for_embedding 104 | ) 105 | print(f"Memory: {total_memory / 2**30:.3f}GB; Latency: {total_latency:.3f}s") -------------------------------------------------------------------------------- /figures/1t.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QwenLM/ParScale/cd6acb48ba6d3b1f9715785ff0d212751210177d/figures/1t.png -------------------------------------------------------------------------------- /figures/cost.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QwenLM/ParScale/cd6acb48ba6d3b1f9715785ff0d212751210177d/figures/cost.png -------------------------------------------------------------------------------- /figures/logo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QwenLM/ParScale/cd6acb48ba6d3b1f9715785ff0d212751210177d/figures/logo.jpg -------------------------------------------------------------------------------- /figures/scaling_comparison.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QwenLM/ParScale/cd6acb48ba6d3b1f9715785ff0d212751210177d/figures/scaling_comparison.png -------------------------------------------------------------------------------- /figures/scaling_law.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QwenLM/ParScale/cd6acb48ba6d3b1f9715785ff0d212751210177d/figures/scaling_law.png -------------------------------------------------------------------------------- /figures/scaling_law2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QwenLM/ParScale/cd6acb48ba6d3b1f9715785ff0d212751210177d/figures/scaling_law2.png -------------------------------------------------------------------------------- /figures/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QwenLM/ParScale/cd6acb48ba6d3b1f9715785ff0d212751210177d/figures/teaser.png -------------------------------------------------------------------------------- /modeling_qwen2_parscale.py: -------------------------------------------------------------------------------- 1 | """ 2 | This is the inference code for ParScale, Based on Qwen2. It can be used directly to load existing Qwen2 models (setting parscale_n = 1 by default). 3 | All modifications are wrapped within the condition 'parscale_n > 1'. 4 | If you are interested in how ParScale is implemented, please search for "parscale_n" in this file. 5 | """ 6 | 7 | from typing import Callable, List, Optional, Tuple, Union 8 | 9 | import torch 10 | from torch import nn 11 | from einops import repeat, rearrange 12 | 13 | from transformers.activations import ACT2FN 14 | from transformers.cache_utils import Cache, DynamicCache, StaticCache 15 | from transformers.generation import GenerationMixin 16 | from transformers.modeling_attn_mask_utils import AttentionMaskConverter 17 | from transformers.modeling_flash_attention_utils import FlashAttentionKwargs 18 | from transformers.modeling_outputs import ( 19 | BaseModelOutputWithPast, 20 | CausalLMOutputWithPast, 21 | QuestionAnsweringModelOutput, 22 | SequenceClassifierOutputWithPast, 23 | TokenClassifierOutput, 24 | ) 25 | from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS 26 | from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel 27 | from transformers.processing_utils import Unpack 28 | from transformers.utils import ( 29 | LossKwargs, 30 | add_code_sample_docstrings, 31 | add_start_docstrings, 32 | add_start_docstrings_to_model_forward, 33 | logging, 34 | replace_return_docstrings, 35 | ) 36 | from .configuration_qwen2_parscale import Qwen2ParScaleConfig 37 | from typing import Any, Dict, List, Optional, Tuple, Union 38 | 39 | 40 | logger = logging.get_logger(__name__) 41 | 42 | _CHECKPOINT_FOR_DOC = "meta-qwen2/Qwen2-2-7b-hf" 43 | _CONFIG_FOR_DOC = "Qwen2ParScaleConfig" 44 | 45 | 46 | class Qwen2MLP(nn.Module): 47 | def __init__(self, config): 48 | super().__init__() 49 | self.config = config 50 | self.hidden_size = config.hidden_size 51 | self.intermediate_size = config.intermediate_size 52 | self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) 53 | self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) 54 | self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) 55 | self.act_fn = ACT2FN[config.hidden_act] 56 | 57 | def forward(self, x): 58 | down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) 59 | return down_proj 60 | 61 | 62 | def rotate_half(x): 63 | """Rotates half the hidden dims of the input.""" 64 | x1 = x[..., : x.shape[-1] // 2] 65 | x2 = x[..., x.shape[-1] // 2 :] 66 | return torch.cat((-x2, x1), dim=-1) 67 | 68 | 69 | def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): 70 | """Applies Rotary Position Embedding to the query and key tensors. 71 | 72 | Args: 73 | q (`torch.Tensor`): The query tensor. 74 | k (`torch.Tensor`): The key tensor. 75 | cos (`torch.Tensor`): The cosine part of the rotary embedding. 76 | sin (`torch.Tensor`): The sine part of the rotary embedding. 77 | position_ids (`torch.Tensor`, *optional*): 78 | Deprecated and unused. 79 | unsqueeze_dim (`int`, *optional*, defaults to 1): 80 | The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and 81 | sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note 82 | that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and 83 | k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes 84 | cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have 85 | the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. 86 | Returns: 87 | `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. 88 | """ 89 | cos = cos.unsqueeze(unsqueeze_dim) 90 | sin = sin.unsqueeze(unsqueeze_dim) 91 | q_embed = (q * cos) + (rotate_half(q) * sin) 92 | k_embed = (k * cos) + (rotate_half(k) * sin) 93 | return q_embed, k_embed 94 | 95 | 96 | def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: 97 | """ 98 | This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, 99 | num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) 100 | """ 101 | batch, num_key_value_heads, slen, head_dim = hidden_states.shape 102 | if n_rep == 1: 103 | return hidden_states 104 | hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) 105 | return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) 106 | 107 | 108 | def eager_attention_forward( 109 | module: nn.Module, 110 | query: torch.Tensor, 111 | key: torch.Tensor, 112 | value: torch.Tensor, 113 | attention_mask: Optional[torch.Tensor], 114 | scaling: float, 115 | dropout: float = 0.0, 116 | **kwargs, 117 | ): 118 | key_states = repeat_kv(key, module.num_key_value_groups) 119 | value_states = repeat_kv(value, module.num_key_value_groups) 120 | 121 | attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling 122 | if attention_mask is not None: 123 | causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] 124 | attn_weights = attn_weights + causal_mask 125 | 126 | attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) 127 | attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) 128 | attn_output = torch.matmul(attn_weights, value_states) 129 | attn_output = attn_output.transpose(1, 2).contiguous() 130 | 131 | return attn_output, attn_weights 132 | 133 | class ParscaleCache(DynamicCache): 134 | def __init__(self, prefix_k, prefix_v) -> None: 135 | super().__init__() 136 | self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen 137 | self.key_cache: List[torch.Tensor] = prefix_k 138 | self.value_cache: List[torch.Tensor] = prefix_v 139 | self.parscale_n = prefix_k[0].size(0) 140 | self.n_prefix_tokens = prefix_k[0].size(2) 141 | def update( 142 | self, 143 | key_states: torch.Tensor, 144 | value_states: torch.Tensor, 145 | layer_idx: int, 146 | cache_kwargs: Optional[Dict[str, Any]] = None, 147 | ) -> Tuple[torch.Tensor, torch.Tensor]: 148 | if self.key_cache[layer_idx].size(0) != key_states.size(0): 149 | # first time generation 150 | self.key_cache[layer_idx] = repeat(self.key_cache[layer_idx], 'n_parscale ... -> (n_parscale b) ...', b=key_states.size(0) // self.parscale_n) 151 | self.value_cache[layer_idx] = repeat(self.value_cache[layer_idx], 'n_parscale ... -> (n_parscale b) ...', b=key_states.size(0) // self.parscale_n) 152 | return super().update(key_states, value_states, layer_idx, cache_kwargs) 153 | 154 | def get_seq_length(self, layer_idx = 0): 155 | seq_len = super().get_seq_length(layer_idx) 156 | if seq_len != 0: 157 | seq_len -= self.n_prefix_tokens 158 | return seq_len 159 | 160 | def reorder_cache(self, beam_idx: torch.LongTensor): 161 | """Reorders the cache for beam search, given the selected beam indices.""" 162 | b = self.key_cache[0].size(0) // self.parscale_n 163 | beam_idx = torch.cat([beam_idx + b * i for i in range(self.parscale_n)]) 164 | super().reorder_cache(beam_idx) 165 | 166 | class Qwen2Attention(nn.Module): 167 | """Multi-headed attention from 'Attention Is All You Need' paper""" 168 | 169 | def __init__(self, config: Qwen2ParScaleConfig, layer_idx: int): 170 | super().__init__() 171 | self.config = config 172 | self.layer_idx = layer_idx 173 | self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) 174 | self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads 175 | self.scaling = self.head_dim**-0.5 176 | self.attention_dropout = config.attention_dropout 177 | self.is_causal = True 178 | self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=True) 179 | self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True) 180 | self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True) 181 | self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) 182 | if config.parscale_n > 1: 183 | self.prefix_k = nn.Parameter(torch.empty((config.parscale_n, config.num_key_value_heads, config.parscale_n_tokens, self.head_dim))) 184 | self.prefix_v = nn.Parameter(torch.empty((config.parscale_n, config.num_key_value_heads, config.parscale_n_tokens, self.head_dim))) 185 | 186 | 187 | def forward( 188 | self, 189 | hidden_states: torch.Tensor, 190 | position_embeddings: Tuple[torch.Tensor, torch.Tensor], 191 | attention_mask: Optional[torch.Tensor], 192 | past_key_value: Optional[Cache] = None, 193 | cache_position: Optional[torch.LongTensor] = None, 194 | **kwargs: Unpack[FlashAttentionKwargs], 195 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 196 | input_shape = hidden_states.shape[:-1] 197 | hidden_shape = (*input_shape, -1, self.head_dim) 198 | 199 | query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) 200 | key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) 201 | value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) 202 | 203 | cos, sin = position_embeddings 204 | query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) 205 | 206 | if past_key_value is not None: 207 | # sin and cos are specific to RoPE models; cache_position needed for the static cache 208 | cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} 209 | key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) 210 | 211 | if self.config.parscale_n > 1: 212 | 213 | # Expand attention mask to contain the prefix tokens 214 | n_virtual_tokens = self.config.parscale_n_tokens 215 | 216 | if attention_mask is not None: 217 | attention_mask = torch.cat([ 218 | torch.zeros((attention_mask.shape[0], attention_mask.shape[1], attention_mask.shape[2], self.config.parscale_n_tokens), dtype=attention_mask.dtype, device=attention_mask.device), 219 | attention_mask 220 | ], dim=3) 221 | 222 | if query_states.size(2) != 1: 223 | query_states = torch.cat([torch.zeros([query_states.size(0), query_states.size(1), n_virtual_tokens, query_states.size(3)], dtype=query_states.dtype, device=query_states.device), query_states], dim=2) 224 | if attention_mask is not None: 225 | attention_mask = torch.cat([ 226 | torch.zeros((attention_mask.shape[0], attention_mask.shape[1], self.config.parscale_n_tokens, attention_mask.shape[3]), dtype=attention_mask.dtype, device=attention_mask.device), 227 | attention_mask 228 | ], dim=2) 229 | 230 | sliding_window = None 231 | if ( 232 | self.config.use_sliding_window 233 | and getattr(self.config, "sliding_window", None) is not None 234 | and self.layer_idx >= self.config.max_window_layers 235 | ): 236 | sliding_window = self.config.sliding_window 237 | 238 | attention_interface: Callable = eager_attention_forward 239 | if self.config._attn_implementation != "eager": 240 | if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): 241 | logger.warning_once( 242 | "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " 243 | 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' 244 | ) 245 | else: 246 | attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] 247 | 248 | attn_output, attn_weights = attention_interface( 249 | self, 250 | query_states, 251 | key_states, 252 | value_states, 253 | attention_mask, 254 | dropout=0.0 if not self.training else self.attention_dropout, 255 | scaling=self.scaling, 256 | sliding_window=sliding_window, # main diff with Llama 257 | # is_causal=True, 258 | **kwargs, 259 | ) 260 | 261 | if self.config.parscale_n > 1 and query_states.size(2) != 1: 262 | # Remove the prefix part 263 | attn_output = attn_output[:, n_virtual_tokens:] 264 | attn_output = attn_output.reshape(*input_shape, -1).contiguous() 265 | attn_output = self.o_proj(attn_output) 266 | return attn_output, attn_weights 267 | 268 | 269 | class Qwen2RMSNorm(nn.Module): 270 | def __init__(self, hidden_size, eps=1e-6): 271 | """ 272 | Qwen2RMSNorm is equivalent to T5LayerNorm 273 | """ 274 | super().__init__() 275 | self.weight = nn.Parameter(torch.ones(hidden_size)) 276 | self.variance_epsilon = eps 277 | 278 | def forward(self, hidden_states): 279 | input_dtype = hidden_states.dtype 280 | hidden_states = hidden_states.to(torch.float32) 281 | variance = hidden_states.pow(2).mean(-1, keepdim=True) 282 | hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) 283 | return self.weight * hidden_states.to(input_dtype) 284 | 285 | def extra_repr(self): 286 | return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" 287 | 288 | 289 | class Qwen2DecoderLayer(nn.Module): 290 | def __init__(self, config: Qwen2ParScaleConfig, layer_idx: int): 291 | super().__init__() 292 | self.hidden_size = config.hidden_size 293 | self.self_attn = Qwen2Attention(config=config, layer_idx=layer_idx) 294 | self.mlp = Qwen2MLP(config) 295 | self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) 296 | self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) 297 | if config.sliding_window and config._attn_implementation != "flash_attention_2": 298 | logger.warning_once( 299 | f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; " 300 | "unexpected results may be encountered." 301 | ) 302 | 303 | def forward( 304 | self, 305 | hidden_states: torch.Tensor, 306 | attention_mask: Optional[torch.Tensor] = None, 307 | position_ids: Optional[torch.LongTensor] = None, 308 | past_key_value: Optional[Cache] = None, 309 | output_attentions: Optional[bool] = False, 310 | use_cache: Optional[bool] = False, 311 | cache_position: Optional[torch.LongTensor] = None, 312 | position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC 313 | **kwargs: Unpack[FlashAttentionKwargs], 314 | ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: 315 | residual = hidden_states 316 | 317 | hidden_states = self.input_layernorm(hidden_states) 318 | 319 | # Self Attention 320 | hidden_states, self_attn_weights = self.self_attn( 321 | hidden_states=hidden_states, 322 | attention_mask=attention_mask, 323 | position_ids=position_ids, 324 | past_key_value=past_key_value, 325 | output_attentions=output_attentions, 326 | use_cache=use_cache, 327 | cache_position=cache_position, 328 | position_embeddings=position_embeddings, 329 | **kwargs, 330 | ) 331 | hidden_states = residual + hidden_states 332 | 333 | # Fully Connected 334 | residual = hidden_states 335 | hidden_states = self.post_attention_layernorm(hidden_states) 336 | hidden_states = self.mlp(hidden_states) 337 | hidden_states = residual + hidden_states 338 | 339 | outputs = (hidden_states,) 340 | if output_attentions: 341 | outputs += (self_attn_weights,) 342 | 343 | return outputs 344 | 345 | 346 | class Qwen2RotaryEmbedding(nn.Module): 347 | def __init__(self, config: Qwen2ParScaleConfig, device=None): 348 | super().__init__() 349 | # BC: "rope_type" was originally "type" 350 | if hasattr(config, "rope_scaling") and config.rope_scaling is not None: 351 | self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) 352 | else: 353 | self.rope_type = "default" 354 | self.max_seq_len_cached = config.max_position_embeddings 355 | self.original_max_seq_len = config.max_position_embeddings 356 | 357 | self.config = config 358 | self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] 359 | 360 | inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) 361 | self.register_buffer("inv_freq", inv_freq, persistent=False) 362 | self.original_inv_freq = self.inv_freq 363 | 364 | def _dynamic_frequency_update(self, position_ids, device): 365 | """ 366 | dynamic RoPE layers should recompute `inv_freq` in the following situations: 367 | 1 - growing beyond the cached sequence length (allow scaling) 368 | 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) 369 | """ 370 | seq_len = torch.max(position_ids) + 1 371 | if seq_len > self.max_seq_len_cached: # growth 372 | inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len) 373 | self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation 374 | self.max_seq_len_cached = seq_len 375 | 376 | if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset 377 | # This .to() is needed if the model has been moved to a device after being initialized (because 378 | # the buffer is automatically moved, but not the original copy) 379 | self.original_inv_freq = self.original_inv_freq.to(device) 380 | self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) 381 | self.max_seq_len_cached = self.original_max_seq_len 382 | 383 | @torch.no_grad() 384 | def forward(self, x, position_ids): 385 | if "dynamic" in self.rope_type: 386 | self._dynamic_frequency_update(position_ids, device=x.device) 387 | 388 | # Core RoPE block 389 | inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) 390 | position_ids_expanded = position_ids[:, None, :].float() 391 | # Force float32 (see https://github.com/huggingface/transformers/pull/29285) 392 | device_type = x.device.type 393 | device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" 394 | with torch.autocast(device_type=device_type, enabled=False): 395 | freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) 396 | emb = torch.cat((freqs, freqs), dim=-1) 397 | cos = emb.cos() 398 | sin = emb.sin() 399 | 400 | # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention 401 | cos = cos * self.attention_scaling 402 | sin = sin * self.attention_scaling 403 | 404 | return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) 405 | 406 | 407 | QWEN2_START_DOCSTRING = r""" 408 | This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the 409 | library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads 410 | etc.) 411 | 412 | This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. 413 | Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage 414 | and behavior. 415 | 416 | Parameters: 417 | config ([`Qwen2ParScaleConfig`]): 418 | Model configuration class with all the parameters of the model. Initializing with a config file does not 419 | load the weights associated with the model, only the configuration. Check out the 420 | [`~PreTrainedModel.from_pretrained`] method to load the model weights. 421 | """ 422 | 423 | 424 | @add_start_docstrings( 425 | "The bare Qwen2 Model outputting raw hidden-states without any specific head on top.", 426 | QWEN2_START_DOCSTRING, 427 | ) 428 | class Qwen2PreTrainedModel(PreTrainedModel): 429 | config_class = Qwen2ParScaleConfig 430 | base_model_prefix = "model" 431 | supports_gradient_checkpointing = True 432 | _no_split_modules = ["Qwen2DecoderLayer"] 433 | _skip_keys_device_placement = ["past_key_values"] 434 | _supports_flash_attn_2 = True 435 | _supports_sdpa = True 436 | _supports_flex_attn = True 437 | _supports_cache_class = True 438 | _supports_quantized_cache = True 439 | _supports_static_cache = True 440 | 441 | def _init_weights(self, module): 442 | std = self.config.initializer_range 443 | if isinstance(module, nn.Linear): 444 | module.weight.data.normal_(mean=0.0, std=std) 445 | if module.bias is not None: 446 | module.bias.data.zero_() 447 | elif isinstance(module, nn.Embedding): 448 | module.weight.data.normal_(mean=0.0, std=std) 449 | if module.padding_idx is not None: 450 | module.weight.data[module.padding_idx].zero_() 451 | 452 | 453 | QWEN2_INPUTS_DOCSTRING = r""" 454 | Args: 455 | input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): 456 | Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide 457 | it. 458 | 459 | Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and 460 | [`PreTrainedTokenizer.__call__`] for details. 461 | 462 | [What are input IDs?](../glossary#input-ids) 463 | attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): 464 | Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: 465 | 466 | - 1 for tokens that are **not masked**, 467 | - 0 for tokens that are **masked**. 468 | 469 | [What are attention masks?](../glossary#attention-mask) 470 | 471 | Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and 472 | [`PreTrainedTokenizer.__call__`] for details. 473 | 474 | If `past_key_values` is used, optionally only the last `input_ids` have to be input (see 475 | `past_key_values`). 476 | 477 | If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] 478 | and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more 479 | information on the default strategy. 480 | 481 | - 1 indicates the head is **not masked**, 482 | - 0 indicates the head is **masked**. 483 | position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): 484 | Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, 485 | config.n_positions - 1]`. 486 | 487 | [What are position IDs?](../glossary#position-ids) 488 | past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): 489 | Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention 490 | blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` 491 | returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. 492 | 493 | Two formats are allowed: 494 | - a [`~cache_utils.Cache`] instance, see our 495 | [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache); 496 | - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of 497 | shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy 498 | cache format. 499 | 500 | The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the 501 | legacy cache format will be returned. 502 | 503 | If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't 504 | have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` 505 | of shape `(batch_size, sequence_length)`. 506 | inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): 507 | Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This 508 | is useful if you want more control over how to convert `input_ids` indices into associated vectors than the 509 | model's internal embedding lookup matrix. 510 | use_cache (`bool`, *optional*): 511 | If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see 512 | `past_key_values`). 513 | output_attentions (`bool`, *optional*): 514 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned 515 | tensors for more detail. 516 | output_hidden_states (`bool`, *optional*): 517 | Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for 518 | more detail. 519 | return_dict (`bool`, *optional*): 520 | Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. 521 | cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): 522 | Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, 523 | this tensor is not affected by padding. It is used to update the cache in the correct position and to infer 524 | the complete sequence length. 525 | """ 526 | 527 | 528 | @add_start_docstrings( 529 | "The bare Qwen2 Model outputting raw hidden-states without any specific head on top.", 530 | QWEN2_START_DOCSTRING, 531 | ) 532 | class Qwen2Model(Qwen2PreTrainedModel): 533 | """ 534 | Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen2DecoderLayer`] 535 | 536 | Args: 537 | config: Qwen2ParScaleConfig 538 | """ 539 | 540 | def __init__(self, config: Qwen2ParScaleConfig): 541 | super().__init__(config) 542 | self.padding_idx = config.pad_token_id 543 | self.vocab_size = config.vocab_size 544 | 545 | self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) 546 | self.layers = nn.ModuleList( 547 | [Qwen2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] 548 | ) 549 | self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) 550 | self.rotary_emb = Qwen2RotaryEmbedding(config=config) 551 | self.gradient_checkpointing = False 552 | 553 | self.parscale_n = config.parscale_n 554 | if config.parscale_n > 1: 555 | self.aggregate_layer = torch.nn.Sequential( 556 | torch.nn.Linear(config.parscale_n * config.hidden_size, config.hidden_size), 557 | torch.nn.SiLU(), 558 | torch.nn.Linear(config.hidden_size, config.parscale_n) 559 | ) 560 | self.parscale_aggregate_attn_smoothing = config.parscale_attn_smooth 561 | 562 | # Initialize weights and apply final processing 563 | self.post_init() 564 | 565 | def get_input_embeddings(self): 566 | return self.embed_tokens 567 | 568 | def set_input_embeddings(self, value): 569 | self.embed_tokens = value 570 | 571 | @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING) 572 | def forward( 573 | self, 574 | input_ids: torch.LongTensor = None, 575 | attention_mask: Optional[torch.Tensor] = None, 576 | position_ids: Optional[torch.LongTensor] = None, 577 | past_key_values: Optional[Cache] = None, 578 | inputs_embeds: Optional[torch.FloatTensor] = None, 579 | use_cache: Optional[bool] = None, 580 | output_attentions: Optional[bool] = None, 581 | output_hidden_states: Optional[bool] = None, 582 | return_dict: Optional[bool] = None, 583 | cache_position: Optional[torch.LongTensor] = None, 584 | **flash_attn_kwargs: Unpack[FlashAttentionKwargs], 585 | ) -> Union[Tuple, BaseModelOutputWithPast]: 586 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 587 | output_hidden_states = ( 588 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 589 | ) 590 | use_cache = use_cache if use_cache is not None else self.config.use_cache 591 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 592 | 593 | if (input_ids is None) ^ (inputs_embeds is not None): 594 | raise ValueError("You must specify exactly one of input_ids or inputs_embeds") 595 | 596 | if self.gradient_checkpointing and self.training and use_cache: 597 | logger.warning_once( 598 | "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." 599 | ) 600 | use_cache = False 601 | 602 | if inputs_embeds is None: 603 | inputs_embeds = self.embed_tokens(input_ids) 604 | 605 | if self.parscale_n > 1: 606 | # Input transformation: we directly copy the input for n_parscale times. 607 | # The transformation is implemented through KVCache (ParscaleCache). 608 | inputs_embeds = repeat(inputs_embeds, "b s h -> (n_parscale b) s h", n_parscale=self.parscale_n) 609 | if attention_mask is not None: 610 | attention_mask = repeat(attention_mask, "b s -> (n_parscale b) s", n_parscale=self.parscale_n) 611 | if position_ids is not None: 612 | position_ids = repeat(position_ids, "b s -> (n_parscale b) s", n_parscale=self.parscale_n) 613 | 614 | # The trained prefix is saved in layer.self_attn.prefix_k / layer.self_attn.prefix_v 615 | # We extract them to construct ParscaleCache. 616 | if past_key_values is None or past_key_values.get_seq_length() == 0: 617 | past_key_values = ParscaleCache([layer.self_attn.prefix_k for layer in self.layers], [layer.self_attn.prefix_v for layer in self.layers]) 618 | 619 | if use_cache and past_key_values is None: 620 | past_key_values = DynamicCache() 621 | 622 | if cache_position is None: 623 | past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 624 | cache_position = torch.arange( 625 | past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device 626 | ) 627 | 628 | if position_ids is None: 629 | position_ids = cache_position.unsqueeze(0) 630 | 631 | causal_mask = self._update_causal_mask( 632 | attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions 633 | ) 634 | 635 | hidden_states = inputs_embeds 636 | 637 | # create position embeddings to be shared across the decoder layers 638 | position_embeddings = self.rotary_emb(hidden_states, position_ids) 639 | 640 | # decoder layers 641 | all_hidden_states = () if output_hidden_states else None 642 | all_self_attns = () if output_attentions else None 643 | 644 | for decoder_layer in self.layers[: self.config.num_hidden_layers]: 645 | if output_hidden_states: 646 | all_hidden_states += (hidden_states,) 647 | 648 | if self.gradient_checkpointing and self.training: 649 | layer_outputs = self._gradient_checkpointing_func( 650 | decoder_layer.__call__, 651 | hidden_states, 652 | causal_mask, 653 | position_ids, 654 | past_key_values, 655 | output_attentions, 656 | use_cache, 657 | cache_position, 658 | position_embeddings, 659 | ) 660 | else: 661 | layer_outputs = decoder_layer( 662 | hidden_states, 663 | attention_mask=causal_mask, 664 | position_ids=position_ids, 665 | past_key_value=past_key_values, 666 | output_attentions=output_attentions, 667 | use_cache=use_cache, 668 | cache_position=cache_position, 669 | position_embeddings=position_embeddings, 670 | **flash_attn_kwargs, 671 | ) 672 | 673 | hidden_states = layer_outputs[0] 674 | 675 | if output_attentions: 676 | all_self_attns += (layer_outputs[1],) 677 | 678 | hidden_states = self.norm(hidden_states) 679 | 680 | if self.parscale_n > 1: 681 | # output aggregation, based on dynamic weighted sum. 682 | attn = torch.unsqueeze(torch.softmax(self.aggregate_layer( 683 | rearrange(hidden_states, "(n_parscale b) s h -> b s (h n_parscale)", n_parscale=self.parscale_n) 684 | ).float(), dim=-1), dim=-1) # [b s n_parscale 1] 685 | if self.parscale_aggregate_attn_smoothing != 0.0: 686 | attn = attn * (1 - self.parscale_aggregate_attn_smoothing) + (self.parscale_aggregate_attn_smoothing / self.parscale_n) 687 | hidden_states = torch.sum( 688 | rearrange(hidden_states, "(n_parscale b) s h -> b s n_parscale h", n_parscale=self.parscale_n) * attn, 689 | dim=2, keepdim=False 690 | ).to(hidden_states.dtype) 691 | 692 | # add hidden states from the last decoder layer 693 | if output_hidden_states: 694 | all_hidden_states += (hidden_states,) 695 | 696 | output = BaseModelOutputWithPast( 697 | last_hidden_state=hidden_states, 698 | past_key_values=past_key_values if use_cache else None, 699 | hidden_states=all_hidden_states, 700 | attentions=all_self_attns, 701 | ) 702 | return output if return_dict else output.to_tuple() 703 | 704 | def _update_causal_mask( 705 | self, 706 | attention_mask: torch.Tensor, 707 | input_tensor: torch.Tensor, 708 | cache_position: torch.Tensor, 709 | past_key_values: Cache, 710 | output_attentions: bool, 711 | ): 712 | if self.config._attn_implementation == "flash_attention_2": 713 | if attention_mask is not None and (attention_mask == 0.0).any(): 714 | return attention_mask 715 | return None 716 | 717 | # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in 718 | # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail 719 | # to infer the attention mask. 720 | past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 721 | using_static_cache = isinstance(past_key_values, StaticCache) 722 | 723 | # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward 724 | if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: 725 | if AttentionMaskConverter._ignore_causal_mask_sdpa( 726 | attention_mask, 727 | inputs_embeds=input_tensor, 728 | past_key_values_length=past_seen_tokens, 729 | is_training=self.training, 730 | ): 731 | return None 732 | 733 | dtype, device = input_tensor.dtype, input_tensor.device 734 | sequence_length = input_tensor.shape[1] 735 | if using_static_cache: 736 | target_length = past_key_values.get_max_cache_shape() 737 | else: 738 | target_length = ( 739 | attention_mask.shape[-1] 740 | if isinstance(attention_mask, torch.Tensor) 741 | else past_seen_tokens + sequence_length + 1 742 | ) 743 | 744 | causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( 745 | attention_mask, 746 | sequence_length=sequence_length, 747 | target_length=target_length, 748 | dtype=dtype, 749 | device=device, 750 | cache_position=cache_position, 751 | batch_size=input_tensor.shape[0], 752 | ) 753 | 754 | if ( 755 | self.config._attn_implementation == "sdpa" 756 | and attention_mask is not None 757 | and attention_mask.device.type == "cuda" 758 | and not output_attentions 759 | ): 760 | # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when 761 | # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. 762 | # Details: https://github.com/pytorch/pytorch/issues/110213 763 | min_dtype = torch.finfo(dtype).min 764 | causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) 765 | 766 | return causal_mask 767 | 768 | @staticmethod 769 | def _prepare_4d_causal_attention_mask_with_cache_position( 770 | attention_mask: torch.Tensor, 771 | sequence_length: int, 772 | target_length: int, 773 | dtype: torch.dtype, 774 | device: torch.device, 775 | cache_position: torch.Tensor, 776 | batch_size: int, 777 | **kwargs, 778 | ): 779 | """ 780 | Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape 781 | `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. 782 | 783 | Args: 784 | attention_mask (`torch.Tensor`): 785 | A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape 786 | `(batch_size, 1, query_length, key_value_length)`. 787 | sequence_length (`int`): 788 | The sequence length being processed. 789 | target_length (`int`): 790 | The target length: when generating with static cache, the mask should be as long as the static cache, 791 | to account for the 0 padding, the part of the cache that is not filled yet. 792 | dtype (`torch.dtype`): 793 | The dtype to use for the 4D attention mask. 794 | device (`torch.device`): 795 | The device to plcae the 4D attention mask on. 796 | cache_position (`torch.Tensor`): 797 | Indices depicting the position of the input sequence tokens in the sequence. 798 | batch_size (`torch.Tensor`): 799 | Batch size. 800 | """ 801 | if attention_mask is not None and attention_mask.dim() == 4: 802 | # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. 803 | causal_mask = attention_mask 804 | else: 805 | min_dtype = torch.finfo(dtype).min 806 | causal_mask = torch.full( 807 | (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device 808 | ) 809 | if sequence_length != 1: 810 | causal_mask = torch.triu(causal_mask, diagonal=1) 811 | causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) 812 | causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) 813 | if attention_mask is not None: 814 | causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit 815 | mask_length = attention_mask.shape[-1] 816 | padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] 817 | padding_mask = padding_mask == 0 818 | causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( 819 | padding_mask, min_dtype 820 | ) 821 | 822 | return causal_mask 823 | 824 | 825 | class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... 826 | 827 | 828 | class Qwen2ParScaleForCausalLM(Qwen2PreTrainedModel, GenerationMixin): 829 | _tied_weights_keys = ["lm_head.weight"] 830 | _tp_plan = {"lm_head": "colwise_rep"} 831 | 832 | def __init__(self, config): 833 | super().__init__(config) 834 | self.model = Qwen2Model(config) 835 | self.vocab_size = config.vocab_size 836 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 837 | 838 | # Initialize weights and apply final processing 839 | self.post_init() 840 | 841 | def get_input_embeddings(self): 842 | return self.model.embed_tokens 843 | 844 | def set_input_embeddings(self, value): 845 | self.model.embed_tokens = value 846 | 847 | def get_output_embeddings(self): 848 | return self.lm_head 849 | 850 | def set_output_embeddings(self, new_embeddings): 851 | self.lm_head = new_embeddings 852 | 853 | def set_decoder(self, decoder): 854 | self.model = decoder 855 | 856 | def get_decoder(self): 857 | return self.model 858 | 859 | @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING) 860 | @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) 861 | def forward( 862 | self, 863 | input_ids: torch.LongTensor = None, 864 | attention_mask: Optional[torch.Tensor] = None, 865 | position_ids: Optional[torch.LongTensor] = None, 866 | past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, 867 | inputs_embeds: Optional[torch.FloatTensor] = None, 868 | labels: Optional[torch.LongTensor] = None, 869 | use_cache: Optional[bool] = None, 870 | output_attentions: Optional[bool] = None, 871 | output_hidden_states: Optional[bool] = None, 872 | return_dict: Optional[bool] = None, 873 | cache_position: Optional[torch.LongTensor] = None, 874 | num_logits_to_keep: int = 0, 875 | **kwargs: Unpack[KwargsForCausalLM], 876 | ) -> Union[Tuple, CausalLMOutputWithPast]: 877 | r""" 878 | Args: 879 | labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): 880 | Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., 881 | config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored 882 | (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. 883 | 884 | num_logits_to_keep (`int`, *optional*): 885 | Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all 886 | `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that 887 | token can save memory, which becomes pretty significant for long sequences or large vocabulary size. 888 | 889 | Returns: 890 | 891 | Example: 892 | 893 | ```python 894 | >>> from transformers import AutoTokenizer, Qwen2ForCausalLM 895 | 896 | >>> model = Qwen2ForCausalLM.from_pretrained("meta-qwen2/Qwen2-2-7b-hf") 897 | >>> tokenizer = AutoTokenizer.from_pretrained("meta-qwen2/Qwen2-2-7b-hf") 898 | 899 | >>> prompt = "Hey, are you conscious? Can you talk to me?" 900 | >>> inputs = tokenizer(prompt, return_tensors="pt") 901 | 902 | >>> # Generate 903 | >>> generate_ids = model.generate(inputs.input_ids, max_length=30) 904 | >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] 905 | "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." 906 | ```""" 907 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 908 | output_hidden_states = ( 909 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 910 | ) 911 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 912 | 913 | # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) 914 | outputs = self.model( 915 | input_ids=input_ids, 916 | attention_mask=attention_mask, 917 | position_ids=position_ids, 918 | past_key_values=past_key_values, 919 | inputs_embeds=inputs_embeds, 920 | use_cache=use_cache, 921 | output_attentions=output_attentions, 922 | output_hidden_states=output_hidden_states, 923 | return_dict=return_dict, 924 | cache_position=cache_position, 925 | **kwargs, 926 | ) 927 | 928 | hidden_states = outputs[0] 929 | # Only compute necessary logits, and do not upcast them to float if we are not computing the loss 930 | logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) 931 | 932 | loss = None 933 | if labels is not None: 934 | loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) 935 | 936 | if not return_dict: 937 | output = (logits,) + outputs[1:] 938 | return (loss,) + output if loss is not None else output 939 | 940 | return CausalLMOutputWithPast( 941 | loss=loss, 942 | logits=logits, 943 | past_key_values=outputs.past_key_values, 944 | hidden_states=outputs.hidden_states, 945 | attentions=outputs.attentions, 946 | ) 947 | 948 | 949 | @add_start_docstrings( 950 | """ 951 | The Qwen2 Model transformer with a sequence classification head on top (linear layer). 952 | 953 | [`Qwen2ForSequenceClassification`] uses the last token in order to do the classification, as other causal models 954 | (e.g. GPT-2) do. 955 | 956 | Since it does classification on the last token, it requires to know the position of the last token. If a 957 | `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If 958 | no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the 959 | padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in 960 | each row of the batch). 961 | """, 962 | QWEN2_START_DOCSTRING, 963 | ) 964 | class Qwen2ForSequenceClassification(Qwen2PreTrainedModel): 965 | def __init__(self, config): 966 | super().__init__(config) 967 | self.num_labels = config.num_labels 968 | self.model = Qwen2Model(config) 969 | self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) 970 | 971 | # Initialize weights and apply final processing 972 | self.post_init() 973 | 974 | def get_input_embeddings(self): 975 | return self.model.embed_tokens 976 | 977 | def set_input_embeddings(self, value): 978 | self.model.embed_tokens = value 979 | 980 | @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING) 981 | def forward( 982 | self, 983 | input_ids: Optional[torch.LongTensor] = None, 984 | attention_mask: Optional[torch.Tensor] = None, 985 | position_ids: Optional[torch.LongTensor] = None, 986 | past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, 987 | inputs_embeds: Optional[torch.FloatTensor] = None, 988 | labels: Optional[torch.LongTensor] = None, 989 | use_cache: Optional[bool] = None, 990 | output_attentions: Optional[bool] = None, 991 | output_hidden_states: Optional[bool] = None, 992 | return_dict: Optional[bool] = None, 993 | ) -> Union[Tuple, SequenceClassifierOutputWithPast]: 994 | r""" 995 | labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): 996 | Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., 997 | config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If 998 | `config.num_labels > 1` a classification loss is computed (Cross-Entropy). 999 | """ 1000 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1001 | 1002 | transformer_outputs = self.model( 1003 | input_ids, 1004 | attention_mask=attention_mask, 1005 | position_ids=position_ids, 1006 | past_key_values=past_key_values, 1007 | inputs_embeds=inputs_embeds, 1008 | use_cache=use_cache, 1009 | output_attentions=output_attentions, 1010 | output_hidden_states=output_hidden_states, 1011 | return_dict=return_dict, 1012 | ) 1013 | hidden_states = transformer_outputs[0] 1014 | logits = self.score(hidden_states) 1015 | 1016 | if input_ids is not None: 1017 | batch_size = input_ids.shape[0] 1018 | else: 1019 | batch_size = inputs_embeds.shape[0] 1020 | 1021 | if self.config.pad_token_id is None and batch_size != 1: 1022 | raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") 1023 | if self.config.pad_token_id is None: 1024 | sequence_lengths = -1 1025 | else: 1026 | if input_ids is not None: 1027 | # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility 1028 | sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 1029 | sequence_lengths = sequence_lengths % input_ids.shape[-1] 1030 | sequence_lengths = sequence_lengths.to(logits.device) 1031 | else: 1032 | sequence_lengths = -1 1033 | 1034 | pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] 1035 | 1036 | loss = None 1037 | if labels is not None: 1038 | loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config) 1039 | 1040 | if not return_dict: 1041 | output = (pooled_logits,) + transformer_outputs[1:] 1042 | return ((loss,) + output) if loss is not None else output 1043 | 1044 | return SequenceClassifierOutputWithPast( 1045 | loss=loss, 1046 | logits=pooled_logits, 1047 | past_key_values=transformer_outputs.past_key_values, 1048 | hidden_states=transformer_outputs.hidden_states, 1049 | attentions=transformer_outputs.attentions, 1050 | ) 1051 | 1052 | 1053 | @add_start_docstrings( 1054 | """ 1055 | The Qwen2 Model transformer with a token classification head on top (a linear layer on top of the hidden-states 1056 | output) e.g. for Named-Entity-Recognition (NER) tasks. 1057 | """, 1058 | QWEN2_START_DOCSTRING, 1059 | ) 1060 | class Qwen2ForTokenClassification(Qwen2PreTrainedModel): 1061 | def __init__(self, config): 1062 | super().__init__(config) 1063 | self.num_labels = config.num_labels 1064 | self.model = Qwen2Model(config) 1065 | if getattr(config, "classifier_dropout", None) is not None: 1066 | classifier_dropout = config.classifier_dropout 1067 | elif getattr(config, "hidden_dropout", None) is not None: 1068 | classifier_dropout = config.hidden_dropout 1069 | else: 1070 | classifier_dropout = 0.1 1071 | self.dropout = nn.Dropout(classifier_dropout) 1072 | self.score = nn.Linear(config.hidden_size, config.num_labels) 1073 | 1074 | # Initialize weights and apply final processing 1075 | self.post_init() 1076 | 1077 | def get_input_embeddings(self): 1078 | return self.model.embed_tokens 1079 | 1080 | def set_input_embeddings(self, value): 1081 | self.model.embed_tokens = value 1082 | 1083 | @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING) 1084 | @add_code_sample_docstrings( 1085 | checkpoint=_CHECKPOINT_FOR_DOC, 1086 | output_type=TokenClassifierOutput, 1087 | config_class=_CONFIG_FOR_DOC, 1088 | ) 1089 | def forward( 1090 | self, 1091 | input_ids: Optional[torch.LongTensor] = None, 1092 | attention_mask: Optional[torch.Tensor] = None, 1093 | position_ids: Optional[torch.LongTensor] = None, 1094 | past_key_values: Optional[List[torch.FloatTensor]] = None, 1095 | inputs_embeds: Optional[torch.FloatTensor] = None, 1096 | labels: Optional[torch.LongTensor] = None, 1097 | use_cache: Optional[bool] = None, 1098 | output_attentions: Optional[bool] = None, 1099 | output_hidden_states: Optional[bool] = None, 1100 | return_dict: Optional[bool] = None, 1101 | ) -> Union[Tuple, TokenClassifierOutput]: 1102 | r""" 1103 | labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): 1104 | Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., 1105 | config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If 1106 | `config.num_labels > 1` a classification loss is computed (Cross-Entropy). 1107 | """ 1108 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1109 | 1110 | outputs = self.model( 1111 | input_ids, 1112 | attention_mask=attention_mask, 1113 | position_ids=position_ids, 1114 | past_key_values=past_key_values, 1115 | inputs_embeds=inputs_embeds, 1116 | use_cache=use_cache, 1117 | output_attentions=output_attentions, 1118 | output_hidden_states=output_hidden_states, 1119 | return_dict=return_dict, 1120 | ) 1121 | sequence_output = outputs[0] 1122 | sequence_output = self.dropout(sequence_output) 1123 | logits = self.score(sequence_output) 1124 | 1125 | loss = None 1126 | if labels is not None: 1127 | loss = self.loss_function(logits, labels, self.config) 1128 | 1129 | if not return_dict: 1130 | output = (logits,) + outputs[2:] 1131 | return ((loss,) + output) if loss is not None else output 1132 | 1133 | return TokenClassifierOutput( 1134 | loss=loss, 1135 | logits=logits, 1136 | hidden_states=outputs.hidden_states, 1137 | attentions=outputs.attentions, 1138 | ) 1139 | 1140 | 1141 | @add_start_docstrings( 1142 | """ 1143 | The Qwen2 Model transformer with a span classification head on top for extractive question-answering tasks like 1144 | SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`). 1145 | """, 1146 | QWEN2_START_DOCSTRING, 1147 | ) 1148 | class Qwen2ForQuestionAnswering(Qwen2PreTrainedModel): 1149 | base_model_prefix = "transformer" 1150 | 1151 | def __init__(self, config): 1152 | super().__init__(config) 1153 | self.transformer = Qwen2Model(config) 1154 | self.qa_outputs = nn.Linear(config.hidden_size, 2) 1155 | 1156 | # Initialize weights and apply final processing 1157 | self.post_init() 1158 | 1159 | def get_input_embeddings(self): 1160 | return self.transformer.embed_tokens 1161 | 1162 | def set_input_embeddings(self, value): 1163 | self.transformer.embed_tokens = value 1164 | 1165 | @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING) 1166 | def forward( 1167 | self, 1168 | input_ids: Optional[torch.LongTensor] = None, 1169 | attention_mask: Optional[torch.FloatTensor] = None, 1170 | position_ids: Optional[torch.LongTensor] = None, 1171 | past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, 1172 | inputs_embeds: Optional[torch.FloatTensor] = None, 1173 | start_positions: Optional[torch.LongTensor] = None, 1174 | end_positions: Optional[torch.LongTensor] = None, 1175 | output_attentions: Optional[bool] = None, 1176 | output_hidden_states: Optional[bool] = None, 1177 | return_dict: Optional[bool] = None, 1178 | **kwargs, 1179 | ) -> Union[Tuple, QuestionAnsweringModelOutput]: 1180 | r""" 1181 | start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): 1182 | Labels for position (index) of the start of the labelled span for computing the token classification loss. 1183 | Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence 1184 | are not taken into account for computing the loss. 1185 | end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): 1186 | Labels for position (index) of the end of the labelled span for computing the token classification loss. 1187 | Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence 1188 | are not taken into account for computing the loss. 1189 | """ 1190 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1191 | 1192 | outputs = self.transformer( 1193 | input_ids, 1194 | attention_mask=attention_mask, 1195 | position_ids=position_ids, 1196 | past_key_values=past_key_values, 1197 | inputs_embeds=inputs_embeds, 1198 | output_attentions=output_attentions, 1199 | output_hidden_states=output_hidden_states, 1200 | return_dict=return_dict, 1201 | ) 1202 | 1203 | sequence_output = outputs[0] 1204 | 1205 | logits = self.qa_outputs(sequence_output) 1206 | start_logits, end_logits = logits.split(1, dim=-1) 1207 | start_logits = start_logits.squeeze(-1).contiguous() 1208 | end_logits = end_logits.squeeze(-1).contiguous() 1209 | 1210 | loss = None 1211 | if start_positions is not None and end_positions is not None: 1212 | loss = self.loss_function(start_logits, end_logits, start_positions, end_positions, **kwargs) 1213 | 1214 | if not return_dict: 1215 | output = (start_logits, end_logits) + outputs[2:] 1216 | return ((loss,) + output) if loss is not None else output 1217 | 1218 | return QuestionAnsweringModelOutput( 1219 | loss=loss, 1220 | start_logits=start_logits, 1221 | end_logits=end_logits, 1222 | hidden_states=outputs.hidden_states, 1223 | attentions=outputs.attentions, 1224 | ) 1225 | -------------------------------------------------------------------------------- /parametric_fit.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.optimize import minimize 3 | from sklearn.linear_model import LinearRegression 4 | import matplotlib.pyplot as plt 5 | from sklearn.metrics import r2_score 6 | import json 7 | import os 8 | import pandas as pd 9 | 10 | def parametric_fit(param_list, p_list, loss_list): 11 | param_list = np.asarray(param_list).reshape((-1, )) 12 | loss_list = np.asarray(loss_list).reshape((-1, )) 13 | p_list = np.asarray(p_list).reshape((-1, )) 14 | 15 | def huber_loss(y_true, y_pred, delta=0.001): 16 | error = y_true - y_pred 17 | is_small_error = np.abs(error) <= delta 18 | squared_loss = np.square(error) / 2 19 | linear_loss = delta * (np.abs(error) - delta / 2) 20 | return np.where(is_small_error, squared_loss, linear_loss).sum() 21 | 22 | def pred_loss(params): 23 | E, A, alpha, k = params 24 | return E + (A * 1e9 / (param_list * (np.log(p_list) * k + 1))) ** alpha 25 | 26 | def objective_function(params): 27 | pred = pred_loss(params) 28 | return huber_loss(np.log(loss_list), np.log(pred)) 29 | 30 | best_param = None 31 | best_func = 1000000 32 | for E in [-1, -0.5, 0]: 33 | for log_A in [-4, -2, 0, 2, 4]: 34 | for alpha in [0, 0.5, 1, 1.5, 2]: 35 | for k in [0.2, 0.4, 0.6, 0.8]: 36 | initial_params = [np.exp(E), np.exp(log_A), alpha, k] 37 | bounds = [(1e-8, None), (1e-8, None), (1e-8, None), (1e-8, None)] 38 | result = minimize(objective_function, initial_params, method='L-BFGS-B', bounds=bounds) 39 | if result.fun < best_func: 40 | best_param = result.x 41 | best_func = result.fun 42 | print(f"{result = }") 43 | print(f"{best_param = }") 44 | print(f"{best_func = }") 45 | 46 | pred_key = "$\\mathcal L_{\\text{pred}}$" 47 | true_key = "$\\mathcal L_{\\text{true}}$" 48 | df = pd.DataFrame({ 49 | "$P$": p_list, 50 | "Parameters (Non-Embedding)": param_list, 51 | pred_key: pred_loss(best_param), 52 | true_key: loss_list, 53 | "Error": pred_loss(best_param) - loss_list 54 | }) 55 | df['Parameters (Non-Embedding)'] = df['Parameters (Non-Embedding)'].apply(lambda x: f"{x:,}") 56 | r2 = r2_score(df[true_key].to_numpy().reshape(-1, 1), df[pred_key].to_numpy().reshape(-1, 1)) 57 | 58 | print(df.to_latex(float_format=lambda x: f"{x:.4f}", index=False, column_format='rrrrr')) 59 | print(f"{r2 = }") 60 | 61 | 62 | if __name__ == "__main__": 63 | 64 | params = [ 65 | [535813376, 693753856, 1088376320, 1571472384, 2774773760, 4353203200], 66 | [538195842, 696738818, 1092762882, 1577522690, 2784937986, 4368529922], 67 | [540577412, 699722756, 1097148164, 1583571460, 2795100164, 4383854084], 68 | [545340552, 705690632, 1105918728, 1595669000, 2815424520, 4414502408], 69 | ] 70 | 71 | stack_loss = [ 72 | [1.1722, 1.1496, 1.1131, 1.0817, 1.0451, 1.0213], # 1.0006], # P1 73 | [1.1507, 1.1262, 1.094, 1.0623, 1.0244, 1.0025], # P2 74 | [1.1354, 1.1124, 1.0808, 1.049, 1.0126, 0.9906], # P4 75 | [1.1231, 1.0997, 1.0688, 1.0383, 1.0016, 0.9794], # P8 76 | ] 77 | 78 | pile_loss = [ 79 | [2.1113, 2.0671, 2.0027, 1.9539, 1.8876, 1.8451], # P1 80 | [2.0772, 2.0363, 1.973, 1.9266, 1.861, 1.8137], # P2 81 | [2.0544, 2.0128, 1.9509, 1.904, 1.8394, 1.7938], # P4 82 | [2.0364, 1.9933, 1.9318, 1.8856, 1.8218, 1.7772], # P8 83 | ] 84 | 85 | p = [ 86 | [1] * 6, 87 | [2] * 6, 88 | [4] * 6, 89 | [8] * 6, 90 | ] 91 | 92 | print("=" * 10 + " Stack-V2 Python " + "=" * 10) 93 | parametric_fit(params, p, stack_loss) 94 | print("=" * 10 + " Pile " + "=" * 10) 95 | parametric_fit(params, p, pile_loss) --------------------------------------------------------------------------------