├── utils ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-38.pyc │ ├── __init__.cpython-310.pyc │ ├── common_utils.cpython-310.pyc │ └── common_utils.cpython-38.pyc └── common_utils.py ├── dataset ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-310.pyc │ ├── __init__.cpython-38.pyc │ ├── dataset.cpython-310.pyc │ └── dataset.cpython-38.pyc └── dataset.py ├── model ├── __init__.py ├── __pycache__ │ ├── pma.cpython-38.pyc │ ├── pma.cpython-310.pyc │ ├── __init__.cpython-38.pyc │ ├── __init__.cpython-310.pyc │ ├── gpt4o_mini.cpython-310.pyc │ ├── gpt4o_mini.cpython-38.pyc │ ├── encoder_model_bert.cpython-38.pyc │ ├── eval_model_for_70b.cpython-38.pyc │ ├── pro_model_qwen_new.cpython-38.pyc │ ├── encoder_model_bert.cpython-310.pyc │ ├── pro_model_llama_new.cpython-310.pyc │ ├── pro_model_llama_new.cpython-38.pyc │ └── pro_model_qwen_new.cpython-310.pyc ├── pma.py ├── encoder_model_bert.py └── pro_model.py ├── img └── E2LLM.png ├── pefts ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-38.pyc │ ├── __init__.cpython-310.pyc │ ├── e2llm_args.cpython-310.pyc │ ├── e2llm_args.cpython-38.pyc │ ├── e2llm_trainer.cpython-38.pyc │ └── e2llm_trainer.cpython-310.pyc ├── e2llm_args.py └── e2llm_trainer.py ├── evaluate ├── __init__.py ├── __pycache__ │ ├── f1_qa.cpython-38.pyc │ ├── __init__.cpython-38.pyc │ ├── f1_qa.cpython-310.pyc │ ├── __init__.cpython-310.pyc │ ├── rouge_sum.cpython-310.pyc │ ├── rouge_sum.cpython-38.pyc │ ├── niah_metric.cpython-310.pyc │ └── niah_metric.cpython-38.pyc ├── rouge_cn.py ├── rouge_sum.py ├── em_quality.py └── f1_qa.py ├── .gitignore ├── prompts ├── 0shot_cot.txt ├── 0shot.txt ├── 0shot_no_context.txt ├── 0shot_rag.txt └── 0shot_cot_ans.txt ├── configs ├── lora_modules.json ├── model2maxlen.json ├── eval_config.json └── train_config.json ├── requirements.txt ├── LEGAL.md ├── local └── ds_config_zero2.yaml ├── train_local_machine.sh ├── train_multi_node.sh ├── eval.sh ├── README.md ├── preprocess └── preshuffle_data_and_chunk.py ├── LICENSE.md ├── train_accelerate.py └── eval.py /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | from .pma import PMA_v1, PMA_v2 -------------------------------------------------------------------------------- /img/E2LLM.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/codefuse-ai/E2LLM/master/img/E2LLM.png -------------------------------------------------------------------------------- /pefts/__init__.py: -------------------------------------------------------------------------------- 1 | from .e2llm_args import E2LLMTrainArgs 2 | from .e2llm_trainer import E2LLMTrainer 3 | -------------------------------------------------------------------------------- /model/__pycache__/pma.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/codefuse-ai/E2LLM/master/model/__pycache__/pma.cpython-38.pyc -------------------------------------------------------------------------------- /evaluate/__init__.py: -------------------------------------------------------------------------------- 1 | from .rouge_sum import compute_rouge 2 | from .f1_qa import compute_f1 3 | from .niah_metric import compute_score -------------------------------------------------------------------------------- /model/__pycache__/pma.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/codefuse-ai/E2LLM/master/model/__pycache__/pma.cpython-310.pyc -------------------------------------------------------------------------------- /evaluate/__pycache__/f1_qa.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/codefuse-ai/E2LLM/master/evaluate/__pycache__/f1_qa.cpython-38.pyc -------------------------------------------------------------------------------- /model/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/codefuse-ai/E2LLM/master/model/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /pefts/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/codefuse-ai/E2LLM/master/pefts/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/codefuse-ai/E2LLM/master/utils/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /dataset/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/codefuse-ai/E2LLM/master/dataset/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /dataset/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/codefuse-ai/E2LLM/master/dataset/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /dataset/__pycache__/dataset.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/codefuse-ai/E2LLM/master/dataset/__pycache__/dataset.cpython-310.pyc -------------------------------------------------------------------------------- /dataset/__pycache__/dataset.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/codefuse-ai/E2LLM/master/dataset/__pycache__/dataset.cpython-38.pyc -------------------------------------------------------------------------------- /evaluate/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/codefuse-ai/E2LLM/master/evaluate/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /evaluate/__pycache__/f1_qa.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/codefuse-ai/E2LLM/master/evaluate/__pycache__/f1_qa.cpython-310.pyc -------------------------------------------------------------------------------- /model/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/codefuse-ai/E2LLM/master/model/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /model/__pycache__/gpt4o_mini.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/codefuse-ai/E2LLM/master/model/__pycache__/gpt4o_mini.cpython-310.pyc -------------------------------------------------------------------------------- /model/__pycache__/gpt4o_mini.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/codefuse-ai/E2LLM/master/model/__pycache__/gpt4o_mini.cpython-38.pyc -------------------------------------------------------------------------------- /pefts/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/codefuse-ai/E2LLM/master/pefts/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /pefts/__pycache__/e2llm_args.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/codefuse-ai/E2LLM/master/pefts/__pycache__/e2llm_args.cpython-310.pyc -------------------------------------------------------------------------------- /pefts/__pycache__/e2llm_args.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/codefuse-ai/E2LLM/master/pefts/__pycache__/e2llm_args.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/codefuse-ai/E2LLM/master/utils/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | do* 2 | show* 3 | draw* 4 | submit* 5 | init* 6 | high* 7 | run* 8 | tools* 9 | test* 10 | pred* 11 | data_statics.py 12 | train -------------------------------------------------------------------------------- /evaluate/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/codefuse-ai/E2LLM/master/evaluate/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /evaluate/__pycache__/rouge_sum.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/codefuse-ai/E2LLM/master/evaluate/__pycache__/rouge_sum.cpython-310.pyc -------------------------------------------------------------------------------- /evaluate/__pycache__/rouge_sum.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/codefuse-ai/E2LLM/master/evaluate/__pycache__/rouge_sum.cpython-38.pyc -------------------------------------------------------------------------------- /pefts/__pycache__/e2llm_trainer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/codefuse-ai/E2LLM/master/pefts/__pycache__/e2llm_trainer.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/common_utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/codefuse-ai/E2LLM/master/utils/__pycache__/common_utils.cpython-310.pyc -------------------------------------------------------------------------------- /utils/__pycache__/common_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/codefuse-ai/E2LLM/master/utils/__pycache__/common_utils.cpython-38.pyc -------------------------------------------------------------------------------- /evaluate/__pycache__/niah_metric.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/codefuse-ai/E2LLM/master/evaluate/__pycache__/niah_metric.cpython-310.pyc -------------------------------------------------------------------------------- /evaluate/__pycache__/niah_metric.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/codefuse-ai/E2LLM/master/evaluate/__pycache__/niah_metric.cpython-38.pyc -------------------------------------------------------------------------------- /pefts/__pycache__/e2llm_trainer.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/codefuse-ai/E2LLM/master/pefts/__pycache__/e2llm_trainer.cpython-310.pyc -------------------------------------------------------------------------------- /model/__pycache__/encoder_model_bert.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/codefuse-ai/E2LLM/master/model/__pycache__/encoder_model_bert.cpython-38.pyc -------------------------------------------------------------------------------- /model/__pycache__/eval_model_for_70b.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/codefuse-ai/E2LLM/master/model/__pycache__/eval_model_for_70b.cpython-38.pyc -------------------------------------------------------------------------------- /model/__pycache__/pro_model_qwen_new.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/codefuse-ai/E2LLM/master/model/__pycache__/pro_model_qwen_new.cpython-38.pyc -------------------------------------------------------------------------------- /model/__pycache__/encoder_model_bert.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/codefuse-ai/E2LLM/master/model/__pycache__/encoder_model_bert.cpython-310.pyc -------------------------------------------------------------------------------- /model/__pycache__/pro_model_llama_new.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/codefuse-ai/E2LLM/master/model/__pycache__/pro_model_llama_new.cpython-310.pyc -------------------------------------------------------------------------------- /model/__pycache__/pro_model_llama_new.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/codefuse-ai/E2LLM/master/model/__pycache__/pro_model_llama_new.cpython-38.pyc -------------------------------------------------------------------------------- /model/__pycache__/pro_model_qwen_new.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/codefuse-ai/E2LLM/master/model/__pycache__/pro_model_qwen_new.cpython-310.pyc -------------------------------------------------------------------------------- /prompts/0shot_cot.txt: -------------------------------------------------------------------------------- 1 | What is the correct answer to this question: $Q$ 2 | Choices: 3 | (A) $C_A$ 4 | (B) $C_B$ 5 | (C) $C_C$ 6 | (D) $C_D$ 7 | 8 | Let’s think step by step: -------------------------------------------------------------------------------- /prompts/0shot.txt: -------------------------------------------------------------------------------- 1 | What is the correct answer to this question: $Q$ 2 | Choices: 3 | (A) $C_A$ 4 | (B) $C_B$ 5 | (C) $C_C$ 6 | (D) $C_D$ 7 | 8 | Format your response as follows: "The correct answer is (insert answer here)". -------------------------------------------------------------------------------- /prompts/0shot_no_context.txt: -------------------------------------------------------------------------------- 1 | What is the correct answer to this question: $Q$ 2 | Choices: 3 | (A) $C_A$ 4 | (B) $C_B$ 5 | (C) $C_C$ 6 | (D) $C_D$ 7 | 8 | What is the single, most likely answer choice? Format your response as follows: "The correct answer is (insert answer here)". -------------------------------------------------------------------------------- /configs/lora_modules.json: -------------------------------------------------------------------------------- 1 | { 2 | "gte-large-en-v1.5": ["qkv_proj", "o_proj"], 3 | "Llama-2-7b-chat-hf": ["q_proj", "k_proj", "v_proj", "o_proj"], 4 | "Llama-2-13b-chat-hf": ["q_proj", "k_proj", "v_proj", "o_proj"], 5 | "Llama2-70B-Chat-hf": ["q_proj", "k_proj", "v_proj", "o_proj"], 6 | "Qwen2.5-7B-Instruct": "all-linear" 7 | } -------------------------------------------------------------------------------- /prompts/0shot_rag.txt: -------------------------------------------------------------------------------- 1 | Please read the following retrieved text chunks and answer the question below. 2 | 3 | 4 | $DOC$ 5 | 6 | 7 | What is the correct answer to this question: $Q$ 8 | Choices: 9 | (A) $C_A$ 10 | (B) $C_B$ 11 | (C) $C_C$ 12 | (D) $C_D$ 13 | 14 | Format your response as follows: "The correct answer is (insert answer here)". -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.23.0 2 | bitsandbytes==0.40.2 3 | datasets==2.14.1 4 | deepspeed==0.9.3 5 | flash-attn==2.3.6 6 | numpy==1.23.5 7 | peft==0.7.0 8 | scikit-learn==1.3.0 9 | tensorboard==2.11.0 10 | tiktoken==0.4.0 11 | tokenizers==0.15.0 12 | torch==2.0.1 13 | transformers==4.36.0 14 | transformers-stream-generator==0.0.4 15 | xformers==0.0.21 -------------------------------------------------------------------------------- /LEGAL.md: -------------------------------------------------------------------------------- 1 | Legal Disclaimer 2 | Within this source code, the comments in Chinese shall be the original, governing version. Any comment in other languages are for reference only. In the event of any conflict between the Chinese language version comments and other language version comments, the Chinese language version shall prevail. 3 | 法律免责声明 4 | 关于代码注释部分,中文注释为官方版本,其它语言注释仅做参考。中文注释可能与其它语言注释存在不一致,当中文注释与其它语言注释存在不一致时,请以中文注释为准。 -------------------------------------------------------------------------------- /prompts/0shot_cot_ans.txt: -------------------------------------------------------------------------------- 1 | Please read the following text and answer the questions below. 2 | 3 | The text is too long and omitted here. 4 | 5 | What is the correct answer to this question: $Q$ 6 | Choices: 7 | (A) $C_A$ 8 | (B) $C_B$ 9 | (C) $C_C$ 10 | (D) $C_D$ 11 | 12 | Let’s think step by step: $COT$ 13 | 14 | Based on the above, what is the single, most likely answer choice? Format your response as follows: "The correct answer is (insert answer here)". -------------------------------------------------------------------------------- /local/ds_config_zero2.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | deepspeed_config: 3 | gradient_accumulation_steps: 1 4 | gradient_clipping: 1.0 5 | offload_optimizer_device: cpu 6 | offload_param_device: none 7 | zero3_init_flag: false 8 | zero3_save_16bit_model: true 9 | zero_stage: 2 10 | # steps_per_print: 1 11 | distributed_type: DEEPSPEED 12 | downcast_bf16: 'no' 13 | dynamo_backend: 'NO' 14 | fsdp_config: {} 15 | machine_rank: 0 16 | main_training_function: main 17 | megatron_lm_config: {} 18 | mixed_precision: 'bf16' 19 | num_machines: 1 20 | num_processes: 1 21 | rdzv_backend: static 22 | same_network: true 23 | use_cpu: false -------------------------------------------------------------------------------- /configs/model2maxlen.json: -------------------------------------------------------------------------------- 1 | { 2 | "Llama-2-7B-Instruct": 4000, 3 | "Qwen2.5-7B-Instruct-YaRN": 64000, 4 | "Qwen2.5-7B-Instruct": 64000, 5 | "YaRN": 58000, 6 | "LongLoRA": 58000, 7 | "RAG": 4000, 8 | "GLM-4-9B-Chat": 120000, 9 | "Llama-3.1-8B-Instruct": 120000, 10 | "Llama-3.1-70B-Instruct": 120000, 11 | "Llama-3.3-70B-Instruct": 120000, 12 | "Llama-3.1-Nemotron-70B-Instruct": 120000, 13 | "Qwen2.5-72B-Instruct": 120000, 14 | "Mistral-Large-Instruct-2407": 120000, 15 | "Mistral-Large-Instruct-2411": 120000, 16 | "c4ai-command-r-plus-08-2024": 120000, 17 | "glm-4-plus": 120000, 18 | "gpt-4o-2024-08-06": 120000, 19 | "gpt-4o-mini-2024-07-18": 120000, 20 | "o1-mini-2024-09-12": 120000, 21 | "o1-preview-2024-09-12": 120000, 22 | "claude-3.5-sonnet-20241022": 200000 23 | } -------------------------------------------------------------------------------- /train_local_machine.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | TRAIN_CONFIG="./configs/train_config.json" 3 | 4 | 5 | 6 | DISTRIBUTED_TYPE="DeepSpeed2" 7 | # DISTRIBUTED_TYPE="DeepSpeed3" 8 | # DISTRIBUTED_TYPE="FSDP" 9 | 10 | if [ "$DISTRIBUTED_TYPE" = "DeepSpeed2" ]; then 11 | echo "DISTRIBUTED_TYPE is DeepSpeed ZeRO2" 12 | accelerate launch --config_file "./local/ds_config_zero2.yaml" train_accelerate.py --train_config "$TRAIN_CONFIG" 13 | elif [ "$DISTRIBUTED_TYPE" = "DeepSpeed3" ]; then 14 | echo "DISTRIBUTED_TYPE is DeepSpeed ZeRO3" 15 | accelerate launch --config_file "./local/ds_config_zero3.yaml" train_accelerate.py --train_config "$TRAIN_CONFIG" 16 | elif [ "$DISTRIBUTED_TYPE" = "FSDP" ]; then 17 | echo "DISTRIBUTED_TYPE is FSDP" 18 | accelerate launch --config_file "./local/fsdp_config.yaml" train_accelerate.py --train_config "$TRAIN_CONFIG" 19 | fi -------------------------------------------------------------------------------- /evaluate/rouge_cn.py: -------------------------------------------------------------------------------- 1 | from rouge_chinese import Rouge 2 | import jieba 3 | 4 | 5 | def compute_rouge(rouge, hypothesis, reference): 6 | 7 | 8 | hypothesis = ' '.join(jieba.cut(hypothesis)) 9 | 10 | reference = ' '.join(jieba.cut(reference)) 11 | 12 | scores = rouge.get_scores(hypothesis, reference) 13 | r1_f = scores[0]['rouge-1']['f'] 14 | r2_f = scores[0]['rouge-2']['f'] 15 | rl_f = scores[0]['rouge-l']['f'] 16 | 17 | return r1_f, r2_f, rl_f 18 | 19 | 20 | if __name__ == '__main__': 21 | avg_r1 = 0 22 | avg_r2 = 0 23 | avg_rl = 0 24 | rouge = Rouge() 25 | dataset = [('', '')] 26 | for h, r in dataset: 27 | count+=1 28 | r1, r2, rl = compute_rouge(rouge, h, r) 29 | avg_r1 += r1 30 | avg_r2 += r2 31 | avg_rl += rl 32 | avg_r1 = avg_r1 / count 33 | avg_r2 = avg_r2 / count 34 | avg_rl = avg_rl / count 35 | -------------------------------------------------------------------------------- /configs/eval_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "encoder_model_dir": "/path/to/encoder/model", 3 | "decoder_model_dir": "/path/to/decoder/model", 4 | "chunk_size": 512, 5 | "max_num_chunks": 500, 6 | "max_seq_len": 1024, 7 | "max_sliding_windows": null, 8 | "enc_pool_fn": "CLS", 9 | "tokens_for_each_chunk": 1, 10 | "ln": true, 11 | "proj_arch": "mlpnew2x_gelu", 12 | "alpha": 1e-9, 13 | "peft_fn": "qlora", 14 | "lora_rank_enc": 32, 15 | "lora_alpha_enc": 32, 16 | "lora_rank_dec": 8, 17 | "lora_alpha_dec": 8, 18 | "num_epochs": 20, 19 | "lr": 1e-4, 20 | "bs": 6, 21 | "log_interval": 5, 22 | "eval_interval": 5, 23 | "patience": 20, 24 | "num_ckpt": 20, 25 | "mark": "vtest3", 26 | "attn_implementation": "flash_attention_2", 27 | "num_warmup_steps": 100, 28 | "recover_training": false, 29 | "recover_global_steps": -1, 30 | "bf16": false, 31 | "distributed_type": "deepspeed", 32 | "mode": "eval" 33 | } -------------------------------------------------------------------------------- /train_multi_node.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | DISTRIBUTED_TYPE="DeepSpeed2" 3 | export OMP_NUM_THREADS=16 4 | 5 | 6 | if [ "$DISTRIBUTED_TYPE" = "DeepSpeed2" ]; then 7 | echo "DISTRIBUTED_TYPE is DeepSpeed ZeRO2" 8 | accelerate launch \ 9 | --num_machines $N_NODE \ 10 | --num_processes $(($N_NODE*$N_GPU_PER_NODE)) \ 11 | --use_deepspeed \ 12 | --deepspeed_multinode_launcher 'standard' \ 13 | --zero_stage 2 \ 14 | --offload_optimizer_device 'cpu' \ 15 | --offload_param_device 'none' \ 16 | --gradient_accumulation_steps 1 \ 17 | --gradient_clipping 1.0 \ 18 | --zero3_init_flag false \ 19 | --zero3_save_16bit_model true \ 20 | --main_training_function 'main' \ 21 | --mixed_precision 'bf16' \ 22 | --dynamo_backend 'no' \ 23 | --same_network \ 24 | --machine_rank $RANK \ 25 | --main_process_ip $MASTER_ADDR \ 26 | --main_process_port $MASTER_PORT \ 27 | --rdzv_backend 'static' \ 28 | train_accelerate.py --train_config "./configs/train_config.json" 29 | -------------------------------------------------------------------------------- /configs/train_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_data_list": "[qmsum]", 3 | "val_data_list": "[qmsum]", 4 | "encoder_model_dir": "/path/to/encoder/model", 5 | "decoder_model_dir": "/path/to/decoder/model", 6 | "chunk_size": 512, 7 | "max_num_chunks": 700, 8 | "max_seq_len": 4096, 9 | "max_sliding_windows": null, 10 | "enc_pool_fn": "PMA", 11 | "tokens_for_each_chunk": 4, 12 | "ln": true, 13 | "proj_arch": null, 14 | "alpha": 1e-9, 15 | "peft_fn": "lora", 16 | "lora_rank_enc": 32, 17 | "lora_alpha_enc": 32, 18 | "lora_rank_dec": 16, 19 | "lora_alpha_dec": 16, 20 | "train_retrieve": false, 21 | "eval_retrieve": false, 22 | "top_k": 3, 23 | "num_epochs": 20, 24 | "lr": 1e-4, 25 | "bs": 4, 26 | "log_interval": 10, 27 | "eval_interval": 150, 28 | "patience": 100, 29 | "num_ckpt": 100, 30 | "mark": "mark_of_your_model_version", 31 | "attn_implementation": "flash_attention_2", 32 | "num_warmup_steps": 100, 33 | "recover_training": false, 34 | "recover_global_steps": -1, 35 | "bf16": true, 36 | "distributed_type": "deepspeed", 37 | "mode": "train" 38 | } -------------------------------------------------------------------------------- /evaluate/rouge_sum.py: -------------------------------------------------------------------------------- 1 | from rouge import Rouge 2 | import math 3 | 4 | def compute_rouge_(rouger, hypothesis, reference): 5 | 6 | print('*'*20) 7 | print(hypothesis) 8 | print('='*20) 9 | print(reference) 10 | print('*'*20) 11 | scores = rouger.get_scores(hypothesis, reference) 12 | r1_f = scores[0]['rouge-1']['f'] 13 | r2_f = scores[0]['rouge-2']['f'] 14 | rl_f = scores[0]['rouge-l']['f'] 15 | 16 | return r1_f, r2_f, rl_f 17 | 18 | def compute_rouge(rouger, dataname, hypothesis_list, reference_list): 19 | assert len(hypothesis_list) == len(reference_list) 20 | avg_r1 = 0 21 | avg_r2 = 0 22 | avg_rl = 0 23 | rouger = Rouge() 24 | count = 0 25 | for i in range(len(hypothesis_list)): 26 | h = hypothesis_list[i] 27 | r = reference_list[i] 28 | count+=1 29 | try: 30 | r1, r2, rl = compute_rouge_(rouger, h, r) 31 | except: 32 | print('ignoring as 0.0 score') 33 | r1, r2, rl = 0.0, 0.0, 0.0 34 | avg_r1 += r1 35 | avg_r2 += r2 36 | avg_rl += rl 37 | avg_r1 = avg_r1 / count 38 | avg_r2 = avg_r2 / count 39 | avg_rl = avg_rl / count 40 | g_mean = math.pow(avg_r1*avg_r2*avg_rl, 1/3) 41 | return {f'{dataname}-r1':avg_r1, f'{dataname}-r2':avg_r2, f'{dataname}-rl':avg_rl, f'{dataname}-g_mean':g_mean} 42 | 43 | -------------------------------------------------------------------------------- /evaluate/em_quality.py: -------------------------------------------------------------------------------- 1 | import re 2 | import string 3 | 4 | 5 | def normalize_answer(s): 6 | """Lower text and remove punctuation, articles and extra whitespace.""" 7 | 8 | def remove_articles(text): 9 | return re.sub(r"\b(a|an|the)\b", " ", text) 10 | 11 | def white_space_fix(text): 12 | return " ".join(text.split()) 13 | 14 | def remove_punc(text): 15 | exclude = set(string.punctuation) 16 | return "".join(ch for ch in text if ch not in exclude) 17 | 18 | def lower(text): 19 | return text.lower() 20 | 21 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 22 | 23 | 24 | def exact_match_score(prediction, ground_truth): 25 | return normalize_answer(prediction) == normalize_answer(ground_truth) 26 | 27 | 28 | # def metric_max_over_ground_truths(metric_fn, prediction, ground_truths): 29 | # scores_for_ground_truths = [] 30 | # for ground_truth in ground_truths: 31 | # score = metric_fn(prediction, ground_truth) 32 | # scores_for_ground_truths.append(score) 33 | # return max(scores_for_ground_truths) 34 | 35 | 36 | def compute_exact_match(predictions, references): 37 | exact_match = 0 38 | for prediction, ground_truths in zip(predictions, references): 39 | em = metric_max_over_ground_truths(exact_match_score, prediction, ground_truths) 40 | exact_match += em 41 | 42 | 43 | return exact_match / len(predictions) 44 | 45 | -------------------------------------------------------------------------------- /eval.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | GLOBAL_STEP=step_of_step 3 | MARK="mark_of_your_model_version" 4 | TEST_DATA_LIST="[qmsum]" 5 | ENCODER_MODEL_DIR="/path/to/encoder/model" 6 | DECODER_MODEL_DIR="/path/to/decoder/model" 7 | CHUNK_SIZE=chunk_size 8 | MAX_NUM_CHUNKS=500 9 | ENC_POOL_FN="PMA" 10 | TOKENS_FOR_EACH_CHUNK=4 11 | LN="True" 12 | PROJ_ARCH="None" 13 | BS=bs 14 | LORA_RANK_ENC=lora_rank_enc 15 | LORA_ALPHA_ENC=lora_alpha_enc 16 | LORA_RANK_DEC=lora_rank_dec 17 | LORA_ALPHA_DEC=lora_alpha_dec 18 | EVAL_RETRIEVE="False" 19 | TOP_K=3 20 | THRES=0.5 21 | MAX_NEW_TOKENS=128 22 | ATTN_IMPLEMENTATION="eager" 23 | 24 | EVAL_CONFIG="./configs/eval_config.json" 25 | 26 | python eval.py --eval_config "$EVAL_CONFIG" \ 27 | --test_data_list $TEST_DATA_LIST \ 28 | --global_step $GLOBAL_STEP \ 29 | --mark $MARK \ 30 | --encoder_model_dir $ENCODER_MODEL_DIR \ 31 | --decoder_model_dir $DECODER_MODEL_DIR \ 32 | --chunk_size $CHUNK_SIZE \ 33 | --max_num_chunks $MAX_NUM_CHUNKS \ 34 | --enc_pool_fn $ENC_POOL_FN \ 35 | --tokens_for_each_chunk $TOKENS_FOR_EACH_CHUNK \ 36 | --ln $LN \ 37 | --proj_arch $PROJ_ARCH \ 38 | --lora_rank_enc $LORA_RANK_ENC \ 39 | --lora_alpha_enc $LORA_ALPHA_ENC \ 40 | --lora_rank_dec $LORA_RANK_DEC \ 41 | --lora_alpha_dec $LORA_ALPHA_DEC \ 42 | --eval_retrieve $EVAL_RETRIEVE \ 43 | --top_k $TOP_K \ 44 | --thres $THRES \ 45 | --attn_implementation $ATTN_IMPLEMENTATION \ 46 | --bs $BS \ 47 | --max_new_tokens $MAX_NEW_TOKENS -------------------------------------------------------------------------------- /evaluate/f1_qa.py: -------------------------------------------------------------------------------- 1 | # Copied from https://github.com/huggingface/datasets/blob/d3c7b9481d427ce41256edaf6773c47570f06f3b/metrics/squad/evaluate.py 2 | 3 | import re 4 | import string 5 | from collections import Counter 6 | 7 | 8 | def normalize_answer(s): 9 | """Lower text and remove punctuation, articles and extra whitespace.""" 10 | 11 | def remove_articles(text): 12 | return re.sub(r"\b(a|an|the)\b", " ", text) 13 | 14 | def white_space_fix(text): 15 | return " ".join(text.split()) 16 | 17 | def remove_punc(text): 18 | exclude = set(string.punctuation) 19 | return "".join(ch for ch in text if ch not in exclude) 20 | 21 | def lower(text): 22 | return text.lower() 23 | 24 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 25 | 26 | 27 | def f1_score(prediction, ground_truth): 28 | print('*'*20) 29 | print(prediction) 30 | print('='*20) 31 | print(ground_truth) 32 | print('*'*20) 33 | prediction_tokens = normalize_answer(prediction).split() 34 | ground_truth_tokens = normalize_answer(ground_truth).split() 35 | common = Counter(prediction_tokens) & Counter(ground_truth_tokens) 36 | num_same = sum(common.values()) 37 | if num_same == 0: 38 | return 0,0,0 39 | precision = 1.0 * num_same / len(prediction_tokens) 40 | recall = 1.0 * num_same / len(ground_truth_tokens) 41 | f1 = (2 * precision * recall) / (precision + recall) 42 | return precision, recall, f1 43 | 44 | 45 | def compute_f1(dataname, predictions, references): 46 | precision_all = 0 47 | recall_all = 0 48 | f1_all = 0 49 | for prediction, ground_truth in zip(predictions, references): 50 | precision, recall, f1 = f1_score(prediction, ground_truth) 51 | precision_all += precision 52 | recall_all += recall 53 | f1_all += f1 54 | res = {f"{dataname}-precision": precision_all/len(predictions), f"{dataname}-recall": recall_all/len(predictions), f"{dataname}-f1": f1_all/len(predictions)} 55 | return res 56 | 57 | -------------------------------------------------------------------------------- /pefts/e2llm_args.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, asdict 2 | from typing import List, Union 3 | 4 | 5 | @dataclass 6 | class E2LLMTrainArgs: 7 | 8 | # train datasets 9 | train_data_list: Union[str, List[str]] = "[qmsum]" 10 | 11 | # val datasets 12 | val_data_list: Union[str, List[str]] = "[qmsum]" 13 | 14 | # encoder model dir 15 | encoder_model_dir: str = "/path/to/encoder/model" 16 | 17 | # decoder model dir 18 | decoder_model_dir: str = "/path/to/decoder/model" 19 | 20 | # chunk size 21 | chunk_size: int = 512 22 | 23 | # max num chunks 24 | max_num_chunks: int = 500 25 | 26 | # max sequence length during training 27 | max_seq_len: int = 4096 28 | 29 | # max sliding windows 30 | max_sliding_windows: Union[None, int] = None 31 | 32 | # encoder pooling method 33 | enc_pool_fn: str = "CLS" 34 | 35 | # only for pma 36 | tokens_for_each_chunk: int = 4 37 | 38 | # only for pma 39 | ln: bool = True 40 | 41 | # adapter architecture, if enc_pool_fn=='PMA', set this to None (null for config.json) 42 | proj_arch: Union[None, str] = "mlpnew2x_gelu" 43 | 44 | # trade-off weight for reconstruct task 45 | alpha: float = 1e-7 46 | 47 | # 'lora' or 'qlora' 48 | peft_fn: Union[None, str] = "lora" 49 | 50 | # lora params 51 | lora_rank_enc: int = 32 52 | lora_alpha_enc: int = 32 53 | lora_rank_dec: int = 8 54 | lora_alpha_dec: int = 8 55 | 56 | # whether retrieve 57 | train_retrieve: bool = True 58 | eval_retrieve: bool = True 59 | top_k: int = 3 60 | 61 | # datasets params 62 | shuffle_sample_ratio: float = 0.0 63 | noisy_sample_ratio: float = 0.0 64 | noise_rate: float = 0.0 65 | 66 | # training epochs 67 | num_epochs: int = 5 68 | 69 | # learning rate 70 | lr: float = 1e-4 71 | 72 | # batch size 73 | bs: int = 2 74 | 75 | # logging per n steps 76 | log_interval: int = 10 77 | 78 | # evaluating per n steps 79 | eval_interval: int = 200 80 | 81 | # random seed 82 | seed: int = 2023 83 | 84 | # patience for early stop 85 | patience: int = 5 86 | 87 | # number of ckpts 88 | num_ckpt: int = 5 89 | 90 | # mark 91 | mark: str = "test" 92 | 93 | # flash attention 94 | attn_implementation: str = "flash_attention_2" 95 | 96 | # bf16 97 | bf16: bool = True 98 | 99 | # weight decay 100 | weight_decay: float = 1e-2 101 | 102 | # gradient clipping 103 | gradient_clipping: float = 1.0 104 | 105 | # warmup steps 106 | num_warmup_steps: int = 100 107 | 108 | # max training steps 109 | max_train_steps: Union[None, int] = None 110 | 111 | # max sliding windows 112 | max_sliding_windows: Union[None, int] = None 113 | 114 | # gradient accumulation steps 115 | gradient_accumulation_steps: int = 1 116 | 117 | 118 | # recover training 119 | recover_training: bool = False 120 | 121 | # recover global steps 122 | recover_global_steps: int = -1 123 | 124 | # verbose 125 | verbose: bool = True 126 | 127 | # distributed type 128 | distributed_type: str = "deepspeed" 129 | 130 | # mode, "train" or "eval" 131 | mode: str = "train" 132 | 133 | # legacy, leave them 134 | use_xformers: bool = True 135 | trust_remote_code: bool = True 136 | model_parallel_size: int = 1 137 | use_slow_tokenizer: bool = False 138 | world_size: int = 8 139 | init_timeout_seconds: Union[None, int] = 3600 140 | 141 | def dict(self): 142 | return {k: str(v) for k, v in asdict(self).items()} 143 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # E2LLM: Encoder Elongated Large Language Models for Long-Context Understanding and Reasoning 2 | 3 | This is the Pytorch implementation of E2LLM in the EMNLP'25 paper: E2LLM: Encoder Elongated Large Language Models for Long-Context Understanding and Reasoning. 4 | 5 | ## Overview 6 | ![The network architecture of E2LLM.](./img/E2LLM.png) 7 | 8 | 9 |
Abstract 10 | 11 | 12 | - We propose E2LLM, a novel long-context modeling framework built on pre-trained text encoders and decoder-only LLMs to effectively address the ``impossible triangle'' challenge. 13 | 14 | - We introduce two training objectives: soft prompt reconstruction and long-context instruction fine-tuning, enabling the LLM to understand the soft prompt while reasoning about accurate outputs. 15 | 16 | - Comprehensive experiments conducted on diverse benchmarks demonstrate the efficiency and practicality of E2LLM and reveal its superiority over 8 SOTA baselines and competency on LongBench v2. 17 | 18 |
19 | 20 | 21 | 22 | ## Requirements 23 | 24 | * Ubuntu OS 25 | * python==3.10 26 | * torch==2.0.1 27 | * cuda==11.7 28 | * accelerate==0.23.0 29 | * transformers==4.36.0 30 | * deepspeed==0.9.3 31 | * flash-attn==2.3.6 32 | * peft==0.7.0 33 | * scikit-learn==1.3.0 34 | 35 | Dependencies can be installed by: 36 | 37 | pip install -r requirements.txt 38 | 39 | 40 | The overall directory structure is as follows: 41 | 42 | ${CODE_ROOT} 43 | |-- configs 44 | |-- eval_config.json 45 | |-- lora_modules.json 46 | |-- model2maxlen.json 47 | |-- train_config.json 48 | |-- dataset 49 | |-- __init__.py 50 | |-- dataset.py 51 | |-- evaluate 52 | |-- __init__.py 53 | |-- em_quality.py 54 | |-- f1_qa.py 55 | |-- niah_metric.py 56 | |-- rouge_sum.py 57 | |-- local 58 | |-- ds_config_zero2.yaml 59 | |-- model 60 | |-- __init__.py 61 | |-- encoder_model_bert.py 62 | |-- pma.py 63 | |-- pro_model.py 64 | |-- pefts 65 | |-- __init__.py 66 | |-- e2llm_args.py 67 | |-- e2llm_trainer.py 68 | |-- preprocess 69 | |-- preshuffle_data_and_chunk.py 70 | |-- prompts 71 | |-- utils 72 | |-- __init__.py 73 | |-- common_utils.py 74 | |-- eval.py 75 | |-- eval.sh 76 | |-- train_accelerate.py 77 | |-- train_local_machine.sh 78 | |-- train_multi_node.sh 79 | 80 | 81 | ## Data preparation 82 | 83 | The five datasets (QMSum, GovReport, Quality, NarrativeQA and TriviaQA) used in this paper can be downloaded from the following links: 84 | 85 | * [QMSum](https://github.com/Yale-LILY/QMSum) 86 | * [GovReport](https://huggingface.co/datasets/ccdv/govreport-summarization) 87 | * [Quality](https://huggingface.co/datasets/emozilla/quality) 88 | * [NarrativeQA](https://github.com/google-deepmind/narrativeqa) 89 | * [TriviaQA](https://huggingface.co/datasets/mandarjoshi/trivia_qa) 90 | 91 | Before training, first convert the data into a JSONL file in the format {'context': 'xxx', 'prompt': 'xxx', 'answer': 'xxx'}. Then run 92 | 93 | ```python 94 | python preprocess/preshuffle_data_and_chunk.py 95 | ``` 96 | 97 | and set the chunk_size parameter during execution. 98 | 99 | ## Train 100 | 101 | During training, first set the desired parameters in configs/train_config.json, then run the appropriate script according to your environment: 102 | 103 | - If you are training on a local machine: 104 | 105 | ``` 106 | sh train_local_machine.sh 107 | ``` 108 | 109 | - If you are training on a cluster / multi-node setup: 110 | 111 | ``` 112 | sh train_multi_node.sh 113 | ``` 114 | 115 | ## Evaluate 116 | 117 | For inference, run 118 | 119 | ``` 120 | sh eval.sh 121 | ``` 122 | 123 | and adjust its parameters so that they match the ones used during training. 124 | 125 | ## Citation 126 | 127 | If you find our repository helpful, please cite us as follows: 128 | 129 | @misc{liao2025e2llmencoderelongatedlarge, 130 | title={E2LLM: Encoder Elongated Large Language Models for Long-Context Understanding and Reasoning}, 131 | author={Zihan Liao and Jun Wang and Hang Yu and Lingxiao Wei and Jianguo Li and Jun Wang and Wei Zhang}, 132 | year={2025}, 133 | eprint={2409.06679}, 134 | archivePrefix={arXiv}, 135 | primaryClass={cs.CL}, 136 | url={https://arxiv.org/abs/2409.06679}, 137 | } -------------------------------------------------------------------------------- /model/pma.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import math 6 | 7 | 8 | # PMA部分 post_normal 9 | class MAB_POST(nn.Module): 10 | def __init__(self, dim_Q, dim_K, dim_V, num_heads, ln=False): 11 | super(MAB_POST, self).__init__() 12 | self.dim_V = dim_V 13 | self.num_heads = num_heads 14 | self.fc_q = nn.Linear(dim_Q, dim_V) 15 | self.fc_k = nn.Linear(dim_K, dim_V) 16 | self.fc_v = nn.Linear(dim_K, dim_V) 17 | 18 | if ln: 19 | self.ln0 = nn.LayerNorm(dim_V) 20 | self.ln1 = nn.LayerNorm(dim_V) 21 | self.fc_o = nn.Linear(dim_V, dim_V) 22 | nn.init.xavier_uniform_(self.fc_q.weight) 23 | nn.init.xavier_uniform_(self.fc_k.weight) 24 | nn.init.xavier_uniform_(self.fc_v.weight) 25 | nn.init.xavier_uniform_(self.fc_o.weight) 26 | 27 | 28 | 29 | # Q(bs, 1, emb), pad_mask (bs, seq) Post-LN 30 | def forward(self, Q, K, pad_mask=None): 31 | 32 | Q_ = self.fc_q(Q) 33 | K_, V_ = self.fc_k(K), self.fc_v(K) 34 | 35 | dim_split = self.dim_V // self.num_heads 36 | Q_ = torch.cat(Q_.split(dim_split, 2), 0) # (bs* num_head, 1, emb) 37 | K_ = torch.cat(K_.split(dim_split, 2), 0) # (bs* num_head, seq, emb) 38 | V_ = torch.cat(V_.split(dim_split, 2), 0) 39 | 40 | pad_mask = pad_mask.unsqueeze(1).repeat(self.num_heads, Q.size(1), 1) # (bs*num_head, 1, seq) 41 | score = Q_.bmm(K_.transpose(1,2))/math.sqrt(self.dim_V) 42 | score = score.masked_fill(pad_mask == 0, -1e12) 43 | A = torch.softmax(score, 2) # (bs*num_head, 1, seq) 44 | A = A * pad_mask 45 | O = torch.cat(A.bmm(V_).split(Q.size(0), 0), 2) # (bs, 1, emb) 46 | O = Q + O 47 | # O = torch.cat((Q_ + A.bmm(V_)).split(Q.size(0), 0), 2) 48 | O = O if getattr(self, 'ln0', None) is None else self.ln0(O) 49 | O = O + F.relu(self.fc_o(O)) 50 | O = O if getattr(self, 'ln1', None) is None else self.ln1(O) 51 | return O 52 | 53 | 54 | 55 | class PMA_v1(nn.Module): 56 | def __init__(self, dim, compressed_dim, num_heads, num_seeds, ln=False): 57 | super(PMA_v1, self).__init__() 58 | self.S = nn.Parameter(torch.Tensor(1, num_seeds, compressed_dim)) 59 | nn.init.xavier_uniform_(self.S) 60 | self.mab = MAB_POST(compressed_dim, dim, compressed_dim, num_heads, ln=ln) 61 | self.mab.to() 62 | def forward(self, X, pad_mask): 63 | if self.S.dtype != torch.bfloat16: 64 | X = X.float() 65 | return self.mab(self.S.expand(X.size(0), -1, -1).to(X.device), X, pad_mask) 66 | 67 | 68 | # PMA部分 post_normal 69 | class MAB_POST_v2(nn.Module): 70 | def __init__(self, dim_Q, dim_K, dim_V, num_heads, ln=False): 71 | super(MAB_POST_v2, self).__init__() 72 | self.dim_V = dim_V 73 | self.num_heads = num_heads 74 | self.fc_q = nn.Linear(dim_Q, dim_V) 75 | self.fc_k = nn.Linear(dim_K, dim_V) 76 | self.fc_v = nn.Linear(dim_K, dim_V) 77 | 78 | if ln: 79 | self.ln0 = nn.LayerNorm(dim_V) 80 | self.ln1 = nn.LayerNorm(dim_V) 81 | self.fc_o = nn.Linear(dim_V, dim_V) 82 | nn.init.xavier_uniform_(self.fc_q.weight) 83 | nn.init.xavier_uniform_(self.fc_k.weight) 84 | nn.init.xavier_uniform_(self.fc_v.weight) 85 | nn.init.xavier_uniform_(self.fc_o.weight) 86 | 87 | 88 | 89 | # Q(B, num_seed, D), pad_mask (bs, seq) Post-LN 90 | def forward(self, Q, K, pad_mask=None): 91 | Q_tmp = self.fc_q(Q) # B, num_seed, C 92 | K_, V_ = self.fc_k(K), self.fc_v(K) # B, L, C 93 | 94 | dim_split = self.dim_V // self.num_heads 95 | Q_ = torch.cat(Q_tmp.split(dim_split, 2), 0) # (B* num_head, num_seed, C) 96 | K_ = torch.cat(K_.split(dim_split, 2), 0) # (B* num_head, L, C) 97 | V_ = torch.cat(V_.split(dim_split, 2), 0) # (B* num_head,L, C) 98 | 99 | pad_mask = pad_mask.unsqueeze(1).repeat(self.num_heads, Q.size(1), 1) # (B*num_head, num_seed, L) 100 | score = Q_.bmm(K_.transpose(1,2))/math.sqrt(self.dim_V) # (B*num_head, num_seed, L) 101 | score = score.masked_fill(pad_mask == 0, -1e12) # B,num_seed,L 102 | A = torch.softmax(score, 2) # (B*num_head, num_seed, L) 103 | A = A * pad_mask 104 | O = torch.cat(A.bmm(V_).split(Q.size(0), 0), 2) # (B, num_seed, D) 105 | O = Q_tmp + O 106 | # O = torch.cat((Q_ + A.bmm(V_)).split(Q.size(0), 0), 2) 107 | O = O if getattr(self, 'ln0', None) is None else self.ln0(O) 108 | O = O + F.relu(self.fc_o(O)) 109 | O = O if getattr(self, 'ln1', None) is None else self.ln1(O) 110 | return O 111 | 112 | class PMA_v2(nn.Module): 113 | def __init__(self, dim, compressed_dim, num_heads, num_seeds, ln=False): 114 | super(PMA_v2, self).__init__() 115 | self.S = nn.Parameter(torch.Tensor(1, num_seeds, dim)) 116 | nn.init.xavier_uniform_(self.S) 117 | self.mab = MAB_POST_v2(dim, dim, compressed_dim, num_heads, ln=ln) 118 | 119 | # X: (bs, seq, emb), pad_mask: (bs, seq) 120 | def forward(self, X, pad_mask): 121 | if self.S.dtype != torch.bfloat16: 122 | X = X.float() 123 | return self.mab(self.S.expand(X.size(0), -1, -1).to(X.device), X, pad_mask) -------------------------------------------------------------------------------- /preprocess/preshuffle_data_and_chunk.py: -------------------------------------------------------------------------------- 1 | import re 2 | import os 3 | import pathlib 4 | import sys 5 | import jsonlines 6 | import math 7 | from tqdm import tqdm 8 | import random 9 | from langchain_text_splitters import CharacterTextSplitter, RecursiveCharacterTextSplitter 10 | 11 | def makedirs(path): 12 | p = pathlib.Path(path) 13 | p.parent.mkdir(parents=True, exist_ok=True) 14 | return path 15 | 16 | def split_text(text, max_length=512): 17 | # 创建一个空列表来存储文本块 18 | chunks = [] 19 | 20 | current_position = 0 21 | 22 | while current_position < len(text): 23 | if current_position + max_length >= len(text): 24 | chunks.append(text[current_position:]) 25 | break 26 | chunk_candidate = text[current_position:current_position + max_length] 27 | last_punctuation_pos = max([chunk_candidate.rfind('.'), chunk_candidate.rfind('\n')]) 28 | if last_punctuation_pos == -1 or last_punctuation_pos==0: 29 | chunks.append(chunk_candidate) 30 | current_position += max_length 31 | else: 32 | chunks.append(text[current_position:current_position + last_punctuation_pos + 1]) 33 | current_position += last_punctuation_pos + 1 34 | return chunks 35 | 36 | def split_text_overlap(text, max_length=512, overlap_ratio=0.3): 37 | chunks = [] 38 | 39 | current_position = 0 40 | overlap_length = int(max_length * overlap_ratio) 41 | 42 | while current_position < len(text): 43 | if current_position + max_length >= len(text): 44 | chunks.append(text[current_position:]) 45 | break 46 | 47 | chunk_candidate = text[current_position:current_position + max_length] 48 | 49 | last_punctuation_pos = max([chunk_candidate.rfind('.'), chunk_candidate.rfind('\n')]) 50 | if last_punctuation_pos == -1 or last_punctuation_pos == 0: 51 | chunks.append(chunk_candidate) 52 | current_position += max_length 53 | else: 54 | chunks.append(text[current_position:current_position + last_punctuation_pos + 1]) 55 | current_position += last_punctuation_pos + 1 56 | 57 | if current_position >= len(text): 58 | break 59 | 60 | current_position = max(current_position - overlap_length, 0) 61 | 62 | return chunks 63 | 64 | 65 | 66 | def preshuffle_and_chunkfile(column_names, chunk_size, data_path, write_data_path, max_samples): 67 | if os.path.exists(write_data_path): 68 | raise ValueError(f'Existed file: {write_data_path}') 69 | else: 70 | makedirs(write_data_path) 71 | all_data = [] 72 | with jsonlines.open(data_path, 'r') as fr: 73 | for i, line in tqdm(enumerate(fr)): 74 | if max_samples is not None: 75 | if i >= max_samples: 76 | break 77 | input_ = line[column_names['input']] 78 | prompt_ = line[column_names['prompt']] 79 | output_ = line[column_names['output']] 80 | 81 | chunked_input = split_text(input_, chunk_size) 82 | all_data.append({'input':chunked_input, 'prompt':'Restate the aforementioned contexts:', 'output':input_, 'task_id':0}) 83 | all_data.append({'input':chunked_input, 'prompt':f'Answer the question:{prompt_}', 'output':output_, 'task_id':1}) 84 | random.shuffle(all_data) 85 | with jsonlines.open(write_data_path, 'w') as fw: 86 | for data in all_data: 87 | fw.write(data) 88 | 89 | def preshuffle_and_chunkfile_summary(column_names, chunk_size, data_path, write_data_path, max_samples): 90 | if os.path.exists(write_data_path): 91 | raise ValueError(f'Existed file: {write_data_path}') 92 | else: 93 | makedirs(write_data_path) 94 | all_data = [] 95 | with jsonlines.open(data_path, 'r') as fr: 96 | for i, line in tqdm(enumerate(fr)): 97 | if max_samples is not None: 98 | if i >= max_samples: 99 | break 100 | input_ = line[column_names['input']] 101 | prompt_ = 'What does the above context describe?' 102 | output_ = line[column_names['output']] 103 | 104 | 105 | chunked_input = split_text(input_, chunk_size) 106 | all_data.append({'input':chunked_input, 'prompt':'Restate the aforementioned contexts:', 'output':input_, 'task_id':0}) 107 | all_data.append({'input':chunked_input, 'prompt':f'Answer the question:{prompt_}', 'output':output_, 'task_id':1}) 108 | random.shuffle(all_data) 109 | with jsonlines.open(write_data_path, 'w') as fw: 110 | for data in all_data: 111 | fw.write(data) 112 | 113 | 114 | 115 | if __name__ == '__main__': 116 | chunk_size = 512 117 | 118 | 119 | # qmsum 120 | max_samples_eval = None 121 | column_names = {'input':'context', 'prompt':'prompt', 'output':'answer'} 122 | write_data_path = f'/path/to/chunked/dataset/with/chunk_size/eval.jsonl' 123 | data_path = '/data/to/raw_eval.jsonl' 124 | preshuffle_and_chunkfile(column_names, chunk_size, data_path, write_data_path, max_samples_eval) 125 | max_samples_train = None 126 | data_path = '/data/to/raw_train.jsonl' 127 | write_data_path = f'/path/to/chunked/dataset/with/chunk_size/train.jsonl' 128 | preshuffle_and_chunkfile(column_names, chunk_size, data_path, write_data_path, max_samples_train) 129 | -------------------------------------------------------------------------------- /utils/common_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import random 4 | import pathlib 5 | import numpy as np 6 | import jsonlines 7 | from scipy.stats import pearsonr, spearmanr 8 | import torch 9 | from loguru import logger 10 | import shutil 11 | from torch.utils.tensorboard import SummaryWriter 12 | import pickle 13 | import linecache 14 | import tracemalloc 15 | import torch.distributed as dist 16 | import oss2 17 | import base64 18 | 19 | def set_seed(seed): 20 | random.seed(seed) 21 | np.random.seed(seed) 22 | torch.manual_seed(seed) 23 | if torch.cuda.is_available(): 24 | torch.cuda.manual_seed_all(seed) 25 | 26 | def save_model(model_engine, ckpt_dir, client_state): 27 | model_engine.save_checkpoint(ckpt_dir, client_state=client_state) 28 | 29 | def remove_earlier_ckpt(path, start_name, current_step_num, max_save_num): 30 | 31 | filenames=os.listdir(path) 32 | ckpts = [dir_name for dir_name in filenames if os.path.isdir(os.path.join(path, dir_name)) and dir_name.startswith(start_name) and int(dir_name.split('-')[1])<=int(current_step_num)] 33 | current_ckpt_num = len(ckpts) 34 | print(f'existing ckpts to remove:{ckpts}') 35 | for dir_name in ckpts: 36 | if dir_name.startswith(start_name) and int(dir_name.split('-')[1]) <= int(current_step_num) and current_ckpt_num > (max_save_num-1): 37 | try: 38 | shutil.rmtree(os.path.join(path, dir_name)) 39 | except Exception as e: 40 | print('Error:', e) 41 | 42 | def makedirs(path): 43 | p = pathlib.Path(path) 44 | p.parent.mkdir(parents=True, exist_ok=True) 45 | return path 46 | 47 | def write_jsonl(obj, file_path): 48 | with jsonlines.open(file_path, 'a') as writer: 49 | writer.write(obj) 50 | 51 | 52 | def load_pickle(path): 53 | with open(path, "rb") as f: 54 | return pickle.load(f) 55 | 56 | def write_pickle(obj, path:str): 57 | if not os.path.exists(path): 58 | makedirs(path) 59 | with open(path, "wb") as f: 60 | return pickle.dump(obj, f) 61 | 62 | def write_tensorboard(summary_writer, log_dict, completed_steps): 63 | for key, value in log_dict.items(): 64 | summary_writer.add_scalar(f'{key}', value, completed_steps) 65 | 66 | def cos_sim(a, b): 67 | """ 68 | Computes the cosine similarity cos_sim(a[i], b[j]) for all i and j. 69 | :return: Matrix with res[i][j] = cos_sim(a[i], b[j]) 70 | """ 71 | if not isinstance(a, torch.Tensor): 72 | a = torch.tensor(a) 73 | 74 | if not isinstance(b, torch.Tensor): 75 | b = torch.tensor(b) 76 | 77 | if len(a.shape) == 1: 78 | a = a.unsqueeze(0) 79 | 80 | if len(b.shape) == 1: 81 | b = b.unsqueeze(0) 82 | 83 | a_norm = torch.nn.functional.normalize(a, p=2, dim=1) 84 | b_norm = torch.nn.functional.normalize(b, p=2, dim=1) 85 | return torch.mm(a_norm, b_norm.transpose(0, 1)) 86 | 87 | def gather_across_devices(v, global_rank, world_size): 88 | if v is None: 89 | return None 90 | v = v.contiguous() 91 | 92 | all_tensors = [torch.empty_like(v) for _ in range(world_size)] 93 | dist.all_gather(all_tensors, v) 94 | 95 | all_tensors[global_rank] = v 96 | all_tensors = torch.cat(all_tensors, dim=0) 97 | 98 | return all_tensors 99 | 100 | 101 | 102 | 103 | # def save_model(model_engine, ckpt_dir, client_state): 104 | # model_engine.save_checkpoint(ckpt_dir, client_state=client_state) 105 | 106 | def save_model(model_engine, ckpt_dir, client_state): 107 | model_engine.save_checkpoint(ckpt_dir, client_state=client_state, exclude_frozen_parameters=True) 108 | 109 | def punctuation_format(text: str): 110 | # Replace non-breaking space with space 111 | # text = text.strip() + '\n' 112 | text = text.replace("\u202f", " ").replace("\xa0", " ") 113 | # change chinese punctuation to english ones 114 | # text = text.translate(table) 115 | if not text.endswith("\n"): 116 | text += "\n" 117 | return text 118 | 119 | def randomly_insert_list1_into_list2(list1, list2): 120 | len_combined = len(list1) + len(list2) 121 | insertion_points = sorted(random.sample(range(len_combined), len(list2))) 122 | 123 | combined_list = [] 124 | insert_index = 0 125 | list2_index = 0 126 | for i in range(len_combined): 127 | if insert_index < len(insertion_points) and i == insertion_points[insert_index]: 128 | combined_list.append(list2[list2_index]) 129 | list2_index += 1 130 | insert_index += 1 131 | else: 132 | combined_list.append(list1[i - list2_index]) 133 | 134 | return combined_list 135 | 136 | def get_parameter_number(model): 137 | total_num = sum(p.numel() for p in model.parameters()) 138 | trainable_num = sum(p.numel() for p in model.parameters() if p.requires_grad) 139 | return total_num, trainable_num 140 | 141 | def find_overlap(s1, s2): 142 | max_overlap = 0 143 | overlap_str = s1 + s2 144 | for i in range(1, min(len(s1), len(s2)) + 1): 145 | if s1[-i:] == s2[:i]: 146 | if i > max_overlap: 147 | max_overlap = i 148 | overlap_str = s1 + s2[i:] 149 | 150 | return overlap_str 151 | 152 | def merge_strings(strings): 153 | 154 | if not strings: 155 | return "" 156 | 157 | merged = strings[0] 158 | for s in strings[1:]: 159 | merged = find_overlap(merged, s) 160 | 161 | return merged 162 | 163 | def compile_helper(): 164 | """Compile helper function at runtime. Make sure this 165 | is invoked on a single process.""" 166 | import os 167 | import subprocess 168 | 169 | path = os.path.abspath(os.path.dirname(__file__)) 170 | ret = subprocess.run(["make", "-C", path]) 171 | if ret.returncode != 0: 172 | print("Making C++ dataset helpers module failed, exiting.") 173 | import sys 174 | 175 | sys.exit(1) 176 | else: 177 | print("Making C++ dataset helpers module successfully.") 178 | 179 | 180 | def print_rank_0(*message): 181 | """If distributed is initialized print only on rank 0.""" 182 | if torch.distributed.is_initialized(): 183 | if torch.distributed.get_rank() == 0: 184 | print(*message, flush=True) 185 | else: 186 | print(*message, flush=True) -------------------------------------------------------------------------------- /model/encoder_model_bert.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import json 4 | import argparse 5 | import torch 6 | import warnings 7 | import deepspeed 8 | from enum import Enum 9 | from typing import Union, List 10 | import torch.nn as nn 11 | from tqdm import tqdm, trange 12 | import numpy as np 13 | warnings.filterwarnings('ignore') 14 | from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM, AutoConfig 15 | from peft import LoraConfig, get_peft_model, TaskType 16 | from model import PMA_v1, PMA_v2 17 | 18 | 19 | 20 | class EncoderType(Enum): 21 | FIRST_LAST_AVG = 0 22 | LAST_AVG = 1 23 | CLS = 2 24 | POOLER = 3 25 | MEAN = 4 26 | PMA = 5 27 | 28 | def __str__(self): 29 | return self.name 30 | 31 | @staticmethod 32 | def from_string(s): 33 | try: 34 | return EncoderType[s] 35 | except KeyError: 36 | raise ValueError() 37 | 38 | class BaseBertModel(nn.Module): 39 | def __init__( 40 | self, 41 | model_name_or_path = None, 42 | max_seq_length = 512, 43 | encoder_type = 'CLS', 44 | pma_output_dim = -1, # pma参数 45 | tokens_for_each_chunk = -1, # pma参数 46 | ln = True, # pma参数 47 | alias = None 48 | ): 49 | super().__init__() 50 | self.model_name_or_path = model_name_or_path 51 | encoder_type = encoder_type.upper() 52 | encoder_type = EncoderType.from_string(encoder_type) if isinstance(encoder_type, str) else encoder_type 53 | 54 | if encoder_type not in list(EncoderType): 55 | raise ValueError(f'encoder_type must be in {list(EncoderType)}') 56 | self.encoder_type = encoder_type 57 | self.max_seq_length = max_seq_length 58 | self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, truncation_side='right', padding_side='right') 59 | self.plm_model = AutoModel.from_pretrained(model_name_or_path, trust_remote_code=True) 60 | self.results = {} 61 | if encoder_type == EncoderType.PMA: 62 | pma_input_dim = self.plm_model.config.hidden_size 63 | 64 | self.pma = PMA_v1(dim=pma_input_dim, compressed_dim=pma_output_dim, num_heads=self.plm_model.config.num_attention_heads, num_seeds=tokens_for_each_chunk, ln=ln) 65 | 66 | 67 | 68 | def get_sentence_embeddings(self, input_ids, attention_mask, token_type_ids=None): 69 | """ 70 | Returns the model output by encoder_type as embeddings. 71 | 72 | Utility function for self.bert() method. 73 | """ 74 | model_output = self.plm_model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, output_hidden_states=True) 75 | 76 | if self.encoder_type == EncoderType.FIRST_LAST_AVG: 77 | 78 | first = model_output.hidden_states[1] 79 | last = model_output.hidden_states[-1] 80 | seq_length = first.size(1) # Sequence length 81 | 82 | first_avg = torch.avg_pool1d(first.transpose(1, 2), kernel_size=seq_length).squeeze(-1) # [batch, hid_size] 83 | last_avg = torch.avg_pool1d(last.transpose(1, 2), kernel_size=seq_length).squeeze(-1) # [batch, hid_size] 84 | final_encoding = torch.avg_pool1d( 85 | torch.cat([first_avg.unsqueeze(1), last_avg.unsqueeze(1)], dim=1).transpose(1, 2), 86 | kernel_size=2).squeeze(-1) 87 | return final_encoding 88 | 89 | if self.encoder_type == EncoderType.LAST_AVG: 90 | sequence_output = model_output.last_hidden_state # [batch_size, max_len, hidden_size] 91 | seq_length = sequence_output.size(1) 92 | final_encoding = torch.avg_pool1d(sequence_output.transpose(1, 2), kernel_size=seq_length).squeeze(-1) 93 | return final_encoding 94 | 95 | if self.encoder_type == EncoderType.CLS: 96 | sequence_output = model_output.last_hidden_state 97 | return sequence_output[:, 0] # [batch, hid_size] 98 | 99 | if self.encoder_type == EncoderType.POOLER: 100 | return model_output.pooler_output # [batch, hid_size] 101 | 102 | if self.encoder_type == EncoderType.MEAN: 103 | """ 104 | Mean Pooling - Take attention mask into account for correct averaging 105 | """ 106 | token_embeddings = model_output.last_hidden_state # Contains all token embeddings 107 | input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() 108 | final_encoding = torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp( 109 | input_mask_expanded.sum(1), min=1e-9) 110 | return final_encoding # [batch, hid_size] 111 | if self.encoder_type == EncoderType.PMA: 112 | token_embeddings = model_output.last_hidden_state # Contains all token embeddings 113 | res_emb = self.pma(token_embeddings, attention_mask) 114 | res_emb = res_emb.reshape(res_emb.size(0), -1) 115 | return res_emb 116 | def batch_to_device(self, batch, device): 117 | """ 118 | send a pytorch batch to a device (CPU/GPU) 119 | """ 120 | for key in batch: 121 | if isinstance(batch[key], torch.Tensor): 122 | batch[key] = batch[key].to(device) 123 | return batch 124 | 125 | 126 | def encode_with_grad( 127 | self, 128 | sentences: Union[str, List[str]], 129 | batch_size: int = 32, 130 | show_progress_bar: bool = False, 131 | convert_to_numpy: bool = False, 132 | convert_to_tensor: bool = True, 133 | device: str = None, 134 | normalize_embeddings: bool = True, 135 | max_seq_length: int = None, 136 | ): 137 | if device is None: 138 | device = self.plm_model.device 139 | # self.plm_model.to(device) 140 | 141 | if max_seq_length is None: 142 | max_seq_length = self.max_seq_length 143 | if convert_to_tensor: 144 | convert_to_numpy = False 145 | input_is_string = False 146 | if isinstance(sentences, str) or not hasattr(sentences, "__len__"): 147 | sentences = [sentences] 148 | input_is_string = True 149 | 150 | all_embeddings = [] 151 | length_sorted_idx = np.argsort([-len(s) for s in sentences]) 152 | sentences_sorted = [sentences[idx] for idx in length_sorted_idx] 153 | for start_index in trange(0, len(sentences), batch_size, desc="Batches", disable=not show_progress_bar): 154 | sentences_batch = sentences_sorted[start_index: start_index + batch_size] 155 | # Compute sentences embeddings 156 | features = self.tokenizer( 157 | sentences_batch, max_length=max_seq_length, 158 | padding=True, truncation=True, return_tensors='pt' 159 | ) 160 | features = self.batch_to_device(features, device) 161 | embeddings = self.get_sentence_embeddings(**features) 162 | embeddings = embeddings.detach() 163 | if normalize_embeddings: 164 | embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1) 165 | 166 | if convert_to_numpy: 167 | embeddings = embeddings.cpu() 168 | all_embeddings.extend(embeddings) 169 | all_embeddings = [all_embeddings[idx] for idx in np.argsort(length_sorted_idx)] 170 | if convert_to_tensor: 171 | all_embeddings = torch.stack(all_embeddings) 172 | elif convert_to_numpy: 173 | all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings]) 174 | 175 | if input_is_string: 176 | all_embeddings = all_embeddings[0] 177 | 178 | return all_embeddings 179 | 180 | def encode_without_grad( 181 | self, 182 | sentences: Union[str, List[str]], 183 | batch_size: int = 32, 184 | show_progress_bar: bool = False, 185 | convert_to_numpy: bool = False, 186 | convert_to_tensor: bool = True, 187 | device: str = None, 188 | normalize_embeddings: bool = True, 189 | max_seq_length: int = None, 190 | ): 191 | self.plm_model.eval() 192 | if device is None: 193 | device = self.plm_model.device 194 | # self.plm_model.to(device) 195 | 196 | if max_seq_length is None: 197 | max_seq_length = self.max_seq_length 198 | if convert_to_tensor: 199 | convert_to_numpy = False 200 | input_is_string = False 201 | if isinstance(sentences, str) or not hasattr(sentences, "__len__"): 202 | sentences = [sentences] 203 | input_is_string = True 204 | 205 | all_embeddings = [] 206 | length_sorted_idx = np.argsort([-len(s) for s in sentences]) 207 | sentences_sorted = [sentences[idx] for idx in length_sorted_idx] 208 | for start_index in trange(0, len(sentences), batch_size, desc="Batches", disable=not show_progress_bar): 209 | sentences_batch = sentences_sorted[start_index: start_index + batch_size] 210 | # Compute sentences embeddings 211 | with torch.no_grad(): 212 | features = self.tokenizer( 213 | sentences_batch, max_length=max_seq_length, 214 | padding=True, truncation=True, return_tensors='pt' 215 | ) 216 | features = self.batch_to_device(features, device) 217 | embeddings = self.get_sentence_embeddings(**features) 218 | embeddings = embeddings.detach() 219 | if normalize_embeddings: 220 | embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1) 221 | 222 | if convert_to_numpy: 223 | embeddings = embeddings.cpu() 224 | all_embeddings.extend(embeddings) 225 | all_embeddings = [all_embeddings[idx] for idx in np.argsort(length_sorted_idx)] 226 | if convert_to_tensor: 227 | all_embeddings = torch.stack(all_embeddings) 228 | elif convert_to_numpy: 229 | all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings]) 230 | 231 | if input_is_string: 232 | all_embeddings = all_embeddings[0] 233 | 234 | return all_embeddings 235 | -------------------------------------------------------------------------------- /dataset/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import math 4 | import numpy as np 5 | import torch 6 | from torch.utils.data import Dataset 7 | from datasets import load_dataset 8 | from transformers import PreTrainedTokenizer, AutoTokenizer 9 | import csv 10 | from loguru import logger 11 | import random 12 | import jsonlines 13 | from utils.common_utils import load_pickle, set_seed, merge_strings, randomly_insert_list1_into_list2 14 | 15 | def load_qmsum_data(data_path, res_data, chunks_repo, global_rank, world_size, mode): 16 | data = [] 17 | if data_path is None: 18 | raise ValueError(f'data path is {data_path}') 19 | with open(data_path, 'r', encoding='utf8') as f: 20 | for i,line in enumerate(jsonlines.Reader(f)): 21 | chunks_repo.extend(line['input']) 22 | if i % world_size != global_rank: 23 | continue 24 | data.append((line['input'], line['prompt'], line['output'], line['task_id'])) 25 | num_samples = len(data) 26 | val_samples = math.ceil(0.05*num_samples) 27 | if mode == 'train': 28 | res_data.extend(data[:-val_samples]) 29 | elif mode == 'validation': 30 | res_data.extend(data[-val_samples:]) 31 | else: 32 | raise ValueError(f'What is {mode} ?') 33 | return res_data, list(set(chunks_repo)) 34 | 35 | def load_qmsum_test_data(data_path): 36 | data = [] 37 | if data_path is None: 38 | raise ValueError(f'data path is {data_path}') 39 | with open(data_path, 'r', encoding='utf8') as f: 40 | for line in jsonlines.Reader(f): 41 | if line['task_id'] == 1: 42 | data.append((line['input'], line['prompt'], line['output'], line['task_id'])) 43 | return data 44 | 45 | 46 | def collate_fn(data): 47 | 48 | res_inputs = [] 49 | res_prompt = [] 50 | res_output = [] 51 | res_task_id = [] 52 | for d in data: 53 | inputs = d[0] # list of str, chunk 54 | prompt = d[1] 55 | output = d[2] 56 | task_id = d[3] 57 | res_inputs.append(inputs) 58 | res_prompt.append(prompt) 59 | res_output.append(output) 60 | res_task_id.append(task_id) 61 | return res_inputs, res_prompt, res_output, res_task_id 62 | class TrainDataset(Dataset): 63 | 64 | def __init__(self, names=None, max_seq_len=0, chunk_size=0, shuffle_sample_ratio=0, noisy_sample_ratio=0, noise_rate=0, max_num_chunks=0, max_sliding_windows=None, process_index=0, num_processes=1, seed=2024): 65 | self.data = [] 66 | self.chunks_repo = [] 67 | self.max_seq_len = max_seq_len 68 | self.chunk_size = chunk_size 69 | self.shuffle_sample_ratio = shuffle_sample_ratio 70 | self.noisy_sample_ratio = noisy_sample_ratio 71 | self.noise_rate = noise_rate 72 | self.max_num_chunks = max_num_chunks 73 | self.max_sliding_windows = max_sliding_windows 74 | self.process_index = process_index 75 | self.num_processes = num_processes 76 | self.deterministic_generator = np.random.default_rng(seed) 77 | data_path = None 78 | names = names[1:-1].split(',') 79 | names = [name.lower() for name in names] 80 | for name in names: 81 | if name == 'qmsum': 82 | data_path = f'/path/to/chunked/dataset/with/chunk_size/train.jsonl' 83 | self.data, self.chunks_repo = load_qmsum_data(data_path, self.data, self.chunks_repo, self.process_index, self.num_processes, 'train') 84 | 85 | self.window_size = math.floor((self.max_seq_len-self.max_num_chunks-60) / self.chunk_size) 86 | self.augmented_data = self.get_augmented_data() 87 | 88 | def get_augmented_data(self): 89 | res_data = [] 90 | for sample in self.data: 91 | chunks = sample[0] # list of str (str represents chunks) 92 | prompt = sample[1] 93 | output = sample[2] 94 | task_id = int(sample[3]) 95 | if task_id == 0: 96 | # 正常重构,不shuffle 97 | num_chunks = len(chunks) 98 | num_sliding_windows = math.ceil(num_chunks / self.window_size) 99 | if self.max_sliding_windows is not None: 100 | num_sliding_windows = min(self.max_sliding_windows, num_sliding_windows) 101 | start = 0 102 | for i in range(num_sliding_windows): 103 | end = min(start + self.window_size, num_chunks) 104 | chunks_in_window = chunks[start:end] 105 | # chunks_in_window_to_string = ''.join(chunks_in_window) 106 | chunks_in_window_to_string = merge_strings(chunks_in_window) 107 | res_data.append((chunks_in_window, prompt, chunks_in_window_to_string, task_id)) 108 | start = end 109 | if task_id == 1: 110 | random_value = random.random() 111 | if random_value < self.shuffle_sample_ratio: 112 | random.shuffle(chunks) 113 | res_data.append((chunks, prompt, output, task_id)) 114 | elif random_value >= self.shuffle_sample_ratio and random_value < self.shuffle_sample_ratio+self.noisy_sample_ratio: 115 | num_noise_chunks = min(math.floor(self.noise_rate*len(chunks)), len(self.chunks_repo)) 116 | noisy_chunks = random.sample(self.chunks_repo, num_noise_chunks) 117 | chunks = randomly_insert_list1_into_list2(chunks, noisy_chunks) 118 | res_data.append((chunks, prompt, output, task_id)) 119 | else: 120 | res_data.append((chunks, prompt, output, task_id)) 121 | 122 | self.step = 0 123 | self.steps_per_epoch = len(res_data) 124 | 125 | return res_data 126 | 127 | 128 | def __getitem__(self, index): 129 | if self.step > self.steps_per_epoch - 1: 130 | self.augmented_data = self.get_augmented_data() 131 | self.step += 1 132 | return self.augmented_data[index] 133 | 134 | def __len__(self): 135 | return self.steps_per_epoch 136 | 137 | class ValDataset(Dataset): 138 | def __init__(self, names=None, max_seq_len=0, chunk_size=0, shuffle_sample_ratio=0, noisy_sample_ratio=0, noise_rate=0, max_num_chunks=0, max_sliding_windows=None, process_index=0, num_processes=1, seed=2024): 139 | self.data = [] 140 | self.chunks_repo = [] 141 | self.max_seq_len = max_seq_len 142 | self.chunk_size = chunk_size 143 | self.max_num_chunks = max_num_chunks 144 | self.shuffle_sample_ratio = shuffle_sample_ratio 145 | self.noisy_sample_ratio = noisy_sample_ratio 146 | self.noise_rate = noise_rate 147 | self.process_index = process_index 148 | self.max_sliding_windows = max_sliding_windows 149 | self.num_processes = num_processes 150 | self.deterministic_generator = np.random.default_rng(seed) 151 | data_path = None 152 | names = names[1:-1].split(',') 153 | names = [name.lower() for name in names] 154 | for name in names: 155 | if name == 'qmsum': 156 | data_path = f'/path/to/chunked/dataset/with/chunk_size/train.jsonl' 157 | self.data, self.chunks_repo = load_qmsum_data(data_path, self.data, self.chunks_repo, self.process_index, self.num_processes, 'validation') 158 | self.window_size = math.floor((self.max_seq_len-self.max_num_chunks-60) / self.chunk_size) 159 | 160 | self.augmented_data = self.get_augmented_data() 161 | 162 | def get_augmented_data(self): 163 | res_data = [] 164 | for sample in self.data: 165 | chunks = sample[0] # list of str (str represents chunks) 166 | prompt = sample[1] 167 | output = sample[2] 168 | task_id = int(sample[3]) 169 | if task_id == 0: 170 | # 正常重构,不shuffle 171 | num_chunks = len(chunks) 172 | num_sliding_windows = math.ceil(num_chunks / self.window_size) 173 | if self.max_sliding_windows is not None: 174 | num_sliding_windows = min(self.max_sliding_windows, num_sliding_windows) 175 | start = 0 176 | for i in range(num_sliding_windows): 177 | end = min(start + self.window_size, num_chunks) 178 | chunks_in_window = chunks[start:end] 179 | # chunks_in_window_to_string = ''.join(chunks_in_window) 180 | chunks_in_window_to_string = merge_strings(chunks_in_window) 181 | res_data.append((chunks_in_window, prompt, chunks_in_window_to_string, task_id)) 182 | start = end 183 | if task_id == 1: 184 | random_value = random.random() 185 | if random_value < self.shuffle_sample_ratio: 186 | random.shuffle(chunks) 187 | res_data.append((chunks, prompt, output, task_id)) 188 | elif random_value >= self.shuffle_sample_ratio and random_value < self.shuffle_sample_ratio+self.noisy_sample_ratio: 189 | num_noise_chunks = min(math.floor(self.noise_rate*len(chunks)), len(self.chunks_repo)) 190 | noisy_chunks = random.sample(self.chunks_repo, num_noise_chunks) 191 | chunks = randomly_insert_list1_into_list2(chunks, noisy_chunks) 192 | res_data.append((chunks, prompt, output, task_id)) 193 | else: 194 | res_data.append((chunks, prompt, output, task_id)) 195 | 196 | self.step = 0 197 | self.steps_per_epoch = len(res_data) 198 | 199 | return res_data 200 | 201 | 202 | def __getitem__(self, index): 203 | if self.step > self.steps_per_epoch - 1: 204 | self.augmented_data = self.get_augmented_data() 205 | self.step += 1 206 | return self.augmented_data[index] 207 | 208 | def __len__(self): 209 | return self.steps_per_epoch 210 | 211 | 212 | 213 | 214 | 215 | class TestqmsumDataset(Dataset): 216 | def __init__(self, chunk_size=0, seed=2024): 217 | self.deterministic_generator = np.random.default_rng(seed) 218 | data_path = f'/path/to/chunked/dataset/with/chunk_size/eval.jsonl' 219 | self.data = load_qmsum_test_data(data_path) 220 | 221 | def __getitem__(self, index): 222 | 223 | return self.data[index] 224 | 225 | def __len__(self): 226 | return len(self.data) 227 | 228 | if __name__ == '__main__': 229 | d = TrainDataset(names = ['qmsum']) 230 | print(len(d.data)) -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | Copyright [2023] [Ant Group] 2 | Licensed under the Apache License, Version 2.0 (the "License"); 3 | you may not use this file except in compliance with the License. 4 | You may obtain a copy of the License at 5 | http://www.apache.org/licenses/LICENSE-2.0 6 | 7 | Unless required by applicable law or agreed to in writing, software 8 | distributed under the License is distributed on an "AS IS" BASIS, 9 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | See the License for the specific language governing permissions and 11 | limitations under the License. 12 | 13 | 14 | Apache License 15 | Version 2.0, January 2004 16 | http://www.apache.org/licenses/ 17 | 18 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 19 | 20 | 1. Definitions. 21 | 22 | "License" shall mean the terms and conditions for use, reproduction, 23 | and distribution as defined by Sections 1 through 9 of this document. 24 | 25 | "Licensor" shall mean the copyright owner or entity authorized by 26 | the copyright owner that is granting the License. 27 | 28 | "Legal Entity" shall mean the union of the acting entity and all 29 | other entities that control, are controlled by, or are under common 30 | control with that entity. For the purposes of this definition, 31 | "control" means (i) the power, direct or indirect, to cause the 32 | direction or management of such entity, whether by contract or 33 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 34 | outstanding shares, or (iii) beneficial ownership of such entity. 35 | 36 | "You" (or "Your") shall mean an individual or Legal Entity 37 | exercising permissions granted by this License. 38 | 39 | "Source" form shall mean the preferred form for making modifications, 40 | including but not limited to software source code, documentation 41 | source, and configuration files. 42 | 43 | "Object" form shall mean any form resulting from mechanical 44 | transformation or translation of a Source form, including but 45 | not limited to compiled object code, generated documentation, 46 | and conversions to other media types. 47 | 48 | "Work" shall mean the work of authorship, whether in Source or 49 | Object form, made available under the License, as indicated by a 50 | copyright notice that is included in or attached to the work 51 | (an example is provided in the Appendix below). 52 | 53 | "Derivative Works" shall mean any work, whether in Source or Object 54 | form, that is based on (or derived from) the Work and for which the 55 | editorial revisions, annotations, elaborations, or other modifications 56 | represent, as a whole, an original work of authorship. For the purposes 57 | of this License, Derivative Works shall not include works that remain 58 | separable from, or merely link (or bind by name) to the interfaces of, 59 | the Work and Derivative Works thereof. 60 | 61 | "Contribution" shall mean any work of authorship, including 62 | the original version of the Work and any modifications or additions 63 | to that Work or Derivative Works thereof, that is intentionally 64 | submitted to Licensor for inclusion in the Work by the copyright owner 65 | or by an individual or Legal Entity authorized to submit on behalf of 66 | the copyright owner. For the purposes of this definition, "submitted" 67 | means any form of electronic, verbal, or written communication sent 68 | to the Licensor or its representatives, including but not limited to 69 | communication on electronic mailing lists, source code control systems, 70 | and issue tracking systems that are managed by, or on behalf of, the 71 | Licensor for the purpose of discussing and improving the Work, but 72 | excluding communication that is conspicuously marked or otherwise 73 | designated in writing by the copyright owner as "Not a Contribution." 74 | 75 | "Contributor" shall mean Licensor and any individual or Legal Entity 76 | on behalf of whom a Contribution has been received by Licensor and 77 | subsequently incorporated within the Work. 78 | 79 | 2. Grant of Copyright License. Subject to the terms and conditions of 80 | this License, each Contributor hereby grants to You a perpetual, 81 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 82 | copyright license to reproduce, prepare Derivative Works of, 83 | publicly display, publicly perform, sublicense, and distribute the 84 | Work and such Derivative Works in Source or Object form. 85 | 86 | 3. Grant of Patent License. Subject to the terms and conditions of 87 | this License, each Contributor hereby grants to You a perpetual, 88 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 89 | (except as stated in this section) patent license to make, have made, 90 | use, offer to sell, sell, import, and otherwise transfer the Work, 91 | where such license applies only to those patent claims licensable 92 | by such Contributor that are necessarily infringed by their 93 | Contribution(s) alone or by combination of their Contribution(s) 94 | with the Work to which such Contribution(s) was submitted. If You 95 | institute patent litigation against any entity (including a 96 | cross-claim or counterclaim in a lawsuit) alleging that the Work 97 | or a Contribution incorporated within the Work constitutes direct 98 | or contributory patent infringement, then any patent licenses 99 | granted to You under this License for that Work shall terminate 100 | as of the date such litigation is filed. 101 | 102 | 4. Redistribution. You may reproduce and distribute copies of the 103 | Work or Derivative Works thereof in any medium, with or without 104 | modifications, and in Source or Object form, provided that You 105 | meet the following conditions: 106 | 107 | (a) You must give any other recipients of the Work or 108 | Derivative Works a copy of this License; and 109 | 110 | (b) You must cause any modified files to carry prominent notices 111 | stating that You changed the files; and 112 | 113 | (c) You must retain, in the Source form of any Derivative Works 114 | that You distribute, all copyright, patent, trademark, and 115 | attribution notices from the Source form of the Work, 116 | excluding those notices that do not pertain to any part of 117 | the Derivative Works; and 118 | 119 | (d) If the Work includes a "NOTICE" text file as part of its 120 | distribution, then any Derivative Works that You distribute must 121 | include a readable copy of the attribution notices contained 122 | within such NOTICE file, excluding those notices that do not 123 | pertain to any part of the Derivative Works, in at least one 124 | of the following places: within a NOTICE text file distributed 125 | as part of the Derivative Works; within the Source form or 126 | documentation, if provided along with the Derivative Works; or, 127 | within a display generated by the Derivative Works, if and 128 | wherever such third-party notices normally appear. The contents 129 | of the NOTICE file are for informational purposes only and 130 | do not modify the License. You may add Your own attribution 131 | notices within Derivative Works that You distribute, alongside 132 | or as an addendum to the NOTICE text from the Work, provided 133 | that such additional attribution notices cannot be construed 134 | as modifying the License. 135 | 136 | You may add Your own copyright statement to Your modifications and 137 | may provide additional or different license terms and conditions 138 | for use, reproduction, or distribution of Your modifications, or 139 | for any such Derivative Works as a whole, provided Your use, 140 | reproduction, and distribution of the Work otherwise complies with 141 | the conditions stated in this License. 142 | 143 | 5. Submission of Contributions. Unless You explicitly state otherwise, 144 | any Contribution intentionally submitted for inclusion in the Work 145 | by You to the Licensor shall be under the terms and conditions of 146 | this License, without any additional terms or conditions. 147 | Notwithstanding the above, nothing herein shall supersede or modify 148 | the terms of any separate license agreement you may have executed 149 | with Licensor regarding such Contributions. 150 | 151 | 6. Trademarks. This License does not grant permission to use the trade 152 | names, trademarks, service marks, or product names of the Licensor, 153 | except as required for reasonable and customary use in describing the 154 | origin of the Work and reproducing the content of the NOTICE file. 155 | 156 | 7. Disclaimer of Warranty. Unless required by applicable law or 157 | agreed to in writing, Licensor provides the Work (and each 158 | Contributor provides its Contributions) on an "AS IS" BASIS, 159 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 160 | implied, including, without limitation, any warranties or conditions 161 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 162 | PARTICULAR PURPOSE. You are solely responsible for determining the 163 | appropriateness of using or redistributing the Work and assume any 164 | risks associated with Your exercise of permissions under this License. 165 | 166 | 8. Limitation of Liability. In no event and under no legal theory, 167 | whether in tort (including negligence), contract, or otherwise, 168 | unless required by applicable law (such as deliberate and grossly 169 | negligent acts) or agreed to in writing, shall any Contributor be 170 | liable to You for damages, including any direct, indirect, special, 171 | incidental, or consequential damages of any character arising as a 172 | result of this License or out of the use or inability to use the 173 | Work (including but not limited to damages for loss of goodwill, 174 | work stoppage, computer failure or malfunction, or any and all 175 | other commercial damages or losses), even if such Contributor 176 | has been advised of the possibility of such damages. 177 | 178 | 9. Accepting Warranty or Additional Liability. While redistributing 179 | the Work or Derivative Works thereof, You may choose to offer, 180 | and charge a fee for, acceptance of support, warranty, indemnity, 181 | or other liability obligations and/or rights consistent with this 182 | License. However, in accepting such obligations, You may act only 183 | on Your own behalf and on Your sole responsibility, not on behalf 184 | of any other Contributor, and only if You agree to indemnify, 185 | defend, and hold each Contributor harmless for any liability 186 | incurred by, or claims asserted against, such Contributor by reason 187 | of your accepting any such warranty or additional liability. 188 | 189 | END OF TERMS AND CONDITIONS 190 | 191 | APPENDIX: How to apply the Apache License to your work. 192 | 193 | To apply the Apache License to your work, attach the following 194 | boilerplate notice, with the fields enclosed by brackets "[]" 195 | replaced with your own identifying information. (Don't include 196 | the brackets!) The text should be enclosed in the appropriate 197 | comment syntax for the file format. We also recommend that a 198 | file or class name and description of purpose be included on the 199 | same "printed page" as the copyright notice for easier 200 | identification within third-party archives. 201 | 202 | Copyright [yyyy] [name of copyright owner] 203 | 204 | Licensed under the Apache License, Version 2.0 (the "License"); 205 | you may not use this file except in compliance with the License. 206 | You may obtain a copy of the License at 207 | 208 | http://www.apache.org/licenses/LICENSE-2.0 209 | 210 | Unless required by applicable law or agreed to in writing, software 211 | distributed under the License is distributed on an "AS IS" BASIS, 212 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 213 | See the License for the specific language governing permissions and 214 | limitations under the License. -------------------------------------------------------------------------------- /train_accelerate.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import math 4 | import time 5 | import argparse 6 | import json 7 | import logging 8 | from utils.common_utils import compile_helper 9 | from tqdm.auto import tqdm 10 | import transformers 11 | import numpy as np 12 | import torch 13 | from torch import nn 14 | import datasets 15 | 16 | from torch.utils.data.distributed import DistributedSampler 17 | from torch.utils.data import DataLoader, Dataset, RandomSampler 18 | from transformers import ( 19 | AutoModelForCausalLM, 20 | AutoTokenizer, 21 | get_linear_schedule_with_warmup, 22 | set_seed, 23 | BitsAndBytesConfig, 24 | get_scheduler, 25 | ) 26 | from peft import ( 27 | LoraConfig, 28 | TaskType, 29 | get_peft_model, 30 | prepare_model_for_kbit_training, 31 | PeftModel, 32 | ) 33 | from accelerate import Accelerator, DistributedType, FullyShardedDataParallelPlugin, DataLoaderConfiguration 34 | from accelerate.logging import get_logger 35 | from datetime import timedelta 36 | from accelerate.utils import InitProcessGroupKwargs 37 | from transformers.optimization import Adafactor 38 | 39 | # insert project root as import path, just in case 40 | current_path = os.path.abspath(__file__) 41 | parent_dir = os.path.dirname(os.path.dirname(current_path)) 42 | sys.path.insert(0, parent_dir) 43 | 44 | from dataset.dataset import * 45 | from utils.common_utils import print_rank_0, get_parameter_number 46 | from pefts import E2LLMTrainArgs, E2LLMTrainer 47 | from model.pro_model import E2LLMModel 48 | logger = get_logger(__name__) 49 | 50 | def pprint_args(args, accelerator): 51 | max_key_length = max(len(str(key)) for key in vars(args).keys()) 52 | message = "" 53 | message += "====" * 60 + "\n" 54 | message += "\n".join([f"{k:<{max_key_length}} : {v}" for k, v in vars(args).items()]) + "\n" 55 | message += "====" * 60 + "\n" 56 | accelerator.print(message) 57 | accelerator.print("GPU: {}".format(torch.cuda.current_device())) 58 | 59 | def str2bool(v): 60 | return v.lower() in ('yes', 'true', 't', '1') 61 | 62 | def prepare_args(): 63 | parser = argparse.ArgumentParser() 64 | parser.add_argument("--train_config", type=str, default="./configs/train_config.json") 65 | parsed = parser.parse_args() 66 | 67 | with open(parsed.train_config, "r") as f: 68 | train_config = json.load(f) 69 | 70 | args = E2LLMTrainArgs(**train_config) 71 | args.output_dir = f"/path/to/your/output/of/ckpt/{args.mark}" 72 | args.tb_dir = f"/path/to/your/output/of/ckpt/{args.mark}/log" 73 | 74 | return args 75 | 76 | def main(): 77 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 78 | os.environ["HF_HUB_OFFLINE"] = "false" 79 | args = prepare_args() 80 | 81 | if args.seed is not None: 82 | set_seed(args.seed) 83 | 84 | os.makedirs(args.output_dir, exist_ok=True) 85 | logging.basicConfig(filename=f"{args.output_dir}/training_{args.mark}.log", level=logging.INFO) 86 | 87 | # define accelerator 88 | init_process_kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=args.init_timeout_seconds)) 89 | if args.distributed_type is not None and args.distributed_type.lower() == "fsdp": 90 | fsdp_plugin = FullyShardedDataParallelPlugin( 91 | # state_dict_config=FullStateDictConfig(offload_to_cpu=True, rank0_only=True), 92 | # optim_state_dict_config=FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True), 93 | limit_all_gathers=True, 94 | sync_module_states=True, 95 | use_orig_params=True, 96 | cpu_offload=False, 97 | ) 98 | accelerator = Accelerator( 99 | gradient_accumulation_steps=args.gradient_accumulation_steps, 100 | fsdp_plugin=fsdp_plugin, 101 | kwargs_handlers=[init_process_kwargs]) 102 | else: 103 | accelerator = Accelerator( 104 | gradient_accumulation_steps=args.gradient_accumulation_steps, 105 | kwargs_handlers=[init_process_kwargs]) 106 | 107 | args.world_size = accelerator.num_processes 108 | pprint_args(args, accelerator) 109 | if accelerator.is_main_process: 110 | with open(os.path.join(args.output_dir, "args.json"), "a") as f: 111 | json.dump(args.dict(), f, indent=2) 112 | 113 | 114 | 115 | global_rank = accelerator.process_index 116 | local_rank = accelerator.local_process_index 117 | accelerator.print(f"world_size: {args.world_size}, global_rank: {global_rank}, local_rank: {local_rank}") 118 | logger.info(accelerator.state, main_process_only=False) 119 | 120 | t0 = time.time() 121 | train_dataset = TrainDataset(names=args.train_data_list, max_seq_len=args.max_seq_len, chunk_size=args.chunk_size, shuffle_sample_ratio=args.shuffle_sample_ratio, noisy_sample_ratio=args.noisy_sample_ratio, noise_rate=args.noise_rate, max_num_chunks=args.max_num_chunks, max_sliding_windows=args.max_sliding_windows, process_index=global_rank, num_processes=args.world_size) 122 | val_dataset = ValDataset(names=args.val_data_list, max_seq_len=args.max_seq_len, chunk_size=args.chunk_size, shuffle_sample_ratio=args.shuffle_sample_ratio, noisy_sample_ratio=args.noisy_sample_ratio, noise_rate=args.noise_rate, max_num_chunks=args.max_num_chunks, max_sliding_windows=args.max_sliding_windows, process_index=global_rank, num_processes=args.world_size) 123 | train_sampler = RandomSampler(train_dataset) 124 | val_sampler = RandomSampler(val_dataset) 125 | train_dataloader = DataLoader(train_dataset, batch_size=args.bs, shuffle=False, sampler=train_sampler, collate_fn=collate_fn, num_workers=0) 126 | val_dataloader = DataLoader(val_dataset, batch_size=args.bs, shuffle=False, sampler=val_sampler, collate_fn=collate_fn, num_workers=0) 127 | t1 = time.time() 128 | logger.info(f"dataset loading time: {t1 - t0:.4f}") 129 | 130 | # cuda memory 131 | free_in_GB = int(torch.cuda.mem_get_info()[0] / 1024 ** 3) 132 | max_memory = f"{free_in_GB - 2}GB" 133 | n_gpus = torch.cuda.device_count() 134 | max_memory = {i: max_memory for i in range(n_gpus)} 135 | accelerator.print("max memory: ", max_memory, n_gpus) 136 | 137 | 138 | with open("./configs/lora_modules.json", 'r') as fr: 139 | lora_modules = json.load(fr) 140 | 141 | if args.peft_fn == "qlora": 142 | print_rank_0(f"[INFO] args.peft_type is set 'qlora', setting quantization to '4bit'") 143 | args.quantization = "4bit" 144 | else: 145 | args.quantization = None 146 | 147 | enc_model = args.encoder_model_dir.split('/')[-1] 148 | dec_model = args.decoder_model_dir.split('/')[-1] 149 | 150 | # peft config 151 | if args.peft_fn in ["lora", "qlora"]: 152 | lora_config_enc = LoraConfig( 153 | r=args.lora_rank_enc, 154 | lora_alpha=args.lora_rank_enc, 155 | # target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], 156 | target_modules=lora_modules[enc_model], 157 | lora_dropout=0.05, 158 | bias="lora_only", 159 | inference_mode=False, 160 | task_type=TaskType.FEATURE_EXTRACTION 161 | ) 162 | lora_config_dec = LoraConfig( 163 | r=args.lora_rank_dec, 164 | lora_alpha=args.lora_alpha_dec, 165 | # target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "down_proj", "up_proj"], 166 | target_modules=lora_modules[dec_model] if dec_model in lora_modules else "all-linear", 167 | lora_dropout=0.05, 168 | bias="lora_only", 169 | inference_mode=False, 170 | task_type=TaskType.CAUSAL_LM 171 | ) 172 | 173 | model = E2LLMModel(args=args) 174 | model.decoder_model.gradient_checkpointing_enable() 175 | 176 | if args.quantization == "4bit": 177 | model.decoder_model = prepare_model_for_kbit_training(model.decoder_model) 178 | if args.peft_fn in ["lora", "qlora"]: 179 | model.encoder_model = get_peft_model(model.encoder_model, lora_config_enc) 180 | model.decoder_model = get_peft_model(model.decoder_model, lora_config_dec) 181 | if accelerator.is_main_process: 182 | model.encoder_model.print_trainable_parameters() 183 | model.decoder_model.print_trainable_parameters() 184 | else: 185 | accelerator.print("[WARNING] Full-Parameters Training takes up a lot of space, we set num_ckpt to 2") 186 | args.num_ckpt = 2 187 | # model.config.use_cache = False # silence the warnings. Please re-enable for inference! 188 | 189 | _, trainable_parameters = get_parameter_number(model) 190 | logger.info(f"Trainable parameters: {trainable_parameters}") 191 | print(f"Trainable parameters: {trainable_parameters}") 192 | 193 | t2 = time.time() 194 | if accelerator.is_main_process: 195 | logging.info(f"model loading time: {t2 - t1:.4f}") 196 | 197 | if hasattr(model.decoder_model.config, "use_logn_attn"): 198 | model.decoder_model.config.use_logn_attn = False # special for qwen model 199 | 200 | # load balance for moe training 201 | if hasattr(model.decoder_model.config, "output_router_logits"): 202 | model.decoder_model.config.output_router_logits = True 203 | 204 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) 205 | if args.max_train_steps is None: 206 | args.max_train_steps = args.num_epochs * num_update_steps_per_epoch 207 | 208 | if accelerator.distributed_type == DistributedType.DEEPSPEED: 209 | adam_optimizer = torch.optim.AdamW 210 | elif accelerator.distributed_type == DistributedType.FSDP: 211 | if args.peft_type and getattr(accelerator.state, "fsdp_plugin", None) is not None: 212 | from peft.utils.other import fsdp_auto_wrap_policy 213 | 214 | accelerator.state.fsdp_plugin.auto_wrap_policy = fsdp_auto_wrap_policy(model) 215 | model = accelerator.prepare(model) 216 | adam_optimizer = torch.optim.AdamW 217 | else: 218 | raise ValueError("Only support DeepSpeed and FSDP") 219 | 220 | optimizer = adam_optimizer( 221 | model.parameters(), 222 | weight_decay=args.weight_decay, 223 | lr=args.lr, 224 | betas=(0.9, 0.999), 225 | ) 226 | 227 | lr_scheduler = get_scheduler( 228 | name= "linear", #args.lr_scheduler_type, 229 | optimizer=optimizer, 230 | num_warmup_steps=args.num_warmup_steps * args.gradient_accumulation_steps * accelerator.num_processes, 231 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps * accelerator.num_processes 232 | ) 233 | 234 | 235 | # prepare all 236 | if accelerator.distributed_type == DistributedType.DEEPSPEED: 237 | model, _, optimizer, lr_scheduler = accelerator.prepare(model, val_dataloader, optimizer, lr_scheduler) 238 | elif accelerator.distributed_type == DistributedType.FSDP: 239 | _, optimizer, lr_scheduler = accelerator.prepare(val_dataloader, optimizer, lr_scheduler) 240 | else: 241 | raise ValueError("Only support DeepSpeed and FSDP") 242 | 243 | print(model.device) 244 | accelerator.print(model) 245 | 246 | # Recalculate our total training steps as the size of the training dataloader may have changed. 247 | # num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) 248 | # if overrode_max_train_steps: 249 | # args.max_train_steps = args.num_epochs * num_update_steps_per_epoch 250 | # # Afterward we recalculate our number of training epochs 251 | # args.num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) 252 | 253 | 254 | # zero 3 flag 255 | is_ds_zero_3 = False 256 | if getattr(accelerator.state, "deepspeed_plugin", None): 257 | is_ds_zero_3 = accelerator.state.deepspeed_plugin.zero_stage == 3 258 | accelerator.print(f"DEEPSPEED plugin: {accelerator.state.deepspeed_plugin}") 259 | elif getattr(accelerator.state, "fsdp_plugin", None): 260 | accelerator.print(f"FSDP plugin: {accelerator.state.fsdp_plugin}") 261 | 262 | trainer = E2LLMTrainer( 263 | accelerator=accelerator, 264 | model=model, 265 | train_dataloader=train_dataloader, 266 | val_dataloader=val_dataloader, 267 | optimizer=optimizer, 268 | lr_scheduler=lr_scheduler, 269 | tokenizer=model.tokenizer, 270 | total_train_dataset_size=len(train_dataset), 271 | args=args, 272 | ) 273 | trainer.accelerate_train() 274 | logger.info(f"Training Finished!") 275 | 276 | if __name__ == '__main__': 277 | main() -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import math 4 | import time 5 | import re 6 | import argparse 7 | import json 8 | import logging 9 | from utils.common_utils import compile_helper 10 | from tqdm.auto import tqdm 11 | import transformers 12 | import numpy as np 13 | import torch 14 | from torch import nn 15 | import datasets 16 | from torch.utils.data.distributed import DistributedSampler 17 | from torch.utils.data import DataLoader, Dataset, RandomSampler 18 | from transformers import ( 19 | AutoModelForCausalLM, 20 | AutoTokenizer, 21 | get_linear_schedule_with_warmup, 22 | set_seed, 23 | BitsAndBytesConfig, 24 | get_scheduler, 25 | ) 26 | from peft import ( 27 | LoraConfig, 28 | TaskType, 29 | get_peft_model, 30 | prepare_model_for_kbit_training, 31 | PeftModel, 32 | ) 33 | from accelerate import Accelerator, DistributedType, FullyShardedDataParallelPlugin, DataLoaderConfiguration 34 | from accelerate.logging import get_logger 35 | from datetime import timedelta 36 | from accelerate.utils import InitProcessGroupKwargs 37 | from transformers.optimization import Adafactor 38 | 39 | # insert project root as import path, just in case 40 | current_path = os.path.abspath(__file__) 41 | parent_dir = os.path.dirname(os.path.dirname(current_path)) 42 | sys.path.insert(0, parent_dir) 43 | 44 | from dataset.dataset import * 45 | from utils.common_utils import print_rank_0, upload_to_oss 46 | from pefts import E2LLMTrainArgs, E2LLMTrainer 47 | from model.pro_model import E2LLMModel 48 | logger = get_logger(__name__) 49 | from rouge import Rouge 50 | from evaluate import compute_rouge, compute_f1, compute_score 51 | from draw import draw_results 52 | 53 | def write_json_fancy(json_data, path): 54 | with open(path, 'a') as file: 55 | json_str = json.dumps(json_data, indent=4) 56 | file.write(json_str + '\n') 57 | 58 | def str2bool(v): 59 | return v.lower() in ('yes', 'true', 't', '1') 60 | 61 | 62 | def filter_top_k_chunks(model, chunks, queries, task_ids, top_k, thres): 63 | 64 | res_retrieved_chunks = [] 65 | chunks_nums = [len(chunks_per) for chunks_per in chunks] 66 | 67 | chunks_cpy = [chunk for chunks_per in chunks for chunk in chunks_per] 68 | queries_chunks = queries + chunks_cpy 69 | queries_chunks_emb = model.encoder_model.encode_without_grad(queries_chunks) 70 | queries_emb = queries_chunks_emb[:len(queries)] # (bs, emb) 71 | chunks_emb = queries_chunks_emb[len(queries):] 72 | splitted_chunks_emb = torch.split(chunks_emb, chunks_nums) 73 | 74 | for i, t in enumerate(splitted_chunks_emb): 75 | if task_ids[i]==0: 76 | res_retrieved_chunks.append([]) 77 | continue 78 | sim = torch.matmul(queries_emb[i].unsqueeze(0), splitted_chunks_emb[i].T) 79 | values, indices = torch.topk(sim, min(top_k, sim.size(1)), dim=1) 80 | mask = values > thres 81 | indices = indices[mask].to('cpu') 82 | # indices = indices[0].to('cpu') 83 | retrieved_chunks = np.array(chunks[i])[indices] 84 | res_retrieved_chunks.append(retrieved_chunks) 85 | 86 | a_sorted = sorted(indices, reverse=True) 87 | for index in a_sorted: 88 | del chunks[i][index] 89 | return chunks, res_retrieved_chunks 90 | 91 | 92 | def postprocess_pred(predict_str): 93 | predict_str = predict_str.strip() 94 | # Remove all non-printable characters 95 | np_pattern = re.compile(r'[\x00-\x1f]') 96 | predict_str = np_pattern.sub('\n', predict_str).strip() 97 | 98 | return predict_str 99 | 100 | def string_match_part(preds, refs): 101 | score = sum([max([1.0 if r.lower() in pred.lower() else 0.0 for r in ref]) for pred, ref in zip(preds, refs)]) / len(preds) * 100 102 | return round(score, 2) 103 | 104 | def string_match_all(preds, refs): 105 | score = sum([sum([1.0 if r.lower() in pred.lower() else 0.0 for r in ref]) / len(ref) for pred, ref in zip(preds, refs)]) / len(preds) * 100 106 | return round(score, 2) 107 | 108 | def extract_answer(response): 109 | response = response.replace('*', '') 110 | match = re.search(r'The correct answer is \(([A-D])\)', response) 111 | if match: 112 | return match.group(1) 113 | else: 114 | match = re.search(r'The correct answer is ([A-D])', response) 115 | if match: 116 | return match.group(1) 117 | else: 118 | return None 119 | 120 | def gen_results(details_file, output_file, name): 121 | files = [details_file] # "./wrongresults.jsonl", './results.jsonl' 122 | output = ["Model\tOverall\tEasy\tHard\tShort\tMedium\tLong"] 123 | compensated = False 124 | 125 | for file1 in files: 126 | filename = file1 127 | try: 128 | pred_data = json.load(open(filename, encoding='utf-8')) 129 | except Exception as e: 130 | pred_data = [json.loads(line) for line in open(filename, encoding='utf-8')] 131 | easy, hard, short, medium, long = 0, 0, 0, 0, 0 132 | easy_acc, hard_acc, short_acc, medium_acc, long_acc = 0, 0, 0, 0, 0 133 | for pred in pred_data: 134 | acc = int(pred['judge']) 135 | if compensated and pred["pred"] == None: 136 | acc = 0.25 137 | if pred["difficulty"] == "easy": 138 | easy += 1 139 | easy_acc += acc 140 | else: 141 | hard += 1 142 | hard_acc += acc 143 | 144 | if pred['length'] == "short": 145 | short += 1 146 | short_acc += acc 147 | elif pred['length'] == "medium": 148 | medium += 1 149 | medium_acc += acc 150 | else: 151 | long += 1 152 | long_acc += acc 153 | 154 | output.append(name+'\t'+str(round(100*(easy_acc+hard_acc)/len(pred_data), 1))+'\t'+str(round(100*easy_acc/(easy+1e-5), 1))+'\t'+str(round(100*hard_acc/(hard+1e-5), 1))+'\t'+str(round(100*short_acc/(short+1e-5), 1))+'\t'+str(round(100*medium_acc/(medium+1e-5), 1))+'\t'+str(round(100*long_acc/(long+1e-5), 1))) 155 | 156 | open(output_file, 'a', encoding='utf-8').write('\n'.join(output)) 157 | 158 | 159 | def prepare_args(): 160 | parser = argparse.ArgumentParser() 161 | parser.add_argument("--eval_config", type=str, default="./configs/eval_config.json") 162 | parser.add_argument('--test_data_list', type=str, default='[qmsum]') 163 | parser.add_argument('--global_step', default=5, type=int) 164 | parser.add_argument('--mark', required=True) 165 | parser.add_argument('--encoder_model_dir', default='/path/to/encoder/model', type=str, help='Model directory') 166 | parser.add_argument('--decoder_model_dir', default='/path/to/decoder/model', type=str, help='Model directory') 167 | parser.add_argument('--chunk_size', default=512, type=int, help='chunk_size') 168 | parser.add_argument('--max_num_chunks', default=500, type=int, help='max_num_chunks') 169 | parser.add_argument('--enc_pool_fn', default='CLS', type=str, help='') 170 | parser.add_argument('--tokens_for_each_chunk', default=4, type=int, help='') 171 | parser.add_argument('--ln', default='True', type=str2bool, help='') 172 | parser.add_argument('--proj_arch', default='', type=str, help='proj_arch') 173 | parser.add_argument('--lora_rank_enc', default=32, type=int, help='lora rank') 174 | parser.add_argument('--lora_alpha_enc', default=32, type=int, help='lora alpha') 175 | parser.add_argument('--lora_rank_dec', default=8, type=int, help='lora rank') 176 | parser.add_argument('--lora_alpha_dec', default=8, type=int, help='lora alpha') 177 | parser.add_argument('--eval_retrieve', default='True', type=str2bool, help='') 178 | parser.add_argument('--top_k', default=3, type=int, help='') 179 | parser.add_argument('--thres', default=0.8, type=float, help='') 180 | parser.add_argument('--attn_implementation', default='eager', type=str, help='') 181 | parser.add_argument('--bs', default=10, type=int) 182 | parser.add_argument('--max_new_tokens', default=100, type=int) 183 | parser.add_argument('--api', default="gpt4o", type=str) # qwen72b 184 | 185 | 186 | parsed = parser.parse_args() 187 | 188 | with open(parsed.eval_config, "r") as f: 189 | eval_config = json.load(f) 190 | 191 | args = E2LLMTrainArgs(**eval_config) 192 | args.test_data_list = parsed.test_data_list 193 | args.global_step = parsed.global_step 194 | args.mark = parsed.mark 195 | args.encoder_model_dir = parsed.encoder_model_dir 196 | args.decoder_model_dir = parsed.decoder_model_dir 197 | args.chunk_size = parsed.chunk_size 198 | args.max_num_chunks = parsed.max_num_chunks 199 | args.enc_pool_fn = parsed.enc_pool_fn 200 | args.tokens_for_each_chunk = parsed.tokens_for_each_chunk 201 | args.ln = parsed.ln 202 | args.api = parsed.api 203 | 204 | if parsed.proj_arch not in [None, '', 'None', 'null']: 205 | args.proj_arch = parsed.proj_arch 206 | else: 207 | args.proj_arch = None 208 | 209 | args.lora_rank_enc = parsed.lora_rank_enc 210 | args.lora_alpha_enc = parsed.lora_alpha_enc 211 | args.lora_rank_dec = parsed.lora_rank_dec 212 | args.lora_alpha_dec = parsed.lora_alpha_dec 213 | 214 | args.eval_retrieve = parsed.eval_retrieve 215 | args.top_k = parsed.top_k 216 | args.thres = parsed.thres 217 | args.attn_implementation = parsed.attn_implementation 218 | args.bs = parsed.bs 219 | args.max_new_tokens = parsed.max_new_tokens 220 | args.bf16 = True 221 | args.output_result_path=f"/path/to/result/json" 222 | 223 | return args 224 | 225 | def main(): 226 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 227 | os.environ["HF_HUB_OFFLINE"] = "false" 228 | args = prepare_args() 229 | print(args) 230 | 231 | ckpt_root = f'/path/to/ckpt' 232 | if args.seed is not None: 233 | set_seed(args.seed) 234 | with open("./configs/lora_modules.json", 'r') as fr: 235 | lora_modules = json.load(fr) 236 | 237 | 238 | enc_model = args.encoder_model_dir.split('/')[-1] 239 | dec_model = args.decoder_model_dir.split('/')[-1] 240 | 241 | if args.peft_fn in ["lora", "qlora"]: 242 | lora_config_enc = LoraConfig( 243 | r=args.lora_rank_enc, 244 | lora_alpha=args.lora_rank_enc, 245 | target_modules=lora_modules[enc_model], 246 | lora_dropout=0, 247 | bias="lora_only", 248 | inference_mode=False, 249 | task_type=TaskType.FEATURE_EXTRACTION 250 | ) 251 | lora_config_dec = LoraConfig( 252 | r=args.lora_rank_dec, 253 | lora_alpha=args.lora_alpha_dec, 254 | target_modules=lora_modules[dec_model] if dec_model in lora_modules else "all-linear", 255 | lora_dropout=0, 256 | bias="lora_only", 257 | inference_mode=False, 258 | task_type=TaskType.CAUSAL_LM 259 | ) 260 | 261 | model = E2LLMModel(args=args) 262 | encoder_lora_ckpt_path = os.path.join(ckpt_root, 'encoder') 263 | decoder_lora_ckpt_path = os.path.join(ckpt_root, 'decoder') 264 | projector_ckpt_path = os.path.join(ckpt_root, 'projector.pth') 265 | 266 | model.encoder_model = PeftModel.from_pretrained(model.encoder_model, encoder_lora_ckpt_path) 267 | model.decoder_model = PeftModel.from_pretrained(model.decoder_model, decoder_lora_ckpt_path) 268 | if args.proj_arch is not None: 269 | projector_state_dict = torch.load(projector_ckpt_path) 270 | model.projector.load_state_dict(projector_state_dict) 271 | 272 | 273 | device = model.decoder_model.device 274 | 275 | # model.encoder_model.plm_model.to(device) 276 | model.encoder_model.to(device) 277 | if args.proj_arch is not None: 278 | for p in model.projector.parameters(): 279 | p = p.to(device) 280 | model.eval() 281 | res_results = {} 282 | names = args.test_data_list[1:-1].split(',') 283 | names = [name.strip().lower() for name in names] 284 | for name in names: 285 | if name == 'qmsum': 286 | with torch.no_grad(): 287 | test_qmsum_dataset = TestqmsumDataset(chunk_size=args.chunk_size) 288 | test_qmsum_dataloader = DataLoader(test_qmsum_dataset, batch_size=args.bs, shuffle=False, collate_fn=collate_fn) 289 | test_qmsum_iterator = tqdm(test_qmsum_dataloader, mininterval=0) 290 | references = [] 291 | responses = [] 292 | with torch.no_grad(): 293 | for step, batch in enumerate(test_qmsum_iterator): 294 | chunks, prompts, outputs, task_ids = batch 295 | if args.eval_retrieve: 296 | chunks, retrieved_chunks = filter_top_k_chunks(model, chunks, prompts, task_ids, args.top_k, args.thres) 297 | response = model.predict(chunks, prompts, outputs, args.max_new_tokens, retrieved_chunks) 298 | else: 299 | response = model.predict(chunks, prompts, outputs, args.max_new_tokens) 300 | 301 | responses.extend(response) 302 | references.extend(outputs) 303 | rouger = Rouge() 304 | responses = [response if response.strip() else '.' for response in responses] 305 | results = compute_rouge(rouger, name, responses, references) 306 | write_json_fancy(results, args.output_result_path) 307 | 308 | 309 | if __name__ == '__main__': 310 | main() -------------------------------------------------------------------------------- /pefts/e2llm_trainer.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoTokenizer, AutoConfig 2 | import gc 3 | import os 4 | import sys 5 | import threading 6 | import argparse 7 | import math 8 | import json 9 | import time 10 | import transformers 11 | import numpy as np 12 | import psutil 13 | import shutil 14 | import torch 15 | from torch import nn 16 | import torch.distributed as dist 17 | from torch.utils.tensorboard import SummaryWriter 18 | from typing import List, Optional, Tuple, Union 19 | from tqdm.auto import tqdm 20 | from tqdm import trange 21 | from accelerate.logging import get_logger 22 | from accelerate import Accelerator 23 | from transformers import set_seed 24 | logger = get_logger(__name__) 25 | 26 | def write_tensorboard(summary_writer: SummaryWriter, log_dict: dict, completed_steps): 27 | for key, value in log_dict.items(): 28 | summary_writer.add_scalar(f"{key}", value, completed_steps) 29 | 30 | def check_existing_ckpts(output_dir): 31 | prefix = "step_" 32 | 33 | if not os.path.exists(output_dir): 34 | return [] 35 | # list all files and dirs 36 | contents = os.listdir(output_dir) 37 | 38 | # find dirs starts with "step_" 39 | matching_folders = [ 40 | folder for folder in contents if os.path.isdir(os.path.join(output_dir, folder)) and folder.startswith(prefix) 41 | ] 42 | 43 | return matching_folders 44 | 45 | def delete_ckpts_over_limits(output_dir, saving_limit, best_step): 46 | """delete ckpts more than saving_limits except for the best_step ckpt""" 47 | existing_ckpts = check_existing_ckpts(output_dir) 48 | logger.info(f"Existing step ckpts folders: {existing_ckpts}, best step ckpt: step_{best_step}", main_process_only=True) 49 | # sorted only step num ascendingly 50 | ckpt_steps = sorted([int(ckpt.replace("step_", "")) for ckpt in existing_ckpts]) 51 | # delete the oldest steps except for the best step at present 52 | if len(ckpt_steps) > saving_limit: 53 | deletable_steps = [ckpt_step for ckpt_step in ckpt_steps if ckpt_step != best_step] 54 | # print(deletable_steps[:len(ckpt_steps) - saving_limit]) 55 | for del_step in deletable_steps[: len(ckpt_steps) - saving_limit]: 56 | shutil.rmtree(os.path.join(output_dir, f"step_{del_step}")) 57 | logger.info(f"Removed ckpt step_{del_step}", main_process_only=True) 58 | 59 | 60 | 61 | class E2LLMTrainer: 62 | def __init__( 63 | self, 64 | accelerator: Accelerator, 65 | model, 66 | train_dataloader, 67 | val_dataloader, 68 | optimizer, 69 | lr_scheduler, 70 | tokenizer, 71 | total_train_dataset_size, 72 | args, 73 | ): 74 | self.accelerator = accelerator 75 | self.model = model 76 | self.train_dataloader = train_dataloader 77 | self.val_dataloader = val_dataloader 78 | self.optimizer = optimizer 79 | self.lr_scheduler = lr_scheduler 80 | self.tokenizer = tokenizer 81 | self.total_train_dataset_size = total_train_dataset_size 82 | self.args = args 83 | self.summary_writer = SummaryWriter(log_dir=args.tb_dir) 84 | 85 | def print(self, msg: str): 86 | """ 87 | accelerator print, default on main process 88 | Args: 89 | msg: 90 | 91 | Returns: 92 | 93 | """ 94 | self.accelerator.print(msg) 95 | 96 | @staticmethod 97 | def format_tensor(tensor, n): 98 | return list(map(lambda x: round(x, n), tensor.tolist())) 99 | 100 | 101 | def accelerate_saving_checkpoint(self, output_dir: str, completed_steps: int): 102 | """ 103 | Saving lora adaptor or full checkpoint using accelerator 104 | Args: 105 | output_dir: exact dir for saving ckpt 106 | completed_steps: 107 | 108 | Returns: 109 | 110 | """ 111 | self.accelerator.wait_for_everyone() 112 | 113 | logger.info(f"[CHECKPOINT] Saving checkpoint", main_process_only=True) 114 | if self.accelerator.is_main_process: 115 | 116 | if self.args.proj_arch is not None: 117 | torch.save(self.accelerator.get_state_dict(self.model.module.projector), f'{output_dir}/projector.pth') 118 | 119 | unwrapped_model = self.accelerator.unwrap_model(self.model) 120 | unwrapped_model.encoder_model.save_pretrained( 121 | f'{output_dir}/encoder', 122 | is_main_process=self.accelerator.is_main_process, 123 | save_function=self.accelerator.save, 124 | state_dict=self.accelerator.get_state_dict(self.model.encoder_model),) 125 | unwrapped_model.decoder_model.save_pretrained( 126 | f'{output_dir}/decoder', 127 | is_main_process=self.accelerator.is_main_process, 128 | save_function=self.accelerator.save, 129 | state_dict=self.accelerator.get_state_dict(self.model.decoder_model), 130 | ) 131 | self.accelerator.wait_for_everyone() 132 | # for full-parameter training, save whole ckpt and tokenizer together because it does not need a merge. 133 | if not self.args.peft_fn and self.accelerator.is_main_process: 134 | self.tokenizer.save_pretrained(output_dir) 135 | 136 | sf = os.path.join(output_dir, "model.safetensors") 137 | index_file = os.path.join(output_dir, "model.safetensors.index.json") 138 | if os.path.isfile(sf) and os.path.isfile(index_file): 139 | self.print(f"Remove bug dummy ckpt {sf}") 140 | os.remove(sf) 141 | 142 | if self.accelerator.is_main_process: 143 | latest = { 144 | "latest_ckpt": output_dir, 145 | "lr": self.optimizer.param_groups[0]["lr"], 146 | } 147 | with open(os.path.join(self.args.output_dir, "latest"), "w") as f: 148 | json.dump(latest, f, indent=2) 149 | 150 | logger.info( 151 | f"[CHECKPOINT][complete_steps={completed_steps}], checkpoint {output_dir} saved, latest: {latest}", 152 | main_process_only=True, 153 | ) 154 | self.accelerator.wait_for_everyone() 155 | 156 | def filter_top_k_chunks(self, chunks, queries, task_ids, top_k): 157 | 158 | res_retrieved_chunks = [] 159 | chunks_nums = [len(chunks_per) for chunks_per in chunks] 160 | 161 | chunks_cpy = [chunk for chunks_per in chunks for chunk in chunks_per] 162 | queries_chunks = queries + chunks_cpy 163 | queries_chunks_emb = self.model.encoder_model.encode_without_grad(queries_chunks) 164 | queries_emb = queries_chunks_emb[:len(queries)] # (bs, emb) 165 | chunks_emb = queries_chunks_emb[len(queries):] 166 | splitted_chunks_emb = torch.split(chunks_emb, chunks_nums) 167 | 168 | for i, t in enumerate(splitted_chunks_emb): 169 | if task_ids[i]==0: 170 | res_retrieved_chunks.append([]) 171 | continue 172 | sim = torch.matmul(queries_emb[i].unsqueeze(0), splitted_chunks_emb[i].T) 173 | values, indices = torch.topk(sim, top_k, dim=1) 174 | indices = indices[0].to('cpu') 175 | retrieved_chunks = np.array(chunks[i])[indices] 176 | res_retrieved_chunks.append(retrieved_chunks) 177 | 178 | a_sorted = sorted(indices, reverse=True) 179 | for index in a_sorted: 180 | del chunks[i][index] 181 | del queries_chunks_emb, queries_emb, chunks_emb, splitted_chunks_emb 182 | return chunks, res_retrieved_chunks 183 | 184 | 185 | def accelerate_monitor( 186 | self, 187 | current_epoch, 188 | reduce_loss, 189 | global_steps, 190 | ): 191 | 192 | reduce_losses = self.accelerator.gather(reduce_loss).detach().float() 193 | 194 | train_loss = torch.mean(reduce_losses) / (self.args.log_interval * self.args.gradient_accumulation_steps) 195 | 196 | logger.info(f'Epoch: {current_epoch+1}, Global Step: {global_steps}, Loss: {train_loss:.5f}', main_process_only=True) 197 | 198 | logger.info( 199 | f"[TRAIN][complete_steps={global_steps}][train_loss={train_loss:.6f}]" 200 | # f"[gather shape={list(reduce_losses.shape)}]" 201 | f"[lr={self.lr_scheduler.get_lr()[0]:.4e}, {self.optimizer.param_groups[0]['lr']:.4e}]", 202 | main_process_only=True, 203 | ) 204 | train_log_dict = {"Loss/train": train_loss} 205 | 206 | if self.accelerator.is_main_process: 207 | write_tensorboard(self.summary_writer, train_log_dict, global_steps) 208 | 209 | def accelerate_evaluate(self, val_dataloader, global_steps): 210 | self.model.eval() 211 | reduce_loss_eval = torch.tensor(0.0).to(self.model.decoder_model.device) 212 | batch_iterator_eval = tqdm(val_dataloader, 213 | disable=(self.accelerator.process_index!=0!=0), 214 | mininterval=0) 215 | with torch.no_grad(): 216 | for step, batch in enumerate(batch_iterator_eval): 217 | chunks, prompts, outputs, task_ids = batch 218 | if self.args.eval_retrieve: 219 | chunks, retrieved_chunks = self.filter_top_k_chunks(chunks, prompts, task_ids, self.args.top_k) 220 | loss_batch_eval = self.model(chunks, prompts, outputs, task_ids, retrieved_chunks) 221 | else: 222 | loss_batch_eval = self.model(chunks, prompts, outputs, task_ids) 223 | if not torch.isnan(loss_batch_eval): 224 | reduce_loss_eval += loss_batch_eval.detach().float() 225 | self.accelerator.wait_for_everyone() 226 | reduce_losses_eval = self.accelerator.gather(reduce_loss_eval).detach().float() 227 | val_loss = torch.mean(reduce_losses_eval) / len(val_dataloader) 228 | 229 | logger.info(f'***Evaluating***, Global Step: {global_steps}, Loss: {val_loss:.5f}', main_process_only=True) 230 | 231 | logger.info( 232 | f"[VAL][complete_steps={global_steps}][val_loss={val_loss:.6f}]", 233 | main_process_only=True, 234 | ) 235 | val_log_dict = {"Loss/val": val_loss} 236 | 237 | if self.accelerator.is_main_process: 238 | write_tensorboard(self.summary_writer, val_log_dict, global_steps) 239 | 240 | self.model.train() 241 | return val_loss 242 | 243 | def accelerate_train(self): 244 | 245 | if self.args.seed is not None: 246 | set_seed(self.args.seed) 247 | 248 | global_batch_size = ( 249 | self.args.bs 250 | * self.accelerator.num_processes 251 | * self.args.gradient_accumulation_steps 252 | ) 253 | logger.info("************************************** Running training ****************************************") 254 | logger.info(f" Num examples = {self.total_train_dataset_size}") 255 | logger.info(f" Num Epochs = {self.args.num_epochs}") 256 | logger.info(f" Instantaneous batch size per device = {self.args.bs}") 257 | logger.info(f" Total global train batch size (w. parallel, distributed & accumulation) = {global_batch_size}") 258 | logger.info(f" Gradient Accumulation steps = {self.args.gradient_accumulation_steps}") 259 | logger.info(f" Total optimization(update/completed) steps = {self.args.max_train_steps}") 260 | logger.info(f" Complete/optimize steps per Epoch = {self.args.max_train_steps // self.args.num_epochs}") 261 | logger.info("************************************************************************************************") 262 | progress_bar = tqdm(range(self.args.max_train_steps), disable=not self.accelerator.is_local_main_process) 263 | # set starting_epoch, completed_steps and resume_step of train_dataloader 264 | 265 | trained_steps = 0 266 | best_eval_metric = 0 267 | best_epoch = 0 268 | stop = 0 269 | global_step = 0 270 | num_saved_ckpts = 0 271 | ckpt_achieve = False 272 | reduce_loss = torch.tensor(0.0).to(self.model.decoder_model.device) 273 | # reduce_loss_eval = torch.tensor(0.0).to(self.model.decoder_model.device) 274 | min_reduce_loss_eval = torch.tensor(float('inf')).to(self.model.decoder_model.device) 275 | for current_epoch in trange(int(self.args.num_epochs), desc="Epoch", disable=(self.accelerator.process_index!=0), mininterval=0): 276 | if stop >= self.args.patience: 277 | logger.info(f'Early Stop at {current_epoch+1}-th epoch {global_step}-th step', main_process_only=True) 278 | logger.info(f'Model trained!\nThe best model at {best_epoch+1}-th epoch {best_step}-th step', main_process_only=True) 279 | break 280 | if ckpt_achieve: 281 | logger.info(f'Num of ckpts achieve {self.args.num_ckpt}, Stop training.', main_process_only=True ) 282 | break 283 | self.model.train() 284 | batch_iterator = tqdm(self.train_dataloader, 285 | desc=f"Running Epoch {current_epoch + 1} of {self.args.num_epochs}", 286 | disable=(self.accelerator.process_index!=0), 287 | mininterval=0) 288 | for step, batch in enumerate(batch_iterator): 289 | 290 | chunks, prompts, outputs, task_ids = batch 291 | if self.args.train_retrieve: 292 | 293 | chunks, retrieved_chunks = self.filter_top_k_chunks(chunks, prompts, task_ids, self.args.top_k) 294 | loss_batch = self.model(chunks, prompts, outputs, task_ids, retrieved_chunks) 295 | else: 296 | loss_batch = self.model(chunks, prompts, outputs, task_ids) 297 | 298 | # backward 299 | self.accelerator.backward(loss_batch) 300 | # print(self.lr_scheduler.state_dict(), self.accelerator.process_index) 301 | # update(sync_gradients) 302 | self.optimizer.step() 303 | self.lr_scheduler.step() 304 | self.optimizer.zero_grad() 305 | if not torch.isnan(loss_batch): 306 | reduce_loss += loss_batch.detach().float() 307 | if self.accelerator.sync_gradients: 308 | global_step += 1 309 | 310 | if global_step % self.args.log_interval == 0: 311 | progress_bar.update(self.args.log_interval) 312 | 313 | self.accelerate_monitor( 314 | current_epoch, 315 | reduce_loss, 316 | global_step, 317 | ) 318 | 319 | # reset loss 320 | reduce_loss = torch.tensor(0.0).to(self.model.decoder_model.device) 321 | 322 | if self.args.eval_interval and global_step % self.args.eval_interval == 0: 323 | 324 | reduce_loss_eval = self.accelerate_evaluate(self.val_dataloader, global_step) 325 | 326 | save_flag = False 327 | if stop >= self.args.patience: 328 | break 329 | if reduce_loss_eval <= min_reduce_loss_eval: 330 | min_reduce_loss_eval = reduce_loss_eval 331 | best_epoch = current_epoch 332 | best_step = global_step 333 | stop = 0 334 | 335 | # remove 336 | max_save_num = 20 337 | if self.accelerator.is_main_process: 338 | try: 339 | delete_ckpts_over_limits(self.args.output_dir, max_save_num, best_step) 340 | except: 341 | logger.info('No ckpt to remove.', main_process_only=True) 342 | else: 343 | stop += 1 344 | 345 | if stop < self.args.num_ckpt: 346 | save_flag = True 347 | else: 348 | ckpt_achieve = True 349 | 350 | if save_flag: 351 | output_dir = f"step_{global_step}" 352 | if self.args.output_dir is not None: 353 | output_dir = os.path.join(self.args.output_dir, output_dir) 354 | os.makedirs(output_dir, exist_ok=True) 355 | self.accelerate_saving_checkpoint(output_dir, global_step) 356 | 357 | 358 | -------------------------------------------------------------------------------- /model/pro_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import math 4 | import re 5 | import numpy as np 6 | import time 7 | import torch 8 | import torch.nn as nn 9 | from model.encoder_model_bert import BaseBertModel 10 | from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig 11 | from peft import prepare_model_for_kbit_training 12 | 13 | # from model.modeling_llama import LlamaForCausalLM 14 | # from model.configuration_llama import LlamaConfig 15 | from peft import ( 16 | TaskType, 17 | LoraConfig, 18 | get_peft_model, 19 | ) 20 | 21 | 22 | 23 | def vectorized_split_and_concat(original_tensor, split_indices, concat_tensor, bf16, device): 24 | bs, seq_len, emb = original_tensor.shape 25 | max_len = seq_len + concat_tensor.shape[1] 26 | if bf16: 27 | new_tensor = torch.zeros((bs, max_len, emb)).to(torch.bfloat16) 28 | else: 29 | new_tensor = torch.zeros((bs, max_len, emb)) 30 | 31 | split_indices = split_indices.long() 32 | total_lengths = split_indices + concat_tensor.shape[1] + (seq_len - split_indices) 33 | 34 | for i, split_idx in enumerate(split_indices): 35 | pad_num = seq_len - split_idx 36 | new_tensor[i, :pad_num] = original_tensor[i, split_idx:] 37 | new_tensor[i, pad_num:pad_num+concat_tensor.shape[1]] = concat_tensor[i] 38 | new_tensor[i, pad_num+concat_tensor.shape[1]:total_lengths[i]] = original_tensor[i, :split_idx] 39 | 40 | new_tensor = new_tensor[:, :new_tensor.size(1) - (max_len - total_lengths).max()] 41 | return new_tensor.to(device) 42 | 43 | class Projector(nn.Module): 44 | def __init__(self, arch, input_size, output_size, bf16): 45 | super(Projector, self).__init__() 46 | self.arch_old = re.match(r'^mlp(\d+)x_gelu$', arch) 47 | self.arch_new = re.match(r'^mlpnew(\d+)x_gelu$', arch) 48 | self.input_size = input_size 49 | self.output_size = output_size 50 | if self.arch_old: 51 | mlp_depth = int(self.arch_old.group(1)) 52 | if bf16: 53 | modules = [nn.Linear(self.input_size, self.output_size).bfloat16()] 54 | else: 55 | modules = [nn.Linear(self.input_size, self.output_size)] 56 | for _ in range(1, mlp_depth): 57 | modules.append(nn.GELU()) 58 | if bf16: 59 | modules.append(nn.Linear(self.output_size, self.output_size).bfloat16()) 60 | else: 61 | modules.append(nn.Linear(self.output_size, self.output_size)) 62 | 63 | elif self.arch_new: 64 | mlp_depth = int(self.arch_new.group(1)) 65 | modules = [] 66 | if mlp_depth == 1: 67 | if bf16: 68 | modules = [nn.Linear(self.input_size, self.output_size).bfloat16()] 69 | else: 70 | modules = [nn.Linear(self.input_size, self.output_size)] 71 | elif mlp_depth > 1: 72 | for _ in range(mlp_depth-1): 73 | if bf16: 74 | modules.append(nn.Linear(self.input_size, self.input_size).bfloat16()) 75 | else: 76 | modules.append(nn.Linear(self.input_size, self.input_size)) 77 | modules.append(nn.GELU()) 78 | if bf16: 79 | modules.append(nn.Linear(self.input_size, self.output_size).bfloat16()) 80 | else: 81 | modules.append(nn.Linear(self.input_size, self.output_size)) 82 | 83 | self.model = nn.Sequential(*modules) 84 | # self.model = nn.Linear(self.input_size, self.output_size).bfloat16() 85 | def forward(self, x): 86 | return self.model(x) 87 | 88 | 89 | 90 | class E2LLMModel(nn.Module): 91 | 92 | def __init__(self, 93 | args=None, 94 | ): 95 | super(E2LLMModel, self).__init__() 96 | self.args = args 97 | self.peft_fn = args.peft_fn 98 | self.alpha = args.alpha 99 | self.encoder_model_dir = args.encoder_model_dir 100 | self.decoder_model_dir = args.decoder_model_dir 101 | self.chunk_size = args.chunk_size 102 | self.proj_arch = args.proj_arch 103 | self.max_num_chunks = args.max_num_chunks 104 | self.max_seq_len = args.max_seq_len 105 | self.tokens_for_each_chunk = args.tokens_for_each_chunk 106 | self.padding_side = 'right' 107 | self.dtype = torch.bfloat16 if args.bf16 is True else torch.float32 108 | self.attn_implementation = args.attn_implementation 109 | self.ln = args.ln 110 | 111 | 112 | if 'Llama-2-7b-chat-hf' in self.decoder_model_dir.split('/')[-1] or 'Llama-2-13b-chat-hf' in self.decoder_model_dir.split('/')[-1] or 'Llama2-70B-Chat-hf' in self.decoder_model_dir.split('/')[-1]: 113 | self.tokenizer = AutoTokenizer.from_pretrained(self.decoder_model_dir, trust_remote_code=True, truncation_side='right', padding_side=self.padding_side) 114 | self.tokenizer.pad_token = self.tokenizer.eos_token 115 | if self.args.mode == "train": 116 | self.decoder_model = AutoModelForCausalLM.from_pretrained(self.decoder_model_dir, 117 | trust_remote_code=self.args.trust_remote_code, 118 | torch_dtype=self.dtype, 119 | attn_implementation=self.attn_implementation, 120 | quantization_config=( 121 | BitsAndBytesConfig( 122 | load_in_4bit=(self.args.quantization == "4bit"), 123 | bnb_4bit_compute_dtype=torch.bfloat16, 124 | bnb_4bit_use_double_quant=True, 125 | bnb_4bit_quant_type="nf4", 126 | bnb_4bit_quant_storage=torch.bfloat16, 127 | ) 128 | if self.args.quantization == "4bit" 129 | else None), 130 | ) 131 | elif self.args.mode == "eval": 132 | self.decoder_model = AutoModelForCausalLM.from_pretrained(self.decoder_model_dir, 133 | trust_remote_code=self.args.trust_remote_code, 134 | torch_dtype=self.dtype, 135 | attn_implementation=self.attn_implementation, 136 | device_map="auto" 137 | ) 138 | 139 | else: 140 | raise ValueError(f"args.mode can only be 'train' or 'eval' ") 141 | 142 | assert self.decoder_model.config.hidden_size % self.tokens_for_each_chunk == 0 143 | self.pma_output_dim = int(self.decoder_model.config.hidden_size / self.tokens_for_each_chunk) 144 | 145 | self.max_seq_len = min(4096, args.max_seq_len)-self.max_num_chunks-500 if self.args.train_retrieve or self.args.eval_retrieve else min(4096, args.max_seq_len)-self.max_num_chunks-60 # 500是给templeate和query留的,以及retrieved chunks 146 | self.start_mark = self.tokenizer.bos_token 147 | self.user_start_mark = '[INST]' 148 | self.user_end_mark = '[/INST]' 149 | self.sep = ' ' 150 | self._n = '\n' 151 | elif "Qwen2.5-7B-Instruct" in self.decoder_model_dir.split('/')[-1] or "Qwen2.5-72B-Instruct" in self.decoder_model_dir.split('/')[-1]: 152 | self.tokenizer = AutoTokenizer.from_pretrained(self.decoder_model_dir, trust_remote_code=True, truncation_side='right', padding_side=self.padding_side) 153 | if self.args.mode == "train": 154 | self.decoder_model = AutoModelForCausalLM.from_pretrained(self.decoder_model_dir, 155 | trust_remote_code=self.args.trust_remote_code, 156 | torch_dtype=self.dtype, 157 | attn_implementation=self.attn_implementation, 158 | quantization_config=( 159 | BitsAndBytesConfig( 160 | load_in_4bit=(self.args.quantization == "4bit"), 161 | bnb_4bit_compute_dtype=torch.bfloat16, 162 | bnb_4bit_use_double_quant=True, 163 | bnb_4bit_quant_type="nf4", 164 | bnb_4bit_quant_storage=torch.bfloat16, 165 | ) 166 | if self.args.quantization == "4bit" 167 | else None), 168 | ) 169 | elif self.args.mode == "eval": 170 | self.decoder_model = AutoModelForCausalLM.from_pretrained(self.decoder_model_dir, 171 | trust_remote_code=self.args.trust_remote_code, 172 | torch_dtype=self.dtype, 173 | attn_implementation=self.attn_implementation, 174 | device_map="auto" 175 | ) 176 | else: 177 | raise ValueError(f"args.mode can only be 'train' or 'eval' ") 178 | assert self.decoder_model.config.hidden_size % self.tokens_for_each_chunk == 0 179 | self.pma_output_dim = int(self.decoder_model.config.hidden_size / self.tokens_for_each_chunk) 180 | self.max_seq_len = min(32000, args.max_seq_len)-self.max_num_chunks-500 if self.args.train_retrieve or self.args.eval_retrieve else min(32000, args.max_seq_len)-self.max_num_chunks-60 # 500是给templeate和query留的,以及retrieved chunks 181 | self.system_mark = 'system' 182 | self.user_mark = 'user' 183 | self.assistant_mark = 'assistant' 184 | self.start_mark = '<|im_start|>' 185 | self.end_mark = '<|im_end|>' 186 | self.sep = '\n' 187 | else: 188 | raise NotImplementedError(f'NotImplementedError of {self.decoder_model_dir}') 189 | 190 | if 'bge-m3' in self.encoder_model_dir.split('/')[-1] or 'gte-large-en-v1.5' in self.encoder_model_dir.split('/')[-1] or 'gte-base-en-v1.5' in self.encoder_model_dir.split('/')[-1]: 191 | self.encoder_model = BaseBertModel(model_name_or_path=self.encoder_model_dir, max_seq_length=self.chunk_size, encoder_type=self.args.enc_pool_fn, pma_output_dim=self.pma_output_dim, tokens_for_each_chunk=self.tokens_for_each_chunk, ln=self.ln) 192 | else: 193 | raise NotImplementedError(f'UnImplemented Encoder Model :{self.encoder_model_dir}') 194 | 195 | self.encoder_emb_size = self.encoder_model.plm_model.config.hidden_size 196 | self.decoder_emb_size = self.decoder_model.config.hidden_size 197 | if self.proj_arch is not None: 198 | if self.args.enc_pool_fn != "PMA": 199 | self.projector = Projector(self.proj_arch, self.encoder_emb_size, self.decoder_emb_size, self.args.bf16) 200 | else: 201 | self.projector = Projector(self.proj_arch, self.decoder_emb_size, self.decoder_emb_size, self.args.bf16) 202 | 203 | self.loss_fct = nn.CrossEntropyLoss(ignore_index=-100, reduction='none') 204 | self.system_content = 'You are a helpful assistant.' 205 | self.user_prefix = 'Given the contexts:' 206 | self.user_suffix = 'Please follow the instruction:' 207 | 208 | self.user_prefix_prompt = f"{self.start_mark}{self.system_mark}{self.sep}{self.system_content}{self.end_mark}{self.sep}{self.start_mark}{self.user_mark}{self.sep}{self.user_prefix}" 209 | # self.user_prefix_prompt = f"{self.start_mark}{self.user_start_mark}{self.sep}{self.user_prefix}" 210 | 211 | self.user_prefix_input_ids = self.tokenizer(self.user_prefix_prompt, return_tensors='pt', add_special_tokens=False)['input_ids'] 212 | self.user_prefix_prompt_len = len(self.user_prefix_input_ids[0]) # 18 213 | 214 | 215 | self.user_suffix_input_ids = self.tokenizer(self.user_suffix, return_tensors='pt', add_special_tokens=False)['input_ids'] 216 | self.user_suffix_len = len(self.user_suffix_input_ids[0]) 217 | 218 | self.start_mark_len = len(self.tokenizer(self.start_mark, return_tensors='pt', add_special_tokens=False)['input_ids'][0]) 219 | self.end_mark_len = len(self.tokenizer(self.end_mark, return_tensors='pt', add_special_tokens=False)['input_ids'][0]) 220 | self.assistant_mark_len = len(self.tokenizer(self.assistant_mark, return_tensors='pt', add_special_tokens=False)['input_ids'][0]) 221 | 222 | 223 | self.sep_id = self.tokenizer(self.sep, return_tensors='pt', add_special_tokens=False)['input_ids'][0] # [0]取出[id]/[id,id] 224 | self.sep_len = len(self.tokenizer(self.sep, return_tensors='pt', add_special_tokens=False)['input_ids'][0]) 225 | self.sft_end_marker = self.tokenizer.eos_token 226 | self.sft_end_marker_id = self.tokenizer.eos_token_id 227 | self.sft_end_marker_len = 1 228 | 229 | def update_queries(self, queries, retrieved_chunks): 230 | queries = ['Some information that may be useful:' + '\n'.join(chunks) + '\nQuestion:' + query for query, chunks in zip(queries, retrieved_chunks)] 231 | return queries 232 | 233 | 234 | def encode(self, context_batch, mode): 235 | 236 | bs = len(context_batch) 237 | num_of_chunks_bs = [min(len(chunks), self.max_num_chunks) for chunks in context_batch] 238 | attn_mask_chunks = torch.zeros((bs, self.max_num_chunks)) 239 | cols = torch.arange(attn_mask_chunks.size(1)).unsqueeze(0).expand_as(attn_mask_chunks) 240 | mask = cols < torch.LongTensor(num_of_chunks_bs).unsqueeze(1) 241 | attn_mask_chunks[mask] = 1 242 | chunks_flatten = [chunk for chunks in context_batch for chunk in chunks] 243 | if mode == 'train': 244 | chunks_emb = self.encoder_model.encode_with_grad(chunks_flatten, batch_size=50) # (num_all_chunks, emb) 245 | else: 246 | chunks_emb = self.encoder_model.encode_without_grad(chunks_flatten, batch_size=50) # (num_all_chunks, emb) 247 | 248 | 249 | res_chunks_emb = torch.zeros((bs, self.max_num_chunks, chunks_emb.size(-1)), dtype=chunks_emb.dtype) 250 | start = 0 251 | for b_id in range(bs): 252 | end = start + num_of_chunks_bs[b_id] 253 | res_chunks_emb[b_id][:num_of_chunks_bs[b_id]] = chunks_emb[start:end] 254 | start = end 255 | 256 | 257 | return res_chunks_emb, num_of_chunks_bs, attn_mask_chunks 258 | 259 | 260 | 261 | def decode_clm_template(self, aligned_emb, attn_mask_chunks, queries, answers, task_ids, retrieved_chunks=None): 262 | loss_task_weight = {0:self.alpha, 1:1} 263 | def loss_multiplier(idx): 264 | return loss_task_weight[idx] 265 | 266 | if retrieved_chunks: 267 | queries = self.update_queries(queries, retrieved_chunks) 268 | 269 | bs = len(queries) 270 | 271 | queries_ids = self.tokenizer(queries, padding=False, add_special_tokens=False)['input_ids'] 272 | queries_lens = torch.LongTensor([len(query) for query in queries_ids]).unsqueeze(1) 273 | 274 | user_prefix_attn_mask = torch.ones((bs, self.user_prefix_prompt_len)) 275 | user_prefix_input_ids_ = self.user_prefix_input_ids.long().to(self.decoder_model.device) 276 | if self.peft_fn in ['lora', 'qlora']: 277 | user_prefix_input_embs = self.decoder_model.get_base_model().model.embed_tokens(user_prefix_input_ids_) 278 | else: 279 | user_prefix_input_embs = self.decoder_model.model.embed_tokens(self.user_prefix_input_ids.long()) 280 | user_prefix_input_embs = user_prefix_input_embs.expand(bs, -1, -1) 281 | 282 | 283 | assistant_prompts = [f"{self.sep}{self.user_suffix}{queries[i]}{self.end_mark}{self.sep}{self.start_mark}{self.assistant_mark}{self.sep}{answers[i]}{self.sft_end_marker}" for i in range(len(queries))] 284 | assistant_inputs = self.tokenizer(assistant_prompts, padding='max_length', max_length=self.max_seq_len, truncation=True, add_special_tokens=False, return_tensors='pt') 285 | assistant_ids = assistant_inputs['input_ids'].to(self.decoder_model.device) # bs, seq 286 | assistant_attn_mask = assistant_inputs['attention_mask'].to(self.decoder_model.device) # bs, seq 287 | 288 | if self.peft_fn in ['lora', 'qlora']: 289 | assistant_embs = self.decoder_model.get_base_model().model.embed_tokens(assistant_ids.long()) 290 | else: 291 | assistant_embs = self.decoder_model.model.embed_tokens(assistant_ids.long()) 292 | 293 | chunk_nums = attn_mask_chunks.sum(-1) 294 | 295 | new_embs = vectorized_split_and_concat(aligned_emb, chunk_nums, user_prefix_input_embs, self.args.bf16, self.decoder_model.device) 296 | 297 | input_embs = torch.cat((new_embs, assistant_embs), dim=1) 298 | pad_nums = (self.max_num_chunks-chunk_nums).unsqueeze(1) 299 | attn_mask_cols = torch.arange(self.user_prefix_prompt_len+self.max_num_chunks).unsqueeze(0).expand(bs, -1) 300 | if self.args.bf16: 301 | new_attn_mask = torch.ones(attn_mask_cols.size(), device=self.decoder_model.device).to(torch.bfloat16) 302 | else: 303 | new_attn_mask = torch.ones(attn_mask_cols.size(), device=self.decoder_model.device) 304 | new_attn_mask = new_attn_mask.masked_fill((attn_mask_cols