├── src ├── utils │ ├── constant.py │ ├── logger.py │ ├── metrics.py │ └── llama_utils.py ├── args.py ├── infer.py ├── pooling_layers.py ├── dataset.py ├── model.py ├── generation_utils.py └── train.py ├── config └── bf16.yaml ├── README.md ├── infer.sh ├── train.sh └── reranker.py /src/utils/constant.py: -------------------------------------------------------------------------------- 1 | PREFIX_CHECKPOINT_DIR = 'checkpoint' 2 | WEIGHTS_NAME = "pytorch_model.bin" 3 | CONFIG_NAME = "config.json" 4 | OPTIMIZER_NAME = "optimizer.pt" 5 | SCHEDULER_NAME = "scheduler.pt" 6 | TRAINING_ARGS_NAME = "training_args.bin" 7 | TRAINER_STATE_NAME = "trainer_state.json" 8 | 9 | FFN_WEIGHTS_NAME = "ffn.pytorch_model.bin" 10 | COMPRESSOR_WEIGHTS_NAME = "compressor.pytorch_model.bin" 11 | POOLING_WEIGHTS_NAME = "pooling_layer.pytorch_model.bin" 12 | -------------------------------------------------------------------------------- /config/bf16.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | deepspeed_config: 3 | gradient_accumulation_steps: 4 4 | gradient_clipping: 2.0 5 | offload_optimizer_device: none 6 | offload_param_device: none 7 | zero3_init_flag: false 8 | zero_stage: 2 9 | 10 | distributed_type: DEEPSPEED 11 | downcast_bf16: 'no' 12 | machine_rank: 0 13 | main_training_function: main 14 | mixed_precision: bf16 15 | num_machines: 1 16 | num_processes: 8 17 | rdzv_backend: static 18 | same_network: true 19 | tpu_env: [] 20 | tpu_use_cluster: false 21 | tpu_use_sudo: false 22 | use_cpu: false 23 | main_process_port: 24242 -------------------------------------------------------------------------------- /src/utils/logger.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import logging 3 | import datasets 4 | import transformers 5 | import accelerate 6 | 7 | formatter = logging.Formatter( 8 | fmt="%(asctime)s | %(levelname)s | %(filename)s - %(lineno)d | %(message)s", 9 | datefmt="%Y-%m-%d %H:%M:%S", 10 | ) 11 | 12 | def get_logger(name: str) -> logging.Logger: 13 | 14 | handler = logging.StreamHandler(sys.stdout) 15 | handler.setFormatter(formatter) 16 | 17 | logger = logging.getLogger(name) 18 | logger.addHandler(handler) 19 | 20 | accelerator = accelerate.Accelerator() 21 | if accelerator.is_main_process: 22 | logger.setLevel(logging.INFO) 23 | # datasets.utils.logging.set_verbosity_info() 24 | # transformers.utils.logging.set_verbosity_info() 25 | else: 26 | logger.setLevel(logging.ERROR) 27 | datasets.utils.logging.set_verbosity_error() 28 | transformers.utils.logging.set_verbosity_error() 29 | 30 | return logger -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Query-Guided Compressor (QGC) 2 | Code for "Retaining Key Information under High Compression Rates: Query-Guided Compressor for LLMs" (ACL 2024) 3 | 4 | ## Requirements 5 | 6 | ``` 7 | datasets==2.15.0 8 | flash-attn==2.3.3 9 | jsonlines==4.0.0 10 | torch==2.0.0 11 | torchvision==0.15.0 12 | transformers==4.35.0 13 | ``` 14 | 15 | ## Instructions 16 | 17 | We use an example to show how to use our codes. 18 | 19 | ### LLMs and Datasets 20 | 21 | We use [LongChat-13B](https://huggingface.co/lmsys/longchat-13b-16k) as the target LLM, and use Llama-2-7B to initial the compressor parameters. For datasets, we use open-source QA datasets (NaturalQuestions, TrivialQA, HotpotQA) to train our compressor and evaluate it. All datasets can be downloaded from [this site](https://drive.google.com/drive/folders/1HhwPP6iZUBbAjWeWRkbEPtgXVIRZUz6V?usp=drive_link). 22 | 23 | ### QGC Training and Inference 24 | 25 | ``` 26 | # train compressor 27 | bash train.sh 28 | 29 | # evaluate compressor 30 | bash infer.sh 31 | ``` 32 | -------------------------------------------------------------------------------- /infer.sh: -------------------------------------------------------------------------------- 1 | data_path=/path-to-test-data-file 2 | compressor_path=/path-to-llama-2-7B 3 | lm_model_path=/path-to-longchat-13B 4 | from_checkpoint=/path-to-compressor-checkpoint 5 | save_path=/path-to-save-generation 6 | 7 | batch_size=4 8 | lm_model_name=longchat 9 | compressor_hidden_size=4096 10 | lm_model_hidden_size=5120 11 | num_compressor_layers=4 12 | num_compressor_encoder_layers=2 13 | benchmark_metric=accuracy 14 | instruction_name=base 15 | num_eval_documents=4 16 | pw_window_sizes=(2 4 6 8) 17 | pw_window_sizes_str=$(printf "_%s" "${pw_window_sizes[@]}") 18 | pw_window_sizes_str=${pw_window_sizes_str:1} 19 | 20 | mkdir -p $save_path 21 | 22 | accelerate launch --config_file config/bf16.yaml \ 23 | src/infer.py \ 24 | --data_path $data_path \ 25 | --compressor_path $compressor_path \ 26 | --lm_model_name $lm_model_name \ 27 | --lm_model_path $lm_model_path \ 28 | --compressor_hidden_size $compressor_hidden_size \ 29 | --lm_model_hidden_size $lm_model_hidden_size \ 30 | --num_compressor_layers $num_compressor_layers \ 31 | --num_compressor_encoder_layers $num_compressor_encoder_layers \ 32 | --eval_batch_size $batch_size \ 33 | --save_path $save_path \ 34 | --num_eval_documents $num_eval_documents \ 35 | --pw_window_sizes ${pw_window_sizes[@]} \ 36 | --from_checkpoint $from_checkpoint \ 37 | --benchmark_metric $benchmark_metric \ 38 | --instruction_name $instruction_name \ 39 | --fix_compressor_mlp_parameters \ 40 | | tee ${save_path}/infer.log 41 | -------------------------------------------------------------------------------- /src/args.py: -------------------------------------------------------------------------------- 1 | import math 2 | from dataclasses import dataclass 3 | from typing import List 4 | 5 | @dataclass 6 | class QGCArguments: 7 | seed: int = 42 8 | 9 | # model args 10 | compressor_path: str = None 11 | lm_model_path: str = None 12 | lm_model_name: str = 'longchat' 13 | num_compressor_layers: int = 4 14 | num_compressor_encoder_layers: int = 2 15 | fix_compressor_mlp_parameters: bool = False 16 | num_attention_heads: int = 32 17 | attn_doc_topp: float = 0.25 18 | compressor_hidden_size: int = 4096 19 | lm_model_hidden_size: int = 5120 20 | 21 | # training args 22 | from_checkpoint: str = None 23 | generation_split_token: str = None 24 | 25 | pool_window_size: int = 4 26 | random_pool_window_size: bool = False 27 | cand_pool_window_sizes: List[int] = None 28 | 29 | train_batch_size: int = 4 30 | eval_batch_size: int = 4 31 | gradient_accumulation_steps: int = 4 32 | 33 | max_steps: int = 150000 34 | learning_rate: float = 5e-5 35 | lr_scheduler_type: str = 'linear' 36 | warmup_ratio: float = 0.0 37 | max_grad_norm: float = 1.0 38 | 39 | logging_steps: int = 100 40 | dev_steps: int = 500 41 | test_steps: int = 500 42 | save_steps: int = 1000 43 | 44 | do_benchmark: bool = False 45 | benchmark_dev_steps: int = 1000 46 | benchmark_test_steps: int = 1000 47 | benchmark_metric: str = None 48 | 49 | label_pad_token_id: int = -100 50 | 51 | # inference args 52 | pw_window_sizes: List[int] = None 53 | 54 | # data args 55 | save_path: str = None 56 | data_path: str = None 57 | output_dir: str = None 58 | train_data_path: str = None 59 | dev_data_path: str = None 60 | test_data_path: str = None 61 | num_eval_documents: int = 5 62 | 63 | num_gold_documents: int = 1 64 | use_answer_as_target: bool = False 65 | instruction_name: str = 'base' 66 | gold_first_for_kd: bool = False 67 | 68 | min_num_documents: int = 1 69 | max_num_documents: int = 5 70 | random_num_documents: bool = False 71 | 72 | max_new_tokens: int = 100 73 | max_doc_tokens: int = 512 74 | question_mask_ratio: float = 0.5 75 | 76 | def get_warmup_steps(self, num_training_steps): 77 | return math.ceil(num_training_steps * self.warmup_ratio) -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | batch_size=2 2 | question_mask_ratio=0.5 3 | distillation_temp=1.0 4 | compressor_hidden_size=4096 5 | num_compressor_layers=4 6 | num_compressor_encoder_layers=2 7 | pool_window_size=4 8 | cand_pool_window_sizes=(4 6 8 10) 9 | min_num_documents=1 # 1 for NQ and TQA, 2 for HQA 10 | max_num_documents=5 11 | compressor_path=/path-to-llama-2-7B # used to initial compressor parameters 12 | 13 | # target LLM is LongChat-13B 14 | lm_model_name=longchat 15 | lm_model_hidden_size=5120 16 | lm_model_path=/path-to-longchat-13B 17 | 18 | # # target LLM is LLaMA-2-7B 19 | # lm_model_name=llama 20 | # lm_model_hidden_size=4096 21 | # lm_model_path=/path-to-llama-2-7B 22 | 23 | data_path=/path-to-dataset 24 | max_steps=20000 25 | dev_steps=500 26 | test_steps=500 27 | save_steps=1000 28 | logging_steps=100 29 | benchmark_dev_steps=1000 30 | benchmark_test_steps=1000 31 | 32 | instruction_name=base # 'base' for NQ, 'short' for TQA and HQA 33 | benchmark_metric=accuracy # NQ: accuracy; TQA: em; HQA: f1 34 | 35 | output_dir=/path-to-save 36 | mkdir -p ${output_dir} 37 | 38 | accelerate launch --config_file config/bf16.yaml \ 39 | src/train.py \ 40 | --data_path $data_path \ 41 | --compressor_path $compressor_path \ 42 | --lm_model_name $lm_model_name \ 43 | --lm_model_path $lm_model_path \ 44 | --output_dir $output_dir \ 45 | --question_mask_ratio $question_mask_ratio \ 46 | --instruction_name $instruction_name \ 47 | --compressor_hidden_size $compressor_hidden_size \ 48 | --lm_model_hidden_size $lm_model_hidden_size \ 49 | --num_compressor_layers $num_compressor_layers \ 50 | --num_compressor_encoder_layers $num_compressor_encoder_layers \ 51 | --random_num_documents \ 52 | --max_num_documents $max_num_documents \ 53 | --min_num_documents $min_num_documents \ 54 | --pool_window_size $pool_window_size \ 55 | --train_batch_size $batch_size \ 56 | --eval_batch_size $batch_size \ 57 | --max_steps $max_steps \ 58 | --dev_steps $dev_steps \ 59 | --test_steps $test_steps \ 60 | --save_steps $save_steps \ 61 | --logging_steps $logging_steps \ 62 | --do_benchmark \ 63 | --benchmark_dev_steps $benchmark_dev_steps \ 64 | --benchmark_test_steps $benchmark_test_steps \ 65 | --benchmark_metric $benchmark_metric \ 66 | --gold_first_for_kd \ 67 | --random_pool_window_size \ 68 | --cand_pool_window_sizes ${cand_pool_window_sizes[@]} \ 69 | | tee ${output_dir}/train.log -------------------------------------------------------------------------------- /src/utils/metrics.py: -------------------------------------------------------------------------------- 1 | import string 2 | from typing import List 3 | from collections import Counter 4 | import regex 5 | from rouge import Rouge 6 | 7 | def normalize_text(s: str) -> str: 8 | """Normalization from the SQuAD evaluation script. 9 | 10 | See https://worksheets.codalab.org/rest/bundles/0x6b567e1cf2e041ec80d7098f031c5c9e/contents/blob/ 11 | """ 12 | 13 | def remove_articles(text): 14 | return regex.sub(r"\b(a|an|the)\b", " ", text) 15 | 16 | def white_space_fix(text): 17 | return " ".join(text.split()) 18 | 19 | def remove_punc(text): 20 | exclude = set(string.punctuation) 21 | return "".join(ch for ch in text if ch not in exclude) 22 | 23 | def lower(text): 24 | return text.lower() 25 | 26 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 27 | 28 | 29 | def best_subspan_em(prediction: str, ground_truths: List[str]) -> float: 30 | normalized_prediction = normalize_text(prediction) 31 | 32 | for ground_truth in ground_truths: 33 | normalized_ground_truth = normalize_text(ground_truth) 34 | if normalized_ground_truth.lower() in normalized_prediction.lower(): 35 | return 1.0 36 | return 0.0 37 | 38 | 39 | def exact_match(prediction: str, ground_truths: List[str]) -> float: 40 | normalized_prediction = normalize_text(prediction) 41 | 42 | for ground_truth in ground_truths: 43 | normalized_ground_truth = normalize_text(ground_truth) 44 | if normalized_prediction.lower() == normalized_ground_truth.lower(): 45 | return 1.0 46 | return 0.0 47 | 48 | 49 | def f1_score(prediction, ground_truths): 50 | '''F1 Score from the HotpotQA evaluation script. 51 | 52 | See https://github.com/hotpotqa/hotpot/blob/master/hotpot_evaluate_v1.py 53 | ''' 54 | 55 | normalized_prediction = normalize_text(prediction) 56 | normalized_ground_truth = normalize_text(ground_truths[0]) 57 | 58 | ZERO_METRIC = 0.0 59 | 60 | if normalized_prediction in ['yes', 'no', 'noanswer'] and normalized_prediction != normalized_ground_truth: 61 | return ZERO_METRIC 62 | if normalized_ground_truth in ['yes', 'no', 'noanswer'] and normalized_prediction != normalized_ground_truth: 63 | return ZERO_METRIC 64 | 65 | prediction_tokens = normalized_prediction.split() 66 | ground_truth_tokens = normalized_ground_truth.split() 67 | common = Counter(prediction_tokens) & Counter(ground_truth_tokens) 68 | num_same = sum(common.values()) 69 | if num_same == 0: 70 | return ZERO_METRIC 71 | precision = 1.0 * num_same / len(prediction_tokens) 72 | recall = 1.0 * num_same / len(ground_truth_tokens) 73 | f1 = (2 * precision * recall) / (precision + recall) 74 | return f1 75 | 76 | 77 | benchmark_function_map = { 78 | 'accuracy': best_subspan_em, 79 | 'em': exact_match, 80 | 'f1': f1_score, 81 | } -------------------------------------------------------------------------------- /src/utils/llama_utils.py: -------------------------------------------------------------------------------- 1 | # code adapted from https://huggingface.co/kaiokendev/superhot-13b-8k-no-rlhf-test/blob/main/llama_rope_scaled_monkey_patch.py 2 | import torch 3 | import transformers 4 | import transformers.models.llama.modeling_llama 5 | 6 | from functools import partial 7 | from torch.distributed import get_rank, is_initialized 8 | 9 | def rank0_print(*args): 10 | if is_initialized(): 11 | if get_rank() == 0: 12 | print(*args) 13 | else: 14 | print(*args) 15 | 16 | class CondenseRotaryEmbedding(torch.nn.Module): 17 | def __init__(self, dim, ratio, max_position_embeddings=2048, base=10000, device=None): 18 | super().__init__() 19 | inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) 20 | self.register_buffer("inv_freq", inv_freq) 21 | 22 | # Build here to make `torch.jit.trace` work. 23 | self.ratio = ratio 24 | max_position_embeddings *= ratio 25 | rank0_print(f"Condensing Positional embeddings from {max_position_embeddings} to {max_position_embeddings // ratio}") 26 | self.max_seq_len_cached = max_position_embeddings 27 | t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype) / ratio 28 | freqs = torch.einsum("i,j->ij", t, self.inv_freq) 29 | # Different from paper, but it uses a different permutation in order to obtain the same calculation 30 | emb = torch.cat((freqs, freqs), dim=-1) 31 | dtype = torch.get_default_dtype() 32 | self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) 33 | self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) 34 | 35 | def forward(self, x, seq_len=None): 36 | # x: [bs, num_attention_heads, seq_len, head_size] 37 | # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case. 38 | if seq_len > self.max_seq_len_cached: 39 | self.max_seq_len_cached = seq_len 40 | t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype) / self.ratio 41 | freqs = torch.einsum("i,j->ij", t, self.inv_freq) 42 | # Different from paper, but it uses a different permutation in order to obtain the same calculation 43 | emb = torch.cat((freqs, freqs), dim=-1).to(x.device) 44 | self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(x.dtype), persistent=False) 45 | self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(x.dtype), persistent=False) 46 | return ( 47 | self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), 48 | self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), 49 | ) 50 | 51 | def replace_llama_with_condense(ratio): 52 | transformers.models.llama.modeling_llama.LlamaRotaryEmbedding = partial(CondenseRotaryEmbedding, ratio=ratio) -------------------------------------------------------------------------------- /src/infer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import jsonlines 3 | from tqdm.auto import tqdm 4 | 5 | import torch 6 | import transformers 7 | from accelerate import Accelerator 8 | 9 | from src.args import QGCArguments 10 | from src.model import ModelWithQGC 11 | from src.dataset import InferDataset 12 | from src.pooling_layers import InferPoolingLayer 13 | from src.utils.constant import * 14 | from src.utils.logger import get_logger 15 | from src.utils.metrics import benchmark_function_map 16 | logger = get_logger(__name__) 17 | 18 | 19 | def load_dataloader(args: QGCArguments, cmp_tokenizer, llm_tokenizer): 20 | dataset = InferDataset( 21 | filepath=args.data_path, 22 | cmp_tokenizer=cmp_tokenizer, 23 | llm_tokenizer=llm_tokenizer, 24 | max_doc_tokens=args.max_doc_tokens, 25 | max_num_documents=args.num_eval_documents, 26 | llm_with_neg_documents=True, 27 | instruction_name=args.instruction_name, 28 | ) 29 | dataloader = torch.utils.data.DataLoader( 30 | dataset, 31 | batch_size=args.eval_batch_size, 32 | shuffle=False, 33 | collate_fn=dataset.collate_fn, 34 | ) 35 | return dataloader 36 | 37 | def main(args: QGCArguments): 38 | transformers.trainer_utils.set_seed(args.seed) 39 | 40 | logger.info('load tokenizer ...') 41 | cmp_tokenizer = transformers.AutoTokenizer.from_pretrained(args.compressor_path) 42 | llm_tokenizer = transformers.AutoTokenizer.from_pretrained(args.lm_model_path) 43 | cmp_tokenizer.pad_token = cmp_tokenizer.unk_token 44 | 45 | additional_special_tokens = ['', '', ''] 46 | cmp_tokenizer.add_special_tokens({"additional_special_tokens": additional_special_tokens}) 47 | 48 | generation_end_token = ( 49 | '' if args.lm_model_name == 'longchat' or 'nq' in args.data_path 50 | else '\n\n' 51 | ) 52 | 53 | logger.info('load dataset ...') 54 | test_dataloader = load_dataloader(args, cmp_tokenizer, llm_tokenizer) 55 | 56 | accelerator = Accelerator() 57 | device = accelerator.device 58 | 59 | logger.info('load compressor ...') 60 | compressor_config = transformers.AutoConfig.from_pretrained(args.compressor_path) 61 | compressor_config.num_hidden_layers = args.num_compressor_layers 62 | compressor = transformers.LlamaModel.from_pretrained(args.compressor_path, config=compressor_config) 63 | compressor.resize_token_embeddings(len(cmp_tokenizer)) 64 | 65 | pooling_layer = InferPoolingLayer(args) 66 | 67 | logger.info('load lm_model ...') 68 | if args.lm_model_name == 'longchat': 69 | llm_config = transformers.AutoConfig.from_pretrained(args.lm_model_path) 70 | llm_config._flash_attn_2_enabled = True 71 | llm_config.use_cache = False 72 | from src.utils.llama_utils import replace_llama_with_condense 73 | replace_llama_with_condense(8) 74 | lm_model = transformers.LlamaForCausalLM.from_pretrained(args.lm_model_path, config=llm_config) 75 | 76 | elif args.lm_model_name == 'llama': 77 | llm_config = transformers.AutoConfig.from_pretrained(args.lm_model_path) 78 | llm_config._flash_attn_2_enabled = True 79 | llm_config.use_cache = False 80 | lm_model = transformers.LlamaForCausalLM.from_pretrained(args.lm_model_path, config=llm_config) 81 | 82 | else: 83 | raise NotImplementedError(args.lm_model_name) 84 | 85 | logger.info(f'build model and load checkpoint from {args.from_checkpoint}') 86 | model = ModelWithQGC(args, compressor=compressor, pooling_layer=pooling_layer, lm_model=lm_model) 87 | model.semantic_alignment_layer.load_state_dict(torch.load(os.path.join(args.from_checkpoint, FFN_WEIGHTS_NAME), map_location='cpu')) 88 | model.compressor.load_state_dict(torch.load(os.path.join(args.from_checkpoint, COMPRESSOR_WEIGHTS_NAME), map_location='cpu')) 89 | model.pooling_layer.load_state_dict(torch.load(os.path.join(args.from_checkpoint, POOLING_WEIGHTS_NAME), map_location='cpu')) 90 | 91 | optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate) 92 | model, optimizer, test_dataloader = accelerator.prepare(model, optimizer, test_dataloader) 93 | 94 | logger.info(f'Model Structure = {model}') 95 | 96 | @torch.no_grad() 97 | def benchmark_step(model, inputs): 98 | model.eval() 99 | benchmark_answers = inputs['answers'] 100 | 101 | cmp_llm_doc_embeds, cmp_llm_doc_mask = model.compress_doc(**inputs) 102 | first_llm_inputs = model.construct_llm_inputs_for_generation( 103 | **inputs, 104 | cmp_llm_doc_embeds=cmp_llm_doc_embeds, 105 | cmp_llm_doc_mask=cmp_llm_doc_mask, 106 | ) 107 | first_llm_outputs = model(**first_llm_inputs, use_cache=True) 108 | 109 | second_llm_inputs = { 110 | 'input_ids': inputs['llm_que_tokens'], 111 | 'attention_mask': torch.cat([first_llm_inputs['attention_mask'], inputs['llm_que_mask']], dim=1), 112 | 'past_key_values': first_llm_outputs.past_key_values, 113 | } 114 | outputs = model.generate(**second_llm_inputs, do_sample=False, max_new_tokens=args.max_new_tokens, use_cache=True) 115 | context_length = second_llm_inputs['input_ids'].size(1) 116 | 117 | llm_generations = [elem.strip().split(generation_end_token)[0] for elem in llm_tokenizer.batch_decode(outputs[:, context_length:])] 118 | benchmark_function = benchmark_function_map[args.benchmark_metric] 119 | score_values = [benchmark_function(generation, answer) for generation, answer in zip(llm_generations, benchmark_answers)] 120 | scores = torch.tensor(score_values, device=device) 121 | 122 | benchmark_outputs = [ 123 | { 124 | 'question': question, 125 | 'raw_generation': raw_generation, 126 | 'ext_generation': ext_generation, 127 | 'answers': answers, 128 | 'score': score, 129 | } 130 | for question, raw_generation, ext_generation, answers, score in zip( 131 | llm_tokenizer.batch_decode(inputs['llm_que_tokens']), 132 | llm_tokenizer.batch_decode(outputs), 133 | llm_generations, 134 | benchmark_answers, 135 | score_values, 136 | ) 137 | ] 138 | return scores, benchmark_outputs 139 | 140 | 141 | def benchmark(model, dataloader, prefix='benchmark'): 142 | benchmark_bar = tqdm( 143 | total=len(dataloader), leave=True, dynamic_ncols=True, 144 | disable=not accelerator.is_main_process, desc='benchmark' 145 | ) 146 | model.eval() 147 | scores_host = () 148 | outputs_host = [] 149 | 150 | for inputs in dataloader: 151 | scores, outputs = benchmark_step(model, inputs) 152 | scores_host += (accelerator.gather_for_metrics(scores),) 153 | outputs_host += outputs 154 | benchmark_bar.update(1) 155 | 156 | benchmark_bar.close() 157 | mean_scores = torch.cat(scores_host, dim=0).mean() 158 | return [ 159 | { 160 | f'{prefix}_score': round(mean_scores.item(), 4), 161 | }, 162 | outputs_host, 163 | ] 164 | 165 | benchmark_metrics, benchmark_outputs = benchmark(model, test_dataloader, prefix='test') 166 | logger.info(benchmark_metrics) 167 | 168 | if accelerator.is_main_process: 169 | print(benchmark_metrics) 170 | 171 | with jsonlines.open(os.path.join(args.save_path, f'benchmark.{accelerator.process_index}.jsonl'), 'w') as fw: 172 | for element in benchmark_outputs: 173 | fw.write(element) 174 | 175 | if __name__ == '__main__': 176 | parser = transformers.HfArgumentParser(QGCArguments) 177 | args = parser.parse_args_into_dataclasses()[0] 178 | logger.info(args) 179 | main(args) 180 | -------------------------------------------------------------------------------- /src/pooling_layers.py: -------------------------------------------------------------------------------- 1 | import random 2 | import math 3 | import torch 4 | import torch.nn as nn 5 | from transformers.models.llama.modeling_llama import LlamaRMSNorm 6 | 7 | 8 | def fix_window_size_pooling(hidden_states, attention_mask, weights): 9 | bsz, pooled_length, window_size, hidden_size = hidden_states.size() 10 | scatter_matrix = torch.zeros_like(attention_mask) 11 | scatter_matrix[..., ::window_size] = 1 12 | scatter_index = scatter_matrix.cumsum(dim=-1) - 1 13 | 14 | hidden_states_after_weighting = (hidden_states * weights).view(bsz, -1, hidden_size) 15 | pooling_hidden_states = torch.zeros([bsz, pooled_length, hidden_size], device=hidden_states.device).to(hidden_states.dtype) 16 | pooling_hidden_states.scatter_add_(1, scatter_index[..., None].repeat(1, 1, hidden_size), hidden_states_after_weighting) 17 | 18 | pooling_attention_mask = torch.zeros([bsz, pooled_length], device=hidden_states.device).to(attention_mask.dtype) 19 | pooling_attention_mask.scatter_add_(1, scatter_index, attention_mask) 20 | pooling_attention_mask = pooling_attention_mask.greater(0).to(attention_mask.dtype) 21 | 22 | return pooling_hidden_states, pooling_attention_mask 23 | 24 | 25 | class PoolingLayer(nn.Module): 26 | def __init__(self, args): 27 | super().__init__() 28 | self.args = args 29 | self.hidden_size = args.compressor_hidden_size 30 | self.num_heads = args.num_attention_heads 31 | self.head_dim = self.hidden_size // self.num_heads 32 | self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) 33 | self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) 34 | self.doc_layernorm = LlamaRMSNorm(args.compressor_hidden_size, eps=1e-05) 35 | self.que_layernorm = LlamaRMSNorm(args.compressor_hidden_size, eps=1e-05) 36 | 37 | 38 | def forward(self, que_hidden_states, doc_hidden_states, enc_doc_mask, enc_que_mask, window_size=None, **kwargs): 39 | if window_size is None: 40 | if self.args.random_pool_window_size: 41 | window_size = random.choice(self.args.cand_pool_window_sizes) 42 | else: 43 | window_size = self.args.pool_window_size 44 | 45 | bsz, d_len, hidden_size = doc_hidden_states.size() 46 | if d_len % window_size != 0: 47 | 48 | def padding(tensor, shape): 49 | return torch.cat([torch.zeros(shape, dtype=tensor.dtype, device=tensor.device), tensor], dim=1) 50 | 51 | padding_length = window_size - d_len % window_size 52 | doc_hidden_states = padding(doc_hidden_states, shape=(bsz, padding_length, hidden_size)) 53 | enc_doc_mask = padding(enc_doc_mask, shape=(bsz, padding_length)) 54 | d_len = enc_doc_mask.size(1) 55 | 56 | doc_hidden_states = self.doc_layernorm(doc_hidden_states) 57 | que_mean_hidden_states = que_hidden_states.masked_fill(~enc_que_mask[..., None].bool(), 0.0) 58 | que_mean_hidden_states = que_mean_hidden_states.sum(dim=1) / enc_que_mask[..., None].sum(dim=1) 59 | que_mean_hidden_states = self.que_layernorm(que_mean_hidden_states) 60 | 61 | query_states = self.q_proj(que_mean_hidden_states).view(bsz, self.num_heads, self.head_dim) 62 | key_states = self.k_proj(doc_hidden_states).view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) 63 | value_states = doc_hidden_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) 64 | 65 | pooling_weights = torch.einsum('bnh,bndh->bnd', query_states, key_states) / math.sqrt(self.head_dim) 66 | pooling_weights.masked_fill_(~enc_doc_mask.unsqueeze(1).bool(), torch.finfo(query_states.dtype).min) 67 | pooling_weights = pooling_weights.view(bsz, self.num_heads, -1, window_size) 68 | pooling_weights = pooling_weights.softmax(dim=-1, dtype=torch.float32).to(query_states.dtype) 69 | 70 | combined_pooling_weights = pooling_weights.permute(0, 2, 3, 1) 71 | combined_pooling_weights = combined_pooling_weights[..., None].repeat(1, 1, 1, 1, self.head_dim).view(bsz, -1, window_size, self.hidden_size) 72 | combined_value_states = value_states.permute(0, 2, 1, 3).view(bsz, -1, window_size, self.hidden_size) 73 | 74 | pooling_hidden_states, pooling_attention_mask = fix_window_size_pooling( 75 | hidden_states=combined_value_states, 76 | attention_mask=enc_doc_mask, 77 | weights=combined_pooling_weights, 78 | ) 79 | 80 | return pooling_hidden_states, pooling_attention_mask 81 | 82 | 83 | class InferPoolingLayer(nn.Module): 84 | def __init__(self, args): 85 | super().__init__() 86 | self.args = args 87 | self.hidden_size = args.compressor_hidden_size 88 | self.num_heads = args.num_attention_heads 89 | self.head_dim = self.hidden_size // self.num_heads 90 | self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) 91 | self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) 92 | self.doc_layernorm = LlamaRMSNorm(args.compressor_hidden_size, eps=1e-05) 93 | self.que_layernorm = LlamaRMSNorm(args.compressor_hidden_size, eps=1e-05) 94 | 95 | 96 | def forward(self, que_hidden_states, doc_hidden_states, enc_doc_mask, enc_que_mask, **kwargs): 97 | def left_padding(tensor, shape, dim): 98 | return torch.cat([torch.zeros(shape, dtype=tensor.dtype, device=tensor.device), tensor], dim=dim) 99 | 100 | pw_window_sizes = self.args.pw_window_sizes 101 | bsz, d_len, hidden_size = doc_hidden_states.size() 102 | doc_hidden_states = self.doc_layernorm(doc_hidden_states) 103 | que_mean_hidden_states = que_hidden_states.masked_fill(~enc_que_mask[..., None].bool(), 0.0) 104 | que_mean_hidden_states = que_mean_hidden_states.sum(dim=1) / enc_que_mask[..., None].sum(dim=1) 105 | que_mean_hidden_states = self.que_layernorm(que_mean_hidden_states) 106 | 107 | query_states = self.q_proj(que_mean_hidden_states).view(bsz, self.num_heads, self.head_dim) 108 | key_states = self.k_proj(doc_hidden_states).view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) 109 | value_states = doc_hidden_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) 110 | 111 | pooling_weights = torch.einsum('bnh,bndh->bnd', query_states, key_states) / math.sqrt(self.head_dim) 112 | pooling_weights.masked_fill_(~enc_doc_mask.unsqueeze(1).bool(), torch.finfo(query_states.dtype).min) 113 | 114 | pooling_max_len = math.ceil(d_len / min(pw_window_sizes)) 115 | pooling_hidden_states = torch.zeros([bsz, pooling_max_len, hidden_size], device=query_states.device).to(query_states.dtype) 116 | pooling_attention_mask = torch.zeros([bsz, pooling_max_len], device=query_states.device).to(enc_doc_mask.dtype) 117 | 118 | for index in range(self.args.num_eval_documents): 119 | current_batch_size = bsz // self.args.num_eval_documents 120 | current_weight = pooling_weights[index::self.args.num_eval_documents] 121 | current_attention_mask = enc_doc_mask[index::self.args.num_eval_documents] 122 | current_value_states = value_states[index::self.args.num_eval_documents].permute(0, 2, 1, 3).view(current_batch_size, -1, hidden_size) 123 | current_window_size = pw_window_sizes[index % len(pw_window_sizes)] 124 | 125 | if current_weight.size(-1) % current_window_size != 0: 126 | padding_length = current_window_size - current_weight.size(-1) % current_window_size 127 | current_weight = left_padding(current_weight, shape=(current_batch_size, self.num_heads, padding_length), dim=2) 128 | current_attention_mask = left_padding(current_attention_mask, shape=(current_batch_size, padding_length), dim=1) 129 | current_value_states = left_padding(current_value_states, shape=(current_batch_size, padding_length, hidden_size), dim=1) 130 | 131 | current_weight = current_weight.view(current_batch_size, self.num_heads, -1, current_window_size) 132 | current_weight = current_weight.softmax(dim=-1, dtype=torch.float32).to(query_states.dtype) 133 | current_softmax_weight = current_weight.permute(0, 2, 3, 1) 134 | current_softmax_weight = current_softmax_weight[..., None].repeat(1, 1, 1, 1, self.head_dim).view(current_batch_size, -1, current_window_size, hidden_size) 135 | current_value_states = current_value_states.view(current_batch_size, -1, current_window_size, hidden_size) 136 | current_weight_hidden_states = (current_softmax_weight * current_value_states).sum(dim=2) 137 | current_attention_mask = current_attention_mask.view(current_batch_size, -1, current_window_size).sum(dim=-1).greater(0).to(enc_doc_mask.dtype) 138 | 139 | current_len = current_weight_hidden_states.size(1) 140 | pooling_hidden_states[index::self.args.num_eval_documents, -current_len:] = current_weight_hidden_states 141 | pooling_attention_mask[index::self.args.num_eval_documents, -current_len:] = current_attention_mask 142 | 143 | return pooling_hidden_states, pooling_attention_mask -------------------------------------------------------------------------------- /reranker.py: -------------------------------------------------------------------------------- 1 | def get_distance_bm25(corpus, query): 2 | from rank_bm25 import BM25Okapi 3 | tokenized_corpus = [doc.split(" ") for doc in corpus] 4 | bm25 = BM25Okapi(tokenized_corpus) 5 | tokenized_query = query.split(" ") 6 | doc_scores = bm25.get_scores(tokenized_query) 7 | idx = [(ii, for ii in -doc_scores.argsort()] 8 | return idx 9 | 10 | 11 | def get_rank_results( 12 | self, 13 | context: list, 14 | question: str, 15 | rank_method: str, 16 | condition_in_question: str, 17 | context_tokens_length: list, 18 | ): 19 | def get_distance_bm25(corpus, query): 20 | from rank_bm25 import BM25Okapi 21 | 22 | tokenized_corpus = [doc.split(" ") for doc in corpus] 23 | bm25 = BM25Okapi(tokenized_corpus) 24 | tokenized_query = query.split(" ") 25 | doc_scores = bm25.get_scores(tokenized_query) 26 | idx = [(ii, 0) for ii in (-doc_scores).argsort()] 27 | return idx 28 | 29 | def get_distance_gzip(corpus, query): 30 | def get_score(x, y): 31 | cx, cy = len(gzip.compress(x.encode())), len(gzip.compress(y.encode())) 32 | cxy = len(gzip.compress(f"{x} {y}".encode())) 33 | return (cxy - min(cx, cy)) / max(cx, cy) 34 | 35 | import gzip 36 | 37 | doc_scores = [get_score(doc, query) for doc in corpus] 38 | idx = [(ii, 0) for ii in np.argsort(doc_scores)] 39 | return idx 40 | 41 | def get_distance_sentbert(corpus, query): 42 | from sentence_transformers import SentenceTransformer, util 43 | 44 | if self.retrieval_model is None or self.retrieval_model_name != rank_method: 45 | self.retrieval_model = SentenceTransformer("multi-qa-mpnet-base-dot-v1") 46 | self.retrieval_model_name = rank_method 47 | doc_embeds = self.retrieval_model.encode(corpus) 48 | query = self.retrieval_model.encode(query) 49 | doc_scores = -util.dot_score(doc_embeds, query).cpu().numpy().reshape(-1) 50 | idx = [(ii, 0) for ii in np.argsort(doc_scores)] 51 | return idx 52 | 53 | def get_distance_openai(corpus, query): 54 | import openai 55 | from sentence_transformers import util 56 | 57 | openai.api_key = self.open_api_config.get("api_key", "") 58 | openai.api_base = self.open_api_config.get( 59 | "api_base", "https://api.openai.com/v1" 60 | ) 61 | openai.api_type = self.open_api_config.get("api_type", "open_ai") 62 | openai.api_version = self.open_api_config.get("api_version", "2023-05-15") 63 | engine = self.open_api_config.get("engine", "text-embedding-ada-002") 64 | 65 | def get_embed(text): 66 | return openai.Embedding.create( 67 | input=[text.replace("\n", " ")], engine=engine 68 | )["data"][0]["embedding"] 69 | 70 | doc_embeds = [get_embed(i) for i in corpus] 71 | query = get_embed(query) 72 | doc_scores = -util.dot_score(doc_embeds, query).cpu().numpy().reshape(-1) 73 | idx = [(ii, 0) for ii in np.argsort(doc_scores)] 74 | return idx 75 | 76 | def get_distance_sentbert_bge(corpus, query): 77 | from sentence_transformers import SentenceTransformer, util 78 | 79 | if self.retrieval_model is None or self.retrieval_model_name != rank_method: 80 | self.retrieval_model = SentenceTransformer("BAAI/bge-large-en-v1.5") 81 | self.retrieval_model_name = rank_method 82 | doc_embeds = self.retrieval_model.encode( 83 | [i for i in corpus], normalize_embeddings=True 84 | ) 85 | query = self.retrieval_model.encode(query, normalize_embeddings=True) 86 | doc_scores = -util.dot_score(doc_embeds, query).cpu().numpy().reshape(-1) 87 | idx = [(ii, 0) for ii in np.argsort(doc_scores)] 88 | return idx 89 | 90 | def get_distance_bge_ranker(corpus, query): 91 | from transformers import AutoModelForSequenceClassification, AutoTokenizer 92 | 93 | pairs = [[i, query] for i in corpus] 94 | if self.retrieval_model is None or self.retrieval_model_name != rank_method: 95 | tokenizer = AutoTokenizer.from_pretrained("BAAI/bge-reranker-large") 96 | model = ( 97 | AutoModelForSequenceClassification.from_pretrained( 98 | "BAAI/bge-reranker-large" 99 | ) 100 | .eval() 101 | .to(self.device) 102 | ) 103 | self.retrieval_model = [tokenizer, model] 104 | self.retrieval_model_name = rank_method 105 | with torch.no_grad(): 106 | inputs = self.retrieval_model[0]( 107 | pairs, 108 | padding=True, 109 | truncation=True, 110 | return_tensors="pt", 111 | max_length=512, 112 | ).to(self.device) 113 | scores = ( 114 | self.retrieval_model[1](**inputs, return_dict=True) 115 | .logits.view( 116 | -1, 117 | ) 118 | .float() 119 | ) 120 | idx = [(ii, 0) for ii in np.argsort(-scores.cpu())] 121 | return idx 122 | 123 | def get_distance_bge_llmembedder(corpus, query): 124 | from transformers import AutoModel, AutoTokenizer 125 | 126 | if self.retrieval_model is None or self.retrieval_model_name != rank_method: 127 | tokenizer = AutoTokenizer.from_pretrained("BAAI/llm-embedder") 128 | model = ( 129 | AutoModel.from_pretrained("BAAI/llm-embedder") 130 | .eval() 131 | .to(self.device) 132 | ) 133 | self.retrieval_model = [tokenizer, model] 134 | self.retrieval_model_name = rank_method 135 | 136 | instruction_qa_query = ( 137 | "Represent this query for retrieving relevant documents: " 138 | ) 139 | instruction_qa_key = "Represent this document for retrieval: " 140 | queries = [instruction_qa_query + query for _ in corpus] 141 | keys = [instruction_qa_key + key for key in corpus] 142 | with torch.no_grad(): 143 | query_inputs = self.retrieval_model[0]( 144 | queries, 145 | padding=True, 146 | truncation=True, 147 | return_tensors="pt", 148 | max_length=512, 149 | ).to(self.device) 150 | key_inputs = self.retrieval_model[0]( 151 | keys, 152 | padding=True, 153 | truncation=True, 154 | return_tensors="pt", 155 | max_length=512, 156 | ).to(self.device) 157 | query_outputs = self.retrieval_model[1](**query_inputs) 158 | key_outputs = self.retrieval_model[1](**key_inputs) 159 | # CLS pooling 160 | query_embeddings = query_outputs.last_hidden_state[:, 0] 161 | key_embeddings = key_outputs.last_hidden_state[:, 0] 162 | # Normalize 163 | query_embeddings = torch.nn.functional.normalize( 164 | query_embeddings, p=2, dim=1 165 | ) 166 | key_embeddings = torch.nn.functional.normalize( 167 | key_embeddings, p=2, dim=1 168 | ) 169 | similarity = query_embeddings @ key_embeddings.T 170 | idx = [(ii, 0) for ii in np.argsort(-similarity[0].cpu())] 171 | return idx 172 | 173 | def get_distance_jinza(corpus, query): 174 | from numpy.linalg import norm 175 | 176 | from transformers import AutoModel 177 | 178 | def cos_sim(a, b): 179 | return (a @ b.T) / (norm(a) * norm(b)) 180 | 181 | if self.retrieval_model is None or self.retrieval_model_name != rank_method: 182 | model = ( 183 | AutoModel.from_pretrained( 184 | "jinaai/jina-embeddings-v2-base-en", trust_remote_code=True 185 | ) 186 | .eval() 187 | .to(self.device) 188 | ) 189 | self.retrieval_model = model 190 | self.retrieval_model_name = rank_method 191 | 192 | doc_embeds = self.retrieval_model.encode(corpus) 193 | query = self.retrieval_model.encode(query) 194 | doc_scores = cos_sim(doc_embeds, query) 195 | idx = [(ii, 0) for ii in np.argsort(-doc_scores)] 196 | return idx 197 | 198 | def get_distance_voyageai(corpus, query): 199 | import voyageai 200 | from sentence_transformers import util 201 | 202 | voyageai.api_key = self.open_api_config.get("voyageai_api_key", "") 203 | 204 | def get_embed(text): 205 | return voyageai.get_embedding(text, model="voyage-01") 206 | 207 | doc_embeds = [get_embed(i) for i in corpus] 208 | query = get_embed(query) 209 | doc_scores = -util.dot_score(doc_embeds, query).cpu().numpy().reshape(-1) 210 | idx = [(ii, 0) for ii in np.argsort(doc_scores)] 211 | return idx 212 | 213 | def get_distance_cohere(corpus, query): 214 | import cohere 215 | 216 | api_key = self.open_api_config.get("cohere_api_key", "") 217 | co = cohere.Client(api_key) 218 | results = co.rerank( 219 | model="rerank-english-v2.0", query=query, documents=corpus, top_n=20 220 | ) 221 | c_map = {jj: ii for ii, jj in enumerate(corpus)} 222 | doc_rank = [c_map[ii.document["text"]] for ii in results] 223 | idx = [(ii, 0) for ii in doc_rank] 224 | return idx 225 | 226 | def get_distance_longllmlingua(corpus, query): 227 | context_ppl = [ 228 | self.get_condition_ppl( 229 | d, 230 | query 231 | + " We can get the answer to this question in the given documents.", 232 | condition_in_question, 233 | ) 234 | - dl * 2 / 250 * 0 235 | for d, dl in zip(corpus, context_tokens_length) 236 | ] 237 | sort_direct = -1 if condition_in_question == "none" else 1 238 | ys = sorted(enumerate(context_ppl), key=lambda x: sort_direct * x[1]) 239 | return ys 240 | 241 | method = None 242 | if rank_method == "bm25": 243 | method = get_distance_bm25 244 | elif rank_method == "gzip": 245 | method = get_distance_gzip 246 | elif rank_method == "sentbert": 247 | method = get_distance_sentbert 248 | elif rank_method == "openai": 249 | method = get_distance_openai 250 | elif rank_method in ["longllmlingua", "llmlingua"]: 251 | method = get_distance_longllmlingua 252 | elif rank_method == "bge": 253 | method = get_distance_sentbert_bge 254 | elif rank_method == "bge_reranker": 255 | method = get_distance_bge_ranker 256 | elif rank_method == "bge_llmembedder": 257 | method = get_distance_bge_llmembedder 258 | elif rank_method == "jinza": 259 | method = get_distance_jinza 260 | elif rank_method == "voyageai": 261 | method = get_distance_voyageai 262 | elif rank_method == "cohere": 263 | method = get_distance_cohere 264 | return method(context, question) -------------------------------------------------------------------------------- /src/dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | from datasets import load_dataset 4 | from torch.utils.data import Dataset 5 | from torch.nn.utils.rnn import pad_sequence 6 | 7 | instructions_map = { 8 | 'base': 'Write a high-quality answer for the given question using only the provided search results(some of which might be irrelevant).\n\n', 9 | 'short': 'Using only the provided search results (some of which might be irrelevant), answer the following question with one or few words.\n\n', 10 | } 11 | 12 | def format_document(document, tokenizer, max_tokens): 13 | return tokenizer.decode( 14 | tokenizer( 15 | document['title'] + ' ' + document['text'] if 'title' in document else document['text'], 16 | add_special_tokens=False, 17 | )['input_ids'][:max_tokens] 18 | ) 19 | 20 | class TrainDataset(Dataset): 21 | def __init__( 22 | self, 23 | filepath, 24 | cmp_tokenizer, 25 | llm_tokenizer, 26 | max_doc_tokens, 27 | instruction_name, 28 | que_mask_ratio=None, 29 | max_num_documents=None, 30 | min_num_documents=None, 31 | random_num_documents=False, 32 | num_gold_documents=1, 33 | use_answer_as_target=False, 34 | gold_first_for_kd=False, 35 | **kwargs, 36 | ): 37 | self.dataset = load_dataset('json', data_files=filepath, split='train') 38 | self.max_doc_tokens = max_doc_tokens 39 | self.cmp_tokenizer = cmp_tokenizer 40 | self.llm_tokenizer = llm_tokenizer 41 | self.que_mask_ratio = que_mask_ratio 42 | self.max_num_documents = max_num_documents 43 | self.min_num_documents = min_num_documents 44 | self.random_num_documents = random_num_documents 45 | self.num_gold_documents = num_gold_documents 46 | self.use_answer_as_target = use_answer_as_target 47 | self.gold_first_for_kd = gold_first_for_kd 48 | 49 | self.llm_tokenizer.padding_side = 'left' 50 | if self.llm_tokenizer.pad_token is None: 51 | self.llm_tokenizer.pad_token = llm_tokenizer.unk_token 52 | self.llm_tokenizer.pad_token_id = llm_tokenizer.unk_token_id 53 | 54 | self.instruction_text = instructions_map[instruction_name] 55 | 56 | def __len__(self): 57 | return len(self.dataset) 58 | 59 | def __getitem__(self, index): 60 | example = self.dataset[index] 61 | question = example['question'] 62 | 63 | neg_documents = [ 64 | format_document(document, self.cmp_tokenizer, self.max_doc_tokens) 65 | for document in example['ctxs'] if document['isgold'] is False 66 | ] 67 | 68 | if len(neg_documents) > self.max_num_documents: 69 | neg_documents = random.sample(neg_documents, k = self.max_num_documents) 70 | else: 71 | random.shuffle(neg_documents) 72 | 73 | pos_documents = [ 74 | format_document(document, self.cmp_tokenizer, self.max_doc_tokens) 75 | for document in example['ctxs'] if document['isgold'] is True 76 | ] 77 | 78 | if len(pos_documents) > self.num_gold_documents: 79 | num_gold_documents = self.num_gold_documents 80 | if len(neg_documents) < self.max_num_documents: 81 | num_gold_documents = self.max_num_documents - len(neg_documents) 82 | pos_documents = random.sample(pos_documents, k = num_gold_documents) 83 | 84 | else: 85 | random.shuffle(pos_documents) 86 | 87 | if self.use_answer_as_target: 88 | appeared_answer_list = [] 89 | for answer in example['answers']: 90 | if answer in '\n\n'.join(pos_documents): 91 | appeared_answer_list.append(answer) 92 | 93 | target = random.choice( 94 | appeared_answer_list if appeared_answer_list != [] else example['answers'] 95 | ) 96 | else: 97 | target = example['target'] 98 | 99 | answers = example['answers'] 100 | 101 | return { 102 | 'question': question, 103 | 'neg_documents': neg_documents, 104 | 'pos_documents': pos_documents, 105 | 'target': target, 106 | 'answers': answers, 107 | } 108 | 109 | def collate_fn(self, batch): 110 | if len(batch) == 0: 111 | return {} 112 | 113 | enc_documents = [] 114 | llm_prefix_tokens = [] 115 | llm_documents = [] 116 | 117 | num_documents = ( 118 | random.randint(self.min_num_documents, self.max_num_documents) 119 | if self.random_num_documents else self.max_num_documents 120 | ) 121 | for instance in batch: 122 | instance_enc_documents = ['' + document for document in instance['pos_documents'] + instance['neg_documents']][:num_documents] 123 | random.shuffle(instance_enc_documents) 124 | enc_documents += instance_enc_documents 125 | 126 | llm_candiate_documents = instance['pos_documents'] + instance['neg_documents'] 127 | llm_candiate_documents = llm_candiate_documents[:num_documents] 128 | if not self.gold_first_for_kd: 129 | random.shuffle(llm_candiate_documents) 130 | 131 | llm_documents += [''.join(['\nDocument:' + document for document in llm_candiate_documents])] 132 | 133 | enc_questions = ['' + instance['question'] for instance in batch] 134 | llm_questions = ['\nQuestion:' + instance['question'] + '\nAnswer:' for instance in batch] 135 | llm_targets = [instance['target'] for instance in batch] 136 | llm_instructions = [self.instruction_text for _ in batch] 137 | answers = [instance['answers'] for instance in batch] 138 | 139 | llm_prefix_tokens = ['\nDocument:' for _ in enc_documents] 140 | enc_que_outputs = self.cmp_tokenizer(enc_questions, return_tensors='pt', padding=True, add_special_tokens=False) 141 | enc_doc_outputs = self.cmp_tokenizer(enc_documents, return_tensors='pt', padding=True, add_special_tokens=False) 142 | 143 | llm_ins_outputs = self.llm_tokenizer(llm_instructions, return_tensors='pt', padding=True) 144 | llm_doc_outputs = self.llm_tokenizer(llm_documents, return_tensors='pt', padding=True, add_special_tokens=False) 145 | llm_que_outputs = self.llm_tokenizer(llm_questions, return_tensors='pt', padding=True, add_special_tokens=False) 146 | llm_pfx_outputs = self.llm_tokenizer(llm_prefix_tokens, return_tensors='pt', padding=True, add_special_tokens=False) 147 | 148 | def right_padding(value, padding_value): 149 | padded_value = pad_sequence( 150 | [torch.tensor(v) for v in value], 151 | batch_first=True, 152 | padding_value=padding_value, 153 | ) 154 | return padded_value 155 | 156 | llm_tgt_outputs = [self.llm_tokenizer(ans, add_special_tokens=False).input_ids for ans in llm_targets] 157 | llm_tgt_tokens = right_padding(llm_tgt_outputs, self.llm_tokenizer.pad_token_id) 158 | llm_tgt_mask = right_padding([[1] * len(elem) for elem in llm_tgt_outputs], 0) 159 | 160 | if self.que_mask_ratio is not None and self.que_mask_ratio > 0: 161 | llm_que_tokens = llm_que_outputs.input_ids 162 | random_indices = torch.rand_like(llm_que_outputs.input_ids[:, :-2].float()).sort().indices 163 | mask_indices = random_indices[:, :int(self.que_mask_ratio * llm_que_tokens.size(1))] 164 | llm_que_outputs.input_ids = llm_que_tokens.scatter(1, mask_indices, self.llm_tokenizer.unk_token_id) 165 | 166 | return { 167 | 'enc_doc_tokens': enc_doc_outputs.input_ids, 168 | 'enc_que_tokens': enc_que_outputs.input_ids, 169 | 'enc_doc_mask': enc_doc_outputs.attention_mask, 170 | 'enc_que_mask': enc_que_outputs.attention_mask, 171 | 'llm_ins_tokens': llm_ins_outputs.input_ids, 172 | 'llm_doc_tokens': llm_doc_outputs.input_ids, 173 | 'llm_que_tokens': llm_que_outputs.input_ids, 174 | 'llm_ins_mask': llm_ins_outputs.attention_mask, 175 | 'llm_doc_mask': llm_doc_outputs.attention_mask, 176 | 'llm_que_mask': llm_que_outputs.attention_mask, 177 | 'llm_tgt_tokens': llm_tgt_tokens, 178 | 'llm_tgt_mask': llm_tgt_mask, 179 | 'llm_pfx_tokens': llm_pfx_outputs.input_ids, 180 | 'llm_pfx_mask': llm_pfx_outputs.attention_mask, 181 | 'answers': answers, 182 | } 183 | 184 | 185 | class InferDataset(Dataset): 186 | def __init__( 187 | self, 188 | filepath, 189 | cmp_tokenizer, 190 | llm_tokenizer, 191 | max_doc_tokens, 192 | instruction_name, 193 | max_num_documents=None, 194 | **kwargs, 195 | ): 196 | self.dataset = load_dataset('json', data_files=filepath, split='train') 197 | self.max_doc_tokens = max_doc_tokens 198 | self.cmp_tokenizer = cmp_tokenizer 199 | self.llm_tokenizer = llm_tokenizer 200 | self.max_num_documents = max_num_documents 201 | 202 | self.llm_tokenizer.padding_side = 'left' 203 | if self.llm_tokenizer.pad_token is None: 204 | self.llm_tokenizer.pad_token = llm_tokenizer.unk_token 205 | self.llm_tokenizer.pad_token_id = llm_tokenizer.unk_token_id 206 | 207 | self.instruction_text = instructions_map[instruction_name] 208 | 209 | 210 | def __len__(self): 211 | return len(self.dataset) 212 | 213 | 214 | def __getitem__(self, index): 215 | example = self.dataset[index] 216 | question = example['question'] 217 | 218 | documents = [ 219 | format_document(document, self.cmp_tokenizer, self.max_doc_tokens) 220 | for document in example['ctxs'][:self.max_num_documents] 221 | ] 222 | 223 | if len(documents) < self.max_num_documents: 224 | documents += ['\n' for _ in range(self.max_num_documents - len(documents))] 225 | 226 | answers = example['answers'] 227 | 228 | return { 229 | 'question': question, 230 | 'documents': documents, 231 | 'answers': answers, 232 | } 233 | 234 | 235 | def collate_fn(self, batch): 236 | if len(batch) == 0: 237 | return {} 238 | 239 | enc_documents = [] 240 | llm_prefix_tokens = [] 241 | for instance in batch: 242 | instance_documents = instance['documents'] 243 | instance_enc_docuemnts = ['' + document for document in instance_documents] 244 | enc_documents += instance_enc_docuemnts 245 | 246 | enc_questions = ['' + instance['question'] for instance in batch] 247 | llm_prefix_tokens = [f'\nDocument:' for instance in batch for _ in instance['documents']] 248 | llm_questions = ['\nQuestion:' + instance['question'] + '\nAnswer:' for instance in batch] 249 | llm_instructions = [self.instruction_text for _ in batch] 250 | answers = [instance['answers'] for instance in batch] 251 | 252 | enc_que_outputs = self.cmp_tokenizer(enc_questions, return_tensors='pt', padding=True, add_special_tokens=False) 253 | enc_doc_outputs = self.cmp_tokenizer(enc_documents, return_tensors='pt', padding=True, add_special_tokens=False) 254 | 255 | llm_ins_outputs = self.llm_tokenizer(llm_instructions, return_tensors='pt', padding=True) 256 | llm_que_outputs = self.llm_tokenizer(llm_questions, return_tensors='pt', padding=True, add_special_tokens=False) 257 | llm_pfx_outputs = self.llm_tokenizer(llm_prefix_tokens, return_tensors='pt', padding=True, add_special_tokens=False) 258 | 259 | return { 260 | 'enc_doc_tokens': enc_doc_outputs.input_ids, 261 | 'enc_que_tokens': enc_que_outputs.input_ids, 262 | 'enc_doc_mask': enc_doc_outputs.attention_mask, 263 | 'enc_que_mask': enc_que_outputs.attention_mask, 264 | 'llm_ins_tokens': llm_ins_outputs.input_ids, 265 | 'llm_que_tokens': llm_que_outputs.input_ids, 266 | 'llm_ins_mask': llm_ins_outputs.attention_mask, 267 | 'llm_que_mask': llm_que_outputs.attention_mask, 268 | 'llm_pfx_tokens': llm_pfx_outputs.input_ids, 269 | 'llm_pfx_mask': llm_pfx_outputs.attention_mask, 270 | 'answers': answers 271 | } -------------------------------------------------------------------------------- /src/model.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, List 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from transformers import LlamaForCausalLM, LlamaPreTrainedModel, LlamaModel 7 | from src.generation_utils import CCGenerationMixin 8 | 9 | 10 | def expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): 11 | bsz, src_len = mask.size() 12 | tgt_len = tgt_len if tgt_len is not None else src_len 13 | expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) 14 | inverted_mask = 1.0 - expanded_mask 15 | return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) 16 | 17 | 18 | class ModelWithQGC(LlamaPreTrainedModel, CCGenerationMixin): 19 | def __init__(self, args, compressor: LlamaModel, pooling_layer, lm_model: LlamaForCausalLM): 20 | super().__init__(lm_model.config) 21 | self.args = args 22 | 23 | self.compressor = compressor 24 | self.lm_model = lm_model 25 | self.pooling_layer = pooling_layer 26 | self.semantic_alignment_layer = nn.Linear(args.compressor_hidden_size, args.lm_model_hidden_size) 27 | 28 | for param in self.lm_model.parameters(): 29 | param.requires_grad = False 30 | 31 | if args.fix_compressor_mlp_parameters: 32 | for name, param in self.compressor.named_parameters(): 33 | if 'mlp' in name: 34 | param.requires_grad = False 35 | 36 | 37 | @property 38 | def llm_embed_tokens(self): 39 | return self.lm_model.get_input_embeddings() 40 | 41 | def context_encoder(self, input_ids, attention_mask): 42 | position_ids = attention_mask.long().cumsum(-1) - 1 43 | position_ids.masked_fill_(~attention_mask.bool(), 1) 44 | hidden_states = self.compressor.embed_tokens(input_ids) 45 | 46 | attention_mask = expand_mask( 47 | attention_mask, hidden_states.dtype, attention_mask.size(1) 48 | ).to(hidden_states.device) 49 | 50 | for idx in range(self.args.num_compressor_encoder_layers): 51 | layer = self.compressor.layers[idx] 52 | layer_outputs = layer( 53 | hidden_states, 54 | attention_mask=attention_mask, 55 | position_ids=position_ids, 56 | ) 57 | hidden_states = layer_outputs[0] 58 | return hidden_states 59 | 60 | def reviewing_layer(self, hidden_states, attention_mask): 61 | position_ids = attention_mask.long().cumsum(-1) - 1 62 | position_ids.masked_fill_(~attention_mask.bool(), 1) 63 | 64 | attention_mask = expand_mask( 65 | attention_mask, hidden_states.dtype, attention_mask.size(1) 66 | ).to(hidden_states.device) 67 | 68 | for idx in range(self.args.num_compressor_encoder_layers, self.args.num_compressor_layers): 69 | layer = self.compressor.layers[idx] 70 | layer_outputs = layer( 71 | hidden_states, 72 | attention_mask=attention_mask, 73 | position_ids=position_ids, 74 | ) 75 | hidden_states = layer_outputs[0] 76 | hidden_states = self.compressor.norm(hidden_states) 77 | return hidden_states 78 | 79 | def compress_doc( 80 | self, 81 | enc_doc_tokens, 82 | enc_doc_mask, 83 | enc_que_tokens, 84 | enc_que_mask, 85 | llm_pfx_tokens=None, 86 | llm_pfx_mask=None, 87 | **kwargs, 88 | ): 89 | doc_bsz = enc_doc_tokens.size(0) 90 | que_bsz = enc_que_tokens.size(0) 91 | 92 | if doc_bsz > que_bsz: 93 | repeat_n = doc_bsz // que_bsz 94 | enc_repeated_que_tokens = (enc_que_tokens[None, ...].repeat(repeat_n, 1, 1).view(-1, enc_que_tokens.size(1))) 95 | enc_repeated_que_mask = (enc_que_mask[None, ...].repeat(repeat_n, 1, 1).view(-1, enc_que_mask.size(1))) 96 | context_input_ids = torch.cat([enc_repeated_que_tokens, enc_doc_tokens], dim=1) 97 | context_attention_mask = torch.cat([enc_repeated_que_mask, enc_doc_mask], dim=1) 98 | else: 99 | enc_repeated_que_mask = enc_que_mask 100 | context_input_ids = torch.cat([enc_que_tokens, enc_doc_tokens], dim=1) 101 | context_attention_mask = torch.cat([enc_que_mask, enc_doc_mask], dim=1) 102 | 103 | context_hidden_states = self.context_encoder(context_input_ids, context_attention_mask) 104 | context_que_hidden_states = context_hidden_states[:, :enc_que_tokens.size(1)] 105 | context_doc_hidden_states = context_hidden_states[:, enc_que_tokens.size(1):] 106 | 107 | weighted_hidden_states, weighted_attention_mask = self.pooling_layer( 108 | que_hidden_states=context_que_hidden_states, 109 | doc_hidden_states=context_doc_hidden_states, 110 | enc_que_mask=enc_repeated_que_mask, 111 | enc_doc_mask=enc_doc_mask, 112 | **kwargs, 113 | ) 114 | 115 | reviewing_input_hidden_states = torch.cat([context_hidden_states, weighted_hidden_states], dim=1) 116 | reviewing_input_attention_mask = torch.cat([context_attention_mask, weighted_attention_mask], dim=1) 117 | 118 | reviewing_hidden_states = self.reviewing_layer(reviewing_input_hidden_states, reviewing_input_attention_mask) 119 | final_hidden_states = self.semantic_alignment_layer(reviewing_hidden_states[:, -weighted_hidden_states.size(1):]) 120 | 121 | prefix_embeds = self.llm_embed_tokens(llm_pfx_tokens) 122 | cmp_llm_hidden_states = torch.cat([prefix_embeds, final_hidden_states], dim=1) 123 | cmp_llm_attention_mask = torch.cat([llm_pfx_mask, weighted_attention_mask], dim=1) 124 | 125 | if doc_bsz > que_bsz: 126 | cmp_llm_hidden_states = cmp_llm_hidden_states.view(que_bsz, -1, cmp_llm_hidden_states.size(-1)) 127 | cmp_llm_attention_mask = cmp_llm_attention_mask.view(que_bsz, -1) 128 | 129 | return cmp_llm_hidden_states, cmp_llm_attention_mask 130 | 131 | 132 | def construct_llm_inputs( 133 | self, 134 | llm_ins_tokens, 135 | llm_ins_mask, 136 | cmp_llm_doc_embeds, 137 | cmp_llm_doc_mask, 138 | llm_que_tokens, 139 | llm_que_mask, 140 | llm_tgt_tokens, 141 | llm_tgt_mask, 142 | **kwargs, 143 | ): 144 | llm_inputs_embeds = torch.cat( 145 | [ 146 | self.llm_embed_tokens(llm_ins_tokens), 147 | cmp_llm_doc_embeds, 148 | self.llm_embed_tokens(llm_que_tokens), 149 | self.llm_embed_tokens(llm_tgt_tokens), 150 | ], 151 | dim=1, 152 | ) 153 | llm_attention_mask = torch.cat([llm_ins_mask, cmp_llm_doc_mask, llm_que_mask, llm_tgt_mask], dim=1) 154 | llm_position_ids = llm_attention_mask.long().cumsum(-1) - 1 155 | llm_position_ids.masked_fill_(~llm_attention_mask.bool(), 1) 156 | 157 | llm_labels = torch.full_like(llm_attention_mask, self.args.label_pad_token_id) 158 | llm_labels[:, -llm_tgt_tokens.size(1):] = llm_tgt_tokens.masked_fill( 159 | ~llm_tgt_mask.bool(), self.args.label_pad_token_id, 160 | ) 161 | 162 | return { 163 | 'inputs_embeds': llm_inputs_embeds, 164 | 'attention_mask': llm_attention_mask, 165 | 'position_ids': llm_position_ids, 166 | 'labels': llm_labels, 167 | } 168 | 169 | 170 | def construct_llm_inputs_for_generation( 171 | self, 172 | llm_ins_tokens, 173 | llm_ins_mask, 174 | cmp_llm_doc_embeds, 175 | cmp_llm_doc_mask, 176 | **kwargs, 177 | ): 178 | llm_inputs_embeds = torch.cat([self.llm_embed_tokens(llm_ins_tokens), cmp_llm_doc_embeds], dim=1) 179 | llm_attention_mask = torch.cat([llm_ins_mask, cmp_llm_doc_mask], dim=1) 180 | llm_position_ids = llm_attention_mask.long().cumsum(-1) - 1 181 | llm_position_ids.masked_fill_(llm_attention_mask == 0, 1) 182 | 183 | return { 184 | 'inputs_embeds': llm_inputs_embeds, 185 | 'attention_mask': llm_attention_mask, 186 | 'position_ids': llm_position_ids, 187 | } 188 | 189 | 190 | @torch.no_grad() 191 | def get_text_logits( 192 | self, 193 | llm_ins_tokens, llm_ins_mask, llm_doc_tokens, llm_doc_mask, 194 | llm_que_tokens, llm_que_mask, llm_tgt_tokens, llm_tgt_mask, 195 | ): 196 | text_llm_input_ids = torch.cat([llm_ins_tokens, llm_doc_tokens, llm_que_tokens, llm_tgt_tokens], dim=1) 197 | text_llm_attention_mask = torch.cat([llm_ins_mask, llm_doc_mask, llm_que_mask, llm_tgt_mask], dim=1) 198 | text_llm_position_ids = text_llm_attention_mask.long().cumsum(-1) - 1 199 | text_llm_position_ids.masked_fill_(text_llm_attention_mask == 0, 1) 200 | 201 | text_llm_outputs = self.lm_model( 202 | input_ids=text_llm_input_ids, 203 | attention_mask=text_llm_attention_mask, 204 | position_ids=text_llm_position_ids, 205 | ) 206 | text_logits = text_llm_outputs.logits[:, -llm_tgt_tokens.size(1):] 207 | return text_logits 208 | 209 | 210 | def joint_forward( 211 | self, 212 | enc_doc_tokens, 213 | enc_doc_mask, 214 | enc_que_tokens, 215 | enc_que_mask, 216 | llm_ins_tokens, 217 | llm_ins_mask, 218 | llm_doc_tokens, 219 | llm_doc_mask, 220 | llm_que_tokens, 221 | llm_que_mask, 222 | llm_tgt_tokens, 223 | llm_tgt_mask, 224 | llm_pfx_tokens=None, 225 | llm_pfx_mask=None, 226 | **kwargs, 227 | ): 228 | cmp_llm_doc_embeds, cmp_llm_doc_mask = self.compress_doc( 229 | enc_doc_tokens, enc_doc_mask, enc_que_tokens, enc_que_mask, 230 | llm_pfx_tokens=llm_pfx_tokens, llm_pfx_mask=llm_pfx_mask, 231 | ) 232 | 233 | text_logits = self.get_text_logits( 234 | llm_ins_tokens, llm_ins_mask, llm_doc_tokens, llm_doc_mask, 235 | llm_que_tokens, llm_que_mask, llm_tgt_tokens, llm_tgt_mask, 236 | ) 237 | 238 | embed_llm_inputs = self.construct_llm_inputs( 239 | llm_ins_tokens, llm_ins_mask, cmp_llm_doc_embeds, cmp_llm_doc_mask, 240 | llm_que_tokens, llm_que_mask, llm_tgt_tokens, llm_tgt_mask, 241 | ) 242 | embed_llm_output = self.lm_model(**embed_llm_inputs) 243 | embed_logits = embed_llm_output.logits[:, -llm_tgt_tokens.size(1):] 244 | 245 | distillation_loss = F.kl_div( 246 | F.log_softmax(embed_logits, dim=-1), 247 | F.softmax(text_logits, dim=-1), 248 | reduction='none', 249 | ) 250 | distillation_loss = distillation_loss.sum(dim=-1).masked_fill(~llm_tgt_mask.bool(), 0.0) 251 | distillation_loss = distillation_loss.sum() / llm_tgt_mask.sum() 252 | 253 | return embed_llm_output.loss, distillation_loss 254 | 255 | 256 | def prepare_inputs_for_generation( 257 | self, 258 | input_ids, 259 | past_key_values=None, 260 | attention_mask=None, 261 | inputs_embeds=None, 262 | position_ids=None, 263 | first_time=False, 264 | **kwargs, 265 | ): 266 | if past_key_values and not first_time: 267 | input_ids = input_ids[:, -1:] 268 | 269 | if attention_mask is not None and position_ids is None: 270 | position_ids = attention_mask.long().cumsum(-1) - 1 271 | position_ids.masked_fill_(attention_mask == 0, 1) 272 | 273 | if past_key_values: 274 | position_ids = position_ids[:, -input_ids.size(1):] 275 | 276 | if inputs_embeds is not None and past_key_values is None: 277 | model_inputs = {"inputs_embeds": inputs_embeds} 278 | else: 279 | model_inputs = {"input_ids": input_ids} 280 | 281 | model_inputs.update( 282 | { 283 | "past_key_values": past_key_values, 284 | "use_cache": kwargs.get("use_cache"), 285 | "attention_mask": attention_mask, 286 | "position_ids": position_ids, 287 | } 288 | ) 289 | return model_inputs 290 | 291 | 292 | def forward( 293 | self, 294 | input_ids: torch.LongTensor = None, 295 | attention_mask: Optional[torch.Tensor] = None, 296 | position_ids: Optional[torch.LongTensor] = None, 297 | past_key_values: Optional[List[torch.FloatTensor]] = None, 298 | inputs_embeds: Optional[torch.FloatTensor] = None, 299 | labels: Optional[torch.LongTensor] = None, 300 | use_cache: Optional[bool] = None, 301 | output_attentions: Optional[bool] = None, 302 | output_hidden_states: Optional[bool] = None, 303 | return_dict: Optional[bool] = None, 304 | ): 305 | return self.lm_model( 306 | input_ids=input_ids, 307 | attention_mask=attention_mask, 308 | position_ids=position_ids, 309 | past_key_values=past_key_values, 310 | inputs_embeds=inputs_embeds, 311 | labels=labels, 312 | use_cache=use_cache, 313 | output_attentions=output_attentions, 314 | output_hidden_states=output_hidden_states, 315 | return_dict=return_dict, 316 | ) -------------------------------------------------------------------------------- /src/generation_utils.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from typing import Optional, Union, List 3 | 4 | import torch 5 | import torch.distributed as dist 6 | 7 | from transformers.generation.logits_process import LogitsProcessorList 8 | from transformers.generation.stopping_criteria import ( 9 | StoppingCriteriaList, 10 | validate_stopping_criteria, 11 | ) 12 | from transformers.generation.utils import ( 13 | GenerationMixin, 14 | GreedySearchDecoderOnlyOutput, 15 | GreedySearchEncoderDecoderOutput, 16 | GreedySearchOutput, 17 | ) 18 | 19 | class CCGenerationMixin(GenerationMixin): 20 | def greedy_search( 21 | self, 22 | input_ids: torch.LongTensor, 23 | logits_processor: Optional[LogitsProcessorList] = None, 24 | stopping_criteria: Optional[StoppingCriteriaList] = None, 25 | max_length: Optional[int] = None, 26 | pad_token_id: Optional[int] = None, 27 | eos_token_id: Optional[Union[int, List[int]]] = None, 28 | output_attentions: Optional[bool] = None, 29 | output_hidden_states: Optional[bool] = None, 30 | output_scores: Optional[bool] = None, 31 | return_dict_in_generate: Optional[bool] = None, 32 | synced_gpus: Optional[bool] = False, 33 | **model_kwargs, 34 | ) -> Union[GreedySearchOutput, torch.LongTensor]: 35 | r""" 36 | NOTE: The only change between the huggingface greedy search and this 37 | function is the introduction of a "first_time" variable that doesn't 38 | truncate input ids when kv seqs are passed for the first time (for gist 39 | caching). 40 | 41 | Generates sequences of token ids for models with a language modeling head using **greedy decoding** and can be 42 | used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. 43 | 44 | 45 | 46 | In most cases, you do not need to call [`~generation.GenerationMixin.greedy_search`] directly. Use generate() 47 | instead. For an overview of generation strategies and code examples, check the [following 48 | guide](../generation_strategies). 49 | 50 | 51 | 52 | 53 | Parameters: 54 | input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): 55 | The sequence used as a prompt for the generation. 56 | logits_processor (`LogitsProcessorList`, *optional*): 57 | An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] 58 | used to modify the prediction scores of the language modeling head applied at each generation step. 59 | stopping_criteria (`StoppingCriteriaList`, *optional*): 60 | An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] 61 | used to tell if the generation loop should stop. 62 | 63 | max_length (`int`, *optional*, defaults to 20): 64 | **DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated 65 | tokens. The maximum length of the sequence to be generated. 66 | pad_token_id (`int`, *optional*): 67 | The id of the *padding* token. 68 | eos_token_id (`Union[int, List[int]]`, *optional*): 69 | The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens. 70 | output_attentions (`bool`, *optional*, defaults to `False`): 71 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under 72 | returned tensors for more details. 73 | output_hidden_states (`bool`, *optional*, defaults to `False`): 74 | Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors 75 | for more details. 76 | output_scores (`bool`, *optional*, defaults to `False`): 77 | Whether or not to return the prediction scores. See `scores` under returned tensors for more details. 78 | return_dict_in_generate (`bool`, *optional*, defaults to `False`): 79 | Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. 80 | synced_gpus (`bool`, *optional*, defaults to `False`): 81 | Whether to continue running the while loop until max_length (needed for ZeRO stage 3) 82 | model_kwargs: 83 | Additional model specific keyword arguments will be forwarded to the `forward` function of the model. 84 | If model is an encoder-decoder model the kwargs should include `encoder_outputs`. 85 | 86 | Return: 87 | [`~generation.GreedySearchDecoderOnlyOutput`], [`~generation.GreedySearchEncoderDecoderOutput`] or 88 | `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a 89 | [`~generation.GreedySearchDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and 90 | `return_dict_in_generate=True` or a [`~generation.GreedySearchEncoderDecoderOutput`] if 91 | `model.config.is_encoder_decoder=True`. 92 | 93 | Examples: 94 | 95 | ```python 96 | >>> from transformers import ( 97 | ... AutoTokenizer, 98 | ... AutoModelForCausalLM, 99 | ... LogitsProcessorList, 100 | ... MinLengthLogitsProcessor, 101 | ... StoppingCriteriaList, 102 | ... MaxLengthCriteria, 103 | ... ) 104 | 105 | >>> tokenizer = AutoTokenizer.from_pretrained("gpt2") 106 | >>> model = AutoModelForCausalLM.from_pretrained("gpt2") 107 | 108 | >>> # set pad_token_id to eos_token_id because GPT2 does not have a PAD token 109 | >>> model.generation_config.pad_token_id = model.generation_config.eos_token_id 110 | 111 | >>> input_prompt = "It might be possible to" 112 | >>> input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids 113 | 114 | >>> # instantiate logits processors 115 | >>> logits_processor = LogitsProcessorList( 116 | ... [ 117 | ... MinLengthLogitsProcessor(10, eos_token_id=model.generation_config.eos_token_id), 118 | ... ] 119 | ... ) 120 | >>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)]) 121 | 122 | >>> outputs = model.greedy_search( 123 | ... input_ids, logits_processor=logits_processor, stopping_criteria=stopping_criteria 124 | ... ) 125 | 126 | >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) 127 | ["It might be possible to get a better understanding of the nature of the problem, but it's not"] 128 | ```""" 129 | # init values 130 | logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() 131 | stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() 132 | if max_length is not None: 133 | warnings.warn( 134 | "`max_length` is deprecated in this function, use" 135 | " `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead.", 136 | UserWarning, 137 | ) 138 | stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) 139 | pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id 140 | eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id 141 | if isinstance(eos_token_id, int): 142 | eos_token_id = [eos_token_id] 143 | eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None 144 | output_scores = output_scores if output_scores is not None else self.generation_config.output_scores 145 | output_attentions = ( 146 | output_attentions if output_attentions is not None else self.generation_config.output_attentions 147 | ) 148 | output_hidden_states = ( 149 | output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states 150 | ) 151 | return_dict_in_generate = ( 152 | return_dict_in_generate 153 | if return_dict_in_generate is not None 154 | else self.generation_config.return_dict_in_generate 155 | ) 156 | 157 | # init attention / hidden states / scores tuples 158 | scores = () if (return_dict_in_generate and output_scores) else None 159 | decoder_attentions = () if (return_dict_in_generate and output_attentions) else None 160 | cross_attentions = () if (return_dict_in_generate and output_attentions) else None 161 | decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None 162 | 163 | # if model is an encoder-decoder, retrieve encoder attention weights and hidden states 164 | if return_dict_in_generate and self.config.is_encoder_decoder: 165 | encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None 166 | encoder_hidden_states = ( 167 | model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None 168 | ) 169 | 170 | # keep track of which sequences are already finished 171 | unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device) 172 | 173 | this_peer_finished = False # used by synced_gpus only 174 | first_time = True 175 | while True: 176 | if synced_gpus: 177 | # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. 178 | # The following logic allows an early break if all peers finished generating their sequence 179 | this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) 180 | # send 0.0 if we finished, 1.0 otherwise 181 | dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) 182 | # did all peers finish? the reduced sum will be 0.0 then 183 | if this_peer_finished_flag.item() == 0.0: 184 | break 185 | 186 | # prepare model inputs 187 | model_inputs = self.prepare_inputs_for_generation(input_ids, first_time=first_time, **model_kwargs) 188 | first_time = False 189 | 190 | # forward pass to get next token 191 | outputs = self( 192 | **model_inputs, 193 | return_dict=True, 194 | output_attentions=output_attentions, 195 | output_hidden_states=output_hidden_states, 196 | ) 197 | 198 | if synced_gpus and this_peer_finished: 199 | continue # don't waste resources running the code we don't need 200 | 201 | next_token_logits = outputs.logits[:, -1, :] 202 | 203 | # pre-process distribution 204 | next_tokens_scores = logits_processor(input_ids, next_token_logits) 205 | 206 | # Store scores, attentions and hidden_states when required 207 | if return_dict_in_generate: 208 | if output_scores: 209 | scores += (next_tokens_scores,) 210 | if output_attentions: 211 | decoder_attentions += ( 212 | (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) 213 | ) 214 | if self.config.is_encoder_decoder: 215 | cross_attentions += (outputs.cross_attentions,) 216 | 217 | if output_hidden_states: 218 | decoder_hidden_states += ( 219 | (outputs.decoder_hidden_states,) 220 | if self.config.is_encoder_decoder 221 | else (outputs.hidden_states,) 222 | ) 223 | 224 | # argmax 225 | next_tokens = torch.argmax(next_tokens_scores, dim=-1) 226 | 227 | # finished sentences should have their next token be a padding token 228 | if eos_token_id is not None: 229 | if pad_token_id is None: 230 | raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") 231 | next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) 232 | 233 | # update generated ids, model inputs, and length for next step 234 | input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) 235 | model_kwargs = self._update_model_kwargs_for_generation( 236 | outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder 237 | ) 238 | 239 | # if eos_token was found in one sentence, set sentence to finished 240 | if eos_token_id_tensor is not None: 241 | unfinished_sequences = unfinished_sequences.mul( 242 | next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0) 243 | ) 244 | 245 | # stop when each sentence is finished, or if we exceed the maximum length 246 | if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores): 247 | if not synced_gpus: 248 | break 249 | else: 250 | this_peer_finished = True 251 | 252 | if return_dict_in_generate: 253 | if self.config.is_encoder_decoder: 254 | return GreedySearchEncoderDecoderOutput( 255 | sequences=input_ids, 256 | scores=scores, 257 | encoder_attentions=encoder_attentions, 258 | encoder_hidden_states=encoder_hidden_states, 259 | decoder_attentions=decoder_attentions, 260 | cross_attentions=cross_attentions, 261 | decoder_hidden_states=decoder_hidden_states, 262 | ) 263 | else: 264 | return GreedySearchDecoderOnlyOutput( 265 | sequences=input_ids, 266 | scores=scores, 267 | attentions=decoder_attentions, 268 | hidden_states=decoder_hidden_states, 269 | ) 270 | else: 271 | return input_ids 272 | -------------------------------------------------------------------------------- /src/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import jsonlines 3 | from tqdm.auto import tqdm 4 | 5 | import torch 6 | import transformers 7 | from accelerate import Accelerator 8 | 9 | from src.args import QGCArguments 10 | from src.model import ModelWithQGC 11 | from src.dataset import TrainDataset 12 | from src.utils.constant import * 13 | from src.utils.logger import get_logger 14 | from src.utils.metrics import benchmark_function_map 15 | from pooling_layers import PoolingLayer 16 | 17 | logger = get_logger(__name__) 18 | def get_model_param_count(model, trainable_only): 19 | param_count = 0 20 | for param in model.parameters(): 21 | if not trainable_only or param.requires_grad: 22 | param_count += param.numel() 23 | return param_count 24 | 25 | def load_dataloader(args: QGCArguments, cmp_tokenizer, llm_tokenizer, split): 26 | if hasattr(args, f'{split}_data_path') and getattr(args, f'{split}_data_path') != None: 27 | filepath = getattr(args, f'{split}_data_path') 28 | else: 29 | filepath = os.path.join(args.data_path, f'{split}.jsonl') 30 | if not os.path.isfile(filepath): 31 | return None 32 | 33 | is_training = split == 'train' 34 | logger.info(f'load {split} data from {filepath}') 35 | 36 | dataset = TrainDataset( 37 | filepath=filepath, 38 | cmp_tokenizer=cmp_tokenizer, 39 | llm_tokenizer=llm_tokenizer, 40 | max_doc_tokens=args.max_doc_tokens, 41 | que_mask_ratio=args.question_mask_ratio if is_training else None, 42 | max_num_documents=args.max_num_documents, 43 | min_num_documents=args.min_num_documents, 44 | random_num_documents=args.random_num_documents, 45 | num_gold_documents=args.num_gold_documents, 46 | use_answer_as_target=args.use_answer_as_target, 47 | instruction_name=args.instruction_name, 48 | gold_first_for_kd=args.gold_first_for_kd, 49 | ) 50 | 51 | dataloader = torch.utils.data.DataLoader( 52 | dataset, 53 | batch_size=args.train_batch_size if is_training else args.eval_batch_size, 54 | shuffle=is_training, 55 | collate_fn=dataset.collate_fn, 56 | ) 57 | return dataloader 58 | 59 | 60 | def main(args: QGCArguments): 61 | transformers.trainer_utils.set_seed(args.seed) 62 | 63 | logger.info('load tokenizer ...') 64 | cmp_tokenizer = transformers.AutoTokenizer.from_pretrained(args.compressor_path) 65 | llm_tokenizer = transformers.AutoTokenizer.from_pretrained(args.lm_model_path) 66 | cmp_tokenizer.pad_token = cmp_tokenizer.unk_token 67 | 68 | additional_special_tokens = ['', '', ''] 69 | cmp_tokenizer.add_special_tokens({"additional_special_tokens": additional_special_tokens}) 70 | 71 | logger.info('load dataset ...') 72 | train_dataloader = load_dataloader(args, cmp_tokenizer, llm_tokenizer, 'train') 73 | dev_dataloader = load_dataloader(args, cmp_tokenizer, llm_tokenizer, 'dev') 74 | test_dataloader = load_dataloader(args, cmp_tokenizer, llm_tokenizer, 'test') 75 | 76 | if args.generation_split_token is None: 77 | args.generation_split_token = ( 78 | '' if args.lm_model_name == 'longchat' or 'natural_questions' in args.data_path 79 | else '\n\n' 80 | ) 81 | 82 | accelerator = Accelerator() 83 | device = accelerator.device 84 | 85 | logger.info('load compressor ...') 86 | compressor_config = transformers.AutoConfig.from_pretrained(args.compressor_path) 87 | compressor_config.num_hidden_layers = args.num_compressor_layers 88 | compressor = transformers.LlamaModel.from_pretrained(args.compressor_path, config=compressor_config) 89 | compressor.resize_token_embeddings(len(cmp_tokenizer)) 90 | with torch.no_grad(): 91 | compressor.get_input_embeddings().weight[-len(additional_special_tokens):] \ 92 | = compressor.get_input_embeddings().weight[:-len(additional_special_tokens)].mean(dim=0) 93 | 94 | logger.info('load lm_model ...') 95 | if args.lm_model_name == 'longchat': 96 | llm_config = transformers.AutoConfig.from_pretrained(args.lm_model_path) 97 | llm_config._flash_attn_2_enabled = True 98 | llm_config.use_cache = False 99 | from src.utils.llama_utils import replace_llama_with_condense 100 | replace_llama_with_condense(8) 101 | lm_model = transformers.LlamaForCausalLM.from_pretrained(args.lm_model_path, config=llm_config) 102 | 103 | elif args.lm_model_name == 'llama': 104 | llm_config = transformers.AutoConfig.from_pretrained(args.lm_model_path) 105 | llm_config._flash_attn_2_enabled = True 106 | llm_config.use_cache = False 107 | lm_model = transformers.LlamaForCausalLM.from_pretrained(args.lm_model_path, config=llm_config) 108 | 109 | else: 110 | raise NotImplementedError(args.lm_model_name) 111 | 112 | pooling_layer = PoolingLayer(args) 113 | model = ModelWithQGC(args, compressor=compressor, pooling_layer=pooling_layer, lm_model=lm_model) 114 | 115 | max_steps = args.max_steps // accelerator.num_processes 116 | num_examples = len(train_dataloader) 117 | total_train_batch_size = args.train_batch_size * args.gradient_accumulation_steps 118 | num_update_steps_per_epoch = num_examples // args.gradient_accumulation_steps 119 | num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1) 120 | num_train_epochs = max_steps // num_update_steps_per_epoch + int( 121 | max_steps % num_update_steps_per_epoch > 0 122 | ) 123 | 124 | optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate) 125 | lr_scheduler = transformers.get_scheduler( 126 | args.lr_scheduler_type, optimizer=optimizer, 127 | num_warmup_steps=args.get_warmup_steps(max_steps), num_training_steps=max_steps, 128 | ) 129 | 130 | if args.from_checkpoint is not None: 131 | logger.info(f'load model checkpoint from {args.from_checkpoint}') 132 | model.semantic_alignment_layer.load_state_dict(torch.load(os.path.join(args.from_checkpoint, FFN_WEIGHTS_NAME), map_location='cpu')) 133 | model.compressor.load_state_dict(torch.load(os.path.join(args.from_checkpoint, COMPRESSOR_WEIGHTS_NAME), map_location='cpu')) 134 | model.pooling_layer.load_state_dict(torch.load(os.path.join(args.from_checkpoint, POOLING_WEIGHTS_NAME), map_location='cpu')) 135 | 136 | 137 | model, optimizer, train_dataloader, dev_dataloader, test_dataloader = \ 138 | accelerator.prepare(model, optimizer, train_dataloader, dev_dataloader, test_dataloader) 139 | 140 | logger.info("***** Running training *****") 141 | logger.info(f" Num examples = {num_examples:,}") 142 | logger.info(f" Num Epochs = {num_train_epochs:,}") 143 | logger.info(f" Instantaneous batch size per device = {args.train_batch_size:,}") 144 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size:,}") 145 | logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") 146 | logger.info(f" Total optimization steps = {max_steps:,}") 147 | logger.info(f" Number of trainable parameters = {get_model_param_count(model, trainable_only=True):,}") 148 | logger.info(f" Model Structure = {model}") 149 | 150 | global_step = 0 151 | current_step = 0 152 | global_epoch = 0 153 | global_step_last_logged = 0 154 | tr_loss = torch.tensor(0.0).to(device) 155 | tr_ce_loss = torch.tensor(0.0).to(device) 156 | tr_kd_loss = torch.tensor(0.0).to(device) 157 | 158 | def training_step(model, inputs): 159 | model.train() 160 | ce_loss, kd_loss = model.joint_forward(**inputs) 161 | loss = ce_loss + kd_loss 162 | accelerator.backward(loss) 163 | 164 | return [ 165 | loss_item / args.gradient_accumulation_steps 166 | for loss_item in [loss, ce_loss, kd_loss] 167 | ] 168 | 169 | @torch.no_grad() 170 | def prediction_step(model, inputs): 171 | model.eval() 172 | ce_loss, kd_loss = model.joint_forward(**inputs) 173 | loss = ce_loss + kd_loss 174 | return loss, ce_loss, kd_loss 175 | 176 | @torch.no_grad() 177 | def benchmark_step(model, inputs): 178 | model.eval() 179 | benchmark_answers = inputs['answers'] 180 | 181 | cmp_llm_doc_embeds, cmp_llm_doc_mask = model.compress_doc(**inputs, window_size=args.pool_window_size) 182 | first_llm_inputs = model.construct_llm_inputs_for_generation( 183 | **inputs, 184 | cmp_llm_doc_embeds=cmp_llm_doc_embeds, 185 | cmp_llm_doc_mask=cmp_llm_doc_mask, 186 | ) 187 | first_llm_outputs = model(**first_llm_inputs, use_cache=True) 188 | 189 | second_llm_inputs = { 190 | 'input_ids': inputs['llm_que_tokens'], 191 | 'attention_mask': torch.cat([first_llm_inputs['attention_mask'], inputs['llm_que_mask']], dim=1), 192 | 'past_key_values': first_llm_outputs.past_key_values, 193 | } 194 | outputs = model.generate(**second_llm_inputs, do_sample=False, max_new_tokens=args.max_new_tokens, use_cache=True) 195 | 196 | raw_generations = [ 197 | element.split('Answer:')[1] for element in llm_tokenizer.batch_decode(outputs) 198 | ] 199 | 200 | # longchat use to represent end, but llama use \n\n 201 | generations = [elem.strip().split(args.generation_split_token)[0] for elem in raw_generations] 202 | benchmark_function = benchmark_function_map[args.benchmark_metric] 203 | score_values = [benchmark_function(generation, answer) for generation, answer in zip(generations, benchmark_answers)] 204 | scores = torch.tensor(score_values, device=device) 205 | 206 | benchmark_outputs = [ 207 | { 208 | 'question': question, 209 | 'raw_generation': raw_generation, 210 | 'ext_generation': ext_generation, 211 | 'answers': answers, 212 | 'score': score 213 | } 214 | for question, raw_generation, ext_generation, answers, score in zip( 215 | llm_tokenizer.batch_decode(inputs['llm_que_tokens']), 216 | llm_tokenizer.batch_decode(outputs), 217 | generations, 218 | benchmark_answers, 219 | score_values, 220 | ) 221 | ] 222 | return scores, benchmark_outputs 223 | 224 | def evaluate(model, dataloader, prefix='eval'): 225 | evaluate_bar = tqdm( 226 | total=len(dataloader), leave=True, dynamic_ncols=True, 227 | disable=not accelerator.is_main_process, desc='evaluate' 228 | ) 229 | model.eval() 230 | losses_host = () 231 | ce_losses_host = () 232 | kd_losses_host = () 233 | for inputs in dataloader: 234 | bsz = inputs['llm_tgt_tokens'].size(0) 235 | loss, ce_loss, kd_loss = prediction_step(model, inputs) 236 | losses_host += (accelerator.gather_for_metrics(loss.repeat(bsz)),) 237 | ce_losses_host += (accelerator.gather_for_metrics(ce_loss.repeat(bsz)),) 238 | kd_losses_host += (accelerator.gather_for_metrics(kd_loss.repeat(bsz)),) 239 | evaluate_bar.update(1) 240 | 241 | evaluate_bar.close() 242 | eval_loss = torch.cat(losses_host, dim=0) 243 | eval_ce_loss = torch.cat(ce_losses_host, dim=0) 244 | eval_kd_loss = torch.cat(kd_losses_host, dim=0) 245 | return { 246 | f'{prefix}_loss': round(eval_loss.mean().item(), 4), 247 | f'{prefix}_ce_loss': round(eval_ce_loss.mean().item(), 4), 248 | f'{prefix}_kd_loss': round(eval_kd_loss.mean().item(), 4), 249 | } 250 | 251 | def benchmark(model, dataloader, prefix='benchmark'): 252 | benchmark_bar = tqdm( 253 | total=len(dataloader), leave=True, dynamic_ncols=True, 254 | disable=not accelerator.is_main_process, desc='benchmark' 255 | ) 256 | model.eval() 257 | scores_host = () 258 | outputs_host = [] 259 | for inputs in dataloader: 260 | scores, outputs = benchmark_step(model, inputs) 261 | scores_host += (accelerator.gather_for_metrics(scores),) 262 | outputs_host += outputs 263 | benchmark_bar.update(1) 264 | 265 | benchmark_bar.close() 266 | mean_scores = torch.cat(scores_host, dim=0).mean() 267 | return [ 268 | { 269 | f'{prefix}_score': round(mean_scores.item(), 4), 270 | }, 271 | outputs_host, 272 | ] 273 | 274 | def save_checkpoint(model, dev_benchmark_outputs=None, test_benchmark_outputs=None): 275 | checkpoint_folder = f'{PREFIX_CHECKPOINT_DIR}-{global_step}' 276 | output_dir = os.path.join(args.output_dir, checkpoint_folder) 277 | os.makedirs(output_dir, exist_ok=True) 278 | logger.info(f"Saving model checkpoint to {output_dir}") 279 | torch.save(args, os.path.join(output_dir, TRAINING_ARGS_NAME)) 280 | 281 | compressor_state_dict = accelerator.get_state_dict(model.compressor) 282 | pooling_layer_state_dict = accelerator.get_state_dict(model.pooling_layer) 283 | ffn_state_dict = accelerator.get_state_dict(model.semantic_alignment_layer) 284 | torch.save(compressor_state_dict, os.path.join(output_dir, COMPRESSOR_WEIGHTS_NAME)) 285 | torch.save(pooling_layer_state_dict, os.path.join(output_dir, POOLING_WEIGHTS_NAME)) 286 | torch.save(ffn_state_dict, os.path.join(output_dir, FFN_WEIGHTS_NAME)) 287 | 288 | if accelerator.is_main_process: 289 | torch.save(optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME)) 290 | torch.save(lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME)) 291 | 292 | if dev_benchmark_outputs is not None: 293 | with jsonlines.open(os.path.join(output_dir, f'dev_benchmark.{accelerator.process_index}.jsonl'), 'w') as fw: 294 | for element in dev_benchmark_outputs: 295 | fw.write(element) 296 | 297 | if test_benchmark_outputs is not None: 298 | with jsonlines.open(os.path.join(output_dir, f'test_benchmark.{accelerator.process_index}.jsonl'), 'w') as fw: 299 | for element in test_benchmark_outputs: 300 | fw.write(element) 301 | 302 | model.train() 303 | model.zero_grad() 304 | total_batched_samples = 0 305 | trainning_bar = tqdm(total=max_steps, dynamic_ncols=True, disable=not accelerator.is_main_process, desc='train') 306 | for epoch in range(num_train_epochs): 307 | epoch_iterator = train_dataloader 308 | steps_in_epoch = len(epoch_iterator) 309 | 310 | step = -1 311 | for step, inputs in enumerate(epoch_iterator): 312 | total_batched_samples += 1 313 | with accelerator.accumulate(model): 314 | tr_loss_step, tr_ce_loss_step, tr_kd_loss_step = training_step(model, inputs) 315 | 316 | tr_loss += tr_loss_step 317 | tr_ce_loss += tr_ce_loss_step 318 | tr_kd_loss += tr_kd_loss_step 319 | 320 | if total_batched_samples % args.gradient_accumulation_steps == 0: 321 | if args.max_grad_norm is not None: 322 | accelerator.clip_grad_norm_(model.parameters(), args.max_grad_norm) 323 | 324 | optimizer.step() 325 | lr_scheduler.step() 326 | model.zero_grad() 327 | 328 | global_epoch = epoch + (step + 1) / steps_in_epoch 329 | global_step += 1 330 | trainning_bar.update(global_step - current_step) 331 | current_step = global_step 332 | 333 | logs = { 334 | 'step': global_step, 335 | 'epoch': round(global_epoch, 2), 336 | } 337 | if global_step % args.logging_steps == 0: 338 | tr_loss_scalar = accelerator.gather(tr_loss).mean().item() 339 | tr_ce_loss_scalar = accelerator.gather(tr_ce_loss).mean().item() 340 | tr_kd_loss_scalar = accelerator.gather(tr_kd_loss).mean().item() 341 | 342 | tr_loss -= tr_loss 343 | tr_ce_loss -= tr_ce_loss 344 | tr_kd_loss -= tr_kd_loss 345 | 346 | logs.update( 347 | { 348 | 'loss': round(tr_loss_scalar / (global_step - global_step_last_logged), 4), 349 | 'ce_loss': round(tr_ce_loss_scalar / (global_step - global_step_last_logged), 4), 350 | 'kd_loss': round(tr_kd_loss_scalar / (global_step - global_step_last_logged), 4), 351 | 'lr': round(lr_scheduler.get_last_lr()[0], 6), 352 | } 353 | ) 354 | global_step_last_logged = global_step 355 | logger.info(logs) 356 | 357 | base_metrics = { 358 | 'step': global_step, 359 | 'epoch': round(global_epoch, 2), 360 | } 361 | if dev_dataloader is not None and global_step % args.dev_steps == 0: 362 | dev_metrics = evaluate(model, dev_dataloader, prefix='dev') 363 | dev_metrics.update(base_metrics) 364 | logger.info(dev_metrics) 365 | 366 | if test_dataloader is not None and global_step % args.test_steps == 0: 367 | test_metrics = evaluate(model, test_dataloader, prefix='test') 368 | test_metrics.update(base_metrics) 369 | logger.info(test_metrics) 370 | 371 | dev_benchmark_outputs = None 372 | if dev_dataloader is not None and args.do_benchmark and global_step % args.benchmark_dev_steps == 0: 373 | dev_benchmark_metrics, dev_benchmark_outputs = benchmark(model, dev_dataloader, prefix='dev') 374 | dev_benchmark_metrics.update(base_metrics) 375 | logger.info(dev_benchmark_metrics) 376 | 377 | test_benchmark_outputs = None 378 | if test_dataloader is not None and args.do_benchmark and global_step % args.benchmark_test_steps == 0: 379 | test_benchmark_metrics, test_benchmark_outputs = benchmark(model, test_dataloader, prefix='test') 380 | test_benchmark_metrics.update(base_metrics) 381 | logger.info(test_benchmark_metrics) 382 | 383 | if global_step % args.save_steps == 0: 384 | save_checkpoint(model, dev_benchmark_outputs=dev_benchmark_outputs, test_benchmark_outputs=test_benchmark_outputs) 385 | 386 | if global_step > max_steps: 387 | trainning_bar.close() 388 | break 389 | 390 | logger.info("\n\nTraining completed. =)\n\n") 391 | save_checkpoint(model) 392 | 393 | 394 | if __name__ == '__main__': 395 | parser = transformers.HfArgumentParser(QGCArguments) 396 | args = parser.parse_args_into_dataclasses()[0] 397 | logger.info(args) 398 | main(args) 399 | --------------------------------------------------------------------------------