├── .gitignore
├── LICENSE
├── README.md
├── docs
├── config.md
├── description.md
├── pic
│ ├── emoji_logo.jpg
│ ├── example_video.gif
│ ├── flowstruct.drawio
│ ├── flowstruct.png
│ ├── introduce.jpg
│ ├── logo.jpg
│ ├── wx.png
│ └── wxemoji.png
└── systemd.md
├── frontend
└── emoji.py
├── langchain_emoji
├── __init__.py
├── __main__.py
├── __version__.py
├── components
│ ├── embedding
│ │ ├── __init__.py
│ │ ├── custom
│ │ │ └── zhipuai
│ │ │ │ ├── __init__.py
│ │ │ │ └── zhipuai_custom.py
│ │ └── embedding_component.py
│ ├── llm
│ │ ├── __init__.py
│ │ ├── custom
│ │ │ ├── .gitkeep
│ │ │ └── zhipuai
│ │ │ │ ├── __init__.py
│ │ │ │ ├── zhipuai_custom.py
│ │ │ │ └── zhipuai_info.py
│ │ └── llm_component.py
│ ├── minio
│ │ ├── __init__.py
│ │ └── minio_component.py
│ ├── trace
│ │ ├── __init__.py
│ │ └── trace_component.py
│ └── vector_store
│ │ ├── __init__.py
│ │ ├── chroma
│ │ ├── __init__.py
│ │ └── chroma.py
│ │ ├── tencent
│ │ ├── __init__.py
│ │ └── tencent.py
│ │ └── vector_store_component.py
├── constants.py
├── di.py
├── launcher.py
├── main.py
├── paths.py
├── server
│ ├── config
│ │ ├── __init__.py
│ │ └── config_router.py
│ ├── emoji
│ │ ├── __init__.py
│ │ ├── emoji_prompt.py
│ │ ├── emoji_router.py
│ │ └── emoji_service.py
│ ├── health
│ │ ├── __init__.py
│ │ └── health_router.py
│ ├── trace
│ │ ├── __init__.py
│ │ └── trace_router.py
│ ├── utils
│ │ ├── __init__.py
│ │ ├── auth.py
│ │ └── model.py
│ └── vector_store
│ │ ├── __init__.py
│ │ ├── vector_store_router.py
│ │ └── vector_store_server.py
├── settings
│ ├── __init__.py
│ ├── settings.py
│ ├── settings_loader.py
│ └── yaml.py
└── utils
│ ├── __init__.py
│ └── _compat.py
├── local_data
└── .gitkeep
├── log
└── .gitkeep
├── poetry.lock
├── pyproject.toml
├── settings.yaml
├── tests
└── __init__.py
└── tools
├── datainit.py
└── json2jsonl.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .vs/
2 | .vscode/
3 | .idea/
4 | # Byte-compiled / optimized / DLL files
5 | __pycache__/
6 | *.py[cod]
7 | *$py.class
8 |
9 | # C extensions
10 | *.so
11 |
12 | # Distribution / packaging
13 | .Python
14 | build/
15 | develop-eggs/
16 | dist/
17 | downloads/
18 | eggs/
19 | .eggs/
20 | lib/
21 | lib64/
22 | parts/
23 | sdist/
24 | var/
25 | wheels/
26 | pip-wheel-metadata/
27 | share/python-wheels/
28 | *.egg-info/
29 | .installed.cfg
30 | *.egg
31 | MANIFEST
32 |
33 | # PyInstaller
34 | # Usually these files are written by a python script from a template
35 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
36 | *.manifest
37 | *.spec
38 |
39 | # Installer logs
40 | pip-log.txt
41 | pip-delete-this-directory.txt
42 |
43 | # Unit test / coverage reports
44 | htmlcov/
45 | .tox/
46 | .nox/
47 | .coverage
48 | .coverage.*
49 | .cache
50 | nosetests.xml
51 | coverage.xml
52 | *.cover
53 | *.py,cover
54 | .hypothesis/
55 | .pytest_cache/
56 |
57 | # Translations
58 | *.mo
59 | *.pot
60 |
61 | # Django stuff:
62 | *.log
63 | local_settings.py
64 | db.sqlite3
65 | db.sqlite3-journal
66 |
67 | # Flask stuff:
68 | instance/
69 | .webassets-cache
70 |
71 | # Scrapy stuff:
72 | .scrapy
73 |
74 | # Sphinx documentation
75 | docs/_build/
76 | docs/docs/_build/
77 |
78 | # PyBuilder
79 | target/
80 |
81 | # Jupyter Notebook
82 | .ipynb_checkpoints
83 | notebooks/
84 |
85 | # IPython
86 | profile_default/
87 | ipython_config.py
88 |
89 | # pyenv
90 | .python-version
91 |
92 | # pipenv
93 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
94 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
95 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
96 | # install all needed dependencies.
97 | #Pipfile.lock
98 |
99 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
100 | __pypackages__/
101 |
102 | # Celery stuff
103 | celerybeat-schedule
104 | celerybeat.pid
105 |
106 | # SageMath parsed files
107 | *.sage.py
108 |
109 | # Environments
110 | .env
111 | .envrc
112 | .venv
113 | .venvs
114 | env/
115 | venv/
116 | ENV/
117 | env.bak/
118 | venv.bak/
119 |
120 | # Spyder project settings
121 | .spyderproject
122 | .spyproject
123 |
124 | # Rope project settings
125 | .ropeproject
126 |
127 | # mkdocs documentation
128 | /site
129 |
130 | # mypy
131 | .mypy_cache/
132 | .dmypy.json
133 | dmypy.json
134 |
135 | # Pyre type checker
136 | .pyre/
137 |
138 | # macOS display setting files
139 | .DS_Store
140 |
141 | # Wandb directory
142 | wandb/
143 |
144 | # asdf tool versions
145 | .tool-versions
146 | /.ruff_cache/
147 |
148 | *.pkl
149 | *.bin
150 |
151 | local_data/emo-visual-data.zip
152 | local_data/emo-visual-data
153 | local_data/chromadb
154 |
155 | settings-pro.yaml
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # 🥳 LangChain-Emoji
2 |
3 | 简体中文 | [English](README-en.md)
4 |
5 |
6 |
7 |
8 |
9 |
10 | 
11 | 基于LangChain的开源表情包斗图Agent
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 | > 本项目表情包数据来源于智谱 AI 团队,数据来源和相关介绍如下
22 | > https://github.com/LLM-Red-Team/emo-visual-data
23 | > 感谢开源 一起玩转大模型 🎉🌟🌟
24 |
25 | 💡💡💡 **更新:我们支持微信客户端体验啦~**
26 |
27 | #### 扫描以下二维码,添加微信,开启斗图模式
28 |
29 |
30 | 
31 | 
32 |
33 |
34 | ## 🚀 Quick Install
35 |
36 | ### 1.部署 Python 环境
37 |
38 | - 安装 miniconda
39 |
40 | ```shell
41 | mkdir ~/miniconda3
42 | wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda3/miniconda.sh
43 | bash ~/miniconda3/miniconda.sh -b -u -p ~/miniconda3
44 | rm -rf ~/miniconda3/miniconda.sh
45 | ~/miniconda3/bin/conda init bash
46 | ```
47 |
48 | - 创建虚拟环境
49 |
50 | ```shell
51 | # 创建环境
52 | conda create -n LangChain-Emoji python==3.10.11
53 | ```
54 |
55 | - 安装 poetry
56 |
57 | ```shell
58 | # 安装
59 | curl -sSL https://install.python-poetry.org | python3 -
60 | ```
61 |
62 | ### 2. 运行 LangChain-Emoji
63 |
64 | - 安装依赖
65 |
66 | ```shell
67 | # 克隆项目代码到本地
68 | git clone https://github.com/ptonlix/LangChain-Emoji.git
69 | conda activate LangChain-Emoji # 激活环境
70 | cd LangChain-Emoji # 进入项目
71 | poetry install # 安装依赖
72 | ```
73 |
74 | - 修改配置文件
75 |
76 | [OpenAI 文档](https://platform.openai.com/docs/introduction)
77 | [ZhipuAI 文档](https://open.bigmodel.cn/dev/howuse/introduction)
78 | [LangChain API](https://smith.langchain.com)
79 |
80 | ```shell
81 | # settings.yaml
82 |
83 | 配置文件录入或通过环境变量设置以下变量
84 |
85 | # OPENAI 大模型API
86 | OPENAI_API_BASE
87 | OPENAI_API_KEY
88 |
89 | # ZHIPUAI 智谱API
90 | ZHIPUAI_API_KEY
91 |
92 | # LangChain调试 API
93 | LANGCHAIN_API_KEY
94 |
95 | # 向量数据库,默认采用chromadb
96 | embedding #配置向量模型,默认为zhipuai
97 |
98 | # 腾讯云向量数据库配置(可选)
99 | TCVERCTORDB_API_HOST
100 | TCVERCTORDB_API_KEY
101 |
102 | # Minio云盘配置(可选)
103 | MINIO_HOST
104 | MINIO_ACCESS_KEY
105 | MINIO_SECRET_KEY
106 |
107 | ```
108 |
109 | 详情配置文件介绍见: [LangChain-Emoji 配置](./docs/config.md)
110 |
111 | - 数据初始化
112 |
113 | 主要借助于 `tools/datainit.py `数据初始化工具完成相关操作
114 |
115 | I. 采用数据本地文件部署
116 |
117 | [➡️ 百度云盘下载](https://pan.baidu.com/s/11iwqoxLtjV-DOQli81vZ6Q?pwd=tab4)
118 |
119 | ```
120 | 从百度下载数据,解析
121 | 地址:https://pan.baidu.com/s/11iwqoxLtjV-DOQli81vZ6Q?pwd=tab4
122 | 下载到local_data,并解压
123 | ```
124 |
125 | [➡️ 谷歌云盘下载](https://pan.baidu.com/s/11iwqoxLtjV-DOQli81vZ6Q?pwd=tab4)
126 |
127 | ```
128 | cd tools
129 | python datainit.py --download
130 | # 等待数据包下载并解压完成
131 | ```
132 |
133 | Ⅱ. 采用 Minio 云盘部署(可选)
134 |
135 | 完成`步骤1`,将数据下载到 local_data 目录并解压完成
136 |
137 | 修改`settings.yaml`配置文件中 minio 的配置,填写好
138 | `MINIO_HOST`
139 | `MINIO_ACCESS_KEY`
140 | `MINIO_SECRET_KEY`
141 | 填写好这三个参数
142 |
143 | ```
144 | cd tools
145 | python datainit.py --upload
146 | # 等待数据上传到minio完成
147 | ```
148 |
149 | Ⅲ. 同步元数据到向量数据库 (默认采用 ChromaDB)
150 |
151 | ```
152 | cd tools
153 | python datainit.py --vectordb
154 | # 等待数据上传到向量数据库完成
155 | ```
156 |
157 | > **腾讯云向量数据库(可选)**
158 | > 修改`settings.yaml`配置文件中 向量数据库 的配置,填写好
159 | > `TCVERCTORDB_API_HOST`
160 | > `TCVERCTORDB_API_HOST`
161 | > 填写好这两个参数
162 | > `vectorstore` `database`选择 `tcvectordb`
163 |
164 | - 启动项目
165 |
166 | ```shell
167 | # 启动项目
168 | python -m langchain_emoji
169 |
170 | # 查看API
171 | 访问: http://localhost:8003/docs 获取 API 信息
172 | ```
173 |
174 | - 启动 Web Demo
175 |
176 | ```shell
177 | # 进入前端目录
178 | cd frontend
179 | # 启动
180 | streamlit run emoji.py
181 | ```
182 |
183 | ## 💡 演示效果
184 |
185 | [](./docs/pic/example_video.gif)
186 |
187 | ### 微信客户端
188 |
189 | 我们支持微信客户端体验啦~
190 |
191 | ## 📖 项目介绍
192 |
193 | ### 1. 流程架构
194 |
195 |
196 | 
197 | LangChain-Emoji流程架构图
198 |
199 |
200 | 核心流程:
201 |
202 | 1. 项目数据初始化,将表情包数据集下载并同步到向量数据库和云盘等
203 | 2. 前端传入 Prompt 通过 Retriever 召回表情包信息(文件名、文件描述)列表(默认 4 个)
204 | 3. 通过大语言模型,从表情包列表中筛选出最符合输入 Prompt 的表情包
205 | 4. 通过表情包的文件名去数据中心获取图片数据,返回前端呈现
206 |
207 | ### 2. 目录结构
208 |
209 | ```
210 | ├── docs # 文档
211 | ├── local_data # 数据集目录
212 | ├── langchain_emoji
213 | │ ├── components #自定义组件
214 | │ ├── server # API服务
215 | │ ├── settings # 配置服务
216 | │ ├── utils
217 | │ ├── constants.py
218 | │ ├── di.py
219 | │ ├── launcher.py
220 | │ ├── main.py
221 | │ ├── paths.py
222 | │ ├── __init__.py
223 | │ ├── __main__.py #入口
224 | │ └── __version__.py
225 | ├── log # 日志目录
226 | ```
227 |
228 | ### 3. 功能介绍
229 |
230 | - 支持 `openai` `zhipuai` `deepseek` 大模型
231 | - 支持本地向量数据库`chroma`和`腾讯云向量数据库`
232 | - 支持 配置文件动态加载
233 | - 支持 Web Demo 演示
234 |
235 | ## 🚩 Roadmap
236 |
237 | - [x] 搭建 LangChain-Emoji 初步框架,完善基本功能
238 | - [x] 支持本地向量数据库 Chroma
239 | - [x] 搭建前端 Web Demo
240 | - [x] 选择 LLM
241 | - [ ] 支持更多模型
242 | - [ ] 在线大模型: 深度求索 ⏳ 测试中
243 | - [ ] 本地大模型
244 | - [x] 接入微信客户端,开启斗图模式
245 |
246 | ## 🌏 项目交流讨论
247 |
248 |
249 |
250 | 🎉 扫码联系作者,如果你也对本项目感兴趣
251 | 🎉 欢迎加入 LangChain-X (帝阅开发社区) 项目群参与讨论交流
252 |
253 | ## 💥 贡献
254 |
255 | 欢迎大家贡献力量,一起共建 LangChain-Emoji,您可以做任何有益事情
256 |
257 | - 报告错误
258 | - 建议改进
259 | - 文档贡献
260 | - 代码贡献
261 | ...
262 | 👏👏👏
263 |
264 | ---
265 |
266 | ### [帝阅介绍](https://dread.run/#/)
267 |
268 | > 「帝阅」
269 | > 是一款个人专属知识管理与创造的 AI Native 产品
270 | > 为用户打造一位专属的侍读助理,帮助提升用户获取知识效率和发挥创造力
271 | > 让用户更好地去积累知识、管理知识、运用知识
272 |
273 | LangChain-Emoji 是帝阅项目一个子项目
274 |
275 | 欢迎大家前往体验[帝阅](https://dread.run/#/) 给我们提出宝贵的建议
276 |
277 | ---
278 |
279 |
280 |
281 | 
282 |
283 | 帝阅DeepRead
284 |
285 |
--------------------------------------------------------------------------------
/docs/config.md:
--------------------------------------------------------------------------------
1 | ## 配置文件说明
2 |
3 | ```yaml
4 | # 服务器配置
5 | server:
6 | env_name: ${APP_ENV:prod}
7 | port: ${PORT:8002}
8 | cors:
9 | enabled: false
10 | allow_origins: ["*"]
11 | allow_methods: ["*"]
12 | allow_headers: ["*"]
13 | auth:
14 | enabled: true #是否开启认证
15 | secret: "Basic c2VjcmV0OmtleQ==" # Http Authorization认证
16 |
17 | # 大模型配置
18 | # 选项有4个 openai zhipuai mock openai+zhipuai
19 | # openai+zhipuai 表示同时支持两个模型,根据API传入参数决定使用哪个大模型
20 | llm:
21 | mode: openai+zhipuai
22 |
23 | # 向量模型
24 | # 选项有3个 openai zhipuai mock
25 | # 国内环境建议选择zhipuai比较稳定
26 | # 如果使用腾讯云向量数据库,此参数可以忽略
27 | embedding:
28 | mode: zhipuai
29 |
30 | # openai模型参数
31 | openai:
32 | temperature: 1
33 | modelname: "gpt-3.5-turbo-0125" #"gpt-3.5-turbo-1106"
34 | api_base: ${OPENAI_API_BASE:}
35 | api_key: ${OPENAI_API_KEY:}
36 |
37 | # zhipuai模型参数
38 | zhipuai:
39 | temperature: 0.95
40 | top_p: 0.6
41 | modelname: "glm-3-turbo"
42 | api_key: ${ZHIPUAI_API_KEY:}
43 |
44 | # LangSmith调试参数
45 | # 详情见 https://smith.langchain.com
46 | langsmith:
47 | trace_version_v2: true
48 | api_key: ${LANGCHAIN_API_KEY:}
49 |
50 | # 向量数据库参数
51 | vectorstore:
52 | database: tcvectordb #向量数据库类型,目前暂时只支持腾讯云向量数据库
53 | tcvectordb: # 配置详情见 https://cloud.tencent.com/document/product/1709
54 | url: ${TCVERCTORDB_API_HOST:} #腾讯云API请求地址
55 | username: root #账号
56 | api_key: ${TCVERCTORDB_API_KEY:} #腾讯云向量数据库api key
57 | collection_name: EmojiCollection #表名称
58 | database_name: DeepReadDatabase #数据库名称
59 |
60 | # 表情包数据集信息
61 | dataset:
62 | name: emo-visual-data # 数据集文件名称
63 | google_driver_id: 1r3uO0wvgQ791M_6iIyBODo_8GekBjPMf # 谷歌云盘ID
64 | mode: local #采用何种数据集加载方式,目前支持 local(本地) 、 Minio(云盘)
65 |
66 | # 数据本地存储信息
67 | data:
68 | local_data_folder: local_data # 本地存储路径,以项目根目录为启始
69 |
70 | # 采用Minio来存储数据
71 | # https://www.minio.org.cn/docs/minio/linux/operations/installation.html
72 | minio:
73 | host: ${MINIO_HOST:} #Minio 请求地址
74 | bucket_name: emoji #数据桶名称
75 | access_key: ${MINIO_ACCESS_KEY:} #密钥Key
76 | secret_key: ${MINIO_SECRET_KEY:} #密钥Key
77 | ```
78 |
79 | ## 私有配置文件
80 |
81 | 由于配置文件涉及一些 API KEY 等隐私信息,在不改动默认配置文件的情况下,可以新增一个单独的私有配置文件,进行加载
82 |
83 | 详情见 `langchain_emoji/settings` 代码
84 |
85 | ```shell
86 | # 1. 设置环境变量
87 | export LE_PROFILES=pro
88 |
89 | # 2. 新增配置文件
90 | vim settings-pro.yaml
91 |
92 | # 3. 复制默认配置文件,增加API_KEY等信息
93 |
94 | # 4. 启动项目,程序会自动合并两个配置文件,冲突地方以settings-pro.yaml为准
95 |
96 | ```
97 |
98 | ## 动态加载配置文件
99 |
100 | 通过监听 yaml 配置文件发生内容变化,程序会重新加载文件,方便实时调整参数
101 |
102 | 详情见 `langchain_emoji/__main__.py` 代码
103 |
--------------------------------------------------------------------------------
/docs/description.md:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ptonlix/LangChain-Emoji/4c0a926464a7419e48df32b3c5e26794e2f9b312/docs/description.md
--------------------------------------------------------------------------------
/docs/pic/emoji_logo.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ptonlix/LangChain-Emoji/4c0a926464a7419e48df32b3c5e26794e2f9b312/docs/pic/emoji_logo.jpg
--------------------------------------------------------------------------------
/docs/pic/example_video.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ptonlix/LangChain-Emoji/4c0a926464a7419e48df32b3c5e26794e2f9b312/docs/pic/example_video.gif
--------------------------------------------------------------------------------
/docs/pic/flowstruct.drawio:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 |
108 |
109 |
110 |
111 |
112 |
113 |
114 |
115 |
116 |
117 |
118 |
119 |
120 |
121 |
122 |
123 |
124 |
125 |
126 |
127 |
128 |
129 |
130 |
131 |
132 |
133 |
134 |
135 |
136 |
137 |
138 |
139 |
140 |
141 |
142 |
143 |
144 |
145 |
146 |
147 |
148 |
149 |
150 |
151 |
152 |
153 |
154 |
155 |
156 |
157 |
158 |
159 |
160 |
161 |
162 |
163 |
--------------------------------------------------------------------------------
/docs/pic/flowstruct.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ptonlix/LangChain-Emoji/4c0a926464a7419e48df32b3c5e26794e2f9b312/docs/pic/flowstruct.png
--------------------------------------------------------------------------------
/docs/pic/introduce.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ptonlix/LangChain-Emoji/4c0a926464a7419e48df32b3c5e26794e2f9b312/docs/pic/introduce.jpg
--------------------------------------------------------------------------------
/docs/pic/logo.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ptonlix/LangChain-Emoji/4c0a926464a7419e48df32b3c5e26794e2f9b312/docs/pic/logo.jpg
--------------------------------------------------------------------------------
/docs/pic/wx.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ptonlix/LangChain-Emoji/4c0a926464a7419e48df32b3c5e26794e2f9b312/docs/pic/wx.png
--------------------------------------------------------------------------------
/docs/pic/wxemoji.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ptonlix/LangChain-Emoji/4c0a926464a7419e48df32b3c5e26794e2f9b312/docs/pic/wxemoji.png
--------------------------------------------------------------------------------
/docs/systemd.md:
--------------------------------------------------------------------------------
1 | # Linux 设置开机自启动
2 |
3 | ```
4 | # /usr/lib/systemd/system/langchain-emoji.service
5 | [Unit]
6 | Description=Langchain-Emoji
7 | After=network.target
8 |
9 | [Service]
10 | Type=simple
11 | ExecStart=/root/miniconda3/envs/LangChain-Emoji/bin/python -m langchain_emoji
12 |
13 | [Install]
14 | WantedBy=multi-user.target
15 | ```
16 |
17 | ```
18 | systemctl enable langchain-emoji.service 开机自启动
19 | systemctl start langchain-emoji
20 |
21 | # 服务启动失败,可通过下面命令查看原因
22 | journalctl -u langchain-emoji -f
23 |
24 | ```
25 |
--------------------------------------------------------------------------------
/frontend/emoji.py:
--------------------------------------------------------------------------------
1 | import aiohttp
2 | from typing import Optional, Dict, Any
3 | from langchain_emoji.settings.settings import settings
4 | from pydantic import ValidationError
5 | import logging
6 | import asyncio
7 | import streamlit as st
8 | import uuid
9 | import base64
10 | from PIL import Image
11 | from io import BytesIO
12 |
13 | logger = logging.getLogger(__name__)
14 |
15 |
16 | class ApiRequest:
17 | def __init__(self, loop: asyncio.AbstractEventLoop, base_url: str):
18 | self.base_url = base_url
19 | self.loop = loop
20 | self.timeout = aiohttp.ClientTimeout(total=10 * 60)
21 | self.session = None
22 |
23 | async def initialize_session(self):
24 | self.session = aiohttp.ClientSession(loop=self.loop, timeout=self.timeout)
25 |
26 | async def _request(
27 | self,
28 | method: str,
29 | endpoint: str,
30 | params: Optional[Dict[str, Any]] = None,
31 | json: Optional[Dict[str, Any]] = None,
32 | headers: Optional[Dict[str, str]] = None,
33 | ) -> Dict[str, Any]:
34 | if self.session is None:
35 | await self.initialize_session()
36 | url = f"{self.base_url}{endpoint}"
37 | headers = headers or {}
38 | headers["Content-Type"] = "application/json"
39 | try:
40 | async with self.session.request(
41 | method, url, params=params, json=json, headers=headers
42 | ) as response:
43 | return await response.json()
44 | except aiohttp.ServerTimeoutError as timeout_error:
45 | logger.warning(
46 | f"Timeout error in {method} request to {url}: {timeout_error}"
47 | )
48 | raise
49 | except aiohttp.ClientError as e:
50 | logger.warning(f"Error in {method} request to {url}: {e}")
51 | raise
52 | except Exception as e:
53 | logger.warning(f"Unexpected error in {method} request to {url}: {e}")
54 | raise
55 |
56 | async def get(
57 | self,
58 | endpoint: str,
59 | params: Optional[Dict[str, Any]] = None,
60 | headers: Optional[Dict[str, str]] = None,
61 | ) -> Dict[str, Any]:
62 | return await self._request("GET", endpoint, params=params, headers=headers)
63 |
64 | async def post(
65 | self,
66 | endpoint: str,
67 | data: Optional[Dict[str, Any]] = None,
68 | headers: Optional[Dict[str, str]] = None,
69 | ) -> Dict[str, Any]:
70 | return await self._request("POST", endpoint, json=data, headers=headers)
71 |
72 | async def close(self):
73 | await self.session.close()
74 |
75 |
76 | class EmojiApiHandler(ApiRequest):
77 | def __init__(self, loop: asyncio.AbstractEventLoop):
78 | super().__init__(
79 | loop=loop, base_url="http://127.0.0.1:" + str(settings().server.port)
80 | )
81 | self.accesstoken = settings().server.auth.secret
82 |
83 | async def emoji(
84 | self,
85 | data: Dict,
86 | endpoint: str = "/v1/emoji",
87 | headers: Optional[Dict[str, str]] = None,
88 | ) -> Dict | None:
89 | headers = headers or {}
90 | headers["Authorization"] = f"Basic {self.accesstoken}"
91 | logger.info(data)
92 | try:
93 | resp_dict = await self.post(endpoint=endpoint, data=data, headers=headers)
94 | # logger.info(resp_dict)
95 | return resp_dict
96 | except ValidationError as e:
97 | logger.warning(f"Unexpected ValidationError error in UserReg request : {e}")
98 | return None
99 | except Exception as e:
100 | logger.exception(f"Unexpected error in UserReg request : {e}")
101 | return None
102 |
103 |
104 | async def handler_progress_bar(progress_bar):
105 | for i in range(8):
106 | progress_bar.progress(10 * i)
107 | await asyncio.sleep(0.1)
108 |
109 |
110 | def decodebase64(context: str):
111 | img_data = base64.b64decode(context) # 解码时只要内容部分
112 |
113 | image = Image.open(BytesIO(img_data))
114 | return image
115 |
116 |
117 | def fetch_emoji(prompt: str, llm: str, progress_bar):
118 |
119 | loop = asyncio.new_event_loop()
120 | asyncio.set_event_loop(loop)
121 | emoji_api = EmojiApiHandler(loop)
122 |
123 | async def emoji_request():
124 | try:
125 | response = await emoji_api.emoji(
126 | data={"prompt": prompt, "req_id": str(uuid.uuid4()), "llm": llm}
127 | )
128 | return response.get("data")
129 | finally:
130 | await emoji_api.close()
131 |
132 | try:
133 | emoji_task = loop.create_task(emoji_request())
134 | progress_task = loop.create_task(handler_progress_bar(progress_bar))
135 | result = loop.run_until_complete(asyncio.gather(emoji_task, progress_task))
136 | return result[0]
137 | finally:
138 | loop.close()
139 |
140 |
141 | def area_image(area_placeholder, image, filename):
142 | col1, col2, col3 = area_placeholder.columns([1, 3, 1])
143 | with col1:
144 | st.write(" ")
145 |
146 | with col2:
147 | st.image(
148 | image=image,
149 | caption=filename,
150 | use_column_width=True,
151 | )
152 | with col3:
153 | st.write(" ")
154 |
155 |
156 | st.subheader("🥳 LangChain-Emoji")
157 | st.markdown(
158 | "基于LangChain的开源表情包斗图Agent [Github](https://github.com/ptonlix/LangChain-Emoji) [作者: Baird](https://github.com/ptonlix)"
159 | )
160 |
161 | with st.container():
162 | image_path = "../docs/pic/logo.jpg"
163 |
164 | with st.container():
165 | image_placeholder = st.empty()
166 | area_placeholder = st.empty()
167 | content_placeholder = st.empty()
168 | area_image(area_placeholder, image_path, "帝阅DeepRead")
169 |
170 | with st.form("emoji_form"):
171 | st.write("Emoji Input")
172 | select_llm = st.radio(
173 | "Please select a LLM",
174 | ["ChatGPT", "ZhipuAI", "DeepSeek"],
175 | captions=["OpenAI", "智谱清言", "深度求索"],
176 | )
177 | llm_mapping = {
178 | "ChatGPT": "openai",
179 | "ZhipuAI": "zhipuai",
180 | "DeepSeek": "deepseek",
181 | }
182 | llm = llm_mapping[select_llm]
183 |
184 | prompt = st.text_input("Enter Emoji Prompt:", value="今天很开心~")
185 |
186 | submitted = st.form_submit_button("Submit", help="点击获取最佳表情包")
187 | if submitted:
188 | bar = st.progress(0)
189 | response = fetch_emoji(prompt, llm, bar)
190 | bar.progress(100)
191 | base64_str = response["emojidetail"]["base64"]
192 | filename = response["emojiinfo"]["filename"]
193 | content = response["emojiinfo"]["content"]
194 | area_image(area_placeholder, decodebase64(base64_str), filename)
195 | content_placeholder.write(content)
196 |
--------------------------------------------------------------------------------
/langchain_emoji/__init__.py:
--------------------------------------------------------------------------------
1 | """langchain_emoji."""
2 |
3 | import logging
4 | from logging.handlers import TimedRotatingFileHandler
5 | from langchain_emoji.constants import PROJECT_ROOT_PATH
6 |
7 | # Set to 'DEBUG' to have extensive logging turned on, even for libraries
8 | ROOT_LOG_LEVEL = "INFO"
9 |
10 | PRETTY_LOG_FORMAT = "%(asctime)s.%(msecs)03d [%(levelname)-5s] [%(filename)-12s][line:%(lineno)d] - %(message)s"
11 | logging.basicConfig(level=ROOT_LOG_LEVEL, format=PRETTY_LOG_FORMAT, datefmt="%H:%M:%S")
12 | logging.captureWarnings(True)
13 |
14 |
15 | file_handler = TimedRotatingFileHandler(
16 | filename=PROJECT_ROOT_PATH / "log/app.log",
17 | when="midnight",
18 | interval=1,
19 | backupCount=7,
20 | )
21 | file_handler.suffix = "%Y-%m-%d.log"
22 | file_handler.encoding = "utf-8"
23 | file_handler.setLevel(ROOT_LOG_LEVEL)
24 | # 创建格式化器并将其添加到处理器
25 | file_formatter = logging.Formatter(PRETTY_LOG_FORMAT)
26 | file_handler.setFormatter(file_formatter)
27 |
28 | logging.getLogger().addHandler(file_handler)
29 |
--------------------------------------------------------------------------------
/langchain_emoji/__main__.py:
--------------------------------------------------------------------------------
1 | # start a fastapi server with uvicorn
2 |
3 | import glob
4 | from langchain_emoji.settings.settings import settings
5 | from langchain_emoji.constants import PROJECT_ROOT_PATH
6 |
7 | from uvicorn import Config, Server
8 | from uvicorn.supervisors.watchfilesreload import WatchFilesReload
9 |
10 | from pathlib import Path
11 | import hashlib
12 | from socket import socket
13 | from typing import Callable
14 | from watchfiles import Change, watch
15 |
16 |
17 | """
18 | 自定义uvicorn启动类
19 | """
20 |
21 |
22 | class FileFilter:
23 | def __init__(self, config: Config):
24 | default_includes = ["*.py"]
25 | self.includes = [
26 | default
27 | for default in default_includes
28 | if default not in config.reload_excludes
29 | ]
30 | self.includes.extend(config.reload_includes)
31 | self.includes = list(set(self.includes))
32 |
33 | default_excludes = [".*", ".py[cod]", ".sw.*", "~*"]
34 | self.excludes = [
35 | default
36 | for default in default_excludes
37 | if default not in config.reload_includes
38 | ]
39 | self.exclude_dirs = []
40 | for e in config.reload_excludes:
41 | p = Path(e)
42 | try:
43 | is_dir = p.is_dir()
44 | except OSError: # pragma: no cover
45 | # gets raised on Windows for values like "*.py"
46 | is_dir = False
47 |
48 | if is_dir:
49 | self.exclude_dirs.append(p)
50 | else:
51 | self.excludes.append(e)
52 | self.excludes = list(set(self.excludes))
53 |
54 | def __call__(self, change: "Change", path: Path | str) -> bool:
55 | if isinstance(path, str):
56 | path = Path(path)
57 | for include_pattern in self.includes:
58 | if path.match(include_pattern):
59 | if str(path).endswith(include_pattern):
60 | return True
61 |
62 | for exclude_dir in self.exclude_dirs:
63 | if exclude_dir in path.parents:
64 | return False
65 |
66 | for exclude_pattern in self.excludes:
67 | if path.match(exclude_pattern):
68 | return False
69 |
70 | return True
71 | return False
72 |
73 |
74 | class CustomWatchFilesReload(WatchFilesReload):
75 |
76 | def __init__(
77 | self,
78 | config: Config,
79 | target: Callable[[list[socket] | None], None],
80 | sockets: list[socket],
81 | ) -> None:
82 | super().__init__(config, target, sockets)
83 | self.reloader_name = "WatchFiles"
84 | self.reload_dirs = []
85 | self.ignore_paths = []
86 | for directory in config.reload_dirs:
87 | self.reload_dirs.append(directory)
88 |
89 | self.watch_filter = FileFilter(config)
90 | self.watcher = watch(
91 | *self.reload_dirs,
92 | watch_filter=self.watch_filter,
93 | stop_event=self.should_exit,
94 | # using yield_on_timeout here mostly to make sure tests don't
95 | # hang forever, won't affect the class's behavior
96 | yield_on_timeout=True,
97 | )
98 |
99 | self.file_hashes = {} # Store file hashes
100 |
101 | # Calculate and store hashes for initial files
102 | for directory in self.reload_dirs:
103 | for file_path in directory.rglob("*"):
104 | if file_path.is_file() and self.watch_filter(None, file_path):
105 | self.file_hashes[str(file_path)] = self.calculate_file_hash(
106 | file_path
107 | )
108 |
109 | def should_restart(self) -> list[Path] | None:
110 | self.pause()
111 |
112 | changes = next(self.watcher)
113 | if changes:
114 | changed_paths = []
115 | for event_type, path in changes:
116 | if event_type == Change.modified or event_type == Change.added:
117 | file_hash = self.calculate_file_hash(path)
118 | if (
119 | path not in self.file_hashes
120 | or self.file_hashes[path] != file_hash
121 | ):
122 | changed_paths.append(Path(path))
123 | self.file_hashes[path] = file_hash
124 |
125 | if changed_paths:
126 | return [p for p in changed_paths if self.watch_filter(None, p)]
127 |
128 | return None
129 |
130 | def calculate_file_hash(self, file_path: str) -> str:
131 | with open(file_path, "rb") as file:
132 | file_contents = file.read()
133 | return hashlib.md5(file_contents).hexdigest()
134 |
135 |
136 | non_yaml_files = [
137 | f
138 | for f in glob.glob("**", root_dir=PROJECT_ROOT_PATH, recursive=True)
139 | if not f.lower().endswith((".yaml", ".yml"))
140 | ]
141 | try:
142 | config = Config(
143 | app="langchain_emoji.main:app",
144 | host="0.0.0.0",
145 | port=settings().server.port,
146 | reload=True,
147 | reload_dirs=str(PROJECT_ROOT_PATH),
148 | reload_excludes=non_yaml_files,
149 | reload_includes="*.yaml",
150 | log_config=None,
151 | )
152 |
153 | server = Server(config=config)
154 |
155 | sock = config.bind_socket()
156 | CustomWatchFilesReload(config, target=server.run, sockets=[sock]).run()
157 | except KeyboardInterrupt:
158 | ...
159 |
--------------------------------------------------------------------------------
/langchain_emoji/__version__.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from typing import TYPE_CHECKING
4 |
5 | from langchain_emoji.utils._compat import metadata
6 |
7 |
8 | if TYPE_CHECKING:
9 | from collections.abc import Callable
10 |
11 |
12 | # The metadata.version that we import for Python 3.7 is untyped, work around
13 | # that.
14 | version: Callable[[str], str] = metadata.version
15 |
16 | __version__ = version("langchain_emoji")
17 |
--------------------------------------------------------------------------------
/langchain_emoji/components/embedding/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ptonlix/LangChain-Emoji/4c0a926464a7419e48df32b3c5e26794e2f9b312/langchain_emoji/components/embedding/__init__.py
--------------------------------------------------------------------------------
/langchain_emoji/components/embedding/custom/zhipuai/__init__.py:
--------------------------------------------------------------------------------
1 | from langchain_emoji.components.embedding.custom.zhipuai.zhipuai_custom import (
2 | ZhipuaiTextEmbeddings,
3 | )
4 |
5 | __all__ = ["ZhipuaiTextEmbeddings"]
6 |
--------------------------------------------------------------------------------
/langchain_emoji/components/embedding/custom/zhipuai/zhipuai_custom.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Dict, List, Optional
2 |
3 | from langchain_core.embeddings import Embeddings
4 | from langchain_core.pydantic_v1 import BaseModel, root_validator
5 | from langchain_core.utils import get_from_dict_or_env
6 | from packaging.version import parse
7 | from importlib.metadata import version
8 | import asyncio
9 | import logging
10 |
11 | logger = logging.getLogger(__name__)
12 |
13 |
14 | def is_zhipu_v2() -> bool:
15 | """Return whether zhipu API is v2 or more."""
16 | _version = parse(version("zhipuai"))
17 | return _version.major >= 2
18 |
19 |
20 | # Official Website: https://open.bigmodel.cn/dev/api#text_embedding
21 | # An API-key is required to use this embedding model. You can get one by registering
22 | class ZhipuaiTextEmbeddings(BaseModel, Embeddings):
23 | """Zhipuai Text Embedding models."""
24 |
25 | client: Any # ZhipuAI #: :meta private:
26 | model_name: str = "embedding-2"
27 | zhipuai_api_key: Optional[str] = None
28 | count_token: int = 0
29 |
30 | @root_validator(allow_reuse=True)
31 | def validate_environment(cls, values: Dict) -> Dict:
32 | """Validate that auth token exists in environment."""
33 |
34 | try:
35 | from zhipuai import ZhipuAI
36 |
37 | if not is_zhipu_v2():
38 | raise RuntimeError(
39 | "zhipuai package version is too low"
40 | "Please install it via 'pip install --upgrade zhipuai'"
41 | )
42 |
43 | zhipuai_api_key = get_from_dict_or_env(
44 | values, "zhipuai_api_key", "ZHIPUAI_API_KEY"
45 | )
46 |
47 | client = ZhipuAI(
48 | api_key=zhipuai_api_key,
49 | )
50 | values["client"] = client
51 | return values
52 | except ImportError:
53 | raise RuntimeError(
54 | "Could not import zhipuai package. "
55 | "Please install it via 'pip install zhipuai'"
56 | )
57 |
58 | def _embed(self, texts: List[str]) -> Optional[List[List[float]]]:
59 | """Internal method to call Zhipuai Embedding API and return embeddings.
60 |
61 | Args:
62 | texts: A list of texts to embed.
63 |
64 | Returns:
65 | A list of list of floats representing the embeddings, or None if an
66 | error occurs.
67 | """
68 | # try:
69 | # return [self._get_embedding(text) for text in texts]
70 |
71 | # except Exception as e:
72 | # logger.exception(e)
73 | # # Log the exception or handle it as needed
74 | # logger.info(
75 | # f"Exception occurred while trying to get embeddings: {str(e)}"
76 | # ) # noqa: T201
77 | # return None
78 | return asyncio.run(self._aembed(texts))
79 |
80 | def embed_documents(self, texts: List[str]) -> Optional[List[List[float]]]: # type: ignore[override]
81 | """Public method to get embeddings for a list of documents.
82 |
83 | Args:
84 | texts: The list of texts to embed.
85 |
86 | Returns:
87 | A list of embeddings, one for each text, or None if an error occurs.
88 | """
89 | return self._embed(texts)
90 |
91 | def embed_query(self, text: str) -> Optional[List[float]]: # type: ignore[override]
92 | """Public method to get embedding for a single query text.
93 |
94 | Args:
95 | text: The text to embed.
96 |
97 | Returns:
98 | Embeddings for the text, or None if an error occurs.
99 | """
100 | result = self._embed([text])
101 | return result[0] if result is not None else None
102 |
103 | def _get_embedding(self, text: str) -> List[float]:
104 | response = self.client.embeddings.create(model=self.model_name, input=text)
105 | self.count_token += response.usage.total_tokens
106 | return response.data[0].embedding
107 |
108 | async def _aet_embedding(self, text: str) -> List[float]:
109 | response = self.client.embeddings.create(model=self.model_name, input=text)
110 | self.count_token += response.usage.total_tokens # 统计token
111 | return response.data[0].embedding
112 |
113 | async def _aembed(self, texts: List[str]) -> Optional[List[List[float]]]:
114 | """Internal method to call Zhipuai Embedding API and return embeddings.
115 |
116 | Args:
117 | texts: A list of texts to embed.
118 |
119 | Returns:
120 | A list of list of floats representing the embeddings, or None if an
121 | error occurs.
122 | """
123 | try:
124 | tasks = [asyncio.create_task(self._aet_embedding(text)) for text in texts]
125 | return await asyncio.gather(*tasks)
126 |
127 | except Exception as e:
128 | logger.exception(e)
129 | # Log the exception or handle it as needed
130 | logger.info(
131 | f"Exception occurred while trying to get embeddings: {str(e)}"
132 | ) # noqa: T201
133 | return None
134 |
135 | async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
136 | """Asynchronous Embed search docs."""
137 | return await self._aembed(texts)
138 |
139 | async def aembed_query(self, text: str) -> List[float]:
140 | """Asynchronous Embed query text."""
141 | result = await self._aembed([text])
142 | return result[0] if result is not None else None
143 |
144 |
145 | if __name__ == "__main__":
146 |
147 | # async def main():
148 | # client = ZhipuaiTextEmbeddings(
149 | # zhipuai_api_key="69c26b8240f40e316f2b6a20c991fde2.3X8YMXlh8udmdhwJ"
150 | # )
151 |
152 | # print(
153 | # await client.aembed_documents(
154 | # ["你好,帝阅AI搜索", "帝阅DeepRead", "帝阅No.1"]
155 | # )
156 | # )
157 |
158 | # asyncio.run(main())
159 |
160 | client = ZhipuaiTextEmbeddings(
161 | zhipuai_api_key="69c26b8240f40e316f2b6a20c991fde2.3X8YMXlh8udmdhwJ"
162 | )
163 | client.embed_documents(["你好,帝阅AI搜索", "帝阅DeepRead", "帝阅No.1"])
164 | print(client.count_token)
165 |
--------------------------------------------------------------------------------
/langchain_emoji/components/embedding/embedding_component.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | from injector import inject
4 | from langchain_community.embeddings import OpenAIEmbeddings
5 | from langchain_core.embeddings import Embeddings, DeterministicFakeEmbedding
6 |
7 | from langchain_emoji.settings.settings import Settings
8 | from langchain_emoji.components.embedding.custom.zhipuai import ZhipuaiTextEmbeddings
9 |
10 |
11 | logger = logging.getLogger(__name__)
12 |
13 |
14 | class EmbeddingComponent:
15 | @inject
16 | def __init__(self, settings: Settings) -> None:
17 | embedding_mode = settings.embedding.mode
18 | logger.info("Initializing the embedding in mode=%s", embedding_mode)
19 | match embedding_mode:
20 | case "local":
21 | ... # Todo
22 | case "openai":
23 | openai_settings = settings.openai
24 | self._embedding = OpenAIEmbeddings(
25 | api_key=openai_settings.api_key,
26 | openai_api_base=openai_settings.api_base,
27 | )
28 | case "zhipuai":
29 | zhipuai_settings = settings.zhipuai
30 | self._embedding = ZhipuaiTextEmbeddings(
31 | zhipuai_api_key=zhipuai_settings.api_key
32 | )
33 | case "mock":
34 | self._embedding = DeterministicFakeEmbedding(size=1352)
35 |
36 | @property
37 | def embedding(self) -> Embeddings:
38 | return self._embedding
39 |
40 | @property
41 | def total_tokens(self) -> int:
42 | try:
43 | return self._embedding.count_token # 目前只支持Zhipuai统计 embedding token
44 | except Exception:
45 | return 0
46 |
--------------------------------------------------------------------------------
/langchain_emoji/components/llm/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ptonlix/LangChain-Emoji/4c0a926464a7419e48df32b3c5e26794e2f9b312/langchain_emoji/components/llm/__init__.py
--------------------------------------------------------------------------------
/langchain_emoji/components/llm/custom/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ptonlix/LangChain-Emoji/4c0a926464a7419e48df32b3c5e26794e2f9b312/langchain_emoji/components/llm/custom/.gitkeep
--------------------------------------------------------------------------------
/langchain_emoji/components/llm/custom/zhipuai/__init__.py:
--------------------------------------------------------------------------------
1 | from langchain_emoji.components.llm.custom.zhipuai.zhipuai_custom import (
2 | ChatZhipuAI,
3 | )
4 | from langchain_emoji.components.llm.custom.zhipuai.zhipuai_info import (
5 | ZhipuAICallbackHandler,
6 | get_zhipuai_callback,
7 | )
8 |
9 |
10 | __all__ = ["ChatZhipuAI, ZhipuAICallbackHandler, get_zhipuai_callback"]
11 |
--------------------------------------------------------------------------------
/langchain_emoji/components/llm/custom/zhipuai/zhipuai_custom.py:
--------------------------------------------------------------------------------
1 | """ZHIPU AI chat models wrapper."""
2 | from __future__ import annotations
3 |
4 | import asyncio
5 | import logging
6 | from functools import partial
7 | from importlib.metadata import version
8 | from typing import (
9 | Any,
10 | Callable,
11 | Dict,
12 | Iterator,
13 | List,
14 | Mapping,
15 | Optional,
16 | Tuple,
17 | Type,
18 | Union,
19 | )
20 |
21 | from langchain_core.callbacks import (
22 | AsyncCallbackManagerForLLMRun,
23 | CallbackManagerForLLMRun,
24 | )
25 | from langchain_core.language_models.chat_models import (
26 | BaseChatModel,
27 | generate_from_stream,
28 | )
29 | from langchain_core.language_models.llms import create_base_retry_decorator
30 | from langchain_core.messages import (
31 | AIMessage,
32 | AIMessageChunk,
33 | BaseMessage,
34 | BaseMessageChunk,
35 | ChatMessage,
36 | ChatMessageChunk,
37 | HumanMessage,
38 | HumanMessageChunk,
39 | SystemMessage,
40 | SystemMessageChunk,
41 | ToolMessage,
42 | ToolMessageChunk,
43 | )
44 | from langchain_core.outputs import (
45 | ChatGeneration,
46 | ChatGenerationChunk,
47 | ChatResult,
48 | )
49 | from langchain_core.pydantic_v1 import BaseModel, Field
50 | from packaging.version import parse
51 |
52 | logger = logging.getLogger(__name__)
53 |
54 |
55 | def is_zhipu_v2() -> bool:
56 | """Return whether zhipu API is v2 or more."""
57 | _version = parse(version("zhipuai"))
58 | return _version.major >= 2
59 |
60 |
61 | def _create_retry_decorator(
62 | llm: ChatZhipuAI,
63 | run_manager: Optional[
64 | Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun]
65 | ] = None,
66 | ) -> Callable[[Any], Any]:
67 | import zhipuai
68 |
69 | errors = [
70 | zhipuai.ZhipuAIError,
71 | zhipuai.APIStatusError,
72 | zhipuai.APIRequestFailedError,
73 | zhipuai.APIReachLimitError,
74 | zhipuai.APIInternalError,
75 | zhipuai.APIServerFlowExceedError,
76 | zhipuai.APIResponseError,
77 | zhipuai.APIResponseValidationError,
78 | zhipuai.APITimeoutError,
79 | ]
80 | return create_base_retry_decorator(
81 | error_types=errors, max_retries=llm.max_retries, run_manager=run_manager
82 | )
83 |
84 |
85 | def convert_message_to_dict(message: BaseMessage) -> dict:
86 | """Convert a LangChain message to a dictionary.
87 |
88 | Args:
89 | message: The LangChain message.
90 |
91 | Returns:
92 | The dictionary.
93 | """
94 | message_dict: Dict[str, Any]
95 | if isinstance(message, ChatMessage):
96 | message_dict = {"role": message.role, "content": message.content}
97 | elif isinstance(message, HumanMessage):
98 | message_dict = {"role": "user", "content": message.content}
99 | elif isinstance(message, AIMessage):
100 | message_dict = {"role": "assistant", "content": message.content}
101 | if "tool_calls" in message.additional_kwargs:
102 | message_dict["tool_calls"] = message.additional_kwargs["tool_calls"]
103 | # If tool calls only, content is None not empty string
104 | if message_dict["content"] == "":
105 | message_dict["content"] = None
106 | elif isinstance(message, SystemMessage):
107 | message_dict = {"role": "system", "content": message.content}
108 | elif isinstance(message, ToolMessage):
109 | message_dict = {
110 | "role": "tool",
111 | "content": message.content,
112 | "tool_call_id": message.tool_call_id,
113 | }
114 | else:
115 | raise TypeError(f"Got unknown type {message}")
116 | if "name" in message.additional_kwargs:
117 | message_dict["name"] = message.additional_kwargs["name"]
118 | return message_dict
119 |
120 |
121 | def convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
122 | """Convert a dictionary to a LangChain message.
123 |
124 | Args:
125 | _dict: The dictionary.
126 |
127 | Returns:
128 | The LangChain message.
129 | """
130 | role = _dict.get("role")
131 | if role == "user":
132 | return HumanMessage(content=_dict.get("content", ""))
133 | elif role == "assistant":
134 | content = _dict.get("content", "") or ""
135 | additional_kwargs: Dict = {}
136 | if tool_calls := _dict.get("tool_calls"):
137 | additional_kwargs["tool_calls"] = tool_calls
138 | return AIMessage(content=content, additional_kwargs=additional_kwargs)
139 | elif role == "system":
140 | return SystemMessage(content=_dict.get("content", ""))
141 | elif role == "tool":
142 | additional_kwargs = {}
143 | if "name" in _dict:
144 | additional_kwargs["name"] = _dict["name"]
145 | return ToolMessage(
146 | content=_dict.get("content", ""),
147 | tool_call_id=_dict.get("tool_call_id"),
148 | additional_kwargs=additional_kwargs,
149 | )
150 | else:
151 | return ChatMessage(content=_dict.get("content", ""), role=role)
152 |
153 |
154 | def _convert_delta_to_message_chunk(
155 | _dict: Mapping[str, Any], default_class: Type[BaseMessageChunk]
156 | ) -> BaseMessageChunk:
157 | role = _dict.get("role")
158 | content = _dict.get("content") or ""
159 | additional_kwargs: Dict = {}
160 | if _dict.get("tool_calls"):
161 | additional_kwargs["tool_calls"] = _dict["tool_calls"]
162 |
163 | if role == "user" or default_class == HumanMessageChunk:
164 | return HumanMessageChunk(content=content)
165 | elif role == "assistant" or default_class == AIMessageChunk:
166 | return AIMessageChunk(content=content, additional_kwargs=additional_kwargs)
167 | elif role == "system" or default_class == SystemMessageChunk:
168 | return SystemMessageChunk(content=content)
169 | elif role == "tool" or default_class == ToolMessageChunk:
170 | return ToolMessageChunk(content=content, tool_call_id=_dict["tool_call_id"])
171 | elif role or default_class == ChatMessageChunk:
172 | return ChatMessageChunk(content=content, role=role)
173 | else:
174 | return default_class(content=content)
175 |
176 |
177 | class ChatZhipuAI(BaseChatModel):
178 | """
179 | `ZHIPU AI` large language chat models API.
180 |
181 | To use, you should have the ``zhipuai`` python package installed.
182 |
183 | Example:
184 | .. code-block:: python
185 |
186 | from langchain_community.chat_models import ChatZhipuAI
187 |
188 | zhipuai_chat = ChatZhipuAI(
189 | temperature=0.5,
190 | api_key="your-api-key",
191 | model_name="glm-3-turbo",
192 | )
193 |
194 | """
195 |
196 | zhipuai: Any
197 | zhipuai_api_key: Optional[str] = Field(default=None, alias="api_key")
198 | """Automatically inferred from env var `ZHIPUAI_API_KEY` if not provided."""
199 |
200 | client: Any = Field(default=None, exclude=True) #: :meta private:
201 |
202 | model_name: str = Field("glm-3-turbo", alias="model")
203 | """
204 | Model name to use.
205 | -glm-3-turbo:
206 | According to the input of natural language instructions to complete a
207 | variety of language tasks, it is recommended to use SSE or asynchronous
208 | call request interface.
209 | -glm-4:
210 | According to the input of natural language instructions to complete a
211 | variety of language tasks, it is recommended to use SSE or asynchronous
212 | call request interface.
213 | """
214 |
215 | temperature: float = Field(0.95)
216 | """
217 | What sampling temperature to use. The value ranges from 0.0 to 1.0 and cannot
218 | be equal to 0.
219 | The larger the value, the more random and creative the output; The smaller
220 | the value, the more stable or certain the output will be.
221 | You are advised to adjust top_p or temperature parameters based on application
222 | scenarios, but do not adjust the two parameters at the same time.
223 | """
224 |
225 | top_p: float = Field(0.7)
226 | """
227 | Another method of sampling temperature is called nuclear sampling. The value
228 | ranges from 0.0 to 1.0 and cannot be equal to 0 or 1.
229 | The model considers the results with top_p probability quality tokens.
230 | For example, 0.1 means that the model decoder only considers tokens from the
231 | top 10% probability of the candidate set.
232 | You are advised to adjust top_p or temperature parameters based on application
233 | scenarios, but do not adjust the two parameters at the same time.
234 | """
235 |
236 | request_id: Optional[str] = Field(None)
237 | """
238 | Parameter transmission by the client must ensure uniqueness; A unique
239 | identifier used to distinguish each request, which is generated by default
240 | by the platform when the client does not transmit it.
241 | """
242 | do_sample: Optional[bool] = Field(True)
243 | """
244 | When do_sample is true, the sampling policy is enabled. When do_sample is false,
245 | the sampling policy temperature and top_p are disabled
246 | """
247 | streaming: bool = Field(False)
248 | """Whether to stream the results or not."""
249 |
250 | model_kwargs: Dict[str, Any] = Field(default_factory=dict)
251 | """Holds any model parameters valid for `create` call not explicitly specified."""
252 |
253 | max_tokens: Optional[int] = None
254 | """Number of chat completions to generate for each prompt."""
255 |
256 | max_retries: int = 2
257 | """Maximum number of retries to make when generating."""
258 |
259 | @property
260 | def _identifying_params(self) -> Dict[str, Any]:
261 | """Get the identifying parameters."""
262 | return {**{"model_name": self.model_name}, **self._default_params}
263 |
264 | @property
265 | def _llm_type(self) -> str:
266 | """Return the type of chat model."""
267 | return "zhipuai"
268 |
269 | @property
270 | def lc_secrets(self) -> Dict[str, str]:
271 | return {"zhipuai_api_key": "ZHIPUAI_API_KEY"}
272 |
273 | @classmethod
274 | def get_lc_namespace(cls) -> List[str]:
275 | """Get the namespace of the langchain object."""
276 | return ["langchain", "chat_models", "zhipuai"]
277 |
278 | @property
279 | def lc_attributes(self) -> Dict[str, Any]:
280 | attributes: Dict[str, Any] = {}
281 |
282 | if self.model_name:
283 | attributes["model"] = self.model_name
284 |
285 | if self.streaming:
286 | attributes["streaming"] = self.streaming
287 |
288 | if self.max_tokens:
289 | attributes["max_tokens"] = self.max_tokens
290 |
291 | return attributes
292 |
293 | @property
294 | def _default_params(self) -> Dict[str, Any]:
295 | """Get the default parameters for calling ZhipuAI API."""
296 | params = {
297 | "model": self.model_name,
298 | "stream": self.streaming,
299 | "temperature": self.temperature,
300 | "top_p": self.top_p,
301 | "do_sample": self.do_sample,
302 | **self.model_kwargs,
303 | }
304 | if self.max_tokens is not None:
305 | params["max_tokens"] = self.max_tokens
306 | return params
307 |
308 | @property
309 | def _client_params(self) -> Dict[str, Any]:
310 | """Get the parameters used for the zhipuai client."""
311 | zhipuai_creds: Dict[str, Any] = {
312 | "request_id": self.request_id,
313 | }
314 | return {**self._default_params, **zhipuai_creds}
315 |
316 | def __init__(self, *args, **kwargs):
317 | super().__init__(*args, **kwargs)
318 | try:
319 | from zhipuai import ZhipuAI
320 |
321 | if not is_zhipu_v2():
322 | raise RuntimeError(
323 | "zhipuai package version is too low"
324 | "Please install it via 'pip install --upgrade zhipuai'"
325 | )
326 |
327 | self.client = ZhipuAI(
328 | api_key=self.zhipuai_api_key, # 填写您的 APIKey
329 | )
330 | except ImportError:
331 | raise RuntimeError(
332 | "Could not import zhipuai package. "
333 | "Please install it via 'pip install zhipuai'"
334 | )
335 |
336 | def completions(self, **kwargs) -> Any | None:
337 | return self.client.chat.completions.create(**kwargs)
338 |
339 | async def async_completions(self, **kwargs) -> Any:
340 | loop = asyncio.get_running_loop()
341 | partial_func = partial(self.client.chat.completions.create, **kwargs)
342 | response = await loop.run_in_executor(
343 | None,
344 | partial_func,
345 | )
346 | return response
347 |
348 | async def async_completions_result(self, task_id):
349 | loop = asyncio.get_running_loop()
350 | response = await loop.run_in_executor(
351 | None,
352 | self.client.asyncCompletions.retrieve_completion_result,
353 | task_id,
354 | )
355 | return response
356 |
357 | def _create_chat_result(self, response: Union[dict, BaseModel]) -> ChatResult:
358 | generations = []
359 | if not isinstance(response, dict):
360 | response = response.dict()
361 | for res in response["choices"]:
362 | message = convert_dict_to_message(res["message"])
363 | generation_info = dict(finish_reason=res.get("finish_reason"))
364 | if "index" in res:
365 | generation_info["index"] = res["index"]
366 | gen = ChatGeneration(
367 | message=message,
368 | generation_info=generation_info,
369 | )
370 | generations.append(gen)
371 | token_usage = response.get("usage", {})
372 | llm_output = {
373 | "token_usage": token_usage,
374 | "model_name": self.model_name,
375 | "task_id": response.get("id", ""),
376 | "created_time": response.get("created", ""),
377 | }
378 | return ChatResult(generations=generations, llm_output=llm_output)
379 |
380 | def _create_message_dicts(
381 | self, messages: List[BaseMessage], stop: Optional[List[str]]
382 | ) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
383 | params = self._client_params
384 | if stop is not None:
385 | if "stop" in params:
386 | raise ValueError("`stop` found in both the input and default params.")
387 | params["stop"] = stop
388 | message_dicts = [convert_message_to_dict(m) for m in messages]
389 | return message_dicts, params
390 |
391 | def completion_with_retry(
392 | self, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any
393 | ) -> Any:
394 | """Use tenacity to retry the completion call."""
395 |
396 | retry_decorator = _create_retry_decorator(self, run_manager=run_manager)
397 |
398 | @retry_decorator
399 | def _completion_with_retry(**kwargs: Any) -> Any:
400 | return self.completions(**kwargs)
401 |
402 | return _completion_with_retry(**kwargs)
403 |
404 | async def acompletion_with_retry(
405 | self,
406 | run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
407 | **kwargs: Any,
408 | ) -> Any:
409 | """Use tenacity to retry the async completion call."""
410 |
411 | retry_decorator = _create_retry_decorator(self, run_manager=run_manager)
412 |
413 | @retry_decorator
414 | async def _completion_with_retry(**kwargs: Any) -> Any:
415 | return await self.async_completions(**kwargs)
416 |
417 | return await _completion_with_retry(**kwargs)
418 |
419 | def _generate(
420 | self,
421 | messages: List[BaseMessage],
422 | stop: Optional[List[str]] = None,
423 | run_manager: Optional[CallbackManagerForLLMRun] = None,
424 | stream: Optional[bool] = None,
425 | **kwargs: Any,
426 | ) -> ChatResult:
427 | """Generate a chat response."""
428 |
429 | should_stream = stream if stream is not None else self.streaming
430 | if should_stream:
431 | stream_iter = self._stream(
432 | messages, stop=stop, run_manager=run_manager, **kwargs
433 | )
434 | return generate_from_stream(stream_iter)
435 |
436 | message_dicts, params = self._create_message_dicts(messages, stop)
437 | params = {
438 | **params,
439 | **({"stream": stream} if stream is not None else {}),
440 | **kwargs,
441 | }
442 | response = self.completion_with_retry(
443 | messages=message_dicts, run_manager=run_manager, **params
444 | )
445 | return self._create_chat_result(response)
446 |
447 | async def _agenerate(
448 | self,
449 | messages: List[BaseMessage],
450 | stop: Optional[List[str]] = None,
451 | run_manager: Optional[CallbackManagerForLLMRun] = None,
452 | stream: Optional[bool] = False,
453 | **kwargs: Any,
454 | ) -> ChatResult:
455 | """Asynchronously generate a chat response."""
456 | should_stream = stream if stream is not None else self.streaming
457 | if should_stream:
458 | stream_iter = self._astream(
459 | messages, stop=stop, run_manager=run_manager, **kwargs
460 | )
461 | return generate_from_stream(stream_iter)
462 |
463 | message_dicts, params = self._create_message_dicts(messages, stop)
464 | params = {
465 | **params,
466 | **({"stream": stream} if stream is not None else {}),
467 | **kwargs,
468 | }
469 | response = await self.acompletion_with_retry(
470 | messages=message_dicts, run_manager=run_manager, **params
471 | )
472 | return self._create_chat_result(response)
473 |
474 | def _stream(
475 | self,
476 | messages: List[BaseMessage],
477 | stop: Optional[List[str]] = None,
478 | run_manager: Optional[CallbackManagerForLLMRun] = None,
479 | **kwargs: Any,
480 | ) -> Iterator[ChatGenerationChunk]:
481 | """Stream the chat response in chunks."""
482 | message_dicts, params = self._create_message_dicts(messages, stop)
483 | params = {**params, **kwargs, "stream": True}
484 |
485 | default_chunk_class = AIMessageChunk
486 | for chunk in self.completion_with_retry(
487 | messages=message_dicts, run_manager=run_manager, **params
488 | ):
489 | if not isinstance(chunk, dict):
490 | chunk = chunk.dict()
491 | if len(chunk["choices"]) == 0:
492 | continue
493 | choice = chunk["choices"][0]
494 | chunk = _convert_delta_to_message_chunk(
495 | choice["delta"], default_chunk_class
496 | )
497 |
498 | finish_reason = choice.get("finish_reason")
499 | generation_info = (
500 | dict(finish_reason=finish_reason) if finish_reason is not None else None
501 | )
502 | default_chunk_class = chunk.__class__
503 | chunk = ChatGenerationChunk(message=chunk, generation_info=generation_info)
504 | yield chunk
505 | if run_manager:
506 | run_manager.on_llm_new_token(chunk.text, chunk=chunk)
507 |
--------------------------------------------------------------------------------
/langchain_emoji/components/llm/custom/zhipuai/zhipuai_info.py:
--------------------------------------------------------------------------------
1 | """Callback Handler that prints to std out."""
2 | import threading
3 | from typing import Any, Dict, List
4 |
5 | from langchain_core.callbacks import BaseCallbackHandler
6 | from langchain_core.outputs import LLMResult
7 | from contextlib import contextmanager
8 | from typing import (
9 | Generator,
10 | )
11 |
12 | MODEL_COST_PER_1K_TOKENS = {
13 | # input
14 | "glm-4": 0.1,
15 | "glm-3-turbo": 0.005,
16 | # output
17 | "glm-4-completion": 0.1,
18 | "glm-3-turbo-completion": 0.005,
19 | }
20 |
21 |
22 | def standardize_model_name(
23 | model_name: str,
24 | is_completion: bool = False,
25 | ) -> str:
26 | """
27 | Standardize the model name to a format that can be used in the ZhipuAI API.
28 |
29 | Args:
30 | model_name: Model name to standardize.
31 | is_completion: Whether the model is used for completion or not.
32 | Defaults to False.
33 |
34 | Returns:
35 | Standardized model name.
36 |
37 | """
38 | model_name = model_name.lower()
39 | if is_completion and (
40 | model_name.startswith("glm-4") or model_name.startswith("glm-3-turbo")
41 | ):
42 | return model_name + "-completion"
43 | else:
44 | return model_name
45 |
46 |
47 | def get_zhipuai_token_cost_for_model(
48 | model_name: str, num_tokens: int, is_completion: bool = False
49 | ) -> float:
50 | """
51 | Get the cost in USD for a given model and number of tokens.
52 |
53 | Args:
54 | model_name: Name of the model
55 | num_tokens: Number of tokens.
56 | is_completion: Whether the model is used for completion or not.
57 | Defaults to False.
58 |
59 | Returns:
60 | Cost in CNY.
61 | """
62 | model_name = standardize_model_name(model_name, is_completion=is_completion)
63 | if model_name not in MODEL_COST_PER_1K_TOKENS:
64 | raise ValueError(
65 | f"Unknown model: {model_name}. Please provide a valid ZhipuAI model name."
66 | "Known models are: " + ", ".join(MODEL_COST_PER_1K_TOKENS.keys())
67 | )
68 | return MODEL_COST_PER_1K_TOKENS[model_name] * (num_tokens / 1000)
69 |
70 |
71 | class ZhipuAICallbackHandler(BaseCallbackHandler):
72 | """Callback Handler that tracks ZhipuAI info."""
73 |
74 | total_tokens: int = 0
75 | prompt_tokens: int = 0
76 | completion_tokens: int = 0
77 | successful_requests: int = 0
78 | total_cost: float = 0.0
79 |
80 | def __init__(self) -> None:
81 | super().__init__()
82 | self._lock = threading.Lock()
83 |
84 | def __repr__(self) -> str:
85 | return (
86 | f"Tokens Used: {self.total_tokens}\n"
87 | f"\tPrompt Tokens: {self.prompt_tokens}\n"
88 | f"\tCompletion Tokens: {self.completion_tokens}\n"
89 | f"Successful Requests: {self.successful_requests}\n"
90 | f"Total Cost (CNY): ${self.total_cost}"
91 | )
92 |
93 | @property
94 | def always_verbose(self) -> bool:
95 | """Whether to call verbose callbacks even if verbose is False."""
96 | return True
97 |
98 | def on_llm_start(
99 | self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
100 | ) -> None:
101 | """Print out the prompts."""
102 | pass
103 |
104 | def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
105 | """Print out the token."""
106 | pass
107 |
108 | def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
109 | """Collect token usage."""
110 | if response.llm_output is None:
111 | return None
112 |
113 | if "token_usage" not in response.llm_output:
114 | with self._lock:
115 | self.successful_requests += 1
116 | return None
117 |
118 | # compute tokens and cost for this request
119 | token_usage = response.llm_output["token_usage"]
120 | completion_tokens = token_usage.get("completion_tokens", 0)
121 | prompt_tokens = token_usage.get("prompt_tokens", 0)
122 | model_name = standardize_model_name(response.llm_output.get("model_name", ""))
123 | if model_name in MODEL_COST_PER_1K_TOKENS:
124 | completion_cost = get_zhipuai_token_cost_for_model(
125 | model_name, completion_tokens, is_completion=True
126 | )
127 | prompt_cost = get_zhipuai_token_cost_for_model(model_name, prompt_tokens)
128 | else:
129 | completion_cost = 0
130 | prompt_cost = 0
131 |
132 | # update shared state behind lock
133 | with self._lock:
134 | self.total_cost += prompt_cost + completion_cost
135 | self.total_tokens += token_usage.get("total_tokens", 0)
136 | self.prompt_tokens += prompt_tokens
137 | self.completion_tokens += completion_tokens
138 | self.successful_requests += 1
139 |
140 | def __copy__(self) -> "ZhipuAICallbackHandler":
141 | """Return a copy of the callback handler."""
142 | return self
143 |
144 | def __deepcopy__(self, memo: Any) -> "ZhipuAICallbackHandler":
145 | """Return a deep copy of the callback handler."""
146 | return self
147 |
148 |
149 | @contextmanager
150 | def get_zhipuai_callback() -> Generator[ZhipuAICallbackHandler, None, None]:
151 | """Get the ZhipuAI callback handler in a context manager.
152 | which conveniently exposes token and cost information.
153 |
154 | Returns:
155 | ZhipuAICallbackHandler: The ZhipuAI callback handler.
156 |
157 | Example:
158 | >>> with get_zhipuai_callback() as cb:
159 | ... # Use the ZhipuAI callback handler
160 | """
161 | cb = ZhipuAICallbackHandler()
162 | yield cb
163 |
--------------------------------------------------------------------------------
/langchain_emoji/components/llm/llm_component.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | from injector import inject, singleton
4 | from langchain_community.chat_models import ChatOpenAI
5 | from langchain.llms.base import BaseLanguageModel
6 | from langchain.llms.fake import FakeListLLM
7 | from langchain_emoji.settings.settings import Settings
8 | from langchain_emoji.components.llm.custom.zhipuai import ChatZhipuAI
9 | from langchain.schema.runnable import ConfigurableField
10 |
11 | logger = logging.getLogger(__name__)
12 |
13 |
14 | @singleton
15 | class LLMComponent:
16 | @inject
17 | def __init__(self, settings: Settings) -> None:
18 | llm_mode = settings.llm.mode
19 | logger.info("Initializing the LLM in mode=%s", llm_mode)
20 | self.modelname = settings.openai.modelname
21 | match settings.llm.mode:
22 | case "local":
23 | ... # Todo
24 | case "openai":
25 | openai_settings = settings.openai
26 | self._llm = ChatOpenAI(
27 | temperature=openai_settings.temperature,
28 | model_name=openai_settings.modelname,
29 | api_key=openai_settings.api_key,
30 | openai_api_base=openai_settings.api_base,
31 | model_kwargs={"response_format": {"type": "json_object"}},
32 | ).configurable_alternatives(
33 | # This gives this field an id
34 | # When configuring the end runnable, we can then use this id to configure this field
35 | ConfigurableField(id="llm"),
36 | default_key="openai",
37 | )
38 |
39 | case "zhipuai":
40 | zhipuai_settings = settings.zhipuai
41 | self._llm = ChatZhipuAI(
42 | model=zhipuai_settings.modelname,
43 | temperature=zhipuai_settings.temperature,
44 | top_p=zhipuai_settings.top_p,
45 | api_key=zhipuai_settings.api_key,
46 | ).configurable_alternatives(
47 | # This gives this field an id
48 | # When configuring the end runnable, we can then use this id to configure this field
49 | ConfigurableField(id="llm"),
50 | default_key="zhipuai",
51 | )
52 | case "deepseek":
53 | deepseek_settings = settings.deepseek
54 | self._llm = ChatOpenAI(
55 | model=deepseek_settings.modelname,
56 | temperature=deepseek_settings.temperature,
57 | api_key=deepseek_settings.api_key,
58 | openai_api_base=deepseek_settings.api_base,
59 | ).configurable_alternatives(
60 | # This gives this field an id
61 | # When configuring the end runnable, we can then use this id to configure this field
62 | ConfigurableField(id="llm"),
63 | default_key="deepseek",
64 | )
65 | case "all":
66 | openai_settings = settings.openai
67 | zhipuai_settings = settings.zhipuai
68 | deepseek_settings = settings.deepseek
69 | self._llm = ChatOpenAI(
70 | temperature=openai_settings.temperature,
71 | model_name=openai_settings.modelname,
72 | api_key=openai_settings.api_key,
73 | openai_api_base=openai_settings.api_base,
74 | model_kwargs={"response_format": {"type": "json_object"}},
75 | ).configurable_alternatives(
76 | # This gives this field an id
77 | # When configuring the end runnable, we can then use this id to configure this field
78 | ConfigurableField(id="llm"),
79 | default_key="openai",
80 | zhipuai=ChatZhipuAI(
81 | model=zhipuai_settings.modelname,
82 | temperature=zhipuai_settings.temperature,
83 | top_p=zhipuai_settings.top_p,
84 | api_key=zhipuai_settings.api_key,
85 | ),
86 | deepseek=ChatOpenAI(
87 | model=deepseek_settings.modelname,
88 | temperature=deepseek_settings.temperature,
89 | api_key=deepseek_settings.api_key,
90 | openai_api_base=deepseek_settings.api_base,
91 | ),
92 | )
93 |
94 | case "mock":
95 | self._llm = FakeListLLM(
96 | responses=["你好,帝阅AI表情包"]
97 | ).configurable_alternatives(
98 | # This gives this field an id
99 | # When configuring the end runnable, we can then use this id to configure this field
100 | ConfigurableField(id="llm"),
101 | default_key="mock",
102 | )
103 |
104 | @property
105 | def llm(self) -> BaseLanguageModel:
106 | return self._llm
107 |
--------------------------------------------------------------------------------
/langchain_emoji/components/minio/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ptonlix/LangChain-Emoji/4c0a926464a7419e48df32b3c5e26794e2f9b312/langchain_emoji/components/minio/__init__.py
--------------------------------------------------------------------------------
/langchain_emoji/components/minio/minio_component.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import base64
3 | from injector import inject, singleton
4 | from langchain_emoji.settings.settings import Settings
5 | from minio import Minio
6 | from minio.error import MinioException
7 |
8 | logger = logging.getLogger(__name__)
9 |
10 |
11 | @singleton
12 | class MinioComponent:
13 | @inject
14 | def __init__(self, settings: Settings) -> None:
15 | if not settings.minio:
16 | raise Exception("minio config is not exist! please check")
17 | self.minio_settings = settings.minio
18 | self.minio_client = Minio(
19 | endpoint=self.minio_settings.host,
20 | access_key=self.minio_settings.access_key,
21 | secret_key=self.minio_settings.secret_key,
22 | secure=False,
23 | )
24 |
25 | def get_file_base64(self, file_name: str) -> str:
26 | try:
27 | response = self.minio_client.get_object(
28 | self.minio_settings.bucket_name, file_name
29 | )
30 | # Read the object content
31 | object_data = response.read()
32 |
33 | # Encode object data to base64
34 | base64_data = base64.b64encode(object_data)
35 |
36 | return base64_data.decode("utf-8")
37 | except MinioException as e:
38 | logger.error(f"get file base64 failed : {e}")
39 | return None
40 |
41 | def get_download_link(self, file_name: str) -> str:
42 | try:
43 | # Generate presigned URL for download
44 | presigned_url = self.minio_client.presigned_get_object(
45 | self.minio_settings.bucket_name, file_name
46 | )
47 | return presigned_url
48 | except MinioException as e:
49 | logger.error(f"get share link failed : {e}")
50 | return None
51 |
52 |
53 | if __name__ == "__main__":
54 | from langchain_emoji.settings.settings import settings
55 |
56 | mc = MinioComponent(settings())
57 |
58 | obj = "06e07a24-df07-4781-a1da-58739ac65404.jpg"
59 |
60 | print(mc.get_file_base64(obj))
61 | print(mc.get_download_link(obj))
62 |
--------------------------------------------------------------------------------
/langchain_emoji/components/trace/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ptonlix/LangChain-Emoji/4c0a926464a7419e48df32b3c5e26794e2f9b312/langchain_emoji/components/trace/__init__.py
--------------------------------------------------------------------------------
/langchain_emoji/components/trace/trace_component.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | from injector import inject, singleton
4 | from langchain_emoji.settings.settings import Settings
5 | from langsmith import Client
6 | from langsmith.utils import LangSmithError
7 | import os
8 | import asyncio
9 |
10 | logger = logging.getLogger(__name__)
11 |
12 |
13 | @singleton
14 | class TraceComponent:
15 | @inject
16 | def __init__(self, settings: Settings) -> None:
17 | os.environ["LANGCHAIN_TRACING_V2"] = str(settings.langsmith.trace_version_v2)
18 | os.environ["LANGCHAIN_PROJECT"] = str(settings.langsmith.langchain_project)
19 | os.environ["LANGCHAIN_API_KEY"] = settings.langsmith.api_key
20 | self.trace_client = Client(api_key=settings.langsmith.api_key)
21 |
22 | async def _arun(self, func, *args, **kwargs):
23 | return await asyncio.get_running_loop().run_in_executor(
24 | None, func, *args, **kwargs
25 | )
26 |
27 | async def aget_trace_url(self, run_id: str) -> str:
28 | for i in range(5):
29 | try:
30 | await self._arun(self.trace_client.read_run, run_id)
31 | break
32 | except LangSmithError:
33 | await asyncio.sleep(1**i)
34 |
35 | if await self._arun(self.trace_client.run_is_shared, run_id):
36 | return await self._arun(self.trace_client.read_run_shared_link, run_id)
37 | return await self._arun(self.trace_client.share_run, run_id)
38 |
--------------------------------------------------------------------------------
/langchain_emoji/components/vector_store/__init__.py:
--------------------------------------------------------------------------------
1 | from langchain_emoji.components.vector_store.vector_store_component import (
2 | VectorStoreComponent,
3 | )
4 |
5 | __all__ = ["VectorStoreComponent"]
6 |
--------------------------------------------------------------------------------
/langchain_emoji/components/vector_store/chroma/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ptonlix/LangChain-Emoji/4c0a926464a7419e48df32b3c5e26794e2f9b312/langchain_emoji/components/vector_store/chroma/__init__.py
--------------------------------------------------------------------------------
/langchain_emoji/components/vector_store/chroma/chroma.py:
--------------------------------------------------------------------------------
1 | from langchain_chroma import Chroma
2 | from langchain_core.documents import Document
3 | from typing import Any, Iterable, List, Optional
4 |
5 |
6 | class EmojiChroma(Chroma):
7 |
8 | def add_original_texts_with_filename(
9 | self,
10 | filename: str,
11 | texts: Iterable[str],
12 | metadatas: Optional[List[dict]] = None,
13 | timeout: Optional[int] = None,
14 | batch_size: int = 1000,
15 | **kwargs: Any,
16 | ) -> List[str]:
17 | if not metadatas:
18 | metadatas = {
19 | "filename": filename,
20 | }
21 | return self.add_texts(
22 | texts=texts,
23 | metadatas=metadatas,
24 | )
25 |
26 | def similarity_search_by_filenames(
27 | self, query: str, filenames: List[str], k: int = 4
28 | ) -> List[Document]:
29 | where = None
30 | if len(filenames) > 0:
31 | where = {"filename": {"$in": filenames}}
32 |
33 | return self.similarity_search(query, k=k, filter=where)
34 |
35 | def delete_texts_with_filenames(
36 | self,
37 | document_ids: List[str],
38 | filenames: List[str] = [],
39 | batch_size: int = 20,
40 | expr: Optional[str] = None,
41 | timeout: Optional[int] = None,
42 | ):
43 | common_ids = document_ids
44 |
45 | if len(filenames) > 0:
46 | where = {"filename": {"$in": filenames}}
47 | result = self._collection.get(where=where)
48 | print(result)
49 | common_ids = [value for value in result.get("ids") if value in document_ids]
50 |
51 | return self.delete(
52 | ids=common_ids,
53 | )
54 |
--------------------------------------------------------------------------------
/langchain_emoji/components/vector_store/tencent/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ptonlix/LangChain-Emoji/4c0a926464a7419e48df32b3c5e26794e2f9b312/langchain_emoji/components/vector_store/tencent/__init__.py
--------------------------------------------------------------------------------
/langchain_emoji/components/vector_store/tencent/tencent.py:
--------------------------------------------------------------------------------
1 | """Wrapper around the Tencent vector database."""
2 |
3 | from __future__ import annotations
4 |
5 | import json
6 | import logging
7 | import time
8 | from typing import Any, Dict, Iterable, List, Optional, Tuple
9 |
10 | import numpy as np
11 | from langchain_core.documents import Document
12 | from langchain_core.embeddings import Embeddings
13 | from langchain_core.utils import guard_import
14 | from langchain_core.vectorstores import VectorStore
15 |
16 | from langchain.vectorstores.utils import maximal_marginal_relevance
17 |
18 |
19 | logger = logging.getLogger(__name__)
20 |
21 |
22 | class ConnectionParams:
23 | """Tencent vector DB Connection params.
24 |
25 | See the following documentation for details:
26 | https://cloud.tencent.com/document/product/1709/95820
27 |
28 | Attribute:
29 | url (str) : The access address of the vector database server
30 | that the client needs to connect to.
31 | key (str): API key for client to access the vector database server,
32 | which is used for authentication.
33 | username (str) : Account for client to access the vector database server.
34 | timeout (int) : Request Timeout.
35 | """
36 |
37 | def __init__(self, url: str, key: str, username: str = "root", timeout: int = 10):
38 | self.url = url
39 | self.key = key
40 | self.username = username
41 | self.timeout = timeout
42 |
43 |
44 | class IndexParams:
45 | """Tencent vector DB Index params.
46 |
47 | See the following documentation for details:
48 | https://cloud.tencent.com/document/product/1709/95826
49 | """
50 |
51 | def __init__(
52 | self,
53 | dimension: int,
54 | shard: int = 1,
55 | replicas: int = 2,
56 | index_type: str = "HNSW",
57 | metric_type: str = "L2",
58 | params: Optional[Dict] = None,
59 | ):
60 | self.dimension = dimension
61 | self.shard = shard
62 | self.replicas = replicas
63 | self.index_type = index_type
64 | self.metric_type = metric_type
65 | self.params = params
66 |
67 |
68 | class TencentVectorDB(VectorStore):
69 | """Initialize wrapper around the tencent vector database.
70 |
71 | In order to use this you need to have a database instance.
72 | See the following documentation for details:
73 | https://cloud.tencent.com/document/product/1709/94951
74 | """
75 |
76 | field_id: str = "id"
77 | field_vector: str = "vector"
78 | field_text: str = "text"
79 | field_metadata: str = "metadata"
80 |
81 | def __init__(
82 | self,
83 | connection_params: ConnectionParams,
84 | index_params: IndexParams = IndexParams(128),
85 | database_name: str = "LangChainDatabase",
86 | collection_name: str = "LangChainCollection",
87 | drop_old: Optional[bool] = False,
88 | embedding: Optional[Embeddings] = None,
89 | ):
90 | self.document = guard_import("tcvectordb.model.document")
91 | tcvectordb = guard_import("tcvectordb")
92 | self.embedding_func = embedding
93 | self.index_params = index_params
94 | self.ebd_own = (
95 | True if embedding is None else False
96 | ) # 是否采用腾讯云自带embedding
97 | self.vdb_client = tcvectordb.VectorDBClient(
98 | url=connection_params.url,
99 | username=connection_params.username,
100 | key=connection_params.key,
101 | timeout=connection_params.timeout,
102 | )
103 | db_list = self.vdb_client.list_databases()
104 | db_exist: bool = False
105 | for db in db_list:
106 | if database_name == db.database_name:
107 | db_exist = True
108 | break
109 | if db_exist:
110 | self.database = self.vdb_client.database(database_name)
111 | else:
112 | self.database = self.vdb_client.create_database(database_name)
113 | try:
114 | self.collection = self.database.describe_collection(collection_name)
115 | if drop_old:
116 | self.database.drop_collection(collection_name)
117 | self._create_collection(collection_name)
118 | except tcvectordb.exceptions.VectorDBException:
119 | self._create_collection(collection_name)
120 |
121 | def _create_collection(self, collection_name: str) -> None:
122 | enum = guard_import("tcvectordb.model.enum")
123 | vdb_index = guard_import("tcvectordb.model.index")
124 | coll = guard_import("tcvectordb.model.collection")
125 | index_type = None
126 | for k, v in enum.IndexType.__members__.items():
127 | if k == self.index_params.index_type:
128 | index_type = v
129 | if index_type is None:
130 | raise ValueError("unsupported index_type")
131 | metric_type = None
132 | for k, v in enum.MetricType.__members__.items():
133 | if k == self.index_params.metric_type:
134 | metric_type = v
135 | if metric_type is None:
136 | raise ValueError("unsupported metric_type")
137 | if self.index_params.params is None:
138 | params = vdb_index.HNSWParams(m=16, efconstruction=200)
139 | else:
140 | params = vdb_index.HNSWParams(
141 | m=self.index_params.params.get("M", 16),
142 | efconstruction=self.index_params.params.get("efConstruction", 200),
143 | )
144 | index = vdb_index.Index(
145 | vdb_index.FilterIndex(
146 | self.field_id, enum.FieldType.String, enum.IndexType.PRIMARY_KEY
147 | ),
148 | vdb_index.VectorIndex(
149 | self.field_vector,
150 | self.index_params.dimension,
151 | index_type,
152 | metric_type,
153 | params,
154 | ),
155 | vdb_index.FilterIndex(
156 | self.field_text, enum.FieldType.String, enum.IndexType.FILTER
157 | ),
158 | vdb_index.FilterIndex(
159 | self.field_metadata, enum.FieldType.String, enum.IndexType.FILTER
160 | ),
161 | )
162 | # Embedding config
163 | ebd = None
164 | if self.ebd_own:
165 | ebd = coll.Embedding(
166 | vector_field=self.field_vector,
167 | field=self.field_text,
168 | model=enum.EmbeddingModel.BGE_BASE_ZH,
169 | )
170 |
171 | self.collection = self.database.create_collection(
172 | name=collection_name,
173 | shard=self.index_params.shard,
174 | replicas=self.index_params.replicas,
175 | description="Collection for LangChain",
176 | index=index,
177 | embedding=ebd,
178 | )
179 |
180 | @property
181 | def embeddings(self) -> Embeddings:
182 | return self.embedding_func
183 |
184 | @classmethod
185 | def from_texts(
186 | cls,
187 | texts: List[str],
188 | metadatas: Optional[List[dict]] = None,
189 | connection_params: Optional[ConnectionParams] = None,
190 | index_params: Optional[IndexParams] = None,
191 | database_name: str = "LangChainDatabase",
192 | collection_name: str = "LangChainCollection",
193 | drop_old: Optional[bool] = False,
194 | embedding: Optional[Embeddings] = None,
195 | **kwargs: Any,
196 | ) -> TencentVectorDB:
197 | """Create a collection, indexes it with HNSW, and insert data."""
198 | if len(texts) == 0:
199 | raise ValueError("texts is empty")
200 | if connection_params is None:
201 | raise ValueError("connection_params is empty")
202 | vector_db = None
203 | if embedding:
204 | try:
205 | embeddings = embedding.embed_documents(texts[0:1])
206 | except NotImplementedError:
207 | embeddings = [embedding.embed_query(texts[0])]
208 | dimension = len(embeddings[0])
209 | if index_params is None:
210 | index_params = IndexParams(dimension=dimension)
211 | else:
212 | index_params.dimension = dimension
213 | vector_db = cls(
214 | embedding=embedding,
215 | connection_params=connection_params,
216 | index_params=index_params,
217 | database_name=database_name,
218 | collection_name=collection_name,
219 | drop_old=drop_old,
220 | )
221 | vector_db.add_texts(texts=texts, metadatas=metadatas, **kwargs)
222 | else: # 使用腾讯 embedding
223 | if index_params is None:
224 | index_params = IndexParams(dimension=768)
225 | vector_db = cls(
226 | embedding=embedding,
227 | connection_params=connection_params,
228 | index_params=index_params,
229 | database_name=database_name,
230 | collection_name=collection_name,
231 | drop_old=drop_old,
232 | )
233 | vector_db.add_original_texts(texts=texts, metadatas=metadatas, **kwargs)
234 | return vector_db
235 |
236 | def add_texts(
237 | self,
238 | texts: Iterable[str],
239 | metadatas: Optional[List[dict]] = None,
240 | timeout: Optional[int] = None,
241 | batch_size: int = 1000,
242 | **kwargs: Any,
243 | ) -> List[str]:
244 | """Insert text data into TencentVectorDB."""
245 | texts = list(texts)
246 | pks: list[str] = []
247 | try:
248 | embeddings = self.embedding_func.embed_documents(texts)
249 | except NotImplementedError:
250 | embeddings = [self.embedding_func.embed_query(x) for x in texts]
251 | if len(embeddings) == 0:
252 | logger.debug("Nothing to insert, skipping.")
253 | return []
254 | total_count = len(embeddings)
255 | for start in range(0, total_count, batch_size):
256 | # Grab end index
257 | docs = []
258 | end = min(start + batch_size, total_count)
259 | for id in range(start, end, 1):
260 | metadata = "{}"
261 | if metadatas is not None:
262 | metadata = json.dumps(metadatas[id])
263 | doc = self.document.Document(
264 | id="{}-{}-{}".format(time.time_ns(), hash(texts[id]), id),
265 | vector=embeddings[id],
266 | text=texts[id],
267 | metadata=metadata,
268 | **kwargs,
269 | )
270 | docs.append(doc)
271 | pks.append(str(id))
272 | self.collection.upsert(docs, timeout)
273 | return pks
274 |
275 | def add_original_texts(
276 | self,
277 | texts: Iterable[str],
278 | metadatas: Optional[List[dict]] = None,
279 | timeout: Optional[int] = None,
280 | batch_size: int = 1000,
281 | **kwargs: Any,
282 | ) -> List[str]:
283 | """Insert text data into TencentVectorDB."""
284 | texts = list(texts)
285 | pks: list[str] = []
286 | total_count = len(texts)
287 | for start in range(0, total_count, batch_size):
288 | # Grab end index
289 | docs = []
290 | end = min(start + batch_size, total_count)
291 | for id in range(start, end, 1):
292 | metadata = "{}"
293 | if metadatas is not None:
294 | metadata = json.dumps(metadatas[id])
295 | vdb_id = "{}-{}-{}".format(time.time_ns(), hash(texts[id]), id)
296 | doc = self.document.Document(
297 | id=vdb_id,
298 | text=texts[id],
299 | metadata=metadata,
300 | **kwargs,
301 | )
302 | docs.append(doc)
303 | pks.append(vdb_id) # 返回向量数据库中的ID
304 | self.collection.upsert(docs, timeout)
305 | return pks
306 |
307 | def similarity_search(
308 | self,
309 | query: str,
310 | k: int = 4,
311 | param: Optional[dict] = None,
312 | expr: Optional[str] = None,
313 | timeout: Optional[int] = None,
314 | **kwargs: Any,
315 | ) -> List[Document]:
316 | """Perform a similarity search against the query string."""
317 | res = self.similarity_search_with_score(
318 | query=query, k=k, param=param, expr=expr, timeout=timeout, **kwargs
319 | )
320 | return [doc for doc, _ in res]
321 |
322 | def similarity_search_with_score(
323 | self,
324 | query: str,
325 | k: int = 4,
326 | param: Optional[dict] = None,
327 | expr: Optional[str] = None,
328 | timeout: Optional[int] = None,
329 | **kwargs: Any,
330 | ) -> List[Tuple[Document, float]]:
331 | """Perform a search on a query string and return results with score."""
332 | res: List[Tuple[Document, float]] = []
333 | if self.ebd_own:
334 | res = self.similarity_search_with_score_by_text(
335 | text=[query],
336 | k=k,
337 | param=param,
338 | expr=expr,
339 | timeout=timeout,
340 | **kwargs,
341 | )
342 | else:
343 | # Embed the query text.
344 | embedding = self.embedding_func.embed_query(query)
345 | res = self.similarity_search_with_score_by_vector(
346 | embedding=embedding,
347 | k=k,
348 | param=param,
349 | expr=expr,
350 | timeout=timeout,
351 | **kwargs,
352 | )
353 | return res
354 |
355 | def similarity_search_by_vector(
356 | self,
357 | embedding: List[float],
358 | k: int = 4,
359 | param: Optional[dict] = None,
360 | expr: Optional[str] = None,
361 | timeout: Optional[int] = None,
362 | **kwargs: Any,
363 | ) -> List[Document]:
364 | """Perform a similarity search against the query string."""
365 | res = self.similarity_search_with_score_by_vector(
366 | embedding=embedding, k=k, param=param, expr=expr, timeout=timeout, **kwargs
367 | )
368 | return [doc for doc, _ in res]
369 |
370 | def similarity_search_with_score_by_text(
371 | self,
372 | text: List[str],
373 | k: int = 4,
374 | param: Optional[dict] = None,
375 | expr: Optional[str] = None,
376 | timeout: Optional[int] = None,
377 | **kwargs: Any,
378 | ) -> List[Tuple[Document, float]]:
379 | """Perform a search on a query string and return results with score."""
380 | filter = None if expr is None else self.document.Filter(expr)
381 | ef = 10 if param is None else param.get("ef", 10)
382 | res_data = self.collection.searchByText(
383 | embeddingItems=text,
384 | filter=filter,
385 | params=self.document.HNSWSearchParams(ef=ef),
386 | retrieve_vector=False,
387 | limit=k,
388 | timeout=timeout,
389 | )
390 | if "documents" not in res_data:
391 | raise ValueError(res_data)
392 | res: List[List[Dict]] = res_data.get("documents")
393 | # Organize results.
394 | ret: List[Tuple[Document, float]] = []
395 | if res is None or len(res) == 0:
396 | return ret
397 | for result in res[0]:
398 | meta = result.get(self.field_metadata)
399 | if meta is not None:
400 | meta = json.loads(meta)
401 | doc = Document(page_content=result.get(self.field_text), metadata=meta)
402 | pair = (doc, result.get("score", 0.0))
403 | ret.append(pair)
404 | return ret
405 |
406 | def similarity_search_with_score_by_vector(
407 | self,
408 | embedding: List[float],
409 | k: int = 4,
410 | param: Optional[dict] = None,
411 | expr: Optional[str] = None,
412 | timeout: Optional[int] = None,
413 | **kwargs: Any,
414 | ) -> List[Tuple[Document, float]]:
415 | """Perform a search on a query string and return results with score."""
416 | filter = None if expr is None else self.document.Filter(expr)
417 | ef = 10 if param is None else param.get("ef", 10)
418 | res_data = self.collection.search(
419 | vectors=[embedding],
420 | filter=filter,
421 | params=self.document.HNSWSearchParams(ef=ef),
422 | retrieve_vector=False,
423 | limit=k,
424 | timeout=timeout,
425 | )
426 | if "documents" not in res_data:
427 | raise ValueError(res_data)
428 | res: List[List[Dict]] = res_data.get("documents") # Organize results.
429 | ret: List[Tuple[Document, float]] = []
430 | if res is None or len(res) == 0:
431 | return ret
432 | for result in res[0]:
433 | meta = result.get(self.field_metadata)
434 | if meta is not None:
435 | meta = json.loads(meta)
436 | doc = Document(page_content=result.get(self.field_text), metadata=meta)
437 | pair = (doc, result.get("score", 0.0))
438 | ret.append(pair)
439 | return ret
440 |
441 | def max_marginal_relevance_search(
442 | self,
443 | query: str,
444 | k: int = 4,
445 | fetch_k: int = 20,
446 | lambda_mult: float = 0.5,
447 | param: Optional[dict] = None,
448 | expr: Optional[str] = None,
449 | timeout: Optional[int] = None,
450 | **kwargs: Any,
451 | ) -> List[Document]:
452 | """Perform a search and return results that are reordered by MMR."""
453 | embedding = self.embedding_func.embed_query(query)
454 | return self.max_marginal_relevance_search_by_vector(
455 | embedding=embedding,
456 | k=k,
457 | fetch_k=fetch_k,
458 | lambda_mult=lambda_mult,
459 | param=param,
460 | expr=expr,
461 | timeout=timeout,
462 | **kwargs,
463 | )
464 |
465 | def max_marginal_relevance_search_by_vector(
466 | self,
467 | embedding: list[float],
468 | k: int = 4,
469 | fetch_k: int = 20,
470 | lambda_mult: float = 0.5,
471 | param: Optional[dict] = None,
472 | expr: Optional[str] = None,
473 | timeout: Optional[int] = None,
474 | **kwargs: Any,
475 | ) -> List[Document]:
476 | """Perform a search and return results that are reordered by MMR."""
477 | filter = None if expr is None else self.document.Filter(expr)
478 | ef = 10 if param is None else param.get("ef", 10)
479 | res: List[List[Dict]] = self.collection.search(
480 | vectors=[embedding],
481 | filter=filter,
482 | params=self.document.HNSWSearchParams(ef=ef),
483 | retrieve_vector=True,
484 | limit=fetch_k,
485 | timeout=timeout,
486 | )
487 | # Organize results.
488 | documents = []
489 | ordered_result_embeddings = []
490 | for result in res[0]:
491 | meta = result.get(self.field_metadata)
492 | if meta is not None:
493 | meta = json.loads(meta)
494 | doc = Document(page_content=result.get(self.field_text), metadata=meta)
495 | documents.append(doc)
496 | ordered_result_embeddings.append(result.get(self.field_vector))
497 | # Get the new order of results.
498 | new_ordering = maximal_marginal_relevance(
499 | np.array(embedding), ordered_result_embeddings, k=k, lambda_mult=lambda_mult
500 | )
501 | # Reorder the values and return.
502 | ret = []
503 | for x in new_ordering:
504 | # Function can return -1 index
505 | if x == -1:
506 | break
507 | else:
508 | ret.append(documents[x])
509 | return ret
510 |
511 | def delete_texts_by_ids(
512 | self,
513 | document_ids: List[str],
514 | batch_size: int = 20,
515 | expr: Optional[str] = None,
516 | timeout: Optional[int] = None,
517 | ) -> List[dict]:
518 | filter = None if expr is None else self.document.Filter(expr)
519 | total_count = len(document_ids)
520 | ret = []
521 | for start in range(0, total_count, batch_size):
522 | # Grab end index
523 | docs = []
524 | end = min(start + batch_size, total_count)
525 | docs = document_ids[start:end]
526 | res_data = self.collection.delete(
527 | document_ids=docs, filter=filter, timeout=timeout
528 | )
529 |
530 | if res_data.get("code") != 0:
531 | raise ValueError(f"delete texts failed: {res_data}")
532 | ret.append(res_data)
533 | return ret
534 |
535 |
536 | class EmojiTencentVectorDB(TencentVectorDB):
537 | field_filename: str = "filename"
538 |
539 | def __init__(
540 | self,
541 | connection_params: ConnectionParams,
542 | index_params: IndexParams = IndexParams(768, replicas=0), # replicas 默认单实例
543 | database_name: str = "DeepReadDatabase",
544 | collection_name: str = "EmojiCollection",
545 | embedding: Optional[Embeddings] = None,
546 | drop_old: Optional[bool] = False,
547 | ):
548 | super().__init__(
549 | embedding=embedding,
550 | connection_params=connection_params,
551 | index_params=index_params,
552 | database_name=database_name,
553 | collection_name=collection_name,
554 | drop_old=drop_old,
555 | )
556 |
557 | def _create_collection(self, collection_name: str) -> None:
558 | enum = guard_import("tcvectordb.model.enum")
559 | vdb_index = guard_import("tcvectordb.model.index")
560 | coll = guard_import("tcvectordb.model.collection")
561 | index_type = None
562 | for k, v in enum.IndexType.__members__.items():
563 | if k == self.index_params.index_type:
564 | index_type = v
565 | if index_type is None:
566 | raise ValueError("unsupported index_type")
567 | metric_type = None
568 | for k, v in enum.MetricType.__members__.items():
569 | if k == self.index_params.metric_type:
570 | metric_type = v
571 | if metric_type is None:
572 | raise ValueError("unsupported metric_type")
573 | if self.index_params.params is None:
574 | params = vdb_index.HNSWParams(m=16, efconstruction=200)
575 | else:
576 | params = vdb_index.HNSWParams(
577 | m=self.index_params.params.get("M", 16),
578 | efconstruction=self.index_params.params.get("efConstruction", 200),
579 | )
580 | index = vdb_index.Index(
581 | vdb_index.FilterIndex(
582 | self.field_id, enum.FieldType.String, enum.IndexType.PRIMARY_KEY
583 | ),
584 | vdb_index.VectorIndex(
585 | self.field_vector,
586 | self.index_params.dimension,
587 | index_type,
588 | metric_type,
589 | params,
590 | ),
591 | vdb_index.FilterIndex(
592 | self.field_text, enum.FieldType.String, enum.IndexType.FILTER
593 | ),
594 | vdb_index.FilterIndex(
595 | self.field_metadata, enum.FieldType.String, enum.IndexType.FILTER
596 | ),
597 | vdb_index.FilterIndex(
598 | self.field_filename, enum.FieldType.String, enum.IndexType.FILTER
599 | ),
600 | )
601 | # Embedding config
602 | ebd = None
603 | if self.ebd_own:
604 | ebd = coll.Embedding(
605 | vector_field=self.field_vector,
606 | field=self.field_text,
607 | model=enum.EmbeddingModel.BGE_BASE_ZH,
608 | )
609 |
610 | self.collection = self.database.create_collection(
611 | name=collection_name,
612 | shard=self.index_params.shard,
613 | replicas=self.index_params.replicas,
614 | description="Collection for DeepRead",
615 | index=index,
616 | embedding=ebd,
617 | )
618 |
619 | @classmethod
620 | def from_texts_with_filename(
621 | cls,
622 | filename: str,
623 | texts: List[str],
624 | metadatas: Optional[List[dict]] = None,
625 | connection_params: Optional[ConnectionParams] = None,
626 | database_name: str = "DeepReadDatabase",
627 | collection_name: str = "DeepReadCollection",
628 | ) -> TencentVectorDB:
629 | return cls.from_texts(
630 | texts=texts,
631 | metadatas=metadatas,
632 | connection_params=connection_params,
633 | database_name=database_name,
634 | collection_name=collection_name,
635 | filename=filename,
636 | )
637 |
638 | def add_original_texts_with_filename(
639 | self,
640 | filename: str,
641 | texts: Iterable[str],
642 | metadatas: Optional[List[dict]] = None,
643 | timeout: Optional[int] = None,
644 | batch_size: int = 1000,
645 | **kwargs: Any,
646 | ) -> List[str]:
647 | return self.add_original_texts(
648 | texts=texts,
649 | metadatas=metadatas,
650 | timeout=timeout,
651 | batch_size=batch_size,
652 | filename=filename,
653 | )
654 |
655 | def similarity_search_by_filenames(
656 | self, query: str, filenames: List[str], k: int = 4
657 | ) -> List[Document]:
658 | expr = ""
659 | if len(filenames) > 0:
660 | for i in range(len(filenames) - 1):
661 | expr += f"{self.field_filename}={filenames[i]} or"
662 | expr += f"{self.field_filename}={filenames[-1]}"
663 |
664 | return self.similarity_search(query, k=k, expr=expr)
665 |
666 | def delete_texts_with_filenames(
667 | self,
668 | document_ids: List[str],
669 | filenames: List[str] = [],
670 | batch_size: int = 20,
671 | expr: Optional[str] = None,
672 | timeout: Optional[int] = None,
673 | ):
674 | expr = ""
675 | if len(filenames) > 0:
676 | for i in range(len(filenames) - 1):
677 | expr += f"{self.field_filename}={filenames[i]} or"
678 | expr += f"{self.field_filename}={filenames[-1]}"
679 |
680 | return self.delete_texts_by_ids(
681 | document_ids=document_ids,
682 | batch_size=batch_size,
683 | expr=expr,
684 | timeout=timeout,
685 | )
686 |
--------------------------------------------------------------------------------
/langchain_emoji/components/vector_store/vector_store_component.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | from injector import inject, singleton
4 | from pathlib import Path
5 | from langchain_emoji.settings.settings import Settings
6 | from langchain_emoji.components.vector_store.tencent.tencent import (
7 | EmojiTencentVectorDB,
8 | ConnectionParams,
9 | )
10 | from langchain_emoji.components.vector_store.chroma.chroma import EmojiChroma
11 | from chromadb.config import Settings as ChromaSettings
12 |
13 | from langchain_emoji.constants import PROJECT_ROOT_PATH
14 | from langchain_emoji.components.embedding.embedding_component import EmbeddingComponent
15 |
16 | logger = logging.getLogger(__name__)
17 |
18 |
19 | @singleton
20 | class VectorStoreComponent:
21 | @inject
22 | def __init__(self, embed: EmbeddingComponent, settings: Settings) -> None:
23 | self.embedcom = embed
24 | match settings.vectorstore.database:
25 | case "tcvectordb":
26 | tcvectorconf = settings.vectorstore.tcvectordb
27 | connect_params = ConnectionParams(
28 | url=tcvectorconf.url, key=tcvectorconf.api_key
29 | )
30 | self.vector_store = EmojiTencentVectorDB(
31 | connection_params=connect_params,
32 | database_name=tcvectorconf.database_name,
33 | collection_name=tcvectorconf.collection_name,
34 | )
35 | case "chromadb":
36 | data_dir = PROJECT_ROOT_PATH / settings.vectorstore.chromadb.persist_dir
37 | persist_directory = self.create_persist_directory("chromadb", data_dir)
38 | collection_name = settings.vectorstore.chromadb.collection_name
39 | self.vector_store = EmojiChroma(
40 | collection_name,
41 | embed._embedding,
42 | client_settings=ChromaSettings(
43 | anonymized_telemetry=False,
44 | is_persistent=True,
45 | persist_directory=persist_directory,
46 | ),
47 | )
48 | case _:
49 | # Should be unreachable
50 | # The settings validator should have caught this
51 | raise ValueError(
52 | f"Vectorstore database {settings.vectorstore.database} not supported"
53 | )
54 |
55 | def create_persist_directory(self, vectordb: str, data_dir: Path) -> str:
56 | if not (os.path.exists(data_dir) and os.path.isdir(data_dir)):
57 | raise FileNotFoundError(f"{data_dir} Error, Please Check Config")
58 |
59 | persist_directory = data_dir / vectordb
60 | # 检查目录是否存在,如果不存在则创建
61 | if not os.path.exists(persist_directory):
62 | os.makedirs(persist_directory)
63 |
64 | return str(persist_directory)
65 |
66 | def close(self) -> None: ...
67 |
--------------------------------------------------------------------------------
/langchain_emoji/constants.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | PROJECT_ROOT_PATH: Path = Path(__file__).parents[1]
4 |
--------------------------------------------------------------------------------
/langchain_emoji/di.py:
--------------------------------------------------------------------------------
1 | from injector import Injector
2 |
3 | from langchain_emoji.settings.settings import Settings, unsafe_typed_settings
4 |
5 |
6 | def create_application_injector() -> Injector:
7 | _injector = Injector(auto_bind=True)
8 | _injector.binder.bind(Settings, to=unsafe_typed_settings)
9 | return _injector
10 |
11 |
12 | """
13 | Global injector for the application.
14 |
15 | Avoid using this reference, it will make your code harder to test.
16 |
17 | Instead, use the `request.state.injector` reference, which is bound to every request
18 | """
19 | global_injector: Injector = create_application_injector()
20 |
--------------------------------------------------------------------------------
/langchain_emoji/launcher.py:
--------------------------------------------------------------------------------
1 | """FastAPI app creation, logger configuration and main API routes."""
2 |
3 | import logging
4 | from typing import Any
5 |
6 | from fastapi import Depends, FastAPI, Request
7 | from fastapi.middleware.cors import CORSMiddleware
8 | from fastapi.openapi.utils import get_openapi
9 | from injector import Injector
10 | from langchain_emoji.paths import docs_path
11 | from langchain_emoji.settings.settings import Settings
12 | from langchain_emoji.server.emoji.emoji_router import emoji_router
13 | from langchain_emoji.server.vector_store.vector_store_router import vector_store_router
14 | from langchain_emoji.server.trace.trace_router import trace_router
15 | from langchain_emoji.server.health.health_router import health_router
16 | from langchain_emoji.server.config.config_router import (
17 | config_router_no_auth,
18 | config_router,
19 | )
20 |
21 | logger = logging.getLogger(__name__)
22 |
23 |
24 | def create_app(root_injector: Injector) -> FastAPI:
25 | # Start the API
26 | with open(docs_path / "description.md") as description_file:
27 | description = description_file.read()
28 |
29 | tags_metadata = [
30 | {
31 | "name": "Emoji",
32 | "description": "AI-enabled Emoji engine",
33 | },
34 | {
35 | "name": "VectorStore",
36 | "description": "Store the emoji vectorically",
37 | },
38 | {
39 | "name": "Config",
40 | "description": "Obtain and modify project configuration files",
41 | },
42 | {
43 | "name": "Health",
44 | "description": "Simple health API to make sure the server is up and running.",
45 | },
46 | ]
47 |
48 | async def bind_injector_to_request(request: Request) -> None:
49 | request.state.injector = root_injector
50 |
51 | app = FastAPI(dependencies=[Depends(bind_injector_to_request)])
52 |
53 | def custom_openapi() -> dict[str, Any]:
54 | if app.openapi_schema:
55 | return app.openapi_schema
56 | openapi_schema = get_openapi(
57 | title="Langchain-SearXNG",
58 | description=description,
59 | version="0.1.0",
60 | summary="AI Emoji Argue Agent 🚀 基于LangChain的开源表情包斗图Agent",
61 | contact={
62 | "url": "https://github.com/ptonlix",
63 | },
64 | license_info={
65 | "name": "Apache 2.0",
66 | "url": "https://www.apache.org/licenses/LICENSE-2.0.html",
67 | },
68 | routes=app.routes,
69 | tags=tags_metadata,
70 | )
71 | openapi_schema["info"]["x-logo"] = {
72 | "url": "https://lh3.googleusercontent.com/drive-viewer"
73 | "/AK7aPaD_iNlMoTquOBsw4boh4tIYxyEuhz6EtEs8nzq3yNkNAK00xGj"
74 | "E1KUCmPJSk3TYOjcs6tReG6w_cLu1S7L_gPgT9z52iw=s2560"
75 | }
76 |
77 | app.openapi_schema = openapi_schema
78 | return app.openapi_schema
79 |
80 | app.openapi = custom_openapi # type: ignore[method-assign]
81 |
82 | app.include_router(emoji_router)
83 | app.include_router(trace_router)
84 | app.include_router(vector_store_router)
85 | app.include_router(health_router)
86 | app.include_router(config_router_no_auth)
87 | app.include_router(config_router)
88 |
89 | settings = root_injector.get(Settings)
90 | if settings.server.cors.enabled:
91 | logger.debug("Setting up CORS middleware")
92 | app.add_middleware(
93 | CORSMiddleware,
94 | allow_credentials=settings.server.cors.allow_credentials,
95 | allow_origins=settings.server.cors.allow_origins,
96 | allow_origin_regex=settings.server.cors.allow_origin_regex,
97 | allow_methods=settings.server.cors.allow_methods,
98 | allow_headers=settings.server.cors.allow_headers,
99 | )
100 |
101 | return app
102 |
--------------------------------------------------------------------------------
/langchain_emoji/main.py:
--------------------------------------------------------------------------------
1 | """FastAPI app creation, logger configuration and main API routes."""
2 |
3 | from langchain_emoji.di import global_injector
4 | from langchain_emoji.launcher import create_app
5 |
6 | app = create_app(global_injector)
7 |
--------------------------------------------------------------------------------
/langchain_emoji/paths.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | from langchain_emoji.constants import PROJECT_ROOT_PATH
4 | from langchain_emoji.settings.settings import settings
5 |
6 |
7 | def _absolute_or_from_project_root(path: str) -> Path:
8 | if path.startswith("/"):
9 | return Path(path)
10 | return PROJECT_ROOT_PATH / path
11 |
12 |
13 | docs_path: Path = PROJECT_ROOT_PATH / "docs"
14 |
15 | local_data_path: Path = _absolute_or_from_project_root(
16 | settings().data.local_data_folder
17 | )
18 |
--------------------------------------------------------------------------------
/langchain_emoji/server/config/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ptonlix/LangChain-Emoji/4c0a926464a7419e48df32b3c5e26794e2f9b312/langchain_emoji/server/config/__init__.py
--------------------------------------------------------------------------------
/langchain_emoji/server/config/config_router.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from fastapi import APIRouter, Depends, Request
3 | from typing import Dict, List
4 | from langchain_emoji.server.utils.auth import authenticated
5 | from langchain_emoji.server.utils.model import (
6 | RestfulModel,
7 | SystemErrorCode,
8 | )
9 | from langchain_emoji.settings.settings_loader import (
10 | get_active_settings,
11 | save_active_settings,
12 | )
13 |
14 |
15 | logger = logging.getLogger(__name__)
16 |
17 | config_router_no_auth = APIRouter(prefix="/v1")
18 |
19 | config_router = APIRouter(prefix="/v1", dependencies=[Depends(authenticated)])
20 |
21 |
22 | @config_router_no_auth.get(
23 | "/config",
24 | response_model=RestfulModel[List[Dict]],
25 | tags=["Config"],
26 | )
27 | async def get_config(request: Request) -> RestfulModel:
28 |
29 | try:
30 | return RestfulModel(data=get_active_settings())
31 | except Exception as e:
32 | return RestfulModel(code=SystemErrorCode, msg=str(e), data=None)
33 |
34 |
35 | @config_router.post(
36 | "/config",
37 | response_model=RestfulModel[None],
38 | tags=["Config"],
39 | )
40 | async def edit_config(request: Request, body: List[Dict]) -> RestfulModel:
41 | try:
42 | for config_dict in body:
43 | for profile, config in config_dict.items():
44 | print(profile, config)
45 | save_active_settings(profile, config)
46 |
47 | return RestfulModel(data=None)
48 | except Exception as e:
49 | logger.exception(e)
50 | return RestfulModel(code=SystemErrorCode, msg=str(e), data=None)
51 |
--------------------------------------------------------------------------------
/langchain_emoji/server/emoji/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ptonlix/LangChain-Emoji/4c0a926464a7419e48df32b3c5e26794e2f9b312/langchain_emoji/server/emoji/__init__.py
--------------------------------------------------------------------------------
/langchain_emoji/server/emoji/emoji_prompt.py:
--------------------------------------------------------------------------------
1 | RESPONSE_TEMPLATE = """\
2 | # Role: 一个表情包专家,擅长根据用户描述为用户选取最合适表情包
3 |
4 | ## Language: 中文
5 |
6 | ## Workflow
7 | 1. 学习 ##EmojiList 中给出表情包列表中每一个表情包的含义。其filename属性记录了表情包文件名`filename`,内容则是表情包的含义表示。
8 | 2. 根据 ##UserInput,选取一个最符合的表情包并返回, 一定不要自己构造数据,按照指定的JSON格式结构输出结果, 包含以下2个关键输出字段: `filename`、`content`。
9 |
10 | ## EmojiList
11 |
12 | {context}
13 |
14 |
15 | ## UserInput
16 |
17 | {prompt}
18 |
19 |
20 | ## Output format
21 |
22 | The output should be formatted as a JSON instance that conforms to the JSON schema below.
23 | filename: str
24 | content: str
25 | As an example, for the schema
26 | {{
27 | "filename":"",
28 | "content":"",
29 | }}
30 |
31 | 输出示例:
32 | ```json
33 | {{
34 | "filename": "5a122755-9316-4d05-81f4-26da5396c04e.jpg",
35 | "content": "这个表情包中的内容和笑点在于它展示了许多带有悲伤或不满情绪的表情符号,这些表情符号的脸部表情看起来都非常忧郁或不高兴。图片下方的文字“我的世界一片灰色”可能意味着这个表情包的使用者感到沮丧或情绪低落,就像世界失去了颜色一样。这种夸张的表达方式和文字与表情符号的结合,使得这个表情包在传达负面情绪的同时,也带有一定的幽默感。"
36 | }}
37 | ```
38 |
39 |
40 | ## Start
41 | 作为一个 #Role, 你默认使用的是##Language,你不需要介绍自己,请根据##Workflow开始工作,你必须严格遵守输出格式##Output format,输出格式指定的JSON格式要求。
42 | """
43 |
44 | # 以下为Prompt 备份
45 |
46 | ZHIPUAI_RESPONSE_TEMPLATE = """
47 | 表情包列表:
48 | {context}
49 |
50 | 用户描述:
51 | {prompt}
52 |
53 | 请根据以下要求,根据用户描述为用户选取最合适表情包:
54 |
55 | 1. 学习`表情包列表`中每一个表情包的含义, 其中metadata属性记录了表情包文件名,内容则是表情包的含义表示
56 |
57 | 2. 根据`用户描述`,选取一个最符合的表情包,一定不要自己构造数据,请按照指定的JSON格式结构输出结果, 包含以下2个关键输出字段: `filename`、`content`,具体格式如下:
58 | ```json
59 | {{
60 | "filename": string,
61 | "content": string
62 | }}
63 | ```
64 | 输出示例:
65 | ```json
66 | {{
67 | "filename": "5a122755-9316-4d05-81f4-26da5396c04e.jpg",
68 | "content": "这个表情包中的内容和笑点在于它展示了许多带有悲伤或不满情绪的表情符号,这些表情符号的脸部表情看起来都非常忧郁或不高兴。图片下方的文字“我的世界一片灰色”可能意味着这个表情包的使用者感到沮丧或情绪低落,就像世界失去了颜色一样。这种夸张的表达方式和文字与表情符号的结合,使得这个表情包在传达负面情绪的同时,也带有一定的幽默感。"
69 | }}
70 | ```
71 |
72 | 请严格按照上述要求进行信息提取、格式输出,并遵守输出格式指定的JSON格式要求。
73 | """
74 |
--------------------------------------------------------------------------------
/langchain_emoji/server/emoji/emoji_router.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from fastapi import APIRouter, Depends, Request
3 | from langchain_emoji.server.utils.auth import authenticated
4 | from langchain_emoji.server.emoji.emoji_service import (
5 | EmojiService,
6 | EmojiRequest,
7 | EmojiResponse,
8 | )
9 | from langchain_emoji.server.utils.model import (
10 | RestfulModel,
11 | SystemErrorCode,
12 | )
13 |
14 | logger = logging.getLogger(__name__)
15 |
16 | emoji_router = APIRouter(prefix="/v1", dependencies=[Depends(authenticated)])
17 |
18 |
19 | @emoji_router.post(
20 | "/emoji",
21 | response_model=RestfulModel[EmojiResponse | int | None],
22 | tags=["Emoji"],
23 | )
24 | async def emoji_invoke(request: Request, body: EmojiRequest) -> RestfulModel:
25 | """
26 | Call directly to return search results
27 | """
28 | service = request.state.injector.get(EmojiService)
29 | try:
30 | return RestfulModel(data=await service.get_emoji(body))
31 | except Exception as e:
32 | logger.exception(e)
33 | return RestfulModel(code=SystemErrorCode, msg=str(e), data=None)
34 |
--------------------------------------------------------------------------------
/langchain_emoji/server/emoji/emoji_service.py:
--------------------------------------------------------------------------------
1 | from injector import inject
2 | from langchain_emoji.components.llm.llm_component import LLMComponent
3 | from langchain_emoji.components.trace.trace_component import TraceComponent
4 | from langchain_emoji.components.minio.minio_component import MinioComponent
5 | from langchain_emoji.components.vector_store.vector_store_component import (
6 | VectorStoreComponent,
7 | )
8 |
9 | from langchain_emoji.server.emoji.emoji_prompt import (
10 | RESPONSE_TEMPLATE,
11 | ZHIPUAI_RESPONSE_TEMPLATE,
12 | )
13 | from langchain.schema.output_parser import StrOutputParser
14 | from pydantic import BaseModel, Field
15 | import logging
16 | import tiktoken
17 | from langchain_emoji.settings.settings import Settings
18 | from langchain.schema.document import Document
19 | from langchain.schema.runnable import (
20 | Runnable,
21 | RunnableLambda,
22 | RunnableBranch,
23 | RunnableMap,
24 | )
25 | from langchain.schema.language_model import BaseLanguageModel
26 | from langchain.schema.messages import BaseMessage
27 | from langchain.schema.retriever import BaseRetriever
28 | from langchain.prompts import ChatPromptTemplate
29 | from langchain.callbacks.base import AsyncCallbackHandler
30 | from langchain.schema.runnable import ConfigurableField
31 | from operator import itemgetter
32 |
33 | from langchain_community.callbacks import get_openai_callback
34 | from langchain_emoji.components.llm.custom.zhipuai import get_zhipuai_callback
35 | from langchain_emoji.paths import local_data_path
36 | from uuid import UUID
37 | from json.decoder import JSONDecodeError
38 | import json
39 | import base64
40 | import re
41 | from typing import (
42 | List,
43 | Optional,
44 | Sequence,
45 | Dict,
46 | Any,
47 | )
48 |
49 |
50 | logger = logging.getLogger(__name__)
51 |
52 |
53 | class TokenInfo(BaseModel):
54 | model: str
55 | total_tokens: int = 0
56 | prompt_tokens: int = 0
57 | completion_tokens: int = 0
58 | embedding_tokens: int = 0
59 | successful_requests: int = 0
60 | total_cost: float = 0.0
61 |
62 | def clear(self):
63 | self.total_tokens = 0
64 | self.prompt_tokens: int = 0
65 | self.completion_tokens: int = 0
66 | self.successful_requests: int = 0
67 | self.total_cost: float = 0.0
68 |
69 |
70 | class EmojiRequest(BaseModel):
71 | prompt: str
72 | req_id: str
73 | llm: str = Field(default="openai", description="大模型")
74 |
75 | model_config = {
76 | "json_schema_extra": {
77 | "examples": [{"prompt": "xxxx", "req_id": "xxxx", "llm": "xxx"}]
78 | }
79 | }
80 |
81 |
82 | class EmojiInfo(BaseModel):
83 | filename: str
84 | content: str
85 |
86 |
87 | class EmojiDetail(BaseModel):
88 | download_link: Optional[str] = None
89 | base64: str
90 |
91 |
92 | class EmojiResponse(BaseModel):
93 | run_id: UUID
94 | emojiinfo: EmojiInfo
95 | emojidetail: EmojiDetail
96 | token_info: TokenInfo
97 |
98 |
99 | """
100 | 读取chain run_id的回调
101 | """
102 |
103 |
104 | class ReadRunIdAsyncHandler(AsyncCallbackHandler):
105 |
106 | def __init__(self):
107 | self.runid: UUID = None
108 |
109 | async def on_chain_start(
110 | self,
111 | serialized: Dict[str, Any],
112 | inputs: Dict[str, Any],
113 | *,
114 | run_id: UUID,
115 | parent_run_id: Optional[UUID] = None,
116 | tags: Optional[List[str]] = None,
117 | metadata: Optional[Dict[str, Any]] = None,
118 | **kwargs: Any,
119 | ) -> None:
120 | """Run when chain starts running."""
121 | if not self.runid:
122 | self.runid = run_id
123 |
124 | async def on_chat_model_start(
125 | self,
126 | serialized: Dict[str, Any],
127 | messages: List[List[BaseMessage]],
128 | *,
129 | run_id: UUID,
130 | parent_run_id: Optional[UUID] = None,
131 | tags: Optional[List[str]] = None,
132 | metadata: Optional[Dict[str, Any]] = None,
133 | **kwargs: Any,
134 | ) -> Any: ...
135 |
136 | async def on_retriever_end(
137 | self,
138 | documents: Sequence[Document],
139 | *,
140 | run_id: UUID,
141 | parent_run_id: Optional[UUID] = None,
142 | tags: Optional[List[str]] = None,
143 | **kwargs: Any,
144 | ) -> None:
145 | """Run on retriever end."""
146 | # TODO 目前没有合适过滤选项,获取原文链接还是采用全局变量获取
147 | ...
148 |
149 | def get_runid(self) -> UUID:
150 | return self.runid
151 |
152 |
153 | def fix_json(json_str: str) -> dict:
154 | # 使用正则表达式替换掉重复的逗号
155 | fixed_json_str = re.sub(r",\s*}", "}", json_str)
156 | fixed_json_str = re.sub(r",\s*]", "]", fixed_json_str)
157 |
158 | # 尝试加载修复后的JSON字符串
159 | return json.loads(fixed_json_str)
160 |
161 |
162 | class EmojiService:
163 |
164 | @inject
165 | def __init__(
166 | self,
167 | llm_component: LLMComponent,
168 | vector_component: VectorStoreComponent,
169 | trace_component: TraceComponent,
170 | minio_component: MinioComponent,
171 | settings: Settings,
172 | ) -> None:
173 | self.settings = settings
174 | self.llm_service = llm_component
175 | self.vector_service = vector_component
176 | self.trace_service = trace_component
177 | self.minio_service = minio_component
178 | self.chain = self.create_chain(
179 | self.llm_service.llm,
180 | self.get_vector_retriever(),
181 | )
182 |
183 | def num_tokens_from_string(self, string: str) -> int:
184 | """Returns the number of tokens in a text string."""
185 | encoding = tiktoken.encoding_for_model(self.llm_service.modelname)
186 | num_tokens = len(encoding.encode(string))
187 | return num_tokens
188 |
189 | """
190 | 防止返回不是json格式,增加校验处理
191 | """
192 |
193 | def output_handle(self, output: str) -> dict:
194 | try:
195 | # 找到json子串并加载为字典
196 | start_index = output.find("{")
197 | end_index = output.rfind("}") + 1
198 | json_str = output[start_index:end_index]
199 |
200 | # 加载为字典
201 | return json.loads(json_str)
202 | except JSONDecodeError as e:
203 | logger.exception(e)
204 | return fix_json(json_str)
205 |
206 | async def get_emoji(self, body: EmojiRequest) -> EmojiResponse | None:
207 | logger.info(body)
208 | token_callback = (
209 | get_openai_callback if body.llm == "openai" else get_zhipuai_callback
210 | )
211 | with token_callback() as cb:
212 | read_runid = ReadRunIdAsyncHandler() # 读取runid回调
213 | result = await self.chain.ainvoke(
214 | input={"prompt": body.prompt, "llm": body.llm},
215 | config={
216 | "metadata": {
217 | "req_id": body.req_id,
218 | },
219 | "configurable": {"llm": body.llm},
220 | "callbacks": [cb, read_runid],
221 | },
222 | )
223 |
224 | logger.info(result)
225 | emojiinfo = EmojiInfo(**result)
226 |
227 | embed_tokens = self.vector_service.embedcom.total_tokens
228 | tokeninfo = TokenInfo(
229 | model=body.llm,
230 | total_tokens=cb.total_tokens + int(embed_tokens / 10),
231 | prompt_tokens=cb.prompt_tokens,
232 | completion_tokens=cb.completion_tokens,
233 | embedding_tokens=int(embed_tokens / 10),
234 | successful_requests=cb.successful_requests,
235 | total_cost=cb.total_cost,
236 | )
237 |
238 | resobj = EmojiResponse(
239 | run_id=read_runid.get_runid(),
240 | emojiinfo=emojiinfo,
241 | emojidetail=self.get_file_desc(emojiinfo),
242 | token_info=tokeninfo,
243 | )
244 | return resobj
245 |
246 | def get_file_desc(self, info: EmojiInfo) -> EmojiDetail:
247 | logger.info(self.settings.dataset.mode)
248 | if self.settings.dataset.mode == "local":
249 | emoji_file = (
250 | local_data_path / self.settings.dataset.name / "emo" / info.filename
251 | )
252 | with open(emoji_file, "rb") as image_file:
253 | encoded_string = base64.b64encode(image_file.read())
254 | file_base64 = encoded_string.decode("utf-8")
255 | return EmojiDetail(base64=file_base64)
256 | elif self.settings.dataset.mode == "minio":
257 | file_base64 = self.minio_service.get_file_base64(info.filename)
258 | file_download_link = self.minio_service.get_download_link(info.filename)
259 | return EmojiDetail(base64=file_base64, download_link=file_download_link)
260 |
261 | def get_vector_retriever(self):
262 | base_vector = self.vector_service.vector_store.as_retriever().configurable_alternatives(
263 | # This gives this field an id
264 | # When configuring the end runnable, we can then use this id to configure this field
265 | ConfigurableField(id="vectordb"),
266 | default_key=self.settings.vectorstore.database,
267 | )
268 |
269 | return base_vector.with_config(run_name="VectorRetriever")
270 |
271 | def create_retriever_chain(self, retriever: BaseRetriever) -> Runnable:
272 | return (
273 | RunnableLambda(itemgetter("prompt")).with_config(
274 | run_name="Itemgetter:prompt"
275 | )
276 | | retriever
277 | ).with_config(run_name="RetrievalChain")
278 |
279 | def format_docs(self, docs: Sequence[Document]) -> str:
280 | logger.info(docs)
281 | formatted_docs = []
282 | for i, doc in enumerate(docs):
283 | filename = doc.metadata.get("filename")
284 | doc_string = (
285 | f"{doc.page_content}"
286 | )
287 | formatted_docs.append(doc_string)
288 | return "\n".join(formatted_docs)
289 |
290 | def create_chain(
291 | self,
292 | llm: BaseLanguageModel,
293 | retriever: BaseRetriever,
294 | ) -> Runnable:
295 | retriever_chain = self.create_retriever_chain(retriever) | RunnableLambda(
296 | self.format_docs
297 | )
298 | _context = RunnableMap(
299 | {
300 | "context": retriever_chain.with_config(run_name="RetrievalChain"),
301 | "prompt": RunnableLambda(itemgetter("prompt")).with_config(
302 | run_name="Itemgetter:prompt"
303 | ),
304 | "llm": RunnableLambda(itemgetter("llm")).with_config(
305 | run_name="Itemgetter:llm"
306 | ),
307 | }
308 | )
309 | _prompt = RunnableBranch(
310 | (
311 | RunnableLambda(
312 | lambda x: bool(
313 | x.get("llm") == "openai" or x.get("llm") == "deepseek"
314 | )
315 | ).with_config(run_name="CheckLLM"),
316 | ChatPromptTemplate.from_messages(
317 | [
318 | ("human", RESPONSE_TEMPLATE),
319 | ]
320 | ).with_config(run_name="OpenaiPrompt"),
321 | ),
322 | (
323 | ChatPromptTemplate.from_messages(
324 | [
325 | ("human", ZHIPUAI_RESPONSE_TEMPLATE),
326 | ]
327 | ).with_config(run_name="ZhipuaiPrompt")
328 | ),
329 | ).with_config(run_name="ChoiceLLMPrompt")
330 |
331 | response_synthesizer = (_prompt | llm | StrOutputParser()).with_config(
332 | run_name="GenerateResponse",
333 | ) | RunnableLambda(self.output_handle).with_config(run_name="ResponseHandle")
334 |
335 | return (
336 | {
337 | "prompt": RunnableLambda(itemgetter("prompt")).with_config(
338 | run_name="Itemgetter:prompt"
339 | ),
340 | "llm": RunnableLambda(itemgetter("llm")).with_config(
341 | run_name="Itemgetter:llm"
342 | ),
343 | }
344 | | _context
345 | | response_synthesizer
346 | ).with_config(run_name="EmojiChain")
347 |
--------------------------------------------------------------------------------
/langchain_emoji/server/health/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ptonlix/LangChain-Emoji/4c0a926464a7419e48df32b3c5e26794e2f9b312/langchain_emoji/server/health/__init__.py
--------------------------------------------------------------------------------
/langchain_emoji/server/health/health_router.py:
--------------------------------------------------------------------------------
1 | from typing import Literal
2 |
3 | from fastapi import APIRouter
4 | from pydantic import BaseModel, Field
5 | from langchain_emoji.server.utils.model import RestfulModel
6 | import logging
7 |
8 | logger = logging.getLogger(__name__)
9 |
10 | # Not authentication or authorization required to get the health status.
11 | health_router = APIRouter(prefix="/v1")
12 |
13 |
14 | class HealthResponse(BaseModel):
15 | status: Literal["ok"] = Field(default="ok")
16 |
17 |
18 | @health_router.get(
19 | "/health",
20 | tags=["Health"],
21 | response_model=RestfulModel[HealthResponse | None],
22 | )
23 | def health() -> RestfulModel:
24 | """Return ok if the system is up."""
25 | return RestfulModel(data=HealthResponse(status="ok"))
26 |
--------------------------------------------------------------------------------
/langchain_emoji/server/trace/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ptonlix/LangChain-Emoji/4c0a926464a7419e48df32b3c5e26794e2f9b312/langchain_emoji/server/trace/__init__.py
--------------------------------------------------------------------------------
/langchain_emoji/server/trace/trace_router.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from fastapi import APIRouter, Depends, Request
3 | from pydantic import BaseModel
4 | from typing import Union, Optional
5 | from langchain_emoji.server.utils.auth import authenticated
6 | from langchain_emoji.server.emoji.emoji_service import (
7 | EmojiService,
8 | )
9 | from langchain_emoji.server.utils.model import (
10 | RestfulModel,
11 | SystemErrorCode,
12 | )
13 | from uuid import UUID
14 |
15 | logger = logging.getLogger(__name__)
16 |
17 | trace_router = APIRouter(prefix="/v1", dependencies=[Depends(authenticated)])
18 |
19 |
20 | class SendFeedbackBody(BaseModel):
21 | run_id: UUID
22 | key: str = "user_score"
23 |
24 | score: Union[float, int, bool, None] = None
25 | feedback_id: Optional[UUID] = None
26 | comment: Optional[str] = None
27 |
28 |
29 | @trace_router.post(
30 | "/feedback",
31 | tags=["Trace"],
32 | )
33 | async def send_feedback(request: Request, body: SendFeedbackBody):
34 | service = request.state.injector.get(EmojiService)
35 | service.trace_service.trace_client.create_feedback(
36 | body.run_id,
37 | body.key,
38 | score=body.score,
39 | comment=body.comment,
40 | feedback_id=body.feedback_id,
41 | )
42 | return RestfulModel(data="posted feedback successfully")
43 |
44 |
45 | class GetTraceBody(BaseModel):
46 | run_id: UUID
47 |
48 |
49 | @trace_router.post(
50 | "/get_trace",
51 | tags=["Trace"],
52 | )
53 | async def get_trace(request: Request, body: GetTraceBody):
54 | service = request.state.injector.get(EmojiService)
55 |
56 | run_id = body.run_id
57 | if run_id is None:
58 | return RestfulModel(
59 | code=SystemErrorCode, msg="No LangSmith run ID provided", data=None
60 | )
61 | return RestfulModel(data=await service.trace_service.aget_trace_url(str(run_id)))
62 |
--------------------------------------------------------------------------------
/langchain_emoji/server/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ptonlix/LangChain-Emoji/4c0a926464a7419e48df32b3c5e26794e2f9b312/langchain_emoji/server/utils/__init__.py
--------------------------------------------------------------------------------
/langchain_emoji/server/utils/auth.py:
--------------------------------------------------------------------------------
1 | """Authentication mechanism for the API.
2 |
3 | Define a simple mechanism to authenticate requests.
4 | More complex authentication mechanisms can be defined here, and be placed in the
5 | `authenticated` method (being a 'bean' injected in fastapi routers).
6 |
7 | Authorization can also be made after the authentication, and depends on
8 | the authentication. Authorization should not be implemented in this file.
9 |
10 | Authorization can be done by following fastapi's guides:
11 | * https://fastapi.tiangolo.com/advanced/security/oauth2-scopes/
12 | * https://fastapi.tiangolo.com/tutorial/security/
13 | * https://fastapi.tiangolo.com/tutorial/dependencies/dependencies-in-path-operation-decorators/
14 | """
15 |
16 | # mypy: ignore-errors
17 | # Disabled mypy error: All conditional function variants must have identical signatures
18 | # We are changing the implementation of the authenticated method, based on
19 | # the config. If the auth is not enabled, we are not defining the complex method
20 | # with its dependencies.
21 | import logging
22 | import secrets
23 | from typing import Annotated
24 |
25 | from fastapi import Depends, Header, HTTPException
26 |
27 | from langchain_emoji.settings.settings import settings
28 |
29 | # 401 signify that the request requires authentication.
30 | # 403 signify that the authenticated user is not authorized to perform the operation.
31 | NOT_AUTHENTICATED = HTTPException(
32 | status_code=401,
33 | detail="Not authenticated",
34 | headers={"WWW-Authenticate": 'Basic realm="All the API", charset="UTF-8"'},
35 | )
36 |
37 | logger = logging.getLogger(__name__)
38 |
39 |
40 | def _simple_authentication(authorization: Annotated[str, Header()] = "") -> bool:
41 | """Check if the request is authenticated."""
42 | if not secrets.compare_digest(authorization, settings().server.auth.secret):
43 | # If the "Authorization" header is not the expected one, raise an exception.
44 | raise NOT_AUTHENTICATED
45 | return True
46 |
47 |
48 | if not settings().server.auth.enabled:
49 | logger.debug(
50 | "Defining a dummy authentication mechanism for fastapi, always authenticating requests"
51 | )
52 |
53 | # Define a dummy authentication method that always returns True.
54 | def authenticated() -> bool:
55 | """Check if the request is authenticated."""
56 | return True
57 |
58 | else:
59 | logger.info("Defining the given authentication mechanism for the API")
60 |
61 | # Method to be used as a dependency to check if the request is authenticated.
62 | def authenticated(
63 | _simple_authentication: Annotated[bool, Depends(_simple_authentication)]
64 | ) -> bool:
65 | """Check if the request is authenticated."""
66 | assert settings().server.auth.enabled
67 | if not _simple_authentication:
68 | raise NOT_AUTHENTICATED
69 | return True
70 |
--------------------------------------------------------------------------------
/langchain_emoji/server/utils/model.py:
--------------------------------------------------------------------------------
1 | from pydantic import BaseModel
2 | from typing import Generic, TypeVar
3 |
4 |
5 | T = TypeVar("T") # 泛型类型 T
6 |
7 |
8 | class RestfulModel(BaseModel, Generic[T]):
9 | code: int = 0
10 | msg: str = "success"
11 | data: T
12 |
13 |
14 | """
15 | Server Error Code List
16 |
17 | """
18 |
19 | """
20 | system: 10000-10099
21 | """
22 | SystemErrorCode = 10001
23 |
--------------------------------------------------------------------------------
/langchain_emoji/server/vector_store/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ptonlix/LangChain-Emoji/4c0a926464a7419e48df32b3c5e26794e2f9b312/langchain_emoji/server/vector_store/__init__.py
--------------------------------------------------------------------------------
/langchain_emoji/server/vector_store/vector_store_router.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from fastapi import APIRouter, Depends, Request
3 | from typing import Any, List
4 | from pydantic import BaseModel, Field
5 | from langchain_emoji.server.utils.auth import authenticated
6 | from langchain_emoji.server.vector_store.vector_store_server import (
7 | VectorStoreService,
8 | EmojiFragment,
9 | )
10 | from langchain_emoji.server.utils.model import (
11 | RestfulModel,
12 | SystemErrorCode,
13 | )
14 |
15 | logger = logging.getLogger(__name__)
16 |
17 | vector_store_router = APIRouter(prefix="/v1", dependencies=[Depends(authenticated)])
18 |
19 |
20 | class AddEmojiBody(BaseModel):
21 | content: str = Field(description="表情描述")
22 | filename: str = Field(description="表情文件名")
23 | model_config = {
24 | "json_schema_extra": {
25 | "examples": [
26 | {
27 | "content": "xxxx",
28 | "filename": "xxxx",
29 | }
30 | ]
31 | }
32 | }
33 |
34 |
35 | class RagEmojiBody(BaseModel):
36 | prompt: str = Field(description="表情描述")
37 | filenames: List[str] = Field(default=[], description="表情文件名称列表")
38 |
39 | model_config = {
40 | "json_schema_extra": {
41 | "examples": [{"prompt": "xxxx", "filenames": ["xxx", "xxx"]}]
42 | }
43 | }
44 |
45 |
46 | class DelEmojiBody(BaseModel):
47 | vdb_ids: List[str] = Field(description="文章向量数据库ID")
48 | filenames: List[str] = Field(default=[], description="表情文件名称列表")
49 | model_config = {
50 | "json_schema_extra": {
51 | "examples": [{"vdb_ids": ["xxxx"], "filenames": ["xxx", "xxx"]}]
52 | }
53 | }
54 |
55 |
56 | @vector_store_router.post(
57 | "/vector_store/add_emoji",
58 | response_model=RestfulModel[Any | None],
59 | tags=["VectorStore"],
60 | )
61 | def add_emoji(request: Request, body: AddEmojiBody) -> RestfulModel:
62 | """
63 | New article into vector database
64 | """
65 | service = request.state.injector.get(VectorStoreService)
66 | try:
67 | return RestfulModel(
68 | data=service.add_emoji(
69 | content=body.content,
70 | filename=body.filename,
71 | )
72 | )
73 | except Exception as e:
74 | logger.exception(e)
75 | return RestfulModel(code=SystemErrorCode, msg=str(e), data=None)
76 |
77 |
78 | @vector_store_router.post(
79 | "/vector_store/rag_emoji",
80 | response_model=RestfulModel[List[EmojiFragment] | None],
81 | tags=["VectorStore"],
82 | )
83 | def rag_emoji(request: Request, body: RagEmojiBody) -> RestfulModel:
84 | """
85 | Recall articles from vector database
86 | """
87 | service = request.state.injector.get(VectorStoreService)
88 | try:
89 | return RestfulModel(
90 | data=service.rag_emoji(prompt=body.prompt, filenames=body.filenames)
91 | )
92 | except Exception as e:
93 | logger.exception(e)
94 | return RestfulModel(code=SystemErrorCode, msg=str(e), data=None)
95 |
96 |
97 | @vector_store_router.post(
98 | "/vector_store/del_emoji",
99 | response_model=RestfulModel[List[dict] | None],
100 | tags=["VectorStore"],
101 | )
102 | def del_emoji(request: Request, body: DelEmojiBody) -> RestfulModel:
103 | """
104 | Delete articles from vector database
105 | """
106 | service = request.state.injector.get(VectorStoreService)
107 | try:
108 | return RestfulModel(
109 | data=service.del_emoji(vdb_ids=body.vdb_ids, filenames=body.filenames)
110 | )
111 | except Exception as e:
112 | logger.exception(e)
113 | return RestfulModel(code=SystemErrorCode, msg=str(e), data=None)
114 |
--------------------------------------------------------------------------------
/langchain_emoji/server/vector_store/vector_store_server.py:
--------------------------------------------------------------------------------
1 | from injector import inject, singleton
2 | from langchain_emoji.components.vector_store import VectorStoreComponent
3 | from pydantic import BaseModel, Field
4 | import logging
5 | from typing import List, Optional
6 |
7 | logger = logging.getLogger(__name__)
8 |
9 |
10 | class EmojiFragment(BaseModel):
11 | content: str = Field(description="表情描述段落")
12 | filename: Optional[str] = Field(default=-1, description="表情文件名")
13 |
14 |
15 | @singleton
16 | class VectorStoreService:
17 | @inject
18 | def __init__(
19 | self,
20 | vector_store: VectorStoreComponent,
21 | ) -> None:
22 | self.client = vector_store.vector_store
23 |
24 | def add_emoji(
25 | self,
26 | content: str,
27 | filename: str,
28 | ) -> List[str]:
29 | metadata = {
30 | "filename": filename,
31 | }
32 | return self.client.add_original_texts_with_filename(
33 | filename=filename, texts=[content], metadatas=[metadata]
34 | )
35 |
36 | def rag_emoji(self, prompt: str, filenames: List[str] = []) -> List[EmojiFragment]:
37 | fragment_list = self.client.similarity_search_by_filenames(
38 | query=prompt, filenames=filenames, k=3
39 | )
40 | res = []
41 | for fragment in fragment_list:
42 | res.append(
43 | EmojiFragment(content=fragment.page_content, **fragment.metadata)
44 | )
45 | return res
46 |
47 | def del_emoji(self, vdb_ids: List[str], filenames: List[str] = []) -> List[dict]:
48 | return self.client.delete_texts_with_filenames(
49 | document_ids=vdb_ids, filenames=filenames
50 | )
51 |
--------------------------------------------------------------------------------
/langchain_emoji/settings/__init__.py:
--------------------------------------------------------------------------------
1 | """Settings."""
2 |
--------------------------------------------------------------------------------
/langchain_emoji/settings/settings.py:
--------------------------------------------------------------------------------
1 | from typing import Literal, Optional
2 |
3 | from pydantic import BaseModel, Field
4 |
5 | from langchain_emoji.settings.settings_loader import load_active_settings
6 |
7 |
8 | class CorsSettings(BaseModel):
9 | """CORS configuration.
10 |
11 | For more details on the CORS configuration, see:
12 | # * https://fastapi.tiangolo.com/tutorial/cors/
13 | # * https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS
14 | """
15 |
16 | enabled: bool = Field(
17 | description="Flag indicating if CORS headers are set or not."
18 | "If set to True, the CORS headers will be set to allow all origins, methods and headers.",
19 | default=False,
20 | )
21 | allow_credentials: bool = Field(
22 | description="Indicate that cookies should be supported for cross-origin requests",
23 | default=False,
24 | )
25 | allow_origins: list[str] = Field(
26 | description="A list of origins that should be permitted to make cross-origin requests.",
27 | default=[],
28 | )
29 | allow_origin_regex: list[str] = Field(
30 | description="A regex string to match against origins that should be permitted to make cross-origin requests.",
31 | default=None,
32 | )
33 | allow_methods: list[str] = Field(
34 | description="A list of HTTP methods that should be allowed for cross-origin requests.",
35 | default=[
36 | "GET",
37 | ],
38 | )
39 | allow_headers: list[str] = Field(
40 | description="A list of HTTP request headers that should be supported for cross-origin requests.",
41 | default=[],
42 | )
43 |
44 |
45 | class AuthSettings(BaseModel):
46 | """Authentication configuration.
47 |
48 | The implementation of the authentication strategy must
49 | """
50 |
51 | enabled: bool = Field(
52 | description="Flag indicating if authentication is enabled or not.",
53 | default=False,
54 | )
55 | secret: str = Field(
56 | description="The secret to be used for authentication. "
57 | "It can be any non-blank string. For HTTP basic authentication, "
58 | "this value should be the whole 'Authorization' header that is expected"
59 | )
60 |
61 |
62 | class ServerSettings(BaseModel):
63 | env_name: str = Field(
64 | description="Name of the environment (prod, staging, local...)"
65 | )
66 | port: int = Field(
67 | description="Port of Langchain-DeepRead FastAPI server, defaults to 8001"
68 | )
69 | cors: CorsSettings = Field(
70 | description="CORS configuration", default=CorsSettings(enabled=False)
71 | )
72 | auth: AuthSettings = Field(
73 | description="Authentication configuration",
74 | default_factory=lambda: AuthSettings(enabled=False, secret="secret-key"),
75 | )
76 |
77 |
78 | class LLMSettings(BaseModel):
79 | mode: Literal["local", "openai", "zhipuai", "deepseek", "all", "mock"]
80 | max_new_tokens: int = Field(
81 | 256,
82 | description="The maximum number of token that the LLM is authorized to generate in one completion.",
83 | )
84 |
85 |
86 | class OpenAISettings(BaseModel):
87 | temperature: float
88 | modelname: str
89 | api_key: str
90 | api_base: str
91 |
92 |
93 | class DeepSeekSettings(BaseModel):
94 | temperature: float
95 | modelname: str
96 | api_key: str
97 | api_base: str
98 |
99 |
100 | class ZhipuAISettings(BaseModel):
101 | temperature: float
102 | top_p: float
103 | modelname: str
104 | api_key: str
105 |
106 |
107 | class LangSmithSettings(BaseModel):
108 | trace_version_v2: bool
109 | langchain_project: str
110 | api_key: str
111 |
112 |
113 | class TvectordbSettings(BaseModel):
114 | url: str
115 | username: str
116 | api_key: str
117 | collection_name: str
118 | database_name: str
119 |
120 |
121 | class EmbeddingSettings(BaseModel):
122 | mode: Literal["local", "openai", "zhipuai", "mock"]
123 |
124 |
125 | class ChromadbSettings(BaseModel):
126 | persist_dir: str
127 | collection_name: str
128 |
129 |
130 | class VectorstoreSettings(BaseModel):
131 | database: Literal["tcvectordb", "chromadb"]
132 | tcvectordb: TvectordbSettings
133 | chromadb: ChromadbSettings
134 |
135 |
136 | class DataSettings(BaseModel):
137 | local_data_folder: str = Field(
138 | description="Path to local storage."
139 | "It will be treated as an absolute path if it starts with /"
140 | )
141 |
142 |
143 | class MinioSettings(BaseModel):
144 | host: str
145 | bucket_name: str
146 | access_key: str
147 | secret_key: str
148 |
149 |
150 | class DatasetSettings(BaseModel):
151 | name: str
152 | google_driver_id: str
153 | mode: Literal["minio", "local"]
154 |
155 |
156 | class Settings(BaseModel):
157 | server: ServerSettings
158 | llm: LLMSettings
159 | openai: OpenAISettings
160 | deepseek: DeepSeekSettings
161 | zhipuai: ZhipuAISettings
162 | langsmith: LangSmithSettings
163 | vectorstore: VectorstoreSettings
164 | embedding: EmbeddingSettings
165 | data: DataSettings
166 | minio: Optional[MinioSettings] = None
167 | dataset: DatasetSettings
168 |
169 |
170 | """
171 | This is visible just for DI or testing purposes.
172 |
173 | Use dependency injection or `settings()` method instead.
174 | """
175 | unsafe_settings = load_active_settings()
176 |
177 | """
178 | This is visible just for DI or testing purposes.
179 |
180 | Use dependency injection or `settings()` method instead.
181 | """
182 | unsafe_typed_settings = Settings(**unsafe_settings)
183 |
184 |
185 | def settings() -> Settings:
186 | """Get the current loaded settings from the DI container.
187 |
188 | This method exists to keep compatibility with the existing code,
189 | that require global access to the settings.
190 |
191 | For regular components use dependency injection instead.
192 | """
193 | from langchain_emoji.di import global_injector
194 |
195 | return global_injector.get(Settings)
196 |
--------------------------------------------------------------------------------
/langchain_emoji/settings/settings_loader.py:
--------------------------------------------------------------------------------
1 | import functools
2 | import logging
3 | import os
4 | import sys
5 | from collections.abc import Iterable
6 | from pathlib import Path
7 | from typing import Any, List
8 |
9 | from pydantic.v1.utils import deep_update, unique_list
10 |
11 | from langchain_emoji.constants import PROJECT_ROOT_PATH
12 | from langchain_emoji.settings.yaml import (
13 | load_yaml_with_envvars,
14 | update_yaml_config_file,
15 | )
16 |
17 | logger = logging.getLogger(__name__)
18 |
19 | _settings_folder = os.environ.get("LE_SETTINGS_FOLDER", PROJECT_ROOT_PATH)
20 |
21 | # if running in unittest, use the test profile
22 | _test_profile = ["test"] if "unittest" in sys.modules else []
23 |
24 | active_profiles: list[str] = unique_list(
25 | ["default"]
26 | + [
27 | item.strip()
28 | for item in os.environ.get("LE_PROFILES", "").split(",")
29 | if item.strip()
30 | ]
31 | + _test_profile
32 | )
33 |
34 |
35 | def merge_settings(settings: Iterable[dict[str, Any]]) -> dict[str, Any]:
36 | return functools.reduce(deep_update, settings, {})
37 |
38 |
39 | def load_settings_from_profile(profile: str) -> dict[str, Any]:
40 | if profile == "default":
41 | profile_file_name = "settings.yaml"
42 | else:
43 | profile_file_name = f"settings-{profile}.yaml"
44 |
45 | path = Path(_settings_folder) / profile_file_name
46 | with Path(path).open("r") as f:
47 | config = load_yaml_with_envvars(f)
48 | if not isinstance(config, dict):
49 | raise TypeError(f"Config file has no top-level mapping: {path}")
50 | return config
51 |
52 |
53 | def load_active_settings() -> dict[str, Any]:
54 | """Load active profiles and merge them."""
55 | logger.info("Starting application with profiles=%s", active_profiles)
56 | loaded_profiles = [
57 | load_settings_from_profile(profile) for profile in active_profiles
58 | ]
59 | merged: dict[str, Any] = merge_settings(loaded_profiles)
60 | return merged
61 |
62 |
63 | def get_active_settings() -> List[dict[str, Any]]:
64 | """Load active profiles and merge them."""
65 | loaded_profiles = [
66 | {profile: load_settings_from_profile(profile)} for profile in active_profiles
67 | ]
68 |
69 | return loaded_profiles
70 |
71 |
72 | def save_active_settings(profile: str, config: dict[str, Any]):
73 |
74 | if profile == "default":
75 | profile_file_name = "settings.yaml"
76 | else:
77 | profile_file_name = f"settings-{profile}.yaml"
78 |
79 | path = Path(_settings_folder) / profile_file_name
80 | update_yaml_config_file(path, config)
81 |
--------------------------------------------------------------------------------
/langchain_emoji/settings/yaml.py:
--------------------------------------------------------------------------------
1 | import os
2 | import re
3 | import typing
4 | from typing import Any, TextIO
5 |
6 | import yaml
7 | from yaml import SafeLoader
8 |
9 | _env_replace_matcher = re.compile(r"\$\{(\w|_)+:?.*}")
10 |
11 |
12 | @typing.no_type_check # pyaml does not have good hints, everything is Any
13 | def load_yaml_with_envvars(
14 | stream: TextIO, environ: dict[str, Any] = os.environ
15 | ) -> dict[str, Any]:
16 | """Load yaml file with environment variable expansion.
17 |
18 | The pattern ${VAR} or ${VAR:default} will be replaced with
19 | the value of the environment variable.
20 | """
21 | loader = SafeLoader(stream)
22 |
23 | def load_env_var(_, node) -> str:
24 | """Extract the matched value, expand env variable, and replace the match."""
25 | value = str(node.value).removeprefix("${").removesuffix("}")
26 | split = value.split(":", 1)
27 | env_var = split[0]
28 | value = environ.get(env_var)
29 | default = None if len(split) == 1 else split[1]
30 | if value is None and default is None:
31 | raise ValueError(
32 | f"Environment variable {env_var} is not set and not default was provided"
33 | )
34 | return value or default
35 |
36 | loader.add_implicit_resolver("env_var_replacer", _env_replace_matcher, None)
37 | loader.add_constructor("env_var_replacer", load_env_var)
38 |
39 | try:
40 | return loader.get_single_data()
41 | finally:
42 | loader.dispose()
43 |
44 |
45 | @typing.no_type_check # pyaml does not have good hints, everything is Any
46 | def update_yaml_config_file(file_path: str, update_dict: dict[str, Any]):
47 | """Update YAML configuration file with given key-value pairs."""
48 | with open(file_path, "r") as file:
49 | config = yaml.safe_load(file)
50 |
51 | # Update the config dictionary with the provided key-value pairs
52 | for key, value in update_dict.items():
53 | config[key] = value
54 |
55 | # Write the updated config back to the file
56 | with open(file_path, "w") as file:
57 | yaml.safe_dump(config, file, sort_keys=False)
58 |
--------------------------------------------------------------------------------
/langchain_emoji/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ptonlix/LangChain-Emoji/4c0a926464a7419e48df32b3c5e26794e2f9b312/langchain_emoji/utils/__init__.py
--------------------------------------------------------------------------------
/langchain_emoji/utils/_compat.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import sys
4 |
5 | from contextlib import suppress
6 |
7 | if sys.version_info < (3, 10):
8 | # compatibility for python <3.10
9 | import importlib_metadata as metadata
10 | else:
11 | from importlib import metadata
12 |
13 | WINDOWS = sys.platform == "win32"
14 |
15 |
16 | def decode(string: bytes | str, encodings: list[str] | None = None) -> str:
17 | if not isinstance(string, bytes):
18 | return string
19 |
20 | encodings = encodings or ["utf-8", "latin1", "ascii"]
21 |
22 | for encoding in encodings:
23 | with suppress(UnicodeEncodeError, UnicodeDecodeError):
24 | return string.decode(encoding)
25 |
26 | return string.decode(encodings[0], errors="ignore")
27 |
28 |
29 | def encode(string: str, encodings: list[str] | None = None) -> bytes:
30 | if isinstance(string, bytes):
31 | return string
32 |
33 | encodings = encodings or ["utf-8", "latin1", "ascii"]
34 |
35 | for encoding in encodings:
36 | with suppress(UnicodeEncodeError, UnicodeDecodeError):
37 | return string.encode(encoding)
38 |
39 | return string.encode(encodings[0], errors="ignore")
40 |
41 |
42 | __all__ = [
43 | "WINDOWS",
44 | "decode",
45 | "encode",
46 | "metadata",
47 | ]
48 |
--------------------------------------------------------------------------------
/local_data/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ptonlix/LangChain-Emoji/4c0a926464a7419e48df32b3c5e26794e2f9b312/local_data/.gitkeep
--------------------------------------------------------------------------------
/log/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ptonlix/LangChain-Emoji/4c0a926464a7419e48df32b3c5e26794e2f9b312/log/.gitkeep
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.poetry]
2 | name = "langchain-emoji"
3 | version = "0.1.0"
4 | description = ""
5 | authors = ["ptonlix <260431910@qq.com>"]
6 | readme = "README.md"
7 |
8 | [tool.poetry.dependencies]
9 | python = "~3.10"
10 | uvicorn = "^0.29.0"
11 | langchain = "^0.1.16"
12 | fastapi = "^0.110.1"
13 | injector = "^0.21.0"
14 | tiktoken = "^0.6.0"
15 | openai = "^1.23.1"
16 | zhipuai = "^2.0.1"
17 | watchfiles = "^0.21.0"
18 | tcvectordb = "^1.3.2"
19 | jsonlines = "^4.0.0"
20 | gdown = "^5.1.0"
21 | minio = "^7.2.7"
22 | langchain-chroma = "^0.1.0"
23 | streamlit = "^1.34.0"
24 | onnxruntime = "1.16.3"
25 |
26 |
27 | [build-system]
28 | requires = ["poetry-core"]
29 | build-backend = "poetry.core.masonry.api"
30 |
--------------------------------------------------------------------------------
/settings.yaml:
--------------------------------------------------------------------------------
1 | # The default configuration file.
2 | # More information about configuration can be found in the documentation: https://github.com/ptonlix/LangChain-SearXNG
3 | # Syntax in `langchain_emoji/settings/settings.py`
4 | server:
5 | env_name: ${APP_ENV:prod}
6 | port: ${PORT:8003}
7 | cors:
8 | enabled: false
9 | allow_origins: ["*"]
10 | allow_methods: ["*"]
11 | allow_headers: ["*"]
12 | auth:
13 | enabled: false
14 | # python -c 'import base64; print("Basic " + base64.b64encode("secret:key".encode()).decode())'
15 | # 'secret' is the username and 'key' is the password for basic auth by default
16 | # If the auth is enabled, this value must be set in the "Authorization" header of the request.
17 | secret: "Basic c2VjcmV0OmtleQ=="
18 |
19 | llm:
20 | mode: all
21 |
22 | embedding:
23 | mode: zhipuai
24 |
25 | openai:
26 | temperature: 1
27 | modelname: "gpt-3.5-turbo-0125" #"gpt-3.5-turbo-1106"
28 | api_base: ${OPENAI_API_BASE:}
29 | api_key: ${OPENAI_API_KEY:}
30 |
31 | deepseek:
32 | temperature: 1
33 | modelname: "deepseek-chat"
34 | api_base: ${DEEPSEEK_API_BASE:}
35 | api_key: ${DEEPSEEK_API_KEY:}
36 |
37 | zhipuai:
38 | temperature: 0.95
39 | top_p: 0.6
40 | modelname: "glm-3-turbo"
41 | api_key: ${ZHIPUAI_API_KEY:}
42 |
43 | langsmith:
44 | trace_version_v2: true
45 | langchain_project: langchain-emoji
46 | api_key: ${LANGCHAIN_API_KEY:}
47 |
48 | vectorstore:
49 | database: chromadb
50 | tcvectordb:
51 | url: ${TCVERCTORDB_API_HOST:}
52 | username: root
53 | api_key: ${TCVERCTORDB_API_KEY:}
54 | collection_name: EmojiCollection
55 | database_name: DeepReadDatabase
56 | chromadb:
57 | persist_dir: local_data
58 | collection_name: EmojiCollection
59 |
60 | dataset:
61 | name: emo-visual-data
62 | google_driver_id: 1r3uO0wvgQ791M_6iIyBODo_8GekBjPMf
63 | mode: local
64 |
65 | data:
66 | local_data_folder: local_data
67 |
68 | minio:
69 | host: ${MINIO_HOST:}
70 | bucket_name: emoji
71 | access_key: ${MINIO_ACCESS_KEY:}
72 | secret_key: ${MINIO_SECRET_KEY:}
73 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ptonlix/LangChain-Emoji/4c0a926464a7419e48df32b3c5e26794e2f9b312/tests/__init__.py
--------------------------------------------------------------------------------
/tools/datainit.py:
--------------------------------------------------------------------------------
1 | import gdown
2 | import zipfile
3 | import os
4 | import sys
5 | import argparse
6 | from tqdm import tqdm
7 | from pathlib import Path
8 | import hashlib
9 | from minio import Minio
10 | from minio.error import MinioException
11 | from langchain_core.vectorstores import VectorStore
12 |
13 | import concurrent.futures
14 | import threading
15 | import jsonlines
16 | from queue import Queue
17 |
18 | import logging
19 |
20 | logging.getLogger("httpx").setLevel(logging.WARNING)
21 | logger = logging.getLogger(__name__)
22 | logger.setLevel(logging.INFO)
23 |
24 | # 从百度下载数据,解析
25 | # https://pan.baidu.com/s/11iwqoxLtjV-DOQli81vZ6Q?pwd=tab4
26 | # 下载到local_data
27 |
28 |
29 | # 或者从谷歌云下载 (可选)
30 | # 定义谷歌云盘文件的ID
31 | def download_and_extract_data(file_id: str, output_dir: Path) -> bool:
32 | """
33 | Download a file from Google Drive and extract it to the specified directory.
34 |
35 | Parameters:
36 | - file_id (str): Google Drive file ID.
37 | - output_dir (str): Directory to extract the downloaded file.
38 |
39 | Returns:
40 | - bool: True if successful, False otherwise.
41 | """
42 | try:
43 | # Define Google Drive download URL
44 | url = f"https://drive.google.com/uc?id={file_id}"
45 |
46 | # Define the output file path
47 | output_path = str(output_dir / "emo-visual-data.zip")
48 |
49 | # Download the file
50 | gdown.download(url, output_path, quiet=False)
51 |
52 | # Create the extraction directory if it doesn't exist
53 | os.makedirs(output_dir, exist_ok=True)
54 |
55 | # Extract the downloaded file
56 | with zipfile.ZipFile(output_path, "r") as zip_ref:
57 | zip_ref.extractall(output_dir)
58 |
59 | # Clean up: remove the downloaded zip file
60 | os.remove(output_path)
61 |
62 | logger.info("File downloaded and extracted successfully.")
63 | return True
64 |
65 | except Exception as e:
66 | logger.exception(f"Error: {e}")
67 | return False
68 |
69 |
70 | def calculate_md5(file_path: str) -> str:
71 | """
72 | Calculate the MD5 hash of a file.
73 |
74 | Parameters:
75 | - file_path (str): Path to the file.
76 |
77 | Returns:
78 | - str: MD5 hash of the file.
79 | """
80 | md5_hash = hashlib.md5()
81 | with open(file_path, "rb") as f:
82 | # 以二进制方式读取文件内容并计算MD5值
83 | for chunk in iter(lambda: f.read(4096), b""):
84 | md5_hash.update(chunk)
85 | return md5_hash.hexdigest()
86 |
87 |
88 | def create_minio_bucket(minio_client: Minio, bucket_name: str) -> bool:
89 | """
90 | Create a MinIO bucket if it doesn't already exist.
91 |
92 | Parameters:
93 | - minio_client (Minio): MinIO client object.
94 | - bucket_name (str): Name of the MinIO bucket to create.
95 |
96 | Returns:
97 | - bool: True if successful or the bucket already exists, False otherwise.
98 | """
99 | try:
100 | # Check if the bucket already exists
101 | if minio_client.bucket_exists(bucket_name):
102 | logger.info(f"Bucket '{bucket_name}' already exists.")
103 | return True
104 | else:
105 | # Create the bucket
106 | minio_client.make_bucket(bucket_name)
107 | logger.info(f"Bucket '{bucket_name}' created successfully.")
108 | return True
109 |
110 | except MinioException as err:
111 | logger.exception(f"MinIO error: {err}")
112 | return False
113 |
114 |
115 | def upload_file(
116 | minio_client: Minio,
117 | bucket_name: str,
118 | file_path: str,
119 | object_name: str,
120 | file_md5: str,
121 | failed_files: list,
122 | progress_queue: Queue,
123 | ) -> bool:
124 | """
125 | Upload a single file to MinIO.
126 |
127 | Parameters:
128 | Parameters:
129 | - minio_client (Minio): MinIO client object.
130 | - bucket_name (str): Name of the MinIO bucket.
131 | - file_path (str): Path to the file.
132 | - object_name (str): Object name in MinIO.
133 | - file_md5 (str): MD5 hash of the file.
134 | - total_pbar (tqdm): Total progress bar.
135 | - failed_files (list): List to store failed file names.
136 |
137 | Returns:
138 | - bool: True if successful, False otherwise.
139 | """
140 | try:
141 | try:
142 | # 检查MinIO中是否已存在相同对象(文件)并且MD5值相同。
143 | object_stat = minio_client.stat_object(bucket_name, object_name)
144 | # 如果对象存在且MD5值相等,跳过此文件。'ETag': '"c204c0f5ab34d3caa8909efd81b24c47"'
145 | if object_stat.metadata["ETag"][1:-1] == file_md5:
146 | logger.debug(
147 | f"Skipping {file_path}: File already exists in MinIO with same MD5."
148 | )
149 | progress_queue.put(1)
150 | return True
151 | except MinioException as e:
152 | # 如果对象不存在或MD5值不相等,则继续上传。
153 | if e.code != "NoSuchKey":
154 | logger.error(f"Error checking object: {e}")
155 | return False
156 |
157 | # 上传文件到MinIO。
158 | minio_client.fput_object(
159 | bucket_name,
160 | object_name,
161 | file_path,
162 | )
163 | logger.debug(f"Uploaded {file_path} to MinIO.")
164 | progress_queue.put(1)
165 | return True
166 |
167 | except MinioException as e:
168 | logger.error(f"Error uploading {file_path} to MinIO: {e}")
169 | failed_files.append(file_path)
170 | progress_queue.put(1)
171 | return False
172 |
173 |
174 | def track_progress(total: int, progress_queue: Queue):
175 | # 创建总进度条。
176 | with tqdm(
177 | total=total,
178 | leave=True,
179 | ncols=100,
180 | file=sys.stdout,
181 | desc=f"Total Files {total}",
182 | unit="files",
183 | ) as total_pbar:
184 |
185 | while True:
186 | progress_queue.get()
187 | total_pbar.update(1)
188 | if total_pbar.n >= total_pbar.total:
189 | break
190 |
191 |
192 | # 上传云端Minio(可选)
193 | def upload_to_minio(minio_client: Minio, source_dir: Path, bucket_name: str) -> bool:
194 | """
195 | Upload files from a local directory to a MinIO bucket.
196 |
197 | Parameters:
198 | - minio_client (Minio): MinIO client object.
199 | - source_dir (str): Directory containing the files to upload.
200 | - bucket_name (str): Name of the MinIO bucket.
201 |
202 | Returns:
203 | - bool: True if successful, False otherwise.
204 | """
205 |
206 | # 保存上传失败的文件名的文件路径
207 | failed_files_path = "failed_files.txt"
208 | # 上传失败的文件列表
209 | failed_files = []
210 |
211 | if not create_minio_bucket(minio_client, bucket_name):
212 | return False
213 |
214 | try:
215 |
216 | # 使用线程池并发执行文件上传任务
217 | with concurrent.futures.ThreadPoolExecutor() as executor:
218 | futures = []
219 | progress_queue = Queue()
220 | total_files = sum(len(files) for _, _, files in os.walk(source_dir))
221 | # 遍历源目录中的所有文件。
222 | for root, dirs, files in os.walk(source_dir):
223 | for file in files:
224 | file_path = os.path.join(root, file)
225 | # 计算对象名称(相对于source_dir的相对路径)。
226 | object_name = os.path.relpath(file_path, source_dir)
227 | # 计算文件的MD5值。
228 | file_md5 = calculate_md5(file_path)
229 | # 提交上传任务到线程池。
230 | future = executor.submit(
231 | upload_file,
232 | minio_client,
233 | bucket_name,
234 | file_path,
235 | object_name,
236 | file_md5,
237 | failed_files,
238 | progress_queue,
239 | )
240 | futures.append(future)
241 |
242 | # Start the progress tracking thread
243 | progress_thread = threading.Thread(
244 | target=track_progress, args=(len(futures), progress_queue), daemon=True
245 | )
246 | progress_thread.start()
247 |
248 | # 等待所有上传任务完成。
249 | for future in concurrent.futures.as_completed(futures):
250 | if not future.result():
251 | logger.error(f"result error:{future.result()}")
252 |
253 | logger.info(f"All {total_files} files uploaded to MinIO.")
254 | # 输出上传成功和失败文件数量
255 | success_count = total_files - len(failed_files)
256 | logger.info(f"Total files uploaded successfully: {success_count}")
257 | logger.info(f"Total files failed to upload: {len(failed_files)}")
258 |
259 | # 将上传失败的文件名保存到文件中
260 | if failed_files:
261 | with open(failed_files_path, "w") as f:
262 | f.write("\n".join(failed_files))
263 |
264 | progress_thread.join()
265 |
266 | return True
267 |
268 | except Exception as err:
269 | logger.exception(f"An error occurred: {err}")
270 | return False
271 |
272 |
273 | # 加载向量数据库
274 |
275 |
276 | def upload_file_vectordb(
277 | client: VectorStore,
278 | data: dict,
279 | failed_files: list,
280 | progress_queue: Queue,
281 | ) -> bool:
282 | try:
283 | filename = data["filename"]
284 | content = data["content"]
285 |
286 | if isinstance(client, VectorStore):
287 | metadata = {
288 | "filename": filename,
289 | }
290 |
291 | result = client.add_original_texts_with_filename(
292 | filename=filename, texts=[content], metadatas=[metadata]
293 | )
294 |
295 | progress_queue.put(1)
296 | return bool(result)
297 |
298 | except Exception as e:
299 | logger.error(f"Error uploading {filename} to VectorDB: {e}")
300 | failed_files.append(filename)
301 | progress_queue.put(1)
302 |
303 |
304 | def upload_vectordb(client: VectorStore, dataset_file: str) -> bool:
305 |
306 | # 保存上传失败的文件名的文件路径
307 | failed_files_path = "vector_failed_files.txt"
308 | # 上传失败的文件列表
309 | failed_files = []
310 |
311 | try:
312 | with jsonlines.open(dataset_file) as reader:
313 |
314 | # 使用线程池并发执行文件上传任务
315 | with concurrent.futures.ThreadPoolExecutor() as executor:
316 | futures = []
317 | progress_queue = Queue()
318 | for json_line in reader:
319 | future = executor.submit(
320 | upload_file_vectordb,
321 | client,
322 | json_line,
323 | failed_files,
324 | progress_queue,
325 | )
326 | futures.append(future)
327 |
328 | # Start the progress tracking thread
329 | progress_thread = threading.Thread(
330 | target=track_progress,
331 | args=(len(futures), progress_queue),
332 | daemon=True,
333 | )
334 | progress_thread.start()
335 |
336 | # 等待所有上传任务完成。
337 | for future in concurrent.futures.as_completed(futures):
338 | if not future.result():
339 | # logger.error(f"result error:{future.result()}")
340 | ...
341 |
342 | logger.info(f"All {len(futures)} files uploaded to VectorDB.")
343 | # 输出上传成功和失败数量
344 | success_count = len(futures) - len(failed_files)
345 | logger.info(f"Total files uploaded successfully: {success_count}")
346 | logger.info(f"Total files failed to upload: {len(failed_files)}")
347 |
348 | # 将上传失败的文件名保存到文件中
349 | if failed_files:
350 | with open(failed_files_path, "w") as f:
351 | f.write("\n".join(failed_files))
352 |
353 | progress_thread.join()
354 |
355 | return True
356 |
357 | except Exception as err:
358 | logger.exception(f"An error occurred: {err}")
359 | return False
360 |
361 |
362 | if __name__ == "__main__":
363 |
364 | parser = argparse.ArgumentParser(description="Emoji data initialization tool")
365 | parser.add_argument(
366 | "--download", action="store_true", help="Download and extract emoji data"
367 | )
368 | parser.add_argument("--upload", action="store_true", help="Upload files to MinIO")
369 | parser.add_argument(
370 | "--vectordb", action="store_true", help="Vector files to Database"
371 | )
372 |
373 | args = parser.parse_args()
374 |
375 | # 检查是否提供了可选参数
376 | if not (args.download or args.upload or args.vectordb):
377 | print(
378 | "提示: 没有提供可选参数 '--download' '--upload '--vectordb' 请至少指定一个操作。"
379 | )
380 | parser.print_help()
381 | exit(1)
382 |
383 | if args.download:
384 |
385 | from langchain_emoji.paths import local_data_path
386 | from langchain_emoji.settings.settings import settings
387 |
388 | # Define Google Drive file ID
389 | file_id = settings().dataset.google_driver_id
390 |
391 | # Define the directory to extract the file
392 | extract_dir = local_data_path
393 |
394 | # Download and extract the file
395 | success = download_and_extract_data(file_id, extract_dir)
396 |
397 | if not success:
398 | print("download and extract emoji data failed, exit!")
399 | exit(1)
400 |
401 | if args.upload:
402 |
403 | from langchain_emoji.settings.settings import settings
404 | from langchain_emoji.paths import local_data_path
405 |
406 | # MinIO configuration
407 | minio_endpoint = settings().minio.host
408 | minio_access_key = settings().minio.access_key
409 | minio_secret_key = settings().minio.secret_key
410 | secure = False # Change to False if MinIO server is not using SSL/TLS
411 |
412 | # # Initialize MinIO client
413 | minio_client = Minio(
414 | minio_endpoint,
415 | access_key=minio_access_key,
416 | secret_key=minio_secret_key,
417 | secure=secure,
418 | )
419 |
420 | dataset_name = settings().dataset.name
421 | # # Source directory containing files to upload
422 | source_dir = local_data_path / dataset_name / "emo"
423 | if not (os.path.exists(source_dir) and os.path.isdir(source_dir)):
424 | print("emoji datasetdoes not exist, exit!")
425 | exit(1)
426 |
427 | # # Name of the MinIO bucket
428 | bucket_name = "emoji"
429 |
430 | # # Upload files to MinIO
431 | success = upload_to_minio(minio_client, source_dir, bucket_name)
432 |
433 | if not success:
434 | print("upload to minio failed, exit!")
435 | exit(1)
436 |
437 | if args.vectordb:
438 |
439 | from langchain_emoji.paths import local_data_path
440 | from langchain_emoji.settings.settings import settings
441 | from langchain_emoji.components.vector_store import VectorStoreComponent
442 | from langchain_emoji.components.embedding.embedding_component import (
443 | EmbeddingComponent,
444 | )
445 |
446 | embed = EmbeddingComponent(settings())
447 | vsc = VectorStoreComponent(embed, settings())
448 |
449 | dataset_name = settings().dataset.name
450 | dataset_file = local_data_path / dataset_name / "data.jsonl"
451 | if not (os.path.exists(dataset_file) and os.path.isfile(dataset_file)):
452 | print("emoji datajsonl not exist, exit!")
453 | exit(1)
454 |
455 | # # Upload files to MinIO
456 | success = upload_vectordb(vsc.vector_store, dataset_file)
457 |
458 | if not success:
459 | print("upload to minio failed, exit!")
460 | exit(1)
461 |
462 | # 加载进向量数据库
463 |
--------------------------------------------------------------------------------
/tools/json2jsonl.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import jsonlines
4 | import sys
5 |
6 |
7 | def convert_json_to_jsonl(json_file, jsonl_file):
8 | with open(json_file, "r") as file:
9 | json_data = json.load(file)
10 |
11 | with jsonlines.open(jsonl_file, "w") as writer:
12 | for item in json_data:
13 | writer.write(item)
14 |
15 |
16 | def main():
17 | parser = argparse.ArgumentParser(description="Convert JSON to JSONL")
18 | parser.add_argument("input_json", help="Input JSON file")
19 | parser.add_argument("output_jsonl", help="Output JSONL file")
20 | args = parser.parse_args()
21 |
22 | try:
23 | convert_json_to_jsonl(args.input_json, args.output_jsonl)
24 | except FileNotFoundError:
25 | print("Error: Input JSON file not found.", file=sys.stderr)
26 | sys.exit(1)
27 |
28 |
29 | if __name__ == "__main__":
30 | main()
31 |
--------------------------------------------------------------------------------