├── .idea
├── .gitignore
├── FinRAG.iml
├── inspectionProfiles
│ ├── Project_Default.xml
│ └── profiles_settings.xml
├── misc.xml
├── modules.xml
└── vcs.xml
├── Dockerfile
├── LICENSE
├── README.md
├── app
├── __init__.py
├── core
│ ├── bce
│ │ ├── __init__.py
│ │ ├── embedding_client.py
│ │ └── rerank_client.py
│ ├── chat
│ │ ├── __init__.py
│ │ ├── open_chat.py
│ │ └── rag_chat.py
│ ├── loader
│ │ └── pdf_loader.py
│ ├── preprocessor
│ │ └── file_processor.py
│ ├── splitter
│ │ ├── __init__.py
│ │ ├── chinese_text_splitter.py
│ │ └── zh_title_enhance.py
│ └── vectorstore
│ │ ├── __init__.py
│ │ └── customer_milvus_client.py
├── finrag_server.py
├── models
│ ├── __init__.py
│ ├── dialog.py
│ └── status.py
└── oss
│ ├── __init__.py
│ └── download_file.py
├── bin
└── start.sh
├── conf
└── config.py
├── docker
└── docker-compose.yml
├── example
├── test.pdf
├── test1.pdf
├── test2.docx
├── test3.txt
└── ~$test2.docx
├── img.png
├── img_1.png
├── img_2.png
├── main.py
├── requirements.txt
├── setup.py
├── test
├── cz.py
├── http_test
│ ├── notify 2.http
│ ├── notify.http
│ ├── query_1.http
│ ├── query_2.http
│ ├── test.http
│ ├── test2.http
│ ├── test3.http
│ └── test4.http
├── test.json
├── test.py
├── test2.py
├── test3.py
└── test_async.py
└── utils.py
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # Default ignored files
2 | /shelf/
3 | /workspace.xml
4 |
--------------------------------------------------------------------------------
/.idea/FinRAG.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/Project_Default.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
179 |
180 |
181 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM python:3.10-slim
2 | # 镜像元信息
3 | LABEL MAINTAINER=wangjia
4 | # 环境设置
5 | ENV LANG=C.UTF-8
6 | ENV TZ=Asia/Shanghai
7 | WORKDIR /FinRAG
8 | COPY . /FinRAG
9 | RUN pip install -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple
10 | EXPOSE 8000
11 | CMD ["/bin/bash", "/bin/start.sh"]
12 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright 2024 AI4Finance Foundation Inc.
190 | All rights reserved.
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
203 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # FinRAG: Financial Retrieval Augmented Generation
2 |
3 | [](https://discord.gg/trsr8SXpW5)
4 |
5 | 
6 |
7 | ## 1. 准备工作
8 |
9 | ### 1.1 安装minicoda
10 | #### 网址
11 | `https://docs.anaconda.com/free/miniconda/miniconda-other-installer-links/#linux-installers`
12 | #### python3.10.14
13 | `wget https://repo.anaconda.com/miniconda/Miniconda3-py310_24.4.0-0-Linux-x86_64.sh`
14 | #### 配置环境变量
15 | `export PATH=$HOME/miniconda3/bin:$PATH`
16 |
17 | ### 1.2 启动Milvus向量数据库
18 | - 使用docker-compose启动Milvus服务
19 | ```
20 | cd docker #切换至docker配置目录环境
21 | docker-compose up -d #启动项目中的服务,并在后台以守护进程方式运行
22 | ```
23 | 如果对docker不了解,可以看下以下文章: \
24 | [docker-compose快速入门](https://blog.csdn.net/m0_37899908/article/details/131268835) \
25 | [docker-conpose命令解读](https://blog.csdn.net/weixin_42494218/article/details/135986248) \
26 | 术语说明: \
27 | 守护进程: 是一类在后台运行的特殊进程,用于执行特定的系统任务,会一直存在。如果以非守护进程启动,服务容易被终止。
28 |
29 | - Milvus 前端展示地址
30 | `http://{ip}:3100/#/` 把ip替换为你所在服务器的ip地址即可
31 |
32 | ### 1.3 Embedding以及Rerank模型下载
33 | - 新建/data/WoLLM 目录
34 | - 将以下两个模型下载到新建的目录中
35 | - Embedding Model 下载:`git clone https://www.modelscope.cn/maidalun/bce-embedding-base_v1.git`
36 | - Rerank Model 下载:`git clone https://www.modelscope.cn/maidalun/bce-reranker-base_v1.git`
37 | 说明: \
38 | Embedding Model: 主要是完成将自然语言文本转化为固定维度向量的工作,主要在知识库的建模,用户查询query表示时会应用。 \
39 | Rerank Model:对结果进行重排操作。 \
40 | 这里采用的都是bce的模型,因为其在RAG上表现较好,可以参考资料,了解一下背景: \
41 | [BCE Embedding技术报告](https://zhuanlan.zhihu.com/p/681370855) \
42 | 下载好两个模型后, 将模型放到指定的位置,并更新项目conf.config.py文件中EMBEDDING_MODEL和RERANK_MODEL对应参数的路径(和模型路径保持一致)。如图:
43 | 
44 |
45 |
46 | ### 1.4 安装依赖及修改配置信息
47 | #### 新建python虚拟环境
48 | `python -m venv .venv`
49 | #### 激活环境
50 | `source .venv/bin/activate`
51 | #### 安装项目依赖
52 | `pip install -r requirements.txt`
53 |
54 | ### 1.5 修改配置文件
55 |
56 | - 配置文件在conf/config.py
57 | - 配置文件的各项信息修改为自己的信息, 每个变量已经加了详细注释.
58 | - 可能需要改动的参数一般就是两个模型文件目录,如图:
59 | 
60 | - 如果你需要修改端口,或者服务器变更,你需要修改docker.docker-compose.yml中的配置参数,一般就是修改ip和端口。
61 | 
62 |
63 | ### 1.6 修改环境变量
64 |
65 | - 将.env_user复制为.env `cp .env_user .env` [重要,重要,重要. 必须复制一下]
66 | - 修改.env的LLM以及OSS相关变量信息(目前只需要复制一下即可, 不需要修改里面的内容了)
67 |
68 | ## 2. 启动App
69 | ### 2.1 第一种方式启动
70 | `python main.py`
71 | ### 2.2 第二种方式启动(为之后打进docker做准备,目前先用第一种方式)
72 | `bin/start.sh`
73 |
74 |
75 | ToDo
76 | - [ ] rag 优化
77 | - [ ] 多级索引优化
78 | - [ ] 多查询优化
79 | - [ ] query优化
80 | - [ ] 解析优化
81 |
--------------------------------------------------------------------------------
/app/__init__.py:
--------------------------------------------------------------------------------
1 | '''
2 | Author: xubing
3 | Date: 2024-05-19 00:03:21
4 | LastEditors: xubing
5 | LastEditTime: 2024-05-19 00:03:22
6 | Description: file content
7 | '''
8 |
--------------------------------------------------------------------------------
/app/core/bce/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AI4Finance-Foundation/FinRAG/790c6158596747a3ea58ea1dae7972ac0f18d3dd/app/core/bce/__init__.py
--------------------------------------------------------------------------------
/app/core/bce/embedding_client.py:
--------------------------------------------------------------------------------
1 | """
2 | Author: wangjia
3 | Date: 2024-05-20 03:31:26
4 | LastEditors: wangjia
5 | LastEditTime: 2024-05-21 11:19:43
6 | Description: file content
7 | """
8 |
9 | import numpy as np
10 | from BCEmbedding import EmbeddingModel
11 |
12 | from conf.config import DEVICE
13 |
14 |
15 | class EmbeddingClient:
16 |
17 | def __init__(self, model_name_or_path) -> None:
18 | self.model = EmbeddingModel(
19 | model_name_or_path=model_name_or_path,
20 | device=DEVICE,
21 | trust_remote_code=True,
22 | )
23 |
24 | def get_embedding(self, sentences):
25 | embeddings = self.model.encode(sentences)
26 | return embeddings
27 |
28 |
29 | if __name__ == "__main__":
30 |
31 | # 读取txt文档,句子列表
32 | f = open("test.txt", "r", encoding="utf-8")
33 | doc = f.read().splitlines()
34 | embedding_client = EmbeddingClient("/data/WoLLM/bce-embedding-base_v1")
35 | embedding = embedding_client.get_embedding(doc)
36 | print(embedding.shape)
37 |
--------------------------------------------------------------------------------
/app/core/bce/rerank_client.py:
--------------------------------------------------------------------------------
1 | '''
2 | Author: wangjia
3 | Date: 2024-05-23 20:40:21
4 | LastEditors: wangjia
5 | LastEditTime: 2024-05-23 21:00:07
6 | Description: file content
7 | '''
8 | from BCEmbedding import RerankerModel
9 |
10 |
11 | class RerankClient:
12 | def __init__(self,rerank_model) -> None:
13 | self.model = RerankerModel(rerank_model,
14 | trust_remote_code=True,)
15 | def rerank(self,query,massages):
16 | sentence_pairs = [[query, massage] for massage in massages]
17 | #scores = self.model.compute_score(sentence_pairs)
18 | rerank_results = self.model.rerank(query, massages)
19 | return rerank_results
20 |
--------------------------------------------------------------------------------
/app/core/chat/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AI4Finance-Foundation/FinRAG/790c6158596747a3ea58ea1dae7972ac0f18d3dd/app/core/chat/__init__.py
--------------------------------------------------------------------------------
/app/core/chat/open_chat.py:
--------------------------------------------------------------------------------
1 | '''
2 | Author: xubing
3 | Date: 2024-05-22 23:43:54
4 | LastEditors: xubing
5 | LastEditTime: 2024-05-23 19:44:13
6 | Description: file content
7 | '''
8 | import os
9 |
10 | from openai import OpenAI
11 | from utils import logger
12 |
13 | class OpenChat:
14 | def __init__(self) -> None:
15 | self.client = OpenAI(
16 | api_key=os.getenv("DASHSCOPE_API_KEY"), # 如果您没有配置环境变量,请在此处用您的API Key进行替换
17 | base_url="https://dashscope.aliyuncs.com/compatible-mode/v1", # 填写DashScope SDK的base_url
18 | )
19 | def chat(self,messages):
20 | logger.info(str(messages))
21 | completion = self.client.chat.completions.create(
22 | model="qwen-plus",
23 | messages=messages,
24 | stream=False
25 | )
26 | result = completion.choices[0].message.content
27 | print(result)
28 | return result
29 |
30 | if __name__=="__main__":
31 | oc = OpenChat()
32 | raw_messsages = [
33 | {
34 | "chatMessageId": "m00001",
35 | "role": "system",
36 | "rawContent": "你是⼀个医疗助⼿"
37 | },
38 | {
39 | "chatMessageId": "m00002",
40 | "role": "user",
41 | "rawContent": "我感冒了吃什么药"
42 | },
43 | {
44 | "chatMessageId": "m00003",
45 | "role": "assistant",
46 | "rawContent": "你要是999感冒灵"
47 | },
48 | {
49 | "chatMessageId": "m00004",
50 | "role": "user",
51 | "rawContent": "我吃了999感觉没什么⽤"
52 | }
53 | ]
54 | messages = [
55 | {
56 | "role": x.get('role').lower(),
57 | "content": x.get("rawContent")
58 | }
59 | for x in raw_messsages
60 | ]
61 | print(messages)
62 | oc.chat(messages)
63 |
64 |
--------------------------------------------------------------------------------
/app/core/chat/rag_chat.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from openai import OpenAI
4 |
5 |
6 | class RAGChat:
7 | def __init__(self) -> None:
8 | self.client = OpenAI(
9 | api_key=os.getenv("DASHSCOPE_API_KEY"), # 如果您没有配置环境变量,请在此处用您的API Key进行替换
10 | base_url="https://dashscope.aliyuncs.com/compatible-mode/v1", # 填写DashScope SDK的base_url
11 | )
12 | def chat(self,messages):
13 | query = messages[-1].get("content")
14 | query_emb = ''
15 |
16 | completion = self.client.chat.completions.create(
17 | model="qwen-plus",
18 | messages=messages,
19 | stream=False
20 | )
21 | result = completion.choices[0].message.content
22 | print(result)
23 | return result
--------------------------------------------------------------------------------
/app/core/loader/pdf_loader.py:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | import base64
5 | import os
6 | from typing import Any, Callable, List, Union
7 |
8 | import fitz
9 | import numpy as np
10 | from langchain.document_loaders.unstructured import UnstructuredFileLoader
11 | from paddleocr import PaddleOCR
12 | from tqdm import tqdm
13 | from unstructured.partition.text import partition_text
14 |
15 | ocr_engine = PaddleOCR(use_angle_cls=True, lang="ch", use_gpu=True, show_log=False)
16 |
17 |
18 | class UnstructuredPaddlePDFLoader(UnstructuredFileLoader):
19 | """Loader that uses unstructured to load image files, such as PNGs and JPGs."""
20 |
21 | def __init__(
22 | self,
23 | file_path: Union[str, List[str]],
24 | # ocr_engine: Callable,
25 | mode: str = "single",
26 | **unstructured_kwargs: Any,
27 | ):
28 | """Initialize with file path."""
29 | self.ocr_engine = ocr_engine
30 | super().__init__(file_path=file_path, mode=mode, **unstructured_kwargs)
31 |
32 | def _get_elements(self) -> List:
33 | def pdf_ocr_txt(filepath, dir_path="tmp"):
34 | full_dir_path = os.path.join(os.path.dirname(filepath), dir_path)
35 | if not os.path.exists(full_dir_path):
36 | os.makedirs(full_dir_path)
37 | doc = fitz.open(filepath)
38 | txt_file_path = os.path.join(
39 | full_dir_path, "{}.txt".format(os.path.split(filepath)[-1])
40 | )
41 | img_name = os.path.join(full_dir_path, "tmp.png")
42 | with open(txt_file_path, "w", encoding="utf-8") as fout:
43 | for i in tqdm(range(doc.page_count)):
44 | page = doc.load_page(i)
45 | pix = page.get_pixmap()
46 | img = np.frombuffer(pix.samples, dtype=np.uint8).reshape(
47 | (pix.h, pix.w, pix.n)
48 | )
49 | img_file = base64.b64encode(img).decode("utf-8")
50 | height, width, channels = pix.h, pix.w, pix.n
51 |
52 | binary_data = base64.b64decode(img_file)
53 | img_array = np.frombuffer(binary_data, dtype=np.uint8).reshape(
54 | (height, width, channels)
55 | )
56 | # result = self.ocr_engine(img_array)
57 | result = self.ocr_engine.ocr(img_array)
58 | result = [line for line in result if line]
59 | ocr_result = [i[1][0] for line in result for i in line]
60 | fout.write("\n".join(ocr_result))
61 | if os.path.exists(img_name):
62 | os.remove(img_name)
63 | return txt_file_path
64 |
65 | txt_file_path = pdf_ocr_txt(self.file_path)
66 | return partition_text(filename=txt_file_path, **self.unstructured_kwargs)
67 |
68 |
69 | if __name__ == "__main__":
70 |
71 | ...
72 | # from paddleocr import PaddleOCR
73 | # from app.core.text_splitter.chinese_text_splitter import ChineseTextSplitter
74 |
75 | # file_path = "附件4-1:广银理财幸福理财日添利开放式理财计划第2期产品说明书(百信银行B份额).pdf"
76 | # ocr_engine = PaddleOCR(use_angle_cls=True, lang="ch", use_gpu=True, show_log=False)
77 | # loader = UnstructuredPaddlePDFLoader(file_path, ocr_engine)
78 | # texts_splitter = ChineseTextSplitter(pdf=True, sentence_size=250)
79 | # docs = loader.load_and_split(texts_splitter)
80 | # print(docs)
81 |
82 |
--------------------------------------------------------------------------------
/app/core/preprocessor/file_processor.py:
--------------------------------------------------------------------------------
1 |
2 | from langchain.text_splitter import RecursiveCharacterTextSplitter
3 | from langchain_community.document_loaders import (
4 | TextLoader, UnstructuredFileLoader, UnstructuredPDFLoader,
5 | UnstructuredWordDocumentLoader)
6 |
7 | from app.core.splitter import ChineseTextSplitter, zh_title_enhance
8 | from conf import config
9 | from utils import logger
10 |
11 | text_splitter = RecursiveCharacterTextSplitter(
12 | separators=[
13 | "\n",
14 | ".",
15 | "。",
16 | "!",
17 | "!",
18 | "?",
19 | "?",
20 | ";",
21 | ";",
22 | "……",
23 | "…",
24 | "、",
25 | ",",
26 | ",",
27 | " ",
28 | ],
29 | chunk_size=300,
30 | # length_function=num_tokens,
31 | )
32 |
33 |
34 | class FileProcesser:
35 |
36 | def __init__(self):
37 | logger.info(f"Success init file processor")
38 |
39 | def split_file_to_docs(self,
40 | file_path,
41 | sentence_size=config.SENTENCE_SIZE):
42 | logger.info("开始解析文件,文件越大,解析所需时间越长,大文件请耐心等待...")
43 | file_type = file_path.split('.')[-1].lower()
44 |
45 | if file_type == "txt":
46 | loader = TextLoader(file_path, autodetect_encoding=True)
47 | texts_splitter = ChineseTextSplitter(pdf=False,
48 | sentence_size=sentence_size)
49 | docs = loader.load_and_split(texts_splitter)
50 | elif file_type == "pdf":
51 | loader = UnstructuredPDFLoader(file_path)
52 | texts_splitter = ChineseTextSplitter(pdf=True,
53 | sentence_size=sentence_size)
54 | docs = loader.load_and_split(texts_splitter)
55 | elif file_type == "docx":
56 | loader = UnstructuredWordDocumentLoader(file_path, mode="elements")
57 | texts_splitter = ChineseTextSplitter(pdf=False,
58 | sentence_size=sentence_size)
59 | docs = loader.load_and_split(texts_splitter)
60 | else:
61 | raise TypeError("文件类型不支持,目前仅支持:[txt,pdf,docx]")
62 | # docs = zh_title_enhance(docs)
63 | # 重构docs,如果doc的文本长度大于800tokens,则利用text_splitter将其拆分成多个doc
64 | # text_splitter: RecursiveCharacterTextSplitter
65 | logger.info(f"before 2nd split doc lens: {len(docs)}")
66 | docs = text_splitter.split_documents(docs)
67 | logger.info(f"after 2nd split doc lens: {len(docs)}")
68 | self.docs = docs
69 | return docs
70 |
--------------------------------------------------------------------------------
/app/core/splitter/__init__.py:
--------------------------------------------------------------------------------
1 | from .chinese_text_splitter import ChineseTextSplitter
2 | from .zh_title_enhance import zh_title_enhance
3 |
--------------------------------------------------------------------------------
/app/core/splitter/chinese_text_splitter.py:
--------------------------------------------------------------------------------
1 | '''
2 | Author: wangjia
3 | Date: 2024-05-20 21:23:22
4 | LastEditors: wangjia
5 | LastEditTime: 2024-05-20 21:25:43
6 | Description: file content
7 | '''
8 | import re
9 | from typing import List
10 |
11 | from langchain.text_splitter import CharacterTextSplitter
12 |
13 |
14 | class ChineseTextSplitter(CharacterTextSplitter):
15 | def __init__(self, pdf: bool = False, sentence_size: int = 200, **kwargs):
16 | super().__init__(**kwargs)
17 | self.pdf = pdf
18 | self.sentence_size = sentence_size
19 |
20 | def split_text1(self, text: str) -> List[str]:
21 | if self.pdf:
22 | text = re.sub(r"\n{3,}", "\n", text)
23 | text = re.sub('\s', ' ', text)
24 | text = text.replace("\n\n", "")
25 | sent_sep_pattern = re.compile('([﹒﹔﹖﹗.。!?]["’”」』]{0,2}|(?=["‘“「『]{1,2}|$))') # del :;
26 | sent_list = []
27 | for ele in sent_sep_pattern.split(text):
28 | if sent_sep_pattern.match(ele) and sent_list:
29 | sent_list[-1] += ele
30 | elif ele:
31 | sent_list.append(ele)
32 | return sent_list
33 |
34 | def split_text(self, text: str) -> List[str]: ##此处需要进一步优化逻辑
35 | if self.pdf:
36 | text = re.sub(r"\n{3,}", r"\n", text)
37 | text = re.sub('\s', " ", text)
38 | text = re.sub("\n\n", "", text)
39 |
40 | text = re.sub(r'([;;.!?。!?\?])([^”’])', r"\1\n\2", text) # 单字符断句符
41 | text = re.sub(r'(\.{6})([^"’”」』])', r"\1\n\2", text) # 英文省略号
42 | text = re.sub(r'(\…{2})([^"’”」』])', r"\1\n\2", text) # 中文省略号
43 | text = re.sub(r'([;;!?。!?\?]["’”」』]{0,2})([^;;!?,。!?\?])', r'\1\n\2', text)
44 | # 如果双引号前有终止符,那么双引号才是句子的终点,把分句符\n放到双引号后,注意前面的几句都小心保留了双引号
45 | text = text.rstrip() # 段尾如果有多余的\n就去掉它
46 | # 很多规则中会考虑分号;,但是这里我把它忽略不计,破折号、英文双引号等同样忽略,需要的再做些简单调整即可。
47 | ls = [i for i in text.split("\n") if i]
48 | for ele in ls:
49 | if len(ele) > self.sentence_size:
50 | ele1 = re.sub(r'([,,.]["’”」』]{0,2})([^,,.])', r'\1\n\2', ele)
51 | ele1_ls = ele1.split("\n")
52 | for ele_ele1 in ele1_ls:
53 | if len(ele_ele1) > self.sentence_size:
54 | ele_ele2 = re.sub(r'([\n]{1,}| {2,}["’”」』]{0,2})([^\s])', r'\1\n\2', ele_ele1)
55 | ele2_ls = ele_ele2.split("\n")
56 | for ele_ele2 in ele2_ls:
57 | if len(ele_ele2) > self.sentence_size:
58 | ele_ele3 = re.sub('( ["’”」』]{0,2})([^ ])', r'\1\n\2', ele_ele2)
59 | ele2_id = ele2_ls.index(ele_ele2)
60 | ele2_ls = ele2_ls[:ele2_id] + [i for i in ele_ele3.split("\n") if i] + ele2_ls[
61 | ele2_id + 1:]
62 | ele_id = ele1_ls.index(ele_ele1)
63 | ele1_ls = ele1_ls[:ele_id] + [i for i in ele2_ls if i] + ele1_ls[ele_id + 1:]
64 |
65 | id = ls.index(ele)
66 | ls = ls[:id] + [i for i in ele1_ls if i] + ls[id + 1:]
67 | return ls
68 |
--------------------------------------------------------------------------------
/app/core/splitter/zh_title_enhance.py:
--------------------------------------------------------------------------------
1 | import re
2 | from typing import List
3 |
4 | from langchain.docstore.document import Document
5 |
6 |
7 | def under_non_alpha_ratio(text: str, threshold: float = 0.5):
8 | """Checks if the proportion of non-alpha characters in the text snippet exceeds a given
9 | threshold. This helps prevent text like "-----------BREAK---------" from being tagged
10 | as a title or narrative text. The ratio does not count spaces.
11 |
12 | Parameters
13 | ----------
14 | text
15 | The input string to test
16 | threshold
17 | If the proportion of non-alpha characters exceeds this threshold, the function
18 | returns False
19 | """
20 | if len(text) == 0:
21 | return False
22 |
23 | alpha_count = len([char for char in text if char.strip() and char.isalpha()])
24 | total_count = len([char for char in text if char.strip()])
25 | try:
26 | ratio = alpha_count / total_count
27 | return ratio < threshold
28 | except:
29 | return False
30 |
31 |
32 | def is_possible_title(
33 | text: str,
34 | title_max_word_length: int = 20,
35 | non_alpha_threshold: float = 0.5,
36 | ) -> bool:
37 | """Checks to see if the text passes all of the checks for a valid title.
38 |
39 | Parameters
40 | ----------
41 | text
42 | The input text to check
43 | title_max_word_length
44 | The maximum number of words a title can contain
45 | non_alpha_threshold
46 | The minimum number of alpha characters the text needs to be considered a title
47 | """
48 |
49 | # 文本长度为0的话,肯定不是title
50 | if len(text) == 0:
51 | print("Not a title. Text is empty.")
52 | return False
53 |
54 | # 文本中有标点符号,就不是title
55 | ENDS_IN_PUNCT_PATTERN = r"[^\w\s]\Z"
56 | ENDS_IN_PUNCT_RE = re.compile(ENDS_IN_PUNCT_PATTERN)
57 | if ENDS_IN_PUNCT_RE.search(text) is not None:
58 | return False
59 |
60 | # 文本长度不能超过设定值,默认20
61 | # NOTE(robinson) - splitting on spaces here instead of word tokenizing because it
62 | # is less expensive and actual tokenization doesn't add much value for the length check
63 | if len(text) > title_max_word_length:
64 | return False
65 |
66 | # 文本中数字的占比不能太高,否则不是title
67 | if under_non_alpha_ratio(text, threshold=non_alpha_threshold):
68 | return False
69 |
70 | # NOTE(robinson) - Prevent flagging salutations like "To My Dearest Friends," as titles
71 | if text.endswith((",", ".", ",", "。")):
72 | return False
73 |
74 | if text.isnumeric():
75 | print(f"Not a title. Text is all numeric:\n\n{text}") # type: ignore
76 | return False
77 |
78 | # 开头的字符内应该有数字,默认5个字符内
79 | if len(text) < 5:
80 | text_5 = text
81 | else:
82 | text_5 = text[:5]
83 | alpha_in_text_5 = sum(list(map(lambda x: x.isnumeric(), list(text_5))))
84 | if not alpha_in_text_5:
85 | return False
86 |
87 | return True
88 |
89 |
90 | def zh_title_enhance(docs: List[Document]) -> List[Document]:
91 | title = None
92 | if len(docs) > 0:
93 | for doc in docs:
94 | if is_possible_title(doc.page_content):
95 | doc.metadata['category'] = 'cn_Title'
96 | title = doc.page_content
97 | elif title:
98 | doc.page_content = f"下文与({title})有关。{doc.page_content}"
99 | return docs
100 | else:
101 | print("文件不存在")
102 |
--------------------------------------------------------------------------------
/app/core/vectorstore/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AI4Finance-Foundation/FinRAG/790c6158596747a3ea58ea1dae7972ac0f18d3dd/app/core/vectorstore/__init__.py
--------------------------------------------------------------------------------
/app/core/vectorstore/customer_milvus_client.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from pymilvus import (Collection, CollectionSchema, DataType, FieldSchema,
4 | MilvusClient, connections, utility)
5 |
6 | from app.core.bce.embedding_client import EmbeddingClient
7 | from app.core.bce.rerank_client import RerankClient
8 | from app.core.chat.open_chat import OpenChat
9 | from app.core.preprocessor.file_processor import FileProcesser
10 | from app.oss.download_file import Downloader
11 | from conf.config import (CACHE_DIR, COLLECTION_NAME, EMBEDDING_MODEL,
12 | MILVUS_URI, RAG_PROMPT, RERANK_MODEL, STORAGE_DIR,
13 | STORAGE_TYPE)
14 |
15 | embedding_client = EmbeddingClient(EMBEDDING_MODEL)
16 | rerank_client = RerankClient(RERANK_MODEL)
17 | oss_downloader = Downloader()
18 | file_processer = FileProcesser()
19 | open_chat = OpenChat()
20 | _dim = 768
21 |
22 | from utils import logger
23 |
24 |
25 | class CustomerMilvusClient:
26 |
27 | def __init__(self):
28 | # self.client = MilvusClient(
29 | # uri=config.milvus_uri
30 | # )
31 | connections.connect(uri=MILVUS_URI)
32 | self.collection_name = COLLECTION_NAME
33 | self.collection = self.init()
34 | self.collection.load()
35 |
36 | def init(self):
37 | try:
38 | if utility.has_collection(self.collection_name):
39 | collection = Collection(self.collection_name)
40 | logger.info(f"collection {self.collection_name} exists")
41 | else:
42 | schema = CollectionSchema(self.fields)
43 | logger.info(
44 | f"create collection {self.collection_name} {schema}")
45 | collection = Collection(self.collection_name, schema)
46 | index_params = {
47 | "metric_type": "L2",
48 | "index_type": "IVF_FLAT",
49 | "params": {
50 | "nlist": 2048
51 | },
52 | }
53 | collection.create_index(field_name="embedding",
54 | index_params=index_params)
55 | logger.info("初始化成功!")
56 | except Exception as e:
57 | logger.error(e)
58 | return collection
59 |
60 | def embedding_to_vdb(self, file_details, batch_size=1000):
61 | """
62 | 当文档过长时,会出现无法一次性插入向量数据库的情况,因此需要分批插入
63 | """
64 | for file_info in file_details:
65 | fileName = file_info.get("fileName")
66 | fileSuffix = file_info.get("fileSuffix")
67 | storagePath = file_info.get("storagePath")
68 | local_file = CACHE_DIR + "/" + fileName
69 | logger.info(f"正在将文件【{fileName}】下载到本地缓存...")
70 | try:
71 | oss_downloader.get_file(storagePath, local_file)
72 | except:
73 | logger.error("下载失败,请检查网络或检查文件是否存在")
74 | logger.info("本地文件地址:" + str(local_file))
75 | docs = file_processer.split_file_to_docs(local_file)
76 | docs_content = [doc.page_content for doc in docs]
77 | embeddings = embedding_client.get_embedding(docs_content)
78 | entities = []
79 | try:
80 | for idx, cont, emb in zip(range(len(docs)), docs_content,
81 | embeddings):
82 | entity = {
83 | "parentId": file_info.get("parentId"),
84 | "categoryName": file_info.get("categoryName"),
85 | "categoryId": file_info.get("categoryId"),
86 | "fileId": str(file_info.get("fileId")),
87 | "fileName": fileName,
88 | "fileSuffix": fileSuffix,
89 | "storagePath": storagePath,
90 | "chunkId": idx,
91 | "chunkContent": cont,
92 | "embedding": emb,
93 | }
94 | entities.append(entity)
95 | if len(entities) == batch_size:
96 | self.collection.insert(entities)
97 | self.collection.flush()
98 | entities = []
99 | self.collection.insert(entities)
100 | self.collection.flush()
101 | logger.info("存入向量数据库成功!")
102 | except:
103 | logger.error(str(local_file) + "写入向量数据库失败,请检查!")
104 |
105 | def parse_request(self, data):
106 | # # 解析JSON字符串
107 | # data = json.loads(json_str)
108 | # 初始化一个空列表来存储结果
109 | file_details = []
110 |
111 | # 遍历JSON数据结构
112 | for category in data.sysCategory:
113 | # 首先获取顶级分类中的文件存储信息
114 | for file_storage in category.get("fileStorages", []):
115 | file_details.append({
116 | "fileName": file_storage["fileName"],
117 | "fileSuffix": file_storage["fileSuffix"],
118 | "storagePath": file_storage["storagePath"],
119 | "categoryName": category["categoryName"],
120 | "categoryId": category["categoryId"],
121 | "parentId": category["parentId"],
122 | })
123 |
124 | # 然后获取子分类中的文件存储信息
125 | for sub_category in category.get("subCategory", []):
126 | for file_storage in sub_category.get("fileStorages", []):
127 | file_details.append({
128 | "fileName":
129 | file_storage["fileName"],
130 | "fileSuffix":
131 | file_storage["fileSuffix"],
132 | "storagePath":
133 | file_storage["storagePath"],
134 | "categoryName":
135 | sub_category["categoryName"],
136 | "categoryId":
137 | sub_category["categoryId"],
138 | "parentId":
139 | sub_category["parentId"],
140 | })
141 |
142 | # 打印结果
143 | for detail in file_details:
144 | print(detail)
145 | return file_details
146 |
147 | @property
148 | def fields(self):
149 | fields = [
150 | FieldSchema(
151 | name="id",
152 | dtype=DataType.VARCHAR,
153 | is_primary=True,
154 | auto_id=True,
155 | max_length=100,
156 | ),
157 | FieldSchema(name="parentId",
158 | dtype=DataType.VARCHAR,
159 | max_length=256),
160 | FieldSchema(name="categoryName",
161 | dtype=DataType.VARCHAR,
162 | max_length=256),
163 | FieldSchema(name="categoryId",
164 | dtype=DataType.VARCHAR,
165 | max_length=256),
166 | FieldSchema(name="fileId", dtype=DataType.VARCHAR, max_length=256),
167 | FieldSchema(name="fileName",
168 | dtype=DataType.VARCHAR,
169 | max_length=1024),
170 | FieldSchema(name="fileSuffix",
171 | dtype=DataType.VARCHAR,
172 | max_length=256),
173 | FieldSchema(name="storagePath",
174 | dtype=DataType.VARCHAR,
175 | max_length=256),
176 | FieldSchema(name="chunkId", dtype=DataType.INT64),
177 | FieldSchema(name="chunkContent",
178 | dtype=DataType.VARCHAR,
179 | max_length=1024),
180 | FieldSchema(name="embedding",
181 | dtype=DataType.FLOAT_VECTOR,
182 | dim=_dim), # 向量字段
183 | ]
184 | return fields
185 |
186 | @property
187 | def output_fields(self):
188 | return [
189 | "parentId",
190 | "categoryName",
191 | "categoryId",
192 | "fileId",
193 | "fileName",
194 | "fileSuffix",
195 | "storagePath",
196 | "chunkId",
197 | "chunkContent",
198 | "embedding",
199 | ]
200 |
201 | def delete_collection(self):
202 | self.collection.release()
203 | utility.drop_collection(self.collection_name)
204 | print("向量数据库重置成功!")
205 |
206 | def get_rag_result(self, initInputs, messages):
207 | query = messages[-1].get("content")
208 | logger.info(f"最新的问题是:【{query}】")
209 | query_emb = embedding_client.get_embedding(query)
210 | categoryIds = initInputs.get("categoryIds")
211 | topK = initInputs.get('topK')
212 | score = initInputs.get('score')
213 | logger.info("score:" + str(score))
214 | if len(categoryIds) > 1:
215 | # rerank
216 | rag_results = []
217 | reference_results = []
218 | for idStr in categoryIds:
219 | category_ids = idStr.split(',')
220 | rag_result, retrival_results = self.retrieval_and_generate(
221 | query_emb, topK, score, category_ids, messages)
222 |
223 | rag_results.append(rag_result)
224 | reference_results.extend(retrival_results)
225 |
226 | rereank_results = self.rerank(query, rag_results)
227 | rag_result = rag_results[rereank_results['rerank_ids'].index(0)]
228 | retrival_results = reference_results
229 | else:
230 | category_ids = categoryIds[0].split(',')
231 | rag_result, retrival_results = self.retrieval_and_generate(
232 | query_emb, topK, score, category_ids, messages)
233 | return rag_result, retrival_results
234 |
235 | def retrieval_and_generate(self, query_emb, topK, score, category_ids,
236 | messages):
237 | expr = "categoryId in {}".format(category_ids)
238 | logger.info(expr)
239 | search_params = {
240 | "metric_type": "L2",
241 | "offset": 0,
242 | "ignore_growing": False,
243 | "params": {
244 | "nprobe": 10
245 | }
246 | }
247 |
248 | results = self.collection.search(
249 | data=query_emb,
250 | anns_field="embedding",
251 | param=search_params,
252 | limit=topK,
253 | expr=expr,
254 | output_fields=['fileName', 'chunkContent'],
255 | consistency_level="Strong")
256 | relevant_content = []
257 | for hits in results:
258 | for hit in hits:
259 | print(hit)
260 | print(f"ID: {hit.id}, score: {hit.score}")
261 | print(f"chunkContent: {hit.entity.get('chunkContent')})")
262 | relevant_content.append((hit.score, hit.entity.get("fileName"),
263 | hit.entity.get('chunkContent')))
264 | logger.info("检索到相关片段的数量是:%d" % len(relevant_content))
265 | if len(relevant_content) > 0:
266 | retrival_results = [x for x in relevant_content if x[0] > score]
267 | retrival_results_text = [x[2] for x in retrival_results]
268 | logger.info("检索到的片段:" + str(retrival_results_text))
269 | retrieval_result_str = '\n\n'.join(retrival_results_text)
270 | else:
271 | retrival_results = []
272 | retrieval_result_str = ""
273 |
274 | messages[-1]['content'] = RAG_PROMPT.format(
275 | context=retrieval_result_str, question=messages[-1]['content'])
276 | rag_result = open_chat.chat(messages)
277 |
278 | return rag_result, retrival_results
279 |
280 | # def generate(self, messages, retrieval_result):
281 | # messages[-1]['content'] = RAG_PROMPT.format(
282 | # context=retrieval_result, question=messages[-1]['content'])
283 | # ans = open_chat.chat(messages)
284 | # return ans
285 |
286 | def rerank(self, query, multi_result):
287 | logger.info('进入rerank模块')
288 | rerank_result = rerank_client.rerank(query, multi_result)
289 | return rerank_result
290 |
--------------------------------------------------------------------------------
/app/finrag_server.py:
--------------------------------------------------------------------------------
1 | """
2 | Author: Lucas
3 | Date: 2024-05-22 11:21:41
4 | LastEditors: Lucas
5 | LastEditTime: 2024-05-22 11:24:10
6 | Description: file content
7 | """
8 |
9 | import time
10 | from typing import Any, Dict
11 | import requests
12 | import httpx
13 | import uvicorn
14 | from fastapi import FastAPI, HTTPException,BackgroundTasks
15 | from pydantic import BaseModel
16 | import asyncio
17 | from app.core.chat.open_chat import OpenChat
18 | from app.core.vectorstore.customer_milvus_client import CustomerMilvusClient
19 | from app.models.status import ErrorMsg, SuccessMsg
20 | from conf import config
21 | from utils import logger
22 | import json
23 |
24 | cmc = CustomerMilvusClient()
25 | open_chat = OpenChat()
26 |
27 |
28 | class Item(BaseModel):
29 | syncId: Any
30 | sysCategory: Any
31 |
32 |
33 | class Query(BaseModel):
34 | chatId: Any
35 | ownerId: Any
36 | chatName: Any
37 | initInputs: Dict
38 | initOpening: Any
39 | chatMessages: Any
40 |
41 |
42 | class Notify(BaseModel):
43 | syncId: Any
44 | status: Any
45 |
46 |
47 |
48 | def notify_another(notify_msg: Notify):
49 | logger.info("通知Embedding完成")
50 | data = {"syncId": notify_msg.syncId, "status": notify_msg.status}
51 | response = requests.post(config.NOTIFY_URL,
52 | headers={'content-type':'application/json'},
53 | data=json.dumps(data))
54 | # 打印响应的状态码
55 | print('Status Code:', response.status_code)
56 | if response.status_code==200:
57 | logger.info("消息同步成功!")
58 | # print('Content:', response.message)
59 | return response
60 |
61 | app = FastAPI()
62 |
63 | @app.post("/chat")
64 | async def chat(query: Query):
65 | logger.info("进入Chat")
66 | # logger.info("query:"+str(query.to_dict()))
67 | chatId = query.chatId
68 | ownerId = query.ownerId
69 | chatName = query.chatName
70 | initInputs = query.initInputs
71 | initOpening = query.initOpening
72 | chatMessages = query.chatMessages
73 | # logger.info(
74 | # str(
75 | # {"chatName":chatName,
76 | # "initInputs":initInputs,
77 | # "chatMessages":chatMessages
78 | # }
79 | # )
80 | # )
81 | logger.info("query:\n***********************************"+
82 | query.model_dump_json()+
83 | "\n***********************************")
84 | messages = [{
85 | "role": x.get("role").lower(),
86 | "content": x.get("rawContent")
87 | } for x in chatMessages]
88 |
89 | try:
90 | chunks = []
91 | if len(initInputs.get("categoryIds")) == 0:
92 | # 开放问答
93 | logger.info("进入开放域知识问答,答案由完全由大模型生成!")
94 | response = open_chat.chat(messages)
95 |
96 | else:
97 | logger.info("进入RAG问答,答案由大模型根据知识库生成!")
98 | response, retrieval_results = cmc.get_rag_result(
99 | initInputs, messages)
100 | if len(retrieval_results):
101 | chunks = [{
102 | "index": x[1],
103 | "chunk": x[2],
104 | "score": x[0]
105 | } for x in retrieval_results]
106 | logger.info("chunks"+str(chunks))
107 | messages.append({"role": "assistant", "content": response})
108 | messages.append({
109 | "role": "user",
110 | "content": "根据上面我们的历史对话,为我推荐三个接下来我可能要问的问题。每个问题以?结尾"
111 | })
112 | suggestedQuestions = open_chat.chat(messages)
113 | try:
114 | suggestedQuestions = suggestedQuestions.split('?')
115 | suggestedQuestions=[x.strip() for x in suggestedQuestions[:3]]
116 | except:
117 | suggestedQuestions = [suggestedQuestions]
118 |
119 | if chatName == "知识问答助手":
120 | try:
121 | new_messages = [
122 | {
123 | "role":"system",
124 | "content":"你是一位得力的助手"
125 | },
126 | {
127 | "role":"user",
128 | "content":config.DIALOGUE_SUMMARY.format(context=str(messages[:-1]))
129 | }
130 | ]
131 | chatName = open_chat.chat(new_messages)
132 |
133 | logger.info("大模型总结的chatName:"+str(chatName))
134 | except:
135 | chatName = chatName
136 |
137 | return {
138 | "code": "000000",
139 | "data": {
140 | "event": "MESSAGE",
141 | "result": {
142 | "chatId": chatId,
143 | "chatName": chatName,
144 | "answer": response,
145 | "suggestedQuestions": suggestedQuestions,
146 | "chunks": chunks,
147 | },
148 | },
149 | "message": "调⽤成功",
150 | "success": True,
151 | "time": time.time(),
152 | }
153 | except:
154 | logger.error("RAG问答出错,请检查!")
155 | return ErrorMsg.to_dict()
156 |
157 |
158 | async def async_update(item:Item):
159 | try:
160 | logger.info("开始更新向量")
161 | logger.info("syncId:" + str(item.syncId))
162 | start_time = time.time()
163 | detail = cmc.parse_request(item)
164 | cmc.embedding_to_vdb(detail)
165 | notify_msg = Notify(syncId=item.syncId, status=1)
166 | end_time = time.time()
167 | logger.info("Cost time:" + str(end_time - start_time))
168 | response = notify_another(notify_msg)
169 | # print(response.message)
170 | # return SuccessMsg.to_dict()
171 | except Exception as e:
172 | # 如果在处理请求时发生错误,返回HTTP 400错误
173 | # raise HTTPException(status_code=400, detail=f"An error occurred: {str(e)}.")
174 | logger.error("更新向量发生错误!")
175 | # return ErrorMsg.to_dict()
176 | @app.post("/update_vector")
177 | async def update_vector(item: Item):
178 | asyncio.create_task(async_update(item))
179 | return SuccessMsg.to_dict()
180 |
--------------------------------------------------------------------------------
/app/models/__init__.py:
--------------------------------------------------------------------------------
1 | '''
2 | Author: xubing
3 | Date: 2024-05-19 00:09:54
4 | LastEditors: xubing
5 | LastEditTime: 2024-05-19 00:09:55
6 | Description: file content
7 | '''
8 |
--------------------------------------------------------------------------------
/app/models/dialog.py:
--------------------------------------------------------------------------------
1 | from .status import Status
2 |
3 |
4 | class DialogRequest:
5 | request_id = ""
6 | query = ""
7 | session_id = ""
8 | user_id = ""
9 | session = []
10 | stream = False
11 |
12 |
13 | class DialogResponse:
14 | status = Status
15 | answer = ""
16 | title = ""
17 | recommend_topic = ""
18 | chunks = ""
19 |
--------------------------------------------------------------------------------
/app/models/status.py:
--------------------------------------------------------------------------------
1 | '''
2 | Author: wangjia
3 | Date: 2024-05-19 00:10:00
4 | LastEditors: wangjia
5 | LastEditTime: 2024-05-24 00:02:48
6 | Description: file content
7 | '''
8 | # class Status:
9 | # status_code = ""
10 | # status_msg = ""
11 | import time
12 |
13 |
14 | class ErrorMsg:
15 | code = "xxxxxx"
16 | data = None
17 | message= "调⽤失败"
18 | success= False
19 | @classmethod
20 | def to_dict(cls):
21 | return {
22 | 'code': cls.code,
23 | 'data': cls.data,
24 | 'message': cls.message,
25 | 'success': cls.success,
26 | 'time': time.time(), # 获取当前时间
27 | }
28 | class SuccessMsg():
29 | code = "000000"
30 | data = None
31 | message= "调⽤成功"
32 | success= True
33 | @classmethod
34 | def to_dict(cls):
35 | return {
36 | 'code': cls.code,
37 | 'data': cls.data,
38 | 'message': cls.message,
39 | 'success': cls.success,
40 | 'time': time.time(), # 获取当前时间
41 | }
42 |
43 | if __name__ == '__main__':
44 | sm = SuccessMsg()
45 | sm.to_dict()
--------------------------------------------------------------------------------
/app/oss/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AI4Finance-Foundation/FinRAG/790c6158596747a3ea58ea1dae7972ac0f18d3dd/app/oss/__init__.py
--------------------------------------------------------------------------------
/app/oss/download_file.py:
--------------------------------------------------------------------------------
1 | """
2 | Author: Lucas
3 | Date: 2024-05-21 10:25:18
4 | LastEditors: Lucas
5 | LastEditTime: 2024-05-22 12:55:53
6 | Description: file content
7 | """
8 |
9 | import os
10 | import shutil
11 |
12 | import dotenv
13 | import oss2
14 | from oss2.credentials import EnvironmentVariableCredentialsProvider
15 |
16 | from conf.config import STORAGE_DIR, STORAGE_TYPE
17 |
18 | dotenv.load_dotenv()
19 | endpoint = os.getenv("END_POINT")
20 | bucket_name = os.getenv("BUCKET_NAME")
21 |
22 | class Downloader:
23 | def __init__(self) -> None:
24 | auth = oss2.ProviderAuth(EnvironmentVariableCredentialsProvider())
25 | self.bucket = oss2.Bucket(auth, endpoint, bucket_name)
26 |
27 | def get_file(self, remote, local):
28 | if STORAGE_TYPE=='local':
29 | return self.get_local_file(remote, local)
30 | else:
31 | return self.get_oss_file(remote, local)
32 |
33 | def get_local_file(self,remote,local):
34 | shutil.copy(os.path.join(STORAGE_DIR,remote),local)
35 | return
36 | def get_oss_file(self,remote,local):
37 | # 首先判断本地是否有缓存, 如果有, 跳过; 没有,下载
38 | if local in os.listdir(".cache"):
39 | return
40 | if local is None:
41 | local = ".cache/" + remote.split("/")[-1]
42 | self.bucket.get_object_to_file(remote, local)
43 | return
--------------------------------------------------------------------------------
/bin/start.sh:
--------------------------------------------------------------------------------
1 | uvicorn app.finrag_server:app --host 0.0.0.0 --port 8000 --reload
--------------------------------------------------------------------------------
/conf/config.py:
--------------------------------------------------------------------------------
1 | '''
2 | Author: wangjia
3 | Date: 2024-05-26 23:28:10
4 | LastEditors: wangjia
5 | LastEditTime: 2024-05-30 22:29:58
6 | Description: file content
7 | '''
8 |
9 | import os
10 |
11 | COLLECTION_NAME = "FIN_RAG" # 向量数据库的名称
12 | SENTENCE_SIZE = 500 # 分割的文本长度
13 | DEVICE = "cuda" # 使用cpu还是gpu
14 | EMBEDDING_MODEL = "/data/WoLLM/bce-embedding-base_v1" # embedding 模型的本地路径
15 | RERANK_MODEL = "/data/WoLLM/bce-reranker-base_v1" # rerank 模型的本地路径
16 | MILVUS_URI = "http://localhost:19530" # milvus的uri
17 | NOTIFY_URL = "http://39.96.174.204/api/medical-assistant/knowledge/file/vector/complete" # 向量更新完成后,向后端发送消息的url
18 | CACHE_DIR = ".cache" # 本地缓存路径. 将oss的文件下载这个缓存文件夹里
19 | STORAGE_TYPE="local" # oss # 使用loca还是osss[新增功能]
20 | STORAGE_DIR="/data/storage/" # 本地文件 [新增功能]
21 | if not os.path.exists(CACHE_DIR):
22 | os.mkdir(CACHE_DIR)
23 |
24 | # 对话内容总结标题的prompt
25 | DIALOGUE_SUMMARY = """为以下对话内容总结一个标题
26 | {context}
27 |
28 | 请限制在20个字以内
29 | 你的回复:"""
30 |
31 | # RAG的核心prompt
32 | RAG_PROMPT = """参考信息:
33 | {context}
34 | ---
35 | 我的问题或指令:
36 | {question}
37 | ---
38 | 请根据上述参考信息回答我的问题或回复我的指令。
39 | - 我的问题或指令是什么语种,你就用什么语种回复.
40 | - 前面的参考信息可能有用,也可能没用, 请自行判别。
41 | - 如果前面的参考信息与问题有关,你需要从我给出的参考信息中选出与我的问题最相关的那些,来为你的回答提供依据。
42 | - 如果前面的参考信息与问题无关,请回答:根据参考信息,无法得到问题的答案.
43 | - 不要随机编造答案
44 |
45 | 你的回复:"""
46 |
47 | if __name__ == "__main__":
48 | print(RAG_PROMPT.format(context="Hello", question="world"))
49 |
--------------------------------------------------------------------------------
/docker/docker-compose.yml:
--------------------------------------------------------------------------------
1 | version: '3.5'
2 |
3 | services:
4 | etcd:
5 | container_name: milvus-etcd
6 | image: quay.io/coreos/etcd:v3.5.5
7 | environment:
8 | - ETCD_AUTO_COMPACTION_MODE=revision
9 | - ETCD_AUTO_COMPACTION_RETENTION=1000
10 | - ETCD_QUOTA_BACKEND_BYTES=4294967296
11 | - ETCD_SNAPSHOT_COUNT=50000
12 | volumes:
13 | - /data/volumes/etcd:/etcd
14 | command: etcd -advertise-client-urls=http://127.0.0.1:2379 -listen-client-urls http://0.0.0.0:2379 --data-dir /etcd
15 | healthcheck:
16 | test: ["CMD", "etcdctl", "endpoint", "health"]
17 | interval: 30s
18 | timeout: 20s
19 | retries: 3
20 | minio:
21 | container_name: milvus-minio
22 | image: minio/minio:RELEASE.2023-03-20T20-16-18Z
23 | environment:
24 | MINIO_ACCESS_KEY: minioadmin
25 | MINIO_SECRET_KEY: minioadmin
26 | ports:
27 | - "9001:9001"
28 | - "9002:9000"
29 | volumes:
30 | - /data/volumes/minio:/minio_data
31 | command: minio server /minio_data --console-address ":9001"
32 | healthcheck:
33 | test: ["CMD", "curl", "-f", "http://localhost:9000/minio/health/live"]
34 | interval: 30s
35 | timeout: 20s
36 | retries: 3
37 | standalone:
38 | container_name: milvus-standalone
39 | image: milvusdb/milvus:v2.3.3
40 | command: ["milvus", "run", "standalone"]
41 | security_opt:
42 | - seccomp:unconfined
43 | environment:
44 | ETCD_ENDPOINTS: etcd:2379
45 | MINIO_ADDRESS: minio:9000
46 | volumes:
47 | - /data/volumes/milvus:/var/lib/milvus
48 | healthcheck:
49 | test: ["CMD", "curl", "-f", "http://localhost:9091/healthz"]
50 | interval: 30s
51 | start_period: 90s
52 | timeout: 20s
53 | retries: 3
54 | ports:
55 | - "19530:19530"
56 | - "9091:9091"
57 | depends_on:
58 | - "etcd"
59 | - "minio"
60 | attu:
61 | container_name: milvus-attu
62 | image: zilliz/attu:v2.3.8
63 | ports:
64 | - "3100:3000"
65 | environment:
66 | - MILVUS_URL=http:localhost:19530 # 这里是milvus的UI,这里的URL填写milvus服务器的url
67 |
68 | networks:
69 | default:
70 | name: milvus
--------------------------------------------------------------------------------
/example/test.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AI4Finance-Foundation/FinRAG/790c6158596747a3ea58ea1dae7972ac0f18d3dd/example/test.pdf
--------------------------------------------------------------------------------
/example/test1.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AI4Finance-Foundation/FinRAG/790c6158596747a3ea58ea1dae7972ac0f18d3dd/example/test1.pdf
--------------------------------------------------------------------------------
/example/test2.docx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AI4Finance-Foundation/FinRAG/790c6158596747a3ea58ea1dae7972ac0f18d3dd/example/test2.docx
--------------------------------------------------------------------------------
/example/~$test2.docx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AI4Finance-Foundation/FinRAG/790c6158596747a3ea58ea1dae7972ac0f18d3dd/example/~$test2.docx
--------------------------------------------------------------------------------
/img.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AI4Finance-Foundation/FinRAG/790c6158596747a3ea58ea1dae7972ac0f18d3dd/img.png
--------------------------------------------------------------------------------
/img_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AI4Finance-Foundation/FinRAG/790c6158596747a3ea58ea1dae7972ac0f18d3dd/img_1.png
--------------------------------------------------------------------------------
/img_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AI4Finance-Foundation/FinRAG/790c6158596747a3ea58ea1dae7972ac0f18d3dd/img_2.png
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import uvicorn
2 | from app.finrag_server import app
3 |
4 | if __name__ == '__main__':
5 | uvicorn.run(app, host="0.0.0.0", port=8000)
6 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | fastapi==0.111.0
2 | loguru==0.7.2
3 | openai==1.30.1
4 | numpy==1.26.4
5 | BCEmbedding==0.1.5
6 | pymilvus==2.3.7
7 | langchain==0.2.0
8 | langchain-community==0.2.0
9 | unstructured==0.14.0
10 | pdfminer.six==20231228
11 | pillow_heif==0.16.0
12 | opencv-python==4.9.0.80
13 | pdf2image==1.17.0
14 | unstructured_inference==0.7.31
15 | pytesseract==0.3.10
16 | pikepdf==8.15.1
17 | python-docx==1.1.2
18 | dashscope==1.19.1
19 | oss2==2.18.5
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | # Read requirements.txt, ignore comments
4 | try:
5 | with open("requirements.txt", "r") as f:
6 | REQUIRES = [line.split("#", 1)[0].strip() for line in f if line.strip()]
7 | except:
8 | print("'requirements.txt' not found!")
9 | REQUIRES = list()
10 |
11 | setup(
12 | name="FinRAG",
13 | version="0.0.1",
14 | include_package_data=True,
15 | author="AI4Finance Foundation",
16 | author_email="contact@ai4finance.org",
17 | url="https://github.com/AI4Finance-Foundation/FinRAG",
18 | license="MIT",
19 | packages=find_packages(),
20 | install_requires=REQUIRES,
21 | description="FinRAG: Financial Retrieval Augmented Generation",
22 | long_description="""FinRAG""",
23 | classifiers=[
24 | # Trove classifiers
25 | # Full list: https://pypi.python.org/pypi?%3Aaction=list_classifiers
26 | "License :: OSI Approved :: MIT License",
27 | "Programming Language :: Python",
28 | "Programming Language :: Python :: 3",
29 | "Programming Language :: Python :: 3.6",
30 | "Programming Language :: Python :: 3.7",
31 | "Programming Language :: Python :: 3.8",
32 | "Programming Language :: Python :: 3.9",
33 | "Programming Language :: Python :: 3.10",
34 | "Programming Language :: Python :: 3.11",
35 | "Programming Language :: Python :: Implementation :: CPython",
36 | "Programming Language :: Python :: Implementation :: PyPy",
37 | ],
38 | keywords="Financial Large Language Models, AI Agents, Retrieval Augmented Generation",
39 | platforms=["any"],
40 | python_requires=">=3.10, <3.12",
41 | )
42 |
--------------------------------------------------------------------------------
/test/cz.py:
--------------------------------------------------------------------------------
1 | from app.core.vectorstore.customer_milvus_client import CustomerMilvusClient
2 | if __name__ == '__main__':
3 | CustomerMilvusClient().delete_collection()
--------------------------------------------------------------------------------
/test/http_test/notify 2.http:
--------------------------------------------------------------------------------
1 |
2 | POST http://39.96.174.204/api/medical-assistant/knowledge/file/vector/complete HTTP/1.1
3 | Content-Type: application/json;charset=utf-8
4 |
5 | {
6 | "syncId": "1",
7 | "status": 2
8 | }
--------------------------------------------------------------------------------
/test/http_test/notify.http:
--------------------------------------------------------------------------------
1 |
2 | POST http://127.0.0.1:8000/notify_another/ HTTP/1.1
3 | Content-Type: application/json;charset=utf-8
4 |
5 | {
6 | "syncId": "",
7 | "status": 3
8 | }
--------------------------------------------------------------------------------
/test/http_test/query_1.http:
--------------------------------------------------------------------------------
1 |
2 | POST http://60.205.147.142:8000/chat/ HTTP/1.1
3 | Content-Type: application/json;charset=utf-8
4 |
5 | {
6 | "chatId": "20046261673390243840005710",
7 | "chatMessages": [
8 | {
9 | "chatMessageId": "m00001",
10 | "role": "SYSTEM",
11 | "rawContent": "你是⼀个医疗助⼿"
12 | },
13 | {
14 | "chatId": "20046261673390243840005710",
15 | "chatMessageId": "20056261711947348049925987",
16 | "cid": "sinodata",
17 | "rawContent": "我感冒了吃什么药",
18 | "role": "USER"
19 | }
20 | ],
21 | "chatName": "知识问答助手",
22 | "initInputs": {
23 | "topK": 1,
24 | "score": 10,
25 | "categoryIds": [
26 | "20026250115495299645444018,1010010101"
27 | ],
28 | "language": "zh"
29 | },
30 | "initOpening": [
31 | "我是知识问答助手",
32 | "你有什么想问我的吗?"
33 | ],
34 | "ownerId": "123213123"
35 | }
--------------------------------------------------------------------------------
/test/http_test/query_2.http:
--------------------------------------------------------------------------------
1 |
2 | POST http://127.0.0.1:8000/chat/ HTTP/1.1
3 | Content-Type: application/json;charset=utf-8
4 |
5 | {"chatId":"20046261673390243840005710",
6 | "chatMessages":[
7 | {
8 | "chatMessageId": "m00001",
9 | "role": "SYSTEM",
10 | "rawContent": "你是⼀个医疗助⼿"
11 | },
12 | {"chatId":"20046261673390243840005710","chatMessageId":"20056261758441543434247063",
13 | "cid":"sinodata","gmtCreate":1716559155000,"gmtModified":1716559155000,"id":"4","isDeleted":1,"rating":1,"rawContent":"我感冒了.我需要吃什么药","role":"USER","tokensUsed":0},
14 | {"chatId":"20046261673390243840005710","chatMessageId":"20056261758454000517126380","cid":"sinodata","gmtCreate":1716559155000,"gmtModified":1716559155000,"id":"5","isDeleted":1,"rating":1,"rawContent":"你好,今天天气很棒啊!","role":"ASSISTANT","tokensUsed":0},
15 | {"chatId":"20046261673390243840005710","chatMessageId":"20056261759779903897609217","cid":"sinodata","rawContent":"病毒的定义及其特点?","role":"USER"}],
16 | "chatName":"知识问答助手","initInputs":
17 | {"topK":1,"score":10,"categoryIds":["20026250122028766986246526,20026250123269316280320966","20026250129949416488964807"],"language":"zh"},
18 | "initOpening":["我是知识问答助手","你有什么想问我的吗?"],"ownerId":"123213123"}
--------------------------------------------------------------------------------
/test/http_test/test.http:
--------------------------------------------------------------------------------
1 |
2 | POST http://127.0.0.1:8000/update_vector/ HTTP/1.1
3 | Content-Type: application/json;charset=utf-8
4 |
5 | {
6 | "sysCategory" : [
7 | {
8 | "fileStorages" : [
9 | {
10 | "fileId" : "file00004",
11 | "fileName" : "py⼊⻔到放弃",
12 | "storagePath" : "stream/78/65/00106250790864562421764687.pdf",
13 | "fileSuffix" : "pdf",
14 | "publicUrl" : "http:\/\/www.baidu.com\/xxx\/xxx\/xxxpy.txt",
15 | "fileSize" : "123456"
16 | },
17 | {
18 | "fileId" : "file00005",
19 | "fileName" : "py⼊⻔到放弃2",
20 | "storagePath" : "stream/78/65/00106250790864562421764687.pdf",
21 | "fileSuffix" : "pdf",
22 | "publicUrl" : "http:\/\/www.baidu.com\/xxx\/xxx\/xxxpy2.txt",
23 | "fileSize" : "123456"
24 | }
25 | ],
26 | "subCategory" : [
27 | {
28 | "fileStorages" : [
29 | {
30 | "fileId" : "file00004",
31 | "fileName" : " 《医学病毒学》知识图谱-231229",
32 | "storagePath" : "stream/40/86/00106250791926853795844449.docx",
33 | "fileSuffix" : "docx",
34 | "publicUrl" : "http:\/\/www.baidu.com\/xxx\/xxx\/xxxpy.txt",
35 | "fileSize" : "123456"
36 | },
37 | {
38 | "fileId" : "file00005",
39 | "fileName" : "《学生手册》2023版",
40 | "storagePath" : "stream/85/11/00106250792497027481606005.pdf",
41 | "fileSuffix" : "pdf",
42 | "publicUrl" : "http:\/\/www.baidu.com\/xxx\/xxx\/xxxpy2.txt",
43 | "fileSize" : "123456"
44 | }
45 | ],
46 | "categoryId" : "234324234",
47 | "parentId" : "1010010101",
48 | "categoryName" : "⼆级分类"
49 | }
50 | ],
51 | "categoryId" : "1010010101",
52 | "parentId" : "0",
53 | "categoryName" : "⼀级分类",
54 | "level" : 1
55 | }
56 | ],
57 | "syncId" : "123"
58 | }
59 |
60 |
--------------------------------------------------------------------------------
/test/http_test/test2.http:
--------------------------------------------------------------------------------
1 |
2 | POST http://127.0.0.1:8000/update_vector/ HTTP/1.1
3 | Content-Type: application/json;charset=utf-8
4 |
5 | {
6 | "sysCategory" : [
7 | {
8 |
9 | "subCategory" : [
10 | {
11 | "fileStorages" : [
12 | {
13 | "fileId" : "file00004",
14 | "fileName" : "《医学病毒学》知识图谱-231229",
15 | "storagePath" : "stream/40/86/00106250791926853795844449.docx",
16 | "fileSuffix" : "docx",
17 | "publicUrl" : "http:\/\/www.baidu.com\/xxx\/xxx\/xxxpy.txt",
18 | "fileSize" : "123456"
19 | }
20 | ],
21 | "categoryId" : "234324234",
22 | "parentId" : "1010010101",
23 | "categoryName" : "⼆级分类"
24 | }
25 | ],
26 | "categoryId" : "1010010101",
27 | "parentId" : "0",
28 | "categoryName" : "⼀级分类",
29 | "level" : 1
30 | }
31 | ],
32 | "syncId" : "123"
33 | }
34 |
35 |
--------------------------------------------------------------------------------
/test/http_test/test3.http:
--------------------------------------------------------------------------------
1 |
2 | POST http://127.0.0.1:8000/update_vector/ HTTP/1.1
3 | Content-Type: application/json;charset=utf-8
4 |
5 | {
6 | "sysCategory": [
7 | {
8 | "categoryId": "20026250115495299645444018",
9 | "categoryName": "基础医学",
10 | "categoryType": "KNOWLEDGE_CATEGORY",
11 | "cid": "sinodata",
12 | "description": "",
13 | "gmtCreate": 1716281565000,
14 | "gmtModified": 1716291331000,
15 | "id": "2",
16 | "isDeleted": 1,
17 | "level": 1,
18 | "parentId": "0",
19 | "path": "/0/20026250115495299645444018",
20 | "sort": 1,
21 | "subCategory": [
22 | {
23 | "categoryId": "20026250120594315018249896",
24 | "categoryName": "解刨学",
25 | "categoryType": "KNOWLEDGE_CATEGORY",
26 | "cid": "sinodata",
27 | "description": "",
28 | "fileStorages": [
29 | {
30 | "categoryId": "20026250120594315018249896",
31 | "cid": "sinodata",
32 | "fileId": "20036265406996757872647591",
33 | "fileMeatData": "null",
34 | "fileName": "临床实地局部解剖学.pdf",
35 | "fileSize": -1,
36 | "fileSuffix": "pdf",
37 | "gmtCreate": 1716646143000,
38 | "gmtModified": 1716646143000,
39 | "id": "1",
40 | "isDeleted": 1,
41 | "publicRead": 2,
42 | "storagePath": "stream/88/15/00106265406971424276483376.pdf",
43 | "storageType": 1
44 | }
45 | ],
46 | "gmtCreate": 1716281687000,
47 | "gmtModified": 1716291353000,
48 | "id": "8",
49 | "isDeleted": 1,
50 | "level": 2,
51 | "parentId": "20026250115495299645444018",
52 | "path": "/0/20026250115495299645444018/20026250120594315018249896",
53 | "sort": 1
54 | }
55 | ]
56 | }
57 | ],
58 | "syncId": "20036268858974255185924634"
59 | }
--------------------------------------------------------------------------------
/test/http_test/test4.http:
--------------------------------------------------------------------------------
1 | GET http://127.0.0.1:8001/async-endpoint/ HTTP/1.1
2 | Content-Type: application/json;charset=utf-8
3 |
4 | {}
--------------------------------------------------------------------------------
/test/test.json:
--------------------------------------------------------------------------------
1 |
2 | {
3 | "sysCategory" : [
4 | {
5 | "fileStorages" : [
6 | {
7 | "fileId" : "file00004",
8 | "fileName" : "py⼊⻔到放弃",
9 | "storagePath" : "stream/78/65/00106250790864562421764687.pdf",
10 | "fileSuffix" : "pdf",
11 | "publicUrl" : "http:\/\/www.baidu.com\/xxx\/xxx\/xxxpy.txt",
12 | "fileSize" : "123456"
13 | },
14 | {
15 | "fileId" : "file00005",
16 | "fileName" : "py⼊⻔到放弃2",
17 | "storagePath" : "\/xxx\/xxx\/xxx2.txt",
18 | "fileSuffix" : "txt",
19 | "publicUrl" : "http:\/\/www.baidu.com\/xxx\/xxx\/xxxpy2.txt",
20 | "fileSize" : "123456"
21 | }
22 | ],
23 | "subCategory" : [
24 | {
25 | "fileStorages" : [
26 | {
27 | "fileId" : "file00004",
28 | "fileName" : " 《医学病毒学》知识图谱-231229",
29 | "storagePath" : "stream/40/86/00106250791926853795844449.docx",
30 | "fileSuffix" : "docx",
31 | "publicUrl" : "http:\/\/www.baidu.com\/xxx\/xxx\/xxxpy.txt",
32 | "fileSize" : "123456"
33 | },
34 | {
35 | "fileId" : "file00005",
36 | "fileName" : "《学生手册》2023版",
37 | "storagePath" : "stream/85/11/00106250792497027481606005.pdf",
38 | "fileSuffix" : "pdf",
39 | "publicUrl" : "http:\/\/www.baidu.com\/xxx\/xxx\/xxxpy2.txt",
40 | "fileSize" : "123456"
41 | }
42 | ],
43 | "categoryId" : "234324234",
44 | "parentId" : "1010010101",
45 | "categoryName" : "⼆级分类"
46 | }
47 | ],
48 | "categoryId" : "1010010101",
49 | "parentId" : "0",
50 | "categoryName" : "⼀级分类",
51 | "level" : 1
52 | }
53 | ],
54 | "syncId" : "123"
55 | }
56 |
57 |
--------------------------------------------------------------------------------
/test/test.py:
--------------------------------------------------------------------------------
1 | """
2 | Author: wangjia
3 | Date: 2024-05-19 22:18:54
4 | LastEditors: wangjia
5 | LastEditTime: 2024-05-20 03:33:35
6 | Description: file content
7 | """
8 |
9 | from pymilvus import Collection, CollectionSchema, DataType, FieldSchema, connections
10 |
11 | from app.core.bce.embedding_client import EmbeddingClient
12 |
13 | connections.connect("default", host="localhost", port="19530")
14 | embedding_client = EmbeddingClient("/data/WoLLM/bce-embedding-base_v1")
15 |
16 | # 读取txt文档,句子列表
17 | f = open("test.txt", "r", encoding="utf-8")
18 | doc = f.read().splitlines()
19 |
20 | embeddings = embedding_client.get_embedding(doc)
21 |
22 | fields = [
23 | FieldSchema(
24 | name="id", dtype=DataType.VARCHAR, is_primary=True, auto_id=True, max_length=100
25 | ),
26 | FieldSchema(name="kb_name", dtype=DataType.VARCHAR, max_length=100),
27 | FieldSchema(name="file_name", dtype=DataType.VARCHAR, max_length=100),
28 | FieldSchema(name="chunk_id", dtype=DataType.INT64),
29 | FieldSchema(name="chunk_content", dtype=DataType.VARCHAR, max_length=100),
30 | FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=768), # 向量字段
31 | ]
32 |
33 | schema = CollectionSchema(fields=fields)
34 | # 创建 collection
35 | # print("client is connected:", client.is_connected())
36 | collection = Collection("test_collection3", schema=schema)
37 | # 添加索引
38 | index_params = {
39 | "metric_type": "L2",
40 | "index_type": "IVF_FLAT",
41 | "params": {"nlist": 128},
42 | }
43 | # 字段 filmVector 创建索引
44 | collection.create_index("embedding", index_params)
45 |
46 | entities = [["test_kb"] * len(doc), doc, range(len(doc)), doc, embeddings]
47 | collection.insert(entities)
48 | # 记得在插入数据后调用 flush 来确保数据被写入到磁盘
49 | collection.flush()
50 | print(collection)
51 | collection.load()
52 |
--------------------------------------------------------------------------------
/test/test2.py:
--------------------------------------------------------------------------------
1 | '''
2 | Author: wangjia
3 | Date: 2024-05-19 22:42:44
4 | LastEditors: wangjia
5 | LastEditTime: 2024-05-19 22:42:47
6 | Description: file content
7 | '''
8 | from pymilvus import connections, utility
9 |
10 | # 创建连接
11 | connections.connect("default", host='localhost', port='19530')
12 |
13 | # 检查连接是否正常
14 | try:
15 | status = utility.get_connection_addr(alias="default")
16 | print("Connected to Milvus server:", status)
17 | except Exception as e:
18 | print("Something went wrong:", e)
19 |
--------------------------------------------------------------------------------
/test/test3.py:
--------------------------------------------------------------------------------
1 | """
2 | Author: wangjia
3 | Date: 2024-05-21 01:29:51
4 | LastEditors: wangjia
5 | LastEditTime: 2024-05-21 10:53:28
6 | Description: file content
7 | """
8 |
9 | from app.core.bce.embedding_client import EmbeddingClient
10 | from app.core.preprocessor.file_processor import FileProcesser
11 |
12 | if __name__ == "__main__":
13 | file_path = "example/test3.txt"
14 | l1_kb = "test1"
15 | l2_kb = "test2"
16 | embedding_model = "/data/gpu/base_models/bge-large-zh-v1.5"
17 | lf = FileProcesser(l1_kb, l2_kb, file_path)
18 | embedding_client = EmbeddingClient(embedding_model)
19 |
20 | docs = lf.split_file_to_docs()
21 | docs_content = [doc.page_content for doc in docs]
22 | embeddings = embedding_client.get_embedding(docs_content)
23 | print(embeddings[0])
24 | print(embeddings.shape)
25 |
--------------------------------------------------------------------------------
/test/test_async.py:
--------------------------------------------------------------------------------
1 | # main.py
2 | from fastapi import FastAPI
3 | import asyncio
4 | import uvicorn
5 | app = FastAPI()
6 |
7 | async def print_hello_world():
8 | # 模拟异步 I/O
9 | # await asyncio.sleep(1)
10 |
11 | print("Hello, World!")
12 |
13 | @app.get("/async-endpoint")
14 | async def root():
15 | # 模拟异步 I/O,这里可以是任何 I/O 操作,例如读写数据库、文件等
16 | asyncio.create_task(print_hello_world())
17 | return {"message": "Task started"}
18 |
19 | if __name__ == '__main__':
20 | uvicorn.run(app, host="0.0.0.0", port=8001)
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | '''
2 | Author: wangjia
3 | Date: 2024-05-19 00:04:14
4 | LastEditors: wangjia
5 | LastEditTime: 2024-05-26 14:38:05
6 | Description: file content
7 | '''
8 | from loguru import logger
9 | from datetime import datetime
10 | import time
11 | import sys
12 |
13 |
14 | class Logger:
15 |
16 | @classmethod
17 | def get_logger(self):
18 | folder_ = "logs/"
19 | prefix_ = "mylog-"
20 | rotation_ = "10 MB"
21 | retention_ = "30 days"
22 | encoding_ = "utf-8"
23 | backtrace_ = True
24 | diagnose_ = True
25 |
26 | # 格式里面添加了process和thread记录,方便查看多进程和线程程序
27 | format_ = (
28 | "{time:YYYY-MM-DD HH:mm:ss.SSS}"
29 | "|{level: <7}"
30 | "|{name}:{function}:{line}"
31 | "|{message}")
32 |
33 | # 这里面采用了层次式的日志记录方式,就是低级日志文件会记录比他高的所有级别日志,这样可以做到低等级日志最丰富,高级别日志更少更关键
34 | # debug
35 | logger.add(
36 | folder_ + prefix_ + "debug.log",
37 | level="DEBUG",
38 | backtrace=backtrace_,
39 | diagnose=diagnose_,
40 | format=format_,
41 | colorize=False,
42 | rotation=rotation_,
43 | retention=retention_,
44 | encoding=encoding_,
45 | filter=lambda record: record["level"].no >= logger.level("DEBUG").
46 | no,
47 | )
48 |
49 | # info
50 | logger.add(
51 | folder_ + prefix_ + "info.log",
52 | level="INFO",
53 | backtrace=backtrace_,
54 | diagnose=diagnose_,
55 | format=format_,
56 | colorize=False,
57 | rotation=rotation_,
58 | retention=retention_,
59 | encoding=encoding_,
60 | filter=lambda record: record["level"].no >= logger.level("INFO").
61 | no,
62 | )
63 |
64 | # warning
65 | logger.add(
66 | folder_ + prefix_ + "warning.log",
67 | level="WARNING",
68 | backtrace=backtrace_,
69 | diagnose=diagnose_,
70 | format=format_,
71 | colorize=False,
72 | rotation=rotation_,
73 | retention=retention_,
74 | encoding=encoding_,
75 | filter=lambda record: record["level"].no >= logger.level("WARNING")
76 | .no,
77 | )
78 |
79 | # error
80 | logger.add(
81 | folder_ + prefix_ + "error.log",
82 | level="ERROR",
83 | backtrace=backtrace_,
84 | diagnose=diagnose_,
85 | format=format_,
86 | colorize=False,
87 | rotation=rotation_,
88 | retention=retention_,
89 | encoding=encoding_,
90 | filter=lambda record: record["level"].no >= logger.level("ERROR").
91 | no,
92 | )
93 |
94 | # critical
95 | logger.add(
96 | folder_ + prefix_ + "critical.log",
97 | level="CRITICAL",
98 | backtrace=backtrace_,
99 | diagnose=diagnose_,
100 | format=format_,
101 | colorize=False,
102 | rotation=rotation_,
103 | retention=retention_,
104 | encoding=encoding_,
105 | filter=lambda record: record["level"].no >= logger.level("CRITICAL"
106 | ).no,
107 | )
108 |
109 | logger.add(
110 | sys.stderr,
111 | level="CRITICAL",
112 | backtrace=backtrace_,
113 | diagnose=diagnose_,
114 | format=format_,
115 | colorize=True,
116 | filter=lambda record: record["level"].no >= logger.level("CRITICAL"
117 | ).no,
118 | )
119 |
120 | return logger
121 |
122 |
123 | # 自定义日志输出
124 | logger = Logger.get_logger()
125 |
126 | # 获取当前时间,并格式化日期时间
127 | current_time = datetime.now()
128 | current_date = current_time.strftime("%Y-%m-%d")
129 | formatted_time = current_time.strftime("%Y-%m-%d %H:%M:%S")
130 |
131 |
132 | # 计算时间函数
133 | def timeit(func):
134 |
135 | def wrapper(*args, **kw):
136 | start_time = time.time()
137 | result = func(*args, **kw)
138 | cost_time = time.time() - start_time
139 | print("==" * 25)
140 | print("Current Function [%s] run time is %s s" %
141 | (func.__name__, cost_time))
142 | print("==" * 25)
143 | return result
144 |
145 | return wrapper
146 |
--------------------------------------------------------------------------------