├── tinygraph ├── llm │ ├── __init__.py │ ├── groq.py │ ├── zhipu.py │ └── base.py ├── embedding │ ├── __init__.py │ ├── zhipu.py │ └── base.py ├── utils.py ├── prompt.py └── graph.py ├── requirements.txt ├── images ├── 运行结果.png ├── 图数据库示例.png ├── Tiny-Graphrag流程图.png └── Learning-Algorithms节点的详细信息.png ├── .gitignore ├── README.md ├── example └── data.md ├── Tiny-Graphrag_test.ipynb └── Tiny-Graphrag_User_Guide_and_Code_Documentation.md /tinygraph/llm/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tinygraph/embedding/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | neo4j 2 | numpy 3 | tqdm 4 | zhipuai 5 | -------------------------------------------------------------------------------- /images/运行结果.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/limafang/tiny-graphrag/HEAD/images/运行结果.png -------------------------------------------------------------------------------- /images/图数据库示例.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/limafang/tiny-graphrag/HEAD/images/图数据库示例.png -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Compiled python modules 2 | *.pyc 3 | __pycache__/ 4 | data_info.txt 5 | workspace/ 6 | .vscode/ -------------------------------------------------------------------------------- /images/Tiny-Graphrag流程图.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/limafang/tiny-graphrag/HEAD/images/Tiny-Graphrag流程图.png -------------------------------------------------------------------------------- /images/Learning-Algorithms节点的详细信息.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/limafang/tiny-graphrag/HEAD/images/Learning-Algorithms节点的详细信息.png -------------------------------------------------------------------------------- /tinygraph/embedding/zhipu.py: -------------------------------------------------------------------------------- 1 | from zhipuai import ZhipuAI 2 | from typing import List 3 | from .base import BaseEmb 4 | 5 | 6 | class zhipuEmb(BaseEmb): 7 | def __init__(self, model_name: str, api_key: str, **kwargs): 8 | super().__init__(model_name=model_name, **kwargs) 9 | self.client = ZhipuAI(api_key=api_key) 10 | 11 | def get_emb(self, text: str) -> List[float]: 12 | emb = self.client.embeddings.create( 13 | model=self.model_name, 14 | input=text, 15 | ) 16 | return emb.data[0].embedding 17 | -------------------------------------------------------------------------------- /tinygraph/embedding/base.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import List, Any, Optional 3 | 4 | 5 | class BaseEmb(ABC): 6 | def __init__( 7 | self, 8 | model_name: str, 9 | model_params: Optional[dict[str, Any]] = None, 10 | **kwargs: Any, 11 | ): 12 | self.model_name = model_name 13 | self.model_params = model_params or {} 14 | 15 | @abstractmethod 16 | def get_emb(self, input: str) -> List[float]: 17 | """Sends a text input to the embedding model and retrieves the embedding. 18 | 19 | Args: 20 | input (str): Text sent to the embedding model 21 | 22 | Returns: 23 | List[float]: The embedding vector from the model. 24 | """ 25 | pass 26 | -------------------------------------------------------------------------------- /tinygraph/llm/groq.py: -------------------------------------------------------------------------------- 1 | from groq import Groq 2 | from typing import Any, Optional 3 | from .base import BaseLLM 4 | 5 | 6 | class groqLLM(BaseLLM): 7 | """Implementation of the BaseLLM interface using zhipuai.""" 8 | 9 | def __init__( 10 | self, 11 | model_name: str, 12 | api_key: str, 13 | model_params: Optional[dict[str, Any]] = None, 14 | **kwargs: Any, 15 | ): 16 | super().__init__(model_name, model_params, **kwargs) 17 | self.client = Groq(api_key=api_key) 18 | 19 | def predict(self, input: str) -> str: 20 | """Sends a text input to the zhipuai model and retrieves a response. 21 | 22 | Args: 23 | input (str): Text sent to the zhipuai model 24 | 25 | Returns: 26 | str: The response from the zhipuai model. 27 | """ 28 | response = self.client.chat.completions.create( 29 | model=self.model_name, 30 | messages=[{"role": "user", "content": input}], 31 | ) 32 | return response.choices[0].message.content 33 | -------------------------------------------------------------------------------- /tinygraph/llm/zhipu.py: -------------------------------------------------------------------------------- 1 | from zhipuai import ZhipuAI 2 | from typing import Any, Optional 3 | from .base import BaseLLM 4 | 5 | 6 | class zhipuLLM(BaseLLM): 7 | """Implementation of the BaseLLM interface using zhipuai.""" 8 | 9 | def __init__( 10 | self, 11 | model_name: str, 12 | api_key: str, 13 | model_params: Optional[dict[str, Any]] = None, 14 | **kwargs: Any, 15 | ): 16 | super().__init__(model_name, model_params, **kwargs) 17 | self.client = ZhipuAI(api_key=api_key) 18 | 19 | def predict(self, input: str) -> str: 20 | """Sends a text input to the zhipuai model and retrieves a response. 21 | 22 | Args: 23 | input (str): Text sent to the zhipuai model 24 | 25 | Returns: 26 | str: The response from the zhipuai model. 27 | """ 28 | response = self.client.chat.completions.create( 29 | model=self.model_name, 30 | messages=[{"role": "user", "content": input}], 31 | ) 32 | return response.choices[0].message.content 33 | -------------------------------------------------------------------------------- /tinygraph/llm/base.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Any, Optional 3 | 4 | 5 | class BaseLLM(ABC): 6 | """Interface for large language models. 7 | 8 | Args: 9 | model_name (str): The name of the language model. 10 | model_params (Optional[dict[str, Any]], optional): Additional parameters passed to the model when text is sent to it. Defaults to None. 11 | **kwargs (Any): Arguments passed to the model when for the class is initialised. Defaults to None. 12 | """ 13 | 14 | def __init__( 15 | self, 16 | model_name: str, 17 | model_params: Optional[dict[str, Any]] = None, 18 | **kwargs: Any, 19 | ): 20 | self.model_name = model_name 21 | self.model_params = model_params or {} 22 | 23 | @abstractmethod 24 | def predict(self, input: str) -> str: 25 | """Sends a text input to the LLM and retrieves a response. 26 | 27 | Args: 28 | input (str): Text sent to the LLM 29 | 30 | Returns: 31 | str: The response from the LLM. 32 | """ 33 | -------------------------------------------------------------------------------- /tinygraph/utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | import numpy as np 3 | from typing import List, Tuple 4 | from hashlib import md5 5 | import json 6 | import os 7 | 8 | 9 | def get_text_inside_tag(html_string: str, tag: str): 10 | # html_string 为待解析文本,tag为查找标签 11 | pattern = f"<{tag}>(.*?)<\/{tag}>" 12 | try: 13 | result = re.findall(pattern, html_string, re.DOTALL) 14 | return result 15 | except SyntaxError as e: 16 | raise ("Json Decode Error: {error}".format(error=e)) 17 | 18 | 19 | def read_json_file(file_path): 20 | try: 21 | with open(file_path, "r", encoding="utf-8") as file: 22 | return json.load(file) 23 | except: 24 | return {} 25 | 26 | 27 | def write_json_file(data, file_path): 28 | with open(file_path, "w", encoding="utf-8") as file: 29 | json.dump(data, file, indent=4, ensure_ascii=False) 30 | 31 | 32 | def compute_mdhash_id(content, prefix: str = ""): 33 | return prefix + md5(content.encode()).hexdigest() 34 | 35 | 36 | def save_triplets_to_txt(triplets, file_path): 37 | with open(file_path, "a", encoding="utf-8") as file: 38 | file.write(f"{triplets[0]},{triplets[1]},{triplets[2]}\n") 39 | 40 | 41 | def cosine_similarity(vector1: List[float], vector2: List[float]) -> float: 42 | """ 43 | calculate cosine similarity between two vectors 44 | """ 45 | dot_product = np.dot(vector1, vector2) 46 | magnitude = np.linalg.norm(vector1) * np.linalg.norm(vector2) 47 | if not magnitude: 48 | return 0 49 | return dot_product / magnitude 50 | 51 | 52 | def create_file_if_not_exists(file_path: str): 53 | if not os.path.exists(file_path): 54 | with open(file_path, "w") as f: 55 | f.write("") 56 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Tiny-Graphrag 2 | 3 | Tiny-Graphrag 是一个简洁版本的 GraphRAG 实现,旨在提供一个最简单的 GraphRAG 系统,包含所有必要的功能。我们实现了添加文档的全部流程,以及本地查询和全局查询的功能。 4 | 5 | ## 安装 6 | 7 | Tiny-Graphrag 需要以下版本的 Neo4j 和 JDK,以及 GDS 插件: 8 | 9 | - Neo4j: 5.24.0 10 | - OpenJDK: 17.0.12 11 | - GDS: 2.10.1 12 | 13 | ## 快速开始 14 | 15 | 首先克隆仓库: 16 | 17 | ```shell 18 | git clone https://github.com/limafang/tiny-graphrag.git 19 | cd tiny-graphrag 20 | ``` 21 | 22 | 安装必要依赖: 23 | 24 | ```shell 25 | pip install -r requirements.txt 26 | ``` 27 | 28 | 接下来,你需要配置使用的 LLM 和 Embedding 服务。目前我们只支持 zhipu 的 LLM 和 Embedding 服务: 29 | 30 | ```python 31 | from tinygraph.graph import TinyGraph 32 | from tinygraph.embedding.zhipu import zhipuEmb 33 | from tinygraph.llm.zhipu import zhipuLLM 34 | 35 | emb = zhipuEmb("model name", "your key") 36 | llm = zhipuLLM("model name", "your key") 37 | graph = TinyGraph( 38 | url="your url", 39 | username="neo4j name", 40 | password="neo4j password", 41 | llm=llm, 42 | emb=emb, 43 | ) 44 | ``` 45 | 46 | 使用 TinyGraph 添加文档。目前支持所有文本格式的文件。这一步的时间可能较长,结束后,在当前目录下会生成一个 `workspace` 文件夹,包含 `community`、`chunk` 和 `doc` 信息: 47 | 48 | ```python 49 | graph.add_document("example/data.md") 50 | ``` 51 | 52 | 完成文档添加后,可以使用 TinyGraph 进行查询。TinyGraph 支持本地查询和全局查询: 53 | 54 | ```python 55 | local_res = graph.local_query("what is ML") 56 | print(local_res) 57 | global_res = graph.global_query("what is ML") 58 | print(global_res) 59 | ``` 60 | 61 | 通过以上步骤,你可以快速上手 Tiny-Graphrag,体验其强大的文档管理和查询功能。 62 | 63 | ## 代码解读 64 | 本仓库提供了Tiny-Graphrag项目核心代码的解读文档,用于帮助新手快速理解整个项目,详情见: 65 | - Tiny-Graphrag_User_Guide_and_Code_Documentation.md 66 | 67 | ## 致谢 68 | 69 | 编写 Tiny-Graphrag 的过程中,我们参考了以下项目: 70 | 71 | [GraphRAG](https://github.com/microsoft/graphrag) 72 | 73 | [nano-graphrag](https://github.com/gusye1234/nano-graphrag) 74 | 75 | 需要说明的是,Tiny-Graphrag 是一个简化版本的 GraphRAG 实现,并不适用于生产环境,如果你需要一个更完整的 GraphRAG 实现,我们建议你使用上述项目。 76 | -------------------------------------------------------------------------------- /example/data.md: -------------------------------------------------------------------------------- 1 | # Introduction 2 | 3 | ## 1.1 Introduction 4 | 5 | Following a drizzling, we take a walk on the wet street. Feeling the gentle breeze and seeing the sunset glow, we bet the weather must be nice tomorrow. Walking to a fruit stand, we pick up a green watermelon with curly root and muffled sound; while hoping the watermelon is ripe, we also expect some good aca- demic marks this semester after all the hard work on studies. We wish readers to share the same confidence in their studies, but to begin with, let us take an informal discussion on what is machine learning . 6 | 7 | Taking a closer look at the scenario described above, we notice that it involves many experience-based predictions. For example, why would we expect beautiful weather tomorrow after observing the gentle breeze and sunset glow? We expect this beautiful weather because,from our experience,theweather on the following day is often beautiful when we experience such a scene in the present day. Also, why do we pick the watermelon with green color, curly root, and muffled sound? It is because we have eaten and enjoyed many watermelons, and those sat- isfying the above criteria are usually ripe. Similarly, our learn- ing experience tells us that hard work leads to good academic marks. We are confident in our predictions because we learned from experience and made experience-based decisions. 8 | 9 | Mitchell ( 1997 ) provides a more formal definition: ‘‘A computer program is said to learn from experience $E$ for some class of tasks $T$ and performance measure $P$ , if its performance at tasks in $T$ , as measured by $P$ , improves with experience $E$ .’’ 10 | 11 | E.g., Hand et al. ( 2001 ). 12 | 13 | While humans learn from experience, can computers do the same? The answer is ‘‘yes’’, and machine learning is what we need. Machine learning is the technique that improves system performance by learning from experience via computational methods. In computer systems, experience exists in the form of data, and the main task of machine learning is to develop learning algorithms that build models from data. By feeding the learning algorithm with experience data, we obtain a model that can make predictions (e.g., the watermelon is ripe) on new observations (e.g., an uncut watermelon). If we consider com- puter science as the subject of algorithms, then machine learn- ing is the subject of learning algorithms . 14 | 15 | In this book, we use ‘‘model’’ as a general term for the out- come learned from data. In some other literature, the term ‘‘model’’may refer to the global outcome (e.g., a decision tree), while the term ‘‘pattern’’ refers to the local outcome (e.g., a single rule). -------------------------------------------------------------------------------- /Tiny-Graphrag_test.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "name": "stdout", 10 | "output_type": "stream", 11 | "text": [ 12 | "/home/calvin-lucas/Documents/DataWhale_Learning_Material/tiny-graphrag\n" 13 | ] 14 | } 15 | ], 16 | "source": [ 17 | "# 注意:重新运行前需要:重启整个内核\n", 18 | "import os\n", 19 | "import sys\n", 20 | "sys.path.append('.') # 添加当前目录到 Python 路径\n", 21 | "print(os.getcwd()) # 验证下当前工作路径" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": 2, 27 | "metadata": {}, 28 | "outputs": [], 29 | "source": [ 30 | "# 导入模块\n", 31 | "from tinygraph.graph import TinyGraph\n", 32 | "from tinygraph.embedding.zhipu import zhipuEmb\n", 33 | "from tinygraph.llm.zhipu import zhipuLLM\n", 34 | "\n", 35 | "from neo4j import GraphDatabase\n", 36 | "from dotenv import load_dotenv # 用于加载环境变量" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": 3, 42 | "metadata": {}, 43 | "outputs": [], 44 | "source": [ 45 | "# 配置使用的 LLM 和 Embedding 服务,现在只支持 ZhipuAI\n", 46 | "# 加载 .env文件, 从而导入api_key\n", 47 | "load_dotenv() # 加载工作目录下的 .env 文件\n", 48 | "\n", 49 | "emb = zhipuEmb(\n", 50 | " model_name=\"embedding-2\", # 嵌入模型\n", 51 | " api_key=os.getenv('API_KEY')\n", 52 | ")\n", 53 | "llm = zhipuLLM(\n", 54 | " model_name=\"glm-3-turbo\", # LLM 模型\n", 55 | " api_key=os.getenv('API_KEY')\n", 56 | ")\n", 57 | "graph = TinyGraph(\n", 58 | " url=\"neo4j://localhost:7687\",\n", 59 | " username=\"neo4j\",\n", 60 | " password=\"neo4j-passwordTGR\", # 初次登陆的默认密码为neo4j,此后需修改再使用\n", 61 | " llm=llm,\n", 62 | " emb=emb,\n", 63 | ")\n" 64 | ] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "execution_count": 4, 69 | "metadata": {}, 70 | "outputs": [ 71 | { 72 | "name": "stdout", 73 | "output_type": "stream", 74 | "text": [ 75 | "Document 'example/data.md' has already been loaded, skipping import process.\n" 76 | ] 77 | } 78 | ], 79 | "source": [ 80 | "# 使用 TinyGraph 添加文档。目前支持所有文本格式的文件。这一步的时间可能较长;\n", 81 | "# 结束后,在当前目录下会生成一个 `workspace` 文件夹,包含 `community`、`chunk` 和 `doc` 信息\n", 82 | "graph.add_document(\"example/data.md\")" 83 | ] 84 | }, 85 | { 86 | "cell_type": "code", 87 | "execution_count": 5, 88 | "metadata": {}, 89 | "outputs": [ 90 | { 91 | "name": "stdout", 92 | "output_type": "stream", 93 | "text": [ 94 | "数据库连接正常,节点数量: 29\n" 95 | ] 96 | } 97 | ], 98 | "source": [ 99 | "# 再次验证数据库连接\n", 100 | "with graph.driver.session() as session:\n", 101 | " result = session.run(\"MATCH (n) RETURN count(n) as count\")\n", 102 | " count = result.single()[\"count\"]\n", 103 | " print(f\"数据库连接正常,节点数量: {count}\")" 104 | ] 105 | }, 106 | { 107 | "cell_type": "code", 108 | "execution_count": 6, 109 | "metadata": {}, 110 | "outputs": [ 111 | { 112 | "name": "stdout", 113 | "output_type": "stream", 114 | "text": [ 115 | "\n", 116 | "本地查询结果:\n", 117 | "The term \"dl\" is not explicitly defined in the provided context. However, based on the context's focus on machine learning, \"dl\" might commonly be interpreted as an abbreviation for \"deep learning,\" which is a subset of machine learning that involves neural networks with many layers (hence \"deep\"). Deep learning has become a prominent field, particularly in the realm of artificial intelligence, where it is used to recognize patterns and make predictions from large datasets.\n", 118 | "\n", 119 | "If \"dl\" refers to something else in the context of the user query, there would be no information to discern its meaning without further clarification or additional context.\n" 120 | ] 121 | } 122 | ], 123 | "source": [ 124 | "# 执行局部查询测试\n", 125 | "local_res = graph.local_query(\"what is dl?\")\n", 126 | "print(\"\\n本地查询结果:\")\n", 127 | "print(local_res)\n" 128 | ] 129 | }, 130 | { 131 | "cell_type": "code", 132 | "execution_count": 7, 133 | "metadata": {}, 134 | "outputs": [ 135 | { 136 | "name": "stdout", 137 | "output_type": "stream", 138 | "text": [ 139 | "\n", 140 | "全局查询结果:\n", 141 | "The term 'dl' is not explicitly mentioned in the provided data tables. Therefore, I don't know what 'dl' refers to in the context of the user's question. If 'dl' stands for 'Deep Learning,' it is a subset of machine learning that uses neural networks with many layers for feature extraction and modeling. However, this context is not provided in the data tables.\n" 142 | ] 143 | } 144 | ], 145 | "source": [ 146 | "\n", 147 | "# 执行全局查询测试\n", 148 | "global_res = graph.global_query(\"what is dl?\")\n", 149 | "print(\"\\n全局查询结果:\")\n", 150 | "print(global_res)" 151 | ] 152 | } 153 | ], 154 | "metadata": { 155 | "kernelspec": { 156 | "display_name": "TinyGraphRAG_2025-04-08", 157 | "language": "python", 158 | "name": "python3" 159 | }, 160 | "language_info": { 161 | "codemirror_mode": { 162 | "name": "ipython", 163 | "version": 3 164 | }, 165 | "file_extension": ".py", 166 | "mimetype": "text/x-python", 167 | "name": "python", 168 | "nbconvert_exporter": "python", 169 | "pygments_lexer": "ipython3", 170 | "version": "3.10.16" 171 | } 172 | }, 173 | "nbformat": 4, 174 | "nbformat_minor": 2 175 | } 176 | -------------------------------------------------------------------------------- /tinygraph/prompt.py: -------------------------------------------------------------------------------- 1 | GEN_NODES = """ 2 | ## Goal 3 | Please identify and extract triplet information from the provided article, focusing only on entities and relationships related to significant knowledge points. 4 | Each triplet should be in the form of (Subject, Predicate, Object). 5 | Follow these guidelines: 6 | 7 | 1. **Subject:** Concepts in Bayesian Optimization 8 | 2. **Predicate:** The action or relationship that links the subject to the object. 9 | 3. **Object:** Concepts in Bayesian Optimization that is affected by or related to the action of the subject. 10 | 11 | ## Example 12 | For the sentence "Gaussian Processes are used to model the objective function in Bayesian Optimization" the triplet would be: 13 | 14 | Gaussian Processesare used to model the objective function inBayesian Optimization 15 | 16 | For the sentence "John read a book on the weekend," which is not related to any knowledge points, no triplet should be extracted. 17 | 18 | ## Instructions 19 | 1. Read through the article carefully. 20 | 2. Think step by step. Try to find some useful knowledge points from the article. You need to reorganize the content of the sentence into corresponding knowledge points. 21 | 3. Identify key sentences that contain relevant triplet information related to significant knowledge points. 22 | 4. Extract and format the triplets as per the given example, excluding any information that is not relevant to significant knowledge points. 23 | 24 | ## Output Format 25 | For each identified triplet, provide: 26 | [Entity]The action or relationshipThe entity 27 | 28 | ## Article 29 | 30 | {text} 31 | 32 | ## Your response 33 | """ 34 | 35 | GET_ENTITY = """ 36 | ## Goal 37 | 38 | You are an experienced machine learning teacher. 39 | You need to identify the key concepts related to machine learning that the article requires students to master. For each concept, provide a brief description that explains its relevance and importance in the context of the article. 40 | 41 | ## Example 42 | 43 | article: 44 | "In the latest study, we explored the potential of using machine learning algorithms for disease prediction. We used support vector machines (SVM) and random forest algorithms to analyze medical data. The results showed that these models performed well in predicting disease risk through feature selection and cross-validation. In particular, the random forest model showed better performance in dealing with overfitting problems. In addition, we discussed the application of deep learning in medical image analysis." 45 | 46 | response: 47 | 48 | Support Vector Machine (SVM) 49 | A supervised learning model used for classification and regression tasks, particularly effective in high-dimensional spaces. 50 | 51 | 52 | Random Forest Algorithm 53 | An ensemble learning method that builds multiple decision trees and merges them together to get a more accurate and stable prediction, often used to reduce overfitting. 54 | 55 | 56 | Feature Selection 57 | The process of selecting a subset of relevant features for use in model construction, crucial for improving model performance and reducing complexity. 58 | 59 | 60 | Overfitting 61 | A common issue where a model learns the details and noise in the training data to the extent that it negatively impacts the model's performance on new data. 62 | 63 | 64 | Deep Learning 65 | A subset of machine learning that uses neural networks with many layers to model complex patterns in large datasets, often applied in image and speech recognition tasks. 66 | 67 | 68 | ## Format 69 | 70 | Wrap each concept in the HTML tag , and include the name of the concept in the tag and its description in the tag. 71 | 72 | ## Article 73 | 74 | {text} 75 | 76 | ## Your response 77 | """ 78 | 79 | 80 | ENTITY_DISAMBIGUATION = """ 81 | ## Goal 82 | Given multiple entities with the same name, determine if they can be merged into a single entity. If merging is possible, provide the transformation from entity id to entity id. 83 | 84 | ## Guidelines 85 | 1. **Entities:** A list of entities with the same name. 86 | 2. **Merge:** Determine if the entities can be merged into a single entity. 87 | 3. **Transformation:** If merging is possible, provide the transformation from entity id to entity id. 88 | 89 | ## Example 90 | 1. Entities: 91 | [ 92 | {"name": "Entity A", "entity id": "entity-1"}, 93 | {"name": "Entity A", "entity id": "entity-2"}, 94 | {"name": "Entity A", "entity id": "entity-3"} 95 | ] 96 | 97 | Your response should be: 98 | 99 | {"entity-2": "entity-1", "entity-3": "entity-1"} 100 | 101 | 102 | 2. Entities: 103 | [ 104 | {"name": "Entity B", "entity id": "entity-4"}, 105 | {"name": "Entity C", "entity id": "entity-5"}, 106 | {"name": "Entity B", "entity id": "entity-6"} 107 | ] 108 | 109 | Your response should be: 110 | 111 | None 112 | 113 | ## Output Format 114 | Provide the following information: 115 | - Transformation: A dictionary mapping entity ids to the final entity id after merging. 116 | 117 | ## Given Entities 118 | {entities} 119 | 120 | ## Your response 121 | """ 122 | 123 | GET_TRIPLETS = """ 124 | ## Goal 125 | Identify and extract all the relationships between the given concepts from the provided text. 126 | Identify as many relationships between the concepts as possible. 127 | The relationship in the triple should accurately reflect the interaction or connection between the two concepts. 128 | 129 | ## Guidelines: 130 | 1. **Subject:** The first entity from the given entities. 131 | 2. **Predicate:** The action or relationship linking the subject to the object. 132 | 3. **Object:** The second entity from the given entities. 133 | 134 | ## Example: 135 | 1. Article : 136 | "Gaussian Processes are used to model the objective function in Bayesian Optimization" 137 | Given entities: 138 | [{{"name": "Gaussian Processes", "entity id": "entity-1"}}, {{"name": "Bayesian Optimization", "entity id": "entity-2"}}] 139 | Output: 140 | Gaussian Processesentity-1are used to model the objective function inBayesian Optimizationentity-2 141 | 142 | 2. Article : 143 | "Hydrogen is a colorless, odorless, non-toxic gas and is the lightest and most abundant element in the universe. Oxygen is a gas that supports combustion and is widely present in the Earth's atmosphere. Water is a compound made up of hydrogen and oxygen, with the chemical formula H2O." 144 | Given entities: 145 | [{{"name": "Hydrogen", "entity id": "entity-3"}}, {{"name": "Oxygen", "entity id": "entity-4"}}, {{"name": "Water", "entity id": "entity-5"}}] 146 | Output: 147 | Hydrogenentity-3is a component ofWaterentity-5 148 | 3. Article : 149 | "John read a book on the weekend" 150 | Given entities: 151 | [] 152 | Output: 153 | None 154 | 155 | ## Format: 156 | For each identified triplet, provide: 157 | **the entity should just from "Given Entities"** 158 | [Entity][Entity ID][The action or relationship][Entity][Entity ID] 159 | 160 | ## Given Entities: 161 | {entity} 162 | 163 | ### Article: 164 | {text} 165 | 166 | ## Additional Instructions: 167 | - Before giving your response, you should analyze and think about it sentence by sentence. 168 | - Both the subject and object must be selected from the given entities and cannot change their content. 169 | - If no relevant triplet involving both entities is found, no triplet should be extracted. 170 | - If there are similar concepts, please rewrite them into a form that suits our requirements. 171 | 172 | ## Your response: 173 | """ 174 | 175 | TEST_PROMPT = """ 176 | ## Foundation of students 177 | {state} 178 | ## Gole 179 | You will help students solve question through multiple rounds of dialogue. 180 | Please follow the steps below to help students solve the question: 181 | 1. Explain the basic knowledge and principles behind the question and make sure the other party understands these basic concepts. 182 | 2. Don't give a complete answer directly, but guide the student to think about the key steps of the question. 183 | 3. After guiding the student to think, let them try to solve the question by themselves. Give appropriate hints and feedback to help them correct their mistakes and further improve their solutions. 184 | 4. Return to TERMINATE after solving the problem 185 | """ 186 | 187 | GEN_COMMUNITY_REPORT = """ 188 | ## Role 189 | You are an AI assistant that helps a human analyst to perform general information discovery. 190 | Information discovery is the process of identifying and assessing relevant information associated with certain entities (e.g., organizations and individuals) within a network. 191 | 192 | ## Goal 193 | Write a comprehensive report of a community. 194 | Given a list of entities that belong to the community as well as their relationships and optional associated claims. The report will be used to inform decision-makers about information associated with the community and their potential impact. 195 | The content of this report includes an overview of the community's key entities, their legal compliance, technical capabilities, reputation, and noteworthy claims. 196 | 197 | ## Report Structure 198 | 199 | The report should include the following sections: 200 | 201 | - TITLE: community's name that represents its key entities - title should be short but specific. When possible, include representative named entities in the title. 202 | - SUMMARY: An executive summary of the community's overall structure, how its entities are related to each other, and significant information associated with its entities. 203 | - DETAILED FINDINGS: A list of 5-10 key insights about the community. Each insight should have a short summary followed by multiple paragraphs of explanatory text grounded according to the grounding rules below. Be comprehensive. 204 | 205 | Return output as a well-formed JSON-formatted string with the following format: 206 | {{ 207 | "title": , 208 | "summary": , 209 | "findings": [ 210 | {{ 211 | "summary":, 212 | "explanation": 213 | }}, 214 | {{ 215 | "summary":, 216 | "explanation": 217 | }} 218 | ... 219 | ] 220 | }} 221 | 222 | ## Grounding Rules 223 | Do not include information where the supporting evidence for it is not provided. 224 | 225 | ## Example Input 226 | ----------- 227 | Text: 228 | ``` 229 | Entities: 230 | ```csv 231 | entity,description 232 | VERDANT OASIS PLAZA,Verdant Oasis Plaza is the location of the Unity March 233 | HARMONY ASSEMBLY,Harmony Assembly is an organization that is holding a march at Verdant Oasis Plaza 234 | ``` 235 | Relationships: 236 | ```csv 237 | source,target,description 238 | VERDANT OASIS PLAZA,UNITY MARCH,Verdant Oasis Plaza is the location of the Unity March 239 | VERDANT OASIS PLAZA,HARMONY ASSEMBLY,Harmony Assembly is holding a march at Verdant Oasis Plaza 240 | VERDANT OASIS PLAZA,UNITY MARCH,The Unity March is taking place at Verdant Oasis Plaza 241 | VERDANT OASIS PLAZA,TRIBUNE SPOTLIGHT,Tribune Spotlight is reporting on the Unity march taking place at Verdant Oasis Plaza 242 | VERDANT OASIS PLAZA,BAILEY ASADI,Bailey Asadi is speaking at Verdant Oasis Plaza about the march 243 | HARMONY ASSEMBLY,UNITY MARCH,Harmony Assembly is organizing the Unity March 244 | ``` 245 | ``` 246 | Output: 247 | {{ 248 | "title": "Verdant Oasis Plaza and Unity March", 249 | "summary": "The community revolves around the Verdant Oasis Plaza, which is the location of the Unity March. The plaza has relationships with the Harmony Assembly, Unity March, and Tribune Spotlight, all of which are associated with the march event.", 250 | "findings": [ 251 | {{ 252 | "summary": "Verdant Oasis Plaza as the central location", 253 | "explanation": "Verdant Oasis Plaza is the central entity in this community, serving as the location for the Unity March. This plaza is the common link between all other entities, suggesting its significance in the community. The plaza's association with the march could potentially lead to issues such as public disorder or conflict, depending on the nature of the march and the reactions it provokes." 254 | }}, 255 | {{ 256 | "summary": "Harmony Assembly's role in the community", 257 | "explanation": "Harmony Assembly is another key entity in this community, being the organizer of the march at Verdant Oasis Plaza. The nature of Harmony Assembly and its march could be a potential source of threat, depending on their objectives and the reactions they provoke. The relationship between Harmony Assembly and the plaza is crucial in understanding the dynamics of this community." 258 | }}, 259 | {{ 260 | "summary": "Unity March as a significant event", 261 | "explanation": "The Unity March is a significant event taking place at Verdant Oasis Plaza. This event is a key factor in the community's dynamics and could be a potential source of threat, depending on the nature of the march and the reactions it provokes. The relationship between the march and the plaza is crucial in understanding the dynamics of this community." 262 | }}, 263 | {{ 264 | "summary": "Role of Tribune Spotlight", 265 | "explanation": "Tribune Spotlight is reporting on the Unity March taking place in Verdant Oasis Plaza. This suggests that the event has attracted media attention, which could amplify its impact on the community. The role of Tribune Spotlight could be significant in shaping public perception of the event and the entities involved." 266 | }} 267 | ] 268 | }} 269 | 270 | ## Real Data 271 | Use the following text for your answer. Do not make anything up in your answer. 272 | 273 | Text: 274 | ``` 275 | {input_text} 276 | ``` 277 | 278 | The report should include the following sections: 279 | 280 | - TITLE: community's name that represents its key entities - title should be short but specific. When possible, include representative named entities in the title. 281 | - SUMMARY: An executive summary of the community's overall structure, how its entities are related to each other, and significant information associated with its entities. 282 | - DETAILED FINDINGS: A list of 5-10 key insights about the community. Each insight should have a short summary followed by multiple paragraphs of explanatory text grounded according to the grounding rules below. Be comprehensive. 283 | 284 | Return output as a well-formed JSON-formatted string with the following format: 285 | {{ 286 | "title": , 287 | "summary": , 288 | "rating": , 289 | "rating_explanation": , 290 | "findings": [ 291 | {{ 292 | "summary":, 293 | "explanation": 294 | }}, 295 | {{ 296 | "summary":, 297 | "explanation": 298 | }} 299 | ... 300 | ] 301 | }} 302 | 303 | ## Grounding Rules 304 | Do not include information where the supporting evidence for it is not provided. 305 | 306 | Output: 307 | """ 308 | 309 | GLOBAL_MAP_POINTS = """ 310 | You are a helpful assistant responding to questions about data in the tables provided. 311 | 312 | 313 | ---Goal--- 314 | 315 | Generate a response consisting of a list of key points that responds to the user's question, summarizing all relevant information in the input data tables. 316 | 317 | You should use the data provided in the data tables below as the primary context for generating the response. 318 | If you don't know the answer or if the input data tables do not contain sufficient information to provide an answer, just say so. Do not make anything up. 319 | 320 | Each key point in the response should have the following element: 321 | - Description: A comprehensive description of the point. 322 | - Importance Score: An integer score between 0-100 that indicates how important the point is in answering the user's question. An 'I don't know' type of response should have a score of 0. 323 | 324 | The response should be HTML formatted as follows: 325 | 326 | 327 | "Description of point 1..."score_value 328 | "Description of point 2..."score_value 329 | 330 | 331 | The response shall preserve the original meaning and use of modal verbs such as "shall", "may" or "will". 332 | Do not include information where the supporting evidence for it is not provided. 333 | 334 | 335 | ---Data tables--- 336 | 337 | {context_data} 338 | 339 | ---User query--- 340 | 341 | {query} 342 | 343 | ---Goal--- 344 | 345 | Generate a response consisting of a list of key points that responds to the user's question, summarizing all relevant information in the input data tables. 346 | 347 | You should use the data provided in the data tables below as the primary context for generating the response. 348 | If you don't know the answer or if the input data tables do not contain sufficient information to provide an answer, just say so. Do not make anything up. 349 | 350 | Each key point in the response should have the following element: 351 | - Description: A comprehensive description of the point. 352 | - Importance Score: An integer score between 0-100 that indicates how important the point is in answering the user's question. An 'I don't know' type of response should have a score of 0. 353 | 354 | The response shall preserve the original meaning and use of modal verbs such as "shall", "may" or "will". 355 | Do not include information where the supporting evidence for it is not provided. 356 | 357 | The response should be HTML formatted as follows: 358 | 359 | "Description of point 1..."score_value 360 | "Description of point 2..."score_value 361 | 362 | 363 | """ 364 | 365 | LOCAL_QUERY = """ 366 | ## User Query 367 | {query} 368 | ## Context 369 | {context} 370 | ## Task 371 | Based on given context, please provide a response to the user query. 372 | ## Your Response 373 | """ 374 | 375 | GLOBAL_QUERY = """ 376 | ## User Query 377 | {query} 378 | ## Context 379 | {context} 380 | ## Task 381 | Based on given context, please provide a response to the user query. 382 | ## Your Response 383 | """ 384 | -------------------------------------------------------------------------------- /Tiny-Graphrag_User_Guide_and_Code_Documentation.md: -------------------------------------------------------------------------------- 1 | # Tiny-Graphrag使用指南与代码解读 2 | >此README包括两部分:1.引言;2.正文 3 | ## 引言: 4 | - Tiny-Graphrag是一个基于Graphrag的简化版本,包含了Graphrag的核心功能: 1.知识图谱构建;2.图检索优化;3.生成增强。创建Graphrag项目的目的是帮助大家理解Graphrag的原理并提供Demo来实现。 5 | - 本项目中信息传输的总体流程如下所示: 6 | 7 |
8 | 9 |
10 | 11 | - 用通俗语言来描述就是:**输入问题后,通过图结构运算层的计算,将得到的上下文交给一个“聪明的学生”(即大语言模型 LLM),让它基于这些上下文进行推理和回答问题。** 12 | ## 正文: 13 | >正文包括三部分:1.Graphrag简要介绍;2.Tiny-Graphrag 使用方法;3.Tiny-Graphrag代码解读 14 | ### Graphrag简要介绍 15 | --- 16 | - 是什么? 17 | - 基于知识图谱的检索增强生成技术,通过显式建模实体关系提升rag的多跳推理能力。 18 | - 提出时能够解决什么问题? 19 | - 传统rag的局限:单跳检索(无法回答"特朗普和拜登的母校有何共同点?"类问题) 语义相似度≠逻辑相关性 20 | - Graphrag的改进:通过图路径实现多跳推理(如"特朗普→宾大→法学←拜登") 21 | - 以微软Graphrag为例,其核心功能如下表所示: 22 | 23 | | 模块 | 模块描述 | 24 | |:------|:-----| 25 | | 知识图谱构建 | 核心功能之一,将文本或结构化数据转化为图结构(节点和边)。 | 26 | | 图检索优化 | 基于图谱的拓扑关系(如多跳路径、子图匹配)改进传统向量检索。 | 27 | | 生成增强 | 利用检索到的图结构(如子图、路径)增强大模型的生成逻辑性和准确性。 | 28 | 29 | ### Tiny-Graphrag 使用方法 30 | --- 31 | - 本项目给出了Tiny-Graphrag使用方法,初学者可以先直接跑通这个程序,然后再继续了解具体原理。这样的学习曲线更缓和,能有效防止卡在代码理解层面而对代码的整体作用缺少理解,难以应用。下面给出Tiny-Graphrag使用的具体方法。 32 | - Tiny-Graphrag 使用方法 33 | - 个人主机环境:ubuntu24.04 34 | - 代码下载 35 | ```bash 36 | git clone https://github.com/limafang/tiny-graphrag.git 37 | cd tiny-graphrag 38 | ``` 39 | - 主机环境配置 40 | 1. 安装:`neo4j --version 5.26.5`,可使用wget配合pip来完成 41 | 2. 安装插件:`GDS`。 可从github上找到与`neo4j 5.26.5`**兼容**的`GDS 2.13.2`,将这个.jar文件放到neo4j的插件文件夹里。 42 | 3. 安装:`OpenJDK-21`。命令行`sudo apt install openjdk-21-jre-headless` 43 | - 使用conda创建虚拟环境(虚拟环境创建此处仅作参考,学习者可以使用自己常用的开发环境来运行) 44 | ```bash 45 | conda create --name Tiny-Graphrag_2025-04 python=3.10 -y # 虚拟环境创建 46 | conda activate TinyEval_2025-04 # 命令行激活虚拟环境 47 | conda install pip -y # 在conda环境内安装pip包管理工具 48 | ``` 49 | - 环境中安装requirements.txt中的依赖,对应命令行为`pip install -r requirements.txt` 50 | - 先运行Neo4j,命令行为:`sudo neo4j start`,然后在浏览器中登陆到neo4j,默认网址为:http://localhost:7474 51 | - 运行`Tiny-Graphrag_test.ipynb` 52 | - 注意每次全部重新运行都需要重启内核,否则在本地查询等步骤会报错 53 | - 使用本电脑首次运行完成耗时15分钟 54 | - 对于非首次运行的打开过程为: 55 | 1. 激活当前项目的对应虚拟环境 56 | 2. 打开neo4j 57 | 3. 运行`Tiny-Graphrag_test.ipynb` 58 | - 其他要求: 59 | - 本项目以zhipuAI作为调用的大模型,需要调用其API,所以需要注册智谱API的帐号,从而获得API 60 | ### Tiny-Graphrag代码解读 61 | --- 62 | >下面将按照Graphrag的三个核心功能来介绍本项目的代码: 63 | #### 1. 知识图谱构建 64 | - 运行代码前需要启动neo4j客户端。 65 | - 模块导入,并添加API,其中API可以手动添加,也可以通过将API设置为环境变量的方法添加,本项目采用后者。 66 | ```python 67 | # 导入模块 68 | import os 69 | import sys 70 | 71 | from Tiny-Graph.graph import Tiny-Graph 72 | from Tiny-Graph.embedding.zhipu import zhipuEmb 73 | from Tiny-Graph.llm.zhipu import zhipuLLM 74 | 75 | from neo4j import GraphDatabase 76 | from dotenv import load_dotenv # 用于加载环境变量 77 | 78 | sys.path.append('.') # 添加当前目录到 Python 路径 79 | print(os.getcwd()) # 验证下当前工作路径 80 | 81 | # 加载 .env文件, 从而导入api_key 82 | load_dotenv() # 加载工作目录下的 .env 文件 83 | ``` 84 | ##### 1.1 emb、llm类的实例化 85 | - 将zhipuAi的嵌入模型(zhipuEmb)、zhipuLLM以及Tiny-Graph类分别实例化: 86 | - llm以及模型的embedding服务,依次完成实例化。其中的llm以及embedding可以根据自己的需要再调整,此处作为示例用,两者分别传入了嵌入模型 / LLM模型的名称以及API_KEY 87 | - 对应代码 88 | ```python 89 | emb = zhipuEmb( 90 | model_name="embedding-2", # 嵌入模型 91 | api_key=os.getenv('API_KEY') 92 | ) 93 | llm = zhipuLLM( 94 | model_name="glm-3-turbo", # LLM 95 | api_key=os.getenv('API_KEY') 96 | ) 97 | ``` 98 | - 以`zhipuEmb`为例,分析下类的继承关系。此处的`zhipuEmb`类是继承于`BaseEmb`类,在类实例化的过程(此处为`emb = zhipuEmb`)中会先调用`__init__`方法; 99 | ```python 100 | class zhipuEmb(BaseEmb): 101 | def __init__(self, model_name: str, api_key: str, **kwargs): 102 | super().__init__(model_name=model_name, **kwargs) 103 | self.client = ZhipuAI(api_key=api_key) # 创建 ZhipuAI 客户端,self.client 是zhipuEmb类的一个属性 104 | 105 | def get_emb(self, text: str) -> List[float]: 106 | emb = self.client.embeddings.create( 107 | model=self.model_name, 108 | input=text, 109 | ) 110 | return emb.data[0].embedding 111 | ``` 112 | - 为了调用`zhipuEmb`继承的`BaseEmb`类的属性,使用`super().__init__(model_name=model_name, **kwargs)`将模型名称传入`zhipuEmb`继承的`BaseEmb`类; 113 | - 而`BaseEmb`类继承自`ABC`类(`Abstract Base Class`,抽象基类) 114 | - `zhipuLLM`的实例化过程与此类似。 115 | ##### 1.2 Tiny-Graph类的实例化 116 | - 传入了neo4j的默认网址、用户名、密码、llm、emb。 117 | - 对应代码 118 | ```python 119 | graph = Tiny-Graph( 120 | url="neo4j://localhost:7687", 121 | username="neo4j", 122 | password="neo4j-passwordTGR", 123 | llm=llm, 124 | emb=emb, 125 | ) 126 | ``` 127 | - 实例化过程自动调用的`__init__`方法完成了创建Neo4j数据库驱动、设置语言模型、设置嵌入模型、设置工作目录等工作,详细注释见下方代码: 128 | ```python 129 | class Tiny-Graph: 130 | """ 131 | 一个用于处理图数据库和语言模型的类。 132 | 133 | 该类通过连接到Neo4j图数据库,并使用语言模型(LLM)和嵌入模型(Embedding)来处理文档和图数据。 134 | 它还管理一个工作目录,用于存储文档、文档块和社区数据。 135 | """ 136 | 137 | def __init__( 138 | self, 139 | url: str, # Neo4j数据库的URL 140 | username: str, # Neo4j数据库的用户名 141 | password: str, # Neo4j数据库的密码 142 | llm: BaseLLM, # 语言模型(LLM)实例 143 | emb: BaseLLM, # 嵌入模型(Embedding)实例 144 | working_dir: str = "workspace", # 工作目录,默认为"workspace" 145 | ): 146 | """ 147 | 初始化Tiny-Graph类。 148 | 149 | 参数: 150 | - url: Neo4j数据库的URL 151 | - username: Neo4j数据库的用户名 152 | - password: Neo4j数据库的密码 153 | - llm: 语言模型(LLM)实例 154 | - emb: 嵌入模型(Embedding)实例 155 | - working_dir: 工作目录,默认为"workspace" 156 | """ 157 | self.driver = driver = GraphDatabase.driver( 158 | url, auth=(username, password) 159 | ) # 创建Neo4j数据库驱动 160 | self.llm = llm # 设置语言模型 161 | self.embedding = emb # 设置嵌入模型 162 | self.working_dir = working_dir # 设置工作目录 163 | os.makedirs(self.working_dir, exist_ok=True) # 创建工作目录(如果不存在) 164 | 165 | # 定义文档、文档块和社区数据的文件路径 166 | self.doc_path = os.path.join(working_dir, "doc.txt") 167 | self.chunk_path = os.path.join(working_dir, "chunk.json") 168 | self.community_path = os.path.join(working_dir, "community.json") 169 | 170 | # 创建文件(如果不存在) 171 | create_file_if_not_exists(self.doc_path) 172 | create_file_if_not_exists(self.chunk_path) 173 | create_file_if_not_exists(self.community_path) 174 | 175 | # 加载已加载的文档 176 | self.loaded_documents = self.get_loaded_documents() 177 | ``` 178 | ##### 1.3 添加文档到图数据库 179 | - 使用Tiny-Graph类下的`add_document`方法来将指定路径的文档添加到图数据库中`graph.add_document("example/data.md")`。该方法会自动处理文档的分块和嵌入生成,并将结果存储在图数据库中。这里的路径是相对路径,指向当前工作目录下的example/data.md文件。其主要功能如下: 180 | ###### 1.3.1 检查文档是否已经分块; 181 | - 对应代码 182 | ```python 183 | # ================ Check if the document has been loaded ================ 184 | if filepath in self.get_loaded_documents(): 185 | print( 186 | f"Document '{filepath}' has already been loaded, skipping import process." 187 | ) 188 | return # 在这段代码中,return 的作用是 终止函数的执行,并返回到调用该函数的地方 189 | ``` 190 | - 功能:检查指定文档是否已经被加载过,避免重复处理。 191 | - 实现步骤: 192 | 1. 调用`self.get_loaded_documents()`方法,读取已加载文档的缓存文件(doc.txt),返回一个包含已加载文档路径的集合 193 | 2. 检查文档路径是否已经存在,如果已经存在,则打印提示信息 194 | 3. 中止函数的执行,return在此段代码中的作用是中止函数的执行,并返回到调用该函数的地方。 195 | 2. 将文档分割成块(此处就是分割为json格式的文件); 196 | 197 | 198 | ###### 1.3.2. 将文档分割成块(此处就是分割为json格式的文件) 199 | - 对应代码 200 | ```python 201 | # ================ Chunking ================ 202 | chunks = self.split_text(filepath) 203 | existing_chunks = read_json_file(self.chunk_path) 204 | 205 | # Filter out chunks that are already in storage 206 | new_chunks = {k: v for k, v in chunks.items() if k not in existing_chunks} 207 | 208 | if not new_chunks: 209 | print("All chunks are already in the storage.") 210 | return 211 | 212 | # Merge new chunks with existing chunks 213 | all_chunks = {**existing_chunks, **new_chunks} 214 | write_json_file(all_chunks, self.chunk_path) 215 | print(f"Document '{filepath}' has been chunked.") 216 | ``` 217 | - 功能:将文档分割成多个小块(chunks),并将这些分块存储到一个JSON文件中,避免重复存储已经存在的分块。 218 | - 实现步骤: 219 | 1. 分割文档:调用`chunks = self.split_text(filepath)`方法,将文档分割成多个小块,并且相邻小块之间有一定重叠,返回值chunks是一个字典,键是分块的唯一ID,值是分块的内容。 220 | 2. 读取已经存储的分块:`existing_chunks = read_json_file(self.chunk_path)`,调用该方法从chunk.json中读取已经存储的分块,返回值existing_chunks是一个字典,包含所有已经存储的分块。 221 | 3. 过渡新分块:`new_chunks = {k: v for k, v in chunks.items() if k not in existing_chunks}`,使用字典推导式过滤出新的分块,返回值new_chunks是一个字典,包含所有新的分块。 222 | 4. 检查是否有新的分块,如果new_chunks为空,也就是没有新的分块需要存储的话,打印提示信息并终止函数执行。 223 | ```python 224 | if not new_chunks: 225 | print("All chunks are already in the storage.") 226 | return 227 | ``` 228 | 5. 合并分块:`all_chunks = {**existing_chunks, **new_chunks}`,使用字典包语法将existing_chunks和new_chunks合并为一个新的字典。 229 | 6. 写入JSON文件:`write_json_file(all_chunks, self.chunk_path)`,将合并后的分块写入chunk.json文件。 230 | 7. 打印提示信息。 231 | 232 | ###### 1.3.3 从块中提取实体(entities)和三元组(triplets); 233 | - 对应代码 234 | ```python 235 | # ================ Entity Extraction ================ 236 | all_entities = []# 用于存储从文档块中提取的实体 237 | all_triplets = []# 用于存储从所有文档中提取的三元组 238 | 239 | # 遍历文档块,每个分块有一个唯一的chunk_id和对应的内容chunk_content 240 | for chunk_id, chunk_content in tqdm( 241 | new_chunks.items(), desc=f"Processing '{filepath}'" 242 | ): 243 | try: 244 | # 从当前分块中提取实体,每个实体包含名称、描述、关联的分块ID以及唯一的实体ID 245 | entities = self.get_entity(chunk_content, chunk_id=chunk_id) 246 | all_entities.extend(entities) 247 | # 从当前分块中提取三元组,每个三元组由主语(subject)、谓语(predicate)和宾语(object)组成,表示实体之间的关系 248 | triplets = self.get_triplets(chunk_content, entities) 249 | all_triplets.extend(triplets) 250 | except: 251 | print( 252 | f"An error occurred while processing chunk '{chunk_id}'. SKIPPING..." 253 | ) 254 | 255 | print( 256 | f"{len(all_entities)} entities and {len(all_triplets)} triplets have been extracted." 257 | ) 258 | ``` 259 | - 功能:遍历文档块以及从当前分块中提取实体和三元组,其中提取实体和三元组,均使用llm来完成。下面分析这代代码的实现步骤,再简单解释下实体和三元组的定义与结构。 260 | - 实现步骤: 261 | 1. 初始化存储容器; 262 | 2. 遍历文档块,遍历new_chunks字典,其中每一块有一个chunk_id和对应的内容chunk_content。 263 | 3. 提取实体和三元组:首先调用self.get_entity(chunk_conten, chunk_id= chunk_id)方法,从当前分块中提取实体,将提取到的实体追加到all_entities列表中;然后调用self.get_triplets(chunk_content, entities)方法,从当前分块中提取三元组,将提取到的三元组追加到all_triplets列表中。如果在处理过程中出现错误,打印错误信息并跳过该分块。 264 | 4. 打印提取的实体和三元组综述,便于检查和提取结果。 265 | - 实体的定义与结构 266 | - 定义:实体是文档中提取的关键概念或对象,通常是名词或专有名词。 267 | - 结构示意 268 | ```python 269 | { 270 | "name": "Entity Name", # 实体名称 271 | "description": "Entity Description", # 实体描述 272 | "chunks id": ["chunk-1a2b3c"], # 关联的文档块 ID 273 | "entity id": "entity-123456" # 实体的唯一标识符 274 | } 275 | ``` 276 | - 三元组的定义与结构 277 | - 定义:三元组是描述实体之间关系的结构,包含主语(subject)、谓语(predicate)和宾语(object)。 278 | - 结构示意 279 | ```python 280 | { 281 | "subject": "Subject Name", # 主语名称 282 | "subject_id": "entity-123456", # 主语的唯一标识符 283 | "predicate": "Predicate Name", # 谓语(关系名称) 284 | "object": "Object Name", # 宾语名称 285 | "object_id": "entity-654321" # 宾语的唯一标识符 286 | } 287 | ``` 288 | 289 | - 实体(Entities)是图数据库中的节点,表示文档中的关键概念。三元组(Triplets)是+图数据库中的边,表示实体之间的关系,Neo4j中的节点与三元组关系如下所示: 290 | 291 |
292 | 293 |
294 | 295 | ###### 1.3.4 执行实体消歧和三元组更新。实体消歧有两种方法可以选择,默认将同名实体认为是同一实体 296 | - 对应代码 297 | ```python 298 | # ================ Entity Disambiguation ================ 299 | entity_names = list(set(entity["name"] for entity in all_entities)) 300 | 301 | if use_llm_deambiguation: 302 | entity_id_mapping = {} 303 | for name in entity_names: 304 | same_name_entities = [ 305 | entity for entity in all_entities if entity["name"] == name 306 | ] 307 | transform_text = self.llm.predict( 308 | ENTITY_DISAMBIGUATION.format(same_name_entities) 309 | ) 310 | entity_id_mapping.update( 311 | get_text_inside_tag(transform_text, "transform") 312 | ) 313 | else: 314 | entity_id_mapping = {} 315 | for entity in all_entities: 316 | entity_name = entity["name"] 317 | if entity_name not in entity_id_mapping: 318 | entity_id_mapping[entity_name] = entity["entity id"] 319 | 320 | for entity in all_entities: 321 | entity["entity id"] = entity_id_mapping.get( 322 | entity["name"], entity["entity id"] 323 | ) 324 | 325 | triplets_to_remove = [ 326 | triplet 327 | for triplet in all_triplets 328 | if entity_id_mapping.get(triplet["subject"], triplet["subject_id"]) is None 329 | or entity_id_mapping.get(triplet["object"], triplet["object_id"]) is None 330 | ] 331 | 332 | updated_triplets = [ 333 | { 334 | **triplet, 335 | "subject_id": entity_id_mapping.get( 336 | triplet["subject"], triplet["subject_id"] 337 | ), 338 | "object_id": entity_id_mapping.get( 339 | triplet["object"], triplet["object_id"] 340 | ), 341 | } 342 | for triplet in all_triplets 343 | if triplet not in triplets_to_remove 344 | ] 345 | all_triplets = updated_triplets 346 | ``` 347 | - 对于实体消歧(Entity Disambiguation)部分 348 | - 功能: 349 | - 解决同名实体歧义的问题,确保每个实体都有唯一的entity_id。如果启用了LLM消歧(use_llm_deambiguation=True),则默认将同名实体视为同一实体;如果未启用LLM消歧,则默认将同名实体视为同一实体。本项目采用后者。 350 | - 实现步骤: 351 | 1. 提取实体的名称存储到entity_names中; 352 | 2. 使用默认方法消歧义 353 | 3. 更新实体ID 354 | - 对于三元组更新(Triplet Update)部分 355 | - 功能: 356 | - 根据消歧后的实体ID更新三元组,并移除无效的三元组。 357 | - 实现步骤: 358 | 1. 移除所有无效的三元组(如果三元组的主语或者宾语的实体ID无法在entity_id_mapping中找到,则将其标记为无效); 359 | 2. 更新三元组(对于有效的三元组,更新其主语和宾语的实体ID) 360 | 3. 保存更新后的三元组(将更新后的三元组列表保存到all_triplets中) 361 | ###### 1.3.5 合并实体和三元组 362 | - 对应代码 363 | ```python 364 | # ================ Merge Entities ================ 365 | entity_map = {} 366 | 367 | for entity in all_entities: 368 | entity_id = entity["entity id"] 369 | if entity_id not in entity_map: 370 | entity_map[entity_id] = { 371 | "name": entity["name"], 372 | "description": entity["description"], 373 | "chunks id": [], 374 | "entity id": entity_id, 375 | } 376 | else: 377 | entity_map[entity_id]["description"] += " " + entity["description"] 378 | 379 | entity_map[entity_id]["chunks id"].extend(entity["chunks id"]) 380 | ``` 381 | - 功能: 382 | - 将所有提取的实体(all_entities)按照其唯一标识符(entity_id)进行归并,确保同一个实体的描述和关联的文档块ID被整合到一起 383 | - 实现步骤: 384 | - 使用一个字典entity_map,以entity_id作为键,存储每个实体的合并信息。如果某个实体entity_id已经存在于entity_map中,则将其描述和文档块ID合并到已有的实体中。 385 | ###### 1.3.6 将合并的实体和三元组存储到Neo4j的图数据库中 386 | - 对应代码 387 | ```python 388 | # ================ Store Data in Neo4j ================ 389 | for triplet in all_triplets: 390 | subject_id = triplet["subject_id"] 391 | object_id = triplet["object_id"] 392 | 393 | subject = entity_map.get(subject_id) 394 | object = entity_map.get(object_id) 395 | if subject and object: 396 | self.create_triplet(subject, triplet["predicate"], object) 397 | ``` 398 | - 功能:将提取的三元组(triplets)存储到Neo4j图数据库中 399 | - 实现步骤: 400 | 1. 遍历all_triplets列表,逐个处理每个三元组 401 | 2. 根据三元组中的subject_id和object_id,从entity_map中获取对应的实体信息 402 | 3. 如果主语和宾语实体都存在,则调用self.create_triplet方法,将三元组存储到Neo4j中。其中的create_triplet方法能够通过Cypher查询语句将实体和关系插入到数据库中。 403 | ###### 1.3.7 生成社区内容 404 | - 对应代码 405 | ```python 406 | # ================ communities ================ 407 | self.gen_community() 408 | self.generate_community_report() 409 | ``` 410 | - 功能: 411 | - 生成社区:通过图算法(本项目为 Leiden 算法)检测图中的社区结构。 412 | - 生成社区报告:借助大语言模型为每个社区生成详细的报告,描述社区中的实体和关系。 413 | - 实现步骤: 414 | 1. 对于生成社区功能,调用 self.gen_community() 方法: 415 | - 使用 Neo4j 的图算法(如 gds.leiden.write)检测社区。 416 | - 生成社区架构(community schema),包括社区的层级、节点、边等信息。 417 | - 将社区架构存储到 community.json 文件中。 418 | 2. 对于生成社区报告功能,调用 self.generate_community_report() 方法: 419 | - 遍历每个社区,生成包含实体和关系的报告。 420 | - 报告通过大语言模型(LLM)生成,描述社区的结构和内容。 421 | ###### 1.3.8 生成嵌入式向量 422 | - 对应代码 423 | ```python 424 | # ================ embedding ================ 425 | self.add_embedding_for_graph() 426 | self.add_loaded_documents(filepath) 427 | print(f"doc '{filepath}' has been loaded.") 428 | ``` 429 | - 功能: 430 | - 为图数据库中的每个实体节点生成嵌入向量(embedding),用于计算相似度(本项目采用余弦相似度)和查询。 431 | - 将处理过的文档路径记录到缓存文件中,避免重复处理。 432 | - 实现步骤: 433 | 1. 生成嵌入:调用 self.add_embedding_for_graph() 方法:遍历图数据库中的每个实体节点;使用嵌入模型(self.embedding)计算节点描述的嵌入向量;将嵌入向量存储到节点的 embedding 属性中。 434 | 2. 记录文档路径:调用 self.add_loaded_documents(filepath) 方法:将当前文路径添加到缓存文件中,避免重复加载。 435 | - 最终生成的图数据信息如下所示: 436 | 437 |
438 | 439 |
440 | 441 | ###### 1.3.9 验证下数据库连接是否正常(当然,此步也可省略) 442 | - 对应代码 443 | ```python 444 | with graph.driver.session() as session: 445 | result = session.run("MATCH (n) RETURN count(n) as count") 446 | count = result.single()["count"] 447 | print(f"数据库连接正常,节点数量: {count}") 448 | ``` 449 | #### 2. 图检索优化 450 | ##### 2.1 两种图检索方法 451 | - 按照Tiny-Graphrag demo代码的执行过程,图检索优化过程有两种:分别为Tiny-Graph类中`local_query`方法和`global_query`方法。这两个方法通俗来讲就是根据问题(本项目中为"what is dl?"),得到了局部检索和全局检索的两种上下文,然后交给大模型来处理。 452 | - 全局查询和局部查询的特点如下表所示: 453 | 454 | | 查询类型 | 特点 | 适用场景 | 455 | |----------|------|----------| 456 | | 全局查询(global_query) | • 基于社区层级评分
• 筛选候选社区
• 返回排序列表 | • 高层次理解
• 全局视角分析 | 457 | | 局部查询(local_query) | • 基于直接关联上下文
• 提取精确实体/关系
• 返回多部分结果 | • 精确定位
• 深度分析 | 458 | - 下面依次分析下`local_query`方法和`global_query`方法的具体实现过程。 459 | ##### 2.2 local_query方法 460 | - 在Tiny_Graphrag_test.ipynb中,执行局部查询测试时,使用的是local_query方法 461 | - 具体代码为:`local_res = graph.local_query("what is dl?")` 462 | - 其中调用的方法`local_query("what is dl?")`,将"what is dl?"传递给`local_query()`方法,以下是`local_query()`方法的代码内容和代码解读 463 | - 代码内容 464 | ```python 465 | def local_query(self, query): 466 | context = self.build_local_query_context(query) # 分别包括社区、实体、关系、文档块这四部分 467 | prompt = LOCAL_QUERY.format(query=query, context=context) # 需要的参数context以及query都在该方法内得到了 468 | response = self.llm.predict(prompt) 469 | return response 470 | ``` 471 | - 代码解读 472 | - 执行`context = self.build_local_query_context(query)`后,根据用户问题(本项目中是"what is dl?")得到的包括社区、实体、关系、文档块的这四部分内容的上下文(context)。得到上下文的具体方法为:`build_local_query_context(self, query)`,该方法内的代码执行顺序是: 473 | 1. 获得输入文本的嵌入向量。对应代码为:`query_emb = self.embedding.get_emb(query)` 474 | 2. 获得前k个最相似的实体,相似判断的依据是余弦相似度。对应代码为:`topk_similar_entities_context = self.get_topk_similar_entities(query_emb)` 475 | 3. 获得前k个最相似的社区,依据的方法是: 476 | - 利用上面得到的最相似实体; 477 | - 只要包含上述的任意topk节点,就认为是相似社区(社区:community,由相互关联的节点组成的集合)。对应代码为: 478 | ```python 479 | topk_similar_communities_context = self.get_communities( 480 | topk_similar_entities_context 481 | ) 482 | ``` 483 | 4. 获得前k个最相似的关系,依据的方法是:在`get_relations`方法中调用`get_node_edgs`方法,获取该实体的所有关系边,认为这些边就是similar relation。对应代码为: 484 | ```python 485 | topk_similar_relations_context = self.get_relations( 486 | topk_similar_entities_context, query 487 | ) 488 | ``` 489 | 5. 获得前k个最相似的文档块,依据是方法是:在`get_chunks`方法中调用`get_node_chunks`方法,获取该实体关联的文档块,认为这些文档块就是similar chunks。对应代码为: 490 | ```python 491 | topk_similar_chunks_context = self.get_chunks( 492 | topk_similar_entities_context, query 493 | ) 494 | ``` 495 | 6. `build_local_query_context()`方法最终返回的是一个多行字符串,包括: 496 | - Reports:社区报告; 497 | - Entities:与查询最相似的实体; 498 | - Relationships:这些实体之间的关系; 499 | - Sources:这些实体关联的文档块。对应的代码为: 500 | ```python 501 | return f""" 502 | -----Reports----- 503 | ```csv 504 | {topk_similar_communities_context} 505 | ``` 506 | -----Entities----- 507 | ```csv 508 | {topk_similar_entities_context} 509 | ``` 510 | -----Relationships----- 511 | ```csv 512 | {topk_similar_relations_context} 513 | ``` 514 | -----Sources----- 515 | ```csv 516 | {topk_similar_chunks_context} 517 | ``` 518 | """ 519 | ``` 520 | - 之后的`prompt = LOCAL_QUERY.format(query=query, context=context)`可以理解为根据刚刚生成的context作为上下文,生成prompt为大模型使用。 521 | - 最后 ,`response = self.llm.predict(prompt)`是将上文得到的prompt传输给大模型,从而让大模型做推理和回答,然后该方法返回到`response(return response)`作为大模型的回答结果。 522 | ###### 2.3 global_query方法 523 | - 在`Tiny_Graphrag_test.ipynb`中,执行全局查询测试时,使用的是`global_query`方法 524 | - 具体代码为:`global_res = graph.global_query("what is dl?")` 525 | - 其中调用的方法`global_query("what is dl?")`,将"what is dl?"传递给`global_query()`方法,以下是`global_query()`方法的代码内容和代码解读 526 | - 代码内容: 527 | ```python 528 | def global_query(self, query, level=1): 529 | context = self.build_global_query_context(query, level) # 得到的是一个列表,包含社区的描述和分数 530 | prompt = GLOBAL_QUERY.format(query=query, context=context)# 将得到的context传入到prompt中 531 | response = self.llm.predict(prompt)# 将prompt传入到llm中,得到最终的结果,也就是将包含描述和分数的列表传入到llm中 532 | return response 533 | ``` 534 | - 代码解读: 535 | - 运行`context = self.build_global_query_context(query, level)`时,会根据用户问题(本项目中是“what is dl?")得到的包含社区描述和分数的上下文(context)。对应代码为:`context = self.build_global_query_context(query, level)`,该方法内的代码执行顺序是: 536 | 1. 设定空的候选社区(字典)以及空的社区评分列表(列表),并筛选符合层级要求的社区。对应代码为: 537 | ```python 538 | communities_schema = self.read_community_schema() 539 | candidate_community = {} # 候选社区 字典 540 | points = [] # 社区评分列表 列表 541 | # 筛选符合层级要求的社区 542 | for communityid, community_info in communities_schema.items(): 543 | if community_info["level"] < level: 544 | candidate_community.update({communityid: community_info}) 545 | ``` 546 | 2. 计算候选的社区的评分,通过调用`map_community_points`函数,结合社区报告和大语言模型的能力,为每个候选社区生成与查询内容(如 "What is DL?")相关程度的评分。对应的代码为: 547 | ```python 548 | for communityid, community_info in candidate_community.items(): 549 | points.extend(self.map_community_points(community_info["report"], query)) 550 | ``` 551 | 3. 按照评分降序排序,得到包含描述和分数的列表。描述是社区的描述,分数是查询的相关性得分。对应代码为: 552 | ```python 553 | points = sorted(points, key=lambda x: x[-1], reverse=True) 554 | return points # 得到包含描述和分数的列表,描述是社区的描述,分数是查询的相关性得分 555 | ``` 556 | 4. 之后的`prompt = GLOBAL_QUERY.format(query=query, context=context)`可以理解为根据刚刚生成的context作为上下文,生成prompt给大模型使用。 557 | 5. 最后,`response=self.llm.predict(prompt)`将上文得到的prompt传输给大模型。`return response`作为大模型的回答结果。 558 | ##### 2.4 生成增强 559 | 1. 通俗来讲就是:将得到的上下文输入给大模型,基于此上下文,大模型作推理和回答 560 | 2. 在本项目代码中,`local_query`方法和`global_query`方法的将各自得到的上下文传输给大模型将是生成增强的过程。 561 | - 局部查询和全局查询成功运行的示例: 562 | 563 |
564 | 565 |
566 | 567 | -------------------------------------------------------------------------------- /tinygraph/graph.py: -------------------------------------------------------------------------------- 1 | from neo4j import GraphDatabase 2 | import os 3 | from tqdm import tqdm 4 | from .utils import ( 5 | get_text_inside_tag, 6 | cosine_similarity, 7 | compute_mdhash_id, 8 | read_json_file, 9 | write_json_file, 10 | create_file_if_not_exists, 11 | ) 12 | from .llm.base import BaseLLM 13 | from .embedding.base import BaseEmb 14 | from .prompt import * 15 | from typing import Dict, List, Optional, Tuple, Union 16 | import numpy as np 17 | from collections import defaultdict 18 | import json 19 | 20 | from dataclasses import dataclass 21 | 22 | 23 | @dataclass 24 | class Node: 25 | name: str 26 | desc: str 27 | chunks_id: list 28 | entity_id: str 29 | similarity: float 30 | 31 | 32 | class TinyGraph: 33 | """ 34 | 一个用于处理图数据库和语言模型的类。 35 | 36 | 该类通过连接到Neo4j图数据库,并使用语言模型(LLM)和嵌入模型(Embedding)来处理文档和图数据。 37 | 它还管理一个工作目录,用于存储文档、文档块和社区数据。 38 | """ 39 | 40 | def __init__( 41 | self, 42 | url: str, # Neo4j数据库的URL 43 | username: str, # Neo4j数据库的用户名 44 | password: str, # Neo4j数据库的密码 45 | llm: BaseLLM, # 语言模型(LLM)实例 46 | emb: BaseLLM, # 嵌入模型(Embedding)实例 47 | working_dir: str = "workspace", # 工作目录,默认为"workspace" 48 | ): 49 | """ 50 | 初始化TinyGraph类。 51 | 52 | 参数: 53 | - url: Neo4j数据库的URL 54 | - username: Neo4j数据库的用户名 55 | - password: Neo4j数据库的密码 56 | - llm: 语言模型(LLM)实例 57 | - emb: 嵌入模型(Embedding)实例 58 | - working_dir: 工作目录,默认为"workspace" 59 | """ 60 | self.driver = driver = GraphDatabase.driver( 61 | url, auth=(username, password) 62 | ) # 创建Neo4j数据库驱动 63 | self.llm = llm # 设置语言模型 64 | self.embedding = emb # 设置嵌入模型 65 | self.working_dir = working_dir # 设置工作目录 66 | os.makedirs(self.working_dir, exist_ok=True) # 创建工作目录(如果不存在) 67 | 68 | # 定义文档、文档块和社区数据的文件路径 69 | self.doc_path = os.path.join(working_dir, "doc.txt") 70 | self.chunk_path = os.path.join(working_dir, "chunk.json") 71 | self.community_path = os.path.join(working_dir, "community.json") 72 | 73 | # 创建文件(如果不存在) 74 | create_file_if_not_exists(self.doc_path) 75 | create_file_if_not_exists(self.chunk_path) 76 | create_file_if_not_exists(self.community_path) 77 | 78 | # 加载已加载的文档 79 | self.loaded_documents = self.get_loaded_documents() 80 | 81 | def create_triplet(self, subject: dict, predicate, object: dict) -> None: 82 | """ 83 | 创建一个三元组(Triplet)并将其存储到Neo4j数据库中。 84 | 85 | 参数: 86 | - subject: 主题实体的字典,包含名称、描述、块ID和实体ID 87 | - predicate: 关系名称 88 | - object: 对象实体的字典,包含名称、描述、块ID和实体ID 89 | 90 | 返回: 91 | - 查询结果 92 | """ 93 | # 定义Cypher查询语句,用于创建或合并实体节点和关系 94 | query = ( 95 | "MERGE (a:Entity {name: $subject_name, description: $subject_desc, chunks_id: $subject_chunks_id, entity_id: $subject_entity_id}) " 96 | "MERGE (b:Entity {name: $object_name, description: $object_desc, chunks_id: $object_chunks_id, entity_id: $object_entity_id}) " 97 | "MERGE (a)-[r:Relationship {name: $predicate}]->(b) " 98 | "RETURN a, b, r" 99 | ) 100 | 101 | # 使用数据库会话执行查询 102 | with self.driver.session() as session: 103 | result = session.run( 104 | query, 105 | subject_name=subject["name"], 106 | subject_desc=subject["description"], 107 | subject_chunks_id=subject["chunks id"], 108 | subject_entity_id=subject["entity id"], 109 | object_name=object["name"], 110 | object_desc=object["description"], 111 | object_chunks_id=object["chunks id"], 112 | object_entity_id=object["entity id"], 113 | predicate=predicate, 114 | ) 115 | 116 | return 117 | 118 | def split_text(self,file_path:str, segment_length=300, overlap_length=50) -> Dict: 119 | """ 120 | 将文本文件分割成多个片段,每个片段的长度为segment_length,相邻片段之间有overlap_length的重叠。 121 | 122 | 参数: 123 | - file_path: 文本文件的路径 124 | - segment_length: 每个片段的长度,默认为300 125 | - overlap_length: 相邻片段之间的重叠长度,默认为50 126 | 127 | 返回: 128 | - 包含片段ID和片段内容的字典 129 | """ 130 | chunks = {} # 用于存储片段的字典 131 | with open(file_path, "r", encoding="utf-8") as file: 132 | content = file.read() # 读取文件内容 133 | 134 | text_segments = [] # 用于存储分割后的文本片段 135 | start_index = 0 # 初始化起始索引 136 | 137 | # 循环分割文本,直到剩余文本长度不足以形成新的片段 138 | while start_index + segment_length <= len(content): 139 | text_segments.append(content[start_index : start_index + segment_length]) 140 | start_index += segment_length - overlap_length # 更新起始索引,考虑重叠长度 141 | 142 | # 处理剩余的文本,如果剩余文本长度小于segment_length但大于0 143 | if start_index < len(content): 144 | text_segments.append(content[start_index:]) 145 | 146 | # 为每个片段生成唯一的ID,并将其存储在字典中 147 | for segement in text_segments: 148 | chunks.update({compute_mdhash_id(segement, prefix="chunk-"): segement}) 149 | 150 | return chunks 151 | 152 | def get_entity(self, text: str, chunk_id: str) -> List[Dict]: 153 | """ 154 | 从给定的文本中提取实体,并为每个实体生成唯一的ID和描述。 155 | 156 | 参数: 157 | - text: 输入的文本 158 | - chunk_id: 文本块的ID 159 | 160 | 返回: 161 | - 包含提取的实体信息的列表 162 | """ 163 | # 使用语言模型预测实体信息 164 | data = self.llm.predict(GET_ENTITY.format(text=text)) 165 | concepts = [] # 用于存储提取的实体信息 166 | 167 | # 从预测结果中提取实体信息 168 | for concept_html in get_text_inside_tag(data, "concept"): 169 | concept = {} 170 | concept["name"] = get_text_inside_tag(concept_html, "name")[0].strip() 171 | concept["description"] = get_text_inside_tag(concept_html, "description")[ 172 | 0 173 | ].strip() 174 | concept["chunks id"] = [chunk_id] 175 | concept["entity id"] = compute_mdhash_id( 176 | concept["description"], prefix="entity-" 177 | ) 178 | concepts.append(concept) 179 | 180 | return concepts 181 | 182 | def get_triplets(self, content, entity: list) -> List[Dict]: 183 | """ 184 | 从给定的内容中提取三元组(Triplet)信息,并返回包含这些三元组信息的列表。 185 | 186 | 参数: 187 | - content: 输入的内容 188 | - entity: 实体列表 189 | 190 | 返回: 191 | - 包含提取的三元组信息的列表 192 | """ 193 | try: 194 | # 使用语言模型预测三元组信息 195 | data = self.llm.predict(GET_TRIPLETS.format(text=content, entity=entity)) 196 | data = get_text_inside_tag(data, "triplet") 197 | except Exception as e: 198 | print(f"Error predicting triplets: {e}") 199 | return [] 200 | 201 | res = [] # 用于存储提取的三元组信息 202 | 203 | # 从预测结果中提取三元组信息 204 | for triplet_data in data: 205 | try: 206 | subject = get_text_inside_tag(triplet_data, "subject")[0] 207 | subject_id = get_text_inside_tag(triplet_data, "subject_id")[0] 208 | predicate = get_text_inside_tag(triplet_data, "predicate")[0] 209 | object = get_text_inside_tag(triplet_data, "object")[0] 210 | object_id = get_text_inside_tag(triplet_data, "object_id")[0] 211 | res.append( 212 | { 213 | "subject": subject, 214 | "subject_id": subject_id, 215 | "predicate": predicate, 216 | "object": object, 217 | "object_id": object_id, 218 | } 219 | ) 220 | except Exception as e: 221 | print(f"Error extracting triplet: {e}") 222 | continue 223 | 224 | return res 225 | 226 | def add_document(self, filepath, use_llm_deambiguation=False) -> None: 227 | """ 228 | 将文档添加到系统中,执行以下步骤: 229 | 1. 检查文档是否已经加载。 230 | 2. 将文档分割成块。 231 | 3. 从块中提取实体和三元组。 232 | 4. 执行实体消岐,有两种方法可选,默认将同名实体认为即为同一实体。 233 | 5. 合并实体和三元组。 234 | 6. 将合并的实体和三元组存储到Neo4j数据库中。 235 | 236 | 参数: 237 | - filepath: 要添加的文档的路径 238 | - use_llm_deambiguation: 是否使用LLM进行实体消岐 239 | """ 240 | # ================ Check if the document has been loaded ================ 241 | if filepath in self.get_loaded_documents(): 242 | print( 243 | f"Document '{filepath}' has already been loaded, skipping import process." 244 | ) 245 | return 246 | 247 | # ================ Chunking ================ 248 | chunks = self.split_text(filepath) 249 | existing_chunks = read_json_file(self.chunk_path) 250 | 251 | # Filter out chunks that are already in storage 252 | new_chunks = {k: v for k, v in chunks.items() if k not in existing_chunks} 253 | 254 | if not new_chunks: 255 | print("All chunks are already in the storage.") 256 | return 257 | 258 | # Merge new chunks with existing chunks 259 | all_chunks = {**existing_chunks, **new_chunks} 260 | write_json_file(all_chunks, self.chunk_path) 261 | print(f"Document '{filepath}' has been chunked.") 262 | 263 | # ================ Entity Extraction ================ 264 | all_entities = [] 265 | all_triplets = [] 266 | 267 | for chunk_id, chunk_content in tqdm( 268 | new_chunks.items(), desc=f"Processing '{filepath}'" 269 | ): 270 | try: 271 | entities = self.get_entity(chunk_content, chunk_id=chunk_id) 272 | all_entities.extend(entities) 273 | triplets = self.get_triplets(chunk_content, entities) 274 | all_triplets.extend(triplets) 275 | except: 276 | print( 277 | f"An error occurred while processing chunk '{chunk_id}'. SKIPPING..." 278 | ) 279 | 280 | print( 281 | f"{len(all_entities)} entities and {len(all_triplets)} triplets have been extracted." 282 | ) 283 | # ================ Entity Disambiguation ================ 284 | entity_names = list(set(entity["name"] for entity in all_entities)) 285 | 286 | if use_llm_deambiguation: 287 | entity_id_mapping = {} 288 | for name in entity_names: 289 | same_name_entities = [ 290 | entity for entity in all_entities if entity["name"] == name 291 | ] 292 | transform_text = self.llm.predict( 293 | ENTITY_DISAMBIGUATION.format(same_name_entities) 294 | ) 295 | entity_id_mapping.update( 296 | get_text_inside_tag(transform_text, "transform") 297 | ) 298 | else: 299 | entity_id_mapping = {} 300 | for entity in all_entities: 301 | entity_name = entity["name"] 302 | if entity_name not in entity_id_mapping: 303 | entity_id_mapping[entity_name] = entity["entity id"] 304 | 305 | for entity in all_entities: 306 | entity["entity id"] = entity_id_mapping.get( 307 | entity["name"], entity["entity id"] 308 | ) 309 | 310 | triplets_to_remove = [ 311 | triplet 312 | for triplet in all_triplets 313 | if entity_id_mapping.get(triplet["subject"], triplet["subject_id"]) is None 314 | or entity_id_mapping.get(triplet["object"], triplet["object_id"]) is None 315 | ] 316 | 317 | updated_triplets = [ 318 | { 319 | **triplet, 320 | "subject_id": entity_id_mapping.get( 321 | triplet["subject"], triplet["subject_id"] 322 | ), 323 | "object_id": entity_id_mapping.get( 324 | triplet["object"], triplet["object_id"] 325 | ), 326 | } 327 | for triplet in all_triplets 328 | if triplet not in triplets_to_remove 329 | ] 330 | all_triplets = updated_triplets 331 | 332 | # ================ Merge Entities ================ 333 | entity_map = {} 334 | 335 | for entity in all_entities: 336 | entity_id = entity["entity id"] 337 | if entity_id not in entity_map: 338 | entity_map[entity_id] = { 339 | "name": entity["name"], 340 | "description": entity["description"], 341 | "chunks id": [], 342 | "entity id": entity_id, 343 | } 344 | else: 345 | entity_map[entity_id]["description"] += " " + entity["description"] 346 | 347 | entity_map[entity_id]["chunks id"].extend(entity["chunks id"]) 348 | # ================ Store Data in Neo4j ================ 349 | for triplet in all_triplets: 350 | subject_id = triplet["subject_id"] 351 | object_id = triplet["object_id"] 352 | 353 | subject = entity_map.get(subject_id) 354 | object = entity_map.get(object_id) 355 | if subject and object: 356 | self.create_triplet(subject, triplet["predicate"], object) 357 | # ================ communities ================ 358 | self.gen_community() 359 | self.generate_community_report() 360 | # ================ embedding ================ 361 | self.add_embedding_for_graph() 362 | self.add_loaded_documents(filepath) 363 | print(f"doc '{filepath}' has been loaded.") 364 | 365 | def detect_communities(self) -> None: 366 | query = """ 367 | CALL gds.graph.project( 368 | 'graph_help', 369 | ['Entity'], 370 | { 371 | Relationship: { 372 | orientation: 'UNDIRECTED' 373 | } 374 | } 375 | ) 376 | """ 377 | with self.driver.session() as session: 378 | result = session.run(query) 379 | 380 | query = """ 381 | CALL gds.leiden.write('graph_help', { 382 | writeProperty: 'communityIds', 383 | includeIntermediateCommunities: True, 384 | maxLevels: 10, 385 | tolerance: 0.0001, 386 | gamma: 1.0, 387 | theta: 0.01 388 | }) 389 | YIELD communityCount, modularity, modularities 390 | """ 391 | with self.driver.session() as session: 392 | result = session.run(query) 393 | for record in result: 394 | print( 395 | f"社区数量: {record['communityCount']}, 模块度: {record['modularity']}" 396 | ) 397 | session.run("CALL gds.graph.drop('graph_help')") 398 | 399 | def get_entity_by_name(self, name): 400 | query = """ 401 | MATCH (n:Entity {name: $name}) 402 | RETURN n 403 | """ 404 | with self.driver.session() as session: 405 | result = session.run(query, name=name) 406 | entities = [record["n"].get("name") for record in result] 407 | return entities[0] 408 | 409 | def get_node_edgs(self, node: Node): 410 | query = """ 411 | MATCH (n)-[r]-(m) 412 | WHERE n.entity_id = $id 413 | RETURN n.name AS n,r.name AS r,m.name AS m 414 | """ 415 | with self.driver.session() as session: 416 | result = session.run(query, id=node.entity_id) 417 | edges = [(record["n"], record["r"], record["m"]) for record in result] 418 | return edges 419 | 420 | def get_node_chunks(self, node): 421 | existing_chunks = read_json_file(self.chunk_path) 422 | chunks = [existing_chunks[i] for i in node.chunks_id] 423 | return chunks 424 | 425 | def add_embedding_for_graph(self): 426 | query = """ 427 | MATCH (n) 428 | RETURN n 429 | """ 430 | with self.driver.session() as session: 431 | result = session.run(query) 432 | for record in result: 433 | node = record["n"] 434 | description = node["description"] 435 | id = node["entity_id"] 436 | embedding = self.embedding.get_emb(description) 437 | # 更新节点,添加新的 embedding 属性 438 | update_query = """ 439 | MATCH (n {entity_id: $id}) 440 | SET n.embedding = $embedding 441 | """ 442 | session.run(update_query, id=id, embedding=embedding) 443 | 444 | def get_topk_similar_entities(self, input_emb, k=1) -> List[Node]: 445 | res = [] 446 | query = """ 447 | MATCH (n) 448 | RETURN n 449 | """ 450 | with self.driver.session() as session: 451 | result = session.run(query) 452 | # 如果遇到报错:ResultConsumedError: The result has been consumed. Fetch all needed records before calling Result.consume().可将result = session.run(query)修改为result = list(session.run(query)) 453 | for record in result: 454 | node = record["n"] 455 | if node["embedding"] is not None: 456 | similarity = cosine_similarity(input_emb, node["embedding"]) 457 | node = Node( 458 | name=node["name"], 459 | desc=node["description"], 460 | chunks_id=node["chunks_id"], 461 | entity_id=node["entity_id"], 462 | similarity=similarity, 463 | ) 464 | res.append(node) 465 | return sorted(res, key=lambda x: x.similarity, reverse=True)[:k] 466 | 467 | def get_communities(self, nodes: List[Node]): 468 | communities_schema = self.read_community_schema() 469 | res = [] 470 | nodes_ids = [i.entity_id for i in nodes] 471 | for community_id, community_info in communities_schema.items(): 472 | if set(nodes_ids) & set(community_info["nodes"]): 473 | res.append( 474 | { 475 | "community_id": community_id, 476 | "community_info": community_info["report"], 477 | } 478 | ) 479 | return res 480 | 481 | def get_relations(self, nodes: List, input_emb): 482 | res = [] 483 | for i in nodes: 484 | res.append(self.get_node_edgs(i)) 485 | return res 486 | 487 | def get_chunks(self, nodes, input_emb): 488 | chunks = [] 489 | for i in nodes: 490 | chunks.append(self.get_node_chunks(i)) 491 | return chunks 492 | 493 | def gen_community_schema(self) -> dict[str, dict]: 494 | results = defaultdict( 495 | lambda: dict( 496 | level=None, 497 | title=None, 498 | edges=set(), 499 | nodes=set(), 500 | chunk_ids=set(), 501 | sub_communities=[], 502 | ) 503 | ) 504 | 505 | with self.driver.session() as session: 506 | # Fetch community data 507 | result = session.run( 508 | f""" 509 | MATCH (n:Entity) 510 | WITH n, n.communityIds AS communityIds, [(n)-[]-(m:Entity) | m.entity_id] AS connected_nodes 511 | RETURN n.entity_id AS node_id, 512 | communityIds AS cluster_key, 513 | connected_nodes 514 | """ 515 | ) 516 | 517 | max_num_ids = 0 518 | for record in result: 519 | for index, c_id in enumerate(record["cluster_key"]): 520 | node_id = str(record["node_id"]) 521 | level = index 522 | cluster_key = str(c_id) 523 | connected_nodes = record["connected_nodes"] 524 | 525 | results[cluster_key]["level"] = level 526 | results[cluster_key]["title"] = f"Cluster {cluster_key}" 527 | results[cluster_key]["nodes"].add(node_id) 528 | results[cluster_key]["edges"].update( 529 | [ 530 | tuple(sorted([node_id, str(connected)])) 531 | for connected in connected_nodes 532 | if connected != node_id 533 | ] 534 | ) 535 | for k, v in results.items(): 536 | v["edges"] = [list(e) for e in v["edges"]] 537 | v["nodes"] = list(v["nodes"]) 538 | v["chunk_ids"] = list(v["chunk_ids"]) 539 | for cluster in results.values(): 540 | cluster["sub_communities"] = [ 541 | sub_key 542 | for sub_key, sub_cluster in results.items() 543 | if sub_cluster["level"] > cluster["level"] 544 | and set(sub_cluster["nodes"]).issubset(set(cluster["nodes"])) 545 | ] 546 | 547 | return dict(results) 548 | 549 | def gen_community(self): 550 | self.detect_communities() 551 | community_schema = self.gen_community_schema() 552 | with open(self.community_path, "w", encoding="utf-8") as file: 553 | json.dump(community_schema, file, indent=4) 554 | 555 | def read_community_schema(self) -> dict: 556 | try: 557 | with open(self.community_path, "r", encoding="utf-8") as file: 558 | community_schema = json.load(file) 559 | except: 560 | raise FileNotFoundError( 561 | "Community schema not found. Please make sure to generate it first." 562 | ) 563 | return community_schema 564 | 565 | def get_loaded_documents(self): 566 | try: 567 | with open(self.doc_path, "r", encoding="utf-8") as file: 568 | lines = file.readlines() 569 | return set(line.strip() for line in lines) 570 | except: 571 | raise FileNotFoundError("Cache file not found.") 572 | 573 | def add_loaded_documents(self, file_path): 574 | if file_path in self.loaded_documents: 575 | print( 576 | f"Document '{file_path}' has already been loaded, skipping addition to cache." 577 | ) 578 | return 579 | with open(self.doc_path, "a", encoding="utf-8") as file: 580 | file.write(file_path + "\n") 581 | self.loaded_documents.add(file_path) 582 | 583 | def get_node_by_id(self, node_id): 584 | query = """ 585 | MATCH (n:Entity {entity_id: $node_id}) 586 | RETURN n 587 | """ 588 | with self.driver.session() as session: 589 | result = session.run(query, node_id=node_id) 590 | nodes = [record["n"] for record in result] 591 | return nodes[0] 592 | 593 | def get_edges_by_id(self, src, tar): 594 | query = """ 595 | MATCH (n:Entity {entity_id: $src})-[r]-(m:Entity {entity_id: $tar}) 596 | RETURN {src: n.name, r: r.name, tar: m.name} AS R 597 | """ 598 | with self.driver.session() as session: 599 | result = session.run(query, {"src": src, "tar": tar}) 600 | edges = [record["R"] for record in result] 601 | return edges[0] 602 | 603 | def gen_single_community_report(self, community: dict): 604 | nodes = community["nodes"] 605 | edges = community["edges"] 606 | nodes_describe = [] 607 | edges_describe = [] 608 | for i in nodes: 609 | node = self.get_node_by_id(i) 610 | nodes_describe.append({"name": node["name"], "desc": node["description"]}) 611 | for i in edges: 612 | edge = self.get_edges_by_id(i[0], i[1]) 613 | edges_describe.append( 614 | {"source": edge["src"], "target": edge["tar"], "desc": edge["r"]} 615 | ) 616 | nodes_csv = "entity,description\n" 617 | for node in nodes_describe: 618 | nodes_csv += f"{node['name']},{node['desc']}\n" 619 | edges_csv = "source,target,description\n" 620 | for edge in edges_describe: 621 | edges_csv += f"{edge['source']},{edge['target']},{edge['desc']}\n" 622 | data = f""" 623 | Text: 624 | -----Entities----- 625 | ```csv 626 | {nodes_csv} 627 | ``` 628 | -----Relationships----- 629 | ```csv 630 | {edges_csv} 631 | ```""" 632 | prompt = GEN_COMMUNITY_REPORT.format(input_text=data) 633 | report = self.llm.predict(prompt) 634 | return report 635 | 636 | def generate_community_report(self): 637 | communities_schema = self.read_community_schema() 638 | for community_key, community in tqdm( 639 | communities_schema.items(), desc="generating community report" 640 | ): 641 | community["report"] = self.gen_single_community_report(community) 642 | with open(self.community_path, "w", encoding="utf-8") as file: 643 | json.dump(communities_schema, file, indent=4) 644 | print("All community report has been generated.") 645 | 646 | def build_local_query_context(self, query): 647 | query_emb = self.embedding.get_emb(query) 648 | topk_similar_entities_context = self.get_topk_similar_entities(query_emb) 649 | topk_similar_communities_context = self.get_communities( 650 | topk_similar_entities_context 651 | ) 652 | topk_similar_relations_context = self.get_relations( 653 | topk_similar_entities_context, query 654 | ) 655 | topk_similar_chunks_context = self.get_chunks( 656 | topk_similar_entities_context, query 657 | ) 658 | return f""" 659 | -----Reports----- 660 | ```csv 661 | {topk_similar_communities_context} 662 | ``` 663 | -----Entities----- 664 | ```csv 665 | {topk_similar_entities_context} 666 | ``` 667 | -----Relationships----- 668 | ```csv 669 | {topk_similar_relations_context} 670 | ``` 671 | -----Sources----- 672 | ```csv 673 | {topk_similar_chunks_context} 674 | ``` 675 | """ 676 | 677 | def map_community_points(self, community_info, query): 678 | points_html = self.llm.predict( 679 | GLOBAL_MAP_POINTS.format(context_data=community_info, query=query) 680 | ) 681 | points = get_text_inside_tag(points_html, "point") 682 | res = [] 683 | for point in points: 684 | try: 685 | score = get_text_inside_tag(point, "score")[0] 686 | desc = get_text_inside_tag(point, "description")[0] 687 | res.append((desc, score)) 688 | except: 689 | continue 690 | return res 691 | 692 | def build_global_query_context(self, query, level=1): 693 | communities_schema = self.read_community_schema() 694 | candidate_community = {} 695 | points = [] 696 | for communityid, community_info in communities_schema.items(): 697 | if community_info["level"] < level: 698 | candidate_community.update({communityid: community_info}) 699 | for communityid, community_info in candidate_community.items(): 700 | points.extend(self.map_community_points(community_info["report"], query)) 701 | points = sorted(points, key=lambda x: x[-1], reverse=True) 702 | return points 703 | 704 | def local_query(self, query): 705 | context = self.build_local_query_context(query) 706 | prompt = LOCAL_QUERY.format(query=query, context=context) 707 | response = self.llm.predict(prompt) 708 | return response 709 | 710 | def global_query(self, query, level=1): 711 | context = self.build_global_query_context(query, level) 712 | prompt = GLOBAL_QUERY.format(query=query, context=context) 713 | response = self.llm.predict(prompt) 714 | return response 715 | --------------------------------------------------------------------------------