├── .gitignore ├── LICENSE ├── README.md ├── docs ├── config.md ├── description.md ├── pic │ ├── emoji_logo.jpg │ ├── example_video.gif │ ├── flowstruct.drawio │ ├── flowstruct.png │ ├── introduce.jpg │ ├── logo.jpg │ ├── wx.png │ └── wxemoji.png └── systemd.md ├── frontend └── emoji.py ├── langchain_emoji ├── __init__.py ├── __main__.py ├── __version__.py ├── components │ ├── embedding │ │ ├── __init__.py │ │ ├── custom │ │ │ └── zhipuai │ │ │ │ ├── __init__.py │ │ │ │ └── zhipuai_custom.py │ │ └── embedding_component.py │ ├── llm │ │ ├── __init__.py │ │ ├── custom │ │ │ ├── .gitkeep │ │ │ └── zhipuai │ │ │ │ ├── __init__.py │ │ │ │ ├── zhipuai_custom.py │ │ │ │ └── zhipuai_info.py │ │ └── llm_component.py │ ├── minio │ │ ├── __init__.py │ │ └── minio_component.py │ ├── trace │ │ ├── __init__.py │ │ └── trace_component.py │ └── vector_store │ │ ├── __init__.py │ │ ├── chroma │ │ ├── __init__.py │ │ └── chroma.py │ │ ├── tencent │ │ ├── __init__.py │ │ └── tencent.py │ │ └── vector_store_component.py ├── constants.py ├── di.py ├── launcher.py ├── main.py ├── paths.py ├── server │ ├── config │ │ ├── __init__.py │ │ └── config_router.py │ ├── emoji │ │ ├── __init__.py │ │ ├── emoji_prompt.py │ │ ├── emoji_router.py │ │ └── emoji_service.py │ ├── health │ │ ├── __init__.py │ │ └── health_router.py │ ├── trace │ │ ├── __init__.py │ │ └── trace_router.py │ ├── utils │ │ ├── __init__.py │ │ ├── auth.py │ │ └── model.py │ └── vector_store │ │ ├── __init__.py │ │ ├── vector_store_router.py │ │ └── vector_store_server.py ├── settings │ ├── __init__.py │ ├── settings.py │ ├── settings_loader.py │ └── yaml.py └── utils │ ├── __init__.py │ └── _compat.py ├── local_data └── .gitkeep ├── log └── .gitkeep ├── poetry.lock ├── pyproject.toml ├── settings.yaml ├── tests └── __init__.py └── tools ├── datainit.py └── json2jsonl.py /.gitignore: -------------------------------------------------------------------------------- 1 | .vs/ 2 | .vscode/ 3 | .idea/ 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | pip-wheel-metadata/ 27 | share/python-wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .nox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | *.py,cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | docs/docs/_build/ 77 | 78 | # PyBuilder 79 | target/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints 83 | notebooks/ 84 | 85 | # IPython 86 | profile_default/ 87 | ipython_config.py 88 | 89 | # pyenv 90 | .python-version 91 | 92 | # pipenv 93 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 94 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 95 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 96 | # install all needed dependencies. 97 | #Pipfile.lock 98 | 99 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 100 | __pypackages__/ 101 | 102 | # Celery stuff 103 | celerybeat-schedule 104 | celerybeat.pid 105 | 106 | # SageMath parsed files 107 | *.sage.py 108 | 109 | # Environments 110 | .env 111 | .envrc 112 | .venv 113 | .venvs 114 | env/ 115 | venv/ 116 | ENV/ 117 | env.bak/ 118 | venv.bak/ 119 | 120 | # Spyder project settings 121 | .spyderproject 122 | .spyproject 123 | 124 | # Rope project settings 125 | .ropeproject 126 | 127 | # mkdocs documentation 128 | /site 129 | 130 | # mypy 131 | .mypy_cache/ 132 | .dmypy.json 133 | dmypy.json 134 | 135 | # Pyre type checker 136 | .pyre/ 137 | 138 | # macOS display setting files 139 | .DS_Store 140 | 141 | # Wandb directory 142 | wandb/ 143 | 144 | # asdf tool versions 145 | .tool-versions 146 | /.ruff_cache/ 147 | 148 | *.pkl 149 | *.bin 150 | 151 | local_data/emo-visual-data.zip 152 | local_data/emo-visual-data 153 | local_data/chromadb 154 | 155 | settings-pro.yaml -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 🥳 LangChain-Emoji 2 | 3 | 简体中文 | [English](README-en.md) 4 | 5 |

6 |

7 | 8 |

9 |

10 |
11 | 基于LangChain的开源表情包斗图Agent 12 |

13 |

14 |

15 | Python 16 | LangChain 17 | tcvectordb 18 | license 19 |

20 | 21 | > 本项目表情包数据来源于智谱 AI 团队,数据来源和相关介绍如下 22 | > https://github.com/LLM-Red-Team/emo-visual-data 23 | > 感谢开源 一起玩转大模型 🎉🌟🌟 24 | 25 | 💡💡💡 **更新:我们支持微信客户端体验啦~** 26 | 27 | #### 扫描以下二维码,添加微信,开启斗图模式 28 | 29 |

30 |

31 |
32 |

33 | 34 | ## 🚀 Quick Install 35 | 36 | ### 1.部署 Python 环境 37 | 38 | - 安装 miniconda 39 | 40 | ```shell 41 | mkdir ~/miniconda3 42 | wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda3/miniconda.sh 43 | bash ~/miniconda3/miniconda.sh -b -u -p ~/miniconda3 44 | rm -rf ~/miniconda3/miniconda.sh 45 | ~/miniconda3/bin/conda init bash 46 | ``` 47 | 48 | - 创建虚拟环境 49 | 50 | ```shell 51 | # 创建环境 52 | conda create -n LangChain-Emoji python==3.10.11 53 | ``` 54 | 55 | - 安装 poetry 56 | 57 | ```shell 58 | # 安装 59 | curl -sSL https://install.python-poetry.org | python3 - 60 | ``` 61 | 62 | ### 2. 运行 LangChain-Emoji 63 | 64 | - 安装依赖 65 | 66 | ```shell 67 | # 克隆项目代码到本地 68 | git clone https://github.com/ptonlix/LangChain-Emoji.git 69 | conda activate LangChain-Emoji # 激活环境 70 | cd LangChain-Emoji # 进入项目 71 | poetry install # 安装依赖 72 | ``` 73 | 74 | - 修改配置文件 75 | 76 | [OpenAI 文档](https://platform.openai.com/docs/introduction) 77 | [ZhipuAI 文档](https://open.bigmodel.cn/dev/howuse/introduction) 78 | [LangChain API](https://smith.langchain.com) 79 | 80 | ```shell 81 | # settings.yaml 82 | 83 | 配置文件录入或通过环境变量设置以下变量 84 | 85 | # OPENAI 大模型API 86 | OPENAI_API_BASE 87 | OPENAI_API_KEY 88 | 89 | # ZHIPUAI 智谱API 90 | ZHIPUAI_API_KEY 91 | 92 | # LangChain调试 API 93 | LANGCHAIN_API_KEY 94 | 95 | # 向量数据库,默认采用chromadb 96 | embedding #配置向量模型,默认为zhipuai 97 | 98 | # 腾讯云向量数据库配置(可选) 99 | TCVERCTORDB_API_HOST 100 | TCVERCTORDB_API_KEY 101 | 102 | # Minio云盘配置(可选) 103 | MINIO_HOST 104 | MINIO_ACCESS_KEY 105 | MINIO_SECRET_KEY 106 | 107 | ``` 108 | 109 | 详情配置文件介绍见: [LangChain-Emoji 配置](./docs/config.md) 110 | 111 | - 数据初始化 112 | 113 | 主要借助于 `tools/datainit.py `数据初始化工具完成相关操作 114 | 115 | I. 采用数据本地文件部署 116 | 117 | [➡️ 百度云盘下载](https://pan.baidu.com/s/11iwqoxLtjV-DOQli81vZ6Q?pwd=tab4) 118 | 119 | ``` 120 | 从百度下载数据,解析 121 | 地址:https://pan.baidu.com/s/11iwqoxLtjV-DOQli81vZ6Q?pwd=tab4 122 | 下载到local_data,并解压 123 | ``` 124 | 125 | [➡️ 谷歌云盘下载](https://pan.baidu.com/s/11iwqoxLtjV-DOQli81vZ6Q?pwd=tab4) 126 | 127 | ``` 128 | cd tools 129 | python datainit.py --download 130 | # 等待数据包下载并解压完成 131 | ``` 132 | 133 | Ⅱ. 采用 Minio 云盘部署(可选) 134 | 135 | 完成`步骤1`,将数据下载到 local_data 目录并解压完成 136 | 137 | 修改`settings.yaml`配置文件中 minio 的配置,填写好 138 | `MINIO_HOST` 139 | `MINIO_ACCESS_KEY` 140 | `MINIO_SECRET_KEY` 141 | 填写好这三个参数 142 | 143 | ``` 144 | cd tools 145 | python datainit.py --upload 146 | # 等待数据上传到minio完成 147 | ``` 148 | 149 | Ⅲ. 同步元数据到向量数据库 (默认采用 ChromaDB) 150 | 151 | ``` 152 | cd tools 153 | python datainit.py --vectordb 154 | # 等待数据上传到向量数据库完成 155 | ``` 156 | 157 | > **腾讯云向量数据库(可选)** 158 | > 修改`settings.yaml`配置文件中 向量数据库 的配置,填写好 159 | > `TCVERCTORDB_API_HOST` 160 | > `TCVERCTORDB_API_HOST` 161 | > 填写好这两个参数 162 | > `vectorstore` `database`选择 `tcvectordb` 163 | 164 | - 启动项目 165 | 166 | ```shell 167 | # 启动项目 168 | python -m langchain_emoji 169 | 170 | # 查看API 171 | 访问: http://localhost:8003/docs 获取 API 信息 172 | ``` 173 | 174 | - 启动 Web Demo 175 | 176 | ```shell 177 | # 进入前端目录 178 | cd frontend 179 | # 启动 180 | streamlit run emoji.py 181 | ``` 182 | 183 | ## 💡 演示效果 184 | 185 | [![IMAGE ALT TEXT](./docs/pic/example_video.gif)](./docs/pic/example_video.gif) 186 | 187 | ### 微信客户端 188 | 189 | 我们支持微信客户端体验啦~ 190 | 191 | ## 📖 项目介绍 192 | 193 | ### 1. 流程架构 194 | 195 |

196 |
197 | LangChain-Emoji流程架构图 198 |

199 | 200 | 核心流程: 201 | 202 | 1. 项目数据初始化,将表情包数据集下载并同步到向量数据库和云盘等 203 | 2. 前端传入 Prompt 通过 Retriever 召回表情包信息(文件名、文件描述)列表(默认 4 个) 204 | 3. 通过大语言模型,从表情包列表中筛选出最符合输入 Prompt 的表情包 205 | 4. 通过表情包的文件名去数据中心获取图片数据,返回前端呈现 206 | 207 | ### 2. 目录结构 208 | 209 | ``` 210 | ├── docs # 文档 211 | ├── local_data # 数据集目录 212 | ├── langchain_emoji 213 | │ ├── components #自定义组件 214 | │ ├── server # API服务 215 | │ ├── settings # 配置服务 216 | │ ├── utils 217 | │ ├── constants.py 218 | │ ├── di.py 219 | │ ├── launcher.py 220 | │ ├── main.py 221 | │ ├── paths.py 222 | │ ├── __init__.py 223 | │ ├── __main__.py #入口 224 | │ └── __version__.py 225 | ├── log # 日志目录 226 | ``` 227 | 228 | ### 3. 功能介绍 229 | 230 | - 支持 `openai` `zhipuai` `deepseek` 大模型 231 | - 支持本地向量数据库`chroma`和`腾讯云向量数据库` 232 | - 支持 配置文件动态加载 233 | - 支持 Web Demo 演示 234 | 235 | ## 🚩 Roadmap 236 | 237 | - [x] 搭建 LangChain-Emoji 初步框架,完善基本功能 238 | - [x] 支持本地向量数据库 Chroma 239 | - [x] 搭建前端 Web Demo 240 | - [x] 选择 LLM 241 | - [ ] 支持更多模型 242 | - [ ] 在线大模型: 深度求索 ⏳ 测试中 243 | - [ ] 本地大模型 244 | - [x] 接入微信客户端,开启斗图模式 245 | 246 | ## 🌏 项目交流讨论 247 | 248 | 249 | 250 | 🎉 扫码联系作者,如果你也对本项目感兴趣 251 | 🎉 欢迎加入 LangChain-X (帝阅开发社区) 项目群参与讨论交流 252 | 253 | ## 💥 贡献 254 | 255 | 欢迎大家贡献力量,一起共建 LangChain-Emoji,您可以做任何有益事情 256 | 257 | - 报告错误 258 | - 建议改进 259 | - 文档贡献 260 | - 代码贡献 261 | ... 262 | 👏👏👏 263 | 264 | --- 265 | 266 | ### [帝阅介绍](https://dread.run/#/) 267 | 268 | > 「帝阅」 269 | > 是一款个人专属知识管理与创造的 AI Native 产品 270 | > 为用户打造一位专属的侍读助理,帮助提升用户获取知识效率和发挥创造力 271 | > 让用户更好地去积累知识、管理知识、运用知识 272 | 273 | LangChain-Emoji 是帝阅项目一个子项目 274 | 275 | 欢迎大家前往体验[帝阅](https://dread.run/#/) 给我们提出宝贵的建议 276 | 277 | --- 278 | 279 |

280 | 281 |
282 |
283 | 帝阅DeepRead 284 |

285 | -------------------------------------------------------------------------------- /docs/config.md: -------------------------------------------------------------------------------- 1 | ## 配置文件说明 2 | 3 | ```yaml 4 | # 服务器配置 5 | server: 6 | env_name: ${APP_ENV:prod} 7 | port: ${PORT:8002} 8 | cors: 9 | enabled: false 10 | allow_origins: ["*"] 11 | allow_methods: ["*"] 12 | allow_headers: ["*"] 13 | auth: 14 | enabled: true #是否开启认证 15 | secret: "Basic c2VjcmV0OmtleQ==" # Http Authorization认证 16 | 17 | # 大模型配置 18 | # 选项有4个 openai zhipuai mock openai+zhipuai 19 | # openai+zhipuai 表示同时支持两个模型,根据API传入参数决定使用哪个大模型 20 | llm: 21 | mode: openai+zhipuai 22 | 23 | # 向量模型 24 | # 选项有3个 openai zhipuai mock 25 | # 国内环境建议选择zhipuai比较稳定 26 | # 如果使用腾讯云向量数据库,此参数可以忽略 27 | embedding: 28 | mode: zhipuai 29 | 30 | # openai模型参数 31 | openai: 32 | temperature: 1 33 | modelname: "gpt-3.5-turbo-0125" #"gpt-3.5-turbo-1106" 34 | api_base: ${OPENAI_API_BASE:} 35 | api_key: ${OPENAI_API_KEY:} 36 | 37 | # zhipuai模型参数 38 | zhipuai: 39 | temperature: 0.95 40 | top_p: 0.6 41 | modelname: "glm-3-turbo" 42 | api_key: ${ZHIPUAI_API_KEY:} 43 | 44 | # LangSmith调试参数 45 | # 详情见 https://smith.langchain.com 46 | langsmith: 47 | trace_version_v2: true 48 | api_key: ${LANGCHAIN_API_KEY:} 49 | 50 | # 向量数据库参数 51 | vectorstore: 52 | database: tcvectordb #向量数据库类型,目前暂时只支持腾讯云向量数据库 53 | tcvectordb: # 配置详情见 https://cloud.tencent.com/document/product/1709 54 | url: ${TCVERCTORDB_API_HOST:} #腾讯云API请求地址 55 | username: root #账号 56 | api_key: ${TCVERCTORDB_API_KEY:} #腾讯云向量数据库api key 57 | collection_name: EmojiCollection #表名称 58 | database_name: DeepReadDatabase #数据库名称 59 | 60 | # 表情包数据集信息 61 | dataset: 62 | name: emo-visual-data # 数据集文件名称 63 | google_driver_id: 1r3uO0wvgQ791M_6iIyBODo_8GekBjPMf # 谷歌云盘ID 64 | mode: local #采用何种数据集加载方式,目前支持 local(本地) 、 Minio(云盘) 65 | 66 | # 数据本地存储信息 67 | data: 68 | local_data_folder: local_data # 本地存储路径,以项目根目录为启始 69 | 70 | # 采用Minio来存储数据 71 | # https://www.minio.org.cn/docs/minio/linux/operations/installation.html 72 | minio: 73 | host: ${MINIO_HOST:} #Minio 请求地址 74 | bucket_name: emoji #数据桶名称 75 | access_key: ${MINIO_ACCESS_KEY:} #密钥Key 76 | secret_key: ${MINIO_SECRET_KEY:} #密钥Key 77 | ``` 78 | 79 | ## 私有配置文件 80 | 81 | 由于配置文件涉及一些 API KEY 等隐私信息,在不改动默认配置文件的情况下,可以新增一个单独的私有配置文件,进行加载 82 | 83 | 详情见 `langchain_emoji/settings` 代码 84 | 85 | ```shell 86 | # 1. 设置环境变量 87 | export LE_PROFILES=pro 88 | 89 | # 2. 新增配置文件 90 | vim settings-pro.yaml 91 | 92 | # 3. 复制默认配置文件,增加API_KEY等信息 93 | 94 | # 4. 启动项目,程序会自动合并两个配置文件,冲突地方以settings-pro.yaml为准 95 | 96 | ``` 97 | 98 | ## 动态加载配置文件 99 | 100 | 通过监听 yaml 配置文件发生内容变化,程序会重新加载文件,方便实时调整参数 101 | 102 | 详情见 `langchain_emoji/__main__.py` 代码 103 | -------------------------------------------------------------------------------- /docs/description.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ptonlix/LangChain-Emoji/4c0a926464a7419e48df32b3c5e26794e2f9b312/docs/description.md -------------------------------------------------------------------------------- /docs/pic/emoji_logo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ptonlix/LangChain-Emoji/4c0a926464a7419e48df32b3c5e26794e2f9b312/docs/pic/emoji_logo.jpg -------------------------------------------------------------------------------- /docs/pic/example_video.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ptonlix/LangChain-Emoji/4c0a926464a7419e48df32b3c5e26794e2f9b312/docs/pic/example_video.gif -------------------------------------------------------------------------------- /docs/pic/flowstruct.drawio: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | 153 | 154 | 155 | 156 | 157 | 158 | 159 | 160 | 161 | 162 | 163 | -------------------------------------------------------------------------------- /docs/pic/flowstruct.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ptonlix/LangChain-Emoji/4c0a926464a7419e48df32b3c5e26794e2f9b312/docs/pic/flowstruct.png -------------------------------------------------------------------------------- /docs/pic/introduce.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ptonlix/LangChain-Emoji/4c0a926464a7419e48df32b3c5e26794e2f9b312/docs/pic/introduce.jpg -------------------------------------------------------------------------------- /docs/pic/logo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ptonlix/LangChain-Emoji/4c0a926464a7419e48df32b3c5e26794e2f9b312/docs/pic/logo.jpg -------------------------------------------------------------------------------- /docs/pic/wx.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ptonlix/LangChain-Emoji/4c0a926464a7419e48df32b3c5e26794e2f9b312/docs/pic/wx.png -------------------------------------------------------------------------------- /docs/pic/wxemoji.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ptonlix/LangChain-Emoji/4c0a926464a7419e48df32b3c5e26794e2f9b312/docs/pic/wxemoji.png -------------------------------------------------------------------------------- /docs/systemd.md: -------------------------------------------------------------------------------- 1 | # Linux 设置开机自启动 2 | 3 | ``` 4 | # /usr/lib/systemd/system/langchain-emoji.service 5 | [Unit] 6 | Description=Langchain-Emoji 7 | After=network.target 8 | 9 | [Service] 10 | Type=simple 11 | ExecStart=/root/miniconda3/envs/LangChain-Emoji/bin/python -m langchain_emoji 12 | 13 | [Install] 14 | WantedBy=multi-user.target 15 | ``` 16 | 17 | ``` 18 | systemctl enable langchain-emoji.service 开机自启动 19 | systemctl start langchain-emoji 20 | 21 | # 服务启动失败,可通过下面命令查看原因 22 | journalctl -u langchain-emoji -f 23 | 24 | ``` 25 | -------------------------------------------------------------------------------- /frontend/emoji.py: -------------------------------------------------------------------------------- 1 | import aiohttp 2 | from typing import Optional, Dict, Any 3 | from langchain_emoji.settings.settings import settings 4 | from pydantic import ValidationError 5 | import logging 6 | import asyncio 7 | import streamlit as st 8 | import uuid 9 | import base64 10 | from PIL import Image 11 | from io import BytesIO 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | 16 | class ApiRequest: 17 | def __init__(self, loop: asyncio.AbstractEventLoop, base_url: str): 18 | self.base_url = base_url 19 | self.loop = loop 20 | self.timeout = aiohttp.ClientTimeout(total=10 * 60) 21 | self.session = None 22 | 23 | async def initialize_session(self): 24 | self.session = aiohttp.ClientSession(loop=self.loop, timeout=self.timeout) 25 | 26 | async def _request( 27 | self, 28 | method: str, 29 | endpoint: str, 30 | params: Optional[Dict[str, Any]] = None, 31 | json: Optional[Dict[str, Any]] = None, 32 | headers: Optional[Dict[str, str]] = None, 33 | ) -> Dict[str, Any]: 34 | if self.session is None: 35 | await self.initialize_session() 36 | url = f"{self.base_url}{endpoint}" 37 | headers = headers or {} 38 | headers["Content-Type"] = "application/json" 39 | try: 40 | async with self.session.request( 41 | method, url, params=params, json=json, headers=headers 42 | ) as response: 43 | return await response.json() 44 | except aiohttp.ServerTimeoutError as timeout_error: 45 | logger.warning( 46 | f"Timeout error in {method} request to {url}: {timeout_error}" 47 | ) 48 | raise 49 | except aiohttp.ClientError as e: 50 | logger.warning(f"Error in {method} request to {url}: {e}") 51 | raise 52 | except Exception as e: 53 | logger.warning(f"Unexpected error in {method} request to {url}: {e}") 54 | raise 55 | 56 | async def get( 57 | self, 58 | endpoint: str, 59 | params: Optional[Dict[str, Any]] = None, 60 | headers: Optional[Dict[str, str]] = None, 61 | ) -> Dict[str, Any]: 62 | return await self._request("GET", endpoint, params=params, headers=headers) 63 | 64 | async def post( 65 | self, 66 | endpoint: str, 67 | data: Optional[Dict[str, Any]] = None, 68 | headers: Optional[Dict[str, str]] = None, 69 | ) -> Dict[str, Any]: 70 | return await self._request("POST", endpoint, json=data, headers=headers) 71 | 72 | async def close(self): 73 | await self.session.close() 74 | 75 | 76 | class EmojiApiHandler(ApiRequest): 77 | def __init__(self, loop: asyncio.AbstractEventLoop): 78 | super().__init__( 79 | loop=loop, base_url="http://127.0.0.1:" + str(settings().server.port) 80 | ) 81 | self.accesstoken = settings().server.auth.secret 82 | 83 | async def emoji( 84 | self, 85 | data: Dict, 86 | endpoint: str = "/v1/emoji", 87 | headers: Optional[Dict[str, str]] = None, 88 | ) -> Dict | None: 89 | headers = headers or {} 90 | headers["Authorization"] = f"Basic {self.accesstoken}" 91 | logger.info(data) 92 | try: 93 | resp_dict = await self.post(endpoint=endpoint, data=data, headers=headers) 94 | # logger.info(resp_dict) 95 | return resp_dict 96 | except ValidationError as e: 97 | logger.warning(f"Unexpected ValidationError error in UserReg request : {e}") 98 | return None 99 | except Exception as e: 100 | logger.exception(f"Unexpected error in UserReg request : {e}") 101 | return None 102 | 103 | 104 | async def handler_progress_bar(progress_bar): 105 | for i in range(8): 106 | progress_bar.progress(10 * i) 107 | await asyncio.sleep(0.1) 108 | 109 | 110 | def decodebase64(context: str): 111 | img_data = base64.b64decode(context) # 解码时只要内容部分 112 | 113 | image = Image.open(BytesIO(img_data)) 114 | return image 115 | 116 | 117 | def fetch_emoji(prompt: str, llm: str, progress_bar): 118 | 119 | loop = asyncio.new_event_loop() 120 | asyncio.set_event_loop(loop) 121 | emoji_api = EmojiApiHandler(loop) 122 | 123 | async def emoji_request(): 124 | try: 125 | response = await emoji_api.emoji( 126 | data={"prompt": prompt, "req_id": str(uuid.uuid4()), "llm": llm} 127 | ) 128 | return response.get("data") 129 | finally: 130 | await emoji_api.close() 131 | 132 | try: 133 | emoji_task = loop.create_task(emoji_request()) 134 | progress_task = loop.create_task(handler_progress_bar(progress_bar)) 135 | result = loop.run_until_complete(asyncio.gather(emoji_task, progress_task)) 136 | return result[0] 137 | finally: 138 | loop.close() 139 | 140 | 141 | def area_image(area_placeholder, image, filename): 142 | col1, col2, col3 = area_placeholder.columns([1, 3, 1]) 143 | with col1: 144 | st.write(" ") 145 | 146 | with col2: 147 | st.image( 148 | image=image, 149 | caption=filename, 150 | use_column_width=True, 151 | ) 152 | with col3: 153 | st.write(" ") 154 | 155 | 156 | st.subheader("🥳 LangChain-Emoji") 157 | st.markdown( 158 | "基于LangChain的开源表情包斗图Agent [Github](https://github.com/ptonlix/LangChain-Emoji) [作者: Baird](https://github.com/ptonlix)" 159 | ) 160 | 161 | with st.container(): 162 | image_path = "../docs/pic/logo.jpg" 163 | 164 | with st.container(): 165 | image_placeholder = st.empty() 166 | area_placeholder = st.empty() 167 | content_placeholder = st.empty() 168 | area_image(area_placeholder, image_path, "帝阅DeepRead") 169 | 170 | with st.form("emoji_form"): 171 | st.write("Emoji Input") 172 | select_llm = st.radio( 173 | "Please select a LLM", 174 | ["ChatGPT", "ZhipuAI", "DeepSeek"], 175 | captions=["OpenAI", "智谱清言", "深度求索"], 176 | ) 177 | llm_mapping = { 178 | "ChatGPT": "openai", 179 | "ZhipuAI": "zhipuai", 180 | "DeepSeek": "deepseek", 181 | } 182 | llm = llm_mapping[select_llm] 183 | 184 | prompt = st.text_input("Enter Emoji Prompt:", value="今天很开心~") 185 | 186 | submitted = st.form_submit_button("Submit", help="点击获取最佳表情包") 187 | if submitted: 188 | bar = st.progress(0) 189 | response = fetch_emoji(prompt, llm, bar) 190 | bar.progress(100) 191 | base64_str = response["emojidetail"]["base64"] 192 | filename = response["emojiinfo"]["filename"] 193 | content = response["emojiinfo"]["content"] 194 | area_image(area_placeholder, decodebase64(base64_str), filename) 195 | content_placeholder.write(content) 196 | -------------------------------------------------------------------------------- /langchain_emoji/__init__.py: -------------------------------------------------------------------------------- 1 | """langchain_emoji.""" 2 | 3 | import logging 4 | from logging.handlers import TimedRotatingFileHandler 5 | from langchain_emoji.constants import PROJECT_ROOT_PATH 6 | 7 | # Set to 'DEBUG' to have extensive logging turned on, even for libraries 8 | ROOT_LOG_LEVEL = "INFO" 9 | 10 | PRETTY_LOG_FORMAT = "%(asctime)s.%(msecs)03d [%(levelname)-5s] [%(filename)-12s][line:%(lineno)d] - %(message)s" 11 | logging.basicConfig(level=ROOT_LOG_LEVEL, format=PRETTY_LOG_FORMAT, datefmt="%H:%M:%S") 12 | logging.captureWarnings(True) 13 | 14 | 15 | file_handler = TimedRotatingFileHandler( 16 | filename=PROJECT_ROOT_PATH / "log/app.log", 17 | when="midnight", 18 | interval=1, 19 | backupCount=7, 20 | ) 21 | file_handler.suffix = "%Y-%m-%d.log" 22 | file_handler.encoding = "utf-8" 23 | file_handler.setLevel(ROOT_LOG_LEVEL) 24 | # 创建格式化器并将其添加到处理器 25 | file_formatter = logging.Formatter(PRETTY_LOG_FORMAT) 26 | file_handler.setFormatter(file_formatter) 27 | 28 | logging.getLogger().addHandler(file_handler) 29 | -------------------------------------------------------------------------------- /langchain_emoji/__main__.py: -------------------------------------------------------------------------------- 1 | # start a fastapi server with uvicorn 2 | 3 | import glob 4 | from langchain_emoji.settings.settings import settings 5 | from langchain_emoji.constants import PROJECT_ROOT_PATH 6 | 7 | from uvicorn import Config, Server 8 | from uvicorn.supervisors.watchfilesreload import WatchFilesReload 9 | 10 | from pathlib import Path 11 | import hashlib 12 | from socket import socket 13 | from typing import Callable 14 | from watchfiles import Change, watch 15 | 16 | 17 | """ 18 | 自定义uvicorn启动类 19 | """ 20 | 21 | 22 | class FileFilter: 23 | def __init__(self, config: Config): 24 | default_includes = ["*.py"] 25 | self.includes = [ 26 | default 27 | for default in default_includes 28 | if default not in config.reload_excludes 29 | ] 30 | self.includes.extend(config.reload_includes) 31 | self.includes = list(set(self.includes)) 32 | 33 | default_excludes = [".*", ".py[cod]", ".sw.*", "~*"] 34 | self.excludes = [ 35 | default 36 | for default in default_excludes 37 | if default not in config.reload_includes 38 | ] 39 | self.exclude_dirs = [] 40 | for e in config.reload_excludes: 41 | p = Path(e) 42 | try: 43 | is_dir = p.is_dir() 44 | except OSError: # pragma: no cover 45 | # gets raised on Windows for values like "*.py" 46 | is_dir = False 47 | 48 | if is_dir: 49 | self.exclude_dirs.append(p) 50 | else: 51 | self.excludes.append(e) 52 | self.excludes = list(set(self.excludes)) 53 | 54 | def __call__(self, change: "Change", path: Path | str) -> bool: 55 | if isinstance(path, str): 56 | path = Path(path) 57 | for include_pattern in self.includes: 58 | if path.match(include_pattern): 59 | if str(path).endswith(include_pattern): 60 | return True 61 | 62 | for exclude_dir in self.exclude_dirs: 63 | if exclude_dir in path.parents: 64 | return False 65 | 66 | for exclude_pattern in self.excludes: 67 | if path.match(exclude_pattern): 68 | return False 69 | 70 | return True 71 | return False 72 | 73 | 74 | class CustomWatchFilesReload(WatchFilesReload): 75 | 76 | def __init__( 77 | self, 78 | config: Config, 79 | target: Callable[[list[socket] | None], None], 80 | sockets: list[socket], 81 | ) -> None: 82 | super().__init__(config, target, sockets) 83 | self.reloader_name = "WatchFiles" 84 | self.reload_dirs = [] 85 | self.ignore_paths = [] 86 | for directory in config.reload_dirs: 87 | self.reload_dirs.append(directory) 88 | 89 | self.watch_filter = FileFilter(config) 90 | self.watcher = watch( 91 | *self.reload_dirs, 92 | watch_filter=self.watch_filter, 93 | stop_event=self.should_exit, 94 | # using yield_on_timeout here mostly to make sure tests don't 95 | # hang forever, won't affect the class's behavior 96 | yield_on_timeout=True, 97 | ) 98 | 99 | self.file_hashes = {} # Store file hashes 100 | 101 | # Calculate and store hashes for initial files 102 | for directory in self.reload_dirs: 103 | for file_path in directory.rglob("*"): 104 | if file_path.is_file() and self.watch_filter(None, file_path): 105 | self.file_hashes[str(file_path)] = self.calculate_file_hash( 106 | file_path 107 | ) 108 | 109 | def should_restart(self) -> list[Path] | None: 110 | self.pause() 111 | 112 | changes = next(self.watcher) 113 | if changes: 114 | changed_paths = [] 115 | for event_type, path in changes: 116 | if event_type == Change.modified or event_type == Change.added: 117 | file_hash = self.calculate_file_hash(path) 118 | if ( 119 | path not in self.file_hashes 120 | or self.file_hashes[path] != file_hash 121 | ): 122 | changed_paths.append(Path(path)) 123 | self.file_hashes[path] = file_hash 124 | 125 | if changed_paths: 126 | return [p for p in changed_paths if self.watch_filter(None, p)] 127 | 128 | return None 129 | 130 | def calculate_file_hash(self, file_path: str) -> str: 131 | with open(file_path, "rb") as file: 132 | file_contents = file.read() 133 | return hashlib.md5(file_contents).hexdigest() 134 | 135 | 136 | non_yaml_files = [ 137 | f 138 | for f in glob.glob("**", root_dir=PROJECT_ROOT_PATH, recursive=True) 139 | if not f.lower().endswith((".yaml", ".yml")) 140 | ] 141 | try: 142 | config = Config( 143 | app="langchain_emoji.main:app", 144 | host="0.0.0.0", 145 | port=settings().server.port, 146 | reload=True, 147 | reload_dirs=str(PROJECT_ROOT_PATH), 148 | reload_excludes=non_yaml_files, 149 | reload_includes="*.yaml", 150 | log_config=None, 151 | ) 152 | 153 | server = Server(config=config) 154 | 155 | sock = config.bind_socket() 156 | CustomWatchFilesReload(config, target=server.run, sockets=[sock]).run() 157 | except KeyboardInterrupt: 158 | ... 159 | -------------------------------------------------------------------------------- /langchain_emoji/__version__.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import TYPE_CHECKING 4 | 5 | from langchain_emoji.utils._compat import metadata 6 | 7 | 8 | if TYPE_CHECKING: 9 | from collections.abc import Callable 10 | 11 | 12 | # The metadata.version that we import for Python 3.7 is untyped, work around 13 | # that. 14 | version: Callable[[str], str] = metadata.version 15 | 16 | __version__ = version("langchain_emoji") 17 | -------------------------------------------------------------------------------- /langchain_emoji/components/embedding/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ptonlix/LangChain-Emoji/4c0a926464a7419e48df32b3c5e26794e2f9b312/langchain_emoji/components/embedding/__init__.py -------------------------------------------------------------------------------- /langchain_emoji/components/embedding/custom/zhipuai/__init__.py: -------------------------------------------------------------------------------- 1 | from langchain_emoji.components.embedding.custom.zhipuai.zhipuai_custom import ( 2 | ZhipuaiTextEmbeddings, 3 | ) 4 | 5 | __all__ = ["ZhipuaiTextEmbeddings"] 6 | -------------------------------------------------------------------------------- /langchain_emoji/components/embedding/custom/zhipuai/zhipuai_custom.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Optional 2 | 3 | from langchain_core.embeddings import Embeddings 4 | from langchain_core.pydantic_v1 import BaseModel, root_validator 5 | from langchain_core.utils import get_from_dict_or_env 6 | from packaging.version import parse 7 | from importlib.metadata import version 8 | import asyncio 9 | import logging 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | def is_zhipu_v2() -> bool: 15 | """Return whether zhipu API is v2 or more.""" 16 | _version = parse(version("zhipuai")) 17 | return _version.major >= 2 18 | 19 | 20 | # Official Website: https://open.bigmodel.cn/dev/api#text_embedding 21 | # An API-key is required to use this embedding model. You can get one by registering 22 | class ZhipuaiTextEmbeddings(BaseModel, Embeddings): 23 | """Zhipuai Text Embedding models.""" 24 | 25 | client: Any # ZhipuAI #: :meta private: 26 | model_name: str = "embedding-2" 27 | zhipuai_api_key: Optional[str] = None 28 | count_token: int = 0 29 | 30 | @root_validator(allow_reuse=True) 31 | def validate_environment(cls, values: Dict) -> Dict: 32 | """Validate that auth token exists in environment.""" 33 | 34 | try: 35 | from zhipuai import ZhipuAI 36 | 37 | if not is_zhipu_v2(): 38 | raise RuntimeError( 39 | "zhipuai package version is too low" 40 | "Please install it via 'pip install --upgrade zhipuai'" 41 | ) 42 | 43 | zhipuai_api_key = get_from_dict_or_env( 44 | values, "zhipuai_api_key", "ZHIPUAI_API_KEY" 45 | ) 46 | 47 | client = ZhipuAI( 48 | api_key=zhipuai_api_key, 49 | ) 50 | values["client"] = client 51 | return values 52 | except ImportError: 53 | raise RuntimeError( 54 | "Could not import zhipuai package. " 55 | "Please install it via 'pip install zhipuai'" 56 | ) 57 | 58 | def _embed(self, texts: List[str]) -> Optional[List[List[float]]]: 59 | """Internal method to call Zhipuai Embedding API and return embeddings. 60 | 61 | Args: 62 | texts: A list of texts to embed. 63 | 64 | Returns: 65 | A list of list of floats representing the embeddings, or None if an 66 | error occurs. 67 | """ 68 | # try: 69 | # return [self._get_embedding(text) for text in texts] 70 | 71 | # except Exception as e: 72 | # logger.exception(e) 73 | # # Log the exception or handle it as needed 74 | # logger.info( 75 | # f"Exception occurred while trying to get embeddings: {str(e)}" 76 | # ) # noqa: T201 77 | # return None 78 | return asyncio.run(self._aembed(texts)) 79 | 80 | def embed_documents(self, texts: List[str]) -> Optional[List[List[float]]]: # type: ignore[override] 81 | """Public method to get embeddings for a list of documents. 82 | 83 | Args: 84 | texts: The list of texts to embed. 85 | 86 | Returns: 87 | A list of embeddings, one for each text, or None if an error occurs. 88 | """ 89 | return self._embed(texts) 90 | 91 | def embed_query(self, text: str) -> Optional[List[float]]: # type: ignore[override] 92 | """Public method to get embedding for a single query text. 93 | 94 | Args: 95 | text: The text to embed. 96 | 97 | Returns: 98 | Embeddings for the text, or None if an error occurs. 99 | """ 100 | result = self._embed([text]) 101 | return result[0] if result is not None else None 102 | 103 | def _get_embedding(self, text: str) -> List[float]: 104 | response = self.client.embeddings.create(model=self.model_name, input=text) 105 | self.count_token += response.usage.total_tokens 106 | return response.data[0].embedding 107 | 108 | async def _aet_embedding(self, text: str) -> List[float]: 109 | response = self.client.embeddings.create(model=self.model_name, input=text) 110 | self.count_token += response.usage.total_tokens # 统计token 111 | return response.data[0].embedding 112 | 113 | async def _aembed(self, texts: List[str]) -> Optional[List[List[float]]]: 114 | """Internal method to call Zhipuai Embedding API and return embeddings. 115 | 116 | Args: 117 | texts: A list of texts to embed. 118 | 119 | Returns: 120 | A list of list of floats representing the embeddings, or None if an 121 | error occurs. 122 | """ 123 | try: 124 | tasks = [asyncio.create_task(self._aet_embedding(text)) for text in texts] 125 | return await asyncio.gather(*tasks) 126 | 127 | except Exception as e: 128 | logger.exception(e) 129 | # Log the exception or handle it as needed 130 | logger.info( 131 | f"Exception occurred while trying to get embeddings: {str(e)}" 132 | ) # noqa: T201 133 | return None 134 | 135 | async def aembed_documents(self, texts: List[str]) -> List[List[float]]: 136 | """Asynchronous Embed search docs.""" 137 | return await self._aembed(texts) 138 | 139 | async def aembed_query(self, text: str) -> List[float]: 140 | """Asynchronous Embed query text.""" 141 | result = await self._aembed([text]) 142 | return result[0] if result is not None else None 143 | 144 | 145 | if __name__ == "__main__": 146 | 147 | # async def main(): 148 | # client = ZhipuaiTextEmbeddings( 149 | # zhipuai_api_key="69c26b8240f40e316f2b6a20c991fde2.3X8YMXlh8udmdhwJ" 150 | # ) 151 | 152 | # print( 153 | # await client.aembed_documents( 154 | # ["你好,帝阅AI搜索", "帝阅DeepRead", "帝阅No.1"] 155 | # ) 156 | # ) 157 | 158 | # asyncio.run(main()) 159 | 160 | client = ZhipuaiTextEmbeddings( 161 | zhipuai_api_key="69c26b8240f40e316f2b6a20c991fde2.3X8YMXlh8udmdhwJ" 162 | ) 163 | client.embed_documents(["你好,帝阅AI搜索", "帝阅DeepRead", "帝阅No.1"]) 164 | print(client.count_token) 165 | -------------------------------------------------------------------------------- /langchain_emoji/components/embedding/embedding_component.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from injector import inject 4 | from langchain_community.embeddings import OpenAIEmbeddings 5 | from langchain_core.embeddings import Embeddings, DeterministicFakeEmbedding 6 | 7 | from langchain_emoji.settings.settings import Settings 8 | from langchain_emoji.components.embedding.custom.zhipuai import ZhipuaiTextEmbeddings 9 | 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | class EmbeddingComponent: 15 | @inject 16 | def __init__(self, settings: Settings) -> None: 17 | embedding_mode = settings.embedding.mode 18 | logger.info("Initializing the embedding in mode=%s", embedding_mode) 19 | match embedding_mode: 20 | case "local": 21 | ... # Todo 22 | case "openai": 23 | openai_settings = settings.openai 24 | self._embedding = OpenAIEmbeddings( 25 | api_key=openai_settings.api_key, 26 | openai_api_base=openai_settings.api_base, 27 | ) 28 | case "zhipuai": 29 | zhipuai_settings = settings.zhipuai 30 | self._embedding = ZhipuaiTextEmbeddings( 31 | zhipuai_api_key=zhipuai_settings.api_key 32 | ) 33 | case "mock": 34 | self._embedding = DeterministicFakeEmbedding(size=1352) 35 | 36 | @property 37 | def embedding(self) -> Embeddings: 38 | return self._embedding 39 | 40 | @property 41 | def total_tokens(self) -> int: 42 | try: 43 | return self._embedding.count_token # 目前只支持Zhipuai统计 embedding token 44 | except Exception: 45 | return 0 46 | -------------------------------------------------------------------------------- /langchain_emoji/components/llm/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ptonlix/LangChain-Emoji/4c0a926464a7419e48df32b3c5e26794e2f9b312/langchain_emoji/components/llm/__init__.py -------------------------------------------------------------------------------- /langchain_emoji/components/llm/custom/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ptonlix/LangChain-Emoji/4c0a926464a7419e48df32b3c5e26794e2f9b312/langchain_emoji/components/llm/custom/.gitkeep -------------------------------------------------------------------------------- /langchain_emoji/components/llm/custom/zhipuai/__init__.py: -------------------------------------------------------------------------------- 1 | from langchain_emoji.components.llm.custom.zhipuai.zhipuai_custom import ( 2 | ChatZhipuAI, 3 | ) 4 | from langchain_emoji.components.llm.custom.zhipuai.zhipuai_info import ( 5 | ZhipuAICallbackHandler, 6 | get_zhipuai_callback, 7 | ) 8 | 9 | 10 | __all__ = ["ChatZhipuAI, ZhipuAICallbackHandler, get_zhipuai_callback"] 11 | -------------------------------------------------------------------------------- /langchain_emoji/components/llm/custom/zhipuai/zhipuai_custom.py: -------------------------------------------------------------------------------- 1 | """ZHIPU AI chat models wrapper.""" 2 | from __future__ import annotations 3 | 4 | import asyncio 5 | import logging 6 | from functools import partial 7 | from importlib.metadata import version 8 | from typing import ( 9 | Any, 10 | Callable, 11 | Dict, 12 | Iterator, 13 | List, 14 | Mapping, 15 | Optional, 16 | Tuple, 17 | Type, 18 | Union, 19 | ) 20 | 21 | from langchain_core.callbacks import ( 22 | AsyncCallbackManagerForLLMRun, 23 | CallbackManagerForLLMRun, 24 | ) 25 | from langchain_core.language_models.chat_models import ( 26 | BaseChatModel, 27 | generate_from_stream, 28 | ) 29 | from langchain_core.language_models.llms import create_base_retry_decorator 30 | from langchain_core.messages import ( 31 | AIMessage, 32 | AIMessageChunk, 33 | BaseMessage, 34 | BaseMessageChunk, 35 | ChatMessage, 36 | ChatMessageChunk, 37 | HumanMessage, 38 | HumanMessageChunk, 39 | SystemMessage, 40 | SystemMessageChunk, 41 | ToolMessage, 42 | ToolMessageChunk, 43 | ) 44 | from langchain_core.outputs import ( 45 | ChatGeneration, 46 | ChatGenerationChunk, 47 | ChatResult, 48 | ) 49 | from langchain_core.pydantic_v1 import BaseModel, Field 50 | from packaging.version import parse 51 | 52 | logger = logging.getLogger(__name__) 53 | 54 | 55 | def is_zhipu_v2() -> bool: 56 | """Return whether zhipu API is v2 or more.""" 57 | _version = parse(version("zhipuai")) 58 | return _version.major >= 2 59 | 60 | 61 | def _create_retry_decorator( 62 | llm: ChatZhipuAI, 63 | run_manager: Optional[ 64 | Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun] 65 | ] = None, 66 | ) -> Callable[[Any], Any]: 67 | import zhipuai 68 | 69 | errors = [ 70 | zhipuai.ZhipuAIError, 71 | zhipuai.APIStatusError, 72 | zhipuai.APIRequestFailedError, 73 | zhipuai.APIReachLimitError, 74 | zhipuai.APIInternalError, 75 | zhipuai.APIServerFlowExceedError, 76 | zhipuai.APIResponseError, 77 | zhipuai.APIResponseValidationError, 78 | zhipuai.APITimeoutError, 79 | ] 80 | return create_base_retry_decorator( 81 | error_types=errors, max_retries=llm.max_retries, run_manager=run_manager 82 | ) 83 | 84 | 85 | def convert_message_to_dict(message: BaseMessage) -> dict: 86 | """Convert a LangChain message to a dictionary. 87 | 88 | Args: 89 | message: The LangChain message. 90 | 91 | Returns: 92 | The dictionary. 93 | """ 94 | message_dict: Dict[str, Any] 95 | if isinstance(message, ChatMessage): 96 | message_dict = {"role": message.role, "content": message.content} 97 | elif isinstance(message, HumanMessage): 98 | message_dict = {"role": "user", "content": message.content} 99 | elif isinstance(message, AIMessage): 100 | message_dict = {"role": "assistant", "content": message.content} 101 | if "tool_calls" in message.additional_kwargs: 102 | message_dict["tool_calls"] = message.additional_kwargs["tool_calls"] 103 | # If tool calls only, content is None not empty string 104 | if message_dict["content"] == "": 105 | message_dict["content"] = None 106 | elif isinstance(message, SystemMessage): 107 | message_dict = {"role": "system", "content": message.content} 108 | elif isinstance(message, ToolMessage): 109 | message_dict = { 110 | "role": "tool", 111 | "content": message.content, 112 | "tool_call_id": message.tool_call_id, 113 | } 114 | else: 115 | raise TypeError(f"Got unknown type {message}") 116 | if "name" in message.additional_kwargs: 117 | message_dict["name"] = message.additional_kwargs["name"] 118 | return message_dict 119 | 120 | 121 | def convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage: 122 | """Convert a dictionary to a LangChain message. 123 | 124 | Args: 125 | _dict: The dictionary. 126 | 127 | Returns: 128 | The LangChain message. 129 | """ 130 | role = _dict.get("role") 131 | if role == "user": 132 | return HumanMessage(content=_dict.get("content", "")) 133 | elif role == "assistant": 134 | content = _dict.get("content", "") or "" 135 | additional_kwargs: Dict = {} 136 | if tool_calls := _dict.get("tool_calls"): 137 | additional_kwargs["tool_calls"] = tool_calls 138 | return AIMessage(content=content, additional_kwargs=additional_kwargs) 139 | elif role == "system": 140 | return SystemMessage(content=_dict.get("content", "")) 141 | elif role == "tool": 142 | additional_kwargs = {} 143 | if "name" in _dict: 144 | additional_kwargs["name"] = _dict["name"] 145 | return ToolMessage( 146 | content=_dict.get("content", ""), 147 | tool_call_id=_dict.get("tool_call_id"), 148 | additional_kwargs=additional_kwargs, 149 | ) 150 | else: 151 | return ChatMessage(content=_dict.get("content", ""), role=role) 152 | 153 | 154 | def _convert_delta_to_message_chunk( 155 | _dict: Mapping[str, Any], default_class: Type[BaseMessageChunk] 156 | ) -> BaseMessageChunk: 157 | role = _dict.get("role") 158 | content = _dict.get("content") or "" 159 | additional_kwargs: Dict = {} 160 | if _dict.get("tool_calls"): 161 | additional_kwargs["tool_calls"] = _dict["tool_calls"] 162 | 163 | if role == "user" or default_class == HumanMessageChunk: 164 | return HumanMessageChunk(content=content) 165 | elif role == "assistant" or default_class == AIMessageChunk: 166 | return AIMessageChunk(content=content, additional_kwargs=additional_kwargs) 167 | elif role == "system" or default_class == SystemMessageChunk: 168 | return SystemMessageChunk(content=content) 169 | elif role == "tool" or default_class == ToolMessageChunk: 170 | return ToolMessageChunk(content=content, tool_call_id=_dict["tool_call_id"]) 171 | elif role or default_class == ChatMessageChunk: 172 | return ChatMessageChunk(content=content, role=role) 173 | else: 174 | return default_class(content=content) 175 | 176 | 177 | class ChatZhipuAI(BaseChatModel): 178 | """ 179 | `ZHIPU AI` large language chat models API. 180 | 181 | To use, you should have the ``zhipuai`` python package installed. 182 | 183 | Example: 184 | .. code-block:: python 185 | 186 | from langchain_community.chat_models import ChatZhipuAI 187 | 188 | zhipuai_chat = ChatZhipuAI( 189 | temperature=0.5, 190 | api_key="your-api-key", 191 | model_name="glm-3-turbo", 192 | ) 193 | 194 | """ 195 | 196 | zhipuai: Any 197 | zhipuai_api_key: Optional[str] = Field(default=None, alias="api_key") 198 | """Automatically inferred from env var `ZHIPUAI_API_KEY` if not provided.""" 199 | 200 | client: Any = Field(default=None, exclude=True) #: :meta private: 201 | 202 | model_name: str = Field("glm-3-turbo", alias="model") 203 | """ 204 | Model name to use. 205 | -glm-3-turbo: 206 | According to the input of natural language instructions to complete a 207 | variety of language tasks, it is recommended to use SSE or asynchronous 208 | call request interface. 209 | -glm-4: 210 | According to the input of natural language instructions to complete a 211 | variety of language tasks, it is recommended to use SSE or asynchronous 212 | call request interface. 213 | """ 214 | 215 | temperature: float = Field(0.95) 216 | """ 217 | What sampling temperature to use. The value ranges from 0.0 to 1.0 and cannot 218 | be equal to 0. 219 | The larger the value, the more random and creative the output; The smaller 220 | the value, the more stable or certain the output will be. 221 | You are advised to adjust top_p or temperature parameters based on application 222 | scenarios, but do not adjust the two parameters at the same time. 223 | """ 224 | 225 | top_p: float = Field(0.7) 226 | """ 227 | Another method of sampling temperature is called nuclear sampling. The value 228 | ranges from 0.0 to 1.0 and cannot be equal to 0 or 1. 229 | The model considers the results with top_p probability quality tokens. 230 | For example, 0.1 means that the model decoder only considers tokens from the 231 | top 10% probability of the candidate set. 232 | You are advised to adjust top_p or temperature parameters based on application 233 | scenarios, but do not adjust the two parameters at the same time. 234 | """ 235 | 236 | request_id: Optional[str] = Field(None) 237 | """ 238 | Parameter transmission by the client must ensure uniqueness; A unique 239 | identifier used to distinguish each request, which is generated by default 240 | by the platform when the client does not transmit it. 241 | """ 242 | do_sample: Optional[bool] = Field(True) 243 | """ 244 | When do_sample is true, the sampling policy is enabled. When do_sample is false, 245 | the sampling policy temperature and top_p are disabled 246 | """ 247 | streaming: bool = Field(False) 248 | """Whether to stream the results or not.""" 249 | 250 | model_kwargs: Dict[str, Any] = Field(default_factory=dict) 251 | """Holds any model parameters valid for `create` call not explicitly specified.""" 252 | 253 | max_tokens: Optional[int] = None 254 | """Number of chat completions to generate for each prompt.""" 255 | 256 | max_retries: int = 2 257 | """Maximum number of retries to make when generating.""" 258 | 259 | @property 260 | def _identifying_params(self) -> Dict[str, Any]: 261 | """Get the identifying parameters.""" 262 | return {**{"model_name": self.model_name}, **self._default_params} 263 | 264 | @property 265 | def _llm_type(self) -> str: 266 | """Return the type of chat model.""" 267 | return "zhipuai" 268 | 269 | @property 270 | def lc_secrets(self) -> Dict[str, str]: 271 | return {"zhipuai_api_key": "ZHIPUAI_API_KEY"} 272 | 273 | @classmethod 274 | def get_lc_namespace(cls) -> List[str]: 275 | """Get the namespace of the langchain object.""" 276 | return ["langchain", "chat_models", "zhipuai"] 277 | 278 | @property 279 | def lc_attributes(self) -> Dict[str, Any]: 280 | attributes: Dict[str, Any] = {} 281 | 282 | if self.model_name: 283 | attributes["model"] = self.model_name 284 | 285 | if self.streaming: 286 | attributes["streaming"] = self.streaming 287 | 288 | if self.max_tokens: 289 | attributes["max_tokens"] = self.max_tokens 290 | 291 | return attributes 292 | 293 | @property 294 | def _default_params(self) -> Dict[str, Any]: 295 | """Get the default parameters for calling ZhipuAI API.""" 296 | params = { 297 | "model": self.model_name, 298 | "stream": self.streaming, 299 | "temperature": self.temperature, 300 | "top_p": self.top_p, 301 | "do_sample": self.do_sample, 302 | **self.model_kwargs, 303 | } 304 | if self.max_tokens is not None: 305 | params["max_tokens"] = self.max_tokens 306 | return params 307 | 308 | @property 309 | def _client_params(self) -> Dict[str, Any]: 310 | """Get the parameters used for the zhipuai client.""" 311 | zhipuai_creds: Dict[str, Any] = { 312 | "request_id": self.request_id, 313 | } 314 | return {**self._default_params, **zhipuai_creds} 315 | 316 | def __init__(self, *args, **kwargs): 317 | super().__init__(*args, **kwargs) 318 | try: 319 | from zhipuai import ZhipuAI 320 | 321 | if not is_zhipu_v2(): 322 | raise RuntimeError( 323 | "zhipuai package version is too low" 324 | "Please install it via 'pip install --upgrade zhipuai'" 325 | ) 326 | 327 | self.client = ZhipuAI( 328 | api_key=self.zhipuai_api_key, # 填写您的 APIKey 329 | ) 330 | except ImportError: 331 | raise RuntimeError( 332 | "Could not import zhipuai package. " 333 | "Please install it via 'pip install zhipuai'" 334 | ) 335 | 336 | def completions(self, **kwargs) -> Any | None: 337 | return self.client.chat.completions.create(**kwargs) 338 | 339 | async def async_completions(self, **kwargs) -> Any: 340 | loop = asyncio.get_running_loop() 341 | partial_func = partial(self.client.chat.completions.create, **kwargs) 342 | response = await loop.run_in_executor( 343 | None, 344 | partial_func, 345 | ) 346 | return response 347 | 348 | async def async_completions_result(self, task_id): 349 | loop = asyncio.get_running_loop() 350 | response = await loop.run_in_executor( 351 | None, 352 | self.client.asyncCompletions.retrieve_completion_result, 353 | task_id, 354 | ) 355 | return response 356 | 357 | def _create_chat_result(self, response: Union[dict, BaseModel]) -> ChatResult: 358 | generations = [] 359 | if not isinstance(response, dict): 360 | response = response.dict() 361 | for res in response["choices"]: 362 | message = convert_dict_to_message(res["message"]) 363 | generation_info = dict(finish_reason=res.get("finish_reason")) 364 | if "index" in res: 365 | generation_info["index"] = res["index"] 366 | gen = ChatGeneration( 367 | message=message, 368 | generation_info=generation_info, 369 | ) 370 | generations.append(gen) 371 | token_usage = response.get("usage", {}) 372 | llm_output = { 373 | "token_usage": token_usage, 374 | "model_name": self.model_name, 375 | "task_id": response.get("id", ""), 376 | "created_time": response.get("created", ""), 377 | } 378 | return ChatResult(generations=generations, llm_output=llm_output) 379 | 380 | def _create_message_dicts( 381 | self, messages: List[BaseMessage], stop: Optional[List[str]] 382 | ) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]: 383 | params = self._client_params 384 | if stop is not None: 385 | if "stop" in params: 386 | raise ValueError("`stop` found in both the input and default params.") 387 | params["stop"] = stop 388 | message_dicts = [convert_message_to_dict(m) for m in messages] 389 | return message_dicts, params 390 | 391 | def completion_with_retry( 392 | self, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any 393 | ) -> Any: 394 | """Use tenacity to retry the completion call.""" 395 | 396 | retry_decorator = _create_retry_decorator(self, run_manager=run_manager) 397 | 398 | @retry_decorator 399 | def _completion_with_retry(**kwargs: Any) -> Any: 400 | return self.completions(**kwargs) 401 | 402 | return _completion_with_retry(**kwargs) 403 | 404 | async def acompletion_with_retry( 405 | self, 406 | run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, 407 | **kwargs: Any, 408 | ) -> Any: 409 | """Use tenacity to retry the async completion call.""" 410 | 411 | retry_decorator = _create_retry_decorator(self, run_manager=run_manager) 412 | 413 | @retry_decorator 414 | async def _completion_with_retry(**kwargs: Any) -> Any: 415 | return await self.async_completions(**kwargs) 416 | 417 | return await _completion_with_retry(**kwargs) 418 | 419 | def _generate( 420 | self, 421 | messages: List[BaseMessage], 422 | stop: Optional[List[str]] = None, 423 | run_manager: Optional[CallbackManagerForLLMRun] = None, 424 | stream: Optional[bool] = None, 425 | **kwargs: Any, 426 | ) -> ChatResult: 427 | """Generate a chat response.""" 428 | 429 | should_stream = stream if stream is not None else self.streaming 430 | if should_stream: 431 | stream_iter = self._stream( 432 | messages, stop=stop, run_manager=run_manager, **kwargs 433 | ) 434 | return generate_from_stream(stream_iter) 435 | 436 | message_dicts, params = self._create_message_dicts(messages, stop) 437 | params = { 438 | **params, 439 | **({"stream": stream} if stream is not None else {}), 440 | **kwargs, 441 | } 442 | response = self.completion_with_retry( 443 | messages=message_dicts, run_manager=run_manager, **params 444 | ) 445 | return self._create_chat_result(response) 446 | 447 | async def _agenerate( 448 | self, 449 | messages: List[BaseMessage], 450 | stop: Optional[List[str]] = None, 451 | run_manager: Optional[CallbackManagerForLLMRun] = None, 452 | stream: Optional[bool] = False, 453 | **kwargs: Any, 454 | ) -> ChatResult: 455 | """Asynchronously generate a chat response.""" 456 | should_stream = stream if stream is not None else self.streaming 457 | if should_stream: 458 | stream_iter = self._astream( 459 | messages, stop=stop, run_manager=run_manager, **kwargs 460 | ) 461 | return generate_from_stream(stream_iter) 462 | 463 | message_dicts, params = self._create_message_dicts(messages, stop) 464 | params = { 465 | **params, 466 | **({"stream": stream} if stream is not None else {}), 467 | **kwargs, 468 | } 469 | response = await self.acompletion_with_retry( 470 | messages=message_dicts, run_manager=run_manager, **params 471 | ) 472 | return self._create_chat_result(response) 473 | 474 | def _stream( 475 | self, 476 | messages: List[BaseMessage], 477 | stop: Optional[List[str]] = None, 478 | run_manager: Optional[CallbackManagerForLLMRun] = None, 479 | **kwargs: Any, 480 | ) -> Iterator[ChatGenerationChunk]: 481 | """Stream the chat response in chunks.""" 482 | message_dicts, params = self._create_message_dicts(messages, stop) 483 | params = {**params, **kwargs, "stream": True} 484 | 485 | default_chunk_class = AIMessageChunk 486 | for chunk in self.completion_with_retry( 487 | messages=message_dicts, run_manager=run_manager, **params 488 | ): 489 | if not isinstance(chunk, dict): 490 | chunk = chunk.dict() 491 | if len(chunk["choices"]) == 0: 492 | continue 493 | choice = chunk["choices"][0] 494 | chunk = _convert_delta_to_message_chunk( 495 | choice["delta"], default_chunk_class 496 | ) 497 | 498 | finish_reason = choice.get("finish_reason") 499 | generation_info = ( 500 | dict(finish_reason=finish_reason) if finish_reason is not None else None 501 | ) 502 | default_chunk_class = chunk.__class__ 503 | chunk = ChatGenerationChunk(message=chunk, generation_info=generation_info) 504 | yield chunk 505 | if run_manager: 506 | run_manager.on_llm_new_token(chunk.text, chunk=chunk) 507 | -------------------------------------------------------------------------------- /langchain_emoji/components/llm/custom/zhipuai/zhipuai_info.py: -------------------------------------------------------------------------------- 1 | """Callback Handler that prints to std out.""" 2 | import threading 3 | from typing import Any, Dict, List 4 | 5 | from langchain_core.callbacks import BaseCallbackHandler 6 | from langchain_core.outputs import LLMResult 7 | from contextlib import contextmanager 8 | from typing import ( 9 | Generator, 10 | ) 11 | 12 | MODEL_COST_PER_1K_TOKENS = { 13 | # input 14 | "glm-4": 0.1, 15 | "glm-3-turbo": 0.005, 16 | # output 17 | "glm-4-completion": 0.1, 18 | "glm-3-turbo-completion": 0.005, 19 | } 20 | 21 | 22 | def standardize_model_name( 23 | model_name: str, 24 | is_completion: bool = False, 25 | ) -> str: 26 | """ 27 | Standardize the model name to a format that can be used in the ZhipuAI API. 28 | 29 | Args: 30 | model_name: Model name to standardize. 31 | is_completion: Whether the model is used for completion or not. 32 | Defaults to False. 33 | 34 | Returns: 35 | Standardized model name. 36 | 37 | """ 38 | model_name = model_name.lower() 39 | if is_completion and ( 40 | model_name.startswith("glm-4") or model_name.startswith("glm-3-turbo") 41 | ): 42 | return model_name + "-completion" 43 | else: 44 | return model_name 45 | 46 | 47 | def get_zhipuai_token_cost_for_model( 48 | model_name: str, num_tokens: int, is_completion: bool = False 49 | ) -> float: 50 | """ 51 | Get the cost in USD for a given model and number of tokens. 52 | 53 | Args: 54 | model_name: Name of the model 55 | num_tokens: Number of tokens. 56 | is_completion: Whether the model is used for completion or not. 57 | Defaults to False. 58 | 59 | Returns: 60 | Cost in CNY. 61 | """ 62 | model_name = standardize_model_name(model_name, is_completion=is_completion) 63 | if model_name not in MODEL_COST_PER_1K_TOKENS: 64 | raise ValueError( 65 | f"Unknown model: {model_name}. Please provide a valid ZhipuAI model name." 66 | "Known models are: " + ", ".join(MODEL_COST_PER_1K_TOKENS.keys()) 67 | ) 68 | return MODEL_COST_PER_1K_TOKENS[model_name] * (num_tokens / 1000) 69 | 70 | 71 | class ZhipuAICallbackHandler(BaseCallbackHandler): 72 | """Callback Handler that tracks ZhipuAI info.""" 73 | 74 | total_tokens: int = 0 75 | prompt_tokens: int = 0 76 | completion_tokens: int = 0 77 | successful_requests: int = 0 78 | total_cost: float = 0.0 79 | 80 | def __init__(self) -> None: 81 | super().__init__() 82 | self._lock = threading.Lock() 83 | 84 | def __repr__(self) -> str: 85 | return ( 86 | f"Tokens Used: {self.total_tokens}\n" 87 | f"\tPrompt Tokens: {self.prompt_tokens}\n" 88 | f"\tCompletion Tokens: {self.completion_tokens}\n" 89 | f"Successful Requests: {self.successful_requests}\n" 90 | f"Total Cost (CNY): ${self.total_cost}" 91 | ) 92 | 93 | @property 94 | def always_verbose(self) -> bool: 95 | """Whether to call verbose callbacks even if verbose is False.""" 96 | return True 97 | 98 | def on_llm_start( 99 | self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any 100 | ) -> None: 101 | """Print out the prompts.""" 102 | pass 103 | 104 | def on_llm_new_token(self, token: str, **kwargs: Any) -> None: 105 | """Print out the token.""" 106 | pass 107 | 108 | def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: 109 | """Collect token usage.""" 110 | if response.llm_output is None: 111 | return None 112 | 113 | if "token_usage" not in response.llm_output: 114 | with self._lock: 115 | self.successful_requests += 1 116 | return None 117 | 118 | # compute tokens and cost for this request 119 | token_usage = response.llm_output["token_usage"] 120 | completion_tokens = token_usage.get("completion_tokens", 0) 121 | prompt_tokens = token_usage.get("prompt_tokens", 0) 122 | model_name = standardize_model_name(response.llm_output.get("model_name", "")) 123 | if model_name in MODEL_COST_PER_1K_TOKENS: 124 | completion_cost = get_zhipuai_token_cost_for_model( 125 | model_name, completion_tokens, is_completion=True 126 | ) 127 | prompt_cost = get_zhipuai_token_cost_for_model(model_name, prompt_tokens) 128 | else: 129 | completion_cost = 0 130 | prompt_cost = 0 131 | 132 | # update shared state behind lock 133 | with self._lock: 134 | self.total_cost += prompt_cost + completion_cost 135 | self.total_tokens += token_usage.get("total_tokens", 0) 136 | self.prompt_tokens += prompt_tokens 137 | self.completion_tokens += completion_tokens 138 | self.successful_requests += 1 139 | 140 | def __copy__(self) -> "ZhipuAICallbackHandler": 141 | """Return a copy of the callback handler.""" 142 | return self 143 | 144 | def __deepcopy__(self, memo: Any) -> "ZhipuAICallbackHandler": 145 | """Return a deep copy of the callback handler.""" 146 | return self 147 | 148 | 149 | @contextmanager 150 | def get_zhipuai_callback() -> Generator[ZhipuAICallbackHandler, None, None]: 151 | """Get the ZhipuAI callback handler in a context manager. 152 | which conveniently exposes token and cost information. 153 | 154 | Returns: 155 | ZhipuAICallbackHandler: The ZhipuAI callback handler. 156 | 157 | Example: 158 | >>> with get_zhipuai_callback() as cb: 159 | ... # Use the ZhipuAI callback handler 160 | """ 161 | cb = ZhipuAICallbackHandler() 162 | yield cb 163 | -------------------------------------------------------------------------------- /langchain_emoji/components/llm/llm_component.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from injector import inject, singleton 4 | from langchain_community.chat_models import ChatOpenAI 5 | from langchain.llms.base import BaseLanguageModel 6 | from langchain.llms.fake import FakeListLLM 7 | from langchain_emoji.settings.settings import Settings 8 | from langchain_emoji.components.llm.custom.zhipuai import ChatZhipuAI 9 | from langchain.schema.runnable import ConfigurableField 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | @singleton 15 | class LLMComponent: 16 | @inject 17 | def __init__(self, settings: Settings) -> None: 18 | llm_mode = settings.llm.mode 19 | logger.info("Initializing the LLM in mode=%s", llm_mode) 20 | self.modelname = settings.openai.modelname 21 | match settings.llm.mode: 22 | case "local": 23 | ... # Todo 24 | case "openai": 25 | openai_settings = settings.openai 26 | self._llm = ChatOpenAI( 27 | temperature=openai_settings.temperature, 28 | model_name=openai_settings.modelname, 29 | api_key=openai_settings.api_key, 30 | openai_api_base=openai_settings.api_base, 31 | model_kwargs={"response_format": {"type": "json_object"}}, 32 | ).configurable_alternatives( 33 | # This gives this field an id 34 | # When configuring the end runnable, we can then use this id to configure this field 35 | ConfigurableField(id="llm"), 36 | default_key="openai", 37 | ) 38 | 39 | case "zhipuai": 40 | zhipuai_settings = settings.zhipuai 41 | self._llm = ChatZhipuAI( 42 | model=zhipuai_settings.modelname, 43 | temperature=zhipuai_settings.temperature, 44 | top_p=zhipuai_settings.top_p, 45 | api_key=zhipuai_settings.api_key, 46 | ).configurable_alternatives( 47 | # This gives this field an id 48 | # When configuring the end runnable, we can then use this id to configure this field 49 | ConfigurableField(id="llm"), 50 | default_key="zhipuai", 51 | ) 52 | case "deepseek": 53 | deepseek_settings = settings.deepseek 54 | self._llm = ChatOpenAI( 55 | model=deepseek_settings.modelname, 56 | temperature=deepseek_settings.temperature, 57 | api_key=deepseek_settings.api_key, 58 | openai_api_base=deepseek_settings.api_base, 59 | ).configurable_alternatives( 60 | # This gives this field an id 61 | # When configuring the end runnable, we can then use this id to configure this field 62 | ConfigurableField(id="llm"), 63 | default_key="deepseek", 64 | ) 65 | case "all": 66 | openai_settings = settings.openai 67 | zhipuai_settings = settings.zhipuai 68 | deepseek_settings = settings.deepseek 69 | self._llm = ChatOpenAI( 70 | temperature=openai_settings.temperature, 71 | model_name=openai_settings.modelname, 72 | api_key=openai_settings.api_key, 73 | openai_api_base=openai_settings.api_base, 74 | model_kwargs={"response_format": {"type": "json_object"}}, 75 | ).configurable_alternatives( 76 | # This gives this field an id 77 | # When configuring the end runnable, we can then use this id to configure this field 78 | ConfigurableField(id="llm"), 79 | default_key="openai", 80 | zhipuai=ChatZhipuAI( 81 | model=zhipuai_settings.modelname, 82 | temperature=zhipuai_settings.temperature, 83 | top_p=zhipuai_settings.top_p, 84 | api_key=zhipuai_settings.api_key, 85 | ), 86 | deepseek=ChatOpenAI( 87 | model=deepseek_settings.modelname, 88 | temperature=deepseek_settings.temperature, 89 | api_key=deepseek_settings.api_key, 90 | openai_api_base=deepseek_settings.api_base, 91 | ), 92 | ) 93 | 94 | case "mock": 95 | self._llm = FakeListLLM( 96 | responses=["你好,帝阅AI表情包"] 97 | ).configurable_alternatives( 98 | # This gives this field an id 99 | # When configuring the end runnable, we can then use this id to configure this field 100 | ConfigurableField(id="llm"), 101 | default_key="mock", 102 | ) 103 | 104 | @property 105 | def llm(self) -> BaseLanguageModel: 106 | return self._llm 107 | -------------------------------------------------------------------------------- /langchain_emoji/components/minio/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ptonlix/LangChain-Emoji/4c0a926464a7419e48df32b3c5e26794e2f9b312/langchain_emoji/components/minio/__init__.py -------------------------------------------------------------------------------- /langchain_emoji/components/minio/minio_component.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import base64 3 | from injector import inject, singleton 4 | from langchain_emoji.settings.settings import Settings 5 | from minio import Minio 6 | from minio.error import MinioException 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | 11 | @singleton 12 | class MinioComponent: 13 | @inject 14 | def __init__(self, settings: Settings) -> None: 15 | if not settings.minio: 16 | raise Exception("minio config is not exist! please check") 17 | self.minio_settings = settings.minio 18 | self.minio_client = Minio( 19 | endpoint=self.minio_settings.host, 20 | access_key=self.minio_settings.access_key, 21 | secret_key=self.minio_settings.secret_key, 22 | secure=False, 23 | ) 24 | 25 | def get_file_base64(self, file_name: str) -> str: 26 | try: 27 | response = self.minio_client.get_object( 28 | self.minio_settings.bucket_name, file_name 29 | ) 30 | # Read the object content 31 | object_data = response.read() 32 | 33 | # Encode object data to base64 34 | base64_data = base64.b64encode(object_data) 35 | 36 | return base64_data.decode("utf-8") 37 | except MinioException as e: 38 | logger.error(f"get file base64 failed : {e}") 39 | return None 40 | 41 | def get_download_link(self, file_name: str) -> str: 42 | try: 43 | # Generate presigned URL for download 44 | presigned_url = self.minio_client.presigned_get_object( 45 | self.minio_settings.bucket_name, file_name 46 | ) 47 | return presigned_url 48 | except MinioException as e: 49 | logger.error(f"get share link failed : {e}") 50 | return None 51 | 52 | 53 | if __name__ == "__main__": 54 | from langchain_emoji.settings.settings import settings 55 | 56 | mc = MinioComponent(settings()) 57 | 58 | obj = "06e07a24-df07-4781-a1da-58739ac65404.jpg" 59 | 60 | print(mc.get_file_base64(obj)) 61 | print(mc.get_download_link(obj)) 62 | -------------------------------------------------------------------------------- /langchain_emoji/components/trace/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ptonlix/LangChain-Emoji/4c0a926464a7419e48df32b3c5e26794e2f9b312/langchain_emoji/components/trace/__init__.py -------------------------------------------------------------------------------- /langchain_emoji/components/trace/trace_component.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from injector import inject, singleton 4 | from langchain_emoji.settings.settings import Settings 5 | from langsmith import Client 6 | from langsmith.utils import LangSmithError 7 | import os 8 | import asyncio 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | @singleton 14 | class TraceComponent: 15 | @inject 16 | def __init__(self, settings: Settings) -> None: 17 | os.environ["LANGCHAIN_TRACING_V2"] = str(settings.langsmith.trace_version_v2) 18 | os.environ["LANGCHAIN_PROJECT"] = str(settings.langsmith.langchain_project) 19 | os.environ["LANGCHAIN_API_KEY"] = settings.langsmith.api_key 20 | self.trace_client = Client(api_key=settings.langsmith.api_key) 21 | 22 | async def _arun(self, func, *args, **kwargs): 23 | return await asyncio.get_running_loop().run_in_executor( 24 | None, func, *args, **kwargs 25 | ) 26 | 27 | async def aget_trace_url(self, run_id: str) -> str: 28 | for i in range(5): 29 | try: 30 | await self._arun(self.trace_client.read_run, run_id) 31 | break 32 | except LangSmithError: 33 | await asyncio.sleep(1**i) 34 | 35 | if await self._arun(self.trace_client.run_is_shared, run_id): 36 | return await self._arun(self.trace_client.read_run_shared_link, run_id) 37 | return await self._arun(self.trace_client.share_run, run_id) 38 | -------------------------------------------------------------------------------- /langchain_emoji/components/vector_store/__init__.py: -------------------------------------------------------------------------------- 1 | from langchain_emoji.components.vector_store.vector_store_component import ( 2 | VectorStoreComponent, 3 | ) 4 | 5 | __all__ = ["VectorStoreComponent"] 6 | -------------------------------------------------------------------------------- /langchain_emoji/components/vector_store/chroma/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ptonlix/LangChain-Emoji/4c0a926464a7419e48df32b3c5e26794e2f9b312/langchain_emoji/components/vector_store/chroma/__init__.py -------------------------------------------------------------------------------- /langchain_emoji/components/vector_store/chroma/chroma.py: -------------------------------------------------------------------------------- 1 | from langchain_chroma import Chroma 2 | from langchain_core.documents import Document 3 | from typing import Any, Iterable, List, Optional 4 | 5 | 6 | class EmojiChroma(Chroma): 7 | 8 | def add_original_texts_with_filename( 9 | self, 10 | filename: str, 11 | texts: Iterable[str], 12 | metadatas: Optional[List[dict]] = None, 13 | timeout: Optional[int] = None, 14 | batch_size: int = 1000, 15 | **kwargs: Any, 16 | ) -> List[str]: 17 | if not metadatas: 18 | metadatas = { 19 | "filename": filename, 20 | } 21 | return self.add_texts( 22 | texts=texts, 23 | metadatas=metadatas, 24 | ) 25 | 26 | def similarity_search_by_filenames( 27 | self, query: str, filenames: List[str], k: int = 4 28 | ) -> List[Document]: 29 | where = None 30 | if len(filenames) > 0: 31 | where = {"filename": {"$in": filenames}} 32 | 33 | return self.similarity_search(query, k=k, filter=where) 34 | 35 | def delete_texts_with_filenames( 36 | self, 37 | document_ids: List[str], 38 | filenames: List[str] = [], 39 | batch_size: int = 20, 40 | expr: Optional[str] = None, 41 | timeout: Optional[int] = None, 42 | ): 43 | common_ids = document_ids 44 | 45 | if len(filenames) > 0: 46 | where = {"filename": {"$in": filenames}} 47 | result = self._collection.get(where=where) 48 | print(result) 49 | common_ids = [value for value in result.get("ids") if value in document_ids] 50 | 51 | return self.delete( 52 | ids=common_ids, 53 | ) 54 | -------------------------------------------------------------------------------- /langchain_emoji/components/vector_store/tencent/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ptonlix/LangChain-Emoji/4c0a926464a7419e48df32b3c5e26794e2f9b312/langchain_emoji/components/vector_store/tencent/__init__.py -------------------------------------------------------------------------------- /langchain_emoji/components/vector_store/tencent/tencent.py: -------------------------------------------------------------------------------- 1 | """Wrapper around the Tencent vector database.""" 2 | 3 | from __future__ import annotations 4 | 5 | import json 6 | import logging 7 | import time 8 | from typing import Any, Dict, Iterable, List, Optional, Tuple 9 | 10 | import numpy as np 11 | from langchain_core.documents import Document 12 | from langchain_core.embeddings import Embeddings 13 | from langchain_core.utils import guard_import 14 | from langchain_core.vectorstores import VectorStore 15 | 16 | from langchain.vectorstores.utils import maximal_marginal_relevance 17 | 18 | 19 | logger = logging.getLogger(__name__) 20 | 21 | 22 | class ConnectionParams: 23 | """Tencent vector DB Connection params. 24 | 25 | See the following documentation for details: 26 | https://cloud.tencent.com/document/product/1709/95820 27 | 28 | Attribute: 29 | url (str) : The access address of the vector database server 30 | that the client needs to connect to. 31 | key (str): API key for client to access the vector database server, 32 | which is used for authentication. 33 | username (str) : Account for client to access the vector database server. 34 | timeout (int) : Request Timeout. 35 | """ 36 | 37 | def __init__(self, url: str, key: str, username: str = "root", timeout: int = 10): 38 | self.url = url 39 | self.key = key 40 | self.username = username 41 | self.timeout = timeout 42 | 43 | 44 | class IndexParams: 45 | """Tencent vector DB Index params. 46 | 47 | See the following documentation for details: 48 | https://cloud.tencent.com/document/product/1709/95826 49 | """ 50 | 51 | def __init__( 52 | self, 53 | dimension: int, 54 | shard: int = 1, 55 | replicas: int = 2, 56 | index_type: str = "HNSW", 57 | metric_type: str = "L2", 58 | params: Optional[Dict] = None, 59 | ): 60 | self.dimension = dimension 61 | self.shard = shard 62 | self.replicas = replicas 63 | self.index_type = index_type 64 | self.metric_type = metric_type 65 | self.params = params 66 | 67 | 68 | class TencentVectorDB(VectorStore): 69 | """Initialize wrapper around the tencent vector database. 70 | 71 | In order to use this you need to have a database instance. 72 | See the following documentation for details: 73 | https://cloud.tencent.com/document/product/1709/94951 74 | """ 75 | 76 | field_id: str = "id" 77 | field_vector: str = "vector" 78 | field_text: str = "text" 79 | field_metadata: str = "metadata" 80 | 81 | def __init__( 82 | self, 83 | connection_params: ConnectionParams, 84 | index_params: IndexParams = IndexParams(128), 85 | database_name: str = "LangChainDatabase", 86 | collection_name: str = "LangChainCollection", 87 | drop_old: Optional[bool] = False, 88 | embedding: Optional[Embeddings] = None, 89 | ): 90 | self.document = guard_import("tcvectordb.model.document") 91 | tcvectordb = guard_import("tcvectordb") 92 | self.embedding_func = embedding 93 | self.index_params = index_params 94 | self.ebd_own = ( 95 | True if embedding is None else False 96 | ) # 是否采用腾讯云自带embedding 97 | self.vdb_client = tcvectordb.VectorDBClient( 98 | url=connection_params.url, 99 | username=connection_params.username, 100 | key=connection_params.key, 101 | timeout=connection_params.timeout, 102 | ) 103 | db_list = self.vdb_client.list_databases() 104 | db_exist: bool = False 105 | for db in db_list: 106 | if database_name == db.database_name: 107 | db_exist = True 108 | break 109 | if db_exist: 110 | self.database = self.vdb_client.database(database_name) 111 | else: 112 | self.database = self.vdb_client.create_database(database_name) 113 | try: 114 | self.collection = self.database.describe_collection(collection_name) 115 | if drop_old: 116 | self.database.drop_collection(collection_name) 117 | self._create_collection(collection_name) 118 | except tcvectordb.exceptions.VectorDBException: 119 | self._create_collection(collection_name) 120 | 121 | def _create_collection(self, collection_name: str) -> None: 122 | enum = guard_import("tcvectordb.model.enum") 123 | vdb_index = guard_import("tcvectordb.model.index") 124 | coll = guard_import("tcvectordb.model.collection") 125 | index_type = None 126 | for k, v in enum.IndexType.__members__.items(): 127 | if k == self.index_params.index_type: 128 | index_type = v 129 | if index_type is None: 130 | raise ValueError("unsupported index_type") 131 | metric_type = None 132 | for k, v in enum.MetricType.__members__.items(): 133 | if k == self.index_params.metric_type: 134 | metric_type = v 135 | if metric_type is None: 136 | raise ValueError("unsupported metric_type") 137 | if self.index_params.params is None: 138 | params = vdb_index.HNSWParams(m=16, efconstruction=200) 139 | else: 140 | params = vdb_index.HNSWParams( 141 | m=self.index_params.params.get("M", 16), 142 | efconstruction=self.index_params.params.get("efConstruction", 200), 143 | ) 144 | index = vdb_index.Index( 145 | vdb_index.FilterIndex( 146 | self.field_id, enum.FieldType.String, enum.IndexType.PRIMARY_KEY 147 | ), 148 | vdb_index.VectorIndex( 149 | self.field_vector, 150 | self.index_params.dimension, 151 | index_type, 152 | metric_type, 153 | params, 154 | ), 155 | vdb_index.FilterIndex( 156 | self.field_text, enum.FieldType.String, enum.IndexType.FILTER 157 | ), 158 | vdb_index.FilterIndex( 159 | self.field_metadata, enum.FieldType.String, enum.IndexType.FILTER 160 | ), 161 | ) 162 | # Embedding config 163 | ebd = None 164 | if self.ebd_own: 165 | ebd = coll.Embedding( 166 | vector_field=self.field_vector, 167 | field=self.field_text, 168 | model=enum.EmbeddingModel.BGE_BASE_ZH, 169 | ) 170 | 171 | self.collection = self.database.create_collection( 172 | name=collection_name, 173 | shard=self.index_params.shard, 174 | replicas=self.index_params.replicas, 175 | description="Collection for LangChain", 176 | index=index, 177 | embedding=ebd, 178 | ) 179 | 180 | @property 181 | def embeddings(self) -> Embeddings: 182 | return self.embedding_func 183 | 184 | @classmethod 185 | def from_texts( 186 | cls, 187 | texts: List[str], 188 | metadatas: Optional[List[dict]] = None, 189 | connection_params: Optional[ConnectionParams] = None, 190 | index_params: Optional[IndexParams] = None, 191 | database_name: str = "LangChainDatabase", 192 | collection_name: str = "LangChainCollection", 193 | drop_old: Optional[bool] = False, 194 | embedding: Optional[Embeddings] = None, 195 | **kwargs: Any, 196 | ) -> TencentVectorDB: 197 | """Create a collection, indexes it with HNSW, and insert data.""" 198 | if len(texts) == 0: 199 | raise ValueError("texts is empty") 200 | if connection_params is None: 201 | raise ValueError("connection_params is empty") 202 | vector_db = None 203 | if embedding: 204 | try: 205 | embeddings = embedding.embed_documents(texts[0:1]) 206 | except NotImplementedError: 207 | embeddings = [embedding.embed_query(texts[0])] 208 | dimension = len(embeddings[0]) 209 | if index_params is None: 210 | index_params = IndexParams(dimension=dimension) 211 | else: 212 | index_params.dimension = dimension 213 | vector_db = cls( 214 | embedding=embedding, 215 | connection_params=connection_params, 216 | index_params=index_params, 217 | database_name=database_name, 218 | collection_name=collection_name, 219 | drop_old=drop_old, 220 | ) 221 | vector_db.add_texts(texts=texts, metadatas=metadatas, **kwargs) 222 | else: # 使用腾讯 embedding 223 | if index_params is None: 224 | index_params = IndexParams(dimension=768) 225 | vector_db = cls( 226 | embedding=embedding, 227 | connection_params=connection_params, 228 | index_params=index_params, 229 | database_name=database_name, 230 | collection_name=collection_name, 231 | drop_old=drop_old, 232 | ) 233 | vector_db.add_original_texts(texts=texts, metadatas=metadatas, **kwargs) 234 | return vector_db 235 | 236 | def add_texts( 237 | self, 238 | texts: Iterable[str], 239 | metadatas: Optional[List[dict]] = None, 240 | timeout: Optional[int] = None, 241 | batch_size: int = 1000, 242 | **kwargs: Any, 243 | ) -> List[str]: 244 | """Insert text data into TencentVectorDB.""" 245 | texts = list(texts) 246 | pks: list[str] = [] 247 | try: 248 | embeddings = self.embedding_func.embed_documents(texts) 249 | except NotImplementedError: 250 | embeddings = [self.embedding_func.embed_query(x) for x in texts] 251 | if len(embeddings) == 0: 252 | logger.debug("Nothing to insert, skipping.") 253 | return [] 254 | total_count = len(embeddings) 255 | for start in range(0, total_count, batch_size): 256 | # Grab end index 257 | docs = [] 258 | end = min(start + batch_size, total_count) 259 | for id in range(start, end, 1): 260 | metadata = "{}" 261 | if metadatas is not None: 262 | metadata = json.dumps(metadatas[id]) 263 | doc = self.document.Document( 264 | id="{}-{}-{}".format(time.time_ns(), hash(texts[id]), id), 265 | vector=embeddings[id], 266 | text=texts[id], 267 | metadata=metadata, 268 | **kwargs, 269 | ) 270 | docs.append(doc) 271 | pks.append(str(id)) 272 | self.collection.upsert(docs, timeout) 273 | return pks 274 | 275 | def add_original_texts( 276 | self, 277 | texts: Iterable[str], 278 | metadatas: Optional[List[dict]] = None, 279 | timeout: Optional[int] = None, 280 | batch_size: int = 1000, 281 | **kwargs: Any, 282 | ) -> List[str]: 283 | """Insert text data into TencentVectorDB.""" 284 | texts = list(texts) 285 | pks: list[str] = [] 286 | total_count = len(texts) 287 | for start in range(0, total_count, batch_size): 288 | # Grab end index 289 | docs = [] 290 | end = min(start + batch_size, total_count) 291 | for id in range(start, end, 1): 292 | metadata = "{}" 293 | if metadatas is not None: 294 | metadata = json.dumps(metadatas[id]) 295 | vdb_id = "{}-{}-{}".format(time.time_ns(), hash(texts[id]), id) 296 | doc = self.document.Document( 297 | id=vdb_id, 298 | text=texts[id], 299 | metadata=metadata, 300 | **kwargs, 301 | ) 302 | docs.append(doc) 303 | pks.append(vdb_id) # 返回向量数据库中的ID 304 | self.collection.upsert(docs, timeout) 305 | return pks 306 | 307 | def similarity_search( 308 | self, 309 | query: str, 310 | k: int = 4, 311 | param: Optional[dict] = None, 312 | expr: Optional[str] = None, 313 | timeout: Optional[int] = None, 314 | **kwargs: Any, 315 | ) -> List[Document]: 316 | """Perform a similarity search against the query string.""" 317 | res = self.similarity_search_with_score( 318 | query=query, k=k, param=param, expr=expr, timeout=timeout, **kwargs 319 | ) 320 | return [doc for doc, _ in res] 321 | 322 | def similarity_search_with_score( 323 | self, 324 | query: str, 325 | k: int = 4, 326 | param: Optional[dict] = None, 327 | expr: Optional[str] = None, 328 | timeout: Optional[int] = None, 329 | **kwargs: Any, 330 | ) -> List[Tuple[Document, float]]: 331 | """Perform a search on a query string and return results with score.""" 332 | res: List[Tuple[Document, float]] = [] 333 | if self.ebd_own: 334 | res = self.similarity_search_with_score_by_text( 335 | text=[query], 336 | k=k, 337 | param=param, 338 | expr=expr, 339 | timeout=timeout, 340 | **kwargs, 341 | ) 342 | else: 343 | # Embed the query text. 344 | embedding = self.embedding_func.embed_query(query) 345 | res = self.similarity_search_with_score_by_vector( 346 | embedding=embedding, 347 | k=k, 348 | param=param, 349 | expr=expr, 350 | timeout=timeout, 351 | **kwargs, 352 | ) 353 | return res 354 | 355 | def similarity_search_by_vector( 356 | self, 357 | embedding: List[float], 358 | k: int = 4, 359 | param: Optional[dict] = None, 360 | expr: Optional[str] = None, 361 | timeout: Optional[int] = None, 362 | **kwargs: Any, 363 | ) -> List[Document]: 364 | """Perform a similarity search against the query string.""" 365 | res = self.similarity_search_with_score_by_vector( 366 | embedding=embedding, k=k, param=param, expr=expr, timeout=timeout, **kwargs 367 | ) 368 | return [doc for doc, _ in res] 369 | 370 | def similarity_search_with_score_by_text( 371 | self, 372 | text: List[str], 373 | k: int = 4, 374 | param: Optional[dict] = None, 375 | expr: Optional[str] = None, 376 | timeout: Optional[int] = None, 377 | **kwargs: Any, 378 | ) -> List[Tuple[Document, float]]: 379 | """Perform a search on a query string and return results with score.""" 380 | filter = None if expr is None else self.document.Filter(expr) 381 | ef = 10 if param is None else param.get("ef", 10) 382 | res_data = self.collection.searchByText( 383 | embeddingItems=text, 384 | filter=filter, 385 | params=self.document.HNSWSearchParams(ef=ef), 386 | retrieve_vector=False, 387 | limit=k, 388 | timeout=timeout, 389 | ) 390 | if "documents" not in res_data: 391 | raise ValueError(res_data) 392 | res: List[List[Dict]] = res_data.get("documents") 393 | # Organize results. 394 | ret: List[Tuple[Document, float]] = [] 395 | if res is None or len(res) == 0: 396 | return ret 397 | for result in res[0]: 398 | meta = result.get(self.field_metadata) 399 | if meta is not None: 400 | meta = json.loads(meta) 401 | doc = Document(page_content=result.get(self.field_text), metadata=meta) 402 | pair = (doc, result.get("score", 0.0)) 403 | ret.append(pair) 404 | return ret 405 | 406 | def similarity_search_with_score_by_vector( 407 | self, 408 | embedding: List[float], 409 | k: int = 4, 410 | param: Optional[dict] = None, 411 | expr: Optional[str] = None, 412 | timeout: Optional[int] = None, 413 | **kwargs: Any, 414 | ) -> List[Tuple[Document, float]]: 415 | """Perform a search on a query string and return results with score.""" 416 | filter = None if expr is None else self.document.Filter(expr) 417 | ef = 10 if param is None else param.get("ef", 10) 418 | res_data = self.collection.search( 419 | vectors=[embedding], 420 | filter=filter, 421 | params=self.document.HNSWSearchParams(ef=ef), 422 | retrieve_vector=False, 423 | limit=k, 424 | timeout=timeout, 425 | ) 426 | if "documents" not in res_data: 427 | raise ValueError(res_data) 428 | res: List[List[Dict]] = res_data.get("documents") # Organize results. 429 | ret: List[Tuple[Document, float]] = [] 430 | if res is None or len(res) == 0: 431 | return ret 432 | for result in res[0]: 433 | meta = result.get(self.field_metadata) 434 | if meta is not None: 435 | meta = json.loads(meta) 436 | doc = Document(page_content=result.get(self.field_text), metadata=meta) 437 | pair = (doc, result.get("score", 0.0)) 438 | ret.append(pair) 439 | return ret 440 | 441 | def max_marginal_relevance_search( 442 | self, 443 | query: str, 444 | k: int = 4, 445 | fetch_k: int = 20, 446 | lambda_mult: float = 0.5, 447 | param: Optional[dict] = None, 448 | expr: Optional[str] = None, 449 | timeout: Optional[int] = None, 450 | **kwargs: Any, 451 | ) -> List[Document]: 452 | """Perform a search and return results that are reordered by MMR.""" 453 | embedding = self.embedding_func.embed_query(query) 454 | return self.max_marginal_relevance_search_by_vector( 455 | embedding=embedding, 456 | k=k, 457 | fetch_k=fetch_k, 458 | lambda_mult=lambda_mult, 459 | param=param, 460 | expr=expr, 461 | timeout=timeout, 462 | **kwargs, 463 | ) 464 | 465 | def max_marginal_relevance_search_by_vector( 466 | self, 467 | embedding: list[float], 468 | k: int = 4, 469 | fetch_k: int = 20, 470 | lambda_mult: float = 0.5, 471 | param: Optional[dict] = None, 472 | expr: Optional[str] = None, 473 | timeout: Optional[int] = None, 474 | **kwargs: Any, 475 | ) -> List[Document]: 476 | """Perform a search and return results that are reordered by MMR.""" 477 | filter = None if expr is None else self.document.Filter(expr) 478 | ef = 10 if param is None else param.get("ef", 10) 479 | res: List[List[Dict]] = self.collection.search( 480 | vectors=[embedding], 481 | filter=filter, 482 | params=self.document.HNSWSearchParams(ef=ef), 483 | retrieve_vector=True, 484 | limit=fetch_k, 485 | timeout=timeout, 486 | ) 487 | # Organize results. 488 | documents = [] 489 | ordered_result_embeddings = [] 490 | for result in res[0]: 491 | meta = result.get(self.field_metadata) 492 | if meta is not None: 493 | meta = json.loads(meta) 494 | doc = Document(page_content=result.get(self.field_text), metadata=meta) 495 | documents.append(doc) 496 | ordered_result_embeddings.append(result.get(self.field_vector)) 497 | # Get the new order of results. 498 | new_ordering = maximal_marginal_relevance( 499 | np.array(embedding), ordered_result_embeddings, k=k, lambda_mult=lambda_mult 500 | ) 501 | # Reorder the values and return. 502 | ret = [] 503 | for x in new_ordering: 504 | # Function can return -1 index 505 | if x == -1: 506 | break 507 | else: 508 | ret.append(documents[x]) 509 | return ret 510 | 511 | def delete_texts_by_ids( 512 | self, 513 | document_ids: List[str], 514 | batch_size: int = 20, 515 | expr: Optional[str] = None, 516 | timeout: Optional[int] = None, 517 | ) -> List[dict]: 518 | filter = None if expr is None else self.document.Filter(expr) 519 | total_count = len(document_ids) 520 | ret = [] 521 | for start in range(0, total_count, batch_size): 522 | # Grab end index 523 | docs = [] 524 | end = min(start + batch_size, total_count) 525 | docs = document_ids[start:end] 526 | res_data = self.collection.delete( 527 | document_ids=docs, filter=filter, timeout=timeout 528 | ) 529 | 530 | if res_data.get("code") != 0: 531 | raise ValueError(f"delete texts failed: {res_data}") 532 | ret.append(res_data) 533 | return ret 534 | 535 | 536 | class EmojiTencentVectorDB(TencentVectorDB): 537 | field_filename: str = "filename" 538 | 539 | def __init__( 540 | self, 541 | connection_params: ConnectionParams, 542 | index_params: IndexParams = IndexParams(768, replicas=0), # replicas 默认单实例 543 | database_name: str = "DeepReadDatabase", 544 | collection_name: str = "EmojiCollection", 545 | embedding: Optional[Embeddings] = None, 546 | drop_old: Optional[bool] = False, 547 | ): 548 | super().__init__( 549 | embedding=embedding, 550 | connection_params=connection_params, 551 | index_params=index_params, 552 | database_name=database_name, 553 | collection_name=collection_name, 554 | drop_old=drop_old, 555 | ) 556 | 557 | def _create_collection(self, collection_name: str) -> None: 558 | enum = guard_import("tcvectordb.model.enum") 559 | vdb_index = guard_import("tcvectordb.model.index") 560 | coll = guard_import("tcvectordb.model.collection") 561 | index_type = None 562 | for k, v in enum.IndexType.__members__.items(): 563 | if k == self.index_params.index_type: 564 | index_type = v 565 | if index_type is None: 566 | raise ValueError("unsupported index_type") 567 | metric_type = None 568 | for k, v in enum.MetricType.__members__.items(): 569 | if k == self.index_params.metric_type: 570 | metric_type = v 571 | if metric_type is None: 572 | raise ValueError("unsupported metric_type") 573 | if self.index_params.params is None: 574 | params = vdb_index.HNSWParams(m=16, efconstruction=200) 575 | else: 576 | params = vdb_index.HNSWParams( 577 | m=self.index_params.params.get("M", 16), 578 | efconstruction=self.index_params.params.get("efConstruction", 200), 579 | ) 580 | index = vdb_index.Index( 581 | vdb_index.FilterIndex( 582 | self.field_id, enum.FieldType.String, enum.IndexType.PRIMARY_KEY 583 | ), 584 | vdb_index.VectorIndex( 585 | self.field_vector, 586 | self.index_params.dimension, 587 | index_type, 588 | metric_type, 589 | params, 590 | ), 591 | vdb_index.FilterIndex( 592 | self.field_text, enum.FieldType.String, enum.IndexType.FILTER 593 | ), 594 | vdb_index.FilterIndex( 595 | self.field_metadata, enum.FieldType.String, enum.IndexType.FILTER 596 | ), 597 | vdb_index.FilterIndex( 598 | self.field_filename, enum.FieldType.String, enum.IndexType.FILTER 599 | ), 600 | ) 601 | # Embedding config 602 | ebd = None 603 | if self.ebd_own: 604 | ebd = coll.Embedding( 605 | vector_field=self.field_vector, 606 | field=self.field_text, 607 | model=enum.EmbeddingModel.BGE_BASE_ZH, 608 | ) 609 | 610 | self.collection = self.database.create_collection( 611 | name=collection_name, 612 | shard=self.index_params.shard, 613 | replicas=self.index_params.replicas, 614 | description="Collection for DeepRead", 615 | index=index, 616 | embedding=ebd, 617 | ) 618 | 619 | @classmethod 620 | def from_texts_with_filename( 621 | cls, 622 | filename: str, 623 | texts: List[str], 624 | metadatas: Optional[List[dict]] = None, 625 | connection_params: Optional[ConnectionParams] = None, 626 | database_name: str = "DeepReadDatabase", 627 | collection_name: str = "DeepReadCollection", 628 | ) -> TencentVectorDB: 629 | return cls.from_texts( 630 | texts=texts, 631 | metadatas=metadatas, 632 | connection_params=connection_params, 633 | database_name=database_name, 634 | collection_name=collection_name, 635 | filename=filename, 636 | ) 637 | 638 | def add_original_texts_with_filename( 639 | self, 640 | filename: str, 641 | texts: Iterable[str], 642 | metadatas: Optional[List[dict]] = None, 643 | timeout: Optional[int] = None, 644 | batch_size: int = 1000, 645 | **kwargs: Any, 646 | ) -> List[str]: 647 | return self.add_original_texts( 648 | texts=texts, 649 | metadatas=metadatas, 650 | timeout=timeout, 651 | batch_size=batch_size, 652 | filename=filename, 653 | ) 654 | 655 | def similarity_search_by_filenames( 656 | self, query: str, filenames: List[str], k: int = 4 657 | ) -> List[Document]: 658 | expr = "" 659 | if len(filenames) > 0: 660 | for i in range(len(filenames) - 1): 661 | expr += f"{self.field_filename}={filenames[i]} or" 662 | expr += f"{self.field_filename}={filenames[-1]}" 663 | 664 | return self.similarity_search(query, k=k, expr=expr) 665 | 666 | def delete_texts_with_filenames( 667 | self, 668 | document_ids: List[str], 669 | filenames: List[str] = [], 670 | batch_size: int = 20, 671 | expr: Optional[str] = None, 672 | timeout: Optional[int] = None, 673 | ): 674 | expr = "" 675 | if len(filenames) > 0: 676 | for i in range(len(filenames) - 1): 677 | expr += f"{self.field_filename}={filenames[i]} or" 678 | expr += f"{self.field_filename}={filenames[-1]}" 679 | 680 | return self.delete_texts_by_ids( 681 | document_ids=document_ids, 682 | batch_size=batch_size, 683 | expr=expr, 684 | timeout=timeout, 685 | ) 686 | -------------------------------------------------------------------------------- /langchain_emoji/components/vector_store/vector_store_component.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | from injector import inject, singleton 4 | from pathlib import Path 5 | from langchain_emoji.settings.settings import Settings 6 | from langchain_emoji.components.vector_store.tencent.tencent import ( 7 | EmojiTencentVectorDB, 8 | ConnectionParams, 9 | ) 10 | from langchain_emoji.components.vector_store.chroma.chroma import EmojiChroma 11 | from chromadb.config import Settings as ChromaSettings 12 | 13 | from langchain_emoji.constants import PROJECT_ROOT_PATH 14 | from langchain_emoji.components.embedding.embedding_component import EmbeddingComponent 15 | 16 | logger = logging.getLogger(__name__) 17 | 18 | 19 | @singleton 20 | class VectorStoreComponent: 21 | @inject 22 | def __init__(self, embed: EmbeddingComponent, settings: Settings) -> None: 23 | self.embedcom = embed 24 | match settings.vectorstore.database: 25 | case "tcvectordb": 26 | tcvectorconf = settings.vectorstore.tcvectordb 27 | connect_params = ConnectionParams( 28 | url=tcvectorconf.url, key=tcvectorconf.api_key 29 | ) 30 | self.vector_store = EmojiTencentVectorDB( 31 | connection_params=connect_params, 32 | database_name=tcvectorconf.database_name, 33 | collection_name=tcvectorconf.collection_name, 34 | ) 35 | case "chromadb": 36 | data_dir = PROJECT_ROOT_PATH / settings.vectorstore.chromadb.persist_dir 37 | persist_directory = self.create_persist_directory("chromadb", data_dir) 38 | collection_name = settings.vectorstore.chromadb.collection_name 39 | self.vector_store = EmojiChroma( 40 | collection_name, 41 | embed._embedding, 42 | client_settings=ChromaSettings( 43 | anonymized_telemetry=False, 44 | is_persistent=True, 45 | persist_directory=persist_directory, 46 | ), 47 | ) 48 | case _: 49 | # Should be unreachable 50 | # The settings validator should have caught this 51 | raise ValueError( 52 | f"Vectorstore database {settings.vectorstore.database} not supported" 53 | ) 54 | 55 | def create_persist_directory(self, vectordb: str, data_dir: Path) -> str: 56 | if not (os.path.exists(data_dir) and os.path.isdir(data_dir)): 57 | raise FileNotFoundError(f"{data_dir} Error, Please Check Config") 58 | 59 | persist_directory = data_dir / vectordb 60 | # 检查目录是否存在,如果不存在则创建 61 | if not os.path.exists(persist_directory): 62 | os.makedirs(persist_directory) 63 | 64 | return str(persist_directory) 65 | 66 | def close(self) -> None: ... 67 | -------------------------------------------------------------------------------- /langchain_emoji/constants.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | PROJECT_ROOT_PATH: Path = Path(__file__).parents[1] 4 | -------------------------------------------------------------------------------- /langchain_emoji/di.py: -------------------------------------------------------------------------------- 1 | from injector import Injector 2 | 3 | from langchain_emoji.settings.settings import Settings, unsafe_typed_settings 4 | 5 | 6 | def create_application_injector() -> Injector: 7 | _injector = Injector(auto_bind=True) 8 | _injector.binder.bind(Settings, to=unsafe_typed_settings) 9 | return _injector 10 | 11 | 12 | """ 13 | Global injector for the application. 14 | 15 | Avoid using this reference, it will make your code harder to test. 16 | 17 | Instead, use the `request.state.injector` reference, which is bound to every request 18 | """ 19 | global_injector: Injector = create_application_injector() 20 | -------------------------------------------------------------------------------- /langchain_emoji/launcher.py: -------------------------------------------------------------------------------- 1 | """FastAPI app creation, logger configuration and main API routes.""" 2 | 3 | import logging 4 | from typing import Any 5 | 6 | from fastapi import Depends, FastAPI, Request 7 | from fastapi.middleware.cors import CORSMiddleware 8 | from fastapi.openapi.utils import get_openapi 9 | from injector import Injector 10 | from langchain_emoji.paths import docs_path 11 | from langchain_emoji.settings.settings import Settings 12 | from langchain_emoji.server.emoji.emoji_router import emoji_router 13 | from langchain_emoji.server.vector_store.vector_store_router import vector_store_router 14 | from langchain_emoji.server.trace.trace_router import trace_router 15 | from langchain_emoji.server.health.health_router import health_router 16 | from langchain_emoji.server.config.config_router import ( 17 | config_router_no_auth, 18 | config_router, 19 | ) 20 | 21 | logger = logging.getLogger(__name__) 22 | 23 | 24 | def create_app(root_injector: Injector) -> FastAPI: 25 | # Start the API 26 | with open(docs_path / "description.md") as description_file: 27 | description = description_file.read() 28 | 29 | tags_metadata = [ 30 | { 31 | "name": "Emoji", 32 | "description": "AI-enabled Emoji engine", 33 | }, 34 | { 35 | "name": "VectorStore", 36 | "description": "Store the emoji vectorically", 37 | }, 38 | { 39 | "name": "Config", 40 | "description": "Obtain and modify project configuration files", 41 | }, 42 | { 43 | "name": "Health", 44 | "description": "Simple health API to make sure the server is up and running.", 45 | }, 46 | ] 47 | 48 | async def bind_injector_to_request(request: Request) -> None: 49 | request.state.injector = root_injector 50 | 51 | app = FastAPI(dependencies=[Depends(bind_injector_to_request)]) 52 | 53 | def custom_openapi() -> dict[str, Any]: 54 | if app.openapi_schema: 55 | return app.openapi_schema 56 | openapi_schema = get_openapi( 57 | title="Langchain-SearXNG", 58 | description=description, 59 | version="0.1.0", 60 | summary="AI Emoji Argue Agent 🚀 基于LangChain的开源表情包斗图Agent", 61 | contact={ 62 | "url": "https://github.com/ptonlix", 63 | }, 64 | license_info={ 65 | "name": "Apache 2.0", 66 | "url": "https://www.apache.org/licenses/LICENSE-2.0.html", 67 | }, 68 | routes=app.routes, 69 | tags=tags_metadata, 70 | ) 71 | openapi_schema["info"]["x-logo"] = { 72 | "url": "https://lh3.googleusercontent.com/drive-viewer" 73 | "/AK7aPaD_iNlMoTquOBsw4boh4tIYxyEuhz6EtEs8nzq3yNkNAK00xGj" 74 | "E1KUCmPJSk3TYOjcs6tReG6w_cLu1S7L_gPgT9z52iw=s2560" 75 | } 76 | 77 | app.openapi_schema = openapi_schema 78 | return app.openapi_schema 79 | 80 | app.openapi = custom_openapi # type: ignore[method-assign] 81 | 82 | app.include_router(emoji_router) 83 | app.include_router(trace_router) 84 | app.include_router(vector_store_router) 85 | app.include_router(health_router) 86 | app.include_router(config_router_no_auth) 87 | app.include_router(config_router) 88 | 89 | settings = root_injector.get(Settings) 90 | if settings.server.cors.enabled: 91 | logger.debug("Setting up CORS middleware") 92 | app.add_middleware( 93 | CORSMiddleware, 94 | allow_credentials=settings.server.cors.allow_credentials, 95 | allow_origins=settings.server.cors.allow_origins, 96 | allow_origin_regex=settings.server.cors.allow_origin_regex, 97 | allow_methods=settings.server.cors.allow_methods, 98 | allow_headers=settings.server.cors.allow_headers, 99 | ) 100 | 101 | return app 102 | -------------------------------------------------------------------------------- /langchain_emoji/main.py: -------------------------------------------------------------------------------- 1 | """FastAPI app creation, logger configuration and main API routes.""" 2 | 3 | from langchain_emoji.di import global_injector 4 | from langchain_emoji.launcher import create_app 5 | 6 | app = create_app(global_injector) 7 | -------------------------------------------------------------------------------- /langchain_emoji/paths.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | from langchain_emoji.constants import PROJECT_ROOT_PATH 4 | from langchain_emoji.settings.settings import settings 5 | 6 | 7 | def _absolute_or_from_project_root(path: str) -> Path: 8 | if path.startswith("/"): 9 | return Path(path) 10 | return PROJECT_ROOT_PATH / path 11 | 12 | 13 | docs_path: Path = PROJECT_ROOT_PATH / "docs" 14 | 15 | local_data_path: Path = _absolute_or_from_project_root( 16 | settings().data.local_data_folder 17 | ) 18 | -------------------------------------------------------------------------------- /langchain_emoji/server/config/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ptonlix/LangChain-Emoji/4c0a926464a7419e48df32b3c5e26794e2f9b312/langchain_emoji/server/config/__init__.py -------------------------------------------------------------------------------- /langchain_emoji/server/config/config_router.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from fastapi import APIRouter, Depends, Request 3 | from typing import Dict, List 4 | from langchain_emoji.server.utils.auth import authenticated 5 | from langchain_emoji.server.utils.model import ( 6 | RestfulModel, 7 | SystemErrorCode, 8 | ) 9 | from langchain_emoji.settings.settings_loader import ( 10 | get_active_settings, 11 | save_active_settings, 12 | ) 13 | 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | config_router_no_auth = APIRouter(prefix="/v1") 18 | 19 | config_router = APIRouter(prefix="/v1", dependencies=[Depends(authenticated)]) 20 | 21 | 22 | @config_router_no_auth.get( 23 | "/config", 24 | response_model=RestfulModel[List[Dict]], 25 | tags=["Config"], 26 | ) 27 | async def get_config(request: Request) -> RestfulModel: 28 | 29 | try: 30 | return RestfulModel(data=get_active_settings()) 31 | except Exception as e: 32 | return RestfulModel(code=SystemErrorCode, msg=str(e), data=None) 33 | 34 | 35 | @config_router.post( 36 | "/config", 37 | response_model=RestfulModel[None], 38 | tags=["Config"], 39 | ) 40 | async def edit_config(request: Request, body: List[Dict]) -> RestfulModel: 41 | try: 42 | for config_dict in body: 43 | for profile, config in config_dict.items(): 44 | print(profile, config) 45 | save_active_settings(profile, config) 46 | 47 | return RestfulModel(data=None) 48 | except Exception as e: 49 | logger.exception(e) 50 | return RestfulModel(code=SystemErrorCode, msg=str(e), data=None) 51 | -------------------------------------------------------------------------------- /langchain_emoji/server/emoji/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ptonlix/LangChain-Emoji/4c0a926464a7419e48df32b3c5e26794e2f9b312/langchain_emoji/server/emoji/__init__.py -------------------------------------------------------------------------------- /langchain_emoji/server/emoji/emoji_prompt.py: -------------------------------------------------------------------------------- 1 | RESPONSE_TEMPLATE = """\ 2 | # Role: 一个表情包专家,擅长根据用户描述为用户选取最合适表情包 3 | 4 | ## Language: 中文 5 | 6 | ## Workflow 7 | 1. 学习 ##EmojiList 中给出表情包列表中每一个表情包的含义。其filename属性记录了表情包文件名`filename`,内容则是表情包的含义表示。 8 | 2. 根据 ##UserInput,选取一个最符合的表情包并返回, 一定不要自己构造数据,按照指定的JSON格式结构输出结果, 包含以下2个关键输出字段: `filename`、`content`。 9 | 10 | ## EmojiList 11 |
12 | {context} 13 |
14 | 15 | ## UserInput 16 |
17 | {prompt} 18 |
19 | 20 | ## Output format 21 |
22 | The output should be formatted as a JSON instance that conforms to the JSON schema below. 23 | filename: str 24 | content: str 25 | As an example, for the schema 26 | {{ 27 | "filename":"", 28 | "content":"", 29 | }} 30 | 31 | 输出示例: 32 | ```json 33 | {{ 34 | "filename": "5a122755-9316-4d05-81f4-26da5396c04e.jpg", 35 | "content": "这个表情包中的内容和笑点在于它展示了许多带有悲伤或不满情绪的表情符号,这些表情符号的脸部表情看起来都非常忧郁或不高兴。图片下方的文字“我的世界一片灰色”可能意味着这个表情包的使用者感到沮丧或情绪低落,就像世界失去了颜色一样。这种夸张的表达方式和文字与表情符号的结合,使得这个表情包在传达负面情绪的同时,也带有一定的幽默感。" 36 | }} 37 | ``` 38 |
39 | 40 | ## Start 41 | 作为一个 #Role, 你默认使用的是##Language,你不需要介绍自己,请根据##Workflow开始工作,你必须严格遵守输出格式##Output format,输出格式指定的JSON格式要求。 42 | """ 43 | 44 | # 以下为Prompt 备份 45 | 46 | ZHIPUAI_RESPONSE_TEMPLATE = """ 47 | 表情包列表: 48 | {context} 49 | 50 | 用户描述: 51 | {prompt} 52 | 53 | 请根据以下要求,根据用户描述为用户选取最合适表情包: 54 | 55 | 1. 学习`表情包列表`中每一个表情包的含义, 其中metadata属性记录了表情包文件名,内容则是表情包的含义表示 56 | 57 | 2. 根据`用户描述`,选取一个最符合的表情包,一定不要自己构造数据,请按照指定的JSON格式结构输出结果, 包含以下2个关键输出字段: `filename`、`content`,具体格式如下: 58 | ```json 59 | {{ 60 | "filename": string, 61 | "content": string 62 | }} 63 | ``` 64 | 输出示例: 65 | ```json 66 | {{ 67 | "filename": "5a122755-9316-4d05-81f4-26da5396c04e.jpg", 68 | "content": "这个表情包中的内容和笑点在于它展示了许多带有悲伤或不满情绪的表情符号,这些表情符号的脸部表情看起来都非常忧郁或不高兴。图片下方的文字“我的世界一片灰色”可能意味着这个表情包的使用者感到沮丧或情绪低落,就像世界失去了颜色一样。这种夸张的表达方式和文字与表情符号的结合,使得这个表情包在传达负面情绪的同时,也带有一定的幽默感。" 69 | }} 70 | ``` 71 | 72 | 请严格按照上述要求进行信息提取、格式输出,并遵守输出格式指定的JSON格式要求。 73 | """ 74 | -------------------------------------------------------------------------------- /langchain_emoji/server/emoji/emoji_router.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from fastapi import APIRouter, Depends, Request 3 | from langchain_emoji.server.utils.auth import authenticated 4 | from langchain_emoji.server.emoji.emoji_service import ( 5 | EmojiService, 6 | EmojiRequest, 7 | EmojiResponse, 8 | ) 9 | from langchain_emoji.server.utils.model import ( 10 | RestfulModel, 11 | SystemErrorCode, 12 | ) 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | emoji_router = APIRouter(prefix="/v1", dependencies=[Depends(authenticated)]) 17 | 18 | 19 | @emoji_router.post( 20 | "/emoji", 21 | response_model=RestfulModel[EmojiResponse | int | None], 22 | tags=["Emoji"], 23 | ) 24 | async def emoji_invoke(request: Request, body: EmojiRequest) -> RestfulModel: 25 | """ 26 | Call directly to return search results 27 | """ 28 | service = request.state.injector.get(EmojiService) 29 | try: 30 | return RestfulModel(data=await service.get_emoji(body)) 31 | except Exception as e: 32 | logger.exception(e) 33 | return RestfulModel(code=SystemErrorCode, msg=str(e), data=None) 34 | -------------------------------------------------------------------------------- /langchain_emoji/server/emoji/emoji_service.py: -------------------------------------------------------------------------------- 1 | from injector import inject 2 | from langchain_emoji.components.llm.llm_component import LLMComponent 3 | from langchain_emoji.components.trace.trace_component import TraceComponent 4 | from langchain_emoji.components.minio.minio_component import MinioComponent 5 | from langchain_emoji.components.vector_store.vector_store_component import ( 6 | VectorStoreComponent, 7 | ) 8 | 9 | from langchain_emoji.server.emoji.emoji_prompt import ( 10 | RESPONSE_TEMPLATE, 11 | ZHIPUAI_RESPONSE_TEMPLATE, 12 | ) 13 | from langchain.schema.output_parser import StrOutputParser 14 | from pydantic import BaseModel, Field 15 | import logging 16 | import tiktoken 17 | from langchain_emoji.settings.settings import Settings 18 | from langchain.schema.document import Document 19 | from langchain.schema.runnable import ( 20 | Runnable, 21 | RunnableLambda, 22 | RunnableBranch, 23 | RunnableMap, 24 | ) 25 | from langchain.schema.language_model import BaseLanguageModel 26 | from langchain.schema.messages import BaseMessage 27 | from langchain.schema.retriever import BaseRetriever 28 | from langchain.prompts import ChatPromptTemplate 29 | from langchain.callbacks.base import AsyncCallbackHandler 30 | from langchain.schema.runnable import ConfigurableField 31 | from operator import itemgetter 32 | 33 | from langchain_community.callbacks import get_openai_callback 34 | from langchain_emoji.components.llm.custom.zhipuai import get_zhipuai_callback 35 | from langchain_emoji.paths import local_data_path 36 | from uuid import UUID 37 | from json.decoder import JSONDecodeError 38 | import json 39 | import base64 40 | import re 41 | from typing import ( 42 | List, 43 | Optional, 44 | Sequence, 45 | Dict, 46 | Any, 47 | ) 48 | 49 | 50 | logger = logging.getLogger(__name__) 51 | 52 | 53 | class TokenInfo(BaseModel): 54 | model: str 55 | total_tokens: int = 0 56 | prompt_tokens: int = 0 57 | completion_tokens: int = 0 58 | embedding_tokens: int = 0 59 | successful_requests: int = 0 60 | total_cost: float = 0.0 61 | 62 | def clear(self): 63 | self.total_tokens = 0 64 | self.prompt_tokens: int = 0 65 | self.completion_tokens: int = 0 66 | self.successful_requests: int = 0 67 | self.total_cost: float = 0.0 68 | 69 | 70 | class EmojiRequest(BaseModel): 71 | prompt: str 72 | req_id: str 73 | llm: str = Field(default="openai", description="大模型") 74 | 75 | model_config = { 76 | "json_schema_extra": { 77 | "examples": [{"prompt": "xxxx", "req_id": "xxxx", "llm": "xxx"}] 78 | } 79 | } 80 | 81 | 82 | class EmojiInfo(BaseModel): 83 | filename: str 84 | content: str 85 | 86 | 87 | class EmojiDetail(BaseModel): 88 | download_link: Optional[str] = None 89 | base64: str 90 | 91 | 92 | class EmojiResponse(BaseModel): 93 | run_id: UUID 94 | emojiinfo: EmojiInfo 95 | emojidetail: EmojiDetail 96 | token_info: TokenInfo 97 | 98 | 99 | """ 100 | 读取chain run_id的回调 101 | """ 102 | 103 | 104 | class ReadRunIdAsyncHandler(AsyncCallbackHandler): 105 | 106 | def __init__(self): 107 | self.runid: UUID = None 108 | 109 | async def on_chain_start( 110 | self, 111 | serialized: Dict[str, Any], 112 | inputs: Dict[str, Any], 113 | *, 114 | run_id: UUID, 115 | parent_run_id: Optional[UUID] = None, 116 | tags: Optional[List[str]] = None, 117 | metadata: Optional[Dict[str, Any]] = None, 118 | **kwargs: Any, 119 | ) -> None: 120 | """Run when chain starts running.""" 121 | if not self.runid: 122 | self.runid = run_id 123 | 124 | async def on_chat_model_start( 125 | self, 126 | serialized: Dict[str, Any], 127 | messages: List[List[BaseMessage]], 128 | *, 129 | run_id: UUID, 130 | parent_run_id: Optional[UUID] = None, 131 | tags: Optional[List[str]] = None, 132 | metadata: Optional[Dict[str, Any]] = None, 133 | **kwargs: Any, 134 | ) -> Any: ... 135 | 136 | async def on_retriever_end( 137 | self, 138 | documents: Sequence[Document], 139 | *, 140 | run_id: UUID, 141 | parent_run_id: Optional[UUID] = None, 142 | tags: Optional[List[str]] = None, 143 | **kwargs: Any, 144 | ) -> None: 145 | """Run on retriever end.""" 146 | # TODO 目前没有合适过滤选项,获取原文链接还是采用全局变量获取 147 | ... 148 | 149 | def get_runid(self) -> UUID: 150 | return self.runid 151 | 152 | 153 | def fix_json(json_str: str) -> dict: 154 | # 使用正则表达式替换掉重复的逗号 155 | fixed_json_str = re.sub(r",\s*}", "}", json_str) 156 | fixed_json_str = re.sub(r",\s*]", "]", fixed_json_str) 157 | 158 | # 尝试加载修复后的JSON字符串 159 | return json.loads(fixed_json_str) 160 | 161 | 162 | class EmojiService: 163 | 164 | @inject 165 | def __init__( 166 | self, 167 | llm_component: LLMComponent, 168 | vector_component: VectorStoreComponent, 169 | trace_component: TraceComponent, 170 | minio_component: MinioComponent, 171 | settings: Settings, 172 | ) -> None: 173 | self.settings = settings 174 | self.llm_service = llm_component 175 | self.vector_service = vector_component 176 | self.trace_service = trace_component 177 | self.minio_service = minio_component 178 | self.chain = self.create_chain( 179 | self.llm_service.llm, 180 | self.get_vector_retriever(), 181 | ) 182 | 183 | def num_tokens_from_string(self, string: str) -> int: 184 | """Returns the number of tokens in a text string.""" 185 | encoding = tiktoken.encoding_for_model(self.llm_service.modelname) 186 | num_tokens = len(encoding.encode(string)) 187 | return num_tokens 188 | 189 | """ 190 | 防止返回不是json格式,增加校验处理 191 | """ 192 | 193 | def output_handle(self, output: str) -> dict: 194 | try: 195 | # 找到json子串并加载为字典 196 | start_index = output.find("{") 197 | end_index = output.rfind("}") + 1 198 | json_str = output[start_index:end_index] 199 | 200 | # 加载为字典 201 | return json.loads(json_str) 202 | except JSONDecodeError as e: 203 | logger.exception(e) 204 | return fix_json(json_str) 205 | 206 | async def get_emoji(self, body: EmojiRequest) -> EmojiResponse | None: 207 | logger.info(body) 208 | token_callback = ( 209 | get_openai_callback if body.llm == "openai" else get_zhipuai_callback 210 | ) 211 | with token_callback() as cb: 212 | read_runid = ReadRunIdAsyncHandler() # 读取runid回调 213 | result = await self.chain.ainvoke( 214 | input={"prompt": body.prompt, "llm": body.llm}, 215 | config={ 216 | "metadata": { 217 | "req_id": body.req_id, 218 | }, 219 | "configurable": {"llm": body.llm}, 220 | "callbacks": [cb, read_runid], 221 | }, 222 | ) 223 | 224 | logger.info(result) 225 | emojiinfo = EmojiInfo(**result) 226 | 227 | embed_tokens = self.vector_service.embedcom.total_tokens 228 | tokeninfo = TokenInfo( 229 | model=body.llm, 230 | total_tokens=cb.total_tokens + int(embed_tokens / 10), 231 | prompt_tokens=cb.prompt_tokens, 232 | completion_tokens=cb.completion_tokens, 233 | embedding_tokens=int(embed_tokens / 10), 234 | successful_requests=cb.successful_requests, 235 | total_cost=cb.total_cost, 236 | ) 237 | 238 | resobj = EmojiResponse( 239 | run_id=read_runid.get_runid(), 240 | emojiinfo=emojiinfo, 241 | emojidetail=self.get_file_desc(emojiinfo), 242 | token_info=tokeninfo, 243 | ) 244 | return resobj 245 | 246 | def get_file_desc(self, info: EmojiInfo) -> EmojiDetail: 247 | logger.info(self.settings.dataset.mode) 248 | if self.settings.dataset.mode == "local": 249 | emoji_file = ( 250 | local_data_path / self.settings.dataset.name / "emo" / info.filename 251 | ) 252 | with open(emoji_file, "rb") as image_file: 253 | encoded_string = base64.b64encode(image_file.read()) 254 | file_base64 = encoded_string.decode("utf-8") 255 | return EmojiDetail(base64=file_base64) 256 | elif self.settings.dataset.mode == "minio": 257 | file_base64 = self.minio_service.get_file_base64(info.filename) 258 | file_download_link = self.minio_service.get_download_link(info.filename) 259 | return EmojiDetail(base64=file_base64, download_link=file_download_link) 260 | 261 | def get_vector_retriever(self): 262 | base_vector = self.vector_service.vector_store.as_retriever().configurable_alternatives( 263 | # This gives this field an id 264 | # When configuring the end runnable, we can then use this id to configure this field 265 | ConfigurableField(id="vectordb"), 266 | default_key=self.settings.vectorstore.database, 267 | ) 268 | 269 | return base_vector.with_config(run_name="VectorRetriever") 270 | 271 | def create_retriever_chain(self, retriever: BaseRetriever) -> Runnable: 272 | return ( 273 | RunnableLambda(itemgetter("prompt")).with_config( 274 | run_name="Itemgetter:prompt" 275 | ) 276 | | retriever 277 | ).with_config(run_name="RetrievalChain") 278 | 279 | def format_docs(self, docs: Sequence[Document]) -> str: 280 | logger.info(docs) 281 | formatted_docs = [] 282 | for i, doc in enumerate(docs): 283 | filename = doc.metadata.get("filename") 284 | doc_string = ( 285 | f"{doc.page_content}" 286 | ) 287 | formatted_docs.append(doc_string) 288 | return "\n".join(formatted_docs) 289 | 290 | def create_chain( 291 | self, 292 | llm: BaseLanguageModel, 293 | retriever: BaseRetriever, 294 | ) -> Runnable: 295 | retriever_chain = self.create_retriever_chain(retriever) | RunnableLambda( 296 | self.format_docs 297 | ) 298 | _context = RunnableMap( 299 | { 300 | "context": retriever_chain.with_config(run_name="RetrievalChain"), 301 | "prompt": RunnableLambda(itemgetter("prompt")).with_config( 302 | run_name="Itemgetter:prompt" 303 | ), 304 | "llm": RunnableLambda(itemgetter("llm")).with_config( 305 | run_name="Itemgetter:llm" 306 | ), 307 | } 308 | ) 309 | _prompt = RunnableBranch( 310 | ( 311 | RunnableLambda( 312 | lambda x: bool( 313 | x.get("llm") == "openai" or x.get("llm") == "deepseek" 314 | ) 315 | ).with_config(run_name="CheckLLM"), 316 | ChatPromptTemplate.from_messages( 317 | [ 318 | ("human", RESPONSE_TEMPLATE), 319 | ] 320 | ).with_config(run_name="OpenaiPrompt"), 321 | ), 322 | ( 323 | ChatPromptTemplate.from_messages( 324 | [ 325 | ("human", ZHIPUAI_RESPONSE_TEMPLATE), 326 | ] 327 | ).with_config(run_name="ZhipuaiPrompt") 328 | ), 329 | ).with_config(run_name="ChoiceLLMPrompt") 330 | 331 | response_synthesizer = (_prompt | llm | StrOutputParser()).with_config( 332 | run_name="GenerateResponse", 333 | ) | RunnableLambda(self.output_handle).with_config(run_name="ResponseHandle") 334 | 335 | return ( 336 | { 337 | "prompt": RunnableLambda(itemgetter("prompt")).with_config( 338 | run_name="Itemgetter:prompt" 339 | ), 340 | "llm": RunnableLambda(itemgetter("llm")).with_config( 341 | run_name="Itemgetter:llm" 342 | ), 343 | } 344 | | _context 345 | | response_synthesizer 346 | ).with_config(run_name="EmojiChain") 347 | -------------------------------------------------------------------------------- /langchain_emoji/server/health/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ptonlix/LangChain-Emoji/4c0a926464a7419e48df32b3c5e26794e2f9b312/langchain_emoji/server/health/__init__.py -------------------------------------------------------------------------------- /langchain_emoji/server/health/health_router.py: -------------------------------------------------------------------------------- 1 | from typing import Literal 2 | 3 | from fastapi import APIRouter 4 | from pydantic import BaseModel, Field 5 | from langchain_emoji.server.utils.model import RestfulModel 6 | import logging 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | # Not authentication or authorization required to get the health status. 11 | health_router = APIRouter(prefix="/v1") 12 | 13 | 14 | class HealthResponse(BaseModel): 15 | status: Literal["ok"] = Field(default="ok") 16 | 17 | 18 | @health_router.get( 19 | "/health", 20 | tags=["Health"], 21 | response_model=RestfulModel[HealthResponse | None], 22 | ) 23 | def health() -> RestfulModel: 24 | """Return ok if the system is up.""" 25 | return RestfulModel(data=HealthResponse(status="ok")) 26 | -------------------------------------------------------------------------------- /langchain_emoji/server/trace/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ptonlix/LangChain-Emoji/4c0a926464a7419e48df32b3c5e26794e2f9b312/langchain_emoji/server/trace/__init__.py -------------------------------------------------------------------------------- /langchain_emoji/server/trace/trace_router.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from fastapi import APIRouter, Depends, Request 3 | from pydantic import BaseModel 4 | from typing import Union, Optional 5 | from langchain_emoji.server.utils.auth import authenticated 6 | from langchain_emoji.server.emoji.emoji_service import ( 7 | EmojiService, 8 | ) 9 | from langchain_emoji.server.utils.model import ( 10 | RestfulModel, 11 | SystemErrorCode, 12 | ) 13 | from uuid import UUID 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | trace_router = APIRouter(prefix="/v1", dependencies=[Depends(authenticated)]) 18 | 19 | 20 | class SendFeedbackBody(BaseModel): 21 | run_id: UUID 22 | key: str = "user_score" 23 | 24 | score: Union[float, int, bool, None] = None 25 | feedback_id: Optional[UUID] = None 26 | comment: Optional[str] = None 27 | 28 | 29 | @trace_router.post( 30 | "/feedback", 31 | tags=["Trace"], 32 | ) 33 | async def send_feedback(request: Request, body: SendFeedbackBody): 34 | service = request.state.injector.get(EmojiService) 35 | service.trace_service.trace_client.create_feedback( 36 | body.run_id, 37 | body.key, 38 | score=body.score, 39 | comment=body.comment, 40 | feedback_id=body.feedback_id, 41 | ) 42 | return RestfulModel(data="posted feedback successfully") 43 | 44 | 45 | class GetTraceBody(BaseModel): 46 | run_id: UUID 47 | 48 | 49 | @trace_router.post( 50 | "/get_trace", 51 | tags=["Trace"], 52 | ) 53 | async def get_trace(request: Request, body: GetTraceBody): 54 | service = request.state.injector.get(EmojiService) 55 | 56 | run_id = body.run_id 57 | if run_id is None: 58 | return RestfulModel( 59 | code=SystemErrorCode, msg="No LangSmith run ID provided", data=None 60 | ) 61 | return RestfulModel(data=await service.trace_service.aget_trace_url(str(run_id))) 62 | -------------------------------------------------------------------------------- /langchain_emoji/server/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ptonlix/LangChain-Emoji/4c0a926464a7419e48df32b3c5e26794e2f9b312/langchain_emoji/server/utils/__init__.py -------------------------------------------------------------------------------- /langchain_emoji/server/utils/auth.py: -------------------------------------------------------------------------------- 1 | """Authentication mechanism for the API. 2 | 3 | Define a simple mechanism to authenticate requests. 4 | More complex authentication mechanisms can be defined here, and be placed in the 5 | `authenticated` method (being a 'bean' injected in fastapi routers). 6 | 7 | Authorization can also be made after the authentication, and depends on 8 | the authentication. Authorization should not be implemented in this file. 9 | 10 | Authorization can be done by following fastapi's guides: 11 | * https://fastapi.tiangolo.com/advanced/security/oauth2-scopes/ 12 | * https://fastapi.tiangolo.com/tutorial/security/ 13 | * https://fastapi.tiangolo.com/tutorial/dependencies/dependencies-in-path-operation-decorators/ 14 | """ 15 | 16 | # mypy: ignore-errors 17 | # Disabled mypy error: All conditional function variants must have identical signatures 18 | # We are changing the implementation of the authenticated method, based on 19 | # the config. If the auth is not enabled, we are not defining the complex method 20 | # with its dependencies. 21 | import logging 22 | import secrets 23 | from typing import Annotated 24 | 25 | from fastapi import Depends, Header, HTTPException 26 | 27 | from langchain_emoji.settings.settings import settings 28 | 29 | # 401 signify that the request requires authentication. 30 | # 403 signify that the authenticated user is not authorized to perform the operation. 31 | NOT_AUTHENTICATED = HTTPException( 32 | status_code=401, 33 | detail="Not authenticated", 34 | headers={"WWW-Authenticate": 'Basic realm="All the API", charset="UTF-8"'}, 35 | ) 36 | 37 | logger = logging.getLogger(__name__) 38 | 39 | 40 | def _simple_authentication(authorization: Annotated[str, Header()] = "") -> bool: 41 | """Check if the request is authenticated.""" 42 | if not secrets.compare_digest(authorization, settings().server.auth.secret): 43 | # If the "Authorization" header is not the expected one, raise an exception. 44 | raise NOT_AUTHENTICATED 45 | return True 46 | 47 | 48 | if not settings().server.auth.enabled: 49 | logger.debug( 50 | "Defining a dummy authentication mechanism for fastapi, always authenticating requests" 51 | ) 52 | 53 | # Define a dummy authentication method that always returns True. 54 | def authenticated() -> bool: 55 | """Check if the request is authenticated.""" 56 | return True 57 | 58 | else: 59 | logger.info("Defining the given authentication mechanism for the API") 60 | 61 | # Method to be used as a dependency to check if the request is authenticated. 62 | def authenticated( 63 | _simple_authentication: Annotated[bool, Depends(_simple_authentication)] 64 | ) -> bool: 65 | """Check if the request is authenticated.""" 66 | assert settings().server.auth.enabled 67 | if not _simple_authentication: 68 | raise NOT_AUTHENTICATED 69 | return True 70 | -------------------------------------------------------------------------------- /langchain_emoji/server/utils/model.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel 2 | from typing import Generic, TypeVar 3 | 4 | 5 | T = TypeVar("T") # 泛型类型 T 6 | 7 | 8 | class RestfulModel(BaseModel, Generic[T]): 9 | code: int = 0 10 | msg: str = "success" 11 | data: T 12 | 13 | 14 | """ 15 | Server Error Code List 16 | 17 | """ 18 | 19 | """ 20 | system: 10000-10099 21 | """ 22 | SystemErrorCode = 10001 23 | -------------------------------------------------------------------------------- /langchain_emoji/server/vector_store/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ptonlix/LangChain-Emoji/4c0a926464a7419e48df32b3c5e26794e2f9b312/langchain_emoji/server/vector_store/__init__.py -------------------------------------------------------------------------------- /langchain_emoji/server/vector_store/vector_store_router.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from fastapi import APIRouter, Depends, Request 3 | from typing import Any, List 4 | from pydantic import BaseModel, Field 5 | from langchain_emoji.server.utils.auth import authenticated 6 | from langchain_emoji.server.vector_store.vector_store_server import ( 7 | VectorStoreService, 8 | EmojiFragment, 9 | ) 10 | from langchain_emoji.server.utils.model import ( 11 | RestfulModel, 12 | SystemErrorCode, 13 | ) 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | vector_store_router = APIRouter(prefix="/v1", dependencies=[Depends(authenticated)]) 18 | 19 | 20 | class AddEmojiBody(BaseModel): 21 | content: str = Field(description="表情描述") 22 | filename: str = Field(description="表情文件名") 23 | model_config = { 24 | "json_schema_extra": { 25 | "examples": [ 26 | { 27 | "content": "xxxx", 28 | "filename": "xxxx", 29 | } 30 | ] 31 | } 32 | } 33 | 34 | 35 | class RagEmojiBody(BaseModel): 36 | prompt: str = Field(description="表情描述") 37 | filenames: List[str] = Field(default=[], description="表情文件名称列表") 38 | 39 | model_config = { 40 | "json_schema_extra": { 41 | "examples": [{"prompt": "xxxx", "filenames": ["xxx", "xxx"]}] 42 | } 43 | } 44 | 45 | 46 | class DelEmojiBody(BaseModel): 47 | vdb_ids: List[str] = Field(description="文章向量数据库ID") 48 | filenames: List[str] = Field(default=[], description="表情文件名称列表") 49 | model_config = { 50 | "json_schema_extra": { 51 | "examples": [{"vdb_ids": ["xxxx"], "filenames": ["xxx", "xxx"]}] 52 | } 53 | } 54 | 55 | 56 | @vector_store_router.post( 57 | "/vector_store/add_emoji", 58 | response_model=RestfulModel[Any | None], 59 | tags=["VectorStore"], 60 | ) 61 | def add_emoji(request: Request, body: AddEmojiBody) -> RestfulModel: 62 | """ 63 | New article into vector database 64 | """ 65 | service = request.state.injector.get(VectorStoreService) 66 | try: 67 | return RestfulModel( 68 | data=service.add_emoji( 69 | content=body.content, 70 | filename=body.filename, 71 | ) 72 | ) 73 | except Exception as e: 74 | logger.exception(e) 75 | return RestfulModel(code=SystemErrorCode, msg=str(e), data=None) 76 | 77 | 78 | @vector_store_router.post( 79 | "/vector_store/rag_emoji", 80 | response_model=RestfulModel[List[EmojiFragment] | None], 81 | tags=["VectorStore"], 82 | ) 83 | def rag_emoji(request: Request, body: RagEmojiBody) -> RestfulModel: 84 | """ 85 | Recall articles from vector database 86 | """ 87 | service = request.state.injector.get(VectorStoreService) 88 | try: 89 | return RestfulModel( 90 | data=service.rag_emoji(prompt=body.prompt, filenames=body.filenames) 91 | ) 92 | except Exception as e: 93 | logger.exception(e) 94 | return RestfulModel(code=SystemErrorCode, msg=str(e), data=None) 95 | 96 | 97 | @vector_store_router.post( 98 | "/vector_store/del_emoji", 99 | response_model=RestfulModel[List[dict] | None], 100 | tags=["VectorStore"], 101 | ) 102 | def del_emoji(request: Request, body: DelEmojiBody) -> RestfulModel: 103 | """ 104 | Delete articles from vector database 105 | """ 106 | service = request.state.injector.get(VectorStoreService) 107 | try: 108 | return RestfulModel( 109 | data=service.del_emoji(vdb_ids=body.vdb_ids, filenames=body.filenames) 110 | ) 111 | except Exception as e: 112 | logger.exception(e) 113 | return RestfulModel(code=SystemErrorCode, msg=str(e), data=None) 114 | -------------------------------------------------------------------------------- /langchain_emoji/server/vector_store/vector_store_server.py: -------------------------------------------------------------------------------- 1 | from injector import inject, singleton 2 | from langchain_emoji.components.vector_store import VectorStoreComponent 3 | from pydantic import BaseModel, Field 4 | import logging 5 | from typing import List, Optional 6 | 7 | logger = logging.getLogger(__name__) 8 | 9 | 10 | class EmojiFragment(BaseModel): 11 | content: str = Field(description="表情描述段落") 12 | filename: Optional[str] = Field(default=-1, description="表情文件名") 13 | 14 | 15 | @singleton 16 | class VectorStoreService: 17 | @inject 18 | def __init__( 19 | self, 20 | vector_store: VectorStoreComponent, 21 | ) -> None: 22 | self.client = vector_store.vector_store 23 | 24 | def add_emoji( 25 | self, 26 | content: str, 27 | filename: str, 28 | ) -> List[str]: 29 | metadata = { 30 | "filename": filename, 31 | } 32 | return self.client.add_original_texts_with_filename( 33 | filename=filename, texts=[content], metadatas=[metadata] 34 | ) 35 | 36 | def rag_emoji(self, prompt: str, filenames: List[str] = []) -> List[EmojiFragment]: 37 | fragment_list = self.client.similarity_search_by_filenames( 38 | query=prompt, filenames=filenames, k=3 39 | ) 40 | res = [] 41 | for fragment in fragment_list: 42 | res.append( 43 | EmojiFragment(content=fragment.page_content, **fragment.metadata) 44 | ) 45 | return res 46 | 47 | def del_emoji(self, vdb_ids: List[str], filenames: List[str] = []) -> List[dict]: 48 | return self.client.delete_texts_with_filenames( 49 | document_ids=vdb_ids, filenames=filenames 50 | ) 51 | -------------------------------------------------------------------------------- /langchain_emoji/settings/__init__.py: -------------------------------------------------------------------------------- 1 | """Settings.""" 2 | -------------------------------------------------------------------------------- /langchain_emoji/settings/settings.py: -------------------------------------------------------------------------------- 1 | from typing import Literal, Optional 2 | 3 | from pydantic import BaseModel, Field 4 | 5 | from langchain_emoji.settings.settings_loader import load_active_settings 6 | 7 | 8 | class CorsSettings(BaseModel): 9 | """CORS configuration. 10 | 11 | For more details on the CORS configuration, see: 12 | # * https://fastapi.tiangolo.com/tutorial/cors/ 13 | # * https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS 14 | """ 15 | 16 | enabled: bool = Field( 17 | description="Flag indicating if CORS headers are set or not." 18 | "If set to True, the CORS headers will be set to allow all origins, methods and headers.", 19 | default=False, 20 | ) 21 | allow_credentials: bool = Field( 22 | description="Indicate that cookies should be supported for cross-origin requests", 23 | default=False, 24 | ) 25 | allow_origins: list[str] = Field( 26 | description="A list of origins that should be permitted to make cross-origin requests.", 27 | default=[], 28 | ) 29 | allow_origin_regex: list[str] = Field( 30 | description="A regex string to match against origins that should be permitted to make cross-origin requests.", 31 | default=None, 32 | ) 33 | allow_methods: list[str] = Field( 34 | description="A list of HTTP methods that should be allowed for cross-origin requests.", 35 | default=[ 36 | "GET", 37 | ], 38 | ) 39 | allow_headers: list[str] = Field( 40 | description="A list of HTTP request headers that should be supported for cross-origin requests.", 41 | default=[], 42 | ) 43 | 44 | 45 | class AuthSettings(BaseModel): 46 | """Authentication configuration. 47 | 48 | The implementation of the authentication strategy must 49 | """ 50 | 51 | enabled: bool = Field( 52 | description="Flag indicating if authentication is enabled or not.", 53 | default=False, 54 | ) 55 | secret: str = Field( 56 | description="The secret to be used for authentication. " 57 | "It can be any non-blank string. For HTTP basic authentication, " 58 | "this value should be the whole 'Authorization' header that is expected" 59 | ) 60 | 61 | 62 | class ServerSettings(BaseModel): 63 | env_name: str = Field( 64 | description="Name of the environment (prod, staging, local...)" 65 | ) 66 | port: int = Field( 67 | description="Port of Langchain-DeepRead FastAPI server, defaults to 8001" 68 | ) 69 | cors: CorsSettings = Field( 70 | description="CORS configuration", default=CorsSettings(enabled=False) 71 | ) 72 | auth: AuthSettings = Field( 73 | description="Authentication configuration", 74 | default_factory=lambda: AuthSettings(enabled=False, secret="secret-key"), 75 | ) 76 | 77 | 78 | class LLMSettings(BaseModel): 79 | mode: Literal["local", "openai", "zhipuai", "deepseek", "all", "mock"] 80 | max_new_tokens: int = Field( 81 | 256, 82 | description="The maximum number of token that the LLM is authorized to generate in one completion.", 83 | ) 84 | 85 | 86 | class OpenAISettings(BaseModel): 87 | temperature: float 88 | modelname: str 89 | api_key: str 90 | api_base: str 91 | 92 | 93 | class DeepSeekSettings(BaseModel): 94 | temperature: float 95 | modelname: str 96 | api_key: str 97 | api_base: str 98 | 99 | 100 | class ZhipuAISettings(BaseModel): 101 | temperature: float 102 | top_p: float 103 | modelname: str 104 | api_key: str 105 | 106 | 107 | class LangSmithSettings(BaseModel): 108 | trace_version_v2: bool 109 | langchain_project: str 110 | api_key: str 111 | 112 | 113 | class TvectordbSettings(BaseModel): 114 | url: str 115 | username: str 116 | api_key: str 117 | collection_name: str 118 | database_name: str 119 | 120 | 121 | class EmbeddingSettings(BaseModel): 122 | mode: Literal["local", "openai", "zhipuai", "mock"] 123 | 124 | 125 | class ChromadbSettings(BaseModel): 126 | persist_dir: str 127 | collection_name: str 128 | 129 | 130 | class VectorstoreSettings(BaseModel): 131 | database: Literal["tcvectordb", "chromadb"] 132 | tcvectordb: TvectordbSettings 133 | chromadb: ChromadbSettings 134 | 135 | 136 | class DataSettings(BaseModel): 137 | local_data_folder: str = Field( 138 | description="Path to local storage." 139 | "It will be treated as an absolute path if it starts with /" 140 | ) 141 | 142 | 143 | class MinioSettings(BaseModel): 144 | host: str 145 | bucket_name: str 146 | access_key: str 147 | secret_key: str 148 | 149 | 150 | class DatasetSettings(BaseModel): 151 | name: str 152 | google_driver_id: str 153 | mode: Literal["minio", "local"] 154 | 155 | 156 | class Settings(BaseModel): 157 | server: ServerSettings 158 | llm: LLMSettings 159 | openai: OpenAISettings 160 | deepseek: DeepSeekSettings 161 | zhipuai: ZhipuAISettings 162 | langsmith: LangSmithSettings 163 | vectorstore: VectorstoreSettings 164 | embedding: EmbeddingSettings 165 | data: DataSettings 166 | minio: Optional[MinioSettings] = None 167 | dataset: DatasetSettings 168 | 169 | 170 | """ 171 | This is visible just for DI or testing purposes. 172 | 173 | Use dependency injection or `settings()` method instead. 174 | """ 175 | unsafe_settings = load_active_settings() 176 | 177 | """ 178 | This is visible just for DI or testing purposes. 179 | 180 | Use dependency injection or `settings()` method instead. 181 | """ 182 | unsafe_typed_settings = Settings(**unsafe_settings) 183 | 184 | 185 | def settings() -> Settings: 186 | """Get the current loaded settings from the DI container. 187 | 188 | This method exists to keep compatibility with the existing code, 189 | that require global access to the settings. 190 | 191 | For regular components use dependency injection instead. 192 | """ 193 | from langchain_emoji.di import global_injector 194 | 195 | return global_injector.get(Settings) 196 | -------------------------------------------------------------------------------- /langchain_emoji/settings/settings_loader.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import logging 3 | import os 4 | import sys 5 | from collections.abc import Iterable 6 | from pathlib import Path 7 | from typing import Any, List 8 | 9 | from pydantic.v1.utils import deep_update, unique_list 10 | 11 | from langchain_emoji.constants import PROJECT_ROOT_PATH 12 | from langchain_emoji.settings.yaml import ( 13 | load_yaml_with_envvars, 14 | update_yaml_config_file, 15 | ) 16 | 17 | logger = logging.getLogger(__name__) 18 | 19 | _settings_folder = os.environ.get("LE_SETTINGS_FOLDER", PROJECT_ROOT_PATH) 20 | 21 | # if running in unittest, use the test profile 22 | _test_profile = ["test"] if "unittest" in sys.modules else [] 23 | 24 | active_profiles: list[str] = unique_list( 25 | ["default"] 26 | + [ 27 | item.strip() 28 | for item in os.environ.get("LE_PROFILES", "").split(",") 29 | if item.strip() 30 | ] 31 | + _test_profile 32 | ) 33 | 34 | 35 | def merge_settings(settings: Iterable[dict[str, Any]]) -> dict[str, Any]: 36 | return functools.reduce(deep_update, settings, {}) 37 | 38 | 39 | def load_settings_from_profile(profile: str) -> dict[str, Any]: 40 | if profile == "default": 41 | profile_file_name = "settings.yaml" 42 | else: 43 | profile_file_name = f"settings-{profile}.yaml" 44 | 45 | path = Path(_settings_folder) / profile_file_name 46 | with Path(path).open("r") as f: 47 | config = load_yaml_with_envvars(f) 48 | if not isinstance(config, dict): 49 | raise TypeError(f"Config file has no top-level mapping: {path}") 50 | return config 51 | 52 | 53 | def load_active_settings() -> dict[str, Any]: 54 | """Load active profiles and merge them.""" 55 | logger.info("Starting application with profiles=%s", active_profiles) 56 | loaded_profiles = [ 57 | load_settings_from_profile(profile) for profile in active_profiles 58 | ] 59 | merged: dict[str, Any] = merge_settings(loaded_profiles) 60 | return merged 61 | 62 | 63 | def get_active_settings() -> List[dict[str, Any]]: 64 | """Load active profiles and merge them.""" 65 | loaded_profiles = [ 66 | {profile: load_settings_from_profile(profile)} for profile in active_profiles 67 | ] 68 | 69 | return loaded_profiles 70 | 71 | 72 | def save_active_settings(profile: str, config: dict[str, Any]): 73 | 74 | if profile == "default": 75 | profile_file_name = "settings.yaml" 76 | else: 77 | profile_file_name = f"settings-{profile}.yaml" 78 | 79 | path = Path(_settings_folder) / profile_file_name 80 | update_yaml_config_file(path, config) 81 | -------------------------------------------------------------------------------- /langchain_emoji/settings/yaml.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import typing 4 | from typing import Any, TextIO 5 | 6 | import yaml 7 | from yaml import SafeLoader 8 | 9 | _env_replace_matcher = re.compile(r"\$\{(\w|_)+:?.*}") 10 | 11 | 12 | @typing.no_type_check # pyaml does not have good hints, everything is Any 13 | def load_yaml_with_envvars( 14 | stream: TextIO, environ: dict[str, Any] = os.environ 15 | ) -> dict[str, Any]: 16 | """Load yaml file with environment variable expansion. 17 | 18 | The pattern ${VAR} or ${VAR:default} will be replaced with 19 | the value of the environment variable. 20 | """ 21 | loader = SafeLoader(stream) 22 | 23 | def load_env_var(_, node) -> str: 24 | """Extract the matched value, expand env variable, and replace the match.""" 25 | value = str(node.value).removeprefix("${").removesuffix("}") 26 | split = value.split(":", 1) 27 | env_var = split[0] 28 | value = environ.get(env_var) 29 | default = None if len(split) == 1 else split[1] 30 | if value is None and default is None: 31 | raise ValueError( 32 | f"Environment variable {env_var} is not set and not default was provided" 33 | ) 34 | return value or default 35 | 36 | loader.add_implicit_resolver("env_var_replacer", _env_replace_matcher, None) 37 | loader.add_constructor("env_var_replacer", load_env_var) 38 | 39 | try: 40 | return loader.get_single_data() 41 | finally: 42 | loader.dispose() 43 | 44 | 45 | @typing.no_type_check # pyaml does not have good hints, everything is Any 46 | def update_yaml_config_file(file_path: str, update_dict: dict[str, Any]): 47 | """Update YAML configuration file with given key-value pairs.""" 48 | with open(file_path, "r") as file: 49 | config = yaml.safe_load(file) 50 | 51 | # Update the config dictionary with the provided key-value pairs 52 | for key, value in update_dict.items(): 53 | config[key] = value 54 | 55 | # Write the updated config back to the file 56 | with open(file_path, "w") as file: 57 | yaml.safe_dump(config, file, sort_keys=False) 58 | -------------------------------------------------------------------------------- /langchain_emoji/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ptonlix/LangChain-Emoji/4c0a926464a7419e48df32b3c5e26794e2f9b312/langchain_emoji/utils/__init__.py -------------------------------------------------------------------------------- /langchain_emoji/utils/_compat.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import sys 4 | 5 | from contextlib import suppress 6 | 7 | if sys.version_info < (3, 10): 8 | # compatibility for python <3.10 9 | import importlib_metadata as metadata 10 | else: 11 | from importlib import metadata 12 | 13 | WINDOWS = sys.platform == "win32" 14 | 15 | 16 | def decode(string: bytes | str, encodings: list[str] | None = None) -> str: 17 | if not isinstance(string, bytes): 18 | return string 19 | 20 | encodings = encodings or ["utf-8", "latin1", "ascii"] 21 | 22 | for encoding in encodings: 23 | with suppress(UnicodeEncodeError, UnicodeDecodeError): 24 | return string.decode(encoding) 25 | 26 | return string.decode(encodings[0], errors="ignore") 27 | 28 | 29 | def encode(string: str, encodings: list[str] | None = None) -> bytes: 30 | if isinstance(string, bytes): 31 | return string 32 | 33 | encodings = encodings or ["utf-8", "latin1", "ascii"] 34 | 35 | for encoding in encodings: 36 | with suppress(UnicodeEncodeError, UnicodeDecodeError): 37 | return string.encode(encoding) 38 | 39 | return string.encode(encodings[0], errors="ignore") 40 | 41 | 42 | __all__ = [ 43 | "WINDOWS", 44 | "decode", 45 | "encode", 46 | "metadata", 47 | ] 48 | -------------------------------------------------------------------------------- /local_data/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ptonlix/LangChain-Emoji/4c0a926464a7419e48df32b3c5e26794e2f9b312/local_data/.gitkeep -------------------------------------------------------------------------------- /log/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ptonlix/LangChain-Emoji/4c0a926464a7419e48df32b3c5e26794e2f9b312/log/.gitkeep -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "langchain-emoji" 3 | version = "0.1.0" 4 | description = "" 5 | authors = ["ptonlix <260431910@qq.com>"] 6 | readme = "README.md" 7 | 8 | [tool.poetry.dependencies] 9 | python = "~3.10" 10 | uvicorn = "^0.29.0" 11 | langchain = "^0.1.16" 12 | fastapi = "^0.110.1" 13 | injector = "^0.21.0" 14 | tiktoken = "^0.6.0" 15 | openai = "^1.23.1" 16 | zhipuai = "^2.0.1" 17 | watchfiles = "^0.21.0" 18 | tcvectordb = "^1.3.2" 19 | jsonlines = "^4.0.0" 20 | gdown = "^5.1.0" 21 | minio = "^7.2.7" 22 | langchain-chroma = "^0.1.0" 23 | streamlit = "^1.34.0" 24 | onnxruntime = "1.16.3" 25 | 26 | 27 | [build-system] 28 | requires = ["poetry-core"] 29 | build-backend = "poetry.core.masonry.api" 30 | -------------------------------------------------------------------------------- /settings.yaml: -------------------------------------------------------------------------------- 1 | # The default configuration file. 2 | # More information about configuration can be found in the documentation: https://github.com/ptonlix/LangChain-SearXNG 3 | # Syntax in `langchain_emoji/settings/settings.py` 4 | server: 5 | env_name: ${APP_ENV:prod} 6 | port: ${PORT:8003} 7 | cors: 8 | enabled: false 9 | allow_origins: ["*"] 10 | allow_methods: ["*"] 11 | allow_headers: ["*"] 12 | auth: 13 | enabled: false 14 | # python -c 'import base64; print("Basic " + base64.b64encode("secret:key".encode()).decode())' 15 | # 'secret' is the username and 'key' is the password for basic auth by default 16 | # If the auth is enabled, this value must be set in the "Authorization" header of the request. 17 | secret: "Basic c2VjcmV0OmtleQ==" 18 | 19 | llm: 20 | mode: all 21 | 22 | embedding: 23 | mode: zhipuai 24 | 25 | openai: 26 | temperature: 1 27 | modelname: "gpt-3.5-turbo-0125" #"gpt-3.5-turbo-1106" 28 | api_base: ${OPENAI_API_BASE:} 29 | api_key: ${OPENAI_API_KEY:} 30 | 31 | deepseek: 32 | temperature: 1 33 | modelname: "deepseek-chat" 34 | api_base: ${DEEPSEEK_API_BASE:} 35 | api_key: ${DEEPSEEK_API_KEY:} 36 | 37 | zhipuai: 38 | temperature: 0.95 39 | top_p: 0.6 40 | modelname: "glm-3-turbo" 41 | api_key: ${ZHIPUAI_API_KEY:} 42 | 43 | langsmith: 44 | trace_version_v2: true 45 | langchain_project: langchain-emoji 46 | api_key: ${LANGCHAIN_API_KEY:} 47 | 48 | vectorstore: 49 | database: chromadb 50 | tcvectordb: 51 | url: ${TCVERCTORDB_API_HOST:} 52 | username: root 53 | api_key: ${TCVERCTORDB_API_KEY:} 54 | collection_name: EmojiCollection 55 | database_name: DeepReadDatabase 56 | chromadb: 57 | persist_dir: local_data 58 | collection_name: EmojiCollection 59 | 60 | dataset: 61 | name: emo-visual-data 62 | google_driver_id: 1r3uO0wvgQ791M_6iIyBODo_8GekBjPMf 63 | mode: local 64 | 65 | data: 66 | local_data_folder: local_data 67 | 68 | minio: 69 | host: ${MINIO_HOST:} 70 | bucket_name: emoji 71 | access_key: ${MINIO_ACCESS_KEY:} 72 | secret_key: ${MINIO_SECRET_KEY:} 73 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ptonlix/LangChain-Emoji/4c0a926464a7419e48df32b3c5e26794e2f9b312/tests/__init__.py -------------------------------------------------------------------------------- /tools/datainit.py: -------------------------------------------------------------------------------- 1 | import gdown 2 | import zipfile 3 | import os 4 | import sys 5 | import argparse 6 | from tqdm import tqdm 7 | from pathlib import Path 8 | import hashlib 9 | from minio import Minio 10 | from minio.error import MinioException 11 | from langchain_core.vectorstores import VectorStore 12 | 13 | import concurrent.futures 14 | import threading 15 | import jsonlines 16 | from queue import Queue 17 | 18 | import logging 19 | 20 | logging.getLogger("httpx").setLevel(logging.WARNING) 21 | logger = logging.getLogger(__name__) 22 | logger.setLevel(logging.INFO) 23 | 24 | # 从百度下载数据,解析 25 | # https://pan.baidu.com/s/11iwqoxLtjV-DOQli81vZ6Q?pwd=tab4 26 | # 下载到local_data 27 | 28 | 29 | # 或者从谷歌云下载 (可选) 30 | # 定义谷歌云盘文件的ID 31 | def download_and_extract_data(file_id: str, output_dir: Path) -> bool: 32 | """ 33 | Download a file from Google Drive and extract it to the specified directory. 34 | 35 | Parameters: 36 | - file_id (str): Google Drive file ID. 37 | - output_dir (str): Directory to extract the downloaded file. 38 | 39 | Returns: 40 | - bool: True if successful, False otherwise. 41 | """ 42 | try: 43 | # Define Google Drive download URL 44 | url = f"https://drive.google.com/uc?id={file_id}" 45 | 46 | # Define the output file path 47 | output_path = str(output_dir / "emo-visual-data.zip") 48 | 49 | # Download the file 50 | gdown.download(url, output_path, quiet=False) 51 | 52 | # Create the extraction directory if it doesn't exist 53 | os.makedirs(output_dir, exist_ok=True) 54 | 55 | # Extract the downloaded file 56 | with zipfile.ZipFile(output_path, "r") as zip_ref: 57 | zip_ref.extractall(output_dir) 58 | 59 | # Clean up: remove the downloaded zip file 60 | os.remove(output_path) 61 | 62 | logger.info("File downloaded and extracted successfully.") 63 | return True 64 | 65 | except Exception as e: 66 | logger.exception(f"Error: {e}") 67 | return False 68 | 69 | 70 | def calculate_md5(file_path: str) -> str: 71 | """ 72 | Calculate the MD5 hash of a file. 73 | 74 | Parameters: 75 | - file_path (str): Path to the file. 76 | 77 | Returns: 78 | - str: MD5 hash of the file. 79 | """ 80 | md5_hash = hashlib.md5() 81 | with open(file_path, "rb") as f: 82 | # 以二进制方式读取文件内容并计算MD5值 83 | for chunk in iter(lambda: f.read(4096), b""): 84 | md5_hash.update(chunk) 85 | return md5_hash.hexdigest() 86 | 87 | 88 | def create_minio_bucket(minio_client: Minio, bucket_name: str) -> bool: 89 | """ 90 | Create a MinIO bucket if it doesn't already exist. 91 | 92 | Parameters: 93 | - minio_client (Minio): MinIO client object. 94 | - bucket_name (str): Name of the MinIO bucket to create. 95 | 96 | Returns: 97 | - bool: True if successful or the bucket already exists, False otherwise. 98 | """ 99 | try: 100 | # Check if the bucket already exists 101 | if minio_client.bucket_exists(bucket_name): 102 | logger.info(f"Bucket '{bucket_name}' already exists.") 103 | return True 104 | else: 105 | # Create the bucket 106 | minio_client.make_bucket(bucket_name) 107 | logger.info(f"Bucket '{bucket_name}' created successfully.") 108 | return True 109 | 110 | except MinioException as err: 111 | logger.exception(f"MinIO error: {err}") 112 | return False 113 | 114 | 115 | def upload_file( 116 | minio_client: Minio, 117 | bucket_name: str, 118 | file_path: str, 119 | object_name: str, 120 | file_md5: str, 121 | failed_files: list, 122 | progress_queue: Queue, 123 | ) -> bool: 124 | """ 125 | Upload a single file to MinIO. 126 | 127 | Parameters: 128 | Parameters: 129 | - minio_client (Minio): MinIO client object. 130 | - bucket_name (str): Name of the MinIO bucket. 131 | - file_path (str): Path to the file. 132 | - object_name (str): Object name in MinIO. 133 | - file_md5 (str): MD5 hash of the file. 134 | - total_pbar (tqdm): Total progress bar. 135 | - failed_files (list): List to store failed file names. 136 | 137 | Returns: 138 | - bool: True if successful, False otherwise. 139 | """ 140 | try: 141 | try: 142 | # 检查MinIO中是否已存在相同对象(文件)并且MD5值相同。 143 | object_stat = minio_client.stat_object(bucket_name, object_name) 144 | # 如果对象存在且MD5值相等,跳过此文件。'ETag': '"c204c0f5ab34d3caa8909efd81b24c47"' 145 | if object_stat.metadata["ETag"][1:-1] == file_md5: 146 | logger.debug( 147 | f"Skipping {file_path}: File already exists in MinIO with same MD5." 148 | ) 149 | progress_queue.put(1) 150 | return True 151 | except MinioException as e: 152 | # 如果对象不存在或MD5值不相等,则继续上传。 153 | if e.code != "NoSuchKey": 154 | logger.error(f"Error checking object: {e}") 155 | return False 156 | 157 | # 上传文件到MinIO。 158 | minio_client.fput_object( 159 | bucket_name, 160 | object_name, 161 | file_path, 162 | ) 163 | logger.debug(f"Uploaded {file_path} to MinIO.") 164 | progress_queue.put(1) 165 | return True 166 | 167 | except MinioException as e: 168 | logger.error(f"Error uploading {file_path} to MinIO: {e}") 169 | failed_files.append(file_path) 170 | progress_queue.put(1) 171 | return False 172 | 173 | 174 | def track_progress(total: int, progress_queue: Queue): 175 | # 创建总进度条。 176 | with tqdm( 177 | total=total, 178 | leave=True, 179 | ncols=100, 180 | file=sys.stdout, 181 | desc=f"Total Files {total}", 182 | unit="files", 183 | ) as total_pbar: 184 | 185 | while True: 186 | progress_queue.get() 187 | total_pbar.update(1) 188 | if total_pbar.n >= total_pbar.total: 189 | break 190 | 191 | 192 | # 上传云端Minio(可选) 193 | def upload_to_minio(minio_client: Minio, source_dir: Path, bucket_name: str) -> bool: 194 | """ 195 | Upload files from a local directory to a MinIO bucket. 196 | 197 | Parameters: 198 | - minio_client (Minio): MinIO client object. 199 | - source_dir (str): Directory containing the files to upload. 200 | - bucket_name (str): Name of the MinIO bucket. 201 | 202 | Returns: 203 | - bool: True if successful, False otherwise. 204 | """ 205 | 206 | # 保存上传失败的文件名的文件路径 207 | failed_files_path = "failed_files.txt" 208 | # 上传失败的文件列表 209 | failed_files = [] 210 | 211 | if not create_minio_bucket(minio_client, bucket_name): 212 | return False 213 | 214 | try: 215 | 216 | # 使用线程池并发执行文件上传任务 217 | with concurrent.futures.ThreadPoolExecutor() as executor: 218 | futures = [] 219 | progress_queue = Queue() 220 | total_files = sum(len(files) for _, _, files in os.walk(source_dir)) 221 | # 遍历源目录中的所有文件。 222 | for root, dirs, files in os.walk(source_dir): 223 | for file in files: 224 | file_path = os.path.join(root, file) 225 | # 计算对象名称(相对于source_dir的相对路径)。 226 | object_name = os.path.relpath(file_path, source_dir) 227 | # 计算文件的MD5值。 228 | file_md5 = calculate_md5(file_path) 229 | # 提交上传任务到线程池。 230 | future = executor.submit( 231 | upload_file, 232 | minio_client, 233 | bucket_name, 234 | file_path, 235 | object_name, 236 | file_md5, 237 | failed_files, 238 | progress_queue, 239 | ) 240 | futures.append(future) 241 | 242 | # Start the progress tracking thread 243 | progress_thread = threading.Thread( 244 | target=track_progress, args=(len(futures), progress_queue), daemon=True 245 | ) 246 | progress_thread.start() 247 | 248 | # 等待所有上传任务完成。 249 | for future in concurrent.futures.as_completed(futures): 250 | if not future.result(): 251 | logger.error(f"result error:{future.result()}") 252 | 253 | logger.info(f"All {total_files} files uploaded to MinIO.") 254 | # 输出上传成功和失败文件数量 255 | success_count = total_files - len(failed_files) 256 | logger.info(f"Total files uploaded successfully: {success_count}") 257 | logger.info(f"Total files failed to upload: {len(failed_files)}") 258 | 259 | # 将上传失败的文件名保存到文件中 260 | if failed_files: 261 | with open(failed_files_path, "w") as f: 262 | f.write("\n".join(failed_files)) 263 | 264 | progress_thread.join() 265 | 266 | return True 267 | 268 | except Exception as err: 269 | logger.exception(f"An error occurred: {err}") 270 | return False 271 | 272 | 273 | # 加载向量数据库 274 | 275 | 276 | def upload_file_vectordb( 277 | client: VectorStore, 278 | data: dict, 279 | failed_files: list, 280 | progress_queue: Queue, 281 | ) -> bool: 282 | try: 283 | filename = data["filename"] 284 | content = data["content"] 285 | 286 | if isinstance(client, VectorStore): 287 | metadata = { 288 | "filename": filename, 289 | } 290 | 291 | result = client.add_original_texts_with_filename( 292 | filename=filename, texts=[content], metadatas=[metadata] 293 | ) 294 | 295 | progress_queue.put(1) 296 | return bool(result) 297 | 298 | except Exception as e: 299 | logger.error(f"Error uploading {filename} to VectorDB: {e}") 300 | failed_files.append(filename) 301 | progress_queue.put(1) 302 | 303 | 304 | def upload_vectordb(client: VectorStore, dataset_file: str) -> bool: 305 | 306 | # 保存上传失败的文件名的文件路径 307 | failed_files_path = "vector_failed_files.txt" 308 | # 上传失败的文件列表 309 | failed_files = [] 310 | 311 | try: 312 | with jsonlines.open(dataset_file) as reader: 313 | 314 | # 使用线程池并发执行文件上传任务 315 | with concurrent.futures.ThreadPoolExecutor() as executor: 316 | futures = [] 317 | progress_queue = Queue() 318 | for json_line in reader: 319 | future = executor.submit( 320 | upload_file_vectordb, 321 | client, 322 | json_line, 323 | failed_files, 324 | progress_queue, 325 | ) 326 | futures.append(future) 327 | 328 | # Start the progress tracking thread 329 | progress_thread = threading.Thread( 330 | target=track_progress, 331 | args=(len(futures), progress_queue), 332 | daemon=True, 333 | ) 334 | progress_thread.start() 335 | 336 | # 等待所有上传任务完成。 337 | for future in concurrent.futures.as_completed(futures): 338 | if not future.result(): 339 | # logger.error(f"result error:{future.result()}") 340 | ... 341 | 342 | logger.info(f"All {len(futures)} files uploaded to VectorDB.") 343 | # 输出上传成功和失败数量 344 | success_count = len(futures) - len(failed_files) 345 | logger.info(f"Total files uploaded successfully: {success_count}") 346 | logger.info(f"Total files failed to upload: {len(failed_files)}") 347 | 348 | # 将上传失败的文件名保存到文件中 349 | if failed_files: 350 | with open(failed_files_path, "w") as f: 351 | f.write("\n".join(failed_files)) 352 | 353 | progress_thread.join() 354 | 355 | return True 356 | 357 | except Exception as err: 358 | logger.exception(f"An error occurred: {err}") 359 | return False 360 | 361 | 362 | if __name__ == "__main__": 363 | 364 | parser = argparse.ArgumentParser(description="Emoji data initialization tool") 365 | parser.add_argument( 366 | "--download", action="store_true", help="Download and extract emoji data" 367 | ) 368 | parser.add_argument("--upload", action="store_true", help="Upload files to MinIO") 369 | parser.add_argument( 370 | "--vectordb", action="store_true", help="Vector files to Database" 371 | ) 372 | 373 | args = parser.parse_args() 374 | 375 | # 检查是否提供了可选参数 376 | if not (args.download or args.upload or args.vectordb): 377 | print( 378 | "提示: 没有提供可选参数 '--download' '--upload '--vectordb' 请至少指定一个操作。" 379 | ) 380 | parser.print_help() 381 | exit(1) 382 | 383 | if args.download: 384 | 385 | from langchain_emoji.paths import local_data_path 386 | from langchain_emoji.settings.settings import settings 387 | 388 | # Define Google Drive file ID 389 | file_id = settings().dataset.google_driver_id 390 | 391 | # Define the directory to extract the file 392 | extract_dir = local_data_path 393 | 394 | # Download and extract the file 395 | success = download_and_extract_data(file_id, extract_dir) 396 | 397 | if not success: 398 | print("download and extract emoji data failed, exit!") 399 | exit(1) 400 | 401 | if args.upload: 402 | 403 | from langchain_emoji.settings.settings import settings 404 | from langchain_emoji.paths import local_data_path 405 | 406 | # MinIO configuration 407 | minio_endpoint = settings().minio.host 408 | minio_access_key = settings().minio.access_key 409 | minio_secret_key = settings().minio.secret_key 410 | secure = False # Change to False if MinIO server is not using SSL/TLS 411 | 412 | # # Initialize MinIO client 413 | minio_client = Minio( 414 | minio_endpoint, 415 | access_key=minio_access_key, 416 | secret_key=minio_secret_key, 417 | secure=secure, 418 | ) 419 | 420 | dataset_name = settings().dataset.name 421 | # # Source directory containing files to upload 422 | source_dir = local_data_path / dataset_name / "emo" 423 | if not (os.path.exists(source_dir) and os.path.isdir(source_dir)): 424 | print("emoji datasetdoes not exist, exit!") 425 | exit(1) 426 | 427 | # # Name of the MinIO bucket 428 | bucket_name = "emoji" 429 | 430 | # # Upload files to MinIO 431 | success = upload_to_minio(minio_client, source_dir, bucket_name) 432 | 433 | if not success: 434 | print("upload to minio failed, exit!") 435 | exit(1) 436 | 437 | if args.vectordb: 438 | 439 | from langchain_emoji.paths import local_data_path 440 | from langchain_emoji.settings.settings import settings 441 | from langchain_emoji.components.vector_store import VectorStoreComponent 442 | from langchain_emoji.components.embedding.embedding_component import ( 443 | EmbeddingComponent, 444 | ) 445 | 446 | embed = EmbeddingComponent(settings()) 447 | vsc = VectorStoreComponent(embed, settings()) 448 | 449 | dataset_name = settings().dataset.name 450 | dataset_file = local_data_path / dataset_name / "data.jsonl" 451 | if not (os.path.exists(dataset_file) and os.path.isfile(dataset_file)): 452 | print("emoji datajsonl not exist, exit!") 453 | exit(1) 454 | 455 | # # Upload files to MinIO 456 | success = upload_vectordb(vsc.vector_store, dataset_file) 457 | 458 | if not success: 459 | print("upload to minio failed, exit!") 460 | exit(1) 461 | 462 | # 加载进向量数据库 463 | -------------------------------------------------------------------------------- /tools/json2jsonl.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import jsonlines 4 | import sys 5 | 6 | 7 | def convert_json_to_jsonl(json_file, jsonl_file): 8 | with open(json_file, "r") as file: 9 | json_data = json.load(file) 10 | 11 | with jsonlines.open(jsonl_file, "w") as writer: 12 | for item in json_data: 13 | writer.write(item) 14 | 15 | 16 | def main(): 17 | parser = argparse.ArgumentParser(description="Convert JSON to JSONL") 18 | parser.add_argument("input_json", help="Input JSON file") 19 | parser.add_argument("output_jsonl", help="Output JSONL file") 20 | args = parser.parse_args() 21 | 22 | try: 23 | convert_json_to_jsonl(args.input_json, args.output_jsonl) 24 | except FileNotFoundError: 25 | print("Error: Input JSON file not found.", file=sys.stderr) 26 | sys.exit(1) 27 | 28 | 29 | if __name__ == "__main__": 30 | main() 31 | --------------------------------------------------------------------------------