├── README.md ├── resources ├── CodeFuse-AI.png └── result.png └── utils └── vllm_codefuse_cge_large.py /README.md: -------------------------------------------------------------------------------- 1 | ## CodeFuse-CGE 2 |

3 | 4 |

5 | 6 | In this project, we introduce CodeFuse-CGE(Code General Embedding), which is distinguish on text2code task for it's powerful ability of capturing the semantic relationship between text and code. 7 | This model has the following notable features: 8 | ● Instruction-tuning is enabled for both query and code snippet sides. 9 | ● The model obtains sentence-level and code-level representations through a layer of cross-attention computation module. 10 | ● The model has a smaller dimensional size without significant degradation in performance. 11 | 12 | CodeFuse-CGE-Large Model Configuration 13 | huggingface:[codefuse-ai/CodeFuse-CGE-Large](https://huggingface.co/codefuse-ai/CodeFuse-CGE-Large) 14 | Base Model: CodeQwen1.5-7B-Chat 15 | Model Size: 7B 16 | Embedding Dimension: 1024 17 | Hidden Layers: 32 18 | 19 | Requirements 20 | ``` 21 | flash_attn==2.4.2 22 | torch==2.1.0 23 | accelerate==0.28.0 24 | transformers==4.39.2 25 | vllm=0.5.3 26 | ``` 27 | 28 | 29 | CodeFuse-CGE-Small Model Configuration 30 | huggingface:[codefuse-ai/CodeFuse-CGE-Small](https://huggingface.co/codefuse-ai/CodeFuse-CGE-Small) 31 | Base Model: Phi-3.5-mini-instruct 32 | Model Size: 3.8B 33 | Embedding Dimension: 1024 34 | Hidden Layers: 32 35 | 36 | Requirements 37 | ``` 38 | flash_attn==2.4.2 39 | torch==2.1.0 40 | accelerate==0.28.0 41 | transformers>=4.43.0 42 | ``` 43 | 44 | 45 | ## Benchmark the Performance 46 | We use MRR metric to evaluate the ability on text2code retrieval tasks: AdvTest, CosQA, CSN 47 | 48 | ![result](./resources/result.png) 49 | 50 | ## How to Use 51 | 52 | You should download model file for huggingface at first. 53 | 54 | ### Transformers 55 | ``` 56 | from transformers import AutoTokenizer, AutoModel 57 | 58 | model_name_or_path = "CodeFuse-CGE-Large" 59 | model = AutoModel.from_pretrained(model_name_or_path, trust_remote_code=True) 60 | tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True, truncation_side='right', padding_side='right') 61 | 62 | if torch.cuda.is_available(): 63 | device = 'cuda' 64 | else: 65 | device = 'cpu' 66 | model.to(device) 67 | 68 | prefix_dict = {'python':{'query':'Retrieve the Python code that solves the following query:', 'passage':'Python code:'}, 69 | 'java':{'query':'Retrieve the Java code that solves the following query:', 'passage':'Java code:'}, 70 | 'go':{'query':'Retrieve the Go code that solves the following query:', 'passage':'Go code:'}, 71 | 'c++':{'query':'Retrieve the C++ code that solves the following query:', 'passage':'C++ code:'}, 72 | 'javascript':{'query':'Retrieve the Javascript code that solves the following query:', 'passage':'Javascript code:'}, 73 | 'php':{'query':'Retrieve the PHP code that solves the following query:', 'passage':'PHP code:'}, 74 | 'ruby':{'query':'Retrieve the Ruby code that solves the following query:', 'passage':'Ruby code:'}, 75 | 'default':{'query':'Retrieve the code that solves the following query:', 'passage':'Code:'} 76 | } 77 | 78 | text = ["Writes a Boolean to the stream.", 79 | "def writeBoolean(self, n): t = TYPE_BOOL_TRUE if n is False: t = TYPE_BOOL_FALSE self.stream.write(t)"] 80 | text[0] += prefix_dict['python']['query'] 81 | text[1] += prefix_dict['python']['passage'] 82 | embed = model.encode(tokenizer, text) 83 | score = embed[0] @ embed[1].T 84 | print("score", score) 85 | ``` 86 | 87 | ### Vllm 88 | We have also adapted Vllm to reduce latency during deployment. 89 | ``` 90 | from vllm import ModelRegistry 91 | from utils.vllm_codefuse_cge_large import CodeFuse_CGE_Large 92 | from vllm.model_executor.models import ModelRegistry 93 | from vllm import LLM 94 | 95 | def always_true_is_embedding_model(model_arch: str) -> bool: 96 | return True 97 | ModelRegistry.is_embedding_model = always_true_is_embedding_model 98 | ModelRegistry.register_model("CodeFuse_CGE_Large", CodeFuse_CGE_Large) 99 | 100 | 101 | model_name_or_path = "CodeFuse-CGE-Large" 102 | model = LLM(model=model_name_or_path, trust_remote_code=True, enforce_eager=True, enable_chunked_prefill=False) 103 | prefix_dict = {'python':{'query':'Retrieve the Python code that solves the following query:', 'passage':'Python code:'}, 104 | 'java':{'query':'Retrieve the Java code that solves the following query:', 'passage':'Java code:'}, 105 | 'go':{'query':'Retrieve the Go code that solves the following query:', 'passage':'Go code:'}, 106 | 'c++':{'query':'Retrieve the C++ code that solves the following query:', 'passage':'C++ code:'}, 107 | 'javascript':{'query':'Retrieve the Javascript code that solves the following query:', 'passage':'Javascript code:'}, 108 | 'php':{'query':'Retrieve the PHP code that solves the following query:', 'passage':'PHP code:'}, 109 | 'ruby':{'query':'Retrieve the Ruby code that solves the following query:', 'passage':'Ruby code:'}, 110 | 'default':{'query':'Retrieve the code that solves the following query:', 'passage':'Code:'} 111 | } 112 | 113 | text = ["Return the best fit based on rsquared", 114 | "def find_best_rsquared ( list_of_fits ) : res = sorted ( list_of_fits , key = lambda x : x . rsquared ) return res [ - 1 ]"] 115 | text[0] += prefix_dict['python']['query'] 116 | text[1] += prefix_dict['python']['passage'] 117 | embed_0 = model.encode([text[0]])[0].outputs.embedding 118 | embed_1 = model.encode([text[1]])[0].outputs.embedding 119 | ``` 120 | Note: 121 | 1. After adapting Vllm, the model's input can only have a batch size of 1; otherwise, it will result in an array overflow error. 122 | 2. Only the CodeFuse-CGE-Large model has been adapted, and support for the CodeFuse-CGE-Small model will be available soon. 123 | 124 | ## Contact us 125 | Email: 126 | 127 | ![CodeFuse-AI](./resources/CodeFuse-AI.png) 128 | 129 | 130 | 131 | ## Acknowledgement 132 | Thanks to the authors of open-sourced datasets, including CSN, Adv, CoSQA. 133 | 134 | -------------------------------------------------------------------------------- /resources/CodeFuse-AI.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/codefuse-ai/CodeFuse-CGE/2e34e061856de876733c4e486d13c1aee538596d/resources/CodeFuse-AI.png -------------------------------------------------------------------------------- /resources/result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/codefuse-ai/CodeFuse-CGE/2e34e061856de876733c4e486d13c1aee538596d/resources/result.png -------------------------------------------------------------------------------- /utils/vllm_codefuse_cge_large.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Adapted from 3 | # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/qwen2/modeling_qwen2.py 4 | # Copyright 2024 The Qwen team. 5 | # Copyright 2023 The vLLM team. 6 | # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. 7 | # 8 | # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX 9 | # and OPT implementations in this library. It has been modified from its 10 | # original forms to accommodate minor architectural differences compared 11 | # to GPT-NeoX and OPT used by the Meta AI team that trained the model. 12 | # 13 | # Licensed under the Apache License, Version 2.0 (the "License"); 14 | # you may not use this file except in compliance with the License. 15 | # You may obtain a copy of the License at 16 | # 17 | # http://www.apache.org/licenses/LICENSE-2.0 18 | # 19 | # Unless required by applicable law or agreed to in writing, software 20 | # distributed under the License is distributed on an "AS IS" BASIS, 21 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 22 | # See the License for the specific language governing permissions and 23 | # limitations under the License. 24 | """Inference-only Qwen2 model compatible with HuggingFace weights.""" 25 | from typing import Iterable, List, Optional, Tuple 26 | from vllm.sequence import EmbeddingSequenceGroupOutput, PoolerOutput 27 | import torch 28 | from torch import nn 29 | from transformers import Qwen2Config 30 | from transformers import PretrainedConfig 31 | from vllm.attention import Attention, AttentionMetadata 32 | from vllm.config import CacheConfig, LoRAConfig 33 | from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size 34 | from vllm.model_executor.layers.activation import SiluAndMul 35 | from vllm.model_executor.layers.layernorm import RMSNorm 36 | from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, 37 | QKVParallelLinear, 38 | RowParallelLinear) 39 | from vllm.model_executor.layers.logits_processor import LogitsProcessor 40 | from vllm.model_executor.layers.quantization.base_config import ( 41 | QuantizationConfig) 42 | from vllm.model_executor.layers.rotary_embedding import get_rope 43 | from vllm.model_executor.layers.sampler import Sampler 44 | from vllm.model_executor.layers.vocab_parallel_embedding import ( 45 | ParallelLMHead, VocabParallelEmbedding) 46 | from vllm.model_executor.model_loader.weight_utils import ( 47 | default_weight_loader, maybe_remap_kv_scale_name) 48 | from vllm.model_executor.sampling_metadata import SamplingMetadata 49 | from vllm.sequence import IntermediateTensors, SamplerOutput 50 | 51 | from vllm.model_executor.models.interfaces import SupportsLoRA 52 | from vllm.model_executor.models.utils import is_pp_missing_parameter, make_layers 53 | from typing import Iterable, List, Optional, Tuple 54 | 55 | import torch 56 | from torch import nn 57 | 58 | from vllm.attention import AttentionMetadata 59 | from vllm.model_executor.layers.pooler import Pooler, PoolingType 60 | from vllm.model_executor.model_loader.weight_utils import default_weight_loader 61 | from vllm.model_executor.models.llama import LlamaModel 62 | from vllm.model_executor.pooling_metadata import PoolingMetadata 63 | from vllm.sequence import PoolerOutput 64 | import math 65 | import sys 66 | import torch 67 | import torch.nn as nn 68 | import torch.nn.functional as F 69 | 70 | 71 | class Qwen2MLP(nn.Module): 72 | 73 | def __init__( 74 | self, 75 | hidden_size: int, 76 | intermediate_size: int, 77 | hidden_act: str, 78 | quant_config: Optional[QuantizationConfig] = None, 79 | ) -> None: 80 | super().__init__() 81 | self.gate_up_proj = MergedColumnParallelLinear( 82 | hidden_size, [intermediate_size] * 2, 83 | bias=False, 84 | quant_config=quant_config) 85 | self.down_proj = RowParallelLinear(intermediate_size, 86 | hidden_size, 87 | bias=False, 88 | quant_config=quant_config) 89 | if hidden_act != "silu": 90 | raise ValueError(f"Unsupported activation: {hidden_act}. " 91 | "Only silu is supported for now.") 92 | self.act_fn = SiluAndMul() 93 | 94 | def forward(self, x): 95 | gate_up, _ = self.gate_up_proj(x) 96 | x = self.act_fn(gate_up) 97 | x, _ = self.down_proj(x) 98 | return x 99 | 100 | 101 | class Qwen2Attention(nn.Module): 102 | 103 | def __init__(self, 104 | hidden_size: int, 105 | num_heads: int, 106 | num_kv_heads: int, 107 | max_position: int = 4096 * 32, 108 | rope_theta: float = 10000, 109 | cache_config: Optional[CacheConfig] = None, 110 | quant_config: Optional[QuantizationConfig] = None, 111 | rope_scaling: Optional[Tuple] = None) -> None: 112 | super().__init__() 113 | self.hidden_size = hidden_size 114 | tp_size = get_tensor_model_parallel_world_size() 115 | self.total_num_heads = num_heads 116 | assert self.total_num_heads % tp_size == 0 117 | self.num_heads = self.total_num_heads // tp_size 118 | self.total_num_kv_heads = num_kv_heads 119 | if self.total_num_kv_heads >= tp_size: 120 | # Number of KV heads is greater than TP size, so we partition 121 | # the KV heads across multiple tensor parallel GPUs. 122 | assert self.total_num_kv_heads % tp_size == 0 123 | else: 124 | # Number of KV heads is less than TP size, so we replicate 125 | # the KV heads across multiple tensor parallel GPUs. 126 | assert tp_size % self.total_num_kv_heads == 0 127 | self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) 128 | self.head_dim = hidden_size // self.total_num_heads 129 | self.q_size = self.num_heads * self.head_dim 130 | self.kv_size = self.num_kv_heads * self.head_dim 131 | self.scaling = self.head_dim**-0.5 132 | self.rope_theta = rope_theta 133 | 134 | self.qkv_proj = QKVParallelLinear( 135 | hidden_size, 136 | self.head_dim, 137 | self.total_num_heads, 138 | self.total_num_kv_heads, 139 | bias=True, 140 | quant_config=quant_config, 141 | ) 142 | self.o_proj = RowParallelLinear( 143 | self.total_num_heads * self.head_dim, 144 | hidden_size, 145 | bias=False, 146 | quant_config=quant_config, 147 | ) 148 | 149 | self.rotary_emb = get_rope( 150 | self.head_dim, 151 | rotary_dim=self.head_dim, 152 | max_position=max_position, 153 | base=self.rope_theta, 154 | rope_scaling=rope_scaling, 155 | ) 156 | self.attn = Attention(self.num_heads, 157 | self.head_dim, 158 | self.scaling, 159 | num_kv_heads=self.num_kv_heads, 160 | cache_config=cache_config, 161 | quant_config=quant_config) 162 | 163 | def forward( 164 | self, 165 | positions: torch.Tensor, 166 | hidden_states: torch.Tensor, 167 | kv_cache: torch.Tensor, 168 | attn_metadata: AttentionMetadata, 169 | ) -> torch.Tensor: 170 | qkv, _ = self.qkv_proj(hidden_states) 171 | q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) 172 | q, k = self.rotary_emb(positions, q, k) 173 | attn_output = self.attn(q, k, v, kv_cache, attn_metadata) 174 | output, _ = self.o_proj(attn_output) 175 | return output 176 | 177 | 178 | class Qwen2DecoderLayer(nn.Module): 179 | 180 | def __init__( 181 | self, 182 | config: Qwen2Config, 183 | cache_config: Optional[CacheConfig] = None, 184 | quant_config: Optional[QuantizationConfig] = None, 185 | ) -> None: 186 | super().__init__() 187 | self.hidden_size = config.hidden_size 188 | # Requires transformers > 4.32.0 189 | rope_theta = getattr(config, "rope_theta", 1000000) 190 | rope_scaling = getattr(config, "rope_scaling", None) 191 | self.self_attn = Qwen2Attention( 192 | hidden_size=self.hidden_size, 193 | num_heads=config.num_attention_heads, 194 | max_position=config.max_position_embeddings, 195 | num_kv_heads=config.num_key_value_heads, 196 | rope_theta=rope_theta, 197 | cache_config=cache_config, 198 | quant_config=quant_config, 199 | rope_scaling=rope_scaling) 200 | self.mlp = Qwen2MLP( 201 | hidden_size=self.hidden_size, 202 | intermediate_size=config.intermediate_size, 203 | hidden_act=config.hidden_act, 204 | quant_config=quant_config, 205 | ) 206 | self.input_layernorm = RMSNorm(config.hidden_size, 207 | eps=config.rms_norm_eps) 208 | self.post_attention_layernorm = RMSNorm(config.hidden_size, 209 | eps=config.rms_norm_eps) 210 | 211 | def forward( 212 | self, 213 | positions: torch.Tensor, 214 | hidden_states: torch.Tensor, 215 | kv_cache: torch.Tensor, 216 | attn_metadata: AttentionMetadata, 217 | residual: Optional[torch.Tensor], 218 | ) -> Tuple[torch.Tensor, torch.Tensor]: 219 | # Self Attention 220 | if residual is None: 221 | residual = hidden_states 222 | hidden_states = self.input_layernorm(hidden_states) 223 | else: 224 | hidden_states, residual = self.input_layernorm( 225 | hidden_states, residual) 226 | hidden_states = self.self_attn( 227 | positions=positions, 228 | hidden_states=hidden_states, 229 | kv_cache=kv_cache, 230 | attn_metadata=attn_metadata, 231 | ) 232 | 233 | # Fully Connected 234 | hidden_states, residual = self.post_attention_layernorm( 235 | hidden_states, residual) 236 | hidden_states = self.mlp(hidden_states) 237 | return hidden_states, residual 238 | 239 | 240 | class Qwen2Model(nn.Module): 241 | 242 | def __init__( 243 | self, 244 | config: Qwen2Config, 245 | cache_config: Optional[CacheConfig] = None, 246 | quant_config: Optional[QuantizationConfig] = None, 247 | prefix: str = "", 248 | ) -> None: 249 | super().__init__() 250 | self.config = config 251 | self.padding_idx = config.pad_token_id 252 | self.vocab_size = config.vocab_size 253 | 254 | self.embed_tokens = VocabParallelEmbedding( 255 | config.vocab_size, 256 | config.hidden_size, 257 | quant_config=quant_config, 258 | ) 259 | self.start_layer, self.end_layer, self.layers = make_layers( 260 | config.num_hidden_layers, 261 | lambda prefix: Qwen2DecoderLayer(config=config, 262 | cache_config=cache_config, 263 | quant_config=quant_config), 264 | prefix=f"{prefix}.layers", 265 | ) 266 | 267 | self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) 268 | 269 | def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: 270 | return self.embed_tokens(input_ids) 271 | 272 | def forward( 273 | self, 274 | input_ids: torch.Tensor, 275 | positions: torch.Tensor, 276 | kv_caches: List[torch.Tensor], 277 | attn_metadata: AttentionMetadata, 278 | intermediate_tensors: Optional[IntermediateTensors] = None, 279 | inputs_embeds: Optional[torch.Tensor] = None, 280 | ) -> torch.Tensor: 281 | if get_pp_group().is_first_rank: 282 | if inputs_embeds is not None: 283 | hidden_states = inputs_embeds 284 | else: 285 | hidden_states = self.embed_tokens(input_ids) 286 | residual = None 287 | else: 288 | assert intermediate_tensors is not None 289 | hidden_states = intermediate_tensors["hidden_states"] 290 | residual = intermediate_tensors["residual"] 291 | for i in range(self.start_layer, self.end_layer): 292 | layer = self.layers[i] 293 | hidden_states, residual = layer( 294 | positions, 295 | hidden_states, 296 | kv_caches[i - self.start_layer], 297 | attn_metadata, 298 | residual, 299 | ) 300 | if not get_pp_group().is_last_rank: 301 | return IntermediateTensors({ 302 | "hidden_states": hidden_states, 303 | "residual": residual 304 | }) 305 | hidden_states, _ = self.norm(hidden_states, residual) 306 | return hidden_states 307 | 308 | 309 | class Qwen2ForCausalLM(nn.Module, SupportsLoRA): 310 | packed_modules_mapping = { 311 | "qkv_proj": [ 312 | "q_proj", 313 | "k_proj", 314 | "v_proj", 315 | ], 316 | "gate_up_proj": [ 317 | "gate_proj", 318 | "up_proj", 319 | ], 320 | } 321 | 322 | # LoRA specific attributes 323 | supported_lora_modules = [ 324 | "qkv_proj", 325 | "o_proj", 326 | "gate_up_proj", 327 | "down_proj", 328 | ] 329 | embedding_modules = {} 330 | embedding_padding_modules = [] 331 | 332 | def __init__( 333 | self, 334 | config: Qwen2Config, 335 | cache_config: Optional[CacheConfig] = None, 336 | quant_config: Optional[QuantizationConfig] = None, 337 | lora_config: Optional[LoRAConfig] = None, 338 | ) -> None: 339 | # TODO (@robertgshaw2): see if this can be moved out 340 | if (cache_config.sliding_window is not None 341 | and hasattr(config, "max_window_layers")): 342 | raise ValueError("Sliding window for some but all layers is not " 343 | "supported. This model uses sliding window " 344 | "but `max_window_layers` = %s is less than " 345 | "`num_hidden_layers` = %s. Please open an issue " 346 | "to discuss this feature." % ( 347 | config.max_window_layers, 348 | config.num_hidden_layers, 349 | )) 350 | 351 | super().__init__() 352 | 353 | self.config = config 354 | self.lora_config = lora_config 355 | 356 | self.quant_config = quant_config 357 | self.model = Qwen2Model(config, cache_config, quant_config) 358 | 359 | if config.tie_word_embeddings: 360 | self.lm_head = self.model.embed_tokens 361 | else: 362 | self.lm_head = ParallelLMHead(config.vocab_size, 363 | config.hidden_size, 364 | quant_config=quant_config) 365 | 366 | self.logits_processor = LogitsProcessor(config.vocab_size) 367 | self.sampler = Sampler() 368 | 369 | def forward( 370 | self, 371 | input_ids: torch.Tensor, 372 | positions: torch.Tensor, 373 | kv_caches: List[torch.Tensor], 374 | attn_metadata: AttentionMetadata, 375 | intermediate_tensors: Optional[IntermediateTensors] = None, 376 | ) -> torch.Tensor: 377 | hidden_states = self.model(input_ids, positions, kv_caches, 378 | attn_metadata, intermediate_tensors) 379 | return hidden_states 380 | 381 | def compute_logits(self, hidden_states: torch.Tensor, 382 | sampling_metadata: SamplingMetadata) -> torch.Tensor: 383 | logits = self.logits_processor(self.lm_head, hidden_states, 384 | sampling_metadata) 385 | return logits 386 | 387 | def make_empty_intermediate_tensors( 388 | self, batch_size: int, dtype: torch.dtype, 389 | device: torch.device) -> IntermediateTensors: 390 | return IntermediateTensors({ 391 | "hidden_states": 392 | torch.zeros((batch_size, self.config.hidden_size), 393 | dtype=dtype, 394 | device=device), 395 | "residual": 396 | torch.zeros((batch_size, self.config.hidden_size), 397 | dtype=dtype, 398 | device=device), 399 | }) 400 | 401 | def sample( 402 | self, 403 | logits: torch.Tensor, 404 | sampling_metadata: SamplingMetadata, 405 | ) -> Optional[SamplerOutput]: 406 | next_tokens = self.sampler(logits, sampling_metadata) 407 | return next_tokens 408 | 409 | def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): 410 | stacked_params_mapping = [ 411 | # (param_name, shard_name, shard_id) 412 | ("qkv_proj", "q_proj", "q"), 413 | ("qkv_proj", "k_proj", "k"), 414 | ("qkv_proj", "v_proj", "v"), 415 | ("gate_up_proj", "gate_proj", 0), 416 | ("gate_up_proj", "up_proj", 1), 417 | ] 418 | params_dict = dict(self.named_parameters(remove_duplicate=False)) 419 | for name, loaded_weight in weights: 420 | if "rotary_emb.inv_freq" in name: 421 | continue 422 | if self.config.tie_word_embeddings and "lm_head.weight" in name: 423 | continue 424 | for (param_name, weight_name, shard_id) in stacked_params_mapping: 425 | if weight_name not in name: 426 | continue 427 | name = name.replace(weight_name, param_name) 428 | # Skip loading extra bias for GPTQ models. 429 | if name.endswith(".bias") and name not in params_dict: 430 | continue 431 | if is_pp_missing_parameter(name, self): 432 | continue 433 | param = params_dict[name] 434 | weight_loader = param.weight_loader 435 | weight_loader(param, loaded_weight, shard_id) 436 | break 437 | else: 438 | # Skip loading extra bias for GPTQ models. 439 | if name.endswith(".bias") and name not in params_dict: 440 | continue 441 | # Remapping the name of FP8 kv-scale. 442 | name = maybe_remap_kv_scale_name(name, params_dict) 443 | if name is None: 444 | continue 445 | if is_pp_missing_parameter(name, self): 446 | continue 447 | param = params_dict[name] 448 | weight_loader = getattr(param, "weight_loader", 449 | default_weight_loader) 450 | weight_loader(param, loaded_weight) 451 | 452 | class CodeFuse_CGE_Large(nn.Module, SupportsLoRA): 453 | packed_modules_mapping = { 454 | "qkv_proj": [ 455 | "q_proj", 456 | "k_proj", 457 | "v_proj", 458 | ], 459 | "gate_up_proj": [ 460 | "gate_proj", 461 | "up_proj", 462 | ], 463 | } 464 | 465 | # LoRA specific attributes 466 | supported_lora_modules = [ 467 | "qkv_proj", 468 | "o_proj", 469 | "gate_up_proj", 470 | "down_proj", 471 | ] 472 | embedding_modules = {} 473 | embedding_padding_modules = [] 474 | 475 | def __init__( 476 | self, 477 | config: Qwen2Config, 478 | cache_config: Optional[CacheConfig] = None, 479 | quant_config: Optional[QuantizationConfig] = None, 480 | lora_config: Optional[LoRAConfig] = None, 481 | ) -> None: 482 | # TODO (@robertgshaw2): see if this can be moved out 483 | if (cache_config.sliding_window is not None 484 | and hasattr(config, "max_window_layers")): 485 | raise ValueError("Sliding window for some but all layers is not " 486 | "supported. This model uses sliding window " 487 | "but `max_window_layers` = %s is less than " 488 | "`num_hidden_layers` = %s. Please open an issue " 489 | "to discuss this feature." % ( 490 | config.max_window_layers, 491 | config.num_hidden_layers, 492 | )) 493 | 494 | super().__init__() 495 | 496 | self.config = config 497 | self.lora_config = lora_config 498 | self.quant_config = quant_config 499 | self.plm_model = Qwen2ForCausalLM(config, cache_config, quant_config) 500 | self.embedding_method = config.embedding_method 501 | self.inf_seq_length = config.inf_seq_length 502 | self.padding_side = config.padding_side 503 | self.keep_max_layer = config.keep_max_layer 504 | self.emb_dim = self.plm_model.model.embed_tokens.weight.size(1) 505 | self.num_heads = config.pma_num_heads 506 | self.ln = config.pma_ln 507 | self.norm = config.pma_norm 508 | self.pma_mode = config.pma_norm_mode 509 | self.mha_pma = PMA(self.emb_dim, self.compress_dim, self.num_heads, 1, ln=self.ln, pma_mode=self.pma_mode).to("cuda") 510 | if config.tie_word_embeddings: 511 | self.lm_head = self.plm_model.embed_tokens 512 | else: 513 | self.lm_head = ParallelLMHead(config.vocab_size, 514 | config.hidden_size, 515 | quant_config=quant_config) 516 | 517 | self.logits_processor = LogitsProcessor(config.vocab_size) 518 | self.sampler = Sampler() 519 | for param_tensor in self.mha_pma.state_dict(): 520 | print(param_tensor, "\t", self.mha_pma.state_dict()[param_tensor]) 521 | 522 | def forward( 523 | self, 524 | input_ids: torch.Tensor, 525 | positions: torch.Tensor, 526 | kv_caches: List[torch.Tensor], 527 | attn_metadata: AttentionMetadata, 528 | intermediate_tensors: Optional[IntermediateTensors] = None, 529 | ) -> torch.Tensor: 530 | hidden_states = self.plm_model(input_ids, positions, kv_caches, 531 | attn_metadata, intermediate_tensors) 532 | 533 | embedding = hidden_states.unsqueeze(0) 534 | res_embedding = self.pma_embedding(embedding, positions.unsqueeze(0)) 535 | return res_embedding 536 | 537 | def pooler( 538 | self, 539 | hidden_states: torch.Tensor, 540 | pooling_metadata: PoolingMetadata, 541 | ) -> Optional[PoolerOutput]: 542 | hidden_states = nn.functional.normalize(hidden_states, p=2, dim=1) 543 | pooled_outputs = [ 544 | EmbeddingSequenceGroupOutput(data.tolist()) for data in hidden_states 545 | ] 546 | 547 | return PoolerOutput(outputs=pooled_outputs) 548 | 549 | def compute_logits(self, hidden_states: torch.Tensor, 550 | sampling_metadata: SamplingMetadata) -> torch.Tensor: 551 | logits = self.logits_processor(self.lm_head, hidden_states, 552 | sampling_metadata) 553 | return logits 554 | 555 | def make_empty_intermediate_tensors( 556 | self, batch_size: int, dtype: torch.dtype, 557 | device: torch.device) -> IntermediateTensors: 558 | return IntermediateTensors({ 559 | "hidden_states": 560 | torch.zeros((batch_size, self.config.hidden_size), 561 | dtype=dtype, 562 | device=device), 563 | "residual": 564 | torch.zeros((batch_size, self.config.hidden_size), 565 | dtype=dtype, 566 | device=device), 567 | }) 568 | 569 | def sample( 570 | self, 571 | logits: torch.Tensor, 572 | sampling_metadata: SamplingMetadata, 573 | ) -> Optional[SamplerOutput]: 574 | next_tokens = self.sampler(logits, sampling_metadata) 575 | return next_tokens 576 | 577 | def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): 578 | stacked_params_mapping = [ 579 | # (param_name, shard_name, shard_id) 580 | ("qkv_proj", "q_proj", "q"), 581 | ("qkv_proj", "k_proj", "k"), 582 | ("qkv_proj", "v_proj", "v"), 583 | ("gate_up_proj", "gate_proj", 0), 584 | ("gate_up_proj", "up_proj", 1), 585 | ] 586 | params_dict = dict(self.named_parameters(remove_duplicate=False)) 587 | for name, loaded_weight in weights: 588 | if "rotary_emb.inv_freq" in name: 589 | continue 590 | if self.config.tie_word_embeddings and "lm_head.weight" in name: 591 | continue 592 | for (param_name, weight_name, shard_id) in stacked_params_mapping: 593 | if weight_name not in name: 594 | continue 595 | name = name.replace(weight_name, param_name) 596 | # Skip loading extra bias for GPTQ models. 597 | if name.endswith(".bias") and name not in params_dict: 598 | continue 599 | if is_pp_missing_parameter(name, self): 600 | continue 601 | param = params_dict[name] 602 | weight_loader = param.weight_loader 603 | weight_loader(param, loaded_weight, shard_id) 604 | break 605 | else: 606 | # Skip loading extra bias for GPTQ models. 607 | if name.endswith(".bias") and name not in params_dict: 608 | continue 609 | # Remapping the name of FP8 kv-scale. 610 | name = maybe_remap_kv_scale_name(name, params_dict) 611 | if name is None: 612 | continue 613 | if is_pp_missing_parameter(name, self): 614 | continue 615 | param = params_dict[name] 616 | weight_loader = getattr(param, "weight_loader", 617 | default_weight_loader) 618 | weight_loader(param, loaded_weight) 619 | for param_tensor in self.mha_pma.state_dict(): 620 | print(param_tensor, "\t", self.mha_pma.state_dict()[param_tensor]) 621 | 622 | def last_embedding(self, A, index): 623 | bs, seq, emb = A.size() 624 | res = A[torch.arange(bs), index, :] 625 | return res 626 | 627 | def mean_embedding(self, A, mask): 628 | bs, seq, emb = A.size() 629 | res = (A * (mask.unsqueeze(-1))).sum(1) / (mask.sum(1).unsqueeze(-1)) 630 | return res 631 | 632 | # A (bs, seq, emb_size), mask (bs, 1, seq) 633 | def weighted_embedding(self, A, mask): 634 | weights = (torch.arange(start=1, end=A.size(1) + 1).unsqueeze(0).unsqueeze(-1).expand(A.size()).float()).to(A.device) 635 | input_mask_expanded = (mask.squeeze(1).unsqueeze(-1).expand(A.size()).float()).to(A.device) 636 | sum_embedding = torch.sum(A * input_mask_expanded * weights, dim=1) 637 | sum_mask = torch.sum(input_mask_expanded * weights, dim=1) 638 | weighted_embedding = sum_embedding / sum_mask 639 | 640 | return weighted_embedding 641 | 642 | def pma_embedding(self, A, mask): 643 | res = self.mha_pma(A, mask).squeeze(1) 644 | return res 645 | 646 | 647 | def get_sentence_embedding(self, embedding_method, **inputs): 648 | outputs = self.plm_model(inputs['input_ids'], inputs['attention_mask'], output_hidden_states=True) 649 | if embedding_method == 'last': 650 | embedding = outputs.hidden_states[self.keep_max_layer] 651 | index = inputs['attention_mask'].sum(-1).long() - 1 652 | res_embedding = self.last_embedding(embedding, index) 653 | elif embedding_method == 'mean': 654 | embedding = outputs.hidden_states[self.keep_max_layer] 655 | res_embedding = self.mean_embedding(embedding, inputs['attention_mask']) 656 | elif embedding_method == 'weighted': 657 | embedding = outputs.hidden_states[self.keep_max_layer] 658 | res_embedding = self.weighted_embedding(embedding, inputs['attention_mask']) 659 | elif embedding_method == 'pma': 660 | embedding = outputs.hidden_states[self.keep_max_layer] 661 | attention_mask = inputs['attention_mask'] 662 | res_embedding = self.pma_embedding(embedding, attention_mask) 663 | else: 664 | logger.debug('Error, no {} way to obtain embbedings'.format(embedding_method)) 665 | 666 | if not self.norm: 667 | res_embedding = torch.nn.functional.normalize(res_embedding, p=2.0, dim=-1, eps=1e-12, out=None) 668 | return res_embedding 669 | 670 | 671 | 672 | def encode(self, tokenizer, sentences, batch_size=32, convert_to_numpy=True, 673 | convert_to_tensor=False, show_progress_bar=True, max_seq_length=None, **kwargs): 674 | if max_seq_length is None: 675 | max_seq_length = self.inf_seq_length 676 | input_is_string = False 677 | if isinstance(sentences, str) or not hasattr(sentences, "__len__"): 678 | sentences = [sentences] 679 | input_is_string = True 680 | 681 | all_embeddings = [] 682 | length_sorted_idx = np.argsort([-len(s) for s in sentences]) 683 | sentences_sorted = [sentences[idx] for idx in length_sorted_idx] # 大到小重排 684 | with torch.no_grad(): 685 | for start_index in trange(0, len(sentences), batch_size, desc="Batches", disable=not show_progress_bar): 686 | sentences_batch = sentences_sorted[start_index: start_index + batch_size] 687 | # Compute sentences embeddings 688 | with torch.no_grad(): 689 | inputs = tokenizer(sentences_batch, padding=True, truncation=True, max_length=max_seq_length, add_special_tokens=False, return_tensors='pt').to(self.plm_model.device) 690 | embeddings = self.get_sentence_embedding(self.embedding_method, **inputs) 691 | embeddings = embeddings.detach() 692 | if convert_to_numpy: 693 | if embeddings.dtype == torch.bfloat16: 694 | embeddings = embeddings.cpu().to(torch.float32) 695 | else: 696 | embeddings = embeddings.cpu() 697 | all_embeddings.extend(embeddings) 698 | all_embeddings = [all_embeddings[idx] for idx in np.argsort(length_sorted_idx)] 699 | if convert_to_tensor: 700 | all_embeddings = torch.stack(all_embeddings) 701 | elif convert_to_numpy: 702 | all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings]) 703 | 704 | if input_is_string: 705 | all_embeddings = all_embeddings[0] 706 | return all_embeddings 707 | 708 | 709 | class MAB_POST(nn.Module): 710 | def __init__(self, dim_Q, dim_K, dim_V, num_heads, ln=False): 711 | super(MAB_POST, self).__init__() 712 | self.dim_V = dim_V 713 | self.num_heads = num_heads 714 | self.fc_q = nn.Linear(dim_Q, dim_V) 715 | self.fc_k = nn.Linear(dim_K, dim_V) 716 | self.fc_v = nn.Linear(dim_K, dim_V) 717 | 718 | if ln: 719 | self.ln0 = nn.LayerNorm(dim_V) 720 | self.ln1 = nn.LayerNorm(dim_V) 721 | self.fc_o = nn.Linear(dim_V, dim_V) 722 | nn.init.xavier_uniform_(self.fc_q.weight) 723 | nn.init.xavier_uniform_(self.fc_k.weight) 724 | nn.init.xavier_uniform_(self.fc_v.weight) 725 | nn.init.xavier_uniform_(self.fc_o.weight) 726 | 727 | def forward(self, Q, K, pad_mask=None): 728 | 729 | Q_ = self.fc_q(Q) 730 | K_, V_ = self.fc_k(K), self.fc_v(K) 731 | dim_split = self.dim_V // self.num_heads 732 | Q_ = torch.cat(Q_.split(dim_split, 2), 0) 733 | K_ = torch.cat(K_.split(dim_split, 2), 0) 734 | V_ = torch.cat(V_.split(dim_split, 2), 0) 735 | 736 | pad_mask = pad_mask.unsqueeze(1).repeat(self.num_heads, Q.size(1), 1) 737 | score = Q_.bmm(K_.transpose(1,2))/math.sqrt(self.dim_V) 738 | score = score.masked_fill(pad_mask == 0, -1e12) 739 | A = torch.softmax(score, 2) 740 | A = A * pad_mask 741 | O = torch.cat(A.bmm(V_).split(Q.size(0), 0), 2) 742 | O = Q + O 743 | O = O if getattr(self, 'ln0', None) is None else self.ln0(O) 744 | O = O + F.relu(self.fc_o(O)) 745 | O = O if getattr(self, 'ln1', None) is None else self.ln1(O) 746 | return O 747 | 748 | 749 | class MAB_PRE_NORMAL(nn.Module): 750 | def __init__(self, dim_Q, dim_K, dim_V, num_heads, ln=False): 751 | super(MAB_PRE_NORMAL, self).__init__() 752 | self.dim_V = dim_V 753 | self.num_heads = num_heads 754 | self.fc_q = nn.Linear(dim_Q, dim_V) 755 | self.fc_k = nn.Linear(dim_K, dim_V) 756 | self.fc_v = nn.Linear(dim_K, dim_V) 757 | 758 | if ln: 759 | self.ln_q = nn.LayerNorm(dim_V) 760 | self.ln_kv = nn.LayerNorm(dim_V) 761 | self.ln_o = nn.LayerNorm(dim_V) 762 | self.ln_final = nn.LayerNorm(dim_V) 763 | 764 | self.fc_o = nn.Linear(dim_V, dim_V) 765 | nn.init.xavier_uniform_(self.fc_q.weight) 766 | nn.init.xavier_uniform_(self.fc_k.weight) 767 | nn.init.xavier_uniform_(self.fc_v.weight) 768 | nn.init.xavier_uniform_(self.fc_o.weight) 769 | 770 | def forward(self, Q, K, pad_mask=None): 771 | Q_ = Q if getattr(self, 'ln_q', None) is None else self.ln_q(Q) 772 | K_ = K if getattr(self, 'ln_kv', None) is None else self.ln_kv(K) 773 | Q_ = self.fc_q(Q_) 774 | K_, V_ = self.fc_k(K_), self.fc_v(K_) 775 | dim_split = self.dim_V // self.num_heads 776 | Q_ = torch.cat(Q_.split(dim_split, 2), 0) 777 | K_ = torch.cat(K_.split(dim_split, 2), 0) 778 | V_ = torch.cat(V_.split(dim_split, 2), 0) 779 | pad_mask = pad_mask.unsqueeze(1).repeat(self.num_heads, Q.size(1), 1) 780 | score = Q_.bmm(K_.transpose(1,2))/math.sqrt(self.dim_V) 781 | score = score.masked_fill(pad_mask == 0, -1e12) 782 | A = torch.softmax(score, 2) 783 | A = A * pad_mask 784 | O = torch.cat(A.bmm(V_).split(Q.size(0), 0), 2) 785 | O = Q + O 786 | O_ = O if getattr(self, 'ln_o', None) is None else self.ln_o(O) 787 | O_ = O + F.relu(self.fc_o(O_)) 788 | return O_ if getattr(self, 'ln_final', None) is None else self.ln_final(O_) 789 | 790 | 791 | class MAB_PRE_GPTJ(nn.Module): 792 | def __init__(self, dim_Q, dim_K, dim_V, num_heads, ln=False): 793 | super(MAB_PRE_GPTJ, self).__init__() 794 | self.dim_V = dim_V 795 | self.num_heads = num_heads 796 | self.fc_q = nn.Linear(dim_Q, dim_V) 797 | self.fc_k = nn.Linear(dim_K, dim_V) 798 | self.fc_v = nn.Linear(dim_K, dim_V) 799 | self.fc_o = nn.Linear(dim_V, dim_V) 800 | 801 | nn.init.xavier_uniform_(self.fc_q.weight) 802 | nn.init.xavier_uniform_(self.fc_k.weight) 803 | nn.init.xavier_uniform_(self.fc_v.weight) 804 | nn.init.xavier_uniform_(self.fc_o.weight) 805 | if ln: 806 | self.ln_q = nn.LayerNorm(dim_V) 807 | self.ln_kv = nn.LayerNorm(dim_V) 808 | self.ln_final = nn.LayerNorm(dim_V) 809 | 810 | def forward(self, Q, K, pad_mask=None): 811 | Q_ = Q if getattr(self, 'ln_q', None) is None else self.ln_q(Q) 812 | K_ = K if getattr(self, 'ln_kv', None) is None else self.ln_kv(K) 813 | 814 | Q1 = self.fc_q(Q_) 815 | K1, V1 = self.fc_k(K_), self.fc_v(K_) 816 | dim_split = self.dim_V // self.num_heads 817 | 818 | Q1 = torch.cat(Q1.split(dim_split, 2), 0) 819 | K1 = torch.cat(K1.split(dim_split, 2), 0) 820 | V1 = torch.cat(V1.split(dim_split, 2), 0) 821 | 822 | 823 | pad_mask = pad_mask.unsqueeze(1).repeat(self.num_heads, Q.size(1), 1) 824 | score = Q1.bmm(K1.transpose(1,2))/math.sqrt(self.dim_V) 825 | score = score.masked_fill(pad_mask == 0, -1e12) 826 | A = torch.softmax(score, 2) 827 | A = A * pad_mask 828 | O1 = torch.cat(A.bmm(V1).split(Q.size(0), 0), 2) 829 | O2 = F.relu(self.fc_o(Q_)) 830 | O_final = Q + O1 + O2 831 | return O_final if getattr(self, 'ln_final', None) is None else self.ln_final(O_final) 832 | 833 | 834 | class PMA(nn.Module): 835 | def __init__(self, dim, compress_dim, num_heads, num_seeds, ln=False, pma_mode=None): 836 | super(PMA, self).__init__() 837 | self.S = nn.Parameter(torch.Tensor(1, num_seeds, compress_dim)) 838 | nn.init.xavier_uniform_(self.S) 839 | if pma_mode == 'post_normal': 840 | self.mab = MAB_POST(compress_dim, dim, compress_dim, num_heads, ln=ln) 841 | elif pma_mode == 'pre_normal': 842 | self.mab = MAB_PRE_NORMAL(compress_dim, dim, compress_dim, num_heads, ln=ln) 843 | elif pma_mode == 'pre_gptj': 844 | self.mab = MAB_PRE_GPTJ(compress_dim, dim, compress_dim, num_heads, ln=ln) 845 | else: 846 | raise ValueError(f"Error, the pma_mode {pma_mode} is not implemented !") 847 | 848 | def forward(self, X, pad_mask): 849 | if self.S.dtype != torch.bfloat16: 850 | X = X.float() 851 | return self.mab(self.S.repeat(X.size(0), 1, 1), X, pad_mask) 852 | --------------------------------------------------------------------------------