├── README.md ├── answer_search.py ├── build_medicalgraph.py ├── chat_with_llm.py ├── chatbot_graph.py ├── llm_server.py ├── question_classifier.py ├── question_parser.py ├── qwen7b_server.py └── wechat.jpg /README.md: -------------------------------------------------------------------------------- 1 | # LLMRAGOnMedicaKG 2 | self-implement of disease centered Medical graph from zero to full and sever as question answering base. 从无到有搭建一个以疾病为中心的一定规模医药领域知识图谱,并以该知识图谱,结合LLM完成自动问答与分析服务。 3 | 4 | # 一、项目介绍 5 | 6 | 目前知识图谱在各个领域全面开花,如教育、医疗、司法、金融等。本项目立足医药领域,以垂直型医药网站为数据来源,以疾病为核心,构建起一个包含7类规模为4.4万的知识实体,11类规模约30万实体关系的知识图谱。 7 | 本项目将包括以下两部分的内容: 8 | 1) 基于垂直网站数据的医药知识图谱构建 9 | 2) 基于医药知识图谱的自动问答,基于LLM的方式 10 | 11 | 实际上,我们在之前的项目 (https://github.com/liuhuanyong/QABasedOnMedicalKnowledgeGraph) 中已经开源过基于朴素KG实现方式的问答,其中涉及到知识图谱构建部分,用到的代码、用到的数据,可以从该项目中继承。 12 | 13 | # 二、项目运行方式 14 | 15 | 1、配置要求:要求配置neo4j数据库及相应的python依赖包。neo4j数据库用户名密码记住,并修改相应文件。 16 | 2、知识图谱数据导入:python build_medicalgraph.py,导入的数据较多,估计需要几个小时。 17 | 3、该项目依赖qwen-7b-chat作为底层llm模型,可以执行python qianwen7b_server.py搭建服务 18 | 4、配置服务地址: model = ModelAPI(MODEL_URL="http://你的IP/generate") 19 | 5、开始执行问答:python chat_with_llm.py, 20 | 21 | # 三、医疗知识图谱构建 22 | # 3.1 业务驱动的知识图谱构建框架 23 | ![image](https://github.com/liuhuanyong/QABasedOnMedicalKnowledgeGraph/blob/master/img/kg_route.png) 24 | 25 | # 3.2 脚本目录 26 | prepare_data/datasoider.py:网络资讯采集脚本 27 | prepare_data/datasoider.py:网络资讯采集脚本 28 | prepare_data/max_cut.py:基于词典的最大向前/向后切分脚本 29 | build_medicalgraph.py:知识图谱入库脚本    30 | 31 | # 3.3 医药领域知识图谱规模 32 | 1.3.1 neo4j图数据库存储规模 33 | ![image](https://github.com/liuhuanyong/QABasedOnMedicalKnowledgeGraph/blob/master/img/graph_summary.png) 34 | 35 | 3.3.2 知识图谱实体类型 36 | 37 | | 实体类型 | 中文含义 | 实体数量 |举例 | 38 | | :--- | :---: | :---: | :--- | 39 | | Check | 诊断检查项目 | 3,353| 支气管造影;关节镜检查| 40 | | Department | 医疗科目 | 54 | 整形美容科;烧伤科| 41 | | Disease | 疾病 | 8,807 | 血栓闭塞性脉管炎;胸降主动脉动脉瘤| 42 | | Drug | 药品 | 3,828 | 京万红痔疮膏;布林佐胺滴眼液| 43 | | Food | 食物 | 4,870 | 番茄冲菜牛肉丸汤;竹笋炖羊肉| 44 | | Producer | 在售药品 | 17,201 | 通药制药青霉素V钾片;青阳醋酸地塞米松片| 45 | | Symptom | 疾病症状 | 5,998 | 乳腺组织肥厚;脑实质深部出血| 46 | | Total | 总计 | 44,111 | 约4.4万实体量级| 47 | 48 | 49 | 3.3.3 知识图谱实体关系类型 50 | 51 | | 实体关系类型 | 中文含义 | 关系数量 | 举例| 52 | | :--- | :---: | :---: | :--- | 53 | | belongs_to | 属于 | 8,844| <妇科,属于,妇产科>| 54 | | common_drug | 疾病常用药品 | 14,649 | <阳强,常用,甲磺酸酚妥拉明分散片>| 55 | | do_eat |疾病宜吃食物 | 22,238| <胸椎骨折,宜吃,黑鱼>| 56 | | drugs_of | 药品在售药品 | 17,315| <青霉素V钾片,在售,通药制药青霉素V钾片>| 57 | | need_check | 疾病所需检查 | 39,422| <单侧肺气肿,所需检查,支气管造影>| 58 | | no_eat | 疾病忌吃食物 | 22,247| <唇病,忌吃,杏仁>| 59 | | recommand_drug | 疾病推荐药品 | 59,467 | <混合痔,推荐用药,京万红痔疮膏>| 60 | | recommand_eat | 疾病推荐食谱 | 40,221 | <鞘膜积液,推荐食谱,番茄冲菜牛肉丸汤>| 61 | | has_symptom | 疾病症状 | 5,998 | <早期乳腺癌,疾病症状,乳腺组织肥厚>| 62 | | acompany_with | 疾病并发疾病 | 12,029 | <下肢交通静脉瓣膜关闭不全,并发疾病,血栓闭塞性脉管炎>| 63 | | Total | 总计 | 294,149 | 约30万关系量级| 64 | 65 | 66 | # 四、基于医疗知识图谱的自动问答 67 | 68 | 基本思想: 69 | 70 | step1: linking entity,针对问题进行实体识别,本项目采用基于ac自动机通过加载图谱词表进行匹配获得; 71 | 72 | step2:recall kg facts,通过上一步得到的多个实体,通过prompt的方式提示llm进行实体的意图识别,然后转换成cypher语句进行查询,并过滤兼枝,得到子图路径; 73 | 74 | step3:generate answer,通过召回好的子图,拼接prompt,使用llm完成问答; 75 | 76 | 77 | def chat(self, query): 78 | print("step1: linking entity.....") 79 | entity_dict = self.entity_linking(query) 80 | depth = 1 81 | facts = list() 82 | answer = "" 83 | default = "抱歉,我在知识库中没有找到对应的实体,无法回答。" 84 | if not entity_dict: 85 | print("no entity founded...finished...") 86 | return default 87 | print("step2:recall kg facts....") 88 | for entity_name, types in entity_dict.items(): 89 | for entity_type in types: 90 | rels = self.link_entity_rel(query, entity_name, entity_type) 91 | entity_triples = self.recall_facts(rels, entity_type, entity_name, depth) 92 | facts += entity_triples 93 | fact_prompt = self.format_prompt(query, facts) 94 | print("step3:generate answer...") 95 | answer = model.chat(query=fact_prompt, history=[]) 96 | return answer 97 | 98 | # 总结 99 | 100 | 1、本文完成了引入LLM-KG的方式进行医疗领域RAG的开源方案; 101 | 2、核心思路在于实体识别、子图召回、意图分类,有很多优化空间; 102 | 3、开源的意义是思路指引,而不是一味搬运、索取、坐享其成,大家一同建设好生态; 103 | 104 | 105 | 106 | 107 | 108 | -------------------------------------------------------------------------------- /answer_search.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # coding: utf-8 3 | # File: answer_search.py 4 | # Author: lhy 5 | # Date: 18-10-5 6 | from py2neo import Graph 7 | 8 | class AnswerSearcher: 9 | def __init__(self): 10 | self.g = Graph( 11 | host="127.0.0.1", 12 | http_port=7474, 13 | user="neo4j", 14 | password="lhy123") 15 | self.num_limit = 20 16 | 17 | '''执行cypher查询,并返回相应结果''' 18 | def search_main(self, sqls): 19 | final_answers = [] 20 | for sql_ in sqls: 21 | question_type = sql_['question_type'] 22 | queries = sql_['sql'] 23 | answers = [] 24 | for query in queries: 25 | ress = self.g.run(query).data() 26 | answers += ress 27 | final_answer = self.answer_prettify(question_type, answers) 28 | if final_answer: 29 | final_answers.append(final_answer) 30 | return final_answers 31 | 32 | '''根据对应的qustion_type,调用相应的回复模板''' 33 | def answer_prettify(self, question_type, answers): 34 | final_answer = [] 35 | if not answers: 36 | return '' 37 | if question_type == 'disease_symptom': 38 | desc = [i['n.name'] for i in answers] 39 | subject = answers[0]['m.name'] 40 | final_answer = '{0}的症状包括:{1}'.format(subject, ';'.join(list(set(desc))[:self.num_limit])) 41 | 42 | elif question_type == 'symptom_disease': 43 | desc = [i['m.name'] for i in answers] 44 | subject = answers[0]['n.name'] 45 | final_answer = '症状{0}可能染上的疾病有:{1}'.format(subject, ';'.join(list(set(desc))[:self.num_limit])) 46 | 47 | elif question_type == 'disease_cause': 48 | desc = [i['m.cause'] for i in answers] 49 | subject = answers[0]['m.name'] 50 | final_answer = '{0}可能的成因有:{1}'.format(subject, ';'.join(list(set(desc))[:self.num_limit])) 51 | 52 | elif question_type == 'disease_prevent': 53 | desc = [i['m.prevent'] for i in answers] 54 | subject = answers[0]['m.name'] 55 | final_answer = '{0}的预防措施包括:{1}'.format(subject, ';'.join(list(set(desc))[:self.num_limit])) 56 | 57 | elif question_type == 'disease_lasttime': 58 | desc = [i['m.cure_lasttime'] for i in answers] 59 | subject = answers[0]['m.name'] 60 | final_answer = '{0}治疗可能持续的周期为:{1}'.format(subject, ';'.join(list(set(desc))[:self.num_limit])) 61 | 62 | elif question_type == 'disease_cureway': 63 | desc = [';'.join(i['m.cure_way']) for i in answers] 64 | subject = answers[0]['m.name'] 65 | final_answer = '{0}可以尝试如下治疗:{1}'.format(subject, ';'.join(list(set(desc))[:self.num_limit])) 66 | 67 | elif question_type == 'disease_cureprob': 68 | desc = [i['m.cured_prob'] for i in answers] 69 | subject = answers[0]['m.name'] 70 | final_answer = '{0}治愈的概率为(仅供参考):{1}'.format(subject, ';'.join(list(set(desc))[:self.num_limit])) 71 | 72 | elif question_type == 'disease_easyget': 73 | desc = [i['m.easy_get'] for i in answers] 74 | subject = answers[0]['m.name'] 75 | 76 | final_answer = '{0}的易感人群包括:{1}'.format(subject, ';'.join(list(set(desc))[:self.num_limit])) 77 | 78 | elif question_type == 'disease_desc': 79 | desc = [i['m.desc'] for i in answers] 80 | subject = answers[0]['m.name'] 81 | final_answer = '{0},熟悉一下:{1}'.format(subject, ';'.join(list(set(desc))[:self.num_limit])) 82 | 83 | elif question_type == 'disease_acompany': 84 | desc1 = [i['n.name'] for i in answers] 85 | desc2 = [i['m.name'] for i in answers] 86 | subject = answers[0]['m.name'] 87 | desc = [i for i in desc1 + desc2 if i != subject] 88 | final_answer = '{0}的症状包括:{1}'.format(subject, ';'.join(list(set(desc))[:self.num_limit])) 89 | 90 | elif question_type == 'disease_not_food': 91 | desc = [i['n.name'] for i in answers] 92 | subject = answers[0]['m.name'] 93 | final_answer = '{0}忌食的食物包括有:{1}'.format(subject, ';'.join(list(set(desc))[:self.num_limit])) 94 | 95 | elif question_type == 'disease_do_food': 96 | do_desc = [i['n.name'] for i in answers if i['r.name'] == '宜吃'] 97 | recommand_desc = [i['n.name'] for i in answers if i['r.name'] == '推荐食谱'] 98 | subject = answers[0]['m.name'] 99 | final_answer = '{0}宜食的食物包括有:{1}\n推荐食谱包括有:{2}'.format(subject, ';'.join(list(set(do_desc))[:self.num_limit]), ';'.join(list(set(recommand_desc))[:self.num_limit])) 100 | 101 | elif question_type == 'food_not_disease': 102 | desc = [i['m.name'] for i in answers] 103 | subject = answers[0]['n.name'] 104 | final_answer = '患有{0}的人最好不要吃{1}'.format(';'.join(list(set(desc))[:self.num_limit]), subject) 105 | 106 | elif question_type == 'food_do_disease': 107 | desc = [i['m.name'] for i in answers] 108 | subject = answers[0]['n.name'] 109 | final_answer = '患有{0}的人建议多试试{1}'.format(';'.join(list(set(desc))[:self.num_limit]), subject) 110 | 111 | elif question_type == 'disease_drug': 112 | desc = [i['n.name'] for i in answers] 113 | subject = answers[0]['m.name'] 114 | final_answer = '{0}通常的使用的药品包括:{1}'.format(subject, ';'.join(list(set(desc))[:self.num_limit])) 115 | 116 | elif question_type == 'drug_disease': 117 | desc = [i['m.name'] for i in answers] 118 | subject = answers[0]['n.name'] 119 | final_answer = '{0}主治的疾病有{1},可以试试'.format(subject, ';'.join(list(set(desc))[:self.num_limit])) 120 | 121 | elif question_type == 'disease_check': 122 | desc = [i['n.name'] for i in answers] 123 | subject = answers[0]['m.name'] 124 | final_answer = '{0}通常可以通过以下方式检查出来:{1}'.format(subject, ';'.join(list(set(desc))[:self.num_limit])) 125 | 126 | elif question_type == 'check_disease': 127 | desc = [i['m.name'] for i in answers] 128 | subject = answers[0]['n.name'] 129 | final_answer = '通常可以通过{0}检查出来的疾病有{1}'.format(subject, ';'.join(list(set(desc))[:self.num_limit])) 130 | 131 | return final_answer 132 | 133 | if __name__ == '__main__': 134 | searcher = AnswerSearcher() -------------------------------------------------------------------------------- /build_medicalgraph.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # coding: utf-8 3 | # File: MedicalGraph.py 4 | # Author: lhy 5 | # Date: 18-10-3 6 | 7 | import os 8 | import json 9 | from py2neo import Graph,Node 10 | 11 | class MedicalGraph: 12 | def __init__(self): 13 | cur_dir = '/'.join(os.path.abspath(__file__).split('/')[:-1]) 14 | self.data_path = os.path.join(cur_dir, 'data/medical.json') 15 | self.g = Graph( 16 | host="127.0.0.1", # neo4j 搭载服务器的ip地址,ifconfig可获取到 17 | http_port=7474, # neo4j 服务器监听的端口号 18 | user="neo4j", # 数据库user name,如果没有更改过,应该是neo4j 19 | password="123456") 20 | 21 | '''读取文件''' 22 | def read_nodes(self): 23 | # 共7类节点 24 | drugs = [] # 药品 25 | foods = [] # 食物 26 | checks = [] # 检查 27 | departments = [] #科室 28 | producers = [] #药品大类 29 | diseases = [] #疾病 30 | symptoms = []#症状 31 | 32 | disease_infos = []#疾病信息 33 | 34 | # 构建节点实体关系 35 | rels_department = [] # 科室-科室关系 36 | rels_noteat = [] # 疾病-忌吃食物关系 37 | rels_doeat = [] # 疾病-宜吃食物关系 38 | rels_recommandeat = [] # 疾病-推荐吃食物关系 39 | rels_commonddrug = [] # 疾病-通用药品关系 40 | rels_recommanddrug = [] # 疾病-热门药品关系 41 | rels_check = [] # 疾病-检查关系 42 | rels_drug_producer = [] # 厂商-药物关系 43 | 44 | rels_symptom = [] #疾病症状关系 45 | rels_acompany = [] # 疾病并发关系 46 | rels_category = [] # 疾病与科室之间的关系 47 | 48 | 49 | count = 0 50 | for data in open(self.data_path): 51 | disease_dict = {} 52 | count += 1 53 | print(count) 54 | data_json = json.loads(data) 55 | disease = data_json['name'] 56 | disease_dict['name'] = disease 57 | diseases.append(disease) 58 | disease_dict['desc'] = '' 59 | disease_dict['prevent'] = '' 60 | disease_dict['cause'] = '' 61 | disease_dict['easy_get'] = '' 62 | disease_dict['cure_department'] = '' 63 | disease_dict['cure_way'] = '' 64 | disease_dict['cure_lasttime'] = '' 65 | disease_dict['symptom'] = '' 66 | disease_dict['cured_prob'] = '' 67 | 68 | if 'symptom' in data_json: 69 | symptoms += data_json['symptom'] 70 | for symptom in data_json['symptom']: 71 | rels_symptom.append([disease, symptom]) 72 | 73 | if 'acompany' in data_json: 74 | for acompany in data_json['acompany']: 75 | rels_acompany.append([disease, acompany]) 76 | 77 | if 'desc' in data_json: 78 | disease_dict['desc'] = data_json['desc'] 79 | 80 | if 'prevent' in data_json: 81 | disease_dict['prevent'] = data_json['prevent'] 82 | 83 | if 'cause' in data_json: 84 | disease_dict['cause'] = data_json['cause'] 85 | 86 | if 'get_prob' in data_json: 87 | disease_dict['get_prob'] = data_json['get_prob'] 88 | 89 | if 'easy_get' in data_json: 90 | disease_dict['easy_get'] = data_json['easy_get'] 91 | 92 | if 'cure_department' in data_json: 93 | cure_department = data_json['cure_department'] 94 | if len(cure_department) == 1: 95 | rels_category.append([disease, cure_department[0]]) 96 | if len(cure_department) == 2: 97 | big = cure_department[0] 98 | small = cure_department[1] 99 | rels_department.append([small, big]) 100 | rels_category.append([disease, small]) 101 | 102 | disease_dict['cure_department'] = cure_department 103 | departments += cure_department 104 | 105 | if 'cure_way' in data_json: 106 | disease_dict['cure_way'] = data_json['cure_way'] 107 | 108 | if 'cure_lasttime' in data_json: 109 | disease_dict['cure_lasttime'] = data_json['cure_lasttime'] 110 | 111 | if 'cured_prob' in data_json: 112 | disease_dict['cured_prob'] = data_json['cured_prob'] 113 | 114 | if 'common_drug' in data_json: 115 | common_drug = data_json['common_drug'] 116 | for drug in common_drug: 117 | rels_commonddrug.append([disease, drug]) 118 | drugs += common_drug 119 | 120 | if 'recommand_drug' in data_json: 121 | recommand_drug = data_json['recommand_drug'] 122 | drugs += recommand_drug 123 | for drug in recommand_drug: 124 | rels_recommanddrug.append([disease, drug]) 125 | 126 | if 'not_eat' in data_json: 127 | not_eat = data_json['not_eat'] 128 | for _not in not_eat: 129 | rels_noteat.append([disease, _not]) 130 | 131 | foods += not_eat 132 | do_eat = data_json['do_eat'] 133 | for _do in do_eat: 134 | rels_doeat.append([disease, _do]) 135 | 136 | foods += do_eat 137 | recommand_eat = data_json['recommand_eat'] 138 | 139 | for _recommand in recommand_eat: 140 | rels_recommandeat.append([disease, _recommand]) 141 | foods += recommand_eat 142 | 143 | if 'check' in data_json: 144 | check = data_json['check'] 145 | for _check in check: 146 | rels_check.append([disease, _check]) 147 | checks += check 148 | if 'drug_detail' in data_json: 149 | drug_detail = data_json['drug_detail'] 150 | producer = [i.split('(')[0] for i in drug_detail] 151 | rels_drug_producer += [[i.split('(')[0], i.split('(')[-1].replace(')', '')] for i in drug_detail] 152 | producers += producer 153 | disease_infos.append(disease_dict) 154 | return set(drugs), set(foods), set(checks), set(departments), set(producers), set(symptoms), set(diseases), disease_infos,\ 155 | rels_check, rels_recommandeat, rels_noteat, rels_doeat, rels_department, rels_commonddrug, rels_drug_producer, rels_recommanddrug,\ 156 | rels_symptom, rels_acompany, rels_category 157 | 158 | '''建立节点''' 159 | def create_node(self, label, nodes): 160 | count = 0 161 | for node_name in nodes: 162 | node = Node(label, name=node_name) 163 | self.g.create(node) 164 | count += 1 165 | print(count, len(nodes)) 166 | return 167 | 168 | '''创建知识图谱中心疾病的节点''' 169 | def create_diseases_nodes(self, disease_infos): 170 | count = 0 171 | for disease_dict in disease_infos: 172 | node = Node("Disease", name=disease_dict['name'], desc=disease_dict['desc'], 173 | prevent=disease_dict['prevent'] ,cause=disease_dict['cause'], 174 | easy_get=disease_dict['easy_get'],cure_lasttime=disease_dict['cure_lasttime'], 175 | cure_department=disease_dict['cure_department'] 176 | ,cure_way=disease_dict['cure_way'] , cured_prob=disease_dict['cured_prob']) 177 | self.g.create(node) 178 | count += 1 179 | print(count) 180 | return 181 | 182 | '''创建知识图谱实体节点类型schema''' 183 | def create_graphnodes(self): 184 | Drugs, Foods, Checks, Departments, Producers, Symptoms, Diseases, disease_infos,rels_check, rels_recommandeat, rels_noteat, rels_doeat, rels_department, rels_commonddrug, rels_drug_producer, rels_recommanddrug,rels_symptom, rels_acompany, rels_category = self.read_nodes() 185 | self.create_diseases_nodes(disease_infos) 186 | self.create_node('Drug', Drugs) 187 | print(len(Drugs)) 188 | self.create_node('Food', Foods) 189 | print(len(Foods)) 190 | self.create_node('Check', Checks) 191 | print(len(Checks)) 192 | self.create_node('Department', Departments) 193 | print(len(Departments)) 194 | self.create_node('Producer', Producers) 195 | print(len(Producers)) 196 | self.create_node('Symptom', Symptoms) 197 | return 198 | 199 | '''创建实体关系边''' 200 | def create_graphrels(self): 201 | Drugs, Foods, Checks, Departments, Producers, Symptoms, Diseases, disease_infos, rels_check, rels_recommandeat, rels_noteat, rels_doeat, rels_department, rels_commonddrug, rels_drug_producer, rels_recommanddrug,rels_symptom, rels_acompany, rels_category = self.read_nodes() 202 | self.create_relationship('Disease', 'Food', rels_recommandeat, 'recommand_eat', '推荐食谱') 203 | self.create_relationship('Disease', 'Food', rels_noteat, 'no_eat', '忌吃') 204 | self.create_relationship('Disease', 'Food', rels_doeat, 'do_eat', '宜吃') 205 | self.create_relationship('Department', 'Department', rels_department, 'belongs_to', '属于') 206 | self.create_relationship('Disease', 'Drug', rels_commonddrug, 'common_drug', '常用药品') 207 | self.create_relationship('Producer', 'Drug', rels_drug_producer, 'drugs_of', '生产药品') 208 | self.create_relationship('Disease', 'Drug', rels_recommanddrug, 'recommand_drug', '好评药品') 209 | self.create_relationship('Disease', 'Check', rels_check, 'need_check', '诊断检查') 210 | self.create_relationship('Disease', 'Symptom', rels_symptom, 'has_symptom', '症状') 211 | self.create_relationship('Disease', 'Disease', rels_acompany, 'acompany_with', '并发症') 212 | self.create_relationship('Disease', 'Department', rels_category, 'belongs_to', '所属科室') 213 | 214 | '''创建实体关联边''' 215 | def create_relationship(self, start_node, end_node, edges, rel_type, rel_name): 216 | count = 0 217 | # 去重处理 218 | set_edges = [] 219 | for edge in edges: 220 | set_edges.append('###'.join(edge)) 221 | all = len(set(set_edges)) 222 | for edge in set(set_edges): 223 | edge = edge.split('###') 224 | p = edge[0] 225 | q = edge[1] 226 | query = "match(p:%s),(q:%s) where p.name='%s'and q.name='%s' create (p)-[rel:%s{name:'%s'}]->(q)" % ( 227 | start_node, end_node, p, q, rel_type, rel_name) 228 | try: 229 | self.g.run(query) 230 | count += 1 231 | print(rel_type, count, all) 232 | except Exception as e: 233 | print(e) 234 | return 235 | 236 | '''导出数据''' 237 | def export_data(self): 238 | Drugs, Foods, Checks, Departments, Producers, Symptoms, Diseases, disease_infos, rels_check, rels_recommandeat, rels_noteat, rels_doeat, rels_department, rels_commonddrug, rels_drug_producer, rels_recommanddrug, rels_symptom, rels_acompany, rels_category = self.read_nodes() 239 | f_drug = open('drug.txt', 'w+') 240 | f_food = open('food.txt', 'w+') 241 | f_check = open('check.txt', 'w+') 242 | f_department = open('department.txt', 'w+') 243 | f_producer = open('producer.txt', 'w+') 244 | f_symptom = open('symptoms.txt', 'w+') 245 | f_disease = open('disease.txt', 'w+') 246 | 247 | f_drug.write('\n'.join(list(Drugs))) 248 | f_food.write('\n'.join(list(Foods))) 249 | f_check.write('\n'.join(list(Checks))) 250 | f_department.write('\n'.join(list(Departments))) 251 | f_producer.write('\n'.join(list(Producers))) 252 | f_symptom.write('\n'.join(list(Symptoms))) 253 | f_disease.write('\n'.join(list(Diseases))) 254 | 255 | f_drug.close() 256 | f_food.close() 257 | f_check.close() 258 | f_department.close() 259 | f_producer.close() 260 | f_symptom.close() 261 | f_disease.close() 262 | 263 | return 264 | 265 | 266 | 267 | if __name__ == '__main__': 268 | handler = MedicalGraph() 269 | print("step1:导入图谱节点中") 270 | handler.create_graphnodes() 271 | print("step2:导入图谱边中") 272 | handler.create_graphrels() 273 | 274 | -------------------------------------------------------------------------------- /chat_with_llm.py: -------------------------------------------------------------------------------- 1 | # ccoding = utf-8 2 | import os 3 | from question_classifier import * 4 | from question_parser import * 5 | from llm_server import * 6 | from build_medicalgraph import * 7 | import re 8 | 9 | entity_parser = QuestionClassifier() 10 | 11 | kg = MedicalGraph() 12 | model = ModelAPI(MODEL_URL="http://你的IP:3001/generate") 13 | 14 | class KGRAG(): 15 | def __init__(self): 16 | self.cn_dict = { 17 | "name":"名称", 18 | "desc":"疾病简介", 19 | "cause":"疾病病因", 20 | "prevent":"预防措施", 21 | "cure_department":"治疗科室", 22 | "cure_lasttime":"治疗周期", 23 | "cure_way":"治疗方式", 24 | "cured_prob":"治愈概率", 25 | "easy_get":"易感人群", 26 | "belongs_to":"所属科室", 27 | "common_drug":"常用药品", 28 | "do_eat":"宜吃", 29 | "drugs_of":"生产药品", 30 | "need_check":"诊断检查", 31 | "no_eat":"忌吃", 32 | "recommand_drug":"好评药品", 33 | "recommand_eat":"推荐食谱", 34 | "has_symptom":"症状", 35 | "acompany_with":"并发症", 36 | "Check":"诊断检查项目", 37 | "Department":"医疗科目", 38 | "Disease":"疾病", 39 | "Drug":"药品", 40 | "Food":"食物", 41 | "Producer":"在售药品", 42 | "Symptom":"疾病症状" 43 | } 44 | self.entity_rel_dict = { 45 | "check":["name", 'need_check'], 46 | "department":["name", 'belongs_to'], 47 | "disease":["prevent", "cure_way", "name", "cure_lasttime", "cured_prob", "cause", "cure_department", "desc", "easy_get", 'recommand_eat', 'no_eat', 'do_eat', "common_drug", 'drugs_of', 'recommand_drug', 'need_check', 'has_symptom', 'acompany_with', 'belongs_to'], 48 | "drug":["name", "common_drug", 'drugs_of', 'recommand_drug'], 49 | "food":["name"], 50 | "producer":["name"], 51 | "symptom":["name", 'has_symptom'], 52 | } 53 | return 54 | 55 | def entity_linking(self, query): 56 | return entity_parser.check_medical(query) 57 | 58 | def link_entity_rel(self, query, entity, entity_type): 59 | cate = [self.cn_dict.get(i) for i in self.entity_rel_dict.get(entity_type)] 60 | prompt = "请判定问题:{query}所提及的是{entity}的哪几个信息,请从{cate}中进行选择,并以列表形式返回。".format(query=query, entity=entity, cate=cate) 61 | answer, history = model.chat(query=prompt, history=[]) 62 | cls_rel = set([i for i in re.split(r"[\[。、, ;'\]]", answer)]).intersection(set(cate)) 63 | print([prompt, answer, cls_rel]) 64 | return cls_rel 65 | 66 | def recall_facts(self, cls_rel, entity_type, entity_name, depth=1): 67 | entity_dict = { 68 | "check":"Check", 69 | "department":"Department", 70 | "disease":"Disease", 71 | "drug":"Drug", 72 | "food":"Food", 73 | "producer":"Producer", 74 | "symptom":"Symptom" 75 | } 76 | # "MATCH p=(m:Disease)-[r*..2]-(n) where m.name = '耳聋' return p " 77 | sql = "MATCH p=(m:{entity_type})-[r*..{depth}]-(n) where m.name = '{entity_name}' return p".format(depth=depth, entity_type=entity_dict.get(entity_type), entity_name=entity_name) 78 | print(sql) 79 | ress = kg.g.run(sql).data() 80 | triples = set() 81 | for res in ress: 82 | p_data = res["p"] 83 | nodes = p_data.nodes 84 | rels = p_data.relationships 85 | for node in nodes: 86 | node_name = node["name"] 87 | for k,v in node.items(): 88 | # print(k) 89 | if v == node_name: 90 | continue 91 | if self.cn_dict[k] not in cls_rel: 92 | continue 93 | triples.add("<" + ','.join([str(node_name), str(self.cn_dict[k]), str(v)]) + ">") 94 | for rel in rels: 95 | if rel.start_node["name"] == rel.end_node["name"]: 96 | continue 97 | # print(rel["name"]) 98 | if rel["name"] not in cls_rel: 99 | continue 100 | triples.add("<" + ','.join([str(rel.start_node["name"]), str(rel["name"]), str(rel.end_node["name"])]) + ">") 101 | print(len(triples), list(triples)[:3]) 102 | return list(triples) 103 | 104 | 105 | def format_prompt(self, query, context): 106 | prompt = "这是一个关于医疗领域的问题。给定以下知识三元组集合,三元组形式为,表示subject和object之间存在relation关系" \ 107 | "请先从这些三元组集合中找到能够支撑问题的部分,在这里叫做证据,并基于此回答问题。如果没有找到,那么直接回答没有找到证据,回答不知道,如果找到了,请先回答证据的内容,然后在给出最终答案" \ 108 | "知识三元组集合为:" + str(context) + "\n问题是:" + query + "\n请回答:" 109 | return prompt 110 | 111 | def chat(self, query): 112 | "{'耳聋': ['disease', 'symptom']}" 113 | print("step1: linking entity.....") 114 | entity_dict = self.entity_linking(query) 115 | depth = 1 116 | facts = list() 117 | answer = "" 118 | default = "抱歉,我在知识库中没有找到对应的实体,无法回答。" 119 | if not entity_dict: 120 | print("no entity founded...finished...") 121 | return default 122 | print("step2:recall kg facts....") 123 | for entity_name, types in entity_dict.items(): 124 | for entity_type in types: 125 | rels = self.link_entity_rel(query, entity_name, entity_type) 126 | entity_triples = self.recall_facts(rels, entity_type, entity_name, depth) 127 | facts += entity_triples 128 | fact_prompt = self.format_prompt(query, facts) 129 | print("step3:generate answer...") 130 | answer = model.chat(query=fact_prompt, history=[]) 131 | return answer 132 | 133 | if __name__ == "__main__": 134 | chatbot = KGRAG() 135 | while 1: 136 | query = input("USER:").strip() 137 | answer = chatbot.chat(query) 138 | print("KGRAG_BOT:", answer) 139 | -------------------------------------------------------------------------------- /chatbot_graph.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # coding: utf-8 3 | # File: chatbot_graph.py 4 | # Author: lhy 5 | # Date: 18-10-4 6 | 7 | from question_classifier import * 8 | from question_parser import * 9 | from answer_search import * 10 | 11 | '''问答类''' 12 | class ChatBotGraph: 13 | def __init__(self): 14 | self.classifier = QuestionClassifier() 15 | self.parser = QuestionPaser() 16 | self.searcher = AnswerSearcher() 17 | 18 | def chat_main(self, sent): 19 | answer = '您好,我是小勇医药智能助理,希望可以帮到您。如果没答上来,可联系https://liuhuanyong.github.io/。祝您身体棒棒!' 20 | res_classify = self.classifier.classify(sent) 21 | if not res_classify: 22 | return answer 23 | res_sql = self.parser.parser_main(res_classify) 24 | final_answers = self.searcher.search_main(res_sql) 25 | if not final_answers: 26 | return answer 27 | else: 28 | return '\n'.join(final_answers) 29 | 30 | if __name__ == '__main__': 31 | handler = ChatBotGraph() 32 | while 1: 33 | question = input('用户:') 34 | answer = handler.chat_main(question) 35 | print('小勇:', answer) 36 | 37 | -------------------------------------------------------------------------------- /llm_server.py: -------------------------------------------------------------------------------- 1 | # coding = utf-8 2 | import os 3 | import re 4 | from tqdm import tqdm 5 | import requests 6 | import json 7 | import time 8 | 9 | 10 | class ModelAPI(): 11 | def __init__(self, MODEL_URL): 12 | self.url = MODEL_URL 13 | return 14 | 15 | def send_request(self, message, history): 16 | data = json.dumps({"message":message, "history":history}) 17 | headers = {'Content-Type': 'application/json'} 18 | try: 19 | res = requests.post(self.url, data=data, headers=headers) 20 | print(res) 21 | predict = json.loads(res.text)["output"][0] 22 | history = json.loads(res.text)["history"] 23 | return predict, history 24 | except Exception as e: 25 | print("request error", e) 26 | return "", [] 27 | 28 | ## 防止并不稳定,需要多次访问 29 | def chat(self, query, history=[]): 30 | message = [{"role": "user", "content": query}] 31 | count = 0 32 | response = '' 33 | history = [] 34 | while count <=10: 35 | try: 36 | count +=1 37 | response, history = self.send_request(message, history) 38 | if response: 39 | return response, history 40 | except Exception as e: 41 | print('Exception:', e) 42 | time.sleep(1) 43 | return response, history 44 | 45 | if __name__ == '__main__': 46 | model = ModelAPI(MODEL_URL="http://xxxxxxx:6666/generate") 47 | res= model.chat(query="你叫啥", history=[]) 48 | print(res) 49 | -------------------------------------------------------------------------------- /question_classifier.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # coding: utf-8 3 | # File: question_classifier.py 4 | # Author: lhy 5 | # Date: 18-10-4 6 | 7 | import os 8 | import ahocorasick 9 | 10 | class QuestionClassifier: 11 | def __init__(self): 12 | cur_dir = '/'.join(os.path.abspath(__file__).split('/')[:-1]) 13 | # 特征词路径 14 | self.disease_path = os.path.join(cur_dir, 'dict/disease.txt') 15 | self.department_path = os.path.join(cur_dir, 'dict/department.txt') 16 | self.check_path = os.path.join(cur_dir, 'dict/check.txt') 17 | self.drug_path = os.path.join(cur_dir, 'dict/drug.txt') 18 | self.food_path = os.path.join(cur_dir, 'dict/food.txt') 19 | self.producer_path = os.path.join(cur_dir, 'dict/producer.txt') 20 | self.symptom_path = os.path.join(cur_dir, 'dict/symptom.txt') 21 | self.deny_path = os.path.join(cur_dir, 'dict/deny.txt') 22 | # 加载特征词 23 | self.disease_wds= [i.strip() for i in open(self.disease_path) if i.strip()] 24 | self.department_wds= [i.strip() for i in open(self.department_path) if i.strip()] 25 | self.check_wds= [i.strip() for i in open(self.check_path) if i.strip()] 26 | self.drug_wds= [i.strip() for i in open(self.drug_path) if i.strip()] 27 | self.food_wds= [i.strip() for i in open(self.food_path) if i.strip()] 28 | self.producer_wds= [i.strip() for i in open(self.producer_path) if i.strip()] 29 | self.symptom_wds= [i.strip() for i in open(self.symptom_path) if i.strip()] 30 | self.region_words = set(self.department_wds + self.disease_wds + self.check_wds + self.drug_wds + self.food_wds + self.producer_wds + self.symptom_wds) 31 | self.deny_words = [i.strip() for i in open(self.deny_path) if i.strip()] 32 | # 构造领域actree 33 | self.region_tree = self.build_actree(list(self.region_words)) 34 | # 构建词典 35 | self.wdtype_dict = self.build_wdtype_dict() 36 | # 问句疑问词 37 | self.symptom_qwds = ['症状', '表征', '现象', '症候', '表现'] 38 | self.cause_qwds = ['原因','成因', '为什么', '怎么会', '怎样才', '咋样才', '怎样会', '如何会', '为啥', '为何', '如何才会', '怎么才会', '会导致', '会造成'] 39 | self.acompany_qwds = ['并发症', '并发', '一起发生', '一并发生', '一起出现', '一并出现', '一同发生', '一同出现', '伴随发生', '伴随', '共现'] 40 | self.food_qwds = ['饮食', '饮用', '吃', '食', '伙食', '膳食', '喝', '菜' ,'忌口', '补品', '保健品', '食谱', '菜谱', '食用', '食物','补品'] 41 | self.drug_qwds = ['药', '药品', '用药', '胶囊', '口服液', '炎片'] 42 | self.prevent_qwds = ['预防', '防范', '抵制', '抵御', '防止','躲避','逃避','避开','免得','逃开','避开','避掉','躲开','躲掉','绕开', 43 | '怎样才能不', '怎么才能不', '咋样才能不','咋才能不', '如何才能不', 44 | '怎样才不', '怎么才不', '咋样才不','咋才不', '如何才不', 45 | '怎样才可以不', '怎么才可以不', '咋样才可以不', '咋才可以不', '如何可以不', 46 | '怎样才可不', '怎么才可不', '咋样才可不', '咋才可不', '如何可不'] 47 | self.lasttime_qwds = ['周期', '多久', '多长时间', '多少时间', '几天', '几年', '多少天', '多少小时', '几个小时', '多少年'] 48 | self.cureway_qwds = ['怎么治疗', '如何医治', '怎么医治', '怎么治', '怎么医', '如何治', '医治方式', '疗法', '咋治', '怎么办', '咋办', '咋治'] 49 | self.cureprob_qwds = ['多大概率能治好', '多大几率能治好', '治好希望大么', '几率', '几成', '比例', '可能性', '能治', '可治', '可以治', '可以医'] 50 | self.easyget_qwds = ['易感人群', '容易感染', '易发人群', '什么人', '哪些人', '感染', '染上', '得上'] 51 | self.check_qwds = ['检查', '检查项目', '查出', '检查', '测出', '试出'] 52 | self.belong_qwds = ['属于什么科', '属于', '什么科', '科室'] 53 | self.cure_qwds = ['治疗什么', '治啥', '治疗啥', '医治啥', '治愈啥', '主治啥', '主治什么', '有什么用', '有何用', '用处', '用途', 54 | '有什么好处', '有什么益处', '有何益处', '用来', '用来做啥', '用来作甚', '需要', '要'] 55 | 56 | print('model init finished ......') 57 | 58 | return 59 | 60 | '''分类主函数''' 61 | def classify(self, question): 62 | data = {} 63 | medical_dict = self.check_medical(question) 64 | if not medical_dict: 65 | return {} 66 | data['args'] = medical_dict 67 | #收集问句当中所涉及到的实体类型 68 | types = [] 69 | for type_ in medical_dict.values(): 70 | types += type_ 71 | question_type = 'others' 72 | 73 | question_types = [] 74 | 75 | # 症状 76 | if self.check_words(self.symptom_qwds, question) and ('disease' in types): 77 | question_type = 'disease_symptom' 78 | question_types.append(question_type) 79 | 80 | if self.check_words(self.symptom_qwds, question) and ('symptom' in types): 81 | question_type = 'symptom_disease' 82 | question_types.append(question_type) 83 | 84 | # 原因 85 | if self.check_words(self.cause_qwds, question) and ('disease' in types): 86 | question_type = 'disease_cause' 87 | question_types.append(question_type) 88 | # 并发症 89 | if self.check_words(self.acompany_qwds, question) and ('disease' in types): 90 | question_type = 'disease_acompany' 91 | question_types.append(question_type) 92 | 93 | # 推荐食品 94 | if self.check_words(self.food_qwds, question) and 'disease' in types: 95 | deny_status = self.check_words(self.deny_words, question) 96 | if deny_status: 97 | question_type = 'disease_not_food' 98 | else: 99 | question_type = 'disease_do_food' 100 | question_types.append(question_type) 101 | 102 | #已知食物找疾病 103 | if self.check_words(self.food_qwds+self.cure_qwds, question) and 'food' in types: 104 | deny_status = self.check_words(self.deny_words, question) 105 | if deny_status: 106 | question_type = 'food_not_disease' 107 | else: 108 | question_type = 'food_do_disease' 109 | question_types.append(question_type) 110 | 111 | # 推荐药品 112 | if self.check_words(self.drug_qwds, question) and 'disease' in types: 113 | question_type = 'disease_drug' 114 | question_types.append(question_type) 115 | 116 | # 药品治啥病 117 | if self.check_words(self.cure_qwds, question) and 'drug' in types: 118 | question_type = 'drug_disease' 119 | question_types.append(question_type) 120 | 121 | # 疾病接受检查项目 122 | if self.check_words(self.check_qwds, question) and 'disease' in types: 123 | question_type = 'disease_check' 124 | question_types.append(question_type) 125 | 126 | # 已知检查项目查相应疾病 127 | if self.check_words(self.check_qwds+self.cure_qwds, question) and 'check' in types: 128 | question_type = 'check_disease' 129 | question_types.append(question_type) 130 | 131 | # 症状防御 132 | if self.check_words(self.prevent_qwds, question) and 'disease' in types: 133 | question_type = 'disease_prevent' 134 | question_types.append(question_type) 135 | 136 | # 疾病医疗周期 137 | if self.check_words(self.lasttime_qwds, question) and 'disease' in types: 138 | question_type = 'disease_lasttime' 139 | question_types.append(question_type) 140 | 141 | # 疾病治疗方式 142 | if self.check_words(self.cureway_qwds, question) and 'disease' in types: 143 | question_type = 'disease_cureway' 144 | question_types.append(question_type) 145 | 146 | # 疾病治愈可能性 147 | if self.check_words(self.cureprob_qwds, question) and 'disease' in types: 148 | question_type = 'disease_cureprob' 149 | question_types.append(question_type) 150 | 151 | # 疾病易感染人群 152 | if self.check_words(self.easyget_qwds, question) and 'disease' in types : 153 | question_type = 'disease_easyget' 154 | question_types.append(question_type) 155 | 156 | # 若没有查到相关的外部查询信息,那么则将该疾病的描述信息返回 157 | if question_types == [] and 'disease' in types: 158 | question_types = ['disease_desc'] 159 | 160 | # 若没有查到相关的外部查询信息,那么则将该疾病的描述信息返回 161 | if question_types == [] and 'symptom' in types: 162 | question_types = ['symptom_disease'] 163 | 164 | # 将多个分类结果进行合并处理,组装成一个字典 165 | data['question_types'] = question_types 166 | 167 | return data 168 | 169 | '''构造词对应的类型''' 170 | def build_wdtype_dict(self): 171 | wd_dict = dict() 172 | for wd in self.region_words: 173 | wd_dict[wd] = [] 174 | if wd in self.disease_wds: 175 | wd_dict[wd].append('disease') 176 | if wd in self.department_wds: 177 | wd_dict[wd].append('department') 178 | if wd in self.check_wds: 179 | wd_dict[wd].append('check') 180 | if wd in self.drug_wds: 181 | wd_dict[wd].append('drug') 182 | if wd in self.food_wds: 183 | wd_dict[wd].append('food') 184 | if wd in self.symptom_wds: 185 | wd_dict[wd].append('symptom') 186 | if wd in self.producer_wds: 187 | wd_dict[wd].append('producer') 188 | return wd_dict 189 | 190 | '''构造actree,加速过滤''' 191 | def build_actree(self, wordlist): 192 | actree = ahocorasick.Automaton() 193 | for index, word in enumerate(wordlist): 194 | actree.add_word(word, (index, word)) 195 | actree.make_automaton() 196 | return actree 197 | 198 | '''问句过滤''' 199 | def check_medical(self, question): 200 | region_wds = [] 201 | for i in self.region_tree.iter(question): 202 | wd = i[1][1] 203 | region_wds.append(wd) 204 | stop_wds = [] 205 | for wd1 in region_wds: 206 | for wd2 in region_wds: 207 | if wd1 in wd2 and wd1 != wd2: 208 | stop_wds.append(wd1) 209 | final_wds = [i for i in region_wds if i not in stop_wds] 210 | final_dict = {i:self.wdtype_dict.get(i) for i in final_wds} 211 | 212 | return final_dict 213 | 214 | '''基于特征词进行分类''' 215 | def check_words(self, wds, sent): 216 | for wd in wds: 217 | if wd in sent: 218 | return True 219 | return False 220 | 221 | if __name__ == '__main__': 222 | handler = QuestionClassifier() 223 | while 1: 224 | question = input('input an question:') 225 | data = handler.classify(question) 226 | print(data) -------------------------------------------------------------------------------- /question_parser.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # coding: utf-8 3 | # File: question_parser.py 4 | # Author: lhy 5 | # Date: 18-10-4 6 | 7 | class QuestionPaser: 8 | '''构建实体节点''' 9 | def build_entitydict(self, args): 10 | entity_dict = {} 11 | for arg, types in args.items(): 12 | for type in types: 13 | if type not in entity_dict: 14 | entity_dict[type] = [arg] 15 | else: 16 | entity_dict[type].append(arg) 17 | return entity_dict 18 | 19 | '''解析主函数''' 20 | def parser_main(self, res_classify): 21 | args = res_classify['args'] 22 | entity_dict = self.build_entitydict(args) 23 | question_types = res_classify['question_types'] 24 | sqls = [] 25 | for question_type in question_types: 26 | sql_ = {} 27 | sql_['question_type'] = question_type 28 | sql = [] 29 | if question_type == 'disease_symptom': 30 | sql = self.sql_transfer(question_type, entity_dict.get('disease')) 31 | 32 | elif question_type == 'symptom_disease': 33 | sql = self.sql_transfer(question_type, entity_dict.get('symptom')) 34 | 35 | elif question_type == 'disease_cause': 36 | sql = self.sql_transfer(question_type, entity_dict.get('disease')) 37 | 38 | elif question_type == 'disease_acompany': 39 | sql = self.sql_transfer(question_type, entity_dict.get('disease')) 40 | 41 | elif question_type == 'disease_not_food': 42 | sql = self.sql_transfer(question_type, entity_dict.get('disease')) 43 | 44 | elif question_type == 'disease_do_food': 45 | sql = self.sql_transfer(question_type, entity_dict.get('disease')) 46 | 47 | elif question_type == 'food_not_disease': 48 | sql = self.sql_transfer(question_type, entity_dict.get('food')) 49 | 50 | elif question_type == 'food_do_disease': 51 | sql = self.sql_transfer(question_type, entity_dict.get('food')) 52 | 53 | elif question_type == 'disease_drug': 54 | sql = self.sql_transfer(question_type, entity_dict.get('disease')) 55 | 56 | elif question_type == 'drug_disease': 57 | sql = self.sql_transfer(question_type, entity_dict.get('drug')) 58 | 59 | elif question_type == 'disease_check': 60 | sql = self.sql_transfer(question_type, entity_dict.get('disease')) 61 | 62 | elif question_type == 'check_disease': 63 | sql = self.sql_transfer(question_type, entity_dict.get('check')) 64 | 65 | elif question_type == 'disease_prevent': 66 | sql = self.sql_transfer(question_type, entity_dict.get('disease')) 67 | 68 | elif question_type == 'disease_lasttime': 69 | sql = self.sql_transfer(question_type, entity_dict.get('disease')) 70 | 71 | elif question_type == 'disease_cureway': 72 | sql = self.sql_transfer(question_type, entity_dict.get('disease')) 73 | 74 | elif question_type == 'disease_cureprob': 75 | sql = self.sql_transfer(question_type, entity_dict.get('disease')) 76 | 77 | elif question_type == 'disease_easyget': 78 | sql = self.sql_transfer(question_type, entity_dict.get('disease')) 79 | 80 | elif question_type == 'disease_desc': 81 | sql = self.sql_transfer(question_type, entity_dict.get('disease')) 82 | 83 | if sql: 84 | sql_['sql'] = sql 85 | 86 | sqls.append(sql_) 87 | 88 | return sqls 89 | 90 | '''针对不同的问题,分开进行处理''' 91 | def sql_transfer(self, question_type, entities): 92 | if not entities: 93 | return [] 94 | 95 | # 查询语句 96 | sql = [] 97 | # 查询疾病的原因 98 | if question_type == 'disease_cause': 99 | sql = ["MATCH (m:Disease) where m.name = '{0}' return m.name, m.cause".format(i) for i in entities] 100 | 101 | # 查询疾病的防御措施 102 | elif question_type == 'disease_prevent': 103 | sql = ["MATCH (m:Disease) where m.name = '{0}' return m.name, m.prevent".format(i) for i in entities] 104 | 105 | # 查询疾病的持续时间 106 | elif question_type == 'disease_lasttime': 107 | sql = ["MATCH (m:Disease) where m.name = '{0}' return m.name, m.cure_lasttime".format(i) for i in entities] 108 | 109 | # 查询疾病的治愈概率 110 | elif question_type == 'disease_cureprob': 111 | sql = ["MATCH (m:Disease) where m.name = '{0}' return m.name, m.cured_prob".format(i) for i in entities] 112 | 113 | # 查询疾病的治疗方式 114 | elif question_type == 'disease_cureway': 115 | sql = ["MATCH (m:Disease) where m.name = '{0}' return m.name, m.cure_way".format(i) for i in entities] 116 | 117 | # 查询疾病的易发人群 118 | elif question_type == 'disease_easyget': 119 | sql = ["MATCH (m:Disease) where m.name = '{0}' return m.name, m.easy_get".format(i) for i in entities] 120 | 121 | # 查询疾病的相关介绍 122 | elif question_type == 'disease_desc': 123 | sql = ["MATCH (m:Disease) where m.name = '{0}' return m.name, m.desc".format(i) for i in entities] 124 | 125 | # 查询疾病有哪些症状 126 | elif question_type == 'disease_symptom': 127 | sql = ["MATCH (m:Disease)-[r:has_symptom]->(n:Symptom) where m.name = '{0}' return m.name, r.name, n.name".format(i) for i in entities] 128 | 129 | # 查询症状会导致哪些疾病 130 | elif question_type == 'symptom_disease': 131 | sql = ["MATCH (m:Disease)-[r:has_symptom]->(n:Symptom) where n.name = '{0}' return m.name, r.name, n.name".format(i) for i in entities] 132 | 133 | # 查询疾病的并发症 134 | elif question_type == 'disease_acompany': 135 | sql1 = ["MATCH (m:Disease)-[r:acompany_with]->(n:Disease) where m.name = '{0}' return m.name, r.name, n.name".format(i) for i in entities] 136 | sql2 = ["MATCH (m:Disease)-[r:acompany_with]->(n:Disease) where n.name = '{0}' return m.name, r.name, n.name".format(i) for i in entities] 137 | sql = sql1 + sql2 138 | # 查询疾病的忌口 139 | elif question_type == 'disease_not_food': 140 | sql = ["MATCH (m:Disease)-[r:no_eat]->(n:Food) where m.name = '{0}' return m.name, r.name, n.name".format(i) for i in entities] 141 | 142 | # 查询疾病建议吃的东西 143 | elif question_type == 'disease_do_food': 144 | sql1 = ["MATCH (m:Disease)-[r:do_eat]->(n:Food) where m.name = '{0}' return m.name, r.name, n.name".format(i) for i in entities] 145 | sql2 = ["MATCH (m:Disease)-[r:recommand_eat]->(n:Food) where m.name = '{0}' return m.name, r.name, n.name".format(i) for i in entities] 146 | sql = sql1 + sql2 147 | 148 | # 已知忌口查疾病 149 | elif question_type == 'food_not_disease': 150 | sql = ["MATCH (m:Disease)-[r:no_eat]->(n:Food) where n.name = '{0}' return m.name, r.name, n.name".format(i) for i in entities] 151 | 152 | # 已知推荐查疾病 153 | elif question_type == 'food_do_disease': 154 | sql1 = ["MATCH (m:Disease)-[r:do_eat]->(n:Food) where n.name = '{0}' return m.name, r.name, n.name".format(i) for i in entities] 155 | sql2 = ["MATCH (m:Disease)-[r:recommand_eat]->(n:Food) where n.name = '{0}' return m.name, r.name, n.name".format(i) for i in entities] 156 | sql = sql1 + sql2 157 | 158 | # 查询疾病常用药品-药品别名记得扩充 159 | elif question_type == 'disease_drug': 160 | sql1 = ["MATCH (m:Disease)-[r:common_drug]->(n:Drug) where m.name = '{0}' return m.name, r.name, n.name".format(i) for i in entities] 161 | sql2 = ["MATCH (m:Disease)-[r:recommand_drug]->(n:Drug) where m.name = '{0}' return m.name, r.name, n.name".format(i) for i in entities] 162 | sql = sql1 + sql2 163 | 164 | # 已知药品查询能够治疗的疾病 165 | elif question_type == 'drug_disease': 166 | sql1 = ["MATCH (m:Disease)-[r:common_drug]->(n:Drug) where n.name = '{0}' return m.name, r.name, n.name".format(i) for i in entities] 167 | sql2 = ["MATCH (m:Disease)-[r:recommand_drug]->(n:Drug) where n.name = '{0}' return m.name, r.name, n.name".format(i) for i in entities] 168 | sql = sql1 + sql2 169 | # 查询疾病应该进行的检查 170 | elif question_type == 'disease_check': 171 | sql = ["MATCH (m:Disease)-[r:need_check]->(n:Check) where m.name = '{0}' return m.name, r.name, n.name".format(i) for i in entities] 172 | 173 | # 已知检查查询疾病 174 | elif question_type == 'check_disease': 175 | sql = ["MATCH (m:Disease)-[r:need_check]->(n:Check) where n.name = '{0}' return m.name, r.name, n.name".format(i) for i in entities] 176 | 177 | return sql 178 | 179 | 180 | 181 | if __name__ == '__main__': 182 | handler = QuestionPaser() 183 | -------------------------------------------------------------------------------- /qwen7b_server.py: -------------------------------------------------------------------------------- 1 | # coding = utf-8 2 | import os 3 | import torch 4 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 5 | device = torch.device("cuda:0") 6 | from transformers import AutoModelForCausalLM, AutoTokenizer 7 | from transformers.generation.utils import GenerationConfig 8 | import json 9 | from flask import Flask, request, jsonify 10 | from torch.nn import CrossEntropyLoss 11 | from tqdm import tqdm 12 | 13 | 14 | ## 注意,Qwen-7B-Chat需要自己下载,可以从huggingface或者modelscope下载,存放至当前目录 15 | 16 | tokenizer = AutoTokenizer.from_pretrained("Qwen-7B-Chat", trust_remote_code=True) 17 | model = AutoModelForCausalLM.from_pretrained("Qwen-7B-Chat", trust_remote_code=True).cuda() 18 | model = model.to(device) 19 | model.generation_config = GenerationConfig.from_pretrained("Qwen-7B-Chat") 20 | 21 | def predict_model(data): 22 | text = data["message"][0]["content"] 23 | inputs = tokenizer(text, return_tensors='pt').to(device) 24 | outputs = model.generate(**inputs, max_new_tokens=data["max_tokens"], top_k=data["top_k"], top_p=data["top_p"], temperature=data["temperature"], repetition_penalty=data["repetition_penalty"], num_beams=data["num_beams"]) 25 | response = tokenizer.decode(outputs[0][len(inputs["input_ids"][0]):], skip_special_tokens=True) 26 | return response 27 | 28 | app = Flask(import_name=__name__) 29 | @app.route("/generate", methods=["POST", "GET"]) 30 | def generate(): 31 | data = json.loads(request.data) 32 | print(data) 33 | try: 34 | res = predict_model(data) 35 | label = "success" 36 | except Exception as e: 37 | res = "" 38 | label = "error" 39 | print(e) 40 | return jsonify({"output":[res], "status":label}) 41 | 42 | if __name__ == '__main__': 43 | app.run(port=3001, debug=False, host='0.0.0.0') # 如果是0.0.0.0,则可以被外网访问 44 | -------------------------------------------------------------------------------- /wechat.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuhuanyong/RAGOnMedicalKG/a037e393d1ebdf120b28118cbedc268020a14e0b/wechat.jpg --------------------------------------------------------------------------------