├── LICENSE ├── README.md ├── configs ├── config.ini └── params.py ├── data └── .gitignore ├── doc_search.py ├── docker └── Dockerfile ├── docs ├── Docker部署Elasticsearch教程.md └── demo_pic.png ├── embedding.py ├── model ├── base.py └── chatglm_llm.py ├── requirements.txt ├── utils └── es_tool.py └── web.py /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 🔥ElasticSearch-Langchain-Chatglm2 2 | 3 | # ✨项目介绍 4 | 5 | 受[langchain-ChatGLM](https://github.com/imClumsyPanda/langchain-ChatGLM)项目启发,由于Elasticsearch可实现文本和向量两种方式混合查询,且在业务场景中使用更广泛,因此本项目用Elasticsearch代替Faiss作为知识存储库,利用Langchain+Chatglm2实现基于自有知识库的智能问答。 6 | 7 | 本项目希望抛砖引玉,能够帮助大家快速地做技术验证和技术路线选取。 8 | 9 | 默认使用的embedding模型为[moka-ai/m3e-large](https://huggingface.co/moka-ai/m3e-large) 10 | 11 | 目前仅支持上传 txt、docx、md等文本格式文件。 12 | 13 | 默认使用余弦距离计算文本相似性。 14 | 15 | 16 | 17 | # 🚀使用方式 18 | 19 | ### 修改配置文件 20 | 21 | 修改配置文件[config.ini](https://github.com/iMagist486/ElasticSearch-Langchain-Chatglm2/blob/main/configs/config.ini),配置Elasticsearch链接 22 | 23 | 模型可修改为本地路径 24 | 25 | **增加对[InternLM](https://github.com/InternLM/InternLM)的支持:**`llm_model`修改为`internlm/internlm-chat-7b`即可。 26 | 27 | ### 运行web demo 28 | 29 | 执行[web.py](https://github.com/iMagist486/ElasticSearch-Langchain-Chatglm2/blob/main/web.py) 30 | 31 | ```python 32 | python web.py 33 | ``` 34 | 35 | # 📑Demo详解 36 | 37 | ![demo_pic](docs/demo_pic.png) 38 | 39 | ### 文档交互模块: 40 | 41 | ES插入时文档交互模块会显示插入是否成功,或抛出异常内容;问答时,文档交互模块会展示查询到的内容,包括文档来源,文档内容和相似度分数。 42 | 43 | ### 查询设置模块: 44 | 45 | **三种查询模式**,具体区别见Elasticsearch官方文档 46 | 47 | 近似查询:[Approximate kNN](https://www.elastic.co/guide/en/elasticsearch/reference/current/knn-search.html#approximate-knn) 48 | 49 | 混合查询:[Combine approximate kNN with other features](https://www.elastic.co/guide/en/elasticsearch/reference/current/knn-search.html#_combine_approximate_knn_with_other_features) 50 | 51 | 精确查询:[Exact, brute-force kNN](https://www.elastic.co/guide/en/elasticsearch/reference/current/knn-search.html#exact-knn) 52 | 53 | **查询阈值**: 54 | 55 | 仅返回相似度分数大于阈值的查询结果,0为不设限制 56 | 57 | **top_k**: 58 | 59 | 返回最相关的k个文本 60 | 61 | **knn_boost**: 62 | 63 | 适用于混合查询,knn_score所占比例 64 | 65 | 66 | 67 | # 🐳Docker 部署 68 | 69 | 打包docker镜像 70 | 71 | ```sh 72 | docker build -f docker/Dockerfile -t es-chatglm:v1.0 . 73 | ``` 74 | 75 | 启动docker容器 76 | 77 | ```sh 78 | docker run --gpus "device=0" -p 8000:8000 -it es-chatglm:v1.0 bash 79 | ``` 80 | 81 | 82 | 83 | # ❤️引用及感谢 84 | 85 | 1. [THUDM/chatglm2-6b](https://huggingface.co/THUDM/chatglm2-6b) 86 | 2. [moka-ai/m3e-large](https://huggingface.co/moka-ai/m3e-large) 87 | 3. [LangChain](https://github.com/hwchase17/langchain) 88 | 4. [langchain-ChatGLM](https://github.com/imClumsyPanda/langchain-ChatGLM) 89 | 90 | # 📧联系方式 91 | 92 | wzh486@outlook.com 93 | 94 | 欢迎沟通交流! 95 | -------------------------------------------------------------------------------- /configs/config.ini: -------------------------------------------------------------------------------- 1 | [model_configs] 2 | # Embedding model name or path 3 | embedding_model = moka-ai/m3e-large 4 | # LLM model name or path 5 | llm_model = THUDM/chatglm2-6b 6 | 7 | [es_configs] 8 | username = elastic 9 | passwd = your ES password 10 | url = your ES ip or url 11 | port = 9200 12 | index_name = test -------------------------------------------------------------------------------- /configs/params.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | from configparser import ConfigParser 3 | 4 | 5 | class BaseParams(object): 6 | """ 7 | 各类型参数的父类 8 | """ 9 | 10 | def __init__(self, conf_fp: str = 'configs/config.ini'): 11 | self.config = ConfigParser() 12 | self.config.read(conf_fp, encoding='utf8') 13 | 14 | 15 | class ModelParams(BaseParams): 16 | """ 17 | 数据拉取参数类 18 | """ 19 | 20 | def __init__(self, conf_fp: str = 'configs/config.ini'): 21 | super(ModelParams, self).__init__(conf_fp) 22 | section_name = 'model_configs' 23 | self.embedding_model = self.config.get(section_name, 'embedding_model') 24 | self.llm_model = self.config.get(section_name, 'llm_model') 25 | 26 | 27 | class ESParams(BaseParams): 28 | """ 29 | 数据拉取参数类 30 | """ 31 | 32 | def __init__(self, conf_fp: str = 'configs/config.ini'): 33 | super(ESParams, self).__init__(conf_fp) 34 | section_name = 'es_configs' 35 | self.username = self.config.get(section_name, 'username') 36 | self.passwd = self.config.get(section_name, 'passwd') 37 | self.url = self.config.get(section_name, 'url') 38 | self.port = self.config.get(section_name, 'port') 39 | self.index_name = self.config.get(section_name, 'index_name') 40 | -------------------------------------------------------------------------------- /data/.gitignore: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iMagist486/ElasticSearch-Langchain-Chatglm2/ef229c3c7eb3cd9087853e1a5edc785a16bed748/data/.gitignore -------------------------------------------------------------------------------- /doc_search.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | from elasticsearch import Elasticsearch 4 | from langchain.vectorstores import ElasticKnnSearch 5 | from langchain.document_loaders import TextLoader 6 | from langchain.text_splitter import CharacterTextSplitter 7 | from configs.params import ESParams 8 | from embedding import Embeddings 9 | from typing import Dict 10 | 11 | 12 | def _default_knn_mapping(dims: int) -> Dict: 13 | """Generates a default index mapping for kNN search.""" 14 | return { 15 | "properties": { 16 | "text": {"type": "text"}, 17 | "vector": { 18 | "type": "dense_vector", 19 | "dims": dims, 20 | "index": True, 21 | "similarity": "cosine", 22 | }, 23 | } 24 | } 25 | 26 | 27 | def generate_search_query(vec, size) -> Dict: 28 | query = { 29 | "query": { 30 | "script_score": { 31 | "query": { 32 | "match_all": {} 33 | }, 34 | "script": { 35 | "source": "cosineSimilarity(params.queryVector, 'vector') + 1.0", 36 | "params": { 37 | "queryVector": vec 38 | } 39 | } 40 | } 41 | }, 42 | "size": size 43 | } 44 | return query 45 | 46 | 47 | def generate_knn_query(vec, size) -> Dict: 48 | query = { 49 | "knn": { 50 | "field": "vector", 51 | "query_vector": vec, 52 | "k": 10, 53 | "num_candidates": 100 54 | }, 55 | "size": size 56 | } 57 | return query 58 | 59 | 60 | def generate_hybrid_query(text, vec, size, knn_boost) -> Dict: 61 | query = { 62 | "query": { 63 | "match": { 64 | "text": { 65 | "query": text, 66 | "boost": 1 - knn_boost 67 | } 68 | } 69 | }, 70 | "knn": { 71 | "field": "vector", 72 | "query_vector": vec, 73 | "k": 10, 74 | "num_candidates": 100, 75 | "boost": knn_boost 76 | }, 77 | "size": size 78 | } 79 | return query 80 | 81 | 82 | def load_file(filepath, chunk_size, chunk_overlap): 83 | loader = TextLoader(filepath, encoding='utf-8') 84 | documents = loader.load() 85 | text_splitter = CharacterTextSplitter(separator='\n', chunk_size=chunk_size, chunk_overlap=chunk_overlap) 86 | docs = text_splitter.split_documents(documents) 87 | return docs 88 | 89 | 90 | class ES: 91 | def __init__(self, embedding_model_path): 92 | self.es_params = ESParams() 93 | self.client = Elasticsearch(['{}:{}'.format(self.es_params.url, self.es_params.port)], 94 | basic_auth=(self.es_params.username, self.es_params.passwd), 95 | verify_certs=False) 96 | self.embedding = Embeddings(embedding_model_path) 97 | self.es = ElasticKnnSearch(index_name=self.es_params.index_name, embedding=self.embedding, 98 | es_connection=self.client) 99 | 100 | def doc_upload(self, file_obj, chunk_size, chunk_overlap): 101 | try: 102 | if not self.client.indices.exists(index=self.es_params.index_name): 103 | dims = len(self.embedding.embed_query("test")) 104 | mapping = _default_knn_mapping(dims) 105 | self.client.indices.create(index=self.es_params.index_name, body={"mappings": mapping}) 106 | filename = os.path.split(file_obj.name)[-1] 107 | file_path = 'data/' + filename 108 | shutil.move(file_obj.name, file_path) 109 | docs = load_file(file_path, chunk_size, chunk_overlap) 110 | self.es.add_documents(docs) 111 | return "插入成功" 112 | except Exception as e: 113 | return e 114 | 115 | def doc_search(self, method, query, top_k, knn_boost): 116 | result = [] 117 | query_vector = self.embedding.embed_query(query) 118 | if method == "近似查询": 119 | query_body = generate_knn_query(vec=query_vector, size=top_k) 120 | elif method == "混合查询": 121 | query_body = generate_hybrid_query(text=query, vec=query_vector, size=top_k, knn_boost=knn_boost) 122 | else: 123 | query_body = generate_search_query(vec=query_vector, size=top_k) 124 | response = self.client.search(index=self.es_params.index_name, body=query_body) 125 | hits = [hit for hit in response["hits"]["hits"]] 126 | for i in hits: 127 | result.append({ 128 | 'content': i['_source']['text'], 129 | 'source': i['_source']['metadata']['source'], 130 | 'score': i['_score'] 131 | }) 132 | return result 133 | -------------------------------------------------------------------------------- /docker/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM pytorch/pytorch:2.0.1-cuda11.7-cudnn8-runtime 2 | MAINTAINER wenzehua 3 | 4 | RUN ln -sf /usr/share/zoneinfo/Asia/Shanghai /etc/localtime && echo 'Asia/Shanghai' >/etc/timezone 5 | RUN pip config set global.index-url https://mirror.sjtu.edu.cn/pypi/web/simple 6 | 7 | RUN apt-get update && apt-get install -y vim 8 | 9 | COPY ./requirements.txt /requirements.txt 10 | RUN pip install --upgrade pip 11 | RUN pip install -r /requirements.txt 12 | 13 | COPY . /data/app 14 | WORKDIR /data/app 15 | 16 | EXPOSE 8000 17 | -------------------------------------------------------------------------------- /docs/Docker部署Elasticsearch教程.md: -------------------------------------------------------------------------------- 1 | # Docker部署Elasticsearch 2 | 3 | ### 1. 下载Elasticsearch和Kibana镜像 4 | 5 | ```sh 6 | docker pull elastic/elasticsearch:8.8.2 7 | docker pull elastic/kibana:8.8.2 8 | ``` 9 | 10 | ### 2. 设置max_map_count 11 | 12 | ```sh 13 | cat /proc/sys/vm/max_map_count 14 | sysctl -w vm.max_map_count=262144 15 | ``` 16 | 17 | ### 3. 为Elasticsearch和Kibana创建docker网络 18 | 19 | ```sh 20 | docker network create elastic 21 | ``` 22 | 23 | ### 4. 创建映射文件夹并设置最高权限 24 | 25 | ```sh 26 | mkdir /data/es 27 | chmod 777 -R /data/es/ 28 | ``` 29 | 30 | ### 4. 启动ES镜像 31 | 32 | ```sh 33 | docker run --name es01 \ 34 | --net elastic -p 9200:9200 \ 35 | -v /data/es/data:/usr/share/elasticsearch/data \ 36 | -v /data/es/logs:/usr/share/elasticsearch/logs \ 37 | -v /data/es/plugins:/usr/share/elasticsearch/plugins \ 38 | -it elastic/elasticsearch:8.8.2 39 | ``` 40 | 41 | 这时会生成一个elastic账户的密码和一个Kibana的enrollment token 42 | 43 | ### 5.启动kibana镜像 44 | 45 | ```sh 46 | docker run --name kib-01 --net elastic -p 5601:5601 elastic/kibana:8.8.2 47 | ``` 48 | 49 | 50 | 51 | ### 其他命令 52 | 53 | ##### 取出证书 54 | 55 | ```sh 56 | docker cp es01:/usr/share/elasticsearch/config/certs/http_ca.crt . 57 | ``` 58 | 59 | ##### 重置密码 60 | 61 | ```sh 62 | docker exec -it es01 /usr/share/elasticsearch/bin/elasticsearch-reset-password -u elastic 63 | ``` 64 | 65 | ##### 重置kibana token 66 | 67 | ```sh 68 | docker exec -it es01 /usr/share/elasticsearch/bin/elasticsearch-create-enrollment-token -s kibana 69 | ``` -------------------------------------------------------------------------------- /docs/demo_pic.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iMagist486/ElasticSearch-Langchain-Chatglm2/ef229c3c7eb3cd9087853e1a5edc785a16bed748/docs/demo_pic.png -------------------------------------------------------------------------------- /embedding.py: -------------------------------------------------------------------------------- 1 | from sentence_transformers import SentenceTransformer 2 | 3 | 4 | class Embeddings: 5 | def __init__(self, model_path): 6 | self.model = SentenceTransformer(model_path) 7 | 8 | def embed_documents(self, text_list): 9 | embeddings = self.model.encode(text_list) 10 | encod_list = embeddings.tolist() 11 | return encod_list 12 | 13 | def embed_query(self, text): 14 | embeddings = self.model.encode([text]) 15 | encod_list = embeddings.tolist() 16 | return encod_list[0] 17 | -------------------------------------------------------------------------------- /model/base.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Optional, List 3 | 4 | 5 | class AnswerResult: 6 | """ 7 | 消息实体 8 | """ 9 | history: List[List[str]] = [] 10 | llm_output: Optional[dict] = None 11 | 12 | 13 | class BaseAnswer(ABC): 14 | """上层业务包装器.用于结果生成统一api调用""" 15 | 16 | @property 17 | @abstractmethod 18 | def _history_len(self) -> int: 19 | """Return _history_len of llm.""" 20 | 21 | @abstractmethod 22 | def set_history_len(self, history_len: int) -> None: 23 | """Return _history_len of llm.""" 24 | 25 | def generatorAnswer(self, prompt: str, 26 | history: List[List[str]] = [], 27 | streaming: bool = False): 28 | pass 29 | -------------------------------------------------------------------------------- /model/chatglm_llm.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, List 2 | import torch 3 | from transformers import AutoModel, AutoTokenizer, AutoModelForCausalLM 4 | from langchain.llms.base import LLM 5 | from model.base import AnswerResult 6 | from configs.params import ModelParams 7 | 8 | model_config = ModelParams() 9 | 10 | 11 | class ChatLLM(LLM): 12 | max_token: int = 8192 13 | temperature: float = 0.95 14 | top_p = 0.8 15 | history_len = 10 16 | history = [] 17 | model_type: str = "ChatGLM" 18 | model_path: str = model_config.llm_model 19 | tokenizer: object = None 20 | model: object = None 21 | 22 | def __init__(self): 23 | super().__init__() 24 | 25 | @property 26 | def _llm_type(self) -> str: 27 | return "ChatLLM" 28 | 29 | def load_llm(self): 30 | if 'internlm' in self.model_path.lower(): 31 | self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, device_map="auto", trust_remote_code=True, 32 | torch_dtype=torch.float16) 33 | self.model = AutoModelForCausalLM.from_pretrained(self.model_path, device_map="auto", 34 | trust_remote_code=True, 35 | torch_dtype=torch.float16) 36 | self.model = self.model.eval() 37 | self.model_type = "InternLM" 38 | else: 39 | self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, trust_remote_code=True) 40 | self.model = AutoModel.from_pretrained(self.model_path, trust_remote_code=True).cuda() 41 | self.model = self.model.eval() 42 | 43 | def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: 44 | print(f"__call:{prompt}") 45 | response, _ = self.model.chat( 46 | self.tokenizer, 47 | prompt, 48 | history=[], 49 | max_length=self.max_token, 50 | temperature=self.temperature, 51 | top_p=self.top_p 52 | ) 53 | print(f"response:{response}") 54 | print(f"+++++++++++++++++++++++++++++++++++") 55 | return response 56 | 57 | def generatorAnswer(self, prompt: str, 58 | history: List[List[str]] = [], 59 | streaming: bool = False): 60 | 61 | if streaming: 62 | history += [[]] 63 | if self.model_type == "InternLM": 64 | response = self.model.stream_chat( 65 | self.tokenizer, 66 | prompt, 67 | history=history[-self.history_len:-1] if self.history_len > 1 else [], 68 | max_new_tokens=self.max_token, 69 | temperature=self.temperature, 70 | top_p=self.top_p 71 | ) 72 | else: 73 | response = self.model.stream_chat( 74 | self.tokenizer, 75 | prompt, 76 | history=history[-self.history_len:-1] if self.history_len > 1 else [], 77 | max_length=self.max_token, 78 | temperature=self.temperature, 79 | top_p=self.top_p 80 | ) 81 | for inum, (stream_resp, _) in enumerate(response): 82 | # self.checkPoint.clear_torch_cache() 83 | history[-1] = [prompt, stream_resp] 84 | answer_result = AnswerResult() 85 | answer_result.history = history 86 | answer_result.llm_output = {"answer": stream_resp} 87 | yield answer_result 88 | else: 89 | response, _ = self.model.chat( 90 | self.tokenizer, 91 | prompt, 92 | history=history[-self.history_len:] if self.history_len > 0 else [], 93 | max_length=self.max_token, 94 | temperature=self.temperature, 95 | top_p=self.top_p 96 | ) 97 | self.clear_torch_cache() 98 | history += [[prompt, response]] 99 | answer_result = AnswerResult() 100 | answer_result.history = history 101 | answer_result.llm_output = {"answer": response} 102 | yield answer_result 103 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | transformers >= 4.30.2 2 | sentence_transformers >= 2.2.2 3 | langchain >= 0.0.294 4 | elasticsearch >= 8.8.0 5 | gradio >= 3.36.1 6 | -------------------------------------------------------------------------------- /utils/es_tool.py: -------------------------------------------------------------------------------- 1 | from configs.params import ESParams 2 | from elasticsearch import Elasticsearch 3 | 4 | es_params = ESParams() 5 | index_name = es_params.index_name 6 | 7 | # %% 初始化ES对象 8 | client = Elasticsearch(['{}:{}'.format(es_params.url, es_params.port)], 9 | basic_auth=(es_params.username, es_params.passwd), 10 | verify_certs=False) 11 | 12 | # %% 连通测试 13 | client.ping() 14 | 15 | # %% 检查索引是否存在 16 | index_exists = client.indices.exists(index=index_name) 17 | 18 | # %% 新建索引 19 | response = client.indices.create(index=index_name, body=mapping) 20 | 21 | # %% 插入数据 22 | response = client.index(index=index_name, id=document_id, document=data) 23 | 24 | # %% 更新 25 | rp = client.update(index=index_name, id=document_id, body={"doc": data}) 26 | 27 | # %% 检查文档是否存在 28 | document_exists = client.exists(index=index_name, id=document_id) 29 | 30 | # %% 根据ID删除文档 31 | response = client.delete(index=index_name, id=document_id) 32 | -------------------------------------------------------------------------------- /web.py: -------------------------------------------------------------------------------- 1 | import re 2 | import gradio as gr 3 | from doc_search import ES 4 | from model.chatglm_llm import ChatLLM 5 | from configs.params import ModelParams 6 | 7 | PROMPT_TEMPLATE = """已知信息: 8 | {context} 9 | 10 | 根据上述已知信息,简洁和专业的来回答用户的问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题” 或 “没有提供足够的相关信息”,不允许在答案中添加编造成分,答案请使用中文。 问题是:{question}""" 11 | 12 | model_config = ModelParams() 13 | es = ES(model_config.embedding_model) 14 | llm = ChatLLM() 15 | llm.load_llm() 16 | 17 | 18 | def clear_session(): 19 | return '', [], '' 20 | 21 | 22 | def search_doc(question, search_method, top_k, knn_boost, threshold): 23 | res = es.doc_search(method=search_method, query=question, top_k=top_k, knn_boost=knn_boost) 24 | if threshold > 0: 25 | result = [i for i in res if i['score'] > threshold] 26 | else: 27 | result = res 28 | return result 29 | 30 | 31 | def doc_format(doc_list): 32 | result = '' 33 | for i in doc_list: 34 | source = re.sub('data/', '', i['source']) 35 | result += f"source: {source}\nscore: {i['score']}\ncontent: {i['content']}\n" 36 | return result 37 | 38 | 39 | def predict(question, search_method, top_k, max_token, temperature, top_p, knn_boost, history, history_length, 40 | threshold): 41 | llm.max_token = max_token 42 | llm.temperature = temperature 43 | llm.top_p = top_p 44 | llm.history_len = history_length 45 | search_res = search_doc(question, search_method, top_k, knn_boost, threshold) 46 | search_result = doc_format(search_res) 47 | 48 | informed_context = '' 49 | for i in search_res: 50 | informed_context += i['content'] + '\n' 51 | prompt = PROMPT_TEMPLATE.replace("{question}", question).replace("{context}", informed_context) 52 | for answer_result in llm.generatorAnswer(prompt=prompt, history=history, streaming=True): 53 | history = answer_result.history 54 | history[-1][0] = question 55 | yield history, history, search_result, "" 56 | 57 | 58 | if __name__ == "__main__": 59 | title = """ 60 | # Elasticsearch + ChatGLM demo 61 | [https://github.com/iMagist486/ElasticSearch-Langchain-Chatglm2](https://github.com/iMagist486/ElasticSearch-Langchain-Chatglm2) 62 | """ 63 | with gr.Blocks() as demo: 64 | gr.Markdown(title) 65 | 66 | with gr.Row(): 67 | with gr.Column(scale=2): 68 | chatbot = gr.Chatbot() 69 | user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=4, container=False) 70 | with gr.Row(): 71 | submitBtn = gr.Button("Submit", variant="primary") 72 | emptyBtn = gr.Button("Clear History") 73 | search_out = gr.Textbox(label="文档交互", lines=25, max_lines=25, interactive=False, scale=1) 74 | 75 | with gr.Row(variant='compact'): 76 | with gr.Column(): 77 | gr.Markdown("""LLM设置""") 78 | max_length = gr.Slider(0, 32768, value=8192, step=1.0, label="Maximum length", interactive=True) 79 | top_p = gr.Slider(0, 1, value=0.8, step=0.01, label="Top P", interactive=True) 80 | temperature = gr.Slider(0, 1, value=0.01, step=0.01, label="Temperature", interactive=True) 81 | history_length = gr.Slider(0, 10, value=3, step=1, label="history_length", interactive=True) 82 | 83 | with gr.Column(): 84 | gr.Markdown("""查询设置""") 85 | search_method = gr.Radio(['近似查询', '混合查询', '精确查询'], 86 | value='精确查询', 87 | label="Search Method") 88 | threshold = gr.Number(label="查询阈值(0为不设限)", value=0.00, interactive=True) 89 | top_k = gr.Slider(0, 10, value=3, step=1.0, label="top_k", interactive=True) 90 | knn_boost = gr.Slider(0, 1, value=0.5, step=0.1, label="knn_boost", interactive=True) 91 | 92 | with gr.Column(): 93 | gr.Markdown("""知识库管理""") 94 | file = gr.File(label='请上传知识库文件', file_types=['.txt', '.md', '.doc', '.docx']) 95 | chunk_size = gr.Number(label="chunk_size", value=300, interactive=True) 96 | chunk_overlap = gr.Number(label="chunk_overlap", value=10, interactive=True) 97 | doc_upload = gr.Button("ES存储") 98 | 99 | history = gr.State([]) 100 | 101 | submitBtn.click(predict, 102 | inputs=[user_input, search_method, top_k, max_length, temperature, top_p, knn_boost, history, 103 | history_length, threshold], 104 | outputs=[chatbot, history, search_out, user_input] 105 | ) 106 | doc_upload.click( 107 | fn=es.doc_upload, 108 | show_progress=True, 109 | inputs=[file, chunk_size, chunk_overlap], 110 | outputs=[search_out], 111 | ) 112 | 113 | emptyBtn.click(fn=clear_session, inputs=[], outputs=[chatbot, history, search_out], queue=False) 114 | 115 | demo.queue().launch(share=False, inbrowser=True, server_name="0.0.0.0", server_port=8000) 116 | --------------------------------------------------------------------------------