├── README.md └── LLM Self-Data Augmentation ├── finetune ├── sft │ ├── test.py │ ├── ds_config_zero2.json │ ├── ds_config_zero3.json │ ├── finetune.sh │ └── finetune.py ├── init.py ├── start_lora.py └── apilora.py ├── requirements.txt ├── readme.md ├── score.py └── main.py /README.md: -------------------------------------------------------------------------------- 1 | # Event-Extraction -------------------------------------------------------------------------------- /LLM Self-Data Augmentation/finetune/sft/test.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | with open('/home/ubri/llm/qwen/data/test.json', 'r') as f: 4 | data = json.load(f) 5 | print(data) 6 | -------------------------------------------------------------------------------- /LLM Self-Data Augmentation/requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate 2 | bitsandbytes 3 | deepspeed 4 | flask 5 | Levenshtein 6 | peft 7 | requests 8 | torch 9 | tqdm 10 | transformers -------------------------------------------------------------------------------- /LLM Self-Data Augmentation/finetune/init.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoConfig, AutoModelForCausalLM 2 | 3 | # 定义模型配置 4 | config = AutoConfig.from_pretrained("./qwen0.5b_15") # 示例:假设Qwen1.5_0.5B的配置类似于GPT-2,但具体层数和尺寸需要你根据实际情况设置 5 | 6 | # 随机初始化模型 7 | model = AutoModelForCausalLM.from_config(config) 8 | 9 | # 保存未训练的模型 10 | model.save_pretrained('./tmp') 11 | 12 | # 保存配置(如果需要) 13 | config.save_pretrained('./tmp') 14 | -------------------------------------------------------------------------------- /LLM Self-Data Augmentation/finetune/start_lora.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | import time 3 | from flask import Flask, request, jsonify 4 | 5 | app = Flask(__name__) 6 | # 启动脚本 7 | 8 | 9 | 10 | process = None 11 | @app.route('/start', methods=['GET']) 12 | def start(): 13 | global process 14 | if process is not None: 15 | process.terminate() 16 | process.wait() 17 | 18 | process=subprocess.Popen(['python', './apilora.py']) 19 | 20 | return jsonify({'msg': 'ok'}), 200 21 | 22 | 23 | if __name__ == '__main__': 24 | app.run(debug=False, host='0.0.0.0', port=8016) -------------------------------------------------------------------------------- /LLM Self-Data Augmentation/finetune/sft/ds_config_zero2.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "optimizer": { 14 | "type": "AdamW", 15 | "params": { 16 | "lr": "auto", 17 | "betas": "auto", 18 | "eps": "auto", 19 | "weight_decay": "auto" 20 | } 21 | }, 22 | 23 | "scheduler": { 24 | "type": "WarmupLR", 25 | "params": { 26 | "warmup_min_lr": "auto", 27 | "warmup_max_lr": "auto", 28 | "warmup_num_steps": "auto" 29 | } 30 | }, 31 | 32 | "zero_optimization": { 33 | "stage": 2, 34 | "offload_optimizer": { 35 | "device": "none", 36 | "pin_memory": true 37 | }, 38 | "allgather_partitions": true, 39 | "allgather_bucket_size": 2e8, 40 | "overlap_comm": true, 41 | "reduce_scatter": true, 42 | "reduce_bucket_size": 2e8, 43 | "contiguous_gradients": true 44 | }, 45 | 46 | "gradient_accumulation_steps": "auto", 47 | "gradient_clipping": "auto", 48 | "steps_per_print": 100, 49 | "train_batch_size": "auto", 50 | "train_micro_batch_size_per_gpu": "auto", 51 | "wall_clock_breakdown": false 52 | } -------------------------------------------------------------------------------- /LLM Self-Data Augmentation/finetune/sft/ds_config_zero3.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "optimizer": { 14 | "type": "AdamW", 15 | "params": { 16 | "lr": "auto", 17 | "betas": "auto", 18 | "eps": "auto", 19 | "weight_decay": "auto" 20 | } 21 | }, 22 | 23 | "scheduler": { 24 | "type": "WarmupLR", 25 | "params": { 26 | "warmup_min_lr": "auto", 27 | "warmup_max_lr": "auto", 28 | "warmup_num_steps": "auto" 29 | } 30 | }, 31 | 32 | "zero_optimization": { 33 | "stage": 3, 34 | "offload_optimizer": { 35 | "device": "none", 36 | "pin_memory": true 37 | }, 38 | "offload_param": { 39 | "device": "none", 40 | "pin_memory": true 41 | }, 42 | "overlap_comm": true, 43 | "contiguous_gradients": true, 44 | "sub_group_size": 1e9, 45 | "reduce_bucket_size": "auto", 46 | "stage3_prefetch_bucket_size": "auto", 47 | "stage3_param_persistence_threshold": "auto", 48 | "stage3_max_live_parameters": 1e9, 49 | "stage3_max_reuse_distance": 1e9, 50 | "stage3_gather_16bit_weights_on_model_save": true 51 | }, 52 | 53 | "gradient_accumulation_steps": "auto", 54 | "gradient_clipping": "auto", 55 | "steps_per_print": 100, 56 | "train_batch_size": "auto", 57 | "train_micro_batch_size_per_gpu": "auto", 58 | "wall_clock_breakdown": false 59 | } 60 | -------------------------------------------------------------------------------- /LLM Self-Data Augmentation/finetune/apilora.py: -------------------------------------------------------------------------------- 1 | from flask import Flask, request, jsonify 2 | import os 3 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 4 | 5 | from peft import AutoPeftModelForCausalLM 6 | from transformers import AutoTokenizer 7 | import gc 8 | 9 | app = Flask(__name__) 10 | # 加载第一个模型和分词器 11 | model = AutoPeftModelForCausalLM.from_pretrained("./sft/output_qwen/1/checkpoint-600", device_map="auto").eval() 12 | tokenizer = AutoTokenizer.from_pretrained("./sft/output_qwen/1/checkpoint-600") 13 | device="cuda" 14 | 15 | @app.route('/hello', methods=['GET']) 16 | def hello(): 17 | return jsonify({'msg': 'ok'}), 200 18 | 19 | 20 | @app.route('/ask', methods=['POST']) 21 | def ask(): 22 | data = request.json 23 | question = data.get('question') 24 | system = data.get('system','You are a helpful assistant.') 25 | history = data.get('history') 26 | print("Q:" + question) 27 | if not question: 28 | return jsonify({'error': 'No question provided'}), 400 29 | 30 | try: 31 | messages = [ 32 | {"role": "system", "content": system} 33 | ] 34 | for item in history: 35 | messages.append({'role': 'user', 'content':item[0]}) 36 | messages.append({'role': 'bot', 'content':item[1]}) 37 | messages.append({'role': 'user', 'content':question}) 38 | text = tokenizer.apply_chat_template( 39 | messages, 40 | tokenize=False, 41 | add_generation_prompt=True 42 | ) 43 | model_inputs = tokenizer([text], return_tensors="pt").to(device) 44 | 45 | generated_ids = model.generate( 46 | model_inputs.input_ids, 47 | max_new_tokens=4096 48 | ) 49 | generated_ids = [ 50 | output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids) 51 | ] 52 | 53 | response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] 54 | print("A:" + response) 55 | return jsonify({'answer': response}) 56 | except Exception as e: 57 | print(e) 58 | return jsonify({'error': str(e)}), 500 59 | 60 | if __name__ == '__main__': 61 | app.run(debug=False, host='0.0.0.0', port=8012) 62 | -------------------------------------------------------------------------------- /LLM Self-Data Augmentation/finetune/sft/finetune.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | export CUDA_DEVICE_MAX_CONNECTIONS=1 3 | DIR=`pwd` 4 | 5 | # Guide: 6 | # This script supports distributed training on multi-gpu workers (as well as single-worker training). 7 | # Please set the options below according to the comments. 8 | # For multi-gpu workers training, these options should be manually set for each worker. 9 | # After setting the options, please run the script on each worker. 10 | 11 | # Number of GPUs per GPU worker 12 | GPUS_PER_NODE=$(python -c 'import torch; print(torch.cuda.device_count())') 13 | 14 | # Number of GPU workers, for single-worker training, please set to 1 15 | NNODES=${NNODES:-1} 16 | 17 | # The rank of this worker, should be in {0, ..., WORKER_CNT-1}, for single-worker training, please set to 0 18 | NODE_RANK=${NODE_RANK:-0} 19 | 20 | # The ip address of the rank-0 worker, for single-worker training, please set to localhost 21 | MASTER_ADDR=${MASTER_ADDR:-localhost} 22 | 23 | # The port for communication 24 | MASTER_PORT=${MASTER_PORT:-6001} 25 | 26 | MODEL="/home/risa/qwen15/qwen14b_15" # Set the path if you do not want to load from huggingface directly 27 | # ATTENTION: specify the path to your training data, which should be a json file consisting of a list of conversations. 28 | # See the section for finetuning in README for more information. 29 | DATA="/home/risa/qwen15/sft/data/1.jsonl" 30 | DS_CONFIG_PATH="/home/risa/qwen15/sft/ds_config_zero3.json" 31 | USE_LORA=True 32 | Q_LORA=False 33 | 34 | function usage() { 35 | echo ' 36 | Usage: bash /home/ubri/llm/qwen/sft/finetune_lora_ds.sh [-m MODEL_PATH] [-d DATA_PATH] [--deepspeed DS_CONFIG_PATH] [--use_lora USE_LORA] [--q_lora Q_LORA] 37 | ' 38 | } 39 | 40 | while [[ "$1" != "" ]]; do 41 | case $1 in 42 | -m | --model ) 43 | shift 44 | MODEL=$1 45 | ;; 46 | -d | --data ) 47 | shift 48 | DATA=$1 49 | ;; 50 | --deepspeed ) 51 | shift 52 | DS_CONFIG_PATH=$1 53 | ;; 54 | --use_lora ) 55 | shift 56 | USE_LORA=$1 57 | ;; 58 | --q_lora ) 59 | shift 60 | Q_LORA=$1 61 | ;; 62 | -h | --help ) 63 | usage 64 | exit 0 65 | ;; 66 | * ) 67 | echo "Unknown argument ${1}" 68 | exit 1 69 | ;; 70 | esac 71 | shift 72 | done 73 | 74 | DISTRIBUTED_ARGS=" 75 | --nproc_per_node $GPUS_PER_NODE \ 76 | --nnodes $NNODES \ 77 | --node_rank $NODE_RANK \ 78 | --master_addr $MASTER_ADDR \ 79 | --master_port $MASTER_PORT 80 | " 81 | 82 | torchrun $DISTRIBUTED_ARGS finetune.py \ 83 | --model_name_or_path $MODEL \ 84 | --data_path $DATA \ 85 | --bf16 True \ 86 | --output_dir output_qwen/1 \ 87 | --num_train_epochs 1 \ 88 | --per_device_train_batch_size 4 \ 89 | --per_device_eval_batch_size 1 \ 90 | --gradient_accumulation_steps 2 \ 91 | --evaluation_strategy "no" \ 92 | --save_strategy "steps" \ 93 | --save_steps 100 \ 94 | --save_total_limit 600 \ 95 | --learning_rate 3e-4 \ 96 | --weight_decay 0.01 \ 97 | --adam_beta2 0.95 \ 98 | --warmup_ratio 0.01 \ 99 | --lr_scheduler_type "cosine" \ 100 | --logging_steps 1 \ 101 | --report_to "none" \ 102 | --model_max_length 4096 \ 103 | --lazy_preprocess True \ 104 | --use_lora ${USE_LORA} \ 105 | --q_lora ${Q_LORA} \ 106 | --gradient_checkpointing \ 107 | --deepspeed ${DS_CONFIG_PATH} 108 | -------------------------------------------------------------------------------- /LLM Self-Data Augmentation/readme.md: -------------------------------------------------------------------------------- 1 | # **基于大型语言模型自数据增强的事件抽取** 2 | 3 | ## **1\. 简介** 4 | 5 | 本论文核心思想是利用大型语言模型(LLM)的自数据增强能力来扩充训练数据,然后使用增强后的数据对模型(如 Qwen)进行监督式微调(SFT)。项目不仅包含了数据增强和模型微调的代码,还实现了一个轻量级的 Flask API 来部署微调后的 LoRA 模型,以及一个评估脚本来计算抽取结果的精确率、召回率和 F1 分数。 6 | 7 | ## **2\. 文件结构** 8 | 9 | . 10 | ├── main.py \# 实验主入口:用于数据增强、调用模型生成结果 11 | ├── score.py \# 评估脚本:计算 P/R/F1 分数 12 | ├── data/ \# 存放原始数据和增强数据 13 | │ └── ... 14 | ├── output/ \# 存放模型生成的预测结果 15 | │ └── ... 16 | ├── finetune/ 17 | │ ├── apilora.py \# LoRA 模型的 API 接口 (Flask) 18 | │ ├── init.py \# 模型初始化脚本 19 | │ ├── start\_lora.py \# 启动 API 服务的管理脚本 20 | │ └── sft/ 21 | │ ├── finetune.py \# 监督式微调 (SFT) 脚本 22 | │ └── test.py \# 简单的测试脚本 23 | └── README.md \# 本文档 24 | 25 | ## **3\. 主要模块功能** 26 | 27 | * **main.py** 28 | * **数据增强 (main2)**:实现了论文中的“反向数据增强” (Reverse Data Augmentation) 流程。它调用 ask\_openai 接口(需配置 openai\_api.py),先修改原始事件 JSON (Prompt P1'),再让 LLM 根据修改后的 JSON 重建文本 (Prompt P2'),从而生成新的 (text, event\_list) 数据对。 29 | * **批量推理 (main3, main4)**:加载测试数据,调用已部署的模型 API (apilora.py) 或 OpenAI API 进行提问,并将结果保存到 output/ 目录以便后续评估。 30 | * **finetune/sft/finetune.py** 31 | * 核心的模型微调脚本。 32 | * 基于 transformers 和 peft 库,使用 LoRA (Low-Rank Adaptation) 技术对大模型(如 Qwen)进行监督式微调。 33 | * 它会加载 main.py 生成的 jsonl 格式数据,并执行训练过程,最终保存模型的 checkpoint。 34 | * **finetune/apilora.py** 35 | * 一个 Flask Web 服务器,用于加载微调后的 LoRA 模型 checkpoint。 36 | * 提供 /ask 接口,接收文本输入,并返回模型的推理结果。 37 | * **finetune/start\_lora.py** 38 | * 用于管理和启动 apilora.py 服务的脚本。 39 | * 它提供 /start 接口,可以在不中断主服务的情况下,重新加载或切换 apilora.py 所使用的模型 checkpoint(通过修改 file\_path.txt 实现)。 40 | * **score.py** 41 | * 评估脚本,用于对比模型的预测结果和“黄金标准” (Ground Truth) 答案。 42 | * 它会加载 data/ 中的 gt\_...jsonl 和 output/ 中的预测文件。 43 | * 实现了 best\_match 贪心算法和字符串相似度(LCS、Levenshtein)来计算事件抽取任务的**精确率 (Precision)**、**召回率 (Recall)** 和 **F1 分数**。 44 | 45 | ## **4\. 基本使用流程** 46 | 47 | ### **步骤 1: 数据增强** 48 | 49 | 1. 准备好原始数据集。 50 | 2. 配置 openai\_api.py(或 main.py 中的 ask\_openai 函数)使其能够访问一个强大的 LLM(如 GPT-4)。 51 | 3. 运行 main.py 中的 main2 函数: 52 | \# (可能需要修改 main.py 的 \_\_name\_\_ \== '\_\_main\_\_' 部分来调用 main2) 53 | python main.py 54 | 55 | 4. 此步骤将生成增强后的训练数据(如 data/train5000\_augmented\_2.jsonl)和测试数据。 56 | 57 | ### **步骤 2: 模型微调** 58 | 59 | 1. 使用 finetune/sft/finetune.py 脚本来微调模型。 60 | \# 示例命令(具体参数请参照 finetune.py 中的 TrainingArguments) 61 | deepspeed finetune/sft/finetune.py \\ 62 | \--model\_name\_or\_path Qwen/Qwen-7B-Chat \\ 63 | \--data\_path ./data/train5000\_augmented\_2.jsonl \\ 64 | \--output\_dir ./finetune/sft/output\_qwen/my\_model \\ 65 | \--num\_train\_epochs 5 \\ 66 | \--per\_device\_train\_batch\_size 8 \\ 67 | \--use\_lora True \\ 68 | \--q\_lora True \\ 69 | \--gradient\_checkpointing True \\ 70 | \--deepspeed ./deepspeed\_config.json 71 | 72 | 2. 训练完成后,LoRA checkpoint 将保存在 output\_dir 中。 73 | 74 | ### **步骤 3: 启动模型 API** 75 | 76 | 1. 修改 finetune/apilora.py,将 from\_pretrained 的路径指向你训练好的 checkpoint (如 ./sft/output\_qwen/my\_model/checkpoint-xxx)。 77 | 2. 首先启动模型管理服务: 78 | python finetune/start\_lora.py 79 | 80 | 3. 然后访问 http://127.0.0.1:8016/start 来启动 apilora.py 服务(该服务运行在 8012 端口)。 81 | 82 | ### **步骤 4: 推理与评估** 83 | 84 | 1. 运行 main.py 中的 main4 函数(确保 ask 函数的 URL 为 http://127.0.0.1:8012/ask)。 85 | \# (可能需要修改 main.py 的 \_\_name\_\_ \== '\_\_main\_\_' 部分来调用 main4) 86 | python main.py 87 | 88 | 2. 该脚本会调用在 8012 端口运行的模型 API,对测试集进行推理,并将结果保存到 output/ 目录。 89 | 3. 最后,运行 score.py 来计算最终得分。 90 | \# (修改 score.py 中 evaluate 函数的文件路径) 91 | python score.py 92 | 93 | 输出示例: 94 | Precision=XX.XX%, Recall=XX.XX%, F1=XX.XX% -------------------------------------------------------------------------------- /LLM Self-Data Augmentation/score.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Author: Derry 3 | Date: 2024-04-11 05:02:03 4 | LastEditors: Derry 5 | Email: drlv@mail.ustc.edu.cn 6 | LastEditTime: 2024-04-12 23:29:33 7 | FilePath: /LLM/score1.py 8 | Description: Coding by drlv of USTC 9 | ''' 10 | import Levenshtein 11 | import json 12 | import difflib 13 | from itertools import permutations 14 | from tqdm import tqdm 15 | import random 16 | 17 | 18 | def load_jsonl(file_path): 19 | data = [] 20 | with open(file_path, 'r', encoding='utf-8') as f: 21 | for line in f: 22 | try: 23 | data.append(json.loads(line)) 24 | except: 25 | data.append(None) 26 | return data 27 | 28 | 29 | def load_json(file_path): 30 | with open(file_path, 'r', encoding='utf-8') as f: 31 | return json.load(f) 32 | 33 | 34 | def save_json(file_path, data): 35 | with open(file_path, 'w', encoding='utf-8') as f: 36 | json.dump(data, f, indent=4, ensure_ascii=False) 37 | 38 | 39 | def string_similar1(s1, s2): 40 | if s1 == s2: 41 | return 1 42 | # 最长公共子序列 43 | return difflib.SequenceMatcher(None, s1, s2).ratio() 44 | 45 | 46 | def string_similar2(s1, s2): 47 | # 使用编辑距离度量 48 | return 1 - 2 * Levenshtein.distance(s1, s2) / (len(s1) + len(s2)) 49 | 50 | 51 | def calculate_metrics(correct, total_gt, total_pred): 52 | precision = correct / total_pred if total_pred else 0 53 | recall = correct / total_gt if total_gt else 0 54 | f1 = 2 * precision * recall / (precision + recall) if ( 55 | precision + recall) else 0 56 | return precision, recall, f1 57 | 58 | 59 | def compare_annotations(gt_annotation, pred_annotation): 60 | correct, total_gt, total_pred = 0, 0, 0 61 | 62 | if type(gt_annotation) is dict: 63 | total_gt += len(gt_annotation) 64 | if type(pred_annotation) is dict: 65 | total_pred += len(pred_annotation) 66 | 67 | if total_gt == 0 or total_pred == 0: 68 | return correct, total_gt, total_pred 69 | 70 | if type(gt_annotation) is dict and type(pred_annotation) is dict: 71 | total_keys = list(set(gt_annotation.keys() | pred_annotation.keys())) 72 | # if 'mention' in total_keys: 73 | # total_keys.remove('mention') 74 | for key in total_keys: 75 | if key in gt_annotation and key in pred_annotation: 76 | cur_gt = gt_annotation[key] 77 | cur_pred = pred_annotation[key] 78 | if type(cur_gt) is list: 79 | cur_gt = ''.join(cur_gt) 80 | if type(cur_pred) is list: 81 | cur_pred = ''.join(cur_pred) 82 | try: 83 | if string_similar1(cur_gt, cur_pred) >= 0.6 or string_similar2(cur_gt, cur_pred) >= 0.6: 84 | correct += 1 85 | except: 86 | pass 87 | 88 | return correct, total_gt, total_pred 89 | 90 | 91 | 92 | def best_match(gt_annotations, pred_annotations): 93 | ret_correct, ret_gt_count, ret_pred_count = 0, 0, 0 94 | 95 | gt_tmp = [] 96 | pred_tmp = [] 97 | #print(pred_annotations) 98 | for i in range(len(gt_annotations)): 99 | gt_tmp.append(gt_annotations[i]) 100 | 101 | for i in range(len(pred_annotations)): 102 | pred_tmp.append(pred_annotations[i]) 103 | while True: 104 | if len(gt_tmp) == 0 or len(pred_tmp) == 0: 105 | break 106 | cur_corrent = -1 107 | cur_gt_count = 0 108 | cur_pred_count = 0 109 | 110 | if len(gt_tmp) > len(pred_tmp): 111 | cur_pred_index = 0 112 | for i in range(len(gt_tmp)): 113 | cur_gt = gt_tmp[i] 114 | cur_pred = pred_tmp[cur_pred_index] 115 | correct, gt_count, pred_count = compare_annotations( 116 | cur_gt, cur_pred) 117 | if correct > cur_corrent: 118 | cur_gt_index = i 119 | cur_corrent = correct 120 | cur_gt_count = gt_count 121 | cur_pred_count = pred_count 122 | else: 123 | cur_gt_index = 0 124 | for i in range(len(pred_tmp)): 125 | cur_gt = gt_tmp[cur_gt_index] 126 | cur_pred = pred_tmp[i] 127 | correct, gt_count, pred_count = compare_annotations( 128 | cur_gt, cur_pred) 129 | if correct > cur_corrent: 130 | cur_pred_index = i 131 | cur_corrent = correct 132 | cur_gt_count = gt_count 133 | cur_pred_count = pred_count 134 | 135 | ret_correct += cur_corrent 136 | ret_gt_count += cur_gt_count 137 | ret_pred_count += cur_pred_count 138 | # print("############################################") 139 | # print(cur_gt_index) 140 | # print(cur_pred_index) 141 | gt_tmp.pop(cur_gt_index) 142 | pred_tmp.pop(cur_pred_index) 143 | 144 | for i in range(len(pred_tmp)): 145 | correct, gt_count, pred_count = compare_annotations(None, pred_tmp[i]) 146 | ret_correct += correct 147 | ret_gt_count += gt_count 148 | ret_pred_count += pred_count 149 | 150 | for i in range(len(gt_tmp)): 151 | correct, gt_count, pred_count = compare_annotations(gt_tmp[i], None) 152 | ret_correct += correct 153 | ret_gt_count += gt_count 154 | ret_pred_count += pred_count 155 | 156 | return ret_correct, ret_gt_count, ret_pred_count 157 | 158 | 159 | 160 | def evaluate(gt_file, pred_file): 161 | gt_data = load_jsonl(gt_file) 162 | pred_data = load_jsonl(pred_file) 163 | total_correct, total_gt, total_pred = 0, 0, 0 164 | 165 | for i in range(len(gt_data)): 166 | print(i) 167 | cur_gt = gt_data[i] 168 | cur_pred = pred_data[i] 169 | if cur_gt == None: 170 | cur_gt=[] 171 | if cur_pred == None: 172 | cur_pred=[] 173 | correct, gt_num, pred_num = best_match( 174 | cur_gt, cur_pred) 175 | total_correct += correct 176 | total_gt += gt_num 177 | total_pred += pred_num 178 | 179 | print(total_correct, total_gt, total_pred) 180 | precision, recall, f1 = calculate_metrics( 181 | total_correct, total_gt, total_pred) 182 | return precision, recall, f1 183 | 184 | 185 | 186 | # todo = range(226, 236) # [71,72,73,74,75,76,77,78,79,80] 187 | # 188 | # for i in todo: 189 | # precision, recall, f1 = evaluate(f'./data/p0_gt.jsonl', f'./output/{i}.jsonl') 190 | # print(f'Precision={100 * precision:.2f}%, Recall={100 * recall:.2f}%, F1={100 * f1:.2f}%') 191 | 192 | precision, recall, f1 = evaluate('./data/p0_gtns.jsonl', './output/p0_qwen_14.jsonl') 193 | print(f'Precision={100*precision:.2f}%, Recall={100*recall:.2f}%, F1={100*f1:.2f}%') 194 | -------------------------------------------------------------------------------- /LLM Self-Data Augmentation/main.py: -------------------------------------------------------------------------------- 1 | import requests 2 | import json 3 | import os 4 | import threading 5 | from queue import Queue 6 | from concurrent.futures import ThreadPoolExecutor 7 | from functools import partial 8 | import random 9 | import time 10 | import openai_api 11 | 12 | prompt_z = "请抽取出文本中的事件,如果有的元素抽取不出来请置为\"\"," \ 13 | '输出严格遵守如下的json具体格式输出:[{"class": "事件类别","actor": "行为方","action": "行为","object": "受体","time": "时间","location": "地点"},…]\n' \ 14 | "文本如下:\n" 15 | 16 | def ask_openai(question): 17 | response= openai_api.query4(question) 18 | return response 19 | 20 | def ask(question,history=[],url="http://127.0.0.1:8012/ask"): 21 | data = {'question': question, 'history': history} 22 | response = requests.post(url, json=data) 23 | return response.json()['answer'] 24 | 25 | def load_data(): 26 | with open("./data/p0_gt.json",'r',encoding='utf-8') as f: 27 | data=json.load(f) 28 | 29 | ret=[] 30 | for item in data: 31 | ret.append((item["text"],item['event_list'])) 32 | 33 | 34 | random.seed(42) 35 | random.shuffle(ret) 36 | 37 | return ret 38 | 39 | def load_data2(): 40 | with open("./data/5.json",'r',encoding='utf-8') as f: 41 | data=json.load(f) 42 | 43 | ret=[] 44 | for item in data: 45 | ret.append((item["text"],item['casuality_list'])) 46 | 47 | to_pop=[] 48 | for item in ret: 49 | if len(item[1])>=4: 50 | to_pop.append(item) 51 | 52 | for item in to_pop: 53 | ret.remove(item) 54 | 55 | random.seed(42) 56 | random.shuffle(ret) 57 | 58 | return ret 59 | 60 | def list2jsonl(list_): 61 | return_str="" 62 | system="你是一个因果分析专家,擅长发现文本中的因果关系。" 63 | for i,item in enumerate(list_): 64 | print(i) 65 | q=item[0] 66 | a=item[1] 67 | tmp_dict={"type": "chatml","messages": [{"role": "system","content": system},{"role": "user","content": q},{"role": "assistant","content": a}],"source": "unknown"} 68 | return_str += json.dumps(tmp_dict,ensure_ascii=False) + "\n" 69 | return return_str 70 | 71 | def test_hello_route(): 72 | # 设置目标 URL 73 | url = 'http://127.0.0.1:8012/hello' 74 | 75 | # 发送 GET 请求 76 | try: 77 | response = requests.get(url) 78 | except: 79 | return False 80 | 81 | # 检查响应状态码 82 | if response.status_code == 200: 83 | return True 84 | else: 85 | return False 86 | 87 | def start_model(): 88 | # 设置目标 URL 89 | url = 'http://127.0.0.1:8016/start' 90 | 91 | # 发送 GET 请求 92 | response = requests.get(url) 93 | 94 | # 检查响应状态码 95 | if response.status_code == 200: 96 | print("请求成功!") 97 | # 打印响应内容(JSON) 98 | print("响应内容:", response.json()) 99 | else: 100 | print("请求失败,状态码:", response.status_code) 101 | 102 | to_eval=['/home/ubri/llm/qwen/sft/output_qwen/q72_5000_1e4_3/checkpoint-84','/home/ubri/llm/qwen/sft/output_qwen/q72_5000_1e4_3/checkpoint-168', 103 | '/home/ubri/llm/qwen/sft/output_qwen/q72_5000_1e4_3/checkpoint-252','/home/ubri/llm/qwen/sft/output_qwen/q72_5000_1e4_3/checkpoint-336', 104 | '/home/ubri/llm/qwen/sft/output_qwen/q72_5000_1e4_3/checkpoint-420','/home/ubri/llm/qwen/sft/output_qwen/q72_5000_1e4_3/checkpoint-504', 105 | '/home/ubri/llm/qwen/sft/output_qwen/q72_5000_1e4_3/checkpoint-588','/home/ubri/llm/qwen/sft/output_qwen/q72_5000_1e4_3/checkpoint-672', 106 | '/home/ubri/llm/qwen/sft/output_qwen/q72_5000_1e4_3/checkpoint-756','/home/ubri/llm/qwen/sft/output_qwen/q72_5000_1e4_3/checkpoint-840'] 107 | 108 | to_save=['/home/ubri/llm/tmp1/output/20/','/home/ubri/llm/tmp1/output/21/', 109 | '/home/ubri/llm/tmp1/output/22/','/home/ubri/llm/tmp1/output/23/', 110 | '/home/ubri/llm/tmp1/output/24/','/home/ubri/llm/tmp1/output/25/', 111 | '/home/ubri/llm/tmp1/output/26/','/home/ubri/llm/tmp1/output/27/', 112 | '/home/ubri/llm/tmp1/output/28/','/home/ubri/llm/tmp1/output/29/',] 113 | 114 | def main4(): 115 | data=load_data()[:100] 116 | for i in range(len(to_eval)): 117 | with open('/home/ubri/llm/qwen/change_model/file_path.txt','w',encoding='utf-8') as f: 118 | f.write(to_eval[i]) 119 | start_model() 120 | while True: 121 | if test_hello_route(): 122 | break 123 | time.sleep(1) 124 | if not os.path.exists(to_save[i]): 125 | os.mkdir(to_save[i]) 126 | for j,item in enumerate(data): 127 | print(i,j) 128 | # print(item[0]) 129 | answer=ask(prompt_z+item[0]) 130 | with open(to_save[i]+str(j)+'.txt','w',encoding='utf-8') as f: 131 | f.write(answer) 132 | print(answer) 133 | 134 | 135 | def main():#问问题 136 | data=load_data()[:100] 137 | to_do=["./sft/output_qwen/q72_5000_1e4_2/checkpoint-84"] 138 | to_save=["./output/20/"] 139 | for ii in range(len(to_do)): 140 | if not os.path.exists(to_save[ii]): 141 | os.mkdir(to_save[ii]) 142 | for i,item in enumerate(data): 143 | print(i) 144 | print(item[0]) 145 | 146 | 147 | def main2(): # 生成数据集 148 | 149 | prompt_p1 = """%EVENT_EXTRACTION% 150 | 上述内容是JSON格式的从文本中提取的几个事件。请根据以下要求进行重写提取: 151 | 1. 保持原有的JSON格式。JSON对象的字段(class, actor, action, object, time, location)不得更改,但JSON数组的大小可以修改。 152 | 2. 请修改事件的 "class"(事件类别),并相应地修改 "actor"(行为方), "action"(行为), "object"(受体), "time"(时间), 和 "location"(地点)。 153 | 3. 修改后的 "actor" 和 "object" 必须在整个JSON中保持一致(如果适用)。 154 | 4. 修改后的 "time" 和 "location" 必须是合理的。 155 | 5. 修改后的JSON必须描绘一个完整的、逻辑连贯的场景,其中每个事件都是这个场景的一部分。 156 | 6. 请只输出JSON格式,确保不包含任何多余的内容。 157 | 158 | 请开始重写: 159 | """ 160 | 161 | # 提示2 (基于论文 P2'): 从修改后的JSON重建文本 162 | prompt_p2 = """%EVENT_EXTRACTION% 163 | 上述内容是JSON格式的从文本中提取的几个事件。请根据以下要求重建原始文本片段: 164 | 1. 重建的文本必须包含JSON中的所有信息,不得遗漏。 165 | 2. JSON中的每一条信息都必须完整地出现在文本片段中,不得被截断或打断。 166 | 3. 重建的文本片段长度必须超过200个字符。您可以包含一些不描述事件的句子来满足字数要求。 167 | 4. 重建的文本片段必须语义连贯、完整,确保整个文本片段处于一个一致的场景中。 168 | 5. 请只输出重建的文本片段,不包含任何多余的内容。 169 | 170 | 请开始重建文本: 171 | """ 172 | 173 | print("Loading data...") 174 | data = load_data() 175 | if not data: 176 | print("No data loaded. Exiting main2.") 177 | return 178 | 179 | # 确保数据量足够 180 | if len(data) < 100: 181 | print(f"Error: Not enough data. Loaded {len(data)}, but need at least 100.") 182 | return 183 | 184 | # 1. 分离训练集和测试集 185 | # 保持与原始代码一致的分割 186 | test_data = data[:100] 187 | train_data = data[100:5140] # 假设原始数据至少有5140条 188 | 189 | print(f"Loaded {len(data)} total samples.") 190 | print(f"Test set size: {len(test_data)}") 191 | print(f"Train set size: {len(train_data)}") 192 | 193 | # 2. 处理和保存测试集 (与原始逻辑相同) 194 | list_test = [] 195 | list_test_gt = [] 196 | 197 | print("Processing test set...") 198 | for i, item in enumerate(test_data): 199 | q = prompt_z + item[0] 200 | a = json.dumps(item[1], ensure_ascii=False) 201 | list_test.append((q, a)) 202 | list_test_gt.append(item[1]) 203 | 204 | output_list_test = list2jsonl(list_test) 205 | with open(f'./data/test_2.jsonl', 'w', encoding='utf-8') as f: 206 | f.write(output_list_test) 207 | 208 | with open(f'./data/gt_2.jsonl', 'w', encoding='utf-8') as f: 209 | for item in list_test_gt: 210 | f.write(json.dumps(item, ensure_ascii=False) + '\n') 211 | print("Test set saved.") 212 | 213 | # 3. 处理和增强训练集 214 | list_train_augmented = [] 215 | 216 | # 3a. 添加所有原始训练数据 217 | print("Adding original training data...") 218 | for i, item in enumerate(train_data): 219 | q = prompt_z + item[0] 220 | a = json.dumps(item[1], ensure_ascii=False) 221 | list_train_augmented.append((q, a)) 222 | 223 | print(f"Added {len(train_data)} original samples to training set.") 224 | 225 | # 3b. 添加增强数据 226 | print(f"Starting self-data augmentation for {len(train_data)} samples...") 227 | print("This will take a long time as it requires 2 LLM calls per sample.") 228 | 229 | for i, item in enumerate(train_data): 230 | print(f"Augmenting sample {i + 1}/{len(train_data)}...") 231 | try: 232 | # 准备阶段1:修改事件JSON 233 | original_event_json_str = json.dumps(item[1], ensure_ascii=False) 234 | prompt1_full = prompt_p1.replace("%EVENT_EXTRACTION%", original_event_json_str) 235 | 236 | # 调用LLM(阶段1) 237 | modified_event_json_str = ask_openai(prompt1_full) 238 | 239 | if not modified_event_json_str: 240 | print(f"Warning: LLM call 1 (Modify) returned None for sample {i}. Skipping.") 241 | continue 242 | 243 | # 简单的JSON验证 244 | try: 245 | # 确保它至少是有效的JSON 246 | json.loads(modified_event_json_str) 247 | # 确保它是一个列表 248 | if not modified_event_json_str.strip().startswith("["): 249 | print(f"Warning: LLM call 1 (Modify) did not return a JSON list for sample {i}. Skipping.") 250 | print(f"Received: {modified_event_json_str}") 251 | continue 252 | except json.JSONDecodeError as e: 253 | print(f"Warning: LLM call 1 (Modify) returned invalid JSON for sample {i}: {e}. Skipping.") 254 | print(f"Received: {modified_event_json_str}") 255 | continue 256 | 257 | # 准备阶段2:重建文本 258 | prompt2_full = prompt_p2.replace("%EVENT_EXTRACTION%", modified_event_json_str) 259 | 260 | # 调用LLM(阶段2) 261 | new_text = ask_openai(prompt2_full) 262 | 263 | if not new_text or new_text.strip().startswith("["): 264 | print(f"Warning: LLM call 2 (Reconstruct) returned None or invalid text for sample {i}. Skipping.") 265 | print(f"Received: {new_text}") 266 | continue 267 | 268 | # 格式化并添加新的增强样本 269 | q_aug = prompt_z + new_text 270 | a_aug = modified_event_json_str # 答案已经是JSON字符串了 271 | 272 | list_train_augmented.append((q_aug, a_aug)) 273 | 274 | if (i + 1) % 100 == 0: 275 | print( 276 | f"Checkpoint: Completed {i + 1} augmentations. Total training samples: {len(list_train_augmented)}") 277 | 278 | except Exception as e: 279 | print(f"Error processing augmentation for sample {i}: {e}. Skipping.") 280 | # 建议在长时间运行时添加更稳健的重试逻辑 281 | time.sleep(1) # 发生错误时稍作等待 282 | 283 | # 4. 保存增强后的训练集 284 | print(f"Augmentation complete. Total augmented training samples: {len(list_train_augmented)}") 285 | print("Saving augmented training set...") 286 | output_list_train = list2jsonl(list_train_augmented) 287 | 288 | # 保存到新文件以避免覆盖原始文件 289 | with open(f'./data/train5000_augmented_2.jsonl', 'w', encoding='utf-8') as f: 290 | f.write(output_list_train) 291 | 292 | print("Augmented training set saved to ./data/train5000_augmented_2.jsonl") 293 | print(f"Original total data length: {len(data)}") 294 | print(f"Final augmented training set size (lines): {len(list_train_augmented)}") 295 | 296 | def worker(item, index, model_dir, save_dir): 297 | 298 | print(index) 299 | answer = ask(prompt_z + item[0]) # 生成答案 300 | file_path = os.path.join(save_dir, f'{index}.txt') 301 | 302 | with open(file_path, 'w', encoding='utf-8') as f: 303 | f.write(answer) 304 | 305 | 306 | def main3(): # 问问题 307 | data = load_data() 308 | to_do = ["./sft/output_qwen/q72_5000_1e4_2/checkpoint-84"] 309 | to_save = ["./output/p0_gpt4/"] 310 | for ii in range(len(to_do)): 311 | if not os.path.exists(to_save[ii]): 312 | os.mkdir(to_save[ii]) 313 | for i, item in enumerate(data): 314 | print(i) 315 | answer = ask_openai(prompt_z + item[0]) 316 | if answer is not None: 317 | with open(to_save[ii] + str(i) + '.txt', 'w', encoding='utf-8') as f: 318 | json.dump(answer, f, ensure_ascii=False, indent=4) 319 | print(answer) 320 | 321 | if __name__=='__main__': 322 | main4() 323 | -------------------------------------------------------------------------------- /LLM Self-Data Augmentation/finetune/sft/finetune.py: -------------------------------------------------------------------------------- 1 | # This code is based on the revised code from fastchat based on tatsu-lab/stanford_alpaca. 2 | 3 | 4 | from dataclasses import dataclass, field 5 | import json 6 | import logging 7 | import os 8 | import pathlib 9 | from typing import Dict, Optional, List 10 | import torch 11 | from torch.utils.data import Dataset 12 | from deepspeed import zero 13 | from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint 14 | from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus 15 | import transformers 16 | from transformers import AutoModelForCausalLM, AutoTokenizer 17 | from transformers import Trainer, BitsAndBytesConfig, deepspeed 18 | from transformers.trainer_pt_utils import LabelSmoother 19 | from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training 20 | from accelerate.utils import DistributedType 21 | 22 | 23 | IGNORE_TOKEN_ID = LabelSmoother.ignore_index 24 | 25 | local_rank = None 26 | 27 | 28 | def rank0_print(*args): 29 | if local_rank == 0: 30 | print(*args) 31 | 32 | 33 | @dataclass 34 | class ModelArguments: 35 | model_name_or_path: Optional[str] = field(default="Qwen/Qwen-7B") 36 | 37 | 38 | @dataclass 39 | class DataArguments: 40 | data_path: str = field( 41 | default=None, metadata={"help": "Path to the training data."} 42 | ) 43 | eval_data_path: str = field( 44 | default=None, metadata={"help": "Path to the evaluation data."} 45 | ) 46 | lazy_preprocess: bool = False 47 | 48 | 49 | @dataclass 50 | class TrainingArguments(transformers.TrainingArguments): 51 | cache_dir: Optional[str] = field(default=None) 52 | optim: str = field(default="adamw_torch") 53 | model_max_length: int = field( 54 | default=8192, 55 | metadata={ 56 | "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)." 57 | }, 58 | ) 59 | use_lora: bool = False 60 | 61 | 62 | @dataclass 63 | class LoraArguments: 64 | lora_r: int = 64 65 | lora_alpha: int = 16 66 | lora_dropout: float = 0.05 67 | lora_target_modules: List[str] = field( 68 | default_factory=lambda: [ 69 | "q_proj", 70 | "k_proj", 71 | "v_proj", 72 | "o_proj", 73 | "up_proj", 74 | "gate_proj", 75 | "down_proj", 76 | ] 77 | ) 78 | lora_weight_path: str = "" 79 | lora_bias: str = "none" 80 | q_lora: bool = False 81 | 82 | 83 | def maybe_zero_3(param): 84 | if hasattr(param, "ds_id"): 85 | assert param.ds_status == ZeroParamStatus.NOT_AVAILABLE 86 | with zero.GatheredParameters([param]): 87 | param = param.data.detach().cpu().clone() 88 | else: 89 | param = param.detach().cpu().clone() 90 | return param 91 | 92 | 93 | # Borrowed from peft.utils.get_peft_model_state_dict 94 | def get_peft_state_maybe_zero_3(named_params, bias): 95 | if bias == "none": 96 | to_return = {k: t for k, t in named_params if "lora_" in k} 97 | elif bias == "all": 98 | to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k} 99 | elif bias == "lora_only": 100 | to_return = {} 101 | maybe_lora_bias = {} 102 | lora_bias_names = set() 103 | for k, t in named_params: 104 | if "lora_" in k: 105 | to_return[k] = t 106 | bias_name = k.split("lora_")[0] + "bias" 107 | lora_bias_names.add(bias_name) 108 | elif "bias" in k: 109 | maybe_lora_bias[k] = t 110 | for k, t in maybe_lora_bias: 111 | if bias_name in lora_bias_names: 112 | to_return[bias_name] = t 113 | else: 114 | raise NotImplementedError 115 | to_return = {k: maybe_zero_3(v) for k, v in to_return.items()} 116 | return to_return 117 | 118 | 119 | def safe_save_model_for_hf_trainer( 120 | trainer: transformers.Trainer, output_dir: str, bias="none" 121 | ): 122 | """Collects the state dict and dump to disk.""" 123 | # check if zero3 mode enabled 124 | if deepspeed.is_deepspeed_zero3_enabled(): 125 | state_dict = trainer.model_wrapped._zero3_consolidated_16bit_state_dict() 126 | else: 127 | if trainer.args.use_lora: 128 | state_dict = get_peft_state_maybe_zero_3( 129 | trainer.model.named_parameters(), bias 130 | ) 131 | else: 132 | state_dict = trainer.model.state_dict() 133 | if trainer.args.should_save and trainer.args.local_rank == 0: 134 | trainer._save(output_dir, state_dict=state_dict) 135 | 136 | 137 | def preprocess( 138 | messages, 139 | tokenizer: transformers.PreTrainedTokenizer, 140 | max_len: int, 141 | ) -> Dict: 142 | """Preprocesses the data for supervised fine-tuning.""" 143 | 144 | texts = [] 145 | for i, msg in enumerate(messages): 146 | texts.append( 147 | tokenizer.apply_chat_template( 148 | msg, 149 | tokenize=True, 150 | add_generation_prompt=False, 151 | padding=True, 152 | max_length=max_len, 153 | truncation=True, 154 | ) 155 | ) 156 | input_ids = torch.tensor(texts, dtype=torch.int) 157 | target_ids = input_ids.clone() 158 | target_ids[target_ids == tokenizer.pad_token_id] = IGNORE_TOKEN_ID 159 | attention_mask = input_ids.ne(tokenizer.pad_token_id) 160 | 161 | return dict( 162 | input_ids=input_ids, target_ids=target_ids, attention_mask=attention_mask 163 | ) 164 | 165 | 166 | class SupervisedDataset(Dataset): 167 | """Dataset for supervised fine-tuning.""" 168 | 169 | def __init__( 170 | self, raw_data, tokenizer: transformers.PreTrainedTokenizer, max_len: int 171 | ): 172 | super(SupervisedDataset, self).__init__() 173 | 174 | rank0_print("Formatting inputs...") 175 | messages = [example["messages"] for example in raw_data] 176 | data_dict = preprocess(messages, tokenizer, max_len) 177 | 178 | self.input_ids = data_dict["input_ids"] 179 | self.target_ids = data_dict["target_ids"] 180 | self.attention_mask = data_dict["attention_mask"] 181 | 182 | def __len__(self): 183 | return len(self.input_ids) 184 | 185 | def __getitem__(self, i) -> Dict[str, torch.Tensor]: 186 | return dict( 187 | input_ids=self.input_ids[i], 188 | labels=self.target_ids[i], 189 | attention_mask=self.attention_mask[i], 190 | ) 191 | 192 | 193 | class LazySupervisedDataset(Dataset): 194 | """Dataset for supervised fine-tuning.""" 195 | 196 | def __init__( 197 | self, raw_data, tokenizer: transformers.PreTrainedTokenizer, max_len: int 198 | ): 199 | super(LazySupervisedDataset, self).__init__() 200 | self.tokenizer = tokenizer 201 | self.max_len = max_len 202 | 203 | rank0_print("Formatting inputs...Skip in lazy mode") 204 | self.tokenizer = tokenizer 205 | self.raw_data = raw_data 206 | self.cached_data_dict = {} 207 | 208 | def __len__(self): 209 | return len(self.raw_data) 210 | 211 | def __getitem__(self, i) -> Dict[str, torch.Tensor]: 212 | if i in self.cached_data_dict: 213 | return self.cached_data_dict[i] 214 | ret = preprocess([self.raw_data[i]["messages"]], self.tokenizer, self.max_len) 215 | ret = dict( 216 | input_ids=ret["input_ids"][0], 217 | labels=ret["target_ids"][0], 218 | attention_mask=ret["attention_mask"][0], 219 | ) 220 | self.cached_data_dict[i] = ret 221 | 222 | return ret 223 | 224 | 225 | def make_supervised_data_module( 226 | tokenizer: transformers.PreTrainedTokenizer, 227 | data_args, 228 | max_len, 229 | ) -> Dict: 230 | """Make dataset and collator for supervised fine-tuning.""" 231 | dataset_cls = ( 232 | LazySupervisedDataset if data_args.lazy_preprocess else SupervisedDataset 233 | ) 234 | rank0_print("Loading data...") 235 | 236 | train_data = [] 237 | with open(data_args.data_path, "r") as f: 238 | #data = json.load(f) 239 | #train_data.append(data) 240 | for line in f: 241 | train_data.append(json.loads(line)) 242 | train_dataset = dataset_cls(train_data, tokenizer=tokenizer, max_len=max_len) 243 | 244 | if data_args.eval_data_path: 245 | eval_data = [] 246 | with open(data_args.eval_data_path, "r") as f: 247 | for line in f: 248 | eval_data.append(json.loads(line)) 249 | eval_dataset = dataset_cls(eval_data, tokenizer=tokenizer, max_len=max_len) 250 | else: 251 | eval_dataset = None 252 | 253 | return dict(train_dataset=train_dataset, eval_dataset=eval_dataset) 254 | 255 | 256 | def train(): 257 | global local_rank 258 | 259 | parser = transformers.HfArgumentParser( 260 | (ModelArguments, DataArguments, TrainingArguments, LoraArguments) 261 | ) 262 | ( 263 | model_args, 264 | data_args, 265 | training_args, 266 | lora_args, 267 | ) = parser.parse_args_into_dataclasses() 268 | 269 | # This serves for single-gpu qlora. 270 | if ( 271 | getattr(training_args, "deepspeed", None) 272 | and int(os.environ.get("WORLD_SIZE", 1)) == 1 273 | ): 274 | training_args.distributed_state.distributed_type = DistributedType.DEEPSPEED 275 | 276 | local_rank = training_args.local_rank 277 | 278 | device_map = None 279 | world_size = int(os.environ.get("WORLD_SIZE", 1)) 280 | ddp = world_size != 1 281 | if lora_args.q_lora: 282 | device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)} if ddp else "auto" 283 | if len(training_args.fsdp) > 0 or deepspeed.is_deepspeed_zero3_enabled(): 284 | logging.warning("FSDP or ZeRO3 is incompatible with QLoRA.") 285 | 286 | model_load_kwargs = { 287 | "low_cpu_mem_usage": not deepspeed.is_deepspeed_zero3_enabled(), 288 | } 289 | 290 | compute_dtype = ( 291 | torch.float16 292 | if training_args.fp16 293 | else (torch.bfloat16 if training_args.bf16 else torch.float32) 294 | ) 295 | 296 | # Load model and tokenizer 297 | config = transformers.AutoConfig.from_pretrained( 298 | model_args.model_name_or_path, 299 | cache_dir=training_args.cache_dir, 300 | ) 301 | config.use_cache = False 302 | 303 | model = AutoModelForCausalLM.from_pretrained( 304 | model_args.model_name_or_path, 305 | config=config, 306 | cache_dir=training_args.cache_dir, 307 | device_map=device_map, 308 | quantization_config=BitsAndBytesConfig( 309 | load_in_4bit=True, 310 | bnb_4bit_use_double_quant=True, 311 | bnb_4bit_quant_type="nf4", 312 | bnb_4bit_compute_dtype=compute_dtype, 313 | ) 314 | if training_args.use_lora and lora_args.q_lora 315 | else None, 316 | **model_load_kwargs, 317 | ) 318 | tokenizer = AutoTokenizer.from_pretrained( 319 | model_args.model_name_or_path, 320 | cache_dir=training_args.cache_dir, 321 | model_max_length=training_args.model_max_length, 322 | padding_side="right", 323 | use_fast=False, 324 | ) 325 | 326 | if training_args.use_lora: 327 | lora_config = LoraConfig( 328 | r=lora_args.lora_r, 329 | lora_alpha=lora_args.lora_alpha, 330 | target_modules=lora_args.lora_target_modules, 331 | lora_dropout=lora_args.lora_dropout, 332 | bias=lora_args.lora_bias, 333 | task_type="CAUSAL_LM", 334 | ) 335 | if lora_args.q_lora: 336 | model = prepare_model_for_kbit_training( 337 | model, use_gradient_checkpointing=training_args.gradient_checkpointing 338 | ) 339 | 340 | model = get_peft_model(model, lora_config) 341 | 342 | # Print peft trainable params 343 | model.print_trainable_parameters() 344 | 345 | if training_args.gradient_checkpointing: 346 | model.enable_input_require_grads() 347 | 348 | # Load data 349 | data_module = make_supervised_data_module( 350 | tokenizer=tokenizer, data_args=data_args, max_len=training_args.model_max_length 351 | ) 352 | 353 | # Start trainer 354 | trainer = Trainer( 355 | model=model, tokenizer=tokenizer, args=training_args, **data_module 356 | ) 357 | 358 | # `not training_args.use_lora` is a temporary workaround for the issue that there are problems with 359 | # loading the checkpoint when using LoRA with DeepSpeed. 360 | # Check this issue https://github.com/huggingface/peft/issues/746 for more information. 361 | if ( 362 | list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")) 363 | and not training_args.use_lora 364 | ): 365 | trainer.train(resume_from_checkpoint=True) 366 | else: 367 | trainer.train() 368 | trainer.save_state() 369 | 370 | safe_save_model_for_hf_trainer( 371 | trainer=trainer, output_dir=training_args.output_dir, bias=lora_args.lora_bias 372 | ) 373 | 374 | 375 | if __name__ == "__main__": 376 | train() 377 | --------------------------------------------------------------------------------