├── .gitignore
├── .idea
├── .gitignore
├── inspectionProfiles
│ └── profiles_settings.xml
├── medical_kb_chatbot.iml
├── modules.xml
└── vcs.xml
├── README.md
├── README_en.md
├── agent
└── bing_search.py
├── app.py
├── chains
├── local_doc_qa.py
├── modules
│ ├── embeddings.py
│ ├── excel_load.py
│ ├── json_load.py
│ └── vectorstores.py
└── text_load.py
├── configs
├── common_config.py
└── test_ass.yaml
├── demo_data
├── kb_drug_demo.jsonl
└── lora_demo.xlsx
├── environment.yml
├── finetune
├── pulse
│ ├── configs
│ │ └── lora_config_bloom.json
│ ├── convert_to_conv_data.py
│ ├── finetune.py
│ └── src
│ │ ├── sample_generator.py
│ │ ├── trainer.py
│ │ └── utils.py
└── pulse_utils.py
├── img
├── 1.jpg
├── 2.jpg
├── 3.jpg
└── 4.jpg
├── loader
├── __init__.py
├── image_loader.py
├── models
│ ├── __init__.py
│ ├── __main__.py
│ ├── base.py
│ ├── bloomz_llm.py
│ ├── extensions
│ │ ├── callback.py
│ │ ├── extensions.py
│ │ ├── llamacpp_model_alternative.py
│ │ └── thread_with_exception.py
│ ├── loader
│ │ ├── __init__.py
│ │ ├── args.py
│ │ └── loader.py
│ └── shared.py
├── pdf_loader.py
├── textsplitter
│ ├── __init__.py
│ ├── ali_text_splitter.py
│ └── chinese_text_splitter.py
└── utils
│ └── __init__.py
├── requirements.txt
└── vector_store
└── drug_kb
├── index.faiss
└── index.pkl
/.gitignore:
--------------------------------------------------------------------------------
1 | **/__pycache__
--------------------------------------------------------------------------------
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # Default ignored files
2 | /shelf/
3 | /workspace.xml
4 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/medical_kb_chatbot.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | [[中文版](https://github.com/JuneYaooo/medical_kb_chatbot/blob/main/README.md)] [[English](https://github.com/JuneYaooo/medical_kb_chatbot/blob/main/README_en.md)]
2 |
3 | # 医疗知识聊天机器人
4 |
5 | 欢迎使用 医疗知识聊天机器人,这是一款基于 PULSE 模型,引入知识库及微调训练的聊天机器人,旨在提供更实用的医疗相关功能和服务。用户可以自己添加相关知识库,进行模型微调,体验更丰富的应用场景:
6 |
7 | ## 你可以用它来做什么
8 |
9 | - **药物查询**:提供药物数据库,用户可以搜索特定药物的信息,如用途、剂量、副作用等。
10 |
11 | - **病症解释**:提供常见疾病、症状和医学术语的解释和定义,帮助用户更好地理解医学知识。
12 |
13 | - **医疗客服**:添加相关医疗产品文档,支持用户与聊天机器人进行个性化对话,回答医疗产品相关问题,提供准确和可靠的信息。
14 |
15 | ## 使用方法
16 |
17 | ### 下载模型与修改配置文件
18 |
19 | 如果直接使用有问题,可以将PULSE模型下载到本地:https://huggingface.co/OpenMEDLab/PULSE-7bv5
20 |
21 | 然后在configs/common_config.py文件中将模型路径修改为本地路径,如修改embedding_model_dict和llm_model_dict中的路径即可。
22 |
23 | ### 安装
24 |
25 | 首先,克隆本项目到本地计算机:
26 |
27 | ```
28 | git clone https://github.com/JuneYaooo/medical_kb_chatbot.git
29 | ```
30 |
31 | #### 使用 pip 安装
32 |
33 | 确保您的计算机上已安装以下依赖项:
34 |
35 | - Python 3.9
36 | - pip 包管理器
37 |
38 | 进入项目目录并安装必要的依赖项:
39 |
40 | ```
41 | cd medical_kb_chatbot
42 | pip install -r requirements.txt
43 | ```
44 |
45 | #### 使用 conda 安装
46 |
47 | 确保您的计算机上已安装以下依赖项:
48 |
49 | - Anaconda 或 Miniconda
50 |
51 | 进入项目目录并创建一个新的 conda 环境:
52 |
53 | ```
54 | cd medical_kb_chatbot
55 | conda env create -f environment.yml
56 | ```
57 |
58 | 激活新创建的环境:
59 |
60 | ```
61 | conda activate kb_chat
62 | ```
63 |
64 | 然后运行聊天机器人:
65 |
66 | ```
67 | python app.py
68 | ```
69 |
70 | ### 使用说明
71 | #### 可选择在知识库页面配置知识库
72 | - 支持excel、json、非图片类型的pdf、word、txt等格式
73 | - 其中excel、json需要按要求格式上传
74 | - 鼓励挂载一些医疗知识库尝试效果,有好的案例欢迎分享
75 | - 提供了一点点药品[demo数据](https://github.com/JuneYaooo/medical_kb_chatbot/blob/main/demo_data/kb_drug_demo.jsonl) ,可以下下来试一下
76 |
77 |
78 | 
79 |
80 | #### 可选择使用lora微调模型
81 | - 微调目前最小需要24G显卡(~一张3090)
82 | - 微调结束后,可看到更新时间
83 | - 提供了一点点训练[demo数据](https://github.com/JuneYaooo/medical_kb_chatbot/blob/main/demo_data/lora_demo.xlsx) ,可以下下来试一下
84 |
85 |
86 | 
87 |
88 | #### 在医疗小助手页面选择配置自己的知识库聊天小助手(可自由选择是否使用某个知识库/微调的lora)
89 | - 配置prompt可参考模板多尝试,有发现好的prompt欢迎分享
90 | - prompt 设置可以参考如下格式
91 | ```
92 | 假设你是用药助手,请根据文档来回复,如果文档内容为空或者None,则忽略,文档:{context}\n{chat_history}User:{question}Helper:
93 | ```
94 |
95 | 
96 |
97 | #### 配置好小助手,来对话测试页面试试吧
98 |
99 | - 选择一个已经配置好的聊天小助手,来体验一下吧
100 |
101 | 
102 |
103 | ## 致谢
104 |
105 | - [PULSE](https://github.com/openmedlab/PULSE): 本项目模型来源于PULSE
106 | - [langchain-ChatGLM](https://github.com/imClumsyPanda/langchain-ChatGLM): 本项目知识库部分参考了langchain-ChatGLM的代码
107 | - [BELLE](https://github.com/LianjiaTech/BELLE): 本项目Lora微调部分参考了BELLE的代码
108 |
109 | ## 贡献
110 |
111 | 如果您对该项目感兴趣,欢迎贡献您的代码和改进建议。您可以通过以下方式参与:
112 |
113 | 1. 提交问题和建议到本项目的 Issue 页面。
114 | 2. Fork 本项目并提交您的改进建议,我们将会审查并合并合适的改动。
115 |
--------------------------------------------------------------------------------
/README_en.md:
--------------------------------------------------------------------------------
1 | [[Chinese Version](https://github.com/JuneYaooo/medical_kb_chatbot/blob/main/README.md)] [[English Version](https://github.com/JuneYaooo/medical_kb_chatbot/blob/main/README_en.md)]
2 |
3 | # Medical Knowledge Chatbot
4 |
5 | Welcome to the Medical Knowledge Chatbot. This is a chatbot based on the PULSE model, incorporating knowledge base and fine-tuning training, aiming to provide more practical medical-related functions and services. Users can add relevant knowledge bases and perform model fine-tuning to experience richer application scenarios.
6 |
7 | ## Example Applications of the Medical Chatbot
8 |
9 | - **Drug Query**: Provides a drug database where users can search for specific drug information such as uses, dosage, side effects, etc.
10 |
11 | - **Symptom Explanation**: Provides explanations and definitions of common diseases, symptoms, and medical terminologies to help users better understand medical knowledge.
12 |
13 | - **Medical Customer Service**: Adds relevant medical product documentation, supports personalized conversations with the chatbot, answers questions related to medical products, and provides accurate and reliable information.
14 |
15 | ## Usage
16 |
17 | ### Download Model and Modify Configuration File
18 |
19 | If there are any issues with direct usage, you can download the PULSE model to your local machine from: [https://huggingface.co/OpenMEDLab/PULSE-7bv5](https://huggingface.co/OpenMEDLab/PULSE-7bv5)
20 |
21 | Then, modify the model path in the `configs/common_config.py` file to the local path. You can modify the paths in the `embedding_model_dict` and `llm_model_dict` variables.
22 |
23 | ### Installation
24 |
25 | First, clone this project to your local machine:
26 |
27 | ```
28 | git clone https://github.com/JuneYaooo/medical_kb_chatbot.git
29 | ```
30 |
31 | #### Install using pip
32 |
33 | Make sure the following dependencies are installed on your machine:
34 |
35 | - Python 3.9
36 | - pip package manager
37 |
38 | Navigate to the project directory and install the necessary dependencies:
39 |
40 | ```
41 | cd medical_kb_chatbot
42 | pip install -r requirements.txt
43 | ```
44 |
45 | #### Install using conda
46 |
47 | Make sure the following dependencies are installed on your machine:
48 |
49 | - Anaconda or Miniconda
50 |
51 | Navigate to the project directory and create a new conda environment:
52 |
53 | ```
54 | cd medical_kb_chatbot
55 | conda env create -f environment.yml
56 | ```
57 |
58 | Activate the newly created environment:
59 |
60 | ```
61 | conda activate kb_chat
62 | ```
63 |
64 | Then run the chatbot:
65 |
66 | ```
67 | python app.py
68 | ```
69 |
70 | ### Instructions for Use
71 |
72 | #### Configure Knowledge Bases on the Knowledge Base Page
73 |
74 | - Supports formats such as Excel, JSON, non-image PDFs, Word documents, TXT files, etc.
75 | - For Excel and JSON formats, they need to be uploaded in the required format.
76 | - It is encouraged to mount some medical knowledge bases to try out the effectiveness. Good examples are welcome to be shared.
77 |
78 | 
79 |
80 | #### Fine-Tune the Model using Lora
81 |
82 | - Fine-tuning currently requires a minimum of 24GB GPU (~one 3090).
83 | - After fine-tuning, you can see the update time.
84 |
85 | 
86 |
87 | #### Configure Your Knowledge Base Chatbot on the Medical Chatbot Page (Optional to use a specific knowledge base/fine-tuned Lora)
88 |
89 | - You can configure your own knowledge base chatbot by selecting whether to use a particular knowledge base or the fine-tuned Lora.
90 | - Refer to the template for configuring prompts and try out different options. Good prompts are welcome to be shared.
91 |
92 | 
93 |
94 | #### Once the chatbot is configured, try it out on the conversation test page
95 |
96 | - Select a pre-configured chatbot and give it a try.
97 |
98 | 
99 |
100 | ## Acknowledgments
101 |
102 | - [PULSE](https://github.com/openmedlab/PULSE): The model used in this project is based on PULSE.
103 | - [langchain-ChatGLM](https://github.com/imClumsyPanda/langchain-ChatGLM): The code for the knowledge base part of this project was inspired by langchain-ChatGLM.
104 | - [BELLE](https://github.com/LianjiaTech/BELLE): The code for the Lora fine-tuning part of this project was inspired by BELLE.
105 |
106 | ## Contribution
107 |
108 | If you are interested in this project, you are welcome to contribute your code and improvement suggestions. You can participate in the following ways:
109 |
110 | 1. Submit issues and suggestions on the Issue page of this project.
111 | 2. Fork this project, make your improvements, and submit a pull request. We will review and merge appropriate changes.
--------------------------------------------------------------------------------
/agent/bing_search.py:
--------------------------------------------------------------------------------
1 | #coding=utf8
2 |
3 | from langchain.utilities import BingSearchAPIWrapper
4 | from configs.common_config import BING_SEARCH_URL, BING_SUBSCRIPTION_KEY
5 |
6 |
7 | def bing_search(text, result_len=3):
8 | if not (BING_SEARCH_URL and BING_SUBSCRIPTION_KEY):
9 | return [{"snippet": "please set BING_SUBSCRIPTION_KEY and BING_SEARCH_URL in os ENV",
10 | "title": "env inof not fould",
11 | "link": "https://python.langchain.com/en/latest/modules/agents/tools/examples/bing_search.html"}]
12 | search = BingSearchAPIWrapper(bing_subscription_key=BING_SUBSCRIPTION_KEY,
13 | bing_search_url=BING_SEARCH_URL)
14 | return search.results(text, result_len)
15 |
16 |
17 | if __name__ == "__main__":
18 | r = bing_search('python')
19 | print(r)
20 |
--------------------------------------------------------------------------------
/app.py:
--------------------------------------------------------------------------------
1 | import gradio as gr
2 | import shutil
3 |
4 | from chains.local_doc_qa import LocalDocQA
5 | from configs.common_config import *
6 | import nltk
7 | from loader.models.base import (BaseAnswer,
8 | AnswerResult,
9 | AnswerResultStream,
10 | AnswerResultQueueSentinelTokenListenerQueue)
11 | import loader.models.shared as shared
12 | from loader.models.loader.args import parser
13 | from loader.models.loader import LoaderCheckPoint
14 | from finetune.pulse_utils import pulse_train_model, stop_train_process
15 | import shutil
16 | import time
17 | import datetime
18 | import re
19 | import os
20 | import glob
21 |
22 | def get_file_modify_time(filename):
23 | try:
24 | return datetime.datetime.fromtimestamp(os.stat(filename).st_mtime).strftime("%Y-%m-%d %H:%M:%S")
25 | except Exception as e:
26 | print('Failed to get modification time for {}'.format(filename))
27 | print(e)
28 | return 'not available'
29 |
30 | def get_model_update_time(model_name, lora_name):
31 | if 'pulse' in model_name.lower():
32 | update_time = get_file_modify_time(f"finetune/pulse/output/{lora_name}/adapter_model.bin")
33 | else:
34 | update_time = 'not available'
35 | return update_time
36 |
37 | def on_train(model_name, lora_name, training_data_file):
38 | training_data_path = 'data/'+os.path.basename(training_data_file.name)
39 | if 'pulse' in model_name.lower():
40 | msg = pulse_train_model(model_name, lora_name, training_data_path)
41 | else:
42 | msg = 'please select one model!'
43 | return msg
44 | nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
45 |
46 |
47 | def upload_file(file):
48 | print('file',file)
49 | if not os.path.exists("data"):
50 | os.mkdir("data")
51 | filename = os.path.basename(file.name)
52 | shutil.move(file.name, "data/" + filename)
53 | # file_list首位插入新上传的文件
54 | filedir = "data/" + filename
55 | return filedir
56 |
57 |
58 | def get_vs_list():
59 | lst_default = []
60 | if not os.path.exists(VS_ROOT_PATH):
61 | return lst_default
62 | lst = os.listdir(VS_ROOT_PATH)
63 | if not lst:
64 | return lst_default
65 | lst.sort()
66 | return lst_default + lst
67 |
68 |
69 |
70 | def get_yaml_files(folder_path):
71 | yaml_files = glob.glob(os.path.join(folder_path, '*.yaml'))
72 | file_names = [os.path.splitext(os.path.basename(file))[0] for file in yaml_files]
73 | return file_names
74 |
75 | yaml_files = get_yaml_files('configs')
76 |
77 |
78 | vs_list = get_vs_list()
79 |
80 |
81 | embedding_model_dict_list = list(embedding_model_dict.keys())
82 |
83 | llm_model_dict_list = list(llm_model_dict.keys())
84 |
85 |
86 | local_doc_qa = LocalDocQA()
87 |
88 | flag_csv_logger = gr.CSVLogger()
89 |
90 |
91 | def read_config(ass_name_en):
92 | config_file_path = f"configs/{ass_name_en}.yaml"
93 | with open(config_file_path, 'r', encoding='utf-8') as file:
94 | yaml = ruamel.yaml.YAML()
95 | config = yaml.load(file)
96 | ass_name = config['ass_name']
97 | llm_model = config['llm_model']
98 | embedding_model = config['embedding_model']
99 | llm_history_len = config['llm_history_len']
100 | lora_name = config['lora_name']
101 | top_k = config['top_k']
102 | score_threshold = config['score_threshold']
103 | chunk_content = config['chunk_content']
104 | chunk_sizes = config['chunk_sizes']
105 | show_reference = config['show_reference']
106 | knowledge_set_name = config['knowledge_set_name']
107 | prompt_template = config['prompt_template']
108 |
109 | return ass_name_en,ass_name, llm_model, embedding_model, lora_name, llm_history_len, knowledge_set_name, top_k, score_threshold, chunk_content, chunk_sizes, show_reference,prompt_template
110 |
111 | def remove_html_tags(text):
112 | clean_text = re.sub('<.*?>', '', text)
113 | return clean_text
114 |
115 | from pynvml import (nvmlInit, nvmlDeviceGetCount, nvmlDeviceGetHandleByIndex,
116 | nvmlDeviceGetName, nvmlDeviceGetMemoryInfo, nvmlShutdown)
117 | def get_available_gpu(threshold=20000):
118 | # Initialize NVML
119 | nvmlInit()
120 | # Get the number of GPU devices
121 | device_count = nvmlDeviceGetCount()
122 |
123 | # Find GPU devices with available memory greater than the threshold
124 | available_gpus = []
125 | for i in range(device_count):
126 | handle = nvmlDeviceGetHandleByIndex(i)
127 | info = nvmlDeviceGetMemoryInfo(handle)
128 | free_memory_mb = info.free / 1024 / 1024
129 |
130 | if free_memory_mb > threshold:
131 | available_gpus.append(i)
132 |
133 | # Shutdown NVML
134 | nvmlShutdown()
135 | # available_gpus = ['0']
136 |
137 | return available_gpus
138 |
139 |
140 | def get_free_memory():
141 | nvmlInit()
142 | # Get the number of GPU devices
143 | device_count = nvmlDeviceGetCount()
144 |
145 | # Find GPU devices with available memory greater than the threshold
146 | free_memory_gpus = []
147 | for i in range(device_count):
148 | handle = nvmlDeviceGetHandleByIndex(i)
149 | info = nvmlDeviceGetMemoryInfo(handle)
150 | free_memory_mb = info.free / 1024 / 1024
151 | free_memory_gpus.append(free_memory_mb)
152 |
153 | # Shutdown NVML
154 | nvmlShutdown()
155 | return free_memory_gpus
156 |
157 | def get_chat_answer(query, ass_id, history, score_threshold=VECTOR_SEARCH_SCORE_THRESHOLD,
158 | vector_search_top_k=VECTOR_SEARCH_TOP_K, chunk_content: bool = True,
159 | chunk_size=CHUNK_SIZE, streaming: bool = STREAMING):
160 | ass_name_en,ass_name, llm_model, embedding_model, lora_name, llm_history_len, knowledge_set_name, top_k, score_threshold, chunk_content, chunk_sizes,show_reference,prompt_template = read_config(ass_id)
161 | if len(history)>0 and len(history[-1])>0:
162 | history[-1][-1] = remove_html_tags(history[-1][-1])
163 | history = history[-llm_history_len:] if history is not None and len(history) > llm_history_len else history
164 | local_doc_qa.top_k = top_k
165 | local_doc_qa.score_threshold = score_threshold
166 | local_doc_qa.chunk_content = chunk_content
167 | local_doc_qa.chunk_size = chunk_size
168 | available_gpus = get_available_gpu(threshold=20000)
169 | if local_doc_qa.llm is None:
170 | if len(available_gpus)>0:
171 | available_gpu = available_gpus[0]
172 | target_device = torch.device(f'cuda:{str(available_gpu)}')
173 | else:
174 | yield [[None,'GPU空间不够,请至少确保机器上GPU剩余空间>20G']],'',''
175 | args_dict = {'model':llm_model, 'lora':lora_name} if lora_name != '不使用' else {'model':llm_model}
176 | shared.loaderCheckPoint = LoaderCheckPoint(args_dict)
177 | llm_model_ins = shared.loaderLLM()
178 | llm_model_ins.set_history_len(llm_history_len)
179 | local_doc_qa.init_cfg(llm_model=llm_model_ins,embedding_model=embedding_model,
180 | embedding_device=target_device)
181 | if local_doc_qa.embeddings is None:
182 | if len(available_gpus)>0:
183 | available_gpu = available_gpus[0]
184 | target_device = torch.device(f'cuda:{str(available_gpu)}')
185 | else:
186 | yield [[None,'GPU空间不够,请至少确保机器上GPU剩余空间>20G']],'',''
187 | local_doc_qa.init_embedding(embedding_model=embedding_model,embedding_device=target_device)
188 | if knowledge_set_name !='不使用知识库':
189 | vs_path = os.path.join(VS_ROOT_PATH, knowledge_set_name)
190 | for resp, history in local_doc_qa.get_knowledge_based_answer(model_name=llm_model,
191 | query=query, vs_path=vs_path, prompt_template=prompt_template,chat_history=history, streaming=streaming):
192 | if len(resp["source_documents"])>0:
193 | source = ""
194 | source += "".join(
195 | [f"""出处 [{i + 1}] {os.path.split(doc.metadata["source"])[-1]}\n"""
196 | f"""{doc.page_content}\n"""
197 | for i, doc in
198 | enumerate(resp["source_documents"])])
199 | else:
200 | source = "暂无"
201 |
202 | yield history, "", source
203 | else:
204 | print('纯聊天模式')
205 | history = history
206 | for answer_result, history in local_doc_qa.get_base_answer(model_name=llm_model,
207 | query=query, prompt_template=prompt_template,chat_history=history, streaming=streaming):
208 | yield history, "", "未挂载知识库"
209 |
210 | def get_knowledge_search(query, vs_path, history, score_threshold=VECTOR_SEARCH_SCORE_THRESHOLD,
211 | vector_search_top_k=VECTOR_SEARCH_TOP_K, chunk_content: bool = True,
212 | chunk_size=CHUNK_SIZE, streaming: bool = STREAMING):
213 | if local_doc_qa.embeddings is None:
214 | local_doc_qa.init_embedding()
215 | if os.path.exists(vs_path):
216 | resp, prompt = local_doc_qa.get_knowledge_based_conent_test(query=query, vs_path=vs_path,
217 | score_threshold=score_threshold,
218 | vector_search_top_k=vector_search_top_k,
219 | chunk_content=chunk_content,
220 | chunk_size=chunk_size)
221 | if not resp["source_documents"]:
222 | yield history + [[query,
223 | "根据您的设定,没有匹配到任何内容,请确认您设置的知识相关度 Score 阈值是否过小或其他参数是否正确。"]], ""
224 | else:
225 | source = "\n".join(
226 | [
227 | f""" 【知识相关度 Score】:{doc.metadata["score"]} - 【出处{i + 1}】: {os.path.split(doc.metadata["source"])[-1]}
\n"""
228 | f"""{doc.page_content}\n"""
229 | f""" """
230 | for i, doc in
231 | enumerate(resp["source_documents"])])
232 | history.append([query, "以下内容为知识库中满足设置条件的匹配结果:\n\n" + source])
233 | yield history, ""
234 | else:
235 | yield history + [[query,
236 | "请选择知识库后进行测试,当前未选择知识库。"]], ""
237 |
238 |
239 | def change_assistant_input(ass_id):
240 |
241 | ass_name_en,ass_name, llm_model, embedding_model, lora_name, llm_history_len, knowledge_set_name, top_k, score_threshold, chunk_content, chunk_sizes,show_reference,prompt_template = read_config(ass_id)
242 |
243 | init_hello = f"你好,我是{ass_name}"
244 |
245 | if show_reference:
246 | return [[None,init_hello]], gr.update(visible=True)
247 | else:
248 | return [[None,init_hello]], gr.update(visible=False)
249 |
250 |
251 |
252 | import ruamel.yaml
253 | def set_config(ass_name_en,ass_name, llm_model, embedding_model, lora_name, llm_history_len, knowledge_set_name, top_k, score_threshold, chunk_content, chunk_sizes,show_reference, prompt_template, ass_list):
254 | config = {
255 | 'ass_name':ass_name,
256 | 'llm_model': llm_model,
257 | 'embedding_model': embedding_model,
258 | 'lora_name': lora_name,
259 | 'llm_history_len': llm_history_len,
260 | 'knowledge_set_name':knowledge_set_name,
261 | 'top_k': top_k,
262 | 'score_threshold':score_threshold,
263 | 'chunk_content':chunk_content,
264 | 'chunk_sizes':chunk_sizes,
265 | 'show_reference':show_reference,
266 | 'prompt_template':prompt_template
267 | }
268 | yaml = ruamel.yaml.YAML()
269 | with open(f'configs/{ass_name_en}.yaml', 'w', encoding="utf-8") as file:
270 | yaml.dump(config, file)
271 |
272 | return gr.update(visible=True),f'configs/{ass_name_en}.yaml 保存成功!',gr.update(visible=True, choices=ass_list, value=ass_list[0])
273 |
274 |
275 | def get_vector_store(vs_id, files, sentence_size, history, one_conent, one_content_segmentation):
276 | if local_doc_qa.embeddings is None:
277 | local_doc_qa.init_embedding()
278 |
279 | vs_path = os.path.join(VS_ROOT_PATH, vs_id)
280 | filelist = []
281 | if not os.path.exists(os.path.join(UPLOAD_ROOT_PATH, vs_id)):
282 | os.makedirs(os.path.join(UPLOAD_ROOT_PATH, vs_id))
283 | if isinstance(files, list):
284 | for file in files:
285 | filename = os.path.split(file.name)[-1]
286 | shutil.move(file.name, os.path.join(UPLOAD_ROOT_PATH, vs_id, filename))
287 | filelist.append(os.path.join(UPLOAD_ROOT_PATH, vs_id, filename))
288 | vs_path, loaded_files = local_doc_qa.init_knowledge_vector_store(filelist, vs_path, sentence_size)
289 | else:
290 | vs_path, loaded_files = local_doc_qa.one_knowledge_add(vs_path, files, one_conent, one_content_segmentation,
291 | sentence_size)
292 | if len(loaded_files):
293 | file_status = f"已添加 {'、'.join([os.path.split(i)[-1] for i in loaded_files if i])} 内容至知识库,并已加载知识库,请开始提问"
294 | else:
295 | file_status = "文件未成功加载,请重新上传文件"
296 | # if local_doc_qa.llm and local_doc_qa.embeddings:
297 |
298 | # else:
299 | # file_status = "模型未完成加载,请先在加载模型后再导入文件"
300 | # vs_path = None
301 | logger.info(file_status)
302 | return vs_path, None, history + [[None, file_status]]
303 |
304 |
305 | def change_vs_name_input(vs_id, history):
306 | if vs_id == "新建知识库":
307 | return gr.update(visible=True), gr.update(visible=True), gr.update(visible=False), None, history
308 | else:
309 | file_status = f"已加载知识库{vs_id},请开始提问"
310 | return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), os.path.join(VS_ROOT_PATH,
311 | vs_id), history + [
312 | [None, file_status]]
313 |
314 |
315 | knowledge_base_test_mode_info = ("【注意】\n\n"
316 | "1.您已进入知识库测试模式,仅用于测试知识库相关参数配置\n\n"
317 | "2.知识相关度 Score,建议设置为 500~800,具体设置情况请结合实际使用调整。\n\n"
318 | "3.目前支持的处理格式包含非图片格式的pdf、word、txt、md、excel、json、jsonl格式\n\n"
319 | "其中excel格式可包含多个数据表,每个数据表里必须包含两列:问题|回答\n\n"
320 | """其中json、jsonl格式处理成多行,每行一个dict类似这样:{"docs": ["氨力农注射液 药物分类\n化学药品", "氨力农注射液 药物剂量\n10毫升:50毫克"]}""")
321 |
322 |
323 | def change_mode(mode, history):
324 | if mode == "知识库问答":
325 | return gr.update(visible=True), gr.update(visible=False), history
326 | # + [[None, "【注意】:您已进入知识库问答模式,您输入的任何查询都将进行知识库查询,然后会自动整理知识库关联内容进入模型查询!!!"]]
327 | elif mode == "知识库配置":
328 | return gr.update(visible=True), gr.update(visible=True), [[None,
329 | knowledge_base_test_mode_info]]
330 | else:
331 | return gr.update(visible=False), gr.update(visible=False), history
332 |
333 |
334 | def change_chunk_content(mode, label_conent, history):
335 | conent = ""
336 | if "chunk_content" in label_conent:
337 | conent = "搜索结果上下文关联"
338 | elif "one_content_segmentation" in label_conent: # 这里没用上,可以先留着
339 | conent = "内容分段入库"
340 |
341 | if mode:
342 | return gr.update(visible=True), history + [[None, f"【已开启{conent}】"]]
343 | else:
344 | return gr.update(visible=False), history + [[None, f"【已关闭{conent}】"]]
345 |
346 |
347 | def add_vs_name(vs_name, vs_list, chatbot):
348 | if not os.path.exists(VS_ROOT_PATH):
349 | os.makedirs(VS_ROOT_PATH)
350 | if vs_name in vs_list:
351 | vs_status = "与已有知识库名称冲突,请重新选择其他名称后提交"
352 | chatbot = chatbot + [[None, vs_status]]
353 | return gr.update(visible=True), vs_list, gr.update(visible=True), gr.update(visible=True), gr.update(
354 | visible=False), chatbot
355 | else:
356 | vs_status = f"""已新增知识库"{vs_name}",将在上传文件并载入成功后进行存储。请在开始对话前,先完成文件上传。 """
357 | chatbot = chatbot + [[None, vs_status]]
358 | return gr.update(visible=True, choices=[vs_name] + vs_list, value=vs_name), [vs_name] + vs_list, gr.update(
359 | visible=False), gr.update(visible=False), gr.update(visible=True), chatbot
360 |
361 | def change_lora_name_input(model_name,lora_name_en):
362 | if lora_name_en == "新建Lora":
363 | return gr.update(visible=True), gr.update(visible=True)
364 | else:
365 | file_status = f"已加载{lora_name_en}"
366 | model_update_time = get_model_update_time(model_name, lora_name_en)
367 | return gr.update(visible=False), gr.update(visible=False), model_update_time
368 |
369 |
370 | def add_lora(lora_name_en,lora_list):
371 | if lora_name_en in lora_list:
372 | print('名称冲突,不新建')
373 | return gr.update(visible=True,value=lora_name_en), gr.update(visible=False), gr.update(visible=False), lora_list
374 | else:
375 | return gr.update(visible=True, choices=[lora_name_en] + lora_list, value=lora_name_en), gr.update(visible=False), gr.update(visible=False),[lora_name_en] + lora_list
376 |
377 | def change_assistant_name_input(ass_id):
378 | if ass_id == "新建小助手":
379 | return gr.update(visible=True), gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), '医疗小助手', LLM_MODEL,EMBEDDING_MODEL,'不使用',LLM_HISTORY_LEN,cur_vs_list.value[0] if len(cur_vs_list.value) > 1 else '不使用知识库',VECTOR_SEARCH_TOP_K,500,True,250,True,''
380 | else:
381 | try:
382 | ass_name_en,ass_name, llm_model, embedding_model, lora_name, llm_history_len, knowledge_set_name, top_k, score_threshold, chunk_content, chunk_sizes,show_reference,prompt_template = read_config(ass_id)
383 | file_status = f"已加载{ass_id}"
384 | print('file_status',file_status)
385 | return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), ass_name, llm_model, embedding_model, lora_name, llm_history_len, knowledge_set_name, top_k, score_threshold, chunk_content, chunk_sizes,show_reference,prompt_template
386 | except Exception as e:
387 | return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), '医疗小助手', LLM_MODEL,EMBEDDING_MODEL,'不使用',LLM_HISTORY_LEN,cur_vs_list.value[0] if len(cur_vs_list.value) > 1 else '不使用知识库',VECTOR_SEARCH_TOP_K,500,True,250,True,''
388 |
389 |
390 | def add_ass_config(ass_id,ass_list):
391 | if ass_id in ass_list:
392 | print('名称冲突,不新建')
393 | return ass_id, gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), '医疗小助手', LLM_MODEL,EMBEDDING_MODEL,'不使用',LLM_HISTORY_LEN,cur_vs_list.value[0] if len(cur_vs_list.value) > 1 else '不使用知识库',VECTOR_SEARCH_TOP_K,500,True,250,True,'',ass_list,gr.update(visible=True)
394 | else:
395 | return ass_id, gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), '医疗小助手', LLM_MODEL,EMBEDDING_MODEL,'不使用',LLM_HISTORY_LEN,cur_vs_list.value[0] if len(cur_vs_list.value) > 1 else '不使用知识库',VECTOR_SEARCH_TOP_K,500,True,250,True,'',ass_list+[ass_id],gr.update(visible=True, choices=ass_list+[ass_id], value=ass_id)
396 |
397 | def find_folders(directory):
398 | folders = []
399 | for item in os.listdir(directory):
400 | item_path = os.path.join(directory, item)
401 | if os.path.isdir(item_path):
402 | folders.append(item)
403 | return folders
404 |
405 | def change_model_name_input(model_name):
406 | if 'pulse' in model_name.lower():
407 | model_name = 'pulse'
408 | else:
409 | model_name = ''
410 | model_dir = os.path.join(f"finetune", model_name,'output')
411 | lora_list = find_folders(model_dir)
412 | return lora_list,gr.update(visible=True, choices=lora_list+["新建Lora"], value=lora_list[0] if len(lora_list)>0 else "新建Lora")
413 |
414 | def change_model_name_select(model_name):
415 | if 'pulse' in model_name.lower():
416 | model_name = 'pulse'
417 | else:
418 | model_name = ''
419 | model_dir = os.path.join(f"finetune", model_name,'output')
420 | lora_list = find_folders(model_dir)
421 | return lora_list,gr.update(visible=True, choices=lora_list+["不使用"], value=lora_list[0] if len(lora_list)>0 else "不使用")
422 |
423 | block_css = """.importantButton {
424 | background: linear-gradient(45deg, #7e0570,#5d1c99, #6e00ff) !important;
425 | border: none !important;
426 | }
427 | .importantButton:hover {
428 | background: linear-gradient(45deg, #ff00e0,#8500ff, #6e00ff) !important;
429 | border: none !important;
430 | }"""
431 |
432 | webui_title = """
433 | # 💁医疗知识聊天机器人💁
434 | """
435 | default_vs = vs_list[0] if len(vs_list) > 1 else "为空"
436 | init_message = f"""请先在右侧选择小助手,再开始对话测试
437 | """
438 |
439 | # 初始化消息
440 | args = None
441 | args = parser.parse_args()
442 |
443 |
444 | model_status = '请手动加载模型'
445 |
446 | default_theme_args = dict(
447 | font=["Source Sans Pro", 'ui-sans-serif', 'system-ui', 'sans-serif'],
448 | font_mono=['IBM Plex Mono', 'ui-monospace', 'Consolas', 'monospace'],
449 | )
450 |
451 | def get_lora_init_list(model_name):
452 | if 'pulse' in model_name.lower():
453 | model_name = 'pulse'
454 | else:
455 | model_name = ''
456 | model_dir = os.path.join(f"finetune", model_name,'output')
457 | if not os.path.exists(model_dir):
458 | os.makedirs(model_dir)
459 | lora_list = find_folders(model_dir)
460 | return lora_list
461 |
462 | lora_init_list = get_lora_init_list(llm_model_dict_list[0])
463 |
464 | with gr.Blocks(css=block_css, theme=gr.themes.Default(**default_theme_args)) as demo:
465 | vs_path, file_status, model_status, set_vs_list , cur_vs_list, set_lora_list, set_ass_list = gr.State(
466 | os.path.join(VS_ROOT_PATH, vs_list[0]) if len(vs_list) > 1 else ""), gr.State(""), gr.State(
467 | model_status), gr.State(vs_list+['新建知识库']), gr.State(vs_list+['不使用知识库']), gr.State(lora_init_list), gr.State(yaml_files)
468 |
469 | gr.Markdown(webui_title)
470 | with gr.Tab("对话测试"):
471 | with gr.Row():
472 | with gr.Column(scale=10):
473 | chatbot = gr.Chatbot([[None, init_message]],
474 | elem_id="chat-box",
475 | show_label=False).style(height=750)
476 | query = gr.Textbox(show_label=False,
477 | placeholder="请输入提问内容,按回车进行提交").style(container=False)
478 | with gr.Column(scale=5):
479 | choose_ass = gr.Dropdown(yaml_files,
480 | label="选择要使用的小助手",
481 | value= yaml_files if len(yaml_files) > 1 else '暂无可使用的',
482 | interactive=True)
483 | reference = gr.Textbox(type="text", label='参考资料',visible=True)
484 | choose_ass.change(fn=change_assistant_input,
485 | inputs=[choose_ass],
486 | outputs=[chatbot,reference])
487 | query.submit(get_chat_answer,
488 | [query, choose_ass, chatbot],
489 | [chatbot, query, reference])
490 |
491 | with gr.Tab("知识库测试"):
492 | with gr.Row():
493 | with gr.Column(scale=10):
494 | chatbot = gr.Chatbot([[None, knowledge_base_test_mode_info]],
495 | elem_id="chat-box",
496 | show_label=False).style(height=750)
497 | query = gr.Textbox(show_label=False,
498 | placeholder="请输入提问内容,按回车进行提交").style(container=False)
499 | with gr.Column(scale=5):
500 | knowledge_set = gr.Accordion("知识库设定", visible=True)
501 | vs_setting = gr.Accordion("配置知识库", visible=True)
502 | with knowledge_set:
503 | score_threshold = gr.Number(value=VECTOR_SEARCH_SCORE_THRESHOLD,
504 | label="知识相关度 Score 阈值,分值越低匹配度越高",
505 | precision=0,
506 | interactive=True)
507 | vector_search_top_k = gr.Number(value=VECTOR_SEARCH_TOP_K, precision=0,
508 | label="获取知识库内容条数", interactive=True)
509 | chunk_content = gr.Checkbox(value=False,
510 | label="是否启用上下文关联",
511 | interactive=True)
512 | chunk_sizes = gr.Number(value=CHUNK_SIZE, precision=0,
513 | label="匹配单段内容的连接上下文后最大长度",
514 | interactive=True, visible=True)
515 | chunk_content.change(fn=change_chunk_content,
516 | inputs=[chunk_content, gr.Textbox(value="chunk_content", visible=False), chatbot],
517 | outputs=[chunk_sizes, chatbot])
518 | with vs_setting:
519 | select_vs = gr.Dropdown(set_vs_list.value,
520 | label="请选择要加载的知识库",
521 | interactive=True,
522 | value=set_vs_list.value[0] if len(set_vs_list.value) > 0 else None)
523 | vs_name = gr.Textbox(label="请输入新建知识库名称,当前知识库命名暂不支持中文",
524 | lines=1,
525 | interactive=True,
526 | visible=True)
527 | vs_add = gr.Button(value="添加至知识库选项", visible=True)
528 | file2vs = gr.Column(visible=False)
529 | with file2vs:
530 | # load_vs = gr.Button("加载知识库")
531 | gr.Markdown("向知识库中添加单条内容或文件")
532 | sentence_size = gr.Number(value=SENTENCE_SIZE, precision=0,
533 | label="文本入库分句长度限制",
534 | interactive=True, visible=True)
535 | with gr.Tab("上传文件"):
536 | files = gr.File(label="添加文件",
537 | file_types=['.txt', '.md', '.docx', '.pdf', '.jsonl'],
538 | file_count="multiple",
539 | show_label=False
540 | )
541 | load_file_button = gr.Button("上传文件并加载知识库")
542 | with gr.Tab("上传文件夹"):
543 | folder_files = gr.File(label="添加文件",
544 | # file_types=['.txt', '.md', '.docx', '.pdf'],
545 | file_count="directory",
546 | show_label=False)
547 | load_folder_button = gr.Button("上传文件夹并加载知识库")
548 | with gr.Tab("添加单条内容"):
549 | one_title = gr.Textbox(label="标题", placeholder="请输入要添加单条段落的标题", lines=1)
550 | one_conent = gr.Textbox(label="内容", placeholder="请输入要添加单条段落的内容", lines=5)
551 | one_content_segmentation = gr.Checkbox(value=True, label="禁止内容分句入库",
552 | interactive=True)
553 | load_conent_button = gr.Button("添加内容并加载知识库")
554 | # 将上传的文件保存到content文件夹下,并更新下拉框
555 | vs_add.click(fn=add_vs_name,
556 | inputs=[vs_name, set_vs_list, chatbot],
557 | outputs=[select_vs, set_vs_list, vs_name, vs_add, file2vs, chatbot])
558 | select_vs.change(fn=change_vs_name_input,
559 | inputs=[select_vs, chatbot],
560 | outputs=[vs_name, vs_add, file2vs, vs_path, chatbot])
561 | load_file_button.click(get_vector_store,
562 | show_progress=True,
563 | inputs=[select_vs, files, sentence_size, chatbot, vs_add, vs_add],
564 | outputs=[vs_path, files, chatbot], )
565 | load_folder_button.click(get_vector_store,
566 | show_progress=True,
567 | inputs=[select_vs, folder_files, sentence_size, chatbot, vs_add,
568 | vs_add],
569 | outputs=[vs_path, folder_files, chatbot], )
570 | load_conent_button.click(get_vector_store,
571 | show_progress=True,
572 | inputs=[select_vs, one_title, sentence_size, chatbot,
573 | one_conent, one_content_segmentation],
574 | outputs=[vs_path, files, chatbot], )
575 | # flag_csv_logger.setup([query, vs_path, chatbot, mode], "flagged")
576 | query.submit(get_knowledge_search,
577 | [query, vs_path, chatbot, score_threshold, vector_search_top_k, chunk_content,
578 | chunk_sizes],
579 | [chatbot, query])
580 | with gr.Tab("lora微调"):
581 | with gr.Row():
582 | with gr.Column():
583 | model_name = gr.Radio(llm_model_dict_list, #'Bert',
584 | label="选择模型",
585 | value= llm_model_dict_list[0] if len(llm_model_dict_list)>0 else '暂无可选模型',
586 | interactive=True)
587 | with gr.Column():
588 | select_lora = gr.Dropdown(set_lora_list.value+['新建Lora'],
589 | label= "选择或者新建一个Lora",
590 | value= set_lora_list.value[0] if len(set_lora_list.value) > 0 else '新建Lora',
591 | interactive=True)
592 | lora_name_en = gr.Textbox(label="请输入Lora英文名称,中间不能有空格,小写字母,单词间可用下划线分开",
593 | lines=1,
594 | interactive=True,
595 | visible=True)
596 | lora_add = gr.Button(value="确认添加Lora", visible=True)
597 | with gr.Row():
598 | lastest_model = gr.outputs.Textbox(type="text", label='模型更新时间(请切换模型或Lora刷新显示)')
599 | gr.Markdown("## lora微调,目前只支持excel格式,要求语料格式为问题|回答两列,或者系统指示|问题|回答三列")
600 | train_data_file = gr.File(label="上传对话语料文件", file_types=['.xlsx'])
601 | train_button = gr.Button("开始训练", label="训练")
602 | kill_train_button = gr.Button("停止所有训练进程", label="训练")
603 | train_res = gr.outputs.Textbox(type="text", label='')
604 | train_data_file.upload(upload_file,
605 | inputs=train_data_file)
606 | train_button.click(on_train, inputs=[model_name, select_lora, train_data_file],outputs=[train_res])
607 | model_name.change(fn=change_model_name_input,
608 | inputs=[model_name],
609 | outputs=[set_lora_list,select_lora])
610 | select_lora.change(fn=change_lora_name_input,
611 | inputs=[model_name,select_lora],
612 | outputs=[lora_name_en, lora_add,lastest_model])
613 | lora_add.click(fn=add_lora,
614 | inputs=[lora_name_en,set_lora_list],
615 | outputs=[select_lora, lora_name_en, lora_add,set_lora_list])
616 |
617 | with gr.Tab("医疗小助手配置"):
618 | with gr.Column():
619 | select_ass = gr.Dropdown(set_ass_list.value+['新建小助手'],
620 | label="选择或者新建一个医疗小助手",
621 | value= set_ass_list.value[0] if len(set_ass_list.value) > 0 else '新建小助手',
622 | interactive=True)
623 | ass_name_en = gr.Textbox(label="请输入小助手英文名称,中间不能有空格,小写字母,单词间可用下划线分开",
624 | lines=1,
625 | interactive=True,
626 | visible=True)
627 | ass_add = gr.Button(value="确认添加小助手", visible=True)
628 | ass_config = gr.Column(visible=False)
629 | with ass_config:
630 | ass_name = gr.Textbox(label="请给机器人取个名字,随便取,中英文均可,可以有空格",
631 | value='医疗小助手',
632 | lines=1,
633 | interactive=True,
634 | visible=True)
635 | llm_model = gr.Radio(llm_model_dict_list,
636 | label="LLM 模型",
637 | value=llm_model_dict_list[0] if len(llm_model_dict_list)>0 else '暂无可选模型',
638 | interactive=True)
639 | embedding_model = gr.Radio(embedding_model_dict_list,
640 | label="Embedding 模型",
641 | value=EMBEDDING_MODEL,
642 | interactive=True)
643 | lora_name = gr.Dropdown(set_lora_list.value+['不使用'],
644 | value='不使用',
645 | label="选择使用的Lora",
646 | interactive=True)
647 | llm_history_len = gr.Slider(0, 10,
648 | value=LLM_HISTORY_LEN,
649 | step=1,
650 | label="LLM 对话轮数",
651 | interactive=True)
652 | knowledge_set_name = gr.Dropdown(cur_vs_list.value,
653 | label="选择知识库",
654 | value= cur_vs_list.value[0] if len(cur_vs_list.value) > 1 else '不使用知识库',
655 | interactive=True)
656 | top_k = gr.Slider(1, 20, value=VECTOR_SEARCH_TOP_K, step=1,
657 | label="向量匹配 top k", interactive=True)
658 | score_threshold = gr.Number(value=700,label="知识相关度 Score 阈值,分值越低匹配度越高,数值范围约为0-1100,一般在700左右",
659 | precision=0,
660 | interactive=True)
661 | chunk_content = gr.Checkbox(value=True,
662 | label="是否启用上下文关联",
663 | interactive=True)
664 | chunk_sizes = gr.Number(value=250, precision=0,
665 | label="匹配单段内容的连接上下文后最大长度",
666 | interactive=True, visible=True)
667 | show_reference = gr.Checkbox(value=True,
668 | label="是否显示参考文献窗口",
669 | interactive=True)
670 | prompt_note = """prompt_template ,{context} 代表搜出来的文档,{chat_history}代表历史聊天记录,{question}代表最后一个问题,请在prompt里加上这些关键词。注意不使用知识库的情况下不生效。参考例子:假设你是用药助手,请根据文档来回复,如果文档内容为空或者None,则忽略,文档:{context}\n{chat_history}User: {question}Helper:"""
671 | gr.Markdown(prompt_note)
672 | prompt_template = gr.Textbox(label="内置prompt模板",
673 | value="假设你是用药助手,请根据文档来回复,如果文档内容为空或者None,则忽略,文档:{context}\n{chat_history}User:{question}Helper:",
674 | lines=8,
675 | interactive=True,
676 | visible=True)
677 | save_config_button = gr.Button("保存助手配置")
678 | save_res = gr.Textbox(type="text", label='', visible=False)
679 | save_config_button.click(set_config, show_progress=True,
680 | inputs=[select_ass,ass_name, llm_model, embedding_model, lora_name, llm_history_len, knowledge_set_name, top_k, score_threshold, chunk_content, chunk_sizes, show_reference, prompt_template,set_ass_list], outputs=[save_res,save_res,choose_ass])
681 |
682 | llm_model.change(fn=change_model_name_select,
683 | inputs=[llm_model],
684 | outputs=[set_lora_list,lora_name])
685 | select_ass.change(fn=change_assistant_name_input,
686 | inputs=[select_ass],
687 | outputs=[ass_name_en, ass_add, ass_config, save_res, ass_name, llm_model, embedding_model, lora_name, llm_history_len, knowledge_set_name, top_k, score_threshold, chunk_content, chunk_sizes, show_reference, prompt_template])
688 | ass_add.click(fn=add_ass_config,
689 | inputs=[ass_name_en,set_ass_list],
690 | outputs=[select_ass, ass_name_en, ass_add, ass_config, save_res, ass_name, llm_model, embedding_model, lora_name, llm_history_len, knowledge_set_name, top_k, score_threshold, chunk_content, chunk_sizes, show_reference, prompt_template,set_ass_list,select_ass])
691 | (demo
692 | .queue(concurrency_count=3)
693 | .launch(server_name='0.0.0.0',
694 | server_port=3355,
695 | show_api=False,
696 | share=True,
697 | debug= True,
698 | inbrowser=True))
--------------------------------------------------------------------------------
/chains/local_doc_qa.py:
--------------------------------------------------------------------------------
1 | from langchain.embeddings.huggingface import HuggingFaceEmbeddings
2 | from langchain.vectorstores import FAISS
3 | from langchain.document_loaders import UnstructuredFileLoader, TextLoader
4 | from configs.common_config import *
5 | import datetime
6 | from loader.textsplitter import ChineseTextSplitter
7 | from typing import List, Tuple, Dict
8 | from langchain.docstore.document import Document
9 | import numpy as np
10 | from loader.utils import torch_gc
11 | from tqdm import tqdm
12 | from pypinyin import lazy_pinyin
13 | from loader import UnstructuredPaddleImageLoader, UnstructuredPaddlePDFLoader
14 | from loader.models.base import (BaseAnswer,
15 | AnswerResult,
16 | AnswerResultStream,
17 | AnswerResultQueueSentinelTokenListenerQueue)
18 | from loader.models.loader.args import parser
19 | from loader.models.loader import LoaderCheckPoint
20 | import loader.models.shared as shared
21 | # from agent import bing_search
22 | from langchain.docstore.document import Document
23 | from .modules.json_load import JsonLoader
24 | from .modules.excel_load import ExcelLoader
25 |
26 | def load_file(filepath, sentence_size=SENTENCE_SIZE):
27 | if filepath.lower().endswith(".md"):
28 | loader = UnstructuredFileLoader(filepath, mode="elements")
29 | docs = loader.load()
30 | elif filepath.lower().endswith(".txt"):
31 | loader = TextLoader(filepath, autodetect_encoding=True)
32 | textsplitter = ChineseTextSplitter(pdf=False, sentence_size=sentence_size)
33 | docs = loader.load_and_split(textsplitter)
34 | elif filepath.lower().endswith(".pdf"):
35 | loader = UnstructuredPaddlePDFLoader(filepath)
36 | textsplitter = ChineseTextSplitter(pdf=True, sentence_size=sentence_size)
37 | docs = loader.load_and_split(textsplitter)
38 | elif filepath.lower().endswith(".jpg") or filepath.lower().endswith(".png"):
39 | loader = UnstructuredPaddleImageLoader(filepath, mode="elements")
40 | textsplitter = ChineseTextSplitter(pdf=False, sentence_size=sentence_size)
41 | docs = loader.load_and_split(text_splitter=textsplitter)
42 | elif filepath.lower().endswith(".jsonl") or filepath.lower().endswith(".json"):
43 | loader = JsonLoader(filepath, autodetect_encoding=True)
44 | docs = loader.load_json()
45 | elif filepath.lower().endswith(".xlsx"):
46 | loader = ExcelLoader(filepath, autodetect_encoding=True)
47 | docs = loader.load_excel()
48 | else:
49 | loader = UnstructuredFileLoader(filepath, mode="elements")
50 | textsplitter = ChineseTextSplitter(pdf=False, sentence_size=sentence_size)
51 | docs = loader.load_and_split(text_splitter=textsplitter)
52 | write_check_file(filepath, docs)
53 | return docs
54 |
55 |
56 | def write_check_file(filepath, docs):
57 | folder_path = os.path.join(os.path.dirname(filepath), "tmp_files")
58 | if not os.path.exists(folder_path):
59 | os.makedirs(folder_path)
60 | fp = os.path.join(folder_path, 'load_file.txt')
61 | with open(fp, 'a+', encoding='utf-8') as fout:
62 | fout.write("filepath=%s,len=%s" % (filepath, len(docs)))
63 | fout.write('\n')
64 | for i in docs:
65 | fout.write(str(i))
66 | fout.write('\n')
67 | fout.close()
68 |
69 | def concat_history(model_name:str,chat_history: List[str]=[],query:str = '',process_type='chat'):
70 | formatted_history = ""
71 | if process_type=='chat':
72 | if 'pulse' in model_name.lower():
73 | for sublist in chat_history:
74 | if sublist[0] is not None:
75 | formatted_history += f"\nUser: {sublist[0]}\nHelper: {sublist[1]}"
76 | else:
77 | for sublist in chat_history:
78 | if sublist[0] is not None:
79 | formatted_history += f"\nUser: {sublist[0]}\nHelper: {sublist[1]}"
80 | elif process_type=='search':
81 | chat_history = chat_history[-1:]
82 | for j,sublist in enumerate(chat_history, start=1):
83 | if sublist[0] is not None:
84 | formatted_history += f"User:{sublist[0]}Helper: {sublist[1]}"
85 | formatted_history+=f"User:{query}"
86 | return formatted_history
87 |
88 | def remove_duplicates(lst):
89 | seen = set()
90 | result = []
91 | for i, item in enumerate(lst, start=1):
92 | page_content = item.page_content
93 | if page_content not in seen:
94 | seen.add(page_content)
95 | result.append(item)
96 | return result
97 |
98 | def generate_prompt(model_name:str, related_docs: List[str],
99 | query: str,
100 | chat_history: List[str]=[],prompt_template: str = PROMPT_TEMPLATE,) -> str:
101 | related_docs = remove_duplicates(related_docs)
102 | if 'pulse' in model_name.lower():
103 | formatted_context = ""
104 | for i, doc in enumerate(related_docs, start=1):
105 | formatted_context += f"[{i}]\n```\n{doc.page_content}\n```\n"
106 | formatted_context = "None" if formatted_context == "" else formatted_context
107 | formatted_history = concat_history(model_name,chat_history,query,process_type='chat')
108 | prompt = prompt_template.replace("{question}", query).replace("{context}", formatted_context).replace("{chat_history}", formatted_history)
109 | else:
110 | if len(related_docs)>0:
111 | formatted_context = ""
112 | for i, doc in enumerate(related_docs, start=1):
113 | formatted_context += f"[{i}]{doc.page_content}\n"
114 | formatted_context = "None" if formatted_context == "" else formatted_context
115 | formatted_history = concat_history(model_name,chat_history,query,process_type='chat')
116 | prompt = prompt_template.replace("{question}", query).replace("{context}", formatted_context).replace("{chat_history}", formatted_history)
117 | else:
118 | prompt = prompt_template.replace("{question}", query)
119 | return prompt
120 |
121 |
122 | def seperate_list(ls: List[int]) -> List[List[int]]:
123 | lists = []
124 | ls1 = [ls[0]]
125 | for i in range(1, len(ls)):
126 | if ls[i - 1] + 1 == ls[i]:
127 | ls1.append(ls[i])
128 | else:
129 | lists.append(ls1)
130 | ls1 = [ls[i]]
131 | lists.append(ls1)
132 | return lists
133 |
134 |
135 | def similarity_search_with_score_by_vector(
136 | self, embedding: List[float], k: int = 4
137 | ) -> List[Tuple[Document, float]]:
138 | scores, indices = self.index.search(np.array([embedding], dtype=np.float32), k)
139 | docs = []
140 | id_set = set()
141 | store_len = len(self.index_to_docstore_id)
142 | for j, i in enumerate(indices[0]):
143 | if i == -1 or 0 < self.score_threshold < scores[0][j]:
144 | # This happens when not enough docs are returned.
145 | continue
146 | _id = self.index_to_docstore_id[i]
147 | doc = self.docstore.search(_id)
148 | if not self.chunk_content:
149 | if not isinstance(doc, Document):
150 | raise ValueError(f"Could not find document for id {_id}, got {doc}")
151 | doc.metadata["score"] = int(scores[0][j])
152 | docs.append(doc)
153 | continue
154 | id_set.add(i)
155 | docs_len = len(doc.page_content)
156 | for k in range(1, max(i, store_len - i)):
157 | break_flag = False
158 | for l in [i + k, i - k]:
159 | if 0 <= l < len(self.index_to_docstore_id):
160 | _id0 = self.index_to_docstore_id[l]
161 | doc0 = self.docstore.search(_id0)
162 | if docs_len + len(doc0.page_content) > self.chunk_size:
163 | break_flag = True
164 | break
165 | elif doc0.metadata["source"] == doc.metadata["source"]:
166 | docs_len += len(doc0.page_content)
167 | id_set.add(l)
168 | if break_flag:
169 | break
170 | if not self.chunk_content:
171 | return docs
172 | if len(id_set) == 0 and self.score_threshold > 0:
173 | return []
174 | id_list = sorted(list(id_set))
175 | id_lists = seperate_list(id_list)
176 | for id_seq in id_lists:
177 | for id in id_seq:
178 | if id == id_seq[0]:
179 | _id = self.index_to_docstore_id[id]
180 | doc = self.docstore.search(_id)
181 | else:
182 | _id0 = self.index_to_docstore_id[id]
183 | doc0 = self.docstore.search(_id0)
184 | doc.page_content += " " + doc0.page_content
185 | if not isinstance(doc, Document):
186 | raise ValueError(f"Could not find document for id {_id}, got {doc}")
187 | doc_score = min([scores[0][id] for id in [indices[0].tolist().index(i) for i in id_seq if i in indices[0]]])
188 | doc.metadata["score"] = int(doc_score)
189 | docs.append(doc)
190 | torch_gc()
191 | return docs
192 |
193 |
194 | def search_result2docs(search_results):
195 | docs = []
196 | for result in search_results:
197 | doc = Document(page_content=result["snippet"] if "snippet" in result.keys() else "",
198 | metadata={"source": result["link"] if "link" in result.keys() else "",
199 | "filename": result["title"] if "title" in result.keys() else ""})
200 | docs.append(doc)
201 | return docs
202 |
203 |
204 | class LocalDocQA:
205 | llm: BaseAnswer = None
206 | embeddings: object = None
207 | top_k: int = VECTOR_SEARCH_TOP_K
208 | chunk_size: int = CHUNK_SIZE
209 | chunk_content: bool = True
210 | score_threshold: int = VECTOR_SEARCH_SCORE_THRESHOLD
211 |
212 | def init_cfg(self,
213 | embedding_model: str = EMBEDDING_MODEL,
214 | embedding_device=EMBEDDING_DEVICE,
215 | llm_model: BaseAnswer = None,
216 | top_k=VECTOR_SEARCH_TOP_K,
217 | ):
218 | self.llm = llm_model
219 | print('embedding_device',embedding_device)
220 | self.embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[embedding_model],
221 | model_kwargs={'device': embedding_device})
222 | self.top_k = top_k
223 |
224 | def init_embedding(self,
225 | embedding_model: str = EMBEDDING_MODEL,
226 | embedding_device=EMBEDDING_DEVICE,
227 | top_k=VECTOR_SEARCH_TOP_K,
228 | ):
229 | self.embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[embedding_model],
230 | model_kwargs={'device': embedding_device})
231 | self.top_k = top_k
232 |
233 | def init_knowledge_vector_store(self,
234 | filepath: str or List[str],
235 | vs_path: str or os.PathLike = None,
236 | sentence_size=SENTENCE_SIZE):
237 | loaded_files = []
238 | failed_files = []
239 | if isinstance(filepath, str):
240 | if not os.path.exists(filepath):
241 | print("路径不存在")
242 | return None
243 | elif os.path.isfile(filepath):
244 | file = os.path.split(filepath)[-1]
245 | try:
246 | docs = load_file(filepath, sentence_size)
247 | logger.info(f"{file} 已成功加载")
248 | loaded_files.append(filepath)
249 | except Exception as e:
250 | logger.error(e)
251 | logger.info(f"{file} 未能成功加载")
252 | return None
253 | elif os.path.isdir(filepath):
254 | docs = []
255 | for file in tqdm(os.listdir(filepath), desc="加载文件"):
256 | fullfilepath = os.path.join(filepath, file)
257 | try:
258 | docs += load_file(fullfilepath, sentence_size)
259 | loaded_files.append(fullfilepath)
260 | except Exception as e:
261 | logger.error(e)
262 | failed_files.append(file)
263 |
264 | if len(failed_files) > 0:
265 | logger.info("以下文件未能成功加载:")
266 | for file in failed_files:
267 | logger.info(f"{file}\n")
268 |
269 | else:
270 | docs = []
271 | for file in filepath:
272 | try:
273 | docs += load_file(file)
274 | logger.info(f"{file} 已成功加载")
275 | loaded_files.append(file)
276 | except Exception as e:
277 | logger.error(e)
278 | logger.info(f"{file} 未能成功加载")
279 | if len(docs) > 0:
280 | logger.info("文件加载完毕,正在生成向量库")
281 | if vs_path and os.path.isdir(vs_path):
282 | vector_store = FAISS.load_local(vs_path, self.embeddings)
283 | vector_store.add_documents(docs)
284 | torch_gc()
285 | else:
286 | if not vs_path:
287 | vs_path = os.path.join(VS_ROOT_PATH,
288 | f"""{"".join(lazy_pinyin(os.path.splitext(file)[0]))}_FAISS_{datetime.datetime.now().strftime("%Y%m%d_%H%M%S")}""")
289 | vector_store = FAISS.from_documents(docs, self.embeddings) # docs 为Document列表
290 | torch_gc()
291 |
292 | vector_store.save_local(vs_path)
293 | return vs_path, loaded_files
294 | else:
295 | logger.info("文件均未成功加载,请检查依赖包或替换为其他文件再次上传。")
296 | return None, loaded_files
297 |
298 | def one_knowledge_add(self, vs_path, one_title, one_conent, one_content_segmentation, sentence_size):
299 | try:
300 | if not vs_path or not one_title or not one_conent:
301 | logger.info("知识库添加错误,请确认知识库名字、标题、内容是否正确!")
302 | return None, [one_title]
303 | docs = [Document(page_content=one_conent + "\n", metadata={"source": one_title})]
304 | if not one_content_segmentation:
305 | text_splitter = ChineseTextSplitter(pdf=False, sentence_size=sentence_size)
306 | docs = text_splitter.split_documents(docs)
307 | if os.path.isdir(vs_path):
308 | vector_store = FAISS.load_local(vs_path, self.embeddings)
309 | vector_store.add_documents(docs)
310 | else:
311 | vector_store = FAISS.from_documents(docs, self.embeddings) ##docs 为Document列表
312 | torch_gc()
313 | vector_store.save_local(vs_path)
314 | return vs_path, [one_title]
315 | except Exception as e:
316 | logger.error(e)
317 | return None, [one_title]
318 |
319 | def get_knowledge_based_answer(self, model_name, query, vs_path, prompt_template, chat_history=[], streaming: bool = STREAMING):
320 | vector_store = FAISS.load_local(vs_path, self.embeddings)
321 | FAISS.similarity_search_with_score_by_vector = similarity_search_with_score_by_vector
322 | vector_store.chunk_size = self.chunk_size
323 | vector_store.chunk_content = self.chunk_content
324 | vector_store.score_threshold = self.score_threshold
325 | formatted_history = concat_history(model_name,chat_history,query,process_type='search')
326 | # 根据历史记录来检索
327 | related_docs_with_score = vector_store.similarity_search_with_score(formatted_history, k=self.top_k)
328 | torch_gc()
329 | prompt = generate_prompt(model_name, related_docs_with_score, query, chat_history, prompt_template)
330 |
331 | for answer_result in self.llm.generatorAnswer(prompt=prompt, history=chat_history,
332 | streaming=streaming):
333 | resp = answer_result.llm_output["answer"]
334 | history = answer_result.history
335 | history[-1][0] = query
336 | response = {"query": query,
337 | "result": resp,
338 | "source_documents": related_docs_with_score}
339 | yield response, history
340 |
341 | def get_base_answer(self, model_name, query, prompt_template, chat_history=[], streaming: bool = STREAMING):
342 | formatted_history = concat_history(model_name,chat_history,query,process_type='chat')
343 | torch_gc()
344 | if prompt_template == '':
345 | prompt = f"{formatted_history}User:{query}Helper:"
346 | else:
347 | prompt = generate_prompt(model_name, [], query, chat_history, prompt_template)
348 |
349 | for answer_result in self.llm.generatorAnswer(prompt=prompt, history=chat_history,
350 | streaming=streaming):
351 | resp = answer_result.llm_output["answer"]
352 | history = answer_result.history
353 | history[-1][0] = query
354 | response = {"query": query,
355 | "result": resp}
356 | yield response, history
357 |
358 | # query 查询内容
359 | # vs_path 知识库路径
360 | # chunk_content 是否启用上下文关联
361 | # score_threshold 搜索匹配score阈值
362 | # vector_search_top_k 搜索知识库内容条数,默认搜索5条结果
363 | # chunk_sizes 匹配单段内容的连接上下文长度
364 | def get_knowledge_based_conent_test(self, query, vs_path, chunk_content,
365 | score_threshold=VECTOR_SEARCH_SCORE_THRESHOLD,
366 | vector_search_top_k=VECTOR_SEARCH_TOP_K, chunk_size=CHUNK_SIZE):
367 | vector_store = FAISS.load_local(vs_path, self.embeddings)
368 | FAISS.similarity_search_with_score_by_vector = similarity_search_with_score_by_vector
369 | vector_store.chunk_content = chunk_content
370 | vector_store.score_threshold = score_threshold
371 | vector_store.chunk_size = chunk_size
372 | related_docs_with_score = vector_store.similarity_search_with_score(query, k=vector_search_top_k)
373 | if not related_docs_with_score:
374 | response = {"query": query,
375 | "source_documents": []}
376 | return response, ""
377 | torch_gc()
378 | prompt = "\n".join([doc.page_content for doc in related_docs_with_score])
379 | response = {"query": query,
380 | "source_documents": related_docs_with_score}
381 | return response, prompt
382 |
383 | def get_search_result_based_answer(self,model_name, query, chat_history=[],prompt_template=BASE_PROMPT_TEMPLATE, streaming: bool = STREAMING):
384 | results = bing_search(query)
385 | result_docs = search_result2docs(results)
386 | prompt = generate_prompt(model_name, result_docs, query, chat_history, prompt_template)
387 |
388 | for answer_result in self.llm.generatorAnswer(prompt=prompt, history=chat_history,
389 | streaming=streaming):
390 | resp = answer_result.llm_output["answer"]
391 | history = answer_result.history
392 | history[-1][0] = query
393 | response = {"query": query,
394 | "result": resp,
395 | "source_documents": result_docs}
396 | yield response, history
397 |
398 |
399 |
--------------------------------------------------------------------------------
/chains/modules/embeddings.py:
--------------------------------------------------------------------------------
1 | from langchain.embeddings.huggingface import HuggingFaceEmbeddings
2 |
3 | from typing import Any, List
4 |
5 |
6 | class MyEmbeddings(HuggingFaceEmbeddings):
7 | def __init__(self, **kwargs: Any):
8 | super().__init__(**kwargs)
9 |
10 | def embed_documents(self, texts: List[str]) -> List[List[float]]:
11 | """Compute doc embeddings using a HuggingFace transformer model.
12 |
13 | Args:
14 | texts: The list of texts to embed.
15 |
16 | Returns:
17 | List of embeddings, one for each text.
18 | """
19 | texts = list(map(lambda x: x.replace("\n", " "), texts))
20 | embeddings = self.client.encode(texts, normalize_embeddings=True)
21 | return embeddings.tolist()
22 |
23 | def embed_query(self, text: str) -> List[float]:
24 | """Compute query embeddings using a HuggingFace transformer model.
25 |
26 | Args:
27 | text: The text to embed.
28 |
29 | Returns:
30 | Embeddings for the text.
31 | """
32 | text = text.replace("\n", " ")
33 | embedding = self.client.encode(text, normalize_embeddings=True)
34 | return embedding.tolist()
35 |
--------------------------------------------------------------------------------
/chains/modules/excel_load.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from typing import List, Optional
3 |
4 | from langchain.docstore.document import Document
5 | from langchain.document_loaders.base import BaseLoader
6 | from langchain.document_loaders.helpers import detect_file_encodings
7 | import json
8 | import pandas as pd
9 |
10 | logger = logging.getLogger(__name__)
11 |
12 | def process_json(json_str):
13 | docs=[]
14 | for doc in json_str['docs']:
15 | new_doc = Document(
16 | page_content=doc, metadata={"source": doc.split(' ')[0]}
17 | )
18 | docs.append(new_doc)
19 |
20 | return docs
21 |
22 | class ExcelLoader(BaseLoader):
23 | """Load text files.
24 |
25 |
26 | Args:
27 | file_path: Path to the file to load.
28 |
29 | encoding: File encoding to use. If `None`, the file will be loaded
30 | with the default system encoding.
31 |
32 | autodetect_encoding: Whether to try to autodetect the file encoding
33 | if the specified encoding fails.
34 | """
35 |
36 | def __init__(
37 | self,
38 | file_path: str,
39 | encoding: Optional[str] = None,
40 | autodetect_encoding: bool = False,
41 | ):
42 | """Initialize with file path."""
43 | self.file_path = file_path
44 | self.encoding = encoding
45 | self.autodetect_encoding = autodetect_encoding
46 |
47 | def load(self) -> List[Document]:
48 | """Load from file path."""
49 | pass
50 |
51 | def load_excel(self) -> List[Document]:
52 | """Load from file path."""
53 | documents = []
54 | try:
55 | df = pd.read_excel(self.file_path, sheet_name=None)
56 | # 获取所有sheet的名字
57 | sheet_names = df.keys()
58 | # 遍历每个sheet
59 | for sheet_name in sheet_names:
60 | # 获取当前sheet的数据
61 | sheet_data = df[sheet_name]
62 | # 遍历每行数据
63 | for _, row in sheet_data.iterrows():
64 | try:
65 | question = str(row["问题"])
66 | answer = str(row["回答"])
67 | doc = "【参考问题】"+question+"【参考回答】"+answer
68 | new_doc = Document(
69 | page_content=doc, metadata={"source": sheet_name}
70 | )
71 | documents.append(new_doc)
72 | except Exception as e:
73 | print('文件表读取发生错误!',sheet_name,e)
74 | except Exception as e:
75 | print('e2',e)
76 | raise RuntimeError(f"Error loading {self.file_path}") from e
77 | return documents
--------------------------------------------------------------------------------
/chains/modules/json_load.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from typing import List, Optional
3 |
4 | from langchain.docstore.document import Document
5 | from langchain.document_loaders.base import BaseLoader
6 | from langchain.document_loaders.helpers import detect_file_encodings
7 | import json
8 | logger = logging.getLogger(__name__)
9 |
10 | def process_json(json_str):
11 | docs=[]
12 | for doc in json_str['docs']:
13 | new_doc = Document(
14 | page_content=doc, metadata={"source": doc.split(' ')[0]}
15 | )
16 | docs.append(new_doc)
17 |
18 | return docs
19 |
20 | class JsonLoader(BaseLoader):
21 | """Load text files.
22 |
23 |
24 | Args:
25 | file_path: Path to the file to load.
26 |
27 | encoding: File encoding to use. If `None`, the file will be loaded
28 | with the default system encoding.
29 |
30 | autodetect_encoding: Whether to try to autodetect the file encoding
31 | if the specified encoding fails.
32 | """
33 |
34 | def __init__(
35 | self,
36 | file_path: str,
37 | encoding: Optional[str] = None,
38 | autodetect_encoding: bool = False,
39 | ):
40 | """Initialize with file path."""
41 | self.file_path = file_path
42 | self.encoding = encoding
43 | self.autodetect_encoding = autodetect_encoding
44 |
45 | def load(self) -> List[Document]:
46 | """Load from file path."""
47 | pass
48 |
49 | def load_json(self) -> List[Document]:
50 | """Load from file path."""
51 | documents = []
52 | try:
53 | with open(self.file_path, encoding=self.encoding) as f:
54 | for line in f:
55 | json_object = json.loads(line)
56 | docs = process_json(json_object)
57 | documents+= docs
58 | except UnicodeDecodeError as e:
59 | print('e1',e)
60 | if self.autodetect_encoding:
61 | detected_encodings = detect_file_encodings(self.file_path)
62 | for encoding in detected_encodings:
63 | logger.debug("Trying encoding: ", encoding.encoding)
64 | try:
65 | with open(self.file_path, 'r', encoding=encoding.encoding) as f:
66 | for line in f:
67 | json_object = json.loads(line)
68 | docs = process_json(json_object)
69 | documents+= docs
70 | break
71 | except UnicodeDecodeError:
72 | continue
73 | else:
74 | raise RuntimeError(f"Error loading {self.file_path}") from e
75 | except Exception as e:
76 | print('e2',e)
77 | raise RuntimeError(f"Error loading {self.file_path}") from e
78 | return documents
--------------------------------------------------------------------------------
/chains/modules/vectorstores.py:
--------------------------------------------------------------------------------
1 | from langchain.vectorstores import FAISS
2 | from typing import Any, Callable, List, Optional, Tuple, Dict
3 | from langchain.docstore.document import Document
4 | from langchain.docstore.base import Docstore
5 |
6 | from langchain.vectorstores.utils import maximal_marginal_relevance
7 | from langchain.embeddings.base import Embeddings
8 | import uuid
9 | from langchain.docstore.in_memory import InMemoryDocstore
10 |
11 | import numpy as np
12 |
13 | def dependable_faiss_import() -> Any:
14 | """Import faiss if available, otherwise raise error."""
15 | try:
16 | import faiss
17 | except ImportError:
18 | raise ValueError(
19 | "Could not import faiss python package. "
20 | "Please install it with `pip install faiss` "
21 | "or `pip install faiss-cpu` (depending on Python version)."
22 | )
23 | return faiss
24 |
25 | class FAISSVS(FAISS):
26 | def __init__(self,
27 | embedding_function: Callable[..., Any],
28 | index: Any,
29 | docstore: Docstore,
30 | index_to_docstore_id: Dict[int, str]):
31 | super().__init__(embedding_function, index, docstore, index_to_docstore_id)
32 |
33 | def max_marginal_relevance_search_by_vector(
34 | self, embedding: List[float], k: int = 4, fetch_k: int = 20, **kwargs: Any
35 | ) -> List[Tuple[Document, float]]:
36 | """Return docs selected using the maximal marginal relevance.
37 |
38 | Maximal marginal relevance optimizes for similarity to query AND diversity
39 | among selected documents.
40 |
41 | Args:
42 | embedding: Embedding to look up documents similar to.
43 | k: Number of Documents to return. Defaults to 4.
44 | fetch_k: Number of Documents to fetch to pass to MMR algorithm.
45 |
46 | Returns:
47 | List of Documents with scores selected by maximal marginal relevance.
48 | """
49 | scores, indices = self.index.search(np.array([embedding], dtype=np.float32), fetch_k)
50 | # -1 happens when not enough docs are returned.
51 | embeddings = [self.index.reconstruct(int(i)) for i in indices[0] if i != -1]
52 | mmr_selected = maximal_marginal_relevance(
53 | np.array([embedding], dtype=np.float32), embeddings, k=k
54 | )
55 | selected_indices = [indices[0][i] for i in mmr_selected]
56 | selected_scores = [scores[0][i] for i in mmr_selected]
57 | docs = []
58 | for i, score in zip(selected_indices, selected_scores):
59 | if i == -1:
60 | # This happens when not enough docs are returned.
61 | continue
62 | _id = self.index_to_docstore_id[i]
63 | doc = self.docstore.search(_id)
64 | if not isinstance(doc, Document):
65 | raise ValueError(f"Could not find document for id {_id}, got {doc}")
66 | docs.append((doc, score))
67 | return docs
68 |
69 | def max_marginal_relevance_search(
70 | self,
71 | query: str,
72 | k: int = 4,
73 | fetch_k: int = 20,
74 | **kwargs: Any,
75 | ) -> List[Tuple[Document, float]]:
76 | """Return docs selected using the maximal marginal relevance.
77 |
78 | Maximal marginal relevance optimizes for similarity to query AND diversity
79 | among selected documents.
80 |
81 | Args:
82 | query: Text to look up documents similar to.
83 | k: Number of Documents to return. Defaults to 4.
84 | fetch_k: Number of Documents to fetch to pass to MMR algorithm.
85 |
86 | Returns:
87 | List of Documents with scores selected by maximal marginal relevance.
88 | """
89 | embedding = self.embedding_function(query)
90 | docs = self.max_marginal_relevance_search_by_vector(embedding, k, fetch_k)
91 | return docs
92 |
93 | @classmethod
94 | def __from(
95 | cls,
96 | texts: List[str],
97 | embeddings: List[List[float]],
98 | embedding: Embeddings,
99 | metadatas: Optional[List[dict]] = None,
100 | **kwargs: Any,
101 | ) -> FAISS:
102 | faiss = dependable_faiss_import()
103 | index = faiss.IndexFlatIP(len(embeddings[0]))
104 | index.add(np.array(embeddings, dtype=np.float32))
105 |
106 | # # my code, for speeding up search
107 | # quantizer = faiss.IndexFlatL2(len(embeddings[0]))
108 | # index = faiss.IndexIVFFlat(quantizer, len(embeddings[0]), 100)
109 | # index.train(np.array(embeddings, dtype=np.float32))
110 | # index.add(np.array(embeddings, dtype=np.float32))
111 |
112 | documents = []
113 | for i, text in enumerate(texts):
114 | metadata = metadatas[i] if metadatas else {}
115 | documents.append(Document(page_content=text, metadata=metadata))
116 | index_to_id = {i: str(uuid.uuid4()) for i in range(len(documents))}
117 | docstore = InMemoryDocstore(
118 | {index_to_id[i]: doc for i, doc in enumerate(documents)}
119 | )
120 | return cls(embedding.embed_query, index, docstore, index_to_id)
121 |
122 |
--------------------------------------------------------------------------------
/chains/text_load.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pinecone
3 | from tqdm import tqdm
4 | from langchain.llms import OpenAI
5 | from langchain.text_splitter import SpacyTextSplitter
6 | from langchain.document_loaders import TextLoader
7 | from langchain.document_loaders import DirectoryLoader
8 | from langchain.indexes import VectorstoreIndexCreator
9 | from langchain.embeddings.openai import OpenAIEmbeddings
10 | from langchain.vectorstores import Pinecone
11 |
12 | #一些配置文件
13 | openai_key="你的key" # 注册 openai.com 后获得
14 | pinecone_key="你的key" # 注册 app.pinecone.io 后获得
15 | pinecone_index="你的库" #app.pinecone.io 获得
16 | pinecone_environment="你的Environment" # 登录pinecone后,在indexes页面 查看Environment
17 | pinecone_namespace="你的Namespace" #如果不存在自动创建
18 |
19 | #科学上网你懂得
20 | os.environ['HTTP_PROXY'] = 'http://127.0.0.1:7890'
21 | os.environ['HTTPS_PROXY'] = 'http://127.0.0.1:7890'
22 |
23 | #初始化pinecone
24 | pinecone.init(
25 | api_key=pinecone_key,
26 | environment=pinecone_environment
27 | )
28 | index = pinecone.Index(pinecone_index)
29 |
30 | #初始化OpenAI的embeddings
31 | embeddings = OpenAIEmbeddings(openai_api_key=openai_key)
32 |
33 | #初始化text_splitter
34 | text_splitter = SpacyTextSplitter(pipeline='zh_core_web_sm',chunk_size=1000,chunk_overlap=200)
35 |
36 | # 读取目录下所有后缀是txt的文件
37 | loader = DirectoryLoader('../docs', glob="**/*.txt", loader_cls=TextLoader)
38 |
39 | #读取文本文件
40 | documents = loader.load()
41 |
42 | # 使用text_splitter对文档进行分割
43 | split_text = text_splitter.split_documents(documents)
44 | try:
45 | for document in tqdm(split_text):
46 | # 获取向量并储存到pinecone
47 | Pinecone.from_documents([document], embeddings, index_name=pinecone_index)
48 | except Exception as e:
49 | print(f"Error: {e}")
50 | quit()
51 |
52 |
53 |
--------------------------------------------------------------------------------
/configs/common_config.py:
--------------------------------------------------------------------------------
1 | import torch.cuda
2 | import torch.backends
3 | import os
4 | import logging
5 | import uuid
6 | import datetime
7 |
8 | def get_formatted_date():
9 | current_date = datetime.date.today()
10 | formatted_date = current_date.strftime("%Y-%m-%d")
11 | return formatted_date
12 |
13 | date_string = get_formatted_date()
14 |
15 | LOG_FORMAT = "%(levelname) -5s %(asctime)s" "-1d: %(message)s"
16 | logger = logging.getLogger()
17 | logger.setLevel(logging.INFO)
18 | logging.basicConfig(format=LOG_FORMAT)
19 |
20 | embedding_model_dict = {
21 | "ernie-tiny": "nghuyong/ernie-3.0-nano-zh",
22 | "ernie-base": "nghuyong/ernie-3.0-base-zh",
23 | "text2vec-base": "shibing624/text2vec-base-chinese",
24 | "text2vec-large-chinese": "GanymedeNil/text2vec-large-chinese",
25 | "m3e-small": "moka-ai/m3e-small",
26 | "m3e-base": "moka-ai/m3e-base",
27 | }
28 |
29 | # Embedding model name
30 | EMBEDDING_MODEL = "text2vec-large-chinese"
31 |
32 | # Embedding running device
33 | EMBEDDING_DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
34 |
35 |
36 | # supported LLM models
37 | # llm_model_dict 处理了loader的一些预设行为,如加载位置,模型名称,模型处理器实例
38 | llm_model_dict = {
39 | "PULSE": {
40 | "name": "PULSE",
41 | "pretrained_model_name": "OpenMEDLab/PULSE-7bv5",
42 | "local_model_path": "OpenMEDLab/PULSE-7bv5",
43 | "provides": "Bloomz"
44 | }
45 | }
46 |
47 | # LLM 名称
48 | LLM_MODEL = "PULSE"
49 | # 如果你需要加载本地的model,指定这个参数 ` --no-remote-model`,或者下方参数修改为 `True`
50 | NO_REMOTE_MODEL = False
51 | # 量化加载8bit 模型
52 | LOAD_IN_8BIT = False
53 | # Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU.
54 | BF16 = False
55 | # 本地模型存放的位置
56 | MODEL_DIR = "model/"
57 | # 本地lora存放的位置
58 | LORA_DIR = "loras/"
59 |
60 | # LLM lora path,默认为空,如果有请直接指定文件夹路径
61 | LLM_LORA_PATH = ""
62 | USE_LORA = True if LLM_LORA_PATH else False
63 |
64 | # LLM streaming reponse
65 | STREAMING = False
66 |
67 | # Use p-tuning-v2 PrefixEncoder
68 | USE_PTUNING_V2 = False
69 |
70 | # LLM running device
71 | LLM_DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
72 |
73 |
74 | VS_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "vector_store")
75 |
76 | UPLOAD_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "content")
77 |
78 | # 基于上下文的prompt模版,请务必保留"{question}"和"{context}"
79 | PROMPT_TEMPLATE = """已知信息:
80 | {context}
81 |
82 | 假设你是客服,根据上述已知信息,来回答用户的问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题” 或 “没有提供足够的相关信息”,不允许在答案中添加编造成分,答案请使用中文。 问题是:{question}"""
83 |
84 | BASE_PROMPT_TEMPLATE = """Instruction: 假如你是机器人。Input:{question}"""
85 |
86 |
87 | # 文本分句长度
88 | SENTENCE_SIZE = 20
89 |
90 | # 匹配后单段上下文长度
91 | CHUNK_SIZE = 250
92 |
93 | # LLM input history length
94 | LLM_HISTORY_LEN = 8
95 |
96 | # return top-k text chunk from vector store
97 | VECTOR_SEARCH_TOP_K = 3
98 |
99 | # 知识检索内容相关度 Score, 数值范围约为0-1100,如果为0,则不生效,经测试设置为小于500时,匹配结果更精准
100 | VECTOR_SEARCH_SCORE_THRESHOLD = 700
101 |
102 | NLTK_DATA_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "nltk_data")
103 |
104 | FLAG_USER_NAME = uuid.uuid4().hex
105 |
106 | logger.info(f"""
107 | loading model config
108 | llm device: {LLM_DEVICE}
109 | embedding device: {EMBEDDING_DEVICE}
110 | dir: {os.path.dirname(os.path.dirname(__file__))}
111 | flagging username: {FLAG_USER_NAME}
112 | """)
113 |
114 |
--------------------------------------------------------------------------------
/configs/test_ass.yaml:
--------------------------------------------------------------------------------
1 | ass_name: 药物小助手
2 | llm_model: PULSE
3 | embedding_model: text2vec-large-chinese
4 | lora_name: 不使用
5 | llm_history_len: 8
6 | knowledge_set_name: drug_kb
7 | top_k: 3
8 | score_threshold: 500
9 | chunk_content: true
10 | chunk_sizes: 250
11 | show_reference: true
12 | prompt_template: '假设你是药物助手,请根据文档来回复,如果文档内容为空或者None,则忽略,文档:{context}\n{chat_history}User:{question}Helper:'
13 |
--------------------------------------------------------------------------------
/demo_data/kb_drug_demo.jsonl:
--------------------------------------------------------------------------------
1 | {"docs": ["布洛芬分散片 英文名\nIbuprofen Dispersible Tablets", "布洛芬分散片 用药科室\n解热镇痛抗炎药", "布洛芬分散片 药物分类\n化学药品", "布洛芬分散片 药物剂量\n50mg", "布洛芬分散片 适应症\n适用于因感冒、急性上呼吸道感染、急性咽喉炎等疾病引起的发热、头痛、周身痛及关节痛;其他疾病如关节炎、牙疾等所致的轻、中度疼痛。", "布洛芬分散片 禁忌症\n1 活动期消化道溃疡。2 对本药物过敏者,因服用阿司匹林和其它非类固醇类抗炎药诱发哮喘、鼻炎或荨麻疹的患者。3 有失血倾向者。", "布洛芬分散片 药物用法\n口服或少许水制成混悬液服用。1 1~12岁儿童患者(1) 用于发热,推荐剂量为每日每公斤体重20mg(2/5片),分三次服用,或遵医嘱.(2) 用于镇痛,推荐剂量为每日每公斤体重30mg(3/5片),分三次服用,或遵医嘱。2 成人及12岁以上儿童,推荐剂量为一次0.2~0.4g(4~8片)一日3次,或遵医嘱。", "布洛芬分散片 不良反应\n本品的不良反应较少,一般为肠、胃不适或皮疹、头痛、耳鸣,偶见转氨酶升高等。也有引起胃肠道出血而加重溃疡的报道。", "布洛芬分散片 注意事项\n1.肠胃病患者慎用。但对其他抗风湿药物耐受性差者可能对本品有良好耐受性。2.有支气管哮喘病史患者,可能会引起支气管痉挛。3.并用抗凝血剂的患者,服药的最初几日应随时监测其凝血酶原时间。4.心功能不全患者慎用。5.连续服用三天发热不退时,应请医生诊治。6.严重肝、肾功能障碍,红斑狼疮或其它免疫疾病患者慎用。7.剧烈腹痛、粪便黑色或带血,皮肤、粘膜炎症,患眼疾者请在医生指导下用药。", "布洛芬分散片 药物相互作用\n口服降糖药、甲氨喋呤、苯妥英、毛地黄、锂剂及饮酒等可降低本品耐受量。", "布洛芬分散片 存储方式\n遮光密封,于阴凉干燥处保存。有效期暂定一年。"]}
2 | {"docs": ["西沙必利片 英文名\nCisapride Tablets", "西沙必利片 用药科室\n消化科用药", "西沙必利片 药物分类\n化学药品", "西沙必利片 药物剂量\n5mg", "西沙必利片 儿童禁忌\n婴幼儿禁用", "西沙必利片 孕妇禁忌\n1.尽管在动物不影响胚胎形成,无原始的胚胎毒性,也无致畸作用,但若在\n妊娠期,尤其是在妊娠的头三个月应权衡利弊使用。\n2.尽管经乳汁排泄的量很少,仍建议哺乳母亲勿用。", "西沙必利片 老年人禁忌\n在老年人,由于中度延长了清除半衰期,稳态血浆浓度一般会增高,故治疗\n剂量应酌减。", "西沙必利片 适应症\n全胃肠促动力药。主要用于功能性消化不良,X线、内窥镜检查为阴性的上消化道不适,症状为早饱,饭后饱胀、食量减退、胃胀、嗳气过多、食欲缺乏、恶心、呕吐或类似溃疡的主诉(上腹部灼痛)。另可用于轻度返流性食管炎的治疗。", "西沙必利片 禁忌症\n已知对本品过敏者禁用。禁止同时口服或非肠道使用强效抑制CYP3A4酶的药物,包括:三唑类抗真菌药;大环内酯类抗生素;HIV蛋白酶抑制剂;萘法唑酮;(见药物相互作用)。心脏病、心律失常、QT间期延长者禁用,禁止与引起QT间期延长的药物一起用;有水、电解质紊乱的患者禁用,特别是低血钾或低血镁者禁用;心动过缓者禁用;先天QT间期延长或有先天QT间期延长综合症家族史者禁用;肺、肝、肾功能不全的病人禁用;婴幼儿禁用。", "西沙必利片 药物用法\n口服治疗:用量:每日最高服药剂量为30mg。成人:根据病情的程度,每日总量15~30mg,分2~3次给药,每次5mg(剂量可以加倍)。体重为25~50公斤的儿童:每次最大剂量为5mg,每日四次。或遵医嘱。体重为25公斤以下的儿童:每次0.2mg/kg体重,每日三至四次。或遵医嘱。建议尽量避免与西柚汁一起服用。在肾功能不全时,建议减半日用量。", "西沙必利片 不良反应\n因本品的药理活性,可能发生瞬时性腹部痉挛、腹鸣和腹泻。发生腹部痉挛时,可减半剂量。偶有过敏反应包括红疹、瘙痒、荨麻疹、支气管痉挛、轻度短暂的头痛或头晕以及与剂量相关的尿频的报道。极少数心律失常的报道,包括室性心动过速、室颤、尖端扭转型室速、QT延长。大多数此类病人常同时服用数种其它药物-其中包括抑制CYP3A4酶的药物,或已患有心脏病、已有心律失常的危险因素存在。罕见可逆性肝功能异常的报道,可伴或不伴胆汁郁积。虽然有男子女性乳房和乳溢的病例报道,但在大规模监测研究中发生率(<", "西沙必利片 注意事项\n用药过程中应注意检测心电图。胃肠道运动增加可造成危害的病人,必须慎用。有猝死的家庭史,要权衡利弊谨慎使用本品。QT间期大于450毫秒的病人或电解质紊乱的病人,不应使用本品。本品不影响神经运动性功能,不引起镇静和嗜睡。然而,本品可加速中枢神经系统抑制剂的吸收,如巴比妥酸盐、酒精等,因此同时给予应慎重。", "西沙必利片 药物相互作用\n本品的主要代谢途径是通过CYP3A4酶进行代谢。若同时口服或非肠道使用能抑制此酶的药物,可导致血浆西沙必利浓度升高,从而增加QT间期和心律失常的危险性,心律失常包括室性心动过速、室颤和尖端扭转型室速。所以禁止与这些药物同时服用。", "西沙必利片 存储方式\n15~30℃干燥处贮存,放于儿童不易拿到处。"]}
3 | {"docs": ["克林霉素磷酸酯氯化钠注射液 英文名\nClindamycin Phosphate and Sodium Chloride Injection", "克林霉素磷酸酯氯化钠注射液 用药科室\n感染科用药","克林霉素磷酸酯氯化钠注射液 药物剂量\n100毫升:克林霉素0.6克与氯化钠0.9克", "克林霉素磷酸酯氯化钠注射液 儿童禁忌\n小于4岁儿童慎用。", "克林霉素磷酸酯氯化钠注射液 孕妇禁忌\n孕妇及哺乳期妇女使用本品应注意。", "克林霉素磷酸酯氯化钠注射液 适应症\n1.革兰氏阳性菌引起的下列各种感染性疾病:\n(1)扁桃体炎、化脓性中耳炎、鼻窦炎等。\n(2)急性支气管炎、慢性支气管炎急性发作、肺炎、肺脓肿和支气管扩张合并感染等。\n(3)皮肤和软组织感染:疖、痈、脓肿、蜂窝组织炎、创伤和手术后感染等。\n(4)泌尿系统感染:急性尿道炎、急性肾盂肾炎、前列腺炎等。\n(5)其他:骨髓炎、败血症、腹膜炎和口腔感染等。\n2.厌氧菌引起的各种感染性疾病:\n(1)脓胸、肺脓肿、厌氧菌性肺病。\n(2)皮肤和软组织感染、败血症。\n(3)腹腔内感染:腹膜炎、腹腔内脓肿。\n(4)女性盆腔及生", "克林霉素磷酸酯氯化钠注射液 禁忌症\n本品与林可霉素、克林霉素有交叉耐药性,对克林霉素或林可霉素有过敏史者禁用。", "克林霉素磷酸酯氯化钠注射液 药物用法\n静脉滴注:成人每日0.6—2.7g(以克林霉素计),分2—3次;儿童15—40mg/kg(以克林霉素计),分2—4次,或遵医嘱。\n静滴速度:每瓶不少于30分钟。", "克林霉素磷酸酯氯化钠注射液 不良反应\n1.长期静脉滴注应注意静脉炎的出现。\n2.胃肠道反应:偶见恶心、呕吐、腹痛及腹泻。\n3.过敏反应:少数病人可出现药物性皮疹。\n4.对造血系统基本无毒性反应,偶可引起中性粒细胞减少,嗜酸性粒细胞增多,血小板减少等。一般轻微为一过性。\n5.少数病人可发生一过性碱性磷酸酶、血清转氨酶轻度升高及黄疸。\n6.极少数病人可产生伪膜性结肠炎。", "克林霉素磷酸酯氯化钠注射液 注意事项\n1.与青霉素、头孢菌素类抗生素无交叉过敏反应,可用于对青霉素过敏者。\n2.禁与氨苄青霉素、苯妥英钠、巴比妥盐、氨茶碱、葡萄糖酸钙及硫酸镁配伍;与红霉素呈拮抗作用,不宜合用。\n3.肝、肾功能损害者及小于4岁儿童慎用。孕妇及哺乳期妇女使用本品应注意其利弊。\n4.如出现伪膜性肠炎,选用万古霉素口服0.125—0.5g,每日4次进行治疗。\n5.本品使用前应仔细检查,如有异物、溶液混浊、封口松动、瓶身有裂纹者切勿使用。\n6.本品应一次性使用,输注后的剩余药液,切勿贮藏再用。", "克林霉素磷酸酯氯化钠注射液 药物相互作用\n本品与红霉素呈拮抗作用,不宜合用。", "克林霉素磷酸酯氯化钠注射液 存储方式\n遮光,密闭,在阴凉处保存。"]}
4 | {"docs": ["柴银口服液 用药科室\n呼吸科用药", "柴银口服液 药物分类\n中药", "柴银口服液 药物剂量\n每瓶装20毫升", "柴银口服液 适应症\n清热解毒,利咽止咳。用于上呼吸道感染外感风热证,症见:发热恶风,头痛、咽痛,汗出,鼻塞流涕,咳嗽,舌边尖红,苔薄黄。", "柴银口服液 禁忌症\n尚不明确。", "柴银口服液 药物用法\n口服。一次1瓶,一日3次,连服3天。", "柴银口服液 不良反应\n偶有腹泻。", "柴银口服液 注意事项\n脾胃虚寒者宜温服。", "柴银口服液 存储方式\n密封。"]}
5 | {"docs": ["麻仁润肠丸 英文名\nMaren Runchang Wan", "麻仁润肠丸 用药科室\n消化科用药", "麻仁润肠丸 药物分类\n中药", "麻仁润肠丸 药物剂量\n每丸重6克", "麻仁润肠丸 孕妇禁忌\n孕妇忌服。", "麻仁润肠丸 适应症\n润肠通便。用于肠胃积热,胸腹胀满,大便秘结。", "麻仁润肠丸 禁忌症\n孕妇忌服。", "麻仁润肠丸 药物用法\n口服。一次1-2丸,一日2次。", "麻仁润肠丸 不良反应\n尚不明确。", "麻仁润肠丸 注意事项\n1.孕妇忌服。月经期慎用。\n2.年青体壮者便秘时不宜用本药。\n3.忌食生冷、油腻、辛辣食物。\n4.严重气质性病变引起的排便困难,如结肠癌,严重的肠道憩窒,肠梗阻及炎症性肠病等忌用。\n5.服药三天后症状未改善,或出现其他症状时,应及时去医院就诊。\n6.按照用法用量服用,有慢性病史者、小儿及年老体虚者不宜长期服用,应在医师指导下服用。\n7.药品性状发生改变时禁止服用。\n8.儿童必须在成人的监护下使用。\n9.请将此药品放在儿童不能接触的地方。\n10.如正在服用其他药品,使用本品前请咨询医师或药师。", "麻仁润肠丸 药物相互作用\n如与其他药物同时使用可能会发生药物相互作用,详情请咨询医师或药师。", "麻仁润肠丸 存储方式\n密封"]}
6 | {"docs": ["布洛芬缓释片 英文名\nIbuprofen Sustained Release Tablets", "布洛芬缓释片 用药科室\n解热镇痛抗炎药", "布洛芬缓释片 药物分类\n化学药品", "布洛芬缓释片 药物剂量\n0.3g", "布洛芬缓释片 适应症\n用于减轻中度疼痛,如关节痛、神经痛、肌肉痛、偏头痛、头痛、痛经、牙痛;也可用于感冒和流感引起的发热。", "布洛芬缓释片 禁忌症\n1 对其他非甾体抗炎药过敏者禁用;\n2 孕妇及哺乳期妇女禁用;\n3 对阿司匹林过敏的哮喘患者禁用。", "布洛芬缓释片 药物用法\n口服。成人,一次1片,一日2次(早晚各一次)。", "布洛芬缓释片 不良反应\n1 少数病人可出现恶心、呕吐、胃烧灼感或轻度消化不良、胃肠道溃疡及出血、转氨酶升高、头痛、头晕、耳鸣、视力模糊、精神紧张、嗜睡、下肢水肿或体重骤增;\n2 罕见皮疹、过敏性肾炎、膀胱炎、肾病综合征、肾乳头坏死或肾功能衰竭、支气管痉挛。", "布洛芬缓释片 注意事项\n1 本品为对症治疗药,不宜长期或大量使用,用于止痛不得超过5天,用于解热不得超过3天,如症状不缓解,请咨询医师或药师;\n2.必须整片吞服,不得碾碎或溶解后服用;\n3 不能同时服用其他含有解热镇痛药的药品(如某些复方抗感冒药);\n4 服用本品期间不得饮酒或含有酒精的饮料;\n5 有下列情况患者慎用:60岁以上、支气管哮喘、肝肾功能不全、凝血机制或血小板功能障碍(如血友病);\n6 下列情况患者应在医师指导下使用:有消化性溃疡史、胃肠道出血、心功能不全、高血压;\n7 如服用过量或出现严重不良反应,应立即就医;\n8 对本品过敏者禁用,过敏体质者慎用;\n9 本品性状发生改变时禁止使用;\n10 请将本品放在儿童不能接触的地方;\n11 如正在使用其他药品,使用本品前请咨询医师或药师。", "布洛芬缓释片 药物相互作用\n1 本品与其他解热、镇痛、抗炎药物同用时可增加胃肠道不良反应,并可能导致溃疡;\n2 本品与肝素、双香豆素等抗凝药同用时,可导致凝血酶原时间延长,增加出血倾向;\n3 本品与地高辛、甲氨蝶呤、口服降血糖药物同用时,能使这些药物的血药浓度增高,不宜同用;\n4 本品与呋塞米(呋喃苯胺酸)同用时,后者的排钠和降压作用减弱;与抗高血压药同用时,也降低后者的降压效果;\n5 如与其他药物同时使用可能会发生药物相互作用,详情请咨询医师或药师。", "布洛芬缓释片 存储方式\n避光,密闭保存 室温贮藏。"]}
7 | {"docs": ["板蓝根颗粒 用药科室\n感冒用药", "板蓝根颗粒 药物分类\n中药", "板蓝根颗粒 适应症\n清热解毒,凉血利咽。用于肺胃热盛所致的咽喉肿痛、口咽干燥;急性扁桃体炎见上述证候者。", "板蓝根颗粒 禁忌症\n忌烟,酒及辛辣,生冷,油腻食物。", "板蓝根颗粒 药物用法\n开水冲服,一次5~10g(含糖型),或一次3~6g(无糖型)一日3~4次。", "板蓝根颗粒 不良反应\n尚不明确。", "板蓝根颗粒 注意事项\n1不宜在服药期间同时服用滋补性中成药。\n2风寒感冒者不宜使用,其表现为恶寒重,发热轻,无汗,鼻塞流涕,口不渴,咳吐稀白痰。\n3有高血压、心脏病、肝病、糖尿病、肾病等慢性病严重者,孕妇或正在接受其他治疗的患者,应在医师指导下服用。\n4服药三天后,症状未改善,或出现发热咳嗽加重,并有其他症状如胸闷、心悸等应去医院就诊。\n5按照用法用量服用,小儿、年老体虚患者应在医师指导下服用。\n6连续服用应向医师咨询。\n7药品性状发生改变时禁止使用。\n8儿童必须在成人监护下使用。\n9请将此药品放在儿童不能接触的地方。\n10如正在服用其他药品,使用本品前请咨询医师或药师。", "板蓝根颗粒 药物相互作用\n如与其他药物同时使用可能会发生药物相互作用,详情请咨询医师或药师。", "板蓝根颗粒 存储方式\n密封。"]}
8 | {"docs": ["维生素B6片 英文名\nVitamin B6 Tablets", "维生素B6片 用药科室\n营养科用药", "维生素B6片 药物分类\n化学药品", "维生素B6片 药物剂量\n10mg", "维生素B6片 孕妇禁忌\n孕妇接受大量维生素B6,可致新生儿产生维生素B6依赖综合症。乳母摄入正常需要量对婴儿无不良影响。", "维生素B6片 适应症\n用于预防和治疗维生素B6缺乏症,如脂溢性皮炎、唇干裂。也可用于减轻妊娠呕吐。", "维生素B6片 禁忌症\n尚不明确。", "维生素B6片 药物用法\n口服。成人,一日1-2片;儿童,一日0.5~1片,连用3周。", "维生素B6片 不良反应\n维生素B6在肾功能正常时几乎不产生毒性。若每天服用200mg,持续30天以上,曾报道可产生维生素B6依赖综合征。每日应用2~6g,持续几个月,可引起严重神经感觉异常,进行性步态不稳至足麻木、手不灵活,停药后可缓解,但仍软弱无力。", "维生素B6片 注意事项\n1 必须按推荐剂量服用,不可超量服用,用药3周后应停药。2 孕妇及哺乳期妇女应在医师指导下使用。3 如服用过量或出现严重不良反应,应立即就医。4 对本品过敏者禁用,过敏体质者慎用。5 本品性状发生改变时禁止使用。6 请将本品放在儿童不能接触的地方。7 儿童必须在成人监护下使用。8 如正在使用其他药品,使用本品前请咨询医师或药师。", "维生素B6片 药物相互作用\n1.小剂量维生素B6(一日5毫克)与左旋多巴合用,可降低后者治疗帕金森病的疗效。但制剂中若含有脱羧酶抑制剂如卡比多巴时,对左旋多巴无影响。 2.氯霉素、盐酸肼酞嗪、异烟肼、青霉胺及免疫抑制剂包括糖皮质激素、环磷酰胺、环孢素等药物可拮抗维生素B6或增强维生素B6经肾排泄,甚至可引起贫血或周围神经炎。 3.服用雌激素时应增加维生素B6的用量,因为雌激素可使维生素B6在体内的活性降低。4.如与其他药物同时使用可能会发生药物相互作用,详情请咨询医师或药师。", "维生素B6片 存储方式\n遮光,密闭保存。"]}
9 | {"docs": ["氯雷他定胶囊 英文名\nLoratadine", "氯雷他定胶囊 用药科室\n呼吸科用药", "氯雷他定胶囊 药物分类\n化学药品", "氯雷他定胶囊 药物剂量\n10毫克", "氯雷他定胶囊 适应症\n用于缓解过敏性鼻炎有关的症状,如喷嚏、流涕、鼻痒、鼻塞以及眼部痒及烧灼感。口服药物后,鼻和眼部症状及体征得以迅速缓解。亦适用于缓解慢性荨麻疹、瘙痒性皮肤病及其他过敏性皮肤病的症状及体征。", "氯雷他定胶囊 禁忌症\n特异体质的病人禁用。", "氯雷他定胶囊 药物用法\n口服。成人及12岁以上儿童:一日1次,一次1粒(10毫克)。 2~12岁儿童:体重>30公斤:一日1次,一次1粒(10毫克)。", "氯雷他定胶囊 不良反应\n在每天10mg的推荐剂量下,本品未见明显的镇静作用。常见不良反应有乏力、头痛、嗜睡、口干、胃肠道不适包括恶心、胃炎以及皮疹等。罕见不良反应有脱发、过敏反应、肝功能异常、心动过速及心悸等。", "氯雷他定胶囊 注意事项\n1 严重肝功能不全的患者请在医生指导下使用。 \n2 妊娠期及哺乳期妇女慎用。\n3 在做皮试前约48小时左右应中止使用本品,因抗组胺药能阻止或降低皮试的阳性反应发生。\n4 对本品过敏者禁用,过敏体质者慎用。\n5 本品性状发生改变时禁止使用。\n6 请将本品放在儿童不能接触的地方。\n7 儿童必须在成人监护下使用。\n8 如正在使用其他药品,使用本品前请咨询医师或药师。", "氯雷他定胶囊 药物相互作用\n同时服用酮康唑、大环内酯类抗生素、西咪替丁、茶碱等药物,会提高氯雷他定在血浆中的浓度,应慎用。其他已知能抑制肝脏代谢的药物,在未明确与氯雷他定相互作用前应谨慎合用。", "氯雷他定胶囊 存储方式\n密封,遮光。"]}
10 | {"docs": ["板蓝根颗粒 药物分类\n中药", "板蓝根颗粒 药物剂量\n(1)每袋装5g(相当于饮片7g)(2)每袋装10g(相当于饮片14g)(3)每袋装3g(无蔗糖,相当于饮片7g)"]}
11 | {"docs": ["氯雷他定片 英文名\nLoratadine Tablets", "氯雷他定片 用药科室\n感冒用药", "氯雷他定片 药物分类\n化学药品", "氯雷他定片 药物剂量\n10毫克", "氯雷他定片 儿童禁忌\n12岁以下儿童应用本品的安全性尚未确定。", "氯雷他定片 孕妇禁忌\n孕妇慎用。服药期宜停止哺乳。", "氯雷他定片 老年人禁忌\n肝肾功能轻中度受损时,对本药的代谢和排泄无明显的影响,所以老年患者用药量与成人相同。", "氯雷他定片 适应症\n用于缓解过敏性鼻炎有关的症状,如喷嚏、流涕、鼻痒、鼻塞(鼻塞是耳鼻咽喉科常见的症状之一,最常见的原因包括鼻炎,鼻中隔偏曲,鼻息肉,鼻窦炎等。理论上来说,鼻塞都可以通过不同的治疗方法进行解决。)以及眼部痒及烧灼感。口服药物后,鼻和眼部症状及体征得以迅速缓解。亦适用于缓解慢性荨麻疹、瘙痒性皮肤病及其他过敏性皮肤病的症状及体征。", "氯雷他定片 禁忌症\n对本品中的成分过敏或特异体质的病人禁用。", "氯雷他定片 药物用法\n1. 成人及12岁以上儿童:一日1次,一次1片(10毫克)。\n2. 2~12岁儿童:\n体重>30公斤:一日1次,一次1片(10毫克);\n体重≤30公斤:一日1次,一次半片(5毫克)。", "氯雷他定片 不良反应\n主要包括头痛,嗜睡,疲乏,口干视觉模糊,血压降低或升高,心悸,晕厥,运动机能亢进,肝功能改变,黄疸,肝炎,肝坏死,脱发,癫痫发作,乳房肿大,多形性红斑及全身性过敏反应。", "氯雷他定片 注意事项\n1.严重肝功能不全的患者请在医生指导下使用。\n2.妊娠期及哺乳期妇女慎用。\n3.在作皮试前的约48小时左右应中止使用本品,因抗组胺药能阻止或降低皮试的阳性反应发生。(皮试是皮肤(或皮内)敏感试验的简称。某些药物在临床使用过程中容易发生过敏反应,如青霉素、链霉素、细胞色素C等,常见的过敏反应包括皮疹、荨麻疹、皮炎、发热、血管神经性水肿、哮喘、过敏性休克等,其中以过敏性休克最为严重,甚至可导致死亡。)\n4.对本品过敏者禁用,过敏体质者慎用。\n5.本品性状发生改变时禁止使用。\n6.请将本品放在儿童不能接触的地方。\n7.儿童必须在成人监护下使用。\n8.如正在使用其他药品,使用本品前请咨询医师或药师。", "氯雷他定片 药物相互作用\n1.同时服用酮康唑、大环内酯类抗生素、西咪替丁、茶碱等药物,会提高氯雷他定在血浆中的浓度,应慎用。其他已知能抑制肝脏代谢的药物,在未明确与氯雷他定相互作用前应谨慎合用。\n2.抑制肝药物代谢酶功能的药物能使本品的代谢减慢。每日同服酮康唑400mg,可使氯雷他定及其活性代谢物去羧乙基氯雷他定的血浆浓度升高,但未观察到心电图改变。与大环内酯类抗生素、西咪替丁、茶碱等药物并用也可抑制氯雷他定的代谢。\n3.如正在服用其它药品,使用本品前请咨询医师或药师。", "氯雷他定片 存储方式\n密封保存。"]}
12 | {"docs": ["卡马西平片 英文名\nCarbamazepine", "卡马西平片 用药科室\n脑神经科用药", "卡马西平片 药物分类\n化学药品", "卡马西平片 药物剂量\n0.1g", "卡马西平片 儿童禁忌\n本品可用于各年龄段儿童,具体参考[用法用量]。", "卡马西平片 孕妇禁忌\n本品能通过胎盘,是否致畸尚不清楚,妊娠早期需慎用;\n本品能分泌入乳汁,约为血药浓度60%,哺乳期妇女不宜应用。", "卡马西平片 老年人禁忌\n老年患者对本品敏感者多,常可引起认知功能障碍、激越、不安、、\n焦虑、精神错乱、房室传导阻滞或心动过缓,也可引起再障。", "卡马西平片 适应症\n1.复杂部分性发作(亦称精神运动性发作或颞叶癫癎)、全身强直-阵孪性发作、上述两种混合性发作或其他部分性或全身性发作;对典型或不典型失神发作、肌阵孪或失神张力发作无效。2.三叉神经痛和舌咽神经痛发作,亦用作三叉神经痛缓解后的长期预防性用药。也可用于脊髓痨和多发性硬化、糖尿病性周围性神经痛、患肢痛和外伤后神经痛以及疱疹后神经痛。3.预防或治疗躁狂-抑郁症;对锂或抗精神病药或抗抑郁药无效的或不能耐受的躁狂-抑郁症,可单用或与锂盐和其它抗抑郁药合用。4.中枢性部分性尿崩症,可单用或氯磺丙脲或氯贝丁酯等合用。5.", "卡马西平片 禁忌症\n禁用:有房室传导阻滞,血清铁严重异常、骨髓抑制、严重肝功能不全等病史者。", "卡马西平片 药物用法\n成人常用量1.抗惊厥,开始一次0.1g,一日2~3次;第二日后每日增加0.1g,直到出现疗效为止;维持量根据调整至最低有效量,分次服用;注意个体化,最高量每日不超过1.2g。2.镇痛,开始一次0.1g,一日2次;第二日后每隔一日增加0.1~0.2g,直到疼痛缓解,维持量每日0.4~0.8g,分次服用;最高量每日不超过1.2g。3.尿崩症,单用时一日0.3~0.6g,如与其他抗利尿药合用,每日0.2~0.4g,分3次服用。4.抗燥狂或抗精神病,开始每日0.2~0.4g,每周逐渐增加至最大量1.6g,分3~4次服用。", "卡马西平片 不良反应\n1.较常见的不良反应是中枢神经系统的反应,表现为视力模糊、复视、眼球震颤。2.因刺激抗利尿激素分泌引起水的潴留和低钠血症(或水中毒),发生率约10~15%。3.较少见的不良反应有变态反应,Stevens-Johnson综合症或中毒性表皮坏死溶解症、皮疹、荨麻疹、瘙痒;儿童行为障碍,严重腹泻,红斑狼疮样综合症(荨麻疹、瘙痒、皮疹、发热、咽喉痛、骨或关节痛、乏力)。4.罕见的不良反应有腺体病,心律失常或房室传导阻滞(老年人尤其注意),骨髓抑制,中枢神经系统中毒(语言困难、精神不安、耳鸣、颤、幻视),过敏性肝炎,低钙血症,直接影响骨代谢导致骨质疏松,肾脏中毒,周围神经炎,急性尿紫质病,栓塞性脉管炎,过敏性肺炎,急性间歇性卟啉病,可致甲状腺功能减退。因注意有一例合并无菌性脑膜炎的肌阵孪性癫癎患者,接受本品治疗后引起脑膜炎复发。偶见粒细胞减少,可逆性血小板减少,再障,中毒性肝炎。", "卡马西平片 注意事项\n1.与三环类抗抑郁药有交叉过敏反应。2.用药期间注意检查:全血细胞检查(包括血小板、网织红细胞及血清铁,应经常复查达2~3年),尿常规,肝功能,眼科检查;卡马西平血药浓度测定。3.一般疼痛不要用本品。4.糖尿病人可能引起尿糖增加,应注意。5.癫癎患者不能突然撤药。6.已用其他抗癫癎药的病人,本品用量应逐渐递增,治疗4周后可能需要增加剂量,避免自身诱导所致血药浓度下降。7.下列情况应停药:肝中毒或骨髓抑制症状出现,心血管系统不良反应或皮疹出现。8.用于特异性疼痛综合征止痛时,如果疼痛完全缓解,应每月减量至停药。9.饭后服用可减少胃肠反应,漏服时应尽快补服,不可一次服双倍量,可一日内分次补足。10.下列情况应慎用:乙醇中毒,心脏损害,冠心病,糖尿病,青光眼,对其他药物有血液反应史者(易诱发骨髓抑制),肝病,抗利尿激素分泌异常或其他内分泌紊乱,尿潴留,肾病。\n【孕妇及哺乳期妇女用药】本品能通过胎盘,是否致畸尚不清楚,妊娠早期需慎用;本品能分泌入乳汁,约为血药浓度60%,哺乳期妇女不宜应用。", "卡马西平片 药物相互作用\n1.与对乙酰氨基酚合用,尤其是单次超量或长期大量,肝脏中毒的危险增加,有可能使后者疗效降低。2.与香豆素类抗凝药合用,由于本品的肝酶的正诱导作用,使抗凝药的血浓度降低,半衰期缩短,抗凝效应减弱,应测定凝血酶原时间而调整药量。3.与碳酸酐酶抑制药合用,骨质疏松的危险增加。4.由于本品的肝酶诱导作用,与氯磺丙脲、氯贝丁酯(安妥明)、去氨加压素(desmopressin)、赖氨加压素(lypressin)、垂体后叶素、加压素等合用,可加强抗利尿作用,合用的各药都需减量。5.与含雌激素的避孕药、环孢素、洋地黄类(可能地高辛除外)、雌激素、左旋甲状腺素或奎尼丁合用时,由于卡马西平对肝代谢酶的正诱导,这些药的效应都会降低,用量应作调整,改用仅含孕激素(黄体酮)的口服避孕药。与口服避孕药合用可能出现阴道大出血。6.与多西环素(强力霉素)合用,后者的血药浓度可能降低,必要时需要调整用量。7.红霉素与醋竹桃霉素(troleandomycin)以及右丙氧芬(detropropoxyphene)可抑制卡马西平的代谢,引起后者血药浓度的升高,出现毒性反应。8.氟哌啶醇、洛沙平、马普替林、噻吨类或三环类抗抑郁药可增强卡马西平的代谢,引起后者血药浓度升高,出现毒性反应。9.锂盐可以降低卡马西平的抗利尿作用。10.与单胺氧化酶(MAO)抑制合用,可引起高热或(和)高血压危象、严重惊厥甚至死亡,两药应用至少要间隔14天。当卡马西平用作抗惊厥剂时,MAO抑制药可以改变癫癎发作的类型。11.卡马西平可以降低诺米芬辛(nomifensine)的吸收并加快其消除。12、苯巴比妥和苯妥英加速卡马西平的代谢,可将卡马西平的t1/2降至9~10小时。", "卡马西平片 存储方式\n密封。"]}
13 | {"docs": ["头孢克肟片 英文名\nCefixime Tablets", "头孢克肟片 用药科室\n感染科用药", "头孢克肟片 药物分类\n化学药品", "头孢克肟片 适应症\n本品适用于敏感菌所致的咽炎、扁桃体炎、急性支气管炎和慢性支气管炎急性发作、中耳炎、尿路感染、单纯性淋病(宫颈炎或尿道炎)等。", "头孢克肟片 禁忌症\n对本品或头孢菌素类抗生素有过敏史者禁用。", "头孢克肟片 药物用法\n成人每日400mg,儿童每日8mg/kg,可单次或分2次口服。儿童体重≥50kg或年龄≥12岁时用成人剂量。治疗单纯性淋病时宜400mg单剂疗法。治疗单纯性淋病时宜400mg单剂疗法。肾功能不全的患者其肌酐清除率(CCr)为21~60ml/min并进行血液透析者给标准剂量的50%,即每日给药300mg; CCr≤20ml/min并进行腹膜透析者给标准剂量的50%,即每日给药200mg。", "头孢克肟片 不良反应\n头孢克肟不良反应大多短暂而轻微。最常见者为胃肠道反应,其中腹泻16%、大便次数增多6%、腹痛3%、恶心7%、消化不良3%、腹胀4%;发生率低于2%的不良反应有皮疹、荨麻疹、药物热、瘙痒、头痛、头昏。实验室异常表现为一过性ALT、AST、ALP、LDH、胆红素、BUN、Cr 升高,血小板和白细胞计数一过性减少和嗜酸性粒细胞增多,直接Coombs试验阳性等。", "头孢克肟片 注意事项\n1. 对头孢菌素类抗生素有过敏史者禁用。肠炎患者慎用,6月以下儿童不宜应用。过去有青霉素过敏休克病史的患者慎用本品,因亦有发生过敏性休克的可能。\n2. 肾功能不全者血清半衰期延长,须调整给药剂量。\n3. 相同剂量混悬剂与片剂服用后以前者为高。血药浓度以前者为高。\n4. 治疗化脓性链球菌感染疗程至少需10天。\n5.中耳炎患者宜用混悬剂治疗。", "头孢克肟片 药物相互作用\n1.本品与下列药物有配伍禁忌;硫酸阿米卡星、庆大霉素、卡那霉素、妥布霉素。新霉素、盐酸金霉素、盐酸四环素、盐酸土霉素、粘菌素甲磺酸钠、硫酸多粘菌素B、葡萄糖酸红霉素、乳糖酸红霉素、林可霉素、磺胺异恶唑、氨茶碱、可溶性巴比妥、氯化钙、葡萄糖酸钙、盐酸苯海拉明的其他抗组胺药、利多卡因、去甲肾上腺素、间羟胺、哌甲酯、琥珀胆碱等。偶亦可能与下列药品发生配伍禁 J忌:青霉素、甲氧西林、琥珀酸氢化可的松、苯妥英钠。丙氯拉嗪(Prochlorperazine)、维生素B族和维生素C、水解蛋白。\n2.呋塞米、依他尼酸、布美他尼等强利尿药,卡氮芥、链佐星(streptozocin)等抗肿瘤药以及氨基糖苷类抗生素与本品合用有增加肾毒性的可能\n3. 棒酸可增加本品对某些因产生B内酰胺酶而对之耐药的革兰氏阴性杆菌的抗菌活性。", "头孢克肟片 存储方式\n密封,置阴凉干燥处。"]}
14 | {"docs": ["小儿善存片 英文名\nCentrum Junior Tablets", "小儿善存片 用药科室\n儿科用药", "小儿善存片 药物分类\n化学药品", "小儿善存片 适应症\n本品为维生素及矿物质药品。用于3—12岁儿童维生素和矿物质的补充。", "小儿善存片 禁忌症\n慢性肾功能衰竭、高钙血症、高磷血症伴肾性佝偻病患者禁用。苯丙酮尿症患者禁用。", "小儿善存片 药物用法\n口服,一日1片", "小儿善存片 不良反应\n偶见胃部不适。", "小儿善存片 注意事项\n1 严格按规定的剂量服用,需要大量服用时,请咨询医师或药师。\n2 如服用过量或出现严重不良反应,应立即就医。\n3 对本品过敏者禁用,过敏体质者慎用。\n4 本品性状发生改变时禁止使用。\n5 请将本品放在儿童不能接触的地方。\n6 儿童必须在成人监护下使用。\n7 如正在使用其他药品,使用本品前请咨询医师或药师。", "小儿善存片 药物相互作用\n1.抗酸药可影响本品中维生素A的吸收,故不应同服。\n2.不应与含有大量镁、钙的药物合用,以免引起高镁、高钙血症。\n3.如与其他药物同时使用可能会发生药物相互作用,详情请咨询医师或药师。", "小儿善存片 存储方式\n遮光密封。"]}
15 | {"docs": ["木糖醇颗粒 英文名\nXylitol Granules", "木糖醇颗粒 用药科室\n内分泌科用药", "木糖醇颗粒 药物分类\n化学药品", "木糖醇颗粒 药物剂量\n10g:9.85g", "木糖醇颗粒 适应症\n用作糖尿病患者的糖的代用品。", "木糖醇颗粒 禁忌症\n1 对本品过敏者禁用.\n2 低血糖患者禁用.", "木糖醇颗粒 药物用法\n口服。成人一次1袋,一日3-5次。", "木糖醇颗粒 不良反应\n初服时可有肠鸣、腹胀、腹泻等症状。适当减少剂量,可减少不良反应。", "木糖醇颗粒 注意事项\n1 儿童用量请咨询医师或药师。2 本品不宜过量服用。3 对本品过敏者禁用,过敏体质者慎用。4 本品性状发生改变时禁止使用。5 请将本品放在儿童不能接触的地方。6 儿童必须在成人监护下使用。7 如正在使用其他药品,使用本品前请咨询医师或药师。", "木糖醇颗粒 药物相互作用\n如与其他药物同时使用可能会发生药物相互作用,详情请咨询医师或药师。", "木糖醇颗粒 存储方式\n密闭,阴凉干燥处保存."]}
--------------------------------------------------------------------------------
/demo_data/lora_demo.xlsx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JuneYaooo/medical_kb_chatbot/be4ff60c47ffb31ac3052f82b5ed607d7c7ab089/demo_data/lora_demo.xlsx
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: kb_chat
2 | channels:
3 | - defaults
4 | dependencies:
5 | - _libgcc_mutex=0.1=main
6 | - _openmp_mutex=5.1=1_gnu
7 | - ca-certificates=2023.05.30=h06a4308_0
8 | - ld_impl_linux-64=2.38=h1181459_1
9 | - libffi=3.4.4=h6a678d5_0
10 | - libgcc-ng=11.2.0=h1234567_1
11 | - libgomp=11.2.0=h1234567_1
12 | - libstdcxx-ng=11.2.0=h1234567_1
13 | - ncurses=6.4=h6a678d5_0
14 | - openssl=3.0.9=h7f8727e_0
15 | - pip=23.1.2=py39h06a4308_0
16 | - python=3.9.16=h955ad1f_3
17 | - readline=8.2=h5eee18b_0
18 | - setuptools=67.8.0=py39h06a4308_0
19 | - sqlite=3.41.2=h5eee18b_0
20 | - tk=8.6.12=h1ccaba5_0
21 | - wheel=0.38.4=py39h06a4308_0
22 | - xz=5.4.2=h5eee18b_0
23 | - zlib=1.2.13=h5eee18b_0
24 | - pip:
25 | - accelerate==0.18.0
26 | - aiofiles==23.1.0
27 | - aiohttp==3.8.4
28 | - aiosignal==1.3.1
29 | - altair==5.0.1
30 | - antlr4-python3-runtime==4.9.3
31 | - anyio==3.7.0
32 | - argilla==1.11.0
33 | - astor==0.8.1
34 | - async-timeout==4.0.2
35 | - attrdict==2.0.1
36 | - attrs==23.1.0
37 | - azure-core==1.27.1
38 | - babel==2.12.1
39 | - backoff==2.2.1
40 | - bce-python-sdk==0.8.83
41 | - beautifulsoup4==4.12.2
42 | - bitsandbytes==0.39.1
43 | - blinker==1.6.2
44 | - brotli==1.0.9
45 | - cachetools==5.3.1
46 | - certifi==2023.5.7
47 | - cffi==1.15.1
48 | - chardet==5.1.0
49 | - charset-normalizer==3.1.0
50 | - click==8.1.3
51 | - cmake==3.26.4
52 | - coloredlogs==15.0.1
53 | - commonmark==0.9.1
54 | - contourpy==1.1.0
55 | - cpm-kernels==1.0.11
56 | - cryptography==41.0.1
57 | - cssselect==1.2.0
58 | - cssutils==2.7.1
59 | - cycler==0.11.0
60 | - cython==0.29.35
61 | - dataclasses-json==0.5.8
62 | - decorator==5.1.1
63 | - deprecated==1.2.14
64 | - dill==0.3.6
65 | - effdet==0.4.1
66 | - et-xmlfile==1.1.0
67 | - exceptiongroup==1.1.1
68 | - faiss-cpu==1.7.4
69 | - fastapi==0.95.1
70 | - ffmpy==0.3.0
71 | - filelock==3.12.2
72 | - filetype==1.2.0
73 | - fire==0.5.0
74 | - flask==2.3.2
75 | - flask-babel==3.1.0
76 | - flatbuffers==23.5.26
77 | - fonttools==4.40.0
78 | - frozenlist==1.3.3
79 | - fsspec==2023.6.0
80 | - future==0.18.3
81 | - gevent==22.10.2
82 | - geventhttpclient==2.0.2
83 | - gradio==3.28.3
84 | - gradio-client==0.2.7
85 | - greenlet==2.0.2
86 | - grpcio==1.56.0
87 | - h11==0.14.0
88 | - httpcore==0.16.3
89 | - httpx==0.23.3
90 | - huggingface-hub==0.15.1
91 | - humanfriendly==10.0
92 | - icetk==0.0.7
93 | - idna==3.4
94 | - imageio==2.31.1
95 | - imgaug==0.4.0
96 | - importlib-metadata==6.7.0
97 | - importlib-resources==5.12.0
98 | - iopath==0.1.10
99 | - itsdangerous==2.1.2
100 | - jinja2==3.1.2
101 | - joblib==1.2.0
102 | - jsonschema==4.17.3
103 | - kiwisolver==1.4.4
104 | - langchain==0.0.174
105 | - layoutparser==0.3.4
106 | - lazy-loader==0.2
107 | - linkify-it-py==2.0.2
108 | - lit==16.0.6
109 | - lmdb==1.4.1
110 | - lxml==4.9.2
111 | - markdown==3.4.3
112 | - markdown-it-py==2.2.0
113 | - markupsafe==2.1.3
114 | - marshmallow==3.19.0
115 | - marshmallow-enum==1.5.1
116 | - matplotlib==3.7.1
117 | - mdit-py-plugins==0.3.3
118 | - mdurl==0.1.2
119 | - monotonic==1.6
120 | - mpmath==1.3.0
121 | - msg-parser==1.2.0
122 | - multidict==6.0.4
123 | - multiprocess==0.70.14
124 | - mypy-extensions==1.0.0
125 | - networkx==3.1
126 | - nltk==3.8.1
127 | - numexpr==2.8.4
128 | - numpy==1.23.5
129 | - nvidia-cublas-cu11==11.10.3.66
130 | - nvidia-cuda-cupti-cu11==11.7.101
131 | - nvidia-cuda-nvrtc-cu11==11.7.99
132 | - nvidia-cuda-runtime-cu11==11.7.99
133 | - nvidia-cudnn-cu11==8.5.0.96
134 | - nvidia-cufft-cu11==10.9.0.58
135 | - nvidia-curand-cu11==10.2.10.91
136 | - nvidia-cusolver-cu11==11.4.0.1
137 | - nvidia-cusparse-cu11==11.7.4.91
138 | - nvidia-nccl-cu11==2.14.3
139 | - nvidia-nvtx-cu11==11.7.91
140 | - olefile==0.46
141 | - omegaconf==2.3.0
142 | - onnx==1.12.0
143 | - onnxruntime==1.15.1
144 | - openapi-schema-pydantic==1.2.4
145 | - opencv-contrib-python==4.6.0.66
146 | - opencv-python==4.6.0.66
147 | - openpyxl==3.1.2
148 | - opt-einsum==3.3.0
149 | - orjson==3.9.1
150 | - packaging==23.1
151 | - paddle-bfloat==0.1.7
152 | - paddleocr==2.6.1.3
153 | - paddlepaddle==2.4.2
154 | - pandas==1.5.3
155 | - pdf2docx==0.5.6
156 | - pdf2image==1.16.3
157 | - pdfminer-six==20221105
158 | - pdfplumber==0.9.0
159 | - peft==0.3.0
160 | - pillow==9.5.0
161 | - portalocker==2.7.0
162 | - premailer==3.10.0
163 | - protobuf==3.18.3
164 | - psutil==5.9.5
165 | - pyclipper==1.3.0.post4
166 | - pycocotools==2.0.6
167 | - pycparser==2.21
168 | - pycryptodome==3.18.0
169 | - pydantic==1.10.9
170 | - pydub==0.25.1
171 | - pygments==2.15.1
172 | - pymupdf==1.20.2
173 | - pynvml==11.5.0
174 | - pypandoc==1.11
175 | - pyparsing==3.1.0
176 | - pypinyin==0.48.0
177 | - pyrsistent==0.19.3
178 | - pytesseract==0.3.10
179 | - python-dateutil==2.8.2
180 | - python-docx==0.8.11
181 | - python-magic==0.4.27
182 | - python-multipart==0.0.6
183 | - python-pptx==0.6.21
184 | - python-rapidjson==1.10
185 | - pytz==2023.3
186 | - pywavelets==1.4.1
187 | - pyyaml==6.0
188 | - rapidfuzz==3.1.1
189 | - rarfile==4.0
190 | - regex==2023.6.3
191 | - requests==2.28.2
192 | - rfc3986==1.5.0
193 | - rich==13.0.1
194 | - ruamel-yaml==0.17.32
195 | - ruamel-yaml-clib==0.2.7
196 | - safetensors==0.3.1
197 | - scikit-image==0.21.0
198 | - scikit-learn==1.2.2
199 | - scipy==1.11.0
200 | - semantic-version==2.10.0
201 | - sentence-transformers==2.2.2
202 | - sentencepiece==0.1.99
203 | - shapely==2.0.1
204 | - six==1.16.0
205 | - sniffio==1.3.0
206 | - soupsieve==2.4.1
207 | - sqlalchemy==2.0.17
208 | - starlette==0.26.1
209 | - sympy==1.12
210 | - tabulate==0.9.0
211 | - tenacity==8.2.2
212 | - termcolor==2.3.0
213 | - threadpoolctl==3.1.0
214 | - tifffile==2023.4.12
215 | - timm==0.9.2
216 | - tokenizers==0.13.3
217 | - toolz==0.12.0
218 | - torch==2.0.1
219 | - torchvision==0.15.2
220 | - tqdm==4.65.0
221 | - transformers==4.29.1
222 | - triton==2.0.0
223 | - tritonclient==2.34.0
224 | - typer==0.9.0
225 | - typing-extensions==4.6.3
226 | - typing-inspect==0.9.0
227 | - tzdata==2023.3
228 | - uc-micro-py==1.0.2
229 | - unstructured==0.7.9
230 | - unstructured-inference==0.5.1
231 | - urllib3==1.26.16
232 | - uvicorn==0.21.1
233 | - visualdl==2.5.0
234 | - wand==0.6.11
235 | - websockets==11.0.3
236 | - werkzeug==2.3.6
237 | - wrapt==1.14.1
238 | - x2paddle==1.4.1
239 | - xlrd==2.0.1
240 | - xlsxwriter==3.1.2
241 | - yarl==1.9.2
242 | - zipp==3.15.0
243 | - zope-event==5.0
244 | - zope-interface==6.0
245 |
--------------------------------------------------------------------------------
/finetune/pulse/configs/lora_config_bloom.json:
--------------------------------------------------------------------------------
1 | {
2 | "lora_r": 8,
3 | "lora_alpha": 32,
4 | "lora_dropout": 0.05,
5 | "lora_target_modules": [
6 | "query_key_value"
7 | ]
8 | }
9 |
--------------------------------------------------------------------------------
/finetune/pulse/convert_to_conv_data.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import datetime
4 | import os
5 |
6 | '''
7 | orig_data: [
8 | {
9 | "instruction": "##结构化任务##根据下文中信息,判断主动脉根部内径是什么?请提取文中对应的值",
10 | "input": "检查途径:经体表 图像等级:乙 检查项目:二维 M型 彩色 多普勒(脉冲式 连续式)一、M型主要测值(单位mm): 名称 测量值 正常参考值 主动脉根部内径 39 20-37 左房内径 36 19-40 左室舒张末期内径 47 35-56 左室收缩末期内径 28 20-37 室间隔厚度 11 6-11 左室后壁厚度 10 6-11 二、二维超声心动图描述:1.各房室无明显扩大,主动脉根部稍增宽。2.各心瓣膜未见明显增厚,开放不受限。3.左室壁不增厚,静息状态下左室各节段收缩活动未见明显异常。三、彩色多普勒超声描述:1.各心瓣膜未见明显异常反流。2.舒张期经二尖瓣口血流频谱:E/A<1。舒张期经二尖瓣口血流:E=54cm/s,A=84cm/s。3.房、室间隔水平未见明显异常分流信号。四、左心功能测定: 名称 测量值 左室舒张末期容量(ml) 102 左室收缩末期容量(ml) 29 每搏输出量(ml) 74 左室短轴缩短率(%) 41 左室射血分数(%) 72 五、组织多普勒检查: 二尖瓣瓣环水平:室间隔侧: e'=6cm/s, E/e'=9。 左室侧壁: e'=9cm/s, E/e'=6。",
11 | "output": "39"
12 | },
13 | {
14 | "instruction": "##结构化任务##根据下文中信息,判断主动脉根部内径是什么?请提取文中对应的值",
15 | "input": "检查途径:经体表 图像等级:乙 检查项目:二维 M型 彩色 多普勒(脉冲式 连续式)一、M型主要测值(单位mm): 名称 测量值 正常参考值 主动脉根部内径 38 20-37 左房内径 44 19-40 左室舒张末期内径 52 35-56 左室收缩末期内径 33 20-37 室间隔厚度 13 6-11 左室后壁厚度 13 6-11 二、二维超声心动图描述:1.左房增大,主动脉根部内径增宽。2.各心瓣膜未见明显增厚,开放不受限。3.左室壁增厚,静息状态下左室各节段收缩活动未见明显异常。三、彩色多普勒超声描述:1.各心瓣膜示轻度微主动脉瓣反流。2.舒张期经二尖瓣口血流频谱:E/A<1。舒张期经二尖瓣口血流:E=63cm/s,A=91cm/s。3.房、室间隔水平未见明显异常分流信号。四、左心功能测定: 名称 测量值 左室舒张末期容量(ml) 129 左室收缩末期容量(ml) 45 每搏输出量(ml) 84 左室短轴缩短率(%) 36 左室射血分数(%) 65 五、组织多普勒检查: 二尖瓣瓣环水平:室间隔侧: e'=5cm/s, E/e'=12.6。 左室侧壁: e'=13cm/s, E/e'=4.9。",
16 | "output": "38"
17 | }]
18 | convert: {
19 | "id": xxx,
20 | "conversations":[
21 | {"from": "human", "value": "题目:小明买了一支钢笔,花费了5元,又买了一本书,花费8元,现在他手里还有10元钱,他手上原来有多少钱?"},
22 | {"from": "assistant", "value": "\n令小明手上原来有的钱为X元。根据题目描述,得出以下方程式:\nX - 5 - 8 = 10\n化简可得:\nX = 23\n因此,小明手上原来有23元钱。"},
23 | ]
24 | }
25 | '''
26 |
27 | def main():
28 | parser = argparse.ArgumentParser()
29 | parser.add_argument(
30 | "--orig_data",
31 | )
32 | parser.add_argument(
33 | "--write_data",
34 | )
35 | parser.add_argument(
36 | "--dataset_name",
37 | )
38 | args = parser.parse_args()
39 | f_write = open(args.write_data,"w")
40 | with open(args.orig_data, 'r', encoding="utf-8") as f:
41 | datas = json.load(f)
42 | num_id = 1
43 | for data in datas:
44 | conversations = [{"from": "human", "value": data['instruction']+'\n'+data['input']},{"from": "assistant", "value": data['output']}]
45 | uniq_id = data['id'] if "id" in data else args.dataset_name+"-"+str(num_id)
46 | item = {"id":uniq_id, "conversations": conversations}
47 | f_write.write(json.dumps(item, ensure_ascii=False)+"\n")
48 | num_id += 1
49 | f_write.close()
50 |
51 |
52 | if __name__ == "__main__":
53 | main()
--------------------------------------------------------------------------------
/finetune/pulse/finetune.py:
--------------------------------------------------------------------------------
1 |
2 | from transformers.utils import add_start_docstrings
3 | from transformers.trainer_utils import get_last_checkpoint
4 | from transformers.trainer_pt_utils import torch_distributed_zero_first
5 | from transformers import (AutoModelForCausalLM, AutoTokenizer,
6 | HfArgumentParser, LlamaTokenizer, TrainingArguments,
7 | set_seed)
8 | from peft import LoraConfig, get_peft_model, prepare_model_for_int8_training
9 | from datasets import load_dataset
10 | import transformers
11 | import torch
12 | from typing import Optional
13 | from functools import partial
14 | from dataclasses import dataclass, field
15 | import os
16 | import math
17 | import logging
18 | import json
19 | import sys
20 |
21 | from src.utils import get_model_param_count
22 | from src.trainer import MyTrainer as Trainer
23 | from src.sample_generator import generate_and_tokenize_prompt
24 |
25 |
26 | logger = logging.getLogger(__name__)
27 |
28 |
29 | @dataclass
30 | class ModelArguments:
31 | """
32 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
33 | """
34 |
35 | model_name_or_path: Optional[str] = field(
36 | default=None,
37 | metadata={
38 | "help": (
39 | "The model checkpoint for weights initialization.Don't set if you want to train a model from scratch."
40 | )
41 | },
42 | )
43 | config_name: Optional[str] = field(
44 | default=None,
45 | metadata={
46 | "help": "Pretrained config name or path if not the same as model_name"
47 | },
48 | )
49 | tokenizer_name: Optional[str] = field(
50 | default=None,
51 | metadata={
52 | "help": "Pretrained tokenizer name or path if not the same as model_name"
53 | },
54 | )
55 | cache_dir: Optional[str] = field(
56 | default=None,
57 | metadata={
58 | "help": "Where do you want to store the pretrained models downloaded from huggingface.co"
59 | },
60 | )
61 | use_fast_tokenizer: bool = field(
62 | default=True,
63 | metadata={
64 | "help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."
65 | },
66 | )
67 | torch_dtype: Optional[str] = field(
68 | default=None,
69 | metadata={
70 | "help": (
71 | "Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the "
72 | "dtype will be automatically derived from the model's weights."
73 | ),
74 | "choices": ["auto", "bfloat16", "float16", "float32"],
75 | },
76 | )
77 | llama: bool = field(default=False, metadata={"help": "Llama model"})
78 |
79 |
80 | @dataclass
81 | class DataArguments:
82 | """
83 | Arguments pertaining to what data we are going to input our model for training and eval.
84 | """
85 |
86 | dataset_name: Optional[str] = field(
87 | default=None,
88 | metadata={
89 | "help": "The name of the dataset to use (via the datasets library)."},
90 | )
91 | dataset_config_name: Optional[str] = field(
92 | default=None,
93 | metadata={
94 | "help": "The configuration name of the dataset to use (via the datasets library)."
95 | },
96 | )
97 | train_file: Optional[str] = field(
98 | default=None, metadata={"help": "The input training data file (a text file)."}
99 | )
100 | validation_file: Optional[str] = field(
101 | default=None,
102 | metadata={
103 | "help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."
104 | },
105 | )
106 | validation_split_percentage: Optional[int] = field(
107 | default=5,
108 | metadata={
109 | "help": "The percentage of the train set used as validation set in case there's no validation split"
110 | },
111 | )
112 |
113 |
114 | @dataclass
115 | @add_start_docstrings(TrainingArguments.__doc__)
116 | class TrainingArguments(TrainingArguments):
117 | model_max_length: int = field(
118 | default=512,
119 | metadata={"help": "Maximum sequence length."},
120 | )
121 | use_lora: bool = field(
122 | default=False,
123 | metadata={"help": "Whether to use LoRA."}
124 | )
125 | use_int8_training: bool = field(
126 | default=False, metadata={"help": "Whether to use int8 training."}
127 | )
128 | lora_config: Optional[str] = field(
129 | default=None,
130 | metadata={"help": "LoRA config file."},
131 | )
132 | ddp_find_unused_parameters: bool = field(
133 | default=False, metadata={"help": "ddp_find_unused_parameters"}
134 | )
135 | gradient_checkpointing: bool = field(
136 | default=False, metadata={"help": "gradient_checkpointing"}
137 | )
138 | # https://discuss.huggingface.co/t/wandb-does-not-display-train-eval-loss-except-for-last-one/9170
139 | evaluation_strategy: str = field(
140 | default="steps", metadata={"help": "wandb bug fix"}
141 | )
142 | save_total_limit: Optional[int] = field(
143 | default=3,
144 | metadata={
145 | "help": "keep saved model less than save_total_limit, delete old checkpoints when save new model"}
146 | )
147 | report_to: str = field(
148 | default="none",
149 | metadata={"help": "places where report the results"}
150 | )
151 |
152 |
153 | def print_rank_0(msg, log_file, rank=0):
154 | if rank <= 0:
155 | with open(log_file, "a") as f:
156 | print(msg)
157 | f.write(msg + "\n")
158 |
159 |
160 | def main():
161 | parser = HfArgumentParser(
162 | (ModelArguments, DataArguments, TrainingArguments)
163 | )
164 | model_args, data_args, training_args = parser.parse_args_into_dataclasses()
165 |
166 | world_size = int(os.environ.get("WORLD_SIZE", 1))
167 | global_rank = torch.distributed.get_rank()
168 | if not os.path.exists(training_args.output_dir):
169 | os.makedirs(training_args.output_dir)
170 | log_file = os.path.join(training_args.output_dir, "print_log.txt")
171 |
172 | # Setup logging
173 | logging.basicConfig(
174 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
175 | datefmt="%m/%d/%Y %H:%M:%S",
176 | handlers=[logging.StreamHandler(sys.stdout)],
177 | )
178 |
179 | if training_args.should_log:
180 | # The default of training_args.log_level is passive, so we set log level at info here to have that default.
181 | transformers.utils.logging.set_verbosity_info()
182 |
183 | log_level = training_args.get_process_log_level()
184 | logger.setLevel(log_level)
185 | transformers.utils.logging.set_verbosity(log_level)
186 | transformers.utils.logging.enable_default_handler()
187 | transformers.utils.logging.enable_explicit_format()
188 |
189 | # Log on each process the small summary:
190 | logger.warning(
191 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}, distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
192 | )
193 | logger.info(f"Training/evaluation parameters {training_args}")
194 |
195 | # Detecting last checkpoint.
196 | last_checkpoint = None
197 | if (
198 | os.path.isdir(training_args.output_dir)
199 | and training_args.do_train
200 | and not training_args.overwrite_output_dir
201 | ):
202 | last_checkpoint = get_last_checkpoint(training_args.output_dir)
203 | if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
204 | raise ValueError(
205 | f"Output directory ({training_args.output_dir}) already exists and is not empty. "
206 | "Use --overwrite_output_dir to overcome."
207 | )
208 | elif (
209 | last_checkpoint is not None and training_args.resume_from_checkpoint is None
210 | ):
211 | logger.info(
212 | f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
213 | "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
214 | )
215 |
216 | # Set seed before initializing model.
217 | set_seed(training_args.seed)
218 |
219 | torch_dtype = (
220 | model_args.torch_dtype
221 | if model_args.torch_dtype in ["auto", None]
222 | else getattr(torch, model_args.torch_dtype)
223 | )
224 | # int8 is not compatible with DeepSpeed (require not to pass device_map)
225 | if training_args.use_int8_training:
226 | print_rank_0(
227 | "int8 is not compatible with DeepSpeed. ",
228 | log_file,
229 | global_rank
230 | )
231 | device_map = (
232 | {"": int(os.environ.get("LOCAL_RANK") or 0)}
233 | if world_size != 1 else "auto"
234 | )
235 | # device_map = "auto"
236 | model = AutoModelForCausalLM.from_pretrained(
237 | model_args.model_name_or_path,
238 | load_in_8bit=True, # xxx: int8 load in
239 | device_map=device_map, # xxx: int8 requires passing device_map
240 | torch_dtype=torch_dtype,
241 | )
242 | else:
243 | model = AutoModelForCausalLM.from_pretrained(
244 | model_args.model_name_or_path,
245 | torch_dtype=torch_dtype,
246 | )
247 |
248 | if model_args.llama:
249 | tokenizer = LlamaTokenizer.from_pretrained(
250 | model_args.model_name_or_path
251 | )
252 | print_rank_0(
253 | "Set the eos_token_id and bos_token_id of LLama model tokenizer",
254 | log_file,
255 | global_rank,
256 | )
257 | tokenizer.eos_token_id = 2
258 | tokenizer.bos_token_id = 1
259 | else:
260 | tokenizer = AutoTokenizer.from_pretrained(
261 | model_args.model_name_or_path
262 | )
263 |
264 | tokenizer.pad_token_id = 0
265 | tokenizer.padding_side = "left" # Allow batched inference
266 |
267 | print_rank_0(
268 | "tokenizer.eos_token_id = {}".format(tokenizer.eos_token_id),
269 | log_file,
270 | global_rank,
271 | )
272 | print_rank_0(
273 | "tokenizer.pad_token_id = {}".format(tokenizer.pad_token_id),
274 | log_file,
275 | global_rank,
276 | )
277 | print_rank_0(
278 | "tokenizer.bos_token_id = {}".format(tokenizer.bos_token_id),
279 | log_file,
280 | global_rank,
281 | )
282 |
283 | # peft model
284 | if training_args.use_lora:
285 | print_rank_0(
286 | "Loading lora config from {}".format(training_args.lora_config),
287 | log_file,
288 | global_rank,
289 | )
290 | lora_config = json.load(open(training_args.lora_config))
291 | print_rank_0(
292 | "Lora config: {}".format(lora_config),
293 | log_file,
294 | global_rank
295 | )
296 | if training_args.use_int8_training:
297 | print_rank_0(
298 | "training_args.use_int8_training!!! (int8 is not compatible with DeepSpeed)",
299 | log_file,
300 | global_rank,
301 | )
302 | model = prepare_model_for_int8_training(model)
303 | config = LoraConfig(
304 | r=lora_config["lora_r"],
305 | lora_alpha=lora_config["lora_alpha"],
306 | target_modules=lora_config["lora_target_modules"],
307 | lora_dropout=lora_config["lora_dropout"],
308 | bias="none",
309 | task_type="CAUSAL_LM",
310 | )
311 |
312 | # "RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn"
313 | if hasattr(model, "enable_input_require_grads"):
314 | model.enable_input_require_grads()
315 | else:
316 |
317 | def make_inputs_require_grad(module, input, output):
318 | output.requires_grad_(True)
319 |
320 | model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
321 |
322 | model = get_peft_model(model, config)
323 | model.print_trainable_parameters()
324 |
325 | if training_args.gradient_checkpointing:
326 | model.gradient_checkpointing_enable()
327 |
328 | # model.is_parallelizable = True
329 | # model.model_parallel = True
330 |
331 | assert os.path.exists(data_args.train_file), "{} file not exists".format(
332 | data_args.train_file
333 | )
334 |
335 | with torch_distributed_zero_first(global_rank):
336 | train_data = load_dataset(
337 | "json",
338 | data_files=data_args.train_file,
339 | cache_dir=model_args.cache_dir
340 | )
341 |
342 | val_data = load_dataset(
343 | "json",
344 | data_files=data_args.validation_file,
345 | cache_dir=model_args.cache_dir
346 | )
347 |
348 | train_data = train_data["train"].shuffle().map(
349 | partial(
350 | generate_and_tokenize_prompt,
351 | training_args.model_max_length,
352 | tokenizer
353 | )
354 | )
355 |
356 | val_data = val_data["train"].shuffle().map(
357 | partial(
358 | generate_and_tokenize_prompt,
359 | training_args.model_max_length,
360 | tokenizer
361 | )
362 | )
363 |
364 | for i in range(2):
365 | print_rank_0(
366 | "Eval tokenized example: {}".format(val_data[i]),
367 | log_file,
368 | global_rank
369 | )
370 | for i in range(2):
371 | print_rank_0(
372 | "Train tokenized example: {}".format(train_data[i]),
373 | log_file,
374 | global_rank
375 | )
376 |
377 | training_nums = len(train_data)
378 | num_gpus = torch.cuda.device_count()
379 |
380 | batch_size = (
381 | training_args.per_device_train_batch_size
382 | * training_args.world_size
383 | * training_args.gradient_accumulation_steps
384 | )
385 | # train steps
386 | t_total = math.ceil(training_nums / batch_size) * \
387 | training_args.num_train_epochs
388 | # eval steps
389 | training_args.eval_steps = max(t_total // 5, 5)
390 | # save steps
391 | training_args.save_steps = training_args.eval_steps
392 | training_args.warmup_steps = (
393 | int(t_total * training_args.warmup_ratio)
394 | if training_args.warmup_ratio > 0.0
395 | else training_args.warmup_steps
396 | )
397 | print_rank_0(
398 | "num_gpus = {}, training_nums = {}, t_total = {}, warmup_steps = {}, eval_steps = {}, save_steps = {}".format(
399 | num_gpus,
400 | training_nums,
401 | t_total,
402 | training_args.warmup_steps,
403 | training_args.eval_steps,
404 | training_args.save_steps,
405 | ),
406 | log_file,
407 | global_rank,
408 | )
409 | print_rank_0(
410 | "val data nums = {}, training_nums = {}, batch_size = {}".format(
411 | len(val_data), training_nums, batch_size
412 | ),
413 | log_file,
414 | global_rank,
415 | )
416 |
417 | # Trainer
418 | # https://github.com/huggingface/transformers/blob/main/src/transformers/training_args.py
419 | # https://github.com/huggingface/transformers/blob/main/src/transformers/data/data_collator.py
420 | # https://github.com/huggingface/transformers/blob/main/src/transformers/trainer.py
421 | # https://www.deepspeed.ai/docs/config-json/
422 | # https://huggingface.co/docs/accelerate/usage_guides/deepspeed
423 | # https://huggingface.co/transformers/v4.10.1/main_classes/deepspeed.html
424 | # https://github.com/tatsu-lab/stanford_alpaca/issues/176
425 | trainer = Trainer(
426 | model=model,
427 | args=training_args,
428 | train_dataset=train_data,
429 | eval_dataset=val_data,
430 | data_collator=transformers.DataCollatorForSeq2Seq(
431 | tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
432 | ),
433 | )
434 |
435 | print_rank_0(
436 | f"Using {training_args.half_precision_backend} half precision backend",
437 | log_file,
438 | global_rank,
439 | )
440 | # Train!
441 | len_dataloader = len(trainer.get_train_dataloader())
442 | num_update_steps_per_epoch = (
443 | len_dataloader // training_args.gradient_accumulation_steps
444 | )
445 |
446 | total_train_batch_size = (
447 | training_args.train_batch_size
448 | * training_args.gradient_accumulation_steps
449 | * training_args.world_size
450 | )
451 | num_examples = trainer.num_examples(trainer.get_train_dataloader())
452 | num_train_samples = num_examples * training_args.num_train_epochs
453 | max_steps = math.ceil(training_args.num_train_epochs * \
454 | num_update_steps_per_epoch)
455 | print_rank_0("***** Running training *****", log_file, global_rank)
456 | print_rank_0(f" Num examples = {num_examples}", log_file, global_rank)
457 | print_rank_0(
458 | f" Num train samples = {num_train_samples}",
459 | log_file,
460 | global_rank
461 | )
462 | print_rank_0(f" world_size = {world_size}", log_file, global_rank)
463 | print_rank_0(
464 | f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}",
465 | log_file,
466 | global_rank,
467 | )
468 | print_rank_0(
469 | f" Gradient Accumulation steps = {training_args.gradient_accumulation_steps}",
470 | log_file,
471 | global_rank,
472 | )
473 | print_rank_0(
474 | f" Total optimization steps = {max_steps}",
475 | log_file,
476 | global_rank
477 | )
478 |
479 | print_rank_0(
480 | f" Number of trainable parameters = {get_model_param_count(model, trainable_only=True)}",
481 | log_file,
482 | global_rank,
483 | )
484 |
485 | # https://discuss.huggingface.co/t/what-is-the-purpose-of-use-cache-in-decoder/958/3
486 | model.config.use_cache = False
487 |
488 | trainer.train(resume_from_checkpoint=None)
489 |
490 | print_rank_0(
491 | "\n Training completed!!! If there's a warning about missing keys above, please disregard :)",
492 | log_file,
493 | global_rank,
494 | )
495 |
496 |
497 | if __name__ == "__main__":
498 | main()
499 |
--------------------------------------------------------------------------------
/finetune/pulse/src/sample_generator.py:
--------------------------------------------------------------------------------
1 | import pudb
2 | import copy
3 | from transformers import PreTrainedTokenizer
4 | import json
5 | IGNORE_INDEX = -100
6 |
7 |
8 | def generate_and_tokenize_prompt(model_max_length: int, tokenizer: PreTrainedTokenizer, data_point):
9 | input_ids = []
10 | labels = []
11 | source = data_point["conversations"]
12 | for sentence in source:
13 | sentence_from = sentence["from"].lower()
14 | sentence_value = (
15 | "Human: \n" + sentence["value"] + "\n\nAssistant: \n"
16 | if sentence_from == "human"
17 | else sentence["value"]
18 | ) # https://github.com/LianjiaTech/BELLE/issues/337
19 | # conversation += sentence_value
20 | sentence_ids = tokenizer.encode(
21 | sentence_value, add_special_tokens=False
22 | ) # do not add bos_token_id
23 | label = (
24 | copy.deepcopy(sentence_ids)
25 | if sentence_from != "human"
26 | else [IGNORE_INDEX] * len(sentence_ids)
27 | )
28 | input_ids += sentence_ids
29 | labels += label
30 | # add eos at every end of assistant sentence
31 | if sentence_from != "human":
32 | input_ids += [
33 | tokenizer.eos_token_id
34 | ] # make sure eos_token_id is correct
35 | labels += [tokenizer.eos_token_id]
36 |
37 | input_ids = input_ids[: model_max_length - 1]
38 | labels = labels[: model_max_length - 1]
39 | if all(x == IGNORE_INDEX for x in labels):
40 | labels[18:24] = input_ids[
41 | 18:24
42 | ] # labels can not have all values being -100. 18 and 24 are just random numbers
43 |
44 | attention_mask = [1] * len(input_ids)
45 | tokenized_full_prompt = {
46 | "input_ids": input_ids,
47 | "attention_mask": attention_mask,
48 | "labels": labels,
49 | }
50 | return tokenized_full_prompt
51 |
52 |
53 | def pretrain_generate(model_max_length: int, tokenizer: PreTrainedTokenizer, data_point):
54 | input_ids = tokenizer.encode(data_point['text'])
55 | labels = copy.deepcopy(input_ids)
56 | input_ids += [tokenizer.eos_token_id]
57 | labels += [tokenizer.eos_token_id]
58 | input_ids = input_ids[: model_max_length]
59 | labels = labels[: model_max_length]
60 | return {
61 | "input_ids": input_ids,
62 | "attention_mask": [1] * len(input_ids),
63 | "labels": labels,
64 | }
65 |
66 |
67 | def exam_generate(model_max_length: int, tokenizer: PreTrainedTokenizer, data_point):
68 | template = 'Human: \n{human}\n\nAssistant: \n'
69 | # pudb.set_trace()
70 | input_str = template.format(
71 | human=f'回答下面的{data_point["type"]}题,用json返回答案,包括原因和答案,如{{"reason":..., "answer":...}}\n{data_point["question"]}\n选项:{" ".join(data_point["candidates"])}'
72 | )
73 | input_ids = tokenizer.encode(
74 | input_str,
75 | add_special_tokens=False
76 | )
77 | labels = [IGNORE_INDEX] * len(input_ids)
78 | bot_ids = tokenizer.encode(
79 | json.dumps(
80 | {
81 | 'reason': data_point['reason'],
82 | 'answer': data_point['answer']
83 | }, ensure_ascii=False
84 | ),
85 | add_special_tokens=False
86 | )
87 | input_ids += bot_ids
88 | labels += bot_ids
89 |
90 | input_ids += [tokenizer.eos_token_id]
91 | labels += [tokenizer.eos_token_id]
92 |
93 | input_ids = input_ids[: model_max_length - 1]
94 | labels = labels[: model_max_length - 1]
95 | return {
96 | "input_ids": input_ids,
97 | "attention_mask": [1] * len(input_ids),
98 | "labels": labels,
99 | }
100 |
--------------------------------------------------------------------------------
/finetune/pulse/src/trainer.py:
--------------------------------------------------------------------------------
1 | from peft import PeftModel
2 | from transformers.trainer import *
3 |
4 | from src.utils import get_ds_state_dict
5 | import re
6 |
7 | def remove_last_directory(path):
8 | pattern = r'^(.*)/[^/]+$' # 匹配目录路径和最后一层目录名
9 | match = re.match(pattern, path)
10 | if match:
11 | return match.group(1)
12 | else:
13 | return path
14 |
15 | class MyTrainer(Trainer):
16 | def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False):
17 | """
18 | Add supports for peft
19 |
20 | Will save the model, so you can reload it using `from_pretrained()`.
21 |
22 | Will only save from the main process.
23 | """
24 |
25 | if output_dir is None:
26 | output_dir = self.args.output_dir
27 |
28 | if is_torch_tpu_available():
29 | self._save_tpu(output_dir)
30 | elif is_sagemaker_mp_enabled():
31 | # Calling the state_dict needs to be done on the wrapped model and on all processes.
32 | os.makedirs(output_dir, exist_ok=True)
33 | state_dict = self.model_wrapped.state_dict()
34 | if self.args.should_save:
35 | self._save(output_dir, state_dict=state_dict)
36 | self._save(remove_last_directory(output_dir), state_dict=state_dict)
37 | if IS_SAGEMAKER_MP_POST_1_10:
38 | # 'user_content.pt' indicates model state_dict saved with smp >= 1.10
39 | Path(os.path.join(output_dir, "user_content.pt")).touch()
40 | elif (
41 | ShardedDDPOption.ZERO_DP_2 in self.args.sharded_ddp
42 | or ShardedDDPOption.ZERO_DP_3 in self.args.sharded_ddp
43 | or self.fsdp is not None
44 | ):
45 | state_dict = self.model.state_dict()
46 |
47 | if self.args.should_save:
48 | self._save(output_dir, state_dict=state_dict)
49 | self._save(remove_last_directory(output_dir), state_dict=state_dict)
50 | elif self.deepspeed:
51 | # This must be called on all ranks in stage 3
52 | if is_deepspeed_zero3_enabled():
53 | state_dict = get_ds_state_dict(self.deepspeed)
54 | else:
55 | # Only run on rank 0 except stage 3
56 | if self.args.should_save:
57 | state_dict = get_ds_state_dict(self.deepspeed)
58 | # this takes care of everything as long as we aren't under zero3
59 | # Only run on rank 0
60 | if self.args.should_save:
61 | # state_dict is available on rank 0
62 | self._save(output_dir, state_dict=state_dict)
63 | self._save(remove_last_directory(output_dir), state_dict=state_dict)
64 |
65 | elif self.args.should_save:
66 | self._save(output_dir)
67 | self._save(remove_last_directory(output_dir))
68 |
69 | # Push to the Hub when `save_model` is called by the user.
70 | if self.args.push_to_hub and not _internal_call:
71 | self.push_to_hub(commit_message="Model save")
72 |
73 | def _save(self, output_dir: Optional[str] = None, state_dict=None):
74 | """
75 | Add supports for peft
76 | """
77 | # If we are executing this function, we are the process zero, so we don't check for that.
78 | output_dir = output_dir if output_dir is not None else self.args.output_dir
79 | os.makedirs(output_dir, exist_ok=True)
80 | logger.info(f"Saving model checkpoint to {output_dir}")
81 | # Save a trained model and configuration using `save_pretrained()`.
82 | # They can then be reloaded using `from_pretrained()`
83 | if not isinstance(self.model, (PreTrainedModel, PeftModel)):
84 | if state_dict is None:
85 | state_dict = self.model.state_dict()
86 |
87 | if isinstance(unwrap_model(self.model), (PreTrainedModel, PeftModel)):
88 | unwrap_model(self.model).save_pretrained(
89 | output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors
90 | )
91 | unwrap_model(self.model).save_pretrained(
92 | remove_last_directory(output_dir), state_dict=state_dict, safe_serialization=self.args.save_safetensors
93 | )
94 | else:
95 | logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
96 | if self.args.save_safetensors:
97 | safetensors.torch.save_file(state_dict, os.path.join(output_dir, SAFE_WEIGHTS_NAME))
98 | else:
99 | torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
100 | else:
101 | self.model.save_pretrained(
102 | output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors
103 | )
104 |
105 | if self.tokenizer is not None:
106 | self.tokenizer.save_pretrained(output_dir)
107 |
108 | # Good practice: save your training arguments together with the trained model
109 | torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
110 |
--------------------------------------------------------------------------------
/finetune/pulse/src/utils.py:
--------------------------------------------------------------------------------
1 | import time
2 | from typing import Any, List, Union
3 | from gradio_client import Client
4 | from tqdm import tqdm
5 | from transformers.deepspeed import is_deepspeed_zero3_enabled
6 | import torch
7 | import traceback
8 | try:
9 | from deepspeed.runtime.engine import DeepSpeedEngine
10 | def get_ds_state_dict(ds_engine: DeepSpeedEngine):
11 | """
12 | 如果是zero stage 3,要对所有rank调用,无视掉stage3_gather_16bit_weights_on_model_save参数
13 | """
14 | if ds_engine.zero_optimization_partition_weights():
15 | # consolidation is expensive in time and memory and therefore isn't a default
16 | state_dict = ds_engine._zero3_consolidated_16bit_state_dict()
17 | else:
18 | state_dict = ds_engine.module.state_dict()
19 | return state_dict
20 |
21 |
22 | def get_model_param_count(model: Union[DeepSpeedEngine, torch.nn.Module], trainable_only=False):
23 | """
24 | Calculate model's total param count. If trainable_only is True then count only those requiring grads
25 | """
26 | if is_deepspeed_zero3_enabled() and isinstance(model, DeepSpeedEngine):
27 | def numel(p):
28 | return p.ds_numel
29 |
30 | else:
31 | def numel(p):
32 | return p.numel()
33 |
34 | return sum(numel(p) for p in model.parameters() if not trainable_only or p.requires_grad)
35 | except ModuleNotFoundError:
36 | traceback.print_exc()
37 | DeepSpeedEngine = None
38 |
39 |
40 | def get_model_param_count(model: Union[DeepSpeedEngine, torch.nn.Module], trainable_only=False):
41 | """
42 | Calculate model's total param count. If trainable_only is True then count only those requiring grads
43 | """
44 | if is_deepspeed_zero3_enabled() and isinstance(model, DeepSpeedEngine):
45 | def numel(p):
46 | return p.ds_numel
47 |
48 | else:
49 | def numel(p):
50 | return p.numel()
51 |
52 | return sum(numel(p) for p in model.parameters() if not trainable_only or p.requires_grad)
53 |
54 | def get_ds_state_dict(ds_engine: DeepSpeedEngine):
55 |
56 | return {}
57 |
58 |
59 |
60 | class MultiClient(object):
61 | def __init__(self, worker_addrs) -> None:
62 | self.clients = [Client(addr) for addr in worker_addrs]
63 |
64 | def predict(self, tasks: List[List], max_retries: int = 3) -> List[Any]:
65 | pbar = tqdm(total=len(tasks))
66 | jobs = {
67 | client: (i, client.submit(*(tasks[i]), api_name="/predict"))
68 | for i, client in enumerate(self.clients)
69 | if i < len(tasks)
70 | }
71 | results = {}
72 | retries = {i: 0 for i in range(len(tasks))}
73 |
74 | while jobs:
75 | for client, (i, job) in list(jobs.items()):
76 | if job.done():
77 | pbar.update(1)
78 | del jobs[client]
79 | try:
80 | result = job.result()
81 | results[i] = result
82 | except Exception as e:
83 | print("Job failed with error:", e)
84 | if retries[i] < max_retries:
85 | print("Retrying job...")
86 | retries[i] += 1
87 | new_job = client.submit(
88 | *tasks[i], api_name="/predict")
89 | jobs[client] = (i, new_job)
90 | continue # Skip the rest of the loop
91 | else:
92 | results[i] = None
93 |
94 | if tasks:
95 | new_i = len(results) + len(jobs)
96 | if new_i < len(tasks):
97 | new_task = tasks[new_i]
98 | new_job = client.submit(
99 | *new_task, api_name="/predict")
100 | jobs[client] = (new_i, new_job)
101 | time.sleep(0.1)
102 | pbar.close()
103 |
104 | predicts = [results[i] for i in sorted(results)]
105 |
106 | return predicts
107 |
--------------------------------------------------------------------------------
/finetune/pulse_utils.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import json
3 | import datetime,time
4 | import shutil
5 | import os
6 | import re
7 | import subprocess
8 | from sklearn.metrics import classification_report
9 | from pynvml import (nvmlInit, nvmlDeviceGetCount, nvmlDeviceGetHandleByIndex,
10 | nvmlDeviceGetName, nvmlDeviceGetMemoryInfo, nvmlShutdown)
11 | from configs.common_config import *
12 |
13 | model_loaded = False
14 | project_change = False
15 | last_lora_name = ''
16 | max_new_tokens = 1500
17 | generation_config = dict(
18 | temperature=0.001,
19 | top_k=30,
20 | top_p=0.85,
21 | do_sample=False,
22 | num_beams=1,
23 | repetition_penalty=1.2,
24 | max_new_tokens=max_new_tokens
25 | )
26 | def read_excel_file(file_path):
27 | df = pd.read_excel(file_path)
28 | return df
29 |
30 | def save_to_excel(df, file_path):
31 | df.to_excel(file_path, index=False)
32 |
33 | def process_data(training_data_path):
34 | # 读取 Excel 文件
35 | df = pd.read_excel(training_data_path)
36 | log = []
37 | log.append(f'开始处理数据')
38 |
39 | all_data = []
40 | # 遍历每一行数据
41 | for index, row in df.iterrows():
42 | instruction = row['系统指示']
43 | question = row['问题']
44 | answer = row['回答']
45 |
46 | # 创建字典并将数据添加到列表中
47 | data = {"instruction": instruction, "input": question, "output": answer}
48 | all_data.append(data)
49 |
50 | log = '\n'.join(log) # 使用换行符拼接log内容
51 | return all_data, log
52 |
53 |
54 | def get_available_gpu(threshold=20000):
55 | # Initialize NVML
56 | nvmlInit()
57 |
58 | # Get the number of GPU devices
59 | device_count = nvmlDeviceGetCount()
60 |
61 | # Find GPU devices with available memory greater than the threshold
62 | available_gpus = []
63 | for i in range(device_count):
64 | handle = nvmlDeviceGetHandleByIndex(i)
65 | info = nvmlDeviceGetMemoryInfo(handle)
66 | free_memory_mb = info.free / 1024 / 1024
67 |
68 | if free_memory_mb > threshold:
69 | available_gpus.append(i)
70 |
71 | # Shutdown NVML
72 | nvmlShutdown()
73 |
74 | return available_gpus
75 |
76 | def pulse_train_model(model_name, lora_name, training_data_path):
77 | now_str = datetime.datetime.now().strftime('%Y%m%d_%H%M')
78 | print('now_str',now_str)
79 | all_data,log = process_data(training_data_path)
80 | log_file_path = f'data/logs/{now_str}.log' # 定义log文件路径
81 | os.makedirs(os.path.dirname(log_file_path), exist_ok=True) # 创建存储log的文件夹
82 |
83 | with open(log_file_path, 'w', encoding="utf-8") as f:
84 | f.write(log) # 将log内容写入文件
85 | with open(f"data/{lora_name}_dataset.json", "w", encoding="utf-8") as f:
86 | json.dump(all_data, f, indent=4, ensure_ascii=False)
87 | if not os.path.exists('finetune/pulse/data'):
88 | os.makedirs('finetune/pulse/data')
89 | if not os.path.exists('finetune/pulse/logs'):
90 | os.makedirs('finetune/pulse/logs')
91 | shutil.copyfile(f"data/{lora_name}_dataset.json", f"finetune/pulse/data/{lora_name}_dataset.json")
92 |
93 | available_gpus = get_available_gpu(threshold=20000)
94 | print('available_gpus[0]',available_gpus[0])
95 | content = f'''python convert_to_conv_data.py --orig_data data/{lora_name}_dataset.json --write_data data/{lora_name}_dataset_conv.json --dataset_name {lora_name}
96 |
97 | CUDA_VISIBLE_DEVICES={available_gpus[0]} torchrun --nproc_per_node 1 finetune.py --model_name_or_path {llm_model_dict[model_name]["local_model_path"]} --use_lora True --use_int8_training --lora_config configs/lora_config_bloom.json --train_file data/{lora_name}_dataset_conv.json --validation_file data/{lora_name}_dataset_conv.json --per_device_train_batch_size 1 --per_device_eval_batch_size 1 --gradient_accumulation_steps 2 --num_train_epochs 2 --model_max_length 100 --save_strategy "steps" --save_total_limit 3 --learning_rate 3e-4 --weight_decay 0.00001 --warmup_ratio 0.05 --lr_scheduler_type "cosine" --logging_steps 10 --evaluation_strategy "steps" --seed 2048 --gradient_checkpointing True --cache_dir cache/{lora_name} --output_dir output/{lora_name}
98 | '''
99 | sh_file_name = f'finetune/pulse/train_{lora_name}.sh'
100 |
101 | with open(sh_file_name , 'w') as file:
102 | file.write(content)
103 |
104 | # 设置文件可执行权限
105 | os.chmod(sh_file_name , 0o755)
106 | now_str = datetime.datetime.now().strftime('%Y%m%d_%H%M')
107 | print('now_str',now_str)
108 | subprocess.Popen(f"""cd finetune/pulse && . /home/pai/etc/profile.d/conda.sh && conda activate med_llm && nohup sh train_{lora_name}.sh > ./logs/train_{now_str}.log 2>&1 &""", shell=True)
109 | print('finish')
110 | # model.train(training_data_path)
111 | return f'{model_name} on training'
112 |
113 | def stop_train_process():
114 | process = subprocess.Popen('ps -ef | grep finetune.py', shell=True, stdout=subprocess.PIPE)
115 | output, _ = process.communicate()
116 | process.kill()
117 |
118 |
119 |
120 | n = 0
121 | # 解析输出以获取进程ID
122 | print('output',output)
123 | try:
124 | lines = output.decode().split('\n')
125 | for line in lines:
126 | if 'finetune.py' in line:
127 | parts = line.split()
128 | pid = parts[1]
129 | # 杀死进程
130 | subprocess.call(['kill', '-9', pid])
131 | n+=1
132 | except Exception as e:
133 | print('error!!',e)
134 |
135 | return f'停止了{n//2}个进程'
--------------------------------------------------------------------------------
/img/1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JuneYaooo/medical_kb_chatbot/be4ff60c47ffb31ac3052f82b5ed607d7c7ab089/img/1.jpg
--------------------------------------------------------------------------------
/img/2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JuneYaooo/medical_kb_chatbot/be4ff60c47ffb31ac3052f82b5ed607d7c7ab089/img/2.jpg
--------------------------------------------------------------------------------
/img/3.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JuneYaooo/medical_kb_chatbot/be4ff60c47ffb31ac3052f82b5ed607d7c7ab089/img/3.jpg
--------------------------------------------------------------------------------
/img/4.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JuneYaooo/medical_kb_chatbot/be4ff60c47ffb31ac3052f82b5ed607d7c7ab089/img/4.jpg
--------------------------------------------------------------------------------
/loader/__init__.py:
--------------------------------------------------------------------------------
1 | from .image_loader import UnstructuredPaddleImageLoader
2 | from .pdf_loader import UnstructuredPaddlePDFLoader
3 |
--------------------------------------------------------------------------------
/loader/image_loader.py:
--------------------------------------------------------------------------------
1 | """Loader that loads image files."""
2 | from typing import List
3 |
4 | from langchain.document_loaders.unstructured import UnstructuredFileLoader
5 | from paddleocr import PaddleOCR
6 | import os
7 | import nltk
8 | from configs.common_config import NLTK_DATA_PATH
9 |
10 | nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
11 |
12 | class UnstructuredPaddleImageLoader(UnstructuredFileLoader):
13 | """Loader that uses unstructured to load image files, such as PNGs and JPGs."""
14 |
15 | def _get_elements(self) -> List:
16 | def image_ocr_txt(filepath, dir_path="tmp_files"):
17 | full_dir_path = os.path.join(os.path.dirname(filepath), dir_path)
18 | if not os.path.exists(full_dir_path):
19 | os.makedirs(full_dir_path)
20 | filename = os.path.split(filepath)[-1]
21 | ocr = PaddleOCR(use_angle_cls=True, lang="ch", use_gpu=False, show_log=False)
22 | result = ocr.ocr(img=filepath)
23 |
24 | ocr_result = [i[1][0] for line in result for i in line]
25 | txt_file_path = os.path.join(full_dir_path, "%s.txt" % (filename))
26 | with open(txt_file_path, 'w', encoding='utf-8') as fout:
27 | fout.write("\n".join(ocr_result))
28 | return txt_file_path
29 |
30 | txt_file_path = image_ocr_txt(self.file_path)
31 | from unstructured.partition.text import partition_text
32 | return partition_text(filename=txt_file_path, **self.unstructured_kwargs)
33 |
34 |
35 | if __name__ == "__main__":
36 | filepath = os.path.join(os.path.dirname(os.path.dirname(__file__)), "content", "samples", "test.jpg")
37 | loader = UnstructuredPaddleImageLoader(filepath, mode="elements")
38 | docs = loader.load()
39 | for doc in docs:
40 | print(doc)
41 |
--------------------------------------------------------------------------------
/loader/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .bloomz_llm import Bloomz
2 |
--------------------------------------------------------------------------------
/loader/models/__main__.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 |
4 | sys.path.append(os.path.dirname(os.path.abspath(__file__)) + '/../')
5 | import asyncio
6 | from argparse import Namespace
7 | from models.loader.args import parser
8 | from models.loader import LoaderCheckPoint
9 |
10 | from langchain.agents import initialize_agent, Tool
11 | from langchain.agents import AgentType
12 |
13 | import loader.models.shared as shared
14 |
15 | from langchain.chains import LLMChain
16 | from langchain.memory import ConversationBufferMemory, ReadOnlySharedMemory
17 | from langchain.prompts import PromptTemplate
18 | from langchain.agents import ZeroShotAgent, Tool, AgentExecutor
19 | from typing import List, Set
20 |
21 |
22 |
23 | class CustomLLMSingleActionAgent(ZeroShotAgent):
24 | allowed_tools: List[str]
25 |
26 | def __init__(self, *args, **kwargs):
27 | super(CustomLLMSingleActionAgent, self).__init__(*args, **kwargs)
28 | self.allowed_tools = kwargs['allowed_tools']
29 |
30 | def get_allowed_tools(self) -> Set[str]:
31 | return set(self.allowed_tools)
32 |
33 |
34 | async def dispatch(args: Namespace):
35 | args_dict = vars(args)
36 |
37 | shared.loaderCheckPoint = LoaderCheckPoint(args_dict)
38 | llm_model_ins = shared.loaderLLM()
39 |
40 | template = """This is a conversation between a human and a bot:
41 |
42 | {chat_history}
43 |
44 | Write a summary of the conversation for {input}:
45 | """
46 |
47 | prompt = PromptTemplate(
48 | input_variables=["input", "chat_history"],
49 | template=template
50 | )
51 | memory = ConversationBufferMemory(memory_key="chat_history")
52 | readonlymemory = ReadOnlySharedMemory(memory=memory)
53 | summry_chain = LLMChain(
54 | llm=llm_model_ins,
55 | prompt=prompt,
56 | verbose=True,
57 | memory=readonlymemory, # use the read-only memory to prevent the tool from modifying the memory
58 | )
59 |
60 |
61 | tools = [
62 | Tool(
63 | name="Summary",
64 | func=summry_chain.run,
65 | description="useful for when you summarize a conversation. The input to this tool should be a string, representing who will read this summary."
66 | )
67 | ]
68 |
69 | prefix = """Have a conversation with a human, answering the following questions as best you can. You have access to the following tools:"""
70 | suffix = """Begin!
71 |
72 | Question: {input}
73 | {agent_scratchpad}"""
74 |
75 |
76 | prompt = CustomLLMSingleActionAgent.create_prompt(
77 | tools,
78 | prefix=prefix,
79 | suffix=suffix,
80 | input_variables=["input", "agent_scratchpad"]
81 | )
82 | tool_names = [tool.name for tool in tools]
83 | llm_chain = LLMChain(llm=llm_model_ins, prompt=prompt)
84 | agent = CustomLLMSingleActionAgent(llm_chain=llm_chain, tools=tools, allowed_tools=tool_names)
85 | agent_chain = AgentExecutor.from_agent_and_tools(agent=agent, tools=tools)
86 |
87 | agent_chain.run(input="你好")
88 | agent_chain.run(input="你是谁?")
89 | agent_chain.run(input="我们之前聊了什么?")
90 |
91 | if __name__ == '__main__':
92 | args = None
93 | args = parser.parse_args(args=['--model-dir', '/media/checkpoint/', '--model', 'vicuna-13b-hf', '--no-remote-model', '--load-in-8bit'])
94 |
95 | loop = asyncio.new_event_loop()
96 | asyncio.set_event_loop(loop)
97 | loop.run_until_complete(dispatch(args))
98 |
--------------------------------------------------------------------------------
/loader/models/base.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from typing import Optional, List
3 | import traceback
4 | from collections import deque
5 | from queue import Queue
6 | from threading import Thread
7 |
8 | import torch
9 | import transformers
10 | from loader.models.loader import LoaderCheckPoint
11 |
12 |
13 | class ListenerToken:
14 | """
15 | 观测结果
16 | """
17 |
18 | input_ids: torch.LongTensor
19 | _scores: torch.FloatTensor
20 |
21 | def __init__(self, input_ids: torch.LongTensor, _scores: torch.FloatTensor):
22 | self.input_ids = input_ids
23 | self._scores = _scores
24 |
25 |
26 | class AnswerResult:
27 | """
28 | 消息实体
29 | """
30 | history: List[List[str]] = []
31 | llm_output: Optional[dict] = None
32 | listenerToken: ListenerToken = None
33 |
34 |
35 | class AnswerResultStream:
36 | def __init__(self, callback_func=None):
37 | self.callback_func = callback_func
38 |
39 | def __call__(self, answerResult: AnswerResult):
40 | if self.callback_func is not None:
41 | self.callback_func(answerResult)
42 |
43 |
44 | class AnswerResultQueueSentinelTokenListenerQueue(transformers.StoppingCriteria):
45 | """
46 | 定义模型stopping_criteria 监听者,在每次响应时将队列数据同步到AnswerResult
47 | 实现此监听器的目的是,不同模型的预测输出可能不是矢量信息,hf框架可以自定义transformers.StoppingCriteria入参来接收每次预测的Tensor和损失函数,
48 | 通过给 StoppingCriteriaList指定模型生成答案时停止的条件。每个 StoppingCriteria 对象表示一个停止条件
49 | 当每轮预测任务开始时,StoppingCriteria都会收到相同的预测结果,最终由下层实现类确认是否结束
50 | 输出值可用于 generatorAnswer generate_with_streaming的自定义参数观测,以实现更加精细的控制
51 | """
52 |
53 | listenerQueue: deque = deque(maxlen=1)
54 |
55 | def __init__(self):
56 | transformers.StoppingCriteria.__init__(self)
57 |
58 | def __call__(self, input_ids: torch.LongTensor, _scores: torch.FloatTensor, **kwargs) -> bool:
59 | """
60 | 每次响应时将数据添加到响应队列
61 | :param input_ids:
62 | :param _scores:
63 | :param kwargs:
64 | :return:
65 | """
66 | self.listenerQueue.append(ListenerToken(input_ids=input_ids, _scores=_scores))
67 | return False
68 |
69 |
70 | class Iteratorize:
71 | """
72 | Transforms a function that takes a callback
73 | into a lazy iterator (generator).
74 | """
75 |
76 | def __init__(self, func, kwargs={}):
77 | self.mfunc = func
78 | self.q = Queue()
79 | self.sentinel = object()
80 | self.kwargs = kwargs
81 | self.stop_now = False
82 |
83 | def _callback(val):
84 | """
85 | 模型输出预测结果收集
86 | 通过定义generate_with_callback收集器AnswerResultStream,收集模型预测的AnswerResult响应结果,最终由下层实现类确认是否结束
87 | 结束条件包含如下
88 | 1、模型预测结束、收集器self.q队列收到 self.sentinel标识
89 | 2、在处理迭代器队列消息时返回了break跳出迭代器,触发了StopIteration事件
90 | 3、模型预测出错
91 | 因为当前类是迭代器,所以在for in 中执行了break后 __exit__ 方法会被调用,最终stop_now属性会被更新,然后抛出异常结束预测行为
92 | 迭代器收集的行为如下
93 | 创建Iteratorize迭代对象,
94 | 定义generate_with_callback收集器AnswerResultStream
95 | 启动一个线程异步预测结果来调用上游checkpoint的实现方法_generate_answer
96 | _generate_answer通过generate_with_callback定义的收集器,收集上游checkpoint包装的AnswerResult消息体
97 | 由于self.q是阻塞模式,每次预测后会被消费后才会执行下次预测
98 | 这时generate_with_callback会被阻塞
99 | 主线程Iteratorize对象的__next__方法调用获取阻塞消息并消费
100 | 1、消息为上游checkpoint包装的AnswerResult消息体,返回下游处理
101 | 2、消息为self.sentinel标识,抛出StopIteration异常
102 | 主线程Iteratorize对象__exit__收到消息,最终stop_now属性会被更新
103 | 异步线程检测stop_now属性被更新,抛出异常结束预测行为
104 | 迭代行为结束
105 | :param val:
106 | :return:
107 | """
108 | if self.stop_now:
109 | raise ValueError
110 | self.q.put(val)
111 |
112 | def gen():
113 | try:
114 | ret = self.mfunc(callback=_callback, **self.kwargs)
115 | except ValueError:
116 | pass
117 | except:
118 | traceback.print_exc()
119 | pass
120 |
121 | self.q.put(self.sentinel)
122 |
123 | self.thread = Thread(target=gen)
124 | self.thread.start()
125 |
126 | def __iter__(self):
127 | return self
128 |
129 | def __next__(self):
130 | obj = self.q.get(True, None)
131 | if obj is self.sentinel:
132 | raise StopIteration
133 | else:
134 | return obj
135 |
136 | def __del__(self):
137 | """
138 | 暂无实现
139 | :return:
140 | """
141 | pass
142 |
143 | def __enter__(self):
144 | return self
145 |
146 | def __exit__(self, exc_type, exc_val, exc_tb):
147 | """ break 后会执行 """
148 | self.stop_now = True
149 |
150 |
151 | class BaseAnswer(ABC):
152 | """上层业务包装器.用于结果生成统一api调用"""
153 |
154 | @property
155 | @abstractmethod
156 | def _check_point(self) -> LoaderCheckPoint:
157 | """Return _check_point of llm."""
158 |
159 | @property
160 | @abstractmethod
161 | def _history_len(self) -> int:
162 | """Return _history_len of llm."""
163 |
164 | @abstractmethod
165 | def set_history_len(self, history_len: int) -> None:
166 | """Return _history_len of llm."""
167 |
168 | def generatorAnswer(self, prompt: str,
169 | history: List[List[str]] = [],
170 | streaming: bool = False):
171 | def generate_with_callback(callback=None, **kwargs):
172 | kwargs['generate_with_callback'] = AnswerResultStream(callback_func=callback)
173 | self._generate_answer(**kwargs)
174 |
175 | def generate_with_streaming(**kwargs):
176 | return Iteratorize(generate_with_callback, kwargs)
177 |
178 | """
179 | eos_token_id是指定token(例如,""),
180 | 用于表示序列的结束。在生成文本任务中,生成器在生成序列时,将不断地生成token,直到生成此特殊的eos_token_id,表示序列生成已经完成。
181 | 在Hugging Face Transformer模型中,eos_token_id是由tokenizer自动添加到输入中的。
182 | 在模型生成输出时,如果模型生成了eos_token_id,则生成过程将停止并返回生成的序列。
183 | """
184 | eos_token_ids = [
185 | self._check_point.tokenizer.eos_token_id] if self._check_point.tokenizer.eos_token_id is not None else []
186 |
187 | with generate_with_streaming(prompt=prompt, history=history, streaming=streaming) as generator:
188 | for answerResult in generator:
189 | if answerResult.listenerToken:
190 | output = answerResult.listenerToken.input_ids
191 | yield answerResult
192 |
193 | @abstractmethod
194 | def _generate_answer(self, prompt: str,
195 | history: List[List[str]] = [],
196 | streaming: bool = False,
197 | generate_with_callback: AnswerResultStream = None) -> None:
198 | pass
199 |
--------------------------------------------------------------------------------
/loader/models/bloomz_llm.py:
--------------------------------------------------------------------------------
1 |
2 | from abc import ABC
3 |
4 | from langchain.llms.base import LLM
5 | from typing import Optional, List
6 | from loader.models.loader import LoaderCheckPoint
7 | from loader.models.base import (BaseAnswer,
8 | AnswerResult,
9 | AnswerResultStream,
10 | AnswerResultQueueSentinelTokenListenerQueue)
11 | import re
12 |
13 | import transformers
14 | from transformers.generation.streamers import BaseStreamer
15 | from threading import Thread
16 | from queue import Queue
17 | from typing import Callable, Iterable, List, Optional, Tuple, Union
18 | # import torch
19 |
20 | class MyStreamer(BaseStreamer):
21 |
22 | def __init__(
23 | self,
24 | # stop_token_ids: torch.Tensor,
25 | # skip_token_count: int,
26 | # max_input_length: int,
27 | timeout: Optional[float] = None
28 | ):
29 |
30 | # 紧急停止策略
31 | # self.stop_token_ids = stop_token_ids
32 | # self.skip_token_count = skip_token_count
33 | # self.max_input_length = max_input_length
34 |
35 |
36 | self.token_queue = Queue()
37 | self.stop_signal = None
38 | self.timeout = timeout
39 |
40 |
41 | def put(self, value):
42 | list_value = value.tolist()
43 | if type(list_value[0]) == int:
44 | self.token_queue.put(list_value, timeout=self.timeout)
45 |
46 | def end(self):
47 | self.token_queue.put(self.stop_signal, timeout=self.timeout)
48 |
49 | def __iter__(self):
50 | return self
51 |
52 | def __next__(self):
53 | value = self.token_queue.get(timeout=self.timeout)
54 | if value == self.stop_signal:
55 | raise StopIteration()
56 | else:
57 | return value
58 |
59 | def remove_starting_symbols(string):
60 | pattern = r'^[,。,.!!]+'
61 | result = re.sub(pattern, '', string)
62 | return result
63 |
64 | def extract_content(replacement_text, string):
65 | pattern_helper = r'Helper:(.*?)' # r'Helper:(.*?)(?=<\/s>)'
66 |
67 | match_helper = re.findall(pattern_helper, string, re.DOTALL)
68 |
69 | if match_helper:
70 | content = match_helper[-1].strip()
71 | return content.replace('', '').replace('<\s>', '')
72 | else:
73 | replaced_string = re.sub(r'Input:.*?(?=\n)', replacement_text, string, re.DOTALL)
74 | replaced_string = replaced_string.replace('', '').replace('<\s>', '')
75 | return replaced_string
76 |
77 | import re
78 |
79 | def remove_prefix_suffix(string):
80 | pattern = r'^((?:Helper:|:)\s*)(.*?)()?$'
81 | string = re.sub(pattern, r'\2', string, flags=re.DOTALL)
82 | string = remove_starting_symbols(string).replace('', '').replace('<\s>', '')
83 | return string
84 |
85 | class Bloomz(BaseAnswer, LLM, ABC):
86 | max_token: int = 10000
87 | temperature: float = 0.01
88 | top_p = 0.9
89 | checkPoint: LoaderCheckPoint = None
90 | # history = []
91 | history_len: int = 10
92 |
93 | def __init__(self, checkPoint: LoaderCheckPoint = None):
94 | super().__init__()
95 | self.checkPoint = checkPoint
96 |
97 | @property
98 | def _llm_type(self) -> str:
99 | return "Bloomz"
100 |
101 | @property
102 | def _check_point(self) -> LoaderCheckPoint:
103 | return self.checkPoint
104 |
105 | @property
106 | def _history_len(self) -> int:
107 | return self.history_len
108 |
109 | def set_history_len(self, history_len: int = 10) -> None:
110 | self.history_len = history_len
111 |
112 | def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
113 | pass
114 |
115 | def _generate_answer(self, prompt: str,
116 | history: List[List[str]] = [],
117 | streaming: bool = False,
118 | generate_with_callback: AnswerResultStream = None) -> None:
119 | # Create the StoppingCriteriaList with the stopping strings
120 | stopping_criteria_list = transformers.StoppingCriteriaList()
121 | # 定义模型stopping_criteria 队列,在每次响应时将 torch.LongTensor, torch.FloatTensor同步到AnswerResult
122 | listenerQueue = AnswerResultQueueSentinelTokenListenerQueue()
123 | stopping_criteria_list.append(listenerQueue)
124 | model_device = next(self.checkPoint.model.parameters()).device
125 | if streaming:
126 | history += [[]]
127 | streamer = MyStreamer() # type: ignore
128 | # torch.manual_seed(23333)
129 | inputs = self.checkPoint.tokenizer([prompt], return_tensors="pt").input_ids.to(model_device)
130 |
131 | thread = Thread(target=self.checkPoint.model.generate, kwargs=dict(
132 | inputs=inputs,
133 | # attention_mask=attention_mask.cuda(),
134 | # gen kargs
135 | num_beams=1,
136 | do_sample=True,
137 | temperature=0.7,
138 | top_p=self.top_p,
139 | top_k=9,
140 | eos_token_id=self.checkPoint.tokenizer.eos_token_id,
141 | max_length=2000,
142 | min_length=1,
143 | streamer=streamer,
144 | ))
145 |
146 | thread.start()
147 | # stream_resp = ''
148 | token_list = []
149 | for token in streamer:
150 | token_list += token
151 | stream_resp = self.checkPoint.tokenizer.decode(token_list)
152 | stream_resp = remove_prefix_suffix(stream_resp)
153 | history[-1] = [prompt, stream_resp] #stream_resp.lstrip(": ")
154 | answer_result = AnswerResult()
155 | answer_result.history = history
156 | answer_result.llm_output = {"answer": stream_resp}
157 | if listenerQueue.listenerQueue.__len__() > 0:
158 | answer_result.listenerToken = listenerQueue.listenerQueue.pop()
159 | generate_with_callback(answer_result)
160 | else:
161 | inputs = self.checkPoint.tokenizer([prompt], return_tensors="pt").input_ids.to(model_device)
162 | re_token_ids = self.checkPoint.model.generate(
163 | inputs=inputs,
164 | # gen kargs
165 | num_beams=1,
166 | do_sample=True,
167 | temperature=0.7,
168 | top_p=self.top_p,
169 | top_k=9,
170 | eos_token_id=self.checkPoint.tokenizer.eos_token_id,
171 | max_length=2000 #512,
172 | )
173 | response = self.checkPoint.tokenizer.decode(re_token_ids[0])
174 | response = extract_content(prompt, response)
175 | self.checkPoint.clear_torch_cache()
176 | history += [[prompt, response]]
177 | answer_result = AnswerResult()
178 | answer_result.history = history
179 | answer_result.llm_output = {"answer": response}
180 | if listenerQueue.listenerQueue.__len__() > 0:
181 | answer_result.listenerToken = listenerQueue.listenerQueue.pop()
182 |
183 | generate_with_callback(answer_result)
--------------------------------------------------------------------------------
/loader/models/extensions/callback.py:
--------------------------------------------------------------------------------
1 | # import gc
2 | import traceback
3 | from queue import Queue
4 | # from threading import Thread
5 | # import threading
6 | from typing import Optional, List, Dict, Any, TypeVar, Deque
7 | from collections import deque
8 | import torch
9 | import transformers
10 |
11 | from models.extensions.thread_with_exception import ThreadWithException
12 | import models.shared as shared
13 |
14 |
15 | K = TypeVar('K')
16 | V = TypeVar('V')
17 |
18 | class LimitedLengthDict(Dict[K, V]):
19 | def __init__(self, maxlen=None, *args, **kwargs):
20 | self.maxlen = maxlen
21 | self._keys: Deque[K] = deque()
22 | super().__init__(*args, **kwargs)
23 |
24 | def __setitem__(self, key: K, value: V):
25 | if key not in self:
26 | if self.maxlen is not None and len(self) >= self.maxlen:
27 | oldest_key = self._keys.popleft()
28 | if oldest_key in self:
29 | del self[oldest_key]
30 | self._keys.append(key)
31 | super().__setitem__(key, value)
32 |
33 |
34 | class FixedLengthQueue:
35 | # 停止符号列表
36 | stop_sequence: Optional[str] = []
37 | # 缓冲区
38 | max_length: int = 0
39 | # 缓冲区容器
40 | queue: deque = None
41 | # 输入区容器
42 | queue_in: LimitedLengthDict[int, str] = {}
43 | # 输出区容器
44 | queue_out: Dict[int, str] = {}
45 |
46 | def __new__(cls, *args, **kwargs):
47 | # 创建新的实例
48 | instance = super().__new__(cls)
49 | # 在这里可以对实例进行额外的设置
50 | return instance
51 |
52 | def __init__(self, stop_sequence):
53 | if stop_sequence is None:
54 | self.stop_sequence = []
55 | self.max_length = 0
56 | elif isinstance(stop_sequence, str):
57 | self.stop_sequence = [stop_sequence]
58 | self.max_length = 1
59 | else:
60 | self.stop_sequence = stop_sequence
61 | self.max_length = len(''.join(stop_sequence))
62 |
63 | self.queue = deque(maxlen=self.max_length)
64 | self.queue.clear()
65 | self.queue_in.clear()
66 | self.queue_out.clear()
67 |
68 | def add(self, index, item):
69 | self.queue_in[index] = item
70 |
71 | def _add_out(self, index, item):
72 | self.queue_out[index] = item
73 |
74 | def put_replace_out(self, index):
75 | return self.queue_out[index]
76 |
77 | def contains_replace_sequence(self):
78 | """
79 | 替换字符
80 | :return:
81 | """
82 |
83 | for key, value in self.queue_in.items():
84 |
85 | word_index = value.rfind(":")
86 | if word_index != -1:
87 | value = value.replace(":", ":")
88 |
89 | word_index = value.rfind("[")
90 | if word_index != -1:
91 | value = value.replace("[", "")
92 |
93 | word_index = value.rfind("]")
94 | if word_index != -1:
95 | value = value.replace("]", "")
96 |
97 | self._add_out(key, value)
98 |
99 | def contains_stop_sequence(self):
100 | # 截取固定大小的数据判断
101 | self.queue.clear()
102 | last_three_keys = list(self.queue_out.keys())[-self.max_length:]
103 | joined_queue = ''.join([self.queue_out[key] for key in last_three_keys])
104 | for char in joined_queue:
105 | self.queue.append(char)
106 |
107 | joined_queue = ''.join(self.queue)
108 | # Initialize a variable to store the index of the last found stop string
109 | last_stop_str_index = -1
110 |
111 | # Iterate through the stop string list
112 | for stop_word in self.stop_sequence:
113 | # Find the last occurrence of the stop string in the output
114 | stop_word_index = joined_queue.rfind(stop_word)
115 |
116 | # If the stop string is found, compare the index with the previously found index
117 | if stop_word_index != -1 and stop_word_index > last_stop_str_index:
118 | last_stop_str_index = stop_word_index
119 |
120 | # Handle the last found stop string index here
121 | return last_stop_str_index
122 |
123 | def __repr__(self):
124 | return str(self.queue)
125 |
126 |
127 | # Copied from https://github.com/PygmalionAI/gradio-ui/
128 | class _SentinelTokenStoppingCriteria(transformers.StoppingCriteria):
129 |
130 | def __init__(self, sentinel_token_ids: list, starting_idx: int):
131 | transformers.StoppingCriteria.__init__(self)
132 | self.sentinel_token_ids = sentinel_token_ids
133 | self.starting_idx = starting_idx
134 |
135 | def __call__(self, input_ids: torch.LongTensor, _scores: torch.FloatTensor) -> bool:
136 | for sample in input_ids:
137 | trimmed_sample = sample[self.starting_idx:]
138 |
139 | for i in range(len(self.sentinel_token_ids)):
140 | # Can't unfold, output is still too tiny. Skip.
141 | if trimmed_sample.shape[-1] < self.sentinel_token_ids[i].shape[-1]:
142 | continue
143 | for window in trimmed_sample.unfold(0, self.sentinel_token_ids[i].shape[-1], 1):
144 | if torch.all(torch.eq(self.sentinel_token_ids[i][0], window)):
145 | return True
146 | return False
147 |
148 |
149 | class Stream(transformers.StoppingCriteria):
150 | def __init__(self, callback_func=None):
151 | self.callback_func = callback_func
152 |
153 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
154 | if shared.stop_everything:
155 | raise ValueError
156 | if self.callback_func is not None:
157 | self.callback_func(input_ids[0])
158 | return False
159 |
160 |
161 | class Iteratorize:
162 | """
163 | Transforms a function that takes a callback
164 | into a lazy iterator (generator).
165 | """
166 |
167 | thread: ThreadWithException = None
168 |
169 | def __new__(cls, *args, **kwargs):
170 | # 创建新的实例
171 | instance = super().__new__(cls)
172 | # 在这里可以对实例进行额外的设置
173 | return instance
174 |
175 | def __init__(self, func, kwargs={}, callback=None):
176 | self.mfunc = func
177 | self.c_callback = callback
178 | self.q = Queue()
179 | self.sentinel = object()
180 | self.kwargs = kwargs
181 |
182 | def _callback(val):
183 | if shared.stop_everything:
184 | raise ValueError
185 | self.q.put(val)
186 |
187 | def gen():
188 | try:
189 | ret = self.mfunc(callback=_callback, **self.kwargs)
190 | except ValueError:
191 | print("print(ValueError)")
192 | except:
193 | traceback.print_exc()
194 | print("traceback.print_exc()")
195 | self.q.put(self.sentinel)
196 |
197 | self.thread = ThreadWithException(target=gen)
198 | self.thread.start()
199 |
200 | def __iter__(self):
201 | shared.stop_everything = False
202 | return self
203 |
204 | def __next__(self):
205 | obj = self.q.get(True, None)
206 | if obj is self.sentinel:
207 | raise StopIteration
208 | else:
209 | return obj
210 |
211 | def __del__(self):
212 | shared.stop_everything = False
213 | self.q.empty()
214 | shared.loaderCheckPoint.clear_torch_cache()
215 |
216 | def __enter__(self):
217 | shared.stop_everything = False
218 | return self
219 |
220 | def __exit__(self, exc_type, exc_val, exc_tb):
221 | shared.stop_everything = True
222 | shared.loaderCheckPoint.clear_torch_cache()
223 | self.thread.raise_exception()
224 |
--------------------------------------------------------------------------------
/loader/models/extensions/extensions.py:
--------------------------------------------------------------------------------
1 | import gc
2 | import traceback
3 | import torch
4 |
5 | # This iterator returns the extensions in the order specified in the command-line
6 | def iterator():
7 | state_extensions = {}
8 | for name in sorted(state_extensions, key=lambda x: state_extensions[x][1]):
9 | if state_extensions[name][0]:
10 | yield getattr(extensions, name).script, name
--------------------------------------------------------------------------------
/loader/models/extensions/llamacpp_model_alternative.py:
--------------------------------------------------------------------------------
1 | '''
2 | Based on
3 | https://github.com/abetlen/llama-cpp-python
4 |
5 | Documentation:
6 | https://abetlen.github.io/llama-cpp-python/
7 | '''
8 |
9 | from llama_cpp import Llama, LlamaCache
10 |
11 | from modules import shared
12 | from modules.callbacks import Iteratorize
13 |
14 |
15 | class LlamaCppModel:
16 | def __init__(self):
17 | self.initialized = False
18 |
19 | @classmethod
20 | def from_pretrained(self, path):
21 | result = self()
22 |
23 | params = {
24 | 'model_path': str(path),
25 | 'n_ctx': 2048,
26 | 'seed': 0,
27 | 'n_threads': shared.args.threads or None
28 | }
29 | self.model = Llama(**params)
30 | self.model.set_cache(LlamaCache)
31 |
32 | # This is ugly, but the model and the tokenizer are the same object in this library.
33 | return result, result
34 |
35 | def encode(self, string):
36 | if type(string) is str:
37 | string = string.encode()
38 | return self.model.tokenize(string)
39 |
40 | def generate(self, context="", token_count=20, temperature=1, top_p=1, top_k=50, repetition_penalty=1, callback=None):
41 | if type(context) is str:
42 | context = context.encode()
43 | tokens = self.model.tokenize(context)
44 |
45 | output = b""
46 | count = 0
47 | for token in self.model.generate(tokens, top_k=top_k, top_p=top_p, temp=temperature, repeat_penalty=repetition_penalty):
48 | text = self.model.detokenize([token])
49 | output += text
50 | if callback:
51 | callback(text.decode())
52 |
53 | count += 1
54 | if count >= token_count or (token == self.model.token_eos()):
55 | break
56 |
57 | return output.decode()
58 |
59 | def generate_with_streaming(self, **kwargs):
60 | with Iteratorize(self.generate, kwargs, callback=None) as generator:
61 | reply = ''
62 | for token in generator:
63 | reply += token
64 | yield reply
65 |
--------------------------------------------------------------------------------
/loader/models/extensions/thread_with_exception.py:
--------------------------------------------------------------------------------
1 | # Python program raising
2 | # exceptions in a python
3 | # thread
4 |
5 | import threading
6 | import ctypes
7 | import time
8 |
9 |
10 | class ThreadWithException(threading.Thread):
11 |
12 | def get_id(self):
13 | return self.ident
14 |
15 | def raise_exception(self):
16 | """raises the exception, performs cleanup if needed"""
17 | try:
18 | thread_id = self.get_id()
19 | tid = ctypes.c_long(thread_id)
20 | res = ctypes.pythonapi.PyThreadState_SetAsyncExc(tid, ctypes.py_object(SystemExit))
21 | if res == 0:
22 | # pass
23 | raise ValueError("invalid thread id")
24 | elif res != 1:
25 | # """if it returns a number greater than one, you're in trouble,
26 | # and you should call it again with exc=NULL to revert the effect"""
27 | ctypes.pythonapi.PyThreadState_SetAsyncExc(tid, None)
28 | raise SystemError("PyThreadState_SetAsyncExc failed")
29 | except Exception as err:
30 | print(err)
31 |
--------------------------------------------------------------------------------
/loader/models/loader/__init__.py:
--------------------------------------------------------------------------------
1 |
2 | from .loader import *
3 |
--------------------------------------------------------------------------------
/loader/models/loader/args.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | from configs.common_config import *
4 |
5 |
6 | # Additional argparse types
7 | def path(string):
8 | if not string:
9 | return ''
10 | s = os.path.expanduser(string)
11 | if not os.path.exists(s):
12 | raise argparse.ArgumentTypeError(f'No such file or directory: "{string}"')
13 | return s
14 |
15 |
16 | def file_path(string):
17 | if not string:
18 | return ''
19 | s = os.path.expanduser(string)
20 | if not os.path.isfile(s):
21 | raise argparse.ArgumentTypeError(f'No such file: "{string}"')
22 | return s
23 |
24 |
25 | def dir_path(string):
26 | if not string:
27 | return ''
28 | s = os.path.expanduser(string)
29 | if not os.path.isdir(s):
30 | raise argparse.ArgumentTypeError(f'No such directory: "{string}"')
31 | return s
32 |
33 |
34 | parser = argparse.ArgumentParser(prog='知识库',
35 | description='知识库')
36 |
37 | parser.add_argument('--no-remote-model', action='store_true', default=NO_REMOTE_MODEL, help='remote in the model on '
38 | 'loader checkpoint, '
39 | 'if your load local '
40 | 'model to add the ` '
41 | '--no-remote-model`')
42 | parser.add_argument('--model', type=str, default=LLM_MODEL, help='Name of the model to load by default.')
43 | parser.add_argument('--lora', type=str, help='Name of the LoRA to apply to the model by default.')
44 | parser.add_argument("--model-dir", type=str, default=MODEL_DIR, help="Path to directory with all the models")
45 | parser.add_argument("--lora-dir", type=str, default=LORA_DIR, help="Path to directory with all the loras")
46 |
47 | # Accelerate/transformers
48 | parser.add_argument('--load-in-8bit', action='store_true', default=LOAD_IN_8BIT,
49 | help='Load the model with 8-bit precision.')
50 | parser.add_argument('--bf16', action='store_true', default=BF16,
51 | help='Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU.')
52 |
53 | args = parser.parse_args([])
54 | # Generares dict with a default value for each argument
55 | DEFAULT_ARGS = vars(args)
56 |
--------------------------------------------------------------------------------
/loader/models/loader/loader.py:
--------------------------------------------------------------------------------
1 | import gc
2 | import json
3 | import os
4 | import re
5 | import time
6 | from pathlib import Path
7 | from peft import PeftModel
8 | from typing import Optional, List, Dict, Tuple, Union
9 | from pynvml import (nvmlInit, nvmlDeviceGetCount, nvmlDeviceGetHandleByIndex,
10 | nvmlDeviceGetName, nvmlDeviceGetMemoryInfo, nvmlShutdown)
11 | def get_available_gpu(threshold=20000):
12 | # Initialize NVML
13 | nvmlInit()
14 | # Get the number of GPU devices
15 | device_count = nvmlDeviceGetCount()
16 |
17 | # Find GPU devices with available memory greater than the threshold
18 | available_gpus = []
19 | for i in range(device_count):
20 | handle = nvmlDeviceGetHandleByIndex(i)
21 | info = nvmlDeviceGetMemoryInfo(handle)
22 | free_memory_mb = info.free / 1024 / 1024
23 |
24 | if free_memory_mb > threshold:
25 | available_gpus.append(i)
26 |
27 | # Shutdown NVML
28 | nvmlShutdown()
29 | # available_gpus = ['0']
30 |
31 | return available_gpus
32 |
33 |
34 | def get_free_memory():
35 | nvmlInit()
36 | # Get the number of GPU devices
37 | device_count = nvmlDeviceGetCount()
38 |
39 | # Find GPU devices with available memory greater than the threshold
40 | free_memory_gpus = []
41 | for i in range(device_count):
42 | handle = nvmlDeviceGetHandleByIndex(i)
43 | info = nvmlDeviceGetMemoryInfo(handle)
44 | free_memory_mb = info.free / 1024 / 1024
45 | free_memory_gpus.append(free_memory_mb)
46 |
47 | # Shutdown NVML
48 | nvmlShutdown()
49 | return free_memory_gpus
50 |
51 | import torch
52 | import transformers
53 |
54 | from transformers import (AutoConfig, AutoModel, AutoModelForCausalLM,
55 | AutoTokenizer, BitsAndBytesConfig, LlamaTokenizer,BloomTokenizerFast, BloomConfig, BloomForCausalLM)
56 | from transformers.dynamic_module_utils import get_class_from_dynamic_module
57 | from transformers.modeling_utils import no_init_weights
58 | from transformers.utils import ContextManagers
59 | from accelerate import init_empty_weights
60 | from accelerate.utils import get_balanced_memory, infer_auto_device_map
61 | from configs.common_config import LLM_DEVICE,llm_model_dict
62 |
63 |
64 |
65 |
66 |
67 |
68 | class LoaderCheckPoint:
69 | """
70 | 加载自定义 model CheckPoint
71 | """
72 | # remote in the model on loader checkpoint
73 | no_remote_model: bool = False
74 | # 模型名称
75 | model_name: str = None
76 | tokenizer: object = None
77 | # 模型全路径
78 | model_path: str = None
79 | model: object = None
80 | model_config: object = None
81 | lora_names: set = []
82 | model_dir: str = None
83 | lora_dir: str = None
84 | ptuning_dir: str = None
85 | use_ptuning_v2: bool = False
86 | # 如果开启了8bit量化加载,项目无法启动,参考此位置,选择合适的cuda版本,https://github.com/TimDettmers/bitsandbytes/issues/156
87 | load_in_8bit: bool = False
88 | is_llamacpp: bool = False
89 | bf16: bool = False
90 | params: object = None
91 | # 自定义设备网络
92 | device_map: Optional[Dict[str, int]] = None
93 | # 默认 cuda ,如果不支持cuda使用多卡, 如果不支持多卡 使用cpu
94 | llm_device = LLM_DEVICE
95 |
96 | def __init__(self, params: dict = None):
97 | """
98 | 模型初始化
99 | :param params:
100 | """
101 | self.model_path = None
102 | self.model = None
103 | self.tokenizer = None
104 | self.params = params or {}
105 | self.no_remote_model = params.get('no_remote_model', True)
106 | self.model_name = params.get('model', '')
107 | self.lora = params.get('lora', '')
108 | self.use_ptuning_v2 = params.get('use_ptuning_v2', False)
109 | self.model_dir = params.get('model_dir', '')
110 | self.lora_dir = params.get('lora_dir', '')
111 | self.ptuning_dir = params.get('ptuning_dir', 'ptuning-v2')
112 | self.load_in_8bit = params.get('load_in_8bit', False)
113 | self.bf16 = params.get('bf16', False)
114 |
115 | def _load_model_config(self, model_name):
116 | checkpoint = Path(f'{self.model_dir}/{model_name}')
117 |
118 | if self.model_path:
119 | checkpoint = Path(f'{self.model_path}')
120 | else:
121 | if not self.no_remote_model:
122 | checkpoint = model_name
123 |
124 | model_config = AutoConfig.from_pretrained(checkpoint, trust_remote_code=True)
125 |
126 | return model_config
127 |
128 | def _load_model(self, model_name):
129 | """
130 | 加载自定义位置的model
131 | :param model_name:
132 | :return:
133 | """
134 | print(f"Loading {model_name}...")
135 | t0 = time.time()
136 |
137 | checkpoint = Path(f'{self.model_dir}/{model_name}')
138 |
139 | self.is_llamacpp = len(list(checkpoint.glob('ggml*.bin'))) > 0
140 |
141 | if self.model_path:
142 | checkpoint = Path(f'{self.model_path}')
143 | else:
144 | if not self.no_remote_model:
145 | checkpoint = model_name
146 |
147 | if 'bloomz' in model_name.lower():
148 | LoaderClass = BloomForCausalLM
149 | else:
150 | LoaderClass = AutoModelForCausalLM
151 |
152 | # Load the model in simple 16-bit mode by default
153 | if not any([self.llm_device.lower()=="cpu",
154 | self.load_in_8bit, self.is_llamacpp]):
155 |
156 | if torch.cuda.is_available() and self.llm_device.lower().startswith("cuda"):
157 | available_gpus = get_available_gpu(threshold=20000)
158 | print('available_gpus',available_gpus)
159 | if len(available_gpus)>0:
160 | available_gpu = available_gpus[0]
161 | target_device = torch.device(f'cuda:{str(available_gpu)}')
162 | print('target_device==',target_device)
163 | model = (
164 | LoaderClass.from_pretrained(checkpoint,
165 | config=self.model_config,
166 | torch_dtype=torch.bfloat16 if self.bf16 else torch.float16,
167 | trust_remote_code=True)
168 | .half()
169 | .to(target_device)
170 | )
171 | else:
172 | print('没有满足要求的GPU设备可用')
173 | model = None
174 |
175 | else:
176 | # print(
177 | # "Warning: torch.cuda.is_available() returned False.\nThis means that no GPU has been "
178 | # "detected.\nFalling back to CPU mode.\n")
179 | model = (
180 | AutoModel.from_pretrained(
181 | checkpoint,
182 | config=self.model_config,
183 | trust_remote_code=True)
184 | .float()
185 | .to(self.llm_device)
186 | )
187 |
188 | elif self.is_llamacpp:
189 | from models.extensions.llamacpp_model_alternative import LlamaCppModel
190 |
191 | model_file = list(checkpoint.glob('ggml*.bin'))[0]
192 | print(f"llama.cpp weights detected: {model_file}\n")
193 |
194 | model, tokenizer = LlamaCppModel.from_pretrained(model_file)
195 | return model, tokenizer
196 |
197 | # Custom
198 | else:
199 | params = {"low_cpu_mem_usage": True}
200 |
201 | if not self.llm_device.lower().startswith("cuda"):
202 | raise SystemError("8bit 模型需要 CUDA 支持,或者改用量化后模型!")
203 | else:
204 | params["device_map"] = 'auto'
205 | params["trust_remote_code"] = True
206 | if self.load_in_8bit:
207 | params['quantization_config'] = BitsAndBytesConfig(load_in_8bit=True,
208 | llm_int8_enable_fp32_cpu_offload=False)
209 | elif self.bf16:
210 | params["torch_dtype"] = torch.bfloat16
211 | else:
212 | params["torch_dtype"] = torch.float16
213 |
214 | if self.load_in_8bit and params.get('max_memory', None) is not None and params['device_map'] == 'auto':
215 | config = AutoConfig.from_pretrained(checkpoint)
216 | with init_empty_weights():
217 | model = LoaderClass.from_config(config)
218 | model.tie_weights()
219 | if self.device_map is not None:
220 | params['device_map'] = self.device_map
221 | else:
222 | params['device_map'] = infer_auto_device_map(
223 | model,
224 | dtype=torch.int8,
225 | max_memory=params['max_memory'],
226 | no_split_module_classes=model._no_split_modules
227 | )
228 |
229 | model = LoaderClass.from_pretrained(checkpoint, **params)
230 |
231 | # Loading the tokenizer
232 | if type(model) is transformers.LlamaForCausalLM:
233 | tokenizer = LlamaTokenizer.from_pretrained(checkpoint, clean_up_tokenization_spaces=True)
234 | # Leaving this here until the LLaMA tokenizer gets figured out.
235 | # For some people this fixes things, for others it causes an error.
236 | try:
237 | tokenizer.eos_token_id = 2
238 | tokenizer.bos_token_id = 1
239 | tokenizer.pad_token_id = 0
240 | except Exception as e:
241 | print(e)
242 | pass
243 | elif 'bloomz' in model_name.lower():
244 | tokenizer = BloomTokenizerFast.from_pretrained(checkpoint,
245 | padding_side='left'
246 | )
247 | else:
248 | tokenizer = AutoTokenizer.from_pretrained(checkpoint, trust_remote_code=True)
249 |
250 | print(f"Loaded the model in {(time.time() - t0):.2f} seconds.")
251 | return model, tokenizer
252 |
253 |
254 | def add_lora_to_model(self, lora):
255 | if 'pulse' in self.model_name.lower():
256 | peft_path = f"finetune/pulse/output/{lora}"
257 | load_type = torch.float16
258 | self.model = PeftModel.from_pretrained(self.model, peft_path, torch_dtype=load_type)
259 |
260 | print('finished load lora:',lora)
261 |
262 | def _add_lora_to_model(self, lora_names):
263 | # 目前加载的lora
264 | prior_set = set(self.lora_names)
265 | # 需要加载的
266 | added_set = set(lora_names) - prior_set
267 | # 删除的lora
268 | removed_set = prior_set - set(lora_names)
269 | self.lora_names = list(lora_names)
270 |
271 | # Nothing to do = skip.
272 | if len(added_set) == 0 and len(removed_set) == 0:
273 | return
274 |
275 | # Only adding, and already peft? Do it the easy way.
276 | if len(removed_set) == 0 and len(prior_set) > 0:
277 | print(f"Adding the LoRA(s) named {added_set} to the model...")
278 | for lora in added_set:
279 | self.model.load_adapter(Path(f"{self.lora_dir}/{lora}"), lora)
280 | return
281 |
282 | # If removing anything, disable all and re-add.
283 | if len(removed_set) > 0:
284 | self.model.disable_adapter()
285 |
286 | if len(lora_names) > 0:
287 | print("Applying the following LoRAs to {}: {}".format(self.model_name, ', '.join(lora_names)))
288 | params = {}
289 | if self.llm_device.lower() != "cpu":
290 | params['dtype'] = self.model.dtype
291 | if hasattr(self.model, "hf_device_map"):
292 | params['device_map'] = {"base_model.model." + k: v for k, v in self.model.hf_device_map.items()}
293 | elif self.load_in_8bit:
294 | params['device_map'] = {'': 0}
295 | self.model.resize_token_embeddings(len(self.tokenizer))
296 |
297 | self.model = PeftModel.from_pretrained(self.model, Path(f"{self.lora_dir}/{lora_names[0]}"), **params)
298 |
299 | for lora in lora_names[1:]:
300 | self.model.load_adapter(Path(f"{self.lora_dir}/{lora}"), lora)
301 |
302 | if not self.load_in_8bit and self.llm_device.lower() != "cpu":
303 |
304 | if not hasattr(self.model, "hf_device_map"):
305 | if torch.has_mps:
306 | device = torch.device('mps')
307 | self.model = self.model.to(device)
308 | else:
309 | self.model = self.model.cuda()
310 |
311 | def clear_torch_cache(self):
312 | gc.collect()
313 | if self.llm_device.lower() != "cpu":
314 | if torch.has_mps:
315 | try:
316 | from torch.mps import empty_cache
317 | empty_cache()
318 | except Exception as e:
319 | print(e)
320 | print(
321 | "如果您使用的是 macOS 建议将 pytorch 版本升级至 2.0.0 或更高版本,以支持及时清理 torch 产生的内存占用。")
322 | elif torch.has_cuda:
323 | device_id = "0" if torch.cuda.is_available() else None
324 | CUDA_DEVICE = f"{self.llm_device}:{device_id}" if device_id else self.llm_device
325 | with torch.cuda.device(CUDA_DEVICE):
326 | torch.cuda.empty_cache()
327 | torch.cuda.ipc_collect()
328 | else:
329 | print("未检测到 cuda 或 mps,暂不支持清理显存")
330 |
331 | def unload_model(self):
332 | del self.model
333 | del self.tokenizer
334 | self.model = self.tokenizer = None
335 | self.clear_torch_cache()
336 |
337 | def set_model_path(self, model_path):
338 | self.model_path = model_path
339 |
340 | def reload_model(self):
341 | self.unload_model()
342 | self.model_config = self._load_model_config(self.model_name)
343 |
344 | if self.use_ptuning_v2:
345 | try:
346 | prefix_encoder_file = open(Path(f'{self.ptuning_dir}/config.json'), 'r')
347 | prefix_encoder_config = json.loads(prefix_encoder_file.read())
348 | prefix_encoder_file.close()
349 | self.model_config.pre_seq_len = prefix_encoder_config['pre_seq_len']
350 | self.model_config.prefix_projection = prefix_encoder_config['prefix_projection']
351 | except Exception as e:
352 | print("加载PrefixEncoder config.json失败")
353 |
354 | self.model, self.tokenizer = self._load_model(self.model_name)
355 |
356 | if self.lora:
357 | self.add_lora_to_model(self.lora)
358 |
359 | if self.use_ptuning_v2:
360 | try:
361 | prefix_state_dict = torch.load(Path(f'{self.ptuning_dir}/pytorch_model.bin'))
362 | new_prefix_state_dict = {}
363 | for k, v in prefix_state_dict.items():
364 | if k.startswith("transformer.prefix_encoder."):
365 | new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
366 | self.model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
367 | self.model.transformer.prefix_encoder.float()
368 | except Exception as e:
369 | print("加载PrefixEncoder模型参数失败")
370 |
371 | self.model = self.model.eval()
372 |
--------------------------------------------------------------------------------
/loader/models/shared.py:
--------------------------------------------------------------------------------
1 | import sys
2 |
3 | from loader.models.loader.args import parser
4 | from loader.models.loader import LoaderCheckPoint
5 | from configs.common_config import (llm_model_dict, LLM_MODEL)
6 | from loader.models.base import BaseAnswer
7 | """迭代器是否停止状态"""
8 | stop_everything = False
9 |
10 | loaderCheckPoint: LoaderCheckPoint = None
11 |
12 |
13 | def loaderLLM(llm_model: str = None, no_remote_model: bool = False, use_ptuning_v2: bool = False) -> BaseAnswer:
14 | """
15 | init llm_model_ins LLM
16 | :param llm_model: model_name
17 | :param no_remote_model: remote in the model on loader checkpoint, if your load local model to add the ` --no-remote-model
18 | :param use_ptuning_v2: Use p-tuning-v2 PrefixEncoder
19 | :return:
20 | """
21 | pre_model_name = loaderCheckPoint.model_name
22 | llm_model_info = llm_model_dict[pre_model_name]
23 |
24 | if no_remote_model:
25 | loaderCheckPoint.no_remote_model = no_remote_model
26 |
27 | if use_ptuning_v2:
28 | loaderCheckPoint.use_ptuning_v2 = use_ptuning_v2
29 |
30 | if llm_model:
31 | llm_model_info = llm_model_dict[llm_model]
32 |
33 | if loaderCheckPoint.no_remote_model:
34 | loaderCheckPoint.model_name = llm_model_info['name']
35 | else:
36 | loaderCheckPoint.model_name = llm_model_info['pretrained_model_name']
37 |
38 | loaderCheckPoint.model_path = llm_model_info["local_model_path"]
39 |
40 | loaderCheckPoint.reload_model()
41 |
42 | provides_class = getattr(sys.modules['loader.models'], llm_model_info['provides'])
43 | modelInsLLM = provides_class(checkPoint=loaderCheckPoint)
44 | return modelInsLLM
45 |
--------------------------------------------------------------------------------
/loader/pdf_loader.py:
--------------------------------------------------------------------------------
1 | """Loader that loads image files."""
2 | from typing import List
3 |
4 | from langchain.document_loaders.unstructured import UnstructuredFileLoader
5 | from paddleocr import PaddleOCR
6 | import os
7 | import fitz
8 | import nltk
9 | from configs.common_config import NLTK_DATA_PATH
10 |
11 | nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
12 |
13 | class UnstructuredPaddlePDFLoader(UnstructuredFileLoader):
14 | """Loader that uses unstructured to load image files, such as PNGs and JPGs."""
15 |
16 | def _get_elements(self) -> List:
17 | def pdf_ocr_txt(filepath, dir_path="tmp_files"):
18 | full_dir_path = os.path.join(os.path.dirname(filepath), dir_path)
19 | if not os.path.exists(full_dir_path):
20 | os.makedirs(full_dir_path)
21 | ocr = PaddleOCR(use_angle_cls=True, lang="ch", use_gpu=False, show_log=False)
22 | doc = fitz.open(filepath)
23 | txt_file_path = os.path.join(full_dir_path, f"{os.path.split(filepath)[-1]}.txt")
24 | img_name = os.path.join(full_dir_path, 'tmp.png')
25 | with open(txt_file_path, 'w', encoding='utf-8') as fout:
26 | for i in range(doc.page_count):
27 | page = doc[i]
28 | text = page.get_text("")
29 | fout.write(text)
30 | fout.write("\n")
31 |
32 | img_list = page.get_images()
33 | for img in img_list:
34 | pix = fitz.Pixmap(doc, img[0])
35 | if pix.n - pix.alpha >= 4:
36 | pix = fitz.Pixmap(fitz.csRGB, pix)
37 | pix.save(img_name)
38 |
39 | result = ocr.ocr(img_name)
40 | ocr_result = [i[1][0] for line in result for i in line]
41 | fout.write("\n".join(ocr_result))
42 | os.remove(img_name)
43 | return txt_file_path
44 |
45 | txt_file_path = pdf_ocr_txt(self.file_path)
46 | from unstructured.partition.text import partition_text
47 | return partition_text(filename=txt_file_path, **self.unstructured_kwargs)
48 |
49 |
50 | if __name__ == "__main__":
51 | filepath = os.path.join(os.path.dirname(os.path.dirname(__file__)), "content", "samples", "test.pdf")
52 | loader = UnstructuredPaddlePDFLoader(filepath, mode="elements")
53 | docs = loader.load()
54 | for doc in docs:
55 | print(doc)
56 |
--------------------------------------------------------------------------------
/loader/textsplitter/__init__.py:
--------------------------------------------------------------------------------
1 | from .chinese_text_splitter import ChineseTextSplitter
2 | from .ali_text_splitter import AliTextSplitter
--------------------------------------------------------------------------------
/loader/textsplitter/ali_text_splitter.py:
--------------------------------------------------------------------------------
1 | from langchain.text_splitter import CharacterTextSplitter
2 | import re
3 | from typing import List
4 |
5 |
6 | class AliTextSplitter(CharacterTextSplitter):
7 | def __init__(self, pdf: bool = False, **kwargs):
8 | super().__init__(**kwargs)
9 | self.pdf = pdf
10 |
11 | def split_text(self, text: str) -> List[str]:
12 | # use_document_segmentation参数指定是否用语义切分文档,此处采取的文档语义分割模型为达摩院开源的nlp_bert_document-segmentation_chinese-base,论文见https://arxiv.org/abs/2107.09278
13 | # 如果使用模型进行文档语义切分,那么需要安装modelscope[nlp]:pip install "modelscope[nlp]" -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html
14 | # 考虑到使用了三个模型,可能对于低配置gpu不太友好,因此这里将模型load进cpu计算,有需要的话可以替换device为自己的显卡id
15 | if self.pdf:
16 | text = re.sub(r"\n{3,}", r"\n", text)
17 | text = re.sub('\s', " ", text)
18 | text = re.sub("\n\n", "", text)
19 | from modelscope.pipelines import pipeline
20 |
21 | p = pipeline(
22 | task="document-segmentation",
23 | model='damo/nlp_bert_document-segmentation_chinese-base',
24 | device="cpu")
25 | result = p(documents=text)
26 | sent_list = [i for i in result["text"].split("\n\t") if i]
27 | return sent_list
28 |
--------------------------------------------------------------------------------
/loader/textsplitter/chinese_text_splitter.py:
--------------------------------------------------------------------------------
1 | from langchain.text_splitter import CharacterTextSplitter
2 | import re
3 | from typing import List
4 | from configs.common_config import SENTENCE_SIZE
5 |
6 |
7 | class ChineseTextSplitter(CharacterTextSplitter):
8 | def __init__(self, pdf: bool = False, sentence_size: int = SENTENCE_SIZE, **kwargs):
9 | super().__init__(**kwargs)
10 | self.pdf = pdf
11 | self.sentence_size = sentence_size
12 |
13 | def split_text1(self, text: str) -> List[str]:
14 | if self.pdf:
15 | text = re.sub(r"\n{3,}", "\n", text)
16 | text = re.sub('\s', ' ', text)
17 | text = text.replace("\n\n", "")
18 | sent_sep_pattern = re.compile('([﹒﹔﹖﹗.。!?]["’”」』]{0,2}|(?=["‘“「『]{1,2}|$))') # del :;
19 | sent_list = []
20 | for ele in sent_sep_pattern.split(text):
21 | if sent_sep_pattern.match(ele) and sent_list:
22 | sent_list[-1] += ele
23 | elif ele:
24 | sent_list.append(ele)
25 | return sent_list
26 |
27 | def split_text(self, text: str) -> List[str]: ##此处需要进一步优化逻辑
28 | if self.pdf:
29 | text = re.sub(r"\n{3,}", r"\n", text)
30 | text = re.sub('\s', " ", text)
31 | text = re.sub("\n\n", "", text)
32 |
33 | text = re.sub(r'([;;.!?。!?\?])([^”’])', r"\1\n\2", text) # 单字符断句符
34 | text = re.sub(r'(\.{6})([^"’”」』])', r"\1\n\2", text) # 英文省略号
35 | text = re.sub(r'(\…{2})([^"’”」』])', r"\1\n\2", text) # 中文省略号
36 | text = re.sub(r'([;;!?。!?\?]["’”」』]{0,2})([^;;!?,。!?\?])', r'\1\n\2', text)
37 | # 如果双引号前有终止符,那么双引号才是句子的终点,把分句符\n放到双引号后,注意前面的几句都小心保留了双引号
38 | text = text.rstrip() # 段尾如果有多余的\n就去掉它
39 | # 很多规则中会考虑分号;,但是这里我把它忽略不计,破折号、英文双引号等同样忽略,需要的再做些简单调整即可。
40 | ls = [i for i in text.split("\n") if i]
41 | for ele in ls:
42 | if len(ele) > self.sentence_size:
43 | ele1 = re.sub(r'([,,.]["’”」』]{0,2})([^,,.])', r'\1\n\2', ele)
44 | ele1_ls = ele1.split("\n")
45 | for ele_ele1 in ele1_ls:
46 | if len(ele_ele1) > self.sentence_size:
47 | ele_ele2 = re.sub(r'([\n]{1,}| {2,}["’”」』]{0,2})([^\s])', r'\1\n\2', ele_ele1)
48 | ele2_ls = ele_ele2.split("\n")
49 | for ele_ele2 in ele2_ls:
50 | if len(ele_ele2) > self.sentence_size:
51 | ele_ele3 = re.sub('( ["’”」』]{0,2})([^ ])', r'\1\n\2', ele_ele2)
52 | ele2_id = ele2_ls.index(ele_ele2)
53 | ele2_ls = ele2_ls[:ele2_id] + [i for i in ele_ele3.split("\n") if i] + ele2_ls[
54 | ele2_id + 1:]
55 | ele_id = ele1_ls.index(ele_ele1)
56 | ele1_ls = ele1_ls[:ele_id] + [i for i in ele2_ls if i] + ele1_ls[ele_id + 1:]
57 |
58 | id = ls.index(ele)
59 | ls = ls[:id] + [i for i in ele1_ls if i] + ls[id + 1:]
60 | return ls
61 |
--------------------------------------------------------------------------------
/loader/utils/__init__.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | def torch_gc():
4 | if torch.cuda.is_available():
5 | # with torch.cuda.device(DEVICE):
6 | torch.cuda.empty_cache()
7 | torch.cuda.ipc_collect()
8 | elif torch.backends.mps.is_available():
9 | try:
10 | from torch.mps import empty_cache
11 | empty_cache()
12 | except Exception as e:
13 | print(e)
14 | print("如果您使用的是 macOS 建议将 pytorch 版本升级至 2.0.0 或更高版本,以支持及时清理 torch 产生的内存占用。")
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | accelerate==0.18.0
2 | aiofiles==23.1.0
3 | aiohttp==3.8.4
4 | aiosignal==1.3.1
5 | altair==5.0.1
6 | antlr4-python3-runtime==4.9.3
7 | anyio==3.7.0
8 | argilla==1.11.0
9 | astor==0.8.1
10 | async-timeout==4.0.2
11 | attrdict==2.0.1
12 | attrs==23.1.0
13 | azure-core==1.27.1
14 | Babel==2.12.1
15 | backoff==2.2.1
16 | bce-python-sdk==0.8.83
17 | beautifulsoup4==4.12.2
18 | bitsandbytes==0.39.1
19 | blinker==1.6.2
20 | Brotli==1.0.9
21 | cachetools==5.3.1
22 | certifi==2023.5.7
23 | cffi==1.15.1
24 | chardet==5.1.0
25 | charset-normalizer==3.1.0
26 | click==8.1.3
27 | cmake==3.26.4
28 | coloredlogs==15.0.1
29 | commonmark==0.9.1
30 | contourpy==1.1.0
31 | cpm-kernels==1.0.11
32 | cryptography==41.0.1
33 | cssselect==1.2.0
34 | cssutils==2.7.1
35 | cycler==0.11.0
36 | Cython==0.29.35
37 | dataclasses-json==0.5.8
38 | decorator==5.1.1
39 | Deprecated==1.2.14
40 | dill==0.3.6
41 | effdet==0.4.1
42 | et-xmlfile==1.1.0
43 | exceptiongroup==1.1.1
44 | faiss-cpu==1.7.4
45 | fastapi==0.95.1
46 | ffmpy==0.3.0
47 | filelock==3.12.2
48 | filetype==1.2.0
49 | fire==0.5.0
50 | Flask==2.3.2
51 | flask-babel==3.1.0
52 | flatbuffers==23.5.26
53 | fonttools==4.40.0
54 | frozenlist==1.3.3
55 | fsspec==2023.6.0
56 | future==0.18.3
57 | gevent==22.10.2
58 | geventhttpclient==2.0.2
59 | gradio==3.28.3
60 | gradio_client==0.2.7
61 | greenlet==2.0.2
62 | grpcio==1.56.0
63 | h11==0.14.0
64 | httpcore==0.16.3
65 | httpx==0.23.3
66 | huggingface-hub==0.15.1
67 | humanfriendly==10.0
68 | icetk==0.0.7
69 | idna==3.4
70 | imageio==2.31.1
71 | imgaug==0.4.0
72 | importlib-metadata==6.7.0
73 | importlib-resources==5.12.0
74 | iopath==0.1.10
75 | itsdangerous==2.1.2
76 | Jinja2==3.1.2
77 | joblib==1.2.0
78 | jsonschema==4.17.3
79 | kiwisolver==1.4.4
80 | langchain==0.0.174
81 | layoutparser==0.3.4
82 | lazy_loader==0.2
83 | linkify-it-py==2.0.2
84 | lit==16.0.6
85 | lmdb==1.4.1
86 | lxml==4.9.2
87 | Markdown==3.4.3
88 | markdown-it-py==2.2.0
89 | MarkupSafe==2.1.3
90 | marshmallow==3.19.0
91 | marshmallow-enum==1.5.1
92 | matplotlib==3.7.1
93 | mdit-py-plugins==0.3.3
94 | mdurl==0.1.2
95 | monotonic==1.6
96 | mpmath==1.3.0
97 | msg-parser==1.2.0
98 | multidict==6.0.4
99 | multiprocess==0.70.14
100 | mypy-extensions==1.0.0
101 | networkx==3.1
102 | nltk==3.8.1
103 | numexpr==2.8.4
104 | numpy==1.23.5
105 | nvidia-cublas-cu11==11.10.3.66
106 | nvidia-cuda-cupti-cu11==11.7.101
107 | nvidia-cuda-nvrtc-cu11==11.7.99
108 | nvidia-cuda-runtime-cu11==11.7.99
109 | nvidia-cudnn-cu11==8.5.0.96
110 | nvidia-cufft-cu11==10.9.0.58
111 | nvidia-curand-cu11==10.2.10.91
112 | nvidia-cusolver-cu11==11.4.0.1
113 | nvidia-cusparse-cu11==11.7.4.91
114 | nvidia-nccl-cu11==2.14.3
115 | nvidia-nvtx-cu11==11.7.91
116 | olefile==0.46
117 | omegaconf==2.3.0
118 | onnx==1.12.0
119 | onnxruntime==1.15.1
120 | openapi-schema-pydantic==1.2.4
121 | opencv-contrib-python==4.6.0.66
122 | opencv-python==4.6.0.66
123 | openpyxl==3.1.2
124 | opt-einsum==3.3.0
125 | orjson==3.9.1
126 | packaging==23.1
127 | paddle-bfloat==0.1.7
128 | paddleocr==2.6.1.3
129 | paddlepaddle==2.4.2
130 | pandas==1.5.3
131 | pdf2docx==0.5.6
132 | pdf2image==1.16.3
133 | pdfminer.six==20221105
134 | pdfplumber==0.9.0
135 | peft==0.3.0
136 | Pillow==9.5.0
137 | portalocker==2.7.0
138 | premailer==3.10.0
139 | protobuf==3.18.3
140 | psutil==5.9.5
141 | pyclipper==1.3.0.post4
142 | pycocotools==2.0.6
143 | pycparser==2.21
144 | pycryptodome==3.18.0
145 | pydantic==1.10.9
146 | pydub==0.25.1
147 | Pygments==2.15.1
148 | PyMuPDF==1.20.2
149 | pynvml==11.5.0
150 | pypandoc==1.11
151 | pyparsing==3.1.0
152 | pypinyin==0.48.0
153 | pyrsistent==0.19.3
154 | pytesseract==0.3.10
155 | python-dateutil==2.8.2
156 | python-docx==0.8.11
157 | python-magic==0.4.27
158 | python-multipart==0.0.6
159 | python-pptx==0.6.21
160 | python-rapidjson==1.10
161 | pytz==2023.3
162 | PyWavelets==1.4.1
163 | PyYAML==6.0
164 | rapidfuzz==3.1.1
165 | rarfile==4.0
166 | regex==2023.6.3
167 | requests==2.28.2
168 | rfc3986==1.5.0
169 | rich==13.0.1
170 | ruamel.yaml==0.17.32
171 | ruamel.yaml.clib==0.2.7
172 | safetensors==0.3.1
173 | scikit-image==0.21.0
174 | scikit-learn==1.2.2
175 | scipy==1.11.0
176 | semantic-version==2.10.0
177 | sentence-transformers==2.2.2
178 | sentencepiece==0.1.99
179 | shapely==2.0.1
180 | six==1.16.0
181 | sniffio==1.3.0
182 | soupsieve==2.4.1
183 | SQLAlchemy==2.0.17
184 | starlette==0.26.1
185 | sympy==1.12
186 | tabulate==0.9.0
187 | tenacity==8.2.2
188 | termcolor==2.3.0
189 | threadpoolctl==3.1.0
190 | tifffile==2023.4.12
191 | timm==0.9.2
192 | tokenizers==0.13.3
193 | toolz==0.12.0
194 | torch==2.0.1
195 | torchvision==0.15.2
196 | tqdm==4.65.0
197 | transformers==4.29.1
198 | triton==2.0.0
199 | tritonclient==2.34.0
200 | typer==0.9.0
201 | typing-inspect==0.9.0
202 | typing_extensions==4.6.3
203 | tzdata==2023.3
204 | uc-micro-py==1.0.2
205 | unstructured==0.7.9
206 | unstructured-inference==0.5.1
207 | urllib3==1.26.16
208 | uvicorn==0.21.1
209 | visualdl==2.5.0
210 | Wand==0.6.11
211 | websockets==11.0.3
212 | Werkzeug==2.3.6
213 | wrapt==1.14.1
214 | x2paddle==1.4.1
215 | xlrd==2.0.1
216 | XlsxWriter==3.1.2
217 | yarl==1.9.2
218 | zipp==3.15.0
219 | zope.event==5.0
220 | zope.interface==6.0
221 |
--------------------------------------------------------------------------------
/vector_store/drug_kb/index.faiss:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JuneYaooo/medical_kb_chatbot/be4ff60c47ffb31ac3052f82b5ed607d7c7ab089/vector_store/drug_kb/index.faiss
--------------------------------------------------------------------------------
/vector_store/drug_kb/index.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JuneYaooo/medical_kb_chatbot/be4ff60c47ffb31ac3052f82b5ed607d7c7ab089/vector_store/drug_kb/index.pkl
--------------------------------------------------------------------------------