├── README.md ├── __pycache__ └── cli.cpython-310.pyc ├── cli.py ├── docs └── content.txt ├── helloworld.py ├── models ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-310.pyc │ ├── chinese_text_splitter.cpython-310.pyc │ ├── config.cpython-310.pyc │ ├── custom_agent.cpython-310.pyc │ ├── custom_llm.cpython-310.pyc │ ├── custom_search.cpython-310.pyc │ └── util.cpython-310.pyc ├── chinese_text_splitter.py ├── config.py ├── custom_agent.py ├── custom_llm.py ├── custom_search.py └── util.py ├── server.py └── vector_store └── FAISS ├── index.faiss └── index.pkl /README.md: -------------------------------------------------------------------------------- 1 | # 本地知识库大模型 langchain + chatglm + Custom Agent 2 | 3 | langchain + chatgpt 的 agent 真香,实现了一个简单的基于 chatglm 的 custom agent。支持本地知识库和联网检索。llm 也是自定义,如果想改成 openAI 也可以。 4 | 5 | 详细介绍参照这里: 6 | 7 | ## 介绍 8 | 9 | langchain 的 agent 设计的非常聪明,但这个“聪明”是基于 chatgpt 强大的理解力,agent 自带的 prompt 并不能被 chatglm-6b 很好的理解,主要是 Action 字段和 Input 字段总是出错,所以我针对 chatglm 写了一个简单的 custom agent,不能做到 100% 的 prompt 指令精确响应,我实测 80% 的情况下可以正确响应,结合 Tool 能实现一些复杂的应用。 10 | 11 | ## 硬件部署 12 | 13 | 这里没有在本地跑大模型,所以硬件条件基本上不限制。 14 | 15 | ## 准备工作 16 | 17 | 1. 部署你的 chatglm 大模型,确保可以 api 调用,修改 `models/custom_llm.py` 里 `_call` 方法里的调用地址,或者你用 openAI 代替也可以,参照 [langchain 官网文档](https://python.langchain.com/) 18 | 1. `models/custom_search.py` 中设置你的 `RapidAPIKey = ""`,[申请步骤异步这里](https://rapidapi.com/microsoft-azure-org-microsoft-cognitive-services/api/bing-web-search1)(接口名字可以搜索:`bing-web-search1`) 19 | 20 | 代码checkout下来后,执行 21 | 22 | ``` 23 | python helloworld.py 24 | ``` 25 | 26 | 有正常的返回就ok了。 27 | 28 | ## 启动 29 | 30 | ``` 31 | python server.py 32 | ``` 33 | 34 | 启动服务监听 8899 端口,这样访问: 35 | 36 | ``` 37 | curl -d "ask=helloworld1" \ 38 | -H "Content-Type: application/x-www-form-urlencoded" \ 39 | -X POST http://127.0.0.1:8899/ai/langchain/ 40 | ``` 41 | 42 | 结果返回 43 | 44 | ``` 45 | {"content":"\u60a8\u9700\u8981\u9884\u5b9a\u673a\u7968\u5417?","status":200} 46 | ``` 47 | 48 | 注意: 49 | 50 | - curl 命令中传参数不要有空格,如果需要测试最好用 postman 之类的工具 51 | - server 启动用的 flask,如果需要其他机器访问,修改server.py 中服务启动加上本地 host:`app.run(debug=False, port=8899, host="192.168.0.11")` 52 | 53 | ## 例子 54 | 55 | 这是一个实际的例子:执行`python cli.py` 56 | 57 | ``` 58 | > Entering new AgentExecutor chain... 59 | DeepSearch('携程最近有什么大新闻?') 60 | 61 | Observation:到中国旅游更偏爱小红书 马来西亚年轻人对携程没兴趣: 【蓝科技观察】小红书成为马来西亚年轻人了解中国旅游市场的首选,而不是中国旅游平台携程。在马来西亚年轻人尤其是华人看来,携程的商业属性更明显,而小红书则能给他们带来更多的价值。 最近,马来西亚第四代华人Emma计划来中国旅游,而她是通过小红书来了解中国的旅游、文化、时尚等信息。 “按照我的思路,获得我想要的结果。”Emma表示,如她一样喜欢旅游的年轻人在马来西亚比比皆是,他们在小红书上follow热 62 | 携程:每10人有1人游览博物馆 00后文博爱好者增速最快: 随着“5.18国际博物馆日”临近,根据博物馆预订人次,结合线上搜索热度及馆藏数量,携程口碑榜发布了“国内博物馆20佳”,分别是:故宫、中国 ... 63 | 十人就有一人逛博物馆 携程口碑榜发布“国内博物馆20佳”: 近期,携程还与浦东美术馆合作,正式开启人工讲解服务,包含快速入场通道、配套耳麦设备、专人接待讲解等多项特别礼宾服务。除了为观众深入讲述馆内展览与重要展品背后的故事,还将全方位介绍美术馆建筑设计及功能空间的独到之处,让艺术不再“有 ... 64 | 出国热又起,各大领馆一约难求,堪比春节抢票?!有的地方拒签率 ...: 上海新闻广播微信公众号消息,积压了整整三年的旅游热情,终于在今年有了可“用武之地”,被戏称为“最强五一”的旅游 ... 65 | 根据已知信息,携程最近在博物馆预订方面取得了显著进展,每10人有1人游览博物馆,00后文博爱好者增速最快。此外,携程还与浦东美术馆合作,开启了人工讲解服务,并提升了博物馆服务质量。这些新闻表明携程在旅游业领域继续扩大其影响力,并不断提升服务质量,以满足消费者的需求。 66 | 67 | > Finished chain. 68 | 根据已知信息,携程最近在博物馆预订方面取得了显著进展,每10人有1人游览博物馆,00后文博爱好者增速最快。此外,携程还与浦东美术馆合作,开启了人工讲解服务,并提升了博物馆服务质量。这些新闻表明携程在旅游业领域继续扩大其影响力,并不断提升服务质量,以满足消费者的需求。 69 | ``` 70 | 71 | -------------------------------------------------------------------------------- /__pycache__/cli.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jayli/langchain-GLM_Agent/3d31bf8da1fca1f772be8b77a2e3369790a2d0db/__pycache__/cli.cpython-310.pyc -------------------------------------------------------------------------------- /cli.py: -------------------------------------------------------------------------------- 1 | import torch.cuda 2 | import torch.backends 3 | from typing import Any, List, Dict, Union, Mapping, Optional 4 | from langchain.embeddings.huggingface import HuggingFaceEmbeddings 5 | from models.custom_llm import CustomLLM 6 | from models.custom_agent import DeepAgent 7 | from models.util import LocalDocQA 8 | from models.config import * 9 | EMBEDDING_DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" 10 | 11 | deep_agent = DeepAgent() 12 | 13 | embeddings = HuggingFaceEmbeddings(model_name="GanymedeNil/text2vec-base-chinese", 14 | model_kwargs={'device':EMBEDDING_DEVICE}) 15 | 16 | qa_doc = LocalDocQA(filepath=LOCAL_CONTENT, 17 | vs_path=VS_PATH, 18 | embeddings=embeddings, 19 | init=True) 20 | 21 | def answer(query: str = ""): 22 | question = query 23 | related_content = qa_doc.query_knowledge(query=question) 24 | formed_related_content = "\n" + related_content 25 | result = deep_agent.query(related_content=formed_related_content, query=question) 26 | return result 27 | 28 | if __name__ == "__main__": 29 | question = "携程最近有什么大新闻?" 30 | related_content = qa_doc.query_knowledge(query=question) 31 | formed_related_content = "\n" + related_content 32 | print(deep_agent.query(related_content=formed_related_content, query=question)) 33 | -------------------------------------------------------------------------------- /docs/content.txt: -------------------------------------------------------------------------------- 1 | Madam Speaker, Madam Vice President, our First Lady and Second Gentleman. Members of Congress and the Cabinet. Justices of the Supreme Court. My fellow Americans. -------------------------------------------------------------------------------- /helloworld.py: -------------------------------------------------------------------------------- 1 | from langchain.prompts import PromptTemplate 2 | from models.custom_llm import CustomLLM 3 | 4 | llm = CustomLLM() 5 | 6 | prompt = PromptTemplate( 7 | input_variables=["product"], 8 | template="请给一个制造{product}的公司起一个好听的名字", 9 | ) 10 | 11 | from langchain.chains import LLMChain 12 | chain = LLMChain(llm=llm, prompt=prompt) 13 | print(chain.run("彩虹色的袜子")) 14 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jayli/langchain-GLM_Agent/3d31bf8da1fca1f772be8b77a2e3369790a2d0db/models/__init__.py -------------------------------------------------------------------------------- /models/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jayli/langchain-GLM_Agent/3d31bf8da1fca1f772be8b77a2e3369790a2d0db/models/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /models/__pycache__/chinese_text_splitter.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jayli/langchain-GLM_Agent/3d31bf8da1fca1f772be8b77a2e3369790a2d0db/models/__pycache__/chinese_text_splitter.cpython-310.pyc -------------------------------------------------------------------------------- /models/__pycache__/config.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jayli/langchain-GLM_Agent/3d31bf8da1fca1f772be8b77a2e3369790a2d0db/models/__pycache__/config.cpython-310.pyc -------------------------------------------------------------------------------- /models/__pycache__/custom_agent.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jayli/langchain-GLM_Agent/3d31bf8da1fca1f772be8b77a2e3369790a2d0db/models/__pycache__/custom_agent.cpython-310.pyc -------------------------------------------------------------------------------- /models/__pycache__/custom_llm.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jayli/langchain-GLM_Agent/3d31bf8da1fca1f772be8b77a2e3369790a2d0db/models/__pycache__/custom_llm.cpython-310.pyc -------------------------------------------------------------------------------- /models/__pycache__/custom_search.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jayli/langchain-GLM_Agent/3d31bf8da1fca1f772be8b77a2e3369790a2d0db/models/__pycache__/custom_search.cpython-310.pyc -------------------------------------------------------------------------------- /models/__pycache__/util.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jayli/langchain-GLM_Agent/3d31bf8da1fca1f772be8b77a2e3369790a2d0db/models/__pycache__/util.cpython-310.pyc -------------------------------------------------------------------------------- /models/chinese_text_splitter.py: -------------------------------------------------------------------------------- 1 | import re 2 | from typing import List 3 | 4 | from langchain.text_splitter import CharacterTextSplitter 5 | 6 | 7 | class ChineseTextSplitter(CharacterTextSplitter): 8 | def __init__(self, pdf: bool = False, **kwargs): 9 | super().__init__(**kwargs) 10 | self.pdf = pdf 11 | 12 | def split_text(self, text: str) -> List[str]: 13 | if self.pdf: 14 | text = re.sub(r"\n{3,}", "\n", text) 15 | text = re.sub('\s', ' ', text) 16 | text = text.replace("\n\n", "") 17 | sent_sep_pattern = re.compile( 18 | '([﹒﹔﹖﹗.。!?]["’”」』]{0,2}|(?=["‘“「『]{1,2}|$))') 19 | sent_list = [] 20 | for ele in sent_sep_pattern.split(text): 21 | if sent_sep_pattern.match(ele) and sent_list: 22 | sent_list[-1] += ele 23 | elif ele: 24 | sent_list.append(ele) 25 | return sent_list -------------------------------------------------------------------------------- /models/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | LOCAL_CONTENT = os.path.join(os.path.dirname(__file__), "../docs") 4 | VS_PATH = os.path.join(os.path.dirname(__file__), "../vector_store/FAISS") 5 | CHUNK_SIZE = 800 6 | CHUNK_OVERLAP = 70 7 | VECTOR_SEARCH_TOP_K = 2 8 | os.environ["SERPAPI_API_KEY"] = "Your SerpAPI Key" 9 | 10 | PROMPT_TEMPLATE = """已知信息: 11 | {context} 12 | 13 | 根据上述已知信息,简洁和专业的来回答用户的问题。如果无法从中得到答案,请给出你认为最合理的回答。答案请使用中文。 问题是:{question}""" 14 | -------------------------------------------------------------------------------- /models/custom_agent.py: -------------------------------------------------------------------------------- 1 | 2 | from langchain.agents import Tool 3 | from langchain.tools import BaseTool 4 | from langchain import PromptTemplate, LLMChain 5 | from models.custom_search import DeepSearch 6 | from langchain.agents import BaseSingleActionAgent, AgentOutputParser, LLMSingleActionAgent, AgentExecutor 7 | from typing import List, Tuple, Any, Union, Optional, Type 8 | from langchain.schema import AgentAction, AgentFinish 9 | from langchain.prompts import StringPromptTemplate 10 | from langchain.callbacks.manager import CallbackManagerForToolRun 11 | from models.custom_llm import CustomLLM 12 | import re 13 | 14 | agent_template = """ 15 | 你现在是一个{role}。这里是一些已知信息: 16 | {related_content} 17 | {background_infomation} 18 | {question_guide}:{input} 19 | 20 | {answer_format} 21 | """ 22 | 23 | class CustomPromptTemplate(StringPromptTemplate): 24 | template: str 25 | tools: List[Tool] 26 | 27 | def format(self, **kwargs) -> str: 28 | intermediate_steps = kwargs.pop("intermediate_steps") 29 | # 没有互联网查询信息 30 | if len(intermediate_steps) == 0: 31 | background_infomation = "\n" 32 | role = "傻瓜机器人" 33 | question_guide = "我现在有一个问题" 34 | answer_format = "如果你知道答案,请直接给出你的回答!如果你不知道答案,请你只回答\"DeepSearch('搜索词')\",并将'搜索词'替换为你认为需要搜索的关键词,除此之外不要回答其他任何内容。\n\n下面请回答我上面提出的问题!" 35 | 36 | # 返回了背景信息 37 | else: 38 | # 根据 intermediate_steps 中的 AgentAction 拼装 background_infomation 39 | background_infomation = "\n\n你还有这些已知信息作为参考:\n\n" 40 | action, observation = intermediate_steps[0] 41 | background_infomation += f"{observation}\n" 42 | role = "聪明的 AI 助手" 43 | question_guide = "请根据这些已知信息回答我的问题" 44 | answer_format = "" 45 | 46 | kwargs["background_infomation"] = background_infomation 47 | kwargs["role"] = role 48 | kwargs["question_guide"] = question_guide 49 | kwargs["answer_format"] = answer_format 50 | return self.template.format(**kwargs) 51 | 52 | class CustomSearchTool(BaseTool): 53 | name: str = "DeepSearch" 54 | description: str = "" 55 | 56 | def _run(self, query: str, run_manager: Optional[CallbackManagerForToolRun] = None): 57 | return DeepSearch.search(query = query) 58 | 59 | async def _arun(self, query: str): 60 | raise NotImplementedError("DeepSearch does not support async") 61 | 62 | class CustomAgent(BaseSingleActionAgent): 63 | @property 64 | def input_keys(self): 65 | return ["input"] 66 | 67 | def plan(self, intermedate_steps: List[Tuple[AgentAction, str]], 68 | **kwargs: Any) -> Union[AgentAction, AgentFinish]: 69 | return AgentAction(tool="DeepSearch", tool_input=kwargs["input"], log="") 70 | 71 | class CustomOutputParser(AgentOutputParser): 72 | def parse(self, llm_output: str) -> Union[AgentAction, AgentFinish]: 73 | # group1 = 调用函数名字 74 | # group2 = 传入参数 75 | match = re.match(r'^[\s\w]*(DeepSearch)\(([^\)]+)\)', llm_output, re.DOTALL) 76 | 77 | # 如果 llm 没有返回 DeepSearch() 则认为直接结束指令 78 | if not match: 79 | return AgentFinish( 80 | return_values={"output": llm_output.strip()}, 81 | log=llm_output, 82 | ) 83 | # 否则的话都认为需要调用 Tool 84 | else: 85 | action = match.group(1).strip() 86 | action_input = match.group(2).strip() 87 | return AgentAction(tool=action, tool_input=action_input.strip(" ").strip('"'), log=llm_output) 88 | 89 | 90 | class DeepAgent: 91 | tool_name: str = "DeepSearch" 92 | agent_executor: any 93 | tools: List[Tool] 94 | llm_chain: any 95 | 96 | def query(self, related_content: str = "", query: str = ""): 97 | tool_name = self.tool_name 98 | result = self.agent_executor.run(related_content=related_content, input=query ,tool_name=self.tool_name) 99 | return result 100 | 101 | def __init__(self, **kwargs): 102 | llm = CustomLLM() 103 | tools = [ 104 | Tool.from_function( 105 | func=DeepSearch.search, 106 | name="DeepSearch", 107 | description="" 108 | ) 109 | ] 110 | self.tools = tools 111 | tool_names = [tool.name for tool in tools] 112 | output_parser = CustomOutputParser() 113 | prompt = CustomPromptTemplate(template=agent_template, 114 | tools=tools, 115 | input_variables=["related_content","tool_name", "input", "intermediate_steps"]) 116 | 117 | llm_chain = LLMChain(llm=llm, prompt=prompt) 118 | self.llm_chain = llm_chain 119 | 120 | agent = LLMSingleActionAgent( 121 | llm_chain=llm_chain, 122 | output_parser=output_parser, 123 | stop=["\nObservation:"], 124 | allowed_tools=tool_names 125 | ) 126 | 127 | agent_executor = AgentExecutor.from_agent_and_tools(agent=agent, tools=tools, verbose=True) 128 | self.agent_executor = agent_executor 129 | 130 | 131 | if __name__ == "__main__": 132 | from custom_llm import CustomLLM 133 | 134 | llm = CustomLLM() 135 | tools = [ 136 | Tool.from_function( 137 | func=DeepSearch.search, 138 | name="DeepSearch", 139 | description="" 140 | ) 141 | ] 142 | tool_names = [tool.name for tool in tools] 143 | output_parser = CustomOutputParser() 144 | prompt = CustomPromptTemplate(template=agent_template, 145 | tools=tools, 146 | input_variables=["related_content","tool_name", "input", "intermediate_steps"]) 147 | 148 | llm_chain = LLMChain(llm=llm, prompt=prompt) 149 | 150 | agent = LLMSingleActionAgent( 151 | llm_chain=llm_chain, 152 | output_parser=output_parser, 153 | stop=["\nObservation:"], 154 | allowed_tools=tool_names 155 | ) 156 | 157 | agent_executor = AgentExecutor.from_agent_and_tools(agent=agent, tools=tools, verbose=True) 158 | print(agent_executor.run(related_content="", input="请问近期携程有什么大的新闻", tool_name="DeepSearch")) 159 | 160 | -------------------------------------------------------------------------------- /models/custom_llm.py: -------------------------------------------------------------------------------- 1 | from typing import Any, List, Dict, Mapping, Optional 2 | import json 3 | 4 | from langchain.callbacks.manager import CallbackManagerForLLMRun 5 | from langchain.requests import TextRequestsWrapper 6 | from langchain.llms.base import LLM 7 | 8 | # llm = CustomLLM() 9 | # print(llm("who are you?")) 10 | class CustomLLM(LLM): 11 | 12 | logging: bool = False 13 | output_keys: List[str] = ["output"] 14 | 15 | llm_type: str = "chatglm" 16 | 17 | @property 18 | def _llm_type(self) -> str: 19 | return self.llm_type 20 | 21 | def log(self, log_str): 22 | if self.logging: 23 | print(log_str) 24 | else: 25 | return 26 | 27 | def _call( 28 | self, 29 | prompt: str, 30 | stop: Optional[List[str]] = None, 31 | run_manager: Optional[CallbackManagerForLLMRun] = None, 32 | ) -> str: 33 | self.log('----------' + self._llm_type + '----------> llm._call()') 34 | self.log(prompt) 35 | requests = TextRequestsWrapper() 36 | 37 | response = requests.post(f"http://js-perf.cn:7001/test/{self._llm_type}", { 38 | "ask": prompt 39 | }) 40 | if self._llm_type == "chatglm": 41 | self.log('<--------chatglm------------') 42 | self.log(response) 43 | return response 44 | else: 45 | return "不支持该类型的 llm" 46 | 47 | @property 48 | def _identifying_params(self) -> Mapping[str, Any]: 49 | """Get the identifying parameters.""" 50 | return {"n": 10} 51 | 52 | -------------------------------------------------------------------------------- /models/custom_search.py: -------------------------------------------------------------------------------- 1 | import requests 2 | 3 | RapidAPIKey = "" 4 | 5 | class DeepSearch: 6 | def search(query: str = ""): 7 | query = query.strip() 8 | 9 | if query == "": 10 | return "" 11 | 12 | if RapidAPIKey == "": 13 | return "请配置你的 RapidAPIKey" 14 | 15 | url = "https://bing-web-search1.p.rapidapi.com/search" 16 | 17 | querystring = {"q": query, 18 | "mkt":"zh-cn","textDecorations":"false","setLang":"CN","safeSearch":"Off","textFormat":"Raw"} 19 | 20 | headers = { 21 | "Accept": "application/json", 22 | "X-BingApis-SDK": "true", 23 | "X-RapidAPI-Key": RapidAPIKey, 24 | "X-RapidAPI-Host": "bing-web-search1.p.rapidapi.com" 25 | } 26 | 27 | response = requests.get(url, headers=headers, params=querystring) 28 | 29 | data_list = response.json()['value'] 30 | 31 | if len(data_list) == 0: 32 | return "" 33 | else: 34 | result_arr = [] 35 | result_str = "" 36 | count_index = 0 37 | for i in range(6): 38 | item = data_list[i] 39 | title = item["name"] 40 | description = item["description"] 41 | item_str = f"{title}: {description}" 42 | result_arr = result_arr + [item_str] 43 | 44 | result_str = "\n".join(result_arr) 45 | return result_str 46 | 47 | -------------------------------------------------------------------------------- /models/util.py: -------------------------------------------------------------------------------- 1 | from typing import Any, List, Dict, Mapping, Optional 2 | import os 3 | from langchain.document_loaders import TextLoader 4 | from langchain.text_splitter import RecursiveCharacterTextSplitter 5 | from langchain.text_splitter import CharacterTextSplitter 6 | from langchain.document_loaders import UnstructuredFileLoader 7 | from models.chinese_text_splitter import ChineseTextSplitter 8 | from langchain.vectorstores import FAISS 9 | from models.custom_llm import CustomLLM 10 | import datetime 11 | import torch 12 | from tqdm import tqdm 13 | from models.config import * 14 | from langchain import PromptTemplate 15 | from langchain.memory import ConversationSummaryBufferMemory 16 | from langchain.chains import ConversationChain 17 | 18 | conversation_template = """你是一个正在跟某个人类对话的机器人. 19 | 20 | {chat_history} 21 | 人类: {human_input} 22 | 机器人:""" 23 | 24 | def load_txt_file(filepath): 25 | loader = TextLoader(filepath, encoding="utf8") 26 | textsplitter = RecursiveCharacterTextSplitter(chunk_size=CHUNK_SIZE, 27 | chunk_overlap=CHUNK_OVERLAP, 28 | length_function=len) 29 | docs = loader.load_and_split(text_splitter=textsplitter) 30 | return docs 31 | 32 | def torch_gc(): 33 | if torch.cuda.is_available(): 34 | torch.cuda.empty_cache() 35 | torch.cuda.ipc_collect() 36 | elif torch.backends.mps.is_available(): 37 | try: 38 | from torch.mps import empty_cache 39 | empty_cache() 40 | except Exception as e: 41 | print(e) 42 | print("如果您使用的是 macOS 建议将 pytorch 版本升级至 2.0.0 或更高版本,以支持及时清理 torch 产生的内存占用。") 43 | 44 | def load_file(filepath): 45 | if filepath.lower().endswith(".md"): 46 | loader = UnstructuredFileLoader(filepath, mode="elements") 47 | docs = loader.load() 48 | elif filepath.lower().endswith(".pdf"): 49 | loader = UnstructuredFileLoader(filepath) 50 | textsplitter = ChineseTextSplitter(pdf=True) 51 | docs = loader.load_and_split(textsplitter) 52 | else: 53 | docs = load_txt_file(filepath) 54 | return docs 55 | 56 | def get_related_content(related_docs): 57 | related_content = [] 58 | for doc in related_docs: 59 | related_content.append(doc.page_content) 60 | return "\n".join(related_content) 61 | 62 | def get_docs_with_score(docs_with_score): 63 | docs = [] 64 | for doc, score in docs_with_score: 65 | doc.metadata["score"] = score 66 | docs.append(doc) 67 | return docs 68 | 69 | # filepath 可以是目录,也可以是文件 70 | def init_knowledge_vector_store(filepath: str or List[str], 71 | vs_path: str or os.PathLike = None, 72 | embeddings: object = None): 73 | loaded_files = [] 74 | failed_files = [] 75 | # 单个文件 76 | if isinstance(filepath, str): 77 | if not os.path.exists(filepath): 78 | print(f"{filepath} 路径不存在") 79 | return None 80 | elif os.path.isfile(filepath): 81 | file = os.path.split(filepath)[-1] 82 | try: 83 | docs = load_file(filepath) 84 | print(f"{file} 已成功加载") 85 | loaded_files.append(filepath) 86 | except Exception as e: 87 | print(e) 88 | print(f"{file} 未能成功加载") 89 | return None 90 | elif os.path.isdir(filepath): 91 | docs = [] 92 | for file in tqdm(os.listdir(filepath), desc="加载文件"): 93 | fullfilepath = os.path.join(filepath, file) 94 | 95 | try: 96 | docs += load_file(fullfilepath) 97 | loaded_files.append(fullfilepath) 98 | except Exception as e: 99 | failed_files.append(file) 100 | 101 | if len(failed_files) > 0: 102 | print("以下文件未能成功加载:") 103 | for file in failed_files: 104 | print(file,end="\n") 105 | # 文件列表 106 | else: 107 | docs = [] 108 | for file in filepath: 109 | try: 110 | docs += load_file(file) 111 | print(f"{file} 已成功加载") 112 | loaded_files.append(file) 113 | except Exception as e: 114 | print(e) 115 | print(f"{file} 未能成功加载") 116 | 117 | if len(docs) > 0: 118 | print("文件加载完毕,正在生成向量库") 119 | if vs_path and os.path.isdir(vs_path): 120 | vector_store = FAISS.load_local(vs_path, embeddings) 121 | vector_store.add_documents(docs) 122 | torch_gc() 123 | else: 124 | if not vs_path: 125 | vs_path = os.path.join(vs_path, 126 | f"""FAISS_{datetime.datetime.now().strftime("%Y%m%d_%H%M%S")}""") 127 | vector_store = FAISS.from_documents(docs, embeddings) 128 | torch_gc() 129 | 130 | vector_store.save_local(vs_path) 131 | print("向量生成成功") 132 | return vs_path, loaded_files 133 | else: 134 | print("文件均未成功加载,请检查依赖包或替换为其他文件再次上传。") 135 | return None, loaded_files 136 | 137 | 138 | 139 | class LocalDocQA: 140 | filepath: str 141 | vs_path: str 142 | load_files: List[str] = [] 143 | top_k: int 144 | embedding: object 145 | llm: object 146 | conversation_with_summary: object 147 | init: bool = True 148 | 149 | def __init__(self, filepath: str, vs_path: str, embeddings: object, 150 | init: bool = True): 151 | if init: 152 | vs_path, loaded_files = init_knowledge_vector_store(filepath=LOCAL_CONTENT, 153 | vs_path=VS_PATH, 154 | embeddings=embeddings) 155 | else: 156 | vs_path = VS_PATH 157 | loaded_files = [] 158 | 159 | 160 | self.load_files = loaded_files 161 | self.vs_path = vs_path 162 | self.filepath = filepath 163 | self.embeddings = embeddings 164 | self.top_k = VECTOR_SEARCH_TOP_K 165 | self.llm = CustomLLM() 166 | self.conversation_with_summary = ConversationChain(llm=self.llm, 167 | memory=ConversationSummaryBufferMemory(llm=self.llm, 168 | max_token_limit=40), 169 | verbose=True) 170 | 171 | def query_knowledge(self, query: str): 172 | vector_store = FAISS.load_local(self.vs_path, self.embeddings) 173 | vector_store.chunk_size = CHUNK_SIZE 174 | related_docs_with_score = vector_store.similarity_search_with_score(query, k = self.top_k) 175 | related_docs = get_docs_with_score(related_docs_with_score) 176 | related_content = get_related_content(related_docs) 177 | return related_content 178 | 179 | def get_knowledge_based_answer(self, query: str): 180 | related_content = self.query_knowledge(query) 181 | prompt = PromptTemplate( 182 | input_variables=["context","question"], 183 | template=PROMPT_TEMPLATE, 184 | ) 185 | pmt = prompt.format(context=related_content, 186 | question=query) 187 | 188 | # answer=self.conversation_with_summary.predict(input=pmt) 189 | answer = self.llm(pmt) 190 | return answer 191 | -------------------------------------------------------------------------------- /server.py: -------------------------------------------------------------------------------- 1 | #coding = utf-8 2 | from flask import Flask, jsonify, request, make_response 3 | import json 4 | from cli import answer 5 | 6 | app = Flask(__name__) 7 | 8 | @app.route('/ai/langchain/', methods=["POST"]) 9 | def handle_langchain_ask(): 10 | if not request.form or not 'ask' in request.form: 11 | return make_response(jsonify({ "status": 500, 12 | "error":"error form params"}), 500) 13 | ask = request.form.get('ask') 14 | content = answer(ask) 15 | return make_response(jsonify({ 'status': 200, 16 | 'content': content}), 200) 17 | 18 | if __name__ == "__main__": 19 | app.run(debug=False, port=8899, host="127.0.0.1") 20 | -------------------------------------------------------------------------------- /vector_store/FAISS/index.faiss: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jayli/langchain-GLM_Agent/3d31bf8da1fca1f772be8b77a2e3369790a2d0db/vector_store/FAISS/index.faiss -------------------------------------------------------------------------------- /vector_store/FAISS/index.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jayli/langchain-GLM_Agent/3d31bf8da1fca1f772be8b77a2e3369790a2d0db/vector_store/FAISS/index.pkl --------------------------------------------------------------------------------