├── PathRAG ├── PathRAG.py ├── __init__.py ├── __pycache__ │ ├── PathRAG.cpython-39.pyc │ ├── __init__.cpython-39.pyc │ ├── base.cpython-39.pyc │ ├── llm.cpython-39.pyc │ ├── operate.cpython-39.pyc │ ├── prompt.cpython-39.pyc │ ├── storage.cpython-39.pyc │ └── utils.cpython-39.pyc ├── base.py ├── llm.py ├── operate.py ├── prompt.py ├── storage.py └── utils.py ├── README.md ├── requirements.txt └── v1_test.py /PathRAG/PathRAG.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import os 3 | from tqdm.asyncio import tqdm as tqdm_async 4 | from dataclasses import asdict, dataclass, field 5 | from datetime import datetime 6 | from functools import partial 7 | from typing import Type, cast 8 | 9 | 10 | from .llm import ( 11 | gpt_4o_mini_complete, 12 | openai_embedding, 13 | ) 14 | from .operate import ( 15 | chunking_by_token_size, 16 | extract_entities, 17 | kg_query, 18 | ) 19 | 20 | from .utils import ( 21 | EmbeddingFunc, 22 | compute_mdhash_id, 23 | limit_async_func_call, 24 | convert_response_to_json, 25 | logger, 26 | set_logger, 27 | ) 28 | from .base import ( 29 | BaseGraphStorage, 30 | BaseKVStorage, 31 | BaseVectorStorage, 32 | StorageNameSpace, 33 | QueryParam, 34 | ) 35 | 36 | from .storage import ( 37 | JsonKVStorage, 38 | NanoVectorDBStorage, 39 | NetworkXStorage, 40 | ) 41 | 42 | 43 | 44 | 45 | def lazy_external_import(module_name: str, class_name: str): 46 | """Lazily import a class from an external module based on the package of the caller.""" 47 | 48 | 49 | import inspect 50 | 51 | caller_frame = inspect.currentframe().f_back 52 | module = inspect.getmodule(caller_frame) 53 | package = module.__package__ if module else None 54 | 55 | def import_class(*args, **kwargs): 56 | import importlib 57 | 58 | 59 | module = importlib.import_module(module_name, package=package) 60 | 61 | 62 | cls = getattr(module, class_name) 63 | return cls(*args, **kwargs) 64 | 65 | return import_class 66 | 67 | 68 | Neo4JStorage = lazy_external_import(".kg.neo4j_impl", "Neo4JStorage") 69 | OracleKVStorage = lazy_external_import(".kg.oracle_impl", "OracleKVStorage") 70 | OracleGraphStorage = lazy_external_import(".kg.oracle_impl", "OracleGraphStorage") 71 | OracleVectorDBStorage = lazy_external_import(".kg.oracle_impl", "OracleVectorDBStorage") 72 | MilvusVectorDBStorge = lazy_external_import(".kg.milvus_impl", "MilvusVectorDBStorge") 73 | MongoKVStorage = lazy_external_import(".kg.mongo_impl", "MongoKVStorage") 74 | ChromaVectorDBStorage = lazy_external_import(".kg.chroma_impl", "ChromaVectorDBStorage") 75 | TiDBKVStorage = lazy_external_import(".kg.tidb_impl", "TiDBKVStorage") 76 | TiDBVectorDBStorage = lazy_external_import(".kg.tidb_impl", "TiDBVectorDBStorage") 77 | AGEStorage = lazy_external_import(".kg.age_impl", "AGEStorage") 78 | 79 | 80 | def always_get_an_event_loop() -> asyncio.AbstractEventLoop: 81 | """ 82 | Ensure that there is always an event loop available. 83 | 84 | This function tries to get the current event loop. If the current event loop is closed or does not exist, 85 | it creates a new event loop and sets it as the current event loop. 86 | 87 | Returns: 88 | asyncio.AbstractEventLoop: The current or newly created event loop. 89 | """ 90 | try: 91 | 92 | current_loop = asyncio.get_event_loop() 93 | if current_loop.is_closed(): 94 | raise RuntimeError("Event loop is closed.") 95 | return current_loop 96 | 97 | except RuntimeError: 98 | 99 | logger.info("Creating a new event loop in main thread.") 100 | new_loop = asyncio.new_event_loop() 101 | asyncio.set_event_loop(new_loop) 102 | return new_loop 103 | 104 | 105 | @dataclass 106 | class PathRAG: 107 | working_dir: str = field( 108 | default_factory=lambda: f"./PathRAG_cache_{datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}" 109 | ) 110 | 111 | embedding_cache_config: dict = field( 112 | default_factory=lambda: { 113 | "enabled": False, 114 | "similarity_threshold": 0.95, 115 | "use_llm_check": False, 116 | } 117 | ) 118 | kv_storage: str = field(default="JsonKVStorage") 119 | vector_storage: str = field(default="NanoVectorDBStorage") 120 | graph_storage: str = field(default="NetworkXStorage") 121 | 122 | current_log_level = logger.level 123 | log_level: str = field(default=current_log_level) 124 | 125 | 126 | chunk_token_size: int = 1200 127 | chunk_overlap_token_size: int = 100 128 | tiktoken_model_name: str = "gpt-4o-mini" 129 | 130 | 131 | entity_extract_max_gleaning: int = 1 132 | entity_summary_to_max_tokens: int = 500 133 | 134 | 135 | node_embedding_algorithm: str = "node2vec" 136 | node2vec_params: dict = field( 137 | default_factory=lambda: { 138 | "dimensions": 1536, 139 | "num_walks": 10, 140 | "walk_length": 40, 141 | "window_size": 2, 142 | "iterations": 3, 143 | "random_seed": 3, 144 | } 145 | ) 146 | 147 | 148 | embedding_func: EmbeddingFunc = field(default_factory=lambda: openai_embedding) 149 | embedding_batch_num: int = 32 150 | embedding_func_max_async: int = 16 151 | 152 | 153 | llm_model_func: callable = gpt_4o_mini_complete 154 | llm_model_name: str = "meta-llama/Llama-3.2-1B-Instruct" 155 | llm_model_max_token_size: int = 32768 156 | llm_model_max_async: int = 16 157 | llm_model_kwargs: dict = field(default_factory=dict) 158 | 159 | 160 | vector_db_storage_cls_kwargs: dict = field(default_factory=dict) 161 | 162 | enable_llm_cache: bool = True 163 | 164 | 165 | addon_params: dict = field(default_factory=dict) 166 | convert_response_to_json_func: callable = convert_response_to_json 167 | 168 | def __post_init__(self): 169 | log_file = os.path.join("PathRAG.log") 170 | set_logger(log_file) 171 | logger.setLevel(self.log_level) 172 | 173 | logger.info(f"Logger initialized for working directory: {self.working_dir}") 174 | 175 | 176 | self.key_string_value_json_storage_cls: Type[BaseKVStorage] = ( 177 | self._get_storage_class()[self.kv_storage] 178 | ) 179 | self.vector_db_storage_cls: Type[BaseVectorStorage] = self._get_storage_class()[ 180 | self.vector_storage 181 | ] 182 | self.graph_storage_cls: Type[BaseGraphStorage] = self._get_storage_class()[ 183 | self.graph_storage 184 | ] 185 | 186 | if not os.path.exists(self.working_dir): 187 | logger.info(f"Creating working directory {self.working_dir}") 188 | os.makedirs(self.working_dir) 189 | 190 | self.llm_response_cache = ( 191 | self.key_string_value_json_storage_cls( 192 | namespace="llm_response_cache", 193 | global_config=asdict(self), 194 | embedding_func=None, 195 | ) 196 | if self.enable_llm_cache 197 | else None 198 | ) 199 | self.embedding_func = limit_async_func_call(self.embedding_func_max_async)( 200 | self.embedding_func 201 | ) 202 | 203 | 204 | self.full_docs = self.key_string_value_json_storage_cls( 205 | namespace="full_docs", 206 | global_config=asdict(self), 207 | embedding_func=self.embedding_func, 208 | ) 209 | self.text_chunks = self.key_string_value_json_storage_cls( 210 | namespace="text_chunks", 211 | global_config=asdict(self), 212 | embedding_func=self.embedding_func, 213 | ) 214 | self.chunk_entity_relation_graph = self.graph_storage_cls( 215 | namespace="chunk_entity_relation", 216 | global_config=asdict(self), 217 | embedding_func=self.embedding_func, 218 | ) 219 | 220 | 221 | self.entities_vdb = self.vector_db_storage_cls( 222 | namespace="entities", 223 | global_config=asdict(self), 224 | embedding_func=self.embedding_func, 225 | meta_fields={"entity_name"}, 226 | ) 227 | self.relationships_vdb = self.vector_db_storage_cls( 228 | namespace="relationships", 229 | global_config=asdict(self), 230 | embedding_func=self.embedding_func, 231 | meta_fields={"src_id", "tgt_id"}, 232 | ) 233 | self.chunks_vdb = self.vector_db_storage_cls( 234 | namespace="chunks", 235 | global_config=asdict(self), 236 | embedding_func=self.embedding_func, 237 | ) 238 | 239 | self.llm_model_func = limit_async_func_call(self.llm_model_max_async)( 240 | partial( 241 | self.llm_model_func, 242 | hashing_kv=self.llm_response_cache 243 | if self.llm_response_cache 244 | and hasattr(self.llm_response_cache, "global_config") 245 | else self.key_string_value_json_storage_cls( 246 | global_config=asdict(self), 247 | ), 248 | **self.llm_model_kwargs, 249 | ) 250 | ) 251 | 252 | def _get_storage_class(self) -> Type[BaseGraphStorage]: 253 | return { 254 | 255 | "JsonKVStorage": JsonKVStorage, 256 | "OracleKVStorage": OracleKVStorage, 257 | "MongoKVStorage": MongoKVStorage, 258 | "TiDBKVStorage": TiDBKVStorage, 259 | 260 | "NanoVectorDBStorage": NanoVectorDBStorage, 261 | "OracleVectorDBStorage": OracleVectorDBStorage, 262 | "MilvusVectorDBStorge": MilvusVectorDBStorge, 263 | "ChromaVectorDBStorage": ChromaVectorDBStorage, 264 | "TiDBVectorDBStorage": TiDBVectorDBStorage, 265 | 266 | "NetworkXStorage": NetworkXStorage, 267 | "Neo4JStorage": Neo4JStorage, 268 | "OracleGraphStorage": OracleGraphStorage, 269 | "AGEStorage": AGEStorage, 270 | 271 | } 272 | 273 | def insert(self, string_or_strings): 274 | 275 | loop = always_get_an_event_loop() 276 | return loop.run_until_complete(self.ainsert(string_or_strings)) 277 | 278 | async def ainsert(self, string_or_strings): 279 | update_storage = False 280 | try: 281 | if isinstance(string_or_strings, str): 282 | string_or_strings = [string_or_strings] 283 | 284 | new_docs = { 285 | compute_mdhash_id(c.strip(), prefix="doc-"): {"content": c.strip()} 286 | for c in string_or_strings 287 | } 288 | _add_doc_keys = await self.full_docs.filter_keys(list(new_docs.keys())) 289 | new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys} 290 | if not len(new_docs): 291 | logger.warning("All docs are already in the storage") 292 | return 293 | update_storage = True 294 | logger.info(f"[New Docs] inserting {len(new_docs)} docs") 295 | 296 | inserting_chunks = {} 297 | for doc_key, doc in tqdm_async( 298 | new_docs.items(), desc="Chunking documents", unit="doc" 299 | ): 300 | chunks = { 301 | compute_mdhash_id(dp["content"], prefix="chunk-"): { 302 | **dp, 303 | "full_doc_id": doc_key, 304 | } 305 | for dp in chunking_by_token_size( 306 | doc["content"], 307 | overlap_token_size=self.chunk_overlap_token_size, 308 | max_token_size=self.chunk_token_size, 309 | tiktoken_model=self.tiktoken_model_name, 310 | ) 311 | } 312 | inserting_chunks.update(chunks) 313 | _add_chunk_keys = await self.text_chunks.filter_keys( 314 | list(inserting_chunks.keys()) 315 | ) 316 | inserting_chunks = { 317 | k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys 318 | } 319 | if not len(inserting_chunks): 320 | logger.warning("All chunks are already in the storage") 321 | return 322 | logger.info(f"[New Chunks] inserting {len(inserting_chunks)} chunks") 323 | 324 | await self.chunks_vdb.upsert(inserting_chunks) 325 | 326 | logger.info("[Entity Extraction]...") 327 | maybe_new_kg = await extract_entities( 328 | inserting_chunks, 329 | knowledge_graph_inst=self.chunk_entity_relation_graph, 330 | entity_vdb=self.entities_vdb, 331 | relationships_vdb=self.relationships_vdb, 332 | global_config=asdict(self), 333 | ) 334 | if maybe_new_kg is None: 335 | logger.warning("No new entities and relationships found") 336 | return 337 | self.chunk_entity_relation_graph = maybe_new_kg 338 | 339 | await self.full_docs.upsert(new_docs) 340 | await self.text_chunks.upsert(inserting_chunks) 341 | finally: 342 | if update_storage: 343 | await self._insert_done() 344 | 345 | async def _insert_done(self): 346 | tasks = [] 347 | for storage_inst in [ 348 | self.full_docs, 349 | self.text_chunks, 350 | self.llm_response_cache, 351 | self.entities_vdb, 352 | self.relationships_vdb, 353 | self.chunks_vdb, 354 | self.chunk_entity_relation_graph, 355 | ]: 356 | if storage_inst is None: 357 | continue 358 | tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback()) 359 | await asyncio.gather(*tasks) 360 | 361 | def insert_custom_kg(self, custom_kg: dict): 362 | loop = always_get_an_event_loop() 363 | return loop.run_until_complete(self.ainsert_custom_kg(custom_kg)) 364 | 365 | async def ainsert_custom_kg(self, custom_kg: dict): 366 | update_storage = False 367 | try: 368 | 369 | all_chunks_data = {} 370 | chunk_to_source_map = {} 371 | for chunk_data in custom_kg.get("chunks", []): 372 | chunk_content = chunk_data["content"] 373 | source_id = chunk_data["source_id"] 374 | chunk_id = compute_mdhash_id(chunk_content.strip(), prefix="chunk-") 375 | 376 | chunk_entry = {"content": chunk_content.strip(), "source_id": source_id} 377 | all_chunks_data[chunk_id] = chunk_entry 378 | chunk_to_source_map[source_id] = chunk_id 379 | update_storage = True 380 | 381 | if self.chunks_vdb is not None and all_chunks_data: 382 | await self.chunks_vdb.upsert(all_chunks_data) 383 | if self.text_chunks is not None and all_chunks_data: 384 | await self.text_chunks.upsert(all_chunks_data) 385 | 386 | 387 | all_entities_data = [] 388 | for entity_data in custom_kg.get("entities", []): 389 | entity_name = f'"{entity_data["entity_name"].upper()}"' 390 | entity_type = entity_data.get("entity_type", "UNKNOWN") 391 | description = entity_data.get("description", "No description provided") 392 | 393 | source_chunk_id = entity_data.get("source_id", "UNKNOWN") 394 | source_id = chunk_to_source_map.get(source_chunk_id, "UNKNOWN") 395 | 396 | 397 | if source_id == "UNKNOWN": 398 | logger.warning( 399 | f"Entity '{entity_name}' has an UNKNOWN source_id. Please check the source mapping." 400 | ) 401 | 402 | 403 | node_data = { 404 | "entity_type": entity_type, 405 | "description": description, 406 | "source_id": source_id, 407 | } 408 | 409 | await self.chunk_entity_relation_graph.upsert_node( 410 | entity_name, node_data=node_data 411 | ) 412 | node_data["entity_name"] = entity_name 413 | all_entities_data.append(node_data) 414 | update_storage = True 415 | 416 | 417 | all_relationships_data = [] 418 | for relationship_data in custom_kg.get("relationships", []): 419 | src_id = f'"{relationship_data["src_id"].upper()}"' 420 | tgt_id = f'"{relationship_data["tgt_id"].upper()}"' 421 | description = relationship_data["description"] 422 | keywords = relationship_data["keywords"] 423 | weight = relationship_data.get("weight", 1.0) 424 | 425 | source_chunk_id = relationship_data.get("source_id", "UNKNOWN") 426 | source_id = chunk_to_source_map.get(source_chunk_id, "UNKNOWN") 427 | 428 | 429 | if source_id == "UNKNOWN": 430 | logger.warning( 431 | f"Relationship from '{src_id}' to '{tgt_id}' has an UNKNOWN source_id. Please check the source mapping." 432 | ) 433 | 434 | 435 | for need_insert_id in [src_id, tgt_id]: 436 | if not ( 437 | await self.chunk_entity_relation_graph.has_node(need_insert_id) 438 | ): 439 | await self.chunk_entity_relation_graph.upsert_node( 440 | need_insert_id, 441 | node_data={ 442 | "source_id": source_id, 443 | "description": "UNKNOWN", 444 | "entity_type": "UNKNOWN", 445 | }, 446 | ) 447 | 448 | 449 | await self.chunk_entity_relation_graph.upsert_edge( 450 | src_id, 451 | tgt_id, 452 | edge_data={ 453 | "weight": weight, 454 | "description": description, 455 | "keywords": keywords, 456 | "source_id": source_id, 457 | }, 458 | ) 459 | edge_data = { 460 | "src_id": src_id, 461 | "tgt_id": tgt_id, 462 | "description": description, 463 | "keywords": keywords, 464 | } 465 | all_relationships_data.append(edge_data) 466 | update_storage = True 467 | 468 | 469 | if self.entities_vdb is not None: 470 | data_for_vdb = { 471 | compute_mdhash_id(dp["entity_name"], prefix="ent-"): { 472 | "content": dp["entity_name"] + dp["description"], 473 | "entity_name": dp["entity_name"], 474 | } 475 | for dp in all_entities_data 476 | } 477 | await self.entities_vdb.upsert(data_for_vdb) 478 | 479 | 480 | if self.relationships_vdb is not None: 481 | data_for_vdb = { 482 | compute_mdhash_id(dp["src_id"] + dp["tgt_id"], prefix="rel-"): { 483 | "src_id": dp["src_id"], 484 | "tgt_id": dp["tgt_id"], 485 | "content": dp["keywords"] 486 | + dp["src_id"] 487 | + dp["tgt_id"] 488 | + dp["description"], 489 | } 490 | for dp in all_relationships_data 491 | } 492 | await self.relationships_vdb.upsert(data_for_vdb) 493 | finally: 494 | if update_storage: 495 | await self._insert_done() 496 | 497 | def query(self, query: str, param: QueryParam = QueryParam()): 498 | loop = always_get_an_event_loop() 499 | return loop.run_until_complete(self.aquery(query, param)) 500 | 501 | async def aquery(self, query: str, param: QueryParam = QueryParam()): 502 | if param.mode in ["hybrid"]: 503 | response= await kg_query( 504 | query, 505 | self.chunk_entity_relation_graph, 506 | self.entities_vdb, 507 | self.relationships_vdb, 508 | self.text_chunks, 509 | param, 510 | asdict(self), 511 | hashing_kv=self.llm_response_cache 512 | if self.llm_response_cache 513 | and hasattr(self.llm_response_cache, "global_config") 514 | else self.key_string_value_json_storage_cls( 515 | global_config=asdict(self), 516 | ), 517 | ) 518 | print("response all ready") 519 | else: 520 | raise ValueError(f"Unknown mode {param.mode}") 521 | await self._query_done() 522 | return response 523 | 524 | 525 | async def _query_done(self): 526 | tasks = [] 527 | for storage_inst in [self.llm_response_cache]: 528 | if storage_inst is None: 529 | continue 530 | tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback()) 531 | await asyncio.gather(*tasks) 532 | 533 | def delete_by_entity(self, entity_name: str): 534 | loop = always_get_an_event_loop() 535 | return loop.run_until_complete(self.adelete_by_entity(entity_name)) 536 | 537 | async def adelete_by_entity(self, entity_name: str): 538 | entity_name = f'"{entity_name.upper()}"' 539 | 540 | try: 541 | await self.entities_vdb.delete_entity(entity_name) 542 | await self.relationships_vdb.delete_relation(entity_name) 543 | await self.chunk_entity_relation_graph.delete_node(entity_name) 544 | 545 | logger.info( 546 | f"Entity '{entity_name}' and its relationships have been deleted." 547 | ) 548 | await self._delete_by_entity_done() 549 | except Exception as e: 550 | logger.error(f"Error while deleting entity '{entity_name}': {e}") 551 | 552 | async def _delete_by_entity_done(self): 553 | tasks = [] 554 | for storage_inst in [ 555 | self.entities_vdb, 556 | self.relationships_vdb, 557 | self.chunk_entity_relation_graph, 558 | ]: 559 | if storage_inst is None: 560 | continue 561 | tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback()) 562 | await asyncio.gather(*tasks) 563 | -------------------------------------------------------------------------------- /PathRAG/__init__.py: -------------------------------------------------------------------------------- 1 | from .PathRAG import PathRAG as PathRAG, QueryParam as QueryParam 2 | 3 | 4 | -------------------------------------------------------------------------------- /PathRAG/__pycache__/PathRAG.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BUPT-GAMMA/PathRAG/d9be0879d7c32dc35f1bd38ffa4b04b11ba9b246/PathRAG/__pycache__/PathRAG.cpython-39.pyc -------------------------------------------------------------------------------- /PathRAG/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BUPT-GAMMA/PathRAG/d9be0879d7c32dc35f1bd38ffa4b04b11ba9b246/PathRAG/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /PathRAG/__pycache__/base.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BUPT-GAMMA/PathRAG/d9be0879d7c32dc35f1bd38ffa4b04b11ba9b246/PathRAG/__pycache__/base.cpython-39.pyc -------------------------------------------------------------------------------- /PathRAG/__pycache__/llm.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BUPT-GAMMA/PathRAG/d9be0879d7c32dc35f1bd38ffa4b04b11ba9b246/PathRAG/__pycache__/llm.cpython-39.pyc -------------------------------------------------------------------------------- /PathRAG/__pycache__/operate.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BUPT-GAMMA/PathRAG/d9be0879d7c32dc35f1bd38ffa4b04b11ba9b246/PathRAG/__pycache__/operate.cpython-39.pyc -------------------------------------------------------------------------------- /PathRAG/__pycache__/prompt.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BUPT-GAMMA/PathRAG/d9be0879d7c32dc35f1bd38ffa4b04b11ba9b246/PathRAG/__pycache__/prompt.cpython-39.pyc -------------------------------------------------------------------------------- /PathRAG/__pycache__/storage.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BUPT-GAMMA/PathRAG/d9be0879d7c32dc35f1bd38ffa4b04b11ba9b246/PathRAG/__pycache__/storage.cpython-39.pyc -------------------------------------------------------------------------------- /PathRAG/__pycache__/utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BUPT-GAMMA/PathRAG/d9be0879d7c32dc35f1bd38ffa4b04b11ba9b246/PathRAG/__pycache__/utils.cpython-39.pyc -------------------------------------------------------------------------------- /PathRAG/base.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from typing import TypedDict, Union, Literal, Generic, TypeVar 3 | 4 | import numpy as np 5 | 6 | from .utils import EmbeddingFunc 7 | 8 | TextChunkSchema = TypedDict( 9 | "TextChunkSchema", 10 | {"tokens": int, "content": str, "full_doc_id": str, "chunk_order_index": int}, 11 | ) 12 | 13 | T = TypeVar("T") 14 | 15 | 16 | @dataclass 17 | class QueryParam: 18 | mode: Literal["hybrid"] = "global" 19 | only_need_context: bool = False 20 | only_need_prompt: bool = False 21 | response_type: str = "Multiple Paragraphs" 22 | stream: bool = False 23 | top_k: int =40 24 | max_token_for_text_unit: int = 4000 25 | max_token_for_global_context: int = 3000 26 | max_token_for_local_context: int = 5000 27 | 28 | 29 | @dataclass 30 | class StorageNameSpace: 31 | namespace: str 32 | global_config: dict 33 | 34 | async def index_done_callback(self): 35 | 36 | pass 37 | 38 | async def query_done_callback(self): 39 | 40 | pass 41 | 42 | 43 | @dataclass 44 | class BaseVectorStorage(StorageNameSpace): 45 | embedding_func: EmbeddingFunc 46 | meta_fields: set = field(default_factory=set) 47 | 48 | async def query(self, query: str, top_k: int) -> list[dict]: 49 | raise NotImplementedError 50 | 51 | async def upsert(self, data: dict[str, dict]): 52 | 53 | raise NotImplementedError 54 | 55 | 56 | @dataclass 57 | class BaseKVStorage(Generic[T], StorageNameSpace): 58 | embedding_func: EmbeddingFunc 59 | 60 | async def all_keys(self) -> list[str]: 61 | raise NotImplementedError 62 | 63 | async def get_by_id(self, id: str) -> Union[T, None]: 64 | raise NotImplementedError 65 | 66 | async def get_by_ids( 67 | self, ids: list[str], fields: Union[set[str], None] = None 68 | ) -> list[Union[T, None]]: 69 | raise NotImplementedError 70 | 71 | async def filter_keys(self, data: list[str]) -> set[str]: 72 | 73 | raise NotImplementedError 74 | 75 | async def upsert(self, data: dict[str, T]): 76 | raise NotImplementedError 77 | 78 | async def drop(self): 79 | raise NotImplementedError 80 | 81 | 82 | @dataclass 83 | class BaseGraphStorage(StorageNameSpace): 84 | embedding_func: EmbeddingFunc = None 85 | 86 | async def has_node(self, node_id: str) -> bool: 87 | raise NotImplementedError 88 | 89 | async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: 90 | raise NotImplementedError 91 | 92 | async def node_degree(self, node_id: str) -> int: 93 | raise NotImplementedError 94 | 95 | async def edge_degree(self, src_id: str, tgt_id: str) -> int: 96 | raise NotImplementedError 97 | 98 | async def get_pagerank(self,node_id:str) -> float: 99 | raise NotImplementedError 100 | 101 | async def get_node(self, node_id: str) -> Union[dict, None]: 102 | raise NotImplementedError 103 | 104 | async def get_edge( 105 | self, source_node_id: str, target_node_id: str 106 | ) -> Union[dict, None]: 107 | raise NotImplementedError 108 | 109 | async def get_node_edges( 110 | self, source_node_id: str 111 | ) -> Union[list[tuple[str, str]], None]: 112 | raise NotImplementedError 113 | 114 | async def get_node_in_edges( 115 | self,source_node_id:str 116 | ) -> Union[list[tuple[str,str]],None]: 117 | raise NotImplementedError 118 | async def get_node_out_edges( 119 | self,source_node_id:str 120 | ) -> Union[list[tuple[str,str]],None]: 121 | raise NotImplementedError 122 | 123 | async def upsert_node(self, node_id: str, node_data: dict[str, str]): 124 | raise NotImplementedError 125 | 126 | async def upsert_edge( 127 | self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] 128 | ): 129 | raise NotImplementedError 130 | 131 | async def delete_node(self, node_id: str): 132 | raise NotImplementedError 133 | 134 | async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]: 135 | raise NotImplementedError("Node embedding is not used in PathRag.") 136 | -------------------------------------------------------------------------------- /PathRAG/llm.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import copy 3 | import json 4 | import os 5 | import re 6 | import struct 7 | from functools import lru_cache 8 | from typing import List, Dict, Callable, Any, Union, Optional 9 | import aioboto3 10 | import aiohttp 11 | import numpy as np 12 | import ollama 13 | import torch 14 | import time 15 | from openai import ( 16 | AsyncOpenAI, 17 | APIConnectionError, 18 | RateLimitError, 19 | Timeout, 20 | AsyncAzureOpenAI, 21 | ) 22 | from pydantic import BaseModel, Field 23 | from tenacity import ( 24 | retry, 25 | stop_after_attempt, 26 | wait_exponential, 27 | retry_if_exception_type, 28 | ) 29 | from transformers import AutoTokenizer, AutoModelForCausalLM 30 | 31 | from .utils import ( 32 | wrap_embedding_func_with_attrs, 33 | locate_json_string_body_from_string, 34 | safe_unicode_decode, 35 | logger, 36 | ) 37 | 38 | import sys 39 | 40 | if sys.version_info < (3, 9): 41 | from typing import AsyncIterator 42 | else: 43 | from collections.abc import AsyncIterator 44 | 45 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 46 | 47 | 48 | @retry( 49 | stop=stop_after_attempt(3), 50 | wait=wait_exponential(multiplier=1, min=4, max=10), 51 | retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)), 52 | ) 53 | async def openai_complete_if_cache( 54 | model, 55 | prompt, 56 | system_prompt=None, 57 | history_messages=[], 58 | base_url="https://api.openai.com/v1", 59 | api_key="", 60 | **kwargs, 61 | ) -> str: 62 | if api_key: 63 | os.environ["OPENAI_API_KEY"] = api_key 64 | time.sleep(2) 65 | openai_async_client = ( 66 | AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url) 67 | ) 68 | kwargs.pop("hashing_kv", None) 69 | kwargs.pop("keyword_extraction", None) 70 | messages = [] 71 | if system_prompt: 72 | messages.append({"role": "system", "content": system_prompt}) 73 | messages.extend(history_messages) 74 | messages.append({"role": "user", "content": prompt}) 75 | 76 | 77 | logger.debug("===== Query Input to LLM =====") 78 | logger.debug(f"Query: {prompt}") 79 | logger.debug(f"System prompt: {system_prompt}") 80 | logger.debug("Full context:") 81 | if "response_format" in kwargs: 82 | response = await openai_async_client.beta.chat.completions.parse( 83 | model=model, messages=messages, **kwargs 84 | ) 85 | else: 86 | response = await openai_async_client.chat.completions.create( 87 | model=model, messages=messages, **kwargs 88 | ) 89 | 90 | if hasattr(response, "__aiter__"): 91 | 92 | async def inner(): 93 | async for chunk in response: 94 | content = chunk.choices[0].delta.content 95 | if content is None: 96 | continue 97 | if r"\u" in content: 98 | content = safe_unicode_decode(content.encode("utf-8")) 99 | yield content 100 | 101 | return inner() 102 | else: 103 | content = response.choices[0].message.content 104 | if r"\u" in content: 105 | content = safe_unicode_decode(content.encode("utf-8")) 106 | return content 107 | 108 | 109 | @retry( 110 | stop=stop_after_attempt(3), 111 | wait=wait_exponential(multiplier=1, min=4, max=10), 112 | retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)), 113 | ) 114 | async def azure_openai_complete_if_cache( 115 | model, 116 | prompt, 117 | system_prompt=None, 118 | history_messages=[], 119 | base_url=None, 120 | api_key=None, 121 | api_version=None, 122 | **kwargs, 123 | ): 124 | if api_key: 125 | os.environ["AZURE_OPENAI_API_KEY"] = api_key 126 | if base_url: 127 | os.environ["AZURE_OPENAI_ENDPOINT"] = base_url 128 | if api_version: 129 | os.environ["AZURE_OPENAI_API_VERSION"] = api_version 130 | 131 | openai_async_client = AsyncAzureOpenAI( 132 | azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"), 133 | api_key=os.getenv("AZURE_OPENAI_API_KEY"), 134 | api_version=os.getenv("AZURE_OPENAI_API_VERSION"), 135 | ) 136 | kwargs.pop("hashing_kv", None) 137 | messages = [] 138 | if system_prompt: 139 | messages.append({"role": "system", "content": system_prompt}) 140 | messages.extend(history_messages) 141 | if prompt is not None: 142 | messages.append({"role": "user", "content": prompt}) 143 | 144 | response = await openai_async_client.chat.completions.create( 145 | model=model, messages=messages, **kwargs 146 | ) 147 | content = response.choices[0].message.content 148 | 149 | return content 150 | 151 | 152 | class BedrockError(Exception): 153 | """Generic error for issues related to Amazon Bedrock""" 154 | 155 | 156 | @retry( 157 | stop=stop_after_attempt(5), 158 | wait=wait_exponential(multiplier=1, max=60), 159 | retry=retry_if_exception_type((BedrockError)), 160 | ) 161 | async def bedrock_complete_if_cache( 162 | model, 163 | prompt, 164 | system_prompt=None, 165 | history_messages=[], 166 | aws_access_key_id=None, 167 | aws_secret_access_key=None, 168 | aws_session_token=None, 169 | **kwargs, 170 | ) -> str: 171 | os.environ["AWS_ACCESS_KEY_ID"] = os.environ.get( 172 | "AWS_ACCESS_KEY_ID", aws_access_key_id 173 | ) 174 | os.environ["AWS_SECRET_ACCESS_KEY"] = os.environ.get( 175 | "AWS_SECRET_ACCESS_KEY", aws_secret_access_key 176 | ) 177 | os.environ["AWS_SESSION_TOKEN"] = os.environ.get( 178 | "AWS_SESSION_TOKEN", aws_session_token 179 | ) 180 | kwargs.pop("hashing_kv", None) 181 | 182 | messages = [] 183 | for history_message in history_messages: 184 | message = copy.copy(history_message) 185 | message["content"] = [{"text": message["content"]}] 186 | messages.append(message) 187 | 188 | 189 | messages.append({"role": "user", "content": [{"text": prompt}]}) 190 | 191 | 192 | args = {"modelId": model, "messages": messages} 193 | 194 | 195 | if system_prompt: 196 | args["system"] = [{"text": system_prompt}] 197 | 198 | 199 | inference_params_map = { 200 | "max_tokens": "maxTokens", 201 | "top_p": "topP", 202 | "stop_sequences": "stopSequences", 203 | } 204 | if inference_params := list( 205 | set(kwargs) & set(["max_tokens", "temperature", "top_p", "stop_sequences"]) 206 | ): 207 | args["inferenceConfig"] = {} 208 | for param in inference_params: 209 | args["inferenceConfig"][inference_params_map.get(param, param)] = ( 210 | kwargs.pop(param) 211 | ) 212 | 213 | 214 | session = aioboto3.Session() 215 | async with session.client("bedrock-runtime") as bedrock_async_client: 216 | try: 217 | response = await bedrock_async_client.converse(**args, **kwargs) 218 | except Exception as e: 219 | raise BedrockError(e) 220 | 221 | return response["output"]["message"]["content"][0]["text"] 222 | 223 | 224 | @lru_cache(maxsize=1) 225 | def initialize_hf_model(model_name): 226 | hf_tokenizer = AutoTokenizer.from_pretrained( 227 | model_name, device_map="auto", trust_remote_code=True 228 | ) 229 | hf_model = AutoModelForCausalLM.from_pretrained( 230 | model_name, device_map="auto", trust_remote_code=True 231 | ) 232 | if hf_tokenizer.pad_token is None: 233 | hf_tokenizer.pad_token = hf_tokenizer.eos_token 234 | 235 | return hf_model, hf_tokenizer 236 | 237 | 238 | @retry( 239 | stop=stop_after_attempt(3), 240 | wait=wait_exponential(multiplier=1, min=4, max=10), 241 | retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)), 242 | ) 243 | async def hf_model_if_cache( 244 | model, 245 | prompt, 246 | system_prompt=None, 247 | history_messages=[], 248 | **kwargs, 249 | ) -> str: 250 | model_name = model 251 | hf_model, hf_tokenizer = initialize_hf_model(model_name) 252 | messages = [] 253 | if system_prompt: 254 | messages.append({"role": "system", "content": system_prompt}) 255 | messages.extend(history_messages) 256 | messages.append({"role": "user", "content": prompt}) 257 | kwargs.pop("hashing_kv", None) 258 | input_prompt = "" 259 | try: 260 | input_prompt = hf_tokenizer.apply_chat_template( 261 | messages, tokenize=False, add_generation_prompt=True 262 | ) 263 | except Exception: 264 | try: 265 | ori_message = copy.deepcopy(messages) 266 | if messages[0]["role"] == "system": 267 | messages[1]["content"] = ( 268 | "" 269 | + messages[0]["content"] 270 | + "\n" 271 | + messages[1]["content"] 272 | ) 273 | messages = messages[1:] 274 | input_prompt = hf_tokenizer.apply_chat_template( 275 | messages, tokenize=False, add_generation_prompt=True 276 | ) 277 | except Exception: 278 | len_message = len(ori_message) 279 | for msgid in range(len_message): 280 | input_prompt = ( 281 | input_prompt 282 | + "<" 283 | + ori_message[msgid]["role"] 284 | + ">" 285 | + ori_message[msgid]["content"] 286 | + "\n" 289 | ) 290 | 291 | input_ids = hf_tokenizer( 292 | input_prompt, return_tensors="pt", padding=True, truncation=True 293 | ).to("cuda") 294 | inputs = {k: v.to(hf_model.device) for k, v in input_ids.items()} 295 | output = hf_model.generate( 296 | **input_ids, max_new_tokens=512, num_return_sequences=1, early_stopping=True 297 | ) 298 | response_text = hf_tokenizer.decode( 299 | output[0][len(inputs["input_ids"][0]) :], skip_special_tokens=True 300 | ) 301 | 302 | return response_text 303 | 304 | 305 | @retry( 306 | stop=stop_after_attempt(3), 307 | wait=wait_exponential(multiplier=1, min=4, max=10), 308 | retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)), 309 | ) 310 | async def ollama_model_if_cache( 311 | model, 312 | prompt, 313 | system_prompt=None, 314 | history_messages=[], 315 | **kwargs, 316 | ) -> Union[str, AsyncIterator[str]]: 317 | stream = True if kwargs.get("stream") else False 318 | kwargs.pop("max_tokens", None) 319 | host = kwargs.pop("host", None) 320 | timeout = kwargs.pop("timeout", None) 321 | kwargs.pop("hashing_kv", None) 322 | ollama_client = ollama.AsyncClient(host=host, timeout=timeout) 323 | messages = [] 324 | if system_prompt: 325 | messages.append({"role": "system", "content": system_prompt}) 326 | messages.extend(history_messages) 327 | messages.append({"role": "user", "content": prompt}) 328 | 329 | response = await ollama_client.chat(model=model, messages=messages, **kwargs) 330 | if stream: 331 | """cannot cache stream response""" 332 | 333 | async def inner(): 334 | async for chunk in response: 335 | yield chunk["message"]["content"] 336 | 337 | return inner() 338 | else: 339 | return response["message"]["content"] 340 | 341 | 342 | @lru_cache(maxsize=1) 343 | def initialize_lmdeploy_pipeline( 344 | model, 345 | tp=1, 346 | chat_template=None, 347 | log_level="WARNING", 348 | model_format="hf", 349 | quant_policy=0, 350 | ): 351 | from lmdeploy import pipeline, ChatTemplateConfig, TurbomindEngineConfig 352 | 353 | lmdeploy_pipe = pipeline( 354 | model_path=model, 355 | backend_config=TurbomindEngineConfig( 356 | tp=tp, model_format=model_format, quant_policy=quant_policy 357 | ), 358 | chat_template_config=( 359 | ChatTemplateConfig(model_name=chat_template) if chat_template else None 360 | ), 361 | log_level="WARNING", 362 | ) 363 | return lmdeploy_pipe 364 | 365 | 366 | @retry( 367 | stop=stop_after_attempt(3), 368 | wait=wait_exponential(multiplier=1, min=4, max=10), 369 | retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)), 370 | ) 371 | async def lmdeploy_model_if_cache( 372 | model, 373 | prompt, 374 | system_prompt=None, 375 | history_messages=[], 376 | chat_template=None, 377 | model_format="hf", 378 | quant_policy=0, 379 | **kwargs, 380 | ) -> str: 381 | """ 382 | Args: 383 | model (str): The path to the model. 384 | It could be one of the following options: 385 | - i) A local directory path of a turbomind model which is 386 | converted by `lmdeploy convert` command or download 387 | from ii) and iii). 388 | - ii) The model_id of a lmdeploy-quantized model hosted 389 | inside a model repo on huggingface.co, such as 390 | "InternLM/internlm-chat-20b-4bit", 391 | "lmdeploy/llama2-chat-70b-4bit", etc. 392 | - iii) The model_id of a model hosted inside a model repo 393 | on huggingface.co, such as "internlm/internlm-chat-7b", 394 | "Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat" 395 | and so on. 396 | chat_template (str): needed when model is a pytorch model on 397 | huggingface.co, such as "internlm-chat-7b", 398 | "Qwen-7B-Chat ", "Baichuan2-7B-Chat" and so on, 399 | and when the model name of local path did not match the original model name in HF. 400 | tp (int): tensor parallel 401 | prompt (Union[str, List[str]]): input texts to be completed. 402 | do_preprocess (bool): whether pre-process the messages. Default to 403 | True, which means chat_template will be applied. 404 | skip_special_tokens (bool): Whether or not to remove special tokens 405 | in the decoding. Default to be True. 406 | do_sample (bool): Whether or not to use sampling, use greedy decoding otherwise. 407 | Default to be False, which means greedy decoding will be applied. 408 | """ 409 | try: 410 | import lmdeploy 411 | from lmdeploy import version_info, GenerationConfig 412 | except Exception: 413 | raise ImportError("Please install lmdeploy before initialize lmdeploy backend.") 414 | kwargs.pop("hashing_kv", None) 415 | kwargs.pop("response_format", None) 416 | max_new_tokens = kwargs.pop("max_tokens", 512) 417 | tp = kwargs.pop("tp", 1) 418 | skip_special_tokens = kwargs.pop("skip_special_tokens", True) 419 | do_preprocess = kwargs.pop("do_preprocess", True) 420 | do_sample = kwargs.pop("do_sample", False) 421 | gen_params = kwargs 422 | 423 | version = version_info 424 | if do_sample is not None and version < (0, 6, 0): 425 | raise RuntimeError( 426 | "`do_sample` parameter is not supported by lmdeploy until " 427 | f"v0.6.0, but currently using lmdeloy {lmdeploy.__version__}" 428 | ) 429 | else: 430 | do_sample = True 431 | gen_params.update(do_sample=do_sample) 432 | 433 | lmdeploy_pipe = initialize_lmdeploy_pipeline( 434 | model=model, 435 | tp=tp, 436 | chat_template=chat_template, 437 | model_format=model_format, 438 | quant_policy=quant_policy, 439 | log_level="WARNING", 440 | ) 441 | 442 | messages = [] 443 | if system_prompt: 444 | messages.append({"role": "system", "content": system_prompt}) 445 | 446 | messages.extend(history_messages) 447 | messages.append({"role": "user", "content": prompt}) 448 | 449 | gen_config = GenerationConfig( 450 | skip_special_tokens=skip_special_tokens, 451 | max_new_tokens=max_new_tokens, 452 | **gen_params, 453 | ) 454 | 455 | response = "" 456 | async for res in lmdeploy_pipe.generate( 457 | messages, 458 | gen_config=gen_config, 459 | do_preprocess=do_preprocess, 460 | stream_response=False, 461 | session_id=1, 462 | ): 463 | response += res.response 464 | return response 465 | 466 | 467 | class GPTKeywordExtractionFormat(BaseModel): 468 | high_level_keywords: List[str] 469 | low_level_keywords: List[str] 470 | 471 | 472 | async def openai_complete( 473 | prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs 474 | ) -> Union[str, AsyncIterator[str]]: 475 | keyword_extraction = kwargs.pop("keyword_extraction", None) 476 | if keyword_extraction: 477 | kwargs["response_format"] = "json" 478 | model_name = kwargs["hashing_kv"].global_config["llm_model_name"] 479 | return await openai_complete_if_cache( 480 | model_name, 481 | prompt, 482 | system_prompt=system_prompt, 483 | history_messages=history_messages, 484 | **kwargs, 485 | ) 486 | 487 | 488 | async def gpt_4o_complete( 489 | prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs 490 | ) -> str: 491 | keyword_extraction = kwargs.pop("keyword_extraction", None) 492 | if keyword_extraction: 493 | kwargs["response_format"] = GPTKeywordExtractionFormat 494 | return await openai_complete_if_cache( 495 | "gpt-4o", 496 | prompt, 497 | system_prompt=system_prompt, 498 | history_messages=history_messages, 499 | **kwargs, 500 | ) 501 | 502 | 503 | async def gpt_4o_mini_complete( 504 | prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs 505 | ) -> str: 506 | keyword_extraction = kwargs.pop("keyword_extraction", None) 507 | if keyword_extraction: 508 | kwargs["response_format"] = GPTKeywordExtractionFormat 509 | return await openai_complete_if_cache( 510 | "gpt-4o-mini", 511 | prompt, 512 | system_prompt=system_prompt, 513 | history_messages=history_messages, 514 | **kwargs, 515 | ) 516 | 517 | 518 | async def nvidia_openai_complete( 519 | prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs 520 | ) -> str: 521 | keyword_extraction = kwargs.pop("keyword_extraction", None) 522 | result = await openai_complete_if_cache( 523 | "nvidia/llama-3.1-nemotron-70b-instruct", 524 | prompt, 525 | system_prompt=system_prompt, 526 | history_messages=history_messages, 527 | base_url="https://integrate.api.nvidia.com/v1", 528 | **kwargs, 529 | ) 530 | if keyword_extraction: # TODO: use JSON API 531 | return locate_json_string_body_from_string(result) 532 | return result 533 | 534 | 535 | async def azure_openai_complete( 536 | prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs 537 | ) -> str: 538 | keyword_extraction = kwargs.pop("keyword_extraction", None) 539 | result = await azure_openai_complete_if_cache( 540 | "conversation-4o-mini", 541 | prompt, 542 | system_prompt=system_prompt, 543 | history_messages=history_messages, 544 | **kwargs, 545 | ) 546 | if keyword_extraction: # TODO: use JSON API 547 | return locate_json_string_body_from_string(result) 548 | return result 549 | 550 | 551 | async def bedrock_complete( 552 | prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs 553 | ) -> str: 554 | keyword_extraction = kwargs.pop("keyword_extraction", None) 555 | result = await bedrock_complete_if_cache( 556 | "anthropic.claude-3-haiku-20240307-v1:0", 557 | prompt, 558 | system_prompt=system_prompt, 559 | history_messages=history_messages, 560 | **kwargs, 561 | ) 562 | if keyword_extraction: # TODO: use JSON API 563 | return locate_json_string_body_from_string(result) 564 | return result 565 | 566 | 567 | async def hf_model_complete( 568 | prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs 569 | ) -> str: 570 | keyword_extraction = kwargs.pop("keyword_extraction", None) 571 | model_name = kwargs["hashing_kv"].global_config["llm_model_name"] 572 | result = await hf_model_if_cache( 573 | model_name, 574 | prompt, 575 | system_prompt=system_prompt, 576 | history_messages=history_messages, 577 | **kwargs, 578 | ) 579 | if keyword_extraction: # TODO: use JSON API 580 | return locate_json_string_body_from_string(result) 581 | return result 582 | 583 | 584 | async def ollama_model_complete( 585 | prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs 586 | ) -> Union[str, AsyncIterator[str]]: 587 | keyword_extraction = kwargs.pop("keyword_extraction", None) 588 | if keyword_extraction: 589 | kwargs["format"] = "json" 590 | model_name = kwargs["hashing_kv"].global_config["llm_model_name"] 591 | return await ollama_model_if_cache( 592 | model_name, 593 | prompt, 594 | system_prompt=system_prompt, 595 | history_messages=history_messages, 596 | **kwargs, 597 | ) 598 | 599 | 600 | @retry( 601 | stop=stop_after_attempt(3), 602 | wait=wait_exponential(multiplier=1, min=4, max=10), 603 | retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)), 604 | ) 605 | async def zhipu_complete_if_cache( 606 | prompt: Union[str, List[Dict[str, str]]], 607 | model: str = "glm-4-flashx", 608 | api_key: Optional[str] = None, 609 | system_prompt: Optional[str] = None, 610 | history_messages: List[Dict[str, str]] = [], 611 | **kwargs, 612 | ) -> str: 613 | 614 | try: 615 | from zhipuai import ZhipuAI 616 | except ImportError: 617 | raise ImportError("Please install zhipuai before initialize zhipuai backend.") 618 | 619 | if api_key: 620 | client = ZhipuAI(api_key=api_key) 621 | else: 622 | client = ZhipuAI() 623 | 624 | messages = [] 625 | 626 | if not system_prompt: 627 | system_prompt = "You are a helpful assistant. 注意内容里的敏感词用***替换。" 628 | 629 | 630 | if system_prompt: 631 | messages.append({"role": "system", "content": system_prompt}) 632 | messages.extend(history_messages) 633 | messages.append({"role": "user", "content": prompt}) 634 | 635 | 636 | logger.debug("===== Query Input to LLM =====") 637 | logger.debug(f"Query: {prompt}") 638 | logger.debug(f"System prompt: {system_prompt}") 639 | 640 | 641 | kwargs = { 642 | k: v for k, v in kwargs.items() if k not in ["hashing_kv", "keyword_extraction"] 643 | } 644 | 645 | response = client.chat.completions.create(model=model, messages=messages, **kwargs) 646 | 647 | return response.choices[0].message.content 648 | 649 | 650 | async def zhipu_complete( 651 | prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs 652 | ): 653 | 654 | keyword_extraction = kwargs.pop("keyword_extraction", None) 655 | 656 | if keyword_extraction: 657 | extraction_prompt = """You are a helpful assistant that extracts keywords from text. 658 | Please analyze the content and extract two types of keywords: 659 | 1. High-level keywords: Important concepts and main themes 660 | 2. Low-level keywords: Specific details and supporting elements 661 | 662 | Return your response in this exact JSON format: 663 | { 664 | "high_level_keywords": ["keyword1", "keyword2"], 665 | "low_level_keywords": ["keyword1", "keyword2", "keyword3"] 666 | } 667 | 668 | Only return the JSON, no other text.""" 669 | 670 | 671 | if system_prompt: 672 | system_prompt = f"{system_prompt}\n\n{extraction_prompt}" 673 | else: 674 | system_prompt = extraction_prompt 675 | 676 | try: 677 | response = await zhipu_complete_if_cache( 678 | prompt=prompt, 679 | system_prompt=system_prompt, 680 | history_messages=history_messages, 681 | **kwargs, 682 | ) 683 | 684 | 685 | try: 686 | data = json.loads(response) 687 | return GPTKeywordExtractionFormat( 688 | high_level_keywords=data.get("high_level_keywords", []), 689 | low_level_keywords=data.get("low_level_keywords", []), 690 | ) 691 | except json.JSONDecodeError: 692 | 693 | match = re.search(r"\{[\s\S]*\}", response) 694 | if match: 695 | try: 696 | data = json.loads(match.group()) 697 | return GPTKeywordExtractionFormat( 698 | high_level_keywords=data.get("high_level_keywords", []), 699 | low_level_keywords=data.get("low_level_keywords", []), 700 | ) 701 | except json.JSONDecodeError: 702 | pass 703 | 704 | 705 | logger.warning( 706 | f"Failed to parse keyword extraction response: {response}" 707 | ) 708 | return GPTKeywordExtractionFormat( 709 | high_level_keywords=[], low_level_keywords=[] 710 | ) 711 | except Exception as e: 712 | logger.error(f"Error during keyword extraction: {str(e)}") 713 | return GPTKeywordExtractionFormat( 714 | high_level_keywords=[], low_level_keywords=[] 715 | ) 716 | else: 717 | return await zhipu_complete_if_cache( 718 | prompt=prompt, 719 | system_prompt=system_prompt, 720 | history_messages=history_messages, 721 | **kwargs, 722 | ) 723 | 724 | 725 | @wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=8192) 726 | @retry( 727 | stop=stop_after_attempt(3), 728 | wait=wait_exponential(multiplier=1, min=4, max=60), 729 | retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)), 730 | ) 731 | async def zhipu_embedding( 732 | texts: list[str], model: str = "embedding-3", api_key: str = None, **kwargs 733 | ) -> np.ndarray: 734 | 735 | try: 736 | from zhipuai import ZhipuAI 737 | except ImportError: 738 | raise ImportError("Please install zhipuai before initialize zhipuai backend.") 739 | if api_key: 740 | client = ZhipuAI(api_key=api_key) 741 | else: 742 | client = ZhipuAI() 743 | 744 | if isinstance(texts, str): 745 | texts = [texts] 746 | 747 | embeddings = [] 748 | for text in texts: 749 | try: 750 | response = client.embeddings.create(model=model, input=[text], **kwargs) 751 | embeddings.append(response.data[0].embedding) 752 | except Exception as e: 753 | raise Exception(f"Error calling ChatGLM Embedding API: {str(e)}") 754 | 755 | return np.array(embeddings) 756 | 757 | 758 | @wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192) 759 | @retry( 760 | stop=stop_after_attempt(3), 761 | wait=wait_exponential(multiplier=1, min=4, max=60), 762 | retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)), 763 | ) 764 | async def openai_embedding( 765 | texts: list[str], 766 | model: str = "text-embedding-3-small", 767 | base_url="https://api.openai.com/v1", 768 | api_key="", 769 | ) -> np.ndarray: 770 | if api_key: 771 | os.environ["OPENAI_API_KEY"] = api_key 772 | 773 | openai_async_client = ( 774 | AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url) 775 | ) 776 | response = await openai_async_client.embeddings.create( 777 | model=model, input=texts, encoding_format="float" 778 | ) 779 | return np.array([dp.embedding for dp in response.data]) 780 | 781 | 782 | async def fetch_data(url, headers, data): 783 | async with aiohttp.ClientSession() as session: 784 | async with session.post(url, headers=headers, json=data) as response: 785 | response_json = await response.json() 786 | data_list = response_json.get("data", []) 787 | return data_list 788 | 789 | 790 | async def jina_embedding( 791 | texts: list[str], 792 | dimensions: int = 1024, 793 | late_chunking: bool = False, 794 | base_url: str = None, 795 | api_key: str = None, 796 | ) -> np.ndarray: 797 | if api_key: 798 | os.environ["JINA_API_KEY"] = api_key 799 | url = "https://api.jina.ai/v1/embeddings" if not base_url else base_url 800 | headers = { 801 | "Content-Type": "application/json", 802 | "Authorization": f"Bearer {os.environ['JINA_API_KEY']}", 803 | } 804 | data = { 805 | "model": "jina-embeddings-v3", 806 | "normalized": True, 807 | "embedding_type": "float", 808 | "dimensions": f"{dimensions}", 809 | "late_chunking": late_chunking, 810 | "input": texts, 811 | } 812 | data_list = await fetch_data(url, headers, data) 813 | return np.array([dp["embedding"] for dp in data_list]) 814 | 815 | 816 | @wrap_embedding_func_with_attrs(embedding_dim=2048, max_token_size=512) 817 | @retry( 818 | stop=stop_after_attempt(3), 819 | wait=wait_exponential(multiplier=1, min=4, max=60), 820 | retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)), 821 | ) 822 | async def nvidia_openai_embedding( 823 | texts: list[str], 824 | model: str = "nvidia/llama-3.2-nv-embedqa-1b-v1", 825 | base_url: str = "https://integrate.api.nvidia.com/v1", 826 | api_key: str = None, 827 | input_type: str = "passage", 828 | trunc: str = "NONE", 829 | encode: str = "float", 830 | ) -> np.ndarray: 831 | if api_key: 832 | os.environ["OPENAI_API_KEY"] = api_key 833 | 834 | openai_async_client = ( 835 | AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url) 836 | ) 837 | response = await openai_async_client.embeddings.create( 838 | model=model, 839 | input=texts, 840 | encoding_format=encode, 841 | extra_body={"input_type": input_type, "truncate": trunc}, 842 | ) 843 | return np.array([dp.embedding for dp in response.data]) 844 | 845 | 846 | @wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8191) 847 | @retry( 848 | stop=stop_after_attempt(3), 849 | wait=wait_exponential(multiplier=1, min=4, max=10), 850 | retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)), 851 | ) 852 | async def azure_openai_embedding( 853 | texts: list[str], 854 | model: str = "text-embedding-3-small", 855 | base_url: str = None, 856 | api_key: str = None, 857 | api_version: str = None, 858 | ) -> np.ndarray: 859 | if api_key: 860 | os.environ["AZURE_OPENAI_API_KEY"] = api_key 861 | if base_url: 862 | os.environ["AZURE_OPENAI_ENDPOINT"] = base_url 863 | if api_version: 864 | os.environ["AZURE_OPENAI_API_VERSION"] = api_version 865 | 866 | openai_async_client = AsyncAzureOpenAI( 867 | azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"), 868 | api_key=os.getenv("AZURE_OPENAI_API_KEY"), 869 | api_version=os.getenv("AZURE_OPENAI_API_VERSION"), 870 | ) 871 | 872 | response = await openai_async_client.embeddings.create( 873 | model=model, input=texts, encoding_format="float" 874 | ) 875 | return np.array([dp.embedding for dp in response.data]) 876 | 877 | 878 | @retry( 879 | stop=stop_after_attempt(3), 880 | wait=wait_exponential(multiplier=1, min=4, max=60), 881 | retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)), 882 | ) 883 | async def siliconcloud_embedding( 884 | texts: list[str], 885 | model: str = "netease-youdao/bce-embedding-base_v1", 886 | base_url: str = "https://api.siliconflow.cn/v1/embeddings", 887 | max_token_size: int = 512, 888 | api_key: str = None, 889 | ) -> np.ndarray: 890 | if api_key and not api_key.startswith("Bearer "): 891 | api_key = "Bearer " + api_key 892 | 893 | headers = {"Authorization": api_key, "Content-Type": "application/json"} 894 | 895 | truncate_texts = [text[0:max_token_size] for text in texts] 896 | 897 | payload = {"model": model, "input": truncate_texts, "encoding_format": "base64"} 898 | 899 | base64_strings = [] 900 | async with aiohttp.ClientSession() as session: 901 | async with session.post(base_url, headers=headers, json=payload) as response: 902 | content = await response.json() 903 | if "code" in content: 904 | raise ValueError(content) 905 | base64_strings = [item["embedding"] for item in content["data"]] 906 | 907 | embeddings = [] 908 | for string in base64_strings: 909 | decode_bytes = base64.b64decode(string) 910 | n = len(decode_bytes) // 4 911 | float_array = struct.unpack("<" + "f" * n, decode_bytes) 912 | embeddings.append(float_array) 913 | return np.array(embeddings) 914 | 915 | 916 | 917 | async def bedrock_embedding( 918 | texts: list[str], 919 | model: str = "amazon.titan-embed-text-v2:0", 920 | aws_access_key_id=None, 921 | aws_secret_access_key=None, 922 | aws_session_token=None, 923 | ) -> np.ndarray: 924 | os.environ["AWS_ACCESS_KEY_ID"] = os.environ.get( 925 | "AWS_ACCESS_KEY_ID", aws_access_key_id 926 | ) 927 | os.environ["AWS_SECRET_ACCESS_KEY"] = os.environ.get( 928 | "AWS_SECRET_ACCESS_KEY", aws_secret_access_key 929 | ) 930 | os.environ["AWS_SESSION_TOKEN"] = os.environ.get( 931 | "AWS_SESSION_TOKEN", aws_session_token 932 | ) 933 | 934 | session = aioboto3.Session() 935 | async with session.client("bedrock-runtime") as bedrock_async_client: 936 | if (model_provider := model.split(".")[0]) == "amazon": 937 | embed_texts = [] 938 | for text in texts: 939 | if "v2" in model: 940 | body = json.dumps( 941 | { 942 | "inputText": text, 943 | 944 | "embeddingTypes": ["float"], 945 | } 946 | ) 947 | elif "v1" in model: 948 | body = json.dumps({"inputText": text}) 949 | else: 950 | raise ValueError(f"Model {model} is not supported!") 951 | 952 | response = await bedrock_async_client.invoke_model( 953 | modelId=model, 954 | body=body, 955 | accept="application/json", 956 | contentType="application/json", 957 | ) 958 | 959 | response_body = await response.get("body").json() 960 | 961 | embed_texts.append(response_body["embedding"]) 962 | elif model_provider == "cohere": 963 | body = json.dumps( 964 | {"texts": texts, "input_type": "search_document", "truncate": "NONE"} 965 | ) 966 | 967 | response = await bedrock_async_client.invoke_model( 968 | model=model, 969 | body=body, 970 | accept="application/json", 971 | contentType="application/json", 972 | ) 973 | 974 | response_body = json.loads(response.get("body").read()) 975 | 976 | embed_texts = response_body["embeddings"] 977 | else: 978 | raise ValueError(f"Model provider '{model_provider}' is not supported!") 979 | 980 | return np.array(embed_texts) 981 | 982 | 983 | async def hf_embedding(texts: list[str], tokenizer, embed_model) -> np.ndarray: 984 | device = next(embed_model.parameters()).device 985 | input_ids = tokenizer( 986 | texts, return_tensors="pt", padding=True, truncation=True 987 | ).input_ids.to(device) 988 | with torch.no_grad(): 989 | outputs = embed_model(input_ids) 990 | embeddings = outputs.last_hidden_state.mean(dim=1) 991 | if embeddings.dtype == torch.bfloat16: 992 | return embeddings.detach().to(torch.float32).cpu().numpy() 993 | else: 994 | return embeddings.detach().cpu().numpy() 995 | 996 | 997 | async def ollama_embedding(texts: list[str], embed_model, **kwargs) -> np.ndarray: 998 | """ 999 | Deprecated in favor of `embed`. 1000 | """ 1001 | embed_text = [] 1002 | ollama_client = ollama.Client(**kwargs) 1003 | for text in texts: 1004 | data = ollama_client.embeddings(model=embed_model, prompt=text) 1005 | embed_text.append(data["embedding"]) 1006 | 1007 | return embed_text 1008 | 1009 | 1010 | async def ollama_embed(texts: list[str], embed_model, **kwargs) -> np.ndarray: 1011 | ollama_client = ollama.Client(**kwargs) 1012 | data = ollama_client.embed(model=embed_model, input=texts) 1013 | return data["embeddings"] 1014 | 1015 | 1016 | class Model(BaseModel): 1017 | """ 1018 | This is a Pydantic model class named 'Model' that is used to define a custom language model. 1019 | 1020 | Attributes: 1021 | gen_func (Callable[[Any], str]): A callable function that generates the response from the language model. 1022 | The function should take any argument and return a string. 1023 | kwargs (Dict[str, Any]): A dictionary that contains the arguments to pass to the callable function. 1024 | This could include parameters such as the model name, API key, etc. 1025 | 1026 | Example usage: 1027 | Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_1"]}) 1028 | 1029 | In this example, 'openai_complete_if_cache' is the callable function that generates the response from the OpenAI model. 1030 | The 'kwargs' dictionary contains the model name and API key to be passed to the function. 1031 | """ 1032 | 1033 | gen_func: Callable[[Any], str] = Field( 1034 | ..., 1035 | description="A function that generates the response from the llm. The response must be a string", 1036 | ) 1037 | kwargs: Dict[str, Any] = Field( 1038 | ..., 1039 | description="The arguments to pass to the callable function. Eg. the api key, model name, etc", 1040 | ) 1041 | 1042 | class Config: 1043 | arbitrary_types_allowed = True 1044 | 1045 | 1046 | class MultiModel: 1047 | """ 1048 | Distributes the load across multiple language models. Useful for circumventing low rate limits with certain api providers especially if you are on the free tier. 1049 | Could also be used for spliting across diffrent models or providers. 1050 | 1051 | Attributes: 1052 | models (List[Model]): A list of language models to be used. 1053 | 1054 | Usage example: 1055 | ```python 1056 | models = [ 1057 | Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_1"]}), 1058 | Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_2"]}), 1059 | Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_3"]}), 1060 | Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_4"]}), 1061 | Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_5"]}), 1062 | ] 1063 | multi_model = MultiModel(models) 1064 | rag = LightRAG( 1065 | llm_model_func=multi_model.llm_model_func 1066 | / ..other args 1067 | ) 1068 | ``` 1069 | """ 1070 | 1071 | def __init__(self, models: List[Model]): 1072 | self._models = models 1073 | self._current_model = 0 1074 | 1075 | def _next_model(self): 1076 | self._current_model = (self._current_model + 1) % len(self._models) 1077 | return self._models[self._current_model] 1078 | 1079 | async def llm_model_func( 1080 | self, prompt, system_prompt=None, history_messages=[], **kwargs 1081 | ) -> str: 1082 | kwargs.pop("model", None) 1083 | kwargs.pop("keyword_extraction", None) 1084 | kwargs.pop("mode", None) 1085 | next_model = self._next_model() 1086 | args = dict( 1087 | prompt=prompt, 1088 | system_prompt=system_prompt, 1089 | history_messages=history_messages, 1090 | **kwargs, 1091 | **next_model.kwargs, 1092 | ) 1093 | 1094 | return await next_model.gen_func(**args) 1095 | 1096 | 1097 | if __name__ == "__main__": 1098 | import asyncio 1099 | 1100 | async def main(): 1101 | result = await gpt_4o_mini_complete("How are you?") 1102 | print(result) 1103 | 1104 | asyncio.run(main()) 1105 | -------------------------------------------------------------------------------- /PathRAG/operate.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import json 3 | import re 4 | from tqdm.asyncio import tqdm as tqdm_async 5 | from typing import Union 6 | from collections import Counter, defaultdict 7 | import warnings 8 | import tiktoken 9 | import time 10 | import csv 11 | from .utils import ( 12 | logger, 13 | clean_str, 14 | compute_mdhash_id, 15 | decode_tokens_by_tiktoken, 16 | encode_string_by_tiktoken, 17 | is_float_regex, 18 | list_of_list_to_csv, 19 | pack_user_ass_to_openai_messages, 20 | split_string_by_multi_markers, 21 | truncate_list_by_token_size, 22 | process_combine_contexts, 23 | compute_args_hash, 24 | handle_cache, 25 | save_to_cache, 26 | CacheData, 27 | ) 28 | from .base import ( 29 | BaseGraphStorage, 30 | BaseKVStorage, 31 | BaseVectorStorage, 32 | TextChunkSchema, 33 | QueryParam, 34 | ) 35 | from .prompt import GRAPH_FIELD_SEP, PROMPTS 36 | 37 | 38 | def chunking_by_token_size( 39 | content: str, overlap_token_size=128, max_token_size=1024, tiktoken_model="gpt-4o" 40 | ): 41 | tokens = encode_string_by_tiktoken(content, model_name=tiktoken_model) 42 | results = [] 43 | for index, start in enumerate( 44 | range(0, len(tokens), max_token_size - overlap_token_size) 45 | ): 46 | chunk_content = decode_tokens_by_tiktoken( 47 | tokens[start : start + max_token_size], model_name=tiktoken_model 48 | ) 49 | results.append( 50 | { 51 | "tokens": min(max_token_size, len(tokens) - start), 52 | "content": chunk_content.strip(), 53 | "chunk_order_index": index, 54 | } 55 | ) 56 | return results 57 | 58 | 59 | async def _handle_entity_relation_summary( 60 | entity_or_relation_name: str, 61 | description: str, 62 | global_config: dict, 63 | ) -> str: 64 | use_llm_func: callable = global_config["llm_model_func"] 65 | llm_max_tokens = global_config["llm_model_max_token_size"] 66 | tiktoken_model_name = global_config["tiktoken_model_name"] 67 | summary_max_tokens = global_config["entity_summary_to_max_tokens"] 68 | language = global_config["addon_params"].get( 69 | "language", PROMPTS["DEFAULT_LANGUAGE"] 70 | ) 71 | 72 | tokens = encode_string_by_tiktoken(description, model_name=tiktoken_model_name) 73 | if len(tokens) < summary_max_tokens: 74 | return description 75 | prompt_template = PROMPTS["summarize_entity_descriptions"] 76 | use_description = decode_tokens_by_tiktoken( 77 | tokens[:llm_max_tokens], model_name=tiktoken_model_name 78 | ) 79 | context_base = dict( 80 | entity_name=entity_or_relation_name, 81 | description_list=use_description.split(GRAPH_FIELD_SEP), 82 | language=language, 83 | ) 84 | use_prompt = prompt_template.format(**context_base) 85 | logger.debug(f"Trigger summary: {entity_or_relation_name}") 86 | summary = await use_llm_func(use_prompt, max_tokens=summary_max_tokens) 87 | return summary 88 | 89 | 90 | async def _handle_single_entity_extraction( 91 | record_attributes: list[str], 92 | chunk_key: str, 93 | ): 94 | if len(record_attributes) < 4 or record_attributes[0] != '"entity"': 95 | return None 96 | 97 | entity_name = clean_str(record_attributes[1].upper()) 98 | if not entity_name.strip(): 99 | return None 100 | entity_type = clean_str(record_attributes[2].upper()) 101 | entity_description = clean_str(record_attributes[3]) 102 | entity_source_id = chunk_key 103 | return dict( 104 | entity_name=entity_name, 105 | entity_type=entity_type, 106 | description=entity_description, 107 | source_id=entity_source_id, 108 | ) 109 | 110 | 111 | async def _handle_single_relationship_extraction( 112 | record_attributes: list[str], 113 | chunk_key: str, 114 | ): 115 | if len(record_attributes) < 5 or record_attributes[0] != '"relationship"': 116 | return None 117 | 118 | source = clean_str(record_attributes[1].upper()) 119 | target = clean_str(record_attributes[2].upper()) 120 | edge_description = clean_str(record_attributes[3]) 121 | 122 | edge_keywords = clean_str(record_attributes[4]) 123 | edge_source_id = chunk_key 124 | weight = ( 125 | float(record_attributes[-1]) if is_float_regex(record_attributes[-1]) else 1.0 126 | ) 127 | return dict( 128 | src_id=source, 129 | tgt_id=target, 130 | weight=weight, 131 | description=edge_description, 132 | keywords=edge_keywords, 133 | source_id=edge_source_id, 134 | ) 135 | 136 | 137 | async def _merge_nodes_then_upsert( 138 | entity_name: str, 139 | nodes_data: list[dict], 140 | knowledge_graph_inst: BaseGraphStorage, 141 | global_config: dict, 142 | ): 143 | already_entity_types = [] 144 | already_source_ids = [] 145 | already_description = [] 146 | 147 | already_node = await knowledge_graph_inst.get_node(entity_name) 148 | if already_node is not None: 149 | already_entity_types.append(already_node["entity_type"]) 150 | already_source_ids.extend( 151 | split_string_by_multi_markers(already_node["source_id"], [GRAPH_FIELD_SEP]) 152 | ) 153 | already_description.append(already_node["description"]) 154 | 155 | entity_type = sorted( 156 | Counter( 157 | [dp["entity_type"] for dp in nodes_data] + already_entity_types 158 | ).items(), 159 | key=lambda x: x[1], 160 | reverse=True, 161 | )[0][0] 162 | description = GRAPH_FIELD_SEP.join( 163 | sorted(set([dp["description"] for dp in nodes_data] + already_description)) 164 | ) 165 | source_id = GRAPH_FIELD_SEP.join( 166 | set([dp["source_id"] for dp in nodes_data] + already_source_ids) 167 | ) 168 | description = await _handle_entity_relation_summary( 169 | entity_name, description, global_config 170 | ) 171 | node_data = dict( 172 | entity_type=entity_type, 173 | description=description, 174 | source_id=source_id, 175 | ) 176 | await knowledge_graph_inst.upsert_node( 177 | entity_name, 178 | node_data=node_data, 179 | ) 180 | node_data["entity_name"] = entity_name 181 | return node_data 182 | 183 | 184 | async def _merge_edges_then_upsert( 185 | src_id: str, 186 | tgt_id: str, 187 | edges_data: list[dict], 188 | knowledge_graph_inst: BaseGraphStorage, 189 | global_config: dict, 190 | ): 191 | already_weights = [] 192 | already_source_ids = [] 193 | already_description = [] 194 | already_keywords = [] 195 | 196 | if await knowledge_graph_inst.has_edge(src_id, tgt_id): 197 | already_edge = await knowledge_graph_inst.get_edge(src_id, tgt_id) 198 | already_weights.append(already_edge["weight"]) 199 | already_source_ids.extend( 200 | split_string_by_multi_markers(already_edge["source_id"], [GRAPH_FIELD_SEP]) 201 | ) 202 | already_description.append(already_edge["description"]) 203 | already_keywords.extend( 204 | split_string_by_multi_markers(already_edge["keywords"], [GRAPH_FIELD_SEP]) 205 | ) 206 | 207 | weight = sum([dp["weight"] for dp in edges_data] + already_weights) 208 | description = GRAPH_FIELD_SEP.join( 209 | sorted(set([dp["description"] for dp in edges_data] + already_description)) 210 | ) 211 | keywords = GRAPH_FIELD_SEP.join( 212 | sorted(set([dp["keywords"] for dp in edges_data] + already_keywords)) 213 | ) 214 | source_id = GRAPH_FIELD_SEP.join( 215 | set([dp["source_id"] for dp in edges_data] + already_source_ids) 216 | ) 217 | for need_insert_id in [src_id, tgt_id]: 218 | if not (await knowledge_graph_inst.has_node(need_insert_id)): 219 | await knowledge_graph_inst.upsert_node( 220 | need_insert_id, 221 | node_data={ 222 | "source_id": source_id, 223 | "description": description, 224 | "entity_type": '"UNKNOWN"', 225 | }, 226 | ) 227 | description = await _handle_entity_relation_summary( 228 | f"({src_id}, {tgt_id})", description, global_config 229 | ) 230 | await knowledge_graph_inst.upsert_edge( 231 | src_id, 232 | tgt_id, 233 | edge_data=dict( 234 | weight=weight, 235 | description=description, 236 | keywords=keywords, 237 | source_id=source_id, 238 | ), 239 | ) 240 | 241 | edge_data = dict( 242 | src_id=src_id, 243 | tgt_id=tgt_id, 244 | description=description, 245 | keywords=keywords, 246 | ) 247 | 248 | return edge_data 249 | 250 | 251 | async def extract_entities( 252 | chunks: dict[str, TextChunkSchema], 253 | knowledge_graph_inst: BaseGraphStorage, 254 | entity_vdb: BaseVectorStorage, 255 | relationships_vdb: BaseVectorStorage, 256 | global_config: dict, 257 | ) -> Union[BaseGraphStorage, None]: 258 | time.sleep(20) 259 | use_llm_func: callable = global_config["llm_model_func"] 260 | entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"] 261 | 262 | ordered_chunks = list(chunks.items()) 263 | 264 | language = global_config["addon_params"].get( 265 | "language", PROMPTS["DEFAULT_LANGUAGE"] 266 | ) 267 | entity_types = global_config["addon_params"].get( 268 | "entity_types", PROMPTS["DEFAULT_ENTITY_TYPES"] 269 | ) 270 | example_number = global_config["addon_params"].get("example_number", None) 271 | if example_number and example_number < len(PROMPTS["entity_extraction_examples"]): 272 | examples = "\n".join( 273 | PROMPTS["entity_extraction_examples"][: int(example_number)] 274 | ) 275 | else: 276 | examples = "\n".join(PROMPTS["entity_extraction_examples"]) 277 | 278 | example_context_base = dict( 279 | tuple_delimiter=PROMPTS["DEFAULT_TUPLE_DELIMITER"], 280 | record_delimiter=PROMPTS["DEFAULT_RECORD_DELIMITER"], 281 | completion_delimiter=PROMPTS["DEFAULT_COMPLETION_DELIMITER"], 282 | entity_types=",".join(entity_types), 283 | language=language, 284 | ) 285 | 286 | examples = examples.format(**example_context_base) 287 | 288 | entity_extract_prompt = PROMPTS["entity_extraction"] 289 | context_base = dict( 290 | tuple_delimiter=PROMPTS["DEFAULT_TUPLE_DELIMITER"], 291 | record_delimiter=PROMPTS["DEFAULT_RECORD_DELIMITER"], 292 | completion_delimiter=PROMPTS["DEFAULT_COMPLETION_DELIMITER"], 293 | entity_types=",".join(entity_types), 294 | examples=examples, 295 | language=language, 296 | ) 297 | 298 | continue_prompt = PROMPTS["entiti_continue_extraction"] 299 | if_loop_prompt = PROMPTS["entiti_if_loop_extraction"] 300 | 301 | already_processed = 0 302 | already_entities = 0 303 | already_relations = 0 304 | 305 | async def _process_single_content(chunk_key_dp: tuple[str, TextChunkSchema]): 306 | nonlocal already_processed, already_entities, already_relations 307 | chunk_key = chunk_key_dp[0] 308 | chunk_dp = chunk_key_dp[1] 309 | content = chunk_dp["content"] 310 | hint_prompt = entity_extract_prompt.format( 311 | **context_base, input_text="{input_text}" 312 | ).format(**context_base, input_text=content) 313 | 314 | final_result = await use_llm_func(hint_prompt) 315 | history = pack_user_ass_to_openai_messages(hint_prompt, final_result) 316 | for now_glean_index in range(entity_extract_max_gleaning): 317 | glean_result = await use_llm_func(continue_prompt, history_messages=history) 318 | 319 | history += pack_user_ass_to_openai_messages(continue_prompt, glean_result) 320 | final_result += glean_result 321 | if now_glean_index == entity_extract_max_gleaning - 1: 322 | break 323 | 324 | if_loop_result: str = await use_llm_func( 325 | if_loop_prompt, history_messages=history 326 | ) 327 | if_loop_result = if_loop_result.strip().strip('"').strip("'").lower() 328 | if if_loop_result != "yes": 329 | break 330 | 331 | records = split_string_by_multi_markers( 332 | final_result, 333 | [context_base["record_delimiter"], context_base["completion_delimiter"]], 334 | ) 335 | 336 | maybe_nodes = defaultdict(list) 337 | maybe_edges = defaultdict(list) 338 | for record in records: 339 | record = re.search(r"\((.*)\)", record) 340 | if record is None: 341 | continue 342 | record = record.group(1) 343 | record_attributes = split_string_by_multi_markers( 344 | record, [context_base["tuple_delimiter"]] 345 | ) 346 | if_entities = await _handle_single_entity_extraction( 347 | record_attributes, chunk_key 348 | ) 349 | if if_entities is not None: 350 | maybe_nodes[if_entities["entity_name"]].append(if_entities) 351 | continue 352 | 353 | if_relation = await _handle_single_relationship_extraction( 354 | record_attributes, chunk_key 355 | ) 356 | if if_relation is not None: 357 | maybe_edges[(if_relation["src_id"], if_relation["tgt_id"])].append( 358 | if_relation 359 | ) 360 | already_processed += 1 361 | already_entities += len(maybe_nodes) 362 | already_relations += len(maybe_edges) 363 | now_ticks = PROMPTS["process_tickers"][ 364 | already_processed % len(PROMPTS["process_tickers"]) 365 | ] 366 | print( 367 | f"{now_ticks} Processed {already_processed} chunks, {already_entities} entities(duplicated), {already_relations} relations(duplicated)\r", 368 | end="", 369 | flush=True, 370 | ) 371 | return dict(maybe_nodes), dict(maybe_edges) 372 | 373 | results = [] 374 | for result in tqdm_async( 375 | asyncio.as_completed([_process_single_content(c) for c in ordered_chunks]), 376 | total=len(ordered_chunks), 377 | desc="Extracting entities from chunks", 378 | unit="chunk", 379 | ): 380 | results.append(await result) 381 | 382 | maybe_nodes = defaultdict(list) 383 | maybe_edges = defaultdict(list) 384 | for m_nodes, m_edges in results: 385 | for k, v in m_nodes.items(): 386 | maybe_nodes[k].extend(v) 387 | for k, v in m_edges.items(): 388 | maybe_edges[k].extend(v) 389 | logger.info("Inserting entities into storage...") 390 | all_entities_data = [] 391 | for result in tqdm_async( 392 | asyncio.as_completed( 393 | [ 394 | _merge_nodes_then_upsert(k, v, knowledge_graph_inst, global_config) 395 | for k, v in maybe_nodes.items() 396 | ] 397 | ), 398 | total=len(maybe_nodes), 399 | desc="Inserting entities", 400 | unit="entity", 401 | ): 402 | all_entities_data.append(await result) 403 | 404 | logger.info("Inserting relationships into storage...") 405 | all_relationships_data = [] 406 | for result in tqdm_async( 407 | asyncio.as_completed( 408 | [ 409 | _merge_edges_then_upsert( 410 | k[0], k[1], v, knowledge_graph_inst, global_config 411 | ) 412 | for k, v in maybe_edges.items() 413 | ] 414 | ), 415 | total=len(maybe_edges), 416 | desc="Inserting relationships", 417 | unit="relationship", 418 | ): 419 | all_relationships_data.append(await result) 420 | 421 | if not len(all_entities_data) and not len(all_relationships_data): 422 | logger.warning( 423 | "Didn't extract any entities and relationships, maybe your LLM is not working" 424 | ) 425 | return None 426 | 427 | if not len(all_entities_data): 428 | logger.warning("Didn't extract any entities") 429 | if not len(all_relationships_data): 430 | logger.warning("Didn't extract any relationships") 431 | 432 | if entity_vdb is not None: 433 | data_for_vdb = { 434 | compute_mdhash_id(dp["entity_name"], prefix="ent-"): { 435 | "content": dp["entity_name"] + dp["description"], 436 | "entity_name": dp["entity_name"], 437 | } 438 | for dp in all_entities_data 439 | } 440 | await entity_vdb.upsert(data_for_vdb) 441 | 442 | if relationships_vdb is not None: 443 | data_for_vdb = { 444 | compute_mdhash_id(dp["src_id"] + dp["tgt_id"], prefix="rel-"): { 445 | "src_id": dp["src_id"], 446 | "tgt_id": dp["tgt_id"], 447 | "content": dp["keywords"] 448 | + dp["src_id"] 449 | + dp["tgt_id"] 450 | + dp["description"], 451 | } 452 | for dp in all_relationships_data 453 | } 454 | await relationships_vdb.upsert(data_for_vdb) 455 | 456 | return knowledge_graph_inst 457 | 458 | 459 | 460 | async def kg_query( 461 | query, 462 | knowledge_graph_inst: BaseGraphStorage, 463 | entities_vdb: BaseVectorStorage, 464 | relationships_vdb: BaseVectorStorage, 465 | text_chunks_db: BaseKVStorage[TextChunkSchema], 466 | query_param: QueryParam, 467 | global_config: dict, 468 | hashing_kv: BaseKVStorage = None, 469 | ) -> str: 470 | 471 | use_model_func = global_config["llm_model_func"] 472 | args_hash = compute_args_hash(query_param.mode, query) 473 | cached_response, quantized, min_val, max_val = await handle_cache( 474 | hashing_kv, args_hash, query, query_param.mode 475 | ) 476 | if cached_response is not None: 477 | return cached_response 478 | 479 | example_number = global_config["addon_params"].get("example_number", None) 480 | if example_number and example_number < len(PROMPTS["keywords_extraction_examples"]): 481 | examples = "\n".join( 482 | PROMPTS["keywords_extraction_examples"][: int(example_number)] 483 | ) 484 | else: 485 | examples = "\n".join(PROMPTS["keywords_extraction_examples"]) 486 | language = global_config["addon_params"].get( 487 | "language", PROMPTS["DEFAULT_LANGUAGE"] 488 | ) 489 | 490 | if query_param.mode not in ["hybrid"]: 491 | logger.error(f"Unknown mode {query_param.mode} in kg_query") 492 | return PROMPTS["fail_response"] 493 | 494 | 495 | kw_prompt_temp = PROMPTS["keywords_extraction"] 496 | kw_prompt = kw_prompt_temp.format(query=query, examples=examples, language=language) 497 | result = await use_model_func(kw_prompt, keyword_extraction=True) 498 | logger.info("kw_prompt result:") 499 | print(result) 500 | try: 501 | 502 | match = re.search(r"\{.*\}", result, re.DOTALL) 503 | if match: 504 | result = match.group(0) 505 | keywords_data = json.loads(result) 506 | 507 | hl_keywords = keywords_data.get("high_level_keywords", []) 508 | ll_keywords = keywords_data.get("low_level_keywords", []) 509 | else: 510 | logger.error("No JSON-like structure found in the result.") 511 | return PROMPTS["fail_response"] 512 | 513 | 514 | except json.JSONDecodeError as e: 515 | print(f"JSON parsing error: {e} {result}") 516 | return PROMPTS["fail_response"] 517 | 518 | 519 | if hl_keywords == [] and ll_keywords == []: 520 | logger.warning("low_level_keywords and high_level_keywords is empty") 521 | return PROMPTS["fail_response"] 522 | if ll_keywords == [] and query_param.mode in ["hybrid"]: 523 | logger.warning("low_level_keywords is empty") 524 | return PROMPTS["fail_response"] 525 | else: 526 | ll_keywords = ", ".join(ll_keywords) 527 | if hl_keywords == [] and query_param.mode in ["hybrid"]: 528 | logger.warning("high_level_keywords is empty") 529 | return PROMPTS["fail_response"] 530 | else: 531 | hl_keywords = ", ".join(hl_keywords) 532 | 533 | 534 | keywords = [ll_keywords, hl_keywords] 535 | context= await _build_query_context( 536 | keywords, 537 | knowledge_graph_inst, 538 | entities_vdb, 539 | relationships_vdb, 540 | text_chunks_db, 541 | query_param, 542 | ) 543 | 544 | 545 | 546 | if query_param.only_need_context: 547 | return context 548 | if context is None: 549 | return PROMPTS["fail_response"] 550 | sys_prompt_temp = PROMPTS["rag_response"] 551 | sys_prompt = sys_prompt_temp.format( 552 | context_data=context, response_type=query_param.response_type 553 | ) 554 | if query_param.only_need_prompt: 555 | return sys_prompt 556 | response = await use_model_func( 557 | query, 558 | system_prompt=sys_prompt, 559 | stream=query_param.stream, 560 | ) 561 | if isinstance(response, str) and len(response) > len(sys_prompt): 562 | response = ( 563 | response.replace(sys_prompt, "") 564 | .replace("user", "") 565 | .replace("model", "") 566 | .replace(query, "") 567 | .replace("", "") 568 | .replace("", "") 569 | .strip() 570 | ) 571 | 572 | 573 | await save_to_cache( 574 | hashing_kv, 575 | CacheData( 576 | args_hash=args_hash, 577 | content=response, 578 | prompt=query, 579 | quantized=quantized, 580 | min_val=min_val, 581 | max_val=max_val, 582 | mode=query_param.mode, 583 | ), 584 | ) 585 | return response 586 | 587 | 588 | async def _build_query_context( 589 | query: list, 590 | knowledge_graph_inst: BaseGraphStorage, 591 | entities_vdb: BaseVectorStorage, 592 | relationships_vdb: BaseVectorStorage, 593 | text_chunks_db: BaseKVStorage[TextChunkSchema], 594 | query_param: QueryParam, 595 | ): 596 | ll_entities_context, ll_relations_context, ll_text_units_context = "", "", "" 597 | hl_entities_context, hl_relations_context, hl_text_units_context = "", "", "" 598 | 599 | ll_kewwords, hl_keywrds = query[0], query[1] 600 | if query_param.mode in ["local", "hybrid"]: 601 | if ll_kewwords == "": 602 | ll_entities_context, ll_relations_context, ll_text_units_context = ( 603 | "", 604 | "", 605 | "", 606 | ) 607 | warnings.warn( 608 | "Low Level context is None. Return empty Low entity/relationship/source" 609 | ) 610 | query_param.mode = "global" 611 | else: 612 | ( 613 | ll_entities_context, 614 | ll_relations_context, 615 | ll_text_units_context, 616 | ) = await _get_node_data( 617 | ll_kewwords, 618 | knowledge_graph_inst, 619 | entities_vdb, 620 | text_chunks_db, 621 | query_param, 622 | ) 623 | if query_param.mode in ["hybrid"]: 624 | if hl_keywrds == "": 625 | hl_entities_context, hl_relations_context, hl_text_units_context = ( 626 | "", 627 | "", 628 | "", 629 | ) 630 | warnings.warn( 631 | "High Level context is None. Return empty High entity/relationship/source" 632 | ) 633 | query_param.mode = "local" 634 | else: 635 | ( 636 | hl_entities_context, 637 | hl_relations_context, 638 | hl_text_units_context, 639 | ) = await _get_edge_data( 640 | hl_keywrds, 641 | knowledge_graph_inst, 642 | relationships_vdb, 643 | text_chunks_db, 644 | query_param, 645 | ) 646 | if ( 647 | hl_entities_context == "" 648 | and hl_relations_context == "" 649 | and hl_text_units_context == "" 650 | ): 651 | logger.warn("No high level context found. Switching to local mode.") 652 | query_param.mode = "local" 653 | if query_param.mode == "hybrid": 654 | entities_context, relations_context, text_units_context = combine_contexts( 655 | [hl_entities_context, hl_relations_context], 656 | [ll_entities_context, ll_relations_context], 657 | [hl_text_units_context, ll_text_units_context], 658 | ) 659 | 660 | 661 | return f""" 662 | -----global-information----- 663 | -----high-level entity information----- 664 | ```csv 665 | {hl_entities_context} 666 | ``` 667 | -----high-level relationship information----- 668 | ```csv 669 | {hl_relations_context} 670 | ``` 671 | -----Sources----- 672 | ```csv 673 | {text_units_context} 674 | ``` 675 | -----local-information----- 676 | -----low-level entity information----- 677 | ```csv 678 | {ll_entities_context} 679 | ``` 680 | -----low-level relationship information----- 681 | ```csv 682 | {ll_relations_context} 683 | ``` 684 | """ 685 | 686 | async def _get_node_data( 687 | query, 688 | knowledge_graph_inst: BaseGraphStorage, 689 | entities_vdb: BaseVectorStorage, 690 | text_chunks_db: BaseKVStorage[TextChunkSchema], 691 | query_param: QueryParam, 692 | ): 693 | 694 | results = await entities_vdb.query(query, top_k=query_param.top_k) 695 | if not len(results): 696 | return "", "", "" 697 | 698 | node_datas = await asyncio.gather( 699 | *[knowledge_graph_inst.get_node(r["entity_name"]) for r in results] 700 | ) 701 | if not all([n is not None for n in node_datas]): 702 | logger.warning("Some nodes are missing, maybe the storage is damaged") 703 | 704 | 705 | node_degrees = await asyncio.gather( 706 | *[knowledge_graph_inst.node_degree(r["entity_name"]) for r in results] 707 | ) 708 | node_datas = [ 709 | {**n, "entity_name": k["entity_name"], "rank": d} 710 | for k, n, d in zip(results, node_datas, node_degrees) 711 | if n is not None 712 | ] 713 | use_text_units = await _find_most_related_text_unit_from_entities( 714 | node_datas, query_param, text_chunks_db, knowledge_graph_inst 715 | ) 716 | 717 | 718 | use_relations= await _find_most_related_edges_from_entities3( 719 | node_datas, query_param, knowledge_graph_inst 720 | ) 721 | 722 | logger.info( 723 | f"Local query uses {len(node_datas)} entites, {len(use_relations)} relations, {len(use_text_units)} text units" 724 | ) 725 | 726 | 727 | entites_section_list = [["id", "entity", "type", "description", "rank"]] 728 | for i, n in enumerate(node_datas): 729 | entites_section_list.append( 730 | [ 731 | i, 732 | n["entity_name"], 733 | n.get("entity_type", "UNKNOWN"), 734 | n.get("description", "UNKNOWN"), 735 | n["rank"], 736 | ] 737 | ) 738 | entities_context = list_of_list_to_csv(entites_section_list) 739 | 740 | relations_section_list=[["id","context"]] 741 | for i,e in enumerate(use_relations): 742 | relations_section_list.append([i,e]) 743 | relations_context=list_of_list_to_csv(relations_section_list) 744 | 745 | text_units_section_list = [["id", "content"]] 746 | for i, t in enumerate(use_text_units): 747 | text_units_section_list.append([i, t["content"]]) 748 | text_units_context = list_of_list_to_csv(text_units_section_list) 749 | 750 | return entities_context,relations_context,text_units_context 751 | 752 | 753 | async def _find_most_related_text_unit_from_entities( 754 | node_datas: list[dict], 755 | query_param: QueryParam, 756 | text_chunks_db: BaseKVStorage[TextChunkSchema], 757 | knowledge_graph_inst: BaseGraphStorage, 758 | ): 759 | text_units = [ 760 | split_string_by_multi_markers(dp["source_id"], [GRAPH_FIELD_SEP]) 761 | for dp in node_datas 762 | ] 763 | edges = await asyncio.gather( 764 | *[knowledge_graph_inst.get_node_edges(dp["entity_name"]) for dp in node_datas] 765 | ) 766 | all_one_hop_nodes = set() 767 | for this_edges in edges: 768 | if not this_edges: 769 | continue 770 | all_one_hop_nodes.update([e[1] for e in this_edges]) 771 | 772 | all_one_hop_nodes = list(all_one_hop_nodes) 773 | all_one_hop_nodes_data = await asyncio.gather( 774 | *[knowledge_graph_inst.get_node(e) for e in all_one_hop_nodes] 775 | ) 776 | 777 | 778 | all_one_hop_text_units_lookup = { 779 | k: set(split_string_by_multi_markers(v["source_id"], [GRAPH_FIELD_SEP])) 780 | for k, v in zip(all_one_hop_nodes, all_one_hop_nodes_data) 781 | if v is not None and "source_id" in v 782 | } 783 | 784 | all_text_units_lookup = {} 785 | for index, (this_text_units, this_edges) in enumerate(zip(text_units, edges)): 786 | for c_id in this_text_units: 787 | if c_id not in all_text_units_lookup: 788 | all_text_units_lookup[c_id] = { 789 | "data": await text_chunks_db.get_by_id(c_id), 790 | "order": index, 791 | "relation_counts": 0, 792 | } 793 | 794 | if this_edges: 795 | for e in this_edges: 796 | if ( 797 | e[1] in all_one_hop_text_units_lookup 798 | and c_id in all_one_hop_text_units_lookup[e[1]] 799 | ): 800 | all_text_units_lookup[c_id]["relation_counts"] += 1 801 | 802 | 803 | all_text_units = [ 804 | {"id": k, **v} 805 | for k, v in all_text_units_lookup.items() 806 | if v is not None and v.get("data") is not None and "content" in v["data"] 807 | ] 808 | 809 | if not all_text_units: 810 | logger.warning("No valid text units found") 811 | return [] 812 | 813 | all_text_units = sorted( 814 | all_text_units, key=lambda x: (x["order"], -x["relation_counts"]) 815 | ) 816 | 817 | all_text_units = truncate_list_by_token_size( 818 | all_text_units, 819 | key=lambda x: x["data"]["content"], 820 | max_token_size=query_param.max_token_for_text_unit, 821 | ) 822 | 823 | all_text_units = [t["data"] for t in all_text_units] 824 | return all_text_units 825 | 826 | async def _get_edge_data( 827 | keywords, 828 | knowledge_graph_inst: BaseGraphStorage, 829 | relationships_vdb: BaseVectorStorage, 830 | text_chunks_db: BaseKVStorage[TextChunkSchema], 831 | query_param: QueryParam, 832 | ): 833 | results = await relationships_vdb.query(keywords, top_k=query_param.top_k) 834 | 835 | if not len(results): 836 | return "", "", "" 837 | 838 | edge_datas = await asyncio.gather( 839 | *[knowledge_graph_inst.get_edge(r["src_id"], r["tgt_id"]) for r in results] 840 | ) 841 | 842 | if not all([n is not None for n in edge_datas]): 843 | logger.warning("Some edges are missing, maybe the storage is damaged") 844 | edge_degree = await asyncio.gather( 845 | *[knowledge_graph_inst.edge_degree(r["src_id"], r["tgt_id"]) for r in results] 846 | ) 847 | edge_datas = [ 848 | {"src_id": k["src_id"], "tgt_id": k["tgt_id"], "rank": d, **v} 849 | for k, v, d in zip(results, edge_datas, edge_degree) 850 | if v is not None 851 | ] 852 | edge_datas = sorted( 853 | edge_datas, key=lambda x: (x["rank"], x["weight"]), reverse=True 854 | ) 855 | edge_datas = truncate_list_by_token_size( 856 | edge_datas, 857 | key=lambda x: x["description"], 858 | max_token_size=query_param.max_token_for_global_context, 859 | ) 860 | 861 | use_entities = await _find_most_related_entities_from_relationships( 862 | edge_datas, query_param, knowledge_graph_inst 863 | ) 864 | use_text_units = await _find_related_text_unit_from_relationships( 865 | edge_datas, query_param, text_chunks_db, knowledge_graph_inst 866 | ) 867 | logger.info( 868 | f"Global query uses {len(use_entities)} entites, {len(edge_datas)} relations, {len(use_text_units)} text units" 869 | ) 870 | 871 | relations_section_list = [ 872 | ["id", "source", "target", "description", "keywords", "weight", "rank"] 873 | ] 874 | for i, e in enumerate(edge_datas): 875 | relations_section_list.append( 876 | [ 877 | i, 878 | e["src_id"], 879 | e["tgt_id"], 880 | e["description"], 881 | e["keywords"], 882 | e["weight"], 883 | e["rank"], 884 | ] 885 | ) 886 | relations_context = list_of_list_to_csv(relations_section_list) 887 | 888 | entites_section_list = [["id", "entity", "type", "description", "rank"]] 889 | for i, n in enumerate(use_entities): 890 | entites_section_list.append( 891 | [ 892 | i, 893 | n["entity_name"], 894 | n.get("entity_type", "UNKNOWN"), 895 | n.get("description", "UNKNOWN"), 896 | n["rank"], 897 | ] 898 | ) 899 | entities_context = list_of_list_to_csv(entites_section_list) 900 | 901 | text_units_section_list = [["id", "content"]] 902 | for i, t in enumerate(use_text_units): 903 | text_units_section_list.append([i, t["content"]]) 904 | text_units_context = list_of_list_to_csv(text_units_section_list) 905 | return entities_context, relations_context, text_units_context 906 | 907 | 908 | async def _find_most_related_entities_from_relationships( 909 | edge_datas: list[dict], 910 | query_param: QueryParam, 911 | knowledge_graph_inst: BaseGraphStorage, 912 | ): 913 | entity_names = [] 914 | seen = set() 915 | 916 | for e in edge_datas: 917 | if e["src_id"] not in seen: 918 | entity_names.append(e["src_id"]) 919 | seen.add(e["src_id"]) 920 | if e["tgt_id"] not in seen: 921 | entity_names.append(e["tgt_id"]) 922 | seen.add(e["tgt_id"]) 923 | 924 | node_datas = await asyncio.gather( 925 | *[knowledge_graph_inst.get_node(entity_name) for entity_name in entity_names] 926 | ) 927 | 928 | node_degrees = await asyncio.gather( 929 | *[knowledge_graph_inst.node_degree(entity_name) for entity_name in entity_names] 930 | ) 931 | node_datas = [ 932 | {**n, "entity_name": k, "rank": d} 933 | for k, n, d in zip(entity_names, node_datas, node_degrees) 934 | ] 935 | 936 | node_datas = truncate_list_by_token_size( 937 | node_datas, 938 | key=lambda x: x["description"], 939 | max_token_size=query_param.max_token_for_local_context, 940 | ) 941 | 942 | return node_datas 943 | 944 | 945 | async def _find_related_text_unit_from_relationships( 946 | edge_datas: list[dict], 947 | query_param: QueryParam, 948 | text_chunks_db: BaseKVStorage[TextChunkSchema], 949 | knowledge_graph_inst: BaseGraphStorage, 950 | ): 951 | text_units = [ 952 | split_string_by_multi_markers(dp["source_id"], [GRAPH_FIELD_SEP]) 953 | for dp in edge_datas 954 | ] 955 | all_text_units_lookup = {} 956 | 957 | for index, unit_list in enumerate(text_units): 958 | for c_id in unit_list: 959 | if c_id not in all_text_units_lookup: 960 | chunk_data = await text_chunks_db.get_by_id(c_id) 961 | 962 | if chunk_data is not None and "content" in chunk_data: 963 | all_text_units_lookup[c_id] = { 964 | "data": chunk_data, 965 | "order": index, 966 | } 967 | 968 | if not all_text_units_lookup: 969 | logger.warning("No valid text chunks found") 970 | return [] 971 | 972 | all_text_units = [{"id": k, **v} for k, v in all_text_units_lookup.items()] 973 | all_text_units = sorted(all_text_units, key=lambda x: x["order"]) 974 | 975 | 976 | valid_text_units = [ 977 | t for t in all_text_units if t["data"] is not None and "content" in t["data"] 978 | ] 979 | 980 | if not valid_text_units: 981 | logger.warning("No valid text chunks after filtering") 982 | return [] 983 | 984 | truncated_text_units = truncate_list_by_token_size( 985 | valid_text_units, 986 | key=lambda x: x["data"]["content"], 987 | max_token_size=query_param.max_token_for_text_unit, 988 | ) 989 | 990 | all_text_units: list[TextChunkSchema] = [t["data"] for t in truncated_text_units] 991 | 992 | return all_text_units 993 | 994 | 995 | def combine_contexts(entities, relationships, sources): 996 | 997 | hl_entities, ll_entities = entities[0], entities[1] 998 | hl_relationships, ll_relationships = relationships[0], relationships[1] 999 | hl_sources, ll_sources = sources[0], sources[1] 1000 | 1001 | combined_entities = process_combine_contexts(hl_entities, ll_entities) 1002 | 1003 | combined_relationships = process_combine_contexts( 1004 | hl_relationships, ll_relationships 1005 | ) 1006 | 1007 | combined_sources = process_combine_contexts(hl_sources, ll_sources) 1008 | 1009 | return combined_entities, combined_relationships, combined_sources 1010 | 1011 | 1012 | import networkx as nx 1013 | from collections import defaultdict 1014 | async def find_paths_and_edges_with_stats(graph, target_nodes): 1015 | 1016 | result = defaultdict(lambda: {"paths": [], "edges": set()}) 1017 | path_stats = {"1-hop": 0, "2-hop": 0, "3-hop": 0} 1018 | one_hop_paths = [] 1019 | two_hop_paths = [] 1020 | three_hop_paths = [] 1021 | 1022 | async def dfs(current, target, path, depth): 1023 | 1024 | if depth > 3: 1025 | return 1026 | if current == target: 1027 | result[(path[0], target)]["paths"].append(list(path)) 1028 | for u, v in zip(path[:-1], path[1:]): 1029 | result[(path[0], target)]["edges"].add(tuple(sorted((u, v)))) 1030 | if depth == 1: 1031 | path_stats["1-hop"] += 1 1032 | one_hop_paths.append(list(path)) 1033 | elif depth == 2: 1034 | path_stats["2-hop"] += 1 1035 | two_hop_paths.append(list(path)) 1036 | elif depth == 3: 1037 | path_stats["3-hop"] += 1 1038 | three_hop_paths.append(list(path)) 1039 | return 1040 | neighbors = graph.neighbors(current) 1041 | for neighbor in neighbors: 1042 | if neighbor not in path: 1043 | await dfs(neighbor, target, path + [neighbor], depth + 1) 1044 | 1045 | for node1 in target_nodes: 1046 | for node2 in target_nodes: 1047 | if node1 != node2: 1048 | await dfs(node1, node2, [node1], 0) 1049 | 1050 | for key in result: 1051 | result[key]["edges"] = list(result[key]["edges"]) 1052 | 1053 | return dict(result), path_stats , one_hop_paths, two_hop_paths, three_hop_paths 1054 | def bfs_weighted_paths(G, path, source, target, threshold, alpha): 1055 | results = [] 1056 | edge_weights = defaultdict(float) 1057 | node = source 1058 | follow_dict = {} 1059 | 1060 | for p in path: 1061 | for i in range(len(p) - 1): 1062 | current = p[i] 1063 | next_num = p[i + 1] 1064 | 1065 | if current in follow_dict: 1066 | follow_dict[current].add(next_num) 1067 | else: 1068 | follow_dict[current] = {next_num} 1069 | 1070 | for neighbor in follow_dict[node]: 1071 | edge_weights[(node, neighbor)] += 1/len(follow_dict[node]) 1072 | 1073 | if neighbor == target: 1074 | results.append(([node, neighbor])) 1075 | continue 1076 | 1077 | if edge_weights[(node, neighbor)] > threshold: 1078 | 1079 | for second_neighbor in follow_dict[neighbor]: 1080 | weight = edge_weights[(node, neighbor)] * alpha / len(follow_dict[neighbor]) 1081 | edge_weights[(neighbor, second_neighbor)] += weight 1082 | 1083 | if second_neighbor == target: 1084 | results.append(([node, neighbor, second_neighbor])) 1085 | continue 1086 | 1087 | if edge_weights[(neighbor, second_neighbor)] > threshold: 1088 | 1089 | for third_neighbor in follow_dict[second_neighbor]: 1090 | weight = edge_weights[(neighbor, second_neighbor)] * alpha / len(follow_dict[second_neighbor]) 1091 | edge_weights[(second_neighbor, third_neighbor)] += weight 1092 | 1093 | if third_neighbor == target : 1094 | results.append(([node, neighbor, second_neighbor, third_neighbor])) 1095 | continue 1096 | path_weights = [] 1097 | for p in path: 1098 | path_weight = 0 1099 | for i in range(len(p) - 1): 1100 | edge = (p[i], p[i + 1]) 1101 | path_weight += edge_weights.get(edge, 0) 1102 | path_weights.append(path_weight/(len(p)-1)) 1103 | 1104 | combined = [(p, w) for p, w in zip(path, path_weights)] 1105 | 1106 | return combined 1107 | async def _find_most_related_edges_from_entities3( 1108 | node_datas: list[dict], 1109 | query_param: QueryParam, 1110 | knowledge_graph_inst: BaseGraphStorage, 1111 | ): 1112 | 1113 | G = nx.Graph() 1114 | edges = await knowledge_graph_inst.edges() 1115 | nodes = await knowledge_graph_inst.nodes() 1116 | 1117 | for u, v in edges: 1118 | G.add_edge(u, v) 1119 | G.add_nodes_from(nodes) 1120 | source_nodes = [dp["entity_name"] for dp in node_datas] 1121 | result, path_stats, one_hop_paths, two_hop_paths, three_hop_paths = await find_paths_and_edges_with_stats(G, source_nodes) 1122 | 1123 | 1124 | threshold = 0.3 1125 | alpha = 0.8 1126 | all_results = [] 1127 | 1128 | for node1 in source_nodes: 1129 | for node2 in source_nodes: 1130 | if node1 != node2: 1131 | if (node1, node2) in result: 1132 | sub_G = nx.Graph() 1133 | paths = result[(node1,node2)]['paths'] 1134 | edges = result[(node1,node2)]['edges'] 1135 | sub_G.add_edges_from(edges) 1136 | results = bfs_weighted_paths(G, paths, node1, node2, threshold, alpha) 1137 | all_results+= results 1138 | all_results = sorted(all_results, key=lambda x: x[1], reverse=True) 1139 | seen = set() 1140 | result_edge = [] 1141 | for edge, weight in all_results: 1142 | sorted_edge = tuple(sorted(edge)) 1143 | if sorted_edge not in seen: 1144 | seen.add(sorted_edge) 1145 | result_edge.append((edge, weight)) 1146 | 1147 | 1148 | length_1 = int(len(one_hop_paths)/2) 1149 | length_2 = int(len(two_hop_paths)/2) 1150 | length_3 = int(len(three_hop_paths)/2) 1151 | results = [] 1152 | if one_hop_paths!=[]: 1153 | results = one_hop_paths[0:length_1] 1154 | if two_hop_paths!=[]: 1155 | results = results + two_hop_paths[0:length_2] 1156 | if three_hop_paths!=[]: 1157 | results =results + three_hop_paths[0:length_3] 1158 | 1159 | length = len(results) 1160 | total_edges = 15 1161 | if length < total_edges: 1162 | total_edges = length 1163 | sort_result = [] 1164 | if result_edge: 1165 | if len(result_edge)>total_edges: 1166 | sort_result = result_edge[0:total_edges] 1167 | else : 1168 | sort_result = result_edge 1169 | final_result = [] 1170 | for edge, weight in sort_result: 1171 | final_result.append(edge) 1172 | 1173 | relationship = [] 1174 | 1175 | for path in final_result: 1176 | if len(path) == 4: 1177 | s_name,b1_name,b2_name,t_name = path[0],path[1],path[2],path[3] 1178 | edge0 = await knowledge_graph_inst.get_edge(path[0], path[1]) or await knowledge_graph_inst.get_edge(path[1], path[0]) 1179 | edge1 = await knowledge_graph_inst.get_edge(path[1],path[2]) or await knowledge_graph_inst.get_edge(path[2], path[1]) 1180 | edge2 = await knowledge_graph_inst.get_edge(path[2],path[3]) or await knowledge_graph_inst.get_edge(path[3], path[2]) 1181 | if edge0==None or edge1==None or edge2==None: 1182 | print(path,"边丢失") 1183 | if edge0==None: 1184 | print("edge0丢失") 1185 | if edge1==None: 1186 | print("edge1丢失") 1187 | if edge2==None: 1188 | print("edge2丢失") 1189 | continue 1190 | e1 = "through edge ("+edge0["keywords"]+") to connect to "+s_name+" and "+b1_name+"." 1191 | e2 = "through edge ("+edge1["keywords"]+") to connect to "+b1_name+" and "+b2_name+"." 1192 | e3 = "through edge ("+edge2["keywords"]+") to connect to "+b2_name+" and "+t_name+"." 1193 | s = await knowledge_graph_inst.get_node(s_name) 1194 | s = "The entity "+s_name+" is a "+s["entity_type"]+" with the description("+s["description"]+")" 1195 | b1 = await knowledge_graph_inst.get_node(b1_name) 1196 | b1 = "The entity "+b1_name+" is a "+b1["entity_type"]+" with the description("+b1["description"]+")" 1197 | b2 = await knowledge_graph_inst.get_node(b2_name) 1198 | b2 = "The entity "+b2_name+" is a "+b2["entity_type"]+" with the description("+b2["description"]+")" 1199 | t = await knowledge_graph_inst.get_node(t_name) 1200 | t = "The entity "+t_name+" is a "+t["entity_type"]+" with the description("+t["description"]+")" 1201 | relationship.append([s+e1+b1+"and"+b1+e2+b2+"and"+b2+e3+t]) 1202 | elif len(path) == 3: 1203 | s_name,b_name,t_name = path[0],path[1],path[2] 1204 | edge0 = await knowledge_graph_inst.get_edge(path[0], path[1]) or await knowledge_graph_inst.get_edge(path[1], path[0]) 1205 | edge1 = await knowledge_graph_inst.get_edge(path[1],path[2]) or await knowledge_graph_inst.get_edge(path[2], path[1]) 1206 | if edge0==None or edge1==None: 1207 | print(path,"边丢失") 1208 | continue 1209 | e1 = "through edge("+edge0["keywords"]+") to connect to "+s_name+" and "+b_name+"." 1210 | e2 = "through edge("+edge1["keywords"]+") to connect to "+b_name+" and "+t_name+"." 1211 | s = await knowledge_graph_inst.get_node(s_name) 1212 | s = "The entity "+s_name+" is a "+s["entity_type"]+" with the description("+s["description"]+")" 1213 | b = await knowledge_graph_inst.get_node(b_name) 1214 | b = "The entity "+b_name+" is a "+b["entity_type"]+" with the description("+b["description"]+")" 1215 | t = await knowledge_graph_inst.get_node(t_name) 1216 | t = "The entity "+t_name+" is a "+t["entity_type"]+" with the description("+t["description"]+")" 1217 | relationship.append([s+e1+b+"and"+b+e2+t]) 1218 | elif len(path) == 2: 1219 | s_name,t_name = path[0],path[1] 1220 | edge0 = await knowledge_graph_inst.get_edge(path[0], path[1]) or await knowledge_graph_inst.get_edge(path[1], path[0]) 1221 | if edge0==None: 1222 | print(path,"边丢失") 1223 | continue 1224 | e = "through edge("+edge0["keywords"]+") to connect to "+s_name+" and "+t_name+"." 1225 | s = await knowledge_graph_inst.get_node(s_name) 1226 | s = "The entity "+s_name+" is a "+s["entity_type"]+" with the description("+s["description"]+")" 1227 | t = await knowledge_graph_inst.get_node(t_name) 1228 | t = "The entity "+t_name+" is a "+t["entity_type"]+" with the description("+t["description"]+")" 1229 | relationship.append([s+e+t]) 1230 | 1231 | 1232 | relationship = truncate_list_by_token_size( 1233 | relationship, 1234 | key=lambda x: x[0], 1235 | max_token_size=query_param.max_token_for_local_context, 1236 | ) 1237 | 1238 | reversed_relationship = relationship[::-1] 1239 | return reversed_relationship -------------------------------------------------------------------------------- /PathRAG/prompt.py: -------------------------------------------------------------------------------- 1 | GRAPH_FIELD_SEP = "" 2 | 3 | PROMPTS = {} 4 | 5 | PROMPTS["DEFAULT_LANGUAGE"] = "English" 6 | PROMPTS["DEFAULT_TUPLE_DELIMITER"] = "<|>" 7 | PROMPTS["DEFAULT_RECORD_DELIMITER"] = "##" 8 | PROMPTS["DEFAULT_COMPLETION_DELIMITER"] = "<|COMPLETE|>" 9 | PROMPTS["process_tickers"] = ["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"] 10 | 11 | PROMPTS["DEFAULT_ENTITY_TYPES"] = ["organization", "person", "geo", "event", "category"] 12 | 13 | PROMPTS["entity_extraction"] = """-Goal- 14 | Given a text document that is potentially relevant to this activity and a list of entity types, identify all entities of those types from the text and all relationships among the identified entities. 15 | Use {language} as output language. 16 | 17 | -Steps- 18 | 1. Identify all entities. For each identified entity, extract the following information: 19 | - entity_name: Name of the entity, use same language as input text. If English, capitalized the name. 20 | - entity_type: One of the following types: [{entity_types}] 21 | - entity_description: Comprehensive description of the entity's attributes and activities 22 | Format each entity as ("entity"{tuple_delimiter}{tuple_delimiter}{tuple_delimiter}) 23 | 24 | 2. From the entities identified in step 1, identify all pairs of (source_entity, target_entity) that are *clearly related* to each other. 25 | For each pair of related entities, extract the following information: 26 | - source_entity: name of the source entity, as identified in step 1 27 | - target_entity: name of the target entity, as identified in step 1 28 | - relationship_description: explanation as to why you think the source entity and the target entity are related to each other 29 | - relationship_strength: a numeric score indicating strength of the relationship between the source entity and target entity 30 | - relationship_keywords: one or more high-level key words that summarize the overarching nature of the relationship, focusing on concepts or themes rather than specific details 31 | Format each relationship as ("relationship"{tuple_delimiter}{tuple_delimiter}{tuple_delimiter}{tuple_delimiter}{tuple_delimiter}) 32 | 33 | 3. Identify high-level key words that summarize the main concepts, themes, or topics of the entire text. These should capture the overarching ideas present in the document. 34 | Format the content-level key words as ("content_keywords"{tuple_delimiter}) 35 | 36 | 4. Return output in {language} as a single list of all the entities and relationships identified in steps 1 and 2. Use **{record_delimiter}** as the list delimiter. 37 | 38 | 5. When finished, output {completion_delimiter} 39 | 40 | ###################### 41 | -Examples- 42 | ###################### 43 | {examples} 44 | 45 | ############################# 46 | -Real Data- 47 | ###################### 48 | Entity_types: {entity_types} 49 | Text: {input_text} 50 | ###################### 51 | Output: 52 | """ 53 | 54 | PROMPTS["entity_extraction_examples"] = [ 55 | """Example 1: 56 | 57 | Entity_types: [person, technology, mission, organization, location] 58 | Text: 59 | while Alex clenched his jaw, the buzz of frustration dull against the backdrop of Taylor's authoritarian certainty. It was this competitive undercurrent that kept him alert, the sense that his and Jordan's shared commitment to discovery was an unspoken rebellion against Cruz's narrowing vision of control and order. 60 | 61 | Then Taylor did something unexpected. They paused beside Jordan and, for a moment, observed the device with something akin to reverence. “If this tech can be understood..." Taylor said, their voice quieter, "It could change the game for us. For all of us.” 62 | 63 | The underlying dismissal earlier seemed to falter, replaced by a glimpse of reluctant respect for the gravity of what lay in their hands. Jordan looked up, and for a fleeting heartbeat, their eyes locked with Taylor's, a wordless clash of wills softening into an uneasy truce. 64 | 65 | It was a small transformation, barely perceptible, but one that Alex noted with an inward nod. They had all been brought here by different paths 66 | ################ 67 | Output: 68 | ("entity"{tuple_delimiter}"Alex"{tuple_delimiter}"person"{tuple_delimiter}"Alex is a character who experiences frustration and is observant of the dynamics among other characters."){record_delimiter} 69 | ("entity"{tuple_delimiter}"Taylor"{tuple_delimiter}"person"{tuple_delimiter}"Taylor is portrayed with authoritarian certainty and shows a moment of reverence towards a device, indicating a change in perspective."){record_delimiter} 70 | ("entity"{tuple_delimiter}"Jordan"{tuple_delimiter}"person"{tuple_delimiter}"Jordan shares a commitment to discovery and has a significant interaction with Taylor regarding a device."){record_delimiter} 71 | ("entity"{tuple_delimiter}"Cruz"{tuple_delimiter}"person"{tuple_delimiter}"Cruz is associated with a vision of control and order, influencing the dynamics among other characters."){record_delimiter} 72 | ("entity"{tuple_delimiter}"The Device"{tuple_delimiter}"technology"{tuple_delimiter}"The Device is central to the story, with potential game-changing implications, and is revered by Taylor."){record_delimiter} 73 | ("relationship"{tuple_delimiter}"Alex"{tuple_delimiter}"Taylor"{tuple_delimiter}"Alex is affected by Taylor's authoritarian certainty and observes changes in Taylor's attitude towards the device."{tuple_delimiter}"power dynamics, perspective shift"{tuple_delimiter}7){record_delimiter} 74 | ("relationship"{tuple_delimiter}"Alex"{tuple_delimiter}"Jordan"{tuple_delimiter}"Alex and Jordan share a commitment to discovery, which contrasts with Cruz's vision."{tuple_delimiter}"shared goals, rebellion"{tuple_delimiter}6){record_delimiter} 75 | ("relationship"{tuple_delimiter}"Taylor"{tuple_delimiter}"Jordan"{tuple_delimiter}"Taylor and Jordan interact directly regarding the device, leading to a moment of mutual respect and an uneasy truce."{tuple_delimiter}"conflict resolution, mutual respect"{tuple_delimiter}8){record_delimiter} 76 | ("relationship"{tuple_delimiter}"Jordan"{tuple_delimiter}"Cruz"{tuple_delimiter}"Jordan's commitment to discovery is in rebellion against Cruz's vision of control and order."{tuple_delimiter}"ideological conflict, rebellion"{tuple_delimiter}5){record_delimiter} 77 | ("relationship"{tuple_delimiter}"Taylor"{tuple_delimiter}"The Device"{tuple_delimiter}"Taylor shows reverence towards the device, indicating its importance and potential impact."{tuple_delimiter}"reverence, technological significance"{tuple_delimiter}9){record_delimiter} 78 | ("content_keywords"{tuple_delimiter}"power dynamics, ideological conflict, discovery, rebellion"){completion_delimiter} 79 | #############################""", 80 | """Example 2: 81 | 82 | Entity_types: [person, technology, mission, organization, location] 83 | Text: 84 | They were no longer mere operatives; they had become guardians of a threshold, keepers of a message from a realm beyond stars and stripes. This elevation in their mission could not be shackled by regulations and established protocols—it demanded a new perspective, a new resolve. 85 | 86 | Tension threaded through the dialogue of beeps and static as communications with Washington buzzed in the background. The team stood, a portentous air enveloping them. It was clear that the decisions they made in the ensuing hours could redefine humanity's place in the cosmos or condemn them to ignorance and potential peril. 87 | 88 | Their connection to the stars solidified, the group moved to address the crystallizing warning, shifting from passive recipients to active participants. Mercer's latter instincts gained precedence— the team's mandate had evolved, no longer solely to observe and report but to interact and prepare. A metamorphosis had begun, and Operation: Dulce hummed with the newfound frequency of their daring, a tone set not by the earthly 89 | ############# 90 | Output: 91 | ("entity"{tuple_delimiter}"Washington"{tuple_delimiter}"location"{tuple_delimiter}"Washington is a location where communications are being received, indicating its importance in the decision-making process."){record_delimiter} 92 | ("entity"{tuple_delimiter}"Operation: Dulce"{tuple_delimiter}"mission"{tuple_delimiter}"Operation: Dulce is described as a mission that has evolved to interact and prepare, indicating a significant shift in objectives and activities."){record_delimiter} 93 | ("entity"{tuple_delimiter}"The team"{tuple_delimiter}"organization"{tuple_delimiter}"The team is portrayed as a group of individuals who have transitioned from passive observers to active participants in a mission, showing a dynamic change in their role."){record_delimiter} 94 | ("relationship"{tuple_delimiter}"The team"{tuple_delimiter}"Washington"{tuple_delimiter}"The team receives communications from Washington, which influences their decision-making process."{tuple_delimiter}"decision-making, external influence"{tuple_delimiter}7){record_delimiter} 95 | ("relationship"{tuple_delimiter}"The team"{tuple_delimiter}"Operation: Dulce"{tuple_delimiter}"The team is directly involved in Operation: Dulce, executing its evolved objectives and activities."{tuple_delimiter}"mission evolution, active participation"{tuple_delimiter}9){completion_delimiter} 96 | ("content_keywords"{tuple_delimiter}"mission evolution, decision-making, active participation, cosmic significance"){completion_delimiter} 97 | #############################""", 98 | """Example 3: 99 | 100 | Entity_types: [person, role, technology, organization, event, location, concept] 101 | Text: 102 | their voice slicing through the buzz of activity. "Control may be an illusion when facing an intelligence that literally writes its own rules," they stated stoically, casting a watchful eye over the flurry of data. 103 | 104 | "It's like it's learning to communicate," offered Sam Rivera from a nearby interface, their youthful energy boding a mix of awe and anxiety. "This gives talking to strangers' a whole new meaning." 105 | 106 | Alex surveyed his team—each face a study in concentration, determination, and not a small measure of trepidation. "This might well be our first contact," he acknowledged, "And we need to be ready for whatever answers back." 107 | 108 | Together, they stood on the edge of the unknown, forging humanity's response to a message from the heavens. The ensuing silence was palpable—a collective introspection about their role in this grand cosmic play, one that could rewrite human history. 109 | 110 | The encrypted dialogue continued to unfold, its intricate patterns showing an almost uncanny anticipation 111 | ############# 112 | Output: 113 | ("entity"{tuple_delimiter}"Sam Rivera"{tuple_delimiter}"person"{tuple_delimiter}"Sam Rivera is a member of a team working on communicating with an unknown intelligence, showing a mix of awe and anxiety."){record_delimiter} 114 | ("entity"{tuple_delimiter}"Alex"{tuple_delimiter}"person"{tuple_delimiter}"Alex is the leader of a team attempting first contact with an unknown intelligence, acknowledging the significance of their task."){record_delimiter} 115 | ("entity"{tuple_delimiter}"Control"{tuple_delimiter}"concept"{tuple_delimiter}"Control refers to the ability to manage or govern, which is challenged by an intelligence that writes its own rules."){record_delimiter} 116 | ("entity"{tuple_delimiter}"Intelligence"{tuple_delimiter}"concept"{tuple_delimiter}"Intelligence here refers to an unknown entity capable of writing its own rules and learning to communicate."){record_delimiter} 117 | ("entity"{tuple_delimiter}"First Contact"{tuple_delimiter}"event"{tuple_delimiter}"First Contact is the potential initial communication between humanity and an unknown intelligence."){record_delimiter} 118 | ("entity"{tuple_delimiter}"Humanity's Response"{tuple_delimiter}"event"{tuple_delimiter}"Humanity's Response is the collective action taken by Alex's team in response to a message from an unknown intelligence."){record_delimiter} 119 | ("relationship"{tuple_delimiter}"Sam Rivera"{tuple_delimiter}"Intelligence"{tuple_delimiter}"Sam Rivera is directly involved in the process of learning to communicate with the unknown intelligence."{tuple_delimiter}"communication, learning process"{tuple_delimiter}9){record_delimiter} 120 | ("relationship"{tuple_delimiter}"Alex"{tuple_delimiter}"First Contact"{tuple_delimiter}"Alex leads the team that might be making the First Contact with the unknown intelligence."{tuple_delimiter}"leadership, exploration"{tuple_delimiter}10){record_delimiter} 121 | ("relationship"{tuple_delimiter}"Alex"{tuple_delimiter}"Humanity's Response"{tuple_delimiter}"Alex and his team are the key figures in Humanity's Response to the unknown intelligence."{tuple_delimiter}"collective action, cosmic significance"{tuple_delimiter}8){record_delimiter} 122 | ("relationship"{tuple_delimiter}"Control"{tuple_delimiter}"Intelligence"{tuple_delimiter}"The concept of Control is challenged by the Intelligence that writes its own rules."{tuple_delimiter}"power dynamics, autonomy"{tuple_delimiter}7){record_delimiter} 123 | ("content_keywords"{tuple_delimiter}"first contact, control, communication, cosmic significance"){completion_delimiter} 124 | #############################""", 125 | ] 126 | 127 | PROMPTS[ 128 | "summarize_entity_descriptions" 129 | ] = """You are a helpful assistant responsible for generating a comprehensive summary of the data provided below. 130 | Given one or two entities, and a list of descriptions, all related to the same entity or group of entities. 131 | Please concatenate all of these into a single, comprehensive description. Make sure to include information collected from all the descriptions. 132 | If the provided descriptions are contradictory, please resolve the contradictions and provide a single, coherent summary. 133 | Make sure it is written in third person, and include the entity names so we the have full context. 134 | Use {language} as output language. 135 | 136 | ####### 137 | -Data- 138 | Entities: {entity_name} 139 | Description List: {description_list} 140 | ####### 141 | Output: 142 | """ 143 | 144 | PROMPTS[ 145 | "entiti_continue_extraction" 146 | ] = """MANY entities were missed in the last extraction. Add them below using the same format: 147 | """ 148 | 149 | PROMPTS[ 150 | "entiti_if_loop_extraction" 151 | ] = """It appears some entities may have still been missed. Answer YES | NO if there are still entities that need to be added. 152 | """ 153 | 154 | PROMPTS["fail_response"] = "Sorry, I'm not able to provide an answer to that question." 155 | 156 | PROMPTS["rag_response"] = """---Role--- 157 | 158 | You are a helpful assistant responding to questions about data in the tables provided. 159 | 160 | 161 | ---Goal--- 162 | 163 | Generate a response of the target length and format that responds to the user's question, summarizing all information in the input data tables appropriate for the response length and format, and incorporating any relevant general knowledge. 164 | If you don't know the answer, just say so. Do not make anything up. 165 | Do not include information where the supporting evidence for it is not provided. 166 | 167 | ---Target response length and format--- 168 | 169 | {response_type} 170 | 171 | ---Data tables--- 172 | 173 | {context_data} 174 | 175 | Add sections and commentary to the response as appropriate for the length and format. Style the response in markdown. 176 | """ 177 | 178 | PROMPTS["keywords_extraction"] = """---Role--- 179 | 180 | You are a helpful assistant tasked with identifying both high-level and low-level keywords in the user's query. 181 | 182 | ---Goal--- 183 | 184 | Given the query, list both high-level and low-level keywords. High-level keywords focus on overarching concepts or themes, while low-level keywords focus on specific entities, details, or concrete terms. 185 | 186 | ---Instructions--- 187 | 188 | - Output the keywords in JSON format. 189 | - The JSON should have two keys: 190 | - "high_level_keywords" for overarching concepts or themes. 191 | - "low_level_keywords" for specific entities or details. 192 | 193 | ###################### 194 | -Examples- 195 | ###################### 196 | {examples} 197 | 198 | ############################# 199 | -Real Data- 200 | ###################### 201 | Query: {query} 202 | ###################### 203 | The `Output` should be human text, not unicode characters. Keep the same language as `Query`. 204 | Output: 205 | 206 | """ 207 | 208 | PROMPTS["keywords_extraction_examples"] = [ 209 | """Example 1: 210 | 211 | Query: "How does international trade influence global economic stability?" 212 | ################ 213 | Output: 214 | {{ 215 | "high_level_keywords": ["International trade", "Global economic stability", "Economic impact"], 216 | "low_level_keywords": ["Trade agreements", "Tariffs", "Currency exchange", "Imports", "Exports"] 217 | }} 218 | #############################""", 219 | """Example 2: 220 | 221 | Query: "What are the environmental consequences of deforestation on biodiversity?" 222 | ################ 223 | Output: 224 | {{ 225 | "high_level_keywords": ["Environmental consequences", "Deforestation", "Biodiversity loss"], 226 | "low_level_keywords": ["Species extinction", "Habitat destruction", "Carbon emissions", "Rainforest", "Ecosystem"] 227 | }} 228 | #############################""", 229 | """Example 3: 230 | 231 | Query: "What is the role of education in reducing poverty?" 232 | ################ 233 | Output: 234 | {{ 235 | "high_level_keywords": ["Education", "Poverty reduction", "Socioeconomic development"], 236 | "low_level_keywords": ["School access", "Literacy rates", "Job training", "Income inequality"] 237 | }} 238 | #############################""", 239 | ] 240 | 241 | 242 | PROMPTS["naive_rag_response"] = """---Role--- 243 | 244 | You are a helpful assistant responding to questions about documents provided. 245 | 246 | 247 | ---Goal--- 248 | 249 | Generate a response of the target length and format that responds to the user's question, summarizing all information in the input data tables appropriate for the response length and format, and incorporating any relevant general knowledge. 250 | If you don't know the answer, just say so. Do not make anything up. 251 | Do not include information where the supporting evidence for it is not provided. 252 | 253 | ---Target response length and format--- 254 | 255 | {response_type} 256 | 257 | ---Documents--- 258 | 259 | {content_data} 260 | 261 | Add sections and commentary to the response as appropriate for the length and format. Style the response in markdown. 262 | """ 263 | 264 | PROMPTS[ 265 | "similarity_check" 266 | ] = """Please analyze the similarity between these two questions: 267 | 268 | Question 1: {original_prompt} 269 | Question 2: {cached_prompt} 270 | 271 | Please evaluate the following two points and provide a similarity score between 0 and 1 directly: 272 | 1. Whether these two questions are semantically similar 273 | 2. Whether the answer to Question 2 can be used to answer Question 1 274 | Similarity score criteria: 275 | 0: Completely unrelated or answer cannot be reused, including but not limited to: 276 | - The questions have different topics 277 | - The locations mentioned in the questions are different 278 | - The times mentioned in the questions are different 279 | - The specific individuals mentioned in the questions are different 280 | - The specific events mentioned in the questions are different 281 | - The background information in the questions is different 282 | - The key conditions in the questions are different 283 | 1: Identical and answer can be directly reused 284 | 0.5: Partially related and answer needs modification to be used 285 | Return only a number between 0-1, without any additional content. 286 | """ 287 | -------------------------------------------------------------------------------- /PathRAG/storage.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import html 3 | import os 4 | from tqdm.asyncio import tqdm as tqdm_async 5 | from dataclasses import dataclass 6 | from typing import Any, Union, cast 7 | import networkx as nx 8 | import numpy as np 9 | from nano_vectordb import NanoVectorDB 10 | 11 | from .utils import ( 12 | logger, 13 | load_json, 14 | write_json, 15 | compute_mdhash_id, 16 | ) 17 | 18 | from .base import ( 19 | BaseGraphStorage, 20 | BaseKVStorage, 21 | BaseVectorStorage, 22 | ) 23 | 24 | 25 | @dataclass 26 | class JsonKVStorage(BaseKVStorage): 27 | def __post_init__(self): 28 | working_dir = self.global_config["working_dir"] 29 | self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json") 30 | self._data = load_json(self._file_name) or {} 31 | logger.info(f"Load KV {self.namespace} with {len(self._data)} data") 32 | 33 | async def all_keys(self) -> list[str]: 34 | return list(self._data.keys()) 35 | 36 | async def index_done_callback(self): 37 | write_json(self._data, self._file_name) 38 | 39 | async def get_by_id(self, id): 40 | return self._data.get(id, None) 41 | 42 | async def get_by_ids(self, ids, fields=None): 43 | if fields is None: 44 | return [self._data.get(id, None) for id in ids] 45 | return [ 46 | ( 47 | {k: v for k, v in self._data[id].items() if k in fields} 48 | if self._data.get(id, None) 49 | else None 50 | ) 51 | for id in ids 52 | ] 53 | 54 | async def filter_keys(self, data: list[str]) -> set[str]: 55 | return set([s for s in data if s not in self._data]) 56 | 57 | async def upsert(self, data: dict[str, dict]): 58 | left_data = {k: v for k, v in data.items() if k not in self._data} 59 | self._data.update(left_data) 60 | return left_data 61 | 62 | async def drop(self): 63 | self._data = {} 64 | 65 | 66 | @dataclass 67 | class NanoVectorDBStorage(BaseVectorStorage): 68 | cosine_better_than_threshold: float = 0.2 69 | 70 | def __post_init__(self): 71 | self._client_file_name = os.path.join( 72 | self.global_config["working_dir"], f"vdb_{self.namespace}.json" 73 | ) 74 | self._max_batch_size = self.global_config["embedding_batch_num"] 75 | self._client = NanoVectorDB( 76 | self.embedding_func.embedding_dim, storage_file=self._client_file_name 77 | ) 78 | self.cosine_better_than_threshold = self.global_config.get( 79 | "cosine_better_than_threshold", self.cosine_better_than_threshold 80 | ) 81 | 82 | async def upsert(self, data: dict[str, dict]): 83 | logger.info(f"Inserting {len(data)} vectors to {self.namespace}") 84 | if not len(data): 85 | logger.warning("You insert an empty data to vector DB") 86 | return [] 87 | list_data = [ 88 | { 89 | "__id__": k, 90 | **{k1: v1 for k1, v1 in v.items() if k1 in self.meta_fields}, 91 | } 92 | for k, v in data.items() 93 | ] 94 | contents = [v["content"] for v in data.values()] 95 | batches = [ 96 | contents[i : i + self._max_batch_size] 97 | for i in range(0, len(contents), self._max_batch_size) 98 | ] 99 | 100 | async def wrapped_task(batch): 101 | result = await self.embedding_func(batch) 102 | pbar.update(1) 103 | return result 104 | 105 | embedding_tasks = [wrapped_task(batch) for batch in batches] 106 | pbar = tqdm_async( 107 | total=len(embedding_tasks), desc="Generating embeddings", unit="batch" 108 | ) 109 | embeddings_list = await asyncio.gather(*embedding_tasks) 110 | 111 | embeddings = np.concatenate(embeddings_list) 112 | if len(embeddings) == len(list_data): 113 | for i, d in enumerate(list_data): 114 | d["__vector__"] = embeddings[i] 115 | results = self._client.upsert(datas=list_data) 116 | return results 117 | else: 118 | # sometimes the embedding is not returned correctly. just log it. 119 | logger.error( 120 | f"embedding is not 1-1 with data, {len(embeddings)} != {len(list_data)}" 121 | ) 122 | 123 | async def query(self, query: str, top_k=5): 124 | embedding = await self.embedding_func([query]) 125 | embedding = embedding[0] 126 | results = self._client.query( 127 | query=embedding, 128 | top_k=top_k, 129 | better_than_threshold=self.cosine_better_than_threshold, 130 | ) 131 | results = [ 132 | {**dp, "id": dp["__id__"], "distance": dp["__metrics__"]} for dp in results 133 | ] 134 | return results 135 | 136 | @property 137 | def client_storage(self): 138 | return getattr(self._client, "_NanoVectorDB__storage") 139 | 140 | async def delete_entity(self, entity_name: str): 141 | try: 142 | entity_id = [compute_mdhash_id(entity_name, prefix="ent-")] 143 | 144 | if self._client.get(entity_id): 145 | self._client.delete(entity_id) 146 | logger.info(f"Entity {entity_name} have been deleted.") 147 | else: 148 | logger.info(f"No entity found with name {entity_name}.") 149 | except Exception as e: 150 | logger.error(f"Error while deleting entity {entity_name}: {e}") 151 | 152 | async def delete_relation(self, entity_name: str): 153 | try: 154 | relations = [ 155 | dp 156 | for dp in self.client_storage["data"] 157 | if dp["src_id"] == entity_name or dp["tgt_id"] == entity_name 158 | ] 159 | ids_to_delete = [relation["__id__"] for relation in relations] 160 | 161 | if ids_to_delete: 162 | self._client.delete(ids_to_delete) 163 | logger.info( 164 | f"All relations related to entity {entity_name} have been deleted." 165 | ) 166 | else: 167 | logger.info(f"No relations found for entity {entity_name}.") 168 | except Exception as e: 169 | logger.error( 170 | f"Error while deleting relations for entity {entity_name}: {e}" 171 | ) 172 | 173 | async def index_done_callback(self): 174 | self._client.save() 175 | 176 | 177 | @dataclass 178 | class NetworkXStorage(BaseGraphStorage): 179 | @staticmethod 180 | def load_nx_graph(file_name) -> nx.DiGraph: 181 | if os.path.exists(file_name): 182 | return nx.read_graphml(file_name) 183 | return None 184 | # def load_nx_graph(file_name) -> nx.Graph: 185 | # if os.path.exists(file_name): 186 | # return nx.read_graphml(file_name) 187 | # return None 188 | 189 | @staticmethod 190 | def write_nx_graph(graph: nx.DiGraph, file_name): 191 | logger.info( 192 | f"Writing graph with {graph.number_of_nodes()} nodes, {graph.number_of_edges()} edges" 193 | ) 194 | nx.write_graphml(graph, file_name) 195 | 196 | @staticmethod 197 | def stable_largest_connected_component(graph: nx.Graph) -> nx.Graph: 198 | """Refer to https://github.com/microsoft/graphrag/index/graph/utils/stable_lcc.py 199 | Return the largest connected component of the graph, with nodes and edges sorted in a stable way. 200 | """ 201 | from graspologic.utils import largest_connected_component 202 | 203 | graph = graph.copy() 204 | graph = cast(nx.Graph, largest_connected_component(graph)) 205 | node_mapping = { 206 | node: html.unescape(node.upper().strip()) for node in graph.nodes() 207 | } # type: ignore 208 | graph = nx.relabel_nodes(graph, node_mapping) 209 | return NetworkXStorage._stabilize_graph(graph) 210 | 211 | @staticmethod 212 | def _stabilize_graph(graph: nx.Graph) -> nx.Graph: 213 | """Refer to https://github.com/microsoft/graphrag/index/graph/utils/stable_lcc.py 214 | Ensure an undirected graph with the same relationships will always be read the same way. 215 | """ 216 | fixed_graph = nx.DiGraph() if graph.is_directed() else nx.Graph() 217 | 218 | sorted_nodes = graph.nodes(data=True) 219 | sorted_nodes = sorted(sorted_nodes, key=lambda x: x[0]) 220 | 221 | fixed_graph.add_nodes_from(sorted_nodes) 222 | edges = list(graph.edges(data=True)) 223 | 224 | if not graph.is_directed(): 225 | 226 | def _sort_source_target(edge): 227 | source, target, edge_data = edge 228 | if source > target: 229 | temp = source 230 | source = target 231 | target = temp 232 | return source, target, edge_data 233 | 234 | edges = [_sort_source_target(edge) for edge in edges] 235 | 236 | def _get_edge_key(source: Any, target: Any) -> str: 237 | return f"{source} -> {target}" 238 | 239 | edges = sorted(edges, key=lambda x: _get_edge_key(x[0], x[1])) 240 | 241 | fixed_graph.add_edges_from(edges) 242 | return fixed_graph 243 | 244 | def __post_init__(self): 245 | self._graphml_xml_file = os.path.join( 246 | self.global_config["working_dir"], f"graph_{self.namespace}.graphml" 247 | ) 248 | preloaded_graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file) 249 | if preloaded_graph is not None: 250 | logger.info( 251 | f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges" 252 | ) 253 | self._graph = preloaded_graph or nx.DiGraph() 254 | self._node_embed_algorithms = { 255 | "node2vec": self._node2vec_embed, 256 | } 257 | 258 | async def index_done_callback(self): 259 | NetworkXStorage.write_nx_graph(self._graph, self._graphml_xml_file) 260 | 261 | async def has_node(self, node_id: str) -> bool: 262 | return self._graph.has_node(node_id) 263 | 264 | async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: 265 | return self._graph.has_edge(source_node_id, target_node_id) 266 | 267 | async def get_node(self, node_id: str) -> Union[dict, None]: 268 | return self._graph.nodes.get(node_id) 269 | 270 | async def node_degree(self, node_id: str) -> int: 271 | return self._graph.degree(node_id) 272 | 273 | async def edge_degree(self, src_id: str, tgt_id: str) -> int: 274 | return self._graph.degree(src_id) + self._graph.degree(tgt_id) 275 | 276 | async def get_edge( 277 | self, source_node_id: str, target_node_id: str 278 | ) -> Union[dict, None]: 279 | return self._graph.edges.get((source_node_id, target_node_id)) 280 | 281 | async def get_node_edges(self, source_node_id: str): 282 | if self._graph.has_node(source_node_id): 283 | return list(self._graph.edges(source_node_id)) 284 | return None 285 | async def get_node_in_edges(self, source_node_id: str): 286 | if self._graph.has_node(source_node_id): 287 | return list(self._graph.in_edges(source_node_id)) 288 | return None 289 | async def get_node_out_edges(self, source_node_id: str): 290 | if self._graph.has_node(source_node_id): 291 | return list(self._graph.out_edges(source_node_id)) 292 | return None 293 | 294 | async def get_pagerank(self,source_node_id:str): 295 | pagerank_list=nx.pagerank(self._graph) 296 | if source_node_id in pagerank_list: 297 | return pagerank_list[source_node_id] 298 | else: 299 | print("pagerank failed") 300 | 301 | async def upsert_node(self, node_id: str, node_data: dict[str, str]): 302 | self._graph.add_node(node_id, **node_data) 303 | 304 | async def upsert_edge( 305 | self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] 306 | ): 307 | self._graph.add_edge(source_node_id, target_node_id, **edge_data) 308 | 309 | async def delete_node(self, node_id: str): 310 | """ 311 | Delete a node from the graph based on the specified node_id. 312 | 313 | :param node_id: The node_id to delete 314 | """ 315 | if self._graph.has_node(node_id): 316 | self._graph.remove_node(node_id) 317 | logger.info(f"Node {node_id} deleted from the graph.") 318 | else: 319 | logger.warning(f"Node {node_id} not found in the graph for deletion.") 320 | 321 | async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]: 322 | if algorithm not in self._node_embed_algorithms: 323 | raise ValueError(f"Node embedding algorithm {algorithm} not supported") 324 | return await self._node_embed_algorithms[algorithm]() 325 | 326 | # @TODO: NOT USED 327 | async def _node2vec_embed(self): 328 | from graspologic import embed 329 | 330 | embeddings, nodes = embed.node2vec_embed( 331 | self._graph, 332 | **self.global_config["node2vec_params"], 333 | ) 334 | 335 | nodes_ids = [self._graph.nodes[node_id]["id"] for node_id in nodes] 336 | return embeddings, nodes_ids 337 | 338 | async def edges(self): 339 | return self._graph.edges() 340 | async def nodes(self): 341 | return self._graph.nodes() 342 | -------------------------------------------------------------------------------- /PathRAG/utils.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import html 3 | import io 4 | import csv 5 | import json 6 | import logging 7 | import os 8 | import re 9 | from dataclasses import dataclass 10 | from functools import wraps 11 | from hashlib import md5 12 | from typing import Any, Union, List, Optional 13 | import xml.etree.ElementTree as ET 14 | 15 | import numpy as np 16 | import tiktoken 17 | 18 | from PathRAG.prompt import PROMPTS 19 | 20 | 21 | class UnlimitedSemaphore: 22 | 23 | 24 | async def __aenter__(self): 25 | pass 26 | 27 | async def __aexit__(self, exc_type, exc, tb): 28 | pass 29 | 30 | 31 | ENCODER = None 32 | 33 | logger = logging.getLogger("PathRAG") 34 | 35 | 36 | def set_logger(log_file: str): 37 | logger.setLevel(logging.DEBUG) 38 | 39 | file_handler = logging.FileHandler(log_file) 40 | file_handler.setLevel(logging.DEBUG) 41 | 42 | formatter = logging.Formatter( 43 | "%(asctime)s - %(name)s - %(levelname)s - %(message)s" 44 | ) 45 | file_handler.setFormatter(formatter) 46 | 47 | if not logger.handlers: 48 | logger.addHandler(file_handler) 49 | 50 | 51 | @dataclass 52 | class EmbeddingFunc: 53 | embedding_dim: int 54 | max_token_size: int 55 | func: callable 56 | concurrent_limit: int = 16 57 | 58 | def __post_init__(self): 59 | if self.concurrent_limit != 0: 60 | self._semaphore = asyncio.Semaphore(self.concurrent_limit) 61 | else: 62 | self._semaphore = UnlimitedSemaphore() 63 | 64 | async def __call__(self, *args, **kwargs) -> np.ndarray: 65 | async with self._semaphore: 66 | return await self.func(*args, **kwargs) 67 | 68 | 69 | def locate_json_string_body_from_string(content: str) -> Union[str, None]: 70 | 71 | try: 72 | maybe_json_str = re.search(r"{.*}", content, re.DOTALL) 73 | if maybe_json_str is not None: 74 | maybe_json_str = maybe_json_str.group(0) 75 | maybe_json_str = maybe_json_str.replace("\\n", "") 76 | maybe_json_str = maybe_json_str.replace("\n", "") 77 | maybe_json_str = maybe_json_str.replace("'", '"') 78 | 79 | return maybe_json_str 80 | except Exception: 81 | pass 82 | 83 | 84 | return None 85 | 86 | 87 | def convert_response_to_json(response: str) -> dict: 88 | json_str = locate_json_string_body_from_string(response) 89 | assert json_str is not None, f"Unable to parse JSON from response: {response}" 90 | try: 91 | data = json.loads(json_str) 92 | return data 93 | except json.JSONDecodeError as e: 94 | logger.error(f"Failed to parse JSON: {json_str}") 95 | raise e from None 96 | 97 | 98 | def compute_args_hash(*args): 99 | return md5(str(args).encode()).hexdigest() 100 | 101 | 102 | def compute_mdhash_id(content, prefix: str = ""): 103 | return prefix + md5(content.encode()).hexdigest() 104 | 105 | 106 | def limit_async_func_call(max_size: int, waitting_time: float = 0.0001): 107 | 108 | 109 | def final_decro(func): 110 | 111 | __current_size = 0 112 | 113 | @wraps(func) 114 | async def wait_func(*args, **kwargs): 115 | nonlocal __current_size 116 | while __current_size >= max_size: 117 | await asyncio.sleep(waitting_time) 118 | __current_size += 1 119 | result = await func(*args, **kwargs) 120 | __current_size -= 1 121 | return result 122 | 123 | return wait_func 124 | 125 | return final_decro 126 | 127 | 128 | def wrap_embedding_func_with_attrs(**kwargs): 129 | 130 | 131 | def final_decro(func) -> EmbeddingFunc: 132 | new_func = EmbeddingFunc(**kwargs, func=func) 133 | return new_func 134 | 135 | return final_decro 136 | 137 | 138 | def load_json(file_name): 139 | if not os.path.exists(file_name): 140 | return None 141 | with open(file_name, encoding="utf-8") as f: 142 | return json.load(f) 143 | 144 | 145 | def write_json(json_obj, file_name): 146 | with open(file_name, "w", encoding="utf-8") as f: 147 | json.dump(json_obj, f, indent=2, ensure_ascii=False) 148 | 149 | 150 | def encode_string_by_tiktoken(content: str, model_name: str = "gpt-4o-mini"): 151 | global ENCODER 152 | if ENCODER is None: 153 | ENCODER = tiktoken.encoding_for_model(model_name) 154 | tokens = ENCODER.encode(content) 155 | return tokens 156 | 157 | 158 | def decode_tokens_by_tiktoken(tokens: list[int], model_name: str = "gpt-4o-mini"): 159 | global ENCODER 160 | if ENCODER is None: 161 | ENCODER = tiktoken.encoding_for_model(model_name) 162 | content = ENCODER.decode(tokens) 163 | return content 164 | 165 | 166 | def pack_user_ass_to_openai_messages(*args: str): 167 | roles = ["user", "assistant"] 168 | return [ 169 | {"role": roles[i % 2], "content": content} for i, content in enumerate(args) 170 | ] 171 | 172 | 173 | def split_string_by_multi_markers(content: str, markers: list[str]) -> list[str]: 174 | 175 | if not markers: 176 | return [content] 177 | results = re.split("|".join(re.escape(marker) for marker in markers), content) 178 | return [r.strip() for r in results if r.strip()] 179 | 180 | 181 | 182 | def clean_str(input: Any) -> str: 183 | 184 | 185 | if not isinstance(input, str): 186 | return input 187 | 188 | result = html.unescape(input.strip()) 189 | 190 | return re.sub(r"[\x00-\x1f\x7f-\x9f]", "", result) 191 | 192 | 193 | def is_float_regex(value): 194 | return bool(re.match(r"^[-+]?[0-9]*\.?[0-9]+$", value)) 195 | 196 | 197 | def truncate_list_by_token_size(list_data: list, key: callable, max_token_size: int): 198 | 199 | if max_token_size <= 0: 200 | return [] 201 | tokens = 0 202 | for i, data in enumerate(list_data): 203 | tokens += len(encode_string_by_tiktoken(key(data))) 204 | if tokens > max_token_size: 205 | return list_data[:i] 206 | return list_data 207 | 208 | 209 | def list_of_list_to_csv(data: List[List[str]]) -> str: 210 | output = io.StringIO() 211 | writer = csv.writer(output) 212 | writer.writerows(data) 213 | return output.getvalue() 214 | 215 | 216 | def csv_string_to_list(csv_string: str) -> List[List[str]]: 217 | output = io.StringIO(csv_string) 218 | reader = csv.reader(output) 219 | return [row for row in reader] 220 | 221 | 222 | def save_data_to_file(data, file_name): 223 | with open(file_name, "w", encoding="utf-8") as f: 224 | json.dump(data, f, ensure_ascii=False, indent=4) 225 | 226 | 227 | def xml_to_json(xml_file): 228 | try: 229 | tree = ET.parse(xml_file) 230 | root = tree.getroot() 231 | 232 | print(f"Root element: {root.tag}") 233 | print(f"Root attributes: {root.attrib}") 234 | 235 | data = {"nodes": [], "edges": []} 236 | namespace = {"": "http://graphml.graphdrawing.org/xmlns"} 237 | 238 | for node in root.findall(".//node", namespace): 239 | node_data = { 240 | "id": node.get("id").strip('"'), 241 | "entity_type": node.find("./data[@key='d0']", namespace).text.strip('"') 242 | if node.find("./data[@key='d0']", namespace) is not None 243 | else "", 244 | "description": node.find("./data[@key='d1']", namespace).text 245 | if node.find("./data[@key='d1']", namespace) is not None 246 | else "", 247 | "source_id": node.find("./data[@key='d2']", namespace).text 248 | if node.find("./data[@key='d2']", namespace) is not None 249 | else "", 250 | } 251 | data["nodes"].append(node_data) 252 | 253 | for edge in root.findall(".//edge", namespace): 254 | edge_data = { 255 | "source": edge.get("source").strip('"'), 256 | "target": edge.get("target").strip('"'), 257 | "weight": float(edge.find("./data[@key='d3']", namespace).text) 258 | if edge.find("./data[@key='d3']", namespace) is not None 259 | else 0.0, 260 | "description": edge.find("./data[@key='d4']", namespace).text 261 | if edge.find("./data[@key='d4']", namespace) is not None 262 | else "", 263 | "keywords": edge.find("./data[@key='d5']", namespace).text 264 | if edge.find("./data[@key='d5']", namespace) is not None 265 | else "", 266 | "source_id": edge.find("./data[@key='d6']", namespace).text 267 | if edge.find("./data[@key='d6']", namespace) is not None 268 | else "", 269 | } 270 | data["edges"].append(edge_data) 271 | 272 | print(f"Found {len(data['nodes'])} nodes and {len(data['edges'])} edges") 273 | 274 | return data 275 | except ET.ParseError as e: 276 | print(f"Error parsing XML file: {e}") 277 | return None 278 | except Exception as e: 279 | print(f"An error occurred: {e}") 280 | return None 281 | 282 | 283 | def process_combine_contexts(hl, ll): 284 | header = None 285 | list_hl = csv_string_to_list(hl.strip()) 286 | list_ll = csv_string_to_list(ll.strip()) 287 | 288 | if list_hl: 289 | header = list_hl[0] 290 | list_hl = list_hl[1:] 291 | if list_ll: 292 | header = list_ll[0] 293 | list_ll = list_ll[1:] 294 | if header is None: 295 | return "" 296 | 297 | if list_hl: 298 | list_hl = [",".join(item[1:]) for item in list_hl if item] 299 | if list_ll: 300 | list_ll = [",".join(item[1:]) for item in list_ll if item] 301 | 302 | combined_sources = [] 303 | seen = set() 304 | 305 | for item in list_hl + list_ll: 306 | if item and item not in seen: 307 | combined_sources.append(item) 308 | seen.add(item) 309 | 310 | combined_sources_result = [",\t".join(header)] 311 | 312 | for i, item in enumerate(combined_sources, start=1): 313 | combined_sources_result.append(f"{i},\t{item}") 314 | 315 | combined_sources_result = "\n".join(combined_sources_result) 316 | 317 | return combined_sources_result 318 | 319 | 320 | async def get_best_cached_response( 321 | hashing_kv, 322 | current_embedding, 323 | similarity_threshold=0.95, 324 | mode="default", 325 | use_llm_check=False, 326 | llm_func=None, 327 | original_prompt=None, 328 | ) -> Union[str, None]: 329 | 330 | mode_cache = await hashing_kv.get_by_id(mode) 331 | if not mode_cache: 332 | return None 333 | 334 | best_similarity = -1 335 | best_response = None 336 | best_prompt = None 337 | best_cache_id = None 338 | 339 | 340 | for cache_id, cache_data in mode_cache.items(): 341 | if cache_data["embedding"] is None: 342 | continue 343 | 344 | 345 | cached_quantized = np.frombuffer( 346 | bytes.fromhex(cache_data["embedding"]), dtype=np.uint8 347 | ).reshape(cache_data["embedding_shape"]) 348 | cached_embedding = dequantize_embedding( 349 | cached_quantized, 350 | cache_data["embedding_min"], 351 | cache_data["embedding_max"], 352 | ) 353 | 354 | similarity = cosine_similarity(current_embedding, cached_embedding) 355 | if similarity > best_similarity: 356 | best_similarity = similarity 357 | best_response = cache_data["return"] 358 | best_prompt = cache_data["original_prompt"] 359 | best_cache_id = cache_id 360 | 361 | if best_similarity > similarity_threshold: 362 | 363 | if use_llm_check and llm_func and original_prompt and best_prompt: 364 | compare_prompt = PROMPTS["similarity_check"].format( 365 | original_prompt=original_prompt, cached_prompt=best_prompt 366 | ) 367 | 368 | try: 369 | llm_result = await llm_func(compare_prompt) 370 | llm_result = llm_result.strip() 371 | llm_similarity = float(llm_result) 372 | 373 | 374 | best_similarity = llm_similarity 375 | if best_similarity < similarity_threshold: 376 | log_data = { 377 | "event": "llm_check_cache_rejected", 378 | "original_question": original_prompt[:100] + "..." 379 | if len(original_prompt) > 100 380 | else original_prompt, 381 | "cached_question": best_prompt[:100] + "..." 382 | if len(best_prompt) > 100 383 | else best_prompt, 384 | "similarity_score": round(best_similarity, 4), 385 | "threshold": similarity_threshold, 386 | } 387 | logger.info(json.dumps(log_data, ensure_ascii=False)) 388 | return None 389 | except Exception as e: 390 | logger.warning(f"LLM similarity check failed: {e}") 391 | return None 392 | 393 | prompt_display = ( 394 | best_prompt[:50] + "..." if len(best_prompt) > 50 else best_prompt 395 | ) 396 | log_data = { 397 | "event": "cache_hit", 398 | "mode": mode, 399 | "similarity": round(best_similarity, 4), 400 | "cache_id": best_cache_id, 401 | "original_prompt": prompt_display, 402 | } 403 | logger.info(json.dumps(log_data, ensure_ascii=False)) 404 | return best_response 405 | return None 406 | 407 | 408 | def cosine_similarity(v1, v2): 409 | 410 | dot_product = np.dot(v1, v2) 411 | norm1 = np.linalg.norm(v1) 412 | norm2 = np.linalg.norm(v2) 413 | return dot_product / (norm1 * norm2) 414 | 415 | 416 | def quantize_embedding(embedding: np.ndarray, bits=8) -> tuple: 417 | 418 | 419 | min_val = embedding.min() 420 | max_val = embedding.max() 421 | 422 | 423 | scale = (2**bits - 1) / (max_val - min_val) 424 | quantized = np.round((embedding - min_val) * scale).astype(np.uint8) 425 | 426 | return quantized, min_val, max_val 427 | 428 | 429 | def dequantize_embedding( 430 | quantized: np.ndarray, min_val: float, max_val: float, bits=8 431 | ) -> np.ndarray: 432 | 433 | scale = (max_val - min_val) / (2**bits - 1) 434 | return (quantized * scale + min_val).astype(np.float32) 435 | 436 | 437 | async def handle_cache(hashing_kv, args_hash, prompt, mode="default"): 438 | 439 | if hashing_kv is None: 440 | return None, None, None, None 441 | 442 | 443 | if mode == "naive": 444 | mode_cache = await hashing_kv.get_by_id(mode) or {} 445 | if args_hash in mode_cache: 446 | return mode_cache[args_hash]["return"], None, None, None 447 | return None, None, None, None 448 | 449 | 450 | embedding_cache_config = hashing_kv.global_config.get( 451 | "embedding_cache_config", 452 | {"enabled": False, "similarity_threshold": 0.95, "use_llm_check": False}, 453 | ) 454 | is_embedding_cache_enabled = embedding_cache_config["enabled"] 455 | use_llm_check = embedding_cache_config.get("use_llm_check", False) 456 | 457 | quantized = min_val = max_val = None 458 | if is_embedding_cache_enabled: 459 | 460 | embedding_model_func = hashing_kv.global_config["embedding_func"]["func"] 461 | llm_model_func = hashing_kv.global_config.get("llm_model_func") 462 | 463 | current_embedding = await embedding_model_func([prompt]) 464 | quantized, min_val, max_val = quantize_embedding(current_embedding[0]) 465 | best_cached_response = await get_best_cached_response( 466 | hashing_kv, 467 | current_embedding[0], 468 | similarity_threshold=embedding_cache_config["similarity_threshold"], 469 | mode=mode, 470 | use_llm_check=use_llm_check, 471 | llm_func=llm_model_func if use_llm_check else None, 472 | original_prompt=prompt if use_llm_check else None, 473 | ) 474 | if best_cached_response is not None: 475 | return best_cached_response, None, None, None 476 | else: 477 | 478 | mode_cache = await hashing_kv.get_by_id(mode) or {} 479 | if args_hash in mode_cache: 480 | return mode_cache[args_hash]["return"], None, None, None 481 | 482 | return None, quantized, min_val, max_val 483 | 484 | 485 | @dataclass 486 | class CacheData: 487 | args_hash: str 488 | content: str 489 | prompt: str 490 | quantized: Optional[np.ndarray] = None 491 | min_val: Optional[float] = None 492 | max_val: Optional[float] = None 493 | mode: str = "default" 494 | 495 | 496 | async def save_to_cache(hashing_kv, cache_data: CacheData): 497 | if hashing_kv is None or hasattr(cache_data.content, "__aiter__"): 498 | return 499 | 500 | mode_cache = await hashing_kv.get_by_id(cache_data.mode) or {} 501 | 502 | mode_cache[cache_data.args_hash] = { 503 | "return": cache_data.content, 504 | "embedding": cache_data.quantized.tobytes().hex() 505 | if cache_data.quantized is not None 506 | else None, 507 | "embedding_shape": cache_data.quantized.shape 508 | if cache_data.quantized is not None 509 | else None, 510 | "embedding_min": cache_data.min_val, 511 | "embedding_max": cache_data.max_val, 512 | "original_prompt": cache_data.prompt, 513 | } 514 | 515 | await hashing_kv.upsert({cache_data.mode: mode_cache}) 516 | 517 | 518 | def safe_unicode_decode(content): 519 | unicode_escape_pattern = re.compile(r"\\u([0-9a-fA-F]{4})") 520 | def replace_unicode_escape(match): 521 | return chr(int(match.group(1), 16)) 522 | 523 | decoded_content = unicode_escape_pattern.sub( 524 | replace_unicode_escape, content.decode("utf-8") 525 | ) 526 | 527 | return decoded_content 528 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | The code for the paper **"PathRAG: Pruning Graph-based Retrieval Augmented Generation with Relational Paths"**. 2 | ## Install 3 | ```bash 4 | cd PathRAG 5 | pip install -e . 6 | ``` 7 | ## Quick Start 8 | * You can quickly experience this project in the `v1_test.py` file. 9 | * Set OpenAI API key in environment if using OpenAI models: `api_key="sk-...".` in the `v1_test.py` and `llm.py` file 10 | * Prepare your retrieval document "text.txt". 11 | * Use the following Python snippet in the "v1_text.py" file to initialize PathRAG and perform queries. 12 | 13 | ```python 14 | import os 15 | from PathRAG import PathRAG, QueryParam 16 | from PathRAG.llm import gpt_4o_mini_complete 17 | 18 | WORKING_DIR = "./your_working_dir" 19 | api_key="your_api_key" 20 | os.environ["OPENAI_API_KEY"] = api_key 21 | base_url="https://api.openai.com/v1" 22 | os.environ["OPENAI_API_BASE"]=base_url 23 | 24 | 25 | if not os.path.exists(WORKING_DIR): 26 | os.mkdir(WORKING_DIR) 27 | 28 | rag = PathRAG( 29 | working_dir=WORKING_DIR, 30 | llm_model_func=gpt_4o_mini_complete, 31 | ) 32 | 33 | data_file="./text.txt" 34 | question="your_question" 35 | with open(data_file) as f: 36 | rag.insert(f.read()) 37 | 38 | print(rag.query(question, param=QueryParam(mode="hybrid"))) 39 | ``` 40 | ## Parameter modification 41 | You can adjust the relevant parameters in the `base.py` and `operate.py` files. 42 | 43 | ## Batch Insert 44 | ```python 45 | import os 46 | folder_path = "your_folder_path" 47 | 48 | txt_files = [f for f in os.listdir(folder_path) if f.endswith(".txt")] 49 | for file_name in txt_files: 50 | file_path = os.path.join(folder_path, file_name) 51 | with open(file_path, "r", encoding="utf-8") as file: 52 | rag.insert(file.read()) 53 | ``` 54 | 55 | ## Cite 56 | Please cite our paper if you use this code in your own work: 57 | ```python 58 | @article{chen2025pathrag, 59 | title={PathRAG: Pruning Graph-based Retrieval Augmented Generation with Relational Paths}, 60 | author={Chen, Boyu and Guo, Zirui and Yang, Zidan and Chen, Yuluo and Chen, Junze and Liu, Zhenghao and Shi, Chuan and Yang, Cheng}, 61 | journal={arXiv preprint arXiv:2502.14902}, 62 | year={2025} 63 | } 64 | ``` 65 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate 2 | aioboto3 3 | aiohttp 4 | 5 | # database packages 6 | graspologic 7 | hnswlib 8 | nano-vectordb 9 | neo4j 10 | networkx 11 | ollama 12 | openai 13 | oracledb 14 | psycopg[binary,pool] 15 | pymilvus 16 | pymongo 17 | pymysql 18 | pyvis 19 | # lmdeploy[all] 20 | sqlalchemy 21 | tenacity 22 | 23 | 24 | # LLM packages 25 | tiktoken 26 | torch 27 | transformers 28 | xxhash 29 | -------------------------------------------------------------------------------- /v1_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PathRAG import PathRAG, QueryParam 3 | from PathRAG.llm import gpt_4o_mini_complete 4 | 5 | WORKING_DIR = "" 6 | 7 | api_key="" 8 | os.environ["OPENAI_API_KEY"] = api_key 9 | base_url="https://api.openai.com/v1" 10 | os.environ["OPENAI_API_BASE"]=base_url 11 | 12 | 13 | if not os.path.exists(WORKING_DIR): 14 | os.mkdir(WORKING_DIR) 15 | 16 | rag = PathRAG( 17 | working_dir=WORKING_DIR, 18 | llm_model_func=gpt_4o_mini_complete, 19 | ) 20 | 21 | data_file="" 22 | question="" 23 | with open(data_file) as f: 24 | rag.insert(f.read()) 25 | 26 | print(rag.query(question, param=QueryParam(mode="hybrid"))) 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | --------------------------------------------------------------------------------