├── .gitignore ├── .idea ├── .gitignore ├── inspectionProfiles │ └── profiles_settings.xml ├── medical_kb_chatbot.iml ├── modules.xml └── vcs.xml ├── README.md ├── README_en.md ├── agent └── bing_search.py ├── app.py ├── chains ├── local_doc_qa.py ├── modules │ ├── embeddings.py │ ├── excel_load.py │ ├── json_load.py │ └── vectorstores.py └── text_load.py ├── configs ├── common_config.py └── test_ass.yaml ├── demo_data ├── kb_drug_demo.jsonl └── lora_demo.xlsx ├── environment.yml ├── finetune ├── pulse │ ├── configs │ │ └── lora_config_bloom.json │ ├── convert_to_conv_data.py │ ├── finetune.py │ └── src │ │ ├── sample_generator.py │ │ ├── trainer.py │ │ └── utils.py └── pulse_utils.py ├── img ├── 1.jpg ├── 2.jpg ├── 3.jpg └── 4.jpg ├── loader ├── __init__.py ├── image_loader.py ├── models │ ├── __init__.py │ ├── __main__.py │ ├── base.py │ ├── bloomz_llm.py │ ├── extensions │ │ ├── callback.py │ │ ├── extensions.py │ │ ├── llamacpp_model_alternative.py │ │ └── thread_with_exception.py │ ├── loader │ │ ├── __init__.py │ │ ├── args.py │ │ └── loader.py │ └── shared.py ├── pdf_loader.py ├── textsplitter │ ├── __init__.py │ ├── ali_text_splitter.py │ └── chinese_text_splitter.py └── utils │ └── __init__.py ├── requirements.txt └── vector_store └── drug_kb ├── index.faiss └── index.pkl /.gitignore: -------------------------------------------------------------------------------- 1 | **/__pycache__ -------------------------------------------------------------------------------- /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/medical_kb_chatbot.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 12 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [[中文版](https://github.com/JuneYaooo/medical_kb_chatbot/blob/main/README.md)] [[English](https://github.com/JuneYaooo/medical_kb_chatbot/blob/main/README_en.md)] 2 | 3 | # 医疗知识聊天机器人 4 | 5 | 欢迎使用 医疗知识聊天机器人,这是一款基于 PULSE 模型,引入知识库及微调训练的聊天机器人,旨在提供更实用的医疗相关功能和服务。用户可以自己添加相关知识库,进行模型微调,体验更丰富的应用场景: 6 | 7 | ## 你可以用它来做什么 8 | 9 | - **药物查询**:提供药物数据库,用户可以搜索特定药物的信息,如用途、剂量、副作用等。 10 | 11 | - **病症解释**:提供常见疾病、症状和医学术语的解释和定义,帮助用户更好地理解医学知识。 12 | 13 | - **医疗客服**:添加相关医疗产品文档,支持用户与聊天机器人进行个性化对话,回答医疗产品相关问题,提供准确和可靠的信息。 14 | 15 | ## 使用方法 16 | 17 | ### 下载模型与修改配置文件 18 | 19 | 如果直接使用有问题,可以将PULSE模型下载到本地:https://huggingface.co/OpenMEDLab/PULSE-7bv5 20 | 21 | 然后在configs/common_config.py文件中将模型路径修改为本地路径,如修改embedding_model_dict和llm_model_dict中的路径即可。 22 | 23 | ### 安装 24 | 25 | 首先,克隆本项目到本地计算机: 26 | 27 | ``` 28 | git clone https://github.com/JuneYaooo/medical_kb_chatbot.git 29 | ``` 30 | 31 | #### 使用 pip 安装 32 | 33 | 确保您的计算机上已安装以下依赖项: 34 | 35 | - Python 3.9 36 | - pip 包管理器 37 | 38 | 进入项目目录并安装必要的依赖项: 39 | 40 | ``` 41 | cd medical_kb_chatbot 42 | pip install -r requirements.txt 43 | ``` 44 | 45 | #### 使用 conda 安装 46 | 47 | 确保您的计算机上已安装以下依赖项: 48 | 49 | - Anaconda 或 Miniconda 50 | 51 | 进入项目目录并创建一个新的 conda 环境: 52 | 53 | ``` 54 | cd medical_kb_chatbot 55 | conda env create -f environment.yml 56 | ``` 57 | 58 | 激活新创建的环境: 59 | 60 | ``` 61 | conda activate kb_chat 62 | ``` 63 | 64 | 然后运行聊天机器人: 65 | 66 | ``` 67 | python app.py 68 | ``` 69 | 70 | ### 使用说明 71 | #### 可选择在知识库页面配置知识库 72 | - 支持excel、json、非图片类型的pdf、word、txt等格式 73 | - 其中excel、json需要按要求格式上传 74 | - 鼓励挂载一些医疗知识库尝试效果,有好的案例欢迎分享 75 | - 提供了一点点药品[demo数据](https://github.com/JuneYaooo/medical_kb_chatbot/blob/main/demo_data/kb_drug_demo.jsonl) ,可以下下来试一下 76 | 77 | 78 | ![知识库配置](img/2.jpg) 79 | 80 | #### 可选择使用lora微调模型 81 | - 微调目前最小需要24G显卡(~一张3090) 82 | - 微调结束后,可看到更新时间 83 | - 提供了一点点训练[demo数据](https://github.com/JuneYaooo/medical_kb_chatbot/blob/main/demo_data/lora_demo.xlsx) ,可以下下来试一下 84 | 85 | 86 | ![Lora微调](img/3.jpg) 87 | 88 | #### 在医疗小助手页面选择配置自己的知识库聊天小助手(可自由选择是否使用某个知识库/微调的lora) 89 | - 配置prompt可参考模板多尝试,有发现好的prompt欢迎分享 90 | - prompt 设置可以参考如下格式 91 | ``` 92 | 假设你是用药助手,请根据文档来回复,如果文档内容为空或者None,则忽略,文档:{context}\n{chat_history}User:{question}Helper: 93 | ``` 94 | 95 | ![配置](img/4.jpg) 96 | 97 | #### 配置好小助手,来对话测试页面试试吧 98 | 99 | - 选择一个已经配置好的聊天小助手,来体验一下吧 100 | 101 | ![使用](img/1.jpg) 102 | 103 | ## 致谢 104 | 105 | - [PULSE](https://github.com/openmedlab/PULSE): 本项目模型来源于PULSE 106 | - [langchain-ChatGLM](https://github.com/imClumsyPanda/langchain-ChatGLM): 本项目知识库部分参考了langchain-ChatGLM的代码 107 | - [BELLE](https://github.com/LianjiaTech/BELLE): 本项目Lora微调部分参考了BELLE的代码 108 | 109 | ## 贡献 110 | 111 | 如果您对该项目感兴趣,欢迎贡献您的代码和改进建议。您可以通过以下方式参与: 112 | 113 | 1. 提交问题和建议到本项目的 Issue 页面。 114 | 2. Fork 本项目并提交您的改进建议,我们将会审查并合并合适的改动。 115 | -------------------------------------------------------------------------------- /README_en.md: -------------------------------------------------------------------------------- 1 | [[Chinese Version](https://github.com/JuneYaooo/medical_kb_chatbot/blob/main/README.md)] [[English Version](https://github.com/JuneYaooo/medical_kb_chatbot/blob/main/README_en.md)] 2 | 3 | # Medical Knowledge Chatbot 4 | 5 | Welcome to the Medical Knowledge Chatbot. This is a chatbot based on the PULSE model, incorporating knowledge base and fine-tuning training, aiming to provide more practical medical-related functions and services. Users can add relevant knowledge bases and perform model fine-tuning to experience richer application scenarios. 6 | 7 | ## Example Applications of the Medical Chatbot 8 | 9 | - **Drug Query**: Provides a drug database where users can search for specific drug information such as uses, dosage, side effects, etc. 10 | 11 | - **Symptom Explanation**: Provides explanations and definitions of common diseases, symptoms, and medical terminologies to help users better understand medical knowledge. 12 | 13 | - **Medical Customer Service**: Adds relevant medical product documentation, supports personalized conversations with the chatbot, answers questions related to medical products, and provides accurate and reliable information. 14 | 15 | ## Usage 16 | 17 | ### Download Model and Modify Configuration File 18 | 19 | If there are any issues with direct usage, you can download the PULSE model to your local machine from: [https://huggingface.co/OpenMEDLab/PULSE-7bv5](https://huggingface.co/OpenMEDLab/PULSE-7bv5) 20 | 21 | Then, modify the model path in the `configs/common_config.py` file to the local path. You can modify the paths in the `embedding_model_dict` and `llm_model_dict` variables. 22 | 23 | ### Installation 24 | 25 | First, clone this project to your local machine: 26 | 27 | ``` 28 | git clone https://github.com/JuneYaooo/medical_kb_chatbot.git 29 | ``` 30 | 31 | #### Install using pip 32 | 33 | Make sure the following dependencies are installed on your machine: 34 | 35 | - Python 3.9 36 | - pip package manager 37 | 38 | Navigate to the project directory and install the necessary dependencies: 39 | 40 | ``` 41 | cd medical_kb_chatbot 42 | pip install -r requirements.txt 43 | ``` 44 | 45 | #### Install using conda 46 | 47 | Make sure the following dependencies are installed on your machine: 48 | 49 | - Anaconda or Miniconda 50 | 51 | Navigate to the project directory and create a new conda environment: 52 | 53 | ``` 54 | cd medical_kb_chatbot 55 | conda env create -f environment.yml 56 | ``` 57 | 58 | Activate the newly created environment: 59 | 60 | ``` 61 | conda activate kb_chat 62 | ``` 63 | 64 | Then run the chatbot: 65 | 66 | ``` 67 | python app.py 68 | ``` 69 | 70 | ### Instructions for Use 71 | 72 | #### Configure Knowledge Bases on the Knowledge Base Page 73 | 74 | - Supports formats such as Excel, JSON, non-image PDFs, Word documents, TXT files, etc. 75 | - For Excel and JSON formats, they need to be uploaded in the required format. 76 | - It is encouraged to mount some medical knowledge bases to try out the effectiveness. Good examples are welcome to be shared. 77 | 78 | ![Knowledge Base Configuration](img/2.jpg) 79 | 80 | #### Fine-Tune the Model using Lora 81 | 82 | - Fine-tuning currently requires a minimum of 24GB GPU (~one 3090). 83 | - After fine-tuning, you can see the update time. 84 | 85 | ![Lora Fine-Tuning](img/3.jpg) 86 | 87 | #### Configure Your Knowledge Base Chatbot on the Medical Chatbot Page (Optional to use a specific knowledge base/fine-tuned Lora) 88 | 89 | - You can configure your own knowledge base chatbot by selecting whether to use a particular knowledge base or the fine-tuned Lora. 90 | - Refer to the template for configuring prompts and try out different options. Good prompts are welcome to be shared. 91 | 92 | ![Configuration](img/4.jpg) 93 | 94 | #### Once the chatbot is configured, try it out on the conversation test page 95 | 96 | - Select a pre-configured chatbot and give it a try. 97 | 98 | ![Usage](img/1.jpg) 99 | 100 | ## Acknowledgments 101 | 102 | - [PULSE](https://github.com/openmedlab/PULSE): The model used in this project is based on PULSE. 103 | - [langchain-ChatGLM](https://github.com/imClumsyPanda/langchain-ChatGLM): The code for the knowledge base part of this project was inspired by langchain-ChatGLM. 104 | - [BELLE](https://github.com/LianjiaTech/BELLE): The code for the Lora fine-tuning part of this project was inspired by BELLE. 105 | 106 | ## Contribution 107 | 108 | If you are interested in this project, you are welcome to contribute your code and improvement suggestions. You can participate in the following ways: 109 | 110 | 1. Submit issues and suggestions on the Issue page of this project. 111 | 2. Fork this project, make your improvements, and submit a pull request. We will review and merge appropriate changes. -------------------------------------------------------------------------------- /agent/bing_search.py: -------------------------------------------------------------------------------- 1 | #coding=utf8 2 | 3 | from langchain.utilities import BingSearchAPIWrapper 4 | from configs.common_config import BING_SEARCH_URL, BING_SUBSCRIPTION_KEY 5 | 6 | 7 | def bing_search(text, result_len=3): 8 | if not (BING_SEARCH_URL and BING_SUBSCRIPTION_KEY): 9 | return [{"snippet": "please set BING_SUBSCRIPTION_KEY and BING_SEARCH_URL in os ENV", 10 | "title": "env inof not fould", 11 | "link": "https://python.langchain.com/en/latest/modules/agents/tools/examples/bing_search.html"}] 12 | search = BingSearchAPIWrapper(bing_subscription_key=BING_SUBSCRIPTION_KEY, 13 | bing_search_url=BING_SEARCH_URL) 14 | return search.results(text, result_len) 15 | 16 | 17 | if __name__ == "__main__": 18 | r = bing_search('python') 19 | print(r) 20 | -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | import gradio as gr 2 | import shutil 3 | 4 | from chains.local_doc_qa import LocalDocQA 5 | from configs.common_config import * 6 | import nltk 7 | from loader.models.base import (BaseAnswer, 8 | AnswerResult, 9 | AnswerResultStream, 10 | AnswerResultQueueSentinelTokenListenerQueue) 11 | import loader.models.shared as shared 12 | from loader.models.loader.args import parser 13 | from loader.models.loader import LoaderCheckPoint 14 | from finetune.pulse_utils import pulse_train_model, stop_train_process 15 | import shutil 16 | import time 17 | import datetime 18 | import re 19 | import os 20 | import glob 21 | 22 | def get_file_modify_time(filename): 23 | try: 24 | return datetime.datetime.fromtimestamp(os.stat(filename).st_mtime).strftime("%Y-%m-%d %H:%M:%S") 25 | except Exception as e: 26 | print('Failed to get modification time for {}'.format(filename)) 27 | print(e) 28 | return 'not available' 29 | 30 | def get_model_update_time(model_name, lora_name): 31 | if 'pulse' in model_name.lower(): 32 | update_time = get_file_modify_time(f"finetune/pulse/output/{lora_name}/adapter_model.bin") 33 | else: 34 | update_time = 'not available' 35 | return update_time 36 | 37 | def on_train(model_name, lora_name, training_data_file): 38 | training_data_path = 'data/'+os.path.basename(training_data_file.name) 39 | if 'pulse' in model_name.lower(): 40 | msg = pulse_train_model(model_name, lora_name, training_data_path) 41 | else: 42 | msg = 'please select one model!' 43 | return msg 44 | nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path 45 | 46 | 47 | def upload_file(file): 48 | print('file',file) 49 | if not os.path.exists("data"): 50 | os.mkdir("data") 51 | filename = os.path.basename(file.name) 52 | shutil.move(file.name, "data/" + filename) 53 | # file_list首位插入新上传的文件 54 | filedir = "data/" + filename 55 | return filedir 56 | 57 | 58 | def get_vs_list(): 59 | lst_default = [] 60 | if not os.path.exists(VS_ROOT_PATH): 61 | return lst_default 62 | lst = os.listdir(VS_ROOT_PATH) 63 | if not lst: 64 | return lst_default 65 | lst.sort() 66 | return lst_default + lst 67 | 68 | 69 | 70 | def get_yaml_files(folder_path): 71 | yaml_files = glob.glob(os.path.join(folder_path, '*.yaml')) 72 | file_names = [os.path.splitext(os.path.basename(file))[0] for file in yaml_files] 73 | return file_names 74 | 75 | yaml_files = get_yaml_files('configs') 76 | 77 | 78 | vs_list = get_vs_list() 79 | 80 | 81 | embedding_model_dict_list = list(embedding_model_dict.keys()) 82 | 83 | llm_model_dict_list = list(llm_model_dict.keys()) 84 | 85 | 86 | local_doc_qa = LocalDocQA() 87 | 88 | flag_csv_logger = gr.CSVLogger() 89 | 90 | 91 | def read_config(ass_name_en): 92 | config_file_path = f"configs/{ass_name_en}.yaml" 93 | with open(config_file_path, 'r', encoding='utf-8') as file: 94 | yaml = ruamel.yaml.YAML() 95 | config = yaml.load(file) 96 | ass_name = config['ass_name'] 97 | llm_model = config['llm_model'] 98 | embedding_model = config['embedding_model'] 99 | llm_history_len = config['llm_history_len'] 100 | lora_name = config['lora_name'] 101 | top_k = config['top_k'] 102 | score_threshold = config['score_threshold'] 103 | chunk_content = config['chunk_content'] 104 | chunk_sizes = config['chunk_sizes'] 105 | show_reference = config['show_reference'] 106 | knowledge_set_name = config['knowledge_set_name'] 107 | prompt_template = config['prompt_template'] 108 | 109 | return ass_name_en,ass_name, llm_model, embedding_model, lora_name, llm_history_len, knowledge_set_name, top_k, score_threshold, chunk_content, chunk_sizes, show_reference,prompt_template 110 | 111 | def remove_html_tags(text): 112 | clean_text = re.sub('<.*?>', '', text) 113 | return clean_text 114 | 115 | from pynvml import (nvmlInit, nvmlDeviceGetCount, nvmlDeviceGetHandleByIndex, 116 | nvmlDeviceGetName, nvmlDeviceGetMemoryInfo, nvmlShutdown) 117 | def get_available_gpu(threshold=20000): 118 | # Initialize NVML 119 | nvmlInit() 120 | # Get the number of GPU devices 121 | device_count = nvmlDeviceGetCount() 122 | 123 | # Find GPU devices with available memory greater than the threshold 124 | available_gpus = [] 125 | for i in range(device_count): 126 | handle = nvmlDeviceGetHandleByIndex(i) 127 | info = nvmlDeviceGetMemoryInfo(handle) 128 | free_memory_mb = info.free / 1024 / 1024 129 | 130 | if free_memory_mb > threshold: 131 | available_gpus.append(i) 132 | 133 | # Shutdown NVML 134 | nvmlShutdown() 135 | # available_gpus = ['0'] 136 | 137 | return available_gpus 138 | 139 | 140 | def get_free_memory(): 141 | nvmlInit() 142 | # Get the number of GPU devices 143 | device_count = nvmlDeviceGetCount() 144 | 145 | # Find GPU devices with available memory greater than the threshold 146 | free_memory_gpus = [] 147 | for i in range(device_count): 148 | handle = nvmlDeviceGetHandleByIndex(i) 149 | info = nvmlDeviceGetMemoryInfo(handle) 150 | free_memory_mb = info.free / 1024 / 1024 151 | free_memory_gpus.append(free_memory_mb) 152 | 153 | # Shutdown NVML 154 | nvmlShutdown() 155 | return free_memory_gpus 156 | 157 | def get_chat_answer(query, ass_id, history, score_threshold=VECTOR_SEARCH_SCORE_THRESHOLD, 158 | vector_search_top_k=VECTOR_SEARCH_TOP_K, chunk_content: bool = True, 159 | chunk_size=CHUNK_SIZE, streaming: bool = STREAMING): 160 | ass_name_en,ass_name, llm_model, embedding_model, lora_name, llm_history_len, knowledge_set_name, top_k, score_threshold, chunk_content, chunk_sizes,show_reference,prompt_template = read_config(ass_id) 161 | if len(history)>0 and len(history[-1])>0: 162 | history[-1][-1] = remove_html_tags(history[-1][-1]) 163 | history = history[-llm_history_len:] if history is not None and len(history) > llm_history_len else history 164 | local_doc_qa.top_k = top_k 165 | local_doc_qa.score_threshold = score_threshold 166 | local_doc_qa.chunk_content = chunk_content 167 | local_doc_qa.chunk_size = chunk_size 168 | available_gpus = get_available_gpu(threshold=20000) 169 | if local_doc_qa.llm is None: 170 | if len(available_gpus)>0: 171 | available_gpu = available_gpus[0] 172 | target_device = torch.device(f'cuda:{str(available_gpu)}') 173 | else: 174 | yield [[None,'GPU空间不够,请至少确保机器上GPU剩余空间>20G']],'','' 175 | args_dict = {'model':llm_model, 'lora':lora_name} if lora_name != '不使用' else {'model':llm_model} 176 | shared.loaderCheckPoint = LoaderCheckPoint(args_dict) 177 | llm_model_ins = shared.loaderLLM() 178 | llm_model_ins.set_history_len(llm_history_len) 179 | local_doc_qa.init_cfg(llm_model=llm_model_ins,embedding_model=embedding_model, 180 | embedding_device=target_device) 181 | if local_doc_qa.embeddings is None: 182 | if len(available_gpus)>0: 183 | available_gpu = available_gpus[0] 184 | target_device = torch.device(f'cuda:{str(available_gpu)}') 185 | else: 186 | yield [[None,'GPU空间不够,请至少确保机器上GPU剩余空间>20G']],'','' 187 | local_doc_qa.init_embedding(embedding_model=embedding_model,embedding_device=target_device) 188 | if knowledge_set_name !='不使用知识库': 189 | vs_path = os.path.join(VS_ROOT_PATH, knowledge_set_name) 190 | for resp, history in local_doc_qa.get_knowledge_based_answer(model_name=llm_model, 191 | query=query, vs_path=vs_path, prompt_template=prompt_template,chat_history=history, streaming=streaming): 192 | if len(resp["source_documents"])>0: 193 | source = "" 194 | source += "".join( 195 | [f"""出处 [{i + 1}] {os.path.split(doc.metadata["source"])[-1]}\n""" 196 | f"""{doc.page_content}\n""" 197 | for i, doc in 198 | enumerate(resp["source_documents"])]) 199 | else: 200 | source = "暂无" 201 | 202 | yield history, "", source 203 | else: 204 | print('纯聊天模式') 205 | history = history 206 | for answer_result, history in local_doc_qa.get_base_answer(model_name=llm_model, 207 | query=query, prompt_template=prompt_template,chat_history=history, streaming=streaming): 208 | yield history, "", "未挂载知识库" 209 | 210 | def get_knowledge_search(query, vs_path, history, score_threshold=VECTOR_SEARCH_SCORE_THRESHOLD, 211 | vector_search_top_k=VECTOR_SEARCH_TOP_K, chunk_content: bool = True, 212 | chunk_size=CHUNK_SIZE, streaming: bool = STREAMING): 213 | if local_doc_qa.embeddings is None: 214 | local_doc_qa.init_embedding() 215 | if os.path.exists(vs_path): 216 | resp, prompt = local_doc_qa.get_knowledge_based_conent_test(query=query, vs_path=vs_path, 217 | score_threshold=score_threshold, 218 | vector_search_top_k=vector_search_top_k, 219 | chunk_content=chunk_content, 220 | chunk_size=chunk_size) 221 | if not resp["source_documents"]: 222 | yield history + [[query, 223 | "根据您的设定,没有匹配到任何内容,请确认您设置的知识相关度 Score 阈值是否过小或其他参数是否正确。"]], "" 224 | else: 225 | source = "\n".join( 226 | [ 227 | f"""
【知识相关度 Score】:{doc.metadata["score"]} - 【出处{i + 1}】: {os.path.split(doc.metadata["source"])[-1]} \n""" 228 | f"""{doc.page_content}\n""" 229 | f"""
""" 230 | for i, doc in 231 | enumerate(resp["source_documents"])]) 232 | history.append([query, "以下内容为知识库中满足设置条件的匹配结果:\n\n" + source]) 233 | yield history, "" 234 | else: 235 | yield history + [[query, 236 | "请选择知识库后进行测试,当前未选择知识库。"]], "" 237 | 238 | 239 | def change_assistant_input(ass_id): 240 | 241 | ass_name_en,ass_name, llm_model, embedding_model, lora_name, llm_history_len, knowledge_set_name, top_k, score_threshold, chunk_content, chunk_sizes,show_reference,prompt_template = read_config(ass_id) 242 | 243 | init_hello = f"你好,我是{ass_name}" 244 | 245 | if show_reference: 246 | return [[None,init_hello]], gr.update(visible=True) 247 | else: 248 | return [[None,init_hello]], gr.update(visible=False) 249 | 250 | 251 | 252 | import ruamel.yaml 253 | def set_config(ass_name_en,ass_name, llm_model, embedding_model, lora_name, llm_history_len, knowledge_set_name, top_k, score_threshold, chunk_content, chunk_sizes,show_reference, prompt_template, ass_list): 254 | config = { 255 | 'ass_name':ass_name, 256 | 'llm_model': llm_model, 257 | 'embedding_model': embedding_model, 258 | 'lora_name': lora_name, 259 | 'llm_history_len': llm_history_len, 260 | 'knowledge_set_name':knowledge_set_name, 261 | 'top_k': top_k, 262 | 'score_threshold':score_threshold, 263 | 'chunk_content':chunk_content, 264 | 'chunk_sizes':chunk_sizes, 265 | 'show_reference':show_reference, 266 | 'prompt_template':prompt_template 267 | } 268 | yaml = ruamel.yaml.YAML() 269 | with open(f'configs/{ass_name_en}.yaml', 'w', encoding="utf-8") as file: 270 | yaml.dump(config, file) 271 | 272 | return gr.update(visible=True),f'configs/{ass_name_en}.yaml 保存成功!',gr.update(visible=True, choices=ass_list, value=ass_list[0]) 273 | 274 | 275 | def get_vector_store(vs_id, files, sentence_size, history, one_conent, one_content_segmentation): 276 | if local_doc_qa.embeddings is None: 277 | local_doc_qa.init_embedding() 278 | 279 | vs_path = os.path.join(VS_ROOT_PATH, vs_id) 280 | filelist = [] 281 | if not os.path.exists(os.path.join(UPLOAD_ROOT_PATH, vs_id)): 282 | os.makedirs(os.path.join(UPLOAD_ROOT_PATH, vs_id)) 283 | if isinstance(files, list): 284 | for file in files: 285 | filename = os.path.split(file.name)[-1] 286 | shutil.move(file.name, os.path.join(UPLOAD_ROOT_PATH, vs_id, filename)) 287 | filelist.append(os.path.join(UPLOAD_ROOT_PATH, vs_id, filename)) 288 | vs_path, loaded_files = local_doc_qa.init_knowledge_vector_store(filelist, vs_path, sentence_size) 289 | else: 290 | vs_path, loaded_files = local_doc_qa.one_knowledge_add(vs_path, files, one_conent, one_content_segmentation, 291 | sentence_size) 292 | if len(loaded_files): 293 | file_status = f"已添加 {'、'.join([os.path.split(i)[-1] for i in loaded_files if i])} 内容至知识库,并已加载知识库,请开始提问" 294 | else: 295 | file_status = "文件未成功加载,请重新上传文件" 296 | # if local_doc_qa.llm and local_doc_qa.embeddings: 297 | 298 | # else: 299 | # file_status = "模型未完成加载,请先在加载模型后再导入文件" 300 | # vs_path = None 301 | logger.info(file_status) 302 | return vs_path, None, history + [[None, file_status]] 303 | 304 | 305 | def change_vs_name_input(vs_id, history): 306 | if vs_id == "新建知识库": 307 | return gr.update(visible=True), gr.update(visible=True), gr.update(visible=False), None, history 308 | else: 309 | file_status = f"已加载知识库{vs_id},请开始提问" 310 | return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), os.path.join(VS_ROOT_PATH, 311 | vs_id), history + [ 312 | [None, file_status]] 313 | 314 | 315 | knowledge_base_test_mode_info = ("【注意】\n\n" 316 | "1.您已进入知识库测试模式,仅用于测试知识库相关参数配置\n\n" 317 | "2.知识相关度 Score,建议设置为 500~800,具体设置情况请结合实际使用调整。\n\n" 318 | "3.目前支持的处理格式包含非图片格式的pdf、word、txt、md、excel、json、jsonl格式\n\n" 319 | "其中excel格式可包含多个数据表,每个数据表里必须包含两列:问题|回答\n\n" 320 | """其中json、jsonl格式处理成多行,每行一个dict类似这样:{"docs": ["氨力农注射液 药物分类\n化学药品", "氨力农注射液 药物剂量\n10毫升:50毫克"]}""") 321 | 322 | 323 | def change_mode(mode, history): 324 | if mode == "知识库问答": 325 | return gr.update(visible=True), gr.update(visible=False), history 326 | # + [[None, "【注意】:您已进入知识库问答模式,您输入的任何查询都将进行知识库查询,然后会自动整理知识库关联内容进入模型查询!!!"]] 327 | elif mode == "知识库配置": 328 | return gr.update(visible=True), gr.update(visible=True), [[None, 329 | knowledge_base_test_mode_info]] 330 | else: 331 | return gr.update(visible=False), gr.update(visible=False), history 332 | 333 | 334 | def change_chunk_content(mode, label_conent, history): 335 | conent = "" 336 | if "chunk_content" in label_conent: 337 | conent = "搜索结果上下文关联" 338 | elif "one_content_segmentation" in label_conent: # 这里没用上,可以先留着 339 | conent = "内容分段入库" 340 | 341 | if mode: 342 | return gr.update(visible=True), history + [[None, f"【已开启{conent}】"]] 343 | else: 344 | return gr.update(visible=False), history + [[None, f"【已关闭{conent}】"]] 345 | 346 | 347 | def add_vs_name(vs_name, vs_list, chatbot): 348 | if not os.path.exists(VS_ROOT_PATH): 349 | os.makedirs(VS_ROOT_PATH) 350 | if vs_name in vs_list: 351 | vs_status = "与已有知识库名称冲突,请重新选择其他名称后提交" 352 | chatbot = chatbot + [[None, vs_status]] 353 | return gr.update(visible=True), vs_list, gr.update(visible=True), gr.update(visible=True), gr.update( 354 | visible=False), chatbot 355 | else: 356 | vs_status = f"""已新增知识库"{vs_name}",将在上传文件并载入成功后进行存储。请在开始对话前,先完成文件上传。 """ 357 | chatbot = chatbot + [[None, vs_status]] 358 | return gr.update(visible=True, choices=[vs_name] + vs_list, value=vs_name), [vs_name] + vs_list, gr.update( 359 | visible=False), gr.update(visible=False), gr.update(visible=True), chatbot 360 | 361 | def change_lora_name_input(model_name,lora_name_en): 362 | if lora_name_en == "新建Lora": 363 | return gr.update(visible=True), gr.update(visible=True) 364 | else: 365 | file_status = f"已加载{lora_name_en}" 366 | model_update_time = get_model_update_time(model_name, lora_name_en) 367 | return gr.update(visible=False), gr.update(visible=False), model_update_time 368 | 369 | 370 | def add_lora(lora_name_en,lora_list): 371 | if lora_name_en in lora_list: 372 | print('名称冲突,不新建') 373 | return gr.update(visible=True,value=lora_name_en), gr.update(visible=False), gr.update(visible=False), lora_list 374 | else: 375 | return gr.update(visible=True, choices=[lora_name_en] + lora_list, value=lora_name_en), gr.update(visible=False), gr.update(visible=False),[lora_name_en] + lora_list 376 | 377 | def change_assistant_name_input(ass_id): 378 | if ass_id == "新建小助手": 379 | return gr.update(visible=True), gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), '医疗小助手', LLM_MODEL,EMBEDDING_MODEL,'不使用',LLM_HISTORY_LEN,cur_vs_list.value[0] if len(cur_vs_list.value) > 1 else '不使用知识库',VECTOR_SEARCH_TOP_K,500,True,250,True,'' 380 | else: 381 | try: 382 | ass_name_en,ass_name, llm_model, embedding_model, lora_name, llm_history_len, knowledge_set_name, top_k, score_threshold, chunk_content, chunk_sizes,show_reference,prompt_template = read_config(ass_id) 383 | file_status = f"已加载{ass_id}" 384 | print('file_status',file_status) 385 | return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), ass_name, llm_model, embedding_model, lora_name, llm_history_len, knowledge_set_name, top_k, score_threshold, chunk_content, chunk_sizes,show_reference,prompt_template 386 | except Exception as e: 387 | return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), '医疗小助手', LLM_MODEL,EMBEDDING_MODEL,'不使用',LLM_HISTORY_LEN,cur_vs_list.value[0] if len(cur_vs_list.value) > 1 else '不使用知识库',VECTOR_SEARCH_TOP_K,500,True,250,True,'' 388 | 389 | 390 | def add_ass_config(ass_id,ass_list): 391 | if ass_id in ass_list: 392 | print('名称冲突,不新建') 393 | return ass_id, gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), '医疗小助手', LLM_MODEL,EMBEDDING_MODEL,'不使用',LLM_HISTORY_LEN,cur_vs_list.value[0] if len(cur_vs_list.value) > 1 else '不使用知识库',VECTOR_SEARCH_TOP_K,500,True,250,True,'',ass_list,gr.update(visible=True) 394 | else: 395 | return ass_id, gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), '医疗小助手', LLM_MODEL,EMBEDDING_MODEL,'不使用',LLM_HISTORY_LEN,cur_vs_list.value[0] if len(cur_vs_list.value) > 1 else '不使用知识库',VECTOR_SEARCH_TOP_K,500,True,250,True,'',ass_list+[ass_id],gr.update(visible=True, choices=ass_list+[ass_id], value=ass_id) 396 | 397 | def find_folders(directory): 398 | folders = [] 399 | for item in os.listdir(directory): 400 | item_path = os.path.join(directory, item) 401 | if os.path.isdir(item_path): 402 | folders.append(item) 403 | return folders 404 | 405 | def change_model_name_input(model_name): 406 | if 'pulse' in model_name.lower(): 407 | model_name = 'pulse' 408 | else: 409 | model_name = '' 410 | model_dir = os.path.join(f"finetune", model_name,'output') 411 | lora_list = find_folders(model_dir) 412 | return lora_list,gr.update(visible=True, choices=lora_list+["新建Lora"], value=lora_list[0] if len(lora_list)>0 else "新建Lora") 413 | 414 | def change_model_name_select(model_name): 415 | if 'pulse' in model_name.lower(): 416 | model_name = 'pulse' 417 | else: 418 | model_name = '' 419 | model_dir = os.path.join(f"finetune", model_name,'output') 420 | lora_list = find_folders(model_dir) 421 | return lora_list,gr.update(visible=True, choices=lora_list+["不使用"], value=lora_list[0] if len(lora_list)>0 else "不使用") 422 | 423 | block_css = """.importantButton { 424 | background: linear-gradient(45deg, #7e0570,#5d1c99, #6e00ff) !important; 425 | border: none !important; 426 | } 427 | .importantButton:hover { 428 | background: linear-gradient(45deg, #ff00e0,#8500ff, #6e00ff) !important; 429 | border: none !important; 430 | }""" 431 | 432 | webui_title = """ 433 | # 💁医疗知识聊天机器人💁 434 | """ 435 | default_vs = vs_list[0] if len(vs_list) > 1 else "为空" 436 | init_message = f"""请先在右侧选择小助手,再开始对话测试 437 | """ 438 | 439 | # 初始化消息 440 | args = None 441 | args = parser.parse_args() 442 | 443 | 444 | model_status = '请手动加载模型' 445 | 446 | default_theme_args = dict( 447 | font=["Source Sans Pro", 'ui-sans-serif', 'system-ui', 'sans-serif'], 448 | font_mono=['IBM Plex Mono', 'ui-monospace', 'Consolas', 'monospace'], 449 | ) 450 | 451 | def get_lora_init_list(model_name): 452 | if 'pulse' in model_name.lower(): 453 | model_name = 'pulse' 454 | else: 455 | model_name = '' 456 | model_dir = os.path.join(f"finetune", model_name,'output') 457 | if not os.path.exists(model_dir): 458 | os.makedirs(model_dir) 459 | lora_list = find_folders(model_dir) 460 | return lora_list 461 | 462 | lora_init_list = get_lora_init_list(llm_model_dict_list[0]) 463 | 464 | with gr.Blocks(css=block_css, theme=gr.themes.Default(**default_theme_args)) as demo: 465 | vs_path, file_status, model_status, set_vs_list , cur_vs_list, set_lora_list, set_ass_list = gr.State( 466 | os.path.join(VS_ROOT_PATH, vs_list[0]) if len(vs_list) > 1 else ""), gr.State(""), gr.State( 467 | model_status), gr.State(vs_list+['新建知识库']), gr.State(vs_list+['不使用知识库']), gr.State(lora_init_list), gr.State(yaml_files) 468 | 469 | gr.Markdown(webui_title) 470 | with gr.Tab("对话测试"): 471 | with gr.Row(): 472 | with gr.Column(scale=10): 473 | chatbot = gr.Chatbot([[None, init_message]], 474 | elem_id="chat-box", 475 | show_label=False).style(height=750) 476 | query = gr.Textbox(show_label=False, 477 | placeholder="请输入提问内容,按回车进行提交").style(container=False) 478 | with gr.Column(scale=5): 479 | choose_ass = gr.Dropdown(yaml_files, 480 | label="选择要使用的小助手", 481 | value= yaml_files if len(yaml_files) > 1 else '暂无可使用的', 482 | interactive=True) 483 | reference = gr.Textbox(type="text", label='参考资料',visible=True) 484 | choose_ass.change(fn=change_assistant_input, 485 | inputs=[choose_ass], 486 | outputs=[chatbot,reference]) 487 | query.submit(get_chat_answer, 488 | [query, choose_ass, chatbot], 489 | [chatbot, query, reference]) 490 | 491 | with gr.Tab("知识库测试"): 492 | with gr.Row(): 493 | with gr.Column(scale=10): 494 | chatbot = gr.Chatbot([[None, knowledge_base_test_mode_info]], 495 | elem_id="chat-box", 496 | show_label=False).style(height=750) 497 | query = gr.Textbox(show_label=False, 498 | placeholder="请输入提问内容,按回车进行提交").style(container=False) 499 | with gr.Column(scale=5): 500 | knowledge_set = gr.Accordion("知识库设定", visible=True) 501 | vs_setting = gr.Accordion("配置知识库", visible=True) 502 | with knowledge_set: 503 | score_threshold = gr.Number(value=VECTOR_SEARCH_SCORE_THRESHOLD, 504 | label="知识相关度 Score 阈值,分值越低匹配度越高", 505 | precision=0, 506 | interactive=True) 507 | vector_search_top_k = gr.Number(value=VECTOR_SEARCH_TOP_K, precision=0, 508 | label="获取知识库内容条数", interactive=True) 509 | chunk_content = gr.Checkbox(value=False, 510 | label="是否启用上下文关联", 511 | interactive=True) 512 | chunk_sizes = gr.Number(value=CHUNK_SIZE, precision=0, 513 | label="匹配单段内容的连接上下文后最大长度", 514 | interactive=True, visible=True) 515 | chunk_content.change(fn=change_chunk_content, 516 | inputs=[chunk_content, gr.Textbox(value="chunk_content", visible=False), chatbot], 517 | outputs=[chunk_sizes, chatbot]) 518 | with vs_setting: 519 | select_vs = gr.Dropdown(set_vs_list.value, 520 | label="请选择要加载的知识库", 521 | interactive=True, 522 | value=set_vs_list.value[0] if len(set_vs_list.value) > 0 else None) 523 | vs_name = gr.Textbox(label="请输入新建知识库名称,当前知识库命名暂不支持中文", 524 | lines=1, 525 | interactive=True, 526 | visible=True) 527 | vs_add = gr.Button(value="添加至知识库选项", visible=True) 528 | file2vs = gr.Column(visible=False) 529 | with file2vs: 530 | # load_vs = gr.Button("加载知识库") 531 | gr.Markdown("向知识库中添加单条内容或文件") 532 | sentence_size = gr.Number(value=SENTENCE_SIZE, precision=0, 533 | label="文本入库分句长度限制", 534 | interactive=True, visible=True) 535 | with gr.Tab("上传文件"): 536 | files = gr.File(label="添加文件", 537 | file_types=['.txt', '.md', '.docx', '.pdf', '.jsonl'], 538 | file_count="multiple", 539 | show_label=False 540 | ) 541 | load_file_button = gr.Button("上传文件并加载知识库") 542 | with gr.Tab("上传文件夹"): 543 | folder_files = gr.File(label="添加文件", 544 | # file_types=['.txt', '.md', '.docx', '.pdf'], 545 | file_count="directory", 546 | show_label=False) 547 | load_folder_button = gr.Button("上传文件夹并加载知识库") 548 | with gr.Tab("添加单条内容"): 549 | one_title = gr.Textbox(label="标题", placeholder="请输入要添加单条段落的标题", lines=1) 550 | one_conent = gr.Textbox(label="内容", placeholder="请输入要添加单条段落的内容", lines=5) 551 | one_content_segmentation = gr.Checkbox(value=True, label="禁止内容分句入库", 552 | interactive=True) 553 | load_conent_button = gr.Button("添加内容并加载知识库") 554 | # 将上传的文件保存到content文件夹下,并更新下拉框 555 | vs_add.click(fn=add_vs_name, 556 | inputs=[vs_name, set_vs_list, chatbot], 557 | outputs=[select_vs, set_vs_list, vs_name, vs_add, file2vs, chatbot]) 558 | select_vs.change(fn=change_vs_name_input, 559 | inputs=[select_vs, chatbot], 560 | outputs=[vs_name, vs_add, file2vs, vs_path, chatbot]) 561 | load_file_button.click(get_vector_store, 562 | show_progress=True, 563 | inputs=[select_vs, files, sentence_size, chatbot, vs_add, vs_add], 564 | outputs=[vs_path, files, chatbot], ) 565 | load_folder_button.click(get_vector_store, 566 | show_progress=True, 567 | inputs=[select_vs, folder_files, sentence_size, chatbot, vs_add, 568 | vs_add], 569 | outputs=[vs_path, folder_files, chatbot], ) 570 | load_conent_button.click(get_vector_store, 571 | show_progress=True, 572 | inputs=[select_vs, one_title, sentence_size, chatbot, 573 | one_conent, one_content_segmentation], 574 | outputs=[vs_path, files, chatbot], ) 575 | # flag_csv_logger.setup([query, vs_path, chatbot, mode], "flagged") 576 | query.submit(get_knowledge_search, 577 | [query, vs_path, chatbot, score_threshold, vector_search_top_k, chunk_content, 578 | chunk_sizes], 579 | [chatbot, query]) 580 | with gr.Tab("lora微调"): 581 | with gr.Row(): 582 | with gr.Column(): 583 | model_name = gr.Radio(llm_model_dict_list, #'Bert', 584 | label="选择模型", 585 | value= llm_model_dict_list[0] if len(llm_model_dict_list)>0 else '暂无可选模型', 586 | interactive=True) 587 | with gr.Column(): 588 | select_lora = gr.Dropdown(set_lora_list.value+['新建Lora'], 589 | label= "选择或者新建一个Lora", 590 | value= set_lora_list.value[0] if len(set_lora_list.value) > 0 else '新建Lora', 591 | interactive=True) 592 | lora_name_en = gr.Textbox(label="请输入Lora英文名称,中间不能有空格,小写字母,单词间可用下划线分开", 593 | lines=1, 594 | interactive=True, 595 | visible=True) 596 | lora_add = gr.Button(value="确认添加Lora", visible=True) 597 | with gr.Row(): 598 | lastest_model = gr.outputs.Textbox(type="text", label='模型更新时间(请切换模型或Lora刷新显示)') 599 | gr.Markdown("## lora微调,目前只支持excel格式,要求语料格式为问题|回答两列,或者系统指示|问题|回答三列") 600 | train_data_file = gr.File(label="上传对话语料文件", file_types=['.xlsx']) 601 | train_button = gr.Button("开始训练", label="训练") 602 | kill_train_button = gr.Button("停止所有训练进程", label="训练") 603 | train_res = gr.outputs.Textbox(type="text", label='') 604 | train_data_file.upload(upload_file, 605 | inputs=train_data_file) 606 | train_button.click(on_train, inputs=[model_name, select_lora, train_data_file],outputs=[train_res]) 607 | model_name.change(fn=change_model_name_input, 608 | inputs=[model_name], 609 | outputs=[set_lora_list,select_lora]) 610 | select_lora.change(fn=change_lora_name_input, 611 | inputs=[model_name,select_lora], 612 | outputs=[lora_name_en, lora_add,lastest_model]) 613 | lora_add.click(fn=add_lora, 614 | inputs=[lora_name_en,set_lora_list], 615 | outputs=[select_lora, lora_name_en, lora_add,set_lora_list]) 616 | 617 | with gr.Tab("医疗小助手配置"): 618 | with gr.Column(): 619 | select_ass = gr.Dropdown(set_ass_list.value+['新建小助手'], 620 | label="选择或者新建一个医疗小助手", 621 | value= set_ass_list.value[0] if len(set_ass_list.value) > 0 else '新建小助手', 622 | interactive=True) 623 | ass_name_en = gr.Textbox(label="请输入小助手英文名称,中间不能有空格,小写字母,单词间可用下划线分开", 624 | lines=1, 625 | interactive=True, 626 | visible=True) 627 | ass_add = gr.Button(value="确认添加小助手", visible=True) 628 | ass_config = gr.Column(visible=False) 629 | with ass_config: 630 | ass_name = gr.Textbox(label="请给机器人取个名字,随便取,中英文均可,可以有空格", 631 | value='医疗小助手', 632 | lines=1, 633 | interactive=True, 634 | visible=True) 635 | llm_model = gr.Radio(llm_model_dict_list, 636 | label="LLM 模型", 637 | value=llm_model_dict_list[0] if len(llm_model_dict_list)>0 else '暂无可选模型', 638 | interactive=True) 639 | embedding_model = gr.Radio(embedding_model_dict_list, 640 | label="Embedding 模型", 641 | value=EMBEDDING_MODEL, 642 | interactive=True) 643 | lora_name = gr.Dropdown(set_lora_list.value+['不使用'], 644 | value='不使用', 645 | label="选择使用的Lora", 646 | interactive=True) 647 | llm_history_len = gr.Slider(0, 10, 648 | value=LLM_HISTORY_LEN, 649 | step=1, 650 | label="LLM 对话轮数", 651 | interactive=True) 652 | knowledge_set_name = gr.Dropdown(cur_vs_list.value, 653 | label="选择知识库", 654 | value= cur_vs_list.value[0] if len(cur_vs_list.value) > 1 else '不使用知识库', 655 | interactive=True) 656 | top_k = gr.Slider(1, 20, value=VECTOR_SEARCH_TOP_K, step=1, 657 | label="向量匹配 top k", interactive=True) 658 | score_threshold = gr.Number(value=700,label="知识相关度 Score 阈值,分值越低匹配度越高,数值范围约为0-1100,一般在700左右", 659 | precision=0, 660 | interactive=True) 661 | chunk_content = gr.Checkbox(value=True, 662 | label="是否启用上下文关联", 663 | interactive=True) 664 | chunk_sizes = gr.Number(value=250, precision=0, 665 | label="匹配单段内容的连接上下文后最大长度", 666 | interactive=True, visible=True) 667 | show_reference = gr.Checkbox(value=True, 668 | label="是否显示参考文献窗口", 669 | interactive=True) 670 | prompt_note = """prompt_template ,{context} 代表搜出来的文档,{chat_history}代表历史聊天记录,{question}代表最后一个问题,请在prompt里加上这些关键词。注意不使用知识库的情况下不生效。参考例子:假设你是用药助手,请根据文档来回复,如果文档内容为空或者None,则忽略,文档:{context}\n{chat_history}User: {question}Helper:""" 671 | gr.Markdown(prompt_note) 672 | prompt_template = gr.Textbox(label="内置prompt模板", 673 | value="假设你是用药助手,请根据文档来回复,如果文档内容为空或者None,则忽略,文档:{context}\n{chat_history}User:{question}Helper:", 674 | lines=8, 675 | interactive=True, 676 | visible=True) 677 | save_config_button = gr.Button("保存助手配置") 678 | save_res = gr.Textbox(type="text", label='', visible=False) 679 | save_config_button.click(set_config, show_progress=True, 680 | inputs=[select_ass,ass_name, llm_model, embedding_model, lora_name, llm_history_len, knowledge_set_name, top_k, score_threshold, chunk_content, chunk_sizes, show_reference, prompt_template,set_ass_list], outputs=[save_res,save_res,choose_ass]) 681 | 682 | llm_model.change(fn=change_model_name_select, 683 | inputs=[llm_model], 684 | outputs=[set_lora_list,lora_name]) 685 | select_ass.change(fn=change_assistant_name_input, 686 | inputs=[select_ass], 687 | outputs=[ass_name_en, ass_add, ass_config, save_res, ass_name, llm_model, embedding_model, lora_name, llm_history_len, knowledge_set_name, top_k, score_threshold, chunk_content, chunk_sizes, show_reference, prompt_template]) 688 | ass_add.click(fn=add_ass_config, 689 | inputs=[ass_name_en,set_ass_list], 690 | outputs=[select_ass, ass_name_en, ass_add, ass_config, save_res, ass_name, llm_model, embedding_model, lora_name, llm_history_len, knowledge_set_name, top_k, score_threshold, chunk_content, chunk_sizes, show_reference, prompt_template,set_ass_list,select_ass]) 691 | (demo 692 | .queue(concurrency_count=3) 693 | .launch(server_name='0.0.0.0', 694 | server_port=3355, 695 | show_api=False, 696 | share=True, 697 | debug= True, 698 | inbrowser=True)) -------------------------------------------------------------------------------- /chains/local_doc_qa.py: -------------------------------------------------------------------------------- 1 | from langchain.embeddings.huggingface import HuggingFaceEmbeddings 2 | from langchain.vectorstores import FAISS 3 | from langchain.document_loaders import UnstructuredFileLoader, TextLoader 4 | from configs.common_config import * 5 | import datetime 6 | from loader.textsplitter import ChineseTextSplitter 7 | from typing import List, Tuple, Dict 8 | from langchain.docstore.document import Document 9 | import numpy as np 10 | from loader.utils import torch_gc 11 | from tqdm import tqdm 12 | from pypinyin import lazy_pinyin 13 | from loader import UnstructuredPaddleImageLoader, UnstructuredPaddlePDFLoader 14 | from loader.models.base import (BaseAnswer, 15 | AnswerResult, 16 | AnswerResultStream, 17 | AnswerResultQueueSentinelTokenListenerQueue) 18 | from loader.models.loader.args import parser 19 | from loader.models.loader import LoaderCheckPoint 20 | import loader.models.shared as shared 21 | # from agent import bing_search 22 | from langchain.docstore.document import Document 23 | from .modules.json_load import JsonLoader 24 | from .modules.excel_load import ExcelLoader 25 | 26 | def load_file(filepath, sentence_size=SENTENCE_SIZE): 27 | if filepath.lower().endswith(".md"): 28 | loader = UnstructuredFileLoader(filepath, mode="elements") 29 | docs = loader.load() 30 | elif filepath.lower().endswith(".txt"): 31 | loader = TextLoader(filepath, autodetect_encoding=True) 32 | textsplitter = ChineseTextSplitter(pdf=False, sentence_size=sentence_size) 33 | docs = loader.load_and_split(textsplitter) 34 | elif filepath.lower().endswith(".pdf"): 35 | loader = UnstructuredPaddlePDFLoader(filepath) 36 | textsplitter = ChineseTextSplitter(pdf=True, sentence_size=sentence_size) 37 | docs = loader.load_and_split(textsplitter) 38 | elif filepath.lower().endswith(".jpg") or filepath.lower().endswith(".png"): 39 | loader = UnstructuredPaddleImageLoader(filepath, mode="elements") 40 | textsplitter = ChineseTextSplitter(pdf=False, sentence_size=sentence_size) 41 | docs = loader.load_and_split(text_splitter=textsplitter) 42 | elif filepath.lower().endswith(".jsonl") or filepath.lower().endswith(".json"): 43 | loader = JsonLoader(filepath, autodetect_encoding=True) 44 | docs = loader.load_json() 45 | elif filepath.lower().endswith(".xlsx"): 46 | loader = ExcelLoader(filepath, autodetect_encoding=True) 47 | docs = loader.load_excel() 48 | else: 49 | loader = UnstructuredFileLoader(filepath, mode="elements") 50 | textsplitter = ChineseTextSplitter(pdf=False, sentence_size=sentence_size) 51 | docs = loader.load_and_split(text_splitter=textsplitter) 52 | write_check_file(filepath, docs) 53 | return docs 54 | 55 | 56 | def write_check_file(filepath, docs): 57 | folder_path = os.path.join(os.path.dirname(filepath), "tmp_files") 58 | if not os.path.exists(folder_path): 59 | os.makedirs(folder_path) 60 | fp = os.path.join(folder_path, 'load_file.txt') 61 | with open(fp, 'a+', encoding='utf-8') as fout: 62 | fout.write("filepath=%s,len=%s" % (filepath, len(docs))) 63 | fout.write('\n') 64 | for i in docs: 65 | fout.write(str(i)) 66 | fout.write('\n') 67 | fout.close() 68 | 69 | def concat_history(model_name:str,chat_history: List[str]=[],query:str = '',process_type='chat'): 70 | formatted_history = "" 71 | if process_type=='chat': 72 | if 'pulse' in model_name.lower(): 73 | for sublist in chat_history: 74 | if sublist[0] is not None: 75 | formatted_history += f"\nUser: {sublist[0]}\nHelper: {sublist[1]}" 76 | else: 77 | for sublist in chat_history: 78 | if sublist[0] is not None: 79 | formatted_history += f"\nUser: {sublist[0]}\nHelper: {sublist[1]}" 80 | elif process_type=='search': 81 | chat_history = chat_history[-1:] 82 | for j,sublist in enumerate(chat_history, start=1): 83 | if sublist[0] is not None: 84 | formatted_history += f"User:{sublist[0]}Helper: {sublist[1]}" 85 | formatted_history+=f"User:{query}" 86 | return formatted_history 87 | 88 | def remove_duplicates(lst): 89 | seen = set() 90 | result = [] 91 | for i, item in enumerate(lst, start=1): 92 | page_content = item.page_content 93 | if page_content not in seen: 94 | seen.add(page_content) 95 | result.append(item) 96 | return result 97 | 98 | def generate_prompt(model_name:str, related_docs: List[str], 99 | query: str, 100 | chat_history: List[str]=[],prompt_template: str = PROMPT_TEMPLATE,) -> str: 101 | related_docs = remove_duplicates(related_docs) 102 | if 'pulse' in model_name.lower(): 103 | formatted_context = "" 104 | for i, doc in enumerate(related_docs, start=1): 105 | formatted_context += f"[{i}]\n```\n{doc.page_content}\n```\n" 106 | formatted_context = "None" if formatted_context == "" else formatted_context 107 | formatted_history = concat_history(model_name,chat_history,query,process_type='chat') 108 | prompt = prompt_template.replace("{question}", query).replace("{context}", formatted_context).replace("{chat_history}", formatted_history) 109 | else: 110 | if len(related_docs)>0: 111 | formatted_context = "" 112 | for i, doc in enumerate(related_docs, start=1): 113 | formatted_context += f"[{i}]{doc.page_content}\n" 114 | formatted_context = "None" if formatted_context == "" else formatted_context 115 | formatted_history = concat_history(model_name,chat_history,query,process_type='chat') 116 | prompt = prompt_template.replace("{question}", query).replace("{context}", formatted_context).replace("{chat_history}", formatted_history) 117 | else: 118 | prompt = prompt_template.replace("{question}", query) 119 | return prompt 120 | 121 | 122 | def seperate_list(ls: List[int]) -> List[List[int]]: 123 | lists = [] 124 | ls1 = [ls[0]] 125 | for i in range(1, len(ls)): 126 | if ls[i - 1] + 1 == ls[i]: 127 | ls1.append(ls[i]) 128 | else: 129 | lists.append(ls1) 130 | ls1 = [ls[i]] 131 | lists.append(ls1) 132 | return lists 133 | 134 | 135 | def similarity_search_with_score_by_vector( 136 | self, embedding: List[float], k: int = 4 137 | ) -> List[Tuple[Document, float]]: 138 | scores, indices = self.index.search(np.array([embedding], dtype=np.float32), k) 139 | docs = [] 140 | id_set = set() 141 | store_len = len(self.index_to_docstore_id) 142 | for j, i in enumerate(indices[0]): 143 | if i == -1 or 0 < self.score_threshold < scores[0][j]: 144 | # This happens when not enough docs are returned. 145 | continue 146 | _id = self.index_to_docstore_id[i] 147 | doc = self.docstore.search(_id) 148 | if not self.chunk_content: 149 | if not isinstance(doc, Document): 150 | raise ValueError(f"Could not find document for id {_id}, got {doc}") 151 | doc.metadata["score"] = int(scores[0][j]) 152 | docs.append(doc) 153 | continue 154 | id_set.add(i) 155 | docs_len = len(doc.page_content) 156 | for k in range(1, max(i, store_len - i)): 157 | break_flag = False 158 | for l in [i + k, i - k]: 159 | if 0 <= l < len(self.index_to_docstore_id): 160 | _id0 = self.index_to_docstore_id[l] 161 | doc0 = self.docstore.search(_id0) 162 | if docs_len + len(doc0.page_content) > self.chunk_size: 163 | break_flag = True 164 | break 165 | elif doc0.metadata["source"] == doc.metadata["source"]: 166 | docs_len += len(doc0.page_content) 167 | id_set.add(l) 168 | if break_flag: 169 | break 170 | if not self.chunk_content: 171 | return docs 172 | if len(id_set) == 0 and self.score_threshold > 0: 173 | return [] 174 | id_list = sorted(list(id_set)) 175 | id_lists = seperate_list(id_list) 176 | for id_seq in id_lists: 177 | for id in id_seq: 178 | if id == id_seq[0]: 179 | _id = self.index_to_docstore_id[id] 180 | doc = self.docstore.search(_id) 181 | else: 182 | _id0 = self.index_to_docstore_id[id] 183 | doc0 = self.docstore.search(_id0) 184 | doc.page_content += " " + doc0.page_content 185 | if not isinstance(doc, Document): 186 | raise ValueError(f"Could not find document for id {_id}, got {doc}") 187 | doc_score = min([scores[0][id] for id in [indices[0].tolist().index(i) for i in id_seq if i in indices[0]]]) 188 | doc.metadata["score"] = int(doc_score) 189 | docs.append(doc) 190 | torch_gc() 191 | return docs 192 | 193 | 194 | def search_result2docs(search_results): 195 | docs = [] 196 | for result in search_results: 197 | doc = Document(page_content=result["snippet"] if "snippet" in result.keys() else "", 198 | metadata={"source": result["link"] if "link" in result.keys() else "", 199 | "filename": result["title"] if "title" in result.keys() else ""}) 200 | docs.append(doc) 201 | return docs 202 | 203 | 204 | class LocalDocQA: 205 | llm: BaseAnswer = None 206 | embeddings: object = None 207 | top_k: int = VECTOR_SEARCH_TOP_K 208 | chunk_size: int = CHUNK_SIZE 209 | chunk_content: bool = True 210 | score_threshold: int = VECTOR_SEARCH_SCORE_THRESHOLD 211 | 212 | def init_cfg(self, 213 | embedding_model: str = EMBEDDING_MODEL, 214 | embedding_device=EMBEDDING_DEVICE, 215 | llm_model: BaseAnswer = None, 216 | top_k=VECTOR_SEARCH_TOP_K, 217 | ): 218 | self.llm = llm_model 219 | print('embedding_device',embedding_device) 220 | self.embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[embedding_model], 221 | model_kwargs={'device': embedding_device}) 222 | self.top_k = top_k 223 | 224 | def init_embedding(self, 225 | embedding_model: str = EMBEDDING_MODEL, 226 | embedding_device=EMBEDDING_DEVICE, 227 | top_k=VECTOR_SEARCH_TOP_K, 228 | ): 229 | self.embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[embedding_model], 230 | model_kwargs={'device': embedding_device}) 231 | self.top_k = top_k 232 | 233 | def init_knowledge_vector_store(self, 234 | filepath: str or List[str], 235 | vs_path: str or os.PathLike = None, 236 | sentence_size=SENTENCE_SIZE): 237 | loaded_files = [] 238 | failed_files = [] 239 | if isinstance(filepath, str): 240 | if not os.path.exists(filepath): 241 | print("路径不存在") 242 | return None 243 | elif os.path.isfile(filepath): 244 | file = os.path.split(filepath)[-1] 245 | try: 246 | docs = load_file(filepath, sentence_size) 247 | logger.info(f"{file} 已成功加载") 248 | loaded_files.append(filepath) 249 | except Exception as e: 250 | logger.error(e) 251 | logger.info(f"{file} 未能成功加载") 252 | return None 253 | elif os.path.isdir(filepath): 254 | docs = [] 255 | for file in tqdm(os.listdir(filepath), desc="加载文件"): 256 | fullfilepath = os.path.join(filepath, file) 257 | try: 258 | docs += load_file(fullfilepath, sentence_size) 259 | loaded_files.append(fullfilepath) 260 | except Exception as e: 261 | logger.error(e) 262 | failed_files.append(file) 263 | 264 | if len(failed_files) > 0: 265 | logger.info("以下文件未能成功加载:") 266 | for file in failed_files: 267 | logger.info(f"{file}\n") 268 | 269 | else: 270 | docs = [] 271 | for file in filepath: 272 | try: 273 | docs += load_file(file) 274 | logger.info(f"{file} 已成功加载") 275 | loaded_files.append(file) 276 | except Exception as e: 277 | logger.error(e) 278 | logger.info(f"{file} 未能成功加载") 279 | if len(docs) > 0: 280 | logger.info("文件加载完毕,正在生成向量库") 281 | if vs_path and os.path.isdir(vs_path): 282 | vector_store = FAISS.load_local(vs_path, self.embeddings) 283 | vector_store.add_documents(docs) 284 | torch_gc() 285 | else: 286 | if not vs_path: 287 | vs_path = os.path.join(VS_ROOT_PATH, 288 | f"""{"".join(lazy_pinyin(os.path.splitext(file)[0]))}_FAISS_{datetime.datetime.now().strftime("%Y%m%d_%H%M%S")}""") 289 | vector_store = FAISS.from_documents(docs, self.embeddings) # docs 为Document列表 290 | torch_gc() 291 | 292 | vector_store.save_local(vs_path) 293 | return vs_path, loaded_files 294 | else: 295 | logger.info("文件均未成功加载,请检查依赖包或替换为其他文件再次上传。") 296 | return None, loaded_files 297 | 298 | def one_knowledge_add(self, vs_path, one_title, one_conent, one_content_segmentation, sentence_size): 299 | try: 300 | if not vs_path or not one_title or not one_conent: 301 | logger.info("知识库添加错误,请确认知识库名字、标题、内容是否正确!") 302 | return None, [one_title] 303 | docs = [Document(page_content=one_conent + "\n", metadata={"source": one_title})] 304 | if not one_content_segmentation: 305 | text_splitter = ChineseTextSplitter(pdf=False, sentence_size=sentence_size) 306 | docs = text_splitter.split_documents(docs) 307 | if os.path.isdir(vs_path): 308 | vector_store = FAISS.load_local(vs_path, self.embeddings) 309 | vector_store.add_documents(docs) 310 | else: 311 | vector_store = FAISS.from_documents(docs, self.embeddings) ##docs 为Document列表 312 | torch_gc() 313 | vector_store.save_local(vs_path) 314 | return vs_path, [one_title] 315 | except Exception as e: 316 | logger.error(e) 317 | return None, [one_title] 318 | 319 | def get_knowledge_based_answer(self, model_name, query, vs_path, prompt_template, chat_history=[], streaming: bool = STREAMING): 320 | vector_store = FAISS.load_local(vs_path, self.embeddings) 321 | FAISS.similarity_search_with_score_by_vector = similarity_search_with_score_by_vector 322 | vector_store.chunk_size = self.chunk_size 323 | vector_store.chunk_content = self.chunk_content 324 | vector_store.score_threshold = self.score_threshold 325 | formatted_history = concat_history(model_name,chat_history,query,process_type='search') 326 | # 根据历史记录来检索 327 | related_docs_with_score = vector_store.similarity_search_with_score(formatted_history, k=self.top_k) 328 | torch_gc() 329 | prompt = generate_prompt(model_name, related_docs_with_score, query, chat_history, prompt_template) 330 | 331 | for answer_result in self.llm.generatorAnswer(prompt=prompt, history=chat_history, 332 | streaming=streaming): 333 | resp = answer_result.llm_output["answer"] 334 | history = answer_result.history 335 | history[-1][0] = query 336 | response = {"query": query, 337 | "result": resp, 338 | "source_documents": related_docs_with_score} 339 | yield response, history 340 | 341 | def get_base_answer(self, model_name, query, prompt_template, chat_history=[], streaming: bool = STREAMING): 342 | formatted_history = concat_history(model_name,chat_history,query,process_type='chat') 343 | torch_gc() 344 | if prompt_template == '': 345 | prompt = f"{formatted_history}User:{query}Helper:" 346 | else: 347 | prompt = generate_prompt(model_name, [], query, chat_history, prompt_template) 348 | 349 | for answer_result in self.llm.generatorAnswer(prompt=prompt, history=chat_history, 350 | streaming=streaming): 351 | resp = answer_result.llm_output["answer"] 352 | history = answer_result.history 353 | history[-1][0] = query 354 | response = {"query": query, 355 | "result": resp} 356 | yield response, history 357 | 358 | # query 查询内容 359 | # vs_path 知识库路径 360 | # chunk_content 是否启用上下文关联 361 | # score_threshold 搜索匹配score阈值 362 | # vector_search_top_k 搜索知识库内容条数,默认搜索5条结果 363 | # chunk_sizes 匹配单段内容的连接上下文长度 364 | def get_knowledge_based_conent_test(self, query, vs_path, chunk_content, 365 | score_threshold=VECTOR_SEARCH_SCORE_THRESHOLD, 366 | vector_search_top_k=VECTOR_SEARCH_TOP_K, chunk_size=CHUNK_SIZE): 367 | vector_store = FAISS.load_local(vs_path, self.embeddings) 368 | FAISS.similarity_search_with_score_by_vector = similarity_search_with_score_by_vector 369 | vector_store.chunk_content = chunk_content 370 | vector_store.score_threshold = score_threshold 371 | vector_store.chunk_size = chunk_size 372 | related_docs_with_score = vector_store.similarity_search_with_score(query, k=vector_search_top_k) 373 | if not related_docs_with_score: 374 | response = {"query": query, 375 | "source_documents": []} 376 | return response, "" 377 | torch_gc() 378 | prompt = "\n".join([doc.page_content for doc in related_docs_with_score]) 379 | response = {"query": query, 380 | "source_documents": related_docs_with_score} 381 | return response, prompt 382 | 383 | def get_search_result_based_answer(self,model_name, query, chat_history=[],prompt_template=BASE_PROMPT_TEMPLATE, streaming: bool = STREAMING): 384 | results = bing_search(query) 385 | result_docs = search_result2docs(results) 386 | prompt = generate_prompt(model_name, result_docs, query, chat_history, prompt_template) 387 | 388 | for answer_result in self.llm.generatorAnswer(prompt=prompt, history=chat_history, 389 | streaming=streaming): 390 | resp = answer_result.llm_output["answer"] 391 | history = answer_result.history 392 | history[-1][0] = query 393 | response = {"query": query, 394 | "result": resp, 395 | "source_documents": result_docs} 396 | yield response, history 397 | 398 | 399 | -------------------------------------------------------------------------------- /chains/modules/embeddings.py: -------------------------------------------------------------------------------- 1 | from langchain.embeddings.huggingface import HuggingFaceEmbeddings 2 | 3 | from typing import Any, List 4 | 5 | 6 | class MyEmbeddings(HuggingFaceEmbeddings): 7 | def __init__(self, **kwargs: Any): 8 | super().__init__(**kwargs) 9 | 10 | def embed_documents(self, texts: List[str]) -> List[List[float]]: 11 | """Compute doc embeddings using a HuggingFace transformer model. 12 | 13 | Args: 14 | texts: The list of texts to embed. 15 | 16 | Returns: 17 | List of embeddings, one for each text. 18 | """ 19 | texts = list(map(lambda x: x.replace("\n", " "), texts)) 20 | embeddings = self.client.encode(texts, normalize_embeddings=True) 21 | return embeddings.tolist() 22 | 23 | def embed_query(self, text: str) -> List[float]: 24 | """Compute query embeddings using a HuggingFace transformer model. 25 | 26 | Args: 27 | text: The text to embed. 28 | 29 | Returns: 30 | Embeddings for the text. 31 | """ 32 | text = text.replace("\n", " ") 33 | embedding = self.client.encode(text, normalize_embeddings=True) 34 | return embedding.tolist() 35 | -------------------------------------------------------------------------------- /chains/modules/excel_load.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import List, Optional 3 | 4 | from langchain.docstore.document import Document 5 | from langchain.document_loaders.base import BaseLoader 6 | from langchain.document_loaders.helpers import detect_file_encodings 7 | import json 8 | import pandas as pd 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | def process_json(json_str): 13 | docs=[] 14 | for doc in json_str['docs']: 15 | new_doc = Document( 16 | page_content=doc, metadata={"source": doc.split(' ')[0]} 17 | ) 18 | docs.append(new_doc) 19 | 20 | return docs 21 | 22 | class ExcelLoader(BaseLoader): 23 | """Load text files. 24 | 25 | 26 | Args: 27 | file_path: Path to the file to load. 28 | 29 | encoding: File encoding to use. If `None`, the file will be loaded 30 | with the default system encoding. 31 | 32 | autodetect_encoding: Whether to try to autodetect the file encoding 33 | if the specified encoding fails. 34 | """ 35 | 36 | def __init__( 37 | self, 38 | file_path: str, 39 | encoding: Optional[str] = None, 40 | autodetect_encoding: bool = False, 41 | ): 42 | """Initialize with file path.""" 43 | self.file_path = file_path 44 | self.encoding = encoding 45 | self.autodetect_encoding = autodetect_encoding 46 | 47 | def load(self) -> List[Document]: 48 | """Load from file path.""" 49 | pass 50 | 51 | def load_excel(self) -> List[Document]: 52 | """Load from file path.""" 53 | documents = [] 54 | try: 55 | df = pd.read_excel(self.file_path, sheet_name=None) 56 | # 获取所有sheet的名字 57 | sheet_names = df.keys() 58 | # 遍历每个sheet 59 | for sheet_name in sheet_names: 60 | # 获取当前sheet的数据 61 | sheet_data = df[sheet_name] 62 | # 遍历每行数据 63 | for _, row in sheet_data.iterrows(): 64 | try: 65 | question = str(row["问题"]) 66 | answer = str(row["回答"]) 67 | doc = "【参考问题】"+question+"【参考回答】"+answer 68 | new_doc = Document( 69 | page_content=doc, metadata={"source": sheet_name} 70 | ) 71 | documents.append(new_doc) 72 | except Exception as e: 73 | print('文件表读取发生错误!',sheet_name,e) 74 | except Exception as e: 75 | print('e2',e) 76 | raise RuntimeError(f"Error loading {self.file_path}") from e 77 | return documents -------------------------------------------------------------------------------- /chains/modules/json_load.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import List, Optional 3 | 4 | from langchain.docstore.document import Document 5 | from langchain.document_loaders.base import BaseLoader 6 | from langchain.document_loaders.helpers import detect_file_encodings 7 | import json 8 | logger = logging.getLogger(__name__) 9 | 10 | def process_json(json_str): 11 | docs=[] 12 | for doc in json_str['docs']: 13 | new_doc = Document( 14 | page_content=doc, metadata={"source": doc.split(' ')[0]} 15 | ) 16 | docs.append(new_doc) 17 | 18 | return docs 19 | 20 | class JsonLoader(BaseLoader): 21 | """Load text files. 22 | 23 | 24 | Args: 25 | file_path: Path to the file to load. 26 | 27 | encoding: File encoding to use. If `None`, the file will be loaded 28 | with the default system encoding. 29 | 30 | autodetect_encoding: Whether to try to autodetect the file encoding 31 | if the specified encoding fails. 32 | """ 33 | 34 | def __init__( 35 | self, 36 | file_path: str, 37 | encoding: Optional[str] = None, 38 | autodetect_encoding: bool = False, 39 | ): 40 | """Initialize with file path.""" 41 | self.file_path = file_path 42 | self.encoding = encoding 43 | self.autodetect_encoding = autodetect_encoding 44 | 45 | def load(self) -> List[Document]: 46 | """Load from file path.""" 47 | pass 48 | 49 | def load_json(self) -> List[Document]: 50 | """Load from file path.""" 51 | documents = [] 52 | try: 53 | with open(self.file_path, encoding=self.encoding) as f: 54 | for line in f: 55 | json_object = json.loads(line) 56 | docs = process_json(json_object) 57 | documents+= docs 58 | except UnicodeDecodeError as e: 59 | print('e1',e) 60 | if self.autodetect_encoding: 61 | detected_encodings = detect_file_encodings(self.file_path) 62 | for encoding in detected_encodings: 63 | logger.debug("Trying encoding: ", encoding.encoding) 64 | try: 65 | with open(self.file_path, 'r', encoding=encoding.encoding) as f: 66 | for line in f: 67 | json_object = json.loads(line) 68 | docs = process_json(json_object) 69 | documents+= docs 70 | break 71 | except UnicodeDecodeError: 72 | continue 73 | else: 74 | raise RuntimeError(f"Error loading {self.file_path}") from e 75 | except Exception as e: 76 | print('e2',e) 77 | raise RuntimeError(f"Error loading {self.file_path}") from e 78 | return documents -------------------------------------------------------------------------------- /chains/modules/vectorstores.py: -------------------------------------------------------------------------------- 1 | from langchain.vectorstores import FAISS 2 | from typing import Any, Callable, List, Optional, Tuple, Dict 3 | from langchain.docstore.document import Document 4 | from langchain.docstore.base import Docstore 5 | 6 | from langchain.vectorstores.utils import maximal_marginal_relevance 7 | from langchain.embeddings.base import Embeddings 8 | import uuid 9 | from langchain.docstore.in_memory import InMemoryDocstore 10 | 11 | import numpy as np 12 | 13 | def dependable_faiss_import() -> Any: 14 | """Import faiss if available, otherwise raise error.""" 15 | try: 16 | import faiss 17 | except ImportError: 18 | raise ValueError( 19 | "Could not import faiss python package. " 20 | "Please install it with `pip install faiss` " 21 | "or `pip install faiss-cpu` (depending on Python version)." 22 | ) 23 | return faiss 24 | 25 | class FAISSVS(FAISS): 26 | def __init__(self, 27 | embedding_function: Callable[..., Any], 28 | index: Any, 29 | docstore: Docstore, 30 | index_to_docstore_id: Dict[int, str]): 31 | super().__init__(embedding_function, index, docstore, index_to_docstore_id) 32 | 33 | def max_marginal_relevance_search_by_vector( 34 | self, embedding: List[float], k: int = 4, fetch_k: int = 20, **kwargs: Any 35 | ) -> List[Tuple[Document, float]]: 36 | """Return docs selected using the maximal marginal relevance. 37 | 38 | Maximal marginal relevance optimizes for similarity to query AND diversity 39 | among selected documents. 40 | 41 | Args: 42 | embedding: Embedding to look up documents similar to. 43 | k: Number of Documents to return. Defaults to 4. 44 | fetch_k: Number of Documents to fetch to pass to MMR algorithm. 45 | 46 | Returns: 47 | List of Documents with scores selected by maximal marginal relevance. 48 | """ 49 | scores, indices = self.index.search(np.array([embedding], dtype=np.float32), fetch_k) 50 | # -1 happens when not enough docs are returned. 51 | embeddings = [self.index.reconstruct(int(i)) for i in indices[0] if i != -1] 52 | mmr_selected = maximal_marginal_relevance( 53 | np.array([embedding], dtype=np.float32), embeddings, k=k 54 | ) 55 | selected_indices = [indices[0][i] for i in mmr_selected] 56 | selected_scores = [scores[0][i] for i in mmr_selected] 57 | docs = [] 58 | for i, score in zip(selected_indices, selected_scores): 59 | if i == -1: 60 | # This happens when not enough docs are returned. 61 | continue 62 | _id = self.index_to_docstore_id[i] 63 | doc = self.docstore.search(_id) 64 | if not isinstance(doc, Document): 65 | raise ValueError(f"Could not find document for id {_id}, got {doc}") 66 | docs.append((doc, score)) 67 | return docs 68 | 69 | def max_marginal_relevance_search( 70 | self, 71 | query: str, 72 | k: int = 4, 73 | fetch_k: int = 20, 74 | **kwargs: Any, 75 | ) -> List[Tuple[Document, float]]: 76 | """Return docs selected using the maximal marginal relevance. 77 | 78 | Maximal marginal relevance optimizes for similarity to query AND diversity 79 | among selected documents. 80 | 81 | Args: 82 | query: Text to look up documents similar to. 83 | k: Number of Documents to return. Defaults to 4. 84 | fetch_k: Number of Documents to fetch to pass to MMR algorithm. 85 | 86 | Returns: 87 | List of Documents with scores selected by maximal marginal relevance. 88 | """ 89 | embedding = self.embedding_function(query) 90 | docs = self.max_marginal_relevance_search_by_vector(embedding, k, fetch_k) 91 | return docs 92 | 93 | @classmethod 94 | def __from( 95 | cls, 96 | texts: List[str], 97 | embeddings: List[List[float]], 98 | embedding: Embeddings, 99 | metadatas: Optional[List[dict]] = None, 100 | **kwargs: Any, 101 | ) -> FAISS: 102 | faiss = dependable_faiss_import() 103 | index = faiss.IndexFlatIP(len(embeddings[0])) 104 | index.add(np.array(embeddings, dtype=np.float32)) 105 | 106 | # # my code, for speeding up search 107 | # quantizer = faiss.IndexFlatL2(len(embeddings[0])) 108 | # index = faiss.IndexIVFFlat(quantizer, len(embeddings[0]), 100) 109 | # index.train(np.array(embeddings, dtype=np.float32)) 110 | # index.add(np.array(embeddings, dtype=np.float32)) 111 | 112 | documents = [] 113 | for i, text in enumerate(texts): 114 | metadata = metadatas[i] if metadatas else {} 115 | documents.append(Document(page_content=text, metadata=metadata)) 116 | index_to_id = {i: str(uuid.uuid4()) for i in range(len(documents))} 117 | docstore = InMemoryDocstore( 118 | {index_to_id[i]: doc for i, doc in enumerate(documents)} 119 | ) 120 | return cls(embedding.embed_query, index, docstore, index_to_id) 121 | 122 | -------------------------------------------------------------------------------- /chains/text_load.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pinecone 3 | from tqdm import tqdm 4 | from langchain.llms import OpenAI 5 | from langchain.text_splitter import SpacyTextSplitter 6 | from langchain.document_loaders import TextLoader 7 | from langchain.document_loaders import DirectoryLoader 8 | from langchain.indexes import VectorstoreIndexCreator 9 | from langchain.embeddings.openai import OpenAIEmbeddings 10 | from langchain.vectorstores import Pinecone 11 | 12 | #一些配置文件 13 | openai_key="你的key" # 注册 openai.com 后获得 14 | pinecone_key="你的key" # 注册 app.pinecone.io 后获得 15 | pinecone_index="你的库" #app.pinecone.io 获得 16 | pinecone_environment="你的Environment" # 登录pinecone后,在indexes页面 查看Environment 17 | pinecone_namespace="你的Namespace" #如果不存在自动创建 18 | 19 | #科学上网你懂得 20 | os.environ['HTTP_PROXY'] = 'http://127.0.0.1:7890' 21 | os.environ['HTTPS_PROXY'] = 'http://127.0.0.1:7890' 22 | 23 | #初始化pinecone 24 | pinecone.init( 25 | api_key=pinecone_key, 26 | environment=pinecone_environment 27 | ) 28 | index = pinecone.Index(pinecone_index) 29 | 30 | #初始化OpenAI的embeddings 31 | embeddings = OpenAIEmbeddings(openai_api_key=openai_key) 32 | 33 | #初始化text_splitter 34 | text_splitter = SpacyTextSplitter(pipeline='zh_core_web_sm',chunk_size=1000,chunk_overlap=200) 35 | 36 | # 读取目录下所有后缀是txt的文件 37 | loader = DirectoryLoader('../docs', glob="**/*.txt", loader_cls=TextLoader) 38 | 39 | #读取文本文件 40 | documents = loader.load() 41 | 42 | # 使用text_splitter对文档进行分割 43 | split_text = text_splitter.split_documents(documents) 44 | try: 45 | for document in tqdm(split_text): 46 | # 获取向量并储存到pinecone 47 | Pinecone.from_documents([document], embeddings, index_name=pinecone_index) 48 | except Exception as e: 49 | print(f"Error: {e}") 50 | quit() 51 | 52 | 53 | -------------------------------------------------------------------------------- /configs/common_config.py: -------------------------------------------------------------------------------- 1 | import torch.cuda 2 | import torch.backends 3 | import os 4 | import logging 5 | import uuid 6 | import datetime 7 | 8 | def get_formatted_date(): 9 | current_date = datetime.date.today() 10 | formatted_date = current_date.strftime("%Y-%m-%d") 11 | return formatted_date 12 | 13 | date_string = get_formatted_date() 14 | 15 | LOG_FORMAT = "%(levelname) -5s %(asctime)s" "-1d: %(message)s" 16 | logger = logging.getLogger() 17 | logger.setLevel(logging.INFO) 18 | logging.basicConfig(format=LOG_FORMAT) 19 | 20 | embedding_model_dict = { 21 | "ernie-tiny": "nghuyong/ernie-3.0-nano-zh", 22 | "ernie-base": "nghuyong/ernie-3.0-base-zh", 23 | "text2vec-base": "shibing624/text2vec-base-chinese", 24 | "text2vec-large-chinese": "GanymedeNil/text2vec-large-chinese", 25 | "m3e-small": "moka-ai/m3e-small", 26 | "m3e-base": "moka-ai/m3e-base", 27 | } 28 | 29 | # Embedding model name 30 | EMBEDDING_MODEL = "text2vec-large-chinese" 31 | 32 | # Embedding running device 33 | EMBEDDING_DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" 34 | 35 | 36 | # supported LLM models 37 | # llm_model_dict 处理了loader的一些预设行为,如加载位置,模型名称,模型处理器实例 38 | llm_model_dict = { 39 | "PULSE": { 40 | "name": "PULSE", 41 | "pretrained_model_name": "OpenMEDLab/PULSE-7bv5", 42 | "local_model_path": "OpenMEDLab/PULSE-7bv5", 43 | "provides": "Bloomz" 44 | } 45 | } 46 | 47 | # LLM 名称 48 | LLM_MODEL = "PULSE" 49 | # 如果你需要加载本地的model,指定这个参数 ` --no-remote-model`,或者下方参数修改为 `True` 50 | NO_REMOTE_MODEL = False 51 | # 量化加载8bit 模型 52 | LOAD_IN_8BIT = False 53 | # Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU. 54 | BF16 = False 55 | # 本地模型存放的位置 56 | MODEL_DIR = "model/" 57 | # 本地lora存放的位置 58 | LORA_DIR = "loras/" 59 | 60 | # LLM lora path,默认为空,如果有请直接指定文件夹路径 61 | LLM_LORA_PATH = "" 62 | USE_LORA = True if LLM_LORA_PATH else False 63 | 64 | # LLM streaming reponse 65 | STREAMING = False 66 | 67 | # Use p-tuning-v2 PrefixEncoder 68 | USE_PTUNING_V2 = False 69 | 70 | # LLM running device 71 | LLM_DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" 72 | 73 | 74 | VS_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "vector_store") 75 | 76 | UPLOAD_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "content") 77 | 78 | # 基于上下文的prompt模版,请务必保留"{question}"和"{context}" 79 | PROMPT_TEMPLATE = """已知信息: 80 | {context} 81 | 82 | 假设你是客服,根据上述已知信息,来回答用户的问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题” 或 “没有提供足够的相关信息”,不允许在答案中添加编造成分,答案请使用中文。 问题是:{question}""" 83 | 84 | BASE_PROMPT_TEMPLATE = """Instruction: 假如你是机器人。Input:{question}""" 85 | 86 | 87 | # 文本分句长度 88 | SENTENCE_SIZE = 20 89 | 90 | # 匹配后单段上下文长度 91 | CHUNK_SIZE = 250 92 | 93 | # LLM input history length 94 | LLM_HISTORY_LEN = 8 95 | 96 | # return top-k text chunk from vector store 97 | VECTOR_SEARCH_TOP_K = 3 98 | 99 | # 知识检索内容相关度 Score, 数值范围约为0-1100,如果为0,则不生效,经测试设置为小于500时,匹配结果更精准 100 | VECTOR_SEARCH_SCORE_THRESHOLD = 700 101 | 102 | NLTK_DATA_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "nltk_data") 103 | 104 | FLAG_USER_NAME = uuid.uuid4().hex 105 | 106 | logger.info(f""" 107 | loading model config 108 | llm device: {LLM_DEVICE} 109 | embedding device: {EMBEDDING_DEVICE} 110 | dir: {os.path.dirname(os.path.dirname(__file__))} 111 | flagging username: {FLAG_USER_NAME} 112 | """) 113 | 114 | -------------------------------------------------------------------------------- /configs/test_ass.yaml: -------------------------------------------------------------------------------- 1 | ass_name: 药物小助手 2 | llm_model: PULSE 3 | embedding_model: text2vec-large-chinese 4 | lora_name: 不使用 5 | llm_history_len: 8 6 | knowledge_set_name: drug_kb 7 | top_k: 3 8 | score_threshold: 500 9 | chunk_content: true 10 | chunk_sizes: 250 11 | show_reference: true 12 | prompt_template: '假设你是药物助手,请根据文档来回复,如果文档内容为空或者None,则忽略,文档:{context}\n{chat_history}User:{question}Helper:' 13 | -------------------------------------------------------------------------------- /demo_data/kb_drug_demo.jsonl: -------------------------------------------------------------------------------- 1 | {"docs": ["布洛芬分散片 英文名\nIbuprofen Dispersible Tablets", "布洛芬分散片 用药科室\n解热镇痛抗炎药", "布洛芬分散片 药物分类\n化学药品", "布洛芬分散片 药物剂量\n50mg", "布洛芬分散片 适应症\n适用于因感冒、急性上呼吸道感染、急性咽喉炎等疾病引起的发热、头痛、周身痛及关节痛;其他疾病如关节炎、牙疾等所致的轻、中度疼痛。", "布洛芬分散片 禁忌症\n1 活动期消化道溃疡。2 对本药物过敏者,因服用阿司匹林和其它非类固醇类抗炎药诱发哮喘、鼻炎或荨麻疹的患者。3 有失血倾向者。", "布洛芬分散片 药物用法\n口服或少许水制成混悬液服用。1 1~12岁儿童患者(1) 用于发热,推荐剂量为每日每公斤体重20mg(2/5片),分三次服用,或遵医嘱.(2) 用于镇痛,推荐剂量为每日每公斤体重30mg(3/5片),分三次服用,或遵医嘱。2 成人及12岁以上儿童,推荐剂量为一次0.2~0.4g(4~8片)一日3次,或遵医嘱。", "布洛芬分散片 不良反应\n本品的不良反应较少,一般为肠、胃不适或皮疹、头痛、耳鸣,偶见转氨酶升高等。也有引起胃肠道出血而加重溃疡的报道。", "布洛芬分散片 注意事项\n1.肠胃病患者慎用。但对其他抗风湿药物耐受性差者可能对本品有良好耐受性。2.有支气管哮喘病史患者,可能会引起支气管痉挛。3.并用抗凝血剂的患者,服药的最初几日应随时监测其凝血酶原时间。4.心功能不全患者慎用。5.连续服用三天发热不退时,应请医生诊治。6.严重肝、肾功能障碍,红斑狼疮或其它免疫疾病患者慎用。7.剧烈腹痛、粪便黑色或带血,皮肤、粘膜炎症,患眼疾者请在医生指导下用药。", "布洛芬分散片 药物相互作用\n口服降糖药、甲氨喋呤、苯妥英、毛地黄、锂剂及饮酒等可降低本品耐受量。", "布洛芬分散片 存储方式\n遮光密封,于阴凉干燥处保存。有效期暂定一年。"]} 2 | {"docs": ["西沙必利片 英文名\nCisapride Tablets", "西沙必利片 用药科室\n消化科用药", "西沙必利片 药物分类\n化学药品", "西沙必利片 药物剂量\n5mg", "西沙必利片 儿童禁忌\n婴幼儿禁用", "西沙必利片 孕妇禁忌\n1.尽管在动物不影响胚胎形成,无原始的胚胎毒性,也无致畸作用,但若在\n妊娠期,尤其是在妊娠的头三个月应权衡利弊使用。\n2.尽管经乳汁排泄的量很少,仍建议哺乳母亲勿用。", "西沙必利片 老年人禁忌\n在老年人,由于中度延长了清除半衰期,稳态血浆浓度一般会增高,故治疗\n剂量应酌减。", "西沙必利片 适应症\n全胃肠促动力药。主要用于功能性消化不良,X线、内窥镜检查为阴性的上消化道不适,症状为早饱,饭后饱胀、食量减退、胃胀、嗳气过多、食欲缺乏、恶心、呕吐或类似溃疡的主诉(上腹部灼痛)。另可用于轻度返流性食管炎的治疗。", "西沙必利片 禁忌症\n已知对本品过敏者禁用。禁止同时口服或非肠道使用强效抑制CYP3A4酶的药物,包括:三唑类抗真菌药;大环内酯类抗生素;HIV蛋白酶抑制剂;萘法唑酮;(见药物相互作用)。心脏病、心律失常、QT间期延长者禁用,禁止与引起QT间期延长的药物一起用;有水、电解质紊乱的患者禁用,特别是低血钾或低血镁者禁用;心动过缓者禁用;先天QT间期延长或有先天QT间期延长综合症家族史者禁用;肺、肝、肾功能不全的病人禁用;婴幼儿禁用。", "西沙必利片 药物用法\n口服治疗:用量:每日最高服药剂量为30mg。成人:根据病情的程度,每日总量15~30mg,分2~3次给药,每次5mg(剂量可以加倍)。体重为25~50公斤的儿童:每次最大剂量为5mg,每日四次。或遵医嘱。体重为25公斤以下的儿童:每次0.2mg/kg体重,每日三至四次。或遵医嘱。建议尽量避免与西柚汁一起服用。在肾功能不全时,建议减半日用量。", "西沙必利片 不良反应\n因本品的药理活性,可能发生瞬时性腹部痉挛、腹鸣和腹泻。发生腹部痉挛时,可减半剂量。偶有过敏反应包括红疹、瘙痒、荨麻疹、支气管痉挛、轻度短暂的头痛或头晕以及与剂量相关的尿频的报道。极少数心律失常的报道,包括室性心动过速、室颤、尖端扭转型室速、QT延长。大多数此类病人常同时服用数种其它药物-其中包括抑制CYP3A4酶的药物,或已患有心脏病、已有心律失常的危险因素存在。罕见可逆性肝功能异常的报道,可伴或不伴胆汁郁积。虽然有男子女性乳房和乳溢的病例报道,但在大规模监测研究中发生率(<", "西沙必利片 注意事项\n用药过程中应注意检测心电图。胃肠道运动增加可造成危害的病人,必须慎用。有猝死的家庭史,要权衡利弊谨慎使用本品。QT间期大于450毫秒的病人或电解质紊乱的病人,不应使用本品。本品不影响神经运动性功能,不引起镇静和嗜睡。然而,本品可加速中枢神经系统抑制剂的吸收,如巴比妥酸盐、酒精等,因此同时给予应慎重。", "西沙必利片 药物相互作用\n本品的主要代谢途径是通过CYP3A4酶进行代谢。若同时口服或非肠道使用能抑制此酶的药物,可导致血浆西沙必利浓度升高,从而增加QT间期和心律失常的危险性,心律失常包括室性心动过速、室颤和尖端扭转型室速。所以禁止与这些药物同时服用。", "西沙必利片 存储方式\n15~30℃干燥处贮存,放于儿童不易拿到处。"]} 3 | {"docs": ["克林霉素磷酸酯氯化钠注射液 英文名\nClindamycin Phosphate and Sodium Chloride Injection", "克林霉素磷酸酯氯化钠注射液 用药科室\n感染科用药","克林霉素磷酸酯氯化钠注射液 药物剂量\n100毫升:克林霉素0.6克与氯化钠0.9克", "克林霉素磷酸酯氯化钠注射液 儿童禁忌\n小于4岁儿童慎用。", "克林霉素磷酸酯氯化钠注射液 孕妇禁忌\n孕妇及哺乳期妇女使用本品应注意。", "克林霉素磷酸酯氯化钠注射液 适应症\n1.革兰氏阳性菌引起的下列各种感染性疾病:\n(1)扁桃体炎、化脓性中耳炎、鼻窦炎等。\n(2)急性支气管炎、慢性支气管炎急性发作、肺炎、肺脓肿和支气管扩张合并感染等。\n(3)皮肤和软组织感染:疖、痈、脓肿、蜂窝组织炎、创伤和手术后感染等。\n(4)泌尿系统感染:急性尿道炎、急性肾盂肾炎、前列腺炎等。\n(5)其他:骨髓炎、败血症、腹膜炎和口腔感染等。\n2.厌氧菌引起的各种感染性疾病:\n(1)脓胸、肺脓肿、厌氧菌性肺病。\n(2)皮肤和软组织感染、败血症。\n(3)腹腔内感染:腹膜炎、腹腔内脓肿。\n(4)女性盆腔及生", "克林霉素磷酸酯氯化钠注射液 禁忌症\n本品与林可霉素、克林霉素有交叉耐药性,对克林霉素或林可霉素有过敏史者禁用。", "克林霉素磷酸酯氯化钠注射液 药物用法\n静脉滴注:成人每日0.6—2.7g(以克林霉素计),分2—3次;儿童15—40mg/kg(以克林霉素计),分2—4次,或遵医嘱。\n静滴速度:每瓶不少于30分钟。", "克林霉素磷酸酯氯化钠注射液 不良反应\n1.长期静脉滴注应注意静脉炎的出现。\n2.胃肠道反应:偶见恶心、呕吐、腹痛及腹泻。\n3.过敏反应:少数病人可出现药物性皮疹。\n4.对造血系统基本无毒性反应,偶可引起中性粒细胞减少,嗜酸性粒细胞增多,血小板减少等。一般轻微为一过性。\n5.少数病人可发生一过性碱性磷酸酶、血清转氨酶轻度升高及黄疸。\n6.极少数病人可产生伪膜性结肠炎。", "克林霉素磷酸酯氯化钠注射液 注意事项\n1.与青霉素、头孢菌素类抗生素无交叉过敏反应,可用于对青霉素过敏者。\n2.禁与氨苄青霉素、苯妥英钠、巴比妥盐、氨茶碱、葡萄糖酸钙及硫酸镁配伍;与红霉素呈拮抗作用,不宜合用。\n3.肝、肾功能损害者及小于4岁儿童慎用。孕妇及哺乳期妇女使用本品应注意其利弊。\n4.如出现伪膜性肠炎,选用万古霉素口服0.125—0.5g,每日4次进行治疗。\n5.本品使用前应仔细检查,如有异物、溶液混浊、封口松动、瓶身有裂纹者切勿使用。\n6.本品应一次性使用,输注后的剩余药液,切勿贮藏再用。", "克林霉素磷酸酯氯化钠注射液 药物相互作用\n本品与红霉素呈拮抗作用,不宜合用。", "克林霉素磷酸酯氯化钠注射液 存储方式\n遮光,密闭,在阴凉处保存。"]} 4 | {"docs": ["柴银口服液 用药科室\n呼吸科用药", "柴银口服液 药物分类\n中药", "柴银口服液 药物剂量\n每瓶装20毫升", "柴银口服液 适应症\n清热解毒,利咽止咳。用于上呼吸道感染外感风热证,症见:发热恶风,头痛、咽痛,汗出,鼻塞流涕,咳嗽,舌边尖红,苔薄黄。", "柴银口服液 禁忌症\n尚不明确。", "柴银口服液 药物用法\n口服。一次1瓶,一日3次,连服3天。", "柴银口服液 不良反应\n偶有腹泻。", "柴银口服液 注意事项\n脾胃虚寒者宜温服。", "柴银口服液 存储方式\n密封。"]} 5 | {"docs": ["麻仁润肠丸 英文名\nMaren Runchang Wan", "麻仁润肠丸 用药科室\n消化科用药", "麻仁润肠丸 药物分类\n中药", "麻仁润肠丸 药物剂量\n每丸重6克", "麻仁润肠丸 孕妇禁忌\n孕妇忌服。", "麻仁润肠丸 适应症\n润肠通便。用于肠胃积热,胸腹胀满,大便秘结。", "麻仁润肠丸 禁忌症\n孕妇忌服。", "麻仁润肠丸 药物用法\n口服。一次1-2丸,一日2次。", "麻仁润肠丸 不良反应\n尚不明确。", "麻仁润肠丸 注意事项\n1.孕妇忌服。月经期慎用。\n2.年青体壮者便秘时不宜用本药。\n3.忌食生冷、油腻、辛辣食物。\n4.严重气质性病变引起的排便困难,如结肠癌,严重的肠道憩窒,肠梗阻及炎症性肠病等忌用。\n5.服药三天后症状未改善,或出现其他症状时,应及时去医院就诊。\n6.按照用法用量服用,有慢性病史者、小儿及年老体虚者不宜长期服用,应在医师指导下服用。\n7.药品性状发生改变时禁止服用。\n8.儿童必须在成人的监护下使用。\n9.请将此药品放在儿童不能接触的地方。\n10.如正在服用其他药品,使用本品前请咨询医师或药师。", "麻仁润肠丸 药物相互作用\n如与其他药物同时使用可能会发生药物相互作用,详情请咨询医师或药师。", "麻仁润肠丸 存储方式\n密封"]} 6 | {"docs": ["布洛芬缓释片 英文名\nIbuprofen Sustained Release Tablets", "布洛芬缓释片 用药科室\n解热镇痛抗炎药", "布洛芬缓释片 药物分类\n化学药品", "布洛芬缓释片 药物剂量\n0.3g", "布洛芬缓释片 适应症\n用于减轻中度疼痛,如关节痛、神经痛、肌肉痛、偏头痛、头痛、痛经、牙痛;也可用于感冒和流感引起的发热。", "布洛芬缓释片 禁忌症\n1  对其他非甾体抗炎药过敏者禁用;\n2  孕妇及哺乳期妇女禁用;\n3  对阿司匹林过敏的哮喘患者禁用。", "布洛芬缓释片 药物用法\n口服。成人,一次1片,一日2次(早晚各一次)。", "布洛芬缓释片 不良反应\n1  少数病人可出现恶心、呕吐、胃烧灼感或轻度消化不良、胃肠道溃疡及出血、转氨酶升高、头痛、头晕、耳鸣、视力模糊、精神紧张、嗜睡、下肢水肿或体重骤增;\n2  罕见皮疹、过敏性肾炎、膀胱炎、肾病综合征、肾乳头坏死或肾功能衰竭、支气管痉挛。", "布洛芬缓释片 注意事项\n1  本品为对症治疗药,不宜长期或大量使用,用于止痛不得超过5天,用于解热不得超过3天,如症状不缓解,请咨询医师或药师;\n2.必须整片吞服,不得碾碎或溶解后服用;\n3  不能同时服用其他含有解热镇痛药的药品(如某些复方抗感冒药);\n4  服用本品期间不得饮酒或含有酒精的饮料;\n5  有下列情况患者慎用:60岁以上、支气管哮喘、肝肾功能不全、凝血机制或血小板功能障碍(如血友病);\n6  下列情况患者应在医师指导下使用:有消化性溃疡史、胃肠道出血、心功能不全、高血压;\n7  如服用过量或出现严重不良反应,应立即就医;\n8  对本品过敏者禁用,过敏体质者慎用;\n9  本品性状发生改变时禁止使用;\n10  请将本品放在儿童不能接触的地方;\n11  如正在使用其他药品,使用本品前请咨询医师或药师。", "布洛芬缓释片 药物相互作用\n1  本品与其他解热、镇痛、抗炎药物同用时可增加胃肠道不良反应,并可能导致溃疡;\n2  本品与肝素、双香豆素等抗凝药同用时,可导致凝血酶原时间延长,增加出血倾向;\n3  本品与地高辛、甲氨蝶呤、口服降血糖药物同用时,能使这些药物的血药浓度增高,不宜同用;\n4 本品与呋塞米(呋喃苯胺酸)同用时,后者的排钠和降压作用减弱;与抗高血压药同用时,也降低后者的降压效果;\n5  如与其他药物同时使用可能会发生药物相互作用,详情请咨询医师或药师。", "布洛芬缓释片 存储方式\n避光,密闭保存 室温贮藏。"]} 7 | {"docs": ["板蓝根颗粒 用药科室\n感冒用药", "板蓝根颗粒 药物分类\n中药", "板蓝根颗粒 适应症\n清热解毒,凉血利咽。用于肺胃热盛所致的咽喉肿痛、口咽干燥;急性扁桃体炎见上述证候者。", "板蓝根颗粒 禁忌症\n忌烟,酒及辛辣,生冷,油腻食物。", "板蓝根颗粒 药物用法\n开水冲服,一次5~10g(含糖型),或一次3~6g(无糖型)一日3~4次。", "板蓝根颗粒 不良反应\n尚不明确。", "板蓝根颗粒 注意事项\n1不宜在服药期间同时服用滋补性中成药。\n2风寒感冒者不宜使用,其表现为恶寒重,发热轻,无汗,鼻塞流涕,口不渴,咳吐稀白痰。\n3有高血压、心脏病、肝病、糖尿病、肾病等慢性病严重者,孕妇或正在接受其他治疗的患者,应在医师指导下服用。\n4服药三天后,症状未改善,或出现发热咳嗽加重,并有其他症状如胸闷、心悸等应去医院就诊。\n5按照用法用量服用,小儿、年老体虚患者应在医师指导下服用。\n6连续服用应向医师咨询。\n7药品性状发生改变时禁止使用。\n8儿童必须在成人监护下使用。\n9请将此药品放在儿童不能接触的地方。\n10如正在服用其他药品,使用本品前请咨询医师或药师。", "板蓝根颗粒 药物相互作用\n如与其他药物同时使用可能会发生药物相互作用,详情请咨询医师或药师。", "板蓝根颗粒 存储方式\n密封。"]} 8 | {"docs": ["维生素B6片 英文名\nVitamin B6 Tablets", "维生素B6片 用药科室\n营养科用药", "维生素B6片 药物分类\n化学药品", "维生素B6片 药物剂量\n10mg", "维生素B6片 孕妇禁忌\n孕妇接受大量维生素B6,可致新生儿产生维生素B6依赖综合症。乳母摄入正常需要量对婴儿无不良影响。", "维生素B6片 适应症\n用于预防和治疗维生素B6缺乏症,如脂溢性皮炎、唇干裂。也可用于减轻妊娠呕吐。", "维生素B6片 禁忌症\n尚不明确。", "维生素B6片 药物用法\n口服。成人,一日1-2片;儿童,一日0.5~1片,连用3周。", "维生素B6片 不良反应\n维生素B6在肾功能正常时几乎不产生毒性。若每天服用200mg,持续30天以上,曾报道可产生维生素B6依赖综合征。每日应用2~6g,持续几个月,可引起严重神经感觉异常,进行性步态不稳至足麻木、手不灵活,停药后可缓解,但仍软弱无力。", "维生素B6片 注意事项\n1 必须按推荐剂量服用,不可超量服用,用药3周后应停药。2 孕妇及哺乳期妇女应在医师指导下使用。3 如服用过量或出现严重不良反应,应立即就医。4 对本品过敏者禁用,过敏体质者慎用。5 本品性状发生改变时禁止使用。6 请将本品放在儿童不能接触的地方。7 儿童必须在成人监护下使用。8 如正在使用其他药品,使用本品前请咨询医师或药师。", "维生素B6片 药物相互作用\n1.小剂量维生素B6(一日5毫克)与左旋多巴合用,可降低后者治疗帕金森病的疗效。但制剂中若含有脱羧酶抑制剂如卡比多巴时,对左旋多巴无影响。 2.氯霉素、盐酸肼酞嗪、异烟肼、青霉胺及免疫抑制剂包括糖皮质激素、环磷酰胺、环孢素等药物可拮抗维生素B6或增强维生素B6经肾排泄,甚至可引起贫血或周围神经炎。 3.服用雌激素时应增加维生素B6的用量,因为雌激素可使维生素B6在体内的活性降低。4.如与其他药物同时使用可能会发生药物相互作用,详情请咨询医师或药师。", "维生素B6片 存储方式\n遮光,密闭保存。"]} 9 | {"docs": ["氯雷他定胶囊 英文名\nLoratadine", "氯雷他定胶囊 用药科室\n呼吸科用药", "氯雷他定胶囊 药物分类\n化学药品", "氯雷他定胶囊 药物剂量\n10毫克", "氯雷他定胶囊 适应症\n用于缓解过敏性鼻炎有关的症状,如喷嚏、流涕、鼻痒、鼻塞以及眼部痒及烧灼感。口服药物后,鼻和眼部症状及体征得以迅速缓解。亦适用于缓解慢性荨麻疹、瘙痒性皮肤病及其他过敏性皮肤病的症状及体征。", "氯雷他定胶囊 禁忌症\n特异体质的病人禁用。", "氯雷他定胶囊 药物用法\n口服。成人及12岁以上儿童:一日1次,一次1粒(10毫克)。 2~12岁儿童:体重>30公斤:一日1次,一次1粒(10毫克)。", "氯雷他定胶囊 不良反应\n在每天10mg的推荐剂量下,本品未见明显的镇静作用。常见不良反应有乏力、头痛、嗜睡、口干、胃肠道不适包括恶心、胃炎以及皮疹等。罕见不良反应有脱发、过敏反应、肝功能异常、心动过速及心悸等。", "氯雷他定胶囊 注意事项\n1  严重肝功能不全的患者请在医生指导下使用。 \n2  妊娠期及哺乳期妇女慎用。\n3  在做皮试前约48小时左右应中止使用本品,因抗组胺药能阻止或降低皮试的阳性反应发生。\n4  对本品过敏者禁用,过敏体质者慎用。\n5  本品性状发生改变时禁止使用。\n6  请将本品放在儿童不能接触的地方。\n7  儿童必须在成人监护下使用。\n8  如正在使用其他药品,使用本品前请咨询医师或药师。", "氯雷他定胶囊 药物相互作用\n同时服用酮康唑、大环内酯类抗生素、西咪替丁、茶碱等药物,会提高氯雷他定在血浆中的浓度,应慎用。其他已知能抑制肝脏代谢的药物,在未明确与氯雷他定相互作用前应谨慎合用。", "氯雷他定胶囊 存储方式\n密封,遮光。"]} 10 | {"docs": ["板蓝根颗粒 药物分类\n中药", "板蓝根颗粒 药物剂量\n(1)每袋装5g(相当于饮片7g)(2)每袋装10g(相当于饮片14g)(3)每袋装3g(无蔗糖,相当于饮片7g)"]} 11 | {"docs": ["氯雷他定片 英文名\nLoratadine Tablets", "氯雷他定片 用药科室\n感冒用药", "氯雷他定片 药物分类\n化学药品", "氯雷他定片 药物剂量\n10毫克", "氯雷他定片 儿童禁忌\n12岁以下儿童应用本品的安全性尚未确定。", "氯雷他定片 孕妇禁忌\n孕妇慎用。服药期宜停止哺乳。", "氯雷他定片 老年人禁忌\n肝肾功能轻中度受损时,对本药的代谢和排泄无明显的影响,所以老年患者用药量与成人相同。", "氯雷他定片 适应症\n用于缓解过敏性鼻炎有关的症状,如喷嚏、流涕、鼻痒、鼻塞(鼻塞是耳鼻咽喉科常见的症状之一,最常见的原因包括鼻炎,鼻中隔偏曲,鼻息肉,鼻窦炎等。理论上来说,鼻塞都可以通过不同的治疗方法进行解决。)以及眼部痒及烧灼感。口服药物后,鼻和眼部症状及体征得以迅速缓解。亦适用于缓解慢性荨麻疹、瘙痒性皮肤病及其他过敏性皮肤病的症状及体征。", "氯雷他定片 禁忌症\n对本品中的成分过敏或特异体质的病人禁用。", "氯雷他定片 药物用法\n1. 成人及12岁以上儿童:一日1次,一次1片(10毫克)。\n2. 2~12岁儿童:\n体重>30公斤:一日1次,一次1片(10毫克);\n体重≤30公斤:一日1次,一次半片(5毫克)。", "氯雷他定片 不良反应\n主要包括头痛,嗜睡,疲乏,口干视觉模糊,血压降低或升高,心悸,晕厥,运动机能亢进,肝功能改变,黄疸,肝炎,肝坏死,脱发,癫痫发作,乳房肿大,多形性红斑及全身性过敏反应。", "氯雷他定片 注意事项\n1.严重肝功能不全的患者请在医生指导下使用。\n2.妊娠期及哺乳期妇女慎用。\n3.在作皮试前的约48小时左右应中止使用本品,因抗组胺药能阻止或降低皮试的阳性反应发生。(皮试是皮肤(或皮内)敏感试验的简称。某些药物在临床使用过程中容易发生过敏反应,如青霉素、链霉素、细胞色素C等,常见的过敏反应包括皮疹、荨麻疹、皮炎、发热、血管神经性水肿、哮喘、过敏性休克等,其中以过敏性休克最为严重,甚至可导致死亡。)\n4.对本品过敏者禁用,过敏体质者慎用。\n5.本品性状发生改变时禁止使用。\n6.请将本品放在儿童不能接触的地方。\n7.儿童必须在成人监护下使用。\n8.如正在使用其他药品,使用本品前请咨询医师或药师。", "氯雷他定片 药物相互作用\n1.同时服用酮康唑、大环内酯类抗生素、西咪替丁、茶碱等药物,会提高氯雷他定在血浆中的浓度,应慎用。其他已知能抑制肝脏代谢的药物,在未明确与氯雷他定相互作用前应谨慎合用。\n2.抑制肝药物代谢酶功能的药物能使本品的代谢减慢。每日同服酮康唑400mg,可使氯雷他定及其活性代谢物去羧乙基氯雷他定的血浆浓度升高,但未观察到心电图改变。与大环内酯类抗生素、西咪替丁、茶碱等药物并用也可抑制氯雷他定的代谢。\n3.如正在服用其它药品,使用本品前请咨询医师或药师。", "氯雷他定片 存储方式\n密封保存。"]} 12 | {"docs": ["卡马西平片 英文名\nCarbamazepine", "卡马西平片 用药科室\n脑神经科用药", "卡马西平片 药物分类\n化学药品", "卡马西平片 药物剂量\n0.1g", "卡马西平片 儿童禁忌\n本品可用于各年龄段儿童,具体参考[用法用量]。", "卡马西平片 孕妇禁忌\n本品能通过胎盘,是否致畸尚不清楚,妊娠早期需慎用;\n本品能分泌入乳汁,约为血药浓度60%,哺乳期妇女不宜应用。", "卡马西平片 老年人禁忌\n老年患者对本品敏感者多,常可引起认知功能障碍、激越、不安、、\n焦虑、精神错乱、房室传导阻滞或心动过缓,也可引起再障。", "卡马西平片 适应症\n1.复杂部分性发作(亦称精神运动性发作或颞叶癫癎)、全身强直-阵孪性发作、上述两种混合性发作或其他部分性或全身性发作;对典型或不典型失神发作、肌阵孪或失神张力发作无效。2.三叉神经痛和舌咽神经痛发作,亦用作三叉神经痛缓解后的长期预防性用药。也可用于脊髓痨和多发性硬化、糖尿病性周围性神经痛、患肢痛和外伤后神经痛以及疱疹后神经痛。3.预防或治疗躁狂-抑郁症;对锂或抗精神病药或抗抑郁药无效的或不能耐受的躁狂-抑郁症,可单用或与锂盐和其它抗抑郁药合用。4.中枢性部分性尿崩症,可单用或氯磺丙脲或氯贝丁酯等合用。5.", "卡马西平片 禁忌症\n禁用:有房室传导阻滞,血清铁严重异常、骨髓抑制、严重肝功能不全等病史者。", "卡马西平片 药物用法\n成人常用量1.抗惊厥,开始一次0.1g,一日2~3次;第二日后每日增加0.1g,直到出现疗效为止;维持量根据调整至最低有效量,分次服用;注意个体化,最高量每日不超过1.2g。2.镇痛,开始一次0.1g,一日2次;第二日后每隔一日增加0.1~0.2g,直到疼痛缓解,维持量每日0.4~0.8g,分次服用;最高量每日不超过1.2g。3.尿崩症,单用时一日0.3~0.6g,如与其他抗利尿药合用,每日0.2~0.4g,分3次服用。4.抗燥狂或抗精神病,开始每日0.2~0.4g,每周逐渐增加至最大量1.6g,分3~4次服用。", "卡马西平片 不良反应\n1.较常见的不良反应是中枢神经系统的反应,表现为视力模糊、复视、眼球震颤。2.因刺激抗利尿激素分泌引起水的潴留和低钠血症(或水中毒),发生率约10~15%。3.较少见的不良反应有变态反应,Stevens-Johnson综合症或中毒性表皮坏死溶解症、皮疹、荨麻疹、瘙痒;儿童行为障碍,严重腹泻,红斑狼疮样综合症(荨麻疹、瘙痒、皮疹、发热、咽喉痛、骨或关节痛、乏力)。4.罕见的不良反应有腺体病,心律失常或房室传导阻滞(老年人尤其注意),骨髓抑制,中枢神经系统中毒(语言困难、精神不安、耳鸣、颤、幻视),过敏性肝炎,低钙血症,直接影响骨代谢导致骨质疏松,肾脏中毒,周围神经炎,急性尿紫质病,栓塞性脉管炎,过敏性肺炎,急性间歇性卟啉病,可致甲状腺功能减退。因注意有一例合并无菌性脑膜炎的肌阵孪性癫癎患者,接受本品治疗后引起脑膜炎复发。偶见粒细胞减少,可逆性血小板减少,再障,中毒性肝炎。", "卡马西平片 注意事项\n1.与三环类抗抑郁药有交叉过敏反应。2.用药期间注意检查:全血细胞检查(包括血小板、网织红细胞及血清铁,应经常复查达2~3年),尿常规,肝功能,眼科检查;卡马西平血药浓度测定。3.一般疼痛不要用本品。4.糖尿病人可能引起尿糖增加,应注意。5.癫癎患者不能突然撤药。6.已用其他抗癫癎药的病人,本品用量应逐渐递增,治疗4周后可能需要增加剂量,避免自身诱导所致血药浓度下降。7.下列情况应停药:肝中毒或骨髓抑制症状出现,心血管系统不良反应或皮疹出现。8.用于特异性疼痛综合征止痛时,如果疼痛完全缓解,应每月减量至停药。9.饭后服用可减少胃肠反应,漏服时应尽快补服,不可一次服双倍量,可一日内分次补足。10.下列情况应慎用:乙醇中毒,心脏损害,冠心病,糖尿病,青光眼,对其他药物有血液反应史者(易诱发骨髓抑制),肝病,抗利尿激素分泌异常或其他内分泌紊乱,尿潴留,肾病。\n【孕妇及哺乳期妇女用药】本品能通过胎盘,是否致畸尚不清楚,妊娠早期需慎用;本品能分泌入乳汁,约为血药浓度60%,哺乳期妇女不宜应用。", "卡马西平片 药物相互作用\n1.与对乙酰氨基酚合用,尤其是单次超量或长期大量,肝脏中毒的危险增加,有可能使后者疗效降低。2.与香豆素类抗凝药合用,由于本品的肝酶的正诱导作用,使抗凝药的血浓度降低,半衰期缩短,抗凝效应减弱,应测定凝血酶原时间而调整药量。3.与碳酸酐酶抑制药合用,骨质疏松的危险增加。4.由于本品的肝酶诱导作用,与氯磺丙脲、氯贝丁酯(安妥明)、去氨加压素(desmopressin)、赖氨加压素(lypressin)、垂体后叶素、加压素等合用,可加强抗利尿作用,合用的各药都需减量。5.与含雌激素的避孕药、环孢素、洋地黄类(可能地高辛除外)、雌激素、左旋甲状腺素或奎尼丁合用时,由于卡马西平对肝代谢酶的正诱导,这些药的效应都会降低,用量应作调整,改用仅含孕激素(黄体酮)的口服避孕药。与口服避孕药合用可能出现阴道大出血。6.与多西环素(强力霉素)合用,后者的血药浓度可能降低,必要时需要调整用量。7.红霉素与醋竹桃霉素(troleandomycin)以及右丙氧芬(detropropoxyphene)可抑制卡马西平的代谢,引起后者血药浓度的升高,出现毒性反应。8.氟哌啶醇、洛沙平、马普替林、噻吨类或三环类抗抑郁药可增强卡马西平的代谢,引起后者血药浓度升高,出现毒性反应。9.锂盐可以降低卡马西平的抗利尿作用。10.与单胺氧化酶(MAO)抑制合用,可引起高热或(和)高血压危象、严重惊厥甚至死亡,两药应用至少要间隔14天。当卡马西平用作抗惊厥剂时,MAO抑制药可以改变癫癎发作的类型。11.卡马西平可以降低诺米芬辛(nomifensine)的吸收并加快其消除。12、苯巴比妥和苯妥英加速卡马西平的代谢,可将卡马西平的t1/2降至9~10小时。", "卡马西平片 存储方式\n密封。"]} 13 | {"docs": ["头孢克肟片 英文名\nCefixime Tablets", "头孢克肟片 用药科室\n感染科用药", "头孢克肟片 药物分类\n化学药品", "头孢克肟片 适应症\n本品适用于敏感菌所致的咽炎、扁桃体炎、急性支气管炎和慢性支气管炎急性发作、中耳炎、尿路感染、单纯性淋病(宫颈炎或尿道炎)等。", "头孢克肟片 禁忌症\n对本品或头孢菌素类抗生素有过敏史者禁用。", "头孢克肟片 药物用法\n成人每日400mg,儿童每日8mg/kg,可单次或分2次口服。儿童体重≥50kg或年龄≥12岁时用成人剂量。治疗单纯性淋病时宜400mg单剂疗法。治疗单纯性淋病时宜400mg单剂疗法。肾功能不全的患者其肌酐清除率(CCr)为21~60ml/min并进行血液透析者给标准剂量的50%,即每日给药300mg; CCr≤20ml/min并进行腹膜透析者给标准剂量的50%,即每日给药200mg。", "头孢克肟片 不良反应\n头孢克肟不良反应大多短暂而轻微。最常见者为胃肠道反应,其中腹泻16%、大便次数增多6%、腹痛3%、恶心7%、消化不良3%、腹胀4%;发生率低于2%的不良反应有皮疹、荨麻疹、药物热、瘙痒、头痛、头昏。实验室异常表现为一过性ALT、AST、ALP、LDH、胆红素、BUN、Cr 升高,血小板和白细胞计数一过性减少和嗜酸性粒细胞增多,直接Coombs试验阳性等。", "头孢克肟片 注意事项\n1. 对头孢菌素类抗生素有过敏史者禁用。肠炎患者慎用,6月以下儿童不宜应用。过去有青霉素过敏休克病史的患者慎用本品,因亦有发生过敏性休克的可能。\n2. 肾功能不全者血清半衰期延长,须调整给药剂量。\n3. 相同剂量混悬剂与片剂服用后以前者为高。血药浓度以前者为高。\n4. 治疗化脓性链球菌感染疗程至少需10天。\n5.中耳炎患者宜用混悬剂治疗。", "头孢克肟片 药物相互作用\n1.本品与下列药物有配伍禁忌;硫酸阿米卡星、庆大霉素、卡那霉素、妥布霉素。新霉素、盐酸金霉素、盐酸四环素、盐酸土霉素、粘菌素甲磺酸钠、硫酸多粘菌素B、葡萄糖酸红霉素、乳糖酸红霉素、林可霉素、磺胺异恶唑、氨茶碱、可溶性巴比妥、氯化钙、葡萄糖酸钙、盐酸苯海拉明的其他抗组胺药、利多卡因、去甲肾上腺素、间羟胺、哌甲酯、琥珀胆碱等。偶亦可能与下列药品发生配伍禁 J忌:青霉素、甲氧西林、琥珀酸氢化可的松、苯妥英钠。丙氯拉嗪(Prochlorperazine)、维生素B族和维生素C、水解蛋白。\n2.呋塞米、依他尼酸、布美他尼等强利尿药,卡氮芥、链佐星(streptozocin)等抗肿瘤药以及氨基糖苷类抗生素与本品合用有增加肾毒性的可能\n3. 棒酸可增加本品对某些因产生B内酰胺酶而对之耐药的革兰氏阴性杆菌的抗菌活性。", "头孢克肟片 存储方式\n密封,置阴凉干燥处。"]} 14 | {"docs": ["小儿善存片 英文名\nCentrum Junior Tablets", "小儿善存片 用药科室\n儿科用药", "小儿善存片 药物分类\n化学药品", "小儿善存片 适应症\n本品为维生素及矿物质药品。用于3—12岁儿童维生素和矿物质的补充。", "小儿善存片 禁忌症\n慢性肾功能衰竭、高钙血症、高磷血症伴肾性佝偻病患者禁用。苯丙酮尿症患者禁用。", "小儿善存片 药物用法\n口服,一日1片", "小儿善存片 不良反应\n偶见胃部不适。", "小儿善存片 注意事项\n1  严格按规定的剂量服用,需要大量服用时,请咨询医师或药师。\n2  如服用过量或出现严重不良反应,应立即就医。\n3  对本品过敏者禁用,过敏体质者慎用。\n4  本品性状发生改变时禁止使用。\n5  请将本品放在儿童不能接触的地方。\n6  儿童必须在成人监护下使用。\n7  如正在使用其他药品,使用本品前请咨询医师或药师。", "小儿善存片 药物相互作用\n1.抗酸药可影响本品中维生素A的吸收,故不应同服。\n2.不应与含有大量镁、钙的药物合用,以免引起高镁、高钙血症。\n3.如与其他药物同时使用可能会发生药物相互作用,详情请咨询医师或药师。", "小儿善存片 存储方式\n遮光密封。"]} 15 | {"docs": ["木糖醇颗粒 英文名\nXylitol Granules", "木糖醇颗粒 用药科室\n内分泌科用药", "木糖醇颗粒 药物分类\n化学药品", "木糖醇颗粒 药物剂量\n10g:9.85g", "木糖醇颗粒 适应症\n用作糖尿病患者的糖的代用品。", "木糖醇颗粒 禁忌症\n1 对本品过敏者禁用.\n2 低血糖患者禁用.", "木糖醇颗粒 药物用法\n口服。成人一次1袋,一日3-5次。", "木糖醇颗粒 不良反应\n初服时可有肠鸣、腹胀、腹泻等症状。适当减少剂量,可减少不良反应。", "木糖醇颗粒 注意事项\n1 儿童用量请咨询医师或药师。2 本品不宜过量服用。3 对本品过敏者禁用,过敏体质者慎用。4 本品性状发生改变时禁止使用。5 请将本品放在儿童不能接触的地方。6 儿童必须在成人监护下使用。7 如正在使用其他药品,使用本品前请咨询医师或药师。", "木糖醇颗粒 药物相互作用\n如与其他药物同时使用可能会发生药物相互作用,详情请咨询医师或药师。", "木糖醇颗粒 存储方式\n密闭,阴凉干燥处保存."]} -------------------------------------------------------------------------------- /demo_data/lora_demo.xlsx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JuneYaooo/medical_kb_chatbot/be4ff60c47ffb31ac3052f82b5ed607d7c7ab089/demo_data/lora_demo.xlsx -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: kb_chat 2 | channels: 3 | - defaults 4 | dependencies: 5 | - _libgcc_mutex=0.1=main 6 | - _openmp_mutex=5.1=1_gnu 7 | - ca-certificates=2023.05.30=h06a4308_0 8 | - ld_impl_linux-64=2.38=h1181459_1 9 | - libffi=3.4.4=h6a678d5_0 10 | - libgcc-ng=11.2.0=h1234567_1 11 | - libgomp=11.2.0=h1234567_1 12 | - libstdcxx-ng=11.2.0=h1234567_1 13 | - ncurses=6.4=h6a678d5_0 14 | - openssl=3.0.9=h7f8727e_0 15 | - pip=23.1.2=py39h06a4308_0 16 | - python=3.9.16=h955ad1f_3 17 | - readline=8.2=h5eee18b_0 18 | - setuptools=67.8.0=py39h06a4308_0 19 | - sqlite=3.41.2=h5eee18b_0 20 | - tk=8.6.12=h1ccaba5_0 21 | - wheel=0.38.4=py39h06a4308_0 22 | - xz=5.4.2=h5eee18b_0 23 | - zlib=1.2.13=h5eee18b_0 24 | - pip: 25 | - accelerate==0.18.0 26 | - aiofiles==23.1.0 27 | - aiohttp==3.8.4 28 | - aiosignal==1.3.1 29 | - altair==5.0.1 30 | - antlr4-python3-runtime==4.9.3 31 | - anyio==3.7.0 32 | - argilla==1.11.0 33 | - astor==0.8.1 34 | - async-timeout==4.0.2 35 | - attrdict==2.0.1 36 | - attrs==23.1.0 37 | - azure-core==1.27.1 38 | - babel==2.12.1 39 | - backoff==2.2.1 40 | - bce-python-sdk==0.8.83 41 | - beautifulsoup4==4.12.2 42 | - bitsandbytes==0.39.1 43 | - blinker==1.6.2 44 | - brotli==1.0.9 45 | - cachetools==5.3.1 46 | - certifi==2023.5.7 47 | - cffi==1.15.1 48 | - chardet==5.1.0 49 | - charset-normalizer==3.1.0 50 | - click==8.1.3 51 | - cmake==3.26.4 52 | - coloredlogs==15.0.1 53 | - commonmark==0.9.1 54 | - contourpy==1.1.0 55 | - cpm-kernels==1.0.11 56 | - cryptography==41.0.1 57 | - cssselect==1.2.0 58 | - cssutils==2.7.1 59 | - cycler==0.11.0 60 | - cython==0.29.35 61 | - dataclasses-json==0.5.8 62 | - decorator==5.1.1 63 | - deprecated==1.2.14 64 | - dill==0.3.6 65 | - effdet==0.4.1 66 | - et-xmlfile==1.1.0 67 | - exceptiongroup==1.1.1 68 | - faiss-cpu==1.7.4 69 | - fastapi==0.95.1 70 | - ffmpy==0.3.0 71 | - filelock==3.12.2 72 | - filetype==1.2.0 73 | - fire==0.5.0 74 | - flask==2.3.2 75 | - flask-babel==3.1.0 76 | - flatbuffers==23.5.26 77 | - fonttools==4.40.0 78 | - frozenlist==1.3.3 79 | - fsspec==2023.6.0 80 | - future==0.18.3 81 | - gevent==22.10.2 82 | - geventhttpclient==2.0.2 83 | - gradio==3.28.3 84 | - gradio-client==0.2.7 85 | - greenlet==2.0.2 86 | - grpcio==1.56.0 87 | - h11==0.14.0 88 | - httpcore==0.16.3 89 | - httpx==0.23.3 90 | - huggingface-hub==0.15.1 91 | - humanfriendly==10.0 92 | - icetk==0.0.7 93 | - idna==3.4 94 | - imageio==2.31.1 95 | - imgaug==0.4.0 96 | - importlib-metadata==6.7.0 97 | - importlib-resources==5.12.0 98 | - iopath==0.1.10 99 | - itsdangerous==2.1.2 100 | - jinja2==3.1.2 101 | - joblib==1.2.0 102 | - jsonschema==4.17.3 103 | - kiwisolver==1.4.4 104 | - langchain==0.0.174 105 | - layoutparser==0.3.4 106 | - lazy-loader==0.2 107 | - linkify-it-py==2.0.2 108 | - lit==16.0.6 109 | - lmdb==1.4.1 110 | - lxml==4.9.2 111 | - markdown==3.4.3 112 | - markdown-it-py==2.2.0 113 | - markupsafe==2.1.3 114 | - marshmallow==3.19.0 115 | - marshmallow-enum==1.5.1 116 | - matplotlib==3.7.1 117 | - mdit-py-plugins==0.3.3 118 | - mdurl==0.1.2 119 | - monotonic==1.6 120 | - mpmath==1.3.0 121 | - msg-parser==1.2.0 122 | - multidict==6.0.4 123 | - multiprocess==0.70.14 124 | - mypy-extensions==1.0.0 125 | - networkx==3.1 126 | - nltk==3.8.1 127 | - numexpr==2.8.4 128 | - numpy==1.23.5 129 | - nvidia-cublas-cu11==11.10.3.66 130 | - nvidia-cuda-cupti-cu11==11.7.101 131 | - nvidia-cuda-nvrtc-cu11==11.7.99 132 | - nvidia-cuda-runtime-cu11==11.7.99 133 | - nvidia-cudnn-cu11==8.5.0.96 134 | - nvidia-cufft-cu11==10.9.0.58 135 | - nvidia-curand-cu11==10.2.10.91 136 | - nvidia-cusolver-cu11==11.4.0.1 137 | - nvidia-cusparse-cu11==11.7.4.91 138 | - nvidia-nccl-cu11==2.14.3 139 | - nvidia-nvtx-cu11==11.7.91 140 | - olefile==0.46 141 | - omegaconf==2.3.0 142 | - onnx==1.12.0 143 | - onnxruntime==1.15.1 144 | - openapi-schema-pydantic==1.2.4 145 | - opencv-contrib-python==4.6.0.66 146 | - opencv-python==4.6.0.66 147 | - openpyxl==3.1.2 148 | - opt-einsum==3.3.0 149 | - orjson==3.9.1 150 | - packaging==23.1 151 | - paddle-bfloat==0.1.7 152 | - paddleocr==2.6.1.3 153 | - paddlepaddle==2.4.2 154 | - pandas==1.5.3 155 | - pdf2docx==0.5.6 156 | - pdf2image==1.16.3 157 | - pdfminer-six==20221105 158 | - pdfplumber==0.9.0 159 | - peft==0.3.0 160 | - pillow==9.5.0 161 | - portalocker==2.7.0 162 | - premailer==3.10.0 163 | - protobuf==3.18.3 164 | - psutil==5.9.5 165 | - pyclipper==1.3.0.post4 166 | - pycocotools==2.0.6 167 | - pycparser==2.21 168 | - pycryptodome==3.18.0 169 | - pydantic==1.10.9 170 | - pydub==0.25.1 171 | - pygments==2.15.1 172 | - pymupdf==1.20.2 173 | - pynvml==11.5.0 174 | - pypandoc==1.11 175 | - pyparsing==3.1.0 176 | - pypinyin==0.48.0 177 | - pyrsistent==0.19.3 178 | - pytesseract==0.3.10 179 | - python-dateutil==2.8.2 180 | - python-docx==0.8.11 181 | - python-magic==0.4.27 182 | - python-multipart==0.0.6 183 | - python-pptx==0.6.21 184 | - python-rapidjson==1.10 185 | - pytz==2023.3 186 | - pywavelets==1.4.1 187 | - pyyaml==6.0 188 | - rapidfuzz==3.1.1 189 | - rarfile==4.0 190 | - regex==2023.6.3 191 | - requests==2.28.2 192 | - rfc3986==1.5.0 193 | - rich==13.0.1 194 | - ruamel-yaml==0.17.32 195 | - ruamel-yaml-clib==0.2.7 196 | - safetensors==0.3.1 197 | - scikit-image==0.21.0 198 | - scikit-learn==1.2.2 199 | - scipy==1.11.0 200 | - semantic-version==2.10.0 201 | - sentence-transformers==2.2.2 202 | - sentencepiece==0.1.99 203 | - shapely==2.0.1 204 | - six==1.16.0 205 | - sniffio==1.3.0 206 | - soupsieve==2.4.1 207 | - sqlalchemy==2.0.17 208 | - starlette==0.26.1 209 | - sympy==1.12 210 | - tabulate==0.9.0 211 | - tenacity==8.2.2 212 | - termcolor==2.3.0 213 | - threadpoolctl==3.1.0 214 | - tifffile==2023.4.12 215 | - timm==0.9.2 216 | - tokenizers==0.13.3 217 | - toolz==0.12.0 218 | - torch==2.0.1 219 | - torchvision==0.15.2 220 | - tqdm==4.65.0 221 | - transformers==4.29.1 222 | - triton==2.0.0 223 | - tritonclient==2.34.0 224 | - typer==0.9.0 225 | - typing-extensions==4.6.3 226 | - typing-inspect==0.9.0 227 | - tzdata==2023.3 228 | - uc-micro-py==1.0.2 229 | - unstructured==0.7.9 230 | - unstructured-inference==0.5.1 231 | - urllib3==1.26.16 232 | - uvicorn==0.21.1 233 | - visualdl==2.5.0 234 | - wand==0.6.11 235 | - websockets==11.0.3 236 | - werkzeug==2.3.6 237 | - wrapt==1.14.1 238 | - x2paddle==1.4.1 239 | - xlrd==2.0.1 240 | - xlsxwriter==3.1.2 241 | - yarl==1.9.2 242 | - zipp==3.15.0 243 | - zope-event==5.0 244 | - zope-interface==6.0 245 | -------------------------------------------------------------------------------- /finetune/pulse/configs/lora_config_bloom.json: -------------------------------------------------------------------------------- 1 | { 2 | "lora_r": 8, 3 | "lora_alpha": 32, 4 | "lora_dropout": 0.05, 5 | "lora_target_modules": [ 6 | "query_key_value" 7 | ] 8 | } 9 | -------------------------------------------------------------------------------- /finetune/pulse/convert_to_conv_data.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import datetime 4 | import os 5 | 6 | ''' 7 | orig_data: [ 8 | { 9 | "instruction": "##结构化任务##根据下文中信息,判断主动脉根部内径是什么?请提取文中对应的值", 10 | "input": "检查途径:经体表 图像等级:乙 检查项目:二维 M型 彩色 多普勒(脉冲式 连续式)一、M型主要测值(单位mm): 名称 测量值 正常参考值 主动脉根部内径 39 20-37 左房内径 36 19-40 左室舒张末期内径 47 35-56 左室收缩末期内径 28 20-37 室间隔厚度 11 6-11 左室后壁厚度 10 6-11 二、二维超声心动图描述:1.各房室无明显扩大,主动脉根部稍增宽。2.各心瓣膜未见明显增厚,开放不受限。3.左室壁不增厚,静息状态下左室各节段收缩活动未见明显异常。三、彩色多普勒超声描述:1.各心瓣膜未见明显异常反流。2.舒张期经二尖瓣口血流频谱:E/A<1。舒张期经二尖瓣口血流:E=54cm/s,A=84cm/s。3.房、室间隔水平未见明显异常分流信号。四、左心功能测定: 名称 测量值 左室舒张末期容量(ml) 102 左室收缩末期容量(ml) 29 每搏输出量(ml) 74 左室短轴缩短率(%) 41 左室射血分数(%) 72 五、组织多普勒检查: 二尖瓣瓣环水平:室间隔侧: e'=6cm/s, E/e'=9。 左室侧壁: e'=9cm/s, E/e'=6。", 11 | "output": "39" 12 | }, 13 | { 14 | "instruction": "##结构化任务##根据下文中信息,判断主动脉根部内径是什么?请提取文中对应的值", 15 | "input": "检查途径:经体表 图像等级:乙 检查项目:二维 M型 彩色 多普勒(脉冲式 连续式)一、M型主要测值(单位mm): 名称 测量值 正常参考值 主动脉根部内径 38 20-37 左房内径 44 19-40 左室舒张末期内径 52 35-56 左室收缩末期内径 33 20-37 室间隔厚度 13 6-11 左室后壁厚度 13 6-11 二、二维超声心动图描述:1.左房增大,主动脉根部内径增宽。2.各心瓣膜未见明显增厚,开放不受限。3.左室壁增厚,静息状态下左室各节段收缩活动未见明显异常。三、彩色多普勒超声描述:1.各心瓣膜示轻度微主动脉瓣反流。2.舒张期经二尖瓣口血流频谱:E/A<1。舒张期经二尖瓣口血流:E=63cm/s,A=91cm/s。3.房、室间隔水平未见明显异常分流信号。四、左心功能测定: 名称 测量值 左室舒张末期容量(ml) 129 左室收缩末期容量(ml) 45 每搏输出量(ml) 84 左室短轴缩短率(%) 36 左室射血分数(%) 65 五、组织多普勒检查: 二尖瓣瓣环水平:室间隔侧: e'=5cm/s, E/e'=12.6。 左室侧壁: e'=13cm/s, E/e'=4.9。", 16 | "output": "38" 17 | }] 18 | convert: { 19 | "id": xxx, 20 | "conversations":[ 21 | {"from": "human", "value": "题目:小明买了一支钢笔,花费了5元,又买了一本书,花费8元,现在他手里还有10元钱,他手上原来有多少钱?"}, 22 | {"from": "assistant", "value": "\n令小明手上原来有的钱为X元。根据题目描述,得出以下方程式:\nX - 5 - 8 = 10\n化简可得:\nX = 23\n因此,小明手上原来有23元钱。"}, 23 | ] 24 | } 25 | ''' 26 | 27 | def main(): 28 | parser = argparse.ArgumentParser() 29 | parser.add_argument( 30 | "--orig_data", 31 | ) 32 | parser.add_argument( 33 | "--write_data", 34 | ) 35 | parser.add_argument( 36 | "--dataset_name", 37 | ) 38 | args = parser.parse_args() 39 | f_write = open(args.write_data,"w") 40 | with open(args.orig_data, 'r', encoding="utf-8") as f: 41 | datas = json.load(f) 42 | num_id = 1 43 | for data in datas: 44 | conversations = [{"from": "human", "value": data['instruction']+'\n'+data['input']},{"from": "assistant", "value": data['output']}] 45 | uniq_id = data['id'] if "id" in data else args.dataset_name+"-"+str(num_id) 46 | item = {"id":uniq_id, "conversations": conversations} 47 | f_write.write(json.dumps(item, ensure_ascii=False)+"\n") 48 | num_id += 1 49 | f_write.close() 50 | 51 | 52 | if __name__ == "__main__": 53 | main() -------------------------------------------------------------------------------- /finetune/pulse/finetune.py: -------------------------------------------------------------------------------- 1 | 2 | from transformers.utils import add_start_docstrings 3 | from transformers.trainer_utils import get_last_checkpoint 4 | from transformers.trainer_pt_utils import torch_distributed_zero_first 5 | from transformers import (AutoModelForCausalLM, AutoTokenizer, 6 | HfArgumentParser, LlamaTokenizer, TrainingArguments, 7 | set_seed) 8 | from peft import LoraConfig, get_peft_model, prepare_model_for_int8_training 9 | from datasets import load_dataset 10 | import transformers 11 | import torch 12 | from typing import Optional 13 | from functools import partial 14 | from dataclasses import dataclass, field 15 | import os 16 | import math 17 | import logging 18 | import json 19 | import sys 20 | 21 | from src.utils import get_model_param_count 22 | from src.trainer import MyTrainer as Trainer 23 | from src.sample_generator import generate_and_tokenize_prompt 24 | 25 | 26 | logger = logging.getLogger(__name__) 27 | 28 | 29 | @dataclass 30 | class ModelArguments: 31 | """ 32 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch. 33 | """ 34 | 35 | model_name_or_path: Optional[str] = field( 36 | default=None, 37 | metadata={ 38 | "help": ( 39 | "The model checkpoint for weights initialization.Don't set if you want to train a model from scratch." 40 | ) 41 | }, 42 | ) 43 | config_name: Optional[str] = field( 44 | default=None, 45 | metadata={ 46 | "help": "Pretrained config name or path if not the same as model_name" 47 | }, 48 | ) 49 | tokenizer_name: Optional[str] = field( 50 | default=None, 51 | metadata={ 52 | "help": "Pretrained tokenizer name or path if not the same as model_name" 53 | }, 54 | ) 55 | cache_dir: Optional[str] = field( 56 | default=None, 57 | metadata={ 58 | "help": "Where do you want to store the pretrained models downloaded from huggingface.co" 59 | }, 60 | ) 61 | use_fast_tokenizer: bool = field( 62 | default=True, 63 | metadata={ 64 | "help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not." 65 | }, 66 | ) 67 | torch_dtype: Optional[str] = field( 68 | default=None, 69 | metadata={ 70 | "help": ( 71 | "Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the " 72 | "dtype will be automatically derived from the model's weights." 73 | ), 74 | "choices": ["auto", "bfloat16", "float16", "float32"], 75 | }, 76 | ) 77 | llama: bool = field(default=False, metadata={"help": "Llama model"}) 78 | 79 | 80 | @dataclass 81 | class DataArguments: 82 | """ 83 | Arguments pertaining to what data we are going to input our model for training and eval. 84 | """ 85 | 86 | dataset_name: Optional[str] = field( 87 | default=None, 88 | metadata={ 89 | "help": "The name of the dataset to use (via the datasets library)."}, 90 | ) 91 | dataset_config_name: Optional[str] = field( 92 | default=None, 93 | metadata={ 94 | "help": "The configuration name of the dataset to use (via the datasets library)." 95 | }, 96 | ) 97 | train_file: Optional[str] = field( 98 | default=None, metadata={"help": "The input training data file (a text file)."} 99 | ) 100 | validation_file: Optional[str] = field( 101 | default=None, 102 | metadata={ 103 | "help": "An optional input evaluation data file to evaluate the perplexity on (a text file)." 104 | }, 105 | ) 106 | validation_split_percentage: Optional[int] = field( 107 | default=5, 108 | metadata={ 109 | "help": "The percentage of the train set used as validation set in case there's no validation split" 110 | }, 111 | ) 112 | 113 | 114 | @dataclass 115 | @add_start_docstrings(TrainingArguments.__doc__) 116 | class TrainingArguments(TrainingArguments): 117 | model_max_length: int = field( 118 | default=512, 119 | metadata={"help": "Maximum sequence length."}, 120 | ) 121 | use_lora: bool = field( 122 | default=False, 123 | metadata={"help": "Whether to use LoRA."} 124 | ) 125 | use_int8_training: bool = field( 126 | default=False, metadata={"help": "Whether to use int8 training."} 127 | ) 128 | lora_config: Optional[str] = field( 129 | default=None, 130 | metadata={"help": "LoRA config file."}, 131 | ) 132 | ddp_find_unused_parameters: bool = field( 133 | default=False, metadata={"help": "ddp_find_unused_parameters"} 134 | ) 135 | gradient_checkpointing: bool = field( 136 | default=False, metadata={"help": "gradient_checkpointing"} 137 | ) 138 | # https://discuss.huggingface.co/t/wandb-does-not-display-train-eval-loss-except-for-last-one/9170 139 | evaluation_strategy: str = field( 140 | default="steps", metadata={"help": "wandb bug fix"} 141 | ) 142 | save_total_limit: Optional[int] = field( 143 | default=3, 144 | metadata={ 145 | "help": "keep saved model less than save_total_limit, delete old checkpoints when save new model"} 146 | ) 147 | report_to: str = field( 148 | default="none", 149 | metadata={"help": "places where report the results"} 150 | ) 151 | 152 | 153 | def print_rank_0(msg, log_file, rank=0): 154 | if rank <= 0: 155 | with open(log_file, "a") as f: 156 | print(msg) 157 | f.write(msg + "\n") 158 | 159 | 160 | def main(): 161 | parser = HfArgumentParser( 162 | (ModelArguments, DataArguments, TrainingArguments) 163 | ) 164 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 165 | 166 | world_size = int(os.environ.get("WORLD_SIZE", 1)) 167 | global_rank = torch.distributed.get_rank() 168 | if not os.path.exists(training_args.output_dir): 169 | os.makedirs(training_args.output_dir) 170 | log_file = os.path.join(training_args.output_dir, "print_log.txt") 171 | 172 | # Setup logging 173 | logging.basicConfig( 174 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 175 | datefmt="%m/%d/%Y %H:%M:%S", 176 | handlers=[logging.StreamHandler(sys.stdout)], 177 | ) 178 | 179 | if training_args.should_log: 180 | # The default of training_args.log_level is passive, so we set log level at info here to have that default. 181 | transformers.utils.logging.set_verbosity_info() 182 | 183 | log_level = training_args.get_process_log_level() 184 | logger.setLevel(log_level) 185 | transformers.utils.logging.set_verbosity(log_level) 186 | transformers.utils.logging.enable_default_handler() 187 | transformers.utils.logging.enable_explicit_format() 188 | 189 | # Log on each process the small summary: 190 | logger.warning( 191 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}, distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" 192 | ) 193 | logger.info(f"Training/evaluation parameters {training_args}") 194 | 195 | # Detecting last checkpoint. 196 | last_checkpoint = None 197 | if ( 198 | os.path.isdir(training_args.output_dir) 199 | and training_args.do_train 200 | and not training_args.overwrite_output_dir 201 | ): 202 | last_checkpoint = get_last_checkpoint(training_args.output_dir) 203 | if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: 204 | raise ValueError( 205 | f"Output directory ({training_args.output_dir}) already exists and is not empty. " 206 | "Use --overwrite_output_dir to overcome." 207 | ) 208 | elif ( 209 | last_checkpoint is not None and training_args.resume_from_checkpoint is None 210 | ): 211 | logger.info( 212 | f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " 213 | "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." 214 | ) 215 | 216 | # Set seed before initializing model. 217 | set_seed(training_args.seed) 218 | 219 | torch_dtype = ( 220 | model_args.torch_dtype 221 | if model_args.torch_dtype in ["auto", None] 222 | else getattr(torch, model_args.torch_dtype) 223 | ) 224 | # int8 is not compatible with DeepSpeed (require not to pass device_map) 225 | if training_args.use_int8_training: 226 | print_rank_0( 227 | "int8 is not compatible with DeepSpeed. ", 228 | log_file, 229 | global_rank 230 | ) 231 | device_map = ( 232 | {"": int(os.environ.get("LOCAL_RANK") or 0)} 233 | if world_size != 1 else "auto" 234 | ) 235 | # device_map = "auto" 236 | model = AutoModelForCausalLM.from_pretrained( 237 | model_args.model_name_or_path, 238 | load_in_8bit=True, # xxx: int8 load in 239 | device_map=device_map, # xxx: int8 requires passing device_map 240 | torch_dtype=torch_dtype, 241 | ) 242 | else: 243 | model = AutoModelForCausalLM.from_pretrained( 244 | model_args.model_name_or_path, 245 | torch_dtype=torch_dtype, 246 | ) 247 | 248 | if model_args.llama: 249 | tokenizer = LlamaTokenizer.from_pretrained( 250 | model_args.model_name_or_path 251 | ) 252 | print_rank_0( 253 | "Set the eos_token_id and bos_token_id of LLama model tokenizer", 254 | log_file, 255 | global_rank, 256 | ) 257 | tokenizer.eos_token_id = 2 258 | tokenizer.bos_token_id = 1 259 | else: 260 | tokenizer = AutoTokenizer.from_pretrained( 261 | model_args.model_name_or_path 262 | ) 263 | 264 | tokenizer.pad_token_id = 0 265 | tokenizer.padding_side = "left" # Allow batched inference 266 | 267 | print_rank_0( 268 | "tokenizer.eos_token_id = {}".format(tokenizer.eos_token_id), 269 | log_file, 270 | global_rank, 271 | ) 272 | print_rank_0( 273 | "tokenizer.pad_token_id = {}".format(tokenizer.pad_token_id), 274 | log_file, 275 | global_rank, 276 | ) 277 | print_rank_0( 278 | "tokenizer.bos_token_id = {}".format(tokenizer.bos_token_id), 279 | log_file, 280 | global_rank, 281 | ) 282 | 283 | # peft model 284 | if training_args.use_lora: 285 | print_rank_0( 286 | "Loading lora config from {}".format(training_args.lora_config), 287 | log_file, 288 | global_rank, 289 | ) 290 | lora_config = json.load(open(training_args.lora_config)) 291 | print_rank_0( 292 | "Lora config: {}".format(lora_config), 293 | log_file, 294 | global_rank 295 | ) 296 | if training_args.use_int8_training: 297 | print_rank_0( 298 | "training_args.use_int8_training!!! (int8 is not compatible with DeepSpeed)", 299 | log_file, 300 | global_rank, 301 | ) 302 | model = prepare_model_for_int8_training(model) 303 | config = LoraConfig( 304 | r=lora_config["lora_r"], 305 | lora_alpha=lora_config["lora_alpha"], 306 | target_modules=lora_config["lora_target_modules"], 307 | lora_dropout=lora_config["lora_dropout"], 308 | bias="none", 309 | task_type="CAUSAL_LM", 310 | ) 311 | 312 | # "RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn" 313 | if hasattr(model, "enable_input_require_grads"): 314 | model.enable_input_require_grads() 315 | else: 316 | 317 | def make_inputs_require_grad(module, input, output): 318 | output.requires_grad_(True) 319 | 320 | model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) 321 | 322 | model = get_peft_model(model, config) 323 | model.print_trainable_parameters() 324 | 325 | if training_args.gradient_checkpointing: 326 | model.gradient_checkpointing_enable() 327 | 328 | # model.is_parallelizable = True 329 | # model.model_parallel = True 330 | 331 | assert os.path.exists(data_args.train_file), "{} file not exists".format( 332 | data_args.train_file 333 | ) 334 | 335 | with torch_distributed_zero_first(global_rank): 336 | train_data = load_dataset( 337 | "json", 338 | data_files=data_args.train_file, 339 | cache_dir=model_args.cache_dir 340 | ) 341 | 342 | val_data = load_dataset( 343 | "json", 344 | data_files=data_args.validation_file, 345 | cache_dir=model_args.cache_dir 346 | ) 347 | 348 | train_data = train_data["train"].shuffle().map( 349 | partial( 350 | generate_and_tokenize_prompt, 351 | training_args.model_max_length, 352 | tokenizer 353 | ) 354 | ) 355 | 356 | val_data = val_data["train"].shuffle().map( 357 | partial( 358 | generate_and_tokenize_prompt, 359 | training_args.model_max_length, 360 | tokenizer 361 | ) 362 | ) 363 | 364 | for i in range(2): 365 | print_rank_0( 366 | "Eval tokenized example: {}".format(val_data[i]), 367 | log_file, 368 | global_rank 369 | ) 370 | for i in range(2): 371 | print_rank_0( 372 | "Train tokenized example: {}".format(train_data[i]), 373 | log_file, 374 | global_rank 375 | ) 376 | 377 | training_nums = len(train_data) 378 | num_gpus = torch.cuda.device_count() 379 | 380 | batch_size = ( 381 | training_args.per_device_train_batch_size 382 | * training_args.world_size 383 | * training_args.gradient_accumulation_steps 384 | ) 385 | # train steps 386 | t_total = math.ceil(training_nums / batch_size) * \ 387 | training_args.num_train_epochs 388 | # eval steps 389 | training_args.eval_steps = max(t_total // 5, 5) 390 | # save steps 391 | training_args.save_steps = training_args.eval_steps 392 | training_args.warmup_steps = ( 393 | int(t_total * training_args.warmup_ratio) 394 | if training_args.warmup_ratio > 0.0 395 | else training_args.warmup_steps 396 | ) 397 | print_rank_0( 398 | "num_gpus = {}, training_nums = {}, t_total = {}, warmup_steps = {}, eval_steps = {}, save_steps = {}".format( 399 | num_gpus, 400 | training_nums, 401 | t_total, 402 | training_args.warmup_steps, 403 | training_args.eval_steps, 404 | training_args.save_steps, 405 | ), 406 | log_file, 407 | global_rank, 408 | ) 409 | print_rank_0( 410 | "val data nums = {}, training_nums = {}, batch_size = {}".format( 411 | len(val_data), training_nums, batch_size 412 | ), 413 | log_file, 414 | global_rank, 415 | ) 416 | 417 | # Trainer 418 | # https://github.com/huggingface/transformers/blob/main/src/transformers/training_args.py 419 | # https://github.com/huggingface/transformers/blob/main/src/transformers/data/data_collator.py 420 | # https://github.com/huggingface/transformers/blob/main/src/transformers/trainer.py 421 | # https://www.deepspeed.ai/docs/config-json/ 422 | # https://huggingface.co/docs/accelerate/usage_guides/deepspeed 423 | # https://huggingface.co/transformers/v4.10.1/main_classes/deepspeed.html 424 | # https://github.com/tatsu-lab/stanford_alpaca/issues/176 425 | trainer = Trainer( 426 | model=model, 427 | args=training_args, 428 | train_dataset=train_data, 429 | eval_dataset=val_data, 430 | data_collator=transformers.DataCollatorForSeq2Seq( 431 | tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True 432 | ), 433 | ) 434 | 435 | print_rank_0( 436 | f"Using {training_args.half_precision_backend} half precision backend", 437 | log_file, 438 | global_rank, 439 | ) 440 | # Train! 441 | len_dataloader = len(trainer.get_train_dataloader()) 442 | num_update_steps_per_epoch = ( 443 | len_dataloader // training_args.gradient_accumulation_steps 444 | ) 445 | 446 | total_train_batch_size = ( 447 | training_args.train_batch_size 448 | * training_args.gradient_accumulation_steps 449 | * training_args.world_size 450 | ) 451 | num_examples = trainer.num_examples(trainer.get_train_dataloader()) 452 | num_train_samples = num_examples * training_args.num_train_epochs 453 | max_steps = math.ceil(training_args.num_train_epochs * \ 454 | num_update_steps_per_epoch) 455 | print_rank_0("***** Running training *****", log_file, global_rank) 456 | print_rank_0(f" Num examples = {num_examples}", log_file, global_rank) 457 | print_rank_0( 458 | f" Num train samples = {num_train_samples}", 459 | log_file, 460 | global_rank 461 | ) 462 | print_rank_0(f" world_size = {world_size}", log_file, global_rank) 463 | print_rank_0( 464 | f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}", 465 | log_file, 466 | global_rank, 467 | ) 468 | print_rank_0( 469 | f" Gradient Accumulation steps = {training_args.gradient_accumulation_steps}", 470 | log_file, 471 | global_rank, 472 | ) 473 | print_rank_0( 474 | f" Total optimization steps = {max_steps}", 475 | log_file, 476 | global_rank 477 | ) 478 | 479 | print_rank_0( 480 | f" Number of trainable parameters = {get_model_param_count(model, trainable_only=True)}", 481 | log_file, 482 | global_rank, 483 | ) 484 | 485 | # https://discuss.huggingface.co/t/what-is-the-purpose-of-use-cache-in-decoder/958/3 486 | model.config.use_cache = False 487 | 488 | trainer.train(resume_from_checkpoint=None) 489 | 490 | print_rank_0( 491 | "\n Training completed!!! If there's a warning about missing keys above, please disregard :)", 492 | log_file, 493 | global_rank, 494 | ) 495 | 496 | 497 | if __name__ == "__main__": 498 | main() 499 | -------------------------------------------------------------------------------- /finetune/pulse/src/sample_generator.py: -------------------------------------------------------------------------------- 1 | import pudb 2 | import copy 3 | from transformers import PreTrainedTokenizer 4 | import json 5 | IGNORE_INDEX = -100 6 | 7 | 8 | def generate_and_tokenize_prompt(model_max_length: int, tokenizer: PreTrainedTokenizer, data_point): 9 | input_ids = [] 10 | labels = [] 11 | source = data_point["conversations"] 12 | for sentence in source: 13 | sentence_from = sentence["from"].lower() 14 | sentence_value = ( 15 | "Human: \n" + sentence["value"] + "\n\nAssistant: \n" 16 | if sentence_from == "human" 17 | else sentence["value"] 18 | ) # https://github.com/LianjiaTech/BELLE/issues/337 19 | # conversation += sentence_value 20 | sentence_ids = tokenizer.encode( 21 | sentence_value, add_special_tokens=False 22 | ) # do not add bos_token_id 23 | label = ( 24 | copy.deepcopy(sentence_ids) 25 | if sentence_from != "human" 26 | else [IGNORE_INDEX] * len(sentence_ids) 27 | ) 28 | input_ids += sentence_ids 29 | labels += label 30 | # add eos at every end of assistant sentence 31 | if sentence_from != "human": 32 | input_ids += [ 33 | tokenizer.eos_token_id 34 | ] # make sure eos_token_id is correct 35 | labels += [tokenizer.eos_token_id] 36 | 37 | input_ids = input_ids[: model_max_length - 1] 38 | labels = labels[: model_max_length - 1] 39 | if all(x == IGNORE_INDEX for x in labels): 40 | labels[18:24] = input_ids[ 41 | 18:24 42 | ] # labels can not have all values being -100. 18 and 24 are just random numbers 43 | 44 | attention_mask = [1] * len(input_ids) 45 | tokenized_full_prompt = { 46 | "input_ids": input_ids, 47 | "attention_mask": attention_mask, 48 | "labels": labels, 49 | } 50 | return tokenized_full_prompt 51 | 52 | 53 | def pretrain_generate(model_max_length: int, tokenizer: PreTrainedTokenizer, data_point): 54 | input_ids = tokenizer.encode(data_point['text']) 55 | labels = copy.deepcopy(input_ids) 56 | input_ids += [tokenizer.eos_token_id] 57 | labels += [tokenizer.eos_token_id] 58 | input_ids = input_ids[: model_max_length] 59 | labels = labels[: model_max_length] 60 | return { 61 | "input_ids": input_ids, 62 | "attention_mask": [1] * len(input_ids), 63 | "labels": labels, 64 | } 65 | 66 | 67 | def exam_generate(model_max_length: int, tokenizer: PreTrainedTokenizer, data_point): 68 | template = 'Human: \n{human}\n\nAssistant: \n' 69 | # pudb.set_trace() 70 | input_str = template.format( 71 | human=f'回答下面的{data_point["type"]}题,用json返回答案,包括原因和答案,如{{"reason":..., "answer":...}}\n{data_point["question"]}\n选项:{" ".join(data_point["candidates"])}' 72 | ) 73 | input_ids = tokenizer.encode( 74 | input_str, 75 | add_special_tokens=False 76 | ) 77 | labels = [IGNORE_INDEX] * len(input_ids) 78 | bot_ids = tokenizer.encode( 79 | json.dumps( 80 | { 81 | 'reason': data_point['reason'], 82 | 'answer': data_point['answer'] 83 | }, ensure_ascii=False 84 | ), 85 | add_special_tokens=False 86 | ) 87 | input_ids += bot_ids 88 | labels += bot_ids 89 | 90 | input_ids += [tokenizer.eos_token_id] 91 | labels += [tokenizer.eos_token_id] 92 | 93 | input_ids = input_ids[: model_max_length - 1] 94 | labels = labels[: model_max_length - 1] 95 | return { 96 | "input_ids": input_ids, 97 | "attention_mask": [1] * len(input_ids), 98 | "labels": labels, 99 | } 100 | -------------------------------------------------------------------------------- /finetune/pulse/src/trainer.py: -------------------------------------------------------------------------------- 1 | from peft import PeftModel 2 | from transformers.trainer import * 3 | 4 | from src.utils import get_ds_state_dict 5 | import re 6 | 7 | def remove_last_directory(path): 8 | pattern = r'^(.*)/[^/]+$' # 匹配目录路径和最后一层目录名 9 | match = re.match(pattern, path) 10 | if match: 11 | return match.group(1) 12 | else: 13 | return path 14 | 15 | class MyTrainer(Trainer): 16 | def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False): 17 | """ 18 | Add supports for peft 19 | 20 | Will save the model, so you can reload it using `from_pretrained()`. 21 | 22 | Will only save from the main process. 23 | """ 24 | 25 | if output_dir is None: 26 | output_dir = self.args.output_dir 27 | 28 | if is_torch_tpu_available(): 29 | self._save_tpu(output_dir) 30 | elif is_sagemaker_mp_enabled(): 31 | # Calling the state_dict needs to be done on the wrapped model and on all processes. 32 | os.makedirs(output_dir, exist_ok=True) 33 | state_dict = self.model_wrapped.state_dict() 34 | if self.args.should_save: 35 | self._save(output_dir, state_dict=state_dict) 36 | self._save(remove_last_directory(output_dir), state_dict=state_dict) 37 | if IS_SAGEMAKER_MP_POST_1_10: 38 | # 'user_content.pt' indicates model state_dict saved with smp >= 1.10 39 | Path(os.path.join(output_dir, "user_content.pt")).touch() 40 | elif ( 41 | ShardedDDPOption.ZERO_DP_2 in self.args.sharded_ddp 42 | or ShardedDDPOption.ZERO_DP_3 in self.args.sharded_ddp 43 | or self.fsdp is not None 44 | ): 45 | state_dict = self.model.state_dict() 46 | 47 | if self.args.should_save: 48 | self._save(output_dir, state_dict=state_dict) 49 | self._save(remove_last_directory(output_dir), state_dict=state_dict) 50 | elif self.deepspeed: 51 | # This must be called on all ranks in stage 3 52 | if is_deepspeed_zero3_enabled(): 53 | state_dict = get_ds_state_dict(self.deepspeed) 54 | else: 55 | # Only run on rank 0 except stage 3 56 | if self.args.should_save: 57 | state_dict = get_ds_state_dict(self.deepspeed) 58 | # this takes care of everything as long as we aren't under zero3 59 | # Only run on rank 0 60 | if self.args.should_save: 61 | # state_dict is available on rank 0 62 | self._save(output_dir, state_dict=state_dict) 63 | self._save(remove_last_directory(output_dir), state_dict=state_dict) 64 | 65 | elif self.args.should_save: 66 | self._save(output_dir) 67 | self._save(remove_last_directory(output_dir)) 68 | 69 | # Push to the Hub when `save_model` is called by the user. 70 | if self.args.push_to_hub and not _internal_call: 71 | self.push_to_hub(commit_message="Model save") 72 | 73 | def _save(self, output_dir: Optional[str] = None, state_dict=None): 74 | """ 75 | Add supports for peft 76 | """ 77 | # If we are executing this function, we are the process zero, so we don't check for that. 78 | output_dir = output_dir if output_dir is not None else self.args.output_dir 79 | os.makedirs(output_dir, exist_ok=True) 80 | logger.info(f"Saving model checkpoint to {output_dir}") 81 | # Save a trained model and configuration using `save_pretrained()`. 82 | # They can then be reloaded using `from_pretrained()` 83 | if not isinstance(self.model, (PreTrainedModel, PeftModel)): 84 | if state_dict is None: 85 | state_dict = self.model.state_dict() 86 | 87 | if isinstance(unwrap_model(self.model), (PreTrainedModel, PeftModel)): 88 | unwrap_model(self.model).save_pretrained( 89 | output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors 90 | ) 91 | unwrap_model(self.model).save_pretrained( 92 | remove_last_directory(output_dir), state_dict=state_dict, safe_serialization=self.args.save_safetensors 93 | ) 94 | else: 95 | logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.") 96 | if self.args.save_safetensors: 97 | safetensors.torch.save_file(state_dict, os.path.join(output_dir, SAFE_WEIGHTS_NAME)) 98 | else: 99 | torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME)) 100 | else: 101 | self.model.save_pretrained( 102 | output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors 103 | ) 104 | 105 | if self.tokenizer is not None: 106 | self.tokenizer.save_pretrained(output_dir) 107 | 108 | # Good practice: save your training arguments together with the trained model 109 | torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME)) 110 | -------------------------------------------------------------------------------- /finetune/pulse/src/utils.py: -------------------------------------------------------------------------------- 1 | import time 2 | from typing import Any, List, Union 3 | from gradio_client import Client 4 | from tqdm import tqdm 5 | from transformers.deepspeed import is_deepspeed_zero3_enabled 6 | import torch 7 | import traceback 8 | try: 9 | from deepspeed.runtime.engine import DeepSpeedEngine 10 | def get_ds_state_dict(ds_engine: DeepSpeedEngine): 11 | """ 12 | 如果是zero stage 3,要对所有rank调用,无视掉stage3_gather_16bit_weights_on_model_save参数 13 | """ 14 | if ds_engine.zero_optimization_partition_weights(): 15 | # consolidation is expensive in time and memory and therefore isn't a default 16 | state_dict = ds_engine._zero3_consolidated_16bit_state_dict() 17 | else: 18 | state_dict = ds_engine.module.state_dict() 19 | return state_dict 20 | 21 | 22 | def get_model_param_count(model: Union[DeepSpeedEngine, torch.nn.Module], trainable_only=False): 23 | """ 24 | Calculate model's total param count. If trainable_only is True then count only those requiring grads 25 | """ 26 | if is_deepspeed_zero3_enabled() and isinstance(model, DeepSpeedEngine): 27 | def numel(p): 28 | return p.ds_numel 29 | 30 | else: 31 | def numel(p): 32 | return p.numel() 33 | 34 | return sum(numel(p) for p in model.parameters() if not trainable_only or p.requires_grad) 35 | except ModuleNotFoundError: 36 | traceback.print_exc() 37 | DeepSpeedEngine = None 38 | 39 | 40 | def get_model_param_count(model: Union[DeepSpeedEngine, torch.nn.Module], trainable_only=False): 41 | """ 42 | Calculate model's total param count. If trainable_only is True then count only those requiring grads 43 | """ 44 | if is_deepspeed_zero3_enabled() and isinstance(model, DeepSpeedEngine): 45 | def numel(p): 46 | return p.ds_numel 47 | 48 | else: 49 | def numel(p): 50 | return p.numel() 51 | 52 | return sum(numel(p) for p in model.parameters() if not trainable_only or p.requires_grad) 53 | 54 | def get_ds_state_dict(ds_engine: DeepSpeedEngine): 55 | 56 | return {} 57 | 58 | 59 | 60 | class MultiClient(object): 61 | def __init__(self, worker_addrs) -> None: 62 | self.clients = [Client(addr) for addr in worker_addrs] 63 | 64 | def predict(self, tasks: List[List], max_retries: int = 3) -> List[Any]: 65 | pbar = tqdm(total=len(tasks)) 66 | jobs = { 67 | client: (i, client.submit(*(tasks[i]), api_name="/predict")) 68 | for i, client in enumerate(self.clients) 69 | if i < len(tasks) 70 | } 71 | results = {} 72 | retries = {i: 0 for i in range(len(tasks))} 73 | 74 | while jobs: 75 | for client, (i, job) in list(jobs.items()): 76 | if job.done(): 77 | pbar.update(1) 78 | del jobs[client] 79 | try: 80 | result = job.result() 81 | results[i] = result 82 | except Exception as e: 83 | print("Job failed with error:", e) 84 | if retries[i] < max_retries: 85 | print("Retrying job...") 86 | retries[i] += 1 87 | new_job = client.submit( 88 | *tasks[i], api_name="/predict") 89 | jobs[client] = (i, new_job) 90 | continue # Skip the rest of the loop 91 | else: 92 | results[i] = None 93 | 94 | if tasks: 95 | new_i = len(results) + len(jobs) 96 | if new_i < len(tasks): 97 | new_task = tasks[new_i] 98 | new_job = client.submit( 99 | *new_task, api_name="/predict") 100 | jobs[client] = (new_i, new_job) 101 | time.sleep(0.1) 102 | pbar.close() 103 | 104 | predicts = [results[i] for i in sorted(results)] 105 | 106 | return predicts 107 | -------------------------------------------------------------------------------- /finetune/pulse_utils.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import json 3 | import datetime,time 4 | import shutil 5 | import os 6 | import re 7 | import subprocess 8 | from sklearn.metrics import classification_report 9 | from pynvml import (nvmlInit, nvmlDeviceGetCount, nvmlDeviceGetHandleByIndex, 10 | nvmlDeviceGetName, nvmlDeviceGetMemoryInfo, nvmlShutdown) 11 | from configs.common_config import * 12 | 13 | model_loaded = False 14 | project_change = False 15 | last_lora_name = '' 16 | max_new_tokens = 1500 17 | generation_config = dict( 18 | temperature=0.001, 19 | top_k=30, 20 | top_p=0.85, 21 | do_sample=False, 22 | num_beams=1, 23 | repetition_penalty=1.2, 24 | max_new_tokens=max_new_tokens 25 | ) 26 | def read_excel_file(file_path): 27 | df = pd.read_excel(file_path) 28 | return df 29 | 30 | def save_to_excel(df, file_path): 31 | df.to_excel(file_path, index=False) 32 | 33 | def process_data(training_data_path): 34 | # 读取 Excel 文件 35 | df = pd.read_excel(training_data_path) 36 | log = [] 37 | log.append(f'开始处理数据') 38 | 39 | all_data = [] 40 | # 遍历每一行数据 41 | for index, row in df.iterrows(): 42 | instruction = row['系统指示'] 43 | question = row['问题'] 44 | answer = row['回答'] 45 | 46 | # 创建字典并将数据添加到列表中 47 | data = {"instruction": instruction, "input": question, "output": answer} 48 | all_data.append(data) 49 | 50 | log = '\n'.join(log) # 使用换行符拼接log内容 51 | return all_data, log 52 | 53 | 54 | def get_available_gpu(threshold=20000): 55 | # Initialize NVML 56 | nvmlInit() 57 | 58 | # Get the number of GPU devices 59 | device_count = nvmlDeviceGetCount() 60 | 61 | # Find GPU devices with available memory greater than the threshold 62 | available_gpus = [] 63 | for i in range(device_count): 64 | handle = nvmlDeviceGetHandleByIndex(i) 65 | info = nvmlDeviceGetMemoryInfo(handle) 66 | free_memory_mb = info.free / 1024 / 1024 67 | 68 | if free_memory_mb > threshold: 69 | available_gpus.append(i) 70 | 71 | # Shutdown NVML 72 | nvmlShutdown() 73 | 74 | return available_gpus 75 | 76 | def pulse_train_model(model_name, lora_name, training_data_path): 77 | now_str = datetime.datetime.now().strftime('%Y%m%d_%H%M') 78 | print('now_str',now_str) 79 | all_data,log = process_data(training_data_path) 80 | log_file_path = f'data/logs/{now_str}.log' # 定义log文件路径 81 | os.makedirs(os.path.dirname(log_file_path), exist_ok=True) # 创建存储log的文件夹 82 | 83 | with open(log_file_path, 'w', encoding="utf-8") as f: 84 | f.write(log) # 将log内容写入文件 85 | with open(f"data/{lora_name}_dataset.json", "w", encoding="utf-8") as f: 86 | json.dump(all_data, f, indent=4, ensure_ascii=False) 87 | if not os.path.exists('finetune/pulse/data'): 88 | os.makedirs('finetune/pulse/data') 89 | if not os.path.exists('finetune/pulse/logs'): 90 | os.makedirs('finetune/pulse/logs') 91 | shutil.copyfile(f"data/{lora_name}_dataset.json", f"finetune/pulse/data/{lora_name}_dataset.json") 92 | 93 | available_gpus = get_available_gpu(threshold=20000) 94 | print('available_gpus[0]',available_gpus[0]) 95 | content = f'''python convert_to_conv_data.py --orig_data data/{lora_name}_dataset.json --write_data data/{lora_name}_dataset_conv.json --dataset_name {lora_name} 96 | 97 | CUDA_VISIBLE_DEVICES={available_gpus[0]} torchrun --nproc_per_node 1 finetune.py --model_name_or_path {llm_model_dict[model_name]["local_model_path"]} --use_lora True --use_int8_training --lora_config configs/lora_config_bloom.json --train_file data/{lora_name}_dataset_conv.json --validation_file data/{lora_name}_dataset_conv.json --per_device_train_batch_size 1 --per_device_eval_batch_size 1 --gradient_accumulation_steps 2 --num_train_epochs 2 --model_max_length 100 --save_strategy "steps" --save_total_limit 3 --learning_rate 3e-4 --weight_decay 0.00001 --warmup_ratio 0.05 --lr_scheduler_type "cosine" --logging_steps 10 --evaluation_strategy "steps" --seed 2048 --gradient_checkpointing True --cache_dir cache/{lora_name} --output_dir output/{lora_name} 98 | ''' 99 | sh_file_name = f'finetune/pulse/train_{lora_name}.sh' 100 | 101 | with open(sh_file_name , 'w') as file: 102 | file.write(content) 103 | 104 | # 设置文件可执行权限 105 | os.chmod(sh_file_name , 0o755) 106 | now_str = datetime.datetime.now().strftime('%Y%m%d_%H%M') 107 | print('now_str',now_str) 108 | subprocess.Popen(f"""cd finetune/pulse && . /home/pai/etc/profile.d/conda.sh && conda activate med_llm && nohup sh train_{lora_name}.sh > ./logs/train_{now_str}.log 2>&1 &""", shell=True) 109 | print('finish') 110 | # model.train(training_data_path) 111 | return f'{model_name} on training' 112 | 113 | def stop_train_process(): 114 | process = subprocess.Popen('ps -ef | grep finetune.py', shell=True, stdout=subprocess.PIPE) 115 | output, _ = process.communicate() 116 | process.kill() 117 | 118 | 119 | 120 | n = 0 121 | # 解析输出以获取进程ID 122 | print('output',output) 123 | try: 124 | lines = output.decode().split('\n') 125 | for line in lines: 126 | if 'finetune.py' in line: 127 | parts = line.split() 128 | pid = parts[1] 129 | # 杀死进程 130 | subprocess.call(['kill', '-9', pid]) 131 | n+=1 132 | except Exception as e: 133 | print('error!!',e) 134 | 135 | return f'停止了{n//2}个进程' -------------------------------------------------------------------------------- /img/1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JuneYaooo/medical_kb_chatbot/be4ff60c47ffb31ac3052f82b5ed607d7c7ab089/img/1.jpg -------------------------------------------------------------------------------- /img/2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JuneYaooo/medical_kb_chatbot/be4ff60c47ffb31ac3052f82b5ed607d7c7ab089/img/2.jpg -------------------------------------------------------------------------------- /img/3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JuneYaooo/medical_kb_chatbot/be4ff60c47ffb31ac3052f82b5ed607d7c7ab089/img/3.jpg -------------------------------------------------------------------------------- /img/4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JuneYaooo/medical_kb_chatbot/be4ff60c47ffb31ac3052f82b5ed607d7c7ab089/img/4.jpg -------------------------------------------------------------------------------- /loader/__init__.py: -------------------------------------------------------------------------------- 1 | from .image_loader import UnstructuredPaddleImageLoader 2 | from .pdf_loader import UnstructuredPaddlePDFLoader 3 | -------------------------------------------------------------------------------- /loader/image_loader.py: -------------------------------------------------------------------------------- 1 | """Loader that loads image files.""" 2 | from typing import List 3 | 4 | from langchain.document_loaders.unstructured import UnstructuredFileLoader 5 | from paddleocr import PaddleOCR 6 | import os 7 | import nltk 8 | from configs.common_config import NLTK_DATA_PATH 9 | 10 | nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path 11 | 12 | class UnstructuredPaddleImageLoader(UnstructuredFileLoader): 13 | """Loader that uses unstructured to load image files, such as PNGs and JPGs.""" 14 | 15 | def _get_elements(self) -> List: 16 | def image_ocr_txt(filepath, dir_path="tmp_files"): 17 | full_dir_path = os.path.join(os.path.dirname(filepath), dir_path) 18 | if not os.path.exists(full_dir_path): 19 | os.makedirs(full_dir_path) 20 | filename = os.path.split(filepath)[-1] 21 | ocr = PaddleOCR(use_angle_cls=True, lang="ch", use_gpu=False, show_log=False) 22 | result = ocr.ocr(img=filepath) 23 | 24 | ocr_result = [i[1][0] for line in result for i in line] 25 | txt_file_path = os.path.join(full_dir_path, "%s.txt" % (filename)) 26 | with open(txt_file_path, 'w', encoding='utf-8') as fout: 27 | fout.write("\n".join(ocr_result)) 28 | return txt_file_path 29 | 30 | txt_file_path = image_ocr_txt(self.file_path) 31 | from unstructured.partition.text import partition_text 32 | return partition_text(filename=txt_file_path, **self.unstructured_kwargs) 33 | 34 | 35 | if __name__ == "__main__": 36 | filepath = os.path.join(os.path.dirname(os.path.dirname(__file__)), "content", "samples", "test.jpg") 37 | loader = UnstructuredPaddleImageLoader(filepath, mode="elements") 38 | docs = loader.load() 39 | for doc in docs: 40 | print(doc) 41 | -------------------------------------------------------------------------------- /loader/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .bloomz_llm import Bloomz 2 | -------------------------------------------------------------------------------- /loader/models/__main__.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | sys.path.append(os.path.dirname(os.path.abspath(__file__)) + '/../') 5 | import asyncio 6 | from argparse import Namespace 7 | from models.loader.args import parser 8 | from models.loader import LoaderCheckPoint 9 | 10 | from langchain.agents import initialize_agent, Tool 11 | from langchain.agents import AgentType 12 | 13 | import loader.models.shared as shared 14 | 15 | from langchain.chains import LLMChain 16 | from langchain.memory import ConversationBufferMemory, ReadOnlySharedMemory 17 | from langchain.prompts import PromptTemplate 18 | from langchain.agents import ZeroShotAgent, Tool, AgentExecutor 19 | from typing import List, Set 20 | 21 | 22 | 23 | class CustomLLMSingleActionAgent(ZeroShotAgent): 24 | allowed_tools: List[str] 25 | 26 | def __init__(self, *args, **kwargs): 27 | super(CustomLLMSingleActionAgent, self).__init__(*args, **kwargs) 28 | self.allowed_tools = kwargs['allowed_tools'] 29 | 30 | def get_allowed_tools(self) -> Set[str]: 31 | return set(self.allowed_tools) 32 | 33 | 34 | async def dispatch(args: Namespace): 35 | args_dict = vars(args) 36 | 37 | shared.loaderCheckPoint = LoaderCheckPoint(args_dict) 38 | llm_model_ins = shared.loaderLLM() 39 | 40 | template = """This is a conversation between a human and a bot: 41 | 42 | {chat_history} 43 | 44 | Write a summary of the conversation for {input}: 45 | """ 46 | 47 | prompt = PromptTemplate( 48 | input_variables=["input", "chat_history"], 49 | template=template 50 | ) 51 | memory = ConversationBufferMemory(memory_key="chat_history") 52 | readonlymemory = ReadOnlySharedMemory(memory=memory) 53 | summry_chain = LLMChain( 54 | llm=llm_model_ins, 55 | prompt=prompt, 56 | verbose=True, 57 | memory=readonlymemory, # use the read-only memory to prevent the tool from modifying the memory 58 | ) 59 | 60 | 61 | tools = [ 62 | Tool( 63 | name="Summary", 64 | func=summry_chain.run, 65 | description="useful for when you summarize a conversation. The input to this tool should be a string, representing who will read this summary." 66 | ) 67 | ] 68 | 69 | prefix = """Have a conversation with a human, answering the following questions as best you can. You have access to the following tools:""" 70 | suffix = """Begin! 71 | 72 | Question: {input} 73 | {agent_scratchpad}""" 74 | 75 | 76 | prompt = CustomLLMSingleActionAgent.create_prompt( 77 | tools, 78 | prefix=prefix, 79 | suffix=suffix, 80 | input_variables=["input", "agent_scratchpad"] 81 | ) 82 | tool_names = [tool.name for tool in tools] 83 | llm_chain = LLMChain(llm=llm_model_ins, prompt=prompt) 84 | agent = CustomLLMSingleActionAgent(llm_chain=llm_chain, tools=tools, allowed_tools=tool_names) 85 | agent_chain = AgentExecutor.from_agent_and_tools(agent=agent, tools=tools) 86 | 87 | agent_chain.run(input="你好") 88 | agent_chain.run(input="你是谁?") 89 | agent_chain.run(input="我们之前聊了什么?") 90 | 91 | if __name__ == '__main__': 92 | args = None 93 | args = parser.parse_args(args=['--model-dir', '/media/checkpoint/', '--model', 'vicuna-13b-hf', '--no-remote-model', '--load-in-8bit']) 94 | 95 | loop = asyncio.new_event_loop() 96 | asyncio.set_event_loop(loop) 97 | loop.run_until_complete(dispatch(args)) 98 | -------------------------------------------------------------------------------- /loader/models/base.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Optional, List 3 | import traceback 4 | from collections import deque 5 | from queue import Queue 6 | from threading import Thread 7 | 8 | import torch 9 | import transformers 10 | from loader.models.loader import LoaderCheckPoint 11 | 12 | 13 | class ListenerToken: 14 | """ 15 | 观测结果 16 | """ 17 | 18 | input_ids: torch.LongTensor 19 | _scores: torch.FloatTensor 20 | 21 | def __init__(self, input_ids: torch.LongTensor, _scores: torch.FloatTensor): 22 | self.input_ids = input_ids 23 | self._scores = _scores 24 | 25 | 26 | class AnswerResult: 27 | """ 28 | 消息实体 29 | """ 30 | history: List[List[str]] = [] 31 | llm_output: Optional[dict] = None 32 | listenerToken: ListenerToken = None 33 | 34 | 35 | class AnswerResultStream: 36 | def __init__(self, callback_func=None): 37 | self.callback_func = callback_func 38 | 39 | def __call__(self, answerResult: AnswerResult): 40 | if self.callback_func is not None: 41 | self.callback_func(answerResult) 42 | 43 | 44 | class AnswerResultQueueSentinelTokenListenerQueue(transformers.StoppingCriteria): 45 | """ 46 | 定义模型stopping_criteria 监听者,在每次响应时将队列数据同步到AnswerResult 47 | 实现此监听器的目的是,不同模型的预测输出可能不是矢量信息,hf框架可以自定义transformers.StoppingCriteria入参来接收每次预测的Tensor和损失函数, 48 | 通过给 StoppingCriteriaList指定模型生成答案时停止的条件。每个 StoppingCriteria 对象表示一个停止条件 49 | 当每轮预测任务开始时,StoppingCriteria都会收到相同的预测结果,最终由下层实现类确认是否结束 50 | 输出值可用于 generatorAnswer generate_with_streaming的自定义参数观测,以实现更加精细的控制 51 | """ 52 | 53 | listenerQueue: deque = deque(maxlen=1) 54 | 55 | def __init__(self): 56 | transformers.StoppingCriteria.__init__(self) 57 | 58 | def __call__(self, input_ids: torch.LongTensor, _scores: torch.FloatTensor, **kwargs) -> bool: 59 | """ 60 | 每次响应时将数据添加到响应队列 61 | :param input_ids: 62 | :param _scores: 63 | :param kwargs: 64 | :return: 65 | """ 66 | self.listenerQueue.append(ListenerToken(input_ids=input_ids, _scores=_scores)) 67 | return False 68 | 69 | 70 | class Iteratorize: 71 | """ 72 | Transforms a function that takes a callback 73 | into a lazy iterator (generator). 74 | """ 75 | 76 | def __init__(self, func, kwargs={}): 77 | self.mfunc = func 78 | self.q = Queue() 79 | self.sentinel = object() 80 | self.kwargs = kwargs 81 | self.stop_now = False 82 | 83 | def _callback(val): 84 | """ 85 | 模型输出预测结果收集 86 | 通过定义generate_with_callback收集器AnswerResultStream,收集模型预测的AnswerResult响应结果,最终由下层实现类确认是否结束 87 | 结束条件包含如下 88 | 1、模型预测结束、收集器self.q队列收到 self.sentinel标识 89 | 2、在处理迭代器队列消息时返回了break跳出迭代器,触发了StopIteration事件 90 | 3、模型预测出错 91 | 因为当前类是迭代器,所以在for in 中执行了break后 __exit__ 方法会被调用,最终stop_now属性会被更新,然后抛出异常结束预测行为 92 | 迭代器收集的行为如下 93 | 创建Iteratorize迭代对象, 94 | 定义generate_with_callback收集器AnswerResultStream 95 | 启动一个线程异步预测结果来调用上游checkpoint的实现方法_generate_answer 96 | _generate_answer通过generate_with_callback定义的收集器,收集上游checkpoint包装的AnswerResult消息体 97 | 由于self.q是阻塞模式,每次预测后会被消费后才会执行下次预测 98 | 这时generate_with_callback会被阻塞 99 | 主线程Iteratorize对象的__next__方法调用获取阻塞消息并消费 100 | 1、消息为上游checkpoint包装的AnswerResult消息体,返回下游处理 101 | 2、消息为self.sentinel标识,抛出StopIteration异常 102 | 主线程Iteratorize对象__exit__收到消息,最终stop_now属性会被更新 103 | 异步线程检测stop_now属性被更新,抛出异常结束预测行为 104 | 迭代行为结束 105 | :param val: 106 | :return: 107 | """ 108 | if self.stop_now: 109 | raise ValueError 110 | self.q.put(val) 111 | 112 | def gen(): 113 | try: 114 | ret = self.mfunc(callback=_callback, **self.kwargs) 115 | except ValueError: 116 | pass 117 | except: 118 | traceback.print_exc() 119 | pass 120 | 121 | self.q.put(self.sentinel) 122 | 123 | self.thread = Thread(target=gen) 124 | self.thread.start() 125 | 126 | def __iter__(self): 127 | return self 128 | 129 | def __next__(self): 130 | obj = self.q.get(True, None) 131 | if obj is self.sentinel: 132 | raise StopIteration 133 | else: 134 | return obj 135 | 136 | def __del__(self): 137 | """ 138 | 暂无实现 139 | :return: 140 | """ 141 | pass 142 | 143 | def __enter__(self): 144 | return self 145 | 146 | def __exit__(self, exc_type, exc_val, exc_tb): 147 | """ break 后会执行 """ 148 | self.stop_now = True 149 | 150 | 151 | class BaseAnswer(ABC): 152 | """上层业务包装器.用于结果生成统一api调用""" 153 | 154 | @property 155 | @abstractmethod 156 | def _check_point(self) -> LoaderCheckPoint: 157 | """Return _check_point of llm.""" 158 | 159 | @property 160 | @abstractmethod 161 | def _history_len(self) -> int: 162 | """Return _history_len of llm.""" 163 | 164 | @abstractmethod 165 | def set_history_len(self, history_len: int) -> None: 166 | """Return _history_len of llm.""" 167 | 168 | def generatorAnswer(self, prompt: str, 169 | history: List[List[str]] = [], 170 | streaming: bool = False): 171 | def generate_with_callback(callback=None, **kwargs): 172 | kwargs['generate_with_callback'] = AnswerResultStream(callback_func=callback) 173 | self._generate_answer(**kwargs) 174 | 175 | def generate_with_streaming(**kwargs): 176 | return Iteratorize(generate_with_callback, kwargs) 177 | 178 | """ 179 | eos_token_id是指定token(例如,""), 180 | 用于表示序列的结束。在生成文本任务中,生成器在生成序列时,将不断地生成token,直到生成此特殊的eos_token_id,表示序列生成已经完成。 181 | 在Hugging Face Transformer模型中,eos_token_id是由tokenizer自动添加到输入中的。 182 | 在模型生成输出时,如果模型生成了eos_token_id,则生成过程将停止并返回生成的序列。 183 | """ 184 | eos_token_ids = [ 185 | self._check_point.tokenizer.eos_token_id] if self._check_point.tokenizer.eos_token_id is not None else [] 186 | 187 | with generate_with_streaming(prompt=prompt, history=history, streaming=streaming) as generator: 188 | for answerResult in generator: 189 | if answerResult.listenerToken: 190 | output = answerResult.listenerToken.input_ids 191 | yield answerResult 192 | 193 | @abstractmethod 194 | def _generate_answer(self, prompt: str, 195 | history: List[List[str]] = [], 196 | streaming: bool = False, 197 | generate_with_callback: AnswerResultStream = None) -> None: 198 | pass 199 | -------------------------------------------------------------------------------- /loader/models/bloomz_llm.py: -------------------------------------------------------------------------------- 1 | 2 | from abc import ABC 3 | 4 | from langchain.llms.base import LLM 5 | from typing import Optional, List 6 | from loader.models.loader import LoaderCheckPoint 7 | from loader.models.base import (BaseAnswer, 8 | AnswerResult, 9 | AnswerResultStream, 10 | AnswerResultQueueSentinelTokenListenerQueue) 11 | import re 12 | 13 | import transformers 14 | from transformers.generation.streamers import BaseStreamer 15 | from threading import Thread 16 | from queue import Queue 17 | from typing import Callable, Iterable, List, Optional, Tuple, Union 18 | # import torch 19 | 20 | class MyStreamer(BaseStreamer): 21 | 22 | def __init__( 23 | self, 24 | # stop_token_ids: torch.Tensor, 25 | # skip_token_count: int, 26 | # max_input_length: int, 27 | timeout: Optional[float] = None 28 | ): 29 | 30 | # 紧急停止策略 31 | # self.stop_token_ids = stop_token_ids 32 | # self.skip_token_count = skip_token_count 33 | # self.max_input_length = max_input_length 34 | 35 | 36 | self.token_queue = Queue() 37 | self.stop_signal = None 38 | self.timeout = timeout 39 | 40 | 41 | def put(self, value): 42 | list_value = value.tolist() 43 | if type(list_value[0]) == int: 44 | self.token_queue.put(list_value, timeout=self.timeout) 45 | 46 | def end(self): 47 | self.token_queue.put(self.stop_signal, timeout=self.timeout) 48 | 49 | def __iter__(self): 50 | return self 51 | 52 | def __next__(self): 53 | value = self.token_queue.get(timeout=self.timeout) 54 | if value == self.stop_signal: 55 | raise StopIteration() 56 | else: 57 | return value 58 | 59 | def remove_starting_symbols(string): 60 | pattern = r'^[,。,.!!]+' 61 | result = re.sub(pattern, '', string) 62 | return result 63 | 64 | def extract_content(replacement_text, string): 65 | pattern_helper = r'Helper:(.*?)' # r'Helper:(.*?)(?=<\/s>)' 66 | 67 | match_helper = re.findall(pattern_helper, string, re.DOTALL) 68 | 69 | if match_helper: 70 | content = match_helper[-1].strip() 71 | return content.replace('', '').replace('<\s>', '') 72 | else: 73 | replaced_string = re.sub(r'Input:.*?(?=\n)', replacement_text, string, re.DOTALL) 74 | replaced_string = replaced_string.replace('', '').replace('<\s>', '') 75 | return replaced_string 76 | 77 | import re 78 | 79 | def remove_prefix_suffix(string): 80 | pattern = r'^((?:Helper:|:)\s*)(.*?)()?$' 81 | string = re.sub(pattern, r'\2', string, flags=re.DOTALL) 82 | string = remove_starting_symbols(string).replace('', '').replace('<\s>', '') 83 | return string 84 | 85 | class Bloomz(BaseAnswer, LLM, ABC): 86 | max_token: int = 10000 87 | temperature: float = 0.01 88 | top_p = 0.9 89 | checkPoint: LoaderCheckPoint = None 90 | # history = [] 91 | history_len: int = 10 92 | 93 | def __init__(self, checkPoint: LoaderCheckPoint = None): 94 | super().__init__() 95 | self.checkPoint = checkPoint 96 | 97 | @property 98 | def _llm_type(self) -> str: 99 | return "Bloomz" 100 | 101 | @property 102 | def _check_point(self) -> LoaderCheckPoint: 103 | return self.checkPoint 104 | 105 | @property 106 | def _history_len(self) -> int: 107 | return self.history_len 108 | 109 | def set_history_len(self, history_len: int = 10) -> None: 110 | self.history_len = history_len 111 | 112 | def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: 113 | pass 114 | 115 | def _generate_answer(self, prompt: str, 116 | history: List[List[str]] = [], 117 | streaming: bool = False, 118 | generate_with_callback: AnswerResultStream = None) -> None: 119 | # Create the StoppingCriteriaList with the stopping strings 120 | stopping_criteria_list = transformers.StoppingCriteriaList() 121 | # 定义模型stopping_criteria 队列,在每次响应时将 torch.LongTensor, torch.FloatTensor同步到AnswerResult 122 | listenerQueue = AnswerResultQueueSentinelTokenListenerQueue() 123 | stopping_criteria_list.append(listenerQueue) 124 | model_device = next(self.checkPoint.model.parameters()).device 125 | if streaming: 126 | history += [[]] 127 | streamer = MyStreamer() # type: ignore 128 | # torch.manual_seed(23333) 129 | inputs = self.checkPoint.tokenizer([prompt], return_tensors="pt").input_ids.to(model_device) 130 | 131 | thread = Thread(target=self.checkPoint.model.generate, kwargs=dict( 132 | inputs=inputs, 133 | # attention_mask=attention_mask.cuda(), 134 | # gen kargs 135 | num_beams=1, 136 | do_sample=True, 137 | temperature=0.7, 138 | top_p=self.top_p, 139 | top_k=9, 140 | eos_token_id=self.checkPoint.tokenizer.eos_token_id, 141 | max_length=2000, 142 | min_length=1, 143 | streamer=streamer, 144 | )) 145 | 146 | thread.start() 147 | # stream_resp = '' 148 | token_list = [] 149 | for token in streamer: 150 | token_list += token 151 | stream_resp = self.checkPoint.tokenizer.decode(token_list) 152 | stream_resp = remove_prefix_suffix(stream_resp) 153 | history[-1] = [prompt, stream_resp] #stream_resp.lstrip(": ") 154 | answer_result = AnswerResult() 155 | answer_result.history = history 156 | answer_result.llm_output = {"answer": stream_resp} 157 | if listenerQueue.listenerQueue.__len__() > 0: 158 | answer_result.listenerToken = listenerQueue.listenerQueue.pop() 159 | generate_with_callback(answer_result) 160 | else: 161 | inputs = self.checkPoint.tokenizer([prompt], return_tensors="pt").input_ids.to(model_device) 162 | re_token_ids = self.checkPoint.model.generate( 163 | inputs=inputs, 164 | # gen kargs 165 | num_beams=1, 166 | do_sample=True, 167 | temperature=0.7, 168 | top_p=self.top_p, 169 | top_k=9, 170 | eos_token_id=self.checkPoint.tokenizer.eos_token_id, 171 | max_length=2000 #512, 172 | ) 173 | response = self.checkPoint.tokenizer.decode(re_token_ids[0]) 174 | response = extract_content(prompt, response) 175 | self.checkPoint.clear_torch_cache() 176 | history += [[prompt, response]] 177 | answer_result = AnswerResult() 178 | answer_result.history = history 179 | answer_result.llm_output = {"answer": response} 180 | if listenerQueue.listenerQueue.__len__() > 0: 181 | answer_result.listenerToken = listenerQueue.listenerQueue.pop() 182 | 183 | generate_with_callback(answer_result) -------------------------------------------------------------------------------- /loader/models/extensions/callback.py: -------------------------------------------------------------------------------- 1 | # import gc 2 | import traceback 3 | from queue import Queue 4 | # from threading import Thread 5 | # import threading 6 | from typing import Optional, List, Dict, Any, TypeVar, Deque 7 | from collections import deque 8 | import torch 9 | import transformers 10 | 11 | from models.extensions.thread_with_exception import ThreadWithException 12 | import models.shared as shared 13 | 14 | 15 | K = TypeVar('K') 16 | V = TypeVar('V') 17 | 18 | class LimitedLengthDict(Dict[K, V]): 19 | def __init__(self, maxlen=None, *args, **kwargs): 20 | self.maxlen = maxlen 21 | self._keys: Deque[K] = deque() 22 | super().__init__(*args, **kwargs) 23 | 24 | def __setitem__(self, key: K, value: V): 25 | if key not in self: 26 | if self.maxlen is not None and len(self) >= self.maxlen: 27 | oldest_key = self._keys.popleft() 28 | if oldest_key in self: 29 | del self[oldest_key] 30 | self._keys.append(key) 31 | super().__setitem__(key, value) 32 | 33 | 34 | class FixedLengthQueue: 35 | # 停止符号列表 36 | stop_sequence: Optional[str] = [] 37 | # 缓冲区 38 | max_length: int = 0 39 | # 缓冲区容器 40 | queue: deque = None 41 | # 输入区容器 42 | queue_in: LimitedLengthDict[int, str] = {} 43 | # 输出区容器 44 | queue_out: Dict[int, str] = {} 45 | 46 | def __new__(cls, *args, **kwargs): 47 | # 创建新的实例 48 | instance = super().__new__(cls) 49 | # 在这里可以对实例进行额外的设置 50 | return instance 51 | 52 | def __init__(self, stop_sequence): 53 | if stop_sequence is None: 54 | self.stop_sequence = [] 55 | self.max_length = 0 56 | elif isinstance(stop_sequence, str): 57 | self.stop_sequence = [stop_sequence] 58 | self.max_length = 1 59 | else: 60 | self.stop_sequence = stop_sequence 61 | self.max_length = len(''.join(stop_sequence)) 62 | 63 | self.queue = deque(maxlen=self.max_length) 64 | self.queue.clear() 65 | self.queue_in.clear() 66 | self.queue_out.clear() 67 | 68 | def add(self, index, item): 69 | self.queue_in[index] = item 70 | 71 | def _add_out(self, index, item): 72 | self.queue_out[index] = item 73 | 74 | def put_replace_out(self, index): 75 | return self.queue_out[index] 76 | 77 | def contains_replace_sequence(self): 78 | """ 79 | 替换字符 80 | :return: 81 | """ 82 | 83 | for key, value in self.queue_in.items(): 84 | 85 | word_index = value.rfind(":") 86 | if word_index != -1: 87 | value = value.replace(":", ":") 88 | 89 | word_index = value.rfind("[") 90 | if word_index != -1: 91 | value = value.replace("[", "") 92 | 93 | word_index = value.rfind("]") 94 | if word_index != -1: 95 | value = value.replace("]", "") 96 | 97 | self._add_out(key, value) 98 | 99 | def contains_stop_sequence(self): 100 | # 截取固定大小的数据判断 101 | self.queue.clear() 102 | last_three_keys = list(self.queue_out.keys())[-self.max_length:] 103 | joined_queue = ''.join([self.queue_out[key] for key in last_three_keys]) 104 | for char in joined_queue: 105 | self.queue.append(char) 106 | 107 | joined_queue = ''.join(self.queue) 108 | # Initialize a variable to store the index of the last found stop string 109 | last_stop_str_index = -1 110 | 111 | # Iterate through the stop string list 112 | for stop_word in self.stop_sequence: 113 | # Find the last occurrence of the stop string in the output 114 | stop_word_index = joined_queue.rfind(stop_word) 115 | 116 | # If the stop string is found, compare the index with the previously found index 117 | if stop_word_index != -1 and stop_word_index > last_stop_str_index: 118 | last_stop_str_index = stop_word_index 119 | 120 | # Handle the last found stop string index here 121 | return last_stop_str_index 122 | 123 | def __repr__(self): 124 | return str(self.queue) 125 | 126 | 127 | # Copied from https://github.com/PygmalionAI/gradio-ui/ 128 | class _SentinelTokenStoppingCriteria(transformers.StoppingCriteria): 129 | 130 | def __init__(self, sentinel_token_ids: list, starting_idx: int): 131 | transformers.StoppingCriteria.__init__(self) 132 | self.sentinel_token_ids = sentinel_token_ids 133 | self.starting_idx = starting_idx 134 | 135 | def __call__(self, input_ids: torch.LongTensor, _scores: torch.FloatTensor) -> bool: 136 | for sample in input_ids: 137 | trimmed_sample = sample[self.starting_idx:] 138 | 139 | for i in range(len(self.sentinel_token_ids)): 140 | # Can't unfold, output is still too tiny. Skip. 141 | if trimmed_sample.shape[-1] < self.sentinel_token_ids[i].shape[-1]: 142 | continue 143 | for window in trimmed_sample.unfold(0, self.sentinel_token_ids[i].shape[-1], 1): 144 | if torch.all(torch.eq(self.sentinel_token_ids[i][0], window)): 145 | return True 146 | return False 147 | 148 | 149 | class Stream(transformers.StoppingCriteria): 150 | def __init__(self, callback_func=None): 151 | self.callback_func = callback_func 152 | 153 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: 154 | if shared.stop_everything: 155 | raise ValueError 156 | if self.callback_func is not None: 157 | self.callback_func(input_ids[0]) 158 | return False 159 | 160 | 161 | class Iteratorize: 162 | """ 163 | Transforms a function that takes a callback 164 | into a lazy iterator (generator). 165 | """ 166 | 167 | thread: ThreadWithException = None 168 | 169 | def __new__(cls, *args, **kwargs): 170 | # 创建新的实例 171 | instance = super().__new__(cls) 172 | # 在这里可以对实例进行额外的设置 173 | return instance 174 | 175 | def __init__(self, func, kwargs={}, callback=None): 176 | self.mfunc = func 177 | self.c_callback = callback 178 | self.q = Queue() 179 | self.sentinel = object() 180 | self.kwargs = kwargs 181 | 182 | def _callback(val): 183 | if shared.stop_everything: 184 | raise ValueError 185 | self.q.put(val) 186 | 187 | def gen(): 188 | try: 189 | ret = self.mfunc(callback=_callback, **self.kwargs) 190 | except ValueError: 191 | print("print(ValueError)") 192 | except: 193 | traceback.print_exc() 194 | print("traceback.print_exc()") 195 | self.q.put(self.sentinel) 196 | 197 | self.thread = ThreadWithException(target=gen) 198 | self.thread.start() 199 | 200 | def __iter__(self): 201 | shared.stop_everything = False 202 | return self 203 | 204 | def __next__(self): 205 | obj = self.q.get(True, None) 206 | if obj is self.sentinel: 207 | raise StopIteration 208 | else: 209 | return obj 210 | 211 | def __del__(self): 212 | shared.stop_everything = False 213 | self.q.empty() 214 | shared.loaderCheckPoint.clear_torch_cache() 215 | 216 | def __enter__(self): 217 | shared.stop_everything = False 218 | return self 219 | 220 | def __exit__(self, exc_type, exc_val, exc_tb): 221 | shared.stop_everything = True 222 | shared.loaderCheckPoint.clear_torch_cache() 223 | self.thread.raise_exception() 224 | -------------------------------------------------------------------------------- /loader/models/extensions/extensions.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import traceback 3 | import torch 4 | 5 | # This iterator returns the extensions in the order specified in the command-line 6 | def iterator(): 7 | state_extensions = {} 8 | for name in sorted(state_extensions, key=lambda x: state_extensions[x][1]): 9 | if state_extensions[name][0]: 10 | yield getattr(extensions, name).script, name -------------------------------------------------------------------------------- /loader/models/extensions/llamacpp_model_alternative.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Based on 3 | https://github.com/abetlen/llama-cpp-python 4 | 5 | Documentation: 6 | https://abetlen.github.io/llama-cpp-python/ 7 | ''' 8 | 9 | from llama_cpp import Llama, LlamaCache 10 | 11 | from modules import shared 12 | from modules.callbacks import Iteratorize 13 | 14 | 15 | class LlamaCppModel: 16 | def __init__(self): 17 | self.initialized = False 18 | 19 | @classmethod 20 | def from_pretrained(self, path): 21 | result = self() 22 | 23 | params = { 24 | 'model_path': str(path), 25 | 'n_ctx': 2048, 26 | 'seed': 0, 27 | 'n_threads': shared.args.threads or None 28 | } 29 | self.model = Llama(**params) 30 | self.model.set_cache(LlamaCache) 31 | 32 | # This is ugly, but the model and the tokenizer are the same object in this library. 33 | return result, result 34 | 35 | def encode(self, string): 36 | if type(string) is str: 37 | string = string.encode() 38 | return self.model.tokenize(string) 39 | 40 | def generate(self, context="", token_count=20, temperature=1, top_p=1, top_k=50, repetition_penalty=1, callback=None): 41 | if type(context) is str: 42 | context = context.encode() 43 | tokens = self.model.tokenize(context) 44 | 45 | output = b"" 46 | count = 0 47 | for token in self.model.generate(tokens, top_k=top_k, top_p=top_p, temp=temperature, repeat_penalty=repetition_penalty): 48 | text = self.model.detokenize([token]) 49 | output += text 50 | if callback: 51 | callback(text.decode()) 52 | 53 | count += 1 54 | if count >= token_count or (token == self.model.token_eos()): 55 | break 56 | 57 | return output.decode() 58 | 59 | def generate_with_streaming(self, **kwargs): 60 | with Iteratorize(self.generate, kwargs, callback=None) as generator: 61 | reply = '' 62 | for token in generator: 63 | reply += token 64 | yield reply 65 | -------------------------------------------------------------------------------- /loader/models/extensions/thread_with_exception.py: -------------------------------------------------------------------------------- 1 | # Python program raising 2 | # exceptions in a python 3 | # thread 4 | 5 | import threading 6 | import ctypes 7 | import time 8 | 9 | 10 | class ThreadWithException(threading.Thread): 11 | 12 | def get_id(self): 13 | return self.ident 14 | 15 | def raise_exception(self): 16 | """raises the exception, performs cleanup if needed""" 17 | try: 18 | thread_id = self.get_id() 19 | tid = ctypes.c_long(thread_id) 20 | res = ctypes.pythonapi.PyThreadState_SetAsyncExc(tid, ctypes.py_object(SystemExit)) 21 | if res == 0: 22 | # pass 23 | raise ValueError("invalid thread id") 24 | elif res != 1: 25 | # """if it returns a number greater than one, you're in trouble, 26 | # and you should call it again with exc=NULL to revert the effect""" 27 | ctypes.pythonapi.PyThreadState_SetAsyncExc(tid, None) 28 | raise SystemError("PyThreadState_SetAsyncExc failed") 29 | except Exception as err: 30 | print(err) 31 | -------------------------------------------------------------------------------- /loader/models/loader/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from .loader import * 3 | -------------------------------------------------------------------------------- /loader/models/loader/args.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from configs.common_config import * 4 | 5 | 6 | # Additional argparse types 7 | def path(string): 8 | if not string: 9 | return '' 10 | s = os.path.expanduser(string) 11 | if not os.path.exists(s): 12 | raise argparse.ArgumentTypeError(f'No such file or directory: "{string}"') 13 | return s 14 | 15 | 16 | def file_path(string): 17 | if not string: 18 | return '' 19 | s = os.path.expanduser(string) 20 | if not os.path.isfile(s): 21 | raise argparse.ArgumentTypeError(f'No such file: "{string}"') 22 | return s 23 | 24 | 25 | def dir_path(string): 26 | if not string: 27 | return '' 28 | s = os.path.expanduser(string) 29 | if not os.path.isdir(s): 30 | raise argparse.ArgumentTypeError(f'No such directory: "{string}"') 31 | return s 32 | 33 | 34 | parser = argparse.ArgumentParser(prog='知识库', 35 | description='知识库') 36 | 37 | parser.add_argument('--no-remote-model', action='store_true', default=NO_REMOTE_MODEL, help='remote in the model on ' 38 | 'loader checkpoint, ' 39 | 'if your load local ' 40 | 'model to add the ` ' 41 | '--no-remote-model`') 42 | parser.add_argument('--model', type=str, default=LLM_MODEL, help='Name of the model to load by default.') 43 | parser.add_argument('--lora', type=str, help='Name of the LoRA to apply to the model by default.') 44 | parser.add_argument("--model-dir", type=str, default=MODEL_DIR, help="Path to directory with all the models") 45 | parser.add_argument("--lora-dir", type=str, default=LORA_DIR, help="Path to directory with all the loras") 46 | 47 | # Accelerate/transformers 48 | parser.add_argument('--load-in-8bit', action='store_true', default=LOAD_IN_8BIT, 49 | help='Load the model with 8-bit precision.') 50 | parser.add_argument('--bf16', action='store_true', default=BF16, 51 | help='Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU.') 52 | 53 | args = parser.parse_args([]) 54 | # Generares dict with a default value for each argument 55 | DEFAULT_ARGS = vars(args) 56 | -------------------------------------------------------------------------------- /loader/models/loader/loader.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import json 3 | import os 4 | import re 5 | import time 6 | from pathlib import Path 7 | from peft import PeftModel 8 | from typing import Optional, List, Dict, Tuple, Union 9 | from pynvml import (nvmlInit, nvmlDeviceGetCount, nvmlDeviceGetHandleByIndex, 10 | nvmlDeviceGetName, nvmlDeviceGetMemoryInfo, nvmlShutdown) 11 | def get_available_gpu(threshold=20000): 12 | # Initialize NVML 13 | nvmlInit() 14 | # Get the number of GPU devices 15 | device_count = nvmlDeviceGetCount() 16 | 17 | # Find GPU devices with available memory greater than the threshold 18 | available_gpus = [] 19 | for i in range(device_count): 20 | handle = nvmlDeviceGetHandleByIndex(i) 21 | info = nvmlDeviceGetMemoryInfo(handle) 22 | free_memory_mb = info.free / 1024 / 1024 23 | 24 | if free_memory_mb > threshold: 25 | available_gpus.append(i) 26 | 27 | # Shutdown NVML 28 | nvmlShutdown() 29 | # available_gpus = ['0'] 30 | 31 | return available_gpus 32 | 33 | 34 | def get_free_memory(): 35 | nvmlInit() 36 | # Get the number of GPU devices 37 | device_count = nvmlDeviceGetCount() 38 | 39 | # Find GPU devices with available memory greater than the threshold 40 | free_memory_gpus = [] 41 | for i in range(device_count): 42 | handle = nvmlDeviceGetHandleByIndex(i) 43 | info = nvmlDeviceGetMemoryInfo(handle) 44 | free_memory_mb = info.free / 1024 / 1024 45 | free_memory_gpus.append(free_memory_mb) 46 | 47 | # Shutdown NVML 48 | nvmlShutdown() 49 | return free_memory_gpus 50 | 51 | import torch 52 | import transformers 53 | 54 | from transformers import (AutoConfig, AutoModel, AutoModelForCausalLM, 55 | AutoTokenizer, BitsAndBytesConfig, LlamaTokenizer,BloomTokenizerFast, BloomConfig, BloomForCausalLM) 56 | from transformers.dynamic_module_utils import get_class_from_dynamic_module 57 | from transformers.modeling_utils import no_init_weights 58 | from transformers.utils import ContextManagers 59 | from accelerate import init_empty_weights 60 | from accelerate.utils import get_balanced_memory, infer_auto_device_map 61 | from configs.common_config import LLM_DEVICE,llm_model_dict 62 | 63 | 64 | 65 | 66 | 67 | 68 | class LoaderCheckPoint: 69 | """ 70 | 加载自定义 model CheckPoint 71 | """ 72 | # remote in the model on loader checkpoint 73 | no_remote_model: bool = False 74 | # 模型名称 75 | model_name: str = None 76 | tokenizer: object = None 77 | # 模型全路径 78 | model_path: str = None 79 | model: object = None 80 | model_config: object = None 81 | lora_names: set = [] 82 | model_dir: str = None 83 | lora_dir: str = None 84 | ptuning_dir: str = None 85 | use_ptuning_v2: bool = False 86 | # 如果开启了8bit量化加载,项目无法启动,参考此位置,选择合适的cuda版本,https://github.com/TimDettmers/bitsandbytes/issues/156 87 | load_in_8bit: bool = False 88 | is_llamacpp: bool = False 89 | bf16: bool = False 90 | params: object = None 91 | # 自定义设备网络 92 | device_map: Optional[Dict[str, int]] = None 93 | # 默认 cuda ,如果不支持cuda使用多卡, 如果不支持多卡 使用cpu 94 | llm_device = LLM_DEVICE 95 | 96 | def __init__(self, params: dict = None): 97 | """ 98 | 模型初始化 99 | :param params: 100 | """ 101 | self.model_path = None 102 | self.model = None 103 | self.tokenizer = None 104 | self.params = params or {} 105 | self.no_remote_model = params.get('no_remote_model', True) 106 | self.model_name = params.get('model', '') 107 | self.lora = params.get('lora', '') 108 | self.use_ptuning_v2 = params.get('use_ptuning_v2', False) 109 | self.model_dir = params.get('model_dir', '') 110 | self.lora_dir = params.get('lora_dir', '') 111 | self.ptuning_dir = params.get('ptuning_dir', 'ptuning-v2') 112 | self.load_in_8bit = params.get('load_in_8bit', False) 113 | self.bf16 = params.get('bf16', False) 114 | 115 | def _load_model_config(self, model_name): 116 | checkpoint = Path(f'{self.model_dir}/{model_name}') 117 | 118 | if self.model_path: 119 | checkpoint = Path(f'{self.model_path}') 120 | else: 121 | if not self.no_remote_model: 122 | checkpoint = model_name 123 | 124 | model_config = AutoConfig.from_pretrained(checkpoint, trust_remote_code=True) 125 | 126 | return model_config 127 | 128 | def _load_model(self, model_name): 129 | """ 130 | 加载自定义位置的model 131 | :param model_name: 132 | :return: 133 | """ 134 | print(f"Loading {model_name}...") 135 | t0 = time.time() 136 | 137 | checkpoint = Path(f'{self.model_dir}/{model_name}') 138 | 139 | self.is_llamacpp = len(list(checkpoint.glob('ggml*.bin'))) > 0 140 | 141 | if self.model_path: 142 | checkpoint = Path(f'{self.model_path}') 143 | else: 144 | if not self.no_remote_model: 145 | checkpoint = model_name 146 | 147 | if 'bloomz' in model_name.lower(): 148 | LoaderClass = BloomForCausalLM 149 | else: 150 | LoaderClass = AutoModelForCausalLM 151 | 152 | # Load the model in simple 16-bit mode by default 153 | if not any([self.llm_device.lower()=="cpu", 154 | self.load_in_8bit, self.is_llamacpp]): 155 | 156 | if torch.cuda.is_available() and self.llm_device.lower().startswith("cuda"): 157 | available_gpus = get_available_gpu(threshold=20000) 158 | print('available_gpus',available_gpus) 159 | if len(available_gpus)>0: 160 | available_gpu = available_gpus[0] 161 | target_device = torch.device(f'cuda:{str(available_gpu)}') 162 | print('target_device==',target_device) 163 | model = ( 164 | LoaderClass.from_pretrained(checkpoint, 165 | config=self.model_config, 166 | torch_dtype=torch.bfloat16 if self.bf16 else torch.float16, 167 | trust_remote_code=True) 168 | .half() 169 | .to(target_device) 170 | ) 171 | else: 172 | print('没有满足要求的GPU设备可用') 173 | model = None 174 | 175 | else: 176 | # print( 177 | # "Warning: torch.cuda.is_available() returned False.\nThis means that no GPU has been " 178 | # "detected.\nFalling back to CPU mode.\n") 179 | model = ( 180 | AutoModel.from_pretrained( 181 | checkpoint, 182 | config=self.model_config, 183 | trust_remote_code=True) 184 | .float() 185 | .to(self.llm_device) 186 | ) 187 | 188 | elif self.is_llamacpp: 189 | from models.extensions.llamacpp_model_alternative import LlamaCppModel 190 | 191 | model_file = list(checkpoint.glob('ggml*.bin'))[0] 192 | print(f"llama.cpp weights detected: {model_file}\n") 193 | 194 | model, tokenizer = LlamaCppModel.from_pretrained(model_file) 195 | return model, tokenizer 196 | 197 | # Custom 198 | else: 199 | params = {"low_cpu_mem_usage": True} 200 | 201 | if not self.llm_device.lower().startswith("cuda"): 202 | raise SystemError("8bit 模型需要 CUDA 支持,或者改用量化后模型!") 203 | else: 204 | params["device_map"] = 'auto' 205 | params["trust_remote_code"] = True 206 | if self.load_in_8bit: 207 | params['quantization_config'] = BitsAndBytesConfig(load_in_8bit=True, 208 | llm_int8_enable_fp32_cpu_offload=False) 209 | elif self.bf16: 210 | params["torch_dtype"] = torch.bfloat16 211 | else: 212 | params["torch_dtype"] = torch.float16 213 | 214 | if self.load_in_8bit and params.get('max_memory', None) is not None and params['device_map'] == 'auto': 215 | config = AutoConfig.from_pretrained(checkpoint) 216 | with init_empty_weights(): 217 | model = LoaderClass.from_config(config) 218 | model.tie_weights() 219 | if self.device_map is not None: 220 | params['device_map'] = self.device_map 221 | else: 222 | params['device_map'] = infer_auto_device_map( 223 | model, 224 | dtype=torch.int8, 225 | max_memory=params['max_memory'], 226 | no_split_module_classes=model._no_split_modules 227 | ) 228 | 229 | model = LoaderClass.from_pretrained(checkpoint, **params) 230 | 231 | # Loading the tokenizer 232 | if type(model) is transformers.LlamaForCausalLM: 233 | tokenizer = LlamaTokenizer.from_pretrained(checkpoint, clean_up_tokenization_spaces=True) 234 | # Leaving this here until the LLaMA tokenizer gets figured out. 235 | # For some people this fixes things, for others it causes an error. 236 | try: 237 | tokenizer.eos_token_id = 2 238 | tokenizer.bos_token_id = 1 239 | tokenizer.pad_token_id = 0 240 | except Exception as e: 241 | print(e) 242 | pass 243 | elif 'bloomz' in model_name.lower(): 244 | tokenizer = BloomTokenizerFast.from_pretrained(checkpoint, 245 | padding_side='left' 246 | ) 247 | else: 248 | tokenizer = AutoTokenizer.from_pretrained(checkpoint, trust_remote_code=True) 249 | 250 | print(f"Loaded the model in {(time.time() - t0):.2f} seconds.") 251 | return model, tokenizer 252 | 253 | 254 | def add_lora_to_model(self, lora): 255 | if 'pulse' in self.model_name.lower(): 256 | peft_path = f"finetune/pulse/output/{lora}" 257 | load_type = torch.float16 258 | self.model = PeftModel.from_pretrained(self.model, peft_path, torch_dtype=load_type) 259 | 260 | print('finished load lora:',lora) 261 | 262 | def _add_lora_to_model(self, lora_names): 263 | # 目前加载的lora 264 | prior_set = set(self.lora_names) 265 | # 需要加载的 266 | added_set = set(lora_names) - prior_set 267 | # 删除的lora 268 | removed_set = prior_set - set(lora_names) 269 | self.lora_names = list(lora_names) 270 | 271 | # Nothing to do = skip. 272 | if len(added_set) == 0 and len(removed_set) == 0: 273 | return 274 | 275 | # Only adding, and already peft? Do it the easy way. 276 | if len(removed_set) == 0 and len(prior_set) > 0: 277 | print(f"Adding the LoRA(s) named {added_set} to the model...") 278 | for lora in added_set: 279 | self.model.load_adapter(Path(f"{self.lora_dir}/{lora}"), lora) 280 | return 281 | 282 | # If removing anything, disable all and re-add. 283 | if len(removed_set) > 0: 284 | self.model.disable_adapter() 285 | 286 | if len(lora_names) > 0: 287 | print("Applying the following LoRAs to {}: {}".format(self.model_name, ', '.join(lora_names))) 288 | params = {} 289 | if self.llm_device.lower() != "cpu": 290 | params['dtype'] = self.model.dtype 291 | if hasattr(self.model, "hf_device_map"): 292 | params['device_map'] = {"base_model.model." + k: v for k, v in self.model.hf_device_map.items()} 293 | elif self.load_in_8bit: 294 | params['device_map'] = {'': 0} 295 | self.model.resize_token_embeddings(len(self.tokenizer)) 296 | 297 | self.model = PeftModel.from_pretrained(self.model, Path(f"{self.lora_dir}/{lora_names[0]}"), **params) 298 | 299 | for lora in lora_names[1:]: 300 | self.model.load_adapter(Path(f"{self.lora_dir}/{lora}"), lora) 301 | 302 | if not self.load_in_8bit and self.llm_device.lower() != "cpu": 303 | 304 | if not hasattr(self.model, "hf_device_map"): 305 | if torch.has_mps: 306 | device = torch.device('mps') 307 | self.model = self.model.to(device) 308 | else: 309 | self.model = self.model.cuda() 310 | 311 | def clear_torch_cache(self): 312 | gc.collect() 313 | if self.llm_device.lower() != "cpu": 314 | if torch.has_mps: 315 | try: 316 | from torch.mps import empty_cache 317 | empty_cache() 318 | except Exception as e: 319 | print(e) 320 | print( 321 | "如果您使用的是 macOS 建议将 pytorch 版本升级至 2.0.0 或更高版本,以支持及时清理 torch 产生的内存占用。") 322 | elif torch.has_cuda: 323 | device_id = "0" if torch.cuda.is_available() else None 324 | CUDA_DEVICE = f"{self.llm_device}:{device_id}" if device_id else self.llm_device 325 | with torch.cuda.device(CUDA_DEVICE): 326 | torch.cuda.empty_cache() 327 | torch.cuda.ipc_collect() 328 | else: 329 | print("未检测到 cuda 或 mps,暂不支持清理显存") 330 | 331 | def unload_model(self): 332 | del self.model 333 | del self.tokenizer 334 | self.model = self.tokenizer = None 335 | self.clear_torch_cache() 336 | 337 | def set_model_path(self, model_path): 338 | self.model_path = model_path 339 | 340 | def reload_model(self): 341 | self.unload_model() 342 | self.model_config = self._load_model_config(self.model_name) 343 | 344 | if self.use_ptuning_v2: 345 | try: 346 | prefix_encoder_file = open(Path(f'{self.ptuning_dir}/config.json'), 'r') 347 | prefix_encoder_config = json.loads(prefix_encoder_file.read()) 348 | prefix_encoder_file.close() 349 | self.model_config.pre_seq_len = prefix_encoder_config['pre_seq_len'] 350 | self.model_config.prefix_projection = prefix_encoder_config['prefix_projection'] 351 | except Exception as e: 352 | print("加载PrefixEncoder config.json失败") 353 | 354 | self.model, self.tokenizer = self._load_model(self.model_name) 355 | 356 | if self.lora: 357 | self.add_lora_to_model(self.lora) 358 | 359 | if self.use_ptuning_v2: 360 | try: 361 | prefix_state_dict = torch.load(Path(f'{self.ptuning_dir}/pytorch_model.bin')) 362 | new_prefix_state_dict = {} 363 | for k, v in prefix_state_dict.items(): 364 | if k.startswith("transformer.prefix_encoder."): 365 | new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v 366 | self.model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict) 367 | self.model.transformer.prefix_encoder.float() 368 | except Exception as e: 369 | print("加载PrefixEncoder模型参数失败") 370 | 371 | self.model = self.model.eval() 372 | -------------------------------------------------------------------------------- /loader/models/shared.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | from loader.models.loader.args import parser 4 | from loader.models.loader import LoaderCheckPoint 5 | from configs.common_config import (llm_model_dict, LLM_MODEL) 6 | from loader.models.base import BaseAnswer 7 | """迭代器是否停止状态""" 8 | stop_everything = False 9 | 10 | loaderCheckPoint: LoaderCheckPoint = None 11 | 12 | 13 | def loaderLLM(llm_model: str = None, no_remote_model: bool = False, use_ptuning_v2: bool = False) -> BaseAnswer: 14 | """ 15 | init llm_model_ins LLM 16 | :param llm_model: model_name 17 | :param no_remote_model: remote in the model on loader checkpoint, if your load local model to add the ` --no-remote-model 18 | :param use_ptuning_v2: Use p-tuning-v2 PrefixEncoder 19 | :return: 20 | """ 21 | pre_model_name = loaderCheckPoint.model_name 22 | llm_model_info = llm_model_dict[pre_model_name] 23 | 24 | if no_remote_model: 25 | loaderCheckPoint.no_remote_model = no_remote_model 26 | 27 | if use_ptuning_v2: 28 | loaderCheckPoint.use_ptuning_v2 = use_ptuning_v2 29 | 30 | if llm_model: 31 | llm_model_info = llm_model_dict[llm_model] 32 | 33 | if loaderCheckPoint.no_remote_model: 34 | loaderCheckPoint.model_name = llm_model_info['name'] 35 | else: 36 | loaderCheckPoint.model_name = llm_model_info['pretrained_model_name'] 37 | 38 | loaderCheckPoint.model_path = llm_model_info["local_model_path"] 39 | 40 | loaderCheckPoint.reload_model() 41 | 42 | provides_class = getattr(sys.modules['loader.models'], llm_model_info['provides']) 43 | modelInsLLM = provides_class(checkPoint=loaderCheckPoint) 44 | return modelInsLLM 45 | -------------------------------------------------------------------------------- /loader/pdf_loader.py: -------------------------------------------------------------------------------- 1 | """Loader that loads image files.""" 2 | from typing import List 3 | 4 | from langchain.document_loaders.unstructured import UnstructuredFileLoader 5 | from paddleocr import PaddleOCR 6 | import os 7 | import fitz 8 | import nltk 9 | from configs.common_config import NLTK_DATA_PATH 10 | 11 | nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path 12 | 13 | class UnstructuredPaddlePDFLoader(UnstructuredFileLoader): 14 | """Loader that uses unstructured to load image files, such as PNGs and JPGs.""" 15 | 16 | def _get_elements(self) -> List: 17 | def pdf_ocr_txt(filepath, dir_path="tmp_files"): 18 | full_dir_path = os.path.join(os.path.dirname(filepath), dir_path) 19 | if not os.path.exists(full_dir_path): 20 | os.makedirs(full_dir_path) 21 | ocr = PaddleOCR(use_angle_cls=True, lang="ch", use_gpu=False, show_log=False) 22 | doc = fitz.open(filepath) 23 | txt_file_path = os.path.join(full_dir_path, f"{os.path.split(filepath)[-1]}.txt") 24 | img_name = os.path.join(full_dir_path, 'tmp.png') 25 | with open(txt_file_path, 'w', encoding='utf-8') as fout: 26 | for i in range(doc.page_count): 27 | page = doc[i] 28 | text = page.get_text("") 29 | fout.write(text) 30 | fout.write("\n") 31 | 32 | img_list = page.get_images() 33 | for img in img_list: 34 | pix = fitz.Pixmap(doc, img[0]) 35 | if pix.n - pix.alpha >= 4: 36 | pix = fitz.Pixmap(fitz.csRGB, pix) 37 | pix.save(img_name) 38 | 39 | result = ocr.ocr(img_name) 40 | ocr_result = [i[1][0] for line in result for i in line] 41 | fout.write("\n".join(ocr_result)) 42 | os.remove(img_name) 43 | return txt_file_path 44 | 45 | txt_file_path = pdf_ocr_txt(self.file_path) 46 | from unstructured.partition.text import partition_text 47 | return partition_text(filename=txt_file_path, **self.unstructured_kwargs) 48 | 49 | 50 | if __name__ == "__main__": 51 | filepath = os.path.join(os.path.dirname(os.path.dirname(__file__)), "content", "samples", "test.pdf") 52 | loader = UnstructuredPaddlePDFLoader(filepath, mode="elements") 53 | docs = loader.load() 54 | for doc in docs: 55 | print(doc) 56 | -------------------------------------------------------------------------------- /loader/textsplitter/__init__.py: -------------------------------------------------------------------------------- 1 | from .chinese_text_splitter import ChineseTextSplitter 2 | from .ali_text_splitter import AliTextSplitter -------------------------------------------------------------------------------- /loader/textsplitter/ali_text_splitter.py: -------------------------------------------------------------------------------- 1 | from langchain.text_splitter import CharacterTextSplitter 2 | import re 3 | from typing import List 4 | 5 | 6 | class AliTextSplitter(CharacterTextSplitter): 7 | def __init__(self, pdf: bool = False, **kwargs): 8 | super().__init__(**kwargs) 9 | self.pdf = pdf 10 | 11 | def split_text(self, text: str) -> List[str]: 12 | # use_document_segmentation参数指定是否用语义切分文档,此处采取的文档语义分割模型为达摩院开源的nlp_bert_document-segmentation_chinese-base,论文见https://arxiv.org/abs/2107.09278 13 | # 如果使用模型进行文档语义切分,那么需要安装modelscope[nlp]:pip install "modelscope[nlp]" -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html 14 | # 考虑到使用了三个模型,可能对于低配置gpu不太友好,因此这里将模型load进cpu计算,有需要的话可以替换device为自己的显卡id 15 | if self.pdf: 16 | text = re.sub(r"\n{3,}", r"\n", text) 17 | text = re.sub('\s', " ", text) 18 | text = re.sub("\n\n", "", text) 19 | from modelscope.pipelines import pipeline 20 | 21 | p = pipeline( 22 | task="document-segmentation", 23 | model='damo/nlp_bert_document-segmentation_chinese-base', 24 | device="cpu") 25 | result = p(documents=text) 26 | sent_list = [i for i in result["text"].split("\n\t") if i] 27 | return sent_list 28 | -------------------------------------------------------------------------------- /loader/textsplitter/chinese_text_splitter.py: -------------------------------------------------------------------------------- 1 | from langchain.text_splitter import CharacterTextSplitter 2 | import re 3 | from typing import List 4 | from configs.common_config import SENTENCE_SIZE 5 | 6 | 7 | class ChineseTextSplitter(CharacterTextSplitter): 8 | def __init__(self, pdf: bool = False, sentence_size: int = SENTENCE_SIZE, **kwargs): 9 | super().__init__(**kwargs) 10 | self.pdf = pdf 11 | self.sentence_size = sentence_size 12 | 13 | def split_text1(self, text: str) -> List[str]: 14 | if self.pdf: 15 | text = re.sub(r"\n{3,}", "\n", text) 16 | text = re.sub('\s', ' ', text) 17 | text = text.replace("\n\n", "") 18 | sent_sep_pattern = re.compile('([﹒﹔﹖﹗.。!?]["’”」』]{0,2}|(?=["‘“「『]{1,2}|$))') # del :; 19 | sent_list = [] 20 | for ele in sent_sep_pattern.split(text): 21 | if sent_sep_pattern.match(ele) and sent_list: 22 | sent_list[-1] += ele 23 | elif ele: 24 | sent_list.append(ele) 25 | return sent_list 26 | 27 | def split_text(self, text: str) -> List[str]: ##此处需要进一步优化逻辑 28 | if self.pdf: 29 | text = re.sub(r"\n{3,}", r"\n", text) 30 | text = re.sub('\s', " ", text) 31 | text = re.sub("\n\n", "", text) 32 | 33 | text = re.sub(r'([;;.!?。!?\?])([^”’])', r"\1\n\2", text) # 单字符断句符 34 | text = re.sub(r'(\.{6})([^"’”」』])', r"\1\n\2", text) # 英文省略号 35 | text = re.sub(r'(\…{2})([^"’”」』])', r"\1\n\2", text) # 中文省略号 36 | text = re.sub(r'([;;!?。!?\?]["’”」』]{0,2})([^;;!?,。!?\?])', r'\1\n\2', text) 37 | # 如果双引号前有终止符,那么双引号才是句子的终点,把分句符\n放到双引号后,注意前面的几句都小心保留了双引号 38 | text = text.rstrip() # 段尾如果有多余的\n就去掉它 39 | # 很多规则中会考虑分号;,但是这里我把它忽略不计,破折号、英文双引号等同样忽略,需要的再做些简单调整即可。 40 | ls = [i for i in text.split("\n") if i] 41 | for ele in ls: 42 | if len(ele) > self.sentence_size: 43 | ele1 = re.sub(r'([,,.]["’”」』]{0,2})([^,,.])', r'\1\n\2', ele) 44 | ele1_ls = ele1.split("\n") 45 | for ele_ele1 in ele1_ls: 46 | if len(ele_ele1) > self.sentence_size: 47 | ele_ele2 = re.sub(r'([\n]{1,}| {2,}["’”」』]{0,2})([^\s])', r'\1\n\2', ele_ele1) 48 | ele2_ls = ele_ele2.split("\n") 49 | for ele_ele2 in ele2_ls: 50 | if len(ele_ele2) > self.sentence_size: 51 | ele_ele3 = re.sub('( ["’”」』]{0,2})([^ ])', r'\1\n\2', ele_ele2) 52 | ele2_id = ele2_ls.index(ele_ele2) 53 | ele2_ls = ele2_ls[:ele2_id] + [i for i in ele_ele3.split("\n") if i] + ele2_ls[ 54 | ele2_id + 1:] 55 | ele_id = ele1_ls.index(ele_ele1) 56 | ele1_ls = ele1_ls[:ele_id] + [i for i in ele2_ls if i] + ele1_ls[ele_id + 1:] 57 | 58 | id = ls.index(ele) 59 | ls = ls[:id] + [i for i in ele1_ls if i] + ls[id + 1:] 60 | return ls 61 | -------------------------------------------------------------------------------- /loader/utils/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def torch_gc(): 4 | if torch.cuda.is_available(): 5 | # with torch.cuda.device(DEVICE): 6 | torch.cuda.empty_cache() 7 | torch.cuda.ipc_collect() 8 | elif torch.backends.mps.is_available(): 9 | try: 10 | from torch.mps import empty_cache 11 | empty_cache() 12 | except Exception as e: 13 | print(e) 14 | print("如果您使用的是 macOS 建议将 pytorch 版本升级至 2.0.0 或更高版本,以支持及时清理 torch 产生的内存占用。") -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.18.0 2 | aiofiles==23.1.0 3 | aiohttp==3.8.4 4 | aiosignal==1.3.1 5 | altair==5.0.1 6 | antlr4-python3-runtime==4.9.3 7 | anyio==3.7.0 8 | argilla==1.11.0 9 | astor==0.8.1 10 | async-timeout==4.0.2 11 | attrdict==2.0.1 12 | attrs==23.1.0 13 | azure-core==1.27.1 14 | Babel==2.12.1 15 | backoff==2.2.1 16 | bce-python-sdk==0.8.83 17 | beautifulsoup4==4.12.2 18 | bitsandbytes==0.39.1 19 | blinker==1.6.2 20 | Brotli==1.0.9 21 | cachetools==5.3.1 22 | certifi==2023.5.7 23 | cffi==1.15.1 24 | chardet==5.1.0 25 | charset-normalizer==3.1.0 26 | click==8.1.3 27 | cmake==3.26.4 28 | coloredlogs==15.0.1 29 | commonmark==0.9.1 30 | contourpy==1.1.0 31 | cpm-kernels==1.0.11 32 | cryptography==41.0.1 33 | cssselect==1.2.0 34 | cssutils==2.7.1 35 | cycler==0.11.0 36 | Cython==0.29.35 37 | dataclasses-json==0.5.8 38 | decorator==5.1.1 39 | Deprecated==1.2.14 40 | dill==0.3.6 41 | effdet==0.4.1 42 | et-xmlfile==1.1.0 43 | exceptiongroup==1.1.1 44 | faiss-cpu==1.7.4 45 | fastapi==0.95.1 46 | ffmpy==0.3.0 47 | filelock==3.12.2 48 | filetype==1.2.0 49 | fire==0.5.0 50 | Flask==2.3.2 51 | flask-babel==3.1.0 52 | flatbuffers==23.5.26 53 | fonttools==4.40.0 54 | frozenlist==1.3.3 55 | fsspec==2023.6.0 56 | future==0.18.3 57 | gevent==22.10.2 58 | geventhttpclient==2.0.2 59 | gradio==3.28.3 60 | gradio_client==0.2.7 61 | greenlet==2.0.2 62 | grpcio==1.56.0 63 | h11==0.14.0 64 | httpcore==0.16.3 65 | httpx==0.23.3 66 | huggingface-hub==0.15.1 67 | humanfriendly==10.0 68 | icetk==0.0.7 69 | idna==3.4 70 | imageio==2.31.1 71 | imgaug==0.4.0 72 | importlib-metadata==6.7.0 73 | importlib-resources==5.12.0 74 | iopath==0.1.10 75 | itsdangerous==2.1.2 76 | Jinja2==3.1.2 77 | joblib==1.2.0 78 | jsonschema==4.17.3 79 | kiwisolver==1.4.4 80 | langchain==0.0.174 81 | layoutparser==0.3.4 82 | lazy_loader==0.2 83 | linkify-it-py==2.0.2 84 | lit==16.0.6 85 | lmdb==1.4.1 86 | lxml==4.9.2 87 | Markdown==3.4.3 88 | markdown-it-py==2.2.0 89 | MarkupSafe==2.1.3 90 | marshmallow==3.19.0 91 | marshmallow-enum==1.5.1 92 | matplotlib==3.7.1 93 | mdit-py-plugins==0.3.3 94 | mdurl==0.1.2 95 | monotonic==1.6 96 | mpmath==1.3.0 97 | msg-parser==1.2.0 98 | multidict==6.0.4 99 | multiprocess==0.70.14 100 | mypy-extensions==1.0.0 101 | networkx==3.1 102 | nltk==3.8.1 103 | numexpr==2.8.4 104 | numpy==1.23.5 105 | nvidia-cublas-cu11==11.10.3.66 106 | nvidia-cuda-cupti-cu11==11.7.101 107 | nvidia-cuda-nvrtc-cu11==11.7.99 108 | nvidia-cuda-runtime-cu11==11.7.99 109 | nvidia-cudnn-cu11==8.5.0.96 110 | nvidia-cufft-cu11==10.9.0.58 111 | nvidia-curand-cu11==10.2.10.91 112 | nvidia-cusolver-cu11==11.4.0.1 113 | nvidia-cusparse-cu11==11.7.4.91 114 | nvidia-nccl-cu11==2.14.3 115 | nvidia-nvtx-cu11==11.7.91 116 | olefile==0.46 117 | omegaconf==2.3.0 118 | onnx==1.12.0 119 | onnxruntime==1.15.1 120 | openapi-schema-pydantic==1.2.4 121 | opencv-contrib-python==4.6.0.66 122 | opencv-python==4.6.0.66 123 | openpyxl==3.1.2 124 | opt-einsum==3.3.0 125 | orjson==3.9.1 126 | packaging==23.1 127 | paddle-bfloat==0.1.7 128 | paddleocr==2.6.1.3 129 | paddlepaddle==2.4.2 130 | pandas==1.5.3 131 | pdf2docx==0.5.6 132 | pdf2image==1.16.3 133 | pdfminer.six==20221105 134 | pdfplumber==0.9.0 135 | peft==0.3.0 136 | Pillow==9.5.0 137 | portalocker==2.7.0 138 | premailer==3.10.0 139 | protobuf==3.18.3 140 | psutil==5.9.5 141 | pyclipper==1.3.0.post4 142 | pycocotools==2.0.6 143 | pycparser==2.21 144 | pycryptodome==3.18.0 145 | pydantic==1.10.9 146 | pydub==0.25.1 147 | Pygments==2.15.1 148 | PyMuPDF==1.20.2 149 | pynvml==11.5.0 150 | pypandoc==1.11 151 | pyparsing==3.1.0 152 | pypinyin==0.48.0 153 | pyrsistent==0.19.3 154 | pytesseract==0.3.10 155 | python-dateutil==2.8.2 156 | python-docx==0.8.11 157 | python-magic==0.4.27 158 | python-multipart==0.0.6 159 | python-pptx==0.6.21 160 | python-rapidjson==1.10 161 | pytz==2023.3 162 | PyWavelets==1.4.1 163 | PyYAML==6.0 164 | rapidfuzz==3.1.1 165 | rarfile==4.0 166 | regex==2023.6.3 167 | requests==2.28.2 168 | rfc3986==1.5.0 169 | rich==13.0.1 170 | ruamel.yaml==0.17.32 171 | ruamel.yaml.clib==0.2.7 172 | safetensors==0.3.1 173 | scikit-image==0.21.0 174 | scikit-learn==1.2.2 175 | scipy==1.11.0 176 | semantic-version==2.10.0 177 | sentence-transformers==2.2.2 178 | sentencepiece==0.1.99 179 | shapely==2.0.1 180 | six==1.16.0 181 | sniffio==1.3.0 182 | soupsieve==2.4.1 183 | SQLAlchemy==2.0.17 184 | starlette==0.26.1 185 | sympy==1.12 186 | tabulate==0.9.0 187 | tenacity==8.2.2 188 | termcolor==2.3.0 189 | threadpoolctl==3.1.0 190 | tifffile==2023.4.12 191 | timm==0.9.2 192 | tokenizers==0.13.3 193 | toolz==0.12.0 194 | torch==2.0.1 195 | torchvision==0.15.2 196 | tqdm==4.65.0 197 | transformers==4.29.1 198 | triton==2.0.0 199 | tritonclient==2.34.0 200 | typer==0.9.0 201 | typing-inspect==0.9.0 202 | typing_extensions==4.6.3 203 | tzdata==2023.3 204 | uc-micro-py==1.0.2 205 | unstructured==0.7.9 206 | unstructured-inference==0.5.1 207 | urllib3==1.26.16 208 | uvicorn==0.21.1 209 | visualdl==2.5.0 210 | Wand==0.6.11 211 | websockets==11.0.3 212 | Werkzeug==2.3.6 213 | wrapt==1.14.1 214 | x2paddle==1.4.1 215 | xlrd==2.0.1 216 | XlsxWriter==3.1.2 217 | yarl==1.9.2 218 | zipp==3.15.0 219 | zope.event==5.0 220 | zope.interface==6.0 221 | -------------------------------------------------------------------------------- /vector_store/drug_kb/index.faiss: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JuneYaooo/medical_kb_chatbot/be4ff60c47ffb31ac3052f82b5ed607d7c7ab089/vector_store/drug_kb/index.faiss -------------------------------------------------------------------------------- /vector_store/drug_kb/index.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JuneYaooo/medical_kb_chatbot/be4ff60c47ffb31ac3052f82b5ed607d7c7ab089/vector_store/drug_kb/index.pkl --------------------------------------------------------------------------------