├── image
├── init
├── few-shot.png
├── logo_00.png
├── zero-shot.png
└── TCMBench_logo.png
├── TCMBench_code
├── init
├── sari.py
└── explain_evaluation.py
├── pipline
├── __pycache__
│ ├── Model_API.cpython-310.pyc
│ └── bench_function.cpython-310.pyc
├── choice_bench.py
├── Model_API.py
├── Acc.py
└── bench_function.py
├── README_Chinese.md
└── README.md
/image/init:
--------------------------------------------------------------------------------
1 | 图表图片文件夹~
2 |
--------------------------------------------------------------------------------
/TCMBench_code/init:
--------------------------------------------------------------------------------
1 | metric需要的模型文件见 https://huggingface.co/WJing123/TCMBench_code
2 |
--------------------------------------------------------------------------------
/image/few-shot.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ywjawmw/TCMBench/HEAD/image/few-shot.png
--------------------------------------------------------------------------------
/image/logo_00.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ywjawmw/TCMBench/HEAD/image/logo_00.png
--------------------------------------------------------------------------------
/image/zero-shot.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ywjawmw/TCMBench/HEAD/image/zero-shot.png
--------------------------------------------------------------------------------
/image/TCMBench_logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ywjawmw/TCMBench/HEAD/image/TCMBench_logo.png
--------------------------------------------------------------------------------
/pipline/__pycache__/Model_API.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ywjawmw/TCMBench/HEAD/pipline/__pycache__/Model_API.cpython-310.pyc
--------------------------------------------------------------------------------
/pipline/__pycache__/bench_function.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ywjawmw/TCMBench/HEAD/pipline/__pycache__/bench_function.cpython-310.pyc
--------------------------------------------------------------------------------
/pipline/choice_bench.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 | parent_path = os.path.dirname(sys.path[0])
4 | print(parent_path)
5 | if parent_path not in sys.path:
6 | sys.path.append(parent_path)
7 | # from LLAMAAPI import LlamaAPI
8 | from Model_API import API
9 | from bench_function import export_distribute_json, export_union_json
10 | import json
11 | import argparse
12 |
13 | def parse_args():
14 | parser = argparse.ArgumentParser(description="parameter of LLMs")
15 | parser.add_argument(
16 | "--data_path",
17 | type=str,
18 | default='../data/first_level', # 注意按照三个不同的级别进行选择不同的测试任务
19 | help="测试数据",
20 | )
21 | parser.add_argument(
22 | "--model_name",
23 | type=str,
24 | default='gpt-4-0613',
25 | help="The LLM name.",
26 | )
27 | parser.add_argument(
28 | "--sys_prompt",
29 | type=str,
30 | default='FKU.json',
31 | help="选择不同测试题类型的指令.",
32 | )
33 | parser.add_argument(
34 | "--start_num",
35 | type=int,
36 | default=0,
37 | help="保存文档的起始id",
38 | )
39 | args = parser.parse_args()
40 | return args
41 |
42 | # 测试主函数, 一个问题+一个答案 --》 bench_function
43 |
44 | if __name__ == "__main__":
45 | args = parse_args()
46 | with open(f"{args.data_path}/{args.sys_prompt}", "r", encoding="utf-8") as f:
47 | data = json.load(f)
48 | f.close()
49 | directory = args.data_path
50 | model_name = args.model_name
51 | api_key = "XXX" # if closed model, using API key, else 为空
52 | api = API(api_key, model_name=model_name)
53 | # keyword = data[i]['keyword']
54 | question_type = data['type']
55 | zero_shot_prompt_text = data['prefix_prompt']
56 | # print(question_type)
57 | print(question_type)
58 | export_distribute_json(
59 | api,
60 | model_name,
61 | directory,
62 | zero_shot_prompt_text,
63 | question_type,
64 | args,
65 | parallel_num=len(data['example']),
66 | )
67 |
68 | export_union_json(
69 | directory,
70 | model_name,
71 | zero_shot_prompt_text,
72 | question_type
73 | )
74 |
75 |
76 |
77 | # export JAVA_HOME=/usr/lib/jvm/java-11-openjdk-amd64
78 | # export JRE_HOME=${JAVA_HOME}/jre
79 | # export CLASSPATH=.:${JAVA_HOME}/lib:${JRE_HOME}/lib
80 | # export PATH=${JAVA_HOME}/bin:$PATH
81 |
82 |
--------------------------------------------------------------------------------
/README_Chinese.md:
--------------------------------------------------------------------------------
1 |
2 | ## TCMBench: Benchmarking Large Language Models in Traditional Chinese Medicine from Knowledge to Clinical Reasoning
3 | Repo for TCMBench (“ShuzhiQihuang” LLMs series,The first comprehensive benchmark for evaluating LLMs in TCM)
4 |
5 | [**English**](./README.md) | [**中文**](./README_Chinese.md)
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 | ## 更新
18 |
19 | 💥 **TCMBench V2.0**来啦,这次加入了能体现中医多标准多因素的动态临床推理过程的测试题目外,还新生成了加入推理扰动的新问题,构成了三层不同难度的测评任务,13个子任务!
20 |
21 | 🚀 论文初始版本已经公开,欢迎引用,❗ 拒绝一切抄袭行为(微笑.jpg).
22 |
23 | ## ⚡ 简介
24 | 为了进一步有效、准确的评估大模型在中医药领域的表现,我们现建立了一个标准化、综合性的中医评测框架**TCMBench**,该评测框架将充分考虑中医药领域的复杂性和专业性,涵盖多个方面,以确保大语言模型在真实场景下的实用性和适用性。
25 |
26 |
27 | ## 📚 数据集:TCMEval
28 | 首先我们构建了首个中医评测数据集TCMEval。为了客观、真实地反映中医领域的知识体系与临床推理特点,以中医执业医师资格考试的高质量模拟题为数据来源,构建了评测数据集TCMEval。该数据集共包含6,482组问答样本,其中1,300组配有官方给定的标准解析文本,用于评估大语言模型的生成质量。所有数据不涉及个人隐私,内容聚焦于中医知识临床两方面。在中医专家的指导下,本章对原始题目进行筛选与确认,从每个学科每类题型中随机抽取了不超过100个样本,同时控制选项的均匀分布,避免数据偏斜。然后由两位中医研究生进行题目确认,确保覆盖考试的全部题型与学科。通过收集、整理和标注这些数据,我们旨在提供一个全面、准确、具有代表性的中医测试基准,来帮助评估和改大语言模型应用在中医领域性能。
29 |
30 |
31 | **🔎 任务类型** :
32 | - 🚀 **基础知识认知任务**:最低复杂度的任务为基础知识认知任务集,共包含5,473组问答样本。该任务集依据TCMLE考试中的标准题型,细分为三类代表性任务,分别对应不同维度的知识认知能力。涵盖基础知识理解任务(FKU)、知识点横向关联任务(KHC)以及临床逻辑纵向推理任务(CVR),数据见[./data/first_level](./data/first_level)。
33 | - 🚀 **综合动态临床分析任务** :在基础知识认知任务之上,结合中医专家建议进一步构建了六类具备多标准诊疗(包括辨证论治、同病异治、异病同治任务)与环境耦合多因素特征(包括社会环境、古籍经典理解以及哲学思想掌握任务)的临床推理任务集,用于评估大语言模型在真实中医临床情境下的知识融合、逻辑建模与复杂推理能力,共883组问答样本。其中每个任务均包含不少于50个样本,以支持稳定评估,数据见[./data/second_level](./data/second_level)。
34 | - 🚀 **复杂临床决策任务** :为进一步评估大语言模型在多标准中医临床环境中的综合推理能力和稳定性,在综合动态临床分析任务的基础上设计了四类复杂临床决策任务。该任务集通过对原始推理样本进行结构重构、语义扰动和任务形式转换(包括在辨证论治、同病异治以及异病同治任务上的扰动,以及中药推荐任务),生成了1,009组全新的问答样本。该任务集系统考察了模型在高复杂度推理中的一致性建模能力、推理鲁棒性与决策稳定性。数据见[./data/third_level](./data/third_level)。
35 |
36 | ## 👨⚕️ 数据处理
37 |
38 | ### 以基础知识认知任务为例
39 | 将试题转换为结构化的测评数据,其数据格式如下所示:
40 | 基础知识理解任务:
41 | ```json
42 | {
43 | "question": "《素问·咳论》:“五脏六腑皆令人咳”,但关系最密切的是( )。\nA.心肺\nB.肺肾\nC.肺脾\nD.肺胃\nE.肺大肠",
44 | "answer": [
45 | "D"
46 | ],
47 | "analysis": "根据《素问·咳论》“此皆聚于胃,关于肺,使人多涕唾而面浮肿气逆也”可知与五脏六腑皆令人咳关系最密切的脏腑为肺胃。手太阴肺经起于中焦,还循胃口,上膈属肺。寒凉饮食入胃,导致中焦寒,寒气循手太阴肺经上入于肺中,导致肺寒,肺为娇脏,不耐寒热,外内寒邪并聚于肺,则肺失宣降,肺气上逆发生咳嗽。因此答案选D。",
48 | "knowledge_point": "中医经典",
49 | "index": 8196,
50 | "score": 1
51 | }
52 | ```
53 | 临床逻辑纵向推理任务:
54 | ```json
55 | {
56 | "share_content": "刘×,男,46岁,刻下眩晕而见头重如蒙。胸闷恶心,食少多寐,苔白腻,脉濡滑。",
57 | "question": [
58 | {
59 | "sub_question": "1).证属( )。\nA.肝阳上亢\nB.气血亏虚\nC.肾精不足\nD.痰浊中阻\nE.以上都不是\n",
60 | "answer": [
61 | "D"
62 | ],
63 | "analysis": ""
64 | },
65 | {
66 | "sub_question": "2).治法宜选( )。\nA.燥湿祛痰,健脾和胃\nB.补肾滋阴\nC.补肾助阳\nD.补养气血,健运脾胃\nE.平肝潜阳,滋养肝肾\n",
67 | "answer": [
68 | "A"
69 | ],
70 | "analysis": ""
71 | },
72 | {
73 | "sub_question": "3).方药宜选( )。\nA.右归丸\nB.左归丸\nC.半夏白术天麻汤\nD.归脾汤\nE.天麻钩藤饮\n",
74 | "answer": [
75 | "C"
76 | ],
77 | "analysis": ""
78 | }
79 | ],
80 | "knowledge_point": "中医内科学",
81 | "index": 334,
82 | "score": 1
83 | }
84 | ```
85 | 知识点横向关联任务:
86 | ```json
87 | {
88 | "share_content": "(共用备选答案)\nA.化痰息风,健脾祛湿\nB.清肺化痰,散结排脓\nC.疏风宣肺,化痰止咳\nD.清热化痰,平肝息风\nE.润肺清热,理气化痰\n",
89 | "question": [
90 | {
91 | "sub_question": "1).贝母瓜蒌散的功用是( )。",
92 | "answer": [
93 | "E"
94 | ],
95 | "analysis": ""
96 | },
97 | {
98 | "sub_question": "2).半夏白术天麻汤的功用是( )。",
99 | "answer": [
100 | "A"
101 | ],
102 | "analysis": ""
103 | }
104 | ],
105 | "knowledge_point": "方剂学",
106 | "index": 1938,
107 | "score": 1
108 | }
109 | ```
110 |
111 | ## 🧐 测评细节
112 |
113 | 我们设计了任务自适应的prompt,要求LLM回答题目,并给出答案和分析,评测框架由如下部分组成:
114 |
115 | | 文件名 | 说明 |
116 | | -------------------------- | -------------- |
117 | | [./pipline/choice_bench.py](./pipline/choice_bench.py) | 设置不同的任务,引导LLMs生成答案与解析 |
118 | | [./pipline/bench_function.py](./pipline/bench_function.py) | 测试相关函数 |
119 | | [./pipline/Acc.py](./pipline/Acc.py) | 计算准确率 |
120 | | [./pipline/Model_API.py](./pipline/Model_API.py)| 调用模型接口,以openai为例,可根据测评模型进行调整 |
121 | | [./TCMBench_code/explain_evaluation.py](./TCMBench_code/explain_evaluation.py)| 采用ROUGE-1,ROUGE-L, SARI,BerScore, BartScore, 以及我们提出的SKScore评估模型解析质量 |
122 | |[./HumanTrue.json](./HumanTrue.json)| HumanTrue数据集|
123 |
124 |
125 | 首先采用[/pipline/choice_bench.py](./pipline/choice_bench.py),测试模型,得到模型生成的答案与解析
126 | ```
127 | 首先若有必要,请设置代理:
128 | os.environ['HTTPS_PROXY']="your proxy"
129 | 其次,若采用闭源模型则将你的Key填写到指定位置,否则置空:
130 | api_key = "your key"
131 | 然后通过设置不同的--data_path 和 --sys_prompt进行不同任务测试,设置--model_name 来调用不同的模型
132 | 使用以下命令在FKU任务对gpt-4-0613进行测试:
133 | python choice_bench.py --data_path ../data/first_level --sys_prompt FKU.json --model_name gpt-4-0613
134 | ```
135 |
136 | 通过[./pipline/Acc.py](./pipline/Acc.py),设置--data_path、--queation_type、--model_name,得到不同模型在不同任务上的准确率得分。
137 | ```
138 | python Acc.py --data_path ../data/first_level --queation_type FKU --model_name gpt-4-0613
139 | ```
140 |
141 | 通过[./TCMBench_code/explain_evaluation.py](./TCMBench_code/explain_evaluation.py),设置--model_name,得到不同模型在6个指标上的解析得分。
142 | ```
143 | python explain_evaluation.py --model_name gpt-4-0613
144 | ```
145 | 其中指标加载的模型见[code](https://huggingface.co/WJing123/TCMBench_code)
146 |
147 |
148 | 👨⚕️ 此外,该工作中也介绍了我们之前构建中医LLMs,ShenNong,欢迎大家关注我们的中医大模型开源项目**ShenNong-TCM-LLM**:
149 |
150 | - 🚀 [ShenNong-TCM](https://github.com/ywjawmw/ShenNong-TCM-LLM) : 为推动LLM在中医药领域的发展和落地,提升LLM的在中医药方面的知识与回答医学咨询的能力,我们推出了**ShenNong**中医药大规模语言模型。基于[中医药指令数据集SN-QA](https://huggingface.co/datasets/michaelwzhu/ShenNong_TCM_Dataset)。
151 |
152 | 以及我们其他医疗大模型开源项目:
153 | - 🚀 [“医”心医意——智能中医传承创新辅助平台](https://github.com/ywjawmw/AI4TCM-Platform) : 数智岐黄系列平台,针对已有的中医传承平台无法覆盖全面的多模态数据这一挑战,我们构建了更全面的中西医知识图谱。其次,针对中医经验传承效率低这一挑战,我们提出了可解释的药方分析技术来挖掘处方信息,自动分析从症状到中药这一立体诊疗过程并给出分析的科学依据。同时提供了一个公平的辅助平台,让青年医师、中医学生等人群快速掌握先进的中医知识,传承经验。
154 | - 🚀 [ChatMed-Consult](https://huggingface.co/michaelwzhu/ChatMed-Consult) : 基于[中文医疗在线问诊数据集ChatMed_Consult_Dataset](https://huggingface.co/datasets/michaelwzhu/ChatMed_Consult_Dataset)的50w+在线问诊+ChatGPT回复作为训练集。模型主干为[LlaMA-7b](https://github.com/facebookresearch/llama),融合了[Chinese-LlaMA-Alpaca](https://github.com/ymcui/Chinese-LLaMA-Alpaca)的LoRA权重与中文扩展词表,然后再进行基于LoRA的参数高效微调。我们将全部代码都进行了公开;
155 | - 🚀 [PromptCBLUE中文医疗大模型评测基准](https://github.com/michael-wzhu/PromptCBLUE): 将[CBLUE](https://tianchi.aliyun.com/dataset/95414)基准进行改造为提示学习模式,形成对大模型的中文医疗知识与医疗文本处理能力的评测基准。PromptCBLUE旨在采用一个生成式大模型即可完成医疗NLP相关的各种不同任务,如病历结构化,问诊,病例文书撰写等。
156 |
157 | ## 致谢
158 |
159 | 本项目基于大模型给出的API进行开发,同时参考了大语言模型在高考试题上的测评任务,在此对相关项目和研究开发人员表示感谢。
160 |
161 | - [ChatGPT](https://openai.com/blog/chatgpt)
162 | - [ChatGLM](https://github.com/THUDM/ChatGLM-6B](https://github.com/THUDM/GLM)
163 | - [GaoKao-Bench](https://github.com/OpenLMLab/GAOKAO-Bench)
164 |
165 |
166 | ## 引用
167 |
168 | 如果你使用了本项目的数据或者代码,请声明引用:
169 |
170 | ```bash
171 | @misc{yue2023 TCMBench,
172 | title={TCMBench: Benchmarking Large Language Models in Traditional Chinese Medicine from Knowledge to Clinical Reasoning},
173 | author={Wenjing Yue, Ming guan, Wei Zhu, Xiaoling Wang, Honglin Li},
174 | year={2023},
175 | publisher = {GitHub},
176 | journal = {GitHub repository},
177 | howpublished = {\url{https://github.com/ywjawmw/TCMBench}},
178 | }
179 |
180 | ```
181 |
182 | ## 团队介绍
183 |
184 | 本项目在华东师范大学计算机科学与技术学院智能知识管理与服务团队王晓玲教授,华东师范大学药学院院长、人工智能新药创智中心主任李洪林教授的指导下完成。
185 |
186 |
187 |
188 |
--------------------------------------------------------------------------------
/TCMBench_code/sari.py:
--------------------------------------------------------------------------------
1 | # =======================================================
2 | # SARI -- Text Simplification Tunable Evaluation Metric
3 | # =======================================================
4 | #
5 | # Author: Wei Xu (UPenn xwe@cis.upenn.edu)
6 | #
7 | # A Python implementation of the SARI metric for text simplification
8 | # evaluation in the following paper
9 | #
10 | # "Optimizing Statistical Machine Translation for Text Simplification"
11 | # Wei Xu, Courtney Napoles, Ellie Pavlick, Quanze Chen and Chris Callison-Burch
12 | # In Transactions of the Association for Computational Linguistics (TACL) 2015
13 | #
14 | # There is also a Java implementation of the SARI metric
15 | # that is integrated into the Joshua MT Decoder. It can
16 | # be used for tuning Joshua models for a real end-to-end
17 | # text simplification model.
18 | #
19 |
20 | from __future__ import division
21 | from collections import Counter
22 | import sys
23 |
24 |
25 |
26 | def ReadInFile (filename):
27 |
28 | with open(filename) as f:
29 | lines = f.readlines()
30 | lines = [x.strip() for x in lines]
31 | return lines
32 |
33 |
34 | def SARIngram(sgrams, cgrams, rgramslist, numref):
35 | rgramsall = [rgram for rgrams in rgramslist for rgram in rgrams]
36 | rgramcounter = Counter(rgramsall)
37 |
38 | sgramcounter = Counter(sgrams)
39 | sgramcounter_rep = Counter()
40 | for sgram, scount in sgramcounter.items():
41 | sgramcounter_rep[sgram] = scount * numref
42 |
43 | cgramcounter = Counter(cgrams)
44 | cgramcounter_rep = Counter()
45 | for cgram, ccount in cgramcounter.items():
46 | cgramcounter_rep[cgram] = ccount * numref
47 |
48 |
49 | # KEEP
50 | keepgramcounter_rep = sgramcounter_rep & cgramcounter_rep
51 | keepgramcountergood_rep = keepgramcounter_rep & rgramcounter # if sen 和 ref相同,那么这个值和keepgramcounter_rep是一样的,也就是计算s和c的相似(和r的相似)
52 | keepgramcounterall_rep = sgramcounter_rep & rgramcounter # if sen 和 ref相同,那么这个值和ref是一样的 (计算s和r的相似)
53 |
54 | keeptmpscore1 = 0
55 | keeptmpscore2 = 0
56 | for keepgram in keepgramcountergood_rep:
57 | keeptmpscore1 += keepgramcountergood_rep[keepgram] / keepgramcounter_rep[keepgram]
58 | keeptmpscore2 += keepgramcountergood_rep[keepgram] / keepgramcounterall_rep[keepgram]
59 | #print "KEEP", keepgram, keepscore, cgramcounter[keepgram], sgramcounter[keepgram], rgramcounter[keepgram]
60 | keepscore_precision = 0
61 | if len(keepgramcounter_rep) > 0:
62 | keepscore_precision = keeptmpscore1 / len(keepgramcounter_rep)
63 | keepscore_recall = 0
64 | if len(keepgramcounterall_rep) > 0:
65 | keepscore_recall = keeptmpscore2 / len(keepgramcounterall_rep)
66 | keepscore = 0
67 | if keepgramcounterall_rep == rgramcounter: # if sen 和 ref相同
68 | keepscore = keepscore_recall
69 | else:
70 | if keepscore_precision > 0 or keepscore_recall > 0:
71 | keepscore = 2 * keepscore_precision * keepscore_recall / (keepscore_precision + keepscore_recall) # f1 score
72 |
73 |
74 | # DELETION
75 | delgramcounter_rep = sgramcounter_rep - cgramcounter_rep # s和c的差异值
76 | delgramcountergood_rep = delgramcounter_rep - rgramcounter # s和c和r的差异值, if s=r, 这个值=0
77 | delgramcounterall_rep = sgramcounter_rep - rgramcounter # s和r的差异值, if s=r, 这个值=0
78 | deltmpscore1 = 0
79 | deltmpscore2 = 0
80 | for delgram in delgramcountergood_rep:
81 | deltmpscore1 += delgramcountergood_rep[delgram] / delgramcounter_rep[delgram]
82 | deltmpscore2 += delgramcountergood_rep[delgram] / delgramcounterall_rep[delgram]
83 | delscore_precision = 0
84 | if len(delgramcounter_rep) > 0:
85 | delscore_precision = deltmpscore1 / len(delgramcounter_rep)
86 | delscore_recall = 0
87 | if len(delgramcounterall_rep) > 0:
88 | delscore_recall = deltmpscore1 / len(delgramcounterall_rep)
89 | delscore = 0
90 | if delscore_precision > 0 or delscore_recall > 0:
91 | delscore = 2 * delscore_precision * delscore_recall / (delscore_precision + delscore_recall)
92 |
93 |
94 | # ADDITION
95 | addgramcounter = set(cgramcounter) - set(sgramcounter)
96 | addgramcountergood = set(addgramcounter) & set(rgramcounter)
97 | addgramcounterall = set(rgramcounter) - set(sgramcounter)
98 |
99 | addtmpscore = 0
100 | for addgram in addgramcountergood:
101 | addtmpscore += 1
102 |
103 | addscore_precision = 0
104 | addscore_recall = 0
105 | if len(addgramcounter) > 0:
106 | addscore_precision = addtmpscore / len(addgramcounter)
107 | if len(addgramcounterall) > 0:
108 | addscore_recall = addtmpscore / len(addgramcounterall)
109 | addscore = 0
110 | if addscore_precision > 0 or addscore_recall > 0:
111 | addscore = 2 * addscore_precision * addscore_recall / (addscore_precision + addscore_recall)
112 |
113 | return (keepscore, delscore_precision, addscore)
114 |
115 |
116 | def SARIsent (ssent, csent, rsents) :
117 | numref = len(rsents)
118 |
119 | # s1grams = ssent.lower().split(" ")
120 | # c1grams = csent.lower().split(" ")
121 | # zh
122 | s1grams = [s for s in ssent]
123 | c1grams = [c for c in csent]
124 | s2grams = []
125 | c2grams = []
126 | s3grams = []
127 | c3grams = []
128 | s4grams = []
129 | c4grams = []
130 |
131 | r1gramslist = []
132 | r2gramslist = []
133 | r3gramslist = []
134 | r4gramslist = []
135 | for rsent in rsents:
136 | # r1grams = rsent.lower().split(" ")
137 | r1grams = [r for r in rsent]
138 | r2grams = []
139 | r3grams = []
140 | r4grams = []
141 | r1gramslist.append(r1grams)
142 | for i in range(0, len(r1grams)-1) :
143 | if i < len(r1grams) - 1:
144 | # r2gram = r1grams[i] + " " + r1grams[i+1]
145 | r2gram = r1grams[i] + r1grams[i + 1]
146 | r2grams.append(r2gram)
147 | if i < len(r1grams)-2:
148 | # r3gram = r1grams[i] + " " + r1grams[i+1] + " " + r1grams[i+2]
149 | r3gram = r1grams[i] + r1grams[i + 1] + r1grams[i + 2]
150 | r3grams.append(r3gram)
151 | if i < len(r1grams)-3:
152 | # r4gram = r1grams[i] + " " + r1grams[i+1] + " " + r1grams[i+2] + " " + r1grams[i+3]
153 | r4gram = r1grams[i] + r1grams[i + 1] + r1grams[i + 2] + r1grams[i + 3]
154 | r4grams.append(r4gram)
155 | r2gramslist.append(r2grams)
156 | r3gramslist.append(r3grams)
157 | r4gramslist.append(r4grams)
158 |
159 | for i in range(0, len(s1grams)-1) :
160 | if i < len(s1grams) - 1:
161 | # s2gram = s1grams[i] + " " + s1grams[i+1]
162 | s2gram = s1grams[i] + s1grams[i + 1]
163 | s2grams.append(s2gram)
164 | if i < len(s1grams)-2:
165 | # s3gram = s1grams[i] + " " + s1grams[i+1] + " " + s1grams[i+2]
166 | s3gram = s1grams[i] + s1grams[i + 1] + s1grams[i + 2]
167 | s3grams.append(s3gram)
168 | if i < len(s1grams)-3:
169 | # s4gram = s1grams[i] + " " + s1grams[i+1] + " " + s1grams[i+2] + " " + s1grams[i+3]
170 | s4gram = s1grams[i] + s1grams[i + 1] + s1grams[i + 2] + s1grams[i + 3]
171 | s4grams.append(s4gram)
172 |
173 | for i in range(0, len(c1grams)-1) :
174 | if i < len(c1grams) - 1:
175 | # c2gram = c1grams[i] + " " + c1grams[i+1]
176 | c2gram = c1grams[i] + c1grams[i + 1]
177 | c2grams.append(c2gram)
178 | if i < len(c1grams)-2:
179 | # c3gram = c1grams[i] + " " + c1grams[i + 1] + " " + c1grams[i + 2]
180 | c3gram = c1grams[i] + c1grams[i + 1] + c1grams[i + 2]
181 | c3grams.append(c3gram)
182 | if i < len(c1grams)-3:
183 | # c4gram = c1grams[i] + " " + c1grams[i+1] + " " + c1grams[i+2] + " " + c1grams[i+3]
184 | c4gram = c1grams[i] + c1grams[i + 1] + c1grams[i + 2] + c1grams[i + 3]
185 | c4grams.append(c4gram)
186 |
187 |
188 | (keep1score, del1score, add1score) = SARIngram(s1grams, c1grams, r1gramslist, numref)
189 | (keep2score, del2score, add2score) = SARIngram(s2grams, c2grams, r2gramslist, numref)
190 | (keep3score, del3score, add3score) = SARIngram(s3grams, c3grams, r3gramslist, numref)
191 | (keep4score, del4score, add4score) = SARIngram(s4grams, c4grams, r4gramslist, numref)
192 | avgkeepscore = sum([keep1score,keep2score,keep3score,keep4score])/4
193 | avgdelscore = sum([del1score,del2score,del3score,del4score])/4
194 | avgaddscore = sum([add1score,add2score,add3score,add4score])/4
195 | finalscore = (avgkeepscore + avgdelscore + avgaddscore ) / 3
196 |
197 | return avgkeepscore, avgdelscore, avgaddscore, finalscore
198 |
199 |
200 | def main():
201 |
202 | # fnamenorm = "./turkcorpus/test.8turkers.tok.norm"
203 | # fnamesimp = "./turkcorpus/test.8turkers.tok.simp"
204 | # fnameturk = "./turkcorpus/test.8turkers.tok.turk."
205 |
206 |
207 | ssent = "About 95 species are currently accepted ."
208 | csent1 = "About 95 you now get in ."
209 | csent2 = "About 95 species are now agreed ."
210 | csent3 = "About 95 species are currently agreed ."
211 | rsents = ["About 95 species are currently accepted ."]
212 | # rsents = ["About 95 species are currently known .", "About 95 species are now accepted .", "95 species are now accepted ."]
213 |
214 | print(SARIsent(ssent, csent1, rsents))
215 | print(SARIsent(ssent, csent2, rsents))
216 | print(SARIsent(ssent, csent3, rsents))
217 |
218 |
219 | if __name__ == '__main__':
220 | main()
--------------------------------------------------------------------------------
/pipline/Model_API.py:
--------------------------------------------------------------------------------
1 | # --*-- conding:utf-8 --*--
2 | # @Time : 2024/1/11 17:21
3 | # @Author : YWJ
4 | # @Email : 52215901025@stu.ecnu.edu.cn
5 | # @File : Model_API.py
6 | # @Software : PyCharm
7 | # @Description : 各个模型的API接口
8 | import os
9 | import openai
10 | import requests
11 | import urllib
12 | import json
13 | import time
14 | import sys
15 | from http import HTTPStatus
16 | import dashscope
17 | import random
18 | # from vllm import LLM, SamplingParams
19 | # from transformers import AutoModelForCausalLM, AutoTokenizer
20 | from dashscope import Generation
21 | from dashscope.api_entities.dashscope_response import Role
22 |
23 | class API():
24 | def __init__(self, api_key_list: str, model_name: str = "gpt-3.5-turbo", temperature: float = 0.0,
25 | max_tokens: int = 1024):
26 | self.api_key_list = api_key_list
27 | self.model_name = model_name # 新的model, 支持1w+
28 | self.temperature = temperature
29 | self.max_tokens = max_tokens
30 | # if self.api_key_list == "":
31 | # self.llm = LLM("/data/xxx", tokenizer_mode='auto', # local model
32 | # trust_remote_code=True,
33 | # enforce_eager=True,
34 | # enable_prefix_caching=True)
35 | # self.sampling_params = SamplingParams(temperature=0.85, top_p=0.8, max_tokens=512)
36 | # self.tokenizer = AutoTokenizer.from_pretrained("/data/xxx", trust_remote_code=True)
37 |
38 | # GPT系列
39 | def send_request_turbo(self, prompt, question):
40 | """
41 | """
42 | zero_shot_prompt_message = {'role': 'system', 'content': prompt}
43 |
44 | messages = [zero_shot_prompt_message]
45 | question = sensitive(question)
46 | message = {"role": "user", "content": question}
47 | print(f"LLM的Prompt是{'*' * 100}\n{zero_shot_prompt_message['content']}\n{message['content']}")
48 | messages.append(message)
49 |
50 | output = {}
51 | while True:
52 | try:
53 | # os.environ['HTTPS_PROXY'] = "http://127.0.0.1:10809"
54 | openai.api_base = "https://xiaoai.plus/v1"
55 | openai.api_key = self.api_key_list
56 | output = openai.ChatCompletion.create(
57 | model=self.model_name,
58 | messages=messages,
59 | temperature=self.temperature
60 | )
61 | answer = revers_sensitive(output['choices'][0]['message']['content'])
62 | print(answer)
63 | return [answer]
64 | except Exception as e:
65 | print('Exception:', e)
66 | print("原始Prompt:")
67 | sys.exit()
68 |
69 | return [output]
70 |
71 | # 多轮会话
72 | def send_request_chat(self, prompt, share_content, question):
73 | """
74 | """
75 | zero_shot_prompt_message = {'role': 'system', 'content': prompt}
76 |
77 | messages = [zero_shot_prompt_message]
78 | share_content = sensitive(share_content)
79 | message = {"role": "user", "content": share_content}
80 | messages.append(message)
81 | output_chat = []
82 | i = 0
83 | error_num = 0
84 | while i < len(question):
85 | sub_question = question[i]
86 | sub_question['sub_question'] = sensitive(sub_question['sub_question'])
87 | message = {"role": "user", "content": sub_question['sub_question']}
88 | messages.append(message)
89 | # os.environ['HTTPS_PROXY'] = "http://127.0.0.1:33210"
90 | # os.environ['HTTPS_PROXY'] = "http://127.0.0.1:10809"
91 | # os.environ['OPENAI_API_KEY'] = self.api_key_list
92 | # os.environ["OPENAI_BASE_URL"] = "https://api.xiaoai.plus/v1"
93 | openai.api_base = "https://xiaoai.plus/v1"
94 | openai.api_key = self.api_key_list
95 | try:
96 | output = openai.ChatCompletion.create(
97 | model=self.model_name,
98 | messages=messages,
99 | temperature=self.temperature
100 | )
101 | answer = revers_sensitive(output['choices'][0]['message']['content'])
102 | answer = sensitive(answer)
103 | messages.append({"role": "assistant", "content": answer})
104 | output_chat.append(answer)
105 | i += 1
106 | print(i, ":", "success!")
107 | # print(output)
108 | except Exception as e:
109 | print('Exception:', e)
110 | print("原始Prompt:")
111 | for m in messages:
112 | print(m)
113 | print("—" * 100)
114 | # if "overloaded" or "Bad" in e:
115 | if "max" in e.args[0]: # 说明到了最大的token, 将上面存储的靠前的子问题删除几个
116 | time.sleep(5)
117 | if error_num == 0:
118 | if len(messages) < 13:
119 | star_index = -1 * len(messages) + 2
120 | else:
121 | star_index = -11 # 前5个
122 | else:
123 | star_index += 2 # 如果还超长,那么就不断的逐个删除子问题
124 | if star_index >= -1:
125 | print("无法处理该问题")
126 | output_chat.append("")
127 | error_num = 0
128 | i += 1
129 | print("#" * 100)
130 | messages = messages[:2] + messages[star_index: -1]
131 | print("最大token, 保留历史前几个问题")
132 | error_num = 1
133 | for m in messages:
134 | print(m)
135 | print("*" * 100)
136 | else:
137 | time.sleep(5) # 递归调用自身进行重试(i不变)
138 | print("重复提问")
139 | messages = messages[:-1]
140 | for m in messages:
141 | print(m)
142 | print("*" * 100)
143 | error_num = 0
144 | # output_chat.append({})
145 | # i += 1
146 | # print("失败,默认回答不出内容!")
147 | time.sleep(5)
148 |
149 | time.sleep(2)
150 |
151 | return output_chat
152 |
153 | # DDST hard
154 | def send_request_hard(self, prompt, option, question, option_num):
155 | """
156 | """
157 | prompt = prompt.replace("", option).replace("", str(option_num))
158 | zero_shot_prompt_message = {'role': 'system', 'content': prompt}
159 |
160 | messages = [zero_shot_prompt_message]
161 | question = sensitive(question)
162 | message = {"role": "user", "content": question}
163 | print(f"LLM的Prompt是{'*' * 100}\n{message['content']}")
164 | messages.append(message)
165 |
166 | output = {}
167 | while True:
168 | try:
169 | # os.environ['HTTPS_PROXY'] = "http://127.0.0.1:10809"
170 | openai.api_base = "https://xiaoai.plus/v1"
171 | openai.api_key = self.api_key_list
172 | output = openai.ChatCompletion.create(
173 | model=self.model_name,
174 | messages=messages,
175 | temperature=self.temperature
176 | )
177 | answer = revers_sensitive(output['choices'][0]['message']['content'])
178 | print(answer)
179 | return [answer]
180 | except Exception as e:
181 | print('Exception:', e)
182 | print("原始Prompt:")
183 | sys.exit()
184 | # herb_predict
185 | def send_request_TCM_Rec(self, prompt, question):
186 | """
187 | """
188 | zero_shot_prompt_message = {'role': 'system', 'content': "你是一个中药推荐系统,你需要根据症状信息推荐20个中药。"}
189 |
190 | messages = [zero_shot_prompt_message]
191 | question = sensitive(question)
192 | question = f"{prompt}\n症状信息为:{question}。\n"
193 | message = {"role": "user", "content": question}
194 | print(f"LLM的Prompt是{'*' * 100}\n{zero_shot_prompt_message['content']}\n{message['content']}")
195 | messages.append(message)
196 |
197 | output = {}
198 | while True:
199 | try:
200 | # os.environ['HTTPS_PROXY'] = "http://127.0.0.1:10809"
201 | openai.api_base = "https://xiaoai.plus/v1"
202 | openai.api_key = self.api_key_list
203 | output = openai.ChatCompletion.create(
204 | model=self.model_name,
205 | messages=messages,
206 | temperature=self.temperature
207 | )
208 | answer = revers_sensitive(output['choices'][0]['message']['content'])
209 | print(answer)
210 | return [answer]
211 | except Exception as e:
212 | print('Exception:', e)
213 | print("原始Prompt:")
214 | sys.exit()
215 |
216 | def sensitive(sentence):
217 | sentence = sentence.replace("阴道", "term-YD")
218 | sentence = sentence.replace("射精", "term-SJ")
219 | return sentence
220 |
221 | def revers_sensitive(sentence):
222 | sentence = sentence.replace("term-YD", "阴道")
223 | sentence = sentence.replace("term-SJ", "射精")
224 | return sentence
--------------------------------------------------------------------------------
/pipline/Acc.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # !/usr/bin/python3
3 | # -*- coding: utf-8 -*-
4 | # @Time : 2025/9/28 21:55
5 | # @Author : Wenjing
6 | # @File : Acc.py
7 | # @Desc : 计算各个任务的准确率
8 |
9 |
10 | import json
11 | import os
12 | from bench_function import test_correction_score_A12, test_correction_score_A34
13 | import matplotlib.pyplot as plt
14 | import numpy as np
15 | import argparse
16 |
17 |
18 | def parse_args():
19 | parser = argparse.ArgumentParser(description="parameter of LLMs")
20 | parser.add_argument(
21 | "--data_path",
22 | type=str,
23 | default='../data/first_level', # 注意按照三个不同的级别进行选择不同的测试任务
24 | help="测试数据",
25 | )
26 | parser.add_argument(
27 | "--model_name",
28 | type=str,
29 | default='gpt-4-0613',
30 | help="The LLM name.",
31 | )
32 | parser.add_argument(
33 | "--question_type",
34 | type=str,
35 | default='FKU',
36 | help="选择不同测试题类型.",
37 | )
38 | parser.add_argument(
39 | "--start_num",
40 | type=int,
41 | default=0,
42 | help="保存文档的起始id",
43 | )
44 | args = parser.parse_args()
45 | return args
46 |
47 | def the_first_task(data_FKU, data_CVR, data_KHC):
48 | score_A12, false_dict_A12, A12_kp_dict, score_num_A12, all_num_A12, correct_dict_A12 = test_correction_score_A12(data_FKU)
49 | score_A34, false_dict_A34, A34_kp_dict, score_num_A34, all_num_A34 = test_correction_score_A34(data_CVR)
50 | score_B1, false_dict_B1, B1_kp_dict, score_num_B1, all_num_B1 = test_correction_score_A34(data_KHC)
51 |
52 | print("基础知识认知任务测试结果")
53 | print("A1-A2题目正确率:%f \nA3-A4题目正确率:%f \nB1题目正确率:%f \n" % (score_A12, score_A34, score_B1))
54 | print("A3-A4-k题目正确率:%f" % score_A34)
55 | # print("A3-A4-k-shot题目正确率:%f" % (score_A34_k))
56 | print("总的准确率:%f",
57 | (score_num_A12 + score_num_A34 + score_num_B1) / (all_num_A12 + all_num_A34 + all_num_B1))
58 |
59 | # 创建一个字典来保存相同Key的合并后的值
60 | merged_values = {}
61 |
62 | A1_num, A3_num, B1_num = 0, 0, 0
63 | for k, v in A12_kp_dict.items():
64 | A1_num += v[1]
65 |
66 | for k, v in A34_kp_dict.items():
67 | A3_num += v[1]
68 |
69 | for k, v in B1_kp_dict.items():
70 | B1_num += v[1]
71 |
72 | print(A1_num, A3_num, B1_num)
73 |
74 | # 遍历所有相同的Key,并将它们的值合并成一个列表
75 | num = 0
76 | kp_num = 0
77 |
78 | A12_res_kp, A34_res_kp, B1_res_kp = dict(), dict(), dict()
79 | res_kp = dict()
80 |
81 | print("A12各个知识点的准确率:", "*" * 100)
82 | for key, value in A12_kp_dict.items():
83 | print(key, "\t", value[0] / value[1])
84 | A12_res_kp[key] = value[0] / value[1]
85 |
86 | print("A34各个知识点的准确率:", "*" * 100)
87 | for key, value in A34_kp_dict.items():
88 | print(key, "\t", value[0] / value[1])
89 | A34_res_kp[key] = value[0] / value[1]
90 |
91 | print("B1各个知识点的准确率:", "*" * 100)
92 | for key, value in B1_kp_dict.items():
93 | print(key, "\t", value[0] / value[1])
94 | B1_res_kp[key] = value[0] / value[1]
95 |
96 | print("总的知识点准确率:", "*" * 100)
97 | for key in A12_kp_dict.keys():
98 | # 从每个字典中获取Key对应的值的列表
99 | values1 = A12_kp_dict.get(key, [])
100 | values2 = A34_kp_dict.get(key, [])
101 | values3 = B1_kp_dict.get(key, [])
102 | if len(values1) == 0:
103 | values1 = [0, 0]
104 | if len(values2) == 0:
105 | values2 = [0, 0]
106 | if len(values3) == 0:
107 | values3 = [0, 0]
108 |
109 | # 将三个字典中相同Key的值合并成一个列表
110 | merged_values[key] = [0, 0, 0.0]
111 | merged_values[key][0] = values1[0] + values2[0] + values3[0]
112 | merged_values[key][1] = values1[1] + values2[1] + values3[1]
113 | merged_values[key][2] = merged_values[key][0] / merged_values[key][1]
114 | # print(key, ": ", merged_values[key][0]/merged_values[key][1])
115 | print(key, "\t", merged_values[key][0] / merged_values[key][1])
116 | res_kp[key] = merged_values[key][0] / merged_values[key][1]
117 | num += merged_values[key][1]
118 | print(A12_res_kp)
119 | print(A34_res_kp)
120 | print(B1_res_kp)
121 | print(res_kp)
122 |
123 | def A12_type_task(data_dict):
124 | score = 0
125 | all_num = 0
126 | for data in data_dict['example']:
127 | all_num += 1
128 | true_answer = data['standard_answer']
129 | model_answer = data['model_answer']
130 | if true_answer == model_answer:
131 | score += 1
132 | print(score / all_num)
133 |
134 | def A34_type_task(data_dict):
135 | score = 0
136 | all_num = 0
137 | for data in data_dict['example']:
138 | question = data["question"]
139 | for sub_question in question:
140 | all_num += 1
141 | standard_answer = sub_question['standard_answer']
142 | model_answer = sub_question['model_answer']
143 | if standard_answer == model_answer:
144 | score += 1
145 | print(score / all_num)
146 |
147 | # 如果两个元素的sym_set列表的交集为sym_set的长度,那么我们将这些元素分为一组,添加到一个新的列表中
148 | def new_test_gt(file_path):
149 | with open(file_path, 'r', encoding='utf-8') as f:
150 | data = json.load(f)["example"]
151 | new_data = {}
152 | visited = set()
153 | count = 0
154 | for i, d1 in enumerate(data):
155 | if i in visited:
156 | continue
157 | sym_set1 = set(d1['sym_set'])
158 | sym_set_key = "-".join(sorted(sym_set1))
159 | new_data[sym_set_key] = [d1["herb_set"]]
160 | visited.add(i)
161 | count += 1
162 | for j in range(i+1, len(data)):
163 | if j in visited:
164 | continue
165 | d2 = data[j]
166 | # 求d1和d2的sym_set的交集
167 | sym_set2 = set(d2['sym_set'])
168 | if len(sym_set1) == len(sym_set2):
169 | if len(sym_set1 & sym_set2) == len(sym_set1):
170 | new_data[sym_set_key].append(d2["herb_set"])
171 | count += 1
172 | visited.add(j)
173 | # print(count)
174 | return new_data
175 |
176 |
177 |
178 | def test_metric(gt_data, test_data):
179 | test_group_list = test_data["example"]
180 | Ks = [20]
181 | result = {'precision': np.zeros(len(Ks)), 'recall': np.zeros(len(Ks)),
182 | 'ndcg': np.zeros(len(Ks)), 'rmrr': np.zeros(len(Ks))}
183 |
184 | precision_n = np.zeros(len(Ks))
185 | recall_n = np.zeros(len(Ks))
186 | ndcg_n = np.zeros(len(Ks))
187 | rmrr_n = np.zeros(len(Ks))
188 | topN = Ks
189 |
190 | for index in range(len(test_group_list)):
191 | entry = test_group_list[index]
192 | sym_set = sorted(set(entry["sym_set"]))
193 | if "-".join(sym_set) not in gt_data.keys():
194 | print(test_group_list[index])
195 | break
196 | v_list = gt_data["-".join(sym_set)] # sym-index's true herb set list
197 | rating = entry["model_output"] # sym-index's predicted herb set list
198 | K_max = topN[len(topN) - 1]
199 | for ii in range(len(topN)): # topN: [5, 10, 15, 20]
200 | top_recall, top_precision, top_ndcg, top_rmrr, top_iou = 0., 0., 0., 0., 0.
201 | for v in v_list: # v:对应的ground truth
202 | r = []
203 | for i in rating[:K_max]:
204 | herb = i.replace("\"", "")
205 | if herb in v:
206 | r.append(1)
207 | else:
208 | r.append(0)
209 | number = 0
210 | all_list_number = 0
211 | herb_results = [] # 推荐列表中herb 集合
212 | for i in rating[:topN[ii]]:
213 | herb = i.replace("\"", "")
214 | herb_results.append(herb)
215 | if herb in v:
216 | number += 1
217 | herb_v = set(herb_results + v)
218 | all_list_number = len(herb_v)
219 | # todo: modified MRR to Rank-MRR
220 | mrr_score = 0.
221 | for a_rank in range(len(v)): # herb 在grand truth中的位置a_rank
222 | if v[a_rank] in herb_results:
223 | a_refer = herb_results.index(v[a_rank]) # herb 在推荐列表中的位置a_refer
224 | mrr_score += 1.0 / (abs(a_refer - a_rank) + 1)
225 | if float(number / topN[ii]) > top_precision: # 使用precision选择GT
226 | top_precision = float(number / topN[ii])
227 | top_recall = float(number / len(v))
228 | top_ndcg = ndcg_at_k(r, topN[ii])
229 | top_rmrr = mrr_score / len(v)
230 | precision_n[ii] = precision_n[ii] + top_precision # [ii]所有测试数据top k的precision之和
231 | recall_n[ii] = recall_n[ii] + top_recall
232 | ndcg_n[ii] = ndcg_n[ii] + top_ndcg
233 | rmrr_n[ii] = rmrr_n[ii] + top_rmrr
234 | for ii in range(len(topN)):
235 | result['precision'][ii] = precision_n[ii] / len(test_group_list)
236 | result['recall'][ii] = recall_n[ii] / len(test_group_list)
237 | result['ndcg'][ii] = ndcg_n[ii] / len(test_group_list)
238 | result['rmrr'][ii] = rmrr_n[ii] / len(test_group_list)
239 | return result
240 |
241 | def dcg_at_k(r, k, method=1):
242 | """Score is discounted cumulative gain (dcg)
243 | Relevance is positive real values. Can use binary
244 | as the previous methods.
245 | Returns:
246 | Discounted cumulative gain
247 | """
248 | r = np.asfarray(r)[:k]
249 | if r.size:
250 | if method == 0:
251 | return r[0] + np.sum(r[1:] / np.log2(np.arange(2, r.size + 1)))
252 | elif method == 1:
253 | return np.sum(r / np.log2(np.arange(2, r.size + 2)))
254 | else:
255 | raise ValueError('method must be 0 or 1.')
256 | return 0.
257 |
258 |
259 | def ndcg_at_k(r, k, method=1):
260 | """Score is normalized discounted cumulative gain (ndcg)
261 | Relevance is positive real values. Can use binary
262 | as the previous methods.
263 | Returns:
264 | Normalized discounted cumulative gain
265 | """
266 | # dcg_max = dcg_at_k(np.ones_like(r), k, method)
267 | dcg_max = dcg_at_k(sorted(r, reverse=True), k, method)
268 | if not dcg_max:
269 | return 0.
270 | return dcg_at_k(r, k, method) / dcg_max
271 |
272 | def herb_rec_test(gt_data, data):
273 | result = test_metric(gt_data, data)
274 | res_score = ""
275 | for key, value in result.items():
276 | res = [str(round(v, 4)) for v in value]
277 | print(key + ":" + ", ".join(res))
278 |
279 |
280 | if __name__ == "__main__":
281 | args = parse_args()
282 | # 第一级三个任务一起计算,因为需要计算一个总分
283 | if "first_level" in args.data_path:
284 | question_type = "FKU"
285 | with open(f"{args.data_path}/{args.model_name}_{question_type}.json", "r", encoding="utf-8") as f:
286 | data_FKU = json.load(f)
287 | f.close()
288 |
289 | question_type = "CVR"
290 | with open(f"{args.data_path}/{args.model_name}_{question_type}.json", "r", encoding="utf-8") as f:
291 | data_CVR = json.load(f)
292 | f.close()
293 |
294 | question_type = "KHC"
295 | with open(f"{args.data_path}/{args.model_name}_{question_type}.json", "r", encoding="utf-8") as f:
296 | data_KHC = json.load(f)
297 | f.close()
298 |
299 | the_first_task(data_FKU, data_CVR, data_KHC)
300 |
301 | # 利用args.question_type 来分别计算其他任务的ACC
302 | question_type = args.question_type
303 | with open(f"{args.data_path}/{args.model_name}_{question_type}/seperate_0-1.json", "r", encoding="utf-8") as f:
304 | data = json.load(f)
305 | f.close()
306 | if question_type in ["CBF", "SCF", "PF", "SDDT", "DDST", "SDDT_hard", "DDST_hard"]:
307 | A12_type_task(data)
308 | elif question_type in ["SDT", 'SDT_reverse', "SDT_shuffle"]:
309 | A34_type_task(data)
310 | elif question_type in ["herb_predict"]:
311 | gt_file_path = f"{args.data_path}/{question_type}.json"
312 | gt_data = new_test_gt(gt_file_path)
313 | herb_rec_test(gt_data, data)
314 |
315 |
316 |
317 |
318 |
319 |
320 |
321 |
322 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## TCMBench: Benchmarking Large Language Models in Traditional Chinese Medicine from Knowledge to Clinical Reasoning
2 | Repo for TCMBench (“ShuzhiQihuang” LLMs series,The first comprehensive benchmark for evaluating LLMs in TCM)
3 |
4 | [**English**](./README.md) | [**中文**](./README_Chinese.md)
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 | ## Updates
17 |
18 | 💥 TCMBench V2.0 is here! In this version, we added test questions that reflect the multi-standard and multi-factor characteristics of dynamic clinical reasoning in TCM. We also generated new questions with reasoning perturbations, forming three levels of evaluation tasks and 13 sub-tasks in total.
19 |
20 | 🚀 The initial version of the paper has been released. Citations are welcome. ❗ All forms of plagiarism are strictly rejected (smile.jpg).
21 |
22 | ## ⚡ Introduction
23 |
24 | To further evaluate the performance of large language models (LLMs) in Traditional Chinese Medicine (TCM) more effectively and accurately, we established a standardized and comprehensive benchmark framework: TCMBench. This benchmark fully considers the complexity and domain-specific nature of TCM, covering multiple aspects to ensure the practical usability and applicability of LLMs in real-world TCM scenarios.
25 |
26 | 📚 Dataset: TCMEval
27 |
28 | We first constructed the TCMEval dataset, the first benchmark dataset in TCM. To objectively and accurately reflect the knowledge system and clinical reasoning characteristics of TCM, TCMEval is built using high-quality simulated questions from the TCM Licensing Examination as the data source.
29 |
30 | The dataset contains 6,482 question–answer pairs, of which 1,300 pairs are accompanied by official standard explanations for evaluating the generation quality of LLMs. All data avoids personal information and focuses on TCM knowledge and clinical content.
31 |
32 | Under the guidance of TCM experts, the original questions were filtered and confirmed. From each subject and question type, no more than 100 samples were randomly selected while ensuring an even distribution of answer options to avoid data bias. Two graduate students in TCM further verified the questions, ensuring full coverage of all exam subjects and question types. Through this collection, organization, and annotation process, TCMEval provides a comprehensive, accurate, and representative TCM benchmark to support the evaluation and improvement of LLM applications in TCM.
33 |
34 | **🔎 Task Types**:
35 | - 🚀 **Fundamental Knowledge Cognition Tasks**:The lowest complexity level includes 5,473 Q&A pairs. This task set is based on standard question types in the TCMLE exam and subdivided into three representative tasks, each reflecting different dimensions of knowledge cognition:
36 | - Fundamental Knowledge Understanding (FKU)
37 | - Knowledge Horizontal Correlation (KHC)
38 | - Clinical Vertical Reasoning (CVR)
39 | Data available at [./data/first_level](./data/first_level).
40 | - 🚀 **Comprehensive Dynamic Clinical Analysis Tasks**:Built on top of the fundamental knowledge cognition tasks and designed with input from TCM experts, this set includes six types of tasks featuring multi-standard diagnosis and treatment (e.g., syndrome differentiation and treatment, one disease with different treatments, different diseases with the same treatment) and multi-factor reasoning (e.g., social environment, classical literature interpretation, and philosophical understanding).
41 | This set contains 883 Q&A pairs, with at least 50 samples per task to ensure stable evaluation.
42 | Data available at [./data/second_level](./data/second_level)l.
43 | - 🚀 **Complex Clinical Decision-Making Tasks**:To further evaluate the comprehensive reasoning ability and stability of LLMs in multi-standard TCM clinical environments, we designed four types of complex decision-making tasks based on the dynamic clinical analysis tasks. By restructuring original reasoning samples, introducing semantic perturbations, and converting task formats (e.g., perturbations in syndrome differentiation and treatment, one disease with different treatments, different diseases with the same treatment, as well as Chinese medicine prescription tasks), we generated 1,009 new Q&A pairs.This task set systematically examines model consistency in high-complexity reasoning, robustness in inference, and stability in decision-making.
44 | Data available at [./data/third_level](./data/third_level).
45 |
46 | ## 👨⚕️ Data Processing
47 |
48 | ### Example: Fundamental Knowledge Cognition Tasks
49 |
50 | We converted the test questions into structured evaluation data. The data format is as follows:
51 |
52 | Fundamental Knowledge Understanding Task
53 | ```json
54 | {
55 | "question": "《素问·咳论》:“五脏六腑皆令人咳”,但关系最密切的是( )。\nA.心肺\nB.肺肾\nC.肺脾\nD.肺胃\nE.肺大肠",
56 | "answer": [
57 | "D"
58 | ],
59 | "analysis": "根据《素问·咳论》“此皆聚于胃,关于肺,使人多涕唾而面浮肿气逆也”可知与五脏六腑皆令人咳关系最密切的脏腑为肺胃。手太阴肺经起于中焦,还循胃口,上膈属肺。寒凉饮食入胃,导致中焦寒,寒气循手太阴肺经上入于肺中,导致肺寒,肺为娇脏,不耐寒热,外内寒邪并聚于肺,则肺失宣降,肺气上逆发生咳嗽。因此答案选D。",
60 | "knowledge_point": "中医经典",
61 | "index": 8196,
62 | "score": 1
63 | }
64 | ```
65 | Clinical Vertical Reasoning Task:
66 | ```json
67 | {
68 | "share_content": "刘×,男,46岁,刻下眩晕而见头重如蒙。胸闷恶心,食少多寐,苔白腻,脉濡滑。",
69 | "question": [
70 | {
71 | "sub_question": "1).证属( )。\nA.肝阳上亢\nB.气血亏虚\nC.肾精不足\nD.痰浊中阻\nE.以上都不是\n",
72 | "answer": [
73 | "D"
74 | ],
75 | "analysis": ""
76 | },
77 | {
78 | "sub_question": "2).治法宜选( )。\nA.燥湿祛痰,健脾和胃\nB.补肾滋阴\nC.补肾助阳\nD.补养气血,健运脾胃\nE.平肝潜阳,滋养肝肾\n",
79 | "answer": [
80 | "A"
81 | ],
82 | "analysis": ""
83 | },
84 | {
85 | "sub_question": "3).方药宜选( )。\nA.右归丸\nB.左归丸\nC.半夏白术天麻汤\nD.归脾汤\nE.天麻钩藤饮\n",
86 | "answer": [
87 | "C"
88 | ],
89 | "analysis": ""
90 | }
91 | ],
92 | "knowledge_point": "中医内科学",
93 | "index": 334,
94 | "score": 1
95 | }
96 | ```
97 | Knowledge Horizontal Correlation Task:
98 | ```json
99 | {
100 | "share_content": "(共用备选答案)\nA.化痰息风,健脾祛湿\nB.清肺化痰,散结排脓\nC.疏风宣肺,化痰止咳\nD.清热化痰,平肝息风\nE.润肺清热,理气化痰\n",
101 | "question": [
102 | {
103 | "sub_question": "1).贝母瓜蒌散的功用是( )。",
104 | "answer": [
105 | "E"
106 | ],
107 | "analysis": ""
108 | },
109 | {
110 | "sub_question": "2).半夏白术天麻汤的功用是( )。",
111 | "answer": [
112 | "A"
113 | ],
114 | "analysis": ""
115 | }
116 | ],
117 | "knowledge_point": "方剂学",
118 | "index": 1938,
119 | "score": 1
120 | }
121 | ```
122 |
123 | ## 🧐 Evaluation Details
124 |
125 | We designed task-adaptive prompts that require LLMs to answer questions and provide explanations. The evaluation framework consists of the following components:
126 |
127 | | 文件名 | 说明 |
128 | | -------------------------- | -------------- |
129 | | [./pipline/choice_bench.py](./pipline/choice_bench.py) | Set up different tasks and guide LLMs to generate answers and explanations|
130 | | [./pipline/bench_function.py](./pipline/bench_function.py) | Functions for testing |
131 | | [./pipline/Acc.py](./pipline/Acc.py) | Compute accuracy |
132 | | [./pipline/Model_API.py](./pipline/Model_API.py)| Call model APIs (OpenAI as an example), adjustable for different models |
133 | | [./TCMBench_code/explain_evaluation.py](./TCMBench_code/explain_evaluation.py)| Evaluate explanation quality using ROUGE-1, ROUGE-L, SARI, BERTScore, BartScore, and our proposed SKScore |
134 | |[./HumanTrue.json](./HumanTrue.json)| HumanTrue Dataset|
135 |
136 |
137 | First, run [/pipline/choice_bench.py](./pipline/choice_bench.py) to test models and obtain their generated answers and explanations:
138 | ```
139 | # (Optional) If needed, set your proxy:
140 | os.environ['HTTPS_PROXY']="your proxy"
141 |
142 | # If using closed-source models, enter your API key; otherwise leave blank:
143 | api_key = "your key"
144 |
145 | # Specify --data_path and --sys_prompt for different tasks, and --model_name to call different models.
146 | # Example: run the FKU task on gpt-4-0613
147 | python choice_bench.py --data_path ../data/first_level --sys_prompt FKU.json --model_name gpt-4-0613
148 | ```
149 |
150 | Use [./pipline/Acc.py](./pipline/Acc.py) to compute accuracy scores for different models on different tasks by setting --data_path, --queation_type, and --model_name:
151 | ```
152 | python Acc.py --data_path ../data/first_level --queation_type FKU --model_name gpt-4-0613
153 | ```
154 |
155 | Use [./TCMBench_code/explain_evaluation.py](./TCMBench_code/explain_evaluation.py) to compute explanation scores across six metrics by specifying --model_name:
156 | ```
157 | python explain_evaluation.py --model_name gpt-4-0613
158 | ```
159 | The models used for these metrics can be found at [code link](https://huggingface.co/WJing123/TCMBench_code)
160 |
161 |
162 | 👨⚕️ In addition, this work also introduces our previously developed TCM LLMs, ShenNong. We welcome everyone to follow our open-source TCM large language model project ShenNong-TCM-LLM: **ShenNong-TCM-LLM**:
163 |
164 | - 🚀 [ShenNong-TCM](https://github.com/ywjawmw/ShenNong-TCM-LLM) : To promote the development and real-world application of LLMs in Traditional Chinese Medicine, we released ShenNong, a large-scale TCM language model designed to improve knowledge coverage and enhance its ability to answer medical consultations in TCM. It is built upon the [TCM instruction dataset SN-QA](https://huggingface.co/datasets/michaelwzhu/ShenNong_TCM_Dataset)。
165 |
166 | We also introduce our other open-source healthcare LLM projects:
167 | - 🚀 [Intelligent TCM Inheritance and Innovation Platform](https://github.com/ywjawmw/AI4TCM-Platform) : As part of the Shuzhi Qihuang series, this platform addresses two main challenges: 1. The inability of existing TCM platforms to cover multimodal data by constructing a more comprehensive knowledge graph integrating TCM and Western medicine. 2. The inefficiency in TCM experience inheritance by proposing interpretable prescription analysis technology that automatically analyzes the holistic diagnostic process from symptoms to prescriptions and provides scientific reasoning. It also provides a fair platform to help young doctors and TCM students quickly master advanced knowledge and inherit medical expertise.
168 | - 🚀 [ChatMed-Consult](https://huggingface.co/michaelwzhu/ChatMed-Consult) : Built from the [ChatMed_Consult_Dataset](https://huggingface.co/datasets/michaelwzhu/ChatMed_Consult_Dataset) with over 500k online medical consultations paired with ChatGPT responses. The model backbone is [LlaMA-7b](https://github.com/facebookresearch/llama), combined with LoRA weights from [Chinese-LlaMA-Alpaca](https://github.com/ymcui/Chinese-LLaMA-Alpaca) and extended Chinese vocabulary, followed by efficient parameter tuning using LoRA. All codes are publicly released.
169 |
170 | - 🚀 [PromptCBLUE中文医疗大模型评测基准](https://github.com/michael-wzhu/PromptCBLUE): A prompt-based adaptation of the [CBLUE](https://tianchi.aliyun.com/dataset/95414) benchmark, designed to evaluate Chinese medical knowledge and text-processing abilities of LLMs. PromptCBLUE enables a single generative LLM to handle a variety of medical NLP tasks such as medical record structuring, consultation, and clinical documentation writing.
171 |
172 | ## Acknowledgements
173 |
174 | This project was developed based on APIs of large language models and inspired by evaluation tasks on Gaokao examination questions. We thank the related projects and developers for their contributions:
175 |
176 | - [ChatGPT](https://openai.com/blog/chatgpt)
177 | - [ChatGLM](https://github.com/THUDM/ChatGLM-6B)
178 | - [GaoKao-Bench](https://github.com/OpenLMLab/GAOKAO-Bench)
179 |
180 |
181 | ## Citation
182 |
183 | If you use the data or code from this project, please cite:
184 |
185 | ```bash
186 | @misc{yue2023 TCMBench,
187 | title={TCMBench: Benchmarking Large Language Models in Traditional Chinese Medicine from Knowledge to Clinical Reasoning},
188 | author={Wenjing Yue, Ming guan, Wei Zhu, Xiaoling Wang, Honglin li},
189 | year={2023},
190 | publisher = {GitHub},
191 | journal = {GitHub repository},
192 | howpublished = {\url{https://github.com/ywjawmw/TCMBench}},
193 | }
194 |
195 | ```
196 |
197 | ## Team
198 |
199 | This project was completed under the guidance of **Prof. Xiaoling Wang** from the School of Computer Science and Technology, **Prof. Honglin Li** from Innovation Center for AI and Drug Discovery, East China Normal University.
200 |
201 |
202 |
203 |
204 |
205 |
--------------------------------------------------------------------------------
/TCMBench_code/explain_evaluation.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # !/usr/bin/python3
3 | # -*- coding: utf-8 -*-
4 | # @Time : 2025/9/28 23:10
5 | # @Author : Wenjing
6 | # @File : explain_evaluation.py
7 | # @Desc : 评估LLMs生成解析的质量
8 |
9 | import argparse
10 | import json
11 | import sys
12 | import json
13 | import math
14 | import os
15 | import time
16 | from typing import List, Dict
17 | from sari import SARIsent
18 | import numpy as np
19 | from BARTScore.bart_score_chinese import BARTScorerChinese
20 | # from BARTScore.bart_score import BARTScorer
21 | import evaluate, torch
22 | import statistics
23 | from rouge_chinese import Rouge
24 | import jieba
25 | from transformers import (
26 | AutoConfig,
27 | AutoModelForSequenceClassification,
28 | AutoTokenizer
29 | )
30 | from sklearn.metrics import roc_auc_score
31 | from sklearn import metrics
32 | from matplotlib import pyplot as plt
33 |
34 | jieba.load_userdict("jieba_tcm.txt") # load进入TCM术语
35 | with open('jieba_tcm.txt', 'r', encoding='utf-8') as f:
36 | tcm_terms = f.readlines()
37 | tcm_term_db = [t.strip() for t in tcm_terms]
38 |
39 | def parse_args():
40 | parser = argparse.ArgumentParser(description="parameter of LLMs")
41 | parser.add_argument(
42 | "--model_name",
43 | type=str,
44 | default='gpt-4-0613',
45 | help="The LLM name.",
46 | )
47 | args = parser.parse_args()
48 | return args
49 |
50 |
51 | def get_analysis_A12(data_dict_predict):
52 | correct_analysis, predict_analysis = [], []
53 | for data in data_dict_predict['example']:
54 | if len(data['analysis']) > 0:
55 | correct_analysis.append(data['analysis'])
56 | predict_analysis.append(data['model_output'])
57 | return correct_analysis, predict_analysis
58 |
59 |
60 | def get_analysis_A34(data_dict_predict):
61 | correct_analysis, predict_analysis = [], []
62 | for data in data_dict_predict['example']:
63 | question = data["question"]
64 | for sub_question in question:
65 | if len(sub_question['analysis']) > 0:
66 | correct_analysis.append(sub_question['analysis'])
67 | predict_analysis.append(sub_question['model_output'])
68 | return correct_analysis, predict_analysis
69 |
70 |
71 | def rouge_score(correct_analysis, predict_analysis):
72 | # rouge = evaluate.load('rouge')
73 | # results = rouge.compute(predictions=correct_analysis,
74 | # references=predict_analysis,
75 | # rouge_types=['rouge1', 'rougeL'],
76 | # use_aggregator=False)
77 | rouge = Rouge()
78 | empty_num = 0
79 | correct_analysis_cal, predict_analysis_cal = [], []
80 | for ca, pa in zip(correct_analysis, predict_analysis):
81 | if len(pa) == 0:
82 | empty_num += 1
83 | else:
84 | correct_analysis_cal.append(' '.join(jieba.cut(ca)))
85 | predict_analysis_cal.append(' '.join(jieba.cut(pa)))
86 | scores = rouge.get_scores(predict_analysis_cal, correct_analysis_cal)
87 | results_rouge1 = [round(score['rouge-1']['r'],2) for score in scores] + [0.0] * empty_num
88 | results_rougel = [round(score['rouge-l']['f'], 2) for score in scores] + [0.0] * empty_num
89 | return np.mean(results_rouge1), np.mean(results_rougel)
90 |
91 | def bert_score(correct_analysis, predict_analysis):
92 | bertscore = evaluate.load('../TCMBench_code/bertscore')
93 | results = bertscore.compute(predictions=predict_analysis, references=correct_analysis, lang="zh",
94 | model_type="bert-base-chinese")
95 | score = [round(v, 2) for v in results["f1"]]
96 | return np.mean(score)
97 |
98 | def bart_Score(correct_analysis, predict_analysis):
99 | bart_scorer = BARTScorerChinese(checkpoint='bart-base-chinese')
100 | bart_scores = bart_scorer.score(correct_analysis, predict_analysis, batch_size=4)
101 | # print("BART Score", np.mean(bart_scores))
102 | return np.mean(bart_scores)
103 |
104 | def calculate_sari(
105 | input_lns: List[str], output_lns: List[str], reference_lns: List[str]
106 | ) -> Dict:
107 | a, b, c, d = [], [], [], []
108 | for input, output, ref in zip(input_lns, output_lns, reference_lns):
109 | a_, b_, c_, d_ = SARIsent(input, output, [ref])
110 |
111 | a.append(round(a_,2))
112 | return a
113 |
114 |
115 | def sari_score(correct_analysis, predict_analysis):
116 | sariscores = calculate_sari(correct_analysis, predict_analysis, correct_analysis) # 参考答案根据reference 只看sari_avgkeepscore
117 | # print(f"SARI score_解析: {sariscores}")
118 | return np.mean(sariscores)
119 |
120 | print("使用 x / sum_x,缩放f1的值")
121 | def softmax(x):
122 | e_x = x
123 | sum_x = e_x.sum(axis=0)
124 | if sum_x == 2 * x.size: # 标准解析中就没有中医术语,那么解析中的每一句话的f1score都=2
125 | return [1/x.size] * x.size # 加和平均
126 | elif sum_x == 0:
127 | return [0] * x.size # 也就是LLMs中没有中医术语,那么这个解析是不好的,给一个特别低的分数
128 | else:
129 | return x / sum_x
130 |
131 | import re
132 | def split_sentences(text):
133 | # 利用正则表达式按照句号、感叹号、问号进行划分
134 | sentences = re.split(r'(?<=[。!?])\s*', text)
135 | # 去除空字符串和空白符
136 | sentences = [s.strip() for s in sentences if s.strip()]
137 | return sentences
138 |
139 | from collections import Counter
140 | def tcm_score_f1(analysis_tcm_terms_counter, analysis_terms_counter, doc, tcm_term_db):
141 | """
142 | 中医术语匹配度
143 | :param analysis_tcm_terms_counter: 解析中的中医术语以及计数
144 | :param analysis_terms_counter: 解析
145 | :param doc: 需要检测的语句
146 | :param tcm_term_db: 中医术语库
147 | :return:
148 | """
149 |
150 | doc_terms_list = list(jieba.cut(doc))
151 | doc_terms_counter = Counter(doc_terms_list)
152 | doc_tcm_terms_list = [term for term in doc_terms_list if term in tcm_term_db]
153 | doc_tcm_terms_counter = Counter(doc_tcm_terms_list) # 片段中所有的中医术语计数
154 | if len(analysis_tcm_terms_counter) == 0:
155 | return 2 # 如果解析中没有中医术语,那么就不需要进行F1score的计算
156 | elif len(doc_tcm_terms_counter) == 0:
157 | return 0 # # 如果LLMs中没有中医术语,那么F1score=0, 说明这句话不符合中医诊疗语言,或者没有什么信息量
158 | comment_term_counter = analysis_tcm_terms_counter & doc_tcm_terms_counter # 重复的中医术语
159 | recall_comment_score, precision_comment_score = 0, 0
160 | for term in comment_term_counter:
161 | recall_comment_score += comment_term_counter[term] / analysis_tcm_terms_counter[term]
162 | precision_comment_score += comment_term_counter[term] / doc_tcm_terms_counter[term]
163 | recall = recall_comment_score / len(analysis_tcm_terms_counter) # 重复的中医术语个数/解析中的中医术语个数 —— 重叠度
164 | precision = precision_comment_score / len(doc_tcm_terms_counter) # 重复的中医术语个数/ 文档的中医术语个数 —— 冗余度
165 | informational = len(list(set(doc_tcm_terms_list))) / len(list(set(doc_terms_list)))
166 | # informational = len(doc_tcm_terms_counter) / len(doc_terms_counter) * (sum(doc_terms_counter.values()) / sum(analysis_terms_counter.values())) # 片段中的中医术语 / 片段中术语个数(不重复) * (片段的术语总个数 / 解析的术语总个数)[长度的惩罚项] —— 信息度
167 | if precision == 0 or recall == 0:
168 | return 0
169 | else:
170 | f1_score = 3 * (precision * recall * informational) / (precision + recall + informational)
171 | return f1_score
172 |
173 | def calculate_score(sentence, sample, model, tokenizer):
174 | inputs = tokenizer.batch_encode_plus(
175 | batch_text_or_text_pairs=[(sentence, sample)],
176 | add_special_tokens=True, padding="longest",
177 | truncation=True, return_tensors="pt",
178 | return_token_type_ids=True, return_attention_mask=True,
179 | max_length=512
180 | )
181 |
182 | logits = model(**inputs).logits # neutral is already removed
183 | # print(logits)
184 | probs = torch.argmax(logits, dim=-1)
185 | prob_label = probs[0].item() # 类别
186 | probs1 = torch.softmax(logits, dim=-1)
187 | prob_1 = probs1[0][0].item() # prob(相关程度)
188 | return prob_label, prob_1
189 |
190 | def f1_score_tcm_term(analysis_tcm_terms_list, analysis_terms_list, doc, tcm_term_db):
191 | """
192 | 中医术语匹配度
193 | :param analysis_tcm_terms_list: 解析中的中医术语
194 | :param analysis_terms_list: 解析
195 | :param doc: 需要检测的语句
196 | :param tcm_term_db: 中医术语库
197 | :return:
198 | """
199 |
200 | doc_terms_list = list(jieba.cut(doc))
201 | doc_tcm_terms_list = [term for term in doc_terms_list if term in tcm_term_db]
202 | doc_tcm_terms_list = doc_tcm_terms_list # 片段中所有的中医术语(含有重复元素)
203 | if len(analysis_tcm_terms_list) == 0:
204 | return 2 # 如果解析中没有中医术语,那么就不需要进行F1score的计算
205 | elif len(doc_tcm_terms_list) == 0:
206 | return 0 # # 如果LLMs中没有中医术语,那么F1score=0, 说明这句话不符合中医诊疗语言,或者没有什么信息量
207 | comment_term = set(analysis_tcm_terms_list) & set(doc_terms_list) # 重复的中医术语
208 | comment_num = 0
209 | for doc_term in doc_tcm_terms_list:
210 | if doc_term in comment_term:
211 | comment_num += 1
212 | recall = comment_num / len(analysis_tcm_terms_list) # 重复的中医术语个数/真正对的中医术语个数 —— 重叠度
213 | precision = comment_num / len(doc_tcm_terms_list) # 重复的中医术语个数/ 文档的中医术语个数 —— 冗余度
214 | informational = len(list(set(doc_tcm_terms_list))) / len(list(set(doc_terms_list)))
215 | if precision == 0 or recall == 0 or informational == 0:
216 | return 0
217 | else:
218 | f1_score = 3 * (precision * recall * informational) / (precision + recall + informational)
219 | return f1_score
220 |
221 |
222 | def predict_analysys_response(sentences: List[str], sampled_passages: List[str], model, tokenizer):
223 | """
224 | :param sentences: list of 标准解析
225 | :param sampled_passages: LLMs的解析
226 | """
227 | scores1 = list() # 计算LLMs生成的解析与标准解析之间的分数
228 | scores1_counter = list() # _counter
229 | scores1_nof1 = list()
230 | scores1_dotf1 = list()
231 | num = 0
232 | for sentence, sample in zip(sentences, sampled_passages): # 解析
233 | if num == 0:
234 | print(f"sentence: {sentence}")
235 | print(f"sample: {sample}")
236 | num += 1
237 | # 分句
238 | response_sentence_list = split_sentences(sample)
239 | analysis_sentence_list = split_sentences(sentence)
240 | tcm_score = []
241 | tcm_score_counter = []
242 | prob_score_list = []
243 | tcm_score_dotf1 = []
244 | for analysis_sentence in analysis_sentence_list: # 分句
245 | f1_score, prob_score = [], []
246 | f1_score_counter = []
247 | if len(response_sentence_list) > 0:
248 | for response_sentence in response_sentence_list: # LLMs分析分句
249 | # 统计标准解析中的中医术语
250 | analysis_terms = list(jieba.cut(analysis_sentence))
251 | analysis_terms_counter = Counter(analysis_terms)
252 | analysis_tcm_terms_list = [term for term in analysis_terms if term in tcm_term_db]
253 | analysis_tcm_terms_counter = Counter(analysis_tcm_terms_list) # 解析中的中医术语计数
254 | analysis_tcm_terms_list = list(analysis_tcm_terms_counter.keys()) # 解析中的中医术语列表
255 |
256 | prob_label_A, prob_1_A = calculate_score(analysis_sentence, response_sentence, model, tokenizer)
257 | prob_label_reverse_A, prob_1_reserve_A = calculate_score(response_sentence, analysis_sentence,
258 | model, tokenizer)
259 | prob = (prob_1_A + prob_1_reserve_A) / 2
260 | ################## TCM score ####################
261 | f1_term_score = f1_score_tcm_term(analysis_tcm_terms_list, analysis_terms, response_sentence, tcm_term_db)
262 | f1_term_score_counter = tcm_score_f1(analysis_tcm_terms_counter, analysis_terms_counter, response_sentence, tcm_term_db)
263 | f1_score.append(f1_term_score)
264 | prob_score.append(prob)
265 | f1_score_counter.append(f1_term_score_counter)
266 | f1_score = np.array(f1_score)
267 |
268 | ####不加f1 score#######
269 | average_prob_score = statistics.mean(prob_score)
270 | prob_score_list.append(average_prob_score)
271 |
272 | prob_score = np.array(prob_score)
273 | f1_score_counter = np.array(f1_score_counter)
274 | # 对列表进行归一化
275 | try:
276 | normalized_f1_score = softmax(f1_score)
277 | normalized_f1_score_counter = softmax(f1_score_counter)
278 | except:
279 | print(response_sentence_list)
280 | print(f1_score)
281 | sys.exit()
282 |
283 | #####直接与F1 score相乘#####
284 | analysis_sentence_dotf1 = np.sum(f1_score * prob_score) / prob_score.size
285 | tcm_score_dotf1.append(analysis_sentence_dotf1)
286 | ###### 计算相乘并相加的结果,加权平均
287 | analysis_sentence_score = np.sum(normalized_f1_score * prob_score)
288 | tcm_score.append(analysis_sentence_score)
289 | # 用Counter计算的score
290 | analysis_sentence_score_counter = np.sum(normalized_f1_score_counter * prob_score)
291 | tcm_score_counter.append(analysis_sentence_score_counter)
292 | else:
293 | tcm_score.append(0)
294 | tcm_score_counter.append(0)
295 | tcm_score_dotf1.append(0)
296 | prob_score_list.append(0)
297 | if len(sample) < len(sentence):
298 | if len(sample) > 1:
299 | length_penalty = math.exp(1 - math.log(len(sentence)) / math.log(len(sample)))
300 | elif len(sample) == 1:
301 | length_penalty = math.exp(1 - math.log(len(sentence)) / math.log(len(sample) + 1))
302 | else:
303 | length_penalty = 0
304 | else:
305 | # length_penalty = 1
306 | length_penalty = math.exp(1 - math.log(len(sample)) / math.log(len(sentence)))
307 | scores_per_response = statistics.mean(tcm_score) * length_penalty
308 | scores1.append(scores_per_response)
309 | # 用Counter计算的score
310 | scores_per_response_counter = statistics.mean(tcm_score_counter) * length_penalty
311 | scores1_counter.append(scores_per_response_counter)
312 | ############不加f1 score
313 | scores1_nof1.append(statistics.mean(prob_score_list))
314 | #####直接与F1 score相乘#####
315 | scores_per_response_dotf1 = statistics.mean(tcm_score_dotf1)
316 | scores1_dotf1.append(scores_per_response_dotf1)
317 | scores_per_doc = statistics.mean(scores1)
318 | scores_per_doc_counter = scores1_counter
319 | print("解析与回答之间的TCM Score,用Counter计算:", scores_per_doc_counter)
320 | return scores_per_doc_counter
321 |
322 | def nli_score(references, predictions, model, tokenizer):
323 | n_score = predict_analysys_response(sentences=references, sampled_passages=predictions, model=model, tokenizer=tokenizer)
324 | return n_score
325 |
326 |
327 | if __name__ == "__main__":
328 | args = parse_args()
329 | question_type = "FKU"
330 | with open(f"../data/first_level/{args.model_name}_{question_type}.json", "r", encoding="utf-8") as f:
331 | data_FKU = json.load(f)
332 | f.close()
333 |
334 | question_type = "CVR"
335 | with open(f"../data/first_level/{args.model_name}_{question_type}.json", "r", encoding="utf-8") as f:
336 | data_CVR = json.load(f)
337 | f.close()
338 |
339 | question_type = "KHC"
340 | with open(f"../data/first_level/{args.model_name}_{question_type}.json", "r", encoding="utf-8") as f:
341 | data_KHC = json.load(f)
342 | f.close()
343 |
344 | correct_analysis, predict_analysis = [], []
345 | correct_analysis_A12, predict_analysis_A12 = get_analysis_A12(data_FKU)
346 | correct_analysis_A3, predict_analysis_A3 = get_analysis_A34(data_CVR)
347 | correct_analysis_B1, predict_analysis_B1 = get_analysis_A34(data_KHC)
348 | correct_analysis += correct_analysis_A12 + correct_analysis_A3 + correct_analysis_B1
349 | predict_analysis += predict_analysis_A12 + predict_analysis_A3 + predict_analysis_B1
350 |
351 | rouge1, rouge_L = rouge_score(correct_analysis, predict_analysis)
352 | print("ROUGE-1:", rouge1)
353 | print("ROUGE-L:", rouge_L)
354 | sari_scores = sari_score(correct_analysis, predict_analysis)
355 | print("SARI:", sari_scores)
356 | bert_scores = bert_score(correct_analysis, predict_analysis)
357 | print("BERTScore:", bert_scores)
358 | bart_scores = bart_Score(correct_analysis, predict_analysis)
359 | print("BARTScore:", bart_scores)
360 |
361 | # SKScore
362 | model_name_or_path = "../TCMBench_code/Deberta-V3-base-tmnli-QAC"
363 | print(model_name_or_path)
364 | config = AutoConfig.from_pretrained(
365 | model_name_or_path,
366 | num_labels=3,
367 | finetuning_task="mnli",
368 | trust_remote_code=False
369 | )
370 | tokenizer = AutoTokenizer.from_pretrained(
371 | model_name_or_path, use_fast=not False, trust_remote_code=False
372 | )
373 | # print(tokenizer)
374 | model = AutoModelForSequenceClassification.from_pretrained(
375 | f"{model_name_or_path}",
376 | from_tf=bool(".ckpt" in model_name_or_path),
377 | config=config,
378 | ignore_mismatched_sizes=False,
379 | )
380 | sk_score = nli_score(correct_analysis, predict_analysis, model, tokenizer)
381 | SKScore = np.mean(sk_score)
382 | print("SKScore:", SKScore)
383 |
384 |
385 |
--------------------------------------------------------------------------------
/pipline/bench_function.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | import json
4 |
5 | import re
6 |
7 | from typing import List
8 |
9 | from tqdm import tqdm
10 | from collections import Counter
11 |
12 | # args = parse_args()
13 |
14 |
15 | def get_api_key(filename: str, start_num: int, end_num: int) -> List[str]:
16 | """
17 | Retrieves API keys from a file.
18 |
19 | :param filename: Name of the file containing API keys
20 | :param start_num: Starting line number for reading the file
21 | :param end_num: Ending line number for reading the file
22 | :return: List of API keys
23 | """
24 | with open(filename, 'r') as file:
25 | lines = file.readlines()
26 |
27 | pattern = re.compile(r'sk-[\s\S]*?(?=\s*\n)')
28 | api_key_list = []
29 |
30 | for i in range(start_num, end_num):
31 | api_key = pattern.findall(lines[i])
32 | if len(api_key) != 0:
33 | api_key_list.append(api_key[0])
34 |
35 | return api_key_list
36 |
37 |
38 | def extract_choice_answer(model_output, question_type, answer_lenth=None):
39 | """
40 | Extract choice answer from model output
41 |
42 | Format of model_output that is expected:
43 | 'single_choice': choice answer should be the last Capital Letter of the model_output, e.g.: "...【答案】 A "
44 | 'multi_question_choice': "...【答案】A ... 【答案】C ..." or write the choice answers at the beginning of the model_output, e.g. "A C D E F...."
45 | 'multi_choice': "...【答案】 ABD " or write the choice answers at the end of the model_output, e.g. "... ACD"
46 | 'five_out_of_seven': choice answers should be the first five Capital Letters of the model_output, e.g. "A C D F B ...."
47 | """
48 | model_answer = []
49 | model_output = model_output[::-1]
50 | pattern = r"([A-Z]).*?案答"
51 | check_info = re.search(pattern, model_output)
52 | if check_info:
53 | pattern = r"\.[A-Z]"
54 | temp = re.findall(pattern, model_output)
55 | if len(temp) > 0:
56 | # answer = temp[0]
57 | answer = check_info.group(1)
58 | model_answer.append(answer)
59 | else:
60 | temp = re.findall(r'[A-E]', model_output)
61 | if len(temp) != 0:
62 | answer = temp[0]
63 | model_answer.append(answer)
64 | else:
65 | temp = re.findall(r'[A-E]', model_output)
66 | if len(temp) != 0:
67 | answer = temp[0]
68 | model_answer.append(answer)
69 |
70 | return model_answer
71 |
72 | def extract_choice_answer_hard(model_output, question_type, answer_lenth=None):
73 | model_answer = []
74 | # DDST
75 | if question_type == "DDST_hard":
76 | model_answer = []
77 | model_output = model_output[0].replace("[", "").replace("]", "").replace("", "").replace("",
78 | "").replace(" ",
79 | "")
80 | model_output = ''.join([char for char in model_output if char.isdigit()])
81 | temp = re.findall(r'[0-9]', model_output)
82 | if len(temp) != 0:
83 | model_answer.append(model_output)
84 | elif question_type == "herb_predict":
85 | model_answer = []
86 | model_output = model_output[0].replace("[", "").replace("]", "").replace("", "").replace("",
87 | "").replace(" ",
88 | "")
89 | temp = model_output.split(",")
90 | if len(temp) != 0:
91 | if len(temp) < 20:
92 | # 在temp补若干个-1,让其长度为20
93 | for i in range(20 - len(temp)):
94 | temp.append("-1")
95 | if len(temp) > 20:
96 | temp = temp[:20]
97 | model_answer.append(temp)
98 | return model_answer
99 |
100 | def A3_second_check(model_output):
101 | pattern = r"【答案】"
102 | temp = re.findall(pattern, model_output)
103 | check_answer = list()
104 | if len(temp) > 0:
105 | model_output = model_output.replace("\n", "")
106 | pattern = r"答案】.*?([A-Z])"
107 | temp = re.findall(pattern, model_output)
108 | if len(temp) > 0:
109 | answer = temp[-1]
110 | check_answer.append(answer)
111 | else:
112 | check_answer = []
113 | else:
114 | pattern = r"答案.*?([A-Z])"
115 | check_info = re.findall(pattern, model_output)
116 | if check_info:
117 | answer = check_info[-1]
118 | check_answer.append(answer)
119 | else:
120 | model_output = model_output[::-1]
121 | temp = re.findall(r'[A-E]', model_output)
122 | if len(temp) != 0:
123 | answer = temp[0]
124 | check_answer.append(answer)
125 | return check_answer
126 |
127 | def pattern_second_check(ans_list, model_output):
128 | check_answer = list()
129 | ans_list = [
130 | ans.replace("A", "").replace("B", "").replace("C", "").replace("D", "").replace("E", "").replace(
131 | ".", "").replace(".", "") for ans in ans_list if len(ans) > 0]
132 | ans_id = -1
133 | candidate_ans = {}
134 | for ans in ans_list:
135 | a_list = re.split(r'[,。;\s]', ans)
136 | max_count = 0
137 | for a in a_list:
138 | if a in model_output:
139 | ans_id = ans_list.index(ans)
140 | c = model_output.count(a)
141 | max_count += c
142 | if max_count > 0:
143 | candidate_ans[ans] = max_count
144 | if len(candidate_ans) > 1:
145 | # 有多个选项
146 | # 如果多个选项中,有的选项出现的频率超过了其它选项,那么认为这个选项是正确答案
147 | max_value = max(candidate_ans.values())
148 | value_clist = Counter(candidate_ans.values())
149 | if value_clist[max_value] == 1:
150 | unique_max_key = [key for key, value in candidate_ans.items() if value == max_value][0]
151 | ans_id = ans_list.index(unique_max_key)
152 | check_answer.append(chr(ans_id + 65))
153 | else:
154 | print(candidate_ans)
155 | elif ans_id >= 0:
156 | check_answer.append(chr(ans_id + 65))
157 | return check_answer
158 |
159 | def herb_second_check(model_output):
160 | no_answer = 0
161 | no_standard = 0
162 | if "-1" in model_output:
163 | # 统计有多少个-1
164 | count = model_output.count("-1")
165 | if 19 <= count <= 20:
166 | no_answer += 1
167 | else:
168 | no_standard += 1
169 | res_list = []
170 | if type(model_output[0]) is list:
171 | model_output = model_output[0]
172 | for res in model_output:
173 | if "、" in res:
174 | res_list += res.split("、")
175 | if len(res_list) > 5:
176 | new_res_list = []
177 | for res in model_output:
178 | if res != "-1" and "、" not in res:
179 | new_res_list += res
180 | else:
181 | if res == "-1":
182 | continue
183 | new_res_list.extend(res_list)
184 | if len(new_res_list) < 20:
185 | # 在temp补若干个-1,让其长度为20
186 | for i in range(20 - len(new_res_list)):
187 | new_res_list.append("-1")
188 |
189 | model_output = new_res_list
190 | if len(set(model_output)) != 20:
191 | # 找到重复的非-1的元素
192 | repeat_list = []
193 | for i, res in enumerate(model_output):
194 | if res != "-1":
195 | if res not in repeat_list:
196 | repeat_list.append(res)
197 | else:
198 | model_output[i] = "-1"
199 | return model_output
200 |
201 |
202 | def choice_test_A12(**kwargs):
203 | model_api = kwargs['model_api']
204 | model_name = kwargs['model_name']
205 | start_num = kwargs['start_num']
206 | end_num = kwargs['end_num']
207 | data = kwargs['data']['example']
208 | keyword = kwargs['keyword']
209 | prompt = kwargs['prompt']
210 | question_type = kwargs['question_type']
211 | save_directory = kwargs['save_directory']
212 | args = kwargs['args']
213 |
214 | model_answer_dict = []
215 | for i in range(start_num, end_num):
216 | index = data[i]['index']
217 | question = data[i]['question'].strip() + '\n'
218 | score = 1
219 | standard_answer = data[i]['answer']
220 | try:
221 | analysis = data[i]['analysis']
222 | except:
223 | analysis = ''
224 | try:
225 | knowledge_point = data[i]['knowledge_point']
226 | except:
227 | knowledge_point = ''
228 | model_output = model_api.send_request_turbo(prompt, question)[0]
229 | model_answer = extract_choice_answer(model_output, question_type, 5)
230 | if len(model_answer) == 0:
231 | ans_list = question.split("\n")[1:]
232 | model_answer = pattern_second_check(ans_list, model_output)
233 | # TODO: which content of temp we expect
234 | dict = {
235 | 'index': index,
236 | # 'year': year,
237 | # 'category': category,
238 | 'score': score,
239 | 'question': question,
240 | 'standard_answer': standard_answer,
241 | 'analysis': analysis,
242 | 'knowledge_point': knowledge_point,
243 | 'model_answer': model_answer,
244 | 'model_output': model_output
245 | }
246 | for key, value in dict.items():
247 | print(key, ":", value)
248 | # print(dict)
249 | print("*" * 100, "index-", dict["index"], "*" * 100)
250 | model_answer_dict.append(dict)
251 |
252 | file_name = f"seperate_{start_num}-{end_num}.json"
253 | file_path = os.path.join(save_directory, file_name)
254 | with open(file_path, 'w', encoding='utf-8') as f:
255 | output = {
256 | 'keyword': keyword,
257 | 'example': model_answer_dict
258 | }
259 | json.dump(output, f, ensure_ascii=False, indent=4)
260 | f.close()
261 |
262 |
263 | def choice_test_A34(**kwargs):
264 | model_api = kwargs['model_api']
265 | model_name = kwargs['model_name']
266 | start_num = kwargs['start_num']
267 | end_num = kwargs['end_num']
268 | data = kwargs['data']['example']
269 | keyword = kwargs['keyword']
270 | prompt = kwargs['prompt']
271 | question_type = kwargs['question_type']
272 | save_directory = kwargs['save_directory']
273 | args = kwargs['args']
274 |
275 | model_answer_dict = []
276 | for i in range(start_num, end_num):
277 | index = data[i]['index']
278 | question = data[i]['question'] # list() 包含多个小问题和答案
279 | score = 1
280 | try:
281 | knowledge_point = data[i]['knowledge_point']
282 | except:
283 | knowledge_point = ''
284 | share_content = data[i]['share_content']
285 | model_output = model_api.send_request_chat(prompt, share_content, question)
286 |
287 | question_list = []
288 | for sub_question, output in zip(question, model_output):
289 | standard_answer = sub_question['answer']
290 | try:
291 | analysis = sub_question['analysis']
292 | except:
293 | analysis = ''
294 | model_answer = extract_choice_answer(output, question_type, 5)
295 | if question_type in ["CVR", "SDT", 'SDT_reverse', "SDT_shuffle"]:
296 | model_answer = A3_second_check(output)
297 | if len(model_answer) == 0:
298 | if question_type in ["CVR", "SDT", 'SDT_reverse', "SDT_shuffle"]:
299 | ans_list = sub_question.split("\n")[1:]
300 | else:
301 | ans_list = share_content.split("\n")[1:]
302 | model_answer = pattern_second_check(ans_list, output)
303 | sub_question_dict = {
304 | 'sub_question': sub_question['sub_question'],
305 | 'standard_answer': standard_answer,
306 | 'analysis': analysis,
307 | 'model_answer': model_answer,
308 | 'model_output': output
309 | }
310 | question_list.append(sub_question_dict)
311 | # TODO: which content of temp we expect
312 |
313 | dict = {
314 | 'index': index,
315 | 'score': score,
316 | 'share_content': share_content,
317 | 'question': question_list,
318 | 'knowledge_point': knowledge_point,
319 | }
320 | model_answer_dict.append(dict)
321 |
322 | file_name = f"seperate_{start_num}-{end_num}.json"
323 | file_path = os.path.join(save_directory, file_name)
324 | with open(file_path, 'w', encoding='utf-8') as f:
325 | output = {
326 | 'keyword': keyword,
327 | 'example': model_answer_dict
328 | }
329 | json.dump(output, f, ensure_ascii=False, indent=4)
330 | f.close()
331 |
332 | def choice_test_DDST_hard(**kwargs):
333 | model_api = kwargs['model_api']
334 | model_name = kwargs['model_name']
335 | start_num = kwargs['start_num']
336 | end_num = kwargs['end_num']
337 | data = kwargs['data']['example']
338 | keyword = kwargs['keyword']
339 | prompt = kwargs['prompt']
340 | save_directory = kwargs['save_directory']
341 |
342 | model_answer_dict = []
343 | for i in range(start_num, end_num):
344 | question = data[i]['question']
345 | option = data[i]['option']
346 | standard_answer = data[i]['answer']
347 | model_output = model_api.send_request_hard(prompt, question, option, int(standard_answer))
348 | model_answer = extract_choice_answer_hard(model_output, keyword, 0)
349 | # TODO: which content of temp we expect
350 | dict = {
351 | 'question': question,
352 | 'option': option,
353 | 'standard_answer': [standard_answer],
354 | 'model_answer': model_answer
355 | }
356 | # print("*" * 100, "index-", dict["index"], "*" * 100)
357 | for key, value in dict.items():
358 | print(key, ":", value)
359 | # print(dict)
360 | model_answer_dict.append(dict)
361 |
362 | file_name = f"seperate_{start_num}-{end_num}.json"
363 | file_path = os.path.join(save_directory, file_name)
364 | with open(file_path, 'w', encoding='utf-8') as f:
365 | output = {
366 | 'keyword': keyword,
367 | 'example': model_answer_dict
368 | }
369 | json.dump(output, f, ensure_ascii=False, indent=4)
370 | f.close()
371 |
372 | def choice_test_TCM_Rec(**kwargs):
373 | model_api = kwargs['model_api']
374 | model_name = kwargs['model_name']
375 | start_num = kwargs['start_num']
376 | end_num = kwargs['end_num']
377 | data = kwargs['data']['example']
378 | keyword = kwargs['keyword']
379 | prompt = kwargs['prompt']
380 | save_directory = kwargs['save_directory']
381 | question_type = kwargs['question_type']
382 |
383 | model_answer_dict = []
384 | for i in range(start_num, end_num):
385 | question = "、".join(data[i]['sym_set'])
386 | # option = data[i]['option']
387 | standard_answer = data[i]['herb_set']
388 | model_output = model_api.send_request_TCM_Rec(prompt, question)
389 | model_answer = extract_choice_answer_hard(model_output, keyword, 0)[0]
390 | model_answer = herb_second_check(model_answer)
391 | # TODO: which content of temp we expect
392 | dict = {
393 | 'sym_set': data[i]['sym_set'],
394 | 'model_output': model_answer,
395 | 'herb_set': standard_answer,
396 | }
397 | # print("*" * 100, "index-", dict["index"], "*" * 100)
398 | for key, value in dict.items():
399 | print(key, ":", value)
400 | # print(dict)
401 | model_answer_dict.append(dict)
402 |
403 | file_name = f"seperate_{start_num}-{end_num}.json"
404 | file_path = os.path.join(save_directory, file_name)
405 | with open(file_path, 'w', encoding='utf-8') as f:
406 | output = {
407 | 'keyword': keyword,
408 | 'example': model_answer_dict
409 | }
410 | json.dump(output, f, ensure_ascii=False, indent=4)
411 | f.close()
412 |
413 |
414 | def export_union_json(directory: str, model_name: str, zero_shot_prompt_text: str or list[str], question_type: str) -> None:
415 | """
416 | Merges JSON files containing processed examples in a directory into a single JSON file.
417 |
418 | :param directory: Directory containing the JSON files
419 | :param model_name: Name of the model used to process the examples
420 | # :param keyword: Keyword used to identify the JSON files
421 | :param zero_shot_prompt_text: Prompt text for zero-shot learning
422 | :param question_type: Type of questions in the JSON files (e.g. single_choice, five_out_of_seven, etc.)
423 | """
424 |
425 | save_directory = os.path.join(directory, f'{model_name}_{question_type}')
426 | if os.path.exists(save_directory):
427 | output = {
428 | 'keywords': question_type,
429 | 'model_name': model_name,
430 | 'prompt': zero_shot_prompt_text,
431 | 'example': []
432 | }
433 |
434 | # Iterate through the JSON files with the specified keyword in the directory
435 |
436 | print("Start to merge json files")
437 | files = [file for file in os.listdir(save_directory) if file.endswith('.json')]
438 | for file in files:
439 | file_path = os.path.join(save_directory, file)
440 |
441 | # Load and merge the data from the JSON files
442 | with open(file_path, "r", encoding='utf-8') as f:
443 | data = json.load(f)
444 | output['example'] += (data['example'])
445 | # Save the merged data into a single JSON file
446 | merge_file = os.path.join(directory, f'{model_name}_{question_type}_predictions.json')
447 | output['example'] = sorted(output['example'], key=lambda x: x['index'])
448 | with open(merge_file, 'w', encoding='utf-8') as f:
449 | json.dump(output, f, ensure_ascii=False, indent=4)
450 |
451 | def export_distribute_json(
452 | model_api,
453 | model_name: str,
454 | directory: str,
455 | # keyword: str,
456 | zero_shot_prompt_text: str or List[str],
457 | question_type: str,
458 | args,
459 | parallel_num: int = 1
460 | ) -> None:
461 | """
462 | Distributes the task of processing examples in a JSON file across multiple processes.
463 |
464 | :param model_name: Name of the model to use
465 | :param directory: Directory containing the JSON file
466 |
467 | :param zero_shot_prompt_text: Prompt text for zero-shot learning
468 | :param question_type: Type of questions in the JSON file (e.g. single_choice, five_out_of_seven, etc.)
469 | :param parallel_num: Number of parallel processes to use (default: 5)
470 |
471 | """
472 | # Find the JSON file with the specified keyword
473 | for root, _, files in os.walk(directory):
474 | for file in files:
475 | if file == f'{question_type}.json':
476 | filepath = os.path.join(root, file)
477 | with open(filepath, 'r', encoding='utf-8') as f:
478 | data = json.load(f)
479 |
480 | example_num = len(data['example'])
481 |
482 | # Prepare the list of keyword arguments for parallel processing
483 | kwargs_list = []
484 | batch_size = example_num // parallel_num
485 | save_directory = os.path.join(directory, f'{model_name}_{question_type}')
486 | if not os.path.exists(save_directory):
487 | os.makedirs(save_directory)
488 | # os.system(f'mkdir {save_directory}')
489 |
490 | for idx in range(args.start_num, parallel_num):
491 | start_num = idx * batch_size
492 | end_num = min(start_num + batch_size, example_num)
493 | if start_num >= example_num:
494 | break
495 |
496 | kwargs = {
497 | 'model_api': model_api,
498 | 'start_num': start_num,
499 | 'end_num': end_num,
500 | 'model_name': model_name,
501 | 'data': data,
502 | 'keyword': question_type,
503 | 'prompt': zero_shot_prompt_text,
504 | 'question_type': question_type,
505 | 'save_directory': save_directory,
506 | 'args': args,
507 | }
508 |
509 | kwargs_list.append(kwargs)
510 |
511 | # Run parallel processing based on the question type
512 | if question_type in ["FKU", "CBF", "SCF", "PF", "SDDT", "DDST", "SDDT_hard"]:
513 | for kwargs in kwargs_list:
514 | choice_test_A12(**kwargs)
515 | elif question_type in ["CVR", "KHC", "SDT", 'SDT_reverse', "SDT_shuffle"]:
516 | for kwargs in kwargs_list:
517 | choice_test_A34(**kwargs)
518 | elif question_type in ["DDST_hard"]:
519 | for kwargs in kwargs_list:
520 | choice_test_DDST_hard(**kwargs)
521 | elif question_type in ["herb_predict"]:
522 | for kwargs in kwargs_list:
523 | choice_test_TCM_Rec(**kwargs)
524 |
525 |
526 |
527 | def test_correction_score_A12(data_dict):
528 | score = 0
529 | all_num = 0
530 | model_answer_dict = []
531 | correct_answer_list = []
532 | model_kpoint = {}
533 | for data in data_dict['example']:
534 | all_num += 1
535 | true_answer = data['standard_answer']
536 | model_answer = data['model_answer']
537 | knowledge_point = data["knowledge_point"]
538 | if knowledge_point == "":
539 | knowledge_point = "其他"
540 | if knowledge_point not in model_kpoint.keys():
541 | model_kpoint[knowledge_point] = [0, 0]
542 | model_kpoint[knowledge_point][1] += 1
543 | dict = {
544 | 'index': data["index"],
545 | 'question': data["question"],
546 | 'standard_answer': true_answer,
547 | 'analysis': data["analysis"],
548 | 'knowledge_point': data["knowledge_point"],
549 | 'model_answer': model_answer,
550 | 'model_output': data["model_output"]
551 | }
552 | if true_answer == model_answer:
553 | score += 1
554 | model_kpoint[knowledge_point][0] += 1
555 | correct_answer_list.append(dict)
556 | else:
557 | model_answer_dict.append(dict)
558 | output = {'keyword': data_dict["keyword"],
559 | 'correct_num': score,
560 | 'all_num': all_num}
561 | if len(model_answer_dict) > 0:
562 | output['example'] = model_answer_dict
563 | return score / all_num, output, model_kpoint, score, all_num, correct_answer_list
564 |
565 |
566 | def test_correction_score_A34(data_dict):
567 | score = 0
568 | all_num = 0
569 | model_answer_dict = []
570 | model_kpoint = {}
571 | for data in data_dict['example']:
572 | correction_flag = True
573 | question = data["question"]
574 | knowledge_point = data["knowledge_point"]
575 | question_list = []
576 | if knowledge_point == "":
577 | knowledge_point = "其他"
578 | if knowledge_point not in model_kpoint.keys():
579 | model_kpoint[knowledge_point] = [0, 0]
580 | # all_num += len(question)
581 | for sub_question in question:
582 | all_num += 1
583 | model_kpoint[knowledge_point][1] += 1
584 | standard_answer = sub_question['standard_answer']
585 | model_answer = sub_question['model_answer']
586 | if standard_answer == model_answer:
587 | score += 1
588 | model_kpoint[knowledge_point][0] += 1
589 | else:
590 | correction_flag = False
591 | sub_question_dict = {
592 | 'sub_question': sub_question['sub_question'],
593 | 'standard_answer': standard_answer,
594 | 'analysis': sub_question['analysis'],
595 | 'model_answer': model_answer,
596 | 'model_output': sub_question['model_output']
597 | }
598 | question_list.append(sub_question_dict)
599 | if correction_flag == False: # 有一个错就存起来当错题
600 | dict = {
601 | 'index': data["index"],
602 | 'share_content': data["share_content"],
603 | 'question': question_list,
604 | 'knowledge_point': data["knowledge_point"],
605 | }
606 | model_answer_dict.append(dict)
607 | output = {'keyword': data_dict["keyword"],
608 | 'correct_num': score,
609 | 'all_num': all_num}
610 | if len(model_answer_dict) > 0:
611 | output['example'] = model_answer_dict
612 |
613 | return score / all_num, output, model_kpoint, score, all_num
614 |
615 |
616 |
617 |
--------------------------------------------------------------------------------