├── utils
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-38.pyc
│ ├── __init__.cpython-310.pyc
│ ├── common_utils.cpython-310.pyc
│ └── common_utils.cpython-38.pyc
└── common_utils.py
├── dataset
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-310.pyc
│ ├── __init__.cpython-38.pyc
│ ├── dataset.cpython-310.pyc
│ └── dataset.cpython-38.pyc
└── dataset.py
├── model
├── __init__.py
├── __pycache__
│ ├── pma.cpython-38.pyc
│ ├── pma.cpython-310.pyc
│ ├── __init__.cpython-38.pyc
│ ├── __init__.cpython-310.pyc
│ ├── gpt4o_mini.cpython-310.pyc
│ ├── gpt4o_mini.cpython-38.pyc
│ ├── encoder_model_bert.cpython-38.pyc
│ ├── eval_model_for_70b.cpython-38.pyc
│ ├── pro_model_qwen_new.cpython-38.pyc
│ ├── encoder_model_bert.cpython-310.pyc
│ ├── pro_model_llama_new.cpython-310.pyc
│ ├── pro_model_llama_new.cpython-38.pyc
│ └── pro_model_qwen_new.cpython-310.pyc
├── pma.py
├── encoder_model_bert.py
└── pro_model.py
├── img
└── E2LLM.png
├── pefts
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-38.pyc
│ ├── __init__.cpython-310.pyc
│ ├── e2llm_args.cpython-310.pyc
│ ├── e2llm_args.cpython-38.pyc
│ ├── e2llm_trainer.cpython-38.pyc
│ └── e2llm_trainer.cpython-310.pyc
├── e2llm_args.py
└── e2llm_trainer.py
├── evaluate
├── __init__.py
├── __pycache__
│ ├── f1_qa.cpython-38.pyc
│ ├── __init__.cpython-38.pyc
│ ├── f1_qa.cpython-310.pyc
│ ├── __init__.cpython-310.pyc
│ ├── rouge_sum.cpython-310.pyc
│ ├── rouge_sum.cpython-38.pyc
│ ├── niah_metric.cpython-310.pyc
│ └── niah_metric.cpython-38.pyc
├── rouge_cn.py
├── rouge_sum.py
├── em_quality.py
└── f1_qa.py
├── .gitignore
├── prompts
├── 0shot_cot.txt
├── 0shot.txt
├── 0shot_no_context.txt
├── 0shot_rag.txt
└── 0shot_cot_ans.txt
├── configs
├── lora_modules.json
├── model2maxlen.json
├── eval_config.json
└── train_config.json
├── requirements.txt
├── LEGAL.md
├── local
└── ds_config_zero2.yaml
├── train_local_machine.sh
├── train_multi_node.sh
├── eval.sh
├── README.md
├── preprocess
└── preshuffle_data_and_chunk.py
├── LICENSE.md
├── train_accelerate.py
└── eval.py
/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/dataset/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/model/__init__.py:
--------------------------------------------------------------------------------
1 | from .pma import PMA_v1, PMA_v2
--------------------------------------------------------------------------------
/img/E2LLM.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/codefuse-ai/E2LLM/master/img/E2LLM.png
--------------------------------------------------------------------------------
/pefts/__init__.py:
--------------------------------------------------------------------------------
1 | from .e2llm_args import E2LLMTrainArgs
2 | from .e2llm_trainer import E2LLMTrainer
3 |
--------------------------------------------------------------------------------
/model/__pycache__/pma.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/codefuse-ai/E2LLM/master/model/__pycache__/pma.cpython-38.pyc
--------------------------------------------------------------------------------
/evaluate/__init__.py:
--------------------------------------------------------------------------------
1 | from .rouge_sum import compute_rouge
2 | from .f1_qa import compute_f1
3 | from .niah_metric import compute_score
--------------------------------------------------------------------------------
/model/__pycache__/pma.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/codefuse-ai/E2LLM/master/model/__pycache__/pma.cpython-310.pyc
--------------------------------------------------------------------------------
/evaluate/__pycache__/f1_qa.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/codefuse-ai/E2LLM/master/evaluate/__pycache__/f1_qa.cpython-38.pyc
--------------------------------------------------------------------------------
/model/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/codefuse-ai/E2LLM/master/model/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/pefts/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/codefuse-ai/E2LLM/master/pefts/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/codefuse-ai/E2LLM/master/utils/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/dataset/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/codefuse-ai/E2LLM/master/dataset/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/dataset/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/codefuse-ai/E2LLM/master/dataset/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/dataset/__pycache__/dataset.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/codefuse-ai/E2LLM/master/dataset/__pycache__/dataset.cpython-310.pyc
--------------------------------------------------------------------------------
/dataset/__pycache__/dataset.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/codefuse-ai/E2LLM/master/dataset/__pycache__/dataset.cpython-38.pyc
--------------------------------------------------------------------------------
/evaluate/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/codefuse-ai/E2LLM/master/evaluate/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/evaluate/__pycache__/f1_qa.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/codefuse-ai/E2LLM/master/evaluate/__pycache__/f1_qa.cpython-310.pyc
--------------------------------------------------------------------------------
/model/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/codefuse-ai/E2LLM/master/model/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/model/__pycache__/gpt4o_mini.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/codefuse-ai/E2LLM/master/model/__pycache__/gpt4o_mini.cpython-310.pyc
--------------------------------------------------------------------------------
/model/__pycache__/gpt4o_mini.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/codefuse-ai/E2LLM/master/model/__pycache__/gpt4o_mini.cpython-38.pyc
--------------------------------------------------------------------------------
/pefts/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/codefuse-ai/E2LLM/master/pefts/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/pefts/__pycache__/e2llm_args.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/codefuse-ai/E2LLM/master/pefts/__pycache__/e2llm_args.cpython-310.pyc
--------------------------------------------------------------------------------
/pefts/__pycache__/e2llm_args.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/codefuse-ai/E2LLM/master/pefts/__pycache__/e2llm_args.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/codefuse-ai/E2LLM/master/utils/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | do*
2 | show*
3 | draw*
4 | submit*
5 | init*
6 | high*
7 | run*
8 | tools*
9 | test*
10 | pred*
11 | data_statics.py
12 | train
--------------------------------------------------------------------------------
/evaluate/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/codefuse-ai/E2LLM/master/evaluate/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/evaluate/__pycache__/rouge_sum.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/codefuse-ai/E2LLM/master/evaluate/__pycache__/rouge_sum.cpython-310.pyc
--------------------------------------------------------------------------------
/evaluate/__pycache__/rouge_sum.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/codefuse-ai/E2LLM/master/evaluate/__pycache__/rouge_sum.cpython-38.pyc
--------------------------------------------------------------------------------
/pefts/__pycache__/e2llm_trainer.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/codefuse-ai/E2LLM/master/pefts/__pycache__/e2llm_trainer.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/common_utils.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/codefuse-ai/E2LLM/master/utils/__pycache__/common_utils.cpython-310.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/common_utils.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/codefuse-ai/E2LLM/master/utils/__pycache__/common_utils.cpython-38.pyc
--------------------------------------------------------------------------------
/evaluate/__pycache__/niah_metric.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/codefuse-ai/E2LLM/master/evaluate/__pycache__/niah_metric.cpython-310.pyc
--------------------------------------------------------------------------------
/evaluate/__pycache__/niah_metric.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/codefuse-ai/E2LLM/master/evaluate/__pycache__/niah_metric.cpython-38.pyc
--------------------------------------------------------------------------------
/pefts/__pycache__/e2llm_trainer.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/codefuse-ai/E2LLM/master/pefts/__pycache__/e2llm_trainer.cpython-310.pyc
--------------------------------------------------------------------------------
/model/__pycache__/encoder_model_bert.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/codefuse-ai/E2LLM/master/model/__pycache__/encoder_model_bert.cpython-38.pyc
--------------------------------------------------------------------------------
/model/__pycache__/eval_model_for_70b.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/codefuse-ai/E2LLM/master/model/__pycache__/eval_model_for_70b.cpython-38.pyc
--------------------------------------------------------------------------------
/model/__pycache__/pro_model_qwen_new.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/codefuse-ai/E2LLM/master/model/__pycache__/pro_model_qwen_new.cpython-38.pyc
--------------------------------------------------------------------------------
/model/__pycache__/encoder_model_bert.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/codefuse-ai/E2LLM/master/model/__pycache__/encoder_model_bert.cpython-310.pyc
--------------------------------------------------------------------------------
/model/__pycache__/pro_model_llama_new.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/codefuse-ai/E2LLM/master/model/__pycache__/pro_model_llama_new.cpython-310.pyc
--------------------------------------------------------------------------------
/model/__pycache__/pro_model_llama_new.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/codefuse-ai/E2LLM/master/model/__pycache__/pro_model_llama_new.cpython-38.pyc
--------------------------------------------------------------------------------
/model/__pycache__/pro_model_qwen_new.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/codefuse-ai/E2LLM/master/model/__pycache__/pro_model_qwen_new.cpython-310.pyc
--------------------------------------------------------------------------------
/prompts/0shot_cot.txt:
--------------------------------------------------------------------------------
1 | What is the correct answer to this question: $Q$
2 | Choices:
3 | (A) $C_A$
4 | (B) $C_B$
5 | (C) $C_C$
6 | (D) $C_D$
7 |
8 | Let’s think step by step:
--------------------------------------------------------------------------------
/prompts/0shot.txt:
--------------------------------------------------------------------------------
1 | What is the correct answer to this question: $Q$
2 | Choices:
3 | (A) $C_A$
4 | (B) $C_B$
5 | (C) $C_C$
6 | (D) $C_D$
7 |
8 | Format your response as follows: "The correct answer is (insert answer here)".
--------------------------------------------------------------------------------
/prompts/0shot_no_context.txt:
--------------------------------------------------------------------------------
1 | What is the correct answer to this question: $Q$
2 | Choices:
3 | (A) $C_A$
4 | (B) $C_B$
5 | (C) $C_C$
6 | (D) $C_D$
7 |
8 | What is the single, most likely answer choice? Format your response as follows: "The correct answer is (insert answer here)".
--------------------------------------------------------------------------------
/configs/lora_modules.json:
--------------------------------------------------------------------------------
1 | {
2 | "gte-large-en-v1.5": ["qkv_proj", "o_proj"],
3 | "Llama-2-7b-chat-hf": ["q_proj", "k_proj", "v_proj", "o_proj"],
4 | "Llama-2-13b-chat-hf": ["q_proj", "k_proj", "v_proj", "o_proj"],
5 | "Llama2-70B-Chat-hf": ["q_proj", "k_proj", "v_proj", "o_proj"],
6 | "Qwen2.5-7B-Instruct": "all-linear"
7 | }
--------------------------------------------------------------------------------
/prompts/0shot_rag.txt:
--------------------------------------------------------------------------------
1 | Please read the following retrieved text chunks and answer the question below.
2 |
3 |
4 | $DOC$
5 |
6 |
7 | What is the correct answer to this question: $Q$
8 | Choices:
9 | (A) $C_A$
10 | (B) $C_B$
11 | (C) $C_C$
12 | (D) $C_D$
13 |
14 | Format your response as follows: "The correct answer is (insert answer here)".
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | accelerate==0.23.0
2 | bitsandbytes==0.40.2
3 | datasets==2.14.1
4 | deepspeed==0.9.3
5 | flash-attn==2.3.6
6 | numpy==1.23.5
7 | peft==0.7.0
8 | scikit-learn==1.3.0
9 | tensorboard==2.11.0
10 | tiktoken==0.4.0
11 | tokenizers==0.15.0
12 | torch==2.0.1
13 | transformers==4.36.0
14 | transformers-stream-generator==0.0.4
15 | xformers==0.0.21
--------------------------------------------------------------------------------
/LEGAL.md:
--------------------------------------------------------------------------------
1 | Legal Disclaimer
2 | Within this source code, the comments in Chinese shall be the original, governing version. Any comment in other languages are for reference only. In the event of any conflict between the Chinese language version comments and other language version comments, the Chinese language version shall prevail.
3 | 法律免责声明
4 | 关于代码注释部分,中文注释为官方版本,其它语言注释仅做参考。中文注释可能与其它语言注释存在不一致,当中文注释与其它语言注释存在不一致时,请以中文注释为准。
--------------------------------------------------------------------------------
/prompts/0shot_cot_ans.txt:
--------------------------------------------------------------------------------
1 | Please read the following text and answer the questions below.
2 |
3 | The text is too long and omitted here.
4 |
5 | What is the correct answer to this question: $Q$
6 | Choices:
7 | (A) $C_A$
8 | (B) $C_B$
9 | (C) $C_C$
10 | (D) $C_D$
11 |
12 | Let’s think step by step: $COT$
13 |
14 | Based on the above, what is the single, most likely answer choice? Format your response as follows: "The correct answer is (insert answer here)".
--------------------------------------------------------------------------------
/local/ds_config_zero2.yaml:
--------------------------------------------------------------------------------
1 | compute_environment: LOCAL_MACHINE
2 | deepspeed_config:
3 | gradient_accumulation_steps: 1
4 | gradient_clipping: 1.0
5 | offload_optimizer_device: cpu
6 | offload_param_device: none
7 | zero3_init_flag: false
8 | zero3_save_16bit_model: true
9 | zero_stage: 2
10 | # steps_per_print: 1
11 | distributed_type: DEEPSPEED
12 | downcast_bf16: 'no'
13 | dynamo_backend: 'NO'
14 | fsdp_config: {}
15 | machine_rank: 0
16 | main_training_function: main
17 | megatron_lm_config: {}
18 | mixed_precision: 'bf16'
19 | num_machines: 1
20 | num_processes: 1
21 | rdzv_backend: static
22 | same_network: true
23 | use_cpu: false
--------------------------------------------------------------------------------
/configs/model2maxlen.json:
--------------------------------------------------------------------------------
1 | {
2 | "Llama-2-7B-Instruct": 4000,
3 | "Qwen2.5-7B-Instruct-YaRN": 64000,
4 | "Qwen2.5-7B-Instruct": 64000,
5 | "YaRN": 58000,
6 | "LongLoRA": 58000,
7 | "RAG": 4000,
8 | "GLM-4-9B-Chat": 120000,
9 | "Llama-3.1-8B-Instruct": 120000,
10 | "Llama-3.1-70B-Instruct": 120000,
11 | "Llama-3.3-70B-Instruct": 120000,
12 | "Llama-3.1-Nemotron-70B-Instruct": 120000,
13 | "Qwen2.5-72B-Instruct": 120000,
14 | "Mistral-Large-Instruct-2407": 120000,
15 | "Mistral-Large-Instruct-2411": 120000,
16 | "c4ai-command-r-plus-08-2024": 120000,
17 | "glm-4-plus": 120000,
18 | "gpt-4o-2024-08-06": 120000,
19 | "gpt-4o-mini-2024-07-18": 120000,
20 | "o1-mini-2024-09-12": 120000,
21 | "o1-preview-2024-09-12": 120000,
22 | "claude-3.5-sonnet-20241022": 200000
23 | }
--------------------------------------------------------------------------------
/train_local_machine.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | TRAIN_CONFIG="./configs/train_config.json"
3 |
4 |
5 |
6 | DISTRIBUTED_TYPE="DeepSpeed2"
7 | # DISTRIBUTED_TYPE="DeepSpeed3"
8 | # DISTRIBUTED_TYPE="FSDP"
9 |
10 | if [ "$DISTRIBUTED_TYPE" = "DeepSpeed2" ]; then
11 | echo "DISTRIBUTED_TYPE is DeepSpeed ZeRO2"
12 | accelerate launch --config_file "./local/ds_config_zero2.yaml" train_accelerate.py --train_config "$TRAIN_CONFIG"
13 | elif [ "$DISTRIBUTED_TYPE" = "DeepSpeed3" ]; then
14 | echo "DISTRIBUTED_TYPE is DeepSpeed ZeRO3"
15 | accelerate launch --config_file "./local/ds_config_zero3.yaml" train_accelerate.py --train_config "$TRAIN_CONFIG"
16 | elif [ "$DISTRIBUTED_TYPE" = "FSDP" ]; then
17 | echo "DISTRIBUTED_TYPE is FSDP"
18 | accelerate launch --config_file "./local/fsdp_config.yaml" train_accelerate.py --train_config "$TRAIN_CONFIG"
19 | fi
--------------------------------------------------------------------------------
/evaluate/rouge_cn.py:
--------------------------------------------------------------------------------
1 | from rouge_chinese import Rouge
2 | import jieba
3 |
4 |
5 | def compute_rouge(rouge, hypothesis, reference):
6 |
7 |
8 | hypothesis = ' '.join(jieba.cut(hypothesis))
9 |
10 | reference = ' '.join(jieba.cut(reference))
11 |
12 | scores = rouge.get_scores(hypothesis, reference)
13 | r1_f = scores[0]['rouge-1']['f']
14 | r2_f = scores[0]['rouge-2']['f']
15 | rl_f = scores[0]['rouge-l']['f']
16 |
17 | return r1_f, r2_f, rl_f
18 |
19 |
20 | if __name__ == '__main__':
21 | avg_r1 = 0
22 | avg_r2 = 0
23 | avg_rl = 0
24 | rouge = Rouge()
25 | dataset = [('', '')]
26 | for h, r in dataset:
27 | count+=1
28 | r1, r2, rl = compute_rouge(rouge, h, r)
29 | avg_r1 += r1
30 | avg_r2 += r2
31 | avg_rl += rl
32 | avg_r1 = avg_r1 / count
33 | avg_r2 = avg_r2 / count
34 | avg_rl = avg_rl / count
35 |
--------------------------------------------------------------------------------
/configs/eval_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "encoder_model_dir": "/path/to/encoder/model",
3 | "decoder_model_dir": "/path/to/decoder/model",
4 | "chunk_size": 512,
5 | "max_num_chunks": 500,
6 | "max_seq_len": 1024,
7 | "max_sliding_windows": null,
8 | "enc_pool_fn": "CLS",
9 | "tokens_for_each_chunk": 1,
10 | "ln": true,
11 | "proj_arch": "mlpnew2x_gelu",
12 | "alpha": 1e-9,
13 | "peft_fn": "qlora",
14 | "lora_rank_enc": 32,
15 | "lora_alpha_enc": 32,
16 | "lora_rank_dec": 8,
17 | "lora_alpha_dec": 8,
18 | "num_epochs": 20,
19 | "lr": 1e-4,
20 | "bs": 6,
21 | "log_interval": 5,
22 | "eval_interval": 5,
23 | "patience": 20,
24 | "num_ckpt": 20,
25 | "mark": "vtest3",
26 | "attn_implementation": "flash_attention_2",
27 | "num_warmup_steps": 100,
28 | "recover_training": false,
29 | "recover_global_steps": -1,
30 | "bf16": false,
31 | "distributed_type": "deepspeed",
32 | "mode": "eval"
33 | }
--------------------------------------------------------------------------------
/train_multi_node.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | DISTRIBUTED_TYPE="DeepSpeed2"
3 | export OMP_NUM_THREADS=16
4 |
5 |
6 | if [ "$DISTRIBUTED_TYPE" = "DeepSpeed2" ]; then
7 | echo "DISTRIBUTED_TYPE is DeepSpeed ZeRO2"
8 | accelerate launch \
9 | --num_machines $N_NODE \
10 | --num_processes $(($N_NODE*$N_GPU_PER_NODE)) \
11 | --use_deepspeed \
12 | --deepspeed_multinode_launcher 'standard' \
13 | --zero_stage 2 \
14 | --offload_optimizer_device 'cpu' \
15 | --offload_param_device 'none' \
16 | --gradient_accumulation_steps 1 \
17 | --gradient_clipping 1.0 \
18 | --zero3_init_flag false \
19 | --zero3_save_16bit_model true \
20 | --main_training_function 'main' \
21 | --mixed_precision 'bf16' \
22 | --dynamo_backend 'no' \
23 | --same_network \
24 | --machine_rank $RANK \
25 | --main_process_ip $MASTER_ADDR \
26 | --main_process_port $MASTER_PORT \
27 | --rdzv_backend 'static' \
28 | train_accelerate.py --train_config "./configs/train_config.json"
29 |
--------------------------------------------------------------------------------
/configs/train_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "train_data_list": "[qmsum]",
3 | "val_data_list": "[qmsum]",
4 | "encoder_model_dir": "/path/to/encoder/model",
5 | "decoder_model_dir": "/path/to/decoder/model",
6 | "chunk_size": 512,
7 | "max_num_chunks": 700,
8 | "max_seq_len": 4096,
9 | "max_sliding_windows": null,
10 | "enc_pool_fn": "PMA",
11 | "tokens_for_each_chunk": 4,
12 | "ln": true,
13 | "proj_arch": null,
14 | "alpha": 1e-9,
15 | "peft_fn": "lora",
16 | "lora_rank_enc": 32,
17 | "lora_alpha_enc": 32,
18 | "lora_rank_dec": 16,
19 | "lora_alpha_dec": 16,
20 | "train_retrieve": false,
21 | "eval_retrieve": false,
22 | "top_k": 3,
23 | "num_epochs": 20,
24 | "lr": 1e-4,
25 | "bs": 4,
26 | "log_interval": 10,
27 | "eval_interval": 150,
28 | "patience": 100,
29 | "num_ckpt": 100,
30 | "mark": "mark_of_your_model_version",
31 | "attn_implementation": "flash_attention_2",
32 | "num_warmup_steps": 100,
33 | "recover_training": false,
34 | "recover_global_steps": -1,
35 | "bf16": true,
36 | "distributed_type": "deepspeed",
37 | "mode": "train"
38 | }
--------------------------------------------------------------------------------
/evaluate/rouge_sum.py:
--------------------------------------------------------------------------------
1 | from rouge import Rouge
2 | import math
3 |
4 | def compute_rouge_(rouger, hypothesis, reference):
5 |
6 | print('*'*20)
7 | print(hypothesis)
8 | print('='*20)
9 | print(reference)
10 | print('*'*20)
11 | scores = rouger.get_scores(hypothesis, reference)
12 | r1_f = scores[0]['rouge-1']['f']
13 | r2_f = scores[0]['rouge-2']['f']
14 | rl_f = scores[0]['rouge-l']['f']
15 |
16 | return r1_f, r2_f, rl_f
17 |
18 | def compute_rouge(rouger, dataname, hypothesis_list, reference_list):
19 | assert len(hypothesis_list) == len(reference_list)
20 | avg_r1 = 0
21 | avg_r2 = 0
22 | avg_rl = 0
23 | rouger = Rouge()
24 | count = 0
25 | for i in range(len(hypothesis_list)):
26 | h = hypothesis_list[i]
27 | r = reference_list[i]
28 | count+=1
29 | try:
30 | r1, r2, rl = compute_rouge_(rouger, h, r)
31 | except:
32 | print('ignoring as 0.0 score')
33 | r1, r2, rl = 0.0, 0.0, 0.0
34 | avg_r1 += r1
35 | avg_r2 += r2
36 | avg_rl += rl
37 | avg_r1 = avg_r1 / count
38 | avg_r2 = avg_r2 / count
39 | avg_rl = avg_rl / count
40 | g_mean = math.pow(avg_r1*avg_r2*avg_rl, 1/3)
41 | return {f'{dataname}-r1':avg_r1, f'{dataname}-r2':avg_r2, f'{dataname}-rl':avg_rl, f'{dataname}-g_mean':g_mean}
42 |
43 |
--------------------------------------------------------------------------------
/evaluate/em_quality.py:
--------------------------------------------------------------------------------
1 | import re
2 | import string
3 |
4 |
5 | def normalize_answer(s):
6 | """Lower text and remove punctuation, articles and extra whitespace."""
7 |
8 | def remove_articles(text):
9 | return re.sub(r"\b(a|an|the)\b", " ", text)
10 |
11 | def white_space_fix(text):
12 | return " ".join(text.split())
13 |
14 | def remove_punc(text):
15 | exclude = set(string.punctuation)
16 | return "".join(ch for ch in text if ch not in exclude)
17 |
18 | def lower(text):
19 | return text.lower()
20 |
21 | return white_space_fix(remove_articles(remove_punc(lower(s))))
22 |
23 |
24 | def exact_match_score(prediction, ground_truth):
25 | return normalize_answer(prediction) == normalize_answer(ground_truth)
26 |
27 |
28 | # def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
29 | # scores_for_ground_truths = []
30 | # for ground_truth in ground_truths:
31 | # score = metric_fn(prediction, ground_truth)
32 | # scores_for_ground_truths.append(score)
33 | # return max(scores_for_ground_truths)
34 |
35 |
36 | def compute_exact_match(predictions, references):
37 | exact_match = 0
38 | for prediction, ground_truths in zip(predictions, references):
39 | em = metric_max_over_ground_truths(exact_match_score, prediction, ground_truths)
40 | exact_match += em
41 |
42 |
43 | return exact_match / len(predictions)
44 |
45 |
--------------------------------------------------------------------------------
/eval.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | GLOBAL_STEP=step_of_step
3 | MARK="mark_of_your_model_version"
4 | TEST_DATA_LIST="[qmsum]"
5 | ENCODER_MODEL_DIR="/path/to/encoder/model"
6 | DECODER_MODEL_DIR="/path/to/decoder/model"
7 | CHUNK_SIZE=chunk_size
8 | MAX_NUM_CHUNKS=500
9 | ENC_POOL_FN="PMA"
10 | TOKENS_FOR_EACH_CHUNK=4
11 | LN="True"
12 | PROJ_ARCH="None"
13 | BS=bs
14 | LORA_RANK_ENC=lora_rank_enc
15 | LORA_ALPHA_ENC=lora_alpha_enc
16 | LORA_RANK_DEC=lora_rank_dec
17 | LORA_ALPHA_DEC=lora_alpha_dec
18 | EVAL_RETRIEVE="False"
19 | TOP_K=3
20 | THRES=0.5
21 | MAX_NEW_TOKENS=128
22 | ATTN_IMPLEMENTATION="eager"
23 |
24 | EVAL_CONFIG="./configs/eval_config.json"
25 |
26 | python eval.py --eval_config "$EVAL_CONFIG" \
27 | --test_data_list $TEST_DATA_LIST \
28 | --global_step $GLOBAL_STEP \
29 | --mark $MARK \
30 | --encoder_model_dir $ENCODER_MODEL_DIR \
31 | --decoder_model_dir $DECODER_MODEL_DIR \
32 | --chunk_size $CHUNK_SIZE \
33 | --max_num_chunks $MAX_NUM_CHUNKS \
34 | --enc_pool_fn $ENC_POOL_FN \
35 | --tokens_for_each_chunk $TOKENS_FOR_EACH_CHUNK \
36 | --ln $LN \
37 | --proj_arch $PROJ_ARCH \
38 | --lora_rank_enc $LORA_RANK_ENC \
39 | --lora_alpha_enc $LORA_ALPHA_ENC \
40 | --lora_rank_dec $LORA_RANK_DEC \
41 | --lora_alpha_dec $LORA_ALPHA_DEC \
42 | --eval_retrieve $EVAL_RETRIEVE \
43 | --top_k $TOP_K \
44 | --thres $THRES \
45 | --attn_implementation $ATTN_IMPLEMENTATION \
46 | --bs $BS \
47 | --max_new_tokens $MAX_NEW_TOKENS
--------------------------------------------------------------------------------
/evaluate/f1_qa.py:
--------------------------------------------------------------------------------
1 | # Copied from https://github.com/huggingface/datasets/blob/d3c7b9481d427ce41256edaf6773c47570f06f3b/metrics/squad/evaluate.py
2 |
3 | import re
4 | import string
5 | from collections import Counter
6 |
7 |
8 | def normalize_answer(s):
9 | """Lower text and remove punctuation, articles and extra whitespace."""
10 |
11 | def remove_articles(text):
12 | return re.sub(r"\b(a|an|the)\b", " ", text)
13 |
14 | def white_space_fix(text):
15 | return " ".join(text.split())
16 |
17 | def remove_punc(text):
18 | exclude = set(string.punctuation)
19 | return "".join(ch for ch in text if ch not in exclude)
20 |
21 | def lower(text):
22 | return text.lower()
23 |
24 | return white_space_fix(remove_articles(remove_punc(lower(s))))
25 |
26 |
27 | def f1_score(prediction, ground_truth):
28 | print('*'*20)
29 | print(prediction)
30 | print('='*20)
31 | print(ground_truth)
32 | print('*'*20)
33 | prediction_tokens = normalize_answer(prediction).split()
34 | ground_truth_tokens = normalize_answer(ground_truth).split()
35 | common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
36 | num_same = sum(common.values())
37 | if num_same == 0:
38 | return 0,0,0
39 | precision = 1.0 * num_same / len(prediction_tokens)
40 | recall = 1.0 * num_same / len(ground_truth_tokens)
41 | f1 = (2 * precision * recall) / (precision + recall)
42 | return precision, recall, f1
43 |
44 |
45 | def compute_f1(dataname, predictions, references):
46 | precision_all = 0
47 | recall_all = 0
48 | f1_all = 0
49 | for prediction, ground_truth in zip(predictions, references):
50 | precision, recall, f1 = f1_score(prediction, ground_truth)
51 | precision_all += precision
52 | recall_all += recall
53 | f1_all += f1
54 | res = {f"{dataname}-precision": precision_all/len(predictions), f"{dataname}-recall": recall_all/len(predictions), f"{dataname}-f1": f1_all/len(predictions)}
55 | return res
56 |
57 |
--------------------------------------------------------------------------------
/pefts/e2llm_args.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass, asdict
2 | from typing import List, Union
3 |
4 |
5 | @dataclass
6 | class E2LLMTrainArgs:
7 |
8 | # train datasets
9 | train_data_list: Union[str, List[str]] = "[qmsum]"
10 |
11 | # val datasets
12 | val_data_list: Union[str, List[str]] = "[qmsum]"
13 |
14 | # encoder model dir
15 | encoder_model_dir: str = "/path/to/encoder/model"
16 |
17 | # decoder model dir
18 | decoder_model_dir: str = "/path/to/decoder/model"
19 |
20 | # chunk size
21 | chunk_size: int = 512
22 |
23 | # max num chunks
24 | max_num_chunks: int = 500
25 |
26 | # max sequence length during training
27 | max_seq_len: int = 4096
28 |
29 | # max sliding windows
30 | max_sliding_windows: Union[None, int] = None
31 |
32 | # encoder pooling method
33 | enc_pool_fn: str = "CLS"
34 |
35 | # only for pma
36 | tokens_for_each_chunk: int = 4
37 |
38 | # only for pma
39 | ln: bool = True
40 |
41 | # adapter architecture, if enc_pool_fn=='PMA', set this to None (null for config.json)
42 | proj_arch: Union[None, str] = "mlpnew2x_gelu"
43 |
44 | # trade-off weight for reconstruct task
45 | alpha: float = 1e-7
46 |
47 | # 'lora' or 'qlora'
48 | peft_fn: Union[None, str] = "lora"
49 |
50 | # lora params
51 | lora_rank_enc: int = 32
52 | lora_alpha_enc: int = 32
53 | lora_rank_dec: int = 8
54 | lora_alpha_dec: int = 8
55 |
56 | # whether retrieve
57 | train_retrieve: bool = True
58 | eval_retrieve: bool = True
59 | top_k: int = 3
60 |
61 | # datasets params
62 | shuffle_sample_ratio: float = 0.0
63 | noisy_sample_ratio: float = 0.0
64 | noise_rate: float = 0.0
65 |
66 | # training epochs
67 | num_epochs: int = 5
68 |
69 | # learning rate
70 | lr: float = 1e-4
71 |
72 | # batch size
73 | bs: int = 2
74 |
75 | # logging per n steps
76 | log_interval: int = 10
77 |
78 | # evaluating per n steps
79 | eval_interval: int = 200
80 |
81 | # random seed
82 | seed: int = 2023
83 |
84 | # patience for early stop
85 | patience: int = 5
86 |
87 | # number of ckpts
88 | num_ckpt: int = 5
89 |
90 | # mark
91 | mark: str = "test"
92 |
93 | # flash attention
94 | attn_implementation: str = "flash_attention_2"
95 |
96 | # bf16
97 | bf16: bool = True
98 |
99 | # weight decay
100 | weight_decay: float = 1e-2
101 |
102 | # gradient clipping
103 | gradient_clipping: float = 1.0
104 |
105 | # warmup steps
106 | num_warmup_steps: int = 100
107 |
108 | # max training steps
109 | max_train_steps: Union[None, int] = None
110 |
111 | # max sliding windows
112 | max_sliding_windows: Union[None, int] = None
113 |
114 | # gradient accumulation steps
115 | gradient_accumulation_steps: int = 1
116 |
117 |
118 | # recover training
119 | recover_training: bool = False
120 |
121 | # recover global steps
122 | recover_global_steps: int = -1
123 |
124 | # verbose
125 | verbose: bool = True
126 |
127 | # distributed type
128 | distributed_type: str = "deepspeed"
129 |
130 | # mode, "train" or "eval"
131 | mode: str = "train"
132 |
133 | # legacy, leave them
134 | use_xformers: bool = True
135 | trust_remote_code: bool = True
136 | model_parallel_size: int = 1
137 | use_slow_tokenizer: bool = False
138 | world_size: int = 8
139 | init_timeout_seconds: Union[None, int] = 3600
140 |
141 | def dict(self):
142 | return {k: str(v) for k, v in asdict(self).items()}
143 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # E2LLM: Encoder Elongated Large Language Models for Long-Context Understanding and Reasoning
2 |
3 | This is the Pytorch implementation of E2LLM in the EMNLP'25 paper: E2LLM: Encoder Elongated Large Language Models for Long-Context Understanding and Reasoning.
4 |
5 | ## Overview
6 | 
7 |
8 |
9 | Abstract
10 |
11 |
12 | - We propose E2LLM, a novel long-context modeling framework built on pre-trained text encoders and decoder-only LLMs to effectively address the ``impossible triangle'' challenge.
13 |
14 | - We introduce two training objectives: soft prompt reconstruction and long-context instruction fine-tuning, enabling the LLM to understand the soft prompt while reasoning about accurate outputs.
15 |
16 | - Comprehensive experiments conducted on diverse benchmarks demonstrate the efficiency and practicality of E2LLM and reveal its superiority over 8 SOTA baselines and competency on LongBench v2.
17 |
18 |
19 |
20 |
21 |
22 | ## Requirements
23 |
24 | * Ubuntu OS
25 | * python==3.10
26 | * torch==2.0.1
27 | * cuda==11.7
28 | * accelerate==0.23.0
29 | * transformers==4.36.0
30 | * deepspeed==0.9.3
31 | * flash-attn==2.3.6
32 | * peft==0.7.0
33 | * scikit-learn==1.3.0
34 |
35 | Dependencies can be installed by:
36 |
37 | pip install -r requirements.txt
38 |
39 |
40 | The overall directory structure is as follows:
41 |
42 | ${CODE_ROOT}
43 | |-- configs
44 | |-- eval_config.json
45 | |-- lora_modules.json
46 | |-- model2maxlen.json
47 | |-- train_config.json
48 | |-- dataset
49 | |-- __init__.py
50 | |-- dataset.py
51 | |-- evaluate
52 | |-- __init__.py
53 | |-- em_quality.py
54 | |-- f1_qa.py
55 | |-- niah_metric.py
56 | |-- rouge_sum.py
57 | |-- local
58 | |-- ds_config_zero2.yaml
59 | |-- model
60 | |-- __init__.py
61 | |-- encoder_model_bert.py
62 | |-- pma.py
63 | |-- pro_model.py
64 | |-- pefts
65 | |-- __init__.py
66 | |-- e2llm_args.py
67 | |-- e2llm_trainer.py
68 | |-- preprocess
69 | |-- preshuffle_data_and_chunk.py
70 | |-- prompts
71 | |-- utils
72 | |-- __init__.py
73 | |-- common_utils.py
74 | |-- eval.py
75 | |-- eval.sh
76 | |-- train_accelerate.py
77 | |-- train_local_machine.sh
78 | |-- train_multi_node.sh
79 |
80 |
81 | ## Data preparation
82 |
83 | The five datasets (QMSum, GovReport, Quality, NarrativeQA and TriviaQA) used in this paper can be downloaded from the following links:
84 |
85 | * [QMSum](https://github.com/Yale-LILY/QMSum)
86 | * [GovReport](https://huggingface.co/datasets/ccdv/govreport-summarization)
87 | * [Quality](https://huggingface.co/datasets/emozilla/quality)
88 | * [NarrativeQA](https://github.com/google-deepmind/narrativeqa)
89 | * [TriviaQA](https://huggingface.co/datasets/mandarjoshi/trivia_qa)
90 |
91 | Before training, first convert the data into a JSONL file in the format {'context': 'xxx', 'prompt': 'xxx', 'answer': 'xxx'}. Then run
92 |
93 | ```python
94 | python preprocess/preshuffle_data_and_chunk.py
95 | ```
96 |
97 | and set the chunk_size parameter during execution.
98 |
99 | ## Train
100 |
101 | During training, first set the desired parameters in configs/train_config.json, then run the appropriate script according to your environment:
102 |
103 | - If you are training on a local machine:
104 |
105 | ```
106 | sh train_local_machine.sh
107 | ```
108 |
109 | - If you are training on a cluster / multi-node setup:
110 |
111 | ```
112 | sh train_multi_node.sh
113 | ```
114 |
115 | ## Evaluate
116 |
117 | For inference, run
118 |
119 | ```
120 | sh eval.sh
121 | ```
122 |
123 | and adjust its parameters so that they match the ones used during training.
124 |
125 | ## Citation
126 |
127 | If you find our repository helpful, please cite us as follows:
128 |
129 | @misc{liao2025e2llmencoderelongatedlarge,
130 | title={E2LLM: Encoder Elongated Large Language Models for Long-Context Understanding and Reasoning},
131 | author={Zihan Liao and Jun Wang and Hang Yu and Lingxiao Wei and Jianguo Li and Jun Wang and Wei Zhang},
132 | year={2025},
133 | eprint={2409.06679},
134 | archivePrefix={arXiv},
135 | primaryClass={cs.CL},
136 | url={https://arxiv.org/abs/2409.06679},
137 | }
--------------------------------------------------------------------------------
/model/pma.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import math
6 |
7 |
8 | # PMA部分 post_normal
9 | class MAB_POST(nn.Module):
10 | def __init__(self, dim_Q, dim_K, dim_V, num_heads, ln=False):
11 | super(MAB_POST, self).__init__()
12 | self.dim_V = dim_V
13 | self.num_heads = num_heads
14 | self.fc_q = nn.Linear(dim_Q, dim_V)
15 | self.fc_k = nn.Linear(dim_K, dim_V)
16 | self.fc_v = nn.Linear(dim_K, dim_V)
17 |
18 | if ln:
19 | self.ln0 = nn.LayerNorm(dim_V)
20 | self.ln1 = nn.LayerNorm(dim_V)
21 | self.fc_o = nn.Linear(dim_V, dim_V)
22 | nn.init.xavier_uniform_(self.fc_q.weight)
23 | nn.init.xavier_uniform_(self.fc_k.weight)
24 | nn.init.xavier_uniform_(self.fc_v.weight)
25 | nn.init.xavier_uniform_(self.fc_o.weight)
26 |
27 |
28 |
29 | # Q(bs, 1, emb), pad_mask (bs, seq) Post-LN
30 | def forward(self, Q, K, pad_mask=None):
31 |
32 | Q_ = self.fc_q(Q)
33 | K_, V_ = self.fc_k(K), self.fc_v(K)
34 |
35 | dim_split = self.dim_V // self.num_heads
36 | Q_ = torch.cat(Q_.split(dim_split, 2), 0) # (bs* num_head, 1, emb)
37 | K_ = torch.cat(K_.split(dim_split, 2), 0) # (bs* num_head, seq, emb)
38 | V_ = torch.cat(V_.split(dim_split, 2), 0)
39 |
40 | pad_mask = pad_mask.unsqueeze(1).repeat(self.num_heads, Q.size(1), 1) # (bs*num_head, 1, seq)
41 | score = Q_.bmm(K_.transpose(1,2))/math.sqrt(self.dim_V)
42 | score = score.masked_fill(pad_mask == 0, -1e12)
43 | A = torch.softmax(score, 2) # (bs*num_head, 1, seq)
44 | A = A * pad_mask
45 | O = torch.cat(A.bmm(V_).split(Q.size(0), 0), 2) # (bs, 1, emb)
46 | O = Q + O
47 | # O = torch.cat((Q_ + A.bmm(V_)).split(Q.size(0), 0), 2)
48 | O = O if getattr(self, 'ln0', None) is None else self.ln0(O)
49 | O = O + F.relu(self.fc_o(O))
50 | O = O if getattr(self, 'ln1', None) is None else self.ln1(O)
51 | return O
52 |
53 |
54 |
55 | class PMA_v1(nn.Module):
56 | def __init__(self, dim, compressed_dim, num_heads, num_seeds, ln=False):
57 | super(PMA_v1, self).__init__()
58 | self.S = nn.Parameter(torch.Tensor(1, num_seeds, compressed_dim))
59 | nn.init.xavier_uniform_(self.S)
60 | self.mab = MAB_POST(compressed_dim, dim, compressed_dim, num_heads, ln=ln)
61 | self.mab.to()
62 | def forward(self, X, pad_mask):
63 | if self.S.dtype != torch.bfloat16:
64 | X = X.float()
65 | return self.mab(self.S.expand(X.size(0), -1, -1).to(X.device), X, pad_mask)
66 |
67 |
68 | # PMA部分 post_normal
69 | class MAB_POST_v2(nn.Module):
70 | def __init__(self, dim_Q, dim_K, dim_V, num_heads, ln=False):
71 | super(MAB_POST_v2, self).__init__()
72 | self.dim_V = dim_V
73 | self.num_heads = num_heads
74 | self.fc_q = nn.Linear(dim_Q, dim_V)
75 | self.fc_k = nn.Linear(dim_K, dim_V)
76 | self.fc_v = nn.Linear(dim_K, dim_V)
77 |
78 | if ln:
79 | self.ln0 = nn.LayerNorm(dim_V)
80 | self.ln1 = nn.LayerNorm(dim_V)
81 | self.fc_o = nn.Linear(dim_V, dim_V)
82 | nn.init.xavier_uniform_(self.fc_q.weight)
83 | nn.init.xavier_uniform_(self.fc_k.weight)
84 | nn.init.xavier_uniform_(self.fc_v.weight)
85 | nn.init.xavier_uniform_(self.fc_o.weight)
86 |
87 |
88 |
89 | # Q(B, num_seed, D), pad_mask (bs, seq) Post-LN
90 | def forward(self, Q, K, pad_mask=None):
91 | Q_tmp = self.fc_q(Q) # B, num_seed, C
92 | K_, V_ = self.fc_k(K), self.fc_v(K) # B, L, C
93 |
94 | dim_split = self.dim_V // self.num_heads
95 | Q_ = torch.cat(Q_tmp.split(dim_split, 2), 0) # (B* num_head, num_seed, C)
96 | K_ = torch.cat(K_.split(dim_split, 2), 0) # (B* num_head, L, C)
97 | V_ = torch.cat(V_.split(dim_split, 2), 0) # (B* num_head,L, C)
98 |
99 | pad_mask = pad_mask.unsqueeze(1).repeat(self.num_heads, Q.size(1), 1) # (B*num_head, num_seed, L)
100 | score = Q_.bmm(K_.transpose(1,2))/math.sqrt(self.dim_V) # (B*num_head, num_seed, L)
101 | score = score.masked_fill(pad_mask == 0, -1e12) # B,num_seed,L
102 | A = torch.softmax(score, 2) # (B*num_head, num_seed, L)
103 | A = A * pad_mask
104 | O = torch.cat(A.bmm(V_).split(Q.size(0), 0), 2) # (B, num_seed, D)
105 | O = Q_tmp + O
106 | # O = torch.cat((Q_ + A.bmm(V_)).split(Q.size(0), 0), 2)
107 | O = O if getattr(self, 'ln0', None) is None else self.ln0(O)
108 | O = O + F.relu(self.fc_o(O))
109 | O = O if getattr(self, 'ln1', None) is None else self.ln1(O)
110 | return O
111 |
112 | class PMA_v2(nn.Module):
113 | def __init__(self, dim, compressed_dim, num_heads, num_seeds, ln=False):
114 | super(PMA_v2, self).__init__()
115 | self.S = nn.Parameter(torch.Tensor(1, num_seeds, dim))
116 | nn.init.xavier_uniform_(self.S)
117 | self.mab = MAB_POST_v2(dim, dim, compressed_dim, num_heads, ln=ln)
118 |
119 | # X: (bs, seq, emb), pad_mask: (bs, seq)
120 | def forward(self, X, pad_mask):
121 | if self.S.dtype != torch.bfloat16:
122 | X = X.float()
123 | return self.mab(self.S.expand(X.size(0), -1, -1).to(X.device), X, pad_mask)
--------------------------------------------------------------------------------
/preprocess/preshuffle_data_and_chunk.py:
--------------------------------------------------------------------------------
1 | import re
2 | import os
3 | import pathlib
4 | import sys
5 | import jsonlines
6 | import math
7 | from tqdm import tqdm
8 | import random
9 | from langchain_text_splitters import CharacterTextSplitter, RecursiveCharacterTextSplitter
10 |
11 | def makedirs(path):
12 | p = pathlib.Path(path)
13 | p.parent.mkdir(parents=True, exist_ok=True)
14 | return path
15 |
16 | def split_text(text, max_length=512):
17 | # 创建一个空列表来存储文本块
18 | chunks = []
19 |
20 | current_position = 0
21 |
22 | while current_position < len(text):
23 | if current_position + max_length >= len(text):
24 | chunks.append(text[current_position:])
25 | break
26 | chunk_candidate = text[current_position:current_position + max_length]
27 | last_punctuation_pos = max([chunk_candidate.rfind('.'), chunk_candidate.rfind('\n')])
28 | if last_punctuation_pos == -1 or last_punctuation_pos==0:
29 | chunks.append(chunk_candidate)
30 | current_position += max_length
31 | else:
32 | chunks.append(text[current_position:current_position + last_punctuation_pos + 1])
33 | current_position += last_punctuation_pos + 1
34 | return chunks
35 |
36 | def split_text_overlap(text, max_length=512, overlap_ratio=0.3):
37 | chunks = []
38 |
39 | current_position = 0
40 | overlap_length = int(max_length * overlap_ratio)
41 |
42 | while current_position < len(text):
43 | if current_position + max_length >= len(text):
44 | chunks.append(text[current_position:])
45 | break
46 |
47 | chunk_candidate = text[current_position:current_position + max_length]
48 |
49 | last_punctuation_pos = max([chunk_candidate.rfind('.'), chunk_candidate.rfind('\n')])
50 | if last_punctuation_pos == -1 or last_punctuation_pos == 0:
51 | chunks.append(chunk_candidate)
52 | current_position += max_length
53 | else:
54 | chunks.append(text[current_position:current_position + last_punctuation_pos + 1])
55 | current_position += last_punctuation_pos + 1
56 |
57 | if current_position >= len(text):
58 | break
59 |
60 | current_position = max(current_position - overlap_length, 0)
61 |
62 | return chunks
63 |
64 |
65 |
66 | def preshuffle_and_chunkfile(column_names, chunk_size, data_path, write_data_path, max_samples):
67 | if os.path.exists(write_data_path):
68 | raise ValueError(f'Existed file: {write_data_path}')
69 | else:
70 | makedirs(write_data_path)
71 | all_data = []
72 | with jsonlines.open(data_path, 'r') as fr:
73 | for i, line in tqdm(enumerate(fr)):
74 | if max_samples is not None:
75 | if i >= max_samples:
76 | break
77 | input_ = line[column_names['input']]
78 | prompt_ = line[column_names['prompt']]
79 | output_ = line[column_names['output']]
80 |
81 | chunked_input = split_text(input_, chunk_size)
82 | all_data.append({'input':chunked_input, 'prompt':'Restate the aforementioned contexts:', 'output':input_, 'task_id':0})
83 | all_data.append({'input':chunked_input, 'prompt':f'Answer the question:{prompt_}', 'output':output_, 'task_id':1})
84 | random.shuffle(all_data)
85 | with jsonlines.open(write_data_path, 'w') as fw:
86 | for data in all_data:
87 | fw.write(data)
88 |
89 | def preshuffle_and_chunkfile_summary(column_names, chunk_size, data_path, write_data_path, max_samples):
90 | if os.path.exists(write_data_path):
91 | raise ValueError(f'Existed file: {write_data_path}')
92 | else:
93 | makedirs(write_data_path)
94 | all_data = []
95 | with jsonlines.open(data_path, 'r') as fr:
96 | for i, line in tqdm(enumerate(fr)):
97 | if max_samples is not None:
98 | if i >= max_samples:
99 | break
100 | input_ = line[column_names['input']]
101 | prompt_ = 'What does the above context describe?'
102 | output_ = line[column_names['output']]
103 |
104 |
105 | chunked_input = split_text(input_, chunk_size)
106 | all_data.append({'input':chunked_input, 'prompt':'Restate the aforementioned contexts:', 'output':input_, 'task_id':0})
107 | all_data.append({'input':chunked_input, 'prompt':f'Answer the question:{prompt_}', 'output':output_, 'task_id':1})
108 | random.shuffle(all_data)
109 | with jsonlines.open(write_data_path, 'w') as fw:
110 | for data in all_data:
111 | fw.write(data)
112 |
113 |
114 |
115 | if __name__ == '__main__':
116 | chunk_size = 512
117 |
118 |
119 | # qmsum
120 | max_samples_eval = None
121 | column_names = {'input':'context', 'prompt':'prompt', 'output':'answer'}
122 | write_data_path = f'/path/to/chunked/dataset/with/chunk_size/eval.jsonl'
123 | data_path = '/data/to/raw_eval.jsonl'
124 | preshuffle_and_chunkfile(column_names, chunk_size, data_path, write_data_path, max_samples_eval)
125 | max_samples_train = None
126 | data_path = '/data/to/raw_train.jsonl'
127 | write_data_path = f'/path/to/chunked/dataset/with/chunk_size/train.jsonl'
128 | preshuffle_and_chunkfile(column_names, chunk_size, data_path, write_data_path, max_samples_train)
129 |
--------------------------------------------------------------------------------
/utils/common_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import random
4 | import pathlib
5 | import numpy as np
6 | import jsonlines
7 | from scipy.stats import pearsonr, spearmanr
8 | import torch
9 | from loguru import logger
10 | import shutil
11 | from torch.utils.tensorboard import SummaryWriter
12 | import pickle
13 | import linecache
14 | import tracemalloc
15 | import torch.distributed as dist
16 | import oss2
17 | import base64
18 |
19 | def set_seed(seed):
20 | random.seed(seed)
21 | np.random.seed(seed)
22 | torch.manual_seed(seed)
23 | if torch.cuda.is_available():
24 | torch.cuda.manual_seed_all(seed)
25 |
26 | def save_model(model_engine, ckpt_dir, client_state):
27 | model_engine.save_checkpoint(ckpt_dir, client_state=client_state)
28 |
29 | def remove_earlier_ckpt(path, start_name, current_step_num, max_save_num):
30 |
31 | filenames=os.listdir(path)
32 | ckpts = [dir_name for dir_name in filenames if os.path.isdir(os.path.join(path, dir_name)) and dir_name.startswith(start_name) and int(dir_name.split('-')[1])<=int(current_step_num)]
33 | current_ckpt_num = len(ckpts)
34 | print(f'existing ckpts to remove:{ckpts}')
35 | for dir_name in ckpts:
36 | if dir_name.startswith(start_name) and int(dir_name.split('-')[1]) <= int(current_step_num) and current_ckpt_num > (max_save_num-1):
37 | try:
38 | shutil.rmtree(os.path.join(path, dir_name))
39 | except Exception as e:
40 | print('Error:', e)
41 |
42 | def makedirs(path):
43 | p = pathlib.Path(path)
44 | p.parent.mkdir(parents=True, exist_ok=True)
45 | return path
46 |
47 | def write_jsonl(obj, file_path):
48 | with jsonlines.open(file_path, 'a') as writer:
49 | writer.write(obj)
50 |
51 |
52 | def load_pickle(path):
53 | with open(path, "rb") as f:
54 | return pickle.load(f)
55 |
56 | def write_pickle(obj, path:str):
57 | if not os.path.exists(path):
58 | makedirs(path)
59 | with open(path, "wb") as f:
60 | return pickle.dump(obj, f)
61 |
62 | def write_tensorboard(summary_writer, log_dict, completed_steps):
63 | for key, value in log_dict.items():
64 | summary_writer.add_scalar(f'{key}', value, completed_steps)
65 |
66 | def cos_sim(a, b):
67 | """
68 | Computes the cosine similarity cos_sim(a[i], b[j]) for all i and j.
69 | :return: Matrix with res[i][j] = cos_sim(a[i], b[j])
70 | """
71 | if not isinstance(a, torch.Tensor):
72 | a = torch.tensor(a)
73 |
74 | if not isinstance(b, torch.Tensor):
75 | b = torch.tensor(b)
76 |
77 | if len(a.shape) == 1:
78 | a = a.unsqueeze(0)
79 |
80 | if len(b.shape) == 1:
81 | b = b.unsqueeze(0)
82 |
83 | a_norm = torch.nn.functional.normalize(a, p=2, dim=1)
84 | b_norm = torch.nn.functional.normalize(b, p=2, dim=1)
85 | return torch.mm(a_norm, b_norm.transpose(0, 1))
86 |
87 | def gather_across_devices(v, global_rank, world_size):
88 | if v is None:
89 | return None
90 | v = v.contiguous()
91 |
92 | all_tensors = [torch.empty_like(v) for _ in range(world_size)]
93 | dist.all_gather(all_tensors, v)
94 |
95 | all_tensors[global_rank] = v
96 | all_tensors = torch.cat(all_tensors, dim=0)
97 |
98 | return all_tensors
99 |
100 |
101 |
102 |
103 | # def save_model(model_engine, ckpt_dir, client_state):
104 | # model_engine.save_checkpoint(ckpt_dir, client_state=client_state)
105 |
106 | def save_model(model_engine, ckpt_dir, client_state):
107 | model_engine.save_checkpoint(ckpt_dir, client_state=client_state, exclude_frozen_parameters=True)
108 |
109 | def punctuation_format(text: str):
110 | # Replace non-breaking space with space
111 | # text = text.strip() + '\n'
112 | text = text.replace("\u202f", " ").replace("\xa0", " ")
113 | # change chinese punctuation to english ones
114 | # text = text.translate(table)
115 | if not text.endswith("\n"):
116 | text += "\n"
117 | return text
118 |
119 | def randomly_insert_list1_into_list2(list1, list2):
120 | len_combined = len(list1) + len(list2)
121 | insertion_points = sorted(random.sample(range(len_combined), len(list2)))
122 |
123 | combined_list = []
124 | insert_index = 0
125 | list2_index = 0
126 | for i in range(len_combined):
127 | if insert_index < len(insertion_points) and i == insertion_points[insert_index]:
128 | combined_list.append(list2[list2_index])
129 | list2_index += 1
130 | insert_index += 1
131 | else:
132 | combined_list.append(list1[i - list2_index])
133 |
134 | return combined_list
135 |
136 | def get_parameter_number(model):
137 | total_num = sum(p.numel() for p in model.parameters())
138 | trainable_num = sum(p.numel() for p in model.parameters() if p.requires_grad)
139 | return total_num, trainable_num
140 |
141 | def find_overlap(s1, s2):
142 | max_overlap = 0
143 | overlap_str = s1 + s2
144 | for i in range(1, min(len(s1), len(s2)) + 1):
145 | if s1[-i:] == s2[:i]:
146 | if i > max_overlap:
147 | max_overlap = i
148 | overlap_str = s1 + s2[i:]
149 |
150 | return overlap_str
151 |
152 | def merge_strings(strings):
153 |
154 | if not strings:
155 | return ""
156 |
157 | merged = strings[0]
158 | for s in strings[1:]:
159 | merged = find_overlap(merged, s)
160 |
161 | return merged
162 |
163 | def compile_helper():
164 | """Compile helper function at runtime. Make sure this
165 | is invoked on a single process."""
166 | import os
167 | import subprocess
168 |
169 | path = os.path.abspath(os.path.dirname(__file__))
170 | ret = subprocess.run(["make", "-C", path])
171 | if ret.returncode != 0:
172 | print("Making C++ dataset helpers module failed, exiting.")
173 | import sys
174 |
175 | sys.exit(1)
176 | else:
177 | print("Making C++ dataset helpers module successfully.")
178 |
179 |
180 | def print_rank_0(*message):
181 | """If distributed is initialized print only on rank 0."""
182 | if torch.distributed.is_initialized():
183 | if torch.distributed.get_rank() == 0:
184 | print(*message, flush=True)
185 | else:
186 | print(*message, flush=True)
--------------------------------------------------------------------------------
/model/encoder_model_bert.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 | import json
4 | import argparse
5 | import torch
6 | import warnings
7 | import deepspeed
8 | from enum import Enum
9 | from typing import Union, List
10 | import torch.nn as nn
11 | from tqdm import tqdm, trange
12 | import numpy as np
13 | warnings.filterwarnings('ignore')
14 | from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM, AutoConfig
15 | from peft import LoraConfig, get_peft_model, TaskType
16 | from model import PMA_v1, PMA_v2
17 |
18 |
19 |
20 | class EncoderType(Enum):
21 | FIRST_LAST_AVG = 0
22 | LAST_AVG = 1
23 | CLS = 2
24 | POOLER = 3
25 | MEAN = 4
26 | PMA = 5
27 |
28 | def __str__(self):
29 | return self.name
30 |
31 | @staticmethod
32 | def from_string(s):
33 | try:
34 | return EncoderType[s]
35 | except KeyError:
36 | raise ValueError()
37 |
38 | class BaseBertModel(nn.Module):
39 | def __init__(
40 | self,
41 | model_name_or_path = None,
42 | max_seq_length = 512,
43 | encoder_type = 'CLS',
44 | pma_output_dim = -1, # pma参数
45 | tokens_for_each_chunk = -1, # pma参数
46 | ln = True, # pma参数
47 | alias = None
48 | ):
49 | super().__init__()
50 | self.model_name_or_path = model_name_or_path
51 | encoder_type = encoder_type.upper()
52 | encoder_type = EncoderType.from_string(encoder_type) if isinstance(encoder_type, str) else encoder_type
53 |
54 | if encoder_type not in list(EncoderType):
55 | raise ValueError(f'encoder_type must be in {list(EncoderType)}')
56 | self.encoder_type = encoder_type
57 | self.max_seq_length = max_seq_length
58 | self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, truncation_side='right', padding_side='right')
59 | self.plm_model = AutoModel.from_pretrained(model_name_or_path, trust_remote_code=True)
60 | self.results = {}
61 | if encoder_type == EncoderType.PMA:
62 | pma_input_dim = self.plm_model.config.hidden_size
63 |
64 | self.pma = PMA_v1(dim=pma_input_dim, compressed_dim=pma_output_dim, num_heads=self.plm_model.config.num_attention_heads, num_seeds=tokens_for_each_chunk, ln=ln)
65 |
66 |
67 |
68 | def get_sentence_embeddings(self, input_ids, attention_mask, token_type_ids=None):
69 | """
70 | Returns the model output by encoder_type as embeddings.
71 |
72 | Utility function for self.bert() method.
73 | """
74 | model_output = self.plm_model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, output_hidden_states=True)
75 |
76 | if self.encoder_type == EncoderType.FIRST_LAST_AVG:
77 |
78 | first = model_output.hidden_states[1]
79 | last = model_output.hidden_states[-1]
80 | seq_length = first.size(1) # Sequence length
81 |
82 | first_avg = torch.avg_pool1d(first.transpose(1, 2), kernel_size=seq_length).squeeze(-1) # [batch, hid_size]
83 | last_avg = torch.avg_pool1d(last.transpose(1, 2), kernel_size=seq_length).squeeze(-1) # [batch, hid_size]
84 | final_encoding = torch.avg_pool1d(
85 | torch.cat([first_avg.unsqueeze(1), last_avg.unsqueeze(1)], dim=1).transpose(1, 2),
86 | kernel_size=2).squeeze(-1)
87 | return final_encoding
88 |
89 | if self.encoder_type == EncoderType.LAST_AVG:
90 | sequence_output = model_output.last_hidden_state # [batch_size, max_len, hidden_size]
91 | seq_length = sequence_output.size(1)
92 | final_encoding = torch.avg_pool1d(sequence_output.transpose(1, 2), kernel_size=seq_length).squeeze(-1)
93 | return final_encoding
94 |
95 | if self.encoder_type == EncoderType.CLS:
96 | sequence_output = model_output.last_hidden_state
97 | return sequence_output[:, 0] # [batch, hid_size]
98 |
99 | if self.encoder_type == EncoderType.POOLER:
100 | return model_output.pooler_output # [batch, hid_size]
101 |
102 | if self.encoder_type == EncoderType.MEAN:
103 | """
104 | Mean Pooling - Take attention mask into account for correct averaging
105 | """
106 | token_embeddings = model_output.last_hidden_state # Contains all token embeddings
107 | input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
108 | final_encoding = torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(
109 | input_mask_expanded.sum(1), min=1e-9)
110 | return final_encoding # [batch, hid_size]
111 | if self.encoder_type == EncoderType.PMA:
112 | token_embeddings = model_output.last_hidden_state # Contains all token embeddings
113 | res_emb = self.pma(token_embeddings, attention_mask)
114 | res_emb = res_emb.reshape(res_emb.size(0), -1)
115 | return res_emb
116 | def batch_to_device(self, batch, device):
117 | """
118 | send a pytorch batch to a device (CPU/GPU)
119 | """
120 | for key in batch:
121 | if isinstance(batch[key], torch.Tensor):
122 | batch[key] = batch[key].to(device)
123 | return batch
124 |
125 |
126 | def encode_with_grad(
127 | self,
128 | sentences: Union[str, List[str]],
129 | batch_size: int = 32,
130 | show_progress_bar: bool = False,
131 | convert_to_numpy: bool = False,
132 | convert_to_tensor: bool = True,
133 | device: str = None,
134 | normalize_embeddings: bool = True,
135 | max_seq_length: int = None,
136 | ):
137 | if device is None:
138 | device = self.plm_model.device
139 | # self.plm_model.to(device)
140 |
141 | if max_seq_length is None:
142 | max_seq_length = self.max_seq_length
143 | if convert_to_tensor:
144 | convert_to_numpy = False
145 | input_is_string = False
146 | if isinstance(sentences, str) or not hasattr(sentences, "__len__"):
147 | sentences = [sentences]
148 | input_is_string = True
149 |
150 | all_embeddings = []
151 | length_sorted_idx = np.argsort([-len(s) for s in sentences])
152 | sentences_sorted = [sentences[idx] for idx in length_sorted_idx]
153 | for start_index in trange(0, len(sentences), batch_size, desc="Batches", disable=not show_progress_bar):
154 | sentences_batch = sentences_sorted[start_index: start_index + batch_size]
155 | # Compute sentences embeddings
156 | features = self.tokenizer(
157 | sentences_batch, max_length=max_seq_length,
158 | padding=True, truncation=True, return_tensors='pt'
159 | )
160 | features = self.batch_to_device(features, device)
161 | embeddings = self.get_sentence_embeddings(**features)
162 | embeddings = embeddings.detach()
163 | if normalize_embeddings:
164 | embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
165 |
166 | if convert_to_numpy:
167 | embeddings = embeddings.cpu()
168 | all_embeddings.extend(embeddings)
169 | all_embeddings = [all_embeddings[idx] for idx in np.argsort(length_sorted_idx)]
170 | if convert_to_tensor:
171 | all_embeddings = torch.stack(all_embeddings)
172 | elif convert_to_numpy:
173 | all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings])
174 |
175 | if input_is_string:
176 | all_embeddings = all_embeddings[0]
177 |
178 | return all_embeddings
179 |
180 | def encode_without_grad(
181 | self,
182 | sentences: Union[str, List[str]],
183 | batch_size: int = 32,
184 | show_progress_bar: bool = False,
185 | convert_to_numpy: bool = False,
186 | convert_to_tensor: bool = True,
187 | device: str = None,
188 | normalize_embeddings: bool = True,
189 | max_seq_length: int = None,
190 | ):
191 | self.plm_model.eval()
192 | if device is None:
193 | device = self.plm_model.device
194 | # self.plm_model.to(device)
195 |
196 | if max_seq_length is None:
197 | max_seq_length = self.max_seq_length
198 | if convert_to_tensor:
199 | convert_to_numpy = False
200 | input_is_string = False
201 | if isinstance(sentences, str) or not hasattr(sentences, "__len__"):
202 | sentences = [sentences]
203 | input_is_string = True
204 |
205 | all_embeddings = []
206 | length_sorted_idx = np.argsort([-len(s) for s in sentences])
207 | sentences_sorted = [sentences[idx] for idx in length_sorted_idx]
208 | for start_index in trange(0, len(sentences), batch_size, desc="Batches", disable=not show_progress_bar):
209 | sentences_batch = sentences_sorted[start_index: start_index + batch_size]
210 | # Compute sentences embeddings
211 | with torch.no_grad():
212 | features = self.tokenizer(
213 | sentences_batch, max_length=max_seq_length,
214 | padding=True, truncation=True, return_tensors='pt'
215 | )
216 | features = self.batch_to_device(features, device)
217 | embeddings = self.get_sentence_embeddings(**features)
218 | embeddings = embeddings.detach()
219 | if normalize_embeddings:
220 | embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
221 |
222 | if convert_to_numpy:
223 | embeddings = embeddings.cpu()
224 | all_embeddings.extend(embeddings)
225 | all_embeddings = [all_embeddings[idx] for idx in np.argsort(length_sorted_idx)]
226 | if convert_to_tensor:
227 | all_embeddings = torch.stack(all_embeddings)
228 | elif convert_to_numpy:
229 | all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings])
230 |
231 | if input_is_string:
232 | all_embeddings = all_embeddings[0]
233 |
234 | return all_embeddings
235 |
--------------------------------------------------------------------------------
/dataset/dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import math
4 | import numpy as np
5 | import torch
6 | from torch.utils.data import Dataset
7 | from datasets import load_dataset
8 | from transformers import PreTrainedTokenizer, AutoTokenizer
9 | import csv
10 | from loguru import logger
11 | import random
12 | import jsonlines
13 | from utils.common_utils import load_pickle, set_seed, merge_strings, randomly_insert_list1_into_list2
14 |
15 | def load_qmsum_data(data_path, res_data, chunks_repo, global_rank, world_size, mode):
16 | data = []
17 | if data_path is None:
18 | raise ValueError(f'data path is {data_path}')
19 | with open(data_path, 'r', encoding='utf8') as f:
20 | for i,line in enumerate(jsonlines.Reader(f)):
21 | chunks_repo.extend(line['input'])
22 | if i % world_size != global_rank:
23 | continue
24 | data.append((line['input'], line['prompt'], line['output'], line['task_id']))
25 | num_samples = len(data)
26 | val_samples = math.ceil(0.05*num_samples)
27 | if mode == 'train':
28 | res_data.extend(data[:-val_samples])
29 | elif mode == 'validation':
30 | res_data.extend(data[-val_samples:])
31 | else:
32 | raise ValueError(f'What is {mode} ?')
33 | return res_data, list(set(chunks_repo))
34 |
35 | def load_qmsum_test_data(data_path):
36 | data = []
37 | if data_path is None:
38 | raise ValueError(f'data path is {data_path}')
39 | with open(data_path, 'r', encoding='utf8') as f:
40 | for line in jsonlines.Reader(f):
41 | if line['task_id'] == 1:
42 | data.append((line['input'], line['prompt'], line['output'], line['task_id']))
43 | return data
44 |
45 |
46 | def collate_fn(data):
47 |
48 | res_inputs = []
49 | res_prompt = []
50 | res_output = []
51 | res_task_id = []
52 | for d in data:
53 | inputs = d[0] # list of str, chunk
54 | prompt = d[1]
55 | output = d[2]
56 | task_id = d[3]
57 | res_inputs.append(inputs)
58 | res_prompt.append(prompt)
59 | res_output.append(output)
60 | res_task_id.append(task_id)
61 | return res_inputs, res_prompt, res_output, res_task_id
62 | class TrainDataset(Dataset):
63 |
64 | def __init__(self, names=None, max_seq_len=0, chunk_size=0, shuffle_sample_ratio=0, noisy_sample_ratio=0, noise_rate=0, max_num_chunks=0, max_sliding_windows=None, process_index=0, num_processes=1, seed=2024):
65 | self.data = []
66 | self.chunks_repo = []
67 | self.max_seq_len = max_seq_len
68 | self.chunk_size = chunk_size
69 | self.shuffle_sample_ratio = shuffle_sample_ratio
70 | self.noisy_sample_ratio = noisy_sample_ratio
71 | self.noise_rate = noise_rate
72 | self.max_num_chunks = max_num_chunks
73 | self.max_sliding_windows = max_sliding_windows
74 | self.process_index = process_index
75 | self.num_processes = num_processes
76 | self.deterministic_generator = np.random.default_rng(seed)
77 | data_path = None
78 | names = names[1:-1].split(',')
79 | names = [name.lower() for name in names]
80 | for name in names:
81 | if name == 'qmsum':
82 | data_path = f'/path/to/chunked/dataset/with/chunk_size/train.jsonl'
83 | self.data, self.chunks_repo = load_qmsum_data(data_path, self.data, self.chunks_repo, self.process_index, self.num_processes, 'train')
84 |
85 | self.window_size = math.floor((self.max_seq_len-self.max_num_chunks-60) / self.chunk_size)
86 | self.augmented_data = self.get_augmented_data()
87 |
88 | def get_augmented_data(self):
89 | res_data = []
90 | for sample in self.data:
91 | chunks = sample[0] # list of str (str represents chunks)
92 | prompt = sample[1]
93 | output = sample[2]
94 | task_id = int(sample[3])
95 | if task_id == 0:
96 | # 正常重构,不shuffle
97 | num_chunks = len(chunks)
98 | num_sliding_windows = math.ceil(num_chunks / self.window_size)
99 | if self.max_sliding_windows is not None:
100 | num_sliding_windows = min(self.max_sliding_windows, num_sliding_windows)
101 | start = 0
102 | for i in range(num_sliding_windows):
103 | end = min(start + self.window_size, num_chunks)
104 | chunks_in_window = chunks[start:end]
105 | # chunks_in_window_to_string = ''.join(chunks_in_window)
106 | chunks_in_window_to_string = merge_strings(chunks_in_window)
107 | res_data.append((chunks_in_window, prompt, chunks_in_window_to_string, task_id))
108 | start = end
109 | if task_id == 1:
110 | random_value = random.random()
111 | if random_value < self.shuffle_sample_ratio:
112 | random.shuffle(chunks)
113 | res_data.append((chunks, prompt, output, task_id))
114 | elif random_value >= self.shuffle_sample_ratio and random_value < self.shuffle_sample_ratio+self.noisy_sample_ratio:
115 | num_noise_chunks = min(math.floor(self.noise_rate*len(chunks)), len(self.chunks_repo))
116 | noisy_chunks = random.sample(self.chunks_repo, num_noise_chunks)
117 | chunks = randomly_insert_list1_into_list2(chunks, noisy_chunks)
118 | res_data.append((chunks, prompt, output, task_id))
119 | else:
120 | res_data.append((chunks, prompt, output, task_id))
121 |
122 | self.step = 0
123 | self.steps_per_epoch = len(res_data)
124 |
125 | return res_data
126 |
127 |
128 | def __getitem__(self, index):
129 | if self.step > self.steps_per_epoch - 1:
130 | self.augmented_data = self.get_augmented_data()
131 | self.step += 1
132 | return self.augmented_data[index]
133 |
134 | def __len__(self):
135 | return self.steps_per_epoch
136 |
137 | class ValDataset(Dataset):
138 | def __init__(self, names=None, max_seq_len=0, chunk_size=0, shuffle_sample_ratio=0, noisy_sample_ratio=0, noise_rate=0, max_num_chunks=0, max_sliding_windows=None, process_index=0, num_processes=1, seed=2024):
139 | self.data = []
140 | self.chunks_repo = []
141 | self.max_seq_len = max_seq_len
142 | self.chunk_size = chunk_size
143 | self.max_num_chunks = max_num_chunks
144 | self.shuffle_sample_ratio = shuffle_sample_ratio
145 | self.noisy_sample_ratio = noisy_sample_ratio
146 | self.noise_rate = noise_rate
147 | self.process_index = process_index
148 | self.max_sliding_windows = max_sliding_windows
149 | self.num_processes = num_processes
150 | self.deterministic_generator = np.random.default_rng(seed)
151 | data_path = None
152 | names = names[1:-1].split(',')
153 | names = [name.lower() for name in names]
154 | for name in names:
155 | if name == 'qmsum':
156 | data_path = f'/path/to/chunked/dataset/with/chunk_size/train.jsonl'
157 | self.data, self.chunks_repo = load_qmsum_data(data_path, self.data, self.chunks_repo, self.process_index, self.num_processes, 'validation')
158 | self.window_size = math.floor((self.max_seq_len-self.max_num_chunks-60) / self.chunk_size)
159 |
160 | self.augmented_data = self.get_augmented_data()
161 |
162 | def get_augmented_data(self):
163 | res_data = []
164 | for sample in self.data:
165 | chunks = sample[0] # list of str (str represents chunks)
166 | prompt = sample[1]
167 | output = sample[2]
168 | task_id = int(sample[3])
169 | if task_id == 0:
170 | # 正常重构,不shuffle
171 | num_chunks = len(chunks)
172 | num_sliding_windows = math.ceil(num_chunks / self.window_size)
173 | if self.max_sliding_windows is not None:
174 | num_sliding_windows = min(self.max_sliding_windows, num_sliding_windows)
175 | start = 0
176 | for i in range(num_sliding_windows):
177 | end = min(start + self.window_size, num_chunks)
178 | chunks_in_window = chunks[start:end]
179 | # chunks_in_window_to_string = ''.join(chunks_in_window)
180 | chunks_in_window_to_string = merge_strings(chunks_in_window)
181 | res_data.append((chunks_in_window, prompt, chunks_in_window_to_string, task_id))
182 | start = end
183 | if task_id == 1:
184 | random_value = random.random()
185 | if random_value < self.shuffle_sample_ratio:
186 | random.shuffle(chunks)
187 | res_data.append((chunks, prompt, output, task_id))
188 | elif random_value >= self.shuffle_sample_ratio and random_value < self.shuffle_sample_ratio+self.noisy_sample_ratio:
189 | num_noise_chunks = min(math.floor(self.noise_rate*len(chunks)), len(self.chunks_repo))
190 | noisy_chunks = random.sample(self.chunks_repo, num_noise_chunks)
191 | chunks = randomly_insert_list1_into_list2(chunks, noisy_chunks)
192 | res_data.append((chunks, prompt, output, task_id))
193 | else:
194 | res_data.append((chunks, prompt, output, task_id))
195 |
196 | self.step = 0
197 | self.steps_per_epoch = len(res_data)
198 |
199 | return res_data
200 |
201 |
202 | def __getitem__(self, index):
203 | if self.step > self.steps_per_epoch - 1:
204 | self.augmented_data = self.get_augmented_data()
205 | self.step += 1
206 | return self.augmented_data[index]
207 |
208 | def __len__(self):
209 | return self.steps_per_epoch
210 |
211 |
212 |
213 |
214 |
215 | class TestqmsumDataset(Dataset):
216 | def __init__(self, chunk_size=0, seed=2024):
217 | self.deterministic_generator = np.random.default_rng(seed)
218 | data_path = f'/path/to/chunked/dataset/with/chunk_size/eval.jsonl'
219 | self.data = load_qmsum_test_data(data_path)
220 |
221 | def __getitem__(self, index):
222 |
223 | return self.data[index]
224 |
225 | def __len__(self):
226 | return len(self.data)
227 |
228 | if __name__ == '__main__':
229 | d = TrainDataset(names = ['qmsum'])
230 | print(len(d.data))
--------------------------------------------------------------------------------
/LICENSE.md:
--------------------------------------------------------------------------------
1 | Copyright [2023] [Ant Group]
2 | Licensed under the Apache License, Version 2.0 (the "License");
3 | you may not use this file except in compliance with the License.
4 | You may obtain a copy of the License at
5 | http://www.apache.org/licenses/LICENSE-2.0
6 |
7 | Unless required by applicable law or agreed to in writing, software
8 | distributed under the License is distributed on an "AS IS" BASIS,
9 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10 | See the License for the specific language governing permissions and
11 | limitations under the License.
12 |
13 |
14 | Apache License
15 | Version 2.0, January 2004
16 | http://www.apache.org/licenses/
17 |
18 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
19 |
20 | 1. Definitions.
21 |
22 | "License" shall mean the terms and conditions for use, reproduction,
23 | and distribution as defined by Sections 1 through 9 of this document.
24 |
25 | "Licensor" shall mean the copyright owner or entity authorized by
26 | the copyright owner that is granting the License.
27 |
28 | "Legal Entity" shall mean the union of the acting entity and all
29 | other entities that control, are controlled by, or are under common
30 | control with that entity. For the purposes of this definition,
31 | "control" means (i) the power, direct or indirect, to cause the
32 | direction or management of such entity, whether by contract or
33 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
34 | outstanding shares, or (iii) beneficial ownership of such entity.
35 |
36 | "You" (or "Your") shall mean an individual or Legal Entity
37 | exercising permissions granted by this License.
38 |
39 | "Source" form shall mean the preferred form for making modifications,
40 | including but not limited to software source code, documentation
41 | source, and configuration files.
42 |
43 | "Object" form shall mean any form resulting from mechanical
44 | transformation or translation of a Source form, including but
45 | not limited to compiled object code, generated documentation,
46 | and conversions to other media types.
47 |
48 | "Work" shall mean the work of authorship, whether in Source or
49 | Object form, made available under the License, as indicated by a
50 | copyright notice that is included in or attached to the work
51 | (an example is provided in the Appendix below).
52 |
53 | "Derivative Works" shall mean any work, whether in Source or Object
54 | form, that is based on (or derived from) the Work and for which the
55 | editorial revisions, annotations, elaborations, or other modifications
56 | represent, as a whole, an original work of authorship. For the purposes
57 | of this License, Derivative Works shall not include works that remain
58 | separable from, or merely link (or bind by name) to the interfaces of,
59 | the Work and Derivative Works thereof.
60 |
61 | "Contribution" shall mean any work of authorship, including
62 | the original version of the Work and any modifications or additions
63 | to that Work or Derivative Works thereof, that is intentionally
64 | submitted to Licensor for inclusion in the Work by the copyright owner
65 | or by an individual or Legal Entity authorized to submit on behalf of
66 | the copyright owner. For the purposes of this definition, "submitted"
67 | means any form of electronic, verbal, or written communication sent
68 | to the Licensor or its representatives, including but not limited to
69 | communication on electronic mailing lists, source code control systems,
70 | and issue tracking systems that are managed by, or on behalf of, the
71 | Licensor for the purpose of discussing and improving the Work, but
72 | excluding communication that is conspicuously marked or otherwise
73 | designated in writing by the copyright owner as "Not a Contribution."
74 |
75 | "Contributor" shall mean Licensor and any individual or Legal Entity
76 | on behalf of whom a Contribution has been received by Licensor and
77 | subsequently incorporated within the Work.
78 |
79 | 2. Grant of Copyright License. Subject to the terms and conditions of
80 | this License, each Contributor hereby grants to You a perpetual,
81 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
82 | copyright license to reproduce, prepare Derivative Works of,
83 | publicly display, publicly perform, sublicense, and distribute the
84 | Work and such Derivative Works in Source or Object form.
85 |
86 | 3. Grant of Patent License. Subject to the terms and conditions of
87 | this License, each Contributor hereby grants to You a perpetual,
88 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
89 | (except as stated in this section) patent license to make, have made,
90 | use, offer to sell, sell, import, and otherwise transfer the Work,
91 | where such license applies only to those patent claims licensable
92 | by such Contributor that are necessarily infringed by their
93 | Contribution(s) alone or by combination of their Contribution(s)
94 | with the Work to which such Contribution(s) was submitted. If You
95 | institute patent litigation against any entity (including a
96 | cross-claim or counterclaim in a lawsuit) alleging that the Work
97 | or a Contribution incorporated within the Work constitutes direct
98 | or contributory patent infringement, then any patent licenses
99 | granted to You under this License for that Work shall terminate
100 | as of the date such litigation is filed.
101 |
102 | 4. Redistribution. You may reproduce and distribute copies of the
103 | Work or Derivative Works thereof in any medium, with or without
104 | modifications, and in Source or Object form, provided that You
105 | meet the following conditions:
106 |
107 | (a) You must give any other recipients of the Work or
108 | Derivative Works a copy of this License; and
109 |
110 | (b) You must cause any modified files to carry prominent notices
111 | stating that You changed the files; and
112 |
113 | (c) You must retain, in the Source form of any Derivative Works
114 | that You distribute, all copyright, patent, trademark, and
115 | attribution notices from the Source form of the Work,
116 | excluding those notices that do not pertain to any part of
117 | the Derivative Works; and
118 |
119 | (d) If the Work includes a "NOTICE" text file as part of its
120 | distribution, then any Derivative Works that You distribute must
121 | include a readable copy of the attribution notices contained
122 | within such NOTICE file, excluding those notices that do not
123 | pertain to any part of the Derivative Works, in at least one
124 | of the following places: within a NOTICE text file distributed
125 | as part of the Derivative Works; within the Source form or
126 | documentation, if provided along with the Derivative Works; or,
127 | within a display generated by the Derivative Works, if and
128 | wherever such third-party notices normally appear. The contents
129 | of the NOTICE file are for informational purposes only and
130 | do not modify the License. You may add Your own attribution
131 | notices within Derivative Works that You distribute, alongside
132 | or as an addendum to the NOTICE text from the Work, provided
133 | that such additional attribution notices cannot be construed
134 | as modifying the License.
135 |
136 | You may add Your own copyright statement to Your modifications and
137 | may provide additional or different license terms and conditions
138 | for use, reproduction, or distribution of Your modifications, or
139 | for any such Derivative Works as a whole, provided Your use,
140 | reproduction, and distribution of the Work otherwise complies with
141 | the conditions stated in this License.
142 |
143 | 5. Submission of Contributions. Unless You explicitly state otherwise,
144 | any Contribution intentionally submitted for inclusion in the Work
145 | by You to the Licensor shall be under the terms and conditions of
146 | this License, without any additional terms or conditions.
147 | Notwithstanding the above, nothing herein shall supersede or modify
148 | the terms of any separate license agreement you may have executed
149 | with Licensor regarding such Contributions.
150 |
151 | 6. Trademarks. This License does not grant permission to use the trade
152 | names, trademarks, service marks, or product names of the Licensor,
153 | except as required for reasonable and customary use in describing the
154 | origin of the Work and reproducing the content of the NOTICE file.
155 |
156 | 7. Disclaimer of Warranty. Unless required by applicable law or
157 | agreed to in writing, Licensor provides the Work (and each
158 | Contributor provides its Contributions) on an "AS IS" BASIS,
159 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
160 | implied, including, without limitation, any warranties or conditions
161 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
162 | PARTICULAR PURPOSE. You are solely responsible for determining the
163 | appropriateness of using or redistributing the Work and assume any
164 | risks associated with Your exercise of permissions under this License.
165 |
166 | 8. Limitation of Liability. In no event and under no legal theory,
167 | whether in tort (including negligence), contract, or otherwise,
168 | unless required by applicable law (such as deliberate and grossly
169 | negligent acts) or agreed to in writing, shall any Contributor be
170 | liable to You for damages, including any direct, indirect, special,
171 | incidental, or consequential damages of any character arising as a
172 | result of this License or out of the use or inability to use the
173 | Work (including but not limited to damages for loss of goodwill,
174 | work stoppage, computer failure or malfunction, or any and all
175 | other commercial damages or losses), even if such Contributor
176 | has been advised of the possibility of such damages.
177 |
178 | 9. Accepting Warranty or Additional Liability. While redistributing
179 | the Work or Derivative Works thereof, You may choose to offer,
180 | and charge a fee for, acceptance of support, warranty, indemnity,
181 | or other liability obligations and/or rights consistent with this
182 | License. However, in accepting such obligations, You may act only
183 | on Your own behalf and on Your sole responsibility, not on behalf
184 | of any other Contributor, and only if You agree to indemnify,
185 | defend, and hold each Contributor harmless for any liability
186 | incurred by, or claims asserted against, such Contributor by reason
187 | of your accepting any such warranty or additional liability.
188 |
189 | END OF TERMS AND CONDITIONS
190 |
191 | APPENDIX: How to apply the Apache License to your work.
192 |
193 | To apply the Apache License to your work, attach the following
194 | boilerplate notice, with the fields enclosed by brackets "[]"
195 | replaced with your own identifying information. (Don't include
196 | the brackets!) The text should be enclosed in the appropriate
197 | comment syntax for the file format. We also recommend that a
198 | file or class name and description of purpose be included on the
199 | same "printed page" as the copyright notice for easier
200 | identification within third-party archives.
201 |
202 | Copyright [yyyy] [name of copyright owner]
203 |
204 | Licensed under the Apache License, Version 2.0 (the "License");
205 | you may not use this file except in compliance with the License.
206 | You may obtain a copy of the License at
207 |
208 | http://www.apache.org/licenses/LICENSE-2.0
209 |
210 | Unless required by applicable law or agreed to in writing, software
211 | distributed under the License is distributed on an "AS IS" BASIS,
212 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
213 | See the License for the specific language governing permissions and
214 | limitations under the License.
--------------------------------------------------------------------------------
/train_accelerate.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import math
4 | import time
5 | import argparse
6 | import json
7 | import logging
8 | from utils.common_utils import compile_helper
9 | from tqdm.auto import tqdm
10 | import transformers
11 | import numpy as np
12 | import torch
13 | from torch import nn
14 | import datasets
15 |
16 | from torch.utils.data.distributed import DistributedSampler
17 | from torch.utils.data import DataLoader, Dataset, RandomSampler
18 | from transformers import (
19 | AutoModelForCausalLM,
20 | AutoTokenizer,
21 | get_linear_schedule_with_warmup,
22 | set_seed,
23 | BitsAndBytesConfig,
24 | get_scheduler,
25 | )
26 | from peft import (
27 | LoraConfig,
28 | TaskType,
29 | get_peft_model,
30 | prepare_model_for_kbit_training,
31 | PeftModel,
32 | )
33 | from accelerate import Accelerator, DistributedType, FullyShardedDataParallelPlugin, DataLoaderConfiguration
34 | from accelerate.logging import get_logger
35 | from datetime import timedelta
36 | from accelerate.utils import InitProcessGroupKwargs
37 | from transformers.optimization import Adafactor
38 |
39 | # insert project root as import path, just in case
40 | current_path = os.path.abspath(__file__)
41 | parent_dir = os.path.dirname(os.path.dirname(current_path))
42 | sys.path.insert(0, parent_dir)
43 |
44 | from dataset.dataset import *
45 | from utils.common_utils import print_rank_0, get_parameter_number
46 | from pefts import E2LLMTrainArgs, E2LLMTrainer
47 | from model.pro_model import E2LLMModel
48 | logger = get_logger(__name__)
49 |
50 | def pprint_args(args, accelerator):
51 | max_key_length = max(len(str(key)) for key in vars(args).keys())
52 | message = ""
53 | message += "====" * 60 + "\n"
54 | message += "\n".join([f"{k:<{max_key_length}} : {v}" for k, v in vars(args).items()]) + "\n"
55 | message += "====" * 60 + "\n"
56 | accelerator.print(message)
57 | accelerator.print("GPU: {}".format(torch.cuda.current_device()))
58 |
59 | def str2bool(v):
60 | return v.lower() in ('yes', 'true', 't', '1')
61 |
62 | def prepare_args():
63 | parser = argparse.ArgumentParser()
64 | parser.add_argument("--train_config", type=str, default="./configs/train_config.json")
65 | parsed = parser.parse_args()
66 |
67 | with open(parsed.train_config, "r") as f:
68 | train_config = json.load(f)
69 |
70 | args = E2LLMTrainArgs(**train_config)
71 | args.output_dir = f"/path/to/your/output/of/ckpt/{args.mark}"
72 | args.tb_dir = f"/path/to/your/output/of/ckpt/{args.mark}/log"
73 |
74 | return args
75 |
76 | def main():
77 | os.environ["TOKENIZERS_PARALLELISM"] = "false"
78 | os.environ["HF_HUB_OFFLINE"] = "false"
79 | args = prepare_args()
80 |
81 | if args.seed is not None:
82 | set_seed(args.seed)
83 |
84 | os.makedirs(args.output_dir, exist_ok=True)
85 | logging.basicConfig(filename=f"{args.output_dir}/training_{args.mark}.log", level=logging.INFO)
86 |
87 | # define accelerator
88 | init_process_kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=args.init_timeout_seconds))
89 | if args.distributed_type is not None and args.distributed_type.lower() == "fsdp":
90 | fsdp_plugin = FullyShardedDataParallelPlugin(
91 | # state_dict_config=FullStateDictConfig(offload_to_cpu=True, rank0_only=True),
92 | # optim_state_dict_config=FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True),
93 | limit_all_gathers=True,
94 | sync_module_states=True,
95 | use_orig_params=True,
96 | cpu_offload=False,
97 | )
98 | accelerator = Accelerator(
99 | gradient_accumulation_steps=args.gradient_accumulation_steps,
100 | fsdp_plugin=fsdp_plugin,
101 | kwargs_handlers=[init_process_kwargs])
102 | else:
103 | accelerator = Accelerator(
104 | gradient_accumulation_steps=args.gradient_accumulation_steps,
105 | kwargs_handlers=[init_process_kwargs])
106 |
107 | args.world_size = accelerator.num_processes
108 | pprint_args(args, accelerator)
109 | if accelerator.is_main_process:
110 | with open(os.path.join(args.output_dir, "args.json"), "a") as f:
111 | json.dump(args.dict(), f, indent=2)
112 |
113 |
114 |
115 | global_rank = accelerator.process_index
116 | local_rank = accelerator.local_process_index
117 | accelerator.print(f"world_size: {args.world_size}, global_rank: {global_rank}, local_rank: {local_rank}")
118 | logger.info(accelerator.state, main_process_only=False)
119 |
120 | t0 = time.time()
121 | train_dataset = TrainDataset(names=args.train_data_list, max_seq_len=args.max_seq_len, chunk_size=args.chunk_size, shuffle_sample_ratio=args.shuffle_sample_ratio, noisy_sample_ratio=args.noisy_sample_ratio, noise_rate=args.noise_rate, max_num_chunks=args.max_num_chunks, max_sliding_windows=args.max_sliding_windows, process_index=global_rank, num_processes=args.world_size)
122 | val_dataset = ValDataset(names=args.val_data_list, max_seq_len=args.max_seq_len, chunk_size=args.chunk_size, shuffle_sample_ratio=args.shuffle_sample_ratio, noisy_sample_ratio=args.noisy_sample_ratio, noise_rate=args.noise_rate, max_num_chunks=args.max_num_chunks, max_sliding_windows=args.max_sliding_windows, process_index=global_rank, num_processes=args.world_size)
123 | train_sampler = RandomSampler(train_dataset)
124 | val_sampler = RandomSampler(val_dataset)
125 | train_dataloader = DataLoader(train_dataset, batch_size=args.bs, shuffle=False, sampler=train_sampler, collate_fn=collate_fn, num_workers=0)
126 | val_dataloader = DataLoader(val_dataset, batch_size=args.bs, shuffle=False, sampler=val_sampler, collate_fn=collate_fn, num_workers=0)
127 | t1 = time.time()
128 | logger.info(f"dataset loading time: {t1 - t0:.4f}")
129 |
130 | # cuda memory
131 | free_in_GB = int(torch.cuda.mem_get_info()[0] / 1024 ** 3)
132 | max_memory = f"{free_in_GB - 2}GB"
133 | n_gpus = torch.cuda.device_count()
134 | max_memory = {i: max_memory for i in range(n_gpus)}
135 | accelerator.print("max memory: ", max_memory, n_gpus)
136 |
137 |
138 | with open("./configs/lora_modules.json", 'r') as fr:
139 | lora_modules = json.load(fr)
140 |
141 | if args.peft_fn == "qlora":
142 | print_rank_0(f"[INFO] args.peft_type is set 'qlora', setting quantization to '4bit'")
143 | args.quantization = "4bit"
144 | else:
145 | args.quantization = None
146 |
147 | enc_model = args.encoder_model_dir.split('/')[-1]
148 | dec_model = args.decoder_model_dir.split('/')[-1]
149 |
150 | # peft config
151 | if args.peft_fn in ["lora", "qlora"]:
152 | lora_config_enc = LoraConfig(
153 | r=args.lora_rank_enc,
154 | lora_alpha=args.lora_rank_enc,
155 | # target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
156 | target_modules=lora_modules[enc_model],
157 | lora_dropout=0.05,
158 | bias="lora_only",
159 | inference_mode=False,
160 | task_type=TaskType.FEATURE_EXTRACTION
161 | )
162 | lora_config_dec = LoraConfig(
163 | r=args.lora_rank_dec,
164 | lora_alpha=args.lora_alpha_dec,
165 | # target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "down_proj", "up_proj"],
166 | target_modules=lora_modules[dec_model] if dec_model in lora_modules else "all-linear",
167 | lora_dropout=0.05,
168 | bias="lora_only",
169 | inference_mode=False,
170 | task_type=TaskType.CAUSAL_LM
171 | )
172 |
173 | model = E2LLMModel(args=args)
174 | model.decoder_model.gradient_checkpointing_enable()
175 |
176 | if args.quantization == "4bit":
177 | model.decoder_model = prepare_model_for_kbit_training(model.decoder_model)
178 | if args.peft_fn in ["lora", "qlora"]:
179 | model.encoder_model = get_peft_model(model.encoder_model, lora_config_enc)
180 | model.decoder_model = get_peft_model(model.decoder_model, lora_config_dec)
181 | if accelerator.is_main_process:
182 | model.encoder_model.print_trainable_parameters()
183 | model.decoder_model.print_trainable_parameters()
184 | else:
185 | accelerator.print("[WARNING] Full-Parameters Training takes up a lot of space, we set num_ckpt to 2")
186 | args.num_ckpt = 2
187 | # model.config.use_cache = False # silence the warnings. Please re-enable for inference!
188 |
189 | _, trainable_parameters = get_parameter_number(model)
190 | logger.info(f"Trainable parameters: {trainable_parameters}")
191 | print(f"Trainable parameters: {trainable_parameters}")
192 |
193 | t2 = time.time()
194 | if accelerator.is_main_process:
195 | logging.info(f"model loading time: {t2 - t1:.4f}")
196 |
197 | if hasattr(model.decoder_model.config, "use_logn_attn"):
198 | model.decoder_model.config.use_logn_attn = False # special for qwen model
199 |
200 | # load balance for moe training
201 | if hasattr(model.decoder_model.config, "output_router_logits"):
202 | model.decoder_model.config.output_router_logits = True
203 |
204 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
205 | if args.max_train_steps is None:
206 | args.max_train_steps = args.num_epochs * num_update_steps_per_epoch
207 |
208 | if accelerator.distributed_type == DistributedType.DEEPSPEED:
209 | adam_optimizer = torch.optim.AdamW
210 | elif accelerator.distributed_type == DistributedType.FSDP:
211 | if args.peft_type and getattr(accelerator.state, "fsdp_plugin", None) is not None:
212 | from peft.utils.other import fsdp_auto_wrap_policy
213 |
214 | accelerator.state.fsdp_plugin.auto_wrap_policy = fsdp_auto_wrap_policy(model)
215 | model = accelerator.prepare(model)
216 | adam_optimizer = torch.optim.AdamW
217 | else:
218 | raise ValueError("Only support DeepSpeed and FSDP")
219 |
220 | optimizer = adam_optimizer(
221 | model.parameters(),
222 | weight_decay=args.weight_decay,
223 | lr=args.lr,
224 | betas=(0.9, 0.999),
225 | )
226 |
227 | lr_scheduler = get_scheduler(
228 | name= "linear", #args.lr_scheduler_type,
229 | optimizer=optimizer,
230 | num_warmup_steps=args.num_warmup_steps * args.gradient_accumulation_steps * accelerator.num_processes,
231 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps * accelerator.num_processes
232 | )
233 |
234 |
235 | # prepare all
236 | if accelerator.distributed_type == DistributedType.DEEPSPEED:
237 | model, _, optimizer, lr_scheduler = accelerator.prepare(model, val_dataloader, optimizer, lr_scheduler)
238 | elif accelerator.distributed_type == DistributedType.FSDP:
239 | _, optimizer, lr_scheduler = accelerator.prepare(val_dataloader, optimizer, lr_scheduler)
240 | else:
241 | raise ValueError("Only support DeepSpeed and FSDP")
242 |
243 | print(model.device)
244 | accelerator.print(model)
245 |
246 | # Recalculate our total training steps as the size of the training dataloader may have changed.
247 | # num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
248 | # if overrode_max_train_steps:
249 | # args.max_train_steps = args.num_epochs * num_update_steps_per_epoch
250 | # # Afterward we recalculate our number of training epochs
251 | # args.num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
252 |
253 |
254 | # zero 3 flag
255 | is_ds_zero_3 = False
256 | if getattr(accelerator.state, "deepspeed_plugin", None):
257 | is_ds_zero_3 = accelerator.state.deepspeed_plugin.zero_stage == 3
258 | accelerator.print(f"DEEPSPEED plugin: {accelerator.state.deepspeed_plugin}")
259 | elif getattr(accelerator.state, "fsdp_plugin", None):
260 | accelerator.print(f"FSDP plugin: {accelerator.state.fsdp_plugin}")
261 |
262 | trainer = E2LLMTrainer(
263 | accelerator=accelerator,
264 | model=model,
265 | train_dataloader=train_dataloader,
266 | val_dataloader=val_dataloader,
267 | optimizer=optimizer,
268 | lr_scheduler=lr_scheduler,
269 | tokenizer=model.tokenizer,
270 | total_train_dataset_size=len(train_dataset),
271 | args=args,
272 | )
273 | trainer.accelerate_train()
274 | logger.info(f"Training Finished!")
275 |
276 | if __name__ == '__main__':
277 | main()
--------------------------------------------------------------------------------
/eval.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import math
4 | import time
5 | import re
6 | import argparse
7 | import json
8 | import logging
9 | from utils.common_utils import compile_helper
10 | from tqdm.auto import tqdm
11 | import transformers
12 | import numpy as np
13 | import torch
14 | from torch import nn
15 | import datasets
16 | from torch.utils.data.distributed import DistributedSampler
17 | from torch.utils.data import DataLoader, Dataset, RandomSampler
18 | from transformers import (
19 | AutoModelForCausalLM,
20 | AutoTokenizer,
21 | get_linear_schedule_with_warmup,
22 | set_seed,
23 | BitsAndBytesConfig,
24 | get_scheduler,
25 | )
26 | from peft import (
27 | LoraConfig,
28 | TaskType,
29 | get_peft_model,
30 | prepare_model_for_kbit_training,
31 | PeftModel,
32 | )
33 | from accelerate import Accelerator, DistributedType, FullyShardedDataParallelPlugin, DataLoaderConfiguration
34 | from accelerate.logging import get_logger
35 | from datetime import timedelta
36 | from accelerate.utils import InitProcessGroupKwargs
37 | from transformers.optimization import Adafactor
38 |
39 | # insert project root as import path, just in case
40 | current_path = os.path.abspath(__file__)
41 | parent_dir = os.path.dirname(os.path.dirname(current_path))
42 | sys.path.insert(0, parent_dir)
43 |
44 | from dataset.dataset import *
45 | from utils.common_utils import print_rank_0, upload_to_oss
46 | from pefts import E2LLMTrainArgs, E2LLMTrainer
47 | from model.pro_model import E2LLMModel
48 | logger = get_logger(__name__)
49 | from rouge import Rouge
50 | from evaluate import compute_rouge, compute_f1, compute_score
51 | from draw import draw_results
52 |
53 | def write_json_fancy(json_data, path):
54 | with open(path, 'a') as file:
55 | json_str = json.dumps(json_data, indent=4)
56 | file.write(json_str + '\n')
57 |
58 | def str2bool(v):
59 | return v.lower() in ('yes', 'true', 't', '1')
60 |
61 |
62 | def filter_top_k_chunks(model, chunks, queries, task_ids, top_k, thres):
63 |
64 | res_retrieved_chunks = []
65 | chunks_nums = [len(chunks_per) for chunks_per in chunks]
66 |
67 | chunks_cpy = [chunk for chunks_per in chunks for chunk in chunks_per]
68 | queries_chunks = queries + chunks_cpy
69 | queries_chunks_emb = model.encoder_model.encode_without_grad(queries_chunks)
70 | queries_emb = queries_chunks_emb[:len(queries)] # (bs, emb)
71 | chunks_emb = queries_chunks_emb[len(queries):]
72 | splitted_chunks_emb = torch.split(chunks_emb, chunks_nums)
73 |
74 | for i, t in enumerate(splitted_chunks_emb):
75 | if task_ids[i]==0:
76 | res_retrieved_chunks.append([])
77 | continue
78 | sim = torch.matmul(queries_emb[i].unsqueeze(0), splitted_chunks_emb[i].T)
79 | values, indices = torch.topk(sim, min(top_k, sim.size(1)), dim=1)
80 | mask = values > thres
81 | indices = indices[mask].to('cpu')
82 | # indices = indices[0].to('cpu')
83 | retrieved_chunks = np.array(chunks[i])[indices]
84 | res_retrieved_chunks.append(retrieved_chunks)
85 |
86 | a_sorted = sorted(indices, reverse=True)
87 | for index in a_sorted:
88 | del chunks[i][index]
89 | return chunks, res_retrieved_chunks
90 |
91 |
92 | def postprocess_pred(predict_str):
93 | predict_str = predict_str.strip()
94 | # Remove all non-printable characters
95 | np_pattern = re.compile(r'[\x00-\x1f]')
96 | predict_str = np_pattern.sub('\n', predict_str).strip()
97 |
98 | return predict_str
99 |
100 | def string_match_part(preds, refs):
101 | score = sum([max([1.0 if r.lower() in pred.lower() else 0.0 for r in ref]) for pred, ref in zip(preds, refs)]) / len(preds) * 100
102 | return round(score, 2)
103 |
104 | def string_match_all(preds, refs):
105 | score = sum([sum([1.0 if r.lower() in pred.lower() else 0.0 for r in ref]) / len(ref) for pred, ref in zip(preds, refs)]) / len(preds) * 100
106 | return round(score, 2)
107 |
108 | def extract_answer(response):
109 | response = response.replace('*', '')
110 | match = re.search(r'The correct answer is \(([A-D])\)', response)
111 | if match:
112 | return match.group(1)
113 | else:
114 | match = re.search(r'The correct answer is ([A-D])', response)
115 | if match:
116 | return match.group(1)
117 | else:
118 | return None
119 |
120 | def gen_results(details_file, output_file, name):
121 | files = [details_file] # "./wrongresults.jsonl", './results.jsonl'
122 | output = ["Model\tOverall\tEasy\tHard\tShort\tMedium\tLong"]
123 | compensated = False
124 |
125 | for file1 in files:
126 | filename = file1
127 | try:
128 | pred_data = json.load(open(filename, encoding='utf-8'))
129 | except Exception as e:
130 | pred_data = [json.loads(line) for line in open(filename, encoding='utf-8')]
131 | easy, hard, short, medium, long = 0, 0, 0, 0, 0
132 | easy_acc, hard_acc, short_acc, medium_acc, long_acc = 0, 0, 0, 0, 0
133 | for pred in pred_data:
134 | acc = int(pred['judge'])
135 | if compensated and pred["pred"] == None:
136 | acc = 0.25
137 | if pred["difficulty"] == "easy":
138 | easy += 1
139 | easy_acc += acc
140 | else:
141 | hard += 1
142 | hard_acc += acc
143 |
144 | if pred['length'] == "short":
145 | short += 1
146 | short_acc += acc
147 | elif pred['length'] == "medium":
148 | medium += 1
149 | medium_acc += acc
150 | else:
151 | long += 1
152 | long_acc += acc
153 |
154 | output.append(name+'\t'+str(round(100*(easy_acc+hard_acc)/len(pred_data), 1))+'\t'+str(round(100*easy_acc/(easy+1e-5), 1))+'\t'+str(round(100*hard_acc/(hard+1e-5), 1))+'\t'+str(round(100*short_acc/(short+1e-5), 1))+'\t'+str(round(100*medium_acc/(medium+1e-5), 1))+'\t'+str(round(100*long_acc/(long+1e-5), 1)))
155 |
156 | open(output_file, 'a', encoding='utf-8').write('\n'.join(output))
157 |
158 |
159 | def prepare_args():
160 | parser = argparse.ArgumentParser()
161 | parser.add_argument("--eval_config", type=str, default="./configs/eval_config.json")
162 | parser.add_argument('--test_data_list', type=str, default='[qmsum]')
163 | parser.add_argument('--global_step', default=5, type=int)
164 | parser.add_argument('--mark', required=True)
165 | parser.add_argument('--encoder_model_dir', default='/path/to/encoder/model', type=str, help='Model directory')
166 | parser.add_argument('--decoder_model_dir', default='/path/to/decoder/model', type=str, help='Model directory')
167 | parser.add_argument('--chunk_size', default=512, type=int, help='chunk_size')
168 | parser.add_argument('--max_num_chunks', default=500, type=int, help='max_num_chunks')
169 | parser.add_argument('--enc_pool_fn', default='CLS', type=str, help='')
170 | parser.add_argument('--tokens_for_each_chunk', default=4, type=int, help='')
171 | parser.add_argument('--ln', default='True', type=str2bool, help='')
172 | parser.add_argument('--proj_arch', default='', type=str, help='proj_arch')
173 | parser.add_argument('--lora_rank_enc', default=32, type=int, help='lora rank')
174 | parser.add_argument('--lora_alpha_enc', default=32, type=int, help='lora alpha')
175 | parser.add_argument('--lora_rank_dec', default=8, type=int, help='lora rank')
176 | parser.add_argument('--lora_alpha_dec', default=8, type=int, help='lora alpha')
177 | parser.add_argument('--eval_retrieve', default='True', type=str2bool, help='')
178 | parser.add_argument('--top_k', default=3, type=int, help='')
179 | parser.add_argument('--thres', default=0.8, type=float, help='')
180 | parser.add_argument('--attn_implementation', default='eager', type=str, help='')
181 | parser.add_argument('--bs', default=10, type=int)
182 | parser.add_argument('--max_new_tokens', default=100, type=int)
183 | parser.add_argument('--api', default="gpt4o", type=str) # qwen72b
184 |
185 |
186 | parsed = parser.parse_args()
187 |
188 | with open(parsed.eval_config, "r") as f:
189 | eval_config = json.load(f)
190 |
191 | args = E2LLMTrainArgs(**eval_config)
192 | args.test_data_list = parsed.test_data_list
193 | args.global_step = parsed.global_step
194 | args.mark = parsed.mark
195 | args.encoder_model_dir = parsed.encoder_model_dir
196 | args.decoder_model_dir = parsed.decoder_model_dir
197 | args.chunk_size = parsed.chunk_size
198 | args.max_num_chunks = parsed.max_num_chunks
199 | args.enc_pool_fn = parsed.enc_pool_fn
200 | args.tokens_for_each_chunk = parsed.tokens_for_each_chunk
201 | args.ln = parsed.ln
202 | args.api = parsed.api
203 |
204 | if parsed.proj_arch not in [None, '', 'None', 'null']:
205 | args.proj_arch = parsed.proj_arch
206 | else:
207 | args.proj_arch = None
208 |
209 | args.lora_rank_enc = parsed.lora_rank_enc
210 | args.lora_alpha_enc = parsed.lora_alpha_enc
211 | args.lora_rank_dec = parsed.lora_rank_dec
212 | args.lora_alpha_dec = parsed.lora_alpha_dec
213 |
214 | args.eval_retrieve = parsed.eval_retrieve
215 | args.top_k = parsed.top_k
216 | args.thres = parsed.thres
217 | args.attn_implementation = parsed.attn_implementation
218 | args.bs = parsed.bs
219 | args.max_new_tokens = parsed.max_new_tokens
220 | args.bf16 = True
221 | args.output_result_path=f"/path/to/result/json"
222 |
223 | return args
224 |
225 | def main():
226 | os.environ["TOKENIZERS_PARALLELISM"] = "false"
227 | os.environ["HF_HUB_OFFLINE"] = "false"
228 | args = prepare_args()
229 | print(args)
230 |
231 | ckpt_root = f'/path/to/ckpt'
232 | if args.seed is not None:
233 | set_seed(args.seed)
234 | with open("./configs/lora_modules.json", 'r') as fr:
235 | lora_modules = json.load(fr)
236 |
237 |
238 | enc_model = args.encoder_model_dir.split('/')[-1]
239 | dec_model = args.decoder_model_dir.split('/')[-1]
240 |
241 | if args.peft_fn in ["lora", "qlora"]:
242 | lora_config_enc = LoraConfig(
243 | r=args.lora_rank_enc,
244 | lora_alpha=args.lora_rank_enc,
245 | target_modules=lora_modules[enc_model],
246 | lora_dropout=0,
247 | bias="lora_only",
248 | inference_mode=False,
249 | task_type=TaskType.FEATURE_EXTRACTION
250 | )
251 | lora_config_dec = LoraConfig(
252 | r=args.lora_rank_dec,
253 | lora_alpha=args.lora_alpha_dec,
254 | target_modules=lora_modules[dec_model] if dec_model in lora_modules else "all-linear",
255 | lora_dropout=0,
256 | bias="lora_only",
257 | inference_mode=False,
258 | task_type=TaskType.CAUSAL_LM
259 | )
260 |
261 | model = E2LLMModel(args=args)
262 | encoder_lora_ckpt_path = os.path.join(ckpt_root, 'encoder')
263 | decoder_lora_ckpt_path = os.path.join(ckpt_root, 'decoder')
264 | projector_ckpt_path = os.path.join(ckpt_root, 'projector.pth')
265 |
266 | model.encoder_model = PeftModel.from_pretrained(model.encoder_model, encoder_lora_ckpt_path)
267 | model.decoder_model = PeftModel.from_pretrained(model.decoder_model, decoder_lora_ckpt_path)
268 | if args.proj_arch is not None:
269 | projector_state_dict = torch.load(projector_ckpt_path)
270 | model.projector.load_state_dict(projector_state_dict)
271 |
272 |
273 | device = model.decoder_model.device
274 |
275 | # model.encoder_model.plm_model.to(device)
276 | model.encoder_model.to(device)
277 | if args.proj_arch is not None:
278 | for p in model.projector.parameters():
279 | p = p.to(device)
280 | model.eval()
281 | res_results = {}
282 | names = args.test_data_list[1:-1].split(',')
283 | names = [name.strip().lower() for name in names]
284 | for name in names:
285 | if name == 'qmsum':
286 | with torch.no_grad():
287 | test_qmsum_dataset = TestqmsumDataset(chunk_size=args.chunk_size)
288 | test_qmsum_dataloader = DataLoader(test_qmsum_dataset, batch_size=args.bs, shuffle=False, collate_fn=collate_fn)
289 | test_qmsum_iterator = tqdm(test_qmsum_dataloader, mininterval=0)
290 | references = []
291 | responses = []
292 | with torch.no_grad():
293 | for step, batch in enumerate(test_qmsum_iterator):
294 | chunks, prompts, outputs, task_ids = batch
295 | if args.eval_retrieve:
296 | chunks, retrieved_chunks = filter_top_k_chunks(model, chunks, prompts, task_ids, args.top_k, args.thres)
297 | response = model.predict(chunks, prompts, outputs, args.max_new_tokens, retrieved_chunks)
298 | else:
299 | response = model.predict(chunks, prompts, outputs, args.max_new_tokens)
300 |
301 | responses.extend(response)
302 | references.extend(outputs)
303 | rouger = Rouge()
304 | responses = [response if response.strip() else '.' for response in responses]
305 | results = compute_rouge(rouger, name, responses, references)
306 | write_json_fancy(results, args.output_result_path)
307 |
308 |
309 | if __name__ == '__main__':
310 | main()
--------------------------------------------------------------------------------
/pefts/e2llm_trainer.py:
--------------------------------------------------------------------------------
1 | from transformers import AutoTokenizer, AutoConfig
2 | import gc
3 | import os
4 | import sys
5 | import threading
6 | import argparse
7 | import math
8 | import json
9 | import time
10 | import transformers
11 | import numpy as np
12 | import psutil
13 | import shutil
14 | import torch
15 | from torch import nn
16 | import torch.distributed as dist
17 | from torch.utils.tensorboard import SummaryWriter
18 | from typing import List, Optional, Tuple, Union
19 | from tqdm.auto import tqdm
20 | from tqdm import trange
21 | from accelerate.logging import get_logger
22 | from accelerate import Accelerator
23 | from transformers import set_seed
24 | logger = get_logger(__name__)
25 |
26 | def write_tensorboard(summary_writer: SummaryWriter, log_dict: dict, completed_steps):
27 | for key, value in log_dict.items():
28 | summary_writer.add_scalar(f"{key}", value, completed_steps)
29 |
30 | def check_existing_ckpts(output_dir):
31 | prefix = "step_"
32 |
33 | if not os.path.exists(output_dir):
34 | return []
35 | # list all files and dirs
36 | contents = os.listdir(output_dir)
37 |
38 | # find dirs starts with "step_"
39 | matching_folders = [
40 | folder for folder in contents if os.path.isdir(os.path.join(output_dir, folder)) and folder.startswith(prefix)
41 | ]
42 |
43 | return matching_folders
44 |
45 | def delete_ckpts_over_limits(output_dir, saving_limit, best_step):
46 | """delete ckpts more than saving_limits except for the best_step ckpt"""
47 | existing_ckpts = check_existing_ckpts(output_dir)
48 | logger.info(f"Existing step ckpts folders: {existing_ckpts}, best step ckpt: step_{best_step}", main_process_only=True)
49 | # sorted only step num ascendingly
50 | ckpt_steps = sorted([int(ckpt.replace("step_", "")) for ckpt in existing_ckpts])
51 | # delete the oldest steps except for the best step at present
52 | if len(ckpt_steps) > saving_limit:
53 | deletable_steps = [ckpt_step for ckpt_step in ckpt_steps if ckpt_step != best_step]
54 | # print(deletable_steps[:len(ckpt_steps) - saving_limit])
55 | for del_step in deletable_steps[: len(ckpt_steps) - saving_limit]:
56 | shutil.rmtree(os.path.join(output_dir, f"step_{del_step}"))
57 | logger.info(f"Removed ckpt step_{del_step}", main_process_only=True)
58 |
59 |
60 |
61 | class E2LLMTrainer:
62 | def __init__(
63 | self,
64 | accelerator: Accelerator,
65 | model,
66 | train_dataloader,
67 | val_dataloader,
68 | optimizer,
69 | lr_scheduler,
70 | tokenizer,
71 | total_train_dataset_size,
72 | args,
73 | ):
74 | self.accelerator = accelerator
75 | self.model = model
76 | self.train_dataloader = train_dataloader
77 | self.val_dataloader = val_dataloader
78 | self.optimizer = optimizer
79 | self.lr_scheduler = lr_scheduler
80 | self.tokenizer = tokenizer
81 | self.total_train_dataset_size = total_train_dataset_size
82 | self.args = args
83 | self.summary_writer = SummaryWriter(log_dir=args.tb_dir)
84 |
85 | def print(self, msg: str):
86 | """
87 | accelerator print, default on main process
88 | Args:
89 | msg:
90 |
91 | Returns:
92 |
93 | """
94 | self.accelerator.print(msg)
95 |
96 | @staticmethod
97 | def format_tensor(tensor, n):
98 | return list(map(lambda x: round(x, n), tensor.tolist()))
99 |
100 |
101 | def accelerate_saving_checkpoint(self, output_dir: str, completed_steps: int):
102 | """
103 | Saving lora adaptor or full checkpoint using accelerator
104 | Args:
105 | output_dir: exact dir for saving ckpt
106 | completed_steps:
107 |
108 | Returns:
109 |
110 | """
111 | self.accelerator.wait_for_everyone()
112 |
113 | logger.info(f"[CHECKPOINT] Saving checkpoint", main_process_only=True)
114 | if self.accelerator.is_main_process:
115 |
116 | if self.args.proj_arch is not None:
117 | torch.save(self.accelerator.get_state_dict(self.model.module.projector), f'{output_dir}/projector.pth')
118 |
119 | unwrapped_model = self.accelerator.unwrap_model(self.model)
120 | unwrapped_model.encoder_model.save_pretrained(
121 | f'{output_dir}/encoder',
122 | is_main_process=self.accelerator.is_main_process,
123 | save_function=self.accelerator.save,
124 | state_dict=self.accelerator.get_state_dict(self.model.encoder_model),)
125 | unwrapped_model.decoder_model.save_pretrained(
126 | f'{output_dir}/decoder',
127 | is_main_process=self.accelerator.is_main_process,
128 | save_function=self.accelerator.save,
129 | state_dict=self.accelerator.get_state_dict(self.model.decoder_model),
130 | )
131 | self.accelerator.wait_for_everyone()
132 | # for full-parameter training, save whole ckpt and tokenizer together because it does not need a merge.
133 | if not self.args.peft_fn and self.accelerator.is_main_process:
134 | self.tokenizer.save_pretrained(output_dir)
135 |
136 | sf = os.path.join(output_dir, "model.safetensors")
137 | index_file = os.path.join(output_dir, "model.safetensors.index.json")
138 | if os.path.isfile(sf) and os.path.isfile(index_file):
139 | self.print(f"Remove bug dummy ckpt {sf}")
140 | os.remove(sf)
141 |
142 | if self.accelerator.is_main_process:
143 | latest = {
144 | "latest_ckpt": output_dir,
145 | "lr": self.optimizer.param_groups[0]["lr"],
146 | }
147 | with open(os.path.join(self.args.output_dir, "latest"), "w") as f:
148 | json.dump(latest, f, indent=2)
149 |
150 | logger.info(
151 | f"[CHECKPOINT][complete_steps={completed_steps}], checkpoint {output_dir} saved, latest: {latest}",
152 | main_process_only=True,
153 | )
154 | self.accelerator.wait_for_everyone()
155 |
156 | def filter_top_k_chunks(self, chunks, queries, task_ids, top_k):
157 |
158 | res_retrieved_chunks = []
159 | chunks_nums = [len(chunks_per) for chunks_per in chunks]
160 |
161 | chunks_cpy = [chunk for chunks_per in chunks for chunk in chunks_per]
162 | queries_chunks = queries + chunks_cpy
163 | queries_chunks_emb = self.model.encoder_model.encode_without_grad(queries_chunks)
164 | queries_emb = queries_chunks_emb[:len(queries)] # (bs, emb)
165 | chunks_emb = queries_chunks_emb[len(queries):]
166 | splitted_chunks_emb = torch.split(chunks_emb, chunks_nums)
167 |
168 | for i, t in enumerate(splitted_chunks_emb):
169 | if task_ids[i]==0:
170 | res_retrieved_chunks.append([])
171 | continue
172 | sim = torch.matmul(queries_emb[i].unsqueeze(0), splitted_chunks_emb[i].T)
173 | values, indices = torch.topk(sim, top_k, dim=1)
174 | indices = indices[0].to('cpu')
175 | retrieved_chunks = np.array(chunks[i])[indices]
176 | res_retrieved_chunks.append(retrieved_chunks)
177 |
178 | a_sorted = sorted(indices, reverse=True)
179 | for index in a_sorted:
180 | del chunks[i][index]
181 | del queries_chunks_emb, queries_emb, chunks_emb, splitted_chunks_emb
182 | return chunks, res_retrieved_chunks
183 |
184 |
185 | def accelerate_monitor(
186 | self,
187 | current_epoch,
188 | reduce_loss,
189 | global_steps,
190 | ):
191 |
192 | reduce_losses = self.accelerator.gather(reduce_loss).detach().float()
193 |
194 | train_loss = torch.mean(reduce_losses) / (self.args.log_interval * self.args.gradient_accumulation_steps)
195 |
196 | logger.info(f'Epoch: {current_epoch+1}, Global Step: {global_steps}, Loss: {train_loss:.5f}', main_process_only=True)
197 |
198 | logger.info(
199 | f"[TRAIN][complete_steps={global_steps}][train_loss={train_loss:.6f}]"
200 | # f"[gather shape={list(reduce_losses.shape)}]"
201 | f"[lr={self.lr_scheduler.get_lr()[0]:.4e}, {self.optimizer.param_groups[0]['lr']:.4e}]",
202 | main_process_only=True,
203 | )
204 | train_log_dict = {"Loss/train": train_loss}
205 |
206 | if self.accelerator.is_main_process:
207 | write_tensorboard(self.summary_writer, train_log_dict, global_steps)
208 |
209 | def accelerate_evaluate(self, val_dataloader, global_steps):
210 | self.model.eval()
211 | reduce_loss_eval = torch.tensor(0.0).to(self.model.decoder_model.device)
212 | batch_iterator_eval = tqdm(val_dataloader,
213 | disable=(self.accelerator.process_index!=0!=0),
214 | mininterval=0)
215 | with torch.no_grad():
216 | for step, batch in enumerate(batch_iterator_eval):
217 | chunks, prompts, outputs, task_ids = batch
218 | if self.args.eval_retrieve:
219 | chunks, retrieved_chunks = self.filter_top_k_chunks(chunks, prompts, task_ids, self.args.top_k)
220 | loss_batch_eval = self.model(chunks, prompts, outputs, task_ids, retrieved_chunks)
221 | else:
222 | loss_batch_eval = self.model(chunks, prompts, outputs, task_ids)
223 | if not torch.isnan(loss_batch_eval):
224 | reduce_loss_eval += loss_batch_eval.detach().float()
225 | self.accelerator.wait_for_everyone()
226 | reduce_losses_eval = self.accelerator.gather(reduce_loss_eval).detach().float()
227 | val_loss = torch.mean(reduce_losses_eval) / len(val_dataloader)
228 |
229 | logger.info(f'***Evaluating***, Global Step: {global_steps}, Loss: {val_loss:.5f}', main_process_only=True)
230 |
231 | logger.info(
232 | f"[VAL][complete_steps={global_steps}][val_loss={val_loss:.6f}]",
233 | main_process_only=True,
234 | )
235 | val_log_dict = {"Loss/val": val_loss}
236 |
237 | if self.accelerator.is_main_process:
238 | write_tensorboard(self.summary_writer, val_log_dict, global_steps)
239 |
240 | self.model.train()
241 | return val_loss
242 |
243 | def accelerate_train(self):
244 |
245 | if self.args.seed is not None:
246 | set_seed(self.args.seed)
247 |
248 | global_batch_size = (
249 | self.args.bs
250 | * self.accelerator.num_processes
251 | * self.args.gradient_accumulation_steps
252 | )
253 | logger.info("************************************** Running training ****************************************")
254 | logger.info(f" Num examples = {self.total_train_dataset_size}")
255 | logger.info(f" Num Epochs = {self.args.num_epochs}")
256 | logger.info(f" Instantaneous batch size per device = {self.args.bs}")
257 | logger.info(f" Total global train batch size (w. parallel, distributed & accumulation) = {global_batch_size}")
258 | logger.info(f" Gradient Accumulation steps = {self.args.gradient_accumulation_steps}")
259 | logger.info(f" Total optimization(update/completed) steps = {self.args.max_train_steps}")
260 | logger.info(f" Complete/optimize steps per Epoch = {self.args.max_train_steps // self.args.num_epochs}")
261 | logger.info("************************************************************************************************")
262 | progress_bar = tqdm(range(self.args.max_train_steps), disable=not self.accelerator.is_local_main_process)
263 | # set starting_epoch, completed_steps and resume_step of train_dataloader
264 |
265 | trained_steps = 0
266 | best_eval_metric = 0
267 | best_epoch = 0
268 | stop = 0
269 | global_step = 0
270 | num_saved_ckpts = 0
271 | ckpt_achieve = False
272 | reduce_loss = torch.tensor(0.0).to(self.model.decoder_model.device)
273 | # reduce_loss_eval = torch.tensor(0.0).to(self.model.decoder_model.device)
274 | min_reduce_loss_eval = torch.tensor(float('inf')).to(self.model.decoder_model.device)
275 | for current_epoch in trange(int(self.args.num_epochs), desc="Epoch", disable=(self.accelerator.process_index!=0), mininterval=0):
276 | if stop >= self.args.patience:
277 | logger.info(f'Early Stop at {current_epoch+1}-th epoch {global_step}-th step', main_process_only=True)
278 | logger.info(f'Model trained!\nThe best model at {best_epoch+1}-th epoch {best_step}-th step', main_process_only=True)
279 | break
280 | if ckpt_achieve:
281 | logger.info(f'Num of ckpts achieve {self.args.num_ckpt}, Stop training.', main_process_only=True )
282 | break
283 | self.model.train()
284 | batch_iterator = tqdm(self.train_dataloader,
285 | desc=f"Running Epoch {current_epoch + 1} of {self.args.num_epochs}",
286 | disable=(self.accelerator.process_index!=0),
287 | mininterval=0)
288 | for step, batch in enumerate(batch_iterator):
289 |
290 | chunks, prompts, outputs, task_ids = batch
291 | if self.args.train_retrieve:
292 |
293 | chunks, retrieved_chunks = self.filter_top_k_chunks(chunks, prompts, task_ids, self.args.top_k)
294 | loss_batch = self.model(chunks, prompts, outputs, task_ids, retrieved_chunks)
295 | else:
296 | loss_batch = self.model(chunks, prompts, outputs, task_ids)
297 |
298 | # backward
299 | self.accelerator.backward(loss_batch)
300 | # print(self.lr_scheduler.state_dict(), self.accelerator.process_index)
301 | # update(sync_gradients)
302 | self.optimizer.step()
303 | self.lr_scheduler.step()
304 | self.optimizer.zero_grad()
305 | if not torch.isnan(loss_batch):
306 | reduce_loss += loss_batch.detach().float()
307 | if self.accelerator.sync_gradients:
308 | global_step += 1
309 |
310 | if global_step % self.args.log_interval == 0:
311 | progress_bar.update(self.args.log_interval)
312 |
313 | self.accelerate_monitor(
314 | current_epoch,
315 | reduce_loss,
316 | global_step,
317 | )
318 |
319 | # reset loss
320 | reduce_loss = torch.tensor(0.0).to(self.model.decoder_model.device)
321 |
322 | if self.args.eval_interval and global_step % self.args.eval_interval == 0:
323 |
324 | reduce_loss_eval = self.accelerate_evaluate(self.val_dataloader, global_step)
325 |
326 | save_flag = False
327 | if stop >= self.args.patience:
328 | break
329 | if reduce_loss_eval <= min_reduce_loss_eval:
330 | min_reduce_loss_eval = reduce_loss_eval
331 | best_epoch = current_epoch
332 | best_step = global_step
333 | stop = 0
334 |
335 | # remove
336 | max_save_num = 20
337 | if self.accelerator.is_main_process:
338 | try:
339 | delete_ckpts_over_limits(self.args.output_dir, max_save_num, best_step)
340 | except:
341 | logger.info('No ckpt to remove.', main_process_only=True)
342 | else:
343 | stop += 1
344 |
345 | if stop < self.args.num_ckpt:
346 | save_flag = True
347 | else:
348 | ckpt_achieve = True
349 |
350 | if save_flag:
351 | output_dir = f"step_{global_step}"
352 | if self.args.output_dir is not None:
353 | output_dir = os.path.join(self.args.output_dir, output_dir)
354 | os.makedirs(output_dir, exist_ok=True)
355 | self.accelerate_saving_checkpoint(output_dir, global_step)
356 |
357 |
358 |
--------------------------------------------------------------------------------
/model/pro_model.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import math
4 | import re
5 | import numpy as np
6 | import time
7 | import torch
8 | import torch.nn as nn
9 | from model.encoder_model_bert import BaseBertModel
10 | from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig
11 | from peft import prepare_model_for_kbit_training
12 |
13 | # from model.modeling_llama import LlamaForCausalLM
14 | # from model.configuration_llama import LlamaConfig
15 | from peft import (
16 | TaskType,
17 | LoraConfig,
18 | get_peft_model,
19 | )
20 |
21 |
22 |
23 | def vectorized_split_and_concat(original_tensor, split_indices, concat_tensor, bf16, device):
24 | bs, seq_len, emb = original_tensor.shape
25 | max_len = seq_len + concat_tensor.shape[1]
26 | if bf16:
27 | new_tensor = torch.zeros((bs, max_len, emb)).to(torch.bfloat16)
28 | else:
29 | new_tensor = torch.zeros((bs, max_len, emb))
30 |
31 | split_indices = split_indices.long()
32 | total_lengths = split_indices + concat_tensor.shape[1] + (seq_len - split_indices)
33 |
34 | for i, split_idx in enumerate(split_indices):
35 | pad_num = seq_len - split_idx
36 | new_tensor[i, :pad_num] = original_tensor[i, split_idx:]
37 | new_tensor[i, pad_num:pad_num+concat_tensor.shape[1]] = concat_tensor[i]
38 | new_tensor[i, pad_num+concat_tensor.shape[1]:total_lengths[i]] = original_tensor[i, :split_idx]
39 |
40 | new_tensor = new_tensor[:, :new_tensor.size(1) - (max_len - total_lengths).max()]
41 | return new_tensor.to(device)
42 |
43 | class Projector(nn.Module):
44 | def __init__(self, arch, input_size, output_size, bf16):
45 | super(Projector, self).__init__()
46 | self.arch_old = re.match(r'^mlp(\d+)x_gelu$', arch)
47 | self.arch_new = re.match(r'^mlpnew(\d+)x_gelu$', arch)
48 | self.input_size = input_size
49 | self.output_size = output_size
50 | if self.arch_old:
51 | mlp_depth = int(self.arch_old.group(1))
52 | if bf16:
53 | modules = [nn.Linear(self.input_size, self.output_size).bfloat16()]
54 | else:
55 | modules = [nn.Linear(self.input_size, self.output_size)]
56 | for _ in range(1, mlp_depth):
57 | modules.append(nn.GELU())
58 | if bf16:
59 | modules.append(nn.Linear(self.output_size, self.output_size).bfloat16())
60 | else:
61 | modules.append(nn.Linear(self.output_size, self.output_size))
62 |
63 | elif self.arch_new:
64 | mlp_depth = int(self.arch_new.group(1))
65 | modules = []
66 | if mlp_depth == 1:
67 | if bf16:
68 | modules = [nn.Linear(self.input_size, self.output_size).bfloat16()]
69 | else:
70 | modules = [nn.Linear(self.input_size, self.output_size)]
71 | elif mlp_depth > 1:
72 | for _ in range(mlp_depth-1):
73 | if bf16:
74 | modules.append(nn.Linear(self.input_size, self.input_size).bfloat16())
75 | else:
76 | modules.append(nn.Linear(self.input_size, self.input_size))
77 | modules.append(nn.GELU())
78 | if bf16:
79 | modules.append(nn.Linear(self.input_size, self.output_size).bfloat16())
80 | else:
81 | modules.append(nn.Linear(self.input_size, self.output_size))
82 |
83 | self.model = nn.Sequential(*modules)
84 | # self.model = nn.Linear(self.input_size, self.output_size).bfloat16()
85 | def forward(self, x):
86 | return self.model(x)
87 |
88 |
89 |
90 | class E2LLMModel(nn.Module):
91 |
92 | def __init__(self,
93 | args=None,
94 | ):
95 | super(E2LLMModel, self).__init__()
96 | self.args = args
97 | self.peft_fn = args.peft_fn
98 | self.alpha = args.alpha
99 | self.encoder_model_dir = args.encoder_model_dir
100 | self.decoder_model_dir = args.decoder_model_dir
101 | self.chunk_size = args.chunk_size
102 | self.proj_arch = args.proj_arch
103 | self.max_num_chunks = args.max_num_chunks
104 | self.max_seq_len = args.max_seq_len
105 | self.tokens_for_each_chunk = args.tokens_for_each_chunk
106 | self.padding_side = 'right'
107 | self.dtype = torch.bfloat16 if args.bf16 is True else torch.float32
108 | self.attn_implementation = args.attn_implementation
109 | self.ln = args.ln
110 |
111 |
112 | if 'Llama-2-7b-chat-hf' in self.decoder_model_dir.split('/')[-1] or 'Llama-2-13b-chat-hf' in self.decoder_model_dir.split('/')[-1] or 'Llama2-70B-Chat-hf' in self.decoder_model_dir.split('/')[-1]:
113 | self.tokenizer = AutoTokenizer.from_pretrained(self.decoder_model_dir, trust_remote_code=True, truncation_side='right', padding_side=self.padding_side)
114 | self.tokenizer.pad_token = self.tokenizer.eos_token
115 | if self.args.mode == "train":
116 | self.decoder_model = AutoModelForCausalLM.from_pretrained(self.decoder_model_dir,
117 | trust_remote_code=self.args.trust_remote_code,
118 | torch_dtype=self.dtype,
119 | attn_implementation=self.attn_implementation,
120 | quantization_config=(
121 | BitsAndBytesConfig(
122 | load_in_4bit=(self.args.quantization == "4bit"),
123 | bnb_4bit_compute_dtype=torch.bfloat16,
124 | bnb_4bit_use_double_quant=True,
125 | bnb_4bit_quant_type="nf4",
126 | bnb_4bit_quant_storage=torch.bfloat16,
127 | )
128 | if self.args.quantization == "4bit"
129 | else None),
130 | )
131 | elif self.args.mode == "eval":
132 | self.decoder_model = AutoModelForCausalLM.from_pretrained(self.decoder_model_dir,
133 | trust_remote_code=self.args.trust_remote_code,
134 | torch_dtype=self.dtype,
135 | attn_implementation=self.attn_implementation,
136 | device_map="auto"
137 | )
138 |
139 | else:
140 | raise ValueError(f"args.mode can only be 'train' or 'eval' ")
141 |
142 | assert self.decoder_model.config.hidden_size % self.tokens_for_each_chunk == 0
143 | self.pma_output_dim = int(self.decoder_model.config.hidden_size / self.tokens_for_each_chunk)
144 |
145 | self.max_seq_len = min(4096, args.max_seq_len)-self.max_num_chunks-500 if self.args.train_retrieve or self.args.eval_retrieve else min(4096, args.max_seq_len)-self.max_num_chunks-60 # 500是给templeate和query留的,以及retrieved chunks
146 | self.start_mark = self.tokenizer.bos_token
147 | self.user_start_mark = '[INST]'
148 | self.user_end_mark = '[/INST]'
149 | self.sep = ' '
150 | self._n = '\n'
151 | elif "Qwen2.5-7B-Instruct" in self.decoder_model_dir.split('/')[-1] or "Qwen2.5-72B-Instruct" in self.decoder_model_dir.split('/')[-1]:
152 | self.tokenizer = AutoTokenizer.from_pretrained(self.decoder_model_dir, trust_remote_code=True, truncation_side='right', padding_side=self.padding_side)
153 | if self.args.mode == "train":
154 | self.decoder_model = AutoModelForCausalLM.from_pretrained(self.decoder_model_dir,
155 | trust_remote_code=self.args.trust_remote_code,
156 | torch_dtype=self.dtype,
157 | attn_implementation=self.attn_implementation,
158 | quantization_config=(
159 | BitsAndBytesConfig(
160 | load_in_4bit=(self.args.quantization == "4bit"),
161 | bnb_4bit_compute_dtype=torch.bfloat16,
162 | bnb_4bit_use_double_quant=True,
163 | bnb_4bit_quant_type="nf4",
164 | bnb_4bit_quant_storage=torch.bfloat16,
165 | )
166 | if self.args.quantization == "4bit"
167 | else None),
168 | )
169 | elif self.args.mode == "eval":
170 | self.decoder_model = AutoModelForCausalLM.from_pretrained(self.decoder_model_dir,
171 | trust_remote_code=self.args.trust_remote_code,
172 | torch_dtype=self.dtype,
173 | attn_implementation=self.attn_implementation,
174 | device_map="auto"
175 | )
176 | else:
177 | raise ValueError(f"args.mode can only be 'train' or 'eval' ")
178 | assert self.decoder_model.config.hidden_size % self.tokens_for_each_chunk == 0
179 | self.pma_output_dim = int(self.decoder_model.config.hidden_size / self.tokens_for_each_chunk)
180 | self.max_seq_len = min(32000, args.max_seq_len)-self.max_num_chunks-500 if self.args.train_retrieve or self.args.eval_retrieve else min(32000, args.max_seq_len)-self.max_num_chunks-60 # 500是给templeate和query留的,以及retrieved chunks
181 | self.system_mark = 'system'
182 | self.user_mark = 'user'
183 | self.assistant_mark = 'assistant'
184 | self.start_mark = '<|im_start|>'
185 | self.end_mark = '<|im_end|>'
186 | self.sep = '\n'
187 | else:
188 | raise NotImplementedError(f'NotImplementedError of {self.decoder_model_dir}')
189 |
190 | if 'bge-m3' in self.encoder_model_dir.split('/')[-1] or 'gte-large-en-v1.5' in self.encoder_model_dir.split('/')[-1] or 'gte-base-en-v1.5' in self.encoder_model_dir.split('/')[-1]:
191 | self.encoder_model = BaseBertModel(model_name_or_path=self.encoder_model_dir, max_seq_length=self.chunk_size, encoder_type=self.args.enc_pool_fn, pma_output_dim=self.pma_output_dim, tokens_for_each_chunk=self.tokens_for_each_chunk, ln=self.ln)
192 | else:
193 | raise NotImplementedError(f'UnImplemented Encoder Model :{self.encoder_model_dir}')
194 |
195 | self.encoder_emb_size = self.encoder_model.plm_model.config.hidden_size
196 | self.decoder_emb_size = self.decoder_model.config.hidden_size
197 | if self.proj_arch is not None:
198 | if self.args.enc_pool_fn != "PMA":
199 | self.projector = Projector(self.proj_arch, self.encoder_emb_size, self.decoder_emb_size, self.args.bf16)
200 | else:
201 | self.projector = Projector(self.proj_arch, self.decoder_emb_size, self.decoder_emb_size, self.args.bf16)
202 |
203 | self.loss_fct = nn.CrossEntropyLoss(ignore_index=-100, reduction='none')
204 | self.system_content = 'You are a helpful assistant.'
205 | self.user_prefix = 'Given the contexts:'
206 | self.user_suffix = 'Please follow the instruction:'
207 |
208 | self.user_prefix_prompt = f"{self.start_mark}{self.system_mark}{self.sep}{self.system_content}{self.end_mark}{self.sep}{self.start_mark}{self.user_mark}{self.sep}{self.user_prefix}"
209 | # self.user_prefix_prompt = f"{self.start_mark}{self.user_start_mark}{self.sep}{self.user_prefix}"
210 |
211 | self.user_prefix_input_ids = self.tokenizer(self.user_prefix_prompt, return_tensors='pt', add_special_tokens=False)['input_ids']
212 | self.user_prefix_prompt_len = len(self.user_prefix_input_ids[0]) # 18
213 |
214 |
215 | self.user_suffix_input_ids = self.tokenizer(self.user_suffix, return_tensors='pt', add_special_tokens=False)['input_ids']
216 | self.user_suffix_len = len(self.user_suffix_input_ids[0])
217 |
218 | self.start_mark_len = len(self.tokenizer(self.start_mark, return_tensors='pt', add_special_tokens=False)['input_ids'][0])
219 | self.end_mark_len = len(self.tokenizer(self.end_mark, return_tensors='pt', add_special_tokens=False)['input_ids'][0])
220 | self.assistant_mark_len = len(self.tokenizer(self.assistant_mark, return_tensors='pt', add_special_tokens=False)['input_ids'][0])
221 |
222 |
223 | self.sep_id = self.tokenizer(self.sep, return_tensors='pt', add_special_tokens=False)['input_ids'][0] # [0]取出[id]/[id,id]
224 | self.sep_len = len(self.tokenizer(self.sep, return_tensors='pt', add_special_tokens=False)['input_ids'][0])
225 | self.sft_end_marker = self.tokenizer.eos_token
226 | self.sft_end_marker_id = self.tokenizer.eos_token_id
227 | self.sft_end_marker_len = 1
228 |
229 | def update_queries(self, queries, retrieved_chunks):
230 | queries = ['Some information that may be useful:' + '\n'.join(chunks) + '\nQuestion:' + query for query, chunks in zip(queries, retrieved_chunks)]
231 | return queries
232 |
233 |
234 | def encode(self, context_batch, mode):
235 |
236 | bs = len(context_batch)
237 | num_of_chunks_bs = [min(len(chunks), self.max_num_chunks) for chunks in context_batch]
238 | attn_mask_chunks = torch.zeros((bs, self.max_num_chunks))
239 | cols = torch.arange(attn_mask_chunks.size(1)).unsqueeze(0).expand_as(attn_mask_chunks)
240 | mask = cols < torch.LongTensor(num_of_chunks_bs).unsqueeze(1)
241 | attn_mask_chunks[mask] = 1
242 | chunks_flatten = [chunk for chunks in context_batch for chunk in chunks]
243 | if mode == 'train':
244 | chunks_emb = self.encoder_model.encode_with_grad(chunks_flatten, batch_size=50) # (num_all_chunks, emb)
245 | else:
246 | chunks_emb = self.encoder_model.encode_without_grad(chunks_flatten, batch_size=50) # (num_all_chunks, emb)
247 |
248 |
249 | res_chunks_emb = torch.zeros((bs, self.max_num_chunks, chunks_emb.size(-1)), dtype=chunks_emb.dtype)
250 | start = 0
251 | for b_id in range(bs):
252 | end = start + num_of_chunks_bs[b_id]
253 | res_chunks_emb[b_id][:num_of_chunks_bs[b_id]] = chunks_emb[start:end]
254 | start = end
255 |
256 |
257 | return res_chunks_emb, num_of_chunks_bs, attn_mask_chunks
258 |
259 |
260 |
261 | def decode_clm_template(self, aligned_emb, attn_mask_chunks, queries, answers, task_ids, retrieved_chunks=None):
262 | loss_task_weight = {0:self.alpha, 1:1}
263 | def loss_multiplier(idx):
264 | return loss_task_weight[idx]
265 |
266 | if retrieved_chunks:
267 | queries = self.update_queries(queries, retrieved_chunks)
268 |
269 | bs = len(queries)
270 |
271 | queries_ids = self.tokenizer(queries, padding=False, add_special_tokens=False)['input_ids']
272 | queries_lens = torch.LongTensor([len(query) for query in queries_ids]).unsqueeze(1)
273 |
274 | user_prefix_attn_mask = torch.ones((bs, self.user_prefix_prompt_len))
275 | user_prefix_input_ids_ = self.user_prefix_input_ids.long().to(self.decoder_model.device)
276 | if self.peft_fn in ['lora', 'qlora']:
277 | user_prefix_input_embs = self.decoder_model.get_base_model().model.embed_tokens(user_prefix_input_ids_)
278 | else:
279 | user_prefix_input_embs = self.decoder_model.model.embed_tokens(self.user_prefix_input_ids.long())
280 | user_prefix_input_embs = user_prefix_input_embs.expand(bs, -1, -1)
281 |
282 |
283 | assistant_prompts = [f"{self.sep}{self.user_suffix}{queries[i]}{self.end_mark}{self.sep}{self.start_mark}{self.assistant_mark}{self.sep}{answers[i]}{self.sft_end_marker}" for i in range(len(queries))]
284 | assistant_inputs = self.tokenizer(assistant_prompts, padding='max_length', max_length=self.max_seq_len, truncation=True, add_special_tokens=False, return_tensors='pt')
285 | assistant_ids = assistant_inputs['input_ids'].to(self.decoder_model.device) # bs, seq
286 | assistant_attn_mask = assistant_inputs['attention_mask'].to(self.decoder_model.device) # bs, seq
287 |
288 | if self.peft_fn in ['lora', 'qlora']:
289 | assistant_embs = self.decoder_model.get_base_model().model.embed_tokens(assistant_ids.long())
290 | else:
291 | assistant_embs = self.decoder_model.model.embed_tokens(assistant_ids.long())
292 |
293 | chunk_nums = attn_mask_chunks.sum(-1)
294 |
295 | new_embs = vectorized_split_and_concat(aligned_emb, chunk_nums, user_prefix_input_embs, self.args.bf16, self.decoder_model.device)
296 |
297 | input_embs = torch.cat((new_embs, assistant_embs), dim=1)
298 | pad_nums = (self.max_num_chunks-chunk_nums).unsqueeze(1)
299 | attn_mask_cols = torch.arange(self.user_prefix_prompt_len+self.max_num_chunks).unsqueeze(0).expand(bs, -1)
300 | if self.args.bf16:
301 | new_attn_mask = torch.ones(attn_mask_cols.size(), device=self.decoder_model.device).to(torch.bfloat16)
302 | else:
303 | new_attn_mask = torch.ones(attn_mask_cols.size(), device=self.decoder_model.device)
304 | new_attn_mask = new_attn_mask.masked_fill((attn_mask_cols