├── README.md
├── resources
├── CodeFuse-AI.png
└── result.png
└── utils
└── vllm_codefuse_cge_large.py
/README.md:
--------------------------------------------------------------------------------
1 | ## CodeFuse-CGE
2 |
3 |
4 |
5 |
6 | In this project, we introduce CodeFuse-CGE(Code General Embedding), which is distinguish on text2code task for it's powerful ability of capturing the semantic relationship between text and code.
7 | This model has the following notable features:
8 | ● Instruction-tuning is enabled for both query and code snippet sides.
9 | ● The model obtains sentence-level and code-level representations through a layer of cross-attention computation module.
10 | ● The model has a smaller dimensional size without significant degradation in performance.
11 |
12 | CodeFuse-CGE-Large Model Configuration
13 | huggingface:[codefuse-ai/CodeFuse-CGE-Large](https://huggingface.co/codefuse-ai/CodeFuse-CGE-Large)
14 | Base Model: CodeQwen1.5-7B-Chat
15 | Model Size: 7B
16 | Embedding Dimension: 1024
17 | Hidden Layers: 32
18 |
19 | Requirements
20 | ```
21 | flash_attn==2.4.2
22 | torch==2.1.0
23 | accelerate==0.28.0
24 | transformers==4.39.2
25 | vllm=0.5.3
26 | ```
27 |
28 |
29 | CodeFuse-CGE-Small Model Configuration
30 | huggingface:[codefuse-ai/CodeFuse-CGE-Small](https://huggingface.co/codefuse-ai/CodeFuse-CGE-Small)
31 | Base Model: Phi-3.5-mini-instruct
32 | Model Size: 3.8B
33 | Embedding Dimension: 1024
34 | Hidden Layers: 32
35 |
36 | Requirements
37 | ```
38 | flash_attn==2.4.2
39 | torch==2.1.0
40 | accelerate==0.28.0
41 | transformers>=4.43.0
42 | ```
43 |
44 |
45 | ## Benchmark the Performance
46 | We use MRR metric to evaluate the ability on text2code retrieval tasks: AdvTest, CosQA, CSN
47 |
48 | 
49 |
50 | ## How to Use
51 |
52 | You should download model file for huggingface at first.
53 |
54 | ### Transformers
55 | ```
56 | from transformers import AutoTokenizer, AutoModel
57 |
58 | model_name_or_path = "CodeFuse-CGE-Large"
59 | model = AutoModel.from_pretrained(model_name_or_path, trust_remote_code=True)
60 | tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True, truncation_side='right', padding_side='right')
61 |
62 | if torch.cuda.is_available():
63 | device = 'cuda'
64 | else:
65 | device = 'cpu'
66 | model.to(device)
67 |
68 | prefix_dict = {'python':{'query':'Retrieve the Python code that solves the following query:', 'passage':'Python code:'},
69 | 'java':{'query':'Retrieve the Java code that solves the following query:', 'passage':'Java code:'},
70 | 'go':{'query':'Retrieve the Go code that solves the following query:', 'passage':'Go code:'},
71 | 'c++':{'query':'Retrieve the C++ code that solves the following query:', 'passage':'C++ code:'},
72 | 'javascript':{'query':'Retrieve the Javascript code that solves the following query:', 'passage':'Javascript code:'},
73 | 'php':{'query':'Retrieve the PHP code that solves the following query:', 'passage':'PHP code:'},
74 | 'ruby':{'query':'Retrieve the Ruby code that solves the following query:', 'passage':'Ruby code:'},
75 | 'default':{'query':'Retrieve the code that solves the following query:', 'passage':'Code:'}
76 | }
77 |
78 | text = ["Writes a Boolean to the stream.",
79 | "def writeBoolean(self, n): t = TYPE_BOOL_TRUE if n is False: t = TYPE_BOOL_FALSE self.stream.write(t)"]
80 | text[0] += prefix_dict['python']['query']
81 | text[1] += prefix_dict['python']['passage']
82 | embed = model.encode(tokenizer, text)
83 | score = embed[0] @ embed[1].T
84 | print("score", score)
85 | ```
86 |
87 | ### Vllm
88 | We have also adapted Vllm to reduce latency during deployment.
89 | ```
90 | from vllm import ModelRegistry
91 | from utils.vllm_codefuse_cge_large import CodeFuse_CGE_Large
92 | from vllm.model_executor.models import ModelRegistry
93 | from vllm import LLM
94 |
95 | def always_true_is_embedding_model(model_arch: str) -> bool:
96 | return True
97 | ModelRegistry.is_embedding_model = always_true_is_embedding_model
98 | ModelRegistry.register_model("CodeFuse_CGE_Large", CodeFuse_CGE_Large)
99 |
100 |
101 | model_name_or_path = "CodeFuse-CGE-Large"
102 | model = LLM(model=model_name_or_path, trust_remote_code=True, enforce_eager=True, enable_chunked_prefill=False)
103 | prefix_dict = {'python':{'query':'Retrieve the Python code that solves the following query:', 'passage':'Python code:'},
104 | 'java':{'query':'Retrieve the Java code that solves the following query:', 'passage':'Java code:'},
105 | 'go':{'query':'Retrieve the Go code that solves the following query:', 'passage':'Go code:'},
106 | 'c++':{'query':'Retrieve the C++ code that solves the following query:', 'passage':'C++ code:'},
107 | 'javascript':{'query':'Retrieve the Javascript code that solves the following query:', 'passage':'Javascript code:'},
108 | 'php':{'query':'Retrieve the PHP code that solves the following query:', 'passage':'PHP code:'},
109 | 'ruby':{'query':'Retrieve the Ruby code that solves the following query:', 'passage':'Ruby code:'},
110 | 'default':{'query':'Retrieve the code that solves the following query:', 'passage':'Code:'}
111 | }
112 |
113 | text = ["Return the best fit based on rsquared",
114 | "def find_best_rsquared ( list_of_fits ) : res = sorted ( list_of_fits , key = lambda x : x . rsquared ) return res [ - 1 ]"]
115 | text[0] += prefix_dict['python']['query']
116 | text[1] += prefix_dict['python']['passage']
117 | embed_0 = model.encode([text[0]])[0].outputs.embedding
118 | embed_1 = model.encode([text[1]])[0].outputs.embedding
119 | ```
120 | Note:
121 | 1. After adapting Vllm, the model's input can only have a batch size of 1; otherwise, it will result in an array overflow error.
122 | 2. Only the CodeFuse-CGE-Large model has been adapted, and support for the CodeFuse-CGE-Small model will be available soon.
123 |
124 | ## Contact us
125 | Email:
126 |
127 | 
128 |
129 |
130 |
131 | ## Acknowledgement
132 | Thanks to the authors of open-sourced datasets, including CSN, Adv, CoSQA.
133 |
134 |
--------------------------------------------------------------------------------
/resources/CodeFuse-AI.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/codefuse-ai/CodeFuse-CGE/2e34e061856de876733c4e486d13c1aee538596d/resources/CodeFuse-AI.png
--------------------------------------------------------------------------------
/resources/result.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/codefuse-ai/CodeFuse-CGE/2e34e061856de876733c4e486d13c1aee538596d/resources/result.png
--------------------------------------------------------------------------------
/utils/vllm_codefuse_cge_large.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Adapted from
3 | # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/qwen2/modeling_qwen2.py
4 | # Copyright 2024 The Qwen team.
5 | # Copyright 2023 The vLLM team.
6 | # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
7 | #
8 | # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
9 | # and OPT implementations in this library. It has been modified from its
10 | # original forms to accommodate minor architectural differences compared
11 | # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
12 | #
13 | # Licensed under the Apache License, Version 2.0 (the "License");
14 | # you may not use this file except in compliance with the License.
15 | # You may obtain a copy of the License at
16 | #
17 | # http://www.apache.org/licenses/LICENSE-2.0
18 | #
19 | # Unless required by applicable law or agreed to in writing, software
20 | # distributed under the License is distributed on an "AS IS" BASIS,
21 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
22 | # See the License for the specific language governing permissions and
23 | # limitations under the License.
24 | """Inference-only Qwen2 model compatible with HuggingFace weights."""
25 | from typing import Iterable, List, Optional, Tuple
26 | from vllm.sequence import EmbeddingSequenceGroupOutput, PoolerOutput
27 | import torch
28 | from torch import nn
29 | from transformers import Qwen2Config
30 | from transformers import PretrainedConfig
31 | from vllm.attention import Attention, AttentionMetadata
32 | from vllm.config import CacheConfig, LoRAConfig
33 | from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
34 | from vllm.model_executor.layers.activation import SiluAndMul
35 | from vllm.model_executor.layers.layernorm import RMSNorm
36 | from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
37 | QKVParallelLinear,
38 | RowParallelLinear)
39 | from vllm.model_executor.layers.logits_processor import LogitsProcessor
40 | from vllm.model_executor.layers.quantization.base_config import (
41 | QuantizationConfig)
42 | from vllm.model_executor.layers.rotary_embedding import get_rope
43 | from vllm.model_executor.layers.sampler import Sampler
44 | from vllm.model_executor.layers.vocab_parallel_embedding import (
45 | ParallelLMHead, VocabParallelEmbedding)
46 | from vllm.model_executor.model_loader.weight_utils import (
47 | default_weight_loader, maybe_remap_kv_scale_name)
48 | from vllm.model_executor.sampling_metadata import SamplingMetadata
49 | from vllm.sequence import IntermediateTensors, SamplerOutput
50 |
51 | from vllm.model_executor.models.interfaces import SupportsLoRA
52 | from vllm.model_executor.models.utils import is_pp_missing_parameter, make_layers
53 | from typing import Iterable, List, Optional, Tuple
54 |
55 | import torch
56 | from torch import nn
57 |
58 | from vllm.attention import AttentionMetadata
59 | from vllm.model_executor.layers.pooler import Pooler, PoolingType
60 | from vllm.model_executor.model_loader.weight_utils import default_weight_loader
61 | from vllm.model_executor.models.llama import LlamaModel
62 | from vllm.model_executor.pooling_metadata import PoolingMetadata
63 | from vllm.sequence import PoolerOutput
64 | import math
65 | import sys
66 | import torch
67 | import torch.nn as nn
68 | import torch.nn.functional as F
69 |
70 |
71 | class Qwen2MLP(nn.Module):
72 |
73 | def __init__(
74 | self,
75 | hidden_size: int,
76 | intermediate_size: int,
77 | hidden_act: str,
78 | quant_config: Optional[QuantizationConfig] = None,
79 | ) -> None:
80 | super().__init__()
81 | self.gate_up_proj = MergedColumnParallelLinear(
82 | hidden_size, [intermediate_size] * 2,
83 | bias=False,
84 | quant_config=quant_config)
85 | self.down_proj = RowParallelLinear(intermediate_size,
86 | hidden_size,
87 | bias=False,
88 | quant_config=quant_config)
89 | if hidden_act != "silu":
90 | raise ValueError(f"Unsupported activation: {hidden_act}. "
91 | "Only silu is supported for now.")
92 | self.act_fn = SiluAndMul()
93 |
94 | def forward(self, x):
95 | gate_up, _ = self.gate_up_proj(x)
96 | x = self.act_fn(gate_up)
97 | x, _ = self.down_proj(x)
98 | return x
99 |
100 |
101 | class Qwen2Attention(nn.Module):
102 |
103 | def __init__(self,
104 | hidden_size: int,
105 | num_heads: int,
106 | num_kv_heads: int,
107 | max_position: int = 4096 * 32,
108 | rope_theta: float = 10000,
109 | cache_config: Optional[CacheConfig] = None,
110 | quant_config: Optional[QuantizationConfig] = None,
111 | rope_scaling: Optional[Tuple] = None) -> None:
112 | super().__init__()
113 | self.hidden_size = hidden_size
114 | tp_size = get_tensor_model_parallel_world_size()
115 | self.total_num_heads = num_heads
116 | assert self.total_num_heads % tp_size == 0
117 | self.num_heads = self.total_num_heads // tp_size
118 | self.total_num_kv_heads = num_kv_heads
119 | if self.total_num_kv_heads >= tp_size:
120 | # Number of KV heads is greater than TP size, so we partition
121 | # the KV heads across multiple tensor parallel GPUs.
122 | assert self.total_num_kv_heads % tp_size == 0
123 | else:
124 | # Number of KV heads is less than TP size, so we replicate
125 | # the KV heads across multiple tensor parallel GPUs.
126 | assert tp_size % self.total_num_kv_heads == 0
127 | self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
128 | self.head_dim = hidden_size // self.total_num_heads
129 | self.q_size = self.num_heads * self.head_dim
130 | self.kv_size = self.num_kv_heads * self.head_dim
131 | self.scaling = self.head_dim**-0.5
132 | self.rope_theta = rope_theta
133 |
134 | self.qkv_proj = QKVParallelLinear(
135 | hidden_size,
136 | self.head_dim,
137 | self.total_num_heads,
138 | self.total_num_kv_heads,
139 | bias=True,
140 | quant_config=quant_config,
141 | )
142 | self.o_proj = RowParallelLinear(
143 | self.total_num_heads * self.head_dim,
144 | hidden_size,
145 | bias=False,
146 | quant_config=quant_config,
147 | )
148 |
149 | self.rotary_emb = get_rope(
150 | self.head_dim,
151 | rotary_dim=self.head_dim,
152 | max_position=max_position,
153 | base=self.rope_theta,
154 | rope_scaling=rope_scaling,
155 | )
156 | self.attn = Attention(self.num_heads,
157 | self.head_dim,
158 | self.scaling,
159 | num_kv_heads=self.num_kv_heads,
160 | cache_config=cache_config,
161 | quant_config=quant_config)
162 |
163 | def forward(
164 | self,
165 | positions: torch.Tensor,
166 | hidden_states: torch.Tensor,
167 | kv_cache: torch.Tensor,
168 | attn_metadata: AttentionMetadata,
169 | ) -> torch.Tensor:
170 | qkv, _ = self.qkv_proj(hidden_states)
171 | q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
172 | q, k = self.rotary_emb(positions, q, k)
173 | attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
174 | output, _ = self.o_proj(attn_output)
175 | return output
176 |
177 |
178 | class Qwen2DecoderLayer(nn.Module):
179 |
180 | def __init__(
181 | self,
182 | config: Qwen2Config,
183 | cache_config: Optional[CacheConfig] = None,
184 | quant_config: Optional[QuantizationConfig] = None,
185 | ) -> None:
186 | super().__init__()
187 | self.hidden_size = config.hidden_size
188 | # Requires transformers > 4.32.0
189 | rope_theta = getattr(config, "rope_theta", 1000000)
190 | rope_scaling = getattr(config, "rope_scaling", None)
191 | self.self_attn = Qwen2Attention(
192 | hidden_size=self.hidden_size,
193 | num_heads=config.num_attention_heads,
194 | max_position=config.max_position_embeddings,
195 | num_kv_heads=config.num_key_value_heads,
196 | rope_theta=rope_theta,
197 | cache_config=cache_config,
198 | quant_config=quant_config,
199 | rope_scaling=rope_scaling)
200 | self.mlp = Qwen2MLP(
201 | hidden_size=self.hidden_size,
202 | intermediate_size=config.intermediate_size,
203 | hidden_act=config.hidden_act,
204 | quant_config=quant_config,
205 | )
206 | self.input_layernorm = RMSNorm(config.hidden_size,
207 | eps=config.rms_norm_eps)
208 | self.post_attention_layernorm = RMSNorm(config.hidden_size,
209 | eps=config.rms_norm_eps)
210 |
211 | def forward(
212 | self,
213 | positions: torch.Tensor,
214 | hidden_states: torch.Tensor,
215 | kv_cache: torch.Tensor,
216 | attn_metadata: AttentionMetadata,
217 | residual: Optional[torch.Tensor],
218 | ) -> Tuple[torch.Tensor, torch.Tensor]:
219 | # Self Attention
220 | if residual is None:
221 | residual = hidden_states
222 | hidden_states = self.input_layernorm(hidden_states)
223 | else:
224 | hidden_states, residual = self.input_layernorm(
225 | hidden_states, residual)
226 | hidden_states = self.self_attn(
227 | positions=positions,
228 | hidden_states=hidden_states,
229 | kv_cache=kv_cache,
230 | attn_metadata=attn_metadata,
231 | )
232 |
233 | # Fully Connected
234 | hidden_states, residual = self.post_attention_layernorm(
235 | hidden_states, residual)
236 | hidden_states = self.mlp(hidden_states)
237 | return hidden_states, residual
238 |
239 |
240 | class Qwen2Model(nn.Module):
241 |
242 | def __init__(
243 | self,
244 | config: Qwen2Config,
245 | cache_config: Optional[CacheConfig] = None,
246 | quant_config: Optional[QuantizationConfig] = None,
247 | prefix: str = "",
248 | ) -> None:
249 | super().__init__()
250 | self.config = config
251 | self.padding_idx = config.pad_token_id
252 | self.vocab_size = config.vocab_size
253 |
254 | self.embed_tokens = VocabParallelEmbedding(
255 | config.vocab_size,
256 | config.hidden_size,
257 | quant_config=quant_config,
258 | )
259 | self.start_layer, self.end_layer, self.layers = make_layers(
260 | config.num_hidden_layers,
261 | lambda prefix: Qwen2DecoderLayer(config=config,
262 | cache_config=cache_config,
263 | quant_config=quant_config),
264 | prefix=f"{prefix}.layers",
265 | )
266 |
267 | self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
268 |
269 | def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
270 | return self.embed_tokens(input_ids)
271 |
272 | def forward(
273 | self,
274 | input_ids: torch.Tensor,
275 | positions: torch.Tensor,
276 | kv_caches: List[torch.Tensor],
277 | attn_metadata: AttentionMetadata,
278 | intermediate_tensors: Optional[IntermediateTensors] = None,
279 | inputs_embeds: Optional[torch.Tensor] = None,
280 | ) -> torch.Tensor:
281 | if get_pp_group().is_first_rank:
282 | if inputs_embeds is not None:
283 | hidden_states = inputs_embeds
284 | else:
285 | hidden_states = self.embed_tokens(input_ids)
286 | residual = None
287 | else:
288 | assert intermediate_tensors is not None
289 | hidden_states = intermediate_tensors["hidden_states"]
290 | residual = intermediate_tensors["residual"]
291 | for i in range(self.start_layer, self.end_layer):
292 | layer = self.layers[i]
293 | hidden_states, residual = layer(
294 | positions,
295 | hidden_states,
296 | kv_caches[i - self.start_layer],
297 | attn_metadata,
298 | residual,
299 | )
300 | if not get_pp_group().is_last_rank:
301 | return IntermediateTensors({
302 | "hidden_states": hidden_states,
303 | "residual": residual
304 | })
305 | hidden_states, _ = self.norm(hidden_states, residual)
306 | return hidden_states
307 |
308 |
309 | class Qwen2ForCausalLM(nn.Module, SupportsLoRA):
310 | packed_modules_mapping = {
311 | "qkv_proj": [
312 | "q_proj",
313 | "k_proj",
314 | "v_proj",
315 | ],
316 | "gate_up_proj": [
317 | "gate_proj",
318 | "up_proj",
319 | ],
320 | }
321 |
322 | # LoRA specific attributes
323 | supported_lora_modules = [
324 | "qkv_proj",
325 | "o_proj",
326 | "gate_up_proj",
327 | "down_proj",
328 | ]
329 | embedding_modules = {}
330 | embedding_padding_modules = []
331 |
332 | def __init__(
333 | self,
334 | config: Qwen2Config,
335 | cache_config: Optional[CacheConfig] = None,
336 | quant_config: Optional[QuantizationConfig] = None,
337 | lora_config: Optional[LoRAConfig] = None,
338 | ) -> None:
339 | # TODO (@robertgshaw2): see if this can be moved out
340 | if (cache_config.sliding_window is not None
341 | and hasattr(config, "max_window_layers")):
342 | raise ValueError("Sliding window for some but all layers is not "
343 | "supported. This model uses sliding window "
344 | "but `max_window_layers` = %s is less than "
345 | "`num_hidden_layers` = %s. Please open an issue "
346 | "to discuss this feature." % (
347 | config.max_window_layers,
348 | config.num_hidden_layers,
349 | ))
350 |
351 | super().__init__()
352 |
353 | self.config = config
354 | self.lora_config = lora_config
355 |
356 | self.quant_config = quant_config
357 | self.model = Qwen2Model(config, cache_config, quant_config)
358 |
359 | if config.tie_word_embeddings:
360 | self.lm_head = self.model.embed_tokens
361 | else:
362 | self.lm_head = ParallelLMHead(config.vocab_size,
363 | config.hidden_size,
364 | quant_config=quant_config)
365 |
366 | self.logits_processor = LogitsProcessor(config.vocab_size)
367 | self.sampler = Sampler()
368 |
369 | def forward(
370 | self,
371 | input_ids: torch.Tensor,
372 | positions: torch.Tensor,
373 | kv_caches: List[torch.Tensor],
374 | attn_metadata: AttentionMetadata,
375 | intermediate_tensors: Optional[IntermediateTensors] = None,
376 | ) -> torch.Tensor:
377 | hidden_states = self.model(input_ids, positions, kv_caches,
378 | attn_metadata, intermediate_tensors)
379 | return hidden_states
380 |
381 | def compute_logits(self, hidden_states: torch.Tensor,
382 | sampling_metadata: SamplingMetadata) -> torch.Tensor:
383 | logits = self.logits_processor(self.lm_head, hidden_states,
384 | sampling_metadata)
385 | return logits
386 |
387 | def make_empty_intermediate_tensors(
388 | self, batch_size: int, dtype: torch.dtype,
389 | device: torch.device) -> IntermediateTensors:
390 | return IntermediateTensors({
391 | "hidden_states":
392 | torch.zeros((batch_size, self.config.hidden_size),
393 | dtype=dtype,
394 | device=device),
395 | "residual":
396 | torch.zeros((batch_size, self.config.hidden_size),
397 | dtype=dtype,
398 | device=device),
399 | })
400 |
401 | def sample(
402 | self,
403 | logits: torch.Tensor,
404 | sampling_metadata: SamplingMetadata,
405 | ) -> Optional[SamplerOutput]:
406 | next_tokens = self.sampler(logits, sampling_metadata)
407 | return next_tokens
408 |
409 | def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
410 | stacked_params_mapping = [
411 | # (param_name, shard_name, shard_id)
412 | ("qkv_proj", "q_proj", "q"),
413 | ("qkv_proj", "k_proj", "k"),
414 | ("qkv_proj", "v_proj", "v"),
415 | ("gate_up_proj", "gate_proj", 0),
416 | ("gate_up_proj", "up_proj", 1),
417 | ]
418 | params_dict = dict(self.named_parameters(remove_duplicate=False))
419 | for name, loaded_weight in weights:
420 | if "rotary_emb.inv_freq" in name:
421 | continue
422 | if self.config.tie_word_embeddings and "lm_head.weight" in name:
423 | continue
424 | for (param_name, weight_name, shard_id) in stacked_params_mapping:
425 | if weight_name not in name:
426 | continue
427 | name = name.replace(weight_name, param_name)
428 | # Skip loading extra bias for GPTQ models.
429 | if name.endswith(".bias") and name not in params_dict:
430 | continue
431 | if is_pp_missing_parameter(name, self):
432 | continue
433 | param = params_dict[name]
434 | weight_loader = param.weight_loader
435 | weight_loader(param, loaded_weight, shard_id)
436 | break
437 | else:
438 | # Skip loading extra bias for GPTQ models.
439 | if name.endswith(".bias") and name not in params_dict:
440 | continue
441 | # Remapping the name of FP8 kv-scale.
442 | name = maybe_remap_kv_scale_name(name, params_dict)
443 | if name is None:
444 | continue
445 | if is_pp_missing_parameter(name, self):
446 | continue
447 | param = params_dict[name]
448 | weight_loader = getattr(param, "weight_loader",
449 | default_weight_loader)
450 | weight_loader(param, loaded_weight)
451 |
452 | class CodeFuse_CGE_Large(nn.Module, SupportsLoRA):
453 | packed_modules_mapping = {
454 | "qkv_proj": [
455 | "q_proj",
456 | "k_proj",
457 | "v_proj",
458 | ],
459 | "gate_up_proj": [
460 | "gate_proj",
461 | "up_proj",
462 | ],
463 | }
464 |
465 | # LoRA specific attributes
466 | supported_lora_modules = [
467 | "qkv_proj",
468 | "o_proj",
469 | "gate_up_proj",
470 | "down_proj",
471 | ]
472 | embedding_modules = {}
473 | embedding_padding_modules = []
474 |
475 | def __init__(
476 | self,
477 | config: Qwen2Config,
478 | cache_config: Optional[CacheConfig] = None,
479 | quant_config: Optional[QuantizationConfig] = None,
480 | lora_config: Optional[LoRAConfig] = None,
481 | ) -> None:
482 | # TODO (@robertgshaw2): see if this can be moved out
483 | if (cache_config.sliding_window is not None
484 | and hasattr(config, "max_window_layers")):
485 | raise ValueError("Sliding window for some but all layers is not "
486 | "supported. This model uses sliding window "
487 | "but `max_window_layers` = %s is less than "
488 | "`num_hidden_layers` = %s. Please open an issue "
489 | "to discuss this feature." % (
490 | config.max_window_layers,
491 | config.num_hidden_layers,
492 | ))
493 |
494 | super().__init__()
495 |
496 | self.config = config
497 | self.lora_config = lora_config
498 | self.quant_config = quant_config
499 | self.plm_model = Qwen2ForCausalLM(config, cache_config, quant_config)
500 | self.embedding_method = config.embedding_method
501 | self.inf_seq_length = config.inf_seq_length
502 | self.padding_side = config.padding_side
503 | self.keep_max_layer = config.keep_max_layer
504 | self.emb_dim = self.plm_model.model.embed_tokens.weight.size(1)
505 | self.num_heads = config.pma_num_heads
506 | self.ln = config.pma_ln
507 | self.norm = config.pma_norm
508 | self.pma_mode = config.pma_norm_mode
509 | self.mha_pma = PMA(self.emb_dim, self.compress_dim, self.num_heads, 1, ln=self.ln, pma_mode=self.pma_mode).to("cuda")
510 | if config.tie_word_embeddings:
511 | self.lm_head = self.plm_model.embed_tokens
512 | else:
513 | self.lm_head = ParallelLMHead(config.vocab_size,
514 | config.hidden_size,
515 | quant_config=quant_config)
516 |
517 | self.logits_processor = LogitsProcessor(config.vocab_size)
518 | self.sampler = Sampler()
519 | for param_tensor in self.mha_pma.state_dict():
520 | print(param_tensor, "\t", self.mha_pma.state_dict()[param_tensor])
521 |
522 | def forward(
523 | self,
524 | input_ids: torch.Tensor,
525 | positions: torch.Tensor,
526 | kv_caches: List[torch.Tensor],
527 | attn_metadata: AttentionMetadata,
528 | intermediate_tensors: Optional[IntermediateTensors] = None,
529 | ) -> torch.Tensor:
530 | hidden_states = self.plm_model(input_ids, positions, kv_caches,
531 | attn_metadata, intermediate_tensors)
532 |
533 | embedding = hidden_states.unsqueeze(0)
534 | res_embedding = self.pma_embedding(embedding, positions.unsqueeze(0))
535 | return res_embedding
536 |
537 | def pooler(
538 | self,
539 | hidden_states: torch.Tensor,
540 | pooling_metadata: PoolingMetadata,
541 | ) -> Optional[PoolerOutput]:
542 | hidden_states = nn.functional.normalize(hidden_states, p=2, dim=1)
543 | pooled_outputs = [
544 | EmbeddingSequenceGroupOutput(data.tolist()) for data in hidden_states
545 | ]
546 |
547 | return PoolerOutput(outputs=pooled_outputs)
548 |
549 | def compute_logits(self, hidden_states: torch.Tensor,
550 | sampling_metadata: SamplingMetadata) -> torch.Tensor:
551 | logits = self.logits_processor(self.lm_head, hidden_states,
552 | sampling_metadata)
553 | return logits
554 |
555 | def make_empty_intermediate_tensors(
556 | self, batch_size: int, dtype: torch.dtype,
557 | device: torch.device) -> IntermediateTensors:
558 | return IntermediateTensors({
559 | "hidden_states":
560 | torch.zeros((batch_size, self.config.hidden_size),
561 | dtype=dtype,
562 | device=device),
563 | "residual":
564 | torch.zeros((batch_size, self.config.hidden_size),
565 | dtype=dtype,
566 | device=device),
567 | })
568 |
569 | def sample(
570 | self,
571 | logits: torch.Tensor,
572 | sampling_metadata: SamplingMetadata,
573 | ) -> Optional[SamplerOutput]:
574 | next_tokens = self.sampler(logits, sampling_metadata)
575 | return next_tokens
576 |
577 | def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
578 | stacked_params_mapping = [
579 | # (param_name, shard_name, shard_id)
580 | ("qkv_proj", "q_proj", "q"),
581 | ("qkv_proj", "k_proj", "k"),
582 | ("qkv_proj", "v_proj", "v"),
583 | ("gate_up_proj", "gate_proj", 0),
584 | ("gate_up_proj", "up_proj", 1),
585 | ]
586 | params_dict = dict(self.named_parameters(remove_duplicate=False))
587 | for name, loaded_weight in weights:
588 | if "rotary_emb.inv_freq" in name:
589 | continue
590 | if self.config.tie_word_embeddings and "lm_head.weight" in name:
591 | continue
592 | for (param_name, weight_name, shard_id) in stacked_params_mapping:
593 | if weight_name not in name:
594 | continue
595 | name = name.replace(weight_name, param_name)
596 | # Skip loading extra bias for GPTQ models.
597 | if name.endswith(".bias") and name not in params_dict:
598 | continue
599 | if is_pp_missing_parameter(name, self):
600 | continue
601 | param = params_dict[name]
602 | weight_loader = param.weight_loader
603 | weight_loader(param, loaded_weight, shard_id)
604 | break
605 | else:
606 | # Skip loading extra bias for GPTQ models.
607 | if name.endswith(".bias") and name not in params_dict:
608 | continue
609 | # Remapping the name of FP8 kv-scale.
610 | name = maybe_remap_kv_scale_name(name, params_dict)
611 | if name is None:
612 | continue
613 | if is_pp_missing_parameter(name, self):
614 | continue
615 | param = params_dict[name]
616 | weight_loader = getattr(param, "weight_loader",
617 | default_weight_loader)
618 | weight_loader(param, loaded_weight)
619 | for param_tensor in self.mha_pma.state_dict():
620 | print(param_tensor, "\t", self.mha_pma.state_dict()[param_tensor])
621 |
622 | def last_embedding(self, A, index):
623 | bs, seq, emb = A.size()
624 | res = A[torch.arange(bs), index, :]
625 | return res
626 |
627 | def mean_embedding(self, A, mask):
628 | bs, seq, emb = A.size()
629 | res = (A * (mask.unsqueeze(-1))).sum(1) / (mask.sum(1).unsqueeze(-1))
630 | return res
631 |
632 | # A (bs, seq, emb_size), mask (bs, 1, seq)
633 | def weighted_embedding(self, A, mask):
634 | weights = (torch.arange(start=1, end=A.size(1) + 1).unsqueeze(0).unsqueeze(-1).expand(A.size()).float()).to(A.device)
635 | input_mask_expanded = (mask.squeeze(1).unsqueeze(-1).expand(A.size()).float()).to(A.device)
636 | sum_embedding = torch.sum(A * input_mask_expanded * weights, dim=1)
637 | sum_mask = torch.sum(input_mask_expanded * weights, dim=1)
638 | weighted_embedding = sum_embedding / sum_mask
639 |
640 | return weighted_embedding
641 |
642 | def pma_embedding(self, A, mask):
643 | res = self.mha_pma(A, mask).squeeze(1)
644 | return res
645 |
646 |
647 | def get_sentence_embedding(self, embedding_method, **inputs):
648 | outputs = self.plm_model(inputs['input_ids'], inputs['attention_mask'], output_hidden_states=True)
649 | if embedding_method == 'last':
650 | embedding = outputs.hidden_states[self.keep_max_layer]
651 | index = inputs['attention_mask'].sum(-1).long() - 1
652 | res_embedding = self.last_embedding(embedding, index)
653 | elif embedding_method == 'mean':
654 | embedding = outputs.hidden_states[self.keep_max_layer]
655 | res_embedding = self.mean_embedding(embedding, inputs['attention_mask'])
656 | elif embedding_method == 'weighted':
657 | embedding = outputs.hidden_states[self.keep_max_layer]
658 | res_embedding = self.weighted_embedding(embedding, inputs['attention_mask'])
659 | elif embedding_method == 'pma':
660 | embedding = outputs.hidden_states[self.keep_max_layer]
661 | attention_mask = inputs['attention_mask']
662 | res_embedding = self.pma_embedding(embedding, attention_mask)
663 | else:
664 | logger.debug('Error, no {} way to obtain embbedings'.format(embedding_method))
665 |
666 | if not self.norm:
667 | res_embedding = torch.nn.functional.normalize(res_embedding, p=2.0, dim=-1, eps=1e-12, out=None)
668 | return res_embedding
669 |
670 |
671 |
672 | def encode(self, tokenizer, sentences, batch_size=32, convert_to_numpy=True,
673 | convert_to_tensor=False, show_progress_bar=True, max_seq_length=None, **kwargs):
674 | if max_seq_length is None:
675 | max_seq_length = self.inf_seq_length
676 | input_is_string = False
677 | if isinstance(sentences, str) or not hasattr(sentences, "__len__"):
678 | sentences = [sentences]
679 | input_is_string = True
680 |
681 | all_embeddings = []
682 | length_sorted_idx = np.argsort([-len(s) for s in sentences])
683 | sentences_sorted = [sentences[idx] for idx in length_sorted_idx] # 大到小重排
684 | with torch.no_grad():
685 | for start_index in trange(0, len(sentences), batch_size, desc="Batches", disable=not show_progress_bar):
686 | sentences_batch = sentences_sorted[start_index: start_index + batch_size]
687 | # Compute sentences embeddings
688 | with torch.no_grad():
689 | inputs = tokenizer(sentences_batch, padding=True, truncation=True, max_length=max_seq_length, add_special_tokens=False, return_tensors='pt').to(self.plm_model.device)
690 | embeddings = self.get_sentence_embedding(self.embedding_method, **inputs)
691 | embeddings = embeddings.detach()
692 | if convert_to_numpy:
693 | if embeddings.dtype == torch.bfloat16:
694 | embeddings = embeddings.cpu().to(torch.float32)
695 | else:
696 | embeddings = embeddings.cpu()
697 | all_embeddings.extend(embeddings)
698 | all_embeddings = [all_embeddings[idx] for idx in np.argsort(length_sorted_idx)]
699 | if convert_to_tensor:
700 | all_embeddings = torch.stack(all_embeddings)
701 | elif convert_to_numpy:
702 | all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings])
703 |
704 | if input_is_string:
705 | all_embeddings = all_embeddings[0]
706 | return all_embeddings
707 |
708 |
709 | class MAB_POST(nn.Module):
710 | def __init__(self, dim_Q, dim_K, dim_V, num_heads, ln=False):
711 | super(MAB_POST, self).__init__()
712 | self.dim_V = dim_V
713 | self.num_heads = num_heads
714 | self.fc_q = nn.Linear(dim_Q, dim_V)
715 | self.fc_k = nn.Linear(dim_K, dim_V)
716 | self.fc_v = nn.Linear(dim_K, dim_V)
717 |
718 | if ln:
719 | self.ln0 = nn.LayerNorm(dim_V)
720 | self.ln1 = nn.LayerNorm(dim_V)
721 | self.fc_o = nn.Linear(dim_V, dim_V)
722 | nn.init.xavier_uniform_(self.fc_q.weight)
723 | nn.init.xavier_uniform_(self.fc_k.weight)
724 | nn.init.xavier_uniform_(self.fc_v.weight)
725 | nn.init.xavier_uniform_(self.fc_o.weight)
726 |
727 | def forward(self, Q, K, pad_mask=None):
728 |
729 | Q_ = self.fc_q(Q)
730 | K_, V_ = self.fc_k(K), self.fc_v(K)
731 | dim_split = self.dim_V // self.num_heads
732 | Q_ = torch.cat(Q_.split(dim_split, 2), 0)
733 | K_ = torch.cat(K_.split(dim_split, 2), 0)
734 | V_ = torch.cat(V_.split(dim_split, 2), 0)
735 |
736 | pad_mask = pad_mask.unsqueeze(1).repeat(self.num_heads, Q.size(1), 1)
737 | score = Q_.bmm(K_.transpose(1,2))/math.sqrt(self.dim_V)
738 | score = score.masked_fill(pad_mask == 0, -1e12)
739 | A = torch.softmax(score, 2)
740 | A = A * pad_mask
741 | O = torch.cat(A.bmm(V_).split(Q.size(0), 0), 2)
742 | O = Q + O
743 | O = O if getattr(self, 'ln0', None) is None else self.ln0(O)
744 | O = O + F.relu(self.fc_o(O))
745 | O = O if getattr(self, 'ln1', None) is None else self.ln1(O)
746 | return O
747 |
748 |
749 | class MAB_PRE_NORMAL(nn.Module):
750 | def __init__(self, dim_Q, dim_K, dim_V, num_heads, ln=False):
751 | super(MAB_PRE_NORMAL, self).__init__()
752 | self.dim_V = dim_V
753 | self.num_heads = num_heads
754 | self.fc_q = nn.Linear(dim_Q, dim_V)
755 | self.fc_k = nn.Linear(dim_K, dim_V)
756 | self.fc_v = nn.Linear(dim_K, dim_V)
757 |
758 | if ln:
759 | self.ln_q = nn.LayerNorm(dim_V)
760 | self.ln_kv = nn.LayerNorm(dim_V)
761 | self.ln_o = nn.LayerNorm(dim_V)
762 | self.ln_final = nn.LayerNorm(dim_V)
763 |
764 | self.fc_o = nn.Linear(dim_V, dim_V)
765 | nn.init.xavier_uniform_(self.fc_q.weight)
766 | nn.init.xavier_uniform_(self.fc_k.weight)
767 | nn.init.xavier_uniform_(self.fc_v.weight)
768 | nn.init.xavier_uniform_(self.fc_o.weight)
769 |
770 | def forward(self, Q, K, pad_mask=None):
771 | Q_ = Q if getattr(self, 'ln_q', None) is None else self.ln_q(Q)
772 | K_ = K if getattr(self, 'ln_kv', None) is None else self.ln_kv(K)
773 | Q_ = self.fc_q(Q_)
774 | K_, V_ = self.fc_k(K_), self.fc_v(K_)
775 | dim_split = self.dim_V // self.num_heads
776 | Q_ = torch.cat(Q_.split(dim_split, 2), 0)
777 | K_ = torch.cat(K_.split(dim_split, 2), 0)
778 | V_ = torch.cat(V_.split(dim_split, 2), 0)
779 | pad_mask = pad_mask.unsqueeze(1).repeat(self.num_heads, Q.size(1), 1)
780 | score = Q_.bmm(K_.transpose(1,2))/math.sqrt(self.dim_V)
781 | score = score.masked_fill(pad_mask == 0, -1e12)
782 | A = torch.softmax(score, 2)
783 | A = A * pad_mask
784 | O = torch.cat(A.bmm(V_).split(Q.size(0), 0), 2)
785 | O = Q + O
786 | O_ = O if getattr(self, 'ln_o', None) is None else self.ln_o(O)
787 | O_ = O + F.relu(self.fc_o(O_))
788 | return O_ if getattr(self, 'ln_final', None) is None else self.ln_final(O_)
789 |
790 |
791 | class MAB_PRE_GPTJ(nn.Module):
792 | def __init__(self, dim_Q, dim_K, dim_V, num_heads, ln=False):
793 | super(MAB_PRE_GPTJ, self).__init__()
794 | self.dim_V = dim_V
795 | self.num_heads = num_heads
796 | self.fc_q = nn.Linear(dim_Q, dim_V)
797 | self.fc_k = nn.Linear(dim_K, dim_V)
798 | self.fc_v = nn.Linear(dim_K, dim_V)
799 | self.fc_o = nn.Linear(dim_V, dim_V)
800 |
801 | nn.init.xavier_uniform_(self.fc_q.weight)
802 | nn.init.xavier_uniform_(self.fc_k.weight)
803 | nn.init.xavier_uniform_(self.fc_v.weight)
804 | nn.init.xavier_uniform_(self.fc_o.weight)
805 | if ln:
806 | self.ln_q = nn.LayerNorm(dim_V)
807 | self.ln_kv = nn.LayerNorm(dim_V)
808 | self.ln_final = nn.LayerNorm(dim_V)
809 |
810 | def forward(self, Q, K, pad_mask=None):
811 | Q_ = Q if getattr(self, 'ln_q', None) is None else self.ln_q(Q)
812 | K_ = K if getattr(self, 'ln_kv', None) is None else self.ln_kv(K)
813 |
814 | Q1 = self.fc_q(Q_)
815 | K1, V1 = self.fc_k(K_), self.fc_v(K_)
816 | dim_split = self.dim_V // self.num_heads
817 |
818 | Q1 = torch.cat(Q1.split(dim_split, 2), 0)
819 | K1 = torch.cat(K1.split(dim_split, 2), 0)
820 | V1 = torch.cat(V1.split(dim_split, 2), 0)
821 |
822 |
823 | pad_mask = pad_mask.unsqueeze(1).repeat(self.num_heads, Q.size(1), 1)
824 | score = Q1.bmm(K1.transpose(1,2))/math.sqrt(self.dim_V)
825 | score = score.masked_fill(pad_mask == 0, -1e12)
826 | A = torch.softmax(score, 2)
827 | A = A * pad_mask
828 | O1 = torch.cat(A.bmm(V1).split(Q.size(0), 0), 2)
829 | O2 = F.relu(self.fc_o(Q_))
830 | O_final = Q + O1 + O2
831 | return O_final if getattr(self, 'ln_final', None) is None else self.ln_final(O_final)
832 |
833 |
834 | class PMA(nn.Module):
835 | def __init__(self, dim, compress_dim, num_heads, num_seeds, ln=False, pma_mode=None):
836 | super(PMA, self).__init__()
837 | self.S = nn.Parameter(torch.Tensor(1, num_seeds, compress_dim))
838 | nn.init.xavier_uniform_(self.S)
839 | if pma_mode == 'post_normal':
840 | self.mab = MAB_POST(compress_dim, dim, compress_dim, num_heads, ln=ln)
841 | elif pma_mode == 'pre_normal':
842 | self.mab = MAB_PRE_NORMAL(compress_dim, dim, compress_dim, num_heads, ln=ln)
843 | elif pma_mode == 'pre_gptj':
844 | self.mab = MAB_PRE_GPTJ(compress_dim, dim, compress_dim, num_heads, ln=ln)
845 | else:
846 | raise ValueError(f"Error, the pma_mode {pma_mode} is not implemented !")
847 |
848 | def forward(self, X, pad_mask):
849 | if self.S.dtype != torch.bfloat16:
850 | X = X.float()
851 | return self.mab(self.S.repeat(X.size(0), 1, 1), X, pad_mask)
852 |
--------------------------------------------------------------------------------