├── README.md
├── configuration_qwen2_parscale.py
├── cost_analysis.py
├── figures
├── 1t.png
├── cost.png
├── logo.jpg
├── scaling_comparison.png
├── scaling_law.png
├── scaling_law2.png
└── teaser.png
├── modeling_qwen2_parscale.py
└── parametric_fit.py
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | # Parallel Scaling Law for Language Model
5 |
6 |
7 | _Yet Another Scaling Law beyond Parameters and Inference Time Scaling_
8 |
9 | [](https://arxiv.org/abs/2505.10475)
10 | [](https://huggingface.co/ParScale)
11 |
12 |
13 |

14 |
15 |
16 |
17 |
18 | 💡 Key Findings
19 | | 📈 Scaling Law
20 | | ⚡ Cost Analysis
21 | | 🔥 Models
22 | | 📚 Citation
23 |
24 |
25 |
26 | ## 🌟 About
27 |
28 | - Most believe that scaling language models requires a heavy cost in either **space** (parameter scaling) or **time** (inference-time scaling).
29 | - We introduce the *third* scaling paradigm for scaling LLMs: leverages **parallel computation** during both training and inference time (Parallel Scaling, or *ParScale*).
30 | - We apply $P$ diverse and learnable transformations to the input, execute forward passes of the model in parallel, and dynamically aggregate the $P$ outputs.
31 |
32 |

33 |
34 |
35 | ---
36 |
37 | ## 💡 Key Findings
38 |
39 |

40 |
41 |
42 | Here are the core insights and benefits distilled from our theoretical analysis and empirical evaluations:
43 |
44 | 📈 **Logarithmic Scaling Law**: We theoretically and empirically establish that **scaling with $P$ parallel streams is comparable to scaling the number of parameters by** $O(\log P)$. This suggests that parallel computation can serve as an efficient substitute for parameter growth, especially for larger models.
45 |
46 | ✅ **Universal Applicability**: Unlike inference-time scaling which requires specialized data and limited application, it works with any model architecture, optimization method, data, or downstream task.
47 |
48 |
49 | 🧠 **Stronger Performance on Reasoning Tasks**: Reasoning-intensive tasks (e.g., coding or math) benefit more from ParScale, which suggests that scaling computation can effectively push the boundary of reasoning.
50 |
51 | ⚡ **Superior Inference Efficiency**: ParScale can use up to **22x less memory increase** and **6x less latency increase** compared to parameter scaling that achieves the same performance improvement (batch size=1).
52 |
53 | 🧱 **Cost-Efficient Training via Two-Stage Strategy**: Training a parallel-scaled model doesn't require starting from scratch. With a two-stage training strategy, we can post-train ithe parallel components using only a small amount of data.
54 |
55 | 🔁 **Dynamic Adaptation at Inference Time**: We find that ParScale remains effective with frozen main parameters for different $P$. This illustrates the potential of dynamic parallel scaling: switching $P$ to dynamically adapt model capabilities during inference.
56 |
57 | We release the inference code in `modeling_qwen2_parscale.py` and `configuration_qwen2_parscale.py`. Our 67 checkpoints is available at [🤗 HuggingFace](https://huggingface.co/ParScale).
58 |
59 | ---
60 |
61 | ## 📈 Scaling Law
62 |
63 | - We carry out large-scale pre-training experiments on the Stack-V2 and Pile corpus, by ranging $P$ from 1 to 8 and model parameters from 500M to 4.4B.
64 | - We use the results to fit a new *parallel scaling law* that generalizes the Chinchilla scaling law.
65 | - We release our parametric fitting code in `parametric_fit.py`.
66 | - Feel free to try [🤗 HuggingFace Space](https://huggingface.co/spaces/ParScale/Parallel_Scaling_Law) for a nice visualization for the parallel scaling law!
67 |
68 |

69 |

70 |
71 |
72 | ---
73 |
74 | ## ⚡ Cost Analysis
75 |
76 |
77 |

78 |
79 |
80 | - We further compare the inference efficiency between parallel scaling and parameter scaling at equivalent performance levels.
81 | - We release our analysis code in `cost_analysis.py`. Before using it, you should first install [llm-analysis](https://github.com/cli99/llm-analysis):
82 |
83 | ```bash
84 | git clone https://github.com/cli99/llm-analysis.git
85 | cd llm-analysis
86 | pip install .
87 | ```
88 |
89 | - You can use the following command to analyze the inference memory and latency cost for our 4.4B model, with $P=2$ and batch size=2:
90 | ```bash
91 | python cost_analysis.py --hidden_size 2560 --intermediate_size 13824 --P 2 --batch_size 2
92 | ```
93 |
94 | ---
95 |
96 | ## 🔥 Models
97 |
98 | ✨ are our recommendation for strong models!
99 |
100 | ### Base models for scaling training data to 1T tokens
101 |
102 | These models demonstrate strong competitiveness among existing small models, including SmolLM, gemma, and Llama-3.2.
103 |
104 | |Model|Description|Download|
105 | |:-:|:-:|:-:|
106 | |ParScale-1.8B-P1|✨ Baseline $P=1$|[🤗 ParScale/ParScale-1.8B-P1](https://huggingface.co/ParScale/ParScale-1.8B-P1)|
107 | |ParScale-1.8B-P2|✨ ParScale $P=2$|[🤗 ParScale/ParScale-1.8B-P2](https://huggingface.co/ParScale/ParScale-1.8B-P2)|
108 | |ParScale-1.8B-P4|✨ ParScale $P=4$|[🤗 ParScale/ParScale-1.8B-P4](https://huggingface.co/ParScale/ParScale-1.8B-P4)|
109 | |ParScale-1.8B-P8|✨ ParScale $P=8$|[🤗 ParScale/ParScale-1.8B-P8](https://huggingface.co/ParScale/ParScale-1.8B-P8)|
110 |
111 | ### Instruct models for scaling training data to 1T tokens
112 |
113 | We post-trained the aforementioned base model on SmolTalk-1M to enable conversational capabilities.
114 |
115 | |Model|Description|Download|
116 | |:-:|:-:|:-:|
117 | |ParScale-1.8B-P1-Inst|✨ Baseline $P=1$|[🤗 ParScale/ParScale-1.8B-P1-Inst](https://huggingface.co/ParScale/ParScale-1.8B-P1-Inst)|
118 | |ParScale-1.8B-P2-Inst|✨ ParScale $P=2$|[🤗 ParScale/ParScale-1.8B-P2-Inst](https://huggingface.co/ParScale/ParScale-1.8B-P2-Inst)|
119 | |ParScale-1.8B-P4-Inst|✨ ParScale $P=4$|[🤗 ParScale/ParScale-1.8B-P4-Inst](https://huggingface.co/ParScale/ParScale-1.8B-P4-Inst)|
120 | |ParScale-1.8B-P8-Inst|✨ ParScale $P=8$|[🤗 ParScale/ParScale-1.8B-P8-Inst](https://huggingface.co/ParScale/ParScale-1.8B-P8-Inst)|
121 |
122 |
123 | ### Continual Pretraining Qwen-2.5-3B
124 |
125 | We froze the parameters of Qwen-2.5-3B and only fine-tuned the newly introduced parameters on Stack-V2-Python. Since the following models share the same backbone parameters as Qwen-2.5-3B, they have the potential for dynamic ParScale: switching P to adapt model capabilities during inference.
126 |
127 | |Model|Description|Download|
128 | |:-:|:-:|:-:|
129 | |ParScale-Qwen-3B-P2-Python|✨ ParScale $P=2$|[🤗 ParScale/ParScale-Qwen-3B-P2-Python](https://huggingface.co/ParScale/ParScale-Qwen-3B-P2-Python)|
130 | |ParScale-Qwen-3B-P4-Python|✨ ParScale $P=4$|[🤗 ParScale/ParScale-Qwen-3B-P4-Python](https://huggingface.co/ParScale/ParScale-Qwen-3B-P4-Python)|
131 | |ParScale-Qwen-3B-P8-Python|✨ ParScale $P=8$|[🤗 ParScale/ParScale-Qwen-3B-P8-Python](https://huggingface.co/ParScale/ParScale-Qwen-3B-P8-Python)|
132 |
133 | - For full continual pretraining on Stack-V2-Python
134 |
135 | |Model|Description|Download|
136 | |:-:|:-:|:-:|
137 | |ParScale-QwenInit-3B-P1-Python|Baseline $P=1$|[🤗 ParScale/ParScale-QwenInit-3B-P1-Python](https://huggingface.co/ParScale/ParScale-QwenInit-3B-P1-Python)|
138 | |ParScale-QwenInit-3B-P2-Python|ParScale $P=2$|[🤗 ParScale/ParScale-QwenInit-3B-P2-Python](https://huggingface.co/ParScale/ParScale-QwenInit-3B-P2-Python)|
139 | |ParScale-QwenInit-3B-P4-Python|ParScale $P=4$|[🤗 ParScale/ParScale-QwenInit-3B-P4-Python](https://huggingface.co/ParScale/ParScale-QwenInit-3B-P4-Python)|
140 | |ParScale-QwenInit-3B-P8-Python|ParScale $P=8$|[🤗 ParScale/ParScale-QwenInit-3B-P8-Python](https://huggingface.co/ParScale/ParScale-QwenInit-3B-P8-Python)|
141 |
142 | - For full continual pretraining on Pile
143 |
144 | |Model|Description|Download|
145 | |:-:|:-:|:-:|
146 | |ParScale-QwenInit-3B-P1-Pile|Baseline $P=1$|[🤗 ParScale/ParScale-QwenInit-3B-P1-Pile](https://huggingface.co/ParScale/ParScale-QwenInit-3B-P1-Pile)|
147 | |ParScale-QwenInit-3B-P2-Pile|ParScale $P=2$|[🤗 ParScale/ParScale-QwenInit-3B-P2-Pile](https://huggingface.co/ParScale/ParScale-QwenInit-3B-P2-Pile)|
148 | |ParScale-QwenInit-3B-P4-Pile|ParScale $P=4$|[🤗 ParScale/ParScale-QwenInit-3B-P4-Pile](https://huggingface.co/ParScale/ParScale-QwenInit-3B-P4-Pile)|
149 | |ParScale-QwenInit-3B-P8-Pile|ParScale $P=8$|[🤗 ParScale/ParScale-QwenInit-3B-P8-Pile](https://huggingface.co/ParScale/ParScale-QwenInit-3B-P8-Pile)|
150 |
151 |
152 | ### Checkpoints Used to Fit the Scaling Law
153 |
154 | Download link: https://huggingface.co/ParScale/ParScale-{size}-{P}-{dataset}
155 |
156 | - {size}: model size, from {0.7B, 0.9B, 1.3B, 1.8B, 3B, 4.7B}
157 | - {P}: number of parallels, from {P1, P2, P4, P8}
158 | - {dataset}: training dataset, from {Python, Pile}
159 | - $6\times 4 \times 2=48$ checkpoints in total.
160 |
161 | ### Usage Example with 🤗 Hugging Face
162 |
163 | ```python
164 | from transformers import AutoModelForCausalLM, AutoTokenizer
165 | name = "ParScale/ParScale-1.8B-P8" # or anything else you like
166 | model = AutoModelForCausalLM.from_pretrained(name, trust_remote_code=True).to("cuda")
167 | tokenizer = AutoTokenizer.from_pretrained(name)
168 | inputs = tokenizer.encode("Hello, how are you today?", return_tensors="pt").to("cuda")
169 | outputs = model.generate(inputs, max_new_tokens=128)[0]
170 | print(tokenizer.decode(outputs))
171 | ```
172 |
173 |
174 | ## 📚 Citation
175 |
176 | ```bibtex
177 | @article{ParScale,
178 | title={Parallel Scaling Law for Language Models},
179 | author={Mouxiang Chen and Binyuan Hui and Zeyu Cui and Jiaxi Yang and Dayiheng Liu and Jianling Sun and Junyang Lin and Zhongxin Liu},
180 | year={2025},
181 | eprint={2505.10475},
182 | archivePrefix={arXiv},
183 | primaryClass={cs.LG},
184 | journal={arXiv preprint arXiv:2505.10475},
185 | url={https://arxiv.org/abs/2505.10475},
186 | }
187 | ```
188 |
--------------------------------------------------------------------------------
/configuration_qwen2_parscale.py:
--------------------------------------------------------------------------------
1 | """Qwen2 model configuration, with support for ParScale"""
2 |
3 | from transformers.configuration_utils import PretrainedConfig
4 | from transformers.modeling_rope_utils import rope_config_validation
5 | from transformers.utils import logging
6 |
7 |
8 | logger = logging.get_logger(__name__)
9 |
10 |
11 | class Qwen2ParScaleConfig(PretrainedConfig):
12 | r"""
13 | This is the configuration class to store the configuration of a [`Qwen2Model`]. It is used to instantiate a
14 | Qwen2 model according to the specified arguments, defining the model architecture. Instantiating a configuration
15 | with the defaults will yield a similar configuration to that of
16 | Qwen2-7B-beta [Qwen/Qwen2-7B-beta](https://huggingface.co/Qwen/Qwen2-7B-beta).
17 |
18 | Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
19 | documentation from [`PretrainedConfig`] for more information.
20 |
21 |
22 | Args:
23 | vocab_size (`int`, *optional*, defaults to 151936):
24 | Vocabulary size of the Qwen2 model. Defines the number of different tokens that can be represented by the
25 | `inputs_ids` passed when calling [`Qwen2Model`]
26 | hidden_size (`int`, *optional*, defaults to 4096):
27 | Dimension of the hidden representations.
28 | intermediate_size (`int`, *optional*, defaults to 22016):
29 | Dimension of the MLP representations.
30 | num_hidden_layers (`int`, *optional*, defaults to 32):
31 | Number of hidden layers in the Transformer encoder.
32 | num_attention_heads (`int`, *optional*, defaults to 32):
33 | Number of attention heads for each attention layer in the Transformer encoder.
34 | num_key_value_heads (`int`, *optional*, defaults to 32):
35 | This is the number of key_value heads that should be used to implement Grouped Query Attention. If
36 | `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
37 | `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
38 | converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
39 | by meanpooling all the original heads within that group. For more details checkout [this
40 | paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`.
41 | hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
42 | The non-linear activation function (function or string) in the decoder.
43 | max_position_embeddings (`int`, *optional*, defaults to 32768):
44 | The maximum sequence length that this model might ever be used with.
45 | initializer_range (`float`, *optional*, defaults to 0.02):
46 | The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
47 | rms_norm_eps (`float`, *optional*, defaults to 1e-06):
48 | The epsilon used by the rms normalization layers.
49 | use_cache (`bool`, *optional*, defaults to `True`):
50 | Whether or not the model should return the last key/values attentions (not used by all models). Only
51 | relevant if `config.is_decoder=True`.
52 | tie_word_embeddings (`bool`, *optional*, defaults to `False`):
53 | Whether the model's input and output word embeddings should be tied.
54 | rope_theta (`float`, *optional*, defaults to 10000.0):
55 | The base period of the RoPE embeddings.
56 | rope_scaling (`Dict`, *optional*):
57 | Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
58 | and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
59 | accordingly.
60 | Expected contents:
61 | `rope_type` (`str`):
62 | The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
63 | 'llama3'], with 'default' being the original RoPE implementation.
64 | `factor` (`float`, *optional*):
65 | Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
66 | most scaling types, a `factor` of x will enable the model to handle sequences of length x *
67 | original maximum pre-trained length.
68 | `original_max_position_embeddings` (`int`, *optional*):
69 | Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
70 | pretraining.
71 | `attention_factor` (`float`, *optional*):
72 | Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
73 | computation. If unspecified, it defaults to value recommended by the implementation, using the
74 | `factor` field to infer the suggested value.
75 | `beta_fast` (`float`, *optional*):
76 | Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
77 | ramp function. If unspecified, it defaults to 32.
78 | `beta_slow` (`float`, *optional*):
79 | Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
80 | ramp function. If unspecified, it defaults to 1.
81 | `short_factor` (`List[float]`, *optional*):
82 | Only used with 'longrope'. The scaling factor to be applied to short contexts (<
83 | `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
84 | size divided by the number of attention heads divided by 2
85 | `long_factor` (`List[float]`, *optional*):
86 | Only used with 'longrope'. The scaling factor to be applied to long contexts (<
87 | `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
88 | size divided by the number of attention heads divided by 2
89 | `low_freq_factor` (`float`, *optional*):
90 | Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
91 | `high_freq_factor` (`float`, *optional*):
92 | Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
93 | use_sliding_window (`bool`, *optional*, defaults to `False`):
94 | Whether to use sliding window attention.
95 | sliding_window (`int`, *optional*, defaults to 4096):
96 | Sliding window attention (SWA) window size. If not specified, will default to `4096`.
97 | max_window_layers (`int`, *optional*, defaults to 28):
98 | The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention.
99 | attention_dropout (`float`, *optional*, defaults to 0.0):
100 | The dropout ratio for the attention probabilities.
101 |
102 | ```python
103 | >>> from transformers import Qwen2Model, Qwen2Config
104 |
105 | >>> # Initializing a Qwen2 style configuration
106 | >>> configuration = Qwen2Config()
107 |
108 | >>> # Initializing a model from the Qwen2-7B style configuration
109 | >>> model = Qwen2Model(configuration)
110 |
111 | >>> # Accessing the model configuration
112 | >>> configuration = model.config
113 | ```"""
114 |
115 | model_type = "qwen2_parscale"
116 | keys_to_ignore_at_inference = ["past_key_values"]
117 |
118 | # Default tensor parallel plan for base model `Qwen2`
119 | base_model_tp_plan = {
120 | "layers.*.self_attn.q_proj": "colwise",
121 | "layers.*.self_attn.k_proj": "colwise",
122 | "layers.*.self_attn.v_proj": "colwise",
123 | "layers.*.self_attn.o_proj": "rowwise",
124 | "layers.*.mlp.gate_proj": "colwise",
125 | "layers.*.mlp.up_proj": "colwise",
126 | "layers.*.mlp.down_proj": "rowwise",
127 | }
128 |
129 | def __init__(
130 | self,
131 | vocab_size=151936,
132 | hidden_size=4096,
133 | intermediate_size=22016,
134 | num_hidden_layers=32,
135 | num_attention_heads=32,
136 | num_key_value_heads=32,
137 | hidden_act="silu",
138 | max_position_embeddings=32768,
139 | initializer_range=0.02,
140 | rms_norm_eps=1e-6,
141 | use_cache=True,
142 | tie_word_embeddings=False,
143 | rope_theta=10000.0,
144 | rope_scaling=None,
145 | use_sliding_window=False,
146 | sliding_window=4096,
147 | max_window_layers=28,
148 | attention_dropout=0.0,
149 | parscale_n=1,
150 | parscale_n_tokens=48,
151 | parscale_attn_smooth=0.01,
152 | **kwargs,
153 | ):
154 | self.vocab_size = vocab_size
155 | self.max_position_embeddings = max_position_embeddings
156 | self.hidden_size = hidden_size
157 | self.intermediate_size = intermediate_size
158 | self.num_hidden_layers = num_hidden_layers
159 | self.num_attention_heads = num_attention_heads
160 | self.use_sliding_window = use_sliding_window
161 | self.sliding_window = sliding_window if use_sliding_window else None
162 | self.max_window_layers = max_window_layers
163 | self.parscale_n = parscale_n
164 | self.parscale_n_tokens = parscale_n_tokens
165 | self.parscale_attn_smooth = parscale_attn_smooth
166 |
167 | # for backward compatibility
168 | if num_key_value_heads is None:
169 | num_key_value_heads = num_attention_heads
170 |
171 | self.num_key_value_heads = num_key_value_heads
172 | self.hidden_act = hidden_act
173 | self.initializer_range = initializer_range
174 | self.rms_norm_eps = rms_norm_eps
175 | self.use_cache = use_cache
176 | self.rope_theta = rope_theta
177 | self.rope_scaling = rope_scaling
178 | self.attention_dropout = attention_dropout
179 | # Validate the correctness of rotary position embeddings parameters
180 | # BC: if there is a 'type' field, move it to 'rope_type'.
181 | if self.rope_scaling is not None and "type" in self.rope_scaling:
182 | self.rope_scaling["rope_type"] = self.rope_scaling["type"]
183 | rope_config_validation(self)
184 |
185 | super().__init__(
186 | tie_word_embeddings=tie_word_embeddings,
187 | **kwargs,
188 | )
189 |
--------------------------------------------------------------------------------
/cost_analysis.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import json
3 | import os
4 | from llm_analysis.analysis import LLMAnalysis, get_gpu_config_by_name, ModelConfig, ActivationRecomputation, BYTES_FP16
5 |
6 |
7 | if __name__ == "__main__":
8 | import argparse
9 |
10 | parser = argparse.ArgumentParser()
11 |
12 | # General model config
13 | parser.add_argument('--hidden_size', type=int, required=True)
14 | parser.add_argument('--intermediate_size', type=int, required=True)
15 | parser.add_argument('--num_hidden_layers', type=int, default=36)
16 | parser.add_argument('--num_attention_heads', type=int, default=16)
17 | parser.add_argument('--max_position_embeddings', type=int, default=2048)
18 | parser.add_argument('--num_key_value_heads', type=int, default=2)
19 | parser.add_argument('--vocab_size', type=int, default=151936)
20 |
21 | # Parscale config
22 | parser.add_argument('--P', type=int, default=1) # Number of parallel streams
23 | parser.add_argument('--parscale_prefix_tokens', type=int, default=48) # Number of prefix tokens
24 |
25 | # Data config
26 | parser.add_argument('--batch_size', type=int, default=1)
27 | parser.add_argument('--input_length', type=int, default=64)
28 | parser.add_argument('--output_length', type=int, default=64)
29 |
30 | # GPU config
31 | parser.add_argument('--gpu_config', type=str, default="a100-sxm-80gb")
32 | parser.add_argument('--flops_efficiency', type=float, default=0.7) # Recommended by llm-analysis
33 | parser.add_argument('--hbm_memory_efficiency', type=float, default=0.9) # Recommended by llm-analysis
34 |
35 | args = parser.parse_args()
36 | p = args.P
37 | model_config = ModelConfig(
38 | name="",
39 | num_layers=args.num_hidden_layers,
40 | n_head=args.num_attention_heads,
41 | hidden_dim=args.hidden_size, vocab_size=args.vocab_size,
42 | max_seq_len=args.max_position_embeddings + (args.parscale_prefix_tokens if p > 1 else 0),
43 | num_key_value_heads=args.num_key_value_heads,
44 | ffn_embed_dim=args.intermediate_size,
45 | mlp_gated_linear_units=True
46 | )
47 | gpu_config = get_gpu_config_by_name("a100-sxm-80gb")
48 | gpu_config.mem_per_GPU_in_GB = 10000
49 |
50 | analysis = LLMAnalysis(
51 | model_config,
52 | gpu_config,
53 | flops_efficiency=0.7,
54 | hbm_memory_efficiency=0.9,
55 | )
56 | seq_len = args.input_length + (args.parscale_prefix_tokens if p > 1 else 0)
57 | summary_dict = analysis.inference(
58 | batch_size_per_gpu=args.batch_size * p,
59 | seq_len=seq_len,
60 | num_tokens_to_generate=args.output_length,
61 | )
62 |
63 | # We consider the influence of the aggregation layer.
64 | aggregate_param = (args.hidden_size + 1) * args.hidden_size * p if p > 1 else 0
65 | aggregate_param_vs_fwd_param = aggregate_param / analysis.get_num_params_per_layer_mlp()
66 | aggregate_latency = aggregate_param_vs_fwd_param * analysis.get_latency_fwd_per_layer_mlp(args.batch_size, args.input_length + args.output_length) if p > 1 else 0
67 | aggregate_memory = aggregate_param * analysis.dtype_config.weight_bits / 8
68 |
69 | prefill_activation_memory_per_gpu = max(
70 | # Each layer's activation memory will increase by P times
71 | analysis.get_activation_memory_per_layer(
72 | args.batch_size * p,
73 | seq_len,
74 | is_inference=True,
75 | layernorm_dtype_bytes=BYTES_FP16,
76 | ),
77 | # The embedding's activation memory will not participate in parallel and independent of P.
78 | analysis.get_activation_memory_output_embedding(
79 | args.batch_size, seq_len
80 | )
81 | )
82 |
83 | # Since we use batch_size * p as the new batch size, the latency for llm-analysis assumes the embedding latency is also computed in this new batch size. However, ParScale will not increase the computation for embedding.
84 | # Therefore, we should make a fix toward it.
85 | embedding_latency_estimate_for_embedding = (
86 | analysis.get_latency_fwd_input_embedding(args.batch_size * p, args.input_length + args.output_length, dtype_bytes=analysis.dtype_config.embedding_bits) +
87 | analysis.get_latency_fwd_output_embedding_loss(args.batch_size * p, args.input_length + args.output_length)
88 | )
89 | embedding_latency_real_for_embedding = (
90 | analysis.get_latency_fwd_input_embedding(args.batch_size, args.input_length + args.output_length, dtype_bytes=analysis.dtype_config.embedding_bits) +
91 | analysis.get_latency_fwd_output_embedding_loss(args.batch_size, args.input_length + args.output_length)
92 | )
93 |
94 | total_memory = (
95 | summary_dict['kv_cache_memory_per_gpu'] +
96 | summary_dict['weight_memory_per_gpu'] +
97 | aggregate_memory +
98 | prefill_activation_memory_per_gpu
99 | )
100 | total_latency = (
101 | summary_dict['total_latency'] + aggregate_latency
102 | - embedding_latency_estimate_for_embedding
103 | + embedding_latency_real_for_embedding
104 | )
105 | print(f"Memory: {total_memory / 2**30:.3f}GB; Latency: {total_latency:.3f}s")
--------------------------------------------------------------------------------
/figures/1t.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/QwenLM/ParScale/cd6acb48ba6d3b1f9715785ff0d212751210177d/figures/1t.png
--------------------------------------------------------------------------------
/figures/cost.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/QwenLM/ParScale/cd6acb48ba6d3b1f9715785ff0d212751210177d/figures/cost.png
--------------------------------------------------------------------------------
/figures/logo.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/QwenLM/ParScale/cd6acb48ba6d3b1f9715785ff0d212751210177d/figures/logo.jpg
--------------------------------------------------------------------------------
/figures/scaling_comparison.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/QwenLM/ParScale/cd6acb48ba6d3b1f9715785ff0d212751210177d/figures/scaling_comparison.png
--------------------------------------------------------------------------------
/figures/scaling_law.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/QwenLM/ParScale/cd6acb48ba6d3b1f9715785ff0d212751210177d/figures/scaling_law.png
--------------------------------------------------------------------------------
/figures/scaling_law2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/QwenLM/ParScale/cd6acb48ba6d3b1f9715785ff0d212751210177d/figures/scaling_law2.png
--------------------------------------------------------------------------------
/figures/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/QwenLM/ParScale/cd6acb48ba6d3b1f9715785ff0d212751210177d/figures/teaser.png
--------------------------------------------------------------------------------
/modeling_qwen2_parscale.py:
--------------------------------------------------------------------------------
1 | """
2 | This is the inference code for ParScale, Based on Qwen2. It can be used directly to load existing Qwen2 models (setting parscale_n = 1 by default).
3 | All modifications are wrapped within the condition 'parscale_n > 1'.
4 | If you are interested in how ParScale is implemented, please search for "parscale_n" in this file.
5 | """
6 |
7 | from typing import Callable, List, Optional, Tuple, Union
8 |
9 | import torch
10 | from torch import nn
11 | from einops import repeat, rearrange
12 |
13 | from transformers.activations import ACT2FN
14 | from transformers.cache_utils import Cache, DynamicCache, StaticCache
15 | from transformers.generation import GenerationMixin
16 | from transformers.modeling_attn_mask_utils import AttentionMaskConverter
17 | from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
18 | from transformers.modeling_outputs import (
19 | BaseModelOutputWithPast,
20 | CausalLMOutputWithPast,
21 | QuestionAnsweringModelOutput,
22 | SequenceClassifierOutputWithPast,
23 | TokenClassifierOutput,
24 | )
25 | from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
26 | from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
27 | from transformers.processing_utils import Unpack
28 | from transformers.utils import (
29 | LossKwargs,
30 | add_code_sample_docstrings,
31 | add_start_docstrings,
32 | add_start_docstrings_to_model_forward,
33 | logging,
34 | replace_return_docstrings,
35 | )
36 | from .configuration_qwen2_parscale import Qwen2ParScaleConfig
37 | from typing import Any, Dict, List, Optional, Tuple, Union
38 |
39 |
40 | logger = logging.get_logger(__name__)
41 |
42 | _CHECKPOINT_FOR_DOC = "meta-qwen2/Qwen2-2-7b-hf"
43 | _CONFIG_FOR_DOC = "Qwen2ParScaleConfig"
44 |
45 |
46 | class Qwen2MLP(nn.Module):
47 | def __init__(self, config):
48 | super().__init__()
49 | self.config = config
50 | self.hidden_size = config.hidden_size
51 | self.intermediate_size = config.intermediate_size
52 | self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
53 | self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
54 | self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
55 | self.act_fn = ACT2FN[config.hidden_act]
56 |
57 | def forward(self, x):
58 | down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
59 | return down_proj
60 |
61 |
62 | def rotate_half(x):
63 | """Rotates half the hidden dims of the input."""
64 | x1 = x[..., : x.shape[-1] // 2]
65 | x2 = x[..., x.shape[-1] // 2 :]
66 | return torch.cat((-x2, x1), dim=-1)
67 |
68 |
69 | def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
70 | """Applies Rotary Position Embedding to the query and key tensors.
71 |
72 | Args:
73 | q (`torch.Tensor`): The query tensor.
74 | k (`torch.Tensor`): The key tensor.
75 | cos (`torch.Tensor`): The cosine part of the rotary embedding.
76 | sin (`torch.Tensor`): The sine part of the rotary embedding.
77 | position_ids (`torch.Tensor`, *optional*):
78 | Deprecated and unused.
79 | unsqueeze_dim (`int`, *optional*, defaults to 1):
80 | The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
81 | sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
82 | that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
83 | k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
84 | cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
85 | the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
86 | Returns:
87 | `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
88 | """
89 | cos = cos.unsqueeze(unsqueeze_dim)
90 | sin = sin.unsqueeze(unsqueeze_dim)
91 | q_embed = (q * cos) + (rotate_half(q) * sin)
92 | k_embed = (k * cos) + (rotate_half(k) * sin)
93 | return q_embed, k_embed
94 |
95 |
96 | def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
97 | """
98 | This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
99 | num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
100 | """
101 | batch, num_key_value_heads, slen, head_dim = hidden_states.shape
102 | if n_rep == 1:
103 | return hidden_states
104 | hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
105 | return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
106 |
107 |
108 | def eager_attention_forward(
109 | module: nn.Module,
110 | query: torch.Tensor,
111 | key: torch.Tensor,
112 | value: torch.Tensor,
113 | attention_mask: Optional[torch.Tensor],
114 | scaling: float,
115 | dropout: float = 0.0,
116 | **kwargs,
117 | ):
118 | key_states = repeat_kv(key, module.num_key_value_groups)
119 | value_states = repeat_kv(value, module.num_key_value_groups)
120 |
121 | attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
122 | if attention_mask is not None:
123 | causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
124 | attn_weights = attn_weights + causal_mask
125 |
126 | attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
127 | attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
128 | attn_output = torch.matmul(attn_weights, value_states)
129 | attn_output = attn_output.transpose(1, 2).contiguous()
130 |
131 | return attn_output, attn_weights
132 |
133 | class ParscaleCache(DynamicCache):
134 | def __init__(self, prefix_k, prefix_v) -> None:
135 | super().__init__()
136 | self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen
137 | self.key_cache: List[torch.Tensor] = prefix_k
138 | self.value_cache: List[torch.Tensor] = prefix_v
139 | self.parscale_n = prefix_k[0].size(0)
140 | self.n_prefix_tokens = prefix_k[0].size(2)
141 | def update(
142 | self,
143 | key_states: torch.Tensor,
144 | value_states: torch.Tensor,
145 | layer_idx: int,
146 | cache_kwargs: Optional[Dict[str, Any]] = None,
147 | ) -> Tuple[torch.Tensor, torch.Tensor]:
148 | if self.key_cache[layer_idx].size(0) != key_states.size(0):
149 | # first time generation
150 | self.key_cache[layer_idx] = repeat(self.key_cache[layer_idx], 'n_parscale ... -> (n_parscale b) ...', b=key_states.size(0) // self.parscale_n)
151 | self.value_cache[layer_idx] = repeat(self.value_cache[layer_idx], 'n_parscale ... -> (n_parscale b) ...', b=key_states.size(0) // self.parscale_n)
152 | return super().update(key_states, value_states, layer_idx, cache_kwargs)
153 |
154 | def get_seq_length(self, layer_idx = 0):
155 | seq_len = super().get_seq_length(layer_idx)
156 | if seq_len != 0:
157 | seq_len -= self.n_prefix_tokens
158 | return seq_len
159 |
160 | def reorder_cache(self, beam_idx: torch.LongTensor):
161 | """Reorders the cache for beam search, given the selected beam indices."""
162 | b = self.key_cache[0].size(0) // self.parscale_n
163 | beam_idx = torch.cat([beam_idx + b * i for i in range(self.parscale_n)])
164 | super().reorder_cache(beam_idx)
165 |
166 | class Qwen2Attention(nn.Module):
167 | """Multi-headed attention from 'Attention Is All You Need' paper"""
168 |
169 | def __init__(self, config: Qwen2ParScaleConfig, layer_idx: int):
170 | super().__init__()
171 | self.config = config
172 | self.layer_idx = layer_idx
173 | self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
174 | self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
175 | self.scaling = self.head_dim**-0.5
176 | self.attention_dropout = config.attention_dropout
177 | self.is_causal = True
178 | self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=True)
179 | self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True)
180 | self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True)
181 | self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
182 | if config.parscale_n > 1:
183 | self.prefix_k = nn.Parameter(torch.empty((config.parscale_n, config.num_key_value_heads, config.parscale_n_tokens, self.head_dim)))
184 | self.prefix_v = nn.Parameter(torch.empty((config.parscale_n, config.num_key_value_heads, config.parscale_n_tokens, self.head_dim)))
185 |
186 |
187 | def forward(
188 | self,
189 | hidden_states: torch.Tensor,
190 | position_embeddings: Tuple[torch.Tensor, torch.Tensor],
191 | attention_mask: Optional[torch.Tensor],
192 | past_key_value: Optional[Cache] = None,
193 | cache_position: Optional[torch.LongTensor] = None,
194 | **kwargs: Unpack[FlashAttentionKwargs],
195 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
196 | input_shape = hidden_states.shape[:-1]
197 | hidden_shape = (*input_shape, -1, self.head_dim)
198 |
199 | query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
200 | key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
201 | value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
202 |
203 | cos, sin = position_embeddings
204 | query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
205 |
206 | if past_key_value is not None:
207 | # sin and cos are specific to RoPE models; cache_position needed for the static cache
208 | cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
209 | key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
210 |
211 | if self.config.parscale_n > 1:
212 |
213 | # Expand attention mask to contain the prefix tokens
214 | n_virtual_tokens = self.config.parscale_n_tokens
215 |
216 | if attention_mask is not None:
217 | attention_mask = torch.cat([
218 | torch.zeros((attention_mask.shape[0], attention_mask.shape[1], attention_mask.shape[2], self.config.parscale_n_tokens), dtype=attention_mask.dtype, device=attention_mask.device),
219 | attention_mask
220 | ], dim=3)
221 |
222 | if query_states.size(2) != 1:
223 | query_states = torch.cat([torch.zeros([query_states.size(0), query_states.size(1), n_virtual_tokens, query_states.size(3)], dtype=query_states.dtype, device=query_states.device), query_states], dim=2)
224 | if attention_mask is not None:
225 | attention_mask = torch.cat([
226 | torch.zeros((attention_mask.shape[0], attention_mask.shape[1], self.config.parscale_n_tokens, attention_mask.shape[3]), dtype=attention_mask.dtype, device=attention_mask.device),
227 | attention_mask
228 | ], dim=2)
229 |
230 | sliding_window = None
231 | if (
232 | self.config.use_sliding_window
233 | and getattr(self.config, "sliding_window", None) is not None
234 | and self.layer_idx >= self.config.max_window_layers
235 | ):
236 | sliding_window = self.config.sliding_window
237 |
238 | attention_interface: Callable = eager_attention_forward
239 | if self.config._attn_implementation != "eager":
240 | if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
241 | logger.warning_once(
242 | "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
243 | 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
244 | )
245 | else:
246 | attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
247 |
248 | attn_output, attn_weights = attention_interface(
249 | self,
250 | query_states,
251 | key_states,
252 | value_states,
253 | attention_mask,
254 | dropout=0.0 if not self.training else self.attention_dropout,
255 | scaling=self.scaling,
256 | sliding_window=sliding_window, # main diff with Llama
257 | # is_causal=True,
258 | **kwargs,
259 | )
260 |
261 | if self.config.parscale_n > 1 and query_states.size(2) != 1:
262 | # Remove the prefix part
263 | attn_output = attn_output[:, n_virtual_tokens:]
264 | attn_output = attn_output.reshape(*input_shape, -1).contiguous()
265 | attn_output = self.o_proj(attn_output)
266 | return attn_output, attn_weights
267 |
268 |
269 | class Qwen2RMSNorm(nn.Module):
270 | def __init__(self, hidden_size, eps=1e-6):
271 | """
272 | Qwen2RMSNorm is equivalent to T5LayerNorm
273 | """
274 | super().__init__()
275 | self.weight = nn.Parameter(torch.ones(hidden_size))
276 | self.variance_epsilon = eps
277 |
278 | def forward(self, hidden_states):
279 | input_dtype = hidden_states.dtype
280 | hidden_states = hidden_states.to(torch.float32)
281 | variance = hidden_states.pow(2).mean(-1, keepdim=True)
282 | hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
283 | return self.weight * hidden_states.to(input_dtype)
284 |
285 | def extra_repr(self):
286 | return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
287 |
288 |
289 | class Qwen2DecoderLayer(nn.Module):
290 | def __init__(self, config: Qwen2ParScaleConfig, layer_idx: int):
291 | super().__init__()
292 | self.hidden_size = config.hidden_size
293 | self.self_attn = Qwen2Attention(config=config, layer_idx=layer_idx)
294 | self.mlp = Qwen2MLP(config)
295 | self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
296 | self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
297 | if config.sliding_window and config._attn_implementation != "flash_attention_2":
298 | logger.warning_once(
299 | f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; "
300 | "unexpected results may be encountered."
301 | )
302 |
303 | def forward(
304 | self,
305 | hidden_states: torch.Tensor,
306 | attention_mask: Optional[torch.Tensor] = None,
307 | position_ids: Optional[torch.LongTensor] = None,
308 | past_key_value: Optional[Cache] = None,
309 | output_attentions: Optional[bool] = False,
310 | use_cache: Optional[bool] = False,
311 | cache_position: Optional[torch.LongTensor] = None,
312 | position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
313 | **kwargs: Unpack[FlashAttentionKwargs],
314 | ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
315 | residual = hidden_states
316 |
317 | hidden_states = self.input_layernorm(hidden_states)
318 |
319 | # Self Attention
320 | hidden_states, self_attn_weights = self.self_attn(
321 | hidden_states=hidden_states,
322 | attention_mask=attention_mask,
323 | position_ids=position_ids,
324 | past_key_value=past_key_value,
325 | output_attentions=output_attentions,
326 | use_cache=use_cache,
327 | cache_position=cache_position,
328 | position_embeddings=position_embeddings,
329 | **kwargs,
330 | )
331 | hidden_states = residual + hidden_states
332 |
333 | # Fully Connected
334 | residual = hidden_states
335 | hidden_states = self.post_attention_layernorm(hidden_states)
336 | hidden_states = self.mlp(hidden_states)
337 | hidden_states = residual + hidden_states
338 |
339 | outputs = (hidden_states,)
340 | if output_attentions:
341 | outputs += (self_attn_weights,)
342 |
343 | return outputs
344 |
345 |
346 | class Qwen2RotaryEmbedding(nn.Module):
347 | def __init__(self, config: Qwen2ParScaleConfig, device=None):
348 | super().__init__()
349 | # BC: "rope_type" was originally "type"
350 | if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
351 | self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
352 | else:
353 | self.rope_type = "default"
354 | self.max_seq_len_cached = config.max_position_embeddings
355 | self.original_max_seq_len = config.max_position_embeddings
356 |
357 | self.config = config
358 | self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
359 |
360 | inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
361 | self.register_buffer("inv_freq", inv_freq, persistent=False)
362 | self.original_inv_freq = self.inv_freq
363 |
364 | def _dynamic_frequency_update(self, position_ids, device):
365 | """
366 | dynamic RoPE layers should recompute `inv_freq` in the following situations:
367 | 1 - growing beyond the cached sequence length (allow scaling)
368 | 2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
369 | """
370 | seq_len = torch.max(position_ids) + 1
371 | if seq_len > self.max_seq_len_cached: # growth
372 | inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)
373 | self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
374 | self.max_seq_len_cached = seq_len
375 |
376 | if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
377 | # This .to() is needed if the model has been moved to a device after being initialized (because
378 | # the buffer is automatically moved, but not the original copy)
379 | self.original_inv_freq = self.original_inv_freq.to(device)
380 | self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
381 | self.max_seq_len_cached = self.original_max_seq_len
382 |
383 | @torch.no_grad()
384 | def forward(self, x, position_ids):
385 | if "dynamic" in self.rope_type:
386 | self._dynamic_frequency_update(position_ids, device=x.device)
387 |
388 | # Core RoPE block
389 | inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
390 | position_ids_expanded = position_ids[:, None, :].float()
391 | # Force float32 (see https://github.com/huggingface/transformers/pull/29285)
392 | device_type = x.device.type
393 | device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
394 | with torch.autocast(device_type=device_type, enabled=False):
395 | freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
396 | emb = torch.cat((freqs, freqs), dim=-1)
397 | cos = emb.cos()
398 | sin = emb.sin()
399 |
400 | # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
401 | cos = cos * self.attention_scaling
402 | sin = sin * self.attention_scaling
403 |
404 | return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
405 |
406 |
407 | QWEN2_START_DOCSTRING = r"""
408 | This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
409 | library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
410 | etc.)
411 |
412 | This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
413 | Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
414 | and behavior.
415 |
416 | Parameters:
417 | config ([`Qwen2ParScaleConfig`]):
418 | Model configuration class with all the parameters of the model. Initializing with a config file does not
419 | load the weights associated with the model, only the configuration. Check out the
420 | [`~PreTrainedModel.from_pretrained`] method to load the model weights.
421 | """
422 |
423 |
424 | @add_start_docstrings(
425 | "The bare Qwen2 Model outputting raw hidden-states without any specific head on top.",
426 | QWEN2_START_DOCSTRING,
427 | )
428 | class Qwen2PreTrainedModel(PreTrainedModel):
429 | config_class = Qwen2ParScaleConfig
430 | base_model_prefix = "model"
431 | supports_gradient_checkpointing = True
432 | _no_split_modules = ["Qwen2DecoderLayer"]
433 | _skip_keys_device_placement = ["past_key_values"]
434 | _supports_flash_attn_2 = True
435 | _supports_sdpa = True
436 | _supports_flex_attn = True
437 | _supports_cache_class = True
438 | _supports_quantized_cache = True
439 | _supports_static_cache = True
440 |
441 | def _init_weights(self, module):
442 | std = self.config.initializer_range
443 | if isinstance(module, nn.Linear):
444 | module.weight.data.normal_(mean=0.0, std=std)
445 | if module.bias is not None:
446 | module.bias.data.zero_()
447 | elif isinstance(module, nn.Embedding):
448 | module.weight.data.normal_(mean=0.0, std=std)
449 | if module.padding_idx is not None:
450 | module.weight.data[module.padding_idx].zero_()
451 |
452 |
453 | QWEN2_INPUTS_DOCSTRING = r"""
454 | Args:
455 | input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
456 | Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
457 | it.
458 |
459 | Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
460 | [`PreTrainedTokenizer.__call__`] for details.
461 |
462 | [What are input IDs?](../glossary#input-ids)
463 | attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
464 | Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
465 |
466 | - 1 for tokens that are **not masked**,
467 | - 0 for tokens that are **masked**.
468 |
469 | [What are attention masks?](../glossary#attention-mask)
470 |
471 | Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
472 | [`PreTrainedTokenizer.__call__`] for details.
473 |
474 | If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
475 | `past_key_values`).
476 |
477 | If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
478 | and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
479 | information on the default strategy.
480 |
481 | - 1 indicates the head is **not masked**,
482 | - 0 indicates the head is **masked**.
483 | position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
484 | Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
485 | config.n_positions - 1]`.
486 |
487 | [What are position IDs?](../glossary#position-ids)
488 | past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
489 | Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
490 | blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
491 | returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
492 |
493 | Two formats are allowed:
494 | - a [`~cache_utils.Cache`] instance, see our
495 | [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache);
496 | - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
497 | shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
498 | cache format.
499 |
500 | The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
501 | legacy cache format will be returned.
502 |
503 | If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
504 | have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
505 | of shape `(batch_size, sequence_length)`.
506 | inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
507 | Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
508 | is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
509 | model's internal embedding lookup matrix.
510 | use_cache (`bool`, *optional*):
511 | If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
512 | `past_key_values`).
513 | output_attentions (`bool`, *optional*):
514 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
515 | tensors for more detail.
516 | output_hidden_states (`bool`, *optional*):
517 | Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
518 | more detail.
519 | return_dict (`bool`, *optional*):
520 | Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
521 | cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
522 | Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
523 | this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
524 | the complete sequence length.
525 | """
526 |
527 |
528 | @add_start_docstrings(
529 | "The bare Qwen2 Model outputting raw hidden-states without any specific head on top.",
530 | QWEN2_START_DOCSTRING,
531 | )
532 | class Qwen2Model(Qwen2PreTrainedModel):
533 | """
534 | Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen2DecoderLayer`]
535 |
536 | Args:
537 | config: Qwen2ParScaleConfig
538 | """
539 |
540 | def __init__(self, config: Qwen2ParScaleConfig):
541 | super().__init__(config)
542 | self.padding_idx = config.pad_token_id
543 | self.vocab_size = config.vocab_size
544 |
545 | self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
546 | self.layers = nn.ModuleList(
547 | [Qwen2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
548 | )
549 | self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
550 | self.rotary_emb = Qwen2RotaryEmbedding(config=config)
551 | self.gradient_checkpointing = False
552 |
553 | self.parscale_n = config.parscale_n
554 | if config.parscale_n > 1:
555 | self.aggregate_layer = torch.nn.Sequential(
556 | torch.nn.Linear(config.parscale_n * config.hidden_size, config.hidden_size),
557 | torch.nn.SiLU(),
558 | torch.nn.Linear(config.hidden_size, config.parscale_n)
559 | )
560 | self.parscale_aggregate_attn_smoothing = config.parscale_attn_smooth
561 |
562 | # Initialize weights and apply final processing
563 | self.post_init()
564 |
565 | def get_input_embeddings(self):
566 | return self.embed_tokens
567 |
568 | def set_input_embeddings(self, value):
569 | self.embed_tokens = value
570 |
571 | @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
572 | def forward(
573 | self,
574 | input_ids: torch.LongTensor = None,
575 | attention_mask: Optional[torch.Tensor] = None,
576 | position_ids: Optional[torch.LongTensor] = None,
577 | past_key_values: Optional[Cache] = None,
578 | inputs_embeds: Optional[torch.FloatTensor] = None,
579 | use_cache: Optional[bool] = None,
580 | output_attentions: Optional[bool] = None,
581 | output_hidden_states: Optional[bool] = None,
582 | return_dict: Optional[bool] = None,
583 | cache_position: Optional[torch.LongTensor] = None,
584 | **flash_attn_kwargs: Unpack[FlashAttentionKwargs],
585 | ) -> Union[Tuple, BaseModelOutputWithPast]:
586 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
587 | output_hidden_states = (
588 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
589 | )
590 | use_cache = use_cache if use_cache is not None else self.config.use_cache
591 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
592 |
593 | if (input_ids is None) ^ (inputs_embeds is not None):
594 | raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
595 |
596 | if self.gradient_checkpointing and self.training and use_cache:
597 | logger.warning_once(
598 | "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
599 | )
600 | use_cache = False
601 |
602 | if inputs_embeds is None:
603 | inputs_embeds = self.embed_tokens(input_ids)
604 |
605 | if self.parscale_n > 1:
606 | # Input transformation: we directly copy the input for n_parscale times.
607 | # The transformation is implemented through KVCache (ParscaleCache).
608 | inputs_embeds = repeat(inputs_embeds, "b s h -> (n_parscale b) s h", n_parscale=self.parscale_n)
609 | if attention_mask is not None:
610 | attention_mask = repeat(attention_mask, "b s -> (n_parscale b) s", n_parscale=self.parscale_n)
611 | if position_ids is not None:
612 | position_ids = repeat(position_ids, "b s -> (n_parscale b) s", n_parscale=self.parscale_n)
613 |
614 | # The trained prefix is saved in layer.self_attn.prefix_k / layer.self_attn.prefix_v
615 | # We extract them to construct ParscaleCache.
616 | if past_key_values is None or past_key_values.get_seq_length() == 0:
617 | past_key_values = ParscaleCache([layer.self_attn.prefix_k for layer in self.layers], [layer.self_attn.prefix_v for layer in self.layers])
618 |
619 | if use_cache and past_key_values is None:
620 | past_key_values = DynamicCache()
621 |
622 | if cache_position is None:
623 | past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
624 | cache_position = torch.arange(
625 | past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
626 | )
627 |
628 | if position_ids is None:
629 | position_ids = cache_position.unsqueeze(0)
630 |
631 | causal_mask = self._update_causal_mask(
632 | attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
633 | )
634 |
635 | hidden_states = inputs_embeds
636 |
637 | # create position embeddings to be shared across the decoder layers
638 | position_embeddings = self.rotary_emb(hidden_states, position_ids)
639 |
640 | # decoder layers
641 | all_hidden_states = () if output_hidden_states else None
642 | all_self_attns = () if output_attentions else None
643 |
644 | for decoder_layer in self.layers[: self.config.num_hidden_layers]:
645 | if output_hidden_states:
646 | all_hidden_states += (hidden_states,)
647 |
648 | if self.gradient_checkpointing and self.training:
649 | layer_outputs = self._gradient_checkpointing_func(
650 | decoder_layer.__call__,
651 | hidden_states,
652 | causal_mask,
653 | position_ids,
654 | past_key_values,
655 | output_attentions,
656 | use_cache,
657 | cache_position,
658 | position_embeddings,
659 | )
660 | else:
661 | layer_outputs = decoder_layer(
662 | hidden_states,
663 | attention_mask=causal_mask,
664 | position_ids=position_ids,
665 | past_key_value=past_key_values,
666 | output_attentions=output_attentions,
667 | use_cache=use_cache,
668 | cache_position=cache_position,
669 | position_embeddings=position_embeddings,
670 | **flash_attn_kwargs,
671 | )
672 |
673 | hidden_states = layer_outputs[0]
674 |
675 | if output_attentions:
676 | all_self_attns += (layer_outputs[1],)
677 |
678 | hidden_states = self.norm(hidden_states)
679 |
680 | if self.parscale_n > 1:
681 | # output aggregation, based on dynamic weighted sum.
682 | attn = torch.unsqueeze(torch.softmax(self.aggregate_layer(
683 | rearrange(hidden_states, "(n_parscale b) s h -> b s (h n_parscale)", n_parscale=self.parscale_n)
684 | ).float(), dim=-1), dim=-1) # [b s n_parscale 1]
685 | if self.parscale_aggregate_attn_smoothing != 0.0:
686 | attn = attn * (1 - self.parscale_aggregate_attn_smoothing) + (self.parscale_aggregate_attn_smoothing / self.parscale_n)
687 | hidden_states = torch.sum(
688 | rearrange(hidden_states, "(n_parscale b) s h -> b s n_parscale h", n_parscale=self.parscale_n) * attn,
689 | dim=2, keepdim=False
690 | ).to(hidden_states.dtype)
691 |
692 | # add hidden states from the last decoder layer
693 | if output_hidden_states:
694 | all_hidden_states += (hidden_states,)
695 |
696 | output = BaseModelOutputWithPast(
697 | last_hidden_state=hidden_states,
698 | past_key_values=past_key_values if use_cache else None,
699 | hidden_states=all_hidden_states,
700 | attentions=all_self_attns,
701 | )
702 | return output if return_dict else output.to_tuple()
703 |
704 | def _update_causal_mask(
705 | self,
706 | attention_mask: torch.Tensor,
707 | input_tensor: torch.Tensor,
708 | cache_position: torch.Tensor,
709 | past_key_values: Cache,
710 | output_attentions: bool,
711 | ):
712 | if self.config._attn_implementation == "flash_attention_2":
713 | if attention_mask is not None and (attention_mask == 0.0).any():
714 | return attention_mask
715 | return None
716 |
717 | # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
718 | # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
719 | # to infer the attention mask.
720 | past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
721 | using_static_cache = isinstance(past_key_values, StaticCache)
722 |
723 | # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
724 | if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
725 | if AttentionMaskConverter._ignore_causal_mask_sdpa(
726 | attention_mask,
727 | inputs_embeds=input_tensor,
728 | past_key_values_length=past_seen_tokens,
729 | is_training=self.training,
730 | ):
731 | return None
732 |
733 | dtype, device = input_tensor.dtype, input_tensor.device
734 | sequence_length = input_tensor.shape[1]
735 | if using_static_cache:
736 | target_length = past_key_values.get_max_cache_shape()
737 | else:
738 | target_length = (
739 | attention_mask.shape[-1]
740 | if isinstance(attention_mask, torch.Tensor)
741 | else past_seen_tokens + sequence_length + 1
742 | )
743 |
744 | causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
745 | attention_mask,
746 | sequence_length=sequence_length,
747 | target_length=target_length,
748 | dtype=dtype,
749 | device=device,
750 | cache_position=cache_position,
751 | batch_size=input_tensor.shape[0],
752 | )
753 |
754 | if (
755 | self.config._attn_implementation == "sdpa"
756 | and attention_mask is not None
757 | and attention_mask.device.type == "cuda"
758 | and not output_attentions
759 | ):
760 | # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
761 | # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
762 | # Details: https://github.com/pytorch/pytorch/issues/110213
763 | min_dtype = torch.finfo(dtype).min
764 | causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
765 |
766 | return causal_mask
767 |
768 | @staticmethod
769 | def _prepare_4d_causal_attention_mask_with_cache_position(
770 | attention_mask: torch.Tensor,
771 | sequence_length: int,
772 | target_length: int,
773 | dtype: torch.dtype,
774 | device: torch.device,
775 | cache_position: torch.Tensor,
776 | batch_size: int,
777 | **kwargs,
778 | ):
779 | """
780 | Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
781 | `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
782 |
783 | Args:
784 | attention_mask (`torch.Tensor`):
785 | A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
786 | `(batch_size, 1, query_length, key_value_length)`.
787 | sequence_length (`int`):
788 | The sequence length being processed.
789 | target_length (`int`):
790 | The target length: when generating with static cache, the mask should be as long as the static cache,
791 | to account for the 0 padding, the part of the cache that is not filled yet.
792 | dtype (`torch.dtype`):
793 | The dtype to use for the 4D attention mask.
794 | device (`torch.device`):
795 | The device to plcae the 4D attention mask on.
796 | cache_position (`torch.Tensor`):
797 | Indices depicting the position of the input sequence tokens in the sequence.
798 | batch_size (`torch.Tensor`):
799 | Batch size.
800 | """
801 | if attention_mask is not None and attention_mask.dim() == 4:
802 | # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
803 | causal_mask = attention_mask
804 | else:
805 | min_dtype = torch.finfo(dtype).min
806 | causal_mask = torch.full(
807 | (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
808 | )
809 | if sequence_length != 1:
810 | causal_mask = torch.triu(causal_mask, diagonal=1)
811 | causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
812 | causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
813 | if attention_mask is not None:
814 | causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
815 | mask_length = attention_mask.shape[-1]
816 | padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
817 | padding_mask = padding_mask == 0
818 | causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
819 | padding_mask, min_dtype
820 | )
821 |
822 | return causal_mask
823 |
824 |
825 | class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
826 |
827 |
828 | class Qwen2ParScaleForCausalLM(Qwen2PreTrainedModel, GenerationMixin):
829 | _tied_weights_keys = ["lm_head.weight"]
830 | _tp_plan = {"lm_head": "colwise_rep"}
831 |
832 | def __init__(self, config):
833 | super().__init__(config)
834 | self.model = Qwen2Model(config)
835 | self.vocab_size = config.vocab_size
836 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
837 |
838 | # Initialize weights and apply final processing
839 | self.post_init()
840 |
841 | def get_input_embeddings(self):
842 | return self.model.embed_tokens
843 |
844 | def set_input_embeddings(self, value):
845 | self.model.embed_tokens = value
846 |
847 | def get_output_embeddings(self):
848 | return self.lm_head
849 |
850 | def set_output_embeddings(self, new_embeddings):
851 | self.lm_head = new_embeddings
852 |
853 | def set_decoder(self, decoder):
854 | self.model = decoder
855 |
856 | def get_decoder(self):
857 | return self.model
858 |
859 | @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
860 | @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
861 | def forward(
862 | self,
863 | input_ids: torch.LongTensor = None,
864 | attention_mask: Optional[torch.Tensor] = None,
865 | position_ids: Optional[torch.LongTensor] = None,
866 | past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
867 | inputs_embeds: Optional[torch.FloatTensor] = None,
868 | labels: Optional[torch.LongTensor] = None,
869 | use_cache: Optional[bool] = None,
870 | output_attentions: Optional[bool] = None,
871 | output_hidden_states: Optional[bool] = None,
872 | return_dict: Optional[bool] = None,
873 | cache_position: Optional[torch.LongTensor] = None,
874 | num_logits_to_keep: int = 0,
875 | **kwargs: Unpack[KwargsForCausalLM],
876 | ) -> Union[Tuple, CausalLMOutputWithPast]:
877 | r"""
878 | Args:
879 | labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
880 | Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
881 | config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
882 | (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
883 |
884 | num_logits_to_keep (`int`, *optional*):
885 | Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
886 | `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
887 | token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
888 |
889 | Returns:
890 |
891 | Example:
892 |
893 | ```python
894 | >>> from transformers import AutoTokenizer, Qwen2ForCausalLM
895 |
896 | >>> model = Qwen2ForCausalLM.from_pretrained("meta-qwen2/Qwen2-2-7b-hf")
897 | >>> tokenizer = AutoTokenizer.from_pretrained("meta-qwen2/Qwen2-2-7b-hf")
898 |
899 | >>> prompt = "Hey, are you conscious? Can you talk to me?"
900 | >>> inputs = tokenizer(prompt, return_tensors="pt")
901 |
902 | >>> # Generate
903 | >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
904 | >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
905 | "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
906 | ```"""
907 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
908 | output_hidden_states = (
909 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
910 | )
911 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
912 |
913 | # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
914 | outputs = self.model(
915 | input_ids=input_ids,
916 | attention_mask=attention_mask,
917 | position_ids=position_ids,
918 | past_key_values=past_key_values,
919 | inputs_embeds=inputs_embeds,
920 | use_cache=use_cache,
921 | output_attentions=output_attentions,
922 | output_hidden_states=output_hidden_states,
923 | return_dict=return_dict,
924 | cache_position=cache_position,
925 | **kwargs,
926 | )
927 |
928 | hidden_states = outputs[0]
929 | # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
930 | logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
931 |
932 | loss = None
933 | if labels is not None:
934 | loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
935 |
936 | if not return_dict:
937 | output = (logits,) + outputs[1:]
938 | return (loss,) + output if loss is not None else output
939 |
940 | return CausalLMOutputWithPast(
941 | loss=loss,
942 | logits=logits,
943 | past_key_values=outputs.past_key_values,
944 | hidden_states=outputs.hidden_states,
945 | attentions=outputs.attentions,
946 | )
947 |
948 |
949 | @add_start_docstrings(
950 | """
951 | The Qwen2 Model transformer with a sequence classification head on top (linear layer).
952 |
953 | [`Qwen2ForSequenceClassification`] uses the last token in order to do the classification, as other causal models
954 | (e.g. GPT-2) do.
955 |
956 | Since it does classification on the last token, it requires to know the position of the last token. If a
957 | `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
958 | no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
959 | padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
960 | each row of the batch).
961 | """,
962 | QWEN2_START_DOCSTRING,
963 | )
964 | class Qwen2ForSequenceClassification(Qwen2PreTrainedModel):
965 | def __init__(self, config):
966 | super().__init__(config)
967 | self.num_labels = config.num_labels
968 | self.model = Qwen2Model(config)
969 | self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
970 |
971 | # Initialize weights and apply final processing
972 | self.post_init()
973 |
974 | def get_input_embeddings(self):
975 | return self.model.embed_tokens
976 |
977 | def set_input_embeddings(self, value):
978 | self.model.embed_tokens = value
979 |
980 | @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
981 | def forward(
982 | self,
983 | input_ids: Optional[torch.LongTensor] = None,
984 | attention_mask: Optional[torch.Tensor] = None,
985 | position_ids: Optional[torch.LongTensor] = None,
986 | past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
987 | inputs_embeds: Optional[torch.FloatTensor] = None,
988 | labels: Optional[torch.LongTensor] = None,
989 | use_cache: Optional[bool] = None,
990 | output_attentions: Optional[bool] = None,
991 | output_hidden_states: Optional[bool] = None,
992 | return_dict: Optional[bool] = None,
993 | ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
994 | r"""
995 | labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
996 | Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
997 | config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
998 | `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
999 | """
1000 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1001 |
1002 | transformer_outputs = self.model(
1003 | input_ids,
1004 | attention_mask=attention_mask,
1005 | position_ids=position_ids,
1006 | past_key_values=past_key_values,
1007 | inputs_embeds=inputs_embeds,
1008 | use_cache=use_cache,
1009 | output_attentions=output_attentions,
1010 | output_hidden_states=output_hidden_states,
1011 | return_dict=return_dict,
1012 | )
1013 | hidden_states = transformer_outputs[0]
1014 | logits = self.score(hidden_states)
1015 |
1016 | if input_ids is not None:
1017 | batch_size = input_ids.shape[0]
1018 | else:
1019 | batch_size = inputs_embeds.shape[0]
1020 |
1021 | if self.config.pad_token_id is None and batch_size != 1:
1022 | raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
1023 | if self.config.pad_token_id is None:
1024 | sequence_lengths = -1
1025 | else:
1026 | if input_ids is not None:
1027 | # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
1028 | sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
1029 | sequence_lengths = sequence_lengths % input_ids.shape[-1]
1030 | sequence_lengths = sequence_lengths.to(logits.device)
1031 | else:
1032 | sequence_lengths = -1
1033 |
1034 | pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
1035 |
1036 | loss = None
1037 | if labels is not None:
1038 | loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config)
1039 |
1040 | if not return_dict:
1041 | output = (pooled_logits,) + transformer_outputs[1:]
1042 | return ((loss,) + output) if loss is not None else output
1043 |
1044 | return SequenceClassifierOutputWithPast(
1045 | loss=loss,
1046 | logits=pooled_logits,
1047 | past_key_values=transformer_outputs.past_key_values,
1048 | hidden_states=transformer_outputs.hidden_states,
1049 | attentions=transformer_outputs.attentions,
1050 | )
1051 |
1052 |
1053 | @add_start_docstrings(
1054 | """
1055 | The Qwen2 Model transformer with a token classification head on top (a linear layer on top of the hidden-states
1056 | output) e.g. for Named-Entity-Recognition (NER) tasks.
1057 | """,
1058 | QWEN2_START_DOCSTRING,
1059 | )
1060 | class Qwen2ForTokenClassification(Qwen2PreTrainedModel):
1061 | def __init__(self, config):
1062 | super().__init__(config)
1063 | self.num_labels = config.num_labels
1064 | self.model = Qwen2Model(config)
1065 | if getattr(config, "classifier_dropout", None) is not None:
1066 | classifier_dropout = config.classifier_dropout
1067 | elif getattr(config, "hidden_dropout", None) is not None:
1068 | classifier_dropout = config.hidden_dropout
1069 | else:
1070 | classifier_dropout = 0.1
1071 | self.dropout = nn.Dropout(classifier_dropout)
1072 | self.score = nn.Linear(config.hidden_size, config.num_labels)
1073 |
1074 | # Initialize weights and apply final processing
1075 | self.post_init()
1076 |
1077 | def get_input_embeddings(self):
1078 | return self.model.embed_tokens
1079 |
1080 | def set_input_embeddings(self, value):
1081 | self.model.embed_tokens = value
1082 |
1083 | @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
1084 | @add_code_sample_docstrings(
1085 | checkpoint=_CHECKPOINT_FOR_DOC,
1086 | output_type=TokenClassifierOutput,
1087 | config_class=_CONFIG_FOR_DOC,
1088 | )
1089 | def forward(
1090 | self,
1091 | input_ids: Optional[torch.LongTensor] = None,
1092 | attention_mask: Optional[torch.Tensor] = None,
1093 | position_ids: Optional[torch.LongTensor] = None,
1094 | past_key_values: Optional[List[torch.FloatTensor]] = None,
1095 | inputs_embeds: Optional[torch.FloatTensor] = None,
1096 | labels: Optional[torch.LongTensor] = None,
1097 | use_cache: Optional[bool] = None,
1098 | output_attentions: Optional[bool] = None,
1099 | output_hidden_states: Optional[bool] = None,
1100 | return_dict: Optional[bool] = None,
1101 | ) -> Union[Tuple, TokenClassifierOutput]:
1102 | r"""
1103 | labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1104 | Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1105 | config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1106 | `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1107 | """
1108 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1109 |
1110 | outputs = self.model(
1111 | input_ids,
1112 | attention_mask=attention_mask,
1113 | position_ids=position_ids,
1114 | past_key_values=past_key_values,
1115 | inputs_embeds=inputs_embeds,
1116 | use_cache=use_cache,
1117 | output_attentions=output_attentions,
1118 | output_hidden_states=output_hidden_states,
1119 | return_dict=return_dict,
1120 | )
1121 | sequence_output = outputs[0]
1122 | sequence_output = self.dropout(sequence_output)
1123 | logits = self.score(sequence_output)
1124 |
1125 | loss = None
1126 | if labels is not None:
1127 | loss = self.loss_function(logits, labels, self.config)
1128 |
1129 | if not return_dict:
1130 | output = (logits,) + outputs[2:]
1131 | return ((loss,) + output) if loss is not None else output
1132 |
1133 | return TokenClassifierOutput(
1134 | loss=loss,
1135 | logits=logits,
1136 | hidden_states=outputs.hidden_states,
1137 | attentions=outputs.attentions,
1138 | )
1139 |
1140 |
1141 | @add_start_docstrings(
1142 | """
1143 | The Qwen2 Model transformer with a span classification head on top for extractive question-answering tasks like
1144 | SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`).
1145 | """,
1146 | QWEN2_START_DOCSTRING,
1147 | )
1148 | class Qwen2ForQuestionAnswering(Qwen2PreTrainedModel):
1149 | base_model_prefix = "transformer"
1150 |
1151 | def __init__(self, config):
1152 | super().__init__(config)
1153 | self.transformer = Qwen2Model(config)
1154 | self.qa_outputs = nn.Linear(config.hidden_size, 2)
1155 |
1156 | # Initialize weights and apply final processing
1157 | self.post_init()
1158 |
1159 | def get_input_embeddings(self):
1160 | return self.transformer.embed_tokens
1161 |
1162 | def set_input_embeddings(self, value):
1163 | self.transformer.embed_tokens = value
1164 |
1165 | @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
1166 | def forward(
1167 | self,
1168 | input_ids: Optional[torch.LongTensor] = None,
1169 | attention_mask: Optional[torch.FloatTensor] = None,
1170 | position_ids: Optional[torch.LongTensor] = None,
1171 | past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
1172 | inputs_embeds: Optional[torch.FloatTensor] = None,
1173 | start_positions: Optional[torch.LongTensor] = None,
1174 | end_positions: Optional[torch.LongTensor] = None,
1175 | output_attentions: Optional[bool] = None,
1176 | output_hidden_states: Optional[bool] = None,
1177 | return_dict: Optional[bool] = None,
1178 | **kwargs,
1179 | ) -> Union[Tuple, QuestionAnsweringModelOutput]:
1180 | r"""
1181 | start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1182 | Labels for position (index) of the start of the labelled span for computing the token classification loss.
1183 | Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1184 | are not taken into account for computing the loss.
1185 | end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1186 | Labels for position (index) of the end of the labelled span for computing the token classification loss.
1187 | Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1188 | are not taken into account for computing the loss.
1189 | """
1190 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1191 |
1192 | outputs = self.transformer(
1193 | input_ids,
1194 | attention_mask=attention_mask,
1195 | position_ids=position_ids,
1196 | past_key_values=past_key_values,
1197 | inputs_embeds=inputs_embeds,
1198 | output_attentions=output_attentions,
1199 | output_hidden_states=output_hidden_states,
1200 | return_dict=return_dict,
1201 | )
1202 |
1203 | sequence_output = outputs[0]
1204 |
1205 | logits = self.qa_outputs(sequence_output)
1206 | start_logits, end_logits = logits.split(1, dim=-1)
1207 | start_logits = start_logits.squeeze(-1).contiguous()
1208 | end_logits = end_logits.squeeze(-1).contiguous()
1209 |
1210 | loss = None
1211 | if start_positions is not None and end_positions is not None:
1212 | loss = self.loss_function(start_logits, end_logits, start_positions, end_positions, **kwargs)
1213 |
1214 | if not return_dict:
1215 | output = (start_logits, end_logits) + outputs[2:]
1216 | return ((loss,) + output) if loss is not None else output
1217 |
1218 | return QuestionAnsweringModelOutput(
1219 | loss=loss,
1220 | start_logits=start_logits,
1221 | end_logits=end_logits,
1222 | hidden_states=outputs.hidden_states,
1223 | attentions=outputs.attentions,
1224 | )
1225 |
--------------------------------------------------------------------------------
/parametric_fit.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from scipy.optimize import minimize
3 | from sklearn.linear_model import LinearRegression
4 | import matplotlib.pyplot as plt
5 | from sklearn.metrics import r2_score
6 | import json
7 | import os
8 | import pandas as pd
9 |
10 | def parametric_fit(param_list, p_list, loss_list):
11 | param_list = np.asarray(param_list).reshape((-1, ))
12 | loss_list = np.asarray(loss_list).reshape((-1, ))
13 | p_list = np.asarray(p_list).reshape((-1, ))
14 |
15 | def huber_loss(y_true, y_pred, delta=0.001):
16 | error = y_true - y_pred
17 | is_small_error = np.abs(error) <= delta
18 | squared_loss = np.square(error) / 2
19 | linear_loss = delta * (np.abs(error) - delta / 2)
20 | return np.where(is_small_error, squared_loss, linear_loss).sum()
21 |
22 | def pred_loss(params):
23 | E, A, alpha, k = params
24 | return E + (A * 1e9 / (param_list * (np.log(p_list) * k + 1))) ** alpha
25 |
26 | def objective_function(params):
27 | pred = pred_loss(params)
28 | return huber_loss(np.log(loss_list), np.log(pred))
29 |
30 | best_param = None
31 | best_func = 1000000
32 | for E in [-1, -0.5, 0]:
33 | for log_A in [-4, -2, 0, 2, 4]:
34 | for alpha in [0, 0.5, 1, 1.5, 2]:
35 | for k in [0.2, 0.4, 0.6, 0.8]:
36 | initial_params = [np.exp(E), np.exp(log_A), alpha, k]
37 | bounds = [(1e-8, None), (1e-8, None), (1e-8, None), (1e-8, None)]
38 | result = minimize(objective_function, initial_params, method='L-BFGS-B', bounds=bounds)
39 | if result.fun < best_func:
40 | best_param = result.x
41 | best_func = result.fun
42 | print(f"{result = }")
43 | print(f"{best_param = }")
44 | print(f"{best_func = }")
45 |
46 | pred_key = "$\\mathcal L_{\\text{pred}}$"
47 | true_key = "$\\mathcal L_{\\text{true}}$"
48 | df = pd.DataFrame({
49 | "$P$": p_list,
50 | "Parameters (Non-Embedding)": param_list,
51 | pred_key: pred_loss(best_param),
52 | true_key: loss_list,
53 | "Error": pred_loss(best_param) - loss_list
54 | })
55 | df['Parameters (Non-Embedding)'] = df['Parameters (Non-Embedding)'].apply(lambda x: f"{x:,}")
56 | r2 = r2_score(df[true_key].to_numpy().reshape(-1, 1), df[pred_key].to_numpy().reshape(-1, 1))
57 |
58 | print(df.to_latex(float_format=lambda x: f"{x:.4f}", index=False, column_format='rrrrr'))
59 | print(f"{r2 = }")
60 |
61 |
62 | if __name__ == "__main__":
63 |
64 | params = [
65 | [535813376, 693753856, 1088376320, 1571472384, 2774773760, 4353203200],
66 | [538195842, 696738818, 1092762882, 1577522690, 2784937986, 4368529922],
67 | [540577412, 699722756, 1097148164, 1583571460, 2795100164, 4383854084],
68 | [545340552, 705690632, 1105918728, 1595669000, 2815424520, 4414502408],
69 | ]
70 |
71 | stack_loss = [
72 | [1.1722, 1.1496, 1.1131, 1.0817, 1.0451, 1.0213], # 1.0006], # P1
73 | [1.1507, 1.1262, 1.094, 1.0623, 1.0244, 1.0025], # P2
74 | [1.1354, 1.1124, 1.0808, 1.049, 1.0126, 0.9906], # P4
75 | [1.1231, 1.0997, 1.0688, 1.0383, 1.0016, 0.9794], # P8
76 | ]
77 |
78 | pile_loss = [
79 | [2.1113, 2.0671, 2.0027, 1.9539, 1.8876, 1.8451], # P1
80 | [2.0772, 2.0363, 1.973, 1.9266, 1.861, 1.8137], # P2
81 | [2.0544, 2.0128, 1.9509, 1.904, 1.8394, 1.7938], # P4
82 | [2.0364, 1.9933, 1.9318, 1.8856, 1.8218, 1.7772], # P8
83 | ]
84 |
85 | p = [
86 | [1] * 6,
87 | [2] * 6,
88 | [4] * 6,
89 | [8] * 6,
90 | ]
91 |
92 | print("=" * 10 + " Stack-V2 Python " + "=" * 10)
93 | parametric_fit(params, p, stack_loss)
94 | print("=" * 10 + " Pile " + "=" * 10)
95 | parametric_fit(params, p, pile_loss)
--------------------------------------------------------------------------------