├── image ├── init ├── few-shot.png ├── logo_00.png ├── zero-shot.png └── TCMBench_logo.png ├── TCMBench_code ├── init ├── sari.py └── explain_evaluation.py ├── pipline ├── __pycache__ │ ├── Model_API.cpython-310.pyc │ └── bench_function.cpython-310.pyc ├── choice_bench.py ├── Model_API.py ├── Acc.py └── bench_function.py ├── README_Chinese.md └── README.md /image/init: -------------------------------------------------------------------------------- 1 | 图表图片文件夹~ 2 | -------------------------------------------------------------------------------- /TCMBench_code/init: -------------------------------------------------------------------------------- 1 | metric需要的模型文件见 https://huggingface.co/WJing123/TCMBench_code 2 | -------------------------------------------------------------------------------- /image/few-shot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ywjawmw/TCMBench/HEAD/image/few-shot.png -------------------------------------------------------------------------------- /image/logo_00.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ywjawmw/TCMBench/HEAD/image/logo_00.png -------------------------------------------------------------------------------- /image/zero-shot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ywjawmw/TCMBench/HEAD/image/zero-shot.png -------------------------------------------------------------------------------- /image/TCMBench_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ywjawmw/TCMBench/HEAD/image/TCMBench_logo.png -------------------------------------------------------------------------------- /pipline/__pycache__/Model_API.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ywjawmw/TCMBench/HEAD/pipline/__pycache__/Model_API.cpython-310.pyc -------------------------------------------------------------------------------- /pipline/__pycache__/bench_function.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ywjawmw/TCMBench/HEAD/pipline/__pycache__/bench_function.cpython-310.pyc -------------------------------------------------------------------------------- /pipline/choice_bench.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | parent_path = os.path.dirname(sys.path[0]) 4 | print(parent_path) 5 | if parent_path not in sys.path: 6 | sys.path.append(parent_path) 7 | # from LLAMAAPI import LlamaAPI 8 | from Model_API import API 9 | from bench_function import export_distribute_json, export_union_json 10 | import json 11 | import argparse 12 | 13 | def parse_args(): 14 | parser = argparse.ArgumentParser(description="parameter of LLMs") 15 | parser.add_argument( 16 | "--data_path", 17 | type=str, 18 | default='../data/first_level', # 注意按照三个不同的级别进行选择不同的测试任务 19 | help="测试数据", 20 | ) 21 | parser.add_argument( 22 | "--model_name", 23 | type=str, 24 | default='gpt-4-0613', 25 | help="The LLM name.", 26 | ) 27 | parser.add_argument( 28 | "--sys_prompt", 29 | type=str, 30 | default='FKU.json', 31 | help="选择不同测试题类型的指令.", 32 | ) 33 | parser.add_argument( 34 | "--start_num", 35 | type=int, 36 | default=0, 37 | help="保存文档的起始id", 38 | ) 39 | args = parser.parse_args() 40 | return args 41 | 42 | # 测试主函数, 一个问题+一个答案 --》 bench_function 43 | 44 | if __name__ == "__main__": 45 | args = parse_args() 46 | with open(f"{args.data_path}/{args.sys_prompt}", "r", encoding="utf-8") as f: 47 | data = json.load(f) 48 | f.close() 49 | directory = args.data_path 50 | model_name = args.model_name 51 | api_key = "XXX" # if closed model, using API key, else 为空 52 | api = API(api_key, model_name=model_name) 53 | # keyword = data[i]['keyword'] 54 | question_type = data['type'] 55 | zero_shot_prompt_text = data['prefix_prompt'] 56 | # print(question_type) 57 | print(question_type) 58 | export_distribute_json( 59 | api, 60 | model_name, 61 | directory, 62 | zero_shot_prompt_text, 63 | question_type, 64 | args, 65 | parallel_num=len(data['example']), 66 | ) 67 | 68 | export_union_json( 69 | directory, 70 | model_name, 71 | zero_shot_prompt_text, 72 | question_type 73 | ) 74 | 75 | 76 | 77 | # export JAVA_HOME=/usr/lib/jvm/java-11-openjdk-amd64 78 | # export JRE_HOME=${JAVA_HOME}/jre 79 | # export CLASSPATH=.:${JAVA_HOME}/lib:${JRE_HOME}/lib 80 | # export PATH=${JAVA_HOME}/bin:$PATH 81 | 82 | -------------------------------------------------------------------------------- /README_Chinese.md: -------------------------------------------------------------------------------- 1 | 2 | ## TCMBench: Benchmarking Large Language Models in Traditional Chinese Medicine from Knowledge to Clinical Reasoning 3 | Repo for TCMBench (“ShuzhiQihuang” LLMs series,The first comprehensive benchmark for evaluating LLMs in TCM) 4 | 5 | [**English**](./README.md) | [**中文**](./README_Chinese.md) 6 | 7 |

8 |
9 | 10 |
11 |

12 |

13 | GitHub 14 | GitHub top language 15 |

16 | 17 | ## 更新 18 | 19 | 💥 **TCMBench V2.0**来啦,这次加入了能体现中医多标准多因素的动态临床推理过程的测试题目外,还新生成了加入推理扰动的新问题,构成了三层不同难度的测评任务,13个子任务! 20 | 21 | 🚀 论文初始版本已经公开,欢迎引用,❗ 拒绝一切抄袭行为(微笑.jpg). 22 | 23 | ## ⚡ 简介 24 | 为了进一步有效、准确的评估大模型在中医药领域的表现,我们现建立了一个标准化、综合性的中医评测框架**TCMBench**,该评测框架将充分考虑中医药领域的复杂性和专业性,涵盖多个方面,以确保大语言模型在真实场景下的实用性和适用性。 25 | 26 | 27 | ## 📚 数据集:TCMEval 28 | 首先我们构建了首个中医评测数据集TCMEval。为了客观、真实地反映中医领域的知识体系与临床推理特点,以中医执业医师资格考试的高质量模拟题为数据来源,构建了评测数据集TCMEval。该数据集共包含6,482组问答样本,其中1,300组配有官方给定的标准解析文本,用于评估大语言模型的生成质量。所有数据不涉及个人隐私,内容聚焦于中医知识临床两方面。在中医专家的指导下,本章对原始题目进行筛选与确认,从每个学科每类题型中随机抽取了不超过100个样本,同时控制选项的均匀分布,避免数据偏斜。然后由两位中医研究生进行题目确认,确保覆盖考试的全部题型与学科。通过收集、整理和标注这些数据,我们旨在提供一个全面、准确、具有代表性的中医测试基准,来帮助评估和改大语言模型应用在中医领域性能。 29 | 30 | 31 | **🔎 任务类型** : 32 | - 🚀 **基础知识认知任务**:最低复杂度的任务为基础知识认知任务集,共包含5,473组问答样本。该任务集依据TCMLE考试中的标准题型,细分为三类代表性任务,分别对应不同维度的知识认知能力。涵盖基础知识理解任务(FKU)、知识点横向关联任务(KHC)以及临床逻辑纵向推理任务(CVR),数据见[./data/first_level](./data/first_level)。 33 | - 🚀 **综合动态临床分析任务** :在基础知识认知任务之上,结合中医专家建议进一步构建了六类具备多标准诊疗(包括辨证论治、同病异治、异病同治任务)与环境耦合多因素特征(包括社会环境、古籍经典理解以及哲学思想掌握任务)的临床推理任务集,用于评估大语言模型在真实中医临床情境下的知识融合、逻辑建模与复杂推理能力,共883组问答样本。其中每个任务均包含不少于50个样本,以支持稳定评估,数据见[./data/second_level](./data/second_level)。 34 | - 🚀 **复杂临床决策任务** :为进一步评估大语言模型在多标准中医临床环境中的综合推理能力和稳定性,在综合动态临床分析任务的基础上设计了四类复杂临床决策任务。该任务集通过对原始推理样本进行结构重构、语义扰动和任务形式转换(包括在辨证论治、同病异治以及异病同治任务上的扰动,以及中药推荐任务),生成了1,009组全新的问答样本。该任务集系统考察了模型在高复杂度推理中的一致性建模能力、推理鲁棒性与决策稳定性。数据见[./data/third_level](./data/third_level)。 35 | 36 | ## 👨‍⚕️ 数据处理 37 | 38 | ### 以基础知识认知任务为例 39 | 将试题转换为结构化的测评数据,其数据格式如下所示: 40 | 基础知识理解任务: 41 | ```json 42 | { 43 | "question": "《素问·咳论》:“五脏六腑皆令人咳”,但关系最密切的是( )。\nA.心肺\nB.肺肾\nC.肺脾\nD.肺胃\nE.肺大肠", 44 | "answer": [ 45 | "D" 46 | ], 47 | "analysis": "根据《素问·咳论》“此皆聚于胃,关于肺,使人多涕唾而面浮肿气逆也”可知与五脏六腑皆令人咳关系最密切的脏腑为肺胃。手太阴肺经起于中焦,还循胃口,上膈属肺。寒凉饮食入胃,导致中焦寒,寒气循手太阴肺经上入于肺中,导致肺寒,肺为娇脏,不耐寒热,外内寒邪并聚于肺,则肺失宣降,肺气上逆发生咳嗽。因此答案选D。", 48 | "knowledge_point": "中医经典", 49 | "index": 8196, 50 | "score": 1 51 | } 52 | ``` 53 | 临床逻辑纵向推理任务: 54 | ```json 55 | { 56 | "share_content": "刘×,男,46岁,刻下眩晕而见头重如蒙。胸闷恶心,食少多寐,苔白腻,脉濡滑。", 57 | "question": [ 58 | { 59 | "sub_question": "1).证属( )。\nA.肝阳上亢\nB.气血亏虚\nC.肾精不足\nD.痰浊中阻\nE.以上都不是\n", 60 | "answer": [ 61 | "D" 62 | ], 63 | "analysis": "" 64 | }, 65 | { 66 | "sub_question": "2).治法宜选( )。\nA.燥湿祛痰,健脾和胃\nB.补肾滋阴\nC.补肾助阳\nD.补养气血,健运脾胃\nE.平肝潜阳,滋养肝肾\n", 67 | "answer": [ 68 | "A" 69 | ], 70 | "analysis": "" 71 | }, 72 | { 73 | "sub_question": "3).方药宜选( )。\nA.右归丸\nB.左归丸\nC.半夏白术天麻汤\nD.归脾汤\nE.天麻钩藤饮\n", 74 | "answer": [ 75 | "C" 76 | ], 77 | "analysis": "" 78 | } 79 | ], 80 | "knowledge_point": "中医内科学", 81 | "index": 334, 82 | "score": 1 83 | } 84 | ``` 85 | 知识点横向关联任务: 86 | ```json 87 | { 88 | "share_content": "(共用备选答案)\nA.化痰息风,健脾祛湿\nB.清肺化痰,散结排脓\nC.疏风宣肺,化痰止咳\nD.清热化痰,平肝息风\nE.润肺清热,理气化痰\n", 89 | "question": [ 90 | { 91 | "sub_question": "1).贝母瓜蒌散的功用是( )。", 92 | "answer": [ 93 | "E" 94 | ], 95 | "analysis": "" 96 | }, 97 | { 98 | "sub_question": "2).半夏白术天麻汤的功用是( )。", 99 | "answer": [ 100 | "A" 101 | ], 102 | "analysis": "" 103 | } 104 | ], 105 | "knowledge_point": "方剂学", 106 | "index": 1938, 107 | "score": 1 108 | } 109 | ``` 110 | 111 | ## 🧐 测评细节 112 | 113 | 我们设计了任务自适应的prompt,要求LLM回答题目,并给出答案和分析,评测框架由如下部分组成: 114 | 115 | | 文件名 | 说明 | 116 | | -------------------------- | -------------- | 117 | | [./pipline/choice_bench.py](./pipline/choice_bench.py) | 设置不同的任务,引导LLMs生成答案与解析 | 118 | | [./pipline/bench_function.py](./pipline/bench_function.py) | 测试相关函数 | 119 | | [./pipline/Acc.py](./pipline/Acc.py) | 计算准确率 | 120 | | [./pipline/Model_API.py](./pipline/Model_API.py)| 调用模型接口,以openai为例,可根据测评模型进行调整 | 121 | | [./TCMBench_code/explain_evaluation.py](./TCMBench_code/explain_evaluation.py)| 采用ROUGE-1,ROUGE-L, SARI,BerScore, BartScore, 以及我们提出的SKScore评估模型解析质量 | 122 | |[./HumanTrue.json](./HumanTrue.json)| HumanTrue数据集| 123 | 124 | 125 | 首先采用[/pipline/choice_bench.py](./pipline/choice_bench.py),测试模型,得到模型生成的答案与解析 126 | ``` 127 | 首先若有必要,请设置代理: 128 | os.environ['HTTPS_PROXY']="your proxy" 129 | 其次,若采用闭源模型则将你的Key填写到指定位置,否则置空: 130 | api_key = "your key" 131 | 然后通过设置不同的--data_path 和 --sys_prompt进行不同任务测试,设置--model_name 来调用不同的模型 132 | 使用以下命令在FKU任务对gpt-4-0613进行测试: 133 | python choice_bench.py --data_path ../data/first_level --sys_prompt FKU.json --model_name gpt-4-0613 134 | ``` 135 | 136 | 通过[./pipline/Acc.py](./pipline/Acc.py),设置--data_path、--queation_type、--model_name,得到不同模型在不同任务上的准确率得分。 137 | ``` 138 | python Acc.py --data_path ../data/first_level --queation_type FKU --model_name gpt-4-0613 139 | ``` 140 | 141 | 通过[./TCMBench_code/explain_evaluation.py](./TCMBench_code/explain_evaluation.py),设置--model_name,得到不同模型在6个指标上的解析得分。 142 | ``` 143 | python explain_evaluation.py --model_name gpt-4-0613 144 | ``` 145 | 其中指标加载的模型见[code](https://huggingface.co/WJing123/TCMBench_code) 146 | 147 | 148 | 👨‍⚕️ 此外,该工作中也介绍了我们之前构建中医LLMs,ShenNong,欢迎大家关注我们的中医大模型开源项目**ShenNong-TCM-LLM**: 149 | 150 | - 🚀 [ShenNong-TCM](https://github.com/ywjawmw/ShenNong-TCM-LLM) : 为推动LLM在中医药领域的发展和落地,提升LLM的在中医药方面的知识与回答医学咨询的能力,我们推出了**ShenNong**中医药大规模语言模型。基于[中医药指令数据集SN-QA](https://huggingface.co/datasets/michaelwzhu/ShenNong_TCM_Dataset)。 151 | 152 | 以及我们其他医疗大模型开源项目: 153 | - 🚀 [“医”心医意——智能中医传承创新辅助平台](https://github.com/ywjawmw/AI4TCM-Platform) : 数智岐黄系列平台,针对已有的中医传承平台无法覆盖全面的多模态数据这一挑战,我们构建了更全面的中西医知识图谱。其次,针对中医经验传承效率低这一挑战,我们提出了可解释的药方分析技术来挖掘处方信息,自动分析从症状到中药这一立体诊疗过程并给出分析的科学依据。同时提供了一个公平的辅助平台,让青年医师、中医学生等人群快速掌握先进的中医知识,传承经验。 154 | - 🚀 [ChatMed-Consult](https://huggingface.co/michaelwzhu/ChatMed-Consult) : 基于[中文医疗在线问诊数据集ChatMed_Consult_Dataset](https://huggingface.co/datasets/michaelwzhu/ChatMed_Consult_Dataset)的50w+在线问诊+ChatGPT回复作为训练集。模型主干为[LlaMA-7b](https://github.com/facebookresearch/llama),融合了[Chinese-LlaMA-Alpaca](https://github.com/ymcui/Chinese-LLaMA-Alpaca)的LoRA权重与中文扩展词表,然后再进行基于LoRA的参数高效微调。我们将全部代码都进行了公开; 155 | - 🚀 [PromptCBLUE中文医疗大模型评测基准](https://github.com/michael-wzhu/PromptCBLUE): 将[CBLUE](https://tianchi.aliyun.com/dataset/95414)基准进行改造为提示学习模式,形成对大模型的中文医疗知识与医疗文本处理能力的评测基准。PromptCBLUE旨在采用一个生成式大模型即可完成医疗NLP相关的各种不同任务,如病历结构化,问诊,病例文书撰写等。 156 | 157 | ## 致谢 158 | 159 | 本项目基于大模型给出的API进行开发,同时参考了大语言模型在高考试题上的测评任务,在此对相关项目和研究开发人员表示感谢。 160 | 161 | - [ChatGPT](https://openai.com/blog/chatgpt) 162 | - [ChatGLM](https://github.com/THUDM/ChatGLM-6B](https://github.com/THUDM/GLM) 163 | - [GaoKao-Bench](https://github.com/OpenLMLab/GAOKAO-Bench) 164 | 165 | 166 | ## 引用 167 | 168 | 如果你使用了本项目的数据或者代码,请声明引用: 169 | 170 | ```bash 171 | @misc{yue2023 TCMBench, 172 | title={TCMBench: Benchmarking Large Language Models in Traditional Chinese Medicine from Knowledge to Clinical Reasoning}, 173 | author={Wenjing Yue, Ming guan, Wei Zhu, Xiaoling Wang, Honglin Li}, 174 | year={2023}, 175 | publisher = {GitHub}, 176 | journal = {GitHub repository}, 177 | howpublished = {\url{https://github.com/ywjawmw/TCMBench}}, 178 | } 179 | 180 | ``` 181 | 182 | ## 团队介绍 183 | 184 | 本项目在华东师范大学计算机科学与技术学院智能知识管理与服务团队王晓玲教授,华东师范大学药学院院长、人工智能新药创智中心主任李洪林教授的指导下完成。 185 | 186 | 187 | 188 | -------------------------------------------------------------------------------- /TCMBench_code/sari.py: -------------------------------------------------------------------------------- 1 | # ======================================================= 2 | # SARI -- Text Simplification Tunable Evaluation Metric 3 | # ======================================================= 4 | # 5 | # Author: Wei Xu (UPenn xwe@cis.upenn.edu) 6 | # 7 | # A Python implementation of the SARI metric for text simplification 8 | # evaluation in the following paper 9 | # 10 | # "Optimizing Statistical Machine Translation for Text Simplification" 11 | # Wei Xu, Courtney Napoles, Ellie Pavlick, Quanze Chen and Chris Callison-Burch 12 | # In Transactions of the Association for Computational Linguistics (TACL) 2015 13 | # 14 | # There is also a Java implementation of the SARI metric 15 | # that is integrated into the Joshua MT Decoder. It can 16 | # be used for tuning Joshua models for a real end-to-end 17 | # text simplification model. 18 | # 19 | 20 | from __future__ import division 21 | from collections import Counter 22 | import sys 23 | 24 | 25 | 26 | def ReadInFile (filename): 27 | 28 | with open(filename) as f: 29 | lines = f.readlines() 30 | lines = [x.strip() for x in lines] 31 | return lines 32 | 33 | 34 | def SARIngram(sgrams, cgrams, rgramslist, numref): 35 | rgramsall = [rgram for rgrams in rgramslist for rgram in rgrams] 36 | rgramcounter = Counter(rgramsall) 37 | 38 | sgramcounter = Counter(sgrams) 39 | sgramcounter_rep = Counter() 40 | for sgram, scount in sgramcounter.items(): 41 | sgramcounter_rep[sgram] = scount * numref 42 | 43 | cgramcounter = Counter(cgrams) 44 | cgramcounter_rep = Counter() 45 | for cgram, ccount in cgramcounter.items(): 46 | cgramcounter_rep[cgram] = ccount * numref 47 | 48 | 49 | # KEEP 50 | keepgramcounter_rep = sgramcounter_rep & cgramcounter_rep 51 | keepgramcountergood_rep = keepgramcounter_rep & rgramcounter # if sen 和 ref相同,那么这个值和keepgramcounter_rep是一样的,也就是计算s和c的相似(和r的相似) 52 | keepgramcounterall_rep = sgramcounter_rep & rgramcounter # if sen 和 ref相同,那么这个值和ref是一样的 (计算s和r的相似) 53 | 54 | keeptmpscore1 = 0 55 | keeptmpscore2 = 0 56 | for keepgram in keepgramcountergood_rep: 57 | keeptmpscore1 += keepgramcountergood_rep[keepgram] / keepgramcounter_rep[keepgram] 58 | keeptmpscore2 += keepgramcountergood_rep[keepgram] / keepgramcounterall_rep[keepgram] 59 | #print "KEEP", keepgram, keepscore, cgramcounter[keepgram], sgramcounter[keepgram], rgramcounter[keepgram] 60 | keepscore_precision = 0 61 | if len(keepgramcounter_rep) > 0: 62 | keepscore_precision = keeptmpscore1 / len(keepgramcounter_rep) 63 | keepscore_recall = 0 64 | if len(keepgramcounterall_rep) > 0: 65 | keepscore_recall = keeptmpscore2 / len(keepgramcounterall_rep) 66 | keepscore = 0 67 | if keepgramcounterall_rep == rgramcounter: # if sen 和 ref相同 68 | keepscore = keepscore_recall 69 | else: 70 | if keepscore_precision > 0 or keepscore_recall > 0: 71 | keepscore = 2 * keepscore_precision * keepscore_recall / (keepscore_precision + keepscore_recall) # f1 score 72 | 73 | 74 | # DELETION 75 | delgramcounter_rep = sgramcounter_rep - cgramcounter_rep # s和c的差异值 76 | delgramcountergood_rep = delgramcounter_rep - rgramcounter # s和c和r的差异值, if s=r, 这个值=0 77 | delgramcounterall_rep = sgramcounter_rep - rgramcounter # s和r的差异值, if s=r, 这个值=0 78 | deltmpscore1 = 0 79 | deltmpscore2 = 0 80 | for delgram in delgramcountergood_rep: 81 | deltmpscore1 += delgramcountergood_rep[delgram] / delgramcounter_rep[delgram] 82 | deltmpscore2 += delgramcountergood_rep[delgram] / delgramcounterall_rep[delgram] 83 | delscore_precision = 0 84 | if len(delgramcounter_rep) > 0: 85 | delscore_precision = deltmpscore1 / len(delgramcounter_rep) 86 | delscore_recall = 0 87 | if len(delgramcounterall_rep) > 0: 88 | delscore_recall = deltmpscore1 / len(delgramcounterall_rep) 89 | delscore = 0 90 | if delscore_precision > 0 or delscore_recall > 0: 91 | delscore = 2 * delscore_precision * delscore_recall / (delscore_precision + delscore_recall) 92 | 93 | 94 | # ADDITION 95 | addgramcounter = set(cgramcounter) - set(sgramcounter) 96 | addgramcountergood = set(addgramcounter) & set(rgramcounter) 97 | addgramcounterall = set(rgramcounter) - set(sgramcounter) 98 | 99 | addtmpscore = 0 100 | for addgram in addgramcountergood: 101 | addtmpscore += 1 102 | 103 | addscore_precision = 0 104 | addscore_recall = 0 105 | if len(addgramcounter) > 0: 106 | addscore_precision = addtmpscore / len(addgramcounter) 107 | if len(addgramcounterall) > 0: 108 | addscore_recall = addtmpscore / len(addgramcounterall) 109 | addscore = 0 110 | if addscore_precision > 0 or addscore_recall > 0: 111 | addscore = 2 * addscore_precision * addscore_recall / (addscore_precision + addscore_recall) 112 | 113 | return (keepscore, delscore_precision, addscore) 114 | 115 | 116 | def SARIsent (ssent, csent, rsents) : 117 | numref = len(rsents) 118 | 119 | # s1grams = ssent.lower().split(" ") 120 | # c1grams = csent.lower().split(" ") 121 | # zh 122 | s1grams = [s for s in ssent] 123 | c1grams = [c for c in csent] 124 | s2grams = [] 125 | c2grams = [] 126 | s3grams = [] 127 | c3grams = [] 128 | s4grams = [] 129 | c4grams = [] 130 | 131 | r1gramslist = [] 132 | r2gramslist = [] 133 | r3gramslist = [] 134 | r4gramslist = [] 135 | for rsent in rsents: 136 | # r1grams = rsent.lower().split(" ") 137 | r1grams = [r for r in rsent] 138 | r2grams = [] 139 | r3grams = [] 140 | r4grams = [] 141 | r1gramslist.append(r1grams) 142 | for i in range(0, len(r1grams)-1) : 143 | if i < len(r1grams) - 1: 144 | # r2gram = r1grams[i] + " " + r1grams[i+1] 145 | r2gram = r1grams[i] + r1grams[i + 1] 146 | r2grams.append(r2gram) 147 | if i < len(r1grams)-2: 148 | # r3gram = r1grams[i] + " " + r1grams[i+1] + " " + r1grams[i+2] 149 | r3gram = r1grams[i] + r1grams[i + 1] + r1grams[i + 2] 150 | r3grams.append(r3gram) 151 | if i < len(r1grams)-3: 152 | # r4gram = r1grams[i] + " " + r1grams[i+1] + " " + r1grams[i+2] + " " + r1grams[i+3] 153 | r4gram = r1grams[i] + r1grams[i + 1] + r1grams[i + 2] + r1grams[i + 3] 154 | r4grams.append(r4gram) 155 | r2gramslist.append(r2grams) 156 | r3gramslist.append(r3grams) 157 | r4gramslist.append(r4grams) 158 | 159 | for i in range(0, len(s1grams)-1) : 160 | if i < len(s1grams) - 1: 161 | # s2gram = s1grams[i] + " " + s1grams[i+1] 162 | s2gram = s1grams[i] + s1grams[i + 1] 163 | s2grams.append(s2gram) 164 | if i < len(s1grams)-2: 165 | # s3gram = s1grams[i] + " " + s1grams[i+1] + " " + s1grams[i+2] 166 | s3gram = s1grams[i] + s1grams[i + 1] + s1grams[i + 2] 167 | s3grams.append(s3gram) 168 | if i < len(s1grams)-3: 169 | # s4gram = s1grams[i] + " " + s1grams[i+1] + " " + s1grams[i+2] + " " + s1grams[i+3] 170 | s4gram = s1grams[i] + s1grams[i + 1] + s1grams[i + 2] + s1grams[i + 3] 171 | s4grams.append(s4gram) 172 | 173 | for i in range(0, len(c1grams)-1) : 174 | if i < len(c1grams) - 1: 175 | # c2gram = c1grams[i] + " " + c1grams[i+1] 176 | c2gram = c1grams[i] + c1grams[i + 1] 177 | c2grams.append(c2gram) 178 | if i < len(c1grams)-2: 179 | # c3gram = c1grams[i] + " " + c1grams[i + 1] + " " + c1grams[i + 2] 180 | c3gram = c1grams[i] + c1grams[i + 1] + c1grams[i + 2] 181 | c3grams.append(c3gram) 182 | if i < len(c1grams)-3: 183 | # c4gram = c1grams[i] + " " + c1grams[i+1] + " " + c1grams[i+2] + " " + c1grams[i+3] 184 | c4gram = c1grams[i] + c1grams[i + 1] + c1grams[i + 2] + c1grams[i + 3] 185 | c4grams.append(c4gram) 186 | 187 | 188 | (keep1score, del1score, add1score) = SARIngram(s1grams, c1grams, r1gramslist, numref) 189 | (keep2score, del2score, add2score) = SARIngram(s2grams, c2grams, r2gramslist, numref) 190 | (keep3score, del3score, add3score) = SARIngram(s3grams, c3grams, r3gramslist, numref) 191 | (keep4score, del4score, add4score) = SARIngram(s4grams, c4grams, r4gramslist, numref) 192 | avgkeepscore = sum([keep1score,keep2score,keep3score,keep4score])/4 193 | avgdelscore = sum([del1score,del2score,del3score,del4score])/4 194 | avgaddscore = sum([add1score,add2score,add3score,add4score])/4 195 | finalscore = (avgkeepscore + avgdelscore + avgaddscore ) / 3 196 | 197 | return avgkeepscore, avgdelscore, avgaddscore, finalscore 198 | 199 | 200 | def main(): 201 | 202 | # fnamenorm = "./turkcorpus/test.8turkers.tok.norm" 203 | # fnamesimp = "./turkcorpus/test.8turkers.tok.simp" 204 | # fnameturk = "./turkcorpus/test.8turkers.tok.turk." 205 | 206 | 207 | ssent = "About 95 species are currently accepted ." 208 | csent1 = "About 95 you now get in ." 209 | csent2 = "About 95 species are now agreed ." 210 | csent3 = "About 95 species are currently agreed ." 211 | rsents = ["About 95 species are currently accepted ."] 212 | # rsents = ["About 95 species are currently known .", "About 95 species are now accepted .", "95 species are now accepted ."] 213 | 214 | print(SARIsent(ssent, csent1, rsents)) 215 | print(SARIsent(ssent, csent2, rsents)) 216 | print(SARIsent(ssent, csent3, rsents)) 217 | 218 | 219 | if __name__ == '__main__': 220 | main() -------------------------------------------------------------------------------- /pipline/Model_API.py: -------------------------------------------------------------------------------- 1 | # --*-- conding:utf-8 --*-- 2 | # @Time : 2024/1/11 17:21 3 | # @Author : YWJ 4 | # @Email : 52215901025@stu.ecnu.edu.cn 5 | # @File : Model_API.py 6 | # @Software : PyCharm 7 | # @Description : 各个模型的API接口 8 | import os 9 | import openai 10 | import requests 11 | import urllib 12 | import json 13 | import time 14 | import sys 15 | from http import HTTPStatus 16 | import dashscope 17 | import random 18 | # from vllm import LLM, SamplingParams 19 | # from transformers import AutoModelForCausalLM, AutoTokenizer 20 | from dashscope import Generation 21 | from dashscope.api_entities.dashscope_response import Role 22 | 23 | class API(): 24 | def __init__(self, api_key_list: str, model_name: str = "gpt-3.5-turbo", temperature: float = 0.0, 25 | max_tokens: int = 1024): 26 | self.api_key_list = api_key_list 27 | self.model_name = model_name # 新的model, 支持1w+ 28 | self.temperature = temperature 29 | self.max_tokens = max_tokens 30 | # if self.api_key_list == "": 31 | # self.llm = LLM("/data/xxx", tokenizer_mode='auto', # local model 32 | # trust_remote_code=True, 33 | # enforce_eager=True, 34 | # enable_prefix_caching=True) 35 | # self.sampling_params = SamplingParams(temperature=0.85, top_p=0.8, max_tokens=512) 36 | # self.tokenizer = AutoTokenizer.from_pretrained("/data/xxx", trust_remote_code=True) 37 | 38 | # GPT系列 39 | def send_request_turbo(self, prompt, question): 40 | """ 41 | """ 42 | zero_shot_prompt_message = {'role': 'system', 'content': prompt} 43 | 44 | messages = [zero_shot_prompt_message] 45 | question = sensitive(question) 46 | message = {"role": "user", "content": question} 47 | print(f"LLM的Prompt是{'*' * 100}\n{zero_shot_prompt_message['content']}\n{message['content']}") 48 | messages.append(message) 49 | 50 | output = {} 51 | while True: 52 | try: 53 | # os.environ['HTTPS_PROXY'] = "http://127.0.0.1:10809" 54 | openai.api_base = "https://xiaoai.plus/v1" 55 | openai.api_key = self.api_key_list 56 | output = openai.ChatCompletion.create( 57 | model=self.model_name, 58 | messages=messages, 59 | temperature=self.temperature 60 | ) 61 | answer = revers_sensitive(output['choices'][0]['message']['content']) 62 | print(answer) 63 | return [answer] 64 | except Exception as e: 65 | print('Exception:', e) 66 | print("原始Prompt:") 67 | sys.exit() 68 | 69 | return [output] 70 | 71 | # 多轮会话 72 | def send_request_chat(self, prompt, share_content, question): 73 | """ 74 | """ 75 | zero_shot_prompt_message = {'role': 'system', 'content': prompt} 76 | 77 | messages = [zero_shot_prompt_message] 78 | share_content = sensitive(share_content) 79 | message = {"role": "user", "content": share_content} 80 | messages.append(message) 81 | output_chat = [] 82 | i = 0 83 | error_num = 0 84 | while i < len(question): 85 | sub_question = question[i] 86 | sub_question['sub_question'] = sensitive(sub_question['sub_question']) 87 | message = {"role": "user", "content": sub_question['sub_question']} 88 | messages.append(message) 89 | # os.environ['HTTPS_PROXY'] = "http://127.0.0.1:33210" 90 | # os.environ['HTTPS_PROXY'] = "http://127.0.0.1:10809" 91 | # os.environ['OPENAI_API_KEY'] = self.api_key_list 92 | # os.environ["OPENAI_BASE_URL"] = "https://api.xiaoai.plus/v1" 93 | openai.api_base = "https://xiaoai.plus/v1" 94 | openai.api_key = self.api_key_list 95 | try: 96 | output = openai.ChatCompletion.create( 97 | model=self.model_name, 98 | messages=messages, 99 | temperature=self.temperature 100 | ) 101 | answer = revers_sensitive(output['choices'][0]['message']['content']) 102 | answer = sensitive(answer) 103 | messages.append({"role": "assistant", "content": answer}) 104 | output_chat.append(answer) 105 | i += 1 106 | print(i, ":", "success!") 107 | # print(output) 108 | except Exception as e: 109 | print('Exception:', e) 110 | print("原始Prompt:") 111 | for m in messages: 112 | print(m) 113 | print("—" * 100) 114 | # if "overloaded" or "Bad" in e: 115 | if "max" in e.args[0]: # 说明到了最大的token, 将上面存储的靠前的子问题删除几个 116 | time.sleep(5) 117 | if error_num == 0: 118 | if len(messages) < 13: 119 | star_index = -1 * len(messages) + 2 120 | else: 121 | star_index = -11 # 前5个 122 | else: 123 | star_index += 2 # 如果还超长,那么就不断的逐个删除子问题 124 | if star_index >= -1: 125 | print("无法处理该问题") 126 | output_chat.append("") 127 | error_num = 0 128 | i += 1 129 | print("#" * 100) 130 | messages = messages[:2] + messages[star_index: -1] 131 | print("最大token, 保留历史前几个问题") 132 | error_num = 1 133 | for m in messages: 134 | print(m) 135 | print("*" * 100) 136 | else: 137 | time.sleep(5) # 递归调用自身进行重试(i不变) 138 | print("重复提问") 139 | messages = messages[:-1] 140 | for m in messages: 141 | print(m) 142 | print("*" * 100) 143 | error_num = 0 144 | # output_chat.append({}) 145 | # i += 1 146 | # print("失败,默认回答不出内容!") 147 | time.sleep(5) 148 | 149 | time.sleep(2) 150 | 151 | return output_chat 152 | 153 | # DDST hard 154 | def send_request_hard(self, prompt, option, question, option_num): 155 | """ 156 | """ 157 | prompt = prompt.replace("", option).replace("", str(option_num)) 158 | zero_shot_prompt_message = {'role': 'system', 'content': prompt} 159 | 160 | messages = [zero_shot_prompt_message] 161 | question = sensitive(question) 162 | message = {"role": "user", "content": question} 163 | print(f"LLM的Prompt是{'*' * 100}\n{message['content']}") 164 | messages.append(message) 165 | 166 | output = {} 167 | while True: 168 | try: 169 | # os.environ['HTTPS_PROXY'] = "http://127.0.0.1:10809" 170 | openai.api_base = "https://xiaoai.plus/v1" 171 | openai.api_key = self.api_key_list 172 | output = openai.ChatCompletion.create( 173 | model=self.model_name, 174 | messages=messages, 175 | temperature=self.temperature 176 | ) 177 | answer = revers_sensitive(output['choices'][0]['message']['content']) 178 | print(answer) 179 | return [answer] 180 | except Exception as e: 181 | print('Exception:', e) 182 | print("原始Prompt:") 183 | sys.exit() 184 | # herb_predict 185 | def send_request_TCM_Rec(self, prompt, question): 186 | """ 187 | """ 188 | zero_shot_prompt_message = {'role': 'system', 'content': "你是一个中药推荐系统,你需要根据症状信息推荐20个中药。"} 189 | 190 | messages = [zero_shot_prompt_message] 191 | question = sensitive(question) 192 | question = f"{prompt}\n症状信息为:{question}。\n" 193 | message = {"role": "user", "content": question} 194 | print(f"LLM的Prompt是{'*' * 100}\n{zero_shot_prompt_message['content']}\n{message['content']}") 195 | messages.append(message) 196 | 197 | output = {} 198 | while True: 199 | try: 200 | # os.environ['HTTPS_PROXY'] = "http://127.0.0.1:10809" 201 | openai.api_base = "https://xiaoai.plus/v1" 202 | openai.api_key = self.api_key_list 203 | output = openai.ChatCompletion.create( 204 | model=self.model_name, 205 | messages=messages, 206 | temperature=self.temperature 207 | ) 208 | answer = revers_sensitive(output['choices'][0]['message']['content']) 209 | print(answer) 210 | return [answer] 211 | except Exception as e: 212 | print('Exception:', e) 213 | print("原始Prompt:") 214 | sys.exit() 215 | 216 | def sensitive(sentence): 217 | sentence = sentence.replace("阴道", "term-YD") 218 | sentence = sentence.replace("射精", "term-SJ") 219 | return sentence 220 | 221 | def revers_sensitive(sentence): 222 | sentence = sentence.replace("term-YD", "阴道") 223 | sentence = sentence.replace("term-SJ", "射精") 224 | return sentence -------------------------------------------------------------------------------- /pipline/Acc.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # !/usr/bin/python3 3 | # -*- coding: utf-8 -*- 4 | # @Time : 2025/9/28 21:55 5 | # @Author : Wenjing 6 | # @File : Acc.py 7 | # @Desc : 计算各个任务的准确率 8 | 9 | 10 | import json 11 | import os 12 | from bench_function import test_correction_score_A12, test_correction_score_A34 13 | import matplotlib.pyplot as plt 14 | import numpy as np 15 | import argparse 16 | 17 | 18 | def parse_args(): 19 | parser = argparse.ArgumentParser(description="parameter of LLMs") 20 | parser.add_argument( 21 | "--data_path", 22 | type=str, 23 | default='../data/first_level', # 注意按照三个不同的级别进行选择不同的测试任务 24 | help="测试数据", 25 | ) 26 | parser.add_argument( 27 | "--model_name", 28 | type=str, 29 | default='gpt-4-0613', 30 | help="The LLM name.", 31 | ) 32 | parser.add_argument( 33 | "--question_type", 34 | type=str, 35 | default='FKU', 36 | help="选择不同测试题类型.", 37 | ) 38 | parser.add_argument( 39 | "--start_num", 40 | type=int, 41 | default=0, 42 | help="保存文档的起始id", 43 | ) 44 | args = parser.parse_args() 45 | return args 46 | 47 | def the_first_task(data_FKU, data_CVR, data_KHC): 48 | score_A12, false_dict_A12, A12_kp_dict, score_num_A12, all_num_A12, correct_dict_A12 = test_correction_score_A12(data_FKU) 49 | score_A34, false_dict_A34, A34_kp_dict, score_num_A34, all_num_A34 = test_correction_score_A34(data_CVR) 50 | score_B1, false_dict_B1, B1_kp_dict, score_num_B1, all_num_B1 = test_correction_score_A34(data_KHC) 51 | 52 | print("基础知识认知任务测试结果") 53 | print("A1-A2题目正确率:%f \nA3-A4题目正确率:%f \nB1题目正确率:%f \n" % (score_A12, score_A34, score_B1)) 54 | print("A3-A4-k题目正确率:%f" % score_A34) 55 | # print("A3-A4-k-shot题目正确率:%f" % (score_A34_k)) 56 | print("总的准确率:%f", 57 | (score_num_A12 + score_num_A34 + score_num_B1) / (all_num_A12 + all_num_A34 + all_num_B1)) 58 | 59 | # 创建一个字典来保存相同Key的合并后的值 60 | merged_values = {} 61 | 62 | A1_num, A3_num, B1_num = 0, 0, 0 63 | for k, v in A12_kp_dict.items(): 64 | A1_num += v[1] 65 | 66 | for k, v in A34_kp_dict.items(): 67 | A3_num += v[1] 68 | 69 | for k, v in B1_kp_dict.items(): 70 | B1_num += v[1] 71 | 72 | print(A1_num, A3_num, B1_num) 73 | 74 | # 遍历所有相同的Key,并将它们的值合并成一个列表 75 | num = 0 76 | kp_num = 0 77 | 78 | A12_res_kp, A34_res_kp, B1_res_kp = dict(), dict(), dict() 79 | res_kp = dict() 80 | 81 | print("A12各个知识点的准确率:", "*" * 100) 82 | for key, value in A12_kp_dict.items(): 83 | print(key, "\t", value[0] / value[1]) 84 | A12_res_kp[key] = value[0] / value[1] 85 | 86 | print("A34各个知识点的准确率:", "*" * 100) 87 | for key, value in A34_kp_dict.items(): 88 | print(key, "\t", value[0] / value[1]) 89 | A34_res_kp[key] = value[0] / value[1] 90 | 91 | print("B1各个知识点的准确率:", "*" * 100) 92 | for key, value in B1_kp_dict.items(): 93 | print(key, "\t", value[0] / value[1]) 94 | B1_res_kp[key] = value[0] / value[1] 95 | 96 | print("总的知识点准确率:", "*" * 100) 97 | for key in A12_kp_dict.keys(): 98 | # 从每个字典中获取Key对应的值的列表 99 | values1 = A12_kp_dict.get(key, []) 100 | values2 = A34_kp_dict.get(key, []) 101 | values3 = B1_kp_dict.get(key, []) 102 | if len(values1) == 0: 103 | values1 = [0, 0] 104 | if len(values2) == 0: 105 | values2 = [0, 0] 106 | if len(values3) == 0: 107 | values3 = [0, 0] 108 | 109 | # 将三个字典中相同Key的值合并成一个列表 110 | merged_values[key] = [0, 0, 0.0] 111 | merged_values[key][0] = values1[0] + values2[0] + values3[0] 112 | merged_values[key][1] = values1[1] + values2[1] + values3[1] 113 | merged_values[key][2] = merged_values[key][0] / merged_values[key][1] 114 | # print(key, ": ", merged_values[key][0]/merged_values[key][1]) 115 | print(key, "\t", merged_values[key][0] / merged_values[key][1]) 116 | res_kp[key] = merged_values[key][0] / merged_values[key][1] 117 | num += merged_values[key][1] 118 | print(A12_res_kp) 119 | print(A34_res_kp) 120 | print(B1_res_kp) 121 | print(res_kp) 122 | 123 | def A12_type_task(data_dict): 124 | score = 0 125 | all_num = 0 126 | for data in data_dict['example']: 127 | all_num += 1 128 | true_answer = data['standard_answer'] 129 | model_answer = data['model_answer'] 130 | if true_answer == model_answer: 131 | score += 1 132 | print(score / all_num) 133 | 134 | def A34_type_task(data_dict): 135 | score = 0 136 | all_num = 0 137 | for data in data_dict['example']: 138 | question = data["question"] 139 | for sub_question in question: 140 | all_num += 1 141 | standard_answer = sub_question['standard_answer'] 142 | model_answer = sub_question['model_answer'] 143 | if standard_answer == model_answer: 144 | score += 1 145 | print(score / all_num) 146 | 147 | # 如果两个元素的sym_set列表的交集为sym_set的长度,那么我们将这些元素分为一组,添加到一个新的列表中 148 | def new_test_gt(file_path): 149 | with open(file_path, 'r', encoding='utf-8') as f: 150 | data = json.load(f)["example"] 151 | new_data = {} 152 | visited = set() 153 | count = 0 154 | for i, d1 in enumerate(data): 155 | if i in visited: 156 | continue 157 | sym_set1 = set(d1['sym_set']) 158 | sym_set_key = "-".join(sorted(sym_set1)) 159 | new_data[sym_set_key] = [d1["herb_set"]] 160 | visited.add(i) 161 | count += 1 162 | for j in range(i+1, len(data)): 163 | if j in visited: 164 | continue 165 | d2 = data[j] 166 | # 求d1和d2的sym_set的交集 167 | sym_set2 = set(d2['sym_set']) 168 | if len(sym_set1) == len(sym_set2): 169 | if len(sym_set1 & sym_set2) == len(sym_set1): 170 | new_data[sym_set_key].append(d2["herb_set"]) 171 | count += 1 172 | visited.add(j) 173 | # print(count) 174 | return new_data 175 | 176 | 177 | 178 | def test_metric(gt_data, test_data): 179 | test_group_list = test_data["example"] 180 | Ks = [20] 181 | result = {'precision': np.zeros(len(Ks)), 'recall': np.zeros(len(Ks)), 182 | 'ndcg': np.zeros(len(Ks)), 'rmrr': np.zeros(len(Ks))} 183 | 184 | precision_n = np.zeros(len(Ks)) 185 | recall_n = np.zeros(len(Ks)) 186 | ndcg_n = np.zeros(len(Ks)) 187 | rmrr_n = np.zeros(len(Ks)) 188 | topN = Ks 189 | 190 | for index in range(len(test_group_list)): 191 | entry = test_group_list[index] 192 | sym_set = sorted(set(entry["sym_set"])) 193 | if "-".join(sym_set) not in gt_data.keys(): 194 | print(test_group_list[index]) 195 | break 196 | v_list = gt_data["-".join(sym_set)] # sym-index's true herb set list 197 | rating = entry["model_output"] # sym-index's predicted herb set list 198 | K_max = topN[len(topN) - 1] 199 | for ii in range(len(topN)): # topN: [5, 10, 15, 20] 200 | top_recall, top_precision, top_ndcg, top_rmrr, top_iou = 0., 0., 0., 0., 0. 201 | for v in v_list: # v:对应的ground truth 202 | r = [] 203 | for i in rating[:K_max]: 204 | herb = i.replace("\"", "") 205 | if herb in v: 206 | r.append(1) 207 | else: 208 | r.append(0) 209 | number = 0 210 | all_list_number = 0 211 | herb_results = [] # 推荐列表中herb 集合 212 | for i in rating[:topN[ii]]: 213 | herb = i.replace("\"", "") 214 | herb_results.append(herb) 215 | if herb in v: 216 | number += 1 217 | herb_v = set(herb_results + v) 218 | all_list_number = len(herb_v) 219 | # todo: modified MRR to Rank-MRR 220 | mrr_score = 0. 221 | for a_rank in range(len(v)): # herb 在grand truth中的位置a_rank 222 | if v[a_rank] in herb_results: 223 | a_refer = herb_results.index(v[a_rank]) # herb 在推荐列表中的位置a_refer 224 | mrr_score += 1.0 / (abs(a_refer - a_rank) + 1) 225 | if float(number / topN[ii]) > top_precision: # 使用precision选择GT 226 | top_precision = float(number / topN[ii]) 227 | top_recall = float(number / len(v)) 228 | top_ndcg = ndcg_at_k(r, topN[ii]) 229 | top_rmrr = mrr_score / len(v) 230 | precision_n[ii] = precision_n[ii] + top_precision # [ii]所有测试数据top k的precision之和 231 | recall_n[ii] = recall_n[ii] + top_recall 232 | ndcg_n[ii] = ndcg_n[ii] + top_ndcg 233 | rmrr_n[ii] = rmrr_n[ii] + top_rmrr 234 | for ii in range(len(topN)): 235 | result['precision'][ii] = precision_n[ii] / len(test_group_list) 236 | result['recall'][ii] = recall_n[ii] / len(test_group_list) 237 | result['ndcg'][ii] = ndcg_n[ii] / len(test_group_list) 238 | result['rmrr'][ii] = rmrr_n[ii] / len(test_group_list) 239 | return result 240 | 241 | def dcg_at_k(r, k, method=1): 242 | """Score is discounted cumulative gain (dcg) 243 | Relevance is positive real values. Can use binary 244 | as the previous methods. 245 | Returns: 246 | Discounted cumulative gain 247 | """ 248 | r = np.asfarray(r)[:k] 249 | if r.size: 250 | if method == 0: 251 | return r[0] + np.sum(r[1:] / np.log2(np.arange(2, r.size + 1))) 252 | elif method == 1: 253 | return np.sum(r / np.log2(np.arange(2, r.size + 2))) 254 | else: 255 | raise ValueError('method must be 0 or 1.') 256 | return 0. 257 | 258 | 259 | def ndcg_at_k(r, k, method=1): 260 | """Score is normalized discounted cumulative gain (ndcg) 261 | Relevance is positive real values. Can use binary 262 | as the previous methods. 263 | Returns: 264 | Normalized discounted cumulative gain 265 | """ 266 | # dcg_max = dcg_at_k(np.ones_like(r), k, method) 267 | dcg_max = dcg_at_k(sorted(r, reverse=True), k, method) 268 | if not dcg_max: 269 | return 0. 270 | return dcg_at_k(r, k, method) / dcg_max 271 | 272 | def herb_rec_test(gt_data, data): 273 | result = test_metric(gt_data, data) 274 | res_score = "" 275 | for key, value in result.items(): 276 | res = [str(round(v, 4)) for v in value] 277 | print(key + ":" + ", ".join(res)) 278 | 279 | 280 | if __name__ == "__main__": 281 | args = parse_args() 282 | # 第一级三个任务一起计算,因为需要计算一个总分 283 | if "first_level" in args.data_path: 284 | question_type = "FKU" 285 | with open(f"{args.data_path}/{args.model_name}_{question_type}.json", "r", encoding="utf-8") as f: 286 | data_FKU = json.load(f) 287 | f.close() 288 | 289 | question_type = "CVR" 290 | with open(f"{args.data_path}/{args.model_name}_{question_type}.json", "r", encoding="utf-8") as f: 291 | data_CVR = json.load(f) 292 | f.close() 293 | 294 | question_type = "KHC" 295 | with open(f"{args.data_path}/{args.model_name}_{question_type}.json", "r", encoding="utf-8") as f: 296 | data_KHC = json.load(f) 297 | f.close() 298 | 299 | the_first_task(data_FKU, data_CVR, data_KHC) 300 | 301 | # 利用args.question_type 来分别计算其他任务的ACC 302 | question_type = args.question_type 303 | with open(f"{args.data_path}/{args.model_name}_{question_type}/seperate_0-1.json", "r", encoding="utf-8") as f: 304 | data = json.load(f) 305 | f.close() 306 | if question_type in ["CBF", "SCF", "PF", "SDDT", "DDST", "SDDT_hard", "DDST_hard"]: 307 | A12_type_task(data) 308 | elif question_type in ["SDT", 'SDT_reverse', "SDT_shuffle"]: 309 | A34_type_task(data) 310 | elif question_type in ["herb_predict"]: 311 | gt_file_path = f"{args.data_path}/{question_type}.json" 312 | gt_data = new_test_gt(gt_file_path) 313 | herb_rec_test(gt_data, data) 314 | 315 | 316 | 317 | 318 | 319 | 320 | 321 | 322 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## TCMBench: Benchmarking Large Language Models in Traditional Chinese Medicine from Knowledge to Clinical Reasoning 2 | Repo for TCMBench (“ShuzhiQihuang” LLMs series,The first comprehensive benchmark for evaluating LLMs in TCM) 3 | 4 | [**English**](./README.md) | [**中文**](./README_Chinese.md) 5 | 6 |

7 |
8 | 9 |
10 |

11 |

12 | GitHub 13 | GitHub top language 14 |

15 | 16 | ## Updates 17 | 18 | 💥 TCMBench V2.0 is here! In this version, we added test questions that reflect the multi-standard and multi-factor characteristics of dynamic clinical reasoning in TCM. We also generated new questions with reasoning perturbations, forming three levels of evaluation tasks and 13 sub-tasks in total. 19 | 20 | 🚀 The initial version of the paper has been released. Citations are welcome. ❗ All forms of plagiarism are strictly rejected (smile.jpg). 21 | 22 | ## ⚡ Introduction 23 | 24 | To further evaluate the performance of large language models (LLMs) in Traditional Chinese Medicine (TCM) more effectively and accurately, we established a standardized and comprehensive benchmark framework: TCMBench. This benchmark fully considers the complexity and domain-specific nature of TCM, covering multiple aspects to ensure the practical usability and applicability of LLMs in real-world TCM scenarios. 25 | 26 | 📚 Dataset: TCMEval 27 | 28 | We first constructed the TCMEval dataset, the first benchmark dataset in TCM. To objectively and accurately reflect the knowledge system and clinical reasoning characteristics of TCM, TCMEval is built using high-quality simulated questions from the TCM Licensing Examination as the data source. 29 | 30 | The dataset contains 6,482 question–answer pairs, of which 1,300 pairs are accompanied by official standard explanations for evaluating the generation quality of LLMs. All data avoids personal information and focuses on TCM knowledge and clinical content. 31 | 32 | Under the guidance of TCM experts, the original questions were filtered and confirmed. From each subject and question type, no more than 100 samples were randomly selected while ensuring an even distribution of answer options to avoid data bias. Two graduate students in TCM further verified the questions, ensuring full coverage of all exam subjects and question types. Through this collection, organization, and annotation process, TCMEval provides a comprehensive, accurate, and representative TCM benchmark to support the evaluation and improvement of LLM applications in TCM. 33 | 34 | **🔎 Task Types**: 35 | - 🚀 **Fundamental Knowledge Cognition Tasks**:The lowest complexity level includes 5,473 Q&A pairs. This task set is based on standard question types in the TCMLE exam and subdivided into three representative tasks, each reflecting different dimensions of knowledge cognition: 36 | - Fundamental Knowledge Understanding (FKU) 37 | - Knowledge Horizontal Correlation (KHC) 38 | - Clinical Vertical Reasoning (CVR) 39 | Data available at [./data/first_level](./data/first_level). 40 | - 🚀 **Comprehensive Dynamic Clinical Analysis Tasks**:Built on top of the fundamental knowledge cognition tasks and designed with input from TCM experts, this set includes six types of tasks featuring multi-standard diagnosis and treatment (e.g., syndrome differentiation and treatment, one disease with different treatments, different diseases with the same treatment) and multi-factor reasoning (e.g., social environment, classical literature interpretation, and philosophical understanding). 41 | This set contains 883 Q&A pairs, with at least 50 samples per task to ensure stable evaluation. 42 | Data available at [./data/second_level](./data/second_level)l. 43 | - 🚀 **Complex Clinical Decision-Making Tasks**:To further evaluate the comprehensive reasoning ability and stability of LLMs in multi-standard TCM clinical environments, we designed four types of complex decision-making tasks based on the dynamic clinical analysis tasks. By restructuring original reasoning samples, introducing semantic perturbations, and converting task formats (e.g., perturbations in syndrome differentiation and treatment, one disease with different treatments, different diseases with the same treatment, as well as Chinese medicine prescription tasks), we generated 1,009 new Q&A pairs.This task set systematically examines model consistency in high-complexity reasoning, robustness in inference, and stability in decision-making. 44 | Data available at [./data/third_level](./data/third_level). 45 | 46 | ## 👨‍⚕️ Data Processing 47 | 48 | ### Example: Fundamental Knowledge Cognition Tasks 49 | 50 | We converted the test questions into structured evaluation data. The data format is as follows: 51 | 52 | Fundamental Knowledge Understanding Task 53 | ```json 54 | { 55 | "question": "《素问·咳论》:“五脏六腑皆令人咳”,但关系最密切的是( )。\nA.心肺\nB.肺肾\nC.肺脾\nD.肺胃\nE.肺大肠", 56 | "answer": [ 57 | "D" 58 | ], 59 | "analysis": "根据《素问·咳论》“此皆聚于胃,关于肺,使人多涕唾而面浮肿气逆也”可知与五脏六腑皆令人咳关系最密切的脏腑为肺胃。手太阴肺经起于中焦,还循胃口,上膈属肺。寒凉饮食入胃,导致中焦寒,寒气循手太阴肺经上入于肺中,导致肺寒,肺为娇脏,不耐寒热,外内寒邪并聚于肺,则肺失宣降,肺气上逆发生咳嗽。因此答案选D。", 60 | "knowledge_point": "中医经典", 61 | "index": 8196, 62 | "score": 1 63 | } 64 | ``` 65 | Clinical Vertical Reasoning Task: 66 | ```json 67 | { 68 | "share_content": "刘×,男,46岁,刻下眩晕而见头重如蒙。胸闷恶心,食少多寐,苔白腻,脉濡滑。", 69 | "question": [ 70 | { 71 | "sub_question": "1).证属( )。\nA.肝阳上亢\nB.气血亏虚\nC.肾精不足\nD.痰浊中阻\nE.以上都不是\n", 72 | "answer": [ 73 | "D" 74 | ], 75 | "analysis": "" 76 | }, 77 | { 78 | "sub_question": "2).治法宜选( )。\nA.燥湿祛痰,健脾和胃\nB.补肾滋阴\nC.补肾助阳\nD.补养气血,健运脾胃\nE.平肝潜阳,滋养肝肾\n", 79 | "answer": [ 80 | "A" 81 | ], 82 | "analysis": "" 83 | }, 84 | { 85 | "sub_question": "3).方药宜选( )。\nA.右归丸\nB.左归丸\nC.半夏白术天麻汤\nD.归脾汤\nE.天麻钩藤饮\n", 86 | "answer": [ 87 | "C" 88 | ], 89 | "analysis": "" 90 | } 91 | ], 92 | "knowledge_point": "中医内科学", 93 | "index": 334, 94 | "score": 1 95 | } 96 | ``` 97 | Knowledge Horizontal Correlation Task: 98 | ```json 99 | { 100 | "share_content": "(共用备选答案)\nA.化痰息风,健脾祛湿\nB.清肺化痰,散结排脓\nC.疏风宣肺,化痰止咳\nD.清热化痰,平肝息风\nE.润肺清热,理气化痰\n", 101 | "question": [ 102 | { 103 | "sub_question": "1).贝母瓜蒌散的功用是( )。", 104 | "answer": [ 105 | "E" 106 | ], 107 | "analysis": "" 108 | }, 109 | { 110 | "sub_question": "2).半夏白术天麻汤的功用是( )。", 111 | "answer": [ 112 | "A" 113 | ], 114 | "analysis": "" 115 | } 116 | ], 117 | "knowledge_point": "方剂学", 118 | "index": 1938, 119 | "score": 1 120 | } 121 | ``` 122 | 123 | ## 🧐 Evaluation Details 124 | 125 | We designed task-adaptive prompts that require LLMs to answer questions and provide explanations. The evaluation framework consists of the following components: 126 | 127 | | 文件名 | 说明 | 128 | | -------------------------- | -------------- | 129 | | [./pipline/choice_bench.py](./pipline/choice_bench.py) | Set up different tasks and guide LLMs to generate answers and explanations| 130 | | [./pipline/bench_function.py](./pipline/bench_function.py) | Functions for testing | 131 | | [./pipline/Acc.py](./pipline/Acc.py) | Compute accuracy | 132 | | [./pipline/Model_API.py](./pipline/Model_API.py)| Call model APIs (OpenAI as an example), adjustable for different models | 133 | | [./TCMBench_code/explain_evaluation.py](./TCMBench_code/explain_evaluation.py)| Evaluate explanation quality using ROUGE-1, ROUGE-L, SARI, BERTScore, BartScore, and our proposed SKScore | 134 | |[./HumanTrue.json](./HumanTrue.json)| HumanTrue Dataset| 135 | 136 | 137 | First, run [/pipline/choice_bench.py](./pipline/choice_bench.py) to test models and obtain their generated answers and explanations: 138 | ``` 139 | # (Optional) If needed, set your proxy: 140 | os.environ['HTTPS_PROXY']="your proxy" 141 | 142 | # If using closed-source models, enter your API key; otherwise leave blank: 143 | api_key = "your key" 144 | 145 | # Specify --data_path and --sys_prompt for different tasks, and --model_name to call different models. 146 | # Example: run the FKU task on gpt-4-0613 147 | python choice_bench.py --data_path ../data/first_level --sys_prompt FKU.json --model_name gpt-4-0613 148 | ``` 149 | 150 | Use [./pipline/Acc.py](./pipline/Acc.py) to compute accuracy scores for different models on different tasks by setting --data_path, --queation_type, and --model_name: 151 | ``` 152 | python Acc.py --data_path ../data/first_level --queation_type FKU --model_name gpt-4-0613 153 | ``` 154 | 155 | Use [./TCMBench_code/explain_evaluation.py](./TCMBench_code/explain_evaluation.py) to compute explanation scores across six metrics by specifying --model_name: 156 | ``` 157 | python explain_evaluation.py --model_name gpt-4-0613 158 | ``` 159 | The models used for these metrics can be found at [code link](https://huggingface.co/WJing123/TCMBench_code) 160 | 161 | 162 | 👨‍⚕️ In addition, this work also introduces our previously developed TCM LLMs, ShenNong. We welcome everyone to follow our open-source TCM large language model project ShenNong-TCM-LLM: **ShenNong-TCM-LLM**: 163 | 164 | - 🚀 [ShenNong-TCM](https://github.com/ywjawmw/ShenNong-TCM-LLM) : To promote the development and real-world application of LLMs in Traditional Chinese Medicine, we released ShenNong, a large-scale TCM language model designed to improve knowledge coverage and enhance its ability to answer medical consultations in TCM. It is built upon the [TCM instruction dataset SN-QA](https://huggingface.co/datasets/michaelwzhu/ShenNong_TCM_Dataset)。 165 | 166 | We also introduce our other open-source healthcare LLM projects: 167 | - 🚀 [Intelligent TCM Inheritance and Innovation Platform](https://github.com/ywjawmw/AI4TCM-Platform) : As part of the Shuzhi Qihuang series, this platform addresses two main challenges: 1. The inability of existing TCM platforms to cover multimodal data by constructing a more comprehensive knowledge graph integrating TCM and Western medicine. 2. The inefficiency in TCM experience inheritance by proposing interpretable prescription analysis technology that automatically analyzes the holistic diagnostic process from symptoms to prescriptions and provides scientific reasoning. It also provides a fair platform to help young doctors and TCM students quickly master advanced knowledge and inherit medical expertise. 168 | - 🚀 [ChatMed-Consult](https://huggingface.co/michaelwzhu/ChatMed-Consult) : Built from the [ChatMed_Consult_Dataset](https://huggingface.co/datasets/michaelwzhu/ChatMed_Consult_Dataset) with over 500k online medical consultations paired with ChatGPT responses. The model backbone is [LlaMA-7b](https://github.com/facebookresearch/llama), combined with LoRA weights from [Chinese-LlaMA-Alpaca](https://github.com/ymcui/Chinese-LLaMA-Alpaca) and extended Chinese vocabulary, followed by efficient parameter tuning using LoRA. All codes are publicly released. 169 | 170 | - 🚀 [PromptCBLUE中文医疗大模型评测基准](https://github.com/michael-wzhu/PromptCBLUE): A prompt-based adaptation of the [CBLUE](https://tianchi.aliyun.com/dataset/95414) benchmark, designed to evaluate Chinese medical knowledge and text-processing abilities of LLMs. PromptCBLUE enables a single generative LLM to handle a variety of medical NLP tasks such as medical record structuring, consultation, and clinical documentation writing. 171 | 172 | ## Acknowledgements 173 | 174 | This project was developed based on APIs of large language models and inspired by evaluation tasks on Gaokao examination questions. We thank the related projects and developers for their contributions: 175 | 176 | - [ChatGPT](https://openai.com/blog/chatgpt) 177 | - [ChatGLM](https://github.com/THUDM/ChatGLM-6B) 178 | - [GaoKao-Bench](https://github.com/OpenLMLab/GAOKAO-Bench) 179 | 180 | 181 | ## Citation 182 | 183 | If you use the data or code from this project, please cite: 184 | 185 | ```bash 186 | @misc{yue2023 TCMBench, 187 | title={TCMBench: Benchmarking Large Language Models in Traditional Chinese Medicine from Knowledge to Clinical Reasoning}, 188 | author={Wenjing Yue, Ming guan, Wei Zhu, Xiaoling Wang, Honglin li}, 189 | year={2023}, 190 | publisher = {GitHub}, 191 | journal = {GitHub repository}, 192 | howpublished = {\url{https://github.com/ywjawmw/TCMBench}}, 193 | } 194 | 195 | ``` 196 | 197 | ## Team 198 | 199 | This project was completed under the guidance of **Prof. Xiaoling Wang** from the School of Computer Science and Technology, **Prof. Honglin Li** from Innovation Center for AI and Drug Discovery, East China Normal University. 200 | 201 | 202 | 203 | 204 | 205 | -------------------------------------------------------------------------------- /TCMBench_code/explain_evaluation.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # !/usr/bin/python3 3 | # -*- coding: utf-8 -*- 4 | # @Time : 2025/9/28 23:10 5 | # @Author : Wenjing 6 | # @File : explain_evaluation.py 7 | # @Desc : 评估LLMs生成解析的质量 8 | 9 | import argparse 10 | import json 11 | import sys 12 | import json 13 | import math 14 | import os 15 | import time 16 | from typing import List, Dict 17 | from sari import SARIsent 18 | import numpy as np 19 | from BARTScore.bart_score_chinese import BARTScorerChinese 20 | # from BARTScore.bart_score import BARTScorer 21 | import evaluate, torch 22 | import statistics 23 | from rouge_chinese import Rouge 24 | import jieba 25 | from transformers import ( 26 | AutoConfig, 27 | AutoModelForSequenceClassification, 28 | AutoTokenizer 29 | ) 30 | from sklearn.metrics import roc_auc_score 31 | from sklearn import metrics 32 | from matplotlib import pyplot as plt 33 | 34 | jieba.load_userdict("jieba_tcm.txt") # load进入TCM术语 35 | with open('jieba_tcm.txt', 'r', encoding='utf-8') as f: 36 | tcm_terms = f.readlines() 37 | tcm_term_db = [t.strip() for t in tcm_terms] 38 | 39 | def parse_args(): 40 | parser = argparse.ArgumentParser(description="parameter of LLMs") 41 | parser.add_argument( 42 | "--model_name", 43 | type=str, 44 | default='gpt-4-0613', 45 | help="The LLM name.", 46 | ) 47 | args = parser.parse_args() 48 | return args 49 | 50 | 51 | def get_analysis_A12(data_dict_predict): 52 | correct_analysis, predict_analysis = [], [] 53 | for data in data_dict_predict['example']: 54 | if len(data['analysis']) > 0: 55 | correct_analysis.append(data['analysis']) 56 | predict_analysis.append(data['model_output']) 57 | return correct_analysis, predict_analysis 58 | 59 | 60 | def get_analysis_A34(data_dict_predict): 61 | correct_analysis, predict_analysis = [], [] 62 | for data in data_dict_predict['example']: 63 | question = data["question"] 64 | for sub_question in question: 65 | if len(sub_question['analysis']) > 0: 66 | correct_analysis.append(sub_question['analysis']) 67 | predict_analysis.append(sub_question['model_output']) 68 | return correct_analysis, predict_analysis 69 | 70 | 71 | def rouge_score(correct_analysis, predict_analysis): 72 | # rouge = evaluate.load('rouge') 73 | # results = rouge.compute(predictions=correct_analysis, 74 | # references=predict_analysis, 75 | # rouge_types=['rouge1', 'rougeL'], 76 | # use_aggregator=False) 77 | rouge = Rouge() 78 | empty_num = 0 79 | correct_analysis_cal, predict_analysis_cal = [], [] 80 | for ca, pa in zip(correct_analysis, predict_analysis): 81 | if len(pa) == 0: 82 | empty_num += 1 83 | else: 84 | correct_analysis_cal.append(' '.join(jieba.cut(ca))) 85 | predict_analysis_cal.append(' '.join(jieba.cut(pa))) 86 | scores = rouge.get_scores(predict_analysis_cal, correct_analysis_cal) 87 | results_rouge1 = [round(score['rouge-1']['r'],2) for score in scores] + [0.0] * empty_num 88 | results_rougel = [round(score['rouge-l']['f'], 2) for score in scores] + [0.0] * empty_num 89 | return np.mean(results_rouge1), np.mean(results_rougel) 90 | 91 | def bert_score(correct_analysis, predict_analysis): 92 | bertscore = evaluate.load('../TCMBench_code/bertscore') 93 | results = bertscore.compute(predictions=predict_analysis, references=correct_analysis, lang="zh", 94 | model_type="bert-base-chinese") 95 | score = [round(v, 2) for v in results["f1"]] 96 | return np.mean(score) 97 | 98 | def bart_Score(correct_analysis, predict_analysis): 99 | bart_scorer = BARTScorerChinese(checkpoint='bart-base-chinese') 100 | bart_scores = bart_scorer.score(correct_analysis, predict_analysis, batch_size=4) 101 | # print("BART Score", np.mean(bart_scores)) 102 | return np.mean(bart_scores) 103 | 104 | def calculate_sari( 105 | input_lns: List[str], output_lns: List[str], reference_lns: List[str] 106 | ) -> Dict: 107 | a, b, c, d = [], [], [], [] 108 | for input, output, ref in zip(input_lns, output_lns, reference_lns): 109 | a_, b_, c_, d_ = SARIsent(input, output, [ref]) 110 | 111 | a.append(round(a_,2)) 112 | return a 113 | 114 | 115 | def sari_score(correct_analysis, predict_analysis): 116 | sariscores = calculate_sari(correct_analysis, predict_analysis, correct_analysis) # 参考答案根据reference 只看sari_avgkeepscore 117 | # print(f"SARI score_解析: {sariscores}") 118 | return np.mean(sariscores) 119 | 120 | print("使用 x / sum_x,缩放f1的值") 121 | def softmax(x): 122 | e_x = x 123 | sum_x = e_x.sum(axis=0) 124 | if sum_x == 2 * x.size: # 标准解析中就没有中医术语,那么解析中的每一句话的f1score都=2 125 | return [1/x.size] * x.size # 加和平均 126 | elif sum_x == 0: 127 | return [0] * x.size # 也就是LLMs中没有中医术语,那么这个解析是不好的,给一个特别低的分数 128 | else: 129 | return x / sum_x 130 | 131 | import re 132 | def split_sentences(text): 133 | # 利用正则表达式按照句号、感叹号、问号进行划分 134 | sentences = re.split(r'(?<=[。!?])\s*', text) 135 | # 去除空字符串和空白符 136 | sentences = [s.strip() for s in sentences if s.strip()] 137 | return sentences 138 | 139 | from collections import Counter 140 | def tcm_score_f1(analysis_tcm_terms_counter, analysis_terms_counter, doc, tcm_term_db): 141 | """ 142 | 中医术语匹配度 143 | :param analysis_tcm_terms_counter: 解析中的中医术语以及计数 144 | :param analysis_terms_counter: 解析 145 | :param doc: 需要检测的语句 146 | :param tcm_term_db: 中医术语库 147 | :return: 148 | """ 149 | 150 | doc_terms_list = list(jieba.cut(doc)) 151 | doc_terms_counter = Counter(doc_terms_list) 152 | doc_tcm_terms_list = [term for term in doc_terms_list if term in tcm_term_db] 153 | doc_tcm_terms_counter = Counter(doc_tcm_terms_list) # 片段中所有的中医术语计数 154 | if len(analysis_tcm_terms_counter) == 0: 155 | return 2 # 如果解析中没有中医术语,那么就不需要进行F1score的计算 156 | elif len(doc_tcm_terms_counter) == 0: 157 | return 0 # # 如果LLMs中没有中医术语,那么F1score=0, 说明这句话不符合中医诊疗语言,或者没有什么信息量 158 | comment_term_counter = analysis_tcm_terms_counter & doc_tcm_terms_counter # 重复的中医术语 159 | recall_comment_score, precision_comment_score = 0, 0 160 | for term in comment_term_counter: 161 | recall_comment_score += comment_term_counter[term] / analysis_tcm_terms_counter[term] 162 | precision_comment_score += comment_term_counter[term] / doc_tcm_terms_counter[term] 163 | recall = recall_comment_score / len(analysis_tcm_terms_counter) # 重复的中医术语个数/解析中的中医术语个数 —— 重叠度 164 | precision = precision_comment_score / len(doc_tcm_terms_counter) # 重复的中医术语个数/ 文档的中医术语个数 —— 冗余度 165 | informational = len(list(set(doc_tcm_terms_list))) / len(list(set(doc_terms_list))) 166 | # informational = len(doc_tcm_terms_counter) / len(doc_terms_counter) * (sum(doc_terms_counter.values()) / sum(analysis_terms_counter.values())) # 片段中的中医术语 / 片段中术语个数(不重复) * (片段的术语总个数 / 解析的术语总个数)[长度的惩罚项] —— 信息度 167 | if precision == 0 or recall == 0: 168 | return 0 169 | else: 170 | f1_score = 3 * (precision * recall * informational) / (precision + recall + informational) 171 | return f1_score 172 | 173 | def calculate_score(sentence, sample, model, tokenizer): 174 | inputs = tokenizer.batch_encode_plus( 175 | batch_text_or_text_pairs=[(sentence, sample)], 176 | add_special_tokens=True, padding="longest", 177 | truncation=True, return_tensors="pt", 178 | return_token_type_ids=True, return_attention_mask=True, 179 | max_length=512 180 | ) 181 | 182 | logits = model(**inputs).logits # neutral is already removed 183 | # print(logits) 184 | probs = torch.argmax(logits, dim=-1) 185 | prob_label = probs[0].item() # 类别 186 | probs1 = torch.softmax(logits, dim=-1) 187 | prob_1 = probs1[0][0].item() # prob(相关程度) 188 | return prob_label, prob_1 189 | 190 | def f1_score_tcm_term(analysis_tcm_terms_list, analysis_terms_list, doc, tcm_term_db): 191 | """ 192 | 中医术语匹配度 193 | :param analysis_tcm_terms_list: 解析中的中医术语 194 | :param analysis_terms_list: 解析 195 | :param doc: 需要检测的语句 196 | :param tcm_term_db: 中医术语库 197 | :return: 198 | """ 199 | 200 | doc_terms_list = list(jieba.cut(doc)) 201 | doc_tcm_terms_list = [term for term in doc_terms_list if term in tcm_term_db] 202 | doc_tcm_terms_list = doc_tcm_terms_list # 片段中所有的中医术语(含有重复元素) 203 | if len(analysis_tcm_terms_list) == 0: 204 | return 2 # 如果解析中没有中医术语,那么就不需要进行F1score的计算 205 | elif len(doc_tcm_terms_list) == 0: 206 | return 0 # # 如果LLMs中没有中医术语,那么F1score=0, 说明这句话不符合中医诊疗语言,或者没有什么信息量 207 | comment_term = set(analysis_tcm_terms_list) & set(doc_terms_list) # 重复的中医术语 208 | comment_num = 0 209 | for doc_term in doc_tcm_terms_list: 210 | if doc_term in comment_term: 211 | comment_num += 1 212 | recall = comment_num / len(analysis_tcm_terms_list) # 重复的中医术语个数/真正对的中医术语个数 —— 重叠度 213 | precision = comment_num / len(doc_tcm_terms_list) # 重复的中医术语个数/ 文档的中医术语个数 —— 冗余度 214 | informational = len(list(set(doc_tcm_terms_list))) / len(list(set(doc_terms_list))) 215 | if precision == 0 or recall == 0 or informational == 0: 216 | return 0 217 | else: 218 | f1_score = 3 * (precision * recall * informational) / (precision + recall + informational) 219 | return f1_score 220 | 221 | 222 | def predict_analysys_response(sentences: List[str], sampled_passages: List[str], model, tokenizer): 223 | """ 224 | :param sentences: list of 标准解析 225 | :param sampled_passages: LLMs的解析 226 | """ 227 | scores1 = list() # 计算LLMs生成的解析与标准解析之间的分数 228 | scores1_counter = list() # _counter 229 | scores1_nof1 = list() 230 | scores1_dotf1 = list() 231 | num = 0 232 | for sentence, sample in zip(sentences, sampled_passages): # 解析 233 | if num == 0: 234 | print(f"sentence: {sentence}") 235 | print(f"sample: {sample}") 236 | num += 1 237 | # 分句 238 | response_sentence_list = split_sentences(sample) 239 | analysis_sentence_list = split_sentences(sentence) 240 | tcm_score = [] 241 | tcm_score_counter = [] 242 | prob_score_list = [] 243 | tcm_score_dotf1 = [] 244 | for analysis_sentence in analysis_sentence_list: # 分句 245 | f1_score, prob_score = [], [] 246 | f1_score_counter = [] 247 | if len(response_sentence_list) > 0: 248 | for response_sentence in response_sentence_list: # LLMs分析分句 249 | # 统计标准解析中的中医术语 250 | analysis_terms = list(jieba.cut(analysis_sentence)) 251 | analysis_terms_counter = Counter(analysis_terms) 252 | analysis_tcm_terms_list = [term for term in analysis_terms if term in tcm_term_db] 253 | analysis_tcm_terms_counter = Counter(analysis_tcm_terms_list) # 解析中的中医术语计数 254 | analysis_tcm_terms_list = list(analysis_tcm_terms_counter.keys()) # 解析中的中医术语列表 255 | 256 | prob_label_A, prob_1_A = calculate_score(analysis_sentence, response_sentence, model, tokenizer) 257 | prob_label_reverse_A, prob_1_reserve_A = calculate_score(response_sentence, analysis_sentence, 258 | model, tokenizer) 259 | prob = (prob_1_A + prob_1_reserve_A) / 2 260 | ################## TCM score #################### 261 | f1_term_score = f1_score_tcm_term(analysis_tcm_terms_list, analysis_terms, response_sentence, tcm_term_db) 262 | f1_term_score_counter = tcm_score_f1(analysis_tcm_terms_counter, analysis_terms_counter, response_sentence, tcm_term_db) 263 | f1_score.append(f1_term_score) 264 | prob_score.append(prob) 265 | f1_score_counter.append(f1_term_score_counter) 266 | f1_score = np.array(f1_score) 267 | 268 | ####不加f1 score####### 269 | average_prob_score = statistics.mean(prob_score) 270 | prob_score_list.append(average_prob_score) 271 | 272 | prob_score = np.array(prob_score) 273 | f1_score_counter = np.array(f1_score_counter) 274 | # 对列表进行归一化 275 | try: 276 | normalized_f1_score = softmax(f1_score) 277 | normalized_f1_score_counter = softmax(f1_score_counter) 278 | except: 279 | print(response_sentence_list) 280 | print(f1_score) 281 | sys.exit() 282 | 283 | #####直接与F1 score相乘##### 284 | analysis_sentence_dotf1 = np.sum(f1_score * prob_score) / prob_score.size 285 | tcm_score_dotf1.append(analysis_sentence_dotf1) 286 | ###### 计算相乘并相加的结果,加权平均 287 | analysis_sentence_score = np.sum(normalized_f1_score * prob_score) 288 | tcm_score.append(analysis_sentence_score) 289 | # 用Counter计算的score 290 | analysis_sentence_score_counter = np.sum(normalized_f1_score_counter * prob_score) 291 | tcm_score_counter.append(analysis_sentence_score_counter) 292 | else: 293 | tcm_score.append(0) 294 | tcm_score_counter.append(0) 295 | tcm_score_dotf1.append(0) 296 | prob_score_list.append(0) 297 | if len(sample) < len(sentence): 298 | if len(sample) > 1: 299 | length_penalty = math.exp(1 - math.log(len(sentence)) / math.log(len(sample))) 300 | elif len(sample) == 1: 301 | length_penalty = math.exp(1 - math.log(len(sentence)) / math.log(len(sample) + 1)) 302 | else: 303 | length_penalty = 0 304 | else: 305 | # length_penalty = 1 306 | length_penalty = math.exp(1 - math.log(len(sample)) / math.log(len(sentence))) 307 | scores_per_response = statistics.mean(tcm_score) * length_penalty 308 | scores1.append(scores_per_response) 309 | # 用Counter计算的score 310 | scores_per_response_counter = statistics.mean(tcm_score_counter) * length_penalty 311 | scores1_counter.append(scores_per_response_counter) 312 | ############不加f1 score 313 | scores1_nof1.append(statistics.mean(prob_score_list)) 314 | #####直接与F1 score相乘##### 315 | scores_per_response_dotf1 = statistics.mean(tcm_score_dotf1) 316 | scores1_dotf1.append(scores_per_response_dotf1) 317 | scores_per_doc = statistics.mean(scores1) 318 | scores_per_doc_counter = scores1_counter 319 | print("解析与回答之间的TCM Score,用Counter计算:", scores_per_doc_counter) 320 | return scores_per_doc_counter 321 | 322 | def nli_score(references, predictions, model, tokenizer): 323 | n_score = predict_analysys_response(sentences=references, sampled_passages=predictions, model=model, tokenizer=tokenizer) 324 | return n_score 325 | 326 | 327 | if __name__ == "__main__": 328 | args = parse_args() 329 | question_type = "FKU" 330 | with open(f"../data/first_level/{args.model_name}_{question_type}.json", "r", encoding="utf-8") as f: 331 | data_FKU = json.load(f) 332 | f.close() 333 | 334 | question_type = "CVR" 335 | with open(f"../data/first_level/{args.model_name}_{question_type}.json", "r", encoding="utf-8") as f: 336 | data_CVR = json.load(f) 337 | f.close() 338 | 339 | question_type = "KHC" 340 | with open(f"../data/first_level/{args.model_name}_{question_type}.json", "r", encoding="utf-8") as f: 341 | data_KHC = json.load(f) 342 | f.close() 343 | 344 | correct_analysis, predict_analysis = [], [] 345 | correct_analysis_A12, predict_analysis_A12 = get_analysis_A12(data_FKU) 346 | correct_analysis_A3, predict_analysis_A3 = get_analysis_A34(data_CVR) 347 | correct_analysis_B1, predict_analysis_B1 = get_analysis_A34(data_KHC) 348 | correct_analysis += correct_analysis_A12 + correct_analysis_A3 + correct_analysis_B1 349 | predict_analysis += predict_analysis_A12 + predict_analysis_A3 + predict_analysis_B1 350 | 351 | rouge1, rouge_L = rouge_score(correct_analysis, predict_analysis) 352 | print("ROUGE-1:", rouge1) 353 | print("ROUGE-L:", rouge_L) 354 | sari_scores = sari_score(correct_analysis, predict_analysis) 355 | print("SARI:", sari_scores) 356 | bert_scores = bert_score(correct_analysis, predict_analysis) 357 | print("BERTScore:", bert_scores) 358 | bart_scores = bart_Score(correct_analysis, predict_analysis) 359 | print("BARTScore:", bart_scores) 360 | 361 | # SKScore 362 | model_name_or_path = "../TCMBench_code/Deberta-V3-base-tmnli-QAC" 363 | print(model_name_or_path) 364 | config = AutoConfig.from_pretrained( 365 | model_name_or_path, 366 | num_labels=3, 367 | finetuning_task="mnli", 368 | trust_remote_code=False 369 | ) 370 | tokenizer = AutoTokenizer.from_pretrained( 371 | model_name_or_path, use_fast=not False, trust_remote_code=False 372 | ) 373 | # print(tokenizer) 374 | model = AutoModelForSequenceClassification.from_pretrained( 375 | f"{model_name_or_path}", 376 | from_tf=bool(".ckpt" in model_name_or_path), 377 | config=config, 378 | ignore_mismatched_sizes=False, 379 | ) 380 | sk_score = nli_score(correct_analysis, predict_analysis, model, tokenizer) 381 | SKScore = np.mean(sk_score) 382 | print("SKScore:", SKScore) 383 | 384 | 385 | -------------------------------------------------------------------------------- /pipline/bench_function.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import json 4 | 5 | import re 6 | 7 | from typing import List 8 | 9 | from tqdm import tqdm 10 | from collections import Counter 11 | 12 | # args = parse_args() 13 | 14 | 15 | def get_api_key(filename: str, start_num: int, end_num: int) -> List[str]: 16 | """ 17 | Retrieves API keys from a file. 18 | 19 | :param filename: Name of the file containing API keys 20 | :param start_num: Starting line number for reading the file 21 | :param end_num: Ending line number for reading the file 22 | :return: List of API keys 23 | """ 24 | with open(filename, 'r') as file: 25 | lines = file.readlines() 26 | 27 | pattern = re.compile(r'sk-[\s\S]*?(?=\s*\n)') 28 | api_key_list = [] 29 | 30 | for i in range(start_num, end_num): 31 | api_key = pattern.findall(lines[i]) 32 | if len(api_key) != 0: 33 | api_key_list.append(api_key[0]) 34 | 35 | return api_key_list 36 | 37 | 38 | def extract_choice_answer(model_output, question_type, answer_lenth=None): 39 | """ 40 | Extract choice answer from model output 41 | 42 | Format of model_output that is expected: 43 | 'single_choice': choice answer should be the last Capital Letter of the model_output, e.g.: "...【答案】 A " 44 | 'multi_question_choice': "...【答案】A ... 【答案】C ..." or write the choice answers at the beginning of the model_output, e.g. "A C D E F...." 45 | 'multi_choice': "...【答案】 ABD " or write the choice answers at the end of the model_output, e.g. "... ACD" 46 | 'five_out_of_seven': choice answers should be the first five Capital Letters of the model_output, e.g. "A C D F B ...." 47 | """ 48 | model_answer = [] 49 | model_output = model_output[::-1] 50 | pattern = r"([A-Z]).*?案答" 51 | check_info = re.search(pattern, model_output) 52 | if check_info: 53 | pattern = r"\.[A-Z]" 54 | temp = re.findall(pattern, model_output) 55 | if len(temp) > 0: 56 | # answer = temp[0] 57 | answer = check_info.group(1) 58 | model_answer.append(answer) 59 | else: 60 | temp = re.findall(r'[A-E]', model_output) 61 | if len(temp) != 0: 62 | answer = temp[0] 63 | model_answer.append(answer) 64 | else: 65 | temp = re.findall(r'[A-E]', model_output) 66 | if len(temp) != 0: 67 | answer = temp[0] 68 | model_answer.append(answer) 69 | 70 | return model_answer 71 | 72 | def extract_choice_answer_hard(model_output, question_type, answer_lenth=None): 73 | model_answer = [] 74 | # DDST 75 | if question_type == "DDST_hard": 76 | model_answer = [] 77 | model_output = model_output[0].replace("[", "").replace("]", "").replace("", "").replace("", 78 | "").replace(" ", 79 | "") 80 | model_output = ''.join([char for char in model_output if char.isdigit()]) 81 | temp = re.findall(r'[0-9]', model_output) 82 | if len(temp) != 0: 83 | model_answer.append(model_output) 84 | elif question_type == "herb_predict": 85 | model_answer = [] 86 | model_output = model_output[0].replace("[", "").replace("]", "").replace("", "").replace("", 87 | "").replace(" ", 88 | "") 89 | temp = model_output.split(",") 90 | if len(temp) != 0: 91 | if len(temp) < 20: 92 | # 在temp补若干个-1,让其长度为20 93 | for i in range(20 - len(temp)): 94 | temp.append("-1") 95 | if len(temp) > 20: 96 | temp = temp[:20] 97 | model_answer.append(temp) 98 | return model_answer 99 | 100 | def A3_second_check(model_output): 101 | pattern = r"【答案】" 102 | temp = re.findall(pattern, model_output) 103 | check_answer = list() 104 | if len(temp) > 0: 105 | model_output = model_output.replace("\n", "") 106 | pattern = r"答案】.*?([A-Z])" 107 | temp = re.findall(pattern, model_output) 108 | if len(temp) > 0: 109 | answer = temp[-1] 110 | check_answer.append(answer) 111 | else: 112 | check_answer = [] 113 | else: 114 | pattern = r"答案.*?([A-Z])" 115 | check_info = re.findall(pattern, model_output) 116 | if check_info: 117 | answer = check_info[-1] 118 | check_answer.append(answer) 119 | else: 120 | model_output = model_output[::-1] 121 | temp = re.findall(r'[A-E]', model_output) 122 | if len(temp) != 0: 123 | answer = temp[0] 124 | check_answer.append(answer) 125 | return check_answer 126 | 127 | def pattern_second_check(ans_list, model_output): 128 | check_answer = list() 129 | ans_list = [ 130 | ans.replace("A", "").replace("B", "").replace("C", "").replace("D", "").replace("E", "").replace( 131 | ".", "").replace(".", "") for ans in ans_list if len(ans) > 0] 132 | ans_id = -1 133 | candidate_ans = {} 134 | for ans in ans_list: 135 | a_list = re.split(r'[,。;\s]', ans) 136 | max_count = 0 137 | for a in a_list: 138 | if a in model_output: 139 | ans_id = ans_list.index(ans) 140 | c = model_output.count(a) 141 | max_count += c 142 | if max_count > 0: 143 | candidate_ans[ans] = max_count 144 | if len(candidate_ans) > 1: 145 | # 有多个选项 146 | # 如果多个选项中,有的选项出现的频率超过了其它选项,那么认为这个选项是正确答案 147 | max_value = max(candidate_ans.values()) 148 | value_clist = Counter(candidate_ans.values()) 149 | if value_clist[max_value] == 1: 150 | unique_max_key = [key for key, value in candidate_ans.items() if value == max_value][0] 151 | ans_id = ans_list.index(unique_max_key) 152 | check_answer.append(chr(ans_id + 65)) 153 | else: 154 | print(candidate_ans) 155 | elif ans_id >= 0: 156 | check_answer.append(chr(ans_id + 65)) 157 | return check_answer 158 | 159 | def herb_second_check(model_output): 160 | no_answer = 0 161 | no_standard = 0 162 | if "-1" in model_output: 163 | # 统计有多少个-1 164 | count = model_output.count("-1") 165 | if 19 <= count <= 20: 166 | no_answer += 1 167 | else: 168 | no_standard += 1 169 | res_list = [] 170 | if type(model_output[0]) is list: 171 | model_output = model_output[0] 172 | for res in model_output: 173 | if "、" in res: 174 | res_list += res.split("、") 175 | if len(res_list) > 5: 176 | new_res_list = [] 177 | for res in model_output: 178 | if res != "-1" and "、" not in res: 179 | new_res_list += res 180 | else: 181 | if res == "-1": 182 | continue 183 | new_res_list.extend(res_list) 184 | if len(new_res_list) < 20: 185 | # 在temp补若干个-1,让其长度为20 186 | for i in range(20 - len(new_res_list)): 187 | new_res_list.append("-1") 188 | 189 | model_output = new_res_list 190 | if len(set(model_output)) != 20: 191 | # 找到重复的非-1的元素 192 | repeat_list = [] 193 | for i, res in enumerate(model_output): 194 | if res != "-1": 195 | if res not in repeat_list: 196 | repeat_list.append(res) 197 | else: 198 | model_output[i] = "-1" 199 | return model_output 200 | 201 | 202 | def choice_test_A12(**kwargs): 203 | model_api = kwargs['model_api'] 204 | model_name = kwargs['model_name'] 205 | start_num = kwargs['start_num'] 206 | end_num = kwargs['end_num'] 207 | data = kwargs['data']['example'] 208 | keyword = kwargs['keyword'] 209 | prompt = kwargs['prompt'] 210 | question_type = kwargs['question_type'] 211 | save_directory = kwargs['save_directory'] 212 | args = kwargs['args'] 213 | 214 | model_answer_dict = [] 215 | for i in range(start_num, end_num): 216 | index = data[i]['index'] 217 | question = data[i]['question'].strip() + '\n' 218 | score = 1 219 | standard_answer = data[i]['answer'] 220 | try: 221 | analysis = data[i]['analysis'] 222 | except: 223 | analysis = '' 224 | try: 225 | knowledge_point = data[i]['knowledge_point'] 226 | except: 227 | knowledge_point = '' 228 | model_output = model_api.send_request_turbo(prompt, question)[0] 229 | model_answer = extract_choice_answer(model_output, question_type, 5) 230 | if len(model_answer) == 0: 231 | ans_list = question.split("\n")[1:] 232 | model_answer = pattern_second_check(ans_list, model_output) 233 | # TODO: which content of temp we expect 234 | dict = { 235 | 'index': index, 236 | # 'year': year, 237 | # 'category': category, 238 | 'score': score, 239 | 'question': question, 240 | 'standard_answer': standard_answer, 241 | 'analysis': analysis, 242 | 'knowledge_point': knowledge_point, 243 | 'model_answer': model_answer, 244 | 'model_output': model_output 245 | } 246 | for key, value in dict.items(): 247 | print(key, ":", value) 248 | # print(dict) 249 | print("*" * 100, "index-", dict["index"], "*" * 100) 250 | model_answer_dict.append(dict) 251 | 252 | file_name = f"seperate_{start_num}-{end_num}.json" 253 | file_path = os.path.join(save_directory, file_name) 254 | with open(file_path, 'w', encoding='utf-8') as f: 255 | output = { 256 | 'keyword': keyword, 257 | 'example': model_answer_dict 258 | } 259 | json.dump(output, f, ensure_ascii=False, indent=4) 260 | f.close() 261 | 262 | 263 | def choice_test_A34(**kwargs): 264 | model_api = kwargs['model_api'] 265 | model_name = kwargs['model_name'] 266 | start_num = kwargs['start_num'] 267 | end_num = kwargs['end_num'] 268 | data = kwargs['data']['example'] 269 | keyword = kwargs['keyword'] 270 | prompt = kwargs['prompt'] 271 | question_type = kwargs['question_type'] 272 | save_directory = kwargs['save_directory'] 273 | args = kwargs['args'] 274 | 275 | model_answer_dict = [] 276 | for i in range(start_num, end_num): 277 | index = data[i]['index'] 278 | question = data[i]['question'] # list() 包含多个小问题和答案 279 | score = 1 280 | try: 281 | knowledge_point = data[i]['knowledge_point'] 282 | except: 283 | knowledge_point = '' 284 | share_content = data[i]['share_content'] 285 | model_output = model_api.send_request_chat(prompt, share_content, question) 286 | 287 | question_list = [] 288 | for sub_question, output in zip(question, model_output): 289 | standard_answer = sub_question['answer'] 290 | try: 291 | analysis = sub_question['analysis'] 292 | except: 293 | analysis = '' 294 | model_answer = extract_choice_answer(output, question_type, 5) 295 | if question_type in ["CVR", "SDT", 'SDT_reverse', "SDT_shuffle"]: 296 | model_answer = A3_second_check(output) 297 | if len(model_answer) == 0: 298 | if question_type in ["CVR", "SDT", 'SDT_reverse', "SDT_shuffle"]: 299 | ans_list = sub_question.split("\n")[1:] 300 | else: 301 | ans_list = share_content.split("\n")[1:] 302 | model_answer = pattern_second_check(ans_list, output) 303 | sub_question_dict = { 304 | 'sub_question': sub_question['sub_question'], 305 | 'standard_answer': standard_answer, 306 | 'analysis': analysis, 307 | 'model_answer': model_answer, 308 | 'model_output': output 309 | } 310 | question_list.append(sub_question_dict) 311 | # TODO: which content of temp we expect 312 | 313 | dict = { 314 | 'index': index, 315 | 'score': score, 316 | 'share_content': share_content, 317 | 'question': question_list, 318 | 'knowledge_point': knowledge_point, 319 | } 320 | model_answer_dict.append(dict) 321 | 322 | file_name = f"seperate_{start_num}-{end_num}.json" 323 | file_path = os.path.join(save_directory, file_name) 324 | with open(file_path, 'w', encoding='utf-8') as f: 325 | output = { 326 | 'keyword': keyword, 327 | 'example': model_answer_dict 328 | } 329 | json.dump(output, f, ensure_ascii=False, indent=4) 330 | f.close() 331 | 332 | def choice_test_DDST_hard(**kwargs): 333 | model_api = kwargs['model_api'] 334 | model_name = kwargs['model_name'] 335 | start_num = kwargs['start_num'] 336 | end_num = kwargs['end_num'] 337 | data = kwargs['data']['example'] 338 | keyword = kwargs['keyword'] 339 | prompt = kwargs['prompt'] 340 | save_directory = kwargs['save_directory'] 341 | 342 | model_answer_dict = [] 343 | for i in range(start_num, end_num): 344 | question = data[i]['question'] 345 | option = data[i]['option'] 346 | standard_answer = data[i]['answer'] 347 | model_output = model_api.send_request_hard(prompt, question, option, int(standard_answer)) 348 | model_answer = extract_choice_answer_hard(model_output, keyword, 0) 349 | # TODO: which content of temp we expect 350 | dict = { 351 | 'question': question, 352 | 'option': option, 353 | 'standard_answer': [standard_answer], 354 | 'model_answer': model_answer 355 | } 356 | # print("*" * 100, "index-", dict["index"], "*" * 100) 357 | for key, value in dict.items(): 358 | print(key, ":", value) 359 | # print(dict) 360 | model_answer_dict.append(dict) 361 | 362 | file_name = f"seperate_{start_num}-{end_num}.json" 363 | file_path = os.path.join(save_directory, file_name) 364 | with open(file_path, 'w', encoding='utf-8') as f: 365 | output = { 366 | 'keyword': keyword, 367 | 'example': model_answer_dict 368 | } 369 | json.dump(output, f, ensure_ascii=False, indent=4) 370 | f.close() 371 | 372 | def choice_test_TCM_Rec(**kwargs): 373 | model_api = kwargs['model_api'] 374 | model_name = kwargs['model_name'] 375 | start_num = kwargs['start_num'] 376 | end_num = kwargs['end_num'] 377 | data = kwargs['data']['example'] 378 | keyword = kwargs['keyword'] 379 | prompt = kwargs['prompt'] 380 | save_directory = kwargs['save_directory'] 381 | question_type = kwargs['question_type'] 382 | 383 | model_answer_dict = [] 384 | for i in range(start_num, end_num): 385 | question = "、".join(data[i]['sym_set']) 386 | # option = data[i]['option'] 387 | standard_answer = data[i]['herb_set'] 388 | model_output = model_api.send_request_TCM_Rec(prompt, question) 389 | model_answer = extract_choice_answer_hard(model_output, keyword, 0)[0] 390 | model_answer = herb_second_check(model_answer) 391 | # TODO: which content of temp we expect 392 | dict = { 393 | 'sym_set': data[i]['sym_set'], 394 | 'model_output': model_answer, 395 | 'herb_set': standard_answer, 396 | } 397 | # print("*" * 100, "index-", dict["index"], "*" * 100) 398 | for key, value in dict.items(): 399 | print(key, ":", value) 400 | # print(dict) 401 | model_answer_dict.append(dict) 402 | 403 | file_name = f"seperate_{start_num}-{end_num}.json" 404 | file_path = os.path.join(save_directory, file_name) 405 | with open(file_path, 'w', encoding='utf-8') as f: 406 | output = { 407 | 'keyword': keyword, 408 | 'example': model_answer_dict 409 | } 410 | json.dump(output, f, ensure_ascii=False, indent=4) 411 | f.close() 412 | 413 | 414 | def export_union_json(directory: str, model_name: str, zero_shot_prompt_text: str or list[str], question_type: str) -> None: 415 | """ 416 | Merges JSON files containing processed examples in a directory into a single JSON file. 417 | 418 | :param directory: Directory containing the JSON files 419 | :param model_name: Name of the model used to process the examples 420 | # :param keyword: Keyword used to identify the JSON files 421 | :param zero_shot_prompt_text: Prompt text for zero-shot learning 422 | :param question_type: Type of questions in the JSON files (e.g. single_choice, five_out_of_seven, etc.) 423 | """ 424 | 425 | save_directory = os.path.join(directory, f'{model_name}_{question_type}') 426 | if os.path.exists(save_directory): 427 | output = { 428 | 'keywords': question_type, 429 | 'model_name': model_name, 430 | 'prompt': zero_shot_prompt_text, 431 | 'example': [] 432 | } 433 | 434 | # Iterate through the JSON files with the specified keyword in the directory 435 | 436 | print("Start to merge json files") 437 | files = [file for file in os.listdir(save_directory) if file.endswith('.json')] 438 | for file in files: 439 | file_path = os.path.join(save_directory, file) 440 | 441 | # Load and merge the data from the JSON files 442 | with open(file_path, "r", encoding='utf-8') as f: 443 | data = json.load(f) 444 | output['example'] += (data['example']) 445 | # Save the merged data into a single JSON file 446 | merge_file = os.path.join(directory, f'{model_name}_{question_type}_predictions.json') 447 | output['example'] = sorted(output['example'], key=lambda x: x['index']) 448 | with open(merge_file, 'w', encoding='utf-8') as f: 449 | json.dump(output, f, ensure_ascii=False, indent=4) 450 | 451 | def export_distribute_json( 452 | model_api, 453 | model_name: str, 454 | directory: str, 455 | # keyword: str, 456 | zero_shot_prompt_text: str or List[str], 457 | question_type: str, 458 | args, 459 | parallel_num: int = 1 460 | ) -> None: 461 | """ 462 | Distributes the task of processing examples in a JSON file across multiple processes. 463 | 464 | :param model_name: Name of the model to use 465 | :param directory: Directory containing the JSON file 466 | 467 | :param zero_shot_prompt_text: Prompt text for zero-shot learning 468 | :param question_type: Type of questions in the JSON file (e.g. single_choice, five_out_of_seven, etc.) 469 | :param parallel_num: Number of parallel processes to use (default: 5) 470 | 471 | """ 472 | # Find the JSON file with the specified keyword 473 | for root, _, files in os.walk(directory): 474 | for file in files: 475 | if file == f'{question_type}.json': 476 | filepath = os.path.join(root, file) 477 | with open(filepath, 'r', encoding='utf-8') as f: 478 | data = json.load(f) 479 | 480 | example_num = len(data['example']) 481 | 482 | # Prepare the list of keyword arguments for parallel processing 483 | kwargs_list = [] 484 | batch_size = example_num // parallel_num 485 | save_directory = os.path.join(directory, f'{model_name}_{question_type}') 486 | if not os.path.exists(save_directory): 487 | os.makedirs(save_directory) 488 | # os.system(f'mkdir {save_directory}') 489 | 490 | for idx in range(args.start_num, parallel_num): 491 | start_num = idx * batch_size 492 | end_num = min(start_num + batch_size, example_num) 493 | if start_num >= example_num: 494 | break 495 | 496 | kwargs = { 497 | 'model_api': model_api, 498 | 'start_num': start_num, 499 | 'end_num': end_num, 500 | 'model_name': model_name, 501 | 'data': data, 502 | 'keyword': question_type, 503 | 'prompt': zero_shot_prompt_text, 504 | 'question_type': question_type, 505 | 'save_directory': save_directory, 506 | 'args': args, 507 | } 508 | 509 | kwargs_list.append(kwargs) 510 | 511 | # Run parallel processing based on the question type 512 | if question_type in ["FKU", "CBF", "SCF", "PF", "SDDT", "DDST", "SDDT_hard"]: 513 | for kwargs in kwargs_list: 514 | choice_test_A12(**kwargs) 515 | elif question_type in ["CVR", "KHC", "SDT", 'SDT_reverse', "SDT_shuffle"]: 516 | for kwargs in kwargs_list: 517 | choice_test_A34(**kwargs) 518 | elif question_type in ["DDST_hard"]: 519 | for kwargs in kwargs_list: 520 | choice_test_DDST_hard(**kwargs) 521 | elif question_type in ["herb_predict"]: 522 | for kwargs in kwargs_list: 523 | choice_test_TCM_Rec(**kwargs) 524 | 525 | 526 | 527 | def test_correction_score_A12(data_dict): 528 | score = 0 529 | all_num = 0 530 | model_answer_dict = [] 531 | correct_answer_list = [] 532 | model_kpoint = {} 533 | for data in data_dict['example']: 534 | all_num += 1 535 | true_answer = data['standard_answer'] 536 | model_answer = data['model_answer'] 537 | knowledge_point = data["knowledge_point"] 538 | if knowledge_point == "": 539 | knowledge_point = "其他" 540 | if knowledge_point not in model_kpoint.keys(): 541 | model_kpoint[knowledge_point] = [0, 0] 542 | model_kpoint[knowledge_point][1] += 1 543 | dict = { 544 | 'index': data["index"], 545 | 'question': data["question"], 546 | 'standard_answer': true_answer, 547 | 'analysis': data["analysis"], 548 | 'knowledge_point': data["knowledge_point"], 549 | 'model_answer': model_answer, 550 | 'model_output': data["model_output"] 551 | } 552 | if true_answer == model_answer: 553 | score += 1 554 | model_kpoint[knowledge_point][0] += 1 555 | correct_answer_list.append(dict) 556 | else: 557 | model_answer_dict.append(dict) 558 | output = {'keyword': data_dict["keyword"], 559 | 'correct_num': score, 560 | 'all_num': all_num} 561 | if len(model_answer_dict) > 0: 562 | output['example'] = model_answer_dict 563 | return score / all_num, output, model_kpoint, score, all_num, correct_answer_list 564 | 565 | 566 | def test_correction_score_A34(data_dict): 567 | score = 0 568 | all_num = 0 569 | model_answer_dict = [] 570 | model_kpoint = {} 571 | for data in data_dict['example']: 572 | correction_flag = True 573 | question = data["question"] 574 | knowledge_point = data["knowledge_point"] 575 | question_list = [] 576 | if knowledge_point == "": 577 | knowledge_point = "其他" 578 | if knowledge_point not in model_kpoint.keys(): 579 | model_kpoint[knowledge_point] = [0, 0] 580 | # all_num += len(question) 581 | for sub_question in question: 582 | all_num += 1 583 | model_kpoint[knowledge_point][1] += 1 584 | standard_answer = sub_question['standard_answer'] 585 | model_answer = sub_question['model_answer'] 586 | if standard_answer == model_answer: 587 | score += 1 588 | model_kpoint[knowledge_point][0] += 1 589 | else: 590 | correction_flag = False 591 | sub_question_dict = { 592 | 'sub_question': sub_question['sub_question'], 593 | 'standard_answer': standard_answer, 594 | 'analysis': sub_question['analysis'], 595 | 'model_answer': model_answer, 596 | 'model_output': sub_question['model_output'] 597 | } 598 | question_list.append(sub_question_dict) 599 | if correction_flag == False: # 有一个错就存起来当错题 600 | dict = { 601 | 'index': data["index"], 602 | 'share_content': data["share_content"], 603 | 'question': question_list, 604 | 'knowledge_point': data["knowledge_point"], 605 | } 606 | model_answer_dict.append(dict) 607 | output = {'keyword': data_dict["keyword"], 608 | 'correct_num': score, 609 | 'all_num': all_num} 610 | if len(model_answer_dict) > 0: 611 | output['example'] = model_answer_dict 612 | 613 | return score / all_num, output, model_kpoint, score, all_num 614 | 615 | 616 | 617 | --------------------------------------------------------------------------------