├── docs
├── assets
│ ├── qwk.png
│ ├── workflow.png
│ ├── annotation.png
│ └── logo.svg
└── Readme_cn.md
├── requirements.txt
├── .gitignore
├── prompts
├── few_shot_prompt.json
└── predict.txt
├── utils
├── cal_steps.py
├── load_few_shot.py
├── errors_consistency_score.py
└── collaborative_consistency_score.py
├── sas_pipelines
├── play.py
├── 3_compute_ccs.py
├── 2_process_prediction.py
├── 4_compute_ecs.py
└── 1_predict_scores.py
├── main
├── models
│ ├── chatglm.py
│ └── chatglm_rlhf.py
├── loaders
│ ├── llm_pure_text.py
│ ├── llm_chat.py
│ ├── chatglm_chat.py
│ ├── qwen_chat.py
│ └── chatglm_rlhf.py
├── analysis.py
├── predictor
│ ├── openai.py
│ ├── qwen_lora.py
│ ├── vllm.py
│ ├── chatglm.py
│ ├── chatglm_lora.py
│ ├── llm_lora.py
│ └── llm.py
├── loader.py
└── trainer
│ ├── llm_lora.py
│ └── chatglm_rlhf.py
├── discuss
├── sample_distribution.py
├── pred_gold.py
├── error_causes_f1.py
└── compute_range_ccs.py
├── preprocess.py
├── LICENSE
└── Readme.md
/docs/assets/qwk.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PKU-DAIR/SAS-Bench/HEAD/docs/assets/qwk.png
--------------------------------------------------------------------------------
/docs/assets/workflow.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PKU-DAIR/SAS-Bench/HEAD/docs/assets/workflow.png
--------------------------------------------------------------------------------
/docs/assets/annotation.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PKU-DAIR/SAS-Bench/HEAD/docs/assets/annotation.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | protobuf
2 | transformers>=4.44.1
3 | cpm_kernels
4 | torch>=2.0
5 | gradio
6 | mdtex2html
7 | sentencepiece
8 | accelerate
9 | json_repair
10 | openai
11 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .vscode
2 | ./*.txt
3 | **/**/__pycache__
4 | model/
5 | save_model/
6 | data_record/
7 | data/
8 | chroma_data/
9 | datasets_*_Scored/
10 | **/**/api_key.txt
11 | results/
--------------------------------------------------------------------------------
/prompts/few_shot_prompt.json:
--------------------------------------------------------------------------------
1 | {
2 | "prefix": "【参考示例】",
3 | "suffix": "以上为参考样例和参考输出格式,请参考这些样例的评分格式对下面的真实材料进行评估",
4 | "template": "\n- 试题内容:{question}\n- 题目分值:{total}\n- 标准答案:{reference}\n- 解析说明:{analysis}\n- 学生作答:{student_answer}\n输出: {output}"
5 | }
--------------------------------------------------------------------------------
/utils/cal_steps.py:
--------------------------------------------------------------------------------
1 | # %%
2 | import os
3 | import json
4 | import numpy as np
5 |
6 | DIR = '/home/lpc/repos/SAS_Benchmark/datasets'
7 | FILES = os.listdir(DIR)
8 |
9 | for file_name in FILES:
10 | if len(file_name.split('_')) < 3:
11 | continue
12 | path = os.path.join(DIR, file_name)
13 | step_count = []
14 | length_count = []
15 | with open(path) as f:
16 | ori_data = f.readlines()
17 | ori_data = [json.loads(item) for item in ori_data]
18 | for item in ori_data:
19 | step_count.append(len(item['steps']))
20 | len_count = 0
21 | for step_item in item['steps']:
22 | len_count += len(list(step_item['response']))
23 | length_count.append(len_count)
24 |
25 | print(file_name, np.mean(step_count), np.mean(length_count))
26 |
27 | # %%
28 |
--------------------------------------------------------------------------------
/prompts/predict.txt:
--------------------------------------------------------------------------------
1 | 请作为数学学科评分专家,根据以下要求对学生的作答进行专业评估:
2 |
3 | 【评估任务】
4 | 依据题目信息、参考答案及评分指南,对学生的分步解答进行精细化评分,并输出结构化评分结果。
5 |
6 | 【评分指南】
7 | {score_guideline}
8 | {few_shot_samples}
9 | 【评估材料】
10 | - 试题内容:{question}
11 | - 题目分值:{total}
12 | - 错因类型:{error_type}
13 | - 标准答案:{reference}
14 | - 解析说明:{analysis}
15 | - 学生作答:{student_answer}
16 |
17 | 【评估流程和要求】
18 | 1. 分步解析:
19 | - 拆解学生作答的每个解题步骤
20 | - 对每个步骤独立评估:
21 | * 判断正误('label')
22 | * 如存在错误,从错因列表中选取1项或多项主因('errors')
23 | - 单步评估格式:{{'step_score': 单步分数, 'errors': [错因]}}
24 |
25 | 2. 综合评定:
26 | - 汇总各步骤得分计算总分
27 | - 给出整体评价('label')
28 |
29 | 3. 结果输出:
30 | - 采用标准JSON格式输出:
31 | {{
32 | 'total': 总分,
33 | 'pred_score': 评估总分数,
34 | 'steps': [各步骤评估结果]
35 | }}
36 | - 'pred_score'必须在'total'范围内
37 | - 分步的'step_score'累积值也必须在0到'pred_score'范围内
38 |
39 | 请按照上述规范完成评分,并以`JSON`格式输出标准化的评估结果。
--------------------------------------------------------------------------------
/docs/assets/logo.svg:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/sas_pipelines/play.py:
--------------------------------------------------------------------------------
1 | # %%
2 | import os
3 | os.environ['CUDA_VISIBLE_DEVICES'] = '1'
4 | import sys
5 | sys.path.append('../')
6 | from main.predictor.llm import Predictor
7 |
8 | pred = Predictor(model_from_pretrained='/home/lpc/models/glm-4-9b-chat/')
9 |
10 | # %%
11 | ask_content = '''请你现在扮演一位物理学科评分专家, 我们提供了一道题目和一个学生的回答, 请你根据题目分值、评分指南、错因列表,对照参考答案和解题分析,遵循以下规则, 对学生回答进行评分。
12 | 规则: 1. 你需要按学生回答的`每个步骤`分别进行评分,并对于有错误的答案从`错因列表`中选取一到多个错因作为打分依据,每个步骤的评估结果以{{'label': '', 'errors': []}}的形式输出。
13 | 2. 你需要根据每个步骤的评分结果输出最终的总分。
14 | 3. 最终的输出结果为json格式的数据,参考格式为{{'total': '', 'label': '', 'steps': []}}。其中`total`表示总分,`label`表示最终的评分结果,`steps`表示每个步骤的评分结果。
15 | 题目: {question}
16 | 分值: {total}
17 | 错因列表: {error_type}
18 | 参考答案: {reference}
19 | 解题分析: {analysis}
20 | 学生回答: {student_answer}
21 | 请你根据以上信息进行评分,并输出评分结果。
22 | '''
23 |
24 | ask_content = '''请作为物理学科评分专家,根据以下要求对学生的作答进行专业评估:
25 |
26 | 【评估任务】
27 | 依据题目信息、参考答案及评分标准,对学生的分步解答进行精细化评分,并输出结构化评分结果。
28 |
29 | 【评估流程】
30 | 1. 分步解析:
31 | - 拆解学生作答的每个解题步骤
32 | - 对每个步骤独立评估:
33 | * 判断正误('label')
34 | * 如存在错误,从错因列表中选取1项或多项主因('errors')
35 | - 单步评估格式:{{'label': '', 'errors': []}}
36 |
37 | 2. 综合评定:
38 | - 汇总各步骤得分计算总分
39 | - 给出整体评价('label')
40 |
41 | 3. 结果输出:
42 | - 采用标准JSON格式:
43 | {{
44 | 'total': '总分',
45 | 'label': '总体评价',
46 | 'steps': [各步骤评估结果]
47 | }}
48 |
49 | 【评估材料】
50 | - 试题内容:{question}
51 | - 题目分值:{total}
52 | - 错因类型:{error_type}
53 | - 标准答案:{reference}
54 | - 解析说明:{analysis}
55 | - 学生作答:{student_answer}
56 |
57 | 【评分准则】
58 | 1. 严格对照标准答案的解题逻辑链
59 | 2. 错因标注需精准对应学生错误本质
60 | 3. 保持不同步骤间的评分尺度一致性
61 | 4. 对于创新解法需额外验证其科学性
62 |
63 | 【特别说明】
64 | 1. 公式错误、单位遗漏等细节问题需单独标注
65 | 2. 概念性错误与计算错误需区分处理
66 | 3. 部分正确的情况应给予相应步骤分
67 |
68 | 请按照上述规范完成评分,并输出标准化的评估结果。'''
69 |
70 | pred(ask_content, build_message=True)
71 |
72 | # %%
73 |
--------------------------------------------------------------------------------
/main/models/chatglm.py:
--------------------------------------------------------------------------------
1 | from re import A
2 | import torch
3 | import torch.nn as nn
4 | from transformers import AutoTokenizer, AutoModel
5 |
6 | class CCGPTModel():
7 |
8 | def __init__(
9 | self,
10 | model_name: str = None,
11 | model_from_pretrained: str = None,
12 | model_config_file_name: str = None,
13 | pretrained_file_name: str = None,
14 | ):
15 | '''
16 | CCGPTModel: 生成式模型 (Generative model)
17 |
18 | ### Args:
19 | `model_name`: 模型名称 (the name of the model)
20 | `model_from_pretrained`: 从预训练模型中加载模型 (load model from pretrained model)
21 | `model_config_file_name`: bert配置文件名 (bert config file name)
22 | `pretrained_file_name`: 预训练模型文件名 (pretrained file name)
23 | `tagset_size`: 标签数量 (the number of tags)
24 | '''
25 | self.model_name = model_name
26 | self.model_from_pretrained = model_from_pretrained
27 | self.model_config_file_name = model_config_file_name
28 | self.pretrained_file_name = pretrained_file_name
29 |
30 | if self.model_name is None:
31 | raise ValueError("model_name is required")
32 | if self.model_from_pretrained is None:
33 | if self.model_config_file_name is None:
34 | raise ValueError("model_config_file_name is required")
35 | if self.pretrained_file_name is None:
36 | raise ValueError("pretrained_file_name is required")
37 |
38 | self.load_model()
39 |
40 | def load_model(self):
41 | if self.model_name == 'ChatGLM2-6B':
42 | self.model = AutoModel.from_pretrained(self.model_from_pretrained, trust_remote_code=True).half().cuda()
43 |
44 | def get_model(self):
45 | return self.model
46 |
47 | def __call__(self):
48 | return self.get_model()
49 |
--------------------------------------------------------------------------------
/discuss/sample_distribution.py:
--------------------------------------------------------------------------------
1 | # %%
2 | import os
3 | import json
4 | import numpy as np
5 | import matplotlib.pyplot as plt
6 | from argparse import ArgumentParser
7 |
8 | plt.style.use('seaborn-v0_8-bright')
9 |
10 | import sys
11 | sys.path.append("../")
12 | cmd_args = True
13 |
14 | parser = ArgumentParser()
15 | parser.add_argument('--file_dir', default='../datasets', help='the directory of the datasets.')
16 | parser.add_argument('--file_name', default='10_Math_gapfilling', help='file name of the dataset, you should make sure it contains `test.jsonl` file')
17 | parser.add_argument('--save_fig', default=0, help='whether save the figure')
18 |
19 | if not cmd_args:
20 | args = parser.parse_args([]) # You can directly set above parameters in the default.
21 | else:
22 | args = parser.parse_args()
23 |
24 | SAMPLE_PATH = os.path.join(args.file_dir, args.file_name + '.jsonl')
25 |
26 | with open(SAMPLE_PATH) as f:
27 | ori_data = f.readlines()
28 | ori_data = [json.loads(item) for item in ori_data]
29 |
30 | labels = []
31 | steps = []
32 | lengths = []
33 | for item in ori_data:
34 | label = float(item['manual_label']) / float(item['total'])
35 | step_count = len(item['steps'])
36 | length = 0
37 | for step in item['steps']:
38 | length += len(step['response'])
39 | labels.append(label)
40 | steps.append(step_count)
41 | lengths.append(length)
42 |
43 | def normalize(data):
44 | return (data - np.min(data)) / (np.max(data) - np.min(data))
45 |
46 |
47 | # %%
48 | plt.hist([item for item in labels], range=(0, 1), bins=20, alpha=0.8, color='#FFC000', label='Label')
49 | plt.hist(normalize(steps), range=(0, 1), bins=20, alpha=0.6, color='#d3d3f9', label='Steps')
50 | # plt.hist(normalize(lengths), bins=20, alpha=0.3, color='#A6C9E8', label='Length')
51 | plt.legend()
52 | if str(args.save_fig) == '1':
53 | plt.savefig(f'{args.file_name}_distribution.svg', format='svg', dpi=300)
54 |
55 | # %%
56 |
--------------------------------------------------------------------------------
/utils/load_few_shot.py:
--------------------------------------------------------------------------------
1 | # %%
2 | import json
3 | import random
4 | from tqdm import tqdm
5 |
6 | def get_few_shot_samples(file_name, num_samples=3):
7 | with open(file_name) as f:
8 | ori_data = f.readlines()
9 | ori_data = [json.loads(item) for item in ori_data]
10 | low = []
11 | mid = []
12 | high = []
13 | for item in tqdm(ori_data):
14 | total = float(item['total'])
15 | manual_score = float(item['manual_label'])
16 | norm_score = manual_score / total
17 | if norm_score < 0.333:
18 | low.append(item)
19 | elif norm_score < 0.667:
20 | mid.append(item)
21 | else:
22 | high.append(item)
23 | results = []
24 | for i in range(num_samples):
25 | if i % num_samples == 0 and len(low) > 0:
26 | results.append(random.choice(low))
27 | continue
28 | if i % num_samples == 1 and len(mid) > 0:
29 | results.append(random.choice(mid))
30 | continue
31 | if i % num_samples == 2 and len(high) > 0:
32 | results.append(random.choice(high))
33 | return results
34 |
35 | def compute_few_shot_prompt(sample, prompt):
36 | output = {}
37 | output_steps = []
38 | steps = sample['steps']
39 | for s in steps:
40 | os = {}
41 | os['step_score'] = int(s['label'])
42 | os['errors'] = s['errors']
43 | output_steps.append(os)
44 | reponse_content = []
45 | for s_idx, step in enumerate(steps):
46 | response = step['response']
47 | reponse_content.append(f'## Step {s_idx}. {response}')
48 | output['total'] = sample['total']
49 | output['pred_score'] = sample['manual_label']
50 | output['steps'] = output_steps
51 | format_prompt = prompt.format(question=sample['question'], total=sample['total'], reference=sample['reference'], analysis=sample['analysis'], student_answer=''.join(reponse_content), output=json.dumps(output, ensure_ascii=False))
52 | return format_prompt
53 |
54 | # %%
55 |
--------------------------------------------------------------------------------
/main/loaders/llm_pure_text.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import torch
4 | import random
5 | import pickle
6 | import numpy as np
7 | from tqdm import tqdm
8 | from torch.utils.data import Dataset, DataLoader
9 | from transformers import PretrainedConfig, PreTrainedTokenizer
10 |
11 | class LLMPureTextDataset(Dataset):
12 |
13 | config: PretrainedConfig
14 | tokenizer: PreTrainedTokenizer
15 |
16 | def __init__(self, tokenizer, config, file_name, max_length=512, do_shuffle=False):
17 | self.config = config
18 | self.tokenizer = tokenizer
19 | self.max_length = max_length
20 | self.do_shuffle = do_shuffle
21 | self.data = self.load_jsonl(file_name)
22 | self.random_list = [idx for idx in range(len(self.data))]
23 | if self.do_shuffle:
24 | random.shuffle(self.random_list)
25 |
26 |
27 | def load_jsonl(self, file_name):
28 | with open(file_name, 'r') as f:
29 | data = [json.loads(line) for line in f]
30 | return data
31 |
32 | def __getitem__(self, index):
33 | index = self.random_list[index]
34 | data = self.data[index]
35 | context = data['context']
36 | target = data['target']
37 |
38 | mx = self.max_length // 2
39 | context_ids = self.tokenizer.encode(context, max_length=mx, truncation=True)
40 | target_ids = self.tokenizer.encode(
41 | target,
42 | max_length=mx,
43 | truncation=True,
44 | add_special_tokens=False)
45 | ids = context_ids + target_ids + [self.config.eos_token_id]
46 | input_len = len(ids)
47 | context_len = len(context_ids)
48 | labels = [-100] * (context_len - 1) + ids[context_len - 1:] + [-100] * (self.max_length - input_len)
49 |
50 | f_ids = ids + [self.config.pad_token_id] * (self.max_length - input_len)
51 |
52 | input_ids = torch.tensor(f_ids)
53 | labels = torch.tensor(labels)
54 |
55 | return {
56 | 'input_ids': input_ids,
57 | 'labels': labels
58 | }
59 |
60 | def __len__(self):
61 | return len(self.data)
--------------------------------------------------------------------------------
/main/loaders/llm_chat.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import torch
4 | import random
5 | import pickle
6 | import numpy as np
7 | from tqdm import tqdm
8 | from torch.utils.data import Dataset, DataLoader
9 | from transformers import PretrainedConfig, PreTrainedTokenizer
10 |
11 | class LLMChatDataset(Dataset):
12 |
13 | config: PretrainedConfig
14 | tokenizer: PreTrainedTokenizer
15 |
16 | def __init__(self, tokenizer, config, file_name, max_length=512, do_shuffle=False):
17 | self.config = config
18 | self.tokenizer = tokenizer
19 | self.max_length = max_length
20 | self.do_shuffle = do_shuffle
21 | self.data = self.load_jsonl(file_name)
22 | self.random_list = [idx for idx in range(len(self.data))]
23 | if self.do_shuffle:
24 | random.shuffle(self.random_list)
25 |
26 |
27 | def load_jsonl(self, file_name):
28 | with open(file_name, 'r') as f:
29 | lines = f.readlines()
30 | data = [json.loads(line) for line in lines]
31 | return data
32 |
33 | def process_item(self, item):
34 | conv = item['conversations'] if 'conversations' in item else item
35 |
36 | input_ids, labels = [], []
37 |
38 | for t in conv:
39 | role = t['role']
40 | ids = self.tokenizer.apply_chat_template([t])
41 | ls = ids if role not in ['user', 'system'] else [-100 for _ in ids]
42 | input_ids.extend(ids)
43 | labels.extend(ls)
44 |
45 | max_length = self.max_length
46 | input_ids = input_ids[:max_length]
47 | labels = labels[:max_length]
48 | return {'input_ids': input_ids, 'labels': labels}
49 |
50 | def __getitem__(self, index):
51 | index = self.random_list[index]
52 | data = self.data[index]
53 | input_ids, labels = self.process_item(data).values()
54 |
55 | input_ids = torch.tensor(input_ids)
56 | labels = torch.tensor(labels)
57 |
58 | return {
59 | 'input_ids': input_ids,
60 | 'labels': labels
61 | }
62 |
63 | def __len__(self):
64 | return len(self.data)
--------------------------------------------------------------------------------
/main/analysis.py:
--------------------------------------------------------------------------------
1 | import os
2 | import datetime
3 | import numpy as np
4 |
5 |
6 | class Analysis():
7 |
8 | def __init__(self):
9 | self.train_record = {}
10 | self.eval_record = {}
11 | self.model_record = {}
12 |
13 | '''
14 | append data record of train
15 | train_record_item: dict
16 | '''
17 |
18 | def append_train_record(self, train_record_item):
19 | for key in train_record_item:
20 | if key not in self.train_record:
21 | self.train_record[key] = []
22 | self.train_record[key].append(train_record_item[key])
23 |
24 | '''
25 | append data record of eval
26 | eval_record_item: dict
27 | '''
28 |
29 | def append_eval_record(self, eval_record_item):
30 | for key in eval_record_item:
31 | if key not in self.eval_record:
32 | self.eval_record[key] = []
33 | self.eval_record[key].append(eval_record_item[key])
34 |
35 | '''
36 | append data record of model
37 | uid: model uid
38 | '''
39 |
40 | def append_model_record(self, uid):
41 | key = "model_uid"
42 | if key not in self.model_record:
43 | self.model_record[key] = []
44 | self.model_record[key].append(uid)
45 |
46 | def save_all_records(self, uid):
47 | self.save_record('train_record', uid)
48 | self.save_record('eval_record', uid)
49 | self.save_record('model_record', uid)
50 |
51 | def save_record(self, record_name, uid):
52 | record_dict = getattr(self, record_name)
53 | path = f'./data_record/{uid}'
54 | if not os.path.exists(path):
55 | os.makedirs(path)
56 | head = []
57 | for key in record_dict:
58 | head.append(key)
59 | if len(head) == 0:
60 | return uid
61 | result = ''
62 | for idx in range(len(record_dict[head[0]])):
63 | for key in head:
64 | result += str(record_dict[key][idx]) + '\t'
65 | result += '\n'
66 |
67 | result = "\t".join(head) + '\n' + result
68 |
69 | with open(f'{path}/{record_name}.csv', encoding='utf-8', mode='w+') as f:
70 | f.write(result)
71 |
72 | return uid
73 |
--------------------------------------------------------------------------------
/preprocess.py:
--------------------------------------------------------------------------------
1 | # %%
2 | import json
3 | from tqdm import tqdm
4 |
5 | DATANAME = '7_Math_ShortAns'
6 | IGNORE_ZERO_WITHOUT_REASON = True
7 | NO_IGNORE_REASON = '计算错误'
8 | filename = f'/home/lpc/repos/SAS_Benchmark/backend_data/scores/{DATANAME}.jsonl'
9 | with open(filename) as f:
10 | ori_data = f.readlines()
11 | ori_data = [json.loads(line) for line in ori_data]
12 |
13 | result = []
14 | for item in tqdm(ori_data):
15 | res_segs = item['bad_student_answer_segs']
16 | last_idx = 0
17 | format_response = {
18 | 'id': item['id'],
19 | 'question': item['question'],
20 | 'reference': item['answer'],
21 | 'analysis': item['analysis'],
22 | 'total': item['score'],
23 | 'steps': []
24 | }
25 | if 'scoreItem' not in item:
26 | result.append(format_response)
27 | continue
28 | scoreItem = item['scoreItem']
29 | format_response['manual_label'] = scoreItem['label']
30 | seg_labels = scoreItem['seg_labels']
31 | seg_labels = json.loads(seg_labels)
32 | remain_score = int(float(scoreItem['label']))
33 |
34 | for seg_item in seg_labels:
35 | seg_idx, seg_label, seg_errors = seg_item['idx'], seg_item['label'], seg_item['errors']
36 | if seg_label == '':
37 | continue
38 | if seg_label != '' and int(float(seg_label)) == 0 and len(seg_errors) == 0:
39 | if IGNORE_ZERO_WITHOUT_REASON:
40 | continue
41 | else:
42 | seg_errors = [NO_IGNORE_REASON]
43 | if seg_label != '' and int(float(seg_label)) > 0 and len(seg_errors) == 0:
44 | # if it is english dataset, you may replace it with 'correct'
45 | seg_errors = ['步骤正确']
46 | format_response['steps'].append({
47 | 'response': '\n'.join(res_segs[last_idx: seg_idx + 1]),
48 | 'label': seg_label,
49 | 'errors': seg_errors
50 | })
51 | last_idx = seg_idx + 1
52 | if seg_label != '' and int(float(seg_label)) > 0:
53 | remain_score -= int(float(seg_label))
54 | if last_idx < len(res_segs):
55 | format_response['steps'].append({
56 | 'response': '\n'.join(res_segs[last_idx:]),
57 | 'label': 0 if remain_score < 0 else remain_score,
58 | 'errors': []
59 | })
60 | result.append(format_response)
61 |
62 | with open(f'./datasets/{DATANAME}.jsonl', 'w') as f:
63 | for item in result:
64 | f.write(json.dumps(item, ensure_ascii=False) + '\n')
65 |
66 | # %%
67 |
--------------------------------------------------------------------------------
/discuss/pred_gold.py:
--------------------------------------------------------------------------------
1 | # %%
2 | import os
3 | import json
4 | import numpy as np
5 | import matplotlib.pyplot as plt
6 | from argparse import ArgumentParser
7 |
8 | plt.style.use('seaborn-v0_8-bright')
9 |
10 | import sys
11 | sys.path.append("../")
12 | cmd_args = True
13 |
14 | parser = ArgumentParser()
15 | parser.add_argument('--file_dir', default='../datasets', help='the directory of the datasets.')
16 | parser.add_argument('--file_name', default='all', help='file name of the dataset without extension (e.g. 0_Physics_ShortAns), you can define `all` for all dataset files.')
17 | parser.add_argument('--save_type_name', default='mimo', help='the prefix name of save dir (usually is the LLM name)')
18 | parser.add_argument('--save_fig', default=0, help='whether save the figure')
19 |
20 | if not cmd_args:
21 | args = parser.parse_args([]) # You can directly set above parameters in the default.
22 | else:
23 | args = parser.parse_args()
24 |
25 | SOURCE_DIR = os.path.join(args.file_dir + '_' + args.save_type_name + '_Scored')
26 |
27 | labels = []
28 | preds = []
29 | step_labels = []
30 | step_preds = []
31 | if args.file_name != 'all':
32 | files = [args.file_name + '_prediction.jsonl']
33 | else:
34 | files = os.listdir(SOURCE_DIR)
35 | for file_name in files:
36 | if file_name.find('prediction') < 0:
37 | continue
38 | SOURCE_FILE = os.path.join(SOURCE_DIR, file_name)
39 | with open(SOURCE_FILE, encoding='utf-8') as f:
40 | ori_data = f.readlines()
41 | ori_data = [json.loads(line) for line in ori_data]
42 | for item in ori_data:
43 | total = float(item['total'])
44 | label = float(item['manual_label']) / total
45 | pred = float(item['pred_label']) / total
46 | labels.append(label)
47 | preds.append(pred)
48 | for step in item['steps']:
49 | step_labels.append(float(step['label']) / total)
50 | for step in item['pred_steps']:
51 | step_preds.append(float(step['step_score']) / total)
52 |
53 | # %%
54 | plt.hist([item for item in labels], range=(0, 1), bins=20, alpha=0.6, color='#FFC000', label='Gold')
55 | plt.hist([item for item in preds], range=(0, 1), bins=20, alpha=0.6, color='#d3d3f9', label='Pred')
56 | # plt.hist([item for item in step_labels], range=(0, 1), bins=20, alpha=0.3, color='#A6C9E8', label='Pred Step-wise Score')
57 | # plt.hist([item for item in step_preds], range=(0, 1), bins=20, alpha=0.3, color='#44B6A3', label='Gold Step-wise Score')
58 | plt.legend()
59 | if str(args.save_fig) == '1':
60 | plt.savefig(f'{args.save_type_name}_pred_gold.svg', format='svg', dpi=300)
61 |
62 | # %%
63 |
--------------------------------------------------------------------------------
/utils/errors_consistency_score.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from scipy.stats import spearmanr
4 |
5 | def compute_ecs(pred_scores, ori_scores, total_scores, pred_errors, gold_errors, max_error=5):
6 | """
7 | Computes the consistency score between predicted and ground truth evaluations.
8 |
9 | Args:
10 | pred_scores: List of predicted scores for each sample
11 | ori_scores: List of original ground truth scores for each sample
12 | total_scores: Maximum possible total score for normalization
13 | pred_errors: Predicted error frequencies as 2D list [[freq1, freq2,...], ...]
14 | where each sublist represents error type frequencies for a sample
15 | gold_errors: Ground truth error frequencies with same structure as pred_errors
16 | max_error: Maximum possible number of errors for normalization
17 | """
18 |
19 | all_norm_scores = []
20 | Ln = 1e-10
21 | for score, max_s in zip(ori_scores, total_scores):
22 | norm_score = score / (max_s + Ln)
23 | all_norm_scores.append(norm_score)
24 | range1, range2 = 0, 0
25 | sorted_scores = sorted(all_norm_scores)
26 | range1 = sorted_scores[int(len(sorted_scores) * 0.33)]
27 | range2 = sorted_scores[int(len(sorted_scores) * 0.67)]
28 |
29 | ori_range_error_matrix = torch.zeros((3, max_error))
30 | pred_range_error_matrix = torch.zeros((3, max_error))
31 | for i in range(len(pred_errors)):
32 | pred, gold, max_s, p_errors, g_errors = pred_scores[i], ori_scores[i], total_scores[i], pred_errors[i], gold_errors[i]
33 | pred = pred / (max_s + Ln)
34 | gold = gold / (max_s + Ln)
35 | if pred <= range1:
36 | for j, freq in enumerate(p_errors):
37 | pred_range_error_matrix[0][j] += freq
38 | elif pred < range2:
39 | for j, freq in enumerate(p_errors):
40 | pred_range_error_matrix[1][j] += freq
41 | else:
42 | for j, freq in enumerate(p_errors):
43 | pred_range_error_matrix[2][j] += freq
44 | if gold <= range1:
45 | for j, freq in enumerate(g_errors):
46 | ori_range_error_matrix[0][j] += freq
47 | elif gold < range2:
48 | for j, freq in enumerate(g_errors):
49 | ori_range_error_matrix[1][j] += freq
50 | else:
51 | for j, freq in enumerate(g_errors):
52 | ori_range_error_matrix[2][j] += freq
53 |
54 | spearmans = []
55 | for i in range(pred_range_error_matrix.shape[0]):
56 | spearman_score = spearmanr(pred_range_error_matrix[i].tolist(), ori_range_error_matrix[i].tolist())
57 | spearmans.append(0 if str(spearman_score.correlation) == 'nan' else spearman_score.correlation)
58 |
59 | return np.mean(spearmans), spearmans
60 |
--------------------------------------------------------------------------------
/main/loaders/chatglm_chat.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import torch
4 | import random
5 | import pickle
6 | import numpy as np
7 | from tqdm import tqdm
8 | from torch.utils.data import Dataset, DataLoader
9 | from transformers import PretrainedConfig, PreTrainedTokenizer
10 |
11 | class ChatGLM_ChatDataset(Dataset):
12 |
13 | config: PretrainedConfig
14 | tokenizer: PreTrainedTokenizer
15 |
16 | def __init__(self, tokenizer, config, file_name, max_length=512, do_shuffle=False):
17 | self.config = config
18 | self.tokenizer = tokenizer
19 | self.max_length = max_length
20 | self.do_shuffle = do_shuffle
21 | self.data = self.load_jsonl(file_name)
22 | self.random_list = [idx for idx in range(len(self.data))]
23 | if self.do_shuffle:
24 | random.shuffle(self.random_list)
25 |
26 |
27 | def load_jsonl(self, file_name):
28 | with open(file_name, 'r') as f:
29 | lines = f.readlines()
30 | data = [json.loads(line) for line in lines]
31 | return data
32 |
33 | def process_item(self, item):
34 | conv = item['conversations'] if 'conversations' in item else item
35 |
36 | input_ids, loss_masks = [
37 | self.tokenizer.get_command('[gMASK]'),
38 | self.tokenizer.get_command('sop'),
39 | ], [False, False]
40 |
41 | for message in conv:
42 | if message['role'] in ('system', 'user'):
43 | loss_mask_val = False
44 | else:
45 | loss_mask_val = True
46 |
47 | if message['role'] == 'tool':
48 | raise NotImplementedError()
49 | else:
50 | new_input_ids = self.tokenizer.build_single_message(
51 | message['role'], '', message['content']
52 | )
53 | new_loss_masks = [loss_mask_val] * len(new_input_ids)
54 |
55 | input_ids += new_input_ids
56 | loss_masks += new_loss_masks
57 |
58 | input_ids.append(self.tokenizer.eos_token_id)
59 | loss_masks = [False, *loss_masks]
60 | labels = []
61 | for input_id, mask in zip(input_ids, loss_masks):
62 | if mask:
63 | labels.append(input_id)
64 | else:
65 | labels.append(-100)
66 | max_length = self.max_length
67 | input_ids = input_ids[:max_length]
68 | labels = labels[:max_length]
69 | return {'input_ids': input_ids, 'labels': labels}
70 |
71 | def __getitem__(self, index):
72 | index = self.random_list[index]
73 | data = self.data[index]
74 | input_ids, labels = self.process_item(data).values()
75 |
76 | input_ids = torch.tensor(input_ids)
77 | labels = torch.tensor(labels)
78 |
79 | return {
80 | 'input_ids': input_ids,
81 | 'labels': labels
82 | }
83 |
84 | def __len__(self):
85 | return len(self.data)
--------------------------------------------------------------------------------
/main/predictor/openai.py:
--------------------------------------------------------------------------------
1 | import json
2 | import torch
3 | from openai import OpenAI
4 | from openai._types import NotGiven
5 | from transformers import AutoTokenizer, AutoModel
6 | from typing import Tuple, List
7 |
8 | NOT_GIVEN = NotGiven()
9 |
10 |
11 | class Predictor():
12 |
13 | def __init__(self,
14 | organization=None,
15 | api_key=None,
16 | base_url=None,
17 | **args
18 | ):
19 | '''
20 | Predictor: OpenAI API预测器 (OpenAI API predictor)
21 |
22 | ### Args:
23 |
24 | `organization`: OpenAI组织名 (OpenAI organization name)
25 |
26 | `api_key`: OpenAI API密钥 (OpenAI API key)
27 | '''
28 |
29 | self.organization = organization
30 | self.api_key = api_key
31 | self.base_url = base_url
32 | self.client = OpenAI(organization=self.organization, api_key=self.api_key, base_url=self.base_url)
33 |
34 |
35 | def predict(self, query: str = '', history: List = None, model: str = 'gpt-4o-mini', max_length=NOT_GIVEN, max_new_tokens=NOT_GIVEN, top_p: float = NOT_GIVEN, temperature=NOT_GIVEN):
36 | if history is None:
37 | history = []
38 | raw = history + [{"role": "user", "content": query}]
39 |
40 | completion = self.client.chat.completions.create(
41 | model=model,
42 | messages=raw,
43 | max_tokens=max_length,
44 | max_completion_tokens=max_new_tokens,
45 | temperature=temperature,
46 | top_p=top_p,
47 | stream=False
48 | )
49 |
50 | message = completion.choices[0].message.content
51 | raw.append({"role": "assistant", "content": message})
52 | return message, raw
53 |
54 | def stream_chat(self, query: str = '', history: List = None, model: str = 'gpt-4o-mini', max_length=NOT_GIVEN, max_new_tokens=NOT_GIVEN, top_p: float = NOT_GIVEN, temperature=NOT_GIVEN):
55 | if history is None:
56 | history = []
57 | raw = history + [{"role": "user", "content": query}]
58 |
59 | completion = self.client.chat.completions.create(
60 | model=model,
61 | messages=raw,
62 | max_tokens=max_length,
63 | max_completion_tokens=max_new_tokens,
64 | temperature=temperature,
65 | top_p=top_p,
66 | stream=True
67 | )
68 |
69 | result = ''
70 | for chunk in completion:
71 | delta = chunk.choices[0].delta
72 | if delta.content is None:
73 | continue
74 | result += delta.content
75 | yield result, raw + [{"role": "assistant", "content": result}]
76 |
77 | def __call__(self, query: str = '', history: List = None, model: str = 'gpt-4o-mini', max_length=NOT_GIVEN, max_new_tokens=NOT_GIVEN, top_p: float = NOT_GIVEN, temperature=NOT_GIVEN):
78 | return self.predict(query, history, model, max_length, max_new_tokens, top_p, temperature)
79 |
--------------------------------------------------------------------------------
/main/predictor/qwen_lora.py:
--------------------------------------------------------------------------------
1 | from transformers import AutoModelForCausalLM, AutoConfig
2 | from transformers import AutoTokenizer
3 | from peft import LoraConfig, TaskType, PeftModel, PeftModelForCausalLM
4 | from typing import Optional, List
5 | from copy import deepcopy
6 | from transformers.generation import GenerationMixin
7 | import torch
8 |
9 | class Predictor(GenerationMixin):
10 | true_model: PeftModelForCausalLM
11 |
12 | def __init__(self,
13 | num_gpus: list = [0],
14 | model_from_pretrained: str = None,
15 | resume_path: str = None
16 | ):
17 | self.config = AutoConfig.from_pretrained(model_from_pretrained, trust_remote_code=True)
18 | peft_config = LoraConfig(
19 | task_type=TaskType.CAUSAL_LM,
20 | inference_mode=False,
21 | r=16,
22 | target_modules=["c_attn", "c_proj", "w1", "w2"],
23 | lora_alpha=32,
24 | lora_dropout=0.1,
25 | )
26 | self.tokenizer = AutoTokenizer.from_pretrained(
27 | model_from_pretrained, trust_remote_code=True)
28 | self.model = AutoModelForCausalLM.from_pretrained(
29 | model_from_pretrained, device_map="auto", trust_remote_code=True).eval()
30 | self.llm = self.model
31 | self.generation_config = self.llm.generation_config
32 | self.model = PeftModel.from_pretrained(
33 | self.model, resume_path, config=peft_config)
34 |
35 | def predict(self, text='', max_length=150, temperature=1.0):
36 | with torch.no_grad():
37 | inputs = self.tokenizer.encode(text)
38 | input_ids = torch.LongTensor([inputs]).to(self.device)
39 | output = self.true_model.generate(**{
40 | 'input_ids': input_ids,
41 | 'max_length': max_length,
42 | 'do_sample': False,
43 | 'temperature': temperature
44 | })
45 | out_text = self.tokenizer.decode(
46 | output[0], skip_special_tokens=True)
47 | return out_text
48 |
49 | @torch.inference_mode()
50 | def chat(self,
51 | query: str,
52 | history = None,
53 | system: str = "You are a helpful assistant.",
54 | stop_words_ids: Optional[List[List[int]]] = None,
55 | generation_config=None,
56 | **kwargs,):
57 | tokenizer = self.tokenizer
58 | generation_config = generation_config if generation_config is not None else self.generation_config
59 |
60 | return self.model.chat(
61 | tokenizer=tokenizer,
62 | query=query,
63 | history=history,
64 | system=system,
65 | generation_config=generation_config,
66 | stop_words_ids=stop_words_ids,
67 | **kwargs,
68 | )
69 |
70 | def __call__(self, text='', max_length=150, temperature=0):
71 | return self.predict(text=text, max_length=max_length, temperature=temperature)
72 |
--------------------------------------------------------------------------------
/main/loaders/qwen_chat.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import torch
4 | import random
5 | import pickle
6 | import numpy as np
7 | from tqdm import tqdm
8 | from torch.utils.data import Dataset, DataLoader
9 | from transformers import PretrainedConfig, PreTrainedTokenizer
10 |
11 | class QwenChatDataset(Dataset):
12 |
13 | config: PretrainedConfig
14 | tokenizer: PreTrainedTokenizer
15 |
16 | def __init__(self, tokenizer, config, file_name, max_length=512, do_shuffle=False):
17 | self.config = config
18 | self.tokenizer = tokenizer
19 | self.max_length = max_length
20 | self.do_shuffle = do_shuffle
21 | self.data = self.load_jsonl(file_name)
22 | self.random_list = [idx for idx in range(len(self.data))]
23 | if self.do_shuffle:
24 | random.shuffle(self.random_list)
25 |
26 |
27 | def load_jsonl(self, file_name):
28 | with open(file_name, 'r') as f:
29 | data = [json.loads(line) for line in f]
30 | return data
31 |
32 | def process_item(self, item):
33 | conv = item['conversations'] if 'conversations' in item else item
34 |
35 | input_ids, loss_masks = [], []
36 |
37 | # im_start, im_end = "<|im_start|>", "<|im_end|>"
38 | im_start_tokens = [self.tokenizer.im_start_id]
39 | im_end_tokens = [self.tokenizer.im_end_id]
40 | nl_tokens = self.tokenizer.encode("\n")
41 |
42 | def _tokenize_str(role, content):
43 | return f"{role}\n{content}", self.tokenizer.encode(
44 | role, allowed_special=set()
45 | ) + nl_tokens + self.tokenizer.encode(content, allowed_special=set())
46 |
47 | for message in conv:
48 | if message['role'] in ('system', 'user'):
49 | loss_mask_val = False
50 | else:
51 | loss_mask_val = True
52 |
53 | if message['role'] == 'tool':
54 | raise NotImplementedError()
55 | else:
56 | _, msg_tokens = _tokenize_str(message['role'], message['content'])
57 | new_input_ids = im_start_tokens + msg_tokens + im_end_tokens
58 | new_input_ids = new_input_ids + nl_tokens
59 | new_loss_masks = [loss_mask_val] * len(new_input_ids)
60 |
61 | input_ids += new_input_ids
62 | loss_masks += new_loss_masks
63 |
64 | input_ids = input_ids + im_end_tokens
65 | loss_masks = [False, *loss_masks]
66 | labels = []
67 | for input_id, mask in zip(input_ids, loss_masks):
68 | if mask:
69 | labels.append(input_id)
70 | else:
71 | labels.append(-100)
72 | max_length = self.max_length
73 | input_ids = input_ids[:max_length]
74 | labels = labels[:max_length]
75 | return {'input_ids': input_ids, 'labels': labels}
76 |
77 | def __getitem__(self, index):
78 | index = self.random_list[index]
79 | data = self.data[index]
80 | input_ids, labels = self.process_item(data).values()
81 |
82 | input_ids = torch.tensor(input_ids)
83 | labels = torch.tensor(labels)
84 |
85 | return {
86 | 'input_ids': input_ids,
87 | 'labels': labels
88 | }
89 |
90 | def __len__(self):
91 | return len(self.data)
--------------------------------------------------------------------------------
/utils/collaborative_consistency_score.py:
--------------------------------------------------------------------------------
1 | import itertools
2 | import numpy as np
3 | from itertools import product
4 | from sklearn.metrics import confusion_matrix, cohen_kappa_score
5 |
6 | def composite_distance(true_tuple, pred_tuple, weights):
7 | """
8 | Computes the weighted difference (e.g., Manhattan distance) between composite score tuples.
9 |
10 | Args:
11 | true_tuple: Ground truth scores (total_score, step1_score, step2_score)
12 | pred_tuple: Predicted scores (total_score, step1_score, step2_score)
13 | weights: Weights for each component [w_total, w_step1, w_step2], e.g., [0.5, 0.25, 0.25]
14 |
15 | Returns:
16 | Weighted composite difference between the tuples
17 | """
18 | return np.sum(weights * np.abs(np.array(true_tuple) - np.array(pred_tuple)))
19 |
20 | def compute_QWK(X, Y, grading_scale=1):
21 | X = [int(x * grading_scale) for x in X]
22 | Y = [int(y * grading_scale) for y in Y]
23 | X = np.array(X)
24 | Y = np.array(Y)
25 |
26 | qwk = cohen_kappa_score(Y, X, weights='quadratic')
27 | return qwk
28 |
29 | def adjusted_qwk(true_tuples, pred_tuples, weights, max_scores):
30 | """
31 | Computes weighted overall and step-wise consistent scores between ground truth and predicted evaluation tuples.
32 |
33 | Args:
34 | true_tuples: List of ground truth score tuples [(total_score, step1, ..., step10), ...]
35 | pred_tuples: List of predicted score tuples [(total_score, step1, ..., step10), ...]
36 | weights: Weight coefficients for each dimension [w_total, w_step1, ..., w_step10]
37 | (should satisfy sum(weights) = 1)
38 | max_scores: Maximum possible scores for normalization [max_total, max_step1, ..., max_step10]
39 | """
40 | # Validate each score entry is of tuple type
41 | true_tuples = [tuple(t) for t in true_tuples]
42 | pred_tuples = [tuple(p) for p in pred_tuples]
43 |
44 | # Construct unified scoring composition levels (with deduplication)
45 | unique_tuples = sorted(list(set(true_tuples + pred_tuples)))
46 | n_levels = len(unique_tuples)
47 | tuple_to_idx = {t: i for i, t in enumerate(unique_tuples)}
48 |
49 | # Build observation matrix O (actual score frequencies)
50 | O = np.zeros((n_levels, n_levels))
51 | for t, p in zip(true_tuples, pred_tuples):
52 | i = tuple_to_idx[t]
53 | j = tuple_to_idx[p]
54 | O[i, j] += 1
55 |
56 | # Build expectation matrix E (theoretical score probabilities)
57 | row_sums = O.sum(axis=1)
58 | col_sums = O.sum(axis=0)
59 | E = np.outer(row_sums, col_sums) / np.sum(O)
60 |
61 | # Construct weight matrix W (normalized squared differences)
62 | max_scores = np.array(max_scores)
63 | weights = np.array(weights)
64 |
65 | W = np.zeros((n_levels, n_levels))
66 | Ln = 1e-10
67 | for i, ti in enumerate(unique_tuples):
68 | for j, tj in enumerate(unique_tuples):
69 | ti = np.array(ti)
70 | tj = np.array(tj)
71 | diff = (ti - tj) / (max_scores + Ln)
72 | if i > len(weights) - 1:
73 | W[i, j] = np.sum(0)
74 | else:
75 | W[i, j] = np.sum(weights[i] * diff ** 2)
76 |
77 | # Compute adjusted Composite Consistency Score (CCS)
78 | ccs = 1 - np.sum(O * W) / np.sum(E * W)
79 |
80 | X = [item[0] for item in pred_tuples]
81 | Y = [item[0] for item in true_tuples]
82 | qwk = compute_QWK(X, Y)
83 |
84 | return ccs, qwk
85 |
--------------------------------------------------------------------------------
/sas_pipelines/3_compute_ccs.py:
--------------------------------------------------------------------------------
1 | # %%
2 | import os
3 | import json
4 | import random
5 | import json_repair
6 | import numpy as np
7 | from tqdm import tqdm
8 | from argparse import ArgumentParser
9 |
10 | import sys
11 | sys.path.append("../")
12 | cmd_args = True
13 |
14 | parser = ArgumentParser()
15 | parser.add_argument('--file_dir', default='../datasets', help='the directory of the datasets.')
16 | parser.add_argument('--file_name', default='all', help='file name of the dataset without extension (e.g. 0_Physics_ShortAns), you can define `all` for all dataset files.')
17 | parser.add_argument('--save_type_name', default='Deepseek', help='the prefix name of save dir (usually is the LLM name)')
18 |
19 | if not cmd_args:
20 | args = parser.parse_args([]) # You can directly set above parameters in the default.
21 | else:
22 | args = parser.parse_args()
23 |
24 | SOURCE_DIR = os.path.join(args.file_dir + '_' + args.save_type_name + '_Scored')
25 |
26 | from utils.collaborative_consistency_score import adjusted_qwk
27 |
28 | results = []
29 | if args.file_name != 'all':
30 | files = [args.file_name + '_prediction.jsonl']
31 | else:
32 | files = os.listdir(SOURCE_DIR)
33 | for file_name in files:
34 | if file_name.find('prediction') < 0:
35 | continue
36 | SOURCE_FILE = os.path.join(SOURCE_DIR, file_name)
37 | with open(SOURCE_FILE, encoding='utf-8') as f:
38 | ori_data = f.readlines()
39 | ori_data = [json.loads(line) for line in ori_data]
40 |
41 | pred_result = []
42 | ori_result = []
43 | weights_result = []
44 | max_score = 0
45 | max_length = 0
46 | for item in tqdm(ori_data):
47 | if max_score < item['total']:
48 | max_score = item['total']
49 | ori_steps = item['steps']
50 | if 'pred_steps' not in item:
51 | pred_steps = [{"step_score": 0, "errors": []} for _ in ori_steps]
52 | else:
53 | pred_steps = item['pred_steps']
54 | ori_score = item['manual_label']
55 | if 'pred_label' not in item or item['pred_label'] == '':
56 | pred_score = 0
57 | else:
58 | pred_score = item['pred_label']
59 | ori_labels = [int(ori_score)]
60 | pred_labels = [int(pred_score)]
61 | weights = [0.5] + [0.5 / len(ori_steps) for _ in range(len(ori_steps))]
62 | for i in range(len(ori_steps)):
63 | try:
64 | ori_labels.append(int(ori_steps[i]['label']))
65 | except:
66 | ori_labels.append(0)
67 | if len(pred_steps) > i and type(pred_steps[i]) == dict and 'step_score' in pred_steps[i]:
68 | pred_labels.append(int(pred_steps[i]['step_score']))
69 | else:
70 | pred_labels.append(0)
71 | if max_length < len(ori_labels):
72 | max_length = len(ori_labels)
73 | pred_result.append(pred_labels)
74 | ori_result.append(ori_labels)
75 | weights_result.append(weights)
76 |
77 | # Padding
78 | for i in range(len(pred_result)):
79 | pred_result[i] += [0] * (max_length - len(pred_result[i]))
80 | ori_result[i] += [0] * (max_length - len(ori_result[i]))
81 | weights_result[i] += [0] * (max_length - len(weights_result[i]))
82 | max_scores = []
83 | for i in range(len(ori_result[0])):
84 | val_i = []
85 | for item in ori_result:
86 | val_i.append(item[i])
87 | for item in pred_result:
88 | val_i.append(item[i])
89 | max_scores.append(max(val_i))
90 | max_scores[0] = max_score
91 | results.append((file_name, adjusted_qwk(ori_result, pred_result, weights_result, max_scores)))
92 |
93 | # %%
94 | def sort_func(x):
95 | try:
96 | x = x[0]
97 | x = x.split('_')[0]
98 | x = int(x)
99 | return x
100 | except:
101 | return 0
102 |
103 | results = sorted(results, key=sort_func)
104 | for item in results:
105 | print(item)
106 |
107 | with open(f'{args.save_type_name}_ccs.csv', 'w+') as f:
108 | for item in results:
109 | cols = [item[0], str(round(item[1][0] * 100, 2)), str(round(item[1][1] * 100 ,2))]
110 | f.write(','.join(cols) + '\n')
111 |
112 | # %%
113 |
--------------------------------------------------------------------------------
/sas_pipelines/2_process_prediction.py:
--------------------------------------------------------------------------------
1 | # %%
2 | import os
3 | import re
4 | import json
5 | import random
6 | import json_repair
7 | import numpy as np
8 | from tqdm import tqdm
9 | from argparse import ArgumentParser
10 |
11 | import sys
12 | sys.path.append("../")
13 | cmd_args = True
14 |
15 | parser = ArgumentParser()
16 | parser.add_argument('--file_dir', default='../datasets', help='the directory of the datasets.')
17 | parser.add_argument('--file_name', default='all', help='file name of the dataset without extension (e.g. 0_Physics_ShortAns), you can define `all` for all dataset files.')
18 | parser.add_argument('--save_type_name', default='Qwen3_32B', help='the prefix name of save dir (usually is the LLM name)')
19 |
20 | if not cmd_args:
21 | args = parser.parse_args([]) # You can directly set above parameters in the default.
22 | else:
23 | args = parser.parse_args()
24 |
25 | def unwrap_score_item(item, len_steps):
26 | item = json.loads(item)
27 | item = str(item)
28 | if item.find("{'total'") >= 0 and item.find('{"total"') < 0:
29 | item = item.replace('\'', '"')
30 | match_index = item.find('{"total"')
31 | if match_index >= 0:
32 | item = item[match_index:]
33 | item = item.replace('}```', '}')
34 | # Remove // Comments
35 | item = re.sub(r'//.*$', '', item, flags=re.MULTILINE)
36 | # Remove /* */ Comments
37 | item = re.sub(r'/\*[\s\S]*?\*/', '', item)
38 | item = json_repair.loads(item)
39 | if type(item) == list:
40 | item = item[0]
41 | else:
42 | item = {
43 | 'total': 0,
44 | 'pred_score': 0,
45 | 'steps': [{'step_score': 0, 'errors': []} for _ in range(len_steps)]
46 | }
47 |
48 | if 'steps' not in item or type(item['steps']) != list:
49 | item['steps'] = [{'step_score': 0, 'errors': []} for _ in range(len_steps)]
50 |
51 | if 'pred_score' not in item:
52 | item['pred_score'] = 0
53 |
54 | for step_idx, step in enumerate(item['steps']):
55 | if type(step) != dict:
56 | item['steps'][step_idx] = {}
57 | step = item['steps'][step_idx]
58 | if 'step_score' not in step:
59 | step['step_score'] = 0
60 | try:
61 | step['step_score'] = int(step['step_score'])
62 | except:
63 | step['step_score'] = 0
64 | if 'errors' not in step or type(step['errors']) != list:
65 | step['errors'] = []
66 | return item
67 |
68 | SOURCE_DIR = os.path.join(args.file_dir)
69 | if args.file_name != 'all':
70 | files = [args.file_name + '.jsonl']
71 | else:
72 | files = os.listdir(SOURCE_DIR)
73 | for file_name in files:
74 | if len(file_name.split('_')) < 3:
75 | continue
76 | SOURCE_FILE = os.path.join(SOURCE_DIR, file_name)
77 | with open(SOURCE_FILE, encoding='utf-8') as f:
78 | ori_data = f.readlines()
79 | ori_data = [json.loads(line) for line in ori_data]
80 |
81 | SCORED_FILE = os.path.join(args.file_dir + '_' + args.save_type_name + '_Scored', f'{file_name.split(".jsonl")[0]}_scored.jsonl')
82 | print(SCORED_FILE)
83 | if not os.path.exists(SCORED_FILE):
84 | continue
85 | with open(SCORED_FILE, encoding='utf-8') as f:
86 | scored_data = f.readlines()
87 |
88 | for item, score_item in tqdm(zip(ori_data, scored_data)):
89 | score_item = score_item.split('\t')[1]
90 | len_steps = len(item['steps'])
91 | score_item = unwrap_score_item(score_item, len_steps=len_steps)
92 | try:
93 | item['pred_label'] = int(score_item['pred_score'])
94 | except:
95 | item['pred_label'] = 0
96 | item['pred_steps'] = score_item['steps']
97 | item['pred_steps'] = item['pred_steps'][:len_steps]
98 | if len(item['pred_steps']) < len_steps:
99 | for _ in range(len_steps - len(item['pred_steps'])):
100 | item['pred_steps'].append({'step_score': 0, 'errors': []})
101 |
102 | SAVE_DIR = os.path.join(os.path.dirname(SCORED_FILE), f'{file_name.split(".jsonl")[0]}_prediction.jsonl')
103 | with open(SAVE_DIR, 'w', encoding='utf-8') as f:
104 | for item in ori_data:
105 | f.write(json.dumps(item, ensure_ascii=False) + '\n')
106 |
107 | # %%
108 |
--------------------------------------------------------------------------------
/discuss/error_causes_f1.py:
--------------------------------------------------------------------------------
1 | # %%
2 | import os
3 | import json
4 | import random
5 | import json_repair
6 | import numpy as np
7 | from tqdm import tqdm
8 | from argparse import ArgumentParser
9 |
10 | import sys
11 | sys.path.append("../")
12 | cmd_args = True
13 |
14 | parser = ArgumentParser()
15 | parser.add_argument('--file_dir', default='../datasets', help='the directory of the datasets.')
16 | parser.add_argument('--file_name', default='all', help='file name of the dataset without extension (e.g. 0_Physics_ShortAns), you can define `all` for all dataset files.')
17 | parser.add_argument('--save_type_name', default='Deepseek', help='the prefix name of save dir (usually is the LLM name)')
18 | parser.add_argument('--skip_correct', default=True, help='batch size')
19 |
20 | if not cmd_args:
21 | args = parser.parse_args([]) # You can directly set above parameters in the default.
22 | else:
23 | args = parser.parse_args()
24 |
25 | SKIP_CORRECT = args.skip_correct
26 | # if it is english dataset, you may replace it with {'name': 'correct', 'description': 'the step is correct.'}
27 | # but it depends on how you predefined the `name` of the correct step.
28 | CORRECT_NAME = '步骤正确'
29 | CORRECT_DESCRIPTION = '该步骤正确'
30 | SOURCE_DIR = os.path.join(args.file_dir + '_' + args.save_type_name + '_Scored')
31 |
32 | from utils.errors_consistency_score import compute_ecs
33 |
34 | results = []
35 | if args.file_name != 'all':
36 | files = [args.file_name + '_prediction.jsonl']
37 | else:
38 | files = os.listdir(SOURCE_DIR)
39 |
40 | for file_name in files:
41 | if file_name.find('prediction') < 0:
42 | continue
43 | SOURCE_FILE = os.path.join(SOURCE_DIR, file_name)
44 | with open(SOURCE_FILE, encoding='utf-8') as f:
45 | ori_data = f.readlines()
46 | ori_data = [json.loads(line) for line in ori_data]
47 |
48 | ERROR_FILE = os.path.join(args.file_dir, 'error_type.jsonl')
49 | ID = file_name.split('_')[0]
50 | with open(ERROR_FILE, encoding='utf-8') as f:
51 | error_type_list = f.readlines()
52 | error_type_list = [json.loads(item) for item in error_type_list]
53 | error_type_item = []
54 | score_guideline = ''
55 | for item in error_type_list:
56 | if str(item['q_id']) == str(ID):
57 | error_type_item = item['errors']
58 | if not SKIP_CORRECT:
59 |
60 | error_type_item.append({'name': CORRECT_NAME, 'description': CORRECT_DESCRIPTION})
61 | break
62 |
63 | error_to_id_dict = {}
64 | for i, item in enumerate(error_type_item):
65 | error_to_id_dict[item['name']] = i
66 | def err2idx(name):
67 | if name in error_to_id_dict:
68 | return error_to_id_dict[name]
69 | return len(error_to_id_dict) - 1
70 |
71 | tp = 0
72 | fp = 0
73 | fn = 0
74 | max_error_length = len(error_type_item)
75 | max_length = 0
76 | for item in tqdm(ori_data):
77 | ori_steps = item['steps']
78 | if 'pred_steps' not in item:
79 | pred_steps = [{"step_score": 0, "errors": []} for _ in ori_steps]
80 | else:
81 | pred_steps = item['pred_steps']
82 | p_errors = [0 for _ in range(max_error_length)]
83 | g_errors = [0 for _ in range(max_error_length)]
84 | for step in ori_steps:
85 | errors = step['errors']
86 | for error in errors:
87 | if error == CORRECT_NAME and SKIP_CORRECT:
88 | continue
89 | g_errors[err2idx(error)] += 1
90 | for step in pred_steps:
91 | errors = step['errors']
92 | for error in errors:
93 | if error == CORRECT_NAME and SKIP_CORRECT:
94 | continue
95 | p_errors[err2idx(str(error))] += 1
96 | for p, g in zip(p_errors, g_errors):
97 | if p > 0:
98 | if g > 0:
99 | tp += 1
100 | else:
101 | fp += 1
102 | else:
103 | if g > 0:
104 | fn += 1
105 |
106 | P = tp / (tp + fp)
107 | R = tp / (tp + fn)
108 | F1 = (2 * P * R) / (P + R)
109 |
110 | results.append((file_name, F1, P, R))
111 |
112 | # %%
113 | def sort_func(x):
114 | try:
115 | x = x[0]
116 | x = x.split('_')[0]
117 | x = int(x)
118 | return x
119 | except:
120 | return 0
121 |
122 | results = sorted(results, key=sort_func)
123 | for item in results:
124 | print(item)
125 |
126 | with open(f'{args.save_type_name}_ef1.csv', 'w+') as f:
127 | for item in results:
128 | cols = [item[0], str(round(item[1] * 100, 2)), str(round(item[2] * 100 ,2)), str(round(item[3] * 100 ,2))]
129 | f.write(','.join(cols) + '\n')
130 |
131 | # %%
132 |
--------------------------------------------------------------------------------
/sas_pipelines/4_compute_ecs.py:
--------------------------------------------------------------------------------
1 | # %%
2 | import os
3 | import json
4 | import random
5 | import json_repair
6 | import numpy as np
7 | from tqdm import tqdm
8 | from argparse import ArgumentParser
9 |
10 | import sys
11 | sys.path.append("../")
12 | cmd_args = True
13 |
14 | parser = ArgumentParser()
15 | parser.add_argument('--file_dir', default='../datasets', help='the directory of the datasets.')
16 | parser.add_argument('--file_name', default='all', help='file name of the dataset without extension (e.g. 0_Physics_ShortAns), you can define `all` for all dataset files.')
17 | parser.add_argument('--save_type_name', default='Deepseek', help='the prefix name of save dir (usually is the LLM name)')
18 | parser.add_argument('--skip_correct', default=True, help='batch size')
19 |
20 | if not cmd_args:
21 | args = parser.parse_args([]) # You can directly set above parameters in the default.
22 | else:
23 | args = parser.parse_args()
24 |
25 | SKIP_CORRECT = args.skip_correct
26 | SOURCE_DIR = os.path.join(args.file_dir + '_' + args.save_type_name + '_Scored')
27 |
28 | # if it is english dataset, you may replace it with {'name': 'correct', 'description': 'the step is correct.'}
29 | # but it depends on how you predefined the `name` of the correct step.
30 | CORRECT_NAME = '步骤正确'
31 | CORRECT_DESCRIPTION = '该步骤正确'
32 |
33 | from utils.errors_consistency_score import compute_ecs
34 |
35 | results = []
36 | if args.file_name != 'all':
37 | files = [args.file_name + '_prediction.jsonl']
38 | else:
39 | files = os.listdir(SOURCE_DIR)
40 |
41 | for file_name in files:
42 | if file_name.find('prediction') < 0:
43 | continue
44 | SOURCE_FILE = os.path.join(SOURCE_DIR, file_name)
45 | with open(SOURCE_FILE, encoding='utf-8') as f:
46 | ori_data = f.readlines()
47 | ori_data = [json.loads(line) for line in ori_data]
48 |
49 | ERROR_FILE = os.path.join(args.file_dir, 'error_type.jsonl')
50 | ID = file_name.split('_')[0]
51 | with open(ERROR_FILE, encoding='utf-8') as f:
52 | error_type_list = f.readlines()
53 | error_type_list = [json.loads(item) for item in error_type_list]
54 | error_type_item = []
55 | score_guideline = ''
56 | for item in error_type_list:
57 | if str(item['q_id']) == str(ID):
58 | error_type_item = item['errors']
59 | error_type_item.append({'name': CORRECT_NAME, 'description': CORRECT_DESCRIPTION})
60 | break
61 |
62 | error_to_id_dict = {}
63 | for i, item in enumerate(error_type_item):
64 | error_to_id_dict[item['name']] = i
65 | def err2idx(name):
66 | if name in error_to_id_dict:
67 | return error_to_id_dict[name]
68 | return len(error_to_id_dict) - 1
69 |
70 | pred_scores = []
71 | ori_scores = []
72 |
73 | total_scores = []
74 |
75 | pred_errors = []
76 | gold_errors = []
77 | max_error_length = len(error_type_item)
78 | max_length = 0
79 | for item in tqdm(ori_data):
80 | total_scores.append(item['total'])
81 | ori_score = item['manual_label']
82 | if 'pred_label' not in item or item['pred_label'] == '':
83 | pred_score = 0
84 | else:
85 | pred_score = item['pred_label']
86 | pred_scores.append(int(pred_score))
87 | ori_scores.append(int(ori_score))
88 |
89 | ori_steps = item['steps']
90 | if 'pred_steps' not in item:
91 | pred_steps = [{"step_score": 0, "errors": []} for _ in ori_steps]
92 | else:
93 | pred_steps = item['pred_steps']
94 | p_errors = [0 for _ in range(max_error_length)]
95 | g_errors = [0 for _ in range(max_error_length)]
96 | for step in ori_steps:
97 | errors = step['errors']
98 | for error in errors:
99 | if error == CORRECT_NAME and SKIP_CORRECT:
100 | continue
101 | g_errors[err2idx(error)] += 1
102 | for step in pred_steps:
103 | errors = step['errors']
104 | for error in errors:
105 | if error == CORRECT_NAME and SKIP_CORRECT:
106 | continue
107 | p_errors[err2idx(str(error))] += 1
108 | pred_errors.append(p_errors)
109 | gold_errors.append(g_errors)
110 |
111 | results.append((file_name, compute_ecs(pred_scores=pred_scores, ori_scores=ori_scores, total_scores=total_scores, pred_errors=pred_errors, gold_errors=gold_errors, max_error=max_error_length)))
112 |
113 | # %%
114 | def sort_func(x):
115 | try:
116 | x = x[0]
117 | x = x.split('_')[0]
118 | x = int(x)
119 | return x
120 | except:
121 | return 0
122 |
123 | results = sorted(results, key=sort_func)
124 | for item in results:
125 | print(item)
126 |
127 | with open(f'{args.save_type_name}_ecs.csv', 'w+') as f:
128 | for item in results:
129 | cols = [item[0], str(round(item[1][0] * 100, 2)), str(round(item[1][1][0] * 100 ,2)), str(round(item[1][1][1] * 100 ,2)), str(round(item[1][1][2] * 100 ,2))]
130 | f.write(','.join(cols) + '\n')
131 |
132 | # %%
133 |
--------------------------------------------------------------------------------
/main/predictor/vllm.py:
--------------------------------------------------------------------------------
1 | import json
2 | import torch
3 | from transformers import AutoTokenizer, AutoConfig
4 | from vllm import LLM, SamplingParams
5 | from typing import Tuple, List
6 |
7 |
8 | class Predictor():
9 |
10 | def __init__(self,
11 | tensor_parallel_size: int = 1,
12 | model_from_pretrained: str = None,
13 | **args
14 | ):
15 | '''
16 | Predictor: LLM预测器 (LLM predictor)
17 |
18 | ### Args:
19 |
20 | `tensor_parallel_size`: 张量并行使用的 GPU 数量(模型被拆分到多个卡)
21 |
22 | `model_config_file_name`: bert配置文件名 (bert config file name)
23 | '''
24 | self.tp = tensor_parallel_size
25 | self.model_from_pretrained = model_from_pretrained
26 | self.model_init()
27 |
28 | def model_init(self):
29 | self.config = AutoConfig.from_pretrained(
30 | self.model_from_pretrained, trust_remote_code=True)
31 | self.tokenizer = AutoTokenizer.from_pretrained(
32 | self.model_from_pretrained, padding_side="left", trust_remote_code=True)
33 | self.llm = LLM(self.model_from_pretrained, tensor_parallel_size=self.tp, trust_remote_code=True)
34 | if hasattr(self.config, 'eos_token_id'):
35 | self.eos_token_id = [self.config.eos_token_id]
36 | if hasattr(self.config, 'bos_token_id'):
37 | self.bos_token_id = [self.config.bos_token_id]
38 | if self.config.model_type == 'chatglm':
39 | self.eos_token_id = self.config.eos_token_id
40 | elif self.config.model_type == 'llama':
41 | self.tokenizer.pad_token = self.tokenizer.eos_token
42 | terminators = [
43 | self.tokenizer.eos_token_id,
44 | self.tokenizer.convert_tokens_to_ids("<|eot_id|>")
45 | ]
46 | if not hasattr(self, 'eos_token_id'):
47 | self.eos_token_id = []
48 | for t in terminators:
49 | if t is not None:
50 | self.eos_token_id.append(t)
51 | elif self.config.model_type == "mimo":
52 | self.eos_token_id = self.config.eos_token_id
53 | self.bos_token_id = self.config.bos_token_id
54 | elif self.config.model_type == "tinyr1":
55 | self.eos_token_id = self.config.eos_token_id
56 | self.bos_token_id = self.config.bos_token_id
57 |
58 | def process_model_outputs(self, inputs, outputs, tokenizer):
59 | responses = []
60 | for input_ids, output_ids in zip(inputs['input_ids'], outputs):
61 | response = tokenizer.decode(output_ids[len(input_ids):], skip_special_tokens=True).strip()
62 | responses.append(response)
63 | return responses
64 |
65 | def build_chat_input(self, query:str, history=None):
66 | if history is None:
67 | history = []
68 | history.append(query)
69 | max_input_tokens = 0
70 | new_batch_input = self.tokenizer.apply_chat_template(history, add_generation_prompt=True, tokenize=False)
71 | max_input_tokens = max(max_input_tokens, len(new_batch_input))
72 | return new_batch_input, max_input_tokens
73 |
74 | def predict(self, query: str | list = '', history: List = None, max_length=512, max_new_tokens=512, num_beams:int=1, top_p: float = 1.0, temperature=0, do_sample: bool = False, build_message=False):
75 | if not isinstance(query, list):
76 | query = [query]
77 | history = [history] if history is not None else None
78 | if build_message:
79 | inputs = []
80 | batch_max_len = 0
81 | for i, t in enumerate(query):
82 | if isinstance(t, str):
83 | t = {'role': 'user', 'content': t}
84 | if history is not None and len(history) > 0:
85 | h_unit = history[i]
86 | else:
87 | h_unit = []
88 | t, max_input_tokens = self.build_chat_input(t, h_unit)
89 | if batch_max_len < max_input_tokens:
90 | batch_max_len = max_input_tokens
91 | inputs.append(t)
92 | else:
93 | inputs = query
94 | batch_max_len = 0
95 | for i in range(len(query)):
96 | if len(query[i]) > batch_max_len:
97 | batch_max_len = len(query[i])
98 |
99 | sampling_params = SamplingParams(temperature=temperature, top_p=top_p, max_tokens=max_new_tokens)
100 |
101 | outputs = self.llm.generate(inputs, sampling_params)
102 | results = []
103 | for output in outputs:
104 | prompt = output.prompt
105 | generated_text = output.outputs[0].text
106 | results.append(generated_text)
107 | return results
108 |
109 | def __call__(self, query: str | list = '', history: List = None, max_length=512, max_new_tokens=512, num_beams:int=1, top_p: float = 0.8, temperature=1.0, do_sample: bool = False, build_message=False):
110 | return self.predict(query, history, max_length, max_new_tokens, num_beams, top_p, temperature, do_sample, build_message)
111 |
--------------------------------------------------------------------------------
/main/loaders/chatglm_rlhf.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import torch
4 | import random
5 | import pickle
6 | import numpy as np
7 | from tqdm import tqdm
8 | from torch.utils.data import Dataset, DataLoader
9 | from transformers import PretrainedConfig, PreTrainedTokenizer
10 |
11 |
12 | class ChatGLM_RLHFDataset(Dataset):
13 |
14 | config: PretrainedConfig
15 | tokenizer: PreTrainedTokenizer
16 |
17 | def __init__(self, tokenizer, config, file_name, max_length=512, do_shuffle=False):
18 | self.config = config
19 | self.tokenizer = tokenizer
20 | self.max_length = max_length
21 | self.do_shuffle = do_shuffle
22 | self.data = self.load_jsonl(file_name)
23 | self.random_list = [idx for idx in range(len(self.data))]
24 | self.preprocess_list = [self.process_item(item) for item in self.data]
25 | if self.do_shuffle:
26 | random.shuffle(self.random_list)
27 |
28 | def load_jsonl(self, file_name):
29 | with open(file_name, 'r') as f:
30 | lines = f.readlines()
31 | data = [json.loads(line) for line in lines]
32 | return data
33 |
34 | def process_item(self, item):
35 | conv = item['conversations'] if 'conversations' in item else item
36 |
37 | input_ids, loss_masks = [
38 | self.tokenizer.get_command('[gMASK]'),
39 | self.tokenizer.get_command('sop'),
40 | ], [False, False]
41 |
42 | last_input_len = 0
43 | last_user_content = ''
44 | last_assistant_content = ''
45 |
46 | for message in conv:
47 | if message['role'] in ('system', 'user'):
48 | loss_mask_val = False
49 | else:
50 | loss_mask_val = True
51 |
52 | if message['role'] == 'tool':
53 | raise NotImplementedError()
54 | else:
55 | new_input_ids = self.tokenizer.build_single_message(
56 | message['role'], '', message['content']
57 | )
58 | new_loss_masks = [loss_mask_val] * len(new_input_ids)
59 | if message['role'] == 'user':
60 | last_user_content = message['content']
61 | last_input_len = len(new_input_ids)
62 | elif message['role'] == 'assistant':
63 | last_assistant_content = message['content']
64 | role_ids = self.tokenizer.build_single_message(
65 | message['role'], '', ''
66 | )
67 | last_input_len = len(new_input_ids) - len(role_ids)
68 |
69 | input_ids += new_input_ids
70 | loss_masks += new_loss_masks
71 |
72 | input_ids.append(self.tokenizer.eos_token_id)
73 | loss_masks = [False, *loss_masks]
74 | labels = []
75 | for input_id, mask in zip(input_ids, loss_masks):
76 | if mask:
77 | labels.append(input_id)
78 | else:
79 | labels.append(-100)
80 | max_length = self.max_length
81 |
82 | input_ids_without_last_turn = input_ids[:-last_input_len]
83 | labels_without_last_turn = labels[:-last_input_len]
84 |
85 | exceed_len = len(input_ids) - max_length
86 | if exceed_len > 0:
87 | last_input_len = last_input_len - exceed_len
88 | input_ids = input_ids[:max_length]
89 | labels = labels[:max_length]
90 | input_ids_without_last_turn = input_ids_without_last_turn[:max_length]
91 | labels_without_last_turn = labels_without_last_turn[:max_length]
92 | return {
93 | 'input_ids': input_ids,
94 | 'labels': labels,
95 | 'input_ids_without_last_turn': input_ids_without_last_turn,
96 | 'labels_without_last_turn': labels_without_last_turn,
97 | 'last_input_len': last_input_len,
98 | 'last_user_content': last_user_content,
99 | 'last_assistant_content': last_assistant_content}
100 |
101 | def __getitem__(self, index):
102 | index = self.random_list[index]
103 | item = self.data[index]
104 | input_ids, labels, input_ids_without_last_turn, labels_without_last_turn, last_input_len, last_user_content, last_assistant_content = self.preprocess_list[index].values(
105 | )
106 | gold_answers = item['gold_answers']
107 | bad_answers = item['bad_answers']
108 |
109 | input_ids = torch.tensor(input_ids)
110 | labels = torch.tensor(labels)
111 | input_ids_without_last_turn = torch.tensor(input_ids_without_last_turn)
112 | labels_without_last_turn = torch.tensor(labels_without_last_turn)
113 |
114 | return {
115 | 'query': last_user_content,
116 | 'gold_answers': gold_answers,
117 | 'bad_answers': bad_answers,
118 | 'input_ids': input_ids,
119 | 'input_ids_without_last_turn': input_ids_without_last_turn,
120 | 'labels_without_last_turn': labels_without_last_turn,
121 | 'last_input_len': torch.tensor(last_input_len),
122 | 'last_assistant_content': last_assistant_content,
123 | 'labels': labels
124 | }
125 |
126 | def __len__(self):
127 | return len(self.data)
128 |
--------------------------------------------------------------------------------
/main/predictor/chatglm.py:
--------------------------------------------------------------------------------
1 | import json
2 | import torch
3 | from transformers import AutoTokenizer, AutoModel
4 | from typing import Tuple, List
5 |
6 |
7 | class Predictor():
8 |
9 | def __init__(self,
10 | num_gpus: list = [0],
11 | model_from_pretrained: str = None,
12 | **args
13 | ):
14 | '''
15 | Predictor: ChatGLM预测器 (ChatGLM predictor)
16 |
17 | ### Args:
18 |
19 | `num_gpus`: 使用的GPU编号列表 (the list of GPU numbers)
20 |
21 | `model_config_file_name`: bert配置文件名 (bert config file name)
22 | '''
23 | self.num_gpus = num_gpus
24 | self.model_from_pretrained = model_from_pretrained
25 | self.model_init()
26 |
27 | def model_init(self):
28 | self.tokenizer = AutoTokenizer.from_pretrained(
29 | self.model_from_pretrained, trust_remote_code=True)
30 | self.model = AutoModel.from_pretrained(
31 | self.model_from_pretrained, trust_remote_code=True).half().cuda()
32 | self.model_to_device(gpu=self.num_gpus)
33 | self.model = self.model.eval()
34 |
35 | def model_to_device(self, gpu=[0]):
36 | self.device = torch.device(
37 | "cuda:0" if torch.cuda.is_available() else "cpu")
38 | self.model.cuda()
39 | self.model = torch.nn.DataParallel(self.model, device_ids=gpu).cuda()
40 | self.model.to(self.device)
41 | self.true_model = self.model.module if hasattr(
42 | self.model, 'module') else self.model
43 |
44 | def build_chat_input(self, query, history=None, role="user"):
45 | if history is None:
46 | history = []
47 | input_ids = []
48 | for item in history:
49 | content = item["content"]
50 | if item["role"] == "system" and "tools" in item:
51 | content = content + "\n" + json.dumps(item["tools"], indent=4, ensure_ascii=False)
52 | input_ids.extend(self.tokenizer.build_single_message(item["role"], item.get("metadata", ""), content))
53 | input_ids.extend(self.tokenizer.build_single_message(role, "", query))
54 | input_ids.extend([self.tokenizer.get_command("<|assistant|>")])
55 | return input_ids
56 |
57 | def predict(self, query: str | list = '', history: List = None, max_new_tokens=512, num_beams:int=1, top_p: float = 0.8, temperature=1.0, do_sample: bool = False, build_message=False):
58 | if isinstance(query, str):
59 | query = [query]
60 | history = [history] if history is not None else None
61 | with torch.no_grad():
62 | if build_message:
63 | inputs = []
64 | batch_max_len = 0
65 | for i, t in enumerate(query):
66 | if history is not None and len(history) > 0:
67 | h_unit = history[i]
68 | t = self.build_chat_input(t, h_unit)
69 | else:
70 | t = self.tokenizer.build_single_message("user", "", t)
71 | t.extend([self.tokenizer.get_command("<|assistant|>")])
72 | if batch_max_len < len(t):
73 | batch_max_len = len(t)
74 | inputs.append(t)
75 | for idx, t in enumerate(inputs):
76 | remain = batch_max_len - len(t)
77 | inputs[idx] = [self.tokenizer.pad_token_id] * remain + t
78 | else:
79 | inputs = self.tokenizer(
80 | query, padding=True, truncation=True)['input_ids']
81 | input_ids = torch.LongTensor(inputs).to(self.device)
82 | output = self.true_model.generate(**{
83 | 'input_ids': input_ids,
84 | 'max_new_tokens': max_new_tokens,
85 | 'num_beams': num_beams,
86 | 'do_sample': do_sample,
87 | 'top_p': top_p,
88 | "temperature": temperature,
89 | "eos_token_id": self.true_model.config.eos_token_id
90 | })
91 | out_text = self.tokenizer.batch_decode(
92 | output, skip_special_tokens=True)
93 | if build_message:
94 | out_text = [self.true_model.process_response(t, [])[0] for t in out_text]
95 | return out_text
96 |
97 | @torch.inference_mode()
98 | def chat(self, query: str, history: List[Tuple[str, str]] = None, role: str = "user",
99 | max_length: int = 8192, num_beams=1, do_sample=True, top_p=0.8, temperature=0.8, logits_processor=None,
100 | **kwargs):
101 | response, history = self.true_model.chat(
102 | self.tokenizer, query, history, role, max_length, num_beams, do_sample, top_p, temperature, logits_processor, **kwargs)
103 | return response, history
104 |
105 | @torch.inference_mode()
106 | def stream_chat(self, query: str, history: List[Tuple[str, str]] = None, role: str = "user",
107 | past_key_values=None, max_length: int = 8192, do_sample=True, top_p=0.8, temperature=0.8,
108 | logits_processor=None, return_past_key_values=False, **kwargs):
109 | for result in self.true_model.stream_chat(self.tokenizer, query, history, role, past_key_values, max_length, do_sample, top_p, temperature, logits_processor, return_past_key_values, **kwargs):
110 | yield result
111 |
112 | def __call__(self, query: str | list = '', history: List = None, max_new_tokens=512, num_beams:int=1, top_p: float = 0.8, temperature=1.0, do_sample: bool = False, build_message=False):
113 | return self.predict(query, history, max_new_tokens, num_beams, top_p, temperature, do_sample, build_message)
114 |
--------------------------------------------------------------------------------
/main/predictor/chatglm_lora.py:
--------------------------------------------------------------------------------
1 | from transformers import AutoModel
2 | from transformers import AutoTokenizer
3 | from peft import LoraConfig, TaskType, PeftModel, PeftModelForCausalLM
4 | from typing import Tuple, List
5 | import json
6 | import torch
7 |
8 |
9 | class Predictor():
10 | true_model: PeftModelForCausalLM
11 |
12 | def __init__(self,
13 | num_gpus: list = [0],
14 | model_from_pretrained: str = None,
15 | resume_path: str = None
16 | ):
17 | peft_config = LoraConfig(
18 | task_type=TaskType.CAUSAL_LM,
19 | inference_mode=False,
20 | r=16,
21 | lora_alpha=32,
22 | lora_dropout=0.1,
23 | )
24 | self.tokenizer = AutoTokenizer.from_pretrained(
25 | model_from_pretrained, trust_remote_code=True)
26 | self.model = AutoModel.from_pretrained(
27 | model_from_pretrained, trust_remote_code=True).half().cuda()
28 | self.model = PeftModel.from_pretrained(
29 | self.model, resume_path, config=peft_config)
30 | self.model_to_device(gpu=num_gpus)
31 |
32 | def model_to_device(self, gpu=[0]):
33 | self.device = torch.device(
34 | "cuda:0" if torch.cuda.is_available() else "cpu")
35 | self.model.cuda()
36 | self.model = torch.nn.DataParallel(self.model, device_ids=gpu).cuda()
37 | self.model.to(self.device)
38 | self.true_model = self.model.module if hasattr(
39 | self.model, 'module') else self.model
40 |
41 | def build_chat_input(self, query, history=None, role="user"):
42 | if history is None:
43 | history = []
44 | input_ids = []
45 | for item in history:
46 | content = item["content"]
47 | if item["role"] == "system" and "tools" in item:
48 | content = content + "\n" + json.dumps(item["tools"], indent=4, ensure_ascii=False)
49 | input_ids.extend(self.tokenizer.build_single_message(item["role"], item.get("metadata", ""), content))
50 | input_ids.extend(self.tokenizer.build_single_message(role, "", query))
51 | input_ids.extend([self.tokenizer.get_command("<|assistant|>")])
52 | return input_ids
53 |
54 | def predict(self, query: str | list = '', history: List = None, max_new_tokens=512, num_beams:int=1, top_p: float = 0.8, temperature=1.0, do_sample: bool = False, build_message=False):
55 | if isinstance(query, str):
56 | query = [query]
57 | history = [history] if history is not None else None
58 | with torch.no_grad():
59 | if build_message:
60 | inputs = []
61 | batch_max_len = 0
62 | for i, t in enumerate(query):
63 | if history is not None and len(history) > 0:
64 | h_unit = history[i]
65 | t = self.build_chat_input(t, h_unit)
66 | else:
67 | t = self.tokenizer.build_single_message("user", "", t)
68 | t.extend([self.tokenizer.get_command("<|assistant|>")])
69 | if batch_max_len < len(t):
70 | batch_max_len = len(t)
71 | inputs.append(t)
72 | for idx, t in enumerate(inputs):
73 | remain = batch_max_len - len(t)
74 | inputs[idx] = [self.tokenizer.pad_token_id] * remain + t
75 | else:
76 | inputs = self.tokenizer(
77 | query, padding=True, truncation=True)['input_ids']
78 | input_ids = torch.LongTensor(inputs).to(self.device)
79 | output = self.true_model.generate(**{
80 | 'input_ids': input_ids,
81 | 'max_new_tokens': max_new_tokens,
82 | 'num_beams': num_beams,
83 | 'do_sample': do_sample,
84 | 'top_p': top_p,
85 | "temperature": temperature,
86 | 'do_sample': False
87 | })
88 | out_text = self.tokenizer.batch_decode(
89 | output, skip_special_tokens=True)
90 | if build_message:
91 | out_text = [self.true_model.process_response(t, [])[0] for t in out_text]
92 | return out_text
93 |
94 | @torch.inference_mode()
95 | def chat(self, query: str, history: List[Tuple[str, str]] = None, role: str = "user",
96 | max_length: int = 8192, num_beams=1, do_sample=True, top_p=0.8, temperature=0.8, logits_processor=None,
97 | **kwargs):
98 | response, history = self.true_model.chat(
99 | self.tokenizer, query, history, role, max_length, num_beams, do_sample, top_p, temperature, logits_processor, **kwargs)
100 | return response, history
101 |
102 | @torch.inference_mode()
103 | def stream_chat(self, query: str, history: List[Tuple[str, str]] = None, role: str = "user",
104 | past_key_values=None, max_length: int = 8192, do_sample=True, top_p=0.8, temperature=0.8,
105 | logits_processor=None, return_past_key_values=False, **kwargs):
106 | for result in self.true_model.stream_chat(self.tokenizer, query, history, role, past_key_values, max_length, do_sample, top_p, temperature, logits_processor, return_past_key_values, **kwargs):
107 | yield result
108 |
109 | def __call__(self, query: str | list = '', history: List = None, max_new_tokens=512, num_beams:int=1, top_p: float = 0.8, temperature=1.0, do_sample: bool = False, build_message=False):
110 | return self.predict(query=query, history=history, max_new_tokens=max_new_tokens, num_beams=num_beams, top_p=top_p, temperature=temperature, do_sample=do_sample, build_message=build_message)
111 |
--------------------------------------------------------------------------------
/discuss/compute_range_ccs.py:
--------------------------------------------------------------------------------
1 | # %%
2 | import os
3 | import json
4 | import random
5 | import json_repair
6 | import numpy as np
7 | from tqdm import tqdm
8 | from argparse import ArgumentParser
9 |
10 | import sys
11 | sys.path.append("../")
12 | cmd_args = True
13 |
14 | parser = ArgumentParser()
15 | parser.add_argument('--file_dir', default='../datasets', help='the directory of the datasets.')
16 | parser.add_argument('--file_name', default='all', help='file name of the dataset without extension (e.g. 0_Physics_ShortAns), you can define `all` for all dataset files.')
17 | parser.add_argument('--save_type_name', default='Deepseek', help='the prefix name of save dir (usually is the LLM name)')
18 |
19 | if not cmd_args:
20 | args = parser.parse_args([]) # You can directly set above parameters in the default.
21 | else:
22 | args = parser.parse_args()
23 |
24 | SOURCE_DIR = os.path.join(args.file_dir + '_' + args.save_type_name + '_Scored')
25 |
26 | from utils.collaborative_consistency_score import adjusted_qwk
27 |
28 | results = []
29 | if args.file_name != 'all':
30 | files = [args.file_name + '_prediction.jsonl']
31 | else:
32 | files = os.listdir(SOURCE_DIR)
33 | for file_name in files:
34 | if file_name.find('prediction') < 0:
35 | continue
36 | SOURCE_FILE = os.path.join(SOURCE_DIR, file_name)
37 | with open(SOURCE_FILE, encoding='utf-8') as f:
38 | ori_data = f.readlines()
39 | ori_data = [json.loads(line) for line in ori_data]
40 |
41 | pred_result = []
42 | ori_result = []
43 | weights_result = []
44 | total_scores = []
45 | max_score = 0
46 | max_length = 0
47 | for item in tqdm(ori_data):
48 | if max_score < item['total']:
49 | max_score = item['total']
50 | total_scores.append(item['total'])
51 | ori_steps = item['steps']
52 | if 'pred_steps' not in item:
53 | pred_steps = [{"step_score": 0, "errors": []} for _ in ori_steps]
54 | else:
55 | pred_steps = item['pred_steps']
56 | ori_score = item['manual_label']
57 | if 'pred_label' not in item or item['pred_label'] == '':
58 | pred_score = 0
59 | else:
60 | pred_score = item['pred_label']
61 | ori_labels = [int(ori_score)]
62 | pred_labels = [int(pred_score)]
63 | weights = [0.5] + [0.5 / len(ori_steps) for _ in range(len(ori_steps))]
64 | for i in range(len(ori_steps)):
65 | try:
66 | ori_labels.append(int(ori_steps[i]['label']))
67 | except:
68 | ori_labels.append(0)
69 | if len(pred_steps) > i and type(pred_steps[i]) == dict and 'step_score' in pred_steps[i]:
70 | pred_labels.append(int(pred_steps[i]['step_score']))
71 | else:
72 | pred_labels.append(0)
73 | if max_length < len(ori_labels):
74 | max_length = len(ori_labels)
75 | pred_result.append(pred_labels)
76 | ori_result.append(ori_labels)
77 | weights_result.append(weights)
78 |
79 | # Padding
80 | for i in range(len(pred_result)):
81 | pred_result[i] += [0] * (max_length - len(pred_result[i]))
82 | ori_result[i] += [0] * (max_length - len(ori_result[i]))
83 | weights_result[i] += [0] * (max_length - len(weights_result[i]))
84 | max_scores = []
85 | for i in range(len(ori_result[0])):
86 | val_i = []
87 | for item in ori_result:
88 | val_i.append(item[i])
89 | for item in pred_result:
90 | val_i.append(item[i])
91 | max_scores.append(max(val_i))
92 | max_scores[0] = max_score
93 |
94 | all_norm_scores = []
95 | Ln = 1e-10
96 | for scores, max_s in zip(ori_result, total_scores):
97 | norm_score = scores[0] / (max_s + Ln)
98 | all_norm_scores.append(norm_score)
99 | range1, range2 = 0, 0
100 | sorted_scores = sorted(all_norm_scores)
101 | range1 = sorted_scores[int(len(sorted_scores) * 0.33)]
102 | range2 = sorted_scores[int(len(sorted_scores) * 0.67)]
103 | low_list = {
104 | 'ori_result': [],
105 | 'pred_result': [],
106 | 'weights_result': [],
107 | 'max_scores': max_scores
108 | }
109 | mid_list = {
110 | 'ori_result': [],
111 | 'pred_result': [],
112 | 'weights_result': [],
113 | 'max_scores': max_scores
114 | }
115 | high_list = {
116 | 'ori_result': [],
117 | 'pred_result': [],
118 | 'weights_result': [],
119 | 'max_scores': max_scores
120 | }
121 | for i in range(len(total_scores)):
122 | total = total_scores[i]
123 | gold = ori_result[i][0]
124 | gold = gold / total
125 | if gold <= range1:
126 | low_list['ori_result'].append(ori_result[i])
127 | low_list['pred_result'].append(pred_result[i])
128 | low_list['weights_result'].append(weights_result[i])
129 | elif gold < range2:
130 | mid_list['ori_result'].append(ori_result[i])
131 | mid_list['pred_result'].append(pred_result[i])
132 | mid_list['weights_result'].append(weights_result[i])
133 | else:
134 | high_list['ori_result'].append(ori_result[i])
135 | high_list['pred_result'].append(pred_result[i])
136 | high_list['weights_result'].append(weights_result[i])
137 | results.append((file_name, adjusted_qwk(*low_list.values())[0], adjusted_qwk(*mid_list.values())[0], adjusted_qwk(*high_list.values())[0]))
138 |
139 | # %%
140 | def sort_func(x):
141 | try:
142 | x = x[0]
143 | x = x.split('_')[0]
144 | x = int(x)
145 | return x
146 | except:
147 | return 0
148 |
149 | results = sorted(results, key=sort_func)
150 | for item in results:
151 | print(item)
152 |
153 | with open(f'{args.save_type_name}_range_ccs.csv', 'w+') as f:
154 | for item in results:
155 | cols = [item[0], str(round(item[1] * 100, 2)), str(round(item[2] * 100 ,2)), str(round(item[3] * 100 ,2))]
156 | f.write(','.join(cols) + '\n')
157 |
158 | # %%
159 |
--------------------------------------------------------------------------------
/main/predictor/llm_lora.py:
--------------------------------------------------------------------------------
1 | from transformers import AutoModel
2 | from transformers import AutoTokenizer, AutoConfig, LlamaForCausalLM, AutoModelForCausalLM
3 | from peft import LoraConfig, TaskType, PeftModel, PeftModelForCausalLM
4 | from typing import Tuple, List
5 | import json
6 | import torch
7 |
8 |
9 | class Predictor():
10 | true_model: PeftModelForCausalLM
11 |
12 | def __init__(self,
13 | num_gpus: list = [0],
14 | model_from_pretrained: str = None,
15 | resume_path: str = None,
16 | lora_r=16, lora_alpha=32, lora_dropout=0.1,
17 | ):
18 | peft_config = LoraConfig(
19 | task_type=TaskType.CAUSAL_LM,
20 | inference_mode=False,
21 | r=lora_r,
22 | lora_alpha=lora_alpha,
23 | lora_dropout=lora_dropout,
24 | )
25 | self.config = AutoConfig.from_pretrained(
26 | model_from_pretrained, trust_remote_code=True)
27 | self.tokenizer = AutoTokenizer.from_pretrained(
28 | model_from_pretrained, trust_remote_code=True)
29 |
30 | if self.config.model_type == 'chatglm':
31 | self.model = AutoModel.from_pretrained(
32 | self.model_from_pretrained, trust_remote_code=True).to(torch.bfloat16)
33 | self.eos_token_id = self.config.eos_token_id
34 | elif self.config.model_type == 'llama':
35 | self.tokenizer.pad_token = self.tokenizer.eos_token
36 | self.model = LlamaForCausalLM.from_pretrained(
37 | self.model_from_pretrained, trust_remote_code=True).to(torch.bfloat16)
38 | terminators = [
39 | self.tokenizer.eos_token_id,
40 | self.tokenizer.convert_tokens_to_ids("<|eot_id|>")
41 | ]
42 | self.eos_token_id = terminators
43 | elif self.config.model_type == 'qwen':
44 | self.model = AutoModelForCausalLM.from_pretrained(
45 | self.model_from_pretrained, torch_dtype="auto", device_map="auto", trust_remote_code=True)
46 | elif self.config.model_type == 'qwen2':
47 | self.model = AutoModelForCausalLM.from_pretrained(
48 | self.model_from_pretrained, torch_dtype="auto", device_map="auto", trust_remote_code=True)
49 |
50 | self.model = PeftModel.from_pretrained(
51 | self.model, resume_path, config=peft_config)
52 | self.model_to_device(gpu=num_gpus)
53 |
54 | def model_to_device(self, gpu=[0]):
55 | self.device = torch.device(
56 | "cuda:0" if torch.cuda.is_available() else "cpu")
57 | self.model.cuda()
58 | self.model = torch.nn.DataParallel(self.model, device_ids=gpu).cuda()
59 | self.model.to(self.device)
60 | self.true_model = self.model.module if hasattr(
61 | self.model, 'module') else self.model
62 |
63 | def process_model_outputs(self, inputs, outputs, tokenizer):
64 | responses = []
65 | for input_ids, output_ids in zip(inputs['input_ids'], outputs):
66 | response = tokenizer.decode(output_ids[len(input_ids):], skip_special_tokens=True).strip()
67 | responses.append(response)
68 | return responses
69 |
70 | def build_chat_input(self, query:str, history=None):
71 | if history is None:
72 | history = []
73 | history.append(query)
74 | max_input_tokens = 0
75 | new_batch_input = self.tokenizer.apply_chat_template(history, add_generation_prompt=True, tokenize=False)
76 | max_input_tokens = max(max_input_tokens, len(new_batch_input))
77 | return new_batch_input, max_input_tokens
78 |
79 | def predict(self, query: str | list = '', history: List = None, max_length=512, max_new_tokens=512, num_beams:int=1, top_p: float = 0.8, temperature=1.0, do_sample: bool = False, build_message=False):
80 | if not isinstance(query, list):
81 | query = [query]
82 | history = [history] if history is not None else None
83 | with torch.no_grad():
84 | if build_message:
85 | inputs = []
86 | batch_max_len = 0
87 | for i, t in enumerate(query):
88 | if isinstance(t, str):
89 | t = {'role': 'user', 'content': t}
90 | if history is not None and len(history) > 0:
91 | h_unit = history[i]
92 | else:
93 | h_unit = []
94 | t, max_input_tokens = self.build_chat_input(t, h_unit)
95 | if batch_max_len < max_input_tokens:
96 | batch_max_len = max_input_tokens
97 | inputs.append(t)
98 | else:
99 | inputs = query
100 | batch_max_len = 0
101 | for i in range(len(query)):
102 | if len(query[i]) > batch_max_len:
103 | batch_max_len = len(query[i])
104 | batched_inputs = self.tokenizer(
105 | inputs,
106 | return_tensors="pt",
107 | padding=True,
108 | truncation=True).to(self.device)
109 | if self.config.model_type == 'llama':
110 | batched_inputs = batched_inputs.data
111 | batched_outputs = self.true_model.generate(**batched_inputs, **{
112 | 'max_new_tokens': max_new_tokens,
113 | 'num_beams': num_beams,
114 | 'do_sample': do_sample,
115 | 'top_p': top_p,
116 | "temperature": temperature,
117 | "eos_token_id": self.eos_token_id
118 | })
119 | batched_response = self.process_model_outputs(batched_inputs, batched_outputs, self.tokenizer)
120 | return batched_response
121 |
122 | def __call__(self, query: str | list = '', history: List = None, max_length=512, max_new_tokens=512, num_beams:int=1, top_p: float = 0.8, temperature=1.0, do_sample: bool = False, build_message=False):
123 | return self.predict(query, history, max_length, max_new_tokens, num_beams, top_p, temperature, do_sample, build_message)
124 |
--------------------------------------------------------------------------------
/main/predictor/llm.py:
--------------------------------------------------------------------------------
1 | import json
2 | import torch
3 | from transformers import AutoTokenizer, AutoModel, AutoConfig, LlamaForCausalLM, AutoModelForCausalLM
4 | from typing import Tuple, List
5 |
6 |
7 | class Predictor():
8 |
9 | def __init__(self,
10 | num_gpus: list = [0],
11 | model_from_pretrained: str = None,
12 | **args
13 | ):
14 | '''
15 | Predictor: LLM预测器 (LLM predictor)
16 |
17 | ### Args:
18 |
19 | `num_gpus`: 使用的GPU编号列表 (the list of GPU numbers)
20 |
21 | `model_config_file_name`: bert配置文件名 (bert config file name)
22 | '''
23 | self.num_gpus = num_gpus
24 | self.model_from_pretrained = model_from_pretrained
25 | self.model_init()
26 |
27 | def model_init(self):
28 | self.config = AutoConfig.from_pretrained(
29 | self.model_from_pretrained, trust_remote_code=True)
30 | self.tokenizer = AutoTokenizer.from_pretrained(
31 | self.model_from_pretrained, padding_side="left", trust_remote_code=True)
32 | if hasattr(self.config, 'eos_token_id'):
33 | self.eos_token_id = [self.config.eos_token_id]
34 | if hasattr(self.config, 'bos_token_id'):
35 | self.bos_token_id = [self.config.bos_token_id]
36 | if self.config.model_type == 'chatglm':
37 | self.model = AutoModel.from_pretrained(
38 | self.model_from_pretrained, trust_remote_code=True).to(torch.bfloat16)
39 | self.eos_token_id = self.config.eos_token_id
40 | elif self.config.model_type == 'llama':
41 | self.tokenizer.pad_token = self.tokenizer.eos_token
42 | self.model = LlamaForCausalLM.from_pretrained(
43 | self.model_from_pretrained, trust_remote_code=True).to(torch.bfloat16)
44 | terminators = [
45 | self.tokenizer.eos_token_id,
46 | self.tokenizer.convert_tokens_to_ids("<|eot_id|>")
47 | ]
48 | if not hasattr(self, 'eos_token_id'):
49 | self.eos_token_id = []
50 | for t in terminators:
51 | if t is not None:
52 | self.eos_token_id.append(t)
53 | elif self.config.model_type == 'qwen':
54 | self.model = AutoModelForCausalLM.from_pretrained(
55 | self.model_from_pretrained, torch_dtype="auto", device_map="auto", trust_remote_code=True)
56 | elif self.config.model_type == 'qwen2':
57 | self.model = AutoModelForCausalLM.from_pretrained(
58 | self.model_from_pretrained, torch_dtype="auto", device_map="auto", trust_remote_code=True)
59 | elif self.config.model_type == "mimo":
60 | self.eos_token_id = self.config.eos_token_id
61 | self.bos_token_id = self.config.bos_token_id
62 | self.model = AutoModelForCausalLM.from_pretrained(self.model_from_pretrained, device_map="auto",trust_remote_code=True)
63 | elif self.config.model_type == "tinyr1":
64 | self.eos_token_id = self.config.eos_token_id
65 | self.bos_token_id = self.config.bos_token_id
66 | self.model = AutoModelForCausalLM.from_pretrained(self.model_from_pretrained, device_map="auto",trust_remote_code=True)
67 | self.model_to_device(gpu=self.num_gpus)
68 | self.model = self.model.eval()
69 |
70 | def model_to_device(self, gpu=[0]):
71 | self.device = torch.device(
72 | "cuda:0" if torch.cuda.is_available() else "cpu")
73 | self.model.cuda()
74 | self.model = torch.nn.DataParallel(self.model, device_ids=gpu).cuda()
75 | self.model.to(self.device)
76 | self.true_model = self.model.module if hasattr(
77 | self.model, 'module') else self.model
78 |
79 | def process_model_outputs(self, inputs, outputs, tokenizer):
80 | responses = []
81 | for input_ids, output_ids in zip(inputs['input_ids'], outputs):
82 | response = tokenizer.decode(output_ids[len(input_ids):], skip_special_tokens=True).strip()
83 | responses.append(response)
84 | return responses
85 |
86 | def build_chat_input(self, query:str, history=None):
87 | if history is None:
88 | history = []
89 | history.append(query)
90 | max_input_tokens = 0
91 | new_batch_input = self.tokenizer.apply_chat_template(history, add_generation_prompt=True, tokenize=False)
92 | max_input_tokens = max(max_input_tokens, len(new_batch_input))
93 | return new_batch_input, max_input_tokens
94 |
95 | def predict(self, query: str | list = '', history: List = None, max_length=512, max_new_tokens=512, num_beams:int=1, top_p: float = 0.8, temperature=1.0, do_sample: bool = False, build_message=False):
96 | if not isinstance(query, list):
97 | query = [query]
98 | history = [history] if history is not None else None
99 | with torch.no_grad():
100 | if build_message:
101 | inputs = []
102 | batch_max_len = 0
103 | for i, t in enumerate(query):
104 | if isinstance(t, str):
105 | t = {'role': 'user', 'content': t}
106 | if history is not None and len(history) > 0:
107 | h_unit = history[i]
108 | else:
109 | h_unit = []
110 | t, max_input_tokens = self.build_chat_input(t, h_unit)
111 | if batch_max_len < max_input_tokens:
112 | batch_max_len = max_input_tokens
113 | inputs.append(t)
114 | else:
115 | inputs = query
116 | batch_max_len = 0
117 | for i in range(len(query)):
118 | if len(query[i]) > batch_max_len:
119 | batch_max_len = len(query[i])
120 | batched_inputs = self.tokenizer(
121 | inputs,
122 | return_tensors="pt",
123 | padding=True,
124 | truncation=True).to(self.device)
125 | if self.config.model_type == 'llama':
126 | batched_inputs = batched_inputs.data
127 | batched_outputs = self.true_model.generate(**batched_inputs, **{
128 | 'max_new_tokens': max_new_tokens,
129 | 'num_beams': num_beams,
130 | 'do_sample': do_sample,
131 | 'top_p': top_p,
132 | "temperature": temperature,
133 | "eos_token_id": self.eos_token_id
134 | })
135 | batched_response = self.process_model_outputs(batched_inputs, batched_outputs, self.tokenizer)
136 | return batched_response
137 |
138 | def __call__(self, query: str | list = '', history: List = None, max_length=512, max_new_tokens=512, num_beams:int=1, top_p: float = 0.8, temperature=1.0, do_sample: bool = False, build_message=False):
139 | return self.predict(query, history, max_length, max_new_tokens, num_beams, top_p, temperature, do_sample, build_message)
140 |
--------------------------------------------------------------------------------
/main/loader.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import random
4 | from torch.utils.data import TensorDataset, DataLoader, Dataset
5 | from main.loaders.llm_pure_text import LLMPureTextDataset
6 | from main.loaders.chatglm_chat import ChatGLM_ChatDataset
7 | from main.loaders.qwen_chat import QwenChatDataset
8 | from main.loaders.llm_chat import LLMChatDataset
9 | from main.loaders.chatglm_rlhf import ChatGLM_RLHFDataset
10 | import torch
11 |
12 | def collate_fn_wrapper(tokenizer):
13 | def left_pad_collate_fn(batch):
14 | result = {}
15 | max_length = 0
16 | max_length_without_last_turn = 0
17 | for item in batch:
18 | if item['input_ids'].shape[0] > max_length:
19 | max_length = item['input_ids'].shape[0]
20 | if 'input_ids_without_last_turn' in item and item['input_ids_without_last_turn'].shape[0] > max_length_without_last_turn:
21 | max_length_without_last_turn = item['input_ids_without_last_turn'].shape[0]
22 | for item in batch:
23 | for key in item:
24 | if key not in result:
25 | result[key] = []
26 | if key in ('input_ids'):
27 | pad_length = max_length - len(item[key])
28 | item[key] = torch.cat([torch.LongTensor([tokenizer.pad_token_id] * pad_length), item[key]], dim=-1)
29 | elif key == ('labels'):
30 | pad_length = max_length - len(item[key])
31 | item[key] = torch.cat([torch.LongTensor([-100] * pad_length), item[key]], dim=-1)
32 | if key in ('input_ids_without_last_turn'):
33 | pad_length = max_length_without_last_turn - len(item[key])
34 | item[key] = torch.cat([torch.LongTensor([tokenizer.pad_token_id] * pad_length), item[key]], dim=-1)
35 | elif key == ('labels_without_last_turn'):
36 | pad_length = max_length_without_last_turn - len(item[key])
37 | item[key] = torch.cat([torch.LongTensor([-100] * pad_length), item[key]], dim=-1)
38 | result[key].append(item[key])
39 | for key in result:
40 | if key in ('input_ids', 'labels', 'input_ids_without_last_turn', 'labels_without_last_turn', 'last_input_len'):
41 | result[key] = torch.stack(result[key])
42 | return result
43 | return left_pad_collate_fn
44 |
45 | class AutoDataloader():
46 |
47 | '''
48 | loader_name: str; the dataloader name
49 | data_path: str or obj; the path of the data; if str, it will use the present dataset in data_present_path, or you should define the path like e.g. { 'train': './train.json', 'dev': './dev.json' }
50 | model_type: interactive or siamese
51 | data_present_path: str; the path of the data_present; the data_present is a json file which contains the path of the dataset, and the format is like e.g. { 'dataset_name': {'train': './train.json', 'dev': './dev.json'} }
52 | max_length: int; the length of the padding
53 | '''
54 |
55 | def __init__(self, tokenizer, config, loader_name='LLM_Chat', data_path="Boss", data_present_path="./data/present.json", max_length=50):
56 | self.tokenizer = tokenizer
57 | self.loader_name = loader_name
58 | self.max_length = max_length
59 | self.data_present = self.get_data_present(data_present_path)
60 | self.data_path = self.data_present[data_path] if data_path in self.data_present else data_path
61 | if loader_name == 'LLM_Pure':
62 | self.train_set = LLMPureTextDataset(
63 | tokenizer, config, self.data_path['train'], max_length=self.max_length, do_shuffle=True)
64 | self.eval_set = LLMPureTextDataset(
65 | tokenizer, config, self.data_path['dev'], max_length=self.max_length, do_shuffle=False)
66 | if 'test' in self.data_path:
67 | self.test_set = LLMPureTextDataset(
68 | tokenizer, config, self.data_path['test'], max_length=self.max_length, do_shuffle=False)
69 | elif loader_name == 'ChatGLM_Chat':
70 | self.train_set = ChatGLM_ChatDataset(
71 | tokenizer, config, self.data_path['train'], max_length=self.max_length, do_shuffle=True)
72 | self.eval_set = ChatGLM_ChatDataset(
73 | tokenizer, config, self.data_path['dev'], max_length=self.max_length, do_shuffle=False)
74 | if 'test' in self.data_path:
75 | self.test_set = ChatGLM_ChatDataset(
76 | tokenizer, config, self.data_path['test'], max_length=self.max_length, do_shuffle=False)
77 | elif loader_name == 'Qwen_Chat':
78 | self.train_set = QwenChatDataset(
79 | tokenizer, config, self.data_path['train'], max_length=self.max_length, do_shuffle=True)
80 | self.eval_set = QwenChatDataset(
81 | tokenizer, config, self.data_path['dev'], max_length=self.max_length, do_shuffle=False)
82 | if 'test' in self.data_path:
83 | self.test_set = QwenChatDataset(
84 | tokenizer, config, self.data_path['test'], max_length=self.max_length, do_shuffle=False)
85 | elif loader_name == 'LLM_Chat':
86 | self.train_set = LLMChatDataset(
87 | tokenizer, config, self.data_path['train'], max_length=self.max_length, do_shuffle=True)
88 | self.eval_set = LLMChatDataset(
89 | tokenizer, config, self.data_path['dev'], max_length=self.max_length, do_shuffle=False)
90 | if 'test' in self.data_path:
91 | self.test_set = LLMChatDataset(
92 | tokenizer, config, self.data_path['test'], max_length=self.max_length, do_shuffle=False)
93 | elif loader_name == 'ChatGLM_RLHF':
94 | self.train_set = ChatGLM_RLHFDataset(
95 | tokenizer, config, self.data_path['train'], max_length=self.max_length, do_shuffle=True)
96 | self.eval_set = ChatGLM_RLHFDataset(
97 | tokenizer, config, self.data_path['dev'], max_length=self.max_length, do_shuffle=False)
98 | if 'test' in self.data_path:
99 | self.test_set = ChatGLM_RLHFDataset(
100 | tokenizer, config, self.data_path['test'], max_length=self.max_length, do_shuffle=False)
101 |
102 | def get_data_present(self, present_path):
103 | if not os.path.exists(present_path):
104 | return {}
105 | with open(present_path, encoding='utf-8') as f:
106 | present_json = f.read()
107 | data_present = json.loads(present_json)
108 | return data_present
109 |
110 | def __call__(self, batch_size=1, batch_size_eval=1, eval_mode='dev', use_collate=False):
111 | if not use_collate:
112 | dataiter = DataLoader(self.train_set, batch_size=batch_size)
113 | if eval_mode == 'dev':
114 | dataiter_eval = DataLoader(
115 | self.eval_set, batch_size=batch_size_eval)
116 | else:
117 | dataiter_eval = DataLoader(
118 | self.test_set, batch_size=batch_size_eval)
119 | else:
120 | left_pad_collate_fn = collate_fn_wrapper(self.tokenizer)
121 | dataiter = DataLoader(self.train_set, batch_size=batch_size, collate_fn=left_pad_collate_fn)
122 | if eval_mode == 'dev':
123 | dataiter_eval = DataLoader(
124 | self.eval_set, batch_size=batch_size_eval, collate_fn=left_pad_collate_fn)
125 | else:
126 | dataiter_eval = DataLoader(
127 | self.test_set, batch_size=batch_size_eval, collate_fn=left_pad_collate_fn)
128 | return dataiter, dataiter_eval
129 |
--------------------------------------------------------------------------------
/main/trainer/llm_lora.py:
--------------------------------------------------------------------------------
1 | import os
2 | import uuid
3 | import torch
4 | import torch.nn as nn
5 | import torch.optim as optim
6 | import torch.nn.functional as F
7 | from transformers import AutoModel, AutoConfig, LlamaForCausalLM, AutoModelForCausalLM
8 | from transformers import get_linear_schedule_with_warmup
9 | from peft import get_peft_model, LoraConfig, TaskType, PeftModel
10 | import numpy as np
11 | import jieba
12 | from nltk.translate.bleu_score import SmoothingFunction, sentence_bleu
13 | from rouge_chinese import Rouge
14 | from tqdm import tqdm
15 | from main.loader import AutoDataloader
16 | from main.analysis import Analysis
17 | from accelerate import Accelerator
18 | accelerator = Accelerator()
19 |
20 |
21 | class Trainer():
22 |
23 | def __init__(self, tokenizer, from_pretrained, loader_name, data_path, config=None, resume_path=None, max_length=512, batch_size=1, batch_size_eval=1,
24 | lora_r=16, lora_alpha=32, lora_dropout=0.1,
25 | eval_mode='dev', task_name='Sim'):
26 | self.tokenizer = tokenizer
27 | self.config = config
28 | self.accelerate = accelerator
29 | self.loader_name = loader_name
30 | self.data_path = data_path
31 | self.from_pretrained = from_pretrained
32 | self.data_path = data_path
33 | self.task_name = task_name
34 | self.max_length = max_length
35 | self.batch_size = batch_size
36 | self.batch_size_eval = batch_size_eval
37 | self.lora_r = lora_r
38 | self.lora_alpha = lora_alpha
39 | self.lora_dropout = lora_dropout
40 | self.eval_mode = eval_mode
41 | self.config_init()
42 | self.dataloader_init()
43 | self.model_init(resume_path=resume_path)
44 | self.analysis = Analysis()
45 |
46 | def config_init(self):
47 | self.config = AutoConfig.from_pretrained(
48 | self.from_pretrained, trust_remote_code=True) if self.config is None else self.config
49 | if self.config.model_type == 'llama':
50 | self.tokenizer.pad_token = self.tokenizer.eos_token
51 |
52 | def model_init(self, resume_path=None):
53 | if self.accelerate.is_local_main_process:
54 | print('AutoModel Choose Model: {}\n'.format(
55 | self.from_pretrained))
56 | if self.config.model_type == 'chatglm':
57 | target_modules=['query_key_value']
58 | self.model = AutoModel.from_pretrained(
59 | self.from_pretrained, trust_remote_code=True).to(torch.bfloat16)
60 | elif self.config.model_type == 'llama':
61 | target_modules=["q_proj", "k_proj", "v_proj"]
62 | self.model = LlamaForCausalLM.from_pretrained(
63 | self.from_pretrained, trust_remote_code=True).to(torch.bfloat16)
64 | elif self.config.model_type == 'qwen':
65 | target_modules=["c_attn", "c_proj", "w1", "w2"]
66 | self.model = AutoModelForCausalLM.from_pretrained(
67 | self.from_pretrained, torch_dtype="auto", device_map="auto", trust_remote_code=True)
68 | elif self.config.model_type == 'qwen2':
69 | target_modules=["q_proj", "k_proj", "v_proj"]
70 | self.model = AutoModelForCausalLM.from_pretrained(
71 | self.from_pretrained, torch_dtype="auto", device_map="auto", trust_remote_code=True)
72 | peft_config = LoraConfig(
73 | task_type=TaskType.CAUSAL_LM,
74 | inference_mode=False,
75 | r=self.lora_r,
76 | target_modules=target_modules,
77 | lora_alpha=self.lora_alpha,
78 | lora_dropout=self.lora_dropout
79 | )
80 | if resume_path is not None:
81 | print('Accessing Resume PATH: {} ...\n'.format(resume_path))
82 | self.model.enable_input_require_grads()
83 | self.model = PeftModel.from_pretrained(
84 | self.model, resume_path, config=peft_config)
85 | else:
86 | self.model = get_peft_model(self.model, peft_config)
87 | self.model.print_trainable_parameters()
88 |
89 | def dataloader_init(self):
90 | d = AutoDataloader(self.tokenizer, self.config, loader_name=self.loader_name, data_path=self.data_path,
91 | max_length=self.max_length)
92 | self.train_loader, self.eval_loader = d(
93 | self.batch_size, self.batch_size_eval, self.eval_mode, True)
94 |
95 | def __call__(self, resume_step=None, num_epochs=30, lr=1e-4, eval_call_epoch=None):
96 | return self.train(resume_step=resume_step,
97 | num_epochs=num_epochs, lr=lr, eval_call_epoch=eval_call_epoch)
98 |
99 | def train(self, resume_step=None, num_epochs=30, lr=1e-4, eval_call_epoch=None):
100 |
101 | optimizer = optim.Adam(self.model.parameters(), lr=lr, weight_decay=0.)
102 | scheduler = get_linear_schedule_with_warmup(optimizer, 190, 80000)
103 | self.model, optimizer, train_loader, scheduler = self.accelerate.prepare(
104 | self.model, optimizer, self.train_loader, scheduler)
105 |
106 | current_uid = str(uuid.uuid1()).split('-')[0]
107 |
108 | train_step = resume_step if resume_step is not None else 0
109 | for epoch in range(num_epochs):
110 | train_count = 0
111 | train_loss = 0
112 | eval_scores = {
113 | 'rouge-1': 0,
114 | 'rouge-2': 0,
115 | 'rouge-l': 0,
116 | 'bleu-4': 0
117 | }
118 |
119 | train_iter = tqdm(train_loader)
120 | self.model.train()
121 |
122 | for it in train_iter:
123 |
124 | output = self.model(**it)
125 | loss = output.loss
126 | loss = loss.mean()
127 | # loss.backward()
128 | self.accelerate.backward(loss)
129 | optimizer.step()
130 | scheduler.step()
131 | self.model.zero_grad()
132 |
133 | logits = output.logits
134 | labels = it['labels']
135 | metrics = self.compute_metrics(logits, labels)
136 | for k, v in metrics.items():
137 | eval_scores[k] += v
138 |
139 | train_loss += loss.data.item()
140 | train_count += 1
141 | train_step += 1
142 |
143 | train_iter.set_description(
144 | 'Train: {}/{}'.format(epoch + 1, num_epochs))
145 | train_iter.set_postfix(
146 | train_loss=train_loss / train_count, **{k: v / train_count for k, v in eval_scores.items()})
147 |
148 | self.analysis.append_train_record({
149 | 'epoch': epoch + 1,
150 | 'train_loss': train_loss / train_count,
151 | **{k: v / train_count for k, v in eval_scores.items()}
152 | })
153 |
154 | model_uid = self.save_model(train_step)
155 | if eval_call_epoch is None or eval_call_epoch(epoch):
156 | self.eval(epoch)
157 |
158 | self.analysis.save_all_records(
159 | uid=current_uid if self.task_name is None else self.task_name)
160 | yield (epoch, self.analysis.train_record, self.analysis.eval_record, self.analysis.model_record, model_uid)
161 |
162 | @accelerator.on_local_main_process
163 | def save_model(self, current_step=0):
164 | if self.task_name is None:
165 | dir = 'undefined'
166 | else:
167 | dir = self.task_name
168 | save_path = f'./save_model/{dir}/ChatGLM_{current_step}'
169 | if not os.path.exists(save_path):
170 | os.makedirs(save_path)
171 | save_model = self.accelerate.unwrap_model(self.model)
172 | save_model.save_pretrained(
173 | save_path,
174 | is_main_process=self.accelerate.is_main_process,
175 | save_function=self.accelerate.save,
176 | )
177 | self.analysis.append_model_record(current_step)
178 | return current_step
179 |
180 | def eval(self, epoch, pure_eval=False):
181 | if pure_eval:
182 | self.model = self.accelerate.prepare_model(self.model)
183 | self.eval_loader = self.accelerate.prepare_data_loader(
184 | self.eval_loader)
185 |
186 | with torch.no_grad():
187 | eval_count = 0
188 | eval_loss = 0
189 | eval_scores = {
190 | 'rouge-1': 0,
191 | 'rouge-2': 0,
192 | 'rouge-l': 0,
193 | 'bleu-4': 0
194 | }
195 |
196 | eval_iter = tqdm(self.eval_loader)
197 | self.model.eval()
198 |
199 | for it in eval_iter:
200 |
201 | output = self.model(**it)
202 | loss = output.loss
203 | loss = loss.mean()
204 |
205 | logits = output.logits
206 | labels = it['labels']
207 | metrics = self.compute_metrics(logits, labels)
208 | for k, v in metrics.items():
209 | eval_scores[k] += v
210 |
211 | eval_loss += loss.data.item()
212 | eval_count += 1
213 |
214 | eval_iter.set_description(
215 | f'Eval: {epoch + 1}')
216 | eval_iter.set_postfix(
217 | eval_loss=eval_loss / eval_count, **{k: v / eval_count for k, v in eval_scores.items()})
218 |
219 | self.analysis.append_eval_record({
220 | 'epoch': epoch + 1,
221 | 'eval_loss': eval_loss / eval_count,
222 | **{k: v / eval_count for k, v in eval_scores.items()}
223 | })
224 |
225 | def compute_metrics(self, logits, labels):
226 | shift_logits = logits[..., :-1, :]
227 | pred_logits = shift_logits.argmax(-1)
228 | pred_logits = pred_logits.tolist()
229 | shift_labels = labels[..., 1:].tolist()
230 |
231 | metrics_dct = {'rouge-1': [], 'rouge-2': [],
232 | 'rouge-l': [], 'bleu-4': []}
233 | for pred_ids, label_ids in zip(pred_logits, shift_labels):
234 | try:
235 | answer_idx = 0
236 | for i in range(len(label_ids)):
237 | if label_ids[i] != -100:
238 | answer_idx = i
239 | break
240 | pred_ids = pred_ids[answer_idx:]
241 | label_ids = label_ids[answer_idx:]
242 | pred_txt = self.tokenizer.decode(pred_ids).strip()
243 | label_txt = self.tokenizer.decode(label_ids).strip()
244 | pred_tokens = list(jieba.cut(pred_txt))
245 | label_tokens = list(jieba.cut(label_txt))
246 | rouge = Rouge()
247 | scores = rouge.get_scores(
248 | ' '.join(pred_tokens), ' '.join(label_tokens))
249 | for k, v in scores[0].items():
250 | metrics_dct[k].append(round(v['f'] * 100, 4))
251 | metrics_dct['bleu-4'].append(
252 | sentence_bleu(
253 | [label_tokens],
254 | pred_tokens,
255 | smoothing_function=SmoothingFunction().method3,
256 | )
257 | )
258 | except:
259 | continue
260 | return {k: np.mean(v) if len(v) > 0 else 0 for k, v in metrics_dct.items()}
261 |
--------------------------------------------------------------------------------
/sas_pipelines/1_predict_scores.py:
--------------------------------------------------------------------------------
1 | # %%
2 | import os
3 | import json
4 | import random
5 | import json_repair
6 | import numpy as np
7 | from tqdm import tqdm
8 | from copy import deepcopy
9 | from argparse import ArgumentParser
10 |
11 | import sys
12 | sys.path.append("../")
13 | cmd_args = True
14 | # Add params n_gpu
15 | parser = ArgumentParser()
16 | parser.add_argument('--n_gpu', default=0, help='CUDA_VISIBLE_DEVICES')
17 | parser.add_argument('--file_dir', default='../datasets', help='the directory of the datasets.')
18 | parser.add_argument('--file_name', default='9_English_gapfilling', help='file name of the dataset without extension (e.g. 0_Physics_ShortAns)')
19 | parser.add_argument('--llm_name', default='', help='the prefix name of save dir (usually is the LLM name)')
20 | parser.add_argument('--save_type_name', default='GLM4', help='the prefix name of save dir (usually is the LLM name)')
21 | parser.add_argument('--few_shot_num', default=0, help='decide the number of few-shot samples')
22 | parser.add_argument('--use_guideline', default='1', help='whether use scoring guideline')
23 | parser.add_argument('--model_from_pretrained', default='/home/lpc/models/glm-4-9b-chat/', help='model from pretrained')
24 | parser.add_argument('--vllm', default='0', help='whether use vllm')
25 | parser.add_argument('--tensor_parallel_size', default=1, help='tensor_parallel_size (TP) for vLLM')
26 | parser.add_argument('--max_new_tokens', default=1024, help='max new tokens')
27 | parser.add_argument('--do_sample', default='0', help='do_sample, useless for vLLM')
28 | parser.add_argument('--temperature', default=0.6, help='temperature, if temperture > 0, it will work on vLLM.')
29 | parser.add_argument('--top_p', default=0.95, help='top_p, if top_p < 1.0, it will work on vLLM')
30 | parser.add_argument('--skip_thinking', default='0', help='skip deep thinking in RL model with
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
43 |
44 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
43 |
44 |