├── .gitattributes ├── .github └── workflows │ └── package.yml ├── .gitignore ├── LICENSE ├── NodeRAG ├── LLM │ ├── LLM.py │ ├── LLM_base.py │ ├── LLM_route.py │ ├── LLM_state.py │ └── __init__.py ├── Vis │ ├── __init__.py │ └── html │ │ ├── __init__.py │ │ ├── __main__.py │ │ └── visual_html.py ├── WebUI │ ├── __init__ .py │ ├── __main__.py │ └── app.py ├── __init__.py ├── __main__.py ├── build │ ├── Node.py │ ├── __init__.py │ ├── __main__.py │ ├── component │ │ ├── __init__.py │ │ ├── attribute.py │ │ ├── community.py │ │ ├── document.py │ │ ├── entity.py │ │ ├── relationship.py │ │ ├── semantic_unit.py │ │ ├── text_unit.py │ │ └── unit.py │ └── pipeline │ │ ├── HNSW_graph.py │ │ ├── INIT_pipeline.py │ │ ├── Insert_text.py │ │ ├── __init__.py │ │ ├── attribute_generation.py │ │ ├── document_pipeline.py │ │ ├── embedding.py │ │ ├── graph_pipeline.py │ │ ├── summary_generation.py │ │ └── text_pipeline.py ├── config │ ├── Node_config.py │ ├── Node_config.yaml │ ├── __init__.py │ └── __main__.py ├── logging │ ├── __init__.py │ ├── error.py │ ├── info_timer.py │ └── logger.py ├── search │ ├── Answer_base.py │ ├── __init__.py │ ├── __main__.py │ └── search.py ├── storage │ ├── __init__.py │ ├── genid.py │ ├── graph_mapping.py │ └── storage.py └── utils │ ├── HNSW.py │ ├── PPR.py │ ├── __init__.py │ ├── graph_operator.py │ ├── lazy_import.py │ ├── observation.py │ ├── prompt │ ├── __init__.py │ ├── answer.py │ ├── attribute_generation_prompt.py │ ├── community_summary.py │ ├── decompose.py │ ├── json_format.py │ ├── prompt_manager.py │ ├── relationship_reconstraction.py │ ├── text_decomposition.py │ └── translation.py │ ├── readable_index.py │ ├── text_spliter.py │ ├── token_utils.py │ └── yaml_operation.py ├── README.md ├── asset ├── NodeGraph_Figure2.png ├── Node_background.jpg ├── performance.png └── system_performance.png ├── pyproject.toml ├── requirements.in ├── requirements.txt └── uv.lock /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /.github/workflows/package.yml: -------------------------------------------------------------------------------- 1 | name: Build and Publish to PyPI (Multi-Python) 2 | 3 | on: 4 | push: 5 | tags: 6 | - 'v*.*.*' # add tag to trigger the workflow 7 | 8 | jobs: 9 | build-and-publish: 10 | strategy: 11 | matrix: 12 | python-version: ["3.10", "3.11"] 13 | 14 | runs-on: ubuntu-latest 15 | 16 | steps: 17 | - uses: actions/checkout@v3 18 | 19 | - name: Set up Python ${{ matrix.python-version }} 20 | uses: actions/setup-python@v4 21 | with: 22 | python-version: ${{ matrix.python-version }} 23 | 24 | - name: Install build tools 25 | run: | 26 | python -m pip install --upgrade pip build twine 27 | 28 | - name: Build package 29 | run: | 30 | python -m build 31 | 32 | - name: Upload artifact 33 | uses: actions/upload-artifact@v4 34 | with: 35 | name: NodeRAG-dist-py${{ matrix.python-version }} 36 | path: dist/* 37 | 38 | - name: Upload to PyPI (only on Python 3.10) 39 | if: matrix.python-version == '3.10' 40 | env: 41 | TWINE_USERNAME: __token__ 42 | TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }} 43 | run: | 44 | twine upload dist/* 45 | 46 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | *.py[cod] 3 | *$py.class 4 | *.log 5 | .venv/ 6 | *.egg-info/ 7 | dist/ 8 | lib/ 9 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Terry-Xu-666 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /NodeRAG/LLM/LLM.py: -------------------------------------------------------------------------------- 1 | import os 2 | import backoff 3 | from ..utils.lazy_import import LazyImport 4 | from json import JSONDecodeError 5 | import json 6 | 7 | 8 | from ..logging.error import ( 9 | error_handler, 10 | error_handler_async 11 | ) 12 | 13 | from ..LLM.LLM_base import ( 14 | LLM_message, 15 | ModelConfig, 16 | LLMOutput, 17 | Embedding_message, 18 | Embedding_output, 19 | LLMBase, 20 | OpenAI_message, 21 | Gemini_content 22 | ) 23 | 24 | 25 | from openai import ( 26 | RateLimitError, 27 | Timeout, 28 | APIConnectionError, 29 | ) 30 | 31 | from google.api_core.exceptions import ( 32 | ResourceExhausted, 33 | TooManyRequests, 34 | InternalServerError 35 | ) 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | OpenAI = LazyImport('openai','OpenAI') 46 | AzureOpenAI = LazyImport('openai','AzureOpenAI') 47 | AsyncOpenAI = LazyImport('openai','AsyncOpenAI') 48 | AsyncAzureOpenAI = LazyImport('openai','AsyncAzureOpenAI') 49 | genai = LazyImport("google.genai") 50 | # Together = LazyImport('together','Together') 51 | # AsyncTogether = LazyImport('together','AsyncTogether') 52 | 53 | 54 | class LLM(LLMBase): 55 | 56 | def __init__(self, 57 | model_name: str, 58 | api_keys: str | None, 59 | config: ModelConfig | None = None) -> None: 60 | 61 | super().__init__(model_name, api_keys, config) 62 | 63 | def extract_config(self, config: ModelConfig) -> ModelConfig: 64 | return config 65 | 66 | def predict(self, input: LLM_message) -> LLMOutput: 67 | response = self.API_client(input) 68 | return response 69 | 70 | 71 | 72 | async def predict_async(self, input: LLM_message) -> LLMOutput: 73 | response = await self.API_client_async(input) 74 | return response 75 | 76 | def API_client(self, input: LLM_message) -> LLMOutput: 77 | pass 78 | 79 | async def API_client_async(self, input: LLM_message) -> LLMOutput: 80 | pass 81 | 82 | 83 | class OPENAI(LLM): 84 | 85 | def __init__(self, 86 | model_name: str, 87 | api_keys: str | None, 88 | Config: ModelConfig|None=None) -> None: 89 | 90 | super().__init__(model_name, api_keys, Config) 91 | 92 | if self.api_keys is None: 93 | self.api_keys = os.getenv("OPENAI_API_KEY") 94 | 95 | self.client = OpenAI(api_key=self.api_keys) 96 | self.client_async = AsyncOpenAI(api_key=self.api_keys) 97 | self.config = self.extract_config(Config) 98 | 99 | 100 | def extract_config(self, config: ModelConfig) -> ModelConfig: 101 | options = { 102 | "max_tokens": config.get("max_tokens", 10000), # Default value if not provided 103 | "temperature": config.get("temperature", 0.0), # Default value if not provided 104 | } 105 | return options 106 | 107 | 108 | @backoff.on_exception(backoff.expo, 109 | [RateLimitError, Timeout, APIConnectionError,JSONDecodeError], 110 | max_time=30, 111 | max_tries=4) 112 | def _create_completion(self, messages, response_format=None): 113 | params = { 114 | "model": self.model_name, 115 | "messages": messages, 116 | **self.config 117 | } 118 | 119 | if response_format: 120 | 121 | params["response_format"] = response_format 122 | response = self.client.beta.chat.completions.parse(**params) 123 | json_response = response.choices[0].message.parsed.model_dump_json() 124 | json_response = json.loads(json_response) 125 | 126 | return json_response 127 | 128 | else: 129 | response = self.client.chat.completions.create(**params) 130 | return response.choices[0].message.content.strip() 131 | 132 | 133 | @backoff.on_exception(backoff.expo, 134 | [RateLimitError, Timeout, APIConnectionError,JSONDecodeError], 135 | max_time=30, 136 | max_tries=4) 137 | async def _create_completion_async(self, messages, response_format=None): 138 | params = { 139 | "model": self.model_name, 140 | "messages": messages, 141 | **self.config 142 | } 143 | if response_format: 144 | params["response_format"] = response_format 145 | response = await self.client_async.beta.chat.completions.parse(**params) 146 | json_response = response.choices[0].message.parsed.model_dump_json() 147 | json_response = json.loads(json_response) 148 | return json_response 149 | else: 150 | 151 | response = await self.client_async.chat.completions.create(**params) 152 | return response.choices[0].message.content.strip() 153 | 154 | 155 | @error_handler 156 | def API_client(self, input: LLM_message) -> LLMOutput: 157 | messages = self.messages(input) 158 | response = self._create_completion( 159 | messages, 160 | input.get('response_format') 161 | ) 162 | return response 163 | 164 | @error_handler_async 165 | async def API_client_async(self, input: LLM_message) -> LLMOutput: 166 | messages = self.messages(input) 167 | response = await self._create_completion_async( 168 | messages, 169 | input.get('response_format') 170 | ) 171 | 172 | return response 173 | 174 | def stream_chat(self,input:LLM_message): 175 | messages = self.messages(input) 176 | response = self.client.chat.completions.create( 177 | model=self.model_name, 178 | messages=messages, 179 | stream=True 180 | ) 181 | for chunk in response: 182 | if chunk.choices[0].delta.content is not None: 183 | yield chunk.choices[0].delta.content 184 | 185 | def messages(self, input: LLM_message) -> OpenAI_message: 186 | 187 | messages = [] 188 | if input.get("system_prompt"): 189 | messages.append({ 190 | "role": "system", 191 | "content": input["system_prompt"] 192 | }) 193 | content =[{"type": "text","text": input["query"]}] 194 | 195 | messages.append({"role": "user","content": content}) 196 | 197 | return messages 198 | 199 | 200 | class OpenAI_Embedding(LLM): 201 | 202 | def __init__(self, 203 | model_name: str, 204 | api_keys: str | None, 205 | Config: ModelConfig|None) -> None: 206 | 207 | super().__init__(model_name, api_keys,Config) 208 | 209 | if api_keys is None: 210 | api_keys = os.getenv("OPENAI_API_KEY") 211 | self.client = OpenAI(api_key=api_keys) 212 | self.client_async = AsyncOpenAI(api_key=api_keys) 213 | 214 | @backoff.on_exception(backoff.expo, 215 | [RateLimitError, Timeout, APIConnectionError], 216 | max_time=30, 217 | max_tries=4) 218 | def _create_embedding(self, input: Embedding_message) -> Embedding_output: 219 | response = self.client.embeddings.create( 220 | model=self.model_name, 221 | input=input 222 | ) 223 | return [res.embedding for res in response.data] 224 | 225 | @error_handler 226 | def API_client(self, input: Embedding_message) -> Embedding_output: 227 | response = self._create_embedding(input) 228 | 229 | return response 230 | 231 | @backoff.on_exception(backoff.expo, 232 | [RateLimitError, Timeout, APIConnectionError], 233 | max_time=30, 234 | max_tries=4) 235 | async def _create_embedding_async(self, input: Embedding_message) -> Embedding_output: 236 | response = await self.client_async.embeddings.create( 237 | model=self.model_name, 238 | input=input 239 | ) 240 | return [res.embedding for res in response.data] 241 | 242 | @error_handler_async 243 | async def API_client_async(self, input: Embedding_message) -> Embedding_output: 244 | response = await self._create_embedding_async(input) 245 | return response 246 | 247 | 248 | 249 | 250 | 251 | 252 | 253 | class Gemini(LLM): 254 | 255 | def __init__(self, 256 | model_name: str, 257 | api_keys: str | None, 258 | Config: ModelConfig|None) -> None: 259 | 260 | super().__init__(model_name, api_keys) 261 | if api_keys is None: 262 | api_keys = os.getenv('GOOGLE_API_KEY') 263 | 264 | 265 | 266 | self.client = genai.Client(api_key=api_keys) 267 | self.config = self.extract_config(Config) 268 | 269 | 270 | 271 | def extract_config(self, config: ModelConfig) -> ModelConfig: 272 | options = { 273 | "max_tokens": config.get("max_tokens", 10000), # Default value if not provided 274 | "temperature": config.get("temperature", 0.0), # Default value if not provided 275 | } 276 | return options 277 | 278 | @backoff.on_exception(backoff.expo, 279 | [ResourceExhausted,TooManyRequests, InternalServerError,JSONDecodeError], 280 | max_time=30, 281 | max_tries=4) 282 | def _create_completion(self, messages, response_format=None): 283 | 284 | 285 | params = { 286 | "model": self.model_name, 287 | "contents": messages, 288 | } 289 | if response_format: 290 | 291 | config = genai.types.GenerateContentConfig( 292 | temperature=self.config.get("temperature", 0.0), 293 | max_output_tokens=self.config.get("max_tokens", 10000), 294 | response_mime_type="application/json", 295 | response_schema=response_format, 296 | ) 297 | response = self.client.models.generate_content(**params,config=config) 298 | json_response = response.text 299 | json_response = json.loads(json_response) 300 | return json_response 301 | else: 302 | 303 | 304 | config = genai.types.GenerateContentConfig( 305 | temperature=self.config.get("temperature", 0.0), 306 | max_output_tokens=self.config.get("max_tokens", 10000), 307 | ) 308 | response = self.client.models.generate_content(**params,config=config) 309 | return response.text 310 | 311 | @backoff.on_exception(backoff.expo, 312 | (ResourceExhausted,TooManyRequests,InternalServerError,JSONDecodeError), 313 | max_time=30, 314 | max_tries=4) 315 | async def _create_completion_async(self, messages, response_format=None): 316 | 317 | params = { 318 | "model": self.model_name, 319 | "contents": messages, 320 | } 321 | if response_format: 322 | config = genai.types.GenerateContentConfig( 323 | temperature=self.config.get("temperature", 0.0), 324 | max_output_tokens=self.config.get("max_tokens", 10000), 325 | response_mime_type="application/json", 326 | response_schema=response_format, 327 | ) 328 | response = await self.client.aio.models.generate_content(**params,config=config) 329 | json_response = response.text 330 | json_response = json.loads(json_response) 331 | return json_response 332 | 333 | 334 | else: 335 | config = genai.types.GenerateContentConfig( 336 | temperature=self.config.get("temperature", 0.0), 337 | max_output_tokens=self.config.get("max_tokens", 10000), 338 | ) 339 | response = await self.client.aio.models.generate_content(**params,config=config) 340 | return response.text 341 | 342 | 343 | 344 | @error_handler 345 | def API_client(self, input: LLM_message) -> LLMOutput: 346 | 347 | 348 | messages = self.messages(input) 349 | response = self._create_completion( 350 | messages, 351 | input.get('response_format') 352 | ) 353 | 354 | 355 | 356 | return response 357 | 358 | 359 | @error_handler_async 360 | async def API_client_async(self, input: LLM_message) -> LLMOutput: 361 | 362 | 363 | 364 | messages = self.messages(input) 365 | 366 | response = await self._create_completion_async( 367 | messages, 368 | input.get('response_format') 369 | ) 370 | 371 | return response 372 | 373 | def messages(self, input: LLM_message) -> Gemini_content: 374 | 375 | query = '' 376 | if input.get("system_prompt"): 377 | query = 'system_prompt:\n'+input["system_prompt"] 378 | query = query + '\nquery:\n'+input["query"] 379 | content = [query] 380 | return content 381 | 382 | def stream_chat(self,input:LLM_message): 383 | messages = self.messages(input) 384 | for chunk in self.client.models.generate_content_stream(model=self.model_name, contents=messages): 385 | yield chunk.text 386 | 387 | 388 | class Gemini_Embedding(LLM): 389 | 390 | def __init__(self, 391 | model_name: str, 392 | api_keys: str | None, 393 | Config: ModelConfig|None) -> None: 394 | super().__init__(model_name, api_keys,Config) 395 | if api_keys is None: 396 | api_keys = os.getenv('GOOGLE_API_KEY') 397 | self.client = genai.Client(api_key=api_keys) 398 | 399 | @backoff.on_exception(backoff.expo, 400 | [ResourceExhausted, TooManyRequests, InternalServerError], 401 | max_time=30, 402 | max_tries=4) 403 | def _create_embedding(self, input: Embedding_message) -> Embedding_output: 404 | response = self.client.models.embed_content( 405 | model=self.model_name, 406 | contents=input 407 | ) 408 | return [res.values for res in response.embeddings] 409 | 410 | @error_handler 411 | def API_client(self, input: Embedding_message) -> Embedding_output: 412 | 413 | response = self._create_embedding(input) 414 | return response 415 | 416 | @backoff.on_exception(backoff.expo, 417 | [ResourceExhausted, TooManyRequests, InternalServerError], 418 | max_time=30, 419 | max_tries=4) 420 | async def _create_embedding_async(self, input: Embedding_message) -> Embedding_output: 421 | response = await self.client.aio.models.embed_content( 422 | model=self.model_name, 423 | contents=input 424 | ) 425 | return [res.values for res in response.embeddings] 426 | 427 | @error_handler_async 428 | async def API_client_async(self, input: Embedding_message) -> Embedding_output: 429 | response = await self._create_embedding_async(input) 430 | return response 431 | 432 | 433 | -------------------------------------------------------------------------------- /NodeRAG/LLM/LLM_base.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, TypeAlias, Union,TypeVar,Generic 2 | from typing_extensions import NotRequired, TypedDict 3 | from pydantic import BaseModel 4 | from abc import ABC, abstractmethod 5 | 6 | # Type Aliases for improved readability and maintainability 7 | LLMOutput: TypeAlias = str 8 | LLMQuery: TypeAlias = str 9 | LLMPrompt: TypeAlias = str 10 | ModelConfig: TypeAlias = Dict[str, Any] 11 | JSONSchema: TypeAlias = BaseModel 12 | EmbeddingInput: TypeAlias = str 13 | EmbeddingOutput: TypeAlias = float 14 | OpenAI_message: TypeAlias = List[dict] 15 | Gemini_content: TypeAlias = List[str] 16 | 17 | 18 | 19 | class LLM_message(TypedDict): 20 | """TypedDict for LLM input parameters. 21 | 22 | Attributes: 23 | system_prompt (str, optional): System-level instructions for the LLM. 24 | query (str): The main prompt or question for the LLM. 25 | response_format (JSONSchema, optional): Expected response format schema. 26 | """ 27 | system_prompt: NotRequired[LLMPrompt] 28 | query: LLMQuery 29 | response_format: NotRequired[JSONSchema] 30 | 31 | class Embedding_message(TypedDict): 32 | """TypedDict for Embedding input parameters. 33 | 34 | Attributes: 35 | input (str): The input text for embedding. 36 | """ 37 | input: Union[List[EmbeddingInput], EmbeddingInput] 38 | 39 | class Embedding_output(TypedDict): 40 | """TypedDict for Embedding output parameters. 41 | 42 | Attributes: 43 | embedding (List[float]): The embedding output. 44 | """ 45 | embedding: List[EmbeddingOutput] 46 | 47 | I = TypeVar('I', bound=Union[LLM_message, Embedding_message]) 48 | O = TypeVar('O', bound=Union[LLMOutput, Embedding_output]) 49 | 50 | class LLMBase(ABC,Generic[I,O]): 51 | def __init__(self, 52 | model_name: str, 53 | api_keys: str | None, 54 | config: ModelConfig | None = None) -> None: 55 | """ 56 | Initializes the LLMBase instance with the specified model name, API keys, and configuration. 57 | 58 | Args: 59 | model_name (str): The name of the model to be used. 60 | api_keys (str | None): The API keys for authentication, if applicable. 61 | config (ModelConfig | None): Optional configuration settings for the model. 62 | """ 63 | self.model_name = model_name 64 | self.api_keys = api_keys 65 | self.config = config 66 | 67 | @abstractmethod 68 | def extract_config(self, config: ModelConfig) -> ModelConfig: 69 | """ 70 | Abstract method to extract the configuration from the provided config. 71 | """ 72 | pass 73 | 74 | @abstractmethod 75 | def predict(self, input: I) -> O: 76 | """ 77 | Abstract method to predict the output based on the provided input. 78 | 79 | Args: 80 | input (LLM_message): The input message for the prediction. 81 | 82 | Returns: 83 | LLMOutput: The predicted output. 84 | """ 85 | pass 86 | 87 | @abstractmethod 88 | async def predict_async(self, input: I) -> O: 89 | """ 90 | Abstract method to asynchronously predict the output based on the provided input. 91 | 92 | Args: 93 | input (LLM_message): The input message for the prediction. 94 | 95 | Returns: 96 | LLMOutput: The predicted output. 97 | """ 98 | pass 99 | 100 | @abstractmethod 101 | def API_client(self, input: I) -> O: 102 | """ 103 | Abstract method to set the API client. 104 | """ 105 | pass 106 | 107 | @abstractmethod 108 | async def API_client_async(self, input: I) -> O: 109 | """ 110 | Abstract method to set the API client. 111 | """ 112 | pass 113 | 114 | 115 | -------------------------------------------------------------------------------- /NodeRAG/LLM/LLM_route.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from .LLM_base import I,O,Dict 3 | from .LLM import * 4 | 5 | from ..logging.error import ( 6 | cache_error, 7 | cache_error_async 8 | ) 9 | 10 | 11 | 12 | def LLM_route(config : ModelConfig) -> LLM: 13 | 14 | '''Route the request to the appropriate LLM service provider''' 15 | 16 | 17 | 18 | service_provider = config.get("service_provider") 19 | model_name = config.get("model_name") 20 | embedding_model_name = config.get("embedding_model_name",None) 21 | api_keys = config.get("api_keys",None) 22 | 23 | match service_provider: 24 | case "openai": 25 | return OPENAI(model_name, api_keys, config) 26 | case "openai_embedding": 27 | return OpenAI_Embedding(embedding_model_name, api_keys, config) 28 | case "gemini": 29 | return Gemini(model_name, api_keys, config) 30 | case "gemini_embedding": 31 | return Gemini_Embedding(embedding_model_name, api_keys, config) 32 | case _: 33 | raise ValueError("Service provider not supported") 34 | 35 | 36 | 37 | class API_client(): 38 | 39 | def __init__(self, 40 | config : ModelConfig) -> None: 41 | 42 | self.llm = LLM_route(config) 43 | self.rate_limit = config.get("rate_limit",50) 44 | self.semaphore = asyncio.Semaphore(self.rate_limit) 45 | 46 | 47 | 48 | 49 | @cache_error_async 50 | async def __call__(self, input: I, *,cache_path:str|None=None,meta_data:Dict|None=None) -> O: 51 | 52 | async with self.semaphore: 53 | response = await self.llm.predict_async(input) 54 | 55 | 56 | return response 57 | 58 | @cache_error 59 | def request(self, input:I, *,cache_path:str|None=None,meta_data:Dict|None=None) -> O: 60 | 61 | 62 | response = self.llm.predict(input) 63 | 64 | 65 | return response 66 | 67 | def stream_chat(self,input:I): 68 | yield from self.llm.stream_chat(input) 69 | 70 | 71 | -------------------------------------------------------------------------------- /NodeRAG/LLM/LLM_state.py: -------------------------------------------------------------------------------- 1 | from .LLM_base import LLMBase 2 | 3 | 4 | api_client = None 5 | embedding_client = None 6 | 7 | def set_api_client(client:LLMBase|None): 8 | if client is None: 9 | raise ValueError("Please provide a valid API client information") 10 | global api_client 11 | api_client = client 12 | return api_client 13 | 14 | def get_api_client(): 15 | return api_client 16 | 17 | def set_embedding_client(client:LLMBase|None): 18 | if client is None: 19 | raise ValueError("Please provide a valid API client information") 20 | global embedding_client 21 | embedding_client = client 22 | return embedding_client 23 | 24 | def get_embedding_client(): 25 | return embedding_client -------------------------------------------------------------------------------- /NodeRAG/LLM/__init__.py: -------------------------------------------------------------------------------- 1 | from .LLM_base import ( 2 | LLM_message, 3 | LLMBase, 4 | LLMOutput, 5 | LLMQuery, 6 | LLMPrompt, 7 | ModelConfig, 8 | JSONSchema, 9 | EmbeddingInput, 10 | EmbeddingOutput, 11 | OpenAI_message, 12 | Embedding_message, 13 | Embedding_output, 14 | I, 15 | O 16 | ) 17 | 18 | from .LLM_route import API_client 19 | 20 | from .LLM_state import ( 21 | get_api_client, 22 | get_embedding_client, 23 | set_api_client, 24 | set_embedding_client 25 | ) 26 | 27 | __all__ = [ 28 | 'LLM_message', 29 | 'LLMBase', 30 | 'LLMOutput', 31 | 'LLMQuery', 32 | 'LLMPrompt', 33 | 'ModelConfig', 34 | 'JSONSchema', 35 | 'EmbeddingInput', 36 | 'EmbeddingOutput', 37 | 'OpenAI_message', 38 | 'Embedding_message', 39 | 'Embedding_output', 40 | 'I', 41 | 'O', 42 | 'API_client', 43 | 'get_api_client', 44 | 'get_embedding_client', 45 | 'set_api_client', 46 | 'set_embedding_client' 47 | ] 48 | 49 | -------------------------------------------------------------------------------- /NodeRAG/Vis/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Terry-Xu-666/NodeRAG/f77dd6adb34cf4dda1d88b30b2bf0b17d14480a9/NodeRAG/Vis/__init__.py -------------------------------------------------------------------------------- /NodeRAG/Vis/html/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Terry-Xu-666/NodeRAG/f77dd6adb34cf4dda1d88b30b2bf0b17d14480a9/NodeRAG/Vis/html/__init__.py -------------------------------------------------------------------------------- /NodeRAG/Vis/html/__main__.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from .visual_html import visualize 3 | from rich import console 4 | 5 | console = console.Console() 6 | 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument('-f', "--main_folder", type=str, required=True) 9 | parser.add_argument('-n', "--nodes_num", type=int, default=500,help="nodes number") 10 | args = parser.parse_args() 11 | 12 | console.print(f"Visualizing {args.main_folder} with nodes number {args.nodes_num}") 13 | visualize(args.main_folder, args.nodes_num) 14 | 15 | -------------------------------------------------------------------------------- /NodeRAG/Vis/html/visual_html.py: -------------------------------------------------------------------------------- 1 | from pyvis.network import Network 2 | import pickle 3 | from NodeRAG.storage.graph_mapping import Mapper 4 | from NodeRAG.utils.PPR import sparse_PPR 5 | import os 6 | import math 7 | from tqdm import tqdm 8 | from rich.console import Console 9 | from rich.text import Text 10 | import networkx as nx 11 | console = Console() 12 | 13 | def load_graph(cache_folder): 14 | with open(os.path.join(cache_folder, 'graph.pkl'), 'rb') as f: 15 | return pickle.load(f) 16 | 17 | def initialize_mapper(cache_folder, storage): 18 | return Mapper([os.path.join(cache_folder, s) for s in storage]) 19 | 20 | def create_network(): 21 | return Network(height='100vh', width='100vw', bgcolor='#222222', font_color='white') 22 | 23 | def filter_nodes(graph,nodes_num=2000): 24 | 25 | page_rank = sparse_PPR(graph).PR() 26 | nodes = [node for node,score in page_rank[:nodes_num]] 27 | subgraph = graph.subgraph(nodes).copy() 28 | if not nx.is_connected(subgraph): 29 | console.print(f"subgraph is not connected") 30 | additional_nodes = set() 31 | for i in range(len(nodes)): 32 | for j in range(i+1,len(nodes)): 33 | if not nx.has_path(subgraph,nodes[i],nodes[j]): 34 | path_length,path_nodes = nx.bidirectional_dijkstra(graph,nodes[i],nodes[j]) 35 | additional_nodes.update(set(path_nodes)) 36 | final_nodes = set(nodes) | additional_nodes 37 | subgraph = graph.subgraph(final_nodes).copy() 38 | 39 | console.print(f"final nodes: {len(final_nodes)}") 40 | weighted_nodes = {node:score for node,score in page_rank} 41 | return subgraph,weighted_nodes 42 | 43 | 44 | 45 | 46 | 47 | def add_nodes_to_network(net, subgraph, mapper,weighted_nodes): 48 | for node in tqdm(subgraph.nodes, total=len(subgraph.nodes)): 49 | node_dict = subgraph.nodes[node] 50 | node_type = node_dict['type'] 51 | color = get_node_color(node_type) 52 | net.add_node(node, label=node_type, title=mapper.get(node, 'context'), color=color, size=20 * weighted_nodes[node] + 20) 53 | 54 | def get_node_color(node_type): 55 | match node_type: 56 | case 'entity': 57 | return '#ADD8E6' 58 | case 'attribute': 59 | return '#FFD700' 60 | case 'relationship': 61 | return '#FF7F50' 62 | case 'high_level_element': 63 | return '#98FB98' 64 | case 'semantic_unit': 65 | return '#D8BFD8' 66 | 67 | def add_edges_to_network(net, subgraph): 68 | for edge in tqdm(subgraph.edges, total=len(subgraph.edges)): 69 | net.add_edge(edge[0], edge[1]) 70 | 71 | def set_network_options(net): 72 | net.set_options(""" 73 | var options = { 74 | "nodes": { 75 | "hover": true, 76 | "title": "Node Information", 77 | "label": { 78 | "enabled": true 79 | } 80 | }, 81 | "edges": { 82 | "hover": true, 83 | "title": "Edge Information" 84 | }, 85 | "physics": { 86 | "forceAtlas2Based": { 87 | "springLength": 1 88 | }, 89 | "minVelocity": 0.1, 90 | "solver": "forceAtlas2Based", 91 | "timestep": 0.1, 92 | "stabilization": { 93 | "enabled": true 94 | } 95 | } 96 | } 97 | """) 98 | 99 | def visualize(main_folder,nodes_num=2000): 100 | cache_folder = os.path.join(main_folder, 'cache') 101 | graph = load_graph(cache_folder) 102 | 103 | storage = ['attributes.parquet', 'entities.parquet', 'relationship.parquet', 'high_level_elements.parquet', 'semantic_units.parquet','text.parquet','high_level_elements_titles.parquet'] 104 | mapper = initialize_mapper(cache_folder, storage) 105 | 106 | net = create_network() 107 | subgraph,weighted_nodes = filter_nodes(graph,nodes_num) 108 | 109 | add_nodes_to_network(net, subgraph, mapper,weighted_nodes) 110 | add_edges_to_network(net, subgraph) 111 | 112 | set_network_options(net) 113 | 114 | console.print(Text(f"edges_count: {len(subgraph.edges)}", style="bold green")) 115 | console.print(Text(f"nodes_count: {len(subgraph.nodes)}", style="bold green")) 116 | 117 | net.show(os.path.join(main_folder, 'index.html'), notebook=False) 118 | 119 | 120 | -------------------------------------------------------------------------------- /NodeRAG/WebUI/__init__ .py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Terry-Xu-666/NodeRAG/f77dd6adb34cf4dda1d88b30b2bf0b17d14480a9/NodeRAG/WebUI/__init__ .py -------------------------------------------------------------------------------- /NodeRAG/WebUI/__main__.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | import streamlit.web.cli as stcli 5 | 6 | def main(): 7 | parser = argparse.ArgumentParser() 8 | current_dir = os.path.dirname(os.path.abspath(__file__)) 9 | 10 | parser.add_argument("-f", "--main_folder", type=str, help="main folder path") 11 | args = parser.parse_args() 12 | 13 | 14 | app_path = os.path.join(current_dir, "app.py") 15 | 16 | # 使用 -- 传递参数给 Streamlit 17 | sys.argv = [ 18 | "streamlit", 19 | "run", 20 | app_path, 21 | "--", 22 | f"--main_folder={args.main_folder}" 23 | ] 24 | 25 | sys.exit(stcli.main()) 26 | 27 | if __name__ == "__main__": 28 | main() -------------------------------------------------------------------------------- /NodeRAG/__init__.py: -------------------------------------------------------------------------------- 1 | from .build.Node import NodeRag 2 | from .config import NodeConfig 3 | from .search import NodeSearch 4 | 5 | 6 | 7 | __all__ = ['NodeRag','NodeConfig','NodeSearch'] 8 | 9 | -------------------------------------------------------------------------------- /NodeRAG/__main__.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import requests 3 | import yaml 4 | 5 | parser = argparse.ArgumentParser(description='TGRAG search') 6 | parser.add_argument('-q','--question', type=str, help='The question to ask the search engine') 7 | parser.add_argument('-f','--folder', type=str, help='The main folder of the project') 8 | parser.add_argument('-r','--retrieval', action='store_true', help='Whether to return the retrieval') 9 | parser.add_argument('-a','--answer', action='store_true', help='Whether to return the answer') 10 | 11 | args = parser.parse_args() 12 | 13 | data = {'question':args.question} 14 | 15 | with open(args.folder+'/Node_config.yaml', 'r') as f: 16 | args.config = yaml.safe_load(f) 17 | config = args.config['config'] 18 | 19 | url = config.get('url','127.0.0.1') 20 | port = config.get('port',5000) 21 | 22 | url = f'http://{url}:{port}' 23 | 24 | if not args.answer and not args.retrieval: 25 | response = requests.post(url+'/answer', json=data) 26 | print(response.json()['answer']) 27 | 28 | elif args.answer and not args.retrieval: 29 | response = requests.post(url+'/answer', json=data) 30 | print(response.json()['answer']) 31 | 32 | elif not args.answer and args.retrieval: 33 | response = requests.post(url+'/retrieval', json=data) 34 | print(response.json()['retrieval']) 35 | 36 | else: 37 | response = requests.post(url+'/answer_retrieval', json=data) 38 | print({'answer':response.json()['answer'], 'retrieval':response.json()['retrieval']}) 39 | 40 | 41 | 42 | -------------------------------------------------------------------------------- /NodeRAG/build/Node.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | import os 3 | import json 4 | from rich.tree import Tree 5 | import asyncio 6 | import sys 7 | 8 | from ..config import NodeConfig 9 | 10 | from .pipeline import ( 11 | INIT_pipeline, 12 | document_pipline, 13 | text_pipline, 14 | Graph_pipeline, 15 | Attribution_generation_pipeline, 16 | Embedding_pipeline, 17 | SummaryGeneration, 18 | Insert_text, 19 | HNSW_pipeline 20 | ) 21 | 22 | 23 | class State(Enum): 24 | INIT = "INIT" 25 | DOCUMENT_PIPELINE = "Document pipeline" 26 | TEXT_PIPELINE = "Text pipeline" 27 | GRAPH_PIPELINE = "Graph pipeline" 28 | ATTRIBUTE_PIPELINE = "Attribute pipeline" 29 | EMBEDDING_PIPELINE = "Embedding pipeline" 30 | SUMMARY_PIPELINE = "Summary pipeline" 31 | INSERT_TEXT = "Insert text pipeline" 32 | HNSW_PIPELINE = "HNSW pipeline" 33 | FINISHED = "FINISHED" 34 | ERROR = "ERROR" 35 | ERROR_LOG = "ERROR_LOG" 36 | ERROR_CACHE = "ERROR_CACHE" 37 | NO_ERROR = "NO_ERROR" 38 | 39 | class NodeRag(): 40 | def __init__(self,config:NodeConfig,web_ui:bool=False): 41 | self._Current_state=State.INIT 42 | self.Error_type=State.NO_ERROR 43 | self.Is_incremental=False 44 | self.config=config 45 | self.console = self.config.console 46 | self.config.config_integrity() 47 | self._documents = None 48 | self._hash_ids = None 49 | self.observers = [] 50 | self.web_ui = web_ui 51 | 52 | 53 | 54 | # define the state to pipeline mapping 55 | self.state_pipeline_map = { 56 | State.DOCUMENT_PIPELINE: document_pipline, 57 | State.TEXT_PIPELINE: text_pipline, 58 | State.GRAPH_PIPELINE: Graph_pipeline, 59 | State.ATTRIBUTE_PIPELINE: Attribution_generation_pipeline, 60 | State.EMBEDDING_PIPELINE: Embedding_pipeline, 61 | State.SUMMARY_PIPELINE: SummaryGeneration, 62 | State.INSERT_TEXT: Insert_text, 63 | State.HNSW_PIPELINE: HNSW_pipeline 64 | } 65 | 66 | # define the state sequence 67 | self.state_sequence = [ 68 | State.INIT, 69 | State.DOCUMENT_PIPELINE, 70 | State.TEXT_PIPELINE, 71 | State.GRAPH_PIPELINE, 72 | State.ATTRIBUTE_PIPELINE, 73 | State.EMBEDDING_PIPELINE, 74 | State.SUMMARY_PIPELINE, 75 | State.INSERT_TEXT, 76 | State.HNSW_PIPELINE, 77 | State.FINISHED 78 | ] 79 | 80 | @property 81 | def state_dict(self): 82 | return {'Current_state':self.Current_state.value, 83 | 'Error_type':self.Error_type.value, 84 | 'Is_incremental':self.Is_incremental} 85 | 86 | @property 87 | def Current_state(self): 88 | return self._Current_state 89 | 90 | @Current_state.setter 91 | def Current_state(self,state:State): 92 | self._Current_state = state 93 | self.notify_state_change() 94 | 95 | def notify_state_change(self): 96 | 97 | for observer in self.observers: 98 | observer.update(self.Current_state.value) 99 | 100 | def add_observer(self,observer): 101 | 102 | self.observers.append(observer) 103 | 104 | 105 | def set_state(self,state:State): 106 | 107 | self.Current_state = state 108 | 109 | def get_state(self): 110 | 111 | return self.Current_state 112 | 113 | async def state_transition(self): 114 | 115 | 116 | try: 117 | while True: 118 | self.update_state_tree() 119 | index = self.state_sequence.index(self.Current_state) 120 | if self.Current_state != State.FINISHED: 121 | self.Current_state = self.state_sequence[index+1] 122 | 123 | if self.Current_state == State.FINISHED: 124 | if self.Is_incremental: 125 | if self.web_ui: 126 | self.console.print("[bold green]Detected incremental file, Continue building.[/bold green]") 127 | self.Current_state = State.DOCUMENT_PIPELINE 128 | self.Is_incremental = False 129 | else: 130 | user_input = self.console.input("[bold green]Detected incremental file, Please enter 'y' to continue. Any other input will cancel the pipeline.[/bold green]") 131 | if user_input.lower() == 'y': 132 | self.console.print("[bold green]Pipeline finished. No incremental mode.[/bold green]") 133 | self.Current_state = State.DOCUMENT_PIPELINE 134 | self.Is_incremental = False 135 | else: 136 | self.console.print("[bold red]Pipeline cancelled by user.[/bold red]") 137 | sys.exit() 138 | 139 | else: 140 | self.console.print("[bold green]Pipeline finished. No incremental mode.[/bold green]") 141 | self.store_state() 142 | self.config.whole_time() 143 | return 144 | 145 | self.config.console.print(f"[bold green]Processing {self.Current_state.value} pipeline...[/bold green]") 146 | await self.state_pipeline_map[self.Current_state](self.config).main() 147 | self.config.console.print(f"[bold green]Processing {self.Current_state.value} pipeline finished.[/bold green]") 148 | 149 | except Exception as e: 150 | error_message = str(e) 151 | if 'Error cached' in error_message: 152 | self.Error_type = State.ERROR_CACHE 153 | elif 'error log' in error_message: 154 | self.Error_type = State.ERROR_LOG 155 | else: 156 | self.Error_type = State.ERROR 157 | self.store_state() 158 | raise Exception(f'Error happened in {self.Current_state.value}.{e}') 159 | except KeyboardInterrupt: 160 | self.store_state() 161 | self.config.console.print("\n[bold red]Pipeline interrupted by user.[/bold red]") 162 | sys.exit(0) 163 | 164 | def load_state(self): 165 | 166 | if os.path.exists(self.config.state_path): 167 | state_dict = json.load(open(self.config.state_path,'r')) 168 | self.Current_state = State(state_dict['Current_state']) 169 | self.Error_type = State(state_dict['Error_type']) 170 | self.Is_incremental = state_dict['Is_incremental'] 171 | 172 | 173 | def store_state(self): 174 | 175 | json.dump(self.state_dict,open(self.config.state_path,'w')) 176 | 177 | 178 | 179 | def display_state_tree(self): 180 | 181 | tree = Tree("[bold cyan]🌳 State Tree[/bold cyan]") 182 | 183 | for state in self.state_sequence: 184 | tree.add(f"{state.value}", style="bright_green") 185 | 186 | self.console.print(tree) 187 | self.console.print(f"[bold green]Current working directory: {self.config.main_folder}[/bold green]") 188 | 189 | if self.web_ui: 190 | return 191 | 192 | while True: 193 | user_input = self.console.input("[bold blue]\nDo you want to start the pipeline? (y/n)[/bold blue] ") 194 | if user_input.lower() == 'y': 195 | self.console.clear() 196 | break 197 | elif user_input.lower() == 'n': 198 | self.console.print("[bold red]Pipeline cancelled by user.[/bold red]") 199 | sys.exit() 200 | else: 201 | self.console.input("[bold red]Invalid input. Please enter 'y' or 'n'.[/bold red]") 202 | 203 | def update_state_tree(self): 204 | 205 | self.console.clear() 206 | tree = Tree("[bold cyan]🚀 Processing Pipeline[/bold cyan]") 207 | index = self.state_sequence.index(self.Current_state) 208 | 209 | for i in range(index+1): 210 | tree.add(f"[green]{self.state_sequence[i].value} Done[/green]") 211 | 212 | self.console.print(tree) 213 | 214 | async def error_handler(self): 215 | 216 | self.update_state_tree() 217 | 218 | if self.Error_type == State.ERROR_LOG or self.Error_type == State.ERROR: 219 | self.console.print("[bold red]Error logged. Rerun the pipeline from the current state.[/bold red]") 220 | 221 | try: 222 | await self.state_pipeline_map[self.Current_state](self.config).main() 223 | 224 | except Exception as e: 225 | self.store_state() 226 | self.Error_type = State.ERROR 227 | raise f'Error happened in {self.Current_state} pipeline, please check the error log.{e}' 228 | 229 | if self.Error_type == State.ERROR_CACHE: 230 | 231 | self.console.print("[bold red]Error cached. Rerun the pipeline from the current state.[/bold red]") 232 | 233 | try: 234 | await self.state_pipeline_map[self.Current_state](self.config).rerun() 235 | 236 | except Exception as e: 237 | self.Error_type = State.ERROR_CACHE 238 | self.store_state() 239 | raise f'Error happened in {self.Current_state} pipeline, please check the error log.{e}' 240 | 241 | self.Error_type = State.NO_ERROR 242 | 243 | async def _run_async(self): 244 | 245 | self.load_state() 246 | 247 | self.Is_incremental = await INIT_pipeline(self.config).main() 248 | 249 | if self.Current_state == State.INIT: 250 | self.display_state_tree() 251 | 252 | if self.Error_type != State.NO_ERROR: 253 | await self.error_handler() 254 | 255 | if self.Error_type == State.NO_ERROR: 256 | await self.state_transition() 257 | 258 | def run(self): 259 | asyncio.run(self._run_async()) 260 | 261 | 262 | -------------------------------------------------------------------------------- /NodeRAG/build/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Terry-Xu-666/NodeRAG/f77dd6adb34cf4dda1d88b30b2bf0b17d14480a9/NodeRAG/build/__init__.py -------------------------------------------------------------------------------- /NodeRAG/build/__main__.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import yaml 3 | import os 4 | 5 | from .Node import NodeRag,NodeConfig 6 | 7 | parser = argparse.ArgumentParser(description='TGRAG Build') 8 | parser.add_argument('-f','--folder_path', type=str, help='The folder path of the document') 9 | 10 | 11 | args = parser.parse_args() 12 | 13 | config_path = os.path.join(args.folder_path,'Node_config.yaml') 14 | if not os.path.exists(config_path): 15 | config = NodeConfig.from_main_folder(args.folder_path) 16 | print("Please modify the config file and run the command again") 17 | exit(0) 18 | else: 19 | with open(config_path, 'r') as f: 20 | config = yaml.safe_load(f) 21 | config = NodeConfig(config) 22 | 23 | 24 | 25 | 26 | 27 | ng = NodeRag(config) 28 | ng.run() 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | -------------------------------------------------------------------------------- /NodeRAG/build/component/__init__.py: -------------------------------------------------------------------------------- 1 | from .semantic_unit import Semantic_unit,semantic_unit_index_counter 2 | from .entity import Entity,entity_index_counter 3 | from .relationship import Relationship,relation_index_counter 4 | from .attribute import Attribute,attribute_index_counter 5 | from .document import document,document_index_counter 6 | from .text_unit import Text_unit,text_unit_index_counter 7 | 8 | 9 | 10 | from .community import ( 11 | Community_summary, 12 | high_level_element_index_counter, 13 | community_summary_index_counter, 14 | High_level_elements, 15 | ) 16 | 17 | 18 | 19 | __all__ = [ 20 | 'Semantic_unit', 21 | 'Entity', 22 | 'Relationship', 23 | 'Attribute', 24 | 'Community_summary', 25 | 'document', 26 | 'Text_unit', 27 | 'semantic_unit_index_counter', 28 | 'entity_index_counter', 29 | 'relation_index_counter', 30 | 'attribute_index_counter', 31 | 'community_summary_index_counter', 32 | 'high_level_element_index_counter', 33 | 'document_index_counter', 34 | 'text_unit_index_counter', 35 | 'High_level_elements' 36 | ] 37 | -------------------------------------------------------------------------------- /NodeRAG/build/component/attribute.py: -------------------------------------------------------------------------------- 1 | 2 | from ...storage import genid 3 | from ...utils.readable_index import attribute_index 4 | from .unit import Unit_base 5 | 6 | 7 | attribute_index_counter = attribute_index() 8 | 9 | 10 | class Attribute(Unit_base): 11 | def __init__(self, raw_context:str = None,node:str = None): 12 | self.node = node 13 | self.raw_context = raw_context 14 | self._hash_id = None 15 | self._human_readable_id = None 16 | 17 | 18 | @property 19 | def hash_id(self): 20 | if not self._hash_id: 21 | self._hash_id = genid([self.raw_context],"sha256") 22 | return self._hash_id 23 | @property 24 | def human_readable_id(self): 25 | if not self._human_readable_id: 26 | self._human_readable_id = attribute_index_counter.increment() 27 | return self._human_readable_id -------------------------------------------------------------------------------- /NodeRAG/build/component/community.py: -------------------------------------------------------------------------------- 1 | import networkx as nx 2 | from sortedcontainers import SortedDict 3 | import json 4 | import backoff 5 | from json.decoder import JSONDecodeError 6 | 7 | from ...storage import genid 8 | from ...utils.readable_index import community_summary_index,high_level_element_index 9 | from .unit import Unit_base 10 | from ...storage.graph_mapping import Mapper 11 | 12 | community_summary_index_counter = community_summary_index() 13 | high_level_element_index_counter = high_level_element_index() 14 | 15 | 16 | class Community_summary(Unit_base): 17 | 18 | def __init__(self,community_node:str|None,mapper:Mapper,graph:nx.MultiGraph,config): 19 | 20 | self.community_node = community_node 21 | self.client = config.API_client 22 | self.mapper = mapper 23 | self.graph = graph 24 | self._used_unit = None 25 | self.prompt = config.prompt_manager 26 | self.token_counter = config.token_counter 27 | self._hash_id = None 28 | self._human_readable_id = None 29 | 30 | @property 31 | def hash_id(self): 32 | if not self._hash_id: 33 | self._hash_id = genid(self.community_node,"sha256") 34 | return self._hash_id 35 | @property 36 | def human_readable_id(self): 37 | if not self._human_readable_id: 38 | self._human_readable_id = community_summary_index_counter.increment() 39 | return self._human_readable_id 40 | 41 | @property 42 | def used_unit(self): 43 | if self._used_unit is None: 44 | self._used_unit = [] 45 | for node in self.community_node: 46 | if self.graph.nodes[node]['type'] == 'semantic_unit': 47 | self._used_unit.append(node) 48 | elif self.graph.nodes.get(node, {}).get('type') == 'attribute': 49 | self._used_unit.append(node) 50 | elif self.graph.nodes[node].get('attribute') == 1: 51 | for neighbour in self.graph.neighbors(node): 52 | if self.graph.nodes[neighbour]['type'] == 'attribute': 53 | self._used_unit.append(neighbour) 54 | return self._used_unit 55 | 56 | def get_normal_query(self): 57 | content = '' 58 | for node in self.used_unit: 59 | content += self.mapper.get(node,'context')+'\n' 60 | query = self.prompt.community_summary.format(content = content) 61 | return query 62 | 63 | def get_important_node_query(self): 64 | weights_dict = SortedDict() 65 | for name in self.used_unit: 66 | weight = 0 67 | for neighbour in self.graph.neighbors(name): 68 | weight += self.graph[neighbour]['weight'] 69 | weights_dict[name] = weight 70 | weights_dict = reversed(weights_dict) 71 | query_old = '' 72 | for i in range(len(weights_dict)+1): 73 | query = self.get_query(weights_dict.keys()[:i]) 74 | if self.token_counter.token_limit(query): 75 | return query_old 76 | query_old = query 77 | 78 | def get_query(self): 79 | query = self.get_normal_query() 80 | if self.token_counter.token_limit(query): 81 | return self.get_important_node_query() 82 | return query 83 | 84 | @backoff.on_exception(backoff.expo, 85 | (JSONDecodeError,), 86 | max_tries=3, 87 | max_time=15) 88 | async def generate_community_summary(self): 89 | query = self.get_query() 90 | input = {'query':query,'response_format':self.prompt.high_level_element_json} 91 | self.response = await self.client(input) 92 | 93 | 94 | 95 | class High_level_elements(Unit_base): 96 | def __init__(self,context:str,title:str,config): 97 | self.context = context 98 | self.title = title 99 | self.embedding_client = config.embedding_client 100 | self._hash_id = None 101 | self._title_hash_id = None 102 | self._human_readable_id = None 103 | self.embedding = None 104 | 105 | @property 106 | def hash_id(self): 107 | if not self._hash_id: 108 | self._hash_id = genid([self.context],"sha256") 109 | return self._hash_id 110 | 111 | @property 112 | def title_hash_id(self): 113 | if not self._title_hash_id: 114 | self._title_hash_id = genid([self.title],"sha256") 115 | return self._title_hash_id 116 | 117 | @property 118 | def human_readable_id(self): 119 | if not self._human_readable_id: 120 | self._human_readable_id = high_level_element_index_counter.increment() 121 | return self._human_readable_id 122 | 123 | def store_embedding(self,embedding:list[float]): 124 | self.embedding = embedding 125 | 126 | def related_node(self,nodes:list[str]): 127 | self.related_node = nodes 128 | 129 | 130 | 131 | 132 | 133 | -------------------------------------------------------------------------------- /NodeRAG/build/component/document.py: -------------------------------------------------------------------------------- 1 | from ...utils.text_spliter import SemanticTextSplitter 2 | from ...storage import genid 3 | from ...utils.readable_index import document_index 4 | from .unit import Unit_base 5 | from .text_unit import Text_unit 6 | 7 | 8 | document_index_counter = document_index() 9 | 10 | 11 | class document(Unit_base): 12 | def __init__(self, raw_context:str = None,path:str = None,splitter:SemanticTextSplitter = None): 13 | 14 | self.path = path 15 | self.raw_context = raw_context 16 | self._processed_context = False 17 | self._hash_id = None 18 | self._human_readable_id = None 19 | self.text_units = None 20 | self.text_hash_id = None 21 | self.text_human_readable_id = None 22 | self.splitter = splitter 23 | 24 | @property 25 | def hash_id(self): 26 | if not self._hash_id: 27 | self._hash_id = genid([self.raw_context],"sha256") 28 | return self._hash_id 29 | 30 | @property 31 | def human_readable_id(self): 32 | if not self._human_readable_id: 33 | self._human_readable_id = document_index_counter.increment() 34 | return self._human_readable_id 35 | 36 | def split(self) -> None: 37 | if not self._processed_context: 38 | self._processed_context = True 39 | texts = self.splitter.split(self.raw_context) 40 | self.text_units = [Text_unit(text) for text in texts] 41 | self.text_hash_id = [text.hash_id for text in self.text_units] 42 | self.text_human_readable_id = [text.human_readable_id for text in self.text_units] 43 | 44 | 45 | 46 | 47 | 48 | 49 | -------------------------------------------------------------------------------- /NodeRAG/build/component/entity.py: -------------------------------------------------------------------------------- 1 | from .unit import Unit_base 2 | from ...storage import genid 3 | from ...utils.readable_index import entity_index 4 | 5 | entity_index_counter = entity_index() 6 | 7 | class Entity(Unit_base): 8 | 9 | def __init__(self, raw_context:str,text_hash_id:str = None): 10 | self.raw_context = raw_context 11 | self.text_hash_id = text_hash_id 12 | self._hash_id = None 13 | self._human_readable_id = None 14 | @property 15 | def hash_id(self): 16 | if not self._hash_id: 17 | self._hash_id = genid([self.raw_context],"sha256") 18 | return self._hash_id 19 | @property 20 | def human_readable_id(self): 21 | if not self._human_readable_id: 22 | self._human_readable_id = entity_index_counter.increment() 23 | return self._human_readable_id -------------------------------------------------------------------------------- /NodeRAG/build/component/relationship.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | from .unit import Unit_base 3 | from ...storage import genid 4 | from ...utils.readable_index import relation_index 5 | from .entity import Entity 6 | 7 | 8 | 9 | relation_index_counter = relation_index() 10 | 11 | class Relationship(Unit_base): 12 | 13 | def __init__(self, relationship_tuple: List[str] = None, text_hash_id: str = None, 14 | frozen_set: frozenset = None, context: str = None,human_readable_id:int = None): 15 | if relationship_tuple: 16 | self.relationship_tuple = relationship_tuple 17 | self.source = Entity(relationship_tuple[0], text_hash_id) 18 | self.target = Entity(relationship_tuple[2], text_hash_id) 19 | self.unique_relationship = frozenset((self.source.hash_id,self.target.hash_id)) 20 | self.raw_context = " ".join(self.relationship_tuple) 21 | self._human_readable_id = None 22 | 23 | elif frozen_set: 24 | self.unique_relationship = frozenset(frozen_set) 25 | self.raw_context = context 26 | self._human_readable_id = human_readable_id 27 | else: 28 | raise ValueError("Must provide either relationship_tuple or (frozen_set and context)") 29 | 30 | self.text_hash_id = text_hash_id 31 | self._hash_id = None 32 | 33 | 34 | @property 35 | def hash_id(self): 36 | if not self._hash_id: 37 | self._hash_id = genid(list(self.unique_relationship),"sha256") 38 | return self._hash_id 39 | 40 | @property 41 | def human_readable_id(self): 42 | if not self._human_readable_id: 43 | self._human_readable_id = relation_index_counter.increment() 44 | return self._human_readable_id 45 | 46 | def __eq__(self, other): 47 | if isinstance(other, frozenset): 48 | return self.unique_relationship == other 49 | elif isinstance(other, Relationship): 50 | return self.unique_relationship == other.unique_relationship 51 | return False 52 | 53 | def __hash__(self): 54 | return hash(self.unique_relationship) 55 | 56 | def add(self,relationship_tuple:List[str]): 57 | raw_context = " ".join(relationship_tuple) 58 | self.raw_context = self.raw_context + "\t" + raw_context 59 | 60 | def __str__(self): 61 | return self.raw_context 62 | 63 | @classmethod 64 | def from_df_row(cls,row): 65 | return cls(frozen_set=row['unique_relationship'],context=row['context'],human_readable_id=row['human_readable_id']) 66 | -------------------------------------------------------------------------------- /NodeRAG/build/component/semantic_unit.py: -------------------------------------------------------------------------------- 1 | from .unit import Unit_base 2 | from ...storage import genid 3 | from ...utils.readable_index import semantic_unit_index 4 | semantic_unit_index_counter = semantic_unit_index() 5 | 6 | class Semantic_unit(Unit_base): 7 | def __init__(self, raw_context:str,text_hash_id:str = None): 8 | self.raw_context = raw_context 9 | self.text_hash_id = text_hash_id 10 | self._hash_id = None 11 | self._human_readable_id = None 12 | 13 | @property 14 | def hash_id(self): 15 | if not self._hash_id: 16 | self._hash_id = genid([self.raw_context],"sha256") 17 | return self._hash_id 18 | @property 19 | def human_readable_id(self): 20 | if not self._human_readable_id: 21 | self._human_readable_id = semantic_unit_index_counter.increment() 22 | return self._human_readable_id 23 | 24 | 25 | -------------------------------------------------------------------------------- /NodeRAG/build/component/text_unit.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | 4 | from .unit import Unit_base 5 | from ...storage import genid 6 | from ...utils.readable_index import text_unit_index 7 | 8 | 9 | 10 | 11 | text_unit_index_counter = text_unit_index() 12 | 13 | 14 | class Text_unit(Unit_base): 15 | def __init__(self, raw_context:str = None,hash_id:str = None,human_readable_id:int = None,semantic_units:list = []): 16 | self.raw_context = raw_context 17 | self._hash_id = hash_id 18 | self._human_readable_id = human_readable_id 19 | 20 | @property 21 | def hash_id(self): 22 | if not self._hash_id: 23 | self._hash_id = genid([self.raw_context],"sha256") 24 | return self._hash_id 25 | 26 | @property 27 | def human_readable_id(self): 28 | if not self._human_readable_id: 29 | self._human_readable_id = text_unit_index_counter.increment() 30 | return self._human_readable_id 31 | 32 | 33 | 34 | 35 | async def text_decomposition(self,config) -> None: 36 | 37 | cache_path = config.text_decomposition_path 38 | prompt = config.prompt_manager.text_decomposition.format(text=self.raw_context) 39 | json_format = config.prompt_manager.text_decomposition_json 40 | input_data = {'query':prompt,'response_format':json_format} 41 | meta_data = {'text_hash_id':self.hash_id,'text_id':self.human_readable_id} 42 | 43 | 44 | response = await config.API_client(input_data,cache_path =config.LLM_error_cache,meta_data = meta_data) 45 | 46 | if response == 'Error cached': 47 | config.tracker.update() 48 | return None 49 | 50 | 51 | with open(cache_path, 'a',encoding='utf-8') as f: 52 | data = {**meta_data,'response':response} 53 | f.write(json.dumps(data,ensure_ascii=False)+'\n') 54 | config.tracker.update() 55 | # return response 56 | 57 | 58 | 59 | -------------------------------------------------------------------------------- /NodeRAG/build/component/unit.py: -------------------------------------------------------------------------------- 1 | from abc import ABC,abstractmethod 2 | 3 | class Unit_base(ABC): 4 | 5 | @property 6 | @abstractmethod 7 | def hash_id(self): 8 | ... 9 | @property 10 | @abstractmethod 11 | def human_readable_id(self): 12 | ... 13 | 14 | def call_action(self,action:str,*args, **kwargs) -> None: 15 | method = getattr(self,action,None) 16 | 17 | if callable(method): 18 | method(*args, **kwargs) 19 | else: 20 | raise ValueError(f"Action {action} not found") 21 | 22 | 23 | -------------------------------------------------------------------------------- /NodeRAG/build/pipeline/HNSW_graph.py: -------------------------------------------------------------------------------- 1 | import os 2 | from ...utils.HNSW import HNSW 3 | from ...storage import Mapper 4 | from ...config import NodeConfig 5 | from ...logging import info_timer 6 | 7 | 8 | 9 | class HNSW_pipeline(): 10 | 11 | def __init__(self,config:NodeConfig): 12 | 13 | self.config = config 14 | self.mapper = self.load_mapper() 15 | self.hnsw = self.load_hnsw() 16 | 17 | def load_mapper(self) -> Mapper: 18 | 19 | mapping_list = [self.config.semantic_units_path, 20 | self.config.attributes_path, 21 | self.config.high_level_elements_path, 22 | self.config.text_path] 23 | 24 | for i in range(len(mapping_list)): 25 | if not os.path.exists(mapping_list[i]): 26 | mapping_list.pop(i) 27 | 28 | mapper = Mapper(mapping_list) 29 | if os.path.exists(self.config.embedding): 30 | mapper.add_embedding(self.config.embedding) 31 | 32 | return mapper 33 | 34 | def load_hnsw(self) -> HNSW: 35 | 36 | hnsw = HNSW(self.config) 37 | 38 | if os.path.exists(self.config.HNSW_path): 39 | 40 | hnsw.load_HNSW(self.config.HNSW_path) 41 | return hnsw 42 | 43 | elif self.mapper.embeddings is not None: 44 | return hnsw 45 | else: 46 | raise Exception('No embeddings found') 47 | 48 | 49 | def generate_HNSW(self): 50 | unHNSW = self.mapper.find_non_HNSW() 51 | 52 | self.config.console.print(f'[yellow]Generating HNSW graph for {len(unHNSW)} nodes[/yellow]') 53 | self.hnsw.add_nodes(unHNSW) 54 | self.config.console.print(f'[green]HNSW graph has been added to the graph[/green]') 55 | self.config.tracker.set(len(unHNSW),desc="storing HNSW graph") 56 | for id,embedding in unHNSW: 57 | self.mapper.add_attribute(id,'embedding','HNSW') 58 | self.config.tracker.update() 59 | self.config.tracker.close() 60 | self.config.console.print(f'[green]HNSW graph generated for {len(unHNSW)} nodes[/green]') 61 | 62 | def delete_embedding(self): 63 | 64 | if os.path.exists(self.config.embedding): 65 | os.remove(self.config.embedding) 66 | 67 | @info_timer(message='HNSW graph generation') 68 | async def main(self): 69 | if os.path.exists(self.config.embedding): 70 | self.generate_HNSW() 71 | self.hnsw.save_HNSW() 72 | self.mapper.update_save() 73 | self.delete_embedding() 74 | self.config.console.print('[green]HNSW graph saved[/green]') 75 | 76 | 77 | 78 | 79 | 80 | 81 | -------------------------------------------------------------------------------- /NodeRAG/build/pipeline/INIT_pipeline.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | from ...config import NodeConfig 4 | from ...logging import info_timer 5 | from ...storage import genid 6 | 7 | 8 | 9 | 10 | class INIT_pipeline(): 11 | def __init__(self, config:NodeConfig): 12 | 13 | self.config = config 14 | self.documents_path = [] 15 | 16 | @property 17 | def document_path_hash(self): 18 | if self.documents_path is None: 19 | raise ValueError('Document path is not loaded') 20 | else: 21 | return genid(''.join(self.documents_path),"sha256") 22 | 23 | def check_folder_structure(self): 24 | if not os.path.exists(self.config.main_folder): 25 | raise ValueError(f'Main folder {self.config.main_folder} does not exist') 26 | 27 | if not os.path.exists(self.config.input_folder): 28 | raise ValueError(f'Input folder {self.config.input_folder} does not exist') 29 | 30 | def load_files(self): 31 | 32 | 33 | if self.config.docu_type == 'mixed': 34 | for file in os.listdir(self.config.input_folder): 35 | if file.endswith('.txt') or file.endswith('.md'): 36 | file_path = os.path.join(self.config.input_folder, file) 37 | self.documents_path.append(file_path) 38 | else: 39 | for file in os.listdir(self.config.input_folder): 40 | if file.endswith(f'.{self.config.docu_type}'): 41 | file_path = os.path.join(self.config.input_folder, file) 42 | self.documents_path.append(file_path) 43 | 44 | if len(self.documents_path) == 0: 45 | raise ValueError(f'No files found in {self.config.input_folder}') 46 | 47 | def check_increment(self): 48 | if not os.path.exists(self.config.document_hash_path): 49 | self.save_document_hash() 50 | return False 51 | else: 52 | with open(self.config.document_hash_path,'r') as f: 53 | file = json.load(f) 54 | previous_hash = file['document_path_hash'] 55 | if previous_hash == self.document_path_hash: 56 | return False 57 | else: 58 | return True 59 | 60 | def save_document_hash(self): 61 | with open(self.config.document_hash_path,'w') as f: 62 | json.dump({'document_path_hash':self.document_path_hash,'document_path':self.documents_path},f) 63 | 64 | 65 | @info_timer(message='Init Pipeline') 66 | async def main(self): 67 | self.check_folder_structure() 68 | self.load_files() 69 | if self.check_increment(): 70 | self.save_document_hash() 71 | return True 72 | else: 73 | return False 74 | 75 | 76 | 77 | -------------------------------------------------------------------------------- /NodeRAG/build/pipeline/Insert_text.py: -------------------------------------------------------------------------------- 1 | import os 2 | from ...storage import storage 3 | from ...config import NodeConfig 4 | from ...utils import MultigraphConcat 5 | from ...logging import info_timer 6 | 7 | 8 | 9 | class Insert_text: 10 | 11 | def __init__(self,config:NodeConfig): 12 | self.config = config 13 | self.G = storage.load(self.config.graph_path) 14 | self.base_G = self.load_base_graph(self.config.base_graph_path) 15 | self.semantic_units = storage.load(self.config.semantic_units_path) 16 | 17 | def insert_text(self): 18 | self.config.tracker.set(len(self.semantic_units),'Inserting text') 19 | for id,row in self.semantic_units.iterrows(): 20 | if row['insert'] is None: 21 | semantic_unit_hash_id = row['hash_id'] 22 | text_unit_hash_id = row['text_hash_id'] 23 | if not self.G.has_node(text_unit_hash_id): 24 | self.G.add_node(text_unit_hash_id,type='text',weight=1) 25 | if not self.G.has_edge(semantic_unit_hash_id,text_unit_hash_id): 26 | self.G.add_edge(semantic_unit_hash_id,text_unit_hash_id,type='text',weight=1) 27 | self.semantic_units.at[id,'insert'] = True 28 | self.config.tracker.update() 29 | self.config.tracker.close() 30 | storage(self.semantic_units).save_parquet(self.config.semantic_units_path) 31 | 32 | def concatenate_graph(self): 33 | 34 | self.base_G = MultigraphConcat(self.base_G).concat(self.G) 35 | storage(self.base_G).save_pickle(self.config.base_graph_path) 36 | os.remove(self.config.graph_path) 37 | self.config.console.print('[bold green]Graph has been concatenated, stored in base graph[/bold green]') 38 | 39 | def load_base_graph(self,base_graph_path:str): 40 | if os.path.exists(base_graph_path): 41 | return storage.load(base_graph_path) 42 | else: 43 | return None 44 | @info_timer(message="Insert text and concatenate graph") 45 | async def main(self): 46 | if os.path.exists(self.config.graph_path): 47 | self.insert_text() 48 | self.concatenate_graph() 49 | -------------------------------------------------------------------------------- /NodeRAG/build/pipeline/__init__.py: -------------------------------------------------------------------------------- 1 | from .INIT_pipeline import INIT_pipeline 2 | from .document_pipeline import document_pipline 3 | from .text_pipeline import text_pipline 4 | from .graph_pipeline import Graph_pipeline 5 | from .attribute_generation import Attribution_generation_pipeline 6 | from .embedding import Embedding_pipeline 7 | from .summary_generation import SummaryGeneration 8 | from .Insert_text import Insert_text 9 | from .HNSW_graph import HNSW_pipeline 10 | 11 | 12 | __all__ = ['INIT_pipeline', 13 | 'document_pipline', 14 | 'text_pipline', 15 | 'Graph_pipeline', 16 | 'Attribution_generation_pipeline', 17 | 'Embedding_pipeline', 18 | 'SummaryGeneration', 19 | 'Insert_text', 20 | 'HNSW_pipeline' 21 | ] -------------------------------------------------------------------------------- /NodeRAG/build/pipeline/attribute_generation.py: -------------------------------------------------------------------------------- 1 | import networkx as nx 2 | import numpy as np 3 | import math 4 | import asyncio 5 | import os 6 | from sortedcontainers import SortedDict 7 | from rich.console import Console 8 | 9 | 10 | from ...storage import ( 11 | Mapper, 12 | storage 13 | ) 14 | from ..component import Attribute 15 | from ...config import NodeConfig 16 | from ...logging import info_timer 17 | 18 | 19 | 20 | class NodeImportance: 21 | 22 | def __init__(self,graph:nx.Graph,console:Console): 23 | self.G = graph 24 | self.important_nodes = [] 25 | self.console = console 26 | 27 | def K_core(self,k:int|None = None): 28 | 29 | if k is None: 30 | k = self.defult_k() 31 | 32 | self.k_subgraph = nx.core.k_core(self.G,k=k) 33 | 34 | for nodes in self.k_subgraph.nodes(): 35 | if self.G[nodes]['type'] == 'entity' and self.G[nodes]['weight'] > 1: 36 | self.important_nodes.append(nodes) 37 | 38 | def avarege_degree(self): 39 | average_degree = sum(dict(self.G.degree()).values())/self.G.number_of_nodes() 40 | return average_degree 41 | 42 | def defult_k(self): 43 | k = round(np.log(self.G.number_of_nodes())*self.avarege_degree()**(1/2)) 44 | return k 45 | 46 | def betweenness_centrality(self): 47 | 48 | self.betweenness = nx.betweenness_centrality(self.G,k=10) 49 | average_betweenness = sum(self.betweenness.values())/len(self.betweenness) 50 | scale = round(math.log10(len(self.betweenness))) 51 | 52 | for node in self.betweenness: 53 | if self.betweenness[node] > average_betweenness*scale: 54 | if self.G.nodes[node]['type'] == 'entity' and self.G.nodes[node]['weight'] > 1: 55 | self.important_nodes.append(node) 56 | 57 | def main(self): 58 | self.K_core() 59 | self.console.print('[bold green]K_core done[/bold green]') 60 | self.betweenness_centrality() 61 | self.console.print('[bold green]Betweenness done[/bold green]') 62 | self.important_nodes = list(set(self.important_nodes)) 63 | return self.important_nodes 64 | 65 | 66 | 67 | class Attribution_generation_pipeline: 68 | 69 | def __init__(self,config:NodeConfig): 70 | 71 | 72 | self.config = config 73 | self.prompt_manager = config.prompt_manager 74 | self.indices = config.indices 75 | self.console = config.console 76 | self.API_client = config.API_client 77 | self.token_counter = config.token_counter 78 | self.important_nodes = [] 79 | self.attributes = [] 80 | 81 | 82 | self.mapper = Mapper([self.config.entities_path,self.config.relationship_path,self.config.semantic_units_path]) 83 | self.G = storage.load(self.config.graph_path) 84 | 85 | def get_important_nodes(self): 86 | 87 | node_importance = NodeImportance(self.G,self.config.console) 88 | important_nodes = node_importance.main() 89 | 90 | if os.path.exists(self.config.attributes_path): 91 | attributes = storage.load(self.config.attributes_path) 92 | existing_nodes = attributes['node'].tolist() 93 | important_nodes = [node for node in important_nodes if node not in existing_nodes] 94 | 95 | self.important_nodes = important_nodes 96 | self.console.print('[bold green]Important nodes found[/bold green]') 97 | 98 | def get_neighbours_material(self,node:str): 99 | 100 | entity = self.mapper.get(node,'context') 101 | semantic_neighbours = ''+'\n' 102 | relationship_neighbours = ''+'\n' 103 | 104 | for neighbour in self.G.neighbors(node): 105 | if self.G.nodes[neighbour]['type'] == 'semantic_unit': 106 | semantic_neighbours += f'{self.mapper.get(neighbour,"context")}\n' 107 | elif self.G.nodes[neighbour]['type'] == 'relationship': 108 | relationship_neighbours += f'{self.mapper.get(neighbour,"context")}\n' 109 | 110 | query = self.prompt_manager.attribute_generation.format(entity = entity,semantic_units = semantic_neighbours,relationships = relationship_neighbours) 111 | return query 112 | 113 | 114 | def get_important_neibours_material(self,node:str): 115 | 116 | entity = self.mapper.get(node,'context') 117 | semantic_neighbours = ''+'\n' 118 | relationship_neighbours = ''+'\n' 119 | sorted_neighbours = SortedDict() 120 | 121 | for neighbour in self.G.neighbors(node): 122 | value = 0 123 | for neighbour_neighbour in self.G.neighbors(neighbour): 124 | value += self.G.nodes[neighbour_neighbour]['weight'] 125 | sorted_neighbours[neighbour] = value 126 | 127 | query = '' 128 | for neighbour in reversed(sorted_neighbours): 129 | while not self.token_counter.token_limit(query): 130 | query = self.prompt_manager.attribute_generation.format(entity = entity,semantic_units = semantic_neighbours,relationships = relationship_neighbours) 131 | if self.G.nodes[neighbour]['type'] == 'semantic_unit': 132 | semantic_neighbours += f'{self.mapper.get(neighbour,"context")}\n' 133 | elif self.G.nodes[neighbour]['type'] == 'relationship': 134 | relationship_neighbours += f'{self.mapper.get(neighbour,"context")}\n' 135 | 136 | return query 137 | 138 | async def generate_attribution_main(self): 139 | 140 | tasks = [] 141 | self.config.tracker.set(len(self.important_nodes),desc="Generating attributes") 142 | 143 | for node in self.important_nodes: 144 | tasks.append(self.generate_attribution(node)) 145 | 146 | await asyncio.gather(*tasks) 147 | 148 | self.config.tracker.close() 149 | 150 | 151 | 152 | 153 | async def generate_attribution(self,node:str): 154 | query = self.get_neighbours_material(node) 155 | 156 | 157 | if self.token_counter.token_limit(query): 158 | query = self.get_important_neibours_material(node) 159 | 160 | response = await self.API_client({'query':query}) 161 | if response is not None: 162 | attribute = Attribute(response,node) 163 | 164 | self.attributes.append(attribute) 165 | self.G.nodes[node]['attributes'] = [attribute.hash_id] 166 | self.G.add_node(attribute.hash_id,type='attribute',weight=1) 167 | self.G.add_edge(node,attribute.hash_id,weight=1) 168 | self.config.tracker.update() 169 | 170 | def save_attributes(self): 171 | 172 | attributes = [] 173 | 174 | for attribute in self.attributes: 175 | attributes.append({'node':attribute.node, 176 | 'type':'attribute', 177 | 'context':attribute.raw_context, 178 | 'hash_id':attribute.hash_id, 179 | 'human_readable_id':attribute.human_readable_id, 180 | 'weight':self.G.nodes[attribute.node]['weight'], 181 | 'embedding':None}) 182 | 183 | storage(attributes).save_parquet(self.config.attributes_path,append= os.path.exists(self.config.attributes_path)) 184 | self.config.console.print('[bold green]Attributes stored[/bold green]') 185 | 186 | 187 | def save_graph(self): 188 | 189 | storage(self.G).save_pickle(self.config.graph_path) 190 | self.config.console.print('Graph stored') 191 | 192 | @info_timer(message='Attribute Generation') 193 | async def main(self): 194 | 195 | if os.path.exists(self.config.graph_path): 196 | 197 | self.get_important_nodes() 198 | await self.generate_attribution_main() 199 | self.save_attributes() 200 | self.save_graph() 201 | self.indices.store_all_indices(self.config.indices_path) 202 | 203 | 204 | 205 | 206 | 207 | 208 | 209 | 210 | 211 | 212 | -------------------------------------------------------------------------------- /NodeRAG/build/pipeline/document_pipeline.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | from ...config import NodeConfig 5 | from ...storage.storage import storage 6 | from ..component.document import document 7 | from ...logging import info_timer 8 | 9 | 10 | class document_pipline(): 11 | 12 | def __init__(self, config:NodeConfig): 13 | 14 | 15 | self.config = config 16 | self.documents_path = self.load_document_path() 17 | self.indices = self.config.indices 18 | self._documents = None 19 | self._hash_ids = None 20 | self._human_readable_id = None 21 | 22 | 23 | def integrity_check(self): 24 | if not os.path.exists(self.config.cache): 25 | os.makedirs(self.config.cache) 26 | elif self.cache_completion_check(): 27 | pass 28 | else: 29 | self.delete_cache() 30 | 31 | def load_document_path(self): 32 | with open(self.config.document_hash_path,'r') as f: 33 | return json.load(f)['document_path'] 34 | 35 | @property 36 | def documents(self): 37 | if self._documents is None: 38 | self._documents = [] 39 | for path in self.documents_path: 40 | with open(path, 'r', encoding='utf-8') as f: 41 | raw_context = f.read() 42 | self._documents.append(document(raw_context,path,self.config.semantic_text_splitter)) 43 | return self._documents 44 | 45 | @property 46 | def hash_ids(self): 47 | if not self._hash_ids: 48 | self._hash_ids = [doc.hash_id for doc in self.documents] 49 | return self._hash_ids 50 | 51 | @property 52 | def human_readable_ids(self): 53 | if not self._human_readable_id: 54 | self._human_readable_id = [doc.human_readable_id for doc in self.documents] 55 | return self._human_readable_id 56 | 57 | 58 | def store_documents_data(self): 59 | doc_list = [] 60 | for doc in self.documents: 61 | doc_list.append({'doc_id':doc.human_readable_id, 62 | 'doc_hash_id':doc.hash_id, 63 | 'text_id':doc.text_human_readable_id, 64 | 'text_hash_id':doc.text_hash_id, 65 | 'path':doc.path}) 66 | storage(doc_list).save_parquet(self.config.documents_path,append= os.path.exists(self.config.documents_path)) 67 | self.config.console.print('[green]Documents stored[/green]') 68 | 69 | def store_text_data(self): 70 | text_list = [] 71 | 72 | self.config.tracker.set(len(self.documents),desc="Processing text") 73 | for doc in self.documents: 74 | doc.split() 75 | for text in doc.text_units: 76 | text_list.append({'text_id':text.human_readable_id, 77 | 'hash_id':text.hash_id, 78 | 'type':'text', 79 | 'context':text.raw_context, 80 | 'doc_id':doc.human_readable_id, 81 | 'doc_hash_id':doc.hash_id, 82 | 'embedding':None,}) 83 | self.config.tracker.update() 84 | self.config.tracker.close() 85 | storage(text_list).save_parquet(self.config.text_path,append= os.path.exists(self.config.text_path)) 86 | self.config.console.print('[green]Texts stored[/green]') 87 | 88 | def store_readable_index(self) -> None: 89 | 90 | self.indices.store_all_indices(self.config.indices_path) 91 | 92 | def cache_completion_check(self) -> bool: 93 | files_name = ['documents.parquet','text.parquet','indices.json'] 94 | files = os.listdir(self.config.cache) 95 | return all([file in files for file in files_name]) 96 | 97 | def delete_cache(self) -> None: 98 | for file in os.listdir(self.config.cache): 99 | os.remove(os.path.join(self.config.cache,file)) 100 | self.config.console.print('[red]There exist incomplete cache,deleted[/red]') 101 | 102 | def increment_doc(self) -> None: 103 | if os.path.exists(self.config.documents_path): 104 | exist_doc_id = storage.load_parquet(self.config.documents_path)['doc_hash_id'].tolist() 105 | increment_doc_id = list(set(self.hash_ids) - set(exist_doc_id)) 106 | self._documents = [doc for doc in self.documents if doc.hash_id in increment_doc_id] 107 | else: 108 | self._documents = self.documents 109 | self.documents_path = [doc.path for doc in self.documents] 110 | self._hash_ids = None 111 | 112 | 113 | @info_timer(message='Document Pipeline') 114 | async def main(self): 115 | self.integrity_check() 116 | self.increment_doc() 117 | self.store_text_data() 118 | self.store_documents_data() 119 | self.store_readable_index() 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | -------------------------------------------------------------------------------- /NodeRAG/build/pipeline/embedding.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | import os 3 | import asyncio 4 | import json 5 | import math 6 | 7 | 8 | from ...config import NodeConfig 9 | from ...LLM import Embedding_message 10 | 11 | 12 | from ...storage import ( 13 | Mapper, 14 | storage 15 | ) 16 | 17 | from ...logging import info_timer 18 | 19 | class Embedding_pipeline(): 20 | 21 | def __init__(self,config:NodeConfig): 22 | self.config = config 23 | self.embedding_client = self.config.embedding_client 24 | self.mapper = self.load_mapper() 25 | 26 | 27 | 28 | def load_mapper(self) -> Mapper: 29 | mapping_list = [self.config.text_path, 30 | self.config.semantic_units_path, 31 | self.config.attributes_path] 32 | mapping_list = [path for path in mapping_list if os.path.exists(path)] 33 | return Mapper(mapping_list) 34 | 35 | async def get_embeddings(self,context_dict:Dict[str,Embedding_message]): 36 | 37 | empty_ids = [key for key, value in context_dict.items() if value == ""] 38 | 39 | if len(empty_ids) > 0: 40 | 41 | context_dict = {key: value for key, value in context_dict.items() if value != ""} 42 | 43 | for empty_id in empty_ids: 44 | self.mapper.delete(empty_id) 45 | 46 | 47 | embedding_input = list(context_dict.values()) 48 | 49 | ids = list(context_dict.keys()) 50 | 51 | embedding_output = await self.embedding_client(embedding_input,cache_path=self.config.LLM_error_cache,meta_data = {'ids':ids}) 52 | 53 | if embedding_output == 'Error cached': 54 | return 55 | 56 | 57 | with open(self.config.embedding_cache,'a',encoding='utf-8') as f: 58 | 59 | for i in range(len(ids)): 60 | line = {'hash_id':ids[i],'embedding':embedding_output[i]} 61 | f.write(json.dumps(line)+'\n') 62 | 63 | self.config.tracker.update() 64 | 65 | def delete_embedding_cache(self): 66 | 67 | if os.path.exists(self.config.embedding_cache): 68 | os.remove(self.config.embedding_cache) 69 | 70 | 71 | 72 | async def generate_embeddings(self): 73 | tasks = [] 74 | none_embedding_ids = self.mapper.find_none_embeddings() 75 | self.config.tracker.set(math.ceil(len(none_embedding_ids)/self.config.embedding_batch_size),desc='Generating embeddings') 76 | for i in range(0,len(none_embedding_ids),self.config.embedding_batch_size): 77 | context_dict = {} 78 | for id in none_embedding_ids[i:i+self.config.embedding_batch_size]: 79 | context_dict[id] = self.mapper.get(id,'context') 80 | tasks.append(self.get_embeddings(context_dict)) 81 | await asyncio.gather(*tasks) 82 | self.config.tracker.close() 83 | 84 | def insert_embeddings(self): 85 | 86 | if not os.path.exists(self.config.embedding_cache): 87 | return None 88 | 89 | with open(self.config.embedding_cache,'r',encoding='utf-8') as f: 90 | lines = [] 91 | for line in f: 92 | line = json.loads(line.strip()) 93 | if isinstance(line['embedding'],str): 94 | continue 95 | self.mapper.add_attribute(line['hash_id'],'embedding','done') 96 | lines.append(line) 97 | 98 | storage(lines).save_parquet(self.config.embedding,append=os.path.exists(self.config.embedding)) 99 | self.mapper.update_save() 100 | 101 | def check_error_cache(self) -> None: 102 | 103 | if os.path.exists(self.config.LLM_error_cache): 104 | num = 0 105 | 106 | with open(self.config.LLM_error_cache,'r',encoding='utf-8') as f: 107 | for line in f: 108 | num += 1 109 | 110 | if num > 0: 111 | self.config.console.print(f"[red]LLM Error Detected,There are {num} errors") 112 | self.config.console.print("[red]Please check the error log") 113 | self.config.console.print("[red]The error cache is named LLM_error.jsonl, stored in the cache folder") 114 | self.config.console.print("[red]Please fix the error and run the pipeline again") 115 | raise Exception("Error happened in embedding pipeline, Error cached.") 116 | 117 | async def rerun(self): 118 | 119 | with open(self.config.LLM_error_cache,'r',encoding='utf-8') as f: 120 | LLM_store = [] 121 | 122 | for line in f: 123 | line = json.loads(line) 124 | LLM_store.append(line) 125 | 126 | tasks = [] 127 | context_dict = {} 128 | 129 | self.config.tracker.set(len(LLM_store),desc='Rerun embedding') 130 | 131 | for store in LLM_store: 132 | input_data = store['input_data'] 133 | meta_data = store['meta_data'] 134 | store.pop('input_data') 135 | store.pop('meta_data') 136 | tasks.append(self.request_save(input_data,store,self.config)) 137 | 138 | await asyncio.gather(*tasks) 139 | self.config.tracker.close() 140 | self.insert_embeddings() 141 | self.delete_embedding_cache() 142 | self.check_error_cache() 143 | await self.main_async() 144 | 145 | async def request_save(self, 146 | input_data:Embedding_message, 147 | meta_data:Dict, 148 | config:NodeConfig) -> None: 149 | 150 | response = await config.client(input_data,cache_path=config.LLM_error_cache,meta_data = meta_data) 151 | 152 | if response == 'Error cached': 153 | return 154 | 155 | with open(self.config.embedding_cache,'a',encoding='utf-8') as f: 156 | for i in range(len(meta_data['ids'])): 157 | line = {'hash_id':meta_data['ids'][i],'embedding':response[i]} 158 | f.write(json.dumps(line)+'\n') 159 | 160 | 161 | 162 | def check_embedding_cache(self): 163 | if os.path.exists(self.config.embedding_cache): 164 | self.insert_embeddings() 165 | self.delete_embedding_cache() 166 | 167 | @info_timer(message='Embedding Pipeline') 168 | async def main(self): 169 | self.check_embedding_cache() 170 | await self.generate_embeddings() 171 | self.insert_embeddings() 172 | self.delete_embedding_cache() 173 | self.check_error_cache() 174 | 175 | 176 | 177 | 178 | 179 | -------------------------------------------------------------------------------- /NodeRAG/build/pipeline/graph_pipeline.py: -------------------------------------------------------------------------------- 1 | import networkx as nx 2 | from typing import List,Dict 3 | import json 4 | import os 5 | import asyncio 6 | 7 | from ...LLM import LLMOutput 8 | 9 | from ..component import ( 10 | Semantic_unit, 11 | Entity, 12 | Relationship 13 | ) 14 | 15 | from ...storage import storage 16 | from ...config import NodeConfig 17 | from ...logging import info_timer 18 | class Graph_pipeline: 19 | 20 | 21 | def __init__(self,config:NodeConfig): 22 | 23 | self.config = config 24 | self.G = self.load_graph() 25 | self.indices = self.config.indices 26 | self.data ,self.processed_data = self.load_data() 27 | self.API_request = self.config.API_client 28 | self.prompt_manager = self.config.prompt_manager 29 | self.semantic_units = [] 30 | self.entities = [] 31 | self.relationship, self.relationship_lookup = self.load_relationship() 32 | self.relationship_nodes = [] 33 | self.console = self.config.console 34 | 35 | 36 | 37 | def check_processed(self,data:Dict)->bool: 38 | if data.get('processed'): 39 | return False 40 | return True 41 | 42 | def load_graph(self) -> nx.Graph: 43 | if os.path.exists(self.config.graph_path): 44 | return storage.load_pickle(self.config.graph_path) 45 | return nx.Graph() 46 | 47 | def load_data(self)->List[LLMOutput]: 48 | data_list = [] 49 | processed_data = [] 50 | with open(self.config.text_decomposition_path, 'r', encoding='utf-8') as f: 51 | for line in f: 52 | data = json.loads(line) 53 | if self.check_processed(data): 54 | data_list.append(data) 55 | else: 56 | processed_data.append(data) 57 | return data_list,processed_data 58 | 59 | def load_relationship(self)->List[Relationship]: 60 | 61 | if os.path.exists(self.config.relationship_path): 62 | df = storage.load(self.config.relationship_path) 63 | relationship = [Relationship.from_df_row(row) for row in df.itertuples()] 64 | relationship_lookup = {relationship.hash_id: relationship for relationship in relationship} 65 | return relationship,relationship_lookup 66 | 67 | return [],{} 68 | 69 | async def build_graph(self): 70 | 71 | self.config.tracker.set(len(self.data),desc="Building graph") 72 | tasks = [] 73 | 74 | for data in self.data: 75 | tasks.append(self.graph_tasks(data)) 76 | await asyncio.gather(*tasks) 77 | self.config.tracker.close() 78 | 79 | async def graph_tasks(self,data:Dict): 80 | text_hash_id = data.get('text_hash_id') 81 | response = data.get('response') 82 | 83 | if isinstance(response,dict): 84 | Output = response.get('Output') 85 | 86 | for output in Output: 87 | semantic_unit = output.get('semantic_unit') 88 | entities = output.get('entities') 89 | relationships = output.get('relationships') 90 | 91 | semantic_unit_hash_id = self.add_semantic_unit(semantic_unit,text_hash_id) 92 | entities_hash_id = self.add_entities(entities,text_hash_id) 93 | 94 | entities_hash_id_re = await self.add_relationships(relationships,text_hash_id) 95 | entities_hash_id.extend(entities_hash_id_re) 96 | self.add_semantic_belongings(semantic_unit_hash_id,entities_hash_id) 97 | data['processed'] = True 98 | self.config.tracker.update() 99 | 100 | 101 | def save_data(self): 102 | with open(self.config.text_decomposition_path, 'w', encoding='utf-8') as f: 103 | self.processed_data.extend(self.data) 104 | for data in self.processed_data: 105 | f.write(json.dumps(data, ensure_ascii=False) + '\n') 106 | 107 | 108 | def add_semantic_unit(self,semantic_unit:Dict,text_hash_id:str): 109 | 110 | semantic_unit = Semantic_unit(semantic_unit,text_hash_id) 111 | if self.G.has_node(semantic_unit.hash_id): 112 | self.G.nodes[semantic_unit.hash_id]['weight'] += 1 113 | else: 114 | self.G.add_node(semantic_unit.hash_id,type ='semantic_unit',weight = 1) 115 | self.semantic_units.append(semantic_unit) 116 | return semantic_unit.hash_id 117 | 118 | def add_entities(self,entities:List[Dict],text_hash_id:str): 119 | 120 | entities_hash_id = [] 121 | 122 | for entity in entities: 123 | 124 | entity = Entity(entity,text_hash_id) 125 | entities_hash_id.append(entity.hash_id) 126 | 127 | if self.G.has_node(entity.hash_id): 128 | self.G.nodes[entity.hash_id]['weight'] += 1 129 | 130 | else: 131 | self.G.add_node(entity.hash_id,type = 'entity',weight = 1) 132 | self.entities.append(entity) 133 | 134 | return entities_hash_id 135 | 136 | def add_semantic_belongings(self, semantic_unit_hash_id: str, hash_id: List[str]): 137 | for entity_hash_id in hash_id: 138 | 139 | 140 | if self.G.has_edge(semantic_unit_hash_id,entity_hash_id): 141 | self.G[semantic_unit_hash_id][entity_hash_id]['weight'] += 1 142 | else: 143 | self.G.add_edge(semantic_unit_hash_id,entity_hash_id,weight = 1) 144 | 145 | async def add_relationships(self,relationships:List[str],text_hash_id:str): 146 | 147 | entities_hash_id = [] 148 | for relationship in relationships: 149 | 150 | relationship = relationship.split(',') 151 | relationship = [i.strip() for i in relationship] 152 | 153 | if len(relationship) != 3: 154 | relationship = await self.reconstruct_relationship(relationship) 155 | 156 | relationship = Relationship(relationship,text_hash_id) 157 | hash_id = relationship.hash_id 158 | if hash_id in self.relationship_lookup: 159 | Re = self.relationship_lookup[hash_id] 160 | Re.add(relationship.raw_context) 161 | continue 162 | 163 | 164 | self.relationship.append(relationship) 165 | self.relationship_lookup[hash_id] = relationship 166 | 167 | 168 | for node in [relationship.source, relationship.target, relationship]: 169 | if not self.G.has_node(node.hash_id): 170 | self.G.add_node(node.hash_id, type='entity' if node in [relationship.source, relationship.target] else 'relationship', weight=1) 171 | if node in [relationship.source, relationship.target]: 172 | self.relationship_nodes.append(node) 173 | entities_hash_id.append(node.hash_id) 174 | 175 | 176 | for edge in [(relationship.source.hash_id, relationship.hash_id), (relationship.hash_id, relationship.target.hash_id)]: 177 | if not self.G.has_edge(*edge): 178 | self.G.add_edge(*edge, weight=1) 179 | else: 180 | self.G[edge[0]][edge[1]]['weight'] += 1 181 | return entities_hash_id 182 | 183 | async def reconstruct_relationship(self,relationship:List[str])->List[str]: 184 | 185 | query = self.prompt_manager.relationship_reconstraction.format(relationship=relationship) 186 | json_format = self.prompt_manager.relationship_reconstraction_json 187 | input_data = {'query':query,'response_format':json_format} 188 | response = await self.API_request(input_data) 189 | return [response.get('source'),response.get('relationship'),response.get('target')] 190 | 191 | 192 | 193 | 194 | def save_semantic_units(self): 195 | semantic_units = [] 196 | for semantic_unit in self.semantic_units: 197 | semantic_units.append({'hash_id':semantic_unit.hash_id, 198 | 'human_readable_id':semantic_unit.human_readable_id, 199 | 'type':'semantic_unit', 200 | 'context':semantic_unit.raw_context, 201 | 'text_hash_id':semantic_unit.text_hash_id, 202 | 'weight':self.G.nodes[semantic_unit.hash_id]['weight'], 203 | 'embedding':None, 204 | 'insert':None}) 205 | G_semantic_units = [node for node in self.G.nodes if self.G.nodes[node]['type'] == 'semantic_unit'] 206 | assert len(semantic_units) == len(G_semantic_units), f"The number of semantic units is not equal to the number of nodes in the graph. {len(semantic_units)} != {len(G_semantic_units)}" 207 | return semantic_units 208 | 209 | 210 | def save_entities(self): 211 | entities = [] 212 | 213 | for entity in self.entities: 214 | entities.append({'hash_id':entity.hash_id, 215 | 'human_readable_id':entity.human_readable_id, 216 | 'type':'entity', 217 | 'context':entity.raw_context, 218 | 'text_hash_id':entity.text_hash_id, 219 | 'weight':self.G.nodes[entity.hash_id]['weight']}) 220 | for node in self.relationship_nodes: 221 | entities.append({'hash_id':node.hash_id, 222 | 'human_readable_id':node.human_readable_id, 223 | 'type':'entity', 224 | 'context':node.raw_context, 225 | 'text_hash_id':node.text_hash_id, 226 | 'weight':self.G.nodes[node.hash_id]['weight']}) 227 | G_entities = [node for node in self.G.nodes if self.G.nodes[node]['type'] == 'entity'] 228 | assert len(entities) == len(G_entities), f"The number of entities is not equal to the number of nodes in the graph. {len(entities)} != {len(G_entities)}" 229 | return entities 230 | 231 | 232 | def save_relationships(self): 233 | relationships = [] 234 | for relationship in self.relationship: 235 | relationships.append({'hash_id':relationship.hash_id, 236 | 'human_readable_id':relationship.human_readable_id, 237 | 'type':'relationship', 238 | 'unique_relationship':list(relationship.unique_relationship), 239 | 'context':relationship.raw_context, 240 | 'text_hash_id':relationship.text_hash_id, 241 | 'weight':self.G.nodes[relationship.hash_id]['weight']}) 242 | relation_nodes = [node for node in self.G.nodes if self.G.nodes[node]['type'] == 'relationship'] 243 | assert len(relationships) == len(relation_nodes), f"The number of relationships is not equal to the number of edges in the graph. {len(relationships)} != {len(relation_nodes)}" 244 | return relationships 245 | 246 | 247 | def save(self): 248 | semantic_units = self.save_semantic_units() 249 | entities = self.save_entities() 250 | relationships = self.save_relationships() 251 | storage(semantic_units).save_parquet(self.config.semantic_units_path,append= os.path.exists(self.config.semantic_units_path)) 252 | storage(entities).save_parquet(self.config.entities_path,append= os.path.exists(self.config.entities_path)) 253 | storage(relationships).save_parquet(self.config.relationship_path,append= os.path.exists(self.config.relationship_path)) 254 | self.console.print('[green]Semantic units, entities and relationships stored[/green]') 255 | 256 | def save_graph(self): 257 | if self.data == []: 258 | return None 259 | storage(self.G).save_pickle(self.config.graph_path) 260 | self.console.print('[green]Graph stored[/green]') 261 | 262 | @info_timer(message='Graph Pipeline') 263 | async def main(self): 264 | await self.build_graph() 265 | self.save() 266 | self.save_graph() 267 | self.indices.store_all_indices(self.config.indices_path) 268 | self.save_data() 269 | 270 | 271 | 272 | 273 | 274 | 275 | 276 | -------------------------------------------------------------------------------- /NodeRAG/build/pipeline/summary_generation.py: -------------------------------------------------------------------------------- 1 | import leidenalg as la 2 | import os 3 | import json 4 | import asyncio 5 | import faiss 6 | import math 7 | import numpy as np 8 | 9 | from ...storage import ( 10 | Mapper, 11 | storage 12 | ) 13 | 14 | from ..component import ( 15 | Community_summary, 16 | High_level_elements 17 | ) 18 | from ...config import NodeConfig 19 | 20 | 21 | from ...utils import ( 22 | IGraph, 23 | ) 24 | 25 | from ...logging import info_timer 26 | 27 | class SummaryGeneration: 28 | 29 | def __init__(self,config:NodeConfig): 30 | 31 | self.config = config 32 | self.indices = self.config.indices 33 | self.communities = [] 34 | self.high_level_elements = [] 35 | 36 | if os.path.exists(self.config.graph_path): 37 | 38 | self.mapper = Mapper([self.config.semantic_units_path, 39 | self.config.attributes_path]) 40 | self.mapper.add_embedding(self.config.embedding) 41 | self.G = storage.load(self.config.graph_path) 42 | self.G_ig = IGraph(self.G).to_igraph() 43 | self.nodes_high_level_elements_group = [] 44 | self.nodes_high_level_elements_match = [] 45 | 46 | 47 | 48 | 49 | def partition(self): 50 | 51 | partition = la.find_partition(self.G_ig,la.ModularityVertexPartition) 52 | 53 | for i,community in enumerate(partition): 54 | community_name = [self.G_ig.vs[node]['name'] for node in community if self.G_ig.vs[node]['name'] in self.mapper.embeddings] 55 | self.communities.append(Community_summary(community_name,self.mapper,self.G,self.config)) 56 | 57 | async def generate_community_summary(self,community:Community_summary): 58 | 59 | await community.generate_community_summary() 60 | if isinstance(community.response,str): 61 | self.config.tracker.update() 62 | return 63 | 64 | community_dict = {'community':community.community_node, 65 | 'response':community.response, 66 | 'hash_id':community.hash_id, 67 | 'human_readable_id':community.human_readable_id} 68 | 69 | with open(self.config.summary_path,'a',encoding='utf-8') as f: 70 | f.write(json.dumps(community_dict,ensure_ascii=False)+'\n') 71 | 72 | self.config.tracker.update() 73 | 74 | 75 | 76 | async def generate_high_level_element_summary(self): 77 | 78 | self.partition() 79 | 80 | tasks = [] 81 | 82 | self.config.tracker.set(len(self.communities),'Community Summary') 83 | for community in self.communities: 84 | tasks.append(self.generate_community_summary(community)) 85 | 86 | await asyncio.gather(*tasks) 87 | 88 | self.config.tracker.close() 89 | 90 | 91 | async def get_summary_embedding(self): 92 | tasks = [] 93 | self.config.tracker.set(math.ceil(len(self.high_level_elements)/self.config.embedding_batch_size),'High Level Element Embedding') 94 | 95 | for i in range(0,len(self.high_level_elements),self.config.embedding_batch_size): 96 | high_level_element_batch = self.high_level_elements[i:i+self.config.embedding_batch_size] 97 | tasks.append(self.embedding_store(high_level_element_batch)) 98 | await asyncio.gather(*tasks) 99 | self.config.tracker.close() 100 | 101 | async def embedding_store(self,high_level_element_batch:list[High_level_elements]): 102 | 103 | context = [high_level_element.context for high_level_element in high_level_element_batch] 104 | embedding = await self.config.embedding_client(context) 105 | 106 | for i in range(len(high_level_element_batch)): 107 | high_level_element_batch[i].store_embedding(embedding[i]) 108 | self.config.tracker.update() 109 | 110 | 111 | async def high_level_element_summary(self): 112 | results = [] 113 | 114 | with open(self.config.summary_path, 'r', encoding='utf-8') as f: 115 | for line in f: 116 | line = json.loads(line) 117 | results.append(line) 118 | 119 | All_nodes = [] 120 | self.config.tracker.set(len(results),'High Level Element Summary') 121 | for result in results: 122 | high_level_elements = [] 123 | node_names = result['community'] 124 | for high_level_element in result['response']['high_level_elements']: 125 | he = High_level_elements(high_level_element['description'],high_level_element['title'],self.config) 126 | he.related_node(node_names) 127 | if self.G.has_node(he.hash_id): 128 | self.G.nodes[he.hash_id]['weight'] += 1 129 | if self.G.has_node(he.title_hash_id): 130 | self.G.nodes[he.title_hash_id]['weight'] += 1 131 | else: 132 | continue 133 | 134 | else: 135 | self.G.add_node(he.hash_id, type='high_level_element', weight=1) 136 | self.G.add_node(he.title_hash_id, type='high_level_element_title', weight=1, related_node=he.hash_id) 137 | high_level_elements.append(he) 138 | 139 | edge = (he.hash_id,he.title_hash_id) 140 | 141 | if not self.G.has_edge(*edge): 142 | self.G.add_edge(*edge,weight=1) 143 | 144 | All_nodes.extend(node_names) 145 | self.high_level_elements.extend(high_level_elements) 146 | self.config.tracker.update() 147 | self.config.tracker.close() 148 | await self.get_summary_embedding() 149 | 150 | 151 | centroids = math.ceil(math.sqrt(len(All_nodes)+len(self.high_level_elements))) 152 | threshold = (len(All_nodes)+len(self.high_level_elements))/centroids 153 | n=0 154 | if threshold > self.config.Hcluster_size: 155 | embedding_list = np.array([self.mapper.embeddings[node] for node in All_nodes], dtype=np.float32) 156 | high_level_element_embedding = np.array([he.embedding for he in self.high_level_elements], dtype=np.float32) 157 | all_embeddings = np.vstack([high_level_element_embedding, embedding_list]) 158 | 159 | kmeans = faiss.Kmeans(d=all_embeddings.shape[1], k=centroids) 160 | kmeans.train(all_embeddings.astype(np.float32)) 161 | _, cluster_labels = kmeans.assign(all_embeddings.astype(np.float32)) 162 | high_level_element_cluster_labels = cluster_labels[:len(self.high_level_elements)] 163 | embedding_cluster_labels = cluster_labels[len(self.high_level_elements):] 164 | self.config.console.print(f'[bold green]KMeans Clustering with {centroids} centroids[/bold green]') 165 | 166 | self.config.tracker.set(len(self.high_level_elements),'Adding High Level Element Summary') 167 | for i in range(len(self.high_level_elements)): 168 | for j in range(len(All_nodes)): 169 | if high_level_element_cluster_labels[i] == embedding_cluster_labels[j] and All_nodes[j] in self.high_level_elements[i].related_node: 170 | self.G.add_edge(All_nodes[j],self.high_level_elements[i].hash_id,weight=1) 171 | n+=1 172 | self.config.tracker.update() 173 | 174 | 175 | 176 | else: 177 | self.config.tracker.set(len(self.high_level_elements),'Adding High Level Element Summary') 178 | for he in self.high_level_elements: 179 | for node in he.related_node: 180 | self.G.add_edge(node,he.hash_id,weight=1) 181 | n+=1 182 | self.config.tracker.update() 183 | 184 | self.config.tracker.close() 185 | self.config.console.print(f'[bold green]Added {n} edges[/bold green]') 186 | 187 | 188 | 189 | 190 | 191 | 192 | 193 | 194 | 195 | 196 | 197 | 198 | def store_graph(self): 199 | storage(self.G).save_pickle(self.config.graph_path) 200 | self.config.console.print('[bold green]Graph stored[/bold green]') 201 | 202 | def delete_community_cache(self): 203 | os.remove(self.config.summary_path) 204 | 205 | def store_high_level_elements(self): 206 | 207 | high_level_elements = [] 208 | titles = [] 209 | embedding_list = [] 210 | for high_level_element in self.high_level_elements: 211 | high_level_elements.append({'type':'high_level_element', 212 | 'title_hash_id':high_level_element.title_hash_id, 213 | 'context':high_level_element.context, 214 | 'hash_id':high_level_element.hash_id, 215 | 'human_readable_id':high_level_element.human_readable_id, 216 | 'related_nodes':list(self.G.neighbors(high_level_element.hash_id)), 217 | 'embedding':'done'}) 218 | 219 | titles.append({'type':'high_level_element_title', 220 | 'hash_id':high_level_element.title_hash_id, 221 | 'context':high_level_element.title, 222 | 'human_readable_id':high_level_element.human_readable_id}) 223 | 224 | embedding_list.append({'hash_id':high_level_element.hash_id, 225 | 'embedding':high_level_element.embedding}) 226 | G_high_level_elements = [node for node in self.G.nodes if self.G.nodes[node].get('type') == 'high_level_element'] 227 | assert len(high_level_elements) == len(G_high_level_elements), f"The number of high level elements is not equal to the number of nodes in the graph. {len(high_level_elements)} != {len(G_high_level_elements)}" 228 | 229 | storage(high_level_elements).save_parquet(self.config.high_level_elements_path,append = os.path.exists(self.config.high_level_elements_path)) 230 | storage(titles).save_parquet(self.config.high_level_elements_titles_path,append = os.path.exists(self.config.high_level_elements_titles_path)) 231 | storage(embedding_list).save_parquet(self.config.embedding,append = os.path.exists(self.config.embedding)) 232 | self.config.console.print('[bold green]High level elements stored[/bold green]') 233 | 234 | @info_timer(message='Summary Generation Pipeline') 235 | async def main(self): 236 | if os.path.exists(self.config.graph_path): 237 | if os.path.exists(self.config.summary_path): 238 | os.remove(self.config.summary_path) 239 | await self.generate_high_level_element_summary() 240 | await self.high_level_element_summary() 241 | self.store_high_level_elements() 242 | self.store_graph() 243 | self.indices.store_all_indices(self.config.indices_path) 244 | self.delete_community_cache() 245 | 246 | 247 | 248 | 249 | 250 | 251 | 252 | 253 | 254 | 255 | 256 | 257 | 258 | 259 | 260 | 261 | 262 | 263 | 264 | 265 | 266 | 267 | 268 | -------------------------------------------------------------------------------- /NodeRAG/build/pipeline/text_pipeline.py: -------------------------------------------------------------------------------- 1 | from typing import Dict,List 2 | import pandas as pd 3 | import asyncio 4 | import os 5 | import json 6 | 7 | from ...config import NodeConfig 8 | from ...LLM import LLM_message 9 | from ...storage import storage 10 | from ..component import Text_unit 11 | from ...logging.error import clear_cache 12 | from ...logging import info_timer 13 | 14 | class text_pipline(): 15 | 16 | def __init__(self, config:NodeConfig)-> None: 17 | 18 | self.config = config 19 | self.texts = self.load_texts() 20 | 21 | 22 | def load_texts(self) -> pd.DataFrame: 23 | 24 | texts = storage.load_parquet(self.config.text_path) 25 | return texts 26 | 27 | async def text_decomposition_pipline(self) -> None: 28 | 29 | async_task = [] 30 | self.config.tracker.set(len(self.texts),'Text Decomposition') 31 | 32 | for index, row in self.texts.iterrows(): 33 | text = Text_unit(row['context'],row['hash_id'],row['text_id']) 34 | async_task.append(text.text_decomposition(self.config)) 35 | await asyncio.gather(*async_task) 36 | 37 | 38 | def increment(self) -> None: 39 | 40 | exist_hash_id = [] 41 | 42 | with open(self.config.text_decomposition_path,'r',encoding='utf-8') as f: 43 | for line in f: 44 | line = json.loads(line) 45 | exist_hash_id.append(line['hash_id']) 46 | self.texts = self.texts[~self.texts['hash_id'].isin(exist_hash_id)] 47 | 48 | async def rerun(self) -> None: 49 | 50 | self.texts = self.load_texts() 51 | 52 | with open(self.config.LLM_error_cache,'r',encoding='utf-8') as f: 53 | LLM_store = [] 54 | for line in f: 55 | line = json.loads(line) 56 | LLM_store.append(line) 57 | 58 | clear_cache(self.config.LLM_error_cache) 59 | 60 | await self.rerun_request(LLM_store) 61 | self.config.tracker.close() 62 | await self.text_decomposition_pipline() 63 | 64 | async def rerun_request(self,LLM_store:List[Dict]) -> None: 65 | tasks = [] 66 | 67 | self.config.tracker.set(len(LLM_store),'Rerun LLM on error cache of text decomposition pipeline') 68 | 69 | for store in LLM_store: 70 | input_data = store['input_data'] 71 | store.pop('input_data') 72 | input_data.update({'response_format':self.config.prompt_manager.text_decomposition}) 73 | tasks.append(self.request_save(input_data,store,self.config)) 74 | await asyncio.gather(*tasks) 75 | 76 | async def request_save(self, 77 | input_data:LLM_message, 78 | meta_data:Dict) -> None: 79 | 80 | response = await self.config.client(input_data,cache_path=self.config.LLM_error_cache,meta_data = meta_data) 81 | 82 | with open(self.config.text_decomposition_path,'a',encoding='utf-8') as f: 83 | await f.write(json.dumps(response)+'\n') 84 | 85 | self.config.tracker.update() 86 | 87 | def check_error_cache(self) -> None: 88 | 89 | if os.path.exists(self.config.LLM_error_cache): 90 | num = 0 91 | 92 | with open(self.config.LLM_error_cache,'r',encoding='utf-8') as f: 93 | for line in f: 94 | num += 1 95 | 96 | if num > 0: 97 | self.config.console.print(f"[red]LLM Error Detected,There are {num} errors") 98 | self.config.console.print("[red]Please check the error log") 99 | self.config.console.print("[red]The error cache is named LLM_error.jsonl, stored in the cache folder") 100 | self.config.console.print("[red]Please fix the error and run the pipeline again") 101 | raise Exception("Error happened in text decomposition pipeline, Error cached.") 102 | 103 | @info_timer(message='Text Pipeline') 104 | async def main(self) -> None: 105 | 106 | if os.path.exists(self.config.text_decomposition_path): 107 | if os.path.getsize(self.config.text_decomposition_path) > 0: 108 | self.increment() 109 | 110 | await self.text_decomposition_pipline() 111 | self.config.tracker.close() 112 | self.check_error_cache() 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | -------------------------------------------------------------------------------- /NodeRAG/config/Node_config.py: -------------------------------------------------------------------------------- 1 | import time 2 | import os 3 | from ..logging import setup_logger 4 | import shutil 5 | import yaml 6 | from typing import Dict,Any 7 | 8 | 9 | from ..utils import ( 10 | index_manager, 11 | prompt_manager, 12 | YamlHandler 13 | ) 14 | 15 | 16 | from ..utils import ( 17 | Tracker, 18 | rich_console, 19 | SemanticTextSplitter 20 | ) 21 | from ..LLM import ( 22 | set_api_client, 23 | set_embedding_client, 24 | API_client 25 | ) 26 | 27 | from ..build.component import text_unit_index_counter 28 | from ..build.component import document_index_counter 29 | from ..build.component import semantic_unit_index_counter 30 | from ..build.component import entity_index_counter 31 | from ..build.component import relation_index_counter 32 | from ..build.component import attribute_index_counter 33 | from ..build.component import community_summary_index_counter,high_level_element_index_counter 34 | 35 | 36 | 37 | class NodeConfig(): 38 | _instance = None 39 | 40 | def __new__(cls,config:dict): 41 | if cls._instance is None: 42 | cls._instance = super(NodeConfig,cls).__new__(cls) 43 | cls._instance.config = config 44 | return cls._instance 45 | 46 | 47 | 48 | def __init__(self,config:Dict[str,Any]): 49 | 50 | 51 | self.config = config['config'] 52 | self.main_folder = self.config.get('main_folder') 53 | if self.main_folder is None: 54 | raise ValueError('main_folder is not set') 55 | 56 | if not os.path.exists(self.main_folder): 57 | raise ValueError(f'main_folder {self.main_folder} does not exist') 58 | 59 | self.input_folder = self.main_folder + '/input' 60 | self.cache = self.main_folder + '/cache' 61 | self.info = self.main_folder + '/info' 62 | 63 | self.embedding_path = self.cache + '/embedding.parquet' 64 | self.text_path = self.cache + '/text.parquet' 65 | self.documents_path = self.cache + '/documents.parquet' 66 | self.text_decomposition_path = self.cache + '/text_decomposition.jsonl' 67 | self.semantic_units_path = self.cache + '/semantic_units.parquet' 68 | self.entities_path = self.cache + '/entities.parquet' 69 | self.relationship_path = self.cache + '/relationship.parquet' 70 | self.graph_path = self.cache + '/new_graph.pkl' 71 | self.attributes_path = self.cache + '/attributes.parquet' 72 | self.embedding_cache = self.cache + '/embedding_cache.jsonl' 73 | self.embedding = self.cache + '/embedding.parquet' 74 | self.base_graph_path = self.cache + '/graph.pkl' 75 | self.summary_path = self.cache + '/community_summary.jsonl' 76 | self.high_level_elements_path = self.cache + '/high_level_elements.parquet' 77 | self.high_level_elements_titles_path = self.cache + '/high_level_elements_titles.parquet' 78 | self.HNSW_path = self.cache + '/HNSW.bin' 79 | self.hnsw_graph_path = self.cache + '/hnsw_graph.pkl' 80 | self.id_map_path = self.cache + '/id_map.parquet' 81 | self.LLM_error_cache = self.cache + '/LLM_error.jsonl' 82 | 83 | 84 | self.embedding_batch_size = self.config.get('embedding_batch_size',50) 85 | self._m = self.config.get('m',5) 86 | self._ef = self.config.get('ef',200) 87 | self._m0 = self.config.get('m0',None) 88 | self.space = self.config.get('space','l2') 89 | self.dim = self.config.get('dim',1536) 90 | self.docu_type = self.config.get('docu_type','mixed') 91 | 92 | self.Hcluster_size = self.config.get('Hcluster_size',39) 93 | self.cross_node = self.config.get('cross_node',10) 94 | self.Enode = self.config.get('Enode',10) 95 | self.Rnode = self.config.get('Rnode',10) 96 | self.Hnode = self.config.get('Hnode',10) 97 | 98 | self.HNSW_results = self.config.get('HNSW_results',10) 99 | self.similarity_weight = self.config.get('similarity_weight',1) 100 | self.accuracy_weight = self.config.get('accuracy_weight',10) 101 | self.ppr_alpha = self.config.get('ppr_alpha',0.5) 102 | self.ppr_max_iter = self.config.get('ppr_max_iter',8) 103 | self.unbalance_adjust = self.config.get('unbalance_adjust',False) 104 | 105 | 106 | self.indices_path = self.info + '/indices.json' 107 | self.state_path = self.info + '/state.json' 108 | self.document_hash_path = self.info + '/document_hash.json' 109 | self.info_path = self.info + '/info.log' 110 | if not os.path.exists(self.info): 111 | os.makedirs(self.info) 112 | if not os.path.exists(self.info_path): 113 | with open(self.info_path,'w') as f: 114 | f.write('') 115 | self.info_logger = setup_logger('info_logger',self.info_path) 116 | self.timer = [] 117 | self.tracker = Tracker(self.cache,use_rich=True) 118 | self.rich_console = rich_console() 119 | self.console = self.rich_console.console 120 | self.indices = self.load_indices() 121 | 122 | 123 | 124 | self._model_config = config['model_config'] 125 | self._embedding_config = config['embedding_config'] 126 | self._language = self.config['language'] 127 | 128 | try: 129 | self.API_client = set_api_client(API_client(self.model_config)) 130 | except: 131 | self.API_client = None 132 | 133 | try: 134 | self.embedding_client = set_embedding_client(API_client(self.embedding_config)) 135 | except: 136 | self.embedding_client = None 137 | 138 | try: 139 | 140 | self.embedding_client = set_embedding_client(API_client(self.embedding_config)) 141 | except: 142 | self.embedding_client = None 143 | 144 | self.semantic_text_splitter = SemanticTextSplitter(self.config['chunk_size'],self.model_config['model_name']) 145 | self.token_counter = self.semantic_text_splitter.token_counter 146 | 147 | 148 | 149 | 150 | 151 | self.prompt_manager = prompt_manager(self._language) 152 | 153 | 154 | 155 | 156 | @property 157 | def model_config(self): 158 | return self._model_config 159 | 160 | @property 161 | def embedding_config(self): 162 | return self._embedding_config 163 | 164 | @embedding_config.setter 165 | def embedding_config(self,embedding_config:dict): 166 | self._embedding_config = embedding_config 167 | try: 168 | self.embedding_client = set_embedding_client(API_client(self.embedding_config)) 169 | except: 170 | self.embedding_client = None 171 | self.console.print(f'warning: embedding_config is not valid') 172 | 173 | 174 | @model_config.setter 175 | def model_config(self,model_config:dict): 176 | self._model_config = model_config 177 | try: 178 | self.API_client = set_api_client(API_client(self.model_config)) 179 | self.semantic_text_splitter = SemanticTextSplitter(self.config['chunk_size'],self.model_config['model_name']) 180 | self.token_counter = self.semantic_text_splitter.token_counter 181 | except: 182 | self.API_client = None 183 | self.semantic_text_splitter = None 184 | self.token_counter = None 185 | self.console.print(f'warning: model_config is not valid') 186 | 187 | @property 188 | def language(self): 189 | return self._language 190 | 191 | @language.setter 192 | def language(self,language:str): 193 | self._language = language 194 | self.prompt_manager = prompt_manager(self._language) 195 | self.console.print(f'language set to {self._language}') 196 | 197 | 198 | def load_indices(self) -> index_manager: 199 | if os.path.exists(self.indices_path): 200 | return index_manager.load_indices(self.indices_path,self.console) 201 | else: 202 | return index_manager([document_index_counter, 203 | text_unit_index_counter, 204 | semantic_unit_index_counter, 205 | entity_index_counter, 206 | relation_index_counter, 207 | attribute_index_counter, 208 | community_summary_index_counter, 209 | high_level_element_index_counter],self.console) 210 | 211 | 212 | def store_readable_index(self) -> None: 213 | 214 | self.indices.store_all_indices(self.indices_path) 215 | 216 | 217 | def update_model_config(self,model_config:dict): 218 | self.model_config.update(model_config) 219 | 220 | def update_embedding_config(self,embedding_config:dict): 221 | self.embedding_config.update(embedding_config) 222 | 223 | def update_settings(self,settings:dict): 224 | self.config.update(settings) 225 | 226 | def config_integrity(self): 227 | if self.API_client is None: 228 | print(self.model_config) 229 | raise ValueError('API_client is not set properly') 230 | if self.embedding_client is None: 231 | raise ValueError('embedding_client is not set properly') 232 | if self.semantic_text_splitter is None: 233 | raise ValueError('semantic_text_splitter is not set properly') 234 | if not os.path.exists(self.main_folder): 235 | raise ValueError('main_folder does not exist') 236 | 237 | def record_info(self,message:str) -> None: 238 | 239 | self.info_logger.info(message) 240 | 241 | def start_timer(self,message:str): 242 | 243 | self.timer.append(time.time()) 244 | self.info_logger.info(message) 245 | 246 | def time_record(self): 247 | 248 | now = time.time() 249 | time_spent = now - self.timer[-1] 250 | self.timer.append(now) 251 | 252 | return time_spent 253 | 254 | def whole_time(self): 255 | 256 | if len(self.timer) > 1: 257 | self.record_info(f'Total time spent: {self.timer[-1] - self.timer[0]} seconds') 258 | 259 | else: 260 | self.record_info('No time record') 261 | 262 | def record_message_with_time(self,message:str): 263 | 264 | time_spent = self.time_record() 265 | self.record_info(f'{message}, Time spent: {time_spent} seconds') 266 | 267 | @staticmethod 268 | def create_config_file(main_folder:str): 269 | 270 | 271 | config_path = os.path.join(main_folder,'Node_config.yaml') 272 | if not os.path.exists(config_path): 273 | shutil.copyfile(os.path.join(os.path.dirname(__file__),'Node_config.yaml'),config_path) 274 | yaml_handler = YamlHandler(config_path) 275 | yaml_handler.update_config(['config','main_folder'],main_folder) 276 | yaml_handler.save() 277 | print(f'Config file created at {config_path}') 278 | else: 279 | print(f'Config file already exists at {config_path}') 280 | 281 | return config_path 282 | 283 | 284 | 285 | @classmethod 286 | def from_main_folder(cls, main_folder: str): 287 | 288 | config_path = cls.create_config_file(main_folder) 289 | 290 | 291 | with open(config_path,'r') as f: 292 | config = yaml.safe_load(f) 293 | 294 | return cls(config) 295 | 296 | 297 | 298 | 299 | 300 | 301 | 302 | 303 | 304 | 305 | 306 | 307 | 308 | 309 | -------------------------------------------------------------------------------- /NodeRAG/config/Node_config.yaml: -------------------------------------------------------------------------------- 1 | #============================================================================== 2 | # AI Model Configuration 3 | #============================================================================== 4 | model_config: 5 | service_provider: openai # AI service provider (e.g., openai, gemini) 6 | model_name: gpt-4o-mini # Model name for text generation 7 | api_keys: ~ # Your API key (optional) 8 | temperature: 0 # Temperature parameter for text generation 9 | max_tokens: 10000 # Maximum tokens to generate 10 | rate_limit: 40 # API rate limit (requests per second) 11 | 12 | embedding_config: 13 | service_provider: openai_embedding # Embedding service provider 14 | embedding_model_name: text-embedding-3-small # Model name for text embeddings 15 | api_keys: ~ # Your API key (optional) 16 | rate_limit: 20 # Rate limit for embedding requests 17 | 18 | 19 | #============================================================================== 20 | # Document Processing Configuration 21 | #============================================================================== 22 | config: 23 | # Basic Settings 24 | main_folder: ~ # Root folder for document processing 25 | language: English # Document processing language 26 | docu_type: mixed # Document type (mixed, pdf, txt, etc.) 27 | 28 | # Chunking Settings 29 | chunk_size: 1048 # Size of text chunks for processing 30 | embedding_batch_size: 50 # Batch size for embedding processing 31 | 32 | # UI Settings 33 | use_tqdm: False # Enable/disable progress bars 34 | use_rich: True # Enable/disable rich text formatting 35 | 36 | # HNSW Index Settings 37 | space: l2 # Distance metric for HNSW (l2, cosine) 38 | dim: 1536 # Embedding dimension (must match embedding model) 39 | m: 50 # Number of connections per layer in HNSW 40 | ef: 200 # Size of dynamic candidate list in HNSW 41 | m0: ~ # Number of bi-directional links in HNSW 42 | 43 | # Summary Settings 44 | Hcluster_size: 39 # Number of clusters for high-level element matching 45 | 46 | # Search Server Settings 47 | url: '127.0.0.1' # Server URL for search service 48 | port: 5000 # Server port number 49 | unbalance_adjust: True # Enable adjustment for unbalanced data 50 | cross_node: 10 # Number of cross nodes to return 51 | Enode: 10 # Number of entity nodes to return 52 | Rnode: 30 # Number of relationship nodes to return 53 | Hnode: 10 # Number of high-level nodes to return 54 | HNSW_results: 10 # Number of HNSW search results 55 | similarity_weight: 1 # Weight for similarity in personalized PageRank 56 | accuracy_weight: 1 # Weight for accuracy in personalized PageRank 57 | ppr_alpha: 0.5 # Damping factor for personalized PageRank 58 | ppr_max_iter: 2 # Maximum iterations for personalized PageRank 59 | -------------------------------------------------------------------------------- /NodeRAG/config/__init__.py: -------------------------------------------------------------------------------- 1 | from .Node_config import NodeConfig 2 | 3 | __all__ = ['NodeConfig'] -------------------------------------------------------------------------------- /NodeRAG/config/__main__.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import shutil 4 | from ..utils import YamlHandler 5 | 6 | args = argparse.ArgumentParser() 7 | args.add_argument('-f','--folder',type=str,required=True) 8 | 9 | 10 | args = args.parse_args() 11 | 12 | config_path = os.path.join(args.folder,'Node_config.yaml') 13 | input_folder = os.path.join(args.folder,'input') 14 | if not os.path.exists(input_folder): 15 | os.makedirs(input_folder) 16 | 17 | if not os.path.exists(config_path): 18 | 19 | 20 | shutil.copyfile(os.path.join(os.path.dirname(__file__),'Node_config.yaml'),config_path) 21 | yaml_handler = YamlHandler(config_path) 22 | yaml_handler.update_config(['config','main_folder'],args.folder) 23 | yaml_handler.save() 24 | print(f'Config file created at {config_path}') 25 | 26 | 27 | else: 28 | print(f'Config file already exists at {config_path}') 29 | -------------------------------------------------------------------------------- /NodeRAG/logging/__init__.py: -------------------------------------------------------------------------------- 1 | from .logger import setup_logger 2 | from .info_timer import info_timer 3 | 4 | __all__ = ['setup_logger','info_timer'] 5 | -------------------------------------------------------------------------------- /NodeRAG/logging/error.py: -------------------------------------------------------------------------------- 1 | from functools import wraps 2 | from .logger import setup_logger 3 | import json 4 | import os 5 | 6 | error_logger = setup_logger(__name__,os.path.join(os.getcwd(),'error.log')) 7 | 8 | def error_handler(func): 9 | @wraps(func) 10 | def wrapper(*args, **kwargs): 11 | try: 12 | return func(*args, **kwargs) 13 | except Exception as e: 14 | return str(e) 15 | return wrapper 16 | 17 | def error_handler_async(func): 18 | @wraps(func) 19 | async def wrapper(*args, **kwargs): 20 | try: 21 | return await func(*args, **kwargs) 22 | except Exception as e: 23 | return str(e) 24 | return wrapper 25 | 26 | def cache_error(func): 27 | @wraps(func) 28 | def wrapper(*args, **kwargs): 29 | response = func(*args, **kwargs) 30 | 31 | if isinstance(response, list): 32 | return response 33 | 34 | if isinstance(response, str): 35 | if kwargs.get('cache_path'): 36 | if "'error':" in response.lower(): 37 | print(f'Error happened: {response}') 38 | error_logger.error(response) 39 | 40 | meta_data = kwargs.get('meta_data',None) 41 | 42 | if meta_data is not None: 43 | cache_path = kwargs.get('cache_path') 44 | 45 | input_data = args[1] 46 | if input_data is None: 47 | input_data = kwargs.get('input',None) 48 | if isinstance(input_data,dict): 49 | if input_data.get('response_format',None) is not None: 50 | input_data.pop('response_format') 51 | LLM_store = {'input':input_data,'meta_data':meta_data} 52 | with open(cache_path,'a') as f: 53 | f.write(json.dumps(LLM_store)+'\n') 54 | response = 'Error cached' 55 | if response == 'Error cached': 56 | return response 57 | else: 58 | raise Exception(f'Error happened, please check the error log.') 59 | return response 60 | 61 | return wrapper 62 | 63 | def cache_error_async(func): 64 | @wraps(func) 65 | async def wrapper(*args, **kwargs): 66 | response = await func(*args, **kwargs) 67 | if isinstance(response, str): 68 | if kwargs.get('cache_path'): 69 | if "'error':" in response.lower(): 70 | #log errors 71 | error_logger.error(response) 72 | 73 | 74 | meta_data = kwargs.get('meta_data',None) 75 | 76 | if meta_data is not None: 77 | if kwargs.get('cache_path',None) is not None: 78 | cache_path = kwargs.get('cache_path') 79 | 80 | input_data = args[1] 81 | if input_data is None: 82 | input_data = kwargs.get('input',None) 83 | if input_data.get('response_format',None) is not None: 84 | input_data.pop('response_format') 85 | LLM_store = {'input':input_data,'meta_data':meta_data} 86 | with open(cache_path,'a') as f: 87 | f.write(json.dumps(LLM_store)+'\n') 88 | response = 'Error cached' 89 | if response == 'Error cached': 90 | return response 91 | else: 92 | raise Exception(f'Error happened, please check the error log.') 93 | return response 94 | 95 | return wrapper 96 | 97 | def clear_cache(path:str) -> None: 98 | with open(path,'w') as f: 99 | f.write('') 100 | return 'cache cleared' -------------------------------------------------------------------------------- /NodeRAG/logging/info_timer.py: -------------------------------------------------------------------------------- 1 | from functools import wraps 2 | 3 | def info_timer(message:str): 4 | def decorator(func): 5 | @wraps(func) 6 | async def wrapper(self,*args,**kwargs): 7 | self.config.start_timer(f'{message} Started') 8 | result = await func(self,*args,**kwargs) 9 | self.config.record_message_with_time(f'{message} Finished') 10 | return result 11 | return wrapper 12 | return decorator 13 | -------------------------------------------------------------------------------- /NodeRAG/logging/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | def setup_logger(name, log_file, level=logging.INFO): 4 | """Function to setup logger""" 5 | logger = logging.getLogger(name) 6 | logger.setLevel(level) 7 | 8 | 9 | file_handler = logging.FileHandler(log_file) 10 | file_handler.setLevel(level) 11 | 12 | 13 | formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') 14 | file_handler.setFormatter(formatter) 15 | 16 | 17 | if not logger.handlers: 18 | logger.addHandler(file_handler) 19 | 20 | return logger -------------------------------------------------------------------------------- /NodeRAG/search/Answer_base.py: -------------------------------------------------------------------------------- 1 | from ..config import NodeConfig 2 | 3 | class Retrieval(): 4 | 5 | def __init__(self,config:NodeConfig,id_to_text:dict,accurate_id_to_text:dict,id_to_type:dict): 6 | 7 | self.config = config 8 | self.HNSW_results_with_distance = None 9 | self._HNSW_results = None 10 | self.id_to_text = id_to_text 11 | self.accurate_id_to_text = accurate_id_to_text 12 | self.accurate_results = None 13 | self.search_list = [] 14 | self.unique_search_list = set() 15 | self.id_to_type = id_to_type 16 | self.relationship_list = None 17 | self._retrieved_list = None 18 | self._structured_prompt = None 19 | self._unstructured_prompt = None 20 | 21 | 22 | 23 | @property 24 | def HNSW_results(self): 25 | if self._HNSW_results is None: 26 | self._HNSW_results = [id for distance,id in self.HNSW_results_with_distance] 27 | self.search_list.extend(self._HNSW_results) 28 | self.unique_search_list.update(self._HNSW_results) 29 | return self._HNSW_results 30 | 31 | @property 32 | def model_name(self): 33 | return self.config.API_client.llm.model_name 34 | 35 | @property 36 | def HNSW_results_str(self): 37 | return [self.id_to_text[id] for id in self.HNSW_results] 38 | 39 | @property 40 | def accurate_results_str(self): 41 | return [self.accurate_id_to_text[id] for id in self.accurate_results] 42 | 43 | @property 44 | def retrieved_list(self): 45 | if self._retrieved_list is None: 46 | self._retrieved_list = [(self.id_to_text[id],self.id_to_type[id]) for id in self.search_list]+ [(self.id_to_text[id],'relationship') for id in self.relationship_list] 47 | return self._retrieved_list 48 | 49 | @property 50 | def structured_prompt(self): 51 | if self._structured_prompt is None: 52 | self._structured_prompt = self.types_info() 53 | return self._structured_prompt 54 | 55 | @property 56 | def unstructured_prompt(self)->str: 57 | if self._unstructured_prompt is None: 58 | self._unstructured_prompt = '\n'.join([content for content,_ in self.retrieved_list]) 59 | return self._unstructured_prompt 60 | 61 | @property 62 | def retrieval_info(self)->str: 63 | return self.structured_prompt 64 | 65 | def types_info(self)->str: 66 | types = set([type for _,type in self.retrieved_list]) 67 | prompt = '' 68 | for type in types: 69 | prompt += f'------------{type}-------------\n' 70 | n=1 71 | for content,typed in self.retrieved_list: 72 | if typed == type: 73 | prompt += f'{n}. {content}\n' 74 | n+=1 75 | prompt += '\n\n' 76 | return prompt 77 | 78 | def __str__(self): 79 | return self.retrieval_info 80 | 81 | 82 | 83 | class Answer(): 84 | 85 | def __init__(self,query:str,retrieval:Retrieval): 86 | self.query = query 87 | self.retrieval = retrieval 88 | self.response = None 89 | 90 | @property 91 | def retrieval_info(self): 92 | return self.retrieval.retrieval_info 93 | 94 | @property 95 | def structured_prompt(self): 96 | return self.retrieval.structured_prompt 97 | 98 | @property 99 | def unstructured_prompt(self): 100 | return self.retrieval.unstructured_prompt 101 | 102 | @property 103 | def retrieval_tokens(self): 104 | return self.retrieval.config.token_counter(self.retrieval_info) 105 | 106 | @property 107 | def response_tokens(self): 108 | return self.retrieval.config.token_counter(self.response) 109 | 110 | def __str__(self): 111 | return self.response 112 | 113 | 114 | -------------------------------------------------------------------------------- /NodeRAG/search/__init__.py: -------------------------------------------------------------------------------- 1 | from .search import NodeSearch 2 | 3 | __all__ = ['NodeSearch'] -------------------------------------------------------------------------------- /NodeRAG/search/__main__.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from flask import Flask, request, jsonify 3 | import yaml 4 | from .search import NodeSearch 5 | from ..config import NodeConfig 6 | import os 7 | parser = argparse.ArgumentParser(description='TGRAG search engine') 8 | parser.add_argument('-f','--folder_path', type=str, help='The folder path of the document') 9 | args = parser.parse_args() 10 | 11 | config_path = os.path.join(args.folder_path, 'Node_config.yaml') 12 | 13 | with open(config_path, 'r') as f: 14 | args.config = yaml.safe_load(f) 15 | 16 | model_config = args.config['model_config'] 17 | embedding_config = args.config['embedding_config'] 18 | 19 | 20 | document_config = args.config['config'] 21 | path = args.folder_path 22 | url = document_config.get('url','127.0.0.1') 23 | port = document_config.get('port',5000) 24 | 25 | Search_engine = NodeSearch(NodeConfig(args.config)) 26 | app = Flask(__name__) 27 | 28 | 29 | 30 | @app.route('/answer', methods=['POST']) 31 | def answer(): 32 | question = request.json['question'] 33 | answer = Search_engine.answer(question) 34 | return jsonify({'answer':answer.response}) 35 | 36 | @app.route('/answer_retrieval', methods=['POST']) 37 | def answer_retrieval(): 38 | question = request.json['question'] 39 | answer = Search_engine.answer(question) 40 | return jsonify({'answer':answer.response, 'retrieval':answer.retrieval_info}) 41 | 42 | @app.route('/retrieval', methods=['POST']) 43 | def search(): 44 | question = request.json['question'] 45 | retrieval = Search_engine.search(question) 46 | return jsonify({'retrieval':retrieval.retrieval_info}) 47 | 48 | if __name__ == '__main__': 49 | app.run(host=url, port=port,debug=False,threaded=True) 50 | 51 | -------------------------------------------------------------------------------- /NodeRAG/search/search.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Dict,List,Tuple 3 | import numpy as np 4 | import re 5 | 6 | 7 | from ..storage import Mapper 8 | from ..utils import HNSW 9 | from ..storage import storage 10 | from ..utils.graph_operator import GraphConcat 11 | from ..config import NodeConfig 12 | from ..utils.PPR import sparse_PPR 13 | from .Answer_base import Answer,Retrieval 14 | 15 | 16 | 17 | 18 | class NodeSearch(): 19 | 20 | def __init__(self,config:NodeConfig): 21 | 22 | 23 | self.config = config 24 | self.hnsw = self.load_hnsw() 25 | self.mapper = self.load_mapper() 26 | self.G = self.load_graph() 27 | self.id_to_type = {id:self.G.nodes[id].get('type') for id in self.G.nodes} 28 | self.id_to_text,self.accurate_id_to_text = self.mapper.generate_id_to_text(['entity','high_level_element_title']) 29 | self.sparse_PPR = sparse_PPR(self.G) 30 | self._semantic_units = None 31 | 32 | 33 | def load_mapper(self) -> Mapper: 34 | 35 | mapping_list = [self.config.semantic_units_path, 36 | self.config.entities_path, 37 | self.config.relationship_path, 38 | self.config.attributes_path, 39 | self.config.high_level_elements_path, 40 | self.config.text_path, 41 | self.config.high_level_elements_titles_path] 42 | 43 | for path in mapping_list: 44 | if not os.path.exists(path): 45 | raise Exception(f'{path} not found, Please check cache integrity. You may need to rebuild the database due to the loss of cache files.') 46 | 47 | mapper = Mapper(mapping_list) 48 | 49 | return mapper 50 | 51 | def load_hnsw(self) -> HNSW: 52 | if os.path.exists(self.config.HNSW_path): 53 | hnsw = HNSW(self.config) 54 | hnsw.load_HNSW() 55 | return hnsw 56 | else: 57 | raise Exception('No HNSW data found.') 58 | 59 | def load_graph(self): 60 | 61 | if os.path.exists(self.config.base_graph_path): 62 | G = storage.load(self.config.base_graph_path) 63 | else: 64 | raise Exception('No base graph found.') 65 | 66 | if os.path.exists(self.config.hnsw_graph_path): 67 | HNSW_graph = storage.load(self.config.hnsw_graph_path) 68 | else: 69 | raise Exception('No HNSW graph found.') 70 | 71 | if self.config.unbalance_adjust: 72 | G = GraphConcat(G).concat(HNSW_graph) 73 | return GraphConcat.unbalance_adjust(G) 74 | 75 | return GraphConcat(G).concat(HNSW_graph) 76 | 77 | 78 | def search(self,query:str): 79 | 80 | retrieval = Retrieval(self.config,self.id_to_text,self.accurate_id_to_text,self.id_to_type) 81 | 82 | 83 | # HNSW search for enter points by cosine similarity 84 | query_embedding = np.array(self.config.embedding_client.request(query),dtype=np.float32) 85 | HNSW_results = self.hnsw.search(query_embedding,HNSW_results=self.config.HNSW_results) 86 | retrieval.HNSW_results_with_distance = HNSW_results 87 | 88 | 89 | 90 | # Decompose query into entities and accurate search for short words level items. 91 | decomposed_entities = self.decompose_query(query) 92 | 93 | accurate_results = self.accurate_search(decomposed_entities) 94 | retrieval.accurate_results = accurate_results 95 | 96 | # Personlization for graph search 97 | personlization = {ids:self.config.similarity_weight for ids in retrieval.HNSW_results} 98 | personlization.update({id:self.config.accuracy_weight for id in retrieval.accurate_results}) 99 | 100 | weighted_nodes = self.graph_search(personlization) 101 | 102 | retrieval = self.post_process_top_k(weighted_nodes,retrieval) 103 | 104 | return retrieval 105 | 106 | def decompose_query(self,query:str): 107 | 108 | query = self.config.prompt_manager.decompose_query.format(query=query) 109 | response = self.config.API_client.request({'query':query,'response_format':self.config.prompt_manager.decomposed_text_json}) 110 | return response['elements'] 111 | 112 | 113 | def accurate_search(self, entities: List[str]) -> List[str]: 114 | accurate_results = [] 115 | 116 | for entity in entities: 117 | # Split entity into words and create a pattern to match the whole phrase 118 | words = entity.lower().split() 119 | pattern = re.compile(r'\b' + r'\s+'.join(map(re.escape, words)) + r'\b') 120 | result = [id for id, text in self.accurate_id_to_text.items() if pattern.search(text.lower())] 121 | if result: 122 | accurate_results.extend(result) 123 | 124 | return accurate_results 125 | 126 | 127 | def answer(self,query:str,id_type:bool=True): 128 | 129 | 130 | retrieval = self.search(query) 131 | 132 | ans = Answer(query,retrieval) 133 | 134 | if id_type: 135 | retrieved_info = ans.structured_prompt 136 | else: 137 | retrieved_info = ans.unstructured_prompt 138 | 139 | query = self.config.prompt_manager.answer.format(info=retrieved_info,query=query) 140 | ans.response = self.config.API_client.request({'query':query}) 141 | 142 | return ans 143 | 144 | 145 | 146 | async def answer_async(self,query:str,id_type:bool=True): 147 | 148 | 149 | retrieval = self.search(query) 150 | 151 | ans = Answer(query,retrieval) 152 | 153 | if id_type: 154 | retrieved_info = ans.structured_prompt 155 | else : 156 | retrieved_info = ans.unstructured_prompt 157 | 158 | query = self.config.prompt_manager.answer.format(info=retrieved_info,query=query) 159 | 160 | ans.response = await self.config.API_client({'query':query}) 161 | 162 | return ans 163 | 164 | 165 | def stream_answer(self,query:str,retrieved_info:str): 166 | 167 | query = self.config.prompt_manager.answer.format(info=retrieved_info,query=query) 168 | response = self.config.API_client.stream_chat({'query':query}) 169 | yield from response 170 | 171 | 172 | def graph_search(self,personlization:Dict[str,float])->List[Tuple[str,str]]|List[str]: 173 | 174 | page_rank_scores = self.sparse_PPR.PPR(personlization,alpha=self.config.ppr_alpha,max_iter=self.config.ppr_max_iter) 175 | 176 | 177 | return [id for id,score in page_rank_scores] 178 | 179 | 180 | def post_process_top_k(self,weighted_nodes:List[str],retrieval:Retrieval)->Retrieval: 181 | 182 | 183 | entity_list = [] 184 | high_level_element_title_list = [] 185 | relationship_list = [] 186 | 187 | addition_node = 0 188 | 189 | for node in weighted_nodes: 190 | if node not in retrieval.search_list: 191 | type = self.G.nodes[node].get('type') 192 | match type: 193 | case 'entity': 194 | if node not in entity_list and len(entity_list) < self.config.Enode: 195 | entity_list.append(node) 196 | case 'relationship': 197 | if node not in relationship_list and len(relationship_list) < self.config.Rnode: 198 | relationship_list.append(node) 199 | case 'high_level_element_title': 200 | if node not in high_level_element_title_list and len(high_level_element_title_list) < self.config.Hnode: 201 | high_level_element_title_list.append(node) 202 | 203 | case _: 204 | if addition_node < self.config.cross_node: 205 | if node not in retrieval.unique_search_list: 206 | retrieval.search_list.append(node) 207 | retrieval.unique_search_list.add(node) 208 | addition_node += 1 209 | 210 | if (addition_node >= self.config.cross_node 211 | and len(entity_list) >= self.config.Enode 212 | and len(relationship_list) >= self.config.Rnode 213 | and len(high_level_element_title_list) >= self.config.Hnode): 214 | break 215 | 216 | for entity in entity_list: 217 | attributes = self.G.nodes[entity].get('attributes') 218 | if attributes: 219 | for attribute in attributes: 220 | if attribute not in retrieval.unique_search_list: 221 | retrieval.search_list.append(attribute) 222 | retrieval.unique_search_list.add(attribute) 223 | 224 | 225 | 226 | for high_level_element_title in high_level_element_title_list: 227 | related_node = self.G.nodes[high_level_element_title].get('related_node') 228 | if related_node not in retrieval.unique_search_list: 229 | retrieval.search_list.append(related_node) 230 | retrieval.unique_search_list.add(related_node) 231 | 232 | 233 | 234 | retrieval.relationship_list = list(set(relationship_list)) 235 | 236 | return retrieval 237 | 238 | -------------------------------------------------------------------------------- /NodeRAG/storage/__init__.py: -------------------------------------------------------------------------------- 1 | from .genid import genid 2 | from .storage import storage 3 | from .graph_mapping import Mapper 4 | 5 | __all__ = ['genid','storage','Mapper'] -------------------------------------------------------------------------------- /NodeRAG/storage/genid.py: -------------------------------------------------------------------------------- 1 | import uuid 2 | from hashlib import md5, sha256 3 | from typing import List 4 | from random import getrandbits 5 | 6 | def genid(input: List[str],type:str) -> str: 7 | match type: 8 | case "md5": 9 | return md5_hash(input) 10 | case "sha256": 11 | return sha256_hash(input) 12 | case "uuid": 13 | return uuid_hash(input) 14 | case _: 15 | raise ValueError("Type not supported") 16 | 17 | def md5_hash(input: List[str]) -> str: 18 | hashed = md5("".join(input).encode('utf-8')).hexdigest() 19 | return f'{hashed}' 20 | 21 | def sha256_hash(input: List[str]) -> str: 22 | hashed = sha256("".join(input).encode('utf-8')).hexdigest() 23 | return f'{hashed}' 24 | 25 | def uuid_hash() -> str: 26 | return str(uuid.uuid4(getrandbits(128), version=4).hex()) -------------------------------------------------------------------------------- /NodeRAG/storage/graph_mapping.py: -------------------------------------------------------------------------------- 1 | from typing import List, Dict, Any,Tuple 2 | import pandas as pd 3 | import numpy as np 4 | from .storage import storage 5 | 6 | class Mapper(): 7 | 8 | def __init__(self,path:List[str]|str) -> None: 9 | 10 | self.path = path 11 | self.mapping = dict() 12 | self.datasources = self.load_datasource() 13 | self.embeddings = {} 14 | 15 | def load_datasource(self) -> None: 16 | 17 | self.datasources = [] 18 | 19 | if isinstance(self.path,str): 20 | self.datasources.append(storage.load(self.path)) 21 | else: 22 | for path in self.path: 23 | self.datasources.append(storage.load(path)) 24 | 25 | for i,datasource in enumerate(self.datasources): 26 | self.generate_mapping(datasource,i) 27 | return self.datasources 28 | 29 | def generate_mapping(self,datasource:pd.DataFrame,datasource_id:int) -> None: 30 | 31 | for index,row in datasource.iterrows(): 32 | self.mapping[row['hash_id']] = [datasource_id,index] 33 | 34 | def add_datasource(self,path:str) -> None: 35 | 36 | if isinstance(self.path,str): 37 | if path in self.path: 38 | print(f'Datasource {path} already loaded') 39 | return None 40 | self.path = [self.path,path] 41 | 42 | else: 43 | self.path.append(path) 44 | 45 | self.datasources.append(storage.load(path)) 46 | self.generate_mapping(self.datasources[-1],len(self.datasources)-1) 47 | 48 | 49 | def add_datasources(self,paths:List[str]) -> None: 50 | for path in paths: 51 | self.add_datasource(path) 52 | 53 | def delete(self,id): 54 | 55 | datasource_id,index = self.mapping[id] 56 | self.datasources[datasource_id] = self.datasources[datasource_id].drop(index) 57 | self.mapping.pop(id) 58 | 59 | 60 | 61 | def get(self,hash_id:str,column:str|None=None) -> Dict[str,Any]|Any: 62 | 63 | datasource_id,index = self.mapping[hash_id] 64 | 65 | if column: 66 | return self.datasources[datasource_id].loc[index,column] 67 | 68 | else: 69 | return self.datasources[datasource_id].iloc[index].to_dict() 70 | 71 | def add_attribute(self,hash_id:str,column:str,value:Any) -> None: 72 | 73 | datasource_id,index = self.mapping[hash_id] 74 | self.datasources[datasource_id].loc[index,column] = value 75 | 76 | def update_save(self,numpy:bool=None) -> None: 77 | 78 | for i,datasource in enumerate(self.datasources): 79 | 80 | if numpy: 81 | datasource['embedding'] = datasource['embedding'].apply(lambda x: np.array(x.tolist(),dtype=np.float32)) 82 | 83 | storage(datasource).save_parquet(self.path[i]) 84 | 85 | def add_embedding(self,path) -> None: 86 | 87 | embeddings = storage.load(path) 88 | 89 | for index,row in embeddings.iterrows(): 90 | 91 | if row['hash_id'] in self.mapping: 92 | self.embeddings[row['hash_id']] = np.array(row['embedding'],dtype=np.float32) 93 | 94 | def add_embeddings_from_tuple(self,embeddings:Tuple[str,np.array]) -> None: 95 | 96 | for hash_id,embedding in embeddings: 97 | self.embeddings[hash_id] = embedding 98 | 99 | 100 | def find_non_HNSW(self) -> Dict[str,np.array]: 101 | 102 | embeddings = [] 103 | 104 | for datasource in self.datasources: 105 | if 'embedding' in datasource.columns: 106 | for index,row in datasource.iterrows(): 107 | if row['embedding'] == 'done': 108 | embeddings.append((row['hash_id'],self.embeddings[row['hash_id']])) 109 | 110 | return embeddings 111 | 112 | def find_none_embeddings(self) -> List[str]: 113 | 114 | none_embedding_ids = [] 115 | 116 | for i,datasource in enumerate(self.datasources): 117 | if 'embedding' in datasource.columns: 118 | for index,row in datasource.iterrows(): 119 | if row['embedding'] is None: 120 | none_embedding_ids.append(row['hash_id']) 121 | 122 | return none_embedding_ids 123 | 124 | def generate_id_to_text(self,types:List[str]) -> Tuple[Dict[str,str],Dict[str,str],List[str]]: 125 | 126 | self.id_to_text = {} 127 | self.accurate_id_to_text= {} 128 | self.relationships = [] 129 | 130 | for id in self.mapping: 131 | self.id_to_text[id] = self.get(id,'context') 132 | if self.get(id,'type') in types: 133 | self.accurate_id_to_text[id] = self.get(id,'context') 134 | 135 | 136 | return self.id_to_text,self.accurate_id_to_text 137 | 138 | -------------------------------------------------------------------------------- /NodeRAG/storage/storage.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Any, List 2 | import pandas as pd 3 | import json 4 | import pickle 5 | import os 6 | 7 | class storage(): 8 | 9 | def __init__(self,content:Dict[str,Any]|List[Dict[str,Any]]) -> None: 10 | self.content = content 11 | 12 | def save_json(self,path:str,append=False) -> None: 13 | if append: 14 | self.append_json(self.content,path) 15 | else: 16 | with open(path,'w') as f: 17 | json.dump(self.content,f,indent=4) 18 | 19 | def append_json(self,content:Dict[str,Any]|List[Dict[str,Any]],path:str) -> None: 20 | with open(path,'w') as f: 21 | exist_content = json.load(f) 22 | if isinstance(exist_content,dict): 23 | exist_content.update(content) 24 | elif isinstance(exist_content,list): 25 | exist_content.append(content) 26 | json.dump(content,f,indent=4) 27 | 28 | def save_parquet(self,path:str,append=False) -> None: 29 | if append: 30 | self.append_parquet(self.content,path) 31 | else: 32 | if isinstance(self.content,list): 33 | df = pd.DataFrame(self.content) 34 | elif isinstance(self.content,dict): 35 | df = pd.DataFrame(self.content) 36 | else: 37 | df =self.content 38 | df.to_parquet(path) 39 | 40 | def append_parquet(self,content,path:str) -> None: 41 | df = self.load_parquet(path) 42 | df = pd.concat([df,pd.DataFrame(content)],ignore_index=True) 43 | df.to_parquet(path) 44 | 45 | def save_pickle(self,path:str) -> None: 46 | with open(path,'wb') as f: 47 | pickle.dump(self.content,f) 48 | 49 | @staticmethod 50 | def load_pickle(path:str) -> Any: 51 | with open(path,'rb') as f: 52 | return pickle.load(f) 53 | 54 | @staticmethod 55 | def load_parquet(path:str) -> pd.DataFrame: 56 | return pd.read_parquet(path) 57 | 58 | @staticmethod 59 | def load_json(path:str) -> Dict[str,Any]: 60 | with open(path) as f: 61 | return json.load(f) 62 | 63 | @staticmethod 64 | def load_jsonl(path:str) -> List[Dict[str,Any]]: 65 | with open(path) as f: 66 | return [json.loads(line) for line in f] 67 | 68 | @staticmethod 69 | def load_csv(path:str) -> pd.DataFrame: 70 | return pd.read_csv(path) 71 | 72 | @staticmethod 73 | def load_excel(path:str) -> pd.DataFrame: 74 | return pd.read_excel(path) 75 | 76 | @staticmethod 77 | def load_file(path:str) -> str: 78 | with open(path) as f: 79 | return f.read() 80 | 81 | @staticmethod 82 | def load_tsv(path:str) -> pd.DataFrame: 83 | return pd.read_csv(path,sep='\t') 84 | 85 | 86 | @staticmethod 87 | def load(path:str) -> str: 88 | if not os.path.exists(path): 89 | return None 90 | if path.endswith('.json'): 91 | return storage.load_json(path) 92 | elif path.endswith('.jsonl'): 93 | return storage.load_jsonl(path) 94 | elif path.endswith('.parquet'): 95 | return storage.load_parquet(path) 96 | elif path.endswith('.pkl'): 97 | return storage.load_pickle(path) 98 | elif path.endswith('.md') or path.endswith('.txt'): 99 | return storage.load_file(path) 100 | elif path.endswith('.csv'): 101 | return storage.load_csv(path) 102 | elif path.endswith('.tsv'): 103 | return storage.load_tsv(path) 104 | elif path.endswith('.xlsx'): 105 | return storage.load_excel(path) 106 | 107 | 108 | 109 | -------------------------------------------------------------------------------- /NodeRAG/utils/HNSW.py: -------------------------------------------------------------------------------- 1 | import hnswlib_noderag 2 | import networkx as nx 3 | import numpy as np 4 | from typing import Tuple,List 5 | from heapq import nsmallest 6 | 7 | import os 8 | from ..storage import storage 9 | 10 | 11 | class HNSW: 12 | 13 | def __init__(self,config): 14 | 15 | self.config = config 16 | 17 | self.id_map = self.load_id_map() 18 | self.load_HNSW() 19 | self._nxgraphs = None 20 | 21 | 22 | 23 | @property 24 | def nxgraphs(self): 25 | graph_layer_0 = self.hnsw.get_layer_graph(0) 26 | if graph_layer_0 is not None: 27 | if self._nxgraphs is None: 28 | self._nxgraphs = nx.Graph() 29 | for id,neighbors in graph_layer_0.items(): 30 | for neighbor in neighbors: 31 | self._nxgraphs.add_edge(self.id_map[id],self.id_map[neighbor]) 32 | return self._nxgraphs 33 | else: 34 | return None 35 | 36 | def add_nodes(self, nodes: List[Tuple[str, np.ndarray]]): 37 | current_length = len(self.id_map) 38 | id_list = [] 39 | embedding_list = [] 40 | for idx, (node_id, embedding) in enumerate(nodes): 41 | new_id = current_length + idx 42 | self.id_map[new_id] = node_id 43 | id_list.append(new_id) 44 | embedding_list.append(embedding) 45 | self.hnsw.resize_index(len(id_list)+current_length) 46 | self.hnsw.add_items(np.array(embedding_list).astype(np.float32),id_list) 47 | 48 | def search(self,query:np.ndarray,HNSW_results:int=None): 49 | 50 | if HNSW_results is None: 51 | HNSW_results = self.config.top_k 52 | 53 | idx,dist = self.hnsw.knn_query(query,HNSW_results) 54 | idx = idx.flatten() 55 | dist = dist.flatten() 56 | node_list = [self.id_map[idx[i]] for i in range(len(idx))] 57 | dist_list = list(dist) 58 | results = zip(dist_list,node_list) 59 | return results 60 | 61 | def search_list(self,query_list:List[np.ndarray],HNSW_results:int=None): 62 | 63 | if HNSW_results is None: 64 | HNSW_results = self.config.top_k 65 | 66 | idx,dist = self.hnsw.knn_query(np.array(query_list).astype(np.float32),HNSW_results) 67 | idx = idx.flatten() 68 | dist = dist.flatten() 69 | node_list = [] 70 | dist_list = [] 71 | for i in range(len(idx)): 72 | if self.id_map[idx[i]] not in node_list: 73 | node_list.append(self.id_map[idx[i]]) 74 | dist_list.append(dist[i]) 75 | else: 76 | dist_list[node_list.index(self.id_map[idx[i]])] = 0.9*min(dist_list[node_list.index(self.id_map[idx[i]])],dist[i]) 77 | results = zip(dist_list,node_list) 78 | 79 | return nsmallest(HNSW_results,results) 80 | 81 | 82 | def load_id_map(self): 83 | 84 | if os.path.exists(self.config.id_map_path): 85 | id_map = storage.load(self.config.id_map_path) 86 | return dict(zip(id_map['id'],id_map['node'])) 87 | 88 | else: 89 | return {} 90 | 91 | def load_HNSW(self): 92 | 93 | self.hnsw = hnswlib_noderag.Index(space=self.config.space, dim=self.config.dim) 94 | if os.path.exists(self.config.HNSW_path): 95 | self.hnsw.load_index(self.config.HNSW_path) 96 | 97 | else: 98 | self.hnsw.init_index(max_elements=len(self.id_map), ef_construction=self.config._ef, M=self.config._m) 99 | 100 | 101 | 102 | 103 | def save_HNSW(self): 104 | 105 | self.hnsw.save_index(self.config.HNSW_path) 106 | storage({'id':list(self.id_map.keys()),'node':list(self.id_map.values())}).save_parquet(self.config.id_map_path) 107 | storage(self.nxgraphs).save_pickle(self.config.hnsw_graph_path) 108 | 109 | def get_layer_graph(self,layer:int): 110 | return self.hnsw.get_layer_graph(layer) 111 | 112 | def get_embeddings(self): 113 | ids = self.hnsw.get_ids_list() 114 | embeddings = self.hnsw.get_items(ids,return_type='numpy') 115 | return zip([self.id_map[id] for id in ids],embeddings) 116 | 117 | -------------------------------------------------------------------------------- /NodeRAG/utils/PPR.py: -------------------------------------------------------------------------------- 1 | import networkx as nx 2 | import numpy as np 3 | import scipy.sparse as sp 4 | from operator import itemgetter 5 | 6 | class sparse_PPR(): 7 | 8 | def __init__(self,graph:nx.Graph,modified = True,weight = 'weight'): 9 | 10 | self.graph = graph 11 | self.nodes = list(self.graph.nodes()) 12 | self.modified = modified 13 | self.weight = weight 14 | self.n_nodes = len(self.nodes) 15 | self.trans_matrix = self.generate_sparse_trasition_matrix() 16 | 17 | def generate_sparse_trasition_matrix(self): 18 | 19 | 20 | 21 | adjaceny_matrix = nx.adjacency_matrix(self.graph,weight = self.weight) 22 | adjaceny_matrix = (adjaceny_matrix+adjaceny_matrix.T)/2 23 | 24 | if self.modified: 25 | out_degree = adjaceny_matrix.sum(1) 26 | adjaceny_matrix = sp.lil_matrix(adjaceny_matrix) 27 | adjaceny_matrix[out_degree==0,:] = np.ones(self.n_nodes) 28 | adjaceny_matrix.setdiag(0) 29 | adjaceny_matrix = sp.csc_matrix(adjaceny_matrix) 30 | out_degree = adjaceny_matrix.sum(1) 31 | tansition_matrix = adjaceny_matrix.multiply(1/out_degree) 32 | # out_matrix transpose is in_matrix 33 | tansition_matrix = tansition_matrix.T 34 | 35 | 36 | return sp.csc_matrix(tansition_matrix) 37 | 38 | def PPR(self, 39 | perosnalization:dict[str,float], 40 | alpha:float=0.85, 41 | max_iter:int=100, 42 | epsilons:float=1e-5): 43 | 44 | probs = np.zeros(len(self.nodes)) 45 | 46 | for node,prob in perosnalization.items(): 47 | probs[self.nodes.index(node)] = prob 48 | 49 | probs = probs/np.sum(probs) 50 | 51 | for i in range(max_iter): 52 | probs_old = probs.copy() 53 | probs = alpha*self.trans_matrix.dot(probs) + (1-alpha)*probs 54 | if np.linalg.norm(probs-probs_old) 0: 88 | 89 | weight_factor = 1 / degree 90 | 91 | for neighbor in graph.neighbors(node): 92 | 93 | if graph[node][neighbor]['weight'] > weight_factor: 94 | graph[node][neighbor]['weight'] = weight_factor 95 | 96 | return graph 97 | 98 | 99 | 100 | -------------------------------------------------------------------------------- /NodeRAG/utils/lazy_import.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import sys 3 | 4 | class LazyImport: 5 | def __init__(self, module_name, class_name=None): 6 | self.module_name = module_name 7 | self.class_name = class_name 8 | self._module = None 9 | self._class = None 10 | 11 | def _import(self): 12 | if self._module is None: 13 | if self.module_name in sys.modules: 14 | self._module = sys.modules[self.module_name] 15 | else: 16 | self._module = importlib.import_module(self.module_name) 17 | 18 | if self.class_name and self._class is None: 19 | self._class = getattr(self._module, self.class_name) 20 | return self._class if self.class_name else self._module 21 | 22 | def __call__(self, *args, **kwargs): 23 | cls = self._import() 24 | return cls(*args, **kwargs) 25 | 26 | def __getattr__(self, item): 27 | cls = self._import() 28 | return getattr(cls, item) -------------------------------------------------------------------------------- /NodeRAG/utils/observation.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import List 3 | from tqdm import tqdm 4 | from rich.progress import Progress, SpinnerColumn, BarColumn, TextColumn, TimeRemainingColumn 5 | from rich.console import Console 6 | 7 | 8 | class Observer(ABC): 9 | @abstractmethod 10 | def update(self, message: str): 11 | pass 12 | @abstractmethod 13 | def reset(self,total_tasks:int|List[str],desc:str=""): 14 | pass 15 | @abstractmethod 16 | def close(self): 17 | pass 18 | 19 | class ProcessState(): 20 | total_tasks:int = 0 21 | current_task:int = 0 22 | desc:str = "" 23 | 24 | def __init__(self): 25 | self.observers = [] 26 | 27 | def add_observer(self, observer: Observer): 28 | self.observers.append(observer) 29 | 30 | def remove_observer(self, observer: Observer): 31 | self.observers.remove(observer) 32 | 33 | def notify(self): 34 | for observer in self.observers: 35 | observer.update(self) 36 | 37 | def reset(self,total_tasks:int,desc:str=""): 38 | self.total_tasks = total_tasks 39 | self._current_task = 0 40 | self.desc = desc 41 | for observer in self.observers: 42 | observer.reset(total_tasks,desc) 43 | 44 | def close(self): 45 | for observer in self.observers: 46 | observer.close() 47 | 48 | @property 49 | def current_task(self): 50 | return self._current_task 51 | 52 | @current_task.setter 53 | def current_task(self,value): 54 | self._current_task = value 55 | self.notify() 56 | 57 | 58 | 59 | class Tracker(): 60 | _instance = None 61 | def __new__(cls, *args, **kwargs): 62 | if cls._instance is None: 63 | cls._instance = super().__new__(cls) 64 | return cls._instance 65 | 66 | def __init__(self,use_tqdm:bool=True,use_rich:bool=False): 67 | self.process_state = ProcessState() 68 | if use_rich: 69 | self.process_state.add_observer(RichObserver()) 70 | elif use_tqdm: 71 | self.process_state.add_observer(tqdm_observer()) 72 | else: 73 | raise Exception("No observer selected") 74 | 75 | def set(self,total_task:int,desc:str=""): 76 | self.process_state.reset(total_task,desc) 77 | 78 | def update(self): 79 | self.process_state.current_task += 1 80 | 81 | def close(self): 82 | self.process_state.close() 83 | 84 | 85 | class tqdm_observer(Observer): 86 | def __init__(self): 87 | self.tqdm_instance = None 88 | 89 | def reset(self,total_task:int,desc:str=""): 90 | if desc == "": 91 | self.tqdm_instance = tqdm(total=total_task, 92 | bar_format="{l_bar}\033[92m{bar}\033[0m| \033[92m{n_fmt}/{total_fmt}\033[0m [\033[92m{elapsed}\033[0m<\033[92m{remaining}\033[0m]", 93 | ascii="░▒▓█", 94 | ncols=80) 95 | else: 96 | self.tqdm_instance = tqdm(total=total_task, 97 | desc='\033[92m' + desc + '\033[0m', 98 | bar_format="{l_bar}\033[92m{bar}\033[0m| \033[92m{n_fmt}/{total_fmt}\033[0m [\033[92m{elapsed}\033[0m<\033[92m{remaining}\033[0m]", 99 | ascii="░▒▓█", 100 | ncols=80) 101 | 102 | def update(self,process_state:ProcessState): 103 | self.tqdm_instance.n = process_state.current_task 104 | self.tqdm_instance.refresh() 105 | 106 | def close(self): 107 | self.tqdm_instance.close() 108 | 109 | 110 | class rich_console(): 111 | _instance = None 112 | def __new__(cls, *args, **kwargs): 113 | if cls._instance is None: 114 | cls._instance = super().__new__(cls) 115 | return cls._instance 116 | 117 | def __init__(self): 118 | self.console = Console() 119 | 120 | 121 | 122 | 123 | class RichObserver(Observer): 124 | def __init__(self): 125 | self.progress = None 126 | self.task = None 127 | self.console = rich_console().console 128 | 129 | def reset(self, total_task: int, desc: str = "Processing"): 130 | # 创建带有多个列的进度条 131 | self.progress = Progress( 132 | SpinnerColumn(), 133 | TextColumn("[progress.description]{task.description}"), 134 | BarColumn(), 135 | TextColumn("[progress.percentage]{task.percentage:>3.0f}%"), 136 | TimeRemainingColumn(), 137 | ) 138 | self.progress.start() 139 | # 添加任务 140 | self.task = self.progress.add_task(desc, total=total_task) 141 | 142 | def update(self, process_state: ProcessState): 143 | if self.progress: 144 | self.progress.update(self.task, completed=process_state.current_task) 145 | 146 | def close(self): 147 | if self.progress: 148 | self.progress.stop() 149 | self.console.print('',end='\r') 150 | -------------------------------------------------------------------------------- /NodeRAG/utils/prompt/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Terry-Xu-666/NodeRAG/f77dd6adb34cf4dda1d88b30b2bf0b17d14480a9/NodeRAG/utils/prompt/__init__.py -------------------------------------------------------------------------------- /NodeRAG/utils/prompt/answer.py: -------------------------------------------------------------------------------- 1 | answer_prompt = """ 2 | ---Role--- 3 | 4 | You are a thorough assistant responding to questions based on retrieved information. 5 | 6 | 7 | ---Goal--- 8 | 9 | Provide a clear and accurate response. Carefully review and verify the retrieved data, and integrate any relevant necessary knowledge to comprehensively address the user's question. 10 | If you are unsure of the answer, just say so. Do not fabricate information. 11 | Do not include details not supported by the provided evidence. 12 | 13 | 14 | ---Target response length and format--- 15 | 16 | Multiple Paragraphs 17 | 18 | ---Retrived Context--- 19 | 20 | {info} 21 | 22 | ---Query--- 23 | 24 | {query} 25 | """ 26 | 27 | 28 | answer_prompt_Chinese = ''' 29 | ---角色--- 30 | 你是一个根据检索到的信息回答问题的细致助手。 31 | 32 | ---目标--- 33 | 提供清晰且准确的回答。仔细审查和验证检索到的数据,并结合任何相关的必要知识,全面地解决用户的问题。 34 | 如果你不确定答案,请直接说明——不要编造信息。 35 | 不要包含没有提供支持证据的细节。 36 | 37 | ---输入--- 38 | 检索到的信息:{info} 39 | 40 | 用户问题:{query} 41 | ''' -------------------------------------------------------------------------------- /NodeRAG/utils/prompt/attribute_generation_prompt.py: -------------------------------------------------------------------------------- 1 | attribute_generation_prompt = """ 2 | Generate a concise summary of the given entity, capturing its essential attributes and important relevant relationships. The summary should read like a character sketch in a novel or a product description, providing an engaging yet precise overview. Ensure the output only includes the summary of the entity without any additional explanations or metadata. The length must not exceed 2000 words but can be shorter if the input material is limited. Focus on distilling the most important insights with a smooth narrative flow, highlighting the entity’s core traits and meaningful connections. 3 | Entity: {entity} 4 | Related Semantic Units: {semantic_units} 5 | Related Relationships: {relationships} 6 | """ 7 | attribute_generation_prompt_Chinese = """ 8 | 生成所给实体的简明总结,涵盖其基本属性和重要相关关系。该总结应像小说中的人物简介或产品描述一样,提供引人入胜且精准的概览。确保输出只包含该实体的总结,不包含任何额外的解释或元数据。字数不得超过2000字,但如果输入材料有限,可以少于2000字。重点在于通过流畅的叙述提炼出最重要的见解,突出实体的核心特征及重要关系。 9 | 实体: {entity} 10 | 相关语义单元: {semantic_units} 11 | 相关关系: {relationships} 12 | """ 13 | -------------------------------------------------------------------------------- /NodeRAG/utils/prompt/community_summary.py: -------------------------------------------------------------------------------- 1 | community_summary = """You will receive a set of text data from the same cluster. Your task is to extract distinct categories of high-level information, such as concepts, themes, relevant theories, potential impacts, and key insights. Each piece of information should include a concise title and a corresponding description, reflecting the unique perspectives within the text cluster. 2 | Please do not attempt to include all possible information; instead, select the elements that have the most significance and diversity in this cluster. Avoid redundant information—if there are highly similar elements, combine them into a single, comprehensive entry. Ensure that the high-level information reflects the varied dimensions within the text, providing a well-rounded overview. 3 | clustered text data: 4 | {content} 5 | """ 6 | 7 | community_summary_Chinese = """你将收到来自同一聚类的一组文本数据。你的任务是从文本数据中提取不同类别的高层次信息,例如概念、主题、相关理论、潜在影响和关键见解。每条信息应包含一个简洁的标题和相应的描述,以反映该聚类文本中的独特视角。 8 | 请不要试图包含所有可能的信息;相反,选择在该聚类中最具重要性和多样性的元素。避免冗余信息——如果有高度相似的内容,请将它们合并为一个综合条目。确保提取的高层次信息反映文本中的多维度内容,提供全面的概览。 9 | 聚类文本数据: 10 | {content} 11 | """ -------------------------------------------------------------------------------- /NodeRAG/utils/prompt/decompose.py: -------------------------------------------------------------------------------- 1 | decompos_query = ''' 2 | Please break down the following query into a single list. Each item in the list should either be a main entity (such as a key noun or object). If you have high confidence about the user's intent or domain knowledge, you may also include closely related terms. If uncertain, please only extract entities and semantic chunks directly from the query. Please try to reduce the number of common nouns in the list. Ensure all elements are organized within one unified list. 3 | Query:{query} 4 | ''' 5 | 6 | decompos_query_Chinese = ''' 7 | 请将以下问题分解为一个 list,其中每一项是句子的主要实体(如关键名词或对象)。如果你对用户的意图或相关领域知识有充分把握,也可以包含密切相关的术语。如果不确定,请仅从问题中提取实体。请尽量减少囊括常见的名词,请将这些元素整合在一个单一的 list 中输出。 8 | 问题:{query} 9 | ''' -------------------------------------------------------------------------------- /NodeRAG/utils/prompt/json_format.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel 2 | 3 | class semantic_group(BaseModel): 4 | semantic_unit:str 5 | entities:list[str] 6 | relationships:list[str] 7 | 8 | class text_decomposition(BaseModel): 9 | Output:list[semantic_group] 10 | 11 | class relationship_reconstraction(BaseModel): 12 | source:str 13 | relationship:str 14 | target:str 15 | 16 | 17 | 18 | class elements(BaseModel): 19 | title:str 20 | description:str 21 | 22 | 23 | class High_level_element(BaseModel): 24 | high_level_elements:list[elements] 25 | 26 | class decomposed_text(BaseModel): 27 | elements:list[str] -------------------------------------------------------------------------------- /NodeRAG/utils/prompt/prompt_manager.py: -------------------------------------------------------------------------------- 1 | from .text_decomposition import text_decomposition_prompt, text_decomposition_prompt_Chinese 2 | from .json_format import text_decomposition, relationship_reconstraction, High_level_element,decomposed_text 3 | from .translation import translate_prompt 4 | from .relationship_reconstraction import relationship_reconstraction_prompt, relationship_reconstraction_prompt_Chinese 5 | from .attribute_generation_prompt import attribute_generation_prompt, attribute_generation_prompt_Chinese 6 | from .community_summary import community_summary, community_summary_Chinese 7 | from .decompose import decompos_query,decompos_query_Chinese 8 | from .answer import answer_prompt, answer_prompt_Chinese 9 | from ...LLM.LLM_state import get_api_client 10 | 11 | 12 | API_request = get_api_client() 13 | 14 | class prompt_manager(): 15 | 16 | def __init__(self, language:str): 17 | self.language = language 18 | 19 | @property 20 | def text_decomposition(self): 21 | match self.language: 22 | case 'English': 23 | return text_decomposition_prompt 24 | case "Chinese": 25 | return text_decomposition_prompt_Chinese 26 | case _: 27 | return self.translate(text_decomposition_prompt) 28 | @property 29 | def relationship_reconstraction(self): 30 | match self.language: 31 | case 'English': 32 | return relationship_reconstraction_prompt 33 | case "Chinese": 34 | return relationship_reconstraction_prompt_Chinese 35 | case _: 36 | return self.translate(relationship_reconstraction_prompt) 37 | 38 | @property 39 | def attribute_generation(self): 40 | match self.language: 41 | case 'English': 42 | return attribute_generation_prompt 43 | case "Chinese": 44 | return attribute_generation_prompt_Chinese 45 | case _: 46 | return self.translate(attribute_generation_prompt) 47 | 48 | @property 49 | def community_summary(self): 50 | match self.language: 51 | case 'English': 52 | return community_summary 53 | case "Chinese": 54 | return community_summary_Chinese 55 | case _: 56 | return self.translate(community_summary) 57 | @property 58 | def decompose_query(self): 59 | match self.language: 60 | case 'English': 61 | return decompos_query 62 | case "Chinese": 63 | return decompos_query_Chinese 64 | case _: 65 | return self.translate(decompos_query) 66 | @property 67 | def answer(self): 68 | match self.language: 69 | case 'English': 70 | return answer_prompt 71 | case "Chinese": 72 | return answer_prompt_Chinese 73 | case _: 74 | return self.translate(answer_prompt) 75 | 76 | 77 | def translate(self,prompt:str): 78 | prompt = translate_prompt.format(language = self.language, prompt = prompt) 79 | input_dict = {'prompt':prompt} 80 | response = API_request.request(input_dict) 81 | return response 82 | 83 | @property 84 | def text_decomposition_json(self): 85 | return text_decomposition 86 | 87 | @property 88 | def relationship_reconstraction_json(self): 89 | return relationship_reconstraction 90 | 91 | @property 92 | def high_level_element_json(self): 93 | return High_level_element 94 | 95 | @property 96 | def decomposed_text_json(self): 97 | return decomposed_text 98 | 99 | 100 | 101 | 102 | -------------------------------------------------------------------------------- /NodeRAG/utils/prompt/relationship_reconstraction.py: -------------------------------------------------------------------------------- 1 | relationship_reconstraction_prompt = """ 2 | You will be given a string containing tuples representing relationships between entities. The format of these relationships is incorrect and needs to be reconstructed. The correct format should be: 'ENTITY_A,RELATION_TYPE,ENTITY_B', where each tuple contains three elements: two entities and a relationship type. Your task is to reconstruct each relationship in the following format: {{'source': 'ENTITY_A', 'relation': 'RELATION_TYPE', 'target': 'ENTITY_B'}}. Please ensure the output follows this structure, accurately mapping the entities and relationships provided. 3 | Incorrect relationships tuple string:{relationship} 4 | """ 5 | 6 | relationship_reconstraction_prompt_Chinese = """ 7 | 你将获得一个包含实体之间关系的元组字符串。这些关系的格式是错误的,需要被重新构建。正确的格式应为:'实体A,关系类型,实体B',每个元组应包含三个元素:两个实体和一个关系类型。你的任务是将每个关系重新构建为以下格式:{{'source': '实体A', 'relation': '关系类型', 'target': '实体B'}}。请确保输出遵循此结构,准确映射提供的实体和关系。 8 | 错误的关系元组:{relationship} 9 | """ 10 | -------------------------------------------------------------------------------- /NodeRAG/utils/prompt/text_decomposition.py: -------------------------------------------------------------------------------- 1 | text_decomposition_prompt = """ 2 | Goal: Given a text, segment it into multiple semantic units, each containing detailed descriptions of specific events or activities. 3 | Perform the following tasks: 4 | 1. Provide a summary for each semantic unit while retaining all crucial details relevant to the original context. 5 | 2. Extract all entities directly from the original text of each semantic unit, not from the paraphrased summary. Format each entity name in UPPERCASE. You should extract all entities including times, locations, people, organizations and all kinds of entities. 6 | 3. From the entities extracted in Step 2, list all relationships within the semantic unit and the corresponding original context in the form of string seperated by comma : "ENTITY_A, RELATION_TYPE, ENTITY_B". The RELATION_TYPE could be a descriptive sentence, while the entities involved in the relationship must come from the entity names extracted in Step 2. Please make sure the string contains three elements representing two entities and the relationship type. 7 | 8 | requirements: 9 | 1. Temporal Entities: Represent time entities based on the available details without filling in missing parts. Use specific formats based on what parts of the date or time are mentioned in the text. 10 | 11 | Each semantic unit should be represented as a dictionary containing three keys: semantic_unit (a paraphrased summary of each semantic unit), entities (a list of entities extracted directly from the original text of each semantic unit, formatted in UPPERCASE), and relationships (a list of extracted relationship strings that contain three elements, where the relationship type is a descriptive sentence). All these dictionaries should be stored in a list to facilitate management and access. 12 | 13 | 14 | Example: 15 | 16 | Text: In September 2024, Dr. Emily Roberts traveled to Paris to attend the International Conference on Renewable Energy. During her visit, she explored partnerships with several European companies and presented her latest research on solar panel efficiency improvements. Meanwhile, on the other side of the world, her colleague, Dr. John Miller, was conducting fieldwork in the Amazon Rainforest. He documented several new species and observed the effects of deforestation on the local wildlife. Both scholars' work is essential in their respective fields and contributes significantly to environmental conservation efforts. 17 | Output: 18 | [ 19 | {{ 20 | "semantic_unit": "In September 2024, Dr. Emily Roberts attended the International Conference on Renewable Energy in Paris, where she presented her research on solar panel efficiency improvements and explored partnerships with European companies.", 21 | "entities": ["DR. EMILY ROBERTS", "2024-09", "PARIS", "INTERNATIONAL CONFERENCE ON RENEWABLE ENERGY", "EUROPEAN COMPANIES", "SOLAR PANEL EFFICIENCY"], 22 | "relationships": [ 23 | "DR. EMILY ROBERTS, attended, INTERNATIONAL CONFERENCE ON RENEWABLE ENERGY", 24 | "DR. EMILY ROBERTS, explored partnerships with, EUROPEAN COMPANIES", 25 | "DR. EMILY ROBERTS, presented research on, SOLAR PANEL EFFICIENCY" 26 | ] 27 | }}, 28 | {{ 29 | "semantic_unit": "Dr. John Miller conducted fieldwork in the Amazon Rainforest, documenting several new species and observing the effects of deforestation on local wildlife.", 30 | "entities": ["DR. JOHN MILLER", "AMAZON RAINFOREST", "NEW SPECIES", "DEFORESTATION", "LOCAL WILDLIFE"], 31 | "relationships": [ 32 | "DR. JOHN MILLER, conducted fieldwork in, AMAZON RAINFOREST", 33 | "DR. JOHN MILLER, documented, NEW SPECIES", 34 | "DR. JOHN MILLER, observed the effects of, DEFORESTATION on LOCAL WILDLIFE" 35 | ] 36 | }}, 37 | {{ 38 | "semantic_unit": "The work of both Dr. Emily Roberts and Dr. John Miller is crucial in their respective fields and contributes significantly to environmental conservation efforts.", 39 | "entities": ["DR. EMILY ROBERTS", "DR. JOHN MILLER", "ENVIRONMENTAL CONSERVATION"], 40 | "relationships": [ 41 | "DR. EMILY ROBERTS, contributes to, ENVIRONMENTAL CONSERVATION", 42 | "DR. JOHN MILLER, contributes to, ENVIRONMENTAL CONSERVATION" 43 | ] 44 | }} 45 | ] 46 | 47 | 48 | ######### 49 | Real_Data: 50 | ######### 51 | Text:{text} 52 | 53 | """ 54 | 55 | text_decomposition_prompt_Chinese = """ 56 | 目标:给定一个文本,将该文本被划分为多个语义单元,每个单元包含对特定事件或活动的详细描述。 57 | 执行以下任务: 58 | 1.为每个语义单元提供总结,同时保留与原始上下文相关的所有关键细节。 59 | 2.直接从每个语义单元的原始文本中提取所有实体,而不是从改写的总结中提取。 60 | 3.从第2步中提取的实体中列出语义单元内的所有关系,其中关系类型可以是描述性句子。使用格式"ENTITY_A,RELATION_TYPE,ENTITY_B",请确保字符串中包含三个元素,分别表示两个实体和关系类型。 61 | 62 | 要求: 63 | 64 | 时间实体:根据文本中提到的日期或时间的具体部分来表示时间实体,不填补缺失部分。 65 | 66 | 每个语义单元应以一个字典表示,包含三个键:semantic_unit(每个语义单元的概括性总结)、entities(直接从每个语义单元的原始文本中提取的实体列表,实体名格式为大写)、relationships(描述性句子形式的提取关系字符串三元组列表)。所有这些字典应存储在一个列表中,以便管理和访问。 67 | 68 | 示例: 69 | 70 | 文本:2024年9月,艾米莉·罗伯茨博士前往巴黎参加国际可再生能源会议。在她的访问期间,她与几家欧洲公司探讨了合作并介绍了她在提高太阳能板效率方面的最新研究。与此同时,在世界的另一边,她的同事约翰·米勒博士在亚马逊雨林进行实地工作。他记录了几种新物种,并观察了森林砍伐对当地野生动物的影响。两位学者的工作在各自的领域内至关重要,对环境保护工作做出了重大贡献。 71 | 输出: 72 | [ 73 | {{ 74 | "semantic_unit": "2024年9月,艾米莉·罗伯茨博士参加了在巴黎举行的国际可再生能源会议,她在会上介绍了她关于太阳能板效率提高的研究并探讨了与欧洲公司的合作。", 75 | "entities": ["艾米莉·罗伯茨博士", "2024-09", "巴黎", "国际可再生能源会议", "欧洲公司", "太阳能板效率"], 76 | "relationships": [ 77 | "艾米莉·罗伯茨博士, 参加了, 国际可再生能源会议", 78 | "艾米莉·罗伯茨博士, 探讨了合作, 欧洲公司", 79 | "艾米莉·罗伯茨博士, 介绍了研究, 太阳能板效率" 80 | ] 81 | }}, 82 | {{ 83 | "semantic_unit": "约翰·米勒博士在亚马逊雨林进行实地工作,记录了几种新物种并观察了森林砍伐对当地野生动物的影响。", 84 | "entities": ["约翰·米勒博士", "亚马逊雨林", "新物种", "森林砍伐", "当地野生动物"], 85 | "relationships": [ 86 | "约翰·米勒博士, 在, 亚马逊雨林进行实地工作", 87 | "约翰·米勒博士, 记录了, 新物种", 88 | "约翰·米勒博士, 观察了, 森林砍伐对当地野生动物的影响" 89 | ] 90 | }}, 91 | {{ 92 | "semantic_unit": "艾米莉·罗伯茨博士和约翰·米勒博士的工作在各自的领域内至关重要,对环境保护工作做出了重大贡献。", 93 | "entities": ["艾米莉·罗伯茨博士", "约翰·米勒博士", "环境保护"], 94 | "relationships": [ 95 | "艾米莉·罗伯茨博士, 贡献于, 环境保护", 96 | "约翰·米勒博士, 贡献于, 环境保护" 97 | ] 98 | }} 99 | ] 100 | 101 | ########## 102 | 实际数据: 103 | ########## 104 | 文本:{text} 105 | """ 106 | 107 | -------------------------------------------------------------------------------- /NodeRAG/utils/prompt/translation.py: -------------------------------------------------------------------------------- 1 | translate_prompt = """ 2 | Goal: Translate the given prompt into {language}. 3 | You are provided with a prompt in English. Your task is to translate the prompt from English to {language}. Please ensure the translation is accurate and maintains the original meaning,context and format. 4 | original prompt text:{prompt} 5 | Output: 6 | """ -------------------------------------------------------------------------------- /NodeRAG/utils/readable_index.py: -------------------------------------------------------------------------------- 1 | from typing import List, Dict 2 | import json 3 | from rich.console import Console 4 | 5 | class readable_index: 6 | 7 | def __new__(cls, *args, **kwargs): 8 | if not hasattr(cls, '_instance'): 9 | cls._instance = super().__new__(cls) 10 | return cls._instance 11 | 12 | def __init__(self, initial_value:int = 0): 13 | if not hasattr(self, '_initialized') or not self._initialized: 14 | self._counter = initial_value 15 | self._initialized = True 16 | 17 | def increment(self): 18 | self._counter += 1 19 | return self._counter 20 | 21 | @property 22 | def counter(self): 23 | return self._counter 24 | 25 | def reset(self,num:int = 0): 26 | self._counter = num 27 | return self 28 | 29 | class document_index(readable_index): 30 | 31 | pass 32 | 33 | class text_unit_index(readable_index): 34 | 35 | pass 36 | 37 | class semantic_unit_index(readable_index): 38 | 39 | pass 40 | 41 | class entity_index(readable_index): 42 | 43 | pass 44 | 45 | class relation_index(readable_index): 46 | 47 | pass 48 | 49 | class attribute_index(readable_index): 50 | 51 | pass 52 | 53 | class community_summary_index(readable_index): 54 | pass 55 | # _instance = {} 56 | 57 | # def __new__(cls, level, *args, **kwargs): 58 | # if level not in cls._instance: 59 | # cls._instance[level] = super().__new__(cls) 60 | # return cls._instance[level] 61 | 62 | # def __init__(self, initial_value: int = 0,level:int = 0): 63 | # if not hasattr(self, '_initialized') or not self._initialized: 64 | # self._counter = initial_value 65 | # self._initialized = True 66 | # self.level = level 67 | class high_level_element_index(readable_index): 68 | 69 | pass 70 | class index_manager(): 71 | 72 | def __init__(self,indexers:List[readable_index],console:Console) -> None: 73 | self.indexer_dict = {} 74 | self.console = console 75 | for index in indexers: 76 | self.add_index(index) 77 | 78 | def get_index(self,index_name:str|int) -> Dict[str,int]: 79 | 80 | if isinstance(index_name,int): 81 | indexer_name = list(self.indexer_dict.keys())[index_name] 82 | return {indexer_name:self.indexer_dict[indexer_name].counter} 83 | 84 | elif isinstance(index_name,str): 85 | if index_name in self.indexer_dict: 86 | return {index_name:self.indexer_dict[index_name].counter} 87 | else: 88 | raise ValueError(f"Index {index_name} not found") 89 | else: 90 | raise ValueError(f"Invalid index name {index_name}") 91 | 92 | def add_index(self,index:readable_index) -> None: 93 | 94 | if index.__class__.__name__ not in self.indexer_dict: 95 | self.indexer_dict[index.__class__.__name__] = index 96 | 97 | def add_indices(self,indexers:List[readable_index]) -> None: 98 | for index in indexers: 99 | self.add_index(index) 100 | 101 | def store_all_indices(self,path:str) -> None: 102 | current_counter = {} 103 | for name, indexer in self.indexer_dict.items(): 104 | current_counter[name] = indexer.counter 105 | 106 | with open(path,'w') as f: 107 | json.dump(current_counter,f,indent=2) 108 | self.console.print(f"Indices stored in {path}",style="bold green") 109 | 110 | @classmethod 111 | def load_indices(cls,path:str,console:Console) -> 'index_manager': 112 | with open(path,'r') as f: 113 | indices = json.load(f) 114 | indexers = [] 115 | for name, counter in indices.items(): 116 | 117 | indexer = globals()[name]().reset(counter) 118 | indexers.append(indexer) 119 | 120 | 121 | return cls(indexers,console) 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | -------------------------------------------------------------------------------- /NodeRAG/utils/text_spliter.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | from .token_utils import get_token_counter 3 | 4 | class SemanticTextSplitter: 5 | def __init__(self, chunk_size: int = 1048, model_name: str = "gpt-4o-mini"): 6 | """ 7 | Initialize the text splitter with chunk size and model name parameters. 8 | 9 | Args: 10 | chunk_size (int): Maximum number of tokens per chunk 11 | model_name (str): Model name for token counting 12 | """ 13 | self.chunk_size = chunk_size 14 | self.token_counter = get_token_counter(model_name) 15 | 16 | def split(self, text: str) -> List[str]: 17 | """ 18 | Split text into chunks based on both token count and semantic boundaries. 19 | """ 20 | chunks = [] 21 | start = 0 22 | text_len = len(text) 23 | 24 | while start < text_len: 25 | # add 4 times of chunk_size string to the start position 26 | end = start + self.chunk_size * 4 # assume each token is 4 characters 27 | if end > text_len: 28 | end = text_len 29 | 30 | # get the current text fragment 31 | current_chunk = text[start:end] 32 | 33 | # if the token count of the current fragment exceeds the limit, need to find the split point 34 | while self.token_counter(current_chunk) > self.chunk_size and start < end: 35 | # find semantic boundary in the current range 36 | boundaries = ['\n\n', '\n', '。', '.', '!', '!', '?', '?', ';', ';'] 37 | semantic_end = end 38 | 39 | for boundary in boundaries: 40 | boundary_pos = current_chunk.rfind(boundary) 41 | if boundary_pos != -1: 42 | semantic_end = start + boundary_pos + len(boundary) 43 | break 44 | 45 | # if found semantic boundary, use it; otherwise, force truncation by character 46 | if semantic_end < end: 47 | end = semantic_end 48 | else: 49 | # 没找到合适的语义边界,往回数token直到满足大小限制 50 | end = start + int(len(current_chunk) // 1.2) 51 | 52 | current_chunk = text[start:end] 53 | 54 | # 添加处理好的文本块 55 | chunk = current_chunk.strip() 56 | if chunk: 57 | chunks.append(chunk) 58 | 59 | # 移动到下一个起始位置 60 | start = end 61 | 62 | return chunks 63 | -------------------------------------------------------------------------------- /NodeRAG/utils/token_utils.py: -------------------------------------------------------------------------------- 1 | import tiktoken 2 | from typing import Protocol, List 3 | # from transformers import AutoTokenizer 4 | 5 | class token_counter(Protocol): 6 | 7 | def __init__(self,model_name:str): 8 | self.model_name = model_name 9 | 10 | def __call__(self, text:str) -> int: 11 | ... 12 | 13 | class tiktoken_counter(token_counter): 14 | 15 | def __init__(self,model_name:str): 16 | self.tokenizer = tiktoken.encoding_for_model(model_name) 17 | self.token_limit_bound = 128000 18 | 19 | 20 | def encode(self, text:str) -> List[int]: 21 | return self.tokenizer.encode(text) 22 | 23 | def token_limit(self, text:str) -> bool: 24 | 25 | return len(self.encode(text)) > self.token_limit_bound 26 | 27 | 28 | 29 | def __call__(self, text: str) -> int: 30 | return len(self.encode(text)) 31 | 32 | # class deepseek_counter(token_counter): 33 | 34 | # def __init__(self,model_name:str): 35 | # self.tokenizer = AutoTokenizer.from_pretrained('deepseek-ai/DeepSeek-V3') 36 | # self.token_limit_bound = 640000 37 | 38 | 39 | # def encode(self, text:str) -> List[int]: 40 | # return self.tokenizer.encode(text) 41 | 42 | # def token_limit(self, text:str) -> bool: 43 | # return len(self.encode(text)) > self.token_limit_bound 44 | 45 | # def __call__(self, text:str) -> int: 46 | # return len(self.encode(text)) 47 | 48 | 49 | 50 | 51 | 52 | 53 | def get_token_counter(model_name:str) -> token_counter: 54 | 55 | 56 | model_name = model_name.lower() 57 | 58 | if 'gpt' in model_name: 59 | return tiktoken_counter(model_name) 60 | elif 'gemini' in model_name: 61 | token = tiktoken_counter('gpt-4o') 62 | token.token_limit_bound = 1280000 63 | return token 64 | # elif 'deepseek' in model_name: 65 | # return deepseek_counter(model_name) 66 | else: 67 | raise ValueError(f"Unsupported model {model_name}") 68 | 69 | 70 | 71 | 72 | -------------------------------------------------------------------------------- /NodeRAG/utils/yaml_operation.py: -------------------------------------------------------------------------------- 1 | from ruamel.yaml import YAML 2 | import os 3 | 4 | class YamlHandler: 5 | def __init__(self, file_path): 6 | self.file_path = file_path 7 | self.yaml = YAML() 8 | self.yaml.preserve_quotes = True 9 | self.yaml.indent(mapping=2, sequence=4, offset=2) 10 | if os.path.exists(self.file_path): 11 | with open(self.file_path, 'r') as f: 12 | self.data = self.yaml.load(f) 13 | else: 14 | raise FileNotFoundError(f"File {self.file_path} does not exist.") 15 | 16 | def save(self): 17 | if self.data is not None: 18 | with open(self.file_path, 'w') as f: 19 | self.yaml.dump(self.data, f) 20 | 21 | def update_config(self, key_path, value): 22 | """ 23 | Update the value of a nested key in the YAML data. 24 | 25 | :param key_path: List of keys representing the path to the target key. 26 | :param value: The new value to set. 27 | """ 28 | data = self.data 29 | for key in key_path[:-1]: 30 | data = data.get(key, {}) 31 | data[key_path[-1]] = value 32 | 33 | 34 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # NodeRAG: Structuring Graph-based RAG with Heterogeneous Nodes 2 | 3 |
4 | NodeRAG Logo 5 | 6 |

7 | arXiv 8 | PyPI 9 | License: MIT 10 | GitHub issues 11 | Python 12 | Website 13 | GitHub stars 14 |

15 |
16 | 17 | ## 📢 News 18 | 19 | - **[2025-03-18]** 🚀 **NodeRAG v0.1.0 Released!** The first stable version is now available on [PyPI](https://pypi.org/project/NodeRAG/). Install it with `pip install NodeRAG`. 20 | 21 | - **[2025-03-18]** 🌐 **Official Website Launched!** Visit [NodeRAG_web](https://terry-xu-666.github.io/NodeRAG_web/) for comprehensive documentation, tutorials, and examples. 22 | 23 | --- 24 | 25 | 🚀 NodeRAG is a heterogeneous graph-based generation and retrieval RAG system that you can install and use in multiple ways. 🖥️ We also provide a user interface (local deployment) and convenient tools for visualization generation. You can read our [paper](#) 📄 to learn more. For experimental discussions, check out our [blog posts](https://terry-xu-666.github.io/NodeRAG_web/blog/) 📝. 26 | 27 | --- 28 | 29 | ## 🚀 Quick Start 30 | 31 | 📖 View our official website for comprehensive documentation and tutorials: 32 | 👉 [NodeRAG_web](https://terry-xu-666.github.io/NodeRAG_web/) 33 | 34 | ### 🧩 Workflow 35 | 36 |
37 | NodeRAG Workflow 38 |
39 | 40 | --- 41 | 42 | ## NodeRAG 43 | 44 | ### Conda Setup 45 | 46 | Create and activate a virtual environment for NodeRAG: 47 | 48 | ```bash 49 | conda create -n NodeRAG python=3.10 50 | conda activate NodeRAG 51 | ``` 52 | 53 | --- 54 | 55 | ### Install `uv` (Optional: Faster Package Installation) 56 | 57 | To speed up package installation, use [`uv`](https://github.com/astral-sh/uv): 58 | 59 | ```bash 60 | pip install uv 61 | ``` 62 | 63 | --- 64 | 65 | ### Install NodeRAG 66 | 67 | Install NodeRAG using `uv` for optimized performance: 68 | 69 | ```bash 70 | uv pip install NodeRAG 71 | ``` 72 | 73 | ### Next 74 | > For indexing and answering processes, please refer to our website: [Indexing](https://terry-xu-666.github.io/NodeRAG_web/docs/indexing/) and [Answering](https://terry-xu-666.github.io/NodeRAG_web/docs/answer/) 75 | 76 | 77 | ## ✨ Features 78 | 79 | 80 | 81 | #### 🔗 Enhancing Graph Structure for RAG 82 | NodeRAG introduces a heterogeneous graph structure that strengthens the foundation of graph-based Retrieval-Augmented Generation (RAG). 83 | 84 | 85 | 86 | #### 🔍 Fine-Grained and Explainable Retrieval 87 | NodeRAG leverages HeteroGraphs to enable functionally distinct nodes, ensuring precise and context-aware retrieval while improving interpretability. 88 | 89 | #### 🧱 A Unified Information Retrieval 90 | Instead of treating extracted insights and raw data as separate layers, NodeRAG integrates them as interconnected nodes, creating a seamless and adaptable retrieval system. 91 | 92 | 93 | #### ⚡ Optimized Performance and Speed 94 | NodeRAG achieves faster graph construction and retrieval speeds through unified algorithms and optimized implementations. 95 | 96 | 97 | #### 🔄 Incremental Graph Updates 98 | NodeRAG supports incremental updates within heterogeneous graphs using graph connectivity mechanisms. 99 | 100 | 101 | 102 | #### 📊 Visualization and User Interface 103 | NodeRAG offers a user-friendly visualization system. Coupled with a fully developed Web UI, users can explore, analyze, and manage the graph structure with ease. 104 | 105 | ## ⚙️ Performance 106 | 107 | ### 📊 Benchmark Performance 108 | 109 |
110 | Benchmark Performance 111 |
112 | 113 | *NodeRAG demonstrates strong performance across multiple benchmark tasks, showcasing efficiency and retrieval quality.* 114 | 115 | --- 116 | 117 | ### 🖥️ System Performance 118 | 119 |
120 | System Performance 121 |
122 | 123 | *Optimized for speed and scalability, NodeRAG achieves fast indexing and query response times even on large datasets.* 124 | 125 | 126 | -------------------------------------------------------------------------------- /asset/NodeGraph_Figure2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Terry-Xu-666/NodeRAG/f77dd6adb34cf4dda1d88b30b2bf0b17d14480a9/asset/NodeGraph_Figure2.png -------------------------------------------------------------------------------- /asset/Node_background.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Terry-Xu-666/NodeRAG/f77dd6adb34cf4dda1d88b30b2bf0b17d14480a9/asset/Node_background.jpg -------------------------------------------------------------------------------- /asset/performance.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Terry-Xu-666/NodeRAG/f77dd6adb34cf4dda1d88b30b2bf0b17d14480a9/asset/performance.png -------------------------------------------------------------------------------- /asset/system_performance.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Terry-Xu-666/NodeRAG/f77dd6adb34cf4dda1d88b30b2bf0b17d14480a9/asset/system_performance.png -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "NodeRAG" 3 | version = "0.1.0" 4 | description = "NodeRAG, a graph-centric framework introducing heterogeneous graph structures that enable the seamless and holistic integration of graph-based methodologies into the RAG workflow" 5 | readme = "README.md" 6 | requires-python = ">=3.10" 7 | license = {text = "MIT"} 8 | dependencies = [ 9 | "aiohappyeyeballs==2.6.1", 10 | "aiohttp==3.11.13", 11 | "aiosignal==1.3.2", 12 | "altair==5.5.0", 13 | "annotated-types==0.7.0", 14 | "anyio==4.8.0", 15 | "asttokens==3.0.0", 16 | "async-timeout==5.0.1", 17 | "attrs==25.3.0", 18 | "backoff==2.2.1", 19 | "blinker==1.9.0", 20 | "cachetools==5.5.2", 21 | "certifi==2025.1.31", 22 | "charset-normalizer==3.4.1", 23 | "click==8.1.8", 24 | "colorama==0.4.6", 25 | "decorator==5.2.1", 26 | "distro==1.9.0", 27 | "exceptiongroup==1.2.2", 28 | "executing==2.2.0", 29 | "faiss-cpu==1.10.0", 30 | "filelock==3.17.0", 31 | "flask==3.1.0", 32 | "frozenlist==1.5.0", 33 | "fsspec==2025.3.0", 34 | "gitdb==4.0.12", 35 | "gitpython==3.1.44", 36 | "google-api-core>=2.24.2", 37 | "h11==0.14.0", 38 | "hnswlib-noderag==0.8.2", 39 | "httpcore==1.0.7", 40 | "httpx==0.28.1", 41 | "idna==3.10", 42 | "igraph==0.11.8", 43 | "ipython==8.34.0", 44 | "itsdangerous==2.2.0", 45 | "jedi==0.19.2", 46 | "jinja2==3.1.6", 47 | "jiter==0.9.0", 48 | "jsonpickle==4.0.2", 49 | "jsonschema==4.23.0", 50 | "jsonschema-specifications==2024.10.1", 51 | "leidenalg==0.10.2", 52 | "markdown-it-py==3.0.0", 53 | "markupsafe==3.0.2", 54 | "matplotlib-inline==0.1.7", 55 | "mdurl==0.1.2", 56 | "mpmath==1.3.0", 57 | "multidict==6.1.0", 58 | "narwhals==1.30.0", 59 | "networkx==3.4.2", 60 | "numpy==1.26.4", 61 | "openai==1.66.3", 62 | "packaging==24.2", 63 | "pandas==2.2.3", 64 | "parso==0.8.4", 65 | "pillow==11.1.0", 66 | "prompt-toolkit==3.0.50", 67 | "propcache==0.3.0", 68 | "protobuf==5.29.3", 69 | "pure-eval==0.2.3", 70 | "pyarrow==19.0.1", 71 | "pydantic==2.10.6", 72 | "pydantic-core==2.27.2", 73 | "pydeck==0.9.1", 74 | "pygments==2.19.1", 75 | "python-dateutil==2.9.0.post0", 76 | "pytz==2025.1", 77 | "pyvis==0.3.2", 78 | "pyyaml==6.0.2", 79 | "referencing==0.36.2", 80 | "regex==2024.11.6", 81 | "requests==2.32.3", 82 | "rich==13.9.4", 83 | "rpds-py==0.23.1", 84 | "ruamel-yaml>=0.18.10", 85 | "scipy==1.12.0", 86 | "six==1.17.0", 87 | "smmap==5.0.2", 88 | "sniffio==1.3.1", 89 | "sortedcontainers==2.4.0", 90 | "stack-data==0.6.3", 91 | "streamlit==1.43.2", 92 | "sympy==1.13.1", 93 | "tenacity==9.0.0", 94 | "texttable==1.7.0", 95 | "tiktoken==0.9.0", 96 | "toml==0.10.2", 97 | "tornado==6.4.2", 98 | "tqdm==4.67.1", 99 | "traitlets==5.14.3", 100 | "typing-extensions==4.12.2", 101 | "tzdata==2025.1", 102 | "urllib3==2.3.0", 103 | "watchdog==6.0.0", 104 | "wcwidth==0.2.13", 105 | "werkzeug==3.1.3", 106 | "yarl==1.18.3", 107 | ] 108 | 109 | [build-system] 110 | requires = ["setuptools>=45", "wheel"] 111 | build-backend = "setuptools.build_meta" 112 | 113 | [tool.setuptools.packages.find] 114 | where = ["."] 115 | 116 | [tool.setuptools.package-data] 117 | NodeRAG = ["/config/*.yaml"] 118 | 119 | -------------------------------------------------------------------------------- /requirements.in: -------------------------------------------------------------------------------- 1 | aiohttp 2 | backoff 3 | faiss-cpu 4 | flask 5 | hnswlib-noderag 6 | igraph 7 | leidenalg 8 | networkx 9 | numpy 10 | openai 11 | google-genai 12 | google-api-core 13 | pandas 14 | pydantic 15 | pyvis 16 | pyyaml 17 | requests 18 | rich 19 | ruamel.yaml 20 | scipy~=1.12.0 21 | sortedcontainers 22 | streamlit 23 | tiktoken 24 | transformers 25 | tqdm -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Terry-Xu-666/NodeRAG/f77dd6adb34cf4dda1d88b30b2bf0b17d14480a9/requirements.txt --------------------------------------------------------------------------------