├── F2LLM ├── imgs │ ├── overview.png │ └── mteb_leaderboard.png ├── requirements.txt ├── configs │ ├── config.json │ └── accelerate_config.yaml ├── arguments.py ├── model.py ├── tokenize_data_qwen.py ├── README.md ├── run.py └── utils.py ├── CGE ├── resources │ ├── result.png │ └── CodeFuse-AI.png ├── README.md └── utils │ └── vllm_codefuse_cge_large.py └── README.md /F2LLM/imgs/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/codefuse-ai/CodeFuse-Embeddings/main/F2LLM/imgs/overview.png -------------------------------------------------------------------------------- /CGE/resources/result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/codefuse-ai/CodeFuse-Embeddings/main/CGE/resources/result.png -------------------------------------------------------------------------------- /F2LLM/requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate 2 | datasets 3 | deepspeed 4 | flash-attn 5 | torch 6 | transformers 7 | tensorboard 8 | -------------------------------------------------------------------------------- /CGE/resources/CodeFuse-AI.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/codefuse-ai/CodeFuse-Embeddings/main/CGE/resources/CodeFuse-AI.png -------------------------------------------------------------------------------- /F2LLM/imgs/mteb_leaderboard.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/codefuse-ai/CodeFuse-Embeddings/main/F2LLM/imgs/mteb_leaderboard.png -------------------------------------------------------------------------------- /F2LLM/configs/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_path": "models/qwen3-4b", 3 | "experiment_id": "4b+lr.8e-6+bs.16x32+context.1024+2epochs", 4 | "train_data_path": "training_data/data_tokenized_qwen", 5 | "output_dir": "output", 6 | "tb_dir": "output/tb", 7 | "cache_dir": "cache", 8 | "train_batch_size": 16, 9 | "checkpointing_steps": 5000, 10 | "validation_steps": 5000, 11 | "max_seq_length": 1024, 12 | "learning_rate": 8e-6, 13 | "min_lr": 1e-7, 14 | "weight_decay": 0.01, 15 | "warmup_steps": 500, 16 | "train_epochs": 2, 17 | "log_interval": 100, 18 | "num_hard_neg": 7 19 | } 20 | -------------------------------------------------------------------------------- /F2LLM/configs/accelerate_config.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | deepspeed_config: 4 | gradient_accumulation_steps: 1 5 | gradient_clipping: 1.0 6 | offload_optimizer_device: none 7 | offload_param_device: none 8 | zero3_init_flag: false 9 | zero_stage: 2 10 | distributed_type: DEEPSPEED 11 | downcast_bf16: "no" 12 | machine_rank: 0 13 | main_training_function: main 14 | mixed_precision: "bf16" 15 | num_machines: 1 16 | num_processes: 8 17 | rdzv_backend: static 18 | same_network: true 19 | tpu_env: [] 20 | tpu_use_cluster: false 21 | tpu_use_sudo: false 22 | use_cpu: false 23 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## CodeFuse Embeddings 2 | 3 |

4 | 5 |

6 | 7 | Embedding-related repos from CodeFuse, including: 8 | 9 | - [CGE](./CGE/README.md) 10 | - [D2LLM](https://github.com/codefuse-ai/D2LLM) 11 | - [F2LLM](./F2LLM/README.md) 12 | 13 | **Star History** 14 | 15 | [![Star History Chart](https://api.star-history.com/svg?repos=codefuse-ai/CodeFuse-Embeddings&type=date&legend=top-left)](https://www.star-history.com/#codefuse-ai/CodeFuse-Embeddings&type=date&legend=top-left) 16 | -------------------------------------------------------------------------------- /F2LLM/arguments.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, asdict 2 | import argparse, json 3 | 4 | 5 | @dataclass 6 | class Args: 7 | 8 | model_path: str 9 | experiment_id: str 10 | # save dir 11 | output_dir: str 12 | tb_dir: str 13 | cache_dir: str 14 | # training arguments 15 | train_data_path: str 16 | train_batch_size: int = 8 17 | max_seq_length: int = 2048 18 | learning_rate: float = 1e-4 19 | min_lr: float = 1e-6 20 | weight_decay: float = 1e-2 21 | warmup_steps: int = 100 22 | # embedding-related settings 23 | num_hard_neg: int = 7 24 | # train steps take precedence over epochs, set to -1 to disable 25 | train_steps: int = -1 26 | train_epochs: int = 5 27 | log_interval: int = 20 28 | checkpointing_steps: int = 100 29 | validation_steps: int = 100 30 | # just placeholder, for logging purpose 31 | num_processes: int=0 32 | 33 | def dict(self): 34 | return asdict(self) 35 | 36 | 37 | def parse_args(): 38 | parser = argparse.ArgumentParser() 39 | parser.add_argument("--config", type=str) 40 | arg = parser.parse_args() 41 | with open(arg.config) as f: 42 | config = json.load(f) 43 | args = Args(**config) 44 | args.output_dir = f"{args.output_dir}/{args.experiment_id}" 45 | args.tb_dir = f"{args.tb_dir}/{args.experiment_id}" 46 | return args -------------------------------------------------------------------------------- /F2LLM/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import AutoModel, AutoTokenizer 3 | 4 | 5 | class F2LLM: 6 | def __init__(self, 7 | model_path, 8 | max_seq_length=512, 9 | args=None 10 | ): 11 | 12 | self.args = args 13 | self.dtype = torch.bfloat16 14 | self.device = None # set after accelerator.prepare 15 | self.lm = AutoModel.from_pretrained(model_path, trust_remote_code=True, torch_dtype=self.dtype, attn_implementation='flash_attention_2') 16 | self.lm.config.use_cache = False 17 | self.tokenizer = AutoTokenizer.from_pretrained(model_path) 18 | self.max_seq_length = max_seq_length 19 | 20 | def set_device(self): 21 | self.device = self.lm.device 22 | 23 | def forward(self, batch): 24 | bs = batch['bs'] 25 | num_hard_neg = int((len(batch['input_ids']) - 2*bs) / bs) 26 | 27 | outputs = self.lm(batch['input_ids'], 28 | batch['attention_mask'], 29 | ) 30 | 31 | passage_features_all_tokens = outputs.last_hidden_state 32 | return { 33 | 'query_passage_features': torch.stack([passage_features_all_tokens[i, [batch['seq_lens'][i]-1]] for i in range(bs)]), 34 | 'passage_passage_features': torch.stack([passage_features_all_tokens[i, [batch['seq_lens'][i]-1]] for i in range(bs, 2*bs)]), 35 | 'negative_passage_features': None if num_hard_neg == 0 else torch.stack([passage_features_all_tokens[i, [batch['seq_lens'][i]-1]] for i in range(2*bs, len(batch['seq_lens']))]).view(bs, num_hard_neg, -1) 36 | } 37 | 38 | -------------------------------------------------------------------------------- /F2LLM/tokenize_data_qwen.py: -------------------------------------------------------------------------------- 1 | from multiprocessing import Pool 2 | import numpy as np 3 | import pandas as pd 4 | import os 5 | from transformers import AutoTokenizer 6 | from tqdm.auto import tqdm 7 | 8 | 9 | tokenizer = AutoTokenizer.from_pretrained('models/qwen3-0.6b') 10 | max_seq_length = 1023 11 | 12 | 13 | def process_sent(sentence): 14 | 15 | # We make sure there's always an eos token at the end of each sequence 16 | tokenizer_outputs = tokenizer(sentence, max_length=max_seq_length, truncation=True, add_special_tokens=False) 17 | 18 | return np.array(tokenizer_outputs.input_ids + [tokenizer.eos_token_id]) 19 | 20 | 21 | def process_sent_batch(s): 22 | return s.apply(process_sent) 23 | 24 | def parallelize(data, func, num_of_processes=8): 25 | indices = np.array_split(data.index, num_of_processes) 26 | data_split = [data.iloc[idx] for idx in indices] 27 | with Pool(num_of_processes) as pool: 28 | data = pd.concat(pool.map(func, data_split)) 29 | return data 30 | 31 | 32 | root_dir = 'training_data' 33 | for ds_name in tqdm(sorted(os.listdir(root_dir))): 34 | print(ds_name, flush=True) 35 | 36 | df = pd.read_parquet(f"{root_dir}/{ds_name}") 37 | df['query_input_ids'] = parallelize(df['query'], process_sent_batch, 62) 38 | 39 | num_neg = 24 if 'negative_2' in df.keys() else 1 40 | 41 | ls = df.passage.to_list() 42 | for i in range(1, num_neg+1): 43 | ls += df[f'negative_{i}'].to_list() 44 | ls = list(set(ls)) 45 | df_tmp = pd.DataFrame({'text': ls}) 46 | df_tmp['input_ids'] = parallelize(df_tmp['text'], process_sent_batch, 62) 47 | df_tmp = df_tmp.set_index('text') 48 | 49 | df['passage_input_ids'] = df.passage.map(df_tmp.input_ids) 50 | 51 | for i in range(1, num_neg+1): 52 | df[f'negative_{i}_input_ids'] = df[f'negative_{i}'].map(df_tmp.input_ids) 53 | 54 | df.to_parquet(f'data_tokenized_qwen/{ds_name}', index=False) 55 | -------------------------------------------------------------------------------- /F2LLM/README.md: -------------------------------------------------------------------------------- 1 | ## F2LLM 2 | 3 | F2LLMs (Foundation-to-Feature Large Language Models) are foundation models directly finetuned on 6 million high-quality query-document pairs, striking a strong balance between model size, training cost, and embedding performance: 4 | 5 |

6 | 7 |

8 | 9 | On the MTEB leaderboard, F2LLM-4B ranks 2nd among models of ~4B size, and 7th overall, while F2LLM-1.7B ranks 1st among models of 1B-2B size. 10 | 11 |

12 | 13 |

14 | 15 | F2LLMs are fully open. Model checkpoints are available at: 16 | 17 | - [F2LLM 0.6B](https://huggingface.co/codefuse-ai/F2LLM-0.6B) 18 | - [F2LLM 1.7B](https://huggingface.co/codefuse-ai/F2LLM-1.7B) 19 | - [F2LLM 4B](https://huggingface.co/codefuse-ai/F2LLM-4B) 20 | 21 | Training data is available at [F2LLM data](https://huggingface.co/datasets/codefuse-ai/F2LLM). 22 | 23 | ### Train 24 | 25 | In this repo we provide a streamlined and efficient script for training embedding models. To reproduce the training of F2LLMs, please: 26 | 27 | - Setup environment following `requirements.txt`. We note that transformers>=4.51.0 is required for training Qwen3 models. 28 | - Download data and backbone models from Hugging Face (we use Qwen3 models). 29 | - Run `tokenize_data_qwen.py` to tokenize the downloaded data 30 | - Modify model path, data path, and other arguments in `configs/config.json`. 31 | - Start training with `accelerate launch --config_file configs/accelerate_config.yaml run.py --config configs/config.json`. 32 | 33 | Note: we recommend setting `num_processes` to 1 in `configs/accelerate_config.yaml` and launch the training code once to generate cache for training data before starting the actual training. 34 | 35 | For multi-node training, run on the main node: 36 | 37 | ``` 38 | accelerate launch --config_file configs/accelerate_config.yaml --num_machines N_NODE --num_processes N_PROCESSES --machine_rank 0 --main_process_ip MASTER_IP --main_process_port MASTER_PORT run.py --config configs/config.json 39 | ``` 40 | 41 | where N_NODE is the number of machines; N_PROCESSES is N_NODE\*8; MASTER_IP is the IP address of your master node, and MASTER_PORT is a port available on your machine (e.g. 6379). 42 | 43 | On worker nodes, also run the above commmand but modify `machine_rank` accordingly. 44 | 45 | ### Citation 46 | 47 | If you use the F2LLM models, data, or code, please cite the following technical report. 48 | 49 | ``` 50 | @article{2025F2LLM, 51 | title={F2LLM Technical Report: Matching SOTA Embedding Performance with 6 Million Open-Source Data}, 52 | author={Ziyin Zhang and Zihan Liao and Hang Yu and Peng Di and Rui Wang}, 53 | journal = {CoRR}, 54 | volume = {abs/2510.02294}, 55 | year = {2025}, 56 | url = {https://doi.org/10.48550/arXiv.2510.02294}, 57 | doi = {10.48550/ARXIV.2510.02294}, 58 | eprinttype = {arXiv}, 59 | eprint = {2510.02294} 60 | } 61 | ``` 62 | -------------------------------------------------------------------------------- /CGE/README.md: -------------------------------------------------------------------------------- 1 | ## CodeFuse-CGE 2 |

3 | 4 |

5 | 6 | In this project, we introduce CodeFuse-CGE(Code General Embedding), which is distinguish on text2code task for it's powerful ability of capturing the semantic relationship between text and code. 7 | This model has the following notable features: 8 | ● Instruction-tuning is enabled for both query and code snippet sides. 9 | ● The model obtains sentence-level and code-level representations through a layer of cross-attention computation module. 10 | ● The model has a smaller dimensional size without significant degradation in performance. 11 | 12 | CodeFuse-CGE-Large Model Configuration 13 | huggingface:[codefuse-ai/CodeFuse-CGE-Large](https://huggingface.co/codefuse-ai/CodeFuse-CGE-Large) 14 | Base Model: CodeQwen1.5-7B-Chat 15 | Model Size: 7B 16 | Embedding Dimension: 1024 17 | Hidden Layers: 32 18 | 19 | Requirements 20 | ``` 21 | flash_attn==2.4.2 22 | torch==2.1.0 23 | accelerate==0.28.0 24 | transformers==4.39.2 25 | vllm=0.5.3 26 | ``` 27 | 28 | 29 | CodeFuse-CGE-Small Model Configuration 30 | huggingface:[codefuse-ai/CodeFuse-CGE-Small](https://huggingface.co/codefuse-ai/CodeFuse-CGE-Small) 31 | Base Model: Phi-3.5-mini-instruct 32 | Model Size: 3.8B 33 | Embedding Dimension: 1024 34 | Hidden Layers: 32 35 | 36 | Requirements 37 | ``` 38 | flash_attn==2.4.2 39 | torch==2.1.0 40 | accelerate==0.28.0 41 | transformers>=4.43.0 42 | ``` 43 | 44 | 45 | ## Benchmark the Performance 46 | We use MRR metric to evaluate the ability on text2code retrieval tasks: AdvTest, CosQA, CSN 47 | 48 | ![result](./resources/result.png) 49 | 50 | ## How to Use 51 | 52 | You should download model file for huggingface at first. 53 | 54 | ### Transformers 55 | ``` 56 | from transformers import AutoTokenizer, AutoModel 57 | 58 | model_name_or_path = "CodeFuse-CGE-Large" 59 | model = AutoModel.from_pretrained(model_name_or_path, trust_remote_code=True) 60 | tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True, truncation_side='right', padding_side='right') 61 | 62 | if torch.cuda.is_available(): 63 | device = 'cuda' 64 | else: 65 | device = 'cpu' 66 | model.to(device) 67 | 68 | prefix_dict = {'python':{'query':'Retrieve the Python code that solves the following query:', 'passage':'Python code:'}, 69 | 'java':{'query':'Retrieve the Java code that solves the following query:', 'passage':'Java code:'}, 70 | 'go':{'query':'Retrieve the Go code that solves the following query:', 'passage':'Go code:'}, 71 | 'c++':{'query':'Retrieve the C++ code that solves the following query:', 'passage':'C++ code:'}, 72 | 'javascript':{'query':'Retrieve the Javascript code that solves the following query:', 'passage':'Javascript code:'}, 73 | 'php':{'query':'Retrieve the PHP code that solves the following query:', 'passage':'PHP code:'}, 74 | 'ruby':{'query':'Retrieve the Ruby code that solves the following query:', 'passage':'Ruby code:'}, 75 | 'default':{'query':'Retrieve the code that solves the following query:', 'passage':'Code:'} 76 | } 77 | 78 | text = ["Writes a Boolean to the stream.", 79 | "def writeBoolean(self, n): t = TYPE_BOOL_TRUE if n is False: t = TYPE_BOOL_FALSE self.stream.write(t)"] 80 | text[0] += prefix_dict['python']['query'] 81 | text[1] += prefix_dict['python']['passage'] 82 | embed = model.encode(tokenizer, text) 83 | score = embed[0] @ embed[1].T 84 | print("score", score) 85 | ``` 86 | 87 | ### Vllm 88 | We have also adapted Vllm to reduce latency during deployment. 89 | ``` 90 | from vllm import ModelRegistry 91 | from utils.vllm_codefuse_cge_large import CodeFuse_CGE_Large 92 | from vllm.model_executor.models import ModelRegistry 93 | from vllm import LLM 94 | 95 | def always_true_is_embedding_model(model_arch: str) -> bool: 96 | return True 97 | ModelRegistry.is_embedding_model = always_true_is_embedding_model 98 | ModelRegistry.register_model("CodeFuse_CGE_Large", CodeFuse_CGE_Large) 99 | 100 | 101 | model_name_or_path = "CodeFuse-CGE-Large" 102 | model = LLM(model=model_name_or_path, trust_remote_code=True, enforce_eager=True, enable_chunked_prefill=False) 103 | prefix_dict = {'python':{'query':'Retrieve the Python code that solves the following query:', 'passage':'Python code:'}, 104 | 'java':{'query':'Retrieve the Java code that solves the following query:', 'passage':'Java code:'}, 105 | 'go':{'query':'Retrieve the Go code that solves the following query:', 'passage':'Go code:'}, 106 | 'c++':{'query':'Retrieve the C++ code that solves the following query:', 'passage':'C++ code:'}, 107 | 'javascript':{'query':'Retrieve the Javascript code that solves the following query:', 'passage':'Javascript code:'}, 108 | 'php':{'query':'Retrieve the PHP code that solves the following query:', 'passage':'PHP code:'}, 109 | 'ruby':{'query':'Retrieve the Ruby code that solves the following query:', 'passage':'Ruby code:'}, 110 | 'default':{'query':'Retrieve the code that solves the following query:', 'passage':'Code:'} 111 | } 112 | 113 | text = ["Return the best fit based on rsquared", 114 | "def find_best_rsquared ( list_of_fits ) : res = sorted ( list_of_fits , key = lambda x : x . rsquared ) return res [ - 1 ]"] 115 | text[0] += prefix_dict['python']['query'] 116 | text[1] += prefix_dict['python']['passage'] 117 | embed_0 = model.encode([text[0]])[0].outputs.embedding 118 | embed_1 = model.encode([text[1]])[0].outputs.embedding 119 | ``` 120 | Note: 121 | 1. After adapting Vllm, the model's input can only have a batch size of 1; otherwise, it will result in an array overflow error. 122 | 2. Only the CodeFuse-CGE-Large model has been adapted, and support for the CodeFuse-CGE-Small model will be available soon. 123 | 124 | ## Contact us 125 | Email: 126 | 127 | ![CodeFuse-AI](./resources/CodeFuse-AI.png) 128 | 129 | 130 | 131 | ## Acknowledgement 132 | Thanks to the authors of open-sourced datasets, including CSN, Adv, CoSQA. 133 | 134 | -------------------------------------------------------------------------------- /F2LLM/run.py: -------------------------------------------------------------------------------- 1 | from arguments import parse_args 2 | from utils import accelerate_train, CLASSIFICATION_DATASETS 3 | from transformers import ( 4 | AutoTokenizer, 5 | set_seed, 6 | get_scheduler 7 | ) 8 | import os, json, random 9 | from datasets import load_dataset 10 | from torch.utils.data import DataLoader 11 | from accelerate import Accelerator 12 | from accelerate.state import AcceleratorState 13 | import torch 14 | from torch.nn.utils.rnn import pad_sequence 15 | from torch.optim import AdamW 16 | from model import F2LLM 17 | 18 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 19 | 20 | args = parse_args() 21 | accelerator = Accelerator() 22 | args.num_processes = accelerator.num_processes 23 | accelerator.print(args) 24 | 25 | def _stack(input_ids, max_len): 26 | data = [ids[:max_len] for ids in input_ids] # input_ids: list of lists 27 | lens = [len(x) for x in data] 28 | tensor = torch.tensor(sum(data, [])) # (total_tokens,) 29 | return tensor.split(lens) # list of 1-d tensors 30 | 31 | 32 | def collate_fn(batch_raw): 33 | ''' 34 | length of input_ids: bs * (2 + num_hard_neg) 35 | 0 - bs-1: query input ids 36 | bs - 2*bs-1: passage input ids 37 | 2*bs - 2*bs+num_hard_neg-1: hard neg for sample 1 38 | 2*bs+num_hard_neg*(i-1) - 2*bs+num_hard_neg*i-1: hard neg for sample i (i from 1 to bs) 39 | ''' 40 | num_hard_neg = 1 if batch_raw[0]['dataset_name'] in CLASSIFICATION_DATASETS else args.num_hard_neg 41 | # select args.num_hard_neg hard negatives from a total of 24 42 | hard_neg_indices = [0] if num_hard_neg == 1 else random.sample(list(range(24)), num_hard_neg) 43 | input_ids = _stack( 44 | [s['query_input_ids'] for s in batch_raw]+\ 45 | [s['passage_input_ids'] for s in batch_raw]+\ 46 | [s[f'negative_{i+1}_input_ids'] for s in batch_raw for i in hard_neg_indices], 47 | args.max_seq_length 48 | ) 49 | seqlens = torch.tensor([ids.size(0) for ids in input_ids]) 50 | # pad input ids to [bs, max_len] 51 | input_ids = pad_sequence(input_ids, batch_first=True, padding_value=tokenizer.pad_token_id) 52 | attention_masks = input_ids.ne(tokenizer.pad_token_id).long() 53 | 54 | return {'input_ids': input_ids, 'seq_lens': seqlens, 'attention_mask': attention_masks, 'bs': len(batch_raw), 'dataset_name': batch_raw[0]['dataset_name']} 55 | 56 | 57 | set_seed(0) 58 | if accelerator.is_main_process: 59 | os.makedirs(f"{args.output_dir}", exist_ok=True) 60 | with open(os.path.join(args.output_dir, "args.json"), "w") as f: 61 | json.dump(args.dict(), f, indent=2) 62 | 63 | train_datasets, valid_datasets = [], [] 64 | for f in sorted(os.listdir(args.train_data_path)): 65 | dataset_name = f.split('.parquet')[0] 66 | dataset = load_dataset("parquet", data_files=os.path.join(args.train_data_path, f), cache_dir=args.cache_dir)['train'] 67 | dataset = dataset.add_column("dataset_name", [dataset_name]*len(dataset)) 68 | dataset = dataset.train_test_split(train_size=0.99, shuffle=True, seed=0) 69 | train_datasets.append((dataset_name, dataset['train'])) 70 | valid_datasets.append((dataset_name, dataset['test'])) 71 | 72 | tokenizer = AutoTokenizer.from_pretrained(args.model_path) 73 | 74 | train_loaders = { 75 | name: DataLoader(ds, shuffle=True, batch_size=args.train_batch_size, collate_fn=collate_fn) 76 | for name, ds in train_datasets 77 | } 78 | valid_loaders = { 79 | name: DataLoader(ds, shuffle=False, batch_size=args.train_batch_size, collate_fn=collate_fn) 80 | for name, ds in valid_datasets 81 | } 82 | 83 | class MultiLoader: 84 | """ 85 | Iterates over a dict(name -> DataLoader) and returns complete batches. 86 | At every __iter__ a new random order is created; 87 | the epoch ends when every loader is exhausted once. 88 | """ 89 | def __init__(self, loader_dict): 90 | self.loader_dict = loader_dict 91 | for k, v in self.loader_dict.items(): 92 | self.loader_dict[k] = accelerator.prepare(v) 93 | 94 | def __len__(self): 95 | return sum(len(v) for v in self.loader_dict.values()) 96 | 97 | def reset_epoch(self, epoch): 98 | self.rng = random.Random(epoch) 99 | self.iters = {k: iter(v) for k, v in self.loader_dict.items()} 100 | self.names = list(self.iters.keys()) 101 | self.weights = [len(self.loader_dict[k]) for k in self.names] 102 | 103 | def __iter__(self): 104 | while self.names: # until every DataLoader is empty 105 | name = self.rng.choices(self.names, weights=self.weights)[0] # pick a data-source at random 106 | try: 107 | batch = next(self.iters[name]) 108 | yield batch 109 | except StopIteration: 110 | idx = self.names.index(name) 111 | self.names.pop(idx) # this dataset has no batch left 112 | self.weights.pop(idx) 113 | 114 | 115 | # determine training steps 116 | override_train_step = False 117 | if args.train_steps < 0: 118 | args.train_steps = sum(len(v) for v in train_loaders.values()) * args.train_epochs 119 | override_train_step = True 120 | 121 | accelerator.print(f"******************************** Training step before prepare: {args.train_steps} ********************************") 122 | model = F2LLM(args.model_path, args.max_seq_length, args=args) 123 | model.lm.gradient_checkpointing_enable() 124 | # set seed again to make sure that different models share the same seed 125 | set_seed(0) 126 | 127 | optimizer = AdamW(model.lm.parameters(), 128 | weight_decay=args.weight_decay, 129 | lr=args.learning_rate, 130 | betas=(0.9, 0.98)) 131 | 132 | lr_scheduler = get_scheduler("cosine", 133 | optimizer=optimizer, 134 | num_warmup_steps=args.warmup_steps, 135 | num_training_steps=args.train_steps) 136 | 137 | AcceleratorState().deepspeed_plugin.deepspeed_config['train_micro_batch_size_per_gpu'] = args.train_batch_size 138 | model.lm, optimizer, lr_scheduler = accelerator.prepare( 139 | model.lm, optimizer, lr_scheduler 140 | ) 141 | model.set_device() 142 | train_dataloader = MultiLoader(train_loaders) 143 | for k, v in valid_loaders.items(): 144 | valid_loaders[k] = accelerator.prepare(v) 145 | 146 | # if training on multiple GPUs, length of dataloader would have changed 147 | if override_train_step: 148 | args.train_steps = len(train_dataloader) * args.train_epochs 149 | accelerator.print(f"******************************** Training step after prepare: {args.train_steps} ********************************") 150 | 151 | 152 | accelerate_train(args, accelerator, model, train_dataloader, valid_loaders, 153 | optimizer, lr_scheduler, len(dataset)) -------------------------------------------------------------------------------- /F2LLM/utils.py: -------------------------------------------------------------------------------- 1 | from tqdm.auto import tqdm 2 | from torch.utils.tensorboard import SummaryWriter 3 | import torch 4 | import torch.nn.functional as F 5 | from torch.nn import CrossEntropyLoss 6 | import os 7 | 8 | CLASSIFICATION_DATASETS = ['amazon_counterfactual', 'amazon_polarity', 'imdb', 'toxic_conversations', 'cola'] 9 | CLUSTERING_DATASETS = ['amazon_reviews', 'banking77', 'emotion', 'mtop_intent', 'mtop_domain', 'massive_scenario', 'massive_intent', 'tweet_sentiment_extraction', 'arxiv_clustering_p2p', 'arxiv_clustering_s2s', 'biorxiv_clustering_p2p', 'biorxiv_clustering_s2s', 'medrxiv_clustering_p2p', 'medrxiv_clustering_s2s', 'reddit_clustering_p2p', 'reddit_clustering_s2s', 'stackexchange_clustering_p2p', 'stackexchange_clustering_s2s', 'twentynewsgroups'] 10 | RETRIEVAL_DATASETS = ['arguana', 'snli', 'mnli', 'anli', 'paq', 'squad', 'stackexchange', 'msmarco', 'natural_questions', 'hotpotqa', 'fever', 'eli5', 'fiqa', 'bioasq', 'nfcorpus', 'miracl', 'mrtidy', 'scifact', 'qqp', 'stackoverflowdupquestions', 'sts12', 'sts22', 'stsbenchmark', 'amazon_qa', 'cnn_dm', 'coliee', 'paq_part2', 'pubmedqa', 's2orc_abstract_citation', 's2orc_title_abstract', 's2orc_title_citation', 'sentence_compression', 'specter', 'triviaqa', 'xsum', 'stackexchange_part2', 'stackexchangedupquestions_s2s', 'stackexchangedupquestions_p2p'] 11 | 12 | 13 | def write_tensorboard(summary_writer: SummaryWriter, log_dict: dict, completed_steps): 14 | for key, value in log_dict.items(): 15 | summary_writer.add_scalar(key, value, completed_steps) 16 | 17 | 18 | def save_checkpoint(args, accelerator, model, output_dir, lr_scheduler): 19 | accelerator.wait_for_everyone() 20 | accelerator.print(f"Saving checkpoint to {output_dir}") 21 | 22 | if accelerator.is_main_process: 23 | model.tokenizer.save_pretrained(output_dir) 24 | unwrapped_model = accelerator.unwrap_model(model.lm) 25 | unwrapped_model.save_pretrained( 26 | output_dir, 27 | is_main_process=accelerator.is_main_process, 28 | save_function=accelerator.save, 29 | state_dict=accelerator.get_state_dict(model.lm), # this is required for zero 3 30 | ) 31 | accelerator.wait_for_everyone() 32 | 33 | 34 | def inbatch_loss( 35 | query_embeddings, # [bs, d] 36 | context_embeddings, # [bs, d] 37 | criterion, 38 | accelerator, 39 | temperature=0.05, 40 | ): 41 | 42 | bs = query_embeddings.size(0) 43 | a_norm = F.normalize(query_embeddings, p=2, dim=-1) 44 | # b_norm = torch.nn.functional.normalize(context_embeddings, p=2, dim=-1) 45 | b_cross_gpus = accelerator.gather(context_embeddings) # [bs*process, d] 46 | # print((context_embeddings - b_cross_gpus[bs * accelerator.process_index : bs * accelerator.process_index+bs]).abs().sum()) 47 | b_norm_cross_gpus = F.normalize(b_cross_gpus, p=2, dim=-1) # () 48 | 49 | student_logits = torch.matmul(a_norm, b_norm_cross_gpus.t()) / temperature # [bs, bs*process] 50 | 51 | labels = torch.arange(bs, device=student_logits.device) + bs * accelerator.process_index 52 | loss_bs = criterion(student_logits, labels) # (bs) 53 | 54 | loss = loss_bs.mean() 55 | 56 | return loss 57 | 58 | def hard_loss( 59 | query_embeddings, # [bs, d] 60 | context_embeddings, # [bs, d] 61 | hard_neg_embeddings, # [bs, num, d] 62 | criterion, 63 | accelerator, 64 | temperature=0.05, 65 | ): 66 | 67 | if hard_neg_embeddings is None: 68 | return 0.0 69 | 70 | bs = query_embeddings.size(0) 71 | a_norm = F.normalize(query_embeddings, p=2, dim=-1) 72 | 73 | hard_neg_embeddings = torch.concat([ 74 | context_embeddings.unsqueeze(1), 75 | hard_neg_embeddings 76 | ], dim=1) # [bs, num_hard+1, d] 77 | 78 | hard_norm = F.normalize(hard_neg_embeddings, p=2, dim=-1) 79 | logits = (a_norm.unsqueeze(1) * hard_norm).sum(-1) / temperature # [bs, num_hard+1] 80 | 81 | loss_hard = criterion(logits, torch.zeros((bs), dtype=torch.long, device=logits.device)).mean() 82 | 83 | return loss_hard 84 | 85 | 86 | def validate(args, accelerator, model, valid_loader_dict, criterion, completed_steps, summary_writer): 87 | eval_log_dict = {} 88 | for dataset_name, valid_dataloader in valid_loader_dict.items(): 89 | loss_ls, loss_hard_ls = [], [] 90 | for batch in valid_dataloader: 91 | with torch.no_grad(): 92 | outputs = model.forward(batch) 93 | loss_hard = hard_loss(outputs['query_passage_features'].squeeze(1), outputs['passage_passage_features'].squeeze(1), outputs['negative_passage_features'], criterion, accelerator) 94 | loss_hard_ls.append(accelerator.gather(loss_hard).float()) 95 | if dataset_name in RETRIEVAL_DATASETS: 96 | loss = inbatch_loss(outputs['query_passage_features'].squeeze(1), outputs['passage_passage_features'].squeeze(1), criterion, accelerator) 97 | loss_ls.append(accelerator.gather(loss).float()) 98 | 99 | accelerator.wait_for_everyone() 100 | loss_hard_ls = torch.cat(loss_hard_ls) 101 | eval_log_dict[f'{dataset_name}/valid_loss_hard'] = loss_hard_ls.mean() 102 | if dataset_name in RETRIEVAL_DATASETS: 103 | loss_ls = torch.cat(loss_ls) 104 | eval_log_dict[f"{dataset_name}/valid_loss_in_batch"] = loss_ls.mean() 105 | 106 | eval_log_dict['Avg/retrieval/valid_loss_in_batch'] = torch.tensor([v for k, v in eval_log_dict.items() if k.split('/')[0] in RETRIEVAL_DATASETS and k.endswith('valid_loss_in_batch')]).mean() 107 | eval_log_dict['Avg/retrieval/valid_loss_hard'] = torch.tensor([v for k, v in eval_log_dict.items() if k.split('/')[0] in RETRIEVAL_DATASETS and k.endswith('valid_loss_hard')]).mean() 108 | eval_log_dict['Avg/classification/valid_loss_hard'] = torch.tensor([v for k, v in eval_log_dict.items() if k.split('/')[0] in CLASSIFICATION_DATASETS]).mean() 109 | eval_log_dict['Avg/clustering/valid_loss_hard'] = torch.tensor([v for k, v in eval_log_dict.items() if k.split('/')[0] in CLUSTERING_DATASETS]).mean() 110 | if accelerator.is_main_process: 111 | write_tensorboard(summary_writer, eval_log_dict, completed_steps) 112 | accelerator.print(f"[Validation] Step = {completed_steps}") 113 | 114 | 115 | def accelerate_train(args, 116 | accelerator, 117 | model, 118 | train_dataloader, 119 | valid_loader_dict, 120 | optimizer, 121 | lr_scheduler, 122 | num_train_samples): 123 | accelerator.print("**************************************** Start training ****************************************") 124 | accelerator.print(f" Num train samples = {num_train_samples}") 125 | accelerator.print(f" Num epochs = {args.train_epochs}") 126 | accelerator.print(f" Per device batch size = {args.train_batch_size}") 127 | accelerator.print(f" Global batch size = {args.train_batch_size * accelerator.num_processes}") 128 | accelerator.print(f" Step per epoch = {len(train_dataloader)}") 129 | accelerator.print(f" Total training steps = {args.train_steps}") 130 | accelerator.print("************************************************************************************************") 131 | global RETRIEVAL_DATASETS, CLASSIFICATION_DATASETS, CLUSTERING_DATASETS 132 | RETRIEVAL_DATASETS = [ds for ds in RETRIEVAL_DATASETS if ds in train_dataloader.loader_dict.keys()] 133 | CLASSIFICATION_DATASETS = [ds for ds in CLASSIFICATION_DATASETS if ds in train_dataloader.loader_dict.keys()] 134 | CLUSTERING_DATASETS = [ds for ds in CLUSTERING_DATASETS if ds in train_dataloader.loader_dict.keys()] 135 | 136 | summary_writer = SummaryWriter(log_dir=args.tb_dir) if accelerator.is_main_process else None 137 | criterion = CrossEntropyLoss(reduction='none') 138 | pbar = tqdm(range(args.train_steps), disable=not accelerator.is_local_main_process) 139 | completed_steps = 0 140 | loss_dict = {ds_name: torch.tensor(0.0, device=model.lm.device) for ds_name in RETRIEVAL_DATASETS} 141 | loss_hard_dict = {ds_name: torch.tensor(0.0, device=model.lm.device) for ds_name in train_dataloader.loader_dict.keys()} 142 | count_dict = {ds_name: torch.tensor(0, device=model.lm.device) for ds_name in RETRIEVAL_DATASETS} 143 | count_hard_dict = {ds_name: torch.tensor(0, device=model.lm.device) for ds_name in train_dataloader.loader_dict.keys()} 144 | 145 | model.lm.train() 146 | for epoch in range(args.train_epochs): 147 | accelerator.print(f"*************** Starting epoch {epoch+1} ***************") 148 | train_dataloader.reset_epoch(epoch) 149 | for batch in train_dataloader: 150 | # forward and compute loss 151 | outputs = model.forward(batch) 152 | # passage features: [bs, 1, d] 153 | # hard_neg_features: [bs, num_hard_neg, d] 154 | 155 | loss_hard = hard_loss(outputs['query_passage_features'].squeeze(1), outputs['passage_passage_features'].squeeze(1), outputs['negative_passage_features'], criterion, accelerator) 156 | dataset_name = batch['dataset_name'] 157 | count_hard_dict[dataset_name] += 1 158 | loss_hard_dict[dataset_name] += loss_hard.detach().float() 159 | if dataset_name in RETRIEVAL_DATASETS: 160 | loss = inbatch_loss(outputs['query_passage_features'].squeeze(1), outputs['passage_passage_features'].squeeze(1), criterion, accelerator) 161 | count_dict[dataset_name] += 1 162 | loss_dict[dataset_name] += loss.detach().float() 163 | else: 164 | loss = 0.0 165 | 166 | loss_total = loss + loss_hard 167 | 168 | # backward, optimizer, scheduler 169 | accelerator.backward(loss_total) 170 | optimizer.step() 171 | lr_scheduler.step() 172 | optimizer.zero_grad() 173 | if optimizer.param_groups[0]['lr'] < args.min_lr: 174 | for i in range(len(optimizer.param_groups)): 175 | optimizer.param_groups[i]['lr'] = args.min_lr 176 | 177 | # log 178 | completed_steps += 1 179 | if completed_steps % args.log_interval == 0: 180 | pbar.update(args.log_interval) 181 | 182 | train_log_dict = {"lr": optimizer.param_groups[0]['lr']} 183 | for k in loss_dict.keys(): 184 | count = accelerator.gather(count_dict[k]).sum() 185 | if count > 0: 186 | train_log_dict[f"{k}/training_loss_in_batch"] = accelerator.gather(loss_dict[k]).sum() / count 187 | for k in loss_hard_dict.keys(): 188 | count = accelerator.gather(count_hard_dict[k]).sum() 189 | if count > 0: 190 | train_log_dict[f"{k}/training_loss_hard"] = accelerator.gather(loss_hard_dict[k]).sum() / count 191 | train_log_dict['Avg/retrieval/training_loss_in_batch'] = torch.tensor([v for k, v in train_log_dict.items() if k.split('/')[0] in RETRIEVAL_DATASETS and k.endswith('training_loss_in_batch')]).mean() 192 | train_log_dict['Avg/retrieval/training_loss_hard'] = torch.tensor([v for k, v in train_log_dict.items() if k.split('/')[0] in RETRIEVAL_DATASETS and k.endswith('training_loss_hard')]).mean() 193 | train_log_dict['Avg/classification/training_loss_hard'] = torch.tensor([v for k, v in train_log_dict.items() if k.split('/')[0] in CLASSIFICATION_DATASETS]).mean() 194 | train_log_dict['Avg/clustering/training_loss_hard'] = torch.tensor([v for k, v in train_log_dict.items() if k.split('/')[0] in CLUSTERING_DATASETS]).mean() 195 | 196 | accelerator.print(f"[Train] Step = {completed_steps}") 197 | if accelerator.is_main_process: 198 | write_tensorboard(summary_writer, train_log_dict, completed_steps) 199 | loss_dict = {ds_name: torch.tensor(0.0, device=model.lm.device) for ds_name in RETRIEVAL_DATASETS} 200 | loss_hard_dict = {ds_name: torch.tensor(0.0, device=model.lm.device) for ds_name in train_dataloader.loader_dict.keys()} 201 | count_dict = {ds_name: torch.tensor(0, device=model.lm.device) for ds_name in RETRIEVAL_DATASETS} 202 | count_hard_dict = {ds_name: torch.tensor(0, device=model.lm.device) for ds_name in train_dataloader.loader_dict.keys()} 203 | 204 | # validation 205 | if completed_steps % args.validation_steps == 0: 206 | model.lm.eval() 207 | validate(args, accelerator, model, valid_loader_dict, criterion, completed_steps, summary_writer) 208 | model.lm.train() 209 | 210 | # step checkpoint 211 | if args.checkpointing_steps and completed_steps % args.checkpointing_steps == 0: 212 | output_dir = os.path.join(args.output_dir, f"step_{completed_steps}") 213 | save_checkpoint(args, accelerator, model, output_dir, lr_scheduler) 214 | 215 | if completed_steps >= args.train_steps: 216 | break 217 | 218 | # epoch checkpoint 219 | output_dir = os.path.join(args.output_dir, f"epoch_{epoch+1}") 220 | save_checkpoint(args, accelerator, model, output_dir, lr_scheduler) 221 | if completed_steps % args.validation_steps != 0: 222 | model.lm.eval() 223 | validate(args, accelerator, model, valid_loader_dict, criterion, completed_steps, summary_writer) 224 | model.lm.train() 225 | 226 | if summary_writer: 227 | summary_writer.close() -------------------------------------------------------------------------------- /CGE/utils/vllm_codefuse_cge_large.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Adapted from 3 | # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/qwen2/modeling_qwen2.py 4 | # Copyright 2024 The Qwen team. 5 | # Copyright 2023 The vLLM team. 6 | # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. 7 | # 8 | # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX 9 | # and OPT implementations in this library. It has been modified from its 10 | # original forms to accommodate minor architectural differences compared 11 | # to GPT-NeoX and OPT used by the Meta AI team that trained the model. 12 | # 13 | # Licensed under the Apache License, Version 2.0 (the "License"); 14 | # you may not use this file except in compliance with the License. 15 | # You may obtain a copy of the License at 16 | # 17 | # http://www.apache.org/licenses/LICENSE-2.0 18 | # 19 | # Unless required by applicable law or agreed to in writing, software 20 | # distributed under the License is distributed on an "AS IS" BASIS, 21 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 22 | # See the License for the specific language governing permissions and 23 | # limitations under the License. 24 | """Inference-only Qwen2 model compatible with HuggingFace weights.""" 25 | from typing import Iterable, List, Optional, Tuple 26 | from vllm.sequence import EmbeddingSequenceGroupOutput, PoolerOutput 27 | import torch 28 | from torch import nn 29 | from transformers import Qwen2Config 30 | from transformers import PretrainedConfig 31 | from vllm.attention import Attention, AttentionMetadata 32 | from vllm.config import CacheConfig, LoRAConfig 33 | from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size 34 | from vllm.model_executor.layers.activation import SiluAndMul 35 | from vllm.model_executor.layers.layernorm import RMSNorm 36 | from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, 37 | QKVParallelLinear, 38 | RowParallelLinear) 39 | from vllm.model_executor.layers.logits_processor import LogitsProcessor 40 | from vllm.model_executor.layers.quantization.base_config import ( 41 | QuantizationConfig) 42 | from vllm.model_executor.layers.rotary_embedding import get_rope 43 | from vllm.model_executor.layers.sampler import Sampler 44 | from vllm.model_executor.layers.vocab_parallel_embedding import ( 45 | ParallelLMHead, VocabParallelEmbedding) 46 | from vllm.model_executor.model_loader.weight_utils import ( 47 | default_weight_loader, maybe_remap_kv_scale_name) 48 | from vllm.model_executor.sampling_metadata import SamplingMetadata 49 | from vllm.sequence import IntermediateTensors, SamplerOutput 50 | 51 | from vllm.model_executor.models.interfaces import SupportsLoRA 52 | from vllm.model_executor.models.utils import is_pp_missing_parameter, make_layers 53 | from typing import Iterable, List, Optional, Tuple 54 | 55 | import torch 56 | from torch import nn 57 | 58 | from vllm.attention import AttentionMetadata 59 | from vllm.model_executor.layers.pooler import Pooler, PoolingType 60 | from vllm.model_executor.model_loader.weight_utils import default_weight_loader 61 | from vllm.model_executor.models.llama import LlamaModel 62 | from vllm.model_executor.pooling_metadata import PoolingMetadata 63 | from vllm.sequence import PoolerOutput 64 | import math 65 | import sys 66 | import torch 67 | import torch.nn as nn 68 | import torch.nn.functional as F 69 | 70 | 71 | class Qwen2MLP(nn.Module): 72 | 73 | def __init__( 74 | self, 75 | hidden_size: int, 76 | intermediate_size: int, 77 | hidden_act: str, 78 | quant_config: Optional[QuantizationConfig] = None, 79 | ) -> None: 80 | super().__init__() 81 | self.gate_up_proj = MergedColumnParallelLinear( 82 | hidden_size, [intermediate_size] * 2, 83 | bias=False, 84 | quant_config=quant_config) 85 | self.down_proj = RowParallelLinear(intermediate_size, 86 | hidden_size, 87 | bias=False, 88 | quant_config=quant_config) 89 | if hidden_act != "silu": 90 | raise ValueError(f"Unsupported activation: {hidden_act}. " 91 | "Only silu is supported for now.") 92 | self.act_fn = SiluAndMul() 93 | 94 | def forward(self, x): 95 | gate_up, _ = self.gate_up_proj(x) 96 | x = self.act_fn(gate_up) 97 | x, _ = self.down_proj(x) 98 | return x 99 | 100 | 101 | class Qwen2Attention(nn.Module): 102 | 103 | def __init__(self, 104 | hidden_size: int, 105 | num_heads: int, 106 | num_kv_heads: int, 107 | max_position: int = 4096 * 32, 108 | rope_theta: float = 10000, 109 | cache_config: Optional[CacheConfig] = None, 110 | quant_config: Optional[QuantizationConfig] = None, 111 | rope_scaling: Optional[Tuple] = None) -> None: 112 | super().__init__() 113 | self.hidden_size = hidden_size 114 | tp_size = get_tensor_model_parallel_world_size() 115 | self.total_num_heads = num_heads 116 | assert self.total_num_heads % tp_size == 0 117 | self.num_heads = self.total_num_heads // tp_size 118 | self.total_num_kv_heads = num_kv_heads 119 | if self.total_num_kv_heads >= tp_size: 120 | # Number of KV heads is greater than TP size, so we partition 121 | # the KV heads across multiple tensor parallel GPUs. 122 | assert self.total_num_kv_heads % tp_size == 0 123 | else: 124 | # Number of KV heads is less than TP size, so we replicate 125 | # the KV heads across multiple tensor parallel GPUs. 126 | assert tp_size % self.total_num_kv_heads == 0 127 | self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) 128 | self.head_dim = hidden_size // self.total_num_heads 129 | self.q_size = self.num_heads * self.head_dim 130 | self.kv_size = self.num_kv_heads * self.head_dim 131 | self.scaling = self.head_dim**-0.5 132 | self.rope_theta = rope_theta 133 | 134 | self.qkv_proj = QKVParallelLinear( 135 | hidden_size, 136 | self.head_dim, 137 | self.total_num_heads, 138 | self.total_num_kv_heads, 139 | bias=True, 140 | quant_config=quant_config, 141 | ) 142 | self.o_proj = RowParallelLinear( 143 | self.total_num_heads * self.head_dim, 144 | hidden_size, 145 | bias=False, 146 | quant_config=quant_config, 147 | ) 148 | 149 | self.rotary_emb = get_rope( 150 | self.head_dim, 151 | rotary_dim=self.head_dim, 152 | max_position=max_position, 153 | base=self.rope_theta, 154 | rope_scaling=rope_scaling, 155 | ) 156 | self.attn = Attention(self.num_heads, 157 | self.head_dim, 158 | self.scaling, 159 | num_kv_heads=self.num_kv_heads, 160 | cache_config=cache_config, 161 | quant_config=quant_config) 162 | 163 | def forward( 164 | self, 165 | positions: torch.Tensor, 166 | hidden_states: torch.Tensor, 167 | kv_cache: torch.Tensor, 168 | attn_metadata: AttentionMetadata, 169 | ) -> torch.Tensor: 170 | qkv, _ = self.qkv_proj(hidden_states) 171 | q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) 172 | q, k = self.rotary_emb(positions, q, k) 173 | attn_output = self.attn(q, k, v, kv_cache, attn_metadata) 174 | output, _ = self.o_proj(attn_output) 175 | return output 176 | 177 | 178 | class Qwen2DecoderLayer(nn.Module): 179 | 180 | def __init__( 181 | self, 182 | config: Qwen2Config, 183 | cache_config: Optional[CacheConfig] = None, 184 | quant_config: Optional[QuantizationConfig] = None, 185 | ) -> None: 186 | super().__init__() 187 | self.hidden_size = config.hidden_size 188 | # Requires transformers > 4.32.0 189 | rope_theta = getattr(config, "rope_theta", 1000000) 190 | rope_scaling = getattr(config, "rope_scaling", None) 191 | self.self_attn = Qwen2Attention( 192 | hidden_size=self.hidden_size, 193 | num_heads=config.num_attention_heads, 194 | max_position=config.max_position_embeddings, 195 | num_kv_heads=config.num_key_value_heads, 196 | rope_theta=rope_theta, 197 | cache_config=cache_config, 198 | quant_config=quant_config, 199 | rope_scaling=rope_scaling) 200 | self.mlp = Qwen2MLP( 201 | hidden_size=self.hidden_size, 202 | intermediate_size=config.intermediate_size, 203 | hidden_act=config.hidden_act, 204 | quant_config=quant_config, 205 | ) 206 | self.input_layernorm = RMSNorm(config.hidden_size, 207 | eps=config.rms_norm_eps) 208 | self.post_attention_layernorm = RMSNorm(config.hidden_size, 209 | eps=config.rms_norm_eps) 210 | 211 | def forward( 212 | self, 213 | positions: torch.Tensor, 214 | hidden_states: torch.Tensor, 215 | kv_cache: torch.Tensor, 216 | attn_metadata: AttentionMetadata, 217 | residual: Optional[torch.Tensor], 218 | ) -> Tuple[torch.Tensor, torch.Tensor]: 219 | # Self Attention 220 | if residual is None: 221 | residual = hidden_states 222 | hidden_states = self.input_layernorm(hidden_states) 223 | else: 224 | hidden_states, residual = self.input_layernorm( 225 | hidden_states, residual) 226 | hidden_states = self.self_attn( 227 | positions=positions, 228 | hidden_states=hidden_states, 229 | kv_cache=kv_cache, 230 | attn_metadata=attn_metadata, 231 | ) 232 | 233 | # Fully Connected 234 | hidden_states, residual = self.post_attention_layernorm( 235 | hidden_states, residual) 236 | hidden_states = self.mlp(hidden_states) 237 | return hidden_states, residual 238 | 239 | 240 | class Qwen2Model(nn.Module): 241 | 242 | def __init__( 243 | self, 244 | config: Qwen2Config, 245 | cache_config: Optional[CacheConfig] = None, 246 | quant_config: Optional[QuantizationConfig] = None, 247 | prefix: str = "", 248 | ) -> None: 249 | super().__init__() 250 | self.config = config 251 | self.padding_idx = config.pad_token_id 252 | self.vocab_size = config.vocab_size 253 | 254 | self.embed_tokens = VocabParallelEmbedding( 255 | config.vocab_size, 256 | config.hidden_size, 257 | quant_config=quant_config, 258 | ) 259 | self.start_layer, self.end_layer, self.layers = make_layers( 260 | config.num_hidden_layers, 261 | lambda prefix: Qwen2DecoderLayer(config=config, 262 | cache_config=cache_config, 263 | quant_config=quant_config), 264 | prefix=f"{prefix}.layers", 265 | ) 266 | 267 | self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) 268 | 269 | def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: 270 | return self.embed_tokens(input_ids) 271 | 272 | def forward( 273 | self, 274 | input_ids: torch.Tensor, 275 | positions: torch.Tensor, 276 | kv_caches: List[torch.Tensor], 277 | attn_metadata: AttentionMetadata, 278 | intermediate_tensors: Optional[IntermediateTensors] = None, 279 | inputs_embeds: Optional[torch.Tensor] = None, 280 | ) -> torch.Tensor: 281 | if get_pp_group().is_first_rank: 282 | if inputs_embeds is not None: 283 | hidden_states = inputs_embeds 284 | else: 285 | hidden_states = self.embed_tokens(input_ids) 286 | residual = None 287 | else: 288 | assert intermediate_tensors is not None 289 | hidden_states = intermediate_tensors["hidden_states"] 290 | residual = intermediate_tensors["residual"] 291 | for i in range(self.start_layer, self.end_layer): 292 | layer = self.layers[i] 293 | hidden_states, residual = layer( 294 | positions, 295 | hidden_states, 296 | kv_caches[i - self.start_layer], 297 | attn_metadata, 298 | residual, 299 | ) 300 | if not get_pp_group().is_last_rank: 301 | return IntermediateTensors({ 302 | "hidden_states": hidden_states, 303 | "residual": residual 304 | }) 305 | hidden_states, _ = self.norm(hidden_states, residual) 306 | return hidden_states 307 | 308 | 309 | class Qwen2ForCausalLM(nn.Module, SupportsLoRA): 310 | packed_modules_mapping = { 311 | "qkv_proj": [ 312 | "q_proj", 313 | "k_proj", 314 | "v_proj", 315 | ], 316 | "gate_up_proj": [ 317 | "gate_proj", 318 | "up_proj", 319 | ], 320 | } 321 | 322 | # LoRA specific attributes 323 | supported_lora_modules = [ 324 | "qkv_proj", 325 | "o_proj", 326 | "gate_up_proj", 327 | "down_proj", 328 | ] 329 | embedding_modules = {} 330 | embedding_padding_modules = [] 331 | 332 | def __init__( 333 | self, 334 | config: Qwen2Config, 335 | cache_config: Optional[CacheConfig] = None, 336 | quant_config: Optional[QuantizationConfig] = None, 337 | lora_config: Optional[LoRAConfig] = None, 338 | ) -> None: 339 | # TODO (@robertgshaw2): see if this can be moved out 340 | if (cache_config.sliding_window is not None 341 | and hasattr(config, "max_window_layers")): 342 | raise ValueError("Sliding window for some but all layers is not " 343 | "supported. This model uses sliding window " 344 | "but `max_window_layers` = %s is less than " 345 | "`num_hidden_layers` = %s. Please open an issue " 346 | "to discuss this feature." % ( 347 | config.max_window_layers, 348 | config.num_hidden_layers, 349 | )) 350 | 351 | super().__init__() 352 | 353 | self.config = config 354 | self.lora_config = lora_config 355 | 356 | self.quant_config = quant_config 357 | self.model = Qwen2Model(config, cache_config, quant_config) 358 | 359 | if config.tie_word_embeddings: 360 | self.lm_head = self.model.embed_tokens 361 | else: 362 | self.lm_head = ParallelLMHead(config.vocab_size, 363 | config.hidden_size, 364 | quant_config=quant_config) 365 | 366 | self.logits_processor = LogitsProcessor(config.vocab_size) 367 | self.sampler = Sampler() 368 | 369 | def forward( 370 | self, 371 | input_ids: torch.Tensor, 372 | positions: torch.Tensor, 373 | kv_caches: List[torch.Tensor], 374 | attn_metadata: AttentionMetadata, 375 | intermediate_tensors: Optional[IntermediateTensors] = None, 376 | ) -> torch.Tensor: 377 | hidden_states = self.model(input_ids, positions, kv_caches, 378 | attn_metadata, intermediate_tensors) 379 | return hidden_states 380 | 381 | def compute_logits(self, hidden_states: torch.Tensor, 382 | sampling_metadata: SamplingMetadata) -> torch.Tensor: 383 | logits = self.logits_processor(self.lm_head, hidden_states, 384 | sampling_metadata) 385 | return logits 386 | 387 | def make_empty_intermediate_tensors( 388 | self, batch_size: int, dtype: torch.dtype, 389 | device: torch.device) -> IntermediateTensors: 390 | return IntermediateTensors({ 391 | "hidden_states": 392 | torch.zeros((batch_size, self.config.hidden_size), 393 | dtype=dtype, 394 | device=device), 395 | "residual": 396 | torch.zeros((batch_size, self.config.hidden_size), 397 | dtype=dtype, 398 | device=device), 399 | }) 400 | 401 | def sample( 402 | self, 403 | logits: torch.Tensor, 404 | sampling_metadata: SamplingMetadata, 405 | ) -> Optional[SamplerOutput]: 406 | next_tokens = self.sampler(logits, sampling_metadata) 407 | return next_tokens 408 | 409 | def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): 410 | stacked_params_mapping = [ 411 | # (param_name, shard_name, shard_id) 412 | ("qkv_proj", "q_proj", "q"), 413 | ("qkv_proj", "k_proj", "k"), 414 | ("qkv_proj", "v_proj", "v"), 415 | ("gate_up_proj", "gate_proj", 0), 416 | ("gate_up_proj", "up_proj", 1), 417 | ] 418 | params_dict = dict(self.named_parameters(remove_duplicate=False)) 419 | for name, loaded_weight in weights: 420 | if "rotary_emb.inv_freq" in name: 421 | continue 422 | if self.config.tie_word_embeddings and "lm_head.weight" in name: 423 | continue 424 | for (param_name, weight_name, shard_id) in stacked_params_mapping: 425 | if weight_name not in name: 426 | continue 427 | name = name.replace(weight_name, param_name) 428 | # Skip loading extra bias for GPTQ models. 429 | if name.endswith(".bias") and name not in params_dict: 430 | continue 431 | if is_pp_missing_parameter(name, self): 432 | continue 433 | param = params_dict[name] 434 | weight_loader = param.weight_loader 435 | weight_loader(param, loaded_weight, shard_id) 436 | break 437 | else: 438 | # Skip loading extra bias for GPTQ models. 439 | if name.endswith(".bias") and name not in params_dict: 440 | continue 441 | # Remapping the name of FP8 kv-scale. 442 | name = maybe_remap_kv_scale_name(name, params_dict) 443 | if name is None: 444 | continue 445 | if is_pp_missing_parameter(name, self): 446 | continue 447 | param = params_dict[name] 448 | weight_loader = getattr(param, "weight_loader", 449 | default_weight_loader) 450 | weight_loader(param, loaded_weight) 451 | 452 | class CodeFuse_CGE_Large(nn.Module, SupportsLoRA): 453 | packed_modules_mapping = { 454 | "qkv_proj": [ 455 | "q_proj", 456 | "k_proj", 457 | "v_proj", 458 | ], 459 | "gate_up_proj": [ 460 | "gate_proj", 461 | "up_proj", 462 | ], 463 | } 464 | 465 | # LoRA specific attributes 466 | supported_lora_modules = [ 467 | "qkv_proj", 468 | "o_proj", 469 | "gate_up_proj", 470 | "down_proj", 471 | ] 472 | embedding_modules = {} 473 | embedding_padding_modules = [] 474 | 475 | def __init__( 476 | self, 477 | config: Qwen2Config, 478 | cache_config: Optional[CacheConfig] = None, 479 | quant_config: Optional[QuantizationConfig] = None, 480 | lora_config: Optional[LoRAConfig] = None, 481 | ) -> None: 482 | # TODO (@robertgshaw2): see if this can be moved out 483 | if (cache_config.sliding_window is not None 484 | and hasattr(config, "max_window_layers")): 485 | raise ValueError("Sliding window for some but all layers is not " 486 | "supported. This model uses sliding window " 487 | "but `max_window_layers` = %s is less than " 488 | "`num_hidden_layers` = %s. Please open an issue " 489 | "to discuss this feature." % ( 490 | config.max_window_layers, 491 | config.num_hidden_layers, 492 | )) 493 | 494 | super().__init__() 495 | 496 | self.config = config 497 | self.lora_config = lora_config 498 | self.quant_config = quant_config 499 | self.plm_model = Qwen2ForCausalLM(config, cache_config, quant_config) 500 | self.embedding_method = config.embedding_method 501 | self.inf_seq_length = config.inf_seq_length 502 | self.padding_side = config.padding_side 503 | self.keep_max_layer = config.keep_max_layer 504 | self.emb_dim = self.plm_model.model.embed_tokens.weight.size(1) 505 | self.num_heads = config.pma_num_heads 506 | self.ln = config.pma_ln 507 | self.norm = config.pma_norm 508 | self.pma_mode = config.pma_norm_mode 509 | self.mha_pma = PMA(self.emb_dim, self.compress_dim, self.num_heads, 1, ln=self.ln, pma_mode=self.pma_mode).to("cuda") 510 | if config.tie_word_embeddings: 511 | self.lm_head = self.plm_model.embed_tokens 512 | else: 513 | self.lm_head = ParallelLMHead(config.vocab_size, 514 | config.hidden_size, 515 | quant_config=quant_config) 516 | 517 | self.logits_processor = LogitsProcessor(config.vocab_size) 518 | self.sampler = Sampler() 519 | for param_tensor in self.mha_pma.state_dict(): 520 | print(param_tensor, "\t", self.mha_pma.state_dict()[param_tensor]) 521 | 522 | def forward( 523 | self, 524 | input_ids: torch.Tensor, 525 | positions: torch.Tensor, 526 | kv_caches: List[torch.Tensor], 527 | attn_metadata: AttentionMetadata, 528 | intermediate_tensors: Optional[IntermediateTensors] = None, 529 | ) -> torch.Tensor: 530 | hidden_states = self.plm_model(input_ids, positions, kv_caches, 531 | attn_metadata, intermediate_tensors) 532 | 533 | embedding = hidden_states.unsqueeze(0) 534 | res_embedding = self.pma_embedding(embedding, positions.unsqueeze(0)) 535 | return res_embedding 536 | 537 | def pooler( 538 | self, 539 | hidden_states: torch.Tensor, 540 | pooling_metadata: PoolingMetadata, 541 | ) -> Optional[PoolerOutput]: 542 | hidden_states = nn.functional.normalize(hidden_states, p=2, dim=1) 543 | pooled_outputs = [ 544 | EmbeddingSequenceGroupOutput(data.tolist()) for data in hidden_states 545 | ] 546 | 547 | return PoolerOutput(outputs=pooled_outputs) 548 | 549 | def compute_logits(self, hidden_states: torch.Tensor, 550 | sampling_metadata: SamplingMetadata) -> torch.Tensor: 551 | logits = self.logits_processor(self.lm_head, hidden_states, 552 | sampling_metadata) 553 | return logits 554 | 555 | def make_empty_intermediate_tensors( 556 | self, batch_size: int, dtype: torch.dtype, 557 | device: torch.device) -> IntermediateTensors: 558 | return IntermediateTensors({ 559 | "hidden_states": 560 | torch.zeros((batch_size, self.config.hidden_size), 561 | dtype=dtype, 562 | device=device), 563 | "residual": 564 | torch.zeros((batch_size, self.config.hidden_size), 565 | dtype=dtype, 566 | device=device), 567 | }) 568 | 569 | def sample( 570 | self, 571 | logits: torch.Tensor, 572 | sampling_metadata: SamplingMetadata, 573 | ) -> Optional[SamplerOutput]: 574 | next_tokens = self.sampler(logits, sampling_metadata) 575 | return next_tokens 576 | 577 | def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): 578 | stacked_params_mapping = [ 579 | # (param_name, shard_name, shard_id) 580 | ("qkv_proj", "q_proj", "q"), 581 | ("qkv_proj", "k_proj", "k"), 582 | ("qkv_proj", "v_proj", "v"), 583 | ("gate_up_proj", "gate_proj", 0), 584 | ("gate_up_proj", "up_proj", 1), 585 | ] 586 | params_dict = dict(self.named_parameters(remove_duplicate=False)) 587 | for name, loaded_weight in weights: 588 | if "rotary_emb.inv_freq" in name: 589 | continue 590 | if self.config.tie_word_embeddings and "lm_head.weight" in name: 591 | continue 592 | for (param_name, weight_name, shard_id) in stacked_params_mapping: 593 | if weight_name not in name: 594 | continue 595 | name = name.replace(weight_name, param_name) 596 | # Skip loading extra bias for GPTQ models. 597 | if name.endswith(".bias") and name not in params_dict: 598 | continue 599 | if is_pp_missing_parameter(name, self): 600 | continue 601 | param = params_dict[name] 602 | weight_loader = param.weight_loader 603 | weight_loader(param, loaded_weight, shard_id) 604 | break 605 | else: 606 | # Skip loading extra bias for GPTQ models. 607 | if name.endswith(".bias") and name not in params_dict: 608 | continue 609 | # Remapping the name of FP8 kv-scale. 610 | name = maybe_remap_kv_scale_name(name, params_dict) 611 | if name is None: 612 | continue 613 | if is_pp_missing_parameter(name, self): 614 | continue 615 | param = params_dict[name] 616 | weight_loader = getattr(param, "weight_loader", 617 | default_weight_loader) 618 | weight_loader(param, loaded_weight) 619 | for param_tensor in self.mha_pma.state_dict(): 620 | print(param_tensor, "\t", self.mha_pma.state_dict()[param_tensor]) 621 | 622 | def last_embedding(self, A, index): 623 | bs, seq, emb = A.size() 624 | res = A[torch.arange(bs), index, :] 625 | return res 626 | 627 | def mean_embedding(self, A, mask): 628 | bs, seq, emb = A.size() 629 | res = (A * (mask.unsqueeze(-1))).sum(1) / (mask.sum(1).unsqueeze(-1)) 630 | return res 631 | 632 | # A (bs, seq, emb_size), mask (bs, 1, seq) 633 | def weighted_embedding(self, A, mask): 634 | weights = (torch.arange(start=1, end=A.size(1) + 1).unsqueeze(0).unsqueeze(-1).expand(A.size()).float()).to(A.device) 635 | input_mask_expanded = (mask.squeeze(1).unsqueeze(-1).expand(A.size()).float()).to(A.device) 636 | sum_embedding = torch.sum(A * input_mask_expanded * weights, dim=1) 637 | sum_mask = torch.sum(input_mask_expanded * weights, dim=1) 638 | weighted_embedding = sum_embedding / sum_mask 639 | 640 | return weighted_embedding 641 | 642 | def pma_embedding(self, A, mask): 643 | res = self.mha_pma(A, mask).squeeze(1) 644 | return res 645 | 646 | 647 | def get_sentence_embedding(self, embedding_method, **inputs): 648 | outputs = self.plm_model(inputs['input_ids'], inputs['attention_mask'], output_hidden_states=True) 649 | if embedding_method == 'last': 650 | embedding = outputs.hidden_states[self.keep_max_layer] 651 | index = inputs['attention_mask'].sum(-1).long() - 1 652 | res_embedding = self.last_embedding(embedding, index) 653 | elif embedding_method == 'mean': 654 | embedding = outputs.hidden_states[self.keep_max_layer] 655 | res_embedding = self.mean_embedding(embedding, inputs['attention_mask']) 656 | elif embedding_method == 'weighted': 657 | embedding = outputs.hidden_states[self.keep_max_layer] 658 | res_embedding = self.weighted_embedding(embedding, inputs['attention_mask']) 659 | elif embedding_method == 'pma': 660 | embedding = outputs.hidden_states[self.keep_max_layer] 661 | attention_mask = inputs['attention_mask'] 662 | res_embedding = self.pma_embedding(embedding, attention_mask) 663 | else: 664 | logger.debug('Error, no {} way to obtain embbedings'.format(embedding_method)) 665 | 666 | if not self.norm: 667 | res_embedding = torch.nn.functional.normalize(res_embedding, p=2.0, dim=-1, eps=1e-12, out=None) 668 | return res_embedding 669 | 670 | 671 | 672 | def encode(self, tokenizer, sentences, batch_size=32, convert_to_numpy=True, 673 | convert_to_tensor=False, show_progress_bar=True, max_seq_length=None, **kwargs): 674 | if max_seq_length is None: 675 | max_seq_length = self.inf_seq_length 676 | input_is_string = False 677 | if isinstance(sentences, str) or not hasattr(sentences, "__len__"): 678 | sentences = [sentences] 679 | input_is_string = True 680 | 681 | all_embeddings = [] 682 | length_sorted_idx = np.argsort([-len(s) for s in sentences]) 683 | sentences_sorted = [sentences[idx] for idx in length_sorted_idx] # 大到小重排 684 | with torch.no_grad(): 685 | for start_index in trange(0, len(sentences), batch_size, desc="Batches", disable=not show_progress_bar): 686 | sentences_batch = sentences_sorted[start_index: start_index + batch_size] 687 | # Compute sentences embeddings 688 | with torch.no_grad(): 689 | inputs = tokenizer(sentences_batch, padding=True, truncation=True, max_length=max_seq_length, add_special_tokens=False, return_tensors='pt').to(self.plm_model.device) 690 | embeddings = self.get_sentence_embedding(self.embedding_method, **inputs) 691 | embeddings = embeddings.detach() 692 | if convert_to_numpy: 693 | if embeddings.dtype == torch.bfloat16: 694 | embeddings = embeddings.cpu().to(torch.float32) 695 | else: 696 | embeddings = embeddings.cpu() 697 | all_embeddings.extend(embeddings) 698 | all_embeddings = [all_embeddings[idx] for idx in np.argsort(length_sorted_idx)] 699 | if convert_to_tensor: 700 | all_embeddings = torch.stack(all_embeddings) 701 | elif convert_to_numpy: 702 | all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings]) 703 | 704 | if input_is_string: 705 | all_embeddings = all_embeddings[0] 706 | return all_embeddings 707 | 708 | 709 | class MAB_POST(nn.Module): 710 | def __init__(self, dim_Q, dim_K, dim_V, num_heads, ln=False): 711 | super(MAB_POST, self).__init__() 712 | self.dim_V = dim_V 713 | self.num_heads = num_heads 714 | self.fc_q = nn.Linear(dim_Q, dim_V) 715 | self.fc_k = nn.Linear(dim_K, dim_V) 716 | self.fc_v = nn.Linear(dim_K, dim_V) 717 | 718 | if ln: 719 | self.ln0 = nn.LayerNorm(dim_V) 720 | self.ln1 = nn.LayerNorm(dim_V) 721 | self.fc_o = nn.Linear(dim_V, dim_V) 722 | nn.init.xavier_uniform_(self.fc_q.weight) 723 | nn.init.xavier_uniform_(self.fc_k.weight) 724 | nn.init.xavier_uniform_(self.fc_v.weight) 725 | nn.init.xavier_uniform_(self.fc_o.weight) 726 | 727 | def forward(self, Q, K, pad_mask=None): 728 | 729 | Q_ = self.fc_q(Q) 730 | K_, V_ = self.fc_k(K), self.fc_v(K) 731 | dim_split = self.dim_V // self.num_heads 732 | Q_ = torch.cat(Q_.split(dim_split, 2), 0) 733 | K_ = torch.cat(K_.split(dim_split, 2), 0) 734 | V_ = torch.cat(V_.split(dim_split, 2), 0) 735 | 736 | pad_mask = pad_mask.unsqueeze(1).repeat(self.num_heads, Q.size(1), 1) 737 | score = Q_.bmm(K_.transpose(1,2))/math.sqrt(self.dim_V) 738 | score = score.masked_fill(pad_mask == 0, -1e12) 739 | A = torch.softmax(score, 2) 740 | A = A * pad_mask 741 | O = torch.cat(A.bmm(V_).split(Q.size(0), 0), 2) 742 | O = Q + O 743 | O = O if getattr(self, 'ln0', None) is None else self.ln0(O) 744 | O = O + F.relu(self.fc_o(O)) 745 | O = O if getattr(self, 'ln1', None) is None else self.ln1(O) 746 | return O 747 | 748 | 749 | class MAB_PRE_NORMAL(nn.Module): 750 | def __init__(self, dim_Q, dim_K, dim_V, num_heads, ln=False): 751 | super(MAB_PRE_NORMAL, self).__init__() 752 | self.dim_V = dim_V 753 | self.num_heads = num_heads 754 | self.fc_q = nn.Linear(dim_Q, dim_V) 755 | self.fc_k = nn.Linear(dim_K, dim_V) 756 | self.fc_v = nn.Linear(dim_K, dim_V) 757 | 758 | if ln: 759 | self.ln_q = nn.LayerNorm(dim_V) 760 | self.ln_kv = nn.LayerNorm(dim_V) 761 | self.ln_o = nn.LayerNorm(dim_V) 762 | self.ln_final = nn.LayerNorm(dim_V) 763 | 764 | self.fc_o = nn.Linear(dim_V, dim_V) 765 | nn.init.xavier_uniform_(self.fc_q.weight) 766 | nn.init.xavier_uniform_(self.fc_k.weight) 767 | nn.init.xavier_uniform_(self.fc_v.weight) 768 | nn.init.xavier_uniform_(self.fc_o.weight) 769 | 770 | def forward(self, Q, K, pad_mask=None): 771 | Q_ = Q if getattr(self, 'ln_q', None) is None else self.ln_q(Q) 772 | K_ = K if getattr(self, 'ln_kv', None) is None else self.ln_kv(K) 773 | Q_ = self.fc_q(Q_) 774 | K_, V_ = self.fc_k(K_), self.fc_v(K_) 775 | dim_split = self.dim_V // self.num_heads 776 | Q_ = torch.cat(Q_.split(dim_split, 2), 0) 777 | K_ = torch.cat(K_.split(dim_split, 2), 0) 778 | V_ = torch.cat(V_.split(dim_split, 2), 0) 779 | pad_mask = pad_mask.unsqueeze(1).repeat(self.num_heads, Q.size(1), 1) 780 | score = Q_.bmm(K_.transpose(1,2))/math.sqrt(self.dim_V) 781 | score = score.masked_fill(pad_mask == 0, -1e12) 782 | A = torch.softmax(score, 2) 783 | A = A * pad_mask 784 | O = torch.cat(A.bmm(V_).split(Q.size(0), 0), 2) 785 | O = Q + O 786 | O_ = O if getattr(self, 'ln_o', None) is None else self.ln_o(O) 787 | O_ = O + F.relu(self.fc_o(O_)) 788 | return O_ if getattr(self, 'ln_final', None) is None else self.ln_final(O_) 789 | 790 | 791 | class MAB_PRE_GPTJ(nn.Module): 792 | def __init__(self, dim_Q, dim_K, dim_V, num_heads, ln=False): 793 | super(MAB_PRE_GPTJ, self).__init__() 794 | self.dim_V = dim_V 795 | self.num_heads = num_heads 796 | self.fc_q = nn.Linear(dim_Q, dim_V) 797 | self.fc_k = nn.Linear(dim_K, dim_V) 798 | self.fc_v = nn.Linear(dim_K, dim_V) 799 | self.fc_o = nn.Linear(dim_V, dim_V) 800 | 801 | nn.init.xavier_uniform_(self.fc_q.weight) 802 | nn.init.xavier_uniform_(self.fc_k.weight) 803 | nn.init.xavier_uniform_(self.fc_v.weight) 804 | nn.init.xavier_uniform_(self.fc_o.weight) 805 | if ln: 806 | self.ln_q = nn.LayerNorm(dim_V) 807 | self.ln_kv = nn.LayerNorm(dim_V) 808 | self.ln_final = nn.LayerNorm(dim_V) 809 | 810 | def forward(self, Q, K, pad_mask=None): 811 | Q_ = Q if getattr(self, 'ln_q', None) is None else self.ln_q(Q) 812 | K_ = K if getattr(self, 'ln_kv', None) is None else self.ln_kv(K) 813 | 814 | Q1 = self.fc_q(Q_) 815 | K1, V1 = self.fc_k(K_), self.fc_v(K_) 816 | dim_split = self.dim_V // self.num_heads 817 | 818 | Q1 = torch.cat(Q1.split(dim_split, 2), 0) 819 | K1 = torch.cat(K1.split(dim_split, 2), 0) 820 | V1 = torch.cat(V1.split(dim_split, 2), 0) 821 | 822 | 823 | pad_mask = pad_mask.unsqueeze(1).repeat(self.num_heads, Q.size(1), 1) 824 | score = Q1.bmm(K1.transpose(1,2))/math.sqrt(self.dim_V) 825 | score = score.masked_fill(pad_mask == 0, -1e12) 826 | A = torch.softmax(score, 2) 827 | A = A * pad_mask 828 | O1 = torch.cat(A.bmm(V1).split(Q.size(0), 0), 2) 829 | O2 = F.relu(self.fc_o(Q_)) 830 | O_final = Q + O1 + O2 831 | return O_final if getattr(self, 'ln_final', None) is None else self.ln_final(O_final) 832 | 833 | 834 | class PMA(nn.Module): 835 | def __init__(self, dim, compress_dim, num_heads, num_seeds, ln=False, pma_mode=None): 836 | super(PMA, self).__init__() 837 | self.S = nn.Parameter(torch.Tensor(1, num_seeds, compress_dim)) 838 | nn.init.xavier_uniform_(self.S) 839 | if pma_mode == 'post_normal': 840 | self.mab = MAB_POST(compress_dim, dim, compress_dim, num_heads, ln=ln) 841 | elif pma_mode == 'pre_normal': 842 | self.mab = MAB_PRE_NORMAL(compress_dim, dim, compress_dim, num_heads, ln=ln) 843 | elif pma_mode == 'pre_gptj': 844 | self.mab = MAB_PRE_GPTJ(compress_dim, dim, compress_dim, num_heads, ln=ln) 845 | else: 846 | raise ValueError(f"Error, the pma_mode {pma_mode} is not implemented !") 847 | 848 | def forward(self, X, pad_mask): 849 | if self.S.dtype != torch.bfloat16: 850 | X = X.float() 851 | return self.mab(self.S.repeat(X.size(0), 1, 1), X, pad_mask) 852 | --------------------------------------------------------------------------------