├── 01.png ├── 02.png ├── 03.png ├── README.md ├── basicTest ├── .env ├── chromaPipelineTest1.py ├── chromaPipelineTest2.py ├── componentTest.py └── inMemoryPipelineTest.py ├── ragTest ├── .env ├── apiTest.py ├── input │ ├── llama2.pdf │ └── 健康档案.pdf ├── main.py ├── queryTest.py ├── tools │ ├── __pycache__ │ │ ├── pdfSplitTest_Ch.cpython-311.pyc │ │ └── pdfSplitTest_En.cpython-311.pyc │ ├── pdfSplitTest_Ch.py │ └── pdfSplitTest_En.py └── vectorSaveTest.py └── requirements.txt /01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NanGePlus/HaystackTest/c762261395178f1e77c8ff537a788fbdf21bde53/01.png -------------------------------------------------------------------------------- /02.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NanGePlus/HaystackTest/c762261395178f1e77c8ff537a788fbdf21bde53/02.png -------------------------------------------------------------------------------- /03.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NanGePlus/HaystackTest/c762261395178f1e77c8ff537a788fbdf21bde53/03.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 1、介绍 2 | ## 1.1 主要内容 3 | 使用Haystack开源框架实现RAG(Retrieval Augmented Generation 检索增强生成)应用,核心思想:人找知识,会查资料;LLM找知识,会查向量数据库 4 | 使用LangChain开源框架实现RAG应用项目地址:https://github.com/NanGePlus/RagLangChainTest 5 | 本次应用案例实现功能为: 6 | **(1)离线步骤(构建索引):** 文档加载->文档切分->向量化->灌入向量数据库 7 | **(2)在线步骤(检索增强生成):** 获取用户问题->用户问题向量化->检索向量数据库->将检索结果和用户问题填入prompt模版->用最终的prompt调用LLM->由LLM生成回复 8 | 9 | 相关视频: 10 | https://www.bilibili.com/video/BV1nXkxYQEVp/ 11 | https://youtu.be/sI_vxzGeOUY 12 | 13 | ## 1.2 Haystack框架 14 | Haystack是一个开源框架,它由deepset开发,用于构建强大的QA问答、检索增强生成RAG等AI应用,支持构建从小型本地化应用到大规模生产级应用多种场景 15 | 官方网址:https://haystack.deepset.ai/ 16 | Github地址:https://github.com/deepset-ai/haystack 17 | **核心概念介绍:** 18 | **DocumentStore 文档存储:** 文档存储库是一个数据库,它存储你的数据,并在查询时将数据提供给检索器(Retriever) 19 | 20 | **Components 组件:** 是构成pipeline的重要组成部分,组件负责处理相关的业务逻辑 21 | **Pipelines 流水线:** 将不同的功能组件进行逻辑编排,形成一个有向无环图,最后运行pipeline 22 | 23 | 24 | 25 | # 2、前期准备工作 26 | ## 2.1 开发环境搭建:anaconda、pycharm 27 | anaconda:提供python虚拟环境,官网下载对应系统版本的安装包安装即可 28 | pycharm:提供集成开发环境,官网下载社区版本安装包安装即可 29 | **可参考如下视频:** 30 | 集成开发环境搭建Anaconda+PyCharm 31 | https://www.bilibili.com/video/BV1q9HxeEEtT/?vd_source=30acb5331e4f5739ebbad50f7cc6b949 32 | https://youtu.be/myVgyitFzrA 33 | 34 | ## 2.2 大模型相关配置 35 | (1)GPT大模型使用方案(第三方代理方式) 36 | (2)非GPT大模型(阿里通义千问、讯飞星火、智谱等大模型)使用方案(OneAPI方式) 37 | (3)本地开源大模型使用方案(Ollama方式) 38 | **可参考如下视频:** 39 | 提供一种LLM集成解决方案,一份代码支持快速同时支持gpt大模型、国产大模型(通义千问、文心一言、百度千帆、讯飞星火等)、本地开源大模型(Ollama) 40 | https://www.bilibili.com/video/BV12PCmYZEDt/?vd_source=30acb5331e4f5739ebbad50f7cc6b949 41 | https://youtu.be/CgZsdK43tcY 42 | 43 | 44 | # 3、项目初始化 45 | ## 3.1 下载源码 46 | GitHub或Gitee中下载工程文件到本地,下载地址如下: 47 | https://github.com/NanGePlus/HaystackTest 48 | https://gitee.com/NanGePlus/HaystackTest 49 | 50 | ## 3.2 构建项目 51 | 使用pycharm构建一个项目,为项目配置虚拟python环境 52 | 项目名称:HaystackTest 53 | 虚拟环境名称保持与项目名称一致 54 | 55 | ## 3.3 将相关代码拷贝到项目工程中 56 | 将下载的代码文件夹中的文件全部拷贝到新建的项目根目录下 57 | 58 | ## 3.4 安装项目依赖 59 | 新建命令行终端,在终端中运行 pip install -r requirements.txt 安装依赖 60 | **注意:** 建议先使用要求的对应版本进行本项目测试,避免因版本升级造成的代码不兼容。测试通过后,可进行升级测试 61 | 62 | 63 | # 4、功能测试 64 | ### (1)测试文档准备 65 | 这里以pdf文件为例,在input文件夹下准备了两份pdf文件 66 | 健康档案.pdf:测试中文pdf文档处理 67 | llama2.pdf:测试英文pdf文档处理 68 | ### (2)大模型准备 69 | **gpt大模型(使用代理方案):** 70 | OPENAI_BASE_URL=https://yunwu.ai/v1 71 | OPENAI_API_KEY=sk-5tKSZtEo4WsXKZJE8v4JeFqV8eNf6GwYwJFgT5JFJ42DP7qe 72 | OPENAI_CHAT_MODEL=gpt-4o-mini 73 | OPENAI_EMBEDDING_MODEL=text-embedding-3-small 74 | **非gpt大模型(使用OneAPI方案):** 75 | OPENAI_BASE_URL=http://139.224.72.218:3000/v1 76 | OPENAI_API_KEY=sk-VIm8DGiCtF5Dc46pEd393967Bf554e7a8dA5A8AeFfDcCd75 77 | OPENAI_CHAT_MODEL=qwen-plus 78 | OPENAI_EMBEDDING_MODEL=text-embedding-v1 79 | ### (3)离线步骤(构建索引) 80 | 文档加载->文档切分->向量化->灌入向量数据库 81 | 在根目录下的ragTest/tools文件夹下提供了pdfSplitTest_Ch.py脚本工具用来处理中文文档、pdfSplitTest_En.py脚本工具用来处理英文文档 82 | 在根目录下的ragTest/vectorSaveTest.py脚本执行调用tools中的工具进行文档预处理后进行向量计算及灌库(Chroma向量数据库) 83 | 打开命令行终端,进入脚本所在目录,运行 python vectorSaveTest.py 命令 84 | ### (4-1)在线步骤(检索增强生成),测试demo 85 | 获取用户问题->用户问题向量化->检索向量数据库->将检索结果和用户问题填入prompt模版->用最终的prompt调用LLM->由LLM生成回复 86 | 在根目录下的ragTest/queryTest.py脚本实现核心业务逻辑 87 | 打开命令行终端,进入脚本所在目录,运行 python queryTest.py 命令 88 | ### (4-2)在线步骤(检索增强生成),封装为API接口对外提供服务 89 | 获取用户问题->用户问题向量化->检索向量数据库->将检索结果和用户问题填入prompt模版->用最终的prompt调用LLM->由LLM生成回复 90 | 在根目录下的ragTest/main.py脚本实现核心业务逻辑并封装为API接口对外提供服务 91 | 在根目录下的ragTest/apiTest.py脚本实现POST请求,调用main服务API接口进行检索增强生成 92 | 打开命令行终端,进入脚本所在目录,首先运行 python main.py 命令启动API接口服务 93 | 再新打开一个命令行终端,进入脚本所在目录,运行 python apiTest.py 命令进行POST请求 94 | -------------------------------------------------------------------------------- /basicTest/.env: -------------------------------------------------------------------------------- 1 | OPENAI_BASE_URL=https://yunwu.ai/v1 2 | OPENAI_API_KEY=sk-5tKSZtEo4WsXKZJE8v4JeFqV8eNf6GwYwJFgT5JFJ42DP7qe 3 | OPENAI_CHAT_MODEL=gpt-4o-mini 4 | OPENAI_EMBEDDING_MODEL = text-embedding-3-small 5 | -------------------------------------------------------------------------------- /basicTest/chromaPipelineTest1.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dotenv import load_dotenv 3 | from haystack import Pipeline, Document 4 | from haystack.components.generators import OpenAIGenerator 5 | from haystack.components.builders.prompt_builder import PromptBuilder 6 | from haystack.utils import Secret 7 | from haystack.components.writers import DocumentWriter 8 | from haystack_integrations.document_stores.chroma import ChromaDocumentStore 9 | from haystack_integrations.components.retrievers.chroma import ChromaQueryTextRetriever 10 | 11 | 12 | 13 | 14 | # 加载环境变量参数 15 | load_dotenv() 16 | 17 | # 创建一个Chroma中的文档存储实例 18 | document_store = ChromaDocumentStore(persist_path="ChromDB001") 19 | 20 | 21 | # 灌入向量数据库 22 | # 测试文档 23 | documents=[ 24 | Document(content="My name is Jean and I live in Paris.", meta={"title": "one"}), 25 | Document(content="My name is Mark and I live in Berlin.", meta={"title": "two"}), 26 | Document(content="My name is Giorgio and I live in Rome.", meta={"title": "three"}) 27 | ] 28 | 29 | # 写入向量数据库 30 | writer = DocumentWriter(document_store) 31 | 32 | # 创建一个新的流水线对象 33 | indexing_pipeline = Pipeline() 34 | # 添加组件 name:组件名称 instance:组件实例 35 | indexing_pipeline.add_component("writer", writer) 36 | # 运行流水线,并传入每个组件的初始输入 37 | results = indexing_pipeline.run( 38 | data={ 39 | "writer": {"documents": documents} 40 | }, 41 | include_outputs_from={"writer"} 42 | ) 43 | 44 | # 运行结果,结果是一个嵌套字典 45 | print(f"results:{results}\n") 46 | 47 | 48 | # 检索 49 | # 定义prompt模版 使用Jinja2 模板语法 50 | prompt_template = """ 51 | Given these documents, answer the question. 52 | Documents: 53 | {% for doc in documents %} 54 | {{ doc.content }} 55 | {% endfor %} 56 | Question: {{question}} 57 | Answer: 58 | """ 59 | 60 | # 用于从文档存储中根据查询找到最相关的文档 61 | # ChromaQueryTextRetriever 使用默认的Embedding模型 62 | # 模型所在位置:/Users/username/.cache/chroma/onnx_models/all-MiniLM-L6-v2 63 | retriever = ChromaQueryTextRetriever(document_store=document_store) 64 | 65 | # 使用prompt模板构建自定义prompt 66 | prompt_builder = PromptBuilder(template=prompt_template) 67 | 68 | # 设置调用 OpenAI Chat模型 生成内容 69 | llm = OpenAIGenerator( 70 | api_base_url=os.getenv("OPENAI_BASE_URL"), 71 | api_key=Secret.from_env_var("OPENAI_API_KEY"), 72 | model=os.getenv("OPENAI_CHAT_MODEL") 73 | ) 74 | 75 | # 创建一个新的流水线对象 76 | query_pipeline = Pipeline() 77 | 78 | # 添加组件 name:组件名称 instance:组件实例 79 | query_pipeline.add_component("retriever", retriever) 80 | query_pipeline.add_component("prompt_builder", prompt_builder) 81 | query_pipeline.add_component("llm", llm) 82 | 83 | # 连接组件 84 | query_pipeline.connect("retriever", "prompt_builder.documents") 85 | query_pipeline.connect("prompt_builder", "llm") 86 | 87 | # 定义问题 88 | question = "Who lives in Paris?" 89 | # 运行流水线,并传入每个组件的初始输入 90 | results = query_pipeline.run( 91 | data={ 92 | "retriever": {"query": question, "top_k": 3}, 93 | "prompt_builder": {"question": question}, 94 | }, 95 | include_outputs_from={"retriever","prompt_builder"} 96 | ) 97 | 98 | # 运行结果,结果是一个嵌套字典 99 | print(f"results:{results}\n") 100 | 101 | # 从嵌套字典中取出结果 102 | response = results["llm"]["replies"] 103 | print(f"response:{response}\n") 104 | 105 | 106 | 107 | 108 | 109 | 110 | -------------------------------------------------------------------------------- /basicTest/chromaPipelineTest2.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dotenv import load_dotenv 3 | from haystack import Pipeline, Document 4 | from haystack.components.generators import OpenAIGenerator 5 | from haystack.components.builders.prompt_builder import PromptBuilder 6 | from haystack.utils import Secret 7 | from haystack.components.writers import DocumentWriter 8 | from haystack_integrations.document_stores.chroma import ChromaDocumentStore 9 | from haystack_integrations.components.retrievers.chroma import ChromaEmbeddingRetriever 10 | from haystack.components.embedders import OpenAITextEmbedder, OpenAIDocumentEmbedder 11 | 12 | 13 | 14 | 15 | # 加载环境变量参数 16 | load_dotenv() 17 | 18 | # 创建一个Chroma中的文档存储实例 19 | document_store = ChromaDocumentStore(persist_path="ChromDB002") 20 | 21 | 22 | # 灌入向量数据库 23 | # 测试文档 24 | documents=[ 25 | Document(content="My name is Jean and I live in Paris.", meta={"title": "one"}), 26 | Document(content="My name is Mark and I live in Berlin.", meta={"title": "two"}), 27 | Document(content="My name is Giorgio and I live in Rome.", meta={"title": "three"}) 28 | ] 29 | 30 | # 写入向量数据库 31 | writer = DocumentWriter(document_store) 32 | 33 | # 设置调用 OpenAI Embedding模型 进行向量处理 34 | # 这里注意,构建向量索引使用OpenAIDocumentEmbedder 35 | index_embedder = OpenAIDocumentEmbedder( 36 | api_base_url=os.getenv("OPENAI_BASE_URL"), 37 | api_key=Secret.from_env_var("OPENAI_API_KEY"), 38 | model=os.getenv("OPENAI_EMBEDDING_MODEL") 39 | ) 40 | 41 | # 创建一个新的流水线对象 42 | indexing_pipeline = Pipeline() 43 | # 添加组件 name:组件名称 instance:组件实例 44 | indexing_pipeline.add_component("index_embedder", index_embedder) 45 | indexing_pipeline.add_component("writer", writer) 46 | indexing_pipeline.connect("index_embedder.documents", "writer.documents") 47 | 48 | # 运行流水线,并传入每个组件的初始输入 49 | results = indexing_pipeline.run( 50 | data={ 51 | "index_embedder": {"documents": documents} 52 | }, 53 | include_outputs_from={"writer"} 54 | ) 55 | # 运行结果,结果是一个嵌套字典 56 | print(f"results:{results}\n") 57 | 58 | 59 | # 检索 60 | # 定义prompt模版 使用Jinja2 模板语法 61 | prompt_template = """ 62 | Given these documents, answer the question. 63 | Documents: 64 | {% for doc in documents %} 65 | {{ doc.content }} 66 | {% endfor %} 67 | Question: {{question}} 68 | Answer: 69 | """ 70 | 71 | # 设置调用 OpenAI Embedding模型 进行向量处理 72 | # 这里注意,查询使用OpenAITextEmbedder 73 | text_embedder = OpenAITextEmbedder( 74 | api_base_url=os.getenv("OPENAI_BASE_URL"), 75 | api_key=Secret.from_env_var("OPENAI_API_KEY"), 76 | model=os.getenv("OPENAI_EMBEDDING_MODEL") 77 | ) 78 | 79 | # 设置调用 OpenAI Chat模型 生成内容 80 | llm = OpenAIGenerator( 81 | api_base_url=os.getenv("OPENAI_BASE_URL"), 82 | api_key=Secret.from_env_var("OPENAI_API_KEY"), 83 | model=os.getenv("OPENAI_CHAT_MODEL") 84 | ) 85 | 86 | # 用于从文档存储中根据查询找到最相关的文档 87 | retriever = ChromaEmbeddingRetriever(document_store=document_store) 88 | 89 | # 使用prompt模板构建自定义prompt 90 | prompt_builder = PromptBuilder(template=prompt_template) 91 | 92 | # 创建一个新的流水线对象 93 | query_pipeline = Pipeline() 94 | 95 | # 添加组件 name:组件名称 instance:组件实例 96 | query_pipeline.add_component("text_embedder", text_embedder) 97 | query_pipeline.add_component("retriever", retriever) 98 | query_pipeline.add_component("prompt_builder", prompt_builder) 99 | query_pipeline.add_component("llm", llm) 100 | 101 | # 连接组件 102 | query_pipeline.connect("text_embedder.embedding", "retriever.query_embedding") 103 | query_pipeline.connect("retriever", "prompt_builder.documents") 104 | query_pipeline.connect("prompt_builder", "llm") 105 | 106 | # 定义问题 107 | question = "Who lives in Paris?" 108 | 109 | # 运行流水线,并传入每个组件的初始输入 110 | results = query_pipeline.run( 111 | data={ 112 | "text_embedder": {"text": question}, 113 | "retriever": {"top_k": 2}, 114 | "prompt_builder": {"question": question}, 115 | }, 116 | include_outputs_from={"retriever","prompt_builder"} 117 | ) 118 | 119 | # 运行结果,结果是一个嵌套字典 120 | print(f"results:{results}\n") 121 | 122 | # 从嵌套字典中取出最终结果 123 | response = results["llm"]["replies"] 124 | print(f"response:{response}\n") 125 | 126 | 127 | 128 | 129 | -------------------------------------------------------------------------------- /basicTest/componentTest.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | from haystack import component, Pipeline 3 | 4 | 5 | 6 | # 声明这是一个 Haystack 组件,能够集成到 Pipeline 中 7 | # 定义了一个类 生成个性化的欢迎信息并将其转为大写 8 | @component 9 | class WelcomeTextGenerator: 10 | # 定义组件的输出类型 11 | # welcome_text 是一个字符串,用于存储欢迎消息 12 | # note 是一个字符串,用于存储注释信息 13 | @component.output_types(welcome_text=str, note=str) 14 | # 定义组件的核心逻辑。name 是方法的输入参数,用于接受用户输入的名字 15 | # 返回一个字典,包含 welcome_text 和 note 两个键 16 | def run(self, name: str): 17 | return {"welcome_text": ('Hello {name}, welcome to Haystack!'.format(name=name)).upper(), 18 | "note": "welcome message is ready"} 19 | 20 | 21 | # 声明这是一个 Haystack 组件,能够集成到 Pipeline 中 22 | # 定义了一个类 根据空格拆分文本 23 | @component 24 | class WhitespaceSplitter: 25 | # 定义输出类型 26 | # splitted_text 是一个字符串列表,存储拆分后的文本 27 | @component.output_types(splitted_text=List[str]) 28 | # 定义组件的核心逻辑。text 是方法的输入参数 29 | # 返回一个字典,包含键 splitted_text 30 | def run(self, text: str): 31 | return {"splitted_text": text.split()} 32 | 33 | 34 | # 实例化一个Pipeline 对象 35 | text_pipeline = Pipeline() 36 | # 向流水线中添加组件 37 | # name 是组件的标识符,用于引用组件 38 | # instance 是组件的实例 39 | text_pipeline.add_component(name="welcome_text_generator", instance=WelcomeTextGenerator()) 40 | text_pipeline.add_component(name="splitter", instance=WhitespaceSplitter()) 41 | 42 | # 连接流水线中的组件 将 welcome_text_generator 的输出字段 welcome_text 连接到 splitter 的输入字段 text 43 | # sender 表示发送数据的组件和其输出字段 44 | # receiver 表示接收数据的组件和其输入字段 45 | text_pipeline.connect(sender="welcome_text_generator.welcome_text", receiver="splitter.text") 46 | 47 | # 运行流水线 48 | # 一个字典,用于为每个组件提供初始输入 49 | result = text_pipeline.run({"welcome_text_generator": {"name": "Bilge"}}) 50 | 51 | 52 | # 运行结果,结果是一个嵌套字典 53 | print(f"result:{result}\n") 54 | # 从嵌套字典中取出结果 55 | response = result["splitter"]["splitted_text"] 56 | print(f"response:{response}\n") -------------------------------------------------------------------------------- /basicTest/inMemoryPipelineTest.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dotenv import load_dotenv 3 | from idlelib.rpc import response_queue 4 | from haystack import Pipeline, Document 5 | from haystack.document_stores.in_memory import InMemoryDocumentStore 6 | from haystack.components.retrievers import InMemoryBM25Retriever 7 | from haystack.components.generators import OpenAIGenerator 8 | from haystack.components.builders.prompt_builder import PromptBuilder 9 | from haystack.utils import Secret 10 | 11 | 12 | 13 | 14 | # 加载环境变量参数 15 | load_dotenv() 16 | 17 | # 创建一个内存中的文档存储实例 18 | # 文档将存储在内存中,不需要外部数据库支持 19 | document_store = InMemoryDocumentStore() 20 | 21 | # 向文档存储中添加文档 22 | # 每个文档包含一段内容 content 23 | document_store.write_documents([ 24 | Document(content="My name is Jean and I live in Paris."), 25 | Document(content="My name is Mark and I live in Berlin."), 26 | Document(content="My name is Giorgio and I live in Rome.") 27 | ]) 28 | 29 | # 定义prompt模版 使用Jinja2 模板语法 30 | prompt_template = """ 31 | Given these documents, answer the question. 32 | Documents: 33 | {% for doc in documents %} 34 | {{ doc.content }} 35 | {% endfor %} 36 | Question: {{question}} 37 | Answer: 38 | """ 39 | 40 | # 基于 BM25 算法的检索器 41 | # 用于从文档存储中根据查询找到最相关的文档 42 | retriever = InMemoryBM25Retriever(document_store=document_store) 43 | 44 | # 使用prompt模板构建自定义prompt 45 | prompt_builder = PromptBuilder(template=prompt_template) 46 | 47 | # 设置调用 OpenAI Chat模型 生成内容 48 | # Secret用于安全地管理敏感信息 49 | llm = OpenAIGenerator( 50 | api_base_url=os.getenv("OPENAI_BASE_URL"), 51 | api_key=Secret.from_token(os.getenv("OPENAI_API_KEY")), 52 | # api_key=Secret.from_env_var("OPENAI_API_KEY"), 53 | model=os.getenv("OPENAI_CHAT_MODEL") 54 | ) 55 | 56 | # 创建一个新的流水线对象 57 | pipeline = Pipeline() 58 | 59 | # 添加组件 name:组件名称 instance:组件实例 60 | pipeline.add_component("retriever", retriever) 61 | pipeline.add_component("prompt_builder", prompt_builder) 62 | pipeline.add_component("llm", llm) 63 | 64 | # 连接组件 65 | # retriever 的输出(相关文档列表)作为 prompt_builder.documents 的输入 66 | pipeline.connect("retriever", "prompt_builder.documents") 67 | # prompt_builder 的输出(生成的提示)作为 llm 的输入 68 | pipeline.connect("prompt_builder", "llm") 69 | 70 | # 定义问题 71 | question = "Who lives in Paris?" 72 | 73 | # 1、运行流水线,并传入每个组件的初始输入 74 | results = pipeline.run( 75 | data={ 76 | "retriever": {"query": question}, 77 | "prompt_builder": {"question": question}, 78 | }, 79 | # include_outputs_from={"retriever","prompt_builder"} 80 | ) 81 | # 运行结果,结果是一个嵌套字典 82 | print(f"results:{results}\n") 83 | # 从嵌套字典中取出结果 84 | response = results["llm"]["replies"] 85 | print(f"response:{response}\n") 86 | 87 | # # 2、流水线可视化 保存为图片 88 | # pipeline.draw(path="test.png") 89 | # 90 | # # 3、序列化 保存到YAML文件 91 | # print(pipeline.dumps()) 92 | # with open("test.yml", "w") as file: 93 | # pipeline.dump(file) 94 | 95 | 96 | 97 | 98 | 99 | 100 | -------------------------------------------------------------------------------- /ragTest/.env: -------------------------------------------------------------------------------- 1 | OPENAI_BASE_URL=https://yunwu.ai/v1 2 | OPENAI_API_KEY=sk-5tKSZtEo4WsXKZJE8v4JeFqV8eNf6GwYwJFgT5JFJ42DP7qe 3 | OPENAI_CHAT_MODEL=gpt-4o-mini 4 | OPENAI_EMBEDDING_MODEL = text-embedding-3-small 5 | -------------------------------------------------------------------------------- /ragTest/apiTest.py: -------------------------------------------------------------------------------- 1 | import requests 2 | import json 3 | import logging 4 | 5 | 6 | # 设置日志模版 7 | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') 8 | logger = logging.getLogger(__name__) 9 | 10 | 11 | url = "http://localhost:8012/v1/chat/completions" 12 | headers = {"Content-Type": "application/json"} 13 | 14 | # 默认非流式输出 True or False 15 | stream_flag = False 16 | 17 | # 测试1 18 | data = { 19 | "messages": [{"role": "user", "content": "张三九的基本信息是什么"}], 20 | "stream": stream_flag, 21 | } 22 | 23 | # # 测试2 24 | # data = { 25 | # "messages": [{"role": "user", "content": "张三九的配偶是谁以及其联系方式"}], 26 | # "stream": stream_flag, 27 | # } 28 | 29 | # # 测试3 测试幻觉 30 | # data = { 31 | # "messages": [{"role": "user", "content": "LangChain是什么,详细介绍下?"}], 32 | # "stream": stream_flag, 33 | # } 34 | 35 | # 接收流式输出 36 | if stream_flag: 37 | try: 38 | with requests.post(url, stream=True, headers=headers, data=json.dumps(data)) as response: 39 | for line in response.iter_lines(): 40 | if line: 41 | json_str = line.decode('utf-8').strip("data: ") 42 | # 检查是否为空或不合法的字符串 43 | if not json_str: 44 | logger.info(f"收到空字符串,跳过...") 45 | continue 46 | # 确保字符串是有效的JSON格式 47 | if json_str.startswith('{') and json_str.endswith('}'): 48 | try: 49 | data = json.loads(json_str) 50 | if data['choices'][0]['finish_reason'] == "stop": 51 | logger.info(f"接收JSON数据结束") 52 | else: 53 | logger.info(f"流式输出,响应内容是: {data['choices'][0]['delta']['content']}") 54 | except json.JSONDecodeError as e: 55 | logger.info(f"JSON解析错误: {e}") 56 | else: 57 | print(f"无效JSON格式: {json_str}") 58 | except Exception as e: 59 | print(f"Error occurred: {e}") 60 | 61 | # 接收非流式输出处理 62 | else: 63 | # 发送post请求 64 | response = requests.post(url, headers=headers, data=json.dumps(data)) 65 | # logger.info(f"接收到返回的响应原始内容: {response.json()}\n") 66 | content = response.json()['choices'][0]['message']['content'] 67 | logger.info(f"非流式输出,响应内容是: {content}\n") -------------------------------------------------------------------------------- /ragTest/input/llama2.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NanGePlus/HaystackTest/c762261395178f1e77c8ff537a788fbdf21bde53/ragTest/input/llama2.pdf -------------------------------------------------------------------------------- /ragTest/input/健康档案.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NanGePlus/HaystackTest/c762261395178f1e77c8ff537a788fbdf21bde53/ragTest/input/健康档案.pdf -------------------------------------------------------------------------------- /ragTest/main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import json 4 | import asyncio 5 | import uuid 6 | import time 7 | import logging 8 | from contextlib import asynccontextmanager 9 | from pydantic import BaseModel, Field 10 | from typing import List, Optional, Dict 11 | from dotenv import load_dotenv 12 | from haystack import Pipeline, Document 13 | from haystack.components.generators import OpenAIGenerator 14 | from haystack.components.builders.prompt_builder import PromptBuilder 15 | from haystack.utils import Secret 16 | from haystack_integrations.document_stores.chroma import ChromaDocumentStore 17 | from haystack_integrations.components.retrievers.chroma import ChromaEmbeddingRetriever 18 | from haystack.components.embedders import OpenAITextEmbedder 19 | from fastapi import FastAPI, HTTPException, Request 20 | from fastapi.responses import JSONResponse, StreamingResponse 21 | import uvicorn 22 | 23 | 24 | 25 | # 设置日志模版 26 | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') 27 | logger = logging.getLogger(__name__) 28 | 29 | # 加载环境变量参数 30 | load_dotenv() 31 | 32 | # 全局变量 33 | document_store = None 34 | text_embedder = None 35 | llm = None 36 | retriever = None 37 | prompt_builder = None 38 | query_pipeline = None 39 | 40 | # 指定向量数据库chromaDB的存储位置和集合 根据自己的实际情况进行调整 41 | CHROMADB_DIRECTORY = "chromaDB" # chromaDB向量数据库的持久化路径 42 | CHROMADB_COLLECTION_NAME = "demo001" # 待查询的chromaDB向量数据库的集合名称 43 | 44 | # API服务设置相关 根据自己的实际情况进行调整 45 | PORT = 8012 # 服务访问的端口 46 | 47 | 48 | 49 | # 定义Message类 50 | class Message(BaseModel): 51 | role: str 52 | content: str 53 | 54 | # 定义ChatCompletionRequest类 55 | class ChatCompletionRequest(BaseModel): 56 | messages: List[Message] 57 | stream: Optional[bool] = False 58 | 59 | # 定义ChatCompletionResponseChoice类 60 | class ChatCompletionResponseChoice(BaseModel): 61 | index: int 62 | message: Message 63 | finish_reason: Optional[str] = None 64 | 65 | # 定义ChatCompletionResponse类 66 | class ChatCompletionResponse(BaseModel): 67 | id: str = Field(default_factory=lambda: f"chatcmpl-{uuid.uuid4().hex}") 68 | object: str = "chat.completion" 69 | created: int = Field(default_factory=lambda: int(time.time())) 70 | choices: List[ChatCompletionResponseChoice] 71 | system_fingerprint: Optional[str] = None 72 | 73 | 74 | # 格式化响应,对输入的文本进行段落分隔、添加适当的换行符,以及在代码块中增加标记,以便生成更具可读性的输出 75 | def format_response(response): 76 | # 使用正则表达式 \n{2, }将输入的response按照两个或更多的连续换行符进行分割。这样可以将文本分割成多个段落,每个段落由连续的非空行组成 77 | paragraphs = re.split(r'\n{2,}', response) 78 | # 空列表,用于存储格式化后的段落 79 | formatted_paragraphs = [] 80 | # 遍历每个段落进行处理 81 | for para in paragraphs: 82 | # 检查段落中是否包含代码块标记 83 | if '```' in para: 84 | # 将段落按照```分割成多个部分,代码块和普通文本交替出现 85 | parts = para.split('```') 86 | for i, part in enumerate(parts): 87 | # 检查当前部分的索引是否为奇数,奇数部分代表代码块 88 | if i % 2 == 1: # 这是代码块 89 | # 将代码块部分用换行符和```包围,并去除多余的空白字符 90 | parts[i] = f"\n```\n{part.strip()}\n```\n" 91 | # 将分割后的部分重新组合成一个字符串 92 | para = ''.join(parts) 93 | else: 94 | # 否则,将句子中的句点后面的空格替换为换行符,以便句子之间有明确的分隔 95 | para = para.replace('. ', '.\n') 96 | # 将格式化后的段落添加到formatted_paragraphs列表 97 | # strip()方法用于移除字符串开头和结尾的空白字符(包括空格、制表符 \t、换行符 \n等) 98 | formatted_paragraphs.append(para.strip()) 99 | # 将所有格式化后的段落用两个换行符连接起来,以形成一个具有清晰段落分隔的文本 100 | return '\n\n'.join(formatted_paragraphs) 101 | 102 | 103 | # 定义了一个异步函数 lifespan,它接收一个FastAPI应用实例app作为参数。这个函数将管理应用的生命周期,包括启动和关闭时的操作 104 | # 函数在应用启动时执行一些初始化操作,如设置搜索引擎、加载上下文数据、以及初始化问题生成器 105 | # 函数在应用关闭时执行一些清理操作 106 | # @asynccontextmanager 装饰器用于创建一个异步上下文管理器,它允许你在 yield 之前和之后执行特定的代码块,分别表示启动和关闭时的操作 107 | @asynccontextmanager 108 | async def lifespan(app: FastAPI): 109 | # 启动时执行 110 | # 申明引用全局变量,在函数中被初始化,并在整个应用中使用 111 | global CHROMADB_DIRECTORY, CHROMADB_COLLECTION_NAME 112 | global document_store, text_embedder, llm, retriever, prompt_builder, query_pipeline 113 | try: 114 | logger.info("正在初始化...") 115 | # 创建一个Chroma中的文档存储实例 116 | document_store = ChromaDocumentStore(persist_path=CHROMADB_DIRECTORY, collection_name=CHROMADB_COLLECTION_NAME) 117 | # 定义prompt模版 118 | prompt_template = """ 119 | 你是一个针对健康档案进行问答的机器人。 120 | 你的任务是根据下述给定的已知信息回答用户问题。 121 | 122 | 已知信息: 123 | {% for doc in documents %} 124 | {{ doc.content }} 125 | {% endfor %} 126 | 127 | 用户问: 128 | {{question}} 129 | 130 | 如果已知信息不包含用户问题的答案,或者已知信息不足以回答用户的问题,请直接回复"我无法回答您的问题"。 131 | 请不要输出已知信息中不包含的信息或答案。 132 | 请不要输出已知信息中不包含的信息或答案。 133 | 请不要输出已知信息中不包含的信息或答案。 134 | 请用中文回答用户问题。 135 | """ 136 | # 设置调用 OpenAI Embedding模型 进行向量处理 137 | text_embedder = OpenAITextEmbedder( 138 | api_base_url=os.getenv("OPENAI_BASE_URL"), 139 | api_key=Secret.from_env_var("OPENAI_API_KEY"), 140 | model=os.getenv("OPENAI_EMBEDDING_MODEL") 141 | ) 142 | # 设置调用 OpenAI Chat模型 生成内容 143 | llm = OpenAIGenerator( 144 | api_base_url=os.getenv("OPENAI_BASE_URL"), 145 | api_key=Secret.from_env_var("OPENAI_API_KEY"), 146 | model=os.getenv("OPENAI_CHAT_MODEL") 147 | ) 148 | # 用于从文档存储中根据查询找到最相关的文档 149 | retriever = ChromaEmbeddingRetriever(document_store=document_store) 150 | # 使用prompt模板构建自定义prompt 151 | prompt_builder = PromptBuilder(template=prompt_template) 152 | # # 创建一个新的流水线对象 153 | query_pipeline = Pipeline() 154 | # 添加组件 name:组件名称 instance:组件实例 155 | query_pipeline.add_component("text_embedder", text_embedder) 156 | query_pipeline.add_component("retriever", retriever) 157 | query_pipeline.add_component("prompt_builder", prompt_builder) 158 | query_pipeline.add_component("llm", llm) 159 | # 连接组件 160 | query_pipeline.connect("text_embedder.embedding", "retriever.query_embedding") 161 | query_pipeline.connect("retriever", "prompt_builder.documents") 162 | query_pipeline.connect("prompt_builder", "llm") 163 | 164 | logger.info("初始化完成!") 165 | 166 | except Exception as e: 167 | logger.error(f"初始化过程中出错: {str(e)}") 168 | # raise 关键字重新抛出异常,以确保程序不会在错误状态下继续运行 169 | raise 170 | 171 | # yield 关键字将控制权交还给FastAPI框架,使应用开始运行 172 | # 分隔了启动和关闭的逻辑。在yield 之前的代码在应用启动时运行,yield 之后的代码在应用关闭时运行 173 | yield 174 | # 关闭时执行 175 | logger.info("正在关闭...") 176 | 177 | 178 | # lifespan 参数用于在应用程序生命周期的开始和结束时执行一些初始化或清理工作 179 | app = FastAPI(lifespan=lifespan) 180 | 181 | 182 | # POST请求接口,与大模型进行知识问答 183 | @app.post("/v1/chat/completions") 184 | async def chat_completions(request: ChatCompletionRequest): 185 | # 判断初始化是否完成 186 | if not document_store or not text_embedder or not retriever or not prompt_builder or not llm or not query_pipeline: 187 | logger.error("服务未初始化") 188 | raise HTTPException(status_code=500, detail="服务未初始化") 189 | 190 | try: 191 | logger.info(f"收到聊天完成请求: {request}") 192 | question = request.messages[-1].content 193 | logger.info(f"用户问题是: {question}") 194 | # 运行流水线,并传入每个组件的初始输入 195 | results = query_pipeline.run( 196 | data={ 197 | "text_embedder": {"text": question}, 198 | "retriever": {"top_k": 2}, 199 | "prompt_builder": {"question": question}, 200 | }, 201 | include_outputs_from={"retriever", "prompt_builder"} 202 | ) 203 | # 对结果进行格式化处理 204 | formatted_response = str(format_response(results["llm"]["replies"][0])) 205 | logger.info(f"格式化的搜索结果: {formatted_response}") 206 | 207 | # 处理流式响应 208 | if request.stream: 209 | # 定义一个异步生成器函数,用于生成流式数据 210 | async def generate_stream(): 211 | # 为每个流式数据片段生成一个唯一的chunk_id 212 | chunk_id = f"chatcmpl-{uuid.uuid4().hex}" 213 | # 将格式化后的响应按行分割 214 | lines = formatted_response.split('\n') 215 | # 历每一行,并构建响应片段 216 | for i, line in enumerate(lines): 217 | # 创建一个字典,表示流式数据的一个片段 218 | chunk = { 219 | "id": chunk_id, 220 | "object": "chat.completion.chunk", 221 | "created": int(time.time()), 222 | # "model": request.model, 223 | "choices": [ 224 | { 225 | "index": 0, 226 | "delta": {"content": line + '\n'}, # if i > 0 else {"role": "assistant", "content": ""}, 227 | "finish_reason": None 228 | } 229 | ] 230 | } 231 | # 将片段转换为JSON格式并生成 232 | yield f"{json.dumps(chunk)}\n" 233 | # 每次生成数据后,异步等待0.5秒 234 | await asyncio.sleep(0.5) 235 | # 生成最后一个片段,表示流式响应的结束 236 | final_chunk = { 237 | "id": chunk_id, 238 | "object": "chat.completion.chunk", 239 | "created": int(time.time()), 240 | "choices": [ 241 | { 242 | "index": 0, 243 | "delta": {}, 244 | "finish_reason": "stop" 245 | } 246 | ] 247 | } 248 | yield f"{json.dumps(final_chunk)}\n" 249 | 250 | # 返回fastapi.responses中StreamingResponse对象,流式传输数据 251 | # media_type设置为text/event-stream以符合SSE(Server-SentEvents) 格式 252 | return StreamingResponse(generate_stream(), media_type="text/event-stream") 253 | 254 | # 处理非流式响应处理 255 | else: 256 | response = ChatCompletionResponse( 257 | choices=[ 258 | ChatCompletionResponseChoice( 259 | index=0, 260 | message=Message(role="assistant", content=formatted_response), 261 | finish_reason="stop" 262 | ) 263 | ] 264 | ) 265 | logger.info(f"发送响应内容: \n{response}") 266 | # 返回fastapi.responses中JSONResponse对象 267 | # model_dump()方法通常用于将Pydantic模型实例的内容转换为一个标准的Python字典,以便进行序列化 268 | return JSONResponse(content=response.model_dump()) 269 | 270 | except Exception as e: 271 | logger.error(f"处理聊天完成时出错:\n\n {str(e)}") 272 | raise HTTPException(status_code=500, detail=str(e)) 273 | 274 | 275 | 276 | 277 | if __name__ == "__main__": 278 | logger.info(f"在端口 {PORT} 上启动服务器") 279 | # uvicorn是一个用于运行ASGI应用的轻量级、超快速的ASGI服务器实现 280 | # 用于部署基于FastAPI框架的异步PythonWeb应用程序 281 | uvicorn.run(app, host="0.0.0.0", port=PORT) 282 | 283 | 284 | -------------------------------------------------------------------------------- /ragTest/queryTest.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dotenv import load_dotenv 3 | from haystack import Pipeline, Document 4 | from haystack.components.generators import OpenAIGenerator 5 | from haystack.components.builders.prompt_builder import PromptBuilder 6 | from haystack.utils import Secret 7 | from haystack_integrations.document_stores.chroma import ChromaDocumentStore 8 | from haystack_integrations.components.retrievers.chroma import ChromaEmbeddingRetriever 9 | from haystack.components.embedders import OpenAITextEmbedder 10 | 11 | 12 | 13 | 14 | # 指定向量数据库chromaDB的存储位置和集合 根据自己的实际情况进行调整 15 | CHROMADB_DIRECTORY = "chromaDB" # chromaDB向量数据库的持久化路径 16 | CHROMADB_COLLECTION_NAME = "demo001" # 待查询的chromaDB向量数据库的集合名称 17 | 18 | # 加载环境变量参数 19 | load_dotenv() 20 | 21 | # 创建一个Chroma中的文档存储实例 22 | document_store = ChromaDocumentStore(persist_path=CHROMADB_DIRECTORY,collection_name=CHROMADB_COLLECTION_NAME) 23 | 24 | # 定义prompt模版 使用Jinja2 模板语法 25 | prompt_template = """ 26 | 你是一个针对健康档案进行问答的机器人。 27 | 你的任务是根据下述给定的已知信息回答用户问题。 28 | 29 | 已知信息: 30 | {% for doc in documents %} 31 | {{ doc.content }} 32 | {% endfor %} 33 | 34 | 用户问: 35 | {{question}} 36 | 37 | 如果已知信息不包含用户问题的答案,或者已知信息不足以回答用户的问题,请直接回复"我无法回答您的问题"。 38 | 请不要输出已知信息中不包含的信息或答案。 39 | 请不要输出已知信息中不包含的信息或答案。 40 | 请不要输出已知信息中不包含的信息或答案。 41 | 请用中文回答用户问题。 42 | """ 43 | 44 | # 设置调用 OpenAI Embedding模型 进行向量处理 45 | # 这里注意,查询使用OpenAITextEmbedder 46 | text_embedder = OpenAITextEmbedder( 47 | api_base_url=os.getenv("OPENAI_BASE_URL"), 48 | api_key=Secret.from_env_var("OPENAI_API_KEY"), 49 | model=os.getenv("OPENAI_EMBEDDING_MODEL") 50 | ) 51 | 52 | # 设置调用 OpenAI Chat模型 生成内容 53 | llm = OpenAIGenerator( 54 | api_base_url=os.getenv("OPENAI_BASE_URL"), 55 | api_key=Secret.from_env_var("OPENAI_API_KEY"), 56 | model=os.getenv("OPENAI_CHAT_MODEL") 57 | ) 58 | 59 | # 用于从文档存储中根据查询找到最相关的文档 60 | retriever = ChromaEmbeddingRetriever(document_store=document_store) 61 | 62 | # 使用prompt模板构建自定义prompt 63 | prompt_builder = PromptBuilder(template=prompt_template) 64 | 65 | # 创建一个新的流水线对象 66 | query_pipeline = Pipeline() 67 | 68 | # 添加组件 name:组件名称 instance:组件实例 69 | query_pipeline.add_component("text_embedder", text_embedder) 70 | query_pipeline.add_component("retriever", retriever) 71 | query_pipeline.add_component("prompt_builder", prompt_builder) 72 | query_pipeline.add_component("llm", llm) 73 | 74 | # 连接组件 75 | query_pipeline.connect("text_embedder.embedding", "retriever.query_embedding") 76 | query_pipeline.connect("retriever", "prompt_builder.documents") 77 | query_pipeline.connect("prompt_builder", "llm") 78 | 79 | # 定义问题 80 | question = "张三九的基本信息是什么" 81 | 82 | # 运行流水线,并传入每个组件的初始输入 83 | results = query_pipeline.run( 84 | data={ 85 | "text_embedder": {"text": question}, 86 | "retriever": {"top_k": 2}, 87 | "prompt_builder": {"question": question}, 88 | }, 89 | include_outputs_from={"retriever","prompt_builder"} 90 | ) 91 | 92 | # 运行结果,结果是一个嵌套字典 93 | print(f"results:{results}\n") 94 | 95 | # 从嵌套字典中取出最终结果 96 | response = results["llm"]["replies"] 97 | print(f"response:{response}\n") 98 | 99 | 100 | 101 | 102 | -------------------------------------------------------------------------------- /ragTest/tools/__pycache__/pdfSplitTest_Ch.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NanGePlus/HaystackTest/c762261395178f1e77c8ff537a788fbdf21bde53/ragTest/tools/__pycache__/pdfSplitTest_Ch.cpython-311.pyc -------------------------------------------------------------------------------- /ragTest/tools/__pycache__/pdfSplitTest_En.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NanGePlus/HaystackTest/c762261395178f1e77c8ff537a788fbdf21bde53/ragTest/tools/__pycache__/pdfSplitTest_En.cpython-311.pyc -------------------------------------------------------------------------------- /ragTest/tools/pdfSplitTest_Ch.py: -------------------------------------------------------------------------------- 1 | # 功能说明:将PDF文件进行文本预处理,适用中文 2 | # 准备工作:安装相关包 3 | # pip install pdfminer.six 4 | 5 | # 导入相关库 6 | import logging 7 | from pdfminer.high_level import extract_pages 8 | from pdfminer.layout import LTTextContainer 9 | import re 10 | 11 | 12 | # 设置日志模版 13 | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') 14 | logger = logging.getLogger(__name__) 15 | 16 | # 当处理中文文本时,按照标点进行断句 17 | def sent_tokenize(input_string): 18 | sentences = re.split(r'(?<=[。!?;?!])', input_string) 19 | # 去掉空字符串 20 | return [sentence for sentence in sentences if sentence.strip()] 21 | 22 | 23 | # PDF文档处理函数,从PDF文件中按指定页码提取文字 24 | def extract_text_from_pdf(filename, page_numbers, min_line_length): 25 | # 申明变量 26 | paragraphs = [] 27 | buffer = '' 28 | full_text = '' 29 | # 提取全部文本并按照一行一行进行截取,并在每一行后面加上换行符 30 | for i, page_layout in enumerate(extract_pages(filename)): 31 | # 如果指定了页码范围,跳过范围外的页 32 | if page_numbers is not None and i not in page_numbers: 33 | continue 34 | for element in page_layout: 35 | if isinstance(element, LTTextContainer): 36 | full_text += element.get_text() + '\n' 37 | # full_text:将文件按照一行一行进行截取,并在每一行后面加上换行符 38 | # logger.info(f"full_text: {full_text}") 39 | 40 | 41 | # 按空行分隔,将文本重新组织成段落 42 | # lines:将full_text按照换行符进行切割,此时空行则为空(‘’) 43 | lines = full_text.split('\n') 44 | # logger.info(f"lines: {lines}") 45 | 46 | # 将lines进行循环,取出每一个片段(text)进行处理合并成段落,处理逻辑为: 47 | # (1)首先判断text的最小行的长度是否大于min_line_length设置的值 48 | # (2)如果大于min_line_length,则将该text拼接在buffer后面,如果该text不是以连字符“-”结尾,则在行前加上一个空格;如果该text是以连字符“-”结尾,则去掉连字符) 49 | # (3)如果小于min_line_length且buffer中有内容,则将其添加到 paragraphs 列表中 50 | # (4)最后,处理剩余的缓冲区内容,在遍历结束后,如果 buffer 中仍有内容,则将其添加到 paragraphs 列表中 51 | for text in lines: 52 | if len(text) >= min_line_length: 53 | buffer += (' '+text) if not text.endswith('-') else text.strip('-') 54 | elif buffer: 55 | paragraphs.append(buffer) 56 | buffer = '' 57 | if buffer: 58 | paragraphs.append(buffer) 59 | # logger.info(f"paragraphs: {paragraphs[:10]}") 60 | 61 | # 其返回值为划分段落的文本列表 62 | return paragraphs 63 | 64 | 65 | # 将PDF文档处理函数得到的文本列表再按一定粒度,部分重叠式的切割文本,使上下文更完整 66 | # chunk_size:每个文本块的目标大小(以字符为单位),默认为 800 67 | # overlap_size:块之间的重叠大小(以字符为单位),默认为 200 68 | def split_text(paragraphs, chunk_size=800, overlap_size=200): 69 | # 按指定 chunk_size 和 overlap_size 交叠割文本 70 | sentences = [s.strip() for p in paragraphs for s in sent_tokenize(p)] 71 | chunks = [] 72 | i = 0 73 | while i < len(sentences): 74 | chunk = sentences[i] 75 | overlap = '' 76 | prev_len = 0 77 | prev = i - 1 78 | # 向前计算重叠部分 79 | while prev >= 0 and len(sentences[prev])+len(overlap) <= overlap_size: 80 | overlap = sentences[prev] + ' ' + overlap 81 | prev -= 1 82 | chunk = overlap+chunk 83 | next = i + 1 84 | # 向后计算当前chunk 85 | while next < len(sentences) and len(sentences[next])+len(chunk) <= chunk_size: 86 | chunk = chunk + ' ' + sentences[next] 87 | next += 1 88 | chunks.append(chunk) 89 | i = next 90 | # logger.info(f"chunks: {chunks[0:10]}") 91 | return chunks 92 | 93 | 94 | def getParagraphs(filename, page_numbers, min_line_length): 95 | paragraphs = extract_text_from_pdf(filename, page_numbers, min_line_length) 96 | chunks = split_text(paragraphs, 800, 200) 97 | return chunks 98 | 99 | 100 | if __name__ == "__main__": 101 | # 测试 PDF文档按一定条件处理成文本数据 102 | paragraphs = getParagraphs( 103 | "../input/健康档案.pdf", 104 | # page_numbers=[2, 3], # 指定页面 105 | page_numbers=None, # 加载全部页面 106 | min_line_length=1 107 | ) 108 | # 测试前3条文本 109 | logger.info(f"只展示3段截取片段:") 110 | logger.info(f"截取的片段1: {paragraphs[0]}") 111 | logger.info(f"截取的片段2: {paragraphs[2]}") 112 | logger.info(f"截取的片段3: {paragraphs[3]}") 113 | 114 | 115 | 116 | -------------------------------------------------------------------------------- /ragTest/tools/pdfSplitTest_En.py: -------------------------------------------------------------------------------- 1 | # 功能说明:将PDF文件进行文本预处理,适用英文 2 | # 准备工作:安装相关包 3 | # pip install pdfminer.six 4 | # pip install nltk 5 | 6 | # 导入相关库 7 | import logging 8 | from pdfminer.high_level import extract_pages 9 | from pdfminer.layout import LTTextContainer 10 | import nltk 11 | 12 | 13 | # 设置日志模版 14 | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') 15 | logger = logging.getLogger(__name__) 16 | 17 | 18 | # 当处理英文文本时,按照该条件进行断句 19 | from nltk.tokenize import sent_tokenize 20 | # # 运行后直接下载使用 21 | # nltk.download('punkt_tab') 22 | # 也可从本地加载punk_tab 23 | nltk.data.path.append('../other/punkt_tab') 24 | 25 | 26 | # PDF文档处理函数,从PDF文件中按指定页码提取文字 27 | def extract_text_from_pdf(filename, page_numbers, min_line_length): 28 | # 申明变量 29 | paragraphs = [] 30 | buffer = '' 31 | full_text = '' 32 | # 提取全部文本并按照一行一行进行截取,并在每一行后面加上换行符 33 | for i, page_layout in enumerate(extract_pages(filename)): 34 | # 如果指定了页码范围,跳过范围外的页 35 | if page_numbers is not None and i not in page_numbers: 36 | continue 37 | for element in page_layout: 38 | if isinstance(element, LTTextContainer): 39 | full_text += element.get_text() + '\n' 40 | # full_text:将文件按照一行一行进行截取,并在每一行后面加上换行符 41 | # logger.info(f"full_text: {full_text}") 42 | 43 | # 按空行分隔,将文本重新组织成段落 44 | # lines:将full_text按照换行符进行切割,此时空行则为空(‘’) 45 | lines = full_text.split('\n') 46 | # logger.info(f"lines: {lines}") 47 | 48 | # 将lines进行循环,取出每一个片段(text)进行处理合并成段落,处理逻辑为: 49 | # (1)首先判断text的最小行的长度是否大于min_line_length设置的值 50 | # (2)如果大于min_line_length,则将该text拼接在buffer后面,如果该text不是以连字符“-”结尾,则在行前加上一个空格;如果该text是以连字符“-”结尾,则去掉连字符) 51 | # (3)如果小于min_line_length且buffer中有内容,则将其添加到 paragraphs 列表中 52 | # (4)最后,处理剩余的缓冲区内容,在遍历结束后,如果 buffer 中仍有内容,则将其添加到 paragraphs 列表中 53 | for text in lines: 54 | if len(text) >= min_line_length: 55 | buffer += (' '+text) if not text.endswith('-') else text.strip('-') 56 | elif buffer: 57 | paragraphs.append(buffer) 58 | buffer = '' 59 | if buffer: 60 | paragraphs.append(buffer) 61 | # logger.info(f"paragraphs: {paragraphs[:10]}") 62 | 63 | # 其返回值为划分段落的文本列表 64 | return paragraphs 65 | 66 | 67 | # 将PDF文档处理函数得到的文本列表再按一定粒度,部分重叠式的切割文本,使上下文更完整 68 | # chunk_size:每个文本块的目标大小(以字符为单位),默认为 800 69 | # overlap_size:块之间的重叠大小(以字符为单位),默认为 200 70 | def split_text(paragraphs, chunk_size=800, overlap_size=200): 71 | # 按指定 chunk_size 和 overlap_size 交叠割文本 72 | sentences = [s.strip() for p in paragraphs for s in sent_tokenize(p)] 73 | chunks = [] 74 | i = 0 75 | while i < len(sentences): 76 | chunk = sentences[i] 77 | overlap = '' 78 | prev_len = 0 79 | prev = i - 1 80 | # 向前计算重叠部分 81 | while prev >= 0 and len(sentences[prev])+len(overlap) <= overlap_size: 82 | overlap = sentences[prev] + ' ' + overlap 83 | prev -= 1 84 | chunk = overlap+chunk 85 | next = i + 1 86 | # 向后计算当前chunk 87 | while next < len(sentences) and len(sentences[next])+len(chunk) <= chunk_size: 88 | chunk = chunk + ' ' + sentences[next] 89 | next += 1 90 | chunks.append(chunk) 91 | i = next 92 | # logger.info(f"chunks: {chunks[0:10]}") 93 | return chunks 94 | 95 | 96 | def getParagraphs(filename, page_numbers, min_line_length): 97 | paragraphs = extract_text_from_pdf(filename, page_numbers, min_line_length) 98 | chunks = split_text(paragraphs, 800, 200) 99 | return chunks 100 | 101 | 102 | if __name__ == "__main__": 103 | # 测试 PDF文档按一定条件处理成文本数据 104 | paragraphs = getParagraphs( 105 | "../input/llama2.pdf", 106 | page_numbers=[2, 3],# 指定页面 107 | # page_numbers=None,#加载全部页面 108 | min_line_length=1 109 | ) 110 | 111 | # 测试前3条文本 112 | logger.info(f"只展示3段截取片段:") 113 | logger.info(f"截取的片段1: {paragraphs[0]}") 114 | logger.info(f"截取的片段2: {paragraphs[2]}") 115 | logger.info(f"截取的片段3: {paragraphs[3]}") 116 | 117 | 118 | 119 | -------------------------------------------------------------------------------- /ragTest/vectorSaveTest.py: -------------------------------------------------------------------------------- 1 | # 功能说明:将PDF文件进行向量计算并持久化存储到向量数据库(chroma) 2 | import logging 3 | import os 4 | from openai import OpenAI 5 | import chromadb 6 | import uuid 7 | import numpy as np 8 | from dotenv import load_dotenv 9 | from tools import pdfSplitTest_Ch 10 | from tools import pdfSplitTest_En 11 | 12 | 13 | 14 | # 设置日志模版 15 | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') 16 | logger = logging.getLogger(__name__) 17 | 18 | # 加载环境变量参数 19 | load_dotenv() 20 | 21 | # 设置测试文本类型 22 | TEXT_LANGUAGE = 'Chinese' #Chinese 或 English 23 | # TEXT_LANGUAGE = 'English' #Chinese 或 English 24 | 25 | # 测试的pdf文件路径 26 | INPUT_PDF = "input/健康档案.pdf" 27 | # INPUT_PDF = "input/llama2.pdf" 28 | 29 | # 指定文件中待处理的页码,全部页码则填None 30 | PAGE_NUMBERS=None 31 | # PAGE_NUMBERS=[2, 3] 32 | 33 | # 指定向量数据库chromaDB的存储位置和集合 根据自己的实际情况进行调整 34 | CHROMADB_DIRECTORY = "chromaDB" # chromaDB向量数据库的持久化路径 35 | CHROMADB_COLLECTION_NAME = "demo001" # 待查询的chromaDB向量数据库的集合名称 36 | 37 | 38 | # get_embeddings方法计算向量 39 | def get_embeddings(texts): 40 | try: 41 | # 初始化OpenAI的Embedding模型 42 | client = OpenAI( 43 | base_url=os.getenv("OPENAI_BASE_URL"), 44 | api_key=os.getenv("OPENAI_API_KEY") 45 | ) 46 | data = client.embeddings.create(input=texts,model=os.getenv("OPENAI_EMBEDDING_MODEL")).data 47 | return [x.embedding for x in data] 48 | except Exception as e: 49 | logger.info(f"生成向量时出错: {e}") 50 | return [] 51 | 52 | 53 | # 对文本按批次进行向量计算 54 | def generate_vectors(data, max_batch_size=25): 55 | results = [] 56 | for i in range(0, len(data), max_batch_size): 57 | batch = data[i:i + max_batch_size] 58 | # 调用向量生成get_embeddings方法 根据调用的API不同进行选择 59 | response = get_embeddings(batch) 60 | results.extend(response) 61 | return results 62 | 63 | 64 | # 封装向量数据库chromadb类,提供两种方法 65 | class MyVectorDBConnector: 66 | def __init__(self, collection_name, embedding_fn): 67 | # 申明使用全局变量 68 | global CHROMADB_DIRECTORY 69 | # 实例化一个chromadb对象 70 | # 设置一个文件夹进行向量数据库的持久化存储 路径为当前文件夹下chromaDB文件夹 71 | chroma_client = chromadb.PersistentClient(path=CHROMADB_DIRECTORY) 72 | # 创建一个collection数据集合 73 | # get_or_create_collection()获取一个现有的向量集合,如果该集合不存在,则创建一个新的集合 74 | self.collection = chroma_client.get_or_create_collection( 75 | name=collection_name) 76 | # embedding处理函数 77 | self.embedding_fn = embedding_fn 78 | 79 | # 添加文档到集合 80 | # 文档通常包括文本数据和其对应的向量表示,这些向量可以用于后续的搜索和相似度计算 81 | def add_documents(self, documents): 82 | self.collection.add( 83 | embeddings=self.embedding_fn(documents), # 调用函数计算出文档中文本数据对应的向量 84 | documents=documents, # 文档的文本数据 85 | ids=[str(uuid.uuid4()) for i in range(len(documents))] # 文档的唯一标识符 自动生成uuid,128位 86 | ) 87 | 88 | # 检索向量数据库,返回包含查询结果的对象或列表,这些结果包括最相似的向量及其相关信息 89 | # query:查询文本 90 | # top_n:返回与查询向量最相似的前 n 个向量 91 | def search(self, query, top_n): 92 | try: 93 | results = self.collection.query( 94 | # 计算查询文本的向量,然后将查询文本生成的向量在向量数据库中进行相似度检索 95 | query_embeddings=self.embedding_fn([query]), 96 | n_results=top_n 97 | ) 98 | return results 99 | except Exception as e: 100 | logger.info(f"检索向量数据库时出错: {e}") 101 | return [] 102 | 103 | 104 | # 封装文本预处理及灌库方法 提供外部调用 105 | def vectorStoreSave(): 106 | global TEXT_LANGUAGE, CHROMADB_COLLECTION_NAME, INPUT_PDF, PAGE_NUMBERS 107 | # 测试中文文本 108 | if TEXT_LANGUAGE == 'Chinese': 109 | # 1、获取处理后的文本数据 110 | # 演示测试对指定的全部页进行处理,其返回值为划分为段落的文本列表 111 | paragraphs = pdfSplitTest_Ch.getParagraphs( 112 | filename=INPUT_PDF, 113 | page_numbers=PAGE_NUMBERS, 114 | min_line_length=1 115 | ) 116 | # 2、将文本片段灌入向量数据库 117 | # 实例化一个向量数据库对象 118 | # 其中,传参collection_name为集合名称, embedding_fn为向量处理函数 119 | vector_db = MyVectorDBConnector(CHROMADB_COLLECTION_NAME, generate_vectors) 120 | # 向向量数据库中添加文档(文本数据、文本数据对应的向量数据) 121 | vector_db.add_documents(paragraphs) 122 | # 3、封装检索接口进行检索测试 123 | user_query = "张三九的基本信息是什么" 124 | # 将检索出的5个近似的结果 125 | search_results = vector_db.search(user_query, 5) 126 | logger.info(f"检索向量数据库的结果: {search_results}") 127 | 128 | 129 | # 测试英文文本 130 | elif TEXT_LANGUAGE == 'English': 131 | # 1、获取处理后的文本数据 132 | # 演示测试对指定的全部页进行处理,其返回值为划分为段落的文本列表 133 | paragraphs = pdfSplitTest_En.getParagraphs( 134 | filename=INPUT_PDF, 135 | page_numbers=PAGE_NUMBERS, 136 | min_line_length=1 137 | ) 138 | # 2、将文本片段灌入向量数据库 139 | # 实例化一个向量数据库对象 140 | # 其中,传参collection_name为集合名称, embedding_fn为向量处理函数 141 | vector_db = MyVectorDBConnector(CHROMADB_COLLECTION_NAME, generate_vectors) 142 | # 向向量数据库中添加文档(文本数据、文本数据对应的向量数据) 143 | vector_db.add_documents(paragraphs) 144 | # 3、封装检索接口进行检索测试 145 | user_query = "llama2安全性如何" 146 | # 将检索出的5个近似的结果 147 | search_results = vector_db.search(user_query, 5) 148 | logger.info(f"检索向量数据库的结果: {search_results}") 149 | 150 | 151 | if __name__ == "__main__": 152 | # 测试文本预处理及灌库 153 | vectorStoreSave() 154 | 155 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | haystack-ai==2.8.0 2 | chroma-haystack==1.0.0 3 | python-dotenv==1.0.1 4 | pdfminer.six 5 | nltk==3.9.1 6 | 7 | 8 | --------------------------------------------------------------------------------