├── .idea ├── .gitignore ├── inspectionProfiles │ └── profiles_settings.xml ├── lmchain.iml ├── misc.xml └── modules.xml ├── GLM3_version_REAM.md ├── LICENSE ├── LMchain.egg-info ├── PKG-INFO ├── SOURCES.txt ├── dependency_links.txt └── top_level.txt ├── MANIFEST.in ├── README.md ├── __pycache__ └── tool_register.cpython-311.pyc ├── build └── lib │ └── lmchain │ ├── __init__.py │ ├── agents │ ├── __init__.py │ ├── llmAgent.py │ ├── llmMultiActionAgent.py │ └── llmMultiAgent.py │ ├── callbacks │ ├── __init__.py │ ├── base.py │ ├── manager.py │ └── stdout.py │ ├── chains │ ├── __init__.py │ ├── base.py │ ├── cmd.py │ ├── conversationalRetrievalChain.py │ ├── mathChain.py │ ├── question_answering.py │ ├── toolchain.py │ └── urlRequestChain.py │ ├── embeddings │ ├── __init__.py │ └── embeddings.py │ ├── hello.py │ ├── index │ ├── __init__.py │ └── indexChain.py │ ├── llms │ ├── __init__.py │ └── base.py │ ├── load │ ├── __init__.py │ └── serializable.py │ ├── memory │ ├── __init__.py │ ├── chat_memory.py │ ├── messageHistory.py │ └── utils.py │ ├── model │ ├── __init__.py │ └── language_model.py │ ├── prompts │ ├── __init__.py │ ├── base.py │ ├── chat.py │ ├── example_selectors.py │ ├── few_shot_templates.py │ ├── loading.py │ ├── prompt.py │ ├── templates.py │ └── tool_templates.py │ ├── schema │ ├── __init__.py │ ├── agent.py │ ├── document.py │ ├── language_model.py │ ├── memory.py │ ├── messages.py │ ├── output.py │ ├── output_parser.py │ ├── prompt.py │ ├── prompt_template.py │ ├── runnable.py │ ├── runnable │ │ ├── __init__.py │ │ ├── base.py │ │ ├── branch.py │ │ ├── config.py │ │ ├── configurable.py │ │ ├── fallbacks.py │ │ ├── passthrough.py │ │ ├── retry.py │ │ ├── router.py │ │ └── utils.py │ ├── runnable_utils.py │ └── schema.py │ ├── tool_register.py │ ├── tools │ ├── __init__.py │ └── tool_register.py │ ├── utils │ ├── __init__.py │ ├── formatting.py │ ├── input.py │ ├── loading.py │ └── math.py │ └── vectorstores │ ├── __init__.py │ ├── chroma.py │ ├── embeddings.py │ ├── laiss.py │ ├── utils.py │ └── vectorstore.py ├── dist ├── LMchain-0.1.60-py3-none-any.whl ├── LMchain-0.1.60.tar.gz ├── LMchain-0.1.61-py3-none-any.whl ├── LMchain-0.1.61.tar.gz ├── LMchain-0.1.62-py3-none-any.whl └── LMchain-0.1.62.tar.gz ├── lmchain ├── __init__.py ├── __pycache__ │ └── __init__.cpython-311.pyc ├── agents │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-311.pyc │ │ ├── llmAgent.cpython-311.pyc │ │ └── llmMultiAgent.cpython-311.pyc │ ├── llmAgent.py │ └── llmMultiAgent.py ├── callbacks │ ├── __init__.py │ ├── base.py │ ├── manager.py │ └── stdout.py ├── chains │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-311.pyc │ │ ├── cmd.cpython-311.pyc │ │ ├── mathChain.cpython-311.pyc │ │ └── urlRequestChain.cpython-311.pyc │ ├── base.py │ ├── cmd.py │ ├── conversationalRetrievalChain.py │ ├── mathChain.py │ ├── question_answering.py │ ├── subQuestChain.py │ ├── toolchain.py │ └── urlRequestChain.py ├── embeddings │ └── __init__.py ├── index │ ├── __init__.py │ └── indexChain.py ├── llms │ ├── __init__.py │ └── base.py ├── load │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-311.pyc │ │ └── serializable.cpython-311.pyc │ └── serializable.py ├── memory │ ├── __init__.py │ ├── chat_memory.py │ ├── messageHistory.py │ └── utils.py ├── model │ ├── __init__.py │ └── language_model.py ├── prompts │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-311.pyc │ │ └── base.cpython-311.pyc │ ├── base.py │ ├── chat.py │ ├── example_selectors.py │ ├── few_shot_templates.py │ ├── loading.py │ ├── prompt.py │ ├── templates.py │ └── tool_templates.py ├── schema │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-311.pyc │ │ ├── document.cpython-311.pyc │ │ ├── messages.cpython-311.pyc │ │ ├── output.cpython-311.pyc │ │ ├── output_parser.cpython-311.pyc │ │ ├── prompt.cpython-311.pyc │ │ └── prompt_template.cpython-311.pyc │ ├── agent.py │ ├── document.py │ ├── language_model.py │ ├── memory.py │ ├── messages.py │ ├── output.py │ ├── output_parser.py │ ├── prompt.py │ ├── prompt_template.py │ ├── runnable.py │ ├── runnable │ │ ├── __init__.py │ │ ├── base.py │ │ ├── branch.py │ │ ├── config.py │ │ ├── configurable.py │ │ ├── fallbacks.py │ │ ├── passthrough.py │ │ ├── retry.py │ │ ├── router.py │ │ └── utils.py │ ├── runnable_utils.py │ └── schema.py ├── tool_register.py ├── tools │ ├── __init__.py │ └── tool_register.py ├── utils │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-311.pyc │ │ ├── formatting.cpython-311.pyc │ │ └── math.cpython-311.pyc │ ├── formatting.py │ ├── input.py │ ├── loading.py │ └── math.py └── vectorstores │ ├── __init__.py │ ├── __pycache__ │ └── vectorstore.cpython-311.pyc │ ├── chroma.py │ ├── embeddings.py │ ├── laiss.py │ ├── utils.py │ └── vectorstore.py ├── pyproject.toml ├── setup.py ├── tool_register.py └── upload /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | # Editor-based HTTP Client requests 5 | /httpRequests/ 6 | # Datasource local storage ignored files 7 | /dataSources/ 8 | /dataSources.local.xml 9 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/lmchain.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 12 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 6 | 7 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /GLM3_version_REAM.md: -------------------------------------------------------------------------------- 1 | 由于GLM更新为GLM4,现在对于GLM3的更新暂停,读者更新 lmchian==0.1.X版本的均为GLM3, 2 | 3 | 而lmchian==0.2.01及以上版本的,均采用GLM4为开发基础模型。 4 | 5 | 在使用上,lmchain可以正常使用。 6 | 7 | LMchain is a toolkit specifically adapted for chinese large model chains. 8 | 9 | Lmchain是专用为中国大陆用户提供免费大模型服务的工具包,目前免费推荐使用chatGLM。 10 | 11 | 免费用户可以在https://open.bigmodel.cn 12 | 注册并获取免费API。也可以使用lmchain中自带的免费key。 13 | 14 | 功能正在陆续添加中,用户可以在issues中发表内容,也可以与作者联系5847713@qq.com 15 | 欢迎提出您的想法和建议。 16 | ----------------------------------------------------------------------------- 17 | 使用方法:```pip install lmchain``` 18 | ----------------------------------------------------------------------------- 19 | 20 | >1、从一个简单的文本问答如下 21 | ``` 22 | from lmchain.agents import llmMultiAgent 23 | llm = llmMultiAgent.AgentZhipuAI() 24 | llm.zhipuai.api_key = "1f565e40af1198e11ff1fd8a5b42771d.SjNfezc40YFsz2KC" #你个人注册可正常使用的API KEY 25 | response = llm("南京是哪里的省会?") 26 | print(response) 27 | 28 | response = llm("那里有什么好玩的地方?") 29 | print(response) 30 | ``` 31 | 32 | >2、除此之外,lmchain还有对复杂任务拆解的功能,例如: 33 | ``` 34 | from lmchain.agents import llmMultiAgent 35 | llm = llmMultiAgent.AgentZhipuAI() 36 | 37 | 38 | query = "工商银行财报中,2023 Q3相比,2024 Q1的收益增长了多少?" 39 | 40 | from lmchain.chains import subQuestChain 41 | subQC = subQuestChain.SubQuestChain(llm) 42 | response = subQC.run(query=query) 43 | 44 | print(response) 45 | ``` 46 | >3、调用大模型Embedding tool对文本进行嵌入embedding计算的方法 47 | ``` 48 | from lmchain.vectorstores import embeddings # 导入embeddings模块 49 | embedding_tool = embeddings.GLMEmbedding() # 创建一个GLMEmbedding对象 50 | embedding_tool.zhipuai.api_key = "1f565e40af1198e11ff1fd8a5b42771d.SjNfezc40YFsz2KC" #你个人注册可正常使用的API KEY 51 | 52 | inputs = ["lmchain还有对复杂任务拆解的功能", "目前lmchain还提供了对工具函数的调用方法", "Lmchain是专用为中国大陆用户提供免费大模型服务的工具包"] * 50 53 | 54 | #由于此时对embedding的处理,对原始传入的文本顺序做了变更, 55 | # 因此需要采用新的文本list排序 56 | aembeddings,atexts = (embedding_tool.aembed_documents(inputs)) 57 | print(aembeddings) 58 | 59 | #每条文本内容被embedding处理为[1,1024]大小的序列 60 | import numpy as np 61 | aembeddings = (np.array(aembeddings)) 62 | print(aembeddings.shape) 63 | ``` 64 | >4、目前lmchain还提供了对工具函数的调用方法 65 | ``` 66 | from lmchain.agents import llmMultiAgent 67 | llm = llmMultiAgent.AgentZhipuAI() 68 | 69 | from lmchain.chains import toolchain 70 | 71 | tool_chain = toolchain.GLMToolChain(llm) 72 | 73 | query = "说一下上海的天气" 74 | response = tool_chain.run(query) 75 | print(response) 76 | ``` 77 | 78 | >5、添加自定义工具并调用的方法 79 | ``` 80 | from lmchain.agents import llmMultiAgent 81 | llm = llmMultiAgent.AgentZhipuAI() 82 | 83 | from lmchain.chains import toolchain 84 | tool_chain = toolchain.GLMToolChain(llm) 85 | 86 | from typing import Annotated 87 | #下面的play_game是自定义的工具 88 | def play_game( 89 | #使用Annotated对形参进行标注[形参类型,形参用途描述,是否必须] 90 | num: Annotated[int, 'use the num to play game', True], 91 | ): 92 | #函数内注释是为了向模型提供对函数用途的解释 93 | """ 94 | 一个数字游戏, 95 | 随机输入数字,按游戏规则输出结果的游戏 96 | """ 97 | if num % 3: 98 | return 3 99 | if num % 5: 100 | return 5 101 | return 0 102 | 103 | tool_chain.add_tools(play_game) 104 | query = "玩一个数字游戏,输入数字3" 105 | result = tool_chain.run(query) 106 | 107 | print(result) 108 | 109 | ``` 110 | 其他功能正在陆续添加中,欢迎读者留下您的意见或与作者联系。 111 | 112 | 113 | 114 | 115 | 116 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2018 The Python Packaging Authority 2 | Permission is hereby granted, free of charge, to any person obtaining a copy 3 | of this software and associated documentation files (the "Software"), to deal 4 | in the Software without restriction, including without limitation the rights 5 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 6 | copies of the Software, and to permit persons to whom the Software is 7 | furnished to do so, subject to the following conditions: 8 | The above copyright notice and this permission notice shall be included in all 9 | copies or substantial portions of the Software. 10 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 11 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 12 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 13 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 14 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 15 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 16 | SOFTWARE. 17 | -------------------------------------------------------------------------------- /LMchain.egg-info/PKG-INFO: -------------------------------------------------------------------------------- 1 | Metadata-Version: 2.1 2 | Name: LMchain 3 | Version: 0.1.62 4 | Summary: A large language chain tools 5 | Author: xiaohuaWang 6 | Author-email: xiaohuaWang <5847713@qq.com> 7 | Project-URL: Homepage, https://github.com/pypa/sampleproject 8 | Project-URL: Bug Tracker, https://github.com/pypa/sampleproject/issues 9 | Classifier: Programming Language :: Python :: 3 10 | Classifier: License :: OSI Approved :: MIT License 11 | Classifier: Operating System :: OS Independent 12 | Requires-Python: >=3 13 | Description-Content-Type: text/markdown 14 | License-File: LICENSE 15 | 16 | LMchain is a toolkit specifically adapted for large model chains 17 | -------------------------------------------------------------------------------- /LMchain.egg-info/SOURCES.txt: -------------------------------------------------------------------------------- 1 | LICENSE 2 | MANIFEST.in 3 | README.md 4 | pyproject.toml 5 | setup.py 6 | LMchain.egg-info/PKG-INFO 7 | LMchain.egg-info/SOURCES.txt 8 | LMchain.egg-info/dependency_links.txt 9 | LMchain.egg-info/top_level.txt 10 | lmchain/__init__.py 11 | lmchain/tool_register.py 12 | lmchain/agents/__init__.py 13 | lmchain/agents/llmAgent.py 14 | lmchain/agents/llmMultiAgent.py 15 | lmchain/callbacks/__init__.py 16 | lmchain/callbacks/base.py 17 | lmchain/callbacks/manager.py 18 | lmchain/callbacks/stdout.py 19 | lmchain/chains/__init__.py 20 | lmchain/chains/base.py 21 | lmchain/chains/cmd.py 22 | lmchain/chains/conversationalRetrievalChain.py 23 | lmchain/chains/mathChain.py 24 | lmchain/chains/question_answering.py 25 | lmchain/chains/toolchain.py 26 | lmchain/chains/urlRequestChain.py 27 | lmchain/embeddings/__init__.py 28 | lmchain/index/__init__.py 29 | lmchain/index/indexChain.py 30 | lmchain/llms/__init__.py 31 | lmchain/llms/base.py 32 | lmchain/load/__init__.py 33 | lmchain/load/serializable.py 34 | lmchain/memory/__init__.py 35 | lmchain/memory/chat_memory.py 36 | lmchain/memory/messageHistory.py 37 | lmchain/memory/utils.py 38 | lmchain/model/__init__.py 39 | lmchain/model/language_model.py 40 | lmchain/prompts/__init__.py 41 | lmchain/prompts/base.py 42 | lmchain/prompts/chat.py 43 | lmchain/prompts/example_selectors.py 44 | lmchain/prompts/few_shot_templates.py 45 | lmchain/prompts/loading.py 46 | lmchain/prompts/prompt.py 47 | lmchain/prompts/templates.py 48 | lmchain/prompts/tool_templates.py 49 | lmchain/schema/__init__.py 50 | lmchain/schema/agent.py 51 | lmchain/schema/document.py 52 | lmchain/schema/language_model.py 53 | lmchain/schema/memory.py 54 | lmchain/schema/messages.py 55 | lmchain/schema/output.py 56 | lmchain/schema/output_parser.py 57 | lmchain/schema/prompt.py 58 | lmchain/schema/prompt_template.py 59 | lmchain/schema/runnable.py 60 | lmchain/schema/runnable_utils.py 61 | lmchain/schema/schema.py 62 | lmchain/schema/runnable/__init__.py 63 | lmchain/schema/runnable/base.py 64 | lmchain/schema/runnable/branch.py 65 | lmchain/schema/runnable/config.py 66 | lmchain/schema/runnable/configurable.py 67 | lmchain/schema/runnable/fallbacks.py 68 | lmchain/schema/runnable/passthrough.py 69 | lmchain/schema/runnable/retry.py 70 | lmchain/schema/runnable/router.py 71 | lmchain/schema/runnable/utils.py 72 | lmchain/tools/__init__.py 73 | lmchain/tools/tool_register.py 74 | lmchain/utils/__init__.py 75 | lmchain/utils/formatting.py 76 | lmchain/utils/input.py 77 | lmchain/utils/loading.py 78 | lmchain/utils/math.py 79 | lmchain/vectorstores/__init__.py 80 | lmchain/vectorstores/chroma.py 81 | lmchain/vectorstores/embeddings.py 82 | lmchain/vectorstores/laiss.py 83 | lmchain/vectorstores/utils.py 84 | lmchain/vectorstores/vectorstore.py -------------------------------------------------------------------------------- /LMchain.egg-info/dependency_links.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /LMchain.egg-info/top_level.txt: -------------------------------------------------------------------------------- 1 | lmchain 2 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | recursive-include your_package_name *.py *.txt 2 | recursive-include another_folder *.csv 3 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | LMchain is a toolkit specifically adapted for chinese large model chains 2 | 3 | Lmchain是专用为中国大陆用户提供免费大模型服务的工具包,目前免费推荐使用chatGLM。 4 | 5 | 免费用户可以在https://open.bigmodel.cn 6 | 注册并获取免费API。也可以使用lmchain中自带的免费key。 7 | 8 | 功能正在陆续添加中,用户可以在issues中发表内容,也可以与作者联系5847713@qq.com 9 | 欢迎提出您的想法和建议。 10 | 11 | 注意:lmchian随着GLM4的更新,已全新更新为新的API,老的基本GLM3版本的用户可以继续使用(版本最高为0.1.78)。 12 | ----------------------------------------------------------------------------- 13 | 使用方法:```pip install lmchain``` 14 | ----------------------------------------------------------------------------- 15 | 16 | >1、从一个简单的文本问答如下 17 | ``` 18 | from lmchain.agents import AgentZhipuAI 19 | llm = AgentZhipuAI() 20 | 21 | response = llm("你好") 22 | print(response) 23 | 24 | response = llm("南京是哪里的省会") 25 | print(response) 26 | 27 | response = llm("那里有什么好玩的地方") 28 | print(response) 29 | ``` 30 | 31 | 32 | 33 | 34 | 35 | -------------------------------------------------------------------------------- /__pycache__/tool_register.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/virgo777/lmchain/3fc4a55df07100d4fe0beef4da89e3fbc6288938/__pycache__/tool_register.cpython-311.pyc -------------------------------------------------------------------------------- /build/lib/lmchain/__init__.py: -------------------------------------------------------------------------------- 1 | name = "lmchain" -------------------------------------------------------------------------------- /build/lib/lmchain/agents/__init__.py: -------------------------------------------------------------------------------- 1 | name = "agents" -------------------------------------------------------------------------------- /build/lib/lmchain/agents/llmAgent.py: -------------------------------------------------------------------------------- 1 | import time 2 | import logging 3 | import requests 4 | from typing import Optional, List, Dict, Mapping, Any 5 | 6 | import langchain 7 | from langchain.llms.base import LLM 8 | from langchain.cache import InMemoryCache 9 | 10 | logging.basicConfig(level=logging.INFO) 11 | # 启动llm的缓存 12 | langchain.llm_cache = InMemoryCache() 13 | 14 | 15 | class AgentChatGLM(LLM): 16 | # 模型服务url 17 | url = "http://127.0.0.1:7866/chat" 18 | #url = "http://192.168.3.20:7866/chat" #3050服务器上 19 | history = [] 20 | 21 | @property 22 | def _llm_type(self) -> str: 23 | return "chatglm" 24 | 25 | def _construct_query(self, prompt: str) -> Dict: 26 | """构造请求体 27 | """ 28 | query = {"query": prompt, "history": self.history} 29 | import json 30 | query = json.dumps(query) # 对请求参数进行JSON编码 31 | 32 | return query 33 | 34 | def _construct_query_tools(self, prompt: str , tools: list ) -> Dict: 35 | """构造请求体 36 | """ 37 | tools_info = {"role": "system", 38 | "content": "你现在是一个查找使用何种工具以及传递何种参数的工具助手,你会一步步的思考问题。你根据需求查找工具函数箱中最合适的工具函数,然后返回工具函数名称和所工具函数对应的参数,参数必须要和需求中的目标对应。", 39 | "tools": tools} 40 | query = {"query": prompt, "history": tools_info} 41 | import json 42 | query = json.dumps(query) # 对请求参数进行JSON编码 43 | 44 | return query 45 | 46 | 47 | @classmethod 48 | def _post(self, url: str, query: Dict) -> Any: 49 | 50 | """POST请求""" 51 | response = requests.post(url, data=query).json() 52 | return response 53 | 54 | def _call(self, prompt: str, stop: Optional[List[str]] = None, tools:list = None) -> str: 55 | """_call""" 56 | if tools == None: 57 | # construct query 58 | query = self._construct_query(prompt=prompt) 59 | 60 | # post 61 | response = self._post(url=self.url,query=query) 62 | 63 | response_chat = response["response"]; 64 | self.history = response["history"] 65 | 66 | return response_chat 67 | else: 68 | 69 | query = self._construct_query_tools(prompt=prompt,tools=tools) 70 | # post 71 | response = self._post(url=self.url, query=query) 72 | self.history = response["history"] #这个history要放上面 73 | response = response["response"] 74 | try: 75 | #import ast 76 | #response = ast.literal_eval(response) 77 | ret = tool_register.dispatch_tool(response["name"], response["parameters"]) 78 | response_chat = llm(prompt=ret) 79 | except: 80 | response_chat = response 81 | return str(response_chat) 82 | 83 | @property 84 | def _identifying_params(self) -> Mapping[str, Any]: 85 | """Get the identifying parameters. 86 | """ 87 | _param_dict = { 88 | "url": self.url 89 | } 90 | return _param_dict 91 | 92 | 93 | if __name__ == "__main__": 94 | 95 | import tool_register 96 | 97 | # 获取注册后的全部工具,并以json的形式返回 98 | tools = tool_register.get_tools() 99 | "--------------------------------------首先是对tools的定义---------------------------------------" 100 | 101 | llm = AgentChatGLM() 102 | llm.url = "http://192.168.3.20:7866/chat" 103 | while True: 104 | while True: 105 | human_input = input("Human: ") 106 | if human_input == "tools": 107 | break 108 | 109 | begin_time = time.time() * 1000 110 | # 请求模型 111 | response = llm(human_input) 112 | end_time = time.time() * 1000 113 | used_time = round(end_time - begin_time, 3) 114 | #logging.info(f"chatGLM process time: {used_time}ms") 115 | print(f"Chat: {response}") 116 | 117 | human_input = input("Human_with_tools_Ask: ") 118 | response = llm(prompt=human_input,tools=tools) 119 | print(f"Chat_with_tools_Que: {response}") 120 | 121 | 122 | 123 | 124 | -------------------------------------------------------------------------------- /build/lib/lmchain/agents/llmMultiActionAgent.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/virgo777/lmchain/3fc4a55df07100d4fe0beef4da89e3fbc6288938/build/lib/lmchain/agents/llmMultiActionAgent.py -------------------------------------------------------------------------------- /build/lib/lmchain/agents/llmMultiAgent.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import logging 4 | import requests 5 | from typing import Optional, List, Dict, Mapping, Any 6 | import langchain 7 | from langchain.llms.base import LLM 8 | from langchain.cache import InMemoryCache 9 | 10 | logging.basicConfig(level=logging.INFO) 11 | # 启动llm的缓存 12 | langchain.llm_cache = InMemoryCache() 13 | 14 | 15 | class AgentZhipuAI(LLM): 16 | import zhipuai as zhipuai 17 | # 模型服务url 18 | url = "127.0.0.1" 19 | zhipuai.api_key ="1f565e40af1198e11ff1fd8a5b42771d.SjNfezc40YFsz2KC"#控制台中获取的 APIKey 信息 20 | model = "chatglm_pro" # 大模型版本 21 | 22 | history = [] 23 | 24 | def getText(self,role, content): 25 | # role 是指定角色,content 是 prompt 内容 26 | jsoncon = {} 27 | jsoncon["role"] = role 28 | jsoncon["content"] = content 29 | self.history.append(jsoncon) 30 | return self.history 31 | 32 | @property 33 | def _llm_type(self) -> str: 34 | return "AgentZhipuAI" 35 | 36 | @classmethod 37 | def _post(self, url: str, query: Dict) -> Any: 38 | 39 | """POST请求""" 40 | response = requests.post(url, data=query).json() 41 | return response 42 | 43 | def _call(self, prompt: str, stop: Optional[List[str]] = None,role = "user") -> str: 44 | """_call""" 45 | # construct query 46 | response = self.zhipuai.model_api.invoke( 47 | model=self.model, 48 | prompt=self.getText(role=role, content=prompt) 49 | ) 50 | choices = (response['data']['choices'])[0] 51 | self.history.append(choices) 52 | return choices["content"] 53 | 54 | @property 55 | def _identifying_params(self) -> Mapping[str, Any]: 56 | """Get the identifying parameters. 57 | """ 58 | _param_dict = { 59 | "url": self.url 60 | } 61 | return _param_dict 62 | 63 | 64 | if __name__ == '__main__': 65 | from langchain.prompts import PromptTemplate 66 | from langchain.chains import LLMChain 67 | 68 | llm = AgentZhipuAI() 69 | 70 | # 没有输入变量的示例prompt 71 | no_input_prompt = PromptTemplate(input_variables=[], template="给我讲个笑话。") 72 | no_input_prompt.format() 73 | 74 | prompt = PromptTemplate( 75 | input_variables=["location", "street"], 76 | template="作为一名专业的旅游顾问,简单的说一下{location}有什么好玩的景点,特别是在{street}?只要说一个就可以。", 77 | ) 78 | 79 | chain = LLMChain(llm=llm, prompt=prompt) 80 | print(chain.run({"location": "南京", "street": "新街口"})) 81 | 82 | 83 | from langchain.chains import ConversationChain 84 | conversation = ConversationChain(llm=llm, verbose=True) 85 | 86 | output = conversation.predict(input="你好!") 87 | print(output) 88 | 89 | output = conversation.predict(input="南京是哪里的省会?") 90 | print(output) 91 | 92 | output = conversation.predict(input="那里有什么好玩的地方,简单的说一个就好。") 93 | print(output) 94 | 95 | -------------------------------------------------------------------------------- /build/lib/lmchain/callbacks/__init__.py: -------------------------------------------------------------------------------- 1 | name = "callbacks" -------------------------------------------------------------------------------- /build/lib/lmchain/callbacks/stdout.py: -------------------------------------------------------------------------------- 1 | """Callback Handler that prints to std out.""" 2 | from typing import Any, Dict, List, Optional 3 | 4 | from langchain.callbacks.base import BaseCallbackHandler 5 | from langchain.schema import AgentAction, AgentFinish, LLMResult 6 | from lmchain.utils.input import print_text 7 | 8 | 9 | class StdOutCallbackHandler(BaseCallbackHandler): 10 | """Callback Handler that prints to std out.""" 11 | 12 | def __init__(self, color: Optional[str] = None) -> None: 13 | """Initialize callback handler.""" 14 | self.color = color 15 | 16 | def on_llm_start( 17 | self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any 18 | ) -> None: 19 | """Print out the prompts.""" 20 | pass 21 | 22 | def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: 23 | """Do nothing.""" 24 | pass 25 | 26 | def on_llm_new_token(self, token: str, **kwargs: Any) -> None: 27 | """Do nothing.""" 28 | pass 29 | 30 | def on_llm_error(self, error: BaseException, **kwargs: Any) -> None: 31 | """Do nothing.""" 32 | pass 33 | 34 | def on_chain_start( 35 | self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any 36 | ) -> None: 37 | """Print out that we are entering a chain.""" 38 | class_name = serialized.get("name", serialized.get("id", [""])[-1]) 39 | print(f"\n\n\033[1m> Entering new {class_name} chain...\033[0m") 40 | 41 | def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: 42 | """Print out that we finished a chain.""" 43 | print("\n\033[1m> Finished chain.\033[0m") 44 | 45 | def on_chain_error(self, error: BaseException, **kwargs: Any) -> None: 46 | """Do nothing.""" 47 | pass 48 | 49 | def on_tool_start( 50 | self, 51 | serialized: Dict[str, Any], 52 | input_str: str, 53 | **kwargs: Any, 54 | ) -> None: 55 | """Do nothing.""" 56 | pass 57 | 58 | def on_agent_action( 59 | self, action: AgentAction, color: Optional[str] = None, **kwargs: Any 60 | ) -> Any: 61 | """Run on agent action.""" 62 | print_text(action.log, color=color or self.color) 63 | 64 | def on_tool_end( 65 | self, 66 | output: str, 67 | color: Optional[str] = None, 68 | observation_prefix: Optional[str] = None, 69 | llm_prefix: Optional[str] = None, 70 | **kwargs: Any, 71 | ) -> None: 72 | """If not the final action, print out observation.""" 73 | if observation_prefix is not None: 74 | print_text(f"\n{observation_prefix}") 75 | print_text(output, color=color or self.color) 76 | if llm_prefix is not None: 77 | print_text(f"\n{llm_prefix}") 78 | 79 | def on_tool_error(self, error: BaseException, **kwargs: Any) -> None: 80 | """Do nothing.""" 81 | pass 82 | 83 | def on_text( 84 | self, 85 | text: str, 86 | color: Optional[str] = None, 87 | end: str = "", 88 | **kwargs: Any, 89 | ) -> None: 90 | """Run when agent ends.""" 91 | print_text(text, color=color or self.color, end=end) 92 | 93 | def on_agent_finish( 94 | self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any 95 | ) -> None: 96 | """Run on agent end.""" 97 | print_text(finish.log, color=color or self.color, end="\n") 98 | -------------------------------------------------------------------------------- /build/lib/lmchain/chains/__init__.py: -------------------------------------------------------------------------------- 1 | name = "chains" -------------------------------------------------------------------------------- /build/lib/lmchain/chains/cmd.py: -------------------------------------------------------------------------------- 1 | #这里是执行对CMD命令进行调用的chain 2 | from langchain.chains.llm import LLMChain 3 | from langchain.prompts import PromptTemplate 4 | from lmchain.lmchain.agents import llmAgent 5 | import os,re 6 | 7 | class LLMCMDChain: 8 | def __init__(self ,llm): 9 | qa_prompt = PromptTemplate(template="""你现在根据需要完成对命令行的编写,要根据需求编写对应的在Windows系统终端运行的命令,不要用%question形参这种指代的参数形式,直接给出可以运行的命令。 10 | Question: 给我一个在Windows系统终端中可以准确执行{question}的命令。 11 | , 12 | input_variables=["question"], 13 | ) 14 | answer:""", input_variables=["question"], ) 15 | self.qa_chain = LLMChain(llm=llm, prompt=qa_prompt) 16 | self.pattern = r"```(.*?)\```" 17 | 18 | def run(self ,text): 19 | cmd_response = self.qa_chain.run(question=text) 20 | cmd_string = str(cmd_response).split("```")[-2][1:-1] 21 | os.system(cmd_string) 22 | return cmd_string 23 | -------------------------------------------------------------------------------- /build/lib/lmchain/chains/conversationalRetrievalChain.py: -------------------------------------------------------------------------------- 1 | from langchain.docstore.document import Document 2 | from langchain.text_splitter import RecursiveCharacterTextSplitter 3 | from lmchain.embeddings import embeddings 4 | from lmchain.vectorstores import laiss 5 | from lmchain.agents import llmMultiAgent 6 | from langchain.memory import ConversationBufferMemory 7 | from langchain.prompts import ( 8 | ChatPromptTemplate, # 用于构建聊天模板的类 9 | MessagesPlaceholder, # 用于在模板中插入消息占位的类 10 | SystemMessagePromptTemplate, # 用于构建系统消息模板的类 11 | HumanMessagePromptTemplate # 用于构建人类消息模板的类 12 | ) 13 | from langchain.chains import ConversationChain 14 | 15 | class ConversationalRetrievalChain: 16 | def __init__(self,document,chunk_size = 1280,chunk_overlap = 50,file_name = "这是一份辅助材料"): 17 | """ 18 | :param document: 输入的文本内容,只要一个text文本 19 | :param chunk_size: 切分后每段的字数 20 | :param chunk_overlap: 每个相隔段落重叠的字数 21 | :param file_name: 文本名称/文本地址 22 | """ 23 | self.text_splitter = RecursiveCharacterTextSplitter( chunk_size=chunk_size, chunk_overlap=chunk_overlap) 24 | self.embedding_tool = embeddings.GLMEmbedding() 25 | 26 | self.lmaiss = laiss.LMASS() #这里是用于将文本转化为vector,并且计算query相应的相似度的类 27 | self.llm = llmMultiAgent.AgentZhipuAI() 28 | self.memory = ConversationBufferMemory(return_messages=True) 29 | 30 | conversation_prompt = ChatPromptTemplate.from_messages([ 31 | SystemMessagePromptTemplate.from_template("你是一个最强大的人工智能程序,可以知无不答,但是你不懂的东西会直接回答不知道。"), 32 | MessagesPlaceholder(variable_name="history"), # 历史消息占位符 33 | HumanMessagePromptTemplate.from_template("{input}") # 人类消息输入模板 34 | ]) 35 | 36 | self.qa_chain = ConversationChain(memory=self.memory, prompt=conversation_prompt, llm=self.llm) 37 | "---------------------------" 38 | document = [Document(page_content=document, metadata={"source": file_name})] #对输入的document进行格式化处理 39 | self.documents = self.text_splitter.split_documents(document) #根据 40 | self.vectorstore = self.lmaiss.from_documents(self.documents, embedding_class=self.embedding_tool) 41 | 42 | def __call__(self, query): 43 | query_embedding = self.embedding_tool.embed_query(query) 44 | 45 | #根据query查找最近的那个序列 46 | close_id = self.lmaiss.get_similarity_vector_indexs(query_embedding, self.vectorstore, k=1)[0] 47 | #查找最近的那个段落id 48 | doc = self.documents[close_id] 49 | 50 | #构建查询的query 51 | query = f"你现在要回答问题'{query}',你可以参考文献'{doc}',你如果找不到对应的内容,就从自己的记忆体中查找,就回答'请提供更为准确的查询内容',注意你要一步步的思考再回答。" 52 | result = (self.qa_chain.predict(input=query)) 53 | return result 54 | 55 | def predict(self,input): 56 | result = self.__call__(input) 57 | return result -------------------------------------------------------------------------------- /build/lib/lmchain/chains/mathChain.py: -------------------------------------------------------------------------------- 1 | #这里是执行对CMD命令进行调用的chain 2 | 3 | from langchain.chains.llm import LLMChain 4 | from langchain.prompts import PromptTemplate 5 | from lmchain.lmchain.agents import llmAgent 6 | import os,re,math 7 | 8 | try: 9 | import numexpr # noqa: F401 10 | except ImportError: 11 | raise ImportError( 12 | "LMchain requires the numexpr package. " 13 | "Please install it with `pip install numexpr`." 14 | ) 15 | 16 | 17 | class LLMMathChain: 18 | def __init__(self ,llm): 19 | qa_prompt = PromptTemplate(template="""现在给你一个中文命令,请你把这个命令转化成数学公式。直接给出数学公式。这个公式会在numexpr包中调用。 20 | Question: 我现在需要计算{question},结果需要在numexpr包中调用。 21 | , 22 | input_variables=["question"], 23 | ) 24 | answer:""", input_variables=["question"], ) 25 | self.qa_chain = LLMChain(llm=llm, prompt=qa_prompt) 26 | 27 | 28 | def run(self ,text): 29 | cmd_response = self.qa_chain.run(question=text) 30 | result = self._evaluate_expression(str(cmd_response)) 31 | return result 32 | 33 | 34 | def _evaluate_expression(self, expression: str) -> str: 35 | import numexpr # noqa: F401 36 | 37 | try: 38 | local_dict = {"pi": math.pi, "e": math.e} 39 | output = str( 40 | numexpr.evaluate( 41 | expression.strip(), 42 | global_dict={}, # restrict access to globals 43 | local_dict=local_dict, # add common mathematical functions 44 | ) 45 | ) 46 | except Exception as e: 47 | raise ValueError( 48 | f'LMchain._evaluate("{expression}") raised error: {e}.' 49 | " Please try again with a valid numerical expression" 50 | ) 51 | 52 | # Remove any leading and trailing brackets from the output 53 | return re.sub(r"^\[|\]$", "", output) -------------------------------------------------------------------------------- /build/lib/lmchain/chains/question_answering.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | -------------------------------------------------------------------------------- /build/lib/lmchain/chains/toolchain.py: -------------------------------------------------------------------------------- 1 | from langchain.chains import LLMChain 2 | from langchain.prompts import PromptTemplate 3 | 4 | from tqdm import tqdm 5 | from lmchain.tools import tool_register 6 | 7 | 8 | class GLMToolChain: 9 | def __init__(self, llm): 10 | 11 | self.llm = llm 12 | self.tool_register = tool_register 13 | self.tools = tool_register.get_tools() 14 | 15 | def __call__(self, query="", tools=None): 16 | 17 | if query == "": 18 | raise "query需要填入查询问题" 19 | if tools != None: 20 | self.tools = tools 21 | else: 22 | raise "将使用默认tools完成函数工具调用~" 23 | template = f""" 24 | 你现在是一个专业的人工智能助手,你现在的需求是{query}。而你需要借助于工具在{self.tools}中找到对应的函数,用json格式返回对应的函数名和参数。 25 | 函数名定义为function_name,参数名为params,还要求写入详细的形参与实参。 26 | 27 | 如果找到合适的函数,就返回json格式的函数名和需要的参数,不要回答任何描述和解释。 28 | 29 | 如果没有找到合适的函数,则返回:'未找到合适参数,请提供更详细的描述。' 30 | """ 31 | 32 | flag = True 33 | counter = 0 34 | while flag: 35 | try: 36 | res = self.llm(template) 37 | 38 | import json 39 | res_dict = json.loads(res) 40 | res_dict = json.loads(res_dict) 41 | flag = False 42 | except: 43 | # print("失败输出,现在开始重新验证") 44 | template = f""" 45 | 你现在是一个专业的人工智能助手,你现在的需求是{query}。而你需要借助于工具在{self.tools}中找到对应的函数,用json格式返回对应的函数名和参数。 46 | 函数名定义为function_name,参数名为params,还要求写入详细的形参与实参。 47 | 48 | 如果找到合适的函数,就返回json格式的函数名和需要的参数,不要回答任何描述和解释。 49 | 50 | 如果没有找到合适的函数,则返回:'未找到合适参数,请提供更详细的描述。' 51 | 52 | 你刚才生成了一组结果,但是返回不符合json格式,现在请你重新按json格式生成并返回结果。 53 | """ 54 | counter += 1 55 | if counter >= 5: 56 | return '未找到合适参数,请提供更详细的描述。' 57 | return res_dict 58 | 59 | def run(self, query, tools=None): 60 | tools = (self.tool_register.get_tools()) 61 | result = self.__call__(query, tools) 62 | 63 | if result == "未找到合适参数,请提供更详细的描述。": 64 | return "未找到合适参数,请提供更详细的描述。" 65 | else: 66 | print("找到对应工具函数,格式如下:", result) 67 | result = self.dispatch_tool(result) 68 | from lmchain.prompts.templates import PromptTemplate 69 | tool_prompt = PromptTemplate( 70 | input_variables=["query", "result"], # 输入变量包括中文和英文。 71 | template="你现在是一个私人助手,现在你的查询任务是{query},而你通过工具从网上查询的结果是{result},现在根据查询的内容与查询的结果,生成最终答案。", 72 | # 使用模板格式化输入和输出。 73 | ) 74 | from langchain.chains import LLMChain 75 | chain = LLMChain(llm=self.llm, prompt=tool_prompt) 76 | 77 | response = (chain.run({"query": query, "result": result})) 78 | 79 | return response 80 | 81 | def add_tools(self, tool): 82 | self.tool_register.register_tool(tool) 83 | return True 84 | 85 | def dispatch_tool(self, tool_result) -> str: 86 | tool_name = tool_result["function_name"] 87 | tool_params = tool_result["params"] 88 | if tool_name not in self.tool_register._TOOL_HOOKS: 89 | return f"Tool `{tool_name}` not found. Please use a provided tool." 90 | tool_call = self.tool_register._TOOL_HOOKS[tool_name] 91 | 92 | try: 93 | ret = tool_call(**tool_params) 94 | except: 95 | import traceback 96 | ret = traceback.format_exc() 97 | return str(ret) 98 | 99 | def get_tools(self): 100 | return (self.tool_register.get_tools()) 101 | 102 | 103 | if __name__ == '__main__': 104 | from lmchain.agents import llmMultiAgent 105 | 106 | llm = llmMultiAgent.AgentZhipuAI() 107 | 108 | from lmchain.chains import toolchain 109 | 110 | tool_chain = toolchain.GLMToolChain(llm) 111 | 112 | from typing import Annotated 113 | 114 | 115 | def rando_numbr( 116 | seed: Annotated[int, 'The random seed used by the generator', True], 117 | range: Annotated[tuple[int, int], 'The range of the generated numbers', True], 118 | ) -> int: 119 | """ 120 | Generates a random number x, s.t. range[0] <= x < range[1] 121 | """ 122 | import random 123 | return random.Random(seed).randint(*range) 124 | 125 | 126 | tool_chain.add_tools(rando_numbr) 127 | 128 | print("------------------------------------------------------") 129 | query = "今天shanghai的天气是什么?" 130 | result = tool_chain.run(query) 131 | 132 | result = tool_chain.dispatch_tool(result) 133 | print(result) 134 | 135 | 136 | -------------------------------------------------------------------------------- /build/lib/lmchain/chains/urlRequestChain.py: -------------------------------------------------------------------------------- 1 | from langchain.chains import LLMRequestsChain, LLMChain 2 | from langchain.prompts import PromptTemplate 3 | 4 | import requests 5 | from bs4 import BeautifulSoup 6 | from tqdm import tqdm 7 | 8 | 9 | class LMRequestsChain: 10 | def __init__(self,llm,max_url_num = 2): 11 | template = """Between >>> and <<< are the raw search result text from google. 12 | Extract the answer to the question '{query}' or say "not found" if the information is not contained. 13 | Use the format 14 | Extracted: 15 | >>> {requests_result} <<< 16 | Extracted:""" 17 | PROMPT = PromptTemplate( 18 | input_variables=["query", "requests_result"], 19 | template=template, 20 | ) 21 | self.chain = LLMRequestsChain(llm_chain=LLMChain(llm=llm, prompt=PROMPT)) 22 | self.max_url_num = max_url_num 23 | 24 | query_prompt = PromptTemplate( 25 | input_variables=["query","responses"], 26 | template = "作为一名专业的信息总结员,我需要查询的信息为{query},根据提供的信息{responses}回答一下查询的结果。") 27 | self.query_chain = LLMChain(llm=llm, prompt=query_prompt) 28 | 29 | def __call__(self, query,target_site = ""): 30 | url_list = self.get_urls(query,target_site = target_site) 31 | print(f"查找到{len(url_list)}条url内容,现在开始解析其中的{self.max_url_num}条内容。") 32 | responses = [] 33 | for url in tqdm(url_list[:self.max_url_num]): 34 | inputs = { 35 | "query": query, 36 | "url": url 37 | } 38 | 39 | response = self.chain(inputs) 40 | output = response["output"] 41 | responses.append(output) 42 | if len(responses) != 0: 43 | output = self.query_chain.run({"query":query,"responses":responses}) 44 | return output 45 | else: 46 | return "查找内容为空,请更换查找词" 47 | 48 | def query_form_url(self,query = "LMchain是什么?",url = ""): 49 | assert url != "",print("url link must be set") 50 | inputs = { 51 | "query": query, 52 | "url": url 53 | } 54 | response = self.chain(inputs) 55 | return response 56 | 57 | def get_urls(self,query='lmchain是什么?', target_site=""): 58 | def bing_search(query, count=30): 59 | url = f'https://cn.bing.com/search?q={query}' 60 | headers = { 61 | 'User-Agent': 'Mozilla/6.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.3'} 62 | response = requests.get(url, headers=headers) 63 | if response.status_code == 200: 64 | html = response.text 65 | # 使用BeautifulSoup解析HTML 66 | 67 | soup = BeautifulSoup(html, 'html.parser') 68 | results = soup.find_all('li', class_='b_algo') 69 | return [result.find('a').text for result in results[:count]] 70 | else: 71 | print(f'请求失败,状态码:{response.status_code}') 72 | return [] 73 | results = bing_search(query) 74 | if len(results) == 0: 75 | return None 76 | url_list = [] 77 | if target_site != "": 78 | for i, result in enumerate(results): 79 | if "https" in result and target_site in result: 80 | url = "https://" + result.split("https://")[1] 81 | url_list.append(url) 82 | else: 83 | for i, result in enumerate(results): 84 | if "https" in result: 85 | url = "https://" + result.split("https://")[1] 86 | url_list.append(url) 87 | if len(url_list) > 0: 88 | return url_list 89 | else: 90 | # 这里是确保在知乎里面找不到对应的内容,有相应的内容返回 91 | for i, result in enumerate(results): 92 | if "https" in result: 93 | url = "https://" + result.split("https://")[1] 94 | url_list.append(url) 95 | return url_list 96 | 97 | 98 | -------------------------------------------------------------------------------- /build/lib/lmchain/embeddings/__init__.py: -------------------------------------------------------------------------------- 1 | name = "embeddings" -------------------------------------------------------------------------------- /build/lib/lmchain/embeddings/embeddings.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from abc import ABC, abstractmethod 3 | from typing import List 4 | 5 | 6 | class Embeddings(ABC): 7 | """Interface for embedding models.""" 8 | 9 | @abstractmethod 10 | def embed_documents(self, texts: List[str]) -> List[List[float]]: 11 | """Embed search docs.""" 12 | 13 | @abstractmethod 14 | def embed_query(self, text: str) -> List[float]: 15 | """Embed query text.""" 16 | 17 | async def aembed_documents(self, texts: List[str]) -> List[List[float]]: 18 | """Asynchronous Embed search docs.""" 19 | return await asyncio.get_running_loop().run_in_executor( 20 | None, self.embed_documents, texts 21 | ) 22 | 23 | async def aembed_query(self, text: str) -> List[float]: 24 | """Asynchronous Embed query text.""" 25 | return await asyncio.get_running_loop().run_in_executor( 26 | None, self.embed_query, text 27 | ) 28 | 29 | class LMEmbedding(Embeddings): 30 | from modelscope.pipelines import pipeline 31 | from modelscope.utils.constant import Tasks 32 | pipeline_se = pipeline(Tasks.sentence_embedding,model='thomas/text2vec-base-chinese', model_revision='v1.0.0',device="cuda") 33 | 34 | 35 | def _costruct_inputs(self,texts): 36 | 37 | inputs = { 38 | "source_sentence": texts 39 | } 40 | 41 | return inputs 42 | 43 | def embed_documents(self, texts: List[str]) -> List[List[float]]: 44 | """Embed search docs.""" 45 | 46 | inputs = self._costruct_inputs(texts) 47 | result_embeddings = self.pipeline_se(input=inputs) 48 | return result_embeddings["text_embedding"] 49 | 50 | def embed_query(self, text: str) -> List[float]: 51 | """Embed query text.""" 52 | inputs = self._costruct_inputs([text]) 53 | result_embeddings = self.pipeline_se(input=inputs) 54 | return result_embeddings["text_embedding"] 55 | 56 | 57 | class GLMEmbedding(Embeddings): 58 | import zhipuai as zhipuai 59 | zhipuai.api_key = "1f565e40af1198e11ff1fd8a5b42771d.SjNfezc40YFsz2KC" # 控制台中获取的 APIKey 信息 60 | def _costruct_inputs(self, texts): 61 | inputs = { 62 | "source_sentence": texts 63 | } 64 | 65 | return inputs 66 | 67 | def embed_documents(self, texts: List[str]) -> List[List[float]]: 68 | """Embed search docs.""" 69 | result_embeddings = [] 70 | for text in texts: 71 | embedding = self.embed_query(text) 72 | result_embeddings.append(embedding) 73 | return result_embeddings 74 | 75 | def embed_query(self, text: str) -> List[float]: 76 | """Embed query text.""" 77 | result_embeddings = self.zhipuai.model_api.invoke( 78 | model="text_embedding", prompt=text) 79 | return result_embeddings["data"]["embedding"] 80 | 81 | 82 | 83 | 84 | if __name__ == '__main__': 85 | inputs = ["不可以,早晨喝牛奶不科学","不可以,今天早晨喝牛奶不科学","早晨喝牛奶不科学"] 86 | print(GLMEmbedding().embed_documents(inputs)) 87 | 88 | 89 | -------------------------------------------------------------------------------- /build/lib/lmchain/hello.py: -------------------------------------------------------------------------------- 1 | print("hello world") -------------------------------------------------------------------------------- /build/lib/lmchain/index/__init__.py: -------------------------------------------------------------------------------- 1 | name = "index" -------------------------------------------------------------------------------- /build/lib/lmchain/index/indexChain.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Optional, Type 2 | 3 | from langchain.chains.qa_with_sources.retrieval import RetrievalQAWithSourcesChain 4 | from langchain.chains.retrieval_qa.base import RetrievalQA 5 | from langchain.document_loaders.base import BaseLoader 6 | from pydantic.v1 import BaseModel, Extra, Field 7 | from langchain.schema import Document 8 | from langchain.schema.embeddings import Embeddings 9 | from langchain.schema.language_model import BaseLanguageModel 10 | from langchain.schema.vectorstore import VectorStore 11 | from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter 12 | from langchain.vectorstores.chroma import Chroma 13 | 14 | 15 | def _get_default_text_splitter() -> TextSplitter: 16 | return RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=0) 17 | 18 | from lmchain.embeddings import embeddings 19 | embedding_tool = embeddings.GLMEmbedding() 20 | 21 | class VectorstoreIndexCreator(BaseModel): 22 | """Logic for creating indexes.""" 23 | 24 | class Config: 25 | """Configuration for this pydantic object.""" 26 | extra = Extra.forbid 27 | arbitrary_types_allowed = True 28 | 29 | 30 | 31 | 32 | chunk_size = 1280 # 每段字数长度 33 | chunk_overlap = 32 # 重叠的字数 34 | text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap) 35 | 36 | 37 | 38 | 39 | 40 | def from_loaders(self, loaders: List[BaseLoader]): 41 | """Create a vectorstore index from loaders.""" 42 | docs = [] 43 | for loader in loaders: 44 | docs.extend(loader.load()) 45 | return self.from_documents(docs) 46 | 47 | 48 | def from_documents(self, documents: List[Document]): 49 | #说一下这个index的作用就是返回 50 | sub_docs = self.text_splitter.split_documents(documents) 51 | 52 | # texts = [d.page_content for d in sub_docs] 53 | # metadatas = [d.metadata for d in sub_docs] 54 | 55 | qa_chain = ConversationalRetrievalChain(document=sub_docs) 56 | return qa_chain 57 | 58 | 59 | from langchain.docstore.document import Document 60 | from langchain.text_splitter import RecursiveCharacterTextSplitter 61 | from lmchain.embeddings import embeddings 62 | from lmchain.vectorstores import laiss 63 | from lmchain.agents import llmMultiAgent 64 | from langchain.memory import ConversationBufferMemory 65 | from langchain.prompts import ( 66 | ChatPromptTemplate, # 用于构建聊天模板的类 67 | MessagesPlaceholder, # 用于在模板中插入消息占位的类 68 | SystemMessagePromptTemplate, # 用于构建系统消息模板的类 69 | HumanMessagePromptTemplate # 用于构建人类消息模板的类 70 | ) 71 | from langchain.chains import ConversationChain 72 | 73 | class ConversationalRetrievalChain: 74 | def __init__(self,document,chunk_size = 1280,chunk_overlap = 50,file_name = "这是一份辅助材料"): 75 | """ 76 | :param document: 输入的文本内容,只要一个text文本 77 | :param chunk_size: 切分后每段的字数 78 | :param chunk_overlap: 每个相隔段落重叠的字数 79 | :param file_name: 文本名称/文本地址 80 | """ 81 | self.text_splitter = RecursiveCharacterTextSplitter( chunk_size=chunk_size, chunk_overlap=chunk_overlap) 82 | self.embedding_tool = embedding_tool 83 | 84 | self.lmaiss = laiss.LMASS() #这里是用于将文本转化为vector,并且计算query相应的相似度的类 85 | self.llm = llmMultiAgent.AgentZhipuAI() 86 | self.memory = ConversationBufferMemory(return_messages=True) 87 | 88 | conversation_prompt = ChatPromptTemplate.from_messages([ 89 | SystemMessagePromptTemplate.from_template("你是一个最强大的人工智能程序,可以知无不答,但是你不懂的东西会直接回答不知道。"), 90 | MessagesPlaceholder(variable_name="history"), # 历史消息占位符 91 | HumanMessagePromptTemplate.from_template("{input}") # 人类消息输入模板 92 | ]) 93 | 94 | self.qa_chain = ConversationChain(memory=self.memory, prompt=conversation_prompt, llm=self.llm) 95 | "---------------------------" 96 | self.metadatas = [] 97 | for doc in document: 98 | self.metadatas.append(doc.metadata) 99 | self.documents = self.text_splitter.split_documents(document) #根据 100 | self.vectorstore = self.lmaiss.from_documents(self.documents, embedding_class=self.embedding_tool) 101 | 102 | 103 | 104 | def __call__(self, query): 105 | query_embedding = self.embedding_tool.embed_query(query) 106 | 107 | #根据query查找最近的那个序列 108 | close_id = self.lmaiss.get_similarity_vector_indexs(query_embedding, self.vectorstore, k=1)[0] 109 | #查找最近的那个段落id 110 | doc = self.documents[close_id] 111 | meta = self.metadatas[close_id] 112 | #构建查询的query 113 | query = f"你现在要回答问题'{query}',你可以参考文献'{doc}',你如果找不到对应的内容,就从自己的记忆体中查找,就回答'请提供更为准确的查询内容'。" 114 | result = (self.qa_chain.predict(input=query)) 115 | return result,meta 116 | 117 | 118 | def query(self,input): 119 | result,meta = self.__call__(input) 120 | return result 121 | 122 | #这里的模型的意思是 123 | def query_with_sources(self,input): 124 | result,meta = self.__call__(input) 125 | return {"answer":result,"sources":meta} 126 | -------------------------------------------------------------------------------- /build/lib/lmchain/llms/__init__.py: -------------------------------------------------------------------------------- 1 | name = "llms" -------------------------------------------------------------------------------- /build/lib/lmchain/llms/base.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/virgo777/lmchain/3fc4a55df07100d4fe0beef4da89e3fbc6288938/build/lib/lmchain/llms/base.py -------------------------------------------------------------------------------- /build/lib/lmchain/load/__init__.py: -------------------------------------------------------------------------------- 1 | name = "load" -------------------------------------------------------------------------------- /build/lib/lmchain/memory/__init__.py: -------------------------------------------------------------------------------- 1 | name = "memory" -------------------------------------------------------------------------------- /build/lib/lmchain/memory/chat_memory.py: -------------------------------------------------------------------------------- 1 | from abc import ABC 2 | from typing import Any, Dict, Optional, Tuple 3 | 4 | from lmchain.memory.utils import get_prompt_input_key 5 | 6 | from lmchain.schema.memory import BaseMemory 7 | 8 | 9 | class BaseChatMemory(BaseMemory, ABC): 10 | """Abstract base class for chat memory.""" 11 | 12 | from lmchain.memory import messageHistory 13 | chat_memory = messageHistory.ChatMessageHistory() 14 | output_key: Optional[str] = None 15 | input_key: Optional[str] = None 16 | return_messages: bool = False 17 | 18 | def _get_input_output( 19 | self, inputs: Dict[str, Any], outputs: Dict[str, str] 20 | ) -> Tuple[str, str]: 21 | if self.input_key is None: 22 | prompt_input_key = get_prompt_input_key(inputs, self.memory_variables) 23 | else: 24 | prompt_input_key = self.input_key 25 | if self.output_key is None: 26 | if len(outputs) != 1: 27 | raise ValueError(f"One output key expected, got {outputs.keys()}") 28 | output_key = list(outputs.keys())[0] 29 | else: 30 | output_key = self.output_key 31 | return inputs[prompt_input_key], outputs[output_key] 32 | 33 | def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None: 34 | """Save context from this conversation to buffer.""" 35 | input_str, output_str = self._get_input_output(inputs, outputs) 36 | self.chat_memory.add_user_message(input_str) 37 | self.chat_memory.add_ai_message(output_str) 38 | 39 | def clear(self) -> None: 40 | """Clear memory contents.""" 41 | self.chat_memory.clear() 42 | -------------------------------------------------------------------------------- /build/lib/lmchain/memory/messageHistory.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING, Any, Dict, List, Sequence, Union 2 | from typing_extensions import Literal 3 | 4 | 5 | class ChatMessageHistory: 6 | """In memory implementation of chat message history. 7 | 8 | Stores messages in an in memory list. 9 | """ 10 | 11 | messages = [] 12 | 13 | def add_message(self, message) -> None: 14 | """Add a self-created message to the store""" 15 | self.messages.append(message) 16 | 17 | def clear(self) -> None: 18 | self.messages = [] 19 | 20 | def __str__(self): 21 | return ", ".join(str(message) for message in self.messages) 22 | 23 | 24 | class ChatMessageHistory(ChatMessageHistory): 25 | def __init__(self): 26 | super(ChatMessageHistory).__init__() 27 | 28 | def add_user_message(self, content: str) -> None: 29 | """Convenience method for adding a human message string to the store. 30 | 31 | Args: 32 | content: The string contents of a human message. 33 | """ 34 | mes = f"HumanMessage(content={content})" 35 | self.messages.append(mes) 36 | 37 | def add_ai_message(self, content: str) -> None: 38 | """Convenience method for adding an AI message string to the store. 39 | 40 | Args: 41 | content: The string contents of an AI message. 42 | """ 43 | mes = f"AIMessage(content={content})" 44 | self.messages.append(mes) 45 | 46 | 47 | from typing import Any, Dict, List, Optional 48 | 49 | from langchain.memory.chat_memory import BaseChatMemory, BaseMemory 50 | from langchain.memory.utils import get_prompt_input_key 51 | from pydantic.v1 import root_validator 52 | from langchain.schema.messages import BaseMessage, get_buffer_string 53 | 54 | 55 | class ConversationBufferMemory(BaseChatMemory): 56 | """Buffer for storing conversation memory.""" 57 | 58 | human_prefix: str = "Human" 59 | ai_prefix: str = "AI" 60 | memory_key: str = "history" #: :meta private: 61 | 62 | @property 63 | def buffer(self) -> Any: 64 | """String buffer of memory.""" 65 | return self.buffer_as_messages if self.return_messages else self.buffer_as_str 66 | 67 | @property 68 | def buffer_as_str(self) -> str: 69 | """Exposes the buffer as a string in case return_messages is True.""" 70 | return get_buffer_string( 71 | self.chat_memory.messages, 72 | human_prefix=self.human_prefix, 73 | ai_prefix=self.ai_prefix, 74 | ) 75 | 76 | @property 77 | def buffer_as_messages(self) -> List[BaseMessage]: 78 | """Exposes the buffer as a list of messages in case return_messages is False.""" 79 | return self.chat_memory.messages 80 | 81 | @property 82 | def memory_variables(self) -> List[str]: 83 | """Will always return list of memory variables. 84 | 85 | :meta private: 86 | """ 87 | return [self.memory_key] 88 | 89 | def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]: 90 | """Return history buffer.""" 91 | return {self.memory_key: self.buffer} 92 | 93 | 94 | class ConversationStringBufferMemory(BaseMemory): 95 | """Buffer for storing conversation memory.""" 96 | 97 | human_prefix: str = "Human" 98 | ai_prefix: str = "AI" 99 | """Prefix to use for AI generated responses.""" 100 | buffer: str = "" 101 | output_key: Optional[str] = None 102 | input_key: Optional[str] = None 103 | memory_key: str = "history" #: :meta private: 104 | 105 | @root_validator() 106 | def validate_chains(cls, values: Dict) -> Dict: 107 | """Validate that return messages is not True.""" 108 | if values.get("return_messages", False): 109 | raise ValueError( 110 | "return_messages must be False for ConversationStringBufferMemory" 111 | ) 112 | return values 113 | 114 | @property 115 | def memory_variables(self) -> List[str]: 116 | """Will always return list of memory variables. 117 | :meta private: 118 | """ 119 | return [self.memory_key] 120 | 121 | def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]: 122 | """Return history buffer.""" 123 | return {self.memory_key: self.buffer} 124 | 125 | def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None: 126 | """Save context from this conversation to buffer.""" 127 | if self.input_key is None: 128 | prompt_input_key = get_prompt_input_key(inputs, self.memory_variables) 129 | else: 130 | prompt_input_key = self.input_key 131 | if self.output_key is None: 132 | if len(outputs) != 1: 133 | raise ValueError(f"One output key expected, got {outputs.keys()}") 134 | output_key = list(outputs.keys())[0] 135 | else: 136 | output_key = self.output_key 137 | human = f"{self.human_prefix}: " + inputs[prompt_input_key] 138 | ai = f"{self.ai_prefix}: " + outputs[output_key] 139 | self.buffer += "\n" + "\n".join([human, ai]) 140 | 141 | def clear(self) -> None: 142 | """Clear memory contents.""" 143 | self.buffer = "" 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | -------------------------------------------------------------------------------- /build/lib/lmchain/memory/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List 2 | 3 | 4 | def get_prompt_input_key(inputs: Dict[str, Any], memory_variables: List[str]) -> str: 5 | """ 6 | Get the prompt input key. 7 | 8 | Args: 9 | inputs: Dict[str, Any] 10 | memory_variables: List[str] 11 | 12 | Returns: 13 | A prompt input key. 14 | """ 15 | # "stop" is a special key that can be passed as input but is not used to 16 | # format the prompt. 17 | prompt_input_keys = list(set(inputs).difference(memory_variables + ["stop"])) 18 | if len(prompt_input_keys) != 1: 19 | raise ValueError(f"One input key expected got {prompt_input_keys}") 20 | return prompt_input_keys[0] -------------------------------------------------------------------------------- /build/lib/lmchain/model/__init__.py: -------------------------------------------------------------------------------- 1 | name = "model" -------------------------------------------------------------------------------- /build/lib/lmchain/prompts/__init__.py: -------------------------------------------------------------------------------- 1 | name = "prompts" -------------------------------------------------------------------------------- /build/lib/lmchain/prompts/base.py: -------------------------------------------------------------------------------- 1 | """BasePrompt schema definition.""" 2 | from __future__ import annotations 3 | 4 | import warnings 5 | from abc import ABC 6 | from string import Formatter 7 | from typing import Any, Callable, Dict, List, Literal, Set 8 | 9 | from lmchain.schema.messages import BaseMessage, HumanMessage 10 | from lmchain.schema.prompt import PromptValue 11 | from lmchain.schema.prompt_template import BasePromptTemplate 12 | #from langchain.schema.prompt_template import BasePromptTemplate 13 | from lmchain.utils.formatting import formatter 14 | 15 | 16 | def jinja2_formatter(template: str, **kwargs: Any) -> str: 17 | """Format a template using jinja2. 18 | 19 | *Security warning*: As of LangChain 0.0.329, this method uses Jinja2's 20 | SandboxedEnvironment by default. However, this sand-boxing should 21 | be treated as a best-effort approach rather than a guarantee of security. 22 | Do not accept jinja2 templates from untrusted sources as they may lead 23 | to arbitrary Python code execution. 24 | 25 | https://jinja.palletsprojects.com/en/3.1.x/sandbox/ 26 | """ 27 | try: 28 | from jinja2.sandbox import SandboxedEnvironment 29 | except ImportError: 30 | raise ImportError( 31 | "jinja2 not installed, which is needed to use the jinja2_formatter. " 32 | "Please install it with `pip install jinja2`." 33 | "Please be cautious when using jinja2 templates. " 34 | "Do not expand jinja2 templates using unverified or user-controlled " 35 | "inputs as that can result in arbitrary Python code execution." 36 | ) 37 | 38 | # This uses a sandboxed environment to prevent arbitrary code execution. 39 | # Jinja2 uses an opt-out rather than opt-in approach for sand-boxing. 40 | # Please treat this sand-boxing as a best-effort approach rather than 41 | # a guarantee of security. 42 | # We recommend to never use jinja2 templates with untrusted inputs. 43 | # https://jinja.palletsprojects.com/en/3.1.x/sandbox/ 44 | # approach not a guarantee of security. 45 | return SandboxedEnvironment().from_string(template).render(**kwargs) 46 | 47 | 48 | def validate_jinja2(template: str, input_variables: List[str]) -> None: 49 | """ 50 | Validate that the input variables are valid for the template. 51 | Issues a warning if missing or extra variables are found. 52 | 53 | Args: 54 | template: The template string. 55 | input_variables: The input variables. 56 | """ 57 | input_variables_set = set(input_variables) 58 | valid_variables = _get_jinja2_variables_from_template(template) 59 | missing_variables = valid_variables - input_variables_set 60 | extra_variables = input_variables_set - valid_variables 61 | 62 | warning_message = "" 63 | if missing_variables: 64 | warning_message += f"Missing variables: {missing_variables} " 65 | 66 | if extra_variables: 67 | warning_message += f"Extra variables: {extra_variables}" 68 | 69 | if warning_message: 70 | warnings.warn(warning_message.strip()) 71 | 72 | 73 | def _get_jinja2_variables_from_template(template: str) -> Set[str]: 74 | try: 75 | from jinja2 import Environment, meta 76 | except ImportError: 77 | raise ImportError( 78 | "jinja2 not installed, which is needed to use the jinja2_formatter. " 79 | "Please install it with `pip install jinja2`." 80 | ) 81 | env = Environment() 82 | ast = env.parse(template) 83 | variables = meta.find_undeclared_variables(ast) 84 | return variables 85 | 86 | 87 | DEFAULT_FORMATTER_MAPPING: Dict[str, Callable] = { 88 | "f-string": formatter.format, 89 | "jinja2": jinja2_formatter, 90 | } 91 | 92 | DEFAULT_VALIDATOR_MAPPING: Dict[str, Callable] = { 93 | "f-string": formatter.validate_input_variables, 94 | "jinja2": validate_jinja2, 95 | } 96 | 97 | 98 | def check_valid_template( 99 | template: str, template_format: str, input_variables: List[str] 100 | ) -> None: 101 | """Check that template string is valid. 102 | 103 | Args: 104 | template: The template string. 105 | template_format: The template format. Should be one of "f-string" or "jinja2". 106 | input_variables: The input variables. 107 | 108 | Raises: 109 | ValueError: If the template format is not supported. 110 | """ 111 | if template_format not in DEFAULT_FORMATTER_MAPPING: 112 | valid_formats = list(DEFAULT_FORMATTER_MAPPING) 113 | raise ValueError( 114 | f"Invalid template format. Got `{template_format}`;" 115 | f" should be one of {valid_formats}" 116 | ) 117 | try: 118 | validator_func = DEFAULT_VALIDATOR_MAPPING[template_format] 119 | validator_func(template, input_variables) 120 | except KeyError as e: 121 | raise ValueError( 122 | "Invalid prompt schema; check for mismatched or missing input parameters. " 123 | + str(e) 124 | ) 125 | 126 | 127 | def get_template_variables(template: str, template_format: str) -> List[str]: 128 | """Get the variables from the template. 129 | 130 | Args: 131 | template: The template string. 132 | template_format: The template format. Should be one of "f-string" or "jinja2". 133 | 134 | Returns: 135 | The variables from the template. 136 | 137 | Raises: 138 | ValueError: If the template format is not supported. 139 | """ 140 | if template_format == "jinja2": 141 | # Get the variables for the template 142 | input_variables = _get_jinja2_variables_from_template(template) 143 | elif template_format == "f-string": 144 | input_variables = { 145 | v for _, v, _, _ in Formatter().parse(template) if v is not None 146 | } 147 | else: 148 | raise ValueError(f"Unsupported template format: {template_format}") 149 | 150 | return sorted(input_variables) 151 | 152 | 153 | class StringPromptValue(PromptValue): 154 | """String prompt value.""" 155 | 156 | text: str 157 | """Prompt text.""" 158 | type: Literal["StringPromptValue"] = "StringPromptValue" 159 | 160 | def to_string(self) -> str: 161 | """Return prompt as string.""" 162 | return self.text 163 | 164 | def to_messages(self) -> List[BaseMessage]: 165 | """Return prompt as messages.""" 166 | return [HumanMessage(content=self.text)] 167 | 168 | 169 | class StringPromptTemplate(BasePromptTemplate, ABC): 170 | """String prompt that exposes the format method, returning a prompt.""" 171 | 172 | def format_prompt(self, **kwargs: Any) -> PromptValue: 173 | """Create Chat Messages.""" 174 | return StringPromptValue(text=self.format(**kwargs)) 175 | -------------------------------------------------------------------------------- /build/lib/lmchain/schema/__init__.py: -------------------------------------------------------------------------------- 1 | name = "schema" -------------------------------------------------------------------------------- /build/lib/lmchain/schema/agent.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Any, Literal, Sequence, Union 4 | 5 | from lmchain.load.serializable import Serializable 6 | from lmchain.schema.messages import BaseMessage 7 | 8 | 9 | class AgentAction(Serializable): 10 | """A full description of an action for an ActionAgent to execute.""" 11 | 12 | tool: str 13 | """The name of the Tool to execute.""" 14 | tool_input: Union[str, dict] 15 | """The input to pass in to the Tool.""" 16 | log: str 17 | """Additional information to log about the action. 18 | This log can be used in a few ways. First, it can be used to audit 19 | what exactly the LLM predicted to lead to this (tool, tool_input). 20 | Second, it can be used in future iterations to show the LLMs prior 21 | thoughts. This is useful when (tool, tool_input) does not contain 22 | full information about the LLM prediction (for example, any `thought` 23 | before the tool/tool_input).""" 24 | type: Literal["AgentAction"] = "AgentAction" 25 | 26 | def __init__( 27 | self, tool: str, tool_input: Union[str, dict], log: str, **kwargs: Any 28 | ): 29 | """Override init to support instantiation by position for backward compat.""" 30 | super().__init__(tool=tool, tool_input=tool_input, log=log, **kwargs) 31 | 32 | @classmethod 33 | def is_lc_serializable(cls) -> bool: 34 | """Return whether or not the class is serializable.""" 35 | return True 36 | 37 | 38 | class AgentActionMessageLog(AgentAction): 39 | message_log: Sequence[BaseMessage] 40 | """Similar to log, this can be used to pass along extra 41 | information about what exact messages were predicted by the LLM 42 | before parsing out the (tool, tool_input). This is again useful 43 | if (tool, tool_input) cannot be used to fully recreate the LLM 44 | prediction, and you need that LLM prediction (for future agent iteration). 45 | Compared to `log`, this is useful when the underlying LLM is a 46 | ChatModel (and therefore returns messages rather than a string).""" 47 | # Ignoring type because we're overriding the type from AgentAction. 48 | # And this is the correct thing to do in this case. 49 | # The type literal is used for serialization purposes. 50 | type: Literal["AgentActionMessageLog"] = "AgentActionMessageLog" # type: ignore 51 | 52 | 53 | class AgentFinish(Serializable): 54 | """The final return value of an ActionAgent.""" 55 | 56 | return_values: dict 57 | """Dictionary of return values.""" 58 | log: str 59 | """Additional information to log about the return value. 60 | This is used to pass along the full LLM prediction, not just the parsed out 61 | return value. For example, if the full LLM prediction was 62 | `Final Answer: 2` you may want to just return `2` as a return value, but pass 63 | along the full string as a `log` (for debugging or observability purposes). 64 | """ 65 | type: Literal["AgentFinish"] = "AgentFinish" 66 | 67 | def __init__(self, return_values: dict, log: str, **kwargs: Any): 68 | """Override init to support instantiation by position for backward compat.""" 69 | super().__init__(return_values=return_values, log=log, **kwargs) 70 | 71 | @classmethod 72 | def is_lc_serializable(cls) -> bool: 73 | """Return whether or not the class is serializable.""" 74 | return True 75 | -------------------------------------------------------------------------------- /build/lib/lmchain/schema/document.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import asyncio 4 | from abc import ABC, abstractmethod 5 | from functools import partial 6 | from typing import Any, Literal, Sequence 7 | 8 | from lmchain.load.serializable import Serializable 9 | from pydantic.v1 import Field 10 | 11 | class Document(Serializable): 12 | """Class for storing a piece of text and associated metadata.""" 13 | 14 | page_content: str 15 | """String text.""" 16 | metadata: dict = Field(default_factory=dict) 17 | """Arbitrary metadata about the page content (e.g., source, relationships to other 18 | documents, etc.). 19 | """ 20 | type: Literal["Document"] = "Document" 21 | 22 | @classmethod 23 | def is_lc_serializable(cls) -> bool: 24 | """Return whether this class is serializable.""" 25 | return True 26 | 27 | 28 | class BaseDocumentTransformer(ABC): 29 | """Abstract base class for document transformation systems. 30 | 31 | A document transformation system takes a sequence of Documents and returns a 32 | sequence of transformed Documents. 33 | 34 | Example: 35 | .. code-block:: python 36 | 37 | class EmbeddingsRedundantFilter(BaseDocumentTransformer, BaseModel): 38 | embeddings: Embeddings 39 | similarity_fn: Callable = cosine_similarity 40 | similarity_threshold: float = 0.95 41 | 42 | class Config: 43 | arbitrary_types_allowed = True 44 | 45 | def transform_documents( 46 | self, documents: Sequence[Document], **kwargs: Any 47 | ) -> Sequence[Document]: 48 | stateful_documents = get_stateful_documents(documents) 49 | embedded_documents = _get_embeddings_from_stateful_docs( 50 | self.embeddings, stateful_documents 51 | ) 52 | included_idxs = _filter_similar_embeddings( 53 | embedded_documents, self.similarity_fn, self.similarity_threshold 54 | ) 55 | return [stateful_documents[i] for i in sorted(included_idxs)] 56 | 57 | async def atransform_documents( 58 | self, documents: Sequence[Document], **kwargs: Any 59 | ) -> Sequence[Document]: 60 | raise NotImplementedError 61 | 62 | """ # noqa: E501 63 | 64 | @abstractmethod 65 | def transform_documents( 66 | self, documents: Sequence[Document], **kwargs: Any 67 | ) -> Sequence[Document]: 68 | """Transform a list of documents. 69 | 70 | Args: 71 | documents: A sequence of Documents to be transformed. 72 | 73 | Returns: 74 | A list of transformed Documents. 75 | """ 76 | 77 | async def atransform_documents( 78 | self, documents: Sequence[Document], **kwargs: Any 79 | ) -> Sequence[Document]: 80 | """Asynchronously transform a list of documents. 81 | 82 | Args: 83 | documents: A sequence of Documents to be transformed. 84 | 85 | Returns: 86 | A list of transformed Documents. 87 | """ 88 | return await asyncio.get_running_loop().run_in_executor( 89 | None, partial(self.transform_documents, **kwargs), documents 90 | ) 91 | -------------------------------------------------------------------------------- /build/lib/lmchain/schema/memory.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from abc import ABC, abstractmethod 4 | from typing import Any, Dict, List 5 | 6 | 7 | class BaseMemory( ABC): 8 | """Abstract base class for memory in Chains. 9 | 10 | Memory refers to state in Chains. Memory can be used to store information about 11 | past executions of a Chain and inject that information into the inputs of 12 | future executions of the Chain. For example, for conversational Chains Memory 13 | can be used to store conversations and automatically add them to future model 14 | prompts so that the model has the necessary context to respond coherently to 15 | the latest input. 16 | 17 | Example: 18 | .. code-block:: python 19 | 20 | class SimpleMemory(BaseMemory): 21 | memories: Dict[str, Any] = dict() 22 | 23 | @property 24 | def memory_variables(self) -> List[str]: 25 | return list(self.memories.keys()) 26 | 27 | def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]: 28 | return self.memories 29 | 30 | def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None: 31 | pass 32 | 33 | def clear(self) -> None: 34 | pass 35 | """ # noqa: E501 36 | 37 | class Config: 38 | """Configuration for this pydantic object.""" 39 | 40 | arbitrary_types_allowed = True 41 | 42 | @property 43 | @abstractmethod 44 | def memory_variables(self) -> List[str]: 45 | """The string keys this memory class will add to chain inputs.""" 46 | 47 | @abstractmethod 48 | def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]: 49 | """Return key-value pairs given the text input to the chain.""" 50 | 51 | @abstractmethod 52 | def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None: 53 | """Save the context of this chain run to memory.""" 54 | 55 | @abstractmethod 56 | def clear(self) -> None: 57 | """Clear memory contents.""" 58 | -------------------------------------------------------------------------------- /build/lib/lmchain/schema/prompt.py: -------------------------------------------------------------------------------- 1 | # 这段代码定义了一个名为 PromptValue 的抽象基类,该类用于表示任何语言模型的输入。 2 | # 这个类继承自 Serializable 和 ABC(Abstract Base Class),意味着它是一个可序列化的抽象基类。 3 | 4 | 5 | # 导入 __future__ 模块中的 annotations 功能,使得在 Python 3.7 以下版本中也可以使用类型注解的延迟评估功能。 6 | from __future__ import annotations 7 | 8 | # 导入 abc 模块中的 ABC(抽象基类)和 abstractmethod(抽象方法)装饰器。 9 | from abc import ABC, abstractmethod 10 | # 导入 typing 模块中的 List 类型,用于类型注解。 11 | from typing import List 12 | 13 | # 从 lmchain.load.serializable 模块中导入 Serializable 类,用于序列化和反序列化对象。 14 | from lmchain.load.serializable import Serializable 15 | # 从 lmchain.schema.messages 模块中导入 BaseMessage 类,作为消息基类。 16 | from lmchain.schema.messages import BaseMessage 17 | 18 | 19 | # 定义一个名为 PromptValue 的抽象基类,继承自 Serializable 和 ABC。 20 | class PromptValue(Serializable, ABC): 21 | """Base abstract class for inputs to any language model. 22 | 23 | PromptValues can be converted to both LLM (pure text-generation) inputs and 24 | ChatModel inputs. 25 | """ 26 | 27 | # 类方法,返回一个布尔值,表示这个类是否可序列化。在这个类中,始终返回 True。 28 | @classmethod 29 | def is_lc_serializable(cls) -> bool: 30 | """Return whether this class is serializable.""" 31 | return True 32 | 33 | # 抽象方法,需要子类实现。返回一个字符串,表示 prompt 的值。 34 | @abstractmethod 35 | def to_string(self) -> str: 36 | """Return prompt value as string.""" 37 | 38 | # 抽象方法,需要子类实现。返回一个 BaseMessage 对象的列表,表示 prompt。 39 | @abstractmethod 40 | def to_messages(self) -> List[BaseMessage]: 41 | """Return prompt as a list of Messages.""" 42 | -------------------------------------------------------------------------------- /build/lib/lmchain/schema/runnable/__init__.py: -------------------------------------------------------------------------------- 1 | name = "schema.runnable" -------------------------------------------------------------------------------- /build/lib/lmchain/schema/runnable/config.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/virgo777/lmchain/3fc4a55df07100d4fe0beef4da89e3fbc6288938/build/lib/lmchain/schema/runnable/config.py -------------------------------------------------------------------------------- /build/lib/lmchain/schema/schema.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Any, Literal, Sequence, Union 4 | 5 | from lmchain.load.serializable import Serializable 6 | from lmchain.schema.messages import BaseMessage 7 | 8 | 9 | class AgentAction(Serializable): 10 | """A full description of an action for an ActionAgent to execute.""" 11 | 12 | tool: str 13 | """The name of the Tool to execute.""" 14 | tool_input: Union[str, dict] 15 | """The input to pass in to the Tool.""" 16 | log: str 17 | """Additional information to log about the action. 18 | This log can be used in a few ways. First, it can be used to audit 19 | what exactly the LLM predicted to lead to this (tool, tool_input). 20 | Second, it can be used in future iterations to show the LLMs prior 21 | thoughts. This is useful when (tool, tool_input) does not contain 22 | full information about the LLM prediction (for example, any `thought` 23 | before the tool/tool_input).""" 24 | type: Literal["AgentAction"] = "AgentAction" 25 | 26 | def __init__( 27 | self, tool: str, tool_input: Union[str, dict], log: str, **kwargs: Any 28 | ): 29 | """Override init to support instantiation by position for backward compat.""" 30 | super().__init__(tool=tool, tool_input=tool_input, log=log, **kwargs) 31 | 32 | @classmethod 33 | def is_lc_serializable(cls) -> bool: 34 | """Return whether or not the class is serializable.""" 35 | return True 36 | 37 | 38 | class AgentActionMessageLog(AgentAction): 39 | message_log: Sequence[BaseMessage] 40 | """Similar to log, this can be used to pass along extra 41 | information about what exact messages were predicted by the LLM 42 | before parsing out the (tool, tool_input). This is again useful 43 | if (tool, tool_input) cannot be used to fully recreate the LLM 44 | prediction, and you need that LLM prediction (for future agent iteration). 45 | Compared to `log`, this is useful when the underlying LLM is a 46 | ChatModel (and therefore returns messages rather than a string).""" 47 | # Ignoring type because we're overriding the type from AgentAction. 48 | # And this is the correct thing to do in this case. 49 | # The type literal is used for serialization purposes. 50 | type: Literal["AgentActionMessageLog"] = "AgentActionMessageLog" # type: ignore 51 | 52 | 53 | class AgentFinish(Serializable): 54 | """The final return value of an ActionAgent.""" 55 | 56 | return_values: dict 57 | """Dictionary of return values.""" 58 | log: str 59 | """Additional information to log about the return value. 60 | This is used to pass along the full LLM prediction, not just the parsed out 61 | return value. For example, if the full LLM prediction was 62 | `Final Answer: 2` you may want to just return `2` as a return value, but pass 63 | along the full string as a `log` (for debugging or observability purposes). 64 | """ 65 | type: Literal["AgentFinish"] = "AgentFinish" 66 | 67 | def __init__(self, return_values: dict, log: str, **kwargs: Any): 68 | """Override init to support instantiation by position for backward compat.""" 69 | super().__init__(return_values=return_values, log=log, **kwargs) 70 | 71 | @classmethod 72 | def is_lc_serializable(cls) -> bool: 73 | """Return whether or not the class is serializable.""" 74 | return True 75 | -------------------------------------------------------------------------------- /build/lib/lmchain/tool_register.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | import traceback 3 | from copy import deepcopy 4 | from pprint import pformat 5 | from types import GenericAlias 6 | from typing import get_origin, Annotated 7 | 8 | _TOOL_HOOKS = {} 9 | _TOOL_DESCRIPTIONS = {} 10 | 11 | 12 | def register_tool(func: callable): 13 | tool_name = func.__name__ 14 | tool_description = inspect.getdoc(func).strip() 15 | python_params = inspect.signature(func).parameters 16 | tool_params = [] 17 | for name, param in python_params.items(): 18 | annotation = param.annotation 19 | if annotation is inspect.Parameter.empty: 20 | raise TypeError(f"Parameter `{name}` missing type annotation") 21 | if get_origin(annotation) != Annotated: 22 | raise TypeError(f"Annotation type for `{name}` must be typing.Annotated") 23 | 24 | typ, (description, required) = annotation.__origin__, annotation.__metadata__ 25 | typ: str = str(typ) if isinstance(typ, GenericAlias) else typ.__name__ 26 | if not isinstance(description, str): 27 | raise TypeError(f"Description for `{name}` must be a string") 28 | if not isinstance(required, bool): 29 | raise TypeError(f"Required for `{name}` must be a bool") 30 | 31 | tool_params.append({ 32 | "name": name, 33 | "description": description, 34 | "type": typ, 35 | "required": required 36 | }) 37 | tool_def = { 38 | "name": tool_name, 39 | "description": tool_description, 40 | "params": tool_params 41 | } 42 | 43 | # print("[registered tool] " + pformat(tool_def)) 44 | _TOOL_HOOKS[tool_name] = func 45 | _TOOL_DESCRIPTIONS[tool_name] = tool_def 46 | 47 | return func 48 | 49 | 50 | def dispatch_tool(tool_name: str, tool_params: dict) -> str: 51 | if tool_name not in _TOOL_HOOKS: 52 | return f"Tool `{tool_name}` not found. Please use a provided tool." 53 | tool_call = _TOOL_HOOKS[tool_name] 54 | try: 55 | ret = tool_call(**tool_params) 56 | except: 57 | ret = traceback.format_exc() 58 | return str(ret) 59 | 60 | 61 | def get_tools() -> dict: 62 | return deepcopy(_TOOL_DESCRIPTIONS) 63 | 64 | 65 | # Tool Definitions 66 | 67 | # @register_tool 68 | # def random_number_generator( 69 | # seed: Annotated[int, 'The random seed used by the generator', True], 70 | # range: Annotated[tuple[int, int], 'The range of the generated numbers', True], 71 | # ) -> int: 72 | # """ 73 | # Generates a random number x, s.t. range[0] <= x < range[1] 74 | # """ 75 | # if not isinstance(seed, int): 76 | # raise TypeError("Seed must be an integer") 77 | # if not isinstance(range, tuple): 78 | # raise TypeError("Range must be a tuple") 79 | # if not isinstance(range[0], int) or not isinstance(range[1], int): 80 | # raise TypeError("Range must be a tuple of integers") 81 | # 82 | # import random 83 | # return random.Random(seed).randint(*range) 84 | # 85 | # 86 | # @register_tool 87 | # def get_weather( 88 | # city_name: Annotated[str, 'The name of the city to be queried', True], 89 | # ) -> str: 90 | # """ 91 | # Get the current weather for `city_name` 92 | # """ 93 | # 94 | # if not isinstance(city_name, str): 95 | # raise TypeError("City name must be a string") 96 | # 97 | # key_selection = { 98 | # "current_condition": ["temp_C", "FeelsLikeC", "humidity", "weatherDesc", "observation_time"], 99 | # } 100 | # import requests 101 | # try: 102 | # resp = requests.get(f"https://wttr.in/{city_name}?format=j1") 103 | # resp.raise_for_status() 104 | # resp = resp.json() 105 | # ret = {k: {_v: resp[k][0][_v] for _v in v} for k, v in key_selection.items()} 106 | # except: 107 | # import traceback 108 | # ret = "Error encountered while fetching weather data!\n" + traceback.format_exc() 109 | # 110 | # return str(ret) 111 | # 112 | # 113 | # @register_tool 114 | # def get_customer_weather(location: Annotated[str, "需要查询位置的名称,用中文表示的地点名称", True] = ""): 115 | # """ 自己编写的天气查询函数""" 116 | # 117 | # if location == "上海": 118 | # return 23.0 119 | # elif location == "南京": 120 | # return 25.0 121 | # else: 122 | # return "未查询相关内容" 123 | # 124 | # 125 | # @register_tool 126 | # def get_random_fun(location: Annotated[str, "随机参数", True] = ""): 127 | # """编写的一个混淆随机函数""" 128 | # location = location 129 | # return "你上当啦" 130 | 131 | 132 | if __name__ == "__main__": 133 | print(dispatch_tool("get_weather", {"city_name": "shanghai"})) 134 | print(get_tools()) 135 | -------------------------------------------------------------------------------- /build/lib/lmchain/tools/__init__.py: -------------------------------------------------------------------------------- 1 | name = "tools" -------------------------------------------------------------------------------- /build/lib/lmchain/tools/tool_register.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | import traceback 3 | from copy import deepcopy 4 | from pprint import pformat 5 | from types import GenericAlias 6 | from typing import get_origin, Annotated 7 | 8 | _TOOL_HOOKS = {} 9 | _TOOL_DESCRIPTIONS = {} 10 | 11 | 12 | def register_tool(func: callable): 13 | tool_name = func.__name__ 14 | tool_description = inspect.getdoc(func).strip() 15 | python_params = inspect.signature(func).parameters 16 | tool_params = [] 17 | for name, param in python_params.items(): 18 | annotation = param.annotation 19 | if annotation is inspect.Parameter.empty: 20 | raise TypeError(f"Parameter `{name}` missing type annotation") 21 | if get_origin(annotation) != Annotated: 22 | raise TypeError(f"Annotation type for `{name}` must be typing.Annotated") 23 | 24 | typ, (description, required) = annotation.__origin__, annotation.__metadata__ 25 | typ: str = str(typ) if isinstance(typ, GenericAlias) else typ.__name__ 26 | if not isinstance(description, str): 27 | raise TypeError(f"Description for `{name}` must be a string") 28 | if not isinstance(required, bool): 29 | raise TypeError(f"Required for `{name}` must be a bool") 30 | 31 | tool_params.append({ 32 | "name": name, 33 | "description": description, 34 | "type": typ, 35 | "required": required 36 | }) 37 | tool_def = { 38 | "name": tool_name, 39 | "description": tool_description, 40 | "params": tool_params 41 | } 42 | 43 | # print("[registered tool] " + pformat(tool_def)) 44 | _TOOL_HOOKS[tool_name] = func 45 | _TOOL_DESCRIPTIONS[tool_name] = tool_def 46 | 47 | return func 48 | 49 | 50 | def dispatch_tool(tool_name: str, tool_params: dict) -> str: 51 | if tool_name not in _TOOL_HOOKS: 52 | return f"Tool `{tool_name}` not found. Please use a provided tool." 53 | tool_call = _TOOL_HOOKS[tool_name] 54 | try: 55 | ret = tool_call(**tool_params) 56 | except: 57 | ret = traceback.format_exc() 58 | return str(ret) 59 | 60 | 61 | def get_tools() -> dict: 62 | return deepcopy(_TOOL_DESCRIPTIONS) 63 | 64 | 65 | # Tool Definitions 66 | 67 | @register_tool 68 | def random_number_generator( 69 | seed: Annotated[int, 'The random seed used by the generator', True], 70 | range: Annotated[tuple[int, int], 'The range of the generated numbers', True], 71 | ) -> int: 72 | """ 73 | Generates a random number x, s.t. range[0] <= x < range[1] 74 | """ 75 | if not isinstance(seed, int): 76 | raise TypeError("Seed must be an integer") 77 | if not isinstance(range, tuple): 78 | raise TypeError("Range must be a tuple") 79 | if not isinstance(range[0], int) or not isinstance(range[1], int): 80 | raise TypeError("Range must be a tuple of integers") 81 | 82 | import random 83 | return random.Random(seed).randint(*range) 84 | 85 | 86 | @register_tool 87 | def get_weather( 88 | city_name: Annotated[str, 'The name of the city to be queried', True], 89 | ) -> str: 90 | """ 91 | Get the current weather for `city_name` 92 | """ 93 | 94 | if not isinstance(city_name, str): 95 | raise TypeError("City name must be a string") 96 | 97 | key_selection = { 98 | "current_condition": ["temp_C", "FeelsLikeC", "humidity", "weatherDesc", "observation_time"], 99 | } 100 | import requests 101 | try: 102 | resp = requests.get(f"https://wttr.in/{city_name}?format=j1") 103 | resp.raise_for_status() 104 | resp = resp.json() 105 | ret = {k: {_v: resp[k][0][_v] for _v in v} for k, v in key_selection.items()} 106 | except: 107 | import traceback 108 | ret = "Error encountered while fetching weather data!\n" + traceback.format_exc() 109 | 110 | return str(ret) 111 | 112 | 113 | if __name__ == "__main__": 114 | # print(dispatch_tool("get_weather", {"city_name": "beijing"})) 115 | tools = (get_tools()) 116 | import zhipuai as zhipuai 117 | 118 | zhipuai.api_key = "1f565e40af1198e11ff1fd8a5b42771d.SjNfezc40YFsz2KC" # 控制台中获取的 APIKey 信息 119 | 120 | query = "今天shanghai的天气是什么?" 121 | prompt = f""" 122 | 你现在是一个专业的人工智能助手,你现在的需求是{query}。而你需要借助于工具在{tools}中找到对应的函数,用json格式返回对应的函数名和需要的参数。 123 | 124 | 只返回json格式的函数名和需要的参数,不要做描述。 125 | 126 | 如果没有找到合适的函数,则返回:'未找到合适参数,请提供更详细的描述。' 127 | """ 128 | 129 | from lmchain.agents import llmMultiAgent 130 | 131 | llm = llmMultiAgent.AgentZhipuAI() 132 | res = llm(prompt) 133 | print(res) 134 | 135 | import json 136 | 137 | res_dict = json.loads(res) 138 | res_dict = json.loads(res_dict) 139 | 140 | print(dispatch_tool(tool_name=res_dict["function_name"], tool_params=res_dict["params"])) 141 | 142 | 143 | 144 | -------------------------------------------------------------------------------- /build/lib/lmchain/utils/__init__.py: -------------------------------------------------------------------------------- 1 | name = "utils" -------------------------------------------------------------------------------- /build/lib/lmchain/utils/formatting.py: -------------------------------------------------------------------------------- 1 | """Utilities for formatting strings.""" 2 | from string import Formatter 3 | from typing import Any, List, Mapping, Sequence, Union 4 | 5 | 6 | class StrictFormatter(Formatter): 7 | """A subclass of formatter that checks for extra keys.""" 8 | 9 | def check_unused_args( 10 | self, 11 | used_args: Sequence[Union[int, str]], 12 | args: Sequence, 13 | kwargs: Mapping[str, Any], 14 | ) -> None: 15 | """Check to see if extra parameters are passed.""" 16 | extra = set(kwargs).difference(used_args) 17 | if extra: 18 | raise KeyError(extra) 19 | 20 | def vformat( 21 | self, format_string: str, args: Sequence, kwargs: Mapping[str, Any] 22 | ) -> str: 23 | """Check that no arguments are provided.""" 24 | if len(args) > 0: 25 | raise ValueError( 26 | "No arguments should be provided, " 27 | "everything should be passed as keyword arguments." 28 | ) 29 | return super().vformat(format_string, args, kwargs) 30 | 31 | def validate_input_variables( 32 | self, format_string: str, input_variables: List[str] 33 | ) -> None: 34 | dummy_inputs = {input_variable: "foo" for input_variable in input_variables} 35 | super().format(format_string, **dummy_inputs) 36 | 37 | 38 | formatter = StrictFormatter() 39 | -------------------------------------------------------------------------------- /build/lib/lmchain/utils/input.py: -------------------------------------------------------------------------------- 1 | """Handle chained inputs.""" 2 | from typing import Dict, List, Optional, TextIO 3 | 4 | _TEXT_COLOR_MAPPING = { 5 | "blue": "36;1", 6 | "yellow": "33;1", 7 | "pink": "38;5;200", 8 | "green": "32;1", 9 | "red": "31;1", 10 | } 11 | 12 | 13 | def get_color_mapping( 14 | items: List[str], excluded_colors: Optional[List] = None 15 | ) -> Dict[str, str]: 16 | """Get mapping for items to a support color.""" 17 | colors = list(_TEXT_COLOR_MAPPING.keys()) 18 | if excluded_colors is not None: 19 | colors = [c for c in colors if c not in excluded_colors] 20 | color_mapping = {item: colors[i % len(colors)] for i, item in enumerate(items)} 21 | return color_mapping 22 | 23 | 24 | def get_colored_text(text: str, color: str) -> str: 25 | """Get colored text.""" 26 | color_str = _TEXT_COLOR_MAPPING[color] 27 | return f"\u001b[{color_str}m\033[1;3m{text}\u001b[0m" 28 | 29 | 30 | def get_bolded_text(text: str) -> str: 31 | """Get bolded text.""" 32 | return f"\033[1m{text}\033[0m" 33 | 34 | 35 | def print_text( 36 | text: str, color: Optional[str] = None, end: str = "", file: Optional[TextIO] = None 37 | ) -> None: 38 | """Print text with highlighting and no end characters.""" 39 | text_to_print = get_colored_text(text, color) if color else text 40 | print(text_to_print, end=end, file=file) 41 | if file: 42 | file.flush() # ensure all printed content are written to file 43 | -------------------------------------------------------------------------------- /build/lib/lmchain/utils/loading.py: -------------------------------------------------------------------------------- 1 | """Utilities for loading configurations from langchain-hub.""" 2 | 3 | import os 4 | import re 5 | import tempfile 6 | from pathlib import Path, PurePosixPath 7 | from typing import Any, Callable, Optional, Set, TypeVar, Union 8 | from urllib.parse import urljoin 9 | 10 | import requests 11 | 12 | DEFAULT_REF = os.environ.get("LANGCHAIN_HUB_DEFAULT_REF", "master") 13 | URL_BASE = os.environ.get( 14 | "LANGCHAIN_HUB_URL_BASE", 15 | "https://raw.githubusercontent.com/hwchase17/langchain-hub/{ref}/", 16 | ) 17 | HUB_PATH_RE = re.compile(r"lc(?P@[^:]+)?://(?P.*)") 18 | 19 | T = TypeVar("T") 20 | 21 | 22 | def try_load_from_hub( 23 | path: Union[str, Path], 24 | loader: Callable[[str], T], 25 | valid_prefix: str, 26 | valid_suffixes: Set[str], 27 | **kwargs: Any, 28 | ) -> Optional[T]: 29 | """Load configuration from hub. Returns None if path is not a hub path.""" 30 | if not isinstance(path, str) or not (match := HUB_PATH_RE.match(path)): 31 | return None 32 | ref, remote_path_str = match.groups() 33 | ref = ref[1:] if ref else DEFAULT_REF 34 | remote_path = Path(remote_path_str) 35 | if remote_path.parts[0] != valid_prefix: 36 | return None 37 | if remote_path.suffix[1:] not in valid_suffixes: 38 | raise ValueError(f"Unsupported file type, must be one of {valid_suffixes}.") 39 | 40 | # Using Path with URLs is not recommended, because on Windows 41 | # the backslash is used as the path separator, which can cause issues 42 | # when working with URLs that use forward slashes as the path separator. 43 | # Instead, use PurePosixPath to ensure that forward slashes are used as the 44 | # path separator, regardless of the operating system. 45 | full_url = urljoin(URL_BASE.format(ref=ref), PurePosixPath(remote_path).__str__()) 46 | 47 | r = requests.get(full_url, timeout=5) 48 | if r.status_code != 200: 49 | raise ValueError(f"Could not find file at {full_url}") 50 | with tempfile.TemporaryDirectory() as tmpdirname: 51 | file = Path(tmpdirname) / remote_path.name 52 | with open(file, "wb") as f: 53 | f.write(r.content) 54 | return loader(str(file), **kwargs) 55 | -------------------------------------------------------------------------------- /build/lib/lmchain/utils/math.py: -------------------------------------------------------------------------------- 1 | """Math utils.""" 2 | import logging 3 | from typing import List, Optional, Tuple, Union 4 | 5 | import numpy as np 6 | 7 | logger = logging.getLogger(__name__) 8 | 9 | Matrix = Union[List[List[float]], List[np.ndarray], np.ndarray] 10 | 11 | 12 | def cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray: 13 | """Row-wise cosine similarity between two equal-width matrices.""" 14 | if len(X) == 0 or len(Y) == 0: 15 | return np.array([]) 16 | 17 | X = np.array(X) 18 | Y = np.array(Y) 19 | if X.shape[1] != Y.shape[1]: 20 | raise ValueError( 21 | f"Number of columns in X and Y must be the same. X has shape {X.shape} " 22 | f"and Y has shape {Y.shape}." 23 | ) 24 | try: 25 | import simsimd as simd 26 | 27 | X = np.array(X, dtype=np.float32) 28 | Y = np.array(Y, dtype=np.float32) 29 | Z = 1 - simd.cdist(X, Y, metric="cosine") 30 | if isinstance(Z, float): 31 | return np.array([Z]) 32 | return Z 33 | except ImportError: 34 | logger.info( 35 | "Unable to import simsimd, defaulting to NumPy implementation. If you want " 36 | "to use simsimd please install with `pip install simsimd`." 37 | ) 38 | X_norm = np.linalg.norm(X, axis=1) 39 | Y_norm = np.linalg.norm(Y, axis=1) 40 | # Ignore divide by zero errors run time warnings as those are handled below. 41 | with np.errstate(divide="ignore", invalid="ignore"): 42 | similarity = np.dot(X, Y.T) / np.outer(X_norm, Y_norm) 43 | similarity[np.isnan(similarity) | np.isinf(similarity)] = 0.0 44 | return similarity 45 | 46 | 47 | def cosine_similarity_top_k( 48 | X: Matrix, 49 | Y: Matrix, 50 | top_k: Optional[int] = 5, 51 | score_threshold: Optional[float] = None, 52 | ) -> Tuple[List[Tuple[int, int]], List[float]]: 53 | """Row-wise cosine similarity with optional top-k and score threshold filtering. 54 | 55 | Args: 56 | X: Matrix. 57 | Y: Matrix, same width as X. 58 | top_k: Max number of results to return. 59 | score_threshold: Minimum cosine similarity of results. 60 | 61 | Returns: 62 | Tuple of two lists. First contains two-tuples of indices (X_idx, Y_idx), 63 | second contains corresponding cosine similarities. 64 | """ 65 | if len(X) == 0 or len(Y) == 0: 66 | return [], [] 67 | score_array = cosine_similarity(X, Y) 68 | score_threshold = score_threshold or -1.0 69 | score_array[score_array < score_threshold] = 0 70 | top_k = min(top_k or len(score_array), np.count_nonzero(score_array)) 71 | top_k_idxs = np.argpartition(score_array, -top_k, axis=None)[-top_k:] 72 | top_k_idxs = top_k_idxs[np.argsort(score_array.ravel()[top_k_idxs])][::-1] 73 | ret_idxs = np.unravel_index(top_k_idxs, score_array.shape) 74 | scores = score_array.ravel()[top_k_idxs].tolist() 75 | return list(zip(*ret_idxs)), scores # type: ignore 76 | -------------------------------------------------------------------------------- /build/lib/lmchain/vectorstores/__init__.py: -------------------------------------------------------------------------------- 1 | name = "vectorstores" -------------------------------------------------------------------------------- /build/lib/lmchain/vectorstores/chroma.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | from langchain.docstore.document import Document 4 | from langchain.text_splitter import RecursiveCharacterTextSplitter 5 | from lmchain.embeddings import embeddings 6 | from lmchain.vectorstores import laiss 7 | 8 | from langchain.memory import ConversationBufferMemory 9 | from langchain.prompts import ( 10 | ChatPromptTemplate, # 用于构建聊天模板的类 11 | MessagesPlaceholder, # 用于在模板中插入消息占位的类 12 | SystemMessagePromptTemplate, # 用于构建系统消息模板的类 13 | HumanMessagePromptTemplate # 用于构建人类消息模板的类 14 | ) 15 | from langchain.chains import ConversationChain 16 | 17 | class Chroma: 18 | def __init__(self,documents,embedding_tool,chunk_size = 1280,chunk_overlap = 50,source = "这是一份辅助材料"): 19 | """ 20 | :param document: 输入的文本内容,只要一个text文本 21 | :param chunk_size: 切分后每段的字数 22 | :param chunk_overlap: 每个相隔段落重叠的字数 23 | :param source: 文本名称/文本地址 24 | """ 25 | self.text_splitter = RecursiveCharacterTextSplitter( chunk_size=chunk_size, chunk_overlap=chunk_overlap) 26 | self.embedding_tool = embedding_tool 27 | 28 | self.lmaiss = laiss.LMASS() #这里是用于将文本转化为vector,并且计算query相应的相似度的类 29 | 30 | self.documents = [] 31 | self.vectorstores = [] 32 | 33 | "---------------------------" 34 | for document in documents: 35 | document = [Document(page_content=document, metadata={"source": source})] #对输入的document进行格式化处理 36 | doc= self.text_splitter.split_documents(document) #根据 37 | self.documents.extend(doc) 38 | 39 | vector = self.lmaiss.from_documents(document, embedding_class=self.embedding_tool) 40 | self.vectorstores.extend(vector) 41 | 42 | # def __call__(self, query): 43 | # query_embedding = self.embedding_tool.embed_query(query) 44 | # 45 | # #根据query查找最近的那个序列 46 | # close_id = self.lmaiss.get_similarity_vector_indexs(query_embedding, self.vectorstore, k=1)[0] 47 | # #查找最近的那个段落id 48 | # doc = self.documents[close_id] 49 | # 50 | # 51 | # return doc 52 | 53 | def similarity_search(self, query): 54 | query_embedding = self.embedding_tool.embed_query(query) 55 | 56 | #根据query查找最近的那个序列 57 | close_id = self.lmaiss.get_similarity_vector_indexs(query_embedding, self.vectorstores, k=1)[0] 58 | #查找最近的那个段落id 59 | doc = self.documents[close_id] 60 | return doc 61 | 62 | def add_texts(self,texts,metadata = ""): 63 | for document in texts: 64 | document = [Document(page_content=document, metadata={"source": metadata})] #对输入的document进行格式化处理 65 | doc= self.text_splitter.split_documents(document) #根据 66 | self.documents.extend(doc) 67 | 68 | vector = self.lmaiss.from_documents(document, embedding_class=self.embedding_tool) 69 | self.vectorstores.extend(vector) 70 | 71 | return True 72 | 73 | 74 | def from_texts(texts,embeddings,source = ""): 75 | docsearch = Chroma(documents = texts,embedding_tool=embeddings,source = source) 76 | return docsearch 77 | 78 | 79 | # def from_texts(texts,embeddings): 80 | # embs = embeddings.embed_documents(texts=texts) 81 | # return embs -------------------------------------------------------------------------------- /build/lib/lmchain/vectorstores/embeddings.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | warnings.filterwarnings("ignore") 3 | 4 | import asyncio 5 | from abc import ABC, abstractmethod 6 | from typing import List 7 | 8 | 9 | class Embeddings(ABC): 10 | """Interface for embedding models.""" 11 | 12 | @abstractmethod 13 | def embed_documents(self, texts: List[str]) -> List[List[float]]: 14 | """Embed search docs.""" 15 | 16 | @abstractmethod 17 | def embed_query(self, text: str) -> List[float]: 18 | """Embed query text.""" 19 | 20 | async def aembed_documents(self, texts: List[str]) -> List[List[float]]: 21 | """Asynchronous Embed search docs.""" 22 | return await asyncio.get_running_loop().run_in_executor( 23 | None, self.embed_documents, texts 24 | ) 25 | 26 | async def aembed_query(self, text: str) -> List[float]: 27 | """Asynchronous Embed query text.""" 28 | return await asyncio.get_running_loop().run_in_executor( 29 | None, self.embed_query, text 30 | ) 31 | 32 | 33 | # class LMEmbedding(Embeddings): 34 | # from modelscope.pipelines import pipeline 35 | # from modelscope.utils.constant import Tasks 36 | # pipeline_se = pipeline(Tasks.sentence_embedding, model='thomas/text2vec-base-chinese', model_revision='v1.0.0', 37 | # device="cuda") 38 | # 39 | # def _costruct_inputs(self, texts): 40 | # inputs = { 41 | # "source_sentence": texts 42 | # } 43 | # 44 | # return inputs 45 | # 46 | # def embed_documents(self, texts: List[str]) -> List[List[float]]: 47 | # """Embed search docs.""" 48 | # 49 | # inputs = self._costruct_inputs(texts) 50 | # result_embeddings = self.pipeline_se(input=inputs) 51 | # return result_embeddings["text_embedding"] 52 | # 53 | # def embed_query(self, text: str) -> List[float]: 54 | # """Embed query text.""" 55 | # inputs = self._costruct_inputs([text]) 56 | # result_embeddings = self.pipeline_se(input=inputs) 57 | # return result_embeddings["text_embedding"] 58 | 59 | 60 | class GLMEmbedding(Embeddings): 61 | import zhipuai as zhipuai 62 | zhipuai.api_key = "1f565e40af1198e11ff1fd8a5b42771d.SjNfezc40YFsz2KC" # 控制台中获取的 APIKey 信息 63 | 64 | def _costruct_inputs(self, texts): 65 | inputs = { 66 | "source_sentence": texts 67 | } 68 | 69 | return inputs 70 | 71 | aembeddings = [] # 这个是为了在并发获取embedding_value时候使用的存储embedding_list内容。 72 | atexts = [] 73 | 74 | def embed_documents(self, texts: List[str]) -> List[List[float]]: 75 | """Embed search docs.""" 76 | result_embeddings = [] 77 | for text in texts: 78 | embedding = self.embed_query(text) 79 | result_embeddings.append(embedding) 80 | return result_embeddings 81 | 82 | def embed_query(self, text: str) -> List[float]: 83 | """Embed query text.""" 84 | result_embeddings = self.zhipuai.model_api.invoke( 85 | model="text_embedding", prompt=text) 86 | return result_embeddings["data"]["embedding"] 87 | 88 | def aembed_query(self, text: str) -> List[float]: 89 | """Embed query text.""" 90 | result_embeddings = self.zhipuai.model_api.invoke( 91 | model="text_embedding", prompt=text) 92 | emb = result_embeddings["data"]["embedding"] 93 | 94 | self.aembeddings.append(emb) 95 | self.atexts.append(text) 96 | 97 | # 这里实现了并发embedding获取 98 | def aembed_documents(self, texts: List[str], thread_num=5, wait_sec=0.3) -> List[List[float]]: 99 | import threading 100 | text_length = len(texts) 101 | thread_batch = text_length // thread_num 102 | 103 | for i in range(thread_batch): 104 | start = i * thread_num 105 | end = (i + 1) * thread_num 106 | 107 | # 创建线程列表 108 | threads = [] 109 | # 创建并启动5个线程,每个线程调用一个模型 110 | for text in texts[start:end]: 111 | thread = threading.Thread(target=self.aembed_query, args=(text,)) 112 | thread.start() 113 | threads.append(thread) 114 | for thread in threads: 115 | thread.join(wait_sec) # 设置超时时间为0.3秒 116 | return self.aembeddings, self.atexts 117 | 118 | 119 | if __name__ == '__main__': 120 | import time 121 | 122 | inputs = ["不可以,早晨喝牛奶不科学", "今天早晨喝牛奶不科学", "早晨喝牛奶不科学"] * 50 123 | 124 | start_time = time.time() 125 | aembeddings = (GLMEmbedding().aembed_documents(inputs, thread_num=5, thread_sec=0.3)) 126 | print(aembeddings) 127 | print(len(aembeddings)) 128 | end_time = time.time() 129 | # 计算函数执行时间并打印结果 130 | execution_time = end_time - start_time 131 | print(f"函数执行时间: {execution_time} 秒") 132 | print("----------------------------------------------------------------------------------") 133 | start_time = time.time() 134 | aembeddings = (GLMEmbedding().embed_documents(inputs)) 135 | print(len(aembeddings)) 136 | end_time = time.time() 137 | # 计算函数执行时间并打印结果 138 | execution_time = end_time - start_time 139 | print(f"函数执行时间: {execution_time} 秒") 140 | -------------------------------------------------------------------------------- /build/lib/lmchain/vectorstores/utils.py: -------------------------------------------------------------------------------- 1 | """Utility functions for working with vectors and vectorstores.""" 2 | 3 | from enum import Enum 4 | from typing import List, Tuple, Type 5 | 6 | import numpy as np 7 | 8 | from lmchain.schema.document import Document 9 | from lmchain.utils.math import cosine_similarity 10 | 11 | class DistanceStrategy(str, Enum): 12 | """Enumerator of the Distance strategies for calculating distances 13 | between vectors.""" 14 | 15 | EUCLIDEAN_DISTANCE = "EUCLIDEAN_DISTANCE" 16 | MAX_INNER_PRODUCT = "MAX_INNER_PRODUCT" 17 | DOT_PRODUCT = "DOT_PRODUCT" 18 | JACCARD = "JACCARD" 19 | COSINE = "COSINE" 20 | 21 | 22 | def maximal_marginal_relevance( 23 | query_embedding: np.ndarray, 24 | embedding_list: list, 25 | lambda_mult: float = 0.5, 26 | k: int = 4, 27 | ) -> List[int]: 28 | """Calculate maximal marginal relevance.""" 29 | if min(k, len(embedding_list)) <= 0: 30 | return [] 31 | if query_embedding.ndim == 1: 32 | query_embedding = np.expand_dims(query_embedding, axis=0) 33 | similarity_to_query = cosine_similarity(query_embedding, embedding_list)[0] 34 | most_similar = int(np.argmax(similarity_to_query)) 35 | idxs = [most_similar] 36 | selected = np.array([embedding_list[most_similar]]) 37 | while len(idxs) < min(k, len(embedding_list)): 38 | best_score = -np.inf 39 | idx_to_add = -1 40 | similarity_to_selected = cosine_similarity(embedding_list, selected) 41 | for i, query_score in enumerate(similarity_to_query): 42 | if i in idxs: 43 | continue 44 | redundant_score = max(similarity_to_selected[i]) 45 | equation_score = ( 46 | lambda_mult * query_score - (1 - lambda_mult) * redundant_score 47 | ) 48 | if equation_score > best_score: 49 | best_score = equation_score 50 | idx_to_add = i 51 | idxs.append(idx_to_add) 52 | selected = np.append(selected, [embedding_list[idx_to_add]], axis=0) 53 | return idxs 54 | 55 | 56 | def filter_complex_metadata( 57 | documents: List[Document], 58 | *, 59 | allowed_types: Tuple[Type, ...] = (str, bool, int, float), 60 | ) -> List[Document]: 61 | """Filter out metadata types that are not supported for a vector store.""" 62 | updated_documents = [] 63 | for document in documents: 64 | filtered_metadata = {} 65 | for key, value in document.metadata.items(): 66 | if not isinstance(value, allowed_types): 67 | continue 68 | filtered_metadata[key] = value 69 | 70 | document.metadata = filtered_metadata 71 | updated_documents.append(document) 72 | 73 | return updated_documents 74 | -------------------------------------------------------------------------------- /dist/LMchain-0.1.60-py3-none-any.whl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/virgo777/lmchain/3fc4a55df07100d4fe0beef4da89e3fbc6288938/dist/LMchain-0.1.60-py3-none-any.whl -------------------------------------------------------------------------------- /dist/LMchain-0.1.60.tar.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/virgo777/lmchain/3fc4a55df07100d4fe0beef4da89e3fbc6288938/dist/LMchain-0.1.60.tar.gz -------------------------------------------------------------------------------- /dist/LMchain-0.1.61-py3-none-any.whl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/virgo777/lmchain/3fc4a55df07100d4fe0beef4da89e3fbc6288938/dist/LMchain-0.1.61-py3-none-any.whl -------------------------------------------------------------------------------- /dist/LMchain-0.1.61.tar.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/virgo777/lmchain/3fc4a55df07100d4fe0beef4da89e3fbc6288938/dist/LMchain-0.1.61.tar.gz -------------------------------------------------------------------------------- /dist/LMchain-0.1.62-py3-none-any.whl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/virgo777/lmchain/3fc4a55df07100d4fe0beef4da89e3fbc6288938/dist/LMchain-0.1.62-py3-none-any.whl -------------------------------------------------------------------------------- /dist/LMchain-0.1.62.tar.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/virgo777/lmchain/3fc4a55df07100d4fe0beef4da89e3fbc6288938/dist/LMchain-0.1.62.tar.gz -------------------------------------------------------------------------------- /lmchain/__init__.py: -------------------------------------------------------------------------------- 1 | name = "lmchain" -------------------------------------------------------------------------------- /lmchain/__pycache__/__init__.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/virgo777/lmchain/3fc4a55df07100d4fe0beef4da89e3fbc6288938/lmchain/__pycache__/__init__.cpython-311.pyc -------------------------------------------------------------------------------- /lmchain/agents/__init__.py: -------------------------------------------------------------------------------- 1 | name = "agents" -------------------------------------------------------------------------------- /lmchain/agents/__pycache__/__init__.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/virgo777/lmchain/3fc4a55df07100d4fe0beef4da89e3fbc6288938/lmchain/agents/__pycache__/__init__.cpython-311.pyc -------------------------------------------------------------------------------- /lmchain/agents/__pycache__/llmAgent.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/virgo777/lmchain/3fc4a55df07100d4fe0beef4da89e3fbc6288938/lmchain/agents/__pycache__/llmAgent.cpython-311.pyc -------------------------------------------------------------------------------- /lmchain/agents/__pycache__/llmMultiAgent.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/virgo777/lmchain/3fc4a55df07100d4fe0beef4da89e3fbc6288938/lmchain/agents/__pycache__/llmMultiAgent.cpython-311.pyc -------------------------------------------------------------------------------- /lmchain/agents/llmAgent.py: -------------------------------------------------------------------------------- 1 | import time 2 | import logging 3 | import requests 4 | from typing import Optional, List, Dict, Mapping, Any 5 | 6 | import langchain 7 | from langchain.llms.base import LLM 8 | from langchain.cache import InMemoryCache 9 | 10 | logging.basicConfig(level=logging.INFO) 11 | # 启动llm的缓存 12 | langchain.llm_cache = InMemoryCache() 13 | 14 | 15 | class AgentChatGLM(LLM): 16 | # 模型服务url 17 | url = "http://127.0.0.1:7866/chat" 18 | #url = "http://192.168.3.20:7866/chat" #3050服务器上 19 | history = [] 20 | 21 | @property 22 | def _llm_type(self) -> str: 23 | return "chatglm" 24 | 25 | def _construct_query(self, prompt: str) -> Dict: 26 | """构造请求体 27 | """ 28 | query = {"query": prompt, "history": self.history} 29 | import json 30 | query = json.dumps(query) # 对请求参数进行JSON编码 31 | 32 | return query 33 | 34 | def _construct_query_tools(self, prompt: str , tools: list ) -> Dict: 35 | """构造请求体 36 | """ 37 | tools_info = {"role": "system", 38 | "content": "你现在是一个查找使用何种工具以及传递何种参数的工具助手,你会一步步的思考问题。你根据需求查找工具函数箱中最合适的工具函数,然后返回工具函数名称和所工具函数对应的参数,参数必须要和需求中的目标对应。", 39 | "tools": tools} 40 | query = {"query": prompt, "history": tools_info} 41 | import json 42 | query = json.dumps(query) # 对请求参数进行JSON编码 43 | 44 | return query 45 | 46 | 47 | @classmethod 48 | def _post(self, url: str, query: Dict) -> Any: 49 | 50 | """POST请求""" 51 | response = requests.post(url, data=query).json() 52 | return response 53 | 54 | def _call(self, prompt: str, stop: Optional[List[str]] = None, tools:list = None) -> str: 55 | """_call""" 56 | if tools == None: 57 | # construct query 58 | query = self._construct_query(prompt=prompt) 59 | 60 | # post 61 | response = self._post(url=self.url,query=query) 62 | 63 | response_chat = response["response"]; 64 | self.history = response["history"] 65 | 66 | return response_chat 67 | else: 68 | 69 | query = self._construct_query_tools(prompt=prompt,tools=tools) 70 | # post 71 | response = self._post(url=self.url, query=query) 72 | self.history = response["history"] #这个history要放上面 73 | response = response["response"] 74 | try: 75 | #import ast 76 | #response = ast.literal_eval(response) 77 | ret = tool_register.dispatch_tool(response["name"], response["parameters"]) 78 | response_chat = llm(prompt=ret) 79 | except: 80 | response_chat = response 81 | return str(response_chat) 82 | 83 | @property 84 | def _identifying_params(self) -> Mapping[str, Any]: 85 | """Get the identifying parameters. 86 | """ 87 | _param_dict = { 88 | "url": self.url 89 | } 90 | return _param_dict 91 | 92 | 93 | if __name__ == "__main__": 94 | 95 | import tool_register 96 | 97 | # 获取注册后的全部工具,并以json的形式返回 98 | tools = tool_register.get_tools() 99 | "--------------------------------------首先是对tools的定义---------------------------------------" 100 | 101 | llm = AgentChatGLM() 102 | llm.url = "http://192.168.3.20:7866/chat" 103 | while True: 104 | while True: 105 | human_input = input("Human: ") 106 | if human_input == "tools": 107 | break 108 | 109 | begin_time = time.time() * 1000 110 | # 请求模型 111 | response = llm(human_input) 112 | end_time = time.time() * 1000 113 | used_time = round(end_time - begin_time, 3) 114 | #logging.info(f"chatGLM process time: {used_time}ms") 115 | print(f"Chat: {response}") 116 | 117 | human_input = input("Human_with_tools_Ask: ") 118 | response = llm(prompt=human_input,tools=tools) 119 | print(f"Chat_with_tools_Que: {response}") 120 | 121 | 122 | 123 | 124 | -------------------------------------------------------------------------------- /lmchain/agents/llmMultiAgent.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import logging 4 | import requests 5 | from typing import Optional, List, Dict, Mapping, Any 6 | import langchain 7 | from langchain.llms.base import LLM 8 | from langchain.cache import InMemoryCache 9 | 10 | logging.basicConfig(level=logging.INFO) 11 | # 启动llm的缓存 12 | langchain.llm_cache = InMemoryCache() 13 | 14 | 15 | class AgentZhipuAI(LLM): 16 | import zhipuai as zhipuai 17 | # 模型服务url 18 | url = "127.0.0.1" 19 | zhipuai.api_key ="1f565e40af1198e11ff1fd8a5b42771d.SjNfezc40YFsz2KC"#控制台中获取的 APIKey 信息 20 | model = "chatglm_pro" # 大模型版本 21 | 22 | history = [] 23 | 24 | def getText(self,role, content): 25 | # role 是指定角色,content 是 prompt 内容 26 | jsoncon = {} 27 | jsoncon["role"] = role 28 | jsoncon["content"] = content 29 | self.history.append(jsoncon) 30 | return self.history 31 | 32 | @property 33 | def _llm_type(self) -> str: 34 | return "AgentZhipuAI" 35 | 36 | @classmethod 37 | def _post(self, url: str, query: Dict) -> Any: 38 | 39 | """POST请求""" 40 | response = requests.post(url, data=query).json() 41 | return response 42 | 43 | def _call(self, prompt: str, stop: Optional[List[str]] = None,role = "user") -> str: 44 | """_call""" 45 | # construct query 46 | response = self.zhipuai.model_api.invoke( 47 | model=self.model, 48 | prompt=self.getText(role=role, content=prompt) 49 | ) 50 | choices = (response['data']['choices'])[0] 51 | self.history.append(choices) 52 | return choices["content"] 53 | 54 | @property 55 | def _identifying_params(self) -> Mapping[str, Any]: 56 | """Get the identifying parameters. 57 | """ 58 | _param_dict = { 59 | "url": self.url 60 | } 61 | return _param_dict 62 | 63 | 64 | if __name__ == '__main__': 65 | from langchain.prompts import PromptTemplate 66 | from langchain.chains import LLMChain 67 | 68 | llm = AgentZhipuAI() 69 | 70 | # 没有输入变量的示例prompt 71 | no_input_prompt = PromptTemplate(input_variables=[], template="给我讲个笑话。") 72 | no_input_prompt.format() 73 | 74 | prompt = PromptTemplate( 75 | input_variables=["location", "street"], 76 | template="作为一名专业的旅游顾问,简单的说一下{location}有什么好玩的景点,特别是在{street}?只要说一个就可以。", 77 | ) 78 | 79 | chain = LLMChain(llm=llm, prompt=prompt) 80 | print(chain.run({"location": "南京", "street": "新街口"})) 81 | 82 | 83 | from langchain.chains import ConversationChain 84 | conversation = ConversationChain(llm=llm, verbose=True) 85 | 86 | output = conversation.predict(input="你好!") 87 | print(output) 88 | 89 | output = conversation.predict(input="南京是哪里的省会?") 90 | print(output) 91 | 92 | output = conversation.predict(input="那里有什么好玩的地方,简单的说一个就好。") 93 | print(output) 94 | 95 | -------------------------------------------------------------------------------- /lmchain/callbacks/__init__.py: -------------------------------------------------------------------------------- 1 | name = "callbacks" -------------------------------------------------------------------------------- /lmchain/callbacks/stdout.py: -------------------------------------------------------------------------------- 1 | """Callback Handler that prints to std out.""" 2 | from typing import Any, Dict, List, Optional 3 | 4 | from langchain.callbacks.base import BaseCallbackHandler 5 | from langchain.schema import AgentAction, AgentFinish, LLMResult 6 | from lmchain.utils.input import print_text 7 | 8 | 9 | class StdOutCallbackHandler(BaseCallbackHandler): 10 | """Callback Handler that prints to std out.""" 11 | 12 | def __init__(self, color: Optional[str] = None) -> None: 13 | """Initialize callback handler.""" 14 | self.color = color 15 | 16 | def on_llm_start( 17 | self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any 18 | ) -> None: 19 | """Print out the prompts.""" 20 | pass 21 | 22 | def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: 23 | """Do nothing.""" 24 | pass 25 | 26 | def on_llm_new_token(self, token: str, **kwargs: Any) -> None: 27 | """Do nothing.""" 28 | pass 29 | 30 | def on_llm_error(self, error: BaseException, **kwargs: Any) -> None: 31 | """Do nothing.""" 32 | pass 33 | 34 | def on_chain_start( 35 | self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any 36 | ) -> None: 37 | """Print out that we are entering a chain.""" 38 | class_name = serialized.get("name", serialized.get("id", [""])[-1]) 39 | print(f"\n\n\033[1m> Entering new {class_name} chain...\033[0m") 40 | 41 | def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: 42 | """Print out that we finished a chain.""" 43 | print("\n\033[1m> Finished chain.\033[0m") 44 | 45 | def on_chain_error(self, error: BaseException, **kwargs: Any) -> None: 46 | """Do nothing.""" 47 | pass 48 | 49 | def on_tool_start( 50 | self, 51 | serialized: Dict[str, Any], 52 | input_str: str, 53 | **kwargs: Any, 54 | ) -> None: 55 | """Do nothing.""" 56 | pass 57 | 58 | def on_agent_action( 59 | self, action: AgentAction, color: Optional[str] = None, **kwargs: Any 60 | ) -> Any: 61 | """Run on agent action.""" 62 | print_text(action.log, color=color or self.color) 63 | 64 | def on_tool_end( 65 | self, 66 | output: str, 67 | color: Optional[str] = None, 68 | observation_prefix: Optional[str] = None, 69 | llm_prefix: Optional[str] = None, 70 | **kwargs: Any, 71 | ) -> None: 72 | """If not the final action, print out observation.""" 73 | if observation_prefix is not None: 74 | print_text(f"\n{observation_prefix}") 75 | print_text(output, color=color or self.color) 76 | if llm_prefix is not None: 77 | print_text(f"\n{llm_prefix}") 78 | 79 | def on_tool_error(self, error: BaseException, **kwargs: Any) -> None: 80 | """Do nothing.""" 81 | pass 82 | 83 | def on_text( 84 | self, 85 | text: str, 86 | color: Optional[str] = None, 87 | end: str = "", 88 | **kwargs: Any, 89 | ) -> None: 90 | """Run when agent ends.""" 91 | print_text(text, color=color or self.color, end=end) 92 | 93 | def on_agent_finish( 94 | self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any 95 | ) -> None: 96 | """Run on agent end.""" 97 | print_text(finish.log, color=color or self.color, end="\n") 98 | -------------------------------------------------------------------------------- /lmchain/chains/__init__.py: -------------------------------------------------------------------------------- 1 | name = "chains" -------------------------------------------------------------------------------- /lmchain/chains/__pycache__/__init__.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/virgo777/lmchain/3fc4a55df07100d4fe0beef4da89e3fbc6288938/lmchain/chains/__pycache__/__init__.cpython-311.pyc -------------------------------------------------------------------------------- /lmchain/chains/__pycache__/cmd.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/virgo777/lmchain/3fc4a55df07100d4fe0beef4da89e3fbc6288938/lmchain/chains/__pycache__/cmd.cpython-311.pyc -------------------------------------------------------------------------------- /lmchain/chains/__pycache__/mathChain.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/virgo777/lmchain/3fc4a55df07100d4fe0beef4da89e3fbc6288938/lmchain/chains/__pycache__/mathChain.cpython-311.pyc -------------------------------------------------------------------------------- /lmchain/chains/__pycache__/urlRequestChain.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/virgo777/lmchain/3fc4a55df07100d4fe0beef4da89e3fbc6288938/lmchain/chains/__pycache__/urlRequestChain.cpython-311.pyc -------------------------------------------------------------------------------- /lmchain/chains/cmd.py: -------------------------------------------------------------------------------- 1 | #这里是执行对CMD命令进行调用的chain 2 | from langchain.chains.llm import LLMChain 3 | from langchain.prompts import PromptTemplate 4 | from lmchain.lmchain.agents import llmAgent 5 | import os,re 6 | 7 | class LLMCMDChain: 8 | def __init__(self ,llm): 9 | qa_prompt = PromptTemplate(template="""你现在根据需要完成对命令行的编写,要根据需求编写对应的在Windows系统终端运行的命令,不要用%question形参这种指代的参数形式,直接给出可以运行的命令。 10 | Question: 给我一个在Windows系统终端中可以准确执行{question}的命令。 11 | , 12 | input_variables=["question"], 13 | ) 14 | answer:""", input_variables=["question"], ) 15 | self.qa_chain = LLMChain(llm=llm, prompt=qa_prompt) 16 | self.pattern = r"```(.*?)\```" 17 | 18 | def run(self ,text): 19 | cmd_response = self.qa_chain.run(question=text) 20 | cmd_string = str(cmd_response).split("```")[-2][1:-1] 21 | os.system(cmd_string) 22 | return cmd_string 23 | -------------------------------------------------------------------------------- /lmchain/chains/conversationalRetrievalChain.py: -------------------------------------------------------------------------------- 1 | from langchain.docstore.document import Document 2 | from langchain.text_splitter import RecursiveCharacterTextSplitter 3 | from lmchain.embeddings import embeddings 4 | from lmchain.vectorstores import laiss 5 | from lmchain.agents import llmMultiAgent 6 | from langchain.memory import ConversationBufferMemory 7 | from langchain.prompts import ( 8 | ChatPromptTemplate, # 用于构建聊天模板的类 9 | MessagesPlaceholder, # 用于在模板中插入消息占位的类 10 | SystemMessagePromptTemplate, # 用于构建系统消息模板的类 11 | HumanMessagePromptTemplate # 用于构建人类消息模板的类 12 | ) 13 | from langchain.chains import ConversationChain 14 | 15 | class ConversationalRetrievalChain: 16 | def __init__(self,document,chunk_size = 1280,chunk_overlap = 50,file_name = "这是一份辅助材料"): 17 | """ 18 | :param document: 输入的文本内容,只要一个text文本 19 | :param chunk_size: 切分后每段的字数 20 | :param chunk_overlap: 每个相隔段落重叠的字数 21 | :param file_name: 文本名称/文本地址 22 | """ 23 | self.text_splitter = RecursiveCharacterTextSplitter( chunk_size=chunk_size, chunk_overlap=chunk_overlap) 24 | self.embedding_tool = embeddings.GLMEmbedding() 25 | 26 | self.lmaiss = laiss.LMASS() #这里是用于将文本转化为vector,并且计算query相应的相似度的类 27 | self.llm = llmMultiAgent.AgentZhipuAI() 28 | self.memory = ConversationBufferMemory(return_messages=True) 29 | 30 | conversation_prompt = ChatPromptTemplate.from_messages([ 31 | SystemMessagePromptTemplate.from_template("你是一个最强大的人工智能程序,可以知无不答,但是你不懂的东西会直接回答不知道。"), 32 | MessagesPlaceholder(variable_name="history"), # 历史消息占位符 33 | HumanMessagePromptTemplate.from_template("{input}") # 人类消息输入模板 34 | ]) 35 | 36 | self.qa_chain = ConversationChain(memory=self.memory, prompt=conversation_prompt, llm=self.llm) 37 | "---------------------------" 38 | document = [Document(page_content=document, metadata={"source": file_name})] #对输入的document进行格式化处理 39 | self.documents = self.text_splitter.split_documents(document) #根据 40 | self.vectorstore = self.lmaiss.from_documents(self.documents, embedding_class=self.embedding_tool) 41 | 42 | def __call__(self, query): 43 | query_embedding = self.embedding_tool.embed_query(query) 44 | 45 | #根据query查找最近的那个序列 46 | close_id = self.lmaiss.get_similarity_vector_indexs(query_embedding, self.vectorstore, k=1)[0] 47 | #查找最近的那个段落id 48 | doc = self.documents[close_id] 49 | 50 | #构建查询的query 51 | query = f"你现在要回答问题'{query}',你可以参考文献'{doc}',你如果找不到对应的内容,就从自己的记忆体中查找,就回答'请提供更为准确的查询内容',注意你要一步步的思考再回答。" 52 | result = (self.qa_chain.predict(input=query)) 53 | return result 54 | 55 | def predict(self,input): 56 | result = self.__call__(input) 57 | return result -------------------------------------------------------------------------------- /lmchain/chains/mathChain.py: -------------------------------------------------------------------------------- 1 | #这里是执行对CMD命令进行调用的chain 2 | 3 | from langchain.chains.llm import LLMChain 4 | from langchain.prompts import PromptTemplate 5 | from lmchain.lmchain.agents import llmAgent 6 | import os,re,math 7 | 8 | try: 9 | import numexpr # noqa: F401 10 | except ImportError: 11 | raise ImportError( 12 | "LMchain requires the numexpr package. " 13 | "Please install it with `pip install numexpr`." 14 | ) 15 | 16 | 17 | class LLMMathChain: 18 | def __init__(self ,llm): 19 | qa_prompt = PromptTemplate(template="""现在给你一个中文命令,请你把这个命令转化成数学公式。直接给出数学公式。这个公式会在numexpr包中调用。 20 | Question: 我现在需要计算{question},结果需要在numexpr包中调用。 21 | , 22 | input_variables=["question"], 23 | ) 24 | answer:""", input_variables=["question"], ) 25 | self.qa_chain = LLMChain(llm=llm, prompt=qa_prompt) 26 | 27 | 28 | def run(self ,text): 29 | cmd_response = self.qa_chain.run(question=text) 30 | result = self._evaluate_expression(str(cmd_response)) 31 | return result 32 | 33 | 34 | def _evaluate_expression(self, expression: str) -> str: 35 | import numexpr # noqa: F401 36 | 37 | try: 38 | local_dict = {"pi": math.pi, "e": math.e} 39 | output = str( 40 | numexpr.evaluate( 41 | expression.strip(), 42 | global_dict={}, # restrict access to globals 43 | local_dict=local_dict, # add common mathematical functions 44 | ) 45 | ) 46 | except Exception as e: 47 | raise ValueError( 48 | f'LMchain._evaluate("{expression}") raised error: {e}.' 49 | " Please try again with a valid numerical expression" 50 | ) 51 | 52 | # Remove any leading and trailing brackets from the output 53 | return re.sub(r"^\[|\]$", "", output) -------------------------------------------------------------------------------- /lmchain/chains/question_answering.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | -------------------------------------------------------------------------------- /lmchain/chains/subQuestChain.py: -------------------------------------------------------------------------------- 1 | from langchain.chains import LLMChain 2 | from langchain.prompts import PromptTemplate 3 | 4 | from tqdm import tqdm 5 | from lmchain.tools import tool_register 6 | 7 | 8 | class SubQuestChain: 9 | def __init__(self, llm): 10 | self.llm = llm 11 | 12 | def __call__(self, query=""): 13 | if query == "": 14 | raise "query需要填入查询问题" 15 | 16 | decomp_template = """ 17 | GENERAL INSTRUCTIONS 18 | You are a domain expert. Your task is to break down a complex question into simpler sub-parts. 19 | 20 | USER QUESTION 21 | {user_question} 22 | 23 | ANSWER FORMAT 24 | ["sub-questions_1","sub-questions_2","sub-questions_3",...] 25 | """ 26 | 27 | from langchain.prompts import PromptTemplate 28 | prompt = PromptTemplate( 29 | input_variables=["user_question"], 30 | template=decomp_template, 31 | ) 32 | 33 | from langchain.chains import LLMChain 34 | chain = LLMChain(llm=self.llm, prompt=prompt) 35 | response = (chain.run({"user_question": query})) 36 | 37 | import json 38 | sub_list = json.loads(response) 39 | 40 | return sub_list 41 | 42 | def run(self, query): 43 | sub_list = self.__call__(query) 44 | return sub_list 45 | 46 | 47 | if __name__ == '__main__': 48 | from lmchain.agents import llmMultiAgent 49 | 50 | llm = llmMultiAgent.AgentZhipuAI() 51 | 52 | subQC = SubQuestChain(llm) 53 | response = subQC.run(query="工商银行财报中,2024财年Q1与Q2 之间,利润增长了多少?") 54 | print(response) -------------------------------------------------------------------------------- /lmchain/chains/toolchain.py: -------------------------------------------------------------------------------- 1 | from langchain.chains import LLMChain 2 | from langchain.prompts import PromptTemplate 3 | 4 | from tqdm import tqdm 5 | from lmchain.tools import tool_register 6 | 7 | 8 | class GLMToolChain: 9 | def __init__(self, llm): 10 | 11 | self.llm = llm 12 | self.tool_register = tool_register 13 | self.tools = tool_register.get_tools() 14 | 15 | def __call__(self, query="", tools=None): 16 | 17 | if query == "": 18 | raise "query需要填入查询问题" 19 | if tools != None: 20 | self.tools = tools 21 | else: 22 | raise "将使用默认tools完成函数工具调用~" 23 | template = f""" 24 | 你现在是一个专业的人工智能助手,你现在的需求是{query}。而你需要借助于工具在{self.tools}中找到对应的函数,用json格式返回对应的函数名和参数。 25 | 函数名定义为function_name,参数名为params,还要求写入详细的形参与实参。 26 | 27 | 如果找到合适的函数,就返回json格式的函数名和需要的参数,不要回答任何描述和解释。 28 | 29 | 如果没有找到合适的函数,则返回:'未找到合适参数,请提供更详细的描述。' 30 | """ 31 | 32 | flag = True 33 | counter = 0 34 | while flag: 35 | try: 36 | res = self.llm(template) 37 | 38 | import json 39 | res_dict = json.loads(res) 40 | res_dict = json.loads(res_dict) 41 | flag = False 42 | except: 43 | # print("失败输出,现在开始重新验证") 44 | template = f""" 45 | 你现在是一个专业的人工智能助手,你现在的需求是{query}。而你需要借助于工具在{self.tools}中找到对应的函数,用json格式返回对应的函数名和参数。 46 | 函数名定义为function_name,参数名为params,还要求写入详细的形参与实参。 47 | 48 | 如果找到合适的函数,就返回json格式的函数名和需要的参数,不要回答任何描述和解释。 49 | 50 | 如果没有找到合适的函数,则返回:'未找到合适参数,请提供更详细的描述。' 51 | 52 | 你刚才生成了一组结果,但是返回不符合json格式,现在请你重新按json格式生成并返回结果。 53 | """ 54 | counter += 1 55 | if counter >= 5: 56 | return '未找到合适参数,请提供更详细的描述。' 57 | return res_dict 58 | 59 | def run(self, query, tools=None): 60 | tools = (self.tool_register.get_tools()) 61 | result = self.__call__(query, tools) 62 | 63 | if result == "未找到合适参数,请提供更详细的描述。": 64 | return "未找到合适参数,请提供更详细的描述。" 65 | else: 66 | print("找到对应工具函数,格式如下:", result) 67 | result = self.dispatch_tool(result) 68 | from lmchain.prompts.templates import PromptTemplate 69 | tool_prompt = PromptTemplate( 70 | input_variables=["query", "result"], # 输入变量包括中文和英文。 71 | template="你现在是一个私人助手,现在你的查询任务是{query},而你通过工具从网上查询的结果是{result},现在根据查询的内容与查询的结果,生成最终答案。", 72 | # 使用模板格式化输入和输出。 73 | ) 74 | from langchain.chains import LLMChain 75 | chain = LLMChain(llm=self.llm, prompt=tool_prompt) 76 | 77 | response = (chain.run({"query": query, "result": result})) 78 | 79 | return response 80 | 81 | def add_tools(self, tool): 82 | self.tool_register.register_tool(tool) 83 | return True 84 | 85 | def dispatch_tool(self, tool_result) -> str: 86 | tool_name = tool_result["function_name"] 87 | tool_params = tool_result["params"] 88 | if tool_name not in self.tool_register._TOOL_HOOKS: 89 | return f"Tool `{tool_name}` not found. Please use a provided tool." 90 | tool_call = self.tool_register._TOOL_HOOKS[tool_name] 91 | 92 | try: 93 | ret = tool_call(**tool_params) 94 | except: 95 | import traceback 96 | ret = traceback.format_exc() 97 | return str(ret) 98 | 99 | def get_tools(self): 100 | return (self.tool_register.get_tools()) 101 | 102 | 103 | if __name__ == '__main__': 104 | from lmchain.agents import llmMultiAgent 105 | 106 | llm = llmMultiAgent.AgentZhipuAI() 107 | 108 | from lmchain.chains import toolchain 109 | 110 | tool_chain = toolchain.GLMToolChain(llm) 111 | 112 | from typing import Annotated 113 | 114 | 115 | def rando_numbr( 116 | seed: Annotated[int, 'The random seed used by the generator', True], 117 | range: Annotated[tuple[int, int], 'The range of the generated numbers', True], 118 | ) -> int: 119 | """ 120 | Generates a random number x, s.t. range[0] <= x < range[1] 121 | """ 122 | import random 123 | return random.Random(seed).randint(*range) 124 | 125 | 126 | tool_chain.add_tools(rando_numbr) 127 | 128 | print("------------------------------------------------------") 129 | query = "今天shanghai的天气是什么?" 130 | result = tool_chain.run(query) 131 | 132 | result = tool_chain.dispatch_tool(result) 133 | print(result) 134 | 135 | 136 | -------------------------------------------------------------------------------- /lmchain/chains/urlRequestChain.py: -------------------------------------------------------------------------------- 1 | from langchain.chains import LLMRequestsChain, LLMChain 2 | from langchain.prompts import PromptTemplate 3 | 4 | import requests 5 | from bs4 import BeautifulSoup 6 | from tqdm import tqdm 7 | 8 | 9 | class LMRequestsChain: 10 | def __init__(self,llm,max_url_num = 2): 11 | template = """Between >>> and <<< are the raw search result text from google. 12 | Extract the answer to the question '{query}' or say "not found" if the information is not contained. 13 | Use the format 14 | Extracted: 15 | >>> {requests_result} <<< 16 | Extracted:""" 17 | PROMPT = PromptTemplate( 18 | input_variables=["query", "requests_result"], 19 | template=template, 20 | ) 21 | self.chain = LLMRequestsChain(llm_chain=LLMChain(llm=llm, prompt=PROMPT)) 22 | self.max_url_num = max_url_num 23 | 24 | query_prompt = PromptTemplate( 25 | input_variables=["query","responses"], 26 | template = "作为一名专业的信息总结员,我需要查询的信息为{query},根据提供的信息{responses}回答一下查询的结果。") 27 | self.query_chain = LLMChain(llm=llm, prompt=query_prompt) 28 | 29 | def __call__(self, query,target_site = ""): 30 | url_list = self.get_urls(query,target_site = target_site) 31 | print(f"查找到{len(url_list)}条url内容,现在开始解析其中的{self.max_url_num}条内容。") 32 | responses = [] 33 | for url in tqdm(url_list[:self.max_url_num]): 34 | inputs = { 35 | "query": query, 36 | "url": url 37 | } 38 | 39 | response = self.chain(inputs) 40 | output = response["output"] 41 | responses.append(output) 42 | if len(responses) != 0: 43 | output = self.query_chain.run({"query":query,"responses":responses}) 44 | return output 45 | else: 46 | return "查找内容为空,请更换查找词" 47 | 48 | def query_form_url(self,query = "LMchain是什么?",url = ""): 49 | assert url != "",print("url link must be set") 50 | inputs = { 51 | "query": query, 52 | "url": url 53 | } 54 | response = self.chain(inputs) 55 | return response 56 | 57 | def get_urls(self,query='lmchain是什么?', target_site=""): 58 | def bing_search(query, count=30): 59 | url = f'https://cn.bing.com/search?q={query}' 60 | headers = { 61 | 'User-Agent': 'Mozilla/6.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.3'} 62 | response = requests.get(url, headers=headers) 63 | if response.status_code == 200: 64 | html = response.text 65 | # 使用BeautifulSoup解析HTML 66 | 67 | soup = BeautifulSoup(html, 'html.parser') 68 | results = soup.find_all('li', class_='b_algo') 69 | return [result.find('a').text for result in results[:count]] 70 | else: 71 | print(f'请求失败,状态码:{response.status_code}') 72 | return [] 73 | results = bing_search(query) 74 | if len(results) == 0: 75 | return None 76 | url_list = [] 77 | if target_site != "": 78 | for i, result in enumerate(results): 79 | if "https" in result and target_site in result: 80 | url = "https://" + result.split("https://")[1] 81 | url_list.append(url) 82 | else: 83 | for i, result in enumerate(results): 84 | if "https" in result: 85 | url = "https://" + result.split("https://")[1] 86 | url_list.append(url) 87 | if len(url_list) > 0: 88 | return url_list 89 | else: 90 | # 这里是确保在知乎里面找不到对应的内容,有相应的内容返回 91 | for i, result in enumerate(results): 92 | if "https" in result: 93 | url = "https://" + result.split("https://")[1] 94 | url_list.append(url) 95 | return url_list 96 | 97 | 98 | -------------------------------------------------------------------------------- /lmchain/embeddings/__init__.py: -------------------------------------------------------------------------------- 1 | name = "embeddings" -------------------------------------------------------------------------------- /lmchain/index/__init__.py: -------------------------------------------------------------------------------- 1 | name = "index" -------------------------------------------------------------------------------- /lmchain/index/indexChain.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Optional, Type 2 | 3 | from langchain.chains.qa_with_sources.retrieval import RetrievalQAWithSourcesChain 4 | from langchain.chains.retrieval_qa.base import RetrievalQA 5 | from langchain.document_loaders.base import BaseLoader 6 | from pydantic.v1 import BaseModel, Extra, Field 7 | from langchain.schema import Document 8 | from langchain.schema.embeddings import Embeddings 9 | from langchain.schema.language_model import BaseLanguageModel 10 | from langchain.schema.vectorstore import VectorStore 11 | from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter 12 | from langchain.vectorstores.chroma import Chroma 13 | 14 | 15 | def _get_default_text_splitter() -> TextSplitter: 16 | return RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=0) 17 | 18 | from lmchain.embeddings import embeddings 19 | embedding_tool = embeddings.GLMEmbedding() 20 | 21 | class VectorstoreIndexCreator(BaseModel): 22 | """Logic for creating indexes.""" 23 | 24 | class Config: 25 | """Configuration for this pydantic object.""" 26 | extra = Extra.forbid 27 | arbitrary_types_allowed = True 28 | 29 | 30 | 31 | 32 | chunk_size = 1280 # 每段字数长度 33 | chunk_overlap = 32 # 重叠的字数 34 | text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap) 35 | 36 | 37 | 38 | 39 | 40 | def from_loaders(self, loaders: List[BaseLoader]): 41 | """Create a vectorstore index from loaders.""" 42 | docs = [] 43 | for loader in loaders: 44 | docs.extend(loader.load()) 45 | return self.from_documents(docs) 46 | 47 | 48 | def from_documents(self, documents: List[Document]): 49 | #说一下这个index的作用就是返回 50 | sub_docs = self.text_splitter.split_documents(documents) 51 | 52 | # texts = [d.page_content for d in sub_docs] 53 | # metadatas = [d.metadata for d in sub_docs] 54 | 55 | qa_chain = ConversationalRetrievalChain(document=sub_docs) 56 | return qa_chain 57 | 58 | 59 | from langchain.docstore.document import Document 60 | from langchain.text_splitter import RecursiveCharacterTextSplitter 61 | from lmchain.embeddings import embeddings 62 | from lmchain.vectorstores import laiss 63 | from lmchain.agents import llmMultiAgent 64 | from langchain.memory import ConversationBufferMemory 65 | from langchain.prompts import ( 66 | ChatPromptTemplate, # 用于构建聊天模板的类 67 | MessagesPlaceholder, # 用于在模板中插入消息占位的类 68 | SystemMessagePromptTemplate, # 用于构建系统消息模板的类 69 | HumanMessagePromptTemplate # 用于构建人类消息模板的类 70 | ) 71 | from langchain.chains import ConversationChain 72 | 73 | class ConversationalRetrievalChain: 74 | def __init__(self,document,chunk_size = 1280,chunk_overlap = 50,file_name = "这是一份辅助材料"): 75 | """ 76 | :param document: 输入的文本内容,只要一个text文本 77 | :param chunk_size: 切分后每段的字数 78 | :param chunk_overlap: 每个相隔段落重叠的字数 79 | :param file_name: 文本名称/文本地址 80 | """ 81 | self.text_splitter = RecursiveCharacterTextSplitter( chunk_size=chunk_size, chunk_overlap=chunk_overlap) 82 | self.embedding_tool = embedding_tool 83 | 84 | self.lmaiss = laiss.LMASS() #这里是用于将文本转化为vector,并且计算query相应的相似度的类 85 | self.llm = llmMultiAgent.AgentZhipuAI() 86 | self.memory = ConversationBufferMemory(return_messages=True) 87 | 88 | conversation_prompt = ChatPromptTemplate.from_messages([ 89 | SystemMessagePromptTemplate.from_template("你是一个最强大的人工智能程序,可以知无不答,但是你不懂的东西会直接回答不知道。"), 90 | MessagesPlaceholder(variable_name="history"), # 历史消息占位符 91 | HumanMessagePromptTemplate.from_template("{input}") # 人类消息输入模板 92 | ]) 93 | 94 | self.qa_chain = ConversationChain(memory=self.memory, prompt=conversation_prompt, llm=self.llm) 95 | "---------------------------" 96 | self.metadatas = [] 97 | for doc in document: 98 | self.metadatas.append(doc.metadata) 99 | self.documents = self.text_splitter.split_documents(document) #根据 100 | self.vectorstore = self.lmaiss.from_documents(self.documents, embedding_class=self.embedding_tool) 101 | 102 | 103 | 104 | def __call__(self, query): 105 | query_embedding = self.embedding_tool.embed_query(query) 106 | 107 | #根据query查找最近的那个序列 108 | close_id = self.lmaiss.get_similarity_vector_indexs(query_embedding, self.vectorstore, k=1)[0] 109 | #查找最近的那个段落id 110 | doc = self.documents[close_id] 111 | meta = self.metadatas[close_id] 112 | #构建查询的query 113 | query = f"你现在要回答问题'{query}',你可以参考文献'{doc}',你如果找不到对应的内容,就从自己的记忆体中查找,就回答'请提供更为准确的查询内容'。" 114 | result = (self.qa_chain.predict(input=query)) 115 | return result,meta 116 | 117 | 118 | def query(self,input): 119 | result,meta = self.__call__(input) 120 | return result 121 | 122 | #这里的模型的意思是 123 | def query_with_sources(self,input): 124 | result,meta = self.__call__(input) 125 | return {"answer":result,"sources":meta} 126 | -------------------------------------------------------------------------------- /lmchain/llms/__init__.py: -------------------------------------------------------------------------------- 1 | name = "llms" -------------------------------------------------------------------------------- /lmchain/llms/base.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/virgo777/lmchain/3fc4a55df07100d4fe0beef4da89e3fbc6288938/lmchain/llms/base.py -------------------------------------------------------------------------------- /lmchain/load/__init__.py: -------------------------------------------------------------------------------- 1 | name = "load" -------------------------------------------------------------------------------- /lmchain/load/__pycache__/__init__.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/virgo777/lmchain/3fc4a55df07100d4fe0beef4da89e3fbc6288938/lmchain/load/__pycache__/__init__.cpython-311.pyc -------------------------------------------------------------------------------- /lmchain/load/__pycache__/serializable.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/virgo777/lmchain/3fc4a55df07100d4fe0beef4da89e3fbc6288938/lmchain/load/__pycache__/serializable.cpython-311.pyc -------------------------------------------------------------------------------- /lmchain/memory/__init__.py: -------------------------------------------------------------------------------- 1 | name = "memory" -------------------------------------------------------------------------------- /lmchain/memory/chat_memory.py: -------------------------------------------------------------------------------- 1 | from abc import ABC 2 | from typing import Any, Dict, Optional, Tuple 3 | 4 | from lmchain.memory.utils import get_prompt_input_key 5 | 6 | from lmchain.schema.memory import BaseMemory 7 | 8 | 9 | class BaseChatMemory(BaseMemory, ABC): 10 | """Abstract base class for chat memory.""" 11 | 12 | from lmchain.memory import messageHistory 13 | chat_memory = messageHistory.ChatMessageHistory() 14 | output_key: Optional[str] = None 15 | input_key: Optional[str] = None 16 | return_messages: bool = False 17 | 18 | def _get_input_output( 19 | self, inputs: Dict[str, Any], outputs: Dict[str, str] 20 | ) -> Tuple[str, str]: 21 | if self.input_key is None: 22 | prompt_input_key = get_prompt_input_key(inputs, self.memory_variables) 23 | else: 24 | prompt_input_key = self.input_key 25 | if self.output_key is None: 26 | if len(outputs) != 1: 27 | raise ValueError(f"One output key expected, got {outputs.keys()}") 28 | output_key = list(outputs.keys())[0] 29 | else: 30 | output_key = self.output_key 31 | return inputs[prompt_input_key], outputs[output_key] 32 | 33 | def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None: 34 | """Save context from this conversation to buffer.""" 35 | input_str, output_str = self._get_input_output(inputs, outputs) 36 | self.chat_memory.add_user_message(input_str) 37 | self.chat_memory.add_ai_message(output_str) 38 | 39 | def clear(self) -> None: 40 | """Clear memory contents.""" 41 | self.chat_memory.clear() 42 | -------------------------------------------------------------------------------- /lmchain/memory/messageHistory.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING, Any, Dict, List, Sequence, Union 2 | from typing_extensions import Literal 3 | 4 | 5 | class ChatMessageHistory: 6 | """In memory implementation of chat message history. 7 | 8 | Stores messages in an in memory list. 9 | """ 10 | 11 | messages = [] 12 | 13 | def add_message(self, message) -> None: 14 | """Add a self-created message to the store""" 15 | self.messages.append(message) 16 | 17 | def clear(self) -> None: 18 | self.messages = [] 19 | 20 | def __str__(self): 21 | return ", ".join(str(message) for message in self.messages) 22 | 23 | 24 | class ChatMessageHistory(ChatMessageHistory): 25 | def __init__(self): 26 | super(ChatMessageHistory).__init__() 27 | 28 | def add_user_message(self, content: str) -> None: 29 | """Convenience method for adding a human message string to the store. 30 | 31 | Args: 32 | content: The string contents of a human message. 33 | """ 34 | mes = f"HumanMessage(content={content})" 35 | self.messages.append(mes) 36 | 37 | def add_ai_message(self, content: str) -> None: 38 | """Convenience method for adding an AI message string to the store. 39 | 40 | Args: 41 | content: The string contents of an AI message. 42 | """ 43 | mes = f"AIMessage(content={content})" 44 | self.messages.append(mes) 45 | 46 | 47 | from typing import Any, Dict, List, Optional 48 | 49 | from langchain.memory.chat_memory import BaseChatMemory, BaseMemory 50 | from langchain.memory.utils import get_prompt_input_key 51 | from pydantic.v1 import root_validator 52 | from langchain.schema.messages import BaseMessage, get_buffer_string 53 | 54 | 55 | class ConversationBufferMemory(BaseChatMemory): 56 | """Buffer for storing conversation memory.""" 57 | 58 | human_prefix: str = "Human" 59 | ai_prefix: str = "AI" 60 | memory_key: str = "history" #: :meta private: 61 | 62 | @property 63 | def buffer(self) -> Any: 64 | """String buffer of memory.""" 65 | return self.buffer_as_messages if self.return_messages else self.buffer_as_str 66 | 67 | @property 68 | def buffer_as_str(self) -> str: 69 | """Exposes the buffer as a string in case return_messages is True.""" 70 | return get_buffer_string( 71 | self.chat_memory.messages, 72 | human_prefix=self.human_prefix, 73 | ai_prefix=self.ai_prefix, 74 | ) 75 | 76 | @property 77 | def buffer_as_messages(self) -> List[BaseMessage]: 78 | """Exposes the buffer as a list of messages in case return_messages is False.""" 79 | return self.chat_memory.messages 80 | 81 | @property 82 | def memory_variables(self) -> List[str]: 83 | """Will always return list of memory variables. 84 | 85 | :meta private: 86 | """ 87 | return [self.memory_key] 88 | 89 | def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]: 90 | """Return history buffer.""" 91 | return {self.memory_key: self.buffer} 92 | 93 | 94 | class ConversationStringBufferMemory(BaseMemory): 95 | """Buffer for storing conversation memory.""" 96 | 97 | human_prefix: str = "Human" 98 | ai_prefix: str = "AI" 99 | """Prefix to use for AI generated responses.""" 100 | buffer: str = "" 101 | output_key: Optional[str] = None 102 | input_key: Optional[str] = None 103 | memory_key: str = "history" #: :meta private: 104 | 105 | @root_validator() 106 | def validate_chains(cls, values: Dict) -> Dict: 107 | """Validate that return messages is not True.""" 108 | if values.get("return_messages", False): 109 | raise ValueError( 110 | "return_messages must be False for ConversationStringBufferMemory" 111 | ) 112 | return values 113 | 114 | @property 115 | def memory_variables(self) -> List[str]: 116 | """Will always return list of memory variables. 117 | :meta private: 118 | """ 119 | return [self.memory_key] 120 | 121 | def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]: 122 | """Return history buffer.""" 123 | return {self.memory_key: self.buffer} 124 | 125 | def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None: 126 | """Save context from this conversation to buffer.""" 127 | if self.input_key is None: 128 | prompt_input_key = get_prompt_input_key(inputs, self.memory_variables) 129 | else: 130 | prompt_input_key = self.input_key 131 | if self.output_key is None: 132 | if len(outputs) != 1: 133 | raise ValueError(f"One output key expected, got {outputs.keys()}") 134 | output_key = list(outputs.keys())[0] 135 | else: 136 | output_key = self.output_key 137 | human = f"{self.human_prefix}: " + inputs[prompt_input_key] 138 | ai = f"{self.ai_prefix}: " + outputs[output_key] 139 | self.buffer += "\n" + "\n".join([human, ai]) 140 | 141 | def clear(self) -> None: 142 | """Clear memory contents.""" 143 | self.buffer = "" 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | -------------------------------------------------------------------------------- /lmchain/memory/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List 2 | 3 | 4 | def get_prompt_input_key(inputs: Dict[str, Any], memory_variables: List[str]) -> str: 5 | """ 6 | Get the prompt input key. 7 | 8 | Args: 9 | inputs: Dict[str, Any] 10 | memory_variables: List[str] 11 | 12 | Returns: 13 | A prompt input key. 14 | """ 15 | # "stop" is a special key that can be passed as input but is not used to 16 | # format the prompt. 17 | prompt_input_keys = list(set(inputs).difference(memory_variables + ["stop"])) 18 | if len(prompt_input_keys) != 1: 19 | raise ValueError(f"One input key expected got {prompt_input_keys}") 20 | return prompt_input_keys[0] -------------------------------------------------------------------------------- /lmchain/model/__init__.py: -------------------------------------------------------------------------------- 1 | name = "model" -------------------------------------------------------------------------------- /lmchain/prompts/__init__.py: -------------------------------------------------------------------------------- 1 | name = "prompts" -------------------------------------------------------------------------------- /lmchain/prompts/__pycache__/__init__.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/virgo777/lmchain/3fc4a55df07100d4fe0beef4da89e3fbc6288938/lmchain/prompts/__pycache__/__init__.cpython-311.pyc -------------------------------------------------------------------------------- /lmchain/prompts/__pycache__/base.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/virgo777/lmchain/3fc4a55df07100d4fe0beef4da89e3fbc6288938/lmchain/prompts/__pycache__/base.cpython-311.pyc -------------------------------------------------------------------------------- /lmchain/prompts/base.py: -------------------------------------------------------------------------------- 1 | """BasePrompt schema definition.""" 2 | from __future__ import annotations 3 | 4 | import warnings 5 | from abc import ABC 6 | from string import Formatter 7 | from typing import Any, Callable, Dict, List, Literal, Set 8 | 9 | from lmchain.schema.messages import BaseMessage, HumanMessage 10 | from lmchain.schema.prompt import PromptValue 11 | from lmchain.schema.prompt_template import BasePromptTemplate 12 | #from langchain.schema.prompt_template import BasePromptTemplate 13 | from lmchain.utils.formatting import formatter 14 | 15 | 16 | def jinja2_formatter(template: str, **kwargs: Any) -> str: 17 | """Format a template using jinja2. 18 | 19 | *Security warning*: As of LangChain 0.0.329, this method uses Jinja2's 20 | SandboxedEnvironment by default. However, this sand-boxing should 21 | be treated as a best-effort approach rather than a guarantee of security. 22 | Do not accept jinja2 templates from untrusted sources as they may lead 23 | to arbitrary Python code execution. 24 | 25 | https://jinja.palletsprojects.com/en/3.1.x/sandbox/ 26 | """ 27 | try: 28 | from jinja2.sandbox import SandboxedEnvironment 29 | except ImportError: 30 | raise ImportError( 31 | "jinja2 not installed, which is needed to use the jinja2_formatter. " 32 | "Please install it with `pip install jinja2`." 33 | "Please be cautious when using jinja2 templates. " 34 | "Do not expand jinja2 templates using unverified or user-controlled " 35 | "inputs as that can result in arbitrary Python code execution." 36 | ) 37 | 38 | # This uses a sandboxed environment to prevent arbitrary code execution. 39 | # Jinja2 uses an opt-out rather than opt-in approach for sand-boxing. 40 | # Please treat this sand-boxing as a best-effort approach rather than 41 | # a guarantee of security. 42 | # We recommend to never use jinja2 templates with untrusted inputs. 43 | # https://jinja.palletsprojects.com/en/3.1.x/sandbox/ 44 | # approach not a guarantee of security. 45 | return SandboxedEnvironment().from_string(template).render(**kwargs) 46 | 47 | 48 | def validate_jinja2(template: str, input_variables: List[str]) -> None: 49 | """ 50 | Validate that the input variables are valid for the template. 51 | Issues a warning if missing or extra variables are found. 52 | 53 | Args: 54 | template: The template string. 55 | input_variables: The input variables. 56 | """ 57 | input_variables_set = set(input_variables) 58 | valid_variables = _get_jinja2_variables_from_template(template) 59 | missing_variables = valid_variables - input_variables_set 60 | extra_variables = input_variables_set - valid_variables 61 | 62 | warning_message = "" 63 | if missing_variables: 64 | warning_message += f"Missing variables: {missing_variables} " 65 | 66 | if extra_variables: 67 | warning_message += f"Extra variables: {extra_variables}" 68 | 69 | if warning_message: 70 | warnings.warn(warning_message.strip()) 71 | 72 | 73 | def _get_jinja2_variables_from_template(template: str) -> Set[str]: 74 | try: 75 | from jinja2 import Environment, meta 76 | except ImportError: 77 | raise ImportError( 78 | "jinja2 not installed, which is needed to use the jinja2_formatter. " 79 | "Please install it with `pip install jinja2`." 80 | ) 81 | env = Environment() 82 | ast = env.parse(template) 83 | variables = meta.find_undeclared_variables(ast) 84 | return variables 85 | 86 | 87 | DEFAULT_FORMATTER_MAPPING: Dict[str, Callable] = { 88 | "f-string": formatter.format, 89 | "jinja2": jinja2_formatter, 90 | } 91 | 92 | DEFAULT_VALIDATOR_MAPPING: Dict[str, Callable] = { 93 | "f-string": formatter.validate_input_variables, 94 | "jinja2": validate_jinja2, 95 | } 96 | 97 | 98 | def check_valid_template( 99 | template: str, template_format: str, input_variables: List[str] 100 | ) -> None: 101 | """Check that template string is valid. 102 | 103 | Args: 104 | template: The template string. 105 | template_format: The template format. Should be one of "f-string" or "jinja2". 106 | input_variables: The input variables. 107 | 108 | Raises: 109 | ValueError: If the template format is not supported. 110 | """ 111 | if template_format not in DEFAULT_FORMATTER_MAPPING: 112 | valid_formats = list(DEFAULT_FORMATTER_MAPPING) 113 | raise ValueError( 114 | f"Invalid template format. Got `{template_format}`;" 115 | f" should be one of {valid_formats}" 116 | ) 117 | try: 118 | validator_func = DEFAULT_VALIDATOR_MAPPING[template_format] 119 | validator_func(template, input_variables) 120 | except KeyError as e: 121 | raise ValueError( 122 | "Invalid prompt schema; check for mismatched or missing input parameters. " 123 | + str(e) 124 | ) 125 | 126 | 127 | def get_template_variables(template: str, template_format: str) -> List[str]: 128 | """Get the variables from the template. 129 | 130 | Args: 131 | template: The template string. 132 | template_format: The template format. Should be one of "f-string" or "jinja2". 133 | 134 | Returns: 135 | The variables from the template. 136 | 137 | Raises: 138 | ValueError: If the template format is not supported. 139 | """ 140 | if template_format == "jinja2": 141 | # Get the variables for the template 142 | input_variables = _get_jinja2_variables_from_template(template) 143 | elif template_format == "f-string": 144 | input_variables = { 145 | v for _, v, _, _ in Formatter().parse(template) if v is not None 146 | } 147 | else: 148 | raise ValueError(f"Unsupported template format: {template_format}") 149 | 150 | return sorted(input_variables) 151 | 152 | 153 | class StringPromptValue(PromptValue): 154 | """String prompt value.""" 155 | 156 | text: str 157 | """Prompt text.""" 158 | type: Literal["StringPromptValue"] = "StringPromptValue" 159 | 160 | def to_string(self) -> str: 161 | """Return prompt as string.""" 162 | return self.text 163 | 164 | def to_messages(self) -> List[BaseMessage]: 165 | """Return prompt as messages.""" 166 | return [HumanMessage(content=self.text)] 167 | 168 | 169 | class StringPromptTemplate(BasePromptTemplate, ABC): 170 | """String prompt that exposes the format method, returning a prompt.""" 171 | 172 | def format_prompt(self, **kwargs: Any) -> PromptValue: 173 | """Create Chat Messages.""" 174 | return StringPromptValue(text=self.format(**kwargs)) 175 | -------------------------------------------------------------------------------- /lmchain/prompts/loading.py: -------------------------------------------------------------------------------- 1 | """Load prompts.""" 2 | import json 3 | import logging 4 | from pathlib import Path 5 | from typing import Callable, Dict, Union 6 | 7 | import yaml 8 | 9 | from lmchain.prompts.few_shot_templates import FewShotPromptTemplate 10 | from lmchain.prompts.prompt import PromptTemplate 11 | #from langchain.schema import BaseLLMOutputParser, BasePromptTemplate, StrOutputParser 12 | from lmchain.schema.output_parser import BaseLLMOutputParser,StrOutputParser 13 | from lmchain.schema.prompt_template import BasePromptTemplate 14 | 15 | 16 | 17 | from lmchain.utils.loading import try_load_from_hub 18 | 19 | URL_BASE = "https://raw.githubusercontent.com/hwchase17/langchain-hub/master/prompts/" 20 | logger = logging.getLogger(__name__) 21 | 22 | 23 | def load_prompt_from_config(config: dict) -> BasePromptTemplate: 24 | """Load prompt from Config Dict.""" 25 | if "_type" not in config: 26 | logger.warning("No `_type` key found, defaulting to `prompt`.") 27 | config_type = config.pop("_type", "prompt") 28 | 29 | if config_type not in type_to_loader_dict: 30 | raise ValueError(f"Loading {config_type} prompt not supported") 31 | 32 | prompt_loader = type_to_loader_dict[config_type] 33 | return prompt_loader(config) 34 | 35 | 36 | def _load_template(var_name: str, config: dict) -> dict: 37 | """Load template from the path if applicable.""" 38 | # Check if template_path exists in config. 39 | if f"{var_name}_path" in config: 40 | # If it does, make sure template variable doesn't also exist. 41 | if var_name in config: 42 | raise ValueError( 43 | f"Both `{var_name}_path` and `{var_name}` cannot be provided." 44 | ) 45 | # Pop the template path from the config. 46 | template_path = Path(config.pop(f"{var_name}_path")) 47 | # Load the template. 48 | if template_path.suffix == ".txt": 49 | with open(template_path) as f: 50 | template = f.read() 51 | else: 52 | raise ValueError 53 | # Set the template variable to the extracted variable. 54 | config[var_name] = template 55 | return config 56 | 57 | 58 | def _load_examples(config: dict) -> dict: 59 | """Load examples if necessary.""" 60 | if isinstance(config["examples"], list): 61 | pass 62 | elif isinstance(config["examples"], str): 63 | with open(config["examples"]) as f: 64 | if config["examples"].endswith(".json"): 65 | examples = json.load(f) 66 | elif config["examples"].endswith((".yaml", ".yml")): 67 | examples = yaml.safe_load(f) 68 | else: 69 | raise ValueError( 70 | "Invalid file format. Only json or yaml formats are supported." 71 | ) 72 | config["examples"] = examples 73 | else: 74 | raise ValueError("Invalid examples format. Only list or string are supported.") 75 | return config 76 | 77 | 78 | def _load_output_parser(config: dict) -> dict: 79 | """Load output parser.""" 80 | if "output_parser" in config and config["output_parser"]: 81 | _config = config.pop("output_parser") 82 | output_parser_type = _config.pop("_type") 83 | if output_parser_type == "regex_parser": 84 | from langchain.output_parsers.regex import RegexParser 85 | 86 | output_parser: BaseLLMOutputParser = RegexParser(**_config) 87 | elif output_parser_type == "default": 88 | output_parser = StrOutputParser(**_config) 89 | else: 90 | raise ValueError(f"Unsupported output parser {output_parser_type}") 91 | config["output_parser"] = output_parser 92 | return config 93 | 94 | 95 | def _load_few_shot_prompt(config: dict) -> FewShotPromptTemplate: 96 | """Load the "few shot" prompt from the config.""" 97 | # Load the suffix and prefix templates. 98 | config = _load_template("suffix", config) 99 | config = _load_template("prefix", config) 100 | # Load the example prompt. 101 | if "example_prompt_path" in config: 102 | if "example_prompt" in config: 103 | raise ValueError( 104 | "Only one of example_prompt and example_prompt_path should " 105 | "be specified." 106 | ) 107 | config["example_prompt"] = load_prompt(config.pop("example_prompt_path")) 108 | else: 109 | config["example_prompt"] = load_prompt_from_config(config["example_prompt"]) 110 | # Load the examples. 111 | config = _load_examples(config) 112 | config = _load_output_parser(config) 113 | return FewShotPromptTemplate(**config) 114 | 115 | 116 | def _load_prompt(config: dict) -> PromptTemplate: 117 | """Load the prompt template from config.""" 118 | # Load the template from disk if necessary. 119 | config = _load_template("template", config) 120 | config = _load_output_parser(config) 121 | 122 | template_format = config.get("template_format", "f-string") 123 | if template_format == "jinja2": 124 | # Disabled due to: 125 | # https://github.com/langchain-ai/langchain/issues/4394 126 | raise ValueError( 127 | f"Loading templates with '{template_format}' format is no longer supported " 128 | f"since it can lead to arbitrary code execution. Please migrate to using " 129 | f"the 'f-string' template format, which does not suffer from this issue." 130 | ) 131 | 132 | return PromptTemplate(**config) 133 | 134 | 135 | def load_prompt(path: Union[str, Path]) -> BasePromptTemplate: 136 | """Unified method for loading a prompt from LangChainHub or local fs.""" 137 | if hub_result := try_load_from_hub( 138 | path, _load_prompt_from_file, "prompts", {"py", "json", "yaml"} 139 | ): 140 | return hub_result 141 | else: 142 | return _load_prompt_from_file(path) 143 | 144 | 145 | def _load_prompt_from_file(file: Union[str, Path]) -> BasePromptTemplate: 146 | """Load prompt from file.""" 147 | # Convert file to a Path object. 148 | if isinstance(file, str): 149 | file_path = Path(file) 150 | else: 151 | file_path = file 152 | # Load from either json or yaml. 153 | if file_path.suffix == ".json": 154 | with open(file_path,encoding="UTF-8") as f: 155 | config = json.load(f) 156 | elif file_path.suffix == ".yaml": 157 | with open(file_path, "r") as f: 158 | config = yaml.safe_load(f) 159 | else: 160 | raise ValueError(f"Got unsupported file type {file_path.suffix}") 161 | # Load the prompt from the config now. 162 | return load_prompt_from_config(config) 163 | 164 | 165 | type_to_loader_dict: Dict[str, Callable[[dict], BasePromptTemplate]] = { 166 | "prompt": _load_prompt, 167 | "few_shot": _load_few_shot_prompt, 168 | } 169 | -------------------------------------------------------------------------------- /lmchain/schema/__init__.py: -------------------------------------------------------------------------------- 1 | name = "schema" -------------------------------------------------------------------------------- /lmchain/schema/__pycache__/__init__.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/virgo777/lmchain/3fc4a55df07100d4fe0beef4da89e3fbc6288938/lmchain/schema/__pycache__/__init__.cpython-311.pyc -------------------------------------------------------------------------------- /lmchain/schema/__pycache__/document.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/virgo777/lmchain/3fc4a55df07100d4fe0beef4da89e3fbc6288938/lmchain/schema/__pycache__/document.cpython-311.pyc -------------------------------------------------------------------------------- /lmchain/schema/__pycache__/messages.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/virgo777/lmchain/3fc4a55df07100d4fe0beef4da89e3fbc6288938/lmchain/schema/__pycache__/messages.cpython-311.pyc -------------------------------------------------------------------------------- /lmchain/schema/__pycache__/output.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/virgo777/lmchain/3fc4a55df07100d4fe0beef4da89e3fbc6288938/lmchain/schema/__pycache__/output.cpython-311.pyc -------------------------------------------------------------------------------- /lmchain/schema/__pycache__/output_parser.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/virgo777/lmchain/3fc4a55df07100d4fe0beef4da89e3fbc6288938/lmchain/schema/__pycache__/output_parser.cpython-311.pyc -------------------------------------------------------------------------------- /lmchain/schema/__pycache__/prompt.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/virgo777/lmchain/3fc4a55df07100d4fe0beef4da89e3fbc6288938/lmchain/schema/__pycache__/prompt.cpython-311.pyc -------------------------------------------------------------------------------- /lmchain/schema/__pycache__/prompt_template.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/virgo777/lmchain/3fc4a55df07100d4fe0beef4da89e3fbc6288938/lmchain/schema/__pycache__/prompt_template.cpython-311.pyc -------------------------------------------------------------------------------- /lmchain/schema/agent.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Any, Literal, Sequence, Union 4 | 5 | from lmchain.load.serializable import Serializable 6 | from lmchain.schema.messages import BaseMessage 7 | 8 | 9 | class AgentAction(Serializable): 10 | """A full description of an action for an ActionAgent to execute.""" 11 | 12 | tool: str 13 | """The name of the Tool to execute.""" 14 | tool_input: Union[str, dict] 15 | """The input to pass in to the Tool.""" 16 | log: str 17 | """Additional information to log about the action. 18 | This log can be used in a few ways. First, it can be used to audit 19 | what exactly the LLM predicted to lead to this (tool, tool_input). 20 | Second, it can be used in future iterations to show the LLMs prior 21 | thoughts. This is useful when (tool, tool_input) does not contain 22 | full information about the LLM prediction (for example, any `thought` 23 | before the tool/tool_input).""" 24 | type: Literal["AgentAction"] = "AgentAction" 25 | 26 | def __init__( 27 | self, tool: str, tool_input: Union[str, dict], log: str, **kwargs: Any 28 | ): 29 | """Override init to support instantiation by position for backward compat.""" 30 | super().__init__(tool=tool, tool_input=tool_input, log=log, **kwargs) 31 | 32 | @classmethod 33 | def is_lc_serializable(cls) -> bool: 34 | """Return whether or not the class is serializable.""" 35 | return True 36 | 37 | 38 | class AgentActionMessageLog(AgentAction): 39 | message_log: Sequence[BaseMessage] 40 | """Similar to log, this can be used to pass along extra 41 | information about what exact messages were predicted by the LLM 42 | before parsing out the (tool, tool_input). This is again useful 43 | if (tool, tool_input) cannot be used to fully recreate the LLM 44 | prediction, and you need that LLM prediction (for future agent iteration). 45 | Compared to `log`, this is useful when the underlying LLM is a 46 | ChatModel (and therefore returns messages rather than a string).""" 47 | # Ignoring type because we're overriding the type from AgentAction. 48 | # And this is the correct thing to do in this case. 49 | # The type literal is used for serialization purposes. 50 | type: Literal["AgentActionMessageLog"] = "AgentActionMessageLog" # type: ignore 51 | 52 | 53 | class AgentFinish(Serializable): 54 | """The final return value of an ActionAgent.""" 55 | 56 | return_values: dict 57 | """Dictionary of return values.""" 58 | log: str 59 | """Additional information to log about the return value. 60 | This is used to pass along the full LLM prediction, not just the parsed out 61 | return value. For example, if the full LLM prediction was 62 | `Final Answer: 2` you may want to just return `2` as a return value, but pass 63 | along the full string as a `log` (for debugging or observability purposes). 64 | """ 65 | type: Literal["AgentFinish"] = "AgentFinish" 66 | 67 | def __init__(self, return_values: dict, log: str, **kwargs: Any): 68 | """Override init to support instantiation by position for backward compat.""" 69 | super().__init__(return_values=return_values, log=log, **kwargs) 70 | 71 | @classmethod 72 | def is_lc_serializable(cls) -> bool: 73 | """Return whether or not the class is serializable.""" 74 | return True 75 | -------------------------------------------------------------------------------- /lmchain/schema/document.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import asyncio 4 | from abc import ABC, abstractmethod 5 | from functools import partial 6 | from typing import Any, Literal, Sequence 7 | 8 | from lmchain.load.serializable import Serializable 9 | from pydantic.v1 import Field 10 | 11 | class Document(Serializable): 12 | """Class for storing a piece of text and associated metadata.""" 13 | 14 | page_content: str 15 | """String text.""" 16 | metadata: dict = Field(default_factory=dict) 17 | """Arbitrary metadata about the page content (e.g., source, relationships to other 18 | documents, etc.). 19 | """ 20 | type: Literal["Document"] = "Document" 21 | 22 | @classmethod 23 | def is_lc_serializable(cls) -> bool: 24 | """Return whether this class is serializable.""" 25 | return True 26 | 27 | 28 | class BaseDocumentTransformer(ABC): 29 | """Abstract base class for document transformation systems. 30 | 31 | A document transformation system takes a sequence of Documents and returns a 32 | sequence of transformed Documents. 33 | 34 | Example: 35 | .. code-block:: python 36 | 37 | class EmbeddingsRedundantFilter(BaseDocumentTransformer, BaseModel): 38 | embeddings: Embeddings 39 | similarity_fn: Callable = cosine_similarity 40 | similarity_threshold: float = 0.95 41 | 42 | class Config: 43 | arbitrary_types_allowed = True 44 | 45 | def transform_documents( 46 | self, documents: Sequence[Document], **kwargs: Any 47 | ) -> Sequence[Document]: 48 | stateful_documents = get_stateful_documents(documents) 49 | embedded_documents = _get_embeddings_from_stateful_docs( 50 | self.embeddings, stateful_documents 51 | ) 52 | included_idxs = _filter_similar_embeddings( 53 | embedded_documents, self.similarity_fn, self.similarity_threshold 54 | ) 55 | return [stateful_documents[i] for i in sorted(included_idxs)] 56 | 57 | async def atransform_documents( 58 | self, documents: Sequence[Document], **kwargs: Any 59 | ) -> Sequence[Document]: 60 | raise NotImplementedError 61 | 62 | """ # noqa: E501 63 | 64 | @abstractmethod 65 | def transform_documents( 66 | self, documents: Sequence[Document], **kwargs: Any 67 | ) -> Sequence[Document]: 68 | """Transform a list of documents. 69 | 70 | Args: 71 | documents: A sequence of Documents to be transformed. 72 | 73 | Returns: 74 | A list of transformed Documents. 75 | """ 76 | 77 | async def atransform_documents( 78 | self, documents: Sequence[Document], **kwargs: Any 79 | ) -> Sequence[Document]: 80 | """Asynchronously transform a list of documents. 81 | 82 | Args: 83 | documents: A sequence of Documents to be transformed. 84 | 85 | Returns: 86 | A list of transformed Documents. 87 | """ 88 | return await asyncio.get_running_loop().run_in_executor( 89 | None, partial(self.transform_documents, **kwargs), documents 90 | ) 91 | -------------------------------------------------------------------------------- /lmchain/schema/memory.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from abc import ABC, abstractmethod 4 | from typing import Any, Dict, List 5 | 6 | 7 | class BaseMemory( ABC): 8 | """Abstract base class for memory in Chains. 9 | 10 | Memory refers to state in Chains. Memory can be used to store information about 11 | past executions of a Chain and inject that information into the inputs of 12 | future executions of the Chain. For example, for conversational Chains Memory 13 | can be used to store conversations and automatically add them to future model 14 | prompts so that the model has the necessary context to respond coherently to 15 | the latest input. 16 | 17 | Example: 18 | .. code-block:: python 19 | 20 | class SimpleMemory(BaseMemory): 21 | memories: Dict[str, Any] = dict() 22 | 23 | @property 24 | def memory_variables(self) -> List[str]: 25 | return list(self.memories.keys()) 26 | 27 | def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]: 28 | return self.memories 29 | 30 | def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None: 31 | pass 32 | 33 | def clear(self) -> None: 34 | pass 35 | """ # noqa: E501 36 | 37 | class Config: 38 | """Configuration for this pydantic object.""" 39 | 40 | arbitrary_types_allowed = True 41 | 42 | @property 43 | @abstractmethod 44 | def memory_variables(self) -> List[str]: 45 | """The string keys this memory class will add to chain inputs.""" 46 | 47 | @abstractmethod 48 | def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]: 49 | """Return key-value pairs given the text input to the chain.""" 50 | 51 | @abstractmethod 52 | def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None: 53 | """Save the context of this chain run to memory.""" 54 | 55 | @abstractmethod 56 | def clear(self) -> None: 57 | """Clear memory contents.""" 58 | -------------------------------------------------------------------------------- /lmchain/schema/prompt.py: -------------------------------------------------------------------------------- 1 | # 这段代码定义了一个名为 PromptValue 的抽象基类,该类用于表示任何语言模型的输入。 2 | # 这个类继承自 Serializable 和 ABC(Abstract Base Class),意味着它是一个可序列化的抽象基类。 3 | 4 | 5 | # 导入 __future__ 模块中的 annotations 功能,使得在 Python 3.7 以下版本中也可以使用类型注解的延迟评估功能。 6 | from __future__ import annotations 7 | 8 | # 导入 abc 模块中的 ABC(抽象基类)和 abstractmethod(抽象方法)装饰器。 9 | from abc import ABC, abstractmethod 10 | # 导入 typing 模块中的 List 类型,用于类型注解。 11 | from typing import List 12 | 13 | # 从 lmchain.load.serializable 模块中导入 Serializable 类,用于序列化和反序列化对象。 14 | from lmchain.load.serializable import Serializable 15 | # 从 lmchain.schema.messages 模块中导入 BaseMessage 类,作为消息基类。 16 | from lmchain.schema.messages import BaseMessage 17 | 18 | 19 | # 定义一个名为 PromptValue 的抽象基类,继承自 Serializable 和 ABC。 20 | class PromptValue(Serializable, ABC): 21 | """Base abstract class for inputs to any language model. 22 | 23 | PromptValues can be converted to both LLM (pure text-generation) inputs and 24 | ChatModel inputs. 25 | """ 26 | 27 | # 类方法,返回一个布尔值,表示这个类是否可序列化。在这个类中,始终返回 True。 28 | @classmethod 29 | def is_lc_serializable(cls) -> bool: 30 | """Return whether this class is serializable.""" 31 | return True 32 | 33 | # 抽象方法,需要子类实现。返回一个字符串,表示 prompt 的值。 34 | @abstractmethod 35 | def to_string(self) -> str: 36 | """Return prompt value as string.""" 37 | 38 | # 抽象方法,需要子类实现。返回一个 BaseMessage 对象的列表,表示 prompt。 39 | @abstractmethod 40 | def to_messages(self) -> List[BaseMessage]: 41 | """Return prompt as a list of Messages.""" 42 | -------------------------------------------------------------------------------- /lmchain/schema/runnable/__init__.py: -------------------------------------------------------------------------------- 1 | name = "schema.runnable" -------------------------------------------------------------------------------- /lmchain/schema/runnable/config.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/virgo777/lmchain/3fc4a55df07100d4fe0beef4da89e3fbc6288938/lmchain/schema/runnable/config.py -------------------------------------------------------------------------------- /lmchain/schema/schema.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Any, Literal, Sequence, Union 4 | 5 | from lmchain.load.serializable import Serializable 6 | from lmchain.schema.messages import BaseMessage 7 | 8 | 9 | class AgentAction(Serializable): 10 | """A full description of an action for an ActionAgent to execute.""" 11 | 12 | tool: str 13 | """The name of the Tool to execute.""" 14 | tool_input: Union[str, dict] 15 | """The input to pass in to the Tool.""" 16 | log: str 17 | """Additional information to log about the action. 18 | This log can be used in a few ways. First, it can be used to audit 19 | what exactly the LLM predicted to lead to this (tool, tool_input). 20 | Second, it can be used in future iterations to show the LLMs prior 21 | thoughts. This is useful when (tool, tool_input) does not contain 22 | full information about the LLM prediction (for example, any `thought` 23 | before the tool/tool_input).""" 24 | type: Literal["AgentAction"] = "AgentAction" 25 | 26 | def __init__( 27 | self, tool: str, tool_input: Union[str, dict], log: str, **kwargs: Any 28 | ): 29 | """Override init to support instantiation by position for backward compat.""" 30 | super().__init__(tool=tool, tool_input=tool_input, log=log, **kwargs) 31 | 32 | @classmethod 33 | def is_lc_serializable(cls) -> bool: 34 | """Return whether or not the class is serializable.""" 35 | return True 36 | 37 | 38 | class AgentActionMessageLog(AgentAction): 39 | message_log: Sequence[BaseMessage] 40 | """Similar to log, this can be used to pass along extra 41 | information about what exact messages were predicted by the LLM 42 | before parsing out the (tool, tool_input). This is again useful 43 | if (tool, tool_input) cannot be used to fully recreate the LLM 44 | prediction, and you need that LLM prediction (for future agent iteration). 45 | Compared to `log`, this is useful when the underlying LLM is a 46 | ChatModel (and therefore returns messages rather than a string).""" 47 | # Ignoring type because we're overriding the type from AgentAction. 48 | # And this is the correct thing to do in this case. 49 | # The type literal is used for serialization purposes. 50 | type: Literal["AgentActionMessageLog"] = "AgentActionMessageLog" # type: ignore 51 | 52 | 53 | class AgentFinish(Serializable): 54 | """The final return value of an ActionAgent.""" 55 | 56 | return_values: dict 57 | """Dictionary of return values.""" 58 | log: str 59 | """Additional information to log about the return value. 60 | This is used to pass along the full LLM prediction, not just the parsed out 61 | return value. For example, if the full LLM prediction was 62 | `Final Answer: 2` you may want to just return `2` as a return value, but pass 63 | along the full string as a `log` (for debugging or observability purposes). 64 | """ 65 | type: Literal["AgentFinish"] = "AgentFinish" 66 | 67 | def __init__(self, return_values: dict, log: str, **kwargs: Any): 68 | """Override init to support instantiation by position for backward compat.""" 69 | super().__init__(return_values=return_values, log=log, **kwargs) 70 | 71 | @classmethod 72 | def is_lc_serializable(cls) -> bool: 73 | """Return whether or not the class is serializable.""" 74 | return True 75 | -------------------------------------------------------------------------------- /lmchain/tool_register.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import inspect 4 | import traceback 5 | from copy import deepcopy 6 | from pprint import pformat 7 | from types import GenericAlias 8 | from typing import get_origin, Annotated 9 | 10 | _TOOL_HOOKS = {} 11 | _TOOL_DESCRIPTIONS = {} 12 | 13 | 14 | def register_tool(func: callable): 15 | tool_name = func.__name__ 16 | tool_description = inspect.getdoc(func).strip() 17 | python_params = inspect.signature(func).parameters 18 | tool_params = [] 19 | for name, param in python_params.items(): 20 | annotation = param.annotation 21 | if annotation is inspect.Parameter.empty: 22 | raise TypeError(f"Parameter `{name}` missing type annotation") 23 | if get_origin(annotation) != Annotated: 24 | raise TypeError(f"Annotation type for `{name}` must be typing.Annotated") 25 | 26 | typ, (description, required) = annotation.__origin__, annotation.__metadata__ 27 | typ: str = str(typ) if isinstance(typ, GenericAlias) else typ.__name__ 28 | if not isinstance(description, str): 29 | raise TypeError(f"Description for `{name}` must be a string") 30 | if not isinstance(required, bool): 31 | raise TypeError(f"Required for `{name}` must be a bool") 32 | 33 | tool_params.append({ 34 | "name": name, 35 | "description": description, 36 | "type": typ, 37 | "required": required 38 | }) 39 | tool_def = { 40 | "name": tool_name, 41 | "description": tool_description, 42 | "params": tool_params 43 | } 44 | 45 | # print("[registered tool] " + pformat(tool_def)) 46 | _TOOL_HOOKS[tool_name] = func 47 | _TOOL_DESCRIPTIONS[tool_name] = tool_def 48 | 49 | return func 50 | 51 | 52 | def dispatch_tool(tool_name: str, tool_params: dict) -> str: 53 | if tool_name not in _TOOL_HOOKS: 54 | return f"Tool `{tool_name}` not found. Please use a provided tool." 55 | tool_call = _TOOL_HOOKS[tool_name] 56 | try: 57 | ret = tool_call(**tool_params) 58 | except: 59 | ret = traceback.format_exc() 60 | return str(ret) 61 | 62 | 63 | def get_tools() -> dict: 64 | return deepcopy(_TOOL_DESCRIPTIONS) 65 | 66 | 67 | # Tool Definitions 68 | 69 | # @register_tool 70 | # def random_number_generator( 71 | # seed: Annotated[int, 'The random seed used by the generator', True], 72 | # range: Annotated[tuple[int, int], 'The range of the generated numbers', True], 73 | # ) -> int: 74 | # """ 75 | # Generates a random number x, s.t. range[0] <= x < range[1] 76 | # """ 77 | # if not isinstance(seed, int): 78 | # raise TypeError("Seed must be an integer") 79 | # if not isinstance(range, tuple): 80 | # raise TypeError("Range must be a tuple") 81 | # if not isinstance(range[0], int) or not isinstance(range[1], int): 82 | # raise TypeError("Range must be a tuple of integers") 83 | # 84 | # import random 85 | # return random.Random(seed).randint(*range) 86 | # 87 | # 88 | # @register_tool 89 | # def get_weather( 90 | # city_name: Annotated[str, 'The name of the city to be queried', True], 91 | # ) -> str: 92 | # """ 93 | # Get the current weather for `city_name` 94 | # """ 95 | # 96 | # if not isinstance(city_name, str): 97 | # raise TypeError("City name must be a string") 98 | # 99 | # key_selection = { 100 | # "current_condition": ["temp_C", "FeelsLikeC", "humidity", "weatherDesc", "observation_time"], 101 | # } 102 | # import requests 103 | # try: 104 | # resp = requests.get(f"https://wttr.in/{city_name}?format=j1") 105 | # resp.raise_for_status() 106 | # resp = resp.json() 107 | # ret = {k: {_v: resp[k][0][_v] for _v in v} for k, v in key_selection.items()} 108 | # except: 109 | # import traceback 110 | # ret = "Error encountered while fetching weather data!\n" + traceback.format_exc() 111 | # 112 | # return str(ret) 113 | # 114 | # 115 | # @register_tool 116 | # def get_customer_weather(location: Annotated[str, "需要查询位置的名称,用中文表示的地点名称", True] = ""): 117 | # """ 自己编写的天气查询函数""" 118 | # 119 | # if location == "上海": 120 | # return 23.0 121 | # elif location == "南京": 122 | # return 25.0 123 | # else: 124 | # return "未查询相关内容" 125 | # 126 | ## 127 | # @register_tool 128 | # def get_random_fun(location: Annotated[str, "随机参数", True] = ""): 129 | # """编写的一个混淆随机函数""" 130 | # location = location 131 | # return "你上当啦" 132 | 133 | 134 | if __name__ == "__main__": 135 | print(dispatch_tool("get_weather", {"city_name": "shanghai"})) 136 | print(get_tools()) 137 | -------------------------------------------------------------------------------- /lmchain/tools/__init__.py: -------------------------------------------------------------------------------- 1 | name = "tools" -------------------------------------------------------------------------------- /lmchain/tools/tool_register.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | import traceback 3 | from copy import deepcopy 4 | from pprint import pformat 5 | from types import GenericAlias 6 | from typing import get_origin, Annotated 7 | 8 | _TOOL_HOOKS = {} 9 | _TOOL_DESCRIPTIONS = {} 10 | 11 | 12 | def register_tool(func: callable): 13 | tool_name = func.__name__ 14 | tool_description = inspect.getdoc(func).strip() 15 | python_params = inspect.signature(func).parameters 16 | tool_params = [] 17 | for name, param in python_params.items(): 18 | annotation = param.annotation 19 | if annotation is inspect.Parameter.empty: 20 | raise TypeError(f"Parameter `{name}` missing type annotation") 21 | if get_origin(annotation) != Annotated: 22 | raise TypeError(f"Annotation type for `{name}` must be typing.Annotated") 23 | 24 | typ, (description, required) = annotation.__origin__, annotation.__metadata__ 25 | typ: str = str(typ) if isinstance(typ, GenericAlias) else typ.__name__ 26 | if not isinstance(description, str): 27 | raise TypeError(f"Description for `{name}` must be a string") 28 | if not isinstance(required, bool): 29 | raise TypeError(f"Required for `{name}` must be a bool") 30 | 31 | tool_params.append({ 32 | "name": name, 33 | "description": description, 34 | "type": typ, 35 | "required": required 36 | }) 37 | tool_def = { 38 | "name": tool_name, 39 | "description": tool_description, 40 | "params": tool_params 41 | } 42 | 43 | # print("[registered tool] " + pformat(tool_def)) 44 | _TOOL_HOOKS[tool_name] = func 45 | _TOOL_DESCRIPTIONS[tool_name] = tool_def 46 | 47 | return func 48 | 49 | 50 | def dispatch_tool(tool_name: str, tool_params: dict) -> str: 51 | if tool_name not in _TOOL_HOOKS: 52 | return f"Tool `{tool_name}` not found. Please use a provided tool." 53 | tool_call = _TOOL_HOOKS[tool_name] 54 | try: 55 | ret = tool_call(**tool_params) 56 | except: 57 | ret = traceback.format_exc() 58 | return str(ret) 59 | 60 | 61 | def get_tools() -> dict: 62 | return deepcopy(_TOOL_DESCRIPTIONS) 63 | 64 | 65 | # Tool Definitions 66 | 67 | @register_tool 68 | def random_number_generator( 69 | seed: Annotated[int, 'The random seed used by the generator', True], 70 | range: Annotated[tuple[int, int], 'The range of the generated numbers', True], 71 | ) -> int: 72 | """ 73 | Generates a random number x, s.t. range[0] <= x < range[1] 74 | """ 75 | if not isinstance(seed, int): 76 | raise TypeError("Seed must be an integer") 77 | if not isinstance(range, tuple): 78 | raise TypeError("Range must be a tuple") 79 | if not isinstance(range[0], int) or not isinstance(range[1], int): 80 | raise TypeError("Range must be a tuple of integers") 81 | 82 | import random 83 | return random.Random(seed).randint(*range) 84 | 85 | 86 | @register_tool 87 | def get_weather( 88 | city_name: Annotated[str, 'The name of the city to be queried', True], 89 | ) -> str: 90 | """ 91 | Get the current weather for `city_name` 92 | """ 93 | 94 | if not isinstance(city_name, str): 95 | raise TypeError("City name must be a string") 96 | 97 | key_selection = { 98 | "current_condition": ["temp_C", "FeelsLikeC", "humidity", "weatherDesc", "observation_time"], 99 | } 100 | import requests 101 | try: 102 | resp = requests.get(f"https://wttr.in/{city_name}?format=j1") 103 | resp.raise_for_status() 104 | resp = resp.json() 105 | ret = {k: {_v: resp[k][0][_v] for _v in v} for k, v in key_selection.items()} 106 | except: 107 | import traceback 108 | ret = "Error encountered while fetching weather data!\n" + traceback.format_exc() 109 | 110 | return str(ret) 111 | 112 | 113 | if __name__ == "__main__": 114 | # print(dispatch_tool("get_weather", {"city_name": "beijing"})) 115 | tools = (get_tools()) 116 | import zhipuai as zhipuai 117 | 118 | zhipuai.api_key = "1f565e40af1198e11ff1fd8a5b42771d.SjNfezc40YFsz2KC" # 控制台中获取的 APIKey 信息 119 | 120 | query = "今天shanghai的天气是什么?" 121 | prompt = f""" 122 | 你现在是一个专业的人工智能助手,你现在的需求是{query}。而你需要借助于工具在{tools}中找到对应的函数,用json格式返回对应的函数名和需要的参数。 123 | 124 | 只返回json格式的函数名和需要的参数,不要做描述。 125 | 126 | 如果没有找到合适的函数,则返回:'未找到合适参数,请提供更详细的描述。' 127 | """ 128 | 129 | from lmchain.agents import llmMultiAgent 130 | 131 | llm = llmMultiAgent.AgentZhipuAI() 132 | res = llm(prompt) 133 | print(res) 134 | 135 | import json 136 | 137 | res_dict = json.loads(res) 138 | res_dict = json.loads(res_dict) 139 | 140 | print(dispatch_tool(tool_name=res_dict["function_name"], tool_params=res_dict["params"])) 141 | 142 | 143 | 144 | -------------------------------------------------------------------------------- /lmchain/utils/__init__.py: -------------------------------------------------------------------------------- 1 | name = "utils" -------------------------------------------------------------------------------- /lmchain/utils/__pycache__/__init__.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/virgo777/lmchain/3fc4a55df07100d4fe0beef4da89e3fbc6288938/lmchain/utils/__pycache__/__init__.cpython-311.pyc -------------------------------------------------------------------------------- /lmchain/utils/__pycache__/formatting.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/virgo777/lmchain/3fc4a55df07100d4fe0beef4da89e3fbc6288938/lmchain/utils/__pycache__/formatting.cpython-311.pyc -------------------------------------------------------------------------------- /lmchain/utils/__pycache__/math.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/virgo777/lmchain/3fc4a55df07100d4fe0beef4da89e3fbc6288938/lmchain/utils/__pycache__/math.cpython-311.pyc -------------------------------------------------------------------------------- /lmchain/utils/formatting.py: -------------------------------------------------------------------------------- 1 | """Utilities for formatting strings.""" 2 | from string import Formatter 3 | from typing import Any, List, Mapping, Sequence, Union 4 | 5 | 6 | class StrictFormatter(Formatter): 7 | """A subclass of formatter that checks for extra keys.""" 8 | 9 | def check_unused_args( 10 | self, 11 | used_args: Sequence[Union[int, str]], 12 | args: Sequence, 13 | kwargs: Mapping[str, Any], 14 | ) -> None: 15 | """Check to see if extra parameters are passed.""" 16 | extra = set(kwargs).difference(used_args) 17 | if extra: 18 | raise KeyError(extra) 19 | 20 | def vformat( 21 | self, format_string: str, args: Sequence, kwargs: Mapping[str, Any] 22 | ) -> str: 23 | """Check that no arguments are provided.""" 24 | if len(args) > 0: 25 | raise ValueError( 26 | "No arguments should be provided, " 27 | "everything should be passed as keyword arguments." 28 | ) 29 | return super().vformat(format_string, args, kwargs) 30 | 31 | def validate_input_variables( 32 | self, format_string: str, input_variables: List[str] 33 | ) -> None: 34 | dummy_inputs = {input_variable: "foo" for input_variable in input_variables} 35 | super().format(format_string, **dummy_inputs) 36 | 37 | 38 | formatter = StrictFormatter() 39 | -------------------------------------------------------------------------------- /lmchain/utils/input.py: -------------------------------------------------------------------------------- 1 | """Handle chained inputs.""" 2 | from typing import Dict, List, Optional, TextIO 3 | 4 | _TEXT_COLOR_MAPPING = { 5 | "blue": "36;1", 6 | "yellow": "33;1", 7 | "pink": "38;5;200", 8 | "green": "32;1", 9 | "red": "31;1", 10 | } 11 | 12 | 13 | def get_color_mapping( 14 | items: List[str], excluded_colors: Optional[List] = None 15 | ) -> Dict[str, str]: 16 | """Get mapping for items to a support color.""" 17 | colors = list(_TEXT_COLOR_MAPPING.keys()) 18 | if excluded_colors is not None: 19 | colors = [c for c in colors if c not in excluded_colors] 20 | color_mapping = {item: colors[i % len(colors)] for i, item in enumerate(items)} 21 | return color_mapping 22 | 23 | 24 | def get_colored_text(text: str, color: str) -> str: 25 | """Get colored text.""" 26 | color_str = _TEXT_COLOR_MAPPING[color] 27 | return f"\u001b[{color_str}m\033[1;3m{text}\u001b[0m" 28 | 29 | 30 | def get_bolded_text(text: str) -> str: 31 | """Get bolded text.""" 32 | return f"\033[1m{text}\033[0m" 33 | 34 | 35 | def print_text( 36 | text: str, color: Optional[str] = None, end: str = "", file: Optional[TextIO] = None 37 | ) -> None: 38 | """Print text with highlighting and no end characters.""" 39 | text_to_print = get_colored_text(text, color) if color else text 40 | print(text_to_print, end=end, file=file) 41 | if file: 42 | file.flush() # ensure all printed content are written to file 43 | -------------------------------------------------------------------------------- /lmchain/utils/loading.py: -------------------------------------------------------------------------------- 1 | """Utilities for loading configurations from langchain-hub.""" 2 | 3 | import os 4 | import re 5 | import tempfile 6 | from pathlib import Path, PurePosixPath 7 | from typing import Any, Callable, Optional, Set, TypeVar, Union 8 | from urllib.parse import urljoin 9 | 10 | import requests 11 | 12 | DEFAULT_REF = os.environ.get("LANGCHAIN_HUB_DEFAULT_REF", "master") 13 | URL_BASE = os.environ.get( 14 | "LANGCHAIN_HUB_URL_BASE", 15 | "https://raw.githubusercontent.com/hwchase17/langchain-hub/{ref}/", 16 | ) 17 | HUB_PATH_RE = re.compile(r"lc(?P@[^:]+)?://(?P.*)") 18 | 19 | T = TypeVar("T") 20 | 21 | 22 | def try_load_from_hub( 23 | path: Union[str, Path], 24 | loader: Callable[[str], T], 25 | valid_prefix: str, 26 | valid_suffixes: Set[str], 27 | **kwargs: Any, 28 | ) -> Optional[T]: 29 | """Load configuration from hub. Returns None if path is not a hub path.""" 30 | if not isinstance(path, str) or not (match := HUB_PATH_RE.match(path)): 31 | return None 32 | ref, remote_path_str = match.groups() 33 | ref = ref[1:] if ref else DEFAULT_REF 34 | remote_path = Path(remote_path_str) 35 | if remote_path.parts[0] != valid_prefix: 36 | return None 37 | if remote_path.suffix[1:] not in valid_suffixes: 38 | raise ValueError(f"Unsupported file type, must be one of {valid_suffixes}.") 39 | 40 | # Using Path with URLs is not recommended, because on Windows 41 | # the backslash is used as the path separator, which can cause issues 42 | # when working with URLs that use forward slashes as the path separator. 43 | # Instead, use PurePosixPath to ensure that forward slashes are used as the 44 | # path separator, regardless of the operating system. 45 | full_url = urljoin(URL_BASE.format(ref=ref), PurePosixPath(remote_path).__str__()) 46 | 47 | r = requests.get(full_url, timeout=5) 48 | if r.status_code != 200: 49 | raise ValueError(f"Could not find file at {full_url}") 50 | with tempfile.TemporaryDirectory() as tmpdirname: 51 | file = Path(tmpdirname) / remote_path.name 52 | with open(file, "wb") as f: 53 | f.write(r.content) 54 | return loader(str(file), **kwargs) 55 | -------------------------------------------------------------------------------- /lmchain/utils/math.py: -------------------------------------------------------------------------------- 1 | """Math utils.""" 2 | import logging 3 | from typing import List, Optional, Tuple, Union 4 | 5 | import numpy as np 6 | 7 | logger = logging.getLogger(__name__) 8 | 9 | Matrix = Union[List[List[float]], List[np.ndarray], np.ndarray] 10 | 11 | 12 | def cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray: 13 | """Row-wise cosine similarity between two equal-width matrices.""" 14 | if len(X) == 0 or len(Y) == 0: 15 | return np.array([]) 16 | 17 | X = np.array(X) 18 | Y = np.array(Y) 19 | if X.shape[1] != Y.shape[1]: 20 | raise ValueError( 21 | f"Number of columns in X and Y must be the same. X has shape {X.shape} " 22 | f"and Y has shape {Y.shape}." 23 | ) 24 | try: 25 | import simsimd as simd 26 | 27 | X = np.array(X, dtype=np.float32) 28 | Y = np.array(Y, dtype=np.float32) 29 | Z = 1 - simd.cdist(X, Y, metric="cosine") 30 | if isinstance(Z, float): 31 | return np.array([Z]) 32 | return Z 33 | except ImportError: 34 | logger.info( 35 | "Unable to import simsimd, defaulting to NumPy implementation. If you want " 36 | "to use simsimd please install with `pip install simsimd`." 37 | ) 38 | X_norm = np.linalg.norm(X, axis=1) 39 | Y_norm = np.linalg.norm(Y, axis=1) 40 | # Ignore divide by zero errors run time warnings as those are handled below. 41 | with np.errstate(divide="ignore", invalid="ignore"): 42 | similarity = np.dot(X, Y.T) / np.outer(X_norm, Y_norm) 43 | similarity[np.isnan(similarity) | np.isinf(similarity)] = 0.0 44 | return similarity 45 | 46 | 47 | def cosine_similarity_top_k( 48 | X: Matrix, 49 | Y: Matrix, 50 | top_k: Optional[int] = 5, 51 | score_threshold: Optional[float] = None, 52 | ) -> Tuple[List[Tuple[int, int]], List[float]]: 53 | """Row-wise cosine similarity with optional top-k and score threshold filtering. 54 | 55 | Args: 56 | X: Matrix. 57 | Y: Matrix, same width as X. 58 | top_k: Max number of results to return. 59 | score_threshold: Minimum cosine similarity of results. 60 | 61 | Returns: 62 | Tuple of two lists. First contains two-tuples of indices (X_idx, Y_idx), 63 | second contains corresponding cosine similarities. 64 | """ 65 | if len(X) == 0 or len(Y) == 0: 66 | return [], [] 67 | score_array = cosine_similarity(X, Y) 68 | score_threshold = score_threshold or -1.0 69 | score_array[score_array < score_threshold] = 0 70 | top_k = min(top_k or len(score_array), np.count_nonzero(score_array)) 71 | top_k_idxs = np.argpartition(score_array, -top_k, axis=None)[-top_k:] 72 | top_k_idxs = top_k_idxs[np.argsort(score_array.ravel()[top_k_idxs])][::-1] 73 | ret_idxs = np.unravel_index(top_k_idxs, score_array.shape) 74 | scores = score_array.ravel()[top_k_idxs].tolist() 75 | return list(zip(*ret_idxs)), scores # type: ignore 76 | -------------------------------------------------------------------------------- /lmchain/vectorstores/__init__.py: -------------------------------------------------------------------------------- 1 | name = "vectorstores" -------------------------------------------------------------------------------- /lmchain/vectorstores/__pycache__/vectorstore.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/virgo777/lmchain/3fc4a55df07100d4fe0beef4da89e3fbc6288938/lmchain/vectorstores/__pycache__/vectorstore.cpython-311.pyc -------------------------------------------------------------------------------- /lmchain/vectorstores/chroma.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | from langchain.docstore.document import Document 4 | from langchain.text_splitter import RecursiveCharacterTextSplitter 5 | from lmchain.embeddings import embeddings 6 | from lmchain.vectorstores import laiss 7 | 8 | from langchain.memory import ConversationBufferMemory 9 | from langchain.prompts import ( 10 | ChatPromptTemplate, # 用于构建聊天模板的类 11 | MessagesPlaceholder, # 用于在模板中插入消息占位的类 12 | SystemMessagePromptTemplate, # 用于构建系统消息模板的类 13 | HumanMessagePromptTemplate # 用于构建人类消息模板的类 14 | ) 15 | from langchain.chains import ConversationChain 16 | 17 | class Chroma: 18 | def __init__(self,documents,embedding_tool,chunk_size = 1280,chunk_overlap = 50,source = "这是一份辅助材料"): 19 | """ 20 | :param document: 输入的文本内容,只要一个text文本 21 | :param chunk_size: 切分后每段的字数 22 | :param chunk_overlap: 每个相隔段落重叠的字数 23 | :param source: 文本名称/文本地址 24 | """ 25 | self.text_splitter = RecursiveCharacterTextSplitter( chunk_size=chunk_size, chunk_overlap=chunk_overlap) 26 | self.embedding_tool = embedding_tool 27 | 28 | self.lmaiss = laiss.LMASS() #这里是用于将文本转化为vector,并且计算query相应的相似度的类 29 | 30 | self.documents = [] 31 | self.vectorstores = [] 32 | 33 | "---------------------------" 34 | for document in documents: 35 | document = [Document(page_content=document, metadata={"source": source})] #对输入的document进行格式化处理 36 | doc= self.text_splitter.split_documents(document) #根据 37 | self.documents.extend(doc) 38 | 39 | vector = self.lmaiss.from_documents(document, embedding_class=self.embedding_tool) 40 | self.vectorstores.extend(vector) 41 | 42 | # def __call__(self, query): 43 | # query_embedding = self.embedding_tool.embed_query(query) 44 | # 45 | # #根据query查找最近的那个序列 46 | # close_id = self.lmaiss.get_similarity_vector_indexs(query_embedding, self.vectorstore, k=1)[0] 47 | # #查找最近的那个段落id 48 | # doc = self.documents[close_id] 49 | # 50 | # 51 | # return doc 52 | 53 | def similarity_search(self, query): 54 | query_embedding = self.embedding_tool.embed_query(query) 55 | 56 | #根据query查找最近的那个序列 57 | close_id = self.lmaiss.get_similarity_vector_indexs(query_embedding, self.vectorstores, k=1)[0] 58 | #查找最近的那个段落id 59 | doc = self.documents[close_id] 60 | return doc 61 | 62 | def add_texts(self,texts,metadata = ""): 63 | for document in texts: 64 | document = [Document(page_content=document, metadata={"source": metadata})] #对输入的document进行格式化处理 65 | doc= self.text_splitter.split_documents(document) #根据 66 | self.documents.extend(doc) 67 | 68 | vector = self.lmaiss.from_documents(document, embedding_class=self.embedding_tool) 69 | self.vectorstores.extend(vector) 70 | 71 | return True 72 | 73 | 74 | def from_texts(texts,embeddings,source = ""): 75 | docsearch = Chroma(documents = texts,embedding_tool=embeddings,source = source) 76 | return docsearch 77 | 78 | 79 | # def from_texts(texts,embeddings): 80 | # embs = embeddings.embed_documents(texts=texts) 81 | # return embs -------------------------------------------------------------------------------- /lmchain/vectorstores/embeddings.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | warnings.filterwarnings("ignore") 3 | 4 | import asyncio 5 | from abc import ABC, abstractmethod 6 | from typing import List 7 | 8 | 9 | class Embeddings(ABC): 10 | """Interface for embedding models.""" 11 | 12 | @abstractmethod 13 | def embed_documents(self, texts: List[str]) -> List[List[float]]: 14 | """Embed search docs.""" 15 | 16 | @abstractmethod 17 | def embed_query(self, text: str) -> List[float]: 18 | """Embed query text.""" 19 | 20 | async def aembed_documents(self, texts: List[str]) -> List[List[float]]: 21 | """Asynchronous Embed search docs.""" 22 | return await asyncio.get_running_loop().run_in_executor( 23 | None, self.embed_documents, texts 24 | ) 25 | 26 | async def aembed_query(self, text: str) -> List[float]: 27 | """Asynchronous Embed query text.""" 28 | return await asyncio.get_running_loop().run_in_executor( 29 | None, self.embed_query, text 30 | ) 31 | 32 | 33 | # class LMEmbedding(Embeddings): 34 | # from modelscope.pipelines import pipeline 35 | # from modelscope.utils.constant import Tasks 36 | # pipeline_se = pipeline(Tasks.sentence_embedding, model='thomas/text2vec-base-chinese', model_revision='v1.0.0', 37 | # device="cuda") 38 | # 39 | # def _costruct_inputs(self, texts): 40 | # inputs = { 41 | # "source_sentence": texts 42 | # } 43 | # 44 | # return inputs 45 | # 46 | # def embed_documents(self, texts: List[str]) -> List[List[float]]: 47 | # """Embed search docs.""" 48 | # 49 | # inputs = self._costruct_inputs(texts) 50 | # result_embeddings = self.pipeline_se(input=inputs) 51 | # return result_embeddings["text_embedding"] 52 | # 53 | # def embed_query(self, text: str) -> List[float]: 54 | # """Embed query text.""" 55 | # inputs = self._costruct_inputs([text]) 56 | # result_embeddings = self.pipeline_se(input=inputs) 57 | # return result_embeddings["text_embedding"] 58 | 59 | 60 | class GLMEmbedding(Embeddings): 61 | import zhipuai as zhipuai 62 | zhipuai.api_key = "1f565e40af1198e11ff1fd8a5b42771d.SjNfezc40YFsz2KC" # 控制台中获取的 APIKey 信息 63 | 64 | def _costruct_inputs(self, texts): 65 | inputs = { 66 | "source_sentence": texts 67 | } 68 | 69 | return inputs 70 | 71 | aembeddings = [] # 这个是为了在并发获取embedding_value时候使用的存储embedding_list内容。 72 | atexts = [] 73 | 74 | def embed_documents(self, texts: List[str]) -> List[List[float]]: 75 | """Embed search docs.""" 76 | result_embeddings = [] 77 | for text in texts: 78 | embedding = self.embed_query(text) 79 | result_embeddings.append(embedding) 80 | return result_embeddings 81 | 82 | def embed_query(self, text: str) -> List[float]: 83 | """Embed query text.""" 84 | result_embeddings = self.zhipuai.model_api.invoke( 85 | model="text_embedding", prompt=text) 86 | return result_embeddings["data"]["embedding"] 87 | 88 | def aembed_query(self, text: str) -> List[float]: 89 | """Embed query text.""" 90 | result_embeddings = self.zhipuai.model_api.invoke( 91 | model="text_embedding", prompt=text) 92 | emb = result_embeddings["data"]["embedding"] 93 | 94 | self.aembeddings.append(emb) 95 | self.atexts.append(text) 96 | 97 | # 这里实现了并发embedding获取 98 | def aembed_documents(self, texts: List[str], thread_num=5, wait_sec=0.3) -> List[List[float]]: 99 | import threading 100 | text_length = len(texts) 101 | thread_batch = text_length // thread_num 102 | 103 | for i in range(thread_batch): 104 | start = i * thread_num 105 | end = (i + 1) * thread_num 106 | 107 | # 创建线程列表 108 | threads = [] 109 | # 创建并启动5个线程,每个线程调用一个模型 110 | for text in texts[start:end]: 111 | thread = threading.Thread(target=self.aembed_query, args=(text,)) 112 | thread.start() 113 | threads.append(thread) 114 | for thread in threads: 115 | thread.join(wait_sec) # 设置超时时间为0.3秒 116 | return self.aembeddings, self.atexts 117 | 118 | 119 | if __name__ == '__main__': 120 | import time 121 | 122 | inputs = ["不可以,早晨喝牛奶不科学", "今天早晨喝牛奶不科学", "早晨喝牛奶不科学"] * 50 123 | 124 | start_time = time.time() 125 | aembeddings = (GLMEmbedding().aembed_documents(inputs, thread_num=5, thread_sec=0.3)) 126 | print(aembeddings) 127 | print(len(aembeddings)) 128 | end_time = time.time() 129 | # 计算函数执行时间并打印结果 130 | execution_time = end_time - start_time 131 | print(f"函数执行时间: {execution_time} 秒") 132 | print("----------------------------------------------------------------------------------") 133 | start_time = time.time() 134 | aembeddings = (GLMEmbedding().embed_documents(inputs)) 135 | print(len(aembeddings)) 136 | end_time = time.time() 137 | # 计算函数执行时间并打印结果 138 | execution_time = end_time - start_time 139 | print(f"函数执行时间: {execution_time} 秒") 140 | -------------------------------------------------------------------------------- /lmchain/vectorstores/utils.py: -------------------------------------------------------------------------------- 1 | """Utility functions for working with vectors and vectorstores.""" 2 | 3 | from enum import Enum 4 | from typing import List, Tuple, Type 5 | 6 | import numpy as np 7 | 8 | from lmchain.schema.document import Document 9 | from lmchain.utils.math import cosine_similarity 10 | 11 | class DistanceStrategy(str, Enum): 12 | """Enumerator of the Distance strategies for calculating distances 13 | between vectors.""" 14 | 15 | EUCLIDEAN_DISTANCE = "EUCLIDEAN_DISTANCE" 16 | MAX_INNER_PRODUCT = "MAX_INNER_PRODUCT" 17 | DOT_PRODUCT = "DOT_PRODUCT" 18 | JACCARD = "JACCARD" 19 | COSINE = "COSINE" 20 | 21 | 22 | def maximal_marginal_relevance( 23 | query_embedding: np.ndarray, 24 | embedding_list: list, 25 | lambda_mult: float = 0.5, 26 | k: int = 4, 27 | ) -> List[int]: 28 | """Calculate maximal marginal relevance.""" 29 | if min(k, len(embedding_list)) <= 0: 30 | return [] 31 | if query_embedding.ndim == 1: 32 | query_embedding = np.expand_dims(query_embedding, axis=0) 33 | similarity_to_query = cosine_similarity(query_embedding, embedding_list)[0] 34 | most_similar = int(np.argmax(similarity_to_query)) 35 | idxs = [most_similar] 36 | selected = np.array([embedding_list[most_similar]]) 37 | while len(idxs) < min(k, len(embedding_list)): 38 | best_score = -np.inf 39 | idx_to_add = -1 40 | similarity_to_selected = cosine_similarity(embedding_list, selected) 41 | for i, query_score in enumerate(similarity_to_query): 42 | if i in idxs: 43 | continue 44 | redundant_score = max(similarity_to_selected[i]) 45 | equation_score = ( 46 | lambda_mult * query_score - (1 - lambda_mult) * redundant_score 47 | ) 48 | if equation_score > best_score: 49 | best_score = equation_score 50 | idx_to_add = i 51 | idxs.append(idx_to_add) 52 | selected = np.append(selected, [embedding_list[idx_to_add]], axis=0) 53 | return idxs 54 | 55 | 56 | def filter_complex_metadata( 57 | documents: List[Document], 58 | *, 59 | allowed_types: Tuple[Type, ...] = (str, bool, int, float), 60 | ) -> List[Document]: 61 | """Filter out metadata types that are not supported for a vector store.""" 62 | updated_documents = [] 63 | for document in documents: 64 | filtered_metadata = {} 65 | for key, value in document.metadata.items(): 66 | if not isinstance(value, allowed_types): 67 | continue 68 | filtered_metadata[key] = value 69 | 70 | document.metadata = filtered_metadata 71 | updated_documents.append(document) 72 | 73 | return updated_documents 74 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "LMchain" 3 | version = "0.1.62" 4 | authors = [ 5 | { name="xiaohuaWang", email="5847713@qq.com" }, 6 | ] 7 | description = "A large language chain tools" 8 | readme = "README.md" 9 | requires-python = ">=3.10" 10 | classifiers = [ 11 | "Programming Language :: Python :: 3", 12 | "License :: OSI Approved :: MIT License", 13 | "Operating System :: OS Independent", 14 | ] 15 | 16 | [project.urls] 17 | "Homepage" = "https://github.com/pypa/sampleproject" 18 | "Bug Tracker" = "https://github.com/pypa/sampleproject/issues" -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | with open("README.md", "r") as fh: 3 | long_description = fh.read() 4 | setuptools.setup( 5 | name="lmchain", # 模块名称 6 | version="0.1.62", # 当前版本 7 | author="xiaohuaWang", # 作者 8 | author_email="5847713@qq.com", # 作者邮箱 9 | description="LMchain是一个专门适配大模型chain的工具包", # 模块简介 10 | long_description=long_description, # 模块详细介绍 11 | long_description_content_type="text/markdown", # 模块详细介绍格式 12 | # url="https://github.com/", # 模块github地址 13 | packages=setuptools.find_packages(), # 自动找到项目中导入的模块 14 | include_package_data=True, 15 | # 模块相关的元数据 16 | classifiers=[ 17 | "Programming Language :: Python :: 3", 18 | "License :: OSI Approved :: MIT License", 19 | "Operating System :: OS Independent", 20 | ], 21 | # 依赖模块 22 | install_requires=[ 23 | 'uvicorn', 'fastapi','typing',"numexpr","langchain","zhipuai","nltk" 24 | ], 25 | python_requires='>=3', 26 | ) 27 | -------------------------------------------------------------------------------- /tool_register.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | import traceback 3 | from copy import deepcopy 4 | from pprint import pformat 5 | from types import GenericAlias 6 | from typing import get_origin, Annotated 7 | 8 | _TOOL_HOOKS = {} 9 | _TOOL_DESCRIPTIONS = {} 10 | 11 | 12 | def register_tool(func: callable): 13 | tool_name = func.__name__ 14 | tool_description = inspect.getdoc(func).strip() 15 | python_params = inspect.signature(func).parameters 16 | tool_params = [] 17 | for name, param in python_params.items(): 18 | annotation = param.annotation 19 | if annotation is inspect.Parameter.empty: 20 | raise TypeError(f"Parameter `{name}` missing type annotation") 21 | if get_origin(annotation) != Annotated: 22 | raise TypeError(f"Annotation type for `{name}` must be typing.Annotated") 23 | 24 | typ, (description, required) = annotation.__origin__, annotation.__metadata__ 25 | typ: str = str(typ) if isinstance(typ, GenericAlias) else typ.__name__ 26 | if not isinstance(description, str): 27 | raise TypeError(f"Description for `{name}` must be a string") 28 | if not isinstance(required, bool): 29 | raise TypeError(f"Required for `{name}` must be a bool") 30 | 31 | tool_params.append({ 32 | "name": name, 33 | "description": description, 34 | "type": typ, 35 | "required": required 36 | }) 37 | tool_def = { 38 | "name": tool_name, 39 | "description": tool_description, 40 | "params": tool_params 41 | } 42 | 43 | # print("[registered tool] " + pformat(tool_def)) 44 | _TOOL_HOOKS[tool_name] = func 45 | _TOOL_DESCRIPTIONS[tool_name] = tool_def 46 | 47 | return func 48 | 49 | 50 | def dispatch_tool(tool_name: str, tool_params: dict) -> str: 51 | if tool_name not in _TOOL_HOOKS: 52 | return f"Tool `{tool_name}` not found. Please use a provided tool." 53 | tool_call = _TOOL_HOOKS[tool_name] 54 | try: 55 | ret = tool_call(**tool_params) 56 | except: 57 | ret = traceback.format_exc() 58 | return str(ret) 59 | 60 | 61 | def get_tools() -> dict: 62 | return deepcopy(_TOOL_DESCRIPTIONS) 63 | 64 | 65 | # Tool Definitions 66 | 67 | # @register_tool 68 | # def random_number_generator( 69 | # seed: Annotated[int, 'The random seed used by the generator', True], 70 | # range: Annotated[tuple[int, int], 'The range of the generated numbers', True], 71 | # ) -> int: 72 | # """ 73 | # Generates a random number x, s.t. range[0] <= x < range[1] 74 | # """ 75 | # if not isinstance(seed, int): 76 | # raise TypeError("Seed must be an integer") 77 | # if not isinstance(range, tuple): 78 | # raise TypeError("Range must be a tuple") 79 | # if not isinstance(range[0], int) or not isinstance(range[1], int): 80 | # raise TypeError("Range must be a tuple of integers") 81 | # 82 | # import random 83 | # return random.Random(seed).randint(*range) 84 | # 85 | # 86 | # @register_tool 87 | # def get_weather( 88 | # city_name: Annotated[str, 'The name of the city to be queried', True], 89 | # ) -> str: 90 | # """ 91 | # Get the current weather for `city_name` 92 | # """ 93 | # 94 | # if not isinstance(city_name, str): 95 | # raise TypeError("City name must be a string") 96 | # 97 | # key_selection = { 98 | # "current_condition": ["temp_C", "FeelsLikeC", "humidity", "weatherDesc", "observation_time"], 99 | # } 100 | # import requests 101 | # try: 102 | # resp = requests.get(f"https://wttr.in/{city_name}?format=j1") 103 | # resp.raise_for_status() 104 | # resp = resp.json() 105 | # ret = {k: {_v: resp[k][0][_v] for _v in v} for k, v in key_selection.items()} 106 | # except: 107 | # import traceback 108 | # ret = "Error encountered while fetching weather data!\n" + traceback.format_exc() 109 | # 110 | # return str(ret) 111 | # 112 | # 113 | # @register_tool 114 | # def get_customer_weather(location: Annotated[str, "需要查询位置的名称,用中文表示的地点名称", True] = ""): 115 | # """ 自己编写的天气查询函数""" 116 | # 117 | # if location == "上海": 118 | # return 23.0 119 | # elif location == "南京": 120 | # return 25.0 121 | # else: 122 | # return "未查询相关内容" 123 | # 124 | # 125 | # @register_tool 126 | # def get_random_fun(location: Annotated[str, "随机参数", True] = ""): 127 | # """编写的一个混淆随机函数""" 128 | # location = location 129 | # return "你上当啦" 130 | 131 | 132 | if __name__ == "__main__": 133 | print(dispatch_tool("get_weather", {"city_name": "shanghai"})) 134 | print(get_tools()) 135 | -------------------------------------------------------------------------------- /upload: -------------------------------------------------------------------------------- 1 | upload file 2 | 3 | 4 | 5 | --------------------------------------------------------------------------------