├── docs ├── assets │ ├── qwk.png │ ├── workflow.png │ ├── annotation.png │ └── logo.svg └── Readme_cn.md ├── requirements.txt ├── .gitignore ├── prompts ├── few_shot_prompt.json └── predict.txt ├── utils ├── cal_steps.py ├── load_few_shot.py ├── errors_consistency_score.py └── collaborative_consistency_score.py ├── sas_pipelines ├── play.py ├── 3_compute_ccs.py ├── 2_process_prediction.py ├── 4_compute_ecs.py └── 1_predict_scores.py ├── main ├── models │ ├── chatglm.py │ └── chatglm_rlhf.py ├── loaders │ ├── llm_pure_text.py │ ├── llm_chat.py │ ├── chatglm_chat.py │ ├── qwen_chat.py │ └── chatglm_rlhf.py ├── analysis.py ├── predictor │ ├── openai.py │ ├── qwen_lora.py │ ├── vllm.py │ ├── chatglm.py │ ├── chatglm_lora.py │ ├── llm_lora.py │ └── llm.py ├── loader.py └── trainer │ ├── llm_lora.py │ └── chatglm_rlhf.py ├── discuss ├── sample_distribution.py ├── pred_gold.py ├── error_causes_f1.py └── compute_range_ccs.py ├── preprocess.py ├── LICENSE └── Readme.md /docs/assets/qwk.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-DAIR/SAS-Bench/HEAD/docs/assets/qwk.png -------------------------------------------------------------------------------- /docs/assets/workflow.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-DAIR/SAS-Bench/HEAD/docs/assets/workflow.png -------------------------------------------------------------------------------- /docs/assets/annotation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-DAIR/SAS-Bench/HEAD/docs/assets/annotation.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | protobuf 2 | transformers>=4.44.1 3 | cpm_kernels 4 | torch>=2.0 5 | gradio 6 | mdtex2html 7 | sentencepiece 8 | accelerate 9 | json_repair 10 | openai 11 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode 2 | ./*.txt 3 | **/**/__pycache__ 4 | model/ 5 | save_model/ 6 | data_record/ 7 | data/ 8 | chroma_data/ 9 | datasets_*_Scored/ 10 | **/**/api_key.txt 11 | results/ -------------------------------------------------------------------------------- /prompts/few_shot_prompt.json: -------------------------------------------------------------------------------- 1 | { 2 | "prefix": "【参考示例】", 3 | "suffix": "以上为参考样例和参考输出格式,请参考这些样例的评分格式对下面的真实材料进行评估", 4 | "template": "\n- 试题内容:{question}\n- 题目分值:{total}\n- 标准答案:{reference}\n- 解析说明:{analysis}\n- 学生作答:{student_answer}\n输出: {output}" 5 | } -------------------------------------------------------------------------------- /utils/cal_steps.py: -------------------------------------------------------------------------------- 1 | # %% 2 | import os 3 | import json 4 | import numpy as np 5 | 6 | DIR = '/home/lpc/repos/SAS_Benchmark/datasets' 7 | FILES = os.listdir(DIR) 8 | 9 | for file_name in FILES: 10 | if len(file_name.split('_')) < 3: 11 | continue 12 | path = os.path.join(DIR, file_name) 13 | step_count = [] 14 | length_count = [] 15 | with open(path) as f: 16 | ori_data = f.readlines() 17 | ori_data = [json.loads(item) for item in ori_data] 18 | for item in ori_data: 19 | step_count.append(len(item['steps'])) 20 | len_count = 0 21 | for step_item in item['steps']: 22 | len_count += len(list(step_item['response'])) 23 | length_count.append(len_count) 24 | 25 | print(file_name, np.mean(step_count), np.mean(length_count)) 26 | 27 | # %% 28 | -------------------------------------------------------------------------------- /prompts/predict.txt: -------------------------------------------------------------------------------- 1 | 请作为数学学科评分专家,根据以下要求对学生的作答进行专业评估: 2 | 3 | 【评估任务】 4 | 依据题目信息、参考答案及评分指南,对学生的分步解答进行精细化评分,并输出结构化评分结果。 5 | 6 | 【评分指南】 7 | {score_guideline} 8 | {few_shot_samples} 9 | 【评估材料】 10 | - 试题内容:{question} 11 | - 题目分值:{total} 12 | - 错因类型:{error_type} 13 | - 标准答案:{reference} 14 | - 解析说明:{analysis} 15 | - 学生作答:{student_answer} 16 | 17 | 【评估流程和要求】 18 | 1. 分步解析: 19 | - 拆解学生作答的每个解题步骤 20 | - 对每个步骤独立评估: 21 | * 判断正误('label') 22 | * 如存在错误,从错因列表中选取1项或多项主因('errors') 23 | - 单步评估格式:{{'step_score': 单步分数, 'errors': [错因]}} 24 | 25 | 2. 综合评定: 26 | - 汇总各步骤得分计算总分 27 | - 给出整体评价('label') 28 | 29 | 3. 结果输出: 30 | - 采用标准JSON格式输出: 31 | {{ 32 | 'total': 总分, 33 | 'pred_score': 评估总分数, 34 | 'steps': [各步骤评估结果] 35 | }} 36 | - 'pred_score'必须在'total'范围内 37 | - 分步的'step_score'累积值也必须在0到'pred_score'范围内 38 | 39 | 请按照上述规范完成评分,并以`JSON`格式输出标准化的评估结果。 -------------------------------------------------------------------------------- /docs/assets/logo.svg: -------------------------------------------------------------------------------- 1 | A -------------------------------------------------------------------------------- /sas_pipelines/play.py: -------------------------------------------------------------------------------- 1 | # %% 2 | import os 3 | os.environ['CUDA_VISIBLE_DEVICES'] = '1' 4 | import sys 5 | sys.path.append('../') 6 | from main.predictor.llm import Predictor 7 | 8 | pred = Predictor(model_from_pretrained='/home/lpc/models/glm-4-9b-chat/') 9 | 10 | # %% 11 | ask_content = '''请你现在扮演一位物理学科评分专家, 我们提供了一道题目和一个学生的回答, 请你根据题目分值、评分指南、错因列表,对照参考答案和解题分析,遵循以下规则, 对学生回答进行评分。 12 | 规则: 1. 你需要按学生回答的`每个步骤`分别进行评分,并对于有错误的答案从`错因列表`中选取一到多个错因作为打分依据,每个步骤的评估结果以{{'label': '', 'errors': []}}的形式输出。 13 | 2. 你需要根据每个步骤的评分结果输出最终的总分。 14 | 3. 最终的输出结果为json格式的数据,参考格式为{{'total': '', 'label': '', 'steps': []}}。其中`total`表示总分,`label`表示最终的评分结果,`steps`表示每个步骤的评分结果。 15 | 题目: {question} 16 | 分值: {total} 17 | 错因列表: {error_type} 18 | 参考答案: {reference} 19 | 解题分析: {analysis} 20 | 学生回答: {student_answer} 21 | 请你根据以上信息进行评分,并输出评分结果。 22 | ''' 23 | 24 | ask_content = '''请作为物理学科评分专家,根据以下要求对学生的作答进行专业评估: 25 | 26 | 【评估任务】 27 | 依据题目信息、参考答案及评分标准,对学生的分步解答进行精细化评分,并输出结构化评分结果。 28 | 29 | 【评估流程】 30 | 1. 分步解析: 31 | - 拆解学生作答的每个解题步骤 32 | - 对每个步骤独立评估: 33 | * 判断正误('label') 34 | * 如存在错误,从错因列表中选取1项或多项主因('errors') 35 | - 单步评估格式:{{'label': '', 'errors': []}} 36 | 37 | 2. 综合评定: 38 | - 汇总各步骤得分计算总分 39 | - 给出整体评价('label') 40 | 41 | 3. 结果输出: 42 | - 采用标准JSON格式: 43 | {{ 44 | 'total': '总分', 45 | 'label': '总体评价', 46 | 'steps': [各步骤评估结果] 47 | }} 48 | 49 | 【评估材料】 50 | - 试题内容:{question} 51 | - 题目分值:{total} 52 | - 错因类型:{error_type} 53 | - 标准答案:{reference} 54 | - 解析说明:{analysis} 55 | - 学生作答:{student_answer} 56 | 57 | 【评分准则】 58 | 1. 严格对照标准答案的解题逻辑链 59 | 2. 错因标注需精准对应学生错误本质 60 | 3. 保持不同步骤间的评分尺度一致性 61 | 4. 对于创新解法需额外验证其科学性 62 | 63 | 【特别说明】 64 | 1. 公式错误、单位遗漏等细节问题需单独标注 65 | 2. 概念性错误与计算错误需区分处理 66 | 3. 部分正确的情况应给予相应步骤分 67 | 68 | 请按照上述规范完成评分,并输出标准化的评估结果。''' 69 | 70 | pred(ask_content, build_message=True) 71 | 72 | # %% 73 | -------------------------------------------------------------------------------- /main/models/chatglm.py: -------------------------------------------------------------------------------- 1 | from re import A 2 | import torch 3 | import torch.nn as nn 4 | from transformers import AutoTokenizer, AutoModel 5 | 6 | class CCGPTModel(): 7 | 8 | def __init__( 9 | self, 10 | model_name: str = None, 11 | model_from_pretrained: str = None, 12 | model_config_file_name: str = None, 13 | pretrained_file_name: str = None, 14 | ): 15 | ''' 16 | CCGPTModel: 生成式模型 (Generative model) 17 | 18 | ### Args: 19 | `model_name`: 模型名称 (the name of the model) 20 | `model_from_pretrained`: 从预训练模型中加载模型 (load model from pretrained model) 21 | `model_config_file_name`: bert配置文件名 (bert config file name) 22 | `pretrained_file_name`: 预训练模型文件名 (pretrained file name) 23 | `tagset_size`: 标签数量 (the number of tags) 24 | ''' 25 | self.model_name = model_name 26 | self.model_from_pretrained = model_from_pretrained 27 | self.model_config_file_name = model_config_file_name 28 | self.pretrained_file_name = pretrained_file_name 29 | 30 | if self.model_name is None: 31 | raise ValueError("model_name is required") 32 | if self.model_from_pretrained is None: 33 | if self.model_config_file_name is None: 34 | raise ValueError("model_config_file_name is required") 35 | if self.pretrained_file_name is None: 36 | raise ValueError("pretrained_file_name is required") 37 | 38 | self.load_model() 39 | 40 | def load_model(self): 41 | if self.model_name == 'ChatGLM2-6B': 42 | self.model = AutoModel.from_pretrained(self.model_from_pretrained, trust_remote_code=True).half().cuda() 43 | 44 | def get_model(self): 45 | return self.model 46 | 47 | def __call__(self): 48 | return self.get_model() 49 | -------------------------------------------------------------------------------- /discuss/sample_distribution.py: -------------------------------------------------------------------------------- 1 | # %% 2 | import os 3 | import json 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | from argparse import ArgumentParser 7 | 8 | plt.style.use('seaborn-v0_8-bright') 9 | 10 | import sys 11 | sys.path.append("../") 12 | cmd_args = True 13 | 14 | parser = ArgumentParser() 15 | parser.add_argument('--file_dir', default='../datasets', help='the directory of the datasets.') 16 | parser.add_argument('--file_name', default='10_Math_gapfilling', help='file name of the dataset, you should make sure it contains `test.jsonl` file') 17 | parser.add_argument('--save_fig', default=0, help='whether save the figure') 18 | 19 | if not cmd_args: 20 | args = parser.parse_args([]) # You can directly set above parameters in the default. 21 | else: 22 | args = parser.parse_args() 23 | 24 | SAMPLE_PATH = os.path.join(args.file_dir, args.file_name + '.jsonl') 25 | 26 | with open(SAMPLE_PATH) as f: 27 | ori_data = f.readlines() 28 | ori_data = [json.loads(item) for item in ori_data] 29 | 30 | labels = [] 31 | steps = [] 32 | lengths = [] 33 | for item in ori_data: 34 | label = float(item['manual_label']) / float(item['total']) 35 | step_count = len(item['steps']) 36 | length = 0 37 | for step in item['steps']: 38 | length += len(step['response']) 39 | labels.append(label) 40 | steps.append(step_count) 41 | lengths.append(length) 42 | 43 | def normalize(data): 44 | return (data - np.min(data)) / (np.max(data) - np.min(data)) 45 | 46 | 47 | # %% 48 | plt.hist([item for item in labels], range=(0, 1), bins=20, alpha=0.8, color='#FFC000', label='Label') 49 | plt.hist(normalize(steps), range=(0, 1), bins=20, alpha=0.6, color='#d3d3f9', label='Steps') 50 | # plt.hist(normalize(lengths), bins=20, alpha=0.3, color='#A6C9E8', label='Length') 51 | plt.legend() 52 | if str(args.save_fig) == '1': 53 | plt.savefig(f'{args.file_name}_distribution.svg', format='svg', dpi=300) 54 | 55 | # %% 56 | -------------------------------------------------------------------------------- /utils/load_few_shot.py: -------------------------------------------------------------------------------- 1 | # %% 2 | import json 3 | import random 4 | from tqdm import tqdm 5 | 6 | def get_few_shot_samples(file_name, num_samples=3): 7 | with open(file_name) as f: 8 | ori_data = f.readlines() 9 | ori_data = [json.loads(item) for item in ori_data] 10 | low = [] 11 | mid = [] 12 | high = [] 13 | for item in tqdm(ori_data): 14 | total = float(item['total']) 15 | manual_score = float(item['manual_label']) 16 | norm_score = manual_score / total 17 | if norm_score < 0.333: 18 | low.append(item) 19 | elif norm_score < 0.667: 20 | mid.append(item) 21 | else: 22 | high.append(item) 23 | results = [] 24 | for i in range(num_samples): 25 | if i % num_samples == 0 and len(low) > 0: 26 | results.append(random.choice(low)) 27 | continue 28 | if i % num_samples == 1 and len(mid) > 0: 29 | results.append(random.choice(mid)) 30 | continue 31 | if i % num_samples == 2 and len(high) > 0: 32 | results.append(random.choice(high)) 33 | return results 34 | 35 | def compute_few_shot_prompt(sample, prompt): 36 | output = {} 37 | output_steps = [] 38 | steps = sample['steps'] 39 | for s in steps: 40 | os = {} 41 | os['step_score'] = int(s['label']) 42 | os['errors'] = s['errors'] 43 | output_steps.append(os) 44 | reponse_content = [] 45 | for s_idx, step in enumerate(steps): 46 | response = step['response'] 47 | reponse_content.append(f'## Step {s_idx}. {response}') 48 | output['total'] = sample['total'] 49 | output['pred_score'] = sample['manual_label'] 50 | output['steps'] = output_steps 51 | format_prompt = prompt.format(question=sample['question'], total=sample['total'], reference=sample['reference'], analysis=sample['analysis'], student_answer=''.join(reponse_content), output=json.dumps(output, ensure_ascii=False)) 52 | return format_prompt 53 | 54 | # %% 55 | -------------------------------------------------------------------------------- /main/loaders/llm_pure_text.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import torch 4 | import random 5 | import pickle 6 | import numpy as np 7 | from tqdm import tqdm 8 | from torch.utils.data import Dataset, DataLoader 9 | from transformers import PretrainedConfig, PreTrainedTokenizer 10 | 11 | class LLMPureTextDataset(Dataset): 12 | 13 | config: PretrainedConfig 14 | tokenizer: PreTrainedTokenizer 15 | 16 | def __init__(self, tokenizer, config, file_name, max_length=512, do_shuffle=False): 17 | self.config = config 18 | self.tokenizer = tokenizer 19 | self.max_length = max_length 20 | self.do_shuffle = do_shuffle 21 | self.data = self.load_jsonl(file_name) 22 | self.random_list = [idx for idx in range(len(self.data))] 23 | if self.do_shuffle: 24 | random.shuffle(self.random_list) 25 | 26 | 27 | def load_jsonl(self, file_name): 28 | with open(file_name, 'r') as f: 29 | data = [json.loads(line) for line in f] 30 | return data 31 | 32 | def __getitem__(self, index): 33 | index = self.random_list[index] 34 | data = self.data[index] 35 | context = data['context'] 36 | target = data['target'] 37 | 38 | mx = self.max_length // 2 39 | context_ids = self.tokenizer.encode(context, max_length=mx, truncation=True) 40 | target_ids = self.tokenizer.encode( 41 | target, 42 | max_length=mx, 43 | truncation=True, 44 | add_special_tokens=False) 45 | ids = context_ids + target_ids + [self.config.eos_token_id] 46 | input_len = len(ids) 47 | context_len = len(context_ids) 48 | labels = [-100] * (context_len - 1) + ids[context_len - 1:] + [-100] * (self.max_length - input_len) 49 | 50 | f_ids = ids + [self.config.pad_token_id] * (self.max_length - input_len) 51 | 52 | input_ids = torch.tensor(f_ids) 53 | labels = torch.tensor(labels) 54 | 55 | return { 56 | 'input_ids': input_ids, 57 | 'labels': labels 58 | } 59 | 60 | def __len__(self): 61 | return len(self.data) -------------------------------------------------------------------------------- /main/loaders/llm_chat.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import torch 4 | import random 5 | import pickle 6 | import numpy as np 7 | from tqdm import tqdm 8 | from torch.utils.data import Dataset, DataLoader 9 | from transformers import PretrainedConfig, PreTrainedTokenizer 10 | 11 | class LLMChatDataset(Dataset): 12 | 13 | config: PretrainedConfig 14 | tokenizer: PreTrainedTokenizer 15 | 16 | def __init__(self, tokenizer, config, file_name, max_length=512, do_shuffle=False): 17 | self.config = config 18 | self.tokenizer = tokenizer 19 | self.max_length = max_length 20 | self.do_shuffle = do_shuffle 21 | self.data = self.load_jsonl(file_name) 22 | self.random_list = [idx for idx in range(len(self.data))] 23 | if self.do_shuffle: 24 | random.shuffle(self.random_list) 25 | 26 | 27 | def load_jsonl(self, file_name): 28 | with open(file_name, 'r') as f: 29 | lines = f.readlines() 30 | data = [json.loads(line) for line in lines] 31 | return data 32 | 33 | def process_item(self, item): 34 | conv = item['conversations'] if 'conversations' in item else item 35 | 36 | input_ids, labels = [], [] 37 | 38 | for t in conv: 39 | role = t['role'] 40 | ids = self.tokenizer.apply_chat_template([t]) 41 | ls = ids if role not in ['user', 'system'] else [-100 for _ in ids] 42 | input_ids.extend(ids) 43 | labels.extend(ls) 44 | 45 | max_length = self.max_length 46 | input_ids = input_ids[:max_length] 47 | labels = labels[:max_length] 48 | return {'input_ids': input_ids, 'labels': labels} 49 | 50 | def __getitem__(self, index): 51 | index = self.random_list[index] 52 | data = self.data[index] 53 | input_ids, labels = self.process_item(data).values() 54 | 55 | input_ids = torch.tensor(input_ids) 56 | labels = torch.tensor(labels) 57 | 58 | return { 59 | 'input_ids': input_ids, 60 | 'labels': labels 61 | } 62 | 63 | def __len__(self): 64 | return len(self.data) -------------------------------------------------------------------------------- /main/analysis.py: -------------------------------------------------------------------------------- 1 | import os 2 | import datetime 3 | import numpy as np 4 | 5 | 6 | class Analysis(): 7 | 8 | def __init__(self): 9 | self.train_record = {} 10 | self.eval_record = {} 11 | self.model_record = {} 12 | 13 | ''' 14 | append data record of train 15 | train_record_item: dict 16 | ''' 17 | 18 | def append_train_record(self, train_record_item): 19 | for key in train_record_item: 20 | if key not in self.train_record: 21 | self.train_record[key] = [] 22 | self.train_record[key].append(train_record_item[key]) 23 | 24 | ''' 25 | append data record of eval 26 | eval_record_item: dict 27 | ''' 28 | 29 | def append_eval_record(self, eval_record_item): 30 | for key in eval_record_item: 31 | if key not in self.eval_record: 32 | self.eval_record[key] = [] 33 | self.eval_record[key].append(eval_record_item[key]) 34 | 35 | ''' 36 | append data record of model 37 | uid: model uid 38 | ''' 39 | 40 | def append_model_record(self, uid): 41 | key = "model_uid" 42 | if key not in self.model_record: 43 | self.model_record[key] = [] 44 | self.model_record[key].append(uid) 45 | 46 | def save_all_records(self, uid): 47 | self.save_record('train_record', uid) 48 | self.save_record('eval_record', uid) 49 | self.save_record('model_record', uid) 50 | 51 | def save_record(self, record_name, uid): 52 | record_dict = getattr(self, record_name) 53 | path = f'./data_record/{uid}' 54 | if not os.path.exists(path): 55 | os.makedirs(path) 56 | head = [] 57 | for key in record_dict: 58 | head.append(key) 59 | if len(head) == 0: 60 | return uid 61 | result = '' 62 | for idx in range(len(record_dict[head[0]])): 63 | for key in head: 64 | result += str(record_dict[key][idx]) + '\t' 65 | result += '\n' 66 | 67 | result = "\t".join(head) + '\n' + result 68 | 69 | with open(f'{path}/{record_name}.csv', encoding='utf-8', mode='w+') as f: 70 | f.write(result) 71 | 72 | return uid 73 | -------------------------------------------------------------------------------- /preprocess.py: -------------------------------------------------------------------------------- 1 | # %% 2 | import json 3 | from tqdm import tqdm 4 | 5 | DATANAME = '7_Math_ShortAns' 6 | IGNORE_ZERO_WITHOUT_REASON = True 7 | NO_IGNORE_REASON = '计算错误' 8 | filename = f'/home/lpc/repos/SAS_Benchmark/backend_data/scores/{DATANAME}.jsonl' 9 | with open(filename) as f: 10 | ori_data = f.readlines() 11 | ori_data = [json.loads(line) for line in ori_data] 12 | 13 | result = [] 14 | for item in tqdm(ori_data): 15 | res_segs = item['bad_student_answer_segs'] 16 | last_idx = 0 17 | format_response = { 18 | 'id': item['id'], 19 | 'question': item['question'], 20 | 'reference': item['answer'], 21 | 'analysis': item['analysis'], 22 | 'total': item['score'], 23 | 'steps': [] 24 | } 25 | if 'scoreItem' not in item: 26 | result.append(format_response) 27 | continue 28 | scoreItem = item['scoreItem'] 29 | format_response['manual_label'] = scoreItem['label'] 30 | seg_labels = scoreItem['seg_labels'] 31 | seg_labels = json.loads(seg_labels) 32 | remain_score = int(float(scoreItem['label'])) 33 | 34 | for seg_item in seg_labels: 35 | seg_idx, seg_label, seg_errors = seg_item['idx'], seg_item['label'], seg_item['errors'] 36 | if seg_label == '': 37 | continue 38 | if seg_label != '' and int(float(seg_label)) == 0 and len(seg_errors) == 0: 39 | if IGNORE_ZERO_WITHOUT_REASON: 40 | continue 41 | else: 42 | seg_errors = [NO_IGNORE_REASON] 43 | if seg_label != '' and int(float(seg_label)) > 0 and len(seg_errors) == 0: 44 | # if it is english dataset, you may replace it with 'correct' 45 | seg_errors = ['步骤正确'] 46 | format_response['steps'].append({ 47 | 'response': '\n'.join(res_segs[last_idx: seg_idx + 1]), 48 | 'label': seg_label, 49 | 'errors': seg_errors 50 | }) 51 | last_idx = seg_idx + 1 52 | if seg_label != '' and int(float(seg_label)) > 0: 53 | remain_score -= int(float(seg_label)) 54 | if last_idx < len(res_segs): 55 | format_response['steps'].append({ 56 | 'response': '\n'.join(res_segs[last_idx:]), 57 | 'label': 0 if remain_score < 0 else remain_score, 58 | 'errors': [] 59 | }) 60 | result.append(format_response) 61 | 62 | with open(f'./datasets/{DATANAME}.jsonl', 'w') as f: 63 | for item in result: 64 | f.write(json.dumps(item, ensure_ascii=False) + '\n') 65 | 66 | # %% 67 | -------------------------------------------------------------------------------- /discuss/pred_gold.py: -------------------------------------------------------------------------------- 1 | # %% 2 | import os 3 | import json 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | from argparse import ArgumentParser 7 | 8 | plt.style.use('seaborn-v0_8-bright') 9 | 10 | import sys 11 | sys.path.append("../") 12 | cmd_args = True 13 | 14 | parser = ArgumentParser() 15 | parser.add_argument('--file_dir', default='../datasets', help='the directory of the datasets.') 16 | parser.add_argument('--file_name', default='all', help='file name of the dataset without extension (e.g. 0_Physics_ShortAns), you can define `all` for all dataset files.') 17 | parser.add_argument('--save_type_name', default='mimo', help='the prefix name of save dir (usually is the LLM name)') 18 | parser.add_argument('--save_fig', default=0, help='whether save the figure') 19 | 20 | if not cmd_args: 21 | args = parser.parse_args([]) # You can directly set above parameters in the default. 22 | else: 23 | args = parser.parse_args() 24 | 25 | SOURCE_DIR = os.path.join(args.file_dir + '_' + args.save_type_name + '_Scored') 26 | 27 | labels = [] 28 | preds = [] 29 | step_labels = [] 30 | step_preds = [] 31 | if args.file_name != 'all': 32 | files = [args.file_name + '_prediction.jsonl'] 33 | else: 34 | files = os.listdir(SOURCE_DIR) 35 | for file_name in files: 36 | if file_name.find('prediction') < 0: 37 | continue 38 | SOURCE_FILE = os.path.join(SOURCE_DIR, file_name) 39 | with open(SOURCE_FILE, encoding='utf-8') as f: 40 | ori_data = f.readlines() 41 | ori_data = [json.loads(line) for line in ori_data] 42 | for item in ori_data: 43 | total = float(item['total']) 44 | label = float(item['manual_label']) / total 45 | pred = float(item['pred_label']) / total 46 | labels.append(label) 47 | preds.append(pred) 48 | for step in item['steps']: 49 | step_labels.append(float(step['label']) / total) 50 | for step in item['pred_steps']: 51 | step_preds.append(float(step['step_score']) / total) 52 | 53 | # %% 54 | plt.hist([item for item in labels], range=(0, 1), bins=20, alpha=0.6, color='#FFC000', label='Gold') 55 | plt.hist([item for item in preds], range=(0, 1), bins=20, alpha=0.6, color='#d3d3f9', label='Pred') 56 | # plt.hist([item for item in step_labels], range=(0, 1), bins=20, alpha=0.3, color='#A6C9E8', label='Pred Step-wise Score') 57 | # plt.hist([item for item in step_preds], range=(0, 1), bins=20, alpha=0.3, color='#44B6A3', label='Gold Step-wise Score') 58 | plt.legend() 59 | if str(args.save_fig) == '1': 60 | plt.savefig(f'{args.save_type_name}_pred_gold.svg', format='svg', dpi=300) 61 | 62 | # %% 63 | -------------------------------------------------------------------------------- /utils/errors_consistency_score.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from scipy.stats import spearmanr 4 | 5 | def compute_ecs(pred_scores, ori_scores, total_scores, pred_errors, gold_errors, max_error=5): 6 | """ 7 | Computes the consistency score between predicted and ground truth evaluations. 8 | 9 | Args: 10 | pred_scores: List of predicted scores for each sample 11 | ori_scores: List of original ground truth scores for each sample 12 | total_scores: Maximum possible total score for normalization 13 | pred_errors: Predicted error frequencies as 2D list [[freq1, freq2,...], ...] 14 | where each sublist represents error type frequencies for a sample 15 | gold_errors: Ground truth error frequencies with same structure as pred_errors 16 | max_error: Maximum possible number of errors for normalization 17 | """ 18 | 19 | all_norm_scores = [] 20 | Ln = 1e-10 21 | for score, max_s in zip(ori_scores, total_scores): 22 | norm_score = score / (max_s + Ln) 23 | all_norm_scores.append(norm_score) 24 | range1, range2 = 0, 0 25 | sorted_scores = sorted(all_norm_scores) 26 | range1 = sorted_scores[int(len(sorted_scores) * 0.33)] 27 | range2 = sorted_scores[int(len(sorted_scores) * 0.67)] 28 | 29 | ori_range_error_matrix = torch.zeros((3, max_error)) 30 | pred_range_error_matrix = torch.zeros((3, max_error)) 31 | for i in range(len(pred_errors)): 32 | pred, gold, max_s, p_errors, g_errors = pred_scores[i], ori_scores[i], total_scores[i], pred_errors[i], gold_errors[i] 33 | pred = pred / (max_s + Ln) 34 | gold = gold / (max_s + Ln) 35 | if pred <= range1: 36 | for j, freq in enumerate(p_errors): 37 | pred_range_error_matrix[0][j] += freq 38 | elif pred < range2: 39 | for j, freq in enumerate(p_errors): 40 | pred_range_error_matrix[1][j] += freq 41 | else: 42 | for j, freq in enumerate(p_errors): 43 | pred_range_error_matrix[2][j] += freq 44 | if gold <= range1: 45 | for j, freq in enumerate(g_errors): 46 | ori_range_error_matrix[0][j] += freq 47 | elif gold < range2: 48 | for j, freq in enumerate(g_errors): 49 | ori_range_error_matrix[1][j] += freq 50 | else: 51 | for j, freq in enumerate(g_errors): 52 | ori_range_error_matrix[2][j] += freq 53 | 54 | spearmans = [] 55 | for i in range(pred_range_error_matrix.shape[0]): 56 | spearman_score = spearmanr(pred_range_error_matrix[i].tolist(), ori_range_error_matrix[i].tolist()) 57 | spearmans.append(0 if str(spearman_score.correlation) == 'nan' else spearman_score.correlation) 58 | 59 | return np.mean(spearmans), spearmans 60 | -------------------------------------------------------------------------------- /main/loaders/chatglm_chat.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import torch 4 | import random 5 | import pickle 6 | import numpy as np 7 | from tqdm import tqdm 8 | from torch.utils.data import Dataset, DataLoader 9 | from transformers import PretrainedConfig, PreTrainedTokenizer 10 | 11 | class ChatGLM_ChatDataset(Dataset): 12 | 13 | config: PretrainedConfig 14 | tokenizer: PreTrainedTokenizer 15 | 16 | def __init__(self, tokenizer, config, file_name, max_length=512, do_shuffle=False): 17 | self.config = config 18 | self.tokenizer = tokenizer 19 | self.max_length = max_length 20 | self.do_shuffle = do_shuffle 21 | self.data = self.load_jsonl(file_name) 22 | self.random_list = [idx for idx in range(len(self.data))] 23 | if self.do_shuffle: 24 | random.shuffle(self.random_list) 25 | 26 | 27 | def load_jsonl(self, file_name): 28 | with open(file_name, 'r') as f: 29 | lines = f.readlines() 30 | data = [json.loads(line) for line in lines] 31 | return data 32 | 33 | def process_item(self, item): 34 | conv = item['conversations'] if 'conversations' in item else item 35 | 36 | input_ids, loss_masks = [ 37 | self.tokenizer.get_command('[gMASK]'), 38 | self.tokenizer.get_command('sop'), 39 | ], [False, False] 40 | 41 | for message in conv: 42 | if message['role'] in ('system', 'user'): 43 | loss_mask_val = False 44 | else: 45 | loss_mask_val = True 46 | 47 | if message['role'] == 'tool': 48 | raise NotImplementedError() 49 | else: 50 | new_input_ids = self.tokenizer.build_single_message( 51 | message['role'], '', message['content'] 52 | ) 53 | new_loss_masks = [loss_mask_val] * len(new_input_ids) 54 | 55 | input_ids += new_input_ids 56 | loss_masks += new_loss_masks 57 | 58 | input_ids.append(self.tokenizer.eos_token_id) 59 | loss_masks = [False, *loss_masks] 60 | labels = [] 61 | for input_id, mask in zip(input_ids, loss_masks): 62 | if mask: 63 | labels.append(input_id) 64 | else: 65 | labels.append(-100) 66 | max_length = self.max_length 67 | input_ids = input_ids[:max_length] 68 | labels = labels[:max_length] 69 | return {'input_ids': input_ids, 'labels': labels} 70 | 71 | def __getitem__(self, index): 72 | index = self.random_list[index] 73 | data = self.data[index] 74 | input_ids, labels = self.process_item(data).values() 75 | 76 | input_ids = torch.tensor(input_ids) 77 | labels = torch.tensor(labels) 78 | 79 | return { 80 | 'input_ids': input_ids, 81 | 'labels': labels 82 | } 83 | 84 | def __len__(self): 85 | return len(self.data) -------------------------------------------------------------------------------- /main/predictor/openai.py: -------------------------------------------------------------------------------- 1 | import json 2 | import torch 3 | from openai import OpenAI 4 | from openai._types import NotGiven 5 | from transformers import AutoTokenizer, AutoModel 6 | from typing import Tuple, List 7 | 8 | NOT_GIVEN = NotGiven() 9 | 10 | 11 | class Predictor(): 12 | 13 | def __init__(self, 14 | organization=None, 15 | api_key=None, 16 | base_url=None, 17 | **args 18 | ): 19 | ''' 20 | Predictor: OpenAI API预测器 (OpenAI API predictor) 21 | 22 | ### Args: 23 | 24 | `organization`: OpenAI组织名 (OpenAI organization name) 25 | 26 | `api_key`: OpenAI API密钥 (OpenAI API key) 27 | ''' 28 | 29 | self.organization = organization 30 | self.api_key = api_key 31 | self.base_url = base_url 32 | self.client = OpenAI(organization=self.organization, api_key=self.api_key, base_url=self.base_url) 33 | 34 | 35 | def predict(self, query: str = '', history: List = None, model: str = 'gpt-4o-mini', max_length=NOT_GIVEN, max_new_tokens=NOT_GIVEN, top_p: float = NOT_GIVEN, temperature=NOT_GIVEN): 36 | if history is None: 37 | history = [] 38 | raw = history + [{"role": "user", "content": query}] 39 | 40 | completion = self.client.chat.completions.create( 41 | model=model, 42 | messages=raw, 43 | max_tokens=max_length, 44 | max_completion_tokens=max_new_tokens, 45 | temperature=temperature, 46 | top_p=top_p, 47 | stream=False 48 | ) 49 | 50 | message = completion.choices[0].message.content 51 | raw.append({"role": "assistant", "content": message}) 52 | return message, raw 53 | 54 | def stream_chat(self, query: str = '', history: List = None, model: str = 'gpt-4o-mini', max_length=NOT_GIVEN, max_new_tokens=NOT_GIVEN, top_p: float = NOT_GIVEN, temperature=NOT_GIVEN): 55 | if history is None: 56 | history = [] 57 | raw = history + [{"role": "user", "content": query}] 58 | 59 | completion = self.client.chat.completions.create( 60 | model=model, 61 | messages=raw, 62 | max_tokens=max_length, 63 | max_completion_tokens=max_new_tokens, 64 | temperature=temperature, 65 | top_p=top_p, 66 | stream=True 67 | ) 68 | 69 | result = '' 70 | for chunk in completion: 71 | delta = chunk.choices[0].delta 72 | if delta.content is None: 73 | continue 74 | result += delta.content 75 | yield result, raw + [{"role": "assistant", "content": result}] 76 | 77 | def __call__(self, query: str = '', history: List = None, model: str = 'gpt-4o-mini', max_length=NOT_GIVEN, max_new_tokens=NOT_GIVEN, top_p: float = NOT_GIVEN, temperature=NOT_GIVEN): 78 | return self.predict(query, history, model, max_length, max_new_tokens, top_p, temperature) 79 | -------------------------------------------------------------------------------- /main/predictor/qwen_lora.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoModelForCausalLM, AutoConfig 2 | from transformers import AutoTokenizer 3 | from peft import LoraConfig, TaskType, PeftModel, PeftModelForCausalLM 4 | from typing import Optional, List 5 | from copy import deepcopy 6 | from transformers.generation import GenerationMixin 7 | import torch 8 | 9 | class Predictor(GenerationMixin): 10 | true_model: PeftModelForCausalLM 11 | 12 | def __init__(self, 13 | num_gpus: list = [0], 14 | model_from_pretrained: str = None, 15 | resume_path: str = None 16 | ): 17 | self.config = AutoConfig.from_pretrained(model_from_pretrained, trust_remote_code=True) 18 | peft_config = LoraConfig( 19 | task_type=TaskType.CAUSAL_LM, 20 | inference_mode=False, 21 | r=16, 22 | target_modules=["c_attn", "c_proj", "w1", "w2"], 23 | lora_alpha=32, 24 | lora_dropout=0.1, 25 | ) 26 | self.tokenizer = AutoTokenizer.from_pretrained( 27 | model_from_pretrained, trust_remote_code=True) 28 | self.model = AutoModelForCausalLM.from_pretrained( 29 | model_from_pretrained, device_map="auto", trust_remote_code=True).eval() 30 | self.llm = self.model 31 | self.generation_config = self.llm.generation_config 32 | self.model = PeftModel.from_pretrained( 33 | self.model, resume_path, config=peft_config) 34 | 35 | def predict(self, text='', max_length=150, temperature=1.0): 36 | with torch.no_grad(): 37 | inputs = self.tokenizer.encode(text) 38 | input_ids = torch.LongTensor([inputs]).to(self.device) 39 | output = self.true_model.generate(**{ 40 | 'input_ids': input_ids, 41 | 'max_length': max_length, 42 | 'do_sample': False, 43 | 'temperature': temperature 44 | }) 45 | out_text = self.tokenizer.decode( 46 | output[0], skip_special_tokens=True) 47 | return out_text 48 | 49 | @torch.inference_mode() 50 | def chat(self, 51 | query: str, 52 | history = None, 53 | system: str = "You are a helpful assistant.", 54 | stop_words_ids: Optional[List[List[int]]] = None, 55 | generation_config=None, 56 | **kwargs,): 57 | tokenizer = self.tokenizer 58 | generation_config = generation_config if generation_config is not None else self.generation_config 59 | 60 | return self.model.chat( 61 | tokenizer=tokenizer, 62 | query=query, 63 | history=history, 64 | system=system, 65 | generation_config=generation_config, 66 | stop_words_ids=stop_words_ids, 67 | **kwargs, 68 | ) 69 | 70 | def __call__(self, text='', max_length=150, temperature=0): 71 | return self.predict(text=text, max_length=max_length, temperature=temperature) 72 | -------------------------------------------------------------------------------- /main/loaders/qwen_chat.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import torch 4 | import random 5 | import pickle 6 | import numpy as np 7 | from tqdm import tqdm 8 | from torch.utils.data import Dataset, DataLoader 9 | from transformers import PretrainedConfig, PreTrainedTokenizer 10 | 11 | class QwenChatDataset(Dataset): 12 | 13 | config: PretrainedConfig 14 | tokenizer: PreTrainedTokenizer 15 | 16 | def __init__(self, tokenizer, config, file_name, max_length=512, do_shuffle=False): 17 | self.config = config 18 | self.tokenizer = tokenizer 19 | self.max_length = max_length 20 | self.do_shuffle = do_shuffle 21 | self.data = self.load_jsonl(file_name) 22 | self.random_list = [idx for idx in range(len(self.data))] 23 | if self.do_shuffle: 24 | random.shuffle(self.random_list) 25 | 26 | 27 | def load_jsonl(self, file_name): 28 | with open(file_name, 'r') as f: 29 | data = [json.loads(line) for line in f] 30 | return data 31 | 32 | def process_item(self, item): 33 | conv = item['conversations'] if 'conversations' in item else item 34 | 35 | input_ids, loss_masks = [], [] 36 | 37 | # im_start, im_end = "<|im_start|>", "<|im_end|>" 38 | im_start_tokens = [self.tokenizer.im_start_id] 39 | im_end_tokens = [self.tokenizer.im_end_id] 40 | nl_tokens = self.tokenizer.encode("\n") 41 | 42 | def _tokenize_str(role, content): 43 | return f"{role}\n{content}", self.tokenizer.encode( 44 | role, allowed_special=set() 45 | ) + nl_tokens + self.tokenizer.encode(content, allowed_special=set()) 46 | 47 | for message in conv: 48 | if message['role'] in ('system', 'user'): 49 | loss_mask_val = False 50 | else: 51 | loss_mask_val = True 52 | 53 | if message['role'] == 'tool': 54 | raise NotImplementedError() 55 | else: 56 | _, msg_tokens = _tokenize_str(message['role'], message['content']) 57 | new_input_ids = im_start_tokens + msg_tokens + im_end_tokens 58 | new_input_ids = new_input_ids + nl_tokens 59 | new_loss_masks = [loss_mask_val] * len(new_input_ids) 60 | 61 | input_ids += new_input_ids 62 | loss_masks += new_loss_masks 63 | 64 | input_ids = input_ids + im_end_tokens 65 | loss_masks = [False, *loss_masks] 66 | labels = [] 67 | for input_id, mask in zip(input_ids, loss_masks): 68 | if mask: 69 | labels.append(input_id) 70 | else: 71 | labels.append(-100) 72 | max_length = self.max_length 73 | input_ids = input_ids[:max_length] 74 | labels = labels[:max_length] 75 | return {'input_ids': input_ids, 'labels': labels} 76 | 77 | def __getitem__(self, index): 78 | index = self.random_list[index] 79 | data = self.data[index] 80 | input_ids, labels = self.process_item(data).values() 81 | 82 | input_ids = torch.tensor(input_ids) 83 | labels = torch.tensor(labels) 84 | 85 | return { 86 | 'input_ids': input_ids, 87 | 'labels': labels 88 | } 89 | 90 | def __len__(self): 91 | return len(self.data) -------------------------------------------------------------------------------- /utils/collaborative_consistency_score.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import numpy as np 3 | from itertools import product 4 | from sklearn.metrics import confusion_matrix, cohen_kappa_score 5 | 6 | def composite_distance(true_tuple, pred_tuple, weights): 7 | """ 8 | Computes the weighted difference (e.g., Manhattan distance) between composite score tuples. 9 | 10 | Args: 11 | true_tuple: Ground truth scores (total_score, step1_score, step2_score) 12 | pred_tuple: Predicted scores (total_score, step1_score, step2_score) 13 | weights: Weights for each component [w_total, w_step1, w_step2], e.g., [0.5, 0.25, 0.25] 14 | 15 | Returns: 16 | Weighted composite difference between the tuples 17 | """ 18 | return np.sum(weights * np.abs(np.array(true_tuple) - np.array(pred_tuple))) 19 | 20 | def compute_QWK(X, Y, grading_scale=1): 21 | X = [int(x * grading_scale) for x in X] 22 | Y = [int(y * grading_scale) for y in Y] 23 | X = np.array(X) 24 | Y = np.array(Y) 25 | 26 | qwk = cohen_kappa_score(Y, X, weights='quadratic') 27 | return qwk 28 | 29 | def adjusted_qwk(true_tuples, pred_tuples, weights, max_scores): 30 | """ 31 | Computes weighted overall and step-wise consistent scores between ground truth and predicted evaluation tuples. 32 | 33 | Args: 34 | true_tuples: List of ground truth score tuples [(total_score, step1, ..., step10), ...] 35 | pred_tuples: List of predicted score tuples [(total_score, step1, ..., step10), ...] 36 | weights: Weight coefficients for each dimension [w_total, w_step1, ..., w_step10] 37 | (should satisfy sum(weights) = 1) 38 | max_scores: Maximum possible scores for normalization [max_total, max_step1, ..., max_step10] 39 | """ 40 | # Validate each score entry is of tuple type 41 | true_tuples = [tuple(t) for t in true_tuples] 42 | pred_tuples = [tuple(p) for p in pred_tuples] 43 | 44 | # Construct unified scoring composition levels (with deduplication) 45 | unique_tuples = sorted(list(set(true_tuples + pred_tuples))) 46 | n_levels = len(unique_tuples) 47 | tuple_to_idx = {t: i for i, t in enumerate(unique_tuples)} 48 | 49 | # Build observation matrix O (actual score frequencies) 50 | O = np.zeros((n_levels, n_levels)) 51 | for t, p in zip(true_tuples, pred_tuples): 52 | i = tuple_to_idx[t] 53 | j = tuple_to_idx[p] 54 | O[i, j] += 1 55 | 56 | # Build expectation matrix E (theoretical score probabilities) 57 | row_sums = O.sum(axis=1) 58 | col_sums = O.sum(axis=0) 59 | E = np.outer(row_sums, col_sums) / np.sum(O) 60 | 61 | # Construct weight matrix W (normalized squared differences) 62 | max_scores = np.array(max_scores) 63 | weights = np.array(weights) 64 | 65 | W = np.zeros((n_levels, n_levels)) 66 | Ln = 1e-10 67 | for i, ti in enumerate(unique_tuples): 68 | for j, tj in enumerate(unique_tuples): 69 | ti = np.array(ti) 70 | tj = np.array(tj) 71 | diff = (ti - tj) / (max_scores + Ln) 72 | if i > len(weights) - 1: 73 | W[i, j] = np.sum(0) 74 | else: 75 | W[i, j] = np.sum(weights[i] * diff ** 2) 76 | 77 | # Compute adjusted Composite Consistency Score (CCS) 78 | ccs = 1 - np.sum(O * W) / np.sum(E * W) 79 | 80 | X = [item[0] for item in pred_tuples] 81 | Y = [item[0] for item in true_tuples] 82 | qwk = compute_QWK(X, Y) 83 | 84 | return ccs, qwk 85 | -------------------------------------------------------------------------------- /sas_pipelines/3_compute_ccs.py: -------------------------------------------------------------------------------- 1 | # %% 2 | import os 3 | import json 4 | import random 5 | import json_repair 6 | import numpy as np 7 | from tqdm import tqdm 8 | from argparse import ArgumentParser 9 | 10 | import sys 11 | sys.path.append("../") 12 | cmd_args = True 13 | 14 | parser = ArgumentParser() 15 | parser.add_argument('--file_dir', default='../datasets', help='the directory of the datasets.') 16 | parser.add_argument('--file_name', default='all', help='file name of the dataset without extension (e.g. 0_Physics_ShortAns), you can define `all` for all dataset files.') 17 | parser.add_argument('--save_type_name', default='Deepseek', help='the prefix name of save dir (usually is the LLM name)') 18 | 19 | if not cmd_args: 20 | args = parser.parse_args([]) # You can directly set above parameters in the default. 21 | else: 22 | args = parser.parse_args() 23 | 24 | SOURCE_DIR = os.path.join(args.file_dir + '_' + args.save_type_name + '_Scored') 25 | 26 | from utils.collaborative_consistency_score import adjusted_qwk 27 | 28 | results = [] 29 | if args.file_name != 'all': 30 | files = [args.file_name + '_prediction.jsonl'] 31 | else: 32 | files = os.listdir(SOURCE_DIR) 33 | for file_name in files: 34 | if file_name.find('prediction') < 0: 35 | continue 36 | SOURCE_FILE = os.path.join(SOURCE_DIR, file_name) 37 | with open(SOURCE_FILE, encoding='utf-8') as f: 38 | ori_data = f.readlines() 39 | ori_data = [json.loads(line) for line in ori_data] 40 | 41 | pred_result = [] 42 | ori_result = [] 43 | weights_result = [] 44 | max_score = 0 45 | max_length = 0 46 | for item in tqdm(ori_data): 47 | if max_score < item['total']: 48 | max_score = item['total'] 49 | ori_steps = item['steps'] 50 | if 'pred_steps' not in item: 51 | pred_steps = [{"step_score": 0, "errors": []} for _ in ori_steps] 52 | else: 53 | pred_steps = item['pred_steps'] 54 | ori_score = item['manual_label'] 55 | if 'pred_label' not in item or item['pred_label'] == '': 56 | pred_score = 0 57 | else: 58 | pred_score = item['pred_label'] 59 | ori_labels = [int(ori_score)] 60 | pred_labels = [int(pred_score)] 61 | weights = [0.5] + [0.5 / len(ori_steps) for _ in range(len(ori_steps))] 62 | for i in range(len(ori_steps)): 63 | try: 64 | ori_labels.append(int(ori_steps[i]['label'])) 65 | except: 66 | ori_labels.append(0) 67 | if len(pred_steps) > i and type(pred_steps[i]) == dict and 'step_score' in pred_steps[i]: 68 | pred_labels.append(int(pred_steps[i]['step_score'])) 69 | else: 70 | pred_labels.append(0) 71 | if max_length < len(ori_labels): 72 | max_length = len(ori_labels) 73 | pred_result.append(pred_labels) 74 | ori_result.append(ori_labels) 75 | weights_result.append(weights) 76 | 77 | # Padding 78 | for i in range(len(pred_result)): 79 | pred_result[i] += [0] * (max_length - len(pred_result[i])) 80 | ori_result[i] += [0] * (max_length - len(ori_result[i])) 81 | weights_result[i] += [0] * (max_length - len(weights_result[i])) 82 | max_scores = [] 83 | for i in range(len(ori_result[0])): 84 | val_i = [] 85 | for item in ori_result: 86 | val_i.append(item[i]) 87 | for item in pred_result: 88 | val_i.append(item[i]) 89 | max_scores.append(max(val_i)) 90 | max_scores[0] = max_score 91 | results.append((file_name, adjusted_qwk(ori_result, pred_result, weights_result, max_scores))) 92 | 93 | # %% 94 | def sort_func(x): 95 | try: 96 | x = x[0] 97 | x = x.split('_')[0] 98 | x = int(x) 99 | return x 100 | except: 101 | return 0 102 | 103 | results = sorted(results, key=sort_func) 104 | for item in results: 105 | print(item) 106 | 107 | with open(f'{args.save_type_name}_ccs.csv', 'w+') as f: 108 | for item in results: 109 | cols = [item[0], str(round(item[1][0] * 100, 2)), str(round(item[1][1] * 100 ,2))] 110 | f.write(','.join(cols) + '\n') 111 | 112 | # %% 113 | -------------------------------------------------------------------------------- /sas_pipelines/2_process_prediction.py: -------------------------------------------------------------------------------- 1 | # %% 2 | import os 3 | import re 4 | import json 5 | import random 6 | import json_repair 7 | import numpy as np 8 | from tqdm import tqdm 9 | from argparse import ArgumentParser 10 | 11 | import sys 12 | sys.path.append("../") 13 | cmd_args = True 14 | 15 | parser = ArgumentParser() 16 | parser.add_argument('--file_dir', default='../datasets', help='the directory of the datasets.') 17 | parser.add_argument('--file_name', default='all', help='file name of the dataset without extension (e.g. 0_Physics_ShortAns), you can define `all` for all dataset files.') 18 | parser.add_argument('--save_type_name', default='Qwen3_32B', help='the prefix name of save dir (usually is the LLM name)') 19 | 20 | if not cmd_args: 21 | args = parser.parse_args([]) # You can directly set above parameters in the default. 22 | else: 23 | args = parser.parse_args() 24 | 25 | def unwrap_score_item(item, len_steps): 26 | item = json.loads(item) 27 | item = str(item) 28 | if item.find("{'total'") >= 0 and item.find('{"total"') < 0: 29 | item = item.replace('\'', '"') 30 | match_index = item.find('{"total"') 31 | if match_index >= 0: 32 | item = item[match_index:] 33 | item = item.replace('}```', '}') 34 | # Remove // Comments 35 | item = re.sub(r'//.*$', '', item, flags=re.MULTILINE) 36 | # Remove /* */ Comments 37 | item = re.sub(r'/\*[\s\S]*?\*/', '', item) 38 | item = json_repair.loads(item) 39 | if type(item) == list: 40 | item = item[0] 41 | else: 42 | item = { 43 | 'total': 0, 44 | 'pred_score': 0, 45 | 'steps': [{'step_score': 0, 'errors': []} for _ in range(len_steps)] 46 | } 47 | 48 | if 'steps' not in item or type(item['steps']) != list: 49 | item['steps'] = [{'step_score': 0, 'errors': []} for _ in range(len_steps)] 50 | 51 | if 'pred_score' not in item: 52 | item['pred_score'] = 0 53 | 54 | for step_idx, step in enumerate(item['steps']): 55 | if type(step) != dict: 56 | item['steps'][step_idx] = {} 57 | step = item['steps'][step_idx] 58 | if 'step_score' not in step: 59 | step['step_score'] = 0 60 | try: 61 | step['step_score'] = int(step['step_score']) 62 | except: 63 | step['step_score'] = 0 64 | if 'errors' not in step or type(step['errors']) != list: 65 | step['errors'] = [] 66 | return item 67 | 68 | SOURCE_DIR = os.path.join(args.file_dir) 69 | if args.file_name != 'all': 70 | files = [args.file_name + '.jsonl'] 71 | else: 72 | files = os.listdir(SOURCE_DIR) 73 | for file_name in files: 74 | if len(file_name.split('_')) < 3: 75 | continue 76 | SOURCE_FILE = os.path.join(SOURCE_DIR, file_name) 77 | with open(SOURCE_FILE, encoding='utf-8') as f: 78 | ori_data = f.readlines() 79 | ori_data = [json.loads(line) for line in ori_data] 80 | 81 | SCORED_FILE = os.path.join(args.file_dir + '_' + args.save_type_name + '_Scored', f'{file_name.split(".jsonl")[0]}_scored.jsonl') 82 | print(SCORED_FILE) 83 | if not os.path.exists(SCORED_FILE): 84 | continue 85 | with open(SCORED_FILE, encoding='utf-8') as f: 86 | scored_data = f.readlines() 87 | 88 | for item, score_item in tqdm(zip(ori_data, scored_data)): 89 | score_item = score_item.split('\t')[1] 90 | len_steps = len(item['steps']) 91 | score_item = unwrap_score_item(score_item, len_steps=len_steps) 92 | try: 93 | item['pred_label'] = int(score_item['pred_score']) 94 | except: 95 | item['pred_label'] = 0 96 | item['pred_steps'] = score_item['steps'] 97 | item['pred_steps'] = item['pred_steps'][:len_steps] 98 | if len(item['pred_steps']) < len_steps: 99 | for _ in range(len_steps - len(item['pred_steps'])): 100 | item['pred_steps'].append({'step_score': 0, 'errors': []}) 101 | 102 | SAVE_DIR = os.path.join(os.path.dirname(SCORED_FILE), f'{file_name.split(".jsonl")[0]}_prediction.jsonl') 103 | with open(SAVE_DIR, 'w', encoding='utf-8') as f: 104 | for item in ori_data: 105 | f.write(json.dumps(item, ensure_ascii=False) + '\n') 106 | 107 | # %% 108 | -------------------------------------------------------------------------------- /discuss/error_causes_f1.py: -------------------------------------------------------------------------------- 1 | # %% 2 | import os 3 | import json 4 | import random 5 | import json_repair 6 | import numpy as np 7 | from tqdm import tqdm 8 | from argparse import ArgumentParser 9 | 10 | import sys 11 | sys.path.append("../") 12 | cmd_args = True 13 | 14 | parser = ArgumentParser() 15 | parser.add_argument('--file_dir', default='../datasets', help='the directory of the datasets.') 16 | parser.add_argument('--file_name', default='all', help='file name of the dataset without extension (e.g. 0_Physics_ShortAns), you can define `all` for all dataset files.') 17 | parser.add_argument('--save_type_name', default='Deepseek', help='the prefix name of save dir (usually is the LLM name)') 18 | parser.add_argument('--skip_correct', default=True, help='batch size') 19 | 20 | if not cmd_args: 21 | args = parser.parse_args([]) # You can directly set above parameters in the default. 22 | else: 23 | args = parser.parse_args() 24 | 25 | SKIP_CORRECT = args.skip_correct 26 | # if it is english dataset, you may replace it with {'name': 'correct', 'description': 'the step is correct.'} 27 | # but it depends on how you predefined the `name` of the correct step. 28 | CORRECT_NAME = '步骤正确' 29 | CORRECT_DESCRIPTION = '该步骤正确' 30 | SOURCE_DIR = os.path.join(args.file_dir + '_' + args.save_type_name + '_Scored') 31 | 32 | from utils.errors_consistency_score import compute_ecs 33 | 34 | results = [] 35 | if args.file_name != 'all': 36 | files = [args.file_name + '_prediction.jsonl'] 37 | else: 38 | files = os.listdir(SOURCE_DIR) 39 | 40 | for file_name in files: 41 | if file_name.find('prediction') < 0: 42 | continue 43 | SOURCE_FILE = os.path.join(SOURCE_DIR, file_name) 44 | with open(SOURCE_FILE, encoding='utf-8') as f: 45 | ori_data = f.readlines() 46 | ori_data = [json.loads(line) for line in ori_data] 47 | 48 | ERROR_FILE = os.path.join(args.file_dir, 'error_type.jsonl') 49 | ID = file_name.split('_')[0] 50 | with open(ERROR_FILE, encoding='utf-8') as f: 51 | error_type_list = f.readlines() 52 | error_type_list = [json.loads(item) for item in error_type_list] 53 | error_type_item = [] 54 | score_guideline = '' 55 | for item in error_type_list: 56 | if str(item['q_id']) == str(ID): 57 | error_type_item = item['errors'] 58 | if not SKIP_CORRECT: 59 | 60 | error_type_item.append({'name': CORRECT_NAME, 'description': CORRECT_DESCRIPTION}) 61 | break 62 | 63 | error_to_id_dict = {} 64 | for i, item in enumerate(error_type_item): 65 | error_to_id_dict[item['name']] = i 66 | def err2idx(name): 67 | if name in error_to_id_dict: 68 | return error_to_id_dict[name] 69 | return len(error_to_id_dict) - 1 70 | 71 | tp = 0 72 | fp = 0 73 | fn = 0 74 | max_error_length = len(error_type_item) 75 | max_length = 0 76 | for item in tqdm(ori_data): 77 | ori_steps = item['steps'] 78 | if 'pred_steps' not in item: 79 | pred_steps = [{"step_score": 0, "errors": []} for _ in ori_steps] 80 | else: 81 | pred_steps = item['pred_steps'] 82 | p_errors = [0 for _ in range(max_error_length)] 83 | g_errors = [0 for _ in range(max_error_length)] 84 | for step in ori_steps: 85 | errors = step['errors'] 86 | for error in errors: 87 | if error == CORRECT_NAME and SKIP_CORRECT: 88 | continue 89 | g_errors[err2idx(error)] += 1 90 | for step in pred_steps: 91 | errors = step['errors'] 92 | for error in errors: 93 | if error == CORRECT_NAME and SKIP_CORRECT: 94 | continue 95 | p_errors[err2idx(str(error))] += 1 96 | for p, g in zip(p_errors, g_errors): 97 | if p > 0: 98 | if g > 0: 99 | tp += 1 100 | else: 101 | fp += 1 102 | else: 103 | if g > 0: 104 | fn += 1 105 | 106 | P = tp / (tp + fp) 107 | R = tp / (tp + fn) 108 | F1 = (2 * P * R) / (P + R) 109 | 110 | results.append((file_name, F1, P, R)) 111 | 112 | # %% 113 | def sort_func(x): 114 | try: 115 | x = x[0] 116 | x = x.split('_')[0] 117 | x = int(x) 118 | return x 119 | except: 120 | return 0 121 | 122 | results = sorted(results, key=sort_func) 123 | for item in results: 124 | print(item) 125 | 126 | with open(f'{args.save_type_name}_ef1.csv', 'w+') as f: 127 | for item in results: 128 | cols = [item[0], str(round(item[1] * 100, 2)), str(round(item[2] * 100 ,2)), str(round(item[3] * 100 ,2))] 129 | f.write(','.join(cols) + '\n') 130 | 131 | # %% 132 | -------------------------------------------------------------------------------- /sas_pipelines/4_compute_ecs.py: -------------------------------------------------------------------------------- 1 | # %% 2 | import os 3 | import json 4 | import random 5 | import json_repair 6 | import numpy as np 7 | from tqdm import tqdm 8 | from argparse import ArgumentParser 9 | 10 | import sys 11 | sys.path.append("../") 12 | cmd_args = True 13 | 14 | parser = ArgumentParser() 15 | parser.add_argument('--file_dir', default='../datasets', help='the directory of the datasets.') 16 | parser.add_argument('--file_name', default='all', help='file name of the dataset without extension (e.g. 0_Physics_ShortAns), you can define `all` for all dataset files.') 17 | parser.add_argument('--save_type_name', default='Deepseek', help='the prefix name of save dir (usually is the LLM name)') 18 | parser.add_argument('--skip_correct', default=True, help='batch size') 19 | 20 | if not cmd_args: 21 | args = parser.parse_args([]) # You can directly set above parameters in the default. 22 | else: 23 | args = parser.parse_args() 24 | 25 | SKIP_CORRECT = args.skip_correct 26 | SOURCE_DIR = os.path.join(args.file_dir + '_' + args.save_type_name + '_Scored') 27 | 28 | # if it is english dataset, you may replace it with {'name': 'correct', 'description': 'the step is correct.'} 29 | # but it depends on how you predefined the `name` of the correct step. 30 | CORRECT_NAME = '步骤正确' 31 | CORRECT_DESCRIPTION = '该步骤正确' 32 | 33 | from utils.errors_consistency_score import compute_ecs 34 | 35 | results = [] 36 | if args.file_name != 'all': 37 | files = [args.file_name + '_prediction.jsonl'] 38 | else: 39 | files = os.listdir(SOURCE_DIR) 40 | 41 | for file_name in files: 42 | if file_name.find('prediction') < 0: 43 | continue 44 | SOURCE_FILE = os.path.join(SOURCE_DIR, file_name) 45 | with open(SOURCE_FILE, encoding='utf-8') as f: 46 | ori_data = f.readlines() 47 | ori_data = [json.loads(line) for line in ori_data] 48 | 49 | ERROR_FILE = os.path.join(args.file_dir, 'error_type.jsonl') 50 | ID = file_name.split('_')[0] 51 | with open(ERROR_FILE, encoding='utf-8') as f: 52 | error_type_list = f.readlines() 53 | error_type_list = [json.loads(item) for item in error_type_list] 54 | error_type_item = [] 55 | score_guideline = '' 56 | for item in error_type_list: 57 | if str(item['q_id']) == str(ID): 58 | error_type_item = item['errors'] 59 | error_type_item.append({'name': CORRECT_NAME, 'description': CORRECT_DESCRIPTION}) 60 | break 61 | 62 | error_to_id_dict = {} 63 | for i, item in enumerate(error_type_item): 64 | error_to_id_dict[item['name']] = i 65 | def err2idx(name): 66 | if name in error_to_id_dict: 67 | return error_to_id_dict[name] 68 | return len(error_to_id_dict) - 1 69 | 70 | pred_scores = [] 71 | ori_scores = [] 72 | 73 | total_scores = [] 74 | 75 | pred_errors = [] 76 | gold_errors = [] 77 | max_error_length = len(error_type_item) 78 | max_length = 0 79 | for item in tqdm(ori_data): 80 | total_scores.append(item['total']) 81 | ori_score = item['manual_label'] 82 | if 'pred_label' not in item or item['pred_label'] == '': 83 | pred_score = 0 84 | else: 85 | pred_score = item['pred_label'] 86 | pred_scores.append(int(pred_score)) 87 | ori_scores.append(int(ori_score)) 88 | 89 | ori_steps = item['steps'] 90 | if 'pred_steps' not in item: 91 | pred_steps = [{"step_score": 0, "errors": []} for _ in ori_steps] 92 | else: 93 | pred_steps = item['pred_steps'] 94 | p_errors = [0 for _ in range(max_error_length)] 95 | g_errors = [0 for _ in range(max_error_length)] 96 | for step in ori_steps: 97 | errors = step['errors'] 98 | for error in errors: 99 | if error == CORRECT_NAME and SKIP_CORRECT: 100 | continue 101 | g_errors[err2idx(error)] += 1 102 | for step in pred_steps: 103 | errors = step['errors'] 104 | for error in errors: 105 | if error == CORRECT_NAME and SKIP_CORRECT: 106 | continue 107 | p_errors[err2idx(str(error))] += 1 108 | pred_errors.append(p_errors) 109 | gold_errors.append(g_errors) 110 | 111 | results.append((file_name, compute_ecs(pred_scores=pred_scores, ori_scores=ori_scores, total_scores=total_scores, pred_errors=pred_errors, gold_errors=gold_errors, max_error=max_error_length))) 112 | 113 | # %% 114 | def sort_func(x): 115 | try: 116 | x = x[0] 117 | x = x.split('_')[0] 118 | x = int(x) 119 | return x 120 | except: 121 | return 0 122 | 123 | results = sorted(results, key=sort_func) 124 | for item in results: 125 | print(item) 126 | 127 | with open(f'{args.save_type_name}_ecs.csv', 'w+') as f: 128 | for item in results: 129 | cols = [item[0], str(round(item[1][0] * 100, 2)), str(round(item[1][1][0] * 100 ,2)), str(round(item[1][1][1] * 100 ,2)), str(round(item[1][1][2] * 100 ,2))] 130 | f.write(','.join(cols) + '\n') 131 | 132 | # %% 133 | -------------------------------------------------------------------------------- /main/predictor/vllm.py: -------------------------------------------------------------------------------- 1 | import json 2 | import torch 3 | from transformers import AutoTokenizer, AutoConfig 4 | from vllm import LLM, SamplingParams 5 | from typing import Tuple, List 6 | 7 | 8 | class Predictor(): 9 | 10 | def __init__(self, 11 | tensor_parallel_size: int = 1, 12 | model_from_pretrained: str = None, 13 | **args 14 | ): 15 | ''' 16 | Predictor: LLM预测器 (LLM predictor) 17 | 18 | ### Args: 19 | 20 | `tensor_parallel_size`: 张量并行使用的 GPU 数量(模型被拆分到多个卡) 21 | 22 | `model_config_file_name`: bert配置文件名 (bert config file name) 23 | ''' 24 | self.tp = tensor_parallel_size 25 | self.model_from_pretrained = model_from_pretrained 26 | self.model_init() 27 | 28 | def model_init(self): 29 | self.config = AutoConfig.from_pretrained( 30 | self.model_from_pretrained, trust_remote_code=True) 31 | self.tokenizer = AutoTokenizer.from_pretrained( 32 | self.model_from_pretrained, padding_side="left", trust_remote_code=True) 33 | self.llm = LLM(self.model_from_pretrained, tensor_parallel_size=self.tp, trust_remote_code=True) 34 | if hasattr(self.config, 'eos_token_id'): 35 | self.eos_token_id = [self.config.eos_token_id] 36 | if hasattr(self.config, 'bos_token_id'): 37 | self.bos_token_id = [self.config.bos_token_id] 38 | if self.config.model_type == 'chatglm': 39 | self.eos_token_id = self.config.eos_token_id 40 | elif self.config.model_type == 'llama': 41 | self.tokenizer.pad_token = self.tokenizer.eos_token 42 | terminators = [ 43 | self.tokenizer.eos_token_id, 44 | self.tokenizer.convert_tokens_to_ids("<|eot_id|>") 45 | ] 46 | if not hasattr(self, 'eos_token_id'): 47 | self.eos_token_id = [] 48 | for t in terminators: 49 | if t is not None: 50 | self.eos_token_id.append(t) 51 | elif self.config.model_type == "mimo": 52 | self.eos_token_id = self.config.eos_token_id 53 | self.bos_token_id = self.config.bos_token_id 54 | elif self.config.model_type == "tinyr1": 55 | self.eos_token_id = self.config.eos_token_id 56 | self.bos_token_id = self.config.bos_token_id 57 | 58 | def process_model_outputs(self, inputs, outputs, tokenizer): 59 | responses = [] 60 | for input_ids, output_ids in zip(inputs['input_ids'], outputs): 61 | response = tokenizer.decode(output_ids[len(input_ids):], skip_special_tokens=True).strip() 62 | responses.append(response) 63 | return responses 64 | 65 | def build_chat_input(self, query:str, history=None): 66 | if history is None: 67 | history = [] 68 | history.append(query) 69 | max_input_tokens = 0 70 | new_batch_input = self.tokenizer.apply_chat_template(history, add_generation_prompt=True, tokenize=False) 71 | max_input_tokens = max(max_input_tokens, len(new_batch_input)) 72 | return new_batch_input, max_input_tokens 73 | 74 | def predict(self, query: str | list = '', history: List = None, max_length=512, max_new_tokens=512, num_beams:int=1, top_p: float = 1.0, temperature=0, do_sample: bool = False, build_message=False): 75 | if not isinstance(query, list): 76 | query = [query] 77 | history = [history] if history is not None else None 78 | if build_message: 79 | inputs = [] 80 | batch_max_len = 0 81 | for i, t in enumerate(query): 82 | if isinstance(t, str): 83 | t = {'role': 'user', 'content': t} 84 | if history is not None and len(history) > 0: 85 | h_unit = history[i] 86 | else: 87 | h_unit = [] 88 | t, max_input_tokens = self.build_chat_input(t, h_unit) 89 | if batch_max_len < max_input_tokens: 90 | batch_max_len = max_input_tokens 91 | inputs.append(t) 92 | else: 93 | inputs = query 94 | batch_max_len = 0 95 | for i in range(len(query)): 96 | if len(query[i]) > batch_max_len: 97 | batch_max_len = len(query[i]) 98 | 99 | sampling_params = SamplingParams(temperature=temperature, top_p=top_p, max_tokens=max_new_tokens) 100 | 101 | outputs = self.llm.generate(inputs, sampling_params) 102 | results = [] 103 | for output in outputs: 104 | prompt = output.prompt 105 | generated_text = output.outputs[0].text 106 | results.append(generated_text) 107 | return results 108 | 109 | def __call__(self, query: str | list = '', history: List = None, max_length=512, max_new_tokens=512, num_beams:int=1, top_p: float = 0.8, temperature=1.0, do_sample: bool = False, build_message=False): 110 | return self.predict(query, history, max_length, max_new_tokens, num_beams, top_p, temperature, do_sample, build_message) 111 | -------------------------------------------------------------------------------- /main/loaders/chatglm_rlhf.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import torch 4 | import random 5 | import pickle 6 | import numpy as np 7 | from tqdm import tqdm 8 | from torch.utils.data import Dataset, DataLoader 9 | from transformers import PretrainedConfig, PreTrainedTokenizer 10 | 11 | 12 | class ChatGLM_RLHFDataset(Dataset): 13 | 14 | config: PretrainedConfig 15 | tokenizer: PreTrainedTokenizer 16 | 17 | def __init__(self, tokenizer, config, file_name, max_length=512, do_shuffle=False): 18 | self.config = config 19 | self.tokenizer = tokenizer 20 | self.max_length = max_length 21 | self.do_shuffle = do_shuffle 22 | self.data = self.load_jsonl(file_name) 23 | self.random_list = [idx for idx in range(len(self.data))] 24 | self.preprocess_list = [self.process_item(item) for item in self.data] 25 | if self.do_shuffle: 26 | random.shuffle(self.random_list) 27 | 28 | def load_jsonl(self, file_name): 29 | with open(file_name, 'r') as f: 30 | lines = f.readlines() 31 | data = [json.loads(line) for line in lines] 32 | return data 33 | 34 | def process_item(self, item): 35 | conv = item['conversations'] if 'conversations' in item else item 36 | 37 | input_ids, loss_masks = [ 38 | self.tokenizer.get_command('[gMASK]'), 39 | self.tokenizer.get_command('sop'), 40 | ], [False, False] 41 | 42 | last_input_len = 0 43 | last_user_content = '' 44 | last_assistant_content = '' 45 | 46 | for message in conv: 47 | if message['role'] in ('system', 'user'): 48 | loss_mask_val = False 49 | else: 50 | loss_mask_val = True 51 | 52 | if message['role'] == 'tool': 53 | raise NotImplementedError() 54 | else: 55 | new_input_ids = self.tokenizer.build_single_message( 56 | message['role'], '', message['content'] 57 | ) 58 | new_loss_masks = [loss_mask_val] * len(new_input_ids) 59 | if message['role'] == 'user': 60 | last_user_content = message['content'] 61 | last_input_len = len(new_input_ids) 62 | elif message['role'] == 'assistant': 63 | last_assistant_content = message['content'] 64 | role_ids = self.tokenizer.build_single_message( 65 | message['role'], '', '' 66 | ) 67 | last_input_len = len(new_input_ids) - len(role_ids) 68 | 69 | input_ids += new_input_ids 70 | loss_masks += new_loss_masks 71 | 72 | input_ids.append(self.tokenizer.eos_token_id) 73 | loss_masks = [False, *loss_masks] 74 | labels = [] 75 | for input_id, mask in zip(input_ids, loss_masks): 76 | if mask: 77 | labels.append(input_id) 78 | else: 79 | labels.append(-100) 80 | max_length = self.max_length 81 | 82 | input_ids_without_last_turn = input_ids[:-last_input_len] 83 | labels_without_last_turn = labels[:-last_input_len] 84 | 85 | exceed_len = len(input_ids) - max_length 86 | if exceed_len > 0: 87 | last_input_len = last_input_len - exceed_len 88 | input_ids = input_ids[:max_length] 89 | labels = labels[:max_length] 90 | input_ids_without_last_turn = input_ids_without_last_turn[:max_length] 91 | labels_without_last_turn = labels_without_last_turn[:max_length] 92 | return { 93 | 'input_ids': input_ids, 94 | 'labels': labels, 95 | 'input_ids_without_last_turn': input_ids_without_last_turn, 96 | 'labels_without_last_turn': labels_without_last_turn, 97 | 'last_input_len': last_input_len, 98 | 'last_user_content': last_user_content, 99 | 'last_assistant_content': last_assistant_content} 100 | 101 | def __getitem__(self, index): 102 | index = self.random_list[index] 103 | item = self.data[index] 104 | input_ids, labels, input_ids_without_last_turn, labels_without_last_turn, last_input_len, last_user_content, last_assistant_content = self.preprocess_list[index].values( 105 | ) 106 | gold_answers = item['gold_answers'] 107 | bad_answers = item['bad_answers'] 108 | 109 | input_ids = torch.tensor(input_ids) 110 | labels = torch.tensor(labels) 111 | input_ids_without_last_turn = torch.tensor(input_ids_without_last_turn) 112 | labels_without_last_turn = torch.tensor(labels_without_last_turn) 113 | 114 | return { 115 | 'query': last_user_content, 116 | 'gold_answers': gold_answers, 117 | 'bad_answers': bad_answers, 118 | 'input_ids': input_ids, 119 | 'input_ids_without_last_turn': input_ids_without_last_turn, 120 | 'labels_without_last_turn': labels_without_last_turn, 121 | 'last_input_len': torch.tensor(last_input_len), 122 | 'last_assistant_content': last_assistant_content, 123 | 'labels': labels 124 | } 125 | 126 | def __len__(self): 127 | return len(self.data) 128 | -------------------------------------------------------------------------------- /main/predictor/chatglm.py: -------------------------------------------------------------------------------- 1 | import json 2 | import torch 3 | from transformers import AutoTokenizer, AutoModel 4 | from typing import Tuple, List 5 | 6 | 7 | class Predictor(): 8 | 9 | def __init__(self, 10 | num_gpus: list = [0], 11 | model_from_pretrained: str = None, 12 | **args 13 | ): 14 | ''' 15 | Predictor: ChatGLM预测器 (ChatGLM predictor) 16 | 17 | ### Args: 18 | 19 | `num_gpus`: 使用的GPU编号列表 (the list of GPU numbers) 20 | 21 | `model_config_file_name`: bert配置文件名 (bert config file name) 22 | ''' 23 | self.num_gpus = num_gpus 24 | self.model_from_pretrained = model_from_pretrained 25 | self.model_init() 26 | 27 | def model_init(self): 28 | self.tokenizer = AutoTokenizer.from_pretrained( 29 | self.model_from_pretrained, trust_remote_code=True) 30 | self.model = AutoModel.from_pretrained( 31 | self.model_from_pretrained, trust_remote_code=True).half().cuda() 32 | self.model_to_device(gpu=self.num_gpus) 33 | self.model = self.model.eval() 34 | 35 | def model_to_device(self, gpu=[0]): 36 | self.device = torch.device( 37 | "cuda:0" if torch.cuda.is_available() else "cpu") 38 | self.model.cuda() 39 | self.model = torch.nn.DataParallel(self.model, device_ids=gpu).cuda() 40 | self.model.to(self.device) 41 | self.true_model = self.model.module if hasattr( 42 | self.model, 'module') else self.model 43 | 44 | def build_chat_input(self, query, history=None, role="user"): 45 | if history is None: 46 | history = [] 47 | input_ids = [] 48 | for item in history: 49 | content = item["content"] 50 | if item["role"] == "system" and "tools" in item: 51 | content = content + "\n" + json.dumps(item["tools"], indent=4, ensure_ascii=False) 52 | input_ids.extend(self.tokenizer.build_single_message(item["role"], item.get("metadata", ""), content)) 53 | input_ids.extend(self.tokenizer.build_single_message(role, "", query)) 54 | input_ids.extend([self.tokenizer.get_command("<|assistant|>")]) 55 | return input_ids 56 | 57 | def predict(self, query: str | list = '', history: List = None, max_new_tokens=512, num_beams:int=1, top_p: float = 0.8, temperature=1.0, do_sample: bool = False, build_message=False): 58 | if isinstance(query, str): 59 | query = [query] 60 | history = [history] if history is not None else None 61 | with torch.no_grad(): 62 | if build_message: 63 | inputs = [] 64 | batch_max_len = 0 65 | for i, t in enumerate(query): 66 | if history is not None and len(history) > 0: 67 | h_unit = history[i] 68 | t = self.build_chat_input(t, h_unit) 69 | else: 70 | t = self.tokenizer.build_single_message("user", "", t) 71 | t.extend([self.tokenizer.get_command("<|assistant|>")]) 72 | if batch_max_len < len(t): 73 | batch_max_len = len(t) 74 | inputs.append(t) 75 | for idx, t in enumerate(inputs): 76 | remain = batch_max_len - len(t) 77 | inputs[idx] = [self.tokenizer.pad_token_id] * remain + t 78 | else: 79 | inputs = self.tokenizer( 80 | query, padding=True, truncation=True)['input_ids'] 81 | input_ids = torch.LongTensor(inputs).to(self.device) 82 | output = self.true_model.generate(**{ 83 | 'input_ids': input_ids, 84 | 'max_new_tokens': max_new_tokens, 85 | 'num_beams': num_beams, 86 | 'do_sample': do_sample, 87 | 'top_p': top_p, 88 | "temperature": temperature, 89 | "eos_token_id": self.true_model.config.eos_token_id 90 | }) 91 | out_text = self.tokenizer.batch_decode( 92 | output, skip_special_tokens=True) 93 | if build_message: 94 | out_text = [self.true_model.process_response(t, [])[0] for t in out_text] 95 | return out_text 96 | 97 | @torch.inference_mode() 98 | def chat(self, query: str, history: List[Tuple[str, str]] = None, role: str = "user", 99 | max_length: int = 8192, num_beams=1, do_sample=True, top_p=0.8, temperature=0.8, logits_processor=None, 100 | **kwargs): 101 | response, history = self.true_model.chat( 102 | self.tokenizer, query, history, role, max_length, num_beams, do_sample, top_p, temperature, logits_processor, **kwargs) 103 | return response, history 104 | 105 | @torch.inference_mode() 106 | def stream_chat(self, query: str, history: List[Tuple[str, str]] = None, role: str = "user", 107 | past_key_values=None, max_length: int = 8192, do_sample=True, top_p=0.8, temperature=0.8, 108 | logits_processor=None, return_past_key_values=False, **kwargs): 109 | for result in self.true_model.stream_chat(self.tokenizer, query, history, role, past_key_values, max_length, do_sample, top_p, temperature, logits_processor, return_past_key_values, **kwargs): 110 | yield result 111 | 112 | def __call__(self, query: str | list = '', history: List = None, max_new_tokens=512, num_beams:int=1, top_p: float = 0.8, temperature=1.0, do_sample: bool = False, build_message=False): 113 | return self.predict(query, history, max_new_tokens, num_beams, top_p, temperature, do_sample, build_message) 114 | -------------------------------------------------------------------------------- /main/predictor/chatglm_lora.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoModel 2 | from transformers import AutoTokenizer 3 | from peft import LoraConfig, TaskType, PeftModel, PeftModelForCausalLM 4 | from typing import Tuple, List 5 | import json 6 | import torch 7 | 8 | 9 | class Predictor(): 10 | true_model: PeftModelForCausalLM 11 | 12 | def __init__(self, 13 | num_gpus: list = [0], 14 | model_from_pretrained: str = None, 15 | resume_path: str = None 16 | ): 17 | peft_config = LoraConfig( 18 | task_type=TaskType.CAUSAL_LM, 19 | inference_mode=False, 20 | r=16, 21 | lora_alpha=32, 22 | lora_dropout=0.1, 23 | ) 24 | self.tokenizer = AutoTokenizer.from_pretrained( 25 | model_from_pretrained, trust_remote_code=True) 26 | self.model = AutoModel.from_pretrained( 27 | model_from_pretrained, trust_remote_code=True).half().cuda() 28 | self.model = PeftModel.from_pretrained( 29 | self.model, resume_path, config=peft_config) 30 | self.model_to_device(gpu=num_gpus) 31 | 32 | def model_to_device(self, gpu=[0]): 33 | self.device = torch.device( 34 | "cuda:0" if torch.cuda.is_available() else "cpu") 35 | self.model.cuda() 36 | self.model = torch.nn.DataParallel(self.model, device_ids=gpu).cuda() 37 | self.model.to(self.device) 38 | self.true_model = self.model.module if hasattr( 39 | self.model, 'module') else self.model 40 | 41 | def build_chat_input(self, query, history=None, role="user"): 42 | if history is None: 43 | history = [] 44 | input_ids = [] 45 | for item in history: 46 | content = item["content"] 47 | if item["role"] == "system" and "tools" in item: 48 | content = content + "\n" + json.dumps(item["tools"], indent=4, ensure_ascii=False) 49 | input_ids.extend(self.tokenizer.build_single_message(item["role"], item.get("metadata", ""), content)) 50 | input_ids.extend(self.tokenizer.build_single_message(role, "", query)) 51 | input_ids.extend([self.tokenizer.get_command("<|assistant|>")]) 52 | return input_ids 53 | 54 | def predict(self, query: str | list = '', history: List = None, max_new_tokens=512, num_beams:int=1, top_p: float = 0.8, temperature=1.0, do_sample: bool = False, build_message=False): 55 | if isinstance(query, str): 56 | query = [query] 57 | history = [history] if history is not None else None 58 | with torch.no_grad(): 59 | if build_message: 60 | inputs = [] 61 | batch_max_len = 0 62 | for i, t in enumerate(query): 63 | if history is not None and len(history) > 0: 64 | h_unit = history[i] 65 | t = self.build_chat_input(t, h_unit) 66 | else: 67 | t = self.tokenizer.build_single_message("user", "", t) 68 | t.extend([self.tokenizer.get_command("<|assistant|>")]) 69 | if batch_max_len < len(t): 70 | batch_max_len = len(t) 71 | inputs.append(t) 72 | for idx, t in enumerate(inputs): 73 | remain = batch_max_len - len(t) 74 | inputs[idx] = [self.tokenizer.pad_token_id] * remain + t 75 | else: 76 | inputs = self.tokenizer( 77 | query, padding=True, truncation=True)['input_ids'] 78 | input_ids = torch.LongTensor(inputs).to(self.device) 79 | output = self.true_model.generate(**{ 80 | 'input_ids': input_ids, 81 | 'max_new_tokens': max_new_tokens, 82 | 'num_beams': num_beams, 83 | 'do_sample': do_sample, 84 | 'top_p': top_p, 85 | "temperature": temperature, 86 | 'do_sample': False 87 | }) 88 | out_text = self.tokenizer.batch_decode( 89 | output, skip_special_tokens=True) 90 | if build_message: 91 | out_text = [self.true_model.process_response(t, [])[0] for t in out_text] 92 | return out_text 93 | 94 | @torch.inference_mode() 95 | def chat(self, query: str, history: List[Tuple[str, str]] = None, role: str = "user", 96 | max_length: int = 8192, num_beams=1, do_sample=True, top_p=0.8, temperature=0.8, logits_processor=None, 97 | **kwargs): 98 | response, history = self.true_model.chat( 99 | self.tokenizer, query, history, role, max_length, num_beams, do_sample, top_p, temperature, logits_processor, **kwargs) 100 | return response, history 101 | 102 | @torch.inference_mode() 103 | def stream_chat(self, query: str, history: List[Tuple[str, str]] = None, role: str = "user", 104 | past_key_values=None, max_length: int = 8192, do_sample=True, top_p=0.8, temperature=0.8, 105 | logits_processor=None, return_past_key_values=False, **kwargs): 106 | for result in self.true_model.stream_chat(self.tokenizer, query, history, role, past_key_values, max_length, do_sample, top_p, temperature, logits_processor, return_past_key_values, **kwargs): 107 | yield result 108 | 109 | def __call__(self, query: str | list = '', history: List = None, max_new_tokens=512, num_beams:int=1, top_p: float = 0.8, temperature=1.0, do_sample: bool = False, build_message=False): 110 | return self.predict(query=query, history=history, max_new_tokens=max_new_tokens, num_beams=num_beams, top_p=top_p, temperature=temperature, do_sample=do_sample, build_message=build_message) 111 | -------------------------------------------------------------------------------- /discuss/compute_range_ccs.py: -------------------------------------------------------------------------------- 1 | # %% 2 | import os 3 | import json 4 | import random 5 | import json_repair 6 | import numpy as np 7 | from tqdm import tqdm 8 | from argparse import ArgumentParser 9 | 10 | import sys 11 | sys.path.append("../") 12 | cmd_args = True 13 | 14 | parser = ArgumentParser() 15 | parser.add_argument('--file_dir', default='../datasets', help='the directory of the datasets.') 16 | parser.add_argument('--file_name', default='all', help='file name of the dataset without extension (e.g. 0_Physics_ShortAns), you can define `all` for all dataset files.') 17 | parser.add_argument('--save_type_name', default='Deepseek', help='the prefix name of save dir (usually is the LLM name)') 18 | 19 | if not cmd_args: 20 | args = parser.parse_args([]) # You can directly set above parameters in the default. 21 | else: 22 | args = parser.parse_args() 23 | 24 | SOURCE_DIR = os.path.join(args.file_dir + '_' + args.save_type_name + '_Scored') 25 | 26 | from utils.collaborative_consistency_score import adjusted_qwk 27 | 28 | results = [] 29 | if args.file_name != 'all': 30 | files = [args.file_name + '_prediction.jsonl'] 31 | else: 32 | files = os.listdir(SOURCE_DIR) 33 | for file_name in files: 34 | if file_name.find('prediction') < 0: 35 | continue 36 | SOURCE_FILE = os.path.join(SOURCE_DIR, file_name) 37 | with open(SOURCE_FILE, encoding='utf-8') as f: 38 | ori_data = f.readlines() 39 | ori_data = [json.loads(line) for line in ori_data] 40 | 41 | pred_result = [] 42 | ori_result = [] 43 | weights_result = [] 44 | total_scores = [] 45 | max_score = 0 46 | max_length = 0 47 | for item in tqdm(ori_data): 48 | if max_score < item['total']: 49 | max_score = item['total'] 50 | total_scores.append(item['total']) 51 | ori_steps = item['steps'] 52 | if 'pred_steps' not in item: 53 | pred_steps = [{"step_score": 0, "errors": []} for _ in ori_steps] 54 | else: 55 | pred_steps = item['pred_steps'] 56 | ori_score = item['manual_label'] 57 | if 'pred_label' not in item or item['pred_label'] == '': 58 | pred_score = 0 59 | else: 60 | pred_score = item['pred_label'] 61 | ori_labels = [int(ori_score)] 62 | pred_labels = [int(pred_score)] 63 | weights = [0.5] + [0.5 / len(ori_steps) for _ in range(len(ori_steps))] 64 | for i in range(len(ori_steps)): 65 | try: 66 | ori_labels.append(int(ori_steps[i]['label'])) 67 | except: 68 | ori_labels.append(0) 69 | if len(pred_steps) > i and type(pred_steps[i]) == dict and 'step_score' in pred_steps[i]: 70 | pred_labels.append(int(pred_steps[i]['step_score'])) 71 | else: 72 | pred_labels.append(0) 73 | if max_length < len(ori_labels): 74 | max_length = len(ori_labels) 75 | pred_result.append(pred_labels) 76 | ori_result.append(ori_labels) 77 | weights_result.append(weights) 78 | 79 | # Padding 80 | for i in range(len(pred_result)): 81 | pred_result[i] += [0] * (max_length - len(pred_result[i])) 82 | ori_result[i] += [0] * (max_length - len(ori_result[i])) 83 | weights_result[i] += [0] * (max_length - len(weights_result[i])) 84 | max_scores = [] 85 | for i in range(len(ori_result[0])): 86 | val_i = [] 87 | for item in ori_result: 88 | val_i.append(item[i]) 89 | for item in pred_result: 90 | val_i.append(item[i]) 91 | max_scores.append(max(val_i)) 92 | max_scores[0] = max_score 93 | 94 | all_norm_scores = [] 95 | Ln = 1e-10 96 | for scores, max_s in zip(ori_result, total_scores): 97 | norm_score = scores[0] / (max_s + Ln) 98 | all_norm_scores.append(norm_score) 99 | range1, range2 = 0, 0 100 | sorted_scores = sorted(all_norm_scores) 101 | range1 = sorted_scores[int(len(sorted_scores) * 0.33)] 102 | range2 = sorted_scores[int(len(sorted_scores) * 0.67)] 103 | low_list = { 104 | 'ori_result': [], 105 | 'pred_result': [], 106 | 'weights_result': [], 107 | 'max_scores': max_scores 108 | } 109 | mid_list = { 110 | 'ori_result': [], 111 | 'pred_result': [], 112 | 'weights_result': [], 113 | 'max_scores': max_scores 114 | } 115 | high_list = { 116 | 'ori_result': [], 117 | 'pred_result': [], 118 | 'weights_result': [], 119 | 'max_scores': max_scores 120 | } 121 | for i in range(len(total_scores)): 122 | total = total_scores[i] 123 | gold = ori_result[i][0] 124 | gold = gold / total 125 | if gold <= range1: 126 | low_list['ori_result'].append(ori_result[i]) 127 | low_list['pred_result'].append(pred_result[i]) 128 | low_list['weights_result'].append(weights_result[i]) 129 | elif gold < range2: 130 | mid_list['ori_result'].append(ori_result[i]) 131 | mid_list['pred_result'].append(pred_result[i]) 132 | mid_list['weights_result'].append(weights_result[i]) 133 | else: 134 | high_list['ori_result'].append(ori_result[i]) 135 | high_list['pred_result'].append(pred_result[i]) 136 | high_list['weights_result'].append(weights_result[i]) 137 | results.append((file_name, adjusted_qwk(*low_list.values())[0], adjusted_qwk(*mid_list.values())[0], adjusted_qwk(*high_list.values())[0])) 138 | 139 | # %% 140 | def sort_func(x): 141 | try: 142 | x = x[0] 143 | x = x.split('_')[0] 144 | x = int(x) 145 | return x 146 | except: 147 | return 0 148 | 149 | results = sorted(results, key=sort_func) 150 | for item in results: 151 | print(item) 152 | 153 | with open(f'{args.save_type_name}_range_ccs.csv', 'w+') as f: 154 | for item in results: 155 | cols = [item[0], str(round(item[1] * 100, 2)), str(round(item[2] * 100 ,2)), str(round(item[3] * 100 ,2))] 156 | f.write(','.join(cols) + '\n') 157 | 158 | # %% 159 | -------------------------------------------------------------------------------- /main/predictor/llm_lora.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoModel 2 | from transformers import AutoTokenizer, AutoConfig, LlamaForCausalLM, AutoModelForCausalLM 3 | from peft import LoraConfig, TaskType, PeftModel, PeftModelForCausalLM 4 | from typing import Tuple, List 5 | import json 6 | import torch 7 | 8 | 9 | class Predictor(): 10 | true_model: PeftModelForCausalLM 11 | 12 | def __init__(self, 13 | num_gpus: list = [0], 14 | model_from_pretrained: str = None, 15 | resume_path: str = None, 16 | lora_r=16, lora_alpha=32, lora_dropout=0.1, 17 | ): 18 | peft_config = LoraConfig( 19 | task_type=TaskType.CAUSAL_LM, 20 | inference_mode=False, 21 | r=lora_r, 22 | lora_alpha=lora_alpha, 23 | lora_dropout=lora_dropout, 24 | ) 25 | self.config = AutoConfig.from_pretrained( 26 | model_from_pretrained, trust_remote_code=True) 27 | self.tokenizer = AutoTokenizer.from_pretrained( 28 | model_from_pretrained, trust_remote_code=True) 29 | 30 | if self.config.model_type == 'chatglm': 31 | self.model = AutoModel.from_pretrained( 32 | self.model_from_pretrained, trust_remote_code=True).to(torch.bfloat16) 33 | self.eos_token_id = self.config.eos_token_id 34 | elif self.config.model_type == 'llama': 35 | self.tokenizer.pad_token = self.tokenizer.eos_token 36 | self.model = LlamaForCausalLM.from_pretrained( 37 | self.model_from_pretrained, trust_remote_code=True).to(torch.bfloat16) 38 | terminators = [ 39 | self.tokenizer.eos_token_id, 40 | self.tokenizer.convert_tokens_to_ids("<|eot_id|>") 41 | ] 42 | self.eos_token_id = terminators 43 | elif self.config.model_type == 'qwen': 44 | self.model = AutoModelForCausalLM.from_pretrained( 45 | self.model_from_pretrained, torch_dtype="auto", device_map="auto", trust_remote_code=True) 46 | elif self.config.model_type == 'qwen2': 47 | self.model = AutoModelForCausalLM.from_pretrained( 48 | self.model_from_pretrained, torch_dtype="auto", device_map="auto", trust_remote_code=True) 49 | 50 | self.model = PeftModel.from_pretrained( 51 | self.model, resume_path, config=peft_config) 52 | self.model_to_device(gpu=num_gpus) 53 | 54 | def model_to_device(self, gpu=[0]): 55 | self.device = torch.device( 56 | "cuda:0" if torch.cuda.is_available() else "cpu") 57 | self.model.cuda() 58 | self.model = torch.nn.DataParallel(self.model, device_ids=gpu).cuda() 59 | self.model.to(self.device) 60 | self.true_model = self.model.module if hasattr( 61 | self.model, 'module') else self.model 62 | 63 | def process_model_outputs(self, inputs, outputs, tokenizer): 64 | responses = [] 65 | for input_ids, output_ids in zip(inputs['input_ids'], outputs): 66 | response = tokenizer.decode(output_ids[len(input_ids):], skip_special_tokens=True).strip() 67 | responses.append(response) 68 | return responses 69 | 70 | def build_chat_input(self, query:str, history=None): 71 | if history is None: 72 | history = [] 73 | history.append(query) 74 | max_input_tokens = 0 75 | new_batch_input = self.tokenizer.apply_chat_template(history, add_generation_prompt=True, tokenize=False) 76 | max_input_tokens = max(max_input_tokens, len(new_batch_input)) 77 | return new_batch_input, max_input_tokens 78 | 79 | def predict(self, query: str | list = '', history: List = None, max_length=512, max_new_tokens=512, num_beams:int=1, top_p: float = 0.8, temperature=1.0, do_sample: bool = False, build_message=False): 80 | if not isinstance(query, list): 81 | query = [query] 82 | history = [history] if history is not None else None 83 | with torch.no_grad(): 84 | if build_message: 85 | inputs = [] 86 | batch_max_len = 0 87 | for i, t in enumerate(query): 88 | if isinstance(t, str): 89 | t = {'role': 'user', 'content': t} 90 | if history is not None and len(history) > 0: 91 | h_unit = history[i] 92 | else: 93 | h_unit = [] 94 | t, max_input_tokens = self.build_chat_input(t, h_unit) 95 | if batch_max_len < max_input_tokens: 96 | batch_max_len = max_input_tokens 97 | inputs.append(t) 98 | else: 99 | inputs = query 100 | batch_max_len = 0 101 | for i in range(len(query)): 102 | if len(query[i]) > batch_max_len: 103 | batch_max_len = len(query[i]) 104 | batched_inputs = self.tokenizer( 105 | inputs, 106 | return_tensors="pt", 107 | padding=True, 108 | truncation=True).to(self.device) 109 | if self.config.model_type == 'llama': 110 | batched_inputs = batched_inputs.data 111 | batched_outputs = self.true_model.generate(**batched_inputs, **{ 112 | 'max_new_tokens': max_new_tokens, 113 | 'num_beams': num_beams, 114 | 'do_sample': do_sample, 115 | 'top_p': top_p, 116 | "temperature": temperature, 117 | "eos_token_id": self.eos_token_id 118 | }) 119 | batched_response = self.process_model_outputs(batched_inputs, batched_outputs, self.tokenizer) 120 | return batched_response 121 | 122 | def __call__(self, query: str | list = '', history: List = None, max_length=512, max_new_tokens=512, num_beams:int=1, top_p: float = 0.8, temperature=1.0, do_sample: bool = False, build_message=False): 123 | return self.predict(query, history, max_length, max_new_tokens, num_beams, top_p, temperature, do_sample, build_message) 124 | -------------------------------------------------------------------------------- /main/predictor/llm.py: -------------------------------------------------------------------------------- 1 | import json 2 | import torch 3 | from transformers import AutoTokenizer, AutoModel, AutoConfig, LlamaForCausalLM, AutoModelForCausalLM 4 | from typing import Tuple, List 5 | 6 | 7 | class Predictor(): 8 | 9 | def __init__(self, 10 | num_gpus: list = [0], 11 | model_from_pretrained: str = None, 12 | **args 13 | ): 14 | ''' 15 | Predictor: LLM预测器 (LLM predictor) 16 | 17 | ### Args: 18 | 19 | `num_gpus`: 使用的GPU编号列表 (the list of GPU numbers) 20 | 21 | `model_config_file_name`: bert配置文件名 (bert config file name) 22 | ''' 23 | self.num_gpus = num_gpus 24 | self.model_from_pretrained = model_from_pretrained 25 | self.model_init() 26 | 27 | def model_init(self): 28 | self.config = AutoConfig.from_pretrained( 29 | self.model_from_pretrained, trust_remote_code=True) 30 | self.tokenizer = AutoTokenizer.from_pretrained( 31 | self.model_from_pretrained, padding_side="left", trust_remote_code=True) 32 | if hasattr(self.config, 'eos_token_id'): 33 | self.eos_token_id = [self.config.eos_token_id] 34 | if hasattr(self.config, 'bos_token_id'): 35 | self.bos_token_id = [self.config.bos_token_id] 36 | if self.config.model_type == 'chatglm': 37 | self.model = AutoModel.from_pretrained( 38 | self.model_from_pretrained, trust_remote_code=True).to(torch.bfloat16) 39 | self.eos_token_id = self.config.eos_token_id 40 | elif self.config.model_type == 'llama': 41 | self.tokenizer.pad_token = self.tokenizer.eos_token 42 | self.model = LlamaForCausalLM.from_pretrained( 43 | self.model_from_pretrained, trust_remote_code=True).to(torch.bfloat16) 44 | terminators = [ 45 | self.tokenizer.eos_token_id, 46 | self.tokenizer.convert_tokens_to_ids("<|eot_id|>") 47 | ] 48 | if not hasattr(self, 'eos_token_id'): 49 | self.eos_token_id = [] 50 | for t in terminators: 51 | if t is not None: 52 | self.eos_token_id.append(t) 53 | elif self.config.model_type == 'qwen': 54 | self.model = AutoModelForCausalLM.from_pretrained( 55 | self.model_from_pretrained, torch_dtype="auto", device_map="auto", trust_remote_code=True) 56 | elif self.config.model_type == 'qwen2': 57 | self.model = AutoModelForCausalLM.from_pretrained( 58 | self.model_from_pretrained, torch_dtype="auto", device_map="auto", trust_remote_code=True) 59 | elif self.config.model_type == "mimo": 60 | self.eos_token_id = self.config.eos_token_id 61 | self.bos_token_id = self.config.bos_token_id 62 | self.model = AutoModelForCausalLM.from_pretrained(self.model_from_pretrained, device_map="auto",trust_remote_code=True) 63 | elif self.config.model_type == "tinyr1": 64 | self.eos_token_id = self.config.eos_token_id 65 | self.bos_token_id = self.config.bos_token_id 66 | self.model = AutoModelForCausalLM.from_pretrained(self.model_from_pretrained, device_map="auto",trust_remote_code=True) 67 | self.model_to_device(gpu=self.num_gpus) 68 | self.model = self.model.eval() 69 | 70 | def model_to_device(self, gpu=[0]): 71 | self.device = torch.device( 72 | "cuda:0" if torch.cuda.is_available() else "cpu") 73 | self.model.cuda() 74 | self.model = torch.nn.DataParallel(self.model, device_ids=gpu).cuda() 75 | self.model.to(self.device) 76 | self.true_model = self.model.module if hasattr( 77 | self.model, 'module') else self.model 78 | 79 | def process_model_outputs(self, inputs, outputs, tokenizer): 80 | responses = [] 81 | for input_ids, output_ids in zip(inputs['input_ids'], outputs): 82 | response = tokenizer.decode(output_ids[len(input_ids):], skip_special_tokens=True).strip() 83 | responses.append(response) 84 | return responses 85 | 86 | def build_chat_input(self, query:str, history=None): 87 | if history is None: 88 | history = [] 89 | history.append(query) 90 | max_input_tokens = 0 91 | new_batch_input = self.tokenizer.apply_chat_template(history, add_generation_prompt=True, tokenize=False) 92 | max_input_tokens = max(max_input_tokens, len(new_batch_input)) 93 | return new_batch_input, max_input_tokens 94 | 95 | def predict(self, query: str | list = '', history: List = None, max_length=512, max_new_tokens=512, num_beams:int=1, top_p: float = 0.8, temperature=1.0, do_sample: bool = False, build_message=False): 96 | if not isinstance(query, list): 97 | query = [query] 98 | history = [history] if history is not None else None 99 | with torch.no_grad(): 100 | if build_message: 101 | inputs = [] 102 | batch_max_len = 0 103 | for i, t in enumerate(query): 104 | if isinstance(t, str): 105 | t = {'role': 'user', 'content': t} 106 | if history is not None and len(history) > 0: 107 | h_unit = history[i] 108 | else: 109 | h_unit = [] 110 | t, max_input_tokens = self.build_chat_input(t, h_unit) 111 | if batch_max_len < max_input_tokens: 112 | batch_max_len = max_input_tokens 113 | inputs.append(t) 114 | else: 115 | inputs = query 116 | batch_max_len = 0 117 | for i in range(len(query)): 118 | if len(query[i]) > batch_max_len: 119 | batch_max_len = len(query[i]) 120 | batched_inputs = self.tokenizer( 121 | inputs, 122 | return_tensors="pt", 123 | padding=True, 124 | truncation=True).to(self.device) 125 | if self.config.model_type == 'llama': 126 | batched_inputs = batched_inputs.data 127 | batched_outputs = self.true_model.generate(**batched_inputs, **{ 128 | 'max_new_tokens': max_new_tokens, 129 | 'num_beams': num_beams, 130 | 'do_sample': do_sample, 131 | 'top_p': top_p, 132 | "temperature": temperature, 133 | "eos_token_id": self.eos_token_id 134 | }) 135 | batched_response = self.process_model_outputs(batched_inputs, batched_outputs, self.tokenizer) 136 | return batched_response 137 | 138 | def __call__(self, query: str | list = '', history: List = None, max_length=512, max_new_tokens=512, num_beams:int=1, top_p: float = 0.8, temperature=1.0, do_sample: bool = False, build_message=False): 139 | return self.predict(query, history, max_length, max_new_tokens, num_beams, top_p, temperature, do_sample, build_message) 140 | -------------------------------------------------------------------------------- /main/loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import random 4 | from torch.utils.data import TensorDataset, DataLoader, Dataset 5 | from main.loaders.llm_pure_text import LLMPureTextDataset 6 | from main.loaders.chatglm_chat import ChatGLM_ChatDataset 7 | from main.loaders.qwen_chat import QwenChatDataset 8 | from main.loaders.llm_chat import LLMChatDataset 9 | from main.loaders.chatglm_rlhf import ChatGLM_RLHFDataset 10 | import torch 11 | 12 | def collate_fn_wrapper(tokenizer): 13 | def left_pad_collate_fn(batch): 14 | result = {} 15 | max_length = 0 16 | max_length_without_last_turn = 0 17 | for item in batch: 18 | if item['input_ids'].shape[0] > max_length: 19 | max_length = item['input_ids'].shape[0] 20 | if 'input_ids_without_last_turn' in item and item['input_ids_without_last_turn'].shape[0] > max_length_without_last_turn: 21 | max_length_without_last_turn = item['input_ids_without_last_turn'].shape[0] 22 | for item in batch: 23 | for key in item: 24 | if key not in result: 25 | result[key] = [] 26 | if key in ('input_ids'): 27 | pad_length = max_length - len(item[key]) 28 | item[key] = torch.cat([torch.LongTensor([tokenizer.pad_token_id] * pad_length), item[key]], dim=-1) 29 | elif key == ('labels'): 30 | pad_length = max_length - len(item[key]) 31 | item[key] = torch.cat([torch.LongTensor([-100] * pad_length), item[key]], dim=-1) 32 | if key in ('input_ids_without_last_turn'): 33 | pad_length = max_length_without_last_turn - len(item[key]) 34 | item[key] = torch.cat([torch.LongTensor([tokenizer.pad_token_id] * pad_length), item[key]], dim=-1) 35 | elif key == ('labels_without_last_turn'): 36 | pad_length = max_length_without_last_turn - len(item[key]) 37 | item[key] = torch.cat([torch.LongTensor([-100] * pad_length), item[key]], dim=-1) 38 | result[key].append(item[key]) 39 | for key in result: 40 | if key in ('input_ids', 'labels', 'input_ids_without_last_turn', 'labels_without_last_turn', 'last_input_len'): 41 | result[key] = torch.stack(result[key]) 42 | return result 43 | return left_pad_collate_fn 44 | 45 | class AutoDataloader(): 46 | 47 | ''' 48 | loader_name: str; the dataloader name 49 | data_path: str or obj; the path of the data; if str, it will use the present dataset in data_present_path, or you should define the path like e.g. { 'train': './train.json', 'dev': './dev.json' } 50 | model_type: interactive or siamese 51 | data_present_path: str; the path of the data_present; the data_present is a json file which contains the path of the dataset, and the format is like e.g. { 'dataset_name': {'train': './train.json', 'dev': './dev.json'} } 52 | max_length: int; the length of the padding 53 | ''' 54 | 55 | def __init__(self, tokenizer, config, loader_name='LLM_Chat', data_path="Boss", data_present_path="./data/present.json", max_length=50): 56 | self.tokenizer = tokenizer 57 | self.loader_name = loader_name 58 | self.max_length = max_length 59 | self.data_present = self.get_data_present(data_present_path) 60 | self.data_path = self.data_present[data_path] if data_path in self.data_present else data_path 61 | if loader_name == 'LLM_Pure': 62 | self.train_set = LLMPureTextDataset( 63 | tokenizer, config, self.data_path['train'], max_length=self.max_length, do_shuffle=True) 64 | self.eval_set = LLMPureTextDataset( 65 | tokenizer, config, self.data_path['dev'], max_length=self.max_length, do_shuffle=False) 66 | if 'test' in self.data_path: 67 | self.test_set = LLMPureTextDataset( 68 | tokenizer, config, self.data_path['test'], max_length=self.max_length, do_shuffle=False) 69 | elif loader_name == 'ChatGLM_Chat': 70 | self.train_set = ChatGLM_ChatDataset( 71 | tokenizer, config, self.data_path['train'], max_length=self.max_length, do_shuffle=True) 72 | self.eval_set = ChatGLM_ChatDataset( 73 | tokenizer, config, self.data_path['dev'], max_length=self.max_length, do_shuffle=False) 74 | if 'test' in self.data_path: 75 | self.test_set = ChatGLM_ChatDataset( 76 | tokenizer, config, self.data_path['test'], max_length=self.max_length, do_shuffle=False) 77 | elif loader_name == 'Qwen_Chat': 78 | self.train_set = QwenChatDataset( 79 | tokenizer, config, self.data_path['train'], max_length=self.max_length, do_shuffle=True) 80 | self.eval_set = QwenChatDataset( 81 | tokenizer, config, self.data_path['dev'], max_length=self.max_length, do_shuffle=False) 82 | if 'test' in self.data_path: 83 | self.test_set = QwenChatDataset( 84 | tokenizer, config, self.data_path['test'], max_length=self.max_length, do_shuffle=False) 85 | elif loader_name == 'LLM_Chat': 86 | self.train_set = LLMChatDataset( 87 | tokenizer, config, self.data_path['train'], max_length=self.max_length, do_shuffle=True) 88 | self.eval_set = LLMChatDataset( 89 | tokenizer, config, self.data_path['dev'], max_length=self.max_length, do_shuffle=False) 90 | if 'test' in self.data_path: 91 | self.test_set = LLMChatDataset( 92 | tokenizer, config, self.data_path['test'], max_length=self.max_length, do_shuffle=False) 93 | elif loader_name == 'ChatGLM_RLHF': 94 | self.train_set = ChatGLM_RLHFDataset( 95 | tokenizer, config, self.data_path['train'], max_length=self.max_length, do_shuffle=True) 96 | self.eval_set = ChatGLM_RLHFDataset( 97 | tokenizer, config, self.data_path['dev'], max_length=self.max_length, do_shuffle=False) 98 | if 'test' in self.data_path: 99 | self.test_set = ChatGLM_RLHFDataset( 100 | tokenizer, config, self.data_path['test'], max_length=self.max_length, do_shuffle=False) 101 | 102 | def get_data_present(self, present_path): 103 | if not os.path.exists(present_path): 104 | return {} 105 | with open(present_path, encoding='utf-8') as f: 106 | present_json = f.read() 107 | data_present = json.loads(present_json) 108 | return data_present 109 | 110 | def __call__(self, batch_size=1, batch_size_eval=1, eval_mode='dev', use_collate=False): 111 | if not use_collate: 112 | dataiter = DataLoader(self.train_set, batch_size=batch_size) 113 | if eval_mode == 'dev': 114 | dataiter_eval = DataLoader( 115 | self.eval_set, batch_size=batch_size_eval) 116 | else: 117 | dataiter_eval = DataLoader( 118 | self.test_set, batch_size=batch_size_eval) 119 | else: 120 | left_pad_collate_fn = collate_fn_wrapper(self.tokenizer) 121 | dataiter = DataLoader(self.train_set, batch_size=batch_size, collate_fn=left_pad_collate_fn) 122 | if eval_mode == 'dev': 123 | dataiter_eval = DataLoader( 124 | self.eval_set, batch_size=batch_size_eval, collate_fn=left_pad_collate_fn) 125 | else: 126 | dataiter_eval = DataLoader( 127 | self.test_set, batch_size=batch_size_eval, collate_fn=left_pad_collate_fn) 128 | return dataiter, dataiter_eval 129 | -------------------------------------------------------------------------------- /main/trainer/llm_lora.py: -------------------------------------------------------------------------------- 1 | import os 2 | import uuid 3 | import torch 4 | import torch.nn as nn 5 | import torch.optim as optim 6 | import torch.nn.functional as F 7 | from transformers import AutoModel, AutoConfig, LlamaForCausalLM, AutoModelForCausalLM 8 | from transformers import get_linear_schedule_with_warmup 9 | from peft import get_peft_model, LoraConfig, TaskType, PeftModel 10 | import numpy as np 11 | import jieba 12 | from nltk.translate.bleu_score import SmoothingFunction, sentence_bleu 13 | from rouge_chinese import Rouge 14 | from tqdm import tqdm 15 | from main.loader import AutoDataloader 16 | from main.analysis import Analysis 17 | from accelerate import Accelerator 18 | accelerator = Accelerator() 19 | 20 | 21 | class Trainer(): 22 | 23 | def __init__(self, tokenizer, from_pretrained, loader_name, data_path, config=None, resume_path=None, max_length=512, batch_size=1, batch_size_eval=1, 24 | lora_r=16, lora_alpha=32, lora_dropout=0.1, 25 | eval_mode='dev', task_name='Sim'): 26 | self.tokenizer = tokenizer 27 | self.config = config 28 | self.accelerate = accelerator 29 | self.loader_name = loader_name 30 | self.data_path = data_path 31 | self.from_pretrained = from_pretrained 32 | self.data_path = data_path 33 | self.task_name = task_name 34 | self.max_length = max_length 35 | self.batch_size = batch_size 36 | self.batch_size_eval = batch_size_eval 37 | self.lora_r = lora_r 38 | self.lora_alpha = lora_alpha 39 | self.lora_dropout = lora_dropout 40 | self.eval_mode = eval_mode 41 | self.config_init() 42 | self.dataloader_init() 43 | self.model_init(resume_path=resume_path) 44 | self.analysis = Analysis() 45 | 46 | def config_init(self): 47 | self.config = AutoConfig.from_pretrained( 48 | self.from_pretrained, trust_remote_code=True) if self.config is None else self.config 49 | if self.config.model_type == 'llama': 50 | self.tokenizer.pad_token = self.tokenizer.eos_token 51 | 52 | def model_init(self, resume_path=None): 53 | if self.accelerate.is_local_main_process: 54 | print('AutoModel Choose Model: {}\n'.format( 55 | self.from_pretrained)) 56 | if self.config.model_type == 'chatglm': 57 | target_modules=['query_key_value'] 58 | self.model = AutoModel.from_pretrained( 59 | self.from_pretrained, trust_remote_code=True).to(torch.bfloat16) 60 | elif self.config.model_type == 'llama': 61 | target_modules=["q_proj", "k_proj", "v_proj"] 62 | self.model = LlamaForCausalLM.from_pretrained( 63 | self.from_pretrained, trust_remote_code=True).to(torch.bfloat16) 64 | elif self.config.model_type == 'qwen': 65 | target_modules=["c_attn", "c_proj", "w1", "w2"] 66 | self.model = AutoModelForCausalLM.from_pretrained( 67 | self.from_pretrained, torch_dtype="auto", device_map="auto", trust_remote_code=True) 68 | elif self.config.model_type == 'qwen2': 69 | target_modules=["q_proj", "k_proj", "v_proj"] 70 | self.model = AutoModelForCausalLM.from_pretrained( 71 | self.from_pretrained, torch_dtype="auto", device_map="auto", trust_remote_code=True) 72 | peft_config = LoraConfig( 73 | task_type=TaskType.CAUSAL_LM, 74 | inference_mode=False, 75 | r=self.lora_r, 76 | target_modules=target_modules, 77 | lora_alpha=self.lora_alpha, 78 | lora_dropout=self.lora_dropout 79 | ) 80 | if resume_path is not None: 81 | print('Accessing Resume PATH: {} ...\n'.format(resume_path)) 82 | self.model.enable_input_require_grads() 83 | self.model = PeftModel.from_pretrained( 84 | self.model, resume_path, config=peft_config) 85 | else: 86 | self.model = get_peft_model(self.model, peft_config) 87 | self.model.print_trainable_parameters() 88 | 89 | def dataloader_init(self): 90 | d = AutoDataloader(self.tokenizer, self.config, loader_name=self.loader_name, data_path=self.data_path, 91 | max_length=self.max_length) 92 | self.train_loader, self.eval_loader = d( 93 | self.batch_size, self.batch_size_eval, self.eval_mode, True) 94 | 95 | def __call__(self, resume_step=None, num_epochs=30, lr=1e-4, eval_call_epoch=None): 96 | return self.train(resume_step=resume_step, 97 | num_epochs=num_epochs, lr=lr, eval_call_epoch=eval_call_epoch) 98 | 99 | def train(self, resume_step=None, num_epochs=30, lr=1e-4, eval_call_epoch=None): 100 | 101 | optimizer = optim.Adam(self.model.parameters(), lr=lr, weight_decay=0.) 102 | scheduler = get_linear_schedule_with_warmup(optimizer, 190, 80000) 103 | self.model, optimizer, train_loader, scheduler = self.accelerate.prepare( 104 | self.model, optimizer, self.train_loader, scheduler) 105 | 106 | current_uid = str(uuid.uuid1()).split('-')[0] 107 | 108 | train_step = resume_step if resume_step is not None else 0 109 | for epoch in range(num_epochs): 110 | train_count = 0 111 | train_loss = 0 112 | eval_scores = { 113 | 'rouge-1': 0, 114 | 'rouge-2': 0, 115 | 'rouge-l': 0, 116 | 'bleu-4': 0 117 | } 118 | 119 | train_iter = tqdm(train_loader) 120 | self.model.train() 121 | 122 | for it in train_iter: 123 | 124 | output = self.model(**it) 125 | loss = output.loss 126 | loss = loss.mean() 127 | # loss.backward() 128 | self.accelerate.backward(loss) 129 | optimizer.step() 130 | scheduler.step() 131 | self.model.zero_grad() 132 | 133 | logits = output.logits 134 | labels = it['labels'] 135 | metrics = self.compute_metrics(logits, labels) 136 | for k, v in metrics.items(): 137 | eval_scores[k] += v 138 | 139 | train_loss += loss.data.item() 140 | train_count += 1 141 | train_step += 1 142 | 143 | train_iter.set_description( 144 | 'Train: {}/{}'.format(epoch + 1, num_epochs)) 145 | train_iter.set_postfix( 146 | train_loss=train_loss / train_count, **{k: v / train_count for k, v in eval_scores.items()}) 147 | 148 | self.analysis.append_train_record({ 149 | 'epoch': epoch + 1, 150 | 'train_loss': train_loss / train_count, 151 | **{k: v / train_count for k, v in eval_scores.items()} 152 | }) 153 | 154 | model_uid = self.save_model(train_step) 155 | if eval_call_epoch is None or eval_call_epoch(epoch): 156 | self.eval(epoch) 157 | 158 | self.analysis.save_all_records( 159 | uid=current_uid if self.task_name is None else self.task_name) 160 | yield (epoch, self.analysis.train_record, self.analysis.eval_record, self.analysis.model_record, model_uid) 161 | 162 | @accelerator.on_local_main_process 163 | def save_model(self, current_step=0): 164 | if self.task_name is None: 165 | dir = 'undefined' 166 | else: 167 | dir = self.task_name 168 | save_path = f'./save_model/{dir}/ChatGLM_{current_step}' 169 | if not os.path.exists(save_path): 170 | os.makedirs(save_path) 171 | save_model = self.accelerate.unwrap_model(self.model) 172 | save_model.save_pretrained( 173 | save_path, 174 | is_main_process=self.accelerate.is_main_process, 175 | save_function=self.accelerate.save, 176 | ) 177 | self.analysis.append_model_record(current_step) 178 | return current_step 179 | 180 | def eval(self, epoch, pure_eval=False): 181 | if pure_eval: 182 | self.model = self.accelerate.prepare_model(self.model) 183 | self.eval_loader = self.accelerate.prepare_data_loader( 184 | self.eval_loader) 185 | 186 | with torch.no_grad(): 187 | eval_count = 0 188 | eval_loss = 0 189 | eval_scores = { 190 | 'rouge-1': 0, 191 | 'rouge-2': 0, 192 | 'rouge-l': 0, 193 | 'bleu-4': 0 194 | } 195 | 196 | eval_iter = tqdm(self.eval_loader) 197 | self.model.eval() 198 | 199 | for it in eval_iter: 200 | 201 | output = self.model(**it) 202 | loss = output.loss 203 | loss = loss.mean() 204 | 205 | logits = output.logits 206 | labels = it['labels'] 207 | metrics = self.compute_metrics(logits, labels) 208 | for k, v in metrics.items(): 209 | eval_scores[k] += v 210 | 211 | eval_loss += loss.data.item() 212 | eval_count += 1 213 | 214 | eval_iter.set_description( 215 | f'Eval: {epoch + 1}') 216 | eval_iter.set_postfix( 217 | eval_loss=eval_loss / eval_count, **{k: v / eval_count for k, v in eval_scores.items()}) 218 | 219 | self.analysis.append_eval_record({ 220 | 'epoch': epoch + 1, 221 | 'eval_loss': eval_loss / eval_count, 222 | **{k: v / eval_count for k, v in eval_scores.items()} 223 | }) 224 | 225 | def compute_metrics(self, logits, labels): 226 | shift_logits = logits[..., :-1, :] 227 | pred_logits = shift_logits.argmax(-1) 228 | pred_logits = pred_logits.tolist() 229 | shift_labels = labels[..., 1:].tolist() 230 | 231 | metrics_dct = {'rouge-1': [], 'rouge-2': [], 232 | 'rouge-l': [], 'bleu-4': []} 233 | for pred_ids, label_ids in zip(pred_logits, shift_labels): 234 | try: 235 | answer_idx = 0 236 | for i in range(len(label_ids)): 237 | if label_ids[i] != -100: 238 | answer_idx = i 239 | break 240 | pred_ids = pred_ids[answer_idx:] 241 | label_ids = label_ids[answer_idx:] 242 | pred_txt = self.tokenizer.decode(pred_ids).strip() 243 | label_txt = self.tokenizer.decode(label_ids).strip() 244 | pred_tokens = list(jieba.cut(pred_txt)) 245 | label_tokens = list(jieba.cut(label_txt)) 246 | rouge = Rouge() 247 | scores = rouge.get_scores( 248 | ' '.join(pred_tokens), ' '.join(label_tokens)) 249 | for k, v in scores[0].items(): 250 | metrics_dct[k].append(round(v['f'] * 100, 4)) 251 | metrics_dct['bleu-4'].append( 252 | sentence_bleu( 253 | [label_tokens], 254 | pred_tokens, 255 | smoothing_function=SmoothingFunction().method3, 256 | ) 257 | ) 258 | except: 259 | continue 260 | return {k: np.mean(v) if len(v) > 0 else 0 for k, v in metrics_dct.items()} 261 | -------------------------------------------------------------------------------- /sas_pipelines/1_predict_scores.py: -------------------------------------------------------------------------------- 1 | # %% 2 | import os 3 | import json 4 | import random 5 | import json_repair 6 | import numpy as np 7 | from tqdm import tqdm 8 | from copy import deepcopy 9 | from argparse import ArgumentParser 10 | 11 | import sys 12 | sys.path.append("../") 13 | cmd_args = True 14 | # Add params n_gpu 15 | parser = ArgumentParser() 16 | parser.add_argument('--n_gpu', default=0, help='CUDA_VISIBLE_DEVICES') 17 | parser.add_argument('--file_dir', default='../datasets', help='the directory of the datasets.') 18 | parser.add_argument('--file_name', default='9_English_gapfilling', help='file name of the dataset without extension (e.g. 0_Physics_ShortAns)') 19 | parser.add_argument('--llm_name', default='', help='the prefix name of save dir (usually is the LLM name)') 20 | parser.add_argument('--save_type_name', default='GLM4', help='the prefix name of save dir (usually is the LLM name)') 21 | parser.add_argument('--few_shot_num', default=0, help='decide the number of few-shot samples') 22 | parser.add_argument('--use_guideline', default='1', help='whether use scoring guideline') 23 | parser.add_argument('--model_from_pretrained', default='/home/lpc/models/glm-4-9b-chat/', help='model from pretrained') 24 | parser.add_argument('--vllm', default='0', help='whether use vllm') 25 | parser.add_argument('--tensor_parallel_size', default=1, help='tensor_parallel_size (TP) for vLLM') 26 | parser.add_argument('--max_new_tokens', default=1024, help='max new tokens') 27 | parser.add_argument('--do_sample', default='0', help='do_sample, useless for vLLM') 28 | parser.add_argument('--temperature', default=0.6, help='temperature, if temperture > 0, it will work on vLLM.') 29 | parser.add_argument('--top_p', default=0.95, help='top_p, if top_p < 1.0, it will work on vLLM') 30 | parser.add_argument('--skip_thinking', default='0', help='skip deep thinking in RL model with \n\n') 31 | parser.add_argument('--batch_size', default=5, help='batch size, suggest to set it larger when use vLLM') 32 | parser.add_argument('--fix_reasoning', default=1, help='Re-generate with longer length for the result without finishing thinking') 33 | 34 | if not cmd_args: 35 | args = parser.parse_args([]) # You can directly set above parameters in the default. 36 | else: 37 | args = parser.parse_args() 38 | 39 | os.environ["CUDA_VISIBLE_DEVICES"] = str(args.n_gpu) 40 | 41 | API_MODELS = ['gpt-4o-mini', 'deepseek-chat', 'deepseek-reasoner'] 42 | API_CONFIGS = [('OpenAI', None), ('Deepseek', 'https://api.deepseek.com'), ('Deepseek', 'https://api.deepseek.com')] 43 | 44 | USE_VLLM = str(args.vllm) == '1' 45 | 46 | llm_name = args.llm_name if args.llm_name != '' else args.save_type_name 47 | if llm_name == 'GLM3': 48 | from main.predictor.chatglm import Predictor 49 | elif llm_name in API_MODELS: 50 | from main.predictor.openai import Predictor 51 | elif USE_VLLM: 52 | from main.predictor.vllm import Predictor 53 | else: 54 | from main.predictor.llm import Predictor 55 | 56 | if llm_name not in API_MODELS: 57 | pred = Predictor(model_from_pretrained=args.model_from_pretrained, tensor_parallel_size=int(args.tensor_parallel_size)) 58 | else: 59 | CONFIG_INDEX = API_MODELS.index(llm_name) 60 | with open('api_key.txt') as f: 61 | api_keys = f.readlines() 62 | for key_item in api_keys: 63 | key_item = key_item.strip().split(' ') 64 | if len(key_item) == 1: 65 | api_key = key_item 66 | break 67 | else: 68 | if key_item[0] == API_CONFIGS[CONFIG_INDEX][0]: 69 | api_key = key_item[1] 70 | break 71 | pred = Predictor(api_key=api_key, base_url=API_CONFIGS[CONFIG_INDEX][1]) 72 | 73 | # %% 74 | SOURCE_FILE = os.path.join(args.file_dir, f'{args.file_name}.jsonl') 75 | ERROR_TYPE_FILE = os.path.join(args.file_dir, 'error_type.jsonl') 76 | SAVE_DIR = os.path.dirname(SOURCE_FILE) + f'_{args.save_type_name}_Scored' 77 | basename = os.path.basename(SOURCE_FILE) 78 | SAVE_FILE = os.path.join(SAVE_DIR, 79 | basename.split('.')[0]+'_scored.jsonl') 80 | FEW_SHOT_NUM = int(args.few_shot_num) 81 | BATCH_SIZE = int(args.batch_size) 82 | MAX_NEW_TOKENS = int(args.max_new_tokens) 83 | 84 | # if it is english dataset, you may replace it with {'name': 'correct', 'description': 'the step is correct.'} 85 | # but it depends on how you predefined the `name` of the correct step. 86 | CORRECT_NAME = '步骤正确' 87 | CORRECT_DESCRIPTION = '该步骤正确' 88 | PREDICT_PROMPT = '' 89 | FEW_SHOT_PROMPT = {'prefix': '', 'suffix': '', 'question': '', 'reference': '', 'total': '', 'analysis': '', 'student_answer': '', 'output': ''} 90 | 91 | with open('../prompts/predict.txt') as f: 92 | PREDICT_PROMPT = f.read().strip() 93 | 94 | with open('../prompts/few_shot_prompt.json') as f: 95 | FEW_SHOT_PROMPT = json.load(f) 96 | 97 | if not os.path.exists(SAVE_DIR): 98 | os.makedirs(SAVE_DIR) 99 | 100 | ID = args.file_name.split('_')[0] 101 | with open(ERROR_TYPE_FILE, encoding='utf-8') as f: 102 | error_type_list = f.readlines() 103 | error_type_list = [json.loads(item) for item in error_type_list] 104 | error_type_item = [] 105 | score_guideline = '' 106 | for item in error_type_list: 107 | if str(item['q_id']) == str(ID): 108 | score_guideline = item['guideline'] 109 | error_type_item = item['errors'] 110 | error_type_item.append({'name': CORRECT_NAME, 'description': CORRECT_DESCRIPTION}) 111 | break 112 | 113 | if str(args.use_guideline) != '1': 114 | score_guideline = '' 115 | 116 | # Read the JSON file 117 | with open(SOURCE_FILE, encoding='utf-8') as f: 118 | ori_data = f.readlines() 119 | ori_data = [json.loads(item) for item in ori_data] 120 | 121 | ori_data_id_dict = {} 122 | for idx, item in enumerate(ori_data): 123 | if item['id'] not in ori_data_id_dict: 124 | ori_data_id_dict[item['id']] = idx 125 | 126 | if str(args.fix_reasoning) == '1' and os.path.exists(SAVE_FILE): 127 | with open(SAVE_FILE) as f: 128 | save_data = f.readlines() 129 | 130 | count = 0 131 | for item in save_data: 132 | item = item.split('\t') 133 | id, content = item[0], item[1] 134 | if id in ori_data_id_dict: 135 | idx = ori_data_id_dict[id] 136 | content = json.loads(content.strip()) 137 | if type(content) != str: 138 | content = str(content) 139 | if content.find('{"total"') >= 0 or content.find('{\'total\'') >= 0: 140 | ori_data[idx]['cache'] = content.strip() 141 | count += 1 142 | print(f'Found correct generation: {count}') 143 | 144 | few_shot_prompt = '' 145 | if FEW_SHOT_NUM > 0: 146 | from utils.load_few_shot import get_few_shot_samples, compute_few_shot_prompt 147 | few_shot_samples = get_few_shot_samples(SOURCE_FILE, num_samples=FEW_SHOT_NUM) 148 | few_shot_samples = [compute_few_shot_prompt(item, prompt=FEW_SHOT_PROMPT['template']) for item in few_shot_samples] 149 | few_shot_prompt = f'{FEW_SHOT_PROMPT["prefix"]}\n' 150 | few_shot_prompt += '\n'.join(few_shot_samples) 151 | few_shot_prompt += f'\n{FEW_SHOT_PROMPT["suffix"]}\n' 152 | 153 | # %% 154 | prompt_prefix = PREDICT_PROMPT 155 | 156 | # %% 157 | all_examples = [] 158 | ask_list = [] 159 | 160 | error_type = [] 161 | for error_item in error_type_item: 162 | error_type.append(error_item['name']) 163 | error_type_content = json.dumps(error_type, ensure_ascii=False) 164 | 165 | for idx, response_item in tqdm(enumerate(ori_data)): 166 | id = response_item['id'] 167 | question = response_item.get('question', '') 168 | reference = response_item.get('reference', '') 169 | analysis = response_item.get('analysis', '') 170 | total = response_item.get('total', '') 171 | manual_label = response_item.get('manual_label', '') 172 | steps = response_item.get('steps', '') 173 | cache_content = response_item['cache'] if 'cache' in response_item else False 174 | 175 | reponse_content = [] 176 | for s_idx, step in enumerate(steps): 177 | response = step['response'] 178 | reponse_content.append(f'## Step {s_idx}. {response}') 179 | 180 | # Construct Q&A session content (context + questions + references) 181 | ask_content = prompt_prefix.format( 182 | question=question, 183 | total=total, 184 | score_guideline=score_guideline, 185 | few_shot_samples=few_shot_prompt, 186 | error_type=error_type_content, 187 | reference=reference, 188 | analysis=analysis, 189 | student_answer=''.join(reponse_content) 190 | ) 191 | ask_list.append((ask_content, id, cache_content)) 192 | 193 | save_content = None 194 | save_content_id_dict = {} 195 | def refresh_save_content(id, content): 196 | global save_content 197 | global save_content_id_dict 198 | if save_content is None: 199 | save_content = [] 200 | for idx, item in enumerate(ask_list): 201 | save_content.append((item[1], item[2])) 202 | save_content_id_dict[item[1]] = idx 203 | idx = save_content_id_dict[id] 204 | save_content[idx] = (id, content) 205 | with open(SAVE_FILE, 'w', encoding='utf-8') as f: 206 | for item in save_content: 207 | f.write(item[0] + '\t' + json.dumps(item[1], ensure_ascii=False) + '\n') 208 | 209 | def build_chat_custom(content): 210 | content = f'<|im_start|>user\n{content}<|im_end|>\n<|im_start|>assistant\n\n\n\n\n' 211 | return content 212 | 213 | if args.skip_thinking == '1': 214 | for idx, tp in enumerate(ask_list): 215 | ask_list[idx] = (build_chat_custom(tp[0]), tp[1], tp[2]) 216 | 217 | #%% 218 | # Calculate total number of evaluation batches 219 | if llm_name not in API_MODELS: 220 | num_batches = len(ask_list) // BATCH_SIZE + (1 if len(ask_list) % BATCH_SIZE != 0 else 0) 221 | 222 | # Run batch prediction and persist results (with progress tracking) 223 | for i in tqdm(range(num_batches)): 224 | batch = ask_list[i * BATCH_SIZE:(i + 1) * BATCH_SIZE] 225 | prompts = [item[0] for item in batch] 226 | ids = [item[1] for item in batch] 227 | cache_content_list = [item[2] for item in batch] 228 | max_length = [len(item[0]) for item in batch] 229 | max_length.sort(reverse=True) 230 | max_new_tokens = max_length[0] 231 | filter_prompts = [] 232 | filter_idxes = [] 233 | output_list = [] 234 | for idx, tp in enumerate(zip(prompts, cache_content_list)): 235 | prompt, cache = tp 236 | if cache == False: 237 | filter_prompts.append(prompt) 238 | filter_idxes.append(idx) 239 | output_list.append('') 240 | else: 241 | output_list.append(cache) 242 | 243 | if len(filter_prompts) > 0: 244 | outputs = pred(filter_prompts, max_new_tokens=MAX_NEW_TOKENS, build_message=args.skip_thinking != '1', do_sample=args.do_sample == '1', temperature=float(args.temperature), top_p=float(args.top_p)) 245 | else: 246 | outputs = [] 247 | 248 | for res, idx in zip(outputs, filter_idxes): 249 | res = res.replace('\n', '') 250 | res = res.replace(' ', '') 251 | output_list[idx] = res 252 | 253 | for res, id in zip(output_list, ids): 254 | refresh_save_content(id, res) 255 | else: 256 | for ask_content, id, cache in tqdm(ask_list): 257 | if cache != False: 258 | res = cache 259 | else: 260 | res = pred(ask_content, model=llm_name) 261 | res = res[0] 262 | res = res.replace('\n', '') 263 | res = res.replace(' ', '') 264 | if res.find("{'total'") >= 0: 265 | res = res.replace('\'', '"') 266 | refresh_save_content(id, res) 267 | 268 | #%% 269 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [2025] [PKU-DAIR & DCML and FZU-ACM Team. All Rights Reserved.] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /main/trainer/chatglm_rlhf.py: -------------------------------------------------------------------------------- 1 | import os 2 | import uuid 3 | import torch 4 | import torch.nn as nn 5 | import torch.optim as optim 6 | import torch.nn.functional as F 7 | from transformers import AutoModel 8 | from transformers import get_linear_schedule_with_warmup 9 | from peft import get_peft_model, LoraConfig, TaskType, PeftModel 10 | import numpy as np 11 | import jieba 12 | import random 13 | from nltk.translate.bleu_score import SmoothingFunction, sentence_bleu 14 | from rouge_chinese import Rouge 15 | from tqdm import tqdm 16 | from main.loader import AutoDataloader 17 | from main.models.chatglm_rlhf import CriticModel, RewardModel, PPO 18 | from main.analysis import Analysis 19 | import torch.distributed as dist 20 | from accelerate import Accelerator 21 | from accelerate.utils import DistributedDataParallelKwargs 22 | 23 | kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) 24 | accelerator = Accelerator(kwargs_handlers=[kwargs]) 25 | 26 | 27 | class Trainer(): 28 | 29 | def __init__(self, tokenizer, config, from_pretrained, reward_from_pretrained, loader_name, data_path, ratio_for_rlhf=0.4, actor_resume_path=None, critic_resume_path=None, critic_layers_keep=1, max_length=512, batch_size=1, batch_size_eval=1, eval_mode='dev', task_name='Sim'): 30 | self.tokenizer = tokenizer 31 | self.config = config 32 | self.accelerate = accelerator 33 | self.loader_name = loader_name 34 | self.data_path = data_path 35 | self.model_from_pretrained = from_pretrained 36 | self.reward_from_pretrained = reward_from_pretrained 37 | self.critic_layers_keep = critic_layers_keep 38 | self.data_path = data_path 39 | self.ratio_for_rlhf = ratio_for_rlhf 40 | self.task_name = task_name 41 | self.max_length = max_length 42 | self.batch_size = batch_size 43 | self.batch_size_eval = batch_size_eval 44 | self.eval_mode = eval_mode 45 | self.decay_up_matrix_T = None 46 | self.dataloader_init() 47 | self.qa_logs = {} 48 | self.model_init(actor_resume_path=actor_resume_path, critic_resume_path=critic_resume_path) 49 | self.analysis = Analysis() 50 | 51 | def model_init(self, actor_resume_path=None, critic_resume_path=None): 52 | if self.accelerate.is_local_main_process: 53 | print('AutoModel Choose Model: {}\n'.format(self.model_from_pretrained)) 54 | self.model = AutoModel.from_pretrained( 55 | self.model_from_pretrained, trust_remote_code=True).to(torch.bfloat16) 56 | self.critic_model = CriticModel(model_from_pretrained=self.model_from_pretrained, resume_path=critic_resume_path, layers_keep=self.critic_layers_keep) 57 | self.reward_model = RewardModel(model_from_pretrained=self.reward_from_pretrained) 58 | self.ppo = PPO(self.tokenizer, self.qa_logs) 59 | peft_config = LoraConfig( 60 | task_type=TaskType.CAUSAL_LM, 61 | inference_mode=False, 62 | r=16, 63 | target_modules=['query_key_value'], 64 | lora_alpha=32, 65 | lora_dropout=0.1 66 | ) 67 | if actor_resume_path is not None: 68 | print('Accessing Resume PATH: {} ...\n'.format(actor_resume_path)) 69 | self.model.enable_input_require_grads() 70 | self.model = PeftModel.from_pretrained(self.model, actor_resume_path, config=peft_config) 71 | else: 72 | self.model = get_peft_model(self.model, peft_config) 73 | self.model.print_trainable_parameters() 74 | 75 | def dataloader_init(self): 76 | d = AutoDataloader(self.tokenizer, self.config, loader_name=self.loader_name, data_path=self.data_path, 77 | max_length=self.max_length) 78 | self.train_loader, self.eval_loader = d( 79 | self.batch_size, self.batch_size_eval, self.eval_mode, True) 80 | 81 | def __call__(self, resume_step=None, num_epochs=30, lr=1e-4, eval_call_epoch=None, ppo_epochs=5): 82 | return self.train(resume_step=resume_step, 83 | num_epochs=num_epochs, lr=lr, eval_call_epoch=eval_call_epoch, ppo_epochs=ppo_epochs) 84 | 85 | def train(self, resume_step=None, num_epochs=30, lr=1e-4, num_beams=3, num_return_sequences=2, eval_call_epoch=None, ppo_epochs=5): 86 | 87 | optimizer = optim.Adam(list(self.model.parameters()) + list(self.critic_model.output_linear.parameters()), lr=lr, weight_decay=0.) 88 | scheduler = get_linear_schedule_with_warmup(optimizer, 190, 80000) 89 | self.model, self.critic_model, self.reward_model, optimizer, train_loader, scheduler, self.ppo = self.accelerate.prepare(self.model, self.critic_model, self.reward_model, optimizer, self.train_loader, scheduler, self.ppo) 90 | 91 | current_uid = str(uuid.uuid1()).split('-')[0] 92 | 93 | train_step = resume_step if resume_step is not None else 0 94 | for epoch in range(num_epochs): 95 | train_count = 0 96 | train_loss = 0 97 | eval_scores = { 98 | 'rouge-1': 0, 99 | 'rouge-2': 0, 100 | 'rouge-l': 0, 101 | 'bleu-4': 0 102 | } 103 | 104 | train_iter = tqdm(train_loader) 105 | self.model.train() 106 | self.critic_model.train() 107 | 108 | for it in train_iter: 109 | is_rlhf = torch.tensor([random.randint(0, 10) / 10]).to(self.accelerate.device) 110 | self.accelerate.wait_for_everyone() 111 | dist.broadcast(is_rlhf, src=0) 112 | is_rlhf = is_rlhf.item() <= 0.4 113 | for loss, logits in self.ppo(is_rlhf, self.model, self.reward_model, self.critic_model, **it, num_beams=num_beams, num_return_sequences=num_return_sequences, ppo_epochs=ppo_epochs): 114 | # loss.backward() 115 | self.accelerate.backward(loss) 116 | optimizer.step() 117 | scheduler.step() 118 | self.model.zero_grad() 119 | self.critic_model.zero_grad() 120 | 121 | if not is_rlhf: 122 | labels = it['labels'] 123 | metrics = self.compute_metrics(logits, labels) 124 | for k, v in metrics.items(): 125 | eval_scores[k] += v 126 | 127 | train_loss += loss.data.item() 128 | train_count += 1 129 | train_step += 1 130 | 131 | train_iter.set_description( 132 | 'Train: {}/{}'.format(epoch + 1, num_epochs)) 133 | train_iter.set_postfix( 134 | train_loss=train_loss / train_count, **{k: v / train_count for k, v in eval_scores.items()}) 135 | 136 | self.analysis.append_train_record({ 137 | 'epoch': epoch + 1, 138 | 'train_loss': train_loss / train_count, 139 | **{k: v / train_count for k, v in eval_scores.items()} 140 | }) 141 | 142 | model_uid = self.save_model(train_step) 143 | if eval_call_epoch is None or eval_call_epoch(epoch): 144 | self.eval(epoch) 145 | 146 | self.analysis.save_all_records( 147 | uid=current_uid if self.task_name is None else self.task_name) 148 | import json 149 | with open(f'./data_record/{current_uid if self.task_name is None else self.task_name}/{train_step}_log', encoding='utf-8', mode='w+') as f: 150 | for item in self.qa_logs: 151 | f.write(json.dumps({f'{item}': self.qa_logs[item]}, ensure_ascii=False) + '\n') 152 | self.qa_logs.clear() 153 | yield (epoch, self.analysis.train_record, self.analysis.eval_record, self.analysis.model_record, model_uid) 154 | 155 | @accelerator.on_local_main_process 156 | def save_model(self, current_step=0): 157 | if self.task_name is None: 158 | dir = 'undefined' 159 | else: 160 | dir = self.task_name 161 | save_dir = f'./save_model/{dir}' 162 | peft_save_path = os.path.join(save_dir, f'ChatGLM_{current_step}') 163 | if not os.path.exists(peft_save_path): 164 | os.makedirs(peft_save_path) 165 | actor_model = self.accelerate.unwrap_model(self.model) 166 | actor_model.save_pretrained( 167 | peft_save_path, 168 | is_main_process=self.accelerate.is_main_process, 169 | save_function=self.accelerate.save, 170 | ) 171 | critic_save_path = os.path.join(save_dir, f'Critic_{current_step}') 172 | if not os.path.exists(critic_save_path): 173 | os.makedirs(critic_save_path) 174 | critic_model = self.accelerate.unwrap_model(self.critic_model) 175 | critic_linear = critic_model.output_linear 176 | torch.save(critic_linear.state_dict(), os.path.join(critic_save_path, 'linear.pth')) 177 | self.analysis.append_model_record(current_step) 178 | return current_step 179 | 180 | def eval(self, epoch, pure_eval=False): 181 | if pure_eval: 182 | self.model = self.accelerate.prepare_model(self.model) 183 | self.eval_loader = self.accelerate.prepare_data_loader(self.eval_loader) 184 | 185 | with torch.no_grad(): 186 | eval_count = 0 187 | eval_loss = 0 188 | eval_scores = { 189 | 'rouge-1': 0, 190 | 'rouge-2': 0, 191 | 'rouge-l': 0, 192 | 'bleu-4': 0 193 | } 194 | 195 | eval_iter = tqdm(self.eval_loader) 196 | self.model.eval() 197 | 198 | for it in eval_iter: 199 | 200 | output = self.model(input_ids=it['input_ids'], labels=it['labels']) 201 | loss = output.loss 202 | loss = loss.mean() 203 | 204 | logits = output.logits 205 | labels = it['labels'] 206 | metrics = self.compute_metrics(logits, labels) 207 | for k, v in metrics.items(): 208 | eval_scores[k] += v 209 | 210 | eval_loss += loss.data.item() 211 | eval_count += 1 212 | 213 | eval_iter.set_description( 214 | f'Eval: {epoch + 1}') 215 | eval_iter.set_postfix( 216 | eval_loss=eval_loss / eval_count, **{k: v / eval_count for k, v in eval_scores.items()}) 217 | 218 | self.analysis.append_eval_record({ 219 | 'epoch': epoch + 1, 220 | 'eval_loss': eval_loss / eval_count, 221 | **{k: v / eval_count for k, v in eval_scores.items()} 222 | }) 223 | 224 | def compute_metrics(self, logits, labels): 225 | shift_logits = logits[..., :-1, :] 226 | pred_logits = shift_logits.argmax(-1) 227 | pred_logits = pred_logits.tolist() 228 | shift_labels = labels[..., 1:].tolist() 229 | 230 | metrics_dct = {'rouge-1': [], 'rouge-2': [], 'rouge-l': [], 'bleu-4': []} 231 | for pred_ids, label_ids in zip(pred_logits, shift_labels): 232 | try: 233 | answer_idx = 0 234 | for i in range(len(label_ids)): 235 | if label_ids[i] != -100: 236 | answer_idx = i 237 | break 238 | pred_ids = pred_ids[answer_idx:] 239 | label_ids = label_ids[answer_idx:] 240 | pred_txt = self.tokenizer.decode(pred_ids).strip() 241 | label_txt = self.tokenizer.decode(label_ids).strip() 242 | pred_tokens = list(jieba.cut(pred_txt)) 243 | label_tokens = list(jieba.cut(label_txt)) 244 | rouge = Rouge() 245 | scores = rouge.get_scores(' '.join(pred_tokens), ' '.join(label_tokens)) 246 | for k, v in scores[0].items(): 247 | metrics_dct[k].append(round(v['f'] * 100, 4)) 248 | metrics_dct['bleu-4'].append( 249 | sentence_bleu( 250 | [label_tokens], 251 | pred_tokens, 252 | smoothing_function=SmoothingFunction().method3, 253 | ) 254 | ) 255 | except: 256 | continue 257 | return {k: np.mean(v) if len(v) > 0 else 0 for k, v in metrics_dct.items()} -------------------------------------------------------------------------------- /docs/Readme_cn.md: -------------------------------------------------------------------------------- 1 |

2 | Logo 3 |

4 | 5 | Static Badge 6 | 7 | 8 | Static Badge 9 | 10 | 11 | GitHub Repo stars 12 | 13 |

14 |

15 | 16 | ## SAS-Bench: A Fine-Grained Benchmark for Evaluating Short Answer Scoring with Large Language Models 17 | 18 | [数据集](https://huggingface.co/datasets/aleversn/SAS-Bench) | [论文](https://arxiv.org/pdf/2505.07247) | [代码](https://github.com/PKU-DAIR/SAS-Bench) 19 | 20 | ## 🔍 项目概述 21 | 22 | SAS-Bench是首个专门针对大语言模型(LLM)的简答题评分(SAS)基准测试。基于中国高考真实试题构建,本基准测试具有以下特点: 23 | 24 | - **1030道试题**覆盖9大学科领域 25 | - **4109份专家标注的学生答案** 26 | - **分步评分**与**分布错因分析** 27 | - **多维度评估体系**(整体评分、分步评分、错因诊断一致性) 28 | 29 | ## 🚀 核心特色 30 | 31 | ### 突破传统SAS系统局限 32 | SAS-Bench解决了传统简答题评分系统的关键缺陷: 33 | 34 | | 维度 | 传统SAS系统 | SAS-Bench优势 | 35 | | -------------- | ------------- | ------------------ | 36 | | **评分粒度** | 单一总分 | 分步分解评分 | 37 | | **可解释性** | 黑箱机制 | 完备的错因类型体系 | 38 | | **答案多样性** | 单一学科/题型 | 跨学科非模板化评估 | 39 | 40 | ### 数据集特性 41 | 42 |

43 | SAS人工标注系统 44 |

45 | 46 | 数据集包含三类题目及丰富标注: 47 | 48 | 1. **选择题**(自由填写形式) 49 | 2. **填空题** 50 | 3. **简答题**(含步骤分解) 51 | 52 | 每份答案包含: 53 | - ✅ 人工标注整体得分 54 | - 🔍 步骤划分与分项评分 55 | - ❌ 步骤错因归类 56 | 57 | ## 🌟 评估框架 58 | 59 | ### CCS评估(协同一致性评分) 60 | 61 | **目的** 62 | 衡量模型预测与人工评分在整体得分和步骤得分上的协同一致性,确保模型理解详细推理过程。 63 | 64 | **公式** 65 | 调整权重矩阵结合整体与步骤差异: 66 | ```math 67 | W_{i,j} = \alpha \cdot \frac{(r_i - r_j)^2}{(N_r - 1)^2} + \frac{1 - \alpha}{m} \sum_{k=1}^{m} \frac{(s_{i,k} - s_{j,k})^2}{(N_{s_k} - 1)^2} 68 | ``` 69 | 其中: 70 | - $r_i, r_j$:模型/人工整体评分 71 | - $s_{i,k}, s_{j,k}$:第$k$步得分 72 | - $\alpha=0.5$:平衡权重 73 | - $N_r, N_{s_k}$:可能得分等级 74 | 75 | 最终CCS计算: 76 | ```math 77 | \text{CCS} := 1 - \frac{\sum_{i,j} O_{i,j} \cdot W_{i,j}}{\sum_{i,j} E_{i,j} \cdot W_{i,j}} 78 | ``` 79 | 80 | ### ECS评估(错因一致性评分) 81 | 82 | **目的** 83 | 量化模型识别错因类型的能力,按答案质量分层评估。 84 | 85 | **公式** 86 | 1. 使用分位数阈值$\tau_1, \tau_2$将样本分为3组(低/中/高): 87 | ```math 88 | \phi(x) = \mathbb{I}(x \geq \tau_1) + \mathbb{I}(x \geq \tau_2) 89 | ``` 90 | 2. 计算每组的错因频率矩阵$\mathbf{M}^p_k, \mathbf{M}^g_k$ 91 | 3. 计算组内Spearman相关性: 92 | ```math 93 | \rho_k = \text{SpearmanR}(\mathbf{M}^p_k, \mathbf{M}^g_k) 94 | ``` 95 | 最终ECS: 96 | ```math 97 | \text{ECS} := \frac{1}{m} \sum_{k=0}^{2} \rho_k 98 | ``` 99 | 100 | **关键特性** 101 | - 采用**3级性能分层**(m=3)确保稳健评估 102 | - 关联**错因类型分布**(而非简单计数) 103 | - 标准化评分支持跨数据集比较 104 | 105 | ## ⚙️ 安装指南 106 | 107 | ### 核心依赖 108 | ```bash 109 | pip install protobuf transformers>=4.44.1 cpm_kernels torch>=2.0 gradio mdtex2html sentencepiece accelerate json_repair openai 110 | ``` 111 | 112 | 或: 113 | ```bash 114 | pip install -r requirements.txt 115 | ``` 116 | 117 | ### vLLM环境配置(推荐) 118 | ```bash 119 | conda create -n vllm python=3.12 -y 120 | conda activate vllm 121 | pip install vllm # 需CUDA 12.0+ 122 | ``` 123 | 124 | 其他配置请参考官方[vLLM安装指南](https://docs.vllm.ai/en/latest/getting_started/installation/gpu.html)。 125 | 126 | ## 📊 基准测试流程 127 | 128 | ![工作流程](./assets/workflow.png) 129 | 130 | ### 目录结构 131 | 132 | ``` 133 | |- discuss/ - 分析脚本 134 | |- docs/ - 文档资源 135 | |- main/ - 模型训练/推理代码 136 | |- prompts/ - 预定义提示模板 137 | |- sas_pipelines/ - 主要评估代码 138 | |- utils/ - 工具函数 139 | ``` 140 | 141 | ### 实施选项 142 | 143 | #### 0. 数据预处理(标注阶段) 144 | - 原始标注数据位于`backend_data` 145 | - 运行`preprocess.py`进行数据整合 146 | - 修改`DATANAME`变量指定源文件(不含扩展名) 147 | 148 | > 此流程处理来自我们标注系统(系统即将开源)的原始数据 149 | 150 | #### 1. 数据获取 151 | 数据集发布于[HuggingFace数据集](https://huggingface.co/datasets/aleversn/SAS-Bench)。下载文件存放于`datasets/`: 152 | - 文件命名格式为`{q_id}_{course}_{question_type}.jsonl` 153 | - 错因分类体系在`error_type.jsonl`中: 154 | ```json 155 | {"q_id": 2, "course": "", "question_type": "", "guideline": "", "errors": [{"name": "", "description": ""}...]} 156 | ``` 157 | - `ID_Dict.json`包含学科-ID映射 158 | 159 | #### 2. LLM预测 160 | 支持Jupyter或命令行执行: 161 | 162 | **选项A:Jupyter Notebook** 163 | - 在`1_predict_scores.py`中设置`cmd_args = False` 164 | - 配置: 165 | - `save_type_name`:模型标识符/输出前缀 166 | - `model_from_pretrained`:模型路径 167 | - `file_name`:数据集标识(如`7_Math_ShortAns`) 168 | 169 | **选项B:命令行** 170 | 设置`cmd_args = True` 171 | 172 | *使用vLLM(推荐)*: 173 | ```bash 174 | cd sas_pipelines/ 175 | python 1_predict_scores.py --file_name=6_Chinese_ShortAns --save_type_name=<模型ID> --model_from_pretrained=<路径> --batch_size=1000 --vllm=1 176 | ``` 177 | 178 | *启用Tensor并行*: 179 | ```bash 180 | python 1_predict_scores.py --n_gpu=0,1 --file_name=6_Chinese_ShortAns --save_type_name=<模型ID> --model_from_pretrained=<路径> --batch_size=1000 --vllm=1 --tensor_parallel_size=2 181 | ``` 182 | 183 | *HuggingFace预测器*: 184 | ```bash 185 | python 1_predict_scores.py --file_name=6_Chinese_ShortAns --save_type_name=<模型ID> --model_from_pretrained=<路径> --batch_size=5 186 | ``` 187 | 188 | *OpenAI API预测*: 189 | 1. 在`sas_pipeline/`创建`api_key.txt`,格式: 190 | ```text 191 | OpenAI 192 | Deepseek 193 | ``` 194 | 2. 执行: 195 | ```bash 196 | python 1_predict_scores.py --file_name=6_Chinese_ShortAns --llm_name=deepseek-chat --save_type_name=Deepseek_V3 197 | ``` 198 | 199 | **附加参数**: 200 | - 使用小样本示例:`--few_shot_num >0` 201 | - 禁用评分指南:`--use_guideline=0` 202 | - 跳过深度思考:`--skip_thinking=1` 203 | - `llm_name`默认为`save_type_name`(GLM3/OpenAI模型除外) 204 | 205 | #### 3. 预测处理 206 | **选项A:Jupyter** 207 | - 在`2_process_prediction.py`设置`cmd_args = False` 208 | - 配置`file_name`(使用`all`进行批量处理) 209 | 210 | **选项B:命令行** 211 | ```bash 212 | python 2_process_prediction.py --file_name=all 213 | ``` 214 | 215 | #### 4. CCS计算 216 | **选项A:Jupyter** 217 | - 在`3_compute_ccs.py`配置`file_name`和`save_type_name` 218 | 219 | **选项B:命令行** 220 | ```bash 221 | python 3_compute_ccs.py --save_type_name=<模型前缀> 222 | ``` 223 | 224 | #### 5. ECS计算 225 | **选项A:Jupyter** 226 | - 在`4_compute_ecs.py`调整参数 227 | 228 | **选项B:命令行** 229 | ```bash 230 | python 4_compute_ecs.py --save_type_name=<模型前缀> 231 | ``` 232 | 233 | ## 📈 Model Performance Insights 234 | 235 | 在16个LLMs上进行了实验: 236 | 237 | - QWK 238 | 239 | ![Workflow](./assets/qwk.png) 240 | 241 | - CCS 242 | 243 | | Models | Phy. (S.) | Phy. (M.) | His. (S.) | Geo. (S.) | Bio. (G.) | Chi. (G.) | Chi. (S.) | Math (S.) | Math (G.) | Pol. (S.) | Eng. (G.) | Che. (G.) | Avg. | 244 | | ---------------------- | --------- | --------- | --------- | --------- | --------- | --------- | --------- | --------- | --------- | --------- | --------- | --------- | --------- | 245 | | Deepseek-R1 | 38.43 | **95.01** | **80.98** | 67.92 | **79.12** | 95.09 | 69.07 | 57.85 | **83.56** | 71.92 | 73.19 | 72.92 | 73.76 | 246 | | QwQ-32B | 48.53 | 87.23 | 75.43 | **77.06** | 72.52 | **96.00** | 31.77 | 48.66 | 45.51 | 74.48 | 54.79 | 62.17 | 64.51 | 247 | | TinyR1-32B-Preview | 38.17 | 84.88 | 75.83 | 71.52 | 73.45 | 92.57 | 52.61 | 48.28 | 74.77 | 70.70 | 57.92 | 41.37 | 65.17 | 248 | | Qwen3-32B | 47.29 | 85.51 | 64.96 | 80.43 | 63.15 | 92.21 | 50.43 | 51.26 | 80.77 | 73.30 | 59.33 | 57.82 | 67.20 | 249 | | Qwen3-8B | 54.33 | 76.17 | 45.54 | 68.89 | 43.22 | 86.01 | 42.02 | 46.33 | 73.33 | 64.25 | 50.55 | 50.52 | 58.43 | 250 | | MiMo-7B-RL | 52.77 | 41.01 | 61.33 | 67.10 | 35.93 | 54.72 | 43.09 | 38.09 | 55.79 | 36.78 | 34.69 | 31.05 | 46.03 | 251 | | Deepseek-Prover-V2-7B | 22.59 | 10.75 | 2.92 | 30.71 | 50.63 | 55.48 | 12.95 | 0.87 | 2.29 | 10.44 | 30.19 | 28.76 | 21.55 | 252 | | DeepSeek-R1-Distill-7B | 33.71 | 29.24 | 50.92 | 32.35 | 52.18 | 52.44 | 44.29 | 29.52 | 39.55 | 53.77 | 32.98 | 34.27 | 40.44 | 253 | | Deepseek-V3 | 53.89 | 85.72 | 69.85 | 76.23 | 76.51 | 93.42 | **69.49** | **58.81** | 80.18 | **76.75** | **73.82** | **74.64** | **74.11** | 254 | | GPT 4o-mini-20240718 | **58.90** | 81.19 | 54.85 | 76.59 | 65.39 | 87.65 | 55.25 | 43.56 | 37.38 | 63.44 | 22.60 | 55.98 | 58.56 | 255 | | Llama3.3-70B-Instruct | 45.34 | 70.03 | 72.02 | 72.51 | 67.94 | 85.30 | 35.83 | 58.60 | 74.97 | 63.68 | 67.60 | 38.94 | 62.73 | 256 | | Mixtral 8×7B-Instruct | 30.78 | 42.27 | 33.43 | 4.99 | 44.45 | 29.85 | 24.00 | 26.73 | 70.04 | 43.92 | 33.40 | 42.05 | 35.49 | 257 | | Qwen2.5-32B-Instruct | 40.53 | 77.02 | 62.34 | 74.50 | 72.07 | 94.85 | 66.37 | 50.08 | 32.59 | 64.09 | 53.35 | 62.87 | 62.56 | 258 | | Qwen2.5-14B-Instruct | 53.76 | 66.12 | 60.96 | 74.30 | 67.50 | 92.81 | 63.08 | 43.28 | 75.62 | 62.03 | 56.34 | 57.53 | 64.44 | 259 | | GLM4-9B-Chat | 45.62 | 52.33 | 36.81 | 69.41 | 39.19 | 63.92 | 42.94 | 35.50 | 56.95 | 54.83 | 33.92 | 30.79 | 46.85 | 260 | | Llama3-8B-Instruct | 41.09 | 35.10 | 37.52 | 31.29 | 32.19 | 38.13 | 32.89 | 23.55 | 62.43 | 37.78 | 31.68 | 29.27 | 36.08 | 261 | 262 | - ECS 263 | 264 | | Models | Phy. (S.) | Phy. (M.) | His. (S.) | Geo. (S.) | Bio. (G.) | Chi. (G.) | Chi. (S.) | Math (S.) | Math (G.) | Pol. (S.) | Eng. (G.) | Che. (G.) | Avg. | 265 | | ---------------------- | --------- | --------- | --------- | --------- | --------- | --------- | --------- | --------- | --------- | --------- | --------- | --------- | --------- | 266 | | Deepseek-R1 | 23.25 | 30.59 | 57.53 | 56.08 | 69.20 | 86.04 | 72.68 | **94.29** | 15.20 | 65.56 | _18.65_ | _81.76_ | **55.90** | 267 | | QwQ-32B | 4.74 | **63.92** | 67.06 | _70.04_ | 53.68 | 51.08 | 69.20 | 79.05 | 16.82 | 48.81 | -22.53 | 48.94 | 45.90 | 268 | | TinyR1-32B-Preview | 3.10 | **63.92** | 65.71 | **77.02** | 56.61 | 64.42 | 74.83 | 82.86 | 23.33 | 40.17 | -31.52 | 17.35 | 44.82 | 269 | | Qwen3-32B | -4.17 | 24.18 | _69.52_ | 54.29 | 53.67 | 52.70 | 47.31 | 82.21 | 18.33 | 62.14 | -26.99 | 36.27 | 39.12 | 270 | | Qwen3-8B | 23.39 | **63.92** | 14.29 | -4.96 | 52.21 | 47.75 | 34.01 | 39.20 | -8.14 | 57.19 | -27.13 | 59.28 | 29.25 | 271 | | MiMo-7B-RL | **51.05** | 24.18 | 14.29 | 38.85 | 58.35 | _92.17_ | 63.07 | 13.39 | 35.12 | -27.10 | -4.41 | 1.04 | 30.00 | 272 | | Deepseek-Prover-V2-7B | -24.10 | -5.20 | 42.86 | -6.23 | 29.54 | -80.81 | 23.25 | 46.67 | -1.51 | -58.64 | -45.23 | -21.91 | -8.44 | 273 | | DeepSeek-R1-Distill-7B | -45.19 | 24.18 | 0.95 | -38.66 | 23.55 | -20.36 | 3.87 | -23.81 | -13.57 | -18.81 | -19.59 | -44.58 | -14.34 | 274 | | Deepseek-V3 | 7.79 | 46.58 | 58.10 | 32.62 | _72.38_ | **96.58** | 57.43 | _92.38_ | _33.33_ | 40.26 | **24.77** | **85.83** | _54.00_ | 275 | | GPT 4o-mini-20240718 | 17.91 | 24.18 | 62.14 | 36.68 | 55.20 | 79.01 | **78.00** | 67.62 | **46.90** | **92.31** | 10.04 | 36.39 | 50.53 | 276 | | Llama3.3-70B-Instruct | 22.56 | _57.35_ | 54.29 | 42.11 | 45.09 | 52.70 | 46.25 | 54.29 | 30.00 | 58.81 | -12.53 | -15.83 | 36.26 | 277 | | Mixtral 8×7B-Instruct | 11.99 | 17.34 | **80.38** | 35.84 | 32.74 | 42.77 | 75.82 | 56.19 | 30.00 | 6.84 | -31.16 | -7.18 | 29.30 | 278 | | Qwen2.5-32B-Instruct | 11.95 | 17.41 | 53.33 | 59.34 | 62.96 | 46.90 | 75.08 | 62.86 | 30.00 | 46.67 | -4.50 | 27.08 | 40.76 | 279 | | Qwen2.5-14B-Instruct | 21.50 | 24.18 | 47.92 | 37.43 | **73.36** | 64.97 | 74.32 | 64.94 | 18.21 | 61.97 | -20.00 | 47.39 | 43.02 | 280 | | GLM4-9B-Chat | 35.00 | 24.18 | 32.49 | 34.73 | 62.12 | 20.36 | _77.34_ | 63.81 | **46.90** | _82.40_ | -25.35 | 7.18 | 38.43 | 281 | | Llama3-8B-Instruct | _48.25_ | 27.46 | 17.23 | 31.58 | 61.37 | -14.05 | 41.23 | 57.77 | 21.55 | -69.07 | -26.50 | -27.19 | 14.14 | 282 | 283 | ## 📅 待办事项 284 | 285 | - [ ] 提供英文本地化版本数据集 286 | - [ ] 开源标注系统(前端 & 后端) 287 | 288 | ## 📜 许可声明 289 | 290 | SAS-Bench采用`Apache License 2.0`协议发布。本数据集仅限研究用途使用。 291 | 292 | ## 📚 引用方式 293 | 294 | ```bibtex 295 | @article{lai2025sasbenchfinegrainedbenchmarkevaluating, 296 | title={SAS-Bench: A Fine-Grained Benchmark for Evaluating Short Answer Scoring with Large Language Models}, 297 | author={Peichao Lai and Kexuan Zhang and Yi Lin and Linyihan Zhang and Feiyang Ye and Jinhao Yan and Yanwei Xu and Conghui He and Yilei Wang and Wentao Zhang and Bin Cui}, 298 | year={2025}, 299 | journal={arXiv preprint arXiv:2505.07247}, 300 | primaryClass={cs.CL}, 301 | url={https://arxiv.org/abs/2505.07247}, 302 | } 303 | ``` 304 | -------------------------------------------------------------------------------- /Readme.md: -------------------------------------------------------------------------------- 1 |

2 | Logo 3 |

4 | 5 | Static Badge 6 | 7 | 8 | Static Badge 9 | 10 | 11 | GitHub Repo stars 12 | 13 |

14 |

15 | 16 | ## SAS-Bench: A Fine-Grained Benchmark for Evaluating Short Answer Scoring with Large Language Models 17 | 18 | [Dataset](https://huggingface.co/datasets/aleversn/SAS-Bench) | [中文](./docs/Readme_cn.md) | [Paper](https://arxiv.org/pdf/2505.07247) | [Code](https://github.com/PKU-DAIR/SAS-Bench) 19 | 20 | ## 🔍 Overview 21 | 22 | SAS-Bench represents the first specialized benchmark for evaluating Large Language Models (LLMs) on Short Answer Scoring (SAS) tasks. Utilizing authentic questions from China's National College Entrance Examination (Gaokao), our benchmark offers: 23 | 24 | - **1,030 questions** spanning 9 academic disciplines 25 | - **4,109 expert-annotated student responses** 26 | - **Step-wise scoring** with **Step-wise error analysis** 27 | - **Multi-dimensional evaluation** (holistic scoring, step-wise scoring, and error diagnosis consistency) 28 | 29 | ## 🚀 Key Features 30 | 31 | ### Advancing Beyond Traditional SAS Limitations 32 | SAS-Bench addresses critical limitations of conventional SAS systems: 33 | 34 | | Aspect | Traditional SAS | SAS-Bench Advantage | 35 | | -------------------------- | ----------------------------- | -------------------------------------------- | 36 | | **Evaluation Granularity** | Single composite score | Step-wise scoring decomposition | 37 | | **Explainability** | Opaque scoring mechanism | Comprehensive error taxonomy | 38 | | **Response Diversity** | Single-subject/type focus | Cross-disciplinary template-free evaluation | 39 | 40 | ### Dataset Characteristics 41 | 42 |

43 | SAS human annotation system 44 |

45 | 46 | Our dataset features three question types with rich annotations: 47 | 48 | 1. **Multiple-Choice Questions** (Template-free responses) 49 | 2. **Gap Filling Questions** 50 | 3. **Short Answer Questions** (With logical step decomposition) 51 | 52 | Each response includes: 53 | - ✅ Human-annotated holistic score 54 | - 🔍 Step segmentation with individual scoring 55 | - ❌ Step-wise Error causes classification 56 | 57 | ## 🌟 Evaluation Framework 58 | 59 | ### CCS Evaluation (Collaborative Consistency Score) 60 | 61 | **Purpose** 62 | Evaluates alignment between model predictions and human grading on both *holistic scores* and *step-wise scores*, ensuring models understand detailed reasoning. 63 | 64 | **Formula** 65 | The adjusted weight matrix combines overall and step-wise differences: 66 | ```math 67 | W_{i,j} = \alpha \cdot \frac{(r_i - r_j)^2}{(N_r - 1)^2} + \frac{1 - \alpha}{m} \sum_{k=1}^{m} \frac{(s_{i,k} - s_{j,k})^2}{(N_{s_k} - 1)^2} 68 | ``` 69 | Where: 70 | - $r_i, r_j$: Model/human overall scores 71 | - $s_{i,k}, s_{j,k}$: Step scores for step $k$ 72 | - $\alpha=0.5$: Balance weight 73 | - $N_r, N_{s_k}$: Possible score levels 74 | 75 | Final CCS calculation: 76 | ```math 77 | \text{CCS} := 1 - \frac{\sum_{i,j} O_{i,j} \cdot W_{i,j}}{\sum_{i,j} E_{i,j} \cdot W_{i,j}} 78 | ``` 79 | 80 | 81 | ### ECS Evaluation (Error Consistency Score) 82 | 83 | **Purpose** 84 | Quantifies how well the model identifies error types compared to human annotators, stratified by answer quality tiers. 85 | 86 | **Formula** 87 | 1. Partition samples into 3 groups (Low/Medium/High) using quantile thresholds $\tau_1, \tau_2$: 88 | ```math 89 | \phi(x) = \mathbb{I}(x \geq \tau_1) + \mathbb{I}(x \geq \tau_2) 90 | ``` 91 | 2. Compute error frequency matrices $\mathbf{M}^p_k, \mathbf{M}^g_k$ per group $k$ 92 | 3. Calculate Spearman correlation per group: 93 | ```math 94 | \rho_k = \text{SpearmanR}(\mathbf{M}^p_k, \mathbf{M}^g_k) 95 | ``` 96 | Final ECS: 97 | ```math 98 | \text{ECS} := \frac{1}{m} \sum_{k=0}^{2} \rho_k 99 | ``` 100 | 101 | **Key Features** 102 | - Uses **3 performance tiers** (m=3) for robust evaluation 103 | - Correlates **error type distributions** (not just counts) 104 | - Normalized scoring for cross-dataset comparison 105 | 106 | ## ⚙️ Installation Guide 107 | 108 | ### Core Dependencies 109 | ```bash 110 | pip install protobuf transformers>=4.44.1 cpm_kernels torch>=2.0 gradio mdtex2html sentencepiece accelerate json_repair openai 111 | ``` 112 | 113 | Alternative: 114 | ```bash 115 | pip install -r requirements.txt 116 | ``` 117 | 118 | ### vLLM Setup (Recommended) 119 | ```bash 120 | conda create -n vllm python=3.12 -y 121 | conda activate vllm 122 | pip install vllm # Requires CUDA 12.0+ 123 | ``` 124 | 125 | For other configurations, refer to official [vLLM installation](https://docs.vllm.ai/en/latest/getting_started/installation/gpu.html). 126 | 127 | ## 📊 Benchmark Workflow 128 | 129 | ![Workflow](./docs/assets/workflow.png) 130 | 131 | ### Directory Structure 132 | ``` 133 | |- discuss/ - Analysis scripts 134 | |- docs/ - Documentation assets 135 | |- main/ - Model training/inference code 136 | |- prompts/ - Predefined prompt templates 137 | |- sas_pipelines/ - Core evaluation scripts 138 | |- utils/ - Utility functions 139 | ``` 140 | 141 | ### Implementation Options 142 | 143 | #### 0. Data Preprocessing (Annotation Phase) 144 | - Raw annotated data resides in `backend_data` 145 | - Execute `preprocess.py` for data consolidation 146 | - Modify `DATANAME` variable to specify source files (omit extensions) 147 | 148 | > This process handles raw data from our annotation system (our system to be open-sourced). 149 | 150 | #### 1. Data Acquisition 151 | The dataset is available on [HuggingFace Dataset](https://huggingface.co/datasets/aleversn/SAS-Bench). Store downloaded files in `datasets/`: 152 | - Files follow `{q_id}_{course}_{question_type}.jsonl` naming 153 | - Error taxonomy in `error_type.jsonl`: 154 | ```json 155 | {"q_id": 2, "course": "", "question_type": "", "guideline": "", "errors": [{"name": "", "description": ""}...]} 156 | ``` 157 | - `ID_Dict.json` contains subject-ID mappings 158 | 159 | #### 2. LLM Prediction 160 | Flexible execution via Jupyter or CLI: 161 | 162 | **Option A: Jupyter Notebook** 163 | - Set `cmd_args = False` in `1_predict_scores.py` 164 | - Configure: 165 | - `save_type_name`: Model identifier/output prefix 166 | - `model_from_pretrained`: Model path 167 | - `file_name`: Dataset identifier (e.g., `7_Math_ShortAns`) 168 | 169 | **Option B: Command Line** 170 | Set `cmd_args = True` 171 | 172 | *Using vLLM (Recommended)*: 173 | ```bash 174 | cd sas_pipelines/ 175 | python 1_predict_scores.py --file_name=6_Chinese_ShortAns --save_type_name= --model_from_pretrained= --batch_size=1000 --vllm=1 176 | ``` 177 | 178 | *With Tensor Parallelism*: 179 | ```bash 180 | python 1_predict_scores.py --n_gpu=0,1 --file_name=6_Chinese_ShortAns --save_type_name= --model_from_pretrained= --batch_size=1000 --vllm=1 --tensor_parallel_size=2 181 | ``` 182 | 183 | *HuggingFace Predictor*: 184 | ```bash 185 | python 1_predict_scores.py --file_name=6_Chinese_ShortAns --save_type_name= --model_from_pretrained= --batch_size=5 186 | ``` 187 | 188 | *OpenAI API Predictor*: 189 | 1. Create `api_key.txt` in `sas_pipeline/` with format: 190 | ```text 191 | OpenAI 192 | Deepseek 193 | ``` 194 | 2. Execute: 195 | ```bash 196 | python 1_predict_scores.py --file_name=6_Chinese_ShortAns --llm_name=deepseek-chat --save_type_name=Deepseek_V3 197 | ``` 198 | 199 | **Additional Parameters**: 200 | - Few-shot learning: `--few_shot_num >0` 201 | - Disable guidelines: `--use_guideline=0` 202 | - Skip reasoning: `--skip_thinking=1` 203 | - `llm_name` defaults to `save_type_name` except for GLM3/OpenAI models 204 | 205 | #### 3. Prediction Processing 206 | **Option A: Jupyter** 207 | - Set `cmd_args = False` in `2_process_prediction.py` 208 | - Configure `file_name` (use `all` for batch processing) 209 | 210 | **Option B: CLI** 211 | ```bash 212 | python 2_process_prediction.py --file_name=all 213 | ``` 214 | 215 | #### 4. CCS Computation 216 | **Option A: Jupyter** 217 | - Configure `file_name` and `save_type_name` in `3_compute_ccs.py` 218 | 219 | **Option B: CLI** 220 | ```bash 221 | python 3_compute_ccs.py --save_type_name= 222 | ``` 223 | 224 | #### 5. ECS Computation 225 | **Option A: Jupyter** 226 | - Adjust parameters in `4_compute_ecs.py` 227 | 228 | **Option B: CLI** 229 | ```bash 230 | python 4_compute_ecs.py --save_type_name= 231 | ``` 232 | 233 | ## 📈 Model Performance Insights 234 | 235 | Our experiments with 16 LLMs reveal: 236 | 237 | - QWK 238 | 239 | ![Workflow](./docs/assets/qwk.png) 240 | 241 | - CCS 242 | 243 | | Models | Phy. (S.) | Phy. (M.) | His. (S.) | Geo. (S.) | Bio. (G.) | Chi. (G.) | Chi. (S.) | Math (S.) | Math (G.) | Pol. (S.) | Eng. (G.) | Che. (G.) | Avg. | 244 | |----------------------------|-----------|-----------|-----------|-----------|-----------|-----------|-----------|-----------|-----------|-----------|-----------|-----------|--------| 245 | | Deepseek-R1 | 38.43 | **95.01** | **80.98** | 67.92 | **79.12** | 95.09 | 69.07 | 57.85 | **83.56** | 71.92 | 73.19 | 72.92 | 73.76 | 246 | | QwQ-32B | 48.53 | 87.23 | 75.43 | **77.06** | 72.52 | **96.00** | 31.77 | 48.66 | 45.51 | 74.48 | 54.79 | 62.17 | 64.51 | 247 | | TinyR1-32B-Preview | 38.17 | 84.88 | 75.83 | 71.52 | 73.45 | 92.57 | 52.61 | 48.28 | 74.77 | 70.70 | 57.92 | 41.37 | 65.17 | 248 | | Qwen3-32B | 47.29 | 85.51 | 64.96 | 80.43 | 63.15 | 92.21 | 50.43 | 51.26 | 80.77 | 73.30 | 59.33 | 57.82 | 67.20 | 249 | | Qwen3-8B | 54.33 | 76.17 | 45.54 | 68.89 | 43.22 | 86.01 | 42.02 | 46.33 | 73.33 | 64.25 | 50.55 | 50.52 | 58.43 | 250 | | MiMo-7B-RL | 52.77 | 41.01 | 61.33 | 67.10 | 35.93 | 54.72 | 43.09 | 38.09 | 55.79 | 36.78 | 34.69 | 31.05 | 46.03 | 251 | | Deepseek-Prover-V2-7B | 22.59 | 10.75 | 2.92 | 30.71 | 50.63 | 55.48 | 12.95 | 0.87 | 2.29 | 10.44 | 30.19 | 28.76 | 21.55 | 252 | | DeepSeek-R1-Distill-7B | 33.71 | 29.24 | 50.92 | 32.35 | 52.18 | 52.44 | 44.29 | 29.52 | 39.55 | 53.77 | 32.98 | 34.27 | 40.44 | 253 | | Deepseek-V3 | 53.89 | 85.72 | 69.85 | 76.23 | 76.51 | 93.42 | **69.49** | **58.81** | 80.18 | **76.75** | **73.82** | **74.64** | **74.11** | 254 | | GPT 4o-mini-20240718 | **58.90** | 81.19 | 54.85 | 76.59 | 65.39 | 87.65 | 55.25 | 43.56 | 37.38 | 63.44 | 22.60 | 55.98 | 58.56 | 255 | | Llama3.3-70B-Instruct | 45.34 | 70.03 | 72.02 | 72.51 | 67.94 | 85.30 | 35.83 | 58.60 | 74.97 | 63.68 | 67.60 | 38.94 | 62.73 | 256 | | Mixtral 8×7B-Instruct | 30.78 | 42.27 | 33.43 | 4.99 | 44.45 | 29.85 | 24.00 | 26.73 | 70.04 | 43.92 | 33.40 | 42.05 | 35.49 | 257 | | Qwen2.5-32B-Instruct | 40.53 | 77.02 | 62.34 | 74.50 | 72.07 | 94.85 | 66.37 | 50.08 | 32.59 | 64.09 | 53.35 | 62.87 | 62.56 | 258 | | Qwen2.5-14B-Instruct | 53.76 | 66.12 | 60.96 | 74.30 | 67.50 | 92.81 | 63.08 | 43.28 | 75.62 | 62.03 | 56.34 | 57.53 | 64.44 | 259 | | GLM4-9B-Chat | 45.62 | 52.33 | 36.81 | 69.41 | 39.19 | 63.92 | 42.94 | 35.50 | 56.95 | 54.83 | 33.92 | 30.79 | 46.85 | 260 | | Llama3-8B-Instruct | 41.09 | 35.10 | 37.52 | 31.29 | 32.19 | 38.13 | 32.89 | 23.55 | 62.43 | 37.78 | 31.68 | 29.27 | 36.08 | 261 | 262 | - ECS 263 | 264 | | Models | Phy. (S.) | Phy. (M.) | His. (S.) | Geo. (S.) | Bio. (G.) | Chi. (G.) | Chi. (S.) | Math (S.) | Math (G.) | Pol. (S.) | Eng. (G.) | Che. (G.) | Avg. | 265 | |----------------------------|-----------|-----------|-----------|-----------|-----------|-----------|-----------|-----------|-----------|-----------|-----------|-----------|--------| 266 | | Deepseek-R1 | 23.25 | 30.59 | 57.53 | 56.08 | 69.20 | 86.04 | 72.68 | **94.29** | 15.20 | 65.56 | _18.65_ | _81.76_ | **55.90** | 267 | | QwQ-32B | 4.74 | **63.92** | 67.06 | _70.04_ | 53.68 | 51.08 | 69.20 | 79.05 | 16.82 | 48.81 | -22.53 | 48.94 | 45.90 | 268 | | TinyR1-32B-Preview | 3.10 | **63.92** | 65.71 | **77.02** | 56.61 | 64.42 | 74.83 | 82.86 | 23.33 | 40.17 | -31.52 | 17.35 | 44.82 | 269 | | Qwen3-32B | -4.17 | 24.18 | _69.52_ | 54.29 | 53.67 | 52.70 | 47.31 | 82.21 | 18.33 | 62.14 | -26.99 | 36.27 | 39.12 | 270 | | Qwen3-8B | 23.39 | **63.92** | 14.29 | -4.96 | 52.21 | 47.75 | 34.01 | 39.20 | -8.14 | 57.19 | -27.13 | 59.28 | 29.25 | 271 | | MiMo-7B-RL | **51.05** | 24.18 | 14.29 | 38.85 | 58.35 | _92.17_ | 63.07 | 13.39 | 35.12 | -27.10 | -4.41 | 1.04 | 30.00 | 272 | | Deepseek-Prover-V2-7B | -24.10 | -5.20 | 42.86 | -6.23 | 29.54 | -80.81 | 23.25 | 46.67 | -1.51 | -58.64 | -45.23 | -21.91 | -8.44 | 273 | | DeepSeek-R1-Distill-7B | -45.19 | 24.18 | 0.95 | -38.66 | 23.55 | -20.36 | 3.87 | -23.81 | -13.57 | -18.81 | -19.59 | -44.58 | -14.34 | 274 | | Deepseek-V3 | 7.79 | 46.58 | 58.10 | 32.62 | _72.38_ | **96.58** | 57.43 | _92.38_ | _33.33_ | 40.26 | **24.77** | **85.83** | _54.00_ | 275 | | GPT 4o-mini-20240718 | 17.91 | 24.18 | 62.14 | 36.68 | 55.20 | 79.01 | **78.00** | 67.62 | **46.90** | **92.31** | 10.04 | 36.39 | 50.53 | 276 | | Llama3.3-70B-Instruct | 22.56 | _57.35_ | 54.29 | 42.11 | 45.09 | 52.70 | 46.25 | 54.29 | 30.00 | 58.81 | -12.53 | -15.83 | 36.26 | 277 | | Mixtral 8×7B-Instruct | 11.99 | 17.34 | **80.38** | 35.84 | 32.74 | 42.77 | 75.82 | 56.19 | 30.00 | 6.84 | -31.16 | -7.18 | 29.30 | 278 | | Qwen2.5-32B-Instruct | 11.95 | 17.41 | 53.33 | 59.34 | 62.96 | 46.90 | 75.08 | 62.86 | 30.00 | 46.67 | -4.50 | 27.08 | 40.76 | 279 | | Qwen2.5-14B-Instruct | 21.50 | 24.18 | 47.92 | 37.43 | **73.36** | 64.97 | 74.32 | 64.94 | 18.21 | 61.97 | -20.00 | 47.39 | 43.02 | 280 | | GLM4-9B-Chat | 35.00 | 24.18 | 32.49 | 34.73 | 62.12 | 20.36 | _77.34_ | 63.81 | **46.90** | _82.40_ | -25.35 | 7.18 | 38.43 | 281 | | Llama3-8B-Instruct | _48.25_ | 27.46 | 17.23 | 31.58 | 61.37 | -14.05 | 41.23 | 57.77 | 21.55 | -69.07 | -26.50 | -27.19 | 14.14 | 282 | 283 | ## 📅 TO-DO 284 | 285 | - [ ] Provide English-localized dataset version 286 | - [ ] Open-source the annotation system (frontend & backend) 287 | 288 | ## 📜 License 289 | SAS-Bench is released under `Apache License 2.0`. The dataset is available for research purposes only. 290 | 291 | > Our questions collect from a publicly available dataset [Gaokao-Bench](https://github.com/OpenLMLab/GAOKAO-Bench) based on China's National College Entrance Examination (Gaokao). 292 | 293 | ## 📚 Citation 294 | ```bibtex 295 | @article{lai2025sasbenchfinegrainedbenchmarkevaluating, 296 | title={SAS-Bench: A Fine-Grained Benchmark for Evaluating Short Answer Scoring with Large Language Models}, 297 | author={Peichao Lai and Kexuan Zhang and Yi Lin and Linyihan Zhang and Feiyang Ye and Jinhao Yan and Yanwei Xu and Conghui He and Yilei Wang and Wentao Zhang and Bin Cui}, 298 | year={2025}, 299 | journal={arXiv preprint arXiv:2505.07247}, 300 | primaryClass={cs.CL}, 301 | url={https://arxiv.org/abs/2505.07247}, 302 | } 303 | ``` 304 | -------------------------------------------------------------------------------- /main/models/chatglm_rlhf.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import random 4 | import numpy as np 5 | from transformers import AutoTokenizer, AutoModel, AutoConfig 6 | from functools import partial 7 | 8 | 9 | class CriticModel(nn.Module): 10 | def __init__(self, model_from_pretrained, resume_path=None, layers_keep=1) -> None: 11 | super().__init__() 12 | self.config = AutoConfig.from_pretrained( 13 | model_from_pretrained, trust_remote_code=True) 14 | self.config.num_layers = layers_keep 15 | model = AutoModel.from_pretrained( 16 | model_from_pretrained, trust_remote_code=True, config=self.config).to(torch.bfloat16) 17 | model = model.transformer 18 | # solve RuntimeError: "LayerNormKernelImpl" not implemented for 'Half' 19 | self.model = model 20 | self.output_linear = nn.Linear( 21 | self.config.hidden_size, 1, device=self.model.device, dtype=self.model.dtype) 22 | if resume_path is not None: 23 | self.output_linear.load_state_dict(torch.load(resume_path)) 24 | 25 | def forward(self, **kwargs): 26 | output = self.model(**kwargs) 27 | values = torch.tanh(self.output_linear(output.last_hidden_state)) 28 | return values.transpose(0, 1).squeeze(-1) 29 | 30 | 31 | class RewardModel(nn.Module): 32 | def __init__(self, model_from_pretrained) -> None: 33 | super().__init__() 34 | # Load model from HuggingFace Hub 35 | tokenizer = AutoTokenizer.from_pretrained( 36 | model_from_pretrained) 37 | model = AutoModel.from_pretrained(model_from_pretrained) 38 | model.eval() 39 | self.model = model 40 | self.tokenizer = tokenizer 41 | 42 | def mean_pooling(self, model_output, attention_mask): 43 | # First element of model_output contains all token embeddings 44 | token_embeddings = model_output[0] 45 | input_mask_expanded = attention_mask.unsqueeze( 46 | -1).expand(token_embeddings.size()).float() 47 | return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) 48 | 49 | @staticmethod 50 | def jaccard(s1, s2): 51 | assert len(s1)+len(s2) > 0 52 | s1 = set(s1) 53 | s2 = set(s2) 54 | s_or = s1 | s2 55 | s_and = s1 & s2 56 | jaccard_distance = len(s_and)/len(s_or) 57 | return jaccard_distance 58 | 59 | def forward(self, gen_texts=["I am the generated content from LLM."], 60 | gold_answers=['Generation from LLM.', "I am the output content from LLM."], 61 | bad_answers=['I am human.', 'I am the content from human writing.'], 62 | weight_for_cos_and_jaccard=[0.5, 0.5]): 63 | examples = gold_answers + bad_answers 64 | example_num = len(examples) 65 | assert len(gen_texts) > 0 and example_num > 0 66 | reward_direction = torch.ones(example_num, device=self.model.device) 67 | reward_direction[len(gold_answers):] = -1 68 | sentences = gen_texts + examples 69 | # Tokenize sentences 70 | encoded_input = self.tokenizer( 71 | sentences, padding=True, return_tensors='pt') 72 | ids = self.tokenizer.batch_encode_plus( 73 | sentences, add_special_tokens=False)["input_ids"] 74 | # temporary truncate position_ids 75 | batch_size, max_seq_len = encoded_input["input_ids"].shape 76 | if max_seq_len > self.model.config.max_position_embeddings: 77 | encoded_input["position_ids"] = torch.arange( 78 | max_seq_len).expand((1, -1)).repeat(batch_size, 1) 79 | encoded_input["position_ids"] = encoded_input["position_ids"] / \ 80 | max_seq_len*self.model.config.max_position_embeddings 81 | encoded_input["position_ids"] = encoded_input["position_ids"].floor( 82 | ).long() 83 | # Compute token embeddings 84 | with torch.no_grad(): 85 | encoded_input = encoded_input.to(self.model.device) 86 | model_output = self.model(**encoded_input) 87 | # Perform pooling. In this case, max pooling. 88 | sentence_embeddings = self.mean_pooling( 89 | model_output, encoded_input['attention_mask']) 90 | gen_text_vecs = sentence_embeddings[:len(gen_texts)] 91 | answers_vecs = sentence_embeddings[len(gen_texts):] 92 | reward_ = [] 93 | for i in range(gen_text_vecs.shape[0]): 94 | gen_text_vecs_ = gen_text_vecs[i:i+1] 95 | # 用一下广播计算cos 96 | coses = torch.cosine_similarity( 97 | gen_text_vecs_, answers_vecs, dim=1) 98 | # 余弦截断 99 | coses[(coses < 0)] = 0 100 | # 计算 jaccard距离 101 | jaccard_s1 = partial(RewardModel.jaccard, ids[i]) 102 | jaccards = torch.tensor(np.vectorize(jaccard_s1)(np.array( 103 | ids[-len(examples):], dtype=object)), dtype=coses.dtype, device=coses.device) 104 | similarity = weight_for_cos_and_jaccard[0] * \ 105 | coses + weight_for_cos_and_jaccard[1]*jaccards 106 | value, index = similarity.max(dim=-1) 107 | reward_.append(value*reward_direction[index]) 108 | reward = torch.stack(reward_) 109 | return reward 110 | 111 | class PPO(nn.Module): 112 | def __init__(self, tokenizer, qa_logs=None): 113 | super().__init__() 114 | self.tokenizer = tokenizer 115 | self.qa_logs = qa_logs 116 | # get weight matrix 117 | self.decay_up_matrix_T = self.get_decay_up_matrix_T() 118 | 119 | def get_log_prob(self, generated_outputs, input_ids, gen_method = "greedy_search"): 120 | # beam_search generate 给出来的scores就是log_prob了,所以直接gather获取即可 121 | gen_sequences = generated_outputs.sequences[:, input_ids.shape[-1]:] 122 | # let's stack the logits generated at each step to a tensor 123 | # 要小心greedy search 拿到的是score,需要再log_softmax 124 | # 而beam_search 拿到的已经是log_softmax了 125 | scores = torch.stack(generated_outputs.scores, dim=1) 126 | # if scores.max() >0 : 127 | # gen_method = "greedy_search" 128 | if gen_method == "beam_search": 129 | log_prob_stacked = scores 130 | else: 131 | log_prob_stacked = torch.stack(generated_outputs.scores, dim=1).log_softmax(dim=-1) 132 | # now we need to collect the log_prob of the generated token # we need to add a dummy dim in the end to make gather work 133 | log_prob = torch.gather(log_prob_stacked, 2, gen_sequences[:, :, None]).squeeze(-1) 134 | return log_prob 135 | 136 | def get_log_probs_with_input_ids(self, actor_model, states, gen_max_len): 137 | input_ids = states 138 | output = actor_model(input_ids) #将已经生成的序列放进去计算,再次计算得到目标action也就是后续字符的概率或者log_prob值 139 | logits = output.logits[:, -(gen_max_len+1):-1].log_softmax(dim=-1) # 比先softmax再log好,复杂度减小,并且解决些nan问题 140 | new_log_probs = logits.gather(dim=-1, index=input_ids[:, -gen_max_len:].unsqueeze(-1)).squeeze(-1) 141 | return new_log_probs, output.logits 142 | 143 | def process_response(self, output): 144 | content = "" 145 | for response in output.split("<|assistant|>"): 146 | metadata, content = response.split("\n", maxsplit=1) 147 | if not metadata.strip(): 148 | content = content.strip() 149 | content = content.replace("[[训练时间]]", "2023年") 150 | else: 151 | content = {"name": metadata.strip(), "content": content} 152 | return content 153 | 154 | def generate_with_rlhf(self, actor_model, input_ids, query, num_beams=1, num_return_sequences=1, max_new_tokens=8): 155 | ''' 156 | `params:` 157 | - input_ids: [batch_size, seq_len] 158 | - query: list, the user query content 159 | - num_beams: int, 3, 2 # set bigger if you have bigger compute memory 160 | - num_return_sequences: int, 3, 2 # set bigger if you have bigger compute memory 161 | - max_new_tokens: int, the max token that LLM can generate for new content 162 | 163 | `return:` 164 | - sequences: the generated ids of sequences 165 | - log_probs: the log_probs of the generated ids 166 | - gen_texts: the generated texts of sequences, which clip with max length. 167 | ''' 168 | assert num_beams >= num_return_sequences, "candidates num should greater than returns num" 169 | gen_method = "greedy_search" if num_beams == 1 else "beam_search" 170 | # 把问题送入模型中,获得问题的输出 171 | if hasattr(actor_model, 'module'): 172 | unwrapped_model = actor_model.module 173 | else: 174 | unwrapped_model = actor_model 175 | generate_ = unwrapped_model.generate(input_ids=input_ids, do_sample=False, num_beams=num_beams, max_new_tokens=max_new_tokens, 176 | num_return_sequences=num_return_sequences, use_cache=True, num_beam_groups=1, output_scores=True, 177 | output_hidden_states=False, return_dict_in_generate=True) 178 | sequences = generate_.sequences 179 | log_probs = self.get_log_prob(generated_outputs=generate_, input_ids=input_ids, gen_method=gen_method) 180 | gen_texts = self.tokenizer.batch_decode(sequences) 181 | gen_texts = [self.process_response(text) for text in gen_texts] 182 | 183 | for i, q in enumerate(query): 184 | cur_gen_texts = gen_texts[i * num_return_sequences : (i + 1) * num_return_sequences] 185 | if self.qa_logs is not None: 186 | if q not in self.qa_logs: 187 | self.qa_logs[q] = [] 188 | self.qa_logs[q] += cur_gen_texts # 将本query的答案保存在qa_logs中;对于同样的query,若多次生成回答,则使用extend方法进行全部存储 189 | 190 | return sequences, log_probs, gen_texts, None 191 | 192 | def generate_with_ft(self, actor_model, input_ids, last_assistant_content, gen_max_len): 193 | ''' 194 | the target sentence is directly used to improve the probability of the RL. zh: 目标句直接用RL提升它的概率 195 | 196 | `params:` 197 | - input_ids: [batch_size, seq_len], query ids with answer 198 | - last_assistant_content: str, the standard answer 199 | - gen_max_len: str, the max length of answer ids 200 | 201 | `return:` 202 | - sequences: the generated ids of sequences, in here is the original input_ids 203 | - log_probs: the log_probs of the input_ids. 204 | - gen_texts: the generated texts of sequences, in here is the original last_assistant_content. 205 | ''' 206 | sequences = input_ids 207 | with torch.no_grad(): 208 | log_probs, logits = self.get_log_probs_with_input_ids(actor_model, input_ids, gen_max_len=gen_max_len) 209 | should_gen_texts = last_assistant_content 210 | return sequences, log_probs, should_gen_texts, logits 211 | 212 | def get_decay_up_matrix_T(self, max_length=2048, gamma=0.99, tau=0.95): 213 | ''' 214 | 生成衰减矩阵 215 | 216 | `params:` 217 | - max_length: int 218 | - gamma: float 219 | - tau: float 220 | 221 | `return:` 222 | - decay_up_matrix_T: torch.Tensor 223 | ''' 224 | decay = gamma * tau # 衰减系数 225 | decay_row = torch.ones(max_length).float() * decay 226 | decay_row[0] = 1 227 | decay_row_cross_time = decay_row.cumprod(dim=-1) # 使用cumprod进行连乘,形成(gamma*tau),(gamma*tau)^2,...,(gamma*tau)^2048这样的结构 228 | assert decay_row_cross_time.sign().min() == 0 229 | decay_up_matrix = torch.zeros((max_length, max_length)).float() 230 | for i in range(max_length): 231 | decay_row = decay_row_cross_time.roll(i) 232 | decay_row[:i] = 0 # 确保看不见前面的 233 | decay_up_matrix[i] = decay_row 234 | decay_up_matrix_T = decay_up_matrix.T # 先进行转置,因为后面需要用到矩阵乘法 235 | return decay_up_matrix_T 236 | 237 | def gae_vectorize(self, values, rewards, masks=None): 238 | """ 239 | `params:` 240 | - values: `[batch_size, sequence_length]`, 表示各个时间步状态的状态值。 241 | - rewards: `[batch_size, sequence_length]`, 表示各个时间步做出的动作的奖励,对于gpt当前动作也是动作对应的下一状态。所以shape和values一样 242 | **注意这里的`rewards`表示当前动作状态的`reward`** 243 | - masks: 由于是要对生成的`actions`做`gae`,也就是泛化优势估计, 244 | 所以类似以往的`mask`只需要对`padding`进行`mask`, 245 | 因为`padding`的`delta`会被放入加权计算,而`action`前面的`delta`, 246 | 由于生成的衰减矩阵就是上三角的,自然就看不到前面的。 247 | `0`表示`mask`, `1`表示需要的。 248 | """ 249 | action_rewards = rewards.roll(-1) # 当前状态的动作的奖励是下一个状态出现时给出的,而奖励是基于状态计算的,所以需要shift一个时间步回去 250 | # 为了学到最后输出的,所以给最后的状态赋予一个rewards试试 251 | action_rewards = (action_rewards + rewards) / 2 # 将奖励分配到最后两步 252 | 253 | values_estimator_1_order = action_rewards + values.roll(-1) # 这里要注意roll是循环的,所以最后一位的值可能不能用 254 | deltas = values_estimator_1_order - values #必须要action+下一个时刻的值函数减去当前值函数,这是表示当前action的优势 255 | # 计算gae 256 | max_goal_length = deltas.shape[-1] 257 | sub_decay_up_matrix_T = self.decay_up_matrix_T[:max_goal_length, :max_goal_length].to(deltas.device) 258 | if masks is not None: 259 | deltas = deltas * masks 260 | gae = deltas.matmul(sub_decay_up_matrix_T) 261 | assert gae.shape == deltas.shape 262 | return gae 263 | 264 | def forward(self, is_rlhf, actor_model, reward_model, critic_model, input_ids, input_ids_without_last_turn, last_input_len, query, last_assistant_content, gold_answers, bad_answers, num_beams, num_return_sequences, ppo_epochs, **args): 265 | if is_rlhf: 266 | input_ids = input_ids_without_last_turn 267 | max_new_tokens = torch.max(last_input_len).item() 268 | sequences, log_probs, gen_texts, logits = self.generate_with_rlhf(actor_model, input_ids, query, num_beams=num_beams, num_return_sequences=num_return_sequences, max_new_tokens=max_new_tokens) 269 | 270 | else: 271 | max_new_tokens = torch.max(last_input_len).item() 272 | sequences, log_probs, gen_texts, logits = self.generate_with_ft(actor_model, input_ids, last_assistant_content, max_new_tokens) 273 | 274 | # compute reward for generated sequences 275 | reward = [] 276 | batch_size = input_ids.shape[0] 277 | for i in range(batch_size): 278 | if is_rlhf: 279 | b_gen_texts = gen_texts[i * num_return_sequences : (i + 1) * num_return_sequences] 280 | else: 281 | b_gen_texts = [gen_texts[i]] 282 | b_gold_answers = gold_answers[i] 283 | b_bad_answers = bad_answers[i] 284 | b_reward = reward_model(gen_texts=b_gen_texts, gold_answers=b_gold_answers, bad_answers=b_bad_answers).unsqueeze(1) 285 | reward.append(b_reward) 286 | reward = torch.cat(reward, dim=0) 287 | assert reward.shape == (len(gen_texts), 1), "need unsqueeze for next scatter_" 288 | rewards = torch.zeros_like(sequences, dtype=reward.dtype) 289 | pad_id = self.tokenizer.convert_tokens_to_ids("") 290 | masks = (sequences!=pad_id).long() 291 | final_position = (sequences[:,input_ids.size(-1):] != pad_id).sum(dim=-1) + input_ids.size(-1) - 1 292 | index = final_position.unsqueeze(-1) 293 | rewards.scatter_(dim=1, index=index, src=reward) 294 | # 确保都放到values所在的device 295 | 296 | torch.cuda.empty_cache() 297 | 298 | for ppo_epoch in range(ppo_epochs): 299 | # compute new log probs 300 | new_log_probs, _ = self.get_log_probs_with_input_ids(actor_model, sequences, log_probs.shape[1]) 301 | entropy = 0 # 暂时不需要熵的约束 302 | # compute value 303 | # 到奖励模型和值函数模型的输入可以是一样的都是生成的序列。 304 | # 生成序列同时包括state和next action 305 | # prepare input for critic model 306 | input_ids_critic = sequences 307 | values = critic_model(input_ids=input_ids_critic) 308 | # compute gae 309 | gae = self.gae_vectorize(values=values, rewards=rewards, masks=masks) 310 | advantages = gae[:, -log_probs.shape[-1]:] 311 | # 计算value的估计量的偏差作为actor loss 312 | # 以及ppo的actor_loss 313 | value_estimator_delta = advantages 314 | ratio = (new_log_probs - log_probs).exp() 315 | # print("reward",reward, "ratio:", ratio, sep="\n") 316 | if torch.isinf(ratio).any(): 317 | break 318 | surr1 = ratio * advantages 319 | surr2 = torch.clamp(ratio, 1.0 - 0.2, 1.0 + 0.2) * advantages 320 | actor_loss = - torch.min(surr1, surr2).mean() 321 | critic_loss = value_estimator_delta.square().mean() 322 | loss = 0.5 * (critic_loss + actor_loss) - 0.001 * entropy 323 | yield loss, logits --------------------------------------------------------------------------------