├── F2LLM
├── imgs
│ ├── overview.png
│ └── mteb_leaderboard.png
├── requirements.txt
├── configs
│ ├── config.json
│ └── accelerate_config.yaml
├── arguments.py
├── model.py
├── tokenize_data_qwen.py
├── README.md
├── run.py
└── utils.py
├── CGE
├── resources
│ ├── result.png
│ └── CodeFuse-AI.png
├── README.md
└── utils
│ └── vllm_codefuse_cge_large.py
└── README.md
/F2LLM/imgs/overview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/codefuse-ai/CodeFuse-Embeddings/main/F2LLM/imgs/overview.png
--------------------------------------------------------------------------------
/CGE/resources/result.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/codefuse-ai/CodeFuse-Embeddings/main/CGE/resources/result.png
--------------------------------------------------------------------------------
/F2LLM/requirements.txt:
--------------------------------------------------------------------------------
1 | accelerate
2 | datasets
3 | deepspeed
4 | flash-attn
5 | torch
6 | transformers
7 | tensorboard
8 |
--------------------------------------------------------------------------------
/CGE/resources/CodeFuse-AI.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/codefuse-ai/CodeFuse-Embeddings/main/CGE/resources/CodeFuse-AI.png
--------------------------------------------------------------------------------
/F2LLM/imgs/mteb_leaderboard.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/codefuse-ai/CodeFuse-Embeddings/main/F2LLM/imgs/mteb_leaderboard.png
--------------------------------------------------------------------------------
/F2LLM/configs/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "model_path": "models/qwen3-4b",
3 | "experiment_id": "4b+lr.8e-6+bs.16x32+context.1024+2epochs",
4 | "train_data_path": "training_data/data_tokenized_qwen",
5 | "output_dir": "output",
6 | "tb_dir": "output/tb",
7 | "cache_dir": "cache",
8 | "train_batch_size": 16,
9 | "checkpointing_steps": 5000,
10 | "validation_steps": 5000,
11 | "max_seq_length": 1024,
12 | "learning_rate": 8e-6,
13 | "min_lr": 1e-7,
14 | "weight_decay": 0.01,
15 | "warmup_steps": 500,
16 | "train_epochs": 2,
17 | "log_interval": 100,
18 | "num_hard_neg": 7
19 | }
20 |
--------------------------------------------------------------------------------
/F2LLM/configs/accelerate_config.yaml:
--------------------------------------------------------------------------------
1 | compute_environment: LOCAL_MACHINE
2 | debug: false
3 | deepspeed_config:
4 | gradient_accumulation_steps: 1
5 | gradient_clipping: 1.0
6 | offload_optimizer_device: none
7 | offload_param_device: none
8 | zero3_init_flag: false
9 | zero_stage: 2
10 | distributed_type: DEEPSPEED
11 | downcast_bf16: "no"
12 | machine_rank: 0
13 | main_training_function: main
14 | mixed_precision: "bf16"
15 | num_machines: 1
16 | num_processes: 8
17 | rdzv_backend: static
18 | same_network: true
19 | tpu_env: []
20 | tpu_use_cluster: false
21 | tpu_use_sudo: false
22 | use_cpu: false
23 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## CodeFuse Embeddings
2 |
3 |
4 |
5 |
6 |
7 | Embedding-related repos from CodeFuse, including:
8 |
9 | - [CGE](./CGE/README.md)
10 | - [D2LLM](https://github.com/codefuse-ai/D2LLM)
11 | - [F2LLM](./F2LLM/README.md)
12 |
13 | **Star History**
14 |
15 | [](https://www.star-history.com/#codefuse-ai/CodeFuse-Embeddings&type=date&legend=top-left)
16 |
--------------------------------------------------------------------------------
/F2LLM/arguments.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass, asdict
2 | import argparse, json
3 |
4 |
5 | @dataclass
6 | class Args:
7 |
8 | model_path: str
9 | experiment_id: str
10 | # save dir
11 | output_dir: str
12 | tb_dir: str
13 | cache_dir: str
14 | # training arguments
15 | train_data_path: str
16 | train_batch_size: int = 8
17 | max_seq_length: int = 2048
18 | learning_rate: float = 1e-4
19 | min_lr: float = 1e-6
20 | weight_decay: float = 1e-2
21 | warmup_steps: int = 100
22 | # embedding-related settings
23 | num_hard_neg: int = 7
24 | # train steps take precedence over epochs, set to -1 to disable
25 | train_steps: int = -1
26 | train_epochs: int = 5
27 | log_interval: int = 20
28 | checkpointing_steps: int = 100
29 | validation_steps: int = 100
30 | # just placeholder, for logging purpose
31 | num_processes: int=0
32 |
33 | def dict(self):
34 | return asdict(self)
35 |
36 |
37 | def parse_args():
38 | parser = argparse.ArgumentParser()
39 | parser.add_argument("--config", type=str)
40 | arg = parser.parse_args()
41 | with open(arg.config) as f:
42 | config = json.load(f)
43 | args = Args(**config)
44 | args.output_dir = f"{args.output_dir}/{args.experiment_id}"
45 | args.tb_dir = f"{args.tb_dir}/{args.experiment_id}"
46 | return args
--------------------------------------------------------------------------------
/F2LLM/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from transformers import AutoModel, AutoTokenizer
3 |
4 |
5 | class F2LLM:
6 | def __init__(self,
7 | model_path,
8 | max_seq_length=512,
9 | args=None
10 | ):
11 |
12 | self.args = args
13 | self.dtype = torch.bfloat16
14 | self.device = None # set after accelerator.prepare
15 | self.lm = AutoModel.from_pretrained(model_path, trust_remote_code=True, torch_dtype=self.dtype, attn_implementation='flash_attention_2')
16 | self.lm.config.use_cache = False
17 | self.tokenizer = AutoTokenizer.from_pretrained(model_path)
18 | self.max_seq_length = max_seq_length
19 |
20 | def set_device(self):
21 | self.device = self.lm.device
22 |
23 | def forward(self, batch):
24 | bs = batch['bs']
25 | num_hard_neg = int((len(batch['input_ids']) - 2*bs) / bs)
26 |
27 | outputs = self.lm(batch['input_ids'],
28 | batch['attention_mask'],
29 | )
30 |
31 | passage_features_all_tokens = outputs.last_hidden_state
32 | return {
33 | 'query_passage_features': torch.stack([passage_features_all_tokens[i, [batch['seq_lens'][i]-1]] for i in range(bs)]),
34 | 'passage_passage_features': torch.stack([passage_features_all_tokens[i, [batch['seq_lens'][i]-1]] for i in range(bs, 2*bs)]),
35 | 'negative_passage_features': None if num_hard_neg == 0 else torch.stack([passage_features_all_tokens[i, [batch['seq_lens'][i]-1]] for i in range(2*bs, len(batch['seq_lens']))]).view(bs, num_hard_neg, -1)
36 | }
37 |
38 |
--------------------------------------------------------------------------------
/F2LLM/tokenize_data_qwen.py:
--------------------------------------------------------------------------------
1 | from multiprocessing import Pool
2 | import numpy as np
3 | import pandas as pd
4 | import os
5 | from transformers import AutoTokenizer
6 | from tqdm.auto import tqdm
7 |
8 |
9 | tokenizer = AutoTokenizer.from_pretrained('models/qwen3-0.6b')
10 | max_seq_length = 1023
11 |
12 |
13 | def process_sent(sentence):
14 |
15 | # We make sure there's always an eos token at the end of each sequence
16 | tokenizer_outputs = tokenizer(sentence, max_length=max_seq_length, truncation=True, add_special_tokens=False)
17 |
18 | return np.array(tokenizer_outputs.input_ids + [tokenizer.eos_token_id])
19 |
20 |
21 | def process_sent_batch(s):
22 | return s.apply(process_sent)
23 |
24 | def parallelize(data, func, num_of_processes=8):
25 | indices = np.array_split(data.index, num_of_processes)
26 | data_split = [data.iloc[idx] for idx in indices]
27 | with Pool(num_of_processes) as pool:
28 | data = pd.concat(pool.map(func, data_split))
29 | return data
30 |
31 |
32 | root_dir = 'training_data'
33 | for ds_name in tqdm(sorted(os.listdir(root_dir))):
34 | print(ds_name, flush=True)
35 |
36 | df = pd.read_parquet(f"{root_dir}/{ds_name}")
37 | df['query_input_ids'] = parallelize(df['query'], process_sent_batch, 62)
38 |
39 | num_neg = 24 if 'negative_2' in df.keys() else 1
40 |
41 | ls = df.passage.to_list()
42 | for i in range(1, num_neg+1):
43 | ls += df[f'negative_{i}'].to_list()
44 | ls = list(set(ls))
45 | df_tmp = pd.DataFrame({'text': ls})
46 | df_tmp['input_ids'] = parallelize(df_tmp['text'], process_sent_batch, 62)
47 | df_tmp = df_tmp.set_index('text')
48 |
49 | df['passage_input_ids'] = df.passage.map(df_tmp.input_ids)
50 |
51 | for i in range(1, num_neg+1):
52 | df[f'negative_{i}_input_ids'] = df[f'negative_{i}'].map(df_tmp.input_ids)
53 |
54 | df.to_parquet(f'data_tokenized_qwen/{ds_name}', index=False)
55 |
--------------------------------------------------------------------------------
/F2LLM/README.md:
--------------------------------------------------------------------------------
1 | ## F2LLM
2 |
3 | F2LLMs (Foundation-to-Feature Large Language Models) are foundation models directly finetuned on 6 million high-quality query-document pairs, striking a strong balance between model size, training cost, and embedding performance:
4 |
5 |
6 |
7 |
8 |
9 | On the MTEB leaderboard, F2LLM-4B ranks 2nd among models of ~4B size, and 7th overall, while F2LLM-1.7B ranks 1st among models of 1B-2B size.
10 |
11 |
12 |
13 |
14 |
15 | F2LLMs are fully open. Model checkpoints are available at:
16 |
17 | - [F2LLM 0.6B](https://huggingface.co/codefuse-ai/F2LLM-0.6B)
18 | - [F2LLM 1.7B](https://huggingface.co/codefuse-ai/F2LLM-1.7B)
19 | - [F2LLM 4B](https://huggingface.co/codefuse-ai/F2LLM-4B)
20 |
21 | Training data is available at [F2LLM data](https://huggingface.co/datasets/codefuse-ai/F2LLM).
22 |
23 | ### Train
24 |
25 | In this repo we provide a streamlined and efficient script for training embedding models. To reproduce the training of F2LLMs, please:
26 |
27 | - Setup environment following `requirements.txt`. We note that transformers>=4.51.0 is required for training Qwen3 models.
28 | - Download data and backbone models from Hugging Face (we use Qwen3 models).
29 | - Run `tokenize_data_qwen.py` to tokenize the downloaded data
30 | - Modify model path, data path, and other arguments in `configs/config.json`.
31 | - Start training with `accelerate launch --config_file configs/accelerate_config.yaml run.py --config configs/config.json`.
32 |
33 | Note: we recommend setting `num_processes` to 1 in `configs/accelerate_config.yaml` and launch the training code once to generate cache for training data before starting the actual training.
34 |
35 | For multi-node training, run on the main node:
36 |
37 | ```
38 | accelerate launch --config_file configs/accelerate_config.yaml --num_machines N_NODE --num_processes N_PROCESSES --machine_rank 0 --main_process_ip MASTER_IP --main_process_port MASTER_PORT run.py --config configs/config.json
39 | ```
40 |
41 | where N_NODE is the number of machines; N_PROCESSES is N_NODE\*8; MASTER_IP is the IP address of your master node, and MASTER_PORT is a port available on your machine (e.g. 6379).
42 |
43 | On worker nodes, also run the above commmand but modify `machine_rank` accordingly.
44 |
45 | ### Citation
46 |
47 | If you use the F2LLM models, data, or code, please cite the following technical report.
48 |
49 | ```
50 | @article{2025F2LLM,
51 | title={F2LLM Technical Report: Matching SOTA Embedding Performance with 6 Million Open-Source Data},
52 | author={Ziyin Zhang and Zihan Liao and Hang Yu and Peng Di and Rui Wang},
53 | journal = {CoRR},
54 | volume = {abs/2510.02294},
55 | year = {2025},
56 | url = {https://doi.org/10.48550/arXiv.2510.02294},
57 | doi = {10.48550/ARXIV.2510.02294},
58 | eprinttype = {arXiv},
59 | eprint = {2510.02294}
60 | }
61 | ```
62 |
--------------------------------------------------------------------------------
/CGE/README.md:
--------------------------------------------------------------------------------
1 | ## CodeFuse-CGE
2 |
3 |
4 |
5 |
6 | In this project, we introduce CodeFuse-CGE(Code General Embedding), which is distinguish on text2code task for it's powerful ability of capturing the semantic relationship between text and code.
7 | This model has the following notable features:
8 | ● Instruction-tuning is enabled for both query and code snippet sides.
9 | ● The model obtains sentence-level and code-level representations through a layer of cross-attention computation module.
10 | ● The model has a smaller dimensional size without significant degradation in performance.
11 |
12 | CodeFuse-CGE-Large Model Configuration
13 | huggingface:[codefuse-ai/CodeFuse-CGE-Large](https://huggingface.co/codefuse-ai/CodeFuse-CGE-Large)
14 | Base Model: CodeQwen1.5-7B-Chat
15 | Model Size: 7B
16 | Embedding Dimension: 1024
17 | Hidden Layers: 32
18 |
19 | Requirements
20 | ```
21 | flash_attn==2.4.2
22 | torch==2.1.0
23 | accelerate==0.28.0
24 | transformers==4.39.2
25 | vllm=0.5.3
26 | ```
27 |
28 |
29 | CodeFuse-CGE-Small Model Configuration
30 | huggingface:[codefuse-ai/CodeFuse-CGE-Small](https://huggingface.co/codefuse-ai/CodeFuse-CGE-Small)
31 | Base Model: Phi-3.5-mini-instruct
32 | Model Size: 3.8B
33 | Embedding Dimension: 1024
34 | Hidden Layers: 32
35 |
36 | Requirements
37 | ```
38 | flash_attn==2.4.2
39 | torch==2.1.0
40 | accelerate==0.28.0
41 | transformers>=4.43.0
42 | ```
43 |
44 |
45 | ## Benchmark the Performance
46 | We use MRR metric to evaluate the ability on text2code retrieval tasks: AdvTest, CosQA, CSN
47 |
48 | 
49 |
50 | ## How to Use
51 |
52 | You should download model file for huggingface at first.
53 |
54 | ### Transformers
55 | ```
56 | from transformers import AutoTokenizer, AutoModel
57 |
58 | model_name_or_path = "CodeFuse-CGE-Large"
59 | model = AutoModel.from_pretrained(model_name_or_path, trust_remote_code=True)
60 | tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True, truncation_side='right', padding_side='right')
61 |
62 | if torch.cuda.is_available():
63 | device = 'cuda'
64 | else:
65 | device = 'cpu'
66 | model.to(device)
67 |
68 | prefix_dict = {'python':{'query':'Retrieve the Python code that solves the following query:', 'passage':'Python code:'},
69 | 'java':{'query':'Retrieve the Java code that solves the following query:', 'passage':'Java code:'},
70 | 'go':{'query':'Retrieve the Go code that solves the following query:', 'passage':'Go code:'},
71 | 'c++':{'query':'Retrieve the C++ code that solves the following query:', 'passage':'C++ code:'},
72 | 'javascript':{'query':'Retrieve the Javascript code that solves the following query:', 'passage':'Javascript code:'},
73 | 'php':{'query':'Retrieve the PHP code that solves the following query:', 'passage':'PHP code:'},
74 | 'ruby':{'query':'Retrieve the Ruby code that solves the following query:', 'passage':'Ruby code:'},
75 | 'default':{'query':'Retrieve the code that solves the following query:', 'passage':'Code:'}
76 | }
77 |
78 | text = ["Writes a Boolean to the stream.",
79 | "def writeBoolean(self, n): t = TYPE_BOOL_TRUE if n is False: t = TYPE_BOOL_FALSE self.stream.write(t)"]
80 | text[0] += prefix_dict['python']['query']
81 | text[1] += prefix_dict['python']['passage']
82 | embed = model.encode(tokenizer, text)
83 | score = embed[0] @ embed[1].T
84 | print("score", score)
85 | ```
86 |
87 | ### Vllm
88 | We have also adapted Vllm to reduce latency during deployment.
89 | ```
90 | from vllm import ModelRegistry
91 | from utils.vllm_codefuse_cge_large import CodeFuse_CGE_Large
92 | from vllm.model_executor.models import ModelRegistry
93 | from vllm import LLM
94 |
95 | def always_true_is_embedding_model(model_arch: str) -> bool:
96 | return True
97 | ModelRegistry.is_embedding_model = always_true_is_embedding_model
98 | ModelRegistry.register_model("CodeFuse_CGE_Large", CodeFuse_CGE_Large)
99 |
100 |
101 | model_name_or_path = "CodeFuse-CGE-Large"
102 | model = LLM(model=model_name_or_path, trust_remote_code=True, enforce_eager=True, enable_chunked_prefill=False)
103 | prefix_dict = {'python':{'query':'Retrieve the Python code that solves the following query:', 'passage':'Python code:'},
104 | 'java':{'query':'Retrieve the Java code that solves the following query:', 'passage':'Java code:'},
105 | 'go':{'query':'Retrieve the Go code that solves the following query:', 'passage':'Go code:'},
106 | 'c++':{'query':'Retrieve the C++ code that solves the following query:', 'passage':'C++ code:'},
107 | 'javascript':{'query':'Retrieve the Javascript code that solves the following query:', 'passage':'Javascript code:'},
108 | 'php':{'query':'Retrieve the PHP code that solves the following query:', 'passage':'PHP code:'},
109 | 'ruby':{'query':'Retrieve the Ruby code that solves the following query:', 'passage':'Ruby code:'},
110 | 'default':{'query':'Retrieve the code that solves the following query:', 'passage':'Code:'}
111 | }
112 |
113 | text = ["Return the best fit based on rsquared",
114 | "def find_best_rsquared ( list_of_fits ) : res = sorted ( list_of_fits , key = lambda x : x . rsquared ) return res [ - 1 ]"]
115 | text[0] += prefix_dict['python']['query']
116 | text[1] += prefix_dict['python']['passage']
117 | embed_0 = model.encode([text[0]])[0].outputs.embedding
118 | embed_1 = model.encode([text[1]])[0].outputs.embedding
119 | ```
120 | Note:
121 | 1. After adapting Vllm, the model's input can only have a batch size of 1; otherwise, it will result in an array overflow error.
122 | 2. Only the CodeFuse-CGE-Large model has been adapted, and support for the CodeFuse-CGE-Small model will be available soon.
123 |
124 | ## Contact us
125 | Email:
126 |
127 | 
128 |
129 |
130 |
131 | ## Acknowledgement
132 | Thanks to the authors of open-sourced datasets, including CSN, Adv, CoSQA.
133 |
134 |
--------------------------------------------------------------------------------
/F2LLM/run.py:
--------------------------------------------------------------------------------
1 | from arguments import parse_args
2 | from utils import accelerate_train, CLASSIFICATION_DATASETS
3 | from transformers import (
4 | AutoTokenizer,
5 | set_seed,
6 | get_scheduler
7 | )
8 | import os, json, random
9 | from datasets import load_dataset
10 | from torch.utils.data import DataLoader
11 | from accelerate import Accelerator
12 | from accelerate.state import AcceleratorState
13 | import torch
14 | from torch.nn.utils.rnn import pad_sequence
15 | from torch.optim import AdamW
16 | from model import F2LLM
17 |
18 | os.environ["TOKENIZERS_PARALLELISM"] = "false"
19 |
20 | args = parse_args()
21 | accelerator = Accelerator()
22 | args.num_processes = accelerator.num_processes
23 | accelerator.print(args)
24 |
25 | def _stack(input_ids, max_len):
26 | data = [ids[:max_len] for ids in input_ids] # input_ids: list of lists
27 | lens = [len(x) for x in data]
28 | tensor = torch.tensor(sum(data, [])) # (total_tokens,)
29 | return tensor.split(lens) # list of 1-d tensors
30 |
31 |
32 | def collate_fn(batch_raw):
33 | '''
34 | length of input_ids: bs * (2 + num_hard_neg)
35 | 0 - bs-1: query input ids
36 | bs - 2*bs-1: passage input ids
37 | 2*bs - 2*bs+num_hard_neg-1: hard neg for sample 1
38 | 2*bs+num_hard_neg*(i-1) - 2*bs+num_hard_neg*i-1: hard neg for sample i (i from 1 to bs)
39 | '''
40 | num_hard_neg = 1 if batch_raw[0]['dataset_name'] in CLASSIFICATION_DATASETS else args.num_hard_neg
41 | # select args.num_hard_neg hard negatives from a total of 24
42 | hard_neg_indices = [0] if num_hard_neg == 1 else random.sample(list(range(24)), num_hard_neg)
43 | input_ids = _stack(
44 | [s['query_input_ids'] for s in batch_raw]+\
45 | [s['passage_input_ids'] for s in batch_raw]+\
46 | [s[f'negative_{i+1}_input_ids'] for s in batch_raw for i in hard_neg_indices],
47 | args.max_seq_length
48 | )
49 | seqlens = torch.tensor([ids.size(0) for ids in input_ids])
50 | # pad input ids to [bs, max_len]
51 | input_ids = pad_sequence(input_ids, batch_first=True, padding_value=tokenizer.pad_token_id)
52 | attention_masks = input_ids.ne(tokenizer.pad_token_id).long()
53 |
54 | return {'input_ids': input_ids, 'seq_lens': seqlens, 'attention_mask': attention_masks, 'bs': len(batch_raw), 'dataset_name': batch_raw[0]['dataset_name']}
55 |
56 |
57 | set_seed(0)
58 | if accelerator.is_main_process:
59 | os.makedirs(f"{args.output_dir}", exist_ok=True)
60 | with open(os.path.join(args.output_dir, "args.json"), "w") as f:
61 | json.dump(args.dict(), f, indent=2)
62 |
63 | train_datasets, valid_datasets = [], []
64 | for f in sorted(os.listdir(args.train_data_path)):
65 | dataset_name = f.split('.parquet')[0]
66 | dataset = load_dataset("parquet", data_files=os.path.join(args.train_data_path, f), cache_dir=args.cache_dir)['train']
67 | dataset = dataset.add_column("dataset_name", [dataset_name]*len(dataset))
68 | dataset = dataset.train_test_split(train_size=0.99, shuffle=True, seed=0)
69 | train_datasets.append((dataset_name, dataset['train']))
70 | valid_datasets.append((dataset_name, dataset['test']))
71 |
72 | tokenizer = AutoTokenizer.from_pretrained(args.model_path)
73 |
74 | train_loaders = {
75 | name: DataLoader(ds, shuffle=True, batch_size=args.train_batch_size, collate_fn=collate_fn)
76 | for name, ds in train_datasets
77 | }
78 | valid_loaders = {
79 | name: DataLoader(ds, shuffle=False, batch_size=args.train_batch_size, collate_fn=collate_fn)
80 | for name, ds in valid_datasets
81 | }
82 |
83 | class MultiLoader:
84 | """
85 | Iterates over a dict(name -> DataLoader) and returns complete batches.
86 | At every __iter__ a new random order is created;
87 | the epoch ends when every loader is exhausted once.
88 | """
89 | def __init__(self, loader_dict):
90 | self.loader_dict = loader_dict
91 | for k, v in self.loader_dict.items():
92 | self.loader_dict[k] = accelerator.prepare(v)
93 |
94 | def __len__(self):
95 | return sum(len(v) for v in self.loader_dict.values())
96 |
97 | def reset_epoch(self, epoch):
98 | self.rng = random.Random(epoch)
99 | self.iters = {k: iter(v) for k, v in self.loader_dict.items()}
100 | self.names = list(self.iters.keys())
101 | self.weights = [len(self.loader_dict[k]) for k in self.names]
102 |
103 | def __iter__(self):
104 | while self.names: # until every DataLoader is empty
105 | name = self.rng.choices(self.names, weights=self.weights)[0] # pick a data-source at random
106 | try:
107 | batch = next(self.iters[name])
108 | yield batch
109 | except StopIteration:
110 | idx = self.names.index(name)
111 | self.names.pop(idx) # this dataset has no batch left
112 | self.weights.pop(idx)
113 |
114 |
115 | # determine training steps
116 | override_train_step = False
117 | if args.train_steps < 0:
118 | args.train_steps = sum(len(v) for v in train_loaders.values()) * args.train_epochs
119 | override_train_step = True
120 |
121 | accelerator.print(f"******************************** Training step before prepare: {args.train_steps} ********************************")
122 | model = F2LLM(args.model_path, args.max_seq_length, args=args)
123 | model.lm.gradient_checkpointing_enable()
124 | # set seed again to make sure that different models share the same seed
125 | set_seed(0)
126 |
127 | optimizer = AdamW(model.lm.parameters(),
128 | weight_decay=args.weight_decay,
129 | lr=args.learning_rate,
130 | betas=(0.9, 0.98))
131 |
132 | lr_scheduler = get_scheduler("cosine",
133 | optimizer=optimizer,
134 | num_warmup_steps=args.warmup_steps,
135 | num_training_steps=args.train_steps)
136 |
137 | AcceleratorState().deepspeed_plugin.deepspeed_config['train_micro_batch_size_per_gpu'] = args.train_batch_size
138 | model.lm, optimizer, lr_scheduler = accelerator.prepare(
139 | model.lm, optimizer, lr_scheduler
140 | )
141 | model.set_device()
142 | train_dataloader = MultiLoader(train_loaders)
143 | for k, v in valid_loaders.items():
144 | valid_loaders[k] = accelerator.prepare(v)
145 |
146 | # if training on multiple GPUs, length of dataloader would have changed
147 | if override_train_step:
148 | args.train_steps = len(train_dataloader) * args.train_epochs
149 | accelerator.print(f"******************************** Training step after prepare: {args.train_steps} ********************************")
150 |
151 |
152 | accelerate_train(args, accelerator, model, train_dataloader, valid_loaders,
153 | optimizer, lr_scheduler, len(dataset))
--------------------------------------------------------------------------------
/F2LLM/utils.py:
--------------------------------------------------------------------------------
1 | from tqdm.auto import tqdm
2 | from torch.utils.tensorboard import SummaryWriter
3 | import torch
4 | import torch.nn.functional as F
5 | from torch.nn import CrossEntropyLoss
6 | import os
7 |
8 | CLASSIFICATION_DATASETS = ['amazon_counterfactual', 'amazon_polarity', 'imdb', 'toxic_conversations', 'cola']
9 | CLUSTERING_DATASETS = ['amazon_reviews', 'banking77', 'emotion', 'mtop_intent', 'mtop_domain', 'massive_scenario', 'massive_intent', 'tweet_sentiment_extraction', 'arxiv_clustering_p2p', 'arxiv_clustering_s2s', 'biorxiv_clustering_p2p', 'biorxiv_clustering_s2s', 'medrxiv_clustering_p2p', 'medrxiv_clustering_s2s', 'reddit_clustering_p2p', 'reddit_clustering_s2s', 'stackexchange_clustering_p2p', 'stackexchange_clustering_s2s', 'twentynewsgroups']
10 | RETRIEVAL_DATASETS = ['arguana', 'snli', 'mnli', 'anli', 'paq', 'squad', 'stackexchange', 'msmarco', 'natural_questions', 'hotpotqa', 'fever', 'eli5', 'fiqa', 'bioasq', 'nfcorpus', 'miracl', 'mrtidy', 'scifact', 'qqp', 'stackoverflowdupquestions', 'sts12', 'sts22', 'stsbenchmark', 'amazon_qa', 'cnn_dm', 'coliee', 'paq_part2', 'pubmedqa', 's2orc_abstract_citation', 's2orc_title_abstract', 's2orc_title_citation', 'sentence_compression', 'specter', 'triviaqa', 'xsum', 'stackexchange_part2', 'stackexchangedupquestions_s2s', 'stackexchangedupquestions_p2p']
11 |
12 |
13 | def write_tensorboard(summary_writer: SummaryWriter, log_dict: dict, completed_steps):
14 | for key, value in log_dict.items():
15 | summary_writer.add_scalar(key, value, completed_steps)
16 |
17 |
18 | def save_checkpoint(args, accelerator, model, output_dir, lr_scheduler):
19 | accelerator.wait_for_everyone()
20 | accelerator.print(f"Saving checkpoint to {output_dir}")
21 |
22 | if accelerator.is_main_process:
23 | model.tokenizer.save_pretrained(output_dir)
24 | unwrapped_model = accelerator.unwrap_model(model.lm)
25 | unwrapped_model.save_pretrained(
26 | output_dir,
27 | is_main_process=accelerator.is_main_process,
28 | save_function=accelerator.save,
29 | state_dict=accelerator.get_state_dict(model.lm), # this is required for zero 3
30 | )
31 | accelerator.wait_for_everyone()
32 |
33 |
34 | def inbatch_loss(
35 | query_embeddings, # [bs, d]
36 | context_embeddings, # [bs, d]
37 | criterion,
38 | accelerator,
39 | temperature=0.05,
40 | ):
41 |
42 | bs = query_embeddings.size(0)
43 | a_norm = F.normalize(query_embeddings, p=2, dim=-1)
44 | # b_norm = torch.nn.functional.normalize(context_embeddings, p=2, dim=-1)
45 | b_cross_gpus = accelerator.gather(context_embeddings) # [bs*process, d]
46 | # print((context_embeddings - b_cross_gpus[bs * accelerator.process_index : bs * accelerator.process_index+bs]).abs().sum())
47 | b_norm_cross_gpus = F.normalize(b_cross_gpus, p=2, dim=-1) # ()
48 |
49 | student_logits = torch.matmul(a_norm, b_norm_cross_gpus.t()) / temperature # [bs, bs*process]
50 |
51 | labels = torch.arange(bs, device=student_logits.device) + bs * accelerator.process_index
52 | loss_bs = criterion(student_logits, labels) # (bs)
53 |
54 | loss = loss_bs.mean()
55 |
56 | return loss
57 |
58 | def hard_loss(
59 | query_embeddings, # [bs, d]
60 | context_embeddings, # [bs, d]
61 | hard_neg_embeddings, # [bs, num, d]
62 | criterion,
63 | accelerator,
64 | temperature=0.05,
65 | ):
66 |
67 | if hard_neg_embeddings is None:
68 | return 0.0
69 |
70 | bs = query_embeddings.size(0)
71 | a_norm = F.normalize(query_embeddings, p=2, dim=-1)
72 |
73 | hard_neg_embeddings = torch.concat([
74 | context_embeddings.unsqueeze(1),
75 | hard_neg_embeddings
76 | ], dim=1) # [bs, num_hard+1, d]
77 |
78 | hard_norm = F.normalize(hard_neg_embeddings, p=2, dim=-1)
79 | logits = (a_norm.unsqueeze(1) * hard_norm).sum(-1) / temperature # [bs, num_hard+1]
80 |
81 | loss_hard = criterion(logits, torch.zeros((bs), dtype=torch.long, device=logits.device)).mean()
82 |
83 | return loss_hard
84 |
85 |
86 | def validate(args, accelerator, model, valid_loader_dict, criterion, completed_steps, summary_writer):
87 | eval_log_dict = {}
88 | for dataset_name, valid_dataloader in valid_loader_dict.items():
89 | loss_ls, loss_hard_ls = [], []
90 | for batch in valid_dataloader:
91 | with torch.no_grad():
92 | outputs = model.forward(batch)
93 | loss_hard = hard_loss(outputs['query_passage_features'].squeeze(1), outputs['passage_passage_features'].squeeze(1), outputs['negative_passage_features'], criterion, accelerator)
94 | loss_hard_ls.append(accelerator.gather(loss_hard).float())
95 | if dataset_name in RETRIEVAL_DATASETS:
96 | loss = inbatch_loss(outputs['query_passage_features'].squeeze(1), outputs['passage_passage_features'].squeeze(1), criterion, accelerator)
97 | loss_ls.append(accelerator.gather(loss).float())
98 |
99 | accelerator.wait_for_everyone()
100 | loss_hard_ls = torch.cat(loss_hard_ls)
101 | eval_log_dict[f'{dataset_name}/valid_loss_hard'] = loss_hard_ls.mean()
102 | if dataset_name in RETRIEVAL_DATASETS:
103 | loss_ls = torch.cat(loss_ls)
104 | eval_log_dict[f"{dataset_name}/valid_loss_in_batch"] = loss_ls.mean()
105 |
106 | eval_log_dict['Avg/retrieval/valid_loss_in_batch'] = torch.tensor([v for k, v in eval_log_dict.items() if k.split('/')[0] in RETRIEVAL_DATASETS and k.endswith('valid_loss_in_batch')]).mean()
107 | eval_log_dict['Avg/retrieval/valid_loss_hard'] = torch.tensor([v for k, v in eval_log_dict.items() if k.split('/')[0] in RETRIEVAL_DATASETS and k.endswith('valid_loss_hard')]).mean()
108 | eval_log_dict['Avg/classification/valid_loss_hard'] = torch.tensor([v for k, v in eval_log_dict.items() if k.split('/')[0] in CLASSIFICATION_DATASETS]).mean()
109 | eval_log_dict['Avg/clustering/valid_loss_hard'] = torch.tensor([v for k, v in eval_log_dict.items() if k.split('/')[0] in CLUSTERING_DATASETS]).mean()
110 | if accelerator.is_main_process:
111 | write_tensorboard(summary_writer, eval_log_dict, completed_steps)
112 | accelerator.print(f"[Validation] Step = {completed_steps}")
113 |
114 |
115 | def accelerate_train(args,
116 | accelerator,
117 | model,
118 | train_dataloader,
119 | valid_loader_dict,
120 | optimizer,
121 | lr_scheduler,
122 | num_train_samples):
123 | accelerator.print("**************************************** Start training ****************************************")
124 | accelerator.print(f" Num train samples = {num_train_samples}")
125 | accelerator.print(f" Num epochs = {args.train_epochs}")
126 | accelerator.print(f" Per device batch size = {args.train_batch_size}")
127 | accelerator.print(f" Global batch size = {args.train_batch_size * accelerator.num_processes}")
128 | accelerator.print(f" Step per epoch = {len(train_dataloader)}")
129 | accelerator.print(f" Total training steps = {args.train_steps}")
130 | accelerator.print("************************************************************************************************")
131 | global RETRIEVAL_DATASETS, CLASSIFICATION_DATASETS, CLUSTERING_DATASETS
132 | RETRIEVAL_DATASETS = [ds for ds in RETRIEVAL_DATASETS if ds in train_dataloader.loader_dict.keys()]
133 | CLASSIFICATION_DATASETS = [ds for ds in CLASSIFICATION_DATASETS if ds in train_dataloader.loader_dict.keys()]
134 | CLUSTERING_DATASETS = [ds for ds in CLUSTERING_DATASETS if ds in train_dataloader.loader_dict.keys()]
135 |
136 | summary_writer = SummaryWriter(log_dir=args.tb_dir) if accelerator.is_main_process else None
137 | criterion = CrossEntropyLoss(reduction='none')
138 | pbar = tqdm(range(args.train_steps), disable=not accelerator.is_local_main_process)
139 | completed_steps = 0
140 | loss_dict = {ds_name: torch.tensor(0.0, device=model.lm.device) for ds_name in RETRIEVAL_DATASETS}
141 | loss_hard_dict = {ds_name: torch.tensor(0.0, device=model.lm.device) for ds_name in train_dataloader.loader_dict.keys()}
142 | count_dict = {ds_name: torch.tensor(0, device=model.lm.device) for ds_name in RETRIEVAL_DATASETS}
143 | count_hard_dict = {ds_name: torch.tensor(0, device=model.lm.device) for ds_name in train_dataloader.loader_dict.keys()}
144 |
145 | model.lm.train()
146 | for epoch in range(args.train_epochs):
147 | accelerator.print(f"*************** Starting epoch {epoch+1} ***************")
148 | train_dataloader.reset_epoch(epoch)
149 | for batch in train_dataloader:
150 | # forward and compute loss
151 | outputs = model.forward(batch)
152 | # passage features: [bs, 1, d]
153 | # hard_neg_features: [bs, num_hard_neg, d]
154 |
155 | loss_hard = hard_loss(outputs['query_passage_features'].squeeze(1), outputs['passage_passage_features'].squeeze(1), outputs['negative_passage_features'], criterion, accelerator)
156 | dataset_name = batch['dataset_name']
157 | count_hard_dict[dataset_name] += 1
158 | loss_hard_dict[dataset_name] += loss_hard.detach().float()
159 | if dataset_name in RETRIEVAL_DATASETS:
160 | loss = inbatch_loss(outputs['query_passage_features'].squeeze(1), outputs['passage_passage_features'].squeeze(1), criterion, accelerator)
161 | count_dict[dataset_name] += 1
162 | loss_dict[dataset_name] += loss.detach().float()
163 | else:
164 | loss = 0.0
165 |
166 | loss_total = loss + loss_hard
167 |
168 | # backward, optimizer, scheduler
169 | accelerator.backward(loss_total)
170 | optimizer.step()
171 | lr_scheduler.step()
172 | optimizer.zero_grad()
173 | if optimizer.param_groups[0]['lr'] < args.min_lr:
174 | for i in range(len(optimizer.param_groups)):
175 | optimizer.param_groups[i]['lr'] = args.min_lr
176 |
177 | # log
178 | completed_steps += 1
179 | if completed_steps % args.log_interval == 0:
180 | pbar.update(args.log_interval)
181 |
182 | train_log_dict = {"lr": optimizer.param_groups[0]['lr']}
183 | for k in loss_dict.keys():
184 | count = accelerator.gather(count_dict[k]).sum()
185 | if count > 0:
186 | train_log_dict[f"{k}/training_loss_in_batch"] = accelerator.gather(loss_dict[k]).sum() / count
187 | for k in loss_hard_dict.keys():
188 | count = accelerator.gather(count_hard_dict[k]).sum()
189 | if count > 0:
190 | train_log_dict[f"{k}/training_loss_hard"] = accelerator.gather(loss_hard_dict[k]).sum() / count
191 | train_log_dict['Avg/retrieval/training_loss_in_batch'] = torch.tensor([v for k, v in train_log_dict.items() if k.split('/')[0] in RETRIEVAL_DATASETS and k.endswith('training_loss_in_batch')]).mean()
192 | train_log_dict['Avg/retrieval/training_loss_hard'] = torch.tensor([v for k, v in train_log_dict.items() if k.split('/')[0] in RETRIEVAL_DATASETS and k.endswith('training_loss_hard')]).mean()
193 | train_log_dict['Avg/classification/training_loss_hard'] = torch.tensor([v for k, v in train_log_dict.items() if k.split('/')[0] in CLASSIFICATION_DATASETS]).mean()
194 | train_log_dict['Avg/clustering/training_loss_hard'] = torch.tensor([v for k, v in train_log_dict.items() if k.split('/')[0] in CLUSTERING_DATASETS]).mean()
195 |
196 | accelerator.print(f"[Train] Step = {completed_steps}")
197 | if accelerator.is_main_process:
198 | write_tensorboard(summary_writer, train_log_dict, completed_steps)
199 | loss_dict = {ds_name: torch.tensor(0.0, device=model.lm.device) for ds_name in RETRIEVAL_DATASETS}
200 | loss_hard_dict = {ds_name: torch.tensor(0.0, device=model.lm.device) for ds_name in train_dataloader.loader_dict.keys()}
201 | count_dict = {ds_name: torch.tensor(0, device=model.lm.device) for ds_name in RETRIEVAL_DATASETS}
202 | count_hard_dict = {ds_name: torch.tensor(0, device=model.lm.device) for ds_name in train_dataloader.loader_dict.keys()}
203 |
204 | # validation
205 | if completed_steps % args.validation_steps == 0:
206 | model.lm.eval()
207 | validate(args, accelerator, model, valid_loader_dict, criterion, completed_steps, summary_writer)
208 | model.lm.train()
209 |
210 | # step checkpoint
211 | if args.checkpointing_steps and completed_steps % args.checkpointing_steps == 0:
212 | output_dir = os.path.join(args.output_dir, f"step_{completed_steps}")
213 | save_checkpoint(args, accelerator, model, output_dir, lr_scheduler)
214 |
215 | if completed_steps >= args.train_steps:
216 | break
217 |
218 | # epoch checkpoint
219 | output_dir = os.path.join(args.output_dir, f"epoch_{epoch+1}")
220 | save_checkpoint(args, accelerator, model, output_dir, lr_scheduler)
221 | if completed_steps % args.validation_steps != 0:
222 | model.lm.eval()
223 | validate(args, accelerator, model, valid_loader_dict, criterion, completed_steps, summary_writer)
224 | model.lm.train()
225 |
226 | if summary_writer:
227 | summary_writer.close()
--------------------------------------------------------------------------------
/CGE/utils/vllm_codefuse_cge_large.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Adapted from
3 | # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/qwen2/modeling_qwen2.py
4 | # Copyright 2024 The Qwen team.
5 | # Copyright 2023 The vLLM team.
6 | # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
7 | #
8 | # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
9 | # and OPT implementations in this library. It has been modified from its
10 | # original forms to accommodate minor architectural differences compared
11 | # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
12 | #
13 | # Licensed under the Apache License, Version 2.0 (the "License");
14 | # you may not use this file except in compliance with the License.
15 | # You may obtain a copy of the License at
16 | #
17 | # http://www.apache.org/licenses/LICENSE-2.0
18 | #
19 | # Unless required by applicable law or agreed to in writing, software
20 | # distributed under the License is distributed on an "AS IS" BASIS,
21 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
22 | # See the License for the specific language governing permissions and
23 | # limitations under the License.
24 | """Inference-only Qwen2 model compatible with HuggingFace weights."""
25 | from typing import Iterable, List, Optional, Tuple
26 | from vllm.sequence import EmbeddingSequenceGroupOutput, PoolerOutput
27 | import torch
28 | from torch import nn
29 | from transformers import Qwen2Config
30 | from transformers import PretrainedConfig
31 | from vllm.attention import Attention, AttentionMetadata
32 | from vllm.config import CacheConfig, LoRAConfig
33 | from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
34 | from vllm.model_executor.layers.activation import SiluAndMul
35 | from vllm.model_executor.layers.layernorm import RMSNorm
36 | from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
37 | QKVParallelLinear,
38 | RowParallelLinear)
39 | from vllm.model_executor.layers.logits_processor import LogitsProcessor
40 | from vllm.model_executor.layers.quantization.base_config import (
41 | QuantizationConfig)
42 | from vllm.model_executor.layers.rotary_embedding import get_rope
43 | from vllm.model_executor.layers.sampler import Sampler
44 | from vllm.model_executor.layers.vocab_parallel_embedding import (
45 | ParallelLMHead, VocabParallelEmbedding)
46 | from vllm.model_executor.model_loader.weight_utils import (
47 | default_weight_loader, maybe_remap_kv_scale_name)
48 | from vllm.model_executor.sampling_metadata import SamplingMetadata
49 | from vllm.sequence import IntermediateTensors, SamplerOutput
50 |
51 | from vllm.model_executor.models.interfaces import SupportsLoRA
52 | from vllm.model_executor.models.utils import is_pp_missing_parameter, make_layers
53 | from typing import Iterable, List, Optional, Tuple
54 |
55 | import torch
56 | from torch import nn
57 |
58 | from vllm.attention import AttentionMetadata
59 | from vllm.model_executor.layers.pooler import Pooler, PoolingType
60 | from vllm.model_executor.model_loader.weight_utils import default_weight_loader
61 | from vllm.model_executor.models.llama import LlamaModel
62 | from vllm.model_executor.pooling_metadata import PoolingMetadata
63 | from vllm.sequence import PoolerOutput
64 | import math
65 | import sys
66 | import torch
67 | import torch.nn as nn
68 | import torch.nn.functional as F
69 |
70 |
71 | class Qwen2MLP(nn.Module):
72 |
73 | def __init__(
74 | self,
75 | hidden_size: int,
76 | intermediate_size: int,
77 | hidden_act: str,
78 | quant_config: Optional[QuantizationConfig] = None,
79 | ) -> None:
80 | super().__init__()
81 | self.gate_up_proj = MergedColumnParallelLinear(
82 | hidden_size, [intermediate_size] * 2,
83 | bias=False,
84 | quant_config=quant_config)
85 | self.down_proj = RowParallelLinear(intermediate_size,
86 | hidden_size,
87 | bias=False,
88 | quant_config=quant_config)
89 | if hidden_act != "silu":
90 | raise ValueError(f"Unsupported activation: {hidden_act}. "
91 | "Only silu is supported for now.")
92 | self.act_fn = SiluAndMul()
93 |
94 | def forward(self, x):
95 | gate_up, _ = self.gate_up_proj(x)
96 | x = self.act_fn(gate_up)
97 | x, _ = self.down_proj(x)
98 | return x
99 |
100 |
101 | class Qwen2Attention(nn.Module):
102 |
103 | def __init__(self,
104 | hidden_size: int,
105 | num_heads: int,
106 | num_kv_heads: int,
107 | max_position: int = 4096 * 32,
108 | rope_theta: float = 10000,
109 | cache_config: Optional[CacheConfig] = None,
110 | quant_config: Optional[QuantizationConfig] = None,
111 | rope_scaling: Optional[Tuple] = None) -> None:
112 | super().__init__()
113 | self.hidden_size = hidden_size
114 | tp_size = get_tensor_model_parallel_world_size()
115 | self.total_num_heads = num_heads
116 | assert self.total_num_heads % tp_size == 0
117 | self.num_heads = self.total_num_heads // tp_size
118 | self.total_num_kv_heads = num_kv_heads
119 | if self.total_num_kv_heads >= tp_size:
120 | # Number of KV heads is greater than TP size, so we partition
121 | # the KV heads across multiple tensor parallel GPUs.
122 | assert self.total_num_kv_heads % tp_size == 0
123 | else:
124 | # Number of KV heads is less than TP size, so we replicate
125 | # the KV heads across multiple tensor parallel GPUs.
126 | assert tp_size % self.total_num_kv_heads == 0
127 | self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
128 | self.head_dim = hidden_size // self.total_num_heads
129 | self.q_size = self.num_heads * self.head_dim
130 | self.kv_size = self.num_kv_heads * self.head_dim
131 | self.scaling = self.head_dim**-0.5
132 | self.rope_theta = rope_theta
133 |
134 | self.qkv_proj = QKVParallelLinear(
135 | hidden_size,
136 | self.head_dim,
137 | self.total_num_heads,
138 | self.total_num_kv_heads,
139 | bias=True,
140 | quant_config=quant_config,
141 | )
142 | self.o_proj = RowParallelLinear(
143 | self.total_num_heads * self.head_dim,
144 | hidden_size,
145 | bias=False,
146 | quant_config=quant_config,
147 | )
148 |
149 | self.rotary_emb = get_rope(
150 | self.head_dim,
151 | rotary_dim=self.head_dim,
152 | max_position=max_position,
153 | base=self.rope_theta,
154 | rope_scaling=rope_scaling,
155 | )
156 | self.attn = Attention(self.num_heads,
157 | self.head_dim,
158 | self.scaling,
159 | num_kv_heads=self.num_kv_heads,
160 | cache_config=cache_config,
161 | quant_config=quant_config)
162 |
163 | def forward(
164 | self,
165 | positions: torch.Tensor,
166 | hidden_states: torch.Tensor,
167 | kv_cache: torch.Tensor,
168 | attn_metadata: AttentionMetadata,
169 | ) -> torch.Tensor:
170 | qkv, _ = self.qkv_proj(hidden_states)
171 | q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
172 | q, k = self.rotary_emb(positions, q, k)
173 | attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
174 | output, _ = self.o_proj(attn_output)
175 | return output
176 |
177 |
178 | class Qwen2DecoderLayer(nn.Module):
179 |
180 | def __init__(
181 | self,
182 | config: Qwen2Config,
183 | cache_config: Optional[CacheConfig] = None,
184 | quant_config: Optional[QuantizationConfig] = None,
185 | ) -> None:
186 | super().__init__()
187 | self.hidden_size = config.hidden_size
188 | # Requires transformers > 4.32.0
189 | rope_theta = getattr(config, "rope_theta", 1000000)
190 | rope_scaling = getattr(config, "rope_scaling", None)
191 | self.self_attn = Qwen2Attention(
192 | hidden_size=self.hidden_size,
193 | num_heads=config.num_attention_heads,
194 | max_position=config.max_position_embeddings,
195 | num_kv_heads=config.num_key_value_heads,
196 | rope_theta=rope_theta,
197 | cache_config=cache_config,
198 | quant_config=quant_config,
199 | rope_scaling=rope_scaling)
200 | self.mlp = Qwen2MLP(
201 | hidden_size=self.hidden_size,
202 | intermediate_size=config.intermediate_size,
203 | hidden_act=config.hidden_act,
204 | quant_config=quant_config,
205 | )
206 | self.input_layernorm = RMSNorm(config.hidden_size,
207 | eps=config.rms_norm_eps)
208 | self.post_attention_layernorm = RMSNorm(config.hidden_size,
209 | eps=config.rms_norm_eps)
210 |
211 | def forward(
212 | self,
213 | positions: torch.Tensor,
214 | hidden_states: torch.Tensor,
215 | kv_cache: torch.Tensor,
216 | attn_metadata: AttentionMetadata,
217 | residual: Optional[torch.Tensor],
218 | ) -> Tuple[torch.Tensor, torch.Tensor]:
219 | # Self Attention
220 | if residual is None:
221 | residual = hidden_states
222 | hidden_states = self.input_layernorm(hidden_states)
223 | else:
224 | hidden_states, residual = self.input_layernorm(
225 | hidden_states, residual)
226 | hidden_states = self.self_attn(
227 | positions=positions,
228 | hidden_states=hidden_states,
229 | kv_cache=kv_cache,
230 | attn_metadata=attn_metadata,
231 | )
232 |
233 | # Fully Connected
234 | hidden_states, residual = self.post_attention_layernorm(
235 | hidden_states, residual)
236 | hidden_states = self.mlp(hidden_states)
237 | return hidden_states, residual
238 |
239 |
240 | class Qwen2Model(nn.Module):
241 |
242 | def __init__(
243 | self,
244 | config: Qwen2Config,
245 | cache_config: Optional[CacheConfig] = None,
246 | quant_config: Optional[QuantizationConfig] = None,
247 | prefix: str = "",
248 | ) -> None:
249 | super().__init__()
250 | self.config = config
251 | self.padding_idx = config.pad_token_id
252 | self.vocab_size = config.vocab_size
253 |
254 | self.embed_tokens = VocabParallelEmbedding(
255 | config.vocab_size,
256 | config.hidden_size,
257 | quant_config=quant_config,
258 | )
259 | self.start_layer, self.end_layer, self.layers = make_layers(
260 | config.num_hidden_layers,
261 | lambda prefix: Qwen2DecoderLayer(config=config,
262 | cache_config=cache_config,
263 | quant_config=quant_config),
264 | prefix=f"{prefix}.layers",
265 | )
266 |
267 | self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
268 |
269 | def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
270 | return self.embed_tokens(input_ids)
271 |
272 | def forward(
273 | self,
274 | input_ids: torch.Tensor,
275 | positions: torch.Tensor,
276 | kv_caches: List[torch.Tensor],
277 | attn_metadata: AttentionMetadata,
278 | intermediate_tensors: Optional[IntermediateTensors] = None,
279 | inputs_embeds: Optional[torch.Tensor] = None,
280 | ) -> torch.Tensor:
281 | if get_pp_group().is_first_rank:
282 | if inputs_embeds is not None:
283 | hidden_states = inputs_embeds
284 | else:
285 | hidden_states = self.embed_tokens(input_ids)
286 | residual = None
287 | else:
288 | assert intermediate_tensors is not None
289 | hidden_states = intermediate_tensors["hidden_states"]
290 | residual = intermediate_tensors["residual"]
291 | for i in range(self.start_layer, self.end_layer):
292 | layer = self.layers[i]
293 | hidden_states, residual = layer(
294 | positions,
295 | hidden_states,
296 | kv_caches[i - self.start_layer],
297 | attn_metadata,
298 | residual,
299 | )
300 | if not get_pp_group().is_last_rank:
301 | return IntermediateTensors({
302 | "hidden_states": hidden_states,
303 | "residual": residual
304 | })
305 | hidden_states, _ = self.norm(hidden_states, residual)
306 | return hidden_states
307 |
308 |
309 | class Qwen2ForCausalLM(nn.Module, SupportsLoRA):
310 | packed_modules_mapping = {
311 | "qkv_proj": [
312 | "q_proj",
313 | "k_proj",
314 | "v_proj",
315 | ],
316 | "gate_up_proj": [
317 | "gate_proj",
318 | "up_proj",
319 | ],
320 | }
321 |
322 | # LoRA specific attributes
323 | supported_lora_modules = [
324 | "qkv_proj",
325 | "o_proj",
326 | "gate_up_proj",
327 | "down_proj",
328 | ]
329 | embedding_modules = {}
330 | embedding_padding_modules = []
331 |
332 | def __init__(
333 | self,
334 | config: Qwen2Config,
335 | cache_config: Optional[CacheConfig] = None,
336 | quant_config: Optional[QuantizationConfig] = None,
337 | lora_config: Optional[LoRAConfig] = None,
338 | ) -> None:
339 | # TODO (@robertgshaw2): see if this can be moved out
340 | if (cache_config.sliding_window is not None
341 | and hasattr(config, "max_window_layers")):
342 | raise ValueError("Sliding window for some but all layers is not "
343 | "supported. This model uses sliding window "
344 | "but `max_window_layers` = %s is less than "
345 | "`num_hidden_layers` = %s. Please open an issue "
346 | "to discuss this feature." % (
347 | config.max_window_layers,
348 | config.num_hidden_layers,
349 | ))
350 |
351 | super().__init__()
352 |
353 | self.config = config
354 | self.lora_config = lora_config
355 |
356 | self.quant_config = quant_config
357 | self.model = Qwen2Model(config, cache_config, quant_config)
358 |
359 | if config.tie_word_embeddings:
360 | self.lm_head = self.model.embed_tokens
361 | else:
362 | self.lm_head = ParallelLMHead(config.vocab_size,
363 | config.hidden_size,
364 | quant_config=quant_config)
365 |
366 | self.logits_processor = LogitsProcessor(config.vocab_size)
367 | self.sampler = Sampler()
368 |
369 | def forward(
370 | self,
371 | input_ids: torch.Tensor,
372 | positions: torch.Tensor,
373 | kv_caches: List[torch.Tensor],
374 | attn_metadata: AttentionMetadata,
375 | intermediate_tensors: Optional[IntermediateTensors] = None,
376 | ) -> torch.Tensor:
377 | hidden_states = self.model(input_ids, positions, kv_caches,
378 | attn_metadata, intermediate_tensors)
379 | return hidden_states
380 |
381 | def compute_logits(self, hidden_states: torch.Tensor,
382 | sampling_metadata: SamplingMetadata) -> torch.Tensor:
383 | logits = self.logits_processor(self.lm_head, hidden_states,
384 | sampling_metadata)
385 | return logits
386 |
387 | def make_empty_intermediate_tensors(
388 | self, batch_size: int, dtype: torch.dtype,
389 | device: torch.device) -> IntermediateTensors:
390 | return IntermediateTensors({
391 | "hidden_states":
392 | torch.zeros((batch_size, self.config.hidden_size),
393 | dtype=dtype,
394 | device=device),
395 | "residual":
396 | torch.zeros((batch_size, self.config.hidden_size),
397 | dtype=dtype,
398 | device=device),
399 | })
400 |
401 | def sample(
402 | self,
403 | logits: torch.Tensor,
404 | sampling_metadata: SamplingMetadata,
405 | ) -> Optional[SamplerOutput]:
406 | next_tokens = self.sampler(logits, sampling_metadata)
407 | return next_tokens
408 |
409 | def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
410 | stacked_params_mapping = [
411 | # (param_name, shard_name, shard_id)
412 | ("qkv_proj", "q_proj", "q"),
413 | ("qkv_proj", "k_proj", "k"),
414 | ("qkv_proj", "v_proj", "v"),
415 | ("gate_up_proj", "gate_proj", 0),
416 | ("gate_up_proj", "up_proj", 1),
417 | ]
418 | params_dict = dict(self.named_parameters(remove_duplicate=False))
419 | for name, loaded_weight in weights:
420 | if "rotary_emb.inv_freq" in name:
421 | continue
422 | if self.config.tie_word_embeddings and "lm_head.weight" in name:
423 | continue
424 | for (param_name, weight_name, shard_id) in stacked_params_mapping:
425 | if weight_name not in name:
426 | continue
427 | name = name.replace(weight_name, param_name)
428 | # Skip loading extra bias for GPTQ models.
429 | if name.endswith(".bias") and name not in params_dict:
430 | continue
431 | if is_pp_missing_parameter(name, self):
432 | continue
433 | param = params_dict[name]
434 | weight_loader = param.weight_loader
435 | weight_loader(param, loaded_weight, shard_id)
436 | break
437 | else:
438 | # Skip loading extra bias for GPTQ models.
439 | if name.endswith(".bias") and name not in params_dict:
440 | continue
441 | # Remapping the name of FP8 kv-scale.
442 | name = maybe_remap_kv_scale_name(name, params_dict)
443 | if name is None:
444 | continue
445 | if is_pp_missing_parameter(name, self):
446 | continue
447 | param = params_dict[name]
448 | weight_loader = getattr(param, "weight_loader",
449 | default_weight_loader)
450 | weight_loader(param, loaded_weight)
451 |
452 | class CodeFuse_CGE_Large(nn.Module, SupportsLoRA):
453 | packed_modules_mapping = {
454 | "qkv_proj": [
455 | "q_proj",
456 | "k_proj",
457 | "v_proj",
458 | ],
459 | "gate_up_proj": [
460 | "gate_proj",
461 | "up_proj",
462 | ],
463 | }
464 |
465 | # LoRA specific attributes
466 | supported_lora_modules = [
467 | "qkv_proj",
468 | "o_proj",
469 | "gate_up_proj",
470 | "down_proj",
471 | ]
472 | embedding_modules = {}
473 | embedding_padding_modules = []
474 |
475 | def __init__(
476 | self,
477 | config: Qwen2Config,
478 | cache_config: Optional[CacheConfig] = None,
479 | quant_config: Optional[QuantizationConfig] = None,
480 | lora_config: Optional[LoRAConfig] = None,
481 | ) -> None:
482 | # TODO (@robertgshaw2): see if this can be moved out
483 | if (cache_config.sliding_window is not None
484 | and hasattr(config, "max_window_layers")):
485 | raise ValueError("Sliding window for some but all layers is not "
486 | "supported. This model uses sliding window "
487 | "but `max_window_layers` = %s is less than "
488 | "`num_hidden_layers` = %s. Please open an issue "
489 | "to discuss this feature." % (
490 | config.max_window_layers,
491 | config.num_hidden_layers,
492 | ))
493 |
494 | super().__init__()
495 |
496 | self.config = config
497 | self.lora_config = lora_config
498 | self.quant_config = quant_config
499 | self.plm_model = Qwen2ForCausalLM(config, cache_config, quant_config)
500 | self.embedding_method = config.embedding_method
501 | self.inf_seq_length = config.inf_seq_length
502 | self.padding_side = config.padding_side
503 | self.keep_max_layer = config.keep_max_layer
504 | self.emb_dim = self.plm_model.model.embed_tokens.weight.size(1)
505 | self.num_heads = config.pma_num_heads
506 | self.ln = config.pma_ln
507 | self.norm = config.pma_norm
508 | self.pma_mode = config.pma_norm_mode
509 | self.mha_pma = PMA(self.emb_dim, self.compress_dim, self.num_heads, 1, ln=self.ln, pma_mode=self.pma_mode).to("cuda")
510 | if config.tie_word_embeddings:
511 | self.lm_head = self.plm_model.embed_tokens
512 | else:
513 | self.lm_head = ParallelLMHead(config.vocab_size,
514 | config.hidden_size,
515 | quant_config=quant_config)
516 |
517 | self.logits_processor = LogitsProcessor(config.vocab_size)
518 | self.sampler = Sampler()
519 | for param_tensor in self.mha_pma.state_dict():
520 | print(param_tensor, "\t", self.mha_pma.state_dict()[param_tensor])
521 |
522 | def forward(
523 | self,
524 | input_ids: torch.Tensor,
525 | positions: torch.Tensor,
526 | kv_caches: List[torch.Tensor],
527 | attn_metadata: AttentionMetadata,
528 | intermediate_tensors: Optional[IntermediateTensors] = None,
529 | ) -> torch.Tensor:
530 | hidden_states = self.plm_model(input_ids, positions, kv_caches,
531 | attn_metadata, intermediate_tensors)
532 |
533 | embedding = hidden_states.unsqueeze(0)
534 | res_embedding = self.pma_embedding(embedding, positions.unsqueeze(0))
535 | return res_embedding
536 |
537 | def pooler(
538 | self,
539 | hidden_states: torch.Tensor,
540 | pooling_metadata: PoolingMetadata,
541 | ) -> Optional[PoolerOutput]:
542 | hidden_states = nn.functional.normalize(hidden_states, p=2, dim=1)
543 | pooled_outputs = [
544 | EmbeddingSequenceGroupOutput(data.tolist()) for data in hidden_states
545 | ]
546 |
547 | return PoolerOutput(outputs=pooled_outputs)
548 |
549 | def compute_logits(self, hidden_states: torch.Tensor,
550 | sampling_metadata: SamplingMetadata) -> torch.Tensor:
551 | logits = self.logits_processor(self.lm_head, hidden_states,
552 | sampling_metadata)
553 | return logits
554 |
555 | def make_empty_intermediate_tensors(
556 | self, batch_size: int, dtype: torch.dtype,
557 | device: torch.device) -> IntermediateTensors:
558 | return IntermediateTensors({
559 | "hidden_states":
560 | torch.zeros((batch_size, self.config.hidden_size),
561 | dtype=dtype,
562 | device=device),
563 | "residual":
564 | torch.zeros((batch_size, self.config.hidden_size),
565 | dtype=dtype,
566 | device=device),
567 | })
568 |
569 | def sample(
570 | self,
571 | logits: torch.Tensor,
572 | sampling_metadata: SamplingMetadata,
573 | ) -> Optional[SamplerOutput]:
574 | next_tokens = self.sampler(logits, sampling_metadata)
575 | return next_tokens
576 |
577 | def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
578 | stacked_params_mapping = [
579 | # (param_name, shard_name, shard_id)
580 | ("qkv_proj", "q_proj", "q"),
581 | ("qkv_proj", "k_proj", "k"),
582 | ("qkv_proj", "v_proj", "v"),
583 | ("gate_up_proj", "gate_proj", 0),
584 | ("gate_up_proj", "up_proj", 1),
585 | ]
586 | params_dict = dict(self.named_parameters(remove_duplicate=False))
587 | for name, loaded_weight in weights:
588 | if "rotary_emb.inv_freq" in name:
589 | continue
590 | if self.config.tie_word_embeddings and "lm_head.weight" in name:
591 | continue
592 | for (param_name, weight_name, shard_id) in stacked_params_mapping:
593 | if weight_name not in name:
594 | continue
595 | name = name.replace(weight_name, param_name)
596 | # Skip loading extra bias for GPTQ models.
597 | if name.endswith(".bias") and name not in params_dict:
598 | continue
599 | if is_pp_missing_parameter(name, self):
600 | continue
601 | param = params_dict[name]
602 | weight_loader = param.weight_loader
603 | weight_loader(param, loaded_weight, shard_id)
604 | break
605 | else:
606 | # Skip loading extra bias for GPTQ models.
607 | if name.endswith(".bias") and name not in params_dict:
608 | continue
609 | # Remapping the name of FP8 kv-scale.
610 | name = maybe_remap_kv_scale_name(name, params_dict)
611 | if name is None:
612 | continue
613 | if is_pp_missing_parameter(name, self):
614 | continue
615 | param = params_dict[name]
616 | weight_loader = getattr(param, "weight_loader",
617 | default_weight_loader)
618 | weight_loader(param, loaded_weight)
619 | for param_tensor in self.mha_pma.state_dict():
620 | print(param_tensor, "\t", self.mha_pma.state_dict()[param_tensor])
621 |
622 | def last_embedding(self, A, index):
623 | bs, seq, emb = A.size()
624 | res = A[torch.arange(bs), index, :]
625 | return res
626 |
627 | def mean_embedding(self, A, mask):
628 | bs, seq, emb = A.size()
629 | res = (A * (mask.unsqueeze(-1))).sum(1) / (mask.sum(1).unsqueeze(-1))
630 | return res
631 |
632 | # A (bs, seq, emb_size), mask (bs, 1, seq)
633 | def weighted_embedding(self, A, mask):
634 | weights = (torch.arange(start=1, end=A.size(1) + 1).unsqueeze(0).unsqueeze(-1).expand(A.size()).float()).to(A.device)
635 | input_mask_expanded = (mask.squeeze(1).unsqueeze(-1).expand(A.size()).float()).to(A.device)
636 | sum_embedding = torch.sum(A * input_mask_expanded * weights, dim=1)
637 | sum_mask = torch.sum(input_mask_expanded * weights, dim=1)
638 | weighted_embedding = sum_embedding / sum_mask
639 |
640 | return weighted_embedding
641 |
642 | def pma_embedding(self, A, mask):
643 | res = self.mha_pma(A, mask).squeeze(1)
644 | return res
645 |
646 |
647 | def get_sentence_embedding(self, embedding_method, **inputs):
648 | outputs = self.plm_model(inputs['input_ids'], inputs['attention_mask'], output_hidden_states=True)
649 | if embedding_method == 'last':
650 | embedding = outputs.hidden_states[self.keep_max_layer]
651 | index = inputs['attention_mask'].sum(-1).long() - 1
652 | res_embedding = self.last_embedding(embedding, index)
653 | elif embedding_method == 'mean':
654 | embedding = outputs.hidden_states[self.keep_max_layer]
655 | res_embedding = self.mean_embedding(embedding, inputs['attention_mask'])
656 | elif embedding_method == 'weighted':
657 | embedding = outputs.hidden_states[self.keep_max_layer]
658 | res_embedding = self.weighted_embedding(embedding, inputs['attention_mask'])
659 | elif embedding_method == 'pma':
660 | embedding = outputs.hidden_states[self.keep_max_layer]
661 | attention_mask = inputs['attention_mask']
662 | res_embedding = self.pma_embedding(embedding, attention_mask)
663 | else:
664 | logger.debug('Error, no {} way to obtain embbedings'.format(embedding_method))
665 |
666 | if not self.norm:
667 | res_embedding = torch.nn.functional.normalize(res_embedding, p=2.0, dim=-1, eps=1e-12, out=None)
668 | return res_embedding
669 |
670 |
671 |
672 | def encode(self, tokenizer, sentences, batch_size=32, convert_to_numpy=True,
673 | convert_to_tensor=False, show_progress_bar=True, max_seq_length=None, **kwargs):
674 | if max_seq_length is None:
675 | max_seq_length = self.inf_seq_length
676 | input_is_string = False
677 | if isinstance(sentences, str) or not hasattr(sentences, "__len__"):
678 | sentences = [sentences]
679 | input_is_string = True
680 |
681 | all_embeddings = []
682 | length_sorted_idx = np.argsort([-len(s) for s in sentences])
683 | sentences_sorted = [sentences[idx] for idx in length_sorted_idx] # 大到小重排
684 | with torch.no_grad():
685 | for start_index in trange(0, len(sentences), batch_size, desc="Batches", disable=not show_progress_bar):
686 | sentences_batch = sentences_sorted[start_index: start_index + batch_size]
687 | # Compute sentences embeddings
688 | with torch.no_grad():
689 | inputs = tokenizer(sentences_batch, padding=True, truncation=True, max_length=max_seq_length, add_special_tokens=False, return_tensors='pt').to(self.plm_model.device)
690 | embeddings = self.get_sentence_embedding(self.embedding_method, **inputs)
691 | embeddings = embeddings.detach()
692 | if convert_to_numpy:
693 | if embeddings.dtype == torch.bfloat16:
694 | embeddings = embeddings.cpu().to(torch.float32)
695 | else:
696 | embeddings = embeddings.cpu()
697 | all_embeddings.extend(embeddings)
698 | all_embeddings = [all_embeddings[idx] for idx in np.argsort(length_sorted_idx)]
699 | if convert_to_tensor:
700 | all_embeddings = torch.stack(all_embeddings)
701 | elif convert_to_numpy:
702 | all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings])
703 |
704 | if input_is_string:
705 | all_embeddings = all_embeddings[0]
706 | return all_embeddings
707 |
708 |
709 | class MAB_POST(nn.Module):
710 | def __init__(self, dim_Q, dim_K, dim_V, num_heads, ln=False):
711 | super(MAB_POST, self).__init__()
712 | self.dim_V = dim_V
713 | self.num_heads = num_heads
714 | self.fc_q = nn.Linear(dim_Q, dim_V)
715 | self.fc_k = nn.Linear(dim_K, dim_V)
716 | self.fc_v = nn.Linear(dim_K, dim_V)
717 |
718 | if ln:
719 | self.ln0 = nn.LayerNorm(dim_V)
720 | self.ln1 = nn.LayerNorm(dim_V)
721 | self.fc_o = nn.Linear(dim_V, dim_V)
722 | nn.init.xavier_uniform_(self.fc_q.weight)
723 | nn.init.xavier_uniform_(self.fc_k.weight)
724 | nn.init.xavier_uniform_(self.fc_v.weight)
725 | nn.init.xavier_uniform_(self.fc_o.weight)
726 |
727 | def forward(self, Q, K, pad_mask=None):
728 |
729 | Q_ = self.fc_q(Q)
730 | K_, V_ = self.fc_k(K), self.fc_v(K)
731 | dim_split = self.dim_V // self.num_heads
732 | Q_ = torch.cat(Q_.split(dim_split, 2), 0)
733 | K_ = torch.cat(K_.split(dim_split, 2), 0)
734 | V_ = torch.cat(V_.split(dim_split, 2), 0)
735 |
736 | pad_mask = pad_mask.unsqueeze(1).repeat(self.num_heads, Q.size(1), 1)
737 | score = Q_.bmm(K_.transpose(1,2))/math.sqrt(self.dim_V)
738 | score = score.masked_fill(pad_mask == 0, -1e12)
739 | A = torch.softmax(score, 2)
740 | A = A * pad_mask
741 | O = torch.cat(A.bmm(V_).split(Q.size(0), 0), 2)
742 | O = Q + O
743 | O = O if getattr(self, 'ln0', None) is None else self.ln0(O)
744 | O = O + F.relu(self.fc_o(O))
745 | O = O if getattr(self, 'ln1', None) is None else self.ln1(O)
746 | return O
747 |
748 |
749 | class MAB_PRE_NORMAL(nn.Module):
750 | def __init__(self, dim_Q, dim_K, dim_V, num_heads, ln=False):
751 | super(MAB_PRE_NORMAL, self).__init__()
752 | self.dim_V = dim_V
753 | self.num_heads = num_heads
754 | self.fc_q = nn.Linear(dim_Q, dim_V)
755 | self.fc_k = nn.Linear(dim_K, dim_V)
756 | self.fc_v = nn.Linear(dim_K, dim_V)
757 |
758 | if ln:
759 | self.ln_q = nn.LayerNorm(dim_V)
760 | self.ln_kv = nn.LayerNorm(dim_V)
761 | self.ln_o = nn.LayerNorm(dim_V)
762 | self.ln_final = nn.LayerNorm(dim_V)
763 |
764 | self.fc_o = nn.Linear(dim_V, dim_V)
765 | nn.init.xavier_uniform_(self.fc_q.weight)
766 | nn.init.xavier_uniform_(self.fc_k.weight)
767 | nn.init.xavier_uniform_(self.fc_v.weight)
768 | nn.init.xavier_uniform_(self.fc_o.weight)
769 |
770 | def forward(self, Q, K, pad_mask=None):
771 | Q_ = Q if getattr(self, 'ln_q', None) is None else self.ln_q(Q)
772 | K_ = K if getattr(self, 'ln_kv', None) is None else self.ln_kv(K)
773 | Q_ = self.fc_q(Q_)
774 | K_, V_ = self.fc_k(K_), self.fc_v(K_)
775 | dim_split = self.dim_V // self.num_heads
776 | Q_ = torch.cat(Q_.split(dim_split, 2), 0)
777 | K_ = torch.cat(K_.split(dim_split, 2), 0)
778 | V_ = torch.cat(V_.split(dim_split, 2), 0)
779 | pad_mask = pad_mask.unsqueeze(1).repeat(self.num_heads, Q.size(1), 1)
780 | score = Q_.bmm(K_.transpose(1,2))/math.sqrt(self.dim_V)
781 | score = score.masked_fill(pad_mask == 0, -1e12)
782 | A = torch.softmax(score, 2)
783 | A = A * pad_mask
784 | O = torch.cat(A.bmm(V_).split(Q.size(0), 0), 2)
785 | O = Q + O
786 | O_ = O if getattr(self, 'ln_o', None) is None else self.ln_o(O)
787 | O_ = O + F.relu(self.fc_o(O_))
788 | return O_ if getattr(self, 'ln_final', None) is None else self.ln_final(O_)
789 |
790 |
791 | class MAB_PRE_GPTJ(nn.Module):
792 | def __init__(self, dim_Q, dim_K, dim_V, num_heads, ln=False):
793 | super(MAB_PRE_GPTJ, self).__init__()
794 | self.dim_V = dim_V
795 | self.num_heads = num_heads
796 | self.fc_q = nn.Linear(dim_Q, dim_V)
797 | self.fc_k = nn.Linear(dim_K, dim_V)
798 | self.fc_v = nn.Linear(dim_K, dim_V)
799 | self.fc_o = nn.Linear(dim_V, dim_V)
800 |
801 | nn.init.xavier_uniform_(self.fc_q.weight)
802 | nn.init.xavier_uniform_(self.fc_k.weight)
803 | nn.init.xavier_uniform_(self.fc_v.weight)
804 | nn.init.xavier_uniform_(self.fc_o.weight)
805 | if ln:
806 | self.ln_q = nn.LayerNorm(dim_V)
807 | self.ln_kv = nn.LayerNorm(dim_V)
808 | self.ln_final = nn.LayerNorm(dim_V)
809 |
810 | def forward(self, Q, K, pad_mask=None):
811 | Q_ = Q if getattr(self, 'ln_q', None) is None else self.ln_q(Q)
812 | K_ = K if getattr(self, 'ln_kv', None) is None else self.ln_kv(K)
813 |
814 | Q1 = self.fc_q(Q_)
815 | K1, V1 = self.fc_k(K_), self.fc_v(K_)
816 | dim_split = self.dim_V // self.num_heads
817 |
818 | Q1 = torch.cat(Q1.split(dim_split, 2), 0)
819 | K1 = torch.cat(K1.split(dim_split, 2), 0)
820 | V1 = torch.cat(V1.split(dim_split, 2), 0)
821 |
822 |
823 | pad_mask = pad_mask.unsqueeze(1).repeat(self.num_heads, Q.size(1), 1)
824 | score = Q1.bmm(K1.transpose(1,2))/math.sqrt(self.dim_V)
825 | score = score.masked_fill(pad_mask == 0, -1e12)
826 | A = torch.softmax(score, 2)
827 | A = A * pad_mask
828 | O1 = torch.cat(A.bmm(V1).split(Q.size(0), 0), 2)
829 | O2 = F.relu(self.fc_o(Q_))
830 | O_final = Q + O1 + O2
831 | return O_final if getattr(self, 'ln_final', None) is None else self.ln_final(O_final)
832 |
833 |
834 | class PMA(nn.Module):
835 | def __init__(self, dim, compress_dim, num_heads, num_seeds, ln=False, pma_mode=None):
836 | super(PMA, self).__init__()
837 | self.S = nn.Parameter(torch.Tensor(1, num_seeds, compress_dim))
838 | nn.init.xavier_uniform_(self.S)
839 | if pma_mode == 'post_normal':
840 | self.mab = MAB_POST(compress_dim, dim, compress_dim, num_heads, ln=ln)
841 | elif pma_mode == 'pre_normal':
842 | self.mab = MAB_PRE_NORMAL(compress_dim, dim, compress_dim, num_heads, ln=ln)
843 | elif pma_mode == 'pre_gptj':
844 | self.mab = MAB_PRE_GPTJ(compress_dim, dim, compress_dim, num_heads, ln=ln)
845 | else:
846 | raise ValueError(f"Error, the pma_mode {pma_mode} is not implemented !")
847 |
848 | def forward(self, X, pad_mask):
849 | if self.S.dtype != torch.bfloat16:
850 | X = X.float()
851 | return self.mab(self.S.repeat(X.size(0), 1, 1), X, pad_mask)
852 |
--------------------------------------------------------------------------------