├── .idea ├── .gitignore ├── FinRAG.iml ├── inspectionProfiles │ ├── Project_Default.xml │ └── profiles_settings.xml ├── misc.xml ├── modules.xml └── vcs.xml ├── Dockerfile ├── LICENSE ├── README.md ├── app ├── __init__.py ├── core │ ├── bce │ │ ├── __init__.py │ │ ├── embedding_client.py │ │ └── rerank_client.py │ ├── chat │ │ ├── __init__.py │ │ ├── open_chat.py │ │ └── rag_chat.py │ ├── loader │ │ └── pdf_loader.py │ ├── preprocessor │ │ └── file_processor.py │ ├── splitter │ │ ├── __init__.py │ │ ├── chinese_text_splitter.py │ │ └── zh_title_enhance.py │ └── vectorstore │ │ ├── __init__.py │ │ └── customer_milvus_client.py ├── finrag_server.py ├── models │ ├── __init__.py │ ├── dialog.py │ └── status.py └── oss │ ├── __init__.py │ └── download_file.py ├── bin └── start.sh ├── conf └── config.py ├── docker └── docker-compose.yml ├── example ├── test.pdf ├── test1.pdf ├── test2.docx ├── test3.txt └── ~$test2.docx ├── img.png ├── img_1.png ├── img_2.png ├── main.py ├── requirements.txt ├── setup.py ├── test ├── cz.py ├── http_test │ ├── notify 2.http │ ├── notify.http │ ├── query_1.http │ ├── query_2.http │ ├── test.http │ ├── test2.http │ ├── test3.http │ └── test4.http ├── test.json ├── test.py ├── test2.py ├── test3.py └── test_async.py └── utils.py /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | -------------------------------------------------------------------------------- /.idea/FinRAG.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 12 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/Project_Default.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 181 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 6 | 7 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.10-slim 2 | # 镜像元信息 3 | LABEL MAINTAINER=wangjia 4 | # 环境设置 5 | ENV LANG=C.UTF-8 6 | ENV TZ=Asia/Shanghai 7 | WORKDIR /FinRAG 8 | COPY . /FinRAG 9 | RUN pip install -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple 10 | EXPOSE 8000 11 | CMD ["/bin/bash", "/bin/start.sh"] 12 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2024 AI4Finance Foundation Inc. 190 | All rights reserved. 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | 203 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FinRAG: Financial Retrieval Augmented Generation 2 | 3 | [![](https://dcbadge.vercel.app/api/server/trsr8SXpW5)](https://discord.gg/trsr8SXpW5) 4 | 5 | ![Visitors](https://api.visitorbadge.io/api/VisitorHit?user=AI4Finance-Foundation&repo=FinRAG&countColor=%23B17A) 6 | 7 | ## 1. 准备工作 8 | 9 | ### 1.1 安装minicoda 10 | #### 网址 11 | `https://docs.anaconda.com/free/miniconda/miniconda-other-installer-links/#linux-installers` 12 | #### python3.10.14 13 | `wget https://repo.anaconda.com/miniconda/Miniconda3-py310_24.4.0-0-Linux-x86_64.sh` 14 | #### 配置环境变量 15 | `export PATH=$HOME/miniconda3/bin:$PATH` 16 | 17 | ### 1.2 启动Milvus向量数据库 18 | - 使用docker-compose启动Milvus服务 19 | ``` 20 | cd docker #切换至docker配置目录环境 21 | docker-compose up -d #启动项目中的服务,并在后台以守护进程方式运行 22 | ``` 23 | 如果对docker不了解,可以看下以下文章: \ 24 | [docker-compose快速入门](https://blog.csdn.net/m0_37899908/article/details/131268835) \ 25 | [docker-conpose命令解读](https://blog.csdn.net/weixin_42494218/article/details/135986248) \ 26 | 术语说明: \ 27 | 守护进程: 是一类在后台运行的特殊进程,用于执行特定的系统任务,会一直存在。如果以非守护进程启动,服务容易被终止。 28 | 29 | - Milvus 前端展示地址 30 | `http://{ip}:3100/#/` 把ip替换为你所在服务器的ip地址即可 31 | 32 | ### 1.3 Embedding以及Rerank模型下载 33 | - 新建/data/WoLLM 目录 34 | - 将以下两个模型下载到新建的目录中 35 | - Embedding Model 下载:`git clone https://www.modelscope.cn/maidalun/bce-embedding-base_v1.git` 36 | - Rerank Model 下载:`git clone https://www.modelscope.cn/maidalun/bce-reranker-base_v1.git` 37 | 说明: \ 38 | Embedding Model: 主要是完成将自然语言文本转化为固定维度向量的工作,主要在知识库的建模,用户查询query表示时会应用。 \ 39 | Rerank Model:对结果进行重排操作。 \ 40 | 这里采用的都是bce的模型,因为其在RAG上表现较好,可以参考资料,了解一下背景: \ 41 | [BCE Embedding技术报告](https://zhuanlan.zhihu.com/p/681370855) \ 42 | 下载好两个模型后, 将模型放到指定的位置,并更新项目conf.config.py文件中EMBEDDING_MODEL和RERANK_MODEL对应参数的路径(和模型路径保持一致)。如图: 43 | ![img.png](img.png) 44 | 45 | 46 | ### 1.4 安装依赖及修改配置信息 47 | #### 新建python虚拟环境 48 | `python -m venv .venv` 49 | #### 激活环境 50 | `source .venv/bin/activate` 51 | #### 安装项目依赖 52 | `pip install -r requirements.txt` 53 | 54 | ### 1.5 修改配置文件 55 | 56 | - 配置文件在conf/config.py 57 | - 配置文件的各项信息修改为自己的信息, 每个变量已经加了详细注释. 58 | - 可能需要改动的参数一般就是两个模型文件目录,如图: 59 | ![img_1.png](img_1.png) 60 | - 如果你需要修改端口,或者服务器变更,你需要修改docker.docker-compose.yml中的配置参数,一般就是修改ip和端口。 61 | ![img_2.png](img_2.png) 62 | 63 | ### 1.6 修改环境变量 64 | 65 | - 将.env_user复制为.env `cp .env_user .env` [重要,重要,重要. 必须复制一下] 66 | - 修改.env的LLM以及OSS相关变量信息(目前只需要复制一下即可, 不需要修改里面的内容了) 67 | 68 | ## 2. 启动App 69 | ### 2.1 第一种方式启动 70 | `python main.py` 71 | ### 2.2 第二种方式启动(为之后打进docker做准备,目前先用第一种方式) 72 | `bin/start.sh` 73 | 74 | 75 | ToDo 76 | - [ ] rag 优化 77 | - [ ] 多级索引优化 78 | - [ ] 多查询优化 79 | - [ ] query优化 80 | - [ ] 解析优化 81 | -------------------------------------------------------------------------------- /app/__init__.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Author: xubing 3 | Date: 2024-05-19 00:03:21 4 | LastEditors: xubing 5 | LastEditTime: 2024-05-19 00:03:22 6 | Description: file content 7 | ''' 8 | -------------------------------------------------------------------------------- /app/core/bce/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI4Finance-Foundation/FinRAG/790c6158596747a3ea58ea1dae7972ac0f18d3dd/app/core/bce/__init__.py -------------------------------------------------------------------------------- /app/core/bce/embedding_client.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author: wangjia 3 | Date: 2024-05-20 03:31:26 4 | LastEditors: wangjia 5 | LastEditTime: 2024-05-21 11:19:43 6 | Description: file content 7 | """ 8 | 9 | import numpy as np 10 | from BCEmbedding import EmbeddingModel 11 | 12 | from conf.config import DEVICE 13 | 14 | 15 | class EmbeddingClient: 16 | 17 | def __init__(self, model_name_or_path) -> None: 18 | self.model = EmbeddingModel( 19 | model_name_or_path=model_name_or_path, 20 | device=DEVICE, 21 | trust_remote_code=True, 22 | ) 23 | 24 | def get_embedding(self, sentences): 25 | embeddings = self.model.encode(sentences) 26 | return embeddings 27 | 28 | 29 | if __name__ == "__main__": 30 | 31 | # 读取txt文档,句子列表 32 | f = open("test.txt", "r", encoding="utf-8") 33 | doc = f.read().splitlines() 34 | embedding_client = EmbeddingClient("/data/WoLLM/bce-embedding-base_v1") 35 | embedding = embedding_client.get_embedding(doc) 36 | print(embedding.shape) 37 | -------------------------------------------------------------------------------- /app/core/bce/rerank_client.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Author: wangjia 3 | Date: 2024-05-23 20:40:21 4 | LastEditors: wangjia 5 | LastEditTime: 2024-05-23 21:00:07 6 | Description: file content 7 | ''' 8 | from BCEmbedding import RerankerModel 9 | 10 | 11 | class RerankClient: 12 | def __init__(self,rerank_model) -> None: 13 | self.model = RerankerModel(rerank_model, 14 | trust_remote_code=True,) 15 | def rerank(self,query,massages): 16 | sentence_pairs = [[query, massage] for massage in massages] 17 | #scores = self.model.compute_score(sentence_pairs) 18 | rerank_results = self.model.rerank(query, massages) 19 | return rerank_results 20 | -------------------------------------------------------------------------------- /app/core/chat/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI4Finance-Foundation/FinRAG/790c6158596747a3ea58ea1dae7972ac0f18d3dd/app/core/chat/__init__.py -------------------------------------------------------------------------------- /app/core/chat/open_chat.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Author: xubing 3 | Date: 2024-05-22 23:43:54 4 | LastEditors: xubing 5 | LastEditTime: 2024-05-23 19:44:13 6 | Description: file content 7 | ''' 8 | import os 9 | 10 | from openai import OpenAI 11 | from utils import logger 12 | 13 | class OpenChat: 14 | def __init__(self) -> None: 15 | self.client = OpenAI( 16 | api_key=os.getenv("DASHSCOPE_API_KEY"), # 如果您没有配置环境变量,请在此处用您的API Key进行替换 17 | base_url="https://dashscope.aliyuncs.com/compatible-mode/v1", # 填写DashScope SDK的base_url 18 | ) 19 | def chat(self,messages): 20 | logger.info(str(messages)) 21 | completion = self.client.chat.completions.create( 22 | model="qwen-plus", 23 | messages=messages, 24 | stream=False 25 | ) 26 | result = completion.choices[0].message.content 27 | print(result) 28 | return result 29 | 30 | if __name__=="__main__": 31 | oc = OpenChat() 32 | raw_messsages = [ 33 | { 34 | "chatMessageId": "m00001", 35 | "role": "system", 36 | "rawContent": "你是⼀个医疗助⼿" 37 | }, 38 | { 39 | "chatMessageId": "m00002", 40 | "role": "user", 41 | "rawContent": "我感冒了吃什么药" 42 | }, 43 | { 44 | "chatMessageId": "m00003", 45 | "role": "assistant", 46 | "rawContent": "你要是999感冒灵" 47 | }, 48 | { 49 | "chatMessageId": "m00004", 50 | "role": "user", 51 | "rawContent": "我吃了999感觉没什么⽤" 52 | } 53 | ] 54 | messages = [ 55 | { 56 | "role": x.get('role').lower(), 57 | "content": x.get("rawContent") 58 | } 59 | for x in raw_messsages 60 | ] 61 | print(messages) 62 | oc.chat(messages) 63 | 64 | -------------------------------------------------------------------------------- /app/core/chat/rag_chat.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from openai import OpenAI 4 | 5 | 6 | class RAGChat: 7 | def __init__(self) -> None: 8 | self.client = OpenAI( 9 | api_key=os.getenv("DASHSCOPE_API_KEY"), # 如果您没有配置环境变量,请在此处用您的API Key进行替换 10 | base_url="https://dashscope.aliyuncs.com/compatible-mode/v1", # 填写DashScope SDK的base_url 11 | ) 12 | def chat(self,messages): 13 | query = messages[-1].get("content") 14 | query_emb = '' 15 | 16 | completion = self.client.chat.completions.create( 17 | model="qwen-plus", 18 | messages=messages, 19 | stream=False 20 | ) 21 | result = completion.choices[0].message.content 22 | print(result) 23 | return result -------------------------------------------------------------------------------- /app/core/loader/pdf_loader.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | import base64 5 | import os 6 | from typing import Any, Callable, List, Union 7 | 8 | import fitz 9 | import numpy as np 10 | from langchain.document_loaders.unstructured import UnstructuredFileLoader 11 | from paddleocr import PaddleOCR 12 | from tqdm import tqdm 13 | from unstructured.partition.text import partition_text 14 | 15 | ocr_engine = PaddleOCR(use_angle_cls=True, lang="ch", use_gpu=True, show_log=False) 16 | 17 | 18 | class UnstructuredPaddlePDFLoader(UnstructuredFileLoader): 19 | """Loader that uses unstructured to load image files, such as PNGs and JPGs.""" 20 | 21 | def __init__( 22 | self, 23 | file_path: Union[str, List[str]], 24 | # ocr_engine: Callable, 25 | mode: str = "single", 26 | **unstructured_kwargs: Any, 27 | ): 28 | """Initialize with file path.""" 29 | self.ocr_engine = ocr_engine 30 | super().__init__(file_path=file_path, mode=mode, **unstructured_kwargs) 31 | 32 | def _get_elements(self) -> List: 33 | def pdf_ocr_txt(filepath, dir_path="tmp"): 34 | full_dir_path = os.path.join(os.path.dirname(filepath), dir_path) 35 | if not os.path.exists(full_dir_path): 36 | os.makedirs(full_dir_path) 37 | doc = fitz.open(filepath) 38 | txt_file_path = os.path.join( 39 | full_dir_path, "{}.txt".format(os.path.split(filepath)[-1]) 40 | ) 41 | img_name = os.path.join(full_dir_path, "tmp.png") 42 | with open(txt_file_path, "w", encoding="utf-8") as fout: 43 | for i in tqdm(range(doc.page_count)): 44 | page = doc.load_page(i) 45 | pix = page.get_pixmap() 46 | img = np.frombuffer(pix.samples, dtype=np.uint8).reshape( 47 | (pix.h, pix.w, pix.n) 48 | ) 49 | img_file = base64.b64encode(img).decode("utf-8") 50 | height, width, channels = pix.h, pix.w, pix.n 51 | 52 | binary_data = base64.b64decode(img_file) 53 | img_array = np.frombuffer(binary_data, dtype=np.uint8).reshape( 54 | (height, width, channels) 55 | ) 56 | # result = self.ocr_engine(img_array) 57 | result = self.ocr_engine.ocr(img_array) 58 | result = [line for line in result if line] 59 | ocr_result = [i[1][0] for line in result for i in line] 60 | fout.write("\n".join(ocr_result)) 61 | if os.path.exists(img_name): 62 | os.remove(img_name) 63 | return txt_file_path 64 | 65 | txt_file_path = pdf_ocr_txt(self.file_path) 66 | return partition_text(filename=txt_file_path, **self.unstructured_kwargs) 67 | 68 | 69 | if __name__ == "__main__": 70 | 71 | ... 72 | # from paddleocr import PaddleOCR 73 | # from app.core.text_splitter.chinese_text_splitter import ChineseTextSplitter 74 | 75 | # file_path = "附件4-1:广银理财幸福理财日添利开放式理财计划第2期产品说明书(百信银行B份额).pdf" 76 | # ocr_engine = PaddleOCR(use_angle_cls=True, lang="ch", use_gpu=True, show_log=False) 77 | # loader = UnstructuredPaddlePDFLoader(file_path, ocr_engine) 78 | # texts_splitter = ChineseTextSplitter(pdf=True, sentence_size=250) 79 | # docs = loader.load_and_split(texts_splitter) 80 | # print(docs) 81 | 82 | -------------------------------------------------------------------------------- /app/core/preprocessor/file_processor.py: -------------------------------------------------------------------------------- 1 | 2 | from langchain.text_splitter import RecursiveCharacterTextSplitter 3 | from langchain_community.document_loaders import ( 4 | TextLoader, UnstructuredFileLoader, UnstructuredPDFLoader, 5 | UnstructuredWordDocumentLoader) 6 | 7 | from app.core.splitter import ChineseTextSplitter, zh_title_enhance 8 | from conf import config 9 | from utils import logger 10 | 11 | text_splitter = RecursiveCharacterTextSplitter( 12 | separators=[ 13 | "\n", 14 | ".", 15 | "。", 16 | "!", 17 | "!", 18 | "?", 19 | "?", 20 | ";", 21 | ";", 22 | "……", 23 | "…", 24 | "、", 25 | ",", 26 | ",", 27 | " ", 28 | ], 29 | chunk_size=300, 30 | # length_function=num_tokens, 31 | ) 32 | 33 | 34 | class FileProcesser: 35 | 36 | def __init__(self): 37 | logger.info(f"Success init file processor") 38 | 39 | def split_file_to_docs(self, 40 | file_path, 41 | sentence_size=config.SENTENCE_SIZE): 42 | logger.info("开始解析文件,文件越大,解析所需时间越长,大文件请耐心等待...") 43 | file_type = file_path.split('.')[-1].lower() 44 | 45 | if file_type == "txt": 46 | loader = TextLoader(file_path, autodetect_encoding=True) 47 | texts_splitter = ChineseTextSplitter(pdf=False, 48 | sentence_size=sentence_size) 49 | docs = loader.load_and_split(texts_splitter) 50 | elif file_type == "pdf": 51 | loader = UnstructuredPDFLoader(file_path) 52 | texts_splitter = ChineseTextSplitter(pdf=True, 53 | sentence_size=sentence_size) 54 | docs = loader.load_and_split(texts_splitter) 55 | elif file_type == "docx": 56 | loader = UnstructuredWordDocumentLoader(file_path, mode="elements") 57 | texts_splitter = ChineseTextSplitter(pdf=False, 58 | sentence_size=sentence_size) 59 | docs = loader.load_and_split(texts_splitter) 60 | else: 61 | raise TypeError("文件类型不支持,目前仅支持:[txt,pdf,docx]") 62 | # docs = zh_title_enhance(docs) 63 | # 重构docs,如果doc的文本长度大于800tokens,则利用text_splitter将其拆分成多个doc 64 | # text_splitter: RecursiveCharacterTextSplitter 65 | logger.info(f"before 2nd split doc lens: {len(docs)}") 66 | docs = text_splitter.split_documents(docs) 67 | logger.info(f"after 2nd split doc lens: {len(docs)}") 68 | self.docs = docs 69 | return docs 70 | -------------------------------------------------------------------------------- /app/core/splitter/__init__.py: -------------------------------------------------------------------------------- 1 | from .chinese_text_splitter import ChineseTextSplitter 2 | from .zh_title_enhance import zh_title_enhance 3 | -------------------------------------------------------------------------------- /app/core/splitter/chinese_text_splitter.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Author: wangjia 3 | Date: 2024-05-20 21:23:22 4 | LastEditors: wangjia 5 | LastEditTime: 2024-05-20 21:25:43 6 | Description: file content 7 | ''' 8 | import re 9 | from typing import List 10 | 11 | from langchain.text_splitter import CharacterTextSplitter 12 | 13 | 14 | class ChineseTextSplitter(CharacterTextSplitter): 15 | def __init__(self, pdf: bool = False, sentence_size: int = 200, **kwargs): 16 | super().__init__(**kwargs) 17 | self.pdf = pdf 18 | self.sentence_size = sentence_size 19 | 20 | def split_text1(self, text: str) -> List[str]: 21 | if self.pdf: 22 | text = re.sub(r"\n{3,}", "\n", text) 23 | text = re.sub('\s', ' ', text) 24 | text = text.replace("\n\n", "") 25 | sent_sep_pattern = re.compile('([﹒﹔﹖﹗.。!?]["’”」』]{0,2}|(?=["‘“「『]{1,2}|$))') # del :; 26 | sent_list = [] 27 | for ele in sent_sep_pattern.split(text): 28 | if sent_sep_pattern.match(ele) and sent_list: 29 | sent_list[-1] += ele 30 | elif ele: 31 | sent_list.append(ele) 32 | return sent_list 33 | 34 | def split_text(self, text: str) -> List[str]: ##此处需要进一步优化逻辑 35 | if self.pdf: 36 | text = re.sub(r"\n{3,}", r"\n", text) 37 | text = re.sub('\s', " ", text) 38 | text = re.sub("\n\n", "", text) 39 | 40 | text = re.sub(r'([;;.!?。!?\?])([^”’])', r"\1\n\2", text) # 单字符断句符 41 | text = re.sub(r'(\.{6})([^"’”」』])', r"\1\n\2", text) # 英文省略号 42 | text = re.sub(r'(\…{2})([^"’”」』])', r"\1\n\2", text) # 中文省略号 43 | text = re.sub(r'([;;!?。!?\?]["’”」』]{0,2})([^;;!?,。!?\?])', r'\1\n\2', text) 44 | # 如果双引号前有终止符,那么双引号才是句子的终点,把分句符\n放到双引号后,注意前面的几句都小心保留了双引号 45 | text = text.rstrip() # 段尾如果有多余的\n就去掉它 46 | # 很多规则中会考虑分号;,但是这里我把它忽略不计,破折号、英文双引号等同样忽略,需要的再做些简单调整即可。 47 | ls = [i for i in text.split("\n") if i] 48 | for ele in ls: 49 | if len(ele) > self.sentence_size: 50 | ele1 = re.sub(r'([,,.]["’”」』]{0,2})([^,,.])', r'\1\n\2', ele) 51 | ele1_ls = ele1.split("\n") 52 | for ele_ele1 in ele1_ls: 53 | if len(ele_ele1) > self.sentence_size: 54 | ele_ele2 = re.sub(r'([\n]{1,}| {2,}["’”」』]{0,2})([^\s])', r'\1\n\2', ele_ele1) 55 | ele2_ls = ele_ele2.split("\n") 56 | for ele_ele2 in ele2_ls: 57 | if len(ele_ele2) > self.sentence_size: 58 | ele_ele3 = re.sub('( ["’”」』]{0,2})([^ ])', r'\1\n\2', ele_ele2) 59 | ele2_id = ele2_ls.index(ele_ele2) 60 | ele2_ls = ele2_ls[:ele2_id] + [i for i in ele_ele3.split("\n") if i] + ele2_ls[ 61 | ele2_id + 1:] 62 | ele_id = ele1_ls.index(ele_ele1) 63 | ele1_ls = ele1_ls[:ele_id] + [i for i in ele2_ls if i] + ele1_ls[ele_id + 1:] 64 | 65 | id = ls.index(ele) 66 | ls = ls[:id] + [i for i in ele1_ls if i] + ls[id + 1:] 67 | return ls 68 | -------------------------------------------------------------------------------- /app/core/splitter/zh_title_enhance.py: -------------------------------------------------------------------------------- 1 | import re 2 | from typing import List 3 | 4 | from langchain.docstore.document import Document 5 | 6 | 7 | def under_non_alpha_ratio(text: str, threshold: float = 0.5): 8 | """Checks if the proportion of non-alpha characters in the text snippet exceeds a given 9 | threshold. This helps prevent text like "-----------BREAK---------" from being tagged 10 | as a title or narrative text. The ratio does not count spaces. 11 | 12 | Parameters 13 | ---------- 14 | text 15 | The input string to test 16 | threshold 17 | If the proportion of non-alpha characters exceeds this threshold, the function 18 | returns False 19 | """ 20 | if len(text) == 0: 21 | return False 22 | 23 | alpha_count = len([char for char in text if char.strip() and char.isalpha()]) 24 | total_count = len([char for char in text if char.strip()]) 25 | try: 26 | ratio = alpha_count / total_count 27 | return ratio < threshold 28 | except: 29 | return False 30 | 31 | 32 | def is_possible_title( 33 | text: str, 34 | title_max_word_length: int = 20, 35 | non_alpha_threshold: float = 0.5, 36 | ) -> bool: 37 | """Checks to see if the text passes all of the checks for a valid title. 38 | 39 | Parameters 40 | ---------- 41 | text 42 | The input text to check 43 | title_max_word_length 44 | The maximum number of words a title can contain 45 | non_alpha_threshold 46 | The minimum number of alpha characters the text needs to be considered a title 47 | """ 48 | 49 | # 文本长度为0的话,肯定不是title 50 | if len(text) == 0: 51 | print("Not a title. Text is empty.") 52 | return False 53 | 54 | # 文本中有标点符号,就不是title 55 | ENDS_IN_PUNCT_PATTERN = r"[^\w\s]\Z" 56 | ENDS_IN_PUNCT_RE = re.compile(ENDS_IN_PUNCT_PATTERN) 57 | if ENDS_IN_PUNCT_RE.search(text) is not None: 58 | return False 59 | 60 | # 文本长度不能超过设定值,默认20 61 | # NOTE(robinson) - splitting on spaces here instead of word tokenizing because it 62 | # is less expensive and actual tokenization doesn't add much value for the length check 63 | if len(text) > title_max_word_length: 64 | return False 65 | 66 | # 文本中数字的占比不能太高,否则不是title 67 | if under_non_alpha_ratio(text, threshold=non_alpha_threshold): 68 | return False 69 | 70 | # NOTE(robinson) - Prevent flagging salutations like "To My Dearest Friends," as titles 71 | if text.endswith((",", ".", ",", "。")): 72 | return False 73 | 74 | if text.isnumeric(): 75 | print(f"Not a title. Text is all numeric:\n\n{text}") # type: ignore 76 | return False 77 | 78 | # 开头的字符内应该有数字,默认5个字符内 79 | if len(text) < 5: 80 | text_5 = text 81 | else: 82 | text_5 = text[:5] 83 | alpha_in_text_5 = sum(list(map(lambda x: x.isnumeric(), list(text_5)))) 84 | if not alpha_in_text_5: 85 | return False 86 | 87 | return True 88 | 89 | 90 | def zh_title_enhance(docs: List[Document]) -> List[Document]: 91 | title = None 92 | if len(docs) > 0: 93 | for doc in docs: 94 | if is_possible_title(doc.page_content): 95 | doc.metadata['category'] = 'cn_Title' 96 | title = doc.page_content 97 | elif title: 98 | doc.page_content = f"下文与({title})有关。{doc.page_content}" 99 | return docs 100 | else: 101 | print("文件不存在") 102 | -------------------------------------------------------------------------------- /app/core/vectorstore/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI4Finance-Foundation/FinRAG/790c6158596747a3ea58ea1dae7972ac0f18d3dd/app/core/vectorstore/__init__.py -------------------------------------------------------------------------------- /app/core/vectorstore/customer_milvus_client.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from pymilvus import (Collection, CollectionSchema, DataType, FieldSchema, 4 | MilvusClient, connections, utility) 5 | 6 | from app.core.bce.embedding_client import EmbeddingClient 7 | from app.core.bce.rerank_client import RerankClient 8 | from app.core.chat.open_chat import OpenChat 9 | from app.core.preprocessor.file_processor import FileProcesser 10 | from app.oss.download_file import Downloader 11 | from conf.config import (CACHE_DIR, COLLECTION_NAME, EMBEDDING_MODEL, 12 | MILVUS_URI, RAG_PROMPT, RERANK_MODEL, STORAGE_DIR, 13 | STORAGE_TYPE) 14 | 15 | embedding_client = EmbeddingClient(EMBEDDING_MODEL) 16 | rerank_client = RerankClient(RERANK_MODEL) 17 | oss_downloader = Downloader() 18 | file_processer = FileProcesser() 19 | open_chat = OpenChat() 20 | _dim = 768 21 | 22 | from utils import logger 23 | 24 | 25 | class CustomerMilvusClient: 26 | 27 | def __init__(self): 28 | # self.client = MilvusClient( 29 | # uri=config.milvus_uri 30 | # ) 31 | connections.connect(uri=MILVUS_URI) 32 | self.collection_name = COLLECTION_NAME 33 | self.collection = self.init() 34 | self.collection.load() 35 | 36 | def init(self): 37 | try: 38 | if utility.has_collection(self.collection_name): 39 | collection = Collection(self.collection_name) 40 | logger.info(f"collection {self.collection_name} exists") 41 | else: 42 | schema = CollectionSchema(self.fields) 43 | logger.info( 44 | f"create collection {self.collection_name} {schema}") 45 | collection = Collection(self.collection_name, schema) 46 | index_params = { 47 | "metric_type": "L2", 48 | "index_type": "IVF_FLAT", 49 | "params": { 50 | "nlist": 2048 51 | }, 52 | } 53 | collection.create_index(field_name="embedding", 54 | index_params=index_params) 55 | logger.info("初始化成功!") 56 | except Exception as e: 57 | logger.error(e) 58 | return collection 59 | 60 | def embedding_to_vdb(self, file_details, batch_size=1000): 61 | """ 62 | 当文档过长时,会出现无法一次性插入向量数据库的情况,因此需要分批插入 63 | """ 64 | for file_info in file_details: 65 | fileName = file_info.get("fileName") 66 | fileSuffix = file_info.get("fileSuffix") 67 | storagePath = file_info.get("storagePath") 68 | local_file = CACHE_DIR + "/" + fileName 69 | logger.info(f"正在将文件【{fileName}】下载到本地缓存...") 70 | try: 71 | oss_downloader.get_file(storagePath, local_file) 72 | except: 73 | logger.error("下载失败,请检查网络或检查文件是否存在") 74 | logger.info("本地文件地址:" + str(local_file)) 75 | docs = file_processer.split_file_to_docs(local_file) 76 | docs_content = [doc.page_content for doc in docs] 77 | embeddings = embedding_client.get_embedding(docs_content) 78 | entities = [] 79 | try: 80 | for idx, cont, emb in zip(range(len(docs)), docs_content, 81 | embeddings): 82 | entity = { 83 | "parentId": file_info.get("parentId"), 84 | "categoryName": file_info.get("categoryName"), 85 | "categoryId": file_info.get("categoryId"), 86 | "fileId": str(file_info.get("fileId")), 87 | "fileName": fileName, 88 | "fileSuffix": fileSuffix, 89 | "storagePath": storagePath, 90 | "chunkId": idx, 91 | "chunkContent": cont, 92 | "embedding": emb, 93 | } 94 | entities.append(entity) 95 | if len(entities) == batch_size: 96 | self.collection.insert(entities) 97 | self.collection.flush() 98 | entities = [] 99 | self.collection.insert(entities) 100 | self.collection.flush() 101 | logger.info("存入向量数据库成功!") 102 | except: 103 | logger.error(str(local_file) + "写入向量数据库失败,请检查!") 104 | 105 | def parse_request(self, data): 106 | # # 解析JSON字符串 107 | # data = json.loads(json_str) 108 | # 初始化一个空列表来存储结果 109 | file_details = [] 110 | 111 | # 遍历JSON数据结构 112 | for category in data.sysCategory: 113 | # 首先获取顶级分类中的文件存储信息 114 | for file_storage in category.get("fileStorages", []): 115 | file_details.append({ 116 | "fileName": file_storage["fileName"], 117 | "fileSuffix": file_storage["fileSuffix"], 118 | "storagePath": file_storage["storagePath"], 119 | "categoryName": category["categoryName"], 120 | "categoryId": category["categoryId"], 121 | "parentId": category["parentId"], 122 | }) 123 | 124 | # 然后获取子分类中的文件存储信息 125 | for sub_category in category.get("subCategory", []): 126 | for file_storage in sub_category.get("fileStorages", []): 127 | file_details.append({ 128 | "fileName": 129 | file_storage["fileName"], 130 | "fileSuffix": 131 | file_storage["fileSuffix"], 132 | "storagePath": 133 | file_storage["storagePath"], 134 | "categoryName": 135 | sub_category["categoryName"], 136 | "categoryId": 137 | sub_category["categoryId"], 138 | "parentId": 139 | sub_category["parentId"], 140 | }) 141 | 142 | # 打印结果 143 | for detail in file_details: 144 | print(detail) 145 | return file_details 146 | 147 | @property 148 | def fields(self): 149 | fields = [ 150 | FieldSchema( 151 | name="id", 152 | dtype=DataType.VARCHAR, 153 | is_primary=True, 154 | auto_id=True, 155 | max_length=100, 156 | ), 157 | FieldSchema(name="parentId", 158 | dtype=DataType.VARCHAR, 159 | max_length=256), 160 | FieldSchema(name="categoryName", 161 | dtype=DataType.VARCHAR, 162 | max_length=256), 163 | FieldSchema(name="categoryId", 164 | dtype=DataType.VARCHAR, 165 | max_length=256), 166 | FieldSchema(name="fileId", dtype=DataType.VARCHAR, max_length=256), 167 | FieldSchema(name="fileName", 168 | dtype=DataType.VARCHAR, 169 | max_length=1024), 170 | FieldSchema(name="fileSuffix", 171 | dtype=DataType.VARCHAR, 172 | max_length=256), 173 | FieldSchema(name="storagePath", 174 | dtype=DataType.VARCHAR, 175 | max_length=256), 176 | FieldSchema(name="chunkId", dtype=DataType.INT64), 177 | FieldSchema(name="chunkContent", 178 | dtype=DataType.VARCHAR, 179 | max_length=1024), 180 | FieldSchema(name="embedding", 181 | dtype=DataType.FLOAT_VECTOR, 182 | dim=_dim), # 向量字段 183 | ] 184 | return fields 185 | 186 | @property 187 | def output_fields(self): 188 | return [ 189 | "parentId", 190 | "categoryName", 191 | "categoryId", 192 | "fileId", 193 | "fileName", 194 | "fileSuffix", 195 | "storagePath", 196 | "chunkId", 197 | "chunkContent", 198 | "embedding", 199 | ] 200 | 201 | def delete_collection(self): 202 | self.collection.release() 203 | utility.drop_collection(self.collection_name) 204 | print("向量数据库重置成功!") 205 | 206 | def get_rag_result(self, initInputs, messages): 207 | query = messages[-1].get("content") 208 | logger.info(f"最新的问题是:【{query}】") 209 | query_emb = embedding_client.get_embedding(query) 210 | categoryIds = initInputs.get("categoryIds") 211 | topK = initInputs.get('topK') 212 | score = initInputs.get('score') 213 | logger.info("score:" + str(score)) 214 | if len(categoryIds) > 1: 215 | # rerank 216 | rag_results = [] 217 | reference_results = [] 218 | for idStr in categoryIds: 219 | category_ids = idStr.split(',') 220 | rag_result, retrival_results = self.retrieval_and_generate( 221 | query_emb, topK, score, category_ids, messages) 222 | 223 | rag_results.append(rag_result) 224 | reference_results.extend(retrival_results) 225 | 226 | rereank_results = self.rerank(query, rag_results) 227 | rag_result = rag_results[rereank_results['rerank_ids'].index(0)] 228 | retrival_results = reference_results 229 | else: 230 | category_ids = categoryIds[0].split(',') 231 | rag_result, retrival_results = self.retrieval_and_generate( 232 | query_emb, topK, score, category_ids, messages) 233 | return rag_result, retrival_results 234 | 235 | def retrieval_and_generate(self, query_emb, topK, score, category_ids, 236 | messages): 237 | expr = "categoryId in {}".format(category_ids) 238 | logger.info(expr) 239 | search_params = { 240 | "metric_type": "L2", 241 | "offset": 0, 242 | "ignore_growing": False, 243 | "params": { 244 | "nprobe": 10 245 | } 246 | } 247 | 248 | results = self.collection.search( 249 | data=query_emb, 250 | anns_field="embedding", 251 | param=search_params, 252 | limit=topK, 253 | expr=expr, 254 | output_fields=['fileName', 'chunkContent'], 255 | consistency_level="Strong") 256 | relevant_content = [] 257 | for hits in results: 258 | for hit in hits: 259 | print(hit) 260 | print(f"ID: {hit.id}, score: {hit.score}") 261 | print(f"chunkContent: {hit.entity.get('chunkContent')})") 262 | relevant_content.append((hit.score, hit.entity.get("fileName"), 263 | hit.entity.get('chunkContent'))) 264 | logger.info("检索到相关片段的数量是:%d" % len(relevant_content)) 265 | if len(relevant_content) > 0: 266 | retrival_results = [x for x in relevant_content if x[0] > score] 267 | retrival_results_text = [x[2] for x in retrival_results] 268 | logger.info("检索到的片段:" + str(retrival_results_text)) 269 | retrieval_result_str = '\n\n'.join(retrival_results_text) 270 | else: 271 | retrival_results = [] 272 | retrieval_result_str = "" 273 | 274 | messages[-1]['content'] = RAG_PROMPT.format( 275 | context=retrieval_result_str, question=messages[-1]['content']) 276 | rag_result = open_chat.chat(messages) 277 | 278 | return rag_result, retrival_results 279 | 280 | # def generate(self, messages, retrieval_result): 281 | # messages[-1]['content'] = RAG_PROMPT.format( 282 | # context=retrieval_result, question=messages[-1]['content']) 283 | # ans = open_chat.chat(messages) 284 | # return ans 285 | 286 | def rerank(self, query, multi_result): 287 | logger.info('进入rerank模块') 288 | rerank_result = rerank_client.rerank(query, multi_result) 289 | return rerank_result 290 | -------------------------------------------------------------------------------- /app/finrag_server.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author: Lucas 3 | Date: 2024-05-22 11:21:41 4 | LastEditors: Lucas 5 | LastEditTime: 2024-05-22 11:24:10 6 | Description: file content 7 | """ 8 | 9 | import time 10 | from typing import Any, Dict 11 | import requests 12 | import httpx 13 | import uvicorn 14 | from fastapi import FastAPI, HTTPException,BackgroundTasks 15 | from pydantic import BaseModel 16 | import asyncio 17 | from app.core.chat.open_chat import OpenChat 18 | from app.core.vectorstore.customer_milvus_client import CustomerMilvusClient 19 | from app.models.status import ErrorMsg, SuccessMsg 20 | from conf import config 21 | from utils import logger 22 | import json 23 | 24 | cmc = CustomerMilvusClient() 25 | open_chat = OpenChat() 26 | 27 | 28 | class Item(BaseModel): 29 | syncId: Any 30 | sysCategory: Any 31 | 32 | 33 | class Query(BaseModel): 34 | chatId: Any 35 | ownerId: Any 36 | chatName: Any 37 | initInputs: Dict 38 | initOpening: Any 39 | chatMessages: Any 40 | 41 | 42 | class Notify(BaseModel): 43 | syncId: Any 44 | status: Any 45 | 46 | 47 | 48 | def notify_another(notify_msg: Notify): 49 | logger.info("通知Embedding完成") 50 | data = {"syncId": notify_msg.syncId, "status": notify_msg.status} 51 | response = requests.post(config.NOTIFY_URL, 52 | headers={'content-type':'application/json'}, 53 | data=json.dumps(data)) 54 | # 打印响应的状态码 55 | print('Status Code:', response.status_code) 56 | if response.status_code==200: 57 | logger.info("消息同步成功!") 58 | # print('Content:', response.message) 59 | return response 60 | 61 | app = FastAPI() 62 | 63 | @app.post("/chat") 64 | async def chat(query: Query): 65 | logger.info("进入Chat") 66 | # logger.info("query:"+str(query.to_dict())) 67 | chatId = query.chatId 68 | ownerId = query.ownerId 69 | chatName = query.chatName 70 | initInputs = query.initInputs 71 | initOpening = query.initOpening 72 | chatMessages = query.chatMessages 73 | # logger.info( 74 | # str( 75 | # {"chatName":chatName, 76 | # "initInputs":initInputs, 77 | # "chatMessages":chatMessages 78 | # } 79 | # ) 80 | # ) 81 | logger.info("query:\n***********************************"+ 82 | query.model_dump_json()+ 83 | "\n***********************************") 84 | messages = [{ 85 | "role": x.get("role").lower(), 86 | "content": x.get("rawContent") 87 | } for x in chatMessages] 88 | 89 | try: 90 | chunks = [] 91 | if len(initInputs.get("categoryIds")) == 0: 92 | # 开放问答 93 | logger.info("进入开放域知识问答,答案由完全由大模型生成!") 94 | response = open_chat.chat(messages) 95 | 96 | else: 97 | logger.info("进入RAG问答,答案由大模型根据知识库生成!") 98 | response, retrieval_results = cmc.get_rag_result( 99 | initInputs, messages) 100 | if len(retrieval_results): 101 | chunks = [{ 102 | "index": x[1], 103 | "chunk": x[2], 104 | "score": x[0] 105 | } for x in retrieval_results] 106 | logger.info("chunks"+str(chunks)) 107 | messages.append({"role": "assistant", "content": response}) 108 | messages.append({ 109 | "role": "user", 110 | "content": "根据上面我们的历史对话,为我推荐三个接下来我可能要问的问题。每个问题以?结尾" 111 | }) 112 | suggestedQuestions = open_chat.chat(messages) 113 | try: 114 | suggestedQuestions = suggestedQuestions.split('?') 115 | suggestedQuestions=[x.strip() for x in suggestedQuestions[:3]] 116 | except: 117 | suggestedQuestions = [suggestedQuestions] 118 | 119 | if chatName == "知识问答助手": 120 | try: 121 | new_messages = [ 122 | { 123 | "role":"system", 124 | "content":"你是一位得力的助手" 125 | }, 126 | { 127 | "role":"user", 128 | "content":config.DIALOGUE_SUMMARY.format(context=str(messages[:-1])) 129 | } 130 | ] 131 | chatName = open_chat.chat(new_messages) 132 | 133 | logger.info("大模型总结的chatName:"+str(chatName)) 134 | except: 135 | chatName = chatName 136 | 137 | return { 138 | "code": "000000", 139 | "data": { 140 | "event": "MESSAGE", 141 | "result": { 142 | "chatId": chatId, 143 | "chatName": chatName, 144 | "answer": response, 145 | "suggestedQuestions": suggestedQuestions, 146 | "chunks": chunks, 147 | }, 148 | }, 149 | "message": "调⽤成功", 150 | "success": True, 151 | "time": time.time(), 152 | } 153 | except: 154 | logger.error("RAG问答出错,请检查!") 155 | return ErrorMsg.to_dict() 156 | 157 | 158 | async def async_update(item:Item): 159 | try: 160 | logger.info("开始更新向量") 161 | logger.info("syncId:" + str(item.syncId)) 162 | start_time = time.time() 163 | detail = cmc.parse_request(item) 164 | cmc.embedding_to_vdb(detail) 165 | notify_msg = Notify(syncId=item.syncId, status=1) 166 | end_time = time.time() 167 | logger.info("Cost time:" + str(end_time - start_time)) 168 | response = notify_another(notify_msg) 169 | # print(response.message) 170 | # return SuccessMsg.to_dict() 171 | except Exception as e: 172 | # 如果在处理请求时发生错误,返回HTTP 400错误 173 | # raise HTTPException(status_code=400, detail=f"An error occurred: {str(e)}.") 174 | logger.error("更新向量发生错误!") 175 | # return ErrorMsg.to_dict() 176 | @app.post("/update_vector") 177 | async def update_vector(item: Item): 178 | asyncio.create_task(async_update(item)) 179 | return SuccessMsg.to_dict() 180 | -------------------------------------------------------------------------------- /app/models/__init__.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Author: xubing 3 | Date: 2024-05-19 00:09:54 4 | LastEditors: xubing 5 | LastEditTime: 2024-05-19 00:09:55 6 | Description: file content 7 | ''' 8 | -------------------------------------------------------------------------------- /app/models/dialog.py: -------------------------------------------------------------------------------- 1 | from .status import Status 2 | 3 | 4 | class DialogRequest: 5 | request_id = "" 6 | query = "" 7 | session_id = "" 8 | user_id = "" 9 | session = [] 10 | stream = False 11 | 12 | 13 | class DialogResponse: 14 | status = Status 15 | answer = "" 16 | title = "" 17 | recommend_topic = "" 18 | chunks = "" 19 | -------------------------------------------------------------------------------- /app/models/status.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Author: wangjia 3 | Date: 2024-05-19 00:10:00 4 | LastEditors: wangjia 5 | LastEditTime: 2024-05-24 00:02:48 6 | Description: file content 7 | ''' 8 | # class Status: 9 | # status_code = "" 10 | # status_msg = "" 11 | import time 12 | 13 | 14 | class ErrorMsg: 15 | code = "xxxxxx" 16 | data = None 17 | message= "调⽤失败" 18 | success= False 19 | @classmethod 20 | def to_dict(cls): 21 | return { 22 | 'code': cls.code, 23 | 'data': cls.data, 24 | 'message': cls.message, 25 | 'success': cls.success, 26 | 'time': time.time(), # 获取当前时间 27 | } 28 | class SuccessMsg(): 29 | code = "000000" 30 | data = None 31 | message= "调⽤成功" 32 | success= True 33 | @classmethod 34 | def to_dict(cls): 35 | return { 36 | 'code': cls.code, 37 | 'data': cls.data, 38 | 'message': cls.message, 39 | 'success': cls.success, 40 | 'time': time.time(), # 获取当前时间 41 | } 42 | 43 | if __name__ == '__main__': 44 | sm = SuccessMsg() 45 | sm.to_dict() -------------------------------------------------------------------------------- /app/oss/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI4Finance-Foundation/FinRAG/790c6158596747a3ea58ea1dae7972ac0f18d3dd/app/oss/__init__.py -------------------------------------------------------------------------------- /app/oss/download_file.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author: Lucas 3 | Date: 2024-05-21 10:25:18 4 | LastEditors: Lucas 5 | LastEditTime: 2024-05-22 12:55:53 6 | Description: file content 7 | """ 8 | 9 | import os 10 | import shutil 11 | 12 | import dotenv 13 | import oss2 14 | from oss2.credentials import EnvironmentVariableCredentialsProvider 15 | 16 | from conf.config import STORAGE_DIR, STORAGE_TYPE 17 | 18 | dotenv.load_dotenv() 19 | endpoint = os.getenv("END_POINT") 20 | bucket_name = os.getenv("BUCKET_NAME") 21 | 22 | class Downloader: 23 | def __init__(self) -> None: 24 | auth = oss2.ProviderAuth(EnvironmentVariableCredentialsProvider()) 25 | self.bucket = oss2.Bucket(auth, endpoint, bucket_name) 26 | 27 | def get_file(self, remote, local): 28 | if STORAGE_TYPE=='local': 29 | return self.get_local_file(remote, local) 30 | else: 31 | return self.get_oss_file(remote, local) 32 | 33 | def get_local_file(self,remote,local): 34 | shutil.copy(os.path.join(STORAGE_DIR,remote),local) 35 | return 36 | def get_oss_file(self,remote,local): 37 | # 首先判断本地是否有缓存, 如果有, 跳过; 没有,下载 38 | if local in os.listdir(".cache"): 39 | return 40 | if local is None: 41 | local = ".cache/" + remote.split("/")[-1] 42 | self.bucket.get_object_to_file(remote, local) 43 | return -------------------------------------------------------------------------------- /bin/start.sh: -------------------------------------------------------------------------------- 1 | uvicorn app.finrag_server:app --host 0.0.0.0 --port 8000 --reload -------------------------------------------------------------------------------- /conf/config.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Author: wangjia 3 | Date: 2024-05-26 23:28:10 4 | LastEditors: wangjia 5 | LastEditTime: 2024-05-30 22:29:58 6 | Description: file content 7 | ''' 8 | 9 | import os 10 | 11 | COLLECTION_NAME = "FIN_RAG" # 向量数据库的名称 12 | SENTENCE_SIZE = 500 # 分割的文本长度 13 | DEVICE = "cuda" # 使用cpu还是gpu 14 | EMBEDDING_MODEL = "/data/WoLLM/bce-embedding-base_v1" # embedding 模型的本地路径 15 | RERANK_MODEL = "/data/WoLLM/bce-reranker-base_v1" # rerank 模型的本地路径 16 | MILVUS_URI = "http://localhost:19530" # milvus的uri 17 | NOTIFY_URL = "http://39.96.174.204/api/medical-assistant/knowledge/file/vector/complete" # 向量更新完成后,向后端发送消息的url 18 | CACHE_DIR = ".cache" # 本地缓存路径. 将oss的文件下载这个缓存文件夹里 19 | STORAGE_TYPE="local" # oss # 使用loca还是osss[新增功能] 20 | STORAGE_DIR="/data/storage/" # 本地文件 [新增功能] 21 | if not os.path.exists(CACHE_DIR): 22 | os.mkdir(CACHE_DIR) 23 | 24 | # 对话内容总结标题的prompt 25 | DIALOGUE_SUMMARY = """为以下对话内容总结一个标题 26 | {context} 27 | 28 | 请限制在20个字以内 29 | 你的回复:""" 30 | 31 | # RAG的核心prompt 32 | RAG_PROMPT = """参考信息: 33 | {context} 34 | --- 35 | 我的问题或指令: 36 | {question} 37 | --- 38 | 请根据上述参考信息回答我的问题或回复我的指令。 39 | - 我的问题或指令是什么语种,你就用什么语种回复. 40 | - 前面的参考信息可能有用,也可能没用, 请自行判别。 41 | - 如果前面的参考信息与问题有关,你需要从我给出的参考信息中选出与我的问题最相关的那些,来为你的回答提供依据。 42 | - 如果前面的参考信息与问题无关,请回答:根据参考信息,无法得到问题的答案. 43 | - 不要随机编造答案 44 | 45 | 你的回复:""" 46 | 47 | if __name__ == "__main__": 48 | print(RAG_PROMPT.format(context="Hello", question="world")) 49 | -------------------------------------------------------------------------------- /docker/docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: '3.5' 2 | 3 | services: 4 | etcd: 5 | container_name: milvus-etcd 6 | image: quay.io/coreos/etcd:v3.5.5 7 | environment: 8 | - ETCD_AUTO_COMPACTION_MODE=revision 9 | - ETCD_AUTO_COMPACTION_RETENTION=1000 10 | - ETCD_QUOTA_BACKEND_BYTES=4294967296 11 | - ETCD_SNAPSHOT_COUNT=50000 12 | volumes: 13 | - /data/volumes/etcd:/etcd 14 | command: etcd -advertise-client-urls=http://127.0.0.1:2379 -listen-client-urls http://0.0.0.0:2379 --data-dir /etcd 15 | healthcheck: 16 | test: ["CMD", "etcdctl", "endpoint", "health"] 17 | interval: 30s 18 | timeout: 20s 19 | retries: 3 20 | minio: 21 | container_name: milvus-minio 22 | image: minio/minio:RELEASE.2023-03-20T20-16-18Z 23 | environment: 24 | MINIO_ACCESS_KEY: minioadmin 25 | MINIO_SECRET_KEY: minioadmin 26 | ports: 27 | - "9001:9001" 28 | - "9002:9000" 29 | volumes: 30 | - /data/volumes/minio:/minio_data 31 | command: minio server /minio_data --console-address ":9001" 32 | healthcheck: 33 | test: ["CMD", "curl", "-f", "http://localhost:9000/minio/health/live"] 34 | interval: 30s 35 | timeout: 20s 36 | retries: 3 37 | standalone: 38 | container_name: milvus-standalone 39 | image: milvusdb/milvus:v2.3.3 40 | command: ["milvus", "run", "standalone"] 41 | security_opt: 42 | - seccomp:unconfined 43 | environment: 44 | ETCD_ENDPOINTS: etcd:2379 45 | MINIO_ADDRESS: minio:9000 46 | volumes: 47 | - /data/volumes/milvus:/var/lib/milvus 48 | healthcheck: 49 | test: ["CMD", "curl", "-f", "http://localhost:9091/healthz"] 50 | interval: 30s 51 | start_period: 90s 52 | timeout: 20s 53 | retries: 3 54 | ports: 55 | - "19530:19530" 56 | - "9091:9091" 57 | depends_on: 58 | - "etcd" 59 | - "minio" 60 | attu: 61 | container_name: milvus-attu 62 | image: zilliz/attu:v2.3.8 63 | ports: 64 | - "3100:3000" 65 | environment: 66 | - MILVUS_URL=http:localhost:19530 # 这里是milvus的UI,这里的URL填写milvus服务器的url 67 | 68 | networks: 69 | default: 70 | name: milvus -------------------------------------------------------------------------------- /example/test.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI4Finance-Foundation/FinRAG/790c6158596747a3ea58ea1dae7972ac0f18d3dd/example/test.pdf -------------------------------------------------------------------------------- /example/test1.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI4Finance-Foundation/FinRAG/790c6158596747a3ea58ea1dae7972ac0f18d3dd/example/test1.pdf -------------------------------------------------------------------------------- /example/test2.docx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI4Finance-Foundation/FinRAG/790c6158596747a3ea58ea1dae7972ac0f18d3dd/example/test2.docx -------------------------------------------------------------------------------- /example/~$test2.docx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI4Finance-Foundation/FinRAG/790c6158596747a3ea58ea1dae7972ac0f18d3dd/example/~$test2.docx -------------------------------------------------------------------------------- /img.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI4Finance-Foundation/FinRAG/790c6158596747a3ea58ea1dae7972ac0f18d3dd/img.png -------------------------------------------------------------------------------- /img_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI4Finance-Foundation/FinRAG/790c6158596747a3ea58ea1dae7972ac0f18d3dd/img_1.png -------------------------------------------------------------------------------- /img_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI4Finance-Foundation/FinRAG/790c6158596747a3ea58ea1dae7972ac0f18d3dd/img_2.png -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import uvicorn 2 | from app.finrag_server import app 3 | 4 | if __name__ == '__main__': 5 | uvicorn.run(app, host="0.0.0.0", port=8000) 6 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | fastapi==0.111.0 2 | loguru==0.7.2 3 | openai==1.30.1 4 | numpy==1.26.4 5 | BCEmbedding==0.1.5 6 | pymilvus==2.3.7 7 | langchain==0.2.0 8 | langchain-community==0.2.0 9 | unstructured==0.14.0 10 | pdfminer.six==20231228 11 | pillow_heif==0.16.0 12 | opencv-python==4.9.0.80 13 | pdf2image==1.17.0 14 | unstructured_inference==0.7.31 15 | pytesseract==0.3.10 16 | pikepdf==8.15.1 17 | python-docx==1.1.2 18 | dashscope==1.19.1 19 | oss2==2.18.5 -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | # Read requirements.txt, ignore comments 4 | try: 5 | with open("requirements.txt", "r") as f: 6 | REQUIRES = [line.split("#", 1)[0].strip() for line in f if line.strip()] 7 | except: 8 | print("'requirements.txt' not found!") 9 | REQUIRES = list() 10 | 11 | setup( 12 | name="FinRAG", 13 | version="0.0.1", 14 | include_package_data=True, 15 | author="AI4Finance Foundation", 16 | author_email="contact@ai4finance.org", 17 | url="https://github.com/AI4Finance-Foundation/FinRAG", 18 | license="MIT", 19 | packages=find_packages(), 20 | install_requires=REQUIRES, 21 | description="FinRAG: Financial Retrieval Augmented Generation", 22 | long_description="""FinRAG""", 23 | classifiers=[ 24 | # Trove classifiers 25 | # Full list: https://pypi.python.org/pypi?%3Aaction=list_classifiers 26 | "License :: OSI Approved :: MIT License", 27 | "Programming Language :: Python", 28 | "Programming Language :: Python :: 3", 29 | "Programming Language :: Python :: 3.6", 30 | "Programming Language :: Python :: 3.7", 31 | "Programming Language :: Python :: 3.8", 32 | "Programming Language :: Python :: 3.9", 33 | "Programming Language :: Python :: 3.10", 34 | "Programming Language :: Python :: 3.11", 35 | "Programming Language :: Python :: Implementation :: CPython", 36 | "Programming Language :: Python :: Implementation :: PyPy", 37 | ], 38 | keywords="Financial Large Language Models, AI Agents, Retrieval Augmented Generation", 39 | platforms=["any"], 40 | python_requires=">=3.10, <3.12", 41 | ) 42 | -------------------------------------------------------------------------------- /test/cz.py: -------------------------------------------------------------------------------- 1 | from app.core.vectorstore.customer_milvus_client import CustomerMilvusClient 2 | if __name__ == '__main__': 3 | CustomerMilvusClient().delete_collection() -------------------------------------------------------------------------------- /test/http_test/notify 2.http: -------------------------------------------------------------------------------- 1 | 2 | POST http://39.96.174.204/api/medical-assistant/knowledge/file/vector/complete HTTP/1.1 3 | Content-Type: application/json;charset=utf-8 4 | 5 | { 6 | "syncId": "1", 7 | "status": 2 8 | } -------------------------------------------------------------------------------- /test/http_test/notify.http: -------------------------------------------------------------------------------- 1 | 2 | POST http://127.0.0.1:8000/notify_another/ HTTP/1.1 3 | Content-Type: application/json;charset=utf-8 4 | 5 | { 6 | "syncId": "", 7 | "status": 3 8 | } -------------------------------------------------------------------------------- /test/http_test/query_1.http: -------------------------------------------------------------------------------- 1 | 2 | POST http://60.205.147.142:8000/chat/ HTTP/1.1 3 | Content-Type: application/json;charset=utf-8 4 | 5 | { 6 | "chatId": "20046261673390243840005710", 7 | "chatMessages": [ 8 | { 9 | "chatMessageId": "m00001", 10 | "role": "SYSTEM", 11 | "rawContent": "你是⼀个医疗助⼿" 12 | }, 13 | { 14 | "chatId": "20046261673390243840005710", 15 | "chatMessageId": "20056261711947348049925987", 16 | "cid": "sinodata", 17 | "rawContent": "我感冒了吃什么药", 18 | "role": "USER" 19 | } 20 | ], 21 | "chatName": "知识问答助手", 22 | "initInputs": { 23 | "topK": 1, 24 | "score": 10, 25 | "categoryIds": [ 26 | "20026250115495299645444018,1010010101" 27 | ], 28 | "language": "zh" 29 | }, 30 | "initOpening": [ 31 | "我是知识问答助手", 32 | "你有什么想问我的吗?" 33 | ], 34 | "ownerId": "123213123" 35 | } -------------------------------------------------------------------------------- /test/http_test/query_2.http: -------------------------------------------------------------------------------- 1 | 2 | POST http://127.0.0.1:8000/chat/ HTTP/1.1 3 | Content-Type: application/json;charset=utf-8 4 | 5 | {"chatId":"20046261673390243840005710", 6 | "chatMessages":[ 7 | { 8 | "chatMessageId": "m00001", 9 | "role": "SYSTEM", 10 | "rawContent": "你是⼀个医疗助⼿" 11 | }, 12 | {"chatId":"20046261673390243840005710","chatMessageId":"20056261758441543434247063", 13 | "cid":"sinodata","gmtCreate":1716559155000,"gmtModified":1716559155000,"id":"4","isDeleted":1,"rating":1,"rawContent":"我感冒了.我需要吃什么药","role":"USER","tokensUsed":0}, 14 | {"chatId":"20046261673390243840005710","chatMessageId":"20056261758454000517126380","cid":"sinodata","gmtCreate":1716559155000,"gmtModified":1716559155000,"id":"5","isDeleted":1,"rating":1,"rawContent":"你好,今天天气很棒啊!","role":"ASSISTANT","tokensUsed":0}, 15 | {"chatId":"20046261673390243840005710","chatMessageId":"20056261759779903897609217","cid":"sinodata","rawContent":"病毒的定义及其特点?","role":"USER"}], 16 | "chatName":"知识问答助手","initInputs": 17 | {"topK":1,"score":10,"categoryIds":["20026250122028766986246526,20026250123269316280320966","20026250129949416488964807"],"language":"zh"}, 18 | "initOpening":["我是知识问答助手","你有什么想问我的吗?"],"ownerId":"123213123"} -------------------------------------------------------------------------------- /test/http_test/test.http: -------------------------------------------------------------------------------- 1 | 2 | POST http://127.0.0.1:8000/update_vector/ HTTP/1.1 3 | Content-Type: application/json;charset=utf-8 4 | 5 | { 6 | "sysCategory" : [ 7 | { 8 | "fileStorages" : [ 9 | { 10 | "fileId" : "file00004", 11 | "fileName" : "py⼊⻔到放弃", 12 | "storagePath" : "stream/78/65/00106250790864562421764687.pdf", 13 | "fileSuffix" : "pdf", 14 | "publicUrl" : "http:\/\/www.baidu.com\/xxx\/xxx\/xxxpy.txt", 15 | "fileSize" : "123456" 16 | }, 17 | { 18 | "fileId" : "file00005", 19 | "fileName" : "py⼊⻔到放弃2", 20 | "storagePath" : "stream/78/65/00106250790864562421764687.pdf", 21 | "fileSuffix" : "pdf", 22 | "publicUrl" : "http:\/\/www.baidu.com\/xxx\/xxx\/xxxpy2.txt", 23 | "fileSize" : "123456" 24 | } 25 | ], 26 | "subCategory" : [ 27 | { 28 | "fileStorages" : [ 29 | { 30 | "fileId" : "file00004", 31 | "fileName" : " 《医学病毒学》知识图谱-231229", 32 | "storagePath" : "stream/40/86/00106250791926853795844449.docx", 33 | "fileSuffix" : "docx", 34 | "publicUrl" : "http:\/\/www.baidu.com\/xxx\/xxx\/xxxpy.txt", 35 | "fileSize" : "123456" 36 | }, 37 | { 38 | "fileId" : "file00005", 39 | "fileName" : "《学生手册》2023版", 40 | "storagePath" : "stream/85/11/00106250792497027481606005.pdf", 41 | "fileSuffix" : "pdf", 42 | "publicUrl" : "http:\/\/www.baidu.com\/xxx\/xxx\/xxxpy2.txt", 43 | "fileSize" : "123456" 44 | } 45 | ], 46 | "categoryId" : "234324234", 47 | "parentId" : "1010010101", 48 | "categoryName" : "⼆级分类" 49 | } 50 | ], 51 | "categoryId" : "1010010101", 52 | "parentId" : "0", 53 | "categoryName" : "⼀级分类", 54 | "level" : 1 55 | } 56 | ], 57 | "syncId" : "123" 58 | } 59 | 60 | -------------------------------------------------------------------------------- /test/http_test/test2.http: -------------------------------------------------------------------------------- 1 | 2 | POST http://127.0.0.1:8000/update_vector/ HTTP/1.1 3 | Content-Type: application/json;charset=utf-8 4 | 5 | { 6 | "sysCategory" : [ 7 | { 8 | 9 | "subCategory" : [ 10 | { 11 | "fileStorages" : [ 12 | { 13 | "fileId" : "file00004", 14 | "fileName" : "《医学病毒学》知识图谱-231229", 15 | "storagePath" : "stream/40/86/00106250791926853795844449.docx", 16 | "fileSuffix" : "docx", 17 | "publicUrl" : "http:\/\/www.baidu.com\/xxx\/xxx\/xxxpy.txt", 18 | "fileSize" : "123456" 19 | } 20 | ], 21 | "categoryId" : "234324234", 22 | "parentId" : "1010010101", 23 | "categoryName" : "⼆级分类" 24 | } 25 | ], 26 | "categoryId" : "1010010101", 27 | "parentId" : "0", 28 | "categoryName" : "⼀级分类", 29 | "level" : 1 30 | } 31 | ], 32 | "syncId" : "123" 33 | } 34 | 35 | -------------------------------------------------------------------------------- /test/http_test/test3.http: -------------------------------------------------------------------------------- 1 | 2 | POST http://127.0.0.1:8000/update_vector/ HTTP/1.1 3 | Content-Type: application/json;charset=utf-8 4 | 5 | { 6 | "sysCategory": [ 7 | { 8 | "categoryId": "20026250115495299645444018", 9 | "categoryName": "基础医学", 10 | "categoryType": "KNOWLEDGE_CATEGORY", 11 | "cid": "sinodata", 12 | "description": "", 13 | "gmtCreate": 1716281565000, 14 | "gmtModified": 1716291331000, 15 | "id": "2", 16 | "isDeleted": 1, 17 | "level": 1, 18 | "parentId": "0", 19 | "path": "/0/20026250115495299645444018", 20 | "sort": 1, 21 | "subCategory": [ 22 | { 23 | "categoryId": "20026250120594315018249896", 24 | "categoryName": "解刨学", 25 | "categoryType": "KNOWLEDGE_CATEGORY", 26 | "cid": "sinodata", 27 | "description": "", 28 | "fileStorages": [ 29 | { 30 | "categoryId": "20026250120594315018249896", 31 | "cid": "sinodata", 32 | "fileId": "20036265406996757872647591", 33 | "fileMeatData": "null", 34 | "fileName": "临床实地局部解剖学.pdf", 35 | "fileSize": -1, 36 | "fileSuffix": "pdf", 37 | "gmtCreate": 1716646143000, 38 | "gmtModified": 1716646143000, 39 | "id": "1", 40 | "isDeleted": 1, 41 | "publicRead": 2, 42 | "storagePath": "stream/88/15/00106265406971424276483376.pdf", 43 | "storageType": 1 44 | } 45 | ], 46 | "gmtCreate": 1716281687000, 47 | "gmtModified": 1716291353000, 48 | "id": "8", 49 | "isDeleted": 1, 50 | "level": 2, 51 | "parentId": "20026250115495299645444018", 52 | "path": "/0/20026250115495299645444018/20026250120594315018249896", 53 | "sort": 1 54 | } 55 | ] 56 | } 57 | ], 58 | "syncId": "20036268858974255185924634" 59 | } -------------------------------------------------------------------------------- /test/http_test/test4.http: -------------------------------------------------------------------------------- 1 | GET http://127.0.0.1:8001/async-endpoint/ HTTP/1.1 2 | Content-Type: application/json;charset=utf-8 3 | 4 | {} -------------------------------------------------------------------------------- /test/test.json: -------------------------------------------------------------------------------- 1 | 2 | { 3 | "sysCategory" : [ 4 | { 5 | "fileStorages" : [ 6 | { 7 | "fileId" : "file00004", 8 | "fileName" : "py⼊⻔到放弃", 9 | "storagePath" : "stream/78/65/00106250790864562421764687.pdf", 10 | "fileSuffix" : "pdf", 11 | "publicUrl" : "http:\/\/www.baidu.com\/xxx\/xxx\/xxxpy.txt", 12 | "fileSize" : "123456" 13 | }, 14 | { 15 | "fileId" : "file00005", 16 | "fileName" : "py⼊⻔到放弃2", 17 | "storagePath" : "\/xxx\/xxx\/xxx2.txt", 18 | "fileSuffix" : "txt", 19 | "publicUrl" : "http:\/\/www.baidu.com\/xxx\/xxx\/xxxpy2.txt", 20 | "fileSize" : "123456" 21 | } 22 | ], 23 | "subCategory" : [ 24 | { 25 | "fileStorages" : [ 26 | { 27 | "fileId" : "file00004", 28 | "fileName" : " 《医学病毒学》知识图谱-231229", 29 | "storagePath" : "stream/40/86/00106250791926853795844449.docx", 30 | "fileSuffix" : "docx", 31 | "publicUrl" : "http:\/\/www.baidu.com\/xxx\/xxx\/xxxpy.txt", 32 | "fileSize" : "123456" 33 | }, 34 | { 35 | "fileId" : "file00005", 36 | "fileName" : "《学生手册》2023版", 37 | "storagePath" : "stream/85/11/00106250792497027481606005.pdf", 38 | "fileSuffix" : "pdf", 39 | "publicUrl" : "http:\/\/www.baidu.com\/xxx\/xxx\/xxxpy2.txt", 40 | "fileSize" : "123456" 41 | } 42 | ], 43 | "categoryId" : "234324234", 44 | "parentId" : "1010010101", 45 | "categoryName" : "⼆级分类" 46 | } 47 | ], 48 | "categoryId" : "1010010101", 49 | "parentId" : "0", 50 | "categoryName" : "⼀级分类", 51 | "level" : 1 52 | } 53 | ], 54 | "syncId" : "123" 55 | } 56 | 57 | -------------------------------------------------------------------------------- /test/test.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author: wangjia 3 | Date: 2024-05-19 22:18:54 4 | LastEditors: wangjia 5 | LastEditTime: 2024-05-20 03:33:35 6 | Description: file content 7 | """ 8 | 9 | from pymilvus import Collection, CollectionSchema, DataType, FieldSchema, connections 10 | 11 | from app.core.bce.embedding_client import EmbeddingClient 12 | 13 | connections.connect("default", host="localhost", port="19530") 14 | embedding_client = EmbeddingClient("/data/WoLLM/bce-embedding-base_v1") 15 | 16 | # 读取txt文档,句子列表 17 | f = open("test.txt", "r", encoding="utf-8") 18 | doc = f.read().splitlines() 19 | 20 | embeddings = embedding_client.get_embedding(doc) 21 | 22 | fields = [ 23 | FieldSchema( 24 | name="id", dtype=DataType.VARCHAR, is_primary=True, auto_id=True, max_length=100 25 | ), 26 | FieldSchema(name="kb_name", dtype=DataType.VARCHAR, max_length=100), 27 | FieldSchema(name="file_name", dtype=DataType.VARCHAR, max_length=100), 28 | FieldSchema(name="chunk_id", dtype=DataType.INT64), 29 | FieldSchema(name="chunk_content", dtype=DataType.VARCHAR, max_length=100), 30 | FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=768), # 向量字段 31 | ] 32 | 33 | schema = CollectionSchema(fields=fields) 34 | # 创建 collection 35 | # print("client is connected:", client.is_connected()) 36 | collection = Collection("test_collection3", schema=schema) 37 | # 添加索引 38 | index_params = { 39 | "metric_type": "L2", 40 | "index_type": "IVF_FLAT", 41 | "params": {"nlist": 128}, 42 | } 43 | # 字段 filmVector 创建索引 44 | collection.create_index("embedding", index_params) 45 | 46 | entities = [["test_kb"] * len(doc), doc, range(len(doc)), doc, embeddings] 47 | collection.insert(entities) 48 | # 记得在插入数据后调用 flush 来确保数据被写入到磁盘 49 | collection.flush() 50 | print(collection) 51 | collection.load() 52 | -------------------------------------------------------------------------------- /test/test2.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Author: wangjia 3 | Date: 2024-05-19 22:42:44 4 | LastEditors: wangjia 5 | LastEditTime: 2024-05-19 22:42:47 6 | Description: file content 7 | ''' 8 | from pymilvus import connections, utility 9 | 10 | # 创建连接 11 | connections.connect("default", host='localhost', port='19530') 12 | 13 | # 检查连接是否正常 14 | try: 15 | status = utility.get_connection_addr(alias="default") 16 | print("Connected to Milvus server:", status) 17 | except Exception as e: 18 | print("Something went wrong:", e) 19 | -------------------------------------------------------------------------------- /test/test3.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author: wangjia 3 | Date: 2024-05-21 01:29:51 4 | LastEditors: wangjia 5 | LastEditTime: 2024-05-21 10:53:28 6 | Description: file content 7 | """ 8 | 9 | from app.core.bce.embedding_client import EmbeddingClient 10 | from app.core.preprocessor.file_processor import FileProcesser 11 | 12 | if __name__ == "__main__": 13 | file_path = "example/test3.txt" 14 | l1_kb = "test1" 15 | l2_kb = "test2" 16 | embedding_model = "/data/gpu/base_models/bge-large-zh-v1.5" 17 | lf = FileProcesser(l1_kb, l2_kb, file_path) 18 | embedding_client = EmbeddingClient(embedding_model) 19 | 20 | docs = lf.split_file_to_docs() 21 | docs_content = [doc.page_content for doc in docs] 22 | embeddings = embedding_client.get_embedding(docs_content) 23 | print(embeddings[0]) 24 | print(embeddings.shape) 25 | -------------------------------------------------------------------------------- /test/test_async.py: -------------------------------------------------------------------------------- 1 | # main.py 2 | from fastapi import FastAPI 3 | import asyncio 4 | import uvicorn 5 | app = FastAPI() 6 | 7 | async def print_hello_world(): 8 | # 模拟异步 I/O 9 | # await asyncio.sleep(1) 10 | 11 | print("Hello, World!") 12 | 13 | @app.get("/async-endpoint") 14 | async def root(): 15 | # 模拟异步 I/O,这里可以是任何 I/O 操作,例如读写数据库、文件等 16 | asyncio.create_task(print_hello_world()) 17 | return {"message": "Task started"} 18 | 19 | if __name__ == '__main__': 20 | uvicorn.run(app, host="0.0.0.0", port=8001) -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Author: wangjia 3 | Date: 2024-05-19 00:04:14 4 | LastEditors: wangjia 5 | LastEditTime: 2024-05-26 14:38:05 6 | Description: file content 7 | ''' 8 | from loguru import logger 9 | from datetime import datetime 10 | import time 11 | import sys 12 | 13 | 14 | class Logger: 15 | 16 | @classmethod 17 | def get_logger(self): 18 | folder_ = "logs/" 19 | prefix_ = "mylog-" 20 | rotation_ = "10 MB" 21 | retention_ = "30 days" 22 | encoding_ = "utf-8" 23 | backtrace_ = True 24 | diagnose_ = True 25 | 26 | # 格式里面添加了process和thread记录,方便查看多进程和线程程序 27 | format_ = ( 28 | "{time:YYYY-MM-DD HH:mm:ss.SSS}" 29 | "|{level: <7}" 30 | "|{name}:{function}:{line}" 31 | "|{message}") 32 | 33 | # 这里面采用了层次式的日志记录方式,就是低级日志文件会记录比他高的所有级别日志,这样可以做到低等级日志最丰富,高级别日志更少更关键 34 | # debug 35 | logger.add( 36 | folder_ + prefix_ + "debug.log", 37 | level="DEBUG", 38 | backtrace=backtrace_, 39 | diagnose=diagnose_, 40 | format=format_, 41 | colorize=False, 42 | rotation=rotation_, 43 | retention=retention_, 44 | encoding=encoding_, 45 | filter=lambda record: record["level"].no >= logger.level("DEBUG"). 46 | no, 47 | ) 48 | 49 | # info 50 | logger.add( 51 | folder_ + prefix_ + "info.log", 52 | level="INFO", 53 | backtrace=backtrace_, 54 | diagnose=diagnose_, 55 | format=format_, 56 | colorize=False, 57 | rotation=rotation_, 58 | retention=retention_, 59 | encoding=encoding_, 60 | filter=lambda record: record["level"].no >= logger.level("INFO"). 61 | no, 62 | ) 63 | 64 | # warning 65 | logger.add( 66 | folder_ + prefix_ + "warning.log", 67 | level="WARNING", 68 | backtrace=backtrace_, 69 | diagnose=diagnose_, 70 | format=format_, 71 | colorize=False, 72 | rotation=rotation_, 73 | retention=retention_, 74 | encoding=encoding_, 75 | filter=lambda record: record["level"].no >= logger.level("WARNING") 76 | .no, 77 | ) 78 | 79 | # error 80 | logger.add( 81 | folder_ + prefix_ + "error.log", 82 | level="ERROR", 83 | backtrace=backtrace_, 84 | diagnose=diagnose_, 85 | format=format_, 86 | colorize=False, 87 | rotation=rotation_, 88 | retention=retention_, 89 | encoding=encoding_, 90 | filter=lambda record: record["level"].no >= logger.level("ERROR"). 91 | no, 92 | ) 93 | 94 | # critical 95 | logger.add( 96 | folder_ + prefix_ + "critical.log", 97 | level="CRITICAL", 98 | backtrace=backtrace_, 99 | diagnose=diagnose_, 100 | format=format_, 101 | colorize=False, 102 | rotation=rotation_, 103 | retention=retention_, 104 | encoding=encoding_, 105 | filter=lambda record: record["level"].no >= logger.level("CRITICAL" 106 | ).no, 107 | ) 108 | 109 | logger.add( 110 | sys.stderr, 111 | level="CRITICAL", 112 | backtrace=backtrace_, 113 | diagnose=diagnose_, 114 | format=format_, 115 | colorize=True, 116 | filter=lambda record: record["level"].no >= logger.level("CRITICAL" 117 | ).no, 118 | ) 119 | 120 | return logger 121 | 122 | 123 | # 自定义日志输出 124 | logger = Logger.get_logger() 125 | 126 | # 获取当前时间,并格式化日期时间 127 | current_time = datetime.now() 128 | current_date = current_time.strftime("%Y-%m-%d") 129 | formatted_time = current_time.strftime("%Y-%m-%d %H:%M:%S") 130 | 131 | 132 | # 计算时间函数 133 | def timeit(func): 134 | 135 | def wrapper(*args, **kw): 136 | start_time = time.time() 137 | result = func(*args, **kw) 138 | cost_time = time.time() - start_time 139 | print("==" * 25) 140 | print("Current Function [%s] run time is %s s" % 141 | (func.__name__, cost_time)) 142 | print("==" * 25) 143 | return result 144 | 145 | return wrapper 146 | --------------------------------------------------------------------------------