├── .idea
├── .gitignore
├── inspectionProfiles
│ └── profiles_settings.xml
├── lmchain.iml
├── misc.xml
└── modules.xml
├── GLM3_version_REAM.md
├── LICENSE
├── LMchain.egg-info
├── PKG-INFO
├── SOURCES.txt
├── dependency_links.txt
└── top_level.txt
├── MANIFEST.in
├── README.md
├── __pycache__
└── tool_register.cpython-311.pyc
├── build
└── lib
│ └── lmchain
│ ├── __init__.py
│ ├── agents
│ ├── __init__.py
│ ├── llmAgent.py
│ ├── llmMultiActionAgent.py
│ └── llmMultiAgent.py
│ ├── callbacks
│ ├── __init__.py
│ ├── base.py
│ ├── manager.py
│ └── stdout.py
│ ├── chains
│ ├── __init__.py
│ ├── base.py
│ ├── cmd.py
│ ├── conversationalRetrievalChain.py
│ ├── mathChain.py
│ ├── question_answering.py
│ ├── toolchain.py
│ └── urlRequestChain.py
│ ├── embeddings
│ ├── __init__.py
│ └── embeddings.py
│ ├── hello.py
│ ├── index
│ ├── __init__.py
│ └── indexChain.py
│ ├── llms
│ ├── __init__.py
│ └── base.py
│ ├── load
│ ├── __init__.py
│ └── serializable.py
│ ├── memory
│ ├── __init__.py
│ ├── chat_memory.py
│ ├── messageHistory.py
│ └── utils.py
│ ├── model
│ ├── __init__.py
│ └── language_model.py
│ ├── prompts
│ ├── __init__.py
│ ├── base.py
│ ├── chat.py
│ ├── example_selectors.py
│ ├── few_shot_templates.py
│ ├── loading.py
│ ├── prompt.py
│ ├── templates.py
│ └── tool_templates.py
│ ├── schema
│ ├── __init__.py
│ ├── agent.py
│ ├── document.py
│ ├── language_model.py
│ ├── memory.py
│ ├── messages.py
│ ├── output.py
│ ├── output_parser.py
│ ├── prompt.py
│ ├── prompt_template.py
│ ├── runnable.py
│ ├── runnable
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── branch.py
│ │ ├── config.py
│ │ ├── configurable.py
│ │ ├── fallbacks.py
│ │ ├── passthrough.py
│ │ ├── retry.py
│ │ ├── router.py
│ │ └── utils.py
│ ├── runnable_utils.py
│ └── schema.py
│ ├── tool_register.py
│ ├── tools
│ ├── __init__.py
│ └── tool_register.py
│ ├── utils
│ ├── __init__.py
│ ├── formatting.py
│ ├── input.py
│ ├── loading.py
│ └── math.py
│ └── vectorstores
│ ├── __init__.py
│ ├── chroma.py
│ ├── embeddings.py
│ ├── laiss.py
│ ├── utils.py
│ └── vectorstore.py
├── dist
├── LMchain-0.1.60-py3-none-any.whl
├── LMchain-0.1.60.tar.gz
├── LMchain-0.1.61-py3-none-any.whl
├── LMchain-0.1.61.tar.gz
├── LMchain-0.1.62-py3-none-any.whl
└── LMchain-0.1.62.tar.gz
├── lmchain
├── __init__.py
├── __pycache__
│ └── __init__.cpython-311.pyc
├── agents
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-311.pyc
│ │ ├── llmAgent.cpython-311.pyc
│ │ └── llmMultiAgent.cpython-311.pyc
│ ├── llmAgent.py
│ └── llmMultiAgent.py
├── callbacks
│ ├── __init__.py
│ ├── base.py
│ ├── manager.py
│ └── stdout.py
├── chains
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-311.pyc
│ │ ├── cmd.cpython-311.pyc
│ │ ├── mathChain.cpython-311.pyc
│ │ └── urlRequestChain.cpython-311.pyc
│ ├── base.py
│ ├── cmd.py
│ ├── conversationalRetrievalChain.py
│ ├── mathChain.py
│ ├── question_answering.py
│ ├── subQuestChain.py
│ ├── toolchain.py
│ └── urlRequestChain.py
├── embeddings
│ └── __init__.py
├── index
│ ├── __init__.py
│ └── indexChain.py
├── llms
│ ├── __init__.py
│ └── base.py
├── load
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-311.pyc
│ │ └── serializable.cpython-311.pyc
│ └── serializable.py
├── memory
│ ├── __init__.py
│ ├── chat_memory.py
│ ├── messageHistory.py
│ └── utils.py
├── model
│ ├── __init__.py
│ └── language_model.py
├── prompts
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-311.pyc
│ │ └── base.cpython-311.pyc
│ ├── base.py
│ ├── chat.py
│ ├── example_selectors.py
│ ├── few_shot_templates.py
│ ├── loading.py
│ ├── prompt.py
│ ├── templates.py
│ └── tool_templates.py
├── schema
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-311.pyc
│ │ ├── document.cpython-311.pyc
│ │ ├── messages.cpython-311.pyc
│ │ ├── output.cpython-311.pyc
│ │ ├── output_parser.cpython-311.pyc
│ │ ├── prompt.cpython-311.pyc
│ │ └── prompt_template.cpython-311.pyc
│ ├── agent.py
│ ├── document.py
│ ├── language_model.py
│ ├── memory.py
│ ├── messages.py
│ ├── output.py
│ ├── output_parser.py
│ ├── prompt.py
│ ├── prompt_template.py
│ ├── runnable.py
│ ├── runnable
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── branch.py
│ │ ├── config.py
│ │ ├── configurable.py
│ │ ├── fallbacks.py
│ │ ├── passthrough.py
│ │ ├── retry.py
│ │ ├── router.py
│ │ └── utils.py
│ ├── runnable_utils.py
│ └── schema.py
├── tool_register.py
├── tools
│ ├── __init__.py
│ └── tool_register.py
├── utils
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-311.pyc
│ │ ├── formatting.cpython-311.pyc
│ │ └── math.cpython-311.pyc
│ ├── formatting.py
│ ├── input.py
│ ├── loading.py
│ └── math.py
└── vectorstores
│ ├── __init__.py
│ ├── __pycache__
│ └── vectorstore.cpython-311.pyc
│ ├── chroma.py
│ ├── embeddings.py
│ ├── laiss.py
│ ├── utils.py
│ └── vectorstore.py
├── pyproject.toml
├── setup.py
├── tool_register.py
└── upload
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # Default ignored files
2 | /shelf/
3 | /workspace.xml
4 | # Editor-based HTTP Client requests
5 | /httpRequests/
6 | # Datasource local storage ignored files
7 | /dataSources/
8 | /dataSources.local.xml
9 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/lmchain.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/GLM3_version_REAM.md:
--------------------------------------------------------------------------------
1 | 由于GLM更新为GLM4,现在对于GLM3的更新暂停,读者更新 lmchian==0.1.X版本的均为GLM3,
2 |
3 | 而lmchian==0.2.01及以上版本的,均采用GLM4为开发基础模型。
4 |
5 | 在使用上,lmchain可以正常使用。
6 |
7 | LMchain is a toolkit specifically adapted for chinese large model chains.
8 |
9 | Lmchain是专用为中国大陆用户提供免费大模型服务的工具包,目前免费推荐使用chatGLM。
10 |
11 | 免费用户可以在https://open.bigmodel.cn
12 | 注册并获取免费API。也可以使用lmchain中自带的免费key。
13 |
14 | 功能正在陆续添加中,用户可以在issues中发表内容,也可以与作者联系5847713@qq.com
15 | 欢迎提出您的想法和建议。
16 | -----------------------------------------------------------------------------
17 | 使用方法:```pip install lmchain```
18 | -----------------------------------------------------------------------------
19 |
20 | >1、从一个简单的文本问答如下
21 | ```
22 | from lmchain.agents import llmMultiAgent
23 | llm = llmMultiAgent.AgentZhipuAI()
24 | llm.zhipuai.api_key = "1f565e40af1198e11ff1fd8a5b42771d.SjNfezc40YFsz2KC" #你个人注册可正常使用的API KEY
25 | response = llm("南京是哪里的省会?")
26 | print(response)
27 |
28 | response = llm("那里有什么好玩的地方?")
29 | print(response)
30 | ```
31 |
32 | >2、除此之外,lmchain还有对复杂任务拆解的功能,例如:
33 | ```
34 | from lmchain.agents import llmMultiAgent
35 | llm = llmMultiAgent.AgentZhipuAI()
36 |
37 |
38 | query = "工商银行财报中,2023 Q3相比,2024 Q1的收益增长了多少?"
39 |
40 | from lmchain.chains import subQuestChain
41 | subQC = subQuestChain.SubQuestChain(llm)
42 | response = subQC.run(query=query)
43 |
44 | print(response)
45 | ```
46 | >3、调用大模型Embedding tool对文本进行嵌入embedding计算的方法
47 | ```
48 | from lmchain.vectorstores import embeddings # 导入embeddings模块
49 | embedding_tool = embeddings.GLMEmbedding() # 创建一个GLMEmbedding对象
50 | embedding_tool.zhipuai.api_key = "1f565e40af1198e11ff1fd8a5b42771d.SjNfezc40YFsz2KC" #你个人注册可正常使用的API KEY
51 |
52 | inputs = ["lmchain还有对复杂任务拆解的功能", "目前lmchain还提供了对工具函数的调用方法", "Lmchain是专用为中国大陆用户提供免费大模型服务的工具包"] * 50
53 |
54 | #由于此时对embedding的处理,对原始传入的文本顺序做了变更,
55 | # 因此需要采用新的文本list排序
56 | aembeddings,atexts = (embedding_tool.aembed_documents(inputs))
57 | print(aembeddings)
58 |
59 | #每条文本内容被embedding处理为[1,1024]大小的序列
60 | import numpy as np
61 | aembeddings = (np.array(aembeddings))
62 | print(aembeddings.shape)
63 | ```
64 | >4、目前lmchain还提供了对工具函数的调用方法
65 | ```
66 | from lmchain.agents import llmMultiAgent
67 | llm = llmMultiAgent.AgentZhipuAI()
68 |
69 | from lmchain.chains import toolchain
70 |
71 | tool_chain = toolchain.GLMToolChain(llm)
72 |
73 | query = "说一下上海的天气"
74 | response = tool_chain.run(query)
75 | print(response)
76 | ```
77 |
78 | >5、添加自定义工具并调用的方法
79 | ```
80 | from lmchain.agents import llmMultiAgent
81 | llm = llmMultiAgent.AgentZhipuAI()
82 |
83 | from lmchain.chains import toolchain
84 | tool_chain = toolchain.GLMToolChain(llm)
85 |
86 | from typing import Annotated
87 | #下面的play_game是自定义的工具
88 | def play_game(
89 | #使用Annotated对形参进行标注[形参类型,形参用途描述,是否必须]
90 | num: Annotated[int, 'use the num to play game', True],
91 | ):
92 | #函数内注释是为了向模型提供对函数用途的解释
93 | """
94 | 一个数字游戏,
95 | 随机输入数字,按游戏规则输出结果的游戏
96 | """
97 | if num % 3:
98 | return 3
99 | if num % 5:
100 | return 5
101 | return 0
102 |
103 | tool_chain.add_tools(play_game)
104 | query = "玩一个数字游戏,输入数字3"
105 | result = tool_chain.run(query)
106 |
107 | print(result)
108 |
109 | ```
110 | 其他功能正在陆续添加中,欢迎读者留下您的意见或与作者联系。
111 |
112 |
113 |
114 |
115 |
116 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright (c) 2018 The Python Packaging Authority
2 | Permission is hereby granted, free of charge, to any person obtaining a copy
3 | of this software and associated documentation files (the "Software"), to deal
4 | in the Software without restriction, including without limitation the rights
5 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
6 | copies of the Software, and to permit persons to whom the Software is
7 | furnished to do so, subject to the following conditions:
8 | The above copyright notice and this permission notice shall be included in all
9 | copies or substantial portions of the Software.
10 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
11 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
12 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
13 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
14 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
15 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
16 | SOFTWARE.
17 |
--------------------------------------------------------------------------------
/LMchain.egg-info/PKG-INFO:
--------------------------------------------------------------------------------
1 | Metadata-Version: 2.1
2 | Name: LMchain
3 | Version: 0.1.62
4 | Summary: A large language chain tools
5 | Author: xiaohuaWang
6 | Author-email: xiaohuaWang <5847713@qq.com>
7 | Project-URL: Homepage, https://github.com/pypa/sampleproject
8 | Project-URL: Bug Tracker, https://github.com/pypa/sampleproject/issues
9 | Classifier: Programming Language :: Python :: 3
10 | Classifier: License :: OSI Approved :: MIT License
11 | Classifier: Operating System :: OS Independent
12 | Requires-Python: >=3
13 | Description-Content-Type: text/markdown
14 | License-File: LICENSE
15 |
16 | LMchain is a toolkit specifically adapted for large model chains
17 |
--------------------------------------------------------------------------------
/LMchain.egg-info/SOURCES.txt:
--------------------------------------------------------------------------------
1 | LICENSE
2 | MANIFEST.in
3 | README.md
4 | pyproject.toml
5 | setup.py
6 | LMchain.egg-info/PKG-INFO
7 | LMchain.egg-info/SOURCES.txt
8 | LMchain.egg-info/dependency_links.txt
9 | LMchain.egg-info/top_level.txt
10 | lmchain/__init__.py
11 | lmchain/tool_register.py
12 | lmchain/agents/__init__.py
13 | lmchain/agents/llmAgent.py
14 | lmchain/agents/llmMultiAgent.py
15 | lmchain/callbacks/__init__.py
16 | lmchain/callbacks/base.py
17 | lmchain/callbacks/manager.py
18 | lmchain/callbacks/stdout.py
19 | lmchain/chains/__init__.py
20 | lmchain/chains/base.py
21 | lmchain/chains/cmd.py
22 | lmchain/chains/conversationalRetrievalChain.py
23 | lmchain/chains/mathChain.py
24 | lmchain/chains/question_answering.py
25 | lmchain/chains/toolchain.py
26 | lmchain/chains/urlRequestChain.py
27 | lmchain/embeddings/__init__.py
28 | lmchain/index/__init__.py
29 | lmchain/index/indexChain.py
30 | lmchain/llms/__init__.py
31 | lmchain/llms/base.py
32 | lmchain/load/__init__.py
33 | lmchain/load/serializable.py
34 | lmchain/memory/__init__.py
35 | lmchain/memory/chat_memory.py
36 | lmchain/memory/messageHistory.py
37 | lmchain/memory/utils.py
38 | lmchain/model/__init__.py
39 | lmchain/model/language_model.py
40 | lmchain/prompts/__init__.py
41 | lmchain/prompts/base.py
42 | lmchain/prompts/chat.py
43 | lmchain/prompts/example_selectors.py
44 | lmchain/prompts/few_shot_templates.py
45 | lmchain/prompts/loading.py
46 | lmchain/prompts/prompt.py
47 | lmchain/prompts/templates.py
48 | lmchain/prompts/tool_templates.py
49 | lmchain/schema/__init__.py
50 | lmchain/schema/agent.py
51 | lmchain/schema/document.py
52 | lmchain/schema/language_model.py
53 | lmchain/schema/memory.py
54 | lmchain/schema/messages.py
55 | lmchain/schema/output.py
56 | lmchain/schema/output_parser.py
57 | lmchain/schema/prompt.py
58 | lmchain/schema/prompt_template.py
59 | lmchain/schema/runnable.py
60 | lmchain/schema/runnable_utils.py
61 | lmchain/schema/schema.py
62 | lmchain/schema/runnable/__init__.py
63 | lmchain/schema/runnable/base.py
64 | lmchain/schema/runnable/branch.py
65 | lmchain/schema/runnable/config.py
66 | lmchain/schema/runnable/configurable.py
67 | lmchain/schema/runnable/fallbacks.py
68 | lmchain/schema/runnable/passthrough.py
69 | lmchain/schema/runnable/retry.py
70 | lmchain/schema/runnable/router.py
71 | lmchain/schema/runnable/utils.py
72 | lmchain/tools/__init__.py
73 | lmchain/tools/tool_register.py
74 | lmchain/utils/__init__.py
75 | lmchain/utils/formatting.py
76 | lmchain/utils/input.py
77 | lmchain/utils/loading.py
78 | lmchain/utils/math.py
79 | lmchain/vectorstores/__init__.py
80 | lmchain/vectorstores/chroma.py
81 | lmchain/vectorstores/embeddings.py
82 | lmchain/vectorstores/laiss.py
83 | lmchain/vectorstores/utils.py
84 | lmchain/vectorstores/vectorstore.py
--------------------------------------------------------------------------------
/LMchain.egg-info/dependency_links.txt:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/LMchain.egg-info/top_level.txt:
--------------------------------------------------------------------------------
1 | lmchain
2 |
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | recursive-include your_package_name *.py *.txt
2 | recursive-include another_folder *.csv
3 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | LMchain is a toolkit specifically adapted for chinese large model chains
2 |
3 | Lmchain是专用为中国大陆用户提供免费大模型服务的工具包,目前免费推荐使用chatGLM。
4 |
5 | 免费用户可以在https://open.bigmodel.cn
6 | 注册并获取免费API。也可以使用lmchain中自带的免费key。
7 |
8 | 功能正在陆续添加中,用户可以在issues中发表内容,也可以与作者联系5847713@qq.com
9 | 欢迎提出您的想法和建议。
10 |
11 | 注意:lmchian随着GLM4的更新,已全新更新为新的API,老的基本GLM3版本的用户可以继续使用(版本最高为0.1.78)。
12 | -----------------------------------------------------------------------------
13 | 使用方法:```pip install lmchain```
14 | -----------------------------------------------------------------------------
15 |
16 | >1、从一个简单的文本问答如下
17 | ```
18 | from lmchain.agents import AgentZhipuAI
19 | llm = AgentZhipuAI()
20 |
21 | response = llm("你好")
22 | print(response)
23 |
24 | response = llm("南京是哪里的省会")
25 | print(response)
26 |
27 | response = llm("那里有什么好玩的地方")
28 | print(response)
29 | ```
30 |
31 |
32 |
33 |
34 |
35 |
--------------------------------------------------------------------------------
/__pycache__/tool_register.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/virgo777/lmchain/3fc4a55df07100d4fe0beef4da89e3fbc6288938/__pycache__/tool_register.cpython-311.pyc
--------------------------------------------------------------------------------
/build/lib/lmchain/__init__.py:
--------------------------------------------------------------------------------
1 | name = "lmchain"
--------------------------------------------------------------------------------
/build/lib/lmchain/agents/__init__.py:
--------------------------------------------------------------------------------
1 | name = "agents"
--------------------------------------------------------------------------------
/build/lib/lmchain/agents/llmAgent.py:
--------------------------------------------------------------------------------
1 | import time
2 | import logging
3 | import requests
4 | from typing import Optional, List, Dict, Mapping, Any
5 |
6 | import langchain
7 | from langchain.llms.base import LLM
8 | from langchain.cache import InMemoryCache
9 |
10 | logging.basicConfig(level=logging.INFO)
11 | # 启动llm的缓存
12 | langchain.llm_cache = InMemoryCache()
13 |
14 |
15 | class AgentChatGLM(LLM):
16 | # 模型服务url
17 | url = "http://127.0.0.1:7866/chat"
18 | #url = "http://192.168.3.20:7866/chat" #3050服务器上
19 | history = []
20 |
21 | @property
22 | def _llm_type(self) -> str:
23 | return "chatglm"
24 |
25 | def _construct_query(self, prompt: str) -> Dict:
26 | """构造请求体
27 | """
28 | query = {"query": prompt, "history": self.history}
29 | import json
30 | query = json.dumps(query) # 对请求参数进行JSON编码
31 |
32 | return query
33 |
34 | def _construct_query_tools(self, prompt: str , tools: list ) -> Dict:
35 | """构造请求体
36 | """
37 | tools_info = {"role": "system",
38 | "content": "你现在是一个查找使用何种工具以及传递何种参数的工具助手,你会一步步的思考问题。你根据需求查找工具函数箱中最合适的工具函数,然后返回工具函数名称和所工具函数对应的参数,参数必须要和需求中的目标对应。",
39 | "tools": tools}
40 | query = {"query": prompt, "history": tools_info}
41 | import json
42 | query = json.dumps(query) # 对请求参数进行JSON编码
43 |
44 | return query
45 |
46 |
47 | @classmethod
48 | def _post(self, url: str, query: Dict) -> Any:
49 |
50 | """POST请求"""
51 | response = requests.post(url, data=query).json()
52 | return response
53 |
54 | def _call(self, prompt: str, stop: Optional[List[str]] = None, tools:list = None) -> str:
55 | """_call"""
56 | if tools == None:
57 | # construct query
58 | query = self._construct_query(prompt=prompt)
59 |
60 | # post
61 | response = self._post(url=self.url,query=query)
62 |
63 | response_chat = response["response"];
64 | self.history = response["history"]
65 |
66 | return response_chat
67 | else:
68 |
69 | query = self._construct_query_tools(prompt=prompt,tools=tools)
70 | # post
71 | response = self._post(url=self.url, query=query)
72 | self.history = response["history"] #这个history要放上面
73 | response = response["response"]
74 | try:
75 | #import ast
76 | #response = ast.literal_eval(response)
77 | ret = tool_register.dispatch_tool(response["name"], response["parameters"])
78 | response_chat = llm(prompt=ret)
79 | except:
80 | response_chat = response
81 | return str(response_chat)
82 |
83 | @property
84 | def _identifying_params(self) -> Mapping[str, Any]:
85 | """Get the identifying parameters.
86 | """
87 | _param_dict = {
88 | "url": self.url
89 | }
90 | return _param_dict
91 |
92 |
93 | if __name__ == "__main__":
94 |
95 | import tool_register
96 |
97 | # 获取注册后的全部工具,并以json的形式返回
98 | tools = tool_register.get_tools()
99 | "--------------------------------------首先是对tools的定义---------------------------------------"
100 |
101 | llm = AgentChatGLM()
102 | llm.url = "http://192.168.3.20:7866/chat"
103 | while True:
104 | while True:
105 | human_input = input("Human: ")
106 | if human_input == "tools":
107 | break
108 |
109 | begin_time = time.time() * 1000
110 | # 请求模型
111 | response = llm(human_input)
112 | end_time = time.time() * 1000
113 | used_time = round(end_time - begin_time, 3)
114 | #logging.info(f"chatGLM process time: {used_time}ms")
115 | print(f"Chat: {response}")
116 |
117 | human_input = input("Human_with_tools_Ask: ")
118 | response = llm(prompt=human_input,tools=tools)
119 | print(f"Chat_with_tools_Que: {response}")
120 |
121 |
122 |
123 |
124 |
--------------------------------------------------------------------------------
/build/lib/lmchain/agents/llmMultiActionAgent.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/virgo777/lmchain/3fc4a55df07100d4fe0beef4da89e3fbc6288938/build/lib/lmchain/agents/llmMultiActionAgent.py
--------------------------------------------------------------------------------
/build/lib/lmchain/agents/llmMultiAgent.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | import logging
4 | import requests
5 | from typing import Optional, List, Dict, Mapping, Any
6 | import langchain
7 | from langchain.llms.base import LLM
8 | from langchain.cache import InMemoryCache
9 |
10 | logging.basicConfig(level=logging.INFO)
11 | # 启动llm的缓存
12 | langchain.llm_cache = InMemoryCache()
13 |
14 |
15 | class AgentZhipuAI(LLM):
16 | import zhipuai as zhipuai
17 | # 模型服务url
18 | url = "127.0.0.1"
19 | zhipuai.api_key ="1f565e40af1198e11ff1fd8a5b42771d.SjNfezc40YFsz2KC"#控制台中获取的 APIKey 信息
20 | model = "chatglm_pro" # 大模型版本
21 |
22 | history = []
23 |
24 | def getText(self,role, content):
25 | # role 是指定角色,content 是 prompt 内容
26 | jsoncon = {}
27 | jsoncon["role"] = role
28 | jsoncon["content"] = content
29 | self.history.append(jsoncon)
30 | return self.history
31 |
32 | @property
33 | def _llm_type(self) -> str:
34 | return "AgentZhipuAI"
35 |
36 | @classmethod
37 | def _post(self, url: str, query: Dict) -> Any:
38 |
39 | """POST请求"""
40 | response = requests.post(url, data=query).json()
41 | return response
42 |
43 | def _call(self, prompt: str, stop: Optional[List[str]] = None,role = "user") -> str:
44 | """_call"""
45 | # construct query
46 | response = self.zhipuai.model_api.invoke(
47 | model=self.model,
48 | prompt=self.getText(role=role, content=prompt)
49 | )
50 | choices = (response['data']['choices'])[0]
51 | self.history.append(choices)
52 | return choices["content"]
53 |
54 | @property
55 | def _identifying_params(self) -> Mapping[str, Any]:
56 | """Get the identifying parameters.
57 | """
58 | _param_dict = {
59 | "url": self.url
60 | }
61 | return _param_dict
62 |
63 |
64 | if __name__ == '__main__':
65 | from langchain.prompts import PromptTemplate
66 | from langchain.chains import LLMChain
67 |
68 | llm = AgentZhipuAI()
69 |
70 | # 没有输入变量的示例prompt
71 | no_input_prompt = PromptTemplate(input_variables=[], template="给我讲个笑话。")
72 | no_input_prompt.format()
73 |
74 | prompt = PromptTemplate(
75 | input_variables=["location", "street"],
76 | template="作为一名专业的旅游顾问,简单的说一下{location}有什么好玩的景点,特别是在{street}?只要说一个就可以。",
77 | )
78 |
79 | chain = LLMChain(llm=llm, prompt=prompt)
80 | print(chain.run({"location": "南京", "street": "新街口"}))
81 |
82 |
83 | from langchain.chains import ConversationChain
84 | conversation = ConversationChain(llm=llm, verbose=True)
85 |
86 | output = conversation.predict(input="你好!")
87 | print(output)
88 |
89 | output = conversation.predict(input="南京是哪里的省会?")
90 | print(output)
91 |
92 | output = conversation.predict(input="那里有什么好玩的地方,简单的说一个就好。")
93 | print(output)
94 |
95 |
--------------------------------------------------------------------------------
/build/lib/lmchain/callbacks/__init__.py:
--------------------------------------------------------------------------------
1 | name = "callbacks"
--------------------------------------------------------------------------------
/build/lib/lmchain/callbacks/stdout.py:
--------------------------------------------------------------------------------
1 | """Callback Handler that prints to std out."""
2 | from typing import Any, Dict, List, Optional
3 |
4 | from langchain.callbacks.base import BaseCallbackHandler
5 | from langchain.schema import AgentAction, AgentFinish, LLMResult
6 | from lmchain.utils.input import print_text
7 |
8 |
9 | class StdOutCallbackHandler(BaseCallbackHandler):
10 | """Callback Handler that prints to std out."""
11 |
12 | def __init__(self, color: Optional[str] = None) -> None:
13 | """Initialize callback handler."""
14 | self.color = color
15 |
16 | def on_llm_start(
17 | self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
18 | ) -> None:
19 | """Print out the prompts."""
20 | pass
21 |
22 | def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
23 | """Do nothing."""
24 | pass
25 |
26 | def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
27 | """Do nothing."""
28 | pass
29 |
30 | def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
31 | """Do nothing."""
32 | pass
33 |
34 | def on_chain_start(
35 | self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
36 | ) -> None:
37 | """Print out that we are entering a chain."""
38 | class_name = serialized.get("name", serialized.get("id", [""])[-1])
39 | print(f"\n\n\033[1m> Entering new {class_name} chain...\033[0m")
40 |
41 | def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
42 | """Print out that we finished a chain."""
43 | print("\n\033[1m> Finished chain.\033[0m")
44 |
45 | def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
46 | """Do nothing."""
47 | pass
48 |
49 | def on_tool_start(
50 | self,
51 | serialized: Dict[str, Any],
52 | input_str: str,
53 | **kwargs: Any,
54 | ) -> None:
55 | """Do nothing."""
56 | pass
57 |
58 | def on_agent_action(
59 | self, action: AgentAction, color: Optional[str] = None, **kwargs: Any
60 | ) -> Any:
61 | """Run on agent action."""
62 | print_text(action.log, color=color or self.color)
63 |
64 | def on_tool_end(
65 | self,
66 | output: str,
67 | color: Optional[str] = None,
68 | observation_prefix: Optional[str] = None,
69 | llm_prefix: Optional[str] = None,
70 | **kwargs: Any,
71 | ) -> None:
72 | """If not the final action, print out observation."""
73 | if observation_prefix is not None:
74 | print_text(f"\n{observation_prefix}")
75 | print_text(output, color=color or self.color)
76 | if llm_prefix is not None:
77 | print_text(f"\n{llm_prefix}")
78 |
79 | def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
80 | """Do nothing."""
81 | pass
82 |
83 | def on_text(
84 | self,
85 | text: str,
86 | color: Optional[str] = None,
87 | end: str = "",
88 | **kwargs: Any,
89 | ) -> None:
90 | """Run when agent ends."""
91 | print_text(text, color=color or self.color, end=end)
92 |
93 | def on_agent_finish(
94 | self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any
95 | ) -> None:
96 | """Run on agent end."""
97 | print_text(finish.log, color=color or self.color, end="\n")
98 |
--------------------------------------------------------------------------------
/build/lib/lmchain/chains/__init__.py:
--------------------------------------------------------------------------------
1 | name = "chains"
--------------------------------------------------------------------------------
/build/lib/lmchain/chains/cmd.py:
--------------------------------------------------------------------------------
1 | #这里是执行对CMD命令进行调用的chain
2 | from langchain.chains.llm import LLMChain
3 | from langchain.prompts import PromptTemplate
4 | from lmchain.lmchain.agents import llmAgent
5 | import os,re
6 |
7 | class LLMCMDChain:
8 | def __init__(self ,llm):
9 | qa_prompt = PromptTemplate(template="""你现在根据需要完成对命令行的编写,要根据需求编写对应的在Windows系统终端运行的命令,不要用%question形参这种指代的参数形式,直接给出可以运行的命令。
10 | Question: 给我一个在Windows系统终端中可以准确执行{question}的命令。
11 | ,
12 | input_variables=["question"],
13 | )
14 | answer:""", input_variables=["question"], )
15 | self.qa_chain = LLMChain(llm=llm, prompt=qa_prompt)
16 | self.pattern = r"```(.*?)\```"
17 |
18 | def run(self ,text):
19 | cmd_response = self.qa_chain.run(question=text)
20 | cmd_string = str(cmd_response).split("```")[-2][1:-1]
21 | os.system(cmd_string)
22 | return cmd_string
23 |
--------------------------------------------------------------------------------
/build/lib/lmchain/chains/conversationalRetrievalChain.py:
--------------------------------------------------------------------------------
1 | from langchain.docstore.document import Document
2 | from langchain.text_splitter import RecursiveCharacterTextSplitter
3 | from lmchain.embeddings import embeddings
4 | from lmchain.vectorstores import laiss
5 | from lmchain.agents import llmMultiAgent
6 | from langchain.memory import ConversationBufferMemory
7 | from langchain.prompts import (
8 | ChatPromptTemplate, # 用于构建聊天模板的类
9 | MessagesPlaceholder, # 用于在模板中插入消息占位的类
10 | SystemMessagePromptTemplate, # 用于构建系统消息模板的类
11 | HumanMessagePromptTemplate # 用于构建人类消息模板的类
12 | )
13 | from langchain.chains import ConversationChain
14 |
15 | class ConversationalRetrievalChain:
16 | def __init__(self,document,chunk_size = 1280,chunk_overlap = 50,file_name = "这是一份辅助材料"):
17 | """
18 | :param document: 输入的文本内容,只要一个text文本
19 | :param chunk_size: 切分后每段的字数
20 | :param chunk_overlap: 每个相隔段落重叠的字数
21 | :param file_name: 文本名称/文本地址
22 | """
23 | self.text_splitter = RecursiveCharacterTextSplitter( chunk_size=chunk_size, chunk_overlap=chunk_overlap)
24 | self.embedding_tool = embeddings.GLMEmbedding()
25 |
26 | self.lmaiss = laiss.LMASS() #这里是用于将文本转化为vector,并且计算query相应的相似度的类
27 | self.llm = llmMultiAgent.AgentZhipuAI()
28 | self.memory = ConversationBufferMemory(return_messages=True)
29 |
30 | conversation_prompt = ChatPromptTemplate.from_messages([
31 | SystemMessagePromptTemplate.from_template("你是一个最强大的人工智能程序,可以知无不答,但是你不懂的东西会直接回答不知道。"),
32 | MessagesPlaceholder(variable_name="history"), # 历史消息占位符
33 | HumanMessagePromptTemplate.from_template("{input}") # 人类消息输入模板
34 | ])
35 |
36 | self.qa_chain = ConversationChain(memory=self.memory, prompt=conversation_prompt, llm=self.llm)
37 | "---------------------------"
38 | document = [Document(page_content=document, metadata={"source": file_name})] #对输入的document进行格式化处理
39 | self.documents = self.text_splitter.split_documents(document) #根据
40 | self.vectorstore = self.lmaiss.from_documents(self.documents, embedding_class=self.embedding_tool)
41 |
42 | def __call__(self, query):
43 | query_embedding = self.embedding_tool.embed_query(query)
44 |
45 | #根据query查找最近的那个序列
46 | close_id = self.lmaiss.get_similarity_vector_indexs(query_embedding, self.vectorstore, k=1)[0]
47 | #查找最近的那个段落id
48 | doc = self.documents[close_id]
49 |
50 | #构建查询的query
51 | query = f"你现在要回答问题'{query}',你可以参考文献'{doc}',你如果找不到对应的内容,就从自己的记忆体中查找,就回答'请提供更为准确的查询内容',注意你要一步步的思考再回答。"
52 | result = (self.qa_chain.predict(input=query))
53 | return result
54 |
55 | def predict(self,input):
56 | result = self.__call__(input)
57 | return result
--------------------------------------------------------------------------------
/build/lib/lmchain/chains/mathChain.py:
--------------------------------------------------------------------------------
1 | #这里是执行对CMD命令进行调用的chain
2 |
3 | from langchain.chains.llm import LLMChain
4 | from langchain.prompts import PromptTemplate
5 | from lmchain.lmchain.agents import llmAgent
6 | import os,re,math
7 |
8 | try:
9 | import numexpr # noqa: F401
10 | except ImportError:
11 | raise ImportError(
12 | "LMchain requires the numexpr package. "
13 | "Please install it with `pip install numexpr`."
14 | )
15 |
16 |
17 | class LLMMathChain:
18 | def __init__(self ,llm):
19 | qa_prompt = PromptTemplate(template="""现在给你一个中文命令,请你把这个命令转化成数学公式。直接给出数学公式。这个公式会在numexpr包中调用。
20 | Question: 我现在需要计算{question},结果需要在numexpr包中调用。
21 | ,
22 | input_variables=["question"],
23 | )
24 | answer:""", input_variables=["question"], )
25 | self.qa_chain = LLMChain(llm=llm, prompt=qa_prompt)
26 |
27 |
28 | def run(self ,text):
29 | cmd_response = self.qa_chain.run(question=text)
30 | result = self._evaluate_expression(str(cmd_response))
31 | return result
32 |
33 |
34 | def _evaluate_expression(self, expression: str) -> str:
35 | import numexpr # noqa: F401
36 |
37 | try:
38 | local_dict = {"pi": math.pi, "e": math.e}
39 | output = str(
40 | numexpr.evaluate(
41 | expression.strip(),
42 | global_dict={}, # restrict access to globals
43 | local_dict=local_dict, # add common mathematical functions
44 | )
45 | )
46 | except Exception as e:
47 | raise ValueError(
48 | f'LMchain._evaluate("{expression}") raised error: {e}.'
49 | " Please try again with a valid numerical expression"
50 | )
51 |
52 | # Remove any leading and trailing brackets from the output
53 | return re.sub(r"^\[|\]$", "", output)
--------------------------------------------------------------------------------
/build/lib/lmchain/chains/question_answering.py:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/build/lib/lmchain/chains/toolchain.py:
--------------------------------------------------------------------------------
1 | from langchain.chains import LLMChain
2 | from langchain.prompts import PromptTemplate
3 |
4 | from tqdm import tqdm
5 | from lmchain.tools import tool_register
6 |
7 |
8 | class GLMToolChain:
9 | def __init__(self, llm):
10 |
11 | self.llm = llm
12 | self.tool_register = tool_register
13 | self.tools = tool_register.get_tools()
14 |
15 | def __call__(self, query="", tools=None):
16 |
17 | if query == "":
18 | raise "query需要填入查询问题"
19 | if tools != None:
20 | self.tools = tools
21 | else:
22 | raise "将使用默认tools完成函数工具调用~"
23 | template = f"""
24 | 你现在是一个专业的人工智能助手,你现在的需求是{query}。而你需要借助于工具在{self.tools}中找到对应的函数,用json格式返回对应的函数名和参数。
25 | 函数名定义为function_name,参数名为params,还要求写入详细的形参与实参。
26 |
27 | 如果找到合适的函数,就返回json格式的函数名和需要的参数,不要回答任何描述和解释。
28 |
29 | 如果没有找到合适的函数,则返回:'未找到合适参数,请提供更详细的描述。'
30 | """
31 |
32 | flag = True
33 | counter = 0
34 | while flag:
35 | try:
36 | res = self.llm(template)
37 |
38 | import json
39 | res_dict = json.loads(res)
40 | res_dict = json.loads(res_dict)
41 | flag = False
42 | except:
43 | # print("失败输出,现在开始重新验证")
44 | template = f"""
45 | 你现在是一个专业的人工智能助手,你现在的需求是{query}。而你需要借助于工具在{self.tools}中找到对应的函数,用json格式返回对应的函数名和参数。
46 | 函数名定义为function_name,参数名为params,还要求写入详细的形参与实参。
47 |
48 | 如果找到合适的函数,就返回json格式的函数名和需要的参数,不要回答任何描述和解释。
49 |
50 | 如果没有找到合适的函数,则返回:'未找到合适参数,请提供更详细的描述。'
51 |
52 | 你刚才生成了一组结果,但是返回不符合json格式,现在请你重新按json格式生成并返回结果。
53 | """
54 | counter += 1
55 | if counter >= 5:
56 | return '未找到合适参数,请提供更详细的描述。'
57 | return res_dict
58 |
59 | def run(self, query, tools=None):
60 | tools = (self.tool_register.get_tools())
61 | result = self.__call__(query, tools)
62 |
63 | if result == "未找到合适参数,请提供更详细的描述。":
64 | return "未找到合适参数,请提供更详细的描述。"
65 | else:
66 | print("找到对应工具函数,格式如下:", result)
67 | result = self.dispatch_tool(result)
68 | from lmchain.prompts.templates import PromptTemplate
69 | tool_prompt = PromptTemplate(
70 | input_variables=["query", "result"], # 输入变量包括中文和英文。
71 | template="你现在是一个私人助手,现在你的查询任务是{query},而你通过工具从网上查询的结果是{result},现在根据查询的内容与查询的结果,生成最终答案。",
72 | # 使用模板格式化输入和输出。
73 | )
74 | from langchain.chains import LLMChain
75 | chain = LLMChain(llm=self.llm, prompt=tool_prompt)
76 |
77 | response = (chain.run({"query": query, "result": result}))
78 |
79 | return response
80 |
81 | def add_tools(self, tool):
82 | self.tool_register.register_tool(tool)
83 | return True
84 |
85 | def dispatch_tool(self, tool_result) -> str:
86 | tool_name = tool_result["function_name"]
87 | tool_params = tool_result["params"]
88 | if tool_name not in self.tool_register._TOOL_HOOKS:
89 | return f"Tool `{tool_name}` not found. Please use a provided tool."
90 | tool_call = self.tool_register._TOOL_HOOKS[tool_name]
91 |
92 | try:
93 | ret = tool_call(**tool_params)
94 | except:
95 | import traceback
96 | ret = traceback.format_exc()
97 | return str(ret)
98 |
99 | def get_tools(self):
100 | return (self.tool_register.get_tools())
101 |
102 |
103 | if __name__ == '__main__':
104 | from lmchain.agents import llmMultiAgent
105 |
106 | llm = llmMultiAgent.AgentZhipuAI()
107 |
108 | from lmchain.chains import toolchain
109 |
110 | tool_chain = toolchain.GLMToolChain(llm)
111 |
112 | from typing import Annotated
113 |
114 |
115 | def rando_numbr(
116 | seed: Annotated[int, 'The random seed used by the generator', True],
117 | range: Annotated[tuple[int, int], 'The range of the generated numbers', True],
118 | ) -> int:
119 | """
120 | Generates a random number x, s.t. range[0] <= x < range[1]
121 | """
122 | import random
123 | return random.Random(seed).randint(*range)
124 |
125 |
126 | tool_chain.add_tools(rando_numbr)
127 |
128 | print("------------------------------------------------------")
129 | query = "今天shanghai的天气是什么?"
130 | result = tool_chain.run(query)
131 |
132 | result = tool_chain.dispatch_tool(result)
133 | print(result)
134 |
135 |
136 |
--------------------------------------------------------------------------------
/build/lib/lmchain/chains/urlRequestChain.py:
--------------------------------------------------------------------------------
1 | from langchain.chains import LLMRequestsChain, LLMChain
2 | from langchain.prompts import PromptTemplate
3 |
4 | import requests
5 | from bs4 import BeautifulSoup
6 | from tqdm import tqdm
7 |
8 |
9 | class LMRequestsChain:
10 | def __init__(self,llm,max_url_num = 2):
11 | template = """Between >>> and <<< are the raw search result text from google.
12 | Extract the answer to the question '{query}' or say "not found" if the information is not contained.
13 | Use the format
14 | Extracted:
15 | >>> {requests_result} <<<
16 | Extracted:"""
17 | PROMPT = PromptTemplate(
18 | input_variables=["query", "requests_result"],
19 | template=template,
20 | )
21 | self.chain = LLMRequestsChain(llm_chain=LLMChain(llm=llm, prompt=PROMPT))
22 | self.max_url_num = max_url_num
23 |
24 | query_prompt = PromptTemplate(
25 | input_variables=["query","responses"],
26 | template = "作为一名专业的信息总结员,我需要查询的信息为{query},根据提供的信息{responses}回答一下查询的结果。")
27 | self.query_chain = LLMChain(llm=llm, prompt=query_prompt)
28 |
29 | def __call__(self, query,target_site = ""):
30 | url_list = self.get_urls(query,target_site = target_site)
31 | print(f"查找到{len(url_list)}条url内容,现在开始解析其中的{self.max_url_num}条内容。")
32 | responses = []
33 | for url in tqdm(url_list[:self.max_url_num]):
34 | inputs = {
35 | "query": query,
36 | "url": url
37 | }
38 |
39 | response = self.chain(inputs)
40 | output = response["output"]
41 | responses.append(output)
42 | if len(responses) != 0:
43 | output = self.query_chain.run({"query":query,"responses":responses})
44 | return output
45 | else:
46 | return "查找内容为空,请更换查找词"
47 |
48 | def query_form_url(self,query = "LMchain是什么?",url = ""):
49 | assert url != "",print("url link must be set")
50 | inputs = {
51 | "query": query,
52 | "url": url
53 | }
54 | response = self.chain(inputs)
55 | return response
56 |
57 | def get_urls(self,query='lmchain是什么?', target_site=""):
58 | def bing_search(query, count=30):
59 | url = f'https://cn.bing.com/search?q={query}'
60 | headers = {
61 | 'User-Agent': 'Mozilla/6.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.3'}
62 | response = requests.get(url, headers=headers)
63 | if response.status_code == 200:
64 | html = response.text
65 | # 使用BeautifulSoup解析HTML
66 |
67 | soup = BeautifulSoup(html, 'html.parser')
68 | results = soup.find_all('li', class_='b_algo')
69 | return [result.find('a').text for result in results[:count]]
70 | else:
71 | print(f'请求失败,状态码:{response.status_code}')
72 | return []
73 | results = bing_search(query)
74 | if len(results) == 0:
75 | return None
76 | url_list = []
77 | if target_site != "":
78 | for i, result in enumerate(results):
79 | if "https" in result and target_site in result:
80 | url = "https://" + result.split("https://")[1]
81 | url_list.append(url)
82 | else:
83 | for i, result in enumerate(results):
84 | if "https" in result:
85 | url = "https://" + result.split("https://")[1]
86 | url_list.append(url)
87 | if len(url_list) > 0:
88 | return url_list
89 | else:
90 | # 这里是确保在知乎里面找不到对应的内容,有相应的内容返回
91 | for i, result in enumerate(results):
92 | if "https" in result:
93 | url = "https://" + result.split("https://")[1]
94 | url_list.append(url)
95 | return url_list
96 |
97 |
98 |
--------------------------------------------------------------------------------
/build/lib/lmchain/embeddings/__init__.py:
--------------------------------------------------------------------------------
1 | name = "embeddings"
--------------------------------------------------------------------------------
/build/lib/lmchain/embeddings/embeddings.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | from abc import ABC, abstractmethod
3 | from typing import List
4 |
5 |
6 | class Embeddings(ABC):
7 | """Interface for embedding models."""
8 |
9 | @abstractmethod
10 | def embed_documents(self, texts: List[str]) -> List[List[float]]:
11 | """Embed search docs."""
12 |
13 | @abstractmethod
14 | def embed_query(self, text: str) -> List[float]:
15 | """Embed query text."""
16 |
17 | async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
18 | """Asynchronous Embed search docs."""
19 | return await asyncio.get_running_loop().run_in_executor(
20 | None, self.embed_documents, texts
21 | )
22 |
23 | async def aembed_query(self, text: str) -> List[float]:
24 | """Asynchronous Embed query text."""
25 | return await asyncio.get_running_loop().run_in_executor(
26 | None, self.embed_query, text
27 | )
28 |
29 | class LMEmbedding(Embeddings):
30 | from modelscope.pipelines import pipeline
31 | from modelscope.utils.constant import Tasks
32 | pipeline_se = pipeline(Tasks.sentence_embedding,model='thomas/text2vec-base-chinese', model_revision='v1.0.0',device="cuda")
33 |
34 |
35 | def _costruct_inputs(self,texts):
36 |
37 | inputs = {
38 | "source_sentence": texts
39 | }
40 |
41 | return inputs
42 |
43 | def embed_documents(self, texts: List[str]) -> List[List[float]]:
44 | """Embed search docs."""
45 |
46 | inputs = self._costruct_inputs(texts)
47 | result_embeddings = self.pipeline_se(input=inputs)
48 | return result_embeddings["text_embedding"]
49 |
50 | def embed_query(self, text: str) -> List[float]:
51 | """Embed query text."""
52 | inputs = self._costruct_inputs([text])
53 | result_embeddings = self.pipeline_se(input=inputs)
54 | return result_embeddings["text_embedding"]
55 |
56 |
57 | class GLMEmbedding(Embeddings):
58 | import zhipuai as zhipuai
59 | zhipuai.api_key = "1f565e40af1198e11ff1fd8a5b42771d.SjNfezc40YFsz2KC" # 控制台中获取的 APIKey 信息
60 | def _costruct_inputs(self, texts):
61 | inputs = {
62 | "source_sentence": texts
63 | }
64 |
65 | return inputs
66 |
67 | def embed_documents(self, texts: List[str]) -> List[List[float]]:
68 | """Embed search docs."""
69 | result_embeddings = []
70 | for text in texts:
71 | embedding = self.embed_query(text)
72 | result_embeddings.append(embedding)
73 | return result_embeddings
74 |
75 | def embed_query(self, text: str) -> List[float]:
76 | """Embed query text."""
77 | result_embeddings = self.zhipuai.model_api.invoke(
78 | model="text_embedding", prompt=text)
79 | return result_embeddings["data"]["embedding"]
80 |
81 |
82 |
83 |
84 | if __name__ == '__main__':
85 | inputs = ["不可以,早晨喝牛奶不科学","不可以,今天早晨喝牛奶不科学","早晨喝牛奶不科学"]
86 | print(GLMEmbedding().embed_documents(inputs))
87 |
88 |
89 |
--------------------------------------------------------------------------------
/build/lib/lmchain/hello.py:
--------------------------------------------------------------------------------
1 | print("hello world")
--------------------------------------------------------------------------------
/build/lib/lmchain/index/__init__.py:
--------------------------------------------------------------------------------
1 | name = "index"
--------------------------------------------------------------------------------
/build/lib/lmchain/index/indexChain.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Dict, List, Optional, Type
2 |
3 | from langchain.chains.qa_with_sources.retrieval import RetrievalQAWithSourcesChain
4 | from langchain.chains.retrieval_qa.base import RetrievalQA
5 | from langchain.document_loaders.base import BaseLoader
6 | from pydantic.v1 import BaseModel, Extra, Field
7 | from langchain.schema import Document
8 | from langchain.schema.embeddings import Embeddings
9 | from langchain.schema.language_model import BaseLanguageModel
10 | from langchain.schema.vectorstore import VectorStore
11 | from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter
12 | from langchain.vectorstores.chroma import Chroma
13 |
14 |
15 | def _get_default_text_splitter() -> TextSplitter:
16 | return RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
17 |
18 | from lmchain.embeddings import embeddings
19 | embedding_tool = embeddings.GLMEmbedding()
20 |
21 | class VectorstoreIndexCreator(BaseModel):
22 | """Logic for creating indexes."""
23 |
24 | class Config:
25 | """Configuration for this pydantic object."""
26 | extra = Extra.forbid
27 | arbitrary_types_allowed = True
28 |
29 |
30 |
31 |
32 | chunk_size = 1280 # 每段字数长度
33 | chunk_overlap = 32 # 重叠的字数
34 | text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
35 |
36 |
37 |
38 |
39 |
40 | def from_loaders(self, loaders: List[BaseLoader]):
41 | """Create a vectorstore index from loaders."""
42 | docs = []
43 | for loader in loaders:
44 | docs.extend(loader.load())
45 | return self.from_documents(docs)
46 |
47 |
48 | def from_documents(self, documents: List[Document]):
49 | #说一下这个index的作用就是返回
50 | sub_docs = self.text_splitter.split_documents(documents)
51 |
52 | # texts = [d.page_content for d in sub_docs]
53 | # metadatas = [d.metadata for d in sub_docs]
54 |
55 | qa_chain = ConversationalRetrievalChain(document=sub_docs)
56 | return qa_chain
57 |
58 |
59 | from langchain.docstore.document import Document
60 | from langchain.text_splitter import RecursiveCharacterTextSplitter
61 | from lmchain.embeddings import embeddings
62 | from lmchain.vectorstores import laiss
63 | from lmchain.agents import llmMultiAgent
64 | from langchain.memory import ConversationBufferMemory
65 | from langchain.prompts import (
66 | ChatPromptTemplate, # 用于构建聊天模板的类
67 | MessagesPlaceholder, # 用于在模板中插入消息占位的类
68 | SystemMessagePromptTemplate, # 用于构建系统消息模板的类
69 | HumanMessagePromptTemplate # 用于构建人类消息模板的类
70 | )
71 | from langchain.chains import ConversationChain
72 |
73 | class ConversationalRetrievalChain:
74 | def __init__(self,document,chunk_size = 1280,chunk_overlap = 50,file_name = "这是一份辅助材料"):
75 | """
76 | :param document: 输入的文本内容,只要一个text文本
77 | :param chunk_size: 切分后每段的字数
78 | :param chunk_overlap: 每个相隔段落重叠的字数
79 | :param file_name: 文本名称/文本地址
80 | """
81 | self.text_splitter = RecursiveCharacterTextSplitter( chunk_size=chunk_size, chunk_overlap=chunk_overlap)
82 | self.embedding_tool = embedding_tool
83 |
84 | self.lmaiss = laiss.LMASS() #这里是用于将文本转化为vector,并且计算query相应的相似度的类
85 | self.llm = llmMultiAgent.AgentZhipuAI()
86 | self.memory = ConversationBufferMemory(return_messages=True)
87 |
88 | conversation_prompt = ChatPromptTemplate.from_messages([
89 | SystemMessagePromptTemplate.from_template("你是一个最强大的人工智能程序,可以知无不答,但是你不懂的东西会直接回答不知道。"),
90 | MessagesPlaceholder(variable_name="history"), # 历史消息占位符
91 | HumanMessagePromptTemplate.from_template("{input}") # 人类消息输入模板
92 | ])
93 |
94 | self.qa_chain = ConversationChain(memory=self.memory, prompt=conversation_prompt, llm=self.llm)
95 | "---------------------------"
96 | self.metadatas = []
97 | for doc in document:
98 | self.metadatas.append(doc.metadata)
99 | self.documents = self.text_splitter.split_documents(document) #根据
100 | self.vectorstore = self.lmaiss.from_documents(self.documents, embedding_class=self.embedding_tool)
101 |
102 |
103 |
104 | def __call__(self, query):
105 | query_embedding = self.embedding_tool.embed_query(query)
106 |
107 | #根据query查找最近的那个序列
108 | close_id = self.lmaiss.get_similarity_vector_indexs(query_embedding, self.vectorstore, k=1)[0]
109 | #查找最近的那个段落id
110 | doc = self.documents[close_id]
111 | meta = self.metadatas[close_id]
112 | #构建查询的query
113 | query = f"你现在要回答问题'{query}',你可以参考文献'{doc}',你如果找不到对应的内容,就从自己的记忆体中查找,就回答'请提供更为准确的查询内容'。"
114 | result = (self.qa_chain.predict(input=query))
115 | return result,meta
116 |
117 |
118 | def query(self,input):
119 | result,meta = self.__call__(input)
120 | return result
121 |
122 | #这里的模型的意思是
123 | def query_with_sources(self,input):
124 | result,meta = self.__call__(input)
125 | return {"answer":result,"sources":meta}
126 |
--------------------------------------------------------------------------------
/build/lib/lmchain/llms/__init__.py:
--------------------------------------------------------------------------------
1 | name = "llms"
--------------------------------------------------------------------------------
/build/lib/lmchain/llms/base.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/virgo777/lmchain/3fc4a55df07100d4fe0beef4da89e3fbc6288938/build/lib/lmchain/llms/base.py
--------------------------------------------------------------------------------
/build/lib/lmchain/load/__init__.py:
--------------------------------------------------------------------------------
1 | name = "load"
--------------------------------------------------------------------------------
/build/lib/lmchain/memory/__init__.py:
--------------------------------------------------------------------------------
1 | name = "memory"
--------------------------------------------------------------------------------
/build/lib/lmchain/memory/chat_memory.py:
--------------------------------------------------------------------------------
1 | from abc import ABC
2 | from typing import Any, Dict, Optional, Tuple
3 |
4 | from lmchain.memory.utils import get_prompt_input_key
5 |
6 | from lmchain.schema.memory import BaseMemory
7 |
8 |
9 | class BaseChatMemory(BaseMemory, ABC):
10 | """Abstract base class for chat memory."""
11 |
12 | from lmchain.memory import messageHistory
13 | chat_memory = messageHistory.ChatMessageHistory()
14 | output_key: Optional[str] = None
15 | input_key: Optional[str] = None
16 | return_messages: bool = False
17 |
18 | def _get_input_output(
19 | self, inputs: Dict[str, Any], outputs: Dict[str, str]
20 | ) -> Tuple[str, str]:
21 | if self.input_key is None:
22 | prompt_input_key = get_prompt_input_key(inputs, self.memory_variables)
23 | else:
24 | prompt_input_key = self.input_key
25 | if self.output_key is None:
26 | if len(outputs) != 1:
27 | raise ValueError(f"One output key expected, got {outputs.keys()}")
28 | output_key = list(outputs.keys())[0]
29 | else:
30 | output_key = self.output_key
31 | return inputs[prompt_input_key], outputs[output_key]
32 |
33 | def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
34 | """Save context from this conversation to buffer."""
35 | input_str, output_str = self._get_input_output(inputs, outputs)
36 | self.chat_memory.add_user_message(input_str)
37 | self.chat_memory.add_ai_message(output_str)
38 |
39 | def clear(self) -> None:
40 | """Clear memory contents."""
41 | self.chat_memory.clear()
42 |
--------------------------------------------------------------------------------
/build/lib/lmchain/memory/messageHistory.py:
--------------------------------------------------------------------------------
1 | from typing import TYPE_CHECKING, Any, Dict, List, Sequence, Union
2 | from typing_extensions import Literal
3 |
4 |
5 | class ChatMessageHistory:
6 | """In memory implementation of chat message history.
7 |
8 | Stores messages in an in memory list.
9 | """
10 |
11 | messages = []
12 |
13 | def add_message(self, message) -> None:
14 | """Add a self-created message to the store"""
15 | self.messages.append(message)
16 |
17 | def clear(self) -> None:
18 | self.messages = []
19 |
20 | def __str__(self):
21 | return ", ".join(str(message) for message in self.messages)
22 |
23 |
24 | class ChatMessageHistory(ChatMessageHistory):
25 | def __init__(self):
26 | super(ChatMessageHistory).__init__()
27 |
28 | def add_user_message(self, content: str) -> None:
29 | """Convenience method for adding a human message string to the store.
30 |
31 | Args:
32 | content: The string contents of a human message.
33 | """
34 | mes = f"HumanMessage(content={content})"
35 | self.messages.append(mes)
36 |
37 | def add_ai_message(self, content: str) -> None:
38 | """Convenience method for adding an AI message string to the store.
39 |
40 | Args:
41 | content: The string contents of an AI message.
42 | """
43 | mes = f"AIMessage(content={content})"
44 | self.messages.append(mes)
45 |
46 |
47 | from typing import Any, Dict, List, Optional
48 |
49 | from langchain.memory.chat_memory import BaseChatMemory, BaseMemory
50 | from langchain.memory.utils import get_prompt_input_key
51 | from pydantic.v1 import root_validator
52 | from langchain.schema.messages import BaseMessage, get_buffer_string
53 |
54 |
55 | class ConversationBufferMemory(BaseChatMemory):
56 | """Buffer for storing conversation memory."""
57 |
58 | human_prefix: str = "Human"
59 | ai_prefix: str = "AI"
60 | memory_key: str = "history" #: :meta private:
61 |
62 | @property
63 | def buffer(self) -> Any:
64 | """String buffer of memory."""
65 | return self.buffer_as_messages if self.return_messages else self.buffer_as_str
66 |
67 | @property
68 | def buffer_as_str(self) -> str:
69 | """Exposes the buffer as a string in case return_messages is True."""
70 | return get_buffer_string(
71 | self.chat_memory.messages,
72 | human_prefix=self.human_prefix,
73 | ai_prefix=self.ai_prefix,
74 | )
75 |
76 | @property
77 | def buffer_as_messages(self) -> List[BaseMessage]:
78 | """Exposes the buffer as a list of messages in case return_messages is False."""
79 | return self.chat_memory.messages
80 |
81 | @property
82 | def memory_variables(self) -> List[str]:
83 | """Will always return list of memory variables.
84 |
85 | :meta private:
86 | """
87 | return [self.memory_key]
88 |
89 | def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
90 | """Return history buffer."""
91 | return {self.memory_key: self.buffer}
92 |
93 |
94 | class ConversationStringBufferMemory(BaseMemory):
95 | """Buffer for storing conversation memory."""
96 |
97 | human_prefix: str = "Human"
98 | ai_prefix: str = "AI"
99 | """Prefix to use for AI generated responses."""
100 | buffer: str = ""
101 | output_key: Optional[str] = None
102 | input_key: Optional[str] = None
103 | memory_key: str = "history" #: :meta private:
104 |
105 | @root_validator()
106 | def validate_chains(cls, values: Dict) -> Dict:
107 | """Validate that return messages is not True."""
108 | if values.get("return_messages", False):
109 | raise ValueError(
110 | "return_messages must be False for ConversationStringBufferMemory"
111 | )
112 | return values
113 |
114 | @property
115 | def memory_variables(self) -> List[str]:
116 | """Will always return list of memory variables.
117 | :meta private:
118 | """
119 | return [self.memory_key]
120 |
121 | def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]:
122 | """Return history buffer."""
123 | return {self.memory_key: self.buffer}
124 |
125 | def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
126 | """Save context from this conversation to buffer."""
127 | if self.input_key is None:
128 | prompt_input_key = get_prompt_input_key(inputs, self.memory_variables)
129 | else:
130 | prompt_input_key = self.input_key
131 | if self.output_key is None:
132 | if len(outputs) != 1:
133 | raise ValueError(f"One output key expected, got {outputs.keys()}")
134 | output_key = list(outputs.keys())[0]
135 | else:
136 | output_key = self.output_key
137 | human = f"{self.human_prefix}: " + inputs[prompt_input_key]
138 | ai = f"{self.ai_prefix}: " + outputs[output_key]
139 | self.buffer += "\n" + "\n".join([human, ai])
140 |
141 | def clear(self) -> None:
142 | """Clear memory contents."""
143 | self.buffer = ""
144 |
145 |
146 |
147 |
148 |
149 |
150 |
151 |
152 |
--------------------------------------------------------------------------------
/build/lib/lmchain/memory/utils.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Dict, List
2 |
3 |
4 | def get_prompt_input_key(inputs: Dict[str, Any], memory_variables: List[str]) -> str:
5 | """
6 | Get the prompt input key.
7 |
8 | Args:
9 | inputs: Dict[str, Any]
10 | memory_variables: List[str]
11 |
12 | Returns:
13 | A prompt input key.
14 | """
15 | # "stop" is a special key that can be passed as input but is not used to
16 | # format the prompt.
17 | prompt_input_keys = list(set(inputs).difference(memory_variables + ["stop"]))
18 | if len(prompt_input_keys) != 1:
19 | raise ValueError(f"One input key expected got {prompt_input_keys}")
20 | return prompt_input_keys[0]
--------------------------------------------------------------------------------
/build/lib/lmchain/model/__init__.py:
--------------------------------------------------------------------------------
1 | name = "model"
--------------------------------------------------------------------------------
/build/lib/lmchain/prompts/__init__.py:
--------------------------------------------------------------------------------
1 | name = "prompts"
--------------------------------------------------------------------------------
/build/lib/lmchain/prompts/base.py:
--------------------------------------------------------------------------------
1 | """BasePrompt schema definition."""
2 | from __future__ import annotations
3 |
4 | import warnings
5 | from abc import ABC
6 | from string import Formatter
7 | from typing import Any, Callable, Dict, List, Literal, Set
8 |
9 | from lmchain.schema.messages import BaseMessage, HumanMessage
10 | from lmchain.schema.prompt import PromptValue
11 | from lmchain.schema.prompt_template import BasePromptTemplate
12 | #from langchain.schema.prompt_template import BasePromptTemplate
13 | from lmchain.utils.formatting import formatter
14 |
15 |
16 | def jinja2_formatter(template: str, **kwargs: Any) -> str:
17 | """Format a template using jinja2.
18 |
19 | *Security warning*: As of LangChain 0.0.329, this method uses Jinja2's
20 | SandboxedEnvironment by default. However, this sand-boxing should
21 | be treated as a best-effort approach rather than a guarantee of security.
22 | Do not accept jinja2 templates from untrusted sources as they may lead
23 | to arbitrary Python code execution.
24 |
25 | https://jinja.palletsprojects.com/en/3.1.x/sandbox/
26 | """
27 | try:
28 | from jinja2.sandbox import SandboxedEnvironment
29 | except ImportError:
30 | raise ImportError(
31 | "jinja2 not installed, which is needed to use the jinja2_formatter. "
32 | "Please install it with `pip install jinja2`."
33 | "Please be cautious when using jinja2 templates. "
34 | "Do not expand jinja2 templates using unverified or user-controlled "
35 | "inputs as that can result in arbitrary Python code execution."
36 | )
37 |
38 | # This uses a sandboxed environment to prevent arbitrary code execution.
39 | # Jinja2 uses an opt-out rather than opt-in approach for sand-boxing.
40 | # Please treat this sand-boxing as a best-effort approach rather than
41 | # a guarantee of security.
42 | # We recommend to never use jinja2 templates with untrusted inputs.
43 | # https://jinja.palletsprojects.com/en/3.1.x/sandbox/
44 | # approach not a guarantee of security.
45 | return SandboxedEnvironment().from_string(template).render(**kwargs)
46 |
47 |
48 | def validate_jinja2(template: str, input_variables: List[str]) -> None:
49 | """
50 | Validate that the input variables are valid for the template.
51 | Issues a warning if missing or extra variables are found.
52 |
53 | Args:
54 | template: The template string.
55 | input_variables: The input variables.
56 | """
57 | input_variables_set = set(input_variables)
58 | valid_variables = _get_jinja2_variables_from_template(template)
59 | missing_variables = valid_variables - input_variables_set
60 | extra_variables = input_variables_set - valid_variables
61 |
62 | warning_message = ""
63 | if missing_variables:
64 | warning_message += f"Missing variables: {missing_variables} "
65 |
66 | if extra_variables:
67 | warning_message += f"Extra variables: {extra_variables}"
68 |
69 | if warning_message:
70 | warnings.warn(warning_message.strip())
71 |
72 |
73 | def _get_jinja2_variables_from_template(template: str) -> Set[str]:
74 | try:
75 | from jinja2 import Environment, meta
76 | except ImportError:
77 | raise ImportError(
78 | "jinja2 not installed, which is needed to use the jinja2_formatter. "
79 | "Please install it with `pip install jinja2`."
80 | )
81 | env = Environment()
82 | ast = env.parse(template)
83 | variables = meta.find_undeclared_variables(ast)
84 | return variables
85 |
86 |
87 | DEFAULT_FORMATTER_MAPPING: Dict[str, Callable] = {
88 | "f-string": formatter.format,
89 | "jinja2": jinja2_formatter,
90 | }
91 |
92 | DEFAULT_VALIDATOR_MAPPING: Dict[str, Callable] = {
93 | "f-string": formatter.validate_input_variables,
94 | "jinja2": validate_jinja2,
95 | }
96 |
97 |
98 | def check_valid_template(
99 | template: str, template_format: str, input_variables: List[str]
100 | ) -> None:
101 | """Check that template string is valid.
102 |
103 | Args:
104 | template: The template string.
105 | template_format: The template format. Should be one of "f-string" or "jinja2".
106 | input_variables: The input variables.
107 |
108 | Raises:
109 | ValueError: If the template format is not supported.
110 | """
111 | if template_format not in DEFAULT_FORMATTER_MAPPING:
112 | valid_formats = list(DEFAULT_FORMATTER_MAPPING)
113 | raise ValueError(
114 | f"Invalid template format. Got `{template_format}`;"
115 | f" should be one of {valid_formats}"
116 | )
117 | try:
118 | validator_func = DEFAULT_VALIDATOR_MAPPING[template_format]
119 | validator_func(template, input_variables)
120 | except KeyError as e:
121 | raise ValueError(
122 | "Invalid prompt schema; check for mismatched or missing input parameters. "
123 | + str(e)
124 | )
125 |
126 |
127 | def get_template_variables(template: str, template_format: str) -> List[str]:
128 | """Get the variables from the template.
129 |
130 | Args:
131 | template: The template string.
132 | template_format: The template format. Should be one of "f-string" or "jinja2".
133 |
134 | Returns:
135 | The variables from the template.
136 |
137 | Raises:
138 | ValueError: If the template format is not supported.
139 | """
140 | if template_format == "jinja2":
141 | # Get the variables for the template
142 | input_variables = _get_jinja2_variables_from_template(template)
143 | elif template_format == "f-string":
144 | input_variables = {
145 | v for _, v, _, _ in Formatter().parse(template) if v is not None
146 | }
147 | else:
148 | raise ValueError(f"Unsupported template format: {template_format}")
149 |
150 | return sorted(input_variables)
151 |
152 |
153 | class StringPromptValue(PromptValue):
154 | """String prompt value."""
155 |
156 | text: str
157 | """Prompt text."""
158 | type: Literal["StringPromptValue"] = "StringPromptValue"
159 |
160 | def to_string(self) -> str:
161 | """Return prompt as string."""
162 | return self.text
163 |
164 | def to_messages(self) -> List[BaseMessage]:
165 | """Return prompt as messages."""
166 | return [HumanMessage(content=self.text)]
167 |
168 |
169 | class StringPromptTemplate(BasePromptTemplate, ABC):
170 | """String prompt that exposes the format method, returning a prompt."""
171 |
172 | def format_prompt(self, **kwargs: Any) -> PromptValue:
173 | """Create Chat Messages."""
174 | return StringPromptValue(text=self.format(**kwargs))
175 |
--------------------------------------------------------------------------------
/build/lib/lmchain/schema/__init__.py:
--------------------------------------------------------------------------------
1 | name = "schema"
--------------------------------------------------------------------------------
/build/lib/lmchain/schema/agent.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from typing import Any, Literal, Sequence, Union
4 |
5 | from lmchain.load.serializable import Serializable
6 | from lmchain.schema.messages import BaseMessage
7 |
8 |
9 | class AgentAction(Serializable):
10 | """A full description of an action for an ActionAgent to execute."""
11 |
12 | tool: str
13 | """The name of the Tool to execute."""
14 | tool_input: Union[str, dict]
15 | """The input to pass in to the Tool."""
16 | log: str
17 | """Additional information to log about the action.
18 | This log can be used in a few ways. First, it can be used to audit
19 | what exactly the LLM predicted to lead to this (tool, tool_input).
20 | Second, it can be used in future iterations to show the LLMs prior
21 | thoughts. This is useful when (tool, tool_input) does not contain
22 | full information about the LLM prediction (for example, any `thought`
23 | before the tool/tool_input)."""
24 | type: Literal["AgentAction"] = "AgentAction"
25 |
26 | def __init__(
27 | self, tool: str, tool_input: Union[str, dict], log: str, **kwargs: Any
28 | ):
29 | """Override init to support instantiation by position for backward compat."""
30 | super().__init__(tool=tool, tool_input=tool_input, log=log, **kwargs)
31 |
32 | @classmethod
33 | def is_lc_serializable(cls) -> bool:
34 | """Return whether or not the class is serializable."""
35 | return True
36 |
37 |
38 | class AgentActionMessageLog(AgentAction):
39 | message_log: Sequence[BaseMessage]
40 | """Similar to log, this can be used to pass along extra
41 | information about what exact messages were predicted by the LLM
42 | before parsing out the (tool, tool_input). This is again useful
43 | if (tool, tool_input) cannot be used to fully recreate the LLM
44 | prediction, and you need that LLM prediction (for future agent iteration).
45 | Compared to `log`, this is useful when the underlying LLM is a
46 | ChatModel (and therefore returns messages rather than a string)."""
47 | # Ignoring type because we're overriding the type from AgentAction.
48 | # And this is the correct thing to do in this case.
49 | # The type literal is used for serialization purposes.
50 | type: Literal["AgentActionMessageLog"] = "AgentActionMessageLog" # type: ignore
51 |
52 |
53 | class AgentFinish(Serializable):
54 | """The final return value of an ActionAgent."""
55 |
56 | return_values: dict
57 | """Dictionary of return values."""
58 | log: str
59 | """Additional information to log about the return value.
60 | This is used to pass along the full LLM prediction, not just the parsed out
61 | return value. For example, if the full LLM prediction was
62 | `Final Answer: 2` you may want to just return `2` as a return value, but pass
63 | along the full string as a `log` (for debugging or observability purposes).
64 | """
65 | type: Literal["AgentFinish"] = "AgentFinish"
66 |
67 | def __init__(self, return_values: dict, log: str, **kwargs: Any):
68 | """Override init to support instantiation by position for backward compat."""
69 | super().__init__(return_values=return_values, log=log, **kwargs)
70 |
71 | @classmethod
72 | def is_lc_serializable(cls) -> bool:
73 | """Return whether or not the class is serializable."""
74 | return True
75 |
--------------------------------------------------------------------------------
/build/lib/lmchain/schema/document.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import asyncio
4 | from abc import ABC, abstractmethod
5 | from functools import partial
6 | from typing import Any, Literal, Sequence
7 |
8 | from lmchain.load.serializable import Serializable
9 | from pydantic.v1 import Field
10 |
11 | class Document(Serializable):
12 | """Class for storing a piece of text and associated metadata."""
13 |
14 | page_content: str
15 | """String text."""
16 | metadata: dict = Field(default_factory=dict)
17 | """Arbitrary metadata about the page content (e.g., source, relationships to other
18 | documents, etc.).
19 | """
20 | type: Literal["Document"] = "Document"
21 |
22 | @classmethod
23 | def is_lc_serializable(cls) -> bool:
24 | """Return whether this class is serializable."""
25 | return True
26 |
27 |
28 | class BaseDocumentTransformer(ABC):
29 | """Abstract base class for document transformation systems.
30 |
31 | A document transformation system takes a sequence of Documents and returns a
32 | sequence of transformed Documents.
33 |
34 | Example:
35 | .. code-block:: python
36 |
37 | class EmbeddingsRedundantFilter(BaseDocumentTransformer, BaseModel):
38 | embeddings: Embeddings
39 | similarity_fn: Callable = cosine_similarity
40 | similarity_threshold: float = 0.95
41 |
42 | class Config:
43 | arbitrary_types_allowed = True
44 |
45 | def transform_documents(
46 | self, documents: Sequence[Document], **kwargs: Any
47 | ) -> Sequence[Document]:
48 | stateful_documents = get_stateful_documents(documents)
49 | embedded_documents = _get_embeddings_from_stateful_docs(
50 | self.embeddings, stateful_documents
51 | )
52 | included_idxs = _filter_similar_embeddings(
53 | embedded_documents, self.similarity_fn, self.similarity_threshold
54 | )
55 | return [stateful_documents[i] for i in sorted(included_idxs)]
56 |
57 | async def atransform_documents(
58 | self, documents: Sequence[Document], **kwargs: Any
59 | ) -> Sequence[Document]:
60 | raise NotImplementedError
61 |
62 | """ # noqa: E501
63 |
64 | @abstractmethod
65 | def transform_documents(
66 | self, documents: Sequence[Document], **kwargs: Any
67 | ) -> Sequence[Document]:
68 | """Transform a list of documents.
69 |
70 | Args:
71 | documents: A sequence of Documents to be transformed.
72 |
73 | Returns:
74 | A list of transformed Documents.
75 | """
76 |
77 | async def atransform_documents(
78 | self, documents: Sequence[Document], **kwargs: Any
79 | ) -> Sequence[Document]:
80 | """Asynchronously transform a list of documents.
81 |
82 | Args:
83 | documents: A sequence of Documents to be transformed.
84 |
85 | Returns:
86 | A list of transformed Documents.
87 | """
88 | return await asyncio.get_running_loop().run_in_executor(
89 | None, partial(self.transform_documents, **kwargs), documents
90 | )
91 |
--------------------------------------------------------------------------------
/build/lib/lmchain/schema/memory.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from abc import ABC, abstractmethod
4 | from typing import Any, Dict, List
5 |
6 |
7 | class BaseMemory( ABC):
8 | """Abstract base class for memory in Chains.
9 |
10 | Memory refers to state in Chains. Memory can be used to store information about
11 | past executions of a Chain and inject that information into the inputs of
12 | future executions of the Chain. For example, for conversational Chains Memory
13 | can be used to store conversations and automatically add them to future model
14 | prompts so that the model has the necessary context to respond coherently to
15 | the latest input.
16 |
17 | Example:
18 | .. code-block:: python
19 |
20 | class SimpleMemory(BaseMemory):
21 | memories: Dict[str, Any] = dict()
22 |
23 | @property
24 | def memory_variables(self) -> List[str]:
25 | return list(self.memories.keys())
26 |
27 | def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]:
28 | return self.memories
29 |
30 | def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
31 | pass
32 |
33 | def clear(self) -> None:
34 | pass
35 | """ # noqa: E501
36 |
37 | class Config:
38 | """Configuration for this pydantic object."""
39 |
40 | arbitrary_types_allowed = True
41 |
42 | @property
43 | @abstractmethod
44 | def memory_variables(self) -> List[str]:
45 | """The string keys this memory class will add to chain inputs."""
46 |
47 | @abstractmethod
48 | def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
49 | """Return key-value pairs given the text input to the chain."""
50 |
51 | @abstractmethod
52 | def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
53 | """Save the context of this chain run to memory."""
54 |
55 | @abstractmethod
56 | def clear(self) -> None:
57 | """Clear memory contents."""
58 |
--------------------------------------------------------------------------------
/build/lib/lmchain/schema/prompt.py:
--------------------------------------------------------------------------------
1 | # 这段代码定义了一个名为 PromptValue 的抽象基类,该类用于表示任何语言模型的输入。
2 | # 这个类继承自 Serializable 和 ABC(Abstract Base Class),意味着它是一个可序列化的抽象基类。
3 |
4 |
5 | # 导入 __future__ 模块中的 annotations 功能,使得在 Python 3.7 以下版本中也可以使用类型注解的延迟评估功能。
6 | from __future__ import annotations
7 |
8 | # 导入 abc 模块中的 ABC(抽象基类)和 abstractmethod(抽象方法)装饰器。
9 | from abc import ABC, abstractmethod
10 | # 导入 typing 模块中的 List 类型,用于类型注解。
11 | from typing import List
12 |
13 | # 从 lmchain.load.serializable 模块中导入 Serializable 类,用于序列化和反序列化对象。
14 | from lmchain.load.serializable import Serializable
15 | # 从 lmchain.schema.messages 模块中导入 BaseMessage 类,作为消息基类。
16 | from lmchain.schema.messages import BaseMessage
17 |
18 |
19 | # 定义一个名为 PromptValue 的抽象基类,继承自 Serializable 和 ABC。
20 | class PromptValue(Serializable, ABC):
21 | """Base abstract class for inputs to any language model.
22 |
23 | PromptValues can be converted to both LLM (pure text-generation) inputs and
24 | ChatModel inputs.
25 | """
26 |
27 | # 类方法,返回一个布尔值,表示这个类是否可序列化。在这个类中,始终返回 True。
28 | @classmethod
29 | def is_lc_serializable(cls) -> bool:
30 | """Return whether this class is serializable."""
31 | return True
32 |
33 | # 抽象方法,需要子类实现。返回一个字符串,表示 prompt 的值。
34 | @abstractmethod
35 | def to_string(self) -> str:
36 | """Return prompt value as string."""
37 |
38 | # 抽象方法,需要子类实现。返回一个 BaseMessage 对象的列表,表示 prompt。
39 | @abstractmethod
40 | def to_messages(self) -> List[BaseMessage]:
41 | """Return prompt as a list of Messages."""
42 |
--------------------------------------------------------------------------------
/build/lib/lmchain/schema/runnable/__init__.py:
--------------------------------------------------------------------------------
1 | name = "schema.runnable"
--------------------------------------------------------------------------------
/build/lib/lmchain/schema/runnable/config.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/virgo777/lmchain/3fc4a55df07100d4fe0beef4da89e3fbc6288938/build/lib/lmchain/schema/runnable/config.py
--------------------------------------------------------------------------------
/build/lib/lmchain/schema/schema.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from typing import Any, Literal, Sequence, Union
4 |
5 | from lmchain.load.serializable import Serializable
6 | from lmchain.schema.messages import BaseMessage
7 |
8 |
9 | class AgentAction(Serializable):
10 | """A full description of an action for an ActionAgent to execute."""
11 |
12 | tool: str
13 | """The name of the Tool to execute."""
14 | tool_input: Union[str, dict]
15 | """The input to pass in to the Tool."""
16 | log: str
17 | """Additional information to log about the action.
18 | This log can be used in a few ways. First, it can be used to audit
19 | what exactly the LLM predicted to lead to this (tool, tool_input).
20 | Second, it can be used in future iterations to show the LLMs prior
21 | thoughts. This is useful when (tool, tool_input) does not contain
22 | full information about the LLM prediction (for example, any `thought`
23 | before the tool/tool_input)."""
24 | type: Literal["AgentAction"] = "AgentAction"
25 |
26 | def __init__(
27 | self, tool: str, tool_input: Union[str, dict], log: str, **kwargs: Any
28 | ):
29 | """Override init to support instantiation by position for backward compat."""
30 | super().__init__(tool=tool, tool_input=tool_input, log=log, **kwargs)
31 |
32 | @classmethod
33 | def is_lc_serializable(cls) -> bool:
34 | """Return whether or not the class is serializable."""
35 | return True
36 |
37 |
38 | class AgentActionMessageLog(AgentAction):
39 | message_log: Sequence[BaseMessage]
40 | """Similar to log, this can be used to pass along extra
41 | information about what exact messages were predicted by the LLM
42 | before parsing out the (tool, tool_input). This is again useful
43 | if (tool, tool_input) cannot be used to fully recreate the LLM
44 | prediction, and you need that LLM prediction (for future agent iteration).
45 | Compared to `log`, this is useful when the underlying LLM is a
46 | ChatModel (and therefore returns messages rather than a string)."""
47 | # Ignoring type because we're overriding the type from AgentAction.
48 | # And this is the correct thing to do in this case.
49 | # The type literal is used for serialization purposes.
50 | type: Literal["AgentActionMessageLog"] = "AgentActionMessageLog" # type: ignore
51 |
52 |
53 | class AgentFinish(Serializable):
54 | """The final return value of an ActionAgent."""
55 |
56 | return_values: dict
57 | """Dictionary of return values."""
58 | log: str
59 | """Additional information to log about the return value.
60 | This is used to pass along the full LLM prediction, not just the parsed out
61 | return value. For example, if the full LLM prediction was
62 | `Final Answer: 2` you may want to just return `2` as a return value, but pass
63 | along the full string as a `log` (for debugging or observability purposes).
64 | """
65 | type: Literal["AgentFinish"] = "AgentFinish"
66 |
67 | def __init__(self, return_values: dict, log: str, **kwargs: Any):
68 | """Override init to support instantiation by position for backward compat."""
69 | super().__init__(return_values=return_values, log=log, **kwargs)
70 |
71 | @classmethod
72 | def is_lc_serializable(cls) -> bool:
73 | """Return whether or not the class is serializable."""
74 | return True
75 |
--------------------------------------------------------------------------------
/build/lib/lmchain/tool_register.py:
--------------------------------------------------------------------------------
1 | import inspect
2 | import traceback
3 | from copy import deepcopy
4 | from pprint import pformat
5 | from types import GenericAlias
6 | from typing import get_origin, Annotated
7 |
8 | _TOOL_HOOKS = {}
9 | _TOOL_DESCRIPTIONS = {}
10 |
11 |
12 | def register_tool(func: callable):
13 | tool_name = func.__name__
14 | tool_description = inspect.getdoc(func).strip()
15 | python_params = inspect.signature(func).parameters
16 | tool_params = []
17 | for name, param in python_params.items():
18 | annotation = param.annotation
19 | if annotation is inspect.Parameter.empty:
20 | raise TypeError(f"Parameter `{name}` missing type annotation")
21 | if get_origin(annotation) != Annotated:
22 | raise TypeError(f"Annotation type for `{name}` must be typing.Annotated")
23 |
24 | typ, (description, required) = annotation.__origin__, annotation.__metadata__
25 | typ: str = str(typ) if isinstance(typ, GenericAlias) else typ.__name__
26 | if not isinstance(description, str):
27 | raise TypeError(f"Description for `{name}` must be a string")
28 | if not isinstance(required, bool):
29 | raise TypeError(f"Required for `{name}` must be a bool")
30 |
31 | tool_params.append({
32 | "name": name,
33 | "description": description,
34 | "type": typ,
35 | "required": required
36 | })
37 | tool_def = {
38 | "name": tool_name,
39 | "description": tool_description,
40 | "params": tool_params
41 | }
42 |
43 | # print("[registered tool] " + pformat(tool_def))
44 | _TOOL_HOOKS[tool_name] = func
45 | _TOOL_DESCRIPTIONS[tool_name] = tool_def
46 |
47 | return func
48 |
49 |
50 | def dispatch_tool(tool_name: str, tool_params: dict) -> str:
51 | if tool_name not in _TOOL_HOOKS:
52 | return f"Tool `{tool_name}` not found. Please use a provided tool."
53 | tool_call = _TOOL_HOOKS[tool_name]
54 | try:
55 | ret = tool_call(**tool_params)
56 | except:
57 | ret = traceback.format_exc()
58 | return str(ret)
59 |
60 |
61 | def get_tools() -> dict:
62 | return deepcopy(_TOOL_DESCRIPTIONS)
63 |
64 |
65 | # Tool Definitions
66 |
67 | # @register_tool
68 | # def random_number_generator(
69 | # seed: Annotated[int, 'The random seed used by the generator', True],
70 | # range: Annotated[tuple[int, int], 'The range of the generated numbers', True],
71 | # ) -> int:
72 | # """
73 | # Generates a random number x, s.t. range[0] <= x < range[1]
74 | # """
75 | # if not isinstance(seed, int):
76 | # raise TypeError("Seed must be an integer")
77 | # if not isinstance(range, tuple):
78 | # raise TypeError("Range must be a tuple")
79 | # if not isinstance(range[0], int) or not isinstance(range[1], int):
80 | # raise TypeError("Range must be a tuple of integers")
81 | #
82 | # import random
83 | # return random.Random(seed).randint(*range)
84 | #
85 | #
86 | # @register_tool
87 | # def get_weather(
88 | # city_name: Annotated[str, 'The name of the city to be queried', True],
89 | # ) -> str:
90 | # """
91 | # Get the current weather for `city_name`
92 | # """
93 | #
94 | # if not isinstance(city_name, str):
95 | # raise TypeError("City name must be a string")
96 | #
97 | # key_selection = {
98 | # "current_condition": ["temp_C", "FeelsLikeC", "humidity", "weatherDesc", "observation_time"],
99 | # }
100 | # import requests
101 | # try:
102 | # resp = requests.get(f"https://wttr.in/{city_name}?format=j1")
103 | # resp.raise_for_status()
104 | # resp = resp.json()
105 | # ret = {k: {_v: resp[k][0][_v] for _v in v} for k, v in key_selection.items()}
106 | # except:
107 | # import traceback
108 | # ret = "Error encountered while fetching weather data!\n" + traceback.format_exc()
109 | #
110 | # return str(ret)
111 | #
112 | #
113 | # @register_tool
114 | # def get_customer_weather(location: Annotated[str, "需要查询位置的名称,用中文表示的地点名称", True] = ""):
115 | # """ 自己编写的天气查询函数"""
116 | #
117 | # if location == "上海":
118 | # return 23.0
119 | # elif location == "南京":
120 | # return 25.0
121 | # else:
122 | # return "未查询相关内容"
123 | #
124 | #
125 | # @register_tool
126 | # def get_random_fun(location: Annotated[str, "随机参数", True] = ""):
127 | # """编写的一个混淆随机函数"""
128 | # location = location
129 | # return "你上当啦"
130 |
131 |
132 | if __name__ == "__main__":
133 | print(dispatch_tool("get_weather", {"city_name": "shanghai"}))
134 | print(get_tools())
135 |
--------------------------------------------------------------------------------
/build/lib/lmchain/tools/__init__.py:
--------------------------------------------------------------------------------
1 | name = "tools"
--------------------------------------------------------------------------------
/build/lib/lmchain/tools/tool_register.py:
--------------------------------------------------------------------------------
1 | import inspect
2 | import traceback
3 | from copy import deepcopy
4 | from pprint import pformat
5 | from types import GenericAlias
6 | from typing import get_origin, Annotated
7 |
8 | _TOOL_HOOKS = {}
9 | _TOOL_DESCRIPTIONS = {}
10 |
11 |
12 | def register_tool(func: callable):
13 | tool_name = func.__name__
14 | tool_description = inspect.getdoc(func).strip()
15 | python_params = inspect.signature(func).parameters
16 | tool_params = []
17 | for name, param in python_params.items():
18 | annotation = param.annotation
19 | if annotation is inspect.Parameter.empty:
20 | raise TypeError(f"Parameter `{name}` missing type annotation")
21 | if get_origin(annotation) != Annotated:
22 | raise TypeError(f"Annotation type for `{name}` must be typing.Annotated")
23 |
24 | typ, (description, required) = annotation.__origin__, annotation.__metadata__
25 | typ: str = str(typ) if isinstance(typ, GenericAlias) else typ.__name__
26 | if not isinstance(description, str):
27 | raise TypeError(f"Description for `{name}` must be a string")
28 | if not isinstance(required, bool):
29 | raise TypeError(f"Required for `{name}` must be a bool")
30 |
31 | tool_params.append({
32 | "name": name,
33 | "description": description,
34 | "type": typ,
35 | "required": required
36 | })
37 | tool_def = {
38 | "name": tool_name,
39 | "description": tool_description,
40 | "params": tool_params
41 | }
42 |
43 | # print("[registered tool] " + pformat(tool_def))
44 | _TOOL_HOOKS[tool_name] = func
45 | _TOOL_DESCRIPTIONS[tool_name] = tool_def
46 |
47 | return func
48 |
49 |
50 | def dispatch_tool(tool_name: str, tool_params: dict) -> str:
51 | if tool_name not in _TOOL_HOOKS:
52 | return f"Tool `{tool_name}` not found. Please use a provided tool."
53 | tool_call = _TOOL_HOOKS[tool_name]
54 | try:
55 | ret = tool_call(**tool_params)
56 | except:
57 | ret = traceback.format_exc()
58 | return str(ret)
59 |
60 |
61 | def get_tools() -> dict:
62 | return deepcopy(_TOOL_DESCRIPTIONS)
63 |
64 |
65 | # Tool Definitions
66 |
67 | @register_tool
68 | def random_number_generator(
69 | seed: Annotated[int, 'The random seed used by the generator', True],
70 | range: Annotated[tuple[int, int], 'The range of the generated numbers', True],
71 | ) -> int:
72 | """
73 | Generates a random number x, s.t. range[0] <= x < range[1]
74 | """
75 | if not isinstance(seed, int):
76 | raise TypeError("Seed must be an integer")
77 | if not isinstance(range, tuple):
78 | raise TypeError("Range must be a tuple")
79 | if not isinstance(range[0], int) or not isinstance(range[1], int):
80 | raise TypeError("Range must be a tuple of integers")
81 |
82 | import random
83 | return random.Random(seed).randint(*range)
84 |
85 |
86 | @register_tool
87 | def get_weather(
88 | city_name: Annotated[str, 'The name of the city to be queried', True],
89 | ) -> str:
90 | """
91 | Get the current weather for `city_name`
92 | """
93 |
94 | if not isinstance(city_name, str):
95 | raise TypeError("City name must be a string")
96 |
97 | key_selection = {
98 | "current_condition": ["temp_C", "FeelsLikeC", "humidity", "weatherDesc", "observation_time"],
99 | }
100 | import requests
101 | try:
102 | resp = requests.get(f"https://wttr.in/{city_name}?format=j1")
103 | resp.raise_for_status()
104 | resp = resp.json()
105 | ret = {k: {_v: resp[k][0][_v] for _v in v} for k, v in key_selection.items()}
106 | except:
107 | import traceback
108 | ret = "Error encountered while fetching weather data!\n" + traceback.format_exc()
109 |
110 | return str(ret)
111 |
112 |
113 | if __name__ == "__main__":
114 | # print(dispatch_tool("get_weather", {"city_name": "beijing"}))
115 | tools = (get_tools())
116 | import zhipuai as zhipuai
117 |
118 | zhipuai.api_key = "1f565e40af1198e11ff1fd8a5b42771d.SjNfezc40YFsz2KC" # 控制台中获取的 APIKey 信息
119 |
120 | query = "今天shanghai的天气是什么?"
121 | prompt = f"""
122 | 你现在是一个专业的人工智能助手,你现在的需求是{query}。而你需要借助于工具在{tools}中找到对应的函数,用json格式返回对应的函数名和需要的参数。
123 |
124 | 只返回json格式的函数名和需要的参数,不要做描述。
125 |
126 | 如果没有找到合适的函数,则返回:'未找到合适参数,请提供更详细的描述。'
127 | """
128 |
129 | from lmchain.agents import llmMultiAgent
130 |
131 | llm = llmMultiAgent.AgentZhipuAI()
132 | res = llm(prompt)
133 | print(res)
134 |
135 | import json
136 |
137 | res_dict = json.loads(res)
138 | res_dict = json.loads(res_dict)
139 |
140 | print(dispatch_tool(tool_name=res_dict["function_name"], tool_params=res_dict["params"]))
141 |
142 |
143 |
144 |
--------------------------------------------------------------------------------
/build/lib/lmchain/utils/__init__.py:
--------------------------------------------------------------------------------
1 | name = "utils"
--------------------------------------------------------------------------------
/build/lib/lmchain/utils/formatting.py:
--------------------------------------------------------------------------------
1 | """Utilities for formatting strings."""
2 | from string import Formatter
3 | from typing import Any, List, Mapping, Sequence, Union
4 |
5 |
6 | class StrictFormatter(Formatter):
7 | """A subclass of formatter that checks for extra keys."""
8 |
9 | def check_unused_args(
10 | self,
11 | used_args: Sequence[Union[int, str]],
12 | args: Sequence,
13 | kwargs: Mapping[str, Any],
14 | ) -> None:
15 | """Check to see if extra parameters are passed."""
16 | extra = set(kwargs).difference(used_args)
17 | if extra:
18 | raise KeyError(extra)
19 |
20 | def vformat(
21 | self, format_string: str, args: Sequence, kwargs: Mapping[str, Any]
22 | ) -> str:
23 | """Check that no arguments are provided."""
24 | if len(args) > 0:
25 | raise ValueError(
26 | "No arguments should be provided, "
27 | "everything should be passed as keyword arguments."
28 | )
29 | return super().vformat(format_string, args, kwargs)
30 |
31 | def validate_input_variables(
32 | self, format_string: str, input_variables: List[str]
33 | ) -> None:
34 | dummy_inputs = {input_variable: "foo" for input_variable in input_variables}
35 | super().format(format_string, **dummy_inputs)
36 |
37 |
38 | formatter = StrictFormatter()
39 |
--------------------------------------------------------------------------------
/build/lib/lmchain/utils/input.py:
--------------------------------------------------------------------------------
1 | """Handle chained inputs."""
2 | from typing import Dict, List, Optional, TextIO
3 |
4 | _TEXT_COLOR_MAPPING = {
5 | "blue": "36;1",
6 | "yellow": "33;1",
7 | "pink": "38;5;200",
8 | "green": "32;1",
9 | "red": "31;1",
10 | }
11 |
12 |
13 | def get_color_mapping(
14 | items: List[str], excluded_colors: Optional[List] = None
15 | ) -> Dict[str, str]:
16 | """Get mapping for items to a support color."""
17 | colors = list(_TEXT_COLOR_MAPPING.keys())
18 | if excluded_colors is not None:
19 | colors = [c for c in colors if c not in excluded_colors]
20 | color_mapping = {item: colors[i % len(colors)] for i, item in enumerate(items)}
21 | return color_mapping
22 |
23 |
24 | def get_colored_text(text: str, color: str) -> str:
25 | """Get colored text."""
26 | color_str = _TEXT_COLOR_MAPPING[color]
27 | return f"\u001b[{color_str}m\033[1;3m{text}\u001b[0m"
28 |
29 |
30 | def get_bolded_text(text: str) -> str:
31 | """Get bolded text."""
32 | return f"\033[1m{text}\033[0m"
33 |
34 |
35 | def print_text(
36 | text: str, color: Optional[str] = None, end: str = "", file: Optional[TextIO] = None
37 | ) -> None:
38 | """Print text with highlighting and no end characters."""
39 | text_to_print = get_colored_text(text, color) if color else text
40 | print(text_to_print, end=end, file=file)
41 | if file:
42 | file.flush() # ensure all printed content are written to file
43 |
--------------------------------------------------------------------------------
/build/lib/lmchain/utils/loading.py:
--------------------------------------------------------------------------------
1 | """Utilities for loading configurations from langchain-hub."""
2 |
3 | import os
4 | import re
5 | import tempfile
6 | from pathlib import Path, PurePosixPath
7 | from typing import Any, Callable, Optional, Set, TypeVar, Union
8 | from urllib.parse import urljoin
9 |
10 | import requests
11 |
12 | DEFAULT_REF = os.environ.get("LANGCHAIN_HUB_DEFAULT_REF", "master")
13 | URL_BASE = os.environ.get(
14 | "LANGCHAIN_HUB_URL_BASE",
15 | "https://raw.githubusercontent.com/hwchase17/langchain-hub/{ref}/",
16 | )
17 | HUB_PATH_RE = re.compile(r"lc(?P[@[^:]+)?://(?P.*)")
18 |
19 | T = TypeVar("T")
20 |
21 |
22 | def try_load_from_hub(
23 | path: Union[str, Path],
24 | loader: Callable[[str], T],
25 | valid_prefix: str,
26 | valid_suffixes: Set[str],
27 | **kwargs: Any,
28 | ) -> Optional[T]:
29 | """Load configuration from hub. Returns None if path is not a hub path."""
30 | if not isinstance(path, str) or not (match := HUB_PATH_RE.match(path)):
31 | return None
32 | ref, remote_path_str = match.groups()
33 | ref = ref[1:] if ref else DEFAULT_REF
34 | remote_path = Path(remote_path_str)
35 | if remote_path.parts[0] != valid_prefix:
36 | return None
37 | if remote_path.suffix[1:] not in valid_suffixes:
38 | raise ValueError(f"Unsupported file type, must be one of {valid_suffixes}.")
39 |
40 | # Using Path with URLs is not recommended, because on Windows
41 | # the backslash is used as the path separator, which can cause issues
42 | # when working with URLs that use forward slashes as the path separator.
43 | # Instead, use PurePosixPath to ensure that forward slashes are used as the
44 | # path separator, regardless of the operating system.
45 | full_url = urljoin(URL_BASE.format(ref=ref), PurePosixPath(remote_path).__str__())
46 |
47 | r = requests.get(full_url, timeout=5)
48 | if r.status_code != 200:
49 | raise ValueError(f"Could not find file at {full_url}")
50 | with tempfile.TemporaryDirectory() as tmpdirname:
51 | file = Path(tmpdirname) / remote_path.name
52 | with open(file, "wb") as f:
53 | f.write(r.content)
54 | return loader(str(file), **kwargs)
55 |
--------------------------------------------------------------------------------
/build/lib/lmchain/utils/math.py:
--------------------------------------------------------------------------------
1 | """Math utils."""
2 | import logging
3 | from typing import List, Optional, Tuple, Union
4 |
5 | import numpy as np
6 |
7 | logger = logging.getLogger(__name__)
8 |
9 | Matrix = Union[List[List[float]], List[np.ndarray], np.ndarray]
10 |
11 |
12 | def cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray:
13 | """Row-wise cosine similarity between two equal-width matrices."""
14 | if len(X) == 0 or len(Y) == 0:
15 | return np.array([])
16 |
17 | X = np.array(X)
18 | Y = np.array(Y)
19 | if X.shape[1] != Y.shape[1]:
20 | raise ValueError(
21 | f"Number of columns in X and Y must be the same. X has shape {X.shape} "
22 | f"and Y has shape {Y.shape}."
23 | )
24 | try:
25 | import simsimd as simd
26 |
27 | X = np.array(X, dtype=np.float32)
28 | Y = np.array(Y, dtype=np.float32)
29 | Z = 1 - simd.cdist(X, Y, metric="cosine")
30 | if isinstance(Z, float):
31 | return np.array([Z])
32 | return Z
33 | except ImportError:
34 | logger.info(
35 | "Unable to import simsimd, defaulting to NumPy implementation. If you want "
36 | "to use simsimd please install with `pip install simsimd`."
37 | )
38 | X_norm = np.linalg.norm(X, axis=1)
39 | Y_norm = np.linalg.norm(Y, axis=1)
40 | # Ignore divide by zero errors run time warnings as those are handled below.
41 | with np.errstate(divide="ignore", invalid="ignore"):
42 | similarity = np.dot(X, Y.T) / np.outer(X_norm, Y_norm)
43 | similarity[np.isnan(similarity) | np.isinf(similarity)] = 0.0
44 | return similarity
45 |
46 |
47 | def cosine_similarity_top_k(
48 | X: Matrix,
49 | Y: Matrix,
50 | top_k: Optional[int] = 5,
51 | score_threshold: Optional[float] = None,
52 | ) -> Tuple[List[Tuple[int, int]], List[float]]:
53 | """Row-wise cosine similarity with optional top-k and score threshold filtering.
54 |
55 | Args:
56 | X: Matrix.
57 | Y: Matrix, same width as X.
58 | top_k: Max number of results to return.
59 | score_threshold: Minimum cosine similarity of results.
60 |
61 | Returns:
62 | Tuple of two lists. First contains two-tuples of indices (X_idx, Y_idx),
63 | second contains corresponding cosine similarities.
64 | """
65 | if len(X) == 0 or len(Y) == 0:
66 | return [], []
67 | score_array = cosine_similarity(X, Y)
68 | score_threshold = score_threshold or -1.0
69 | score_array[score_array < score_threshold] = 0
70 | top_k = min(top_k or len(score_array), np.count_nonzero(score_array))
71 | top_k_idxs = np.argpartition(score_array, -top_k, axis=None)[-top_k:]
72 | top_k_idxs = top_k_idxs[np.argsort(score_array.ravel()[top_k_idxs])][::-1]
73 | ret_idxs = np.unravel_index(top_k_idxs, score_array.shape)
74 | scores = score_array.ravel()[top_k_idxs].tolist()
75 | return list(zip(*ret_idxs)), scores # type: ignore
76 |
--------------------------------------------------------------------------------
/build/lib/lmchain/vectorstores/__init__.py:
--------------------------------------------------------------------------------
1 | name = "vectorstores"
--------------------------------------------------------------------------------
/build/lib/lmchain/vectorstores/chroma.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 | from langchain.docstore.document import Document
4 | from langchain.text_splitter import RecursiveCharacterTextSplitter
5 | from lmchain.embeddings import embeddings
6 | from lmchain.vectorstores import laiss
7 |
8 | from langchain.memory import ConversationBufferMemory
9 | from langchain.prompts import (
10 | ChatPromptTemplate, # 用于构建聊天模板的类
11 | MessagesPlaceholder, # 用于在模板中插入消息占位的类
12 | SystemMessagePromptTemplate, # 用于构建系统消息模板的类
13 | HumanMessagePromptTemplate # 用于构建人类消息模板的类
14 | )
15 | from langchain.chains import ConversationChain
16 |
17 | class Chroma:
18 | def __init__(self,documents,embedding_tool,chunk_size = 1280,chunk_overlap = 50,source = "这是一份辅助材料"):
19 | """
20 | :param document: 输入的文本内容,只要一个text文本
21 | :param chunk_size: 切分后每段的字数
22 | :param chunk_overlap: 每个相隔段落重叠的字数
23 | :param source: 文本名称/文本地址
24 | """
25 | self.text_splitter = RecursiveCharacterTextSplitter( chunk_size=chunk_size, chunk_overlap=chunk_overlap)
26 | self.embedding_tool = embedding_tool
27 |
28 | self.lmaiss = laiss.LMASS() #这里是用于将文本转化为vector,并且计算query相应的相似度的类
29 |
30 | self.documents = []
31 | self.vectorstores = []
32 |
33 | "---------------------------"
34 | for document in documents:
35 | document = [Document(page_content=document, metadata={"source": source})] #对输入的document进行格式化处理
36 | doc= self.text_splitter.split_documents(document) #根据
37 | self.documents.extend(doc)
38 |
39 | vector = self.lmaiss.from_documents(document, embedding_class=self.embedding_tool)
40 | self.vectorstores.extend(vector)
41 |
42 | # def __call__(self, query):
43 | # query_embedding = self.embedding_tool.embed_query(query)
44 | #
45 | # #根据query查找最近的那个序列
46 | # close_id = self.lmaiss.get_similarity_vector_indexs(query_embedding, self.vectorstore, k=1)[0]
47 | # #查找最近的那个段落id
48 | # doc = self.documents[close_id]
49 | #
50 | #
51 | # return doc
52 |
53 | def similarity_search(self, query):
54 | query_embedding = self.embedding_tool.embed_query(query)
55 |
56 | #根据query查找最近的那个序列
57 | close_id = self.lmaiss.get_similarity_vector_indexs(query_embedding, self.vectorstores, k=1)[0]
58 | #查找最近的那个段落id
59 | doc = self.documents[close_id]
60 | return doc
61 |
62 | def add_texts(self,texts,metadata = ""):
63 | for document in texts:
64 | document = [Document(page_content=document, metadata={"source": metadata})] #对输入的document进行格式化处理
65 | doc= self.text_splitter.split_documents(document) #根据
66 | self.documents.extend(doc)
67 |
68 | vector = self.lmaiss.from_documents(document, embedding_class=self.embedding_tool)
69 | self.vectorstores.extend(vector)
70 |
71 | return True
72 |
73 |
74 | def from_texts(texts,embeddings,source = ""):
75 | docsearch = Chroma(documents = texts,embedding_tool=embeddings,source = source)
76 | return docsearch
77 |
78 |
79 | # def from_texts(texts,embeddings):
80 | # embs = embeddings.embed_documents(texts=texts)
81 | # return embs
--------------------------------------------------------------------------------
/build/lib/lmchain/vectorstores/embeddings.py:
--------------------------------------------------------------------------------
1 | import warnings
2 | warnings.filterwarnings("ignore")
3 |
4 | import asyncio
5 | from abc import ABC, abstractmethod
6 | from typing import List
7 |
8 |
9 | class Embeddings(ABC):
10 | """Interface for embedding models."""
11 |
12 | @abstractmethod
13 | def embed_documents(self, texts: List[str]) -> List[List[float]]:
14 | """Embed search docs."""
15 |
16 | @abstractmethod
17 | def embed_query(self, text: str) -> List[float]:
18 | """Embed query text."""
19 |
20 | async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
21 | """Asynchronous Embed search docs."""
22 | return await asyncio.get_running_loop().run_in_executor(
23 | None, self.embed_documents, texts
24 | )
25 |
26 | async def aembed_query(self, text: str) -> List[float]:
27 | """Asynchronous Embed query text."""
28 | return await asyncio.get_running_loop().run_in_executor(
29 | None, self.embed_query, text
30 | )
31 |
32 |
33 | # class LMEmbedding(Embeddings):
34 | # from modelscope.pipelines import pipeline
35 | # from modelscope.utils.constant import Tasks
36 | # pipeline_se = pipeline(Tasks.sentence_embedding, model='thomas/text2vec-base-chinese', model_revision='v1.0.0',
37 | # device="cuda")
38 | #
39 | # def _costruct_inputs(self, texts):
40 | # inputs = {
41 | # "source_sentence": texts
42 | # }
43 | #
44 | # return inputs
45 | #
46 | # def embed_documents(self, texts: List[str]) -> List[List[float]]:
47 | # """Embed search docs."""
48 | #
49 | # inputs = self._costruct_inputs(texts)
50 | # result_embeddings = self.pipeline_se(input=inputs)
51 | # return result_embeddings["text_embedding"]
52 | #
53 | # def embed_query(self, text: str) -> List[float]:
54 | # """Embed query text."""
55 | # inputs = self._costruct_inputs([text])
56 | # result_embeddings = self.pipeline_se(input=inputs)
57 | # return result_embeddings["text_embedding"]
58 |
59 |
60 | class GLMEmbedding(Embeddings):
61 | import zhipuai as zhipuai
62 | zhipuai.api_key = "1f565e40af1198e11ff1fd8a5b42771d.SjNfezc40YFsz2KC" # 控制台中获取的 APIKey 信息
63 |
64 | def _costruct_inputs(self, texts):
65 | inputs = {
66 | "source_sentence": texts
67 | }
68 |
69 | return inputs
70 |
71 | aembeddings = [] # 这个是为了在并发获取embedding_value时候使用的存储embedding_list内容。
72 | atexts = []
73 |
74 | def embed_documents(self, texts: List[str]) -> List[List[float]]:
75 | """Embed search docs."""
76 | result_embeddings = []
77 | for text in texts:
78 | embedding = self.embed_query(text)
79 | result_embeddings.append(embedding)
80 | return result_embeddings
81 |
82 | def embed_query(self, text: str) -> List[float]:
83 | """Embed query text."""
84 | result_embeddings = self.zhipuai.model_api.invoke(
85 | model="text_embedding", prompt=text)
86 | return result_embeddings["data"]["embedding"]
87 |
88 | def aembed_query(self, text: str) -> List[float]:
89 | """Embed query text."""
90 | result_embeddings = self.zhipuai.model_api.invoke(
91 | model="text_embedding", prompt=text)
92 | emb = result_embeddings["data"]["embedding"]
93 |
94 | self.aembeddings.append(emb)
95 | self.atexts.append(text)
96 |
97 | # 这里实现了并发embedding获取
98 | def aembed_documents(self, texts: List[str], thread_num=5, wait_sec=0.3) -> List[List[float]]:
99 | import threading
100 | text_length = len(texts)
101 | thread_batch = text_length // thread_num
102 |
103 | for i in range(thread_batch):
104 | start = i * thread_num
105 | end = (i + 1) * thread_num
106 |
107 | # 创建线程列表
108 | threads = []
109 | # 创建并启动5个线程,每个线程调用一个模型
110 | for text in texts[start:end]:
111 | thread = threading.Thread(target=self.aembed_query, args=(text,))
112 | thread.start()
113 | threads.append(thread)
114 | for thread in threads:
115 | thread.join(wait_sec) # 设置超时时间为0.3秒
116 | return self.aembeddings, self.atexts
117 |
118 |
119 | if __name__ == '__main__':
120 | import time
121 |
122 | inputs = ["不可以,早晨喝牛奶不科学", "今天早晨喝牛奶不科学", "早晨喝牛奶不科学"] * 50
123 |
124 | start_time = time.time()
125 | aembeddings = (GLMEmbedding().aembed_documents(inputs, thread_num=5, thread_sec=0.3))
126 | print(aembeddings)
127 | print(len(aembeddings))
128 | end_time = time.time()
129 | # 计算函数执行时间并打印结果
130 | execution_time = end_time - start_time
131 | print(f"函数执行时间: {execution_time} 秒")
132 | print("----------------------------------------------------------------------------------")
133 | start_time = time.time()
134 | aembeddings = (GLMEmbedding().embed_documents(inputs))
135 | print(len(aembeddings))
136 | end_time = time.time()
137 | # 计算函数执行时间并打印结果
138 | execution_time = end_time - start_time
139 | print(f"函数执行时间: {execution_time} 秒")
140 |
--------------------------------------------------------------------------------
/build/lib/lmchain/vectorstores/utils.py:
--------------------------------------------------------------------------------
1 | """Utility functions for working with vectors and vectorstores."""
2 |
3 | from enum import Enum
4 | from typing import List, Tuple, Type
5 |
6 | import numpy as np
7 |
8 | from lmchain.schema.document import Document
9 | from lmchain.utils.math import cosine_similarity
10 |
11 | class DistanceStrategy(str, Enum):
12 | """Enumerator of the Distance strategies for calculating distances
13 | between vectors."""
14 |
15 | EUCLIDEAN_DISTANCE = "EUCLIDEAN_DISTANCE"
16 | MAX_INNER_PRODUCT = "MAX_INNER_PRODUCT"
17 | DOT_PRODUCT = "DOT_PRODUCT"
18 | JACCARD = "JACCARD"
19 | COSINE = "COSINE"
20 |
21 |
22 | def maximal_marginal_relevance(
23 | query_embedding: np.ndarray,
24 | embedding_list: list,
25 | lambda_mult: float = 0.5,
26 | k: int = 4,
27 | ) -> List[int]:
28 | """Calculate maximal marginal relevance."""
29 | if min(k, len(embedding_list)) <= 0:
30 | return []
31 | if query_embedding.ndim == 1:
32 | query_embedding = np.expand_dims(query_embedding, axis=0)
33 | similarity_to_query = cosine_similarity(query_embedding, embedding_list)[0]
34 | most_similar = int(np.argmax(similarity_to_query))
35 | idxs = [most_similar]
36 | selected = np.array([embedding_list[most_similar]])
37 | while len(idxs) < min(k, len(embedding_list)):
38 | best_score = -np.inf
39 | idx_to_add = -1
40 | similarity_to_selected = cosine_similarity(embedding_list, selected)
41 | for i, query_score in enumerate(similarity_to_query):
42 | if i in idxs:
43 | continue
44 | redundant_score = max(similarity_to_selected[i])
45 | equation_score = (
46 | lambda_mult * query_score - (1 - lambda_mult) * redundant_score
47 | )
48 | if equation_score > best_score:
49 | best_score = equation_score
50 | idx_to_add = i
51 | idxs.append(idx_to_add)
52 | selected = np.append(selected, [embedding_list[idx_to_add]], axis=0)
53 | return idxs
54 |
55 |
56 | def filter_complex_metadata(
57 | documents: List[Document],
58 | *,
59 | allowed_types: Tuple[Type, ...] = (str, bool, int, float),
60 | ) -> List[Document]:
61 | """Filter out metadata types that are not supported for a vector store."""
62 | updated_documents = []
63 | for document in documents:
64 | filtered_metadata = {}
65 | for key, value in document.metadata.items():
66 | if not isinstance(value, allowed_types):
67 | continue
68 | filtered_metadata[key] = value
69 |
70 | document.metadata = filtered_metadata
71 | updated_documents.append(document)
72 |
73 | return updated_documents
74 |
--------------------------------------------------------------------------------
/dist/LMchain-0.1.60-py3-none-any.whl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/virgo777/lmchain/3fc4a55df07100d4fe0beef4da89e3fbc6288938/dist/LMchain-0.1.60-py3-none-any.whl
--------------------------------------------------------------------------------
/dist/LMchain-0.1.60.tar.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/virgo777/lmchain/3fc4a55df07100d4fe0beef4da89e3fbc6288938/dist/LMchain-0.1.60.tar.gz
--------------------------------------------------------------------------------
/dist/LMchain-0.1.61-py3-none-any.whl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/virgo777/lmchain/3fc4a55df07100d4fe0beef4da89e3fbc6288938/dist/LMchain-0.1.61-py3-none-any.whl
--------------------------------------------------------------------------------
/dist/LMchain-0.1.61.tar.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/virgo777/lmchain/3fc4a55df07100d4fe0beef4da89e3fbc6288938/dist/LMchain-0.1.61.tar.gz
--------------------------------------------------------------------------------
/dist/LMchain-0.1.62-py3-none-any.whl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/virgo777/lmchain/3fc4a55df07100d4fe0beef4da89e3fbc6288938/dist/LMchain-0.1.62-py3-none-any.whl
--------------------------------------------------------------------------------
/dist/LMchain-0.1.62.tar.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/virgo777/lmchain/3fc4a55df07100d4fe0beef4da89e3fbc6288938/dist/LMchain-0.1.62.tar.gz
--------------------------------------------------------------------------------
/lmchain/__init__.py:
--------------------------------------------------------------------------------
1 | name = "lmchain"
--------------------------------------------------------------------------------
/lmchain/__pycache__/__init__.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/virgo777/lmchain/3fc4a55df07100d4fe0beef4da89e3fbc6288938/lmchain/__pycache__/__init__.cpython-311.pyc
--------------------------------------------------------------------------------
/lmchain/agents/__init__.py:
--------------------------------------------------------------------------------
1 | name = "agents"
--------------------------------------------------------------------------------
/lmchain/agents/__pycache__/__init__.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/virgo777/lmchain/3fc4a55df07100d4fe0beef4da89e3fbc6288938/lmchain/agents/__pycache__/__init__.cpython-311.pyc
--------------------------------------------------------------------------------
/lmchain/agents/__pycache__/llmAgent.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/virgo777/lmchain/3fc4a55df07100d4fe0beef4da89e3fbc6288938/lmchain/agents/__pycache__/llmAgent.cpython-311.pyc
--------------------------------------------------------------------------------
/lmchain/agents/__pycache__/llmMultiAgent.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/virgo777/lmchain/3fc4a55df07100d4fe0beef4da89e3fbc6288938/lmchain/agents/__pycache__/llmMultiAgent.cpython-311.pyc
--------------------------------------------------------------------------------
/lmchain/agents/llmAgent.py:
--------------------------------------------------------------------------------
1 | import time
2 | import logging
3 | import requests
4 | from typing import Optional, List, Dict, Mapping, Any
5 |
6 | import langchain
7 | from langchain.llms.base import LLM
8 | from langchain.cache import InMemoryCache
9 |
10 | logging.basicConfig(level=logging.INFO)
11 | # 启动llm的缓存
12 | langchain.llm_cache = InMemoryCache()
13 |
14 |
15 | class AgentChatGLM(LLM):
16 | # 模型服务url
17 | url = "http://127.0.0.1:7866/chat"
18 | #url = "http://192.168.3.20:7866/chat" #3050服务器上
19 | history = []
20 |
21 | @property
22 | def _llm_type(self) -> str:
23 | return "chatglm"
24 |
25 | def _construct_query(self, prompt: str) -> Dict:
26 | """构造请求体
27 | """
28 | query = {"query": prompt, "history": self.history}
29 | import json
30 | query = json.dumps(query) # 对请求参数进行JSON编码
31 |
32 | return query
33 |
34 | def _construct_query_tools(self, prompt: str , tools: list ) -> Dict:
35 | """构造请求体
36 | """
37 | tools_info = {"role": "system",
38 | "content": "你现在是一个查找使用何种工具以及传递何种参数的工具助手,你会一步步的思考问题。你根据需求查找工具函数箱中最合适的工具函数,然后返回工具函数名称和所工具函数对应的参数,参数必须要和需求中的目标对应。",
39 | "tools": tools}
40 | query = {"query": prompt, "history": tools_info}
41 | import json
42 | query = json.dumps(query) # 对请求参数进行JSON编码
43 |
44 | return query
45 |
46 |
47 | @classmethod
48 | def _post(self, url: str, query: Dict) -> Any:
49 |
50 | """POST请求"""
51 | response = requests.post(url, data=query).json()
52 | return response
53 |
54 | def _call(self, prompt: str, stop: Optional[List[str]] = None, tools:list = None) -> str:
55 | """_call"""
56 | if tools == None:
57 | # construct query
58 | query = self._construct_query(prompt=prompt)
59 |
60 | # post
61 | response = self._post(url=self.url,query=query)
62 |
63 | response_chat = response["response"];
64 | self.history = response["history"]
65 |
66 | return response_chat
67 | else:
68 |
69 | query = self._construct_query_tools(prompt=prompt,tools=tools)
70 | # post
71 | response = self._post(url=self.url, query=query)
72 | self.history = response["history"] #这个history要放上面
73 | response = response["response"]
74 | try:
75 | #import ast
76 | #response = ast.literal_eval(response)
77 | ret = tool_register.dispatch_tool(response["name"], response["parameters"])
78 | response_chat = llm(prompt=ret)
79 | except:
80 | response_chat = response
81 | return str(response_chat)
82 |
83 | @property
84 | def _identifying_params(self) -> Mapping[str, Any]:
85 | """Get the identifying parameters.
86 | """
87 | _param_dict = {
88 | "url": self.url
89 | }
90 | return _param_dict
91 |
92 |
93 | if __name__ == "__main__":
94 |
95 | import tool_register
96 |
97 | # 获取注册后的全部工具,并以json的形式返回
98 | tools = tool_register.get_tools()
99 | "--------------------------------------首先是对tools的定义---------------------------------------"
100 |
101 | llm = AgentChatGLM()
102 | llm.url = "http://192.168.3.20:7866/chat"
103 | while True:
104 | while True:
105 | human_input = input("Human: ")
106 | if human_input == "tools":
107 | break
108 |
109 | begin_time = time.time() * 1000
110 | # 请求模型
111 | response = llm(human_input)
112 | end_time = time.time() * 1000
113 | used_time = round(end_time - begin_time, 3)
114 | #logging.info(f"chatGLM process time: {used_time}ms")
115 | print(f"Chat: {response}")
116 |
117 | human_input = input("Human_with_tools_Ask: ")
118 | response = llm(prompt=human_input,tools=tools)
119 | print(f"Chat_with_tools_Que: {response}")
120 |
121 |
122 |
123 |
124 |
--------------------------------------------------------------------------------
/lmchain/agents/llmMultiAgent.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | import logging
4 | import requests
5 | from typing import Optional, List, Dict, Mapping, Any
6 | import langchain
7 | from langchain.llms.base import LLM
8 | from langchain.cache import InMemoryCache
9 |
10 | logging.basicConfig(level=logging.INFO)
11 | # 启动llm的缓存
12 | langchain.llm_cache = InMemoryCache()
13 |
14 |
15 | class AgentZhipuAI(LLM):
16 | import zhipuai as zhipuai
17 | # 模型服务url
18 | url = "127.0.0.1"
19 | zhipuai.api_key ="1f565e40af1198e11ff1fd8a5b42771d.SjNfezc40YFsz2KC"#控制台中获取的 APIKey 信息
20 | model = "chatglm_pro" # 大模型版本
21 |
22 | history = []
23 |
24 | def getText(self,role, content):
25 | # role 是指定角色,content 是 prompt 内容
26 | jsoncon = {}
27 | jsoncon["role"] = role
28 | jsoncon["content"] = content
29 | self.history.append(jsoncon)
30 | return self.history
31 |
32 | @property
33 | def _llm_type(self) -> str:
34 | return "AgentZhipuAI"
35 |
36 | @classmethod
37 | def _post(self, url: str, query: Dict) -> Any:
38 |
39 | """POST请求"""
40 | response = requests.post(url, data=query).json()
41 | return response
42 |
43 | def _call(self, prompt: str, stop: Optional[List[str]] = None,role = "user") -> str:
44 | """_call"""
45 | # construct query
46 | response = self.zhipuai.model_api.invoke(
47 | model=self.model,
48 | prompt=self.getText(role=role, content=prompt)
49 | )
50 | choices = (response['data']['choices'])[0]
51 | self.history.append(choices)
52 | return choices["content"]
53 |
54 | @property
55 | def _identifying_params(self) -> Mapping[str, Any]:
56 | """Get the identifying parameters.
57 | """
58 | _param_dict = {
59 | "url": self.url
60 | }
61 | return _param_dict
62 |
63 |
64 | if __name__ == '__main__':
65 | from langchain.prompts import PromptTemplate
66 | from langchain.chains import LLMChain
67 |
68 | llm = AgentZhipuAI()
69 |
70 | # 没有输入变量的示例prompt
71 | no_input_prompt = PromptTemplate(input_variables=[], template="给我讲个笑话。")
72 | no_input_prompt.format()
73 |
74 | prompt = PromptTemplate(
75 | input_variables=["location", "street"],
76 | template="作为一名专业的旅游顾问,简单的说一下{location}有什么好玩的景点,特别是在{street}?只要说一个就可以。",
77 | )
78 |
79 | chain = LLMChain(llm=llm, prompt=prompt)
80 | print(chain.run({"location": "南京", "street": "新街口"}))
81 |
82 |
83 | from langchain.chains import ConversationChain
84 | conversation = ConversationChain(llm=llm, verbose=True)
85 |
86 | output = conversation.predict(input="你好!")
87 | print(output)
88 |
89 | output = conversation.predict(input="南京是哪里的省会?")
90 | print(output)
91 |
92 | output = conversation.predict(input="那里有什么好玩的地方,简单的说一个就好。")
93 | print(output)
94 |
95 |
--------------------------------------------------------------------------------
/lmchain/callbacks/__init__.py:
--------------------------------------------------------------------------------
1 | name = "callbacks"
--------------------------------------------------------------------------------
/lmchain/callbacks/stdout.py:
--------------------------------------------------------------------------------
1 | """Callback Handler that prints to std out."""
2 | from typing import Any, Dict, List, Optional
3 |
4 | from langchain.callbacks.base import BaseCallbackHandler
5 | from langchain.schema import AgentAction, AgentFinish, LLMResult
6 | from lmchain.utils.input import print_text
7 |
8 |
9 | class StdOutCallbackHandler(BaseCallbackHandler):
10 | """Callback Handler that prints to std out."""
11 |
12 | def __init__(self, color: Optional[str] = None) -> None:
13 | """Initialize callback handler."""
14 | self.color = color
15 |
16 | def on_llm_start(
17 | self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
18 | ) -> None:
19 | """Print out the prompts."""
20 | pass
21 |
22 | def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
23 | """Do nothing."""
24 | pass
25 |
26 | def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
27 | """Do nothing."""
28 | pass
29 |
30 | def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
31 | """Do nothing."""
32 | pass
33 |
34 | def on_chain_start(
35 | self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
36 | ) -> None:
37 | """Print out that we are entering a chain."""
38 | class_name = serialized.get("name", serialized.get("id", [""])[-1])
39 | print(f"\n\n\033[1m> Entering new {class_name} chain...\033[0m")
40 |
41 | def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
42 | """Print out that we finished a chain."""
43 | print("\n\033[1m> Finished chain.\033[0m")
44 |
45 | def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
46 | """Do nothing."""
47 | pass
48 |
49 | def on_tool_start(
50 | self,
51 | serialized: Dict[str, Any],
52 | input_str: str,
53 | **kwargs: Any,
54 | ) -> None:
55 | """Do nothing."""
56 | pass
57 |
58 | def on_agent_action(
59 | self, action: AgentAction, color: Optional[str] = None, **kwargs: Any
60 | ) -> Any:
61 | """Run on agent action."""
62 | print_text(action.log, color=color or self.color)
63 |
64 | def on_tool_end(
65 | self,
66 | output: str,
67 | color: Optional[str] = None,
68 | observation_prefix: Optional[str] = None,
69 | llm_prefix: Optional[str] = None,
70 | **kwargs: Any,
71 | ) -> None:
72 | """If not the final action, print out observation."""
73 | if observation_prefix is not None:
74 | print_text(f"\n{observation_prefix}")
75 | print_text(output, color=color or self.color)
76 | if llm_prefix is not None:
77 | print_text(f"\n{llm_prefix}")
78 |
79 | def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
80 | """Do nothing."""
81 | pass
82 |
83 | def on_text(
84 | self,
85 | text: str,
86 | color: Optional[str] = None,
87 | end: str = "",
88 | **kwargs: Any,
89 | ) -> None:
90 | """Run when agent ends."""
91 | print_text(text, color=color or self.color, end=end)
92 |
93 | def on_agent_finish(
94 | self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any
95 | ) -> None:
96 | """Run on agent end."""
97 | print_text(finish.log, color=color or self.color, end="\n")
98 |
--------------------------------------------------------------------------------
/lmchain/chains/__init__.py:
--------------------------------------------------------------------------------
1 | name = "chains"
--------------------------------------------------------------------------------
/lmchain/chains/__pycache__/__init__.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/virgo777/lmchain/3fc4a55df07100d4fe0beef4da89e3fbc6288938/lmchain/chains/__pycache__/__init__.cpython-311.pyc
--------------------------------------------------------------------------------
/lmchain/chains/__pycache__/cmd.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/virgo777/lmchain/3fc4a55df07100d4fe0beef4da89e3fbc6288938/lmchain/chains/__pycache__/cmd.cpython-311.pyc
--------------------------------------------------------------------------------
/lmchain/chains/__pycache__/mathChain.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/virgo777/lmchain/3fc4a55df07100d4fe0beef4da89e3fbc6288938/lmchain/chains/__pycache__/mathChain.cpython-311.pyc
--------------------------------------------------------------------------------
/lmchain/chains/__pycache__/urlRequestChain.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/virgo777/lmchain/3fc4a55df07100d4fe0beef4da89e3fbc6288938/lmchain/chains/__pycache__/urlRequestChain.cpython-311.pyc
--------------------------------------------------------------------------------
/lmchain/chains/cmd.py:
--------------------------------------------------------------------------------
1 | #这里是执行对CMD命令进行调用的chain
2 | from langchain.chains.llm import LLMChain
3 | from langchain.prompts import PromptTemplate
4 | from lmchain.lmchain.agents import llmAgent
5 | import os,re
6 |
7 | class LLMCMDChain:
8 | def __init__(self ,llm):
9 | qa_prompt = PromptTemplate(template="""你现在根据需要完成对命令行的编写,要根据需求编写对应的在Windows系统终端运行的命令,不要用%question形参这种指代的参数形式,直接给出可以运行的命令。
10 | Question: 给我一个在Windows系统终端中可以准确执行{question}的命令。
11 | ,
12 | input_variables=["question"],
13 | )
14 | answer:""", input_variables=["question"], )
15 | self.qa_chain = LLMChain(llm=llm, prompt=qa_prompt)
16 | self.pattern = r"```(.*?)\```"
17 |
18 | def run(self ,text):
19 | cmd_response = self.qa_chain.run(question=text)
20 | cmd_string = str(cmd_response).split("```")[-2][1:-1]
21 | os.system(cmd_string)
22 | return cmd_string
23 |
--------------------------------------------------------------------------------
/lmchain/chains/conversationalRetrievalChain.py:
--------------------------------------------------------------------------------
1 | from langchain.docstore.document import Document
2 | from langchain.text_splitter import RecursiveCharacterTextSplitter
3 | from lmchain.embeddings import embeddings
4 | from lmchain.vectorstores import laiss
5 | from lmchain.agents import llmMultiAgent
6 | from langchain.memory import ConversationBufferMemory
7 | from langchain.prompts import (
8 | ChatPromptTemplate, # 用于构建聊天模板的类
9 | MessagesPlaceholder, # 用于在模板中插入消息占位的类
10 | SystemMessagePromptTemplate, # 用于构建系统消息模板的类
11 | HumanMessagePromptTemplate # 用于构建人类消息模板的类
12 | )
13 | from langchain.chains import ConversationChain
14 |
15 | class ConversationalRetrievalChain:
16 | def __init__(self,document,chunk_size = 1280,chunk_overlap = 50,file_name = "这是一份辅助材料"):
17 | """
18 | :param document: 输入的文本内容,只要一个text文本
19 | :param chunk_size: 切分后每段的字数
20 | :param chunk_overlap: 每个相隔段落重叠的字数
21 | :param file_name: 文本名称/文本地址
22 | """
23 | self.text_splitter = RecursiveCharacterTextSplitter( chunk_size=chunk_size, chunk_overlap=chunk_overlap)
24 | self.embedding_tool = embeddings.GLMEmbedding()
25 |
26 | self.lmaiss = laiss.LMASS() #这里是用于将文本转化为vector,并且计算query相应的相似度的类
27 | self.llm = llmMultiAgent.AgentZhipuAI()
28 | self.memory = ConversationBufferMemory(return_messages=True)
29 |
30 | conversation_prompt = ChatPromptTemplate.from_messages([
31 | SystemMessagePromptTemplate.from_template("你是一个最强大的人工智能程序,可以知无不答,但是你不懂的东西会直接回答不知道。"),
32 | MessagesPlaceholder(variable_name="history"), # 历史消息占位符
33 | HumanMessagePromptTemplate.from_template("{input}") # 人类消息输入模板
34 | ])
35 |
36 | self.qa_chain = ConversationChain(memory=self.memory, prompt=conversation_prompt, llm=self.llm)
37 | "---------------------------"
38 | document = [Document(page_content=document, metadata={"source": file_name})] #对输入的document进行格式化处理
39 | self.documents = self.text_splitter.split_documents(document) #根据
40 | self.vectorstore = self.lmaiss.from_documents(self.documents, embedding_class=self.embedding_tool)
41 |
42 | def __call__(self, query):
43 | query_embedding = self.embedding_tool.embed_query(query)
44 |
45 | #根据query查找最近的那个序列
46 | close_id = self.lmaiss.get_similarity_vector_indexs(query_embedding, self.vectorstore, k=1)[0]
47 | #查找最近的那个段落id
48 | doc = self.documents[close_id]
49 |
50 | #构建查询的query
51 | query = f"你现在要回答问题'{query}',你可以参考文献'{doc}',你如果找不到对应的内容,就从自己的记忆体中查找,就回答'请提供更为准确的查询内容',注意你要一步步的思考再回答。"
52 | result = (self.qa_chain.predict(input=query))
53 | return result
54 |
55 | def predict(self,input):
56 | result = self.__call__(input)
57 | return result
--------------------------------------------------------------------------------
/lmchain/chains/mathChain.py:
--------------------------------------------------------------------------------
1 | #这里是执行对CMD命令进行调用的chain
2 |
3 | from langchain.chains.llm import LLMChain
4 | from langchain.prompts import PromptTemplate
5 | from lmchain.lmchain.agents import llmAgent
6 | import os,re,math
7 |
8 | try:
9 | import numexpr # noqa: F401
10 | except ImportError:
11 | raise ImportError(
12 | "LMchain requires the numexpr package. "
13 | "Please install it with `pip install numexpr`."
14 | )
15 |
16 |
17 | class LLMMathChain:
18 | def __init__(self ,llm):
19 | qa_prompt = PromptTemplate(template="""现在给你一个中文命令,请你把这个命令转化成数学公式。直接给出数学公式。这个公式会在numexpr包中调用。
20 | Question: 我现在需要计算{question},结果需要在numexpr包中调用。
21 | ,
22 | input_variables=["question"],
23 | )
24 | answer:""", input_variables=["question"], )
25 | self.qa_chain = LLMChain(llm=llm, prompt=qa_prompt)
26 |
27 |
28 | def run(self ,text):
29 | cmd_response = self.qa_chain.run(question=text)
30 | result = self._evaluate_expression(str(cmd_response))
31 | return result
32 |
33 |
34 | def _evaluate_expression(self, expression: str) -> str:
35 | import numexpr # noqa: F401
36 |
37 | try:
38 | local_dict = {"pi": math.pi, "e": math.e}
39 | output = str(
40 | numexpr.evaluate(
41 | expression.strip(),
42 | global_dict={}, # restrict access to globals
43 | local_dict=local_dict, # add common mathematical functions
44 | )
45 | )
46 | except Exception as e:
47 | raise ValueError(
48 | f'LMchain._evaluate("{expression}") raised error: {e}.'
49 | " Please try again with a valid numerical expression"
50 | )
51 |
52 | # Remove any leading and trailing brackets from the output
53 | return re.sub(r"^\[|\]$", "", output)
--------------------------------------------------------------------------------
/lmchain/chains/question_answering.py:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/lmchain/chains/subQuestChain.py:
--------------------------------------------------------------------------------
1 | from langchain.chains import LLMChain
2 | from langchain.prompts import PromptTemplate
3 |
4 | from tqdm import tqdm
5 | from lmchain.tools import tool_register
6 |
7 |
8 | class SubQuestChain:
9 | def __init__(self, llm):
10 | self.llm = llm
11 |
12 | def __call__(self, query=""):
13 | if query == "":
14 | raise "query需要填入查询问题"
15 |
16 | decomp_template = """
17 | GENERAL INSTRUCTIONS
18 | You are a domain expert. Your task is to break down a complex question into simpler sub-parts.
19 |
20 | USER QUESTION
21 | {user_question}
22 |
23 | ANSWER FORMAT
24 | ["sub-questions_1","sub-questions_2","sub-questions_3",...]
25 | """
26 |
27 | from langchain.prompts import PromptTemplate
28 | prompt = PromptTemplate(
29 | input_variables=["user_question"],
30 | template=decomp_template,
31 | )
32 |
33 | from langchain.chains import LLMChain
34 | chain = LLMChain(llm=self.llm, prompt=prompt)
35 | response = (chain.run({"user_question": query}))
36 |
37 | import json
38 | sub_list = json.loads(response)
39 |
40 | return sub_list
41 |
42 | def run(self, query):
43 | sub_list = self.__call__(query)
44 | return sub_list
45 |
46 |
47 | if __name__ == '__main__':
48 | from lmchain.agents import llmMultiAgent
49 |
50 | llm = llmMultiAgent.AgentZhipuAI()
51 |
52 | subQC = SubQuestChain(llm)
53 | response = subQC.run(query="工商银行财报中,2024财年Q1与Q2 之间,利润增长了多少?")
54 | print(response)
--------------------------------------------------------------------------------
/lmchain/chains/toolchain.py:
--------------------------------------------------------------------------------
1 | from langchain.chains import LLMChain
2 | from langchain.prompts import PromptTemplate
3 |
4 | from tqdm import tqdm
5 | from lmchain.tools import tool_register
6 |
7 |
8 | class GLMToolChain:
9 | def __init__(self, llm):
10 |
11 | self.llm = llm
12 | self.tool_register = tool_register
13 | self.tools = tool_register.get_tools()
14 |
15 | def __call__(self, query="", tools=None):
16 |
17 | if query == "":
18 | raise "query需要填入查询问题"
19 | if tools != None:
20 | self.tools = tools
21 | else:
22 | raise "将使用默认tools完成函数工具调用~"
23 | template = f"""
24 | 你现在是一个专业的人工智能助手,你现在的需求是{query}。而你需要借助于工具在{self.tools}中找到对应的函数,用json格式返回对应的函数名和参数。
25 | 函数名定义为function_name,参数名为params,还要求写入详细的形参与实参。
26 |
27 | 如果找到合适的函数,就返回json格式的函数名和需要的参数,不要回答任何描述和解释。
28 |
29 | 如果没有找到合适的函数,则返回:'未找到合适参数,请提供更详细的描述。'
30 | """
31 |
32 | flag = True
33 | counter = 0
34 | while flag:
35 | try:
36 | res = self.llm(template)
37 |
38 | import json
39 | res_dict = json.loads(res)
40 | res_dict = json.loads(res_dict)
41 | flag = False
42 | except:
43 | # print("失败输出,现在开始重新验证")
44 | template = f"""
45 | 你现在是一个专业的人工智能助手,你现在的需求是{query}。而你需要借助于工具在{self.tools}中找到对应的函数,用json格式返回对应的函数名和参数。
46 | 函数名定义为function_name,参数名为params,还要求写入详细的形参与实参。
47 |
48 | 如果找到合适的函数,就返回json格式的函数名和需要的参数,不要回答任何描述和解释。
49 |
50 | 如果没有找到合适的函数,则返回:'未找到合适参数,请提供更详细的描述。'
51 |
52 | 你刚才生成了一组结果,但是返回不符合json格式,现在请你重新按json格式生成并返回结果。
53 | """
54 | counter += 1
55 | if counter >= 5:
56 | return '未找到合适参数,请提供更详细的描述。'
57 | return res_dict
58 |
59 | def run(self, query, tools=None):
60 | tools = (self.tool_register.get_tools())
61 | result = self.__call__(query, tools)
62 |
63 | if result == "未找到合适参数,请提供更详细的描述。":
64 | return "未找到合适参数,请提供更详细的描述。"
65 | else:
66 | print("找到对应工具函数,格式如下:", result)
67 | result = self.dispatch_tool(result)
68 | from lmchain.prompts.templates import PromptTemplate
69 | tool_prompt = PromptTemplate(
70 | input_variables=["query", "result"], # 输入变量包括中文和英文。
71 | template="你现在是一个私人助手,现在你的查询任务是{query},而你通过工具从网上查询的结果是{result},现在根据查询的内容与查询的结果,生成最终答案。",
72 | # 使用模板格式化输入和输出。
73 | )
74 | from langchain.chains import LLMChain
75 | chain = LLMChain(llm=self.llm, prompt=tool_prompt)
76 |
77 | response = (chain.run({"query": query, "result": result}))
78 |
79 | return response
80 |
81 | def add_tools(self, tool):
82 | self.tool_register.register_tool(tool)
83 | return True
84 |
85 | def dispatch_tool(self, tool_result) -> str:
86 | tool_name = tool_result["function_name"]
87 | tool_params = tool_result["params"]
88 | if tool_name not in self.tool_register._TOOL_HOOKS:
89 | return f"Tool `{tool_name}` not found. Please use a provided tool."
90 | tool_call = self.tool_register._TOOL_HOOKS[tool_name]
91 |
92 | try:
93 | ret = tool_call(**tool_params)
94 | except:
95 | import traceback
96 | ret = traceback.format_exc()
97 | return str(ret)
98 |
99 | def get_tools(self):
100 | return (self.tool_register.get_tools())
101 |
102 |
103 | if __name__ == '__main__':
104 | from lmchain.agents import llmMultiAgent
105 |
106 | llm = llmMultiAgent.AgentZhipuAI()
107 |
108 | from lmchain.chains import toolchain
109 |
110 | tool_chain = toolchain.GLMToolChain(llm)
111 |
112 | from typing import Annotated
113 |
114 |
115 | def rando_numbr(
116 | seed: Annotated[int, 'The random seed used by the generator', True],
117 | range: Annotated[tuple[int, int], 'The range of the generated numbers', True],
118 | ) -> int:
119 | """
120 | Generates a random number x, s.t. range[0] <= x < range[1]
121 | """
122 | import random
123 | return random.Random(seed).randint(*range)
124 |
125 |
126 | tool_chain.add_tools(rando_numbr)
127 |
128 | print("------------------------------------------------------")
129 | query = "今天shanghai的天气是什么?"
130 | result = tool_chain.run(query)
131 |
132 | result = tool_chain.dispatch_tool(result)
133 | print(result)
134 |
135 |
136 |
--------------------------------------------------------------------------------
/lmchain/chains/urlRequestChain.py:
--------------------------------------------------------------------------------
1 | from langchain.chains import LLMRequestsChain, LLMChain
2 | from langchain.prompts import PromptTemplate
3 |
4 | import requests
5 | from bs4 import BeautifulSoup
6 | from tqdm import tqdm
7 |
8 |
9 | class LMRequestsChain:
10 | def __init__(self,llm,max_url_num = 2):
11 | template = """Between >>> and <<< are the raw search result text from google.
12 | Extract the answer to the question '{query}' or say "not found" if the information is not contained.
13 | Use the format
14 | Extracted:
15 | >>> {requests_result} <<<
16 | Extracted:"""
17 | PROMPT = PromptTemplate(
18 | input_variables=["query", "requests_result"],
19 | template=template,
20 | )
21 | self.chain = LLMRequestsChain(llm_chain=LLMChain(llm=llm, prompt=PROMPT))
22 | self.max_url_num = max_url_num
23 |
24 | query_prompt = PromptTemplate(
25 | input_variables=["query","responses"],
26 | template = "作为一名专业的信息总结员,我需要查询的信息为{query},根据提供的信息{responses}回答一下查询的结果。")
27 | self.query_chain = LLMChain(llm=llm, prompt=query_prompt)
28 |
29 | def __call__(self, query,target_site = ""):
30 | url_list = self.get_urls(query,target_site = target_site)
31 | print(f"查找到{len(url_list)}条url内容,现在开始解析其中的{self.max_url_num}条内容。")
32 | responses = []
33 | for url in tqdm(url_list[:self.max_url_num]):
34 | inputs = {
35 | "query": query,
36 | "url": url
37 | }
38 |
39 | response = self.chain(inputs)
40 | output = response["output"]
41 | responses.append(output)
42 | if len(responses) != 0:
43 | output = self.query_chain.run({"query":query,"responses":responses})
44 | return output
45 | else:
46 | return "查找内容为空,请更换查找词"
47 |
48 | def query_form_url(self,query = "LMchain是什么?",url = ""):
49 | assert url != "",print("url link must be set")
50 | inputs = {
51 | "query": query,
52 | "url": url
53 | }
54 | response = self.chain(inputs)
55 | return response
56 |
57 | def get_urls(self,query='lmchain是什么?', target_site=""):
58 | def bing_search(query, count=30):
59 | url = f'https://cn.bing.com/search?q={query}'
60 | headers = {
61 | 'User-Agent': 'Mozilla/6.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.3'}
62 | response = requests.get(url, headers=headers)
63 | if response.status_code == 200:
64 | html = response.text
65 | # 使用BeautifulSoup解析HTML
66 |
67 | soup = BeautifulSoup(html, 'html.parser')
68 | results = soup.find_all('li', class_='b_algo')
69 | return [result.find('a').text for result in results[:count]]
70 | else:
71 | print(f'请求失败,状态码:{response.status_code}')
72 | return []
73 | results = bing_search(query)
74 | if len(results) == 0:
75 | return None
76 | url_list = []
77 | if target_site != "":
78 | for i, result in enumerate(results):
79 | if "https" in result and target_site in result:
80 | url = "https://" + result.split("https://")[1]
81 | url_list.append(url)
82 | else:
83 | for i, result in enumerate(results):
84 | if "https" in result:
85 | url = "https://" + result.split("https://")[1]
86 | url_list.append(url)
87 | if len(url_list) > 0:
88 | return url_list
89 | else:
90 | # 这里是确保在知乎里面找不到对应的内容,有相应的内容返回
91 | for i, result in enumerate(results):
92 | if "https" in result:
93 | url = "https://" + result.split("https://")[1]
94 | url_list.append(url)
95 | return url_list
96 |
97 |
98 |
--------------------------------------------------------------------------------
/lmchain/embeddings/__init__.py:
--------------------------------------------------------------------------------
1 | name = "embeddings"
--------------------------------------------------------------------------------
/lmchain/index/__init__.py:
--------------------------------------------------------------------------------
1 | name = "index"
--------------------------------------------------------------------------------
/lmchain/index/indexChain.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Dict, List, Optional, Type
2 |
3 | from langchain.chains.qa_with_sources.retrieval import RetrievalQAWithSourcesChain
4 | from langchain.chains.retrieval_qa.base import RetrievalQA
5 | from langchain.document_loaders.base import BaseLoader
6 | from pydantic.v1 import BaseModel, Extra, Field
7 | from langchain.schema import Document
8 | from langchain.schema.embeddings import Embeddings
9 | from langchain.schema.language_model import BaseLanguageModel
10 | from langchain.schema.vectorstore import VectorStore
11 | from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter
12 | from langchain.vectorstores.chroma import Chroma
13 |
14 |
15 | def _get_default_text_splitter() -> TextSplitter:
16 | return RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
17 |
18 | from lmchain.embeddings import embeddings
19 | embedding_tool = embeddings.GLMEmbedding()
20 |
21 | class VectorstoreIndexCreator(BaseModel):
22 | """Logic for creating indexes."""
23 |
24 | class Config:
25 | """Configuration for this pydantic object."""
26 | extra = Extra.forbid
27 | arbitrary_types_allowed = True
28 |
29 |
30 |
31 |
32 | chunk_size = 1280 # 每段字数长度
33 | chunk_overlap = 32 # 重叠的字数
34 | text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
35 |
36 |
37 |
38 |
39 |
40 | def from_loaders(self, loaders: List[BaseLoader]):
41 | """Create a vectorstore index from loaders."""
42 | docs = []
43 | for loader in loaders:
44 | docs.extend(loader.load())
45 | return self.from_documents(docs)
46 |
47 |
48 | def from_documents(self, documents: List[Document]):
49 | #说一下这个index的作用就是返回
50 | sub_docs = self.text_splitter.split_documents(documents)
51 |
52 | # texts = [d.page_content for d in sub_docs]
53 | # metadatas = [d.metadata for d in sub_docs]
54 |
55 | qa_chain = ConversationalRetrievalChain(document=sub_docs)
56 | return qa_chain
57 |
58 |
59 | from langchain.docstore.document import Document
60 | from langchain.text_splitter import RecursiveCharacterTextSplitter
61 | from lmchain.embeddings import embeddings
62 | from lmchain.vectorstores import laiss
63 | from lmchain.agents import llmMultiAgent
64 | from langchain.memory import ConversationBufferMemory
65 | from langchain.prompts import (
66 | ChatPromptTemplate, # 用于构建聊天模板的类
67 | MessagesPlaceholder, # 用于在模板中插入消息占位的类
68 | SystemMessagePromptTemplate, # 用于构建系统消息模板的类
69 | HumanMessagePromptTemplate # 用于构建人类消息模板的类
70 | )
71 | from langchain.chains import ConversationChain
72 |
73 | class ConversationalRetrievalChain:
74 | def __init__(self,document,chunk_size = 1280,chunk_overlap = 50,file_name = "这是一份辅助材料"):
75 | """
76 | :param document: 输入的文本内容,只要一个text文本
77 | :param chunk_size: 切分后每段的字数
78 | :param chunk_overlap: 每个相隔段落重叠的字数
79 | :param file_name: 文本名称/文本地址
80 | """
81 | self.text_splitter = RecursiveCharacterTextSplitter( chunk_size=chunk_size, chunk_overlap=chunk_overlap)
82 | self.embedding_tool = embedding_tool
83 |
84 | self.lmaiss = laiss.LMASS() #这里是用于将文本转化为vector,并且计算query相应的相似度的类
85 | self.llm = llmMultiAgent.AgentZhipuAI()
86 | self.memory = ConversationBufferMemory(return_messages=True)
87 |
88 | conversation_prompt = ChatPromptTemplate.from_messages([
89 | SystemMessagePromptTemplate.from_template("你是一个最强大的人工智能程序,可以知无不答,但是你不懂的东西会直接回答不知道。"),
90 | MessagesPlaceholder(variable_name="history"), # 历史消息占位符
91 | HumanMessagePromptTemplate.from_template("{input}") # 人类消息输入模板
92 | ])
93 |
94 | self.qa_chain = ConversationChain(memory=self.memory, prompt=conversation_prompt, llm=self.llm)
95 | "---------------------------"
96 | self.metadatas = []
97 | for doc in document:
98 | self.metadatas.append(doc.metadata)
99 | self.documents = self.text_splitter.split_documents(document) #根据
100 | self.vectorstore = self.lmaiss.from_documents(self.documents, embedding_class=self.embedding_tool)
101 |
102 |
103 |
104 | def __call__(self, query):
105 | query_embedding = self.embedding_tool.embed_query(query)
106 |
107 | #根据query查找最近的那个序列
108 | close_id = self.lmaiss.get_similarity_vector_indexs(query_embedding, self.vectorstore, k=1)[0]
109 | #查找最近的那个段落id
110 | doc = self.documents[close_id]
111 | meta = self.metadatas[close_id]
112 | #构建查询的query
113 | query = f"你现在要回答问题'{query}',你可以参考文献'{doc}',你如果找不到对应的内容,就从自己的记忆体中查找,就回答'请提供更为准确的查询内容'。"
114 | result = (self.qa_chain.predict(input=query))
115 | return result,meta
116 |
117 |
118 | def query(self,input):
119 | result,meta = self.__call__(input)
120 | return result
121 |
122 | #这里的模型的意思是
123 | def query_with_sources(self,input):
124 | result,meta = self.__call__(input)
125 | return {"answer":result,"sources":meta}
126 |
--------------------------------------------------------------------------------
/lmchain/llms/__init__.py:
--------------------------------------------------------------------------------
1 | name = "llms"
--------------------------------------------------------------------------------
/lmchain/llms/base.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/virgo777/lmchain/3fc4a55df07100d4fe0beef4da89e3fbc6288938/lmchain/llms/base.py
--------------------------------------------------------------------------------
/lmchain/load/__init__.py:
--------------------------------------------------------------------------------
1 | name = "load"
--------------------------------------------------------------------------------
/lmchain/load/__pycache__/__init__.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/virgo777/lmchain/3fc4a55df07100d4fe0beef4da89e3fbc6288938/lmchain/load/__pycache__/__init__.cpython-311.pyc
--------------------------------------------------------------------------------
/lmchain/load/__pycache__/serializable.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/virgo777/lmchain/3fc4a55df07100d4fe0beef4da89e3fbc6288938/lmchain/load/__pycache__/serializable.cpython-311.pyc
--------------------------------------------------------------------------------
/lmchain/memory/__init__.py:
--------------------------------------------------------------------------------
1 | name = "memory"
--------------------------------------------------------------------------------
/lmchain/memory/chat_memory.py:
--------------------------------------------------------------------------------
1 | from abc import ABC
2 | from typing import Any, Dict, Optional, Tuple
3 |
4 | from lmchain.memory.utils import get_prompt_input_key
5 |
6 | from lmchain.schema.memory import BaseMemory
7 |
8 |
9 | class BaseChatMemory(BaseMemory, ABC):
10 | """Abstract base class for chat memory."""
11 |
12 | from lmchain.memory import messageHistory
13 | chat_memory = messageHistory.ChatMessageHistory()
14 | output_key: Optional[str] = None
15 | input_key: Optional[str] = None
16 | return_messages: bool = False
17 |
18 | def _get_input_output(
19 | self, inputs: Dict[str, Any], outputs: Dict[str, str]
20 | ) -> Tuple[str, str]:
21 | if self.input_key is None:
22 | prompt_input_key = get_prompt_input_key(inputs, self.memory_variables)
23 | else:
24 | prompt_input_key = self.input_key
25 | if self.output_key is None:
26 | if len(outputs) != 1:
27 | raise ValueError(f"One output key expected, got {outputs.keys()}")
28 | output_key = list(outputs.keys())[0]
29 | else:
30 | output_key = self.output_key
31 | return inputs[prompt_input_key], outputs[output_key]
32 |
33 | def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
34 | """Save context from this conversation to buffer."""
35 | input_str, output_str = self._get_input_output(inputs, outputs)
36 | self.chat_memory.add_user_message(input_str)
37 | self.chat_memory.add_ai_message(output_str)
38 |
39 | def clear(self) -> None:
40 | """Clear memory contents."""
41 | self.chat_memory.clear()
42 |
--------------------------------------------------------------------------------
/lmchain/memory/messageHistory.py:
--------------------------------------------------------------------------------
1 | from typing import TYPE_CHECKING, Any, Dict, List, Sequence, Union
2 | from typing_extensions import Literal
3 |
4 |
5 | class ChatMessageHistory:
6 | """In memory implementation of chat message history.
7 |
8 | Stores messages in an in memory list.
9 | """
10 |
11 | messages = []
12 |
13 | def add_message(self, message) -> None:
14 | """Add a self-created message to the store"""
15 | self.messages.append(message)
16 |
17 | def clear(self) -> None:
18 | self.messages = []
19 |
20 | def __str__(self):
21 | return ", ".join(str(message) for message in self.messages)
22 |
23 |
24 | class ChatMessageHistory(ChatMessageHistory):
25 | def __init__(self):
26 | super(ChatMessageHistory).__init__()
27 |
28 | def add_user_message(self, content: str) -> None:
29 | """Convenience method for adding a human message string to the store.
30 |
31 | Args:
32 | content: The string contents of a human message.
33 | """
34 | mes = f"HumanMessage(content={content})"
35 | self.messages.append(mes)
36 |
37 | def add_ai_message(self, content: str) -> None:
38 | """Convenience method for adding an AI message string to the store.
39 |
40 | Args:
41 | content: The string contents of an AI message.
42 | """
43 | mes = f"AIMessage(content={content})"
44 | self.messages.append(mes)
45 |
46 |
47 | from typing import Any, Dict, List, Optional
48 |
49 | from langchain.memory.chat_memory import BaseChatMemory, BaseMemory
50 | from langchain.memory.utils import get_prompt_input_key
51 | from pydantic.v1 import root_validator
52 | from langchain.schema.messages import BaseMessage, get_buffer_string
53 |
54 |
55 | class ConversationBufferMemory(BaseChatMemory):
56 | """Buffer for storing conversation memory."""
57 |
58 | human_prefix: str = "Human"
59 | ai_prefix: str = "AI"
60 | memory_key: str = "history" #: :meta private:
61 |
62 | @property
63 | def buffer(self) -> Any:
64 | """String buffer of memory."""
65 | return self.buffer_as_messages if self.return_messages else self.buffer_as_str
66 |
67 | @property
68 | def buffer_as_str(self) -> str:
69 | """Exposes the buffer as a string in case return_messages is True."""
70 | return get_buffer_string(
71 | self.chat_memory.messages,
72 | human_prefix=self.human_prefix,
73 | ai_prefix=self.ai_prefix,
74 | )
75 |
76 | @property
77 | def buffer_as_messages(self) -> List[BaseMessage]:
78 | """Exposes the buffer as a list of messages in case return_messages is False."""
79 | return self.chat_memory.messages
80 |
81 | @property
82 | def memory_variables(self) -> List[str]:
83 | """Will always return list of memory variables.
84 |
85 | :meta private:
86 | """
87 | return [self.memory_key]
88 |
89 | def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
90 | """Return history buffer."""
91 | return {self.memory_key: self.buffer}
92 |
93 |
94 | class ConversationStringBufferMemory(BaseMemory):
95 | """Buffer for storing conversation memory."""
96 |
97 | human_prefix: str = "Human"
98 | ai_prefix: str = "AI"
99 | """Prefix to use for AI generated responses."""
100 | buffer: str = ""
101 | output_key: Optional[str] = None
102 | input_key: Optional[str] = None
103 | memory_key: str = "history" #: :meta private:
104 |
105 | @root_validator()
106 | def validate_chains(cls, values: Dict) -> Dict:
107 | """Validate that return messages is not True."""
108 | if values.get("return_messages", False):
109 | raise ValueError(
110 | "return_messages must be False for ConversationStringBufferMemory"
111 | )
112 | return values
113 |
114 | @property
115 | def memory_variables(self) -> List[str]:
116 | """Will always return list of memory variables.
117 | :meta private:
118 | """
119 | return [self.memory_key]
120 |
121 | def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]:
122 | """Return history buffer."""
123 | return {self.memory_key: self.buffer}
124 |
125 | def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
126 | """Save context from this conversation to buffer."""
127 | if self.input_key is None:
128 | prompt_input_key = get_prompt_input_key(inputs, self.memory_variables)
129 | else:
130 | prompt_input_key = self.input_key
131 | if self.output_key is None:
132 | if len(outputs) != 1:
133 | raise ValueError(f"One output key expected, got {outputs.keys()}")
134 | output_key = list(outputs.keys())[0]
135 | else:
136 | output_key = self.output_key
137 | human = f"{self.human_prefix}: " + inputs[prompt_input_key]
138 | ai = f"{self.ai_prefix}: " + outputs[output_key]
139 | self.buffer += "\n" + "\n".join([human, ai])
140 |
141 | def clear(self) -> None:
142 | """Clear memory contents."""
143 | self.buffer = ""
144 |
145 |
146 |
147 |
148 |
149 |
150 |
151 |
152 |
--------------------------------------------------------------------------------
/lmchain/memory/utils.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Dict, List
2 |
3 |
4 | def get_prompt_input_key(inputs: Dict[str, Any], memory_variables: List[str]) -> str:
5 | """
6 | Get the prompt input key.
7 |
8 | Args:
9 | inputs: Dict[str, Any]
10 | memory_variables: List[str]
11 |
12 | Returns:
13 | A prompt input key.
14 | """
15 | # "stop" is a special key that can be passed as input but is not used to
16 | # format the prompt.
17 | prompt_input_keys = list(set(inputs).difference(memory_variables + ["stop"]))
18 | if len(prompt_input_keys) != 1:
19 | raise ValueError(f"One input key expected got {prompt_input_keys}")
20 | return prompt_input_keys[0]
--------------------------------------------------------------------------------
/lmchain/model/__init__.py:
--------------------------------------------------------------------------------
1 | name = "model"
--------------------------------------------------------------------------------
/lmchain/prompts/__init__.py:
--------------------------------------------------------------------------------
1 | name = "prompts"
--------------------------------------------------------------------------------
/lmchain/prompts/__pycache__/__init__.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/virgo777/lmchain/3fc4a55df07100d4fe0beef4da89e3fbc6288938/lmchain/prompts/__pycache__/__init__.cpython-311.pyc
--------------------------------------------------------------------------------
/lmchain/prompts/__pycache__/base.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/virgo777/lmchain/3fc4a55df07100d4fe0beef4da89e3fbc6288938/lmchain/prompts/__pycache__/base.cpython-311.pyc
--------------------------------------------------------------------------------
/lmchain/prompts/base.py:
--------------------------------------------------------------------------------
1 | """BasePrompt schema definition."""
2 | from __future__ import annotations
3 |
4 | import warnings
5 | from abc import ABC
6 | from string import Formatter
7 | from typing import Any, Callable, Dict, List, Literal, Set
8 |
9 | from lmchain.schema.messages import BaseMessage, HumanMessage
10 | from lmchain.schema.prompt import PromptValue
11 | from lmchain.schema.prompt_template import BasePromptTemplate
12 | #from langchain.schema.prompt_template import BasePromptTemplate
13 | from lmchain.utils.formatting import formatter
14 |
15 |
16 | def jinja2_formatter(template: str, **kwargs: Any) -> str:
17 | """Format a template using jinja2.
18 |
19 | *Security warning*: As of LangChain 0.0.329, this method uses Jinja2's
20 | SandboxedEnvironment by default. However, this sand-boxing should
21 | be treated as a best-effort approach rather than a guarantee of security.
22 | Do not accept jinja2 templates from untrusted sources as they may lead
23 | to arbitrary Python code execution.
24 |
25 | https://jinja.palletsprojects.com/en/3.1.x/sandbox/
26 | """
27 | try:
28 | from jinja2.sandbox import SandboxedEnvironment
29 | except ImportError:
30 | raise ImportError(
31 | "jinja2 not installed, which is needed to use the jinja2_formatter. "
32 | "Please install it with `pip install jinja2`."
33 | "Please be cautious when using jinja2 templates. "
34 | "Do not expand jinja2 templates using unverified or user-controlled "
35 | "inputs as that can result in arbitrary Python code execution."
36 | )
37 |
38 | # This uses a sandboxed environment to prevent arbitrary code execution.
39 | # Jinja2 uses an opt-out rather than opt-in approach for sand-boxing.
40 | # Please treat this sand-boxing as a best-effort approach rather than
41 | # a guarantee of security.
42 | # We recommend to never use jinja2 templates with untrusted inputs.
43 | # https://jinja.palletsprojects.com/en/3.1.x/sandbox/
44 | # approach not a guarantee of security.
45 | return SandboxedEnvironment().from_string(template).render(**kwargs)
46 |
47 |
48 | def validate_jinja2(template: str, input_variables: List[str]) -> None:
49 | """
50 | Validate that the input variables are valid for the template.
51 | Issues a warning if missing or extra variables are found.
52 |
53 | Args:
54 | template: The template string.
55 | input_variables: The input variables.
56 | """
57 | input_variables_set = set(input_variables)
58 | valid_variables = _get_jinja2_variables_from_template(template)
59 | missing_variables = valid_variables - input_variables_set
60 | extra_variables = input_variables_set - valid_variables
61 |
62 | warning_message = ""
63 | if missing_variables:
64 | warning_message += f"Missing variables: {missing_variables} "
65 |
66 | if extra_variables:
67 | warning_message += f"Extra variables: {extra_variables}"
68 |
69 | if warning_message:
70 | warnings.warn(warning_message.strip())
71 |
72 |
73 | def _get_jinja2_variables_from_template(template: str) -> Set[str]:
74 | try:
75 | from jinja2 import Environment, meta
76 | except ImportError:
77 | raise ImportError(
78 | "jinja2 not installed, which is needed to use the jinja2_formatter. "
79 | "Please install it with `pip install jinja2`."
80 | )
81 | env = Environment()
82 | ast = env.parse(template)
83 | variables = meta.find_undeclared_variables(ast)
84 | return variables
85 |
86 |
87 | DEFAULT_FORMATTER_MAPPING: Dict[str, Callable] = {
88 | "f-string": formatter.format,
89 | "jinja2": jinja2_formatter,
90 | }
91 |
92 | DEFAULT_VALIDATOR_MAPPING: Dict[str, Callable] = {
93 | "f-string": formatter.validate_input_variables,
94 | "jinja2": validate_jinja2,
95 | }
96 |
97 |
98 | def check_valid_template(
99 | template: str, template_format: str, input_variables: List[str]
100 | ) -> None:
101 | """Check that template string is valid.
102 |
103 | Args:
104 | template: The template string.
105 | template_format: The template format. Should be one of "f-string" or "jinja2".
106 | input_variables: The input variables.
107 |
108 | Raises:
109 | ValueError: If the template format is not supported.
110 | """
111 | if template_format not in DEFAULT_FORMATTER_MAPPING:
112 | valid_formats = list(DEFAULT_FORMATTER_MAPPING)
113 | raise ValueError(
114 | f"Invalid template format. Got `{template_format}`;"
115 | f" should be one of {valid_formats}"
116 | )
117 | try:
118 | validator_func = DEFAULT_VALIDATOR_MAPPING[template_format]
119 | validator_func(template, input_variables)
120 | except KeyError as e:
121 | raise ValueError(
122 | "Invalid prompt schema; check for mismatched or missing input parameters. "
123 | + str(e)
124 | )
125 |
126 |
127 | def get_template_variables(template: str, template_format: str) -> List[str]:
128 | """Get the variables from the template.
129 |
130 | Args:
131 | template: The template string.
132 | template_format: The template format. Should be one of "f-string" or "jinja2".
133 |
134 | Returns:
135 | The variables from the template.
136 |
137 | Raises:
138 | ValueError: If the template format is not supported.
139 | """
140 | if template_format == "jinja2":
141 | # Get the variables for the template
142 | input_variables = _get_jinja2_variables_from_template(template)
143 | elif template_format == "f-string":
144 | input_variables = {
145 | v for _, v, _, _ in Formatter().parse(template) if v is not None
146 | }
147 | else:
148 | raise ValueError(f"Unsupported template format: {template_format}")
149 |
150 | return sorted(input_variables)
151 |
152 |
153 | class StringPromptValue(PromptValue):
154 | """String prompt value."""
155 |
156 | text: str
157 | """Prompt text."""
158 | type: Literal["StringPromptValue"] = "StringPromptValue"
159 |
160 | def to_string(self) -> str:
161 | """Return prompt as string."""
162 | return self.text
163 |
164 | def to_messages(self) -> List[BaseMessage]:
165 | """Return prompt as messages."""
166 | return [HumanMessage(content=self.text)]
167 |
168 |
169 | class StringPromptTemplate(BasePromptTemplate, ABC):
170 | """String prompt that exposes the format method, returning a prompt."""
171 |
172 | def format_prompt(self, **kwargs: Any) -> PromptValue:
173 | """Create Chat Messages."""
174 | return StringPromptValue(text=self.format(**kwargs))
175 |
--------------------------------------------------------------------------------
/lmchain/prompts/loading.py:
--------------------------------------------------------------------------------
1 | """Load prompts."""
2 | import json
3 | import logging
4 | from pathlib import Path
5 | from typing import Callable, Dict, Union
6 |
7 | import yaml
8 |
9 | from lmchain.prompts.few_shot_templates import FewShotPromptTemplate
10 | from lmchain.prompts.prompt import PromptTemplate
11 | #from langchain.schema import BaseLLMOutputParser, BasePromptTemplate, StrOutputParser
12 | from lmchain.schema.output_parser import BaseLLMOutputParser,StrOutputParser
13 | from lmchain.schema.prompt_template import BasePromptTemplate
14 |
15 |
16 |
17 | from lmchain.utils.loading import try_load_from_hub
18 |
19 | URL_BASE = "https://raw.githubusercontent.com/hwchase17/langchain-hub/master/prompts/"
20 | logger = logging.getLogger(__name__)
21 |
22 |
23 | def load_prompt_from_config(config: dict) -> BasePromptTemplate:
24 | """Load prompt from Config Dict."""
25 | if "_type" not in config:
26 | logger.warning("No `_type` key found, defaulting to `prompt`.")
27 | config_type = config.pop("_type", "prompt")
28 |
29 | if config_type not in type_to_loader_dict:
30 | raise ValueError(f"Loading {config_type} prompt not supported")
31 |
32 | prompt_loader = type_to_loader_dict[config_type]
33 | return prompt_loader(config)
34 |
35 |
36 | def _load_template(var_name: str, config: dict) -> dict:
37 | """Load template from the path if applicable."""
38 | # Check if template_path exists in config.
39 | if f"{var_name}_path" in config:
40 | # If it does, make sure template variable doesn't also exist.
41 | if var_name in config:
42 | raise ValueError(
43 | f"Both `{var_name}_path` and `{var_name}` cannot be provided."
44 | )
45 | # Pop the template path from the config.
46 | template_path = Path(config.pop(f"{var_name}_path"))
47 | # Load the template.
48 | if template_path.suffix == ".txt":
49 | with open(template_path) as f:
50 | template = f.read()
51 | else:
52 | raise ValueError
53 | # Set the template variable to the extracted variable.
54 | config[var_name] = template
55 | return config
56 |
57 |
58 | def _load_examples(config: dict) -> dict:
59 | """Load examples if necessary."""
60 | if isinstance(config["examples"], list):
61 | pass
62 | elif isinstance(config["examples"], str):
63 | with open(config["examples"]) as f:
64 | if config["examples"].endswith(".json"):
65 | examples = json.load(f)
66 | elif config["examples"].endswith((".yaml", ".yml")):
67 | examples = yaml.safe_load(f)
68 | else:
69 | raise ValueError(
70 | "Invalid file format. Only json or yaml formats are supported."
71 | )
72 | config["examples"] = examples
73 | else:
74 | raise ValueError("Invalid examples format. Only list or string are supported.")
75 | return config
76 |
77 |
78 | def _load_output_parser(config: dict) -> dict:
79 | """Load output parser."""
80 | if "output_parser" in config and config["output_parser"]:
81 | _config = config.pop("output_parser")
82 | output_parser_type = _config.pop("_type")
83 | if output_parser_type == "regex_parser":
84 | from langchain.output_parsers.regex import RegexParser
85 |
86 | output_parser: BaseLLMOutputParser = RegexParser(**_config)
87 | elif output_parser_type == "default":
88 | output_parser = StrOutputParser(**_config)
89 | else:
90 | raise ValueError(f"Unsupported output parser {output_parser_type}")
91 | config["output_parser"] = output_parser
92 | return config
93 |
94 |
95 | def _load_few_shot_prompt(config: dict) -> FewShotPromptTemplate:
96 | """Load the "few shot" prompt from the config."""
97 | # Load the suffix and prefix templates.
98 | config = _load_template("suffix", config)
99 | config = _load_template("prefix", config)
100 | # Load the example prompt.
101 | if "example_prompt_path" in config:
102 | if "example_prompt" in config:
103 | raise ValueError(
104 | "Only one of example_prompt and example_prompt_path should "
105 | "be specified."
106 | )
107 | config["example_prompt"] = load_prompt(config.pop("example_prompt_path"))
108 | else:
109 | config["example_prompt"] = load_prompt_from_config(config["example_prompt"])
110 | # Load the examples.
111 | config = _load_examples(config)
112 | config = _load_output_parser(config)
113 | return FewShotPromptTemplate(**config)
114 |
115 |
116 | def _load_prompt(config: dict) -> PromptTemplate:
117 | """Load the prompt template from config."""
118 | # Load the template from disk if necessary.
119 | config = _load_template("template", config)
120 | config = _load_output_parser(config)
121 |
122 | template_format = config.get("template_format", "f-string")
123 | if template_format == "jinja2":
124 | # Disabled due to:
125 | # https://github.com/langchain-ai/langchain/issues/4394
126 | raise ValueError(
127 | f"Loading templates with '{template_format}' format is no longer supported "
128 | f"since it can lead to arbitrary code execution. Please migrate to using "
129 | f"the 'f-string' template format, which does not suffer from this issue."
130 | )
131 |
132 | return PromptTemplate(**config)
133 |
134 |
135 | def load_prompt(path: Union[str, Path]) -> BasePromptTemplate:
136 | """Unified method for loading a prompt from LangChainHub or local fs."""
137 | if hub_result := try_load_from_hub(
138 | path, _load_prompt_from_file, "prompts", {"py", "json", "yaml"}
139 | ):
140 | return hub_result
141 | else:
142 | return _load_prompt_from_file(path)
143 |
144 |
145 | def _load_prompt_from_file(file: Union[str, Path]) -> BasePromptTemplate:
146 | """Load prompt from file."""
147 | # Convert file to a Path object.
148 | if isinstance(file, str):
149 | file_path = Path(file)
150 | else:
151 | file_path = file
152 | # Load from either json or yaml.
153 | if file_path.suffix == ".json":
154 | with open(file_path,encoding="UTF-8") as f:
155 | config = json.load(f)
156 | elif file_path.suffix == ".yaml":
157 | with open(file_path, "r") as f:
158 | config = yaml.safe_load(f)
159 | else:
160 | raise ValueError(f"Got unsupported file type {file_path.suffix}")
161 | # Load the prompt from the config now.
162 | return load_prompt_from_config(config)
163 |
164 |
165 | type_to_loader_dict: Dict[str, Callable[[dict], BasePromptTemplate]] = {
166 | "prompt": _load_prompt,
167 | "few_shot": _load_few_shot_prompt,
168 | }
169 |
--------------------------------------------------------------------------------
/lmchain/schema/__init__.py:
--------------------------------------------------------------------------------
1 | name = "schema"
--------------------------------------------------------------------------------
/lmchain/schema/__pycache__/__init__.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/virgo777/lmchain/3fc4a55df07100d4fe0beef4da89e3fbc6288938/lmchain/schema/__pycache__/__init__.cpython-311.pyc
--------------------------------------------------------------------------------
/lmchain/schema/__pycache__/document.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/virgo777/lmchain/3fc4a55df07100d4fe0beef4da89e3fbc6288938/lmchain/schema/__pycache__/document.cpython-311.pyc
--------------------------------------------------------------------------------
/lmchain/schema/__pycache__/messages.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/virgo777/lmchain/3fc4a55df07100d4fe0beef4da89e3fbc6288938/lmchain/schema/__pycache__/messages.cpython-311.pyc
--------------------------------------------------------------------------------
/lmchain/schema/__pycache__/output.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/virgo777/lmchain/3fc4a55df07100d4fe0beef4da89e3fbc6288938/lmchain/schema/__pycache__/output.cpython-311.pyc
--------------------------------------------------------------------------------
/lmchain/schema/__pycache__/output_parser.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/virgo777/lmchain/3fc4a55df07100d4fe0beef4da89e3fbc6288938/lmchain/schema/__pycache__/output_parser.cpython-311.pyc
--------------------------------------------------------------------------------
/lmchain/schema/__pycache__/prompt.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/virgo777/lmchain/3fc4a55df07100d4fe0beef4da89e3fbc6288938/lmchain/schema/__pycache__/prompt.cpython-311.pyc
--------------------------------------------------------------------------------
/lmchain/schema/__pycache__/prompt_template.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/virgo777/lmchain/3fc4a55df07100d4fe0beef4da89e3fbc6288938/lmchain/schema/__pycache__/prompt_template.cpython-311.pyc
--------------------------------------------------------------------------------
/lmchain/schema/agent.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from typing import Any, Literal, Sequence, Union
4 |
5 | from lmchain.load.serializable import Serializable
6 | from lmchain.schema.messages import BaseMessage
7 |
8 |
9 | class AgentAction(Serializable):
10 | """A full description of an action for an ActionAgent to execute."""
11 |
12 | tool: str
13 | """The name of the Tool to execute."""
14 | tool_input: Union[str, dict]
15 | """The input to pass in to the Tool."""
16 | log: str
17 | """Additional information to log about the action.
18 | This log can be used in a few ways. First, it can be used to audit
19 | what exactly the LLM predicted to lead to this (tool, tool_input).
20 | Second, it can be used in future iterations to show the LLMs prior
21 | thoughts. This is useful when (tool, tool_input) does not contain
22 | full information about the LLM prediction (for example, any `thought`
23 | before the tool/tool_input)."""
24 | type: Literal["AgentAction"] = "AgentAction"
25 |
26 | def __init__(
27 | self, tool: str, tool_input: Union[str, dict], log: str, **kwargs: Any
28 | ):
29 | """Override init to support instantiation by position for backward compat."""
30 | super().__init__(tool=tool, tool_input=tool_input, log=log, **kwargs)
31 |
32 | @classmethod
33 | def is_lc_serializable(cls) -> bool:
34 | """Return whether or not the class is serializable."""
35 | return True
36 |
37 |
38 | class AgentActionMessageLog(AgentAction):
39 | message_log: Sequence[BaseMessage]
40 | """Similar to log, this can be used to pass along extra
41 | information about what exact messages were predicted by the LLM
42 | before parsing out the (tool, tool_input). This is again useful
43 | if (tool, tool_input) cannot be used to fully recreate the LLM
44 | prediction, and you need that LLM prediction (for future agent iteration).
45 | Compared to `log`, this is useful when the underlying LLM is a
46 | ChatModel (and therefore returns messages rather than a string)."""
47 | # Ignoring type because we're overriding the type from AgentAction.
48 | # And this is the correct thing to do in this case.
49 | # The type literal is used for serialization purposes.
50 | type: Literal["AgentActionMessageLog"] = "AgentActionMessageLog" # type: ignore
51 |
52 |
53 | class AgentFinish(Serializable):
54 | """The final return value of an ActionAgent."""
55 |
56 | return_values: dict
57 | """Dictionary of return values."""
58 | log: str
59 | """Additional information to log about the return value.
60 | This is used to pass along the full LLM prediction, not just the parsed out
61 | return value. For example, if the full LLM prediction was
62 | `Final Answer: 2` you may want to just return `2` as a return value, but pass
63 | along the full string as a `log` (for debugging or observability purposes).
64 | """
65 | type: Literal["AgentFinish"] = "AgentFinish"
66 |
67 | def __init__(self, return_values: dict, log: str, **kwargs: Any):
68 | """Override init to support instantiation by position for backward compat."""
69 | super().__init__(return_values=return_values, log=log, **kwargs)
70 |
71 | @classmethod
72 | def is_lc_serializable(cls) -> bool:
73 | """Return whether or not the class is serializable."""
74 | return True
75 |
--------------------------------------------------------------------------------
/lmchain/schema/document.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import asyncio
4 | from abc import ABC, abstractmethod
5 | from functools import partial
6 | from typing import Any, Literal, Sequence
7 |
8 | from lmchain.load.serializable import Serializable
9 | from pydantic.v1 import Field
10 |
11 | class Document(Serializable):
12 | """Class for storing a piece of text and associated metadata."""
13 |
14 | page_content: str
15 | """String text."""
16 | metadata: dict = Field(default_factory=dict)
17 | """Arbitrary metadata about the page content (e.g., source, relationships to other
18 | documents, etc.).
19 | """
20 | type: Literal["Document"] = "Document"
21 |
22 | @classmethod
23 | def is_lc_serializable(cls) -> bool:
24 | """Return whether this class is serializable."""
25 | return True
26 |
27 |
28 | class BaseDocumentTransformer(ABC):
29 | """Abstract base class for document transformation systems.
30 |
31 | A document transformation system takes a sequence of Documents and returns a
32 | sequence of transformed Documents.
33 |
34 | Example:
35 | .. code-block:: python
36 |
37 | class EmbeddingsRedundantFilter(BaseDocumentTransformer, BaseModel):
38 | embeddings: Embeddings
39 | similarity_fn: Callable = cosine_similarity
40 | similarity_threshold: float = 0.95
41 |
42 | class Config:
43 | arbitrary_types_allowed = True
44 |
45 | def transform_documents(
46 | self, documents: Sequence[Document], **kwargs: Any
47 | ) -> Sequence[Document]:
48 | stateful_documents = get_stateful_documents(documents)
49 | embedded_documents = _get_embeddings_from_stateful_docs(
50 | self.embeddings, stateful_documents
51 | )
52 | included_idxs = _filter_similar_embeddings(
53 | embedded_documents, self.similarity_fn, self.similarity_threshold
54 | )
55 | return [stateful_documents[i] for i in sorted(included_idxs)]
56 |
57 | async def atransform_documents(
58 | self, documents: Sequence[Document], **kwargs: Any
59 | ) -> Sequence[Document]:
60 | raise NotImplementedError
61 |
62 | """ # noqa: E501
63 |
64 | @abstractmethod
65 | def transform_documents(
66 | self, documents: Sequence[Document], **kwargs: Any
67 | ) -> Sequence[Document]:
68 | """Transform a list of documents.
69 |
70 | Args:
71 | documents: A sequence of Documents to be transformed.
72 |
73 | Returns:
74 | A list of transformed Documents.
75 | """
76 |
77 | async def atransform_documents(
78 | self, documents: Sequence[Document], **kwargs: Any
79 | ) -> Sequence[Document]:
80 | """Asynchronously transform a list of documents.
81 |
82 | Args:
83 | documents: A sequence of Documents to be transformed.
84 |
85 | Returns:
86 | A list of transformed Documents.
87 | """
88 | return await asyncio.get_running_loop().run_in_executor(
89 | None, partial(self.transform_documents, **kwargs), documents
90 | )
91 |
--------------------------------------------------------------------------------
/lmchain/schema/memory.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from abc import ABC, abstractmethod
4 | from typing import Any, Dict, List
5 |
6 |
7 | class BaseMemory( ABC):
8 | """Abstract base class for memory in Chains.
9 |
10 | Memory refers to state in Chains. Memory can be used to store information about
11 | past executions of a Chain and inject that information into the inputs of
12 | future executions of the Chain. For example, for conversational Chains Memory
13 | can be used to store conversations and automatically add them to future model
14 | prompts so that the model has the necessary context to respond coherently to
15 | the latest input.
16 |
17 | Example:
18 | .. code-block:: python
19 |
20 | class SimpleMemory(BaseMemory):
21 | memories: Dict[str, Any] = dict()
22 |
23 | @property
24 | def memory_variables(self) -> List[str]:
25 | return list(self.memories.keys())
26 |
27 | def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]:
28 | return self.memories
29 |
30 | def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
31 | pass
32 |
33 | def clear(self) -> None:
34 | pass
35 | """ # noqa: E501
36 |
37 | class Config:
38 | """Configuration for this pydantic object."""
39 |
40 | arbitrary_types_allowed = True
41 |
42 | @property
43 | @abstractmethod
44 | def memory_variables(self) -> List[str]:
45 | """The string keys this memory class will add to chain inputs."""
46 |
47 | @abstractmethod
48 | def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
49 | """Return key-value pairs given the text input to the chain."""
50 |
51 | @abstractmethod
52 | def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
53 | """Save the context of this chain run to memory."""
54 |
55 | @abstractmethod
56 | def clear(self) -> None:
57 | """Clear memory contents."""
58 |
--------------------------------------------------------------------------------
/lmchain/schema/prompt.py:
--------------------------------------------------------------------------------
1 | # 这段代码定义了一个名为 PromptValue 的抽象基类,该类用于表示任何语言模型的输入。
2 | # 这个类继承自 Serializable 和 ABC(Abstract Base Class),意味着它是一个可序列化的抽象基类。
3 |
4 |
5 | # 导入 __future__ 模块中的 annotations 功能,使得在 Python 3.7 以下版本中也可以使用类型注解的延迟评估功能。
6 | from __future__ import annotations
7 |
8 | # 导入 abc 模块中的 ABC(抽象基类)和 abstractmethod(抽象方法)装饰器。
9 | from abc import ABC, abstractmethod
10 | # 导入 typing 模块中的 List 类型,用于类型注解。
11 | from typing import List
12 |
13 | # 从 lmchain.load.serializable 模块中导入 Serializable 类,用于序列化和反序列化对象。
14 | from lmchain.load.serializable import Serializable
15 | # 从 lmchain.schema.messages 模块中导入 BaseMessage 类,作为消息基类。
16 | from lmchain.schema.messages import BaseMessage
17 |
18 |
19 | # 定义一个名为 PromptValue 的抽象基类,继承自 Serializable 和 ABC。
20 | class PromptValue(Serializable, ABC):
21 | """Base abstract class for inputs to any language model.
22 |
23 | PromptValues can be converted to both LLM (pure text-generation) inputs and
24 | ChatModel inputs.
25 | """
26 |
27 | # 类方法,返回一个布尔值,表示这个类是否可序列化。在这个类中,始终返回 True。
28 | @classmethod
29 | def is_lc_serializable(cls) -> bool:
30 | """Return whether this class is serializable."""
31 | return True
32 |
33 | # 抽象方法,需要子类实现。返回一个字符串,表示 prompt 的值。
34 | @abstractmethod
35 | def to_string(self) -> str:
36 | """Return prompt value as string."""
37 |
38 | # 抽象方法,需要子类实现。返回一个 BaseMessage 对象的列表,表示 prompt。
39 | @abstractmethod
40 | def to_messages(self) -> List[BaseMessage]:
41 | """Return prompt as a list of Messages."""
42 |
--------------------------------------------------------------------------------
/lmchain/schema/runnable/__init__.py:
--------------------------------------------------------------------------------
1 | name = "schema.runnable"
--------------------------------------------------------------------------------
/lmchain/schema/runnable/config.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/virgo777/lmchain/3fc4a55df07100d4fe0beef4da89e3fbc6288938/lmchain/schema/runnable/config.py
--------------------------------------------------------------------------------
/lmchain/schema/schema.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from typing import Any, Literal, Sequence, Union
4 |
5 | from lmchain.load.serializable import Serializable
6 | from lmchain.schema.messages import BaseMessage
7 |
8 |
9 | class AgentAction(Serializable):
10 | """A full description of an action for an ActionAgent to execute."""
11 |
12 | tool: str
13 | """The name of the Tool to execute."""
14 | tool_input: Union[str, dict]
15 | """The input to pass in to the Tool."""
16 | log: str
17 | """Additional information to log about the action.
18 | This log can be used in a few ways. First, it can be used to audit
19 | what exactly the LLM predicted to lead to this (tool, tool_input).
20 | Second, it can be used in future iterations to show the LLMs prior
21 | thoughts. This is useful when (tool, tool_input) does not contain
22 | full information about the LLM prediction (for example, any `thought`
23 | before the tool/tool_input)."""
24 | type: Literal["AgentAction"] = "AgentAction"
25 |
26 | def __init__(
27 | self, tool: str, tool_input: Union[str, dict], log: str, **kwargs: Any
28 | ):
29 | """Override init to support instantiation by position for backward compat."""
30 | super().__init__(tool=tool, tool_input=tool_input, log=log, **kwargs)
31 |
32 | @classmethod
33 | def is_lc_serializable(cls) -> bool:
34 | """Return whether or not the class is serializable."""
35 | return True
36 |
37 |
38 | class AgentActionMessageLog(AgentAction):
39 | message_log: Sequence[BaseMessage]
40 | """Similar to log, this can be used to pass along extra
41 | information about what exact messages were predicted by the LLM
42 | before parsing out the (tool, tool_input). This is again useful
43 | if (tool, tool_input) cannot be used to fully recreate the LLM
44 | prediction, and you need that LLM prediction (for future agent iteration).
45 | Compared to `log`, this is useful when the underlying LLM is a
46 | ChatModel (and therefore returns messages rather than a string)."""
47 | # Ignoring type because we're overriding the type from AgentAction.
48 | # And this is the correct thing to do in this case.
49 | # The type literal is used for serialization purposes.
50 | type: Literal["AgentActionMessageLog"] = "AgentActionMessageLog" # type: ignore
51 |
52 |
53 | class AgentFinish(Serializable):
54 | """The final return value of an ActionAgent."""
55 |
56 | return_values: dict
57 | """Dictionary of return values."""
58 | log: str
59 | """Additional information to log about the return value.
60 | This is used to pass along the full LLM prediction, not just the parsed out
61 | return value. For example, if the full LLM prediction was
62 | `Final Answer: 2` you may want to just return `2` as a return value, but pass
63 | along the full string as a `log` (for debugging or observability purposes).
64 | """
65 | type: Literal["AgentFinish"] = "AgentFinish"
66 |
67 | def __init__(self, return_values: dict, log: str, **kwargs: Any):
68 | """Override init to support instantiation by position for backward compat."""
69 | super().__init__(return_values=return_values, log=log, **kwargs)
70 |
71 | @classmethod
72 | def is_lc_serializable(cls) -> bool:
73 | """Return whether or not the class is serializable."""
74 | return True
75 |
--------------------------------------------------------------------------------
/lmchain/tool_register.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | import inspect
4 | import traceback
5 | from copy import deepcopy
6 | from pprint import pformat
7 | from types import GenericAlias
8 | from typing import get_origin, Annotated
9 |
10 | _TOOL_HOOKS = {}
11 | _TOOL_DESCRIPTIONS = {}
12 |
13 |
14 | def register_tool(func: callable):
15 | tool_name = func.__name__
16 | tool_description = inspect.getdoc(func).strip()
17 | python_params = inspect.signature(func).parameters
18 | tool_params = []
19 | for name, param in python_params.items():
20 | annotation = param.annotation
21 | if annotation is inspect.Parameter.empty:
22 | raise TypeError(f"Parameter `{name}` missing type annotation")
23 | if get_origin(annotation) != Annotated:
24 | raise TypeError(f"Annotation type for `{name}` must be typing.Annotated")
25 |
26 | typ, (description, required) = annotation.__origin__, annotation.__metadata__
27 | typ: str = str(typ) if isinstance(typ, GenericAlias) else typ.__name__
28 | if not isinstance(description, str):
29 | raise TypeError(f"Description for `{name}` must be a string")
30 | if not isinstance(required, bool):
31 | raise TypeError(f"Required for `{name}` must be a bool")
32 |
33 | tool_params.append({
34 | "name": name,
35 | "description": description,
36 | "type": typ,
37 | "required": required
38 | })
39 | tool_def = {
40 | "name": tool_name,
41 | "description": tool_description,
42 | "params": tool_params
43 | }
44 |
45 | # print("[registered tool] " + pformat(tool_def))
46 | _TOOL_HOOKS[tool_name] = func
47 | _TOOL_DESCRIPTIONS[tool_name] = tool_def
48 |
49 | return func
50 |
51 |
52 | def dispatch_tool(tool_name: str, tool_params: dict) -> str:
53 | if tool_name not in _TOOL_HOOKS:
54 | return f"Tool `{tool_name}` not found. Please use a provided tool."
55 | tool_call = _TOOL_HOOKS[tool_name]
56 | try:
57 | ret = tool_call(**tool_params)
58 | except:
59 | ret = traceback.format_exc()
60 | return str(ret)
61 |
62 |
63 | def get_tools() -> dict:
64 | return deepcopy(_TOOL_DESCRIPTIONS)
65 |
66 |
67 | # Tool Definitions
68 |
69 | # @register_tool
70 | # def random_number_generator(
71 | # seed: Annotated[int, 'The random seed used by the generator', True],
72 | # range: Annotated[tuple[int, int], 'The range of the generated numbers', True],
73 | # ) -> int:
74 | # """
75 | # Generates a random number x, s.t. range[0] <= x < range[1]
76 | # """
77 | # if not isinstance(seed, int):
78 | # raise TypeError("Seed must be an integer")
79 | # if not isinstance(range, tuple):
80 | # raise TypeError("Range must be a tuple")
81 | # if not isinstance(range[0], int) or not isinstance(range[1], int):
82 | # raise TypeError("Range must be a tuple of integers")
83 | #
84 | # import random
85 | # return random.Random(seed).randint(*range)
86 | #
87 | #
88 | # @register_tool
89 | # def get_weather(
90 | # city_name: Annotated[str, 'The name of the city to be queried', True],
91 | # ) -> str:
92 | # """
93 | # Get the current weather for `city_name`
94 | # """
95 | #
96 | # if not isinstance(city_name, str):
97 | # raise TypeError("City name must be a string")
98 | #
99 | # key_selection = {
100 | # "current_condition": ["temp_C", "FeelsLikeC", "humidity", "weatherDesc", "observation_time"],
101 | # }
102 | # import requests
103 | # try:
104 | # resp = requests.get(f"https://wttr.in/{city_name}?format=j1")
105 | # resp.raise_for_status()
106 | # resp = resp.json()
107 | # ret = {k: {_v: resp[k][0][_v] for _v in v} for k, v in key_selection.items()}
108 | # except:
109 | # import traceback
110 | # ret = "Error encountered while fetching weather data!\n" + traceback.format_exc()
111 | #
112 | # return str(ret)
113 | #
114 | #
115 | # @register_tool
116 | # def get_customer_weather(location: Annotated[str, "需要查询位置的名称,用中文表示的地点名称", True] = ""):
117 | # """ 自己编写的天气查询函数"""
118 | #
119 | # if location == "上海":
120 | # return 23.0
121 | # elif location == "南京":
122 | # return 25.0
123 | # else:
124 | # return "未查询相关内容"
125 | #
126 | ##
127 | # @register_tool
128 | # def get_random_fun(location: Annotated[str, "随机参数", True] = ""):
129 | # """编写的一个混淆随机函数"""
130 | # location = location
131 | # return "你上当啦"
132 |
133 |
134 | if __name__ == "__main__":
135 | print(dispatch_tool("get_weather", {"city_name": "shanghai"}))
136 | print(get_tools())
137 |
--------------------------------------------------------------------------------
/lmchain/tools/__init__.py:
--------------------------------------------------------------------------------
1 | name = "tools"
--------------------------------------------------------------------------------
/lmchain/tools/tool_register.py:
--------------------------------------------------------------------------------
1 | import inspect
2 | import traceback
3 | from copy import deepcopy
4 | from pprint import pformat
5 | from types import GenericAlias
6 | from typing import get_origin, Annotated
7 |
8 | _TOOL_HOOKS = {}
9 | _TOOL_DESCRIPTIONS = {}
10 |
11 |
12 | def register_tool(func: callable):
13 | tool_name = func.__name__
14 | tool_description = inspect.getdoc(func).strip()
15 | python_params = inspect.signature(func).parameters
16 | tool_params = []
17 | for name, param in python_params.items():
18 | annotation = param.annotation
19 | if annotation is inspect.Parameter.empty:
20 | raise TypeError(f"Parameter `{name}` missing type annotation")
21 | if get_origin(annotation) != Annotated:
22 | raise TypeError(f"Annotation type for `{name}` must be typing.Annotated")
23 |
24 | typ, (description, required) = annotation.__origin__, annotation.__metadata__
25 | typ: str = str(typ) if isinstance(typ, GenericAlias) else typ.__name__
26 | if not isinstance(description, str):
27 | raise TypeError(f"Description for `{name}` must be a string")
28 | if not isinstance(required, bool):
29 | raise TypeError(f"Required for `{name}` must be a bool")
30 |
31 | tool_params.append({
32 | "name": name,
33 | "description": description,
34 | "type": typ,
35 | "required": required
36 | })
37 | tool_def = {
38 | "name": tool_name,
39 | "description": tool_description,
40 | "params": tool_params
41 | }
42 |
43 | # print("[registered tool] " + pformat(tool_def))
44 | _TOOL_HOOKS[tool_name] = func
45 | _TOOL_DESCRIPTIONS[tool_name] = tool_def
46 |
47 | return func
48 |
49 |
50 | def dispatch_tool(tool_name: str, tool_params: dict) -> str:
51 | if tool_name not in _TOOL_HOOKS:
52 | return f"Tool `{tool_name}` not found. Please use a provided tool."
53 | tool_call = _TOOL_HOOKS[tool_name]
54 | try:
55 | ret = tool_call(**tool_params)
56 | except:
57 | ret = traceback.format_exc()
58 | return str(ret)
59 |
60 |
61 | def get_tools() -> dict:
62 | return deepcopy(_TOOL_DESCRIPTIONS)
63 |
64 |
65 | # Tool Definitions
66 |
67 | @register_tool
68 | def random_number_generator(
69 | seed: Annotated[int, 'The random seed used by the generator', True],
70 | range: Annotated[tuple[int, int], 'The range of the generated numbers', True],
71 | ) -> int:
72 | """
73 | Generates a random number x, s.t. range[0] <= x < range[1]
74 | """
75 | if not isinstance(seed, int):
76 | raise TypeError("Seed must be an integer")
77 | if not isinstance(range, tuple):
78 | raise TypeError("Range must be a tuple")
79 | if not isinstance(range[0], int) or not isinstance(range[1], int):
80 | raise TypeError("Range must be a tuple of integers")
81 |
82 | import random
83 | return random.Random(seed).randint(*range)
84 |
85 |
86 | @register_tool
87 | def get_weather(
88 | city_name: Annotated[str, 'The name of the city to be queried', True],
89 | ) -> str:
90 | """
91 | Get the current weather for `city_name`
92 | """
93 |
94 | if not isinstance(city_name, str):
95 | raise TypeError("City name must be a string")
96 |
97 | key_selection = {
98 | "current_condition": ["temp_C", "FeelsLikeC", "humidity", "weatherDesc", "observation_time"],
99 | }
100 | import requests
101 | try:
102 | resp = requests.get(f"https://wttr.in/{city_name}?format=j1")
103 | resp.raise_for_status()
104 | resp = resp.json()
105 | ret = {k: {_v: resp[k][0][_v] for _v in v} for k, v in key_selection.items()}
106 | except:
107 | import traceback
108 | ret = "Error encountered while fetching weather data!\n" + traceback.format_exc()
109 |
110 | return str(ret)
111 |
112 |
113 | if __name__ == "__main__":
114 | # print(dispatch_tool("get_weather", {"city_name": "beijing"}))
115 | tools = (get_tools())
116 | import zhipuai as zhipuai
117 |
118 | zhipuai.api_key = "1f565e40af1198e11ff1fd8a5b42771d.SjNfezc40YFsz2KC" # 控制台中获取的 APIKey 信息
119 |
120 | query = "今天shanghai的天气是什么?"
121 | prompt = f"""
122 | 你现在是一个专业的人工智能助手,你现在的需求是{query}。而你需要借助于工具在{tools}中找到对应的函数,用json格式返回对应的函数名和需要的参数。
123 |
124 | 只返回json格式的函数名和需要的参数,不要做描述。
125 |
126 | 如果没有找到合适的函数,则返回:'未找到合适参数,请提供更详细的描述。'
127 | """
128 |
129 | from lmchain.agents import llmMultiAgent
130 |
131 | llm = llmMultiAgent.AgentZhipuAI()
132 | res = llm(prompt)
133 | print(res)
134 |
135 | import json
136 |
137 | res_dict = json.loads(res)
138 | res_dict = json.loads(res_dict)
139 |
140 | print(dispatch_tool(tool_name=res_dict["function_name"], tool_params=res_dict["params"]))
141 |
142 |
143 |
144 |
--------------------------------------------------------------------------------
/lmchain/utils/__init__.py:
--------------------------------------------------------------------------------
1 | name = "utils"
--------------------------------------------------------------------------------
/lmchain/utils/__pycache__/__init__.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/virgo777/lmchain/3fc4a55df07100d4fe0beef4da89e3fbc6288938/lmchain/utils/__pycache__/__init__.cpython-311.pyc
--------------------------------------------------------------------------------
/lmchain/utils/__pycache__/formatting.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/virgo777/lmchain/3fc4a55df07100d4fe0beef4da89e3fbc6288938/lmchain/utils/__pycache__/formatting.cpython-311.pyc
--------------------------------------------------------------------------------
/lmchain/utils/__pycache__/math.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/virgo777/lmchain/3fc4a55df07100d4fe0beef4da89e3fbc6288938/lmchain/utils/__pycache__/math.cpython-311.pyc
--------------------------------------------------------------------------------
/lmchain/utils/formatting.py:
--------------------------------------------------------------------------------
1 | """Utilities for formatting strings."""
2 | from string import Formatter
3 | from typing import Any, List, Mapping, Sequence, Union
4 |
5 |
6 | class StrictFormatter(Formatter):
7 | """A subclass of formatter that checks for extra keys."""
8 |
9 | def check_unused_args(
10 | self,
11 | used_args: Sequence[Union[int, str]],
12 | args: Sequence,
13 | kwargs: Mapping[str, Any],
14 | ) -> None:
15 | """Check to see if extra parameters are passed."""
16 | extra = set(kwargs).difference(used_args)
17 | if extra:
18 | raise KeyError(extra)
19 |
20 | def vformat(
21 | self, format_string: str, args: Sequence, kwargs: Mapping[str, Any]
22 | ) -> str:
23 | """Check that no arguments are provided."""
24 | if len(args) > 0:
25 | raise ValueError(
26 | "No arguments should be provided, "
27 | "everything should be passed as keyword arguments."
28 | )
29 | return super().vformat(format_string, args, kwargs)
30 |
31 | def validate_input_variables(
32 | self, format_string: str, input_variables: List[str]
33 | ) -> None:
34 | dummy_inputs = {input_variable: "foo" for input_variable in input_variables}
35 | super().format(format_string, **dummy_inputs)
36 |
37 |
38 | formatter = StrictFormatter()
39 |
--------------------------------------------------------------------------------
/lmchain/utils/input.py:
--------------------------------------------------------------------------------
1 | """Handle chained inputs."""
2 | from typing import Dict, List, Optional, TextIO
3 |
4 | _TEXT_COLOR_MAPPING = {
5 | "blue": "36;1",
6 | "yellow": "33;1",
7 | "pink": "38;5;200",
8 | "green": "32;1",
9 | "red": "31;1",
10 | }
11 |
12 |
13 | def get_color_mapping(
14 | items: List[str], excluded_colors: Optional[List] = None
15 | ) -> Dict[str, str]:
16 | """Get mapping for items to a support color."""
17 | colors = list(_TEXT_COLOR_MAPPING.keys())
18 | if excluded_colors is not None:
19 | colors = [c for c in colors if c not in excluded_colors]
20 | color_mapping = {item: colors[i % len(colors)] for i, item in enumerate(items)}
21 | return color_mapping
22 |
23 |
24 | def get_colored_text(text: str, color: str) -> str:
25 | """Get colored text."""
26 | color_str = _TEXT_COLOR_MAPPING[color]
27 | return f"\u001b[{color_str}m\033[1;3m{text}\u001b[0m"
28 |
29 |
30 | def get_bolded_text(text: str) -> str:
31 | """Get bolded text."""
32 | return f"\033[1m{text}\033[0m"
33 |
34 |
35 | def print_text(
36 | text: str, color: Optional[str] = None, end: str = "", file: Optional[TextIO] = None
37 | ) -> None:
38 | """Print text with highlighting and no end characters."""
39 | text_to_print = get_colored_text(text, color) if color else text
40 | print(text_to_print, end=end, file=file)
41 | if file:
42 | file.flush() # ensure all printed content are written to file
43 |
--------------------------------------------------------------------------------
/lmchain/utils/loading.py:
--------------------------------------------------------------------------------
1 | """Utilities for loading configurations from langchain-hub."""
2 |
3 | import os
4 | import re
5 | import tempfile
6 | from pathlib import Path, PurePosixPath
7 | from typing import Any, Callable, Optional, Set, TypeVar, Union
8 | from urllib.parse import urljoin
9 |
10 | import requests
11 |
12 | DEFAULT_REF = os.environ.get("LANGCHAIN_HUB_DEFAULT_REF", "master")
13 | URL_BASE = os.environ.get(
14 | "LANGCHAIN_HUB_URL_BASE",
15 | "https://raw.githubusercontent.com/hwchase17/langchain-hub/{ref}/",
16 | )
17 | HUB_PATH_RE = re.compile(r"lc(?P][@[^:]+)?://(?P.*)")
18 |
19 | T = TypeVar("T")
20 |
21 |
22 | def try_load_from_hub(
23 | path: Union[str, Path],
24 | loader: Callable[[str], T],
25 | valid_prefix: str,
26 | valid_suffixes: Set[str],
27 | **kwargs: Any,
28 | ) -> Optional[T]:
29 | """Load configuration from hub. Returns None if path is not a hub path."""
30 | if not isinstance(path, str) or not (match := HUB_PATH_RE.match(path)):
31 | return None
32 | ref, remote_path_str = match.groups()
33 | ref = ref[1:] if ref else DEFAULT_REF
34 | remote_path = Path(remote_path_str)
35 | if remote_path.parts[0] != valid_prefix:
36 | return None
37 | if remote_path.suffix[1:] not in valid_suffixes:
38 | raise ValueError(f"Unsupported file type, must be one of {valid_suffixes}.")
39 |
40 | # Using Path with URLs is not recommended, because on Windows
41 | # the backslash is used as the path separator, which can cause issues
42 | # when working with URLs that use forward slashes as the path separator.
43 | # Instead, use PurePosixPath to ensure that forward slashes are used as the
44 | # path separator, regardless of the operating system.
45 | full_url = urljoin(URL_BASE.format(ref=ref), PurePosixPath(remote_path).__str__())
46 |
47 | r = requests.get(full_url, timeout=5)
48 | if r.status_code != 200:
49 | raise ValueError(f"Could not find file at {full_url}")
50 | with tempfile.TemporaryDirectory() as tmpdirname:
51 | file = Path(tmpdirname) / remote_path.name
52 | with open(file, "wb") as f:
53 | f.write(r.content)
54 | return loader(str(file), **kwargs)
55 |
--------------------------------------------------------------------------------
/lmchain/utils/math.py:
--------------------------------------------------------------------------------
1 | """Math utils."""
2 | import logging
3 | from typing import List, Optional, Tuple, Union
4 |
5 | import numpy as np
6 |
7 | logger = logging.getLogger(__name__)
8 |
9 | Matrix = Union[List[List[float]], List[np.ndarray], np.ndarray]
10 |
11 |
12 | def cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray:
13 | """Row-wise cosine similarity between two equal-width matrices."""
14 | if len(X) == 0 or len(Y) == 0:
15 | return np.array([])
16 |
17 | X = np.array(X)
18 | Y = np.array(Y)
19 | if X.shape[1] != Y.shape[1]:
20 | raise ValueError(
21 | f"Number of columns in X and Y must be the same. X has shape {X.shape} "
22 | f"and Y has shape {Y.shape}."
23 | )
24 | try:
25 | import simsimd as simd
26 |
27 | X = np.array(X, dtype=np.float32)
28 | Y = np.array(Y, dtype=np.float32)
29 | Z = 1 - simd.cdist(X, Y, metric="cosine")
30 | if isinstance(Z, float):
31 | return np.array([Z])
32 | return Z
33 | except ImportError:
34 | logger.info(
35 | "Unable to import simsimd, defaulting to NumPy implementation. If you want "
36 | "to use simsimd please install with `pip install simsimd`."
37 | )
38 | X_norm = np.linalg.norm(X, axis=1)
39 | Y_norm = np.linalg.norm(Y, axis=1)
40 | # Ignore divide by zero errors run time warnings as those are handled below.
41 | with np.errstate(divide="ignore", invalid="ignore"):
42 | similarity = np.dot(X, Y.T) / np.outer(X_norm, Y_norm)
43 | similarity[np.isnan(similarity) | np.isinf(similarity)] = 0.0
44 | return similarity
45 |
46 |
47 | def cosine_similarity_top_k(
48 | X: Matrix,
49 | Y: Matrix,
50 | top_k: Optional[int] = 5,
51 | score_threshold: Optional[float] = None,
52 | ) -> Tuple[List[Tuple[int, int]], List[float]]:
53 | """Row-wise cosine similarity with optional top-k and score threshold filtering.
54 |
55 | Args:
56 | X: Matrix.
57 | Y: Matrix, same width as X.
58 | top_k: Max number of results to return.
59 | score_threshold: Minimum cosine similarity of results.
60 |
61 | Returns:
62 | Tuple of two lists. First contains two-tuples of indices (X_idx, Y_idx),
63 | second contains corresponding cosine similarities.
64 | """
65 | if len(X) == 0 or len(Y) == 0:
66 | return [], []
67 | score_array = cosine_similarity(X, Y)
68 | score_threshold = score_threshold or -1.0
69 | score_array[score_array < score_threshold] = 0
70 | top_k = min(top_k or len(score_array), np.count_nonzero(score_array))
71 | top_k_idxs = np.argpartition(score_array, -top_k, axis=None)[-top_k:]
72 | top_k_idxs = top_k_idxs[np.argsort(score_array.ravel()[top_k_idxs])][::-1]
73 | ret_idxs = np.unravel_index(top_k_idxs, score_array.shape)
74 | scores = score_array.ravel()[top_k_idxs].tolist()
75 | return list(zip(*ret_idxs)), scores # type: ignore
76 |
--------------------------------------------------------------------------------
/lmchain/vectorstores/__init__.py:
--------------------------------------------------------------------------------
1 | name = "vectorstores"
--------------------------------------------------------------------------------
/lmchain/vectorstores/__pycache__/vectorstore.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/virgo777/lmchain/3fc4a55df07100d4fe0beef4da89e3fbc6288938/lmchain/vectorstores/__pycache__/vectorstore.cpython-311.pyc
--------------------------------------------------------------------------------
/lmchain/vectorstores/chroma.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 | from langchain.docstore.document import Document
4 | from langchain.text_splitter import RecursiveCharacterTextSplitter
5 | from lmchain.embeddings import embeddings
6 | from lmchain.vectorstores import laiss
7 |
8 | from langchain.memory import ConversationBufferMemory
9 | from langchain.prompts import (
10 | ChatPromptTemplate, # 用于构建聊天模板的类
11 | MessagesPlaceholder, # 用于在模板中插入消息占位的类
12 | SystemMessagePromptTemplate, # 用于构建系统消息模板的类
13 | HumanMessagePromptTemplate # 用于构建人类消息模板的类
14 | )
15 | from langchain.chains import ConversationChain
16 |
17 | class Chroma:
18 | def __init__(self,documents,embedding_tool,chunk_size = 1280,chunk_overlap = 50,source = "这是一份辅助材料"):
19 | """
20 | :param document: 输入的文本内容,只要一个text文本
21 | :param chunk_size: 切分后每段的字数
22 | :param chunk_overlap: 每个相隔段落重叠的字数
23 | :param source: 文本名称/文本地址
24 | """
25 | self.text_splitter = RecursiveCharacterTextSplitter( chunk_size=chunk_size, chunk_overlap=chunk_overlap)
26 | self.embedding_tool = embedding_tool
27 |
28 | self.lmaiss = laiss.LMASS() #这里是用于将文本转化为vector,并且计算query相应的相似度的类
29 |
30 | self.documents = []
31 | self.vectorstores = []
32 |
33 | "---------------------------"
34 | for document in documents:
35 | document = [Document(page_content=document, metadata={"source": source})] #对输入的document进行格式化处理
36 | doc= self.text_splitter.split_documents(document) #根据
37 | self.documents.extend(doc)
38 |
39 | vector = self.lmaiss.from_documents(document, embedding_class=self.embedding_tool)
40 | self.vectorstores.extend(vector)
41 |
42 | # def __call__(self, query):
43 | # query_embedding = self.embedding_tool.embed_query(query)
44 | #
45 | # #根据query查找最近的那个序列
46 | # close_id = self.lmaiss.get_similarity_vector_indexs(query_embedding, self.vectorstore, k=1)[0]
47 | # #查找最近的那个段落id
48 | # doc = self.documents[close_id]
49 | #
50 | #
51 | # return doc
52 |
53 | def similarity_search(self, query):
54 | query_embedding = self.embedding_tool.embed_query(query)
55 |
56 | #根据query查找最近的那个序列
57 | close_id = self.lmaiss.get_similarity_vector_indexs(query_embedding, self.vectorstores, k=1)[0]
58 | #查找最近的那个段落id
59 | doc = self.documents[close_id]
60 | return doc
61 |
62 | def add_texts(self,texts,metadata = ""):
63 | for document in texts:
64 | document = [Document(page_content=document, metadata={"source": metadata})] #对输入的document进行格式化处理
65 | doc= self.text_splitter.split_documents(document) #根据
66 | self.documents.extend(doc)
67 |
68 | vector = self.lmaiss.from_documents(document, embedding_class=self.embedding_tool)
69 | self.vectorstores.extend(vector)
70 |
71 | return True
72 |
73 |
74 | def from_texts(texts,embeddings,source = ""):
75 | docsearch = Chroma(documents = texts,embedding_tool=embeddings,source = source)
76 | return docsearch
77 |
78 |
79 | # def from_texts(texts,embeddings):
80 | # embs = embeddings.embed_documents(texts=texts)
81 | # return embs
--------------------------------------------------------------------------------
/lmchain/vectorstores/embeddings.py:
--------------------------------------------------------------------------------
1 | import warnings
2 | warnings.filterwarnings("ignore")
3 |
4 | import asyncio
5 | from abc import ABC, abstractmethod
6 | from typing import List
7 |
8 |
9 | class Embeddings(ABC):
10 | """Interface for embedding models."""
11 |
12 | @abstractmethod
13 | def embed_documents(self, texts: List[str]) -> List[List[float]]:
14 | """Embed search docs."""
15 |
16 | @abstractmethod
17 | def embed_query(self, text: str) -> List[float]:
18 | """Embed query text."""
19 |
20 | async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
21 | """Asynchronous Embed search docs."""
22 | return await asyncio.get_running_loop().run_in_executor(
23 | None, self.embed_documents, texts
24 | )
25 |
26 | async def aembed_query(self, text: str) -> List[float]:
27 | """Asynchronous Embed query text."""
28 | return await asyncio.get_running_loop().run_in_executor(
29 | None, self.embed_query, text
30 | )
31 |
32 |
33 | # class LMEmbedding(Embeddings):
34 | # from modelscope.pipelines import pipeline
35 | # from modelscope.utils.constant import Tasks
36 | # pipeline_se = pipeline(Tasks.sentence_embedding, model='thomas/text2vec-base-chinese', model_revision='v1.0.0',
37 | # device="cuda")
38 | #
39 | # def _costruct_inputs(self, texts):
40 | # inputs = {
41 | # "source_sentence": texts
42 | # }
43 | #
44 | # return inputs
45 | #
46 | # def embed_documents(self, texts: List[str]) -> List[List[float]]:
47 | # """Embed search docs."""
48 | #
49 | # inputs = self._costruct_inputs(texts)
50 | # result_embeddings = self.pipeline_se(input=inputs)
51 | # return result_embeddings["text_embedding"]
52 | #
53 | # def embed_query(self, text: str) -> List[float]:
54 | # """Embed query text."""
55 | # inputs = self._costruct_inputs([text])
56 | # result_embeddings = self.pipeline_se(input=inputs)
57 | # return result_embeddings["text_embedding"]
58 |
59 |
60 | class GLMEmbedding(Embeddings):
61 | import zhipuai as zhipuai
62 | zhipuai.api_key = "1f565e40af1198e11ff1fd8a5b42771d.SjNfezc40YFsz2KC" # 控制台中获取的 APIKey 信息
63 |
64 | def _costruct_inputs(self, texts):
65 | inputs = {
66 | "source_sentence": texts
67 | }
68 |
69 | return inputs
70 |
71 | aembeddings = [] # 这个是为了在并发获取embedding_value时候使用的存储embedding_list内容。
72 | atexts = []
73 |
74 | def embed_documents(self, texts: List[str]) -> List[List[float]]:
75 | """Embed search docs."""
76 | result_embeddings = []
77 | for text in texts:
78 | embedding = self.embed_query(text)
79 | result_embeddings.append(embedding)
80 | return result_embeddings
81 |
82 | def embed_query(self, text: str) -> List[float]:
83 | """Embed query text."""
84 | result_embeddings = self.zhipuai.model_api.invoke(
85 | model="text_embedding", prompt=text)
86 | return result_embeddings["data"]["embedding"]
87 |
88 | def aembed_query(self, text: str) -> List[float]:
89 | """Embed query text."""
90 | result_embeddings = self.zhipuai.model_api.invoke(
91 | model="text_embedding", prompt=text)
92 | emb = result_embeddings["data"]["embedding"]
93 |
94 | self.aembeddings.append(emb)
95 | self.atexts.append(text)
96 |
97 | # 这里实现了并发embedding获取
98 | def aembed_documents(self, texts: List[str], thread_num=5, wait_sec=0.3) -> List[List[float]]:
99 | import threading
100 | text_length = len(texts)
101 | thread_batch = text_length // thread_num
102 |
103 | for i in range(thread_batch):
104 | start = i * thread_num
105 | end = (i + 1) * thread_num
106 |
107 | # 创建线程列表
108 | threads = []
109 | # 创建并启动5个线程,每个线程调用一个模型
110 | for text in texts[start:end]:
111 | thread = threading.Thread(target=self.aembed_query, args=(text,))
112 | thread.start()
113 | threads.append(thread)
114 | for thread in threads:
115 | thread.join(wait_sec) # 设置超时时间为0.3秒
116 | return self.aembeddings, self.atexts
117 |
118 |
119 | if __name__ == '__main__':
120 | import time
121 |
122 | inputs = ["不可以,早晨喝牛奶不科学", "今天早晨喝牛奶不科学", "早晨喝牛奶不科学"] * 50
123 |
124 | start_time = time.time()
125 | aembeddings = (GLMEmbedding().aembed_documents(inputs, thread_num=5, thread_sec=0.3))
126 | print(aembeddings)
127 | print(len(aembeddings))
128 | end_time = time.time()
129 | # 计算函数执行时间并打印结果
130 | execution_time = end_time - start_time
131 | print(f"函数执行时间: {execution_time} 秒")
132 | print("----------------------------------------------------------------------------------")
133 | start_time = time.time()
134 | aembeddings = (GLMEmbedding().embed_documents(inputs))
135 | print(len(aembeddings))
136 | end_time = time.time()
137 | # 计算函数执行时间并打印结果
138 | execution_time = end_time - start_time
139 | print(f"函数执行时间: {execution_time} 秒")
140 |
--------------------------------------------------------------------------------
/lmchain/vectorstores/utils.py:
--------------------------------------------------------------------------------
1 | """Utility functions for working with vectors and vectorstores."""
2 |
3 | from enum import Enum
4 | from typing import List, Tuple, Type
5 |
6 | import numpy as np
7 |
8 | from lmchain.schema.document import Document
9 | from lmchain.utils.math import cosine_similarity
10 |
11 | class DistanceStrategy(str, Enum):
12 | """Enumerator of the Distance strategies for calculating distances
13 | between vectors."""
14 |
15 | EUCLIDEAN_DISTANCE = "EUCLIDEAN_DISTANCE"
16 | MAX_INNER_PRODUCT = "MAX_INNER_PRODUCT"
17 | DOT_PRODUCT = "DOT_PRODUCT"
18 | JACCARD = "JACCARD"
19 | COSINE = "COSINE"
20 |
21 |
22 | def maximal_marginal_relevance(
23 | query_embedding: np.ndarray,
24 | embedding_list: list,
25 | lambda_mult: float = 0.5,
26 | k: int = 4,
27 | ) -> List[int]:
28 | """Calculate maximal marginal relevance."""
29 | if min(k, len(embedding_list)) <= 0:
30 | return []
31 | if query_embedding.ndim == 1:
32 | query_embedding = np.expand_dims(query_embedding, axis=0)
33 | similarity_to_query = cosine_similarity(query_embedding, embedding_list)[0]
34 | most_similar = int(np.argmax(similarity_to_query))
35 | idxs = [most_similar]
36 | selected = np.array([embedding_list[most_similar]])
37 | while len(idxs) < min(k, len(embedding_list)):
38 | best_score = -np.inf
39 | idx_to_add = -1
40 | similarity_to_selected = cosine_similarity(embedding_list, selected)
41 | for i, query_score in enumerate(similarity_to_query):
42 | if i in idxs:
43 | continue
44 | redundant_score = max(similarity_to_selected[i])
45 | equation_score = (
46 | lambda_mult * query_score - (1 - lambda_mult) * redundant_score
47 | )
48 | if equation_score > best_score:
49 | best_score = equation_score
50 | idx_to_add = i
51 | idxs.append(idx_to_add)
52 | selected = np.append(selected, [embedding_list[idx_to_add]], axis=0)
53 | return idxs
54 |
55 |
56 | def filter_complex_metadata(
57 | documents: List[Document],
58 | *,
59 | allowed_types: Tuple[Type, ...] = (str, bool, int, float),
60 | ) -> List[Document]:
61 | """Filter out metadata types that are not supported for a vector store."""
62 | updated_documents = []
63 | for document in documents:
64 | filtered_metadata = {}
65 | for key, value in document.metadata.items():
66 | if not isinstance(value, allowed_types):
67 | continue
68 | filtered_metadata[key] = value
69 |
70 | document.metadata = filtered_metadata
71 | updated_documents.append(document)
72 |
73 | return updated_documents
74 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = "LMchain"
3 | version = "0.1.62"
4 | authors = [
5 | { name="xiaohuaWang", email="5847713@qq.com" },
6 | ]
7 | description = "A large language chain tools"
8 | readme = "README.md"
9 | requires-python = ">=3.10"
10 | classifiers = [
11 | "Programming Language :: Python :: 3",
12 | "License :: OSI Approved :: MIT License",
13 | "Operating System :: OS Independent",
14 | ]
15 |
16 | [project.urls]
17 | "Homepage" = "https://github.com/pypa/sampleproject"
18 | "Bug Tracker" = "https://github.com/pypa/sampleproject/issues"
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | import setuptools
2 | with open("README.md", "r") as fh:
3 | long_description = fh.read()
4 | setuptools.setup(
5 | name="lmchain", # 模块名称
6 | version="0.1.62", # 当前版本
7 | author="xiaohuaWang", # 作者
8 | author_email="5847713@qq.com", # 作者邮箱
9 | description="LMchain是一个专门适配大模型chain的工具包", # 模块简介
10 | long_description=long_description, # 模块详细介绍
11 | long_description_content_type="text/markdown", # 模块详细介绍格式
12 | # url="https://github.com/", # 模块github地址
13 | packages=setuptools.find_packages(), # 自动找到项目中导入的模块
14 | include_package_data=True,
15 | # 模块相关的元数据
16 | classifiers=[
17 | "Programming Language :: Python :: 3",
18 | "License :: OSI Approved :: MIT License",
19 | "Operating System :: OS Independent",
20 | ],
21 | # 依赖模块
22 | install_requires=[
23 | 'uvicorn', 'fastapi','typing',"numexpr","langchain","zhipuai","nltk"
24 | ],
25 | python_requires='>=3',
26 | )
27 |
--------------------------------------------------------------------------------
/tool_register.py:
--------------------------------------------------------------------------------
1 | import inspect
2 | import traceback
3 | from copy import deepcopy
4 | from pprint import pformat
5 | from types import GenericAlias
6 | from typing import get_origin, Annotated
7 |
8 | _TOOL_HOOKS = {}
9 | _TOOL_DESCRIPTIONS = {}
10 |
11 |
12 | def register_tool(func: callable):
13 | tool_name = func.__name__
14 | tool_description = inspect.getdoc(func).strip()
15 | python_params = inspect.signature(func).parameters
16 | tool_params = []
17 | for name, param in python_params.items():
18 | annotation = param.annotation
19 | if annotation is inspect.Parameter.empty:
20 | raise TypeError(f"Parameter `{name}` missing type annotation")
21 | if get_origin(annotation) != Annotated:
22 | raise TypeError(f"Annotation type for `{name}` must be typing.Annotated")
23 |
24 | typ, (description, required) = annotation.__origin__, annotation.__metadata__
25 | typ: str = str(typ) if isinstance(typ, GenericAlias) else typ.__name__
26 | if not isinstance(description, str):
27 | raise TypeError(f"Description for `{name}` must be a string")
28 | if not isinstance(required, bool):
29 | raise TypeError(f"Required for `{name}` must be a bool")
30 |
31 | tool_params.append({
32 | "name": name,
33 | "description": description,
34 | "type": typ,
35 | "required": required
36 | })
37 | tool_def = {
38 | "name": tool_name,
39 | "description": tool_description,
40 | "params": tool_params
41 | }
42 |
43 | # print("[registered tool] " + pformat(tool_def))
44 | _TOOL_HOOKS[tool_name] = func
45 | _TOOL_DESCRIPTIONS[tool_name] = tool_def
46 |
47 | return func
48 |
49 |
50 | def dispatch_tool(tool_name: str, tool_params: dict) -> str:
51 | if tool_name not in _TOOL_HOOKS:
52 | return f"Tool `{tool_name}` not found. Please use a provided tool."
53 | tool_call = _TOOL_HOOKS[tool_name]
54 | try:
55 | ret = tool_call(**tool_params)
56 | except:
57 | ret = traceback.format_exc()
58 | return str(ret)
59 |
60 |
61 | def get_tools() -> dict:
62 | return deepcopy(_TOOL_DESCRIPTIONS)
63 |
64 |
65 | # Tool Definitions
66 |
67 | # @register_tool
68 | # def random_number_generator(
69 | # seed: Annotated[int, 'The random seed used by the generator', True],
70 | # range: Annotated[tuple[int, int], 'The range of the generated numbers', True],
71 | # ) -> int:
72 | # """
73 | # Generates a random number x, s.t. range[0] <= x < range[1]
74 | # """
75 | # if not isinstance(seed, int):
76 | # raise TypeError("Seed must be an integer")
77 | # if not isinstance(range, tuple):
78 | # raise TypeError("Range must be a tuple")
79 | # if not isinstance(range[0], int) or not isinstance(range[1], int):
80 | # raise TypeError("Range must be a tuple of integers")
81 | #
82 | # import random
83 | # return random.Random(seed).randint(*range)
84 | #
85 | #
86 | # @register_tool
87 | # def get_weather(
88 | # city_name: Annotated[str, 'The name of the city to be queried', True],
89 | # ) -> str:
90 | # """
91 | # Get the current weather for `city_name`
92 | # """
93 | #
94 | # if not isinstance(city_name, str):
95 | # raise TypeError("City name must be a string")
96 | #
97 | # key_selection = {
98 | # "current_condition": ["temp_C", "FeelsLikeC", "humidity", "weatherDesc", "observation_time"],
99 | # }
100 | # import requests
101 | # try:
102 | # resp = requests.get(f"https://wttr.in/{city_name}?format=j1")
103 | # resp.raise_for_status()
104 | # resp = resp.json()
105 | # ret = {k: {_v: resp[k][0][_v] for _v in v} for k, v in key_selection.items()}
106 | # except:
107 | # import traceback
108 | # ret = "Error encountered while fetching weather data!\n" + traceback.format_exc()
109 | #
110 | # return str(ret)
111 | #
112 | #
113 | # @register_tool
114 | # def get_customer_weather(location: Annotated[str, "需要查询位置的名称,用中文表示的地点名称", True] = ""):
115 | # """ 自己编写的天气查询函数"""
116 | #
117 | # if location == "上海":
118 | # return 23.0
119 | # elif location == "南京":
120 | # return 25.0
121 | # else:
122 | # return "未查询相关内容"
123 | #
124 | #
125 | # @register_tool
126 | # def get_random_fun(location: Annotated[str, "随机参数", True] = ""):
127 | # """编写的一个混淆随机函数"""
128 | # location = location
129 | # return "你上当啦"
130 |
131 |
132 | if __name__ == "__main__":
133 | print(dispatch_tool("get_weather", {"city_name": "shanghai"}))
134 | print(get_tools())
135 |
--------------------------------------------------------------------------------
/upload:
--------------------------------------------------------------------------------
1 | upload file
2 |
3 |
4 |
5 |
--------------------------------------------------------------------------------
]